server/auth/
token_refresh_service.rs

1use super::auth_state::AuthStateManager;
2use super::errors::TokenRefreshError;
3use super::provider::AuthProvider;
4use super::types::CachedToken;
5use crate::service_bus_manager::ServiceBusError;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::{interval, sleep};
10
11/// Callback for handling refresh failures
12pub type RefreshFailureCallback = Arc<dyn Fn(TokenRefreshError) + Send + Sync>;
13
14/// Service that periodically checks and refreshes tokens before they expire
15pub struct TokenRefreshService {
16    auth_state: Arc<AuthStateManager>,
17    check_interval: Duration,
18    shutdown_signal: Arc<RwLock<bool>>,
19    failure_callback: Option<RefreshFailureCallback>,
20}
21
22impl TokenRefreshService {
23    /// Create a new token refresh service
24    pub fn new(auth_state: Arc<AuthStateManager>) -> Self {
25        Self {
26            auth_state,
27            check_interval: Duration::from_secs(120), // Check every 2 minutes
28            shutdown_signal: Arc::new(RwLock::new(false)),
29            failure_callback: None,
30        }
31    }
32
33    /// Set a callback to be invoked when token refresh fails
34    pub fn with_failure_callback(mut self, callback: RefreshFailureCallback) -> Self {
35        self.failure_callback = Some(callback);
36        self
37    }
38
39    /// Start the background refresh service
40    pub fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
41        tokio::spawn(async move {
42            self.run().await;
43        })
44    }
45
46    /// Signal the service to shutdown
47    pub async fn shutdown(&self) {
48        let mut shutdown = self.shutdown_signal.write().await;
49        *shutdown = true;
50    }
51
52    /// Run the refresh service loop
53    async fn run(&self) {
54        let mut check_interval = interval(self.check_interval);
55        check_interval.tick().await; // Skip the first immediate tick
56
57        loop {
58            // Check if we should shutdown
59            if *self.shutdown_signal.read().await {
60                log::info!("Token refresh service shutting down");
61                break;
62            }
63
64            check_interval.tick().await;
65
66            // Check and refresh tokens
67            if let Err(e) = self.check_and_refresh_tokens().await {
68                log::error!("Error during token refresh check: {e}");
69            }
70        }
71    }
72
73    /// Check all cached tokens and refresh those that need it
74    async fn check_and_refresh_tokens(&self) -> Result<(), ServiceBusError> {
75        log::debug!("Checking tokens for refresh...");
76
77        // Check Service Bus token
78        if let Some(provider) = self.auth_state.get_service_bus_provider().await {
79            self.refresh_if_needed("service_bus", provider).await?;
80        }
81
82        // Check Management API token
83        if let Some(provider) = self.auth_state.get_management_provider().await {
84            self.refresh_if_needed("management_api", provider).await?;
85        }
86
87        Ok(())
88    }
89
90    /// Refresh a specific token if it needs refreshing
91    async fn refresh_if_needed(
92        &self,
93        cache_key: &str,
94        provider: Arc<dyn AuthProvider>,
95    ) -> Result<(), ServiceBusError> {
96        let token_cache = self.auth_state.get_token_cache();
97
98        if token_cache.needs_refresh(cache_key).await {
99            log::info!("Token for '{cache_key}' needs refresh, attempting refresh...");
100
101            match self.refresh_with_retry(provider, 3).await {
102                Ok(auth_token) => {
103                    // Store the refreshed token
104                    let cached_token = CachedToken::new(
105                        auth_token.token,
106                        Duration::from_secs(auth_token.expires_in_secs.unwrap_or(3600)),
107                        auth_token.token_type,
108                    );
109
110                    token_cache.set(cache_key.to_string(), cached_token).await;
111                    log::info!("Successfully refreshed token for '{cache_key}'");
112                }
113                Err(e) => {
114                    log::error!("Failed to refresh token for '{cache_key}': {e}");
115
116                    // Invalidate the token so next access will trigger re-authentication
117                    token_cache.invalidate(cache_key).await;
118
119                    // Invoke failure callback if set
120                    if let Some(callback) = &self.failure_callback {
121                        callback(e.clone());
122                    }
123
124                    // Convert to ServiceBusError
125                    return Err(e.into());
126                }
127            }
128        }
129
130        Ok(())
131    }
132
133    /// Attempt to refresh a token with retry logic
134    async fn refresh_with_retry(
135        &self,
136        provider: Arc<dyn AuthProvider>,
137        max_attempts: u32,
138    ) -> Result<super::provider::AuthToken, TokenRefreshError> {
139        let mut last_error = None;
140
141        for attempt in 1..=max_attempts {
142            match provider.refresh().await {
143                Ok(token) => return Ok(token),
144                Err(e) => {
145                    // Convert ServiceBusError to TokenRefreshError
146                    let refresh_error = match &e {
147                        ServiceBusError::AuthenticationFailed(_) => {
148                            TokenRefreshError::InvalidRefreshToken
149                        }
150                        ServiceBusError::AuthenticationError(msg) if msg.contains("expired") => {
151                            TokenRefreshError::RefreshTokenExpired
152                        }
153                        ServiceBusError::ConnectionFailed(reason) => {
154                            TokenRefreshError::NetworkError {
155                                reason: reason.clone(),
156                            }
157                        }
158                        ServiceBusError::OperationTimeout(msg) => {
159                            if msg.contains("rate") {
160                                TokenRefreshError::RateLimited {
161                                    retry_after_seconds: None,
162                                }
163                            } else {
164                                TokenRefreshError::ServiceUnavailable {
165                                    reason: msg.clone(),
166                                }
167                            }
168                        }
169                        _ => TokenRefreshError::Internal(e.to_string()),
170                    };
171
172                    last_error = Some(refresh_error);
173
174                    if attempt < max_attempts {
175                        let delay = Duration::from_secs(2u64.pow(attempt - 1)); // Exponential backoff: 1s, 2s, 4s
176                        log::warn!(
177                            "Token refresh attempt {attempt} failed, retrying in {delay:?}..."
178                        );
179                        sleep(delay).await;
180                    }
181                }
182            }
183        }
184
185        Err(last_error.unwrap_or(TokenRefreshError::MaxRetriesExceeded {
186            attempts: max_attempts,
187        }))
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::auth::auth_state::AuthStateManager;
195    use crate::auth::provider::{AuthProvider, AuthToken};
196    use crate::auth::types::AuthType;
197    use async_trait::async_trait;
198    use std::sync::atomic::{AtomicU32, Ordering};
199
200    // Mock provider for testing
201    struct MockAuthProvider {
202        refresh_count: Arc<AtomicU32>,
203        should_fail: bool,
204    }
205
206    #[async_trait]
207    impl AuthProvider for MockAuthProvider {
208        async fn authenticate(&self) -> Result<AuthToken, ServiceBusError> {
209            Ok(AuthToken {
210                token: "test_token".to_string(),
211                token_type: "Bearer".to_string(),
212                expires_in_secs: Some(3600),
213            })
214        }
215
216        async fn refresh(&self) -> Result<AuthToken, ServiceBusError> {
217            self.refresh_count.fetch_add(1, Ordering::SeqCst);
218
219            if self.should_fail {
220                Err(ServiceBusError::AuthenticationError(
221                    "Mock refresh failure".to_string(),
222                ))
223            } else {
224                self.authenticate().await
225            }
226        }
227
228        fn auth_type(&self) -> AuthType {
229            AuthType::ConnectionString
230        }
231    }
232
233    #[tokio::test]
234    async fn test_refresh_with_retry_success() {
235        let auth_state = Arc::new(AuthStateManager::new());
236        let service = TokenRefreshService::new(auth_state);
237
238        let provider = Arc::new(MockAuthProvider {
239            refresh_count: Arc::new(AtomicU32::new(0)),
240            should_fail: false,
241        });
242
243        let result = service.refresh_with_retry(provider.clone(), 3).await;
244        assert!(result.is_ok());
245        assert_eq!(provider.refresh_count.load(Ordering::SeqCst), 1);
246    }
247
248    #[tokio::test]
249    async fn test_refresh_with_retry_failure() {
250        let auth_state = Arc::new(AuthStateManager::new());
251        let service = TokenRefreshService::new(auth_state);
252
253        let provider = Arc::new(MockAuthProvider {
254            refresh_count: Arc::new(AtomicU32::new(0)),
255            should_fail: true,
256        });
257
258        let result = service.refresh_with_retry(provider.clone(), 3).await;
259        assert!(result.is_err());
260        assert_eq!(provider.refresh_count.load(Ordering::SeqCst), 3); // All 3 attempts
261    }
262}