use crate::AuthUser;
use crate::mailer::{OutgoingMail, active_mailer};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::{DateTime, Utc};
use rand::{Rng, RngCore};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use umbral::db::transaction;
use umbral::orm::{F, ForeignKey};
use umbral::templates::{context, render};
pub const PURPOSE_EMAIL_VERIFY: &str = "email_verify";
pub const PURPOSE_PASSWORD_RESET: &str = "password_reset";
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize, umbral::orm::Model)]
pub struct AuthChallenge {
pub id: i64,
#[umbral(on_delete = "cascade")]
pub user_id: ForeignKey<AuthUser>,
#[umbral(max_length = 32)]
pub purpose: String,
#[umbral(max_length = 64)]
pub secret_hash: String,
pub expires_at: DateTime<Utc>,
pub attempts: i32,
pub used_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
}
pub(crate) fn generate_code() -> String {
let n: u32 = rand::rngs::OsRng.gen_range(0..1_000_000);
format!("{n:06}")
}
pub(crate) fn generate_reset_token() -> String {
let mut buf = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut buf);
format!("umbral_{}", URL_SAFE_NO_PAD.encode(buf))
}
pub(crate) fn hash_secret(plaintext: &str) -> String {
crate::token::digest_token(plaintext)
}
impl AuthChallenge {
pub async fn issue(
user_id: i64,
purpose: &str,
plaintext: &str,
ttl: Duration,
) -> Result<AuthChallenge, crate::AuthError> {
let now = Utc::now();
let expires_at =
now + chrono::Duration::from_std(ttl).unwrap_or_else(|_| chrono::Duration::minutes(15));
let row = AuthChallenge::objects()
.create(AuthChallenge {
id: 0, user_id: ForeignKey::new(user_id),
purpose: purpose.to_string(),
secret_hash: hash_secret(plaintext),
expires_at,
attempts: 0,
used_at: None,
created_at: now,
})
.await?;
Ok(row)
}
pub fn is_live(&self) -> bool {
self.used_at.is_none() && self.expires_at > Utc::now()
}
pub async fn find_active_for_user(
user_id: i64,
purpose: &str,
) -> Result<Option<AuthChallenge>, crate::AuthError> {
let row = AuthChallenge::objects()
.filter(
auth_challenge::USER_ID.eq(user_id)
& auth_challenge::PURPOSE.eq(purpose)
& auth_challenge::USED_AT.is_null(),
)
.order_by(auth_challenge::CREATED_AT.desc())
.first()
.await?;
Ok(row.filter(|c| c.is_live()))
}
pub async fn find_active_by_secret(
plaintext: &str,
purpose: &str,
) -> Result<Option<AuthChallenge>, crate::AuthError> {
let row = AuthChallenge::objects()
.filter(
auth_challenge::SECRET_HASH.eq(hash_secret(plaintext))
& auth_challenge::PURPOSE.eq(purpose)
& auth_challenge::USED_AT.is_null(),
)
.first()
.await?;
Ok(row.filter(|c| c.is_live()))
}
pub async fn mark_used(&self) -> Result<(), crate::AuthError> {
let mut delta = serde_json::Map::new();
delta.insert("used_at".to_string(), serde_json::json!(Utc::now()));
AuthChallenge::objects()
.filter(auth_challenge::ID.eq(self.id))
.update_values(delta)
.await?;
Ok(())
}
pub async fn bump_attempts(&self) -> Result<(), crate::AuthError> {
AuthChallenge::objects()
.filter(auth_challenge::ID.eq(self.id))
.update_expr("attempts", F::col("attempts").add(1))
.await?;
Ok(())
}
}
const CODE_TTL: Duration = Duration::from_secs(15 * 60);
const MAX_CODE_ATTEMPTS: i32 = 5;
pub async fn start_email_verification(user: &AuthUser) -> Result<(), crate::AuthError> {
let code = generate_code();
AuthChallenge::issue(user.id, PURPOSE_EMAIL_VERIFY, &code, CODE_TTL).await?;
let ctx = context! { code => code.clone(), username => user.username.clone() };
let html = render("auth/email/verify_code.html", &ctx)
.map_err(|e| crate::AuthError::Template(e.to_string()))?;
let text = render("auth/email/verify_code.txt", &ctx)
.map_err(|e| crate::AuthError::Template(e.to_string()))?;
active_mailer()
.send(OutgoingMail {
to: user.email.clone(),
username: user.username.clone(),
kind: crate::mailer::MailKind::EmailVerification { code },
subject: "Verify your email".into(),
html,
text,
})
.await
.map_err(|e| crate::AuthError::Mail(e.to_string()))?;
Ok(())
}
pub async fn verify_email(email: &str, code: &str) -> Result<(), crate::AuthError> {
let Some(user) = AuthUser::objects()
.filter(crate::auth_user::EMAIL.eq(email.to_string()))
.first()
.await?
else {
return Err(crate::AuthError::InvalidChallenge);
};
let Some(challenge) =
AuthChallenge::find_active_for_user(user.id, PURPOSE_EMAIL_VERIFY).await?
else {
return Err(crate::AuthError::InvalidChallenge);
};
if challenge.attempts >= MAX_CODE_ATTEMPTS {
challenge.mark_used().await?; return Err(crate::AuthError::InvalidChallenge);
}
if hash_secret(code) != challenge.secret_hash {
challenge.bump_attempts().await?;
return Err(crate::AuthError::InvalidChallenge);
}
let challenge_id = challenge.id;
let user_id = user.id;
transaction(|tx| {
Box::pin(async move {
let mut mark_delta = serde_json::Map::new();
mark_delta.insert("used_at".to_string(), serde_json::json!(Utc::now()));
AuthChallenge::objects()
.filter(auth_challenge::ID.eq(challenge_id))
.on_tx(tx)
.update_values(mark_delta)
.await?;
let mut verify_delta = serde_json::Map::new();
verify_delta.insert(
"email_verified_at".to_string(),
serde_json::json!(Utc::now()),
);
AuthUser::objects()
.filter(crate::auth_user::ID.eq(user_id))
.on_tx(tx)
.update_values(verify_delta)
.await?;
Ok::<_, crate::AuthError>(())
})
})
.await?;
Ok(())
}
const RESET_TTL: Duration = Duration::from_secs(60 * 60);
pub async fn start_password_reset(
email: &str,
reset_url_base: &str,
) -> Result<(), crate::AuthError> {
let Some(user) = crate::AuthUser::objects()
.filter(crate::auth_user::EMAIL.eq(email.to_string()))
.first()
.await?
else {
return Ok(());
};
let token = generate_reset_token();
AuthChallenge::issue(user.id, PURPOSE_PASSWORD_RESET, &token, RESET_TTL).await?;
let reset_url = format!("{reset_url_base}?token={token}");
let ctx = context! { reset_url => reset_url.clone(), username => user.username.clone() };
let html = render("auth/email/reset_link.html", &ctx)
.map_err(|e| crate::AuthError::Template(e.to_string()))?;
let text = render("auth/email/reset_link.txt", &ctx)
.map_err(|e| crate::AuthError::Template(e.to_string()))?;
active_mailer()
.send(OutgoingMail {
to: user.email.clone(),
username: user.username.clone(),
kind: crate::mailer::MailKind::PasswordReset { reset_url },
subject: "Reset your password".into(),
html,
text,
})
.await
.map_err(|e| crate::AuthError::Mail(e.to_string()))?;
Ok(())
}
pub async fn reset_password(token: &str, new_password: &str) -> Result<(), crate::AuthError> {
let Some(challenge) =
AuthChallenge::find_active_by_secret(token, PURPOSE_PASSWORD_RESET).await?
else {
return Err(crate::AuthError::InvalidChallenge);
};
let user_id: i64 = challenge.user_id.id();
let Some(user) = crate::AuthUser::objects()
.filter(crate::auth_user::ID.eq(user_id))
.first()
.await?
else {
return Err(crate::AuthError::InvalidChallenge);
};
crate::validate_password(
new_password,
&crate::PasswordContext::new(Some(&user.username), Some(&user.email)),
)
.map_err(crate::AuthError::WeakPassword)?;
let hash = crate::hash_password(new_password)?;
let challenge_id = challenge.id;
transaction(|tx| {
let hash = hash.clone();
Box::pin(async move {
let mut pw_delta = serde_json::Map::new();
pw_delta.insert("password_hash".to_string(), serde_json::json!(hash));
crate::AuthUser::objects()
.filter(crate::auth_user::ID.eq(user_id))
.on_tx(tx)
.update_values(pw_delta)
.await?;
let mut mark_delta = serde_json::Map::new();
mark_delta.insert("used_at".to_string(), serde_json::json!(Utc::now()));
AuthChallenge::objects()
.filter(auth_challenge::ID.eq(challenge_id))
.on_tx(tx)
.update_values(mark_delta)
.await?;
Ok::<_, crate::AuthError>(())
})
})
.await?;
if let Err(e) = crate::token::AuthToken::objects()
.filter(crate::token::auth_token::USER_ID.eq(user_id))
.delete()
.await
{
tracing::error!(user_id, error = %e, "password reset: failed to revoke bearer tokens");
}
if let Err(e) = umbral_sessions::revoke_user_sessions(&user_id.to_string()).await {
tracing::error!(user_id, error = %e, "password reset: failed to revoke sessions");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_code_is_six_ascii_digits() {
for _ in 0..20 {
let code = generate_code();
assert_eq!(
code.len(),
6,
"generate_code must produce exactly 6 chars; got {code:?}"
);
assert!(
code.chars().all(|c| c.is_ascii_digit()),
"generate_code must contain only ASCII digits; got {code:?}"
);
}
}
#[test]
fn generate_code_zero_pads_small_numbers() {
assert_eq!(format!("{:06}", 48u32), "000048");
let code = generate_code();
assert_eq!(code.len(), 6);
}
#[test]
fn generate_reset_token_has_prefix_and_length() {
let tok = generate_reset_token();
assert!(
tok.starts_with("umbral_"),
"token must start with 'umbral_'; got {tok:?}"
);
assert_eq!(
tok.len(),
50,
"token must be 50 chars (7 prefix + 43 base64); got {tok:?}"
);
}
#[test]
fn generate_reset_token_is_unique_per_call() {
let a = generate_reset_token();
let b = generate_reset_token();
assert_ne!(a, b, "two consecutive tokens must not collide");
}
#[test]
fn hash_secret_is_deterministic_and_delegates_to_digest_token() {
let a = hash_secret("483920");
let b = hash_secret("483920");
let c = hash_secret("999999");
assert_eq!(a, b, "hash_secret must be deterministic");
assert_ne!(a, c, "different inputs must produce different digests");
assert_eq!(
a.len(),
43,
"SHA-256 in URL-safe base64 (no pad) is 43 chars"
);
}
}