1use std::net::{IpAddr, SocketAddr};
2use std::sync::Arc;
3use std::time::Instant;
4
5use axum::Form;
6use axum::Router;
7use axum::extract::{ConnectInfo, Extension, Query};
8use axum::http::header::{COOKIE, SET_COOKIE, USER_AGENT};
9use axum::http::{HeaderMap, StatusCode};
10use axum::response::{Html, IntoResponse, Response};
11use axum::routing::get;
12use axum_htmx::{HxBoosted, HxRequest};
13use chrono::Utc;
14use dashmap::DashMap;
15use serde::Deserialize;
16
17use allowthem_core::applications::BrandingConfig;
18#[cfg(test)]
19use allowthem_core::applications::CreateApplicationParams;
20use allowthem_core::password::verify_password;
21use allowthem_core::sessions;
22use allowthem_core::types::ClientId;
23use allowthem_core::{AllowThem, AuditEvent, PasswordHash, SessionToken};
24
25use crate::auth_views::{LoginView, login_fragment, login_page};
26use crate::branding::{DefaultBranding, default_branding_ref, resolve_branding};
27use crate::browser_error::BrowserError;
28use crate::csrf::CsrfToken;
29
30const LOGIN_ERROR: &str = "Invalid email or password.";
32
33const DUMMY_HASH: &str = "$argon2id$v=19$m=19456,t=2,p=1$ldQz3PJVzDn06G+Bzin5Ew$IaOeOaTQjgM1uJpHDULCxq8r6pj2OqvY/lcKo6Fv3IM";
37
38#[derive(Clone)]
39struct LoginConfig {
40 is_production: bool,
41 login_attempts: Arc<DashMap<IpAddr, (u32, Instant)>>,
42 max_login_attempts: u32,
43 rate_limit_window_secs: u64,
44 oauth_providers: Vec<String>,
45 signup_url: Option<String>,
46 terms_url: Option<String>,
47 privacy_url: Option<String>,
48}
49
50#[derive(Deserialize)]
51struct LoginQuery {
52 next: Option<String>,
53 client_id: Option<ClientId>,
54}
55
56#[derive(Deserialize)]
57struct LoginForm {
58 identifier: String,
59 password: String,
60 next: Option<String>,
61 client_id: Option<ClientId>,
62 #[allow(dead_code)]
63 csrf_token: String,
64}
65
66fn validate_next(next: &str) -> &str {
69 if next.starts_with('/') && !next.starts_with("//") && !next.contains("://") {
70 next
71 } else {
72 "/"
73 }
74}
75
76fn extract_session_token(ath: &AllowThem, headers: &HeaderMap) -> Option<SessionToken> {
77 headers
78 .get(COOKIE)
79 .and_then(|v| v.to_str().ok())
80 .and_then(|v| ath.parse_session_cookie(v))
81}
82
83fn render_login_form(
84 config: &LoginConfig,
85 csrf_token: &str,
86 identifier: &str,
87 next: Option<&str>,
88 error: &str,
89 client_id: Option<&ClientId>,
90 branding: Option<&BrandingConfig>,
91) -> Result<Html<String>, BrowserError> {
92 let next_val = next.map(validate_next).unwrap_or("");
93
94 login_page(&LoginView {
95 csrf_token,
96 identifier,
97 next: Some(next_val).filter(|value| !value.is_empty()),
98 error,
99 client_id: client_id.map(|c| c.as_str()),
100 oauth_providers: &config.oauth_providers,
101 signup_url: config.signup_url.as_deref(),
102 terms_url: config.terms_url.as_deref(),
103 privacy_url: config.privacy_url.as_deref(),
104 branding,
105 is_production: config.is_production,
106 })
107}
108
109fn render_login_fragment(
110 config: &LoginConfig,
111 csrf_token: &str,
112 identifier: &str,
113 next: Option<&str>,
114 error: &str,
115 client_id: Option<&ClientId>,
116 branding: Option<&BrandingConfig>,
117) -> Result<Html<String>, BrowserError> {
118 let next_val = next.map(validate_next).unwrap_or("");
119
120 login_fragment(&LoginView {
121 csrf_token,
122 identifier,
123 next: Some(next_val).filter(|value| !value.is_empty()),
124 error,
125 client_id: client_id.map(|c| c.as_str()),
126 oauth_providers: &config.oauth_providers,
127 signup_url: config.signup_url.as_deref(),
128 terms_url: config.terms_url.as_deref(),
129 privacy_url: config.privacy_url.as_deref(),
130 branding,
131 is_production: config.is_production,
132 })
133}
134
135fn is_rate_limited(config: &LoginConfig, ip: IpAddr) -> bool {
136 if let Some(entry) = config.login_attempts.get(&ip) {
137 let (count, window_start) = *entry;
138 if window_start.elapsed().as_secs() > config.rate_limit_window_secs {
139 return false;
140 }
141 count >= config.max_login_attempts
142 } else {
143 false
144 }
145}
146
147fn record_login_failure(config: &LoginConfig, ip: IpAddr) {
148 let now = Instant::now();
149 config
150 .login_attempts
151 .entry(ip)
152 .and_modify(|(count, window_start)| {
153 if window_start.elapsed().as_secs() > config.rate_limit_window_secs {
154 *count = 1;
155 *window_start = now;
156 } else {
157 *count += 1;
158 }
159 })
160 .or_insert((1, now));
161}
162
163fn record_login_success(config: &LoginConfig, ip: IpAddr) {
164 config.login_attempts.remove(&ip);
165}
166
167#[allow(clippy::too_many_arguments)]
169async fn get_login(
170 Extension(ath): Extension<AllowThem>,
171 Extension(config): Extension<LoginConfig>,
172 default_branding: Option<Extension<Arc<DefaultBranding>>>,
173 csrf: CsrfToken,
174 Query(query): Query<LoginQuery>,
175 headers: HeaderMap,
176 HxBoosted(boosted): HxBoosted,
177 HxRequest(request): HxRequest,
178) -> Result<Response, BrowserError> {
179 if let Some(token) = extract_session_token(&ath, &headers)
181 && ath.db().lookup_session(&token).await?.is_some()
182 {
183 let dest = query.next.as_deref().map(validate_next).unwrap_or("/");
184 return Ok((
185 StatusCode::SEE_OTHER,
186 [(axum::http::header::LOCATION, dest.to_string())],
187 )
188 .into_response());
189 }
190
191 let default = default_branding_ref(&default_branding);
192 let branding = resolve_branding(&ath, query.client_id.as_ref(), default).await;
193
194 if request && !boosted {
195 let html = render_login_fragment(
196 &config,
197 csrf.as_str(),
198 "",
199 query.next.as_deref(),
200 "",
201 query.client_id.as_ref(),
202 branding.as_ref(),
203 )?;
204 return Ok(html.into_response());
205 }
206
207 let html = render_login_form(
208 &config,
209 csrf.as_str(),
210 "",
211 query.next.as_deref(),
212 "",
213 query.client_id.as_ref(),
214 branding.as_ref(),
215 )?;
216 Ok(html.into_response())
217}
218
219async fn post_login(
221 Extension(ath): Extension<AllowThem>,
222 Extension(config): Extension<LoginConfig>,
223 default_branding: Option<Extension<Arc<DefaultBranding>>>,
224 csrf: CsrfToken,
225 ConnectInfo(addr): ConnectInfo<SocketAddr>,
226 headers: HeaderMap,
227 Form(form): Form<LoginForm>,
228) -> Result<Response, BrowserError> {
229 let ip = addr.ip();
230 let ua = headers.get(USER_AGENT).and_then(|v| v.to_str().ok());
231 let ip_str = ip.to_string();
232 let default = default_branding_ref(&default_branding);
233 let branding = resolve_branding(&ath, form.client_id.as_ref(), default).await;
234
235 if is_rate_limited(&config, ip) {
237 let html = render_login_form(
238 &config,
239 csrf.as_str(),
240 &form.identifier,
241 form.next.as_deref(),
242 "Too many login attempts. Please try again later.",
243 form.client_id.as_ref(),
244 branding.as_ref(),
245 )?;
246 return Ok((StatusCode::TOO_MANY_REQUESTS, html).into_response());
247 }
248
249 let identifier = form.identifier.trim();
250 if identifier.is_empty() {
251 let html = render_login_form(
252 &config,
253 csrf.as_str(),
254 "",
255 form.next.as_deref(),
256 LOGIN_ERROR,
257 form.client_id.as_ref(),
258 branding.as_ref(),
259 )?;
260 return Ok(html.into_response());
261 }
262
263 let dummy = PasswordHash::new_unchecked(DUMMY_HASH.to_string());
265 let user = ath.db().find_for_login(identifier).await;
266
267 match user {
268 Ok(user) => {
269 let hash = user.password_hash.as_ref().unwrap_or(&dummy);
270 let password_ok = verify_password(&form.password, hash).unwrap_or(false);
271
272 if password_ok && user.is_active {
273 record_login_success(&config, ip);
275
276 if ath.has_mfa_enabled(user.id).await? {
278 let mfa_token = ath.db().create_mfa_challenge(user.id).await?;
279 let dest = format!("/mfa/challenge?token={mfa_token}");
280 return Ok((
281 StatusCode::SEE_OTHER,
282 [(axum::http::header::LOCATION, dest)],
283 )
284 .into_response());
285 }
286
287 let token = sessions::generate_token();
288 let token_hash = sessions::hash_token(&token);
289 let ttl = ath.session_config().ttl;
290 let expires_at = Utc::now() + ttl;
291 ath.db()
292 .create_session(user.id, token_hash, Some(&ip_str), ua, expires_at)
293 .await?;
294
295 let cookie = ath.session_cookie(&token);
296 let _ = ath
297 .db()
298 .log_audit(
299 AuditEvent::Login,
300 Some(&user.id),
301 None,
302 Some(&ip_str),
303 ua,
304 None,
305 )
306 .await;
307
308 ath.notify_user_active(user.id);
309
310 let dest = form.next.as_deref().map(validate_next).unwrap_or("/");
311 Ok((
312 StatusCode::SEE_OTHER,
313 [
314 (SET_COOKIE, cookie),
315 (axum::http::header::LOCATION, dest.to_string()),
316 ],
317 )
318 .into_response())
319 } else {
320 let _ = ath
322 .db()
323 .log_audit(
324 AuditEvent::LoginFailed,
325 Some(&user.id),
326 None,
327 Some(&ip_str),
328 ua,
329 Some(identifier),
330 )
331 .await;
332 record_login_failure(&config, ip);
333
334 let html = render_login_form(
335 &config,
336 csrf.as_str(),
337 identifier,
338 form.next.as_deref(),
339 LOGIN_ERROR,
340 form.client_id.as_ref(),
341 branding.as_ref(),
342 )?;
343 Ok(html.into_response())
344 }
345 }
346 Err(allowthem_core::AuthError::NotFound) => {
347 let _ = verify_password(&form.password, &dummy);
349
350 let _ = ath
351 .db()
352 .log_audit(
353 AuditEvent::LoginFailed,
354 None,
355 None,
356 Some(&ip_str),
357 ua,
358 Some(identifier),
359 )
360 .await;
361 record_login_failure(&config, ip);
362
363 let html = render_login_form(
364 &config,
365 csrf.as_str(),
366 identifier,
367 form.next.as_deref(),
368 LOGIN_ERROR,
369 form.client_id.as_ref(),
370 branding.as_ref(),
371 )?;
372 Ok(html.into_response())
373 }
374 Err(e) => Err(BrowserError::Auth(e)),
375 }
376}
377
378pub struct LoginOverrides {
379 pub signup_url: Option<String>,
380 pub terms_url: Option<String>,
381 pub privacy_url: Option<String>,
382}
383
384pub fn login_routes(
385 is_production: bool,
386 max_login_attempts: u32,
387 rate_limit_window_secs: u64,
388 oauth_providers: Vec<String>,
389 overrides: Option<LoginOverrides>,
390) -> Router<()> {
391 let ov = overrides.unwrap_or(LoginOverrides {
392 signup_url: None,
393 terms_url: None,
394 privacy_url: None,
395 });
396 let cfg = LoginConfig {
397 is_production,
398 login_attempts: Arc::new(DashMap::new()),
399 max_login_attempts,
400 rate_limit_window_secs,
401 oauth_providers,
402 signup_url: ov.signup_url,
403 terms_url: ov.terms_url,
404 privacy_url: ov.privacy_url,
405 };
406 Router::new()
407 .route("/login", get(get_login).post(post_login))
408 .layer(Extension(cfg))
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 use axum::Router;
416 use axum::body::Body;
417 use axum::extract::connect_info::MockConnectInfo;
418 use axum::http::{Request, StatusCode, header};
419 use chrono::Duration;
420 use tower::ServiceExt;
421
422 use allowthem_core::types::ClientType;
423 use allowthem_core::{AllowThemBuilder, Email, generate_token, hash_token};
424
425 async fn setup() -> (AllowThem, LoginConfig) {
426 let ath = AllowThemBuilder::new("sqlite::memory:")
427 .cookie_secure(false)
428 .csrf_key(*b"test-csrf-key-for-binary-tests!!")
429 .build()
430 .await
431 .unwrap();
432 let config = LoginConfig {
433 is_production: false,
434 login_attempts: Arc::new(DashMap::new()),
435 max_login_attempts: 10,
436 rate_limit_window_secs: 900,
437 oauth_providers: Vec::new(),
438 signup_url: None,
439 terms_url: None,
440 privacy_url: None,
441 };
442 (ath, config)
443 }
444
445 fn test_app(ath: AllowThem, config: LoginConfig) -> Router {
446 login_routes(
447 config.is_production,
448 config.max_login_attempts,
449 config.rate_limit_window_secs,
450 config.oauth_providers.clone(),
451 None,
452 )
453 .layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
454 .layer(MockConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0))))
455 .layer(axum::middleware::from_fn_with_state(
456 ath.clone(),
457 crate::cors::inject_ath_into_extensions,
458 ))
459 }
460
461 async fn get_csrf_token(app: &Router) -> String {
462 let req = Request::builder()
463 .uri("/login")
464 .body(Body::empty())
465 .unwrap();
466 let resp = app.clone().oneshot(req).await.unwrap();
467 let set_cookie = resp
468 .headers()
469 .get(header::SET_COOKIE)
470 .unwrap()
471 .to_str()
472 .unwrap()
473 .to_string();
474 set_cookie
475 .split(';')
476 .next()
477 .unwrap()
478 .split('=')
479 .nth(1)
480 .unwrap()
481 .to_string()
482 }
483
484 fn login_request(
485 csrf: &str,
486 identifier: &str,
487 password: &str,
488 next: Option<&str>,
489 ) -> Request<Body> {
490 let mut body = format!(
491 "identifier={}&password={}&csrf_token={}",
492 identifier, password, csrf
493 );
494 if let Some(n) = next {
495 body.push_str(&format!("&next={}", n));
496 }
497 Request::builder()
498 .method("POST")
499 .uri("/login")
500 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
501 .header(header::COOKIE, format!("csrf_pre={}", csrf))
502 .body(Body::from(body))
503 .unwrap()
504 }
505
506 async fn create_user(ath: &AllowThem, email: &str, password: &str) {
507 let email = Email::new(email.into()).unwrap();
508 ath.db()
509 .create_user(email, password, None, None)
510 .await
511 .unwrap();
512 }
513
514 #[test]
517 fn validate_next_allows_simple_paths() {
518 assert_eq!(validate_next("/dashboard"), "/dashboard");
519 assert_eq!(validate_next("/search?q=foo"), "/search?q=foo");
520 assert_eq!(validate_next("/a/b/c"), "/a/b/c");
521 }
522
523 #[test]
524 fn validate_next_rejects_open_redirects() {
525 assert_eq!(validate_next("https://evil.com"), "/");
526 assert_eq!(validate_next("//evil.com"), "/");
527 assert_eq!(validate_next(""), "/");
528 assert_eq!(validate_next("relative/path"), "/");
529 assert_eq!(validate_next("/ok/path://thing"), "/");
530 assert_eq!(validate_next("http://evil.com/foo"), "/");
531 }
532
533 #[tokio::test]
536 async fn get_login_renders_form() {
537 let (ath, config) = setup().await;
538 let app = test_app(ath, config);
539
540 let resp = app
541 .oneshot(
542 Request::builder()
543 .uri("/login")
544 .body(Body::empty())
545 .unwrap(),
546 )
547 .await
548 .unwrap();
549
550 assert_eq!(resp.status(), StatusCode::OK);
551 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
552 .await
553 .unwrap();
554 let html = String::from_utf8(body.to_vec()).unwrap();
555 assert!(html.contains("<form"), "should contain a form");
556 assert!(html.contains("csrf_token"), "should contain csrf_token");
557 assert!(
558 html.contains("identifier"),
559 "should contain identifier input"
560 );
561 }
562
563 #[tokio::test]
564 async fn get_login_redirects_when_authenticated() {
565 let (ath, config) = setup().await;
566
567 let email = Email::new("auth@example.com".into()).unwrap();
568 let user = ath
569 .db()
570 .create_user(email, "pass123", None, None)
571 .await
572 .unwrap();
573 let token = generate_token();
574 let token_hash = hash_token(&token);
575 let expires = Utc::now() + Duration::hours(24);
576 ath.db()
577 .create_session(user.id, token_hash, None, None, expires)
578 .await
579 .unwrap();
580 let cookie = ath.session_cookie(&token);
581 let cookie_val = cookie.split(';').next().unwrap();
582
583 let app = test_app(ath, config);
584 let req = Request::builder()
585 .uri("/login")
586 .header(header::COOKIE, cookie_val)
587 .body(Body::empty())
588 .unwrap();
589 let resp = app.oneshot(req).await.unwrap();
590
591 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
592 assert_eq!(resp.headers().get("location").unwrap(), "/");
593 }
594
595 #[tokio::test]
596 async fn get_login_preserves_next_param() {
597 let (ath, config) = setup().await;
598 let app = test_app(ath, config);
599
600 let resp = app
601 .oneshot(
602 Request::builder()
603 .uri("/login?next=/dashboard")
604 .body(Body::empty())
605 .unwrap(),
606 )
607 .await
608 .unwrap();
609
610 assert_eq!(resp.status(), StatusCode::OK);
611 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
612 .await
613 .unwrap();
614 let html = String::from_utf8(body.to_vec()).unwrap();
615 assert!(
616 html.contains("name=\"next\""),
617 "should contain next hidden field"
618 );
619 assert!(
620 html.contains("dashboard"),
621 "next field should contain dashboard"
622 );
623 }
624
625 #[tokio::test]
626 async fn post_login_success_redirects() {
627 let (ath, config) = setup().await;
628 create_user(&ath, "login@example.com", "correcthorse").await;
629 let app = test_app(ath, config);
630
631 let csrf = get_csrf_token(&app).await;
632 let req = login_request(&csrf, "login@example.com", "correcthorse", None);
633 let resp = app.oneshot(req).await.unwrap();
634
635 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
636 assert_eq!(resp.headers().get("location").unwrap(), "/");
637 assert!(
638 resp.headers().get(SET_COOKIE).is_some(),
639 "should set session cookie"
640 );
641 }
642
643 #[tokio::test]
644 async fn post_login_success_redirects_to_next() {
645 let (ath, config) = setup().await;
646 create_user(&ath, "next@example.com", "correcthorse").await;
647 let app = test_app(ath, config);
648
649 let csrf = get_csrf_token(&app).await;
650 let req = login_request(
651 &csrf,
652 "next@example.com",
653 "correcthorse",
654 Some("/dashboard"),
655 );
656 let resp = app.oneshot(req).await.unwrap();
657
658 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
659 assert_eq!(resp.headers().get("location").unwrap(), "/dashboard");
660 }
661
662 #[tokio::test]
663 async fn post_login_wrong_password_shows_error() {
664 let (ath, config) = setup().await;
665 create_user(&ath, "wrong@example.com", "correcthorse").await;
666 let app = test_app(ath, config);
667
668 let csrf = get_csrf_token(&app).await;
669 let req = login_request(&csrf, "wrong@example.com", "wrongpassword", None);
670 let resp = app.oneshot(req).await.unwrap();
671
672 assert_eq!(resp.status(), StatusCode::OK);
673 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
674 .await
675 .unwrap();
676 let html = String::from_utf8(body.to_vec()).unwrap();
677 assert!(html.contains(LOGIN_ERROR), "should show generic error");
678 assert!(
679 html.contains("wrong@example.com"),
680 "should pre-fill identifier"
681 );
682 }
683
684 #[tokio::test]
685 async fn post_login_nonexistent_user_shows_error() {
686 let (ath, config) = setup().await;
687 let app = test_app(ath, config);
688
689 let csrf = get_csrf_token(&app).await;
690 let req = login_request(&csrf, "nobody@example.com", "anypassword", None);
691 let resp = app.oneshot(req).await.unwrap();
692
693 assert_eq!(resp.status(), StatusCode::OK);
694 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
695 .await
696 .unwrap();
697 let html = String::from_utf8(body.to_vec()).unwrap();
698 assert!(
699 html.contains(LOGIN_ERROR),
700 "should show same generic error as wrong password"
701 );
702 }
703
704 #[tokio::test]
705 async fn post_login_inactive_user_shows_error() {
706 let (ath, config) = setup().await;
707 let email = Email::new("inactive@example.com".into()).unwrap();
708 let user = ath
709 .db()
710 .create_user(email, "correcthorse", None, None)
711 .await
712 .unwrap();
713 ath.db().update_user_active(user.id, false).await.unwrap();
714
715 let app = test_app(ath, config);
716 let csrf = get_csrf_token(&app).await;
717 let req = login_request(&csrf, "inactive@example.com", "correcthorse", None);
718 let resp = app.oneshot(req).await.unwrap();
719
720 assert_eq!(resp.status(), StatusCode::OK);
721 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
722 .await
723 .unwrap();
724 let html = String::from_utf8(body.to_vec()).unwrap();
725 assert!(
726 html.contains(LOGIN_ERROR),
727 "inactive user should get generic error"
728 );
729 }
730
731 #[tokio::test]
732 async fn post_login_rate_limit() {
733 let (ath, config) = setup().await;
734 let app = test_app(ath, config);
735
736 let csrf = get_csrf_token(&app).await;
737
738 for _ in 0..10_u32 {
740 let req = login_request(&csrf, "nobody@example.com", "wrong", None);
741 let _ = app.clone().oneshot(req).await.unwrap();
742 }
743
744 let req = login_request(&csrf, "nobody@example.com", "wrong", None);
746 let resp = app.oneshot(req).await.unwrap();
747
748 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
749 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
750 .await
751 .unwrap();
752 let html = String::from_utf8(body.to_vec()).unwrap();
753 assert!(html.contains("Too many login attempts"));
754 }
755
756 #[tokio::test]
757 async fn post_login_csrf_required() {
758 let (ath, config) = setup().await;
759 let app = test_app(ath, config);
760
761 let req = Request::builder()
762 .method("POST")
763 .uri("/login")
764 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
765 .body(Body::from("identifier=test&password=test"))
766 .unwrap();
767 let resp = app.oneshot(req).await.unwrap();
768
769 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
770 }
771
772 #[tokio::test]
773 async fn login_with_client_id_shows_branding() {
774 let (ath, config) = setup().await;
775 let (app, _) = ath
776 .db()
777 .create_application(CreateApplicationParams {
778 name: "BrandedApp".into(),
779 client_type: ClientType::Confidential,
780 redirect_uris: vec!["https://example.com/cb".into()],
781 is_trusted: false,
782 created_by: None,
783 logo_url: Some("https://cdn.example.com/logo.png".into()),
784 primary_color: Some("#ff6600".into()),
785 accent_hex: None,
786 accent_ink: None,
787 forced_mode: None,
788 font_css_url: None,
789 font_family: None,
790 splash_text: None,
791 splash_image_url: None,
792 splash_primitive: None,
793 splash_url: None,
794 shader_cell_scale: None,
795 })
796 .await
797 .unwrap();
798 let router = test_app(ath, config);
799
800 let req = Request::builder()
801 .uri(&format!("/login?client_id={}", app.client_id))
802 .body(Body::empty())
803 .unwrap();
804 let resp = router.oneshot(req).await.unwrap();
805 assert_eq!(resp.status(), StatusCode::OK);
806 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
807 .await
808 .unwrap();
809 let html = String::from_utf8(body.to_vec()).unwrap();
810 assert!(html.contains("BrandedApp"), "should show app name");
811 assert!(
812 html.contains("<title>Log in — BrandedApp</title>"),
813 "default title brand should use the application name"
814 );
815 assert!(
816 html.contains("--accent: #ff6600"),
817 "primary_color should flow to --accent"
818 );
819 assert!(
820 html.contains("--accent-ink:"),
821 "accent_ink should be emitted in template"
822 );
823 }
824
825 #[tokio::test]
826 async fn login_without_client_id_shows_default() {
827 let (ath, config) = setup().await;
828 let router = test_app(ath, config);
829
830 let req = Request::builder()
831 .uri("/login")
832 .body(Body::empty())
833 .unwrap();
834 let resp = router.oneshot(req).await.unwrap();
835 assert_eq!(resp.status(), StatusCode::OK);
836 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
837 .await
838 .unwrap();
839 let html = String::from_utf8(body.to_vec()).unwrap();
840 assert!(!html.contains("<img"), "no logo without client_id");
841 assert!(
842 html.contains("--accent: #ffffff"),
843 "should have default white accent"
844 );
845 assert!(
846 html.contains("--accent-ink: #000000"),
847 "should have default black ink"
848 );
849 }
850
851 #[tokio::test]
852 async fn login_with_invalid_client_id_shows_default() {
853 let (ath, config) = setup().await;
854 let router = test_app(ath, config);
855
856 let req = Request::builder()
857 .uri("/login?client_id=ath_nonexistent")
858 .body(Body::empty())
859 .unwrap();
860 let resp = router.oneshot(req).await.unwrap();
861 assert_eq!(resp.status(), StatusCode::OK);
862 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
863 .await
864 .unwrap();
865 let html = String::from_utf8(body.to_vec()).unwrap();
866 assert!(!html.contains("<img"), "no logo for invalid client_id");
867 assert!(
868 html.contains("--accent: #ffffff"),
869 "should fall back to default white accent"
870 );
871 assert!(
872 html.contains("--accent-ink: #000000"),
873 "should fall back to default black ink"
874 );
875 }
876
877 #[tokio::test]
878 async fn login_default_emits_light_mode_accent_override() {
879 let (ath, config) = setup().await;
884 let router = test_app(ath, config);
885
886 let req = Request::builder()
887 .uri("/login")
888 .body(Body::empty())
889 .unwrap();
890 let resp = router.oneshot(req).await.unwrap();
891 assert_eq!(resp.status(), StatusCode::OK);
892 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
893 .await
894 .unwrap();
895 let html = String::from_utf8(body.to_vec()).unwrap();
896
897 let root_idx = html
898 .find(":root {")
899 .expect(":root block missing from base.html");
900 let light_idx = html
901 .find("html[data-mode=\"light\"] {")
902 .expect("html[data-mode=\"light\"] block missing from base.html");
903 assert!(
904 root_idx < light_idx,
905 ":root block should come before the light-mode override"
906 );
907
908 let root_block = &html[root_idx..light_idx];
909 assert!(
910 root_block.contains("--accent: #ffffff;"),
911 ":root should pin the default white accent"
912 );
913 assert!(
914 root_block.contains("--accent-ink: #000000;"),
915 ":root should pin the default black ink"
916 );
917
918 let light_block = &html[light_idx..];
919 assert!(
920 light_block.contains("--accent: #000000;"),
921 "html[data-mode=\"light\"] should override accent to black"
922 );
923 assert!(
924 light_block.contains("--accent-ink: #ffffff;"),
925 "html[data-mode=\"light\"] should override ink to white"
926 );
927 }
928
929 #[tokio::test]
930 async fn login_branded_emits_same_accent_in_both_modes() {
931 let (ath, config) = setup().await;
935 let (app, _) = ath
936 .db()
937 .create_application(CreateApplicationParams {
938 name: "BrandedThemeApp".into(),
939 client_type: ClientType::Confidential,
940 redirect_uris: vec!["https://example.com/cb".into()],
941 is_trusted: false,
942 created_by: None,
943 logo_url: None,
944 primary_color: None,
945 accent_hex: Some("#ff6600".into()),
946 accent_ink: None,
947 forced_mode: None,
948 font_css_url: None,
949 font_family: None,
950 splash_text: None,
951 splash_image_url: None,
952 splash_primitive: None,
953 splash_url: None,
954 shader_cell_scale: None,
955 })
956 .await
957 .unwrap();
958 let router = test_app(ath, config);
959
960 let req = Request::builder()
961 .uri(&format!("/login?client_id={}", app.client_id))
962 .body(Body::empty())
963 .unwrap();
964 let resp = router.oneshot(req).await.unwrap();
965 assert_eq!(resp.status(), StatusCode::OK);
966 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
967 .await
968 .unwrap();
969 let html = String::from_utf8(body.to_vec()).unwrap();
970
971 let root_idx = html
972 .find(":root {")
973 .expect(":root block missing from base.html");
974 let light_idx = html
975 .find("html[data-mode=\"light\"] {")
976 .expect("html[data-mode=\"light\"] block missing from base.html");
977 let root_block = &html[root_idx..light_idx];
978 let light_block = &html[light_idx..];
979 assert!(
980 root_block.contains("--accent: #ff6600;"),
981 ":root should use the integrator's accent"
982 );
983 assert!(
984 light_block.contains("--accent: #ff6600;"),
985 "html[data-mode=\"light\"] should also use the integrator's accent"
986 );
987 }
988
989 #[tokio::test]
990 async fn branded_login_post_failure_preserves_branding() {
991 let (ath, config) = setup().await;
992 create_user(&ath, "branded@example.com", "correcthorse").await;
993 let (app, _) = ath
994 .db()
995 .create_application(CreateApplicationParams {
996 name: "BrandedPost".into(),
997 client_type: ClientType::Confidential,
998 redirect_uris: vec!["https://example.com/cb".into()],
999 is_trusted: false,
1000 created_by: None,
1001 logo_url: None,
1002 primary_color: Some("#ff6600".into()),
1003 accent_hex: None,
1004 accent_ink: None,
1005 forced_mode: None,
1006 font_css_url: None,
1007 font_family: None,
1008 splash_text: None,
1009 splash_image_url: None,
1010 splash_primitive: None,
1011 splash_url: None,
1012 shader_cell_scale: None,
1013 })
1014 .await
1015 .unwrap();
1016 let router = test_app(ath, config);
1017
1018 let csrf = get_csrf_token(&router).await;
1019 let body_str = format!(
1020 "identifier=branded%40example.com&password=wrong&csrf_token={}&client_id={}",
1021 csrf, app.client_id,
1022 );
1023 let req = Request::builder()
1024 .method("POST")
1025 .uri("/login")
1026 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1027 .header(header::COOKIE, format!("csrf_pre={}", csrf))
1028 .body(Body::from(body_str))
1029 .unwrap();
1030 let resp = router.oneshot(req).await.unwrap();
1031 assert_eq!(resp.status(), StatusCode::OK);
1032 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
1033 .await
1034 .unwrap();
1035 let html = String::from_utf8(body.to_vec()).unwrap();
1036 assert!(
1037 html.contains("BrandedPost"),
1038 "app name preserved after error"
1039 );
1040 assert!(
1041 html.contains("--accent: #ff6600"),
1042 "accent color preserved after error"
1043 );
1044 }
1045
1046 #[tokio::test]
1047 async fn post_login_with_mfa_enabled_redirects_to_challenge_without_session() {
1048 const MFA_KEY: [u8; 32] = [0x42; 32];
1051 let ath = AllowThemBuilder::new("sqlite::memory:")
1052 .cookie_secure(false)
1053 .mfa_key(MFA_KEY)
1054 .csrf_key(*b"test-csrf-key-for-binary-tests!!")
1055 .build()
1056 .await
1057 .unwrap();
1058 let config = LoginConfig {
1059 is_production: false,
1060 login_attempts: Arc::new(DashMap::new()),
1061 max_login_attempts: 10,
1062 rate_limit_window_secs: 900,
1063 oauth_providers: Vec::new(),
1064 signup_url: None,
1065 terms_url: None,
1066 privacy_url: None,
1067 };
1068
1069 create_user(&ath, "mfa-gate@example.com", "correcthorse").await;
1071 let user = ath
1072 .db()
1073 .find_for_login("mfa-gate@example.com")
1074 .await
1075 .unwrap();
1076 let secret = ath.create_mfa_secret(user.id).await.unwrap();
1077 use totp_rs::{Algorithm, Secret, TOTP};
1078 let totp = TOTP::new(
1079 Algorithm::SHA1,
1080 6,
1081 1,
1082 30,
1083 Secret::Encoded(secret).to_bytes().unwrap(),
1084 None,
1085 String::new(),
1086 )
1087 .unwrap();
1088 let code = totp.generate_current().unwrap();
1089 ath.enable_mfa(user.id, &code).await.unwrap();
1090
1091 let app = test_app(ath, config);
1092 let csrf = get_csrf_token(&app).await;
1093 let req = login_request(&csrf, "mfa-gate@example.com", "correcthorse", None);
1094 let resp = app.oneshot(req).await.unwrap();
1095
1096 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1098 let location = resp.headers().get("location").unwrap().to_str().unwrap();
1099 assert!(
1100 location.starts_with("/mfa/challenge?token="),
1101 "MFA gate must redirect to /mfa/challenge, got: {location}"
1102 );
1103 assert!(
1105 resp.headers().get(SET_COOKIE).is_none(),
1106 "MFA gate must not set a session cookie before TOTP is verified"
1107 );
1108 }
1109
1110 #[tokio::test]
1111 async fn render_login_fragment_composes_main_and_oob_head() {
1112 let (_ath, config) = setup().await;
1113 let html = render_login_fragment(&config, "tok", "", None, "", None, None)
1114 .unwrap()
1115 .0;
1116 assert!(
1117 html.contains("<main class=\"wf-auth-form\">"),
1118 "fragment must include the <main> root"
1119 );
1120 assert!(
1121 html.contains("<title hx-swap-oob=\"true\">"),
1122 "fragment must include the OOB <title> tag"
1123 );
1124 assert!(
1125 html.contains("id=\"wf-screen-label\""),
1126 "fragment must include the OOB #wf-screen-label span"
1127 );
1128 }
1129
1130 #[tokio::test]
1131 async fn login_register_link_carries_client_id() {
1132 let (ath, config) = setup().await;
1133 let (app, _) = ath
1134 .db()
1135 .create_application(CreateApplicationParams {
1136 name: "LinkApp".into(),
1137 client_type: ClientType::Confidential,
1138 redirect_uris: vec!["https://example.com/cb".into()],
1139 is_trusted: false,
1140 created_by: None,
1141 logo_url: None,
1142 primary_color: None,
1143 accent_hex: None,
1144 accent_ink: None,
1145 forced_mode: None,
1146 font_css_url: None,
1147 font_family: None,
1148 splash_text: None,
1149 splash_image_url: None,
1150 splash_primitive: None,
1151 splash_url: None,
1152 shader_cell_scale: None,
1153 })
1154 .await
1155 .unwrap();
1156 let router = test_app(ath, config);
1157
1158 let req = Request::builder()
1159 .uri(&format!("/login?client_id={}", app.client_id))
1160 .body(Body::empty())
1161 .unwrap();
1162 let resp = router.oneshot(req).await.unwrap();
1163 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
1164 .await
1165 .unwrap();
1166 let html = String::from_utf8(body.to_vec()).unwrap();
1167 let id = app.client_id.as_str();
1168 let register_link = format!("/register?client_id={id}");
1169 assert!(
1170 html.contains(®ister_link),
1171 "register link should carry client_id"
1172 );
1173 }
1174}