1use parking_lot::RwLock;
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10use crate::codegen::CompiledFunction;
11use crate::ir::Graph;
12
13pub struct FunctionCache {
15 cache: RwLock<FxHashMap<u64, CompiledFunction>>,
16 max_size: usize,
17}
18
19impl FunctionCache {
20 pub fn new(max_size: usize) -> Self {
22 Self {
23 cache: RwLock::new(FxHashMap::default()),
24 max_size,
25 }
26 }
27
28 pub fn default_size() -> Self {
30 Self::new(1000)
31 }
32
33 pub fn hash_graph(graph: &Graph) -> u64 {
35 let mut hasher = DefaultHasher::new();
36
37 for node in graph.nodes() {
39 std::mem::discriminant(&node.op).hash(&mut hasher);
41
42 node.dtype.hash(&mut hasher);
44
45 node.shape.dims().hash(&mut hasher);
47
48 match &node.op {
50 crate::ir::Op::Input { name } => name.hash(&mut hasher),
51 crate::ir::Op::Output { name, input } => {
52 name.hash(&mut hasher);
53 input.index().hash(&mut hasher);
54 }
55 crate::ir::Op::Constant { value } => {
56 value.to_bits().hash(&mut hasher);
57 }
58 crate::ir::Op::AddScalar { input, scalar }
59 | crate::ir::Op::MulScalar { input, scalar } => {
60 input.index().hash(&mut hasher);
61 scalar.to_bits().hash(&mut hasher);
62 }
63 crate::ir::Op::Reshape { input, shape } => {
64 input.index().hash(&mut hasher);
65 shape.hash(&mut hasher);
66 }
67 crate::ir::Op::Transpose { input, dim0, dim1 } => {
68 input.index().hash(&mut hasher);
69 dim0.hash(&mut hasher);
70 dim1.hash(&mut hasher);
71 }
72 crate::ir::Op::SumAxis {
73 input,
74 axis,
75 keepdim,
76 }
77 | crate::ir::Op::MeanAxis {
78 input,
79 axis,
80 keepdim,
81 }
82 | crate::ir::Op::MaxAxis {
83 input,
84 axis,
85 keepdim,
86 } => {
87 input.index().hash(&mut hasher);
88 axis.hash(&mut hasher);
89 keepdim.hash(&mut hasher);
90 }
91 crate::ir::Op::Squeeze { input, dim } | crate::ir::Op::Unsqueeze { input, dim } => {
92 input.index().hash(&mut hasher);
93 dim.hash(&mut hasher);
94 }
95 crate::ir::Op::Broadcast { input, shape } => {
96 input.index().hash(&mut hasher);
97 shape.hash(&mut hasher);
98 }
99 crate::ir::Op::Cast { input, dtype } => {
100 input.index().hash(&mut hasher);
101 dtype.hash(&mut hasher);
102 }
103 crate::ir::Op::Add { lhs, rhs }
105 | crate::ir::Op::Sub { lhs, rhs }
106 | crate::ir::Op::Mul { lhs, rhs }
107 | crate::ir::Op::Div { lhs, rhs }
108 | crate::ir::Op::Pow {
109 base: lhs,
110 exp: rhs,
111 }
112 | crate::ir::Op::Max { lhs, rhs }
113 | crate::ir::Op::Min { lhs, rhs }
114 | crate::ir::Op::MatMul { lhs, rhs }
115 | crate::ir::Op::Gt { lhs, rhs }
116 | crate::ir::Op::Lt { lhs, rhs }
117 | crate::ir::Op::Eq { lhs, rhs } => {
118 lhs.index().hash(&mut hasher);
119 rhs.index().hash(&mut hasher);
120 }
121 crate::ir::Op::Where { condition, x, y } => {
123 condition.index().hash(&mut hasher);
124 x.index().hash(&mut hasher);
125 y.index().hash(&mut hasher);
126 }
127 _ => {
129 for input in node.op.inputs() {
130 input.index().hash(&mut hasher);
131 }
132 }
133 }
134 }
135
136 hasher.finish()
137 }
138
139 pub fn get(&self, key: u64) -> Option<CompiledFunction> {
141 self.cache.read().get(&key).cloned()
142 }
143
144 pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
146 let key = Self::hash_graph(graph);
147 self.get(key)
148 }
149
150 pub fn insert(&self, key: u64, func: CompiledFunction) {
152 let mut cache = self.cache.write();
153
154 if cache.len() >= self.max_size {
156 if let Some(&first_key) = cache.keys().next() {
158 cache.remove(&first_key);
159 }
160 }
161
162 cache.insert(key, func);
163 }
164
165 pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
167 let key = Self::hash_graph(graph);
168 self.insert(key, func);
169 }
170
171 pub fn len(&self) -> usize {
173 self.cache.read().len()
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.cache.read().is_empty()
179 }
180
181 pub fn clear(&self) {
183 self.cache.write().clear();
184 }
185
186 pub fn stats(&self) -> CacheStats {
188 CacheStats {
189 entries: self.len(),
190 max_size: self.max_size,
191 }
192 }
193}
194
195impl Default for FunctionCache {
196 fn default() -> Self {
197 Self::default_size()
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct CacheStats {
204 pub entries: usize,
206 pub max_size: usize,
208}
209
210impl CacheStats {
211 pub fn utilization(&self) -> f64 {
213 if self.max_size == 0 {
214 0.0
215 } else {
216 (self.entries as f64 / self.max_size as f64) * 100.0
217 }
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::trace::trace;
225
226 #[test]
227 fn test_graph_hash() {
228 let graph1 = trace(|tracer| {
229 let a = tracer.input("a", &[2, 3]);
230 let b = tracer.input("b", &[2, 3]);
231 let c = a.add(&b);
232 tracer.output("result", c)
233 });
234
235 let graph2 = trace(|tracer| {
236 let a = tracer.input("a", &[2, 3]);
237 let b = tracer.input("b", &[2, 3]);
238 let c = a.add(&b);
239 tracer.output("result", c)
240 });
241
242 let graph3 = trace(|tracer| {
243 let a = tracer.input("a", &[2, 3]);
244 let b = tracer.input("b", &[2, 3]);
245 let c = a.mul(&b); tracer.output("result", c)
247 });
248
249 let hash1 = FunctionCache::hash_graph(&graph1);
250 let hash2 = FunctionCache::hash_graph(&graph2);
251 let hash3 = FunctionCache::hash_graph(&graph3);
252
253 assert_eq!(hash1, hash2);
255 assert_ne!(hash1, hash3);
257 }
258
259 #[test]
260 fn test_cache_insert_get() {
261 let cache = FunctionCache::new(10);
262
263 let graph = trace(|tracer| {
264 let a = tracer.input("a", &[2, 3]);
265 tracer.output("result", a.relu())
266 });
267
268 let key = FunctionCache::hash_graph(&graph);
269 let func = CompiledFunction::placeholder();
270
271 assert!(cache.get(key).is_none());
272 cache.insert(key, func.clone());
273 assert!(cache.get(key).is_some());
274 }
275
276 #[test]
277 fn test_cache_eviction() {
278 let cache = FunctionCache::new(2);
279
280 for i in 0..3 {
281 cache.insert(i as u64, CompiledFunction::placeholder());
282 }
283
284 assert_eq!(cache.len(), 2);
286 }
287
288 #[test]
289 fn test_cache_stats() {
290 let cache = FunctionCache::new(100);
291 cache.insert(1, CompiledFunction::placeholder());
292 cache.insert(2, CompiledFunction::placeholder());
293
294 let stats = cache.stats();
295 assert_eq!(stats.entries, 2);
296 assert_eq!(stats.max_size, 100);
297 assert!((stats.utilization() - 2.0).abs() < 0.01);
298 }
299}