use crate::pool::Pool;
use crate::pool::SessionPool;
use async_trait::async_trait;
use axum_session_auth::*;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
pub type Session = axum_session_auth::AuthSession<User, i64, SessionPool, Pool>;
pub type AuthLayer = axum_session_auth::AuthSessionLayer<User, i64, SessionPool, Pool>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: i32,
pub anonymous: bool,
pub username: String,
pub display_name: Option<String>,
pub email: Option<String>,
pub avatar_url: Option<String>,
pub html_url: Option<String>,
pub permissions: HashSet<String>,
}
#[derive(sqlx::FromRow, Clone)]
pub struct SqlPermissionTokens {
pub token: String,
}
#[async_trait]
impl Authentication<User, i64, Pool> for User {
async fn load_user(userid: i64, pool: Option<&Pool>) -> Result<User, anyhow::Error> {
let db = pool.ok_or_else(|| anyhow::anyhow!("load_user called without a database pool"))?;
#[derive(sqlx::FromRow, Clone)]
struct SqlUser {
id: i32,
anonymous: bool,
username: String,
display_name: Option<String>,
email: Option<String>,
avatar_url: Option<String>,
html_url: Option<String>,
}
let sqluser = sqlx::query_as::<_, SqlUser>(
"SELECT id, anonymous, username, display_name, email, avatar_url, html_url \
FROM users WHERE id = $1",
)
.bind(userid)
.fetch_one(db)
.await?;
let sql_user_perms = sqlx::query_as::<_, SqlPermissionTokens>(
"SELECT token FROM user_permissions WHERE user_id = $1 \
UNION \
SELECT rp.token FROM role_permissions rp \
JOIN user_roles ur ON ur.role_id = rp.role_id \
WHERE ur.user_id = $1",
)
.bind(userid)
.fetch_all(db)
.await?;
Ok(User {
id: sqluser.id,
anonymous: sqluser.anonymous,
username: sqluser.username,
display_name: sqluser.display_name,
email: sqluser.email,
avatar_url: sqluser.avatar_url,
html_url: sqluser.html_url,
permissions: sql_user_perms.into_iter().map(|x| x.token).collect(),
})
}
fn is_authenticated(&self) -> bool {
!self.anonymous
}
fn is_active(&self) -> bool {
!self.anonymous
}
fn is_anonymous(&self) -> bool {
self.anonymous
}
}
#[async_trait]
impl HasPermission<Pool> for User {
async fn has(&self, perm: &str, _pool: &Option<&Pool>) -> bool {
self.permissions.contains(perm)
}
}
pub mod role {
pub const ADMIN: i64 = 1;
pub const MEMBER: i64 = 2;
pub const GUEST: i64 = 3;
}
pub async fn maybe_bootstrap_admin(
db: &Pool,
user_id: i64,
email: Option<&str>,
) -> anyhow::Result<()> {
let Some(email) = email else { return Ok(()) };
let target = bootstrap_admin_email();
if let Some(t) = target
&& t.eq_ignore_ascii_case(email)
{
grant_role(db, user_id, role::ADMIN).await?;
}
Ok(())
}
pub async fn maybe_grant_first_admin(db: &Pool, user_id: i64) -> anyhow::Result<()> {
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM user_roles WHERE role_id = $1")
.bind(role::ADMIN)
.fetch_one(db)
.await?;
if count == 0 {
grant_role(db, user_id, role::ADMIN).await?;
eprintln!(
"[startup] bootstrap-admin: promoted user {user_id} to admin (first signup, \
no existing admins)"
);
}
Ok(())
}
pub async fn sync_bootstrap_admin(db: &Pool) -> anyhow::Result<()> {
let Some(email) = bootstrap_admin_email() else {
return Ok(());
};
let user: Option<(i64,)> = sqlx::query_as(
"SELECT id FROM users \
WHERE LOWER(email) = LOWER($1) AND deleted_at IS NULL \
LIMIT 1",
)
.bind(email.trim())
.fetch_optional(db)
.await?;
let Some((user_id,)) = user else {
eprintln!(
"[startup] bootstrap-admin: no account for {email} yet — they'll be promoted on signup"
);
return Ok(());
};
let already: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM user_roles WHERE user_id = $1 AND role_id = $2")
.bind(user_id)
.bind(role::ADMIN)
.fetch_one(db)
.await?;
if already == 0 {
grant_role(db, user_id, role::ADMIN).await?;
eprintln!("[startup] bootstrap-admin: granted admin role to {email} (user id {user_id})");
}
Ok(())
}
fn bootstrap_admin_email() -> Option<String> {
std::env::var("DX_AUTH_BOOTSTRAP_ADMIN_EMAIL")
.or_else(|_| std::env::var("BOOTSTRAP_ADMIN_EMAIL"))
.ok()
.filter(|s| !s.is_empty())
}
pub async fn assign_default_role(db: &Pool, user_id: i64) -> anyhow::Result<()> {
grant_role(db, user_id, role::MEMBER).await
}
pub async fn grant_role(db: &Pool, user_id: i64, role_id: i64) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO user_roles (user_id, role_id) VALUES ($1, $2) \
ON CONFLICT (user_id, role_id) DO NOTHING",
)
.bind(user_id)
.bind(role_id)
.execute(db)
.await?;
Ok(())
}
pub async fn revoke_role(db: &Pool, user_id: i64, role_id: i64) -> anyhow::Result<()> {
sqlx::query("DELETE FROM user_roles WHERE user_id = $1 AND role_id = $2")
.bind(user_id)
.bind(role_id)
.execute(db)
.await?;
Ok(())
}
pub async fn set_user_roles(db: &Pool, user_id: i64, role_ids: &[i64]) -> anyhow::Result<()> {
let mut tx = db.begin().await?;
sqlx::query("DELETE FROM user_roles WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
for &rid in role_ids {
sqlx::query("INSERT INTO user_roles (user_id, role_id) VALUES ($1, $2)")
.bind(user_id)
.bind(rid)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize, PartialEq)]
pub struct RoleRow {
pub id: i64,
pub name: String,
pub description: Option<String>,
pub is_system: bool,
}
pub async fn list_roles(db: &Pool) -> anyhow::Result<Vec<RoleRow>> {
let rows = sqlx::query_as::<_, RoleRow>(
"SELECT id, name, description, is_system FROM roles ORDER BY id",
)
.fetch_all(db)
.await?;
Ok(rows)
}
pub async fn create_role(
db: &Pool,
name: &str,
description: Option<&str>,
permissions: &[String],
) -> anyhow::Result<i64> {
let name = name.trim();
if name.is_empty() {
anyhow::bail!("Role name is required.");
}
let mut tx = db.begin().await?;
let inserted: Result<(i64,), sqlx::Error> = sqlx::query_as(
"INSERT INTO roles (name, description, is_system) VALUES ($1, $2, false) RETURNING id",
)
.bind(name)
.bind(description)
.fetch_one(&mut *tx)
.await;
let (role_id,) = match inserted {
Ok(row) => row,
Err(sqlx::Error::Database(dberr)) if dberr.is_unique_violation() => {
anyhow::bail!("A role with that name already exists.");
}
Err(e) => return Err(e.into()),
};
for token in dedup_tokens(permissions) {
sqlx::query("INSERT INTO role_permissions (role_id, token) VALUES ($1, $2)")
.bind(role_id)
.bind(token)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(role_id)
}
pub async fn update_role(
db: &Pool,
role_id: i64,
name: &str,
description: Option<&str>,
permissions: &[String],
) -> anyhow::Result<()> {
let name = name.trim();
if name.is_empty() {
anyhow::bail!("Role name is required.");
}
let mut tx = db.begin().await?;
let row: Option<(bool,)> = sqlx::query_as("SELECT is_system FROM roles WHERE id = $1")
.bind(role_id)
.fetch_optional(&mut *tx)
.await?;
match row {
None => anyhow::bail!("Role not found."),
Some((true,)) => anyhow::bail!("System roles are read-only."),
_ => {}
}
match sqlx::query("UPDATE roles SET name = $1, description = $2 WHERE id = $3")
.bind(name)
.bind(description)
.bind(role_id)
.execute(&mut *tx)
.await
{
Ok(_) => {}
Err(sqlx::Error::Database(dberr)) if dberr.is_unique_violation() => {
anyhow::bail!("A role with that name already exists.");
}
Err(e) => return Err(e.into()),
}
sqlx::query("DELETE FROM role_permissions WHERE role_id = $1")
.bind(role_id)
.execute(&mut *tx)
.await?;
for token in dedup_tokens(permissions) {
sqlx::query("INSERT INTO role_permissions (role_id, token) VALUES ($1, $2)")
.bind(role_id)
.bind(token)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
pub async fn delete_role(db: &Pool, role_id: i64) -> anyhow::Result<()> {
let mut tx = db.begin().await?;
let row: Option<(bool,)> = sqlx::query_as("SELECT is_system FROM roles WHERE id = $1")
.bind(role_id)
.fetch_optional(&mut *tx)
.await?;
match row {
None => anyhow::bail!("Role not found."),
Some((true,)) => anyhow::bail!("System roles are read-only."),
_ => {}
}
sqlx::query("DELETE FROM user_roles WHERE role_id = $1")
.bind(role_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM role_permissions WHERE role_id = $1")
.bind(role_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM roles WHERE id = $1")
.bind(role_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
fn dedup_tokens(tokens: &[String]) -> Vec<String> {
let mut seen = HashSet::new();
let mut out = Vec::with_capacity(tokens.len());
for t in tokens {
let trimmed = t.trim();
if trimmed.is_empty() {
continue;
}
if seen.insert(trimmed.to_string()) {
out.push(trimmed.to_string());
}
}
out
}
pub async fn get_user_role_ids(db: &Pool, user_id: i64) -> anyhow::Result<Vec<i64>> {
let rows: Vec<(i64,)> =
sqlx::query_as("SELECT role_id FROM user_roles WHERE user_id = $1 ORDER BY role_id")
.bind(user_id)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(r,)| r).collect())
}
pub async fn soft_delete_user(db: &Pool, user_id: i64) -> anyhow::Result<()> {
let mut tx = db.begin().await?;
sqlx::query(
"UPDATE users SET \
display_name = NULL, \
email = NULL, \
avatar_url = NULL, \
html_url = NULL, \
password_hash = NULL, \
mfa_secret = NULL, \
mfa_enabled_at = NULL, \
email_verified_at = NULL, \
deleted_at = $1 \
WHERE id = $2 AND deleted_at IS NULL",
)
.bind(unix_now())
.bind(user_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM user_roles WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM oauth_accounts WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM user_permissions WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
pub async fn update_display_name(
db: &Pool,
user_id: i64,
new_name: Option<&str>,
) -> anyhow::Result<()> {
sqlx::query("UPDATE users SET display_name = $1 WHERE id = $2")
.bind(new_name)
.bind(user_id)
.execute(db)
.await?;
Ok(())
}
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct AdminUserRow {
pub id: i64,
pub username: String,
pub display_name: Option<String>,
pub email: Option<String>,
pub email_verified_at: Option<i64>,
pub mfa_enabled_at: Option<i64>,
pub anonymous: bool,
pub deleted_at: Option<i64>,
pub avatar_url: Option<String>,
pub html_url: Option<String>,
}
pub async fn list_users_for_admin(
db: &Pool,
limit: i64,
offset: i64,
) -> anyhow::Result<Vec<AdminUserRow>> {
let rows = sqlx::query_as::<_, AdminUserRow>(
"SELECT id, username, display_name, email, email_verified_at, \
mfa_enabled_at, anonymous, deleted_at, avatar_url, html_url \
FROM users \
ORDER BY id \
LIMIT $1 OFFSET $2",
)
.bind(limit.clamp(1, 500))
.bind(offset.max(0))
.fetch_all(db)
.await?;
Ok(rows)
}
pub async fn get_user_for_admin(db: &Pool, user_id: i64) -> anyhow::Result<Option<AdminUserRow>> {
let row = sqlx::query_as::<_, AdminUserRow>(
"SELECT id, username, display_name, email, email_verified_at, \
mfa_enabled_at, anonymous, deleted_at, avatar_url, html_url \
FROM users WHERE id = $1",
)
.bind(user_id)
.fetch_optional(db)
.await?;
Ok(row)
}
pub async fn list_permissions_for_user(db: &Pool, user_id: i64) -> anyhow::Result<Vec<String>> {
let rows: Vec<(String,)> = sqlx::query_as(
"SELECT token FROM user_permissions WHERE user_id = $1 \
UNION \
SELECT rp.token FROM role_permissions rp \
JOIN user_roles ur ON ur.role_id = rp.role_id \
WHERE ur.user_id = $1 \
ORDER BY 1",
)
.bind(user_id)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(t,)| t).collect())
}
pub async fn list_permissions_for_role(db: &Pool, role_id: i64) -> anyhow::Result<Vec<String>> {
let rows: Vec<(String,)> =
sqlx::query_as("SELECT token FROM role_permissions WHERE role_id = $1 ORDER BY token")
.bind(role_id)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(t,)| t).collect())
}
pub async fn get_password_hash(db: &Pool, user_id: i64) -> anyhow::Result<Option<String>> {
let row: Option<(Option<String>,)> =
sqlx::query_as("SELECT password_hash FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(db)
.await?;
Ok(row.and_then(|(h,)| h))
}
pub async fn replace_password_hash(
db: &Pool,
user_id: i64,
new_password: &str,
) -> anyhow::Result<()> {
if new_password.len() < 8 {
anyhow::bail!("Password must be at least 8 characters.");
}
let hash = hash_password(new_password)?;
sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2")
.bind(hash)
.bind(user_id)
.execute(db)
.await?;
Ok(())
}
pub fn verify_password_against_hash(stored_hash: &str, candidate: &str) -> bool {
use argon2::Argon2;
use argon2::password_hash::{PasswordHash, PasswordVerifier};
let Ok(parsed) = PasswordHash::new(stored_hash) else {
return false;
};
Argon2::default()
.verify_password(candidate.as_bytes(), &parsed)
.is_ok()
}
pub async fn linked_oauth_providers(db: &Pool, user_id: i64) -> anyhow::Result<Vec<String>> {
let rows: Vec<(String,)> = sqlx::query_as(
"SELECT DISTINCT provider FROM oauth_accounts WHERE user_id = $1 ORDER BY provider",
)
.bind(user_id)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(p,)| p).collect())
}
pub async fn unique_username(db: &Pool, desired: &str) -> anyhow::Result<String> {
let base = {
let trimmed = desired.trim();
if trimmed.is_empty() {
"user".to_string()
} else {
trimmed.to_string()
}
};
let mut candidate = base.clone();
let mut n: u32 = 1;
loop {
let taken: bool = sqlx::query_scalar(
"SELECT EXISTS(SELECT 1 FROM users WHERE LOWER(username) = LOWER($1))",
)
.bind(&candidate)
.fetch_one(db)
.await?;
if !taken {
return Ok(candidate);
}
n = n.saturating_add(1);
if n > 10_000 {
anyhow::bail!("could not allocate a unique username for {base:?}");
}
candidate = format!("{base}{n}");
}
}
pub async fn create_password_user(db: &Pool, email: &str, password: &str) -> anyhow::Result<i64> {
use argon2::Argon2;
use argon2::password_hash::{PasswordHasher, SaltString, rand_core::OsRng};
let email = email.trim();
if email.is_empty() || !email.contains('@') {
anyhow::bail!("Please enter a valid email address.");
}
if password.len() < 8 {
anyhow::bail!("Password must be at least 8 characters.");
}
let salt = SaltString::generate(&mut OsRng);
let hash = Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("hashing failed: {e}"))?
.to_string();
let desired = email.split('@').next().unwrap_or(email);
let username = unique_username(db, desired).await?;
let inserted: Result<(i64,), sqlx::Error> = sqlx::query_as(
"INSERT INTO users (anonymous, username, email, password_hash) \
VALUES (false, $1, $2, $3) RETURNING id",
)
.bind(username)
.bind(email)
.bind(&hash)
.fetch_one(db)
.await;
let (user_id,) = match inserted {
Ok(row) => row,
Err(sqlx::Error::Database(dberr)) if dberr.is_unique_violation() => {
anyhow::bail!("An account with that email already exists.");
}
Err(e) => return Err(e.into()),
};
assign_default_role(db, user_id).await?;
maybe_bootstrap_admin(db, user_id, Some(email)).await?;
maybe_grant_first_admin(db, user_id).await?;
Ok(user_id)
}
pub async fn request_password_reset(db: &Pool, email: &str) -> anyhow::Result<Option<String>> {
use argon2::password_hash::rand_core::{OsRng, RngCore};
let user: Option<(i64,)> = sqlx::query_as(
"SELECT id FROM users \
WHERE LOWER(email) = LOWER($1) AND password_hash IS NOT NULL \
LIMIT 1",
)
.bind(email.trim())
.fetch_optional(db)
.await?;
let Some((user_id,)) = user else {
return Ok(None);
};
let mut bytes = [0u8; 16];
let mut rng = OsRng;
rng.fill_bytes(&mut bytes);
let token: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
let expires_at = unix_now().saturating_add(3600);
sqlx::query(
"INSERT INTO password_reset_tokens (token, user_id, expires_at) VALUES ($1, $2, $3)",
)
.bind(&token)
.bind(user_id)
.bind(expires_at)
.execute(db)
.await?;
Ok(Some(token))
}
pub async fn consume_password_reset(
db: &Pool,
token: &str,
new_password: &str,
) -> anyhow::Result<i64> {
if new_password.len() < 8 {
anyhow::bail!("Password must be at least 8 characters.");
}
let row: Option<(i64,)> = sqlx::query_as(
"SELECT user_id FROM password_reset_tokens WHERE token = $1 AND expires_at > $2 LIMIT 1",
)
.bind(token)
.bind(unix_now())
.fetch_optional(db)
.await?;
let Some((user_id,)) = row else {
anyhow::bail!("This reset link has expired or already been used.");
};
let hash = hash_password(new_password)?;
let mut tx = db.begin().await?;
sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2")
.bind(&hash)
.bind(user_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM password_reset_tokens WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(user_id)
}
fn hash_password(plaintext: &str) -> anyhow::Result<String> {
use argon2::Argon2;
use argon2::password_hash::{PasswordHasher, SaltString, rand_core::OsRng};
let salt = SaltString::generate(&mut OsRng);
Ok(Argon2::default()
.hash_password(plaintext.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("hashing failed: {e}"))?
.to_string())
}
fn unix_now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VerifyOutcome {
Verified(i64),
Unverified,
Invalid,
}
const DUMMY_PASSWORD_HASH: &str = "$argon2id$v=19$m=19456,t=2,p=1$OSdEu4xJYj4c5XviuP4CTQ$8LSH0M1A859epUylwUTwZJUp5O8rAtv0wURpMnvMbE4";
fn burn_password_verify(password: &str) {
use argon2::Argon2;
use argon2::password_hash::{PasswordHash, PasswordVerifier};
if let Ok(parsed) = PasswordHash::new(DUMMY_PASSWORD_HASH) {
let _ = Argon2::default().verify_password(password.as_bytes(), &parsed);
}
}
pub async fn verify_password_user(
db: &Pool,
email: &str,
password: &str,
) -> anyhow::Result<VerifyOutcome> {
use argon2::Argon2;
use argon2::password_hash::{PasswordHash, PasswordVerifier};
let row: Option<(i64, String, Option<i64>)> = sqlx::query_as(
"SELECT id, password_hash, email_verified_at FROM users \
WHERE LOWER(email) = LOWER($1) AND password_hash IS NOT NULL \
LIMIT 1",
)
.bind(email.trim())
.fetch_optional(db)
.await?;
let Some((user_id, stored_hash, verified_at)) = row else {
burn_password_verify(password);
return Ok(VerifyOutcome::Invalid);
};
let Ok(parsed) = PasswordHash::new(&stored_hash) else {
burn_password_verify(password);
return Ok(VerifyOutcome::Invalid);
};
if Argon2::default()
.verify_password(password.as_bytes(), &parsed)
.is_err()
{
return Ok(VerifyOutcome::Invalid);
}
if verified_at.is_none() {
return Ok(VerifyOutcome::Unverified);
}
Ok(VerifyOutcome::Verified(user_id))
}
pub async fn issue_verification_token(db: &Pool, user_id: i64) -> anyhow::Result<String> {
use argon2::password_hash::rand_core::{OsRng, RngCore};
let mut bytes = [0u8; 16];
let mut rng = OsRng;
rng.fill_bytes(&mut bytes);
let token: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
let expires_at = unix_now().saturating_add(24 * 3600);
sqlx::query(
"INSERT INTO email_verification_tokens (token, user_id, expires_at) VALUES ($1, $2, $3)",
)
.bind(&token)
.bind(user_id)
.bind(expires_at)
.execute(db)
.await?;
Ok(token)
}
pub async fn consume_verification_token(db: &Pool, token: &str) -> anyhow::Result<Option<i64>> {
let row: Option<(i64,)> = sqlx::query_as(
"SELECT user_id FROM email_verification_tokens WHERE token = $1 AND expires_at > $2 LIMIT 1",
)
.bind(token)
.bind(unix_now())
.fetch_optional(db)
.await?;
let Some((user_id,)) = row else {
return Ok(None);
};
let mut tx = db.begin().await?;
sqlx::query("UPDATE users SET email_verified_at = $1 WHERE id = $2")
.bind(unix_now())
.bind(user_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM email_verification_tokens WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(Some(user_id))
}
pub async fn mark_email_verified(db: &Pool, user_id: i64) -> anyhow::Result<()> {
sqlx::query("UPDATE users SET email_verified_at = $1 WHERE id = $2")
.bind(unix_now())
.bind(user_id)
.execute(db)
.await?;
Ok(())
}
#[cfg(feature = "mfa")]
const MFA_ISSUER: &str = "dx-auth example";
#[cfg(feature = "mfa")]
const RECOVERY_CODE_COUNT: usize = 10;
#[cfg(feature = "mfa")]
pub struct MfaSetupInfo {
pub secret_base32: String,
pub qr_png_base64: String,
pub recovery_codes: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg(feature = "mfa")]
pub enum MfaStatus {
Disabled,
Pending,
Enabled,
}
#[cfg(feature = "mfa")]
pub async fn setup_mfa_secret(
db: &Pool,
user_id: i64,
account_label: &str,
) -> anyhow::Result<MfaSetupInfo> {
use argon2::Argon2;
use argon2::password_hash::{PasswordHasher, SaltString, rand_core::OsRng};
use totp_rs::{Algorithm, Secret, TOTP};
let secret = Secret::generate_secret();
let secret_base32 = match &secret {
Secret::Encoded(s) => s.clone(),
Secret::Raw(_) => secret.to_encoded().to_string(),
};
let totp = TOTP::new(
Algorithm::SHA1,
6,
1,
30,
secret.to_bytes()?,
Some(MFA_ISSUER.to_string()),
account_label.to_string(),
)?;
let qr_png_base64 = totp
.get_qr_base64()
.map_err(|e| anyhow::anyhow!("QR generation failed: {e}"))?;
let mut rng = OsRng;
let mut recovery_codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
let mut recovery_hashes = Vec::with_capacity(RECOVERY_CODE_COUNT);
for _ in 0..RECOVERY_CODE_COUNT {
let code = generate_recovery_code(&mut rng);
let salt = SaltString::generate(&mut OsRng);
let hash = Argon2::default()
.hash_password(code.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("hashing recovery code failed: {e}"))?
.to_string();
recovery_codes.push(code);
recovery_hashes.push(hash);
}
let mut tx = db.begin().await?;
sqlx::query("UPDATE users SET mfa_secret = $1, mfa_enabled_at = NULL WHERE id = $2")
.bind(&secret_base32)
.bind(user_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM mfa_recovery_codes WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
for hash in &recovery_hashes {
sqlx::query("INSERT INTO mfa_recovery_codes (user_id, code_hash) VALUES ($1, $2)")
.bind(user_id)
.bind(hash)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(MfaSetupInfo {
secret_base32,
qr_png_base64,
recovery_codes,
})
}
#[cfg(feature = "mfa")]
pub async fn enable_mfa(db: &Pool, user_id: i64, totp_code: &str) -> anyhow::Result<bool> {
let Some(secret) = load_mfa_secret(db, user_id).await? else {
return Ok(false);
};
if !check_totp(&secret, totp_code) {
return Ok(false);
}
sqlx::query("UPDATE users SET mfa_enabled_at = $1 WHERE id = $2")
.bind(unix_now())
.bind(user_id)
.execute(db)
.await?;
Ok(true)
}
#[cfg(feature = "mfa")]
pub async fn verify_mfa_challenge(db: &Pool, user_id: i64, code: &str) -> anyhow::Result<bool> {
let code = code.trim();
if let Some(secret) = load_mfa_secret(db, user_id).await?
&& check_totp(&secret, code)
{
return Ok(true);
}
consume_recovery_code(db, user_id, code).await
}
#[cfg(feature = "mfa")]
pub async fn disable_mfa(db: &Pool, user_id: i64) -> anyhow::Result<()> {
let mut tx = db.begin().await?;
sqlx::query("UPDATE users SET mfa_secret = NULL, mfa_enabled_at = NULL WHERE id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM mfa_recovery_codes WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
#[cfg(feature = "mfa")]
pub async fn user_has_mfa(db: &Pool, user_id: i64) -> anyhow::Result<bool> {
let row: Option<(Option<i64>,)> =
sqlx::query_as("SELECT mfa_enabled_at FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(db)
.await?;
Ok(row.and_then(|(t,)| t).is_some())
}
#[cfg(feature = "mfa")]
pub async fn mfa_status(db: &Pool, user_id: i64) -> anyhow::Result<MfaStatus> {
let row: Option<(Option<String>, Option<i64>)> =
sqlx::query_as("SELECT mfa_secret, mfa_enabled_at FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(db)
.await?;
Ok(match row {
Some((Some(_), Some(_))) => MfaStatus::Enabled,
Some((Some(_), None)) => MfaStatus::Pending,
_ => MfaStatus::Disabled,
})
}
#[cfg(feature = "mfa")]
async fn load_mfa_secret(db: &Pool, user_id: i64) -> anyhow::Result<Option<String>> {
let row: Option<(Option<String>,)> =
sqlx::query_as("SELECT mfa_secret FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(db)
.await?;
Ok(row.and_then(|(s,)| s))
}
#[cfg(feature = "mfa")]
fn check_totp(secret_base32: &str, code: &str) -> bool {
use totp_rs::{Algorithm, Secret, TOTP};
let Ok(bytes) = Secret::Encoded(secret_base32.to_string()).to_bytes() else {
return false;
};
let Ok(totp) = TOTP::new(Algorithm::SHA1, 6, 1, 30, bytes, None, "".to_string()) else {
return false;
};
totp.check_current(code).unwrap_or(false)
}
#[cfg(feature = "mfa")]
async fn consume_recovery_code(db: &Pool, user_id: i64, candidate: &str) -> anyhow::Result<bool> {
use argon2::Argon2;
use argon2::password_hash::{PasswordHash, PasswordVerifier};
let rows: Vec<(String,)> = sqlx::query_as(
"SELECT code_hash FROM mfa_recovery_codes WHERE user_id = $1 AND used_at IS NULL",
)
.bind(user_id)
.fetch_all(db)
.await?;
for (hash,) in rows {
if let Ok(parsed) = PasswordHash::new(&hash)
&& Argon2::default()
.verify_password(candidate.as_bytes(), &parsed)
.is_ok()
{
sqlx::query(
"UPDATE mfa_recovery_codes SET used_at = $1 \
WHERE user_id = $2 AND code_hash = $3",
)
.bind(unix_now())
.bind(user_id)
.bind(&hash)
.execute(db)
.await?;
return Ok(true);
}
}
Ok(false)
}
#[cfg(feature = "mfa")]
fn generate_recovery_code<R: argon2::password_hash::rand_core::RngCore>(rng: &mut R) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHJKMNPQRSTUVWXYZ23456789";
let mut bytes = [0u8; 10];
rng.fill_bytes(&mut bytes);
bytes
.iter()
.map(|b| {
let i = (*b as usize).checked_rem(ALPHABET.len()).unwrap_or(0);
ALPHABET.get(i).copied().unwrap_or(b'A') as char
})
.collect()
}
pub async fn find_unverified_user_id(db: &Pool, email: &str) -> anyhow::Result<Option<i64>> {
let row: Option<(i64,)> = sqlx::query_as(
"SELECT id FROM users \
WHERE LOWER(email) = LOWER($1) \
AND password_hash IS NOT NULL \
AND email_verified_at IS NULL \
LIMIT 1",
)
.bind(email.trim())
.fetch_optional(db)
.await?;
Ok(row.map(|(id,)| id))
}
pub mod audit {
use super::*;
use crate::wire::{AuditEventView, AuditQuery};
pub const USER_LOGIN_SUCCESS: &str = "user.login.success";
pub const USER_LOGIN_FAILED: &str = "user.login.failed";
pub const USER_LOGOUT: &str = "user.logout";
pub const USER_SIGNUP: &str = "user.signup";
pub const USER_EMAIL_VERIFIED: &str = "user.email_verified";
pub const USER_PWD_RESET_REQUESTED: &str = "user.password_reset.requested";
pub const USER_PWD_RESET_CONSUMED: &str = "user.password_reset.consumed";
pub const USER_MFA_ENABLED: &str = "user.mfa.enabled";
pub const USER_MFA_DISABLED: &str = "user.mfa.disabled";
pub const USER_API_TOKEN_CREATED: &str = "user.api_token.created";
pub const USER_API_TOKEN_REVOKED: &str = "user.api_token.revoked";
pub const ACCOUNT_PASSWORD_CHANGED: &str = "account.password_changed";
pub const ACCOUNT_DISPLAY_NAME_CHANGED: &str = "account.display_name_changed";
pub const ACCOUNT_SELF_DELETED: &str = "account.self_deleted";
pub const ADMIN_ROLES_CHANGED: &str = "admin.user.roles_changed";
pub const ADMIN_USER_DELETED: &str = "admin.user.soft_deleted";
pub const ADMIN_ROLE_CREATED: &str = "admin.role.created";
pub const ADMIN_ROLE_UPDATED: &str = "admin.role.updated";
pub const ADMIN_ROLE_DELETED: &str = "admin.role.deleted";
pub const RESOURCE_ACCESS_DENIED: &str = "resource.access.denied";
pub const RESOURCE_ACCESS_GRANTED: &str = "resource.access.granted";
pub struct RecordInput<'a> {
pub event_type: &'a str,
pub actor_id: Option<i64>,
pub target_id: Option<i64>,
pub ip: Option<&'a str>,
pub user_agent: Option<&'a str>,
pub details: Option<&'a str>,
}
pub async fn record(db: &Pool, input: RecordInput<'_>) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO audit_events \
(occurred_at, event_type, actor_id, target_id, ip, user_agent, details) \
VALUES ($1, $2, $3, $4, $5, $6, $7)",
)
.bind(unix_now())
.bind(input.event_type)
.bind(input.actor_id)
.bind(input.target_id)
.bind(input.ip)
.bind(input.user_agent)
.bind(input.details)
.execute(db)
.await?;
Ok(())
}
pub async fn record_or_log(db: &Pool, input: RecordInput<'_>) {
if let Err(err) = record(db, input).await {
eprintln!("[audit] WARN: failed to record event: {err}");
}
}
pub async fn query(db: &Pool, q: &AuditQuery) -> anyhow::Result<Vec<AuditEventView>> {
let limit = q.limit.clamp(1, 500);
let offset = q.offset.max(0);
let (type_clause, type_pattern, type_is_filtered) = if q.event_type.is_empty() {
("".to_string(), String::new(), false)
} else if q.event_type.ends_with('.') {
(
" AND e.event_type LIKE $1".to_string(),
format!("{}%", q.event_type),
true,
)
} else {
(
" AND e.event_type = $1".to_string(),
q.event_type.clone(),
true,
)
};
let mut idx: i32 = if type_is_filtered { 2 } else { 1 };
let mut clauses = String::new();
let actor_idx = q.actor_id.map(|_| {
let i = idx;
idx = idx.saturating_add(1);
clauses.push_str(&format!(" AND e.actor_id = ${i}"));
i
});
let target_idx = q.target_id.map(|_| {
let i = idx;
idx = idx.saturating_add(1);
clauses.push_str(&format!(" AND e.target_id = ${i}"));
i
});
let since_idx = q.since.map(|_| {
let i = idx;
idx = idx.saturating_add(1);
clauses.push_str(&format!(" AND e.occurred_at >= ${i}"));
i
});
let until_idx = q.until.map(|_| {
let i = idx;
idx = idx.saturating_add(1);
clauses.push_str(&format!(" AND e.occurred_at <= ${i}"));
i
});
let limit_idx = idx;
let offset_idx = idx.saturating_add(1);
let sql = format!(
"SELECT e.id, e.occurred_at, e.event_type, \
e.actor_id, ua.email AS actor_email, \
e.target_id, ut.email AS target_email, \
e.ip, e.user_agent, e.details \
FROM audit_events e \
LEFT JOIN users ua ON ua.id = e.actor_id \
LEFT JOIN users ut ON ut.id = e.target_id \
WHERE 1 = 1{type_clause}{clauses} \
ORDER BY e.occurred_at DESC, e.id DESC \
LIMIT ${limit_idx} OFFSET ${offset_idx}",
);
let mut qb = sqlx::query_as::<_, AuditRow>(&sql);
if type_is_filtered {
qb = qb.bind(type_pattern);
}
if let (Some(_), Some(v)) = (actor_idx, q.actor_id) {
qb = qb.bind(v);
}
if let (Some(_), Some(v)) = (target_idx, q.target_id) {
qb = qb.bind(v);
}
if let (Some(_), Some(v)) = (since_idx, q.since) {
qb = qb.bind(v);
}
if let (Some(_), Some(v)) = (until_idx, q.until) {
qb = qb.bind(v);
}
qb = qb.bind(limit).bind(offset);
let rows = qb.fetch_all(db).await?;
Ok(rows
.into_iter()
.map(|r| AuditEventView {
id: r.id,
occurred_at: r.occurred_at,
occurred_at_iso: format_unix(r.occurred_at),
event_type: r.event_type,
actor_id: r.actor_id,
actor_email: r.actor_email,
target_id: r.target_id,
target_email: r.target_email,
ip: r.ip,
user_agent: r.user_agent,
details: r.details,
})
.collect())
}
fn format_unix(secs: i64) -> String {
use chrono::TimeZone;
chrono::Utc
.timestamp_opt(secs, 0)
.single()
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S UTC").to_string())
.unwrap_or_else(|| secs.to_string())
}
pub async fn prune(db: &Pool, retention_days: u64) -> anyhow::Result<u64> {
if retention_days == 0 {
return Ok(0);
}
let seconds = (retention_days as i64).saturating_mul(86_400);
let cutoff = unix_now().saturating_sub(seconds);
let res = sqlx::query("DELETE FROM audit_events WHERE occurred_at < $1")
.bind(cutoff)
.execute(db)
.await?;
Ok(res.rows_affected())
}
#[derive(sqlx::FromRow)]
struct AuditRow {
id: i64,
occurred_at: i64,
event_type: String,
actor_id: Option<i64>,
actor_email: Option<String>,
target_id: Option<i64>,
target_email: Option<String>,
ip: Option<String>,
user_agent: Option<String>,
details: Option<String>,
}
}
#[cfg(feature = "tokens")]
pub mod tokens {
use super::*;
use crate::wire::ApiTokenView;
const TOKEN_BYTES: usize = 16;
const PREFIX_LEN: usize = 9; const TOKEN_PREFIX: &str = "dxsk_";
const MAX_NAME_LEN: usize = 64;
pub fn generate_api_token() -> (String, String, String) {
use argon2::password_hash::rand_core::{OsRng, RngCore};
use sha2::{Digest, Sha256};
let mut bytes = [0u8; TOKEN_BYTES];
let mut rng = OsRng;
rng.fill_bytes(&mut bytes);
let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
let plaintext = format!("{TOKEN_PREFIX}{hex}");
let prefix = plaintext.chars().take(PREFIX_LEN).collect::<String>();
let mut hasher = Sha256::new();
hasher.update(plaintext.as_bytes());
let hash_hex: String = hasher
.finalize()
.iter()
.map(|b| format!("{b:02x}"))
.collect();
(plaintext, prefix, hash_hex)
}
pub fn hash_api_token(plaintext: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(plaintext.as_bytes());
hasher
.finalize()
.iter()
.map(|b| format!("{b:02x}"))
.collect()
}
pub async fn create_for_user(
db: &Pool,
user_id: i64,
name: &str,
) -> anyhow::Result<(String, ApiTokenView)> {
let trimmed = name.trim();
if trimmed.is_empty() {
return Err(anyhow::anyhow!("Token name is required."));
}
if trimmed.chars().count() > MAX_NAME_LEN {
return Err(anyhow::anyhow!(
"Token name is too long (max {MAX_NAME_LEN} characters)."
));
}
let (plaintext, prefix, token_hash) = generate_api_token();
let now = unix_now();
let (id,): (i64,) = sqlx::query_as(
"INSERT INTO api_keys (user_id, name, token_hash, prefix, created_at) \
VALUES ($1, $2, $3, $4, $5) \
RETURNING id",
)
.bind(user_id)
.bind(trimmed)
.bind(&token_hash)
.bind(&prefix)
.bind(now)
.fetch_one(db)
.await?;
let view = ApiTokenView {
id,
name: trimmed.to_string(),
prefix,
created_at_iso: format_unix_date(now),
last_used_at_iso: None,
};
Ok((plaintext, view))
}
pub async fn list_for_user(db: &Pool, user_id: i64) -> anyhow::Result<Vec<ApiTokenView>> {
let rows: Vec<(i64, String, String, i64, Option<i64>)> = sqlx::query_as(
"SELECT id, name, prefix, created_at, last_used_at \
FROM api_keys \
WHERE user_id = $1 AND revoked_at IS NULL \
ORDER BY created_at DESC, id DESC",
)
.bind(user_id)
.fetch_all(db)
.await?;
Ok(rows
.into_iter()
.map(
|(id, name, prefix, created_at, last_used_at)| ApiTokenView {
id,
name,
prefix,
created_at_iso: format_unix_date(created_at),
last_used_at_iso: last_used_at.map(format_unix_date),
},
)
.collect())
}
fn format_unix_date(secs: i64) -> String {
use chrono::TimeZone;
chrono::Utc
.timestamp_opt(secs, 0)
.single()
.map(|dt| dt.format("%Y-%m-%d").to_string())
.unwrap_or_else(|| secs.to_string())
}
pub async fn revoke_for_user(db: &Pool, user_id: i64, token_id: i64) -> anyhow::Result<bool> {
let res = sqlx::query(
"UPDATE api_keys SET revoked_at = $1 \
WHERE id = $2 AND user_id = $3 AND revoked_at IS NULL",
)
.bind(unix_now())
.bind(token_id)
.bind(user_id)
.execute(db)
.await?;
Ok(res.rows_affected() > 0)
}
}