pylon_plugin/builtin/
totp.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use hmac::{Hmac, Mac};
6use sha1::Sha1;
7
8use crate::Plugin;
9
10type HmacSha1 = Hmac<Sha1>;
11
12pub struct TotpEnrollment {
22 pub user_id: String,
23 pub secret: String,
24 pub verified: bool,
25 pub last_accepted_counter: Option<u64>,
26}
27
28impl Clone for TotpEnrollment {
29 fn clone(&self) -> Self {
30 Self {
31 user_id: self.user_id.clone(),
32 secret: self.secret.clone(),
33 verified: self.verified,
34 last_accepted_counter: self.last_accepted_counter,
35 }
36 }
37}
38
39pub struct TotpPlugin {
42 enrollments: Mutex<HashMap<String, TotpEnrollment>>,
43 pub enforce: bool,
45 pub protected_actions: Vec<String>,
47}
48
49impl TotpPlugin {
50 pub fn new() -> Self {
51 Self {
52 enrollments: Mutex::new(HashMap::new()),
53 enforce: false,
54 protected_actions: vec![],
55 }
56 }
57
58 pub fn enforced(protected_actions: Vec<String>) -> Self {
59 Self {
60 enrollments: Mutex::new(HashMap::new()),
61 enforce: true,
62 protected_actions,
63 }
64 }
65
66 pub fn enroll(&self, user_id: &str) -> String {
68 let secret = generate_secret();
69 self.enrollments.lock().unwrap().insert(
70 user_id.to_string(),
71 TotpEnrollment {
72 user_id: user_id.to_string(),
73 secret: secret.clone(),
74 verified: false,
75 last_accepted_counter: None,
76 },
77 );
78 secret
79 }
80
81 pub fn verify(&self, user_id: &str, code: &str) -> bool {
88 let now = SystemTime::now()
89 .duration_since(UNIX_EPOCH)
90 .unwrap_or_default()
91 .as_secs();
92 let counter = now / 30;
93
94 let mut enrollments = self.enrollments.lock().unwrap();
95 let enrollment = match enrollments.get_mut(user_id) {
96 Some(e) => e,
97 None => return false,
98 };
99
100 if enrollment.last_accepted_counter == Some(counter) {
104 return false;
105 }
106
107 let expected = generate_totp_at(&enrollment.secret, now);
108 if pylon_auth::constant_time_eq(expected.as_bytes(), code.as_bytes()) {
109 enrollment.verified = true;
110 enrollment.last_accepted_counter = Some(counter);
111 return true;
112 }
113 false
114 }
115
116 pub fn is_verified(&self, user_id: &str) -> bool {
118 self.enrollments
119 .lock()
120 .unwrap()
121 .get(user_id)
122 .map(|e| e.verified)
123 .unwrap_or(false)
124 }
125
126 pub fn is_enrolled(&self, user_id: &str) -> bool {
128 self.enrollments.lock().unwrap().contains_key(user_id)
129 }
130
131 pub fn current_code(&self, user_id: &str) -> Option<String> {
133 let enrollments = self.enrollments.lock().unwrap();
134 let enrollment = enrollments.get(user_id)?;
135 Some(generate_totp(&enrollment.secret))
136 }
137
138 pub fn unenroll(&self, user_id: &str) -> bool {
140 self.enrollments.lock().unwrap().remove(user_id).is_some()
141 }
142}
143
144impl Plugin for TotpPlugin {
145 fn name(&self) -> &str {
146 "totp-2fa"
147 }
148}
149
150fn generate_secret() -> String {
152 use rand::Rng;
153 let mut rng = rand::thread_rng();
154 let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
155 (0..16)
156 .map(|_| chars[rng.gen_range(0..32)] as char)
157 .collect()
158}
159
160fn generate_totp(secret: &str) -> String {
163 let ts = SystemTime::now()
164 .duration_since(UNIX_EPOCH)
165 .unwrap_or_default()
166 .as_secs();
167 generate_totp_at(secret, ts)
168}
169
170fn generate_totp_at(secret: &str, unix_secs: u64) -> String {
173 let counter = unix_secs / 30;
174 let counter_bytes = counter.to_be_bytes();
175
176 let mut mac =
177 HmacSha1::new_from_slice(secret.as_bytes()).expect("HMAC can take key of any size");
178 mac.update(&counter_bytes);
179 let result = mac.finalize().into_bytes();
180 let hash = result.as_slice();
181
182 let offset = (hash[hash.len() - 1] & 0x0f) as usize;
184 let binary = ((hash[offset] as u32 & 0x7f) << 24)
185 | ((hash[offset + 1] as u32) << 16)
186 | ((hash[offset + 2] as u32) << 8)
187 | (hash[offset + 3] as u32);
188
189 format!("{:06}", binary % 1_000_000)
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn enroll_and_verify() {
198 let plugin = TotpPlugin::new();
199 let secret = plugin.enroll("user-1");
200 assert!(!secret.is_empty());
201 assert!(plugin.is_enrolled("user-1"));
202 assert!(!plugin.is_verified("user-1"));
203
204 let code = plugin.current_code("user-1").unwrap();
205 assert!(plugin.verify("user-1", &code));
206 assert!(plugin.is_verified("user-1"));
207 }
208
209 #[test]
210 fn wrong_code_rejected() {
211 let plugin = TotpPlugin::new();
212 plugin.enroll("user-1");
213 assert!(!plugin.verify("user-1", "000000"));
214 assert!(!plugin.is_verified("user-1"));
215 }
216
217 #[test]
218 fn code_cannot_be_replayed_in_same_window() {
219 let plugin = TotpPlugin::new();
223 plugin.enroll("user-1");
224 let code = plugin.current_code("user-1").unwrap();
225 assert!(
226 plugin.verify("user-1", &code),
227 "first verify should succeed"
228 );
229 assert!(
230 !plugin.verify("user-1", &code),
231 "replay within the same window must be rejected"
232 );
233 }
234
235 #[test]
236 fn not_enrolled_returns_none() {
237 let plugin = TotpPlugin::new();
238 assert!(plugin.current_code("user-1").is_none());
239 assert!(!plugin.is_enrolled("user-1"));
240 }
241
242 #[test]
243 fn unenroll() {
244 let plugin = TotpPlugin::new();
245 plugin.enroll("user-1");
246 assert!(plugin.unenroll("user-1"));
247 assert!(!plugin.is_enrolled("user-1"));
248 assert!(!plugin.unenroll("user-1")); }
250
251 #[test]
252 fn code_is_six_digits() {
253 let plugin = TotpPlugin::new();
254 plugin.enroll("user-1");
255 let code = plugin.current_code("user-1").unwrap();
256 assert_eq!(code.len(), 6);
257 assert!(code.chars().all(|c| c.is_ascii_digit()));
258 }
259
260 #[test]
261 fn different_users_different_secrets() {
262 let plugin = TotpPlugin::new();
263 let s1 = plugin.enroll("user-1");
264 let s2 = plugin.enroll("user-2");
265 assert!(!s1.is_empty());
266 assert!(!s2.is_empty());
267 }
268
269 #[test]
270 fn generate_totp_at_is_deterministic() {
271 let code1 = generate_totp_at("JBSWY3DPEHPK3PXP", 1_700_000_000);
273 let code2 = generate_totp_at("JBSWY3DPEHPK3PXP", 1_700_000_000);
274 assert_eq!(code1, code2);
275 assert_eq!(code1.len(), 6);
276 assert!(code1.chars().all(|c| c.is_ascii_digit()));
277 }
278
279 #[test]
280 fn generate_totp_at_different_times_differ() {
281 let code1 = generate_totp_at("JBSWY3DPEHPK3PXP", 1_700_000_000);
283 let code2 = generate_totp_at("JBSWY3DPEHPK3PXP", 1_700_000_030);
284 assert_ne!(code1, code2);
285 }
286
287 #[test]
288 fn generate_totp_at_same_window_equal() {
289 let code1 = generate_totp_at("SECRET", 1_700_000_000);
291 let code2 = generate_totp_at("SECRET", 1_700_000_005);
292 assert_eq!(code1, code2);
293 }
294
295 #[test]
296 fn generate_totp_at_different_secrets_differ() {
297 let code1 = generate_totp_at("SECRET_A", 1_700_000_000);
298 let code2 = generate_totp_at("SECRET_B", 1_700_000_000);
299 assert_ne!(code1, code2);
300 }
301
302 #[test]
303 fn generate_secret_is_16_chars_base32() {
304 let s = generate_secret();
305 assert_eq!(s.len(), 16);
306 assert!(s
307 .chars()
308 .all(|c| "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567".contains(c)));
309 }
310}