use hmac::{Hmac, Mac};
use sha1::Sha1;
use sha2::{Sha256, Sha512};
use zeroize::Zeroizing;
use crate::errors::{SafeError, SafeResult};
type HmacSha1 = Hmac<Sha1>;
type HmacSha256 = Hmac<Sha256>;
type HmacSha512 = Hmac<Sha512>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TotpAlgorithm {
Sha1,
Sha256,
Sha512,
}
impl TotpAlgorithm {
fn parse(value: &str) -> SafeResult<Self> {
match value.to_ascii_uppercase().as_str() {
"SHA1" | "SHA-1" => Ok(Self::Sha1),
"SHA256" | "SHA-256" => Ok(Self::Sha256),
"SHA512" | "SHA-512" => Ok(Self::Sha512),
other => Err(SafeError::InvalidVault {
reason: format!("unsupported TOTP algorithm '{other}'"),
}),
}
}
fn as_uri_str(self) -> &'static str {
match self {
Self::Sha1 => "SHA1",
Self::Sha256 => "SHA256",
Self::Sha512 => "SHA512",
}
}
}
#[derive(Debug)]
struct TotpConfig {
secret: Zeroizing<String>,
algorithm: TotpAlgorithm,
digits: u32,
period: u64,
}
impl TotpConfig {
fn default_for_secret(secret: String) -> Self {
Self {
secret: Zeroizing::new(secret),
algorithm: TotpAlgorithm::Sha1,
digits: 6,
period: 30,
}
}
}
pub fn extract_base32(input: &str) -> SafeResult<Zeroizing<String>> {
Ok(parse_totp_config(input)?.secret)
}
pub fn provisioning_uri(
label: &str,
input: &str,
algorithm: Option<&str>,
digits: Option<u32>,
period: Option<u64>,
) -> SafeResult<String> {
let mut config = parse_totp_config(input)?;
if let Some(algorithm) = algorithm {
config.algorithm = TotpAlgorithm::parse(algorithm)?;
}
if let Some(digits) = digits {
config.digits = validate_digits(digits)?;
}
if let Some(period) = period {
config.period = validate_period(period)?;
}
Ok(format!(
"otpauth://totp/{label}?secret={}&algorithm={}&digits={}&period={}",
config.secret.as_str(),
config.algorithm.as_uri_str(),
config.digits,
config.period
))
}
pub fn generate_code(input: &str) -> SafeResult<String> {
generate_code_at(input, unix_timestamp())
}
pub fn generate_code_at(input: &str, timestamp: u64) -> SafeResult<String> {
let config = parse_totp_config(input)?;
let key_bytes = decode_base32(&config.secret)?;
let counter = timestamp / config.period;
let code = hotp(&key_bytes, counter, config.digits, config.algorithm)?;
Ok(format_code(code, config.digits))
}
pub fn seconds_remaining() -> u64 {
let ts = unix_timestamp();
30 - (ts % 30)
}
pub fn seconds_remaining_for(input: &str) -> SafeResult<u64> {
seconds_remaining_for_at(input, unix_timestamp())
}
pub fn seconds_remaining_for_at(input: &str, timestamp: u64) -> SafeResult<u64> {
let config = parse_totp_config(input)?;
Ok(config.period - (timestamp % config.period))
}
fn unix_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn decode_base32(s: &str) -> SafeResult<Vec<u8>> {
base32::decode(base32::Alphabet::Rfc4648 { padding: false }, s)
.or_else(|| base32::decode(base32::Alphabet::Rfc4648 { padding: true }, s))
.ok_or_else(|| SafeError::InvalidVault {
reason: "invalid TOTP base32 secret".into(),
})
}
fn parse_totp_config(input: &str) -> SafeResult<TotpConfig> {
if input.starts_with("otpauth://") {
parse_otpauth_uri(input)
} else {
let normalised = normalise_base32(input);
decode_base32(&normalised)?;
Ok(TotpConfig::default_for_secret(normalised))
}
}
fn parse_otpauth_uri(input: &str) -> SafeResult<TotpConfig> {
let query_start = input.find('?').ok_or_else(|| SafeError::InvalidVault {
reason: "otpauth:// URI has no query string".into(),
})?;
let query = &input[query_start + 1..];
let mut secret = None;
let mut algorithm = TotpAlgorithm::Sha1;
let mut digits = 6;
let mut period = 30;
for pair in query.split('&') {
let Some((key, value)) = pair.split_once('=') else {
continue;
};
let value = decode_query_component(value)?;
match key.to_ascii_lowercase().as_str() {
"secret" if secret.is_none() => {
let normalised = normalise_base32(&value);
decode_base32(&normalised)?;
secret = Some(Zeroizing::new(normalised));
}
"algorithm" => {
algorithm = TotpAlgorithm::parse(&value)?;
}
"digits" => {
digits = parse_digits(&value)?;
}
"period" => {
period = parse_period(&value)?;
}
_ => {}
}
}
let secret = secret.ok_or_else(|| SafeError::InvalidVault {
reason: "otpauth:// URI is missing the 'secret' parameter".into(),
})?;
Ok(TotpConfig {
secret,
algorithm,
digits,
period,
})
}
fn normalise_base32(raw: &str) -> String {
raw.chars()
.filter(|c| !c.is_whitespace() && *c != '-')
.map(|c| c.to_ascii_uppercase())
.collect()
}
fn decode_query_component(input: &str) -> SafeResult<String> {
let bytes = input.as_bytes();
let mut out = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'%' if i + 2 < bytes.len() => {
let high = from_hex(bytes[i + 1])?;
let low = from_hex(bytes[i + 2])?;
out.push((high << 4) | low);
i += 3;
}
b'%' => {
return Err(SafeError::InvalidVault {
reason: "invalid percent encoding in otpauth:// URI".into(),
});
}
b'+' => {
out.push(b' ');
i += 1;
}
byte => {
out.push(byte);
i += 1;
}
}
}
String::from_utf8(out).map_err(|e| SafeError::InvalidVault {
reason: format!("otpauth:// URI parameter is not UTF-8: {e}"),
})
}
fn from_hex(byte: u8) -> SafeResult<u8> {
match byte {
b'0'..=b'9' => Ok(byte - b'0'),
b'a'..=b'f' => Ok(byte - b'a' + 10),
b'A'..=b'F' => Ok(byte - b'A' + 10),
_ => Err(SafeError::InvalidVault {
reason: "invalid percent encoding in otpauth:// URI".into(),
}),
}
}
fn parse_digits(value: &str) -> SafeResult<u32> {
let digits = value.parse::<u32>().map_err(|_| SafeError::InvalidVault {
reason: format!("invalid TOTP digits '{value}'"),
})?;
validate_digits(digits)
}
fn validate_digits(digits: u32) -> SafeResult<u32> {
if (1..=10).contains(&digits) {
Ok(digits)
} else {
Err(SafeError::InvalidVault {
reason: "TOTP digits must be between 1 and 10".into(),
})
}
}
fn parse_period(value: &str) -> SafeResult<u64> {
let period = value.parse::<u64>().map_err(|_| SafeError::InvalidVault {
reason: format!("invalid TOTP period '{value}'"),
})?;
validate_period(period)
}
fn validate_period(period: u64) -> SafeResult<u64> {
if period > 0 {
Ok(period)
} else {
Err(SafeError::InvalidVault {
reason: "TOTP period must be at least 1 second".into(),
})
}
}
fn hotp(key: &[u8], counter: u64, digits: u32, algorithm: TotpAlgorithm) -> SafeResult<u64> {
let counter_bytes = counter.to_be_bytes();
let result = hmac_digest(key, &counter_bytes, algorithm)?;
let offset = (result[result.len() - 1] & 0x0f) as usize;
let code = u32::from_be_bytes([
result[offset] & 0x7f,
result[offset + 1],
result[offset + 2],
result[offset + 3],
]);
let modulus = 10u64.pow(digits);
Ok(u64::from(code) % modulus)
}
fn hmac_digest(key: &[u8], counter_bytes: &[u8], algorithm: TotpAlgorithm) -> SafeResult<Vec<u8>> {
match algorithm {
TotpAlgorithm::Sha1 => {
let mut mac = HmacSha1::new_from_slice(key).map_err(hmac_key_error)?;
mac.update(counter_bytes);
Ok(mac.finalize().into_bytes().to_vec())
}
TotpAlgorithm::Sha256 => {
let mut mac = HmacSha256::new_from_slice(key).map_err(hmac_key_error)?;
mac.update(counter_bytes);
Ok(mac.finalize().into_bytes().to_vec())
}
TotpAlgorithm::Sha512 => {
let mut mac = HmacSha512::new_from_slice(key).map_err(hmac_key_error)?;
mac.update(counter_bytes);
Ok(mac.finalize().into_bytes().to_vec())
}
}
}
fn hmac_key_error(e: hmac::digest::InvalidLength) -> SafeError {
SafeError::InvalidVault {
reason: format!("HMAC key error: {e}"),
}
}
fn format_code(code: u64, digits: u32) -> String {
format!("{code:0>width$}", width = digits as usize)
}
#[cfg(test)]
mod tests {
use super::*;
const KNOWN_B32: &str = "JBSWY3DPEHPK3PXP";
#[test]
fn extract_base32_plain_returns_normalised() {
let result = extract_base32(KNOWN_B32).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_lowercase_is_normalised_to_upper() {
let result = extract_base32(&KNOWN_B32.to_lowercase()).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_strips_spaces_and_hyphens() {
let spaced = "JBSWY 3DP-EHPK 3PXP";
let result = extract_base32(spaced).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_parses_otpauth_uri() {
let uri = format!("otpauth://totp/Alice?secret={KNOWN_B32}&issuer=Example");
let result = extract_base32(&uri).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_otpauth_uri_secret_case_insensitive_param_name() {
let uri = format!("otpauth://totp/Alice?SECRET={KNOWN_B32}");
let result = extract_base32(&uri).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_otpauth_uri_missing_query_string_errors() {
let result = extract_base32("otpauth://totp/Alice");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn extract_base32_otpauth_uri_missing_secret_param_errors() {
let result = extract_base32("otpauth://totp/Alice?issuer=Example");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn extract_base32_invalid_base32_chars_errors() {
let result = extract_base32("!!!NOT-VALID-BASE32!!!");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn provisioning_uri_preserves_otpauth_parameters() {
let seed = encode_base32(b"12345678901234567890123456789012");
let input =
format!("otpauth://totp/Alice?secret={seed}&algorithm=SHA256&digits=8&period=60");
let uri = provisioning_uri("GITHUB_2FA", &input, None, None, None).unwrap();
assert_eq!(
uri,
format!("otpauth://totp/GITHUB_2FA?secret={seed}&algorithm=SHA256&digits=8&period=60")
);
}
#[test]
fn provisioning_uri_overrides_otpauth_parameters() {
let seed = encode_base32(b"12345678901234567890123456789012");
let input =
format!("otpauth://totp/Alice?secret={seed}&algorithm=SHA256&digits=8&period=60");
let uri = provisioning_uri("GITHUB_2FA", &input, Some("SHA1"), Some(6), Some(30)).unwrap();
assert_eq!(
uri,
format!("otpauth://totp/GITHUB_2FA?secret={seed}&algorithm=SHA1&digits=6&period=30")
);
}
#[test]
fn generate_code_returns_six_digit_string() {
let code = generate_code(KNOWN_B32).unwrap();
assert_eq!(
code.len(),
6,
"TOTP code must be exactly 6 chars, got {code:?}"
);
assert!(
code.chars().all(|c| c.is_ascii_digit()),
"TOTP code must be all digits, got {code:?}"
);
}
#[test]
fn generate_code_is_stable_within_same_30s_window() {
let a = generate_code(KNOWN_B32).unwrap();
let b = generate_code(KNOWN_B32).unwrap();
assert_eq!(a, b, "codes differed between two rapid calls");
}
#[test]
fn generate_code_rejects_invalid_base32() {
let result = generate_code("!!!INVALID!!!");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn generate_code_zero_pads_to_six_digits() {
for _ in 0..3 {
let code = generate_code(KNOWN_B32).unwrap();
let n: u32 = code.parse().expect("should parse as integer");
assert!(n < 1_000_000, "code {n} must be < 1_000_000");
}
}
#[test]
fn generate_code_at_matches_rfc6238_vectors() {
let seed_sha1 = encode_base32(b"12345678901234567890");
let seed_sha256 = encode_base32(b"12345678901234567890123456789012");
let seed_sha512 =
encode_base32(b"1234567890123456789012345678901234567890123456789012345678901234");
let vectors = [
(59, "94287082", "46119246", "90693936"),
(1_111_111_109, "07081804", "68084774", "25091201"),
(1_111_111_111, "14050471", "67062674", "99943326"),
(1_234_567_890, "89005924", "91819424", "93441116"),
(2_000_000_000, "69279037", "90698825", "38618901"),
(20_000_000_000, "65353130", "77737706", "47863826"),
];
for (timestamp, sha1, sha256, sha512) in vectors {
assert_eq!(
generate_code_at(&format_uri(&seed_sha1, "SHA1", 8, 30), timestamp).unwrap(),
sha1
);
assert_eq!(
generate_code_at(&format_uri(&seed_sha256, "SHA256", 8, 30), timestamp).unwrap(),
sha256
);
assert_eq!(
generate_code_at(&format_uri(&seed_sha512, "SHA512", 8, 30), timestamp).unwrap(),
sha512
);
}
}
#[test]
fn generate_code_at_honors_digits_parameter() {
let seed = encode_base32(b"12345678901234567890");
let uri = format_uri(&seed, "SHA1", 8, 30);
assert_eq!(generate_code_at(&uri, 59).unwrap(), "94287082");
}
#[test]
fn generate_code_at_honors_period_parameter() {
let seed = encode_base32(b"12345678901234567890");
let uri = format_uri(&seed, "SHA1", 8, 60);
assert_eq!(generate_code_at(&uri, 59).unwrap(), "84755224");
assert_eq!(generate_code_at(&uri, 60).unwrap(), "94287082");
}
#[test]
fn generate_code_at_parses_lowercase_otpauth_parameters() {
let seed = encode_base32(b"12345678901234567890123456789012");
let uri = format!("otpauth://totp/Alice?secret={seed}&algorithm=sha256&digits=8&period=30");
assert_eq!(generate_code_at(&uri, 59).unwrap(), "46119246");
}
#[test]
fn generate_code_at_rejects_invalid_otpauth_parameters() {
let seed = encode_base32(b"12345678901234567890");
for uri in [
format!("otpauth://totp/Alice?secret={seed}&algorithm=MD5"),
format!("otpauth://totp/Alice?secret={seed}&digits=0"),
format!("otpauth://totp/Alice?secret={seed}&digits=11"),
format!("otpauth://totp/Alice?secret={seed}&period=0"),
format!("otpauth://totp/Alice?secret={seed}&period=abc"),
] {
let result = generate_code_at(&uri, 59);
assert!(
matches!(result, Err(SafeError::InvalidVault { .. })),
"expected invalid parameter error for {uri:?}, got {result:?}"
);
}
}
#[test]
fn seconds_remaining_for_honors_period_parameter() {
let seed = encode_base32(b"12345678901234567890");
let uri = format_uri(&seed, "SHA1", 6, 60);
assert_eq!(seconds_remaining_for_at(&uri, 59).unwrap(), 1);
assert_eq!(seconds_remaining_for_at(&uri, 60).unwrap(), 60);
}
#[test]
fn seconds_remaining_is_in_range_1_to_30() {
let secs = seconds_remaining();
assert!(
(1..=30).contains(&secs),
"seconds_remaining() returned {secs}, expected 1..=30"
);
}
fn encode_base32(bytes: &[u8]) -> String {
base32::encode(base32::Alphabet::Rfc4648 { padding: false }, bytes)
}
fn format_uri(secret: &str, algorithm: &str, digits: u32, period: u64) -> String {
format!(
"otpauth://totp/Alice?secret={secret}&algorithm={algorithm}&digits={digits}&period={period}"
)
}
}