Skip to main content

axonml_jit/
cache.rs

1//! Function Cache
2//!
3//! Caches compiled functions for reuse.
4
5use parking_lot::RwLock;
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10use crate::codegen::CompiledFunction;
11use crate::ir::Graph;
12
13/// Cache for compiled functions.
14pub struct FunctionCache {
15    cache: RwLock<FxHashMap<u64, CompiledFunction>>,
16    max_size: usize,
17}
18
19impl FunctionCache {
20    /// Creates a new function cache.
21    pub fn new(max_size: usize) -> Self {
22        Self {
23            cache: RwLock::new(FxHashMap::default()),
24            max_size,
25        }
26    }
27
28    /// Creates a cache with default size (1000).
29    pub fn default_size() -> Self {
30        Self::new(1000)
31    }
32
33    /// Computes a hash key for a graph.
34    pub fn hash_graph(graph: &Graph) -> u64 {
35        let mut hasher = DefaultHasher::new();
36
37        // Hash graph structure
38        for node in graph.nodes() {
39            // Hash op type
40            std::mem::discriminant(&node.op).hash(&mut hasher);
41
42            // Hash dtype
43            node.dtype.hash(&mut hasher);
44
45            // Hash shape
46            node.shape.dims().hash(&mut hasher);
47
48            // Hash op-specific data
49            match &node.op {
50                crate::ir::Op::Input { name } => name.hash(&mut hasher),
51                crate::ir::Op::Output { name, input } => {
52                    name.hash(&mut hasher);
53                    input.index().hash(&mut hasher);
54                }
55                crate::ir::Op::Constant { value } => {
56                    value.to_bits().hash(&mut hasher);
57                }
58                crate::ir::Op::AddScalar { input, scalar }
59                | crate::ir::Op::MulScalar { input, scalar } => {
60                    input.index().hash(&mut hasher);
61                    scalar.to_bits().hash(&mut hasher);
62                }
63                crate::ir::Op::Reshape { input, shape } => {
64                    input.index().hash(&mut hasher);
65                    shape.hash(&mut hasher);
66                }
67                crate::ir::Op::Transpose { input, dim0, dim1 } => {
68                    input.index().hash(&mut hasher);
69                    dim0.hash(&mut hasher);
70                    dim1.hash(&mut hasher);
71                }
72                crate::ir::Op::SumAxis {
73                    input,
74                    axis,
75                    keepdim,
76                }
77                | crate::ir::Op::MeanAxis {
78                    input,
79                    axis,
80                    keepdim,
81                }
82                | crate::ir::Op::MaxAxis {
83                    input,
84                    axis,
85                    keepdim,
86                } => {
87                    input.index().hash(&mut hasher);
88                    axis.hash(&mut hasher);
89                    keepdim.hash(&mut hasher);
90                }
91                crate::ir::Op::Squeeze { input, dim } | crate::ir::Op::Unsqueeze { input, dim } => {
92                    input.index().hash(&mut hasher);
93                    dim.hash(&mut hasher);
94                }
95                crate::ir::Op::Broadcast { input, shape } => {
96                    input.index().hash(&mut hasher);
97                    shape.hash(&mut hasher);
98                }
99                crate::ir::Op::Cast { input, dtype } => {
100                    input.index().hash(&mut hasher);
101                    dtype.hash(&mut hasher);
102                }
103                // Binary ops
104                crate::ir::Op::Add { lhs, rhs }
105                | crate::ir::Op::Sub { lhs, rhs }
106                | crate::ir::Op::Mul { lhs, rhs }
107                | crate::ir::Op::Div { lhs, rhs }
108                | crate::ir::Op::Pow {
109                    base: lhs,
110                    exp: rhs,
111                }
112                | crate::ir::Op::Max { lhs, rhs }
113                | crate::ir::Op::Min { lhs, rhs }
114                | crate::ir::Op::MatMul { lhs, rhs }
115                | crate::ir::Op::Gt { lhs, rhs }
116                | crate::ir::Op::Lt { lhs, rhs }
117                | crate::ir::Op::Eq { lhs, rhs } => {
118                    lhs.index().hash(&mut hasher);
119                    rhs.index().hash(&mut hasher);
120                }
121                // Ternary ops
122                crate::ir::Op::Where { condition, x, y } => {
123                    condition.index().hash(&mut hasher);
124                    x.index().hash(&mut hasher);
125                    y.index().hash(&mut hasher);
126                }
127                // Unary ops just hash input
128                _ => {
129                    for input in node.op.inputs() {
130                        input.index().hash(&mut hasher);
131                    }
132                }
133            }
134        }
135
136        hasher.finish()
137    }
138
139    /// Gets a cached function or returns None.
140    pub fn get(&self, key: u64) -> Option<CompiledFunction> {
141        self.cache.read().get(&key).cloned()
142    }
143
144    /// Gets a cached function by graph.
145    pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
146        let key = Self::hash_graph(graph);
147        self.get(key)
148    }
149
150    /// Inserts a compiled function.
151    pub fn insert(&self, key: u64, func: CompiledFunction) {
152        let mut cache = self.cache.write();
153
154        // Evict if at capacity
155        if cache.len() >= self.max_size {
156            // Simple eviction: remove first entry
157            if let Some(&first_key) = cache.keys().next() {
158                cache.remove(&first_key);
159            }
160        }
161
162        cache.insert(key, func);
163    }
164
165    /// Inserts a compiled function for a graph.
166    pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
167        let key = Self::hash_graph(graph);
168        self.insert(key, func);
169    }
170
171    /// Returns the number of cached functions.
172    pub fn len(&self) -> usize {
173        self.cache.read().len()
174    }
175
176    /// Returns whether the cache is empty.
177    pub fn is_empty(&self) -> bool {
178        self.cache.read().is_empty()
179    }
180
181    /// Clears the cache.
182    pub fn clear(&self) {
183        self.cache.write().clear();
184    }
185
186    /// Returns cache statistics.
187    pub fn stats(&self) -> CacheStats {
188        CacheStats {
189            entries: self.len(),
190            max_size: self.max_size,
191        }
192    }
193}
194
195impl Default for FunctionCache {
196    fn default() -> Self {
197        Self::default_size()
198    }
199}
200
201/// Cache statistics.
202#[derive(Debug, Clone)]
203pub struct CacheStats {
204    /// Number of cached entries.
205    pub entries: usize,
206    /// Maximum cache size.
207    pub max_size: usize,
208}
209
210impl CacheStats {
211    /// Returns the utilization as a percentage.
212    pub fn utilization(&self) -> f64 {
213        if self.max_size == 0 {
214            0.0
215        } else {
216            (self.entries as f64 / self.max_size as f64) * 100.0
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::trace::trace;
225
226    #[test]
227    fn test_graph_hash() {
228        let graph1 = trace(|tracer| {
229            let a = tracer.input("a", &[2, 3]);
230            let b = tracer.input("b", &[2, 3]);
231            let c = a.add(&b);
232            tracer.output("result", c)
233        });
234
235        let graph2 = trace(|tracer| {
236            let a = tracer.input("a", &[2, 3]);
237            let b = tracer.input("b", &[2, 3]);
238            let c = a.add(&b);
239            tracer.output("result", c)
240        });
241
242        let graph3 = trace(|tracer| {
243            let a = tracer.input("a", &[2, 3]);
244            let b = tracer.input("b", &[2, 3]);
245            let c = a.mul(&b); // Different op
246            tracer.output("result", c)
247        });
248
249        let hash1 = FunctionCache::hash_graph(&graph1);
250        let hash2 = FunctionCache::hash_graph(&graph2);
251        let hash3 = FunctionCache::hash_graph(&graph3);
252
253        // Same structure should have same hash
254        assert_eq!(hash1, hash2);
255        // Different structure should have different hash
256        assert_ne!(hash1, hash3);
257    }
258
259    #[test]
260    fn test_cache_insert_get() {
261        let cache = FunctionCache::new(10);
262
263        let graph = trace(|tracer| {
264            let a = tracer.input("a", &[2, 3]);
265            tracer.output("result", a.relu())
266        });
267
268        let key = FunctionCache::hash_graph(&graph);
269        let func = CompiledFunction::placeholder();
270
271        assert!(cache.get(key).is_none());
272        cache.insert(key, func.clone());
273        assert!(cache.get(key).is_some());
274    }
275
276    #[test]
277    fn test_cache_eviction() {
278        let cache = FunctionCache::new(2);
279
280        for i in 0..3 {
281            cache.insert(i as u64, CompiledFunction::placeholder());
282        }
283
284        // Cache should only have 2 entries
285        assert_eq!(cache.len(), 2);
286    }
287
288    #[test]
289    fn test_cache_stats() {
290        let cache = FunctionCache::new(100);
291        cache.insert(1, CompiledFunction::placeholder());
292        cache.insert(2, CompiledFunction::placeholder());
293
294        let stats = cache.stats();
295        assert_eq!(stats.entries, 2);
296        assert_eq!(stats.max_size, 100);
297        assert!((stats.utilization() - 2.0).abs() < 0.01);
298    }
299}