1use parking_lot::RwLock;
18use rustc_hash::FxHashMap;
19use std::collections::hash_map::DefaultHasher;
20use std::hash::{Hash, Hasher};
21
22use crate::codegen::CompiledFunction;
23use crate::ir::Graph;
24
25pub struct FunctionCache {
27 cache: RwLock<FxHashMap<u64, CompiledFunction>>,
28 max_size: usize,
29}
30
31impl FunctionCache {
32 pub fn new(max_size: usize) -> Self {
34 Self {
35 cache: RwLock::new(FxHashMap::default()),
36 max_size,
37 }
38 }
39
40 pub fn default_size() -> Self {
42 Self::new(1000)
43 }
44
45 pub fn hash_graph(graph: &Graph) -> u64 {
47 let mut hasher = DefaultHasher::new();
48
49 for node in graph.nodes() {
51 std::mem::discriminant(&node.op).hash(&mut hasher);
53
54 node.dtype.hash(&mut hasher);
56
57 node.shape.dims().hash(&mut hasher);
59
60 match &node.op {
62 crate::ir::Op::Input { name } => name.hash(&mut hasher),
63 crate::ir::Op::Output { name, input } => {
64 name.hash(&mut hasher);
65 input.index().hash(&mut hasher);
66 }
67 crate::ir::Op::Constant { value } => {
68 value.to_bits().hash(&mut hasher);
69 }
70 crate::ir::Op::AddScalar { input, scalar }
71 | crate::ir::Op::MulScalar { input, scalar } => {
72 input.index().hash(&mut hasher);
73 scalar.to_bits().hash(&mut hasher);
74 }
75 crate::ir::Op::Reshape { input, shape } => {
76 input.index().hash(&mut hasher);
77 shape.hash(&mut hasher);
78 }
79 crate::ir::Op::Transpose { input, dim0, dim1 } => {
80 input.index().hash(&mut hasher);
81 dim0.hash(&mut hasher);
82 dim1.hash(&mut hasher);
83 }
84 crate::ir::Op::SumAxis {
85 input,
86 axis,
87 keepdim,
88 }
89 | crate::ir::Op::MeanAxis {
90 input,
91 axis,
92 keepdim,
93 }
94 | crate::ir::Op::MaxAxis {
95 input,
96 axis,
97 keepdim,
98 } => {
99 input.index().hash(&mut hasher);
100 axis.hash(&mut hasher);
101 keepdim.hash(&mut hasher);
102 }
103 crate::ir::Op::Squeeze { input, dim } | crate::ir::Op::Unsqueeze { input, dim } => {
104 input.index().hash(&mut hasher);
105 dim.hash(&mut hasher);
106 }
107 crate::ir::Op::Broadcast { input, shape } => {
108 input.index().hash(&mut hasher);
109 shape.hash(&mut hasher);
110 }
111 crate::ir::Op::Cast { input, dtype } => {
112 input.index().hash(&mut hasher);
113 dtype.hash(&mut hasher);
114 }
115 crate::ir::Op::Add { lhs, rhs }
117 | crate::ir::Op::Sub { lhs, rhs }
118 | crate::ir::Op::Mul { lhs, rhs }
119 | crate::ir::Op::Div { lhs, rhs }
120 | crate::ir::Op::Pow {
121 base: lhs,
122 exp: rhs,
123 }
124 | crate::ir::Op::Max { lhs, rhs }
125 | crate::ir::Op::Min { lhs, rhs }
126 | crate::ir::Op::MatMul { lhs, rhs }
127 | crate::ir::Op::Gt { lhs, rhs }
128 | crate::ir::Op::Lt { lhs, rhs }
129 | crate::ir::Op::Eq { lhs, rhs } => {
130 lhs.index().hash(&mut hasher);
131 rhs.index().hash(&mut hasher);
132 }
133 crate::ir::Op::Where { condition, x, y } => {
135 condition.index().hash(&mut hasher);
136 x.index().hash(&mut hasher);
137 y.index().hash(&mut hasher);
138 }
139 _ => {
141 for input in node.op.inputs() {
142 input.index().hash(&mut hasher);
143 }
144 }
145 }
146 }
147
148 hasher.finish()
149 }
150
151 pub fn get(&self, key: u64) -> Option<CompiledFunction> {
153 self.cache.read().get(&key).cloned()
154 }
155
156 pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
158 let key = Self::hash_graph(graph);
159 self.get(key)
160 }
161
162 pub fn insert(&self, key: u64, func: CompiledFunction) {
164 let mut cache = self.cache.write();
165
166 if cache.len() >= self.max_size {
168 if let Some(&first_key) = cache.keys().next() {
170 cache.remove(&first_key);
171 }
172 }
173
174 cache.insert(key, func);
175 }
176
177 pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
179 let key = Self::hash_graph(graph);
180 self.insert(key, func);
181 }
182
183 pub fn len(&self) -> usize {
185 self.cache.read().len()
186 }
187
188 pub fn is_empty(&self) -> bool {
190 self.cache.read().is_empty()
191 }
192
193 pub fn clear(&self) {
195 self.cache.write().clear();
196 }
197
198 pub fn stats(&self) -> CacheStats {
200 CacheStats {
201 entries: self.len(),
202 max_size: self.max_size,
203 }
204 }
205}
206
207impl Default for FunctionCache {
208 fn default() -> Self {
209 Self::default_size()
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct CacheStats {
216 pub entries: usize,
218 pub max_size: usize,
220}
221
222impl CacheStats {
223 pub fn utilization(&self) -> f64 {
225 if self.max_size == 0 {
226 0.0
227 } else {
228 (self.entries as f64 / self.max_size as f64) * 100.0
229 }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::trace::trace;
237
238 #[test]
239 fn test_graph_hash() {
240 let graph1 = trace(|tracer| {
241 let a = tracer.input("a", &[2, 3]);
242 let b = tracer.input("b", &[2, 3]);
243 let c = a.add(&b);
244 tracer.output("result", c)
245 });
246
247 let graph2 = trace(|tracer| {
248 let a = tracer.input("a", &[2, 3]);
249 let b = tracer.input("b", &[2, 3]);
250 let c = a.add(&b);
251 tracer.output("result", c)
252 });
253
254 let graph3 = trace(|tracer| {
255 let a = tracer.input("a", &[2, 3]);
256 let b = tracer.input("b", &[2, 3]);
257 let c = a.mul(&b); tracer.output("result", c)
259 });
260
261 let hash1 = FunctionCache::hash_graph(&graph1);
262 let hash2 = FunctionCache::hash_graph(&graph2);
263 let hash3 = FunctionCache::hash_graph(&graph3);
264
265 assert_eq!(hash1, hash2);
267 assert_ne!(hash1, hash3);
269 }
270
271 #[test]
272 fn test_cache_insert_get() {
273 let cache = FunctionCache::new(10);
274
275 let graph = trace(|tracer| {
276 let a = tracer.input("a", &[2, 3]);
277 tracer.output("result", a.relu())
278 });
279
280 let key = FunctionCache::hash_graph(&graph);
281 let func = CompiledFunction::placeholder();
282
283 assert!(cache.get(key).is_none());
284 cache.insert(key, func.clone());
285 assert!(cache.get(key).is_some());
286 }
287
288 #[test]
289 fn test_cache_eviction() {
290 let cache = FunctionCache::new(2);
291
292 for i in 0..3 {
293 cache.insert(i as u64, CompiledFunction::placeholder());
294 }
295
296 assert_eq!(cache.len(), 2);
298 }
299
300 #[test]
301 fn test_cache_stats() {
302 let cache = FunctionCache::new(100);
303 cache.insert(1, CompiledFunction::placeholder());
304 cache.insert(2, CompiledFunction::placeholder());
305
306 let stats = cache.stats();
307 assert_eq!(stats.entries, 2);
308 assert_eq!(stats.max_size, 100);
309 assert!((stats.utilization() - 2.0).abs() < 0.01);
310 }
311}