1use std::{collections::HashMap, sync::Arc};
4
5use chrono::{DateTime, Duration, Utc};
6use serde::{Deserialize, Serialize};
7
8use super::{super::error::AuthError, client::OIDCProviderConfig};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum ProviderType {
14 OAuth2,
16 OIDC,
18}
19
20impl std::fmt::Display for ProviderType {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 Self::OAuth2 => write!(f, "oauth2"),
24 Self::OIDC => write!(f, "oidc"),
25 }
26 }
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct OAuthSession {
32 pub id: String,
34 pub user_id: String,
36 pub provider_type: ProviderType,
38 pub provider_name: String,
40 pub provider_user_id: String,
42 pub access_token: String,
44 pub refresh_token: Option<String>,
46 pub token_expiry: DateTime<Utc>,
48 pub created_at: DateTime<Utc>,
50 pub last_refreshed: Option<DateTime<Utc>>,
52}
53
54impl OAuthSession {
55 pub fn new(
57 user_id: String,
58 provider_type: ProviderType,
59 provider_name: String,
60 provider_user_id: String,
61 access_token: String,
62 token_expiry: DateTime<Utc>,
63 ) -> Self {
64 Self {
65 id: uuid::Uuid::new_v4().to_string(),
66 user_id,
67 provider_type,
68 provider_name,
69 provider_user_id,
70 access_token,
71 refresh_token: None,
72 token_expiry,
73 created_at: Utc::now(),
74 last_refreshed: None,
75 }
76 }
77
78 pub fn is_expired(&self) -> bool {
80 self.token_expiry <= Utc::now()
81 }
82
83 pub fn is_expiring_soon(&self, grace_seconds: i64) -> bool {
85 self.token_expiry <= (Utc::now() + Duration::seconds(grace_seconds))
86 }
87
88 pub fn refresh_tokens(&mut self, access_token: String, token_expiry: DateTime<Utc>) {
90 self.access_token = access_token;
91 self.token_expiry = token_expiry;
92 self.last_refreshed = Some(Utc::now());
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98pub struct ExternalAuthProvider {
99 pub id: String,
101 pub provider_type: ProviderType,
103 pub provider_name: String,
105 pub client_id: String,
107 pub client_secret_vault_path: String,
109 pub oidc_config: Option<OIDCProviderConfig>,
111 pub oauth2_config: Option<OAuth2ClientConfig>,
113 pub enabled: bool,
115 pub scopes: Vec<String>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
121pub struct OAuth2ClientConfig {
122 pub authorization_endpoint: String,
124 pub token_endpoint: String,
126 pub use_pkce: bool,
128}
129
130impl ExternalAuthProvider {
131 pub fn new(
133 provider_type: ProviderType,
134 provider_name: impl Into<String>,
135 client_id: impl Into<String>,
136 client_secret_vault_path: impl Into<String>,
137 ) -> Self {
138 Self {
139 id: uuid::Uuid::new_v4().to_string(),
140 provider_type,
141 provider_name: provider_name.into(),
142 client_id: client_id.into(),
143 client_secret_vault_path: client_secret_vault_path.into(),
144 oidc_config: None,
145 oauth2_config: None,
146 enabled: true,
147 scopes: vec![
148 "openid".to_string(),
149 "profile".to_string(),
150 "email".to_string(),
151 ],
152 }
153 }
154
155 pub const fn set_enabled(&mut self, enabled: bool) {
157 self.enabled = enabled;
158 }
159
160 pub fn set_scopes(&mut self, scopes: Vec<String>) {
162 self.scopes = scopes;
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct ProviderRegistry {
169 providers: Arc<std::sync::Mutex<HashMap<String, ExternalAuthProvider>>>,
173}
174
175impl ProviderRegistry {
176 pub fn new() -> Self {
178 Self {
179 providers: Arc::new(std::sync::Mutex::new(HashMap::new())),
180 }
181 }
182
183 pub fn register(&self, provider: ExternalAuthProvider) -> std::result::Result<(), AuthError> {
189 let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
190 message: "provider registry mutex poisoned".to_string(),
191 })?;
192 providers.insert(provider.provider_name.clone(), provider);
193 Ok(())
194 }
195
196 pub fn get(&self, name: &str) -> std::result::Result<Option<ExternalAuthProvider>, AuthError> {
202 let providers = self.providers.lock().map_err(|_| AuthError::Internal {
203 message: "provider registry mutex poisoned".to_string(),
204 })?;
205 Ok(providers.get(name).cloned())
206 }
207
208 pub fn list_enabled(&self) -> std::result::Result<Vec<ExternalAuthProvider>, AuthError> {
214 let providers = self.providers.lock().map_err(|_| AuthError::Internal {
215 message: "provider registry mutex poisoned".to_string(),
216 })?;
217 Ok(providers.values().filter(|p| p.enabled).cloned().collect())
218 }
219
220 pub fn disable(&self, name: &str) -> std::result::Result<bool, AuthError> {
226 let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
227 message: "provider registry mutex poisoned".to_string(),
228 })?;
229 if let Some(provider) = providers.get_mut(name) {
230 provider.set_enabled(false);
231 Ok(true)
232 } else {
233 Ok(false)
234 }
235 }
236
237 pub fn enable(&self, name: &str) -> std::result::Result<bool, AuthError> {
243 let mut providers = self.providers.lock().map_err(|_| AuthError::Internal {
244 message: "provider registry mutex poisoned".to_string(),
245 })?;
246 if let Some(provider) = providers.get_mut(name) {
247 provider.set_enabled(true);
248 Ok(true)
249 } else {
250 Ok(false)
251 }
252 }
253}
254
255impl Default for ProviderRegistry {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
268 fn test_provider_type_display() {
269 assert_eq!(ProviderType::OAuth2.to_string(), "oauth2");
270 assert_eq!(ProviderType::OIDC.to_string(), "oidc");
271 }
272
273 #[test]
276 fn test_session_is_not_expired_on_creation() {
277 let session = OAuthSession::new(
278 "user_1".to_string(),
279 ProviderType::OIDC,
280 "auth0".to_string(),
281 "auth0|sub".to_string(),
282 "at_fresh".to_string(),
283 Utc::now() + Duration::hours(1),
284 );
285 assert!(!session.is_expired(), "freshly created session must not be expired");
286 }
287
288 #[test]
289 fn test_session_is_expiring_soon() {
290 let session = OAuthSession::new(
291 "user_1".to_string(),
292 ProviderType::OIDC,
293 "auth0".to_string(),
294 "auth0|sub".to_string(),
295 "at".to_string(),
296 Utc::now() + Duration::seconds(30),
297 );
298 assert!(
299 session.is_expiring_soon(60),
300 "session expiring in 30s must be considered expiring soon with grace=60"
301 );
302 assert!(
303 !session.is_expiring_soon(10),
304 "session expiring in 30s must not be considered expiring soon with grace=10"
305 );
306 }
307
308 #[test]
309 fn test_session_refresh_tokens_updates_fields() {
310 let mut session = OAuthSession::new(
311 "user_1".to_string(),
312 ProviderType::OIDC,
313 "auth0".to_string(),
314 "auth0|sub".to_string(),
315 "old_at".to_string(),
316 Utc::now() + Duration::hours(1),
317 );
318 let new_expiry = Utc::now() + Duration::hours(2);
319 session.refresh_tokens("new_at".to_string(), new_expiry);
320
321 assert_eq!(session.access_token, "new_at");
322 assert_eq!(session.token_expiry, new_expiry);
323 assert!(session.last_refreshed.is_some());
324 }
325
326 #[test]
329 fn test_external_provider_defaults() {
330 let provider =
331 ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "client_id", "vault/secret");
332 assert!(provider.enabled, "new provider must be enabled by default");
333 assert_eq!(provider.scopes, vec!["openid", "profile", "email"]);
334 assert!(provider.oidc_config.is_none());
335 assert!(provider.oauth2_config.is_none());
336 }
337
338 #[test]
339 fn test_external_provider_set_enabled() {
340 let mut provider =
341 ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id", "vault/path");
342 provider.set_enabled(false);
343 assert!(!provider.enabled);
344 provider.set_enabled(true);
345 assert!(provider.enabled);
346 }
347
348 #[test]
349 fn test_external_provider_set_scopes() {
350 let mut provider =
351 ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id", "vault/path");
352 provider.set_scopes(vec!["openid".to_string()]);
353 assert_eq!(provider.scopes, vec!["openid"]);
354 }
355
356 #[test]
359 fn test_registry_register_and_get() {
360 let registry = ProviderRegistry::new();
361 let provider = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id", "vault/path");
362 registry.register(provider.clone()).expect("register must succeed");
363 let retrieved = registry.get("auth0").expect("get must succeed");
364 assert_eq!(retrieved, Some(provider));
365 }
366
367 #[test]
368 fn test_registry_get_nonexistent_returns_none() {
369 let registry = ProviderRegistry::new();
370 let retrieved = registry.get("nonexistent").expect("get must succeed");
371 assert!(retrieved.is_none());
372 }
373
374 #[test]
375 fn test_registry_list_enabled_filters_disabled() {
376 let registry = ProviderRegistry::new();
377 let p1 = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id1", "path1");
378 let mut p2 = ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id2", "path2");
379 p2.set_enabled(false);
380 registry.register(p1).expect("register must succeed");
381 registry.register(p2).expect("register must succeed");
382
383 let enabled = registry.list_enabled().expect("list_enabled must succeed");
384 assert_eq!(enabled.len(), 1);
385 assert_eq!(enabled[0].provider_name, "auth0");
386 }
387
388 #[test]
389 fn test_registry_disable_and_enable() {
390 let registry = ProviderRegistry::new();
391 let provider = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id", "path");
392 registry.register(provider).expect("register must succeed");
393
394 assert!(registry.disable("auth0").expect("disable must succeed"));
395 let p = registry.get("auth0").expect("get must succeed").expect("provider must exist");
396 assert!(!p.enabled);
397
398 assert!(registry.enable("auth0").expect("enable must succeed"));
399 let p = registry.get("auth0").expect("get must succeed").expect("provider must exist");
400 assert!(p.enabled);
401 }
402
403 #[test]
404 fn test_registry_disable_nonexistent_returns_false() {
405 let registry = ProviderRegistry::new();
406 assert!(!registry.disable("nonexistent").expect("disable must succeed"));
407 }
408}