1pub mod backup_codes;
4pub mod email;
5pub mod sms_kit;
6pub mod totp;
7
8use crate::errors::Result;
9use crate::methods::MfaChallenge;
10use crate::storage::AuthStorage;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::debug;
15
16pub use backup_codes::BackupCodesManager;
17pub use email::EmailManager;
18pub use totp::TotpManager;
19
20pub use sms_kit::{
22 RateLimitConfig as SmsKitRateLimitConfig, SmsKitConfig, SmsKitManager, SmsKitProvider,
23 SmsKitProviderConfig, WebhookConfig,
24};
25
26pub use sms_kit::SmsKitManager as SmsManager;
28
29pub struct MfaManager {
95 pub totp: TotpManager,
97
98 pub sms: SmsKitManager,
100
101 pub email: EmailManager,
103
104 pub backup_codes: BackupCodesManager,
106
107 challenges: Arc<RwLock<HashMap<String, MfaChallenge>>>,
109
110 storage: Arc<dyn AuthStorage>,
113}
114
115impl MfaManager {
116 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
118 Self {
119 totp: TotpManager::new(storage.clone()),
120 sms: SmsKitManager::new(storage.clone()),
121 email: EmailManager::new(storage.clone()),
122 backup_codes: BackupCodesManager::new(storage.clone()),
123 challenges: Arc::new(RwLock::new(HashMap::new())),
124 storage,
125 }
126 }
127
128 pub fn new_with_smskit_config(
130 storage: Arc<dyn AuthStorage>,
131 smskit_config: SmsKitConfig,
132 ) -> Result<Self> {
133 Ok(Self {
134 totp: TotpManager::new(storage.clone()),
135 sms: SmsKitManager::new_with_config(storage.clone(), smskit_config)?,
136 email: EmailManager::new(storage.clone()),
137 backup_codes: BackupCodesManager::new(storage.clone()),
138 challenges: Arc::new(RwLock::new(HashMap::new())),
139 storage,
140 })
141 }
142
143 pub async fn store_challenge(&self, challenge: MfaChallenge) -> Result<()> {
145 debug!("Storing MFA challenge '{}'", challenge.id);
146
147 let mut challenges = self.challenges.write().await;
148 challenges.insert(challenge.id.clone(), challenge);
149
150 Ok(())
151 }
152
153 pub async fn get_challenge(&self, challenge_id: &str) -> Result<Option<MfaChallenge>> {
155 let challenges = self.challenges.read().await;
156 Ok(challenges.get(challenge_id).cloned())
157 }
158
159 pub async fn remove_challenge(&self, challenge_id: &str) -> Result<()> {
161 debug!("Removing MFA challenge '{}'", challenge_id);
162
163 let mut challenges = self.challenges.write().await;
164 challenges.remove(challenge_id);
165
166 Ok(())
167 }
168
169 pub async fn cleanup_expired_challenges(&self) -> Result<()> {
171 debug!("Cleaning up expired MFA challenges");
172
173 let mut challenges = self.challenges.write().await;
174 let now = chrono::Utc::now();
175 challenges.retain(|_, challenge| challenge.expires_at > now);
176
177 Ok(())
178 }
179
180 pub async fn get_active_challenge_count(&self) -> usize {
182 self.challenges.read().await.len()
183 }
184
185 pub async fn initiate_step_up_authentication(
187 &self,
188 user_id: &str,
189 required_methods: &[MfaMethod],
190 risk_level: RiskLevel,
191 ) -> Result<CrossMethodChallenge> {
192 tracing::info!(
193 "Initiating step-up authentication for user: {} with risk level: {:?}",
194 user_id,
195 risk_level
196 );
197
198 let adaptive_methods = self
200 .adapt_required_methods(required_methods, risk_level.clone())
201 .await?;
202
203 let challenge_id = uuid::Uuid::new_v4().to_string();
205
206 let mut method_challenges = HashMap::new();
208 let mut completion_status = HashMap::new();
209
210 for method in &adaptive_methods {
211 let method_challenge = match method {
212 MfaMethod::Totp => {
213 completion_status.insert(method.clone(), false);
214 self.create_totp_challenge(user_id, &challenge_id).await?
215 }
216 MfaMethod::Sms => {
217 completion_status.insert(method.clone(), false);
218 self.create_sms_challenge(user_id, &challenge_id).await?
219 }
220 MfaMethod::Email => {
221 completion_status.insert(method.clone(), false);
222 self.create_email_challenge(user_id, &challenge_id).await?
223 }
224 MfaMethod::BackupCode => {
225 completion_status.insert(method.clone(), false);
226 MethodChallenge::BackupCode {
227 challenge_id: format!("{}-backup", challenge_id),
228 instructions: "Enter one of your backup codes".to_string(),
229 }
230 }
231 };
232
233 method_challenges.insert(method.clone(), method_challenge);
234 }
235
236 let cross_method_challenge = CrossMethodChallenge {
237 id: challenge_id,
238 user_id: user_id.to_string(),
239 required_methods: adaptive_methods.clone(),
240 method_challenges,
241 completion_status,
242 risk_level,
243 expires_at: chrono::Utc::now() + chrono::Duration::minutes(10),
244 created_at: chrono::Utc::now(),
245 };
246
247 {
249 let mut challenges = self.challenges.write().await;
250 challenges.insert(
251 cross_method_challenge.id.clone(),
252 MfaChallenge {
253 id: cross_method_challenge.id.clone(),
254 mfa_type: crate::methods::MfaType::Totp, user_id: user_id.to_string(),
256 expires_at: cross_method_challenge.expires_at,
257 message: Some("Complete all required authentication methods".to_string()),
258 data: {
259 let mut data = HashMap::new();
260 data.insert(
261 "cross_method_data".to_string(),
262 serde_json::to_value(&cross_method_challenge)?,
263 );
264 data
265 },
266 },
267 );
268 }
269
270 tracing::info!(
271 "Step-up authentication initiated with {} methods",
272 adaptive_methods.len()
273 );
274 Ok(cross_method_challenge)
275 }
276
277 pub async fn complete_cross_method_step(
279 &self,
280 challenge_id: &str,
281 method: MfaMethod,
282 response: &str,
283 ) -> Result<CrossMethodCompletionResult> {
284 tracing::debug!(
285 "Completing cross-method step: {:?} for challenge: {}",
286 method,
287 challenge_id
288 );
289
290 let mut cross_challenge = self.get_cross_method_challenge(challenge_id).await?;
292
293 if cross_challenge.completion_status.get(&method) == Some(&true) {
294 return Ok(CrossMethodCompletionResult {
295 method,
296 success: true,
297 remaining_methods: self.get_remaining_methods(&cross_challenge),
298 all_completed: false,
299 error: Some("Method already completed".to_string()),
300 });
301 }
302
303 let verification_result = match method {
305 MfaMethod::Totp => {
306 self.totp
307 .verify_code(&cross_challenge.user_id, response)
308 .await
309 }
310 MfaMethod::Sms => {
311 self.sms
312 .verify_code(&cross_challenge.user_id, response)
313 .await
314 }
315 MfaMethod::Email => {
316 self.email
317 .verify_code(&cross_challenge.user_id, response)
318 .await
319 }
320 MfaMethod::BackupCode => {
321 self.backup_codes
322 .verify_code(&cross_challenge.user_id, response)
323 .await
324 }
325 };
326
327 let success = verification_result.is_ok();
328
329 if success {
330 cross_challenge
332 .completion_status
333 .insert(method.clone(), true);
334
335 self.update_cross_method_challenge(&cross_challenge).await?;
337
338 tracing::info!("Cross-method step completed successfully: {:?}", method);
339 } else {
340 tracing::warn!(
341 "Cross-method step failed: {:?} - {:?}",
342 method,
343 verification_result
344 );
345 }
346
347 let remaining_methods = self.get_remaining_methods(&cross_challenge);
348 let all_completed = remaining_methods.is_empty();
349
350 if all_completed {
351 tracing::info!(
352 "All cross-method authentication steps completed for challenge: {}",
353 challenge_id
354 );
355 self.remove_challenge(challenge_id).await?;
357 }
358
359 Ok(CrossMethodCompletionResult {
360 method,
361 success,
362 remaining_methods,
363 all_completed,
364 error: if success {
365 None
366 } else {
367 Some(format!(
368 "Verification failed: {:?}",
369 verification_result.unwrap_err()
370 ))
371 },
372 })
373 }
374
375 pub async fn get_available_methods(&self, user_id: &str) -> Result<Vec<MfaMethod>> {
377 tracing::debug!("Getting available MFA methods for user: {}", user_id);
378
379 let mut available_methods = Vec::new();
380
381 if self.totp.has_totp_secret(user_id).await.unwrap_or(false) {
383 available_methods.push(MfaMethod::Totp);
384 }
385
386 if self.sms.has_phone_number(user_id).await.unwrap_or(false) {
388 available_methods.push(MfaMethod::Sms);
389 }
390
391 if self.email.has_email(user_id).await.unwrap_or(false) {
393 available_methods.push(MfaMethod::Email);
394 }
395
396 if self
398 .backup_codes
399 .has_backup_codes(user_id)
400 .await
401 .unwrap_or(false)
402 {
403 available_methods.push(MfaMethod::BackupCode);
404 }
405
406 tracing::debug!(
407 "Available methods for user {}: {:?}",
408 user_id,
409 available_methods
410 );
411 Ok(available_methods)
412 }
413
414 pub async fn perform_method_fallback(
416 &self,
417 user_id: &str,
418 failed_method: MfaMethod,
419 fallback_order: &[MfaMethod],
420 ) -> Result<MethodFallbackResult> {
421 tracing::info!(
422 "Performing method fallback for user: {} after failed method: {:?}",
423 user_id,
424 failed_method
425 );
426
427 let available_methods = self.get_available_methods(user_id).await?;
428
429 for fallback_method in fallback_order {
431 if available_methods.contains(fallback_method) && fallback_method != &failed_method {
432 let fallback_challenge = match fallback_method {
434 MfaMethod::Totp => self.create_totp_challenge(user_id, "fallback").await?,
435 MfaMethod::Sms => self.create_sms_challenge(user_id, "fallback").await?,
436 MfaMethod::Email => self.create_email_challenge(user_id, "fallback").await?,
437 MfaMethod::BackupCode => MethodChallenge::BackupCode {
438 challenge_id: "fallback-backup".to_string(),
439 instructions: "Enter one of your backup codes".to_string(),
440 },
441 };
442
443 tracing::info!(
444 "Fallback method activated: {:?} for user: {}",
445 fallback_method,
446 user_id
447 );
448
449 return Ok(MethodFallbackResult {
450 fallback_method: fallback_method.clone(),
451 challenge: fallback_challenge,
452 remaining_fallbacks: fallback_order
453 .iter()
454 .skip_while(|&m| m != fallback_method)
455 .skip(1)
456 .filter(|&m| available_methods.contains(m))
457 .cloned()
458 .collect(),
459 });
460 }
461 }
462
463 Err(crate::errors::AuthError::validation(
464 "No fallback methods available",
465 ))
466 }
467
468 async fn adapt_required_methods(
470 &self,
471 base_methods: &[MfaMethod],
472 risk_level: RiskLevel,
473 ) -> Result<Vec<MfaMethod>> {
474 let mut adapted_methods = base_methods.to_vec();
475
476 match risk_level {
477 RiskLevel::Low => {
478 adapted_methods.truncate(1);
480 }
481 RiskLevel::Medium => {
482 }
485 RiskLevel::High => {
486 if !adapted_methods.contains(&MfaMethod::Email) {
488 adapted_methods.push(MfaMethod::Email);
489 }
490 if !adapted_methods.contains(&MfaMethod::Sms) {
491 adapted_methods.push(MfaMethod::Sms);
492 }
493 }
494 RiskLevel::Critical => {
495 adapted_methods = vec![MfaMethod::Totp, MfaMethod::Sms, MfaMethod::Email];
497 }
498 }
499
500 Ok(adapted_methods)
501 }
502
503 async fn get_cross_method_challenge(&self, challenge_id: &str) -> Result<CrossMethodChallenge> {
505 let challenges = self.challenges.read().await;
506 let challenge = challenges
507 .get(challenge_id)
508 .ok_or_else(|| crate::errors::AuthError::validation("Challenge not found"))?;
509
510 let cross_challenge: CrossMethodChallenge =
511 if let Some(cross_method_value) = challenge.data.get("cross_method_data") {
512 serde_json::from_value(cross_method_value.clone())?
513 } else {
514 return Err(crate::errors::AuthError::validation(
515 "Invalid cross-method challenge data",
516 ));
517 };
518 Ok(cross_challenge)
519 }
520
521 async fn update_cross_method_challenge(
522 &self,
523 cross_challenge: &CrossMethodChallenge,
524 ) -> Result<()> {
525 let mut challenges = self.challenges.write().await;
526 if let Some(challenge) = challenges.get_mut(&cross_challenge.id) {
527 challenge.data.insert(
528 "cross_method_data".to_string(),
529 serde_json::to_value(cross_challenge)?,
530 );
531 }
532 Ok(())
533 }
534
535 fn get_remaining_methods(&self, cross_challenge: &CrossMethodChallenge) -> Vec<MfaMethod> {
536 cross_challenge
537 .completion_status
538 .iter()
539 .filter_map(|(method, &completed)| {
540 if !completed {
541 Some(method.clone())
542 } else {
543 None
544 }
545 })
546 .collect()
547 }
548
549 async fn create_totp_challenge(
551 &self,
552 _user_id: &str,
553 challenge_prefix: &str,
554 ) -> Result<MethodChallenge> {
555 Ok(MethodChallenge::Totp {
556 challenge_id: format!("{}-totp", challenge_prefix),
557 instructions: "Enter the 6-digit code from your authenticator app".to_string(),
558 })
559 }
560
561 async fn create_sms_challenge(
562 &self,
563 user_id: &str,
564 challenge_prefix: &str,
565 ) -> Result<MethodChallenge> {
566 let _code = self.sms.send_verification_code(user_id).await?;
567 Ok(MethodChallenge::Sms {
568 challenge_id: format!("{}-sms", challenge_prefix),
569 instructions: "Enter the verification code sent to your phone".to_string(),
570 phone_hint: self
571 .get_phone_hint(user_id)
572 .await
573 .unwrap_or_else(|_| "***-***-****".to_string()),
574 })
575 }
576
577 async fn create_email_challenge(
578 &self,
579 user_id: &str,
580 challenge_prefix: &str,
581 ) -> Result<MethodChallenge> {
582 let _code = self.email.send_email_code(user_id).await?;
583 Ok(MethodChallenge::Email {
584 challenge_id: format!("{}-email", challenge_prefix),
585 instructions: "Enter the verification code sent to your email".to_string(),
586 email_hint: self
587 .get_email_hint(user_id)
588 .await
589 .unwrap_or_else(|_| "****@****.com".to_string()),
590 })
591 }
592
593 async fn get_phone_hint(&self, user_id: &str) -> Result<String> {
594 Ok(format!("***-***-{}", &user_id[..4]))
596 }
597
598 async fn get_email_hint(&self, user_id: &str) -> Result<String> {
599 Ok(format!("{}****@****.com", &user_id[..2]))
601 }
602
603 pub async fn emergency_mfa_bypass(&self, user_id: &str, admin_token: &str) -> Result<bool> {
606 tracing::warn!("Emergency MFA bypass requested for user: {}", user_id);
607
608 let admin_key = format!("emergency_admin:{}", admin_token);
610 if let Some(_admin_data) = self.storage.get_kv(&admin_key).await? {
611 tracing::info!("Emergency MFA bypass granted for user: {}", user_id);
612
613 let bypass_key = format!("mfa_bypass:{}:{}", user_id, chrono::Utc::now().timestamp());
615 let bypass_data = format!(
616 "Emergency bypass by admin token at {}",
617 chrono::Utc::now().to_rfc3339()
618 );
619 self.storage
620 .store_kv(
621 &bypass_key,
622 bypass_data.as_bytes(),
623 Some(std::time::Duration::from_secs(86400)),
624 )
625 .await?;
626
627 Ok(true)
628 } else {
629 tracing::error!("Invalid admin token for emergency MFA bypass");
630 Ok(false)
631 }
632 }
633}
634
635#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
637pub enum MfaMethod {
638 Totp,
639 Sms,
640 Email,
641 BackupCode,
642}
643
644#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
646pub enum RiskLevel {
647 Low,
648 Medium,
649 High,
650 Critical,
651}
652
653#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
655pub struct CrossMethodChallenge {
656 pub id: String,
657 pub user_id: String,
658 pub required_methods: Vec<MfaMethod>,
659 pub method_challenges: HashMap<MfaMethod, MethodChallenge>,
660 pub completion_status: HashMap<MfaMethod, bool>,
661 pub risk_level: RiskLevel,
662 pub expires_at: chrono::DateTime<chrono::Utc>,
663 pub created_at: chrono::DateTime<chrono::Utc>,
664}
665
666#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
668pub enum MethodChallenge {
669 Totp {
670 challenge_id: String,
671 instructions: String,
672 },
673 Sms {
674 challenge_id: String,
675 instructions: String,
676 phone_hint: String,
677 },
678 Email {
679 challenge_id: String,
680 instructions: String,
681 email_hint: String,
682 },
683 BackupCode {
684 challenge_id: String,
685 instructions: String,
686 },
687}
688
689#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
691pub struct CrossMethodCompletionResult {
692 pub method: MfaMethod,
693 pub success: bool,
694 pub remaining_methods: Vec<MfaMethod>,
695 pub all_completed: bool,
696 pub error: Option<String>,
697}
698
699#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
701pub struct MethodFallbackResult {
702 pub fallback_method: MfaMethod,
703 pub challenge: MethodChallenge,
704 pub remaining_fallbacks: Vec<MfaMethod>,
705}