Skip to main content

fraiseql_auth/oauth/
failover.rs

1//! Multi-provider failover management.
2
3use std::sync::Arc;
4
5use chrono::{DateTime, Duration, Utc};
6
7use super::super::error::AuthError;
8
9/// Multi-provider failover manager
10#[derive(Debug, Clone)]
11pub struct ProviderFailoverManager {
12    /// Primary provider name
13    primary_provider:   String,
14    /// Fallback providers in priority order
15    fallback_providers: Vec<String>,
16    /// Providers currently unavailable
17    // std::sync::Mutex is intentional: this lock is never held across .await.
18    // Switch to tokio::sync::Mutex if that constraint ever changes.
19    unavailable: Arc<std::sync::Mutex<Vec<(String, DateTime<Utc>)>>>,
20}
21
22impl ProviderFailoverManager {
23    /// Create new failover manager
24    pub fn new(primary: String, fallbacks: Vec<String>) -> Self {
25        Self {
26            primary_provider:   primary,
27            fallback_providers: fallbacks,
28            unavailable:        Arc::new(std::sync::Mutex::new(Vec::new())),
29        }
30    }
31
32    /// Get next available provider
33    ///
34    /// # Errors
35    ///
36    /// Returns `AuthError::Internal` if the mutex is poisoned or no providers are available.
37    pub fn get_available_provider(&self) -> std::result::Result<String, AuthError> {
38        let unavailable = self.unavailable.lock().map_err(|_| AuthError::Internal {
39            message: "failover manager mutex poisoned".to_string(),
40        })?;
41        let now = Utc::now();
42
43        // Check if primary is available
44        if !unavailable
45            .iter()
46            .any(|(name, exp)| name == &self.primary_provider && *exp > now)
47        {
48            return Ok(self.primary_provider.clone());
49        }
50
51        // Find first available fallback
52        for fallback in &self.fallback_providers {
53            if !unavailable.iter().any(|(name, exp)| name == fallback && *exp > now) {
54                return Ok(fallback.clone());
55            }
56        }
57
58        Err(AuthError::Internal {
59            message: "no OAuth providers available".to_string(),
60        })
61    }
62
63    /// Mark provider as unavailable
64    ///
65    /// # Errors
66    ///
67    /// Returns `AuthError::Internal` if the mutex is poisoned.
68    pub fn mark_unavailable(
69        &self,
70        provider: String,
71        duration_seconds: u64,
72    ) -> std::result::Result<(), AuthError> {
73        let mut unavailable = self.unavailable.lock().map_err(|_| AuthError::Internal {
74            message: "failover manager mutex poisoned".to_string(),
75        })?;
76        unavailable
77            .push((provider, Utc::now() + Duration::seconds(duration_seconds.cast_signed())));
78        Ok(())
79    }
80
81    /// Mark provider as available
82    ///
83    /// # Errors
84    ///
85    /// Returns `AuthError::Internal` if the mutex is poisoned.
86    pub fn mark_available(&self, provider: &str) -> std::result::Result<(), AuthError> {
87        let mut unavailable = self.unavailable.lock().map_err(|_| AuthError::Internal {
88            message: "failover manager mutex poisoned".to_string(),
89        })?;
90        unavailable.retain(|(name, _)| name != provider);
91        Ok(())
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_primary_available_by_default() {
101        let mgr = ProviderFailoverManager::new("primary".to_string(), vec!["fallback".to_string()]);
102        let available = mgr.get_available_provider().expect("must succeed");
103        assert_eq!(available, "primary");
104    }
105
106    #[test]
107    fn test_fallback_used_when_primary_unavailable() {
108        let mgr = ProviderFailoverManager::new("primary".to_string(), vec!["fallback".to_string()]);
109        mgr.mark_unavailable("primary".to_string(), 300)
110            .expect("mark_unavailable must succeed");
111        let available = mgr.get_available_provider().expect("must succeed");
112        assert_eq!(available, "fallback");
113    }
114
115    #[test]
116    fn test_all_unavailable_returns_error() {
117        let mgr = ProviderFailoverManager::new("primary".to_string(), vec!["fallback".to_string()]);
118        mgr.mark_unavailable("primary".to_string(), 300).expect("must succeed");
119        mgr.mark_unavailable("fallback".to_string(), 300).expect("must succeed");
120        let result = mgr.get_available_provider();
121        assert!(result.is_err(), "must return error when no providers are available");
122    }
123
124    #[test]
125    fn test_mark_available_restores_provider() {
126        let mgr = ProviderFailoverManager::new("primary".to_string(), vec!["fallback".to_string()]);
127        mgr.mark_unavailable("primary".to_string(), 300).expect("must succeed");
128        mgr.mark_available("primary").expect("must succeed");
129        let available = mgr.get_available_provider().expect("must succeed");
130        assert_eq!(available, "primary", "primary must be available after mark_available");
131    }
132
133    #[test]
134    fn test_no_fallbacks_returns_primary() {
135        let mgr = ProviderFailoverManager::new("only".to_string(), vec![]);
136        let available = mgr.get_available_provider().expect("must succeed");
137        assert_eq!(available, "only");
138    }
139
140    #[test]
141    fn test_no_fallbacks_primary_unavailable_returns_error() {
142        let mgr = ProviderFailoverManager::new("only".to_string(), vec![]);
143        mgr.mark_unavailable("only".to_string(), 300).expect("must succeed");
144        let result = mgr.get_available_provider();
145        assert!(result.is_err());
146    }
147
148    #[test]
149    fn test_multiple_fallbacks_in_order() {
150        let mgr = ProviderFailoverManager::new(
151            "primary".to_string(),
152            vec!["fb1".to_string(), "fb2".to_string()],
153        );
154        mgr.mark_unavailable("primary".to_string(), 300).expect("must succeed");
155        let available = mgr.get_available_provider().expect("must succeed");
156        assert_eq!(available, "fb1", "first fallback must be selected");
157
158        mgr.mark_unavailable("fb1".to_string(), 300).expect("must succeed");
159        let available = mgr.get_available_provider().expect("must succeed");
160        assert_eq!(available, "fb2", "second fallback must be selected when first is unavailable");
161    }
162}