use std::sync::Arc;
use std::time::Duration as StdDuration;
use chrono::Duration as ChronoDuration;
use crate::admin::audit::{record as audit_record, ActionType, AuditEvent, LogEntry};
use crate::admin::redact::redact_token;
use crate::admin::Admin;
use crate::auth::sessions::{hash_token_for_storage, random_token};
use crate::auth::users::{find_user_by_email, Identity};
use crate::auth::{invalidate_sessions, set_password, SessionInvalidationReason, SessionTarget};
use crate::email::Mail;
use crate::error::Result;
use crate::http::Request;
use crate::middleware::RateLimiter;
use crate::orm::Db;
pub(crate) async fn init_recovery_tables(db: &Db) -> Result<()> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS rustio_password_reset_tokens (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL,
requested_ip TEXT,
requested_user_agent TEXT,
requested_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL,
consumed_at TIMESTAMPTZ,
mail_status TEXT NOT NULL DEFAULT 'pending'
CHECK (mail_status IN ('pending', 'sent', 'failed')),
correlation_id TEXT
)",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE UNIQUE INDEX IF NOT EXISTS rustio_password_reset_tokens_active_uq \
ON rustio_password_reset_tokens (token_hash) \
WHERE consumed_at IS NULL",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_password_reset_tokens_user_idx \
ON rustio_password_reset_tokens (user_id)",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_password_reset_tokens_expires_idx \
ON rustio_password_reset_tokens (expires_at) \
WHERE consumed_at IS NULL",
)
.execute(db.pool())
.await?;
Ok(())
}
pub(crate) async fn migrate_user_recovery_schema(db: &Db) -> Result<()> {
sqlx::query(
"ALTER TABLE rustio_users \
ADD COLUMN IF NOT EXISTS must_change_password BOOLEAN NOT NULL DEFAULT FALSE",
)
.execute(db.pool())
.await?;
sqlx::query(
"ALTER TABLE rustio_users ADD COLUMN IF NOT EXISTS password_changed_at TIMESTAMPTZ",
)
.execute(db.pool())
.await?;
Ok(())
}
pub trait PasswordPolicy: Send + Sync {
fn validate(&self, candidate: &str) -> std::result::Result<(), PasswordPolicyError>;
fn min_length(&self) -> usize;
}
pub type SharedPasswordPolicy = Arc<dyn PasswordPolicy>;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum PasswordPolicyError {
TooShort { min: usize, actual: usize },
Custom(String),
}
impl std::fmt::Display for PasswordPolicyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooShort { min, actual } => write!(
f,
"This password is too short. It must contain at least {min} characters \
(you entered {actual})."
),
Self::Custom(msg) => f.write_str(msg),
}
}
}
impl std::error::Error for PasswordPolicyError {}
#[derive(Debug, Clone, Copy)]
pub struct DefaultPasswordPolicy {
pub min_len: usize,
}
impl DefaultPasswordPolicy {
pub const fn new() -> Self {
Self { min_len: 10 }
}
pub const fn with_min_len(min_len: usize) -> Self {
Self { min_len }
}
}
impl Default for DefaultPasswordPolicy {
fn default() -> Self {
Self::new()
}
}
impl PasswordPolicy for DefaultPasswordPolicy {
fn validate(&self, candidate: &str) -> std::result::Result<(), PasswordPolicyError> {
let actual = candidate.chars().count();
if actual < self.min_len {
return Err(PasswordPolicyError::TooShort {
min: self.min_len,
actual,
});
}
Ok(())
}
fn min_length(&self) -> usize {
self.min_len
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LoginThrottle {
pub max_attempts: u32,
pub window_minutes: i64,
pub lock_minutes: i64,
}
impl LoginThrottle {
pub const DEFAULT: Self = Self {
max_attempts: 5,
window_minutes: 10,
lock_minutes: 15,
};
}
impl Default for LoginThrottle {
fn default() -> Self {
Self::DEFAULT
}
}
pub trait RecoveryPolicy: Send + Sync {
fn reset_token_ttl(&self) -> ChronoDuration;
fn request_rate_limit(&self) -> (u32, StdDuration);
fn consume_rate_limit(&self) -> (u32, StdDuration);
fn strict_mailer_required(&self) -> bool;
fn public_site_url(&self, req: &Request) -> Option<String> {
derive_public_site_url(|name| req.header(name).map(|s| s.to_string()))
}
fn login_throttle(&self) -> LoginThrottle {
LoginThrottle::default()
}
fn reauth_window(&self) -> ChronoDuration {
ChronoDuration::minutes(15)
}
fn mfa_step_seconds(&self) -> u64 {
30
}
fn mfa_skew_steps(&self) -> u32 {
1
}
fn scope_for(&self, _identity: &Identity) -> Option<SharedRecoveryPolicy> {
None
}
}
pub type SharedRecoveryPolicy = Arc<dyn RecoveryPolicy>;
#[derive(Debug, Clone)]
pub struct DefaultRecoveryPolicy {
pub reset_token_ttl: ChronoDuration,
pub request_rate_limit: (u32, StdDuration),
pub consume_rate_limit: (u32, StdDuration),
pub strict_mailer_required: bool,
}
impl DefaultRecoveryPolicy {
pub fn new() -> Self {
Self {
reset_token_ttl: ChronoDuration::hours(1),
request_rate_limit: (5, StdDuration::from_secs(15 * 60)),
consume_rate_limit: (10, StdDuration::from_secs(5 * 60)),
strict_mailer_required: false,
}
}
pub fn with_reset_token_ttl(mut self, ttl: ChronoDuration) -> Self {
self.reset_token_ttl = ttl;
self
}
pub fn with_request_rate_limit(mut self, capacity: u32, window: StdDuration) -> Self {
self.request_rate_limit = (capacity, window);
self
}
pub fn with_consume_rate_limit(mut self, capacity: u32, window: StdDuration) -> Self {
self.consume_rate_limit = (capacity, window);
self
}
pub fn with_strict_mailer_required(mut self, required: bool) -> Self {
self.strict_mailer_required = required;
self
}
}
impl Default for DefaultRecoveryPolicy {
fn default() -> Self {
Self::new()
}
}
impl RecoveryPolicy for DefaultRecoveryPolicy {
fn reset_token_ttl(&self) -> ChronoDuration {
self.reset_token_ttl
}
fn request_rate_limit(&self) -> (u32, StdDuration) {
self.request_rate_limit
}
fn consume_rate_limit(&self) -> (u32, StdDuration) {
self.consume_rate_limit
}
fn strict_mailer_required(&self) -> bool {
self.strict_mailer_required
}
}
pub(crate) fn derive_public_site_url<F>(header: F) -> Option<String>
where
F: Fn(&str) -> Option<String>,
{
if let Some(value) = header("forwarded") {
if let Some(url) = parse_forwarded_first_hop(&value) {
return Some(url);
}
}
let xfp = header("x-forwarded-proto").and_then(|s| first_csv(&s).map(|v| v.to_string()));
let xfh = header("x-forwarded-host").and_then(|s| first_csv(&s).map(|v| v.to_string()));
if let (Some(proto), Some(host)) = (xfp, xfh) {
if is_safe_proto(&proto) && is_safe_host(&host) {
return Some(format!("{}://{}", proto.to_ascii_lowercase(), host));
}
}
if let Some(host) = header("host") {
if is_safe_host(&host) {
return Some(format!("http://{host}"));
}
}
None
}
fn first_csv(s: &str) -> Option<&str> {
let trimmed = s.split(',').next()?.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed)
}
}
fn is_safe_proto(p: &str) -> bool {
p.eq_ignore_ascii_case("http") || p.eq_ignore_ascii_case("https")
}
fn is_safe_host(h: &str) -> bool {
if h.is_empty() || h.len() > 253 {
return false;
}
h.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | ':' | '-' | '_' | '[' | ']'))
}
fn parse_forwarded_first_hop(value: &str) -> Option<String> {
let first = value.split(',').next()?;
let mut proto: Option<&str> = None;
let mut host: Option<&str> = None;
for pair in first.split(';') {
let pair = pair.trim();
if pair.is_empty() {
continue;
}
let (key, val) = match pair.split_once('=') {
Some(p) => p,
None => continue,
};
let key = key.trim();
let val = val.trim().trim_matches('"');
if val.is_empty() {
continue;
}
if key.eq_ignore_ascii_case("proto") {
proto = Some(val);
} else if key.eq_ignore_ascii_case("host") {
host = Some(val);
}
}
let proto = proto?;
let host = host?;
if !is_safe_proto(proto) || !is_safe_host(host) {
return None;
}
Some(format!("{}://{}", proto.to_ascii_lowercase(), host))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum IssueOutcome {
Issued {
token_id: i64,
email_status: MailerEmailStatus,
},
UnknownOrInactive,
RateLimited,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MailerEmailStatus {
Sent,
Failed,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum ConsumeOutcome {
Consumed {
user_id: i64,
revoked_session_count: usize,
},
Invalid,
PolicyRejected(PasswordPolicyError),
RateLimited,
}
pub(crate) async fn issue_reset_token(
db: &Db,
admin: &Admin,
request_limiter: &RateLimiter,
request: &Request,
email: &str,
correlation_id: Option<&str>,
) -> Result<IssueOutcome> {
let ip = extract_request_ip(request);
if !request_limiter.allow(&ip) {
log::info!(
target: "rustio_admin::recovery::issue",
"rate-limit exhausted ip={} correlation_id={:?}",
ip,
correlation_id,
);
return Ok(IssueOutcome::RateLimited);
}
let email_input = email.trim().to_ascii_lowercase();
if email_input.is_empty() {
log::info!(
target: "rustio_admin::recovery::issue",
"empty-email submission ip={} correlation_id={:?}",
ip,
correlation_id,
);
return Ok(IssueOutcome::UnknownOrInactive);
}
let user = match find_user_by_email(db, &email_input).await? {
Some(u) if u.is_active => u,
Some(u) => {
log::info!(
target: "rustio_admin::recovery::issue",
"inactive-user submission user_id={} ip={} correlation_id={:?}",
u.id,
ip,
correlation_id,
);
return Ok(IssueOutcome::UnknownOrInactive);
}
None => {
log::info!(
target: "rustio_admin::recovery::issue",
"unknown-email submission ip={} correlation_id={:?}",
ip,
correlation_id,
);
return Ok(IssueOutcome::UnknownOrInactive);
}
};
let token = random_token();
let token_hash = hash_token_for_storage(&token);
let policy = admin.active_recovery_policy();
let ttl = policy.reset_token_ttl();
let expires_at = chrono::Utc::now() + ttl;
let user_agent_owned = request.header("user-agent").map(|s| s.to_string());
let token_id: i64 = sqlx::query_scalar(
"INSERT INTO rustio_password_reset_tokens
(user_id, token_hash, requested_ip, requested_user_agent,
expires_at, mail_status, correlation_id)
VALUES ($1, $2, $3, $4, $5, 'pending', $6)
RETURNING id",
)
.bind(user.id)
.bind(&token_hash)
.bind(&ip)
.bind(user_agent_owned.as_deref())
.bind(expires_at)
.bind(correlation_id)
.fetch_one(db.pool())
.await?;
let mail_status = match policy.public_site_url(request) {
Some(public_site_url) => {
let reset_link = format!(
"{}/admin/reset-password/{}",
public_site_url.trim_end_matches('/'),
token,
);
let when = chrono::Utc::now();
let body = format!(
"We received a request to sign you back in to {site_header}.\n\n\
Click the link below to set a new password:\n\n\
{reset_link}\n\n\
The link expires {ttl_human}. If you didn't request this, you can \
safely ignore this email.\n",
site_header = admin.branding().site_header,
reset_link = reset_link,
ttl_human = humanize_ttl(ttl),
);
let mail = Mail::framework_envelope(
user.email.clone(),
format!("{} — sign-in link", admin.branding().site_header),
body,
&admin.branding().site_header,
Some(&ip),
user_agent_owned.as_deref(),
when,
);
match admin.active_mailer().send(mail).await {
Ok(()) => {
set_token_mail_status(db, token_id, "sent").await?;
MailerEmailStatus::Sent
}
Err(e) => {
log::error!(
target: "rustio_admin::recovery::issue",
"mailer send failed user_id={} fingerprint={} correlation_id={:?}: {}",
user.id,
redact_token(&token),
correlation_id,
e,
);
set_token_mail_status(db, token_id, "failed").await?;
MailerEmailStatus::Failed
}
}
}
None => {
log::error!(
target: "rustio_admin::recovery::issue",
"public_site_url derivation returned None — reset link cannot be built. \
user_id={} fingerprint={} correlation_id={:?}",
user.id,
redact_token(&token),
correlation_id,
);
set_token_mail_status(db, token_id, "failed").await?;
MailerEmailStatus::Failed
}
};
let metadata = serde_json::json!({
"token_fingerprint": redact_token(&token),
"email_send_status": match mail_status {
MailerEmailStatus::Sent => "sent",
MailerEmailStatus::Failed => "failed",
},
"requested_ip": ip,
"requested_user_agent": user_agent_owned,
"expires_at": expires_at.to_rfc3339(),
});
let mut entry = LogEntry::new(user.id, ActionType::Update, "user", user.id)
.with_event(AuditEvent::PasswordResetSelfRequest);
entry.correlation_id = correlation_id;
entry.ip_address = Some(&ip);
entry.metadata = Some(metadata);
entry.summary = format!(
"password reset requested; mail {}",
match mail_status {
MailerEmailStatus::Sent => "sent",
MailerEmailStatus::Failed => "failed",
}
);
audit_record(db, entry).await?;
Ok(IssueOutcome::Issued {
token_id,
email_status: mail_status,
})
}
pub(crate) async fn consume_reset_token(
db: &Db,
admin: &Admin,
consume_limiter: &RateLimiter,
request: &Request,
token: &str,
new_password: &str,
correlation_id: Option<&str>,
) -> Result<ConsumeOutcome> {
let ip = extract_request_ip(request);
if !consume_limiter.allow(&ip) {
log::info!(
target: "rustio_admin::recovery::consume",
"rate-limit exhausted ip={} correlation_id={:?}",
ip,
correlation_id,
);
return Ok(ConsumeOutcome::RateLimited);
}
if let Err(e) = admin.active_password_policy().validate(new_password) {
return Ok(ConsumeOutcome::PolicyRejected(e));
}
let token_hash = hash_token_for_storage(token);
let user_id: Option<i64> = sqlx::query_scalar(
"UPDATE rustio_password_reset_tokens
SET consumed_at = NOW()
WHERE token_hash = $1
AND consumed_at IS NULL
AND expires_at > NOW()
RETURNING user_id",
)
.bind(&token_hash)
.fetch_optional(db.pool())
.await?;
let user_id = match user_id {
Some(uid) => uid,
None => {
log::info!(
target: "rustio_admin::recovery::consume",
"consume on invalid/expired/consumed token ip={} fingerprint={} correlation_id={:?}",
ip,
redact_token(token),
correlation_id,
);
return Ok(ConsumeOutcome::Invalid);
}
};
set_password(db, user_id, new_password).await?;
let outcome = invalidate_sessions(
db,
SessionTarget::User { user_id },
SessionInvalidationReason::PasswordReset,
)
.await?;
let revoked_session_count = outcome.revoked_session_ids.len();
let user_agent_owned = request.header("user-agent").map(|s| s.to_string());
let metadata = serde_json::json!({
"token_fingerprint": redact_token(token),
"invalidated_session_count": revoked_session_count,
"ip": ip,
"user_agent": user_agent_owned,
});
let mut entry = LogEntry::new(user_id, ActionType::Update, "user", user_id)
.with_event(AuditEvent::PasswordResetSelfConsume);
entry.correlation_id = correlation_id;
entry.ip_address = Some(&ip);
entry.metadata = Some(metadata);
entry.summary =
format!("password reset self-consumed; {revoked_session_count} session(s) revoked");
audit_record(db, entry).await?;
Ok(ConsumeOutcome::Consumed {
user_id,
revoked_session_count,
})
}
pub(crate) async fn check_reset_token_valid(db: &Db, token: &str) -> Result<bool> {
let token_hash = hash_token_for_storage(token);
let exists: Option<i64> = sqlx::query_scalar(
"SELECT id FROM rustio_password_reset_tokens
WHERE token_hash = $1
AND consumed_at IS NULL
AND expires_at > NOW()
LIMIT 1",
)
.bind(&token_hash)
.fetch_optional(db.pool())
.await?;
Ok(exists.is_some())
}
const RESET_TOKEN_RETENTION_DAYS: i64 = 7;
pub(crate) async fn purge_expired_reset_tokens(db: &Db) -> Result<u64> {
let query = format!(
"DELETE FROM rustio_password_reset_tokens \
WHERE expires_at < NOW() - INTERVAL '{RESET_TOKEN_RETENTION_DAYS} days'"
);
let result = sqlx::query(&query).execute(db.pool()).await?;
Ok(result.rows_affected())
}
async fn set_token_mail_status(db: &Db, token_id: i64, status: &str) -> Result<()> {
sqlx::query(
"UPDATE rustio_password_reset_tokens
SET mail_status = $1
WHERE id = $2",
)
.bind(status)
.bind(token_id)
.execute(db.pool())
.await?;
Ok(())
}
fn extract_request_ip(request: &Request) -> String {
request
.header("x-forwarded-for")
.and_then(|v| v.split(',').next())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "anon".to_string())
}
fn humanize_ttl(ttl: ChronoDuration) -> String {
let secs = ttl.num_seconds();
if secs <= 0 {
return "very soon".to_string();
}
if ttl.num_hours() >= 1 {
let h = ttl.num_hours();
return if h == 1 {
"in 1 hour".to_string()
} else {
format!("in {h} hours")
};
}
if ttl.num_minutes() >= 1 {
let m = ttl.num_minutes();
return if m == 1 {
"in 1 minute".to_string()
} else {
format!("in {m} minutes")
};
}
if secs == 1 {
"in 1 second".to_string()
} else {
format!("in {secs} seconds")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_policy_floor_is_ten() {
assert_eq!(DefaultPasswordPolicy::new().min_length(), 10);
assert_eq!(DefaultPasswordPolicy::default().min_length(), 10);
}
#[test]
fn default_policy_accepts_password_at_floor() {
let p = DefaultPasswordPolicy::new();
assert!(p.validate("aaaaaaaaaa").is_ok());
assert!(p.validate("correct horse battery staple").is_ok());
}
#[test]
fn default_policy_rejects_short_password() {
let p = DefaultPasswordPolicy::new();
let err = p.validate("nine_char").unwrap_err();
assert_eq!(err, PasswordPolicyError::TooShort { min: 10, actual: 9 });
}
#[test]
fn default_policy_rejects_empty_password() {
let p = DefaultPasswordPolicy::new();
let err = p.validate("").unwrap_err();
assert_eq!(err, PasswordPolicyError::TooShort { min: 10, actual: 0 });
}
#[test]
fn default_policy_with_min_len_overrides_floor() {
let p = DefaultPasswordPolicy::with_min_len(16);
assert_eq!(p.min_length(), 16);
assert!(p.validate("fifteen_chars__").is_err()); assert!(p.validate("sixteen_chars___").is_ok()); }
#[test]
fn default_policy_counts_chars_not_bytes() {
let p = DefaultPasswordPolicy::new();
let pw = "пароль1234";
assert_eq!(pw.chars().count(), 10);
assert!(pw.len() > 10);
assert!(p.validate(pw).is_ok());
let pw = "пароль123";
let err = p.validate(pw).unwrap_err();
assert_eq!(err, PasswordPolicyError::TooShort { min: 10, actual: 9 });
}
#[test]
fn error_renderings_do_not_leak_plaintext() {
let p = DefaultPasswordPolicy::new();
let plaintext = "Pwn4Ge#xy"; let err = p.validate(plaintext).unwrap_err();
let display = format!("{err}");
let debug = format!("{err:?}");
assert!(
!display.contains(plaintext),
"Display leaked plaintext: {display}"
);
assert!(
!debug.contains(plaintext),
"Debug leaked plaintext: {debug}"
);
}
#[test]
fn custom_error_renders_message_verbatim() {
let err = PasswordPolicyError::Custom("breached password rejected".into());
assert_eq!(format!("{err}"), "breached password rejected");
}
#[test]
fn shared_password_policy_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SharedPasswordPolicy>();
}
#[test]
fn default_recovery_policy_ttl_is_one_hour() {
let p = DefaultRecoveryPolicy::new();
assert_eq!(p.reset_token_ttl(), ChronoDuration::hours(1));
}
#[test]
fn default_recovery_policy_request_rate_limit_is_five_per_fifteen_min() {
let p = DefaultRecoveryPolicy::new();
assert_eq!(p.request_rate_limit(), (5, StdDuration::from_secs(15 * 60)));
}
#[test]
fn default_recovery_policy_consume_rate_limit_is_ten_per_five_min() {
let p = DefaultRecoveryPolicy::new();
assert_eq!(p.consume_rate_limit(), (10, StdDuration::from_secs(5 * 60)));
}
#[test]
fn default_recovery_policy_strict_mailer_required_is_false() {
let p = DefaultRecoveryPolicy::new();
assert!(!p.strict_mailer_required());
}
#[test]
fn default_recovery_policy_with_overrides_apply_field_by_field() {
let p = DefaultRecoveryPolicy::new()
.with_reset_token_ttl(ChronoDuration::hours(2))
.with_request_rate_limit(3, StdDuration::from_secs(60))
.with_consume_rate_limit(20, StdDuration::from_secs(30))
.with_strict_mailer_required(true);
assert_eq!(p.reset_token_ttl(), ChronoDuration::hours(2));
assert_eq!(p.request_rate_limit(), (3, StdDuration::from_secs(60)));
assert_eq!(p.consume_rate_limit(), (20, StdDuration::from_secs(30)));
assert!(p.strict_mailer_required());
}
#[test]
fn shared_recovery_policy_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SharedRecoveryPolicy>();
}
#[test]
fn login_throttle_default_is_five_ten_fifteen() {
let t = LoginThrottle::default();
assert_eq!(t.max_attempts, 5);
assert_eq!(t.window_minutes, 10);
assert_eq!(t.lock_minutes, 15);
assert_eq!(t, LoginThrottle::DEFAULT);
}
#[test]
fn default_recovery_policy_login_throttle_is_default() {
let p = DefaultRecoveryPolicy::new();
assert_eq!(p.login_throttle(), LoginThrottle::DEFAULT);
}
#[test]
fn default_recovery_policy_reauth_window_is_fifteen_minutes() {
let p = DefaultRecoveryPolicy::new();
assert_eq!(p.reauth_window(), ChronoDuration::minutes(15));
}
#[test]
fn default_recovery_policy_mfa_step_seconds_is_thirty() {
let p = DefaultRecoveryPolicy::new();
assert_eq!(p.mfa_step_seconds(), 30);
}
#[test]
fn default_recovery_policy_mfa_skew_steps_is_one() {
let p = DefaultRecoveryPolicy::new();
assert_eq!(p.mfa_skew_steps(), 1);
}
#[test]
fn default_recovery_policy_scope_for_returns_none() {
use crate::auth::Role;
let identity = Identity {
user_id: 42,
email: "test@example.com".into(),
role: Role::User,
is_active: true,
is_demo: false,
demo_label: None,
must_change_password: false,
mfa_enabled: false,
trust_level: crate::auth::SessionTrust::Authenticated,
};
let p = DefaultRecoveryPolicy::new();
assert!(p.scope_for(&identity).is_none());
}
#[test]
fn login_throttle_is_send_sync_copy() {
fn assert_send_sync_copy<T: Send + Sync + Copy>() {}
assert_send_sync_copy::<LoginThrottle>();
}
fn header_lookup(
pairs: &'static [(&'static str, &'static str)],
) -> impl Fn(&str) -> Option<String> + 'static {
move |name| {
pairs
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| (*v).to_string())
}
}
#[test]
fn site_url_prefers_rfc7239_forwarded_first_hop() {
let h = header_lookup(&[
(
"forwarded",
"for=1.2.3.4;proto=https;host=admin.example.com",
),
("x-forwarded-proto", "http"),
("x-forwarded-host", "wrong.example.com"),
("host", "internal.local"),
]);
assert_eq!(
derive_public_site_url(&h),
Some("https://admin.example.com".to_string())
);
}
#[test]
fn site_url_falls_through_to_x_forwarded_pair() {
let h = header_lookup(&[
("x-forwarded-proto", "https"),
("x-forwarded-host", "admin.example.com"),
("host", "internal.local"),
]);
assert_eq!(
derive_public_site_url(&h),
Some("https://admin.example.com".to_string())
);
}
#[test]
fn site_url_x_forwarded_takes_first_csv_entry() {
let h = header_lookup(&[
("x-forwarded-proto", "https, http"),
("x-forwarded-host", "admin.example.com, internal.local"),
]);
assert_eq!(
derive_public_site_url(&h),
Some("https://admin.example.com".to_string())
);
}
#[test]
fn site_url_falls_back_to_host_header_with_http() {
let h = header_lookup(&[("host", "admin.example.com")]);
assert_eq!(
derive_public_site_url(&h),
Some("http://admin.example.com".to_string())
);
}
#[test]
fn site_url_returns_none_when_no_headers_resolve() {
let h = header_lookup(&[]);
assert_eq!(derive_public_site_url(&h), None);
}
#[test]
fn site_url_rejects_non_http_proto() {
let h = header_lookup(&[
(
"forwarded",
"for=1.2.3.4;proto=javascript;host=evil.example.com",
),
("host", "fallback.example.com"),
]);
assert_eq!(
derive_public_site_url(&h),
Some("http://fallback.example.com".to_string())
);
}
#[test]
fn site_url_rejects_host_with_whitespace_or_control() {
let h = header_lookup(&[("host", "example.com\r\nX-Injected: yes")]);
assert_eq!(derive_public_site_url(&h), None);
}
#[test]
fn site_url_handles_quoted_forwarded_values() {
let h = header_lookup(&[(
"forwarded",
"for=\"_obfuscated\";proto=\"https\";host=\"admin.example.com\"",
)]);
assert_eq!(
derive_public_site_url(&h),
Some("https://admin.example.com".to_string())
);
}
#[test]
fn site_url_handles_ipv6_bracketed_host() {
let h = header_lookup(&[
("x-forwarded-proto", "https"),
("x-forwarded-host", "[2001:db8::1]:8443"),
]);
assert_eq!(
derive_public_site_url(&h),
Some("https://[2001:db8::1]:8443".to_string())
);
}
#[test]
fn humanize_ttl_one_hour_default() {
assert_eq!(humanize_ttl(ChronoDuration::hours(1)), "in 1 hour");
}
#[test]
fn humanize_ttl_two_hours_pluralises() {
assert_eq!(humanize_ttl(ChronoDuration::hours(2)), "in 2 hours");
}
#[test]
fn humanize_ttl_minutes() {
assert_eq!(humanize_ttl(ChronoDuration::minutes(30)), "in 30 minutes");
assert_eq!(humanize_ttl(ChronoDuration::minutes(1)), "in 1 minute");
}
#[test]
fn humanize_ttl_seconds_for_short_windows() {
assert_eq!(humanize_ttl(ChronoDuration::seconds(45)), "in 45 seconds");
assert_eq!(humanize_ttl(ChronoDuration::seconds(1)), "in 1 second");
}
#[test]
fn reset_token_retention_window_is_seven_days() {
assert_eq!(RESET_TOKEN_RETENTION_DAYS, 7);
}
#[test]
fn purge_query_includes_retention_window_and_table() {
let query = format!(
"DELETE FROM rustio_password_reset_tokens \
WHERE expires_at < NOW() - INTERVAL '{RESET_TOKEN_RETENTION_DAYS} days'"
);
assert!(
query.contains("rustio_password_reset_tokens"),
"purge must target the recovery table"
);
assert!(
query.contains("INTERVAL '7 days'"),
"purge must use the locked 7-day retention window"
);
assert!(
!query.contains("consumed_at"),
"purge must apply to BOTH consumed and unconsumed expired rows; \
a `consumed_at` filter would leak old consumed rows indefinitely"
);
assert!(
query.starts_with("DELETE FROM"),
"purge must be a DELETE statement"
);
}
#[test]
fn humanize_ttl_zero_or_negative_returns_safe_string() {
assert_eq!(humanize_ttl(ChronoDuration::zero()), "very soon");
assert_eq!(humanize_ttl(ChronoDuration::seconds(-30)), "very soon");
}
#[test]
fn issue_outcome_debug_never_carries_plaintext_token() {
let synthetic = "Pwn4Ge_ZZ_token_plaintext_1234567890";
for outcome in [
IssueOutcome::Issued {
token_id: 42,
email_status: MailerEmailStatus::Sent,
},
IssueOutcome::Issued {
token_id: 7,
email_status: MailerEmailStatus::Failed,
},
IssueOutcome::UnknownOrInactive,
IssueOutcome::RateLimited,
] {
let debug = format!("{outcome:?}");
assert!(
!debug.contains(synthetic),
"IssueOutcome Debug leaked plaintext: {debug}",
);
}
}
#[test]
fn consume_outcome_debug_never_carries_plaintext_token() {
let synthetic = "Pwn4Ge_ZZ_token_plaintext_1234567890";
for outcome in [
ConsumeOutcome::Consumed {
user_id: 1,
revoked_session_count: 3,
},
ConsumeOutcome::Invalid,
ConsumeOutcome::PolicyRejected(PasswordPolicyError::TooShort { min: 10, actual: 4 }),
ConsumeOutcome::PolicyRejected(PasswordPolicyError::Custom("stub rejected".into())),
ConsumeOutcome::RateLimited,
] {
let debug = format!("{outcome:?}");
assert!(
!debug.contains(synthetic),
"ConsumeOutcome Debug leaked plaintext: {debug}",
);
}
}
#[test]
fn mailer_email_status_round_trip_strings() {
assert_eq!(format!("{:?}", MailerEmailStatus::Sent), "Sent");
assert_eq!(format!("{:?}", MailerEmailStatus::Failed), "Failed");
}
#[test]
fn malformed_forwarded_inputs_never_panic() {
for input in &[
"",
"garbage",
"for=",
"proto=;host=",
"proto=javascript:alert(1);host=evil",
"host=example com",
"proto=https;host=",
";;;",
",,,",
"proto=https",
"host=example.com",
"for=\"unterminated",
"=value",
"key=",
"key==value=",
] {
let value = (*input).to_string();
let h = move |name: &str| match name {
"forwarded" => Some(value.clone()),
"host" => Some("fallback.example.com".to_string()),
_ => None,
};
let result = derive_public_site_url(h);
assert!(
result.is_none()
|| result.as_deref() == Some("http://fallback.example.com")
|| result.as_deref().map(|s| s.starts_with("https://")) == Some(true)
|| result.as_deref().map(|s| s.starts_with("http://")) == Some(true),
"input {input:?} produced unexpected url {result:?}"
);
}
}
}