comemo/
cache.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicUsize, Ordering};
3
4use once_cell::sync::Lazy;
5use parking_lot::RwLock;
6use siphasher::sip128::{Hasher128, SipHasher13};
7
8use crate::accelerate;
9use crate::constraint::Join;
10use crate::input::Input;
11
12/// The global list of eviction functions.
13static EVICTORS: RwLock<Vec<fn(usize)>> = RwLock::new(Vec::new());
14
15#[cfg(feature = "testing")]
16thread_local! {
17    /// Whether the last call was a hit.
18    static LAST_WAS_HIT: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
19}
20
21/// Execute a function or use a cached result for it.
22pub fn memoized<'c, In, Out, F>(
23    mut input: In,
24    constraint: &'c In::Constraint,
25    cache: &Cache<In::Constraint, Out>,
26    func: F,
27) -> Out
28where
29    In: Input + 'c,
30    Out: Clone + 'static,
31    F: FnOnce(In::Tracked<'c>) -> Out,
32{
33    // Compute the hash of the input's key part.
34    let key = {
35        let mut state = SipHasher13::new();
36        input.key(&mut state);
37        state.finish128().as_u128()
38    };
39
40    // Check if there is a cached output.
41    let borrow = cache.0.read();
42    if let Some((constrained, value)) = borrow.lookup::<In>(key, &input) {
43        // Replay the mutations.
44        input.replay(constrained);
45
46        // Add the cached constraints to the outer ones.
47        input.retrack(constraint).1.join(constrained);
48
49        #[cfg(feature = "testing")]
50        LAST_WAS_HIT.with(|cell| cell.set(true));
51
52        return value.clone();
53    }
54
55    // Release the borrow so that nested memoized calls can access the
56    // cache without dead locking.
57    drop(borrow);
58
59    // Execute the function with the new constraints hooked in.
60    let (input, outer) = input.retrack(constraint);
61    let output = func(input);
62
63    // Add the new constraints to the outer ones.
64    outer.join(constraint);
65
66    // Insert the result into the cache.
67    let mut borrow = cache.0.write();
68    borrow.insert::<In>(key, constraint.take(), output.clone());
69
70    #[cfg(feature = "testing")]
71    LAST_WAS_HIT.with(|cell| cell.set(false));
72
73    output
74}
75
76/// Evict the global cache.
77///
78/// This removes all memoized results from the cache whose age is larger than or
79/// equal to `max_age`. The age of a result grows by one during each eviction
80/// and is reset to zero when the result produces a cache hit. Set `max_age` to
81/// zero to completely clear the cache.
82pub fn evict(max_age: usize) {
83    for subevict in EVICTORS.read().iter() {
84        subevict(max_age);
85    }
86
87    accelerate::evict();
88}
89
90/// Register an eviction function in the global list.
91pub fn register_evictor(evict: fn(usize)) {
92    EVICTORS.write().push(evict);
93}
94
95/// Whether the last call was a hit.
96#[cfg(feature = "testing")]
97pub fn last_was_hit() -> bool {
98    LAST_WAS_HIT.with(|cell| cell.get())
99}
100
101/// A cache for a single memoized function.
102pub struct Cache<C, Out>(Lazy<RwLock<CacheData<C, Out>>>);
103
104impl<C: 'static, Out: 'static> Cache<C, Out> {
105    /// Create an empty cache.
106    ///
107    /// It must take an initialization function because the `evict` fn
108    /// pointer cannot be passed as an argument otherwise the function
109    /// passed to `Lazy::new` is a closure and not a function pointer.
110    pub const fn new(init: fn() -> RwLock<CacheData<C, Out>>) -> Self {
111        Self(Lazy::new(init))
112    }
113
114    /// Evict all entries whose age is larger than or equal to `max_age`.
115    pub fn evict(&self, max_age: usize) {
116        self.0.write().evict(max_age)
117    }
118}
119
120/// The internal data for a cache.
121pub struct CacheData<C, Out> {
122    /// Maps from hashes to memoized results.
123    entries: HashMap<u128, Vec<CacheEntry<C, Out>>>,
124}
125
126impl<C, Out: 'static> CacheData<C, Out> {
127    /// Evict all entries whose age is larger than or equal to `max_age`.
128    fn evict(&mut self, max_age: usize) {
129        self.entries.retain(|_, entries| {
130            entries.retain_mut(|entry| {
131                let age = entry.age.get_mut();
132                *age += 1;
133                *age <= max_age
134            });
135            !entries.is_empty()
136        });
137    }
138
139    /// Look for a matching entry in the cache.
140    fn lookup<In>(&self, key: u128, input: &In) -> Option<(&In::Constraint, &Out)>
141    where
142        In: Input<Constraint = C>,
143    {
144        self.entries
145            .get(&key)?
146            .iter()
147            .rev()
148            .find_map(|entry| entry.lookup::<In>(input))
149    }
150
151    /// Insert an entry into the cache.
152    fn insert<In>(&mut self, key: u128, constraint: In::Constraint, output: Out)
153    where
154        In: Input<Constraint = C>,
155    {
156        self.entries
157            .entry(key)
158            .or_default()
159            .push(CacheEntry::new::<In>(constraint, output));
160    }
161}
162
163impl<C, Out> Default for CacheData<C, Out> {
164    fn default() -> Self {
165        Self { entries: HashMap::new() }
166    }
167}
168
169/// A memoized result.
170struct CacheEntry<C, Out> {
171    /// The memoized function's constraint.
172    constraint: C,
173    /// The memoized function's output.
174    output: Out,
175    /// How many evictions have passed since the entry has been last used.
176    age: AtomicUsize,
177}
178
179impl<C, Out: 'static> CacheEntry<C, Out> {
180    /// Create a new entry.
181    fn new<In>(constraint: In::Constraint, output: Out) -> Self
182    where
183        In: Input<Constraint = C>,
184    {
185        Self { constraint, output, age: AtomicUsize::new(0) }
186    }
187
188    /// Return the entry's output if it is valid for the given input.
189    fn lookup<In>(&self, input: &In) -> Option<(&In::Constraint, &Out)>
190    where
191        In: Input<Constraint = C>,
192    {
193        input.validate(&self.constraint).then(|| {
194            self.age.store(0, Ordering::SeqCst);
195            (&self.constraint, &self.output)
196        })
197    }
198}