use std::sync::Arc;
use serde::Serialize;
use crate::auth::recovery::{
check_reset_token_valid, consume_reset_token, issue_reset_token, ConsumeOutcome,
PasswordPolicyError,
};
use crate::error::Result;
use crate::http::{Request, Response};
use crate::middleware::{CorrelationId, RateLimiter};
use super::handlers::{csrf_token, AdminCtx};
use super::render::{BaseContext, FlashCtx, FormField, FormSection};
use super::types::Admin;
#[derive(Clone)]
pub(crate) struct RecoveryState {
pub request_limiter: Arc<RateLimiter>,
pub consume_limiter: Arc<RateLimiter>,
}
impl RecoveryState {
pub fn from_admin(admin: &Admin) -> Self {
let policy = admin.active_recovery_policy();
let (req_cap, req_window) = policy.request_rate_limit();
let (cons_cap, cons_window) = policy.consume_rate_limit();
Self {
request_limiter: Arc::new(RateLimiter::new(req_cap, req_window)),
consume_limiter: Arc::new(RateLimiter::new(cons_cap, cons_window)),
}
}
}
fn correlation_id(req: &Request) -> Option<String> {
req.ctx().get::<CorrelationId>().map(|c| c.0.clone())
}
#[derive(Serialize)]
struct ForgotPasswordCtx {
#[serde(flatten)]
base: BaseContext,
page_title: &'static str,
sections: Vec<FormSection>,
error: Option<String>,
correlation_id: Option<String>,
}
#[derive(Serialize)]
struct ForgotPasswordSentCtx {
#[serde(flatten)]
base: BaseContext,
page_title: &'static str,
correlation_id: Option<String>,
}
#[derive(Serialize)]
struct ResetPasswordCtx {
#[serde(flatten)]
base: BaseContext,
page_title: &'static str,
invalid: bool,
token: String,
min_length: usize,
sections: Vec<FormSection>,
errors: Vec<String>,
flash: Option<FlashCtx>,
correlation_id: Option<String>,
}
fn forgot_password_form_sections(prefilled_email: &str) -> Vec<FormSection> {
vec![FormSection {
title: None,
fields: vec![FormField {
name: "email",
label: "Email".to_string(),
widget: "input",
input_type: "email",
value: prefilled_email.to_string(),
hint: None,
placeholder: None,
required: true,
options: None,
multiple: false,
span: 2,
autocomplete: Some("username"),
autofocus: true,
disabled: false,
maxlength: None,
searchable: false,
has_more: false,
search_url: None,
errors: vec![],
target_model: None,
checked: false,
}],
}]
}
fn reset_password_form_sections(
min_length: usize,
field_errors: &[(&'static str, String)],
) -> Vec<FormSection> {
let err_for = |name: &str| -> Vec<String> {
field_errors
.iter()
.filter(|(n, _)| *n == name)
.map(|(_, m)| m.clone())
.collect()
};
let hint = format!("At least {min_length} characters.");
vec![FormSection {
title: None,
fields: vec![
FormField {
name: "new_password1",
label: "New password".to_string(),
widget: "input",
input_type: "password",
value: String::new(),
hint: Some(hint.clone()),
placeholder: None,
required: true,
options: None,
multiple: false,
span: 2,
autocomplete: Some("new-password"),
autofocus: true,
disabled: false,
maxlength: None,
searchable: false,
has_more: false,
search_url: None,
errors: err_for("new_password1"),
target_model: None,
checked: false,
},
FormField {
name: "new_password2",
label: "Confirm".to_string(),
widget: "input",
input_type: "password",
value: String::new(),
hint: None,
placeholder: None,
required: true,
options: None,
multiple: false,
span: 2,
autocomplete: Some("new-password"),
autofocus: false,
disabled: false,
maxlength: None,
searchable: false,
has_more: false,
search_url: None,
errors: err_for("new_password2"),
target_model: None,
checked: false,
},
],
}]
}
pub(crate) async fn show_forgot_password(ctx: &AdminCtx, req: &Request) -> Result<Response> {
let view = ForgotPasswordCtx {
base: BaseContext::new(None, csrf_token(req), &ctx.admin),
page_title: "Reset your password",
sections: forgot_password_form_sections(""),
error: None,
correlation_id: correlation_id(req),
};
let body = ctx.templates.render("admin/forgot_password.html", &view)?;
Ok(Response::html(body))
}
pub(crate) async fn do_forgot_password(
ctx: &AdminCtx,
state: &RecoveryState,
req: Request,
) -> Result<Response> {
let cid = correlation_id(&req);
let form = req.form()?;
let email = form.get("email").unwrap_or("").to_string();
let _outcome = issue_reset_token(
&ctx.db,
&ctx.admin,
&state.request_limiter,
&req,
&email,
cid.as_deref(),
)
.await?;
Ok(Response::redirect("/admin/forgot-password/sent"))
}
pub(crate) async fn show_forgot_password_sent(ctx: &AdminCtx, req: &Request) -> Result<Response> {
let view = ForgotPasswordSentCtx {
base: BaseContext::new(None, csrf_token(req), &ctx.admin),
page_title: "Check your email",
correlation_id: correlation_id(req),
};
let body = ctx
.templates
.render("admin/forgot_password_sent.html", &view)?;
Ok(Response::html(body))
}
pub(crate) async fn show_reset_password(
ctx: &AdminCtx,
req: &Request,
token: &str,
) -> Result<Response> {
let valid = check_reset_token_valid(&ctx.db, token).await?;
let min_length = ctx.admin.active_password_policy().min_length();
let view = ResetPasswordCtx {
base: BaseContext::new(None, csrf_token(req), &ctx.admin),
page_title: if valid {
"Set a new password"
} else {
"This link is no longer valid"
},
invalid: !valid,
token: token.to_string(),
min_length,
sections: if valid {
reset_password_form_sections(min_length, &[])
} else {
Vec::new()
},
errors: Vec::new(),
flash: None,
correlation_id: correlation_id(req),
};
let body = ctx.templates.render("admin/reset_password.html", &view)?;
Ok(Response::html(body))
}
pub(crate) async fn do_reset_password(
ctx: &AdminCtx,
state: &RecoveryState,
req: Request,
token: &str,
) -> Result<Response> {
let cid = correlation_id(&req);
let form = req.form()?;
let pw1 = form.get("new_password1").unwrap_or("").to_string();
let pw2 = form.get("new_password2").unwrap_or("").to_string();
let min_length = ctx.admin.active_password_policy().min_length();
if pw1 != pw2 {
return render_reset_password_form_error(
ctx,
&req,
token,
min_length,
&[(
"new_password2",
"The two password fields didn't match.".into(),
)],
cid,
);
}
let outcome = consume_reset_token(
&ctx.db,
&ctx.admin,
&state.consume_limiter,
&req,
token,
&pw1,
cid.as_deref(),
)
.await?;
match outcome {
ConsumeOutcome::Consumed { .. } => {
Ok(Response::redirect("/admin/login?password_reset=success"))
}
ConsumeOutcome::Invalid | ConsumeOutcome::RateLimited => {
render_reset_password_invalid(ctx, &req, token, min_length, cid)
}
ConsumeOutcome::PolicyRejected(err) => {
let user_message = render_password_policy_error(&err);
render_reset_password_form_error(
ctx,
&req,
token,
min_length,
&[("new_password1", user_message)],
cid,
)
}
}
}
fn render_reset_password_form_error(
ctx: &AdminCtx,
req: &Request,
token: &str,
min_length: usize,
field_errors: &[(&'static str, String)],
cid: Option<String>,
) -> Result<Response> {
let view = ResetPasswordCtx {
base: BaseContext::new(None, csrf_token(req), &ctx.admin),
page_title: "Set a new password",
invalid: false,
token: token.to_string(),
min_length,
sections: reset_password_form_sections(min_length, field_errors),
errors: field_errors.iter().map(|(_, m)| m.clone()).collect(),
flash: None,
correlation_id: cid,
};
let body = ctx.templates.render("admin/reset_password.html", &view)?;
Ok(Response::html(body).with_status(hyper::StatusCode::BAD_REQUEST))
}
fn render_reset_password_invalid(
ctx: &AdminCtx,
req: &Request,
token: &str,
min_length: usize,
cid: Option<String>,
) -> Result<Response> {
let view = ResetPasswordCtx {
base: BaseContext::new(None, csrf_token(req), &ctx.admin),
page_title: "This link is no longer valid",
invalid: true,
token: token.to_string(),
min_length,
sections: Vec::new(),
errors: Vec::new(),
flash: None,
correlation_id: cid,
};
let body = ctx.templates.render("admin/reset_password.html", &view)?;
Ok(Response::html(body))
}
fn render_password_policy_error(err: &PasswordPolicyError) -> String {
err.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::{DefaultPasswordPolicy, PasswordPolicy};
#[test]
fn recovery_state_constructs_from_default_admin() {
let admin = Admin::new();
let state = RecoveryState::from_admin(&admin);
assert!(state.request_limiter.allow("test-ip-1"));
assert!(state.consume_limiter.allow("test-ip-1"));
}
#[test]
fn reset_password_form_renders_live_policy_minimum() {
let sections = reset_password_form_sections(10, &[]);
assert_eq!(sections.len(), 1);
let section = §ions[0];
assert_eq!(section.fields.len(), 2);
let pw1 = §ion.fields[0];
assert_eq!(pw1.name, "new_password1");
assert_eq!(
pw1.hint.as_deref(),
Some("At least 10 characters."),
"hint must reflect the live policy minimum"
);
let sections = reset_password_form_sections(16, &[]);
assert_eq!(
sections[0].fields[0].hint.as_deref(),
Some("At least 16 characters."),
);
}
#[test]
fn reset_password_form_routes_field_errors_to_named_field() {
let sections = reset_password_form_sections(
10,
&[
("new_password1", "policy err".into()),
("new_password2", "match err".into()),
],
);
assert_eq!(sections[0].fields[0].errors, vec!["policy err"]);
assert_eq!(sections[0].fields[1].errors, vec!["match err"]);
}
#[test]
fn forgot_password_form_has_only_email_field() {
let sections = forgot_password_form_sections("");
assert_eq!(sections.len(), 1);
assert_eq!(sections[0].fields.len(), 1);
assert_eq!(sections[0].fields[0].name, "email");
assert_eq!(sections[0].fields[0].input_type, "email");
assert!(sections[0].fields[0].autofocus);
}
#[test]
fn render_password_policy_error_does_not_leak_plaintext() {
let policy = DefaultPasswordPolicy::new();
let plaintext = "Pwn4Ge#zZ"; let err = policy
.validate(plaintext)
.expect_err("9-char password should fail the 10-char floor");
let rendered = render_password_policy_error(&err);
assert!(
!rendered.contains(plaintext),
"rendered policy error leaked plaintext: {rendered}"
);
}
#[test]
fn render_contexts_tolerate_absent_correlation_id() {
let admin = Admin::new();
let base = BaseContext::new(None, "csrf-test".into(), &admin);
let ctx = ForgotPasswordCtx {
base,
page_title: "Reset your password",
sections: forgot_password_form_sections(""),
error: None,
correlation_id: None,
};
let json = serde_json::to_string(&ctx).expect("serializes cleanly with absent cid");
assert!(json.contains("\"correlation_id\":null"));
}
}