1#[cfg(feature = "enhanced-device-flow")]
5pub mod enhanced_device;
6
7#[cfg(feature = "enhanced-device-flow")]
8pub use enhanced_device::{EnhancedDeviceFlowMethod, DeviceFlowInstructions};
9
10#[cfg(test)]
12#[cfg(feature = "enhanced-device-flow")]
13mod enhanced_device_tests;
14
15use crate::credentials::{Credential, CredentialMetadata};
16use crate::errors::{AuthError, Result};
17use crate::providers::{OAuthProvider, generate_state, generate_pkce};
18use crate::tokens::{AuthToken, TokenManager};
19use async_trait::async_trait;
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::time::Duration;
23
24#[derive(Debug, Clone)]
26pub enum MethodResult {
27 Success(Box<AuthToken>),
29
30 MfaRequired(Box<MfaChallenge>),
32
33 Failure { reason: String },
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct MfaChallenge {
40 pub id: String,
42
43 pub mfa_type: MfaType,
45
46 pub user_id: String,
48
49 pub expires_at: chrono::DateTime<chrono::Utc>,
51
52 pub message: Option<String>,
54
55 pub data: HashMap<String, serde_json::Value>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub enum MfaType {
62 Totp,
64
65 Sms { phone_number: String },
67
68 Email { email_address: String },
70
71 Push { device_id: String },
73
74 SecurityKey,
76
77 BackupCode,
79}
80
81#[async_trait]
83pub trait AuthMethod: Send + Sync {
84 fn name(&self) -> &str;
86
87 async fn authenticate(
89 &self,
90 credential: &Credential,
91 metadata: &CredentialMetadata,
92 ) -> Result<MethodResult>;
93
94 fn validate_config(&self) -> Result<()>;
96
97 fn supports_refresh(&self) -> bool {
99 false
100 }
101
102 async fn refresh_token(&self, _refresh_token: &str) -> Result<AuthToken> {
104 Err(AuthError::auth_method(
105 self.name(),
106 "Token refresh not supported by this method".to_string(),
107 ))
108 }
109}
110
111pub struct PasswordMethod {
113 name: String,
114 password_verifier: Box<dyn PasswordVerifier>,
115 token_manager: TokenManager,
116 mfa_enabled: bool,
117 user_lookup: Box<dyn UserLookup>,
118}
119
120pub struct JwtMethod {
122 name: String,
123 token_manager: TokenManager,
124 issuer: String,
125 audience: String,
126}
127
128pub struct ApiKeyMethod {
130 name: String,
131 key_prefix: Option<String>,
132 header_name: String,
133 key_validator: Box<dyn ApiKeyValidator>,
134 token_manager: TokenManager,
135}
136
137pub struct OAuth2Method {
139 name: String,
140 provider: OAuthProvider,
141 client_id: String,
142 client_secret: String,
143 redirect_uri: String,
144 scopes: Vec<String>,
145 use_pkce: bool,
146 token_manager: TokenManager,
147}
148
149#[async_trait]
151pub trait PasswordVerifier: Send + Sync {
152 async fn verify_password(&self, username: &str, password: &str) -> Result<bool>;
154
155 async fn hash_password(&self, password: &str) -> Result<String>;
157}
158
159#[async_trait]
161pub trait UserLookup: Send + Sync {
162 async fn lookup_user(&self, username: &str) -> Result<Option<UserInfo>>;
164
165 async fn requires_mfa(&self, user_id: &str) -> Result<bool>;
167}
168
169#[async_trait]
171pub trait ApiKeyValidator: Send + Sync {
172 async fn validate_key(&self, api_key: &str) -> Result<Option<UserInfo>>;
174
175 async fn create_key(&self, user_id: &str, expires_in: Option<Duration>) -> Result<String>;
177
178 async fn revoke_key(&self, api_key: &str) -> Result<()>;
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct UserInfo {
185 pub id: String,
187
188 pub username: String,
190
191 pub email: Option<String>,
193
194 pub name: Option<String>,
196
197 pub roles: Vec<String>,
199
200 pub active: bool,
202
203 pub attributes: HashMap<String, serde_json::Value>,
205}
206
207impl MfaChallenge {
208 pub fn new(
210 mfa_type: MfaType,
211 user_id: impl Into<String>,
212 expires_in: Duration,
213 ) -> Self {
214 Self {
215 id: uuid::Uuid::new_v4().to_string(),
216 mfa_type,
217 user_id: user_id.into(),
218 expires_at: chrono::Utc::now() + chrono::Duration::from_std(expires_in).unwrap(),
219 message: None,
220 data: HashMap::new(),
221 }
222 }
223
224 pub fn id(&self) -> &str {
226 &self.id
227 }
228
229 pub fn is_expired(&self) -> bool {
231 chrono::Utc::now() > self.expires_at
232 }
233
234 pub fn with_message(mut self, message: impl Into<String>) -> Self {
236 self.message = Some(message.into());
237 self
238 }
239}
240
241impl PasswordMethod {
242 pub fn new(
244 password_verifier: Box<dyn PasswordVerifier>,
245 user_lookup: Box<dyn UserLookup>,
246 token_manager: TokenManager,
247 ) -> Self {
248 Self {
249 name: "password".to_string(),
250 password_verifier,
251 token_manager,
252 mfa_enabled: false,
253 user_lookup,
254 }
255 }
256
257 pub fn with_mfa(mut self, enabled: bool) -> Self {
259 self.mfa_enabled = enabled;
260 self
261 }
262}
263
264#[async_trait]
265impl AuthMethod for PasswordMethod {
266 fn name(&self) -> &str {
267 &self.name
268 }
269
270 async fn authenticate(
271 &self,
272 credential: &Credential,
273 _metadata: &CredentialMetadata,
274 ) -> Result<MethodResult> {
275 let (username, password) = match credential {
276 Credential::Password { username, password } => (username, password),
277 _ => return Err(AuthError::auth_method(
278 self.name(),
279 "Invalid credential type for password authentication".to_string(),
280 )),
281 };
282
283 if !self.password_verifier.verify_password(username, password).await? {
285 return Ok(MethodResult::Failure {
286 reason: "Invalid username or password".to_string(),
287 });
288 }
289
290 let user = self.user_lookup.lookup_user(username).await?
292 .ok_or_else(|| AuthError::auth_method(
293 self.name(),
294 "User not found".to_string(),
295 ))?;
296
297 if !user.active {
298 return Ok(MethodResult::Failure {
299 reason: "User account is disabled".to_string(),
300 });
301 }
302
303 if self.mfa_enabled && self.user_lookup.requires_mfa(&user.id).await? {
305 let challenge = MfaChallenge::new(
306 MfaType::Totp, &user.id,
308 Duration::from_secs(300), ).with_message("Please enter your MFA code");
310
311 return Ok(MethodResult::MfaRequired(Box::new(challenge)));
312 }
313
314 let token = self.token_manager.create_auth_token(
316 &user.id,
317 vec![], self.name(),
319 None,
320 )?;
321
322 Ok(MethodResult::Success(Box::new(token)))
323 }
324
325 fn validate_config(&self) -> Result<()> {
326 Ok(())
328 }
329}
330
331impl Default for JwtMethod {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337impl JwtMethod {
338 pub fn new() -> Self {
340 let token_manager = TokenManager::new_hmac(
341 b"default-secret", "default-issuer",
343 "default-audience",
344 );
345
346 Self {
347 name: "jwt".to_string(),
348 token_manager,
349 issuer: "default-issuer".to_string(),
350 audience: "default-audience".to_string(),
351 }
352 }
353
354 pub fn secret_key(mut self, secret: impl Into<String>) -> Self {
356 let secret = secret.into();
357 self.token_manager = TokenManager::new_hmac(
358 secret.as_bytes(),
359 &self.issuer,
360 &self.audience,
361 );
362 self
363 }
364
365 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
367 self.issuer = issuer.into();
368 self.token_manager = TokenManager::new_hmac(
369 b"default-secret", &self.issuer,
371 &self.audience,
372 );
373 self
374 }
375
376 pub fn audience(mut self, audience: impl Into<String>) -> Self {
378 self.audience = audience.into();
379 self.token_manager = TokenManager::new_hmac(
380 b"default-secret", &self.issuer,
382 &self.audience,
383 );
384 self
385 }
386
387 pub fn algorithm(self, _algorithm: impl Into<String>) -> Self {
389 self
391 }
392}
393
394#[async_trait]
395impl AuthMethod for JwtMethod {
396 fn name(&self) -> &str {
397 &self.name
398 }
399
400 async fn authenticate(
401 &self,
402 credential: &Credential,
403 _metadata: &CredentialMetadata,
404 ) -> Result<MethodResult> {
405 let token_str = match credential {
406 Credential::Jwt { token } => token,
407 Credential::Bearer { token } => token,
408 _ => return Err(AuthError::auth_method(
409 self.name(),
410 "Invalid credential type for JWT authentication".to_string(),
411 )),
412 };
413
414 let claims = self.token_manager.validate_jwt_token(token_str)?;
416
417 let remaining_seconds = (claims.exp - chrono::Utc::now().timestamp()).max(0) as u64;
419 let token = AuthToken::new(
420 claims.sub,
421 token_str.clone(),
422 std::time::Duration::from_secs(remaining_seconds),
423 self.name(),
424 ).with_scopes(claims.scope.split_whitespace().map(|s| s.to_string()).collect());
425
426 Ok(MethodResult::Success(Box::new(token)))
427 }
428
429 fn validate_config(&self) -> Result<()> {
430 Ok(())
432 }
433}
434
435impl Default for ApiKeyMethod {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441impl ApiKeyMethod {
442 pub fn new() -> Self {
444 let token_manager = TokenManager::new_hmac(
445 b"default-secret",
446 "api-key-issuer",
447 "api-key-audience",
448 );
449
450 Self {
451 name: "api-key".to_string(),
452 key_prefix: None,
453 header_name: "X-API-Key".to_string(),
454 key_validator: Box::new(DefaultApiKeyValidator),
455 token_manager,
456 }
457 }
458
459 pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
461 self.key_prefix = Some(prefix.into());
462 self
463 }
464
465 pub fn key_length(self, _length: usize) -> Self {
467 self
469 }
470
471 pub fn header_name(mut self, name: impl Into<String>) -> Self {
473 self.header_name = name.into();
474 self
475 }
476
477 pub fn key_validator(mut self, validator: Box<dyn ApiKeyValidator>) -> Self {
479 self.key_validator = validator;
480 self
481 }
482}
483
484#[async_trait]
485impl AuthMethod for ApiKeyMethod {
486 fn name(&self) -> &str {
487 &self.name
488 }
489
490 async fn authenticate(
491 &self,
492 credential: &Credential,
493 _metadata: &CredentialMetadata,
494 ) -> Result<MethodResult> {
495 let api_key = match credential {
496 Credential::ApiKey { key } => key,
497 _ => return Err(AuthError::auth_method(
498 self.name(),
499 "Invalid credential type for API key authentication".to_string(),
500 )),
501 };
502
503 if let Some(prefix) = &self.key_prefix {
505 if !api_key.starts_with(prefix) {
506 return Ok(MethodResult::Failure {
507 reason: "Invalid API key format".to_string(),
508 });
509 }
510 }
511
512 let user = self.key_validator.validate_key(api_key).await?
514 .ok_or_else(|| AuthError::auth_method(
515 self.name(),
516 "Invalid API key".to_string(),
517 ))?;
518
519 let token = self.token_manager.create_auth_token(
521 &user.id,
522 vec!["api".to_string()], self.name(),
524 Some(std::time::Duration::from_secs(3600)), )?;
526
527 Ok(MethodResult::Success(Box::new(token)))
528 }
529
530 fn validate_config(&self) -> Result<()> {
531 Ok(())
532 }
533}
534
535impl Default for OAuth2Method {
536 fn default() -> Self {
537 Self::new()
538 }
539}
540
541impl OAuth2Method {
542 pub fn new() -> Self {
544 let token_manager = TokenManager::new_hmac(
545 b"oauth-secret",
546 "oauth-issuer",
547 "oauth-audience",
548 );
549
550 Self {
551 name: "oauth2".to_string(),
552 provider: OAuthProvider::GitHub, client_id: String::new(),
554 client_secret: String::new(),
555 redirect_uri: String::new(),
556 scopes: Vec::new(),
557 use_pkce: true,
558 token_manager,
559 }
560 }
561
562 pub fn provider(mut self, provider: OAuthProvider) -> Self {
564 self.provider = provider;
565 self
566 }
567
568 pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
570 self.client_id = client_id.into();
571 self
572 }
573
574 pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
576 self.client_secret = client_secret.into();
577 self
578 }
579
580 pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
582 self.redirect_uri = redirect_uri.into();
583 self
584 }
585
586 pub fn scopes(mut self, scopes: Vec<String>) -> Self {
588 self.scopes = scopes;
589 self
590 }
591
592 pub fn use_pkce(mut self, use_pkce: bool) -> Self {
594 self.use_pkce = use_pkce;
595 self
596 }
597
598 pub fn authorization_url(&self) -> Result<AuthorizationUrlResult> {
600 let state = generate_state();
601 let pkce = if self.use_pkce {
602 Some(generate_pkce())
603 } else {
604 None
605 };
606
607 let url = self.provider.build_authorization_url(
608 &self.client_id,
609 &self.redirect_uri,
610 &state,
611 if self.scopes.is_empty() { None } else { Some(&self.scopes) },
612 pkce.as_ref().map(|(_, challenge)| challenge.as_str()),
613 )?;
614
615 Ok((url, state, pkce))
616 }
617}
618
619#[async_trait]
620impl AuthMethod for OAuth2Method {
621 fn name(&self) -> &str {
622 &self.name
623 }
624
625 async fn authenticate(
626 &self,
627 credential: &Credential,
628 _metadata: &CredentialMetadata,
629 ) -> Result<MethodResult> {
630 let (authorization_code, code_verifier) = match credential {
631 Credential::OAuth { authorization_code, code_verifier, .. } => {
632 (authorization_code, code_verifier.as_deref())
633 }
634 _ => return Err(AuthError::auth_method(
635 self.name(),
636 "Invalid credential type for OAuth authentication".to_string(),
637 )),
638 };
639
640 let token_response = self.provider.exchange_code(
642 &self.client_id,
643 &self.client_secret,
644 authorization_code,
645 &self.redirect_uri,
646 code_verifier,
647 ).await?;
648
649 let user_info = self.provider.get_user_info(&token_response.access_token).await?;
651
652 let expires_in = token_response.expires_in
655 .map(std::time::Duration::from_secs)
656 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
657
658 let mut token = self.token_manager.create_auth_token(
659 &user_info.id,
660 token_response.scope
661 .unwrap_or_default()
662 .split_whitespace()
663 .map(|s| s.to_string())
664 .collect(),
665 self.name(),
666 Some(expires_in),
667 )?;
668
669 if let Some(refresh_token) = token_response.refresh_token {
671 token = token.with_refresh_token(refresh_token);
672 }
673
674 Ok(MethodResult::Success(Box::new(token)))
675 }
676
677 fn validate_config(&self) -> Result<()> {
678 if self.client_id.is_empty() {
679 return Err(AuthError::config("OAuth client ID is required"));
680 }
681 if self.client_secret.is_empty() {
682 return Err(AuthError::config("OAuth client secret is required"));
683 }
684 if self.redirect_uri.is_empty() {
685 return Err(AuthError::config("OAuth redirect URI is required"));
686 }
687 Ok(())
688 }
689
690 fn supports_refresh(&self) -> bool {
691 self.provider.config().supports_refresh
692 }
693
694 async fn refresh_token(&self, refresh_token: &str) -> Result<AuthToken> {
695 let token_response = self.provider.refresh_token(
696 &self.client_id,
697 &self.client_secret,
698 refresh_token,
699 ).await?;
700
701 let expires_in = token_response.expires_in
702 .map(Duration::from_secs)
703 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
704
705 let token = self.token_manager.create_auth_token(
708 "unknown", token_response.scope
710 .unwrap_or_default()
711 .split_whitespace()
712 .map(|s| s.to_string())
713 .collect(),
714 self.name(),
715 Some(expires_in),
716 )?;
717
718 Ok(token)
719 }
720}
721
722type PkceParams = (String, String);
724
725type AuthorizationUrlResult = (String, String, Option<PkceParams>);
727
728#[derive(Debug, Clone)]
730struct DefaultApiKeyValidator;
731
732#[async_trait]
733impl ApiKeyValidator for DefaultApiKeyValidator {
734 async fn validate_key(&self, _api_key: &str) -> Result<Option<UserInfo>> {
735 Ok(None)
737 }
738
739 async fn create_key(&self, _user_id: &str, _expires_in: Option<Duration>) -> Result<String> {
740 Ok(format!("api-{}", uuid::Uuid::new_v4()))
742 }
743
744 async fn revoke_key(&self, _api_key: &str) -> Result<()> {
745 Ok(())
747 }
748}
749
750#[cfg(test)]
751mod tests {
752 use super::*;
753
754 #[test]
755 fn test_mfa_challenge() {
756 let challenge = MfaChallenge::new(
757 MfaType::Totp,
758 "user123",
759 Duration::from_secs(300),
760 );
761
762 assert_eq!(challenge.user_id, "user123");
763 assert!(!challenge.is_expired());
764 assert_eq!(challenge.id().len(), 36); }
766
767 #[test]
768 fn test_jwt_method_creation() {
769 let jwt_method = JwtMethod::new()
770 .secret_key("test-secret")
771 .issuer("test-issuer")
772 .audience("test-audience");
773
774 assert_eq!(jwt_method.name(), "jwt");
775 assert_eq!(jwt_method.issuer, "test-issuer");
776 assert_eq!(jwt_method.audience, "test-audience");
777 }
778
779 #[test]
780 fn test_oauth2_method_creation() {
781 let oauth_method = OAuth2Method::new()
782 .provider(OAuthProvider::GitHub)
783 .client_id("test-client")
784 .client_secret("test-secret")
785 .redirect_uri("https://example.com/callback");
786
787 assert_eq!(oauth_method.name(), "oauth2");
788 assert_eq!(oauth_method.client_id, "test-client");
789 assert!(oauth_method.use_pkce);
790 }
791}