1use parking_lot::RwLock;
37use rustc_hash::FxHashMap;
38use std::collections::hash_map::DefaultHasher;
39use std::hash::{Hash, Hasher};
40
41use crate::codegen::CompiledFunction;
42use crate::ir::Graph;
43
44pub struct FunctionCache {
50 cache: RwLock<FxHashMap<u64, CompiledFunction>>,
51 max_size: usize,
52}
53
54impl FunctionCache {
55 pub fn new(max_size: usize) -> Self {
57 Self {
58 cache: RwLock::new(FxHashMap::default()),
59 max_size,
60 }
61 }
62
63 pub fn default_size() -> Self {
65 Self::new(1000)
66 }
67
68 pub fn hash_graph(graph: &Graph) -> u64 {
74 let mut hasher = DefaultHasher::new();
75
76 for node in graph.nodes() {
78 std::mem::discriminant(&node.op).hash(&mut hasher);
80
81 node.dtype.hash(&mut hasher);
83
84 node.shape.dims().hash(&mut hasher);
86
87 match &node.op {
89 crate::ir::Op::Input { name } => name.hash(&mut hasher),
90 crate::ir::Op::Output { name, input } => {
91 name.hash(&mut hasher);
92 input.index().hash(&mut hasher);
93 }
94 crate::ir::Op::Constant { value } => {
95 value.to_bits().hash(&mut hasher);
96 }
97 crate::ir::Op::AddScalar { input, scalar }
98 | crate::ir::Op::MulScalar { input, scalar } => {
99 input.index().hash(&mut hasher);
100 scalar.to_bits().hash(&mut hasher);
101 }
102 crate::ir::Op::Reshape { input, shape } => {
103 input.index().hash(&mut hasher);
104 shape.hash(&mut hasher);
105 }
106 crate::ir::Op::Transpose { input, dim0, dim1 } => {
107 input.index().hash(&mut hasher);
108 dim0.hash(&mut hasher);
109 dim1.hash(&mut hasher);
110 }
111 crate::ir::Op::SumAxis {
112 input,
113 axis,
114 keepdim,
115 }
116 | crate::ir::Op::MeanAxis {
117 input,
118 axis,
119 keepdim,
120 }
121 | crate::ir::Op::MaxAxis {
122 input,
123 axis,
124 keepdim,
125 } => {
126 input.index().hash(&mut hasher);
127 axis.hash(&mut hasher);
128 keepdim.hash(&mut hasher);
129 }
130 crate::ir::Op::Squeeze { input, dim } | crate::ir::Op::Unsqueeze { input, dim } => {
131 input.index().hash(&mut hasher);
132 dim.hash(&mut hasher);
133 }
134 crate::ir::Op::Broadcast { input, shape } => {
135 input.index().hash(&mut hasher);
136 shape.hash(&mut hasher);
137 }
138 crate::ir::Op::Cast { input, dtype } => {
139 input.index().hash(&mut hasher);
140 dtype.hash(&mut hasher);
141 }
142 crate::ir::Op::Add { lhs, rhs }
144 | crate::ir::Op::Sub { lhs, rhs }
145 | crate::ir::Op::Mul { lhs, rhs }
146 | crate::ir::Op::Div { lhs, rhs }
147 | crate::ir::Op::Pow {
148 base: lhs,
149 exp: rhs,
150 }
151 | crate::ir::Op::Max { lhs, rhs }
152 | crate::ir::Op::Min { lhs, rhs }
153 | crate::ir::Op::MatMul { lhs, rhs }
154 | crate::ir::Op::Gt { lhs, rhs }
155 | crate::ir::Op::Lt { lhs, rhs }
156 | crate::ir::Op::Eq { lhs, rhs } => {
157 lhs.index().hash(&mut hasher);
158 rhs.index().hash(&mut hasher);
159 }
160 crate::ir::Op::Where { condition, x, y } => {
162 condition.index().hash(&mut hasher);
163 x.index().hash(&mut hasher);
164 y.index().hash(&mut hasher);
165 }
166 _ => {
168 for input in node.op.inputs() {
169 input.index().hash(&mut hasher);
170 }
171 }
172 }
173 }
174
175 hasher.finish()
176 }
177
178 pub fn get(&self, key: u64) -> Option<CompiledFunction> {
184 self.cache.read().get(&key).cloned()
185 }
186
187 pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
189 let key = Self::hash_graph(graph);
190 self.get(key)
191 }
192
193 pub fn insert(&self, key: u64, func: CompiledFunction) {
195 let mut cache = self.cache.write();
196
197 if cache.len() >= self.max_size {
199 if let Some(&first_key) = cache.keys().next() {
201 cache.remove(&first_key);
202 }
203 }
204
205 cache.insert(key, func);
206 }
207
208 pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
210 let key = Self::hash_graph(graph);
211 self.insert(key, func);
212 }
213
214 pub fn len(&self) -> usize {
220 self.cache.read().len()
221 }
222
223 pub fn is_empty(&self) -> bool {
225 self.cache.read().is_empty()
226 }
227
228 pub fn clear(&self) {
230 self.cache.write().clear();
231 }
232
233 pub fn stats(&self) -> CacheStats {
235 CacheStats {
236 entries: self.len(),
237 max_size: self.max_size,
238 }
239 }
240}
241
242impl Default for FunctionCache {
243 fn default() -> Self {
244 Self::default_size()
245 }
246}
247
248#[derive(Debug, Clone)]
254pub struct CacheStats {
255 pub entries: usize,
257 pub max_size: usize,
259}
260
261impl CacheStats {
262 pub fn utilization(&self) -> f64 {
264 if self.max_size == 0 {
265 0.0
266 } else {
267 (self.entries as f64 / self.max_size as f64) * 100.0
268 }
269 }
270}
271
272#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::trace::trace;
280
281 #[test]
282 fn test_graph_hash() {
283 let graph1 = trace(|tracer| {
284 let a = tracer.input("a", &[2, 3]);
285 let b = tracer.input("b", &[2, 3]);
286 let c = a.add(&b);
287 tracer.output("result", c)
288 });
289
290 let graph2 = trace(|tracer| {
291 let a = tracer.input("a", &[2, 3]);
292 let b = tracer.input("b", &[2, 3]);
293 let c = a.add(&b);
294 tracer.output("result", c)
295 });
296
297 let graph3 = trace(|tracer| {
298 let a = tracer.input("a", &[2, 3]);
299 let b = tracer.input("b", &[2, 3]);
300 let c = a.mul(&b); tracer.output("result", c)
302 });
303
304 let hash1 = FunctionCache::hash_graph(&graph1);
305 let hash2 = FunctionCache::hash_graph(&graph2);
306 let hash3 = FunctionCache::hash_graph(&graph3);
307
308 assert_eq!(hash1, hash2);
310 assert_ne!(hash1, hash3);
312 }
313
314 #[test]
315 fn test_cache_insert_get() {
316 let cache = FunctionCache::new(10);
317
318 let graph = trace(|tracer| {
319 let a = tracer.input("a", &[2, 3]);
320 tracer.output("result", a.relu())
321 });
322
323 let key = FunctionCache::hash_graph(&graph);
324 let func = CompiledFunction::placeholder();
325
326 assert!(cache.get(key).is_none());
327 cache.insert(key, func.clone());
328 assert!(cache.get(key).is_some());
329 }
330
331 #[test]
332 fn test_cache_eviction() {
333 let cache = FunctionCache::new(2);
334
335 for i in 0..3 {
336 cache.insert(i as u64, CompiledFunction::placeholder());
337 }
338
339 assert_eq!(cache.len(), 2);
341 }
342
343 #[test]
344 fn test_cache_stats() {
345 let cache = FunctionCache::new(100);
346 cache.insert(1, CompiledFunction::placeholder());
347 cache.insert(2, CompiledFunction::placeholder());
348
349 let stats = cache.stats();
350 assert_eq!(stats.entries, 2);
351 assert_eq!(stats.max_size, 100);
352 assert!((stats.utilization() - 2.0).abs() < 0.01);
353 }
354}