Skip to main content

wae_authentication/totp/
secret.rs

1//! TOTP 密钥管理
2
3use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
4use rand::Rng;
5use wae_types::{WaeError, WaeErrorKind};
6
7/// TOTP 结果类型
8pub type TotpResult<T> = Result<T, WaeError>;
9
10/// Base32 字符集
11const BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
12
13/// TOTP 密钥
14#[derive(Debug, Clone)]
15pub struct TotpSecret {
16    /// 原始字节
17    bytes: Vec<u8>,
18
19    /// Base32 编码
20    base32: String,
21}
22
23impl TotpSecret {
24    /// 生成新的随机密钥
25    ///
26    /// # Arguments
27    /// * `length` - 密钥长度 (字节),推荐 20 字节 (160 位)
28    pub fn generate(length: usize) -> TotpResult<Self> {
29        let mut bytes = vec![0u8; length];
30        rand::rng().fill_bytes(&mut bytes);
31        Self::from_bytes(&bytes)
32    }
33
34    /// 生成推荐长度的密钥 (20 字节)
35    pub fn generate_default() -> TotpResult<Self> {
36        Self::generate(20)
37    }
38
39    /// 从字节创建密钥
40    pub fn from_bytes(bytes: &[u8]) -> TotpResult<Self> {
41        let base32 = Self::encode_base32(bytes);
42        Ok(Self { bytes: bytes.to_vec(), base32 })
43    }
44
45    /// 从 Base32 字符串创建密钥
46    pub fn from_base32(s: &str) -> TotpResult<Self> {
47        let bytes = Self::decode_base32(s)?;
48        Ok(Self { bytes, base32: s.to_uppercase().replace(" ", "") })
49    }
50
51    /// 获取原始字节
52    pub fn as_bytes(&self) -> &[u8] {
53        &self.bytes
54    }
55
56    /// 获取 Base32 编码
57    pub fn as_base32(&self) -> &str {
58        &self.base32
59    }
60
61    /// 获取密钥长度 (字节)
62    pub fn len(&self) -> usize {
63        self.bytes.len()
64    }
65
66    /// 检查密钥是否为空
67    pub fn is_empty(&self) -> bool {
68        self.bytes.is_empty()
69    }
70
71    /// 编码为 Base32
72    fn encode_base32(data: &[u8]) -> String {
73        let mut result = String::new();
74        let mut i = 0;
75        let n = data.len();
76
77        while i < n {
78            let mut word: u64 = 0;
79            let mut bits = 0;
80
81            for j in 0..5 {
82                if i + j < n {
83                    word = (word << 8) | (data[i + j] as u64);
84                    bits += 8;
85                }
86            }
87
88            i += 5;
89
90            while bits >= 5 {
91                bits -= 5;
92                let index = ((word >> bits) & 0x1F) as usize;
93                result.push(BASE32_CHARS[index] as char);
94            }
95
96            if bits > 0 {
97                let index = ((word << (5 - bits)) & 0x1F) as usize;
98                result.push(BASE32_CHARS[index] as char);
99            }
100        }
101
102        result
103    }
104
105    /// 解码 Base32
106    fn decode_base32(s: &str) -> TotpResult<Vec<u8>> {
107        let s = s.to_uppercase().replace(" ", "").replace("-", "");
108        let mut result = Vec::new();
109        let chars: Vec<char> = s.chars().collect();
110
111        let mut i = 0;
112        while i < chars.len() {
113            let mut word: u64 = 0;
114            let mut bits = 0;
115
116            for j in 0..8 {
117                if i + j < chars.len() {
118                    let val = Self::base32_char_to_value(chars[i + j])?;
119                    word = (word << 5) | (val as u64);
120                    bits += 5;
121                }
122            }
123
124            i += 8;
125
126            while bits >= 8 {
127                bits -= 8;
128                result.push(((word >> bits) & 0xFF) as u8);
129            }
130        }
131
132        Ok(result)
133    }
134
135    /// Base32 字符转值
136    fn base32_char_to_value(c: char) -> TotpResult<u8> {
137        match c {
138            'A'..='Z' => Ok((c as u8) - b'A'),
139            '2'..='7' => Ok((c as u8) - b'2' + 26),
140            _ => Err(WaeError::new(WaeErrorKind::Base32Error { reason: format!("Invalid character: {}", c) })),
141        }
142    }
143}
144
145impl std::fmt::Display for TotpSecret {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        write!(f, "{}", self.base32)
148    }
149}
150
151/// 密钥格式化选项
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153pub enum SecretFormat {
154    /// Base32 格式 (无分隔符)
155    Base32,
156    /// Base32 格式 (每 4 字符空格分隔)
157    Base32Spaced,
158    /// 原始字节
159    Raw,
160    /// Base64 格式
161    Base64,
162}
163
164impl TotpSecret {
165    /// 格式化密钥
166    pub fn format(&self, format: SecretFormat) -> String {
167        match format {
168            SecretFormat::Base32 => self.base32.clone(),
169            SecretFormat::Base32Spaced => self
170                .base32
171                .as_bytes()
172                .chunks(4)
173                .map(|chunk| std::str::from_utf8(chunk).unwrap_or(""))
174                .collect::<Vec<_>>()
175                .join(" "),
176            SecretFormat::Raw => self.bytes.iter().map(|b| format!("{:02x}", b)).collect(),
177            SecretFormat::Base64 => BASE64_STANDARD.encode(&self.bytes),
178        }
179    }
180}