Skip to main content

wae_crypto/
totp.rs

1//! TOTP/HOTP 模块
2
3use crate::{
4    base64::base64_encode,
5    error::{CryptoError, CryptoResult},
6    hmac::{HmacAlgorithm, hmac_sign},
7};
8use rand::Rng;
9
10/// TOTP 算法
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TotpAlgorithm {
13    /// SHA-1
14    SHA1,
15    /// SHA-256
16    SHA256,
17    /// SHA-512
18    SHA512,
19}
20
21impl Default for TotpAlgorithm {
22    fn default() -> Self {
23        TotpAlgorithm::SHA1
24    }
25}
26
27/// Base32 字符集
28const BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
29
30/// TOTP 密钥
31#[derive(Debug, Clone)]
32pub struct TotpSecret {
33    bytes: Vec<u8>,
34    base32: String,
35}
36
37impl TotpSecret {
38    /// 生成新的随机密钥
39    pub fn generate(length: usize) -> CryptoResult<Self> {
40        let mut bytes = vec![0u8; length];
41        rand::rng().fill_bytes(&mut bytes);
42        Self::from_bytes(&bytes)
43    }
44
45    /// 生成推荐长度的密钥(20 字节)
46    pub fn generate_default() -> CryptoResult<Self> {
47        Self::generate(20)
48    }
49
50    /// 从字节创建密钥
51    pub fn from_bytes(bytes: &[u8]) -> CryptoResult<Self> {
52        let base32 = Self::encode_base32(bytes);
53        Ok(Self { bytes: bytes.to_vec(), base32 })
54    }
55
56    /// 从 Base32 字符串创建密钥
57    pub fn from_base32(s: &str) -> CryptoResult<Self> {
58        let bytes = Self::decode_base32(s)?;
59        Ok(Self { bytes, base32: s.to_uppercase().replace(" ", "") })
60    }
61
62    /// 获取原始字节
63    pub fn as_bytes(&self) -> &[u8] {
64        &self.bytes
65    }
66
67    /// 获取 Base32 编码
68    pub fn as_base32(&self) -> &str {
69        &self.base32
70    }
71
72    /// 获取密钥长度(字节)
73    pub fn len(&self) -> usize {
74        self.bytes.len()
75    }
76
77    /// 检查密钥是否为空
78    pub fn is_empty(&self) -> bool {
79        self.bytes.is_empty()
80    }
81
82    /// 编码为 Base32
83    fn encode_base32(data: &[u8]) -> String {
84        let mut result = String::new();
85        let mut i = 0;
86        let n = data.len();
87
88        while i < n {
89            let mut word: u64 = 0;
90            let mut bits = 0;
91
92            for j in 0..5 {
93                if i + j < n {
94                    word = (word << 8) | (data[i + j] as u64);
95                    bits += 8;
96                }
97            }
98
99            i += 5;
100
101            while bits >= 5 {
102                bits -= 5;
103                let index = ((word >> bits) & 0x1F) as usize;
104                result.push(BASE32_CHARS[index] as char);
105            }
106
107            if bits > 0 {
108                let index = ((word << (5 - bits)) & 0x1F) as usize;
109                result.push(BASE32_CHARS[index] as char);
110            }
111        }
112
113        result
114    }
115
116    /// 解码 Base32
117    fn decode_base32(s: &str) -> CryptoResult<Vec<u8>> {
118        let s = s.to_uppercase().replace(" ", "").replace("-", "");
119        let mut result = Vec::new();
120        let chars: Vec<char> = s.chars().collect();
121
122        let mut i = 0;
123        while i < chars.len() {
124            let mut word: u64 = 0;
125            let mut bits = 0;
126
127            for j in 0..8 {
128                if i + j < chars.len() {
129                    let val = Self::base32_char_to_value(chars[i + j])?;
130                    word = (word << 5) | (val as u64);
131                    bits += 5;
132                }
133            }
134
135            i += 8;
136
137            while bits >= 8 {
138                bits -= 8;
139                result.push(((word >> bits) & 0xFF) as u8);
140            }
141        }
142
143        Ok(result)
144    }
145
146    /// Base32 字符转值
147    fn base32_char_to_value(c: char) -> CryptoResult<u8> {
148        match c {
149            'A'..='Z' => Ok((c as u8) - b'A'),
150            '2'..='7' => Ok((c as u8) - b'2' + 26),
151            _ => Err(CryptoError::Base32Error(format!("Invalid character: {}", c))),
152        }
153    }
154}
155
156impl std::fmt::Display for TotpSecret {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        write!(f, "{}", self.base32)
159    }
160}
161
162/// 密钥格式化选项
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum SecretFormat {
165    /// Base32 格式(无分隔符)
166    Base32,
167    /// Base32 格式(每 4 字符空格分隔)
168    Base32Spaced,
169    /// 原始字节(十六进制)
170    Raw,
171    /// Base64 格式
172    Base64,
173}
174
175impl TotpSecret {
176    /// 格式化密钥
177    pub fn format(&self, format: SecretFormat) -> String {
178        match format {
179            SecretFormat::Base32 => self.base32.clone(),
180            SecretFormat::Base32Spaced => self
181                .base32
182                .as_bytes()
183                .chunks(4)
184                .map(|chunk| std::str::from_utf8(chunk).unwrap_or(""))
185                .collect::<Vec<_>>()
186                .join(" "),
187            SecretFormat::Raw => self.bytes.iter().map(|b| format!("{:02x}", b)).collect(),
188            SecretFormat::Base64 => base64_encode(&self.bytes),
189        }
190    }
191}
192
193/// 动态截取函数
194fn dynamic_truncate(hmac_result: &[u8], digits: u32) -> u32 {
195    let offset = (hmac_result.last().unwrap() & 0x0F) as usize;
196    let binary = ((hmac_result[offset] as u32 & 0x7F) << 24)
197        | ((hmac_result[offset + 1] as u32 & 0xFF) << 16)
198        | ((hmac_result[offset + 2] as u32 & 0xFF) << 8)
199        | (hmac_result[offset + 3] as u32 & 0xFF);
200
201    let power = 10u32.pow(digits);
202    binary % power
203}
204
205/// 计算 HMAC
206fn compute_hmac(algorithm: TotpAlgorithm, key: &[u8], counter: &[u8]) -> CryptoResult<Vec<u8>> {
207    let hmac_alg = match algorithm {
208        TotpAlgorithm::SHA1 => HmacAlgorithm::SHA1,
209        TotpAlgorithm::SHA256 => HmacAlgorithm::SHA256,
210        TotpAlgorithm::SHA512 => HmacAlgorithm::SHA512,
211    };
212    hmac_sign(hmac_alg, key, counter)
213}
214
215/// 生成 HOTP
216pub fn generate_hotp(secret: &[u8], counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<String> {
217    let counter_bytes = counter.to_be_bytes();
218    let hmac_result = compute_hmac(algorithm, secret, &counter_bytes)?;
219    let code = dynamic_truncate(&hmac_result, digits);
220    Ok(format!("{:0width$}", code, width = digits as usize))
221}
222
223/// 生成 TOTP
224pub fn generate_totp(
225    secret: &[u8],
226    timestamp: u64,
227    time_step: u64,
228    digits: u32,
229    algorithm: TotpAlgorithm,
230) -> CryptoResult<String> {
231    let counter = timestamp / time_step;
232    generate_hotp(secret, counter, digits, algorithm)
233}
234
235/// 验证 TOTP 码
236pub fn verify_totp(
237    secret: &[u8],
238    code: &str,
239    timestamp: u64,
240    time_step: u64,
241    digits: u32,
242    algorithm: TotpAlgorithm,
243    window: u32,
244) -> CryptoResult<bool> {
245    let current_counter = timestamp / time_step;
246
247    for i in -(window as i64)..=(window as i64) {
248        let counter = (current_counter as i64 + i) as u64;
249        let expected = generate_hotp(secret, counter, digits, algorithm)?;
250
251        if constant_time_compare(code, &expected) {
252            return Ok(true);
253        }
254    }
255
256    Ok(false)
257}
258
259/// 验证 HOTP 码
260pub fn verify_hotp(secret: &[u8], code: &str, counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<bool> {
261    let expected = generate_hotp(secret, counter, digits, algorithm)?;
262    Ok(constant_time_compare(code, &expected))
263}
264
265/// 常量时间比较(防止时序攻击)
266fn constant_time_compare(a: &str, b: &str) -> bool {
267    if a.len() != b.len() {
268        return false;
269    }
270
271    let a_bytes = a.as_bytes();
272    let b_bytes = b.as_bytes();
273
274    let mut result = 0u8;
275    for i in 0..a.len() {
276        result |= a_bytes[i] ^ b_bytes[i];
277    }
278
279    result == 0
280}
281
282/// 获取当前时间步
283pub fn get_time_step(timestamp: u64, time_step: u64) -> u64 {
284    timestamp / time_step
285}
286
287/// 获取当前步剩余时间
288pub fn get_remaining_seconds(timestamp: u64, time_step: u64) -> u64 {
289    time_step - (timestamp % time_step)
290}