use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::{anyhow, Context, Result};
use argon2::password_hash::{rand_core::OsRng as ArgonOsRng, PasswordHasher, PasswordVerifier, SaltString};
use argon2::{Algorithm, Argon2, Params, PasswordHash, Version};
use base64::{
engine::general_purpose::{STANDARD as B64, URL_SAFE_NO_PAD as B64URL},
Engine,
};
use chrono::{DateTime, Duration, Utc};
use rand::{rngs::OsRng as RandOsRng, RngCore};
use rusqlite::{params, Connection, OptionalExtension};
use sha2::{Digest, Sha256};
use thiserror::Error;
use totp_rs::{Algorithm as TotpAlg, TOTP};
use dragoon_proto::{constants, verify::verify_ssh_wire_signature};
fn argon2() -> Argon2<'static> {
let params = Params::new(65_536, 3, 4, None).expect("argon2 params are valid");
Argon2::new(Algorithm::Argon2id, Version::V0x13, params)
}
pub fn hash_password(plain: &str) -> Result<String> {
let salt = SaltString::generate(&mut ArgonOsRng);
let hash = argon2()
.hash_password(plain.as_bytes(), &salt)
.map_err(|e| anyhow!("argon2 hash: {e}"))?;
Ok(hash.to_string())
}
pub fn verify_password(plain: &str, hashed: &str) -> bool {
let Ok(parsed) = PasswordHash::new(hashed) else {
return false;
};
Argon2::default()
.verify_password(plain.as_bytes(), &parsed)
.is_ok()
}
pub fn generate_totp_secret() -> String {
let mut bytes = [0u8; 20];
RandOsRng.fill_bytes(&mut bytes);
base32_encode_no_pad(&bytes)
}
pub fn verify_totp(secret_base32: &str, code: &str) -> bool {
let Some(secret_bytes) = base32_decode(secret_base32) else {
return false;
};
let Ok(totp) = TOTP::new(TotpAlg::SHA1, 6, 1, 30, secret_bytes) else {
return false;
};
let Ok(now) = SystemTime::now().duration_since(UNIX_EPOCH) else {
return false;
};
let now = now.as_secs();
for offset in [-1i64, 0, 1] {
let t = (now as i64 + offset * 30).max(0) as u64;
let want = totp.generate(t);
if want == code {
return true;
}
}
false
}
const B32: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
fn base32_encode_no_pad(input: &[u8]) -> String {
let mut out = String::new();
let mut buffer: u32 = 0;
let mut bits = 0u32;
for &b in input {
buffer = (buffer << 8) | u32::from(b);
bits += 8;
while bits >= 5 {
bits -= 5;
let idx = ((buffer >> bits) & 0x1f) as usize;
out.push(B32[idx] as char);
}
}
if bits > 0 {
let idx = ((buffer << (5 - bits)) & 0x1f) as usize;
out.push(B32[idx] as char);
}
out
}
fn base32_decode(s: &str) -> Option<Vec<u8>> {
let mut out = Vec::with_capacity(s.len() * 5 / 8);
let mut buffer: u32 = 0;
let mut bits = 0u32;
for c in s.chars() {
if c == '=' {
break;
}
let v = match c {
'A'..='Z' => (c as u8) - b'A',
'a'..='z' => (c as u8) - b'a',
'2'..='7' => (c as u8) - b'2' + 26,
_ => return None,
};
buffer = (buffer << 5) | u32::from(v);
bits += 5;
if bits >= 8 {
bits -= 8;
out.push(((buffer >> bits) & 0xff) as u8);
}
}
Some(out)
}
fn token_urlsafe(byte_len: usize) -> String {
let mut bytes = vec![0u8; byte_len];
RandOsRng.fill_bytes(&mut bytes);
B64URL.encode(&bytes)
}
fn sha256_hex(input: &str) -> String {
let digest = Sha256::digest(input.as_bytes());
hex::encode(digest)
}
pub fn generate_recovery_codes(n: usize) -> (Vec<String>, Vec<String>) {
let plain: Vec<String> = (0..n).map(|_| token_urlsafe(10)).collect();
let hashes: Vec<String> = plain.iter().map(|c| sha256_hex(c)).collect();
(plain, hashes)
}
pub fn consume_recovery_code(code: &str, hashes: &[String]) -> (bool, Vec<String>) {
let h = sha256_hex(code);
if !hashes.iter().any(|x| x == &h) {
return (false, hashes.to_vec());
}
(true, hashes.iter().filter(|x| **x != h).cloned().collect())
}
#[derive(Debug, Clone)]
pub struct Session {
pub user_id: i64,
pub fingerprint: String,
pub expires_at: DateTime<Utc>,
}
fn iso_now() -> String {
Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
}
fn iso(dt: DateTime<Utc>) -> String {
dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
}
fn parse_iso(s: &str) -> Result<DateTime<Utc>> {
let s = if let Some(stripped) = s.strip_suffix('Z') {
format!("{stripped}+00:00")
} else {
s.to_owned()
};
Ok(DateTime::parse_from_rfc3339(&s)
.with_context(|| format!("parse rfc3339: {s}"))?
.with_timezone(&Utc))
}
fn hash_token(token: &str) -> String {
sha256_hex(token)
}
pub fn issue_session(
conn: &Connection,
user_id: i64,
fingerprint: &str,
ttl: Duration,
) -> Result<(String, DateTime<Utc>)> {
let token = token_urlsafe(32);
let h = hash_token(&token);
let now = Utc::now();
let expires = now + ttl;
conn.execute(
"INSERT INTO sessions (token_hash, user_id, fingerprint, created_at, expires_at)
VALUES (?,?,?,?,?)",
params![h, user_id, fingerprint, iso(now), iso(expires)],
)?;
Ok((token, expires))
}
pub fn lookup_session(conn: &Connection, token: &str) -> Result<Option<Session>> {
let h = hash_token(token);
let row: Option<(i64, String, String, Option<String>)> = conn
.query_row(
"SELECT user_id, fingerprint, expires_at, revoked_at
FROM sessions WHERE token_hash=?",
[h],
|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)),
)
.optional()?;
let Some((user_id, fingerprint, expires_at, revoked_at)) = row else {
return Ok(None);
};
if revoked_at.is_some() {
return Ok(None);
}
let expires = parse_iso(&expires_at)?;
if expires <= Utc::now() {
return Ok(None);
}
Ok(Some(Session {
user_id,
fingerprint,
expires_at: expires,
}))
}
pub fn revoke_session(conn: &Connection, token: &str) -> Result<()> {
let h = hash_token(token);
conn.execute(
"UPDATE sessions SET revoked_at=? WHERE token_hash=?",
params![iso_now(), h],
)?;
Ok(())
}
#[derive(Debug, Clone)]
pub struct IssuedChallenge {
pub challenge: String,
pub expires_at: DateTime<Utc>,
}
pub fn issue_challenge(conn: &Connection, ttl_sec: Option<i64>) -> Result<IssuedChallenge> {
let ttl = ttl_sec.unwrap_or(constants::CHALLENGE_TTL_SEC);
let challenge = token_urlsafe(16);
let expires = Utc::now() + Duration::seconds(ttl);
conn.execute(
"INSERT INTO challenges (challenge, expires_at, used_at) VALUES (?,?,NULL)",
params![challenge, iso(expires)],
)?;
Ok(IssuedChallenge {
challenge,
expires_at: expires,
})
}
pub fn consume_challenge(conn: &Connection, challenge: &str) -> Result<bool> {
let row: Option<(String, Option<String>)> = conn
.query_row(
"SELECT expires_at, used_at FROM challenges WHERE challenge=?",
[challenge],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.optional()?;
let Some((expires_at, used_at)) = row else {
return Ok(false);
};
if used_at.is_some() {
return Ok(false);
}
if parse_iso(&expires_at)? <= Utc::now() {
return Ok(false);
}
let n = conn.execute(
"UPDATE challenges SET used_at=? WHERE challenge=? AND used_at IS NULL",
params![iso_now(), challenge],
)?;
Ok(n > 0)
}
pub fn consume_nonce(
conn: &Connection,
user_id: i64,
nonce: &str,
ttl_sec: i64,
) -> Result<bool> {
let expires = Utc::now() + Duration::seconds(ttl_sec);
let r = conn.execute(
"INSERT INTO nonces (user_id, nonce, expires_at) VALUES (?,?,?)",
params![user_id, nonce, iso(expires)],
);
match r {
Ok(_) => Ok(true),
Err(rusqlite::Error::SqliteFailure(err, _))
if err.code == rusqlite::ErrorCode::ConstraintViolation =>
{
Ok(false)
}
Err(e) => Err(e.into()),
}
}
pub fn purge_expired_nonces(conn: &Connection) -> Result<usize> {
Ok(conn.execute(
"DELETE FROM nonces WHERE expires_at<?",
params![iso_now()],
)?)
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum AuthError {
#[error("no_session")]
NoSession,
#[error("clock_skew")]
ClockSkew,
#[error("fp_session_mismatch")]
FpSessionMismatch,
#[error("unknown_fp")]
UnknownFingerprint,
#[error("revoked_fp")]
RevokedFingerprint,
#[error("replay")]
Replay,
#[error("bad_sig")]
BadSignature,
}
impl AuthError {
pub fn reason(&self) -> &'static str {
match self {
Self::NoSession => "no_session",
Self::ClockSkew => "clock_skew",
Self::FpSessionMismatch => "fp_session_mismatch",
Self::UnknownFingerprint => "unknown_fp",
Self::RevokedFingerprint => "revoked_fp",
Self::Replay => "replay",
Self::BadSignature => "bad_sig",
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn verify_signed_request(
conn: &Connection,
session_token: &str,
method: &str,
path: &str,
timestamp: i64,
nonce: &str,
key_fingerprint: &str,
signature_b64: &str,
body: &[u8],
now: Option<i64>,
) -> std::result::Result<Session, AuthError> {
let actual_now = now.unwrap_or_else(|| {
i64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
)
.unwrap_or(0)
});
let sess = lookup_session(conn, session_token)
.map_err(|_| AuthError::NoSession)?
.ok_or(AuthError::NoSession)?;
if (actual_now - timestamp).abs() > constants::TIMESTAMP_SKEW_SEC {
return Err(AuthError::ClockSkew);
}
if sess.fingerprint != key_fingerprint {
return Err(AuthError::FpSessionMismatch);
}
let row: Option<(Vec<u8>, Option<String>)> = conn
.query_row(
"SELECT pubkey_blob, revoked_at FROM user_pubkeys
WHERE user_id=? AND fingerprint=?",
params![sess.user_id, key_fingerprint],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.optional()
.map_err(|_| AuthError::UnknownFingerprint)?;
let Some((pub_blob, revoked_at)) = row else {
return Err(AuthError::UnknownFingerprint);
};
if revoked_at.is_some() {
return Err(AuthError::RevokedFingerprint);
}
let fresh = consume_nonce(conn, sess.user_id, nonce, constants::NONCE_TTL_SEC)
.map_err(|_| AuthError::Replay)?;
if !fresh {
return Err(AuthError::Replay);
}
let canonical = dragoon_proto::canonical::canonical_string(method, path, timestamp, nonce, body);
let sig_wire = B64.decode(signature_b64).map_err(|_| AuthError::BadSignature)?;
verify_ssh_wire_signature(&pub_blob, &sig_wire, &canonical)
.map_err(|_| AuthError::BadSignature)?;
Ok(sess)
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh() -> Connection {
let c = crate::db::connect_in_memory().unwrap();
crate::db::bootstrap(&c).unwrap();
c
}
#[test]
fn argon2_hash_then_verify() {
let h = hash_password("hunter2").unwrap();
assert!(h.starts_with("$argon2id$"));
assert!(verify_password("hunter2", &h));
assert!(!verify_password("wrong", &h));
}
#[test]
fn totp_round_trip() {
let s = generate_totp_secret();
assert!(base32_decode(&s).is_some());
let secret_bytes = base32_decode(&s).unwrap();
let totp = TOTP::new(TotpAlg::SHA1, 6, 1, 30, secret_bytes).unwrap();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let code: String = totp.generate(now);
assert!(verify_totp(&s, &code));
assert!(!verify_totp(&s, "000000"));
}
#[test]
fn recovery_codes_consume_once() {
let (plain, hashes) = generate_recovery_codes(3);
let (ok, remaining) = consume_recovery_code(&plain[1], &hashes);
assert!(ok);
assert_eq!(remaining.len(), 2);
let (ok2, _) = consume_recovery_code(&plain[1], &remaining);
assert!(!ok2);
}
#[test]
fn session_round_trip_then_revoke() {
let c = fresh();
c.execute(
"INSERT INTO users (username, password_hash, totp_secret_enc, created_at)
VALUES (?,?,?,?)",
params!["alice", "h", "s", "2026-01-01T00:00:00Z"],
)
.unwrap();
let uid = c.last_insert_rowid();
let (tok, _) = issue_session(&c, uid, "SHA256:fp", Duration::hours(1)).unwrap();
let sess = lookup_session(&c, &tok).unwrap().unwrap();
assert_eq!(sess.user_id, uid);
assert_eq!(sess.fingerprint, "SHA256:fp");
revoke_session(&c, &tok).unwrap();
assert!(lookup_session(&c, &tok).unwrap().is_none());
}
#[test]
fn challenge_one_shot() {
let c = fresh();
let ch = issue_challenge(&c, None).unwrap();
assert!(consume_challenge(&c, &ch.challenge).unwrap());
assert!(!consume_challenge(&c, &ch.challenge).unwrap());
}
#[test]
fn nonce_rejected_on_replay() {
let c = fresh();
c.execute(
"INSERT INTO users (username, password_hash, totp_secret_enc, created_at)
VALUES (?,?,?,?)",
params!["u", "h", "s", "2026-01-01T00:00:00Z"],
)
.unwrap();
let uid = c.last_insert_rowid();
assert!(consume_nonce(&c, uid, "abc", 300).unwrap());
assert!(!consume_nonce(&c, uid, "abc", 300).unwrap());
}
}