Skip to main content

auth_framework/security/
secure_mfa.rs

1// Secure MFA implementation with cryptographically strong code generation
2// Fixes critical security vulnerabilities in MFA code generation and validation
3
4use crate::errors::{AuthError, Result};
5use crate::methods::{MfaChallenge, MfaType};
6use crate::storage::AuthStorage;
7use base64::Engine;
8use dashmap::DashMap;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use subtle::ConstantTimeEq;
14use zeroize::ZeroizeOnDrop;
15
16/// Secure MFA code that zeros itself when dropped
17#[derive(Debug, Clone, ZeroizeOnDrop)]
18pub struct SecureMfaCode {
19    code: String,
20}
21
22impl SecureMfaCode {
23    pub fn as_str(&self) -> &str {
24        &self.code
25    }
26}
27
28/// Secure MFA service with proper cryptographic implementations
29pub struct SecureMfaService {
30    storage: Box<dyn AuthStorage>,
31    rng: SystemRandom,
32    /// Rate limiting: user_id -> (attempts, last_attempt)
33    rate_limits: Arc<DashMap<String, (u32, SystemTime)>>,
34}
35
36impl SecureMfaService {
37    pub fn new(storage: Box<dyn AuthStorage>) -> Self {
38        Self {
39            storage,
40            rng: SystemRandom::new(),
41            rate_limits: Arc::new(DashMap::new()),
42        }
43    }
44
45    /// Generate cryptographically secure MFA code
46    pub fn generate_secure_code(&self, length: usize) -> Result<SecureMfaCode> {
47        if !(4..=12).contains(&length) {
48            return Err(AuthError::validation(
49                "MFA code length must be between 4 and 12",
50            ));
51        }
52
53        let mut code = String::with_capacity(length);
54
55        // Generate each digit individually to avoid modulo bias
56        for _ in 0..length {
57            let mut byte = [0u8; 1];
58            loop {
59                self.rng.fill(&mut byte).map_err(|_| {
60                    AuthError::crypto("Failed to generate secure random bytes".to_string())
61                })?;
62
63                // Use rejection sampling for uniform distribution (0-9)
64                // Reject values >= 250 to avoid modulo bias (250 is divisible by 10)
65                if byte[0] < 250 {
66                    code.push(char::from(b'0' + (byte[0] % 10)));
67                    break;
68                }
69            }
70        }
71
72        Ok(SecureMfaCode { code })
73    }
74
75    /// Hash MFA code for secure storage using PBKDF2-HMAC-SHA256.
76    ///
77    /// Although MFA codes are short-lived, we use a proper KDF to resist
78    /// brute-force attacks on the 1M possible 6-digit code values.
79    fn hash_code(&self, code: &str, salt: &[u8]) -> Result<String> {
80        use ring::pbkdf2;
81
82        let mut out = [0u8; 32];
83        pbkdf2::derive(
84            pbkdf2::PBKDF2_HMAC_SHA256,
85            std::num::NonZeroU32::new(10_000).unwrap(),
86            salt,
87            code.as_bytes(),
88            &mut out,
89        );
90
91        Ok(base64::engine::general_purpose::STANDARD.encode(&out))
92    }
93
94    /// Generate secure salt
95    fn generate_salt(&self) -> Result<Vec<u8>> {
96        let mut salt = vec![0u8; 32];
97        self.rng
98            .fill(&mut salt)
99            .map_err(|_| AuthError::crypto("Failed to generate salt".to_string()))?;
100        Ok(salt)
101    }
102
103    /// Check rate limiting for user
104    fn check_rate_limit(&self, user_id: &str) -> Result<()> {
105        let now = SystemTime::now();
106        let window = Duration::from_secs(60); // 1 minute window
107        let max_attempts = 5; // Max 5 attempts per minute
108
109        let (attempts, last_attempt) = self
110            .rate_limits
111            .get(user_id)
112            .map(|entry| *entry.value())
113            .unwrap_or((0, now));
114
115        // Reset counter if window has passed
116        if now.duration_since(last_attempt).unwrap_or(Duration::ZERO) > window {
117            self.rate_limits.insert(user_id.to_string(), (1, now));
118            return Ok(());
119        }
120
121        if attempts >= max_attempts {
122            return Err(AuthError::rate_limit(
123                "Too many MFA attempts. Please wait.".to_string(),
124            ));
125        }
126
127        self.rate_limits
128            .insert(user_id.to_string(), (attempts + 1, now));
129        Ok(())
130    }
131
132    /// Create secure MFA challenge
133    pub async fn create_challenge(
134        &self,
135        user_id: &str,
136        mfa_type: MfaType,
137        code_length: usize,
138    ) -> Result<(String, SecureMfaCode)> {
139        // Check rate limiting
140        self.check_rate_limit(user_id)?;
141
142        // Generate secure challenge ID
143        let challenge_id = self.generate_secure_id("mfa")?;
144
145        // Generate secure code
146        let secure_code = self.generate_secure_code(code_length)?;
147
148        // Generate salt and hash the code
149        let salt = self.generate_salt()?;
150        let code_hash = self.hash_code(secure_code.as_str(), &salt)?;
151
152        // Create challenge using the canonical methods::MfaChallenge type
153        let now = chrono::Utc::now();
154        let challenge = MfaChallenge {
155            id: challenge_id.clone(),
156            user_id: user_id.to_string(),
157            mfa_type,
158            created_at: now,
159            expires_at: now + chrono::Duration::seconds(300), // 5 minutes
160            attempts: 0,
161            max_attempts: 3,
162            code_hash: Some(code_hash),
163            message: None,
164            data: HashMap::new(),
165        };
166
167        // Store challenge and salt
168        let challenge_data = serde_json::to_vec(&challenge)
169            .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
170
171        self.storage
172            .store_kv(
173                &format!("mfa_challenge:{}", challenge_id),
174                &challenge_data,
175                Some(Duration::from_secs(300)),
176            )
177            .await?;
178
179        self.storage
180            .store_kv(
181                &format!("mfa_salt:{}", challenge_id),
182                &salt,
183                Some(Duration::from_secs(300)),
184            )
185            .await?;
186
187        tracing::info!("Created secure MFA challenge for user: {}", user_id);
188        Ok((challenge_id, secure_code))
189    }
190
191    /// Verify MFA code with constant-time comparison
192    pub async fn verify_challenge(&self, challenge_id: &str, provided_code: &str) -> Result<bool> {
193        // Always perform the full verification path to prevent timing oracles.
194        // Input format validation is deferred until after the hash comparison
195        // to avoid leaking information about valid challenge IDs via response timing.
196        let format_valid = !provided_code.is_empty()
197            && provided_code.len() <= 12
198            && provided_code.chars().all(|c| c.is_ascii_digit());
199
200        // Retrieve challenge
201        let challenge_data = self
202            .storage
203            .get_kv(&format!("mfa_challenge:{}", challenge_id))
204            .await?;
205
206        let mut challenge: MfaChallenge = match challenge_data {
207            Some(data) => serde_json::from_slice(&data)
208                .map_err(|_| AuthError::validation("Invalid challenge data"))?,
209            None => {
210                // Challenge not found — still perform dummy work to prevent timing leak
211                let dummy_salt = [0u8; 32];
212                let _ = self.hash_code("000000", &dummy_salt);
213                return Ok(false);
214            }
215        };
216
217        // Rate-limit verification attempts per user
218        self.check_rate_limit(&challenge.user_id)?;
219
220        // Check if challenge is expired
221        if chrono::Utc::now() > challenge.expires_at {
222            // Clean up expired challenge
223            self.cleanup_challenge(challenge_id).await?;
224            // Still perform dummy hash to keep constant timing
225            let dummy_salt = [0u8; 32];
226            let _ = self.hash_code("000000", &dummy_salt);
227            return Ok(false);
228        }
229
230        // Check attempt limits
231        if challenge.attempts >= challenge.max_attempts {
232            self.cleanup_challenge(challenge_id).await?;
233            let dummy_salt = [0u8; 32];
234            let _ = self.hash_code("000000", &dummy_salt);
235            return Ok(false);
236        }
237
238        // Increment attempt counter
239        challenge.attempts += 1;
240        let challenge_data = serde_json::to_vec(&challenge)
241            .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
242        self.storage
243            .store_kv(
244                &format!("mfa_challenge:{}", challenge_id),
245                &challenge_data,
246                Some(Duration::from_secs(300)),
247            )
248            .await?;
249
250        // Retrieve salt
251        let salt = match self
252            .storage
253            .get_kv(&format!("mfa_salt:{}", challenge_id))
254            .await?
255        {
256            Some(salt) => salt,
257            None => return Ok(false),
258        };
259
260        // Hash provided code with same salt (always performed regardless of format_valid)
261        let provided_hash = self.hash_code(
262            if format_valid {
263                provided_code
264            } else {
265                "000000"
266            },
267            &salt,
268        )?;
269
270        // Constant-time comparison (code_hash is Option<String>)
271        let hash_matches = challenge.code_hash.as_ref().is_some_and(|stored_hash| {
272            stored_hash
273                .as_bytes()
274                .ct_eq(provided_hash.as_bytes())
275                .into()
276        });
277
278        // Both format AND hash must be valid
279        let is_valid = format_valid && hash_matches;
280
281        if is_valid {
282            // Clean up successful challenge
283            self.cleanup_challenge(challenge_id).await?;
284            tracing::info!(
285                "MFA challenge verified successfully for user: {}",
286                challenge.user_id
287            );
288        }
289
290        Ok(is_valid)
291    }
292
293    /// Clean up challenge data
294    async fn cleanup_challenge(&self, challenge_id: &str) -> Result<()> {
295        let _ = self
296            .storage
297            .delete_kv(&format!("mfa_challenge:{}", challenge_id))
298            .await;
299        let _ = self
300            .storage
301            .delete_kv(&format!("mfa_salt:{}", challenge_id))
302            .await;
303        Ok(())
304    }
305
306    /// Generate secure ID
307    fn generate_secure_id(&self, prefix: &str) -> Result<String> {
308        let mut bytes = vec![0u8; 16];
309        self.rng
310            .fill(&mut bytes)
311            .map_err(|_| AuthError::crypto("Failed to generate secure ID".to_string()))?;
312
313        let id = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes);
314        Ok(format!("{}_{}", prefix, id))
315    }
316
317    /// Generate cryptographically secure backup codes
318    pub fn generate_backup_codes(
319        &self,
320        count: u8,
321    ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
322        let mut codes = Vec::with_capacity(count as usize);
323
324        for _ in 0..count {
325            // Generate a secure 16-character alphanumeric backup code
326            let mut code_bytes = [0u8; 10]; // 10 bytes = 80 bits of entropy
327            self.rng
328                .fill(&mut code_bytes)
329                .map_err(|_| "Failed to generate random bytes".to_string())?;
330
331            // Convert to base32 for human readability (no ambiguous characters)
332            let code = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &code_bytes);
333
334            // Format as XXXX-XXXX-XXXX-XXXX for readability
335            let formatted_code = format!(
336                "{}-{}-{}-{}",
337                &code[0..4],
338                &code[4..8],
339                &code[8..12],
340                &code[12..16]
341            );
342
343            codes.push(formatted_code);
344        }
345
346        Ok(codes)
347    }
348
349    /// Securely hash backup codes for storage
350    pub fn hash_backup_codes(
351        &self,
352        codes: &[String],
353    ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
354        let mut hashed_codes = Vec::with_capacity(codes.len());
355
356        for code in codes {
357            // Use ring's PBKDF2 for secure hashing
358            let salt = self.generate_salt()?;
359            let mut hash = [0u8; 32];
360
361            ring::pbkdf2::derive(
362                ring::pbkdf2::PBKDF2_HMAC_SHA256,
363                std::num::NonZeroU32::new(100_000).expect("100_000 is non-zero"), // 100k iterations
364                &salt,
365                code.as_bytes(),
366                &mut hash,
367            );
368
369            // Store as salt:hash for verification
370            let salt_hex = hex::encode(&salt);
371            let hash_hex = hex::encode(hash);
372            hashed_codes.push(format!("{}:{}", salt_hex, hash_hex));
373        }
374
375        Ok(hashed_codes)
376    }
377
378    /// Verify a backup code against stored hashes
379    pub fn verify_backup_code(
380        &self,
381        hashed_codes: &[String],
382        provided_code: &str,
383    ) -> Result<bool, Box<dyn std::error::Error>> {
384        // Input validation
385        if provided_code.len() != 19 || provided_code.chars().filter(|&c| c == '-').count() != 3 {
386            return Ok(false);
387        }
388
389        // Remove dashes for processing
390        let clean_code = provided_code.replace("-", "");
391        if clean_code.len() != 16 || !clean_code.chars().all(|c| c.is_ascii_alphanumeric()) {
392            return Ok(false);
393        }
394
395        // Iterate ALL codes without early return to prevent timing attacks
396        // that could reveal which position the valid code occupies.
397        let mut found = false;
398
399        for hashed_code in hashed_codes {
400            let parts: Vec<&str> = hashed_code.split(':').collect();
401            if parts.len() != 2 {
402                continue;
403            }
404
405            let salt = match hex::decode(parts[0]) {
406                Ok(s) => s,
407                Err(_) => continue,
408            };
409
410            let stored_hash = match hex::decode(parts[1]) {
411                Ok(h) => h,
412                Err(_) => continue,
413            };
414
415            // Derive hash from provided code
416            let mut derived_hash = [0u8; 32];
417            ring::pbkdf2::derive(
418                ring::pbkdf2::PBKDF2_HMAC_SHA256,
419                std::num::NonZeroU32::new(100_000).expect("100_000 is non-zero"),
420                &salt,
421                provided_code.as_bytes(),
422                &mut derived_hash,
423            );
424
425            // Constant-time comparison — do NOT early-return on match to prevent
426            // leaking which position the valid code is at via timing analysis.
427            let matches: bool =
428                subtle::ConstantTimeEq::ct_eq(&stored_hash[..], &derived_hash[..]).into();
429            if matches {
430                found = true;
431            }
432        }
433
434        Ok(found)
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::testing::MockStorage;
442
443    #[tokio::test]
444    async fn test_secure_code_generation() {
445        let storage = Box::new(MockStorage::new());
446        let mfa_service = SecureMfaService::new(storage);
447
448        let code = mfa_service.generate_secure_code(6).unwrap();
449        assert_eq!(code.as_str().len(), 6);
450        assert!(code.as_str().chars().all(|c| c.is_ascii_digit()));
451    }
452
453    #[tokio::test]
454    async fn test_mfa_challenge_flow() {
455        let storage = Box::new(MockStorage::new());
456        let mfa_service = SecureMfaService::new(storage);
457
458        // Create challenge
459        let (challenge_id, code) = mfa_service
460            .create_challenge(
461                "user123",
462                MfaType::Sms {
463                    phone_number: String::new(),
464                },
465                6,
466            )
467            .await
468            .unwrap();
469
470        // Verify with correct code
471        let result = mfa_service
472            .verify_challenge(&challenge_id, code.as_str())
473            .await
474            .unwrap();
475        assert!(result);
476
477        // Challenge should be cleaned up after successful verification
478        let result2 = mfa_service
479            .verify_challenge(&challenge_id, code.as_str())
480            .await
481            .unwrap();
482        assert!(!result2);
483    }
484
485    #[tokio::test]
486    async fn test_invalid_code_rejection() {
487        let storage = Box::new(MockStorage::new());
488        let mfa_service = SecureMfaService::new(storage);
489
490        let (challenge_id, _code) = mfa_service
491            .create_challenge(
492                "user123",
493                MfaType::Sms {
494                    phone_number: String::new(),
495                },
496                6,
497            )
498            .await
499            .unwrap();
500
501        // Test various invalid codes
502        assert!(
503            !mfa_service
504                .verify_challenge(&challenge_id, "000000")
505                .await
506                .unwrap()
507        );
508        assert!(
509            !mfa_service
510                .verify_challenge(&challenge_id, "123abc")
511                .await
512                .unwrap()
513        );
514        assert!(
515            !mfa_service
516                .verify_challenge(&challenge_id, "")
517                .await
518                .unwrap()
519        );
520        assert!(
521            !mfa_service
522                .verify_challenge(&challenge_id, "12345678901234")
523                .await
524                .unwrap()
525        );
526    }
527
528    #[tokio::test]
529    async fn test_rate_limiting() {
530        let storage = Box::new(MockStorage::new());
531        let mfa_service = SecureMfaService::new(storage);
532
533        // Should succeed first few times
534        for _ in 0..5 {
535            let result = mfa_service
536                .create_challenge(
537                    "user123",
538                    MfaType::Sms {
539                        phone_number: String::new(),
540                    },
541                    6,
542                )
543                .await;
544            assert!(result.is_ok());
545        }
546
547        // Should fail due to rate limiting
548        let result = mfa_service
549            .create_challenge(
550                "user123",
551                MfaType::Sms {
552                    phone_number: String::new(),
553                },
554                6,
555            )
556            .await;
557        assert!(result.is_err());
558    }
559}