Skip to main content

wae_crypto/
totp.rs

1//! TOTP/HOTP 模块
2
3use crate::{
4    base64::base64_encode,
5    error::{CryptoError, CryptoResult},
6    hmac::{HmacAlgorithm, hmac_sign},
7};
8
9/// TOTP 算法
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TotpAlgorithm {
12    /// SHA-1
13    SHA1,
14    /// SHA-256
15    SHA256,
16    /// SHA-512
17    SHA512,
18}
19
20impl Default for TotpAlgorithm {
21    fn default() -> Self {
22        TotpAlgorithm::SHA1
23    }
24}
25
26/// Base32 字符集
27const BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
28
29/// TOTP 密钥
30#[derive(Debug, Clone)]
31pub struct TotpSecret {
32    bytes: Vec<u8>,
33    base32: String,
34}
35
36impl TotpSecret {
37    /// 生成新的随机密钥
38    pub fn generate(length: usize) -> CryptoResult<Self> {
39        use argon2::password_hash::rand_core::{OsRng, RngCore};
40        let mut rng = OsRng;
41        let mut bytes = vec![0u8; length];
42        rng.fill_bytes(&mut bytes);
43        Self::from_bytes(&bytes)
44    }
45
46    /// 生成推荐长度的密钥(20 字节)
47    pub fn generate_default() -> CryptoResult<Self> {
48        Self::generate(20)
49    }
50
51    /// 从字节创建密钥
52    pub fn from_bytes(bytes: &[u8]) -> CryptoResult<Self> {
53        let base32 = Self::encode_base32(bytes);
54        Ok(Self { bytes: bytes.to_vec(), base32 })
55    }
56
57    /// 从 Base32 字符串创建密钥
58    pub fn from_base32(s: &str) -> CryptoResult<Self> {
59        let bytes = Self::decode_base32(s)?;
60        Ok(Self { bytes, base32: s.to_uppercase().replace(" ", "") })
61    }
62
63    /// 获取原始字节
64    pub fn as_bytes(&self) -> &[u8] {
65        &self.bytes
66    }
67
68    /// 获取 Base32 编码
69    pub fn as_base32(&self) -> &str {
70        &self.base32
71    }
72
73    /// 获取密钥长度(字节)
74    pub fn len(&self) -> usize {
75        self.bytes.len()
76    }
77
78    /// 检查密钥是否为空
79    pub fn is_empty(&self) -> bool {
80        self.bytes.is_empty()
81    }
82
83    /// 编码为 Base32
84    fn encode_base32(data: &[u8]) -> String {
85        let mut result = String::new();
86        let mut i = 0;
87        let n = data.len();
88
89        while i < n {
90            let mut word: u64 = 0;
91            let mut bits = 0;
92
93            for j in 0..5 {
94                if i + j < n {
95                    word = (word << 8) | (data[i + j] as u64);
96                    bits += 8;
97                }
98            }
99
100            i += 5;
101
102            while bits >= 5 {
103                bits -= 5;
104                let index = ((word >> bits) & 0x1F) as usize;
105                result.push(BASE32_CHARS[index] as char);
106            }
107
108            if bits > 0 {
109                let index = ((word << (5 - bits)) & 0x1F) as usize;
110                result.push(BASE32_CHARS[index] as char);
111            }
112        }
113
114        result
115    }
116
117    /// 解码 Base32
118    fn decode_base32(s: &str) -> CryptoResult<Vec<u8>> {
119        let s = s.to_uppercase().replace(" ", "").replace("-", "");
120        let mut result = Vec::new();
121        let chars: Vec<char> = s.chars().collect();
122
123        let mut i = 0;
124        while i < chars.len() {
125            let mut word: u64 = 0;
126            let mut bits = 0;
127
128            for j in 0..8 {
129                if i + j < chars.len() {
130                    let val = Self::base32_char_to_value(chars[i + j])?;
131                    word = (word << 5) | (val as u64);
132                    bits += 5;
133                }
134            }
135
136            i += 8;
137
138            while bits >= 8 {
139                bits -= 8;
140                result.push(((word >> bits) & 0xFF) as u8);
141            }
142        }
143
144        Ok(result)
145    }
146
147    /// Base32 字符转值
148    fn base32_char_to_value(c: char) -> CryptoResult<u8> {
149        match c {
150            'A'..='Z' => Ok((c as u8) - b'A'),
151            '2'..='7' => Ok((c as u8) - b'2' + 26),
152            _ => Err(CryptoError::Base32Error(format!("Invalid character: {}", c))),
153        }
154    }
155}
156
157impl std::fmt::Display for TotpSecret {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        write!(f, "{}", self.base32)
160    }
161}
162
163/// 密钥格式化选项
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum SecretFormat {
166    /// Base32 格式(无分隔符)
167    Base32,
168    /// Base32 格式(每 4 字符空格分隔)
169    Base32Spaced,
170    /// 原始字节(十六进制)
171    Raw,
172    /// Base64 格式
173    Base64,
174}
175
176impl TotpSecret {
177    /// 格式化密钥
178    pub fn format(&self, format: SecretFormat) -> String {
179        match format {
180            SecretFormat::Base32 => self.base32.clone(),
181            SecretFormat::Base32Spaced => self
182                .base32
183                .as_bytes()
184                .chunks(4)
185                .map(|chunk| std::str::from_utf8(chunk).unwrap_or(""))
186                .collect::<Vec<_>>()
187                .join(" "),
188            SecretFormat::Raw => self.bytes.iter().map(|b| format!("{:02x}", b)).collect(),
189            SecretFormat::Base64 => base64_encode(&self.bytes),
190        }
191    }
192}
193
194/// 动态截取函数
195fn dynamic_truncate(hmac_result: &[u8], digits: u32) -> u32 {
196    let offset = (hmac_result.last().unwrap() & 0x0F) as usize;
197    let binary = ((hmac_result[offset] as u32 & 0x7F) << 24)
198        | ((hmac_result[offset + 1] as u32 & 0xFF) << 16)
199        | ((hmac_result[offset + 2] as u32 & 0xFF) << 8)
200        | (hmac_result[offset + 3] as u32 & 0xFF);
201
202    let power = 10u32.pow(digits);
203    binary % power
204}
205
206/// 计算 HMAC
207fn compute_hmac(algorithm: TotpAlgorithm, key: &[u8], counter: &[u8]) -> CryptoResult<Vec<u8>> {
208    let hmac_alg = match algorithm {
209        TotpAlgorithm::SHA1 => HmacAlgorithm::SHA1,
210        TotpAlgorithm::SHA256 => HmacAlgorithm::SHA256,
211        TotpAlgorithm::SHA512 => HmacAlgorithm::SHA512,
212    };
213    hmac_sign(hmac_alg, key, counter)
214}
215
216/// 生成 HOTP
217pub fn generate_hotp(secret: &[u8], counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<String> {
218    let counter_bytes = counter.to_be_bytes();
219    let hmac_result = compute_hmac(algorithm, secret, &counter_bytes)?;
220    let code = dynamic_truncate(&hmac_result, digits);
221    Ok(format!("{:0width$}", code, width = digits as usize))
222}
223
224/// 生成 TOTP
225pub fn generate_totp(
226    secret: &[u8],
227    timestamp: u64,
228    time_step: u64,
229    digits: u32,
230    algorithm: TotpAlgorithm,
231) -> CryptoResult<String> {
232    let counter = timestamp / time_step;
233    generate_hotp(secret, counter, digits, algorithm)
234}
235
236/// 验证 TOTP 码
237pub fn verify_totp(
238    secret: &[u8],
239    code: &str,
240    timestamp: u64,
241    time_step: u64,
242    digits: u32,
243    algorithm: TotpAlgorithm,
244    window: u32,
245) -> CryptoResult<bool> {
246    let current_counter = timestamp / time_step;
247
248    for i in -(window as i64)..=(window as i64) {
249        let counter = (current_counter as i64 + i) as u64;
250        let expected = generate_hotp(secret, counter, digits, algorithm)?;
251
252        if constant_time_compare(code, &expected) {
253            return Ok(true);
254        }
255    }
256
257    Ok(false)
258}
259
260/// 验证 HOTP 码
261pub fn verify_hotp(secret: &[u8], code: &str, counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<bool> {
262    let expected = generate_hotp(secret, counter, digits, algorithm)?;
263    Ok(constant_time_compare(code, &expected))
264}
265
266/// 常量时间比较(防止时序攻击)
267fn constant_time_compare(a: &str, b: &str) -> bool {
268    if a.len() != b.len() {
269        return false;
270    }
271
272    let a_bytes = a.as_bytes();
273    let b_bytes = b.as_bytes();
274
275    let mut result = 0u8;
276    for i in 0..a.len() {
277        result |= a_bytes[i] ^ b_bytes[i];
278    }
279
280    result == 0
281}
282
283/// 获取当前时间步
284pub fn get_time_step(timestamp: u64, time_step: u64) -> u64 {
285    timestamp / time_step
286}
287
288/// 获取当前步剩余时间
289pub fn get_remaining_seconds(timestamp: u64, time_step: u64) -> u64 {
290    time_step - (timestamp % time_step)
291}