fraiseql-auth 2.2.0

Authentication, authorization, and session management for FraiseQL
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
//! External OAuth provider registry and session management.

use std::{collections::HashMap, sync::Arc};

use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};

use super::{super::error::AuthError, client::OIDCProviderConfig};

/// External authentication provider type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ProviderType {
    /// OAuth2 provider
    OAuth2,
    /// OIDC provider
    OIDC,
}

impl std::fmt::Display for ProviderType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::OAuth2 => write!(f, "oauth2"),
            Self::OIDC => write!(f, "oidc"),
        }
    }
}

/// OAuth session stored in database
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthSession {
    /// Session ID
    pub id:               String,
    /// User ID (local system)
    pub user_id:          String,
    /// Provider type (oauth2, oidc)
    pub provider_type:    ProviderType,
    /// Provider name (Auth0, Google, etc.)
    pub provider_name:    String,
    /// Provider's user ID (sub claim)
    pub provider_user_id: String,
    /// Access token (encrypted)
    pub access_token:     String,
    /// Refresh token (encrypted), if available
    pub refresh_token:    Option<String>,
    /// When access token expires
    pub token_expiry:     DateTime<Utc>,
    /// Session creation time
    pub created_at:       DateTime<Utc>,
    /// Last time token was refreshed
    pub last_refreshed:   Option<DateTime<Utc>>,
}

impl OAuthSession {
    /// Create new OAuth session
    pub fn new(
        user_id: String,
        provider_type: ProviderType,
        provider_name: String,
        provider_user_id: String,
        access_token: String,
        token_expiry: DateTime<Utc>,
    ) -> Self {
        Self {
            id: uuid::Uuid::new_v4().to_string(),
            user_id,
            provider_type,
            provider_name,
            provider_user_id,
            access_token,
            refresh_token: None,
            token_expiry,
            created_at: Utc::now(),
            last_refreshed: None,
        }
    }

    /// Check if session is expired
    pub fn is_expired(&self) -> bool {
        self.token_expiry <= Utc::now()
    }

    /// Check if session will be expired within grace period
    pub fn is_expiring_soon(&self, grace_seconds: i64) -> bool {
        self.token_expiry <= (Utc::now() + Duration::seconds(grace_seconds))
    }

    /// Update tokens after refresh
    pub fn refresh_tokens(&mut self, access_token: String, token_expiry: DateTime<Utc>) {
        self.access_token = access_token;
        self.token_expiry = token_expiry;
        self.last_refreshed = Some(Utc::now());
    }
}

/// External auth provider configuration
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ExternalAuthProvider {
    /// Provider ID
    pub id: String,
    /// Provider type (oauth2, oidc)
    pub provider_type: ProviderType,
    /// Provider name (Auth0, Google, Microsoft, Okta)
    pub provider_name: String,
    /// Client ID
    pub client_id: String,
    /// Client secret (should be fetched from vault)
    pub client_secret_vault_path: String,
    /// Provider configuration (OIDC)
    pub oidc_config: Option<OIDCProviderConfig>,
    /// OAuth2 configuration
    pub oauth2_config: Option<OAuth2ClientConfig>,
    /// Enabled flag
    pub enabled: bool,
    /// Requested scopes
    pub scopes: Vec<String>,
}

/// OAuth2 client configuration
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct OAuth2ClientConfig {
    /// Authorization endpoint
    pub authorization_endpoint: String,
    /// Token endpoint
    pub token_endpoint:         String,
    /// Use PKCE
    pub use_pkce:               bool,
}

impl ExternalAuthProvider {
    /// Create new external auth provider
    pub fn new(
        provider_type: ProviderType,
        provider_name: impl Into<String>,
        client_id: impl Into<String>,
        client_secret_vault_path: impl Into<String>,
    ) -> Self {
        Self {
            id: uuid::Uuid::new_v4().to_string(),
            provider_type,
            provider_name: provider_name.into(),
            client_id: client_id.into(),
            client_secret_vault_path: client_secret_vault_path.into(),
            oidc_config: None,
            oauth2_config: None,
            enabled: true,
            scopes: vec![
                "openid".to_string(),
                "profile".to_string(),
                "email".to_string(),
            ],
        }
    }

