elif_auth/providers/
mfa.rs

1//! Multi-factor authentication provider
2//!
3//! Provides TOTP (Time-based One-Time Password) functionality and backup codes
4//! for enhanced security in the authentication flow.
5
6use crate::{AuthError, AuthResult};
7use chrono::{DateTime, Utc};
8use rand::{thread_rng, Rng};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13#[cfg(feature = "mfa")]
14use base32;
15#[cfg(feature = "mfa")]
16use qrcode::{render::unicode, QrCode};
17#[cfg(feature = "mfa")]
18use totp_lite::{totp, Sha1};
19#[cfg(feature = "mfa")]
20use urlencoding;
21
22/// MFA configuration
23#[derive(Debug, Clone)]
24pub struct MfaConfig {
25    /// TOTP time step in seconds (usually 30)
26    pub time_step: u64,
27    /// Number of time windows to check (for time tolerance)
28    pub window_tolerance: u8,
29    /// Secret key length in bytes (recommended: 20 for SHA-1)
30    pub secret_length: usize,
31    /// Issuer name for TOTP URIs
32    pub issuer: String,
33    /// Number of backup codes to generate
34    pub backup_codes_count: usize,
35}
36
37impl Default for MfaConfig {
38    fn default() -> Self {
39        Self {
40            time_step: 30,
41            window_tolerance: 1,
42            secret_length: 20,
43            issuer: "elif.rs".to_string(),
44            backup_codes_count: 10,
45        }
46    }
47}
48
49/// MFA secret containing TOTP secret and metadata
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct MfaSecret {
52    /// User ID this secret belongs to
53    pub user_id: Uuid,
54    /// Base32-encoded TOTP secret
55    pub secret: String,
56    /// Backup codes (hashed)
57    pub backup_codes: Vec<String>,
58    /// Used backup codes (to prevent reuse)
59    pub used_backup_codes: Vec<String>,
60    /// MFA setup completion timestamp
61    pub setup_completed_at: Option<DateTime<Utc>>,
62    /// Last successful verification timestamp
63    pub last_verified_at: Option<DateTime<Utc>>,
64    /// Creation timestamp
65    pub created_at: DateTime<Utc>,
66}
67
68/// MFA setup information for user enrollment
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct MfaSetup {
71    /// QR code as text/ASCII art for terminal display
72    pub qr_code: String,
73    /// Manual entry key (base32 secret)
74    pub manual_key: String,
75    /// TOTP URI for QR code
76    pub totp_uri: String,
77    /// Generated backup codes (plaintext - show only once)
78    pub backup_codes: Vec<String>,
79}
80
81/// MFA verification result
82#[derive(Debug, Clone, PartialEq)]
83pub enum MfaVerificationResult {
84    /// Verification successful with TOTP
85    TotpSuccess,
86    /// Verification successful with backup code
87    BackupCodeSuccess,
88    /// Verification failed
89    Failed,
90    /// MFA not set up for this user
91    NotSetup,
92}
93
94/// Multi-factor authentication provider
95pub struct MfaProvider {
96    config: MfaConfig,
97    #[allow(dead_code)]
98    secrets: HashMap<Uuid, MfaSecret>,
99}
100
101impl MfaProvider {
102    /// Create a new MFA provider with default configuration
103    pub fn new() -> AuthResult<Self> {
104        Self::with_config(MfaConfig::default())
105    }
106
107    /// Create a new MFA provider with custom configuration
108    pub fn with_config(config: MfaConfig) -> AuthResult<Self> {
109        Ok(Self {
110            config,
111            secrets: HashMap::new(),
112        })
113    }
114
115    /// Generate a new MFA setup for a user
116    pub fn generate_setup(&self, _user_id: Uuid, username: &str) -> AuthResult<MfaSetup> {
117        #[cfg(not(feature = "mfa"))]
118        {
119            let _ = (_user_id, username);
120            return Err(AuthError::generic_error(
121                "MFA feature not enabled - compile with 'mfa' feature",
122            ));
123        }
124
125        #[cfg(feature = "mfa")]
126        {
127            // Generate random secret
128            let secret_bytes = self.generate_secret();
129            let secret_base32 =
130                base32::encode(base32::Alphabet::Rfc4648 { padding: true }, &secret_bytes);
131
132            // Generate backup codes
133            let backup_codes = self.generate_backup_codes();
134
135            // Create TOTP URI
136            let totp_uri = format!(
137                "otpauth://totp/{}:{}?secret={}&issuer={}&digits=6&period={}",
138                urlencoding::encode(&self.config.issuer),
139                urlencoding::encode(username),
140                secret_base32,
141                urlencoding::encode(&self.config.issuer),
142                self.config.time_step
143            );
144
145            // Generate QR code
146            let qr_code = self.generate_qr_code(&totp_uri)?;
147
148            Ok(MfaSetup {
149                qr_code,
150                manual_key: secret_base32,
151                totp_uri,
152                backup_codes,
153            })
154        }
155    }
156
157    /// Complete MFA setup for a user by verifying the first TOTP code
158    pub fn complete_setup(
159        &mut self,
160        user_id: Uuid,
161        setup: &MfaSetup,
162        totp_code: &str,
163    ) -> AuthResult<()> {
164        #[cfg(not(feature = "mfa"))]
165        {
166            let _ = (user_id, setup, totp_code);
167            return Err(AuthError::generic_error(
168                "MFA feature not enabled - compile with 'mfa' feature",
169            ));
170        }
171
172        #[cfg(feature = "mfa")]
173        {
174            // Verify the TOTP code
175            if !self.verify_totp_code(&setup.manual_key, totp_code)? {
176                return Err(AuthError::invalid_credentials("Invalid TOTP code"));
177            }
178
179            // Hash backup codes for storage
180            let hashed_backup_codes = setup
181                .backup_codes
182                .iter()
183                .map(|code| self.hash_backup_code(code))
184                .collect::<Result<Vec<_>, _>>()?;
185
186            // Store the secret
187            let secret = MfaSecret {
188                user_id,
189                secret: setup.manual_key.clone(),
190                backup_codes: hashed_backup_codes,
191                used_backup_codes: Vec::new(),
192                setup_completed_at: Some(Utc::now()),
193                last_verified_at: Some(Utc::now()),
194                created_at: Utc::now(),
195            };
196
197            self.secrets.insert(user_id, secret);
198            Ok(())
199        }
200    }
201
202    /// Verify MFA for a user
203    pub fn verify_mfa(&mut self, user_id: Uuid, code: &str) -> AuthResult<MfaVerificationResult> {
204        #[cfg(not(feature = "mfa"))]
205        {
206            let _ = (user_id, code);
207            return Ok(MfaVerificationResult::NotSetup);
208        }
209
210        #[cfg(feature = "mfa")]
211        {
212            if !self.secrets.contains_key(&user_id) {
213                return Ok(MfaVerificationResult::NotSetup);
214            }
215
216            // Try TOTP verification first - use a separate scope to avoid borrowing conflicts
217            let totp_secret = {
218                let secret = self.secrets.get(&user_id).unwrap();
219                secret.secret.clone()
220            };
221
222            if self.verify_totp_code(&totp_secret, code)? {
223                // Update last verified time after verification succeeds
224                if let Some(secret) = self.secrets.get_mut(&user_id) {
225                    secret.last_verified_at = Some(Utc::now());
226                }
227                return Ok(MfaVerificationResult::TotpSuccess);
228            }
229
230            // Try backup code verification - need to separate the operations to avoid borrowing conflicts
231            let backup_code_valid = if let Some(secret) = self.secrets.get_mut(&user_id) {
232                // Clone the necessary data to avoid borrowing conflicts
233                let backup_codes = secret.backup_codes.clone();
234                let used_codes = secret.used_backup_codes.clone();
235
236                let code_hash = self.hash_backup_code(code)?;
237
238                // Check if this backup code exists and hasn't been used
239                backup_codes.contains(&code_hash) && !used_codes.contains(&code_hash)
240            } else {
241                false
242            };
243
244            if backup_code_valid {
245                // Compute hash before mutable borrow
246                let code_hash = self.hash_backup_code(code)?;
247
248                // Update the secret after verification
249                if let Some(secret) = self.secrets.get_mut(&user_id) {
250                    secret.used_backup_codes.push(code_hash);
251                    secret.last_verified_at = Some(Utc::now());
252                }
253                return Ok(MfaVerificationResult::BackupCodeSuccess);
254            }
255
256            Ok(MfaVerificationResult::Failed)
257        }
258    }
259
260    /// Check if MFA is enabled for a user
261    pub fn is_mfa_enabled(&self, user_id: Uuid) -> bool {
262        self.secrets
263            .get(&user_id)
264            .map(|secret| secret.setup_completed_at.is_some())
265            .unwrap_or(false)
266    }
267
268    /// Disable MFA for a user
269    pub fn disable_mfa(&mut self, user_id: Uuid) -> AuthResult<bool> {
270        Ok(self.secrets.remove(&user_id).is_some())
271    }
272
273    /// Get remaining backup codes count for a user
274    pub fn get_remaining_backup_codes_count(&self, user_id: Uuid) -> AuthResult<usize> {
275        match self.secrets.get(&user_id) {
276            Some(secret) => {
277                // Use saturating_sub to avoid potential underflow if used codes exceed total
278                let remaining = secret
279                    .backup_codes
280                    .len()
281                    .saturating_sub(secret.used_backup_codes.len());
282                Ok(remaining)
283            }
284            None => Ok(0),
285        }
286    }
287
288    /// Generate new backup codes (invalidates old ones)
289    pub fn regenerate_backup_codes(&mut self, user_id: Uuid) -> AuthResult<Vec<String>> {
290        #[cfg(not(feature = "mfa"))]
291        {
292            let _ = user_id;
293            return Err(AuthError::generic_error(
294                "MFA feature not enabled - compile with 'mfa' feature",
295            ));
296        }
297
298        #[cfg(feature = "mfa")]
299        {
300            if !self.secrets.contains_key(&user_id) {
301                return Err(AuthError::not_found("MFA not setup for user"));
302            }
303
304            // Generate new backup codes
305            let new_backup_codes = self.generate_backup_codes();
306            let hashed_codes = new_backup_codes
307                .iter()
308                .map(|code| self.hash_backup_code(code))
309                .collect::<Result<Vec<_>, _>>()?;
310
311            // Replace old backup codes
312            if let Some(secret) = self.secrets.get_mut(&user_id) {
313                secret.backup_codes = hashed_codes;
314                secret.used_backup_codes.clear();
315            }
316
317            Ok(new_backup_codes)
318        }
319    }
320
321    // Private helper methods
322
323    #[cfg(feature = "mfa")]
324    fn generate_secret(&self) -> Vec<u8> {
325        let mut secret = vec![0u8; self.config.secret_length];
326        thread_rng().fill(&mut secret[..]);
327        secret
328    }
329
330    #[cfg(feature = "mfa")]
331    fn generate_backup_codes(&self) -> Vec<String> {
332        let mut codes = Vec::with_capacity(self.config.backup_codes_count);
333        let mut rng = thread_rng();
334
335        for _ in 0..self.config.backup_codes_count {
336            // Generate 8-digit backup code
337            let code = format!("{:08}", rng.gen_range(10000000..99999999));
338            codes.push(code);
339        }
340
341        codes
342    }
343
344    #[cfg(feature = "mfa")]
345    fn generate_qr_code(&self, totp_uri: &str) -> AuthResult<String> {
346        let qr_code = QrCode::new(totp_uri)
347            .map_err(|e| AuthError::generic_error(format!("Failed to generate QR code: {}", e)))?;
348
349        let qr_string = qr_code
350            .render::<unicode::Dense1x2>()
351            .dark_color(unicode::Dense1x2::Light)
352            .light_color(unicode::Dense1x2::Dark)
353            .build();
354
355        Ok(qr_string)
356    }
357
358    #[cfg(feature = "mfa")]
359    fn verify_totp_code(&self, secret_base32: &str, code: &str) -> AuthResult<bool> {
360        // Decode the base32 secret
361        let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret_base32)
362            .ok_or_else(|| AuthError::generic_error("Invalid secret format"))?;
363
364        // Validate code format (should be 6 digits)
365        if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
366            return Err(AuthError::invalid_credentials("Invalid TOTP code format"));
367        }
368
369        let current_time = Utc::now().timestamp() as u64;
370
371        // Check current time window and tolerance windows
372        for i in 0..=(self.config.window_tolerance * 2) {
373            let time_offset = (i as i64) - (self.config.window_tolerance as i64);
374            let time_window =
375                ((current_time as i64) + (time_offset * self.config.time_step as i64)) as u64;
376
377            // Generate TOTP code for this time window
378            let expected_code = totp::<Sha1>(&secret, time_window);
379
380            // The standard totp function returns 8 digits, but we need 6
381            // So we need to take the last 6 digits to match standard authenticator apps
382            let expected_code_6digits = if expected_code.len() >= 6 {
383                &expected_code[expected_code.len() - 6..]
384            } else {
385                &expected_code
386            };
387
388            if expected_code_6digits == code {
389                return Ok(true);
390            }
391        }
392
393        Ok(false)
394    }
395
396    #[cfg(feature = "mfa")]
397    pub fn verify_backup_code(&mut self, secret: &mut MfaSecret, code: &str) -> AuthResult<bool> {
398        let code_hash = self.hash_backup_code(code)?;
399
400        // Check if this backup code exists and hasn't been used
401        if secret.backup_codes.contains(&code_hash)
402            && !secret.used_backup_codes.contains(&code_hash)
403        {
404            secret.used_backup_codes.push(code_hash);
405            return Ok(true);
406        }
407
408        Ok(false)
409    }
410
411    #[cfg(feature = "mfa")]
412    pub fn hash_backup_code(&self, code: &str) -> AuthResult<String> {
413        // Simple hash for backup codes - in production, use proper hashing
414        use std::collections::hash_map::DefaultHasher;
415        use std::hash::{Hash, Hasher};
416
417        let mut hasher = DefaultHasher::new();
418        code.hash(&mut hasher);
419        Ok(format!("{:x}", hasher.finish()))
420    }
421}
422
423impl Default for MfaProvider {
424    fn default() -> Self {
425        Self::new().expect("Failed to create default MFA provider")
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[tokio::test]
434    async fn test_mfa_provider_creation() {
435        let provider = MfaProvider::new();
436        assert!(provider.is_ok());
437    }
438
439    #[tokio::test]
440    async fn test_mfa_config_defaults() {
441        let config = MfaConfig::default();
442        assert_eq!(config.time_step, 30);
443        assert_eq!(config.window_tolerance, 1);
444        assert_eq!(config.secret_length, 20);
445        assert_eq!(config.issuer, "elif.rs");
446        assert_eq!(config.backup_codes_count, 10);
447    }
448
449    #[tokio::test]
450    async fn test_mfa_provider_with_custom_config() {
451        let config = MfaConfig {
452            time_step: 60,
453            window_tolerance: 2,
454            secret_length: 32,
455            issuer: "test-app".to_string(),
456            backup_codes_count: 12,
457        };
458
459        let provider = MfaProvider::with_config(config.clone());
460        assert!(provider.is_ok());
461
462        let provider = provider.unwrap();
463        assert_eq!(provider.config.time_step, 60);
464        assert_eq!(provider.config.issuer, "test-app");
465    }
466
467    #[cfg(feature = "mfa")]
468    #[tokio::test]
469    async fn test_mfa_setup_generation() {
470        let provider = MfaProvider::new().unwrap();
471        let user_id = Uuid::new_v4();
472        let username = "testuser";
473
474        let setup = provider.generate_setup(user_id, username);
475        assert!(setup.is_ok());
476
477        let setup = setup.unwrap();
478        assert!(!setup.qr_code.is_empty());
479        assert!(!setup.manual_key.is_empty());
480        assert!(setup.totp_uri.contains("otpauth://totp/"));
481        assert!(setup.totp_uri.contains(username));
482        assert_eq!(setup.backup_codes.len(), 10);
483    }
484
485    #[tokio::test]
486    async fn test_mfa_not_enabled_by_default() {
487        let mut provider = MfaProvider::new().unwrap();
488        let user_id = Uuid::new_v4();
489
490        assert!(!provider.is_mfa_enabled(user_id));
491
492        let result = provider.verify_mfa(user_id, "123456");
493        assert!(result.is_ok());
494        assert_eq!(result.unwrap(), MfaVerificationResult::NotSetup);
495    }
496
497    #[tokio::test]
498    async fn test_mfa_disable() {
499        let mut provider = MfaProvider::new().unwrap();
500        let user_id = Uuid::new_v4();
501
502        // Initially no MFA setup
503        let disabled = provider.disable_mfa(user_id).unwrap();
504        assert!(!disabled);
505
506        // TODO: Test with actual MFA setup once setup is working
507    }
508
509    #[tokio::test]
510    async fn test_backup_codes_count() {
511        let provider = MfaProvider::new().unwrap();
512        let user_id = Uuid::new_v4();
513
514        let count = provider.get_remaining_backup_codes_count(user_id);
515        assert!(count.is_ok());
516        assert_eq!(count.unwrap(), 0);
517    }
518
519    #[tokio::test]
520    async fn test_backup_codes_count_underflow() {
521        let mut provider = MfaProvider::new().unwrap();
522        let user_id = Uuid::new_v4();
523
524        // Simulate corrupted state where used codes exceed total backup codes
525        provider.secrets.insert(
526            user_id,
527            MfaSecret {
528                user_id,
529                secret: "dummy".to_string(),
530                backup_codes: Vec::new(),
531                used_backup_codes: vec!["used".to_string()],
532                setup_completed_at: None,
533                last_verified_at: None,
534                created_at: Utc::now(),
535            },
536        );
537
538        let count = provider.get_remaining_backup_codes_count(user_id);
539        assert!(count.is_ok());
540        assert_eq!(count.unwrap(), 0);
541    }
542
543    #[cfg(not(feature = "mfa"))]
544    #[tokio::test]
545    async fn test_mfa_disabled_error_messages() {
546        let provider = MfaProvider::new().unwrap();
547        let user_id = Uuid::new_v4();
548
549        let setup_result = provider.generate_setup(user_id, "testuser");
550        assert!(setup_result.is_err());
551        assert!(setup_result
552            .unwrap_err()
553            .to_string()
554            .contains("MFA feature not enabled"));
555    }
556
557    #[cfg(feature = "mfa")]
558    #[tokio::test]
559    async fn test_backup_code_generation() {
560        let provider = MfaProvider::new().unwrap();
561        let codes = provider.generate_backup_codes();
562
563        assert_eq!(codes.len(), 10);
564        for code in &codes {
565            assert_eq!(code.len(), 8);
566            assert!(code.chars().all(|c| c.is_ascii_digit()));
567        }
568
569        // All codes should be unique
570        let unique_codes: std::collections::HashSet<_> = codes.iter().collect();
571        assert_eq!(unique_codes.len(), codes.len());
572    }
573
574    #[cfg(feature = "mfa")]
575    #[tokio::test]
576    async fn test_secret_generation() {
577        let provider = MfaProvider::new().unwrap();
578        let secret1 = provider.generate_secret();
579        let secret2 = provider.generate_secret();
580
581        assert_eq!(secret1.len(), 20);
582        assert_eq!(secret2.len(), 20);
583        assert_ne!(secret1, secret2); // Should generate different secrets
584    }
585
586    #[cfg(feature = "mfa")]
587    #[tokio::test]
588    async fn test_totp_uri_format() {
589        let provider = MfaProvider::new().unwrap();
590        let user_id = Uuid::new_v4();
591        let username = "test@example.com";
592
593        let setup = provider.generate_setup(user_id, username).unwrap();
594
595        assert!(setup.totp_uri.starts_with("otpauth://totp/"));
596        assert!(setup.totp_uri.contains("elif.rs"));
597        assert!(setup.totp_uri.contains("test%40example.com")); // URL encoded
598        assert!(setup.totp_uri.contains("digits=6"));
599        assert!(setup.totp_uri.contains("period=30"));
600    }
601}