1use crate::errors::{AuthError, Result};
7use crate::security::MfaConfig;
8use async_trait::async_trait;
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use std::time::{SystemTime, UNIX_EPOCH};
12use totp_lite::{Sha1, totp};
13
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub enum MfaMethodType {
17 Totp,
18 Sms,
19 Email,
20 WebAuthn,
21 BackupCodes,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct MfaChallenge {
27 pub id: String,
29 pub user_id: String,
31 pub method_type: MfaMethodType,
33 pub challenge_data: MfaChallengeData,
35 pub created_at: SystemTime,
37 pub expires_at: SystemTime,
39 pub attempts: u32,
41 pub max_attempts: u32,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum MfaChallengeData {
48 Totp {
49 time_window: u64,
51 },
52 Sms {
53 phone_number: String,
55 code: String,
57 },
58 Email {
59 email: String,
61 code: String,
63 },
64 WebAuthn {
65 challenge: Vec<u8>,
67 allowed_credentials: Vec<String>,
69 },
70 BackupCodes {
71 remaining_codes: u32,
73 },
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct UserMfaMethod {
79 pub id: String,
81 pub user_id: String,
83 pub method_type: MfaMethodType,
85 pub method_data: MfaMethodData,
87 pub display_name: String,
89 pub is_primary: bool,
91 pub is_enabled: bool,
93 pub created_at: SystemTime,
95 pub last_used_at: Option<SystemTime>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub enum MfaMethodData {
102 Totp {
103 secret_key: String,
105 qr_code_url: String,
107 },
108 Sms {
109 phone_number: String,
111 is_verified: bool,
113 },
114 Email {
115 email: String,
117 is_verified: bool,
119 },
120 WebAuthn {
121 credential_id: String,
123 public_key: Vec<u8>,
125 counter: u32,
127 },
128 BackupCodes {
129 codes: Vec<String>,
131 used_count: u32,
133 },
134}
135
136#[derive(Debug, Clone)]
138pub struct MfaVerificationResult {
139 pub success: bool,
141 pub method_type: MfaMethodType,
143 pub remaining_attempts: Option<u32>,
145 pub error_message: Option<String>,
147}
148
149pub struct TotpProvider {
151 config: crate::security::TotpConfig,
152}
153
154impl TotpProvider {
155 pub fn new(config: crate::security::TotpConfig) -> Self {
156 Self { config }
157 }
158
159 pub fn generate_secret(&self) -> crate::Result<String> {
161 use ring::rand::{SecureRandom, SystemRandom};
162 let rng = SystemRandom::new();
163 let mut secret = [0u8; 20];
164 rng.fill(&mut secret).map_err(|_| {
165 crate::errors::AuthError::crypto("Failed to generate secure TOTP secret".to_string())
166 })?;
167 Ok(base32::encode(
168 base32::Alphabet::Rfc4648 { padding: true },
169 &secret,
170 ))
171 }
172
173 pub fn generate_qr_code_url(&self, secret: &str, user_identifier: &str) -> String {
175 format!(
176 "otpauth://totp/{}:{}?secret={}&issuer={}&digits={}&period={}",
177 urlencoding::encode(&self.config.issuer),
178 urlencoding::encode(user_identifier),
179 secret,
180 urlencoding::encode(&self.config.issuer),
181 self.config.digits,
182 self.config.period
183 )
184 }
185
186 pub fn generate_code(&self, secret: &str, time_step: Option<u64>) -> Result<String> {
188 if secret.trim().is_empty() {
189 return Err(AuthError::validation("TOTP secret cannot be empty"));
190 }
191
192 let secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret)
193 .ok_or_else(|| AuthError::validation("Invalid TOTP secret"))?;
194
195 let time_step = time_step.unwrap_or_else(|| {
196 SystemTime::now()
197 .duration_since(UNIX_EPOCH)
198 .unwrap()
199 .as_secs()
200 / self.config.period
201 });
202
203 let unix_timestamp = time_step.checked_mul(self.config.period).ok_or_else(|| {
206 AuthError::InvalidInput("Time step too large for conversion".to_string())
207 })?;
208
209 let totp_value = totp::<Sha1>(&secret_bytes, unix_timestamp);
211
212 let parsed_value: u32 = totp_value
214 .parse()
215 .map_err(|_| AuthError::validation("TOTP generation error"))?;
216
217 Ok(format!(
219 "{:0width$}",
220 parsed_value % 10_u32.pow(self.config.digits.into()),
221 width = self.config.digits as usize
222 ))
223 }
224
225 pub fn verify_code(&self, secret: &str, code: &str, time_window: Option<u64>) -> Result<bool> {
227 let _secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret)
229 .ok_or_else(|| AuthError::validation("Invalid TOTP secret"))?;
230
231 let current_time_step = if let Some(time) = time_window {
232 time / self.config.period
233 } else {
234 SystemTime::now()
235 .duration_since(UNIX_EPOCH)
236 .unwrap()
237 .as_secs()
238 / self.config.period
239 };
240
241 for step_offset in [-1i64, 0, 1] {
243 let time_step_i64 = current_time_step as i64 + step_offset;
244 if time_step_i64 < 0 {
246 continue;
247 }
248 let time_step = time_step_i64 as u64;
249 let expected_code = self.generate_code(secret, Some(time_step))?;
250 if expected_code == code {
251 return Ok(true);
252 }
253 }
254
255 Ok(false)
256 }
257
258 pub fn verify_totp(&self, secret: &str, token: &str, window: u8) -> Result<bool> {
260 let now = SystemTime::now()
261 .duration_since(UNIX_EPOCH)
262 .map_err(|_| AuthError::validation("System time error"))?
263 .as_secs()
264 / self.config.period;
265
266 use subtle::ConstantTimeEq;
268
269 for i in 0..=window {
270 if i == 0 {
272 if let Ok(expected_code) = self.generate_code(secret, Some(now))
273 && expected_code.as_bytes().ct_eq(token.as_bytes()).into()
274 {
275 return Ok(true);
276 }
277 } else {
278 for offset in [i as i64, -(i as i64)] {
280 let time_step_i64 = now as i64 + offset;
281 if time_step_i64 < 0 {
283 continue;
284 }
285 let time_step = time_step_i64 as u64;
286 if let Ok(expected_code) = self.generate_code(secret, Some(time_step))
287 && expected_code.as_bytes().ct_eq(token.as_bytes()).into()
288 {
289 return Ok(true);
290 }
291 }
292 }
293 }
294 Ok(false)
295 }
296}
297
298#[async_trait]
300pub trait SmsProvider: Send + Sync {
301 async fn send_code(&self, phone_number: &str, code: &str) -> Result<()>;
302}
303
304#[async_trait]
306pub trait EmailProvider: Send + Sync {
307 async fn send_code(&self, email: &str, code: &str) -> Result<()>;
308}
309
310pub struct BackupCodesProvider;
312
313impl BackupCodesProvider {
314 pub fn generate_codes(count: u8) -> Vec<String> {
316 let mut rng = rand::rng();
317 (0..count)
318 .map(|_| {
319 format!(
320 "{:04}-{:04}",
321 rng.random_range(1000..9999),
322 rng.random_range(1000..9999)
323 )
324 })
325 .collect()
326 }
327
328 pub fn hash_codes(codes: &[String]) -> Result<Vec<String>> {
330 codes
331 .iter()
332 .map(|code| {
333 Ok(format!("hashed_{}", code))
335 })
336 .collect()
337 }
338
339 pub fn verify_code(hashed_codes: &[String], provided_code: &str) -> bool {
341 let expected_hash = format!("hashed_{}", provided_code);
342 hashed_codes.contains(&expected_hash)
343 }
344}
345
346#[async_trait]
348pub trait MfaStorage: Send + Sync {
349 async fn store_mfa_method(&self, method: &UserMfaMethod) -> Result<()>;
351
352 async fn get_user_mfa_methods(&self, user_id: &str) -> Result<Vec<UserMfaMethod>>;
354
355 async fn update_mfa_method(&self, method: &UserMfaMethod) -> Result<()>;
357
358 async fn delete_mfa_method(&self, method_id: &str) -> Result<()>;
360
361 async fn store_mfa_challenge(&self, challenge: &MfaChallenge) -> Result<()>;
363
364 async fn get_mfa_challenge(&self, challenge_id: &str) -> Result<Option<MfaChallenge>>;
366
367 async fn update_mfa_challenge(&self, challenge: &MfaChallenge) -> Result<()>;
369
370 async fn delete_mfa_challenge(&self, challenge_id: &str) -> Result<()>;
372
373 async fn cleanup_expired_challenges(&self) -> Result<()>;
375}
376
377pub struct MfaManager<S: MfaStorage> {
379 storage: S,
380 config: MfaConfig,
381 totp_provider: TotpProvider,
382 sms_provider: Option<Box<dyn SmsProvider>>,
383 email_provider: Option<Box<dyn EmailProvider>>,
384}
385
386impl<S: MfaStorage> MfaManager<S> {
387 pub fn new(storage: S, config: MfaConfig) -> Self {
389 let totp_provider = TotpProvider::new(config.totp_config.clone());
390
391 Self {
392 storage,
393 config,
394 totp_provider,
395 sms_provider: None,
396 email_provider: None,
397 }
398 }
399
400 pub fn with_sms_provider(mut self, provider: Box<dyn SmsProvider>) -> Self {
402 self.sms_provider = Some(provider);
403 self
404 }
405
406 pub fn with_email_provider(mut self, provider: Box<dyn EmailProvider>) -> Self {
408 self.email_provider = Some(provider);
409 self
410 }
411
412 pub async fn setup_totp(&self, user_id: &str, user_identifier: &str) -> Result<UserMfaMethod> {
414 let secret = self.totp_provider.generate_secret()?;
415 let qr_code_url = self
416 .totp_provider
417 .generate_qr_code_url(&secret, user_identifier);
418
419 let method = UserMfaMethod {
420 id: uuid::Uuid::new_v4().to_string(),
421 user_id: user_id.to_string(),
422 method_type: MfaMethodType::Totp,
423 method_data: MfaMethodData::Totp {
424 secret_key: secret,
425 qr_code_url,
426 },
427 display_name: "Authenticator App".to_string(),
428 is_primary: false,
429 is_enabled: false, created_at: SystemTime::now(),
431 last_used_at: None,
432 };
433
434 self.storage.store_mfa_method(&method).await?;
435 Ok(method)
436 }
437
438 pub async fn setup_sms(&self, user_id: &str, phone_number: &str) -> Result<UserMfaMethod> {
440 let method = UserMfaMethod {
441 id: uuid::Uuid::new_v4().to_string(),
442 user_id: user_id.to_string(),
443 method_type: MfaMethodType::Sms,
444 method_data: MfaMethodData::Sms {
445 phone_number: phone_number.to_string(),
446 is_verified: false,
447 },
448 display_name: format!("SMS to {}", mask_phone_number(phone_number)),
449 is_primary: false,
450 is_enabled: false,
451 created_at: SystemTime::now(),
452 last_used_at: None,
453 };
454
455 self.storage.store_mfa_method(&method).await?;
456 Ok(method)
457 }
458
459 pub async fn generate_backup_codes(
461 &self,
462 user_id: &str,
463 ) -> Result<(UserMfaMethod, Vec<String>)> {
464 let codes = BackupCodesProvider::generate_codes(10);
465 let hashed_codes = BackupCodesProvider::hash_codes(&codes)?;
466
467 let method = UserMfaMethod {
468 id: uuid::Uuid::new_v4().to_string(),
469 user_id: user_id.to_string(),
470 method_type: MfaMethodType::BackupCodes,
471 method_data: MfaMethodData::BackupCodes {
472 codes: hashed_codes,
473 used_count: 0,
474 },
475 display_name: "Backup Codes".to_string(),
476 is_primary: false,
477 is_enabled: true,
478 created_at: SystemTime::now(),
479 last_used_at: None,
480 };
481
482 self.storage.store_mfa_method(&method).await?;
483 Ok((method, codes))
484 }
485
486 pub async fn create_challenge(
488 &self,
489 user_id: &str,
490 method_type: MfaMethodType,
491 ) -> Result<MfaChallenge> {
492 let user_methods = self.storage.get_user_mfa_methods(user_id).await?;
493 let method = user_methods
494 .iter()
495 .find(|m| m.method_type == method_type && m.is_enabled)
496 .ok_or_else(|| AuthError::validation("MFA method not found or not enabled"))?;
497
498 let challenge_data = match &method.method_data {
499 MfaMethodData::Totp { .. } => MfaChallengeData::Totp {
500 time_window: SystemTime::now()
501 .duration_since(UNIX_EPOCH)
502 .unwrap()
503 .as_secs()
504 / self.config.totp_config.period,
505 },
506 MfaMethodData::Sms { phone_number, .. } => {
507 let code = generate_numeric_code(6);
508 if let Some(sms_provider) = &self.sms_provider {
509 sms_provider.send_code(phone_number, &code).await?;
510 }
511 MfaChallengeData::Sms {
512 phone_number: mask_phone_number(phone_number),
513 code,
514 }
515 }
516 MfaMethodData::Email { email, .. } => {
517 let code = generate_numeric_code(6);
518 if let Some(email_provider) = &self.email_provider {
519 email_provider.send_code(email, &code).await?;
520 }
521 MfaChallengeData::Email {
522 email: mask_email(email),
523 code,
524 }
525 }
526 MfaMethodData::BackupCodes { .. } => {
527 MfaChallengeData::BackupCodes { remaining_codes: 8 } }
529 _ => return Err(AuthError::validation("Unsupported MFA method type")),
530 };
531
532 let challenge = MfaChallenge {
533 id: uuid::Uuid::new_v4().to_string(),
534 user_id: user_id.to_string(),
535 method_type,
536 challenge_data,
537 created_at: SystemTime::now(),
538 expires_at: SystemTime::now() + std::time::Duration::from_secs(300), attempts: 0,
540 max_attempts: 3,
541 };
542
543 self.storage.store_mfa_challenge(&challenge).await?;
544 Ok(challenge)
545 }
546
547 pub async fn verify_challenge(
549 &self,
550 challenge_id: &str,
551 response: &str,
552 ) -> Result<MfaVerificationResult> {
553 let mut challenge = self
554 .storage
555 .get_mfa_challenge(challenge_id)
556 .await?
557 .ok_or_else(|| AuthError::validation("MFA challenge not found"))?;
558
559 if SystemTime::now() > challenge.expires_at {
561 self.storage.delete_mfa_challenge(challenge_id).await?;
562 return Ok(MfaVerificationResult {
563 success: false,
564 method_type: challenge.method_type,
565 remaining_attempts: None,
566 error_message: Some("Challenge has expired".to_string()),
567 });
568 }
569
570 if challenge.attempts >= challenge.max_attempts {
572 self.storage.delete_mfa_challenge(challenge_id).await?;
573 return Ok(MfaVerificationResult {
574 success: false,
575 method_type: challenge.method_type,
576 remaining_attempts: Some(0),
577 error_message: Some("Maximum attempts exceeded".to_string()),
578 });
579 }
580
581 challenge.attempts += 1;
582
583 let success = match &challenge.challenge_data {
584 MfaChallengeData::Totp { time_window } => {
585 let user_methods = self
586 .storage
587 .get_user_mfa_methods(&challenge.user_id)
588 .await?;
589 if let Some(method) = user_methods
590 .iter()
591 .find(|m| m.method_type == MfaMethodType::Totp)
592 {
593 if let MfaMethodData::Totp { secret_key, .. } = &method.method_data {
594 self.totp_provider
595 .verify_code(secret_key, response, Some(*time_window))?
596 } else {
597 false
598 }
599 } else {
600 false
601 }
602 }
603 MfaChallengeData::Sms { code, .. } => code == response,
604 MfaChallengeData::Email { code, .. } => code == response,
605 MfaChallengeData::BackupCodes { .. } => {
606 let user_methods = self
607 .storage
608 .get_user_mfa_methods(&challenge.user_id)
609 .await?;
610 if let Some(method) = user_methods
611 .iter()
612 .find(|m| m.method_type == MfaMethodType::BackupCodes)
613 {
614 if let MfaMethodData::BackupCodes { codes, .. } = &method.method_data {
615 BackupCodesProvider::verify_code(codes, response)
616 } else {
617 false
618 }
619 } else {
620 false
621 }
622 }
623 _ => false,
624 };
625
626 if success {
627 self.storage.delete_mfa_challenge(challenge_id).await?;
628 Ok(MfaVerificationResult {
629 success: true,
630 method_type: challenge.method_type,
631 remaining_attempts: None,
632 error_message: None,
633 })
634 } else {
635 let remaining = challenge.max_attempts.saturating_sub(challenge.attempts);
636 self.storage.update_mfa_challenge(&challenge).await?;
637
638 Ok(MfaVerificationResult {
639 success: false,
640 method_type: challenge.method_type,
641 remaining_attempts: Some(remaining),
642 error_message: Some("Invalid code".to_string()),
643 })
644 }
645 }
646
647 pub async fn has_mfa_enabled(&self, user_id: &str) -> Result<bool> {
649 let methods = self.storage.get_user_mfa_methods(user_id).await?;
650 Ok(methods.iter().any(|m| m.is_enabled))
651 }
652
653 pub async fn get_enabled_methods(&self, user_id: &str) -> Result<Vec<MfaMethodType>> {
655 let methods = self.storage.get_user_mfa_methods(user_id).await?;
656 Ok(methods
657 .iter()
658 .filter(|m| m.is_enabled)
659 .map(|m| m.method_type.clone())
660 .collect())
661 }
662}
663
664fn generate_numeric_code(length: u8) -> String {
666 let mut rng = rand::rng();
667 (0..length)
668 .map(|_| rng.random_range(0..10).to_string())
669 .collect()
670}
671
672fn mask_phone_number(phone: &str) -> String {
674 if phone.len() > 4 {
675 format!("***-***-{}", &phone[phone.len() - 4..])
676 } else {
677 "***-***-****".to_string()
678 }
679}
680
681fn mask_email(email: &str) -> String {
683 if let Some(at_pos) = email.find('@') {
684 let (local, domain) = email.split_at(at_pos);
685 if local.len() > 2 {
686 format!("{}***{}", &local[0..1], &domain)
687 } else {
688 format!("***{}", domain)
689 }
690 } else {
691 "***@***.***".to_string()
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698
699 #[test]
700 fn test_totp_generation() {
701 let config = crate::security::TotpConfig::default();
702 let provider = TotpProvider::new(config);
703
704 let secret = provider.generate_secret().unwrap();
705 assert!(!secret.is_empty());
706
707 let code = provider.generate_code(&secret, Some(1)).unwrap();
708 assert_eq!(code.len(), 6);
709
710 assert!(provider.verify_code(&secret, &code, Some(1)).unwrap());
712
713 assert!(!provider.verify_code(&secret, "000000", Some(1)).unwrap());
715 }
716
717 #[test]
718 fn test_backup_codes() {
719 let codes = BackupCodesProvider::generate_codes(5);
720 assert_eq!(codes.len(), 5);
721
722 let hashed = BackupCodesProvider::hash_codes(&codes).unwrap();
723 assert_eq!(hashed.len(), 5);
724
725 assert!(BackupCodesProvider::verify_code(&hashed, &codes[0]));
727
728 assert!(!BackupCodesProvider::verify_code(&hashed, "1234-5678"));
730 }
731
732 #[test]
733 fn test_masking() {
734 assert_eq!(mask_phone_number("+1234567890"), "***-***-7890");
735 assert_eq!(mask_email("user@example.com"), "u***@example.com");
736 }
737}