1use crate::error::{SecurityError, Result};
4use crate::config::MfaConfig;
5use qrcode::QrCode;
6use qrcode::render::svg;
7use totp_rs::{Algorithm, Secret, TOTP};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use rand::Rng;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct MfaSecret {
15 pub secret: String,
16 pub algorithm: MfaAlgorithm,
17 pub digits: usize,
18 pub skew: u8,
19 pub step: u64,
20 pub issuer: String,
21 pub account_name: String,
22}
23
24#[derive(Debug, Clone)]
26pub struct MfaCode {
27 pub code: String,
28 pub timestamp: u64,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
33pub enum MfaAlgorithm {
34 SHA1,
35 SHA256,
36 SHA512,
37}
38
39impl From<MfaAlgorithm> for Algorithm {
40 fn from(alg: MfaAlgorithm) -> Self {
41 match alg {
42 MfaAlgorithm::SHA1 => Algorithm::SHA1,
43 MfaAlgorithm::SHA256 => Algorithm::SHA256,
44 MfaAlgorithm::SHA512 => Algorithm::SHA512,
45 }
46 }
47}
48
49pub struct MfaService {
51 config: MfaConfig,
52}
53
54impl MfaService {
55 pub fn new() -> Self {
57 Self {
58 config: MfaConfig::default(),
59 }
60 }
61
62 pub fn with_config(config: MfaConfig) -> Self {
64 Self { config }
65 }
66
67 pub fn generate_secret(&self, account_name: &str) -> Result<(String, String)> {
69 let mut rng = rand::thread_rng();
70
71 let secret_bytes: [u8; 20] = rng.gen();
73 let secret = Secret::Encoded(hex::encode(secret_bytes));
74
75 let _totp = TOTP::new(
77 Algorithm::SHA1,
78 self.config.digits.into(),
79 self.config.skew,
80 self.config.step,
81 secret.to_bytes()
82 .map_err(|e| SecurityError::Mfa(format!("Failed to decode secret: {}", e)))?,
83 ).map_err(|e| SecurityError::Mfa(format!("Failed to create TOTP: {}", e)))?;
84
85 let issuer_encoded = self.config.issuer.replace(" ", "%20").replace(":", "%3A");
87 let account_encoded = account_name.replace(" ", "%20").replace(":", "%3A");
88
89 let url = format!(
90 "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm={}&digits={}&period={}",
91 issuer_encoded,
92 account_encoded,
93 hex::encode(&secret_bytes),
94 issuer_encoded,
95 "SHA1",
96 self.config.digits,
97 self.config.step
98 );
99
100 let qr_code = QrCode::new(url.as_bytes())
102 .map_err(|e| SecurityError::Mfa(format!("Failed to generate QR code: {}", e)))?;
103
104 let qr_svg = qr_code
105 .render()
106 .min_dimensions(self.config.qr_code_size, self.config.qr_code_size)
107 .dark_color(svg::Color("#000000"))
108 .light_color(svg::Color("#FFFFFF"))
109 .build();
110
111 let secret_hex = hex::encode(secret_bytes);
112
113 Ok((secret_hex, qr_svg))
114 }
115
116 pub fn generate_secret_detailed(&self, account_name: &str) -> Result<MfaSecret> {
118 let mut rng = rand::thread_rng();
119 let secret_bytes: [u8; 20] = rng.gen();
120 let secret_hex = hex::encode(secret_bytes);
121
122 let secret = MfaSecret {
123 secret: secret_hex,
124 algorithm: MfaAlgorithm::SHA1,
125 digits: self.config.digits as usize,
126 skew: self.config.skew,
127 step: self.config.step,
128 issuer: self.config.issuer.clone(),
129 account_name: account_name.to_string(),
130 };
131
132 Ok(secret)
133 }
134
135 pub fn generate_qr_code(&self, secret: &MfaSecret) -> Result<String> {
137 let _secret_bytes = hex::decode(&secret.secret)
138 .map_err(|e| SecurityError::Mfa(format!("Invalid secret hex: {}", e)))?;
139
140 let totp_secret = Secret::Encoded(secret.secret.clone());
141
142 let _totp = TOTP::new(
143 secret.algorithm.clone().into(),
144 secret.digits,
145 secret.skew,
146 secret.step,
147 totp_secret.to_bytes()
148 .map_err(|e| SecurityError::Mfa(format!("Failed to decode secret: {}", e)))?,
149 ).map_err(|e| SecurityError::Mfa(format!("Failed to create TOTP: {}", e)))?;
150
151 let issuer_encoded = secret.issuer.replace(" ", "%20").replace(":", "%3A");
153 let account_encoded = secret.account_name.replace(" ", "%20").replace(":", "%3A");
154
155 let url = format!(
156 "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm={}&digits={}&period={}",
157 issuer_encoded,
158 account_encoded,
159 secret.secret,
160 issuer_encoded,
161 match secret.algorithm {
162 MfaAlgorithm::SHA1 => "SHA1",
163 MfaAlgorithm::SHA256 => "SHA256",
164 MfaAlgorithm::SHA512 => "SHA512",
165 },
166 secret.digits,
167 secret.step
168 );
169
170 let qr_code = QrCode::new(url.as_bytes())
171 .map_err(|e| SecurityError::Mfa(format!("Failed to generate QR code: {}", e)))?;
172
173 let qr_svg = qr_code
174 .render()
175 .min_dimensions(self.config.qr_code_size, self.config.qr_code_size)
176 .dark_color(svg::Color("#000000"))
177 .light_color(svg::Color("#FFFFFF"))
178 .build();
179
180 Ok(qr_svg)
181 }
182
183 pub fn verify_code(&self, secret_hex: &str, code: &str) -> Result<bool> {
185 let _secret_bytes = hex::decode(secret_hex)
186 .map_err(|e| SecurityError::Mfa(format!("Invalid secret hex: {}", e)))?;
187
188 let secret = Secret::Encoded(secret_hex.to_string());
189
190 let secret_bytes = secret.to_bytes()
191 .map_err(|e| SecurityError::Mfa(format!("Failed to decode secret: {}", e)))?;
192 let totp = TOTP::new(
193 Algorithm::SHA1,
194 self.config.digits.into(),
195 self.config.skew,
196 self.config.step,
197 secret_bytes,
198 ).map_err(|e| SecurityError::Mfa(format!("Failed to create TOTP: {}", e)))?;
199
200 let current_time = std::time::SystemTime::now()
201 .duration_since(std::time::UNIX_EPOCH)
202 .map_err(|e| SecurityError::Time(e.to_string()))?
203 .as_secs();
204
205 Ok(totp.check(code, current_time))
206 }
207
208 pub fn verify_code_detailed(&self, secret: &MfaSecret, code: &str) -> Result<bool> {
210 let secret_obj = Secret::Encoded(secret.secret.clone());
211
212 let totp = TOTP::new(
213 secret.algorithm.clone().into(),
214 secret.digits,
215 secret.skew,
216 secret.step,
217 secret_obj.to_bytes().map_err(|e| SecurityError::Mfa(format!("Failed to decode secret: {}", e)))?,
218 ).map_err(|e| SecurityError::Mfa(format!("Failed to create TOTP: {}", e)))?;
219
220 let current_time = std::time::SystemTime::now()
221 .duration_since(std::time::UNIX_EPOCH)
222 .map_err(|e| SecurityError::Time(e.to_string()))?
223 .as_secs();
224
225 Ok(totp.check(code, current_time))
226 }
227
228 pub fn generate_backup_codes(&self, count: usize) -> Vec<String> {
230 let mut rng = rand::thread_rng();
231 let mut codes = Vec::with_capacity(count);
232
233 for _ in 0..count {
234 let code: u32 = rng.gen_range(100000..999999);
235 codes.push(format!("{:06}", code));
236 }
237
238 codes
239 }
240
241 pub fn get_current_code(&self, secret_hex: &str) -> Result<String> {
243 let secret = Secret::Encoded(secret_hex.to_string());
244 let secret_bytes = secret.to_bytes()
245 .map_err(|e| SecurityError::Mfa(format!("Failed to decode secret: {}", e)))?;
246
247 let totp = TOTP::new(
248 Algorithm::SHA1,
249 self.config.digits.into(),
250 self.config.skew,
251 self.config.step,
252 secret_bytes,
253 ).map_err(|e| SecurityError::Mfa(format!("Failed to create TOTP: {}", e)))?;
254
255 let current_time = std::time::SystemTime::now()
256 .duration_since(std::time::UNIX_EPOCH)
257 .map_err(|e| SecurityError::Time(e.to_string()))?
258 .as_secs();
259
260 Ok(totp.generate(current_time))
261 }
262
263 pub fn validate_secret(&self, secret_hex: &str) -> Result<()> {
265 hex::decode(secret_hex)
267 .map_err(|_| SecurityError::InvalidInput("Invalid secret format".to_string()))?;
268
269 if secret_hex.len() != 40 {
271 return Err(SecurityError::InvalidInput("Secret must be 40 hex characters".to_string()));
272 }
273
274 Ok(())
275 }
276
277 pub fn get_remaining_time(&self, secret: &MfaSecret) -> Result<u64> {
279 let current_time = std::time::SystemTime::now()
280 .duration_since(std::time::UNIX_EPOCH)
281 .map_err(|e| SecurityError::Time(e.to_string()))?
282 .as_secs();
283
284 let remaining = secret.step - (current_time % secret.step);
285 Ok(remaining)
286 }
287
288 pub fn create_provisioning_uri(&self, secret: &MfaSecret) -> Result<String> {
290 let secret_obj = Secret::Encoded(secret.secret.clone());
291
292 let _totp = TOTP::new(
293 secret.algorithm.clone().into(),
294 secret.digits,
295 secret.skew,
296 secret.step,
297 secret_obj.to_bytes().map_err(|e| SecurityError::Mfa(format!("Failed to decode secret: {}", e)))?,
298 ).map_err(|e| SecurityError::Mfa(format!("Failed to create TOTP: {}", e)))?;
299
300 let issuer_encoded = secret.issuer.replace(" ", "%20").replace(":", "%3A");
303 let account_encoded = secret.account_name.replace(" ", "%20").replace(":", "%3A");
304
305 let uri = format!(
306 "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm={}&digits={}&period={}",
307 issuer_encoded,
308 account_encoded,
309 secret.secret,
310 issuer_encoded,
311 match secret.algorithm {
312 MfaAlgorithm::SHA1 => "SHA1",
313 MfaAlgorithm::SHA256 => "SHA256",
314 MfaAlgorithm::SHA512 => "SHA512",
315 },
316 secret.digits,
317 secret.step
318 );
319
320 Ok(uri)
321 }
322}
323
324pub struct MfaRecoveryManager {
326 codes: HashMap<String, Vec<String>>, }
328
329impl MfaRecoveryManager {
330 pub fn new() -> Self {
331 Self {
332 codes: HashMap::new(),
333 }
334 }
335
336 pub fn generate_recovery_codes(&mut self, user_id: &str, count: usize) -> Vec<String> {
338 let codes = (0..count)
339 .map(|_| {
340 let mut rng = rand::thread_rng();
341 let code: u64 = rng.gen();
342 hex::encode(&code.to_be_bytes()[..6]) })
344 .collect::<Vec<_>>();
345
346 self.codes.insert(user_id.to_string(), codes.clone());
347 codes
348 }
349
350 pub fn verify_recovery_code(&mut self, user_id: &str, code: &str) -> bool {
352 if let Some(codes) = self.codes.get_mut(user_id) {
353 if let Some(pos) = codes.iter().position(|c| c == code) {
354 codes.remove(pos);
355 return true;
356 }
357 }
358 false
359 }
360
361 pub fn get_remaining_codes_count(&self, user_id: &str) -> usize {
363 self.codes.get(user_id).map(|codes| codes.len()).unwrap_or(0)
364 }
365
366 pub fn has_recovery_codes(&self, user_id: &str) -> bool {
368 self.codes.get(user_id).map(|codes| !codes.is_empty()).unwrap_or(false)
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use std::thread;
376 use std::time::Duration;
377
378 fn create_test_service() -> MfaService {
379 MfaService::new()
380 }
381
382 #[test]
383 fn test_generate_secret() {
384 let service = create_test_service();
385 let result = service.generate_secret("test@example.com");
386
387 assert!(result.is_ok());
388 let (secret, qr_code) = result.unwrap();
389
390 assert_eq!(secret.len(), 40); assert!(qr_code.contains("<svg")); assert!(qr_code.contains("test@example.com")); }
394
395 #[test]
396 fn test_secret_validation() {
397 let service = create_test_service();
398
399 assert!(service.validate_secret("1234567890abcdef1234567890abcdef12345678").is_ok());
401
402 assert!(service.validate_secret("gggggggggggggggggggggggggggggggggggggggg").is_err());
404
405 assert!(service.validate_secret("1234567890abcdef").is_err());
407 }
408
409 #[test]
410 fn test_code_verification() {
411 let service = create_test_service();
412
413 let (secret, _) = service.generate_secret("test@example.com").unwrap();
415
416 let current_code = service.get_current_code(&secret).unwrap();
418
419 let is_valid = service.verify_code(&secret, ¤t_code).unwrap();
421 assert!(is_valid);
422
423 let is_valid_wrong = service.verify_code(&secret, "000000").unwrap();
425 assert!(!is_valid_wrong);
426 }
427
428 #[test]
429 fn test_code_expiration() {
430 let service = create_test_service();
431
432 let (secret, _) = service.generate_secret("test@example.com").unwrap();
434
435 let current_code = service.get_current_code(&secret).unwrap();
437
438 thread::sleep(Duration::from_secs(31)); let is_valid_old = service.verify_code(&secret, ¤t_code).unwrap();
443 assert!(!is_valid_old);
444 }
445
446 #[test]
447 fn test_detailed_secret() {
448 let service = create_test_service();
449 let secret = service.generate_secret_detailed("test@example.com").unwrap();
450
451 assert_eq!(secret.digits, 6);
452 assert_eq!(secret.step, 30);
453 assert_eq!(secret.algorithm, MfaAlgorithm::SHA1);
454 assert_eq!(secret.issuer, "Kotoba");
455 assert_eq!(secret.account_name, "test@example.com");
456
457 let qr_code = service.generate_qr_code(&secret).unwrap();
459 assert!(qr_code.contains("<svg"));
460 }
461
462 #[test]
463 fn test_backup_codes() {
464 let service = create_test_service();
465 let codes = service.generate_backup_codes(5);
466
467 assert_eq!(codes.len(), 5);
468 for code in &codes {
469 assert_eq!(code.len(), 6); assert!(code.chars().all(|c| c.is_ascii_digit()));
471 }
472 }
473
474 #[test]
475 fn test_recovery_manager() {
476 let mut manager = MfaRecoveryManager::new();
477
478 let user_id = "user123";
479 let codes = manager.generate_recovery_codes(user_id, 5);
480
481 assert_eq!(codes.len(), 5);
482 assert_eq!(manager.get_remaining_codes_count(user_id), 5);
483
484 let first_code = codes[0].clone();
486 assert!(manager.verify_recovery_code(user_id, &first_code));
487 assert_eq!(manager.get_remaining_codes_count(user_id), 4);
488
489 assert!(!manager.verify_recovery_code(user_id, &first_code));
491 assert_eq!(manager.get_remaining_codes_count(user_id), 4);
492
493 assert!(!manager.verify_recovery_code(user_id, "invalid"));
495 assert_eq!(manager.get_remaining_codes_count(user_id), 4);
496 }
497
498 #[test]
499 fn test_remaining_time() {
500 let service = create_test_service();
501 let secret = service.generate_secret_detailed("test@example.com").unwrap();
502
503 let remaining = service.get_remaining_time(&secret).unwrap();
504 assert!(remaining <= 30);
505 assert!(remaining >= 0);
506 }
507
508 #[test]
509 fn test_provisioning_uri() {
510 let service = create_test_service();
511 let secret = service.generate_secret_detailed("test@example.com").unwrap();
512
513 let uri = service.create_provisioning_uri(&secret).unwrap();
514 assert!(uri.starts_with("otpauth://totp/"));
515 assert!(uri.contains("Kotoba"));
516 assert!(uri.contains("test@example.com"));
517 }
518}