auth_framework/security/
secure_mfa.rs

1// Secure MFA implementation with cryptographically strong code generation
2// Fixes critical security vulnerabilities in MFA code generation and validation
3
4use crate::errors::{AuthError, Result};
5use crate::storage::AuthStorage;
6use base64::Engine;
7use dashmap::DashMap;
8use ring::rand::{SecureRandom, SystemRandom};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use subtle::ConstantTimeEq;
13use zeroize::ZeroizeOnDrop;
14
15/// Secure MFA challenge with proper entropy and expiration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MfaChallenge {
18    pub challenge_id: String,
19    pub user_id: String,
20    pub challenge_type: MfaChallengeType,
21    pub code_hash: String,
22    pub created_at: SystemTime,
23    pub expires_at: SystemTime,
24    pub attempts: u32,
25    pub max_attempts: u32,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub enum MfaChallengeType {
30    Sms,
31    Email,
32    Totp,
33}
34
35/// Secure MFA code that zeros itself when dropped
36#[derive(Debug, Clone, ZeroizeOnDrop)]
37pub struct SecureMfaCode {
38    code: String,
39}
40
41impl SecureMfaCode {
42    pub fn as_str(&self) -> &str {
43        &self.code
44    }
45}
46
47/// Secure MFA service with proper cryptographic implementations
48pub struct SecureMfaService {
49    storage: Box<dyn AuthStorage>,
50    rng: SystemRandom,
51    /// Rate limiting: user_id -> (attempts, last_attempt)
52    rate_limits: Arc<DashMap<String, (u32, SystemTime)>>,
53}
54
55impl SecureMfaService {
56    pub fn new(storage: Box<dyn AuthStorage>) -> Self {
57        Self {
58            storage,
59            rng: SystemRandom::new(),
60            rate_limits: Arc::new(DashMap::new()),
61        }
62    }
63
64    /// Generate cryptographically secure MFA code
65    pub fn generate_secure_code(&self, length: usize) -> Result<SecureMfaCode> {
66        if !(4..=12).contains(&length) {
67            return Err(AuthError::validation(
68                "MFA code length must be between 4 and 12",
69            ));
70        }
71
72        let mut code = String::with_capacity(length);
73
74        // Generate each digit individually to avoid modulo bias
75        for _ in 0..length {
76            let mut byte = [0u8; 1];
77            loop {
78                self.rng.fill(&mut byte).map_err(|_| {
79                    AuthError::crypto("Failed to generate secure random bytes".to_string())
80                })?;
81
82                // Use rejection sampling for uniform distribution (0-9)
83                let digit = byte[0] % 250; // Use 250 to avoid bias (250 is divisible by 10)
84                if digit < 250 {
85                    code.push(char::from(b'0' + (digit % 10)));
86                    break;
87                }
88            }
89        }
90
91        Ok(SecureMfaCode { code })
92    }
93
94    /// Hash MFA code for secure storage
95    fn hash_code(&self, code: &str, salt: &[u8]) -> Result<String> {
96        use ring::digest;
97
98        let mut context = digest::Context::new(&digest::SHA256);
99        context.update(salt);
100        context.update(code.as_bytes());
101        let hash = context.finish();
102
103        Ok(base64::engine::general_purpose::STANDARD.encode(hash.as_ref()))
104    }
105
106    /// Generate secure salt
107    fn generate_salt(&self) -> Result<Vec<u8>> {
108        let mut salt = vec![0u8; 32];
109        self.rng
110            .fill(&mut salt)
111            .map_err(|_| AuthError::crypto("Failed to generate salt".to_string()))?;
112        Ok(salt)
113    }
114
115    /// Check rate limiting for user
116    fn check_rate_limit(&self, user_id: &str) -> Result<()> {
117        let now = SystemTime::now();
118        let window = Duration::from_secs(60); // 1 minute window
119        let max_attempts = 5; // Max 5 attempts per minute
120
121        let (attempts, last_attempt) = self
122            .rate_limits
123            .get(user_id)
124            .map(|entry| *entry.value())
125            .unwrap_or((0, now));
126
127        // Reset counter if window has passed
128        if now.duration_since(last_attempt).unwrap_or(Duration::ZERO) > window {
129            self.rate_limits.insert(user_id.to_string(), (1, now));
130            return Ok(());
131        }
132
133        if attempts >= max_attempts {
134            return Err(AuthError::rate_limit(
135                "Too many MFA attempts. Please wait.".to_string(),
136            ));
137        }
138
139        self.rate_limits
140            .insert(user_id.to_string(), (attempts + 1, now));
141        Ok(())
142    }
143
144    /// Create secure MFA challenge
145    pub async fn create_challenge(
146        &self,
147        user_id: &str,
148        challenge_type: MfaChallengeType,
149        code_length: usize,
150    ) -> Result<(String, SecureMfaCode)> {
151        // Check rate limiting
152        self.check_rate_limit(user_id)?;
153
154        // Generate secure challenge ID
155        let challenge_id = self.generate_secure_id("mfa")?;
156
157        // Generate secure code
158        let secure_code = self.generate_secure_code(code_length)?;
159
160        // Generate salt and hash the code
161        let salt = self.generate_salt()?;
162        let code_hash = self.hash_code(secure_code.as_str(), &salt)?;
163
164        // Create challenge
165        let now = SystemTime::now();
166        let challenge = MfaChallenge {
167            challenge_id: challenge_id.clone(),
168            user_id: user_id.to_string(),
169            challenge_type,
170            code_hash,
171            created_at: now,
172            expires_at: now + Duration::from_secs(300), // 5 minutes
173            attempts: 0,
174            max_attempts: 3,
175        };
176
177        // Store challenge and salt
178        let challenge_data = serde_json::to_vec(&challenge)
179            .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
180
181        self.storage
182            .store_kv(
183                &format!("mfa_challenge:{}", challenge_id),
184                &challenge_data,
185                Some(Duration::from_secs(300)),
186            )
187            .await?;
188
189        self.storage
190            .store_kv(
191                &format!("mfa_salt:{}", challenge_id),
192                &salt,
193                Some(Duration::from_secs(300)),
194            )
195            .await?;
196
197        tracing::info!("Created secure MFA challenge for user: {}", user_id);
198        Ok((challenge_id, secure_code))
199    }
200
201    /// Verify MFA code with constant-time comparison
202    pub async fn verify_challenge(&self, challenge_id: &str, provided_code: &str) -> Result<bool> {
203        // Validate input format
204        if provided_code.is_empty() || provided_code.len() > 12 {
205            return Ok(false);
206        }
207
208        if !provided_code.chars().all(|c| c.is_ascii_digit()) {
209            return Ok(false);
210        }
211
212        // Retrieve challenge
213        let challenge_data = self
214            .storage
215            .get_kv(&format!("mfa_challenge:{}", challenge_id))
216            .await?;
217
218        let mut challenge: MfaChallenge = match challenge_data {
219            Some(data) => serde_json::from_slice(&data)
220                .map_err(|_| AuthError::validation("Invalid challenge data"))?,
221            None => return Ok(false), // Challenge not found or expired
222        };
223
224        // Check if challenge is expired
225        if SystemTime::now() > challenge.expires_at {
226            // Clean up expired challenge
227            self.cleanup_challenge(challenge_id).await?;
228            return Ok(false);
229        }
230
231        // Check attempt limits
232        if challenge.attempts >= challenge.max_attempts {
233            self.cleanup_challenge(challenge_id).await?;
234            return Ok(false);
235        }
236
237        // Increment attempt counter
238        challenge.attempts += 1;
239        let challenge_data = serde_json::to_vec(&challenge)
240            .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
241        self.storage
242            .store_kv(
243                &format!("mfa_challenge:{}", challenge_id),
244                &challenge_data,
245                Some(Duration::from_secs(300)),
246            )
247            .await?;
248
249        // Retrieve salt
250        let salt = match self
251            .storage
252            .get_kv(&format!("mfa_salt:{}", challenge_id))
253            .await?
254        {
255            Some(salt) => salt,
256            None => return Ok(false),
257        };
258
259        // Hash provided code with same salt
260        let provided_hash = self.hash_code(provided_code, &salt)?;
261
262        // Constant-time comparison
263        let is_valid = challenge
264            .code_hash
265            .as_bytes()
266            .ct_eq(provided_hash.as_bytes())
267            .into();
268
269        if is_valid {
270            // Clean up successful challenge
271            self.cleanup_challenge(challenge_id).await?;
272            tracing::info!(
273                "MFA challenge verified successfully for user: {}",
274                challenge.user_id
275            );
276        }
277
278        Ok(is_valid)
279    }
280
281    /// Clean up challenge data
282    async fn cleanup_challenge(&self, challenge_id: &str) -> Result<()> {
283        let _ = self
284            .storage
285            .delete_kv(&format!("mfa_challenge:{}", challenge_id))
286            .await;
287        let _ = self
288            .storage
289            .delete_kv(&format!("mfa_salt:{}", challenge_id))
290            .await;
291        Ok(())
292    }
293
294    /// Generate secure ID
295    fn generate_secure_id(&self, prefix: &str) -> Result<String> {
296        let mut bytes = vec![0u8; 16];
297        self.rng
298            .fill(&mut bytes)
299            .map_err(|_| AuthError::crypto("Failed to generate secure ID".to_string()))?;
300
301        let id = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes);
302        Ok(format!("{}_{}", prefix, id))
303    }
304
305    /// Generate cryptographically secure backup codes
306    pub fn generate_backup_codes(
307        &self,
308        count: u8,
309    ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
310        let mut codes = Vec::with_capacity(count as usize);
311
312        for _ in 0..count {
313            // Generate a secure 16-character alphanumeric backup code
314            let mut code_bytes = [0u8; 10]; // 10 bytes = 80 bits of entropy
315            self.rng.fill(&mut code_bytes)?;
316
317            // Convert to base32 for human readability (no ambiguous characters)
318            let code = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &code_bytes);
319
320            // Format as XXXX-XXXX-XXXX-XXXX for readability
321            let formatted_code = format!(
322                "{}-{}-{}-{}",
323                &code[0..4],
324                &code[4..8],
325                &code[8..12],
326                &code[12..16]
327            );
328
329            codes.push(formatted_code);
330        }
331
332        Ok(codes)
333    }
334
335    /// Securely hash backup codes for storage
336    pub fn hash_backup_codes(
337        &self,
338        codes: &[String],
339    ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
340        let mut hashed_codes = Vec::with_capacity(codes.len());
341
342        for code in codes {
343            // Use ring's PBKDF2 for secure hashing
344            let salt = self.generate_salt()?;
345            let mut hash = [0u8; 32];
346
347            ring::pbkdf2::derive(
348                ring::pbkdf2::PBKDF2_HMAC_SHA256,
349                std::num::NonZeroU32::new(100_000).unwrap(), // 100k iterations
350                &salt,
351                code.as_bytes(),
352                &mut hash,
353            );
354
355            // Store as salt:hash for verification
356            let salt_hex = hex::encode(&salt);
357            let hash_hex = hex::encode(hash);
358            hashed_codes.push(format!("{}:{}", salt_hex, hash_hex));
359        }
360
361        Ok(hashed_codes)
362    }
363
364    /// Verify a backup code against stored hashes
365    pub fn verify_backup_code(
366        &self,
367        hashed_codes: &[String],
368        provided_code: &str,
369    ) -> Result<bool, Box<dyn std::error::Error>> {
370        // Input validation
371        if provided_code.len() != 19 || provided_code.chars().filter(|&c| c == '-').count() != 3 {
372            return Ok(false);
373        }
374
375        // Remove dashes for processing
376        let clean_code = provided_code.replace("-", "");
377        if clean_code.len() != 16 || !clean_code.chars().all(|c| c.is_ascii_alphanumeric()) {
378            return Ok(false);
379        }
380
381        for hashed_code in hashed_codes {
382            let parts: Vec<&str> = hashed_code.split(':').collect();
383            if parts.len() != 2 {
384                continue;
385            }
386
387            let salt = match hex::decode(parts[0]) {
388                Ok(s) => s,
389                Err(_) => continue,
390            };
391
392            let stored_hash = match hex::decode(parts[1]) {
393                Ok(h) => h,
394                Err(_) => continue,
395            };
396
397            // Derive hash from provided code
398            let mut derived_hash = [0u8; 32];
399            ring::pbkdf2::derive(
400                ring::pbkdf2::PBKDF2_HMAC_SHA256,
401                std::num::NonZeroU32::new(100_000).unwrap(),
402                &salt,
403                provided_code.as_bytes(),
404                &mut derived_hash,
405            );
406
407            // Constant-time comparison
408            if subtle::ConstantTimeEq::ct_eq(&stored_hash[..], &derived_hash[..]).into() {
409                return Ok(true);
410            }
411        }
412
413        Ok(false)
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use crate::testing::MockStorage;
421
422    #[tokio::test]
423    async fn test_secure_code_generation() {
424        let storage = Box::new(MockStorage::new());
425        let mfa_service = SecureMfaService::new(storage);
426
427        let code = mfa_service.generate_secure_code(6).unwrap();
428        assert_eq!(code.as_str().len(), 6);
429        assert!(code.as_str().chars().all(|c| c.is_ascii_digit()));
430    }
431
432    #[tokio::test]
433    async fn test_mfa_challenge_flow() {
434        let storage = Box::new(MockStorage::new());
435        let mfa_service = SecureMfaService::new(storage);
436
437        // Create challenge
438        let (challenge_id, code) = mfa_service
439            .create_challenge("user123", MfaChallengeType::Sms, 6)
440            .await
441            .unwrap();
442
443        // Verify with correct code
444        let result = mfa_service
445            .verify_challenge(&challenge_id, code.as_str())
446            .await
447            .unwrap();
448        assert!(result);
449
450        // Challenge should be cleaned up after successful verification
451        let result2 = mfa_service
452            .verify_challenge(&challenge_id, code.as_str())
453            .await
454            .unwrap();
455        assert!(!result2);
456    }
457
458    #[tokio::test]
459    async fn test_invalid_code_rejection() {
460        let storage = Box::new(MockStorage::new());
461        let mfa_service = SecureMfaService::new(storage);
462
463        let (challenge_id, _code) = mfa_service
464            .create_challenge("user123", MfaChallengeType::Sms, 6)
465            .await
466            .unwrap();
467
468        // Test various invalid codes
469        assert!(
470            !mfa_service
471                .verify_challenge(&challenge_id, "000000")
472                .await
473                .unwrap()
474        );
475        assert!(
476            !mfa_service
477                .verify_challenge(&challenge_id, "123abc")
478                .await
479                .unwrap()
480        );
481        assert!(
482            !mfa_service
483                .verify_challenge(&challenge_id, "")
484                .await
485                .unwrap()
486        );
487        assert!(
488            !mfa_service
489                .verify_challenge(&challenge_id, "12345678901234")
490                .await
491                .unwrap()
492        );
493    }
494
495    #[tokio::test]
496    async fn test_rate_limiting() {
497        let storage = Box::new(MockStorage::new());
498        let mfa_service = SecureMfaService::new(storage);
499
500        // Should succeed first few times
501        for _ in 0..5 {
502            let result = mfa_service
503                .create_challenge("user123", MfaChallengeType::Sms, 6)
504                .await;
505            assert!(result.is_ok());
506        }
507
508        // Should fail due to rate limiting
509        let result = mfa_service
510            .create_challenge("user123", MfaChallengeType::Sms, 6)
511            .await;
512        assert!(result.is_err());
513    }
514}