Skip to main content

axonml_jit/
cache.rs

1//! Function Cache
2//!
3//! # File
4//! `crates/axonml-jit/src/cache.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use parking_lot::RwLock;
18use rustc_hash::FxHashMap;
19use std::collections::hash_map::DefaultHasher;
20use std::hash::{Hash, Hasher};
21
22use crate::codegen::CompiledFunction;
23use crate::ir::Graph;
24
25/// Cache for compiled functions.
26pub struct FunctionCache {
27    cache: RwLock<FxHashMap<u64, CompiledFunction>>,
28    max_size: usize,
29}
30
31impl FunctionCache {
32    /// Creates a new function cache.
33    pub fn new(max_size: usize) -> Self {
34        Self {
35            cache: RwLock::new(FxHashMap::default()),
36            max_size,
37        }
38    }
39
40    /// Creates a cache with default size (1000).
41    pub fn default_size() -> Self {
42        Self::new(1000)
43    }
44
45    /// Computes a hash key for a graph.
46    pub fn hash_graph(graph: &Graph) -> u64 {
47        let mut hasher = DefaultHasher::new();
48
49        // Hash graph structure
50        for node in graph.nodes() {
51            // Hash op type
52            std::mem::discriminant(&node.op).hash(&mut hasher);
53
54            // Hash dtype
55            node.dtype.hash(&mut hasher);
56
57            // Hash shape
58            node.shape.dims().hash(&mut hasher);
59
60            // Hash op-specific data
61            match &node.op {
62                crate::ir::Op::Input { name } => name.hash(&mut hasher),
63                crate::ir::Op::Output { name, input } => {
64                    name.hash(&mut hasher);
65                    input.index().hash(&mut hasher);
66                }
67                crate::ir::Op::Constant { value } => {
68                    value.to_bits().hash(&mut hasher);
69                }
70                crate::ir::Op::AddScalar { input, scalar }
71                | crate::ir::Op::MulScalar { input, scalar } => {
72                    input.index().hash(&mut hasher);
73                    scalar.to_bits().hash(&mut hasher);
74                }
75                crate::ir::Op::Reshape { input, shape } => {
76                    input.index().hash(&mut hasher);
77                    shape.hash(&mut hasher);
78                }
79                crate::ir::Op::Transpose { input, dim0, dim1 } => {
80                    input.index().hash(&mut hasher);
81                    dim0.hash(&mut hasher);
82                    dim1.hash(&mut hasher);
83                }
84                crate::ir::Op::SumAxis {
85                    input,
86                    axis,
87                    keepdim,
88                }
89                | crate::ir::Op::MeanAxis {
90                    input,
91                    axis,
92                    keepdim,
93                }
94                | crate::ir::Op::MaxAxis {
95                    input,
96                    axis,
97                    keepdim,
98                } => {
99                    input.index().hash(&mut hasher);
100                    axis.hash(&mut hasher);
101                    keepdim.hash(&mut hasher);
102                }
103                crate::ir::Op::Squeeze { input, dim } | crate::ir::Op::Unsqueeze { input, dim } => {
104                    input.index().hash(&mut hasher);
105                    dim.hash(&mut hasher);
106                }
107                crate::ir::Op::Broadcast { input, shape } => {
108                    input.index().hash(&mut hasher);
109                    shape.hash(&mut hasher);
110                }
111                crate::ir::Op::Cast { input, dtype } => {
112                    input.index().hash(&mut hasher);
113                    dtype.hash(&mut hasher);
114                }
115                // Binary ops
116                crate::ir::Op::Add { lhs, rhs }
117                | crate::ir::Op::Sub { lhs, rhs }
118                | crate::ir::Op::Mul { lhs, rhs }
119                | crate::ir::Op::Div { lhs, rhs }
120                | crate::ir::Op::Pow {
121                    base: lhs,
122                    exp: rhs,
123                }
124                | crate::ir::Op::Max { lhs, rhs }
125                | crate::ir::Op::Min { lhs, rhs }
126                | crate::ir::Op::MatMul { lhs, rhs }
127                | crate::ir::Op::Gt { lhs, rhs }
128                | crate::ir::Op::Lt { lhs, rhs }
129                | crate::ir::Op::Eq { lhs, rhs } => {
130                    lhs.index().hash(&mut hasher);
131                    rhs.index().hash(&mut hasher);
132                }
133                // Ternary ops
134                crate::ir::Op::Where { condition, x, y } => {
135                    condition.index().hash(&mut hasher);
136                    x.index().hash(&mut hasher);
137                    y.index().hash(&mut hasher);
138                }
139                // Unary ops just hash input
140                _ => {
141                    for input in node.op.inputs() {
142                        input.index().hash(&mut hasher);
143                    }
144                }
145            }
146        }
147
148        hasher.finish()
149    }
150
151    /// Gets a cached function or returns None.
152    pub fn get(&self, key: u64) -> Option<CompiledFunction> {
153        self.cache.read().get(&key).cloned()
154    }
155
156    /// Gets a cached function by graph.
157    pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
158        let key = Self::hash_graph(graph);
159        self.get(key)
160    }
161
162    /// Inserts a compiled function.
163    pub fn insert(&self, key: u64, func: CompiledFunction) {
164        let mut cache = self.cache.write();
165
166        // Evict if at capacity
167        if cache.len() >= self.max_size {
168            // Simple eviction: remove first entry
169            if let Some(&first_key) = cache.keys().next() {
170                cache.remove(&first_key);
171            }
172        }
173
174        cache.insert(key, func);
175    }
176
177    /// Inserts a compiled function for a graph.
178    pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
179        let key = Self::hash_graph(graph);
180        self.insert(key, func);
181    }
182
183    /// Returns the number of cached functions.
184    pub fn len(&self) -> usize {
185        self.cache.read().len()
186    }
187
188    /// Returns whether the cache is empty.
189    pub fn is_empty(&self) -> bool {
190        self.cache.read().is_empty()
191    }
192
193    /// Clears the cache.
194    pub fn clear(&self) {
195        self.cache.write().clear();
196    }
197
198    /// Returns cache statistics.
199    pub fn stats(&self) -> CacheStats {
200        CacheStats {
201            entries: self.len(),
202            max_size: self.max_size,
203        }
204    }
205}
206
207impl Default for FunctionCache {
208    fn default() -> Self {
209        Self::default_size()
210    }
211}
212
213/// Cache statistics.
214#[derive(Debug, Clone)]
215pub struct CacheStats {
216    /// Number of cached entries.
217    pub entries: usize,
218    /// Maximum cache size.
219    pub max_size: usize,
220}
221
222impl CacheStats {
223    /// Returns the utilization as a percentage.
224    pub fn utilization(&self) -> f64 {
225        if self.max_size == 0 {
226            0.0
227        } else {
228            (self.entries as f64 / self.max_size as f64) * 100.0
229        }
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::trace::trace;
237
238    #[test]
239    fn test_graph_hash() {
240        let graph1 = trace(|tracer| {
241            let a = tracer.input("a", &[2, 3]);
242            let b = tracer.input("b", &[2, 3]);
243            let c = a.add(&b);
244            tracer.output("result", c)
245        });
246
247        let graph2 = trace(|tracer| {
248            let a = tracer.input("a", &[2, 3]);
249            let b = tracer.input("b", &[2, 3]);
250            let c = a.add(&b);
251            tracer.output("result", c)
252        });
253
254        let graph3 = trace(|tracer| {
255            let a = tracer.input("a", &[2, 3]);
256            let b = tracer.input("b", &[2, 3]);
257            let c = a.mul(&b); // Different op
258            tracer.output("result", c)
259        });
260
261        let hash1 = FunctionCache::hash_graph(&graph1);
262        let hash2 = FunctionCache::hash_graph(&graph2);
263        let hash3 = FunctionCache::hash_graph(&graph3);
264
265        // Same structure should have same hash
266        assert_eq!(hash1, hash2);
267        // Different structure should have different hash
268        assert_ne!(hash1, hash3);
269    }
270
271    #[test]
272    fn test_cache_insert_get() {
273        let cache = FunctionCache::new(10);
274
275        let graph = trace(|tracer| {
276            let a = tracer.input("a", &[2, 3]);
277            tracer.output("result", a.relu())
278        });
279
280        let key = FunctionCache::hash_graph(&graph);
281        let func = CompiledFunction::placeholder();
282
283        assert!(cache.get(key).is_none());
284        cache.insert(key, func.clone());
285        assert!(cache.get(key).is_some());
286    }
287
288    #[test]
289    fn test_cache_eviction() {
290        let cache = FunctionCache::new(2);
291
292        for i in 0..3 {
293            cache.insert(i as u64, CompiledFunction::placeholder());
294        }
295
296        // Cache should only have 2 entries
297        assert_eq!(cache.len(), 2);
298    }
299
300    #[test]
301    fn test_cache_stats() {
302        let cache = FunctionCache::new(100);
303        cache.insert(1, CompiledFunction::placeholder());
304        cache.insert(2, CompiledFunction::placeholder());
305
306        let stats = cache.stats();
307        assert_eq!(stats.entries, 2);
308        assert_eq!(stats.max_size, 100);
309        assert!((stats.utilization() - 2.0).abs() < 0.01);
310    }
311}