Skip to main content

fraiseql_auth/oauth/
refresh.rs

1//! Token refresh scheduler and background worker.
2
3use std::{sync::Arc, time::Duration as StdDuration};
4
5use chrono::{DateTime, Duration, Utc};
6
7use super::super::error::AuthError;
8
9/// Token refresh scheduler
10#[derive(Debug, Clone)]
11pub struct TokenRefreshScheduler {
12    /// Sessions needing refresh
13    // std::sync::Mutex is intentional: this lock is never held across .await.
14    // Switch to tokio::sync::Mutex if that constraint ever changes.
15    refresh_queue: Arc<std::sync::Mutex<Vec<(String, DateTime<Utc>)>>>,
16}
17
18impl TokenRefreshScheduler {
19    /// Create new refresh scheduler
20    pub fn new() -> Self {
21        Self {
22            refresh_queue: Arc::new(std::sync::Mutex::new(Vec::new())),
23        }
24    }
25
26    /// Schedule token refresh for session
27    ///
28    /// # Errors
29    ///
30    /// Returns `AuthError::Internal` if the mutex is poisoned.
31    pub fn schedule_refresh(
32        &self,
33        session_id: String,
34        refresh_time: DateTime<Utc>,
35    ) -> std::result::Result<(), AuthError> {
36        let mut queue = self.refresh_queue.lock().map_err(|_| AuthError::Internal {
37            message: "token refresh scheduler mutex poisoned".to_string(),
38        })?;
39        queue.push((session_id, refresh_time));
40        queue.sort_by_key(|(_, time)| *time);
41        Ok(())
42    }
43
44    /// Get next session to refresh
45    ///
46    /// # Errors
47    ///
48    /// Returns `AuthError::Internal` if the mutex is poisoned.
49    pub fn get_next_refresh(&self) -> std::result::Result<Option<String>, AuthError> {
50        let mut queue = self.refresh_queue.lock().map_err(|_| AuthError::Internal {
51            message: "token refresh scheduler mutex poisoned".to_string(),
52        })?;
53        if let Some((_, refresh_time)) = queue.first() {
54            if *refresh_time <= Utc::now() {
55                let (id, _) = queue.remove(0);
56                return Ok(Some(id));
57            }
58        }
59        Ok(None)
60    }
61
62    /// Cancel scheduled refresh
63    ///
64    /// # Errors
65    ///
66    /// Returns `AuthError::Internal` if the mutex is poisoned.
67    pub fn cancel_refresh(&self, session_id: &str) -> std::result::Result<bool, AuthError> {
68        let mut queue = self.refresh_queue.lock().map_err(|_| AuthError::Internal {
69            message: "token refresh scheduler mutex poisoned".to_string(),
70        })?;
71        let len_before = queue.len();
72        queue.retain(|(id, _)| id != session_id);
73        Ok(queue.len() < len_before)
74    }
75}
76
77impl Default for TokenRefreshScheduler {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83/// Callback trait for the token refresh worker to perform provider-specific
84/// token refresh and session updates.
85#[async_trait::async_trait]
86pub trait TokenRefresher: Send + Sync {
87    /// Refresh the token for the given session ID.
88    ///
89    /// Should look up the session, call the appropriate OAuth2 provider's
90    /// `refresh_token()`, update the stored session, and return the new expiry.
91    /// Returns `None` if the session no longer exists or has no refresh token.
92    async fn refresh_session(
93        &self,
94        session_id: &str,
95    ) -> std::result::Result<Option<DateTime<Utc>>, AuthError>;
96}
97
98/// Background worker that polls the `TokenRefreshScheduler` and refreshes
99/// expiring OAuth tokens.
100pub struct TokenRefreshWorker {
101    scheduler:     Arc<TokenRefreshScheduler>,
102    refresher:     Arc<dyn TokenRefresher>,
103    cancel_rx:     tokio::sync::watch::Receiver<bool>,
104    poll_interval: StdDuration,
105}
106
107impl TokenRefreshWorker {
108    /// Create a new token refresh worker.
109    ///
110    /// Returns the worker and a sender to trigger cancellation (send `true` to
111    /// stop).
112    pub fn new(
113        scheduler: Arc<TokenRefreshScheduler>,
114        refresher: Arc<dyn TokenRefresher>,
115        poll_interval: StdDuration,
116    ) -> (Self, tokio::sync::watch::Sender<bool>) {
117        let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
118        (
119            Self {
120                scheduler,
121                refresher,
122                cancel_rx,
123                poll_interval,
124            },
125            cancel_tx,
126        )
127    }
128
129    /// Run the refresh loop until cancelled.
130    pub async fn run(mut self) {
131        tracing::info!(
132            interval_secs = self.poll_interval.as_secs(),
133            "Token refresh worker started"
134        );
135        loop {
136            tokio::select! {
137                result = self.cancel_rx.changed() => {
138                    if result.is_err() || *self.cancel_rx.borrow() {
139                        tracing::info!("Token refresh worker stopped");
140                        break;
141                    }
142                },
143                () = tokio::time::sleep(self.poll_interval) => {
144                    self.process_due_refreshes().await;
145                }
146            }
147        }
148    }
149
150    async fn process_due_refreshes(&self) {
151        while let Ok(Some(session_id)) = self.scheduler.get_next_refresh() {
152            match self.refresher.refresh_session(&session_id).await {
153                Ok(Some(new_expiry)) => {
154                    // Re-schedule at 80% of the remaining time
155                    let remaining = new_expiry - Utc::now();
156                    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
157                    // Reason: intentional 80% f64 scaling; sub-second precision loss acceptable for
158                    // scheduling
159                    let next_refresh_secs = (remaining.num_seconds() as f64 * 0.8) as i64;
160                    let next_refresh = Utc::now() + Duration::seconds(next_refresh_secs);
161                    if let Err(e) =
162                        self.scheduler.schedule_refresh(session_id.clone(), next_refresh)
163                    {
164                        tracing::warn!(
165                            session_id = %session_id,
166                            error = %e,
167                            "Failed to re-schedule token refresh"
168                        );
169                    }
170                },
171                Ok(None) => {
172                    tracing::debug!(
173                        session_id = %session_id,
174                        "Session no longer exists, skipping refresh"
175                    );
176                },
177                Err(e) => {
178                    tracing::warn!(
179                        session_id = %session_id,
180                        error = %e,
181                        "Token refresh failed"
182                    );
183                },
184            }
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use chrono::Duration;
192
193    use super::*;
194
195    #[test]
196    fn test_scheduler_schedule_and_get_due_refresh() {
197        let scheduler = TokenRefreshScheduler::new();
198        // Schedule a refresh in the past (already due)
199        let past = Utc::now() - Duration::seconds(10);
200        scheduler
201            .schedule_refresh("session_a".to_string(), past)
202            .expect("schedule_refresh must succeed");
203
204        let next = scheduler.get_next_refresh().expect("get_next_refresh must succeed");
205        assert_eq!(next, Some("session_a".to_string()));
206    }
207
208    #[test]
209    fn test_scheduler_future_refresh_not_returned() {
210        let scheduler = TokenRefreshScheduler::new();
211        // Schedule a refresh far in the future
212        let future = Utc::now() + Duration::hours(1);
213        scheduler
214            .schedule_refresh("session_b".to_string(), future)
215            .expect("schedule_refresh must succeed");
216
217        let next = scheduler.get_next_refresh().expect("get_next_refresh must succeed");
218        assert!(next.is_none(), "future refresh must not be returned as next");
219    }
220
221    #[test]
222    fn test_scheduler_ordering_by_time() {
223        let scheduler = TokenRefreshScheduler::new();
224        let now = Utc::now();
225        scheduler
226            .schedule_refresh("later".to_string(), now - Duration::seconds(5))
227            .expect("schedule must succeed");
228        scheduler
229            .schedule_refresh("earlier".to_string(), now - Duration::seconds(10))
230            .expect("schedule must succeed");
231
232        // The earliest due refresh should come first
233        let first = scheduler.get_next_refresh().expect("must succeed");
234        assert_eq!(first, Some("earlier".to_string()));
235        let second = scheduler.get_next_refresh().expect("must succeed");
236        assert_eq!(second, Some("later".to_string()));
237    }
238
239    #[test]
240    fn test_scheduler_cancel_refresh() {
241        let scheduler = TokenRefreshScheduler::new();
242        let future = Utc::now() + Duration::hours(1);
243        scheduler
244            .schedule_refresh("session_c".to_string(), future)
245            .expect("schedule must succeed");
246
247        let cancelled = scheduler.cancel_refresh("session_c").expect("cancel must succeed");
248        assert!(cancelled, "cancel_refresh must return true for existing session");
249
250        let cancelled_again = scheduler.cancel_refresh("session_c").expect("cancel must succeed");
251        assert!(!cancelled_again, "cancel_refresh must return false for already-removed session");
252    }
253
254    #[test]
255    fn test_scheduler_cancel_nonexistent_returns_false() {
256        let scheduler = TokenRefreshScheduler::new();
257        let cancelled = scheduler.cancel_refresh("nonexistent").expect("cancel must succeed");
258        assert!(!cancelled);
259    }
260
261    #[test]
262    fn test_scheduler_empty_returns_none() {
263        let scheduler = TokenRefreshScheduler::new();
264        let next = scheduler.get_next_refresh().expect("must succeed");
265        assert!(next.is_none());
266    }
267}