1mod maintenance;
2mod metrics;
3mod policy;
4mod shard;
5mod tinylfu;
6
7pub use crate::maintenance::MaintenanceConfig;
8use crate::maintenance::MaintenanceHandle;
9pub use crate::metrics::MetricsSnapshot;
10use crate::shard::Shard;
11
12use axhash_core::AxHasher;
13use core::borrow::Borrow;
14use core::hash::{BuildHasher, BuildHasherDefault, Hash};
15use core::sync::atomic::{AtomicBool, Ordering};
16use std::sync::{Arc, OnceLock};
17use std::time::{Duration, Instant};
18
19const NO_EXPIRY: u32 = u32::MAX;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum InsertOutcome {
23 Inserted,
24 Updated,
25 Rejected,
26}
27
28impl InsertOutcome {
29 #[inline]
30 pub const fn is_present(self) -> bool {
31 matches!(self, Self::Inserted | Self::Updated)
32 }
33
34 #[inline]
35 pub const fn is_new(self) -> bool {
36 matches!(self, Self::Inserted)
37 }
38
39 #[inline]
40 pub const fn is_rejected(self) -> bool {
41 matches!(self, Self::Rejected)
42 }
43}
44
45pub struct Cache<K, V> {
46 shards: Arc<[Shard<K, V>]>,
47 mask: u64,
48 shard_shift: u32,
49 epoch: Instant,
50 has_ttl: AtomicBool,
51 _maintenance: OnceLock<MaintenanceHandle>,
52}
53
54impl<K, V> Cache<K, V>
55where
56 K: Eq + Hash + Clone,
57 V: Clone,
58{
59 pub fn new(capacity: usize) -> Self {
60 let parallelism = std::thread::available_parallelism()
61 .map(|n| n.get())
62 .unwrap_or(4);
63
64 let shard_count = (parallelism * 4).next_power_of_two();
65 Self::with_shards(capacity, shard_count)
66 }
67
68 pub fn with_shards(capacity: usize, shard_count: usize) -> Self {
69 let shard_count = shard_count.max(1).next_power_of_two();
70 let per_shard = (capacity / shard_count).max(1);
71 let shards: Vec<Shard<K, V>> = (0..shard_count).map(|_| Shard::new(per_shard)).collect();
72 let mask = (shard_count - 1) as u64;
73
74 let shard_shift = if shard_count == 1 {
75 0
76 } else {
77 64 - shard_count.trailing_zeros()
78 };
79 Self {
80 shards: Arc::from(shards.into_boxed_slice()),
81 mask,
82 shard_shift,
83 epoch: Instant::now(),
84 has_ttl: AtomicBool::new(false),
85 _maintenance: OnceLock::new(),
86 }
87 }
88
89 pub fn enable_maintenance(&self, config: MaintenanceConfig)
90 where
91 K: Send + Sync + 'static,
92 V: Send + Sync + 'static,
93 {
94 if self._maintenance.get().is_some() {
95 return;
96 }
97 let shards = Arc::clone(&self.shards);
98 let epoch = self.epoch;
99 let now_fn =
100 move || -> u32 { u32::try_from(epoch.elapsed().as_millis()).unwrap_or(NO_EXPIRY - 1) };
101 let handle = maintenance::spawn_worker(shards, config, now_fn);
102 let _ = self._maintenance.set(handle);
103 }
104
105 #[inline(always)]
106 fn route<Q: Hash + ?Sized>(&self, key: &Q) -> (usize, u64) {
107 let hasher_builder = BuildHasherDefault::<AxHasher>::default();
108 let h = hasher_builder.hash_one(key);
109 let mixed = h.wrapping_mul(0x9E3779B97F4A7C15);
110 let idx = ((mixed >> self.shard_shift) & self.mask) as usize;
111 (idx, h)
112 }
113
114 #[inline(always)]
115 fn now_ms(&self) -> u32 {
116 if !self.has_ttl.load(Ordering::Relaxed) {
117 return 0;
118 }
119
120 u32::try_from(self.epoch.elapsed().as_millis()).unwrap_or(NO_EXPIRY - 1)
121 }
122
123 #[inline(always)]
124 fn expiry_for(&self, ttl: Duration, now: u32) -> u32 {
125 let ttl_ms = u32::try_from(ttl.as_millis()).unwrap_or(NO_EXPIRY - 1);
126 now.saturating_add(ttl_ms).min(NO_EXPIRY - 1)
127 }
128
129 pub fn get<Q>(&self, key: &Q) -> Option<V>
130 where
131 K: Borrow<Q>,
132 Q: Eq + Hash + ?Sized,
133 {
134 let (idx, hash) = self.route(key);
135 let shard = &self.shards[idx];
136 let now = self.now_ms();
137 match shard.get(key, hash, now) {
138 Some(v) => {
139 shard.metrics.hit();
140 Some(v)
141 }
142 None => {
143 shard.metrics.miss();
144 None
145 }
146 }
147 }
148
149 pub fn insert(&self, key: K, value: V) -> InsertOutcome {
150 let (idx, key_hash) = self.route(&key);
151 self.shards[idx].insert(key, value, NO_EXPIRY, self.now_ms(), key_hash)
152 }
153
154 pub fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) -> InsertOutcome {
155 if !self.has_ttl.load(Ordering::Relaxed) {
156 self.has_ttl.store(true, Ordering::Relaxed);
157 }
158 let now = self.now_ms();
159 let expiry = self.expiry_for(ttl, now);
160 let (idx, key_hash) = self.route(&key);
161 self.shards[idx].insert(key, value, expiry, now, key_hash)
162 }
163
164 pub fn remove<Q>(&self, key: &Q) -> Option<V>
165 where
166 K: Borrow<Q>,
167 Q: Eq + Hash + ?Sized,
168 {
169 let (idx, hash) = self.route(key);
170 self.shards[idx].remove(key, hash)
171 }
172
173 pub fn contains_key<Q>(&self, key: &Q) -> bool
174 where
175 K: Borrow<Q>,
176 Q: Eq + Hash + ?Sized,
177 {
178 let (idx, hash) = self.route(key);
179 self.shards[idx].contains_key(key, hash, self.now_ms())
180 }
181
182 pub fn clear(&self) {
183 for shard in self.shards.iter() {
184 shard.clear();
185 }
186 }
187
188 pub fn len(&self) -> usize {
189 self.shards.iter().map(|s| s.len()).sum()
190 }
191
192 pub fn is_empty(&self) -> bool {
193 self.len() == 0
194 }
195
196 pub fn shard_count(&self) -> usize {
197 self.shards.len()
198 }
199
200 pub fn metrics(&self) -> MetricsSnapshot {
201 let mut snap = MetricsSnapshot::default();
202 for shard in self.shards.iter() {
203 snap.merge(&shard.metrics.snapshot());
204 }
205 snap
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn basic_insert_get() {
215 let c: Cache<String, u64> = Cache::with_shards(64, 4);
216 c.insert("alpha".to_string(), 1);
217 c.insert("beta".to_string(), 2);
218 assert_eq!(c.get("alpha"), Some(1));
219 assert_eq!(c.get("beta"), Some(2));
220 assert_eq!(c.get("missing"), None);
221 }
222
223 #[test]
224 fn update_replaces_value() {
225 let c: Cache<u32, u32> = Cache::with_shards(32, 2);
226 assert_eq!(c.insert(1, 10), InsertOutcome::Inserted);
227 assert_eq!(c.insert(1, 20), InsertOutcome::Updated);
228 assert_eq!(c.get(&1), Some(20));
229 }
230
231 #[test]
232 fn insert_outcome_helpers() {
233 assert!(InsertOutcome::Inserted.is_present());
234 assert!(InsertOutcome::Inserted.is_new());
235 assert!(!InsertOutcome::Inserted.is_rejected());
236
237 assert!(InsertOutcome::Updated.is_present());
238 assert!(!InsertOutcome::Updated.is_new());
239 assert!(!InsertOutcome::Updated.is_rejected());
240
241 assert!(!InsertOutcome::Rejected.is_present());
242 assert!(!InsertOutcome::Rejected.is_new());
243 assert!(InsertOutcome::Rejected.is_rejected());
244 }
245
246 #[test]
247 fn contains_key_works() {
248 let c: Cache<&'static str, u32> = Cache::with_shards(64, 1);
249 assert!(!c.contains_key("missing"));
250 c.insert("present", 1);
251 assert!(c.contains_key("present"));
252 assert!(!c.contains_key("missing"));
253 c.remove("present");
254 assert!(!c.contains_key("present"));
255 }
256
257 #[test]
258 fn contains_key_respects_ttl() {
259 let c: Cache<u32, u32> = Cache::with_shards(64, 1);
260 c.insert_with_ttl(1, 100, Duration::from_millis(40));
261 assert!(c.contains_key(&1));
262 std::thread::sleep(Duration::from_millis(80));
263 assert!(!c.contains_key(&1));
264 }
265
266 #[test]
267 fn clear_empties_cache() {
268 let c: Cache<u32, u32> = Cache::with_shards(64, 4);
269 for i in 0..32u32 {
270 c.insert(i, i);
271 }
272 assert_eq!(c.len(), 32);
273 c.clear();
274 assert_eq!(c.len(), 0);
275 assert!(c.is_empty());
276 for i in 0..32u32 {
277 assert!(c.get(&i).is_none());
278 }
279 c.insert(99, 99);
280 assert_eq!(c.get(&99), Some(99));
281 }
282
283 #[test]
284 fn remove_works() {
285 let c: Cache<u32, u32> = Cache::with_shards(32, 2);
286 c.insert(1, 10);
287 assert_eq!(c.remove(&1), Some(10));
288 assert_eq!(c.remove(&1), None);
289 assert_eq!(c.get(&1), None);
290 }
291
292 #[test]
293 fn capacity_is_respected() {
294 let c: Cache<u64, u64> = Cache::with_shards(32, 4);
295 for i in 0..256u64 {
296 c.insert(i, i);
297 }
298
299 assert!(c.len() <= 32, "expected len ≤ 32, got {}", c.len());
300 }
301
302 #[test]
303 fn capacity_holds_under_all_hot_workload() {
304 const CAP: usize = 1024;
305 let c: Cache<u64, u64> = Cache::with_shards(CAP, 8);
306
307 for i in 0..CAP as u64 {
308 c.insert(i, i);
309 }
310 for _ in 0..8 {
311 for i in 0..CAP as u64 {
312 let _ = c.get(&i);
313 }
314 }
315
316 for i in (CAP as u64)..(CAP as u64 * 100) {
317 c.insert(i, i);
318 }
319
320 let len = c.len();
321 assert!(
322 len <= CAP * 2,
323 "cache grew unboundedly under hot workload: len={} cap={}",
324 len,
325 CAP
326 );
327 }
328
329 #[test]
330 fn hot_keys_survive_eviction() {
331 let c: Cache<u64, u64> = Cache::with_shards(64, 1);
332 for i in 0..8u64 {
333 c.insert(i, i);
334 }
335 for _ in 0..16 {
336 for i in 0..8u64 {
337 let _ = c.get(&i);
338 }
339 }
340 for i in 1000..2000u64 {
341 c.insert(i, i);
342 }
343 let surviving = (0..8u64).filter(|i| c.get(i).is_some()).count();
344 assert!(
345 surviving >= 6,
346 "expected ≥6 hot keys to survive, got {}",
347 surviving
348 );
349 }
350
351 #[test]
352 fn ttl_expires_after_deadline() {
353 let c: Cache<u32, u32> = Cache::with_shards(64, 1);
354 c.insert_with_ttl(1, 100, Duration::from_millis(50));
355 assert_eq!(c.get(&1), Some(100));
356 std::thread::sleep(Duration::from_millis(150));
357 assert_eq!(c.get(&1), None);
358 }
359
360 #[test]
361 fn ttl_default_insert_never_expires_automatically() {
362 let c: Cache<u32, u32> = Cache::with_shards(64, 1);
363 c.insert(1, 100);
364 std::thread::sleep(Duration::from_millis(60));
365 assert_eq!(c.get(&1), Some(100));
366 }
367
368 #[test]
369 fn ttl_zero_insert_is_immediately_expired() {
370 let c: Cache<u32, u32> = Cache::with_shards(64, 1);
371 c.insert_with_ttl(1, 100, Duration::ZERO);
372 assert_eq!(c.get(&1), None);
373 }
374
375 #[test]
376 fn ttl_mixed_with_no_ttl_in_same_cache() {
377 let c: Cache<u32, u32> = Cache::with_shards(64, 1);
378 c.insert(1, 100); c.insert_with_ttl(2, 200, Duration::from_millis(50));
380 std::thread::sleep(Duration::from_millis(150));
381 assert_eq!(c.get(&1), Some(100));
382 assert_eq!(c.get(&2), None);
383 }
384
385 #[test]
386 fn ttl_reinsert_extends_deadline() {
387 let c: Cache<u32, u32> = Cache::with_shards(64, 1);
388 c.insert_with_ttl(1, 100, Duration::from_millis(50));
389 std::thread::sleep(Duration::from_millis(30));
390 c.insert_with_ttl(1, 200, Duration::from_millis(200));
391 std::thread::sleep(Duration::from_millis(40));
392 assert_eq!(c.get(&1), Some(200));
393 }
394
395 #[test]
396 fn ttl_expired_entries_get_swept_on_rebalance() {
397 let c: Cache<u32, u32> = Cache::with_shards(4, 1);
398 c.insert_with_ttl(1, 100, Duration::from_millis(40));
399 c.insert_with_ttl(2, 200, Duration::from_millis(40));
400 c.insert_with_ttl(3, 300, Duration::from_millis(40));
401 c.insert(4, 400); std::thread::sleep(Duration::from_millis(100));
404
405 for k in 5..20u32 {
406 c.insert(k, k);
407 }
408 assert_eq!(c.get(&1), None);
409 assert_eq!(c.get(&2), None);
410 assert_eq!(c.get(&3), None);
411 assert!(c.len() <= 4, "expected len ≤ 4, got {}", c.len());
412 }
413
414 #[test]
415 fn concurrent_smoke() {
416 use std::sync::Arc;
417 use std::thread;
418 let c = Arc::new(Cache::<u64, u64>::with_shards(1024, 16));
419 let mut handles = Vec::new();
420 for t in 0..8u64 {
421 let c = Arc::clone(&c);
422 handles.push(thread::spawn(move || {
423 for i in 0..2000u64 {
424 let k = (t * 10_000) + i;
425 c.insert(k, k);
426 let _ = c.get(&k);
427 }
428 }));
429 }
430 for h in handles {
431 h.join().unwrap();
432 }
433 let m = c.metrics();
434 assert!(m.insertions > 0);
435 assert!(m.hits + m.misses > 0);
436 }
437
438 #[test]
439 fn remove_churn_does_not_leak_queue_memory() {
440 let c: Cache<u64, u64> = Cache::with_shards(100, 1);
441 for cycle in 0..100u64 {
442 for i in 0..50u64 {
443 let k = cycle * 1000 + i;
444 c.insert(k, k);
445 }
446 for i in 0..50u64 {
447 let k = cycle * 1000 + i;
448 c.remove(&k);
449 }
450 }
451 assert_eq!(c.len(), 0);
452 }
453
454 #[test]
455 fn shard_distribution_uniformity() {
456 let c: Cache<u64, u64> = Cache::with_shards(10_000, 16);
457 for i in 0..10_000u64 {
458 c.insert(i, i);
459 }
460
461 let total = c.len();
462 let expected_per_shard = total as f64 / c.shard_count() as f64;
463 let lo = (expected_per_shard * 0.5) as usize;
464 let hi = (expected_per_shard * 1.5) as usize;
465 assert!(total > 0);
466 assert!(total <= 10_000, "total {} exceeds capacity", total);
467 let _ = (lo, hi);
468 }
469
470 #[test]
471 fn maintenance_sweeps_expired_entries() {
472 let c: Cache<u32, u32> = Cache::with_shards(64, 1);
473 c.enable_maintenance(MaintenanceConfig {
474 sweep_interval: Duration::from_millis(50),
475 max_sweep_per_shard: 32,
476 });
477 for i in 0..10u32 {
478 c.insert_with_ttl(i, i * 10, Duration::from_millis(30));
479 }
480 assert!(!c.is_empty());
481 std::thread::sleep(Duration::from_millis(200));
482 assert_eq!(c.len(), 0, "expected 0 after sweep, got {}", c.len());
483 }
484}