1use super::{compile_with, CompileConfig, CompiledFn, JitError};
52use crate::kernel::{ExprId, ExprPool};
53use std::collections::HashMap;
54use std::sync::Arc;
55
56type CacheKey = (ExprId, Vec<ExprId>, CompileConfig);
65
66pub struct CompileCache {
74 store: HashMap<CacheKey, Arc<CompiledFn>>,
75 compiles: u64,
77 hits: u64,
79}
80
81impl CompileCache {
82 pub fn new() -> Self {
84 Self {
85 store: HashMap::new(),
86 compiles: 0,
87 hits: 0,
88 }
89 }
90
91 pub fn compile(
106 &mut self,
107 expr: ExprId,
108 inputs: &[ExprId],
109 pool: &ExprPool,
110 ) -> Result<Arc<CompiledFn>, JitError> {
111 self.compile_with(expr, inputs, pool, CompileConfig::default())
112 }
113
114 pub fn compile_with(
119 &mut self,
120 expr: ExprId,
121 inputs: &[ExprId],
122 pool: &ExprPool,
123 config: CompileConfig,
124 ) -> Result<Arc<CompiledFn>, JitError> {
125 let key: CacheKey = (expr, inputs.to_vec(), config);
126 if let Some(cached) = self.store.get(&key) {
127 self.hits += 1;
128 return Ok(Arc::clone(cached));
129 }
130 self.compiles += 1;
131 let compiled = Arc::new(compile_with(expr, inputs, pool, config)?);
132 self.store.insert(key, Arc::clone(&compiled));
133 Ok(compiled)
134 }
135
136 pub fn len(&self) -> usize {
138 self.store.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.store.is_empty()
144 }
145
146 pub fn contains(&self, expr: ExprId, inputs: &[ExprId]) -> bool {
148 self.contains_with(expr, inputs, CompileConfig::default())
149 }
150
151 pub fn contains_with(&self, expr: ExprId, inputs: &[ExprId], config: CompileConfig) -> bool {
153 self.store.contains_key(&(expr, inputs.to_vec(), config))
154 }
155
156 pub fn compile_count(&self) -> u64 {
158 self.compiles
159 }
160
161 pub fn hit_count(&self) -> u64 {
163 self.hits
164 }
165
166 pub fn hit_rate(&self) -> f64 {
168 let total = self.compiles + self.hits;
169 if total == 0 {
170 0.0
171 } else {
172 self.hits as f64 / total as f64
173 }
174 }
175
176 pub fn clear(&mut self) {
179 self.store.clear();
180 }
183
184 pub fn evict(&mut self, expr: ExprId, inputs: &[ExprId]) -> Option<Arc<CompiledFn>> {
186 self.evict_with(expr, inputs, CompileConfig::default())
187 }
188
189 pub fn evict_with(
191 &mut self,
192 expr: ExprId,
193 inputs: &[ExprId],
194 config: CompileConfig,
195 ) -> Option<Arc<CompiledFn>> {
196 self.store.remove(&(expr, inputs.to_vec(), config))
197 }
198}
199
200impl Default for CompileCache {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::kernel::{Domain, ExprPool};
214
215 fn p() -> ExprPool {
216 ExprPool::new()
217 }
218
219 #[test]
220 fn cache_miss_then_hit() {
221 let pool = p();
222 let x = pool.symbol("x", Domain::Real);
223 let expr = pool.pow(x, pool.integer(2_i32));
224
225 let mut cache = CompileCache::new();
226 assert!(cache.is_empty());
227 assert_eq!(cache.compile_count(), 0);
228 assert_eq!(cache.hit_count(), 0);
229
230 let f1 = cache.compile(expr, &[x], &pool).unwrap();
231 assert_eq!(cache.len(), 1);
232 assert_eq!(cache.compile_count(), 1);
233 assert_eq!(cache.hit_count(), 0);
234
235 let f2 = cache.compile(expr, &[x], &pool).unwrap();
236 assert_eq!(cache.len(), 1); assert_eq!(cache.compile_count(), 1); assert_eq!(cache.hit_count(), 1);
239
240 assert!(Arc::ptr_eq(&f1, &f2));
242 }
243
244 #[test]
245 fn cache_correct_result() {
246 let pool = p();
247 let x = pool.symbol("x", Domain::Real);
248 let expr = pool.pow(x, pool.integer(2_i32));
249
250 let mut cache = CompileCache::new();
251 let f = cache.compile(expr, &[x], &pool).unwrap();
252 assert!((f.call(&[3.0]) - 9.0).abs() < 1e-10);
253 assert!((f.call(&[5.0]) - 25.0).abs() < 1e-10);
254 }
255
256 #[test]
257 fn different_var_order_different_entry() {
258 let pool = p();
259 let x = pool.symbol("x", Domain::Real);
260 let y = pool.symbol("y", Domain::Real);
261 let expr = pool.add(vec![x, y]);
262
263 let mut cache = CompileCache::new();
264 let f_xy = cache.compile(expr, &[x, y], &pool).unwrap();
265 let f_yx = cache.compile(expr, &[y, x], &pool).unwrap();
266
267 assert_eq!(cache.len(), 2);
269 assert!(!Arc::ptr_eq(&f_xy, &f_yx));
270
271 assert!((f_xy.call(&[1.0, 2.0]) - 3.0).abs() < 1e-10);
273 assert!((f_yx.call(&[1.0, 2.0]) - 3.0).abs() < 1e-10);
275 }
276
277 #[test]
278 fn different_exprs_different_entries() {
279 let pool = p();
280 let x = pool.symbol("x", Domain::Real);
281 let sq = pool.pow(x, pool.integer(2_i32));
282 let cube = pool.pow(x, pool.integer(3_i32));
283
284 let mut cache = CompileCache::new();
285 let f_sq = cache.compile(sq, &[x], &pool).unwrap();
286 let f_cu = cache.compile(cube, &[x], &pool).unwrap();
287
288 assert_eq!(cache.len(), 2);
289 assert!(!Arc::ptr_eq(&f_sq, &f_cu));
290 assert!((f_sq.call(&[3.0]) - 9.0).abs() < 1e-10);
291 assert!((f_cu.call(&[3.0]) - 27.0).abs() < 1e-10);
292 }
293
294 #[test]
295 fn arc_survives_cache_clear() {
296 let pool = p();
297 let x = pool.symbol("x", Domain::Real);
298 let expr = pool.pow(x, pool.integer(2_i32));
299
300 let mut cache = CompileCache::new();
301 let f = cache.compile(expr, &[x], &pool).unwrap();
302
303 cache.clear();
304 assert!(cache.is_empty());
305
306 assert!((f.call(&[4.0]) - 16.0).abs() < 1e-10);
308 }
309
310 #[test]
311 fn evict_removes_single_entry() {
312 let pool = p();
313 let x = pool.symbol("x", Domain::Real);
314 let sq = pool.pow(x, pool.integer(2_i32));
315 let cube = pool.pow(x, pool.integer(3_i32));
316
317 let mut cache = CompileCache::new();
318 cache.compile(sq, &[x], &pool).unwrap();
319 cache.compile(cube, &[x], &pool).unwrap();
320 assert_eq!(cache.len(), 2);
321
322 let evicted = cache.evict(sq, &[x]);
323 assert!(evicted.is_some());
324 assert_eq!(cache.len(), 1);
325 assert!(!cache.contains(sq, &[x]));
326 assert!(cache.contains(cube, &[x]));
327 }
328
329 #[test]
330 fn contains_checks_key() {
331 let pool = p();
332 let x = pool.symbol("x", Domain::Real);
333 let y = pool.symbol("y", Domain::Real);
334 let expr = pool.add(vec![x, y]);
335
336 let mut cache = CompileCache::new();
337 assert!(!cache.contains(expr, &[x, y]));
338 cache.compile(expr, &[x, y], &pool).unwrap();
339 assert!(cache.contains(expr, &[x, y]));
340 assert!(!cache.contains(expr, &[y, x])); }
342
343 #[test]
344 fn hit_rate_is_correct() {
345 let pool = p();
346 let x = pool.symbol("x", Domain::Real);
347 let expr = pool.pow(x, pool.integer(2_i32));
348
349 let mut cache = CompileCache::new();
350 assert_eq!(cache.hit_rate(), 0.0);
351
352 cache.compile(expr, &[x], &pool).unwrap(); assert_eq!(cache.hit_rate(), 0.0); cache.compile(expr, &[x], &pool).unwrap(); cache.compile(expr, &[x], &pool).unwrap(); let rate = cache.hit_rate();
360 assert!((rate - 2.0 / 3.0).abs() < 1e-10);
361 }
362}