1use num_rational::BigRational;
16use num_traits::{Signed, Zero};
17use oxiz_math::polynomial::{Polynomial, Var};
18use rustc_hash::FxHashMap;
19use std::collections::VecDeque;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23struct EvalFingerprint {
24 poly_hash: u64,
25 assignment_hash: u64,
26}
27
28impl EvalFingerprint {
29 fn new(poly: &Polynomial, assignment: &FxHashMap<Var, BigRational>) -> Self {
31 use std::collections::hash_map::DefaultHasher;
32 use std::hash::{Hash, Hasher};
33
34 let mut poly_hasher = DefaultHasher::new();
36 format!("{:?}", poly).hash(&mut poly_hasher);
37 let poly_hash = poly_hasher.finish();
38
39 let mut assignment_hasher = DefaultHasher::new();
41 let mut vars = poly.vars();
42 vars.sort();
43 for var in &vars {
44 if let Some(value) = assignment.get(var) {
45 var.hash(&mut assignment_hasher);
46 format!("{:?}", value).hash(&mut assignment_hasher);
47 }
48 }
49 let assignment_hash = assignment_hasher.finish();
50
51 Self {
52 poly_hash,
53 assignment_hash,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub enum CachedSign {
61 Positive,
63 Zero,
65 Negative,
67}
68
69impl CachedSign {
70 pub fn from_value(value: &BigRational) -> Self {
72 if value.is_zero() {
73 CachedSign::Zero
74 } else if value.is_positive() {
75 CachedSign::Positive
76 } else {
77 CachedSign::Negative
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Hash)]
84pub struct SignPattern(Vec<CachedSign>);
85
86impl SignPattern {
87 pub fn new(signs: Vec<CachedSign>) -> Self {
89 Self(signs)
90 }
91
92 pub fn get(&self, index: usize) -> Option<CachedSign> {
94 self.0.get(index).copied()
95 }
96
97 pub fn len(&self) -> usize {
99 self.0.len()
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.0.is_empty()
105 }
106}
107
108#[derive(Debug, Clone, Default)]
110pub struct EvalCacheStats {
111 pub value_hits: u64,
113 pub value_misses: u64,
115 pub pattern_hits: u64,
117 pub pattern_misses: u64,
119 pub evictions: u64,
121}
122
123impl EvalCacheStats {
124 pub fn value_hit_rate(&self) -> f64 {
126 let total = self.value_hits + self.value_misses;
127 if total == 0 {
128 0.0
129 } else {
130 self.value_hits as f64 / total as f64
131 }
132 }
133
134 pub fn pattern_hit_rate(&self) -> f64 {
136 let total = self.pattern_hits + self.pattern_misses;
137 if total == 0 {
138 0.0
139 } else {
140 self.pattern_hits as f64 / total as f64
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct EvalCacheConfig {
148 pub max_value_entries: usize,
150 pub max_pattern_entries: usize,
152 pub enable_value_cache: bool,
154 pub enable_pattern_cache: bool,
156}
157
158impl Default for EvalCacheConfig {
159 fn default() -> Self {
160 Self {
161 max_value_entries: 10_000,
162 max_pattern_entries: 5_000,
163 enable_value_cache: true,
164 enable_pattern_cache: true,
165 }
166 }
167}
168
169pub struct EvalCache {
171 config: EvalCacheConfig,
173 value_cache: FxHashMap<EvalFingerprint, BigRational>,
175 value_lru: VecDeque<EvalFingerprint>,
177 pattern_cache: FxHashMap<u64, SignPattern>,
179 pattern_lru: VecDeque<u64>,
181 stats: EvalCacheStats,
183}
184
185impl EvalCache {
186 pub fn new() -> Self {
188 Self::with_config(EvalCacheConfig::default())
189 }
190
191 pub fn with_config(config: EvalCacheConfig) -> Self {
193 Self {
194 config,
195 value_cache: FxHashMap::default(),
196 value_lru: VecDeque::new(),
197 pattern_cache: FxHashMap::default(),
198 pattern_lru: VecDeque::new(),
199 stats: EvalCacheStats::default(),
200 }
201 }
202
203 pub fn lookup_value(
205 &mut self,
206 poly: &Polynomial,
207 assignment: &FxHashMap<Var, BigRational>,
208 ) -> Option<BigRational> {
209 if !self.config.enable_value_cache {
210 return None;
211 }
212
213 let fingerprint = EvalFingerprint::new(poly, assignment);
214
215 if let Some(value) = self.value_cache.get(&fingerprint) {
216 self.stats.value_hits += 1;
217 self.value_lru.retain(|&f| f != fingerprint);
219 self.value_lru.push_front(fingerprint);
220 Some(value.clone())
221 } else {
222 self.stats.value_misses += 1;
223 None
224 }
225 }
226
227 pub fn insert_value(
229 &mut self,
230 poly: &Polynomial,
231 assignment: &FxHashMap<Var, BigRational>,
232 value: BigRational,
233 ) {
234 if !self.config.enable_value_cache {
235 return;
236 }
237
238 let fingerprint = EvalFingerprint::new(poly, assignment);
239
240 if self.value_cache.len() >= self.config.max_value_entries
242 && let Some(old_fingerprint) = self.value_lru.pop_back()
243 {
244 self.value_cache.remove(&old_fingerprint);
245 self.stats.evictions += 1;
246 }
247
248 self.value_cache.insert(fingerprint, value);
249 self.value_lru.push_front(fingerprint);
250 }
251
252 pub fn lookup_pattern(&mut self, pattern_hash: u64) -> Option<SignPattern> {
254 if !self.config.enable_pattern_cache {
255 return None;
256 }
257
258 if let Some(pattern) = self.pattern_cache.get(&pattern_hash) {
259 self.stats.pattern_hits += 1;
260 self.pattern_lru.retain(|&h| h != pattern_hash);
262 self.pattern_lru.push_front(pattern_hash);
263 Some(pattern.clone())
264 } else {
265 self.stats.pattern_misses += 1;
266 None
267 }
268 }
269
270 pub fn insert_pattern(&mut self, pattern_hash: u64, pattern: SignPattern) {
272 if !self.config.enable_pattern_cache {
273 return;
274 }
275
276 if self.pattern_cache.len() >= self.config.max_pattern_entries
278 && let Some(old_hash) = self.pattern_lru.pop_back()
279 {
280 self.pattern_cache.remove(&old_hash);
281 self.stats.evictions += 1;
282 }
283
284 self.pattern_cache.insert(pattern_hash, pattern);
285 self.pattern_lru.push_front(pattern_hash);
286 }
287
288 pub fn compute_pattern_hash(
290 polys: &[Polynomial],
291 assignment: &FxHashMap<Var, BigRational>,
292 ) -> u64 {
293 use std::collections::hash_map::DefaultHasher;
294 use std::hash::{Hash, Hasher};
295
296 let mut hasher = DefaultHasher::new();
297
298 for poly in polys {
299 format!("{:?}", poly).hash(&mut hasher);
300 }
301
302 let mut vars: Vec<_> = assignment.keys().copied().collect();
304 vars.sort();
305 for var in vars {
306 if let Some(value) = assignment.get(&var) {
307 var.hash(&mut hasher);
308 format!("{:?}", value).hash(&mut hasher);
309 }
310 }
311
312 hasher.finish()
313 }
314
315 pub fn clear(&mut self) {
317 self.value_cache.clear();
318 self.value_lru.clear();
319 self.pattern_cache.clear();
320 self.pattern_lru.clear();
321 }
322
323 pub fn stats(&self) -> &EvalCacheStats {
325 &self.stats
326 }
327
328 pub fn value_cache_size(&self) -> usize {
330 self.value_cache.len()
331 }
332
333 pub fn pattern_cache_size(&self) -> usize {
335 self.pattern_cache.len()
336 }
337}
338
339impl Default for EvalCache {
340 fn default() -> Self {
341 Self::new()
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use num_bigint::BigInt;
349 use num_traits::Zero;
350
351 #[test]
352 fn test_eval_cache_new() {
353 let cache = EvalCache::new();
354 assert_eq!(cache.value_cache_size(), 0);
355 assert_eq!(cache.pattern_cache_size(), 0);
356 }
357
358 #[test]
359 fn test_value_cache_basic() {
360 let mut cache = EvalCache::new();
361
362 let poly = Polynomial::from_var(0);
363 let mut assignment = FxHashMap::default();
364 assignment.insert(0, BigRational::from_integer(BigInt::from(5)));
365
366 assert!(cache.lookup_value(&poly, &assignment).is_none());
368
369 let value = BigRational::from_integer(BigInt::from(5));
371 cache.insert_value(&poly, &assignment, value.clone());
372
373 assert_eq!(cache.lookup_value(&poly, &assignment), Some(value));
375 assert_eq!(cache.stats().value_hits, 1);
376 assert_eq!(cache.stats().value_misses, 1);
377 }
378
379 #[test]
380 fn test_value_cache_eviction() {
381 let config = EvalCacheConfig {
382 max_value_entries: 2,
383 ..Default::default()
384 };
385 let mut cache = EvalCache::with_config(config);
386
387 for i in 0..3 {
389 let poly = Polynomial::from_var(i);
390 let mut assignment = FxHashMap::default();
391 assignment.insert(i, BigRational::from_integer(BigInt::from(i as i32)));
392 cache.insert_value(&poly, &assignment, BigRational::zero());
393 }
394
395 assert_eq!(cache.value_cache_size(), 2);
396 assert!(cache.stats().evictions > 0);
397 }
398
399 #[test]
400 fn test_sign_pattern() {
401 let pattern = SignPattern::new(vec![
402 CachedSign::Positive,
403 CachedSign::Zero,
404 CachedSign::Negative,
405 ]);
406
407 assert_eq!(pattern.len(), 3);
408 assert_eq!(pattern.get(0), Some(CachedSign::Positive));
409 assert_eq!(pattern.get(1), Some(CachedSign::Zero));
410 assert_eq!(pattern.get(2), Some(CachedSign::Negative));
411 assert_eq!(pattern.get(3), None);
412 }
413
414 #[test]
415 fn test_pattern_cache_basic() {
416 let mut cache = EvalCache::new();
417
418 let pattern = SignPattern::new(vec![CachedSign::Positive, CachedSign::Negative]);
419 let hash = 12345u64;
420
421 assert!(cache.lookup_pattern(hash).is_none());
423
424 cache.insert_pattern(hash, pattern.clone());
426
427 assert_eq!(cache.lookup_pattern(hash), Some(pattern));
429 assert_eq!(cache.stats().pattern_hits, 1);
430 assert_eq!(cache.stats().pattern_misses, 1);
431 }
432
433 #[test]
434 fn test_stats_hit_rate() {
435 let stats = EvalCacheStats {
436 value_hits: 80,
437 value_misses: 20,
438 pattern_hits: 60,
439 pattern_misses: 40,
440 evictions: 5,
441 };
442
443 assert_eq!(stats.value_hit_rate(), 0.8);
444 assert_eq!(stats.pattern_hit_rate(), 0.6);
445 }
446
447 #[test]
448 fn test_clear() {
449 let mut cache = EvalCache::new();
450
451 let poly = Polynomial::from_var(0);
452 let mut assignment = FxHashMap::default();
453 assignment.insert(0, BigRational::from_integer(BigInt::from(1)));
454 cache.insert_value(&poly, &assignment, BigRational::zero());
455
456 cache.insert_pattern(123, SignPattern::new(vec![CachedSign::Positive]));
457
458 cache.clear();
459 assert_eq!(cache.value_cache_size(), 0);
460 assert_eq!(cache.pattern_cache_size(), 0);
461 }
462}