use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Mutex, OnceLock};
use std::time::{Duration as StdDuration, Instant};
use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use argon2::Argon2;
use chrono::{DateTime, Duration, Utc};
use rand::rngs::OsRng;
use rand::RngCore;
use sqlx::Row as _;
use crate::context::Context;
use crate::error::Error;
use crate::http::{Request, Response};
use crate::middleware::Next;
use crate::orm::Db;
pub const SESSION_COOKIE: &str = "rustio_session";
pub const SESSION_TTL_DAYS: i64 = 7;
const SESSION_TOKEN_BYTES: usize = 32;
pub const ROLE_ADMIN: &str = "admin";
pub const ROLE_USER: &str = "user";
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
#[derive(Debug, Clone, PartialEq)]
pub struct User {
pub id: i64,
pub email: String,
pub password_hash: String,
pub is_active: bool,
pub role: String,
}
impl User {
pub fn is_admin(&self) -> bool {
crate::admin::rbac::Role::from_role_string(&self.role).is_some()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Identity {
pub user_id: i64,
pub email: String,
pub is_admin: bool,
}
impl From<&User> for Identity {
fn from(u: &User) -> Self {
Self {
user_id: u.id,
email: u.email.clone(),
is_admin: u.is_admin(),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub struct Session {
pub id: String,
pub user_id: i64,
pub expires_at: DateTime<Utc>,
pub csrf_token: String,
}
pub mod password {
use super::*;
pub fn hash(password: &str) -> Result<String, Error> {
if password.is_empty() {
return Err(Error::BadRequest("password must not be empty".into()));
}
let salt = SaltString::generate(&mut OsRng);
let hash = Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map_err(|e| Error::Internal(format!("password hashing failed: {e}")))?;
Ok(hash.to_string())
}
pub fn verify(password: &str, stored: &str) -> bool {
let Ok(parsed) = PasswordHash::new(stored) else {
return false;
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed)
.is_ok()
}
}
fn generate_token() -> String {
use std::fmt::Write;
let mut buf = [0u8; SESSION_TOKEN_BYTES];
OsRng.fill_bytes(&mut buf);
let mut out = String::with_capacity(SESSION_TOKEN_BYTES * 2);
for b in buf {
let _ = write!(out, "{b:02x}");
}
out
}
#[derive(Debug, Clone, PartialEq)]
pub struct CsrfToken(pub String);
pub mod csrf {
pub fn generate_token() -> String {
super::generate_token()
}
pub fn verify_token(expected: &str, provided: &str) -> bool {
if expected.is_empty() || provided.is_empty() {
return false;
}
if expected.len() != provided.len() {
return false;
}
let mut diff: u8 = 0;
for (a, b) in expected.bytes().zip(provided.bytes()) {
diff |= a ^ b;
}
diff == 0
}
}
pub mod session {
use super::*;
pub async fn create(db: &Db, user_id: i64) -> Result<Session, Error> {
let id = generate_token();
let csrf_token = csrf::generate_token();
let expires_at = Utc::now() + Duration::days(SESSION_TTL_DAYS);
sqlx::query(
"INSERT INTO rustio_sessions (id, user_id, expires_at, csrf_token)
VALUES (?, ?, ?, ?)",
)
.bind(&id)
.bind(user_id)
.bind(expires_at)
.bind(&csrf_token)
.execute(db.pool())
.await?;
Ok(Session {
id,
user_id,
expires_at,
csrf_token,
})
}
pub async fn find_valid(db: &Db, id: &str) -> Result<Option<Session>, Error> {
let row = sqlx::query(
"SELECT id, user_id, expires_at, csrf_token
FROM rustio_sessions WHERE id = ?",
)
.bind(id)
.fetch_optional(db.pool())
.await?;
let Some(r) = row else {
return Ok(None);
};
let expires_at: DateTime<Utc> = r.try_get("expires_at")?;
if expires_at <= Utc::now() {
let _ = delete(db, id).await;
return Ok(None);
}
Ok(Some(Session {
id: r.try_get("id")?,
user_id: r.try_get("user_id")?,
expires_at,
csrf_token: r.try_get("csrf_token")?,
}))
}
pub async fn delete(db: &Db, id: &str) -> Result<(), Error> {
sqlx::query("DELETE FROM rustio_sessions WHERE id = ?")
.bind(id)
.execute(db.pool())
.await?;
Ok(())
}
pub async fn sweep_expired(db: &Db) -> Result<u64, Error> {
let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at <= ?")
.bind(Utc::now())
.execute(db.pool())
.await?;
Ok(result.rows_affected())
}
}
pub mod user {
use super::*;
pub async fn create(db: &Db, email: &str, password: &str, role: &str) -> Result<User, Error> {
let email = normalise_email(email);
validate_email(&email)?;
if role != ROLE_ADMIN && role != ROLE_USER {
return Err(Error::BadRequest(format!(
"role must be `{ROLE_ADMIN}` or `{ROLE_USER}`, got `{role}`"
)));
}
let hash = password::hash(password)?;
let result = sqlx::query(
"INSERT INTO rustio_users (email, password_hash, is_active, role)
VALUES (?, ?, 1, ?)",
)
.bind(&email)
.bind(&hash)
.bind(role)
.execute(db.pool())
.await
.map_err(|e| match &e {
sqlx::Error::Database(de) if de.is_unique_violation() => {
Error::BadRequest(format!("a user with email `{email}` already exists"))
}
_ => Error::from(e),
})?;
Ok(User {
id: result.last_insert_rowid(),
email,
password_hash: hash,
is_active: true,
role: role.to_string(),
})
}
pub async fn find_by_email(db: &Db, email: &str) -> Result<Option<User>, Error> {
let email = normalise_email(email);
let row = sqlx::query(
"SELECT id, email, password_hash, is_active, role
FROM rustio_users WHERE email = ?",
)
.bind(&email)
.fetch_optional(db.pool())
.await?;
match row {
Some(r) => Ok(Some(user_from_row(&r)?)),
None => Ok(None),
}
}
pub async fn find_by_id(db: &Db, id: i64) -> Result<Option<User>, Error> {
let row = sqlx::query(
"SELECT id, email, password_hash, is_active, role
FROM rustio_users WHERE id = ?",
)
.bind(id)
.fetch_optional(db.pool())
.await?;
match row {
Some(r) => Ok(Some(user_from_row(&r)?)),
None => Ok(None),
}
}
pub async fn set_password(db: &Db, id: i64, password: &str) -> Result<(), Error> {
let hash = password::hash(password)?;
let mut tx = db.pool().begin().await?;
sqlx::query("UPDATE rustio_users SET password_hash = ? WHERE id = ?")
.bind(&hash)
.bind(id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM rustio_sessions WHERE user_id = ?")
.bind(id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
pub async fn set_active(db: &Db, id: i64, is_active: bool) -> Result<(), Error> {
sqlx::query("UPDATE rustio_users SET is_active = ? WHERE id = ?")
.bind(is_active)
.bind(id)
.execute(db.pool())
.await?;
Ok(())
}
pub async fn count(db: &Db) -> Result<i64, Error> {
let row = sqlx::query("SELECT COUNT(*) FROM rustio_users")
.fetch_one(db.pool())
.await?;
Ok(row.try_get(0)?)
}
fn user_from_row(r: &sqlx::sqlite::SqliteRow) -> Result<User, Error> {
Ok(User {
id: r.try_get("id")?,
email: r.try_get("email")?,
password_hash: r.try_get("password_hash")?,
is_active: r.try_get("is_active")?,
role: r.try_get("role")?,
})
}
}
pub fn normalise_email(email: &str) -> String {
email.trim().to_lowercase()
}
pub fn validate_email(email: &str) -> Result<(), Error> {
if email.is_empty() {
return Err(Error::BadRequest("email must not be empty".into()));
}
let Some((local, domain)) = email.split_once('@') else {
return Err(Error::BadRequest(format!(
"`{email}` is not a valid email (missing @)"
)));
};
if local.is_empty() || domain.is_empty() || !domain.contains('.') {
return Err(Error::BadRequest(format!("`{email}` is not a valid email")));
}
Ok(())
}
pub fn dummy_password_hash() -> &'static str {
static DUMMY: OnceLock<String> = OnceLock::new();
DUMMY.get_or_init(|| {
password::hash("timing-attack-filler-not-a-real-password").expect("dummy hash must succeed")
})
}
struct FailureEntry {
count: u32,
locked_until: Instant,
}
pub struct LoginRateLimiter {
failures: Mutex<HashMap<String, FailureEntry>>,
max_failures: u32,
lockout: StdDuration,
}
impl LoginRateLimiter {
pub const MAX_FAILURES: u32 = 5;
pub const LOCKOUT: StdDuration = StdDuration::from_secs(60);
pub fn new() -> Self {
Self::with_params(Self::MAX_FAILURES, Self::LOCKOUT)
}
pub fn with_params(max_failures: u32, lockout: StdDuration) -> Self {
Self {
failures: Mutex::new(HashMap::new()),
max_failures,
lockout,
}
}
pub fn global() -> &'static Self {
static INSTANCE: OnceLock<LoginRateLimiter> = OnceLock::new();
INSTANCE.get_or_init(LoginRateLimiter::new)
}
pub fn check(&self, key: &str) -> Result<(), StdDuration> {
let mut map = self.failures.lock().expect("rate-limiter mutex poisoned");
match map.get(key) {
Some(entry) if entry.count >= self.max_failures => {
let now = Instant::now();
if entry.locked_until > now {
Err(entry.locked_until - now)
} else {
map.remove(key);
Ok(())
}
}
_ => Ok(()),
}
}
pub fn record_failure(&self, key: &str) {
let mut map = self.failures.lock().expect("rate-limiter mutex poisoned");
let entry = map.entry(key.to_string()).or_insert(FailureEntry {
count: 0,
locked_until: Instant::now(),
});
entry.count = entry.count.saturating_add(1);
if entry.count >= self.max_failures {
entry.locked_until = Instant::now() + self.lockout;
}
}
pub fn record_success(&self, key: &str) {
self.failures
.lock()
.expect("rate-limiter mutex poisoned")
.remove(key);
}
pub fn compose_key(email: &str, ip: Option<&str>) -> String {
match ip {
Some(ip) => format!("email:{email}|ip:{ip}"),
None => format!("email:{email}"),
}
}
}
impl Default for LoginRateLimiter {
fn default() -> Self {
Self::new()
}
}
pub async fn resolve_identity_with_session(
db: &Db,
token: Option<&str>,
) -> Option<(Identity, Session)> {
let token = token?;
let sess = session::find_valid(db, token).await.ok().flatten()?;
let user = user::find_by_id(db, sess.user_id).await.ok().flatten()?;
if !user.is_active {
return None;
}
Some((Identity::from(&user), sess))
}
pub async fn resolve_identity(db: &Db, token: Option<&str>) -> Option<Identity> {
resolve_identity_with_session(db, token)
.await
.map(|(identity, _)| identity)
}
pub fn authenticate(
db: Db,
) -> impl Fn(Request, Next) -> BoxFuture<Result<Response, Error>> + Send + Sync + Clone + 'static {
move |mut req, next| {
let db = db.clone();
Box::pin(async move {
let token = req.cookie(SESSION_COOKIE);
if let Some((identity, session)) =
resolve_identity_with_session(&db, token.as_deref()).await
{
req.ctx_mut().insert(CsrfToken(session.csrf_token));
req.ctx_mut().insert(identity);
}
next.run(req).await
})
}
}
pub async fn ensure_core_tables(db: &Db) -> Result<(), Error> {
db.execute(
"CREATE TABLE IF NOT EXISTS rustio_users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
is_active INTEGER NOT NULL DEFAULT 1,
role TEXT NOT NULL DEFAULT 'user',
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)",
)
.await?;
db.execute(
"CREATE TABLE IF NOT EXISTS rustio_sessions (
id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
expires_at TEXT NOT NULL,
csrf_token TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL DEFAULT (datetime('now')),
FOREIGN KEY (user_id) REFERENCES rustio_users(id) ON DELETE CASCADE
)",
)
.await?;
let cols: Vec<String> =
sqlx::query_scalar::<_, String>("SELECT name FROM pragma_table_info('rustio_sessions')")
.fetch_all(db.pool())
.await?;
if !cols.iter().any(|c| c == "csrf_token") {
db.execute("ALTER TABLE rustio_sessions ADD COLUMN csrf_token TEXT NOT NULL DEFAULT ''")
.await?;
}
db.execute(
"CREATE TABLE IF NOT EXISTS rustio_admin_actions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
action_type TEXT NOT NULL,
model_name TEXT NOT NULL,
object_id INTEGER NOT NULL,
timestamp TEXT NOT NULL,
ip_address TEXT NULL,
summary TEXT NOT NULL,
FOREIGN KEY (user_id) REFERENCES rustio_users(id) ON DELETE CASCADE
)",
)
.await?;
db.execute(
"CREATE INDEX IF NOT EXISTS idx_rustio_admin_actions_model_object
ON rustio_admin_actions(model_name, object_id)",
)
.await?;
db.execute(
"CREATE INDEX IF NOT EXISTS idx_rustio_admin_actions_timestamp
ON rustio_admin_actions(timestamp DESC)",
)
.await?;
Ok(())
}
pub fn in_production() -> bool {
std::env::var("RUSTIO_ENV")
.map(|v| {
let v = v.to_ascii_lowercase();
v == "production" || v == "prod"
})
.unwrap_or(false)
}
pub fn bearer_token(req: &Request) -> Option<&str> {
req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
}
pub fn identity(ctx: &Context) -> Option<&Identity> {
ctx.get::<Identity>()
}
pub fn require_auth(ctx: &Context) -> Result<&Identity, Error> {
identity(ctx).ok_or(Error::Unauthorized)
}
pub fn require_admin(ctx: &Context) -> Result<&Identity, Error> {
let id = require_auth(ctx)?;
if !id.is_admin {
return Err(Error::Forbidden);
}
Ok(id)
}
#[cfg(test)]
mod tests {
use super::*;
fn admin_identity() -> Identity {
Identity {
user_id: 1,
email: "admin@example.com".into(),
is_admin: true,
}
}
fn user_identity() -> Identity {
Identity {
user_id: 2,
email: "user@example.com".into(),
is_admin: false,
}
}
#[test]
fn identity_returns_none_when_absent() {
let ctx = Context::new();
assert!(identity(&ctx).is_none());
}
#[test]
fn identity_returns_reference_when_attached() {
let mut ctx = Context::new();
ctx.insert(user_identity());
assert_eq!(
identity(&ctx).map(|i| i.email.as_str()),
Some("user@example.com")
);
}
#[test]
fn require_auth_missing_returns_unauthorized() {
let ctx = Context::new();
assert!(matches!(require_auth(&ctx), Err(Error::Unauthorized)));
}
#[test]
fn require_admin_non_admin_returns_forbidden() {
let mut ctx = Context::new();
ctx.insert(user_identity());
assert!(matches!(require_admin(&ctx), Err(Error::Forbidden)));
}
#[test]
fn require_admin_admin_returns_identity() {
let mut ctx = Context::new();
ctx.insert(admin_identity());
let id = require_admin(&ctx).unwrap();
assert!(id.is_admin);
}
#[test]
fn hash_then_verify_succeeds() {
let h = password::hash("correct horse battery staple").unwrap();
assert!(password::verify("correct horse battery staple", &h));
}
#[test]
fn verify_wrong_password_fails() {
let h = password::hash("real").unwrap();
assert!(!password::verify("fake", &h));
}
#[test]
fn verify_invalid_hash_returns_false_without_panic() {
assert!(!password::verify("anything", ""));
assert!(!password::verify("anything", "not a phc string"));
assert!(!password::verify("anything", "$argon2id$v=19$m=1"));
}
#[test]
fn hash_rejects_empty_password() {
assert!(matches!(password::hash(""), Err(Error::BadRequest(_))));
}
#[test]
fn hash_is_salted_so_same_input_produces_different_hash() {
let a = password::hash("same").unwrap();
let b = password::hash("same").unwrap();
assert_ne!(a, b, "identical inputs must produce different hashes");
assert!(password::verify("same", &a));
assert!(password::verify("same", &b));
}
#[test]
fn normalise_email_trims_and_lowercases() {
assert_eq!(
normalise_email(" Alice@EXAMPLE.com "),
"alice@example.com"
);
}
#[test]
fn validate_email_accepts_reasonable_forms() {
assert!(validate_email("a@b.co").is_ok());
assert!(validate_email("alice.smith+tag@example.co.uk").is_ok());
}
#[test]
fn validate_email_rejects_bad_forms() {
assert!(validate_email("").is_err());
assert!(validate_email("no-at-sign").is_err());
assert!(validate_email("@no-local").is_err());
assert!(validate_email("no-domain@").is_err());
assert!(validate_email("no-dot@localhost").is_err());
}
#[test]
fn generate_token_is_stable_length_and_hex() {
let t = generate_token();
assert_eq!(t.len(), SESSION_TOKEN_BYTES * 2);
assert!(t.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn generate_token_does_not_repeat() {
let a = generate_token();
let b = generate_token();
assert_ne!(a, b);
}
async fn setup() -> Db {
let db = Db::memory().await.unwrap();
ensure_core_tables(&db).await.unwrap();
db
}
#[tokio::test]
async fn user_create_round_trips() {
let db = setup().await;
let u = user::create(&db, "Admin@Example.com", "hunter2", ROLE_ADMIN)
.await
.unwrap();
assert_eq!(u.email, "admin@example.com");
assert!(u.is_admin());
assert!(u.is_active);
let lookup = user::find_by_email(&db, "ADMIN@example.com")
.await
.unwrap()
.unwrap();
assert_eq!(lookup.id, u.id);
assert!(password::verify("hunter2", &lookup.password_hash));
}
#[test]
fn is_admin_recognises_0_10_role_strings() {
assert!(user_with_role("admin").is_admin());
assert!(!user_with_role("user").is_admin());
assert!(!user_with_role("").is_admin());
assert!(user_with_role("superadmin").is_admin());
assert!(user_with_role("restricted_admin").is_admin());
assert!(user_with_role("editor").is_admin());
assert!(user_with_role("viewer").is_admin());
assert!(!user_with_role("nobody").is_admin());
}
fn user_with_role(role: &str) -> User {
User {
id: 1,
email: "t@example.com".into(),
password_hash: "x".into(),
is_active: true,
role: role.into(),
}
}
#[tokio::test]
async fn user_create_rejects_duplicate_email() {
let db = setup().await;
user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
let err = user::create(&db, "a@b.co", "pw2", ROLE_USER).await;
assert!(matches!(err, Err(Error::BadRequest(_))));
}
#[tokio::test]
async fn user_create_rejects_unknown_role() {
let db = setup().await;
let err = user::create(&db, "a@b.co", "pw", "emperor").await;
assert!(matches!(err, Err(Error::BadRequest(_))));
}
#[tokio::test]
async fn set_password_changes_verifiable_hash() {
let db = setup().await;
let u = user::create(&db, "a@b.co", "old", ROLE_USER).await.unwrap();
user::set_password(&db, u.id, "new").await.unwrap();
let reloaded = user::find_by_id(&db, u.id).await.unwrap().unwrap();
assert!(!password::verify("old", &reloaded.password_hash));
assert!(password::verify("new", &reloaded.password_hash));
}
#[tokio::test]
async fn set_active_toggles_flag() {
let db = setup().await;
let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
user::set_active(&db, u.id, false).await.unwrap();
let reloaded = user::find_by_id(&db, u.id).await.unwrap().unwrap();
assert!(!reloaded.is_active);
}
#[tokio::test]
async fn session_create_and_find_returns_live_session() {
let db = setup().await;
let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
let s = session::create(&db, u.id).await.unwrap();
let found = session::find_valid(&db, &s.id).await.unwrap().unwrap();
assert_eq!(found.user_id, u.id);
assert_eq!(found.id, s.id);
assert!(found.expires_at > Utc::now());
}
#[tokio::test]
async fn session_lookup_rejects_unknown_token() {
let db = setup().await;
let out = session::find_valid(&db, "deadbeef").await.unwrap();
assert!(out.is_none());
}
#[tokio::test]
async fn session_lookup_rejects_expired_session() {
let db = setup().await;
let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
let token = generate_token();
sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
.bind(&token)
.bind(u.id)
.bind(Utc::now() - Duration::seconds(1))
.execute(db.pool())
.await
.unwrap();
let out = session::find_valid(&db, &token).await.unwrap();
assert!(out.is_none(), "expired sessions must not validate");
}
#[tokio::test]
async fn session_delete_invalidates_lookup() {
let db = setup().await;
let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
let s = session::create(&db, u.id).await.unwrap();
session::delete(&db, &s.id).await.unwrap();
assert!(session::find_valid(&db, &s.id).await.unwrap().is_none());
}
#[tokio::test]
async fn sweep_expired_removes_only_expired() {
let db = setup().await;
let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
let live = session::create(&db, u.id).await.unwrap();
let dead_token = generate_token();
sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
.bind(&dead_token)
.bind(u.id)
.bind(Utc::now() - Duration::seconds(1))
.execute(db.pool())
.await
.unwrap();
let removed = session::sweep_expired(&db).await.unwrap();
assert_eq!(removed, 1);
assert!(session::find_valid(&db, &live.id).await.unwrap().is_some());
assert!(session::find_valid(&db, &dead_token)
.await
.unwrap()
.is_none());
}
#[tokio::test]
async fn deleting_user_cascades_to_sessions() {
let db = setup().await;
let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
let s = session::create(&db, u.id).await.unwrap();
assert!(session::find_valid(&db, &s.id).await.unwrap().is_some());
sqlx::query("DELETE FROM rustio_users WHERE id = ?")
.bind(u.id)
.execute(db.pool())
.await
.unwrap();
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM rustio_sessions")
.fetch_one(db.pool())
.await
.unwrap();
assert_eq!(
count, 0,
"FK cascade must have removed the orphan session; is PRAGMA foreign_keys on?"
);
}
#[tokio::test]
async fn ensure_core_tables_is_idempotent() {
let db = setup().await; ensure_core_tables(&db).await.unwrap();
ensure_core_tables(&db).await.unwrap();
assert_eq!(user::count(&db).await.unwrap(), 0);
}
async fn seeded_user(db: &Db, role: &str) -> User {
user::create(db, "u@example.com", "pw", role).await.unwrap()
}
#[tokio::test]
async fn resolve_identity_none_cookie_returns_none() {
let db = setup().await;
assert!(resolve_identity(&db, None).await.is_none());
}
#[tokio::test]
async fn resolve_identity_unknown_token_returns_none() {
let db = setup().await;
assert!(resolve_identity(&db, Some("not-a-real-token"))
.await
.is_none());
}
#[tokio::test]
async fn resolve_identity_expired_session_returns_none() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let token = generate_token();
sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
.bind(&token)
.bind(u.id)
.bind(Utc::now() - Duration::seconds(1))
.execute(db.pool())
.await
.unwrap();
assert!(resolve_identity(&db, Some(&token)).await.is_none());
}
#[tokio::test]
async fn resolve_identity_inactive_user_returns_none() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
user::set_active(&db, u.id, false).await.unwrap();
let s = session::create(&db, u.id).await.unwrap();
assert!(
resolve_identity(&db, Some(&s.id)).await.is_none(),
"inactive users must not resolve to an Identity"
);
}
#[tokio::test]
async fn resolve_identity_deleted_user_returns_none() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let s = session::create(&db, u.id).await.unwrap();
sqlx::query("DELETE FROM rustio_users WHERE id = ?")
.bind(u.id)
.execute(db.pool())
.await
.unwrap();
assert!(resolve_identity(&db, Some(&s.id)).await.is_none());
}
#[tokio::test]
async fn resolve_identity_valid_admin_session_attaches_admin_identity() {
let db = setup().await;
let u = seeded_user(&db, ROLE_ADMIN).await;
let s = session::create(&db, u.id).await.unwrap();
let id = resolve_identity(&db, Some(&s.id)).await.unwrap();
assert_eq!(id.user_id, u.id);
assert!(id.is_admin);
}
#[tokio::test]
async fn resolve_identity_valid_user_session_attaches_non_admin_identity() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let s = session::create(&db, u.id).await.unwrap();
let id = resolve_identity(&db, Some(&s.id)).await.unwrap();
assert_eq!(id.user_id, u.id);
assert!(!id.is_admin);
}
#[tokio::test]
async fn changing_password_invalidates_all_user_sessions() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let s1 = session::create(&db, u.id).await.unwrap();
let s2 = session::create(&db, u.id).await.unwrap();
assert!(session::find_valid(&db, &s1.id).await.unwrap().is_some());
assert!(session::find_valid(&db, &s2.id).await.unwrap().is_some());
user::set_password(&db, u.id, "new password").await.unwrap();
let remaining: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM rustio_sessions WHERE user_id = ?")
.bind(u.id)
.fetch_one(db.pool())
.await
.unwrap();
assert_eq!(
remaining, 0,
"password change must wipe every live session for the user"
);
assert!(session::find_valid(&db, &s1.id).await.unwrap().is_none());
assert!(session::find_valid(&db, &s2.id).await.unwrap().is_none());
}
#[tokio::test]
async fn find_valid_cleans_up_expired_row_inline() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let token = generate_token();
sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
.bind(&token)
.bind(u.id)
.bind(Utc::now() - Duration::seconds(1))
.execute(db.pool())
.await
.unwrap();
assert!(session::find_valid(&db, &token).await.unwrap().is_none());
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM rustio_sessions WHERE id = ?")
.bind(&token)
.fetch_one(db.pool())
.await
.unwrap();
assert_eq!(count, 0, "find_valid must purge expired rows inline");
}
#[test]
fn rate_limiter_allows_up_to_threshold() {
let limiter = LoginRateLimiter::with_params(3, StdDuration::from_secs(60));
assert!(limiter.check("alice@example.com").is_ok());
limiter.record_failure("alice@example.com");
limiter.record_failure("alice@example.com");
assert!(limiter.check("alice@example.com").is_ok());
}
#[test]
fn rate_limiter_locks_out_at_threshold() {
let limiter = LoginRateLimiter::with_params(3, StdDuration::from_secs(60));
for _ in 0..3 {
limiter.record_failure("alice@example.com");
}
let result = limiter.check("alice@example.com");
assert!(result.is_err(), "3rd failure must trip the lockout");
let remaining = result.unwrap_err();
assert!(remaining > StdDuration::ZERO);
assert!(remaining <= StdDuration::from_secs(60));
}
#[test]
fn rate_limiter_resets_on_successful_login() {
let limiter = LoginRateLimiter::with_params(3, StdDuration::from_secs(60));
for _ in 0..3 {
limiter.record_failure("alice@example.com");
}
assert!(limiter.check("alice@example.com").is_err());
limiter.record_success("alice@example.com");
assert!(
limiter.check("alice@example.com").is_ok(),
"a successful login must clear the lockout counter"
);
}
#[tokio::test]
async fn rate_limiter_lockout_expires_after_duration() {
let limiter = LoginRateLimiter::with_params(3, StdDuration::from_millis(50));
for _ in 0..3 {
limiter.record_failure("bob@example.com");
}
assert!(limiter.check("bob@example.com").is_err());
tokio::time::sleep(StdDuration::from_millis(80)).await;
assert!(
limiter.check("bob@example.com").is_ok(),
"lockout must lift after the configured duration"
);
}
#[test]
fn compose_key_email_only_is_stable() {
let k = LoginRateLimiter::compose_key("alice@example.com", None);
assert_eq!(k, "email:alice@example.com");
}
#[test]
fn compose_key_with_ip_is_distinct_from_email_only() {
let a = LoginRateLimiter::compose_key("alice@example.com", None);
let b = LoginRateLimiter::compose_key("alice@example.com", Some("203.0.113.5"));
assert_ne!(a, b);
assert_eq!(b, "email:alice@example.com|ip:203.0.113.5");
}
#[test]
fn compose_key_distinct_ips_produce_distinct_keys() {
let a = LoginRateLimiter::compose_key("a@b.co", Some("10.0.0.1"));
let b = LoginRateLimiter::compose_key("a@b.co", Some("10.0.0.2"));
assert_ne!(a, b);
}
#[test]
fn csrf_generate_returns_hex_of_expected_length() {
let t = csrf::generate_token();
assert_eq!(t.len(), 64);
assert!(t.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn csrf_generate_produces_unique_tokens() {
let a = csrf::generate_token();
let b = csrf::generate_token();
assert_ne!(a, b);
}
#[test]
fn csrf_verify_matching_returns_true() {
let t = csrf::generate_token();
assert!(csrf::verify_token(&t, &t));
}
#[test]
fn csrf_verify_mismatched_returns_false() {
let t = csrf::generate_token();
let other = csrf::generate_token();
assert!(!csrf::verify_token(&t, &other));
}
#[test]
fn csrf_verify_empty_either_side_returns_false() {
let t = csrf::generate_token();
assert!(!csrf::verify_token("", &t));
assert!(!csrf::verify_token(&t, ""));
assert!(!csrf::verify_token("", ""));
}
#[test]
fn csrf_verify_rejects_different_lengths() {
assert!(!csrf::verify_token("abc", "abcd"));
assert!(!csrf::verify_token("abcd", "abc"));
}
#[test]
fn csrf_verify_rejects_single_byte_difference() {
let a = "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef";
let mut b = String::from(a);
b.pop();
b.push('0');
assert!(!csrf::verify_token(a, &b));
}
#[tokio::test]
async fn session_create_generates_unique_csrf_per_session() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let s1 = session::create(&db, u.id).await.unwrap();
let s2 = session::create(&db, u.id).await.unwrap();
assert_eq!(s1.csrf_token.len(), 64);
assert_ne!(
s1.csrf_token, s2.csrf_token,
"each session must get an independent CSRF token"
);
assert_ne!(
s1.csrf_token, s1.id,
"session id and csrf token must not be the same value"
);
}
#[tokio::test]
async fn session_find_valid_returns_csrf_token() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let s = session::create(&db, u.id).await.unwrap();
let found = session::find_valid(&db, &s.id).await.unwrap().unwrap();
assert_eq!(found.csrf_token, s.csrf_token);
}
#[tokio::test]
async fn resolve_identity_with_session_exposes_csrf() {
let db = setup().await;
let u = seeded_user(&db, ROLE_ADMIN).await;
let s = session::create(&db, u.id).await.unwrap();
let (id, sess) = resolve_identity_with_session(&db, Some(&s.id))
.await
.unwrap();
assert_eq!(id.user_id, u.id);
assert_eq!(sess.csrf_token, s.csrf_token);
}
#[test]
fn rate_limiter_tracks_keys_independently() {
let limiter = LoginRateLimiter::with_params(2, StdDuration::from_secs(60));
limiter.record_failure("alice@example.com");
limiter.record_failure("alice@example.com");
assert!(limiter.check("alice@example.com").is_err());
assert!(limiter.check("bob@example.com").is_ok());
}
#[test]
fn dummy_password_hash_is_stable_across_calls() {
let a = dummy_password_hash();
let b = dummy_password_hash();
assert!(std::ptr::eq(a, b));
}
#[test]
fn dummy_password_hash_is_a_valid_phc_string() {
assert!(PasswordHash::new(dummy_password_hash()).is_ok());
}
#[test]
fn verify_against_dummy_hash_rejects_arbitrary_inputs() {
assert!(!password::verify("", dummy_password_hash()));
assert!(!password::verify("wrong password", dummy_password_hash()));
assert!(!password::verify("admin", dummy_password_hash()));
}
#[tokio::test]
async fn logout_deletes_session_so_later_requests_are_anonymous() {
let db = setup().await;
let u = seeded_user(&db, ROLE_USER).await;
let s = session::create(&db, u.id).await.unwrap();
assert!(resolve_identity(&db, Some(&s.id)).await.is_some());
session::delete(&db, &s.id).await.unwrap();
assert!(
resolve_identity(&db, Some(&s.id)).await.is_none(),
"deleted session must not resolve"
);
}
}