mathhook_core/core/polynomial/
cache.rs

1//! Polynomial Computation Cache
2//!
3//! Thread-local LRU cache for expensive polynomial computations.
4//! Uses side-table design to preserve 32-byte Expression size constraint.
5//!
6//! The cache implements a proper LRU (Least Recently Used) eviction policy
7//! by tracking access order for each cached entry. When the cache reaches
8//! capacity, the least recently accessed entries are evicted first.
9
10use std::cell::RefCell;
11use std::collections::HashMap;
12
13use super::poly::IntPoly;
14use crate::core::Symbol;
15
16/// Access counter for LRU tracking
17/// Higher values indicate more recent access
18static GLOBAL_ACCESS_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
19
20fn next_access_time() -> u64 {
21    GLOBAL_ACCESS_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
22}
23
24/// Entry wrapper that tracks last access time for LRU eviction
25#[derive(Debug, Clone)]
26struct CacheEntry<T> {
27    value: T,
28    last_access: u64,
29}
30
31impl<T> CacheEntry<T> {
32    fn new(value: T) -> Self {
33        Self {
34            value,
35            last_access: next_access_time(),
36        }
37    }
38
39    fn touch(&mut self) {
40        self.last_access = next_access_time();
41    }
42}
43
44/// Thread-local polynomial computation cache
45///
46/// Caches expensive computations like degree, classification, and content
47/// using expression pointer hash as key. Implements true LRU eviction
48/// by tracking access times for each entry.
49///
50/// # Design
51///
52/// Uses pointer-based hashing since `Expression` doesn't implement `Hash`.
53/// The cache is thread-local to avoid synchronization overhead.
54pub struct PolynomialCache {
55    /// Degree cache: expression_hash -> (variable_name -> degree)
56    degree_cache: HashMap<u64, CacheEntry<HashMap<String, i64>>>,
57    /// Classification cache: expression_hash -> classification
58    classification_cache: HashMap<u64, CacheEntry<CachedClassification>>,
59    /// Leading coefficient cache: expression_hash -> (variable_name -> coeff_hash)
60    leading_coeff_cache: HashMap<u64, CacheEntry<HashMap<String, u64>>>,
61    /// Content cache: expression_hash -> (variable_name -> content_hash)
62    content_cache: HashMap<u64, CacheEntry<HashMap<String, u64>>>,
63    /// IntPoly cache: expression_hash -> (IntPoly, Symbol)
64    /// Caches Expression → IntPoly conversions to eliminate repeated bridging
65    intpoly_cache: HashMap<u64, CacheEntry<(IntPoly, Symbol)>>,
66    /// Maximum cache entries per cache type
67    max_entries: usize,
68    /// Cache hit counter for statistics
69    hits: u64,
70    /// Cache miss counter for statistics
71    misses: u64,
72    /// IntPoly-specific hit counter
73    intpoly_hits: u64,
74    /// IntPoly-specific miss counter
75    intpoly_misses: u64,
76}
77
78/// Cached classification result
79#[derive(Debug, Clone)]
80pub enum CachedClassification {
81    Integer,
82    Rational,
83    Univariate {
84        var: String,
85        degree: i64,
86    },
87    Multivariate {
88        vars: Vec<String>,
89        total_degree: i64,
90    },
91    RationalFunction,
92    Transcendental,
93    Symbolic,
94}
95
96impl PolynomialCache {
97    /// Create a new cache with default capacity (1024 entries per cache type)
98    pub fn new() -> Self {
99        Self {
100            degree_cache: HashMap::new(),
101            classification_cache: HashMap::new(),
102            leading_coeff_cache: HashMap::new(),
103            content_cache: HashMap::new(),
104            intpoly_cache: HashMap::new(),
105            max_entries: 1024,
106            hits: 0,
107            misses: 0,
108            intpoly_hits: 0,
109            intpoly_misses: 0,
110        }
111    }
112
113    /// Create a new cache with custom capacity
114    pub fn with_capacity(max_entries: usize) -> Self {
115        Self {
116            degree_cache: HashMap::new(),
117            classification_cache: HashMap::new(),
118            leading_coeff_cache: HashMap::new(),
119            content_cache: HashMap::new(),
120            intpoly_cache: HashMap::new(),
121            max_entries,
122            hits: 0,
123            misses: 0,
124            intpoly_hits: 0,
125            intpoly_misses: 0,
126        }
127    }
128
129    /// Get cached degree for expression and variable
130    pub fn get_degree(&mut self, expr_hash: u64, var: &str) -> Option<i64> {
131        if let Some(entry) = self.degree_cache.get_mut(&expr_hash) {
132            entry.touch();
133            if let Some(&degree) = entry.value.get(var) {
134                self.hits += 1;
135                return Some(degree);
136            }
137        }
138        self.misses += 1;
139        None
140    }
141
142    /// Cache degree for expression and variable
143    pub fn set_degree(&mut self, expr_hash: u64, var: &str, degree: i64) {
144        self.maybe_evict_lru(&CacheType::Degree);
145        self.degree_cache
146            .entry(expr_hash)
147            .or_insert_with(|| CacheEntry::new(HashMap::new()))
148            .value
149            .insert(var.to_owned(), degree);
150    }
151
152    /// Get cached classification for expression
153    pub fn get_classification(&mut self, expr_hash: u64) -> Option<CachedClassification> {
154        if let Some(entry) = self.classification_cache.get_mut(&expr_hash) {
155            entry.touch();
156            self.hits += 1;
157            return Some(entry.value.clone());
158        }
159        self.misses += 1;
160        None
161    }
162
163    /// Cache classification for expression
164    pub fn set_classification(&mut self, expr_hash: u64, classification: CachedClassification) {
165        self.maybe_evict_lru(&CacheType::Classification);
166        self.classification_cache
167            .insert(expr_hash, CacheEntry::new(classification));
168    }
169
170    /// Get cached leading coefficient hash for expression and variable
171    pub fn get_leading_coeff(&mut self, expr_hash: u64, var: &str) -> Option<u64> {
172        if let Some(entry) = self.leading_coeff_cache.get_mut(&expr_hash) {
173            entry.touch();
174            if let Some(&coeff_hash) = entry.value.get(var) {
175                self.hits += 1;
176                return Some(coeff_hash);
177            }
178        }
179        self.misses += 1;
180        None
181    }
182
183    /// Cache leading coefficient hash for expression and variable
184    pub fn set_leading_coeff(&mut self, expr_hash: u64, var: &str, coeff_hash: u64) {
185        self.maybe_evict_lru(&CacheType::LeadingCoeff);
186        self.leading_coeff_cache
187            .entry(expr_hash)
188            .or_insert_with(|| CacheEntry::new(HashMap::new()))
189            .value
190            .insert(var.to_owned(), coeff_hash);
191    }
192
193    /// Get cached content hash for expression and variable
194    pub fn get_content(&mut self, expr_hash: u64, var: &str) -> Option<u64> {
195        if let Some(entry) = self.content_cache.get_mut(&expr_hash) {
196            entry.touch();
197            if let Some(&content_hash) = entry.value.get(var) {
198                self.hits += 1;
199                return Some(content_hash);
200            }
201        }
202        self.misses += 1;
203        None
204    }
205
206    /// Cache content hash for expression and variable
207    pub fn set_content(&mut self, expr_hash: u64, var: &str, content_hash: u64) {
208        self.maybe_evict_lru(&CacheType::Content);
209        self.content_cache
210            .entry(expr_hash)
211            .or_insert_with(|| CacheEntry::new(HashMap::new()))
212            .value
213            .insert(var.to_owned(), content_hash);
214    }
215
216    /// Get cached IntPoly for expression
217    ///
218    /// Returns the cached IntPoly and variable if available.
219    /// This eliminates repeated Expression → IntPoly conversions.
220    pub fn get_intpoly(&mut self, expr_hash: u64) -> Option<(IntPoly, Symbol)> {
221        if let Some(entry) = self.intpoly_cache.get_mut(&expr_hash) {
222            entry.touch();
223            self.intpoly_hits += 1;
224            return Some(entry.value.clone());
225        }
226        self.intpoly_misses += 1;
227        None
228    }
229
230    /// Cache IntPoly for expression
231    ///
232    /// Caches the IntPoly representation along with its variable.
233    pub fn set_intpoly(&mut self, expr_hash: u64, poly: IntPoly, var: Symbol) {
234        self.maybe_evict_lru(&CacheType::IntPoly);
235        self.intpoly_cache
236            .insert(expr_hash, CacheEntry::new((poly, var)));
237    }
238
239    /// Clear all caches
240    pub fn clear(&mut self) {
241        self.degree_cache.clear();
242        self.classification_cache.clear();
243        self.leading_coeff_cache.clear();
244        self.content_cache.clear();
245        self.intpoly_cache.clear();
246        self.hits = 0;
247        self.misses = 0;
248        self.intpoly_hits = 0;
249        self.intpoly_misses = 0;
250    }
251
252    /// Get cache statistics
253    pub fn stats(&self) -> CacheStats {
254        let total_hits = self.hits + self.intpoly_hits;
255        let total_misses = self.misses + self.intpoly_misses;
256        CacheStats {
257            degree_entries: self.degree_cache.len(),
258            classification_entries: self.classification_cache.len(),
259            leading_coeff_entries: self.leading_coeff_cache.len(),
260            content_entries: self.content_cache.len(),
261            intpoly_entries: self.intpoly_cache.len(),
262            hits: total_hits,
263            misses: total_misses,
264            intpoly_hits: self.intpoly_hits,
265            intpoly_misses: self.intpoly_misses,
266            hit_rate: if total_hits + total_misses > 0 {
267                total_hits as f64 / (total_hits + total_misses) as f64
268            } else {
269                0.0
270            },
271            intpoly_hit_rate: if self.intpoly_hits + self.intpoly_misses > 0 {
272                self.intpoly_hits as f64 / (self.intpoly_hits + self.intpoly_misses) as f64
273            } else {
274                0.0
275            },
276        }
277    }
278
279    /// Evict least recently used entries from a specific cache
280    fn maybe_evict_lru(&mut self, cache_type: &CacheType) {
281        match cache_type {
282            CacheType::Degree => {
283                if self.degree_cache.len() >= self.max_entries {
284                    self.evict_lru_from_degree_cache();
285                }
286            }
287            CacheType::Classification => {
288                if self.classification_cache.len() >= self.max_entries {
289                    self.evict_lru_from_classification_cache();
290                }
291            }
292            CacheType::LeadingCoeff => {
293                if self.leading_coeff_cache.len() >= self.max_entries {
294                    self.evict_lru_from_leading_coeff_cache();
295                }
296            }
297            CacheType::Content => {
298                if self.content_cache.len() >= self.max_entries {
299                    self.evict_lru_from_content_cache();
300                }
301            }
302            CacheType::IntPoly => {
303                if self.intpoly_cache.len() >= self.max_entries {
304                    self.evict_lru_from_intpoly_cache();
305                }
306            }
307        }
308    }
309
310    fn evict_lru_from_degree_cache(&mut self) {
311        let to_remove = self.max_entries / 4;
312        let mut entries: Vec<_> = self
313            .degree_cache
314            .iter()
315            .map(|(k, v)| (*k, v.last_access))
316            .collect();
317        entries.sort_by_key(|(_, access)| *access);
318
319        for (key, _) in entries.into_iter().take(to_remove) {
320            self.degree_cache.remove(&key);
321        }
322    }
323
324    fn evict_lru_from_classification_cache(&mut self) {
325        let to_remove = self.max_entries / 4;
326        let mut entries: Vec<_> = self
327            .classification_cache
328            .iter()
329            .map(|(k, v)| (*k, v.last_access))
330            .collect();
331        entries.sort_by_key(|(_, access)| *access);
332
333        for (key, _) in entries.into_iter().take(to_remove) {
334            self.classification_cache.remove(&key);
335        }
336    }
337
338    fn evict_lru_from_leading_coeff_cache(&mut self) {
339        let to_remove = self.max_entries / 4;
340        let mut entries: Vec<_> = self
341            .leading_coeff_cache
342            .iter()
343            .map(|(k, v)| (*k, v.last_access))
344            .collect();
345        entries.sort_by_key(|(_, access)| *access);
346
347        for (key, _) in entries.into_iter().take(to_remove) {
348            self.leading_coeff_cache.remove(&key);
349        }
350    }
351
352    fn evict_lru_from_content_cache(&mut self) {
353        let to_remove = self.max_entries / 4;
354        let mut entries: Vec<_> = self
355            .content_cache
356            .iter()
357            .map(|(k, v)| (*k, v.last_access))
358            .collect();
359        entries.sort_by_key(|(_, access)| *access);
360
361        for (key, _) in entries.into_iter().take(to_remove) {
362            self.content_cache.remove(&key);
363        }
364    }
365
366    fn evict_lru_from_intpoly_cache(&mut self) {
367        let to_remove = self.max_entries / 4;
368        let mut entries: Vec<_> = self
369            .intpoly_cache
370            .iter()
371            .map(|(k, v)| (*k, v.last_access))
372            .collect();
373        entries.sort_by_key(|(_, access)| *access);
374
375        for (key, _) in entries.into_iter().take(to_remove) {
376            self.intpoly_cache.remove(&key);
377        }
378    }
379}
380
381/// Internal enum to identify cache types for eviction
382enum CacheType {
383    Degree,
384    Classification,
385    LeadingCoeff,
386    Content,
387    IntPoly,
388}
389
390/// Cache statistics for monitoring
391#[derive(Debug, Clone)]
392pub struct CacheStats {
393    pub degree_entries: usize,
394    pub classification_entries: usize,
395    pub leading_coeff_entries: usize,
396    pub content_entries: usize,
397    pub intpoly_entries: usize,
398    pub hits: u64,
399    pub misses: u64,
400    pub intpoly_hits: u64,
401    pub intpoly_misses: u64,
402    pub hit_rate: f64,
403    pub intpoly_hit_rate: f64,
404}
405
406impl Default for PolynomialCache {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412// Thread-local cache instance
413thread_local! {
414    static CACHE: RefCell<PolynomialCache> = RefCell::new(PolynomialCache::new());
415}
416
417/// Access the thread-local polynomial cache
418pub fn with_cache<F, R>(f: F) -> R
419where
420    F: FnOnce(&mut PolynomialCache) -> R,
421{
422    CACHE.with(|cache| f(&mut cache.borrow_mut()))
423}
424
425/// Clear the thread-local polynomial cache
426pub fn clear_cache() {
427    with_cache(|cache| cache.clear());
428}
429
430/// Get statistics from the thread-local polynomial cache
431pub fn cache_stats() -> CacheStats {
432    with_cache(|cache| cache.stats())
433}
434
435/// Get or compute IntPoly representation with caching
436///
437/// This is the main entry point for eliminating internal bridging.
438/// It computes Expression → IntPoly once and caches the result.
439/// Subsequent calls for the same expression hit the cache.
440///
441/// # Arguments
442/// * `expr` - The Expression to convert
443/// * `hash` - Pre-computed structural hash of the expression
444/// * `compute_fn` - Function to compute IntPoly if not cached
445///
446/// # Returns
447/// `Some((IntPoly, Symbol))` if conversion succeeds, `None` otherwise
448pub fn get_or_compute_intpoly<F>(expr_hash: u64, compute_fn: F) -> Option<(IntPoly, Symbol)>
449where
450    F: FnOnce() -> Option<(IntPoly, Symbol)>,
451{
452    with_cache(|cache| {
453        if let Some(cached) = cache.get_intpoly(expr_hash) {
454            return Some(cached);
455        }
456
457        if let Some((poly, var)) = compute_fn() {
458            cache.set_intpoly(expr_hash, poly.clone(), var.clone());
459            Some((poly, var))
460        } else {
461            None
462        }
463    })
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_cache_degree() {
472        let mut cache = PolynomialCache::new();
473
474        cache.set_degree(12345, "x", 5);
475        assert_eq!(cache.get_degree(12345, "x"), Some(5));
476        assert_eq!(cache.get_degree(12345, "y"), None);
477        assert_eq!(cache.get_degree(99999, "x"), None);
478    }
479
480    #[test]
481    fn test_cache_classification() {
482        let mut cache = PolynomialCache::new();
483
484        cache.set_classification(
485            12345,
486            CachedClassification::Univariate {
487                var: "x".to_string(),
488                degree: 3,
489            },
490        );
491
492        let result = cache.get_classification(12345);
493        assert!(matches!(
494            result,
495            Some(CachedClassification::Univariate { .. })
496        ));
497    }
498
499    #[test]
500    fn test_thread_local_cache() {
501        with_cache(|cache| {
502            cache.set_degree(111, "x", 2);
503        });
504
505        let degree = with_cache(|cache| cache.get_degree(111, "x"));
506        assert_eq!(degree, Some(2));
507    }
508
509    #[test]
510    fn test_cache_lru_eviction() {
511        let mut cache = PolynomialCache::with_capacity(10);
512
513        for i in 0..15 {
514            cache.set_degree(i, "x", i as i64);
515        }
516
517        let stats = cache.stats();
518        assert!(
519            stats.degree_entries <= 10,
520            "Cache should have evicted entries"
521        );
522    }
523
524    #[test]
525    fn test_cache_hit_tracking() {
526        let mut cache = PolynomialCache::new();
527
528        cache.set_degree(123, "x", 5);
529
530        let _ = cache.get_degree(123, "x");
531        let _ = cache.get_degree(123, "y");
532        let _ = cache.get_degree(999, "x");
533
534        let stats = cache.stats();
535        assert_eq!(stats.hits, 1);
536        assert_eq!(stats.misses, 2);
537    }
538
539    #[test]
540    fn test_cache_leading_coeff() {
541        let mut cache = PolynomialCache::new();
542
543        cache.set_leading_coeff(12345, "x", 999);
544        assert_eq!(cache.get_leading_coeff(12345, "x"), Some(999));
545        assert_eq!(cache.get_leading_coeff(12345, "y"), None);
546    }
547
548    #[test]
549    fn test_cache_content() {
550        let mut cache = PolynomialCache::new();
551
552        cache.set_content(12345, "x", 777);
553        assert_eq!(cache.get_content(12345, "x"), Some(777));
554        assert_eq!(cache.get_content(12345, "y"), None);
555    }
556
557    #[test]
558    fn test_cache_stats_helper() {
559        clear_cache();
560        with_cache(|cache| {
561            cache.set_degree(1, "x", 1);
562            cache.set_classification(2, CachedClassification::Integer);
563        });
564
565        let stats = cache_stats();
566        assert_eq!(stats.degree_entries, 1);
567        assert_eq!(stats.classification_entries, 1);
568    }
569
570    #[test]
571    fn test_intpoly_cache() {
572        use crate::symbol;
573
574        let mut cache = PolynomialCache::new();
575        let x = symbol!(x);
576        let poly = IntPoly::from_coeffs(vec![1, 2, 3]);
577
578        cache.set_intpoly(12345, poly.clone(), x.clone());
579        let cached = cache.get_intpoly(12345);
580        assert!(cached.is_some());
581        let (p, v) = cached.unwrap();
582        assert_eq!(p, poly);
583        assert_eq!(v, x);
584
585        assert!(cache.get_intpoly(99999).is_none());
586    }
587
588    #[test]
589    fn test_get_or_compute_intpoly() {
590        use crate::symbol;
591
592        clear_cache();
593        let x = symbol!(x);
594        let poly = IntPoly::from_coeffs(vec![1, 2, 3]);
595        let hash = 54321u64;
596
597        let mut call_count = 0;
598
599        let result1 = get_or_compute_intpoly(hash, || {
600            call_count += 1;
601            Some((poly.clone(), x.clone()))
602        });
603        assert!(result1.is_some());
604        assert_eq!(call_count, 1);
605
606        let result2 = get_or_compute_intpoly(hash, || {
607            call_count += 1;
608            Some((poly.clone(), x.clone()))
609        });
610        assert!(result2.is_some());
611        assert_eq!(call_count, 1);
612
613        let stats = cache_stats();
614        assert!(stats.intpoly_hits >= 1);
615    }
616}