ant_node/payment/
cache.rs1use lru::LruCache;
7use parking_lot::Mutex;
8use std::num::NonZeroUsize;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12pub use super::quote::XorName;
13
14const DEFAULT_CACHE_CAPACITY: usize = 100_000;
16
17#[derive(Clone)]
27pub struct VerifiedCache {
28 inner: Arc<Mutex<LruCache<XorName, VerificationLevel>>>,
29 hits: Arc<AtomicU64>,
30 misses: Arc<AtomicU64>,
31 additions: Arc<AtomicU64>,
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35enum VerificationLevel {
36 PaidList,
37 ClientPut,
38}
39
40impl VerificationLevel {
41 fn satisfies(self, required: Self) -> bool {
42 matches!(
43 (self, required),
44 (Self::PaidList, Self::PaidList) | (Self::ClientPut, Self::PaidList | Self::ClientPut)
45 )
46 }
47}
48
49#[derive(Debug, Default, Clone, Copy)]
51pub struct CacheStats {
52 pub hits: u64,
54 pub misses: u64,
56 pub additions: u64,
58}
59
60impl CacheStats {
61 #[must_use]
63 #[allow(clippy::cast_precision_loss)]
64 pub fn hit_rate(&self) -> f64 {
65 let total = self.hits + self.misses;
66 if total == 0 {
67 0.0
68 } else {
69 (self.hits as f64 / total as f64) * 100.0
70 }
71 }
72}
73
74impl VerifiedCache {
75 #[must_use]
77 pub fn new() -> Self {
78 Self::with_capacity(DEFAULT_CACHE_CAPACITY)
79 }
80
81 #[must_use]
85 pub fn with_capacity(capacity: usize) -> Self {
86 let effective_capacity = capacity.max(1);
88 let cap = NonZeroUsize::new(effective_capacity).unwrap_or(NonZeroUsize::MIN);
91 Self {
92 inner: Arc::new(Mutex::new(LruCache::new(cap))),
93 hits: Arc::new(AtomicU64::new(0)),
94 misses: Arc::new(AtomicU64::new(0)),
95 additions: Arc::new(AtomicU64::new(0)),
96 }
97 }
98
99 #[must_use]
104 pub fn contains(&self, xorname: &XorName) -> bool {
105 let found = self.inner.lock().get(xorname).is_some();
106
107 if found {
108 self.hits.fetch_add(1, Ordering::Relaxed);
109 } else {
110 self.misses.fetch_add(1, Ordering::Relaxed);
111 }
112
113 found
114 }
115
116 #[must_use]
122 pub fn contains_paid_list_verified(&self, xorname: &XorName) -> bool {
123 let found = self
124 .inner
125 .lock()
126 .get(xorname)
127 .copied()
128 .is_some_and(|level| level.satisfies(VerificationLevel::PaidList));
129
130 if found {
131 self.hits.fetch_add(1, Ordering::Relaxed);
132 } else {
133 self.misses.fetch_add(1, Ordering::Relaxed);
134 }
135
136 found
137 }
138
139 #[must_use]
145 pub fn contains_client_put_verified(&self, xorname: &XorName) -> bool {
146 let found = self
147 .inner
148 .lock()
149 .get(xorname)
150 .copied()
151 .is_some_and(|level| level.satisfies(VerificationLevel::ClientPut));
152
153 if found {
154 self.hits.fetch_add(1, Ordering::Relaxed);
155 } else {
156 self.misses.fetch_add(1, Ordering::Relaxed);
157 }
158
159 found
160 }
161
162 pub fn insert(&self, xorname: XorName) {
167 self.insert_with_level(xorname, VerificationLevel::ClientPut);
168 }
169
170 pub fn insert_paid_list_verified(&self, xorname: XorName) {
174 self.insert_with_level(xorname, VerificationLevel::PaidList);
175 }
176
177 fn insert_with_level(&self, xorname: XorName, level: VerificationLevel) {
178 let added = {
179 let mut inner = self.inner.lock();
180 if inner.get(&xorname).is_some() {
182 if let Some(existing) = inner.get_mut(&xorname) {
183 if !existing.satisfies(level) {
184 *existing = level;
185 }
186 }
187 false
188 } else {
189 inner.put(xorname, level);
190 true
191 }
192 };
193 if added {
194 self.additions.fetch_add(1, Ordering::Relaxed);
195 }
196 }
197
198 #[must_use]
200 pub fn stats(&self) -> CacheStats {
201 CacheStats {
202 hits: self.hits.load(Ordering::Relaxed),
203 misses: self.misses.load(Ordering::Relaxed),
204 additions: self.additions.load(Ordering::Relaxed),
205 }
206 }
207
208 #[must_use]
210 pub fn len(&self) -> usize {
211 self.inner.lock().len()
212 }
213
214 #[must_use]
216 pub fn is_empty(&self) -> bool {
217 self.inner.lock().is_empty()
218 }
219
220 pub fn clear(&self) {
222 self.inner.lock().clear();
223 }
224}
225
226impl Default for VerifiedCache {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233#[allow(clippy::expect_used)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_cache_basic_operations() {
239 let cache = VerifiedCache::new();
240
241 let xorname1 = [1u8; 32];
242 let xorname2 = [2u8; 32];
243
244 assert!(cache.is_empty());
246 assert!(!cache.contains(&xorname1));
247
248 cache.insert(xorname1);
250 assert!(cache.contains(&xorname1));
251 assert!(!cache.contains(&xorname2));
252 assert_eq!(cache.len(), 1);
253
254 cache.insert(xorname2);
256 assert!(cache.contains(&xorname1));
257 assert!(cache.contains(&xorname2));
258 assert_eq!(cache.len(), 2);
259 }
260
261 #[test]
262 fn test_cache_verification_levels_do_not_downgrade_or_over_authorize() {
263 let cache = VerifiedCache::new();
264 let paid_list = [2u8; 32];
265 let client_put = [3u8; 32];
266
267 cache.insert_paid_list_verified(paid_list);
268 assert!(cache.contains(&paid_list));
269 assert!(cache.contains_paid_list_verified(&paid_list));
270 assert!(!cache.contains_client_put_verified(&paid_list));
271
272 cache.insert(paid_list);
273 assert!(cache.contains_client_put_verified(&paid_list));
274
275 cache.insert(client_put);
276 assert!(cache.contains(&client_put));
277 assert!(cache.contains_paid_list_verified(&client_put));
278 assert!(cache.contains_client_put_verified(&client_put));
279
280 cache.insert_paid_list_verified(client_put);
281 assert!(cache.contains_client_put_verified(&client_put));
282 }
283
284 #[test]
285 fn test_cache_stats() {
286 let cache = VerifiedCache::new();
287 let xorname = [1u8; 32];
288
289 assert!(!cache.contains(&xorname));
291 let stats = cache.stats();
292 assert_eq!(stats.misses, 1);
293 assert_eq!(stats.hits, 0);
294
295 cache.insert(xorname);
297 let stats = cache.stats();
298 assert_eq!(stats.additions, 1);
299
300 assert!(cache.contains(&xorname));
302 let stats = cache.stats();
303 assert_eq!(stats.hits, 1);
304 assert_eq!(stats.misses, 1);
305
306 assert!((stats.hit_rate() - 50.0).abs() < 0.01);
308 }
309
310 #[test]
311 fn test_cache_lru_eviction() {
312 let cache = VerifiedCache::with_capacity(2);
314
315 let xorname1 = [1u8; 32];
316 let xorname2 = [2u8; 32];
317 let xorname3 = [3u8; 32];
318
319 cache.insert(xorname1);
320 cache.insert(xorname2);
321 assert_eq!(cache.len(), 2);
322
323 cache.insert(xorname3);
325 assert_eq!(cache.len(), 2);
326 assert!(!cache.contains(&xorname1)); }
329
330 #[test]
331 fn test_cache_clear() {
332 let cache = VerifiedCache::new();
333
334 cache.insert([1u8; 32]);
335 cache.insert([2u8; 32]);
336 assert_eq!(cache.len(), 2);
337
338 cache.clear();
339 assert!(cache.is_empty());
340 }
341
342 #[test]
343 fn test_with_capacity_zero_defaults_to_one() {
344 let cache = VerifiedCache::with_capacity(0);
345 cache.insert([1u8; 32]);
347 assert_eq!(cache.len(), 1);
348 }
349
350 #[test]
351 fn test_default_impl() {
352 let cache = VerifiedCache::default();
353 assert!(cache.is_empty());
354 cache.insert([1u8; 32]);
355 assert!(cache.contains(&[1u8; 32]));
356 }
357
358 #[test]
359 fn test_hit_rate_zero_total() {
360 let stats = CacheStats::default();
361 assert!(stats.hit_rate().abs() < f64::EPSILON);
362 }
363
364 #[test]
365 fn test_hit_rate_all_hits() {
366 let stats = CacheStats {
367 hits: 10,
368 misses: 0,
369 additions: 0,
370 };
371 assert!((stats.hit_rate() - 100.0).abs() < 0.01);
372 }
373
374 #[test]
375 fn test_hit_rate_all_misses() {
376 let stats = CacheStats {
377 hits: 0,
378 misses: 10,
379 additions: 0,
380 };
381 assert!(stats.hit_rate().abs() < f64::EPSILON);
382 }
383
384 #[test]
385 fn test_clear_does_not_reset_stats() {
386 let cache = VerifiedCache::new();
387 cache.insert([1u8; 32]);
388 let _ = cache.contains(&[1u8; 32]); let _ = cache.contains(&[2u8; 32]); cache.clear();
392
393 let stats = cache.stats();
395 assert_eq!(stats.hits, 1);
396 assert_eq!(stats.misses, 1);
397 assert_eq!(stats.additions, 1);
398 }
399
400 #[test]
401 fn test_concurrent_insert_and_contains() {
402 use std::sync::Arc;
403 use std::thread;
404
405 let cache = Arc::new(VerifiedCache::with_capacity(1000));
406 let mut handles = Vec::new();
407
408 for i in 0..10u8 {
410 let c = cache.clone();
411 handles.push(thread::spawn(move || {
412 let xorname = [i; 32];
413 c.insert(xorname);
414 }));
415 }
416
417 for i in 0..10u8 {
419 let c = cache.clone();
420 handles.push(thread::spawn(move || {
421 let xorname = [i; 32];
422 let _ = c.contains(&xorname);
423 }));
424 }
425
426 for handle in handles {
427 handle.join().expect("thread panicked");
428 }
429
430 assert_eq!(cache.len(), 10);
432 }
433
434 #[test]
435 fn test_cache_stats_copy() {
436 let stats = CacheStats {
437 hits: 5,
438 misses: 3,
439 additions: 8,
440 };
441 let stats2 = stats; assert_eq!(stats.hits, stats2.hits);
443 assert_eq!(stats.misses, stats2.misses);
444 assert_eq!(stats.additions, stats2.additions);
445 }
446}