auth_framework/
methods.rs

1//! Authentication method implementations.
2
3// Enhanced device flow support (optional)
4#[cfg(feature = "enhanced-device-flow")]
5pub mod enhanced_device;
6
7#[cfg(feature = "enhanced-device-flow")]
8pub use enhanced_device::{EnhancedDeviceFlowMethod, DeviceFlowInstructions};
9
10// Enhanced device flow comprehensive tests
11#[cfg(test)]
12#[cfg(feature = "enhanced-device-flow")]
13mod enhanced_device_tests;
14
15use crate::credentials::{Credential, CredentialMetadata};
16use crate::errors::{AuthError, Result};
17use crate::providers::{OAuthProvider, generate_state, generate_pkce};
18use crate::tokens::{AuthToken, TokenManager};
19use async_trait::async_trait;
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::time::Duration;
23
24/// Result of an authentication attempt.
25#[derive(Debug, Clone)]
26pub enum MethodResult {
27    /// Authentication was successful
28    Success(Box<AuthToken>),
29    
30    /// Multi-factor authentication is required
31    MfaRequired(Box<MfaChallenge>),
32    
33    /// Authentication failed
34    Failure { reason: String },
35}
36
37/// Multi-factor authentication challenge.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct MfaChallenge {
40    /// Unique challenge ID
41    pub id: String,
42    
43    /// Type of MFA required
44    pub mfa_type: MfaType,
45    
46    /// User ID this challenge is for
47    pub user_id: String,
48    
49    /// When the challenge expires
50    pub expires_at: chrono::DateTime<chrono::Utc>,
51    
52    /// Optional message or instructions
53    pub message: Option<String>,
54    
55    /// Additional challenge data
56    pub data: HashMap<String, serde_json::Value>,
57}
58
59/// Types of multi-factor authentication.
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub enum MfaType {
62    /// Time-based one-time password (TOTP)
63    Totp,
64    
65    /// SMS verification code
66    Sms { phone_number: String },
67    
68    /// Email verification code
69    Email { email_address: String },
70    
71    /// Push notification
72    Push { device_id: String },
73    
74    /// Hardware security key
75    SecurityKey,
76    
77    /// Backup codes
78    BackupCode,
79}
80
81/// Trait for authentication methods.
82#[async_trait]
83pub trait AuthMethod: Send + Sync {
84    /// Get the name of this authentication method.
85    fn name(&self) -> &str;
86    
87    /// Authenticate using the provided credentials.
88    async fn authenticate(
89        &self,
90        credential: &Credential,
91        metadata: &CredentialMetadata,
92    ) -> Result<MethodResult>;
93    
94    /// Validate configuration for this method.
95    fn validate_config(&self) -> Result<()>;
96    
97    /// Check if this method supports refresh tokens.
98    fn supports_refresh(&self) -> bool {
99        false
100    }
101    
102    /// Refresh a token if supported.
103    async fn refresh_token(&self, _refresh_token: &str) -> Result<AuthToken> {
104        Err(AuthError::auth_method(
105            self.name(),
106            "Token refresh not supported by this method".to_string(),
107        ))
108    }
109}
110
111/// Password-based authentication method.
112pub struct PasswordMethod {
113    name: String,
114    password_verifier: Box<dyn PasswordVerifier>,
115    token_manager: TokenManager,
116    mfa_enabled: bool,
117    user_lookup: Box<dyn UserLookup>,
118}
119
120/// JWT-based authentication method.
121pub struct JwtMethod {
122    name: String,
123    token_manager: TokenManager,
124    issuer: String,
125    audience: String,
126}
127
128/// API key authentication method.
129pub struct ApiKeyMethod {
130    name: String,
131    key_prefix: Option<String>,
132    header_name: String,
133    key_validator: Box<dyn ApiKeyValidator>,
134    token_manager: TokenManager,
135}
136
137/// OAuth 2.0 authentication method.
138pub struct OAuth2Method {
139    name: String,
140    provider: OAuthProvider,
141    client_id: String,
142    client_secret: String,
143    redirect_uri: String,
144    scopes: Vec<String>,
145    use_pkce: bool,
146    token_manager: TokenManager,
147}
148
149/// Trait for password verification.
150#[async_trait]
151pub trait PasswordVerifier: Send + Sync {
152    /// Verify a password against a hash.
153    async fn verify_password(&self, username: &str, password: &str) -> Result<bool>;
154    
155    /// Hash a password.
156    async fn hash_password(&self, password: &str) -> Result<String>;
157}
158
159/// Trait for user lookup operations.
160#[async_trait]
161pub trait UserLookup: Send + Sync {
162    /// Look up a user by username.
163    async fn lookup_user(&self, username: &str) -> Result<Option<UserInfo>>;
164    
165    /// Check if a user requires MFA.
166    async fn requires_mfa(&self, user_id: &str) -> Result<bool>;
167}
168
169/// Trait for API key validation.
170#[async_trait]
171pub trait ApiKeyValidator: Send + Sync {
172    /// Validate an API key and return associated user info.
173    async fn validate_key(&self, api_key: &str) -> Result<Option<UserInfo>>;
174    
175    /// Create a new API key for a user.
176    async fn create_key(&self, user_id: &str, expires_in: Option<Duration>) -> Result<String>;
177    
178    /// Revoke an API key.
179    async fn revoke_key(&self, api_key: &str) -> Result<()>;
180}
181
182/// Basic user information.
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct UserInfo {
185    /// User ID
186    pub id: String,
187    
188    /// Username
189    pub username: String,
190    
191    /// Email address
192    pub email: Option<String>,
193    
194    /// Display name
195    pub name: Option<String>,
196    
197    /// User roles
198    pub roles: Vec<String>,
199    
200    /// Whether the user is active
201    pub active: bool,
202    
203    /// Additional user attributes
204    pub attributes: HashMap<String, serde_json::Value>,
205}
206
207impl MfaChallenge {
208    /// Create a new MFA challenge.
209    pub fn new(
210        mfa_type: MfaType,
211        user_id: impl Into<String>,
212        expires_in: Duration,
213    ) -> Self {
214        Self {
215            id: uuid::Uuid::new_v4().to_string(),
216            mfa_type,
217            user_id: user_id.into(),
218            expires_at: chrono::Utc::now() + chrono::Duration::from_std(expires_in).unwrap(),
219            message: None,
220            data: HashMap::new(),
221        }
222    }
223
224    /// Get the challenge ID.
225    pub fn id(&self) -> &str {
226        &self.id
227    }
228
229    /// Check if the challenge has expired.
230    pub fn is_expired(&self) -> bool {
231        chrono::Utc::now() > self.expires_at
232    }
233
234    /// Set a message for the challenge.
235    pub fn with_message(mut self, message: impl Into<String>) -> Self {
236        self.message = Some(message.into());
237        self
238    }
239}
240
241impl PasswordMethod {
242    /// Create a new password authentication method.
243    pub fn new(
244        password_verifier: Box<dyn PasswordVerifier>,
245        user_lookup: Box<dyn UserLookup>,
246        token_manager: TokenManager,
247    ) -> Self {
248        Self {
249            name: "password".to_string(),
250            password_verifier,
251            token_manager,
252            mfa_enabled: false,
253            user_lookup,
254        }
255    }
256
257    /// Enable or disable MFA for this method.
258    pub fn with_mfa(mut self, enabled: bool) -> Self {
259        self.mfa_enabled = enabled;
260        self
261    }
262}
263
264#[async_trait]
265impl AuthMethod for PasswordMethod {
266    fn name(&self) -> &str {
267        &self.name
268    }
269
270    async fn authenticate(
271        &self,
272        credential: &Credential,
273        _metadata: &CredentialMetadata,
274    ) -> Result<MethodResult> {
275        let (username, password) = match credential {
276            Credential::Password { username, password } => (username, password),
277            _ => return Err(AuthError::auth_method(
278                self.name(),
279                "Invalid credential type for password authentication".to_string(),
280            )),
281        };
282
283        // Verify password
284        if !self.password_verifier.verify_password(username, password).await? {
285            return Ok(MethodResult::Failure {
286                reason: "Invalid username or password".to_string(),
287            });
288        }
289
290        // Look up user
291        let user = self.user_lookup.lookup_user(username).await?
292            .ok_or_else(|| AuthError::auth_method(
293                self.name(),
294                "User not found".to_string(),
295            ))?;
296
297        if !user.active {
298            return Ok(MethodResult::Failure {
299                reason: "User account is disabled".to_string(),
300            });
301        }
302
303        // Check if MFA is required
304        if self.mfa_enabled && self.user_lookup.requires_mfa(&user.id).await? {
305            let challenge = MfaChallenge::new(
306                MfaType::Totp, // Default to TOTP, could be configurable
307                &user.id,
308                Duration::from_secs(300), // 5 minutes
309            ).with_message("Please enter your MFA code");
310
311            return Ok(MethodResult::MfaRequired(Box::new(challenge)));
312        }
313
314        // Create token
315        let token = self.token_manager.create_auth_token(
316            &user.id,
317            vec![], // Scopes would be determined by user roles
318            self.name(),
319            None,
320        )?;
321
322        Ok(MethodResult::Success(Box::new(token)))
323    }
324
325    fn validate_config(&self) -> Result<()> {
326        // Validation would depend on the specific implementation
327        Ok(())
328    }
329}
330
331impl Default for JwtMethod {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337impl JwtMethod {
338    /// Create a new JWT authentication method.
339    pub fn new() -> Self {
340        let token_manager = TokenManager::new_hmac(
341            b"default-secret", // This should be configurable
342            "default-issuer",
343            "default-audience",
344        );
345
346        Self {
347            name: "jwt".to_string(),
348            token_manager,
349            issuer: "default-issuer".to_string(),
350            audience: "default-audience".to_string(),
351        }
352    }
353
354    /// Set the secret key for JWT signing.
355    pub fn secret_key(mut self, secret: impl Into<String>) -> Self {
356        let secret = secret.into();
357        self.token_manager = TokenManager::new_hmac(
358            secret.as_bytes(),
359            &self.issuer,
360            &self.audience,
361        );
362        self
363    }
364
365    /// Set the issuer.
366    pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
367        self.issuer = issuer.into();
368        self.token_manager = TokenManager::new_hmac(
369            b"default-secret", // This should use the actual secret
370            &self.issuer,
371            &self.audience,
372        );
373        self
374    }
375
376    /// Set the audience.
377    pub fn audience(mut self, audience: impl Into<String>) -> Self {
378        self.audience = audience.into();
379        self.token_manager = TokenManager::new_hmac(
380            b"default-secret", // This should use the actual secret
381            &self.issuer,
382            &self.audience,
383        );
384        self
385    }
386
387    /// Set the JWT algorithm.
388    pub fn algorithm(self, _algorithm: impl Into<String>) -> Self {
389        // In a real implementation, this would set the JWT algorithm
390        self
391    }
392}
393
394#[async_trait]
395impl AuthMethod for JwtMethod {
396    fn name(&self) -> &str {
397        &self.name
398    }
399
400    async fn authenticate(
401        &self,
402        credential: &Credential,
403        _metadata: &CredentialMetadata,
404    ) -> Result<MethodResult> {
405        let token_str = match credential {
406            Credential::Jwt { token } => token,
407            Credential::Bearer { token } => token,
408            _ => return Err(AuthError::auth_method(
409                self.name(),
410                "Invalid credential type for JWT authentication".to_string(),
411            )),
412        };
413
414        // Validate JWT
415        let claims = self.token_manager.validate_jwt_token(token_str)?;
416        
417        // Create auth token from JWT claims
418        let remaining_seconds = (claims.exp - chrono::Utc::now().timestamp()).max(0) as u64;
419        let token = AuthToken::new(
420            claims.sub,
421            token_str.clone(),
422            std::time::Duration::from_secs(remaining_seconds),
423            self.name(),
424        ).with_scopes(claims.scope.split_whitespace().map(|s| s.to_string()).collect());
425
426        Ok(MethodResult::Success(Box::new(token)))
427    }
428
429    fn validate_config(&self) -> Result<()> {
430        // Validate JWT configuration
431        Ok(())
432    }
433}
434
435impl Default for ApiKeyMethod {
436    fn default() -> Self {
437        Self::new()
438    }
439}
440
441impl ApiKeyMethod {
442    /// Create a new API key authentication method.
443    pub fn new() -> Self {
444        let token_manager = TokenManager::new_hmac(
445            b"default-secret",
446            "api-key-issuer",
447            "api-key-audience",
448        );
449
450        Self {
451            name: "api-key".to_string(),
452            key_prefix: None,
453            header_name: "X-API-Key".to_string(),
454            key_validator: Box::new(DefaultApiKeyValidator),
455            token_manager,
456        }
457    }
458
459    /// Set the key prefix.
460    pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
461        self.key_prefix = Some(prefix.into());
462        self
463    }
464
465    /// Set the key length (for compatibility with examples).
466    pub fn key_length(self, _length: usize) -> Self {
467        // In a real implementation, this would affect key generation
468        self
469    }
470
471    /// Set the header name.
472    pub fn header_name(mut self, name: impl Into<String>) -> Self {
473        self.header_name = name.into();
474        self
475    }
476
477    /// Set the key validator.
478    pub fn key_validator(mut self, validator: Box<dyn ApiKeyValidator>) -> Self {
479        self.key_validator = validator;
480        self
481    }
482}
483
484#[async_trait]
485impl AuthMethod for ApiKeyMethod {
486    fn name(&self) -> &str {
487        &self.name
488    }
489
490    async fn authenticate(
491        &self,
492        credential: &Credential,
493        _metadata: &CredentialMetadata,
494    ) -> Result<MethodResult> {
495        let api_key = match credential {
496            Credential::ApiKey { key } => key,
497            _ => return Err(AuthError::auth_method(
498                self.name(),
499                "Invalid credential type for API key authentication".to_string(),
500            )),
501        };
502
503        // Validate prefix if configured
504        if let Some(prefix) = &self.key_prefix {
505            if !api_key.starts_with(prefix) {
506                return Ok(MethodResult::Failure {
507                    reason: "Invalid API key format".to_string(),
508                });
509            }
510        }
511
512        // Validate key
513        let user = self.key_validator.validate_key(api_key).await?
514            .ok_or_else(|| AuthError::auth_method(
515                self.name(),
516                "Invalid API key".to_string(),
517            ))?;
518
519        // Create token
520        let token = self.token_manager.create_auth_token(
521            &user.id,
522            vec!["api".to_string()], // Default API scope
523            self.name(),
524            Some(std::time::Duration::from_secs(3600)), // 1 hour default
525        )?;
526
527        Ok(MethodResult::Success(Box::new(token)))
528    }
529
530    fn validate_config(&self) -> Result<()> {
531        Ok(())
532    }
533}
534
535impl Default for OAuth2Method {
536    fn default() -> Self {
537        Self::new()
538    }
539}
540
541impl OAuth2Method {
542    /// Create a new OAuth 2.0 authentication method.
543    pub fn new() -> Self {
544        let token_manager = TokenManager::new_hmac(
545            b"oauth-secret",
546            "oauth-issuer",
547            "oauth-audience",
548        );
549
550        Self {
551            name: "oauth2".to_string(),
552            provider: OAuthProvider::GitHub, // Default provider
553            client_id: String::new(),
554            client_secret: String::new(),
555            redirect_uri: String::new(),
556            scopes: Vec::new(),
557            use_pkce: true,
558            token_manager,
559        }
560    }
561
562    /// Set the OAuth provider.
563    pub fn provider(mut self, provider: OAuthProvider) -> Self {
564        self.provider = provider;
565        self
566    }
567
568    /// Set the client ID.
569    pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
570        self.client_id = client_id.into();
571        self
572    }
573
574    /// Set the client secret.
575    pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
576        self.client_secret = client_secret.into();
577        self
578    }
579
580    /// Set the redirect URI.
581    pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
582        self.redirect_uri = redirect_uri.into();
583        self
584    }
585
586    /// Set the scopes.
587    pub fn scopes(mut self, scopes: Vec<String>) -> Self {
588        self.scopes = scopes;
589        self
590    }
591
592    /// Enable or disable PKCE.
593    pub fn use_pkce(mut self, use_pkce: bool) -> Self {
594        self.use_pkce = use_pkce;
595        self
596    }
597
598    /// Generate authorization URL.
599    pub fn authorization_url(&self) -> Result<AuthorizationUrlResult> {
600        let state = generate_state();
601        let pkce = if self.use_pkce {
602            Some(generate_pkce())
603        } else {
604            None
605        };
606
607        let url = self.provider.build_authorization_url(
608            &self.client_id,
609            &self.redirect_uri,
610            &state,
611            if self.scopes.is_empty() { None } else { Some(&self.scopes) },
612            pkce.as_ref().map(|(_, challenge)| challenge.as_str()),
613        )?;
614
615        Ok((url, state, pkce))
616    }
617}
618
619#[async_trait]
620impl AuthMethod for OAuth2Method {
621    fn name(&self) -> &str {
622        &self.name
623    }
624
625    async fn authenticate(
626        &self,
627        credential: &Credential,
628        _metadata: &CredentialMetadata,
629    ) -> Result<MethodResult> {
630        let (authorization_code, code_verifier) = match credential {
631            Credential::OAuth { authorization_code, code_verifier, .. } => {
632                (authorization_code, code_verifier.as_deref())
633            }
634            _ => return Err(AuthError::auth_method(
635                self.name(),
636                "Invalid credential type for OAuth authentication".to_string(),
637            )),
638        };
639
640        // Exchange authorization code for tokens
641        let token_response = self.provider.exchange_code(
642            &self.client_id,
643            &self.client_secret,
644            authorization_code,
645            &self.redirect_uri,
646            code_verifier,
647        ).await?;
648
649        // Get user info
650        let user_info = self.provider.get_user_info(&token_response.access_token).await?;
651
652        // Create auth token
653        // Convert to duration with proper type
654        let expires_in = token_response.expires_in
655            .map(std::time::Duration::from_secs)
656            .unwrap_or_else(|| std::time::Duration::from_secs(3600));
657
658        let mut token = self.token_manager.create_auth_token(
659            &user_info.id,
660            token_response.scope
661                .unwrap_or_default()
662                .split_whitespace()
663                .map(|s| s.to_string())
664                .collect(),
665            self.name(),
666            Some(expires_in),
667        )?;
668
669        // Set refresh token if available
670        if let Some(refresh_token) = token_response.refresh_token {
671            token = token.with_refresh_token(refresh_token);
672        }
673
674        Ok(MethodResult::Success(Box::new(token)))
675    }
676
677    fn validate_config(&self) -> Result<()> {
678        if self.client_id.is_empty() {
679            return Err(AuthError::config("OAuth client ID is required"));
680        }
681        if self.client_secret.is_empty() {
682            return Err(AuthError::config("OAuth client secret is required"));
683        }
684        if self.redirect_uri.is_empty() {
685            return Err(AuthError::config("OAuth redirect URI is required"));
686        }
687        Ok(())
688    }
689
690    fn supports_refresh(&self) -> bool {
691        self.provider.config().supports_refresh
692    }
693
694    async fn refresh_token(&self, refresh_token: &str) -> Result<AuthToken> {
695        let token_response = self.provider.refresh_token(
696            &self.client_id,
697            &self.client_secret,
698            refresh_token,
699        ).await?;
700
701        let expires_in = token_response.expires_in
702            .map(Duration::from_secs)
703            .unwrap_or_else(|| std::time::Duration::from_secs(3600));
704
705        // We need user info to create the token, but we don't have it from refresh
706        // In a real implementation, we'd store user ID with the refresh token
707        let token = self.token_manager.create_auth_token(
708            "unknown", // This would need to be resolved from the refresh token
709            token_response.scope
710                .unwrap_or_default()
711                .split_whitespace()
712                .map(|s| s.to_string())
713                .collect(),
714            self.name(),
715            Some(expires_in),
716        )?;
717
718        Ok(token)
719    }
720}
721
722/// PKCE challenge and verifier pair
723type PkceParams = (String, String);
724
725/// OAuth authorization URL result: (url, state, optional_pkce)
726type AuthorizationUrlResult = (String, String, Option<PkceParams>);
727
728/// Default API key validator (placeholder implementation).
729#[derive(Debug, Clone)]
730struct DefaultApiKeyValidator;
731
732#[async_trait]
733impl ApiKeyValidator for DefaultApiKeyValidator {
734    async fn validate_key(&self, _api_key: &str) -> Result<Option<UserInfo>> {
735        // This is a placeholder - real implementation would check against a database
736        Ok(None)
737    }
738
739    async fn create_key(&self, _user_id: &str, _expires_in: Option<Duration>) -> Result<String> {
740        // Generate a new API key
741        Ok(format!("api-{}", uuid::Uuid::new_v4()))
742    }
743
744    async fn revoke_key(&self, _api_key: &str) -> Result<()> {
745        // Mark key as revoked in database
746        Ok(())
747    }
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    #[test]
755    fn test_mfa_challenge() {
756        let challenge = MfaChallenge::new(
757            MfaType::Totp,
758            "user123",
759            Duration::from_secs(300),
760        );
761
762        assert_eq!(challenge.user_id, "user123");
763        assert!(!challenge.is_expired());
764        assert_eq!(challenge.id().len(), 36); // UUID length
765    }
766
767    #[test]
768    fn test_jwt_method_creation() {
769        let jwt_method = JwtMethod::new()
770            .secret_key("test-secret")
771            .issuer("test-issuer")
772            .audience("test-audience");
773
774        assert_eq!(jwt_method.name(), "jwt");
775        assert_eq!(jwt_method.issuer, "test-issuer");
776        assert_eq!(jwt_method.audience, "test-audience");
777    }
778
779    #[test]
780    fn test_oauth2_method_creation() {
781        let oauth_method = OAuth2Method::new()
782            .provider(OAuthProvider::GitHub)
783            .client_id("test-client")
784            .client_secret("test-secret")
785            .redirect_uri("https://example.com/callback");
786
787        assert_eq!(oauth_method.name(), "oauth2");
788        assert_eq!(oauth_method.client_id, "test-client");
789        assert!(oauth_method.use_pkce);
790    }
791}