geronimo_captcha/
registry.rs

1use crate::utils::get_timestamp;
2
3use dashmap::DashMap;
4use std::fmt;
5use std::sync::Mutex;
6
7/// This file defines trait for the challenge registry implementation that
8/// stores generated challenges in memory or database, checks how
9/// many resolving attempts performed to prevent brute-force.
10pub trait ChallengeRegistry: Send + Sync {
11    fn register(&self, id: &str);
12    fn check(&self, id: &str) -> RegistryCheckResult;
13    fn verify(&self, id: &str);
14    fn note_attempt(&self, id: &str, success: bool);
15}
16
17#[derive(PartialEq, Debug)]
18pub enum RegistryCheckResult {
19    Ok,
20    AlreadyVerified,
21    NotRegistered,
22    MaxAttemptsLimitExceeded,
23}
24
25impl fmt::Display for RegistryCheckResult {
26    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27        write!(f, "{}", format!("{self:?}").to_uppercase())
28    }
29}
30
31struct ChallengeStatus {
32    verified: bool,
33    attempts_count: u16,
34    timestamp: u64,
35}
36
37struct Wheel {
38    buckets: Vec<Vec<String>>, // ids scheduled to expire at bucket index
39    pos: usize,                // current bucket index (advances with time)
40    last_tick: u64,            // last observed time (secs)
41    len: usize,                // buckets length == ttl (secs)
42}
43
44pub struct ChallengeInMemoryRegistry {
45    cache: DashMap<String, ChallengeStatus>,
46    max_attempts: u16,
47    ttl: u64,
48    wheel: Mutex<Wheel>,
49}
50
51impl ChallengeInMemoryRegistry {
52    pub fn new(ttl: u64, max_attempts: u16) -> Self {
53        let now = get_timestamp();
54        let len = ttl.max(1) as usize;
55        let pos = (now as usize) % len;
56        let wheel = Wheel {
57            buckets: vec![Vec::new(); len],
58            pos,
59            last_tick: now,
60            len,
61        };
62
63        Self {
64            cache: DashMap::new(),
65            max_attempts,
66            ttl,
67            wheel: Mutex::new(wheel),
68        }
69    }
70
71    fn advance_wheel(&self, now: u64) {
72        let mut w = self.wheel.lock().unwrap();
73        if now <= w.last_tick {
74            return;
75        }
76
77        let steps = ((now - w.last_tick) as usize).min(w.len);
78        for _ in 0..steps {
79            w.pos = (w.pos + 1) % w.len;
80
81            let pos = w.pos;
82            let expired_ids = std::mem::take(&mut w.buckets[pos]);
83
84            for id in expired_ids {
85                if let Some(cs_ref) = self.cache.get(&id) {
86                    let expired = now.saturating_sub(cs_ref.timestamp) >= self.ttl;
87                    drop(cs_ref);
88                    if expired {
89                        let _ = self.cache.remove(&id);
90                    }
91                } else {
92                    // already removed
93                }
94            }
95        }
96
97        w.last_tick = now;
98    }
99
100    fn schedule_expiry(&self, id: &str, now: u64) {
101        let mut w = self.wheel.lock().unwrap();
102
103        // Schedule at (now + ttl) bucket
104        let target = (now + self.ttl) as usize % w.len;
105        w.buckets[target].push(id.to_string());
106    }
107}
108
109impl ChallengeRegistry for ChallengeInMemoryRegistry {
110    fn register(&self, id: &str) {
111        let now = get_timestamp();
112        self.advance_wheel(now);
113        self.cache.insert(
114            id.to_string(),
115            ChallengeStatus {
116                verified: false,
117                attempts_count: 0,
118                timestamp: now,
119            },
120        );
121        self.schedule_expiry(id, now);
122    }
123
124    fn check(&self, id: &str) -> RegistryCheckResult {
125        let now = get_timestamp();
126        self.advance_wheel(now);
127
128        if let Some(challenge_ref) = self.cache.get(id) {
129            let cs = challenge_ref.value();
130            if cs.verified {
131                return RegistryCheckResult::AlreadyVerified;
132            }
133
134            if cs.attempts_count >= self.max_attempts {
135                return RegistryCheckResult::MaxAttemptsLimitExceeded;
136            }
137
138            if now.saturating_sub(cs.timestamp) <= self.ttl {
139                return RegistryCheckResult::Ok;
140            }
141        }
142
143        RegistryCheckResult::NotRegistered
144    }
145
146    fn verify(&self, id: &str) {
147        if let Some(mut challenge_ref) = self.cache.get_mut(id) {
148            let cs = challenge_ref.value_mut();
149            cs.verified = true;
150        }
151    }
152
153    fn note_attempt(&self, id: &str, success: bool) {
154        if let Some(mut challenge_ref) = self.cache.get_mut(id) {
155            let cs = challenge_ref.value_mut();
156            if !success {
157                cs.attempts_count = cs.attempts_count.saturating_add(1);
158            }
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use std::sync::Arc;
167
168    const DEFAULT_TTL: u64 = 60;
169
170    #[test]
171    fn test_register_and_check() {
172        let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1);
173        let challenge_id = "challenge-123";
174
175        registry.register(challenge_id);
176        assert_eq!(registry.check(challenge_id), RegistryCheckResult::Ok);
177    }
178
179    #[test]
180    fn test_check_unregistered() {
181        let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1);
182        let challenge_id = "challenge-123";
183        assert_eq!(
184            registry.check(challenge_id),
185            RegistryCheckResult::NotRegistered
186        );
187    }
188
189    #[test]
190    fn test_check_already_verified() {
191        let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1);
192        let challenge_id = "challenge-123";
193        registry.register(challenge_id);
194        registry.verify(challenge_id);
195        assert_eq!(
196            registry.check(challenge_id),
197            RegistryCheckResult::AlreadyVerified
198        );
199    }
200
201    #[test]
202    fn test_check_max_attempts_limit() {
203        let registry = ChallengeInMemoryRegistry::new(DEFAULT_TTL, 2);
204        let challenge_id = "challenge-123";
205        registry.register(challenge_id);
206
207        assert_eq!(registry.check(challenge_id), RegistryCheckResult::Ok);
208
209        registry.note_attempt(challenge_id, false);
210        assert_eq!(registry.check(challenge_id), RegistryCheckResult::Ok);
211
212        registry.note_attempt(challenge_id, false);
213        assert_eq!(
214            registry.check(challenge_id),
215            RegistryCheckResult::MaxAttemptsLimitExceeded
216        );
217    }
218
219    #[test]
220    fn test_concurrent_usage_safe() {
221        use std::thread;
222
223        let registry = Arc::new(ChallengeInMemoryRegistry::new(DEFAULT_TTL, 1));
224        let handles: Vec<_> = (0..10)
225            .map(|i| {
226                let reg = registry.clone();
227                thread::spawn(move || {
228                    let challenge_id = format!("challenge-{i}");
229                    reg.register(&challenge_id);
230                })
231            })
232            .collect();
233
234        for h in handles {
235            h.join().unwrap();
236        }
237
238        for i in 0..10 {
239            let id = format!("challenge-{i}");
240            assert_eq!(registry.check(&id), RegistryCheckResult::Ok);
241        }
242    }
243}