auth_framework/
methods.rs

1//! Authentication method implementations.
2
3use crate::credentials::{Credential, CredentialMetadata};
4use crate::errors::{AuthError, Result};
5use crate::providers::{OAuthProvider, generate_state, generate_pkce};
6use crate::tokens::{AuthToken, TokenManager};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::Duration;
11
12/// Result of an authentication attempt.
13#[derive(Debug, Clone)]
14pub enum MethodResult {
15    /// Authentication was successful
16    Success(Box<AuthToken>),
17    
18    /// Multi-factor authentication is required
19    MfaRequired(Box<MfaChallenge>),
20    
21    /// Authentication failed
22    Failure { reason: String },
23}
24
25/// Multi-factor authentication challenge.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct MfaChallenge {
28    /// Unique challenge ID
29    pub id: String,
30    
31    /// Type of MFA required
32    pub mfa_type: MfaType,
33    
34    /// User ID this challenge is for
35    pub user_id: String,
36    
37    /// When the challenge expires
38    pub expires_at: chrono::DateTime<chrono::Utc>,
39    
40    /// Optional message or instructions
41    pub message: Option<String>,
42    
43    /// Additional challenge data
44    pub data: HashMap<String, serde_json::Value>,
45}
46
47/// Types of multi-factor authentication.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum MfaType {
50    /// Time-based one-time password (TOTP)
51    Totp,
52    
53    /// SMS verification code
54    Sms { phone_number: String },
55    
56    /// Email verification code
57    Email { email_address: String },
58    
59    /// Push notification
60    Push { device_id: String },
61    
62    /// Hardware security key
63    SecurityKey,
64    
65    /// Backup codes
66    BackupCode,
67}
68
69/// Trait for authentication methods.
70#[async_trait]
71pub trait AuthMethod: Send + Sync {
72    /// Get the name of this authentication method.
73    fn name(&self) -> &str;
74    
75    /// Authenticate using the provided credentials.
76    async fn authenticate(
77        &self,
78        credential: &Credential,
79        metadata: &CredentialMetadata,
80    ) -> Result<MethodResult>;
81    
82    /// Validate configuration for this method.
83    fn validate_config(&self) -> Result<()>;
84    
85    /// Check if this method supports refresh tokens.
86    fn supports_refresh(&self) -> bool {
87        false
88    }
89    
90    /// Refresh a token if supported.
91    async fn refresh_token(&self, _refresh_token: &str) -> Result<AuthToken> {
92        Err(AuthError::auth_method(
93            self.name(),
94            "Token refresh not supported by this method".to_string(),
95        ))
96    }
97}
98
99/// Password-based authentication method.
100pub struct PasswordMethod {
101    name: String,
102    password_verifier: Box<dyn PasswordVerifier>,
103    token_manager: TokenManager,
104    mfa_enabled: bool,
105    user_lookup: Box<dyn UserLookup>,
106}
107
108/// JWT-based authentication method.
109pub struct JwtMethod {
110    name: String,
111    token_manager: TokenManager,
112    issuer: String,
113    audience: String,
114}
115
116/// API key authentication method.
117pub struct ApiKeyMethod {
118    name: String,
119    key_prefix: Option<String>,
120    header_name: String,
121    key_validator: Box<dyn ApiKeyValidator>,
122    token_manager: TokenManager,
123}
124
125/// OAuth 2.0 authentication method.
126pub struct OAuth2Method {
127    name: String,
128    provider: OAuthProvider,
129    client_id: String,
130    client_secret: String,
131    redirect_uri: String,
132    scopes: Vec<String>,
133    use_pkce: bool,
134    token_manager: TokenManager,
135}
136
137/// Trait for password verification.
138#[async_trait]
139pub trait PasswordVerifier: Send + Sync {
140    /// Verify a password against a hash.
141    async fn verify_password(&self, username: &str, password: &str) -> Result<bool>;
142    
143    /// Hash a password.
144    async fn hash_password(&self, password: &str) -> Result<String>;
145}
146
147/// Trait for user lookup operations.
148#[async_trait]
149pub trait UserLookup: Send + Sync {
150    /// Look up a user by username.
151    async fn lookup_user(&self, username: &str) -> Result<Option<UserInfo>>;
152    
153    /// Check if a user requires MFA.
154    async fn requires_mfa(&self, user_id: &str) -> Result<bool>;
155}
156
157/// Trait for API key validation.
158#[async_trait]
159pub trait ApiKeyValidator: Send + Sync {
160    /// Validate an API key and return associated user info.
161    async fn validate_key(&self, api_key: &str) -> Result<Option<UserInfo>>;
162    
163    /// Create a new API key for a user.
164    async fn create_key(&self, user_id: &str, expires_in: Option<Duration>) -> Result<String>;
165    
166    /// Revoke an API key.
167    async fn revoke_key(&self, api_key: &str) -> Result<()>;
168}
169
170/// Basic user information.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct UserInfo {
173    /// User ID
174    pub id: String,
175    
176    /// Username
177    pub username: String,
178    
179    /// Email address
180    pub email: Option<String>,
181    
182    /// Display name
183    pub name: Option<String>,
184    
185    /// User roles
186    pub roles: Vec<String>,
187    
188    /// Whether the user is active
189    pub active: bool,
190    
191    /// Additional user attributes
192    pub attributes: HashMap<String, serde_json::Value>,
193}
194
195impl MfaChallenge {
196    /// Create a new MFA challenge.
197    pub fn new(
198        mfa_type: MfaType,
199        user_id: impl Into<String>,
200        expires_in: Duration,
201    ) -> Self {
202        Self {
203            id: uuid::Uuid::new_v4().to_string(),
204            mfa_type,
205            user_id: user_id.into(),
206            expires_at: chrono::Utc::now() + chrono::Duration::from_std(expires_in).unwrap(),
207            message: None,
208            data: HashMap::new(),
209        }
210    }
211
212    /// Get the challenge ID.
213    pub fn id(&self) -> &str {
214        &self.id
215    }
216
217    /// Check if the challenge has expired.
218    pub fn is_expired(&self) -> bool {
219        chrono::Utc::now() > self.expires_at
220    }
221
222    /// Set a message for the challenge.
223    pub fn with_message(mut self, message: impl Into<String>) -> Self {
224        self.message = Some(message.into());
225        self
226    }
227}
228
229impl PasswordMethod {
230    /// Create a new password authentication method.
231    pub fn new(
232        password_verifier: Box<dyn PasswordVerifier>,
233        user_lookup: Box<dyn UserLookup>,
234        token_manager: TokenManager,
235    ) -> Self {
236        Self {
237            name: "password".to_string(),
238            password_verifier,
239            token_manager,
240            mfa_enabled: false,
241            user_lookup,
242        }
243    }
244
245    /// Enable or disable MFA for this method.
246    pub fn with_mfa(mut self, enabled: bool) -> Self {
247        self.mfa_enabled = enabled;
248        self
249    }
250}
251
252#[async_trait]
253impl AuthMethod for PasswordMethod {
254    fn name(&self) -> &str {
255        &self.name
256    }
257
258    async fn authenticate(
259        &self,
260        credential: &Credential,
261        _metadata: &CredentialMetadata,
262    ) -> Result<MethodResult> {
263        let (username, password) = match credential {
264            Credential::Password { username, password } => (username, password),
265            _ => return Err(AuthError::auth_method(
266                self.name(),
267                "Invalid credential type for password authentication".to_string(),
268            )),
269        };
270
271        // Verify password
272        if !self.password_verifier.verify_password(username, password).await? {
273            return Ok(MethodResult::Failure {
274                reason: "Invalid username or password".to_string(),
275            });
276        }
277
278        // Look up user
279        let user = self.user_lookup.lookup_user(username).await?
280            .ok_or_else(|| AuthError::auth_method(
281                self.name(),
282                "User not found".to_string(),
283            ))?;
284
285        if !user.active {
286            return Ok(MethodResult::Failure {
287                reason: "User account is disabled".to_string(),
288            });
289        }
290
291        // Check if MFA is required
292        if self.mfa_enabled && self.user_lookup.requires_mfa(&user.id).await? {
293            let challenge = MfaChallenge::new(
294                MfaType::Totp, // Default to TOTP, could be configurable
295                &user.id,
296                Duration::from_secs(300), // 5 minutes
297            ).with_message("Please enter your MFA code");
298
299            return Ok(MethodResult::MfaRequired(Box::new(challenge)));
300        }
301
302        // Create token
303        let token = self.token_manager.create_auth_token(
304            &user.id,
305            vec![], // Scopes would be determined by user roles
306            self.name(),
307            None,
308        )?;
309
310        Ok(MethodResult::Success(Box::new(token)))
311    }
312
313    fn validate_config(&self) -> Result<()> {
314        // Validation would depend on the specific implementation
315        Ok(())
316    }
317}
318
319impl Default for JwtMethod {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325impl JwtMethod {
326    /// Create a new JWT authentication method.
327    pub fn new() -> Self {
328        let token_manager = TokenManager::new_hmac(
329            b"default-secret", // This should be configurable
330            "default-issuer",
331            "default-audience",
332        );
333
334        Self {
335            name: "jwt".to_string(),
336            token_manager,
337            issuer: "default-issuer".to_string(),
338            audience: "default-audience".to_string(),
339        }
340    }
341
342    /// Set the secret key for JWT signing.
343    pub fn secret_key(mut self, secret: impl Into<String>) -> Self {
344        let secret = secret.into();
345        self.token_manager = TokenManager::new_hmac(
346            secret.as_bytes(),
347            &self.issuer,
348            &self.audience,
349        );
350        self
351    }
352
353    /// Set the issuer.
354    pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
355        self.issuer = issuer.into();
356        self.token_manager = TokenManager::new_hmac(
357            b"default-secret", // This should use the actual secret
358            &self.issuer,
359            &self.audience,
360        );
361        self
362    }
363
364    /// Set the audience.
365    pub fn audience(mut self, audience: impl Into<String>) -> Self {
366        self.audience = audience.into();
367        self.token_manager = TokenManager::new_hmac(
368            b"default-secret", // This should use the actual secret
369            &self.issuer,
370            &self.audience,
371        );
372        self
373    }
374}
375
376#[async_trait]
377impl AuthMethod for JwtMethod {
378    fn name(&self) -> &str {
379        &self.name
380    }
381
382    async fn authenticate(
383        &self,
384        credential: &Credential,
385        _metadata: &CredentialMetadata,
386    ) -> Result<MethodResult> {
387        let token_str = match credential {
388            Credential::Jwt { token } => token,
389            Credential::Bearer { token } => token,
390            _ => return Err(AuthError::auth_method(
391                self.name(),
392                "Invalid credential type for JWT authentication".to_string(),
393            )),
394        };
395
396        // Validate JWT
397        let claims = self.token_manager.validate_jwt_token(token_str)?;
398        
399        // Create auth token from JWT claims
400        let remaining_seconds = (claims.exp - chrono::Utc::now().timestamp()).max(0) as u64;
401        let token = AuthToken::new(
402            claims.sub,
403            token_str.clone(),
404            std::time::Duration::from_secs(remaining_seconds),
405            self.name(),
406        ).with_scopes(claims.scope.split_whitespace().map(|s| s.to_string()).collect());
407
408        Ok(MethodResult::Success(Box::new(token)))
409    }
410
411    fn validate_config(&self) -> Result<()> {
412        // Validate JWT configuration
413        Ok(())
414    }
415}
416
417impl Default for ApiKeyMethod {
418    fn default() -> Self {
419        Self::new()
420    }
421}
422
423impl ApiKeyMethod {
424    /// Create a new API key authentication method.
425    pub fn new() -> Self {
426        let token_manager = TokenManager::new_hmac(
427            b"default-secret",
428            "api-key-issuer",
429            "api-key-audience",
430        );
431
432        Self {
433            name: "api-key".to_string(),
434            key_prefix: None,
435            header_name: "X-API-Key".to_string(),
436            key_validator: Box::new(DefaultApiKeyValidator),
437            token_manager,
438        }
439    }
440
441    /// Set the key prefix.
442    pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
443        self.key_prefix = Some(prefix.into());
444        self
445    }
446
447    /// Set the header name.
448    pub fn header_name(mut self, name: impl Into<String>) -> Self {
449        self.header_name = name.into();
450        self
451    }
452
453    /// Set the key validator.
454    pub fn key_validator(mut self, validator: Box<dyn ApiKeyValidator>) -> Self {
455        self.key_validator = validator;
456        self
457    }
458}
459
460#[async_trait]
461impl AuthMethod for ApiKeyMethod {
462    fn name(&self) -> &str {
463        &self.name
464    }
465
466    async fn authenticate(
467        &self,
468        credential: &Credential,
469        _metadata: &CredentialMetadata,
470    ) -> Result<MethodResult> {
471        let api_key = match credential {
472            Credential::ApiKey { key } => key,
473            _ => return Err(AuthError::auth_method(
474                self.name(),
475                "Invalid credential type for API key authentication".to_string(),
476            )),
477        };
478
479        // Validate prefix if configured
480        if let Some(prefix) = &self.key_prefix {
481            if !api_key.starts_with(prefix) {
482                return Ok(MethodResult::Failure {
483                    reason: "Invalid API key format".to_string(),
484                });
485            }
486        }
487
488        // Validate key
489        let user = self.key_validator.validate_key(api_key).await?
490            .ok_or_else(|| AuthError::auth_method(
491                self.name(),
492                "Invalid API key".to_string(),
493            ))?;
494
495        // Create token
496        let token = self.token_manager.create_auth_token(
497            &user.id,
498            vec!["api".to_string()], // Default API scope
499            self.name(),
500            Some(std::time::Duration::from_secs(3600)), // 1 hour default
501        )?;
502
503        Ok(MethodResult::Success(Box::new(token)))
504    }
505
506    fn validate_config(&self) -> Result<()> {
507        Ok(())
508    }
509}
510
511impl Default for OAuth2Method {
512    fn default() -> Self {
513        Self::new()
514    }
515}
516
517impl OAuth2Method {
518    /// Create a new OAuth 2.0 authentication method.
519    pub fn new() -> Self {
520        let token_manager = TokenManager::new_hmac(
521            b"oauth-secret",
522            "oauth-issuer",
523            "oauth-audience",
524        );
525
526        Self {
527            name: "oauth2".to_string(),
528            provider: OAuthProvider::GitHub, // Default provider
529            client_id: String::new(),
530            client_secret: String::new(),
531            redirect_uri: String::new(),
532            scopes: Vec::new(),
533            use_pkce: true,
534            token_manager,
535        }
536    }
537
538    /// Set the OAuth provider.
539    pub fn provider(mut self, provider: OAuthProvider) -> Self {
540        self.provider = provider;
541        self
542    }
543
544    /// Set the client ID.
545    pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
546        self.client_id = client_id.into();
547        self
548    }
549
550    /// Set the client secret.
551    pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
552        self.client_secret = client_secret.into();
553        self
554    }
555
556    /// Set the redirect URI.
557    pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
558        self.redirect_uri = redirect_uri.into();
559        self
560    }
561
562    /// Set the scopes.
563    pub fn scopes(mut self, scopes: Vec<String>) -> Self {
564        self.scopes = scopes;
565        self
566    }
567
568    /// Enable or disable PKCE.
569    pub fn use_pkce(mut self, use_pkce: bool) -> Self {
570        self.use_pkce = use_pkce;
571        self
572    }
573
574    /// Generate authorization URL.
575    pub fn authorization_url(&self) -> Result<AuthorizationUrlResult> {
576        let state = generate_state();
577        let pkce = if self.use_pkce {
578            Some(generate_pkce())
579        } else {
580            None
581        };
582
583        let url = self.provider.build_authorization_url(
584            &self.client_id,
585            &self.redirect_uri,
586            &state,
587            if self.scopes.is_empty() { None } else { Some(&self.scopes) },
588            pkce.as_ref().map(|(_, challenge)| challenge.as_str()),
589        )?;
590
591        Ok((url, state, pkce))
592    }
593}
594
595#[async_trait]
596impl AuthMethod for OAuth2Method {
597    fn name(&self) -> &str {
598        &self.name
599    }
600
601    async fn authenticate(
602        &self,
603        credential: &Credential,
604        _metadata: &CredentialMetadata,
605    ) -> Result<MethodResult> {
606        let (authorization_code, code_verifier) = match credential {
607            Credential::OAuth { authorization_code, code_verifier, .. } => {
608                (authorization_code, code_verifier.as_deref())
609            }
610            _ => return Err(AuthError::auth_method(
611                self.name(),
612                "Invalid credential type for OAuth authentication".to_string(),
613            )),
614        };
615
616        // Exchange authorization code for tokens
617        let token_response = self.provider.exchange_code(
618            &self.client_id,
619            &self.client_secret,
620            authorization_code,
621            &self.redirect_uri,
622            code_verifier,
623        ).await?;
624
625        // Get user info
626        let user_info = self.provider.get_user_info(&token_response.access_token).await?;
627
628        // Create auth token
629        // Convert to duration with proper type
630        let expires_in = token_response.expires_in
631            .map(std::time::Duration::from_secs)
632            .unwrap_or_else(|| std::time::Duration::from_secs(3600));
633
634        let mut token = self.token_manager.create_auth_token(
635            &user_info.id,
636            token_response.scope
637                .unwrap_or_default()
638                .split_whitespace()
639                .map(|s| s.to_string())
640                .collect(),
641            self.name(),
642            Some(expires_in),
643        )?;
644
645        // Set refresh token if available
646        if let Some(refresh_token) = token_response.refresh_token {
647            token = token.with_refresh_token(refresh_token);
648        }
649
650        Ok(MethodResult::Success(Box::new(token)))
651    }
652
653    fn validate_config(&self) -> Result<()> {
654        if self.client_id.is_empty() {
655            return Err(AuthError::config("OAuth client ID is required"));
656        }
657        if self.client_secret.is_empty() {
658            return Err(AuthError::config("OAuth client secret is required"));
659        }
660        if self.redirect_uri.is_empty() {
661            return Err(AuthError::config("OAuth redirect URI is required"));
662        }
663        Ok(())
664    }
665
666    fn supports_refresh(&self) -> bool {
667        self.provider.config().supports_refresh
668    }
669
670    async fn refresh_token(&self, refresh_token: &str) -> Result<AuthToken> {
671        let token_response = self.provider.refresh_token(
672            &self.client_id,
673            &self.client_secret,
674            refresh_token,
675        ).await?;
676
677        let expires_in = token_response.expires_in
678            .map(Duration::from_secs)
679            .unwrap_or_else(|| std::time::Duration::from_secs(3600));
680
681        // We need user info to create the token, but we don't have it from refresh
682        // In a real implementation, we'd store user ID with the refresh token
683        let token = self.token_manager.create_auth_token(
684            "unknown", // This would need to be resolved from the refresh token
685            token_response.scope
686                .unwrap_or_default()
687                .split_whitespace()
688                .map(|s| s.to_string())
689                .collect(),
690            self.name(),
691            Some(expires_in),
692        )?;
693
694        Ok(token)
695    }
696}
697
698/// PKCE challenge and verifier pair
699type PkceParams = (String, String);
700
701/// OAuth authorization URL result: (url, state, optional_pkce)
702type AuthorizationUrlResult = (String, String, Option<PkceParams>);
703
704/// Default API key validator (placeholder implementation).
705#[derive(Debug, Clone)]
706struct DefaultApiKeyValidator;
707
708#[async_trait]
709impl ApiKeyValidator for DefaultApiKeyValidator {
710    async fn validate_key(&self, _api_key: &str) -> Result<Option<UserInfo>> {
711        // This is a placeholder - real implementation would check against a database
712        Ok(None)
713    }
714
715    async fn create_key(&self, _user_id: &str, _expires_in: Option<Duration>) -> Result<String> {
716        // Generate a new API key
717        Ok(format!("api-{}", uuid::Uuid::new_v4()))
718    }
719
720    async fn revoke_key(&self, _api_key: &str) -> Result<()> {
721        // Mark key as revoked in database
722        Ok(())
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729
730    #[test]
731    fn test_mfa_challenge() {
732        let challenge = MfaChallenge::new(
733            MfaType::Totp,
734            "user123",
735            Duration::from_secs(300),
736        );
737
738        assert_eq!(challenge.user_id, "user123");
739        assert!(!challenge.is_expired());
740        assert_eq!(challenge.id().len(), 36); // UUID length
741    }
742
743    #[test]
744    fn test_jwt_method_creation() {
745        let jwt_method = JwtMethod::new()
746            .secret_key("test-secret")
747            .issuer("test-issuer")
748            .audience("test-audience");
749
750        assert_eq!(jwt_method.name(), "jwt");
751        assert_eq!(jwt_method.issuer, "test-issuer");
752        assert_eq!(jwt_method.audience, "test-audience");
753    }
754
755    #[test]
756    fn test_oauth2_method_creation() {
757        let oauth_method = OAuth2Method::new()
758            .provider(OAuthProvider::GitHub)
759            .client_id("test-client")
760            .client_secret("test-secret")
761            .redirect_uri("https://example.com/callback");
762
763        assert_eq!(oauth_method.name(), "oauth2");
764        assert_eq!(oauth_method.client_id, "test-client");
765        assert!(oauth_method.use_pkce);
766    }
767}