Skip to main content

pounce_common/
cached.rs

1//! Tag-keyed result cache.
2//!
3//! Mirrors `Common/IpCachedResults.hpp`. Ipopt's `CachedResults<T>`
4//! stores an LRU list of (dependency tags, scalar dependencies, value)
5//! tuples; on lookup it returns the value whose dependency vector
6//! matches the current tags of the depended-on `TaggedObject`s.
7//!
8//! Differences from upstream:
9//! - We do not implement the `Observer/Subject` invalidation push
10//!   path. Upstream uses it to mark stale entries early; correctness
11//!   only requires the pull-side check at lookup time, which we keep.
12//! - Negative `max_cache_size` (= unbounded) is supported with
13//!   `Cache::unbounded()` — same semantics as Ipopt's negative size.
14
15use crate::tagged::{Tag, TaggedObject};
16use crate::types::Number;
17use std::collections::VecDeque;
18
19/// One entry in the cache: stored value plus the dependency tags and
20/// scalar dependencies it was computed against.
21#[derive(Debug, Clone)]
22struct Entry<T> {
23    value: T,
24    dep_tags: Vec<Tag>,
25    scalar_deps: Vec<Number>,
26}
27
28impl<T> Entry<T> {
29    fn matches(&self, dep_tags: &[Tag], scalar_deps: &[Number]) -> bool {
30        if self.dep_tags.len() != dep_tags.len() || self.scalar_deps.len() != scalar_deps.len() {
31            return false;
32        }
33        for (a, b) in self.dep_tags.iter().zip(dep_tags.iter()) {
34            if a != b {
35                return false;
36            }
37        }
38        for (a, b) in self.scalar_deps.iter().zip(scalar_deps.iter()) {
39            // Matches Ipopt: bit-equality via float `!=` comparison.
40            if a != b {
41                return false;
42            }
43        }
44        true
45    }
46}
47
48/// LRU cache keyed on dependency tags + scalar dependencies. `T` is
49/// the cached value type (`Number`, a `Vec<Number>`, an `Rc<Vector>`,
50/// ...). Equivalent to `Ipopt::CachedResults<T>`.
51#[derive(Debug)]
52pub struct Cache<T> {
53    /// `None` means unbounded.
54    max_size: Option<usize>,
55    /// Front = most-recently inserted, matching Ipopt's `push_front`.
56    entries: VecDeque<Entry<T>>,
57}
58
59impl<T: Clone> Cache<T> {
60    /// Bounded cache holding up to `max_size` entries.
61    pub fn new(max_size: usize) -> Self {
62        Self {
63            max_size: Some(max_size),
64            entries: VecDeque::new(),
65        }
66    }
67
68    /// Equivalent to Ipopt's negative `max_cache_size` (no eviction).
69    pub fn unbounded() -> Self {
70        Self {
71            max_size: None,
72            entries: VecDeque::new(),
73        }
74    }
75
76    pub fn clear(&mut self) {
77        self.entries.clear();
78    }
79
80    pub fn len(&self) -> usize {
81        self.entries.len()
82    }
83
84    pub fn is_empty(&self) -> bool {
85        self.entries.is_empty()
86    }
87
88    /// Generic add — equivalent to `AddCachedResult(result, dependents, scalar_dependents)`.
89    pub fn add(
90        &mut self,
91        value: T,
92        dependents: &[&dyn TaggedObject],
93        scalar_dependents: &[Number],
94    ) {
95        let dep_tags: Vec<Tag> = dependents.iter().map(|d| d.get_tag()).collect();
96        self.add_with_tags(value, dep_tags, scalar_dependents.to_vec());
97    }
98
99    fn add_with_tags(&mut self, value: T, dep_tags: Vec<Tag>, scalar_deps: Vec<Number>) {
100        self.entries.push_front(Entry {
101            value,
102            dep_tags,
103            scalar_deps,
104        });
105        if let Some(max) = self.max_size {
106            while self.entries.len() > max {
107                self.entries.pop_back();
108            }
109        }
110    }
111
112    /// Generic lookup — equivalent to `GetCachedResult(...)`. Returns
113    /// `Some(value)` if a stored entry's dependency tags exactly match
114    /// the current tags of `dependents` and the scalar deps match.
115    pub fn get(&self, dependents: &[&dyn TaggedObject], scalar_dependents: &[Number]) -> Option<T> {
116        let dep_tags: Vec<Tag> = dependents.iter().map(|d| d.get_tag()).collect();
117        for e in &self.entries {
118            if e.matches(&dep_tags, scalar_dependents) {
119                return Some(e.value.clone());
120            }
121        }
122        None
123    }
124
125    pub fn add_1dep(&mut self, value: T, dep: &dyn TaggedObject) {
126        self.add(value, &[dep], &[]);
127    }
128
129    pub fn get_1dep(&self, dep: &dyn TaggedObject) -> Option<T> {
130        self.get(&[dep], &[])
131    }
132
133    pub fn add_2dep(&mut self, value: T, d1: &dyn TaggedObject, d2: &dyn TaggedObject) {
134        self.add(value, &[d1, d2], &[]);
135    }
136
137    pub fn get_2dep(&self, d1: &dyn TaggedObject, d2: &dyn TaggedObject) -> Option<T> {
138        self.get(&[d1, d2], &[])
139    }
140
141    pub fn add_3dep(
142        &mut self,
143        value: T,
144        d1: &dyn TaggedObject,
145        d2: &dyn TaggedObject,
146        d3: &dyn TaggedObject,
147    ) {
148        self.add(value, &[d1, d2, d3], &[]);
149    }
150
151    pub fn get_3dep(
152        &self,
153        d1: &dyn TaggedObject,
154        d2: &dyn TaggedObject,
155        d3: &dyn TaggedObject,
156    ) -> Option<T> {
157        self.get(&[d1, d2, d3], &[])
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::tagged::TaggedCell;
165
166    #[test]
167    fn hit_then_miss_after_bump() {
168        let dep = TaggedCell::new();
169        let mut cache: Cache<f64> = Cache::new(4);
170        cache.add_1dep(2.5, &dep);
171        assert_eq!(cache.get_1dep(&dep), Some(2.5));
172        dep.bump();
173        assert_eq!(cache.get_1dep(&dep), None);
174    }
175
176    #[test]
177    fn lru_evicts_oldest() {
178        let d1 = TaggedCell::new();
179        let d2 = TaggedCell::new();
180        let d3 = TaggedCell::new();
181        let mut cache: Cache<i32> = Cache::new(2);
182        cache.add_1dep(1, &d1);
183        cache.add_1dep(2, &d2);
184        cache.add_1dep(3, &d3);
185        assert_eq!(cache.get_1dep(&d1), None); // evicted
186        assert_eq!(cache.get_1dep(&d2), Some(2));
187        assert_eq!(cache.get_1dep(&d3), Some(3));
188    }
189
190    #[test]
191    fn unbounded_keeps_all() {
192        let deps: Vec<TaggedCell> = (0..32).map(|_| TaggedCell::new()).collect();
193        let mut cache: Cache<i32> = Cache::unbounded();
194        for (i, d) in deps.iter().enumerate() {
195            cache.add_1dep(i as i32, d);
196        }
197        for (i, d) in deps.iter().enumerate() {
198            assert_eq!(cache.get_1dep(d), Some(i as i32));
199        }
200    }
201
202    #[test]
203    fn scalar_dep_distinguishes_entries() {
204        let dep = TaggedCell::new();
205        let mut cache: Cache<i32> = Cache::new(8);
206        cache.add(10, &[&dep], &[1.0]);
207        cache.add(20, &[&dep], &[2.0]);
208        assert_eq!(cache.get(&[&dep], &[1.0]), Some(10));
209        assert_eq!(cache.get(&[&dep], &[2.0]), Some(20));
210        assert_eq!(cache.get(&[&dep], &[3.0]), None);
211    }
212}