1use std::cell::RefCell;
11use std::collections::HashMap;
12
13use super::poly::IntPoly;
14use crate::core::Symbol;
15
16static GLOBAL_ACCESS_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
19
20fn next_access_time() -> u64 {
21 GLOBAL_ACCESS_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
22}
23
24#[derive(Debug, Clone)]
26struct CacheEntry<T> {
27 value: T,
28 last_access: u64,
29}
30
31impl<T> CacheEntry<T> {
32 fn new(value: T) -> Self {
33 Self {
34 value,
35 last_access: next_access_time(),
36 }
37 }
38
39 fn touch(&mut self) {
40 self.last_access = next_access_time();
41 }
42}
43
44pub struct PolynomialCache {
55 degree_cache: HashMap<u64, CacheEntry<HashMap<String, i64>>>,
57 classification_cache: HashMap<u64, CacheEntry<CachedClassification>>,
59 leading_coeff_cache: HashMap<u64, CacheEntry<HashMap<String, u64>>>,
61 content_cache: HashMap<u64, CacheEntry<HashMap<String, u64>>>,
63 intpoly_cache: HashMap<u64, CacheEntry<(IntPoly, Symbol)>>,
66 max_entries: usize,
68 hits: u64,
70 misses: u64,
72 intpoly_hits: u64,
74 intpoly_misses: u64,
76}
77
78#[derive(Debug, Clone)]
80pub enum CachedClassification {
81 Integer,
82 Rational,
83 Univariate {
84 var: String,
85 degree: i64,
86 },
87 Multivariate {
88 vars: Vec<String>,
89 total_degree: i64,
90 },
91 RationalFunction,
92 Transcendental,
93 Symbolic,
94}
95
96impl PolynomialCache {
97 pub fn new() -> Self {
99 Self {
100 degree_cache: HashMap::new(),
101 classification_cache: HashMap::new(),
102 leading_coeff_cache: HashMap::new(),
103 content_cache: HashMap::new(),
104 intpoly_cache: HashMap::new(),
105 max_entries: 1024,
106 hits: 0,
107 misses: 0,
108 intpoly_hits: 0,
109 intpoly_misses: 0,
110 }
111 }
112
113 pub fn with_capacity(max_entries: usize) -> Self {
115 Self {
116 degree_cache: HashMap::new(),
117 classification_cache: HashMap::new(),
118 leading_coeff_cache: HashMap::new(),
119 content_cache: HashMap::new(),
120 intpoly_cache: HashMap::new(),
121 max_entries,
122 hits: 0,
123 misses: 0,
124 intpoly_hits: 0,
125 intpoly_misses: 0,
126 }
127 }
128
129 pub fn get_degree(&mut self, expr_hash: u64, var: &str) -> Option<i64> {
131 if let Some(entry) = self.degree_cache.get_mut(&expr_hash) {
132 entry.touch();
133 if let Some(°ree) = entry.value.get(var) {
134 self.hits += 1;
135 return Some(degree);
136 }
137 }
138 self.misses += 1;
139 None
140 }
141
142 pub fn set_degree(&mut self, expr_hash: u64, var: &str, degree: i64) {
144 self.maybe_evict_lru(&CacheType::Degree);
145 self.degree_cache
146 .entry(expr_hash)
147 .or_insert_with(|| CacheEntry::new(HashMap::new()))
148 .value
149 .insert(var.to_owned(), degree);
150 }
151
152 pub fn get_classification(&mut self, expr_hash: u64) -> Option<CachedClassification> {
154 if let Some(entry) = self.classification_cache.get_mut(&expr_hash) {
155 entry.touch();
156 self.hits += 1;
157 return Some(entry.value.clone());
158 }
159 self.misses += 1;
160 None
161 }
162
163 pub fn set_classification(&mut self, expr_hash: u64, classification: CachedClassification) {
165 self.maybe_evict_lru(&CacheType::Classification);
166 self.classification_cache
167 .insert(expr_hash, CacheEntry::new(classification));
168 }
169
170 pub fn get_leading_coeff(&mut self, expr_hash: u64, var: &str) -> Option<u64> {
172 if let Some(entry) = self.leading_coeff_cache.get_mut(&expr_hash) {
173 entry.touch();
174 if let Some(&coeff_hash) = entry.value.get(var) {
175 self.hits += 1;
176 return Some(coeff_hash);
177 }
178 }
179 self.misses += 1;
180 None
181 }
182
183 pub fn set_leading_coeff(&mut self, expr_hash: u64, var: &str, coeff_hash: u64) {
185 self.maybe_evict_lru(&CacheType::LeadingCoeff);
186 self.leading_coeff_cache
187 .entry(expr_hash)
188 .or_insert_with(|| CacheEntry::new(HashMap::new()))
189 .value
190 .insert(var.to_owned(), coeff_hash);
191 }
192
193 pub fn get_content(&mut self, expr_hash: u64, var: &str) -> Option<u64> {
195 if let Some(entry) = self.content_cache.get_mut(&expr_hash) {
196 entry.touch();
197 if let Some(&content_hash) = entry.value.get(var) {
198 self.hits += 1;
199 return Some(content_hash);
200 }
201 }
202 self.misses += 1;
203 None
204 }
205
206 pub fn set_content(&mut self, expr_hash: u64, var: &str, content_hash: u64) {
208 self.maybe_evict_lru(&CacheType::Content);
209 self.content_cache
210 .entry(expr_hash)
211 .or_insert_with(|| CacheEntry::new(HashMap::new()))
212 .value
213 .insert(var.to_owned(), content_hash);
214 }
215
216 pub fn get_intpoly(&mut self, expr_hash: u64) -> Option<(IntPoly, Symbol)> {
221 if let Some(entry) = self.intpoly_cache.get_mut(&expr_hash) {
222 entry.touch();
223 self.intpoly_hits += 1;
224 return Some(entry.value.clone());
225 }
226 self.intpoly_misses += 1;
227 None
228 }
229
230 pub fn set_intpoly(&mut self, expr_hash: u64, poly: IntPoly, var: Symbol) {
234 self.maybe_evict_lru(&CacheType::IntPoly);
235 self.intpoly_cache
236 .insert(expr_hash, CacheEntry::new((poly, var)));
237 }
238
239 pub fn clear(&mut self) {
241 self.degree_cache.clear();
242 self.classification_cache.clear();
243 self.leading_coeff_cache.clear();
244 self.content_cache.clear();
245 self.intpoly_cache.clear();
246 self.hits = 0;
247 self.misses = 0;
248 self.intpoly_hits = 0;
249 self.intpoly_misses = 0;
250 }
251
252 pub fn stats(&self) -> CacheStats {
254 let total_hits = self.hits + self.intpoly_hits;
255 let total_misses = self.misses + self.intpoly_misses;
256 CacheStats {
257 degree_entries: self.degree_cache.len(),
258 classification_entries: self.classification_cache.len(),
259 leading_coeff_entries: self.leading_coeff_cache.len(),
260 content_entries: self.content_cache.len(),
261 intpoly_entries: self.intpoly_cache.len(),
262 hits: total_hits,
263 misses: total_misses,
264 intpoly_hits: self.intpoly_hits,
265 intpoly_misses: self.intpoly_misses,
266 hit_rate: if total_hits + total_misses > 0 {
267 total_hits as f64 / (total_hits + total_misses) as f64
268 } else {
269 0.0
270 },
271 intpoly_hit_rate: if self.intpoly_hits + self.intpoly_misses > 0 {
272 self.intpoly_hits as f64 / (self.intpoly_hits + self.intpoly_misses) as f64
273 } else {
274 0.0
275 },
276 }
277 }
278
279 fn maybe_evict_lru(&mut self, cache_type: &CacheType) {
281 match cache_type {
282 CacheType::Degree => {
283 if self.degree_cache.len() >= self.max_entries {
284 self.evict_lru_from_degree_cache();
285 }
286 }
287 CacheType::Classification => {
288 if self.classification_cache.len() >= self.max_entries {
289 self.evict_lru_from_classification_cache();
290 }
291 }
292 CacheType::LeadingCoeff => {
293 if self.leading_coeff_cache.len() >= self.max_entries {
294 self.evict_lru_from_leading_coeff_cache();
295 }
296 }
297 CacheType::Content => {
298 if self.content_cache.len() >= self.max_entries {
299 self.evict_lru_from_content_cache();
300 }
301 }
302 CacheType::IntPoly => {
303 if self.intpoly_cache.len() >= self.max_entries {
304 self.evict_lru_from_intpoly_cache();
305 }
306 }
307 }
308 }
309
310 fn evict_lru_from_degree_cache(&mut self) {
311 let to_remove = self.max_entries / 4;
312 let mut entries: Vec<_> = self
313 .degree_cache
314 .iter()
315 .map(|(k, v)| (*k, v.last_access))
316 .collect();
317 entries.sort_by_key(|(_, access)| *access);
318
319 for (key, _) in entries.into_iter().take(to_remove) {
320 self.degree_cache.remove(&key);
321 }
322 }
323
324 fn evict_lru_from_classification_cache(&mut self) {
325 let to_remove = self.max_entries / 4;
326 let mut entries: Vec<_> = self
327 .classification_cache
328 .iter()
329 .map(|(k, v)| (*k, v.last_access))
330 .collect();
331 entries.sort_by_key(|(_, access)| *access);
332
333 for (key, _) in entries.into_iter().take(to_remove) {
334 self.classification_cache.remove(&key);
335 }
336 }
337
338 fn evict_lru_from_leading_coeff_cache(&mut self) {
339 let to_remove = self.max_entries / 4;
340 let mut entries: Vec<_> = self
341 .leading_coeff_cache
342 .iter()
343 .map(|(k, v)| (*k, v.last_access))
344 .collect();
345 entries.sort_by_key(|(_, access)| *access);
346
347 for (key, _) in entries.into_iter().take(to_remove) {
348 self.leading_coeff_cache.remove(&key);
349 }
350 }
351
352 fn evict_lru_from_content_cache(&mut self) {
353 let to_remove = self.max_entries / 4;
354 let mut entries: Vec<_> = self
355 .content_cache
356 .iter()
357 .map(|(k, v)| (*k, v.last_access))
358 .collect();
359 entries.sort_by_key(|(_, access)| *access);
360
361 for (key, _) in entries.into_iter().take(to_remove) {
362 self.content_cache.remove(&key);
363 }
364 }
365
366 fn evict_lru_from_intpoly_cache(&mut self) {
367 let to_remove = self.max_entries / 4;
368 let mut entries: Vec<_> = self
369 .intpoly_cache
370 .iter()
371 .map(|(k, v)| (*k, v.last_access))
372 .collect();
373 entries.sort_by_key(|(_, access)| *access);
374
375 for (key, _) in entries.into_iter().take(to_remove) {
376 self.intpoly_cache.remove(&key);
377 }
378 }
379}
380
381enum CacheType {
383 Degree,
384 Classification,
385 LeadingCoeff,
386 Content,
387 IntPoly,
388}
389
390#[derive(Debug, Clone)]
392pub struct CacheStats {
393 pub degree_entries: usize,
394 pub classification_entries: usize,
395 pub leading_coeff_entries: usize,
396 pub content_entries: usize,
397 pub intpoly_entries: usize,
398 pub hits: u64,
399 pub misses: u64,
400 pub intpoly_hits: u64,
401 pub intpoly_misses: u64,
402 pub hit_rate: f64,
403 pub intpoly_hit_rate: f64,
404}
405
406impl Default for PolynomialCache {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412thread_local! {
414 static CACHE: RefCell<PolynomialCache> = RefCell::new(PolynomialCache::new());
415}
416
417pub fn with_cache<F, R>(f: F) -> R
419where
420 F: FnOnce(&mut PolynomialCache) -> R,
421{
422 CACHE.with(|cache| f(&mut cache.borrow_mut()))
423}
424
425pub fn clear_cache() {
427 with_cache(|cache| cache.clear());
428}
429
430pub fn cache_stats() -> CacheStats {
432 with_cache(|cache| cache.stats())
433}
434
435pub fn get_or_compute_intpoly<F>(expr_hash: u64, compute_fn: F) -> Option<(IntPoly, Symbol)>
449where
450 F: FnOnce() -> Option<(IntPoly, Symbol)>,
451{
452 with_cache(|cache| {
453 if let Some(cached) = cache.get_intpoly(expr_hash) {
454 return Some(cached);
455 }
456
457 if let Some((poly, var)) = compute_fn() {
458 cache.set_intpoly(expr_hash, poly.clone(), var.clone());
459 Some((poly, var))
460 } else {
461 None
462 }
463 })
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_cache_degree() {
472 let mut cache = PolynomialCache::new();
473
474 cache.set_degree(12345, "x", 5);
475 assert_eq!(cache.get_degree(12345, "x"), Some(5));
476 assert_eq!(cache.get_degree(12345, "y"), None);
477 assert_eq!(cache.get_degree(99999, "x"), None);
478 }
479
480 #[test]
481 fn test_cache_classification() {
482 let mut cache = PolynomialCache::new();
483
484 cache.set_classification(
485 12345,
486 CachedClassification::Univariate {
487 var: "x".to_string(),
488 degree: 3,
489 },
490 );
491
492 let result = cache.get_classification(12345);
493 assert!(matches!(
494 result,
495 Some(CachedClassification::Univariate { .. })
496 ));
497 }
498
499 #[test]
500 fn test_thread_local_cache() {
501 with_cache(|cache| {
502 cache.set_degree(111, "x", 2);
503 });
504
505 let degree = with_cache(|cache| cache.get_degree(111, "x"));
506 assert_eq!(degree, Some(2));
507 }
508
509 #[test]
510 fn test_cache_lru_eviction() {
511 let mut cache = PolynomialCache::with_capacity(10);
512
513 for i in 0..15 {
514 cache.set_degree(i, "x", i as i64);
515 }
516
517 let stats = cache.stats();
518 assert!(
519 stats.degree_entries <= 10,
520 "Cache should have evicted entries"
521 );
522 }
523
524 #[test]
525 fn test_cache_hit_tracking() {
526 let mut cache = PolynomialCache::new();
527
528 cache.set_degree(123, "x", 5);
529
530 let _ = cache.get_degree(123, "x");
531 let _ = cache.get_degree(123, "y");
532 let _ = cache.get_degree(999, "x");
533
534 let stats = cache.stats();
535 assert_eq!(stats.hits, 1);
536 assert_eq!(stats.misses, 2);
537 }
538
539 #[test]
540 fn test_cache_leading_coeff() {
541 let mut cache = PolynomialCache::new();
542
543 cache.set_leading_coeff(12345, "x", 999);
544 assert_eq!(cache.get_leading_coeff(12345, "x"), Some(999));
545 assert_eq!(cache.get_leading_coeff(12345, "y"), None);
546 }
547
548 #[test]
549 fn test_cache_content() {
550 let mut cache = PolynomialCache::new();
551
552 cache.set_content(12345, "x", 777);
553 assert_eq!(cache.get_content(12345, "x"), Some(777));
554 assert_eq!(cache.get_content(12345, "y"), None);
555 }
556
557 #[test]
558 fn test_cache_stats_helper() {
559 clear_cache();
560 with_cache(|cache| {
561 cache.set_degree(1, "x", 1);
562 cache.set_classification(2, CachedClassification::Integer);
563 });
564
565 let stats = cache_stats();
566 assert_eq!(stats.degree_entries, 1);
567 assert_eq!(stats.classification_entries, 1);
568 }
569
570 #[test]
571 fn test_intpoly_cache() {
572 use crate::symbol;
573
574 let mut cache = PolynomialCache::new();
575 let x = symbol!(x);
576 let poly = IntPoly::from_coeffs(vec![1, 2, 3]);
577
578 cache.set_intpoly(12345, poly.clone(), x.clone());
579 let cached = cache.get_intpoly(12345);
580 assert!(cached.is_some());
581 let (p, v) = cached.unwrap();
582 assert_eq!(p, poly);
583 assert_eq!(v, x);
584
585 assert!(cache.get_intpoly(99999).is_none());
586 }
587
588 #[test]
589 fn test_get_or_compute_intpoly() {
590 use crate::symbol;
591
592 clear_cache();
593 let x = symbol!(x);
594 let poly = IntPoly::from_coeffs(vec![1, 2, 3]);
595 let hash = 54321u64;
596
597 let mut call_count = 0;
598
599 let result1 = get_or_compute_intpoly(hash, || {
600 call_count += 1;
601 Some((poly.clone(), x.clone()))
602 });
603 assert!(result1.is_some());
604 assert_eq!(call_count, 1);
605
606 let result2 = get_or_compute_intpoly(hash, || {
607 call_count += 1;
608 Some((poly.clone(), x.clone()))
609 });
610 assert!(result2.is_some());
611 assert_eq!(call_count, 1);
612
613 let stats = cache_stats();
614 assert!(stats.intpoly_hits >= 1);
615 }
616}