1use crate::{AuthError, AuthResult};
7use chrono::{DateTime, Utc};
8use rand::{thread_rng, Rng};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13#[cfg(feature = "mfa")]
14use base32;
15#[cfg(feature = "mfa")]
16use qrcode::{render::unicode, QrCode};
17#[cfg(feature = "mfa")]
18use totp_lite::{totp, Sha1};
19#[cfg(feature = "mfa")]
20use urlencoding;
21
22#[derive(Debug, Clone)]
24pub struct MfaConfig {
25 pub time_step: u64,
27 pub window_tolerance: u8,
29 pub secret_length: usize,
31 pub issuer: String,
33 pub backup_codes_count: usize,
35}
36
37impl Default for MfaConfig {
38 fn default() -> Self {
39 Self {
40 time_step: 30,
41 window_tolerance: 1,
42 secret_length: 20,
43 issuer: "elif.rs".to_string(),
44 backup_codes_count: 10,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct MfaSecret {
52 pub user_id: Uuid,
54 pub secret: String,
56 pub backup_codes: Vec<String>,
58 pub used_backup_codes: Vec<String>,
60 pub setup_completed_at: Option<DateTime<Utc>>,
62 pub last_verified_at: Option<DateTime<Utc>>,
64 pub created_at: DateTime<Utc>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct MfaSetup {
71 pub qr_code: String,
73 pub manual_key: String,
75 pub totp_uri: String,
77 pub backup_codes: Vec<String>,
79}
80
81#[derive(Debug, Clone, PartialEq)]
83pub enum MfaVerificationResult {
84 TotpSuccess,
86 BackupCodeSuccess,
88 Failed,
90 NotSetup,
92}
93
94pub struct MfaProvider {
96 config: MfaConfig,
97 #[allow(dead_code)]
98 secrets: HashMap<Uuid, MfaSecret>,
99}
100
101impl MfaProvider {
102 pub fn new() -> AuthResult<Self> {
104 Self::with_config(MfaConfig::default())
105 }
106
107 pub fn with_config(config: MfaConfig) -> AuthResult<Self> {
109 Ok(Self {
110 config,
111 secrets: HashMap::new(),
112 })
113 }
114
115 pub fn generate_setup(&self, _user_id: Uuid, username: &str) -> AuthResult<MfaSetup> {
117 #[cfg(not(feature = "mfa"))]
118 {
119 let _ = (_user_id, username);
120 return Err(AuthError::generic_error(
121 "MFA feature not enabled - compile with 'mfa' feature",
122 ));
123 }
124
125 #[cfg(feature = "mfa")]
126 {
127 let secret_bytes = self.generate_secret();
129 let secret_base32 =
130 base32::encode(base32::Alphabet::Rfc4648 { padding: true }, &secret_bytes);
131
132 let backup_codes = self.generate_backup_codes();
134
135 let totp_uri = format!(
137 "otpauth://totp/{}:{}?secret={}&issuer={}&digits=6&period={}",
138 urlencoding::encode(&self.config.issuer),
139 urlencoding::encode(username),
140 secret_base32,
141 urlencoding::encode(&self.config.issuer),
142 self.config.time_step
143 );
144
145 let qr_code = self.generate_qr_code(&totp_uri)?;
147
148 Ok(MfaSetup {
149 qr_code,
150 manual_key: secret_base32,
151 totp_uri,
152 backup_codes,
153 })
154 }
155 }
156
157 pub fn complete_setup(
159 &mut self,
160 user_id: Uuid,
161 setup: &MfaSetup,
162 totp_code: &str,
163 ) -> AuthResult<()> {
164 #[cfg(not(feature = "mfa"))]
165 {
166 let _ = (user_id, setup, totp_code);
167 return Err(AuthError::generic_error(
168 "MFA feature not enabled - compile with 'mfa' feature",
169 ));
170 }
171
172 #[cfg(feature = "mfa")]
173 {
174 if !self.verify_totp_code(&setup.manual_key, totp_code)? {
176 return Err(AuthError::invalid_credentials("Invalid TOTP code"));
177 }
178
179 let hashed_backup_codes = setup
181 .backup_codes
182 .iter()
183 .map(|code| self.hash_backup_code(code))
184 .collect::<Result<Vec<_>, _>>()?;
185
186 let secret = MfaSecret {
188 user_id,
189 secret: setup.manual_key.clone(),
190 backup_codes: hashed_backup_codes,
191 used_backup_codes: Vec::new(),
192 setup_completed_at: Some(Utc::now()),
193 last_verified_at: Some(Utc::now()),
194 created_at: Utc::now(),
195 };
196
197 self.secrets.insert(user_id, secret);
198 Ok(())
199 }
200 }
201
202 pub fn verify_mfa(&mut self, user_id: Uuid, code: &str) -> AuthResult<MfaVerificationResult> {
204 #[cfg(not(feature = "mfa"))]
205 {
206 let _ = (user_id, code);
207 return Ok(MfaVerificationResult::NotSetup);
208 }
209
210 #[cfg(feature = "mfa")]
211 {
212 if !self.secrets.contains_key(&user_id) {
213 return Ok(MfaVerificationResult::NotSetup);
214 }
215
216 let totp_secret = {
218 let secret = self.secrets.get(&user_id).unwrap();
219 secret.secret.clone()
220 };
221
222 if self.verify_totp_code(&totp_secret, code)? {
223 if let Some(secret) = self.secrets.get_mut(&user_id) {
225 secret.last_verified_at = Some(Utc::now());
226 }
227 return Ok(MfaVerificationResult::TotpSuccess);
228 }
229
230 let backup_code_valid = if let Some(secret) = self.secrets.get_mut(&user_id) {
232 let backup_codes = secret.backup_codes.clone();
234 let used_codes = secret.used_backup_codes.clone();
235
236 let code_hash = self.hash_backup_code(code)?;
237
238 backup_codes.contains(&code_hash) && !used_codes.contains(&code_hash)
240 } else {
241 false
242 };
243
244 if backup_code_valid {
245 let code_hash = self.hash_backup_code(code)?;
247
248 if let Some(secret) = self.secrets.get_mut(&user_id) {
250 secret.used_backup_codes.push(code_hash);
251 secret.last_verified_at = Some(Utc::now());
252 }
253 return Ok(MfaVerificationResult::BackupCodeSuccess);
254 }
255
256 Ok(MfaVerificationResult::Failed)
257 }
258 }
259
260 pub fn is_mfa_enabled(&self, user_id: Uuid) -> bool {
262 self.secrets
263 .get(&user_id)
264 .map(|secret| secret.setup_completed_at.is_some())
265 .unwrap_or(false)
266 }
267
268 pub fn disable_mfa(&mut self, user_id: Uuid) -> AuthResult<bool> {
270 Ok(self.secrets.remove(&user_id).is_some())
271 }
272
273 pub fn get_remaining_backup_codes_count(&self, user_id: Uuid) -> AuthResult<usize> {
275 match self.secrets.get(&user_id) {
276 Some(secret) => {
277 let remaining = secret
279 .backup_codes
280 .len()
281 .saturating_sub(secret.used_backup_codes.len());
282 Ok(remaining)
283 }
284 None => Ok(0),
285 }
286 }
287
288 pub fn regenerate_backup_codes(&mut self, user_id: Uuid) -> AuthResult<Vec<String>> {
290 #[cfg(not(feature = "mfa"))]
291 {
292 let _ = user_id;
293 return Err(AuthError::generic_error(
294 "MFA feature not enabled - compile with 'mfa' feature",
295 ));
296 }
297
298 #[cfg(feature = "mfa")]
299 {
300 if !self.secrets.contains_key(&user_id) {
301 return Err(AuthError::not_found("MFA not setup for user"));
302 }
303
304 let new_backup_codes = self.generate_backup_codes();
306 let hashed_codes = new_backup_codes
307 .iter()
308 .map(|code| self.hash_backup_code(code))
309 .collect::<Result<Vec<_>, _>>()?;
310
311 if let Some(secret) = self.secrets.get_mut(&user_id) {
313 secret.backup_codes = hashed_codes;
314 secret.used_backup_codes.clear();
315 }
316
317 Ok(new_backup_codes)
318 }
319 }
320
321 #[cfg(feature = "mfa")]
324 fn generate_secret(&self) -> Vec<u8> {
325 let mut secret = vec![0u8; self.config.secret_length];
326 thread_rng().fill(&mut secret[..]);
327 secret
328 }
329
330 #[cfg(feature = "mfa")]
331 fn generate_backup_codes(&self) -> Vec<String> {
332 let mut codes = Vec::with_capacity(self.config.backup_codes_count);
333 let mut rng = thread_rng();
334
335 for _ in 0..self.config.backup_codes_count {
336 let code = format!("{:08}", rng.gen_range(10000000..99999999));
338 codes.push(code);
339 }
340
341 codes
342 }
343
344 #[cfg(feature = "mfa")]
345 fn generate_qr_code(&self, totp_uri: &str) -> AuthResult<String> {
346 let qr_code = QrCode::new(totp_uri)
347 .map_err(|e| AuthError::generic_error(format!("Failed to generate QR code: {}", e)))?;
348
349 let qr_string = qr_code
350 .render::<unicode::Dense1x2>()
351 .dark_color(unicode::Dense1x2::Light)
352 .light_color(unicode::Dense1x2::Dark)
353 .build();
354
355 Ok(qr_string)
356 }
357
358 #[cfg(feature = "mfa")]
359 fn verify_totp_code(&self, secret_base32: &str, code: &str) -> AuthResult<bool> {
360 let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, secret_base32)
362 .ok_or_else(|| AuthError::generic_error("Invalid secret format"))?;
363
364 if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
366 return Err(AuthError::invalid_credentials("Invalid TOTP code format"));
367 }
368
369 let current_time = Utc::now().timestamp() as u64;
370
371 for i in 0..=(self.config.window_tolerance * 2) {
373 let time_offset = (i as i64) - (self.config.window_tolerance as i64);
374 let time_window =
375 ((current_time as i64) + (time_offset * self.config.time_step as i64)) as u64;
376
377 let expected_code = totp::<Sha1>(&secret, time_window);
379
380 let expected_code_6digits = if expected_code.len() >= 6 {
383 &expected_code[expected_code.len() - 6..]
384 } else {
385 &expected_code
386 };
387
388 if expected_code_6digits == code {
389 return Ok(true);
390 }
391 }
392
393 Ok(false)
394 }
395
396 #[cfg(feature = "mfa")]
397 pub fn verify_backup_code(&mut self, secret: &mut MfaSecret, code: &str) -> AuthResult<bool> {
398 let code_hash = self.hash_backup_code(code)?;
399
400 if secret.backup_codes.contains(&code_hash)
402 && !secret.used_backup_codes.contains(&code_hash)
403 {
404 secret.used_backup_codes.push(code_hash);
405 return Ok(true);
406 }
407
408 Ok(false)
409 }
410
411 #[cfg(feature = "mfa")]
412 pub fn hash_backup_code(&self, code: &str) -> AuthResult<String> {
413 use std::collections::hash_map::DefaultHasher;
415 use std::hash::{Hash, Hasher};
416
417 let mut hasher = DefaultHasher::new();
418 code.hash(&mut hasher);
419 Ok(format!("{:x}", hasher.finish()))
420 }
421}
422
423impl Default for MfaProvider {
424 fn default() -> Self {
425 Self::new().expect("Failed to create default MFA provider")
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[tokio::test]
434 async fn test_mfa_provider_creation() {
435 let provider = MfaProvider::new();
436 assert!(provider.is_ok());
437 }
438
439 #[tokio::test]
440 async fn test_mfa_config_defaults() {
441 let config = MfaConfig::default();
442 assert_eq!(config.time_step, 30);
443 assert_eq!(config.window_tolerance, 1);
444 assert_eq!(config.secret_length, 20);
445 assert_eq!(config.issuer, "elif.rs");
446 assert_eq!(config.backup_codes_count, 10);
447 }
448
449 #[tokio::test]
450 async fn test_mfa_provider_with_custom_config() {
451 let config = MfaConfig {
452 time_step: 60,
453 window_tolerance: 2,
454 secret_length: 32,
455 issuer: "test-app".to_string(),
456 backup_codes_count: 12,
457 };
458
459 let provider = MfaProvider::with_config(config.clone());
460 assert!(provider.is_ok());
461
462 let provider = provider.unwrap();
463 assert_eq!(provider.config.time_step, 60);
464 assert_eq!(provider.config.issuer, "test-app");
465 }
466
467 #[cfg(feature = "mfa")]
468 #[tokio::test]
469 async fn test_mfa_setup_generation() {
470 let provider = MfaProvider::new().unwrap();
471 let user_id = Uuid::new_v4();
472 let username = "testuser";
473
474 let setup = provider.generate_setup(user_id, username);
475 assert!(setup.is_ok());
476
477 let setup = setup.unwrap();
478 assert!(!setup.qr_code.is_empty());
479 assert!(!setup.manual_key.is_empty());
480 assert!(setup.totp_uri.contains("otpauth://totp/"));
481 assert!(setup.totp_uri.contains(username));
482 assert_eq!(setup.backup_codes.len(), 10);
483 }
484
485 #[tokio::test]
486 async fn test_mfa_not_enabled_by_default() {
487 let mut provider = MfaProvider::new().unwrap();
488 let user_id = Uuid::new_v4();
489
490 assert!(!provider.is_mfa_enabled(user_id));
491
492 let result = provider.verify_mfa(user_id, "123456");
493 assert!(result.is_ok());
494 assert_eq!(result.unwrap(), MfaVerificationResult::NotSetup);
495 }
496
497 #[tokio::test]
498 async fn test_mfa_disable() {
499 let mut provider = MfaProvider::new().unwrap();
500 let user_id = Uuid::new_v4();
501
502 let disabled = provider.disable_mfa(user_id).unwrap();
504 assert!(!disabled);
505
506 }
508
509 #[tokio::test]
510 async fn test_backup_codes_count() {
511 let provider = MfaProvider::new().unwrap();
512 let user_id = Uuid::new_v4();
513
514 let count = provider.get_remaining_backup_codes_count(user_id);
515 assert!(count.is_ok());
516 assert_eq!(count.unwrap(), 0);
517 }
518
519 #[tokio::test]
520 async fn test_backup_codes_count_underflow() {
521 let mut provider = MfaProvider::new().unwrap();
522 let user_id = Uuid::new_v4();
523
524 provider.secrets.insert(
526 user_id,
527 MfaSecret {
528 user_id,
529 secret: "dummy".to_string(),
530 backup_codes: Vec::new(),
531 used_backup_codes: vec!["used".to_string()],
532 setup_completed_at: None,
533 last_verified_at: None,
534 created_at: Utc::now(),
535 },
536 );
537
538 let count = provider.get_remaining_backup_codes_count(user_id);
539 assert!(count.is_ok());
540 assert_eq!(count.unwrap(), 0);
541 }
542
543 #[cfg(not(feature = "mfa"))]
544 #[tokio::test]
545 async fn test_mfa_disabled_error_messages() {
546 let provider = MfaProvider::new().unwrap();
547 let user_id = Uuid::new_v4();
548
549 let setup_result = provider.generate_setup(user_id, "testuser");
550 assert!(setup_result.is_err());
551 assert!(setup_result
552 .unwrap_err()
553 .to_string()
554 .contains("MFA feature not enabled"));
555 }
556
557 #[cfg(feature = "mfa")]
558 #[tokio::test]
559 async fn test_backup_code_generation() {
560 let provider = MfaProvider::new().unwrap();
561 let codes = provider.generate_backup_codes();
562
563 assert_eq!(codes.len(), 10);
564 for code in &codes {
565 assert_eq!(code.len(), 8);
566 assert!(code.chars().all(|c| c.is_ascii_digit()));
567 }
568
569 let unique_codes: std::collections::HashSet<_> = codes.iter().collect();
571 assert_eq!(unique_codes.len(), codes.len());
572 }
573
574 #[cfg(feature = "mfa")]
575 #[tokio::test]
576 async fn test_secret_generation() {
577 let provider = MfaProvider::new().unwrap();
578 let secret1 = provider.generate_secret();
579 let secret2 = provider.generate_secret();
580
581 assert_eq!(secret1.len(), 20);
582 assert_eq!(secret2.len(), 20);
583 assert_ne!(secret1, secret2); }
585
586 #[cfg(feature = "mfa")]
587 #[tokio::test]
588 async fn test_totp_uri_format() {
589 let provider = MfaProvider::new().unwrap();
590 let user_id = Uuid::new_v4();
591 let username = "test@example.com";
592
593 let setup = provider.generate_setup(user_id, username).unwrap();
594
595 assert!(setup.totp_uri.starts_with("otpauth://totp/"));
596 assert!(setup.totp_uri.contains("elif.rs"));
597 assert!(setup.totp_uri.contains("test%40example.com")); assert!(setup.totp_uri.contains("digits=6"));
599 assert!(setup.totp_uri.contains("period=30"));
600 }
601}