quantrs2_symengine_pure/cache/
mod.rs

1//! Expression caching and memoization.
2//!
3//! This module provides caching mechanisms for expensive operations
4//! like evaluation, simplification, and complex number operations.
5//!
6//! ## Features
7//!
8//! - **`EvalCache`**: Thread-safe cache for real-valued evaluation results with LRU eviction
9//! - **`ComplexEvalCache`**: Thread-safe cache for complex-valued evaluations (quantum amplitudes)
10//! - **`SimplificationCache`**: Cache for expression simplification results
11//! - **`BatchEvalCache`**: Optimized for VQE optimization loops with parameter sweeps
12//! - **`CachedEvaluator`**: Convenient wrapper with all caching features integrated
13//! - **`ExpressionCache`**: Hash consing for structural sharing of expressions
14//!
15//! ## Performance Benefits
16//!
17//! Expression caching is critical for quantum computing applications:
18//!
19//! - **VQE/QAOA loops**: Same expressions evaluated thousands of times with different parameters
20//! - **Gradient computation**: Derivatives computed repeatedly during optimization
21//! - **Circuit simulation**: Gate matrices cached after first computation
22//!
23//! ## Example
24//!
25//! ```ignore
26//! use quantrs2_symengine_pure::cache::{CachedEvaluator, hash_params};
27//! use quantrs2_symengine_pure::Expression;
28//! use std::collections::HashMap;
29//!
30//! let evaluator = CachedEvaluator::new();
31//! let expr = Expression::symbol("x").sin();
32//!
33//! // First evaluation computes the result
34//! let mut params = HashMap::new();
35//! params.insert("x".to_string(), 0.5);
36//! let result1 = evaluator.eval(&expr, &params).unwrap();
37//!
38//! // Second evaluation retrieves from cache
39//! let result2 = evaluator.eval(&expr, &params).unwrap();
40//! assert!((result1 - result2).abs() < 1e-10);
41//!
42//! // Check hit rate
43//! let stats = evaluator.stats();
44//! println!("Cache hit rate: {:.1}%", stats.overall_hit_rate() * 100.0);
45//! ```
46
47use std::collections::HashMap;
48use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
49use std::sync::Arc;
50
51use dashmap::DashMap;
52use rustc_hash::FxHasher;
53use scirs2_core::Complex64;
54
55use crate::error::SymEngineResult;
56use crate::expr::Expression;
57
58/// Default maximum cache size (entries)
59pub const DEFAULT_MAX_CACHE_SIZE: usize = 10_000;
60
61/// Thread-safe cache for expression evaluation results with LRU eviction.
62///
63/// Uses DashMap for concurrent access and FxHasher for fast hashing.
64pub struct EvalCache {
65    cache: DashMap<(u64, u64), CachedValue<f64>, std::hash::BuildHasherDefault<FxHasher>>,
66    max_size: usize,
67    access_counter: AtomicU64,
68    hits: AtomicUsize,
69    misses: AtomicUsize,
70}
71
72/// A cached value with access tracking for LRU eviction
73#[derive(Clone)]
74struct CachedValue<T> {
75    value: T,
76    last_access: u64,
77}
78
79impl EvalCache {
80    /// Create a new evaluation cache with default size
81    #[must_use]
82    pub fn new() -> Self {
83        Self::with_capacity(DEFAULT_MAX_CACHE_SIZE)
84    }
85
86    /// Create a new evaluation cache with specified maximum size
87    #[must_use]
88    pub fn with_capacity(max_size: usize) -> Self {
89        Self {
90            cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
91            max_size,
92            access_counter: AtomicU64::new(0),
93            hits: AtomicUsize::new(0),
94            misses: AtomicUsize::new(0),
95        }
96    }
97
98    /// Get or compute an evaluation result
99    pub fn get_or_compute<F>(&self, expr_hash: u64, params_hash: u64, compute: F) -> f64
100    where
101        F: FnOnce() -> f64,
102    {
103        let key = (expr_hash, params_hash);
104        let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
105
106        if let Some(mut entry) = self.cache.get_mut(&key) {
107            self.hits.fetch_add(1, Ordering::Relaxed);
108            entry.last_access = access_time;
109            return entry.value;
110        }
111
112        self.misses.fetch_add(1, Ordering::Relaxed);
113        let result = compute();
114
115        // Check if we need to evict
116        if self.cache.len() >= self.max_size {
117            self.evict_lru();
118        }
119
120        self.cache.insert(
121            key,
122            CachedValue {
123                value: result,
124                last_access: access_time,
125            },
126        );
127        result
128    }
129
130    /// Try to get a cached value without computing
131    #[must_use]
132    pub fn get(&self, expr_hash: u64, params_hash: u64) -> Option<f64> {
133        let key = (expr_hash, params_hash);
134        let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
135
136        self.cache.get_mut(&key).map(|mut entry| {
137            entry.last_access = access_time;
138            entry.value
139        })
140    }
141
142    /// Get or compute with Result return type
143    pub fn get_or_try_compute<F, E>(
144        &self,
145        expr_hash: u64,
146        params_hash: u64,
147        compute: F,
148    ) -> Result<f64, E>
149    where
150        F: FnOnce() -> Result<f64, E>,
151    {
152        let key = (expr_hash, params_hash);
153        let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
154
155        if let Some(mut entry) = self.cache.get_mut(&key) {
156            self.hits.fetch_add(1, Ordering::Relaxed);
157            entry.last_access = access_time;
158            return Ok(entry.value);
159        }
160
161        self.misses.fetch_add(1, Ordering::Relaxed);
162        let result = compute()?;
163
164        if self.cache.len() >= self.max_size {
165            self.evict_lru();
166        }
167
168        self.cache.insert(
169            key,
170            CachedValue {
171                value: result,
172                last_access: access_time,
173            },
174        );
175        Ok(result)
176    }
177
178    /// Insert a value into the cache
179    pub fn insert(&self, expr_hash: u64, params_hash: u64, value: f64) {
180        let key = (expr_hash, params_hash);
181        let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
182
183        if self.cache.len() >= self.max_size {
184            self.evict_lru();
185        }
186
187        self.cache.insert(
188            key,
189            CachedValue {
190                value,
191                last_access: access_time,
192            },
193        );
194    }
195
196    /// Evict the least recently used entries (removes ~10% of cache)
197    fn evict_lru(&self) {
198        let evict_count = self.max_size / 10;
199        if evict_count == 0 {
200            return;
201        }
202
203        // Collect entries sorted by access time
204        let mut entries: Vec<_> = self
205            .cache
206            .iter()
207            .map(|e| (*e.key(), e.value().last_access))
208            .collect();
209        entries.sort_by_key(|(_, access)| *access);
210
211        // Remove oldest entries
212        for (key, _) in entries.into_iter().take(evict_count) {
213            self.cache.remove(&key);
214        }
215    }
216
217    /// Clear the cache
218    pub fn clear(&self) {
219        self.cache.clear();
220        self.hits.store(0, Ordering::Relaxed);
221        self.misses.store(0, Ordering::Relaxed);
222    }
223
224    /// Get the number of cached entries
225    #[must_use]
226    pub fn len(&self) -> usize {
227        self.cache.len()
228    }
229
230    /// Check if the cache is empty
231    #[must_use]
232    pub fn is_empty(&self) -> bool {
233        self.cache.is_empty()
234    }
235
236    /// Get cache statistics
237    #[must_use]
238    pub fn stats(&self) -> CacheStats {
239        let hits = self.hits.load(Ordering::Relaxed);
240        let misses = self.misses.load(Ordering::Relaxed);
241        CacheStats {
242            size: self.cache.len(),
243            max_size: self.max_size,
244            hits,
245            misses,
246            hit_rate: if hits + misses > 0 {
247                hits as f64 / (hits + misses) as f64
248            } else {
249                0.0
250            },
251        }
252    }
253}
254
255/// Cache statistics
256#[derive(Debug, Clone)]
257pub struct CacheStats {
258    /// Current number of entries
259    pub size: usize,
260    /// Maximum allowed entries
261    pub max_size: usize,
262    /// Number of cache hits
263    pub hits: usize,
264    /// Number of cache misses
265    pub misses: usize,
266    /// Hit rate (0.0 to 1.0)
267    pub hit_rate: f64,
268}
269
270impl Default for EvalCache {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276/// Hash consing for structural sharing of expressions
277pub struct ExpressionCache {
278    cache: DashMap<u64, Arc<Expression>, std::hash::BuildHasherDefault<FxHasher>>,
279}
280
281impl ExpressionCache {
282    /// Create a new expression cache
283    #[must_use]
284    pub fn new() -> Self {
285        Self {
286            cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
287        }
288    }
289
290    /// Get or insert an expression, returning a shared reference
291    pub fn get_or_insert(&self, expr: Expression) -> Arc<Expression> {
292        let hash = compute_hash(&expr);
293        self.cache
294            .entry(hash)
295            .or_insert_with(|| Arc::new(expr))
296            .clone()
297    }
298
299    /// Clear the cache
300    pub fn clear(&self) {
301        self.cache.clear();
302    }
303}
304
305impl Default for ExpressionCache {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311// =========================================================================
312// Complex Evaluation Cache
313// =========================================================================
314
315/// Thread-safe cache for complex number evaluation results.
316///
317/// Used for caching quantum amplitude calculations and complex-valued
318/// expression evaluations.
319pub struct ComplexEvalCache {
320    cache: DashMap<(u64, u64), CachedValue<Complex64>, std::hash::BuildHasherDefault<FxHasher>>,
321    max_size: usize,
322    access_counter: AtomicU64,
323    hits: AtomicUsize,
324    misses: AtomicUsize,
325}
326
327impl ComplexEvalCache {
328    /// Create a new complex evaluation cache with default size
329    #[must_use]
330    pub fn new() -> Self {
331        Self::with_capacity(DEFAULT_MAX_CACHE_SIZE)
332    }
333
334    /// Create a new complex evaluation cache with specified maximum size
335    #[must_use]
336    pub fn with_capacity(max_size: usize) -> Self {
337        Self {
338            cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
339            max_size,
340            access_counter: AtomicU64::new(0),
341            hits: AtomicUsize::new(0),
342            misses: AtomicUsize::new(0),
343        }
344    }
345
346    /// Get or compute a complex evaluation result
347    pub fn get_or_compute<F>(&self, expr_hash: u64, params_hash: u64, compute: F) -> Complex64
348    where
349        F: FnOnce() -> Complex64,
350    {
351        let key = (expr_hash, params_hash);
352        let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
353
354        if let Some(mut entry) = self.cache.get_mut(&key) {
355            self.hits.fetch_add(1, Ordering::Relaxed);
356            entry.last_access = access_time;
357            return entry.value;
358        }
359
360        self.misses.fetch_add(1, Ordering::Relaxed);
361        let result = compute();
362
363        if self.cache.len() >= self.max_size {
364            self.evict_lru();
365        }
366
367        self.cache.insert(
368            key,
369            CachedValue {
370                value: result,
371                last_access: access_time,
372            },
373        );
374        result
375    }
376
377    /// Get or compute with Result return type
378    pub fn get_or_try_compute<F, E>(
379        &self,
380        expr_hash: u64,
381        params_hash: u64,
382        compute: F,
383    ) -> Result<Complex64, E>
384    where
385        F: FnOnce() -> Result<Complex64, E>,
386    {
387        let key = (expr_hash, params_hash);
388        let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
389
390        if let Some(mut entry) = self.cache.get_mut(&key) {
391            self.hits.fetch_add(1, Ordering::Relaxed);
392            entry.last_access = access_time;
393            return Ok(entry.value);
394        }
395
396        self.misses.fetch_add(1, Ordering::Relaxed);
397        let result = compute()?;
398
399        if self.cache.len() >= self.max_size {
400            self.evict_lru();
401        }
402
403        self.cache.insert(
404            key,
405            CachedValue {
406                value: result,
407                last_access: access_time,
408            },
409        );
410        Ok(result)
411    }
412
413    /// Evict the least recently used entries
414    fn evict_lru(&self) {
415        let evict_count = self.max_size / 10;
416        if evict_count == 0 {
417            return;
418        }
419
420        let mut entries: Vec<_> = self
421            .cache
422            .iter()
423            .map(|e| (*e.key(), e.value().last_access))
424            .collect();
425        entries.sort_by_key(|(_, access)| *access);
426
427        for (key, _) in entries.into_iter().take(evict_count) {
428            self.cache.remove(&key);
429        }
430    }
431
432    /// Clear the cache
433    pub fn clear(&self) {
434        self.cache.clear();
435        self.hits.store(0, Ordering::Relaxed);
436        self.misses.store(0, Ordering::Relaxed);
437    }
438
439    /// Get the number of cached entries
440    #[must_use]
441    pub fn len(&self) -> usize {
442        self.cache.len()
443    }
444
445    /// Check if the cache is empty
446    #[must_use]
447    pub fn is_empty(&self) -> bool {
448        self.cache.is_empty()
449    }
450
451    /// Get cache statistics
452    #[must_use]
453    pub fn stats(&self) -> CacheStats {
454        let hits = self.hits.load(Ordering::Relaxed);
455        let misses = self.misses.load(Ordering::Relaxed);
456        CacheStats {
457            size: self.cache.len(),
458            max_size: self.max_size,
459            hits,
460            misses,
461            hit_rate: if hits + misses > 0 {
462                hits as f64 / (hits + misses) as f64
463            } else {
464                0.0
465            },
466        }
467    }
468}
469
470impl Default for ComplexEvalCache {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476// =========================================================================
477// Simplification Cache
478// =========================================================================
479
480/// Thread-safe cache for expression simplification results.
481///
482/// Caches the result of expensive simplification operations to avoid
483/// re-running e-graph saturation for the same expressions.
484pub struct SimplificationCache {
485    cache: DashMap<u64, CachedValue<Expression>, std::hash::BuildHasherDefault<FxHasher>>,
486    max_size: usize,
487    access_counter: AtomicU64,
488    hits: AtomicUsize,
489    misses: AtomicUsize,
490}
491
492impl SimplificationCache {
493    /// Create a new simplification cache with default size
494    #[must_use]
495    pub fn new() -> Self {
496        Self::with_capacity(DEFAULT_MAX_CACHE_SIZE)
497    }
498
499    /// Create a new simplification cache with specified maximum size
500    #[must_use]
501    pub fn with_capacity(max_size: usize) -> Self {
502        Self {
503            cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
504            max_size,
505            access_counter: AtomicU64::new(0),
506            hits: AtomicUsize::new(0),
507            misses: AtomicUsize::new(0),
508        }
509    }
510
511    /// Get or compute a simplified expression
512    pub fn get_or_simplify<F>(&self, expr: &Expression, simplify: F) -> Expression
513    where
514        F: FnOnce() -> Expression,
515    {
516        let expr_hash = compute_hash(expr);
517        let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
518
519        if let Some(mut entry) = self.cache.get_mut(&expr_hash) {
520            self.hits.fetch_add(1, Ordering::Relaxed);
521            entry.last_access = access_time;
522            return entry.value.clone();
523        }
524
525        self.misses.fetch_add(1, Ordering::Relaxed);
526        let result = simplify();
527
528        if self.cache.len() >= self.max_size {
529            self.evict_lru();
530        }
531
532        self.cache.insert(
533            expr_hash,
534            CachedValue {
535                value: result.clone(),
536                last_access: access_time,
537            },
538        );
539        result
540    }
541
542    /// Evict the least recently used entries
543    fn evict_lru(&self) {
544        let evict_count = self.max_size / 10;
545        if evict_count == 0 {
546            return;
547        }
548
549        let mut entries: Vec<_> = self
550            .cache
551            .iter()
552            .map(|e| (*e.key(), e.value().last_access))
553            .collect();
554        entries.sort_by_key(|(_, access)| *access);
555
556        for (key, _) in entries.into_iter().take(evict_count) {
557            self.cache.remove(&key);
558        }
559    }
560
561    /// Clear the cache
562    pub fn clear(&self) {
563        self.cache.clear();
564        self.hits.store(0, Ordering::Relaxed);
565        self.misses.store(0, Ordering::Relaxed);
566    }
567
568    /// Get the number of cached entries
569    #[must_use]
570    pub fn len(&self) -> usize {
571        self.cache.len()
572    }
573
574    /// Check if the cache is empty
575    #[must_use]
576    pub fn is_empty(&self) -> bool {
577        self.cache.is_empty()
578    }
579
580    /// Get cache statistics
581    #[must_use]
582    pub fn stats(&self) -> CacheStats {
583        let hits = self.hits.load(Ordering::Relaxed);
584        let misses = self.misses.load(Ordering::Relaxed);
585        CacheStats {
586            size: self.cache.len(),
587            max_size: self.max_size,
588            hits,
589            misses,
590            hit_rate: if hits + misses > 0 {
591                hits as f64 / (hits + misses) as f64
592            } else {
593                0.0
594            },
595        }
596    }
597}
598
599impl Default for SimplificationCache {
600    fn default() -> Self {
601        Self::new()
602    }
603}
604
605// =========================================================================
606// Batch Evaluation Cache
607// =========================================================================
608
609/// Cache for batch evaluation results in VQE optimization loops.
610///
611/// Optimized for scenarios where the same expression is evaluated many times
612/// with slightly different parameter sets (e.g., parameter sweeps).
613pub struct BatchEvalCache {
614    /// Expression hash -> (params_hash -> result)
615    cache: DashMap<
616        u64,
617        DashMap<u64, f64, std::hash::BuildHasherDefault<FxHasher>>,
618        std::hash::BuildHasherDefault<FxHasher>,
619    >,
620    max_expressions: usize,
621    max_params_per_expr: usize,
622}
623
624impl BatchEvalCache {
625    /// Create a new batch evaluation cache
626    #[must_use]
627    pub fn new() -> Self {
628        Self::with_capacity(1000, 1000)
629    }
630
631    /// Create a new batch evaluation cache with specified capacities
632    #[must_use]
633    pub fn with_capacity(max_expressions: usize, max_params_per_expr: usize) -> Self {
634        Self {
635            cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
636            max_expressions,
637            max_params_per_expr,
638        }
639    }
640
641    /// Get or compute a batch of evaluation results
642    pub fn get_or_compute_batch<F>(
643        &self,
644        expr_hash: u64,
645        param_hashes: &[u64],
646        compute: F,
647    ) -> Vec<f64>
648    where
649        F: FnOnce(&[usize]) -> Vec<f64>,
650    {
651        // Find which parameter sets we need to compute
652        let expr_cache = self.cache.entry(expr_hash).or_insert_with(|| {
653            DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default())
654        });
655
656        let mut results = vec![0.0; param_hashes.len()];
657        let mut missing_indices = Vec::new();
658
659        for (i, &ph) in param_hashes.iter().enumerate() {
660            if let Some(val) = expr_cache.get(&ph) {
661                results[i] = *val;
662            } else {
663                missing_indices.push(i);
664            }
665        }
666
667        // Compute missing values
668        if !missing_indices.is_empty() {
669            let computed = compute(&missing_indices);
670
671            for (j, &i) in missing_indices.iter().enumerate() {
672                results[i] = computed[j];
673                let ph = param_hashes[i];
674
675                // Check if we need to evict from per-expression cache
676                if expr_cache.len() >= self.max_params_per_expr {
677                    // Simple random eviction (for speed)
678                    // Extract key first to avoid holding the iterator lock
679                    let first_key = expr_cache.iter().next().map(|e| *e.key());
680                    if let Some(key) = first_key {
681                        expr_cache.remove(&key);
682                    }
683                }
684
685                expr_cache.insert(ph, computed[j]);
686            }
687        }
688
689        results
690    }
691
692    /// Clear the cache
693    pub fn clear(&self) {
694        self.cache.clear();
695    }
696
697    /// Get the number of cached expressions
698    #[must_use]
699    pub fn len(&self) -> usize {
700        self.cache.len()
701    }
702
703    /// Check if the cache is empty
704    #[must_use]
705    pub fn is_empty(&self) -> bool {
706        self.cache.is_empty()
707    }
708
709    /// Get total number of cached parameter sets across all expressions
710    #[must_use]
711    pub fn total_params_cached(&self) -> usize {
712        self.cache.iter().map(|e| e.value().len()).sum()
713    }
714}
715
716impl Default for BatchEvalCache {
717    fn default() -> Self {
718        Self::new()
719    }
720}
721
722// =========================================================================
723// Cached Expression Evaluator
724// =========================================================================
725
726/// An expression evaluator with integrated caching.
727///
728/// This provides a convenient interface for evaluating expressions with
729/// automatic caching of results.
730#[allow(clippy::struct_field_names)]
731pub struct CachedEvaluator {
732    eval_cache: EvalCache,
733    complex_cache: ComplexEvalCache,
734    simplification_cache: SimplificationCache,
735}
736
737impl CachedEvaluator {
738    /// Create a new cached evaluator
739    #[must_use]
740    pub fn new() -> Self {
741        Self {
742            eval_cache: EvalCache::new(),
743            complex_cache: ComplexEvalCache::new(),
744            simplification_cache: SimplificationCache::new(),
745        }
746    }
747
748    /// Create a new cached evaluator with specified cache sizes
749    #[must_use]
750    pub fn with_capacity(eval_size: usize, complex_size: usize, simplify_size: usize) -> Self {
751        Self {
752            eval_cache: EvalCache::with_capacity(eval_size),
753            complex_cache: ComplexEvalCache::with_capacity(complex_size),
754            simplification_cache: SimplificationCache::with_capacity(simplify_size),
755        }
756    }
757
758    /// Evaluate an expression with caching
759    pub fn eval(&self, expr: &Expression, values: &HashMap<String, f64>) -> SymEngineResult<f64> {
760        let expr_hash = compute_hash(expr);
761        let params_hash = hash_params(values);
762
763        // Use get_or_try_compute to properly track hits/misses
764        self.eval_cache
765            .get_or_try_compute(expr_hash, params_hash, || expr.eval(values))
766    }
767
768    /// Evaluate an expression as complex with caching
769    pub fn eval_complex(
770        &self,
771        expr: &Expression,
772        values: &HashMap<String, f64>,
773    ) -> SymEngineResult<Complex64> {
774        let expr_hash = compute_hash(expr);
775        let params_hash = hash_params(values);
776
777        self.complex_cache
778            .get_or_try_compute(expr_hash, params_hash, || expr.eval_complex(values))
779    }
780
781    /// Simplify an expression with caching
782    pub fn simplify(&self, expr: &Expression) -> Expression {
783        self.simplification_cache
784            .get_or_simplify(expr, || expr.simplify())
785    }
786
787    /// Clear all caches
788    pub fn clear(&self) {
789        self.eval_cache.clear();
790        self.complex_cache.clear();
791        self.simplification_cache.clear();
792    }
793
794    /// Get combined cache statistics
795    #[must_use]
796    pub fn stats(&self) -> CombinedCacheStats {
797        CombinedCacheStats {
798            eval: self.eval_cache.stats(),
799            complex: self.complex_cache.stats(),
800            simplification: self.simplification_cache.stats(),
801        }
802    }
803}
804
805impl Default for CachedEvaluator {
806    fn default() -> Self {
807        Self::new()
808    }
809}
810
811/// Combined statistics for all cache types
812#[derive(Debug, Clone)]
813pub struct CombinedCacheStats {
814    /// Real evaluation cache stats
815    pub eval: CacheStats,
816    /// Complex evaluation cache stats
817    pub complex: CacheStats,
818    /// Simplification cache stats
819    pub simplification: CacheStats,
820}
821
822impl CombinedCacheStats {
823    /// Get the total number of cached entries
824    #[must_use]
825    pub const fn total_size(&self) -> usize {
826        self.eval.size + self.complex.size + self.simplification.size
827    }
828
829    /// Get the total number of cache hits
830    #[must_use]
831    pub const fn total_hits(&self) -> usize {
832        self.eval.hits + self.complex.hits + self.simplification.hits
833    }
834
835    /// Get the total number of cache misses
836    #[must_use]
837    pub const fn total_misses(&self) -> usize {
838        self.eval.misses + self.complex.misses + self.simplification.misses
839    }
840
841    /// Get the overall hit rate
842    #[must_use]
843    pub fn overall_hit_rate(&self) -> f64 {
844        let total = self.total_hits() + self.total_misses();
845        if total > 0 {
846            self.total_hits() as f64 / total as f64
847        } else {
848            0.0
849        }
850    }
851}
852
853// =========================================================================
854// Hash Functions
855// =========================================================================
856
857/// Compute a hash for an expression
858pub fn compute_hash(expr: &Expression) -> u64 {
859    use std::hash::{Hash, Hasher};
860    let mut hasher = FxHasher::default();
861    expr.to_string().hash(&mut hasher);
862    hasher.finish()
863}
864
865/// Compute a hash for a set of real parameters
866pub fn hash_params(params: &HashMap<String, f64>) -> u64 {
867    use std::hash::{Hash, Hasher};
868    let mut hasher = FxHasher::default();
869
870    // Sort keys for consistent hashing
871    let mut keys: Vec<_> = params.keys().collect();
872    keys.sort();
873
874    for key in keys {
875        key.hash(&mut hasher);
876        if let Some(value) = params.get(key) {
877            value.to_bits().hash(&mut hasher);
878        }
879    }
880
881    hasher.finish()
882}
883
884/// Compute a hash for complex parameters
885pub fn hash_complex_params(params: &HashMap<String, Complex64>) -> u64 {
886    use std::hash::{Hash, Hasher};
887    let mut hasher = FxHasher::default();
888
889    let mut keys: Vec<_> = params.keys().collect();
890    keys.sort();
891
892    for key in keys {
893        key.hash(&mut hasher);
894        if let Some(value) = params.get(key) {
895            value.re.to_bits().hash(&mut hasher);
896            value.im.to_bits().hash(&mut hasher);
897        }
898    }
899
900    hasher.finish()
901}
902
903/// Compute a hash for a parameter array (for batch operations)
904pub fn hash_param_array(params: &[f64]) -> u64 {
905    use std::hash::{Hash, Hasher};
906    let mut hasher = FxHasher::default();
907
908    for value in params {
909        value.to_bits().hash(&mut hasher);
910    }
911
912    hasher.finish()
913}
914
915#[cfg(test)]
916#[allow(clippy::approx_constant)]
917mod tests {
918    use super::*;
919
920    #[test]
921    fn test_eval_cache() {
922        let cache = EvalCache::new();
923
924        let result1 = cache.get_or_compute(1, 1, || 42.0);
925        assert!((result1 - 42.0).abs() < 1e-10);
926
927        // Should return cached value
928        let result2 = cache.get_or_compute(1, 1, || 100.0);
929        assert!((result2 - 42.0).abs() < 1e-10);
930
931        assert_eq!(cache.len(), 1);
932    }
933
934    #[test]
935    fn test_eval_cache_stats() {
936        let cache = EvalCache::new();
937
938        // Miss then hit
939        cache.get_or_compute(1, 1, || 42.0);
940        cache.get_or_compute(1, 1, || 42.0);
941
942        let stats = cache.stats();
943        assert_eq!(stats.hits, 1);
944        assert_eq!(stats.misses, 1);
945        assert!((stats.hit_rate - 0.5).abs() < 1e-10);
946    }
947
948    #[test]
949    fn test_eval_cache_lru_eviction() {
950        let cache = EvalCache::with_capacity(10);
951
952        // Fill cache beyond capacity
953        for i in 0..15u64 {
954            cache.get_or_compute(i, 0, || i as f64);
955        }
956
957        // Should have evicted some entries
958        assert!(cache.len() <= 10);
959    }
960
961    #[test]
962    fn test_complex_eval_cache() {
963        let cache = ComplexEvalCache::new();
964
965        let result1 = cache.get_or_compute(1, 1, || Complex64::new(3.0, 4.0));
966        assert!((result1.re - 3.0).abs() < 1e-10);
967        assert!((result1.im - 4.0).abs() < 1e-10);
968
969        // Should return cached value
970        let result2 = cache.get_or_compute(1, 1, || Complex64::new(100.0, 200.0));
971        assert!((result2.re - 3.0).abs() < 1e-10);
972        assert!((result2.im - 4.0).abs() < 1e-10);
973    }
974
975    #[test]
976    fn test_complex_eval_cache_try_compute() {
977        let cache = ComplexEvalCache::new();
978
979        let result: Result<_, &str> =
980            cache.get_or_try_compute(1, 1, || Ok(Complex64::new(1.0, 2.0)));
981        assert!(result.is_ok());
982
983        let stats = cache.stats();
984        assert_eq!(stats.misses, 1);
985
986        // Second call should hit cache
987        let result2: Result<_, &str> =
988            cache.get_or_try_compute(1, 1, || Err("should not be called"));
989        assert!(result2.is_ok());
990
991        let stats = cache.stats();
992        assert_eq!(stats.hits, 1);
993    }
994
995    #[test]
996    fn test_simplification_cache() {
997        let cache = SimplificationCache::new();
998
999        let expr = Expression::symbol("x") + Expression::symbol("x");
1000        let simplified = cache.get_or_simplify(&expr, || {
1001            // This simulates simplification
1002            Expression::int(2) * Expression::symbol("x")
1003        });
1004
1005        // Should have cached
1006        assert_eq!(cache.len(), 1);
1007
1008        // Second call should return cached
1009        let simplified2 = cache.get_or_simplify(&expr, || {
1010            // This should not be called
1011            Expression::symbol("should_not_appear")
1012        });
1013
1014        // Both should be equivalent
1015        assert_eq!(simplified.to_string(), simplified2.to_string());
1016
1017        let stats = cache.stats();
1018        assert_eq!(stats.hits, 1);
1019        assert_eq!(stats.misses, 1);
1020    }
1021
1022    #[test]
1023    fn test_batch_eval_cache() {
1024        let cache = BatchEvalCache::new();
1025
1026        let expr_hash = 12345u64;
1027        let param_hashes = vec![1, 2, 3, 4, 5];
1028
1029        let mut compute_count = 0;
1030        let results = cache.get_or_compute_batch(expr_hash, &param_hashes, |missing| {
1031            compute_count = missing.len();
1032            missing.iter().map(|&i| i as f64 * 10.0).collect()
1033        });
1034
1035        assert_eq!(compute_count, 5); // All were missing
1036        assert!((results[0] - 0.0).abs() < 1e-10);
1037        assert!((results[1] - 10.0).abs() < 1e-10);
1038
1039        // Second call - all should be cached
1040        let mut compute_count2 = 0;
1041        let results2 = cache.get_or_compute_batch(expr_hash, &param_hashes, |missing| {
1042            compute_count2 = missing.len();
1043            missing.iter().map(|&i| i as f64 * 100.0).collect()
1044        });
1045
1046        assert_eq!(compute_count2, 0); // All were cached
1047        assert!((results2[0] - 0.0).abs() < 1e-10);
1048        assert!((results2[1] - 10.0).abs() < 1e-10);
1049    }
1050
1051    #[test]
1052    fn test_batch_eval_cache_partial_hit() {
1053        let cache = BatchEvalCache::new();
1054
1055        let expr_hash = 12345u64;
1056
1057        // First call with params 1, 2, 3
1058        cache.get_or_compute_batch(expr_hash, &[1, 2, 3], |missing| {
1059            missing.iter().map(|&i| i as f64).collect()
1060        });
1061
1062        // Second call with params 2, 3, 4, 5 - 2 and 3 should be cached
1063        let mut computed_indices = Vec::new();
1064        cache.get_or_compute_batch(expr_hash, &[2, 3, 4, 5], |missing| {
1065            computed_indices = missing.to_vec();
1066            missing.iter().map(|&i| i as f64).collect()
1067        });
1068
1069        // Only indices 2 and 3 (params 4 and 5) should be computed
1070        assert_eq!(computed_indices, vec![2, 3]);
1071    }
1072
1073    #[test]
1074    fn test_cached_evaluator() {
1075        let evaluator = CachedEvaluator::new();
1076
1077        let expr = Expression::symbol("x");
1078        let mut values = HashMap::new();
1079        values.insert("x".to_string(), 5.0);
1080
1081        let result1 = evaluator.eval(&expr, &values).expect("should eval");
1082        assert!((result1 - 5.0).abs() < 1e-10);
1083
1084        // Second call should use cache
1085        let result2 = evaluator.eval(&expr, &values).expect("should eval");
1086        assert!((result2 - 5.0).abs() < 1e-10);
1087
1088        let stats = evaluator.stats();
1089        assert_eq!(stats.eval.misses, 1);
1090        assert_eq!(stats.eval.hits, 1);
1091    }
1092
1093    #[test]
1094    fn test_cached_evaluator_complex() {
1095        let evaluator = CachedEvaluator::new();
1096
1097        // Expression: 1 + I (imaginary unit)
1098        let expr = Expression::int(1) + Expression::symbol("I");
1099        let values = HashMap::new();
1100
1101        let result = evaluator.eval_complex(&expr, &values).expect("should eval");
1102        assert!((result.re - 1.0).abs() < 1e-10);
1103        assert!((result.im - 1.0).abs() < 1e-10);
1104
1105        let stats = evaluator.stats();
1106        assert_eq!(stats.complex.misses, 1);
1107    }
1108
1109    #[test]
1110    fn test_cached_evaluator_simplify() {
1111        let evaluator = CachedEvaluator::new();
1112
1113        let expr = Expression::symbol("x") + Expression::int(0);
1114        let simplified = evaluator.simplify(&expr);
1115
1116        // x + 0 should simplify to just x
1117        assert!(simplified.is_symbol() || simplified.to_string().contains('x'));
1118
1119        // Second call should use cache
1120        let simplified2 = evaluator.simplify(&expr);
1121        assert_eq!(simplified.to_string(), simplified2.to_string());
1122
1123        let stats = evaluator.stats();
1124        assert_eq!(stats.simplification.misses, 1);
1125        assert_eq!(stats.simplification.hits, 1);
1126    }
1127
1128    #[test]
1129    fn test_combined_cache_stats() {
1130        let evaluator = CachedEvaluator::new();
1131
1132        // Generate some hits and misses
1133        let expr = Expression::symbol("x");
1134        let mut values = HashMap::new();
1135        values.insert("x".to_string(), 1.0);
1136
1137        // Miss, hit, hit
1138        for _ in 0..3 {
1139            let _ = evaluator.eval(&expr, &values);
1140        }
1141
1142        let stats = evaluator.stats();
1143        assert_eq!(stats.total_size(), 1);
1144        assert_eq!(stats.total_hits(), 2);
1145        assert_eq!(stats.total_misses(), 1);
1146        assert!((stats.overall_hit_rate() - 2.0 / 3.0).abs() < 1e-10);
1147    }
1148
1149    #[test]
1150    fn test_hash_params() {
1151        let mut params1 = HashMap::new();
1152        params1.insert("x".to_string(), 1.0);
1153        params1.insert("y".to_string(), 2.0);
1154
1155        let mut params2 = HashMap::new();
1156        params2.insert("y".to_string(), 2.0);
1157        params2.insert("x".to_string(), 1.0);
1158
1159        // Order shouldn't matter
1160        assert_eq!(hash_params(&params1), hash_params(&params2));
1161    }
1162
1163    #[test]
1164    fn test_hash_complex_params() {
1165        let mut params1 = HashMap::new();
1166        params1.insert("a".to_string(), Complex64::new(1.0, 2.0));
1167        params1.insert("b".to_string(), Complex64::new(3.0, 4.0));
1168
1169        let mut params2 = HashMap::new();
1170        params2.insert("b".to_string(), Complex64::new(3.0, 4.0));
1171        params2.insert("a".to_string(), Complex64::new(1.0, 2.0));
1172
1173        // Order shouldn't matter
1174        assert_eq!(hash_complex_params(&params1), hash_complex_params(&params2));
1175    }
1176
1177    #[test]
1178    fn test_hash_param_array() {
1179        let params1 = [1.0, 2.0, 3.0];
1180        let params2 = [1.0, 2.0, 3.0];
1181        let params3 = [1.0, 2.0, 4.0];
1182
1183        assert_eq!(hash_param_array(&params1), hash_param_array(&params2));
1184        assert_ne!(hash_param_array(&params1), hash_param_array(&params3));
1185    }
1186
1187    #[test]
1188    fn test_expression_cache() {
1189        let cache = ExpressionCache::new();
1190
1191        let expr1 = Expression::symbol("x");
1192        let arc1 = cache.get_or_insert(expr1.clone());
1193        let arc2 = cache.get_or_insert(expr1);
1194
1195        // Should be the same Arc
1196        assert!(Arc::ptr_eq(&arc1, &arc2));
1197    }
1198
1199    #[test]
1200    fn test_cache_clear() {
1201        let cache = EvalCache::new();
1202        cache.get_or_compute(1, 1, || 42.0);
1203        assert_eq!(cache.len(), 1);
1204
1205        cache.clear();
1206        assert!(cache.is_empty());
1207
1208        let stats = cache.stats();
1209        assert_eq!(stats.hits, 0);
1210        assert_eq!(stats.misses, 0);
1211    }
1212}