1use hmac::{Hmac, Mac};
23use serde::{Deserialize, Serialize};
24
25use crate::error::CoreError;
26
27pub const DEFAULT_PERIOD: u8 = 30;
29pub const DEFAULT_DIGITS: u8 = 6;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
36#[serde(rename_all = "lowercase")]
37pub enum TotpAlgorithm {
38 #[default]
40 Sha1,
41 Sha256,
43 Sha512,
45}
46
47impl TotpAlgorithm {
48 pub fn parse(s: &str) -> Result<Self, CoreError> {
51 match s.to_ascii_uppercase().as_str() {
52 "SHA1" => Ok(TotpAlgorithm::Sha1),
53 "SHA256" => Ok(TotpAlgorithm::Sha256),
54 "SHA512" => Ok(TotpAlgorithm::Sha512),
55 other => Err(CoreError::Totp(format!(
56 "unknown TOTP algorithm `{other}` (expected SHA1|SHA256|SHA512)"
57 ))),
58 }
59 }
60
61 pub fn as_str(&self) -> &'static str {
63 match self {
64 TotpAlgorithm::Sha1 => "SHA1",
65 TotpAlgorithm::Sha256 => "SHA256",
66 TotpAlgorithm::Sha512 => "SHA512",
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub struct TotpParams {
75 pub algorithm: TotpAlgorithm,
77 pub digits: u8,
79 pub period: u8,
81}
82
83impl Default for TotpParams {
84 fn default() -> Self {
85 Self {
86 algorithm: TotpAlgorithm::default(),
87 digits: DEFAULT_DIGITS,
88 period: DEFAULT_PERIOD,
89 }
90 }
91}
92
93pub struct ParsedEnrollment {
97 pub seed: Vec<u8>,
99 pub params: TotpParams,
101}
102
103pub fn code_at(
113 seed: &[u8],
114 unix_secs: u64,
115 algorithm: TotpAlgorithm,
116 digits: u8,
117 period: u8,
118) -> Result<String, CoreError> {
119 if period == 0 {
120 return Err(CoreError::Totp("period must be at least 1 second".into()));
121 }
122 if !(1..=9).contains(&digits) {
123 return Err(CoreError::Totp(format!(
124 "digits must be between 1 and 9 (got {digits})"
125 )));
126 }
127 let counter = unix_secs / period as u64;
128 let mac = hmac_counter(seed, counter, algorithm)?;
129 let offset = (mac[mac.len() - 1] & 0x0f) as usize;
132 let bin = ((mac[offset] as u32 & 0x7f) << 24)
133 | ((mac[offset + 1] as u32) << 16)
134 | ((mac[offset + 2] as u32) << 8)
135 | (mac[offset + 3] as u32);
136 let modulus = 10u32.pow(digits as u32);
137 let code = bin % modulus;
138 Ok(format!("{code:0width$}", width = digits as usize))
139}
140
141pub fn seconds_remaining(unix_secs: u64, period: u64) -> u64 {
150 if period == 0 {
151 return 0;
152 }
153 period - (unix_secs % period)
154}
155
156pub fn returns_current(remaining: u64, min_validity: u64) -> bool {
169 remaining > min_validity
170}
171
172fn hmac_counter(seed: &[u8], counter: u64, algorithm: TotpAlgorithm) -> Result<Vec<u8>, CoreError> {
178 let msg = counter.to_be_bytes();
179 let init_err = || CoreError::Totp("hmac init".into());
180 let out = match algorithm {
181 TotpAlgorithm::Sha1 => {
182 let mut mac = <Hmac<sha1::Sha1>>::new_from_slice(seed).map_err(|_| init_err())?;
183 mac.update(&msg);
184 mac.finalize().into_bytes().to_vec()
185 }
186 TotpAlgorithm::Sha256 => {
187 let mut mac = <Hmac<sha2::Sha256>>::new_from_slice(seed).map_err(|_| init_err())?;
188 mac.update(&msg);
189 mac.finalize().into_bytes().to_vec()
190 }
191 TotpAlgorithm::Sha512 => {
192 let mut mac = <Hmac<sha2::Sha512>>::new_from_slice(seed).map_err(|_| init_err())?;
193 mac.update(&msg);
194 mac.finalize().into_bytes().to_vec()
195 }
196 };
197 Ok(out)
198}
199
200pub fn decode_base32(input: &str) -> Result<Vec<u8>, CoreError> {
205 const ALPHABET: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
206 let mut bits: u32 = 0;
207 let mut nbits: u32 = 0;
208 let mut out = Vec::new();
209 for ch in input.chars() {
210 if ch == '=' || ch.is_whitespace() || ch == '-' {
211 continue;
212 }
213 let up = ch.to_ascii_uppercase() as u8;
214 let val = ALPHABET
215 .iter()
216 .position(|&c| c == up)
217 .ok_or_else(|| CoreError::Totp("seed is not valid base32 (A–Z, 2–7)".into()))?
218 as u32;
219 bits = (bits << 5) | val;
220 nbits += 5;
221 if nbits >= 8 {
222 nbits -= 8;
223 out.push((bits >> nbits) as u8);
224 }
225 }
226 if out.is_empty() {
227 return Err(CoreError::Totp("empty seed".into()));
228 }
229 Ok(out)
230}
231
232pub fn parse_otpauth(uri: &str) -> Result<ParsedEnrollment, CoreError> {
239 let rest = uri
240 .strip_prefix("otpauth://totp/")
241 .ok_or_else(|| CoreError::Totp("not an `otpauth://totp/` URI".into()))?;
242 let query = rest.split_once('?').map(|(_, q)| q).unwrap_or("");
243 let mut secret: Option<String> = None;
244 let mut params = TotpParams::default();
245 for pair in query.split('&').filter(|p| !p.is_empty()) {
246 let (k, v) = pair
247 .split_once('=')
248 .ok_or_else(|| CoreError::Totp("malformed otpauth query parameter".into()))?;
249 match k.to_ascii_lowercase().as_str() {
250 "secret" => secret = Some(percent_decode(v)),
251 "algorithm" => params.algorithm = TotpAlgorithm::parse(&percent_decode(v))?,
252 "digits" => {
253 params.digits = percent_decode(v)
254 .parse::<u8>()
255 .map_err(|_| CoreError::Totp("digits must be a small integer".into()))?
256 }
257 "period" => {
258 params.period = percent_decode(v)
259 .parse::<u8>()
260 .map_err(|_| CoreError::Totp("period must be a small integer".into()))?
261 }
262 _ => {}
264 }
265 }
266 let secret = secret.ok_or_else(|| CoreError::Totp("otpauth URI has no `secret`".into()))?;
267 let seed = decode_base32(&secret)?;
268 code_at(&seed, 0, params.algorithm, params.digits, params.period)?;
271 Ok(ParsedEnrollment { seed, params })
272}
273
274fn percent_decode(s: &str) -> String {
278 let bytes = s.as_bytes();
279 let mut out = Vec::with_capacity(bytes.len());
280 let mut i = 0;
281 while i < bytes.len() {
282 if bytes[i] == b'%' && i + 2 < bytes.len() {
283 let hi = (bytes[i + 1] as char).to_digit(16);
284 let lo = (bytes[i + 2] as char).to_digit(16);
285 if let (Some(hi), Some(lo)) = (hi, lo) {
286 out.push((hi * 16 + lo) as u8);
287 i += 3;
288 continue;
289 }
290 }
291 out.push(bytes[i]);
292 i += 1;
293 }
294 String::from_utf8_lossy(&out).into_owned()
295}
296
297pub fn parse_seed_input(input: &str) -> Result<ParsedEnrollment, CoreError> {
301 let trimmed = input.trim();
302 if trimmed.starts_with("otpauth://") {
303 parse_otpauth(trimmed)
304 } else {
305 let seed = decode_base32(trimmed)?;
306 Ok(ParsedEnrollment {
307 seed,
308 params: TotpParams::default(),
309 })
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 fn sha1_seed() -> Vec<u8> {
321 b"12345678901234567890".to_vec()
322 }
323 fn sha256_seed() -> Vec<u8> {
324 b"12345678901234567890123456789012".to_vec()
325 }
326 fn sha512_seed() -> Vec<u8> {
327 b"1234567890123456789012345678901234567890123456789012345678901234".to_vec()
328 }
329
330 #[test]
334 fn rfc6238_known_answer_vectors_sha1() {
335 for (t, expected) in [
337 (59u64, "94287082"),
338 (1_111_111_109, "07081804"),
339 (1_111_111_111, "14050471"),
340 (1_234_567_890, "89005924"),
341 (2_000_000_000, "69279037"),
342 (20_000_000_000, "65353130"),
343 ] {
344 let code = code_at(&sha1_seed(), t, TotpAlgorithm::Sha1, 8, 30).unwrap();
345 assert_eq!(code, expected, "SHA1 vector at t={t}");
346 }
347 }
348
349 #[test]
350 fn rfc6238_known_answer_vectors_sha256() {
351 for (t, expected) in [
352 (59u64, "46119246"),
353 (1_111_111_109, "68084774"),
354 (1_234_567_890, "91819424"),
355 (20_000_000_000, "77737706"),
356 ] {
357 let code = code_at(&sha256_seed(), t, TotpAlgorithm::Sha256, 8, 30).unwrap();
358 assert_eq!(code, expected, "SHA256 vector at t={t}");
359 }
360 }
361
362 #[test]
363 fn rfc6238_known_answer_vectors_sha512() {
364 for (t, expected) in [
365 (59u64, "90693936"),
366 (1_111_111_109, "25091201"),
367 (1_234_567_890, "93441116"),
368 (20_000_000_000, "47863826"),
369 ] {
370 let code = code_at(&sha512_seed(), t, TotpAlgorithm::Sha512, 8, 30).unwrap();
371 assert_eq!(code, expected, "SHA512 vector at t={t}");
372 }
373 }
374
375 #[test]
378 fn code_via_mock_clock_matches_vector() {
379 use crate::clock::{Clock, MockClock};
380 let clock = MockClock::at(59);
381 let code = code_at(&sha1_seed(), clock.unix_secs(), TotpAlgorithm::Sha1, 8, 30).unwrap();
382 assert_eq!(code, "94287082");
383 }
384
385 #[test]
387 fn default_six_digits_truncates_the_vector() {
388 let code = code_at(&sha1_seed(), 59, TotpAlgorithm::Sha1, 6, 30).unwrap();
389 assert_eq!(code, "287082");
390 assert_eq!(code.len(), 6);
391 }
392
393 #[test]
395 fn base32_decode_known_vectors() {
396 assert_eq!(decode_base32("MFRGG===").unwrap(), b"abc");
397 assert_eq!(decode_base32("mfrgg").unwrap(), b"abc");
398 assert_eq!(
400 decode_base32("JBSWY3DPEHPK3PXP").unwrap(),
401 b"Hello!\xde\xad\xbe\xef"
402 );
403 assert_eq!(decode_base32("MFRG G===").unwrap(), b"abc");
405 assert!(decode_base32("0189!").is_err());
407 assert!(decode_base32("").is_err());
408 }
409
410 #[test]
413 fn otpauth_parse_round_trip() {
414 let uri = "otpauth://totp/ACME:alice@example.com?secret=GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ&issuer=ACME&algorithm=SHA1&digits=8&period=30";
417 let parsed = parse_otpauth(uri).unwrap();
418 assert_eq!(parsed.seed, sha1_seed());
419 assert_eq!(parsed.params.algorithm, TotpAlgorithm::Sha1);
420 assert_eq!(parsed.params.digits, 8);
421 assert_eq!(parsed.params.period, 30);
422 let code = code_at(
423 &parsed.seed,
424 59,
425 parsed.params.algorithm,
426 parsed.params.digits,
427 parsed.params.period,
428 )
429 .unwrap();
430 assert_eq!(code, "94287082");
431 }
432
433 #[test]
435 fn otpauth_defaults_and_bare_seed() {
436 let uri = "otpauth://totp/x?secret=GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ";
437 let parsed = parse_otpauth(uri).unwrap();
438 assert_eq!(parsed.params, TotpParams::default());
439
440 let bare = parse_seed_input("GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ").unwrap();
441 assert_eq!(bare.seed, sha1_seed());
442 assert_eq!(bare.params, TotpParams::default());
443 assert_eq!(bare.params.digits, 6);
445 }
446
447 #[test]
448 fn parse_seed_input_routes_uri_vs_bare() {
449 assert!(parse_seed_input("otpauth://totp/x?secret=MFRGG").is_ok());
450 assert!(parse_seed_input("MFRGG").is_ok());
451 assert!(parse_seed_input("otpauth://hotp/x?secret=MFRGG").is_err());
453 }
454
455 #[test]
458 fn seconds_remaining_counts_down_within_the_window() {
459 assert_eq!(seconds_remaining(59, 30), 1);
462 assert_eq!(seconds_remaining(60, 30), 30);
463 assert_eq!(seconds_remaining(75, 30), 15);
464 assert_eq!(seconds_remaining(0, 30), 30);
466 assert_eq!(seconds_remaining(30, 30), 30);
467 assert_eq!(seconds_remaining(5, 0), 0);
469 }
470
471 #[test]
474 fn returns_current_thresholds_on_min_validity() {
475 assert!(returns_current(30, 0));
477 assert!(returns_current(11, 10));
478 assert!(returns_current(2, 1));
479 assert!(!returns_current(10, 10));
481 assert!(!returns_current(5, 10));
482 assert!(!returns_current(0, 0));
483 for remaining in 1..=30 {
486 assert!(returns_current(remaining, 0));
487 }
488 }
489
490 #[test]
491 fn rejects_degenerate_params() {
492 assert!(code_at(b"seed", 0, TotpAlgorithm::Sha1, 6, 0).is_err()); assert!(code_at(b"seed", 0, TotpAlgorithm::Sha1, 0, 30).is_err()); assert!(code_at(b"seed", 0, TotpAlgorithm::Sha1, 10, 30).is_err()); }
496
497 #[test]
498 fn algorithm_parse_round_trips() {
499 assert_eq!(TotpAlgorithm::parse("sha1").unwrap(), TotpAlgorithm::Sha1);
500 assert_eq!(
501 TotpAlgorithm::parse("SHA256").unwrap(),
502 TotpAlgorithm::Sha256
503 );
504 assert_eq!(TotpAlgorithm::Sha512.as_str(), "SHA512");
505 assert!(TotpAlgorithm::parse("md5").is_err());
506 }
507}