1use crate::{
4 base64::base64_encode,
5 error::{CryptoError, CryptoResult},
6 hmac::{HmacAlgorithm, hmac_sign},
7};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TotpAlgorithm {
12 SHA1,
14 SHA256,
16 SHA512,
18}
19
20impl Default for TotpAlgorithm {
21 fn default() -> Self {
22 TotpAlgorithm::SHA1
23 }
24}
25
26const BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
28
29#[derive(Debug, Clone)]
31pub struct TotpSecret {
32 bytes: Vec<u8>,
33 base32: String,
34}
35
36impl TotpSecret {
37 pub fn generate(length: usize) -> CryptoResult<Self> {
39 use argon2::password_hash::rand_core::{OsRng, RngCore};
40 let mut rng = OsRng;
41 let mut bytes = vec![0u8; length];
42 rng.fill_bytes(&mut bytes);
43 Self::from_bytes(&bytes)
44 }
45
46 pub fn generate_default() -> CryptoResult<Self> {
48 Self::generate(20)
49 }
50
51 pub fn from_bytes(bytes: &[u8]) -> CryptoResult<Self> {
53 let base32 = Self::encode_base32(bytes);
54 Ok(Self { bytes: bytes.to_vec(), base32 })
55 }
56
57 pub fn from_base32(s: &str) -> CryptoResult<Self> {
59 let bytes = Self::decode_base32(s)?;
60 Ok(Self { bytes, base32: s.to_uppercase().replace(" ", "") })
61 }
62
63 pub fn as_bytes(&self) -> &[u8] {
65 &self.bytes
66 }
67
68 pub fn as_base32(&self) -> &str {
70 &self.base32
71 }
72
73 pub fn len(&self) -> usize {
75 self.bytes.len()
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.bytes.is_empty()
81 }
82
83 fn encode_base32(data: &[u8]) -> String {
85 let mut result = String::new();
86 let mut i = 0;
87 let n = data.len();
88
89 while i < n {
90 let mut word: u64 = 0;
91 let mut bits = 0;
92
93 for j in 0..5 {
94 if i + j < n {
95 word = (word << 8) | (data[i + j] as u64);
96 bits += 8;
97 }
98 }
99
100 i += 5;
101
102 while bits >= 5 {
103 bits -= 5;
104 let index = ((word >> bits) & 0x1F) as usize;
105 result.push(BASE32_CHARS[index] as char);
106 }
107
108 if bits > 0 {
109 let index = ((word << (5 - bits)) & 0x1F) as usize;
110 result.push(BASE32_CHARS[index] as char);
111 }
112 }
113
114 result
115 }
116
117 fn decode_base32(s: &str) -> CryptoResult<Vec<u8>> {
119 let s = s.to_uppercase().replace(" ", "").replace("-", "");
120 let mut result = Vec::new();
121 let chars: Vec<char> = s.chars().collect();
122
123 let mut i = 0;
124 while i < chars.len() {
125 let mut word: u64 = 0;
126 let mut bits = 0;
127
128 for j in 0..8 {
129 if i + j < chars.len() {
130 let val = Self::base32_char_to_value(chars[i + j])?;
131 word = (word << 5) | (val as u64);
132 bits += 5;
133 }
134 }
135
136 i += 8;
137
138 while bits >= 8 {
139 bits -= 8;
140 result.push(((word >> bits) & 0xFF) as u8);
141 }
142 }
143
144 Ok(result)
145 }
146
147 fn base32_char_to_value(c: char) -> CryptoResult<u8> {
149 match c {
150 'A'..='Z' => Ok((c as u8) - b'A'),
151 '2'..='7' => Ok((c as u8) - b'2' + 26),
152 _ => Err(CryptoError::Base32Error(format!("Invalid character: {}", c))),
153 }
154 }
155}
156
157impl std::fmt::Display for TotpSecret {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 write!(f, "{}", self.base32)
160 }
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum SecretFormat {
166 Base32,
168 Base32Spaced,
170 Raw,
172 Base64,
174}
175
176impl TotpSecret {
177 pub fn format(&self, format: SecretFormat) -> String {
179 match format {
180 SecretFormat::Base32 => self.base32.clone(),
181 SecretFormat::Base32Spaced => self
182 .base32
183 .as_bytes()
184 .chunks(4)
185 .map(|chunk| std::str::from_utf8(chunk).unwrap_or(""))
186 .collect::<Vec<_>>()
187 .join(" "),
188 SecretFormat::Raw => self.bytes.iter().map(|b| format!("{:02x}", b)).collect(),
189 SecretFormat::Base64 => base64_encode(&self.bytes),
190 }
191 }
192}
193
194fn dynamic_truncate(hmac_result: &[u8], digits: u32) -> u32 {
196 let offset = (hmac_result.last().unwrap() & 0x0F) as usize;
197 let binary = ((hmac_result[offset] as u32 & 0x7F) << 24)
198 | ((hmac_result[offset + 1] as u32 & 0xFF) << 16)
199 | ((hmac_result[offset + 2] as u32 & 0xFF) << 8)
200 | (hmac_result[offset + 3] as u32 & 0xFF);
201
202 let power = 10u32.pow(digits);
203 binary % power
204}
205
206fn compute_hmac(algorithm: TotpAlgorithm, key: &[u8], counter: &[u8]) -> CryptoResult<Vec<u8>> {
208 let hmac_alg = match algorithm {
209 TotpAlgorithm::SHA1 => HmacAlgorithm::SHA1,
210 TotpAlgorithm::SHA256 => HmacAlgorithm::SHA256,
211 TotpAlgorithm::SHA512 => HmacAlgorithm::SHA512,
212 };
213 hmac_sign(hmac_alg, key, counter)
214}
215
216pub fn generate_hotp(secret: &[u8], counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<String> {
218 let counter_bytes = counter.to_be_bytes();
219 let hmac_result = compute_hmac(algorithm, secret, &counter_bytes)?;
220 let code = dynamic_truncate(&hmac_result, digits);
221 Ok(format!("{:0width$}", code, width = digits as usize))
222}
223
224pub fn generate_totp(
226 secret: &[u8],
227 timestamp: u64,
228 time_step: u64,
229 digits: u32,
230 algorithm: TotpAlgorithm,
231) -> CryptoResult<String> {
232 let counter = timestamp / time_step;
233 generate_hotp(secret, counter, digits, algorithm)
234}
235
236pub fn verify_totp(
238 secret: &[u8],
239 code: &str,
240 timestamp: u64,
241 time_step: u64,
242 digits: u32,
243 algorithm: TotpAlgorithm,
244 window: u32,
245) -> CryptoResult<bool> {
246 let current_counter = timestamp / time_step;
247
248 for i in -(window as i64)..=(window as i64) {
249 let counter = (current_counter as i64 + i) as u64;
250 let expected = generate_hotp(secret, counter, digits, algorithm)?;
251
252 if constant_time_compare(code, &expected) {
253 return Ok(true);
254 }
255 }
256
257 Ok(false)
258}
259
260pub fn verify_hotp(secret: &[u8], code: &str, counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<bool> {
262 let expected = generate_hotp(secret, counter, digits, algorithm)?;
263 Ok(constant_time_compare(code, &expected))
264}
265
266fn constant_time_compare(a: &str, b: &str) -> bool {
268 if a.len() != b.len() {
269 return false;
270 }
271
272 let a_bytes = a.as_bytes();
273 let b_bytes = b.as_bytes();
274
275 let mut result = 0u8;
276 for i in 0..a.len() {
277 result |= a_bytes[i] ^ b_bytes[i];
278 }
279
280 result == 0
281}
282
283pub fn get_time_step(timestamp: u64, time_step: u64) -> u64 {
285 timestamp / time_step
286}
287
288pub fn get_remaining_seconds(timestamp: u64, time_step: u64) -> u64 {
290 time_step - (timestamp % time_step)
291}