use axum::{extract::State, http::HeaderMap, Json};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::callback::AuthCallback;
use crate::errors::AppError;
use crate::models::MessageResponse;
use crate::repositories::AuditEventType;
use crate::services::{EmailService, TotpService};
use crate::utils::authenticate;
use crate::AppState;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct MfaSetupResponse {
pub secret: String,
pub otpauth_uri: String,
pub recovery_codes: Vec<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EnableMfaRequest {
pub code: String,
}
#[derive(Debug, Deserialize, Zeroize, ZeroizeOnDrop)]
#[serde(rename_all = "camelCase")]
pub struct DisableMfaRequest {
pub password: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct MfaStatusResponse {
pub enabled: bool,
pub recovery_codes_remaining: usize,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct VerifyMfaRequest {
pub code: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RecoveryCodeRequest {
pub code: String,
}
async fn get_authenticated_session<C: AuthCallback, E: EmailService>(
state: &Arc<AppState<C, E>>,
headers: &HeaderMap,
) -> Result<(uuid::Uuid, uuid::Uuid), AppError> {
let auth = authenticate(state, headers).await?;
let session_id = auth.session_id.ok_or(AppError::StepUpRequired)?;
Ok((auth.user_id, session_id))
}
pub async fn setup_mfa<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<Json<MfaSetupResponse>, AppError> {
let (user_id, session_id) = get_authenticated_session(&state, &headers).await?;
state.step_up_service.require_step_up(session_id).await?;
if state.totp_repo.has_mfa_enabled(user_id).await? {
return Err(AppError::Validation("MFA is already enabled".into()));
}
let db_user = state
.user_repo
.find_by_id(user_id)
.await?
.ok_or(AppError::NotFound("User not found".into()))?;
let email = db_user
.email
.ok_or(AppError::Validation("Email required for MFA setup".into()))?;
let secret = state.totp_service.generate_secret();
state.totp_repo.upsert_secret(user_id, &secret).await?;
let otpauth_uri = state.totp_service.get_otpauth_uri(&secret, &email)?;
let recovery_codes = state.totp_service.generate_recovery_codes();
let code_hashes: Vec<String> = recovery_codes
.iter()
.map(|c| TotpService::hash_recovery_code(c))
.collect::<Result<Vec<_>, _>>()?;
state
.totp_repo
.store_recovery_codes(user_id, code_hashes)
.await?;
if let Err(e) = state
.audit_service
.log_user_event(AuditEventType::MfaSetupStarted, user_id, Some(&headers))
.await
{
tracing::warn!(error = %e, user_id = %user_id, "Failed to log MFA setup audit event");
}
Ok(Json(MfaSetupResponse {
secret,
otpauth_uri,
recovery_codes,
}))
}
pub async fn enable_mfa<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<EnableMfaRequest>,
) -> Result<Json<MessageResponse>, AppError> {
let (user_id, session_id) = get_authenticated_session(&state, &headers).await?;
state.step_up_service.require_step_up(session_id).await?;
let totp_secret = state
.totp_repo
.find_by_user(user_id)
.await?
.ok_or(AppError::Validation("MFA setup not started".into()))?;
if totp_secret.enabled {
return Err(AppError::Validation("MFA is already enabled".into()));
}
let db_user = state
.user_repo
.find_by_id(user_id)
.await?
.ok_or(AppError::NotFound("User not found".into()))?;
let email = db_user
.email
.ok_or(AppError::Validation("Email required for MFA".into()))?;
let time_step = state
.totp_service
.verify_with_replay_check(
&totp_secret.secret,
&req.code,
&email,
totp_secret.last_used_time_step,
)?
.ok_or(AppError::Validation("Invalid verification code".into()))?;
if !state
.totp_repo
.record_used_time_step_if_newer(user_id, time_step)
.await?
{
return Err(AppError::Validation("Invalid verification code".into()));
}
state.totp_repo.enable_mfa(user_id).await?;
if let Err(e) = state
.audit_service
.log_user_event(AuditEventType::MfaEnabled, user_id, Some(&headers))
.await
{
tracing::warn!(error = %e, user_id = %user_id, "Failed to log MFA enabled audit event");
}
Ok(Json(MessageResponse {
message: "MFA enabled successfully".into(),
}))
}
pub async fn disable_mfa<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<DisableMfaRequest>,
) -> Result<Json<MessageResponse>, AppError> {
let (user_id, _) = get_authenticated_session(&state, &headers).await?;
let db_user = state
.user_repo
.find_by_id(user_id)
.await?
.ok_or(AppError::NotFound("User not found".into()))?;
let password_hash = db_user.password_hash.ok_or(AppError::Validation(
"Password required to disable MFA".into(),
))?;
if !state
.password_service
.verify(req.password.clone(), password_hash)
.await?
{
return Err(AppError::InvalidCredentials);
}
state.totp_repo.disable_mfa(user_id).await?;
if let Err(e) = state
.audit_service
.log_user_event(AuditEventType::MfaDisabled, user_id, Some(&headers))
.await
{
tracing::warn!(error = %e, user_id = %user_id, "Failed to log MFA disabled audit event");
}
Ok(Json(MessageResponse {
message: "MFA disabled successfully".into(),
}))
}
pub async fn mfa_status<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<Json<MfaStatusResponse>, AppError> {
let (user_id, _) = get_authenticated_session(&state, &headers).await?;
let enabled = state.totp_repo.has_mfa_enabled(user_id).await?;
let recovery_codes = state.totp_repo.get_recovery_codes(user_id).await?;
Ok(Json(MfaStatusResponse {
enabled,
recovery_codes_remaining: recovery_codes.len(),
}))
}
pub async fn verify_mfa<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<VerifyMfaRequest>,
) -> Result<Json<MessageResponse>, AppError> {
let (user_id, session_id) = get_authenticated_session(&state, &headers).await?;
if let Err(remaining) = state.mfa_attempt_service.check_allowed(user_id).await {
let minutes = remaining.as_secs().div_ceil(60);
return Err(AppError::TooManyRequests(format!(
"Too many failed attempts. Try again in {} minute{}",
minutes,
if minutes == 1 { "" } else { "s" }
)));
}
let totp_secret = state
.totp_repo
.find_by_user(user_id)
.await?
.ok_or(AppError::Validation("MFA not configured".into()))?;
if !totp_secret.enabled {
return Err(AppError::Validation("MFA not enabled".into()));
}
let db_user = state
.user_repo
.find_by_id(user_id)
.await?
.ok_or(AppError::NotFound("User not found".into()))?;
let email = db_user.email.ok_or(AppError::Internal(anyhow::anyhow!(
"User has MFA but no email"
)))?;
let verification_result = state.totp_service.verify_with_replay_check(
&totp_secret.secret,
&req.code,
&email,
totp_secret.last_used_time_step,
)?;
match verification_result {
Some(time_step) => {
state.mfa_attempt_service.record_success(user_id).await;
if !state
.totp_repo
.record_used_time_step_if_newer(user_id, time_step)
.await?
{
return Err(AppError::Validation("Invalid verification code".into()));
}
state.step_up_service.record_strong_auth(session_id).await?;
Ok(Json(MessageResponse {
message: "MFA verification successful".into(),
}))
}
None => {
if let Err(lockout_duration) = state.mfa_attempt_service.record_failed(user_id).await {
let minutes = lockout_duration.as_secs().div_ceil(60);
return Err(AppError::TooManyRequests(format!(
"Too many failed attempts. Try again in {} minute{}",
minutes,
if minutes == 1 { "" } else { "s" }
)));
}
Err(AppError::Validation("Invalid verification code".into()))
}
}
}
pub async fn use_recovery_code<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<RecoveryCodeRequest>,
) -> Result<Json<MessageResponse>, AppError> {
let (user_id, session_id) = get_authenticated_session(&state, &headers).await?;
if let Err(remaining) = state.mfa_attempt_service.check_allowed(user_id).await {
let minutes = remaining.as_secs().div_ceil(60);
return Err(AppError::TooManyRequests(format!(
"Too many failed attempts. Try again in {} minute{}",
minutes,
if minutes == 1 { "" } else { "s" }
)));
}
if !state.totp_repo.has_mfa_enabled(user_id).await? {
return Err(AppError::Validation("MFA not enabled".into()));
}
let used = state
.totp_repo
.use_recovery_code(user_id, &req.code)
.await?;
if !used {
if let Err(lockout_duration) = state.mfa_attempt_service.record_failed(user_id).await {
let minutes = lockout_duration.as_secs().div_ceil(60);
return Err(AppError::TooManyRequests(format!(
"Too many failed attempts. Try again in {} minute{}",
minutes,
if minutes == 1 { "" } else { "s" }
)));
}
return Err(AppError::Validation("Invalid recovery code".into()));
}
state.mfa_attempt_service.record_success(user_id).await;
state.step_up_service.record_strong_auth(session_id).await?;
if let Err(e) = state
.audit_service
.log_user_event(AuditEventType::MfaRecoveryCodeUsed, user_id, Some(&headers))
.await
{
tracing::warn!(error = %e, user_id = %user_id, "Failed to log MFA recovery code audit event");
}
Ok(Json(MessageResponse {
message: "Recovery code accepted".into(),
}))
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RegenerateRecoveryCodesResponse {
pub recovery_codes: Vec<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RegenerateRecoveryCodesRequest {
pub code: String,
}
pub async fn regenerate_recovery_codes<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<RegenerateRecoveryCodesRequest>,
) -> Result<Json<RegenerateRecoveryCodesResponse>, AppError> {
let (user_id, session_id) = get_authenticated_session(&state, &headers).await?;
let totp_secret = state
.totp_repo
.find_by_user(user_id)
.await?
.ok_or(AppError::Validation("MFA not configured".into()))?;
if !totp_secret.enabled {
return Err(AppError::Validation("MFA not enabled".into()));
}
let db_user = state
.user_repo
.find_by_id(user_id)
.await?
.ok_or(AppError::NotFound("User not found".into()))?;
let email = db_user.email.ok_or(AppError::Internal(anyhow::anyhow!(
"User has MFA but no email"
)))?;
let time_step = state
.totp_service
.verify_with_replay_check(
&totp_secret.secret,
&req.code,
&email,
totp_secret.last_used_time_step,
)?
.ok_or(AppError::Validation("Invalid verification code".into()))?;
if !state
.totp_repo
.record_used_time_step_if_newer(user_id, time_step)
.await?
{
return Err(AppError::Validation("Invalid verification code".into()));
}
state.step_up_service.record_strong_auth(session_id).await?;
let recovery_codes = state.totp_service.generate_recovery_codes();
let code_hashes: Vec<String> = recovery_codes
.iter()
.map(|c| TotpService::hash_recovery_code(c))
.collect::<Result<Vec<_>, _>>()?;
state
.totp_repo
.store_recovery_codes(user_id, code_hashes)
.await?;
if let Err(e) = state
.audit_service
.log_user_event(
AuditEventType::MfaRecoveryCodesRegenerated,
user_id,
Some(&headers),
)
.await
{
tracing::warn!(error = %e, user_id = %user_id, "Failed to log MFA recovery codes regenerated audit event");
}
Ok(Json(RegenerateRecoveryCodesResponse { recovery_codes }))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enable_mfa_request_deserialize() {
let json = r#"{"code": "123456"}"#;
let req: EnableMfaRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.code, "123456");
}
#[test]
fn test_disable_mfa_request_deserialize() {
let json = r#"{"password": "mypassword"}"#;
let req: DisableMfaRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.password, "mypassword");
}
#[test]
fn test_mfa_status_response_serialize() {
let response = MfaStatusResponse {
enabled: true,
recovery_codes_remaining: 8,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("\"enabled\":true"));
assert!(json.contains("\"recoveryCodesRemaining\":8"));
}
}