    /// Enable or disable provider
    pub const fn set_enabled(&mut self, enabled: bool) {
        self.enabled = enabled;
    }

    /// Set requested scopes
    pub fn set_scopes(&mut self, scopes: Vec<String>) {
        self.scopes = scopes;
    }
}

/// Provider registry managing multiple OAuth providers
#[derive(Debug, Clone)]
pub struct ProviderRegistry {
    /// Map of providers by name
    // std::sync::Mutex is intentional: this lock is never held across .await.
    // Switch to tokio::sync::Mutex if that constraint ever changes.
    providers: Arc<std::sync::Mutex<HashMap<String, ExternalAuthProvider>>>,
}

impl ProviderRegistry {
    /// Create new provider registry
    pub fn new() -> Self {
        Self {
            providers: Arc::new(std::sync::Mutex::new(HashMap::new())),
        }
    }

    /// Register provider
    ///
    /// # Errors
    ///
    /// Returns `AuthError::Internal` if the mutex is poisoned.
    pub fn register(&self, provider: ExternalAuthProvider) -> std::result::Result<(), AuthError> {
        let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
            message: "provider registry mutex poisoned".to_string(),
        })?;
        providers.insert(provider.provider_name.clone(), provider);
        Ok(())
    }

    /// Get provider by name
    ///
    /// # Errors
    ///
    /// Returns `AuthError::Internal` if the mutex is poisoned.
    pub fn get(&self, name: &str) -> std::result::Result<Option<ExternalAuthProvider>, AuthError> {
        let providers = self.providers.lock().map_err(|_| AuthError::Internal {
            message: "provider registry mutex poisoned".to_string(),
        })?;
        Ok(providers.get(name).cloned())
    }

    /// List all enabled providers
    ///
    /// # Errors
    ///
    /// Returns `AuthError::Internal` if the mutex is poisoned.
    pub fn list_enabled(&self) -> std::result::Result<Vec<ExternalAuthProvider>, AuthError> {
        let providers = self.providers.lock().map_err(|_| AuthError::Internal {
            message: "provider registry mutex poisoned".to_string(),
        })?;
        Ok(providers.values().filter(|p| p.enabled).cloned().collect())
    }

    /// Disable provider
    ///
    /// # Errors
    ///
    /// Returns `AuthError::Internal` if the mutex is poisoned.
    pub fn disable(&self, name: &str) -> std::result::Result<bool, AuthError> {
        let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
            message: "provider registry mutex poisoned".to_string(),
        })?;
        if let Some(provider) = providers.get_mut(name) {
            provider.set_enabled(false);
            Ok(true)
        } else {
            Ok(false)
        }
    }

    /// Enable provider
    ///
    /// # Errors
    ///
    /// Returns `AuthError::Internal` if the mutex is poisoned.
    pub fn enable(&self, name: &str) -> std::result::Result<bool, AuthError> {
        let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
            message: "provider registry mutex poisoned".to_string(),
        })?;
        if let Some(provider) = providers.get_mut(name) {
            provider.set_enabled(true);
            Ok(true)
        } else {
            Ok(false)
        }
    }
}

impl Default for ProviderRegistry {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // --- ProviderType tests ---

    #[test]
    fn test_provider_type_display() {
        assert_eq!(ProviderType::OAuth2.to_string(), "oauth2");
        assert_eq!(ProviderType::OIDC.to_string(), "oidc");
    }

    // --- OAuthSession tests ---

    #[test]
    fn test_session_is_not_expired_on_creation() {
        let session = OAuthSession::new(
            "user_1".to_string(),
            ProviderType::OIDC,
            "auth0".to_string(),
            "auth0|sub".to_string(),
            "at_fresh".to_string(),
            Utc::now() + Duration::hours(1),
        );
        assert!(!session.is_expired(), "freshly created session must not be expired");
    }

