server/auth/
auth_state.rs

1use super::provider::AuthProvider;
2use super::token_cache::TokenCache;
3use super::token_refresh_service::TokenRefreshService;
4use super::types::DeviceCodeInfo;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8use tokio::task::JoinHandle;
9
10/// Authentication state tracking for the application.
11///
12/// Represents the current authentication status and provides context
13/// for ongoing authentication processes like device code flows.
14#[derive(Clone, Debug, Default)]
15pub enum AuthenticationState {
16    /// No authentication is currently active
17    #[default]
18    NotAuthenticated,
19    /// Device code authentication is in progress
20    AwaitingDeviceCode {
21        /// Device code information for user interaction
22        info: DeviceCodeInfo,
23        /// When the device code flow was initiated
24        started_at: Instant,
25    },
26    /// Authentication completed successfully
27    Authenticated {
28        /// The authentication token
29        token: String,
30        /// When the token expires
31        expires_at: Instant,
32        /// Optional connection string for Service Bus operations
33        connection_string: Option<String>,
34    },
35    /// Authentication failed with error message
36    Failed(String),
37}
38
39// Consolidated state structure to prevent deadlocks
40#[derive(Default)]
41struct AuthState {
42    authentication_state: AuthenticationState,
43    azure_ad_token: Option<(String, Instant)>,
44    sas_token: Option<(String, Instant)>,
45    service_bus_provider: Option<Arc<dyn AuthProvider>>,
46    management_provider: Option<Arc<dyn AuthProvider>>,
47    refresh_service: Option<Arc<TokenRefreshService>>,
48    refresh_handle: Option<JoinHandle<()>>,
49}
50
51/// Centralized authentication state management for the application.
52///
53/// Manages authentication state, token caching, and refresh services across
54/// the entire application. Provides thread-safe access to authentication
55/// providers and tokens with automatic expiration handling.
56///
57/// # Features
58///
59/// - Thread-safe state management with RwLock
60/// - Token caching with automatic expiration
61/// - Authentication provider management
62/// - Token refresh service integration
63/// - Device code flow support
64///
65/// # Examples
66///
67/// ```no_run
68/// use quetty_server::auth::AuthStateManager;
69/// use std::sync::Arc;
70///
71/// let auth_manager = Arc::new(AuthStateManager::new());
72///
73/// // Check authentication status
74/// if !auth_manager.is_authenticated().await {
75///     // Start authentication process
76/// }
77///
78/// // Get cached tokens
79/// if let Some(token) = auth_manager.get_azure_ad_token().await {
80///     // Use token for API calls
81/// }
82/// ```
83pub struct AuthStateManager {
84    inner: Arc<RwLock<AuthState>>,
85    token_cache: TokenCache,
86}
87
88impl AuthStateManager {
89    /// Creates a new authentication state manager.
90    ///
91    /// # Returns
92    ///
93    /// A new AuthStateManager with clean state and empty token cache
94    pub fn new() -> Self {
95        Self {
96            inner: Arc::new(RwLock::new(AuthState::default())),
97            token_cache: TokenCache::new(),
98        }
99    }
100
101    /// Gets the current authentication state.
102    ///
103    /// # Returns
104    ///
105    /// The current [`AuthenticationState`] indicating the authentication status
106    pub async fn get_state(&self) -> AuthenticationState {
107        self.inner.read().await.authentication_state.clone()
108    }
109
110    /// Sets the authentication state to indicate device code flow is in progress.
111    ///
112    /// This method is called when a device code authentication flow has been initiated
113    /// and is waiting for user interaction to complete the authentication process.
114    ///
115    /// # Arguments
116    ///
117    /// * `info` - Device code information including user code and verification URL
118    ///
119    /// # Examples
120    ///
121    /// ```no_run
122    /// use quetty_server::auth::{AuthStateManager, DeviceCodeInfo};
123    /// use std::sync::Arc;
124    ///
125    /// let auth_manager = Arc::new(AuthStateManager::new());
126    /// let device_info = DeviceCodeInfo {
127    ///     device_code: "device123".to_string(),
128    ///     user_code: "ABC123".to_string(),
129    ///     verification_uri: "https://microsoft.com/devicelogin".to_string(),
130    ///     expires_in: 900,
131    ///     interval: 5,
132    ///     message: "Enter code ABC123 at https://microsoft.com/devicelogin".to_string(),
133    /// };
134    ///
135    /// auth_manager.set_device_code_pending(device_info).await;
136    /// ```
137    pub async fn set_device_code_pending(&self, info: DeviceCodeInfo) {
138        let mut state = self.inner.write().await;
139        state.authentication_state = AuthenticationState::AwaitingDeviceCode {
140            info,
141            started_at: Instant::now(),
142        };
143    }
144
145    pub async fn set_authenticated(
146        &self,
147        token: String,
148        expires_in: Duration,
149        connection_string: Option<String>,
150    ) {
151        let mut state = self.inner.write().await;
152        let expires_at = Instant::now() + expires_in;
153
154        state.authentication_state = AuthenticationState::Authenticated {
155            token: token.clone(),
156            expires_at,
157            connection_string,
158        };
159
160        // Store Azure AD token
161        state.azure_ad_token = Some((token, expires_at));
162    }
163
164    /// Sets the authentication state to failed with an error message.
165    ///
166    /// This method is called when authentication attempts fail, providing
167    /// detailed error information that can be displayed to the user.
168    ///
169    /// # Arguments
170    ///
171    /// * `error` - Human-readable error message describing the authentication failure
172    ///
173    /// # Examples
174    ///
175    /// ```no_run
176    /// use quetty_server::auth::AuthStateManager;
177    /// use std::sync::Arc;
178    ///
179    /// let auth_manager = Arc::new(AuthStateManager::new());
180    /// auth_manager.set_failed("Invalid credentials provided".to_string()).await;
181    /// ```
182    pub async fn set_failed(&self, error: String) {
183        let mut state = self.inner.write().await;
184        state.authentication_state = AuthenticationState::Failed(error);
185    }
186
187    /// Logs out the user and clears all authentication state.
188    ///
189    /// This method resets the authentication state to `NotAuthenticated` and
190    /// clears all cached tokens and authentication providers. It also stops
191    /// any running token refresh services.
192    ///
193    /// # Examples
194    ///
195    /// ```no_run
196    /// use quetty_server::auth::AuthStateManager;
197    /// use std::sync::Arc;
198    ///
199    /// let auth_manager = Arc::new(AuthStateManager::new());
200    ///
201    /// // After authentication...
202    /// auth_manager.logout().await;
203    ///
204    /// // State is now reset
205    /// assert!(!auth_manager.is_authenticated().await);
206    /// ```
207    pub async fn logout(&self) {
208        let mut state = self.inner.write().await;
209        state.authentication_state = AuthenticationState::NotAuthenticated;
210        state.azure_ad_token = None;
211        state.sas_token = None;
212    }
213
214    /// Checks if the user is currently authenticated.
215    ///
216    /// # Returns
217    ///
218    /// `true` if authentication is successful and active, `false` otherwise
219    pub async fn is_authenticated(&self) -> bool {
220        let state = self.inner.read().await;
221        matches!(
222            state.authentication_state,
223            AuthenticationState::Authenticated { .. }
224        )
225    }
226
227    /// Checks if reauthentication is needed.
228    ///
229    /// Returns `true` if the user is not authenticated or if the current
230    /// authentication token expires within 5 minutes.
231    ///
232    /// # Returns
233    ///
234    /// `true` if reauthentication is required, `false` if current auth is still valid
235    pub async fn needs_reauthentication(&self) -> bool {
236        let state = self.inner.read().await;
237        match &state.authentication_state {
238            AuthenticationState::Authenticated { expires_at, .. } => {
239                // Check if token expires in less than 5 minutes
240                Instant::now() + Duration::from_secs(300) >= *expires_at
241            }
242            _ => true,
243        }
244    }
245
246    /// Retrieves a valid Azure AD access token if available.
247    ///
248    /// Returns the cached Azure AD token if it exists and hasn't expired.
249    /// This token can be used for authenticating with Azure Service Bus
250    /// and other Azure resources.
251    ///
252    /// # Returns
253    ///
254    /// * `Some(token)` - Valid Azure AD access token
255    /// * `None` - No token available or token has expired
256    ///
257    /// # Examples
258    ///
259    /// ```no_run
260    /// use quetty_server::auth::AuthStateManager;
261    /// use std::sync::Arc;
262    ///
263    /// let auth_manager = Arc::new(AuthStateManager::new());
264    ///
265    /// if let Some(token) = auth_manager.get_azure_ad_token().await {
266    ///     println!("Using Azure AD token: {}", token);
267    ///     // Use token for Service Bus operations
268    /// } else {
269    ///     println!("No valid Azure AD token available");
270    /// }
271    /// ```
272    pub async fn get_azure_ad_token(&self) -> Option<String> {
273        let state = self.inner.read().await;
274        if let Some((token_str, expires_at)) = &state.azure_ad_token {
275            if Instant::now() < *expires_at {
276                return Some(token_str.clone());
277            }
278        }
279        None
280    }
281
282    /// Retrieves a valid SAS token if available.
283    ///
284    /// Returns the cached SAS (Shared Access Signature) token if it exists
285    /// and hasn't expired. SAS tokens are used for connection string-based
286    /// authentication with Azure Service Bus.
287    ///
288    /// # Returns
289    ///
290    /// * `Some(token)` - Valid SAS token
291    /// * `None` - No token available or token has expired
292    ///
293    /// # Examples
294    ///
295    /// ```no_run
296    /// use quetty_server::auth::AuthStateManager;
297    /// use std::sync::Arc;
298    ///
299    /// let auth_manager = Arc::new(AuthStateManager::new());
300    ///
301    /// if let Some(sas_token) = auth_manager.get_sas_token().await {
302    ///     println!("Using SAS token: {}", sas_token);
303    ///     // Use token for Service Bus operations
304    /// } else {
305    ///     println!("No valid SAS token available");
306    /// }
307    /// ```
308    pub async fn get_sas_token(&self) -> Option<String> {
309        let state = self.inner.read().await;
310        if let Some((token_str, expires_at)) = &state.sas_token {
311            if Instant::now() < *expires_at {
312                return Some(token_str.clone());
313            }
314        }
315        None
316    }
317
318    /// Stores a SAS token with its expiration time.
319    ///
320    /// Caches a SAS token for future use with automatic expiration handling.
321    /// The token will be considered invalid after the specified duration.
322    ///
323    /// # Arguments
324    ///
325    /// * `token` - The SAS token string to cache
326    /// * `expires_in` - Duration until the token expires
327    ///
328    /// # Examples
329    ///
330    /// ```no_run
331    /// use quetty_server::auth::AuthStateManager;
332    /// use std::sync::Arc;
333    /// use std::time::Duration;
334    ///
335    /// let auth_manager = Arc::new(AuthStateManager::new());
336    /// let token = "SharedAccessSignature sr=...".to_string();
337    /// let expires_in = Duration::from_secs(24 * 3600); // 24 hours
338    ///
339    /// auth_manager.set_sas_token(token, expires_in).await;
340    /// ```
341    pub async fn set_sas_token(&self, token: String, expires_in: Duration) {
342        let mut state = self.inner.write().await;
343        state.sas_token = Some((token, Instant::now() + expires_in));
344    }
345
346    /// Retrieves the connection string from the current authentication state.
347    ///
348    /// Returns the connection string if the user is authenticated and a
349    /// connection string is available in the authentication state.
350    ///
351    /// # Returns
352    ///
353    /// * `Some(connection_string)` - Valid connection string for Service Bus
354    /// * `None` - No connection string available or not authenticated
355    ///
356    /// # Examples
357    ///
358    /// ```no_run
359    /// use quetty_server::auth::AuthStateManager;
360    /// use std::sync::Arc;
361    ///
362    /// let auth_manager = Arc::new(AuthStateManager::new());
363    ///
364    /// if let Some(conn_str) = auth_manager.get_connection_string().await {
365    ///     println!("Using connection string: {}", conn_str);
366    ///     // Use connection string for Service Bus operations
367    /// }
368    /// ```
369    pub async fn get_connection_string(&self) -> Option<String> {
370        let state = self.inner.read().await;
371        match &state.authentication_state {
372            AuthenticationState::Authenticated {
373                connection_string, ..
374            } => connection_string.clone(),
375            _ => None,
376        }
377    }
378
379    /// Retrieves device code information if device code flow is in progress.
380    ///
381    /// Returns the device code information (user code, verification URL, etc.)
382    /// if a device code authentication flow is currently active.
383    ///
384    /// # Returns
385    ///
386    /// * `Some(DeviceCodeInfo)` - Device code flow information
387    /// * `None` - No device code flow is currently active
388    ///
389    /// # Examples
390    ///
391    /// ```no_run
392    /// use quetty_server::auth::AuthStateManager;
393    /// use std::sync::Arc;
394    ///
395    /// let auth_manager = Arc::new(AuthStateManager::new());
396    ///
397    /// if let Some(device_info) = auth_manager.get_device_code_info().await {
398    ///     println!("Go to: {}", device_info.verification_uri);
399    ///     println!("Enter code: {}", device_info.user_code);
400    /// }
401    /// ```
402    pub async fn get_device_code_info(&self) -> Option<DeviceCodeInfo> {
403        let state = self.inner.read().await;
404        match &state.authentication_state {
405            AuthenticationState::AwaitingDeviceCode { info, .. } => Some(info.clone()),
406            _ => None,
407        }
408    }
409
410    // Provider management methods
411
412    /// Sets the authentication provider for Service Bus operations.
413    ///
414    /// Configures the authentication provider that will be used for
415    /// Service Bus data plane operations (sending/receiving messages).
416    ///
417    /// # Arguments
418    ///
419    /// * `provider` - Authentication provider for Service Bus operations
420    ///
421    /// # Examples
422    ///
423    /// ```no_run
424    /// use quetty_server::auth::{AuthStateManager, AzureAdProvider};
425    /// use std::sync::Arc;
426    ///
427    /// let auth_manager = Arc::new(AuthStateManager::new());
428    /// let provider = Arc::new(AzureAdProvider::new(config, client)?);
429    ///
430    /// auth_manager.set_service_bus_provider(provider).await;
431    /// ```
432    pub async fn set_service_bus_provider(&self, provider: Arc<dyn AuthProvider>) {
433        let mut state = self.inner.write().await;
434        state.service_bus_provider = Some(provider);
435    }
436
437    /// Retrieves the current Service Bus authentication provider.
438    ///
439    /// Returns the authentication provider configured for Service Bus
440    /// data plane operations if one has been set.
441    ///
442    /// # Returns
443    ///
444    /// * `Some(provider)` - Configured Service Bus authentication provider
445    /// * `None` - No provider has been configured
446    ///
447    /// # Examples
448    ///
449    /// ```no_run
450    /// use quetty_server::auth::AuthStateManager;
451    /// use std::sync::Arc;
452    ///
453    /// let auth_manager = Arc::new(AuthStateManager::new());
454    ///
455    /// if let Some(provider) = auth_manager.get_service_bus_provider().await {
456    ///     let token = provider.authenticate().await?;
457    ///     // Use token for Service Bus operations
458    /// }
459    /// ```
460    pub async fn get_service_bus_provider(&self) -> Option<Arc<dyn AuthProvider>> {
461        self.inner.read().await.service_bus_provider.clone()
462    }
463
464    /// Sets the authentication provider for Service Bus management operations.
465    ///
466    /// Configures the authentication provider that will be used for
467    /// Service Bus management plane operations (creating queues, topics, etc.).
468    ///
469    /// # Arguments
470    ///
471    /// * `provider` - Authentication provider for management operations
472    ///
473    /// # Examples
474    ///
475    /// ```no_run
476    /// use quetty_server::auth::{AuthStateManager, AzureAdProvider};
477    /// use std::sync::Arc;
478    ///
479    /// let auth_manager = Arc::new(AuthStateManager::new());
480    /// let provider = Arc::new(AzureAdProvider::new(config, client)?);
481    ///
482    /// auth_manager.set_management_provider(provider).await;
483    /// ```
484    pub async fn set_management_provider(&self, provider: Arc<dyn AuthProvider>) {
485        let mut state = self.inner.write().await;
486        state.management_provider = Some(provider);
487    }
488
489    /// Retrieves the current Service Bus management authentication provider.
490    ///
491    /// Returns the authentication provider configured for Service Bus
492    /// management plane operations if one has been set.
493    ///
494    /// # Returns
495    ///
496    /// * `Some(provider)` - Configured management authentication provider
497    /// * `None` - No provider has been configured
498    ///
499    /// # Examples
500    ///
501    /// ```no_run
502    /// use quetty_server::auth::AuthStateManager;
503    /// use std::sync::Arc;
504    ///
505    /// let auth_manager = Arc::new(AuthStateManager::new());
506    ///
507    /// if let Some(provider) = auth_manager.get_management_provider().await {
508    ///     let token = provider.authenticate().await?;
509    ///     // Use token for management operations
510    /// }
511    /// ```
512    pub async fn get_management_provider(&self) -> Option<Arc<dyn AuthProvider>> {
513        self.inner.read().await.management_provider.clone()
514    }
515
516    /// Gets a reference to the token cache.
517    ///
518    /// # Returns
519    ///
520    /// A reference to the [`TokenCache`] for manual token management
521    pub fn get_token_cache(&self) -> &TokenCache {
522        &self.token_cache
523    }
524
525    // Token refresh service management
526
527    /// Starts the automatic token refresh service.
528    ///
529    /// Initiates a background service that automatically refreshes tokens
530    /// before they expire, ensuring continuous authentication without
531    /// user intervention.
532    ///
533    /// # Examples
534    ///
535    /// ```no_run
536    /// use quetty_server::auth::AuthStateManager;
537    /// use std::sync::Arc;
538    ///
539    /// let auth_manager = Arc::new(AuthStateManager::new());
540    ///
541    /// // Start automatic token refresh
542    /// auth_manager.clone().start_refresh_service().await;
543    ///
544    /// // Tokens will now be refreshed automatically in the background
545    /// ```
546    pub async fn start_refresh_service(self: Arc<Self>) {
547        self.start_refresh_service_with_callback(None).await;
548    }
549
550    pub async fn start_refresh_service_with_callback(
551        self: Arc<Self>,
552        failure_callback: Option<super::token_refresh_service::RefreshFailureCallback>,
553    ) {
554        // Stop any existing service
555        self.stop_refresh_service().await;
556
557        // Create and start new service
558        let mut refresh_service = TokenRefreshService::new(self.clone());
559        if let Some(callback) = failure_callback {
560            refresh_service = refresh_service.with_failure_callback(callback);
561        }
562
563        let refresh_service = Arc::new(refresh_service);
564        let handle = refresh_service.clone().start();
565
566        // Store service and handle in consolidated state
567        let mut state = self.inner.write().await;
568        state.refresh_service = Some(refresh_service);
569        state.refresh_handle = Some(handle);
570
571        log::info!("Token refresh service started");
572    }
573
574    /// Stops the automatic token refresh service.
575    ///
576    /// Gracefully shuts down the background token refresh service,
577    /// stopping automatic token renewal. Tokens will no longer be
578    /// refreshed automatically after calling this method.
579    ///
580    /// # Examples
581    ///
582    /// ```no_run
583    /// use quetty_server::auth::AuthStateManager;
584    /// use std::sync::Arc;
585    ///
586    /// let auth_manager = Arc::new(AuthStateManager::new());
587    ///
588    /// // Start refresh service
589    /// auth_manager.clone().start_refresh_service().await;
590    ///
591    /// // Later, stop the service
592    /// auth_manager.stop_refresh_service().await;
593    /// ```
594    pub async fn stop_refresh_service(&self) {
595        // Get service reference and signal shutdown
596        let service_ref = {
597            let state = self.inner.read().await;
598            state.refresh_service.clone()
599        };
600
601        if let Some(service) = service_ref {
602            service.shutdown().await;
603        }
604
605        // Wait for service to stop and clear references
606        let mut state = self.inner.write().await;
607        if let Some(handle) = state.refresh_handle.take() {
608            // Drop the write lock before waiting
609            drop(state);
610            let _ = handle.await;
611
612            // Re-acquire write lock to clear service reference
613            let mut state = self.inner.write().await;
614            state.refresh_service = None;
615        } else {
616            state.refresh_service = None;
617        }
618
619        log::info!("Token refresh service stopped");
620    }
621}
622
623impl Default for AuthStateManager {
624    fn default() -> Self {
625        Self::new()
626    }
627}