Skip to main content

allowthem_server/
password_reset_page_routes.rs

1use std::sync::Arc;
2
3use axum::Extension;
4use axum::Form;
5use axum::Router;
6use axum::extract::{Query, State};
7use axum::http::HeaderMap;
8use axum::http::StatusCode;
9use axum::http::header::COOKIE;
10use axum::response::{IntoResponse, Response};
11use axum::routing::get;
12use minijinja::{Environment, context};
13use serde::Deserialize;
14
15use allowthem_core::{AllowThem, Email, EmailSender};
16
17use crate::browser_error::BrowserError;
18use crate::csrf::CsrfToken;
19
20const MIN_PASSWORD_LEN: usize = 8;
21
22#[derive(Clone)]
23struct PasswordResetPageConfig {
24    templates: Arc<Environment<'static>>,
25    is_production: bool,
26    email_sender: Arc<dyn EmailSender>,
27    base_url: String,
28}
29
30#[derive(Deserialize)]
31pub struct ResetTokenQuery {
32    token: Option<String>,
33}
34
35#[derive(Deserialize)]
36pub struct ForgotPasswordForm {
37    email: String,
38    #[allow(dead_code)]
39    csrf_token: String,
40}
41
42#[derive(Deserialize)]
43pub struct ResetPasswordForm {
44    token: String,
45    new_password: String,
46    confirm_password: String,
47    #[allow(dead_code)]
48    csrf_token: String,
49}
50
51/// GET /forgot-password — render the email input form.
52async fn get_forgot_password(
53    State(ath): State<AllowThem>,
54    Extension(config): Extension<PasswordResetPageConfig>,
55    headers: HeaderMap,
56    csrf: CsrfToken,
57) -> Result<Response, BrowserError> {
58    if is_authenticated(&ath, &headers).await {
59        return Ok((StatusCode::SEE_OTHER, [(axum::http::header::LOCATION, "/")]).into_response());
60    }
61
62    let html = crate::browser_templates::render(
63        &config.templates,
64        "forgot_password.html",
65        context! {
66            csrf_token => csrf.as_str(),
67            success => false,
68            error => "",
69            is_production => config.is_production,
70        },
71    )?;
72    Ok(html.into_response())
73}
74
75/// POST /forgot-password — initiate reset; always render success to prevent enumeration.
76async fn post_forgot_password(
77    State(ath): State<AllowThem>,
78    Extension(config): Extension<PasswordResetPageConfig>,
79    csrf: CsrfToken,
80    Form(form): Form<ForgotPasswordForm>,
81) -> Result<Response, BrowserError> {
82    let email = match Email::new(form.email.clone()) {
83        Ok(e) => e,
84        Err(_) => {
85            let html = crate::browser_templates::render(
86                &config.templates,
87                "forgot_password.html",
88                context! {
89                    csrf_token => csrf.as_str(),
90                    success => false,
91                    error => "Please enter a valid email address.",
92                    is_production => config.is_production,
93                },
94            )?;
95            return Ok(html.into_response());
96        }
97    };
98
99    let sender: &dyn EmailSender = &*config.email_sender;
100    if let Err(err) = ath
101        .db()
102        .send_password_reset(&email, &config.base_url, sender)
103        .await
104    {
105        tracing::error!("password reset email error: {err}");
106    }
107
108    let html = crate::browser_templates::render(
109        &config.templates,
110        "forgot_password.html",
111        context! {
112            csrf_token => csrf.as_str(),
113            success => true,
114            error => "",
115            is_production => config.is_production,
116        },
117    )?;
118    Ok(html.into_response())
119}
120
121/// GET /auth/reset-password?token=... — validate token and render form or error state.
122async fn get_reset_password(
123    State(ath): State<AllowThem>,
124    Extension(config): Extension<PasswordResetPageConfig>,
125    csrf: CsrfToken,
126    Query(query): Query<ResetTokenQuery>,
127) -> Result<Response, BrowserError> {
128    let token = match query.token {
129        Some(ref t) if !t.is_empty() => t.clone(),
130        _ => {
131            let html = crate::browser_templates::render(
132                &config.templates,
133                "reset_password.html",
134                context! {
135                    csrf_token => csrf.as_str(),
136                    token => "",
137                    invalid_token => true,
138                    success => false,
139                    error => "",
140                    is_production => config.is_production,
141                },
142            )?;
143            return Ok(html.into_response());
144        }
145    };
146
147    let valid = ath.db().validate_reset_token(&token).await?;
148
149    if valid.is_some() {
150        let html = crate::browser_templates::render(
151            &config.templates,
152            "reset_password.html",
153            context! {
154                csrf_token => csrf.as_str(),
155                token,
156                invalid_token => false,
157                success => false,
158                error => "",
159                is_production => config.is_production,
160            },
161        )?;
162        Ok(html.into_response())
163    } else {
164        let html = crate::browser_templates::render(
165            &config.templates,
166            "reset_password.html",
167            context! {
168                csrf_token => csrf.as_str(),
169                token => "",
170                invalid_token => true,
171                success => false,
172                error => "",
173                is_production => config.is_production,
174            },
175        )?;
176        Ok(html.into_response())
177    }
178}
179
180/// POST /auth/reset-password — execute the password reset.
181async fn post_reset_password(
182    State(ath): State<AllowThem>,
183    Extension(config): Extension<PasswordResetPageConfig>,
184    csrf: CsrfToken,
185    Form(form): Form<ResetPasswordForm>,
186) -> Result<Response, BrowserError> {
187    // Validate: passwords match
188    if form.new_password != form.confirm_password {
189        let html = crate::browser_templates::render(
190            &config.templates,
191            "reset_password.html",
192            context! {
193                csrf_token => csrf.as_str(),
194                token => form.token,
195                invalid_token => false,
196                success => false,
197                error => "Passwords do not match",
198                is_production => config.is_production,
199            },
200        )?;
201        return Ok(html.into_response());
202    }
203
204    // Validate: password length
205    if form.new_password.len() < MIN_PASSWORD_LEN {
206        let html = crate::browser_templates::render(
207            &config.templates,
208            "reset_password.html",
209            context! {
210                csrf_token => csrf.as_str(),
211                token => form.token,
212                invalid_token => false,
213                success => false,
214                error => "Password must be at least 8 characters",
215                is_production => config.is_production,
216            },
217        )?;
218        return Ok(html.into_response());
219    }
220
221    match ath
222        .db()
223        .execute_reset(&form.token, &form.new_password)
224        .await?
225    {
226        true => {
227            let html = crate::browser_templates::render(
228                &config.templates,
229                "reset_password.html",
230                context! {
231                    csrf_token => csrf.as_str(),
232                    token => "",
233                    invalid_token => false,
234                    success => true,
235                    error => "",
236                    is_production => config.is_production,
237                },
238            )?;
239            Ok(html.into_response())
240        }
241        false => {
242            let html = crate::browser_templates::render(
243                &config.templates,
244                "reset_password.html",
245                context! {
246                    csrf_token => csrf.as_str(),
247                    token => "",
248                    invalid_token => true,
249                    success => false,
250                    error => "",
251                    is_production => config.is_production,
252                },
253            )?;
254            Ok(html.into_response())
255        }
256    }
257}
258
259/// Returns true if the request carries a valid session cookie.
260async fn is_authenticated(ath: &AllowThem, headers: &HeaderMap) -> bool {
261    let Some(cookie_header) = headers.get(COOKIE).and_then(|v| v.to_str().ok()) else {
262        return false;
263    };
264    let Some(token) = ath.parse_session_cookie(cookie_header) else {
265        return false;
266    };
267    let ttl = ath.session_config().ttl;
268    ath.db()
269        .validate_session(&token, ttl)
270        .await
271        .unwrap_or(None)
272        .is_some()
273}
274
275pub fn password_reset_page_routes(
276    templates: Arc<Environment<'static>>,
277    is_production: bool,
278    email_sender: Arc<dyn EmailSender>,
279    base_url: String,
280) -> Router<AllowThem> {
281    let cfg = PasswordResetPageConfig {
282        templates,
283        is_production,
284        email_sender,
285        base_url,
286    };
287    Router::new()
288        .route(
289            "/forgot-password",
290            get(get_forgot_password).post(post_forgot_password),
291        )
292        .route(
293            "/auth/reset-password",
294            get(get_reset_password).post(post_reset_password),
295        )
296        .layer(Extension(cfg))
297}
298
299#[cfg(test)]
300mod tests {
301    use std::sync::Arc;
302
303    use axum::Router;
304    use axum::body::Body;
305    use axum::http::{Request, StatusCode, header};
306    use tower::ServiceExt;
307
308    use allowthem_core::{AllowThem, AllowThemBuilder, Email, LogEmailSender};
309
310    use super::{PasswordResetPageConfig, password_reset_page_routes};
311
312    async fn setup() -> (AllowThem, PasswordResetPageConfig) {
313        let ath = AllowThemBuilder::new("sqlite::memory:")
314            .cookie_secure(false)
315            .csrf_key(*b"test-csrf-key-for-binary-tests!!")
316            .build()
317            .await
318            .unwrap();
319        let templates = crate::browser_templates::build_default_browser_env();
320        let config = PasswordResetPageConfig {
321            templates,
322            is_production: false,
323            email_sender: Arc::new(LogEmailSender),
324            base_url: "http://localhost:3000".into(),
325        };
326        (ath, config)
327    }
328
329    fn test_app(ath: AllowThem, config: PasswordResetPageConfig) -> Router {
330        password_reset_page_routes(
331            config.templates.clone(),
332            config.is_production,
333            config.email_sender.clone(),
334            config.base_url.clone(),
335        )
336        .layer(axum::middleware::from_fn_with_state(
337            ath.clone(),
338            crate::csrf::csrf_middleware,
339        ))
340        .with_state(ath)
341    }
342
343    async fn get_csrf_token(app: &Router, path: &str) -> String {
344        let req = Request::builder().uri(path).body(Body::empty()).unwrap();
345        let resp = app.clone().oneshot(req).await.unwrap();
346        let set_cookie = resp
347            .headers()
348            .get(header::SET_COOKIE)
349            .unwrap()
350            .to_str()
351            .unwrap()
352            .to_string();
353        set_cookie
354            .split(';')
355            .next()
356            .unwrap()
357            .split('=')
358            .nth(1)
359            .unwrap()
360            .to_string()
361    }
362
363    async fn body_string(resp: axum::http::Response<Body>) -> String {
364        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
365            .await
366            .unwrap();
367        String::from_utf8(bytes.to_vec()).unwrap()
368    }
369
370    async fn create_user_and_token(ath: &AllowThem, email_str: &str) -> String {
371        let email = Email::new(email_str.into()).unwrap();
372        ath.db()
373            .create_user(email.clone(), "OldPass123!", None, None)
374            .await
375            .unwrap();
376        ath.db()
377            .create_password_reset(&email)
378            .await
379            .unwrap()
380            .unwrap()
381    }
382
383    #[tokio::test]
384    async fn get_forgot_password_renders_form() {
385        let (ath, config) = setup().await;
386        let app = test_app(ath, config);
387        let resp = app
388            .oneshot(
389                Request::builder()
390                    .uri("/forgot-password")
391                    .body(Body::empty())
392                    .unwrap(),
393            )
394            .await
395            .unwrap();
396        assert_eq!(resp.status(), StatusCode::OK);
397        let html = body_string(resp).await;
398        assert!(html.contains("<form"));
399        assert!(html.contains("name=\"email\""));
400    }
401
402    #[tokio::test]
403    async fn post_forgot_password_valid_email_shows_success() {
404        let (ath, config) = setup().await;
405        let email = Email::new("reset@example.com".into()).unwrap();
406        ath.db()
407            .create_user(email, "Pass123!", None, None)
408            .await
409            .unwrap();
410        let app = test_app(ath, config);
411        let csrf = get_csrf_token(&app, "/forgot-password").await;
412
413        let req = Request::builder()
414            .method("POST")
415            .uri("/forgot-password")
416            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
417            .header(header::COOKIE, format!("csrf_pre={csrf}"))
418            .body(Body::from(format!(
419                "email=reset%40example.com&csrf_token={csrf}"
420            )))
421            .unwrap();
422        let resp = app.oneshot(req).await.unwrap();
423        assert_eq!(resp.status(), StatusCode::OK);
424        let html = body_string(resp).await;
425        assert!(html.contains("If an account with that email exists"));
426    }
427
428    #[tokio::test]
429    async fn post_forgot_password_unknown_email_shows_success() {
430        let (ath, config) = setup().await;
431        let app = test_app(ath, config);
432        let csrf = get_csrf_token(&app, "/forgot-password").await;
433
434        let req = Request::builder()
435            .method("POST")
436            .uri("/forgot-password")
437            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
438            .header(header::COOKIE, format!("csrf_pre={csrf}"))
439            .body(Body::from(format!(
440                "email=nobody%40example.com&csrf_token={csrf}"
441            )))
442            .unwrap();
443        let resp = app.oneshot(req).await.unwrap();
444        assert_eq!(resp.status(), StatusCode::OK);
445        let html = body_string(resp).await;
446        assert!(html.contains("If an account with that email exists"));
447    }
448
449    #[tokio::test]
450    async fn post_forgot_password_invalid_email_shows_error() {
451        let (ath, config) = setup().await;
452        let app = test_app(ath, config);
453        let csrf = get_csrf_token(&app, "/forgot-password").await;
454
455        let req = Request::builder()
456            .method("POST")
457            .uri("/forgot-password")
458            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
459            .header(header::COOKIE, format!("csrf_pre={csrf}"))
460            .body(Body::from(format!("email=notanemail&csrf_token={csrf}")))
461            .unwrap();
462        let resp = app.oneshot(req).await.unwrap();
463        assert_eq!(resp.status(), StatusCode::OK);
464        let html = body_string(resp).await;
465        assert!(html.contains("Please enter a valid email address."));
466    }
467
468    #[tokio::test]
469    async fn get_reset_password_valid_token_renders_form() {
470        let (ath, config) = setup().await;
471        let token = create_user_and_token(&ath, "tok@example.com").await;
472        let app = test_app(ath, config);
473
474        let resp = app
475            .oneshot(
476                Request::builder()
477                    .uri(format!("/auth/reset-password?token={token}"))
478                    .body(Body::empty())
479                    .unwrap(),
480            )
481            .await
482            .unwrap();
483        assert_eq!(resp.status(), StatusCode::OK);
484        let html = body_string(resp).await;
485        assert!(html.contains("name=\"new_password\""));
486        assert!(html.contains("name=\"confirm_password\""));
487    }
488
489    #[tokio::test]
490    async fn get_reset_password_invalid_token_shows_error() {
491        let (ath, config) = setup().await;
492        let app = test_app(ath, config);
493
494        let resp = app
495            .oneshot(
496                Request::builder()
497                    .uri("/auth/reset-password?token=invalidtoken")
498                    .body(Body::empty())
499                    .unwrap(),
500            )
501            .await
502            .unwrap();
503        assert_eq!(resp.status(), StatusCode::OK);
504        let html = body_string(resp).await;
505        assert!(html.contains("invalid or has expired"));
506        assert!(!html.contains("name=\"new_password\""));
507    }
508
509    #[tokio::test]
510    async fn post_reset_password_passwords_mismatch_shows_error() {
511        let (ath, config) = setup().await;
512        let token = create_user_and_token(&ath, "mismatch@example.com").await;
513        let app = test_app(ath, config);
514        let csrf = get_csrf_token(&app, &format!("/auth/reset-password?token={token}")).await;
515
516        let req = Request::builder()
517            .method("POST")
518            .uri("/auth/reset-password")
519            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
520            .header(header::COOKIE, format!("csrf_pre={csrf}"))
521            .body(Body::from(format!(
522                "token={token}&new_password=NewPass999!&confirm_password=Different1!&csrf_token={csrf}"
523            )))
524            .unwrap();
525        let resp = app.oneshot(req).await.unwrap();
526        assert_eq!(resp.status(), StatusCode::OK);
527        let html = body_string(resp).await;
528        assert!(html.contains("Passwords do not match"));
529    }
530
531    #[tokio::test]
532    async fn post_reset_password_too_short_shows_error() {
533        let (ath, config) = setup().await;
534        let token = create_user_and_token(&ath, "short@example.com").await;
535        let app = test_app(ath, config);
536        let csrf = get_csrf_token(&app, &format!("/auth/reset-password?token={token}")).await;
537
538        let req = Request::builder()
539            .method("POST")
540            .uri("/auth/reset-password")
541            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
542            .header(header::COOKIE, format!("csrf_pre={csrf}"))
543            .body(Body::from(format!(
544                "token={token}&new_password=short&confirm_password=short&csrf_token={csrf}"
545            )))
546            .unwrap();
547        let resp = app.oneshot(req).await.unwrap();
548        assert_eq!(resp.status(), StatusCode::OK);
549        let html = body_string(resp).await;
550        assert!(html.contains("Password must be at least 8 characters"));
551    }
552
553    #[tokio::test]
554    async fn post_reset_password_success_shows_confirmation() {
555        let (ath, config) = setup().await;
556        let token = create_user_and_token(&ath, "success@example.com").await;
557        let app = test_app(ath, config);
558        let csrf = get_csrf_token(&app, &format!("/auth/reset-password?token={token}")).await;
559
560        let req = Request::builder()
561            .method("POST")
562            .uri("/auth/reset-password")
563            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
564            .header(header::COOKIE, format!("csrf_pre={csrf}"))
565            .body(Body::from(format!(
566                "token={token}&new_password=NewPass999!&confirm_password=NewPass999!&csrf_token={csrf}"
567            )))
568            .unwrap();
569        let resp = app.oneshot(req).await.unwrap();
570        assert_eq!(resp.status(), StatusCode::OK);
571        let html = body_string(resp).await;
572        assert!(html.contains("Your password has been reset"));
573    }
574
575    #[tokio::test]
576    async fn post_reset_password_used_token_shows_invalid() {
577        let (ath, config) = setup().await;
578        let token = create_user_and_token(&ath, "used@example.com").await;
579        // Consume the token directly via DB
580        ath.db()
581            .execute_reset(&token, "AlreadyUsed1!")
582            .await
583            .unwrap();
584
585        let app = test_app(ath, config);
586        let csrf = get_csrf_token(&app, "/forgot-password").await;
587
588        let req = Request::builder()
589            .method("POST")
590            .uri("/auth/reset-password")
591            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
592            .header(header::COOKIE, format!("csrf_pre={csrf}"))
593            .body(Body::from(format!(
594                "token={token}&new_password=NewPass999!&confirm_password=NewPass999!&csrf_token={csrf}"
595            )))
596            .unwrap();
597        let resp = app.oneshot(req).await.unwrap();
598        assert_eq!(resp.status(), StatusCode::OK);
599        let html = body_string(resp).await;
600        assert!(html.contains("invalid or has expired"));
601    }
602
603    #[tokio::test]
604    async fn get_forgot_password_logged_in_redirects_to_root() {
605        use allowthem_core::{generate_token, hash_token};
606        use chrono::{Duration, Utc};
607
608        let (ath, config) = setup().await;
609
610        // Create a user and an active session
611        let email = Email::new("loggedin@example.com".into()).unwrap();
612        let user = ath
613            .db()
614            .create_user(email, "password123", None, None)
615            .await
616            .unwrap();
617        let token = generate_token();
618        let token_hash = hash_token(&token);
619        ath.db()
620            .create_session(
621                user.id,
622                token_hash,
623                None,
624                None,
625                Utc::now() + Duration::hours(24),
626            )
627            .await
628            .unwrap();
629        let session_cookie = ath.session_cookie(&token);
630        let cookie_value = session_cookie.split(';').next().unwrap().to_string();
631
632        let app = test_app(ath, config);
633        let req = Request::builder()
634            .uri("/forgot-password")
635            .header(header::COOKIE, cookie_value)
636            .body(Body::empty())
637            .unwrap();
638        let resp = app.oneshot(req).await.unwrap();
639
640        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
641        assert_eq!(resp.headers().get("location").unwrap(), "/");
642    }
643}