use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2,
};
use totp_rs::{Algorithm, Secret, TOTP};
use crate::errors::AppError;
const RECOVERY_CODE_COUNT: usize = 10;
const RECOVERY_CODE_LENGTH: usize = 16;
#[derive(Clone)]
pub struct TotpService {
issuer: String,
skew: u8,
}
impl TotpService {
pub fn new(issuer: impl Into<String>) -> Self {
Self {
issuer: issuer.into(),
skew: 1, }
}
pub fn generate_secret(&self) -> String {
Secret::generate_secret().to_encoded().to_string()
}
pub fn get_otpauth_uri(&self, secret: &str, email: &str) -> Result<String, AppError> {
let totp = self.create_totp(secret, email)?;
Ok(totp.get_url())
}
pub fn verify_with_replay_check(
&self,
secret: &str,
code: &str,
email: &str,
last_used_time_step: Option<i64>,
) -> Result<Option<i64>, AppError> {
let totp = self.create_totp(secret, email)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| AppError::Internal(anyhow::anyhow!("System time error: {}", e)))?
.as_secs();
let current_time_step = (now / 30) as i64;
if !totp.check_current(code).unwrap_or(false) {
return Ok(None);
}
if let Some(last_step) = last_used_time_step {
if current_time_step < last_step {
tracing::warn!(
current_time_step = current_time_step,
last_used_time_step = last_step,
drift_steps = last_step - current_time_step,
"H-05: TOTP verification failed - system clock appears to have gone backward. \
Check NTP configuration and avoid clock stepping."
);
return Ok(None);
} else if current_time_step == last_step {
tracing::warn!(
current_time_step = current_time_step,
last_used_time_step = last_step,
"S-14: TOTP replay attack detected - code reuse within same time step"
);
return Ok(None);
}
}
Ok(Some(current_time_step))
}
pub fn generate_recovery_codes(&self) -> Vec<String> {
(0..RECOVERY_CODE_COUNT)
.map(|_| self.generate_recovery_code())
.collect()
}
pub fn hash_recovery_code(code: &str) -> Result<String, AppError> {
let normalized = code.to_uppercase().replace('-', "");
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
argon2
.hash_password(normalized.as_bytes(), &salt)
.map(|hash| hash.to_string())
.map_err(|e| AppError::Internal(anyhow::anyhow!("Recovery code hashing failed: {}", e)))
}
pub fn verify_recovery_code(code: &str, hash: &str) -> bool {
let normalized = code.to_uppercase().replace('-', "");
let parsed_hash = match PasswordHash::new(hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(normalized.as_bytes(), &parsed_hash)
.is_ok()
}
fn create_totp(&self, secret: &str, email: &str) -> Result<TOTP, AppError> {
let secret = Secret::Encoded(secret.to_string())
.to_bytes()
.map_err(|e| AppError::Internal(anyhow::anyhow!("Invalid TOTP secret: {}", e)))?;
TOTP::new(
Algorithm::SHA1,
6, self.skew, 30, secret,
Some(self.issuer.clone()),
email.to_string(),
)
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to create TOTP: {}", e)))
}
fn generate_recovery_code(&self) -> String {
use rand::Rng;
let mut rng = rand::rngs::OsRng;
let charset: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let code: String = (0..RECOVERY_CODE_LENGTH)
.map(|_| {
let idx = rng.gen_range(0..charset.len());
charset[idx] as char
})
.collect();
format!(
"{}-{}-{}-{}",
&code[0..4],
&code[4..8],
&code[8..12],
&code[12..16]
)
}
}
impl Default for TotpService {
fn default() -> Self {
Self::new("Cedros")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_secret() {
let service = TotpService::new("Test App");
let secret = service.generate_secret();
assert!(!secret.is_empty());
assert!(secret
.chars()
.all(|c| c.is_ascii_uppercase() || c.is_ascii_digit()));
}
#[test]
fn test_get_otpauth_uri() {
let service = TotpService::new("Test App");
let secret = service.generate_secret();
let uri = service
.get_otpauth_uri(&secret, "test@example.com")
.unwrap();
assert!(uri.starts_with("otpauth://totp/"));
assert!(uri.contains("test%40example.com"));
assert!(uri.contains("issuer=Test%20App"));
}
#[test]
fn test_generate_recovery_codes() {
let service = TotpService::new("Test App");
let codes = service.generate_recovery_codes();
assert_eq!(codes.len(), RECOVERY_CODE_COUNT);
for code in &codes {
assert_eq!(code.len(), 19);
assert!(code.chars().nth(4) == Some('-'));
assert!(code.chars().nth(9) == Some('-'));
assert!(code.chars().nth(14) == Some('-'));
}
let mut unique = codes.clone();
unique.sort();
unique.dedup();
assert_eq!(unique.len(), codes.len());
}
#[test]
fn test_hash_recovery_code() {
let hash1 = TotpService::hash_recovery_code("ABCD-1234").unwrap();
let hash2 = TotpService::hash_recovery_code("ABCD-1234").unwrap();
assert_ne!(hash1, hash2);
assert!(TotpService::verify_recovery_code("ABCD-1234", &hash1));
assert!(TotpService::verify_recovery_code("ABCD-1234", &hash2));
assert!(hash1.starts_with("$argon2id$"));
}
#[test]
fn test_verify_recovery_code_case_insensitive() {
let hash = TotpService::hash_recovery_code("ABCD-1234").unwrap();
assert!(TotpService::verify_recovery_code("ABCD-1234", &hash));
assert!(TotpService::verify_recovery_code("abcd-1234", &hash));
assert!(TotpService::verify_recovery_code("ABCD1234", &hash));
assert!(TotpService::verify_recovery_code("abcd1234", &hash));
assert!(!TotpService::verify_recovery_code("WXYZ-5678", &hash));
}
#[test]
fn test_verify_recovery_code_invalid_hash() {
assert!(!TotpService::verify_recovery_code(
"ABCD-1234",
"invalid-hash"
));
assert!(!TotpService::verify_recovery_code("ABCD-1234", ""));
}
#[test]
fn test_verify_with_replay_check_no_previous_use() {
let service = TotpService::new("Test App");
let secret = service.generate_secret();
let totp = TOTP::new(
Algorithm::SHA1,
6,
1,
30,
Secret::Encoded(secret.clone()).to_bytes().unwrap(),
Some("Test App".to_string()),
"test@example.com".to_string(),
)
.unwrap();
let valid_code = totp.generate_current().unwrap();
let result = service
.verify_with_replay_check(&secret, &valid_code, "test@example.com", None)
.unwrap();
assert!(result.is_some());
assert!(result.unwrap() > 0);
}
#[test]
fn test_verify_with_replay_check_rejects_replay() {
let service = TotpService::new("Test App");
let secret = service.generate_secret();
let totp = TOTP::new(
Algorithm::SHA1,
6,
1,
30,
Secret::Encoded(secret.clone()).to_bytes().unwrap(),
Some("Test App".to_string()),
"test@example.com".to_string(),
)
.unwrap();
let valid_code = totp.generate_current().unwrap();
let first_result = service
.verify_with_replay_check(&secret, &valid_code, "test@example.com", None)
.unwrap();
assert!(first_result.is_some());
let time_step = first_result.unwrap();
let second_result = service
.verify_with_replay_check(&secret, &valid_code, "test@example.com", Some(time_step))
.unwrap();
assert!(
second_result.is_none(),
"S-14: Replay attack should be rejected"
);
}
#[test]
fn test_verify_with_replay_check_invalid_code() {
let service = TotpService::new("Test App");
let secret = service.generate_secret();
let result = service
.verify_with_replay_check(&secret, "000000", "test@example.com", None)
.unwrap();
assert!(result.is_none());
}
#[test]
fn test_verify_with_replay_check_backward_clock() {
let service = TotpService::new("Test App");
let secret = service.generate_secret();
let totp = TOTP::new(
Algorithm::SHA1,
6,
1,
30,
Secret::Encoded(secret.clone()).to_bytes().unwrap(),
Some("Test App".to_string()),
"test@example.com".to_string(),
)
.unwrap();
let valid_code = totp.generate_current().unwrap();
let future_time_step = i64::MAX - 1;
let result = service
.verify_with_replay_check(
&secret,
&valid_code,
"test@example.com",
Some(future_time_step),
)
.unwrap();
assert!(result.is_none(), "H-05: Backward clock should be rejected");
}
}