1use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9pub struct BudgetTokenBucket {
14 max_tokens: u64,
16 tokens: u64,
18 window: Duration,
20 window_start: Instant,
22}
23
24impl BudgetTokenBucket {
25 pub fn new(max_tokens: u64, window: Duration) -> Self {
31 Self {
32 max_tokens,
33 tokens: max_tokens,
34 window,
35 window_start: Instant::now(),
36 }
37 }
38
39 pub fn try_consume(&mut self, cost: u64) -> Result<u64, u64> {
42 self.maybe_refill();
43
44 if cost <= self.tokens {
45 self.tokens -= cost;
46 Ok(self.tokens)
47 } else {
48 Err(cost - self.tokens)
49 }
50 }
51
52 pub fn remaining(&mut self) -> u64 {
54 self.maybe_refill();
55 self.tokens
56 }
57
58 pub fn refill(&mut self) {
60 self.tokens = self.max_tokens;
61 self.window_start = Instant::now();
62 }
63
64 fn maybe_refill(&mut self) {
65 if self.window_start.elapsed() >= self.window {
66 self.tokens = self.max_tokens;
67 self.window_start = Instant::now();
68 }
69 }
70}
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
77pub struct QuerySignature {
78 hash: u64,
79}
80
81impl QuerySignature {
82 pub fn from_query(query: &[f32]) -> Self {
86 let mut hash: u64 = 0xcbf29ce484222325;
88 for &val in query {
89 let quantized = (val.clamp(-1.0, 1.0) * 127.0) as i8;
91 hash ^= quantized as u64;
92 hash = hash.wrapping_mul(0x100000001b3);
93 }
94 Self { hash }
95 }
96}
97
98struct NegativeCacheEntry {
100 hit_count: u32,
101 first_seen: Instant,
102 last_seen: Instant,
103}
104
105pub struct NegativeCache {
111 entries: HashMap<QuerySignature, NegativeCacheEntry>,
112 threshold: u32,
114 window: Duration,
116 max_entries: usize,
118}
119
120impl NegativeCache {
121 pub fn new(threshold: u32, window: Duration, max_entries: usize) -> Self {
128 Self {
129 entries: HashMap::new(),
130 threshold,
131 window,
132 max_entries,
133 }
134 }
135
136 pub fn record_degenerate(&mut self, sig: QuerySignature) -> bool {
139 let now = Instant::now();
140
141 if self.entries.len() >= self.max_entries {
143 self.evict_expired(now);
144 }
145
146 if self.entries.len() >= self.max_entries {
148 self.evict_oldest();
149 }
150
151 let entry = self.entries.entry(sig).or_insert(NegativeCacheEntry {
152 hit_count: 0,
153 first_seen: now,
154 last_seen: now,
155 });
156
157 if now.duration_since(entry.first_seen) > self.window {
159 entry.hit_count = 0;
160 entry.first_seen = now;
161 }
162
163 entry.hit_count += 1;
164 entry.last_seen = now;
165
166 entry.hit_count >= self.threshold
167 }
168
169 pub fn is_blacklisted(&self, sig: &QuerySignature) -> bool {
171 if let Some(entry) = self.entries.get(sig) {
172 entry.hit_count >= self.threshold
173 } else {
174 false
175 }
176 }
177
178 pub fn len(&self) -> usize {
180 self.entries.len()
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.entries.is_empty()
186 }
187
188 fn evict_expired(&mut self, now: Instant) {
189 self.entries.retain(|_, entry| {
190 now.duration_since(entry.first_seen) <= self.window
191 });
192 }
193
194 fn evict_oldest(&mut self) {
195 if let Some(oldest_key) = self
196 .entries
197 .iter()
198 .min_by_key(|(_, e)| e.last_seen)
199 .map(|(k, _)| *k)
200 {
201 self.entries.remove(&oldest_key);
202 }
203 }
204}
205
206#[derive(Clone, Debug)]
211pub struct ProofOfWork {
212 pub challenge: [u8; 16],
214 pub difficulty: u8,
216}
217
218impl ProofOfWork {
219 pub const MAX_DIFFICULTY: u8 = 24;
222
223 pub fn verify(&self, nonce: u64) -> bool {
228 let mut hash: u64 = 0xcbf29ce484222325;
229 for &byte in &self.challenge {
230 hash ^= byte as u64;
231 hash = hash.wrapping_mul(0x100000001b3);
232 }
233 for &byte in &nonce.to_le_bytes() {
234 hash ^= byte as u64;
235 hash = hash.wrapping_mul(0x100000001b3);
236 }
237
238 let clamped = self.difficulty.min(Self::MAX_DIFFICULTY);
239 let leading_zeros = hash.leading_zeros() as u8;
240 leading_zeros >= clamped
241 }
242
243 pub fn solve(&self) -> Option<u64> {
246 let max_attempts: u64 = 1u64 << self.difficulty.min(Self::MAX_DIFFICULTY).min(30);
247 for nonce in 0..max_attempts.saturating_mul(4) {
248 if self.verify(nonce) {
249 return Some(nonce);
250 }
251 }
252 None
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn token_bucket_basic() {
262 let mut bucket = BudgetTokenBucket::new(100, Duration::from_secs(1));
263 assert_eq!(bucket.remaining(), 100);
264 assert_eq!(bucket.try_consume(30), Ok(70));
265 assert_eq!(bucket.remaining(), 70);
266 }
267
268 #[test]
269 fn token_bucket_exhaustion() {
270 let mut bucket = BudgetTokenBucket::new(10, Duration::from_secs(60));
271 assert_eq!(bucket.try_consume(10), Ok(0));
272 assert!(bucket.try_consume(1).is_err());
273 }
274
275 #[test]
276 fn token_bucket_refill() {
277 let mut bucket = BudgetTokenBucket::new(100, Duration::from_millis(1));
278 bucket.try_consume(100).unwrap();
279 assert!(bucket.try_consume(1).is_err());
280 std::thread::sleep(Duration::from_millis(2));
281 assert_eq!(bucket.remaining(), 100);
282 }
283
284 #[test]
285 fn token_bucket_manual_refill() {
286 let mut bucket = BudgetTokenBucket::new(100, Duration::from_secs(60));
287 bucket.try_consume(100).unwrap();
288 bucket.refill();
289 assert_eq!(bucket.remaining(), 100);
290 }
291
292 #[test]
293 fn query_signature_deterministic() {
294 let query = vec![0.1, 0.2, 0.3, 0.4];
295 let sig1 = QuerySignature::from_query(&query);
296 let sig2 = QuerySignature::from_query(&query);
297 assert_eq!(sig1, sig2);
298 }
299
300 #[test]
301 fn query_signature_different_vectors() {
302 let sig1 = QuerySignature::from_query(&[0.1, 0.2, 0.3]);
303 let sig2 = QuerySignature::from_query(&[0.4, 0.5, 0.6]);
304 assert_ne!(sig1, sig2);
305 }
306
307 #[test]
308 fn negative_cache_below_threshold() {
309 let mut cache = NegativeCache::new(3, Duration::from_secs(60), 1000);
310 let sig = QuerySignature::from_query(&[0.1, 0.2]);
311 assert!(!cache.record_degenerate(sig));
312 assert!(!cache.record_degenerate(sig));
313 assert!(!cache.is_blacklisted(&sig));
314 }
315
316 #[test]
317 fn negative_cache_reaches_threshold() {
318 let mut cache = NegativeCache::new(3, Duration::from_secs(60), 1000);
319 let sig = QuerySignature::from_query(&[0.1, 0.2]);
320 cache.record_degenerate(sig);
321 cache.record_degenerate(sig);
322 assert!(cache.record_degenerate(sig)); assert!(cache.is_blacklisted(&sig));
324 }
325
326 #[test]
327 fn negative_cache_max_entries() {
328 let mut cache = NegativeCache::new(100, Duration::from_secs(60), 5);
329 for i in 0..10 {
330 let sig = QuerySignature::from_query(&[i as f32]);
331 cache.record_degenerate(sig);
332 }
333 assert!(cache.len() <= 5);
334 }
335
336 #[test]
337 fn negative_cache_empty() {
338 let cache = NegativeCache::new(3, Duration::from_secs(60), 1000);
339 assert!(cache.is_empty());
340 assert_eq!(cache.len(), 0);
341 }
342
343 #[test]
344 fn proof_of_work_low_difficulty() {
345 let pow = ProofOfWork {
346 challenge: [0xAB; 16],
347 difficulty: 1, };
349 let nonce = pow.solve().expect("should solve easily");
350 assert!(pow.verify(nonce));
351 }
352
353 #[test]
354 fn proof_of_work_wrong_nonce() {
355 let pow = ProofOfWork {
356 challenge: [0xAB; 16],
357 difficulty: 16, };
359 assert!(!pow.verify(0xDEADBEEF));
361 }
362
363 #[test]
364 fn proof_of_work_solve_and_verify() {
365 let pow = ProofOfWork {
366 challenge: [0x42; 16],
367 difficulty: 8,
368 };
369 let nonce = pow.solve().expect("should solve d=8");
370 assert!(pow.verify(nonce));
371 }
372
373 #[test]
374 fn proof_of_work_max_difficulty_clamped() {
375 let pow = ProofOfWork {
376 challenge: [0x42; 16],
377 difficulty: 255, };
379 assert_eq!(pow.difficulty.min(ProofOfWork::MAX_DIFFICULTY), 24);
382 }
383}