use hmac::{Hmac, Mac};
use sha1::Sha1;
use crate::error::Error;
type HmacSha1 = Hmac<Sha1>;
const TIME_STEP: u64 = 30;
const CODE_DIGITS: u32 = 6;
const SKEW: i64 = 1;
pub fn generate_secret() -> String {
use rand::rngs::OsRng;
use rand::RngCore;
let mut buf = [0u8; 20]; OsRng.fill_bytes(&mut buf);
base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &buf)
}
pub fn otpauth_uri(secret_base32: &str, account: &str) -> String {
let encoded_account = percent_encode_path(account);
format!(
"otpauth://totp/envseal:{encoded_account}?secret={secret_base32}&issuer=envseal&digits={CODE_DIGITS}&period={TIME_STEP}"
)
}
fn percent_encode_path(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for byte in s.bytes() {
let allowed = byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.' | b'~');
if allowed {
out.push(byte as char);
} else {
use std::fmt::Write as _;
let _ = write!(out, "%{byte:02X}");
}
}
out
}
pub fn generate_code(secret_base32: &str) -> Result<String, Error> {
let now = current_time_step()?;
compute_totp(secret_base32, now)
}
pub fn verify_code(secret_base32: &str, user_code: &str) -> Result<bool, Error> {
if !rate_limit::permit_attempt() {
return Err(Error::CryptoFailure(
"TOTP verification rate-limit exceeded \
(5 failed attempts per 60s); slow down and retry later"
.to_string(),
));
}
let now = current_time_step()?;
let user_code = user_code.trim();
for offset in -SKEW..=SKEW {
let step = if offset < 0 {
now.checked_sub(offset.unsigned_abs())
} else {
#[allow(clippy::cast_sign_loss)]
now.checked_add(offset as u64)
};
if let Some(step) = step {
let expected = compute_totp(secret_base32, step)?;
if constant_time_eq(user_code.as_bytes(), expected.as_bytes()) {
rate_limit::record_success();
return Ok(true);
}
}
}
rate_limit::record_failure();
Ok(false)
}
pub fn encrypt_secret(secret_base32: &str, master_key: &[u8; 32]) -> Result<String, Error> {
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use rand::RngCore;
let cipher = Aes256Gcm::new(master_key.into());
let mut nonce_bytes = [0u8; 12];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, secret_base32.as_bytes())
.map_err(|_| Error::CryptoFailure("failed to encrypt TOTP secret".to_string()))?;
let mut combined = Vec::with_capacity(12 + ciphertext.len());
combined.extend_from_slice(&nonce_bytes);
combined.extend_from_slice(&ciphertext);
Ok(hex_encode(&combined))
}
pub fn decrypt_secret(encrypted_hex: &str, master_key: &[u8; 32]) -> Result<String, Error> {
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
let combined = hex_decode(encrypted_hex)?;
if combined.len() < 13 {
return Err(Error::CryptoFailure(
"encrypted TOTP secret too short".to_string(),
));
}
let (nonce_bytes, ciphertext) = combined.split_at(12);
let cipher = Aes256Gcm::new(master_key.into());
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|_| Error::CryptoFailure("failed to decrypt TOTP secret".to_string()))?;
String::from_utf8(plaintext)
.map_err(|_| Error::CryptoFailure("TOTP secret is not valid UTF-8".to_string()))
}
fn compute_totp(secret_base32: &str, time_step: u64) -> Result<String, Error> {
let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, secret_base32)
.ok_or_else(|| Error::CryptoFailure("invalid base32 TOTP secret".to_string()))?;
let mut mac = HmacSha1::new_from_slice(&secret)
.map_err(|_| Error::CryptoFailure("invalid HMAC key length".to_string()))?;
mac.update(&time_step.to_be_bytes());
let result = mac.finalize().into_bytes();
let offset = (result[19] & 0x0f) as usize;
let code = u32::from_be_bytes([
result[offset] & 0x7f,
result[offset + 1],
result[offset + 2],
result[offset + 3],
]);
let modulus = 10u32.pow(CODE_DIGITS);
Ok(format!(
"{:0>width$}",
code % modulus,
width = CODE_DIGITS as usize
))
}
fn current_time_step() -> Result<u64, Error> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|_| Error::CryptoFailure("system clock before UNIX epoch".to_string()))?;
Ok(now.as_secs() / TIME_STEP)
}
mod rate_limit {
use std::sync::Mutex;
use std::time::{Duration, Instant};
const MAX_FAILURES: usize = 5;
const WINDOW: Duration = Duration::from_secs(60);
static FAILURES: Mutex<Vec<Instant>> = Mutex::new(Vec::new());
pub fn permit_attempt() -> bool {
let mut guard = FAILURES
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let cutoff = Instant::now()
.checked_sub(WINDOW)
.unwrap_or_else(Instant::now);
guard.retain(|t| *t > cutoff);
guard.len() < MAX_FAILURES
}
pub fn record_failure() {
let mut guard = FAILURES
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.push(Instant::now());
let cutoff = Instant::now()
.checked_sub(WINDOW)
.unwrap_or_else(Instant::now);
guard.retain(|t| *t > cutoff);
}
pub fn record_success() {
let mut guard = FAILURES
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.clear();
}
#[cfg(test)]
#[allow(dead_code)] pub(super) fn reset_for_test() {
let mut guard = FAILURES
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.clear();
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= std::hint::black_box(x ^ y);
}
std::hint::black_box(diff) == 0
}
fn hex_encode(bytes: &[u8]) -> String {
use std::fmt::Write;
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(s, "{b:02x}");
}
s
}
fn hex_decode(hex: &str) -> Result<Vec<u8>, Error> {
if hex.len() % 2 != 0 {
return Err(Error::CryptoFailure("odd-length hex string".to_string()));
}
let mut bytes = Vec::with_capacity(hex.len() / 2);
for i in (0..hex.len()).step_by(2) {
let byte = u8::from_str_radix(&hex[i..i + 2], 16)
.map_err(|_| Error::CryptoFailure("invalid hex character".to_string()))?;
bytes.push(byte);
}
Ok(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_and_verify() {
let secret = generate_secret();
assert!(!secret.is_empty());
let code = generate_code(&secret).unwrap();
assert_eq!(code.len(), 6);
assert!(code.chars().all(|c| c.is_ascii_digit()));
assert!(verify_code(&secret, &code).unwrap());
assert!(!verify_code(&secret, "000000").unwrap_or(true));
}
#[test]
fn encrypt_decrypt_roundtrip() {
let secret = generate_secret();
let master_key = [42u8; 32];
let encrypted = encrypt_secret(&secret, &master_key).unwrap();
let decrypted = decrypt_secret(&encrypted, &master_key).unwrap();
assert_eq!(secret, decrypted);
}
#[test]
fn otpauth_uri_format() {
let uri = otpauth_uri("JBSWY3DPEHPK3PXP", "test@example.com");
assert!(uri.starts_with("otpauth://totp/"));
assert!(uri.contains("secret=JBSWY3DPEHPK3PXP"));
assert!(uri.contains("issuer=envseal"));
assert!(uri.contains("digits=6"));
assert!(uri.contains("period=30"));
}
#[test]
fn constant_time_eq_works() {
assert!(constant_time_eq(b"123456", b"123456"));
assert!(!constant_time_eq(b"123456", b"654321"));
assert!(!constant_time_eq(b"123456", b"12345"));
}
#[test]
fn rfc6238_test_vector() {
let secret_b32 = base32::encode(
base32::Alphabet::Rfc4648 { padding: false },
b"12345678901234567890",
);
let code = compute_totp(&secret_b32, 1).unwrap();
assert_eq!(code.len(), 6);
assert!(code.chars().all(|c| c.is_ascii_digit()));
}
}