use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use crate::error::CoreError;
pub const DEFAULT_PERIOD: u8 = 30;
pub const DEFAULT_DIGITS: u8 = 6;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum TotpAlgorithm {
#[default]
Sha1,
Sha256,
Sha512,
}
impl TotpAlgorithm {
pub fn parse(s: &str) -> Result<Self, CoreError> {
match s.to_ascii_uppercase().as_str() {
"SHA1" => Ok(TotpAlgorithm::Sha1),
"SHA256" => Ok(TotpAlgorithm::Sha256),
"SHA512" => Ok(TotpAlgorithm::Sha512),
other => Err(CoreError::Totp(format!(
"unknown TOTP algorithm `{other}` (expected SHA1|SHA256|SHA512)"
))),
}
}
pub fn as_str(&self) -> &'static str {
match self {
TotpAlgorithm::Sha1 => "SHA1",
TotpAlgorithm::Sha256 => "SHA256",
TotpAlgorithm::Sha512 => "SHA512",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TotpParams {
pub algorithm: TotpAlgorithm,
pub digits: u8,
pub period: u8,
}
impl Default for TotpParams {
fn default() -> Self {
Self {
algorithm: TotpAlgorithm::default(),
digits: DEFAULT_DIGITS,
period: DEFAULT_PERIOD,
}
}
}
pub struct ParsedEnrollment {
pub seed: Vec<u8>,
pub params: TotpParams,
}
pub fn code_at(
seed: &[u8],
unix_secs: u64,
algorithm: TotpAlgorithm,
digits: u8,
period: u8,
) -> Result<String, CoreError> {
if period == 0 {
return Err(CoreError::Totp("period must be at least 1 second".into()));
}
if !(1..=9).contains(&digits) {
return Err(CoreError::Totp(format!(
"digits must be between 1 and 9 (got {digits})"
)));
}
let counter = unix_secs / period as u64;
let mac = hmac_counter(seed, counter, algorithm)?;
let offset = (mac[mac.len() - 1] & 0x0f) as usize;
let bin = ((mac[offset] as u32 & 0x7f) << 24)
| ((mac[offset + 1] as u32) << 16)
| ((mac[offset + 2] as u32) << 8)
| (mac[offset + 3] as u32);
let modulus = 10u32.pow(digits as u32);
let code = bin % modulus;
Ok(format!("{code:0width$}", width = digits as usize))
}
pub fn seconds_remaining(unix_secs: u64, period: u64) -> u64 {
if period == 0 {
return 0;
}
period - (unix_secs % period)
}
pub fn returns_current(remaining: u64, min_validity: u64) -> bool {
remaining > min_validity
}
fn hmac_counter(seed: &[u8], counter: u64, algorithm: TotpAlgorithm) -> Result<Vec<u8>, CoreError> {
let msg = counter.to_be_bytes();
let init_err = || CoreError::Totp("hmac init".into());
let out = match algorithm {
TotpAlgorithm::Sha1 => {
let mut mac = <Hmac<sha1::Sha1>>::new_from_slice(seed).map_err(|_| init_err())?;
mac.update(&msg);
mac.finalize().into_bytes().to_vec()
}
TotpAlgorithm::Sha256 => {
let mut mac = <Hmac<sha2::Sha256>>::new_from_slice(seed).map_err(|_| init_err())?;
mac.update(&msg);
mac.finalize().into_bytes().to_vec()
}
TotpAlgorithm::Sha512 => {
let mut mac = <Hmac<sha2::Sha512>>::new_from_slice(seed).map_err(|_| init_err())?;
mac.update(&msg);
mac.finalize().into_bytes().to_vec()
}
};
Ok(out)
}
pub fn decode_base32(input: &str) -> Result<Vec<u8>, CoreError> {
const ALPHABET: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
let mut bits: u32 = 0;
let mut nbits: u32 = 0;
let mut out = Vec::new();
for ch in input.chars() {
if ch == '=' || ch.is_whitespace() || ch == '-' {
continue;
}
let up = ch.to_ascii_uppercase() as u8;
let val = ALPHABET
.iter()
.position(|&c| c == up)
.ok_or_else(|| CoreError::Totp("seed is not valid base32 (A–Z, 2–7)".into()))?
as u32;
bits = (bits << 5) | val;
nbits += 5;
if nbits >= 8 {
nbits -= 8;
out.push((bits >> nbits) as u8);
}
}
if out.is_empty() {
return Err(CoreError::Totp("empty seed".into()));
}
Ok(out)
}
pub fn parse_otpauth(uri: &str) -> Result<ParsedEnrollment, CoreError> {
let rest = uri
.strip_prefix("otpauth://totp/")
.ok_or_else(|| CoreError::Totp("not an `otpauth://totp/` URI".into()))?;
let query = rest.split_once('?').map(|(_, q)| q).unwrap_or("");
let mut secret: Option<String> = None;
let mut params = TotpParams::default();
for pair in query.split('&').filter(|p| !p.is_empty()) {
let (k, v) = pair
.split_once('=')
.ok_or_else(|| CoreError::Totp("malformed otpauth query parameter".into()))?;
match k.to_ascii_lowercase().as_str() {
"secret" => secret = Some(percent_decode(v)),
"algorithm" => params.algorithm = TotpAlgorithm::parse(&percent_decode(v))?,
"digits" => {
params.digits = percent_decode(v)
.parse::<u8>()
.map_err(|_| CoreError::Totp("digits must be a small integer".into()))?
}
"period" => {
params.period = percent_decode(v)
.parse::<u8>()
.map_err(|_| CoreError::Totp("period must be a small integer".into()))?
}
_ => {}
}
}
let secret = secret.ok_or_else(|| CoreError::Totp("otpauth URI has no `secret`".into()))?;
let seed = decode_base32(&secret)?;
code_at(&seed, 0, params.algorithm, params.digits, params.period)?;
Ok(ParsedEnrollment { seed, params })
}
fn percent_decode(s: &str) -> String {
let bytes = s.as_bytes();
let mut out = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
let hi = (bytes[i + 1] as char).to_digit(16);
let lo = (bytes[i + 2] as char).to_digit(16);
if let (Some(hi), Some(lo)) = (hi, lo) {
out.push((hi * 16 + lo) as u8);
i += 3;
continue;
}
}
out.push(bytes[i]);
i += 1;
}
String::from_utf8_lossy(&out).into_owned()
}
pub fn parse_seed_input(input: &str) -> Result<ParsedEnrollment, CoreError> {
let trimmed = input.trim();
if trimmed.starts_with("otpauth://") {
parse_otpauth(trimmed)
} else {
let seed = decode_base32(trimmed)?;
Ok(ParsedEnrollment {
seed,
params: TotpParams::default(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sha1_seed() -> Vec<u8> {
b"12345678901234567890".to_vec()
}
fn sha256_seed() -> Vec<u8> {
b"12345678901234567890123456789012".to_vec()
}
fn sha512_seed() -> Vec<u8> {
b"1234567890123456789012345678901234567890123456789012345678901234".to_vec()
}
#[test]
fn rfc6238_known_answer_vectors_sha1() {
for (t, expected) in [
(59u64, "94287082"),
(1_111_111_109, "07081804"),
(1_111_111_111, "14050471"),
(1_234_567_890, "89005924"),
(2_000_000_000, "69279037"),
(20_000_000_000, "65353130"),
] {
let code = code_at(&sha1_seed(), t, TotpAlgorithm::Sha1, 8, 30).unwrap();
assert_eq!(code, expected, "SHA1 vector at t={t}");
}
}
#[test]
fn rfc6238_known_answer_vectors_sha256() {
for (t, expected) in [
(59u64, "46119246"),
(1_111_111_109, "68084774"),
(1_234_567_890, "91819424"),
(20_000_000_000, "77737706"),
] {
let code = code_at(&sha256_seed(), t, TotpAlgorithm::Sha256, 8, 30).unwrap();
assert_eq!(code, expected, "SHA256 vector at t={t}");
}
}
#[test]
fn rfc6238_known_answer_vectors_sha512() {
for (t, expected) in [
(59u64, "90693936"),
(1_111_111_109, "25091201"),
(1_234_567_890, "93441116"),
(20_000_000_000, "47863826"),
] {
let code = code_at(&sha512_seed(), t, TotpAlgorithm::Sha512, 8, 30).unwrap();
assert_eq!(code, expected, "SHA512 vector at t={t}");
}
}
#[test]
fn code_via_mock_clock_matches_vector() {
use crate::clock::{Clock, MockClock};
let clock = MockClock::at(59);
let code = code_at(&sha1_seed(), clock.unix_secs(), TotpAlgorithm::Sha1, 8, 30).unwrap();
assert_eq!(code, "94287082");
}
#[test]
fn default_six_digits_truncates_the_vector() {
let code = code_at(&sha1_seed(), 59, TotpAlgorithm::Sha1, 6, 30).unwrap();
assert_eq!(code, "287082");
assert_eq!(code.len(), 6);
}
#[test]
fn base32_decode_known_vectors() {
assert_eq!(decode_base32("MFRGG===").unwrap(), b"abc");
assert_eq!(decode_base32("mfrgg").unwrap(), b"abc");
assert_eq!(
decode_base32("JBSWY3DPEHPK3PXP").unwrap(),
b"Hello!\xde\xad\xbe\xef"
);
assert_eq!(decode_base32("MFRG G===").unwrap(), b"abc");
assert!(decode_base32("0189!").is_err());
assert!(decode_base32("").is_err());
}
#[test]
fn otpauth_parse_round_trip() {
let uri = "otpauth://totp/ACME:alice@example.com?secret=GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ&issuer=ACME&algorithm=SHA1&digits=8&period=30";
let parsed = parse_otpauth(uri).unwrap();
assert_eq!(parsed.seed, sha1_seed());
assert_eq!(parsed.params.algorithm, TotpAlgorithm::Sha1);
assert_eq!(parsed.params.digits, 8);
assert_eq!(parsed.params.period, 30);
let code = code_at(
&parsed.seed,
59,
parsed.params.algorithm,
parsed.params.digits,
parsed.params.period,
)
.unwrap();
assert_eq!(code, "94287082");
}
#[test]
fn otpauth_defaults_and_bare_seed() {
let uri = "otpauth://totp/x?secret=GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ";
let parsed = parse_otpauth(uri).unwrap();
assert_eq!(parsed.params, TotpParams::default());
let bare = parse_seed_input("GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ").unwrap();
assert_eq!(bare.seed, sha1_seed());
assert_eq!(bare.params, TotpParams::default());
assert_eq!(bare.params.digits, 6);
}
#[test]
fn parse_seed_input_routes_uri_vs_bare() {
assert!(parse_seed_input("otpauth://totp/x?secret=MFRGG").is_ok());
assert!(parse_seed_input("MFRGG").is_ok());
assert!(parse_seed_input("otpauth://hotp/x?secret=MFRGG").is_err());
}
#[test]
fn seconds_remaining_counts_down_within_the_window() {
assert_eq!(seconds_remaining(59, 30), 1);
assert_eq!(seconds_remaining(60, 30), 30);
assert_eq!(seconds_remaining(75, 30), 15);
assert_eq!(seconds_remaining(0, 30), 30);
assert_eq!(seconds_remaining(30, 30), 30);
assert_eq!(seconds_remaining(5, 0), 0);
}
#[test]
fn returns_current_thresholds_on_min_validity() {
assert!(returns_current(30, 0));
assert!(returns_current(11, 10));
assert!(returns_current(2, 1));
assert!(!returns_current(10, 10));
assert!(!returns_current(5, 10));
assert!(!returns_current(0, 0));
for remaining in 1..=30 {
assert!(returns_current(remaining, 0));
}
}
#[test]
fn rejects_degenerate_params() {
assert!(code_at(b"seed", 0, TotpAlgorithm::Sha1, 6, 0).is_err()); assert!(code_at(b"seed", 0, TotpAlgorithm::Sha1, 0, 30).is_err()); assert!(code_at(b"seed", 0, TotpAlgorithm::Sha1, 10, 30).is_err()); }
#[test]
fn algorithm_parse_round_trips() {
assert_eq!(TotpAlgorithm::parse("sha1").unwrap(), TotpAlgorithm::Sha1);
assert_eq!(
TotpAlgorithm::parse("SHA256").unwrap(),
TotpAlgorithm::Sha256
);
assert_eq!(TotpAlgorithm::Sha512.as_str(), "SHA512");
assert!(TotpAlgorithm::parse("md5").is_err());
}
}