fraiseql_auth/oauth/
refresh.rs1use std::{sync::Arc, time::Duration as StdDuration};
4
5use chrono::{DateTime, Duration, Utc};
6
7use super::super::error::AuthError;
8
9#[derive(Debug, Clone)]
11pub struct TokenRefreshScheduler {
12 refresh_queue: Arc<std::sync::Mutex<Vec<(String, DateTime<Utc>)>>>,
16}
17
18impl TokenRefreshScheduler {
19 pub fn new() -> Self {
21 Self {
22 refresh_queue: Arc::new(std::sync::Mutex::new(Vec::new())),
23 }
24 }
25
26 pub fn schedule_refresh(
32 &self,
33 session_id: String,
34 refresh_time: DateTime<Utc>,
35 ) -> std::result::Result<(), AuthError> {
36 let mut queue = self.refresh_queue.lock().map_err(|_| AuthError::Internal {
37 message: "token refresh scheduler mutex poisoned".to_string(),
38 })?;
39 queue.push((session_id, refresh_time));
40 queue.sort_by_key(|(_, time)| *time);
41 Ok(())
42 }
43
44 pub fn get_next_refresh(&self) -> std::result::Result<Option<String>, AuthError> {
50 let mut queue = self.refresh_queue.lock().map_err(|_| AuthError::Internal {
51 message: "token refresh scheduler mutex poisoned".to_string(),
52 })?;
53 if let Some((_, refresh_time)) = queue.first() {
54 if *refresh_time <= Utc::now() {
55 let (id, _) = queue.remove(0);
56 return Ok(Some(id));
57 }
58 }
59 Ok(None)
60 }
61
62 pub fn cancel_refresh(&self, session_id: &str) -> std::result::Result<bool, AuthError> {
68 let mut queue = self.refresh_queue.lock().map_err(|_| AuthError::Internal {
69 message: "token refresh scheduler mutex poisoned".to_string(),
70 })?;
71 let len_before = queue.len();
72 queue.retain(|(id, _)| id != session_id);
73 Ok(queue.len() < len_before)
74 }
75}
76
77impl Default for TokenRefreshScheduler {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83#[async_trait::async_trait]
86pub trait TokenRefresher: Send + Sync {
87 async fn refresh_session(
93 &self,
94 session_id: &str,
95 ) -> std::result::Result<Option<DateTime<Utc>>, AuthError>;
96}
97
98pub struct TokenRefreshWorker {
101 scheduler: Arc<TokenRefreshScheduler>,
102 refresher: Arc<dyn TokenRefresher>,
103 cancel_rx: tokio::sync::watch::Receiver<bool>,
104 poll_interval: StdDuration,
105}
106
107impl TokenRefreshWorker {
108 pub fn new(
113 scheduler: Arc<TokenRefreshScheduler>,
114 refresher: Arc<dyn TokenRefresher>,
115 poll_interval: StdDuration,
116 ) -> (Self, tokio::sync::watch::Sender<bool>) {
117 let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
118 (
119 Self {
120 scheduler,
121 refresher,
122 cancel_rx,
123 poll_interval,
124 },
125 cancel_tx,
126 )
127 }
128
129 pub async fn run(mut self) {
131 tracing::info!(
132 interval_secs = self.poll_interval.as_secs(),
133 "Token refresh worker started"
134 );
135 loop {
136 tokio::select! {
137 result = self.cancel_rx.changed() => {
138 if result.is_err() || *self.cancel_rx.borrow() {
139 tracing::info!("Token refresh worker stopped");
140 break;
141 }
142 },
143 () = tokio::time::sleep(self.poll_interval) => {
144 self.process_due_refreshes().await;
145 }
146 }
147 }
148 }
149
150 async fn process_due_refreshes(&self) {
151 while let Ok(Some(session_id)) = self.scheduler.get_next_refresh() {
152 match self.refresher.refresh_session(&session_id).await {
153 Ok(Some(new_expiry)) => {
154 let remaining = new_expiry - Utc::now();
156 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
157 let next_refresh_secs = (remaining.num_seconds() as f64 * 0.8) as i64;
160 let next_refresh = Utc::now() + Duration::seconds(next_refresh_secs);
161 if let Err(e) =
162 self.scheduler.schedule_refresh(session_id.clone(), next_refresh)
163 {
164 tracing::warn!(
165 session_id = %session_id,
166 error = %e,
167 "Failed to re-schedule token refresh"
168 );
169 }
170 },
171 Ok(None) => {
172 tracing::debug!(
173 session_id = %session_id,
174 "Session no longer exists, skipping refresh"
175 );
176 },
177 Err(e) => {
178 tracing::warn!(
179 session_id = %session_id,
180 error = %e,
181 "Token refresh failed"
182 );
183 },
184 }
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use chrono::Duration;
192
193 use super::*;
194
195 #[test]
196 fn test_scheduler_schedule_and_get_due_refresh() {
197 let scheduler = TokenRefreshScheduler::new();
198 let past = Utc::now() - Duration::seconds(10);
200 scheduler
201 .schedule_refresh("session_a".to_string(), past)
202 .expect("schedule_refresh must succeed");
203
204 let next = scheduler.get_next_refresh().expect("get_next_refresh must succeed");
205 assert_eq!(next, Some("session_a".to_string()));
206 }
207
208 #[test]
209 fn test_scheduler_future_refresh_not_returned() {
210 let scheduler = TokenRefreshScheduler::new();
211 let future = Utc::now() + Duration::hours(1);
213 scheduler
214 .schedule_refresh("session_b".to_string(), future)
215 .expect("schedule_refresh must succeed");
216
217 let next = scheduler.get_next_refresh().expect("get_next_refresh must succeed");
218 assert!(next.is_none(), "future refresh must not be returned as next");
219 }
220
221 #[test]
222 fn test_scheduler_ordering_by_time() {
223 let scheduler = TokenRefreshScheduler::new();
224 let now = Utc::now();
225 scheduler
226 .schedule_refresh("later".to_string(), now - Duration::seconds(5))
227 .expect("schedule must succeed");
228 scheduler
229 .schedule_refresh("earlier".to_string(), now - Duration::seconds(10))
230 .expect("schedule must succeed");
231
232 let first = scheduler.get_next_refresh().expect("must succeed");
234 assert_eq!(first, Some("earlier".to_string()));
235 let second = scheduler.get_next_refresh().expect("must succeed");
236 assert_eq!(second, Some("later".to_string()));
237 }
238
239 #[test]
240 fn test_scheduler_cancel_refresh() {
241 let scheduler = TokenRefreshScheduler::new();
242 let future = Utc::now() + Duration::hours(1);
243 scheduler
244 .schedule_refresh("session_c".to_string(), future)
245 .expect("schedule must succeed");
246
247 let cancelled = scheduler.cancel_refresh("session_c").expect("cancel must succeed");
248 assert!(cancelled, "cancel_refresh must return true for existing session");
249
250 let cancelled_again = scheduler.cancel_refresh("session_c").expect("cancel must succeed");
251 assert!(!cancelled_again, "cancel_refresh must return false for already-removed session");
252 }
253
254 #[test]
255 fn test_scheduler_cancel_nonexistent_returns_false() {
256 let scheduler = TokenRefreshScheduler::new();
257 let cancelled = scheduler.cancel_refresh("nonexistent").expect("cancel must succeed");
258 assert!(!cancelled);
259 }
260
261 #[test]
262 fn test_scheduler_empty_returns_none() {
263 let scheduler = TokenRefreshScheduler::new();
264 let next = scheduler.get_next_refresh().expect("must succeed");
265 assert!(next.is_none());
266 }
267}