use std::time::{Duration, Instant};
use rand::RngCore;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
use crate::keys::{decrypt_bytes, encrypt_bytes, CryptoError, EncryptedKey};
const SECRET_LEN: usize = 32;
const MAX_FAILURES: u32 = 3;
const LOCKOUT_DURATION: Duration = Duration::from_secs(300);
pub struct TotpSecret {
secret: Vec<u8>,
}
impl TotpSecret {
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self { secret: bytes }
}
pub fn as_bytes(&self) -> &[u8] {
&self.secret
}
}
impl Drop for TotpSecret {
fn drop(&mut self) {
self.secret.zeroize();
}
}
pub fn generate_secret() -> TotpSecret {
let mut secret = vec![0u8; SECRET_LEN];
rand::rng().fill_bytes(&mut secret);
TotpSecret { secret }
}
pub fn qr_code_unicode(secret: &TotpSecret, issuer: &str, account: &str) -> String {
qr_code_unicode_raw(&build_totp_uri(secret, issuer, account))
}
#[cfg(feature = "qr")]
pub fn qr_code_unicode_raw(payload: &str) -> String {
use qrcode::render::unicode;
use qrcode::QrCode;
match QrCode::new(payload.as_bytes()) {
Ok(code) => code
.render::<unicode::Dense1x2>()
.dark_color(unicode::Dense1x2::Light)
.light_color(unicode::Dense1x2::Dark)
.build(),
Err(e) => {
tracing::warn!(error = %e, "QR code generation failed");
format!("(QR code unavailable: {e})")
}
}
}
pub fn qr_code_png_base64(secret: &TotpSecret, issuer: &str, account: &str) -> String {
qr_code_png_base64_raw(&build_totp_uri(secret, issuer, account))
}
#[cfg(feature = "qr")]
pub fn qr_code_png_base64_raw(payload: &str) -> String {
use image::Luma;
use qrcode::QrCode;
let scale: u32 = 8;
match QrCode::new(payload.as_bytes()) {
Ok(code) => {
let img = code
.render::<Luma<u8>>()
.quiet_zone(true)
.min_dimensions(scale * 21, scale * 21)
.build();
let mut png_bytes: Vec<u8> = Vec::new();
let encoder = image::codecs::png::PngEncoder::new(&mut png_bytes);
if let Err(e) = image::ImageEncoder::write_image(
encoder,
img.as_raw(),
img.width(),
img.height(),
image::ExtendedColorType::L8,
) {
tracing::warn!(error = %e, "PNG encoding failed");
return format!("(QR PNG unavailable: {e})");
}
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD.encode(&png_bytes);
format!("data:image/png;base64,{b64}")
}
Err(e) => {
tracing::warn!(error = %e, "QR code generation failed");
format!("(QR code unavailable: {e})")
}
}
}
#[cfg(not(feature = "qr"))]
pub fn qr_code_unicode_raw(payload: &str) -> String {
payload.to_string()
}
#[cfg(not(feature = "qr"))]
pub fn qr_code_png_base64_raw(payload: &str) -> String {
payload.to_string()
}
pub fn verify_code(secret: &TotpSecret, code: &str) -> bool {
let Ok(totp) = build_totp(secret) else {
return false;
};
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let step = 30u64;
let code_bytes = code.as_bytes();
let mut matched = false;
for offset in [0i64, -1, 1] {
let time = (now as i64 + offset * step as i64) as u64;
let expected = totp.generate(time);
let expected_bytes = expected.as_bytes();
let step_ok: bool = bool::from(code_bytes.ct_eq(expected_bytes));
matched |= step_ok;
}
matched
}
pub fn encrypt_secret(secret: &TotpSecret, passphrase: &str) -> Result<EncryptedKey, CryptoError> {
encrypt_bytes(&secret.secret, passphrase)
}
pub fn decrypt_secret(
encrypted: &EncryptedKey,
passphrase: &str,
) -> Result<TotpSecret, CryptoError> {
let bytes = decrypt_bytes(encrypted, passphrase)?;
Ok(TotpSecret::from_bytes(bytes))
}
pub struct RateLimiter {
failures: u32,
locked_until: Option<Instant>,
}
impl RateLimiter {
pub fn new() -> Self {
Self {
failures: 0,
locked_until: None,
}
}
pub fn is_locked(&self) -> bool {
self.locked_until
.map(|until| Instant::now() < until)
.unwrap_or(false)
}
pub fn check_and_record(&mut self, valid: bool) -> Result<(), RateLimitError> {
if self.is_locked() {
let remaining = self
.locked_until
.unwrap_or_else(Instant::now)
.saturating_duration_since(Instant::now());
return Err(RateLimitError::LockedOut {
remaining_secs: remaining.as_secs(),
});
}
if self.locked_until.is_some() && !self.is_locked() {
self.locked_until = None;
self.failures = 0;
}
if valid {
self.failures = 0;
self.locked_until = None;
Ok(())
} else {
self.failures += 1;
if self.failures >= MAX_FAILURES {
self.locked_until = Some(Instant::now() + LOCKOUT_DURATION);
Err(RateLimitError::LockedOut {
remaining_secs: LOCKOUT_DURATION.as_secs(),
})
} else {
Err(RateLimitError::InvalidCode {
attempts_remaining: MAX_FAILURES - self.failures,
})
}
}
}
pub fn attempts_remaining(&self) -> u32 {
MAX_FAILURES.saturating_sub(self.failures)
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum RateLimitError {
#[error("invalid code ({attempts_remaining} attempts remaining)")]
InvalidCode { attempts_remaining: u32 },
#[error("locked out for {remaining_secs} seconds")]
LockedOut { remaining_secs: u64 },
}
pub fn current_code(secret: &TotpSecret) -> Result<String, String> {
let totp = build_totp(secret).map_err(|e| e.to_string())?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Ok(totp.generate(now))
}
pub fn build_totp_uri(secret: &TotpSecret, issuer: &str, account: &str) -> String {
use totp_rs::Secret;
let encoded = Secret::Raw(secret.secret.clone()).to_encoded().to_string();
format!(
"otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits=6&period=30",
issuer, account, encoded, issuer
)
}
fn build_totp(secret: &TotpSecret) -> Result<totp_rs::TOTP, totp_rs::TotpUrlError> {
use totp_rs::{Algorithm, Secret, TOTP};
let bytes = Secret::Raw(secret.secret.clone())
.to_bytes()
.map_err(|_| totp_rs::TotpUrlError::Secret("failed to convert TOTP secret".into()))?;
TOTP::new(Algorithm::SHA1, 6, 1, 30, bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_secret_produces_correct_length() {
let secret = generate_secret();
assert_eq!(secret.as_bytes().len(), SECRET_LEN);
}
#[cfg(feature = "qr")]
#[test]
fn qr_code_contains_unicode() {
let secret = generate_secret();
let qr = qr_code_unicode(&secret, "Koi", "test@example.com");
assert!(qr.contains('\n'));
assert!(!qr.is_empty());
}
#[cfg(not(feature = "qr"))]
#[test]
fn qr_code_falls_back_to_uri_without_feature() {
let secret = generate_secret();
let out = qr_code_unicode(&secret, "Koi", "test@example.com");
assert!(out.starts_with("otpauth://"));
}
#[test]
fn verify_valid_code() {
let secret = generate_secret();
let totp = build_totp(&secret).unwrap();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let code = totp.generate(now);
assert!(verify_code(&secret, &code));
}
#[test]
fn verify_invalid_code() {
let secret = generate_secret();
let valid = current_code(&secret).unwrap();
let invalid = if valid != "000000" {
"000000"
} else {
"111111"
};
assert!(!verify_code(&secret, invalid));
}
#[test]
fn rate_limiter_allows_initial_attempts() {
let rl = RateLimiter::new();
assert!(!rl.is_locked());
assert_eq!(rl.attempts_remaining(), 3);
}
#[test]
fn rate_limiter_tracks_failures() {
let mut rl = RateLimiter::new();
let r = rl.check_and_record(false);
assert!(r.is_err());
assert_eq!(rl.attempts_remaining(), 2);
let r = rl.check_and_record(false);
assert!(r.is_err());
assert_eq!(rl.attempts_remaining(), 1);
}
#[test]
fn rate_limiter_locks_after_max_failures() {
let mut rl = RateLimiter::new();
let _ = rl.check_and_record(false);
let _ = rl.check_and_record(false);
let r = rl.check_and_record(false);
assert!(r.is_err());
assert!(rl.is_locked());
assert!(matches!(r, Err(RateLimitError::LockedOut { .. })));
}
#[test]
fn rate_limiter_resets_on_success() {
let mut rl = RateLimiter::new();
let _ = rl.check_and_record(false);
let _ = rl.check_and_record(false);
let r = rl.check_and_record(true);
assert!(r.is_ok());
assert!(!rl.is_locked());
assert_eq!(rl.attempts_remaining(), 3);
}
#[test]
fn rate_limiter_rejects_during_lockout() {
let mut rl = RateLimiter::new();
let _ = rl.check_and_record(false);
let _ = rl.check_and_record(false);
let _ = rl.check_and_record(false);
assert!(rl.is_locked());
let r = rl.check_and_record(true);
assert!(r.is_err());
}
#[test]
fn encrypt_decrypt_secret_round_trip() {
let secret = generate_secret();
let original_bytes = secret.as_bytes().to_vec();
let encrypted = encrypt_secret(&secret, "test-pass").unwrap();
let decrypted = decrypt_secret(&encrypted, "test-pass").unwrap();
assert_eq!(decrypted.as_bytes(), &original_bytes);
}
#[test]
fn totp_uri_format() {
let secret = generate_secret();
let uri = build_totp_uri(&secret, "Koi Certmesh", "admin@stone-01");
assert!(uri.starts_with("otpauth://totp/Koi Certmesh:admin@stone-01?secret="));
assert!(uri.contains("algorithm=SHA1"));
assert!(uri.contains("digits=6"));
assert!(uri.contains("period=30"));
}
}