auth_framework/api/
mfa.rs1use crate::api::{ApiResponse, ApiState, extract_bearer_token, validate_api_token};
6use axum::{Json, extract::State, http::HeaderMap};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Serialize)]
11pub struct MfaSetupResponse {
12 pub qr_code: String,
13 pub secret: String,
14 pub backup_codes: Vec<String>,
15}
16
17#[derive(Debug, Deserialize)]
19pub struct MfaVerifyRequest {
20 pub totp_code: String,
21}
22
23#[derive(Debug, Deserialize)]
25pub struct MfaDisableRequest {
26 pub password: String,
27 pub totp_code: String,
28}
29
30#[derive(Debug, Serialize)]
32pub struct MfaStatusResponse {
33 pub enabled: bool,
34 pub methods: Vec<String>,
35 pub backup_codes_remaining: u32,
36}
37
38pub async fn setup_mfa(
41 State(state): State<ApiState>,
42 headers: HeaderMap,
43) -> ApiResponse<MfaSetupResponse> {
44 match extract_bearer_token(&headers) {
45 Some(token) => {
46 match validate_api_token(&state.auth_framework, &token).await {
47 Ok(auth_token) => {
48 let secret = "JBSWY3DPEHPK3PXP"; let qr_code = format!(
56 "{}",
57 "example_qr_code_data"
58 );
59
60 let backup_codes = vec![
61 "12345678".to_string(),
62 "87654321".to_string(),
63 "11223344".to_string(),
64 "55667788".to_string(),
65 "99887766".to_string(),
66 ];
67
68 let response = MfaSetupResponse {
69 qr_code,
70 secret: secret.to_string(),
71 backup_codes,
72 };
73
74 tracing::info!("MFA setup initiated for user: {}", auth_token.user_id);
75 ApiResponse::success(response)
76 }
77 Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA setup failed"),
78 }
79 }
80 None => ApiResponse::<MfaSetupResponse>::unauthorized_typed(),
81 }
82}
83
84pub async fn verify_mfa(
87 State(state): State<ApiState>,
88 headers: HeaderMap,
89 Json(req): Json<MfaVerifyRequest>,
90) -> ApiResponse<()> {
91 if req.totp_code.is_empty() {
92 return ApiResponse::validation_error("TOTP code is required");
93 }
94
95 if req.totp_code.len() != 6 || !req.totp_code.chars().all(|c| c.is_ascii_digit()) {
96 return ApiResponse::validation_error("TOTP code must be 6 digits");
97 }
98
99 match extract_bearer_token(&headers) {
100 Some(token) => {
101 match validate_api_token(&state.auth_framework, &token).await {
102 Ok(auth_token) => {
103 tracing::info!("MFA verified and enabled for user: {}", auth_token.user_id);
110 ApiResponse::<()>::ok_with_message("MFA enabled successfully")
111 }
112 Err(e) => ApiResponse::<()>::from(e),
113 }
114 }
115 None => ApiResponse::<()>::unauthorized(),
116 }
117}
118
119pub async fn disable_mfa(
122 State(state): State<ApiState>,
123 headers: HeaderMap,
124 Json(req): Json<MfaDisableRequest>,
125) -> ApiResponse<()> {
126 if req.password.is_empty() || req.totp_code.is_empty() {
127 return ApiResponse::validation_error("Password and TOTP code are required");
128 }
129
130 match extract_bearer_token(&headers) {
131 Some(token) => {
132 match validate_api_token(&state.auth_framework, &token).await {
133 Ok(auth_token) => {
134 tracing::info!("MFA disabled for user: {}", auth_token.user_id);
141 ApiResponse::<()>::ok_with_message("MFA disabled successfully")
142 }
143 Err(e) => ApiResponse::<()>::from(e),
144 }
145 }
146 None => ApiResponse::<()>::unauthorized(),
147 }
148}
149
150pub async fn get_mfa_status(
153 State(state): State<ApiState>,
154 headers: HeaderMap,
155) -> ApiResponse<MfaStatusResponse> {
156 match extract_bearer_token(&headers) {
157 Some(token) => {
158 match validate_api_token(&state.auth_framework, &token).await {
159 Ok(_auth_token) => {
160 let mfa_enabled =
162 check_user_mfa_status(&state.auth_framework, &_auth_token.user_id).await;
163 let backup_codes_count =
164 get_backup_codes_count(&state.auth_framework, &_auth_token.user_id).await;
165
166 let status = MfaStatusResponse {
167 enabled: mfa_enabled,
168 methods: if mfa_enabled {
169 vec!["totp".to_string()]
170 } else {
171 vec![]
172 },
173 backup_codes_remaining: backup_codes_count,
174 };
175
176 ApiResponse::success(status)
177 }
178 Err(_e) => ApiResponse::error_typed("MFA_ERROR", "MFA status check failed"),
179 }
180 }
181 None => ApiResponse::<MfaStatusResponse>::unauthorized_typed(),
182 }
183}
184
185pub async fn regenerate_backup_codes(
188 State(state): State<ApiState>,
189 headers: HeaderMap,
190) -> ApiResponse<Vec<String>> {
191 match extract_bearer_token(&headers) {
192 Some(token) => {
193 match validate_api_token(&state.auth_framework, &token).await {
194 Ok(auth_token) => {
195 let new_backup_codes = vec![
202 "98765432".to_string(),
203 "13579246".to_string(),
204 "24681357".to_string(),
205 "86420975".to_string(),
206 "19283746".to_string(),
207 ];
208
209 tracing::info!("Backup codes regenerated for user: {}", auth_token.user_id);
210 ApiResponse::success(new_backup_codes)
211 }
212 Err(_e) => {
213 ApiResponse::error_typed("MFA_ERROR", "MFA backup codes generation failed")
214 }
215 }
216 }
217 None => ApiResponse::<Vec<String>>::unauthorized_typed(),
218 }
219}
220
221#[derive(Debug, Deserialize)]
224pub struct BackupCodeVerifyRequest {
225 pub backup_code: String,
226}
227
228pub async fn verify_backup_code(
229 State(_state): State<ApiState>,
230 Json(req): Json<BackupCodeVerifyRequest>,
231) -> ApiResponse<()> {
232 if req.backup_code.is_empty() {
233 return ApiResponse::validation_error("Backup code is required");
234 }
235
236 tracing::info!("Backup code verification attempted");
242 ApiResponse::<()>::ok_with_message("Backup code verified")
243}
244
245async fn check_user_mfa_status(
247 auth_framework: &std::sync::Arc<crate::AuthFramework>,
248 user_id: &str,
249) -> bool {
250 match auth_framework.get_user_profile(user_id).await {
253 Ok(profile) => {
254 profile
256 .additional_data
257 .get("mfa_enabled")
258 .and_then(|v| v.as_bool())
259 .unwrap_or(false)
260 }
261 Err(_) => false, }
263}
264
265async fn get_backup_codes_count(
266 auth_framework: &std::sync::Arc<crate::AuthFramework>,
267 user_id: &str,
268) -> u32 {
269 match auth_framework.get_user_profile(user_id).await {
272 Ok(profile) => profile
273 .additional_data
274 .get("backup_codes_count")
275 .and_then(|v| v.as_u64())
276 .map(|v| v as u32)
277 .unwrap_or(0),
278 Err(_) => 0, }
280}