Skip to main content

agentzero_core/security/
otp.rs

1//! TOTP (Time-based One-Time Password) implementation per RFC 6238.
2//!
3//! Used for OTP gating of sensitive actions, domains, and estop resume.
4
5use hmac::{Hmac, Mac};
6use sha1::Sha1;
7use std::collections::HashMap;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10type HmacSha1 = Hmac<Sha1>;
11
12/// Generate a TOTP code for the given secret and time step.
13///
14/// - `secret`: shared secret key (raw bytes, typically base32-decoded)
15/// - `time_step_secs`: time step in seconds (default: 30)
16/// - `digits`: number of digits in the OTP (default: 6)
17pub fn generate_totp(secret: &[u8], time_step_secs: u64, digits: u32) -> anyhow::Result<String> {
18    let now = SystemTime::now()
19        .duration_since(UNIX_EPOCH)
20        .map_err(|_| anyhow::anyhow!("system clock before unix epoch"))?;
21    let counter = now.as_secs() / time_step_secs;
22    generate_hotp(secret, counter, digits)
23}
24
25/// Generate a TOTP code for a specific unix timestamp (for testing).
26pub fn generate_totp_at(
27    secret: &[u8],
28    time_step_secs: u64,
29    digits: u32,
30    unix_secs: u64,
31) -> anyhow::Result<String> {
32    let counter = unix_secs / time_step_secs;
33    generate_hotp(secret, counter, digits)
34}
35
36/// Validate a TOTP code with a time window for clock skew tolerance.
37///
38/// - `code`: the OTP code to validate
39/// - `secret`: shared secret key
40/// - `time_step_secs`: time step in seconds
41/// - `digits`: number of digits
42/// - `window`: number of time steps to check before and after current
43pub fn validate_totp(
44    code: &str,
45    secret: &[u8],
46    time_step_secs: u64,
47    digits: u32,
48    window: u64,
49) -> anyhow::Result<bool> {
50    let now = SystemTime::now()
51        .duration_since(UNIX_EPOCH)
52        .map_err(|_| anyhow::anyhow!("system clock before unix epoch"))?;
53    let current_counter = now.as_secs() / time_step_secs;
54
55    for offset in 0..=window {
56        // Check current and both directions
57        let counters = if offset == 0 {
58            vec![current_counter]
59        } else {
60            vec![
61                current_counter.wrapping_add(offset),
62                current_counter.wrapping_sub(offset),
63            ]
64        };
65
66        for counter in counters {
67            let expected = generate_hotp(secret, counter, digits)?;
68            if constant_time_eq(code.as_bytes(), expected.as_bytes()) {
69                return Ok(true);
70            }
71        }
72    }
73
74    Ok(false)
75}
76
77/// Validate a TOTP code at a specific timestamp (for testing).
78pub fn validate_totp_at(
79    code: &str,
80    secret: &[u8],
81    time_step_secs: u64,
82    digits: u32,
83    window: u64,
84    unix_secs: u64,
85) -> anyhow::Result<bool> {
86    let current_counter = unix_secs / time_step_secs;
87
88    for offset in 0..=window {
89        let counters = if offset == 0 {
90            vec![current_counter]
91        } else {
92            vec![
93                current_counter.wrapping_add(offset),
94                current_counter.wrapping_sub(offset),
95            ]
96        };
97
98        for counter in counters {
99            let expected = generate_hotp(secret, counter, digits)?;
100            if constant_time_eq(code.as_bytes(), expected.as_bytes()) {
101                return Ok(true);
102            }
103        }
104    }
105
106    Ok(false)
107}
108
109/// HOTP (HMAC-based One-Time Password) per RFC 4226.
110fn generate_hotp(secret: &[u8], counter: u64, digits: u32) -> anyhow::Result<String> {
111    let mut mac =
112        HmacSha1::new_from_slice(secret).map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?;
113    mac.update(&counter.to_be_bytes());
114    let result = mac.finalize().into_bytes();
115
116    // Dynamic truncation (RFC 4226 section 5.4)
117    let offset = (result[19] & 0x0f) as usize;
118    let code = u32::from_be_bytes([
119        result[offset] & 0x7f,
120        result[offset + 1],
121        result[offset + 2],
122        result[offset + 3],
123    ]);
124
125    let modulus = 10u32.pow(digits);
126    Ok(format!(
127        "{:0>width$}",
128        code % modulus,
129        width = digits as usize
130    ))
131}
132
133/// Constant-time comparison to prevent timing attacks.
134fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
135    if a.len() != b.len() {
136        return false;
137    }
138    let mut diff = 0u8;
139    for (x, y) in a.iter().zip(b.iter()) {
140        diff |= x ^ y;
141    }
142    diff == 0
143}
144
145// ---------------------------------------------------------------------------
146// OTP Gating Engine
147// ---------------------------------------------------------------------------
148
149/// Decision from the OTP gate.
150#[derive(Debug, Clone, PartialEq, Eq)]
151pub enum OtpGateResult {
152    /// Action is not gated — proceed without OTP.
153    NotGated,
154    /// Action requires OTP validation.
155    RequiresOtp { reason: String },
156    /// OTP was validated and the action is approved (cached for `cache_valid_secs`).
157    Approved,
158    /// OTP validation failed.
159    Denied { reason: String },
160}
161
162/// OTP gating engine — checks whether an action or domain requires OTP.
163#[derive(Debug, Clone)]
164pub struct OtpGate {
165    pub enabled: bool,
166    pub gated_actions: Vec<String>,
167    pub gated_domains: Vec<String>,
168    pub gated_domain_categories: Vec<String>,
169    pub cache_valid_secs: u64,
170    /// Tracks approved actions with their approval timestamp.
171    approvals: HashMap<String, u64>,
172}
173
174impl OtpGate {
175    pub fn new(
176        enabled: bool,
177        gated_actions: Vec<String>,
178        gated_domains: Vec<String>,
179        gated_domain_categories: Vec<String>,
180        cache_valid_secs: u64,
181    ) -> Self {
182        Self {
183            enabled,
184            gated_actions,
185            gated_domains,
186            gated_domain_categories,
187            cache_valid_secs,
188            approvals: HashMap::new(),
189        }
190    }
191
192    pub fn disabled() -> Self {
193        Self {
194            enabled: false,
195            gated_actions: Vec::new(),
196            gated_domains: Vec::new(),
197            gated_domain_categories: Vec::new(),
198            cache_valid_secs: 300,
199            approvals: HashMap::new(),
200        }
201    }
202
203    /// Check if an action requires OTP.
204    pub fn check_action(&self, action: &str) -> OtpGateResult {
205        if !self.enabled {
206            return OtpGateResult::NotGated;
207        }
208
209        let is_gated = self
210            .gated_actions
211            .iter()
212            .any(|a| a.eq_ignore_ascii_case(action));
213
214        if !is_gated {
215            return OtpGateResult::NotGated;
216        }
217
218        // Check cached approval
219        let cache_key = format!("action:{action}");
220        if self.is_cached_approval(&cache_key) {
221            return OtpGateResult::Approved;
222        }
223
224        OtpGateResult::RequiresOtp {
225            reason: format!("Action `{action}` requires OTP verification"),
226        }
227    }
228
229    /// Check if a domain requires OTP.
230    pub fn check_domain(&self, domain: &str) -> OtpGateResult {
231        if !self.enabled {
232            return OtpGateResult::NotGated;
233        }
234
235        let is_gated = self
236            .gated_domains
237            .iter()
238            .any(|d| d.eq_ignore_ascii_case(domain) || domain.ends_with(&format!(".{d}")));
239
240        if !is_gated {
241            return OtpGateResult::NotGated;
242        }
243
244        let cache_key = format!("domain:{domain}");
245        if self.is_cached_approval(&cache_key) {
246            return OtpGateResult::Approved;
247        }
248
249        OtpGateResult::RequiresOtp {
250            reason: format!("Domain `{domain}` requires OTP verification"),
251        }
252    }
253
254    /// Record an OTP approval for caching.
255    pub fn record_approval(&mut self, key: &str) {
256        let now = SystemTime::now()
257            .duration_since(UNIX_EPOCH)
258            .unwrap_or_default()
259            .as_secs();
260        self.approvals.insert(key.to_string(), now);
261    }
262
263    /// Record an action approval after successful OTP validation.
264    pub fn approve_action(&mut self, action: &str) {
265        self.record_approval(&format!("action:{action}"));
266    }
267
268    /// Record a domain approval after successful OTP validation.
269    pub fn approve_domain(&mut self, domain: &str) {
270        self.record_approval(&format!("domain:{domain}"));
271    }
272
273    fn is_cached_approval(&self, key: &str) -> bool {
274        if let Some(&approved_at) = self.approvals.get(key) {
275            let now = SystemTime::now()
276                .duration_since(UNIX_EPOCH)
277                .unwrap_or_default()
278                .as_secs();
279            now.saturating_sub(approved_at) < self.cache_valid_secs
280        } else {
281            false
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    const TEST_SECRET: &[u8] = b"12345678901234567890";
291
292    #[test]
293    fn hotp_rfc4226_test_vectors() {
294        // RFC 4226 Appendix D test vectors for HOTP with secret "12345678901234567890"
295        let expected = [
296            "755224", "287082", "359152", "969429", "338314", "254676", "287922", "162583",
297            "399871", "520489",
298        ];
299
300        for (counter, expected_code) in expected.iter().enumerate() {
301            let code = generate_hotp(TEST_SECRET, counter as u64, 6).unwrap();
302            assert_eq!(&code, expected_code, "HOTP mismatch at counter {counter}");
303        }
304    }
305
306    #[test]
307    fn totp_at_known_time() {
308        // With time_step=30, at time=59 the counter=1
309        let code = generate_totp_at(TEST_SECRET, 30, 6, 59).unwrap();
310        // Counter 1 should produce "287082" per RFC 4226
311        assert_eq!(code, "287082");
312    }
313
314    #[test]
315    fn totp_at_boundary() {
316        // At time=30, counter=1
317        let code_30 = generate_totp_at(TEST_SECRET, 30, 6, 30).unwrap();
318        // At time=29, counter=0
319        let code_29 = generate_totp_at(TEST_SECRET, 30, 6, 29).unwrap();
320        // These should be different (different counters)
321        assert_ne!(code_30, code_29);
322    }
323
324    #[test]
325    fn validate_totp_correct_code() {
326        let unix_time = 59u64;
327        let code = generate_totp_at(TEST_SECRET, 30, 6, unix_time).unwrap();
328        let valid = validate_totp_at(&code, TEST_SECRET, 30, 6, 0, unix_time).unwrap();
329        assert!(valid);
330    }
331
332    #[test]
333    fn validate_totp_wrong_code() {
334        let valid = validate_totp_at("000000", TEST_SECRET, 30, 6, 0, 59).unwrap();
335        assert!(!valid);
336    }
337
338    #[test]
339    fn validate_totp_with_window() {
340        // Generate code at time=30 (counter=1)
341        let code = generate_totp_at(TEST_SECRET, 30, 6, 30).unwrap();
342        // Validate at time=60 (counter=2), with window=1 should accept counter=1
343        let valid = validate_totp_at(&code, TEST_SECRET, 30, 6, 1, 60).unwrap();
344        assert!(valid);
345        // Without window, should reject
346        let invalid = validate_totp_at(&code, TEST_SECRET, 30, 6, 0, 60).unwrap();
347        assert!(!invalid);
348    }
349
350    #[test]
351    fn digits_8_generates_8_chars() {
352        let code = generate_totp_at(TEST_SECRET, 30, 8, 59).unwrap();
353        assert_eq!(code.len(), 8);
354    }
355
356    #[test]
357    fn constant_time_eq_works() {
358        assert!(constant_time_eq(b"hello", b"hello"));
359        assert!(!constant_time_eq(b"hello", b"world"));
360        assert!(!constant_time_eq(b"hello", b"hell"));
361    }
362
363    // --- OTP Gate tests ---
364
365    #[test]
366    fn disabled_gate_never_requires_otp() {
367        let gate = OtpGate::disabled();
368        assert_eq!(gate.check_action("shell"), OtpGateResult::NotGated);
369        assert_eq!(gate.check_domain("example.com"), OtpGateResult::NotGated);
370    }
371
372    #[test]
373    fn gate_requires_otp_for_gated_action() {
374        let gate = OtpGate::new(
375            true,
376            vec!["shell".into(), "file_write".into()],
377            Vec::new(),
378            Vec::new(),
379            300,
380        );
381
382        match gate.check_action("shell") {
383            OtpGateResult::RequiresOtp { reason } => {
384                assert!(reason.contains("shell"));
385            }
386            other => panic!("expected RequiresOtp, got {other:?}"),
387        }
388
389        assert_eq!(gate.check_action("file_read"), OtpGateResult::NotGated);
390    }
391
392    #[test]
393    fn gate_requires_otp_for_gated_domain() {
394        let gate = OtpGate::new(true, Vec::new(), vec!["bank.com".into()], Vec::new(), 300);
395
396        match gate.check_domain("bank.com") {
397            OtpGateResult::RequiresOtp { reason } => {
398                assert!(reason.contains("bank.com"));
399            }
400            other => panic!("expected RequiresOtp, got {other:?}"),
401        }
402
403        // Subdomain should also be gated
404        match gate.check_domain("api.bank.com") {
405            OtpGateResult::RequiresOtp { .. } => {}
406            other => panic!("expected RequiresOtp for subdomain, got {other:?}"),
407        }
408
409        assert_eq!(gate.check_domain("example.com"), OtpGateResult::NotGated);
410    }
411
412    #[test]
413    fn gate_caches_approval() {
414        let mut gate = OtpGate::new(true, vec!["shell".into()], Vec::new(), Vec::new(), 300);
415
416        // First check should require OTP
417        assert!(matches!(
418            gate.check_action("shell"),
419            OtpGateResult::RequiresOtp { .. }
420        ));
421
422        // Record approval
423        gate.approve_action("shell");
424
425        // Now should be approved (cached)
426        assert_eq!(gate.check_action("shell"), OtpGateResult::Approved);
427    }
428
429    #[test]
430    fn gate_case_insensitive_action_check() {
431        let gate = OtpGate::new(true, vec!["Shell".into()], Vec::new(), Vec::new(), 300);
432
433        assert!(matches!(
434            gate.check_action("shell"),
435            OtpGateResult::RequiresOtp { .. }
436        ));
437        assert!(matches!(
438            gate.check_action("SHELL"),
439            OtpGateResult::RequiresOtp { .. }
440        ));
441    }
442}