use std::time::Duration;
use crate::signed_url::{sign, verify, SignedUrlError};
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum AuthFlowError {
#[error("token is missing or malformed")]
Malformed,
#[error("token signature does not match")]
InvalidSignature,
#[error("token has expired")]
Expired,
#[error("token is for the wrong purpose ({0})")]
WrongPurpose(String),
#[error("password rejected: {0}")]
WeakPassword(String),
#[error("database error: {0}")]
Database(String),
#[error("token already used")]
AlreadyUsed,
}
impl From<SignedUrlError> for AuthFlowError {
fn from(e: SignedUrlError) -> Self {
match e {
SignedUrlError::MissingSignature | SignedUrlError::MalformedSignature => {
Self::Malformed
}
SignedUrlError::InvalidSignature => Self::InvalidSignature,
SignedUrlError::Expired => Self::Expired,
}
}
}
pub struct PasswordReset;
impl PasswordReset {
const PURPOSE: &'static str = "pwreset";
#[must_use]
pub fn issue(base_url: &str, user_id: i64, secret: &[u8], ttl: Duration) -> String {
let url = format!(
"{}?user_id={}&purpose={}",
base_url.trim_end_matches('?'),
user_id,
Self::PURPOSE,
);
sign(&url, secret, Some(ttl))
}
pub fn verify(url: &str, secret: &[u8]) -> Result<i64, AuthFlowError> {
verify(url, secret)?;
let purpose = extract_query(url, "purpose").ok_or(AuthFlowError::Malformed)?;
if purpose != Self::PURPOSE {
return Err(AuthFlowError::WrongPurpose(purpose));
}
let user_id_str = extract_query(url, "user_id").ok_or(AuthFlowError::Malformed)?;
user_id_str
.parse::<i64>()
.map_err(|_| AuthFlowError::Malformed)
}
#[cfg(feature = "cache")]
pub async fn verify_single_use(
url: &str,
secret: &[u8],
cache: &std::sync::Arc<dyn crate::cache::Cache>,
) -> Result<i64, AuthFlowError> {
let user_id = Self::verify(url, secret)?;
consume_single_use(url, cache).await?;
Ok(user_id)
}
}
#[cfg(feature = "passwords")]
pub async fn confirm_password_reset_pool(
pool: &crate::sql::Pool,
url: &str,
new_password: &str,
secret: &[u8],
) -> Result<i64, AuthFlowError> {
confirm_password_reset_pool_into(
pool,
url,
new_password,
secret,
"rustango_users",
"id",
"password_hash",
)
.await
}
#[cfg(feature = "passwords")]
pub async fn confirm_password_reset_pool_into(
pool: &crate::sql::Pool,
url: &str,
new_password: &str,
secret: &[u8],
user_table: &str,
pk_column: &str,
password_column: &str,
) -> Result<i64, AuthFlowError> {
let user_id = PasswordReset::verify(url, secret)?;
if new_password.len() < 8 {
return Err(AuthFlowError::WeakPassword(
"Password must be at least 8 characters.".into(),
));
}
let hash =
crate::passwords::hash(new_password).map_err(|e| AuthFlowError::Database(e.to_string()))?;
let dialect = pool.dialect();
let t = dialect.quote_ident(user_table);
let pw = dialect.quote_ident(password_column);
let pk = dialect.quote_ident(pk_column);
let p1 = dialect.placeholder(1);
let p2 = dialect.placeholder(2);
let sql = format!("UPDATE {t} SET {pw} = {p1} WHERE {pk} = {p2}");
crate::sql::raw_execute_pool(
pool,
&sql,
vec![
crate::core::SqlValue::String(hash),
crate::core::SqlValue::I64(user_id),
],
)
.await
.map_err(|e| AuthFlowError::Database(e.to_string()))?;
Ok(user_id)
}
pub struct EmailVerification;
impl EmailVerification {
const PURPOSE: &'static str = "verify_email";
#[must_use]
pub fn issue(
base_url: &str,
user_id: i64,
email: &str,
secret: &[u8],
ttl: Duration,
) -> String {
let url = format!(
"{}?user_id={}&email={}&purpose={}",
base_url.trim_end_matches('?'),
user_id,
url_encode(email),
Self::PURPOSE,
);
sign(&url, secret, Some(ttl))
}
pub fn verify(url: &str, secret: &[u8]) -> Result<(i64, String), AuthFlowError> {
verify(url, secret)?;
let purpose = extract_query(url, "purpose").ok_or(AuthFlowError::Malformed)?;
if purpose != Self::PURPOSE {
return Err(AuthFlowError::WrongPurpose(purpose));
}
let user_id_str = extract_query(url, "user_id").ok_or(AuthFlowError::Malformed)?;
let user_id = user_id_str
.parse::<i64>()
.map_err(|_| AuthFlowError::Malformed)?;
let email = extract_query(url, "email").ok_or(AuthFlowError::Malformed)?;
Ok((user_id, email))
}
#[cfg(feature = "cache")]
pub async fn verify_single_use(
url: &str,
secret: &[u8],
cache: &std::sync::Arc<dyn crate::cache::Cache>,
) -> Result<(i64, String), AuthFlowError> {
let out = Self::verify(url, secret)?;
consume_single_use(url, cache).await?;
Ok(out)
}
}
pub struct MagicLink;
impl MagicLink {
const PURPOSE: &'static str = "magic_link";
#[must_use]
pub fn issue(base_url: &str, email: &str, secret: &[u8], ttl: Duration) -> String {
let url = format!(
"{}?email={}&purpose={}",
base_url.trim_end_matches('?'),
url_encode(email),
Self::PURPOSE,
);
sign(&url, secret, Some(ttl))
}
pub fn verify(url: &str, secret: &[u8]) -> Result<String, AuthFlowError> {
verify(url, secret)?;
let purpose = extract_query(url, "purpose").ok_or(AuthFlowError::Malformed)?;
if purpose != Self::PURPOSE {
return Err(AuthFlowError::WrongPurpose(purpose));
}
extract_query(url, "email").ok_or(AuthFlowError::Malformed)
}
#[cfg(feature = "cache")]
pub async fn verify_single_use(
url: &str,
secret: &[u8],
cache: &std::sync::Arc<dyn crate::cache::Cache>,
) -> Result<String, AuthFlowError> {
let email = Self::verify(url, secret)?;
consume_single_use(url, cache).await?;
Ok(email)
}
}
fn extract_query(url: &str, key: &str) -> Option<String> {
let query = url.split_once('?')?.1;
for pair in query.split('&') {
let (k, v) = pair.split_once('=')?;
let k = url_decode(k);
if k == key {
return Some(url_decode(v));
}
}
None
}
#[cfg(feature = "cache")]
async fn consume_single_use(
url: &str,
cache: &std::sync::Arc<dyn crate::cache::Cache>,
) -> Result<(), AuthFlowError> {
let sig = extract_query(url, "signature").ok_or(AuthFlowError::Malformed)?;
let key = format!("authflow_used:{sig}");
match cache.exists(&key).await {
Ok(true) => return Err(AuthFlowError::AlreadyUsed),
Ok(false) => {}
Err(_) => return Err(AuthFlowError::AlreadyUsed),
}
let _ = cache.set(&key, "1", Some(single_use_ttl(url))).await;
Ok(())
}
#[cfg(feature = "cache")]
fn single_use_ttl(url: &str) -> Duration {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs());
extract_query(url, "expires")
.and_then(|s| s.parse::<u64>().ok())
.map(|exp| Duration::from_secs(exp.saturating_sub(now)))
.filter(|d| !d.is_zero())
.unwrap_or(Duration::from_secs(3600))
}
use crate::url_codec::{url_decode, url_encode};
#[cfg(test)]
mod tests {
use super::*;
const SECRET: &[u8] = b"my-test-secret";
#[test]
fn password_reset_issue_and_verify_roundtrip() {
let url = PasswordReset::issue(
"https://app.example.com/reset",
42,
SECRET,
Duration::from_secs(3600),
);
let user_id = PasswordReset::verify(&url, SECRET).unwrap();
assert_eq!(user_id, 42);
}
#[test]
fn password_reset_wrong_secret_fails() {
let url = PasswordReset::issue("https://x/r", 42, SECRET, Duration::from_secs(3600));
let r = PasswordReset::verify(&url, b"different");
assert_eq!(r.unwrap_err(), AuthFlowError::InvalidSignature);
}
#[test]
fn password_reset_tampered_user_id_fails() {
let url = PasswordReset::issue("https://x/r", 42, SECRET, Duration::from_secs(3600));
let tampered = url.replace("user_id=42", "user_id=99");
let r = PasswordReset::verify(&tampered, SECRET);
assert_eq!(r.unwrap_err(), AuthFlowError::InvalidSignature);
}
#[test]
fn password_reset_rejects_email_verification_token() {
let url = EmailVerification::issue(
"https://x/r",
42,
"alice@x.com",
SECRET,
Duration::from_secs(3600),
);
let r = PasswordReset::verify(&url, SECRET);
assert!(matches!(r, Err(AuthFlowError::WrongPurpose(_))));
}
#[test]
fn email_verification_roundtrip() {
let url = EmailVerification::issue(
"https://x/v",
42,
"alice@example.com",
SECRET,
Duration::from_secs(86_400),
);
let (uid, email) = EmailVerification::verify(&url, SECRET).unwrap();
assert_eq!(uid, 42);
assert_eq!(email, "alice@example.com");
}
#[test]
fn email_verification_handles_special_chars() {
let url = EmailVerification::issue(
"https://x/v",
42,
"a+b@example.com",
SECRET,
Duration::from_secs(86_400),
);
let (_, email) = EmailVerification::verify(&url, SECRET).unwrap();
assert_eq!(email, "a+b@example.com");
}
#[test]
fn email_verification_rejects_password_reset_token() {
let url = PasswordReset::issue("https://x/v", 42, SECRET, Duration::from_secs(3600));
let r = EmailVerification::verify(&url, SECRET);
assert!(matches!(r, Err(AuthFlowError::WrongPurpose(_))));
}
#[test]
fn magic_link_roundtrip() {
let url = MagicLink::issue(
"https://x/login",
"alice@example.com",
SECRET,
Duration::from_secs(900),
);
let email = MagicLink::verify(&url, SECRET).unwrap();
assert_eq!(email, "alice@example.com");
}
#[test]
fn magic_link_rejects_password_reset_token() {
let url = PasswordReset::issue("https://x/r", 42, SECRET, Duration::from_secs(3600));
let r = MagicLink::verify(&url, SECRET);
assert!(matches!(r, Err(AuthFlowError::WrongPurpose(_))));
}
#[cfg(feature = "cache")]
#[tokio::test]
async fn magic_link_single_use_rejects_replay() {
use crate::cache::InMemoryCache;
let cache: std::sync::Arc<dyn crate::cache::Cache> =
std::sync::Arc::new(InMemoryCache::new());
let url = MagicLink::issue(
"https://x/login",
"alice@example.com",
SECRET,
Duration::from_secs(900),
);
let email = MagicLink::verify_single_use(&url, SECRET, &cache)
.await
.unwrap();
assert_eq!(email, "alice@example.com");
let replay = MagicLink::verify_single_use(&url, SECRET, &cache).await;
assert!(
matches!(replay, Err(AuthFlowError::AlreadyUsed)),
"{replay:?}"
);
let other = MagicLink::issue(
"https://x/login",
"bob@example.com",
SECRET,
Duration::from_secs(900),
);
assert!(MagicLink::verify_single_use(&other, SECRET, &cache)
.await
.is_ok());
}
#[test]
fn extract_query_picks_right_param() {
let url = "https://x/path?a=1&b=2&c=3";
assert_eq!(extract_query(url, "b"), Some("2".to_owned()));
assert_eq!(extract_query(url, "missing"), None);
}
#[test]
fn extract_query_handles_url_encoded_value() {
let url = "https://x/path?email=alice%40x.com";
assert_eq!(extract_query(url, "email"), Some("alice@x.com".to_owned()));
}
#[test]
fn missing_purpose_param_treated_as_malformed() {
let url = sign(
"https://x/r?user_id=42",
SECRET,
Some(Duration::from_secs(60)),
);
let r = PasswordReset::verify(&url, SECRET);
assert_eq!(r.unwrap_err(), AuthFlowError::Malformed);
}
}