1use std::sync::Arc;
2
3use axum::Extension;
4use axum::Form;
5use axum::Router;
6use axum::extract::{Query, State};
7use axum::http::HeaderMap;
8use axum::http::StatusCode;
9use axum::http::header::COOKIE;
10use axum::response::{IntoResponse, Response};
11use axum::routing::get;
12use minijinja::{Environment, context};
13use serde::Deserialize;
14
15use allowthem_core::{AllowThem, Email, EmailSender};
16
17use crate::browser_error::BrowserError;
18use crate::csrf::CsrfToken;
19
20const MIN_PASSWORD_LEN: usize = 8;
21
22#[derive(Clone)]
23struct PasswordResetPageConfig {
24 templates: Arc<Environment<'static>>,
25 is_production: bool,
26 email_sender: Arc<dyn EmailSender>,
27 base_url: String,
28}
29
30#[derive(Deserialize)]
31pub struct ResetTokenQuery {
32 token: Option<String>,
33}
34
35#[derive(Deserialize)]
36pub struct ForgotPasswordForm {
37 email: String,
38 #[allow(dead_code)]
39 csrf_token: String,
40}
41
42#[derive(Deserialize)]
43pub struct ResetPasswordForm {
44 token: String,
45 new_password: String,
46 confirm_password: String,
47 #[allow(dead_code)]
48 csrf_token: String,
49}
50
51async fn get_forgot_password(
53 State(ath): State<AllowThem>,
54 Extension(config): Extension<PasswordResetPageConfig>,
55 headers: HeaderMap,
56 csrf: CsrfToken,
57) -> Result<Response, BrowserError> {
58 if is_authenticated(&ath, &headers).await {
59 return Ok((StatusCode::SEE_OTHER, [(axum::http::header::LOCATION, "/")]).into_response());
60 }
61
62 let html = crate::browser_templates::render(
63 &config.templates,
64 "forgot_password.html",
65 context! {
66 csrf_token => csrf.as_str(),
67 success => false,
68 error => "",
69 is_production => config.is_production,
70 },
71 )?;
72 Ok(html.into_response())
73}
74
75async fn post_forgot_password(
77 State(ath): State<AllowThem>,
78 Extension(config): Extension<PasswordResetPageConfig>,
79 csrf: CsrfToken,
80 Form(form): Form<ForgotPasswordForm>,
81) -> Result<Response, BrowserError> {
82 let email = match Email::new(form.email.clone()) {
83 Ok(e) => e,
84 Err(_) => {
85 let html = crate::browser_templates::render(
86 &config.templates,
87 "forgot_password.html",
88 context! {
89 csrf_token => csrf.as_str(),
90 success => false,
91 error => "Please enter a valid email address.",
92 is_production => config.is_production,
93 },
94 )?;
95 return Ok(html.into_response());
96 }
97 };
98
99 let sender: &dyn EmailSender = &*config.email_sender;
100 if let Err(err) = ath
101 .db()
102 .send_password_reset(&email, &config.base_url, sender)
103 .await
104 {
105 tracing::error!("password reset email error: {err}");
106 }
107
108 let html = crate::browser_templates::render(
109 &config.templates,
110 "forgot_password.html",
111 context! {
112 csrf_token => csrf.as_str(),
113 success => true,
114 error => "",
115 is_production => config.is_production,
116 },
117 )?;
118 Ok(html.into_response())
119}
120
121async fn get_reset_password(
123 State(ath): State<AllowThem>,
124 Extension(config): Extension<PasswordResetPageConfig>,
125 csrf: CsrfToken,
126 Query(query): Query<ResetTokenQuery>,
127) -> Result<Response, BrowserError> {
128 let token = match query.token {
129 Some(ref t) if !t.is_empty() => t.clone(),
130 _ => {
131 let html = crate::browser_templates::render(
132 &config.templates,
133 "reset_password.html",
134 context! {
135 csrf_token => csrf.as_str(),
136 token => "",
137 invalid_token => true,
138 success => false,
139 error => "",
140 is_production => config.is_production,
141 },
142 )?;
143 return Ok(html.into_response());
144 }
145 };
146
147 let valid = ath.db().validate_reset_token(&token).await?;
148
149 if valid.is_some() {
150 let html = crate::browser_templates::render(
151 &config.templates,
152 "reset_password.html",
153 context! {
154 csrf_token => csrf.as_str(),
155 token,
156 invalid_token => false,
157 success => false,
158 error => "",
159 is_production => config.is_production,
160 },
161 )?;
162 Ok(html.into_response())
163 } else {
164 let html = crate::browser_templates::render(
165 &config.templates,
166 "reset_password.html",
167 context! {
168 csrf_token => csrf.as_str(),
169 token => "",
170 invalid_token => true,
171 success => false,
172 error => "",
173 is_production => config.is_production,
174 },
175 )?;
176 Ok(html.into_response())
177 }
178}
179
180async fn post_reset_password(
182 State(ath): State<AllowThem>,
183 Extension(config): Extension<PasswordResetPageConfig>,
184 csrf: CsrfToken,
185 Form(form): Form<ResetPasswordForm>,
186) -> Result<Response, BrowserError> {
187 if form.new_password != form.confirm_password {
189 let html = crate::browser_templates::render(
190 &config.templates,
191 "reset_password.html",
192 context! {
193 csrf_token => csrf.as_str(),
194 token => form.token,
195 invalid_token => false,
196 success => false,
197 error => "Passwords do not match",
198 is_production => config.is_production,
199 },
200 )?;
201 return Ok(html.into_response());
202 }
203
204 if form.new_password.len() < MIN_PASSWORD_LEN {
206 let html = crate::browser_templates::render(
207 &config.templates,
208 "reset_password.html",
209 context! {
210 csrf_token => csrf.as_str(),
211 token => form.token,
212 invalid_token => false,
213 success => false,
214 error => "Password must be at least 8 characters",
215 is_production => config.is_production,
216 },
217 )?;
218 return Ok(html.into_response());
219 }
220
221 match ath
222 .db()
223 .execute_reset(&form.token, &form.new_password)
224 .await?
225 {
226 true => {
227 let html = crate::browser_templates::render(
228 &config.templates,
229 "reset_password.html",
230 context! {
231 csrf_token => csrf.as_str(),
232 token => "",
233 invalid_token => false,
234 success => true,
235 error => "",
236 is_production => config.is_production,
237 },
238 )?;
239 Ok(html.into_response())
240 }
241 false => {
242 let html = crate::browser_templates::render(
243 &config.templates,
244 "reset_password.html",
245 context! {
246 csrf_token => csrf.as_str(),
247 token => "",
248 invalid_token => true,
249 success => false,
250 error => "",
251 is_production => config.is_production,
252 },
253 )?;
254 Ok(html.into_response())
255 }
256 }
257}
258
259async fn is_authenticated(ath: &AllowThem, headers: &HeaderMap) -> bool {
261 let Some(cookie_header) = headers.get(COOKIE).and_then(|v| v.to_str().ok()) else {
262 return false;
263 };
264 let Some(token) = ath.parse_session_cookie(cookie_header) else {
265 return false;
266 };
267 let ttl = ath.session_config().ttl;
268 ath.db()
269 .validate_session(&token, ttl)
270 .await
271 .unwrap_or(None)
272 .is_some()
273}
274
275pub fn password_reset_page_routes(
276 templates: Arc<Environment<'static>>,
277 is_production: bool,
278 email_sender: Arc<dyn EmailSender>,
279 base_url: String,
280) -> Router<AllowThem> {
281 let cfg = PasswordResetPageConfig {
282 templates,
283 is_production,
284 email_sender,
285 base_url,
286 };
287 Router::new()
288 .route(
289 "/forgot-password",
290 get(get_forgot_password).post(post_forgot_password),
291 )
292 .route(
293 "/auth/reset-password",
294 get(get_reset_password).post(post_reset_password),
295 )
296 .layer(Extension(cfg))
297}
298
299#[cfg(test)]
300mod tests {
301 use std::sync::Arc;
302
303 use axum::Router;
304 use axum::body::Body;
305 use axum::http::{Request, StatusCode, header};
306 use tower::ServiceExt;
307
308 use allowthem_core::{AllowThem, AllowThemBuilder, Email, LogEmailSender};
309
310 use super::{PasswordResetPageConfig, password_reset_page_routes};
311
312 async fn setup() -> (AllowThem, PasswordResetPageConfig) {
313 let ath = AllowThemBuilder::new("sqlite::memory:")
314 .cookie_secure(false)
315 .csrf_key(*b"test-csrf-key-for-binary-tests!!")
316 .build()
317 .await
318 .unwrap();
319 let templates = crate::browser_templates::build_default_browser_env();
320 let config = PasswordResetPageConfig {
321 templates,
322 is_production: false,
323 email_sender: Arc::new(LogEmailSender),
324 base_url: "http://localhost:3000".into(),
325 };
326 (ath, config)
327 }
328
329 fn test_app(ath: AllowThem, config: PasswordResetPageConfig) -> Router {
330 password_reset_page_routes(
331 config.templates.clone(),
332 config.is_production,
333 config.email_sender.clone(),
334 config.base_url.clone(),
335 )
336 .layer(axum::middleware::from_fn_with_state(
337 ath.clone(),
338 crate::csrf::csrf_middleware,
339 ))
340 .with_state(ath)
341 }
342
343 async fn get_csrf_token(app: &Router, path: &str) -> String {
344 let req = Request::builder().uri(path).body(Body::empty()).unwrap();
345 let resp = app.clone().oneshot(req).await.unwrap();
346 let set_cookie = resp
347 .headers()
348 .get(header::SET_COOKIE)
349 .unwrap()
350 .to_str()
351 .unwrap()
352 .to_string();
353 set_cookie
354 .split(';')
355 .next()
356 .unwrap()
357 .split('=')
358 .nth(1)
359 .unwrap()
360 .to_string()
361 }
362
363 async fn body_string(resp: axum::http::Response<Body>) -> String {
364 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
365 .await
366 .unwrap();
367 String::from_utf8(bytes.to_vec()).unwrap()
368 }
369
370 async fn create_user_and_token(ath: &AllowThem, email_str: &str) -> String {
371 let email = Email::new(email_str.into()).unwrap();
372 ath.db()
373 .create_user(email.clone(), "OldPass123!", None, None)
374 .await
375 .unwrap();
376 ath.db()
377 .create_password_reset(&email)
378 .await
379 .unwrap()
380 .unwrap()
381 }
382
383 #[tokio::test]
384 async fn get_forgot_password_renders_form() {
385 let (ath, config) = setup().await;
386 let app = test_app(ath, config);
387 let resp = app
388 .oneshot(
389 Request::builder()
390 .uri("/forgot-password")
391 .body(Body::empty())
392 .unwrap(),
393 )
394 .await
395 .unwrap();
396 assert_eq!(resp.status(), StatusCode::OK);
397 let html = body_string(resp).await;
398 assert!(html.contains("<form"));
399 assert!(html.contains("name=\"email\""));
400 }
401
402 #[tokio::test]
403 async fn post_forgot_password_valid_email_shows_success() {
404 let (ath, config) = setup().await;
405 let email = Email::new("reset@example.com".into()).unwrap();
406 ath.db()
407 .create_user(email, "Pass123!", None, None)
408 .await
409 .unwrap();
410 let app = test_app(ath, config);
411 let csrf = get_csrf_token(&app, "/forgot-password").await;
412
413 let req = Request::builder()
414 .method("POST")
415 .uri("/forgot-password")
416 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
417 .header(header::COOKIE, format!("csrf_pre={csrf}"))
418 .body(Body::from(format!(
419 "email=reset%40example.com&csrf_token={csrf}"
420 )))
421 .unwrap();
422 let resp = app.oneshot(req).await.unwrap();
423 assert_eq!(resp.status(), StatusCode::OK);
424 let html = body_string(resp).await;
425 assert!(html.contains("If an account with that email exists"));
426 }
427
428 #[tokio::test]
429 async fn post_forgot_password_unknown_email_shows_success() {
430 let (ath, config) = setup().await;
431 let app = test_app(ath, config);
432 let csrf = get_csrf_token(&app, "/forgot-password").await;
433
434 let req = Request::builder()
435 .method("POST")
436 .uri("/forgot-password")
437 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
438 .header(header::COOKIE, format!("csrf_pre={csrf}"))
439 .body(Body::from(format!(
440 "email=nobody%40example.com&csrf_token={csrf}"
441 )))
442 .unwrap();
443 let resp = app.oneshot(req).await.unwrap();
444 assert_eq!(resp.status(), StatusCode::OK);
445 let html = body_string(resp).await;
446 assert!(html.contains("If an account with that email exists"));
447 }
448
449 #[tokio::test]
450 async fn post_forgot_password_invalid_email_shows_error() {
451 let (ath, config) = setup().await;
452 let app = test_app(ath, config);
453 let csrf = get_csrf_token(&app, "/forgot-password").await;
454
455 let req = Request::builder()
456 .method("POST")
457 .uri("/forgot-password")
458 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
459 .header(header::COOKIE, format!("csrf_pre={csrf}"))
460 .body(Body::from(format!("email=notanemail&csrf_token={csrf}")))
461 .unwrap();
462 let resp = app.oneshot(req).await.unwrap();
463 assert_eq!(resp.status(), StatusCode::OK);
464 let html = body_string(resp).await;
465 assert!(html.contains("Please enter a valid email address."));
466 }
467
468 #[tokio::test]
469 async fn get_reset_password_valid_token_renders_form() {
470 let (ath, config) = setup().await;
471 let token = create_user_and_token(&ath, "tok@example.com").await;
472 let app = test_app(ath, config);
473
474 let resp = app
475 .oneshot(
476 Request::builder()
477 .uri(format!("/auth/reset-password?token={token}"))
478 .body(Body::empty())
479 .unwrap(),
480 )
481 .await
482 .unwrap();
483 assert_eq!(resp.status(), StatusCode::OK);
484 let html = body_string(resp).await;
485 assert!(html.contains("name=\"new_password\""));
486 assert!(html.contains("name=\"confirm_password\""));
487 }
488
489 #[tokio::test]
490 async fn get_reset_password_invalid_token_shows_error() {
491 let (ath, config) = setup().await;
492 let app = test_app(ath, config);
493
494 let resp = app
495 .oneshot(
496 Request::builder()
497 .uri("/auth/reset-password?token=invalidtoken")
498 .body(Body::empty())
499 .unwrap(),
500 )
501 .await
502 .unwrap();
503 assert_eq!(resp.status(), StatusCode::OK);
504 let html = body_string(resp).await;
505 assert!(html.contains("invalid or has expired"));
506 assert!(!html.contains("name=\"new_password\""));
507 }
508
509 #[tokio::test]
510 async fn post_reset_password_passwords_mismatch_shows_error() {
511 let (ath, config) = setup().await;
512 let token = create_user_and_token(&ath, "mismatch@example.com").await;
513 let app = test_app(ath, config);
514 let csrf = get_csrf_token(&app, &format!("/auth/reset-password?token={token}")).await;
515
516 let req = Request::builder()
517 .method("POST")
518 .uri("/auth/reset-password")
519 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
520 .header(header::COOKIE, format!("csrf_pre={csrf}"))
521 .body(Body::from(format!(
522 "token={token}&new_password=NewPass999!&confirm_password=Different1!&csrf_token={csrf}"
523 )))
524 .unwrap();
525 let resp = app.oneshot(req).await.unwrap();
526 assert_eq!(resp.status(), StatusCode::OK);
527 let html = body_string(resp).await;
528 assert!(html.contains("Passwords do not match"));
529 }
530
531 #[tokio::test]
532 async fn post_reset_password_too_short_shows_error() {
533 let (ath, config) = setup().await;
534 let token = create_user_and_token(&ath, "short@example.com").await;
535 let app = test_app(ath, config);
536 let csrf = get_csrf_token(&app, &format!("/auth/reset-password?token={token}")).await;
537
538 let req = Request::builder()
539 .method("POST")
540 .uri("/auth/reset-password")
541 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
542 .header(header::COOKIE, format!("csrf_pre={csrf}"))
543 .body(Body::from(format!(
544 "token={token}&new_password=short&confirm_password=short&csrf_token={csrf}"
545 )))
546 .unwrap();
547 let resp = app.oneshot(req).await.unwrap();
548 assert_eq!(resp.status(), StatusCode::OK);
549 let html = body_string(resp).await;
550 assert!(html.contains("Password must be at least 8 characters"));
551 }
552
553 #[tokio::test]
554 async fn post_reset_password_success_shows_confirmation() {
555 let (ath, config) = setup().await;
556 let token = create_user_and_token(&ath, "success@example.com").await;
557 let app = test_app(ath, config);
558 let csrf = get_csrf_token(&app, &format!("/auth/reset-password?token={token}")).await;
559
560 let req = Request::builder()
561 .method("POST")
562 .uri("/auth/reset-password")
563 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
564 .header(header::COOKIE, format!("csrf_pre={csrf}"))
565 .body(Body::from(format!(
566 "token={token}&new_password=NewPass999!&confirm_password=NewPass999!&csrf_token={csrf}"
567 )))
568 .unwrap();
569 let resp = app.oneshot(req).await.unwrap();
570 assert_eq!(resp.status(), StatusCode::OK);
571 let html = body_string(resp).await;
572 assert!(html.contains("Your password has been reset"));
573 }
574
575 #[tokio::test]
576 async fn post_reset_password_used_token_shows_invalid() {
577 let (ath, config) = setup().await;
578 let token = create_user_and_token(&ath, "used@example.com").await;
579 ath.db()
581 .execute_reset(&token, "AlreadyUsed1!")
582 .await
583 .unwrap();
584
585 let app = test_app(ath, config);
586 let csrf = get_csrf_token(&app, "/forgot-password").await;
587
588 let req = Request::builder()
589 .method("POST")
590 .uri("/auth/reset-password")
591 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
592 .header(header::COOKIE, format!("csrf_pre={csrf}"))
593 .body(Body::from(format!(
594 "token={token}&new_password=NewPass999!&confirm_password=NewPass999!&csrf_token={csrf}"
595 )))
596 .unwrap();
597 let resp = app.oneshot(req).await.unwrap();
598 assert_eq!(resp.status(), StatusCode::OK);
599 let html = body_string(resp).await;
600 assert!(html.contains("invalid or has expired"));
601 }
602
603 #[tokio::test]
604 async fn get_forgot_password_logged_in_redirects_to_root() {
605 use allowthem_core::{generate_token, hash_token};
606 use chrono::{Duration, Utc};
607
608 let (ath, config) = setup().await;
609
610 let email = Email::new("loggedin@example.com".into()).unwrap();
612 let user = ath
613 .db()
614 .create_user(email, "password123", None, None)
615 .await
616 .unwrap();
617 let token = generate_token();
618 let token_hash = hash_token(&token);
619 ath.db()
620 .create_session(
621 user.id,
622 token_hash,
623 None,
624 None,
625 Utc::now() + Duration::hours(24),
626 )
627 .await
628 .unwrap();
629 let session_cookie = ath.session_cookie(&token);
630 let cookie_value = session_cookie.split(';').next().unwrap().to_string();
631
632 let app = test_app(ath, config);
633 let req = Request::builder()
634 .uri("/forgot-password")
635 .header(header::COOKIE, cookie_value)
636 .body(Body::empty())
637 .unwrap();
638 let resp = app.oneshot(req).await.unwrap();
639
640 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
641 assert_eq!(resp.headers().get("location").unwrap(), "/");
642 }
643}