use hmac::{Hmac, Mac};
use sha1::Sha1;
use zeroize::Zeroizing;
use crate::errors::{SafeError, SafeResult};
type HmacSha1 = Hmac<Sha1>;
pub fn extract_base32(input: &str) -> SafeResult<Zeroizing<String>> {
let raw = if input.starts_with("otpauth://") {
let query_start = input.find('?').ok_or_else(|| SafeError::InvalidVault {
reason: "otpauth:// URI has no query string".into(),
})?;
let query = &input[query_start + 1..];
let secret = query
.split('&')
.find_map(|pair| {
let (k, v) = pair.split_once('=')?;
if k.eq_ignore_ascii_case("secret") {
Some(v)
} else {
None
}
})
.ok_or_else(|| SafeError::InvalidVault {
reason: "otpauth:// URI is missing the 'secret' parameter".into(),
})?;
secret.to_string()
} else {
input.to_string()
};
let normalised: String = raw
.chars()
.filter(|c| !c.is_whitespace() && *c != '-')
.map(|c| c.to_ascii_uppercase())
.collect();
decode_base32(&normalised)?;
Ok(Zeroizing::new(normalised))
}
pub fn generate_code(base32_secret: &str) -> SafeResult<String> {
let key_bytes = decode_base32(base32_secret)?;
let counter = current_counter();
let code = hotp(&key_bytes, counter, 6)?;
Ok(format!("{code:0>6}"))
}
pub fn seconds_remaining() -> u64 {
let ts = unix_timestamp();
30 - (ts % 30)
}
fn unix_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn current_counter() -> u64 {
unix_timestamp() / 30
}
fn decode_base32(s: &str) -> SafeResult<Vec<u8>> {
base32::decode(base32::Alphabet::RFC4648 { padding: false }, s)
.or_else(|| base32::decode(base32::Alphabet::RFC4648 { padding: true }, s))
.ok_or_else(|| SafeError::InvalidVault {
reason: "invalid TOTP base32 secret".into(),
})
}
fn hotp(key: &[u8], counter: u64, digits: u32) -> SafeResult<u32> {
let counter_bytes = counter.to_be_bytes();
let mut mac = HmacSha1::new_from_slice(key).map_err(|e| SafeError::InvalidVault {
reason: format!("HMAC key error: {e}"),
})?;
mac.update(&counter_bytes);
let result = mac.finalize().into_bytes();
let result = result.as_slice();
let offset = (result[19] & 0x0f) as usize;
let code = u32::from_be_bytes([
result[offset] & 0x7f,
result[offset + 1],
result[offset + 2],
result[offset + 3],
]);
let modulus = 10u32.pow(digits);
Ok(code % modulus)
}
#[cfg(test)]
mod tests {
use super::*;
const KNOWN_B32: &str = "JBSWY3DPEHPK3PXP";
#[test]
fn extract_base32_plain_returns_normalised() {
let result = extract_base32(KNOWN_B32).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_lowercase_is_normalised_to_upper() {
let result = extract_base32(&KNOWN_B32.to_lowercase()).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_strips_spaces_and_hyphens() {
let spaced = "JBSWY 3DP-EHPK 3PXP";
let result = extract_base32(spaced).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_parses_otpauth_uri() {
let uri = format!("otpauth://totp/Alice?secret={KNOWN_B32}&issuer=Example");
let result = extract_base32(&uri).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_otpauth_uri_secret_case_insensitive_param_name() {
let uri = format!("otpauth://totp/Alice?SECRET={KNOWN_B32}");
let result = extract_base32(&uri).unwrap();
assert_eq!(*result, KNOWN_B32);
}
#[test]
fn extract_base32_otpauth_uri_missing_query_string_errors() {
let result = extract_base32("otpauth://totp/Alice");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn extract_base32_otpauth_uri_missing_secret_param_errors() {
let result = extract_base32("otpauth://totp/Alice?issuer=Example");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn extract_base32_invalid_base32_chars_errors() {
let result = extract_base32("!!!NOT-VALID-BASE32!!!");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn generate_code_returns_six_digit_string() {
let code = generate_code(KNOWN_B32).unwrap();
assert_eq!(
code.len(),
6,
"TOTP code must be exactly 6 chars, got {code:?}"
);
assert!(
code.chars().all(|c| c.is_ascii_digit()),
"TOTP code must be all digits, got {code:?}"
);
}
#[test]
fn generate_code_is_stable_within_same_30s_window() {
let a = generate_code(KNOWN_B32).unwrap();
let b = generate_code(KNOWN_B32).unwrap();
assert_eq!(a, b, "codes differed between two rapid calls");
}
#[test]
fn generate_code_rejects_invalid_base32() {
let result = generate_code("!!!INVALID!!!");
assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
}
#[test]
fn generate_code_zero_pads_to_six_digits() {
for _ in 0..3 {
let code = generate_code(KNOWN_B32).unwrap();
let n: u32 = code.parse().expect("should parse as integer");
assert!(n < 1_000_000, "code {n} must be < 1_000_000");
}
}
#[test]
fn seconds_remaining_is_in_range_1_to_30() {
let secs = seconds_remaining();
assert!(
(1..=30).contains(&secs),
"seconds_remaining() returned {secs}, expected 1..=30"
);
}
}