use std::time::Duration;
use crate::signed_url::{sign, verify, SignedUrlError};
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum AuthFlowError {
#[error("token is missing or malformed")]
Malformed,
#[error("token signature does not match")]
InvalidSignature,
#[error("token has expired")]
Expired,
#[error("token is for the wrong purpose ({0})")]
WrongPurpose(String),
#[error("password rejected: {0}")]
WeakPassword(String),
#[error("database error: {0}")]
Database(String),
}
impl From<SignedUrlError> for AuthFlowError {
fn from(e: SignedUrlError) -> Self {
match e {
SignedUrlError::MissingSignature | SignedUrlError::MalformedSignature => {
Self::Malformed
}
SignedUrlError::InvalidSignature => Self::InvalidSignature,
SignedUrlError::Expired => Self::Expired,
}
}
}
pub struct PasswordReset;
impl PasswordReset {
const PURPOSE: &'static str = "pwreset";
#[must_use]
pub fn issue(base_url: &str, user_id: i64, secret: &[u8], ttl: Duration) -> String {
let url = format!(
"{}?user_id={}&purpose={}",
base_url.trim_end_matches('?'),
user_id,
Self::PURPOSE,
);
sign(&url, secret, Some(ttl))
}
pub fn verify(url: &str, secret: &[u8]) -> Result<i64, AuthFlowError> {
verify(url, secret)?;
let purpose = extract_query(url, "purpose").ok_or(AuthFlowError::Malformed)?;
if purpose != Self::PURPOSE {
return Err(AuthFlowError::WrongPurpose(purpose));
}
let user_id_str = extract_query(url, "user_id").ok_or(AuthFlowError::Malformed)?;
user_id_str
.parse::<i64>()
.map_err(|_| AuthFlowError::Malformed)
}
}
#[cfg(feature = "passwords")]
pub async fn confirm_password_reset_pool(
pool: &crate::sql::Pool,
url: &str,
new_password: &str,
secret: &[u8],
) -> Result<i64, AuthFlowError> {
confirm_password_reset_pool_into(
pool,
url,
new_password,
secret,
"rustango_users",
"id",
"password_hash",
)
.await
}
#[cfg(feature = "passwords")]
pub async fn confirm_password_reset_pool_into(
pool: &crate::sql::Pool,
url: &str,
new_password: &str,
secret: &[u8],
user_table: &str,
pk_column: &str,
password_column: &str,
) -> Result<i64, AuthFlowError> {
let user_id = PasswordReset::verify(url, secret)?;
if new_password.len() < 8 {
return Err(AuthFlowError::WeakPassword(
"Password must be at least 8 characters.".into(),
));
}
let hash =
crate::passwords::hash(new_password).map_err(|e| AuthFlowError::Database(e.to_string()))?;
let dialect = pool.dialect();
let t = dialect.quote_ident(user_table);
let pw = dialect.quote_ident(password_column);
let pk = dialect.quote_ident(pk_column);
let p1 = dialect.placeholder(1);
let p2 = dialect.placeholder(2);
let sql = format!("UPDATE {t} SET {pw} = {p1} WHERE {pk} = {p2}");
crate::sql::raw_execute_pool(
pool,
&sql,
vec![
crate::core::SqlValue::String(hash),
crate::core::SqlValue::I64(user_id),
],
)
.await
.map_err(|e| AuthFlowError::Database(e.to_string()))?;
Ok(user_id)
}
pub struct EmailVerification;
impl EmailVerification {
const PURPOSE: &'static str = "verify_email";
#[must_use]
pub fn issue(
base_url: &str,
user_id: i64,
email: &str,
secret: &[u8],
ttl: Duration,
) -> String {
let url = format!(
"{}?user_id={}&email={}&purpose={}",
base_url.trim_end_matches('?'),
user_id,
url_encode(email),
Self::PURPOSE,
);
sign(&url, secret, Some(ttl))
}
pub fn verify(url: &str, secret: &[u8]) -> Result<(i64, String), AuthFlowError> {
verify(url, secret)?;
let purpose = extract_query(url, "purpose").ok_or(AuthFlowError::Malformed)?;
if purpose != Self::PURPOSE {
return Err(AuthFlowError::WrongPurpose(purpose));
}
let user_id_str = extract_query(url, "user_id").ok_or(AuthFlowError::Malformed)?;
let user_id = user_id_str
.parse::<i64>()
.map_err(|_| AuthFlowError::Malformed)?;
let email = extract_query(url, "email").ok_or(AuthFlowError::Malformed)?;
Ok((user_id, email))
}
}
pub struct MagicLink;
impl MagicLink {
const PURPOSE: &'static str = "magic_link";
#[must_use]
pub fn issue(base_url: &str, email: &str, secret: &[u8], ttl: Duration) -> String {
let url = format!(
"{}?email={}&purpose={}",
base_url.trim_end_matches('?'),
url_encode(email),
Self::PURPOSE,
);
sign(&url, secret, Some(ttl))
}
pub fn verify(url: &str, secret: &[u8]) -> Result<String, AuthFlowError> {
verify(url, secret)?;
let purpose = extract_query(url, "purpose").ok_or(AuthFlowError::Malformed)?;
if purpose != Self::PURPOSE {
return Err(AuthFlowError::WrongPurpose(purpose));
}
extract_query(url, "email").ok_or(AuthFlowError::Malformed)
}
}
fn extract_query(url: &str, key: &str) -> Option<String> {
let query = url.split_once('?')?.1;
for pair in query.split('&') {
let (k, v) = pair.split_once('=')?;
let k = url_decode(k);
if k == key {
return Some(url_decode(v));
}
}
None
}
fn url_encode(s: &str) -> String {
s.bytes()
.map(|b| {
if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.' | b'~') {
(b as char).to_string()
} else {
format!("%{b:02X}")
}
})
.collect()
}
use crate::url_codec::url_decode;
#[cfg(test)]
mod tests {
use super::*;
const SECRET: &[u8] = b"my-test-secret";
#[test]
fn password_reset_issue_and_verify_roundtrip() {
let url = PasswordReset::issue(
"https://app.example.com/reset",
42,
SECRET,
Duration::from_secs(3600),
);
let user_id = PasswordReset::verify(&url, SECRET).unwrap();
assert_eq!(user_id, 42);
}
#[test]
fn password_reset_wrong_secret_fails() {
let url = PasswordReset::issue("https://x/r", 42, SECRET, Duration::from_secs(3600));
let r = PasswordReset::verify(&url, b"different");
assert_eq!(r.unwrap_err(), AuthFlowError::InvalidSignature);
}
#[test]
fn password_reset_tampered_user_id_fails() {
let url = PasswordReset::issue("https://x/r", 42, SECRET, Duration::from_secs(3600));
let tampered = url.replace("user_id=42", "user_id=99");
let r = PasswordReset::verify(&tampered, SECRET);
assert_eq!(r.unwrap_err(), AuthFlowError::InvalidSignature);
}
#[test]
fn password_reset_rejects_email_verification_token() {
let url = EmailVerification::issue(
"https://x/r",
42,
"alice@x.com",
SECRET,
Duration::from_secs(3600),
);
let r = PasswordReset::verify(&url, SECRET);
assert!(matches!(r, Err(AuthFlowError::WrongPurpose(_))));
}
#[test]
fn email_verification_roundtrip() {
let url = EmailVerification::issue(
"https://x/v",
42,
"alice@example.com",
SECRET,
Duration::from_secs(86_400),
);
let (uid, email) = EmailVerification::verify(&url, SECRET).unwrap();
assert_eq!(uid, 42);
assert_eq!(email, "alice@example.com");
}
#[test]
fn email_verification_handles_special_chars() {
let url = EmailVerification::issue(
"https://x/v",
42,
"a+b@example.com",
SECRET,
Duration::from_secs(86_400),
);
let (_, email) = EmailVerification::verify(&url, SECRET).unwrap();
assert_eq!(email, "a+b@example.com");
}
#[test]
fn email_verification_rejects_password_reset_token() {
let url = PasswordReset::issue("https://x/v", 42, SECRET, Duration::from_secs(3600));
let r = EmailVerification::verify(&url, SECRET);
assert!(matches!(r, Err(AuthFlowError::WrongPurpose(_))));
}
#[test]
fn magic_link_roundtrip() {
let url = MagicLink::issue(
"https://x/login",
"alice@example.com",
SECRET,
Duration::from_secs(900),
);
let email = MagicLink::verify(&url, SECRET).unwrap();
assert_eq!(email, "alice@example.com");
}
#[test]
fn magic_link_rejects_password_reset_token() {
let url = PasswordReset::issue("https://x/r", 42, SECRET, Duration::from_secs(3600));
let r = MagicLink::verify(&url, SECRET);
assert!(matches!(r, Err(AuthFlowError::WrongPurpose(_))));
}
#[test]
fn extract_query_picks_right_param() {
let url = "https://x/path?a=1&b=2&c=3";
assert_eq!(extract_query(url, "b"), Some("2".to_owned()));
assert_eq!(extract_query(url, "missing"), None);
}
#[test]
fn extract_query_handles_url_encoded_value() {
let url = "https://x/path?email=alice%40x.com";
assert_eq!(extract_query(url, "email"), Some("alice@x.com".to_owned()));
}
#[test]
fn missing_purpose_param_treated_as_malformed() {
let url = sign(
"https://x/r?user_id=42",
SECRET,
Some(Duration::from_secs(60)),
);
let r = PasswordReset::verify(&url, SECRET);
assert_eq!(r.unwrap_err(), AuthFlowError::Malformed);
}
}