Skip to main content

allowthem_server/
login_routes.rs

1use std::net::{IpAddr, SocketAddr};
2use std::sync::Arc;
3use std::time::Instant;
4
5use axum::Form;
6use axum::Router;
7use axum::extract::{ConnectInfo, Extension, Query};
8use axum::http::header::{COOKIE, SET_COOKIE, USER_AGENT};
9use axum::http::{HeaderMap, StatusCode};
10use axum::response::{Html, IntoResponse, Response};
11use axum::routing::get;
12use axum_htmx::{HxBoosted, HxRequest};
13use chrono::Utc;
14use dashmap::DashMap;
15use serde::Deserialize;
16
17use allowthem_core::applications::BrandingConfig;
18#[cfg(test)]
19use allowthem_core::applications::CreateApplicationParams;
20use allowthem_core::password::verify_password;
21use allowthem_core::sessions;
22use allowthem_core::types::ClientId;
23use allowthem_core::{AllowThem, AuditEvent, PasswordHash, SessionToken};
24
25use crate::auth_views::{LoginView, login_fragment, login_page};
26use crate::branding::{DefaultBranding, default_branding_ref, resolve_branding};
27use crate::browser_error::BrowserError;
28use crate::csrf::CsrfToken;
29
30/// Generic error shown for all credential failures.
31const LOGIN_ERROR: &str = "Invalid email or password.";
32
33/// Pre-computed Argon2id hash for timing equalization when a user is not found.
34/// The actual value doesn't matter — we just need `verify_password()` to run its
35/// full Argon2id computation so the response time is consistent.
36const DUMMY_HASH: &str = "$argon2id$v=19$m=19456,t=2,p=1$ldQz3PJVzDn06G+Bzin5Ew$IaOeOaTQjgM1uJpHDULCxq8r6pj2OqvY/lcKo6Fv3IM";
37
38#[derive(Clone)]
39struct LoginConfig {
40    is_production: bool,
41    login_attempts: Arc<DashMap<IpAddr, (u32, Instant)>>,
42    max_login_attempts: u32,
43    rate_limit_window_secs: u64,
44    oauth_providers: Vec<String>,
45    signup_url: Option<String>,
46    terms_url: Option<String>,
47    privacy_url: Option<String>,
48}
49
50#[derive(Deserialize)]
51struct LoginQuery {
52    next: Option<String>,
53    client_id: Option<ClientId>,
54}
55
56#[derive(Deserialize)]
57struct LoginForm {
58    identifier: String,
59    password: String,
60    next: Option<String>,
61    client_id: Option<ClientId>,
62    #[allow(dead_code)]
63    csrf_token: String,
64}
65
66/// Open redirect protection: only allow paths starting with `/`, reject
67/// protocol-relative (`//`) and absolute URLs with schemes (`://`).
68fn validate_next(next: &str) -> &str {
69    if next.starts_with('/') && !next.starts_with("//") && !next.contains("://") {
70        next
71    } else {
72        "/"
73    }
74}
75
76fn extract_session_token(ath: &AllowThem, headers: &HeaderMap) -> Option<SessionToken> {
77    headers
78        .get(COOKIE)
79        .and_then(|v| v.to_str().ok())
80        .and_then(|v| ath.parse_session_cookie(v))
81}
82
83fn render_login_form(
84    config: &LoginConfig,
85    csrf_token: &str,
86    identifier: &str,
87    next: Option<&str>,
88    error: &str,
89    client_id: Option<&ClientId>,
90    branding: Option<&BrandingConfig>,
91) -> Result<Html<String>, BrowserError> {
92    let next_val = next.map(validate_next).unwrap_or("");
93
94    login_page(&LoginView {
95        csrf_token,
96        identifier,
97        next: Some(next_val).filter(|value| !value.is_empty()),
98        error,
99        client_id: client_id.map(|c| c.as_str()),
100        oauth_providers: &config.oauth_providers,
101        signup_url: config.signup_url.as_deref(),
102        terms_url: config.terms_url.as_deref(),
103        privacy_url: config.privacy_url.as_deref(),
104        branding,
105        is_production: config.is_production,
106    })
107}
108
109fn render_login_fragment(
110    config: &LoginConfig,
111    csrf_token: &str,
112    identifier: &str,
113    next: Option<&str>,
114    error: &str,
115    client_id: Option<&ClientId>,
116    branding: Option<&BrandingConfig>,
117) -> Result<Html<String>, BrowserError> {
118    let next_val = next.map(validate_next).unwrap_or("");
119
120    login_fragment(&LoginView {
121        csrf_token,
122        identifier,
123        next: Some(next_val).filter(|value| !value.is_empty()),
124        error,
125        client_id: client_id.map(|c| c.as_str()),
126        oauth_providers: &config.oauth_providers,
127        signup_url: config.signup_url.as_deref(),
128        terms_url: config.terms_url.as_deref(),
129        privacy_url: config.privacy_url.as_deref(),
130        branding,
131        is_production: config.is_production,
132    })
133}
134
135fn is_rate_limited(config: &LoginConfig, ip: IpAddr) -> bool {
136    if let Some(entry) = config.login_attempts.get(&ip) {
137        let (count, window_start) = *entry;
138        if window_start.elapsed().as_secs() > config.rate_limit_window_secs {
139            return false;
140        }
141        count >= config.max_login_attempts
142    } else {
143        false
144    }
145}
146
147fn record_login_failure(config: &LoginConfig, ip: IpAddr) {
148    let now = Instant::now();
149    config
150        .login_attempts
151        .entry(ip)
152        .and_modify(|(count, window_start)| {
153            if window_start.elapsed().as_secs() > config.rate_limit_window_secs {
154                *count = 1;
155                *window_start = now;
156            } else {
157                *count += 1;
158            }
159        })
160        .or_insert((1, now));
161}
162
163fn record_login_success(config: &LoginConfig, ip: IpAddr) {
164    config.login_attempts.remove(&ip);
165}
166
167/// GET /login — render the login form, or redirect if already authenticated.
168#[allow(clippy::too_many_arguments)]
169async fn get_login(
170    Extension(ath): Extension<AllowThem>,
171    Extension(config): Extension<LoginConfig>,
172    default_branding: Option<Extension<Arc<DefaultBranding>>>,
173    csrf: CsrfToken,
174    Query(query): Query<LoginQuery>,
175    headers: HeaderMap,
176    HxBoosted(boosted): HxBoosted,
177    HxRequest(request): HxRequest,
178) -> Result<Response, BrowserError> {
179    // If already authenticated, redirect
180    if let Some(token) = extract_session_token(&ath, &headers)
181        && ath.db().lookup_session(&token).await?.is_some()
182    {
183        let dest = query.next.as_deref().map(validate_next).unwrap_or("/");
184        return Ok((
185            StatusCode::SEE_OTHER,
186            [(axum::http::header::LOCATION, dest.to_string())],
187        )
188            .into_response());
189    }
190
191    let default = default_branding_ref(&default_branding);
192    let branding = resolve_branding(&ath, query.client_id.as_ref(), default).await;
193
194    if request && !boosted {
195        let html = render_login_fragment(
196            &config,
197            csrf.as_str(),
198            "",
199            query.next.as_deref(),
200            "",
201            query.client_id.as_ref(),
202            branding.as_ref(),
203        )?;
204        return Ok(html.into_response());
205    }
206
207    let html = render_login_form(
208        &config,
209        csrf.as_str(),
210        "",
211        query.next.as_deref(),
212        "",
213        query.client_id.as_ref(),
214        branding.as_ref(),
215    )?;
216    Ok(html.into_response())
217}
218
219/// POST /login — validate credentials, create session on success.
220async fn post_login(
221    Extension(ath): Extension<AllowThem>,
222    Extension(config): Extension<LoginConfig>,
223    default_branding: Option<Extension<Arc<DefaultBranding>>>,
224    csrf: CsrfToken,
225    ConnectInfo(addr): ConnectInfo<SocketAddr>,
226    headers: HeaderMap,
227    Form(form): Form<LoginForm>,
228) -> Result<Response, BrowserError> {
229    let ip = addr.ip();
230    let ua = headers.get(USER_AGENT).and_then(|v| v.to_str().ok());
231    let ip_str = ip.to_string();
232    let default = default_branding_ref(&default_branding);
233    let branding = resolve_branding(&ath, form.client_id.as_ref(), default).await;
234
235    // 1. Rate limit check
236    if is_rate_limited(&config, ip) {
237        let html = render_login_form(
238            &config,
239            csrf.as_str(),
240            &form.identifier,
241            form.next.as_deref(),
242            "Too many login attempts. Please try again later.",
243            form.client_id.as_ref(),
244            branding.as_ref(),
245        )?;
246        return Ok((StatusCode::TOO_MANY_REQUESTS, html).into_response());
247    }
248
249    let identifier = form.identifier.trim();
250    if identifier.is_empty() {
251        let html = render_login_form(
252            &config,
253            csrf.as_str(),
254            "",
255            form.next.as_deref(),
256            LOGIN_ERROR,
257            form.client_id.as_ref(),
258            branding.as_ref(),
259        )?;
260        return Ok(html.into_response());
261    }
262
263    // 2. Look up user
264    let dummy = PasswordHash::new_unchecked(DUMMY_HASH.to_string());
265    let user = ath.db().find_for_login(identifier).await;
266
267    match user {
268        Ok(user) => {
269            let hash = user.password_hash.as_ref().unwrap_or(&dummy);
270            let password_ok = verify_password(&form.password, hash).unwrap_or(false);
271
272            if password_ok && user.is_active {
273                // Success
274                record_login_success(&config, ip);
275
276                // MFA gate: if user has MFA enabled, redirect to challenge page
277                if ath.has_mfa_enabled(user.id).await? {
278                    let mfa_token = ath.db().create_mfa_challenge(user.id).await?;
279                    let dest = format!("/mfa/challenge?token={mfa_token}");
280                    return Ok((
281                        StatusCode::SEE_OTHER,
282                        [(axum::http::header::LOCATION, dest)],
283                    )
284                        .into_response());
285                }
286
287                let token = sessions::generate_token();
288                let token_hash = sessions::hash_token(&token);
289                let ttl = ath.session_config().ttl;
290                let expires_at = Utc::now() + ttl;
291                ath.db()
292                    .create_session(user.id, token_hash, Some(&ip_str), ua, expires_at)
293                    .await?;
294
295                let cookie = ath.session_cookie(&token);
296                let _ = ath
297                    .db()
298                    .log_audit(
299                        AuditEvent::Login,
300                        Some(&user.id),
301                        None,
302                        Some(&ip_str),
303                        ua,
304                        None,
305                    )
306                    .await;
307
308                ath.notify_user_active(user.id);
309
310                let dest = form.next.as_deref().map(validate_next).unwrap_or("/");
311                Ok((
312                    StatusCode::SEE_OTHER,
313                    [
314                        (SET_COOKIE, cookie),
315                        (axum::http::header::LOCATION, dest.to_string()),
316                    ],
317                )
318                    .into_response())
319            } else {
320                // Wrong password or inactive user
321                let _ = ath
322                    .db()
323                    .log_audit(
324                        AuditEvent::LoginFailed,
325                        Some(&user.id),
326                        None,
327                        Some(&ip_str),
328                        ua,
329                        Some(identifier),
330                    )
331                    .await;
332                record_login_failure(&config, ip);
333
334                let html = render_login_form(
335                    &config,
336                    csrf.as_str(),
337                    identifier,
338                    form.next.as_deref(),
339                    LOGIN_ERROR,
340                    form.client_id.as_ref(),
341                    branding.as_ref(),
342                )?;
343                Ok(html.into_response())
344            }
345        }
346        Err(allowthem_core::AuthError::NotFound) => {
347            // Timing equalization: run verify against dummy hash
348            let _ = verify_password(&form.password, &dummy);
349
350            let _ = ath
351                .db()
352                .log_audit(
353                    AuditEvent::LoginFailed,
354                    None,
355                    None,
356                    Some(&ip_str),
357                    ua,
358                    Some(identifier),
359                )
360                .await;
361            record_login_failure(&config, ip);
362
363            let html = render_login_form(
364                &config,
365                csrf.as_str(),
366                identifier,
367                form.next.as_deref(),
368                LOGIN_ERROR,
369                form.client_id.as_ref(),
370                branding.as_ref(),
371            )?;
372            Ok(html.into_response())
373        }
374        Err(e) => Err(BrowserError::Auth(e)),
375    }
376}
377
378pub struct LoginOverrides {
379    pub signup_url: Option<String>,
380    pub terms_url: Option<String>,
381    pub privacy_url: Option<String>,
382}
383
384pub fn login_routes(
385    is_production: bool,
386    max_login_attempts: u32,
387    rate_limit_window_secs: u64,
388    oauth_providers: Vec<String>,
389    overrides: Option<LoginOverrides>,
390) -> Router<()> {
391    let ov = overrides.unwrap_or(LoginOverrides {
392        signup_url: None,
393        terms_url: None,
394        privacy_url: None,
395    });
396    let cfg = LoginConfig {
397        is_production,
398        login_attempts: Arc::new(DashMap::new()),
399        max_login_attempts,
400        rate_limit_window_secs,
401        oauth_providers,
402        signup_url: ov.signup_url,
403        terms_url: ov.terms_url,
404        privacy_url: ov.privacy_url,
405    };
406    Router::new()
407        .route("/login", get(get_login).post(post_login))
408        .layer(Extension(cfg))
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    use axum::Router;
416    use axum::body::Body;
417    use axum::extract::connect_info::MockConnectInfo;
418    use axum::http::{Request, StatusCode, header};
419    use chrono::Duration;
420    use tower::ServiceExt;
421
422    use allowthem_core::types::ClientType;
423    use allowthem_core::{AllowThemBuilder, Email, generate_token, hash_token};
424
425    async fn setup() -> (AllowThem, LoginConfig) {
426        let ath = AllowThemBuilder::new("sqlite::memory:")
427            .cookie_secure(false)
428            .csrf_key(*b"test-csrf-key-for-binary-tests!!")
429            .build()
430            .await
431            .unwrap();
432        let config = LoginConfig {
433            is_production: false,
434            login_attempts: Arc::new(DashMap::new()),
435            max_login_attempts: 10,
436            rate_limit_window_secs: 900,
437            oauth_providers: Vec::new(),
438            signup_url: None,
439            terms_url: None,
440            privacy_url: None,
441        };
442        (ath, config)
443    }
444
445    fn test_app(ath: AllowThem, config: LoginConfig) -> Router {
446        login_routes(
447            config.is_production,
448            config.max_login_attempts,
449            config.rate_limit_window_secs,
450            config.oauth_providers.clone(),
451            None,
452        )
453        .layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
454        .layer(MockConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0))))
455        .layer(axum::middleware::from_fn_with_state(
456            ath.clone(),
457            crate::cors::inject_ath_into_extensions,
458        ))
459    }
460
461    async fn get_csrf_token(app: &Router) -> String {
462        let req = Request::builder()
463            .uri("/login")
464            .body(Body::empty())
465            .unwrap();
466        let resp = app.clone().oneshot(req).await.unwrap();
467        let set_cookie = resp
468            .headers()
469            .get(header::SET_COOKIE)
470            .unwrap()
471            .to_str()
472            .unwrap()
473            .to_string();
474        set_cookie
475            .split(';')
476            .next()
477            .unwrap()
478            .split('=')
479            .nth(1)
480            .unwrap()
481            .to_string()
482    }
483
484    fn login_request(
485        csrf: &str,
486        identifier: &str,
487        password: &str,
488        next: Option<&str>,
489    ) -> Request<Body> {
490        let mut body = format!(
491            "identifier={}&password={}&csrf_token={}",
492            identifier, password, csrf
493        );
494        if let Some(n) = next {
495            body.push_str(&format!("&next={}", n));
496        }
497        Request::builder()
498            .method("POST")
499            .uri("/login")
500            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
501            .header(header::COOKIE, format!("csrf_pre={}", csrf))
502            .body(Body::from(body))
503            .unwrap()
504    }
505
506    async fn create_user(ath: &AllowThem, email: &str, password: &str) {
507        let email = Email::new(email.into()).unwrap();
508        ath.db()
509            .create_user(email, password, None, None)
510            .await
511            .unwrap();
512    }
513
514    // --- Unit tests ---
515
516    #[test]
517    fn validate_next_allows_simple_paths() {
518        assert_eq!(validate_next("/dashboard"), "/dashboard");
519        assert_eq!(validate_next("/search?q=foo"), "/search?q=foo");
520        assert_eq!(validate_next("/a/b/c"), "/a/b/c");
521    }
522
523    #[test]
524    fn validate_next_rejects_open_redirects() {
525        assert_eq!(validate_next("https://evil.com"), "/");
526        assert_eq!(validate_next("//evil.com"), "/");
527        assert_eq!(validate_next(""), "/");
528        assert_eq!(validate_next("relative/path"), "/");
529        assert_eq!(validate_next("/ok/path://thing"), "/");
530        assert_eq!(validate_next("http://evil.com/foo"), "/");
531    }
532
533    // --- Integration tests ---
534
535    #[tokio::test]
536    async fn get_login_renders_form() {
537        let (ath, config) = setup().await;
538        let app = test_app(ath, config);
539
540        let resp = app
541            .oneshot(
542                Request::builder()
543                    .uri("/login")
544                    .body(Body::empty())
545                    .unwrap(),
546            )
547            .await
548            .unwrap();
549
550        assert_eq!(resp.status(), StatusCode::OK);
551        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
552            .await
553            .unwrap();
554        let html = String::from_utf8(body.to_vec()).unwrap();
555        assert!(html.contains("<form"), "should contain a form");
556        assert!(html.contains("csrf_token"), "should contain csrf_token");
557        assert!(
558            html.contains("identifier"),
559            "should contain identifier input"
560        );
561    }
562
563    #[tokio::test]
564    async fn get_login_redirects_when_authenticated() {
565        let (ath, config) = setup().await;
566
567        let email = Email::new("auth@example.com".into()).unwrap();
568        let user = ath
569            .db()
570            .create_user(email, "pass123", None, None)
571            .await
572            .unwrap();
573        let token = generate_token();
574        let token_hash = hash_token(&token);
575        let expires = Utc::now() + Duration::hours(24);
576        ath.db()
577            .create_session(user.id, token_hash, None, None, expires)
578            .await
579            .unwrap();
580        let cookie = ath.session_cookie(&token);
581        let cookie_val = cookie.split(';').next().unwrap();
582
583        let app = test_app(ath, config);
584        let req = Request::builder()
585            .uri("/login")
586            .header(header::COOKIE, cookie_val)
587            .body(Body::empty())
588            .unwrap();
589        let resp = app.oneshot(req).await.unwrap();
590
591        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
592        assert_eq!(resp.headers().get("location").unwrap(), "/");
593    }
594
595    #[tokio::test]
596    async fn get_login_preserves_next_param() {
597        let (ath, config) = setup().await;
598        let app = test_app(ath, config);
599
600        let resp = app
601            .oneshot(
602                Request::builder()
603                    .uri("/login?next=/dashboard")
604                    .body(Body::empty())
605                    .unwrap(),
606            )
607            .await
608            .unwrap();
609
610        assert_eq!(resp.status(), StatusCode::OK);
611        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
612            .await
613            .unwrap();
614        let html = String::from_utf8(body.to_vec()).unwrap();
615        assert!(
616            html.contains("name=\"next\""),
617            "should contain next hidden field"
618        );
619        assert!(
620            html.contains("dashboard"),
621            "next field should contain dashboard"
622        );
623    }
624
625    #[tokio::test]
626    async fn post_login_success_redirects() {
627        let (ath, config) = setup().await;
628        create_user(&ath, "login@example.com", "correcthorse").await;
629        let app = test_app(ath, config);
630
631        let csrf = get_csrf_token(&app).await;
632        let req = login_request(&csrf, "login@example.com", "correcthorse", None);
633        let resp = app.oneshot(req).await.unwrap();
634
635        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
636        assert_eq!(resp.headers().get("location").unwrap(), "/");
637        assert!(
638            resp.headers().get(SET_COOKIE).is_some(),
639            "should set session cookie"
640        );
641    }
642
643    #[tokio::test]
644    async fn post_login_success_redirects_to_next() {
645        let (ath, config) = setup().await;
646        create_user(&ath, "next@example.com", "correcthorse").await;
647        let app = test_app(ath, config);
648
649        let csrf = get_csrf_token(&app).await;
650        let req = login_request(
651            &csrf,
652            "next@example.com",
653            "correcthorse",
654            Some("/dashboard"),
655        );
656        let resp = app.oneshot(req).await.unwrap();
657
658        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
659        assert_eq!(resp.headers().get("location").unwrap(), "/dashboard");
660    }
661
662    #[tokio::test]
663    async fn post_login_wrong_password_shows_error() {
664        let (ath, config) = setup().await;
665        create_user(&ath, "wrong@example.com", "correcthorse").await;
666        let app = test_app(ath, config);
667
668        let csrf = get_csrf_token(&app).await;
669        let req = login_request(&csrf, "wrong@example.com", "wrongpassword", None);
670        let resp = app.oneshot(req).await.unwrap();
671
672        assert_eq!(resp.status(), StatusCode::OK);
673        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
674            .await
675            .unwrap();
676        let html = String::from_utf8(body.to_vec()).unwrap();
677        assert!(html.contains(LOGIN_ERROR), "should show generic error");
678        assert!(
679            html.contains("wrong@example.com"),
680            "should pre-fill identifier"
681        );
682    }
683
684    #[tokio::test]
685    async fn post_login_nonexistent_user_shows_error() {
686        let (ath, config) = setup().await;
687        let app = test_app(ath, config);
688
689        let csrf = get_csrf_token(&app).await;
690        let req = login_request(&csrf, "nobody@example.com", "anypassword", None);
691        let resp = app.oneshot(req).await.unwrap();
692
693        assert_eq!(resp.status(), StatusCode::OK);
694        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
695            .await
696            .unwrap();
697        let html = String::from_utf8(body.to_vec()).unwrap();
698        assert!(
699            html.contains(LOGIN_ERROR),
700            "should show same generic error as wrong password"
701        );
702    }
703
704    #[tokio::test]
705    async fn post_login_inactive_user_shows_error() {
706        let (ath, config) = setup().await;
707        let email = Email::new("inactive@example.com".into()).unwrap();
708        let user = ath
709            .db()
710            .create_user(email, "correcthorse", None, None)
711            .await
712            .unwrap();
713        ath.db().update_user_active(user.id, false).await.unwrap();
714
715        let app = test_app(ath, config);
716        let csrf = get_csrf_token(&app).await;
717        let req = login_request(&csrf, "inactive@example.com", "correcthorse", None);
718        let resp = app.oneshot(req).await.unwrap();
719
720        assert_eq!(resp.status(), StatusCode::OK);
721        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
722            .await
723            .unwrap();
724        let html = String::from_utf8(body.to_vec()).unwrap();
725        assert!(
726            html.contains(LOGIN_ERROR),
727            "inactive user should get generic error"
728        );
729    }
730
731    #[tokio::test]
732    async fn post_login_rate_limit() {
733        let (ath, config) = setup().await;
734        let app = test_app(ath, config);
735
736        let csrf = get_csrf_token(&app).await;
737
738        // Exhaust rate limit
739        for _ in 0..10_u32 {
740            let req = login_request(&csrf, "nobody@example.com", "wrong", None);
741            let _ = app.clone().oneshot(req).await.unwrap();
742        }
743
744        // Next attempt should be rate limited
745        let req = login_request(&csrf, "nobody@example.com", "wrong", None);
746        let resp = app.oneshot(req).await.unwrap();
747
748        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
749        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
750            .await
751            .unwrap();
752        let html = String::from_utf8(body.to_vec()).unwrap();
753        assert!(html.contains("Too many login attempts"));
754    }
755
756    #[tokio::test]
757    async fn post_login_csrf_required() {
758        let (ath, config) = setup().await;
759        let app = test_app(ath, config);
760
761        let req = Request::builder()
762            .method("POST")
763            .uri("/login")
764            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
765            .body(Body::from("identifier=test&password=test"))
766            .unwrap();
767        let resp = app.oneshot(req).await.unwrap();
768
769        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
770    }
771
772    #[tokio::test]
773    async fn login_with_client_id_shows_branding() {
774        let (ath, config) = setup().await;
775        let (app, _) = ath
776            .db()
777            .create_application(CreateApplicationParams {
778                name: "BrandedApp".into(),
779                client_type: ClientType::Confidential,
780                redirect_uris: vec!["https://example.com/cb".into()],
781                is_trusted: false,
782                created_by: None,
783                logo_url: Some("https://cdn.example.com/logo.png".into()),
784                primary_color: Some("#ff6600".into()),
785                accent_hex: None,
786                accent_ink: None,
787                forced_mode: None,
788                font_css_url: None,
789                font_family: None,
790                splash_text: None,
791                splash_image_url: None,
792                splash_primitive: None,
793                splash_url: None,
794                shader_cell_scale: None,
795            })
796            .await
797            .unwrap();
798        let router = test_app(ath, config);
799
800        let req = Request::builder()
801            .uri(&format!("/login?client_id={}", app.client_id))
802            .body(Body::empty())
803            .unwrap();
804        let resp = router.oneshot(req).await.unwrap();
805        assert_eq!(resp.status(), StatusCode::OK);
806        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
807            .await
808            .unwrap();
809        let html = String::from_utf8(body.to_vec()).unwrap();
810        assert!(html.contains("BrandedApp"), "should show app name");
811        assert!(
812            html.contains("<title>Log in — BrandedApp</title>"),
813            "default title brand should use the application name"
814        );
815        assert!(
816            html.contains("--accent: #ff6600"),
817            "primary_color should flow to --accent"
818        );
819        assert!(
820            html.contains("--accent-ink:"),
821            "accent_ink should be emitted in template"
822        );
823    }
824
825    #[tokio::test]
826    async fn login_without_client_id_shows_default() {
827        let (ath, config) = setup().await;
828        let router = test_app(ath, config);
829
830        let req = Request::builder()
831            .uri("/login")
832            .body(Body::empty())
833            .unwrap();
834        let resp = router.oneshot(req).await.unwrap();
835        assert_eq!(resp.status(), StatusCode::OK);
836        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
837            .await
838            .unwrap();
839        let html = String::from_utf8(body.to_vec()).unwrap();
840        assert!(!html.contains("<img"), "no logo without client_id");
841        assert!(
842            html.contains("--accent: #ffffff"),
843            "should have default white accent"
844        );
845        assert!(
846            html.contains("--accent-ink: #000000"),
847            "should have default black ink"
848        );
849    }
850
851    #[tokio::test]
852    async fn login_with_invalid_client_id_shows_default() {
853        let (ath, config) = setup().await;
854        let router = test_app(ath, config);
855
856        let req = Request::builder()
857            .uri("/login?client_id=ath_nonexistent")
858            .body(Body::empty())
859            .unwrap();
860        let resp = router.oneshot(req).await.unwrap();
861        assert_eq!(resp.status(), StatusCode::OK);
862        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
863            .await
864            .unwrap();
865        let html = String::from_utf8(body.to_vec()).unwrap();
866        assert!(!html.contains("<img"), "no logo for invalid client_id");
867        assert!(
868            html.contains("--accent: #ffffff"),
869            "should fall back to default white accent"
870        );
871        assert!(
872            html.contains("--accent-ink: #000000"),
873            "should fall back to default black ink"
874        );
875    }
876
877    #[tokio::test]
878    async fn login_default_emits_light_mode_accent_override() {
879        // Without branding, base.html must emit both the dark-mode :root
880        // block (white/black) AND an html[data-mode="light"] override
881        // (black/white), so theme toggles stay on-brand without loading
882        // the upstream Catppuccin default.
883        let (ath, config) = setup().await;
884        let router = test_app(ath, config);
885
886        let req = Request::builder()
887            .uri("/login")
888            .body(Body::empty())
889            .unwrap();
890        let resp = router.oneshot(req).await.unwrap();
891        assert_eq!(resp.status(), StatusCode::OK);
892        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
893            .await
894            .unwrap();
895        let html = String::from_utf8(body.to_vec()).unwrap();
896
897        let root_idx = html
898            .find(":root {")
899            .expect(":root block missing from base.html");
900        let light_idx = html
901            .find("html[data-mode=\"light\"] {")
902            .expect("html[data-mode=\"light\"] block missing from base.html");
903        assert!(
904            root_idx < light_idx,
905            ":root block should come before the light-mode override"
906        );
907
908        let root_block = &html[root_idx..light_idx];
909        assert!(
910            root_block.contains("--accent: #ffffff;"),
911            ":root should pin the default white accent"
912        );
913        assert!(
914            root_block.contains("--accent-ink: #000000;"),
915            ":root should pin the default black ink"
916        );
917
918        let light_block = &html[light_idx..];
919        assert!(
920            light_block.contains("--accent: #000000;"),
921            "html[data-mode=\"light\"] should override accent to black"
922        );
923        assert!(
924            light_block.contains("--accent-ink: #ffffff;"),
925            "html[data-mode=\"light\"] should override ink to white"
926        );
927    }
928
929    #[tokio::test]
930    async fn login_branded_emits_same_accent_in_both_modes() {
931        // When branding is configured, both :root and
932        // html[data-mode="light"] must pin the integrator's accent so a
933        // theme toggle doesn't silently revert to the upstream default.
934        let (ath, config) = setup().await;
935        let (app, _) = ath
936            .db()
937            .create_application(CreateApplicationParams {
938                name: "BrandedThemeApp".into(),
939                client_type: ClientType::Confidential,
940                redirect_uris: vec!["https://example.com/cb".into()],
941                is_trusted: false,
942                created_by: None,
943                logo_url: None,
944                primary_color: None,
945                accent_hex: Some("#ff6600".into()),
946                accent_ink: None,
947                forced_mode: None,
948                font_css_url: None,
949                font_family: None,
950                splash_text: None,
951                splash_image_url: None,
952                splash_primitive: None,
953                splash_url: None,
954                shader_cell_scale: None,
955            })
956            .await
957            .unwrap();
958        let router = test_app(ath, config);
959
960        let req = Request::builder()
961            .uri(&format!("/login?client_id={}", app.client_id))
962            .body(Body::empty())
963            .unwrap();
964        let resp = router.oneshot(req).await.unwrap();
965        assert_eq!(resp.status(), StatusCode::OK);
966        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
967            .await
968            .unwrap();
969        let html = String::from_utf8(body.to_vec()).unwrap();
970
971        let root_idx = html
972            .find(":root {")
973            .expect(":root block missing from base.html");
974        let light_idx = html
975            .find("html[data-mode=\"light\"] {")
976            .expect("html[data-mode=\"light\"] block missing from base.html");
977        let root_block = &html[root_idx..light_idx];
978        let light_block = &html[light_idx..];
979        assert!(
980            root_block.contains("--accent: #ff6600;"),
981            ":root should use the integrator's accent"
982        );
983        assert!(
984            light_block.contains("--accent: #ff6600;"),
985            "html[data-mode=\"light\"] should also use the integrator's accent"
986        );
987    }
988
989    #[tokio::test]
990    async fn branded_login_post_failure_preserves_branding() {
991        let (ath, config) = setup().await;
992        create_user(&ath, "branded@example.com", "correcthorse").await;
993        let (app, _) = ath
994            .db()
995            .create_application(CreateApplicationParams {
996                name: "BrandedPost".into(),
997                client_type: ClientType::Confidential,
998                redirect_uris: vec!["https://example.com/cb".into()],
999                is_trusted: false,
1000                created_by: None,
1001                logo_url: None,
1002                primary_color: Some("#ff6600".into()),
1003                accent_hex: None,
1004                accent_ink: None,
1005                forced_mode: None,
1006                font_css_url: None,
1007                font_family: None,
1008                splash_text: None,
1009                splash_image_url: None,
1010                splash_primitive: None,
1011                splash_url: None,
1012                shader_cell_scale: None,
1013            })
1014            .await
1015            .unwrap();
1016        let router = test_app(ath, config);
1017
1018        let csrf = get_csrf_token(&router).await;
1019        let body_str = format!(
1020            "identifier=branded%40example.com&password=wrong&csrf_token={}&client_id={}",
1021            csrf, app.client_id,
1022        );
1023        let req = Request::builder()
1024            .method("POST")
1025            .uri("/login")
1026            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1027            .header(header::COOKIE, format!("csrf_pre={}", csrf))
1028            .body(Body::from(body_str))
1029            .unwrap();
1030        let resp = router.oneshot(req).await.unwrap();
1031        assert_eq!(resp.status(), StatusCode::OK);
1032        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
1033            .await
1034            .unwrap();
1035        let html = String::from_utf8(body.to_vec()).unwrap();
1036        assert!(
1037            html.contains("BrandedPost"),
1038            "app name preserved after error"
1039        );
1040        assert!(
1041            html.contains("--accent: #ff6600"),
1042            "accent color preserved after error"
1043        );
1044    }
1045
1046    #[tokio::test]
1047    async fn post_login_with_mfa_enabled_redirects_to_challenge_without_session() {
1048        // When a user has MFA enabled, correct credentials must NOT create a session —
1049        // they must redirect to /mfa/challenge?token=... instead.
1050        const MFA_KEY: [u8; 32] = [0x42; 32];
1051        let ath = AllowThemBuilder::new("sqlite::memory:")
1052            .cookie_secure(false)
1053            .mfa_key(MFA_KEY)
1054            .csrf_key(*b"test-csrf-key-for-binary-tests!!")
1055            .build()
1056            .await
1057            .unwrap();
1058        let config = LoginConfig {
1059            is_production: false,
1060            login_attempts: Arc::new(DashMap::new()),
1061            max_login_attempts: 10,
1062            rate_limit_window_secs: 900,
1063            oauth_providers: Vec::new(),
1064            signup_url: None,
1065            terms_url: None,
1066            privacy_url: None,
1067        };
1068
1069        // Create user and enable MFA
1070        create_user(&ath, "mfa-gate@example.com", "correcthorse").await;
1071        let user = ath
1072            .db()
1073            .find_for_login("mfa-gate@example.com")
1074            .await
1075            .unwrap();
1076        let secret = ath.create_mfa_secret(user.id).await.unwrap();
1077        use totp_rs::{Algorithm, Secret, TOTP};
1078        let totp = TOTP::new(
1079            Algorithm::SHA1,
1080            6,
1081            1,
1082            30,
1083            Secret::Encoded(secret).to_bytes().unwrap(),
1084            None,
1085            String::new(),
1086        )
1087        .unwrap();
1088        let code = totp.generate_current().unwrap();
1089        ath.enable_mfa(user.id, &code).await.unwrap();
1090
1091        let app = test_app(ath, config);
1092        let csrf = get_csrf_token(&app).await;
1093        let req = login_request(&csrf, "mfa-gate@example.com", "correcthorse", None);
1094        let resp = app.oneshot(req).await.unwrap();
1095
1096        // Must redirect to MFA challenge page, not to /
1097        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1098        let location = resp.headers().get("location").unwrap().to_str().unwrap();
1099        assert!(
1100            location.starts_with("/mfa/challenge?token="),
1101            "MFA gate must redirect to /mfa/challenge, got: {location}"
1102        );
1103        // Must NOT set a session cookie
1104        assert!(
1105            resp.headers().get(SET_COOKIE).is_none(),
1106            "MFA gate must not set a session cookie before TOTP is verified"
1107        );
1108    }
1109
1110    #[tokio::test]
1111    async fn render_login_fragment_composes_main_and_oob_head() {
1112        let (_ath, config) = setup().await;
1113        let html = render_login_fragment(&config, "tok", "", None, "", None, None)
1114            .unwrap()
1115            .0;
1116        assert!(
1117            html.contains("<main class=\"wf-auth-form\">"),
1118            "fragment must include the <main> root"
1119        );
1120        assert!(
1121            html.contains("<title hx-swap-oob=\"true\">"),
1122            "fragment must include the OOB <title> tag"
1123        );
1124        assert!(
1125            html.contains("id=\"wf-screen-label\""),
1126            "fragment must include the OOB #wf-screen-label span"
1127        );
1128    }
1129
1130    #[tokio::test]
1131    async fn login_register_link_carries_client_id() {
1132        let (ath, config) = setup().await;
1133        let (app, _) = ath
1134            .db()
1135            .create_application(CreateApplicationParams {
1136                name: "LinkApp".into(),
1137                client_type: ClientType::Confidential,
1138                redirect_uris: vec!["https://example.com/cb".into()],
1139                is_trusted: false,
1140                created_by: None,
1141                logo_url: None,
1142                primary_color: None,
1143                accent_hex: None,
1144                accent_ink: None,
1145                forced_mode: None,
1146                font_css_url: None,
1147                font_family: None,
1148                splash_text: None,
1149                splash_image_url: None,
1150                splash_primitive: None,
1151                splash_url: None,
1152                shader_cell_scale: None,
1153            })
1154            .await
1155            .unwrap();
1156        let router = test_app(ath, config);
1157
1158        let req = Request::builder()
1159            .uri(&format!("/login?client_id={}", app.client_id))
1160            .body(Body::empty())
1161            .unwrap();
1162        let resp = router.oneshot(req).await.unwrap();
1163        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
1164            .await
1165            .unwrap();
1166        let html = String::from_utf8(body.to_vec()).unwrap();
1167        let id = app.client_id.as_str();
1168        let register_link = format!("/register?client_id={id}");
1169        assert!(
1170            html.contains(&register_link),
1171            "register link should carry client_id"
1172        );
1173    }
1174}