use crate::api::{ApiResponse, ApiState, extract_bearer_token, validate_api_token};
use axum::{Json, extract::State, http::HeaderMap};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
pub struct MfaSetupResponse {
pub qr_code: String,
pub secret: String,
pub backup_codes: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct MfaVerifyRequest {
pub totp_code: String,
}
#[derive(Debug, Deserialize)]
pub struct MfaDisableRequest {
pub password: String,
pub totp_code: String,
}
#[derive(Debug, Serialize)]
pub struct MfaStatusResponse {
pub enabled: bool,
pub methods: Vec<String>,
pub backup_codes_remaining: u32,
}
pub async fn setup_mfa(
State(state): State<ApiState>,
headers: HeaderMap,
) -> ApiResponse<MfaSetupResponse> {
match extract_bearer_token(&headers) {
Some(token) => {
match validate_api_token(&state.auth_framework, &token).await {
Ok(auth_token) => {
let secret = "JBSWY3DPEHPK3PXP"; let qr_code = format!(
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAYAAABccqhmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAAdgAAAHYBTnsmCAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAACAASURBVHic7Z15fFTV{}",
"example_qr_code_data"
);
let backup_codes = vec![
"12345678".to_string(),
"87654321".to_string(),
"11223344".to_string(),
"55667788".to_string(),
"99887766".to_string(),
];
let response = MfaSetupResponse {
qr_code,
secret: secret.to_string(),
backup_codes,
};
tracing::info!("MFA setup initiated for user: {}", auth_token.user_id);
ApiResponse::success(response)
}
Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA setup failed"),
}
}
None => ApiResponse::<MfaSetupResponse>::unauthorized_typed(),
}
}
pub async fn verify_mfa(
State(state): State<ApiState>,
headers: HeaderMap,
Json(req): Json<MfaVerifyRequest>,
) -> ApiResponse<()> {
if req.totp_code.is_empty() {
return ApiResponse::validation_error("TOTP code is required");
}
if req.totp_code.len() != 6 || !req.totp_code.chars().all(|c| c.is_ascii_digit()) {
return ApiResponse::validation_error("TOTP code must be 6 digits");
}
match extract_bearer_token(&headers) {
Some(token) => {
match validate_api_token(&state.auth_framework, &token).await {
Ok(auth_token) => {
tracing::info!("MFA verified and enabled for user: {}", auth_token.user_id);
ApiResponse::<()>::ok_with_message("MFA enabled successfully")
}
Err(e) => ApiResponse::<()>::from(e),
}
}
None => ApiResponse::<()>::unauthorized(),
}
}
pub async fn disable_mfa(
State(state): State<ApiState>,
headers: HeaderMap,
Json(req): Json<MfaDisableRequest>,
) -> ApiResponse<()> {
if req.password.is_empty() || req.totp_code.is_empty() {
return ApiResponse::validation_error("Password and TOTP code are required");
}
match extract_bearer_token(&headers) {
Some(token) => {
match validate_api_token(&state.auth_framework, &token).await {
Ok(auth_token) => {
tracing::info!("MFA disabled for user: {}", auth_token.user_id);
ApiResponse::<()>::ok_with_message("MFA disabled successfully")
}
Err(e) => ApiResponse::<()>::from(e),
}
}
None => ApiResponse::<()>::unauthorized(),
}
}
pub async fn get_mfa_status(
State(state): State<ApiState>,
headers: HeaderMap,
) -> ApiResponse<MfaStatusResponse> {
match extract_bearer_token(&headers) {
Some(token) => {
match validate_api_token(&state.auth_framework, &token).await {
Ok(_auth_token) => {
let mfa_enabled =
check_user_mfa_status(&state.auth_framework, &_auth_token.user_id).await;
let backup_codes_count =
get_backup_codes_count(&state.auth_framework, &_auth_token.user_id).await;
let status = MfaStatusResponse {
enabled: mfa_enabled,
methods: if mfa_enabled {
vec!["totp".to_string()]
} else {
vec![]
},
backup_codes_remaining: backup_codes_count,
};
ApiResponse::success(status)
}
Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA status check failed"),
}
}
None => ApiResponse::<MfaStatusResponse>::unauthorized_typed(),
}
}
pub async fn regenerate_backup_codes(
State(state): State<ApiState>,
headers: HeaderMap,
) -> ApiResponse<Vec<String>> {
match extract_bearer_token(&headers) {
Some(token) => {
match validate_api_token(&state.auth_framework, &token).await {
Ok(auth_token) => {
let new_backup_codes = vec![
"98765432".to_string(),
"13579246".to_string(),
"24681357".to_string(),
"86420975".to_string(),
"19283746".to_string(),
];
tracing::info!("Backup codes regenerated for user: {}", auth_token.user_id);
ApiResponse::success(new_backup_codes)
}
Err(_e) => {
ApiResponse::error_typed("MFA_ERROR", "MFA backup codes generation failed")
}
}
}
None => ApiResponse::<Vec<String>>::unauthorized_typed(),
}
}
#[derive(Debug, Deserialize)]
pub struct BackupCodeVerifyRequest {
pub backup_code: String,
}
pub async fn verify_backup_code(
State(_state): State<ApiState>,
Json(req): Json<BackupCodeVerifyRequest>,
) -> ApiResponse<()> {
if req.backup_code.is_empty() {
return ApiResponse::validation_error("Backup code is required");
}
tracing::info!("Backup code verification attempted");
ApiResponse::<()>::ok_with_message("Backup code verified")
}
async fn check_user_mfa_status(
auth_framework: &std::sync::Arc<crate::AuthFramework>,
user_id: &str,
) -> bool {
match auth_framework.get_user_profile(user_id).await {
Ok(profile) => {
profile
.additional_data
.get("mfa_enabled")
.and_then(|v| v.as_bool())
.unwrap_or(false)
}
Err(_) => false, }
}
async fn get_backup_codes_count(
auth_framework: &std::sync::Arc<crate::AuthFramework>,
user_id: &str,
) -> u32 {
match auth_framework.get_user_profile(user_id).await {
Ok(profile) => profile
.additional_data
.get("backup_codes_count")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(0),
Err(_) => 0, }
}