1use chrono::Utc;
4use rand::TryRngCore;
5use rand::rngs::OsRng;
6use sha2::{Digest, Sha256};
7use totp_rs::{Algorithm, Secret, TOTP};
8
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::mfa_encrypt;
12use crate::types::{MfaRecoveryCodeId, MfaSecretId, UserId};
13
14const RECOVERY_CODE_COUNT: usize = 10;
15const RECOVERY_CODE_LENGTH: usize = 8;
16const RECOVERY_CHARSET: &[u8] = b"ABCDEFGHJKMNPQRSTUVWXYZ23456789";
18
19#[derive(Debug, Clone, sqlx::FromRow)]
20struct MfaSecretRow {
21 id: MfaSecretId,
22 #[allow(dead_code)]
23 user_id: UserId,
24 secret: String, enabled: bool,
26 #[allow(dead_code)]
27 created_at: chrono::DateTime<Utc>,
28}
29
30fn build_totp(secret_base32: &str) -> Result<TOTP, AuthError> {
31 let secret_bytes = Secret::Encoded(secret_base32.to_string())
32 .to_bytes()
33 .map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
34 TOTP::new(Algorithm::SHA1, 6, 1, 30, secret_bytes, None, String::new())
35 .map_err(|e| AuthError::MfaEncryption(e.to_string()))
36}
37
38fn generate_recovery_code() -> String {
39 let mut bytes = [0u8; RECOVERY_CODE_LENGTH];
40 OsRng
41 .try_fill_bytes(&mut bytes)
42 .expect("OS RNG unavailable");
43 bytes
44 .iter()
45 .map(|b| RECOVERY_CHARSET[(*b as usize) % RECOVERY_CHARSET.len()] as char)
46 .collect()
47}
48
49fn hash_mfa_challenge(raw: &str) -> String {
50 let digest = Sha256::digest(raw.as_bytes());
51 format!("{digest:x}")
52}
53
54fn hash_recovery_code(code: &str) -> String {
55 let normalized = code.to_ascii_uppercase();
56 let digest = Sha256::digest(normalized.as_bytes());
57 format!("{digest:x}")
58}
59
60pub fn totp_uri(secret_base32: &str, account_name: &str, issuer: &str) -> String {
65 let secret_bytes = Secret::Encoded(secret_base32.to_string())
66 .to_bytes()
67 .expect("totp_uri called with invalid secret");
68 let totp = TOTP::new(
69 Algorithm::SHA1,
70 6,
71 1,
72 30,
73 secret_bytes,
74 Some(issuer.to_string()),
75 account_name.to_string(),
76 )
77 .expect("totp_uri called with invalid secret");
78 totp.get_url()
79}
80
81impl Db {
82 pub async fn create_mfa_secret(
92 &self,
93 user_id: UserId,
94 mfa_key: &[u8; 32],
95 ) -> Result<String, AuthError> {
96 let existing: Option<MfaSecretRow> = sqlx::query_as(
97 "SELECT id, user_id, secret, enabled, created_at \
98 FROM allowthem_mfa_secrets WHERE user_id = ?",
99 )
100 .bind(user_id)
101 .fetch_optional(self.pool())
102 .await?;
103
104 if let Some(row) = existing {
105 if row.enabled {
106 return Err(AuthError::MfaAlreadyEnabled);
107 }
108 sqlx::query("DELETE FROM allowthem_mfa_secrets WHERE id = ?")
110 .bind(row.id)
111 .execute(self.pool())
112 .await?;
113 }
114
115 let secret = Secret::generate_secret();
116 let secret_base32 = secret.to_encoded().to_string();
117
118 let encrypted = mfa_encrypt::encrypt_secret(secret_base32.as_bytes(), mfa_key)?;
119 let id = MfaSecretId::new();
120
121 sqlx::query(
122 "INSERT INTO allowthem_mfa_secrets (id, user_id, secret, enabled) \
123 VALUES (?, ?, ?, 0)",
124 )
125 .bind(id)
126 .bind(user_id)
127 .bind(&encrypted)
128 .execute(self.pool())
129 .await?;
130
131 Ok(secret_base32)
132 }
133
134 pub async fn enable_mfa(
142 &self,
143 user_id: UserId,
144 code: &str,
145 mfa_key: &[u8; 32],
146 ) -> Result<Vec<String>, AuthError> {
147 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
148
149 let row: MfaSecretRow = sqlx::query_as(
150 "SELECT id, user_id, secret, enabled, created_at \
151 FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 0",
152 )
153 .bind(user_id)
154 .fetch_optional(&mut *tx)
155 .await
156 .map_err(AuthError::Database)?
157 .ok_or(AuthError::MfaNotEnabled)?;
158
159 let secret_bytes = mfa_encrypt::decrypt_secret(&row.secret, mfa_key)?;
160 let secret_base32 =
161 String::from_utf8(secret_bytes).map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
162 let totp = build_totp(&secret_base32)?;
163
164 if !totp
165 .check_current(code)
166 .map_err(|e| AuthError::MfaEncryption(e.to_string()))?
167 {
168 return Err(AuthError::InvalidTotpCode);
169 }
170
171 sqlx::query("UPDATE allowthem_mfa_secrets SET enabled = 1 WHERE id = ?")
172 .bind(row.id)
173 .execute(&mut *tx)
174 .await
175 .map_err(AuthError::Database)?;
176
177 let mut plaintext_codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
178 for _ in 0..RECOVERY_CODE_COUNT {
179 let recovery = generate_recovery_code();
180 let code_hash = hash_recovery_code(&recovery);
181 let code_id = MfaRecoveryCodeId::new();
182
183 sqlx::query(
184 "INSERT INTO allowthem_mfa_recovery_codes (id, user_id, code_hash) \
185 VALUES (?, ?, ?)",
186 )
187 .bind(code_id)
188 .bind(user_id)
189 .bind(&code_hash)
190 .execute(&mut *tx)
191 .await
192 .map_err(AuthError::Database)?;
193
194 plaintext_codes.push(recovery);
195 }
196
197 tx.commit().await.map_err(AuthError::Database)?;
198
199 Ok(plaintext_codes)
200 }
201
202 pub async fn verify_totp(
207 &self,
208 user_id: UserId,
209 code: &str,
210 mfa_key: &[u8; 32],
211 ) -> Result<bool, AuthError> {
212 let row: MfaSecretRow = sqlx::query_as(
213 "SELECT id, user_id, secret, enabled, created_at \
214 FROM allowthem_mfa_secrets WHERE user_id = ? AND enabled = 1",
215 )
216 .bind(user_id)
217 .fetch_optional(self.pool())
218 .await?
219 .ok_or(AuthError::MfaNotEnabled)?;
220
221 let secret_bytes = mfa_encrypt::decrypt_secret(&row.secret, mfa_key)?;
222 let secret_base32 =
223 String::from_utf8(secret_bytes).map_err(|e| AuthError::MfaEncryption(e.to_string()))?;
224 let totp = build_totp(&secret_base32)?;
225
226 totp.check_current(code)
227 .map_err(|e| AuthError::MfaEncryption(e.to_string()))
228 }
229
230 pub async fn has_mfa_enabled(&self, user_id: UserId) -> Result<bool, AuthError> {
232 let count: (i64,) = sqlx::query_as(
233 "SELECT COUNT(*) FROM allowthem_mfa_secrets \
234 WHERE user_id = ? AND enabled = 1",
235 )
236 .bind(user_id)
237 .fetch_one(self.pool())
238 .await?;
239
240 Ok(count.0 > 0)
241 }
242
243 pub async fn disable_mfa(&self, user_id: UserId) -> Result<(), AuthError> {
247 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
248
249 sqlx::query("DELETE FROM allowthem_mfa_recovery_codes WHERE user_id = ?")
250 .bind(user_id)
251 .execute(&mut *tx)
252 .await
253 .map_err(AuthError::Database)?;
254
255 sqlx::query("DELETE FROM allowthem_mfa_secrets WHERE user_id = ?")
256 .bind(user_id)
257 .execute(&mut *tx)
258 .await
259 .map_err(AuthError::Database)?;
260
261 tx.commit().await.map_err(AuthError::Database)?;
262
263 Ok(())
264 }
265
266 pub async fn verify_recovery_code(
272 &self,
273 user_id: UserId,
274 code: &str,
275 ) -> Result<bool, AuthError> {
276 let code_hash = hash_recovery_code(code);
277 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
278
279 let row: Option<(MfaRecoveryCodeId,)> = sqlx::query_as(
280 "UPDATE allowthem_mfa_recovery_codes SET used_at = ?1 \
281 WHERE user_id = ?2 AND code_hash = ?3 AND used_at IS NULL \
282 RETURNING id",
283 )
284 .bind(&now)
285 .bind(user_id)
286 .bind(&code_hash)
287 .fetch_optional(self.pool())
288 .await?;
289
290 Ok(row.is_some())
291 }
292
293 pub async fn remaining_recovery_codes(&self, user_id: UserId) -> Result<i64, AuthError> {
295 let count: (i64,) = sqlx::query_as(
296 "SELECT COUNT(*) FROM allowthem_mfa_recovery_codes \
297 WHERE user_id = ? AND used_at IS NULL",
298 )
299 .bind(user_id)
300 .fetch_one(self.pool())
301 .await?;
302
303 Ok(count.0)
304 }
305
306 pub async fn regenerate_recovery_codes(
311 &self,
312 user_id: UserId,
313 ) -> Result<Vec<String>, AuthError> {
314 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
315
316 sqlx::query("DELETE FROM allowthem_mfa_recovery_codes WHERE user_id = ?")
317 .bind(user_id)
318 .execute(&mut *tx)
319 .await
320 .map_err(AuthError::Database)?;
321
322 let mut plaintext_codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
323 for _ in 0..RECOVERY_CODE_COUNT {
324 let code = generate_recovery_code();
325 let code_hash = hash_recovery_code(&code);
326 let code_id = MfaRecoveryCodeId::new();
327
328 sqlx::query(
329 "INSERT INTO allowthem_mfa_recovery_codes (id, user_id, code_hash) \
330 VALUES (?, ?, ?)",
331 )
332 .bind(code_id)
333 .bind(user_id)
334 .bind(&code_hash)
335 .execute(&mut *tx)
336 .await
337 .map_err(AuthError::Database)?;
338
339 plaintext_codes.push(code);
340 }
341
342 tx.commit().await.map_err(AuthError::Database)?;
343
344 Ok(plaintext_codes)
345 }
346
347 pub async fn create_mfa_challenge(&self, user_id: UserId) -> Result<String, AuthError> {
354 use crate::sessions::generate_token;
355 use crate::types::MfaChallengeId;
356
357 let token = generate_token();
358 let token_hash = hash_mfa_challenge(token.as_str());
359 let id = MfaChallengeId::new();
360 let expires_at = Utc::now() + chrono::Duration::minutes(5);
361 let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
362
363 sqlx::query(
364 "INSERT INTO allowthem_mfa_challenges (id, token_hash, user_id, expires_at) \
365 VALUES (?, ?, ?, ?)",
366 )
367 .bind(id)
368 .bind(&token_hash)
369 .bind(user_id)
370 .bind(&expires_at_str)
371 .execute(self.pool())
372 .await?;
373
374 Ok(token.as_str().to_string())
375 }
376
377 pub async fn validate_mfa_challenge(
383 &self,
384 raw_token: &str,
385 ) -> Result<Option<UserId>, AuthError> {
386 let token_hash = hash_mfa_challenge(raw_token);
387 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
388
389 let row: Option<(UserId,)> = sqlx::query_as(
390 "SELECT user_id FROM allowthem_mfa_challenges \
391 WHERE token_hash = ? AND expires_at > ?",
392 )
393 .bind(&token_hash)
394 .bind(&now)
395 .fetch_optional(self.pool())
396 .await?;
397
398 Ok(row.map(|(uid,)| uid))
399 }
400
401 pub async fn consume_mfa_challenge(&self, raw_token: &str) -> Result<(), AuthError> {
405 let token_hash = hash_mfa_challenge(raw_token);
406
407 sqlx::query("DELETE FROM allowthem_mfa_challenges WHERE token_hash = ?")
408 .bind(&token_hash)
409 .execute(self.pool())
410 .await?;
411
412 Ok(())
413 }
414}
415
416use crate::handle::AllowThem;
417
418impl AllowThem {
419 pub async fn create_mfa_secret(&self, user_id: UserId) -> Result<String, AuthError> {
420 self.db().create_mfa_secret(user_id, self.mfa_key()?).await
421 }
422
423 pub async fn enable_mfa(&self, user_id: UserId, code: &str) -> Result<Vec<String>, AuthError> {
424 self.db().enable_mfa(user_id, code, self.mfa_key()?).await
425 }
426
427 pub async fn verify_totp(&self, user_id: UserId, code: &str) -> Result<bool, AuthError> {
428 self.db().verify_totp(user_id, code, self.mfa_key()?).await
429 }
430
431 pub async fn has_mfa_enabled(&self, user_id: UserId) -> Result<bool, AuthError> {
432 self.db().has_mfa_enabled(user_id).await
433 }
434
435 pub async fn disable_mfa(&self, user_id: UserId) -> Result<(), AuthError> {
436 self.db().disable_mfa(user_id).await
437 }
438
439 pub async fn verify_recovery_code(
440 &self,
441 user_id: UserId,
442 code: &str,
443 ) -> Result<bool, AuthError> {
444 self.db().verify_recovery_code(user_id, code).await
445 }
446
447 pub async fn remaining_recovery_codes(&self, user_id: UserId) -> Result<i64, AuthError> {
448 self.db().remaining_recovery_codes(user_id).await
449 }
450
451 pub async fn regenerate_recovery_codes(
452 &self,
453 user_id: UserId,
454 ) -> Result<Vec<String>, AuthError> {
455 self.db().regenerate_recovery_codes(user_id).await
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use crate::db::Db;
462 use crate::error::AuthError;
463 use crate::handle::AllowThemBuilder;
464 use crate::types::Email;
465
466 use super::*;
467
468 const TEST_MFA_KEY: [u8; 32] = [0x42; 32];
469
470 async fn test_db() -> Db {
471 Db::connect("sqlite::memory:").await.expect("in-memory db")
472 }
473
474 async fn make_user(db: &Db) -> UserId {
475 let email = Email::new("mfa@example.com".to_string()).unwrap();
476 db.create_user(email, "password123", None).await.unwrap().id
477 }
478
479 async fn setup_and_enable_mfa(db: &Db, user_id: UserId) -> Vec<String> {
482 let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
483 let totp = build_totp(&secret_b32).unwrap();
484 let code = totp.generate_current().unwrap();
485 db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap()
486 }
487
488 #[tokio::test]
489 async fn totp_validation() {
490 let secret = Secret::generate_secret();
491 let secret_b32 = secret.to_encoded().to_string();
492 let totp = build_totp(&secret_b32).unwrap();
493 let code = totp.generate_current().unwrap();
494 let valid = totp
495 .check_current(&code)
496 .expect("check_current should not fail");
497 assert!(valid, "generated code must validate");
498 }
499
500 #[tokio::test]
501 async fn totp_uri_format() {
502 let secret = Secret::generate_secret();
503 let secret_b32 = secret.to_encoded().to_string();
504 let uri = totp_uri(&secret_b32, "user@example.com", "allowthem");
505 assert!(
506 uri.starts_with("otpauth://totp/"),
507 "URI must start with otpauth://totp/"
508 );
509 assert!(
510 uri.contains("user%40example.com"),
511 "URI must contain account name"
512 );
513 assert!(uri.contains("allowthem"), "URI must contain issuer");
514 }
515
516 #[tokio::test]
517 async fn create_and_enable_flow() {
518 let db = test_db().await;
519 let user_id = make_user(&db).await;
520
521 let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
522 let totp = build_totp(&secret_b32).unwrap();
523 let code = totp.generate_current().unwrap();
524
525 let recovery_codes = db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap();
526 assert_eq!(recovery_codes.len(), 10, "must return 10 recovery codes");
527
528 let enabled = db.has_mfa_enabled(user_id).await.unwrap();
529 assert!(enabled, "MFA must be enabled after enable_mfa");
530 }
531
532 #[tokio::test]
533 async fn enable_rejects_wrong_code() {
534 let db = test_db().await;
535 let user_id = make_user(&db).await;
536 db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
537
538 let result = db.enable_mfa(user_id, "000000", &TEST_MFA_KEY).await;
539 assert!(
540 matches!(result, Err(AuthError::InvalidTotpCode)),
541 "wrong code must return InvalidTotpCode"
542 );
543 }
544
545 #[tokio::test]
546 async fn double_enable_blocked() {
547 let db = test_db().await;
548 let user_id = make_user(&db).await;
549 setup_and_enable_mfa(&db, user_id).await;
550
551 let result = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await;
552 assert!(
553 matches!(result, Err(AuthError::MfaAlreadyEnabled)),
554 "second create must return MfaAlreadyEnabled"
555 );
556 }
557
558 #[tokio::test]
559 async fn abandoned_setup_replacement() {
560 let db = test_db().await;
561 let user_id = make_user(&db).await;
562
563 let secret_a = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
564 let secret_b = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
565 assert_ne!(secret_a, secret_b, "replacement must produce a new secret");
566
567 let totp = build_totp(&secret_b).unwrap();
569 let code = totp.generate_current().unwrap();
570 let result = db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await;
571 assert!(result.is_ok(), "enable with new secret must succeed");
572 }
573
574 #[tokio::test]
575 async fn verify_totp_valid_and_invalid() {
576 let db = test_db().await;
577 let user_id = make_user(&db).await;
578
579 let secret_b32 = db.create_mfa_secret(user_id, &TEST_MFA_KEY).await.unwrap();
580 let totp = build_totp(&secret_b32).unwrap();
581 let code = totp.generate_current().unwrap();
582 db.enable_mfa(user_id, &code, &TEST_MFA_KEY).await.unwrap();
583
584 let fresh_code = totp.generate_current().unwrap();
586 let valid = db
587 .verify_totp(user_id, &fresh_code, &TEST_MFA_KEY)
588 .await
589 .unwrap();
590 assert!(valid, "correct TOTP code must validate");
591
592 let invalid = db
594 .verify_totp(user_id, "000000", &TEST_MFA_KEY)
595 .await
596 .unwrap();
597 assert!(!invalid, "wrong TOTP code must return false");
598 }
599
600 #[tokio::test]
601 async fn verify_totp_no_mfa() {
602 let db = test_db().await;
603 let user_id = make_user(&db).await;
604
605 let result = db.verify_totp(user_id, "123456", &TEST_MFA_KEY).await;
606 assert!(
607 matches!(result, Err(AuthError::MfaNotEnabled)),
608 "verify_totp on non-MFA user must return MfaNotEnabled"
609 );
610 }
611
612 #[tokio::test]
613 async fn recovery_code_consumption() {
614 let db = test_db().await;
615 let user_id = make_user(&db).await;
616 let codes = setup_and_enable_mfa(&db, user_id).await;
617
618 let consumed = db.verify_recovery_code(user_id, &codes[0]).await.unwrap();
619 assert!(consumed, "valid recovery code must be consumed");
620
621 let reuse = db.verify_recovery_code(user_id, &codes[0]).await.unwrap();
622 assert!(!reuse, "used recovery code must not be reusable");
623
624 let remaining = db.remaining_recovery_codes(user_id).await.unwrap();
625 assert_eq!(remaining, 9, "one code consumed, 9 remaining");
626 }
627
628 #[tokio::test]
629 async fn recovery_code_wrong() {
630 let db = test_db().await;
631 let user_id = make_user(&db).await;
632 setup_and_enable_mfa(&db, user_id).await;
633
634 let result = db.verify_recovery_code(user_id, "ZZZZZZZZ").await.unwrap();
635 assert!(!result, "wrong recovery code must return false");
636 }
637
638 #[tokio::test]
639 async fn recovery_code_case_insensitive() {
640 let db = test_db().await;
641 let user_id = make_user(&db).await;
642 let codes = setup_and_enable_mfa(&db, user_id).await;
643
644 let consumed = db
645 .verify_recovery_code(user_id, &codes[1].to_lowercase())
646 .await
647 .unwrap();
648 assert!(consumed, "lowercase recovery code must match");
649 }
650
651 #[tokio::test]
652 async fn disable_mfa_cleans_up() {
653 let db = test_db().await;
654 let user_id = make_user(&db).await;
655 setup_and_enable_mfa(&db, user_id).await;
656
657 db.disable_mfa(user_id).await.unwrap();
658
659 let enabled = db.has_mfa_enabled(user_id).await.unwrap();
660 assert!(!enabled, "MFA must not be enabled after disable");
661
662 let remaining = db.remaining_recovery_codes(user_id).await.unwrap();
663 assert_eq!(remaining, 0, "recovery codes must be deleted");
664 }
665
666 #[tokio::test]
667 async fn user_deletion_cascades() {
668 let db = test_db().await;
669 let user_id = make_user(&db).await;
670 setup_and_enable_mfa(&db, user_id).await;
671
672 db.delete_user(user_id).await.unwrap();
673
674 let enabled = db.has_mfa_enabled(user_id).await.unwrap();
675 assert!(!enabled, "MFA must not be enabled after user deletion");
676 }
677
678 #[tokio::test]
679 async fn mfa_not_configured_without_key() {
680 let ath = AllowThemBuilder::new("sqlite::memory:")
681 .build()
682 .await
683 .unwrap();
684 let email = Email::new("nokey@example.com".to_string()).unwrap();
685 let user = ath
686 .db()
687 .create_user(email, "password123", None)
688 .await
689 .unwrap();
690
691 let result = ath.create_mfa_secret(user.id).await;
692 assert!(
693 matches!(result, Err(AuthError::MfaNotConfigured)),
694 "MFA without key must return MfaNotConfigured"
695 );
696 }
697}