1use crate::errors::{AuthError, Result};
5use crate::storage::AuthStorage;
6use base64::Engine;
7use dashmap::DashMap;
8use ring::rand::{SecureRandom, SystemRandom};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use subtle::ConstantTimeEq;
13use zeroize::ZeroizeOnDrop;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MfaChallenge {
18 pub challenge_id: String,
19 pub user_id: String,
20 pub challenge_type: MfaChallengeType,
21 pub code_hash: String,
22 pub created_at: SystemTime,
23 pub expires_at: SystemTime,
24 pub attempts: u32,
25 pub max_attempts: u32,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub enum MfaChallengeType {
30 Sms,
31 Email,
32 Totp,
33}
34
35#[derive(Debug, Clone, ZeroizeOnDrop)]
37pub struct SecureMfaCode {
38 code: String,
39}
40
41impl SecureMfaCode {
42 pub fn as_str(&self) -> &str {
43 &self.code
44 }
45}
46
47pub struct SecureMfaService {
49 storage: Box<dyn AuthStorage>,
50 rng: SystemRandom,
51 rate_limits: Arc<DashMap<String, (u32, SystemTime)>>,
53}
54
55impl SecureMfaService {
56 pub fn new(storage: Box<dyn AuthStorage>) -> Self {
57 Self {
58 storage,
59 rng: SystemRandom::new(),
60 rate_limits: Arc::new(DashMap::new()),
61 }
62 }
63
64 pub fn generate_secure_code(&self, length: usize) -> Result<SecureMfaCode> {
66 if !(4..=12).contains(&length) {
67 return Err(AuthError::validation(
68 "MFA code length must be between 4 and 12",
69 ));
70 }
71
72 let mut code = String::with_capacity(length);
73
74 for _ in 0..length {
76 let mut byte = [0u8; 1];
77 loop {
78 self.rng.fill(&mut byte).map_err(|_| {
79 AuthError::crypto("Failed to generate secure random bytes".to_string())
80 })?;
81
82 let digit = byte[0] % 250; if digit < 250 {
85 code.push(char::from(b'0' + (digit % 10)));
86 break;
87 }
88 }
89 }
90
91 Ok(SecureMfaCode { code })
92 }
93
94 fn hash_code(&self, code: &str, salt: &[u8]) -> Result<String> {
96 use ring::digest;
97
98 let mut context = digest::Context::new(&digest::SHA256);
99 context.update(salt);
100 context.update(code.as_bytes());
101 let hash = context.finish();
102
103 Ok(base64::engine::general_purpose::STANDARD.encode(hash.as_ref()))
104 }
105
106 fn generate_salt(&self) -> Result<Vec<u8>> {
108 let mut salt = vec![0u8; 32];
109 self.rng
110 .fill(&mut salt)
111 .map_err(|_| AuthError::crypto("Failed to generate salt".to_string()))?;
112 Ok(salt)
113 }
114
115 fn check_rate_limit(&self, user_id: &str) -> Result<()> {
117 let now = SystemTime::now();
118 let window = Duration::from_secs(60); let max_attempts = 5; let (attempts, last_attempt) = self
122 .rate_limits
123 .get(user_id)
124 .map(|entry| *entry.value())
125 .unwrap_or((0, now));
126
127 if now.duration_since(last_attempt).unwrap_or(Duration::ZERO) > window {
129 self.rate_limits.insert(user_id.to_string(), (1, now));
130 return Ok(());
131 }
132
133 if attempts >= max_attempts {
134 return Err(AuthError::rate_limit(
135 "Too many MFA attempts. Please wait.".to_string(),
136 ));
137 }
138
139 self.rate_limits
140 .insert(user_id.to_string(), (attempts + 1, now));
141 Ok(())
142 }
143
144 pub async fn create_challenge(
146 &self,
147 user_id: &str,
148 challenge_type: MfaChallengeType,
149 code_length: usize,
150 ) -> Result<(String, SecureMfaCode)> {
151 self.check_rate_limit(user_id)?;
153
154 let challenge_id = self.generate_secure_id("mfa")?;
156
157 let secure_code = self.generate_secure_code(code_length)?;
159
160 let salt = self.generate_salt()?;
162 let code_hash = self.hash_code(secure_code.as_str(), &salt)?;
163
164 let now = SystemTime::now();
166 let challenge = MfaChallenge {
167 challenge_id: challenge_id.clone(),
168 user_id: user_id.to_string(),
169 challenge_type,
170 code_hash,
171 created_at: now,
172 expires_at: now + Duration::from_secs(300), attempts: 0,
174 max_attempts: 3,
175 };
176
177 let challenge_data = serde_json::to_vec(&challenge)
179 .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
180
181 self.storage
182 .store_kv(
183 &format!("mfa_challenge:{}", challenge_id),
184 &challenge_data,
185 Some(Duration::from_secs(300)),
186 )
187 .await?;
188
189 self.storage
190 .store_kv(
191 &format!("mfa_salt:{}", challenge_id),
192 &salt,
193 Some(Duration::from_secs(300)),
194 )
195 .await?;
196
197 tracing::info!("Created secure MFA challenge for user: {}", user_id);
198 Ok((challenge_id, secure_code))
199 }
200
201 pub async fn verify_challenge(&self, challenge_id: &str, provided_code: &str) -> Result<bool> {
203 if provided_code.is_empty() || provided_code.len() > 12 {
205 return Ok(false);
206 }
207
208 if !provided_code.chars().all(|c| c.is_ascii_digit()) {
209 return Ok(false);
210 }
211
212 let challenge_data = self
214 .storage
215 .get_kv(&format!("mfa_challenge:{}", challenge_id))
216 .await?;
217
218 let mut challenge: MfaChallenge = match challenge_data {
219 Some(data) => serde_json::from_slice(&data)
220 .map_err(|_| AuthError::validation("Invalid challenge data"))?,
221 None => return Ok(false), };
223
224 if SystemTime::now() > challenge.expires_at {
226 self.cleanup_challenge(challenge_id).await?;
228 return Ok(false);
229 }
230
231 if challenge.attempts >= challenge.max_attempts {
233 self.cleanup_challenge(challenge_id).await?;
234 return Ok(false);
235 }
236
237 challenge.attempts += 1;
239 let challenge_data = serde_json::to_vec(&challenge)
240 .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
241 self.storage
242 .store_kv(
243 &format!("mfa_challenge:{}", challenge_id),
244 &challenge_data,
245 Some(Duration::from_secs(300)),
246 )
247 .await?;
248
249 let salt = match self
251 .storage
252 .get_kv(&format!("mfa_salt:{}", challenge_id))
253 .await?
254 {
255 Some(salt) => salt,
256 None => return Ok(false),
257 };
258
259 let provided_hash = self.hash_code(provided_code, &salt)?;
261
262 let is_valid = challenge
264 .code_hash
265 .as_bytes()
266 .ct_eq(provided_hash.as_bytes())
267 .into();
268
269 if is_valid {
270 self.cleanup_challenge(challenge_id).await?;
272 tracing::info!(
273 "MFA challenge verified successfully for user: {}",
274 challenge.user_id
275 );
276 }
277
278 Ok(is_valid)
279 }
280
281 async fn cleanup_challenge(&self, challenge_id: &str) -> Result<()> {
283 let _ = self
284 .storage
285 .delete_kv(&format!("mfa_challenge:{}", challenge_id))
286 .await;
287 let _ = self
288 .storage
289 .delete_kv(&format!("mfa_salt:{}", challenge_id))
290 .await;
291 Ok(())
292 }
293
294 fn generate_secure_id(&self, prefix: &str) -> Result<String> {
296 let mut bytes = vec![0u8; 16];
297 self.rng
298 .fill(&mut bytes)
299 .map_err(|_| AuthError::crypto("Failed to generate secure ID".to_string()))?;
300
301 let id = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes);
302 Ok(format!("{}_{}", prefix, id))
303 }
304
305 pub fn generate_backup_codes(
307 &self,
308 count: u8,
309 ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
310 let mut codes = Vec::with_capacity(count as usize);
311
312 for _ in 0..count {
313 let mut code_bytes = [0u8; 10]; self.rng.fill(&mut code_bytes)?;
316
317 let code = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &code_bytes);
319
320 let formatted_code = format!(
322 "{}-{}-{}-{}",
323 &code[0..4],
324 &code[4..8],
325 &code[8..12],
326 &code[12..16]
327 );
328
329 codes.push(formatted_code);
330 }
331
332 Ok(codes)
333 }
334
335 pub fn hash_backup_codes(
337 &self,
338 codes: &[String],
339 ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
340 let mut hashed_codes = Vec::with_capacity(codes.len());
341
342 for code in codes {
343 let salt = self.generate_salt()?;
345 let mut hash = [0u8; 32];
346
347 ring::pbkdf2::derive(
348 ring::pbkdf2::PBKDF2_HMAC_SHA256,
349 std::num::NonZeroU32::new(100_000).unwrap(), &salt,
351 code.as_bytes(),
352 &mut hash,
353 );
354
355 let salt_hex = hex::encode(&salt);
357 let hash_hex = hex::encode(hash);
358 hashed_codes.push(format!("{}:{}", salt_hex, hash_hex));
359 }
360
361 Ok(hashed_codes)
362 }
363
364 pub fn verify_backup_code(
366 &self,
367 hashed_codes: &[String],
368 provided_code: &str,
369 ) -> Result<bool, Box<dyn std::error::Error>> {
370 if provided_code.len() != 19 || provided_code.chars().filter(|&c| c == '-').count() != 3 {
372 return Ok(false);
373 }
374
375 let clean_code = provided_code.replace("-", "");
377 if clean_code.len() != 16 || !clean_code.chars().all(|c| c.is_ascii_alphanumeric()) {
378 return Ok(false);
379 }
380
381 for hashed_code in hashed_codes {
382 let parts: Vec<&str> = hashed_code.split(':').collect();
383 if parts.len() != 2 {
384 continue;
385 }
386
387 let salt = match hex::decode(parts[0]) {
388 Ok(s) => s,
389 Err(_) => continue,
390 };
391
392 let stored_hash = match hex::decode(parts[1]) {
393 Ok(h) => h,
394 Err(_) => continue,
395 };
396
397 let mut derived_hash = [0u8; 32];
399 ring::pbkdf2::derive(
400 ring::pbkdf2::PBKDF2_HMAC_SHA256,
401 std::num::NonZeroU32::new(100_000).unwrap(),
402 &salt,
403 provided_code.as_bytes(),
404 &mut derived_hash,
405 );
406
407 if subtle::ConstantTimeEq::ct_eq(&stored_hash[..], &derived_hash[..]).into() {
409 return Ok(true);
410 }
411 }
412
413 Ok(false)
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::testing::MockStorage;
421
422 #[tokio::test]
423 async fn test_secure_code_generation() {
424 let storage = Box::new(MockStorage::new());
425 let mfa_service = SecureMfaService::new(storage);
426
427 let code = mfa_service.generate_secure_code(6).unwrap();
428 assert_eq!(code.as_str().len(), 6);
429 assert!(code.as_str().chars().all(|c| c.is_ascii_digit()));
430 }
431
432 #[tokio::test]
433 async fn test_mfa_challenge_flow() {
434 let storage = Box::new(MockStorage::new());
435 let mfa_service = SecureMfaService::new(storage);
436
437 let (challenge_id, code) = mfa_service
439 .create_challenge("user123", MfaChallengeType::Sms, 6)
440 .await
441 .unwrap();
442
443 let result = mfa_service
445 .verify_challenge(&challenge_id, code.as_str())
446 .await
447 .unwrap();
448 assert!(result);
449
450 let result2 = mfa_service
452 .verify_challenge(&challenge_id, code.as_str())
453 .await
454 .unwrap();
455 assert!(!result2);
456 }
457
458 #[tokio::test]
459 async fn test_invalid_code_rejection() {
460 let storage = Box::new(MockStorage::new());
461 let mfa_service = SecureMfaService::new(storage);
462
463 let (challenge_id, _code) = mfa_service
464 .create_challenge("user123", MfaChallengeType::Sms, 6)
465 .await
466 .unwrap();
467
468 assert!(
470 !mfa_service
471 .verify_challenge(&challenge_id, "000000")
472 .await
473 .unwrap()
474 );
475 assert!(
476 !mfa_service
477 .verify_challenge(&challenge_id, "123abc")
478 .await
479 .unwrap()
480 );
481 assert!(
482 !mfa_service
483 .verify_challenge(&challenge_id, "")
484 .await
485 .unwrap()
486 );
487 assert!(
488 !mfa_service
489 .verify_challenge(&challenge_id, "12345678901234")
490 .await
491 .unwrap()
492 );
493 }
494
495 #[tokio::test]
496 async fn test_rate_limiting() {
497 let storage = Box::new(MockStorage::new());
498 let mfa_service = SecureMfaService::new(storage);
499
500 for _ in 0..5 {
502 let result = mfa_service
503 .create_challenge("user123", MfaChallengeType::Sms, 6)
504 .await;
505 assert!(result.is_ok());
506 }
507
508 let result = mfa_service
510 .create_challenge("user123", MfaChallengeType::Sms, 6)
511 .await;
512 assert!(result.is_err());
513 }
514}