    #[test]
    fn test_session_is_expiring_soon() {
        let session = OAuthSession::new(
            "user_1".to_string(),
            ProviderType::OIDC,
            "auth0".to_string(),
            "auth0|sub".to_string(),
            "at".to_string(),
            Utc::now() + Duration::seconds(30),
        );
        assert!(
            session.is_expiring_soon(60),
            "session expiring in 30s must be considered expiring soon with grace=60"
        );
        assert!(
            !session.is_expiring_soon(10),
            "session expiring in 30s must not be considered expiring soon with grace=10"
        );
    }

    #[test]
    fn test_session_refresh_tokens_updates_fields() {
        let mut session = OAuthSession::new(
            "user_1".to_string(),
            ProviderType::OIDC,
            "auth0".to_string(),
            "auth0|sub".to_string(),
            "old_at".to_string(),
            Utc::now() + Duration::hours(1),
        );
        let new_expiry = Utc::now() + Duration::hours(2);
        session.refresh_tokens("new_at".to_string(), new_expiry);

        assert_eq!(session.access_token, "new_at");
        assert_eq!(session.token_expiry, new_expiry);
        assert!(session.last_refreshed.is_some());
    }

    // --- ExternalAuthProvider tests ---

    #[test]
    fn test_external_provider_defaults() {
        let provider =
            ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "client_id", "vault/secret");
        assert!(provider.enabled, "new provider must be enabled by default");
        assert_eq!(provider.scopes, vec!["openid", "profile", "email"]);
        assert!(provider.oidc_config.is_none());
        assert!(provider.oauth2_config.is_none());
    }

    #[test]
    fn test_external_provider_set_enabled() {
        let mut provider =
            ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id", "vault/path");
        provider.set_enabled(false);
        assert!(!provider.enabled);
        provider.set_enabled(true);
        assert!(provider.enabled);
    }

    #[test]
    fn test_external_provider_set_scopes() {
        let mut provider =
            ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id", "vault/path");
        provider.set_scopes(vec!["openid".to_string()]);
        assert_eq!(provider.scopes, vec!["openid"]);
    }

    // --- ProviderRegistry tests ---

    #[test]
    fn test_registry_register_and_get() {
        let registry = ProviderRegistry::new();
        let provider = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id", "vault/path");
        registry.register(provider.clone()).expect("register must succeed");
        let retrieved = registry.get("auth0").expect("get must succeed");
        assert_eq!(retrieved, Some(provider));
    }

    #[test]
    fn test_registry_get_nonexistent_returns_none() {
        let registry = ProviderRegistry::new();
        let retrieved = registry.get("nonexistent").expect("get must succeed");
        assert!(retrieved.is_none());
    }

    #[test]
    fn test_registry_list_enabled_filters_disabled() {
        let registry = ProviderRegistry::new();
        let p1 = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id1", "path1");
        let mut p2 = ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id2", "path2");
        p2.set_enabled(false);
        registry.register(p1).expect("register must succeed");
        registry.register(p2).expect("register must succeed");

        let enabled = registry.list_enabled().expect("list_enabled must succeed");
        assert_eq!(enabled.len(), 1);
        assert_eq!(enabled[0].provider_name, "auth0");
    }

    #[test]
    fn test_registry_disable_and_enable() {
        let registry = ProviderRegistry::new();
        let provider = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id", "path");
        registry.register(provider).expect("register must succeed");

        assert!(registry.disable("auth0").expect("disable must succeed"));
        let p = registry.get("auth0").expect("get must succeed").expect("provider must exist");
        assert!(!p.enabled);

        assert!(registry.enable("auth0").expect("enable must succeed"));
        let p = registry.get("auth0").expect("get must succeed").expect("provider must exist");
        assert!(p.enabled);
    }

    #[test]
    fn test_registry_disable_nonexistent_returns_false() {
        let registry = ProviderRegistry::new();
        assert!(!registry.disable("nonexistent").expect("disable must succeed"));
    }
}