1use hmac::{Hmac, Mac};
6use sha1::Sha1;
7use std::collections::HashMap;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10type HmacSha1 = Hmac<Sha1>;
11
12pub fn generate_totp(secret: &[u8], time_step_secs: u64, digits: u32) -> anyhow::Result<String> {
18 let now = SystemTime::now()
19 .duration_since(UNIX_EPOCH)
20 .map_err(|_| anyhow::anyhow!("system clock before unix epoch"))?;
21 let counter = now.as_secs() / time_step_secs;
22 generate_hotp(secret, counter, digits)
23}
24
25pub fn generate_totp_at(
27 secret: &[u8],
28 time_step_secs: u64,
29 digits: u32,
30 unix_secs: u64,
31) -> anyhow::Result<String> {
32 let counter = unix_secs / time_step_secs;
33 generate_hotp(secret, counter, digits)
34}
35
36pub fn validate_totp(
44 code: &str,
45 secret: &[u8],
46 time_step_secs: u64,
47 digits: u32,
48 window: u64,
49) -> anyhow::Result<bool> {
50 let now = SystemTime::now()
51 .duration_since(UNIX_EPOCH)
52 .map_err(|_| anyhow::anyhow!("system clock before unix epoch"))?;
53 let current_counter = now.as_secs() / time_step_secs;
54
55 for offset in 0..=window {
56 let counters = if offset == 0 {
58 vec![current_counter]
59 } else {
60 vec![
61 current_counter.wrapping_add(offset),
62 current_counter.wrapping_sub(offset),
63 ]
64 };
65
66 for counter in counters {
67 let expected = generate_hotp(secret, counter, digits)?;
68 if constant_time_eq(code.as_bytes(), expected.as_bytes()) {
69 return Ok(true);
70 }
71 }
72 }
73
74 Ok(false)
75}
76
77pub fn validate_totp_at(
79 code: &str,
80 secret: &[u8],
81 time_step_secs: u64,
82 digits: u32,
83 window: u64,
84 unix_secs: u64,
85) -> anyhow::Result<bool> {
86 let current_counter = unix_secs / time_step_secs;
87
88 for offset in 0..=window {
89 let counters = if offset == 0 {
90 vec![current_counter]
91 } else {
92 vec![
93 current_counter.wrapping_add(offset),
94 current_counter.wrapping_sub(offset),
95 ]
96 };
97
98 for counter in counters {
99 let expected = generate_hotp(secret, counter, digits)?;
100 if constant_time_eq(code.as_bytes(), expected.as_bytes()) {
101 return Ok(true);
102 }
103 }
104 }
105
106 Ok(false)
107}
108
109fn generate_hotp(secret: &[u8], counter: u64, digits: u32) -> anyhow::Result<String> {
111 let mut mac =
112 HmacSha1::new_from_slice(secret).map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?;
113 mac.update(&counter.to_be_bytes());
114 let result = mac.finalize().into_bytes();
115
116 let offset = (result[19] & 0x0f) as usize;
118 let code = u32::from_be_bytes([
119 result[offset] & 0x7f,
120 result[offset + 1],
121 result[offset + 2],
122 result[offset + 3],
123 ]);
124
125 let modulus = 10u32.pow(digits);
126 Ok(format!(
127 "{:0>width$}",
128 code % modulus,
129 width = digits as usize
130 ))
131}
132
133fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
135 if a.len() != b.len() {
136 return false;
137 }
138 let mut diff = 0u8;
139 for (x, y) in a.iter().zip(b.iter()) {
140 diff |= x ^ y;
141 }
142 diff == 0
143}
144
145#[derive(Debug, Clone, PartialEq, Eq)]
151pub enum OtpGateResult {
152 NotGated,
154 RequiresOtp { reason: String },
156 Approved,
158 Denied { reason: String },
160}
161
162#[derive(Debug, Clone)]
164pub struct OtpGate {
165 pub enabled: bool,
166 pub gated_actions: Vec<String>,
167 pub gated_domains: Vec<String>,
168 pub gated_domain_categories: Vec<String>,
169 pub cache_valid_secs: u64,
170 approvals: HashMap<String, u64>,
172}
173
174impl OtpGate {
175 pub fn new(
176 enabled: bool,
177 gated_actions: Vec<String>,
178 gated_domains: Vec<String>,
179 gated_domain_categories: Vec<String>,
180 cache_valid_secs: u64,
181 ) -> Self {
182 Self {
183 enabled,
184 gated_actions,
185 gated_domains,
186 gated_domain_categories,
187 cache_valid_secs,
188 approvals: HashMap::new(),
189 }
190 }
191
192 pub fn disabled() -> Self {
193 Self {
194 enabled: false,
195 gated_actions: Vec::new(),
196 gated_domains: Vec::new(),
197 gated_domain_categories: Vec::new(),
198 cache_valid_secs: 300,
199 approvals: HashMap::new(),
200 }
201 }
202
203 pub fn check_action(&self, action: &str) -> OtpGateResult {
205 if !self.enabled {
206 return OtpGateResult::NotGated;
207 }
208
209 let is_gated = self
210 .gated_actions
211 .iter()
212 .any(|a| a.eq_ignore_ascii_case(action));
213
214 if !is_gated {
215 return OtpGateResult::NotGated;
216 }
217
218 let cache_key = format!("action:{action}");
220 if self.is_cached_approval(&cache_key) {
221 return OtpGateResult::Approved;
222 }
223
224 OtpGateResult::RequiresOtp {
225 reason: format!("Action `{action}` requires OTP verification"),
226 }
227 }
228
229 pub fn check_domain(&self, domain: &str) -> OtpGateResult {
231 if !self.enabled {
232 return OtpGateResult::NotGated;
233 }
234
235 let is_gated = self
236 .gated_domains
237 .iter()
238 .any(|d| d.eq_ignore_ascii_case(domain) || domain.ends_with(&format!(".{d}")));
239
240 if !is_gated {
241 return OtpGateResult::NotGated;
242 }
243
244 let cache_key = format!("domain:{domain}");
245 if self.is_cached_approval(&cache_key) {
246 return OtpGateResult::Approved;
247 }
248
249 OtpGateResult::RequiresOtp {
250 reason: format!("Domain `{domain}` requires OTP verification"),
251 }
252 }
253
254 pub fn record_approval(&mut self, key: &str) {
256 let now = SystemTime::now()
257 .duration_since(UNIX_EPOCH)
258 .unwrap_or_default()
259 .as_secs();
260 self.approvals.insert(key.to_string(), now);
261 }
262
263 pub fn approve_action(&mut self, action: &str) {
265 self.record_approval(&format!("action:{action}"));
266 }
267
268 pub fn approve_domain(&mut self, domain: &str) {
270 self.record_approval(&format!("domain:{domain}"));
271 }
272
273 fn is_cached_approval(&self, key: &str) -> bool {
274 if let Some(&approved_at) = self.approvals.get(key) {
275 let now = SystemTime::now()
276 .duration_since(UNIX_EPOCH)
277 .unwrap_or_default()
278 .as_secs();
279 now.saturating_sub(approved_at) < self.cache_valid_secs
280 } else {
281 false
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 const TEST_SECRET: &[u8] = b"12345678901234567890";
291
292 #[test]
293 fn hotp_rfc4226_test_vectors() {
294 let expected = [
296 "755224", "287082", "359152", "969429", "338314", "254676", "287922", "162583",
297 "399871", "520489",
298 ];
299
300 for (counter, expected_code) in expected.iter().enumerate() {
301 let code = generate_hotp(TEST_SECRET, counter as u64, 6).unwrap();
302 assert_eq!(&code, expected_code, "HOTP mismatch at counter {counter}");
303 }
304 }
305
306 #[test]
307 fn totp_at_known_time() {
308 let code = generate_totp_at(TEST_SECRET, 30, 6, 59).unwrap();
310 assert_eq!(code, "287082");
312 }
313
314 #[test]
315 fn totp_at_boundary() {
316 let code_30 = generate_totp_at(TEST_SECRET, 30, 6, 30).unwrap();
318 let code_29 = generate_totp_at(TEST_SECRET, 30, 6, 29).unwrap();
320 assert_ne!(code_30, code_29);
322 }
323
324 #[test]
325 fn validate_totp_correct_code() {
326 let unix_time = 59u64;
327 let code = generate_totp_at(TEST_SECRET, 30, 6, unix_time).unwrap();
328 let valid = validate_totp_at(&code, TEST_SECRET, 30, 6, 0, unix_time).unwrap();
329 assert!(valid);
330 }
331
332 #[test]
333 fn validate_totp_wrong_code() {
334 let valid = validate_totp_at("000000", TEST_SECRET, 30, 6, 0, 59).unwrap();
335 assert!(!valid);
336 }
337
338 #[test]
339 fn validate_totp_with_window() {
340 let code = generate_totp_at(TEST_SECRET, 30, 6, 30).unwrap();
342 let valid = validate_totp_at(&code, TEST_SECRET, 30, 6, 1, 60).unwrap();
344 assert!(valid);
345 let invalid = validate_totp_at(&code, TEST_SECRET, 30, 6, 0, 60).unwrap();
347 assert!(!invalid);
348 }
349
350 #[test]
351 fn digits_8_generates_8_chars() {
352 let code = generate_totp_at(TEST_SECRET, 30, 8, 59).unwrap();
353 assert_eq!(code.len(), 8);
354 }
355
356 #[test]
357 fn constant_time_eq_works() {
358 assert!(constant_time_eq(b"hello", b"hello"));
359 assert!(!constant_time_eq(b"hello", b"world"));
360 assert!(!constant_time_eq(b"hello", b"hell"));
361 }
362
363 #[test]
366 fn disabled_gate_never_requires_otp() {
367 let gate = OtpGate::disabled();
368 assert_eq!(gate.check_action("shell"), OtpGateResult::NotGated);
369 assert_eq!(gate.check_domain("example.com"), OtpGateResult::NotGated);
370 }
371
372 #[test]
373 fn gate_requires_otp_for_gated_action() {
374 let gate = OtpGate::new(
375 true,
376 vec!["shell".into(), "file_write".into()],
377 Vec::new(),
378 Vec::new(),
379 300,
380 );
381
382 match gate.check_action("shell") {
383 OtpGateResult::RequiresOtp { reason } => {
384 assert!(reason.contains("shell"));
385 }
386 other => panic!("expected RequiresOtp, got {other:?}"),
387 }
388
389 assert_eq!(gate.check_action("file_read"), OtpGateResult::NotGated);
390 }
391
392 #[test]
393 fn gate_requires_otp_for_gated_domain() {
394 let gate = OtpGate::new(true, Vec::new(), vec!["bank.com".into()], Vec::new(), 300);
395
396 match gate.check_domain("bank.com") {
397 OtpGateResult::RequiresOtp { reason } => {
398 assert!(reason.contains("bank.com"));
399 }
400 other => panic!("expected RequiresOtp, got {other:?}"),
401 }
402
403 match gate.check_domain("api.bank.com") {
405 OtpGateResult::RequiresOtp { .. } => {}
406 other => panic!("expected RequiresOtp for subdomain, got {other:?}"),
407 }
408
409 assert_eq!(gate.check_domain("example.com"), OtpGateResult::NotGated);
410 }
411
412 #[test]
413 fn gate_caches_approval() {
414 let mut gate = OtpGate::new(true, vec!["shell".into()], Vec::new(), Vec::new(), 300);
415
416 assert!(matches!(
418 gate.check_action("shell"),
419 OtpGateResult::RequiresOtp { .. }
420 ));
421
422 gate.approve_action("shell");
424
425 assert_eq!(gate.check_action("shell"), OtpGateResult::Approved);
427 }
428
429 #[test]
430 fn gate_case_insensitive_action_check() {
431 let gate = OtpGate::new(true, vec!["Shell".into()], Vec::new(), Vec::new(), 300);
432
433 assert!(matches!(
434 gate.check_action("shell"),
435 OtpGateResult::RequiresOtp { .. }
436 ));
437 assert!(matches!(
438 gate.check_action("SHELL"),
439 OtpGateResult::RequiresOtp { .. }
440 ));
441 }
442}