1use crate::errors::{AuthError, Result};
7use crate::security::MfaConfig;
8use async_trait::async_trait;
9use ring::rand::SecureRandom;
10use serde::{Deserialize, Serialize};
11use std::time::{SystemTime, UNIX_EPOCH};
12use subtle::ConstantTimeEq;
13use totp_lite::{Sha1, totp};
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum MfaMethodType {
18 Totp,
19 Sms,
20 Email,
21 WebAuthn,
22 BackupCodes,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct MfaChallenge {
28 pub id: String,
30 pub user_id: String,
32 pub method_type: MfaMethodType,
34 pub challenge_data: MfaChallengeData,
36 pub created_at: SystemTime,
38 pub expires_at: SystemTime,
40 pub attempts: u32,
42 pub max_attempts: u32,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub enum MfaChallengeData {
49 Totp {
50 time_window: u64,
52 },
53 Sms {
54 phone_number: String,
56 code: String,
58 },
59 Email {
60 email: String,
62 code: String,
64 },
65 WebAuthn {
66 challenge: Vec<u8>,
68 allowed_credentials: Vec<String>,
70 },
71 BackupCodes {
72 remaining_codes: u32,
74 },
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct UserMfaMethod {
80 pub id: String,
82 pub user_id: String,
84 pub method_type: MfaMethodType,
86 pub method_data: MfaMethodData,
88 pub display_name: String,
90 pub is_primary: bool,
92 pub is_enabled: bool,
94 pub created_at: SystemTime,
96 pub last_used_at: Option<SystemTime>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum MfaMethodData {
103 Totp {
104 secret_key: String,
106 qr_code_url: String,
108 },
109 Sms {
110 phone_number: String,
112 is_verified: bool,
114 },
115 Email {
116 email: String,
118 is_verified: bool,
120 },
121 WebAuthn {
122 credential_id: String,
124 public_key: Vec<u8>,
126 counter: u32,
128 },
129 BackupCodes {
130 codes: Vec<String>,
132 used_count: u32,
134 },
135}
136
137#[derive(Debug, Clone)]
139pub struct MfaVerificationResult {
140 pub success: bool,
142 pub method_type: MfaMethodType,
144 pub remaining_attempts: Option<u32>,
146 pub error_message: Option<String>,
148}
149
150pub struct TotpProvider {
152 config: crate::security::TotpConfig,
153}
154
155impl TotpProvider {
156 pub fn new(config: crate::security::TotpConfig) -> Self {
157 Self { config }
158 }
159
160 pub fn generate_secret(&self) -> crate::Result<String> {
162 use ring::rand::{SecureRandom, SystemRandom};
163 let rng = SystemRandom::new();
164 let mut secret = [0u8; 20];
165 rng.fill(&mut secret).map_err(|_| {
166 crate::errors::AuthError::crypto("Failed to generate secure TOTP secret".to_string())
167 })?;
168 Ok(base32::encode(
169 base32::Alphabet::Rfc4648 { padding: true },
170 &secret,
171 ))
172 }
173
174 pub fn generate_qr_code_url(&self, secret: &str, user_identifier: &str) -> String {
176 format!(
177 "otpauth://totp/{}:{}?secret={}&issuer={}&digits={}&period={}",
178 urlencoding::encode(&self.config.issuer),
179 urlencoding::encode(user_identifier),
180 secret,
181 urlencoding::encode(&self.config.issuer),
182 self.config.digits,
183 self.config.period
184 )
185 }
186
187 pub fn generate_code(&self, secret: &str, time_step: Option<u64>) -> Result<String> {
189 if secret.trim().is_empty() {
190 return Err(AuthError::validation("TOTP secret cannot be empty"));
191 }
192
193 let secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret)
194 .ok_or_else(|| AuthError::validation("Invalid TOTP secret"))?;
195
196 let time_step = time_step.unwrap_or_else(|| {
197 SystemTime::now()
198 .duration_since(UNIX_EPOCH)
199 .unwrap_or_default()
200 .as_secs()
201 / self.config.period
202 });
203
204 let unix_timestamp = time_step.checked_mul(self.config.period).ok_or_else(|| {
207 AuthError::InvalidInput("Time step too large for conversion".to_string())
208 })?;
209
210 let totp_value = totp::<Sha1>(&secret_bytes, unix_timestamp);
212
213 let parsed_value: u32 = totp_value
215 .parse()
216 .map_err(|_| AuthError::validation("TOTP generation error"))?;
217
218 Ok(format!(
220 "{:0width$}",
221 parsed_value % 10_u32.pow(self.config.digits.into()),
222 width = self.config.digits as usize
223 ))
224 }
225
226 pub fn verify_code(&self, secret: &str, code: &str, time_window: Option<u64>) -> Result<bool> {
228 let _secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret)
230 .ok_or_else(|| AuthError::validation("Invalid TOTP secret"))?;
231
232 let current_time_step = if let Some(time) = time_window {
233 time / self.config.period
234 } else {
235 SystemTime::now()
236 .duration_since(UNIX_EPOCH)
237 .unwrap_or_default()
238 .as_secs()
239 / self.config.period
240 };
241
242 let mut matched = false;
245 for step_offset in [-1i64, 0, 1] {
246 let time_step_i64 = current_time_step as i64 + step_offset;
247 if time_step_i64 < 0 {
249 continue;
250 }
251 let time_step = time_step_i64 as u64;
252 let expected_code = self.generate_code(secret, Some(time_step))?;
253 let eq: bool = expected_code.as_bytes().ct_eq(code.as_bytes()).into();
254 matched |= eq;
255 }
256
257 Ok(matched)
258 }
259
260 pub fn verify_totp(&self, secret: &str, token: &str, window: u8) -> Result<bool> {
262 let now = SystemTime::now()
263 .duration_since(UNIX_EPOCH)
264 .map_err(|_| AuthError::validation("System time error"))?
265 .as_secs()
266 / self.config.period;
267
268 for i in 0..=window {
271 if i == 0 {
273 if let Ok(expected_code) = self.generate_code(secret, Some(now))
274 && expected_code.as_bytes().ct_eq(token.as_bytes()).into()
275 {
276 return Ok(true);
277 }
278 } else {
279 for offset in [i as i64, -(i as i64)] {
281 let time_step_i64 = now as i64 + offset;
282 if time_step_i64 < 0 {
284 continue;
285 }
286 let time_step = time_step_i64 as u64;
287 if let Ok(expected_code) = self.generate_code(secret, Some(time_step))
288 && expected_code.as_bytes().ct_eq(token.as_bytes()).into()
289 {
290 return Ok(true);
291 }
292 }
293 }
294 }
295 Ok(false)
296 }
297}
298
299#[async_trait]
301pub trait SmsProvider: Send + Sync {
302 async fn send_code(&self, phone_number: &str, code: &str) -> Result<()>;
303}
304
305#[async_trait]
307pub trait EmailProvider: Send + Sync {
308 async fn send_code(&self, email: &str, code: &str) -> Result<()>;
309}
310
311pub struct BackupCodesProvider;
313
314impl BackupCodesProvider {
315 pub fn generate_codes(count: u8) -> Vec<String> {
317 let rng = ring::rand::SystemRandom::new();
318 (0..count)
319 .map(|_| {
320 let mut buf = [0u8; 4];
321 rng.fill(&mut buf).expect("system RNG failure");
322 let val1 = u16::from_le_bytes([buf[0], buf[1]]) % 8999 + 1000;
323 let val2 = u16::from_le_bytes([buf[2], buf[3]]) % 8999 + 1000;
324 format!("{:04}-{:04}", val1, val2)
325 })
326 .collect()
327 }
328
329 pub fn hash_codes(codes: &[String]) -> Result<Vec<String>> {
331 use sha2::{Digest, Sha256};
332 codes
333 .iter()
334 .map(|code| {
335 let hash = Sha256::digest(code.as_bytes());
336 Ok(hex::encode(hash))
337 })
338 .collect()
339 }
340
341 pub fn verify_code(hashed_codes: &[String], provided_code: &str) -> bool {
343 use sha2::{Digest, Sha256};
344 let provided_hash = hex::encode(Sha256::digest(provided_code.as_bytes()));
345 let provided_bytes = provided_hash.as_bytes();
346 hashed_codes
347 .iter()
348 .any(|h| h.as_bytes().ct_eq(provided_bytes).into())
349 }
350}
351
352#[async_trait]
354pub trait MfaStorage: Send + Sync {
355 async fn store_mfa_method(&self, method: &UserMfaMethod) -> Result<()>;
357
358 async fn get_user_mfa_methods(&self, user_id: &str) -> Result<Vec<UserMfaMethod>>;
360
361 async fn update_mfa_method(&self, method: &UserMfaMethod) -> Result<()>;
363
364 async fn delete_mfa_method(&self, method_id: &str) -> Result<()>;
366
367 async fn store_mfa_challenge(&self, challenge: &MfaChallenge) -> Result<()>;
369
370 async fn get_mfa_challenge(&self, challenge_id: &str) -> Result<Option<MfaChallenge>>;
372
373 async fn update_mfa_challenge(&self, challenge: &MfaChallenge) -> Result<()>;
375
376 async fn delete_mfa_challenge(&self, challenge_id: &str) -> Result<()>;
378
379 async fn cleanup_expired_challenges(&self) -> Result<()>;
381}
382
383pub struct MfaManager<S: MfaStorage> {
385 storage: S,
386 config: MfaConfig,
387 totp_provider: TotpProvider,
388 sms_provider: Option<Box<dyn SmsProvider>>,
389 email_provider: Option<Box<dyn EmailProvider>>,
390}
391
392impl<S: MfaStorage> MfaManager<S> {
393 pub fn new(storage: S, config: MfaConfig) -> Self {
395 let totp_provider = TotpProvider::new(config.totp_config.clone());
396
397 Self {
398 storage,
399 config,
400 totp_provider,
401 sms_provider: None,
402 email_provider: None,
403 }
404 }
405
406 pub fn with_sms_provider(mut self, provider: Box<dyn SmsProvider>) -> Self {
408 self.sms_provider = Some(provider);
409 self
410 }
411
412 pub fn with_email_provider(mut self, provider: Box<dyn EmailProvider>) -> Self {
414 self.email_provider = Some(provider);
415 self
416 }
417
418 pub async fn setup_totp(&self, user_id: &str, user_identifier: &str) -> Result<UserMfaMethod> {
420 let secret = self.totp_provider.generate_secret()?;
421 let qr_code_url = self
422 .totp_provider
423 .generate_qr_code_url(&secret, user_identifier);
424
425 let method = UserMfaMethod {
426 id: uuid::Uuid::new_v4().to_string(),
427 user_id: user_id.to_string(),
428 method_type: MfaMethodType::Totp,
429 method_data: MfaMethodData::Totp {
430 secret_key: secret,
431 qr_code_url,
432 },
433 display_name: "Authenticator App".to_string(),
434 is_primary: false,
435 is_enabled: false, created_at: SystemTime::now(),
437 last_used_at: None,
438 };
439
440 self.storage.store_mfa_method(&method).await?;
441 Ok(method)
442 }
443
444 pub async fn setup_sms(&self, user_id: &str, phone_number: &str) -> Result<UserMfaMethod> {
446 let method = UserMfaMethod {
447 id: uuid::Uuid::new_v4().to_string(),
448 user_id: user_id.to_string(),
449 method_type: MfaMethodType::Sms,
450 method_data: MfaMethodData::Sms {
451 phone_number: phone_number.to_string(),
452 is_verified: false,
453 },
454 display_name: format!("SMS to {}", mask_phone_number(phone_number)),
455 is_primary: false,
456 is_enabled: false,
457 created_at: SystemTime::now(),
458 last_used_at: None,
459 };
460
461 self.storage.store_mfa_method(&method).await?;
462 Ok(method)
463 }
464
465 pub async fn generate_backup_codes(
467 &self,
468 user_id: &str,
469 ) -> Result<(UserMfaMethod, Vec<String>)> {
470 let codes = BackupCodesProvider::generate_codes(10);
471 let hashed_codes = BackupCodesProvider::hash_codes(&codes)?;
472
473 let method = UserMfaMethod {
474 id: uuid::Uuid::new_v4().to_string(),
475 user_id: user_id.to_string(),
476 method_type: MfaMethodType::BackupCodes,
477 method_data: MfaMethodData::BackupCodes {
478 codes: hashed_codes,
479 used_count: 0,
480 },
481 display_name: "Backup Codes".to_string(),
482 is_primary: false,
483 is_enabled: true,
484 created_at: SystemTime::now(),
485 last_used_at: None,
486 };
487
488 self.storage.store_mfa_method(&method).await?;
489 Ok((method, codes))
490 }
491
492 pub async fn create_challenge(
494 &self,
495 user_id: &str,
496 method_type: MfaMethodType,
497 ) -> Result<MfaChallenge> {
498 let user_methods = self.storage.get_user_mfa_methods(user_id).await?;
499 let method = user_methods
500 .iter()
501 .find(|m| m.method_type == method_type && m.is_enabled)
502 .ok_or_else(|| AuthError::validation("MFA method not found or not enabled"))?;
503
504 let challenge_data = match &method.method_data {
505 MfaMethodData::Totp { .. } => MfaChallengeData::Totp {
506 time_window: SystemTime::now()
507 .duration_since(UNIX_EPOCH)
508 .unwrap_or_default()
509 .as_secs()
510 / self.config.totp_config.period,
511 },
512 MfaMethodData::Sms { phone_number, .. } => {
513 let code = generate_numeric_code(6);
514 if let Some(sms_provider) = &self.sms_provider {
515 sms_provider.send_code(phone_number, &code).await?;
516 }
517 MfaChallengeData::Sms {
518 phone_number: mask_phone_number(phone_number),
519 code,
520 }
521 }
522 MfaMethodData::Email { email, .. } => {
523 let code = generate_numeric_code(6);
524 if let Some(email_provider) = &self.email_provider {
525 email_provider.send_code(email, &code).await?;
526 }
527 MfaChallengeData::Email {
528 email: mask_email(email),
529 code,
530 }
531 }
532 MfaMethodData::BackupCodes { .. } => {
533 MfaChallengeData::BackupCodes { remaining_codes: 8 } }
535 _ => return Err(AuthError::validation("Unsupported MFA method type")),
536 };
537
538 let challenge = MfaChallenge {
539 id: uuid::Uuid::new_v4().to_string(),
540 user_id: user_id.to_string(),
541 method_type,
542 challenge_data,
543 created_at: SystemTime::now(),
544 expires_at: SystemTime::now() + std::time::Duration::from_secs(300), attempts: 0,
546 max_attempts: 3,
547 };
548
549 self.storage.store_mfa_challenge(&challenge).await?;
550 Ok(challenge)
551 }
552
553 pub async fn verify_challenge(
555 &self,
556 challenge_id: &str,
557 response: &str,
558 ) -> Result<MfaVerificationResult> {
559 let mut challenge = self
560 .storage
561 .get_mfa_challenge(challenge_id)
562 .await?
563 .ok_or_else(|| AuthError::validation("MFA challenge not found"))?;
564
565 if SystemTime::now() > challenge.expires_at {
567 self.storage.delete_mfa_challenge(challenge_id).await?;
568 return Ok(MfaVerificationResult {
569 success: false,
570 method_type: challenge.method_type,
571 remaining_attempts: None,
572 error_message: Some("Challenge has expired".to_string()),
573 });
574 }
575
576 if challenge.attempts >= challenge.max_attempts {
578 self.storage.delete_mfa_challenge(challenge_id).await?;
579 return Ok(MfaVerificationResult {
580 success: false,
581 method_type: challenge.method_type,
582 remaining_attempts: Some(0),
583 error_message: Some("Maximum attempts exceeded".to_string()),
584 });
585 }
586
587 challenge.attempts += 1;
588
589 let success = match &challenge.challenge_data {
590 MfaChallengeData::Totp { time_window } => {
591 let user_methods = self
592 .storage
593 .get_user_mfa_methods(&challenge.user_id)
594 .await?;
595 if let Some(method) = user_methods
596 .iter()
597 .find(|m| m.method_type == MfaMethodType::Totp)
598 {
599 if let MfaMethodData::Totp { secret_key, .. } = &method.method_data {
600 self.totp_provider
601 .verify_code(secret_key, response, Some(*time_window))?
602 } else {
603 false
604 }
605 } else {
606 false
607 }
608 }
609 MfaChallengeData::Sms { code, .. } => code.as_bytes().ct_eq(response.as_bytes()).into(),
610 MfaChallengeData::Email { code, .. } => {
611 code.as_bytes().ct_eq(response.as_bytes()).into()
612 }
613 MfaChallengeData::BackupCodes { .. } => {
614 let user_methods = self
615 .storage
616 .get_user_mfa_methods(&challenge.user_id)
617 .await?;
618 if let Some(method) = user_methods
619 .iter()
620 .find(|m| m.method_type == MfaMethodType::BackupCodes)
621 {
622 if let MfaMethodData::BackupCodes { codes, .. } = &method.method_data {
623 BackupCodesProvider::verify_code(codes, response)
624 } else {
625 false
626 }
627 } else {
628 false
629 }
630 }
631 _ => false,
632 };
633
634 if success {
635 self.storage.delete_mfa_challenge(challenge_id).await?;
636 Ok(MfaVerificationResult {
637 success: true,
638 method_type: challenge.method_type,
639 remaining_attempts: None,
640 error_message: None,
641 })
642 } else {
643 let remaining = challenge.max_attempts.saturating_sub(challenge.attempts);
644 self.storage.update_mfa_challenge(&challenge).await?;
645
646 Ok(MfaVerificationResult {
647 success: false,
648 method_type: challenge.method_type,
649 remaining_attempts: Some(remaining),
650 error_message: Some("Invalid code".to_string()),
651 })
652 }
653 }
654
655 pub async fn has_mfa_enabled(&self, user_id: &str) -> Result<bool> {
657 let methods = self.storage.get_user_mfa_methods(user_id).await?;
658 Ok(methods.iter().any(|m| m.is_enabled))
659 }
660
661 pub async fn get_enabled_methods(&self, user_id: &str) -> Result<Vec<MfaMethodType>> {
663 let methods = self.storage.get_user_mfa_methods(user_id).await?;
664 Ok(methods
665 .iter()
666 .filter(|m| m.is_enabled)
667 .map(|m| m.method_type.clone())
668 .collect())
669 }
670}
671
672fn generate_numeric_code(length: u8) -> String {
674 let rng = ring::rand::SystemRandom::new();
675 (0..length)
676 .map(|_| {
677 let mut buf = [0u8; 1];
678 rng.fill(&mut buf).expect("system RNG failure");
679 (buf[0] % 10).to_string()
680 })
681 .collect()
682}
683
684fn mask_phone_number(phone: &str) -> String {
686 if phone.len() > 4 {
687 format!("***-***-{}", &phone[phone.len() - 4..])
688 } else {
689 "***-***-****".to_string()
690 }
691}
692
693fn mask_email(email: &str) -> String {
695 if let Some(at_pos) = email.find('@') {
696 let (local, domain) = email.split_at(at_pos);
697 if local.len() > 2 {
698 format!("{}***{}", &local[0..1], &domain)
699 } else {
700 format!("***{}", domain)
701 }
702 } else {
703 "***@***.***".to_string()
704 }
705}
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710
711 #[test]
712 fn test_totp_generation() {
713 let config = crate::security::TotpConfig::default();
714 let provider = TotpProvider::new(config);
715
716 let secret = provider.generate_secret().unwrap();
717 assert!(!secret.is_empty());
718
719 let code = provider.generate_code(&secret, Some(1)).unwrap();
720 assert_eq!(code.len(), 6);
721
722 assert!(provider.verify_code(&secret, &code, Some(1)).unwrap());
724
725 assert!(!provider.verify_code(&secret, "000000", Some(1)).unwrap());
727 }
728
729 #[test]
730 fn test_backup_codes() {
731 let codes = BackupCodesProvider::generate_codes(5);
732 assert_eq!(codes.len(), 5);
733
734 let hashed = BackupCodesProvider::hash_codes(&codes).unwrap();
735 assert_eq!(hashed.len(), 5);
736
737 assert!(BackupCodesProvider::verify_code(&hashed, &codes[0]));
739
740 assert!(!BackupCodesProvider::verify_code(&hashed, "1234-5678"));
742 }
743
744 #[test]
745 fn test_masking() {
746 assert_eq!(mask_phone_number("+1234567890"), "***-***-7890");
747 assert_eq!(mask_email("user@example.com"), "u***@example.com");
748 }
749}