1use chrono::Utc;
4use rand::TryRngCore;
5use rand::rngs::OsRng;
6use sha2::{Digest, Sha256};
7use totp_rs::{Algorithm, Secret, TOTP};
8
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::mfa_encrypt;
12use crate::types::{MfaRecoveryCodeId, MfaSecretId, UserId};
13
14const RECOVERY_CODE_COUNT: usize = 10;
15const RECOVERY_CODE_LENGTH: usize = 8;
16const RECOVERY_CHARSET: &[u8] = b"ABCDEFGHJKMNPQRSTUVWXYZ23456789";
18
19#[derive(Debug, Clone, sqlx::FromRow)]
20struct MfaSecretRow {
21 id: MfaSecretId,
22 #[allow(dead_code)]
23 user_id: UserId,
24 secret: String, enabled: bool,
26 #[allow(dead_code)]
27 created_at: chrono::DateTime<Utc>,
28}
29
30fn build_totp(secret_base32: &str) -> Result<TOTP, AuthError> {
31 let secret_bytes = Secret::Encoded(secret_base32.to_string())
32 .to_bytes()
33 .map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
34 TOTP::new(Algorithm::SHA1, 6, 1, 30, secret_bytes, None, String::new())
35 .map_err(|e| AuthError::MfaEncryption(e.to_string()))
36}
37
38fn generate_recovery_code() -> String {
39 let mut bytes = [0u8; RECOVERY_CODE_LENGTH];
40 OsRng
41 .try_fill_bytes(&mut bytes)
42 .expect("OS RNG unavailable");
43 bytes
44 .iter()
45 .map(|b| RECOVERY_CHARSET[(*b as usize) % RECOVERY_CHARSET.len()] as char)
46 .collect()
47}
48
49fn hash_mfa_challenge(raw: &str) -> String {
50 let digest = Sha256::digest(raw.as_bytes());
51 format!("{digest:x}")
52}
53
54fn hash_recovery_code(code: &str) -> String {
55 let normalized = code.to_ascii_uppercase();
56 let digest = Sha256::digest(normalized.as_bytes());
57 format!("{digest:x}")
58}
59
60pub fn totp_uri(secret_base32: &str, account_name: &str, issuer: &str) -> String {
65 let secret_bytes = Secret::Encoded(secret_base32.to_string())
66 .to_bytes()
67 .expect("totp_uri called with invalid secret");
68 let totp = TOTP::new(
69 Algorithm::SHA1,
70 6,
71 1,
72 30,
73 secret_bytes,
74 Some(issuer.to_string()),
75 account_name.to_string(),
76 )
77 .expect("totp_uri called with invalid secret");
78 totp.get_url()
79}
80
81impl Db {
82 pub async fn get_pending_mfa_secret(
87 &self,
88 user_id: UserId,
89 mfa_key: &[u8; 32],
90 ) -> Result<Option<String>, AuthError> {
91 let row: Option<MfaSecretRow> = sqlx::query_as(
92 "SELECT id, user_id, secret, enabled, created_at \
93 FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 0",
94 )
95 .bind(user_id)
96 .fetch_optional(self.pool())
97 .await?;
98
99 match row {
100 Some(r) => {
101 let secret_bytes = mfa_encrypt::decrypt_secret(&r.secret, mfa_key)?;
102 let secret_base32 = String::from_utf8(secret_bytes)
103 .map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
104 Ok(Some(secret_base32))
105 }
106 None => Ok(None),
107 }
108 }
109
110 pub async fn create_mfa_secret(
120 &self,
121 user_id: UserId,
122 mfa_key: &[u8; 32],
123 ) -> Result<String, AuthError> {
124 let existing: Option<MfaSecretRow> = sqlx::query_as(
125 "SELECT id, user_id, secret, enabled, created_at \
126 FROM allowthem_mfa_secrets WHERE user_id = ?",
127 )
128 .bind(user_id)
129 .fetch_optional(self.pool())
130 .await?;
131
132 if let Some(row) = existing {
133 if row.enabled {
134 return Err(AuthError::MfaAlreadyEnabled);
135 }
136 sqlx::query("DELETE FROM allowthem_mfa_secrets WHERE id = ?")
138 .bind(row.id)
139 .execute(self.pool())
140 .await?;
141 }
142
143 let secret = Secret::generate_secret();
144 let secret_base32 = secret.to_encoded().to_string();
145
146 let encrypted = mfa_encrypt::encrypt_secret(secret_base32.as_bytes(), mfa_key)?;
147 let id = MfaSecretId::new();
148
149 sqlx::query(
150 "INSERT INTO allowthem_mfa_secrets (id, user_id, secret, enabled) \
151 VALUES (?, ?, ?, 0)",
152 )
153 .bind(id)
154 .bind(user_id)
155 .bind(&encrypted)
156 .execute(self.pool())
157 .await?;
158
159 Ok(secret_base32)
160 }
161
162 pub async fn enable_mfa(
170 &self,
171 user_id: UserId,
172 code: &str,
173 mfa_key: &[u8; 32],
174 ) -> Result<Vec<String>, AuthError> {
175 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
176
177 let row: MfaSecretRow = sqlx::query_as(
178 "SELECT id, user_id, secret, enabled, created_at \
179 FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 0",
180 )
181 .bind(user_id)
182 .fetch_optional(&mut *tx)
183 .await
184 .map_err(AuthError::Database)?
185 .ok_or(AuthError::MfaNotEnabled)?;
186
187 let secret_bytes = mfa_encrypt::decrypt_secret(&row.secret, mfa_key)?;
188 let secret_base32 =
189 String::from_utf8(secret_bytes).map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
190 let totp = build_totp(&secret_base32)?;
191
192 if !totp
193 .check_current(code)
194 .map_err(|e| AuthError::MfaEncryption(e.to_string()))?
195 {
196 return Err(AuthError::InvalidTotpCode);
197 }
198
199 sqlx::query("UPDATE allowthem_mfa_secrets SET enabled = 1 WHERE id = ?")
200 .bind(row.id)
201 .execute(&mut *tx)
202 .await
203 .map_err(AuthError::Database)?;
204
205 let mut plaintext_codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
206 for _ in 0..RECOVERY_CODE_COUNT {
207 let recovery = generate_recovery_code();
208 let code_hash = hash_recovery_code(&recovery);
209 let code_id = MfaRecoveryCodeId::new();
210
211 sqlx::query(
212 "INSERT INTO allowthem_mfa_recovery_codes (id, user_id, code_hash) \
213 VALUES (?, ?, ?)",
214 )
215 .bind(code_id)
216 .bind(user_id)
217 .bind(&code_hash)
218 .execute(&mut *tx)
219 .await
220 .map_err(AuthError::Database)?;
221
222 plaintext_codes.push(recovery);
223 }
224
225 tx.commit().await.map_err(AuthError::Database)?;
226
227 Ok(plaintext_codes)
228 }
229
230 pub async fn verify_totp(
235 &self,
236 user_id: UserId,
237 code: &str,
238 mfa_key: &[u8; 32],
239 ) -> Result<bool, AuthError> {
240 let row: MfaSecretRow = sqlx::query_as(
241 "SELECT id, user_id, secret, enabled, created_at \
242 FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 1",
243 )
244 .bind(user_id)
245 .fetch_optional(self.pool())
246 .await?
247 .ok_or(AuthError::MfaNotEnabled)?;
248
249 let secret_bytes = mfa_encrypt::decrypt_secret(&row.secret, mfa_key)?;
250 let secret_base32 =
251 String::from_utf8(secret_bytes).map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
252 let totp = build_totp(&secret_base32)?;
253
254 totp.check_current(code)
255 .map_err(|e| AuthError::MfaEncryption(e.to_string()))
256 }
257
258 pub async fn has_mfa_enabled(&self, user_id: UserId) -> Result<bool, AuthError> {
260 let count: (i64,) = sqlx::query_as(
261 "SELECT COUNT(*) FROM allowthem_mfa_secrets \
262 WHERE user_id = ? AND enabled = 1",
263 )
264 .bind(user_id)
265 .fetch_one(self.pool())
266 .await?;
267
268 Ok(count.0 > 0)
269 }
270
271 pub async fn disable_mfa(&self, user_id: UserId) -> Result<(), AuthError> {
275 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
276
277 sqlx::query("DELETE FROM allowthem_mfa_recovery_codes WHERE user_id = ?")
278 .bind(user_id)
279 .execute(&mut *tx)
280 .await
281 .map_err(AuthError::Database)?;
282
283 sqlx::query("DELETE FROM allowthem_mfa_secrets WHERE user_id = ?")
284 .bind(user_id)
285 .execute(&mut *tx)
286 .await
287 .map_err(AuthError::Database)?;
288
289 tx.commit().await.map_err(AuthError::Database)?;
290
291 Ok(())
292 }
293
294 pub async fn verify_recovery_code(
300 &self,
301 user_id: UserId,
302 code: &str,
303 ) -> Result<bool, AuthError> {
304 let code_hash = hash_recovery_code(code);
305 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
306
307 let row: Option<(MfaRecoveryCodeId,)> = sqlx::query_as(
308 "UPDATE allowthem_mfa_recovery_codes SET used_at = ?1 \
309 WHERE user_id = ?2 AND code_hash = ?3 AND used_at IS NULL \
310 RETURNING id",
311 )
312 .bind(&now)
313 .bind(user_id)
314 .bind(&code_hash)
315 .fetch_optional(self.pool())
316 .await?;
317
318 Ok(row.is_some())
319 }
320
321 pub async fn remaining_recovery_codes(&self, user_id: UserId) -> Result<i64, AuthError> {
323 let count: (i64,) = sqlx::query_as(
324 "SELECT COUNT(*) FROM allowthem_mfa_recovery_codes \
325 WHERE user_id = ? AND used_at IS NULL",
326 )
327 .bind(user_id)
328 .fetch_one(self.pool())
329 .await?;
330
331 Ok(count.0)
332 }
333
334 pub async fn regenerate_recovery_codes(
339 &self,
340 user_id: UserId,
341 ) -> Result<Vec<String>, AuthError> {
342 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
343
344 sqlx::query("DELETE FROM allowthem_mfa_recovery_codes WHERE user_id = ?")
345 .bind(user_id)
346 .execute(&mut *tx)
347 .await
348 .map_err(AuthError::Database)?;
349
350 let mut plaintext_codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
351 for _ in 0..RECOVERY_CODE_COUNT {
352 let code = generate_recovery_code();
353 let code_hash = hash_recovery_code(&code);
354 let code_id = MfaRecoveryCodeId::new();
355
356 sqlx::query(
357 "INSERT INTO allowthem_mfa_recovery_codes (id, user_id, code_hash) \
358 VALUES (?, ?, ?)",
359 )
360 .bind(code_id)
361 .bind(user_id)
362 .bind(&code_hash)
363 .execute(&mut *tx)
364 .await
365 .map_err(AuthError::Database)?;
366
367 plaintext_codes.push(code);
368 }
369
370 tx.commit().await.map_err(AuthError::Database)?;
371
372 Ok(plaintext_codes)
373 }
374
375 pub async fn create_mfa_challenge(&self, user_id: UserId) -> Result<String, AuthError> {
382 use crate::sessions::generate_token;
383 use crate::types::MfaChallengeId;
384
385 let token = generate_token();
386 let token_hash = hash_mfa_challenge(token.as_str());
387 let id = MfaChallengeId::new();
388 let expires_at = Utc::now() + chrono::Duration::minutes(5);
389 let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
390
391 sqlx::query(
392 "INSERT INTO allowthem_mfa_challenges (id, token_hash, user_id, expires_at) \
393 VALUES (?, ?, ?, ?)",
394 )
395 .bind(id)
396 .bind(&token_hash)
397 .bind(user_id)
398 .bind(&expires_at_str)
399 .execute(self.pool())
400 .await?;
401
402 Ok(token.as_str().to_string())
403 }
404
405 pub async fn validate_mfa_challenge(
411 &self,
412 raw_token: &str,
413 ) -> Result<Option<UserId>, AuthError> {
414 let token_hash = hash_mfa_challenge(raw_token);
415 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
416
417 let row: Option<(UserId,)> = sqlx::query_as(
418 "SELECT user_id FROM allowthem_mfa_challenges \
419 WHERE token_hash = ? AND expires_at > ?",
420 )
421 .bind(&token_hash)
422 .bind(&now)
423 .fetch_optional(self.pool())
424 .await?;
425
426 Ok(row.map(|(uid,)| uid))
427 }
428
429 pub async fn consume_mfa_challenge(&self, raw_token: &str) -> Result<(), AuthError> {
433 let token_hash = hash_mfa_challenge(raw_token);
434
435 sqlx::query("DELETE FROM allowthem_mfa_challenges WHERE token_hash = ?")
436 .bind(&token_hash)
437 .execute(self.pool())
438 .await?;
439
440 Ok(())
441 }
442}
443
444use crate::handle::AllowThem;
445
446impl AllowThem {
447 pub async fn get_pending_mfa_secret(
448 &self,
449 user_id: UserId,
450 ) -> Result<Option<String>, AuthError> {
451 self.db()
452 .get_pending_mfa_secret(user_id, self.mfa_key()?)
453 .await
454 }
455
456 pub async fn create_mfa_secret(&self, user_id: UserId) -> Result<String, AuthError> {
457 self.db().create_mfa_secret(user_id, self.mfa_key()?).await
458 }
459
460 pub async fn enable_mfa(&self, user_id: UserId, code: &str) -> Result<Vec<String>, AuthError> {
461 self.db().enable_mfa(user_id, code, self.mfa_key()?).await
462 }
463
464 pub async fn verify_totp(&self, user_id: UserId, code: &str) -> Result<bool, AuthError> {
465 self.db().verify_totp(user_id, code, self.mfa_key()?).await
466 }
467
468 pub async fn has_mfa_enabled(&self, user_id: UserId) -> Result<bool, AuthError> {
469 self.db().has_mfa_enabled(user_id).await
470 }
471
472 pub async fn disable_mfa(&self, user_id: UserId) -> Result<(), AuthError> {
473 self.db().disable_mfa(user_id).await
474 }
475
476 pub async fn verify_recovery_code(
477 &self,
478 user_id: UserId,
479 code: &str,
480 ) -> Result<bool, AuthError> {
481 self.db().verify_recovery_code(user_id, code).await
482 }
483
484 pub async fn remaining_recovery_codes(&self, user_id: UserId) -> Result<i64, AuthError> {
485 self.db().remaining_recovery_codes(user_id).await
486 }
487
488 pub async fn regenerate_recovery_codes(
489 &self,
490 user_id: UserId,
491 ) -> Result<Vec<String>, AuthError> {
492 self.db().regenerate_recovery_codes(user_id).await
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use crate::db::Db;
499 use crate::error::AuthError;
500 use crate::handle::AllowThemBuilder;
501 use crate::types::Email;
502
503 use super::*;
504
505 const TEST_MFA_KEY: [u8; 32] = [0x42; 32];
506
507 async fn test_db() -> Db {
508 Db::connect("sqlite::memory:").await.expect("in-memory db")
509 }
510
511 async fn make_user(db: &Db) -> UserId {
512 let email = Email::new("mfa@example.com".to_string()).unwrap();
513 db.create_user(email, "password123", None, None)
514 .await
515 .unwrap()
516 .id
517 }
518
519 async fn setup_and_enable_mfa(db: &Db, user_id: UserId) -> Vec<String> {
522 let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
523 let totp = build_totp(&secret_b32).unwrap();
524 let code = totp.generate_current().unwrap();
525 db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap()
526 }
527
528 #[tokio::test]
529 async fn totp_validation() {
530 let secret = Secret::generate_secret();
531 let secret_b32 = secret.to_encoded().to_string();
532 let totp = build_totp(&secret_b32).unwrap();
533 let code = totp.generate_current().unwrap();
534 let valid = totp
535 .check_current(&code)
536 .expect("check_current should not fail");
537 assert!(valid, "generated code must validate");
538 }
539
540 #[tokio::test]
541 async fn totp_uri_format() {
542 let secret = Secret::generate_secret();
543 let secret_b32 = secret.to_encoded().to_string();
544 let uri = totp_uri(&secret_b32, "user@example.com", "allowthem");
545 assert!(
546 uri.starts_with("otpauth://totp/"),
547 "URI must start with otpauth://totp/"
548 );
549 assert!(
550 uri.contains("user%40example.com"),
551 "URI must contain account name"
552 );
553 assert!(uri.contains("allowthem"), "URI must contain issuer");
554 }
555
556 #[tokio::test]
557 async fn create_and_enable_flow() {
558 let db = test_db().await;
559 let user_id = make_user(&db).await;
560
561 let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
562 let totp = build_totp(&secret_b32).unwrap();
563 let code = totp.generate_current().unwrap();
564
565 let recovery_codes = db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap();
566 assert_eq!(recovery_codes.len(), 10, "must return 10 recovery codes");
567
568 let enabled = db.has_mfa_enabled(user_id).await.unwrap();
569 assert!(enabled, "MFA must be enabled after enable_mfa");
570 }
571
572 #[tokio::test]
573 async fn enable_rejects_wrong_code() {
574 let db = test_db().await;
575 let user_id = make_user(&db).await;
576 db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
577
578 let result = db.enable_mfa(user_id, "000000", &TEST_MFA_KEY).await;
579 assert!(
580 matches!(result, Err(AuthError::InvalidTotpCode)),
581 "wrong code must return InvalidTotpCode"
582 );
583 }
584
585 #[tokio::test]
586 async fn double_enable_blocked() {
587 let db = test_db().await;
588 let user_id = make_user(&db).await;
589 setup_and_enable_mfa(&db, user_id).await;
590
591 let result = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await;
592 assert!(
593 matches!(result, Err(AuthError::MfaAlreadyEnabled)),
594 "second create must return MfaAlreadyEnabled"
595 );
596 }
597
598 #[tokio::test]
599 async fn abandoned_setup_replacement() {
600 let db = test_db().await;
601 let user_id = make_user(&db).await;
602
603 let secret_a = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
604 let secret_b = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
605 assert_ne!(secret_a, secret_b, "replacement must produce a new secret");
606
607 let totp = build_totp(&secret_b).unwrap();
609 let code = totp.generate_current().unwrap();
610 let result = db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await;
611 assert!(result.is_ok(), "enable with new secret must succeed");
612 }
613
614 #[tokio::test]
615 async fn verify_totp_valid_and_invalid() {
616 let db = test_db().await;
617 let user_id = make_user(&db).await;
618
619 let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
620 let totp = build_totp(&secret_b32).unwrap();
621 let code = totp.generate_current().unwrap();
622 db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap();
623
624 let fresh_code = totp.generate_current().unwrap();
626 let valid = db
627 .verify_totp(user_id, &fresh_code, &TEST_MFA_KEY)
628 .await
629 .unwrap();
630 assert!(valid, "correct TOTP code must validate");
631
632 let invalid = db
634 .verify_totp(user_id, "000000", &TEST_MFA_KEY)
635 .await
636 .unwrap();
637 assert!(!invalid, "wrong TOTP code must return false");
638 }
639
640 #[tokio::test]
641 async fn verify_totp_no_mfa() {
642 let db = test_db().await;
643 let user_id = make_user(&db).await;
644
645 let result = db.verify_totp(user_id, "123456", &TEST_MFA_KEY).await;
646 assert!(
647 matches!(result, Err(AuthError::MfaNotEnabled)),
648 "verify_totp on non-MFA user must return MfaNotEnabled"
649 );
650 }
651
652 #[tokio::test]
653 async fn recovery_code_consumption() {
654 let db = test_db().await;
655 let user_id = make_user(&db).await;
656 let codes = setup_and_enable_mfa(&db, user_id).await;
657
658 let consumed = db.verify_recovery_code(user_id, &codes[0]).await.unwrap();
659 assert!(consumed, "valid recovery code must be consumed");
660
661 let reuse = db.verify_recovery_code(user_id, &codes[0]).await.unwrap();
662 assert!(!reuse, "used recovery code must not be reusable");
663
664 let remaining = db.remaining_recovery_codes(user_id).await.unwrap();
665 assert_eq!(remaining, 9, "one code consumed, 9 remaining");
666 }
667
668 #[tokio::test]
669 async fn recovery_code_wrong() {
670 let db = test_db().await;
671 let user_id = make_user(&db).await;
672 setup_and_enable_mfa(&db, user_id).await;
673
674 let result = db.verify_recovery_code(user_id, "ZZZZZZZZ").await.unwrap();
675 assert!(!result, "wrong recovery code must return false");
676 }
677
678 #[tokio::test]
679 async fn recovery_code_case_insensitive() {
680 let db = test_db().await;
681 let user_id = make_user(&db).await;
682 let codes = setup_and_enable_mfa(&db, user_id).await;
683
684 let consumed = db
685 .verify_recovery_code(user_id, &codes[1].to_lowercase())
686 .await
687 .unwrap();
688 assert!(consumed, "lowercase recovery code must match");
689 }
690
691 #[tokio::test]
692 async fn disable_mfa_cleans_up() {
693 let db = test_db().await;
694 let user_id = make_user(&db).await;
695 setup_and_enable_mfa(&db, user_id).await;
696
697 db.disable_mfa(user_id).await.unwrap();
698
699 let enabled = db.has_mfa_enabled(user_id).await.unwrap();
700 assert!(!enabled, "MFA must not be enabled after disable");
701
702 let remaining = db.remaining_recovery_codes(user_id).await.unwrap();
703 assert_eq!(remaining, 0, "recovery codes must be deleted");
704 }
705
706 #[tokio::test]
707 async fn user_deletion_cascades() {
708 let db = test_db().await;
709 let user_id = make_user(&db).await;
710 setup_and_enable_mfa(&db, user_id).await;
711
712 db.delete_user(user_id).await.unwrap();
713
714 let enabled = db.has_mfa_enabled(user_id).await.unwrap();
715 assert!(!enabled, "MFA must not be enabled after user deletion");
716 }
717
718 #[tokio::test]
719 async fn mfa_not_configured_without_key() {
720 let ath = AllowThemBuilder::new("sqlite::memory:")
721 .build()
722 .await
723 .unwrap();
724 let email = Email::new("nokey@example.com".to_string()).unwrap();
725 let user = ath
726 .db()
727 .create_user(email, "password123", None, None)
728 .await
729 .unwrap();
730
731 let result = ath.create_mfa_secret(user.id).await;
732 assert!(
733 matches!(result, Err(AuthError::MfaNotConfigured)),
734 "MFA without key must return MfaNotConfigured"
735 );
736 }
737}