1use crate::credentials::{Credential, CredentialMetadata};
4use crate::errors::{AuthError, Result};
5use crate::providers::{OAuthProvider, generate_state, generate_pkce};
6use crate::tokens::{AuthToken, TokenManager};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub enum MethodResult {
15 Success(Box<AuthToken>),
17
18 MfaRequired(Box<MfaChallenge>),
20
21 Failure { reason: String },
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct MfaChallenge {
28 pub id: String,
30
31 pub mfa_type: MfaType,
33
34 pub user_id: String,
36
37 pub expires_at: chrono::DateTime<chrono::Utc>,
39
40 pub message: Option<String>,
42
43 pub data: HashMap<String, serde_json::Value>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum MfaType {
50 Totp,
52
53 Sms { phone_number: String },
55
56 Email { email_address: String },
58
59 Push { device_id: String },
61
62 SecurityKey,
64
65 BackupCode,
67}
68
69#[async_trait]
71pub trait AuthMethod: Send + Sync {
72 fn name(&self) -> &str;
74
75 async fn authenticate(
77 &self,
78 credential: &Credential,
79 metadata: &CredentialMetadata,
80 ) -> Result<MethodResult>;
81
82 fn validate_config(&self) -> Result<()>;
84
85 fn supports_refresh(&self) -> bool {
87 false
88 }
89
90 async fn refresh_token(&self, _refresh_token: &str) -> Result<AuthToken> {
92 Err(AuthError::auth_method(
93 self.name(),
94 "Token refresh not supported by this method".to_string(),
95 ))
96 }
97}
98
99pub struct PasswordMethod {
101 name: String,
102 password_verifier: Box<dyn PasswordVerifier>,
103 token_manager: TokenManager,
104 mfa_enabled: bool,
105 user_lookup: Box<dyn UserLookup>,
106}
107
108pub struct JwtMethod {
110 name: String,
111 token_manager: TokenManager,
112 issuer: String,
113 audience: String,
114}
115
116pub struct ApiKeyMethod {
118 name: String,
119 key_prefix: Option<String>,
120 header_name: String,
121 key_validator: Box<dyn ApiKeyValidator>,
122 token_manager: TokenManager,
123}
124
125pub struct OAuth2Method {
127 name: String,
128 provider: OAuthProvider,
129 client_id: String,
130 client_secret: String,
131 redirect_uri: String,
132 scopes: Vec<String>,
133 use_pkce: bool,
134 token_manager: TokenManager,
135}
136
137#[async_trait]
139pub trait PasswordVerifier: Send + Sync {
140 async fn verify_password(&self, username: &str, password: &str) -> Result<bool>;
142
143 async fn hash_password(&self, password: &str) -> Result<String>;
145}
146
147#[async_trait]
149pub trait UserLookup: Send + Sync {
150 async fn lookup_user(&self, username: &str) -> Result<Option<UserInfo>>;
152
153 async fn requires_mfa(&self, user_id: &str) -> Result<bool>;
155}
156
157#[async_trait]
159pub trait ApiKeyValidator: Send + Sync {
160 async fn validate_key(&self, api_key: &str) -> Result<Option<UserInfo>>;
162
163 async fn create_key(&self, user_id: &str, expires_in: Option<Duration>) -> Result<String>;
165
166 async fn revoke_key(&self, api_key: &str) -> Result<()>;
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct UserInfo {
173 pub id: String,
175
176 pub username: String,
178
179 pub email: Option<String>,
181
182 pub name: Option<String>,
184
185 pub roles: Vec<String>,
187
188 pub active: bool,
190
191 pub attributes: HashMap<String, serde_json::Value>,
193}
194
195impl MfaChallenge {
196 pub fn new(
198 mfa_type: MfaType,
199 user_id: impl Into<String>,
200 expires_in: Duration,
201 ) -> Self {
202 Self {
203 id: uuid::Uuid::new_v4().to_string(),
204 mfa_type,
205 user_id: user_id.into(),
206 expires_at: chrono::Utc::now() + chrono::Duration::from_std(expires_in).unwrap(),
207 message: None,
208 data: HashMap::new(),
209 }
210 }
211
212 pub fn id(&self) -> &str {
214 &self.id
215 }
216
217 pub fn is_expired(&self) -> bool {
219 chrono::Utc::now() > self.expires_at
220 }
221
222 pub fn with_message(mut self, message: impl Into<String>) -> Self {
224 self.message = Some(message.into());
225 self
226 }
227}
228
229impl PasswordMethod {
230 pub fn new(
232 password_verifier: Box<dyn PasswordVerifier>,
233 user_lookup: Box<dyn UserLookup>,
234 token_manager: TokenManager,
235 ) -> Self {
236 Self {
237 name: "password".to_string(),
238 password_verifier,
239 token_manager,
240 mfa_enabled: false,
241 user_lookup,
242 }
243 }
244
245 pub fn with_mfa(mut self, enabled: bool) -> Self {
247 self.mfa_enabled = enabled;
248 self
249 }
250}
251
252#[async_trait]
253impl AuthMethod for PasswordMethod {
254 fn name(&self) -> &str {
255 &self.name
256 }
257
258 async fn authenticate(
259 &self,
260 credential: &Credential,
261 _metadata: &CredentialMetadata,
262 ) -> Result<MethodResult> {
263 let (username, password) = match credential {
264 Credential::Password { username, password } => (username, password),
265 _ => return Err(AuthError::auth_method(
266 self.name(),
267 "Invalid credential type for password authentication".to_string(),
268 )),
269 };
270
271 if !self.password_verifier.verify_password(username, password).await? {
273 return Ok(MethodResult::Failure {
274 reason: "Invalid username or password".to_string(),
275 });
276 }
277
278 let user = self.user_lookup.lookup_user(username).await?
280 .ok_or_else(|| AuthError::auth_method(
281 self.name(),
282 "User not found".to_string(),
283 ))?;
284
285 if !user.active {
286 return Ok(MethodResult::Failure {
287 reason: "User account is disabled".to_string(),
288 });
289 }
290
291 if self.mfa_enabled && self.user_lookup.requires_mfa(&user.id).await? {
293 let challenge = MfaChallenge::new(
294 MfaType::Totp, &user.id,
296 Duration::from_secs(300), ).with_message("Please enter your MFA code");
298
299 return Ok(MethodResult::MfaRequired(Box::new(challenge)));
300 }
301
302 let token = self.token_manager.create_auth_token(
304 &user.id,
305 vec![], self.name(),
307 None,
308 )?;
309
310 Ok(MethodResult::Success(Box::new(token)))
311 }
312
313 fn validate_config(&self) -> Result<()> {
314 Ok(())
316 }
317}
318
319impl Default for JwtMethod {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl JwtMethod {
326 pub fn new() -> Self {
328 let token_manager = TokenManager::new_hmac(
329 b"default-secret", "default-issuer",
331 "default-audience",
332 );
333
334 Self {
335 name: "jwt".to_string(),
336 token_manager,
337 issuer: "default-issuer".to_string(),
338 audience: "default-audience".to_string(),
339 }
340 }
341
342 pub fn secret_key(mut self, secret: impl Into<String>) -> Self {
344 let secret = secret.into();
345 self.token_manager = TokenManager::new_hmac(
346 secret.as_bytes(),
347 &self.issuer,
348 &self.audience,
349 );
350 self
351 }
352
353 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
355 self.issuer = issuer.into();
356 self.token_manager = TokenManager::new_hmac(
357 b"default-secret", &self.issuer,
359 &self.audience,
360 );
361 self
362 }
363
364 pub fn audience(mut self, audience: impl Into<String>) -> Self {
366 self.audience = audience.into();
367 self.token_manager = TokenManager::new_hmac(
368 b"default-secret", &self.issuer,
370 &self.audience,
371 );
372 self
373 }
374}
375
376#[async_trait]
377impl AuthMethod for JwtMethod {
378 fn name(&self) -> &str {
379 &self.name
380 }
381
382 async fn authenticate(
383 &self,
384 credential: &Credential,
385 _metadata: &CredentialMetadata,
386 ) -> Result<MethodResult> {
387 let token_str = match credential {
388 Credential::Jwt { token } => token,
389 Credential::Bearer { token } => token,
390 _ => return Err(AuthError::auth_method(
391 self.name(),
392 "Invalid credential type for JWT authentication".to_string(),
393 )),
394 };
395
396 let claims = self.token_manager.validate_jwt_token(token_str)?;
398
399 let remaining_seconds = (claims.exp - chrono::Utc::now().timestamp()).max(0) as u64;
401 let token = AuthToken::new(
402 claims.sub,
403 token_str.clone(),
404 std::time::Duration::from_secs(remaining_seconds),
405 self.name(),
406 ).with_scopes(claims.scope.split_whitespace().map(|s| s.to_string()).collect());
407
408 Ok(MethodResult::Success(Box::new(token)))
409 }
410
411 fn validate_config(&self) -> Result<()> {
412 Ok(())
414 }
415}
416
417impl Default for ApiKeyMethod {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423impl ApiKeyMethod {
424 pub fn new() -> Self {
426 let token_manager = TokenManager::new_hmac(
427 b"default-secret",
428 "api-key-issuer",
429 "api-key-audience",
430 );
431
432 Self {
433 name: "api-key".to_string(),
434 key_prefix: None,
435 header_name: "X-API-Key".to_string(),
436 key_validator: Box::new(DefaultApiKeyValidator),
437 token_manager,
438 }
439 }
440
441 pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
443 self.key_prefix = Some(prefix.into());
444 self
445 }
446
447 pub fn header_name(mut self, name: impl Into<String>) -> Self {
449 self.header_name = name.into();
450 self
451 }
452
453 pub fn key_validator(mut self, validator: Box<dyn ApiKeyValidator>) -> Self {
455 self.key_validator = validator;
456 self
457 }
458}
459
460#[async_trait]
461impl AuthMethod for ApiKeyMethod {
462 fn name(&self) -> &str {
463 &self.name
464 }
465
466 async fn authenticate(
467 &self,
468 credential: &Credential,
469 _metadata: &CredentialMetadata,
470 ) -> Result<MethodResult> {
471 let api_key = match credential {
472 Credential::ApiKey { key } => key,
473 _ => return Err(AuthError::auth_method(
474 self.name(),
475 "Invalid credential type for API key authentication".to_string(),
476 )),
477 };
478
479 if let Some(prefix) = &self.key_prefix {
481 if !api_key.starts_with(prefix) {
482 return Ok(MethodResult::Failure {
483 reason: "Invalid API key format".to_string(),
484 });
485 }
486 }
487
488 let user = self.key_validator.validate_key(api_key).await?
490 .ok_or_else(|| AuthError::auth_method(
491 self.name(),
492 "Invalid API key".to_string(),
493 ))?;
494
495 let token = self.token_manager.create_auth_token(
497 &user.id,
498 vec!["api".to_string()], self.name(),
500 Some(std::time::Duration::from_secs(3600)), )?;
502
503 Ok(MethodResult::Success(Box::new(token)))
504 }
505
506 fn validate_config(&self) -> Result<()> {
507 Ok(())
508 }
509}
510
511impl Default for OAuth2Method {
512 fn default() -> Self {
513 Self::new()
514 }
515}
516
517impl OAuth2Method {
518 pub fn new() -> Self {
520 let token_manager = TokenManager::new_hmac(
521 b"oauth-secret",
522 "oauth-issuer",
523 "oauth-audience",
524 );
525
526 Self {
527 name: "oauth2".to_string(),
528 provider: OAuthProvider::GitHub, client_id: String::new(),
530 client_secret: String::new(),
531 redirect_uri: String::new(),
532 scopes: Vec::new(),
533 use_pkce: true,
534 token_manager,
535 }
536 }
537
538 pub fn provider(mut self, provider: OAuthProvider) -> Self {
540 self.provider = provider;
541 self
542 }
543
544 pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
546 self.client_id = client_id.into();
547 self
548 }
549
550 pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
552 self.client_secret = client_secret.into();
553 self
554 }
555
556 pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
558 self.redirect_uri = redirect_uri.into();
559 self
560 }
561
562 pub fn scopes(mut self, scopes: Vec<String>) -> Self {
564 self.scopes = scopes;
565 self
566 }
567
568 pub fn use_pkce(mut self, use_pkce: bool) -> Self {
570 self.use_pkce = use_pkce;
571 self
572 }
573
574 pub fn authorization_url(&self) -> Result<AuthorizationUrlResult> {
576 let state = generate_state();
577 let pkce = if self.use_pkce {
578 Some(generate_pkce())
579 } else {
580 None
581 };
582
583 let url = self.provider.build_authorization_url(
584 &self.client_id,
585 &self.redirect_uri,
586 &state,
587 if self.scopes.is_empty() { None } else { Some(&self.scopes) },
588 pkce.as_ref().map(|(_, challenge)| challenge.as_str()),
589 )?;
590
591 Ok((url, state, pkce))
592 }
593}
594
595#[async_trait]
596impl AuthMethod for OAuth2Method {
597 fn name(&self) -> &str {
598 &self.name
599 }
600
601 async fn authenticate(
602 &self,
603 credential: &Credential,
604 _metadata: &CredentialMetadata,
605 ) -> Result<MethodResult> {
606 let (authorization_code, code_verifier) = match credential {
607 Credential::OAuth { authorization_code, code_verifier, .. } => {
608 (authorization_code, code_verifier.as_deref())
609 }
610 _ => return Err(AuthError::auth_method(
611 self.name(),
612 "Invalid credential type for OAuth authentication".to_string(),
613 )),
614 };
615
616 let token_response = self.provider.exchange_code(
618 &self.client_id,
619 &self.client_secret,
620 authorization_code,
621 &self.redirect_uri,
622 code_verifier,
623 ).await?;
624
625 let user_info = self.provider.get_user_info(&token_response.access_token).await?;
627
628 let expires_in = token_response.expires_in
631 .map(std::time::Duration::from_secs)
632 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
633
634 let mut token = self.token_manager.create_auth_token(
635 &user_info.id,
636 token_response.scope
637 .unwrap_or_default()
638 .split_whitespace()
639 .map(|s| s.to_string())
640 .collect(),
641 self.name(),
642 Some(expires_in),
643 )?;
644
645 if let Some(refresh_token) = token_response.refresh_token {
647 token = token.with_refresh_token(refresh_token);
648 }
649
650 Ok(MethodResult::Success(Box::new(token)))
651 }
652
653 fn validate_config(&self) -> Result<()> {
654 if self.client_id.is_empty() {
655 return Err(AuthError::config("OAuth client ID is required"));
656 }
657 if self.client_secret.is_empty() {
658 return Err(AuthError::config("OAuth client secret is required"));
659 }
660 if self.redirect_uri.is_empty() {
661 return Err(AuthError::config("OAuth redirect URI is required"));
662 }
663 Ok(())
664 }
665
666 fn supports_refresh(&self) -> bool {
667 self.provider.config().supports_refresh
668 }
669
670 async fn refresh_token(&self, refresh_token: &str) -> Result<AuthToken> {
671 let token_response = self.provider.refresh_token(
672 &self.client_id,
673 &self.client_secret,
674 refresh_token,
675 ).await?;
676
677 let expires_in = token_response.expires_in
678 .map(Duration::from_secs)
679 .unwrap_or_else(|| std::time::Duration::from_secs(3600));
680
681 let token = self.token_manager.create_auth_token(
684 "unknown", token_response.scope
686 .unwrap_or_default()
687 .split_whitespace()
688 .map(|s| s.to_string())
689 .collect(),
690 self.name(),
691 Some(expires_in),
692 )?;
693
694 Ok(token)
695 }
696}
697
698type PkceParams = (String, String);
700
701type AuthorizationUrlResult = (String, String, Option<PkceParams>);
703
704#[derive(Debug, Clone)]
706struct DefaultApiKeyValidator;
707
708#[async_trait]
709impl ApiKeyValidator for DefaultApiKeyValidator {
710 async fn validate_key(&self, _api_key: &str) -> Result<Option<UserInfo>> {
711 Ok(None)
713 }
714
715 async fn create_key(&self, _user_id: &str, _expires_in: Option<Duration>) -> Result<String> {
716 Ok(format!("api-{}", uuid::Uuid::new_v4()))
718 }
719
720 async fn revoke_key(&self, _api_key: &str) -> Result<()> {
721 Ok(())
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[test]
731 fn test_mfa_challenge() {
732 let challenge = MfaChallenge::new(
733 MfaType::Totp,
734 "user123",
735 Duration::from_secs(300),
736 );
737
738 assert_eq!(challenge.user_id, "user123");
739 assert!(!challenge.is_expired());
740 assert_eq!(challenge.id().len(), 36); }
742
743 #[test]
744 fn test_jwt_method_creation() {
745 let jwt_method = JwtMethod::new()
746 .secret_key("test-secret")
747 .issuer("test-issuer")
748 .audience("test-audience");
749
750 assert_eq!(jwt_method.name(), "jwt");
751 assert_eq!(jwt_method.issuer, "test-issuer");
752 assert_eq!(jwt_method.audience, "test-audience");
753 }
754
755 #[test]
756 fn test_oauth2_method_creation() {
757 let oauth_method = OAuth2Method::new()
758 .provider(OAuthProvider::GitHub)
759 .client_id("test-client")
760 .client_secret("test-secret")
761 .redirect_uri("https://example.com/callback");
762
763 assert_eq!(oauth_method.name(), "oauth2");
764 assert_eq!(oauth_method.client_id, "test-client");
765 assert!(oauth_method.use_pkce);
766 }
767}