Skip to main content

dragoon_server/
auth.rs

1//! Authentication primitives: password hashing, TOTP, recovery codes,
2//! session / challenge / nonce lifecycle.
3//!
4//! The full per-request signature pipeline (`verify_signed_request`) lives
5//! at the bottom of this file; Phase 6 plugs it into an axum middleware.
6//! Mirrors `python/.../server/auth.py`.
7
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use anyhow::{anyhow, Context, Result};
11use argon2::password_hash::{rand_core::OsRng as ArgonOsRng, PasswordHasher, PasswordVerifier, SaltString};
12use argon2::{Algorithm, Argon2, Params, PasswordHash, Version};
13use base64::{
14    engine::general_purpose::{STANDARD as B64, URL_SAFE_NO_PAD as B64URL},
15    Engine,
16};
17use chrono::{DateTime, Duration, Utc};
18use rand::{rngs::OsRng as RandOsRng, RngCore};
19use rusqlite::{params, Connection, OptionalExtension};
20use sha2::{Digest, Sha256};
21use thiserror::Error;
22use totp_rs::{Algorithm as TotpAlg, TOTP};
23
24use dragoon_proto::{constants, verify::verify_ssh_wire_signature};
25
26// --------------------------------------------------------------------------
27// Password (argon2id)
28// --------------------------------------------------------------------------
29
30fn argon2() -> Argon2<'static> {
31    // Pinned params (m=65536 KiB, t=3, p=4) for new hashes. Verifying older
32    // hashes uses whatever params the PHC string carries so Python-produced
33    // hashes (which use argon2-cffi defaults) still verify.
34    let params = Params::new(65_536, 3, 4, None).expect("argon2 params are valid");
35    Argon2::new(Algorithm::Argon2id, Version::V0x13, params)
36}
37
38pub fn hash_password(plain: &str) -> Result<String> {
39    let salt = SaltString::generate(&mut ArgonOsRng);
40    let hash = argon2()
41        .hash_password(plain.as_bytes(), &salt)
42        .map_err(|e| anyhow!("argon2 hash: {e}"))?;
43    Ok(hash.to_string())
44}
45
46pub fn verify_password(plain: &str, hashed: &str) -> bool {
47    let Ok(parsed) = PasswordHash::new(hashed) else {
48        return false;
49    };
50    Argon2::default()
51        .verify_password(plain.as_bytes(), &parsed)
52        .is_ok()
53}
54
55// --------------------------------------------------------------------------
56// TOTP (RFC 6238, ±1 step)
57// --------------------------------------------------------------------------
58
59pub fn generate_totp_secret() -> String {
60    // 160-bit base32 secret, matching pyotp.random_base32().
61    let mut bytes = [0u8; 20];
62    RandOsRng.fill_bytes(&mut bytes);
63    base32_encode_no_pad(&bytes)
64}
65
66pub fn verify_totp(secret_base32: &str, code: &str) -> bool {
67    let Some(secret_bytes) = base32_decode(secret_base32) else {
68        return false;
69    };
70    let Ok(totp) = TOTP::new(TotpAlg::SHA1, 6, 1, 30, secret_bytes) else {
71        return false;
72    };
73    let Ok(now) = SystemTime::now().duration_since(UNIX_EPOCH) else {
74        return false;
75    };
76    let now = now.as_secs();
77    // ±1 step window
78    for offset in [-1i64, 0, 1] {
79        let t = (now as i64 + offset * 30).max(0) as u64;
80        let want = totp.generate(t);
81        if want == code {
82            return true;
83        }
84    }
85    false
86}
87
88const B32: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
89
90fn base32_encode_no_pad(input: &[u8]) -> String {
91    let mut out = String::new();
92    let mut buffer: u32 = 0;
93    let mut bits = 0u32;
94    for &b in input {
95        buffer = (buffer << 8) | u32::from(b);
96        bits += 8;
97        while bits >= 5 {
98            bits -= 5;
99            let idx = ((buffer >> bits) & 0x1f) as usize;
100            out.push(B32[idx] as char);
101        }
102    }
103    if bits > 0 {
104        let idx = ((buffer << (5 - bits)) & 0x1f) as usize;
105        out.push(B32[idx] as char);
106    }
107    out
108}
109
110fn base32_decode(s: &str) -> Option<Vec<u8>> {
111    let mut out = Vec::with_capacity(s.len() * 5 / 8);
112    let mut buffer: u32 = 0;
113    let mut bits = 0u32;
114    for c in s.chars() {
115        if c == '=' {
116            break;
117        }
118        let v = match c {
119            'A'..='Z' => (c as u8) - b'A',
120            'a'..='z' => (c as u8) - b'a',
121            '2'..='7' => (c as u8) - b'2' + 26,
122            _ => return None,
123        };
124        buffer = (buffer << 5) | u32::from(v);
125        bits += 5;
126        if bits >= 8 {
127            bits -= 8;
128            out.push(((buffer >> bits) & 0xff) as u8);
129        }
130    }
131    Some(out)
132}
133
134// --------------------------------------------------------------------------
135// Recovery codes
136// --------------------------------------------------------------------------
137
138fn token_urlsafe(byte_len: usize) -> String {
139    let mut bytes = vec![0u8; byte_len];
140    RandOsRng.fill_bytes(&mut bytes);
141    B64URL.encode(&bytes)
142}
143
144fn sha256_hex(input: &str) -> String {
145    let digest = Sha256::digest(input.as_bytes());
146    hex::encode(digest)
147}
148
149pub fn generate_recovery_codes(n: usize) -> (Vec<String>, Vec<String>) {
150    let plain: Vec<String> = (0..n).map(|_| token_urlsafe(10)).collect();
151    let hashes: Vec<String> = plain.iter().map(|c| sha256_hex(c)).collect();
152    (plain, hashes)
153}
154
155pub fn consume_recovery_code(code: &str, hashes: &[String]) -> (bool, Vec<String>) {
156    let h = sha256_hex(code);
157    if !hashes.iter().any(|x| x == &h) {
158        return (false, hashes.to_vec());
159    }
160    (true, hashes.iter().filter(|x| **x != h).cloned().collect())
161}
162
163// --------------------------------------------------------------------------
164// Sessions
165// --------------------------------------------------------------------------
166
167#[derive(Debug, Clone)]
168pub struct Session {
169    pub user_id: i64,
170    pub fingerprint: String,
171    pub expires_at: DateTime<Utc>,
172}
173
174fn iso_now() -> String {
175    Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
176}
177
178fn iso(dt: DateTime<Utc>) -> String {
179    dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
180}
181
182fn parse_iso(s: &str) -> Result<DateTime<Utc>> {
183    let s = if let Some(stripped) = s.strip_suffix('Z') {
184        format!("{stripped}+00:00")
185    } else {
186        s.to_owned()
187    };
188    Ok(DateTime::parse_from_rfc3339(&s)
189        .with_context(|| format!("parse rfc3339: {s}"))?
190        .with_timezone(&Utc))
191}
192
193fn hash_token(token: &str) -> String {
194    sha256_hex(token)
195}
196
197pub fn issue_session(
198    conn: &Connection,
199    user_id: i64,
200    fingerprint: &str,
201    ttl: Duration,
202) -> Result<(String, DateTime<Utc>)> {
203    let token = token_urlsafe(32);
204    let h = hash_token(&token);
205    let now = Utc::now();
206    let expires = now + ttl;
207    conn.execute(
208        "INSERT INTO sessions (token_hash, user_id, fingerprint, created_at, expires_at)
209         VALUES (?,?,?,?,?)",
210        params![h, user_id, fingerprint, iso(now), iso(expires)],
211    )?;
212    Ok((token, expires))
213}
214
215pub fn lookup_session(conn: &Connection, token: &str) -> Result<Option<Session>> {
216    let h = hash_token(token);
217    let row: Option<(i64, String, String, Option<String>)> = conn
218        .query_row(
219            "SELECT user_id, fingerprint, expires_at, revoked_at
220             FROM sessions WHERE token_hash=?",
221            [h],
222            |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)),
223        )
224        .optional()?;
225    let Some((user_id, fingerprint, expires_at, revoked_at)) = row else {
226        return Ok(None);
227    };
228    if revoked_at.is_some() {
229        return Ok(None);
230    }
231    let expires = parse_iso(&expires_at)?;
232    if expires <= Utc::now() {
233        return Ok(None);
234    }
235    Ok(Some(Session {
236        user_id,
237        fingerprint,
238        expires_at: expires,
239    }))
240}
241
242pub fn revoke_session(conn: &Connection, token: &str) -> Result<()> {
243    let h = hash_token(token);
244    conn.execute(
245        "UPDATE sessions SET revoked_at=? WHERE token_hash=?",
246        params![iso_now(), h],
247    )?;
248    Ok(())
249}
250
251// --------------------------------------------------------------------------
252// Challenges (one-shot per login)
253// --------------------------------------------------------------------------
254
255#[derive(Debug, Clone)]
256pub struct IssuedChallenge {
257    pub challenge: String,
258    pub expires_at: DateTime<Utc>,
259}
260
261pub fn issue_challenge(conn: &Connection, ttl_sec: Option<i64>) -> Result<IssuedChallenge> {
262    let ttl = ttl_sec.unwrap_or(constants::CHALLENGE_TTL_SEC);
263    let challenge = token_urlsafe(16);
264    let expires = Utc::now() + Duration::seconds(ttl);
265    conn.execute(
266        "INSERT INTO challenges (challenge, expires_at, used_at) VALUES (?,?,NULL)",
267        params![challenge, iso(expires)],
268    )?;
269    Ok(IssuedChallenge {
270        challenge,
271        expires_at: expires,
272    })
273}
274
275pub fn consume_challenge(conn: &Connection, challenge: &str) -> Result<bool> {
276    let row: Option<(String, Option<String>)> = conn
277        .query_row(
278            "SELECT expires_at, used_at FROM challenges WHERE challenge=?",
279            [challenge],
280            |r| Ok((r.get(0)?, r.get(1)?)),
281        )
282        .optional()?;
283    let Some((expires_at, used_at)) = row else {
284        return Ok(false);
285    };
286    if used_at.is_some() {
287        return Ok(false);
288    }
289    if parse_iso(&expires_at)? <= Utc::now() {
290        return Ok(false);
291    }
292    let n = conn.execute(
293        "UPDATE challenges SET used_at=? WHERE challenge=? AND used_at IS NULL",
294        params![iso_now(), challenge],
295    )?;
296    Ok(n > 0)
297}
298
299// --------------------------------------------------------------------------
300// Nonces (per-user, replay defense)
301// --------------------------------------------------------------------------
302
303pub fn consume_nonce(
304    conn: &Connection,
305    user_id: i64,
306    nonce: &str,
307    ttl_sec: i64,
308) -> Result<bool> {
309    let expires = Utc::now() + Duration::seconds(ttl_sec);
310    let r = conn.execute(
311        "INSERT INTO nonces (user_id, nonce, expires_at) VALUES (?,?,?)",
312        params![user_id, nonce, iso(expires)],
313    );
314    match r {
315        Ok(_) => Ok(true),
316        Err(rusqlite::Error::SqliteFailure(err, _))
317            if err.code == rusqlite::ErrorCode::ConstraintViolation =>
318        {
319            Ok(false)
320        }
321        Err(e) => Err(e.into()),
322    }
323}
324
325pub fn purge_expired_nonces(conn: &Connection) -> Result<usize> {
326    Ok(conn.execute(
327        "DELETE FROM nonces WHERE expires_at<?",
328        params![iso_now()],
329    )?)
330}
331
332// --------------------------------------------------------------------------
333// Per-request signature verification (design §7.2)
334// --------------------------------------------------------------------------
335
336#[derive(Debug, Error, Clone, PartialEq, Eq)]
337pub enum AuthError {
338    #[error("no_session")]
339    NoSession,
340    #[error("clock_skew")]
341    ClockSkew,
342    #[error("fp_session_mismatch")]
343    FpSessionMismatch,
344    #[error("unknown_fp")]
345    UnknownFingerprint,
346    #[error("revoked_fp")]
347    RevokedFingerprint,
348    #[error("replay")]
349    Replay,
350    #[error("bad_sig")]
351    BadSignature,
352}
353
354impl AuthError {
355    pub fn reason(&self) -> &'static str {
356        match self {
357            Self::NoSession => "no_session",
358            Self::ClockSkew => "clock_skew",
359            Self::FpSessionMismatch => "fp_session_mismatch",
360            Self::UnknownFingerprint => "unknown_fp",
361            Self::RevokedFingerprint => "revoked_fp",
362            Self::Replay => "replay",
363            Self::BadSignature => "bad_sig",
364        }
365    }
366}
367
368#[allow(clippy::too_many_arguments)]
369pub fn verify_signed_request(
370    conn: &Connection,
371    session_token: &str,
372    method: &str,
373    path: &str,
374    timestamp: i64,
375    nonce: &str,
376    key_fingerprint: &str,
377    signature_b64: &str,
378    body: &[u8],
379    now: Option<i64>,
380) -> std::result::Result<Session, AuthError> {
381    let actual_now = now.unwrap_or_else(|| {
382        i64::try_from(
383            SystemTime::now()
384                .duration_since(UNIX_EPOCH)
385                .map(|d| d.as_secs())
386                .unwrap_or(0),
387        )
388        .unwrap_or(0)
389    });
390
391    let sess = lookup_session(conn, session_token)
392        .map_err(|_| AuthError::NoSession)?
393        .ok_or(AuthError::NoSession)?;
394
395    if (actual_now - timestamp).abs() > constants::TIMESTAMP_SKEW_SEC {
396        return Err(AuthError::ClockSkew);
397    }
398
399    if sess.fingerprint != key_fingerprint {
400        return Err(AuthError::FpSessionMismatch);
401    }
402
403    let row: Option<(Vec<u8>, Option<String>)> = conn
404        .query_row(
405            "SELECT pubkey_blob, revoked_at FROM user_pubkeys
406             WHERE user_id=? AND fingerprint=?",
407            params![sess.user_id, key_fingerprint],
408            |r| Ok((r.get(0)?, r.get(1)?)),
409        )
410        .optional()
411        .map_err(|_| AuthError::UnknownFingerprint)?;
412    let Some((pub_blob, revoked_at)) = row else {
413        return Err(AuthError::UnknownFingerprint);
414    };
415    if revoked_at.is_some() {
416        return Err(AuthError::RevokedFingerprint);
417    }
418
419    let fresh = consume_nonce(conn, sess.user_id, nonce, constants::NONCE_TTL_SEC)
420        .map_err(|_| AuthError::Replay)?;
421    if !fresh {
422        return Err(AuthError::Replay);
423    }
424
425    let canonical = dragoon_proto::canonical::canonical_string(method, path, timestamp, nonce, body);
426    let sig_wire = B64.decode(signature_b64).map_err(|_| AuthError::BadSignature)?;
427    verify_ssh_wire_signature(&pub_blob, &sig_wire, &canonical)
428        .map_err(|_| AuthError::BadSignature)?;
429    Ok(sess)
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    fn fresh() -> Connection {
437        let c = crate::db::connect_in_memory().unwrap();
438        crate::db::bootstrap(&c).unwrap();
439        c
440    }
441
442    #[test]
443    fn argon2_hash_then_verify() {
444        let h = hash_password("hunter2").unwrap();
445        assert!(h.starts_with("$argon2id$"));
446        assert!(verify_password("hunter2", &h));
447        assert!(!verify_password("wrong", &h));
448    }
449
450    #[test]
451    fn totp_round_trip() {
452        let s = generate_totp_secret();
453        // base32 decode round-trip
454        assert!(base32_decode(&s).is_some());
455
456        let secret_bytes = base32_decode(&s).unwrap();
457        let totp = TOTP::new(TotpAlg::SHA1, 6, 1, 30, secret_bytes).unwrap();
458        let now = SystemTime::now()
459            .duration_since(UNIX_EPOCH)
460            .unwrap()
461            .as_secs();
462        let code: String = totp.generate(now);
463        assert!(verify_totp(&s, &code));
464        assert!(!verify_totp(&s, "000000"));
465    }
466
467    #[test]
468    fn recovery_codes_consume_once() {
469        let (plain, hashes) = generate_recovery_codes(3);
470        let (ok, remaining) = consume_recovery_code(&plain[1], &hashes);
471        assert!(ok);
472        assert_eq!(remaining.len(), 2);
473        let (ok2, _) = consume_recovery_code(&plain[1], &remaining);
474        assert!(!ok2);
475    }
476
477    #[test]
478    fn session_round_trip_then_revoke() {
479        let c = fresh();
480        c.execute(
481            "INSERT INTO users (username, password_hash, totp_secret_enc, created_at)
482             VALUES (?,?,?,?)",
483            params!["alice", "h", "s", "2026-01-01T00:00:00Z"],
484        )
485        .unwrap();
486        let uid = c.last_insert_rowid();
487        let (tok, _) = issue_session(&c, uid, "SHA256:fp", Duration::hours(1)).unwrap();
488        let sess = lookup_session(&c, &tok).unwrap().unwrap();
489        assert_eq!(sess.user_id, uid);
490        assert_eq!(sess.fingerprint, "SHA256:fp");
491        revoke_session(&c, &tok).unwrap();
492        assert!(lookup_session(&c, &tok).unwrap().is_none());
493    }
494
495    #[test]
496    fn challenge_one_shot() {
497        let c = fresh();
498        let ch = issue_challenge(&c, None).unwrap();
499        assert!(consume_challenge(&c, &ch.challenge).unwrap());
500        assert!(!consume_challenge(&c, &ch.challenge).unwrap());
501    }
502
503    #[test]
504    fn nonce_rejected_on_replay() {
505        let c = fresh();
506        c.execute(
507            "INSERT INTO users (username, password_hash, totp_secret_enc, created_at)
508             VALUES (?,?,?,?)",
509            params!["u", "h", "s", "2026-01-01T00:00:00Z"],
510        )
511        .unwrap();
512        let uid = c.last_insert_rowid();
513        assert!(consume_nonce(&c, uid, "abc", 300).unwrap());
514        assert!(!consume_nonce(&c, uid, "abc", 300).unwrap());
515    }
516}