1use constant_time_eq::constant_time_eq;
19use hmac::{Hmac, Mac};
20use rand::RngCore;
21use std::time::SystemTime;
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum Error {
26 SystemTimeError(String),
28 Base32DecodeError,
30}
31
32impl std::fmt::Display for Error {
33 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 match &self {
35 Error::SystemTimeError(e) => write!(f, "System time error: {}", e),
36 Error::Base32DecodeError => write!(f, "Base32 decode error"),
37 }
38 }
39}
40
41impl std::error::Error for Error {}
42
43pub type Result<T> = std::result::Result<T, Error>;
45
46fn system_time() -> Result<u64> {
47 SystemTime::now()
48 .duration_since(SystemTime::UNIX_EPOCH)
49 .map(|d| d.as_secs())
50 .map_err(|e| Error::SystemTimeError(e.to_string()))
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum Algorithm {
56 HmacSha1,
58 HmacSha256,
60 HmacSha512,
62}
63
64impl Default for Algorithm {
65 fn default() -> Self {
66 Algorithm::HmacSha1
67 }
68}
69
70impl Algorithm {
71 fn digest<M: Mac>(mut mac: M, data: &[u8]) -> Vec<u8> {
72 mac.update(data);
73 mac.finalize().into_bytes().to_vec()
74 }
75
76 fn sign(&self, key: &[u8], data: &[u8]) -> Vec<u8> {
77 match self {
78 Algorithm::HmacSha1 => {
79 Algorithm::digest(Hmac::<sha1::Sha1>::new_from_slice(key).unwrap(), data)
80 }
81 Algorithm::HmacSha256 => {
82 Algorithm::digest(Hmac::<sha2::Sha256>::new_from_slice(key).unwrap(), data)
83 }
84 Algorithm::HmacSha512 => {
85 Algorithm::digest(Hmac::<sha2::Sha512>::new_from_slice(key).unwrap(), data)
86 }
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct Secret {
94 key: Vec<u8>,
96}
97
98impl PartialEq for Secret {
99 fn eq(&self, other: &Self) -> bool {
100 constant_time_eq(&self.key, &other.key)
101 }
102}
103
104impl Eq for Secret {}
105
106impl Secret {
107 pub fn new() -> Self {
109 let mut key = vec![0u8; 16];
110 rand::thread_rng().fill_bytes(&mut key);
111 Self { key }
112 }
113
114 pub fn from_base32(key: &str) -> Result<Self> {
116 let key = base32::decode(base32::Alphabet::RFC4648 { padding: false }, key);
117 if key.is_none() {
118 return Err(Error::Base32DecodeError);
119 }
120 Ok(Self { key: key.unwrap() })
121 }
122
123 pub fn from_raw(key: &str) -> Self {
125 Self {
126 key: key.as_bytes().to_vec(),
127 }
128 }
129
130 pub fn from_slice(key: &[u8]) -> Self {
132 Self { key: key.to_vec() }
133 }
134
135 pub fn raw(&self) -> String {
137 String::from_utf8(self.key.clone()).unwrap()
138 }
139
140 pub fn base32(&self) -> String {
142 base32::encode(base32::Alphabet::RFC4648 { padding: false }, &self.key)
143 }
144
145 pub fn as_slice(&self) -> &[u8] {
147 &self.key
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq)]
153pub struct TOTP {
154 algorithm: Algorithm,
156 digits: usize,
158 window: u64,
160 time_step: u64,
162 key: Secret,
164}
165
166impl TOTP {
167 pub fn new(
169 algorithm: Algorithm,
170 digits: usize,
171 window: u64,
172 time_step: u64,
173 key: Secret,
174 ) -> Self {
175 Self::new_unchecked(algorithm, digits, window, time_step, key)
176 }
177
178 pub fn new_unchecked(
180 algorithm: Algorithm,
181 digits: usize,
182 window: u64,
183 time_step: u64,
184 key: Secret,
185 ) -> Self {
186 Self {
187 algorithm,
188 digits,
189 window,
190 time_step,
191 key,
192 }
193 }
194
195 pub fn generate(&self, time: u64) -> String {
197 let hash = self.algorithm.sign(
198 self.key.as_slice(),
199 (time / self.time_step).to_be_bytes().as_ref(),
200 );
201 let offset = (hash.last().unwrap() & 0xf) as usize;
202 let result = u32::from_be_bytes(hash[offset..offset + 4].try_into().unwrap()) & 0x7fffffff;
203 format!(
204 "{1:00$}",
205 self.digits,
206 result % 10_u32.pow(self.digits as u32)
207 )
208 }
209
210 pub fn generate_now(&self) -> Result<String> {
212 let time = system_time()?;
213 Ok(self.generate(time))
214 }
215
216 pub fn next_time_step(&self, time: u64) -> u64 {
218 let time_step = time / self.time_step;
219 (time_step + 1) * self.time_step
220 }
221
222 pub fn ttl(&self) -> Result<u64> {
224 let time = system_time()?;
225 let remain = self.time_step - (time % self.time_step);
226 Ok(remain)
227 }
228
229 pub fn verify(&self, token: &str) -> Result<bool> {
231 let time = system_time()?;
232 Ok(self.verify_with_time(token, time))
233 }
234
235 pub fn verify_with_time(&self, token: &str, time: u64) -> bool {
237 let step = time / self.time_step - self.window;
238 for i in 0..self.window * 2 + 1 {
239 let t = (step + i) * self.time_step;
240 let code = self.generate(t);
241 if constant_time_eq(token.as_bytes(), code.as_bytes()) {
242 return true;
243 }
244 }
245 false
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_totp() {
255 let secret =
256 Secret::from_base32("KLOKQRZDX46UDXHSOUUM32PKUPPKGI7LM4U3G4GEDVUTSNXLTLNA").unwrap();
257 let totp = TOTP::new(Algorithm::HmacSha1, 6, 1, 30, secret);
258 let code = totp.generate_now().unwrap();
259 assert_eq!(code.len(), 6);
260 assert!(totp.verify(&code).unwrap());
261 }
262}