use hmac::{Hmac, Mac};
use sha1::Sha1;
use crate::error::Error;
type HmacSha1 = Hmac<Sha1>;
const TIME_STEP: u64 = 30;
const CODE_DIGITS: u32 = 6;
const SKEW: i64 = 1;
pub fn generate_secret() -> String {
use rand::RngCore;
let mut buf = [0u8; 20]; rand::thread_rng().fill_bytes(&mut buf);
base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &buf)
}
pub fn otpauth_uri(secret_base32: &str, account: &str) -> String {
format!(
"otpauth://totp/envseal:{account}?secret={secret_base32}&issuer=envseal&digits={CODE_DIGITS}&period={TIME_STEP}"
)
}
pub fn generate_code(secret_base32: &str) -> Result<String, Error> {
let now = current_time_step()?;
compute_totp(secret_base32, now)
}
pub fn verify_code(secret_base32: &str, user_code: &str) -> Result<bool, Error> {
let now = current_time_step()?;
let user_code = user_code.trim();
for offset in -SKEW..=SKEW {
let step = if offset < 0 {
now.checked_sub(offset.unsigned_abs())
} else {
#[allow(clippy::cast_sign_loss)]
now.checked_add(offset as u64)
};
if let Some(step) = step {
let expected = compute_totp(secret_base32, step)?;
if constant_time_eq(user_code.as_bytes(), expected.as_bytes()) {
return Ok(true);
}
}
}
Ok(false)
}
pub fn encrypt_secret(secret_base32: &str, master_key: &[u8; 32]) -> Result<String, Error> {
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use rand::RngCore;
let cipher = Aes256Gcm::new(master_key.into());
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, secret_base32.as_bytes())
.map_err(|_| Error::CryptoFailure("failed to encrypt TOTP secret".to_string()))?;
let mut combined = Vec::with_capacity(12 + ciphertext.len());
combined.extend_from_slice(&nonce_bytes);
combined.extend_from_slice(&ciphertext);
Ok(hex_encode(&combined))
}
pub fn decrypt_secret(encrypted_hex: &str, master_key: &[u8; 32]) -> Result<String, Error> {
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
let combined = hex_decode(encrypted_hex)?;
if combined.len() < 13 {
return Err(Error::CryptoFailure(
"encrypted TOTP secret too short".to_string(),
));
}
let (nonce_bytes, ciphertext) = combined.split_at(12);
let cipher = Aes256Gcm::new(master_key.into());
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|_| Error::CryptoFailure("failed to decrypt TOTP secret".to_string()))?;
String::from_utf8(plaintext)
.map_err(|_| Error::CryptoFailure("TOTP secret is not valid UTF-8".to_string()))
}
fn compute_totp(secret_base32: &str, time_step: u64) -> Result<String, Error> {
let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, secret_base32)
.ok_or_else(|| Error::CryptoFailure("invalid base32 TOTP secret".to_string()))?;
let mut mac = HmacSha1::new_from_slice(&secret)
.map_err(|_| Error::CryptoFailure("invalid HMAC key length".to_string()))?;
mac.update(&time_step.to_be_bytes());
let result = mac.finalize().into_bytes();
let offset = (result[19] & 0x0f) as usize;
let code = u32::from_be_bytes([
result[offset] & 0x7f,
result[offset + 1],
result[offset + 2],
result[offset + 3],
]);
let modulus = 10u32.pow(CODE_DIGITS);
Ok(format!(
"{:0>width$}",
code % modulus,
width = CODE_DIGITS as usize
))
}
fn current_time_step() -> Result<u64, Error> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|_| Error::CryptoFailure("system clock before UNIX epoch".to_string()))?;
Ok(now.as_secs() / TIME_STEP)
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn hex_encode(bytes: &[u8]) -> String {
use std::fmt::Write;
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(s, "{b:02x}");
}
s
}
fn hex_decode(hex: &str) -> Result<Vec<u8>, Error> {
if hex.len() % 2 != 0 {
return Err(Error::CryptoFailure("odd-length hex string".to_string()));
}
let mut bytes = Vec::with_capacity(hex.len() / 2);
for i in (0..hex.len()).step_by(2) {
let byte = u8::from_str_radix(&hex[i..i + 2], 16)
.map_err(|_| Error::CryptoFailure("invalid hex character".to_string()))?;
bytes.push(byte);
}
Ok(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_and_verify() {
let secret = generate_secret();
assert!(!secret.is_empty());
let code = generate_code(&secret).unwrap();
assert_eq!(code.len(), 6);
assert!(code.chars().all(|c| c.is_ascii_digit()));
assert!(verify_code(&secret, &code).unwrap());
assert!(!verify_code(&secret, "000000").unwrap_or(true));
}
#[test]
fn encrypt_decrypt_roundtrip() {
let secret = generate_secret();
let master_key = [42u8; 32];
let encrypted = encrypt_secret(&secret, &master_key).unwrap();
let decrypted = decrypt_secret(&encrypted, &master_key).unwrap();
assert_eq!(secret, decrypted);
}
#[test]
fn otpauth_uri_format() {
let uri = otpauth_uri("JBSWY3DPEHPK3PXP", "test@example.com");
assert!(uri.starts_with("otpauth://totp/"));
assert!(uri.contains("secret=JBSWY3DPEHPK3PXP"));
assert!(uri.contains("issuer=envseal"));
assert!(uri.contains("digits=6"));
assert!(uri.contains("period=30"));
}
#[test]
fn constant_time_eq_works() {
assert!(constant_time_eq(b"123456", b"123456"));
assert!(!constant_time_eq(b"123456", b"654321"));
assert!(!constant_time_eq(b"123456", b"12345"));
}
#[test]
fn rfc6238_test_vector() {
let secret_b32 = base32::encode(
base32::Alphabet::Rfc4648 { padding: false },
b"12345678901234567890",
);
let code = compute_totp(&secret_b32, 1).unwrap();
assert_eq!(code.len(), 6);
assert!(code.chars().all(|c| c.is_ascii_digit()));
}
}