Skip to main content

ax_cache/
lib.rs

1mod maintenance;
2mod metrics;
3mod policy;
4mod shard;
5mod tinylfu;
6
7pub use crate::maintenance::MaintenanceConfig;
8use crate::maintenance::MaintenanceHandle;
9pub use crate::metrics::MetricsSnapshot;
10use crate::shard::Shard;
11
12use axhash_core::AxHasher;
13use core::borrow::Borrow;
14use core::hash::{BuildHasher, BuildHasherDefault, Hash};
15use core::sync::atomic::{AtomicBool, Ordering};
16use std::sync::{Arc, OnceLock};
17use std::time::{Duration, Instant};
18
19const NO_EXPIRY: u32 = u32::MAX;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum InsertOutcome {
23    Inserted,
24    Updated,
25    Rejected,
26}
27
28impl InsertOutcome {
29    #[inline]
30    pub const fn is_present(self) -> bool {
31        matches!(self, Self::Inserted | Self::Updated)
32    }
33
34    #[inline]
35    pub const fn is_new(self) -> bool {
36        matches!(self, Self::Inserted)
37    }
38
39    #[inline]
40    pub const fn is_rejected(self) -> bool {
41        matches!(self, Self::Rejected)
42    }
43}
44
45pub struct Cache<K, V> {
46    shards: Arc<[Shard<K, V>]>,
47    mask: u64,
48    shard_shift: u32,
49    epoch: Instant,
50    has_ttl: AtomicBool,
51    _maintenance: OnceLock<MaintenanceHandle>,
52}
53
54impl<K, V> Cache<K, V>
55where
56    K: Eq + Hash + Clone,
57    V: Clone,
58{
59    pub fn new(capacity: usize) -> Self {
60        let parallelism = std::thread::available_parallelism()
61            .map(|n| n.get())
62            .unwrap_or(4);
63
64        let shard_count = (parallelism * 4).next_power_of_two();
65        Self::with_shards(capacity, shard_count)
66    }
67
68    pub fn with_shards(capacity: usize, shard_count: usize) -> Self {
69        let shard_count = shard_count.max(1).next_power_of_two();
70        let per_shard = (capacity / shard_count).max(1);
71        let shards: Vec<Shard<K, V>> = (0..shard_count).map(|_| Shard::new(per_shard)).collect();
72        let mask = (shard_count - 1) as u64;
73
74        let shard_shift = if shard_count == 1 {
75            0
76        } else {
77            64 - shard_count.trailing_zeros()
78        };
79        Self {
80            shards: Arc::from(shards.into_boxed_slice()),
81            mask,
82            shard_shift,
83            epoch: Instant::now(),
84            has_ttl: AtomicBool::new(false),
85            _maintenance: OnceLock::new(),
86        }
87    }
88
89    pub fn enable_maintenance(&self, config: MaintenanceConfig)
90    where
91        K: Send + Sync + 'static,
92        V: Send + Sync + 'static,
93    {
94        if self._maintenance.get().is_some() {
95            return;
96        }
97        let shards = Arc::clone(&self.shards);
98        let epoch = self.epoch;
99        let now_fn =
100            move || -> u32 { u32::try_from(epoch.elapsed().as_millis()).unwrap_or(NO_EXPIRY - 1) };
101        let handle = maintenance::spawn_worker(shards, config, now_fn);
102        let _ = self._maintenance.set(handle);
103    }
104
105    #[inline(always)]
106    fn route<Q: Hash + ?Sized>(&self, key: &Q) -> (usize, u64) {
107        let hasher_builder = BuildHasherDefault::<AxHasher>::default();
108        let h = hasher_builder.hash_one(key);
109        let mixed = h.wrapping_mul(0x9E3779B97F4A7C15);
110        let idx = ((mixed >> self.shard_shift) & self.mask) as usize;
111        (idx, h)
112    }
113
114    #[inline(always)]
115    fn now_ms(&self) -> u32 {
116        if !self.has_ttl.load(Ordering::Relaxed) {
117            return 0;
118        }
119
120        u32::try_from(self.epoch.elapsed().as_millis()).unwrap_or(NO_EXPIRY - 1)
121    }
122
123    #[inline(always)]
124    fn expiry_for(&self, ttl: Duration, now: u32) -> u32 {
125        let ttl_ms = u32::try_from(ttl.as_millis()).unwrap_or(NO_EXPIRY - 1);
126        now.saturating_add(ttl_ms).min(NO_EXPIRY - 1)
127    }
128
129    pub fn get<Q>(&self, key: &Q) -> Option<V>
130    where
131        K: Borrow<Q>,
132        Q: Eq + Hash + ?Sized,
133    {
134        let (idx, hash) = self.route(key);
135        let shard = &self.shards[idx];
136        let now = self.now_ms();
137        match shard.get(key, hash, now) {
138            Some(v) => {
139                shard.metrics.hit();
140                Some(v)
141            }
142            None => {
143                shard.metrics.miss();
144                None
145            }
146        }
147    }
148
149    pub fn insert(&self, key: K, value: V) -> InsertOutcome {
150        let (idx, key_hash) = self.route(&key);
151        self.shards[idx].insert(key, value, NO_EXPIRY, self.now_ms(), key_hash)
152    }
153
154    pub fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) -> InsertOutcome {
155        if !self.has_ttl.load(Ordering::Relaxed) {
156            self.has_ttl.store(true, Ordering::Relaxed);
157        }
158        let now = self.now_ms();
159        let expiry = self.expiry_for(ttl, now);
160        let (idx, key_hash) = self.route(&key);
161        self.shards[idx].insert(key, value, expiry, now, key_hash)
162    }
163
164    pub fn remove<Q>(&self, key: &Q) -> Option<V>
165    where
166        K: Borrow<Q>,
167        Q: Eq + Hash + ?Sized,
168    {
169        let (idx, hash) = self.route(key);
170        self.shards[idx].remove(key, hash)
171    }
172
173    pub fn contains_key<Q>(&self, key: &Q) -> bool
174    where
175        K: Borrow<Q>,
176        Q: Eq + Hash + ?Sized,
177    {
178        let (idx, hash) = self.route(key);
179        self.shards[idx].contains_key(key, hash, self.now_ms())
180    }
181
182    pub fn clear(&self) {
183        for shard in self.shards.iter() {
184            shard.clear();
185        }
186    }
187
188    pub fn len(&self) -> usize {
189        self.shards.iter().map(|s| s.len()).sum()
190    }
191
192    pub fn is_empty(&self) -> bool {
193        self.len() == 0
194    }
195
196    pub fn shard_count(&self) -> usize {
197        self.shards.len()
198    }
199
200    pub fn metrics(&self) -> MetricsSnapshot {
201        let mut snap = MetricsSnapshot::default();
202        for shard in self.shards.iter() {
203            snap.merge(&shard.metrics.snapshot());
204        }
205        snap
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn basic_insert_get() {
215        let c: Cache<String, u64> = Cache::with_shards(64, 4);
216        c.insert("alpha".to_string(), 1);
217        c.insert("beta".to_string(), 2);
218        assert_eq!(c.get("alpha"), Some(1));
219        assert_eq!(c.get("beta"), Some(2));
220        assert_eq!(c.get("missing"), None);
221    }
222
223    #[test]
224    fn update_replaces_value() {
225        let c: Cache<u32, u32> = Cache::with_shards(32, 2);
226        assert_eq!(c.insert(1, 10), InsertOutcome::Inserted);
227        assert_eq!(c.insert(1, 20), InsertOutcome::Updated);
228        assert_eq!(c.get(&1), Some(20));
229    }
230
231    #[test]
232    fn insert_outcome_helpers() {
233        assert!(InsertOutcome::Inserted.is_present());
234        assert!(InsertOutcome::Inserted.is_new());
235        assert!(!InsertOutcome::Inserted.is_rejected());
236
237        assert!(InsertOutcome::Updated.is_present());
238        assert!(!InsertOutcome::Updated.is_new());
239        assert!(!InsertOutcome::Updated.is_rejected());
240
241        assert!(!InsertOutcome::Rejected.is_present());
242        assert!(!InsertOutcome::Rejected.is_new());
243        assert!(InsertOutcome::Rejected.is_rejected());
244    }
245
246    #[test]
247    fn contains_key_works() {
248        let c: Cache<&'static str, u32> = Cache::with_shards(64, 1);
249        assert!(!c.contains_key("missing"));
250        c.insert("present", 1);
251        assert!(c.contains_key("present"));
252        assert!(!c.contains_key("missing"));
253        c.remove("present");
254        assert!(!c.contains_key("present"));
255    }
256
257    #[test]
258    fn contains_key_respects_ttl() {
259        let c: Cache<u32, u32> = Cache::with_shards(64, 1);
260        c.insert_with_ttl(1, 100, Duration::from_millis(40));
261        assert!(c.contains_key(&1));
262        std::thread::sleep(Duration::from_millis(80));
263        assert!(!c.contains_key(&1));
264    }
265
266    #[test]
267    fn clear_empties_cache() {
268        let c: Cache<u32, u32> = Cache::with_shards(64, 4);
269        for i in 0..32u32 {
270            c.insert(i, i);
271        }
272        assert_eq!(c.len(), 32);
273        c.clear();
274        assert_eq!(c.len(), 0);
275        assert!(c.is_empty());
276        for i in 0..32u32 {
277            assert!(c.get(&i).is_none());
278        }
279        c.insert(99, 99);
280        assert_eq!(c.get(&99), Some(99));
281    }
282
283    #[test]
284    fn remove_works() {
285        let c: Cache<u32, u32> = Cache::with_shards(32, 2);
286        c.insert(1, 10);
287        assert_eq!(c.remove(&1), Some(10));
288        assert_eq!(c.remove(&1), None);
289        assert_eq!(c.get(&1), None);
290    }
291
292    #[test]
293    fn capacity_is_respected() {
294        let c: Cache<u64, u64> = Cache::with_shards(32, 4);
295        for i in 0..256u64 {
296            c.insert(i, i);
297        }
298
299        assert!(c.len() <= 32, "expected len ≤ 32, got {}", c.len());
300    }
301
302    #[test]
303    fn capacity_holds_under_all_hot_workload() {
304        const CAP: usize = 1024;
305        let c: Cache<u64, u64> = Cache::with_shards(CAP, 8);
306
307        for i in 0..CAP as u64 {
308            c.insert(i, i);
309        }
310        for _ in 0..8 {
311            for i in 0..CAP as u64 {
312                let _ = c.get(&i);
313            }
314        }
315
316        for i in (CAP as u64)..(CAP as u64 * 100) {
317            c.insert(i, i);
318        }
319
320        let len = c.len();
321        assert!(
322            len <= CAP * 2,
323            "cache grew unboundedly under hot workload: len={} cap={}",
324            len,
325            CAP
326        );
327    }
328
329    #[test]
330    fn hot_keys_survive_eviction() {
331        let c: Cache<u64, u64> = Cache::with_shards(64, 1);
332        for i in 0..8u64 {
333            c.insert(i, i);
334        }
335        for _ in 0..16 {
336            for i in 0..8u64 {
337                let _ = c.get(&i);
338            }
339        }
340        for i in 1000..2000u64 {
341            c.insert(i, i);
342        }
343        let surviving = (0..8u64).filter(|i| c.get(i).is_some()).count();
344        assert!(
345            surviving >= 6,
346            "expected ≥6 hot keys to survive, got {}",
347            surviving
348        );
349    }
350
351    #[test]
352    fn ttl_expires_after_deadline() {
353        let c: Cache<u32, u32> = Cache::with_shards(64, 1);
354        c.insert_with_ttl(1, 100, Duration::from_millis(50));
355        assert_eq!(c.get(&1), Some(100));
356        std::thread::sleep(Duration::from_millis(150));
357        assert_eq!(c.get(&1), None);
358    }
359
360    #[test]
361    fn ttl_default_insert_never_expires_automatically() {
362        let c: Cache<u32, u32> = Cache::with_shards(64, 1);
363        c.insert(1, 100);
364        std::thread::sleep(Duration::from_millis(60));
365        assert_eq!(c.get(&1), Some(100));
366    }
367
368    #[test]
369    fn ttl_zero_insert_is_immediately_expired() {
370        let c: Cache<u32, u32> = Cache::with_shards(64, 1);
371        c.insert_with_ttl(1, 100, Duration::ZERO);
372        assert_eq!(c.get(&1), None);
373    }
374
375    #[test]
376    fn ttl_mixed_with_no_ttl_in_same_cache() {
377        let c: Cache<u32, u32> = Cache::with_shards(64, 1);
378        c.insert(1, 100); // no TTL
379        c.insert_with_ttl(2, 200, Duration::from_millis(50));
380        std::thread::sleep(Duration::from_millis(150));
381        assert_eq!(c.get(&1), Some(100));
382        assert_eq!(c.get(&2), None);
383    }
384
385    #[test]
386    fn ttl_reinsert_extends_deadline() {
387        let c: Cache<u32, u32> = Cache::with_shards(64, 1);
388        c.insert_with_ttl(1, 100, Duration::from_millis(50));
389        std::thread::sleep(Duration::from_millis(30));
390        c.insert_with_ttl(1, 200, Duration::from_millis(200));
391        std::thread::sleep(Duration::from_millis(40));
392        assert_eq!(c.get(&1), Some(200));
393    }
394
395    #[test]
396    fn ttl_expired_entries_get_swept_on_rebalance() {
397        let c: Cache<u32, u32> = Cache::with_shards(4, 1);
398        c.insert_with_ttl(1, 100, Duration::from_millis(40));
399        c.insert_with_ttl(2, 200, Duration::from_millis(40));
400        c.insert_with_ttl(3, 300, Duration::from_millis(40));
401        c.insert(4, 400); // no TTL
402
403        std::thread::sleep(Duration::from_millis(100));
404
405        for k in 5..20u32 {
406            c.insert(k, k);
407        }
408        assert_eq!(c.get(&1), None);
409        assert_eq!(c.get(&2), None);
410        assert_eq!(c.get(&3), None);
411        assert!(c.len() <= 4, "expected len ≤ 4, got {}", c.len());
412    }
413
414    #[test]
415    fn concurrent_smoke() {
416        use std::sync::Arc;
417        use std::thread;
418        let c = Arc::new(Cache::<u64, u64>::with_shards(1024, 16));
419        let mut handles = Vec::new();
420        for t in 0..8u64 {
421            let c = Arc::clone(&c);
422            handles.push(thread::spawn(move || {
423                for i in 0..2000u64 {
424                    let k = (t * 10_000) + i;
425                    c.insert(k, k);
426                    let _ = c.get(&k);
427                }
428            }));
429        }
430        for h in handles {
431            h.join().unwrap();
432        }
433        let m = c.metrics();
434        assert!(m.insertions > 0);
435        assert!(m.hits + m.misses > 0);
436    }
437
438    #[test]
439    fn remove_churn_does_not_leak_queue_memory() {
440        let c: Cache<u64, u64> = Cache::with_shards(100, 1);
441        for cycle in 0..100u64 {
442            for i in 0..50u64 {
443                let k = cycle * 1000 + i;
444                c.insert(k, k);
445            }
446            for i in 0..50u64 {
447                let k = cycle * 1000 + i;
448                c.remove(&k);
449            }
450        }
451        assert_eq!(c.len(), 0);
452    }
453
454    #[test]
455    fn shard_distribution_uniformity() {
456        let c: Cache<u64, u64> = Cache::with_shards(10_000, 16);
457        for i in 0..10_000u64 {
458            c.insert(i, i);
459        }
460
461        let total = c.len();
462        let expected_per_shard = total as f64 / c.shard_count() as f64;
463        let lo = (expected_per_shard * 0.5) as usize;
464        let hi = (expected_per_shard * 1.5) as usize;
465        assert!(total > 0);
466        assert!(total <= 10_000, "total {} exceeds capacity", total);
467        let _ = (lo, hi);
468    }
469
470    #[test]
471    fn maintenance_sweeps_expired_entries() {
472        let c: Cache<u32, u32> = Cache::with_shards(64, 1);
473        c.enable_maintenance(MaintenanceConfig {
474            sweep_interval: Duration::from_millis(50),
475            max_sweep_per_shard: 32,
476        });
477        for i in 0..10u32 {
478            c.insert_with_ttl(i, i * 10, Duration::from_millis(30));
479        }
480        assert!(!c.is_empty());
481        std::thread::sleep(Duration::from_millis(200));
482        assert_eq!(c.len(), 0, "expected 0 after sweep, got {}", c.len());
483    }
484}