Skip to main content

age_otp/
types.rs

1use crate::error::{Error, GenerationError, KeyError, Result};
2use crate::PublicKey;
3use std::fmt;
4pub const SEED_LEN: usize = 32;
5pub const MIN_CODE_LEN: usize = 4;
6pub const MAX_CODE_LEN: usize = 64;
7pub const MIN_STEP_SECS: u64 = 1;
8pub const MAX_STEP_SECS: u64 = 3600;
9pub const MAX_SKEW_STEPS: u64 = 10;
10const DIGITS: &[u8] = b"0123456789";
11const ALPHANUM: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
12const HEX: &[u8] = b"0123456789abcdef";
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Charset {
15    Numeric,
16    AlphanumericUpper,
17    HexLower,
18}
19impl Charset {
20    fn chars(&self) -> &'static [u8] {
21        match self {
22            Self::Numeric => DIGITS,
23            Self::AlphanumericUpper => ALPHANUM,
24            Self::HexLower => HEX,
25        }
26    }
27    pub fn len(&self) -> usize {
28        self.chars().len()
29    }
30    pub fn validate(&self, s: &str) -> bool {
31        let chars = self.chars();
32        !s.is_empty() && s.bytes().all(|b| chars.contains(&b))
33    }
34}
35impl Default for Charset {
36    fn default() -> Self {
37        Self::Numeric
38    }
39}
40use bech32::primitives::decode::CheckedHrpstring;
41use bech32::Bech32;
42pub(crate) fn decode_age_public_key(raw: &str) -> Result<[u8; SEED_LEN]> {
43    if raw.is_empty() {
44        return Err(KeyError::Empty.into());
45    }
46    if !raw.starts_with("age1") {
47        return Err(KeyError::InvalidPrefix(raw[..raw.len().min(10)].into()).into());
48    }
49    let parsed = CheckedHrpstring::new::<Bech32>(raw)
50        .map_err(|e| KeyError::Bech32Decode(format!("Bech32 parse error: {}", e)))?;
51    let bytes: Vec<u8> = parsed.byte_iter().collect();
52    if bytes.len() != SEED_LEN {
53        return Err(KeyError::InvalidDecodedLength(bytes.len()).into());
54    }
55    let mut result = [0u8; SEED_LEN];
56    result.copy_from_slice(&bytes);
57    Ok(result)
58}
59#[derive(Clone)]
60pub struct OtpSeed {
61    bytes: [u8; SEED_LEN],
62}
63impl OtpSeed {
64    pub fn from_public_key(pk: &PublicKey) -> Result<Self> {
65        let raw = pk.expose();
66        let decoded_key = decode_age_public_key(raw)?;
67        use hkdf::Hkdf;
68        use sha2::Sha256;
69        let hk = Hkdf::<Sha256>::new(None, &decoded_key);
70        let mut seed = [0u8; SEED_LEN];
71        hk.expand(b"age-otp-v1", &mut seed)
72            .map_err(|_| GenerationError::HmacFailed)?;
73        Ok(Self { bytes: seed })
74    }
75    pub fn from_bytes(bytes: [u8; SEED_LEN]) -> Self {
76        Self { bytes }
77    }
78    pub fn as_bytes(&self) -> &[u8; SEED_LEN] {
79        &self.bytes
80    }
81    pub fn to_hex(&self) -> String {
82        self.bytes.iter().map(|b| format!("{:02x}", b)).collect()
83    }
84}
85impl fmt::Debug for OtpSeed {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.debug_struct("OtpSeed")
88            .field("hex_prefix", &&self.to_hex()[..8])
89            .finish_non_exhaustive()
90    }
91}
92#[derive(Clone, PartialEq, Eq)]
93pub struct OtpCode {
94    value: String,
95    born: u64,
96}
97impl OtpCode {
98    pub fn new(value: String, time_step: u64, step_secs: u64) -> Result<Self> {
99        let born = time_step
100            .checked_mul(step_secs)
101            .ok_or(GenerationError::Overflow)?;
102        Ok(Self { value, born })
103    }
104    pub fn as_str(&self) -> &str {
105        &self.value
106    }
107    pub fn born_at(&self) -> u64 {
108        self.born
109    }
110    pub fn len(&self) -> usize {
111        self.value.len()
112    }
113    pub fn is_empty(&self) -> bool {
114        self.value.is_empty()
115    }
116    pub fn is_valid_at(&self, ts: u64, ttl: u64) -> bool {
117        let expires = self.born.saturating_add(ttl);
118        ts >= self.born && ts < expires
119    }
120}
121impl fmt::Debug for OtpCode {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        let masked = if self.value.len() >= 2 {
124            format!("{}***", &self.value[..2])
125        } else {
126            "***".into()
127        };
128        f.debug_struct("OtpCode")
129            .field("code", &masked)
130            .field("born", &self.born)
131            .finish()
132    }
133}
134impl fmt::Display for OtpCode {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        write!(f, "{}", self.value)
137    }
138}
139impl AsRef<str> for OtpCode {
140    fn as_ref(&self) -> &str {
141        &self.value
142    }
143}
144pub fn now_ts() -> u64 {
145    std::time::SystemTime::now()
146        .duration_since(std::time::UNIX_EPOCH)
147        .map(|d| d.as_secs())
148        .unwrap_or(0)
149}
150pub fn compute_hmac(seed: &[u8; SEED_LEN], step: u64) -> Result<[u8; 32]> {
151    use hmac::{Hmac, Mac};
152    use sha2::Sha256;
153    type HmacSha256 = Hmac<Sha256>;
154    let mut mac = HmacSha256::new_from_slice(seed).map_err(|_| GenerationError::HmacFailed)?;
155    mac.update(&step.to_be_bytes());
156    let result = mac.finalize().into_bytes();
157    let mut hash = [0u8; 32];
158    hash.copy_from_slice(&result);
159    Ok(hash)
160}
161pub fn truncate(hash: &[u8; 32], charset: Charset, len: usize) -> Result<String> {
162    if len < MIN_CODE_LEN || len > MAX_CODE_LEN {
163        return Err(GenerationError::InvalidLength(format!(
164            "code length must be {}-{}, got {}",
165            MIN_CODE_LEN, MAX_CODE_LEN, len
166        ))
167        .into());
168    }
169    let offset = (hash[31] & 0x0f) as usize;
170    let binary = ((hash[offset] as u32) << 24)
171        | ((hash[offset + 1] as u32) << 16)
172        | ((hash[offset + 2] as u32) << 8)
173        | (hash[offset + 3] as u32);
174    let base = charset.len() as u64;
175    let code_val = match base.checked_pow(len as u32) {
176        Some(max_val) => (binary as u64) % max_val,
177        None => binary as u64,
178    };
179    let chars = charset.chars();
180    let mut s = String::with_capacity(len);
181    let mut rem = code_val;
182    for _ in 0..len {
183        let idx = (rem % base) as usize;
184        s.push(chars[idx] as char);
185        rem /= base;
186    }
187    let s: String = s.chars().rev().collect();
188    debug_assert_eq!(s.len(), len, "truncation produced wrong length");
189    Ok(s)
190}
191pub fn ct_eq(a: &[u8], b: &[u8]) -> bool {
192    use subtle::ConstantTimeEq;
193    if a.len() != b.len() {
194        return false;
195    }
196    a.ct_eq(b).into()
197}
198pub fn validate_code_len(len: usize) -> Result<()> {
199    if len < MIN_CODE_LEN || len > MAX_CODE_LEN {
200        return Err(Error::Generation(GenerationError::InvalidLength(format!(
201            "code length must be {}-{}, got {}",
202            MIN_CODE_LEN, MAX_CODE_LEN, len
203        ))));
204    }
205    Ok(())
206}
207pub fn validate_step_secs(secs: u64) -> Result<()> {
208    if secs < MIN_STEP_SECS || secs > MAX_STEP_SECS {
209        return Err(Error::Generation(GenerationError::TruncateFailed(format!(
210            "step_secs must be {}-{}, got {}",
211            MIN_STEP_SECS, MAX_STEP_SECS, secs
212        ))));
213    }
214    Ok(())
215}
216pub fn validate_skew_steps(skew: u64) -> Result<()> {
217    if skew > MAX_SKEW_STEPS {
218        return Err(Error::Verification(
219            crate::error::VerificationError::InvalidFormat(format!(
220                "skew_steps must be <= {}, got {}",
221                MAX_SKEW_STEPS, skew
222            )),
223        ));
224    }
225    Ok(())
226}