geronimo_captcha/
registry.rs1use crate::utils::get_timestamp;
2
3use dashmap::DashMap;
4use std::fmt;
5use std::sync::Mutex;
6
7pub trait ChallengeRegistry: Send + Sync {
11 fn register(&self, id: &str);
12 fn check(&self, id: &str) -> RegistryCheckResult;
13 fn verify(&self, id: &str);
14 fn note_attempt(&self, id: &str, success: bool);
15}
16
17#[derive(PartialEq, Debug)]
18pub enum RegistryCheckResult {
19 Ok,
20 AlreadyVerified,
21 NotRegistered,
22 MaxAttemptsLimitExceeded,
23}
24
25impl fmt::Display for RegistryCheckResult {
26 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27 write!(f, "{}", format!("{self:?}").to_uppercase())
28 }
29}
30
31struct ChallengeStatus {
32 verified: bool,
33 attempts_count: u16,
34 timestamp: u64,
35}
36
37struct Wheel {
38 buckets: Vec<Vec<String>>, pos: usize, last_tick: u64, len: usize, }
43
44pub struct ChallengeInMemoryRegistry {
45 cache: DashMap<String, ChallengeStatus>,
46 max_attempts: u16,
47 ttl: u64,
48 wheel: Mutex<Wheel>,
49}
50
51impl ChallengeInMemoryRegistry {
52 pub fn new(ttl: u64, max_attempts: u16) -> Self {
53 let now = get_timestamp();
54 let len = ttl.max(1) as usize;
55 let pos = (now as usize) % len;
56 let wheel = Wheel {
57 buckets: vec![Vec::new(); len],
58 pos,
59 last_tick: now,
60 len,
61 };
62
63 Self {
64 cache: DashMap::new(),
65 max_attempts,
66 ttl,
67 wheel: Mutex::new(wheel),
68 }
69 }
70
71 fn advance_wheel(&self, now: u64) {
72 let mut w = self.wheel.lock().unwrap();
73 if now <= w.last_tick {
74 return;
75 }
76
77 let steps = ((now - w.last_tick) as usize).min(w.len);
78 for _ in 0..steps {
79 w.pos = (w.pos + 1) % w.len;
80
81 let pos = w.pos;
82 let expired_ids = std::mem::take(&mut w.buckets[pos]);
83
84 for id in expired_ids {
85 if let Some(cs_ref) = self.cache.get(&id) {
86 let expired = now.saturating_sub(cs_ref.timestamp) >= self.ttl;
87 drop(cs_ref);
88 if expired {
89 let _ = self.cache.remove(&id);
90 }
91 } else {
92 }
94 }
95 }
96
97 w.last_tick = now;
98 }
99
100 fn schedule_expiry(&self, id: &str, now: u64) {
101 let mut w = self.wheel.lock().unwrap();
102
103 let target = (now + self.ttl) as usize % w.len;
105 w.buckets[target].push(id.to_string());
106 }
107}
108
109impl ChallengeRegistry for ChallengeInMemoryRegistry {
110 fn register(&self, id: &str) {
111 let now = get_timestamp();
112 self.advance_wheel(now);
113 self.cache.insert(
114 id.to_string(),
115 ChallengeStatus {
116 verified: false,
117 attempts_count: 0,
118 timestamp: now,
119 },
120 );
121 self.schedule_expiry(id, now);
122 }
123
124 fn check(&self, id: &str) -> RegistryCheckResult {
125 let now = get_timestamp();
126 self.advance_wheel(now);
127
128 if let Some(challenge_ref) = self.cache.get(id) {
129 let cs = challenge_ref.value();
130 if cs.verified {
131 return RegistryCheckResult::AlreadyVerified;
132 }
133
134 if cs.attempts_count >= self.max_attempts {
135 return RegistryCheckResult::MaxAttemptsLimitExceeded;
136 }
137
138 if now.saturating_sub(cs.timestamp) <= self.ttl {
139 return RegistryCheckResult::Ok;
140 }
141 }
142
143 RegistryCheckResult::NotRegistered
144 }
145
146 fn verify(&self, id: &str) {
147 if let Some(mut challenge_ref) = self.cache.get_mut(id) {
148 let cs = challenge_ref.value_mut();
149 cs.verified = true;
150 }
151 }
152
153 fn note_attempt(&self, id: &str, success: bool) {
154 if let Some(mut challenge_ref) = self.cache.get_mut(id) {
155 let cs = challenge_ref.value_mut();
156 if !success {
157 cs.attempts_count = cs.attempts_count.saturating_add(1);
158 }
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use std::sync::Arc;
167
168 const DEFAULT_TTL: u64 = 60;
169
170 #[test]
171 fn test_register_and_check() {
172 let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1);
173 let challenge_id = "challenge-123";
174
175 registry.register(challenge_id);
176 assert_eq!(registry.check(challenge_id), RegistryCheckResult::Ok);
177 }
178
179 #[test]
180 fn test_check_unregistered() {
181 let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1);
182 let challenge_id = "challenge-123";
183 assert_eq!(
184 registry.check(challenge_id),
185 RegistryCheckResult::NotRegistered
186 );
187 }
188
189 #[test]
190 fn test_check_already_verified() {
191 let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1);
192 let challenge_id = "challenge-123";
193 registry.register(challenge_id);
194 registry.verify(challenge_id);
195 assert_eq!(
196 registry.check(challenge_id),
197 RegistryCheckResult::AlreadyVerified
198 );
199 }
200
201 #[test]
202 fn test_check_max_attempts_limit() {
203 let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 2);
204 let challenge_id = "challenge-123";
205 registry.register(challenge_id);
206
207 assert_eq!(registry.check(challenge_id), RegistryCheckResult::Ok);
208
209 registry.note_attempt(challenge_id, false);
210 assert_eq!(registry.check(challenge_id), RegistryCheckResult::Ok);
211
212 registry.note_attempt(challenge_id, false);
213 assert_eq!(
214 registry.check(challenge_id),
215 RegistryCheckResult::MaxAttemptsLimitExceeded
216 );
217 }
218
219 #[test]
220 fn test_concurrent_usage_safe() {
221 use std::thread;
222
223 let registry = Arc::new(ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1));
224 let handles: Vec<_> = (0..10)
225 .map(|i| {
226 let reg = registry.clone();
227 thread::spawn(move || {
228 let challenge_id = format!("challenge-{i}");
229 reg.register(&challenge_id);
230 })
231 })
232 .collect();
233
234 for h in handles {
235 h.join().unwrap();
236 }
237
238 for i in 0..10 {
239 let id = format!("challenge-{i}");
240 assert_eq!(registry.check(&id), RegistryCheckResult::Ok);
241 }
242 }
243}