auth_framework/authentication/
mfa.rs

1//! Multi-Factor Authentication (MFA) implementation.
2//!
3//! This module provides comprehensive MFA support including TOTP, SMS, email,
4//! backup codes, and WebAuthn for enhanced security.
5
6use crate::errors::{AuthError, Result};
7use crate::security::MfaConfig;
8use async_trait::async_trait;
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use std::time::{SystemTime, UNIX_EPOCH};
12use totp_lite::{Sha1, totp};
13
14/// MFA method types
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub enum MfaMethodType {
17    Totp,
18    Sms,
19    Email,
20    WebAuthn,
21    BackupCodes,
22}
23
24/// MFA challenge that must be completed
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct MfaChallenge {
27    /// Unique challenge ID
28    pub id: String,
29    /// User ID this challenge belongs to
30    pub user_id: String,
31    /// Type of MFA method
32    pub method_type: MfaMethodType,
33    /// Challenge data (varies by method type)
34    pub challenge_data: MfaChallengeData,
35    /// When the challenge was created
36    pub created_at: SystemTime,
37    /// When the challenge expires
38    pub expires_at: SystemTime,
39    /// Number of attempts made
40    pub attempts: u32,
41    /// Maximum allowed attempts
42    pub max_attempts: u32,
43}
44
45/// Challenge data specific to each MFA method
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum MfaChallengeData {
48    Totp {
49        /// Current time window
50        time_window: u64,
51    },
52    Sms {
53        /// Phone number (masked)
54        phone_number: String,
55        /// Generated code
56        code: String,
57    },
58    Email {
59        /// Email address (masked)
60        email: String,
61        /// Generated code
62        code: String,
63    },
64    WebAuthn {
65        /// Challenge bytes
66        challenge: Vec<u8>,
67        /// Allowed credential IDs
68        allowed_credentials: Vec<String>,
69    },
70    BackupCodes {
71        /// Remaining backup codes count
72        remaining_codes: u32,
73    },
74}
75
76/// MFA method configuration for a user
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct UserMfaMethod {
79    /// Unique method ID
80    pub id: String,
81    /// User ID
82    pub user_id: String,
83    /// Method type
84    pub method_type: MfaMethodType,
85    /// Method-specific data
86    pub method_data: MfaMethodData,
87    /// Display name for the method
88    pub display_name: String,
89    /// Whether this is the primary method
90    pub is_primary: bool,
91    /// Whether this method is enabled
92    pub is_enabled: bool,
93    /// When the method was created
94    pub created_at: SystemTime,
95    /// When the method was last used
96    pub last_used_at: Option<SystemTime>,
97}
98
99/// Method-specific configuration data
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub enum MfaMethodData {
102    Totp {
103        /// Base32-encoded secret key
104        secret_key: String,
105        /// QR code URL for setup
106        qr_code_url: String,
107    },
108    Sms {
109        /// Phone number
110        phone_number: String,
111        /// Whether phone number is verified
112        is_verified: bool,
113    },
114    Email {
115        /// Email address
116        email: String,
117        /// Whether email is verified
118        is_verified: bool,
119    },
120    WebAuthn {
121        /// Credential ID
122        credential_id: String,
123        /// Public key
124        public_key: Vec<u8>,
125        /// Counter for replay protection
126        counter: u32,
127    },
128    BackupCodes {
129        /// List of backup codes (hashed)
130        codes: Vec<String>,
131        /// Number of codes used
132        used_count: u32,
133    },
134}
135
136/// MFA verification result
137#[derive(Debug, Clone)]
138pub struct MfaVerificationResult {
139    /// Whether verification succeeded
140    pub success: bool,
141    /// Method that was used
142    pub method_type: MfaMethodType,
143    /// Remaining attempts (if failed)
144    pub remaining_attempts: Option<u32>,
145    /// Error message (if failed)
146    pub error_message: Option<String>,
147}
148
149/// TOTP (Time-based One-Time Password) implementation
150pub struct TotpProvider {
151    config: crate::security::TotpConfig,
152}
153
154impl TotpProvider {
155    pub fn new(config: crate::security::TotpConfig) -> Self {
156        Self { config }
157    }
158
159    /// Generate a new TOTP secret using cryptographically secure random
160    pub fn generate_secret(&self) -> crate::Result<String> {
161        use ring::rand::{SecureRandom, SystemRandom};
162        let rng = SystemRandom::new();
163        let mut secret = [0u8; 20];
164        rng.fill(&mut secret).map_err(|_| {
165            crate::errors::AuthError::crypto("Failed to generate secure TOTP secret".to_string())
166        })?;
167        Ok(base32::encode(
168            base32::Alphabet::Rfc4648 { padding: true },
169            &secret,
170        ))
171    }
172
173    /// Generate QR code URL for TOTP setup
174    pub fn generate_qr_code_url(&self, secret: &str, user_identifier: &str) -> String {
175        format!(
176            "otpauth://totp/{}:{}?secret={}&issuer={}&digits={}&period={}",
177            urlencoding::encode(&self.config.issuer),
178            urlencoding::encode(user_identifier),
179            secret,
180            urlencoding::encode(&self.config.issuer),
181            self.config.digits,
182            self.config.period
183        )
184    }
185
186    /// Generate TOTP code for the current time window
187    pub fn generate_code(&self, secret: &str, time_step: Option<u64>) -> Result<String> {
188        if secret.trim().is_empty() {
189            return Err(AuthError::validation("TOTP secret cannot be empty"));
190        }
191
192        let secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret)
193            .ok_or_else(|| AuthError::validation("Invalid TOTP secret"))?;
194
195        let time_step = time_step.unwrap_or_else(|| {
196            SystemTime::now()
197                .duration_since(UNIX_EPOCH)
198                .unwrap()
199                .as_secs()
200                / self.config.period
201        });
202
203        // Convert time step to Unix timestamp for totp-lite
204        // totp-lite expects Unix timestamp, not time step
205        let unix_timestamp = time_step.checked_mul(self.config.period).ok_or_else(|| {
206            AuthError::InvalidInput("Time step too large for conversion".to_string())
207        })?;
208
209        // Use totp-lite for proper TOTP generation
210        let totp_value = totp::<Sha1>(&secret_bytes, unix_timestamp);
211
212        // totp-lite returns variable length string, parse and format according to config
213        let parsed_value: u32 = totp_value
214            .parse()
215            .map_err(|_| AuthError::validation("TOTP generation error"))?;
216
217        // Format to the specified number of digits
218        Ok(format!(
219            "{:0width$}",
220            parsed_value % 10_u32.pow(self.config.digits.into()),
221            width = self.config.digits as usize
222        ))
223    }
224
225    /// Verify TOTP code with time window tolerance
226    pub fn verify_code(&self, secret: &str, code: &str, time_window: Option<u64>) -> Result<bool> {
227        // First validate the secret by trying to decode it
228        let _secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret)
229            .ok_or_else(|| AuthError::validation("Invalid TOTP secret"))?;
230
231        let current_time_step = if let Some(time) = time_window {
232            time / self.config.period
233        } else {
234            SystemTime::now()
235                .duration_since(UNIX_EPOCH)
236                .unwrap()
237                .as_secs()
238                / self.config.period
239        };
240
241        // Check current time step and ±1 time step for clock skew tolerance
242        for step_offset in [-1i64, 0, 1] {
243            let time_step_i64 = current_time_step as i64 + step_offset;
244            // Skip negative time steps to avoid u64 overflow
245            if time_step_i64 < 0 {
246                continue;
247            }
248            let time_step = time_step_i64 as u64;
249            let expected_code = self.generate_code(secret, Some(time_step))?;
250            if expected_code == code {
251                return Ok(true);
252            }
253        }
254
255        Ok(false)
256    }
257
258    /// Verify TOTP code with configurable time window
259    pub fn verify_totp(&self, secret: &str, token: &str, window: u8) -> Result<bool> {
260        let now = SystemTime::now()
261            .duration_since(UNIX_EPOCH)
262            .map_err(|_| AuthError::validation("System time error"))?
263            .as_secs()
264            / self.config.period;
265
266        // Check within the specified time window using constant-time comparison
267        use subtle::ConstantTimeEq;
268
269        for i in 0..=window {
270            // Check current and positive offset
271            if i == 0 {
272                if let Ok(expected_code) = self.generate_code(secret, Some(now))
273                    && expected_code.as_bytes().ct_eq(token.as_bytes()).into()
274                {
275                    return Ok(true);
276                }
277            } else {
278                // Check both positive and negative offsets
279                for offset in [i as i64, -(i as i64)] {
280                    let time_step_i64 = now as i64 + offset;
281                    // Skip negative time steps to avoid u64 overflow
282                    if time_step_i64 < 0 {
283                        continue;
284                    }
285                    let time_step = time_step_i64 as u64;
286                    if let Ok(expected_code) = self.generate_code(secret, Some(time_step))
287                        && expected_code.as_bytes().ct_eq(token.as_bytes()).into()
288                    {
289                        return Ok(true);
290                    }
291                }
292            }
293        }
294        Ok(false)
295    }
296}
297
298/// SMS provider for sending verification codes
299#[async_trait]
300pub trait SmsProvider: Send + Sync {
301    async fn send_code(&self, phone_number: &str, code: &str) -> Result<()>;
302}
303
304/// Email provider for sending verification codes
305#[async_trait]
306pub trait EmailProvider: Send + Sync {
307    async fn send_code(&self, email: &str, code: &str) -> Result<()>;
308}
309
310/// Backup codes provider
311pub struct BackupCodesProvider;
312
313impl BackupCodesProvider {
314    /// Generate backup codes
315    pub fn generate_codes(count: u8) -> Vec<String> {
316        let mut rng = rand::rng();
317        (0..count)
318            .map(|_| {
319                format!(
320                    "{:04}-{:04}",
321                    rng.random_range(1000..9999),
322                    rng.random_range(1000..9999)
323                )
324            })
325            .collect()
326    }
327
328    /// Hash backup codes for storage
329    pub fn hash_codes(codes: &[String]) -> Result<Vec<String>> {
330        codes
331            .iter()
332            .map(|code| {
333                // In production, use a proper password hashing function
334                Ok(format!("hashed_{}", code))
335            })
336            .collect()
337    }
338
339    /// Verify backup code
340    pub fn verify_code(hashed_codes: &[String], provided_code: &str) -> bool {
341        let expected_hash = format!("hashed_{}", provided_code);
342        hashed_codes.contains(&expected_hash)
343    }
344}
345
346/// MFA storage trait
347#[async_trait]
348pub trait MfaStorage: Send + Sync {
349    /// Store user MFA method
350    async fn store_mfa_method(&self, method: &UserMfaMethod) -> Result<()>;
351
352    /// Get user's MFA methods
353    async fn get_user_mfa_methods(&self, user_id: &str) -> Result<Vec<UserMfaMethod>>;
354
355    /// Update MFA method
356    async fn update_mfa_method(&self, method: &UserMfaMethod) -> Result<()>;
357
358    /// Delete MFA method
359    async fn delete_mfa_method(&self, method_id: &str) -> Result<()>;
360
361    /// Store MFA challenge
362    async fn store_mfa_challenge(&self, challenge: &MfaChallenge) -> Result<()>;
363
364    /// Get MFA challenge
365    async fn get_mfa_challenge(&self, challenge_id: &str) -> Result<Option<MfaChallenge>>;
366
367    /// Update MFA challenge (for attempt counting)
368    async fn update_mfa_challenge(&self, challenge: &MfaChallenge) -> Result<()>;
369
370    /// Delete MFA challenge
371    async fn delete_mfa_challenge(&self, challenge_id: &str) -> Result<()>;
372
373    /// Clean up expired challenges
374    async fn cleanup_expired_challenges(&self) -> Result<()>;
375}
376
377/// MFA manager for handling multi-factor authentication
378pub struct MfaManager<S: MfaStorage> {
379    storage: S,
380    config: MfaConfig,
381    totp_provider: TotpProvider,
382    sms_provider: Option<Box<dyn SmsProvider>>,
383    email_provider: Option<Box<dyn EmailProvider>>,
384}
385
386impl<S: MfaStorage> MfaManager<S> {
387    /// Create a new MFA manager
388    pub fn new(storage: S, config: MfaConfig) -> Self {
389        let totp_provider = TotpProvider::new(config.totp_config.clone());
390
391        Self {
392            storage,
393            config,
394            totp_provider,
395            sms_provider: None,
396            email_provider: None,
397        }
398    }
399
400    /// Set SMS provider
401    pub fn with_sms_provider(mut self, provider: Box<dyn SmsProvider>) -> Self {
402        self.sms_provider = Some(provider);
403        self
404    }
405
406    /// Set email provider
407    pub fn with_email_provider(mut self, provider: Box<dyn EmailProvider>) -> Self {
408        self.email_provider = Some(provider);
409        self
410    }
411
412    /// Setup TOTP for a user
413    pub async fn setup_totp(&self, user_id: &str, user_identifier: &str) -> Result<UserMfaMethod> {
414        let secret = self.totp_provider.generate_secret()?;
415        let qr_code_url = self
416            .totp_provider
417            .generate_qr_code_url(&secret, user_identifier);
418
419        let method = UserMfaMethod {
420            id: uuid::Uuid::new_v4().to_string(),
421            user_id: user_id.to_string(),
422            method_type: MfaMethodType::Totp,
423            method_data: MfaMethodData::Totp {
424                secret_key: secret,
425                qr_code_url,
426            },
427            display_name: "Authenticator App".to_string(),
428            is_primary: false,
429            is_enabled: false, // Will be enabled after verification
430            created_at: SystemTime::now(),
431            last_used_at: None,
432        };
433
434        self.storage.store_mfa_method(&method).await?;
435        Ok(method)
436    }
437
438    /// Setup SMS MFA for a user
439    pub async fn setup_sms(&self, user_id: &str, phone_number: &str) -> Result<UserMfaMethod> {
440        let method = UserMfaMethod {
441            id: uuid::Uuid::new_v4().to_string(),
442            user_id: user_id.to_string(),
443            method_type: MfaMethodType::Sms,
444            method_data: MfaMethodData::Sms {
445                phone_number: phone_number.to_string(),
446                is_verified: false,
447            },
448            display_name: format!("SMS to {}", mask_phone_number(phone_number)),
449            is_primary: false,
450            is_enabled: false,
451            created_at: SystemTime::now(),
452            last_used_at: None,
453        };
454
455        self.storage.store_mfa_method(&method).await?;
456        Ok(method)
457    }
458
459    /// Generate backup codes for a user
460    pub async fn generate_backup_codes(
461        &self,
462        user_id: &str,
463    ) -> Result<(UserMfaMethod, Vec<String>)> {
464        let codes = BackupCodesProvider::generate_codes(10);
465        let hashed_codes = BackupCodesProvider::hash_codes(&codes)?;
466
467        let method = UserMfaMethod {
468            id: uuid::Uuid::new_v4().to_string(),
469            user_id: user_id.to_string(),
470            method_type: MfaMethodType::BackupCodes,
471            method_data: MfaMethodData::BackupCodes {
472                codes: hashed_codes,
473                used_count: 0,
474            },
475            display_name: "Backup Codes".to_string(),
476            is_primary: false,
477            is_enabled: true,
478            created_at: SystemTime::now(),
479            last_used_at: None,
480        };
481
482        self.storage.store_mfa_method(&method).await?;
483        Ok((method, codes))
484    }
485
486    /// Create MFA challenge for user
487    pub async fn create_challenge(
488        &self,
489        user_id: &str,
490        method_type: MfaMethodType,
491    ) -> Result<MfaChallenge> {
492        let user_methods = self.storage.get_user_mfa_methods(user_id).await?;
493        let method = user_methods
494            .iter()
495            .find(|m| m.method_type == method_type && m.is_enabled)
496            .ok_or_else(|| AuthError::validation("MFA method not found or not enabled"))?;
497
498        let challenge_data = match &method.method_data {
499            MfaMethodData::Totp { .. } => MfaChallengeData::Totp {
500                time_window: SystemTime::now()
501                    .duration_since(UNIX_EPOCH)
502                    .unwrap()
503                    .as_secs()
504                    / self.config.totp_config.period,
505            },
506            MfaMethodData::Sms { phone_number, .. } => {
507                let code = generate_numeric_code(6);
508                if let Some(sms_provider) = &self.sms_provider {
509                    sms_provider.send_code(phone_number, &code).await?;
510                }
511                MfaChallengeData::Sms {
512                    phone_number: mask_phone_number(phone_number),
513                    code,
514                }
515            }
516            MfaMethodData::Email { email, .. } => {
517                let code = generate_numeric_code(6);
518                if let Some(email_provider) = &self.email_provider {
519                    email_provider.send_code(email, &code).await?;
520                }
521                MfaChallengeData::Email {
522                    email: mask_email(email),
523                    code,
524                }
525            }
526            MfaMethodData::BackupCodes { .. } => {
527                MfaChallengeData::BackupCodes { remaining_codes: 8 } // Default backup codes count
528            }
529            _ => return Err(AuthError::validation("Unsupported MFA method type")),
530        };
531
532        let challenge = MfaChallenge {
533            id: uuid::Uuid::new_v4().to_string(),
534            user_id: user_id.to_string(),
535            method_type,
536            challenge_data,
537            created_at: SystemTime::now(),
538            expires_at: SystemTime::now() + std::time::Duration::from_secs(300), // 5 minutes
539            attempts: 0,
540            max_attempts: 3,
541        };
542
543        self.storage.store_mfa_challenge(&challenge).await?;
544        Ok(challenge)
545    }
546
547    /// Verify MFA challenge
548    pub async fn verify_challenge(
549        &self,
550        challenge_id: &str,
551        response: &str,
552    ) -> Result<MfaVerificationResult> {
553        let mut challenge = self
554            .storage
555            .get_mfa_challenge(challenge_id)
556            .await?
557            .ok_or_else(|| AuthError::validation("MFA challenge not found"))?;
558
559        // Check if challenge has expired
560        if SystemTime::now() > challenge.expires_at {
561            self.storage.delete_mfa_challenge(challenge_id).await?;
562            return Ok(MfaVerificationResult {
563                success: false,
564                method_type: challenge.method_type,
565                remaining_attempts: None,
566                error_message: Some("Challenge has expired".to_string()),
567            });
568        }
569
570        // Check if max attempts exceeded
571        if challenge.attempts >= challenge.max_attempts {
572            self.storage.delete_mfa_challenge(challenge_id).await?;
573            return Ok(MfaVerificationResult {
574                success: false,
575                method_type: challenge.method_type,
576                remaining_attempts: Some(0),
577                error_message: Some("Maximum attempts exceeded".to_string()),
578            });
579        }
580
581        challenge.attempts += 1;
582
583        let success = match &challenge.challenge_data {
584            MfaChallengeData::Totp { time_window } => {
585                let user_methods = self
586                    .storage
587                    .get_user_mfa_methods(&challenge.user_id)
588                    .await?;
589                if let Some(method) = user_methods
590                    .iter()
591                    .find(|m| m.method_type == MfaMethodType::Totp)
592                {
593                    if let MfaMethodData::Totp { secret_key, .. } = &method.method_data {
594                        self.totp_provider
595                            .verify_code(secret_key, response, Some(*time_window))?
596                    } else {
597                        false
598                    }
599                } else {
600                    false
601                }
602            }
603            MfaChallengeData::Sms { code, .. } => code == response,
604            MfaChallengeData::Email { code, .. } => code == response,
605            MfaChallengeData::BackupCodes { .. } => {
606                let user_methods = self
607                    .storage
608                    .get_user_mfa_methods(&challenge.user_id)
609                    .await?;
610                if let Some(method) = user_methods
611                    .iter()
612                    .find(|m| m.method_type == MfaMethodType::BackupCodes)
613                {
614                    if let MfaMethodData::BackupCodes { codes, .. } = &method.method_data {
615                        BackupCodesProvider::verify_code(codes, response)
616                    } else {
617                        false
618                    }
619                } else {
620                    false
621                }
622            }
623            _ => false,
624        };
625
626        if success {
627            self.storage.delete_mfa_challenge(challenge_id).await?;
628            Ok(MfaVerificationResult {
629                success: true,
630                method_type: challenge.method_type,
631                remaining_attempts: None,
632                error_message: None,
633            })
634        } else {
635            let remaining = challenge.max_attempts.saturating_sub(challenge.attempts);
636            self.storage.update_mfa_challenge(&challenge).await?;
637
638            Ok(MfaVerificationResult {
639                success: false,
640                method_type: challenge.method_type,
641                remaining_attempts: Some(remaining),
642                error_message: Some("Invalid code".to_string()),
643            })
644        }
645    }
646
647    /// Check if user has MFA enabled
648    pub async fn has_mfa_enabled(&self, user_id: &str) -> Result<bool> {
649        let methods = self.storage.get_user_mfa_methods(user_id).await?;
650        Ok(methods.iter().any(|m| m.is_enabled))
651    }
652
653    /// Get user's enabled MFA methods
654    pub async fn get_enabled_methods(&self, user_id: &str) -> Result<Vec<MfaMethodType>> {
655        let methods = self.storage.get_user_mfa_methods(user_id).await?;
656        Ok(methods
657            .iter()
658            .filter(|m| m.is_enabled)
659            .map(|m| m.method_type.clone())
660            .collect())
661    }
662}
663
664/// Generate a numeric code of specified length
665fn generate_numeric_code(length: u8) -> String {
666    let mut rng = rand::rng();
667    (0..length)
668        .map(|_| rng.random_range(0..10).to_string())
669        .collect()
670}
671
672/// Mask phone number for display
673fn mask_phone_number(phone: &str) -> String {
674    if phone.len() > 4 {
675        format!("***-***-{}", &phone[phone.len() - 4..])
676    } else {
677        "***-***-****".to_string()
678    }
679}
680
681/// Mask email address for display
682fn mask_email(email: &str) -> String {
683    if let Some(at_pos) = email.find('@') {
684        let (local, domain) = email.split_at(at_pos);
685        if local.len() > 2 {
686            format!("{}***{}", &local[0..1], &domain)
687        } else {
688            format!("***{}", domain)
689        }
690    } else {
691        "***@***.***".to_string()
692    }
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698
699    #[test]
700    fn test_totp_generation() {
701        let config = crate::security::TotpConfig::default();
702        let provider = TotpProvider::new(config);
703
704        let secret = provider.generate_secret().unwrap();
705        assert!(!secret.is_empty());
706
707        let code = provider.generate_code(&secret, Some(1)).unwrap();
708        assert_eq!(code.len(), 6);
709
710        // Verify the same code
711        assert!(provider.verify_code(&secret, &code, Some(1)).unwrap());
712
713        // Verify wrong code
714        assert!(!provider.verify_code(&secret, "000000", Some(1)).unwrap());
715    }
716
717    #[test]
718    fn test_backup_codes() {
719        let codes = BackupCodesProvider::generate_codes(5);
720        assert_eq!(codes.len(), 5);
721
722        let hashed = BackupCodesProvider::hash_codes(&codes).unwrap();
723        assert_eq!(hashed.len(), 5);
724
725        // Should verify correctly
726        assert!(BackupCodesProvider::verify_code(&hashed, &codes[0]));
727
728        // Should not verify wrong code
729        assert!(!BackupCodesProvider::verify_code(&hashed, "1234-5678"));
730    }
731
732    #[test]
733    fn test_masking() {
734        assert_eq!(mask_phone_number("+1234567890"), "***-***-7890");
735        assert_eq!(mask_email("user@example.com"), "u***@example.com");
736    }
737}