Skip to main content

axonml_jit/
cache.rs

1//! Function Cache
2//!
3//! Caches compiled functions for reuse.
4
5use std::hash::{Hash, Hasher};
6use std::collections::hash_map::DefaultHasher;
7use parking_lot::RwLock;
8use rustc_hash::FxHashMap;
9
10use crate::ir::Graph;
11use crate::codegen::CompiledFunction;
12
13/// Cache for compiled functions.
14pub struct FunctionCache {
15    cache: RwLock<FxHashMap<u64, CompiledFunction>>,
16    max_size: usize,
17}
18
19impl FunctionCache {
20    /// Creates a new function cache.
21    pub fn new(max_size: usize) -> Self {
22        Self {
23            cache: RwLock::new(FxHashMap::default()),
24            max_size,
25        }
26    }
27
28    /// Creates a cache with default size (1000).
29    pub fn default_size() -> Self {
30        Self::new(1000)
31    }
32
33    /// Computes a hash key for a graph.
34    pub fn hash_graph(graph: &Graph) -> u64 {
35        let mut hasher = DefaultHasher::new();
36
37        // Hash graph structure
38        for node in graph.nodes() {
39            // Hash op type
40            std::mem::discriminant(&node.op).hash(&mut hasher);
41
42            // Hash dtype
43            node.dtype.hash(&mut hasher);
44
45            // Hash shape
46            node.shape.dims().hash(&mut hasher);
47
48            // Hash op-specific data
49            match &node.op {
50                crate::ir::Op::Input { name } => name.hash(&mut hasher),
51                crate::ir::Op::Output { name, input } => {
52                    name.hash(&mut hasher);
53                    input.index().hash(&mut hasher);
54                }
55                crate::ir::Op::Constant { value } => {
56                    value.to_bits().hash(&mut hasher);
57                }
58                crate::ir::Op::AddScalar { input, scalar } |
59                crate::ir::Op::MulScalar { input, scalar } => {
60                    input.index().hash(&mut hasher);
61                    scalar.to_bits().hash(&mut hasher);
62                }
63                crate::ir::Op::Reshape { input, shape } => {
64                    input.index().hash(&mut hasher);
65                    shape.hash(&mut hasher);
66                }
67                crate::ir::Op::Transpose { input, dim0, dim1 } => {
68                    input.index().hash(&mut hasher);
69                    dim0.hash(&mut hasher);
70                    dim1.hash(&mut hasher);
71                }
72                crate::ir::Op::SumAxis { input, axis, keepdim } |
73                crate::ir::Op::MeanAxis { input, axis, keepdim } |
74                crate::ir::Op::MaxAxis { input, axis, keepdim } => {
75                    input.index().hash(&mut hasher);
76                    axis.hash(&mut hasher);
77                    keepdim.hash(&mut hasher);
78                }
79                crate::ir::Op::Squeeze { input, dim } |
80                crate::ir::Op::Unsqueeze { input, dim } => {
81                    input.index().hash(&mut hasher);
82                    dim.hash(&mut hasher);
83                }
84                crate::ir::Op::Broadcast { input, shape } => {
85                    input.index().hash(&mut hasher);
86                    shape.hash(&mut hasher);
87                }
88                crate::ir::Op::Cast { input, dtype } => {
89                    input.index().hash(&mut hasher);
90                    dtype.hash(&mut hasher);
91                }
92                // Binary ops
93                crate::ir::Op::Add { lhs, rhs } |
94                crate::ir::Op::Sub { lhs, rhs } |
95                crate::ir::Op::Mul { lhs, rhs } |
96                crate::ir::Op::Div { lhs, rhs } |
97                crate::ir::Op::Pow { base: lhs, exp: rhs } |
98                crate::ir::Op::Max { lhs, rhs } |
99                crate::ir::Op::Min { lhs, rhs } |
100                crate::ir::Op::MatMul { lhs, rhs } |
101                crate::ir::Op::Gt { lhs, rhs } |
102                crate::ir::Op::Lt { lhs, rhs } |
103                crate::ir::Op::Eq { lhs, rhs } => {
104                    lhs.index().hash(&mut hasher);
105                    rhs.index().hash(&mut hasher);
106                }
107                // Ternary ops
108                crate::ir::Op::Where { condition, x, y } => {
109                    condition.index().hash(&mut hasher);
110                    x.index().hash(&mut hasher);
111                    y.index().hash(&mut hasher);
112                }
113                // Unary ops just hash input
114                _ => {
115                    for input in node.op.inputs() {
116                        input.index().hash(&mut hasher);
117                    }
118                }
119            }
120        }
121
122        hasher.finish()
123    }
124
125    /// Gets a cached function or returns None.
126    pub fn get(&self, key: u64) -> Option<CompiledFunction> {
127        self.cache.read().get(&key).cloned()
128    }
129
130    /// Gets a cached function by graph.
131    pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
132        let key = Self::hash_graph(graph);
133        self.get(key)
134    }
135
136    /// Inserts a compiled function.
137    pub fn insert(&self, key: u64, func: CompiledFunction) {
138        let mut cache = self.cache.write();
139
140        // Evict if at capacity
141        if cache.len() >= self.max_size {
142            // Simple eviction: remove first entry
143            if let Some(&first_key) = cache.keys().next() {
144                cache.remove(&first_key);
145            }
146        }
147
148        cache.insert(key, func);
149    }
150
151    /// Inserts a compiled function for a graph.
152    pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
153        let key = Self::hash_graph(graph);
154        self.insert(key, func);
155    }
156
157    /// Returns the number of cached functions.
158    pub fn len(&self) -> usize {
159        self.cache.read().len()
160    }
161
162    /// Returns whether the cache is empty.
163    pub fn is_empty(&self) -> bool {
164        self.cache.read().is_empty()
165    }
166
167    /// Clears the cache.
168    pub fn clear(&self) {
169        self.cache.write().clear();
170    }
171
172    /// Returns cache statistics.
173    pub fn stats(&self) -> CacheStats {
174        CacheStats {
175            entries: self.len(),
176            max_size: self.max_size,
177        }
178    }
179}
180
181impl Default for FunctionCache {
182    fn default() -> Self {
183        Self::default_size()
184    }
185}
186
187/// Cache statistics.
188#[derive(Debug, Clone)]
189pub struct CacheStats {
190    /// Number of cached entries.
191    pub entries: usize,
192    /// Maximum cache size.
193    pub max_size: usize,
194}
195
196impl CacheStats {
197    /// Returns the utilization as a percentage.
198    pub fn utilization(&self) -> f64 {
199        if self.max_size == 0 {
200            0.0
201        } else {
202            (self.entries as f64 / self.max_size as f64) * 100.0
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::trace::trace;
211
212    #[test]
213    fn test_graph_hash() {
214        let graph1 = trace(|tracer| {
215            let a = tracer.input("a", &[2, 3]);
216            let b = tracer.input("b", &[2, 3]);
217            let c = a.add(&b);
218            tracer.output("result", c)
219        });
220
221        let graph2 = trace(|tracer| {
222            let a = tracer.input("a", &[2, 3]);
223            let b = tracer.input("b", &[2, 3]);
224            let c = a.add(&b);
225            tracer.output("result", c)
226        });
227
228        let graph3 = trace(|tracer| {
229            let a = tracer.input("a", &[2, 3]);
230            let b = tracer.input("b", &[2, 3]);
231            let c = a.mul(&b); // Different op
232            tracer.output("result", c)
233        });
234
235        let hash1 = FunctionCache::hash_graph(&graph1);
236        let hash2 = FunctionCache::hash_graph(&graph2);
237        let hash3 = FunctionCache::hash_graph(&graph3);
238
239        // Same structure should have same hash
240        assert_eq!(hash1, hash2);
241        // Different structure should have different hash
242        assert_ne!(hash1, hash3);
243    }
244
245    #[test]
246    fn test_cache_insert_get() {
247        let cache = FunctionCache::new(10);
248
249        let graph = trace(|tracer| {
250            let a = tracer.input("a", &[2, 3]);
251            tracer.output("result", a.relu())
252        });
253
254        let key = FunctionCache::hash_graph(&graph);
255        let func = CompiledFunction::placeholder();
256
257        assert!(cache.get(key).is_none());
258        cache.insert(key, func.clone());
259        assert!(cache.get(key).is_some());
260    }
261
262    #[test]
263    fn test_cache_eviction() {
264        let cache = FunctionCache::new(2);
265
266        for i in 0..3 {
267            cache.insert(i as u64, CompiledFunction::placeholder());
268        }
269
270        // Cache should only have 2 entries
271        assert_eq!(cache.len(), 2);
272    }
273
274    #[test]
275    fn test_cache_stats() {
276        let cache = FunctionCache::new(100);
277        cache.insert(1, CompiledFunction::placeholder());
278        cache.insert(2, CompiledFunction::placeholder());
279
280        let stats = cache.stats();
281        assert_eq!(stats.entries, 2);
282        assert_eq!(stats.max_size, 100);
283        assert!((stats.utilization() - 2.0).abs() < 0.01);
284    }
285}