1use crate::credentials::{Credential, CredentialMetadata};
4use crate::errors::{AuthError, Result};
5use crate::providers::{OAuthProvider, generate_state, generate_pkce};
6use crate::tokens::{AuthToken, TokenManager};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub enum MethodResult {
15 Success(Box<AuthToken>),
17
18 MfaRequired(Box<MfaChallenge>),
20
21 Failure { reason: String },
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct MfaChallenge {
28 pub id: String,
30
31 pub mfa_type: MfaType,
33
34 pub user_id: String,
36
37 pub expires_at: chrono::DateTime<chrono::Utc>,
39
40 pub message: Option<String>,
42
43 pub data: HashMap<String, serde_json::Value>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum MfaType {
50 Totp,
52
53 Sms { phone_number: String },
55
56 Email { email_address: String },
58
59 Push { device_id: String },
61
62 SecurityKey,
64
65 BackupCode,
67}
68
69#[async_trait]
71pub trait AuthMethod: Send + Sync {
72 fn name(&self) -> &str;
74
75 async fn authenticate(
77 &self,
78 credential: &Credential,
79 metadata: &CredentialMetadata,
80 ) -> Result<MethodResult>;
81
82 fn validate_config(&self) -> Result<()>;
84
85 fn supports_refresh(&self) -> bool {
87 false
88 }
89
90 async fn refresh_token(&self, _refresh_token: &str) -> Result<AuthToken> {
92 Err(AuthError::auth_method(
93 self.name(),
94 "Token refresh not supported by this method".to_string(),
95 ))
96 }
97}
98
99pub struct PasswordMethod {
101 name: String,
102 password_verifier: Box<dyn PasswordVerifier>,
103 token_manager: TokenManager,
104 mfa_enabled: bool,
105 user_lookup: Box<dyn UserLookup>,
106}
107
108pub struct JwtMethod {
110 name: String,
111 token_manager: TokenManager,
112 issuer: String,
113 audience: String,
114}
115
116pub struct ApiKeyMethod {
118 name: String,
119 key_prefix: Option<String>,
120 header_name: String,
121 key_validator: Box<dyn ApiKeyValidator>,
122 token_manager: TokenManager,
123}
124
125pub struct OAuth2Method {
127 name: String,
128 provider: OAuthProvider,
129 client_id: String,
130 client_secret: String,
131 redirect_uri: String,
132 scopes: Vec<String>,
133 use_pkce: bool,
134 token_manager: TokenManager,
135}
136
137#[async_trait]
139pub trait PasswordVerifier: Send + Sync {
140 async fn verify_password(&self, username: &str, password: &str) -> Result<bool>;
142
143 async fn hash_password(&self, password: &str) -> Result<String>;
145}
146
147#[async_trait]
149pub trait UserLookup: Send + Sync {
150 async fn lookup_user(&self, username: &str) -> Result<Option<UserInfo>>;
152
153 async fn requires_mfa(&self, user_id: &str) -> Result<bool>;
155}
156
157#[async_trait]
159pub trait ApiKeyValidator: Send + Sync {
160 async fn validate_key(&self, api_key: &str) -> Result<Option<UserInfo>>;
162
163 async fn create_key(&self, user_id: &str, expires_in: Option<Duration>) -> Result<String>;
165
166 async fn revoke_key(&self, api_key: &str) -> Result<()>;
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct UserInfo {
173 pub id: String,
175
176 pub username: String,
178
179 pub email: Option<String>,
181
182 pub name: Option<String>,
184
185 pub roles: Vec<String>,
187
188 pub active: bool,
190
191 pub attributes: HashMap<String, serde_json::Value>,
193}
194
195impl MfaChallenge {
196 pub fn new(
198 mfa_type: MfaType,
199 user_id: impl Into<String>,
200 expires_in: Duration,
201 ) -> Self {
202 Self {
203 id: uuid::Uuid::new_v4().to_string(),
204 mfa_type,
205 user_id: user_id.into(),
206 expires_at: chrono::Utc::now() + chrono::Duration::from_std(expires_in).unwrap(),
207 message: None,
208 data: HashMap::new(),
209 }
210 }
211
212 pub fn id(&self) -> &str {
214 &self.id
215 }
216
217 pub fn is_expired(&self) -> bool {
219 chrono::Utc::now() > self.expires_at
220 }
221
222 pub fn with_message(mut self, message: impl Into<String>) -> Self {
224 self.message = Some(message.into());
225 self
226 }
227}
228
229impl PasswordMethod {
230 pub fn new(
232 password_verifier: Box<dyn PasswordVerifier>,
233 user_lookup: Box<dyn UserLookup>,
234 token_manager: TokenManager,
235 ) -> Self {
236 Self {
237 name: "password".to_string(),
238 password_verifier,
239 token_manager,
240 mfa_enabled: false,
241 user_lookup,
242 }
243 }
244
245 pub fn with_mfa(mut self, enabled: bool) -> Self {
247 self.mfa_enabled = enabled;
248 self
249 }
250}
251
252#[async_trait]
253impl AuthMethod for PasswordMethod {
254 fn name(&self) -> &str {
255 &self.name
256 }
257
258 async fn authenticate(
259 &self,
260 credential: &Credential,
261 _metadata: &CredentialMetadata,
262 ) -> Result<MethodResult> {
263 let (username, password) = match credential {
264 Credential::Password { username, password } => (username, password),
265 _ => return Err(AuthError::auth_method(
266 self.name(),
267 "Invalid credential type for password authentication".to_string(),
268 )),
269 };
270
271 if !self.password_verifier.verify_password(username, password).await? {
273 return Ok(MethodResult::Failure {
274 reason: "Invalid username or password".to_string(),
275 });
276 }
277
278 let user = self.user_lookup.lookup_user(username).await?
280 .ok_or_else(|| AuthError::auth_method(
281 self.name(),
282 "User not found".to_string(),
283 ))?;
284
285 if !user.active {
286 return Ok(MethodResult::Failure {
287 reason: "User account is disabled".to_string(),
288 });
289 }
290
291 if self.mfa_enabled && self.user_lookup.requires_mfa(&user.id).await? {
293 let challenge = MfaChallenge::new(
294 MfaType::Totp, &user.id,
296 Duration::from_secs(300), ).with_message("Please enter your MFA code");
298
299 return Ok(MethodResult::MfaRequired(Box::new(challenge)));
300 }
301
302 let token = self.token_manager.create_auth_token(
304 &user.id,
305 vec![], self.name(),
307 None,
308 )?;
309
310 Ok(MethodResult::Success(Box::new(token)))
311 }
312
313 fn validate_config(&self) -> Result<()> {
314 Ok(())
316 }
317}
318
319impl Default for JwtMethod {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl JwtMethod {
326 pub fn new() -> Self {
328 let token_manager = TokenManager::new_hmac(
329 b"default-secret", "default-issuer",
331 "default-audience",
332 );
333
334 Self {
335 name: "jwt".to_string(),
336 token_manager,
337 issuer: "default-issuer".to_string(),
338 audience: "default-audience".to_string(),
339 }
340 }
341
342 pub fn secret_key(mut self, secret: impl Into<String>) -> Self {
344 let secret = secret.into();
345 self.token_manager = TokenManager::new_hmac(
346 secret.as_bytes(),
347 &self.issuer,
348 &self.audience,
349 );
350 self
351 }
352
353 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
355 self.issuer = issuer.into();
356 self.token_manager = TokenManager::new_hmac(
357 b"default-secret", &self.issuer,
359 &self.audience,
360 );
361 self
362 }
363
364 pub fn audience(mut self, audience: impl Into<String>) -> Self {
366 self.audience = audience.into();
367 self.token_manager = TokenManager::new_hmac(
368 b"default-secret", &self.issuer,
370 &self.audience,
371 );
372 self
373 }
374
375 pub fn algorithm(self, _algorithm: impl Into<String>) -> Self {
377 self
379 }
380}
381
382#[async_trait]
383impl AuthMethod for JwtMethod {
384 fn name(&self) -> &str {
385 &self.name
386 }
387
388 async fn authenticate(
389 &self,
390 credential: &Credential,
391 _metadata: &CredentialMetadata,
392 ) -> Result<MethodResult> {
393 let token_str = match credential {
394 Credential::Jwt { token } => token,
395 Credential::Bearer { token } => token,
396 _ => return Err(AuthError::auth_method(
397 self.name(),
398 "Invalid credential type for JWT authentication".to_string(),
399 )),
400 };
401
402 let claims = self.token_manager.validate_jwt_token(token_str)?;
404
405 let remaining_seconds = (claims.exp - chrono::Utc::now().timestamp()).max(0) as u64;
407 let token = AuthToken::new(
408 claims.sub,
409 token_str.clone(),
410 std::time::Duration::from_secs(remaining_seconds),
411 self.name(),
412 ).with_scopes(claims.scope.split_whitespace().map(|s| s.to_string()).collect());
413
414 Ok(MethodResult::Success(Box::new(token)))
415 }
416
417 fn validate_config(&self) -> Result<()> {
418 Ok(())
420 }
421}
422
423impl Default for ApiKeyMethod {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429impl ApiKeyMethod {
430 pub fn new() -> Self {
432 let token_manager = TokenManager::new_hmac(
433 b"default-secret",
434 "api-key-issuer",
435 "api-key-audience",
436 );
437
438 Self {
439 name: "api-key".to_string(),
440 key_prefix: None,
441 header_name: "X-API-Key".to_string(),
442 key_validator: Box::new(DefaultApiKeyValidator),
443 token_manager,
444 }
445 }
446
447 pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
449 self.key_prefix = Some(prefix.into());
450 self
451 }
452
453 pub fn key_length(self, _length: usize) -> Self {
455 self
457 }
458
459 pub fn header_name(mut self, name: impl Into<String>) -> Self {
461 self.header_name = name.into();
462 self
463 }
464
465 pub fn key_validator(mut self, validator: Box<dyn ApiKeyValidator>) -> Self {
467 self.key_validator = validator;
468 self
469 }
470}
471
472#[async_trait]
473impl AuthMethod for ApiKeyMethod {
474 fn name(&self) -> &str {
475 &self.name
476 }
477
478 async fn authenticate(
479 &self,
480 credential: &Credential,
481 _metadata: &CredentialMetadata,
482 ) -> Result<MethodResult> {
483 let api_key = match credential {
484 Credential::ApiKey { key } => key,
485 _ => return Err(AuthError::auth_method(
486 self.name(),
487 "Invalid credential type for API key authentication".to_string(),
488 )),
489 };
490
491 if let Some(prefix) = &self.key_prefix {
493 if !api_key.starts_with(prefix) {
494 return Ok(MethodResult::Failure {
495 reason: "Invalid API key format".to_string(),
496 });
497 }
498 }
499
500 let user = self.key_validator.validate_key(api_key).await?
502 .ok_or_else(|| AuthError::auth_method(
503 self.name(),
504 "Invalid API key".to_string(),
505 ))?;
506
507 let token = self.token_manager.create_auth_token(
509 &user.id,
510 vec!["api".to_string()], self.name(),
512 Some(std::time::Duration::from_secs(3600)), )?;
514
515 Ok(MethodResult::Success(Box::new(token)))
516 }
517
518 fn validate_config(&self) -> Result<()> {
519 Ok(())
520 }
521}
522
523impl Default for OAuth2Method {
524 fn default() -> Self {
525 Self::new()
526 }
527}
528
529impl OAuth2Method {
530 pub fn new() -> Self {
532 let token_manager = TokenManager::new_hmac(
533 b"oauth-secret",
534 "oauth-issuer",
535 "oauth-audience",
536 );
537
538 Self {
539 name: "oauth2".to_string(),
540 provider: OAuthProvider::GitHub, client_id: String::new(),
542 client_secret: String::new(),
543 redirect_uri: String::new(),
544 scopes: Vec::new(),
545 use_pkce: true,
546 token_manager,
547 }
548 }
549
550 pub fn provider(mut self, provider: OAuthProvider) -> Self {
552 self.provider = provider;
553 self
554 }
555
556 pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
558 self.client_id = client_id.into();
559 self
560 }
561
562 pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
564 self.client_secret = client_secret.into();
565 self
566 }
567
568 pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
570 self.redirect_uri = redirect_uri.into();
571 self
572 }
573
574 pub fn scopes(mut self, scopes: Vec<String>) -> Self {
576 self.scopes = scopes;
577 self
578 }
579
580 pub fn use_pkce(mut self, use_pkce: bool) -> Self {
582 self.use_pkce = use_pkce;
583 self
584 }
585
586 pub fn authorization_url(&self) -> Result<AuthorizationUrlResult> {
588 let state = generate_state();
589 let pkce = if self.use_pkce {
590 Some(generate_pkce())
591 } else {
592 None
593 };
594
595 let url = self.provider.build_authorization_url(
596 &self.client_id,
597 &self.redirect_uri,
598 &state,
599 if self.scopes.is_empty() { None } else { Some(&self.scopes) },
600 pkce.as_ref().map(|(_, challenge)| challenge.as_str()),
601 )?;
602
603 Ok((url, state, pkce))
604 }
605}
606
607#[async_trait]
608impl AuthMethod for OAuth2Method {
609 fn name(&self) -> &str {
610 &self.name
611 }
612
613 async fn authenticate(
614 &self,
615 credential: &Credential,
616 _metadata: &CredentialMetadata,
617 ) -> Result<MethodResult> {
618 let (authorization_code, code_verifier) = match credential {
619 Credential::OAuth { authorization_code, code_verifier, .. } => {
620 (authorization_code, code_verifier.as_deref())
621 }
622 _ => return Err(AuthError::auth_method(
623 self.name(),
624 "Invalid credential type for OAuth authentication".to_string(),
625 )),
626 };
627
628 let token_response = self.provider.exchange_code(
630 &self.client_id,
631 &self.client_secret,
632 authorization_code,
633 &self.redirect_uri,
634 code_verifier,
635 ).await?;
636
637 let user_info = self.provider.get_user_info(&token_response.access_token).await?;
639
640 let expires_in = token_response.expires_in
643 .map(std::time::Duration::from_secs)
644 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
645
646 let mut token = self.token_manager.create_auth_token(
647 &user_info.id,
648 token_response.scope
649 .unwrap_or_default()
650 .split_whitespace()
651 .map(|s| s.to_string())
652 .collect(),
653 self.name(),
654 Some(expires_in),
655 )?;
656
657 if let Some(refresh_token) = token_response.refresh_token {
659 token = token.with_refresh_token(refresh_token);
660 }
661
662 Ok(MethodResult::Success(Box::new(token)))
663 }
664
665 fn validate_config(&self) -> Result<()> {
666 if self.client_id.is_empty() {
667 return Err(AuthError::config("OAuth client ID is required"));
668 }
669 if self.client_secret.is_empty() {
670 return Err(AuthError::config("OAuth client secret is required"));
671 }
672 if self.redirect_uri.is_empty() {
673 return Err(AuthError::config("OAuth redirect URI is required"));
674 }
675 Ok(())
676 }
677
678 fn supports_refresh(&self) -> bool {
679 self.provider.config().supports_refresh
680 }
681
682 async fn refresh_token(&self, refresh_token: &str) -> Result<AuthToken> {
683 let token_response = self.provider.refresh_token(
684 &self.client_id,
685 &self.client_secret,
686 refresh_token,
687 ).await?;
688
689 let expires_in = token_response.expires_in
690 .map(Duration::from_secs)
691 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
692
693 let token = self.token_manager.create_auth_token(
696 "unknown", token_response.scope
698 .unwrap_or_default()
699 .split_whitespace()
700 .map(|s| s.to_string())
701 .collect(),
702 self.name(),
703 Some(expires_in),
704 )?;
705
706 Ok(token)
707 }
708}
709
710type PkceParams = (String, String);
712
713type AuthorizationUrlResult = (String, String, Option<PkceParams>);
715
716#[derive(Debug, Clone)]
718struct DefaultApiKeyValidator;
719
720#[async_trait]
721impl ApiKeyValidator for DefaultApiKeyValidator {
722 async fn validate_key(&self, _api_key: &str) -> Result<Option<UserInfo>> {
723 Ok(None)
725 }
726
727 async fn create_key(&self, _user_id: &str, _expires_in: Option<Duration>) -> Result<String> {
728 Ok(format!("api-{}", uuid::Uuid::new_v4()))
730 }
731
732 async fn revoke_key(&self, _api_key: &str) -> Result<()> {
733 Ok(())
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741
742 #[test]
743 fn test_mfa_challenge() {
744 let challenge = MfaChallenge::new(
745 MfaType::Totp,
746 "user123",
747 Duration::from_secs(300),
748 );
749
750 assert_eq!(challenge.user_id, "user123");
751 assert!(!challenge.is_expired());
752 assert_eq!(challenge.id().len(), 36); }
754
755 #[test]
756 fn test_jwt_method_creation() {
757 let jwt_method = JwtMethod::new()
758 .secret_key("test-secret")
759 .issuer("test-issuer")
760 .audience("test-audience");
761
762 assert_eq!(jwt_method.name(), "jwt");
763 assert_eq!(jwt_method.issuer, "test-issuer");
764 assert_eq!(jwt_method.audience, "test-audience");
765 }
766
767 #[test]
768 fn test_oauth2_method_creation() {
769 let oauth_method = OAuth2Method::new()
770 .provider(OAuthProvider::GitHub)
771 .client_id("test-client")
772 .client_secret("test-secret")
773 .redirect_uri("https://example.com/callback");
774
775 assert_eq!(oauth_method.name(), "oauth2");
776 assert_eq!(oauth_method.client_id, "test-client");
777 assert!(oauth_method.use_pkce);
778 }
779}