Skip to main content

alkahest_cas/jit/
cache.rs

1//! Content-addressed cache for JIT-compiled functions.
2//!
3//! [`CompileCache`] maps `(ExprId, Vec<ExprId>)` → `Arc<CompiledFn>`.
4//! Because [`ExprPool`] already hash-conses
5//! expressions, `ExprId` is a stable content key: the same expression tree
6//! always produces the same `ExprId`.  This means **the cache key _is_ the
7//! content hash** — no separate hashing of the expression tree is required.
8//!
9//! # Key design
10//!
11//! The cache key is `(ExprId, Vec<ExprId>)` — the root expression plus the
12//! ordered list of input variables.  Two compilations of the same expression
13//! with different variable orderings produce separate entries (and separate
14//! compiled functions with different argument positions).
15//!
16//! # Lifetime of compiled code
17//!
18//! Each cached value is an `Arc<CompiledFn>`.  The compiled code stays alive
19//! as long as any live `Arc` references it — clearing or dropping the cache
20//! does not invalidate `Arc`s already returned to callers.
21//!
22//! # Thread safety
23//!
24//! `CompileCache` itself requires `&mut self` for writes and is therefore
25//! single-owner.  Wrap in `Mutex<CompileCache>` or `RwLock<CompileCache>` for
26//! shared multi-threaded access.  `Arc<CompiledFn>` is `Send + Sync` so
27//! compiled functions can be freely shared across threads after retrieval.
28//!
29//! # Example
30//!
31//! ```
32//! use alkahest_cas::kernel::{Domain, ExprPool};
33//! use alkahest_cas::jit::CompileCache;
34//! use std::sync::Arc;
35//!
36//! let pool = ExprPool::new();
37//! let x = pool.symbol("x", Domain::Real);
38//! let expr = pool.pow(x, pool.integer(2_i32));
39//!
40//! let mut cache = CompileCache::new();
41//!
42//! // First call compiles
43//! let f1 = cache.compile(expr, &[x], &pool).unwrap();
44//! // Second call is a cache hit — same Arc, no recompilation
45//! let f2 = cache.compile(expr, &[x], &pool).unwrap();
46//!
47//! assert!(Arc::ptr_eq(&f1, &f2));
48//! assert!((f1.call(&[3.0]) - 9.0).abs() < 1e-10);
49//! ```
50
51use super::{compile_with, CompileConfig, CompiledFn, JitError};
52use crate::kernel::{ExprId, ExprPool};
53use std::collections::HashMap;
54use std::sync::Arc;
55
56// ---------------------------------------------------------------------------
57// Cache key
58// ---------------------------------------------------------------------------
59
60/// `(expression root, ordered input variables)`.
61///
62/// The `Vec<ExprId>` captures variable order — two compilations of the same
63/// expression with different orderings produce separate entries.
64type CacheKey = (ExprId, Vec<ExprId>, CompileConfig);
65
66// ---------------------------------------------------------------------------
67// CompileCache
68// ---------------------------------------------------------------------------
69
70/// Content-addressed cache of JIT-compiled functions.
71///
72/// See the [module documentation](self) for full details.
73pub struct CompileCache {
74    store: HashMap<CacheKey, Arc<CompiledFn>>,
75    /// Total number of compilations (cache misses + initial compiles).
76    compiles: u64,
77    /// Total number of cache hits.
78    hits: u64,
79}
80
81impl CompileCache {
82    /// Create a new, empty cache.
83    pub fn new() -> Self {
84        Self {
85            store: HashMap::new(),
86            compiles: 0,
87            hits: 0,
88        }
89    }
90
91    /// Compile `expr` with the given `inputs`, returning a cached `Arc<CompiledFn>`.
92    ///
93    /// # Cache behaviour
94    ///
95    /// - **Miss**: the expression is compiled via [`jit::compile`](super::compile)
96    ///   and the result is stored.  Subsequent calls with the same `(expr,
97    ///   inputs)` pair return the cached value immediately.
98    /// - **Hit**: returns `Arc::clone` of the cached value — O(1), no
99    ///   recompilation.
100    ///
101    /// # Errors
102    ///
103    /// Returns `Err(JitError)` only on a cache miss where compilation fails.
104    /// Cache hits never fail.
105    pub fn compile(
106        &mut self,
107        expr: ExprId,
108        inputs: &[ExprId],
109        pool: &ExprPool,
110    ) -> Result<Arc<CompiledFn>, JitError> {
111        self.compile_with(expr, inputs, pool, CompileConfig::default())
112    }
113
114    /// Like [`compile`](Self::compile) but passes [`CompileConfig`] to tier selection.
115    ///
116    /// Use [`CompileConfig::for_batch`] when the cached function will drive a large
117    /// `call_bulk` / `call_batch` sweep so LLVM is selected when available.
118    pub fn compile_with(
119        &mut self,
120        expr: ExprId,
121        inputs: &[ExprId],
122        pool: &ExprPool,
123        config: CompileConfig,
124    ) -> Result<Arc<CompiledFn>, JitError> {
125        let key: CacheKey = (expr, inputs.to_vec(), config);
126        if let Some(cached) = self.store.get(&key) {
127            self.hits += 1;
128            return Ok(Arc::clone(cached));
129        }
130        self.compiles += 1;
131        let compiled = Arc::new(compile_with(expr, inputs, pool, config)?);
132        self.store.insert(key, Arc::clone(&compiled));
133        Ok(compiled)
134    }
135
136    /// Number of `(expr, inputs)` pairs currently cached.
137    pub fn len(&self) -> usize {
138        self.store.len()
139    }
140
141    /// Returns `true` if the cache contains no entries.
142    pub fn is_empty(&self) -> bool {
143        self.store.is_empty()
144    }
145
146    /// Returns `true` if a compiled function for `(expr, inputs)` is cached.
147    pub fn contains(&self, expr: ExprId, inputs: &[ExprId]) -> bool {
148        self.contains_with(expr, inputs, CompileConfig::default())
149    }
150
151    /// Returns `true` if `(expr, inputs, config)` is cached.
152    pub fn contains_with(&self, expr: ExprId, inputs: &[ExprId], config: CompileConfig) -> bool {
153        self.store.contains_key(&(expr, inputs.to_vec(), config))
154    }
155
156    /// Total number of compilations performed (cache misses that succeeded).
157    pub fn compile_count(&self) -> u64 {
158        self.compiles
159    }
160
161    /// Total number of cache hits.
162    pub fn hit_count(&self) -> u64 {
163        self.hits
164    }
165
166    /// Cache hit rate in `[0.0, 1.0]`; `0.0` when no lookups have been made.
167    pub fn hit_rate(&self) -> f64 {
168        let total = self.compiles + self.hits;
169        if total == 0 {
170            0.0
171        } else {
172            self.hits as f64 / total as f64
173        }
174    }
175
176    /// Evict all cached functions, freeing compiled code unless other `Arc`s
177    /// keep them alive.
178    pub fn clear(&mut self) {
179        self.store.clear();
180        // Keep statistics — they describe the lifetime of the cache, not just
181        // the current contents.
182    }
183
184    /// Evict a single entry.  Returns the cached function if it was present.
185    pub fn evict(&mut self, expr: ExprId, inputs: &[ExprId]) -> Option<Arc<CompiledFn>> {
186        self.evict_with(expr, inputs, CompileConfig::default())
187    }
188
189    /// Evict a single entry for the given compile configuration.
190    pub fn evict_with(
191        &mut self,
192        expr: ExprId,
193        inputs: &[ExprId],
194        config: CompileConfig,
195    ) -> Option<Arc<CompiledFn>> {
196        self.store.remove(&(expr, inputs.to_vec(), config))
197    }
198}
199
200impl Default for CompileCache {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206// ---------------------------------------------------------------------------
207// Tests
208// ---------------------------------------------------------------------------
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::kernel::{Domain, ExprPool};
214
215    fn p() -> ExprPool {
216        ExprPool::new()
217    }
218
219    #[test]
220    fn cache_miss_then_hit() {
221        let pool = p();
222        let x = pool.symbol("x", Domain::Real);
223        let expr = pool.pow(x, pool.integer(2_i32));
224
225        let mut cache = CompileCache::new();
226        assert!(cache.is_empty());
227        assert_eq!(cache.compile_count(), 0);
228        assert_eq!(cache.hit_count(), 0);
229
230        let f1 = cache.compile(expr, &[x], &pool).unwrap();
231        assert_eq!(cache.len(), 1);
232        assert_eq!(cache.compile_count(), 1);
233        assert_eq!(cache.hit_count(), 0);
234
235        let f2 = cache.compile(expr, &[x], &pool).unwrap();
236        assert_eq!(cache.len(), 1); // still one entry
237        assert_eq!(cache.compile_count(), 1); // no new compile
238        assert_eq!(cache.hit_count(), 1);
239
240        // Same Arc — identical pointer
241        assert!(Arc::ptr_eq(&f1, &f2));
242    }
243
244    #[test]
245    fn cache_correct_result() {
246        let pool = p();
247        let x = pool.symbol("x", Domain::Real);
248        let expr = pool.pow(x, pool.integer(2_i32));
249
250        let mut cache = CompileCache::new();
251        let f = cache.compile(expr, &[x], &pool).unwrap();
252        assert!((f.call(&[3.0]) - 9.0).abs() < 1e-10);
253        assert!((f.call(&[5.0]) - 25.0).abs() < 1e-10);
254    }
255
256    #[test]
257    fn different_var_order_different_entry() {
258        let pool = p();
259        let x = pool.symbol("x", Domain::Real);
260        let y = pool.symbol("y", Domain::Real);
261        let expr = pool.add(vec![x, y]);
262
263        let mut cache = CompileCache::new();
264        let f_xy = cache.compile(expr, &[x, y], &pool).unwrap();
265        let f_yx = cache.compile(expr, &[y, x], &pool).unwrap();
266
267        // Different orderings → different cache entries
268        assert_eq!(cache.len(), 2);
269        assert!(!Arc::ptr_eq(&f_xy, &f_yx));
270
271        // f_xy: inputs[0]=x, inputs[1]=y; call(1.0, 2.0) → x=1, y=2 → 3
272        assert!((f_xy.call(&[1.0, 2.0]) - 3.0).abs() < 1e-10);
273        // f_yx: inputs[0]=y, inputs[1]=x; call(1.0, 2.0) → y=1, x=2 → 3
274        assert!((f_yx.call(&[1.0, 2.0]) - 3.0).abs() < 1e-10);
275    }
276
277    #[test]
278    fn different_exprs_different_entries() {
279        let pool = p();
280        let x = pool.symbol("x", Domain::Real);
281        let sq = pool.pow(x, pool.integer(2_i32));
282        let cube = pool.pow(x, pool.integer(3_i32));
283
284        let mut cache = CompileCache::new();
285        let f_sq = cache.compile(sq, &[x], &pool).unwrap();
286        let f_cu = cache.compile(cube, &[x], &pool).unwrap();
287
288        assert_eq!(cache.len(), 2);
289        assert!(!Arc::ptr_eq(&f_sq, &f_cu));
290        assert!((f_sq.call(&[3.0]) - 9.0).abs() < 1e-10);
291        assert!((f_cu.call(&[3.0]) - 27.0).abs() < 1e-10);
292    }
293
294    #[test]
295    fn arc_survives_cache_clear() {
296        let pool = p();
297        let x = pool.symbol("x", Domain::Real);
298        let expr = pool.pow(x, pool.integer(2_i32));
299
300        let mut cache = CompileCache::new();
301        let f = cache.compile(expr, &[x], &pool).unwrap();
302
303        cache.clear();
304        assert!(cache.is_empty());
305
306        // f still valid — Arc keeps it alive
307        assert!((f.call(&[4.0]) - 16.0).abs() < 1e-10);
308    }
309
310    #[test]
311    fn evict_removes_single_entry() {
312        let pool = p();
313        let x = pool.symbol("x", Domain::Real);
314        let sq = pool.pow(x, pool.integer(2_i32));
315        let cube = pool.pow(x, pool.integer(3_i32));
316
317        let mut cache = CompileCache::new();
318        cache.compile(sq, &[x], &pool).unwrap();
319        cache.compile(cube, &[x], &pool).unwrap();
320        assert_eq!(cache.len(), 2);
321
322        let evicted = cache.evict(sq, &[x]);
323        assert!(evicted.is_some());
324        assert_eq!(cache.len(), 1);
325        assert!(!cache.contains(sq, &[x]));
326        assert!(cache.contains(cube, &[x]));
327    }
328
329    #[test]
330    fn contains_checks_key() {
331        let pool = p();
332        let x = pool.symbol("x", Domain::Real);
333        let y = pool.symbol("y", Domain::Real);
334        let expr = pool.add(vec![x, y]);
335
336        let mut cache = CompileCache::new();
337        assert!(!cache.contains(expr, &[x, y]));
338        cache.compile(expr, &[x, y], &pool).unwrap();
339        assert!(cache.contains(expr, &[x, y]));
340        assert!(!cache.contains(expr, &[y, x])); // different order
341    }
342
343    #[test]
344    fn hit_rate_is_correct() {
345        let pool = p();
346        let x = pool.symbol("x", Domain::Real);
347        let expr = pool.pow(x, pool.integer(2_i32));
348
349        let mut cache = CompileCache::new();
350        assert_eq!(cache.hit_rate(), 0.0);
351
352        cache.compile(expr, &[x], &pool).unwrap(); // miss
353        assert_eq!(cache.hit_rate(), 0.0); // 0/1
354
355        cache.compile(expr, &[x], &pool).unwrap(); // hit
356        cache.compile(expr, &[x], &pool).unwrap(); // hit
357
358        // 2 hits / 3 total = 2/3
359        let rate = cache.hit_rate();
360        assert!((rate - 2.0 / 3.0).abs() < 1e-10);
361    }
362}