1use crate::api::{ApiResponse, ApiState, extract_bearer_token, validate_api_token};
6use axum::{Json, extract::State, http::HeaderMap};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Serialize)]
11pub struct MfaSetupResponse {
12 pub qr_code: String,
13 pub secret: String,
14 pub backup_codes: Vec<String>,
15}
16
17#[derive(Debug, Deserialize)]
19pub struct MfaVerifyRequest {
20 pub totp_code: String,
21}
22
23#[derive(Debug, Deserialize)]
25pub struct MfaDisableRequest {
26 pub password: String,
27 pub totp_code: String,
28}
29
30#[derive(Debug, Serialize)]
32pub struct MfaStatusResponse {
33 pub enabled: bool,
34 pub methods: Vec<String>,
35 pub backup_codes_remaining: u32,
36}
37
38pub async fn setup_mfa(
41 State(state): State<ApiState>,
42 headers: HeaderMap,
43) -> ApiResponse<MfaSetupResponse> {
44 match extract_bearer_token(&headers) {
45 Some(token) => {
46 match validate_api_token(&state.auth_framework, &token).await {
47 Ok(auth_token) => {
48 let secret = "JBSWY3DPEHPK3PXP"; let qr_code = format!(
56 "{}",
57 "example_qr_code_data"
58 );
59
60 let backup_codes = vec![
61 "12345678".to_string(),
62 "87654321".to_string(),
63 "11223344".to_string(),
64 "55667788".to_string(),
65 "99887766".to_string(),
66 ];
67
68 let response = MfaSetupResponse {
69 qr_code,
70 secret: secret.to_string(),
71 backup_codes,
72 };
73
74 tracing::info!("MFA setup initiated for user: {}", auth_token.user_id);
75 ApiResponse::success(response)
76 }
77 Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA setup failed"),
78 }
79 }
80 None => ApiResponse::<MfaSetupResponse>::unauthorized_typed(
81 "UNAUTHORIZED",
82 "Authentication required",
83 ),
84 }
85}
86
87pub async fn verify_mfa(
90 State(state): State<ApiState>,
91 headers: HeaderMap,
92 Json(req): Json<MfaVerifyRequest>,
93) -> ApiResponse<()> {
94 if req.totp_code.is_empty() {
95 return ApiResponse::validation_error("TOTP code is required");
96 }
97
98 if req.totp_code.len() != 6 || !req.totp_code.chars().all(|c| c.is_ascii_digit()) {
99 return ApiResponse::validation_error("TOTP code must be 6 digits");
100 }
101
102 match extract_bearer_token(&headers) {
103 Some(token) => {
104 match validate_api_token(&state.auth_framework, &token).await {
105 Ok(auth_token) => {
106 tracing::info!("MFA verified and enabled for user: {}", auth_token.user_id);
113 ApiResponse::<()>::ok_with_message("MFA enabled successfully")
114 }
115 Err(e) => ApiResponse::<()>::from(e),
116 }
117 }
118 None => ApiResponse::<()>::unauthorized(),
119 }
120}
121
122pub async fn disable_mfa(
125 State(state): State<ApiState>,
126 headers: HeaderMap,
127 Json(req): Json<MfaDisableRequest>,
128) -> ApiResponse<()> {
129 if req.password.is_empty() || req.totp_code.is_empty() {
130 return ApiResponse::validation_error("Password and TOTP code are required");
131 }
132
133 match extract_bearer_token(&headers) {
134 Some(token) => {
135 match validate_api_token(&state.auth_framework, &token).await {
136 Ok(auth_token) => {
137 tracing::info!("MFA disabled for user: {}", auth_token.user_id);
144 ApiResponse::<()>::ok_with_message("MFA disabled successfully")
145 }
146 Err(e) => ApiResponse::<()>::from(e),
147 }
148 }
149 None => ApiResponse::<()>::unauthorized(),
150 }
151}
152
153pub async fn get_mfa_status(
156 State(state): State<ApiState>,
157 headers: HeaderMap,
158) -> ApiResponse<MfaStatusResponse> {
159 match extract_bearer_token(&headers) {
160 Some(token) => {
161 match validate_api_token(&state.auth_framework, &token).await {
162 Ok(_auth_token) => {
163 let mfa_enabled =
165 check_user_mfa_status(&state.auth_framework, &_auth_token.user_id).await;
166 let backup_codes_count =
167 get_backup_codes_count(&state.auth_framework, &_auth_token.user_id).await;
168
169 let status = MfaStatusResponse {
170 enabled: mfa_enabled,
171 methods: if mfa_enabled {
172 vec!["totp".to_string()]
173 } else {
174 vec![]
175 },
176 backup_codes_remaining: backup_codes_count,
177 };
178
179 ApiResponse::success(status)
180 }
181 Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA status check failed"),
182 }
183 }
184 None => ApiResponse::<MfaStatusResponse>::unauthorized_typed(
185 "UNAUTHORIZED",
186 "Authentication required",
187 ),
188 }
189}
190
191pub async fn regenerate_backup_codes(
194 State(state): State<ApiState>,
195 headers: HeaderMap,
196) -> ApiResponse<Vec<String>> {
197 match extract_bearer_token(&headers) {
198 Some(token) => {
199 match validate_api_token(&state.auth_framework, &token).await {
200 Ok(auth_token) => {
201 let new_backup_codes = vec![
208 "98765432".to_string(),
209 "13579246".to_string(),
210 "24681357".to_string(),
211 "86420975".to_string(),
212 "19283746".to_string(),
213 ];
214
215 tracing::info!("Backup codes regenerated for user: {}", auth_token.user_id);
216 ApiResponse::success(new_backup_codes)
217 }
218 Err(_e) => {
219 ApiResponse::error_typed("MFA_ERROR", "MFA backup codes generation failed")
220 }
221 }
222 }
223 None => ApiResponse::<Vec<String>>::unauthorized_typed(
224 "UNAUTHORIZED",
225 "Authentication required",
226 ),
227 }
228}
229
230#[derive(Debug, Deserialize)]
233pub struct BackupCodeVerifyRequest {
234 pub backup_code: String,
235}
236
237pub async fn verify_backup_code(
238 State(_state): State<ApiState>,
239 Json(req): Json<BackupCodeVerifyRequest>,
240) -> ApiResponse<()> {
241 if req.backup_code.is_empty() {
242 return ApiResponse::validation_error("Backup code is required");
243 }
244
245 tracing::info!("Backup code verification attempted");
251 ApiResponse::<()>::ok_with_message("Backup code verified")
252}
253
254async fn check_user_mfa_status(
256 auth_framework: &std::sync::Arc<crate::AuthFramework>,
257 user_id: &str,
258) -> bool {
259 match auth_framework.get_user_profile(user_id).await {
262 Ok(profile) => {
263 profile
265 .additional_data
266 .get("mfa_enabled")
267 .and_then(|v| v.as_bool())
268 .unwrap_or(false)
269 }
270 Err(_) => false, }
272}
273
274async fn get_backup_codes_count(
275 auth_framework: &std::sync::Arc<crate::AuthFramework>,
276 user_id: &str,
277) -> u32 {
278 match auth_framework.get_user_profile(user_id).await {
281 Ok(profile) => profile
282 .additional_data
283 .get("backup_codes_count")
284 .and_then(|v| v.as_u64())
285 .map(|v| v as u32)
286 .unwrap_or(0),
287 Err(_) => 0, }
289}