Skip to main content

age_otp/
engine.rs

1use crate::error::{Error, Result, VerificationError};
2use crate::types::{
3    compute_hmac, ct_eq, now_ts, truncate, validate_code_len, validate_skew_steps,
4    validate_step_secs, Charset, OtpCode, OtpSeed,
5};
6use crate::PublicKey;
7use std::fmt;
8#[derive(Clone)]
9pub struct OtpEngine {
10    seed: OtpSeed,
11}
12impl fmt::Debug for OtpEngine {
13    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
14        f.debug_struct("OtpEngine")
15            .field("seed", &self.seed)
16            .finish()
17    }
18}
19impl OtpEngine {
20    pub fn from_public_key(pk: &PublicKey) -> Result<Self> {
21        let seed = OtpSeed::from_public_key(pk)?;
22        Ok(Self { seed })
23    }
24    pub fn from_seed(seed: OtpSeed) -> Self {
25        Self { seed }
26    }
27    pub fn seed(&self) -> &OtpSeed {
28        &self.seed
29    }
30    pub fn generate(
31        &self,
32        len: usize,
33        time_step: u64,
34        step_secs: u64,
35        charset: Charset,
36    ) -> Result<OtpCode> {
37        validate_code_len(len)?;
38        validate_step_secs(step_secs)?;
39        let hash = compute_hmac(self.seed.as_bytes(), time_step)?;
40        let value = truncate(&hash, charset, len)?;
41        OtpCode::new(value, time_step, step_secs)
42    }
43    pub fn generate_default(&self, len: usize, time_step: u64) -> Result<OtpCode> {
44        self.generate(len, time_step, 30, Charset::default())
45    }
46    pub fn generate_now(&self, len: usize) -> Result<OtpCode> {
47        let step = now_ts() / 30;
48        self.generate_default(len, step)
49    }
50    pub fn verify(
51        &self,
52        code: &OtpCode,
53        time_step: u64,
54        ttl: u64,
55        step_secs: u64,
56        charset: Charset,
57    ) -> Result<()> {
58        validate_step_secs(step_secs)?;
59        let expected = self.generate(code.len(), time_step, step_secs, charset)?;
60        if !ct_eq(code.as_str().as_bytes(), expected.as_str().as_bytes()) {
61            return Err(VerificationError::Mismatch.into());
62        }
63        let now = now_ts();
64        if !code.is_valid_at(now, ttl) {
65            return Err(VerificationError::Expired(code.born_at().saturating_add(ttl), now).into());
66        }
67        Ok(())
68    }
69    pub fn verify_default(&self, code: &OtpCode, time_step: u64, ttl: u64) -> Result<()> {
70        self.verify(code, time_step, ttl, 30, Charset::default())
71    }
72    pub fn verify_raw(
73        &self,
74        raw: &str,
75        len: usize,
76        time_step: u64,
77        ttl: u64,
78        step_secs: u64,
79        charset: Charset,
80    ) -> Result<()> {
81        let expected = self.validate_and_generate(raw, len, time_step, step_secs, charset)?;
82        if !ct_eq(raw.as_bytes(), expected.as_str().as_bytes()) {
83            return Err(VerificationError::Mismatch.into());
84        }
85        let now = now_ts();
86        let born = time_step.checked_mul(step_secs).ok_or_else(|| {
87            Error::Verification(VerificationError::InvalidFormat("overflow".into()))
88        })?;
89        if now < born || now >= born.saturating_add(ttl) {
90            return Err(VerificationError::Expired(born.saturating_add(ttl), now).into());
91        }
92        Ok(())
93    }
94    pub fn verify_with_skew(
95        &self,
96        raw: &str,
97        len: usize,
98        time_step: u64,
99        ttl: u64,
100        step_secs: u64,
101        charset: Charset,
102        skew_steps: u64,
103    ) -> Result<()> {
104        validate_skew_steps(skew_steps)?;
105        validate_code_len(len)?;
106        validate_step_secs(step_secs)?;
107        if raw.len() != len {
108            return Err(VerificationError::InvalidFormat(format!(
109                "expected length {}, got {}",
110                len,
111                raw.len()
112            ))
113            .into());
114        }
115        if !charset.validate(raw) {
116            return Err(VerificationError::InvalidFormat("invalid charset".into()).into());
117        }
118        let start = time_step.saturating_sub(skew_steps);
119        let end = time_step.saturating_add(skew_steps);
120        for step in start..=end {
121            let expected = self.generate(len, step, step_secs, charset)?;
122            if ct_eq(raw.as_bytes(), expected.as_str().as_bytes()) {
123                let now = now_ts();
124                let born = step.checked_mul(step_secs).ok_or_else(|| {
125                    Error::Verification(VerificationError::InvalidFormat("overflow".into()))
126                })?;
127                if now >= born && now < born.saturating_add(ttl) {
128                    return Ok(());
129                }
130                return Err(VerificationError::Expired(born.saturating_add(ttl), now).into());
131            }
132        }
133        Err(VerificationError::Mismatch.into())
134    }
135    fn validate_and_generate(
136        &self,
137        raw: &str,
138        len: usize,
139        time_step: u64,
140        step_secs: u64,
141        charset: Charset,
142    ) -> Result<OtpCode> {
143        validate_code_len(len)?;
144        validate_step_secs(step_secs)?;
145        if raw.len() != len {
146            return Err(VerificationError::InvalidFormat(format!(
147                "expected length {}, got {}",
148                len,
149                raw.len()
150            ))
151            .into());
152        }
153        if !charset.validate(raw) {
154            return Err(VerificationError::InvalidFormat("invalid charset".into()).into());
155        }
156        self.generate(len, time_step, step_secs, charset)
157    }
158}