Skip to main content

axonml_jit/
cache.rs

1//! Function Cache — Hash-Keyed Reuse of Compiled JIT Graphs
2//!
3//! Implements `FunctionCache`, a bounded `FxHashMap<u64, CompiledFunction>`
4//! guarded by a `parking_lot::RwLock` that memoizes codegen output keyed by a
5//! structural hash of the IR graph. `hash_graph` walks every `Node`, hashing
6//! op discriminant, dtype, shape dims, and per-op payload (input name,
7//! constant bit-pattern, scalar bit-pattern for `AddScalar` / `MulScalar`,
8//! reshape / transpose / reduction / squeeze / broadcast / cast arguments,
9//! and left/right/condition node indices for binary and ternary ops), with a
10//! fallback that hashes all inputs for unary ops. Provides capacity-bounded
11//! first-key eviction on insert, graph-keyed `get_by_graph` /
12//! `insert_for_graph` helpers, a `CacheStats` view (`entries`, `max_size`,
13//! `utilization`), `Default` constructing a 1000-entry cache, and tests
14//! covering structural hash equality, insert/get round-trip, LRU-style
15//! eviction, and stats reporting.
16//!
17//! # File
18//! `crates/axonml-jit/src/cache.rs`
19//!
20//! # Author
21//! Andrew Jewell Sr. — AutomataNexus LLC
22//! ORCID: 0009-0005-2158-7060
23//!
24//! # Updated
25//! April 16, 2026 11:15 PM EST
26//!
27//! # Disclaimer
28//! Use at own risk. This software is provided "as is", without warranty of any
29//! kind, express or implied. The author and AutomataNexus shall not be held
30//! liable for any damages arising from the use of this software.
31
32// =============================================================================
33// Imports
34// =============================================================================
35
36use parking_lot::RwLock;
37use rustc_hash::FxHashMap;
38use std::collections::hash_map::DefaultHasher;
39use std::hash::{Hash, Hasher};
40
41use crate::codegen::CompiledFunction;
42use crate::ir::Graph;
43
44// =============================================================================
45// FunctionCache
46// =============================================================================
47
48/// Cache for compiled functions.
49pub struct FunctionCache {
50    cache: RwLock<FxHashMap<u64, CompiledFunction>>,
51    max_size: usize,
52}
53
54impl FunctionCache {
55    /// Creates a new function cache.
56    pub fn new(max_size: usize) -> Self {
57        Self {
58            cache: RwLock::new(FxHashMap::default()),
59            max_size,
60        }
61    }
62
63    /// Creates a cache with default size (1000).
64    pub fn default_size() -> Self {
65        Self::new(1000)
66    }
67
68    // -------------------------------------------------------------------------
69    // Graph Hashing
70    // -------------------------------------------------------------------------
71
72    /// Computes a hash key for a graph.
73    pub fn hash_graph(graph: &Graph) -> u64 {
74        let mut hasher = DefaultHasher::new();
75
76        // Hash graph structure
77        for node in graph.nodes() {
78            // Hash op type
79            std::mem::discriminant(&node.op).hash(&mut hasher);
80
81            // Hash dtype
82            node.dtype.hash(&mut hasher);
83
84            // Hash shape
85            node.shape.dims().hash(&mut hasher);
86
87            // Hash op-specific data
88            match &node.op {
89                crate::ir::Op::Input { name } => name.hash(&mut hasher),
90                crate::ir::Op::Output { name, input } => {
91                    name.hash(&mut hasher);
92                    input.index().hash(&mut hasher);
93                }
94                crate::ir::Op::Constant { value } => {
95                    value.to_bits().hash(&mut hasher);
96                }
97                crate::ir::Op::AddScalar { input, scalar }
98                | crate::ir::Op::MulScalar { input, scalar } => {
99                    input.index().hash(&mut hasher);
100                    scalar.to_bits().hash(&mut hasher);
101                }
102                crate::ir::Op::Reshape { input, shape } => {
103                    input.index().hash(&mut hasher);
104                    shape.hash(&mut hasher);
105                }
106                crate::ir::Op::Transpose { input, dim0, dim1 } => {
107                    input.index().hash(&mut hasher);
108                    dim0.hash(&mut hasher);
109                    dim1.hash(&mut hasher);
110                }
111                crate::ir::Op::SumAxis {
112                    input,
113                    axis,
114                    keepdim,
115                }
116                | crate::ir::Op::MeanAxis {
117                    input,
118                    axis,
119                    keepdim,
120                }
121                | crate::ir::Op::MaxAxis {
122                    input,
123                    axis,
124                    keepdim,
125                } => {
126                    input.index().hash(&mut hasher);
127                    axis.hash(&mut hasher);
128                    keepdim.hash(&mut hasher);
129                }
130                crate::ir::Op::Squeeze { input, dim } | crate::ir::Op::Unsqueeze { input, dim } => {
131                    input.index().hash(&mut hasher);
132                    dim.hash(&mut hasher);
133                }
134                crate::ir::Op::Broadcast { input, shape } => {
135                    input.index().hash(&mut hasher);
136                    shape.hash(&mut hasher);
137                }
138                crate::ir::Op::Cast { input, dtype } => {
139                    input.index().hash(&mut hasher);
140                    dtype.hash(&mut hasher);
141                }
142                // Binary ops
143                crate::ir::Op::Add { lhs, rhs }
144                | crate::ir::Op::Sub { lhs, rhs }
145                | crate::ir::Op::Mul { lhs, rhs }
146                | crate::ir::Op::Div { lhs, rhs }
147                | crate::ir::Op::Pow {
148                    base: lhs,
149                    exp: rhs,
150                }
151                | crate::ir::Op::Max { lhs, rhs }
152                | crate::ir::Op::Min { lhs, rhs }
153                | crate::ir::Op::MatMul { lhs, rhs }
154                | crate::ir::Op::Gt { lhs, rhs }
155                | crate::ir::Op::Lt { lhs, rhs }
156                | crate::ir::Op::Eq { lhs, rhs } => {
157                    lhs.index().hash(&mut hasher);
158                    rhs.index().hash(&mut hasher);
159                }
160                // Ternary ops
161                crate::ir::Op::Where { condition, x, y } => {
162                    condition.index().hash(&mut hasher);
163                    x.index().hash(&mut hasher);
164                    y.index().hash(&mut hasher);
165                }
166                // Unary ops just hash input
167                _ => {
168                    for input in node.op.inputs() {
169                        input.index().hash(&mut hasher);
170                    }
171                }
172            }
173        }
174
175        hasher.finish()
176    }
177
178    // -------------------------------------------------------------------------
179    // Lookup and Insertion
180    // -------------------------------------------------------------------------
181
182    /// Gets a cached function or returns None.
183    pub fn get(&self, key: u64) -> Option<CompiledFunction> {
184        self.cache.read().get(&key).cloned()
185    }
186
187    /// Gets a cached function by graph.
188    pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
189        let key = Self::hash_graph(graph);
190        self.get(key)
191    }
192
193    /// Inserts a compiled function.
194    pub fn insert(&self, key: u64, func: CompiledFunction) {
195        let mut cache = self.cache.write();
196
197        // Evict if at capacity
198        if cache.len() >= self.max_size {
199            // Simple eviction: remove first entry
200            if let Some(&first_key) = cache.keys().next() {
201                cache.remove(&first_key);
202            }
203        }
204
205        cache.insert(key, func);
206    }
207
208    /// Inserts a compiled function for a graph.
209    pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
210        let key = Self::hash_graph(graph);
211        self.insert(key, func);
212    }
213
214    // -------------------------------------------------------------------------
215    // Introspection and Maintenance
216    // -------------------------------------------------------------------------
217
218    /// Returns the number of cached functions.
219    pub fn len(&self) -> usize {
220        self.cache.read().len()
221    }
222
223    /// Returns whether the cache is empty.
224    pub fn is_empty(&self) -> bool {
225        self.cache.read().is_empty()
226    }
227
228    /// Clears the cache.
229    pub fn clear(&self) {
230        self.cache.write().clear();
231    }
232
233    /// Returns cache statistics.
234    pub fn stats(&self) -> CacheStats {
235        CacheStats {
236            entries: self.len(),
237            max_size: self.max_size,
238        }
239    }
240}
241
242impl Default for FunctionCache {
243    fn default() -> Self {
244        Self::default_size()
245    }
246}
247
248// =============================================================================
249// CacheStats
250// =============================================================================
251
252/// Cache statistics.
253#[derive(Debug, Clone)]
254pub struct CacheStats {
255    /// Number of cached entries.
256    pub entries: usize,
257    /// Maximum cache size.
258    pub max_size: usize,
259}
260
261impl CacheStats {
262    /// Returns the utilization as a percentage.
263    pub fn utilization(&self) -> f64 {
264        if self.max_size == 0 {
265            0.0
266        } else {
267            (self.entries as f64 / self.max_size as f64) * 100.0
268        }
269    }
270}
271
272// =============================================================================
273// Tests
274// =============================================================================
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::trace::trace;
280
281    #[test]
282    fn test_graph_hash() {
283        let graph1 = trace(|tracer| {
284            let a = tracer.input("a", &[2, 3]);
285            let b = tracer.input("b", &[2, 3]);
286            let c = a.add(&b);
287            tracer.output("result", c)
288        });
289
290        let graph2 = trace(|tracer| {
291            let a = tracer.input("a", &[2, 3]);
292            let b = tracer.input("b", &[2, 3]);
293            let c = a.add(&b);
294            tracer.output("result", c)
295        });
296
297        let graph3 = trace(|tracer| {
298            let a = tracer.input("a", &[2, 3]);
299            let b = tracer.input("b", &[2, 3]);
300            let c = a.mul(&b); // Different op
301            tracer.output("result", c)
302        });
303
304        let hash1 = FunctionCache::hash_graph(&graph1);
305        let hash2 = FunctionCache::hash_graph(&graph2);
306        let hash3 = FunctionCache::hash_graph(&graph3);
307
308        // Same structure should have same hash
309        assert_eq!(hash1, hash2);
310        // Different structure should have different hash
311        assert_ne!(hash1, hash3);
312    }
313
314    #[test]
315    fn test_cache_insert_get() {
316        let cache = FunctionCache::new(10);
317
318        let graph = trace(|tracer| {
319            let a = tracer.input("a", &[2, 3]);
320            tracer.output("result", a.relu())
321        });
322
323        let key = FunctionCache::hash_graph(&graph);
324        let func = CompiledFunction::placeholder();
325
326        assert!(cache.get(key).is_none());
327        cache.insert(key, func.clone());
328        assert!(cache.get(key).is_some());
329    }
330
331    #[test]
332    fn test_cache_eviction() {
333        let cache = FunctionCache::new(2);
334
335        for i in 0..3 {
336            cache.insert(i as u64, CompiledFunction::placeholder());
337        }
338
339        // Cache should only have 2 entries
340        assert_eq!(cache.len(), 2);
341    }
342
343    #[test]
344    fn test_cache_stats() {
345        let cache = FunctionCache::new(100);
346        cache.insert(1, CompiledFunction::placeholder());
347        cache.insert(2, CompiledFunction::placeholder());
348
349        let stats = cache.stats();
350        assert_eq!(stats.entries, 2);
351        assert_eq!(stats.max_size, 100);
352        assert!((stats.utilization() - 2.0).abs() < 0.01);
353    }
354}