1use std::sync::Arc;
2
3use axum::Form;
4use axum::Router;
5use axum::extract::{Extension, Query};
6use axum::http::HeaderMap;
7use axum::http::StatusCode;
8use axum::http::Uri;
9use axum::http::header::{LOCATION, SET_COOKIE, USER_AGENT};
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12use axum_htmx::{HxBoosted, HxRequest};
13use chrono::Utc;
14use serde::Deserialize;
15
16use allowthem_core::applications::BrandingConfig;
17use allowthem_core::totp::totp_uri;
18use allowthem_core::{AllowThem, AuditEvent, AuthError, sessions};
19use qrcode::QrCode;
20use qrcode::render::svg;
21
22use crate::auth_views::{
23 MfaChallengeView, MfaRecoveryView, MfaSetupView, mfa_challenge_fragment, mfa_challenge_page,
24 mfa_recovery_fragment, mfa_recovery_page, mfa_setup_fragment, mfa_setup_page,
25};
26use crate::branding::{DefaultBranding, default_branding_ref, resolve_branding};
27use crate::browser_error::BrowserError;
28use crate::csrf::CsrfToken;
29use crate::error::BrowserAuthRedirect;
30
31const SETUP_INVALID_CODE: &str = "Invalid TOTP code";
33
34const CHALLENGE_INVALID_TOTP: &str = "Invalid TOTP or recovery code";
36
37const CHALLENGE_INVALID_RECOVERY: &str = "Invalid recovery code";
39
40#[derive(Clone)]
41struct MfaPageConfig {
42 is_production: bool,
43 base_url: String,
44}
45
46fn client_ip(headers: &HeaderMap) -> Option<String> {
51 headers
52 .get("x-forwarded-for")
53 .and_then(|v| v.to_str().ok())
54 .and_then(|s| s.split(',').next())
55 .map(|s| s.trim().to_string())
56}
57
58fn qr_data_uri(text: &str) -> String {
64 let code = match QrCode::new(text.as_bytes()) {
65 Ok(c) => c,
66 Err(_) => return String::new(),
67 };
68 let svg_str = code
69 .render()
70 .min_dimensions(200, 200)
71 .dark_color(svg::Color("#000000"))
72 .light_color(svg::Color("#ffffff"))
73 .build();
74 let encoded = svg_str
76 .replace('#', "%23")
77 .replace('<', "%3C")
78 .replace('>', "%3E")
79 .replace('"', "'");
80 format!("data:image/svg+xml,{encoded}")
81}
82
83fn derive_issuer(base_url: &str) -> String {
88 base_url
89 .trim_start_matches("https://")
90 .trim_start_matches("http://")
91 .split('/')
92 .next()
93 .unwrap_or("allowthem")
94 .split(':')
95 .next()
96 .unwrap_or("allowthem")
97 .to_string()
98}
99
100async fn require_browser_user(
106 ath: &AllowThem,
107 headers: &HeaderMap,
108 path: &str,
109) -> Result<allowthem_core::types::User, Response> {
110 let cookie_header = headers
111 .get(axum::http::header::COOKIE)
112 .and_then(|v| v.to_str().ok())
113 .ok_or_else(|| BrowserAuthRedirect::new(path).into_response())?;
114
115 let token = ath
116 .parse_session_cookie(cookie_header)
117 .ok_or_else(|| BrowserAuthRedirect::new(path).into_response())?;
118
119 let ttl = ath.session_config().ttl;
120 let session = ath
121 .db()
122 .validate_session(&token, ttl)
123 .await
124 .map_err(|err| {
125 tracing::error!("session validation error: {err}");
126 BrowserAuthRedirect::new(path).into_response()
127 })?
128 .ok_or_else(|| BrowserAuthRedirect::new(path).into_response())?;
129
130 match ath.db().get_user(session.user_id).await {
131 Ok(user) if user.is_active => Ok(user),
132 Ok(_) => Err(BrowserAuthRedirect::new(path).into_response()),
133 Err(AuthError::NotFound) => Err(BrowserAuthRedirect::new(path).into_response()),
134 Err(err) => {
135 tracing::error!("user lookup error: {err}");
136 Err(BrowserAuthRedirect::new(path).into_response())
137 }
138 }
139}
140
141fn render_mfa_setup_fragment(
153 config: &MfaPageConfig,
154 csrf_token: &str,
155 totp_uri: &str,
156 qr_data_uri: &str,
157 secret: &str,
158 error: &str,
159 branding: Option<&BrandingConfig>,
160) -> Result<axum::response::Html<String>, BrowserError> {
161 mfa_setup_fragment(&MfaSetupView {
162 csrf_token,
163 totp_uri,
164 qr_data_uri,
165 secret,
166 error,
167 branding,
168 is_production: config.is_production,
169 })
170}
171
172fn render_mfa_recovery_fragment(
175 config: &MfaPageConfig,
176 recovery_codes: &[String],
177 branding: Option<&BrandingConfig>,
178) -> Result<axum::response::Html<String>, BrowserError> {
179 mfa_recovery_fragment(&MfaRecoveryView {
180 recovery_codes,
181 branding,
182 is_production: config.is_production,
183 })
184}
185
186#[allow(clippy::too_many_arguments)]
191async fn get_mfa_setup(
192 Extension(ath): Extension<AllowThem>,
193 Extension(config): Extension<MfaPageConfig>,
194 default_branding: Option<Extension<Arc<DefaultBranding>>>,
195 uri: Uri,
196 csrf: CsrfToken,
197 headers: HeaderMap,
198 HxBoosted(boosted): HxBoosted,
199 HxRequest(request): HxRequest,
200) -> Result<Response, BrowserError> {
201 let user = match require_browser_user(&ath, &headers, uri.path()).await {
202 Ok(u) => u,
203 Err(redirect) => return Ok(redirect),
204 };
205
206 let default = default_branding_ref(&default_branding);
207 let branding = resolve_branding(&ath, None, default).await;
208
209 let secret = match ath.get_pending_mfa_secret(user.id).await? {
211 Some(s) => s,
212 None => ath.create_mfa_secret(user.id).await?,
213 };
214
215 let issuer = derive_issuer(&config.base_url);
216 let uri = totp_uri(&secret, user.email.as_str(), &issuer);
217 let qr = qr_data_uri(&uri);
218
219 if request && !boosted {
220 let html = render_mfa_setup_fragment(
221 &config,
222 csrf.as_str(),
223 &uri,
224 &qr,
225 &secret,
226 "",
227 branding.as_ref(),
228 )?;
229 return Ok(html.into_response());
230 }
231
232 let html = mfa_setup_page(&MfaSetupView {
233 csrf_token: csrf.as_str(),
234 totp_uri: &uri,
235 qr_data_uri: &qr,
236 secret: &secret,
237 error: "",
238 branding: branding.as_ref(),
239 is_production: config.is_production,
240 })?;
241 Ok(html.into_response())
242}
243
244#[derive(Deserialize)]
245pub struct MfaConfirmForm {
246 code: String,
247 #[allow(dead_code)]
248 csrf_token: String,
249}
250
251#[allow(clippy::too_many_arguments)]
256async fn post_mfa_confirm(
257 Extension(ath): Extension<AllowThem>,
258 Extension(config): Extension<MfaPageConfig>,
259 default_branding: Option<Extension<Arc<DefaultBranding>>>,
260 uri: Uri,
261 csrf: CsrfToken,
262 headers: HeaderMap,
263 HxBoosted(boosted): HxBoosted,
264 HxRequest(request): HxRequest,
265 Form(form): Form<MfaConfirmForm>,
266) -> Result<Response, BrowserError> {
267 let user = match require_browser_user(&ath, &headers, uri.path()).await {
268 Ok(u) => u,
269 Err(redirect) => return Ok(redirect),
270 };
271
272 let default = default_branding_ref(&default_branding);
273 let branding = resolve_branding(&ath, None, default).await;
274
275 let ip = client_ip(&headers);
276 let ua = headers.get(USER_AGENT).and_then(|v| v.to_str().ok());
277
278 match ath.enable_mfa(user.id, &form.code).await {
279 Ok(recovery_codes) => {
280 let _ = ath
281 .db()
282 .log_audit(
283 AuditEvent::MfaEnabled,
284 Some(&user.id),
285 None,
286 ip.as_deref(),
287 ua,
288 None,
289 )
290 .await;
291
292 if request && !boosted {
293 let html =
294 render_mfa_recovery_fragment(&config, &recovery_codes, branding.as_ref())?;
295 return Ok(html.into_response());
296 }
297
298 let html = mfa_recovery_page(&MfaRecoveryView {
299 recovery_codes: &recovery_codes,
300 branding: branding.as_ref(),
301 is_production: config.is_production,
302 })?;
303 Ok(html.into_response())
304 }
305 Err(allowthem_core::AuthError::InvalidTotpCode) => {
306 let secret = ath
308 .get_pending_mfa_secret(user.id)
309 .await?
310 .unwrap_or_default();
311 let issuer = derive_issuer(&config.base_url);
312 let uri = totp_uri(&secret, user.email.as_str(), &issuer);
313 let qr = qr_data_uri(&uri);
314
315 let html = if request && !boosted {
316 render_mfa_setup_fragment(
317 &config,
318 csrf.as_str(),
319 &uri,
320 &qr,
321 &secret,
322 SETUP_INVALID_CODE,
323 branding.as_ref(),
324 )?
325 } else {
326 mfa_setup_page(&MfaSetupView {
327 csrf_token: csrf.as_str(),
328 totp_uri: &uri,
329 qr_data_uri: &qr,
330 secret: &secret,
331 error: SETUP_INVALID_CODE,
332 branding: branding.as_ref(),
333 is_production: config.is_production,
334 })?
335 };
336 Ok(html.into_response())
337 }
338 Err(e) => Err(BrowserError::Auth(e)),
339 }
340}
341
342#[derive(Deserialize)]
343pub struct MfaDisableForm {
344 #[allow(dead_code)]
345 csrf_token: String,
346}
347
348async fn post_mfa_disable(
350 Extension(ath): Extension<AllowThem>,
351 uri: Uri,
352 headers: HeaderMap,
353 Form(_form): Form<MfaDisableForm>,
354) -> Result<Response, BrowserError> {
355 let user = match require_browser_user(&ath, &headers, uri.path()).await {
356 Ok(u) => u,
357 Err(redirect) => return Ok(redirect),
358 };
359
360 let ip = client_ip(&headers);
361 let ua = headers.get(USER_AGENT).and_then(|v| v.to_str().ok());
362
363 ath.disable_mfa(user.id).await?;
364
365 let _ = ath
366 .db()
367 .log_audit(
368 AuditEvent::MfaDisabled,
369 Some(&user.id),
370 None,
371 ip.as_deref(),
372 ua,
373 None,
374 )
375 .await;
376
377 Ok((StatusCode::SEE_OTHER, [(LOCATION, "/settings".to_string())]).into_response())
378}
379
380#[derive(Deserialize)]
381struct RegenerateCodesForm {
382 #[allow(dead_code)]
383 csrf_token: String,
384}
385
386#[allow(clippy::too_many_arguments)]
388async fn post_regenerate_recovery_codes(
389 Extension(ath): Extension<AllowThem>,
390 Extension(config): Extension<MfaPageConfig>,
391 default_branding: Option<Extension<Arc<DefaultBranding>>>,
392 uri: Uri,
393 headers: HeaderMap,
394 HxBoosted(boosted): HxBoosted,
395 HxRequest(request): HxRequest,
396 Form(_form): Form<RegenerateCodesForm>,
397) -> Result<Response, BrowserError> {
398 let user = match require_browser_user(&ath, &headers, uri.path()).await {
399 Ok(u) => u,
400 Err(redirect) => return Ok(redirect),
401 };
402
403 let has_mfa = ath.db().has_mfa_enabled(user.id).await?;
404 if !has_mfa {
405 return Ok((StatusCode::SEE_OTHER, [(LOCATION, "/settings".to_string())]).into_response());
406 }
407
408 let recovery_codes = ath.regenerate_recovery_codes(user.id).await?;
409
410 let default = default_branding_ref(&default_branding);
411 let branding = resolve_branding(&ath, None, default).await;
412
413 if request && !boosted {
414 let html = render_mfa_recovery_fragment(&config, &recovery_codes, branding.as_ref())?;
415 return Ok(html.into_response());
416 }
417
418 let html = mfa_recovery_page(&MfaRecoveryView {
419 recovery_codes: &recovery_codes,
420 branding: branding.as_ref(),
421 is_production: config.is_production,
422 })?;
423 Ok(html.into_response())
424}
425
426#[derive(Deserialize)]
431pub struct ChallengeQuery {
432 token: String,
433}
434
435fn render_mfa_challenge_fragment(
444 config: &MfaPageConfig,
445 mfa_token: &str,
446 error: &str,
447 branding: Option<&BrandingConfig>,
448) -> Result<axum::response::Html<String>, BrowserError> {
449 mfa_challenge_fragment(&MfaChallengeView {
450 mfa_token,
451 error,
452 branding,
453 is_production: config.is_production,
454 })
455}
456
457async fn get_mfa_challenge(
459 Extension(ath): Extension<AllowThem>,
460 Extension(config): Extension<MfaPageConfig>,
461 default_branding: Option<Extension<Arc<DefaultBranding>>>,
462 Query(query): Query<ChallengeQuery>,
463 HxBoosted(boosted): HxBoosted,
464 HxRequest(request): HxRequest,
465) -> Result<Response, BrowserError> {
466 let user_id = ath.db().validate_mfa_challenge(&query.token).await?;
468 if user_id.is_none() {
469 return Ok((StatusCode::SEE_OTHER, [(LOCATION, "/login".to_string())]).into_response());
471 }
472
473 let default = default_branding_ref(&default_branding);
474 let branding = resolve_branding(&ath, None, default).await;
475
476 if request && !boosted {
477 let html = render_mfa_challenge_fragment(&config, &query.token, "", branding.as_ref())?;
478 return Ok(html.into_response());
479 }
480
481 let html = mfa_challenge_page(&MfaChallengeView {
482 mfa_token: &query.token,
483 error: "",
484 branding: branding.as_ref(),
485 is_production: config.is_production,
486 })?;
487 Ok(html.into_response())
488}
489
490#[derive(Deserialize)]
491pub struct MfaChallengeForm {
492 mfa_token: String,
493 #[serde(default)]
494 code: Option<String>,
495 #[serde(default)]
496 recovery_code: Option<String>,
497 #[serde(default)]
498 use_recovery: Option<String>,
499}
500
501async fn post_mfa_challenge(
503 Extension(ath): Extension<AllowThem>,
504 Extension(config): Extension<MfaPageConfig>,
505 default_branding: Option<Extension<Arc<DefaultBranding>>>,
506 headers: HeaderMap,
507 Form(form): Form<MfaChallengeForm>,
508) -> Result<Response, BrowserError> {
509 let default = default_branding_ref(&default_branding);
510 let branding = resolve_branding(&ath, None, default).await;
511 let ip = headers
512 .get("x-forwarded-for")
513 .and_then(|v| v.to_str().ok())
514 .and_then(|s| s.split(',').next())
515 .map(|s| s.trim().to_string());
516 let ua = headers.get(USER_AGENT).and_then(|v| v.to_str().ok());
517
518 let user_id = match ath.db().validate_mfa_challenge(&form.mfa_token).await? {
520 Some(uid) => uid,
521 None => {
522 return Ok((StatusCode::SEE_OTHER, [(LOCATION, "/login".to_string())]).into_response());
523 }
524 };
525
526 let use_recovery = form.use_recovery.is_some();
528 let verified = if use_recovery {
529 let code = form.recovery_code.as_deref().unwrap_or("");
530 ath.verify_recovery_code(user_id, code).await?
531 } else {
532 let code = form.code.as_deref().unwrap_or("");
533 ath.verify_totp(user_id, code).await?
534 };
535
536 if !verified {
537 let _ = ath
539 .db()
540 .log_audit(
541 AuditEvent::MfaChallengeFailed,
542 Some(&user_id),
543 None,
544 ip.as_deref(),
545 ua,
546 None,
547 )
548 .await;
549
550 let error_msg = if use_recovery {
551 CHALLENGE_INVALID_RECOVERY
552 } else {
553 CHALLENGE_INVALID_TOTP
554 };
555
556 let html = mfa_challenge_page(&MfaChallengeView {
557 mfa_token: &form.mfa_token,
558 error: error_msg,
559 branding: branding.as_ref(),
560 is_production: config.is_production,
561 })?;
562 return Ok(html.into_response());
563 }
564
565 ath.db().consume_mfa_challenge(&form.mfa_token).await?;
567
568 let _ = ath
569 .db()
570 .log_audit(
571 AuditEvent::MfaChallengeSuccess,
572 Some(&user_id),
573 None,
574 ip.as_deref(),
575 ua,
576 None,
577 )
578 .await;
579
580 let _ = ath
583 .db()
584 .log_audit(
585 AuditEvent::Login,
586 Some(&user_id),
587 None,
588 ip.as_deref(),
589 ua,
590 None,
591 )
592 .await;
593
594 let token = sessions::generate_token();
595 let token_hash = sessions::hash_token(&token);
596 let ttl = ath.session_config().ttl;
597 let expires_at = Utc::now() + ttl;
598 ath.db()
599 .create_session(user_id, token_hash, ip.as_deref(), ua, expires_at)
600 .await?;
601
602 ath.notify_user_active(user_id);
603 ath.emit_event(allowthem_core::AuthEvent::new(
604 "session.created",
605 Some(user_id),
606 serde_json::json!({ "user_id": user_id }),
607 ))
608 .await;
609
610 let cookie = ath.session_cookie(&token);
611
612 Ok((
613 StatusCode::SEE_OTHER,
614 [(SET_COOKIE, cookie), (LOCATION, "/".to_string())],
615 )
616 .into_response())
617}
618
619pub fn mfa_setup_routes(is_production: bool, base_url: String) -> Router<()> {
630 let cfg = MfaPageConfig {
631 is_production,
632 base_url,
633 };
634 Router::new()
635 .route("/settings/mfa/setup", get(get_mfa_setup))
636 .route("/settings/mfa/confirm", post(post_mfa_confirm))
637 .route("/settings/mfa/disable", post(post_mfa_disable))
638 .route(
639 "/settings/mfa/recovery-codes/regenerate",
640 post(post_regenerate_recovery_codes),
641 )
642 .layer(Extension(cfg))
643}
644
645pub fn mfa_challenge_routes(is_production: bool) -> Router<()> {
651 let cfg = MfaPageConfig {
652 is_production,
653 base_url: String::new(),
654 };
655 Router::new()
656 .route(
657 "/mfa/challenge",
658 get(get_mfa_challenge).post(post_mfa_challenge),
659 )
660 .layer(Extension(cfg))
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 use axum::body::Body;
668 use axum::http::{Request, StatusCode, header};
669 use chrono::{Duration, Utc};
670 use totp_rs::{Algorithm, Secret, TOTP};
671 use tower::ServiceExt;
672
673 use allowthem_core::{AllowThemBuilder, Email, generate_token, hash_token};
674
675 const TEST_MFA_KEY: [u8; 32] = [0x42; 32];
676
677 async fn setup() -> AllowThem {
682 AllowThemBuilder::new("sqlite::memory:")
683 .cookie_secure(false)
684 .mfa_key(TEST_MFA_KEY)
685 .csrf_key(*b"test-csrf-key-for-binary-tests!!")
686 .build()
687 .await
688 .unwrap()
689 }
690
691 fn test_app(ath: AllowThem) -> Router {
694 Router::new()
695 .merge(mfa_setup_routes(false, "http://127.0.0.1:3100".into()))
696 .layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
697 .merge(mfa_challenge_routes(false))
698 .layer(axum::middleware::from_fn_with_state(
699 ath.clone(),
700 crate::cors::inject_ath_into_extensions,
701 ))
702 }
703
704 async fn create_session(ath: &AllowThem) -> (allowthem_core::types::UserId, String) {
705 let email = Email::new("mfa-test@example.com".into()).unwrap();
706 let user = ath
707 .db()
708 .create_user(email, "pass", None, None)
709 .await
710 .unwrap();
711 let token = generate_token();
712 let token_hash = hash_token(&token);
713 let expires = Utc::now() + Duration::hours(24);
714 ath.db()
715 .create_session(user.id, token_hash, None, None, expires)
716 .await
717 .unwrap();
718 let cookie = ath.session_cookie(&token);
719 let cookie_val = cookie.split(';').next().unwrap().to_string();
720 (user.id, cookie_val)
721 }
722
723 async fn get_csrf(app: &Router, session_cookie: &str) -> String {
725 let req = Request::builder()
726 .uri("/settings/mfa/setup")
727 .header(header::COOKIE, session_cookie)
728 .body(Body::empty())
729 .unwrap();
730 let resp = app.clone().oneshot(req).await.unwrap();
731 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
732 .await
733 .unwrap();
734 let html = String::from_utf8(bytes.to_vec()).unwrap();
735 let marker = "name=\"csrf_token\" value=\"";
736 let start = html.find(marker).expect("csrf_token not found in HTML") + marker.len();
737 let end = html[start..].find('"').unwrap() + start;
738 html[start..end].to_string()
739 }
740
741 async fn enable_mfa_for_user(
743 ath: &AllowThem,
744 user_id: allowthem_core::types::UserId,
745 ) -> (TOTP, Vec<String>) {
746 let secret_b32 = ath.create_mfa_secret(user_id).await.unwrap();
747 let totp = TOTP::new(
748 Algorithm::SHA1,
749 6,
750 1,
751 30,
752 Secret::Encoded(secret_b32).to_bytes().unwrap(),
753 None,
754 String::new(),
755 )
756 .unwrap();
757 let code = totp.generate_current().unwrap();
758 let recovery_codes = ath.enable_mfa(user_id, &code).await.unwrap();
759 (totp, recovery_codes)
760 }
761
762 #[test]
767 fn qr_data_uri_produces_svg_data_uri() {
768 let uri = qr_data_uri("otpauth://totp/test?secret=ABC&issuer=test");
769 assert!(
770 uri.starts_with("data:image/svg+xml,"),
771 "must produce an SVG data URI"
772 );
773 assert!(uri.contains("svg"), "must contain SVG content");
774 assert!(
777 !uri.contains('&'),
778 "data URI must not contain raw '&' characters"
779 );
780 }
781
782 #[test]
783 fn qr_data_uri_empty_input_still_works() {
784 let uri = qr_data_uri("");
785 assert!(uri.starts_with("data:image/svg+xml,"));
787 }
788
789 #[test]
794 fn derive_issuer_strips_http_scheme() {
795 assert_eq!(derive_issuer("http://example.com"), "example.com");
796 }
797
798 #[test]
799 fn derive_issuer_strips_https_scheme() {
800 assert_eq!(
801 derive_issuer("https://auth.example.com"),
802 "auth.example.com"
803 );
804 }
805
806 #[test]
807 fn derive_issuer_strips_port() {
808 assert_eq!(derive_issuer("http://127.0.0.1:3100"), "127.0.0.1");
810 }
811
812 #[test]
813 fn derive_issuer_strips_path() {
814 assert_eq!(
815 derive_issuer("https://auth.example.com/some/path"),
816 "auth.example.com"
817 );
818 }
819
820 #[tokio::test]
825 async fn get_mfa_setup_renders_secret() {
826 let ath = setup().await;
827 let app = test_app(ath.clone());
828 let (_, cookie) = create_session(&ath).await;
829
830 let csrf = get_csrf(&app, &cookie).await;
831 let req = Request::builder()
832 .uri("/settings/mfa/setup")
833 .header(header::COOKIE, format!("{cookie}; csrf_token={csrf}"))
834 .body(Body::empty())
835 .unwrap();
836 let resp = app.oneshot(req).await.unwrap();
837
838 assert_eq!(resp.status(), StatusCode::OK);
839 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
840 .await
841 .unwrap();
842 let html = String::from_utf8(body.to_vec()).unwrap();
843 assert!(
844 html.contains("totp-secret"),
845 "setup page must show secret element"
846 );
847 assert!(
848 html.contains("totp-uri"),
849 "setup page must show QR URI container"
850 );
851 assert!(
852 html.contains("data:image/svg+xml,"),
853 "setup page must include a QR code data URI"
854 );
855 }
856
857 #[tokio::test]
858 async fn get_mfa_setup_is_idempotent() {
859 let ath = setup().await;
861 let app = test_app(ath.clone());
862 let (_, cookie) = create_session(&ath).await;
863 let csrf = get_csrf(&app, &cookie).await;
864
865 let secret_of = |html: String| -> String {
866 let after_attr = html
868 .split("<code id=\"totp-secret\"")
869 .nth(1)
870 .expect("totp-secret element not found in HTML");
871 let after_tag_close = after_attr
872 .splitn(2, '>')
873 .nth(1)
874 .expect("closing > of totp-secret element not found");
875 after_tag_close.split('<').next().unwrap_or("").to_string()
876 };
877
878 let req1 = Request::builder()
879 .uri("/settings/mfa/setup")
880 .header(header::COOKIE, format!("{cookie}; csrf_token={csrf}"))
881 .body(Body::empty())
882 .unwrap();
883 let resp1 = app.clone().oneshot(req1).await.unwrap();
884 let html1 = String::from_utf8(
885 axum::body::to_bytes(resp1.into_body(), usize::MAX)
886 .await
887 .unwrap()
888 .to_vec(),
889 )
890 .unwrap();
891
892 let req2 = Request::builder()
893 .uri("/settings/mfa/setup")
894 .header(header::COOKIE, format!("{cookie}; csrf_token={csrf}"))
895 .body(Body::empty())
896 .unwrap();
897 let resp2 = app.clone().oneshot(req2).await.unwrap();
898 let html2 = String::from_utf8(
899 axum::body::to_bytes(resp2.into_body(), usize::MAX)
900 .await
901 .unwrap()
902 .to_vec(),
903 )
904 .unwrap();
905
906 assert_eq!(
907 secret_of(html1),
908 secret_of(html2),
909 "repeated GET /settings/mfa/setup must return the same pending secret"
910 );
911 }
912
913 #[tokio::test]
918 async fn post_mfa_confirm_invalid_code_shows_error_and_does_not_enable() {
919 let ath = setup().await;
920 let app = test_app(ath.clone());
921 let (user_id, cookie) = create_session(&ath).await;
922
923 let csrf = get_csrf(&app, &cookie).await;
925
926 let body_str = format!("code=000000&csrf_token={csrf}");
927 let req = Request::builder()
928 .method("POST")
929 .uri("/settings/mfa/confirm")
930 .header(header::COOKIE, format!("{cookie}; csrf_token={csrf}"))
931 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
932 .body(Body::from(body_str))
933 .unwrap();
934 let resp = app.oneshot(req).await.unwrap();
935
936 assert_eq!(resp.status(), StatusCode::OK);
937 let html = String::from_utf8(
938 axum::body::to_bytes(resp.into_body(), usize::MAX)
939 .await
940 .unwrap()
941 .to_vec(),
942 )
943 .unwrap();
944 assert!(
945 html.contains(SETUP_INVALID_CODE),
946 "wrong code must show setup error"
947 );
948 assert!(
949 !ath.has_mfa_enabled(user_id).await.unwrap(),
950 "MFA must not be enabled after wrong code"
951 );
952 }
953
954 #[tokio::test]
955 async fn post_mfa_confirm_valid_code_enables_mfa_and_renders_recovery_codes() {
956 let ath = setup().await;
957 let app = test_app(ath.clone());
958 let (user_id, cookie) = create_session(&ath).await;
959
960 let csrf = get_csrf(&app, &cookie).await;
961
962 let secret = ath.create_mfa_secret(user_id).await.unwrap();
964 let totp = TOTP::new(
965 Algorithm::SHA1,
966 6,
967 1,
968 30,
969 Secret::Encoded(secret).to_bytes().unwrap(),
970 None,
971 String::new(),
972 )
973 .unwrap();
974 let code = totp.generate_current().unwrap();
975
976 let body_str = format!("code={code}&csrf_token={csrf}");
977 let req = Request::builder()
978 .method("POST")
979 .uri("/settings/mfa/confirm")
980 .header(header::COOKIE, format!("{cookie}; csrf_token={csrf}"))
981 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
982 .body(Body::from(body_str))
983 .unwrap();
984 let resp = app.oneshot(req).await.unwrap();
985
986 assert_eq!(resp.status(), StatusCode::OK);
987 let html = String::from_utf8(
988 axum::body::to_bytes(resp.into_body(), usize::MAX)
989 .await
990 .unwrap()
991 .to_vec(),
992 )
993 .unwrap();
994 assert!(
995 html.contains("recovery-code"),
996 "success must render recovery codes"
997 );
998 assert!(
999 ath.has_mfa_enabled(user_id).await.unwrap(),
1000 "MFA must be enabled after valid confirm"
1001 );
1002 }
1003
1004 #[tokio::test]
1009 async fn post_mfa_disable_removes_mfa_and_redirects() {
1010 let ath = setup().await;
1011 let app = test_app(ath.clone());
1012 let (user_id, cookie) = create_session(&ath).await;
1013 enable_mfa_for_user(&ath, user_id).await;
1014
1015 let session_token_val = cookie.split('=').nth(1).unwrap().to_string();
1017 let session_token = allowthem_core::types::SessionToken::from_encoded(session_token_val);
1018 let csrf =
1019 allowthem_core::derive_csrf_token(&session_token, b"test-csrf-key-for-binary-tests!!");
1020
1021 let body_str = format!("csrf_token={csrf}");
1022 let req = Request::builder()
1023 .method("POST")
1024 .uri("/settings/mfa/disable")
1025 .header(header::COOKIE, &cookie)
1026 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1027 .body(Body::from(body_str))
1028 .unwrap();
1029 let resp = app.oneshot(req).await.unwrap();
1030
1031 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1032 assert_eq!(resp.headers().get("location").unwrap(), "/settings");
1033 assert!(
1034 !ath.has_mfa_enabled(user_id).await.unwrap(),
1035 "MFA must be disabled after disable POST"
1036 );
1037 }
1038
1039 #[tokio::test]
1044 async fn post_regenerate_recovery_codes_renders_new_codes() {
1045 let ath = setup().await;
1046 let app = test_app(ath.clone());
1047 let (user_id, cookie) = create_session(&ath).await;
1048 let (_, old_codes) = enable_mfa_for_user(&ath, user_id).await;
1049
1050 let session_token_val = cookie.split('=').nth(1).unwrap().to_string();
1051 let session_token = allowthem_core::types::SessionToken::from_encoded(session_token_val);
1052 let csrf =
1053 allowthem_core::derive_csrf_token(&session_token, b"test-csrf-key-for-binary-tests!!");
1054
1055 let body_str = format!("csrf_token={csrf}");
1056 let req = Request::builder()
1057 .method("POST")
1058 .uri("/settings/mfa/recovery-codes/regenerate")
1059 .header(header::COOKIE, &cookie)
1060 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1061 .body(Body::from(body_str))
1062 .unwrap();
1063 let resp = app.oneshot(req).await.unwrap();
1064
1065 assert_eq!(resp.status(), StatusCode::OK);
1066 let html = String::from_utf8(
1067 axum::body::to_bytes(resp.into_body(), usize::MAX)
1068 .await
1069 .unwrap()
1070 .to_vec(),
1071 )
1072 .unwrap();
1073 assert!(
1074 html.contains("recovery-code"),
1075 "regeneration must render recovery codes"
1076 );
1077 for old_code in &old_codes {
1079 let valid = ath.verify_recovery_code(user_id, old_code).await.unwrap();
1080 assert!(
1081 !valid,
1082 "old recovery code must be invalidated after regeneration"
1083 );
1084 }
1085 }
1086
1087 #[tokio::test]
1088 async fn post_regenerate_recovery_codes_without_mfa_redirects() {
1089 let ath = setup().await;
1090 let app = test_app(ath.clone());
1091 let (_, cookie) = create_session(&ath).await;
1092
1093 let session_token_val = cookie.split('=').nth(1).unwrap().to_string();
1094 let session_token = allowthem_core::types::SessionToken::from_encoded(session_token_val);
1095 let csrf =
1096 allowthem_core::derive_csrf_token(&session_token, b"test-csrf-key-for-binary-tests!!");
1097
1098 let body_str = format!("csrf_token={csrf}");
1099 let req = Request::builder()
1100 .method("POST")
1101 .uri("/settings/mfa/recovery-codes/regenerate")
1102 .header(header::COOKIE, &cookie)
1103 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1104 .body(Body::from(body_str))
1105 .unwrap();
1106 let resp = app.oneshot(req).await.unwrap();
1107
1108 assert_eq!(
1109 resp.status(),
1110 StatusCode::SEE_OTHER,
1111 "must redirect when MFA is not enabled"
1112 );
1113 assert_eq!(resp.headers().get("location").unwrap(), "/settings");
1114 }
1115
1116 #[tokio::test]
1121 async fn get_mfa_challenge_with_invalid_token_redirects_to_login() {
1122 let ath = setup().await;
1123 let app = test_app(ath);
1124
1125 let req = Request::builder()
1126 .uri("/mfa/challenge?token=not-a-real-token")
1127 .body(Body::empty())
1128 .unwrap();
1129 let resp = app.oneshot(req).await.unwrap();
1130
1131 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1132 assert_eq!(resp.headers().get("location").unwrap(), "/login");
1133 }
1134
1135 #[tokio::test]
1136 async fn get_mfa_challenge_with_valid_token_renders_form() {
1137 let ath = setup().await;
1138 let app = test_app(ath.clone());
1139 let (user_id, _) = create_session(&ath).await;
1140 enable_mfa_for_user(&ath, user_id).await;
1141
1142 let token = ath.db().create_mfa_challenge(user_id).await.unwrap();
1143 let req = Request::builder()
1144 .uri(format!("/mfa/challenge?token={token}"))
1145 .body(Body::empty())
1146 .unwrap();
1147 let resp = app.oneshot(req).await.unwrap();
1148
1149 assert_eq!(resp.status(), StatusCode::OK);
1150 let html = String::from_utf8(
1151 axum::body::to_bytes(resp.into_body(), usize::MAX)
1152 .await
1153 .unwrap()
1154 .to_vec(),
1155 )
1156 .unwrap();
1157 assert!(
1158 html.contains("name=\"code\""),
1159 "challenge form must have code input"
1160 );
1161 assert!(
1162 html.contains("mfa_token"),
1163 "challenge form must embed mfa_token hidden field"
1164 );
1165 }
1166
1167 #[tokio::test]
1168 async fn get_mfa_challenge_hx_request_returns_fragment() {
1169 let ath = setup().await;
1170 let app = test_app(ath.clone());
1171 let (user_id, _) = create_session(&ath).await;
1172 enable_mfa_for_user(&ath, user_id).await;
1173
1174 let token = ath.db().create_mfa_challenge(user_id).await.unwrap();
1175 let req = Request::builder()
1176 .uri(format!("/mfa/challenge?token={token}"))
1177 .header("HX-Request", "true")
1178 .body(Body::empty())
1179 .unwrap();
1180 let resp = app.oneshot(req).await.unwrap();
1181
1182 assert_eq!(resp.status(), StatusCode::OK);
1183 let html = String::from_utf8(
1184 axum::body::to_bytes(resp.into_body(), usize::MAX)
1185 .await
1186 .unwrap()
1187 .to_vec(),
1188 )
1189 .unwrap();
1190 assert!(
1191 html.contains("<main class=\"wf-auth-form\">"),
1192 "HX response must be a fragment starting at <main>"
1193 );
1194 assert!(
1195 !html.contains("<html"),
1196 "HX response must not render the full shell"
1197 );
1198 }
1199
1200 #[test]
1201 fn render_mfa_setup_fragment_composes_main_and_oob_head() {
1202 let config = MfaPageConfig {
1203 is_production: false,
1204 base_url: "http://127.0.0.1:3100".into(),
1205 };
1206 let totp =
1207 "otpauth://totp/allowthem:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=allowthem";
1208 let html = render_mfa_setup_fragment(
1209 &config,
1210 "csrf-tok",
1211 totp,
1212 &qr_data_uri(totp),
1213 "JBSWY3DPEHPK3PXP",
1214 "",
1215 None,
1216 )
1217 .unwrap()
1218 .0;
1219 assert!(
1220 html.contains("<main class=\"wf-auth-form\">"),
1221 "fragment must include the <main> root"
1222 );
1223 assert!(
1224 html.contains("<title hx-swap-oob=\"true\">"),
1225 "fragment must include the OOB <title> tag"
1226 );
1227 assert!(
1228 html.contains("id=\"wf-screen-label\""),
1229 "fragment must include the OOB #wf-screen-label span"
1230 );
1231 assert!(
1232 html.contains("ENABLE 2FA"),
1233 "fragment must include the ENABLE 2FA status hint"
1234 );
1235 assert!(
1236 html.contains("JBSWY3DPEHPK3PXP"),
1237 "fragment must include the base32 secret"
1238 );
1239 }
1240
1241 #[tokio::test]
1242 async fn get_mfa_setup_hx_request_returns_fragment() {
1243 let ath = setup().await;
1244 let app = test_app(ath.clone());
1245 let (_, cookie) = create_session(&ath).await;
1246 let csrf = get_csrf(&app, &cookie).await;
1247
1248 let req = Request::builder()
1249 .uri("/settings/mfa/setup")
1250 .header(header::COOKIE, format!("{cookie}; csrf_token={csrf}"))
1251 .header("HX-Request", "true")
1252 .body(Body::empty())
1253 .unwrap();
1254 let resp = app.oneshot(req).await.unwrap();
1255
1256 assert_eq!(resp.status(), StatusCode::OK);
1257 let html = String::from_utf8(
1258 axum::body::to_bytes(resp.into_body(), usize::MAX)
1259 .await
1260 .unwrap()
1261 .to_vec(),
1262 )
1263 .unwrap();
1264 assert!(
1265 html.contains("<main class=\"wf-auth-form\">"),
1266 "HX response must be a fragment starting at <main>"
1267 );
1268 assert!(
1269 !html.contains("<html"),
1270 "HX response must not render the full shell"
1271 );
1272 }
1273
1274 #[test]
1275 fn render_mfa_recovery_fragment_composes_main_and_oob_head() {
1276 let config = MfaPageConfig {
1277 is_production: false,
1278 base_url: "http://127.0.0.1:3100".into(),
1279 };
1280 let codes = vec!["AAAA-BBBB".to_string(), "CCCC-DDDD".to_string()];
1281 let html = render_mfa_recovery_fragment(&config, &codes, None)
1282 .unwrap()
1283 .0;
1284 assert!(
1285 html.contains("<main class=\"wf-auth-form\">"),
1286 "fragment must include the <main> root"
1287 );
1288 assert!(
1289 html.contains("<title hx-swap-oob=\"true\">"),
1290 "fragment must include the OOB <title> tag"
1291 );
1292 assert!(
1293 html.contains("id=\"wf-screen-label\""),
1294 "fragment must include the OOB #wf-screen-label span"
1295 );
1296 assert!(
1297 html.contains("RECOVERY CODES"),
1298 "fragment must include the RECOVERY CODES status hint"
1299 );
1300 assert!(
1301 html.contains("AAAA-BBBB"),
1302 "fragment must include the rendered recovery codes"
1303 );
1304 assert!(
1305 html.contains(r#"data-testid="recovery-code-grid""#),
1306 "fragment must include the recovery code grid"
1307 );
1308 }
1309
1310 #[tokio::test]
1311 async fn post_mfa_confirm_hx_request_returns_recovery_fragment() {
1312 let ath = setup().await;
1313 let app = test_app(ath.clone());
1314 let (user_id, cookie) = create_session(&ath).await;
1315 let csrf = get_csrf(&app, &cookie).await;
1316
1317 let secret = ath.create_mfa_secret(user_id).await.unwrap();
1318 let totp = TOTP::new(
1319 Algorithm::SHA1,
1320 6,
1321 1,
1322 30,
1323 Secret::Encoded(secret).to_bytes().unwrap(),
1324 None,
1325 String::new(),
1326 )
1327 .unwrap();
1328 let code = totp.generate_current().unwrap();
1329
1330 let body_str = format!("code={code}&csrf_token={csrf}");
1331 let req = Request::builder()
1332 .method("POST")
1333 .uri("/settings/mfa/confirm")
1334 .header(header::COOKIE, format!("{cookie}; csrf_token={csrf}"))
1335 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1336 .header("HX-Request", "true")
1337 .body(Body::from(body_str))
1338 .unwrap();
1339 let resp = app.oneshot(req).await.unwrap();
1340
1341 assert_eq!(resp.status(), StatusCode::OK);
1342 let html = String::from_utf8(
1343 axum::body::to_bytes(resp.into_body(), usize::MAX)
1344 .await
1345 .unwrap()
1346 .to_vec(),
1347 )
1348 .unwrap();
1349 assert!(
1350 html.contains("<main class=\"wf-auth-form\">"),
1351 "HX response must be a fragment starting at <main>"
1352 );
1353 assert!(
1354 !html.contains("<html"),
1355 "HX response must not render the full shell"
1356 );
1357 assert!(
1358 html.contains("recovery-code"),
1359 "HX response must render the recovery codes"
1360 );
1361 }
1362
1363 #[test]
1364 fn render_mfa_challenge_fragment_composes_main_and_oob_head() {
1365 let config = MfaPageConfig {
1366 is_production: false,
1367 base_url: String::new(),
1368 };
1369 let html = render_mfa_challenge_fragment(&config, "mfa-token-abc", "", None)
1370 .unwrap()
1371 .0;
1372 assert!(
1373 html.contains("<main class=\"wf-auth-form\">"),
1374 "fragment must include the <main> root"
1375 );
1376 assert!(
1377 html.contains("<title hx-swap-oob=\"true\">"),
1378 "fragment must include the OOB <title> tag"
1379 );
1380 assert!(
1381 html.contains("id=\"wf-screen-label\""),
1382 "fragment must include the OOB #wf-screen-label span"
1383 );
1384 assert!(
1385 html.contains("TWO-FACTOR"),
1386 "fragment must include the TWO-FACTOR status hint"
1387 );
1388 }
1389
1390 #[tokio::test]
1395 async fn post_mfa_challenge_invalid_token_redirects_to_login() {
1396 let ath = setup().await;
1397 let app = test_app(ath);
1398
1399 let body_str = "mfa_token=garbage&code=123456";
1400 let req = Request::builder()
1401 .method("POST")
1402 .uri("/mfa/challenge")
1403 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1404 .body(Body::from(body_str))
1405 .unwrap();
1406 let resp = app.oneshot(req).await.unwrap();
1407
1408 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1409 assert_eq!(resp.headers().get("location").unwrap(), "/login");
1410 }
1411
1412 #[tokio::test]
1413 async fn post_mfa_challenge_wrong_totp_does_not_consume_challenge() {
1414 let ath = setup().await;
1416 let app = test_app(ath.clone());
1417 let (user_id, _) = create_session(&ath).await;
1418 enable_mfa_for_user(&ath, user_id).await;
1419
1420 let token = ath.db().create_mfa_challenge(user_id).await.unwrap();
1421
1422 let body_str = format!("mfa_token={token}&code=000000");
1423 let req = Request::builder()
1424 .method("POST")
1425 .uri("/mfa/challenge")
1426 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1427 .body(Body::from(body_str))
1428 .unwrap();
1429 let resp = app.oneshot(req).await.unwrap();
1430
1431 assert_eq!(resp.status(), StatusCode::OK);
1432 let html = String::from_utf8(
1433 axum::body::to_bytes(resp.into_body(), usize::MAX)
1434 .await
1435 .unwrap()
1436 .to_vec(),
1437 )
1438 .unwrap();
1439 assert!(
1440 html.contains(CHALLENGE_INVALID_TOTP),
1441 "wrong code must show TOTP error"
1442 );
1443
1444 let still_valid = ath.db().validate_mfa_challenge(&token).await.unwrap();
1446 assert!(
1447 still_valid.is_some(),
1448 "challenge must survive a failed attempt"
1449 );
1450 }
1451
1452 #[tokio::test]
1453 async fn post_mfa_challenge_valid_totp_creates_session_and_emits_login() {
1454 let ath = setup().await;
1455 let app = test_app(ath.clone());
1456 let (user_id, _) = create_session(&ath).await;
1457 let (totp, _) = enable_mfa_for_user(&ath, user_id).await;
1458
1459 let token = ath.db().create_mfa_challenge(user_id).await.unwrap();
1460 let code = totp.generate_current().unwrap();
1461
1462 let body_str = format!("mfa_token={token}&code={code}");
1463 let req = Request::builder()
1464 .method("POST")
1465 .uri("/mfa/challenge")
1466 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1467 .body(Body::from(body_str))
1468 .unwrap();
1469 let resp = app.oneshot(req).await.unwrap();
1470
1471 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1472 assert_eq!(resp.headers().get("location").unwrap(), "/");
1473 assert!(
1474 resp.headers().get(header::SET_COOKIE).is_some(),
1475 "session cookie must be set on success"
1476 );
1477
1478 let consumed = ath.db().validate_mfa_challenge(&token).await.unwrap();
1480 assert!(
1481 consumed.is_none(),
1482 "challenge must be consumed after success"
1483 );
1484
1485 let entries = ath.db().get_audit_log(Some(&user_id), 50, 0).await.unwrap();
1487 let event_types: Vec<&allowthem_core::AuditEvent> =
1488 entries.iter().map(|e| &e.event_type).collect();
1489 assert!(
1490 event_types.contains(&&allowthem_core::AuditEvent::MfaChallengeSuccess),
1491 "MfaChallengeSuccess must be in audit log"
1492 );
1493 assert!(
1494 event_types.contains(&&allowthem_core::AuditEvent::Login),
1495 "Login must be in audit log after MFA challenge success"
1496 );
1497 }
1498
1499 #[tokio::test]
1500 async fn post_mfa_challenge_wrong_recovery_code_shows_error() {
1501 let ath = setup().await;
1502 let app = test_app(ath.clone());
1503 let (user_id, _) = create_session(&ath).await;
1504 enable_mfa_for_user(&ath, user_id).await;
1505
1506 let token = ath.db().create_mfa_challenge(user_id).await.unwrap();
1507
1508 let body_str = format!("mfa_token={token}&recovery_code=AAAAAAAA&use_recovery=on");
1509 let req = Request::builder()
1510 .method("POST")
1511 .uri("/mfa/challenge")
1512 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1513 .body(Body::from(body_str))
1514 .unwrap();
1515 let resp = app.oneshot(req).await.unwrap();
1516
1517 assert_eq!(resp.status(), StatusCode::OK);
1518 let html = String::from_utf8(
1519 axum::body::to_bytes(resp.into_body(), usize::MAX)
1520 .await
1521 .unwrap()
1522 .to_vec(),
1523 )
1524 .unwrap();
1525 assert!(
1526 html.contains(CHALLENGE_INVALID_RECOVERY),
1527 "wrong recovery code must show recovery error"
1528 );
1529 }
1530}