1use std::hash::{Hash, Hasher};
6use std::collections::hash_map::DefaultHasher;
7use parking_lot::RwLock;
8use rustc_hash::FxHashMap;
9
10use crate::ir::Graph;
11use crate::codegen::CompiledFunction;
12
13pub struct FunctionCache {
15 cache: RwLock<FxHashMap<u64, CompiledFunction>>,
16 max_size: usize,
17}
18
19impl FunctionCache {
20 pub fn new(max_size: usize) -> Self {
22 Self {
23 cache: RwLock::new(FxHashMap::default()),
24 max_size,
25 }
26 }
27
28 pub fn default_size() -> Self {
30 Self::new(1000)
31 }
32
33 pub fn hash_graph(graph: &Graph) -> u64 {
35 let mut hasher = DefaultHasher::new();
36
37 for node in graph.nodes() {
39 std::mem::discriminant(&node.op).hash(&mut hasher);
41
42 node.dtype.hash(&mut hasher);
44
45 node.shape.dims().hash(&mut hasher);
47
48 match &node.op {
50 crate::ir::Op::Input { name } => name.hash(&mut hasher),
51 crate::ir::Op::Output { name, input } => {
52 name.hash(&mut hasher);
53 input.index().hash(&mut hasher);
54 }
55 crate::ir::Op::Constant { value } => {
56 value.to_bits().hash(&mut hasher);
57 }
58 crate::ir::Op::AddScalar { input, scalar } |
59 crate::ir::Op::MulScalar { input, scalar } => {
60 input.index().hash(&mut hasher);
61 scalar.to_bits().hash(&mut hasher);
62 }
63 crate::ir::Op::Reshape { input, shape } => {
64 input.index().hash(&mut hasher);
65 shape.hash(&mut hasher);
66 }
67 crate::ir::Op::Transpose { input, dim0, dim1 } => {
68 input.index().hash(&mut hasher);
69 dim0.hash(&mut hasher);
70 dim1.hash(&mut hasher);
71 }
72 crate::ir::Op::SumAxis { input, axis, keepdim } |
73 crate::ir::Op::MeanAxis { input, axis, keepdim } |
74 crate::ir::Op::MaxAxis { input, axis, keepdim } => {
75 input.index().hash(&mut hasher);
76 axis.hash(&mut hasher);
77 keepdim.hash(&mut hasher);
78 }
79 crate::ir::Op::Squeeze { input, dim } |
80 crate::ir::Op::Unsqueeze { input, dim } => {
81 input.index().hash(&mut hasher);
82 dim.hash(&mut hasher);
83 }
84 crate::ir::Op::Broadcast { input, shape } => {
85 input.index().hash(&mut hasher);
86 shape.hash(&mut hasher);
87 }
88 crate::ir::Op::Cast { input, dtype } => {
89 input.index().hash(&mut hasher);
90 dtype.hash(&mut hasher);
91 }
92 crate::ir::Op::Add { lhs, rhs } |
94 crate::ir::Op::Sub { lhs, rhs } |
95 crate::ir::Op::Mul { lhs, rhs } |
96 crate::ir::Op::Div { lhs, rhs } |
97 crate::ir::Op::Pow { base: lhs, exp: rhs } |
98 crate::ir::Op::Max { lhs, rhs } |
99 crate::ir::Op::Min { lhs, rhs } |
100 crate::ir::Op::MatMul { lhs, rhs } |
101 crate::ir::Op::Gt { lhs, rhs } |
102 crate::ir::Op::Lt { lhs, rhs } |
103 crate::ir::Op::Eq { lhs, rhs } => {
104 lhs.index().hash(&mut hasher);
105 rhs.index().hash(&mut hasher);
106 }
107 crate::ir::Op::Where { condition, x, y } => {
109 condition.index().hash(&mut hasher);
110 x.index().hash(&mut hasher);
111 y.index().hash(&mut hasher);
112 }
113 _ => {
115 for input in node.op.inputs() {
116 input.index().hash(&mut hasher);
117 }
118 }
119 }
120 }
121
122 hasher.finish()
123 }
124
125 pub fn get(&self, key: u64) -> Option<CompiledFunction> {
127 self.cache.read().get(&key).cloned()
128 }
129
130 pub fn get_by_graph(&self, graph: &Graph) -> Option<CompiledFunction> {
132 let key = Self::hash_graph(graph);
133 self.get(key)
134 }
135
136 pub fn insert(&self, key: u64, func: CompiledFunction) {
138 let mut cache = self.cache.write();
139
140 if cache.len() >= self.max_size {
142 if let Some(&first_key) = cache.keys().next() {
144 cache.remove(&first_key);
145 }
146 }
147
148 cache.insert(key, func);
149 }
150
151 pub fn insert_for_graph(&self, graph: &Graph, func: CompiledFunction) {
153 let key = Self::hash_graph(graph);
154 self.insert(key, func);
155 }
156
157 pub fn len(&self) -> usize {
159 self.cache.read().len()
160 }
161
162 pub fn is_empty(&self) -> bool {
164 self.cache.read().is_empty()
165 }
166
167 pub fn clear(&self) {
169 self.cache.write().clear();
170 }
171
172 pub fn stats(&self) -> CacheStats {
174 CacheStats {
175 entries: self.len(),
176 max_size: self.max_size,
177 }
178 }
179}
180
181impl Default for FunctionCache {
182 fn default() -> Self {
183 Self::default_size()
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct CacheStats {
190 pub entries: usize,
192 pub max_size: usize,
194}
195
196impl CacheStats {
197 pub fn utilization(&self) -> f64 {
199 if self.max_size == 0 {
200 0.0
201 } else {
202 (self.entries as f64 / self.max_size as f64) * 100.0
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::trace::trace;
211
212 #[test]
213 fn test_graph_hash() {
214 let graph1 = trace(|tracer| {
215 let a = tracer.input("a", &[2, 3]);
216 let b = tracer.input("b", &[2, 3]);
217 let c = a.add(&b);
218 tracer.output("result", c)
219 });
220
221 let graph2 = trace(|tracer| {
222 let a = tracer.input("a", &[2, 3]);
223 let b = tracer.input("b", &[2, 3]);
224 let c = a.add(&b);
225 tracer.output("result", c)
226 });
227
228 let graph3 = trace(|tracer| {
229 let a = tracer.input("a", &[2, 3]);
230 let b = tracer.input("b", &[2, 3]);
231 let c = a.mul(&b); tracer.output("result", c)
233 });
234
235 let hash1 = FunctionCache::hash_graph(&graph1);
236 let hash2 = FunctionCache::hash_graph(&graph2);
237 let hash3 = FunctionCache::hash_graph(&graph3);
238
239 assert_eq!(hash1, hash2);
241 assert_ne!(hash1, hash3);
243 }
244
245 #[test]
246 fn test_cache_insert_get() {
247 let cache = FunctionCache::new(10);
248
249 let graph = trace(|tracer| {
250 let a = tracer.input("a", &[2, 3]);
251 tracer.output("result", a.relu())
252 });
253
254 let key = FunctionCache::hash_graph(&graph);
255 let func = CompiledFunction::placeholder();
256
257 assert!(cache.get(key).is_none());
258 cache.insert(key, func.clone());
259 assert!(cache.get(key).is_some());
260 }
261
262 #[test]
263 fn test_cache_eviction() {
264 let cache = FunctionCache::new(2);
265
266 for i in 0..3 {
267 cache.insert(i as u64, CompiledFunction::placeholder());
268 }
269
270 assert_eq!(cache.len(), 2);
272 }
273
274 #[test]
275 fn test_cache_stats() {
276 let cache = FunctionCache::new(100);
277 cache.insert(1, CompiledFunction::placeholder());
278 cache.insert(2, CompiledFunction::placeholder());
279
280 let stats = cache.stats();
281 assert_eq!(stats.entries, 2);
282 assert_eq!(stats.max_size, 100);
283 assert!((stats.utilization() - 2.0).abs() < 0.01);
284 }
285}