comemo/
memoize.rs

1use std::sync::LazyLock;
2use std::sync::atomic::{AtomicUsize, Ordering};
3
4use parking_lot::RwLock;
5use siphasher::sip128::{Hasher128, SipHasher13};
6
7use crate::accelerate;
8use crate::constraint::Constraint;
9use crate::input::Input;
10use crate::track::Call;
11use crate::tree::{CallTree, InsertError};
12
13/// The global list of eviction functions.
14static EVICTORS: RwLock<Vec<fn(usize)>> = RwLock::new(Vec::new());
15
16/// Executes a function, trying to use a cached result for it.
17#[allow(clippy::type_complexity)]
18pub fn memoize<'a, In, Out, F>(
19    cache: &Cache<In::Call, Out>,
20    mut input: In,
21    // These values must come from outside so that they have a lifetime that
22    // allows them to be attached to the `input`. On the call site, they are
23    // simply initialized as `&mut Default::default()`.
24    (storage, constraint): &'a mut (
25        In::Storage<&'a Constraint<In::Call>>,
26        Constraint<In::Call>,
27    ),
28    enabled: bool,
29    func: F,
30) -> Out
31where
32    In: Input<'a>,
33    Out: Clone + 'static,
34    F: FnOnce(In) -> Out,
35{
36    // Early bypass if memoization is disabled.
37    if !enabled {
38        let output = func(input);
39
40        // Ensure that the last call was a miss during testing.
41        #[cfg(feature = "testing")]
42        crate::testing::register_miss();
43
44        return output;
45    }
46
47    // Compute the hash of the input's key part.
48    let key = {
49        let mut state = SipHasher13::new();
50        input.key(&mut state);
51        state.finish128().as_u128()
52    };
53
54    // Check if there is a cached output.
55    if let Some(entry) = cache.0.read().lookup(key, &input) {
56        // Replay mutations.
57        for call in &entry.mutable {
58            input.call_mut(call);
59        }
60
61        #[cfg(feature = "testing")]
62        crate::testing::register_hit();
63
64        return entry.output.clone();
65    }
66
67    // Attach the constraint.
68    input.attach(storage, constraint);
69
70    // Execute the function with the constraint attached.
71    let output = func(input);
72
73    // Insert the result into the cache.
74    match cache.0.write().insert(key, constraint, output.clone()) {
75        Ok(()) => {}
76        Err(InsertError::AlreadyExists) => {
77            // A concurrent call with the same arguments may have inserted
78            // a value in the meantime. That's okay.
79        }
80        Err(InsertError::MissingCall) => {
81            // A missing call indicates a bug from a comemo user. See the
82            // documentation for `InsertError::MissingCall` for more details.
83            #[cfg(debug_assertions)]
84            panic!("comemo: memoized function is non-deterministic");
85        }
86    }
87
88    #[cfg(feature = "testing")]
89    crate::testing::register_miss();
90
91    output
92}
93
94/// Evict the global cache.
95///
96/// This removes all memoized results from the cache whose age is larger than or
97/// equal to `max_age`. The age of a result grows by one during each eviction
98/// and is reset to zero when the result produces a cache hit. Set `max_age` to
99/// zero to completely clear the cache.
100pub fn evict(max_age: usize) {
101    for subevict in EVICTORS.read().iter() {
102        subevict(max_age);
103    }
104
105    accelerate::evict();
106}
107
108/// Register an eviction function in the global list.
109pub fn register_evictor(evict: fn(usize)) {
110    EVICTORS.write().push(evict);
111}
112
113/// A cache for a single memoized function.
114pub struct Cache<C, Out>(LazyLock<RwLock<CacheData<C, Out>>>);
115
116impl<C: 'static, Out: 'static> Cache<C, Out> {
117    /// Create an empty cache.
118    ///
119    /// It must take an initialization function because the `evict` fn
120    /// pointer cannot be passed as an argument otherwise the function
121    /// passed to `Lazy::new` is a closure and not a function pointer.
122    pub const fn new(init: fn() -> RwLock<CacheData<C, Out>>) -> Self {
123        Self(LazyLock::new(init))
124    }
125
126    /// Evict all entries whose age is larger than or equal to `max_age`.
127    pub fn evict(&self, max_age: usize) {
128        self.0.write().evict(max_age);
129    }
130}
131
132/// The internal data for a cache.
133pub struct CacheData<C, Out> {
134    /// Maps from hashes to memoized results.
135    tree: CallTree<C, CacheEntry<C, Out>>,
136}
137
138/// A memoized result.
139struct CacheEntry<C, Out> {
140    /// The memoized function's output.
141    output: Out,
142    /// Mutable tracked calls that must be replayed.
143    mutable: Vec<C>,
144    /// How many evictions have passed since the entry has last been used.
145    age: AtomicUsize,
146}
147
148impl<C, Out: 'static> CacheData<C, Out> {
149    /// Evict all entries whose age is larger than or equal to `max_age`.
150    fn evict(&mut self, max_age: usize) {
151        self.tree.retain(|entry| {
152            let age = entry.age.get_mut();
153            *age += 1;
154            *age <= max_age
155        });
156    }
157
158    /// Look for a matching entry in the cache.
159    fn lookup<'a, In>(&self, key: u128, input: &In) -> Option<&CacheEntry<C, Out>>
160    where
161        C: Call,
162        In: Input<'a, Call = C>,
163    {
164        self.tree
165            .get(key, |c| input.call(c))
166            .inspect(|entry| entry.age.store(0, Ordering::SeqCst))
167    }
168
169    /// Insert an entry into the cache.
170    fn insert(
171        &mut self,
172        key: u128,
173        constraint: &Constraint<C>,
174        output: Out,
175    ) -> Result<(), InsertError>
176    where
177        C: Call,
178    {
179        let (immutable, mutable) = constraint.take();
180        self.tree.insert(
181            key,
182            immutable,
183            CacheEntry { output, mutable, age: AtomicUsize::new(0) },
184        )
185    }
186}
187
188impl<C, Out> Default for CacheData<C, Out> {
189    fn default() -> Self {
190        Self { tree: CallTree::new() }
191    }
192}