Skip to main content

fraiseql_auth/oauth/
provider.rs

1//! External OAuth provider registry and session management.
2
3use std::{collections::HashMap, sync::Arc};
4
5use chrono::{DateTime, Duration, Utc};
6use serde::{Deserialize, Serialize};
7
8use super::{super::error::AuthError, client::OIDCProviderConfig};
9
10/// External authentication provider type
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum ProviderType {
14    /// OAuth2 provider
15    OAuth2,
16    /// OIDC provider
17    OIDC,
18}
19
20impl std::fmt::Display for ProviderType {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            Self::OAuth2 => write!(f, "oauth2"),
24            Self::OIDC => write!(f, "oidc"),
25        }
26    }
27}
28
29/// OAuth session stored in database
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct OAuthSession {
32    /// Session ID
33    pub id:               String,
34    /// User ID (local system)
35    pub user_id:          String,
36    /// Provider type (oauth2, oidc)
37    pub provider_type:    ProviderType,
38    /// Provider name (Auth0, Google, etc.)
39    pub provider_name:    String,
40    /// Provider's user ID (sub claim)
41    pub provider_user_id: String,
42    /// Access token (encrypted)
43    pub access_token:     String,
44    /// Refresh token (encrypted), if available
45    pub refresh_token:    Option<String>,
46    /// When access token expires
47    pub token_expiry:     DateTime<Utc>,
48    /// Session creation time
49    pub created_at:       DateTime<Utc>,
50    /// Last time token was refreshed
51    pub last_refreshed:   Option<DateTime<Utc>>,
52}
53
54impl OAuthSession {
55    /// Create new OAuth session
56    pub fn new(
57        user_id: String,
58        provider_type: ProviderType,
59        provider_name: String,
60        provider_user_id: String,
61        access_token: String,
62        token_expiry: DateTime<Utc>,
63    ) -> Self {
64        Self {
65            id: uuid::Uuid::new_v4().to_string(),
66            user_id,
67            provider_type,
68            provider_name,
69            provider_user_id,
70            access_token,
71            refresh_token: None,
72            token_expiry,
73            created_at: Utc::now(),
74            last_refreshed: None,
75        }
76    }
77
78    /// Check if session is expired
79    pub fn is_expired(&self) -> bool {
80        self.token_expiry <= Utc::now()
81    }
82
83    /// Check if session will be expired within grace period
84    pub fn is_expiring_soon(&self, grace_seconds: i64) -> bool {
85        self.token_expiry <= (Utc::now() + Duration::seconds(grace_seconds))
86    }
87
88    /// Update tokens after refresh
89    pub fn refresh_tokens(&mut self, access_token: String, token_expiry: DateTime<Utc>) {
90        self.access_token = access_token;
91        self.token_expiry = token_expiry;
92        self.last_refreshed = Some(Utc::now());
93    }
94}
95
96/// External auth provider configuration
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98pub struct ExternalAuthProvider {
99    /// Provider ID
100    pub id: String,
101    /// Provider type (oauth2, oidc)
102    pub provider_type: ProviderType,
103    /// Provider name (Auth0, Google, Microsoft, Okta)
104    pub provider_name: String,
105    /// Client ID
106    pub client_id: String,
107    /// Client secret (should be fetched from vault)
108    pub client_secret_vault_path: String,
109    /// Provider configuration (OIDC)
110    pub oidc_config: Option<OIDCProviderConfig>,
111    /// OAuth2 configuration
112    pub oauth2_config: Option<OAuth2ClientConfig>,
113    /// Enabled flag
114    pub enabled: bool,
115    /// Requested scopes
116    pub scopes: Vec<String>,
117}
118
119/// OAuth2 client configuration
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
121pub struct OAuth2ClientConfig {
122    /// Authorization endpoint
123    pub authorization_endpoint: String,
124    /// Token endpoint
125    pub token_endpoint:         String,
126    /// Use PKCE
127    pub use_pkce:               bool,
128}
129
130impl ExternalAuthProvider {
131    /// Create new external auth provider
132    pub fn new(
133        provider_type: ProviderType,
134        provider_name: impl Into<String>,
135        client_id: impl Into<String>,
136        client_secret_vault_path: impl Into<String>,
137    ) -> Self {
138        Self {
139            id: uuid::Uuid::new_v4().to_string(),
140            provider_type,
141            provider_name: provider_name.into(),
142            client_id: client_id.into(),
143            client_secret_vault_path: client_secret_vault_path.into(),
144            oidc_config: None,
145            oauth2_config: None,
146            enabled: true,
147            scopes: vec![
148                "openid".to_string(),
149                "profile".to_string(),
150                "email".to_string(),
151            ],
152        }
153    }
154
155    /// Enable or disable provider
156    pub const fn set_enabled(&mut self, enabled: bool) {
157        self.enabled = enabled;
158    }
159
160    /// Set requested scopes
161    pub fn set_scopes(&mut self, scopes: Vec<String>) {
162        self.scopes = scopes;
163    }
164}
165
166/// Provider registry managing multiple OAuth providers
167#[derive(Debug, Clone)]
168pub struct ProviderRegistry {
169    /// Map of providers by name
170    // std::sync::Mutex is intentional: this lock is never held across .await.
171    // Switch to tokio::sync::Mutex if that constraint ever changes.
172    providers: Arc<std::sync::Mutex<HashMap<String, ExternalAuthProvider>>>,
173}
174
175impl ProviderRegistry {
176    /// Create new provider registry
177    pub fn new() -> Self {
178        Self {
179            providers: Arc::new(std::sync::Mutex::new(HashMap::new())),
180        }
181    }
182
183    /// Register provider
184    ///
185    /// # Errors
186    ///
187    /// Returns `AuthError::Internal` if the mutex is poisoned.
188    pub fn register(&self, provider: ExternalAuthProvider) -> std::result::Result<(), AuthError> {
189        let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
190            message: "provider registry mutex poisoned".to_string(),
191        })?;
192        providers.insert(provider.provider_name.clone(), provider);
193        Ok(())
194    }
195
196    /// Get provider by name
197    ///
198    /// # Errors
199    ///
200    /// Returns `AuthError::Internal` if the mutex is poisoned.
201    pub fn get(&self, name: &str) -> std::result::Result<Option<ExternalAuthProvider>, AuthError> {
202        let providers = self.providers.lock().map_err(|_| AuthError::Internal {
203            message: "provider registry mutex poisoned".to_string(),
204        })?;
205        Ok(providers.get(name).cloned())
206    }
207
208    /// List all enabled providers
209    ///
210    /// # Errors
211    ///
212    /// Returns `AuthError::Internal` if the mutex is poisoned.
213    pub fn list_enabled(&self) -> std::result::Result<Vec<ExternalAuthProvider>, AuthError> {
214        let providers = self.providers.lock().map_err(|_| AuthError::Internal {
215            message: "provider registry mutex poisoned".to_string(),
216        })?;
217        Ok(providers.values().filter(|p| p.enabled).cloned().collect())
218    }
219
220    /// Disable provider
221    ///
222    /// # Errors
223    ///
224    /// Returns `AuthError::Internal` if the mutex is poisoned.
225    pub fn disable(&self, name: &str) -> std::result::Result<bool, AuthError> {
226        let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
227            message: "provider registry mutex poisoned".to_string(),
228        })?;
229        if let Some(provider) = providers.get_mut(name) {
230            provider.set_enabled(false);
231            Ok(true)
232        } else {
233            Ok(false)
234        }
235    }
236
237    /// Enable provider
238    ///
239    /// # Errors
240    ///
241    /// Returns `AuthError::Internal` if the mutex is poisoned.
242    pub fn enable(&self, name: &str) -> std::result::Result<bool, AuthError> {
243        let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
244            message: "provider registry mutex poisoned".to_string(),
245        })?;
246        if let Some(provider) = providers.get_mut(name) {
247            provider.set_enabled(true);
248            Ok(true)
249        } else {
250            Ok(false)
251        }
252    }
253}
254
255impl Default for ProviderRegistry {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    // --- ProviderType tests ---
266
267    #[test]
268    fn test_provider_type_display() {
269        assert_eq!(ProviderType::OAuth2.to_string(), "oauth2");
270        assert_eq!(ProviderType::OIDC.to_string(), "oidc");
271    }
272
273    // --- OAuthSession tests ---
274
275    #[test]
276    fn test_session_is_not_expired_on_creation() {
277        let session = OAuthSession::new(
278            "user_1".to_string(),
279            ProviderType::OIDC,
280            "auth0".to_string(),
281            "auth0|sub".to_string(),
282            "at_fresh".to_string(),
283            Utc::now() + Duration::hours(1),
284        );
285        assert!(!session.is_expired(), "freshly created session must not be expired");
286    }
287
288    #[test]
289    fn test_session_is_expiring_soon() {
290        let session = OAuthSession::new(
291            "user_1".to_string(),
292            ProviderType::OIDC,
293            "auth0".to_string(),
294            "auth0|sub".to_string(),
295            "at".to_string(),
296            Utc::now() + Duration::seconds(30),
297        );
298        assert!(
299            session.is_expiring_soon(60),
300            "session expiring in 30s must be considered expiring soon with grace=60"
301        );
302        assert!(
303            !session.is_expiring_soon(10),
304            "session expiring in 30s must not be considered expiring soon with grace=10"
305        );
306    }
307
308    #[test]
309    fn test_session_refresh_tokens_updates_fields() {
310        let mut session = OAuthSession::new(
311            "user_1".to_string(),
312            ProviderType::OIDC,
313            "auth0".to_string(),
314            "auth0|sub".to_string(),
315            "old_at".to_string(),
316            Utc::now() + Duration::hours(1),
317        );
318        let new_expiry = Utc::now() + Duration::hours(2);
319        session.refresh_tokens("new_at".to_string(), new_expiry);
320
321        assert_eq!(session.access_token, "new_at");
322        assert_eq!(session.token_expiry, new_expiry);
323        assert!(session.last_refreshed.is_some());
324    }
325
326    // --- ExternalAuthProvider tests ---
327
328    #[test]
329    fn test_external_provider_defaults() {
330        let provider =
331            ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "client_id", "vault/secret");
332        assert!(provider.enabled, "new provider must be enabled by default");
333        assert_eq!(provider.scopes, vec!["openid", "profile", "email"]);
334        assert!(provider.oidc_config.is_none());
335        assert!(provider.oauth2_config.is_none());
336    }
337
338    #[test]
339    fn test_external_provider_set_enabled() {
340        let mut provider =
341            ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id", "vault/path");
342        provider.set_enabled(false);
343        assert!(!provider.enabled);
344        provider.set_enabled(true);
345        assert!(provider.enabled);
346    }
347
348    #[test]
349    fn test_external_provider_set_scopes() {
350        let mut provider =
351            ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id", "vault/path");
352        provider.set_scopes(vec!["openid".to_string()]);
353        assert_eq!(provider.scopes, vec!["openid"]);
354    }
355
356    // --- ProviderRegistry tests ---
357
358    #[test]
359    fn test_registry_register_and_get() {
360        let registry = ProviderRegistry::new();
361        let provider = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id", "vault/path");
362        registry.register(provider.clone()).expect("register must succeed");
363        let retrieved = registry.get("auth0").expect("get must succeed");
364        assert_eq!(retrieved, Some(provider));
365    }
366
367    #[test]
368    fn test_registry_get_nonexistent_returns_none() {
369        let registry = ProviderRegistry::new();
370        let retrieved = registry.get("nonexistent").expect("get must succeed");
371        assert!(retrieved.is_none());
372    }
373
374    #[test]
375    fn test_registry_list_enabled_filters_disabled() {
376        let registry = ProviderRegistry::new();
377        let p1 = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id1", "path1");
378        let mut p2 = ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id2", "path2");
379        p2.set_enabled(false);
380        registry.register(p1).expect("register must succeed");
381        registry.register(p2).expect("register must succeed");
382
383        let enabled = registry.list_enabled().expect("list_enabled must succeed");
384        assert_eq!(enabled.len(), 1);
385        assert_eq!(enabled[0].provider_name, "auth0");
386    }
387
388    #[test]
389    fn test_registry_disable_and_enable() {
390        let registry = ProviderRegistry::new();
391        let provider = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id", "path");
392        registry.register(provider).expect("register must succeed");
393
394        assert!(registry.disable("auth0").expect("disable must succeed"));
395        let p = registry.get("auth0").expect("get must succeed").expect("provider must exist");
396        assert!(!p.enabled);
397
398        assert!(registry.enable("auth0").expect("enable must succeed"));
399        let p = registry.get("auth0").expect("get must succeed").expect("provider must exist");
400        assert!(p.enabled);
401    }
402
403    #[test]
404    fn test_registry_disable_nonexistent_returns_false() {
405        let registry = ProviderRegistry::new();
406        assert!(!registry.disable("nonexistent").expect("disable must succeed"));
407    }
408}