pub mod backup_codes;
pub mod email;
pub mod sms_kit;
pub mod totp;
use crate::errors::Result;
use crate::methods::MfaChallenge;
use crate::storage::AuthStorage;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::debug;
pub use backup_codes::BackupCodesManager;
pub use email::EmailManager;
pub use totp::TotpManager;
pub use sms_kit::{
RateLimitConfig as SmsKitRateLimitConfig, SmsKitConfig, SmsKitManager, SmsKitProvider,
SmsKitProviderConfig, WebhookConfig,
};
pub use sms_kit::SmsKitManager as SmsManager;
pub struct MfaManager {
pub totp: TotpManager,
pub sms: SmsKitManager,
pub email: EmailManager,
pub backup_codes: BackupCodesManager,
challenges: Arc<RwLock<HashMap<String, MfaChallenge>>>,
storage: Arc<dyn AuthStorage>,
}
impl MfaManager {
pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
Self {
totp: TotpManager::new(storage.clone()),
sms: SmsKitManager::new(storage.clone()),
email: EmailManager::new(storage.clone()),
backup_codes: BackupCodesManager::new(storage.clone()),
challenges: Arc::new(RwLock::new(HashMap::new())),
storage,
}
}
pub fn new_with_smskit_config(
storage: Arc<dyn AuthStorage>,
smskit_config: SmsKitConfig,
) -> Result<Self> {
Ok(Self {
totp: TotpManager::new(storage.clone()),
sms: SmsKitManager::new_with_config(storage.clone(), smskit_config)?,
email: EmailManager::new(storage.clone()),
backup_codes: BackupCodesManager::new(storage.clone()),
challenges: Arc::new(RwLock::new(HashMap::new())),
storage,
})
}
pub async fn store_challenge(&self, challenge: MfaChallenge) -> Result<()> {
debug!("Storing MFA challenge '{}'", challenge.id);
let mut challenges = self.challenges.write().await;
challenges.insert(challenge.id.clone(), challenge);
Ok(())
}
pub async fn get_challenge(&self, challenge_id: &str) -> Result<Option<MfaChallenge>> {
let challenges = self.challenges.read().await;
Ok(challenges.get(challenge_id).cloned())
}
pub async fn remove_challenge(&self, challenge_id: &str) -> Result<()> {
debug!("Removing MFA challenge '{}'", challenge_id);
let mut challenges = self.challenges.write().await;
challenges.remove(challenge_id);
Ok(())
}
pub async fn cleanup_expired_challenges(&self) -> Result<()> {
debug!("Cleaning up expired MFA challenges");
let mut challenges = self.challenges.write().await;
let now = chrono::Utc::now();
challenges.retain(|_, challenge| challenge.expires_at > now);
Ok(())
}
pub async fn get_active_challenge_count(&self) -> usize {
self.challenges.read().await.len()
}
pub async fn initiate_step_up_authentication(
&self,
user_id: &str,
required_methods: &[MfaMethod],
risk_level: RiskLevel,
) -> Result<CrossMethodChallenge> {
tracing::info!(
"Initiating step-up authentication for user: {} with risk level: {:?}",
user_id,
risk_level
);
let adaptive_methods = self
.adapt_required_methods(required_methods, risk_level.clone())
.await?;
let challenge_id = uuid::Uuid::new_v4().to_string();
let mut method_challenges = HashMap::new();
let mut completion_status = HashMap::new();
for method in &adaptive_methods {
let method_challenge = match method {
MfaMethod::Totp => {
completion_status.insert(method.clone(), false);
self.create_totp_challenge(user_id, &challenge_id).await?
}
MfaMethod::Sms => {
completion_status.insert(method.clone(), false);
self.create_sms_challenge(user_id, &challenge_id).await?
}
MfaMethod::Email => {
completion_status.insert(method.clone(), false);
self.create_email_challenge(user_id, &challenge_id).await?
}
MfaMethod::BackupCode => {
completion_status.insert(method.clone(), false);
MethodChallenge::BackupCode {
challenge_id: format!("{}-backup", challenge_id),
instructions: "Enter one of your backup codes".to_string(),
}
}
};
method_challenges.insert(method.clone(), method_challenge);
}
let cross_method_challenge = CrossMethodChallenge {
id: challenge_id,
user_id: user_id.to_string(),
required_methods: adaptive_methods.clone(),
method_challenges,
completion_status,
risk_level,
expires_at: chrono::Utc::now() + chrono::Duration::minutes(10),
created_at: chrono::Utc::now(),
};
{
let mut challenges = self.challenges.write().await;
challenges.insert(
cross_method_challenge.id.clone(),
MfaChallenge {
id: cross_method_challenge.id.clone(),
mfa_type: crate::methods::MfaType::Totp, user_id: user_id.to_string(),
expires_at: cross_method_challenge.expires_at,
message: Some("Complete all required authentication methods".to_string()),
data: {
let mut data = HashMap::new();
data.insert(
"cross_method_data".to_string(),
serde_json::to_value(&cross_method_challenge)?,
);
data
},
},
);
}
tracing::info!(
"Step-up authentication initiated with {} methods",
adaptive_methods.len()
);
Ok(cross_method_challenge)
}
pub async fn complete_cross_method_step(
&self,
challenge_id: &str,
method: MfaMethod,
response: &str,
) -> Result<CrossMethodCompletionResult> {
tracing::debug!(
"Completing cross-method step: {:?} for challenge: {}",
method,
challenge_id
);
let mut cross_challenge = self.get_cross_method_challenge(challenge_id).await?;
if cross_challenge.completion_status.get(&method) == Some(&true) {
return Ok(CrossMethodCompletionResult {
method,
success: true,
remaining_methods: self.get_remaining_methods(&cross_challenge),
all_completed: false,
error: Some("Method already completed".to_string()),
});
}
let verification_result = match method {
MfaMethod::Totp => {
self.totp
.verify_code(&cross_challenge.user_id, response)
.await
}
MfaMethod::Sms => {
self.sms
.verify_code(&cross_challenge.user_id, response)
.await
}
MfaMethod::Email => {
self.email
.verify_code(&cross_challenge.user_id, response)
.await
}
MfaMethod::BackupCode => {
self.backup_codes
.verify_code(&cross_challenge.user_id, response)
.await
}
};
let success = verification_result.is_ok();
if success {
cross_challenge
.completion_status
.insert(method.clone(), true);
self.update_cross_method_challenge(&cross_challenge).await?;
tracing::info!("Cross-method step completed successfully: {:?}", method);
} else {
tracing::warn!(
"Cross-method step failed: {:?} - {:?}",
method,
verification_result
);
}
let remaining_methods = self.get_remaining_methods(&cross_challenge);
let all_completed = remaining_methods.is_empty();
if all_completed {
tracing::info!(
"All cross-method authentication steps completed for challenge: {}",
challenge_id
);
self.remove_challenge(challenge_id).await?;
}
Ok(CrossMethodCompletionResult {
method,
success,
remaining_methods,
all_completed,
error: if success {
None
} else {
Some(format!(
"Verification failed: {:?}",
verification_result.unwrap_err()
))
},
})
}
pub async fn get_available_methods(&self, user_id: &str) -> Result<Vec<MfaMethod>> {
tracing::debug!("Getting available MFA methods for user: {}", user_id);
let mut available_methods = Vec::new();
if self.totp.has_totp_secret(user_id).await.unwrap_or(false) {
available_methods.push(MfaMethod::Totp);
}
if self.sms.has_phone_number(user_id).await.unwrap_or(false) {
available_methods.push(MfaMethod::Sms);
}
if self.email.has_email(user_id).await.unwrap_or(false) {
available_methods.push(MfaMethod::Email);
}
if self
.backup_codes
.has_backup_codes(user_id)
.await
.unwrap_or(false)
{
available_methods.push(MfaMethod::BackupCode);
}
tracing::debug!(
"Available methods for user {}: {:?}",
user_id,
available_methods
);
Ok(available_methods)
}
pub async fn perform_method_fallback(
&self,
user_id: &str,
failed_method: MfaMethod,
fallback_order: &[MfaMethod],
) -> Result<MethodFallbackResult> {
tracing::info!(
"Performing method fallback for user: {} after failed method: {:?}",
user_id,
failed_method
);
let available_methods = self.get_available_methods(user_id).await?;
for fallback_method in fallback_order {
if available_methods.contains(fallback_method) && fallback_method != &failed_method {
let fallback_challenge = match fallback_method {
MfaMethod::Totp => self.create_totp_challenge(user_id, "fallback").await?,
MfaMethod::Sms => self.create_sms_challenge(user_id, "fallback").await?,
MfaMethod::Email => self.create_email_challenge(user_id, "fallback").await?,
MfaMethod::BackupCode => MethodChallenge::BackupCode {
challenge_id: "fallback-backup".to_string(),
instructions: "Enter one of your backup codes".to_string(),
},
};
tracing::info!(
"Fallback method activated: {:?} for user: {}",
fallback_method,
user_id
);
return Ok(MethodFallbackResult {
fallback_method: fallback_method.clone(),
challenge: fallback_challenge,
remaining_fallbacks: fallback_order
.iter()
.skip_while(|&m| m != fallback_method)
.skip(1)
.filter(|&m| available_methods.contains(m))
.cloned()
.collect(),
});
}
}
Err(crate::errors::AuthError::validation(
"No fallback methods available",
))
}
async fn adapt_required_methods(
&self,
base_methods: &[MfaMethod],
risk_level: RiskLevel,
) -> Result<Vec<MfaMethod>> {
let mut adapted_methods = base_methods.to_vec();
match risk_level {
RiskLevel::Low => {
adapted_methods.truncate(1);
}
RiskLevel::Medium => {
}
RiskLevel::High => {
if !adapted_methods.contains(&MfaMethod::Email) {
adapted_methods.push(MfaMethod::Email);
}
if !adapted_methods.contains(&MfaMethod::Sms) {
adapted_methods.push(MfaMethod::Sms);
}
}
RiskLevel::Critical => {
adapted_methods = vec![MfaMethod::Totp, MfaMethod::Sms, MfaMethod::Email];
}
}
Ok(adapted_methods)
}
async fn get_cross_method_challenge(&self, challenge_id: &str) -> Result<CrossMethodChallenge> {
let challenges = self.challenges.read().await;
let challenge = challenges
.get(challenge_id)
.ok_or_else(|| crate::errors::AuthError::validation("Challenge not found"))?;
let cross_challenge: CrossMethodChallenge =
if let Some(cross_method_value) = challenge.data.get("cross_method_data") {
serde_json::from_value(cross_method_value.clone())?
} else {
return Err(crate::errors::AuthError::validation(
"Invalid cross-method challenge data",
));
};
Ok(cross_challenge)
}
async fn update_cross_method_challenge(
&self,
cross_challenge: &CrossMethodChallenge,
) -> Result<()> {
let mut challenges = self.challenges.write().await;
if let Some(challenge) = challenges.get_mut(&cross_challenge.id) {
challenge.data.insert(
"cross_method_data".to_string(),
serde_json::to_value(cross_challenge)?,
);
}
Ok(())
}
fn get_remaining_methods(&self, cross_challenge: &CrossMethodChallenge) -> Vec<MfaMethod> {
cross_challenge
.completion_status
.iter()
.filter_map(|(method, &completed)| {
if !completed {
Some(method.clone())
} else {
None
}
})
.collect()
}
async fn create_totp_challenge(
&self,
_user_id: &str,
challenge_prefix: &str,
) -> Result<MethodChallenge> {
Ok(MethodChallenge::Totp {
challenge_id: format!("{}-totp", challenge_prefix),
instructions: "Enter the 6-digit code from your authenticator app".to_string(),
})
}
async fn create_sms_challenge(
&self,
user_id: &str,
challenge_prefix: &str,
) -> Result<MethodChallenge> {
let _code = self.sms.send_verification_code(user_id).await?;
Ok(MethodChallenge::Sms {
challenge_id: format!("{}-sms", challenge_prefix),
instructions: "Enter the verification code sent to your phone".to_string(),
phone_hint: self
.get_phone_hint(user_id)
.await
.unwrap_or_else(|_| "***-***-****".to_string()),
})
}
async fn create_email_challenge(
&self,
user_id: &str,
challenge_prefix: &str,
) -> Result<MethodChallenge> {
let _code = self.email.send_email_code(user_id).await?;
Ok(MethodChallenge::Email {
challenge_id: format!("{}-email", challenge_prefix),
instructions: "Enter the verification code sent to your email".to_string(),
email_hint: self
.get_email_hint(user_id)
.await
.unwrap_or_else(|_| "****@****.com".to_string()),
})
}
async fn get_phone_hint(&self, user_id: &str) -> Result<String> {
Ok(format!("***-***-{}", &user_id[..4]))
}
async fn get_email_hint(&self, user_id: &str) -> Result<String> {
Ok(format!("{}****@****.com", &user_id[..2]))
}
pub async fn emergency_mfa_bypass(&self, user_id: &str, admin_token: &str) -> Result<bool> {
tracing::warn!("Emergency MFA bypass requested for user: {}", user_id);
let admin_key = format!("emergency_admin:{}", admin_token);
if let Some(_admin_data) = self.storage.get_kv(&admin_key).await? {
tracing::info!("Emergency MFA bypass granted for user: {}", user_id);
let bypass_key = format!("mfa_bypass:{}:{}", user_id, chrono::Utc::now().timestamp());
let bypass_data = format!(
"Emergency bypass by admin token at {}",
chrono::Utc::now().to_rfc3339()
);
self.storage
.store_kv(
&bypass_key,
bypass_data.as_bytes(),
Some(std::time::Duration::from_secs(86400)),
)
.await?;
Ok(true)
} else {
tracing::error!("Invalid admin token for emergency MFA bypass");
Ok(false)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum MfaMethod {
Totp,
Sms,
Email,
BackupCode,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum RiskLevel {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CrossMethodChallenge {
pub id: String,
pub user_id: String,
pub required_methods: Vec<MfaMethod>,
pub method_challenges: HashMap<MfaMethod, MethodChallenge>,
pub completion_status: HashMap<MfaMethod, bool>,
pub risk_level: RiskLevel,
pub expires_at: chrono::DateTime<chrono::Utc>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum MethodChallenge {
Totp {
challenge_id: String,
instructions: String,
},
Sms {
challenge_id: String,
instructions: String,
phone_hint: String,
},
Email {
challenge_id: String,
instructions: String,
email_hint: String,
},
BackupCode {
challenge_id: String,
instructions: String,
},
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CrossMethodCompletionResult {
pub method: MfaMethod,
pub success: bool,
pub remaining_methods: Vec<MfaMethod>,
pub all_completed: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MethodFallbackResult {
pub fallback_method: MfaMethod,
pub challenge: MethodChallenge,
pub remaining_fallbacks: Vec<MfaMethod>,
}