1use crate::errors::{AuthError, Result};
5use crate::methods::{MfaChallenge, MfaType};
6use crate::storage::AuthStorage;
7use base64::Engine;
8use dashmap::DashMap;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use subtle::ConstantTimeEq;
14use zeroize::ZeroizeOnDrop;
15
16#[derive(Debug, Clone, ZeroizeOnDrop)]
18pub struct SecureMfaCode {
19 code: String,
20}
21
22impl SecureMfaCode {
23 pub fn as_str(&self) -> &str {
24 &self.code
25 }
26}
27
28pub struct SecureMfaService {
30 storage: Box<dyn AuthStorage>,
31 rng: SystemRandom,
32 rate_limits: Arc<DashMap<String, (u32, SystemTime)>>,
34}
35
36impl SecureMfaService {
37 pub fn new(storage: Box<dyn AuthStorage>) -> Self {
38 Self {
39 storage,
40 rng: SystemRandom::new(),
41 rate_limits: Arc::new(DashMap::new()),
42 }
43 }
44
45 pub fn generate_secure_code(&self, length: usize) -> Result<SecureMfaCode> {
47 if !(4..=12).contains(&length) {
48 return Err(AuthError::validation(
49 "MFA code length must be between 4 and 12",
50 ));
51 }
52
53 let mut code = String::with_capacity(length);
54
55 for _ in 0..length {
57 let mut byte = [0u8; 1];
58 loop {
59 self.rng.fill(&mut byte).map_err(|_| {
60 AuthError::crypto("Failed to generate secure random bytes".to_string())
61 })?;
62
63 if byte[0] < 250 {
66 code.push(char::from(b'0' + (byte[0] % 10)));
67 break;
68 }
69 }
70 }
71
72 Ok(SecureMfaCode { code })
73 }
74
75 fn hash_code(&self, code: &str, salt: &[u8]) -> Result<String> {
80 use ring::pbkdf2;
81
82 let mut out = [0u8; 32];
83 pbkdf2::derive(
84 pbkdf2::PBKDF2_HMAC_SHA256,
85 std::num::NonZeroU32::new(10_000).unwrap(),
86 salt,
87 code.as_bytes(),
88 &mut out,
89 );
90
91 Ok(base64::engine::general_purpose::STANDARD.encode(&out))
92 }
93
94 fn generate_salt(&self) -> Result<Vec<u8>> {
96 let mut salt = vec![0u8; 32];
97 self.rng
98 .fill(&mut salt)
99 .map_err(|_| AuthError::crypto("Failed to generate salt".to_string()))?;
100 Ok(salt)
101 }
102
103 fn check_rate_limit(&self, user_id: &str) -> Result<()> {
105 let now = SystemTime::now();
106 let window = Duration::from_secs(60); let max_attempts = 5; let (attempts, last_attempt) = self
110 .rate_limits
111 .get(user_id)
112 .map(|entry| *entry.value())
113 .unwrap_or((0, now));
114
115 if now.duration_since(last_attempt).unwrap_or(Duration::ZERO) > window {
117 self.rate_limits.insert(user_id.to_string(), (1, now));
118 return Ok(());
119 }
120
121 if attempts >= max_attempts {
122 return Err(AuthError::rate_limit(
123 "Too many MFA attempts. Please wait.".to_string(),
124 ));
125 }
126
127 self.rate_limits
128 .insert(user_id.to_string(), (attempts + 1, now));
129 Ok(())
130 }
131
132 pub async fn create_challenge(
134 &self,
135 user_id: &str,
136 mfa_type: MfaType,
137 code_length: usize,
138 ) -> Result<(String, SecureMfaCode)> {
139 self.check_rate_limit(user_id)?;
141
142 let challenge_id = self.generate_secure_id("mfa")?;
144
145 let secure_code = self.generate_secure_code(code_length)?;
147
148 let salt = self.generate_salt()?;
150 let code_hash = self.hash_code(secure_code.as_str(), &salt)?;
151
152 let now = chrono::Utc::now();
154 let challenge = MfaChallenge {
155 id: challenge_id.clone(),
156 user_id: user_id.to_string(),
157 mfa_type,
158 created_at: now,
159 expires_at: now + chrono::Duration::seconds(300), attempts: 0,
161 max_attempts: 3,
162 code_hash: Some(code_hash),
163 message: None,
164 data: HashMap::new(),
165 };
166
167 let challenge_data = serde_json::to_vec(&challenge)
169 .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
170
171 self.storage
172 .store_kv(
173 &format!("mfa_challenge:{}", challenge_id),
174 &challenge_data,
175 Some(Duration::from_secs(300)),
176 )
177 .await?;
178
179 self.storage
180 .store_kv(
181 &format!("mfa_salt:{}", challenge_id),
182 &salt,
183 Some(Duration::from_secs(300)),
184 )
185 .await?;
186
187 tracing::info!("Created secure MFA challenge for user: {}", user_id);
188 Ok((challenge_id, secure_code))
189 }
190
191 pub async fn verify_challenge(&self, challenge_id: &str, provided_code: &str) -> Result<bool> {
193 let format_valid = !provided_code.is_empty()
197 && provided_code.len() <= 12
198 && provided_code.chars().all(|c| c.is_ascii_digit());
199
200 let challenge_data = self
202 .storage
203 .get_kv(&format!("mfa_challenge:{}", challenge_id))
204 .await?;
205
206 let mut challenge: MfaChallenge = match challenge_data {
207 Some(data) => serde_json::from_slice(&data)
208 .map_err(|_| AuthError::validation("Invalid challenge data"))?,
209 None => {
210 let dummy_salt = [0u8; 32];
212 let _ = self.hash_code("000000", &dummy_salt);
213 return Ok(false);
214 }
215 };
216
217 self.check_rate_limit(&challenge.user_id)?;
219
220 if chrono::Utc::now() > challenge.expires_at {
222 self.cleanup_challenge(challenge_id).await?;
224 let dummy_salt = [0u8; 32];
226 let _ = self.hash_code("000000", &dummy_salt);
227 return Ok(false);
228 }
229
230 if challenge.attempts >= challenge.max_attempts {
232 self.cleanup_challenge(challenge_id).await?;
233 let dummy_salt = [0u8; 32];
234 let _ = self.hash_code("000000", &dummy_salt);
235 return Ok(false);
236 }
237
238 challenge.attempts += 1;
240 let challenge_data = serde_json::to_vec(&challenge)
241 .map_err(|e| AuthError::crypto(format!("Failed to serialize challenge: {}", e)))?;
242 self.storage
243 .store_kv(
244 &format!("mfa_challenge:{}", challenge_id),
245 &challenge_data,
246 Some(Duration::from_secs(300)),
247 )
248 .await?;
249
250 let salt = match self
252 .storage
253 .get_kv(&format!("mfa_salt:{}", challenge_id))
254 .await?
255 {
256 Some(salt) => salt,
257 None => return Ok(false),
258 };
259
260 let provided_hash = self.hash_code(
262 if format_valid {
263 provided_code
264 } else {
265 "000000"
266 },
267 &salt,
268 )?;
269
270 let hash_matches = challenge.code_hash.as_ref().is_some_and(|stored_hash| {
272 stored_hash
273 .as_bytes()
274 .ct_eq(provided_hash.as_bytes())
275 .into()
276 });
277
278 let is_valid = format_valid && hash_matches;
280
281 if is_valid {
282 self.cleanup_challenge(challenge_id).await?;
284 tracing::info!(
285 "MFA challenge verified successfully for user: {}",
286 challenge.user_id
287 );
288 }
289
290 Ok(is_valid)
291 }
292
293 async fn cleanup_challenge(&self, challenge_id: &str) -> Result<()> {
295 let _ = self
296 .storage
297 .delete_kv(&format!("mfa_challenge:{}", challenge_id))
298 .await;
299 let _ = self
300 .storage
301 .delete_kv(&format!("mfa_salt:{}", challenge_id))
302 .await;
303 Ok(())
304 }
305
306 fn generate_secure_id(&self, prefix: &str) -> Result<String> {
308 let mut bytes = vec![0u8; 16];
309 self.rng
310 .fill(&mut bytes)
311 .map_err(|_| AuthError::crypto("Failed to generate secure ID".to_string()))?;
312
313 let id = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes);
314 Ok(format!("{}_{}", prefix, id))
315 }
316
317 pub fn generate_backup_codes(
319 &self,
320 count: u8,
321 ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
322 let mut codes = Vec::with_capacity(count as usize);
323
324 for _ in 0..count {
325 let mut code_bytes = [0u8; 10]; self.rng
328 .fill(&mut code_bytes)
329 .map_err(|_| "Failed to generate random bytes".to_string())?;
330
331 let code = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &code_bytes);
333
334 let formatted_code = format!(
336 "{}-{}-{}-{}",
337 &code[0..4],
338 &code[4..8],
339 &code[8..12],
340 &code[12..16]
341 );
342
343 codes.push(formatted_code);
344 }
345
346 Ok(codes)
347 }
348
349 pub fn hash_backup_codes(
351 &self,
352 codes: &[String],
353 ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
354 let mut hashed_codes = Vec::with_capacity(codes.len());
355
356 for code in codes {
357 let salt = self.generate_salt()?;
359 let mut hash = [0u8; 32];
360
361 ring::pbkdf2::derive(
362 ring::pbkdf2::PBKDF2_HMAC_SHA256,
363 std::num::NonZeroU32::new(100_000).expect("100_000 is non-zero"), &salt,
365 code.as_bytes(),
366 &mut hash,
367 );
368
369 let salt_hex = hex::encode(&salt);
371 let hash_hex = hex::encode(hash);
372 hashed_codes.push(format!("{}:{}", salt_hex, hash_hex));
373 }
374
375 Ok(hashed_codes)
376 }
377
378 pub fn verify_backup_code(
380 &self,
381 hashed_codes: &[String],
382 provided_code: &str,
383 ) -> Result<bool, Box<dyn std::error::Error>> {
384 if provided_code.len() != 19 || provided_code.chars().filter(|&c| c == '-').count() != 3 {
386 return Ok(false);
387 }
388
389 let clean_code = provided_code.replace("-", "");
391 if clean_code.len() != 16 || !clean_code.chars().all(|c| c.is_ascii_alphanumeric()) {
392 return Ok(false);
393 }
394
395 let mut found = false;
398
399 for hashed_code in hashed_codes {
400 let parts: Vec<&str> = hashed_code.split(':').collect();
401 if parts.len() != 2 {
402 continue;
403 }
404
405 let salt = match hex::decode(parts[0]) {
406 Ok(s) => s,
407 Err(_) => continue,
408 };
409
410 let stored_hash = match hex::decode(parts[1]) {
411 Ok(h) => h,
412 Err(_) => continue,
413 };
414
415 let mut derived_hash = [0u8; 32];
417 ring::pbkdf2::derive(
418 ring::pbkdf2::PBKDF2_HMAC_SHA256,
419 std::num::NonZeroU32::new(100_000).expect("100_000 is non-zero"),
420 &salt,
421 provided_code.as_bytes(),
422 &mut derived_hash,
423 );
424
425 let matches: bool =
428 subtle::ConstantTimeEq::ct_eq(&stored_hash[..], &derived_hash[..]).into();
429 if matches {
430 found = true;
431 }
432 }
433
434 Ok(found)
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::testing::MockStorage;
442
443 #[tokio::test]
444 async fn test_secure_code_generation() {
445 let storage = Box::new(MockStorage::new());
446 let mfa_service = SecureMfaService::new(storage);
447
448 let code = mfa_service.generate_secure_code(6).unwrap();
449 assert_eq!(code.as_str().len(), 6);
450 assert!(code.as_str().chars().all(|c| c.is_ascii_digit()));
451 }
452
453 #[tokio::test]
454 async fn test_mfa_challenge_flow() {
455 let storage = Box::new(MockStorage::new());
456 let mfa_service = SecureMfaService::new(storage);
457
458 let (challenge_id, code) = mfa_service
460 .create_challenge(
461 "user123",
462 MfaType::Sms {
463 phone_number: String::new(),
464 },
465 6,
466 )
467 .await
468 .unwrap();
469
470 let result = mfa_service
472 .verify_challenge(&challenge_id, code.as_str())
473 .await
474 .unwrap();
475 assert!(result);
476
477 let result2 = mfa_service
479 .verify_challenge(&challenge_id, code.as_str())
480 .await
481 .unwrap();
482 assert!(!result2);
483 }
484
485 #[tokio::test]
486 async fn test_invalid_code_rejection() {
487 let storage = Box::new(MockStorage::new());
488 let mfa_service = SecureMfaService::new(storage);
489
490 let (challenge_id, _code) = mfa_service
491 .create_challenge(
492 "user123",
493 MfaType::Sms {
494 phone_number: String::new(),
495 },
496 6,
497 )
498 .await
499 .unwrap();
500
501 assert!(
503 !mfa_service
504 .verify_challenge(&challenge_id, "000000")
505 .await
506 .unwrap()
507 );
508 assert!(
509 !mfa_service
510 .verify_challenge(&challenge_id, "123abc")
511 .await
512 .unwrap()
513 );
514 assert!(
515 !mfa_service
516 .verify_challenge(&challenge_id, "")
517 .await
518 .unwrap()
519 );
520 assert!(
521 !mfa_service
522 .verify_challenge(&challenge_id, "12345678901234")
523 .await
524 .unwrap()
525 );
526 }
527
528 #[tokio::test]
529 async fn test_rate_limiting() {
530 let storage = Box::new(MockStorage::new());
531 let mfa_service = SecureMfaService::new(storage);
532
533 for _ in 0..5 {
535 let result = mfa_service
536 .create_challenge(
537 "user123",
538 MfaType::Sms {
539 phone_number: String::new(),
540 },
541 6,
542 )
543 .await;
544 assert!(result.is_ok());
545 }
546
547 let result = mfa_service
549 .create_challenge(
550 "user123",
551 MfaType::Sms {
552 phone_number: String::new(),
553 },
554 6,
555 )
556 .await;
557 assert!(result.is_err());
558 }
559}