use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Key as GcmKey, Nonce};
use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use argon2::{Algorithm, Argon2, Params, Version};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use chrono::{Duration as ChronoDuration, Utc};
use hmac::{Hmac, Mac};
use rand::{Rng, RngCore};
use sha1::Sha1;
use crate::admin::audit::{record as audit_record, ActionType, AuditEvent, LogEntry};
use crate::admin::builtin::client_ip;
use crate::auth::sessions::{
hash_token_for_storage, invalidate_sessions, random_token, SessionInvalidationReason,
SessionTarget,
};
use crate::auth::Role;
use crate::error::{Error, Result};
use crate::http::Request;
use crate::orm::Db;
const MFA_VERIFIED_SESSION_DAYS: i64 = 14;
type HmacSha1 = Hmac<Sha1>;
#[derive(Clone)]
#[allow(dead_code)] pub struct MfaKey([u8; 32]);
#[allow(dead_code)] impl MfaKey {
pub(crate) fn from_env() -> Result<Self> {
let raw = std::env::var("RUSTIO_SECRET_KEY").map_err(|_| {
Error::Internal(
"RUSTIO_SECRET_KEY env var is unset; required when MfaPolicy != Disabled".into(),
)
})?;
let decoded = URL_SAFE_NO_PAD.decode(raw.trim()).map_err(|e| {
Error::Internal(format!(
"RUSTIO_SECRET_KEY is not valid URL-safe-base64 (no padding): {e}"
))
})?;
let bytes: [u8; 32] = decoded.as_slice().try_into().map_err(|_| {
Error::Internal(format!(
"RUSTIO_SECRET_KEY decodes to {} bytes; AES-256 requires exactly 32",
decoded.len()
))
})?;
Ok(Self(bytes))
}
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self(bytes)
}
fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
#[allow(dead_code)] pub(crate) fn wrap_secret(plaintext: &[u8], key: &MfaKey) -> Vec<u8> {
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = Aes256Gcm::new(GcmKey::<Aes256Gcm>::from_slice(key.as_bytes()));
let ciphertext = cipher
.encrypt(nonce, plaintext)
.expect("AES-256-GCM encrypt cannot fail for in-memory plaintext");
let mut out = Vec::with_capacity(12 + ciphertext.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ciphertext);
out
}
#[allow(dead_code)] pub(crate) fn unwrap_secret(input: &[u8], key: &MfaKey) -> Result<Vec<u8>> {
if input.len() < 12 + 16 {
return Err(Error::Internal(format!(
"MFA ciphertext too short ({} bytes); minimum is 28 (nonce + tag)",
input.len()
)));
}
let (nonce_bytes, ciphertext) = input.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let cipher = Aes256Gcm::new(GcmKey::<Aes256Gcm>::from_slice(key.as_bytes()));
cipher
.decrypt(nonce, ciphertext)
.map_err(|_| Error::Internal("MFA ciphertext failed AEAD verification".into()))
}
#[allow(dead_code)] pub const BACKUP_CODE_COUNT: usize = 8;
pub(crate) const BACKUP_CODE_LEN: usize = 8;
const BACKUP_CODE_ALPHABET: &[u8] = b"23456789ABCDEFGHJKMNPQRSTUVWXYZ";
fn backup_code_argon2() -> Result<Argon2<'static>> {
let params = Params::new(16 * 1024, 2, 1, None)
.map_err(|e| Error::Internal(format!("argon2 params: {e}")))?;
Ok(Argon2::new(Algorithm::Argon2id, Version::V0x13, params))
}
#[allow(dead_code)] pub(crate) fn generate_backup_codes(count: usize) -> Vec<String> {
let mut rng = rand::thread_rng();
let alphabet_len = BACKUP_CODE_ALPHABET.len();
(0..count)
.map(|_| {
let mut out = String::with_capacity(BACKUP_CODE_LEN + 1);
for i in 0..BACKUP_CODE_LEN {
if i == 4 {
out.push('-');
}
let idx = rng.gen_range(0..alphabet_len);
out.push(BACKUP_CODE_ALPHABET[idx] as char);
}
out
})
.collect()
}
#[allow(dead_code)] pub(crate) fn normalise_backup_code(input: &str) -> String {
input
.chars()
.filter(|c| c.is_ascii_alphanumeric())
.collect::<String>()
.to_ascii_uppercase()
}
#[allow(dead_code)] pub(crate) fn hash_backup_code(plaintext: &str) -> Result<String> {
let argon2 = backup_code_argon2()?;
let salt = SaltString::generate(&mut rand::thread_rng());
let hash = argon2
.hash_password(plaintext.as_bytes(), &salt)
.map_err(|e| Error::Internal(format!("argon2 hash: {e}")))?;
Ok(hash.to_string())
}
#[allow(dead_code)] pub(crate) fn verify_backup_code(plaintext: &str, hash: &str) -> bool {
let parsed = match PasswordHash::new(hash) {
Ok(p) => p,
Err(_) => return false,
};
Argon2::default()
.verify_password(plaintext.as_bytes(), &parsed)
.is_ok()
}
#[allow(dead_code)] pub fn current_step(now_unix: u64, step_seconds: u64) -> u64 {
debug_assert!(step_seconds > 0, "step_seconds must be > 0");
now_unix / step_seconds
}
#[allow(dead_code)] pub fn generate_totp(secret: &[u8], step: u64) -> u32 {
let mut mac = <HmacSha1 as Mac>::new_from_slice(secret).expect("HMAC accepts any key length");
mac.update(&step.to_be_bytes());
let hash = mac.finalize().into_bytes();
let offset = (hash[19] & 0x0F) as usize;
let bin_code = u32::from_be_bytes([
hash[offset] & 0x7F,
hash[offset + 1],
hash[offset + 2],
hash[offset + 3],
]);
bin_code % 1_000_000
}
#[allow(dead_code)] pub(crate) fn verify_totp(
secret: &[u8],
candidate: u32,
now_unix: u64,
step_seconds: u64,
skew_steps: u32,
) -> Option<u64> {
let current = current_step(now_unix, step_seconds);
let skew = i64::from(skew_steps);
for delta in -skew..=skew {
let step_to_try = (current as i64).saturating_add(delta).max(0) as u64;
if generate_totp(secret, step_to_try) == candidate {
return Some(step_to_try);
}
}
None
}
#[allow(dead_code)] pub struct ProvisionedSecret {
pub secret_bytes: Vec<u8>,
pub base32: String,
}
#[allow(dead_code)] pub fn provision_secret() -> ProvisionedSecret {
let mut bytes = vec![0u8; 20];
rand::thread_rng().fill_bytes(&mut bytes);
let base32 = base32_encode_no_pad(&bytes);
ProvisionedSecret {
secret_bytes: bytes,
base32,
}
}
#[allow(dead_code)] pub(crate) fn build_otpauth_url(
issuer: &str,
account: &str,
base32_secret: &str,
step_seconds: u64,
) -> String {
let issuer_enc = urlencoding::encode(issuer);
let account_enc = urlencoding::encode(account);
format!(
"otpauth://totp/{issuer_enc}:{account_enc}?secret={base32_secret}\
&issuer={issuer_enc}&algorithm=SHA1&digits=6&period={step_seconds}"
)
}
fn base32_encode_no_pad(bytes: &[u8]) -> String {
const ALPHA: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
let mut out = String::with_capacity(bytes.len().div_ceil(5) * 8);
let mut buffer: u32 = 0;
let mut bits_in_buffer: u8 = 0;
for &byte in bytes {
buffer = (buffer << 8) | u32::from(byte);
bits_in_buffer += 8;
while bits_in_buffer >= 5 {
bits_in_buffer -= 5;
let idx = (buffer >> bits_in_buffer) as usize & 0x1F;
out.push(ALPHA[idx] as char);
}
}
if bits_in_buffer > 0 {
let idx = (buffer << (5 - bits_in_buffer)) as usize & 0x1F;
out.push(ALPHA[idx] as char);
}
out
}
#[allow(dead_code)] pub(crate) fn base32_decode_no_pad(input: &str) -> Option<Vec<u8>> {
let mut buffer: u32 = 0;
let mut bits_in_buffer: u8 = 0;
let mut out = Vec::with_capacity(input.len() * 5 / 8 + 1);
for c in input.chars() {
if c.is_ascii_whitespace() || c == '-' || c == '=' {
continue;
}
let value: u32 = match c.to_ascii_uppercase() {
'A'..='Z' => (c.to_ascii_uppercase() as u32) - ('A' as u32),
'2'..='7' => (c as u32) - ('2' as u32) + 26,
_ => return None,
};
buffer = (buffer << 5) | value;
bits_in_buffer += 5;
if bits_in_buffer >= 8 {
bits_in_buffer -= 8;
out.push(((buffer >> bits_in_buffer) & 0xFF) as u8);
}
}
Some(out)
}
#[allow(dead_code)] pub enum EnrolOutcome {
Enrolled { plain_backup_codes: Vec<String> },
InvalidCode,
AlreadyEnrolled,
}
#[allow(dead_code)] #[allow(clippy::too_many_arguments)]
pub async fn confirm_enrolment(
db: &Db,
request: &Request,
user_id: i64,
secret_bytes: &[u8],
candidate_code: u32,
step_seconds: u64,
skew_steps: u32,
key: &MfaKey,
key_id: u32,
correlation_id: Option<&str>,
) -> Result<EnrolOutcome> {
let already: Option<bool> =
sqlx::query_scalar("SELECT mfa_enabled FROM rustio_users WHERE id = $1")
.bind(user_id)
.fetch_optional(db.pool())
.await?;
let Some(already) = already else {
return Err(Error::NotFound(format!("user {user_id} not found")));
};
if already {
return Ok(EnrolOutcome::AlreadyEnrolled);
}
let now_unix = Utc::now().timestamp().max(0) as u64;
let step = match verify_totp(
secret_bytes,
candidate_code,
now_unix,
step_seconds,
skew_steps,
) {
Some(step) => step,
None => return Ok(EnrolOutcome::InvalidCode),
};
let ciphertext = wrap_secret(secret_bytes, key);
sqlx::query(
"UPDATE rustio_users \
SET mfa_enabled = TRUE, \
mfa_secret_ciphertext = $1, \
mfa_secret_key_id = $2, \
mfa_last_used_step = $3 \
WHERE id = $4",
)
.bind(&ciphertext)
.bind(key_id as i32)
.bind(step as i64)
.bind(user_id)
.execute(db.pool())
.await?;
let plain_codes = generate_backup_codes(BACKUP_CODE_COUNT);
for code in &plain_codes {
let normalised = normalise_backup_code(code);
let hash = hash_backup_code(&normalised)?;
sqlx::query("INSERT INTO rustio_mfa_backup_codes (user_id, code_hash) VALUES ($1, $2)")
.bind(user_id)
.bind(&hash)
.execute(db.pool())
.await?;
}
let metadata = serde_json::json!({
"backup_codes_count": BACKUP_CODE_COUNT,
"key_id": key_id,
});
let ip = client_ip(request);
let mut entry = LogEntry::new(user_id, ActionType::Update, "users", user_id)
.with_event(AuditEvent::MfaEnabled)
.with_actor(user_id);
entry.correlation_id = correlation_id;
entry.ip_address = ip.as_deref();
entry.metadata = Some(metadata);
entry.summary = "MFA enabled (TOTP + 8 backup codes)".to_string();
audit_record(db, entry).await?;
Ok(EnrolOutcome::Enrolled {
plain_backup_codes: plain_codes,
})
}
#[allow(dead_code)] pub enum VerifyOutcome {
Verified { step_used: u64 },
Replay { last_used_step: u64 },
Invalid,
NotEnrolled,
}
#[allow(dead_code)] pub async fn verify_totp_for_user(
db: &Db,
user_id: i64,
candidate_code_str: &str,
step_seconds: u64,
skew_steps: u32,
key: &MfaKey,
) -> Result<VerifyOutcome> {
use sqlx::Row as _;
let candidate = match candidate_code_str.trim().parse::<u32>() {
Ok(n) if n < 1_000_000 => n,
_ => return Ok(VerifyOutcome::Invalid),
};
let row = sqlx::query(
"SELECT mfa_enabled, mfa_secret_ciphertext, mfa_last_used_step \
FROM rustio_users WHERE id = $1",
)
.bind(user_id)
.fetch_optional(db.pool())
.await?;
let row = row.ok_or_else(|| Error::NotFound(format!("user {user_id} not found")))?;
let mfa_enabled: bool = row.try_get("mfa_enabled")?;
if !mfa_enabled {
return Ok(VerifyOutcome::NotEnrolled);
}
let ciphertext: Option<Vec<u8>> = row.try_get("mfa_secret_ciphertext")?;
let last_used_step: Option<i64> = row.try_get("mfa_last_used_step")?;
let ciphertext = ciphertext.ok_or_else(|| {
Error::Internal(format!(
"user {user_id} has mfa_enabled=TRUE but mfa_secret_ciphertext IS NULL"
))
})?;
let secret_bytes = unwrap_secret(&ciphertext, key)?;
let now_unix = Utc::now().timestamp().max(0) as u64;
let step = match verify_totp(&secret_bytes, candidate, now_unix, step_seconds, skew_steps) {
Some(step) => step,
None => return Ok(VerifyOutcome::Invalid),
};
let last = last_used_step.unwrap_or(-1);
if (step as i64) <= last {
return Ok(VerifyOutcome::Replay {
last_used_step: last.max(0) as u64,
});
}
sqlx::query("UPDATE rustio_users SET mfa_last_used_step = $1 WHERE id = $2")
.bind(step as i64)
.bind(user_id)
.execute(db.pool())
.await?;
Ok(VerifyOutcome::Verified { step_used: step })
}
#[allow(dead_code)] pub enum BackupConsumeOutcome {
Consumed { code_id: i64, remaining: u32 },
Invalid,
NotEnrolled,
#[allow(dead_code)]
AlreadyUsed,
}
#[allow(dead_code)] pub async fn consume_backup_code(
db: &Db,
request: &Request,
user_id: i64,
candidate_str: &str,
via: &'static str,
correlation_id: Option<&str>,
) -> Result<BackupConsumeOutcome> {
use sqlx::Row as _;
let candidate = normalise_backup_code(candidate_str);
if candidate.is_empty() {
return Ok(BackupConsumeOutcome::Invalid);
}
let mfa_enabled: Option<bool> =
sqlx::query_scalar("SELECT mfa_enabled FROM rustio_users WHERE id = $1")
.bind(user_id)
.fetch_optional(db.pool())
.await?;
let mfa_enabled =
mfa_enabled.ok_or_else(|| Error::NotFound(format!("user {user_id} not found")))?;
if !mfa_enabled {
return Ok(BackupConsumeOutcome::NotEnrolled);
}
let rows = sqlx::query(
"SELECT id, code_hash FROM rustio_mfa_backup_codes \
WHERE user_id = $1 AND used_at IS NULL \
ORDER BY id",
)
.bind(user_id)
.fetch_all(db.pool())
.await?;
let mut matched_id: Option<i64> = None;
for row in &rows {
let id: i64 = row.try_get("id")?;
let hash: String = row.try_get("code_hash")?;
if verify_backup_code(&candidate, &hash) && matched_id.is_none() {
matched_id = Some(id);
}
}
let Some(matched_id) = matched_id else {
return Ok(BackupConsumeOutcome::Invalid);
};
let result = sqlx::query(
"UPDATE rustio_mfa_backup_codes \
SET used_at = NOW() \
WHERE id = $1 AND used_at IS NULL",
)
.bind(matched_id)
.execute(db.pool())
.await?;
if result.rows_affected() == 0 {
return Ok(BackupConsumeOutcome::Invalid);
}
let remaining: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM rustio_mfa_backup_codes \
WHERE user_id = $1 AND used_at IS NULL",
)
.bind(user_id)
.fetch_one(db.pool())
.await?;
let remaining = remaining.max(0) as u32;
let metadata = serde_json::json!({
"code_id": matched_id,
"remaining_codes": remaining,
"via": via,
});
let ip = client_ip(request);
let mut entry = LogEntry::new(user_id, ActionType::Update, "users", user_id)
.with_event(AuditEvent::MfaCodeConsumed)
.with_actor(user_id);
entry.correlation_id = correlation_id;
entry.ip_address = ip.as_deref();
entry.metadata = Some(metadata);
entry.summary = format!("backup code consumed via {via}; {remaining} remaining");
audit_record(db, entry).await?;
Ok(BackupConsumeOutcome::Consumed {
code_id: matched_id,
remaining,
})
}
#[allow(dead_code)] pub enum DisableOutcome {
Disabled { sessions_revoked: usize },
NotEnrolled,
#[allow(dead_code)]
PolicyRequired,
}
#[allow(dead_code)] pub async fn disable_mfa(
db: &Db,
request: &Request,
user_id: i64,
correlation_id: Option<&str>,
) -> Result<DisableOutcome> {
let mfa_enabled: Option<bool> =
sqlx::query_scalar("SELECT mfa_enabled FROM rustio_users WHERE id = $1")
.bind(user_id)
.fetch_optional(db.pool())
.await?;
let mfa_enabled =
mfa_enabled.ok_or_else(|| Error::NotFound(format!("user {user_id} not found")))?;
if !mfa_enabled {
return Ok(DisableOutcome::NotEnrolled);
}
let previous_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM rustio_mfa_backup_codes WHERE user_id = $1")
.bind(user_id)
.fetch_one(db.pool())
.await?;
let previous_count = previous_count.max(0) as u32;
sqlx::query(
"UPDATE rustio_users \
SET mfa_enabled = FALSE, \
mfa_secret_ciphertext = NULL, \
mfa_secret_key_id = NULL, \
mfa_last_used_step = NULL \
WHERE id = $1",
)
.bind(user_id)
.execute(db.pool())
.await?;
sqlx::query("DELETE FROM rustio_mfa_backup_codes WHERE user_id = $1")
.bind(user_id)
.execute(db.pool())
.await?;
let invalidation = invalidate_sessions(
db,
SessionTarget::User { user_id },
SessionInvalidationReason::MfaDisabled,
)
.await?;
let sessions_revoked = invalidation.revoked_session_ids.len();
let metadata = serde_json::json!({
"reason": "self_disabled",
"previous_backup_codes_count": previous_count,
"sessions_revoked": sessions_revoked,
});
let ip = client_ip(request);
let mut entry = LogEntry::new(user_id, ActionType::Update, "users", user_id)
.with_event(AuditEvent::MfaDisabled)
.with_actor(user_id);
entry.correlation_id = correlation_id;
entry.ip_address = ip.as_deref();
entry.metadata = Some(metadata);
entry.summary = format!(
"MFA self-disabled; {previous_count} backup codes deleted; \
{sessions_revoked} sessions revoked"
);
audit_record(db, entry).await?;
Ok(DisableOutcome::Disabled { sessions_revoked })
}
#[allow(dead_code)] pub enum RegenOutcome {
Regenerated {
plain_backup_codes: Vec<String>,
previous_codes_invalidated: u32,
},
NotEnrolled,
}
#[allow(dead_code)] pub async fn regenerate_backup_codes(
db: &Db,
request: &Request,
user_id: i64,
correlation_id: Option<&str>,
) -> Result<RegenOutcome> {
let plain_codes = generate_backup_codes(BACKUP_CODE_COUNT);
let hashes: Vec<String> = plain_codes
.iter()
.map(|c| {
let normalised = normalise_backup_code(c);
hash_backup_code(&normalised)
})
.collect::<Result<Vec<String>>>()?;
let mut tx = db.pool().begin().await?;
let mfa_enabled: Option<bool> =
sqlx::query_scalar("SELECT mfa_enabled FROM rustio_users WHERE id = $1 FOR UPDATE")
.bind(user_id)
.fetch_optional(&mut *tx)
.await?;
let mfa_enabled =
mfa_enabled.ok_or_else(|| Error::NotFound(format!("user {user_id} not found")))?;
if !mfa_enabled {
return Ok(RegenOutcome::NotEnrolled);
}
let previous_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM rustio_mfa_backup_codes WHERE user_id = $1")
.bind(user_id)
.fetch_one(&mut *tx)
.await?;
let previous_count = previous_count.max(0) as u32;
sqlx::query("DELETE FROM rustio_mfa_backup_codes WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;
for hash in &hashes {
sqlx::query("INSERT INTO rustio_mfa_backup_codes (user_id, code_hash) VALUES ($1, $2)")
.bind(user_id)
.bind(hash)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
let metadata = serde_json::json!({
"previous_codes_invalidated": previous_count,
"new_codes_count": BACKUP_CODE_COUNT,
});
let ip = client_ip(request);
let mut entry = LogEntry::new(user_id, ActionType::Update, "users", user_id)
.with_event(AuditEvent::BackupCodesRegenerated)
.with_actor(user_id);
entry.correlation_id = correlation_id;
entry.ip_address = ip.as_deref();
entry.metadata = Some(metadata);
entry.summary = format!(
"backup codes regenerated; {previous_count} previous invalidated; \
{BACKUP_CODE_COUNT} new codes issued"
);
audit_record(db, entry).await?;
Ok(RegenOutcome::Regenerated {
plain_backup_codes: plain_codes,
previous_codes_invalidated: previous_count,
})
}
#[allow(dead_code)] pub async fn promote_session_to_mfa_verified(
db: &Db,
current_session_id: i64,
user_id: i64,
) -> Result<String> {
let token = random_token();
let token_hash = hash_token_for_storage(&token);
let expires = Utc::now() + ChronoDuration::days(MFA_VERIFIED_SESSION_DAYS);
sqlx::query(
"INSERT INTO rustio_sessions \
(token, token_hash, user_id, expires_at, trust_level, parent_session_id) \
VALUES ($1, $2, $3, $4, 'mfa_verified', $5)",
)
.bind(&token)
.bind(&token_hash)
.bind(user_id)
.bind(expires)
.bind(current_session_id)
.execute(db.pool())
.await?;
invalidate_sessions(
db,
SessionTarget::Single {
session_id: current_session_id,
},
SessionInvalidationReason::TrustEscalation,
)
.await?;
Ok(token)
}
#[allow(dead_code)] pub(crate) async fn promote_session_mfa_elevated(
db: &Db,
session_id: i64,
ttl: ChronoDuration,
) -> Result<()> {
sqlx::query(
"UPDATE rustio_sessions \
SET elevated_until = NOW() + (INTERVAL '1 second' * $2::bigint), \
trust_level = 'mfa_verified' \
WHERE session_id = $1 AND revoked_at IS NULL",
)
.bind(session_id)
.bind(ttl.num_seconds())
.execute(db.pool())
.await?;
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub enum MfaPolicy {
Disabled,
Optional,
Required,
RequiredForRoles(&'static [Role]),
}
impl Default for MfaPolicy {
fn default() -> Self {
Self::Optional
}
}
pub(crate) async fn migrate_user_mfa_schema(db: &Db) -> Result<()> {
sqlx::query(
"ALTER TABLE rustio_users \
ADD COLUMN IF NOT EXISTS mfa_enabled BOOLEAN NOT NULL DEFAULT FALSE",
)
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_users ADD COLUMN IF NOT EXISTS mfa_secret_ciphertext BYTEA")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_users ADD COLUMN IF NOT EXISTS mfa_secret_key_id INT")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_users ADD COLUMN IF NOT EXISTS mfa_last_used_step BIGINT")
.execute(db.pool())
.await?;
sqlx::query(
"CREATE TABLE IF NOT EXISTS rustio_mfa_backup_codes ( \
id BIGSERIAL PRIMARY KEY, \
user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE, \
code_hash TEXT NOT NULL, \
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), \
used_at TIMESTAMPTZ \
)",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_mfa_backup_codes_user_unused_idx \
ON rustio_mfa_backup_codes (user_id) \
WHERE used_at IS NULL",
)
.execute(db.pool())
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_optional() {
assert!(matches!(MfaPolicy::default(), MfaPolicy::Optional));
}
#[test]
fn policy_is_copy() {
const ROLES: &[Role] = &[Role::Administrator];
let original = MfaPolicy::RequiredForRoles(ROLES);
let copy = original;
assert!(matches!(original, MfaPolicy::RequiredForRoles(_)));
assert!(matches!(copy, MfaPolicy::RequiredForRoles(_)));
}
fn fixed_test_key() -> MfaKey {
let mut bytes = [0u8; 32];
for (i, b) in bytes.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(7).wrapping_add(13);
}
MfaKey::from_bytes(bytes)
}
#[test]
fn wrap_unwrap_round_trip_recovers_plaintext() {
let key = fixed_test_key();
let plaintext = b"hello-mfa-secret-20-bytes";
let ciphertext = wrap_secret(plaintext, &key);
assert_eq!(ciphertext.len(), 12 + plaintext.len() + 16);
let recovered = unwrap_secret(&ciphertext, &key).expect("round-trip must decrypt");
assert_eq!(recovered, plaintext);
}
#[test]
fn wrap_uses_fresh_nonce_per_call() {
let key = fixed_test_key();
let plaintext = b"identical-plaintext";
let a = wrap_secret(plaintext, &key);
let b = wrap_secret(plaintext, &key);
assert_ne!(a, b, "fresh nonce per call must yield different ciphertext");
}
#[test]
fn tampered_ciphertext_fails_aead_verification() {
let key = fixed_test_key();
let plaintext = b"sensitive-mfa-secret";
let mut ciphertext = wrap_secret(plaintext, &key);
ciphertext[20] ^= 0x01;
let result = unwrap_secret(&ciphertext, &key);
assert!(
result.is_err(),
"tampered ciphertext must fail AEAD verification"
);
}
#[test]
fn wrong_key_fails_decryption() {
let key_enc = fixed_test_key();
let key_dec = MfaKey::from_bytes([0xFFu8; 32]);
let plaintext = b"wrong-key-test";
let ciphertext = wrap_secret(plaintext, &key_enc);
let result = unwrap_secret(&ciphertext, &key_dec);
assert!(result.is_err(), "decrypt with wrong key must fail");
}
#[test]
fn truncated_input_rejects_explicitly() {
let key = fixed_test_key();
let too_short = [0u8; 27];
let result = unwrap_secret(&too_short, &key);
assert!(result.is_err(), "input below 28 bytes must reject");
}
#[test]
fn alphabet_is_31_chars_no_ambiguous() {
assert_eq!(BACKUP_CODE_ALPHABET.len(), 31);
for &b in BACKUP_CODE_ALPHABET {
let c = b as char;
assert!(c.is_ascii_alphanumeric(), "non-alphanumeric: {c:?}");
assert!(
!matches!(c, '0' | 'O' | '1' | 'I' | 'L'),
"ambiguous char in alphabet: {c:?}"
);
}
}
#[test]
fn generate_returns_count_codes() {
let codes = generate_backup_codes(BACKUP_CODE_COUNT);
assert_eq!(codes.len(), BACKUP_CODE_COUNT);
}
#[test]
fn each_code_is_xxxx_dash_xxxx_shape() {
let codes = generate_backup_codes(8);
for code in &codes {
assert_eq!(code.len(), BACKUP_CODE_LEN + 1, "wrong length: {code:?}");
assert_eq!(
code.chars().nth(4),
Some('-'),
"hyphen missing at position 4: {code:?}"
);
for (i, c) in code.chars().enumerate() {
if i == 4 {
continue;
}
assert!(
BACKUP_CODE_ALPHABET.contains(&(c as u8)),
"char {c:?} at position {i} not in alphabet"
);
}
}
}
#[test]
fn generated_codes_are_unique_within_batch() {
let codes = generate_backup_codes(64);
let unique: std::collections::HashSet<_> = codes.iter().cloned().collect();
assert_eq!(unique.len(), 64, "batch contained duplicates");
}
#[test]
fn normalise_strips_hyphens_and_uppercases() {
assert_eq!(normalise_backup_code("ABCD-EFGH"), "ABCDEFGH");
assert_eq!(normalise_backup_code("abcd-efgh"), "ABCDEFGH");
assert_eq!(normalise_backup_code("AbCdEfGh"), "ABCDEFGH");
assert_eq!(normalise_backup_code(" abcd efgh "), "ABCDEFGH");
assert_eq!(normalise_backup_code("abcdefgh"), "ABCDEFGH");
}
#[test]
fn normalise_is_idempotent() {
let once = normalise_backup_code("xxxx-yyyy");
let twice = normalise_backup_code(&once);
assert_eq!(once, twice);
}
#[test]
fn hash_verify_round_trip() {
let code = "ABCDEFGH";
let hash = hash_backup_code(code).expect("hashing must succeed");
assert!(verify_backup_code(code, &hash), "round-trip must verify");
}
#[test]
fn hash_uses_argon2id_low_memory_params() {
let hash = hash_backup_code("ABCDEFGH").expect("hash succeeds");
assert!(hash.starts_with("$argon2id$"), "wrong algorithm: {hash}");
assert!(
hash.contains("m=16384,t=2,p=1"),
"params drifted from locked m=16MB/t=2/p=1: {hash}"
);
}
#[test]
fn verify_rejects_wrong_code() {
let hash = hash_backup_code("ABCDEFGH").expect("hash succeeds");
assert!(!verify_backup_code("WRONGCDE", &hash));
}
#[test]
fn verify_rejects_invalid_phc_string() {
assert!(!verify_backup_code("ABCDEFGH", "not-a-phc-hash"));
assert!(!verify_backup_code("ABCDEFGH", ""));
}
#[test]
fn separate_hash_calls_yield_different_phc_strings() {
let a = hash_backup_code("ABCDEFGH").expect("a");
let b = hash_backup_code("ABCDEFGH").expect("b");
assert_ne!(a, b, "fresh salt must produce different hashes");
assert!(verify_backup_code("ABCDEFGH", &a));
assert!(verify_backup_code("ABCDEFGH", &b));
}
const RFC6238_SECRET: &[u8] = b"12345678901234567890";
#[test]
fn current_step_at_canonical_30s_interval() {
assert_eq!(current_step(0, 30), 0);
assert_eq!(current_step(29, 30), 0);
assert_eq!(current_step(30, 30), 1);
assert_eq!(current_step(59, 30), 1);
assert_eq!(current_step(60, 30), 2);
}
#[test]
fn rfc6238_appendix_b_test_vectors_truncated_to_6_digits() {
let cases: &[(u64, u32)] = &[
(59, 287_082),
(1_111_111_109, 81_804),
(1_111_111_111, 50_471),
(1_234_567_890, 5_924),
(2_000_000_000, 279_037),
(20_000_000_000, 353_130),
];
for &(t, expected) in cases {
let step = current_step(t, 30);
let got = generate_totp(RFC6238_SECRET, step);
assert_eq!(got, expected, "RFC 6238 vector at T={t} mismatched");
}
}
#[test]
fn generate_totp_returns_six_digit_range() {
for step in [0u64, 1, 100, 12_345, u64::MAX] {
let code = generate_totp(RFC6238_SECRET, step);
assert!(
code < 1_000_000,
"code out of range for step {step}: {code}"
);
}
}
#[test]
fn verify_accepts_current_step() {
let t = 1_111_111_111u64;
let step = current_step(t, 30);
let code = generate_totp(RFC6238_SECRET, step);
assert_eq!(verify_totp(RFC6238_SECRET, code, t, 30, 1), Some(step));
}
#[test]
fn verify_accepts_one_step_skew() {
let t_gen = 1_111_111_111u64;
let step_gen = current_step(t_gen, 30);
let code = generate_totp(RFC6238_SECRET, step_gen);
let t_verify = t_gen + 30; let result = verify_totp(RFC6238_SECRET, code, t_verify, 30, 1);
assert_eq!(result, Some(step_gen), "skew ±1 must accept previous step");
}
#[test]
fn verify_rejects_two_step_skew_when_window_is_one() {
let t_gen = 1_111_111_111u64;
let step_gen = current_step(t_gen, 30);
let code = generate_totp(RFC6238_SECRET, step_gen);
let t_verify = t_gen + 60; let result = verify_totp(RFC6238_SECRET, code, t_verify, 30, 1);
assert_eq!(result, None, "skew=1 must reject two-step drift");
}
#[test]
fn totp_verify_rejects_wrong_code() {
let t = 1_111_111_111u64;
let result = verify_totp(RFC6238_SECRET, 999_999, t, 30, 1);
assert_eq!(result, None);
}
#[test]
fn verify_does_not_underflow_at_t_zero() {
let code = generate_totp(RFC6238_SECRET, 0);
let result = verify_totp(RFC6238_SECRET, code, 0, 30, 1);
assert_eq!(result, Some(0));
}
#[test]
fn base32_rfc4648_test_vector_foobar() {
assert_eq!(base32_encode_no_pad(b"foobar"), "MZXW6YTBOI");
}
#[test]
fn base32_rfc4648_progressive_test_vectors() {
assert_eq!(base32_encode_no_pad(b"f"), "MY");
assert_eq!(base32_encode_no_pad(b"fo"), "MZXQ");
assert_eq!(base32_encode_no_pad(b"foo"), "MZXW6");
assert_eq!(base32_encode_no_pad(b"foob"), "MZXW6YQ");
assert_eq!(base32_encode_no_pad(b"fooba"), "MZXW6YTB");
}
#[test]
fn provision_secret_returns_20_bytes() {
let secret = provision_secret();
assert_eq!(
secret.secret_bytes.len(),
20,
"RFC 6238 default + universal authenticator-app interop"
);
}
#[test]
fn provision_secret_base32_length_matches_secret() {
let secret = provision_secret();
assert_eq!(secret.base32.len(), 32);
for c in secret.base32.chars() {
assert!(
c.is_ascii_uppercase() || ('2'..='7').contains(&c),
"non-base32 char in encoding: {c:?}"
);
}
}
#[test]
fn provision_secret_each_call_yields_different_secret() {
let mut seen = std::collections::HashSet::new();
for _ in 0..16 {
let secret = provision_secret();
assert!(seen.insert(secret.secret_bytes), "RNG produced duplicate");
}
}
#[test]
fn build_otpauth_url_matches_google_authenticator_format() {
let url = build_otpauth_url("Acme Corp", "alice@example.com", "MZXW6YTBOI", 30);
assert!(
url.starts_with("otpauth://totp/Acme%20Corp:alice%40example.com?"),
"wrong path encoding: {url}"
);
assert!(url.contains("secret=MZXW6YTBOI"), "secret missing: {url}");
assert!(
url.contains("issuer=Acme%20Corp"),
"issuer query missing: {url}"
);
assert!(url.contains("algorithm=SHA1"), "algorithm missing: {url}");
assert!(url.contains("digits=6"), "digits missing: {url}");
assert!(url.contains("period=30"), "period missing: {url}");
}
#[test]
fn base32_decode_rfc4648_round_trips_progressive_vectors() {
let cases: &[(&str, &[u8])] = &[
("MY", b"f"),
("MZXQ", b"fo"),
("MZXW6", b"foo"),
("MZXW6YQ", b"foob"),
("MZXW6YTB", b"fooba"),
("MZXW6YTBOI", b"foobar"),
];
for &(encoded, expected) in cases {
let decoded =
base32_decode_no_pad(encoded).unwrap_or_else(|| panic!("decode failed: {encoded}"));
assert_eq!(decoded.as_slice(), expected, "round-trip {encoded}");
}
}
#[test]
fn base32_decode_tolerates_hyphens_spaces_padding_and_lowercase() {
for variant in [
"MZXW6YTBOI",
"mzxw6ytboi",
"MZXW 6YTB OI",
"MZXW-6YTB-OI",
"MZXW6YTBOI==",
] {
assert_eq!(
base32_decode_no_pad(variant).expect("decode should succeed"),
b"foobar",
"variant: {variant:?}"
);
}
}
#[test]
fn base32_decode_rejects_non_alphabet_chars() {
assert!(base32_decode_no_pad("ABC0DEF").is_none());
assert!(base32_decode_no_pad("ABC1DEF").is_none());
assert!(base32_decode_no_pad("ABC8DEF").is_none());
assert!(base32_decode_no_pad("ABC9DEF").is_none());
assert!(base32_decode_no_pad("hello!").is_none());
}
}