1use crate::api::{ApiResponse, ApiState, extract_bearer_token, validate_api_token};
13use axum::{Json, extract::State, http::HeaderMap};
14use base32::Alphabet;
15use rand::Rng as _;
16use serde::{Deserialize, Serialize};
17use sha2::{Digest as _, Sha256};
18use subtle::ConstantTimeEq as _;
19
20#[derive(Debug, Serialize)]
25pub struct MfaSetupResponse {
26 pub qr_code: String,
28 pub secret: String,
30 pub backup_codes: Vec<String>,
32}
33
34#[derive(Debug, Deserialize)]
36pub struct MfaVerifyRequest {
37 pub totp_code: String,
39}
40
41#[derive(Debug, Deserialize)]
43pub struct MfaDisableRequest {
44 pub password: String,
46 pub totp_code: String,
48}
49
50#[derive(Debug, Serialize)]
52pub struct MfaStatusResponse {
53 pub enabled: bool,
55 pub methods: Vec<String>,
57 pub backup_codes_remaining: u32,
59}
60
61fn generate_backup_codes() -> (Vec<String>, Vec<String>) {
68 const ALPHABET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; let mut plaintext = Vec::with_capacity(10);
70 let mut hashed = Vec::with_capacity(10);
71 let mut buf = [0u8; 8];
72 for _ in 0..10 {
73 rand::rng().fill_bytes(&mut buf);
74 let code: String = buf
75 .iter()
76 .map(|b| ALPHABET[(*b as usize) % ALPHABET.len()] as char)
77 .collect();
78 let hash = hex::encode(Sha256::digest(code.as_bytes()));
79 plaintext.push(code);
80 hashed.push(hash);
81 }
82 (plaintext, hashed)
83}
84
85fn hash_backup_code(code: &str) -> String {
87 hex::encode(Sha256::digest(code.as_bytes()))
88}
89
90fn verify_totp_code(provided: &str, secret_bytes: &[u8], now: u64) -> bool {
94 use subtle::ConstantTimeEq as _;
95 use totp_lite::{Sha1, totp_custom};
96 const STEP: u64 = 30;
97 const DIGITS: u32 = 6;
98
99 let mut matched = false;
103 for offset in [0u64, STEP, STEP.wrapping_neg()] {
104 let t = now.wrapping_add(offset);
105 let expected = totp_custom::<Sha1>(STEP, DIGITS, secret_bytes, t);
106 let eq: bool = expected.as_bytes().ct_eq(provided.as_bytes()).into();
108 matched |= eq;
109 }
110 matched
111}
112
113pub async fn setup_mfa(
125 State(state): State<ApiState>,
126 headers: HeaderMap,
127) -> ApiResponse<MfaSetupResponse> {
128 match extract_bearer_token(&headers) {
129 Some(token) => {
130 match validate_api_token(&state.auth_framework, &token).await {
131 Ok(auth_token) => {
132 let mut secret_bytes = [0u8; 20];
134 rand::rng().fill_bytes(&mut secret_bytes);
135 let secret_b32 =
136 base32::encode(Alphabet::Rfc4648 { padding: false }, &secret_bytes);
137
138 let (plaintext_codes, hashed_codes) = generate_backup_codes();
140
141 let storage = state.auth_framework.storage();
144 let pending_secret_key = format!("mfa_pending_secret:{}", auth_token.user_id);
145 let pending_backup_key =
146 format!("mfa_pending_backup_codes:{}", auth_token.user_id);
147 let ttl = std::time::Duration::from_secs(600);
148
149 if let Err(e) = storage
150 .store_kv(&pending_secret_key, secret_b32.as_bytes(), Some(ttl))
151 .await
152 {
153 tracing::error!("Failed to store pending MFA secret: {}", e);
154 return ApiResponse::error_typed(
155 "MFA_ERROR",
156 "Failed to initiate MFA setup",
157 );
158 }
159
160 let hashed_json =
161 serde_json::to_string(&hashed_codes).unwrap_or_else(|_| "[]".to_string());
162 let _ = storage
163 .store_kv(&pending_backup_key, hashed_json.as_bytes(), Some(ttl))
164 .await;
165
166 let issuer = "AuthFramework";
169 let account = urlencoding::encode(&auth_token.user_id);
170 let qr_code = format!(
171 "otpauth://totp/{issuer}:{account}?secret={secret_b32}&issuer={issuer}&digits=6&period=30"
172 );
173
174 tracing::info!("MFA setup initiated for user: {}", auth_token.user_id);
175 ApiResponse::success(MfaSetupResponse {
176 qr_code,
177 secret: secret_b32,
178 backup_codes: plaintext_codes,
179 })
180 }
181 Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA setup failed"),
182 }
183 }
184 None => ApiResponse::<MfaSetupResponse>::unauthorized_typed(),
185 }
186}
187
188pub async fn verify_mfa(
195 State(state): State<ApiState>,
196 headers: HeaderMap,
197 Json(req): Json<MfaVerifyRequest>,
198) -> ApiResponse<()> {
199 if req.totp_code.is_empty() {
200 return ApiResponse::validation_error("TOTP code is required");
201 }
202
203 if req.totp_code.len() != 6 || !req.totp_code.chars().all(|c| c.is_ascii_digit()) {
204 return ApiResponse::validation_error("TOTP code must be 6 digits");
205 }
206
207 match extract_bearer_token(&headers) {
208 Some(token) => {
209 match validate_api_token(&state.auth_framework, &token).await {
210 Ok(auth_token) => {
211 let storage = state.auth_framework.storage();
212 let pending_key = format!("mfa_pending_secret:{}", auth_token.user_id);
213
214 let secret_b32 = match storage.get_kv(&pending_key).await {
216 Ok(Some(data)) => String::from_utf8_lossy(&data).to_string(),
217 _ => {
218 return ApiResponse::error_typed(
219 "MFA_NOT_PENDING",
220 "No pending MFA setup found. Call /mfa/setup first.",
221 );
222 }
223 };
224
225 let secret_bytes =
226 match base32::decode(Alphabet::Rfc4648 { padding: false }, &secret_b32) {
227 Some(b) => b,
228 None => {
229 return ApiResponse::error_typed(
230 "MFA_ERROR",
231 "Invalid stored secret",
232 );
233 }
234 };
235
236 let now = chrono::Utc::now().timestamp() as u64;
238 if !verify_totp_code(&req.totp_code, &secret_bytes, now) {
239 return ApiResponse::error_typed("MFA_INVALID_CODE", "Invalid TOTP code");
240 }
241
242 let active_key = format!("mfa_secret:{}", auth_token.user_id);
244 if let Err(e) = storage
245 .store_kv(&active_key, secret_b32.as_bytes(), None)
246 .await
247 {
248 tracing::error!(
249 "Failed to persist MFA secret for user {}: {}",
250 auth_token.user_id,
251 e
252 );
253 return ApiResponse::error_typed("MFA_ERROR", "Failed to activate MFA");
254 }
255
256 let pending_backup_key =
258 format!("mfa_pending_backup_codes:{}", auth_token.user_id);
259 if let Ok(Some(data)) = storage.get_kv(&pending_backup_key).await {
260 let active_backup_key = format!("mfa_backup_codes:{}", auth_token.user_id);
261 if let Err(e) = storage.store_kv(&active_backup_key, &data, None).await {
262 tracing::warn!("Failed to promote MFA backup codes for user {}: {}", auth_token.user_id, e);
263 }
264 if let Err(e) = storage.delete_kv(&pending_backup_key).await {
265 tracing::warn!("Failed to clean up pending MFA backup codes for user {}: {}", auth_token.user_id, e);
266 }
267 }
268
269 if let Err(e) = storage.delete_kv(&pending_key).await {
271 tracing::warn!("Failed to clean up pending MFA secret for user {}: {}", auth_token.user_id, e);
272 }
273
274 let flag_key = format!("mfa_enabled:{}", auth_token.user_id);
276 if let Err(e) = storage.store_kv(&flag_key, b"true", None).await {
277 tracing::warn!("Failed to set MFA enabled flag for user {}: {}", auth_token.user_id, e);
278 }
279
280 tracing::info!("MFA enabled for user: {}", auth_token.user_id);
281 ApiResponse::<()>::ok_with_message("MFA enabled successfully")
282 }
283 Err(e) => ApiResponse::<()>::from(e),
284 }
285 }
286 None => ApiResponse::<()>::unauthorized(),
287 }
288}
289
290pub async fn disable_mfa(
295 State(state): State<ApiState>,
296 headers: HeaderMap,
297 Json(req): Json<MfaDisableRequest>,
298) -> ApiResponse<()> {
299 if req.password.is_empty() || req.totp_code.is_empty() {
300 return ApiResponse::validation_error("Password and TOTP code are required");
301 }
302
303 match extract_bearer_token(&headers) {
304 Some(token) => {
305 match validate_api_token(&state.auth_framework, &token).await {
306 Ok(auth_token) => {
307 match state
310 .auth_framework
311 .verify_user_password(&auth_token.user_id, &req.password)
312 .await
313 {
314 Ok(true) => {}
315 Ok(false) => {
316 return ApiResponse::error_typed(
317 "MFA_UNAUTHORIZED",
318 "Incorrect password",
319 );
320 }
321 Err(_) => {
322 return ApiResponse::error_typed(
323 "MFA_UNAUTHORIZED",
324 "Password verification failed",
325 );
326 }
327 }
328
329 let storage = state.auth_framework.storage();
330 let active_key = format!("mfa_secret:{}", auth_token.user_id);
331
332 let secret_b32 = match storage.get_kv(&active_key).await {
334 Ok(Some(data)) => String::from_utf8_lossy(&data).to_string(),
335 _ => {
336 return ApiResponse::error_typed(
337 "MFA_NOT_ENABLED",
338 "MFA is not enabled for this account",
339 );
340 }
341 };
342
343 let secret_bytes =
344 match base32::decode(Alphabet::Rfc4648 { padding: false }, &secret_b32) {
345 Some(b) => b,
346 None => {
347 return ApiResponse::error_typed(
348 "MFA_ERROR",
349 "Invalid stored secret",
350 );
351 }
352 };
353
354 let now = chrono::Utc::now().timestamp() as u64;
355 if !verify_totp_code(&req.totp_code, &secret_bytes, now) {
356 return ApiResponse::error_typed("MFA_INVALID_CODE", "Invalid TOTP code");
357 }
358
359 let backup_key = format!("mfa_backup_codes:{}", auth_token.user_id);
361 let flag_key = format!("mfa_enabled:{}", auth_token.user_id);
362
363 if let Err(e) = storage.delete_kv(&active_key).await {
364 tracing::warn!("Failed to delete MFA secret for user {}: {}", auth_token.user_id, e);
365 }
366 if let Err(e) = storage.delete_kv(&backup_key).await {
367 tracing::warn!("Failed to delete MFA backup codes for user {}: {}", auth_token.user_id, e);
368 }
369 if let Err(e) = storage.delete_kv(&flag_key).await {
370 tracing::warn!("Failed to delete MFA enabled flag for user {}: {}", auth_token.user_id, e);
371 }
372
373 tracing::info!("MFA disabled for user: {}", auth_token.user_id);
374 ApiResponse::<()>::ok_with_message("MFA disabled successfully")
375 }
376 Err(e) => ApiResponse::<()>::from(e),
377 }
378 }
379 None => ApiResponse::<()>::unauthorized(),
380 }
381}
382
383pub async fn get_mfa_status(
385 State(state): State<ApiState>,
386 headers: HeaderMap,
387) -> ApiResponse<MfaStatusResponse> {
388 match extract_bearer_token(&headers) {
389 Some(token) => match validate_api_token(&state.auth_framework, &token).await {
390 Ok(auth_token) => {
391 let storage = state.auth_framework.storage();
392 let mfa_enabled = check_mfa_enabled(storage.as_ref(), &auth_token.user_id).await;
393 let backup_codes_remaining =
394 count_backup_codes(storage.as_ref(), &auth_token.user_id).await;
395
396 let status = MfaStatusResponse {
397 enabled: mfa_enabled,
398 methods: if mfa_enabled {
399 vec!["totp".to_string()]
400 } else {
401 vec![]
402 },
403 backup_codes_remaining,
404 };
405
406 ApiResponse::success(status)
407 }
408 Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA status check failed"),
409 },
410 None => ApiResponse::<MfaStatusResponse>::unauthorized_typed(),
411 }
412}
413
414pub async fn regenerate_backup_codes(
419 State(state): State<ApiState>,
420 headers: HeaderMap,
421) -> ApiResponse<Vec<String>> {
422 match extract_bearer_token(&headers) {
423 Some(token) => {
424 match validate_api_token(&state.auth_framework, &token).await {
425 Ok(auth_token) => {
426 let storage = state.auth_framework.storage();
427
428 if !check_mfa_enabled(storage.as_ref(), &auth_token.user_id).await {
430 return ApiResponse::error_typed(
431 "MFA_NOT_ENABLED",
432 "MFA is not enabled for this account",
433 );
434 }
435
436 let (plaintext, hashed) = generate_backup_codes();
437 let backup_key = format!("mfa_backup_codes:{}", auth_token.user_id);
438 let hashed_json =
439 serde_json::to_string(&hashed).unwrap_or_else(|_| "[]".to_string());
440
441 if let Err(e) = storage
442 .store_kv(&backup_key, hashed_json.as_bytes(), None)
443 .await
444 {
445 tracing::error!(
446 "Failed to store backup codes for user {}: {}",
447 auth_token.user_id,
448 e
449 );
450 return ApiResponse::error_typed(
451 "MFA_ERROR",
452 "Failed to regenerate backup codes",
453 );
454 }
455
456 tracing::info!("Backup codes regenerated for user: {}", auth_token.user_id);
457 ApiResponse::success(plaintext)
458 }
459 Err(_e) => {
460 ApiResponse::error_typed("MFA_ERROR", "MFA backup codes generation failed")
461 }
462 }
463 }
464 None => ApiResponse::<Vec<String>>::unauthorized_typed(),
465 }
466}
467
468#[derive(Debug, Deserialize)]
474pub struct BackupCodeVerifyRequest {
475 pub backup_code: String,
477}
478
479pub async fn verify_backup_code(
481 State(state): State<ApiState>,
482 headers: HeaderMap,
483 Json(req): Json<BackupCodeVerifyRequest>,
484) -> ApiResponse<()> {
485 if req.backup_code.is_empty() {
486 return ApiResponse::validation_error("Backup code is required");
487 }
488
489 match extract_bearer_token(&headers) {
490 Some(token) => {
491 match validate_api_token(&state.auth_framework, &token).await {
492 Ok(auth_token) => {
493 let storage = state.auth_framework.storage();
494 let backup_key = format!("mfa_backup_codes:{}", auth_token.user_id);
495
496 let codes: Vec<String> = match storage.get_kv(&backup_key).await {
498 Ok(Some(data)) => serde_json::from_slice(&data).unwrap_or_default(),
499 _ => {
500 return ApiResponse::error_typed(
501 "MFA_ERROR",
502 "No backup codes found for this account",
503 );
504 }
505 };
506
507 let provided_hash_hex = hash_backup_code(req.backup_code.trim());
512 let provided_bytes = hex::decode(&provided_hash_hex).unwrap_or_default();
513
514 let mut found_idx: Option<usize> = None;
515 for (i, stored_hex) in codes.iter().enumerate() {
516 let stored_bytes = hex::decode(stored_hex).unwrap_or_default();
517 if stored_bytes.len() == provided_bytes.len()
518 && bool::from(stored_bytes.ct_eq(&provided_bytes))
519 {
520 found_idx = Some(i);
523 }
524 }
525
526 match found_idx {
527 Some(idx) => {
528 let mut remaining = codes;
530 remaining.remove(idx);
531 let updated = serde_json::to_string(&remaining)
532 .unwrap_or_else(|_| "[]".to_string());
533 let _ = storage
534 .store_kv(&backup_key, updated.as_bytes(), None)
535 .await;
536
537 tracing::info!(
538 "Backup code used for user: {}. {} codes remaining.",
539 auth_token.user_id,
540 remaining.len()
541 );
542 ApiResponse::<()>::ok_with_message("Backup code verified")
543 }
544 None => ApiResponse::error_typed(
545 "MFA_INVALID_CODE",
546 "Invalid or already-used backup code",
547 ),
548 }
549 }
550 Err(e) => ApiResponse::<()>::from(e),
551 }
552 }
553 None => ApiResponse::<()>::unauthorized(),
554 }
555}
556
557pub async fn check_user_mfa_status(
565 auth_framework: &std::sync::Arc<crate::AuthFramework>,
566 user_id: &str,
567) -> bool {
568 check_mfa_enabled(auth_framework.storage().as_ref(), user_id).await
569}
570
571async fn check_mfa_enabled(storage: &dyn crate::storage::AuthStorage, user_id: &str) -> bool {
573 let flag_key = format!("mfa_enabled:{}", user_id);
574 matches!(storage.get_kv(&flag_key).await, Ok(Some(_)))
575}
576
577async fn count_backup_codes(storage: &dyn crate::storage::AuthStorage, user_id: &str) -> u32 {
579 let backup_key = format!("mfa_backup_codes:{}", user_id);
580 match storage.get_kv(&backup_key).await {
581 Ok(Some(data)) => serde_json::from_slice::<Vec<String>>(&data)
582 .map(|v| v.len() as u32)
583 .unwrap_or(0),
584 _ => 0,
585 }
586}