Skip to main content

allowthem_core/
totp.rs

1//! TOTP core: secret management, code validation, and recovery codes.
2
3use chrono::Utc;
4use rand::TryRngCore;
5use rand::rngs::OsRng;
6use sha2::{Digest, Sha256};
7use totp_rs::{Algorithm, Secret, TOTP};
8
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::mfa_encrypt;
12use crate::types::{MfaRecoveryCodeId, MfaSecretId, UserId};
13
14const RECOVERY_CODE_COUNT: usize = 10;
15const RECOVERY_CODE_LENGTH: usize = 8;
16/// Unambiguous character set: no 0/O, 1/I/L
17const RECOVERY_CHARSET: &[u8] = b"ABCDEFGHJKMNPQRSTUVWXYZ23456789";
18
19#[derive(Debug, Clone, sqlx::FromRow)]
20struct MfaSecretRow {
21    id: MfaSecretId,
22    #[allow(dead_code)]
23    user_id: UserId,
24    secret: String, // encrypted
25    enabled: bool,
26    #[allow(dead_code)]
27    created_at: chrono::DateTime<Utc>,
28}
29
30fn build_totp(secret_base32: &str) -> Result<TOTP, AuthError> {
31    let secret_bytes = Secret::Encoded(secret_base32.to_string())
32        .to_bytes()
33        .map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
34    TOTP::new(Algorithm::SHA1, 6, 1, 30, secret_bytes, None, String::new())
35        .map_err(|e| AuthError::MfaEncryption(e.to_string()))
36}
37
38fn generate_recovery_code() -> String {
39    let mut bytes = [0u8; RECOVERY_CODE_LENGTH];
40    OsRng
41        .try_fill_bytes(&mut bytes)
42        .expect("OS RNG unavailable");
43    bytes
44        .iter()
45        .map(|b| RECOVERY_CHARSET[(*b as usize) % RECOVERY_CHARSET.len()] as char)
46        .collect()
47}
48
49fn hash_mfa_challenge(raw: &str) -> String {
50    let digest = Sha256::digest(raw.as_bytes());
51    format!("{digest:x}")
52}
53
54fn hash_recovery_code(code: &str) -> String {
55    let normalized = code.to_ascii_uppercase();
56    let digest = Sha256::digest(normalized.as_bytes());
57    format!("{digest:x}")
58}
59
60/// Build an `otpauth://totp/` URI from a base32 secret.
61///
62/// The URI encodes the issuer, account name, algorithm, digits, and period.
63/// M26 will render this as a QR code; this function only produces the string.
64pub fn totp_uri(secret_base32: &str, account_name: &str, issuer: &str) -> String {
65    let secret_bytes = Secret::Encoded(secret_base32.to_string())
66        .to_bytes()
67        .expect("totp_uri called with invalid secret");
68    let totp = TOTP::new(
69        Algorithm::SHA1,
70        6,
71        1,
72        30,
73        secret_bytes,
74        Some(issuer.to_string()),
75        account_name.to_string(),
76    )
77    .expect("totp_uri called with invalid secret");
78    totp.get_url()
79}
80
81impl Db {
82    /// Retrieve a pending (non-enabled) MFA secret for a user.
83    ///
84    /// Returns `Some(base32_secret)` if a non-enabled secret exists, `None` otherwise.
85    /// Used by the setup page to avoid regenerating the secret on every page load.
86    pub async fn get_pending_mfa_secret(
87        &self,
88        user_id: UserId,
89        mfa_key: &[u8; 32],
90    ) -> Result<Option<String>, AuthError> {
91        let row: Option<MfaSecretRow> = sqlx::query_as(
92            "SELECT id, user_id, secret, enabled, created_at \
93             FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 0",
94        )
95        .bind(user_id)
96        .fetch_optional(self.pool())
97        .await?;
98
99        match row {
100            Some(r) => {
101                let secret_bytes = mfa_encrypt::decrypt_secret(&r.secret, mfa_key)?;
102                let secret_base32 = String::from_utf8(secret_bytes)
103                    .map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
104                Ok(Some(secret_base32))
105            }
106            None => Ok(None),
107        }
108    }
109
110    /// Generate a new TOTP secret for a user and store it (encrypted, not yet enabled).
111    ///
112    /// Returns the plaintext base32-encoded secret for display to the user
113    /// during the setup flow. The caller must present this secret (or a QR code
114    /// derived from it) and require the user to confirm with a valid code
115    /// before calling `enable_mfa`.
116    ///
117    /// Fails with `MfaAlreadyEnabled` if the user already has an enabled MFA secret.
118    /// If a non-enabled secret exists (abandoned setup attempt), it is replaced.
119    pub async fn create_mfa_secret(
120        &self,
121        user_id: UserId,
122        mfa_key: &[u8; 32],
123    ) -> Result<String, AuthError> {
124        let existing: Option<MfaSecretRow> = sqlx::query_as(
125            "SELECT id, user_id, secret, enabled, created_at \
126             FROM allowthem_mfa_secrets WHERE user_id = ?",
127        )
128        .bind(user_id)
129        .fetch_optional(self.pool())
130        .await?;
131
132        if let Some(row) = existing {
133            if row.enabled {
134                return Err(AuthError::MfaAlreadyEnabled);
135            }
136            // Abandoned setup -- delete the old non-enabled secret
137            sqlx::query("DELETE FROM allowthem_mfa_secrets WHERE id = ?")
138                .bind(row.id)
139                .execute(self.pool())
140                .await?;
141        }
142
143        let secret = Secret::generate_secret();
144        let secret_base32 = secret.to_encoded().to_string();
145
146        let encrypted = mfa_encrypt::encrypt_secret(secret_base32.as_bytes(), mfa_key)?;
147        let id = MfaSecretId::new();
148
149        sqlx::query(
150            "INSERT INTO allowthem_mfa_secrets (id, user_id, secret, enabled) \
151             VALUES (?, ?, ?, 0)",
152        )
153        .bind(id)
154        .bind(user_id)
155        .bind(&encrypted)
156        .execute(self.pool())
157        .await?;
158
159        Ok(secret_base32)
160    }
161
162    /// Enable MFA for a user after verifying a TOTP code.
163    ///
164    /// Decrypts the stored secret, validates the provided code against it,
165    /// and if valid, sets `enabled = 1` and inserts 10 hashed recovery codes.
166    /// Returns the plaintext recovery codes (this is the only time they are visible).
167    ///
168    /// Runs in a transaction to ensure MFA is never enabled without recovery codes.
169    pub async fn enable_mfa(
170        &self,
171        user_id: UserId,
172        code: &str,
173        mfa_key: &[u8; 32],
174    ) -> Result<Vec<String>, AuthError> {
175        let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
176
177        let row: MfaSecretRow = sqlx::query_as(
178            "SELECT id, user_id, secret, enabled, created_at \
179             FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 0",
180        )
181        .bind(user_id)
182        .fetch_optional(&mut *tx)
183        .await
184        .map_err(AuthError::Database)?
185        .ok_or(AuthError::MfaNotEnabled)?;
186
187        let secret_bytes = mfa_encrypt::decrypt_secret(&row.secret, mfa_key)?;
188        let secret_base32 =
189            String::from_utf8(secret_bytes).map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
190        let totp = build_totp(&secret_base32)?;
191
192        if !totp
193            .check_current(code)
194            .map_err(|e| AuthError::MfaEncryption(e.to_string()))?
195        {
196            return Err(AuthError::InvalidTotpCode);
197        }
198
199        sqlx::query("UPDATE allowthem_mfa_secrets SET enabled = 1 WHERE id = ?")
200            .bind(row.id)
201            .execute(&mut *tx)
202            .await
203            .map_err(AuthError::Database)?;
204
205        let mut plaintext_codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
206        for _ in 0..RECOVERY_CODE_COUNT {
207            let recovery = generate_recovery_code();
208            let code_hash = hash_recovery_code(&recovery);
209            let code_id = MfaRecoveryCodeId::new();
210
211            sqlx::query(
212                "INSERT INTO allowthem_mfa_recovery_codes (id, user_id, code_hash) \
213                 VALUES (?, ?, ?)",
214            )
215            .bind(code_id)
216            .bind(user_id)
217            .bind(&code_hash)
218            .execute(&mut *tx)
219            .await
220            .map_err(AuthError::Database)?;
221
222            plaintext_codes.push(recovery);
223        }
224
225        tx.commit().await.map_err(AuthError::Database)?;
226
227        Ok(plaintext_codes)
228    }
229
230    /// Validate a TOTP code against a user's enabled MFA secret.
231    ///
232    /// Returns `Ok(true)` if the code is valid, `Ok(false)` if invalid.
233    /// Returns `Err(MfaNotEnabled)` if the user has no enabled MFA.
234    pub async fn verify_totp(
235        &self,
236        user_id: UserId,
237        code: &str,
238        mfa_key: &[u8; 32],
239    ) -> Result<bool, AuthError> {
240        let row: MfaSecretRow = sqlx::query_as(
241            "SELECT id, user_id, secret, enabled, created_at \
242             FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 1",
243        )
244        .bind(user_id)
245        .fetch_optional(self.pool())
246        .await?
247        .ok_or(AuthError::MfaNotEnabled)?;
248
249        let secret_bytes = mfa_encrypt::decrypt_secret(&row.secret, mfa_key)?;
250        let secret_base32 =
251            String::from_utf8(secret_bytes).map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
252        let totp = build_totp(&secret_base32)?;
253
254        totp.check_current(code)
255            .map_err(|e| AuthError::MfaEncryption(e.to_string()))
256    }
257
258    /// Check whether a user has MFA enabled.
259    pub async fn has_mfa_enabled(&self, user_id: UserId) -> Result<bool, AuthError> {
260        let count: (i64,) = sqlx::query_as(
261            "SELECT COUNT(*) FROM allowthem_mfa_secrets \
262             WHERE user_id = ? AND enabled = 1",
263        )
264        .bind(user_id)
265        .fetch_one(self.pool())
266        .await?;
267
268        Ok(count.0 > 0)
269    }
270
271    /// Disable MFA for a user. Deletes the secret and all recovery codes.
272    ///
273    /// Uses a transaction to ensure both deletes are atomic.
274    pub async fn disable_mfa(&self, user_id: UserId) -> Result<(), AuthError> {
275        let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
276
277        sqlx::query("DELETE FROM allowthem_mfa_recovery_codes WHERE user_id = ?")
278            .bind(user_id)
279            .execute(&mut *tx)
280            .await
281            .map_err(AuthError::Database)?;
282
283        sqlx::query("DELETE FROM allowthem_mfa_secrets WHERE user_id = ?")
284            .bind(user_id)
285            .execute(&mut *tx)
286            .await
287            .map_err(AuthError::Database)?;
288
289        tx.commit().await.map_err(AuthError::Database)?;
290
291        Ok(())
292    }
293
294    /// Verify a recovery code. If valid, marks it as used (one-time use).
295    ///
296    /// Uses atomic `UPDATE ... RETURNING` to prevent race conditions.
297    /// Returns `Ok(true)` if the code was valid and consumed,
298    /// `Ok(false)` if no matching unused code was found.
299    pub async fn verify_recovery_code(
300        &self,
301        user_id: UserId,
302        code: &str,
303    ) -> Result<bool, AuthError> {
304        let code_hash = hash_recovery_code(code);
305        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
306
307        let row: Option<(MfaRecoveryCodeId,)> = sqlx::query_as(
308            "UPDATE allowthem_mfa_recovery_codes SET used_at = ?1 \
309             WHERE user_id = ?2 AND code_hash = ?3 AND used_at IS NULL \
310             RETURNING id",
311        )
312        .bind(&now)
313        .bind(user_id)
314        .bind(&code_hash)
315        .fetch_optional(self.pool())
316        .await?;
317
318        Ok(row.is_some())
319    }
320
321    /// Count remaining unused recovery codes for a user.
322    pub async fn remaining_recovery_codes(&self, user_id: UserId) -> Result<i64, AuthError> {
323        let count: (i64,) = sqlx::query_as(
324            "SELECT COUNT(*) FROM allowthem_mfa_recovery_codes \
325             WHERE user_id = ? AND used_at IS NULL",
326        )
327        .bind(user_id)
328        .fetch_one(self.pool())
329        .await?;
330
331        Ok(count.0)
332    }
333
334    /// Replace all recovery codes with a fresh set of 10.
335    ///
336    /// Deletes all existing codes (used and unused) and inserts 10 new ones.
337    /// Returns the plaintext codes. Runs in a transaction.
338    pub async fn regenerate_recovery_codes(
339        &self,
340        user_id: UserId,
341    ) -> Result<Vec<String>, AuthError> {
342        let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
343
344        sqlx::query("DELETE FROM allowthem_mfa_recovery_codes WHERE user_id = ?")
345            .bind(user_id)
346            .execute(&mut *tx)
347            .await
348            .map_err(AuthError::Database)?;
349
350        let mut plaintext_codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
351        for _ in 0..RECOVERY_CODE_COUNT {
352            let code = generate_recovery_code();
353            let code_hash = hash_recovery_code(&code);
354            let code_id = MfaRecoveryCodeId::new();
355
356            sqlx::query(
357                "INSERT INTO allowthem_mfa_recovery_codes (id, user_id, code_hash) \
358                 VALUES (?, ?, ?)",
359            )
360            .bind(code_id)
361            .bind(user_id)
362            .bind(&code_hash)
363            .execute(&mut *tx)
364            .await
365            .map_err(AuthError::Database)?;
366
367            plaintext_codes.push(code);
368        }
369
370        tx.commit().await.map_err(AuthError::Database)?;
371
372        Ok(plaintext_codes)
373    }
374
375    /// Create a short-lived MFA challenge token after password verification.
376    ///
377    /// The integrator calls this when a user with MFA enabled passes password
378    /// verification. Returns the raw token string to send to the client. The
379    /// client presents this token along with a TOTP code to complete login.
380    /// Challenge tokens expire after 5 minutes.
381    pub async fn create_mfa_challenge(&self, user_id: UserId) -> Result<String, AuthError> {
382        use crate::sessions::generate_token;
383        use crate::types::MfaChallengeId;
384
385        let token = generate_token();
386        let token_hash = hash_mfa_challenge(token.as_str());
387        let id = MfaChallengeId::new();
388        let expires_at = Utc::now() + chrono::Duration::minutes(5);
389        let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
390
391        sqlx::query(
392            "INSERT INTO allowthem_mfa_challenges (id, token_hash, user_id, expires_at) \
393             VALUES (?, ?, ?, ?)",
394        )
395        .bind(id)
396        .bind(&token_hash)
397        .bind(user_id)
398        .bind(&expires_at_str)
399        .execute(self.pool())
400        .await?;
401
402        Ok(token.as_str().to_string())
403    }
404
405    /// Validate an MFA challenge token without consuming it.
406    ///
407    /// Returns `Some(user_id)` if the token is valid and not expired,
408    /// `None` otherwise. Does not consume the token so the user can retry
409    /// if they mistype the TOTP code.
410    pub async fn validate_mfa_challenge(
411        &self,
412        raw_token: &str,
413    ) -> Result<Option<UserId>, AuthError> {
414        let token_hash = hash_mfa_challenge(raw_token);
415        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
416
417        let row: Option<(UserId,)> = sqlx::query_as(
418            "SELECT user_id FROM allowthem_mfa_challenges \
419             WHERE token_hash = ? AND expires_at > ?",
420        )
421        .bind(&token_hash)
422        .bind(&now)
423        .fetch_optional(self.pool())
424        .await?;
425
426        Ok(row.map(|(uid,)| uid))
427    }
428
429    /// Consume an MFA challenge token after successful TOTP verification.
430    ///
431    /// Uses `DELETE ... RETURNING` for atomicity.
432    pub async fn consume_mfa_challenge(&self, raw_token: &str) -> Result<(), AuthError> {
433        let token_hash = hash_mfa_challenge(raw_token);
434
435        sqlx::query("DELETE FROM allowthem_mfa_challenges WHERE token_hash = ?")
436            .bind(&token_hash)
437            .execute(self.pool())
438            .await?;
439
440        Ok(())
441    }
442}
443
444use crate::event_sink::AuthEvent;
445use crate::handle::AllowThem;
446
447impl AllowThem {
448    pub async fn get_pending_mfa_secret(
449        &self,
450        user_id: UserId,
451    ) -> Result<Option<String>, AuthError> {
452        self.db()
453            .get_pending_mfa_secret(user_id, self.mfa_key()?)
454            .await
455    }
456
457    pub async fn create_mfa_secret(&self, user_id: UserId) -> Result<String, AuthError> {
458        self.db().create_mfa_secret(user_id, self.mfa_key()?).await
459    }
460
461    pub async fn enable_mfa(&self, user_id: UserId, code: &str) -> Result<Vec<String>, AuthError> {
462        let codes = self.db().enable_mfa(user_id, code, self.mfa_key()?).await?;
463        self.emit_event(AuthEvent::new(
464            "mfa.enrolled",
465            Some(user_id),
466            serde_json::json!({ "user_id": user_id }),
467        ))
468        .await;
469        Ok(codes)
470    }
471
472    pub async fn verify_totp(&self, user_id: UserId, code: &str) -> Result<bool, AuthError> {
473        self.db().verify_totp(user_id, code, self.mfa_key()?).await
474    }
475
476    pub async fn has_mfa_enabled(&self, user_id: UserId) -> Result<bool, AuthError> {
477        self.db().has_mfa_enabled(user_id).await
478    }
479
480    pub async fn disable_mfa(&self, user_id: UserId) -> Result<(), AuthError> {
481        self.db().disable_mfa(user_id).await?;
482        self.emit_event(AuthEvent::new(
483            "mfa.removed",
484            Some(user_id),
485            serde_json::json!({ "user_id": user_id }),
486        ))
487        .await;
488        Ok(())
489    }
490
491    pub async fn verify_recovery_code(
492        &self,
493        user_id: UserId,
494        code: &str,
495    ) -> Result<bool, AuthError> {
496        self.db().verify_recovery_code(user_id, code).await
497    }
498
499    pub async fn remaining_recovery_codes(&self, user_id: UserId) -> Result<i64, AuthError> {
500        self.db().remaining_recovery_codes(user_id).await
501    }
502
503    pub async fn regenerate_recovery_codes(
504        &self,
505        user_id: UserId,
506    ) -> Result<Vec<String>, AuthError> {
507        self.db().regenerate_recovery_codes(user_id).await
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use crate::db::Db;
514    use crate::error::AuthError;
515    use crate::handle::AllowThemBuilder;
516    use crate::types::Email;
517
518    use super::*;
519
520    const TEST_MFA_KEY: [u8; 32] = [0x42; 32];
521
522    async fn test_db() -> Db {
523        Db::connect("sqlite::memory:").await.expect("in-memory db")
524    }
525
526    async fn make_user(db: &Db) -> UserId {
527        let email = Email::new("mfa@example.com".to_string()).unwrap();
528        db.create_user(email, "password123", None, None)
529            .await
530            .unwrap()
531            .id
532    }
533
534    /// Helper: create MFA secret, generate a valid current code from it, enable MFA.
535    /// Returns the recovery codes.
536    async fn setup_and_enable_mfa(db: &Db, user_id: UserId) -> Vec<String> {
537        let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
538        let totp = build_totp(&secret_b32).unwrap();
539        let code = totp.generate_current().unwrap();
540        db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap()
541    }
542
543    #[tokio::test]
544    async fn totp_validation() {
545        let secret = Secret::generate_secret();
546        let secret_b32 = secret.to_encoded().to_string();
547        let totp = build_totp(&secret_b32).unwrap();
548        let code = totp.generate_current().unwrap();
549        let valid = totp
550            .check_current(&code)
551            .expect("check_current should not fail");
552        assert!(valid, "generated code must validate");
553    }
554
555    #[tokio::test]
556    async fn totp_uri_format() {
557        let secret = Secret::generate_secret();
558        let secret_b32 = secret.to_encoded().to_string();
559        let uri = totp_uri(&secret_b32, "user@example.com", "allowthem");
560        assert!(
561            uri.starts_with("otpauth://totp/"),
562            "URI must start with otpauth://totp/"
563        );
564        assert!(
565            uri.contains("user%40example.com"),
566            "URI must contain account name"
567        );
568        assert!(uri.contains("allowthem"), "URI must contain issuer");
569    }
570
571    #[tokio::test]
572    async fn create_and_enable_flow() {
573        let db = test_db().await;
574        let user_id = make_user(&db).await;
575
576        let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
577        let totp = build_totp(&secret_b32).unwrap();
578        let code = totp.generate_current().unwrap();
579
580        let recovery_codes = db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap();
581        assert_eq!(recovery_codes.len(), 10, "must return 10 recovery codes");
582
583        let enabled = db.has_mfa_enabled(user_id).await.unwrap();
584        assert!(enabled, "MFA must be enabled after enable_mfa");
585    }
586
587    #[tokio::test]
588    async fn enable_rejects_wrong_code() {
589        let db = test_db().await;
590        let user_id = make_user(&db).await;
591        db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
592
593        let result = db.enable_mfa(user_id, "000000", &TEST_MFA_KEY).await;
594        assert!(
595            matches!(result, Err(AuthError::InvalidTotpCode)),
596            "wrong code must return InvalidTotpCode"
597        );
598    }
599
600    #[tokio::test]
601    async fn double_enable_blocked() {
602        let db = test_db().await;
603        let user_id = make_user(&db).await;
604        setup_and_enable_mfa(&db, user_id).await;
605
606        let result = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await;
607        assert!(
608            matches!(result, Err(AuthError::MfaAlreadyEnabled)),
609            "second create must return MfaAlreadyEnabled"
610        );
611    }
612
613    #[tokio::test]
614    async fn abandoned_setup_replacement() {
615        let db = test_db().await;
616        let user_id = make_user(&db).await;
617
618        let secret_a = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
619        let secret_b = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
620        assert_ne!(secret_a, secret_b, "replacement must produce a new secret");
621
622        // Enable with code from secret B
623        let totp = build_totp(&secret_b).unwrap();
624        let code = totp.generate_current().unwrap();
625        let result = db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await;
626        assert!(result.is_ok(), "enable with new secret must succeed");
627    }
628
629    #[tokio::test]
630    async fn verify_totp_valid_and_invalid() {
631        let db = test_db().await;
632        let user_id = make_user(&db).await;
633
634        let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
635        let totp = build_totp(&secret_b32).unwrap();
636        let code = totp.generate_current().unwrap();
637        db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap();
638
639        // Valid code
640        let fresh_code = totp.generate_current().unwrap();
641        let valid = db
642            .verify_totp(user_id, &fresh_code, &TEST_MFA_KEY)
643            .await
644            .unwrap();
645        assert!(valid, "correct TOTP code must validate");
646
647        // Invalid code
648        let invalid = db
649            .verify_totp(user_id, "000000", &TEST_MFA_KEY)
650            .await
651            .unwrap();
652        assert!(!invalid, "wrong TOTP code must return false");
653    }
654
655    #[tokio::test]
656    async fn verify_totp_no_mfa() {
657        let db = test_db().await;
658        let user_id = make_user(&db).await;
659
660        let result = db.verify_totp(user_id, "123456", &TEST_MFA_KEY).await;
661        assert!(
662            matches!(result, Err(AuthError::MfaNotEnabled)),
663            "verify_totp on non-MFA user must return MfaNotEnabled"
664        );
665    }
666
667    #[tokio::test]
668    async fn recovery_code_consumption() {
669        let db = test_db().await;
670        let user_id = make_user(&db).await;
671        let codes = setup_and_enable_mfa(&db, user_id).await;
672
673        let consumed = db.verify_recovery_code(user_id, &codes[0]).await.unwrap();
674        assert!(consumed, "valid recovery code must be consumed");
675
676        let reuse = db.verify_recovery_code(user_id, &codes[0]).await.unwrap();
677        assert!(!reuse, "used recovery code must not be reusable");
678
679        let remaining = db.remaining_recovery_codes(user_id).await.unwrap();
680        assert_eq!(remaining, 9, "one code consumed, 9 remaining");
681    }
682
683    #[tokio::test]
684    async fn recovery_code_wrong() {
685        let db = test_db().await;
686        let user_id = make_user(&db).await;
687        setup_and_enable_mfa(&db, user_id).await;
688
689        let result = db.verify_recovery_code(user_id, "ZZZZZZZZ").await.unwrap();
690        assert!(!result, "wrong recovery code must return false");
691    }
692
693    #[tokio::test]
694    async fn recovery_code_case_insensitive() {
695        let db = test_db().await;
696        let user_id = make_user(&db).await;
697        let codes = setup_and_enable_mfa(&db, user_id).await;
698
699        let consumed = db
700            .verify_recovery_code(user_id, &codes[1].to_lowercase())
701            .await
702            .unwrap();
703        assert!(consumed, "lowercase recovery code must match");
704    }
705
706    #[tokio::test]
707    async fn disable_mfa_cleans_up() {
708        let db = test_db().await;
709        let user_id = make_user(&db).await;
710        setup_and_enable_mfa(&db, user_id).await;
711
712        db.disable_mfa(user_id).await.unwrap();
713
714        let enabled = db.has_mfa_enabled(user_id).await.unwrap();
715        assert!(!enabled, "MFA must not be enabled after disable");
716
717        let remaining = db.remaining_recovery_codes(user_id).await.unwrap();
718        assert_eq!(remaining, 0, "recovery codes must be deleted");
719    }
720
721    #[tokio::test]
722    async fn user_deletion_cascades() {
723        let db = test_db().await;
724        let user_id = make_user(&db).await;
725        setup_and_enable_mfa(&db, user_id).await;
726
727        db.delete_user(user_id).await.unwrap();
728
729        let enabled = db.has_mfa_enabled(user_id).await.unwrap();
730        assert!(!enabled, "MFA must not be enabled after user deletion");
731    }
732
733    #[tokio::test]
734    async fn mfa_not_configured_without_key() {
735        let ath = AllowThemBuilder::new("sqlite::memory:")
736            .build()
737            .await
738            .unwrap();
739        let email = Email::new("nokey@example.com".to_string()).unwrap();
740        let user = ath
741            .db()
742            .create_user(email, "password123", None, None)
743            .await
744            .unwrap();
745
746        let result = ath.create_mfa_secret(user.id).await;
747        assert!(
748            matches!(result, Err(AuthError::MfaNotConfigured)),
749            "MFA without key must return MfaNotConfigured"
750        );
751    }
752}