1use std::collections::HashMap;
48use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
49use std::sync::Arc;
50
51use dashmap::DashMap;
52use rustc_hash::FxHasher;
53use scirs2_core::Complex64;
54
55use crate::error::SymEngineResult;
56use crate::expr::Expression;
57
58pub const DEFAULT_MAX_CACHE_SIZE: usize = 10_000;
60
61pub struct EvalCache {
65 cache: DashMap<(u64, u64), CachedValue<f64>, std::hash::BuildHasherDefault<FxHasher>>,
66 max_size: usize,
67 access_counter: AtomicU64,
68 hits: AtomicUsize,
69 misses: AtomicUsize,
70}
71
72#[derive(Clone)]
74struct CachedValue<T> {
75 value: T,
76 last_access: u64,
77}
78
79impl EvalCache {
80 #[must_use]
82 pub fn new() -> Self {
83 Self::with_capacity(DEFAULT_MAX_CACHE_SIZE)
84 }
85
86 #[must_use]
88 pub fn with_capacity(max_size: usize) -> Self {
89 Self {
90 cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
91 max_size,
92 access_counter: AtomicU64::new(0),
93 hits: AtomicUsize::new(0),
94 misses: AtomicUsize::new(0),
95 }
96 }
97
98 pub fn get_or_compute<F>(&self, expr_hash: u64, params_hash: u64, compute: F) -> f64
100 where
101 F: FnOnce() -> f64,
102 {
103 let key = (expr_hash, params_hash);
104 let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
105
106 if let Some(mut entry) = self.cache.get_mut(&key) {
107 self.hits.fetch_add(1, Ordering::Relaxed);
108 entry.last_access = access_time;
109 return entry.value;
110 }
111
112 self.misses.fetch_add(1, Ordering::Relaxed);
113 let result = compute();
114
115 if self.cache.len() >= self.max_size {
117 self.evict_lru();
118 }
119
120 self.cache.insert(
121 key,
122 CachedValue {
123 value: result,
124 last_access: access_time,
125 },
126 );
127 result
128 }
129
130 #[must_use]
132 pub fn get(&self, expr_hash: u64, params_hash: u64) -> Option<f64> {
133 let key = (expr_hash, params_hash);
134 let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
135
136 self.cache.get_mut(&key).map(|mut entry| {
137 entry.last_access = access_time;
138 entry.value
139 })
140 }
141
142 pub fn get_or_try_compute<F, E>(
144 &self,
145 expr_hash: u64,
146 params_hash: u64,
147 compute: F,
148 ) -> Result<f64, E>
149 where
150 F: FnOnce() -> Result<f64, E>,
151 {
152 let key = (expr_hash, params_hash);
153 let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
154
155 if let Some(mut entry) = self.cache.get_mut(&key) {
156 self.hits.fetch_add(1, Ordering::Relaxed);
157 entry.last_access = access_time;
158 return Ok(entry.value);
159 }
160
161 self.misses.fetch_add(1, Ordering::Relaxed);
162 let result = compute()?;
163
164 if self.cache.len() >= self.max_size {
165 self.evict_lru();
166 }
167
168 self.cache.insert(
169 key,
170 CachedValue {
171 value: result,
172 last_access: access_time,
173 },
174 );
175 Ok(result)
176 }
177
178 pub fn insert(&self, expr_hash: u64, params_hash: u64, value: f64) {
180 let key = (expr_hash, params_hash);
181 let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
182
183 if self.cache.len() >= self.max_size {
184 self.evict_lru();
185 }
186
187 self.cache.insert(
188 key,
189 CachedValue {
190 value,
191 last_access: access_time,
192 },
193 );
194 }
195
196 fn evict_lru(&self) {
198 let evict_count = self.max_size / 10;
199 if evict_count == 0 {
200 return;
201 }
202
203 let mut entries: Vec<_> = self
205 .cache
206 .iter()
207 .map(|e| (*e.key(), e.value().last_access))
208 .collect();
209 entries.sort_by_key(|(_, access)| *access);
210
211 for (key, _) in entries.into_iter().take(evict_count) {
213 self.cache.remove(&key);
214 }
215 }
216
217 pub fn clear(&self) {
219 self.cache.clear();
220 self.hits.store(0, Ordering::Relaxed);
221 self.misses.store(0, Ordering::Relaxed);
222 }
223
224 #[must_use]
226 pub fn len(&self) -> usize {
227 self.cache.len()
228 }
229
230 #[must_use]
232 pub fn is_empty(&self) -> bool {
233 self.cache.is_empty()
234 }
235
236 #[must_use]
238 pub fn stats(&self) -> CacheStats {
239 let hits = self.hits.load(Ordering::Relaxed);
240 let misses = self.misses.load(Ordering::Relaxed);
241 CacheStats {
242 size: self.cache.len(),
243 max_size: self.max_size,
244 hits,
245 misses,
246 hit_rate: if hits + misses > 0 {
247 hits as f64 / (hits + misses) as f64
248 } else {
249 0.0
250 },
251 }
252 }
253}
254
255#[derive(Debug, Clone)]
257pub struct CacheStats {
258 pub size: usize,
260 pub max_size: usize,
262 pub hits: usize,
264 pub misses: usize,
266 pub hit_rate: f64,
268}
269
270impl Default for EvalCache {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276pub struct ExpressionCache {
278 cache: DashMap<u64, Arc<Expression>, std::hash::BuildHasherDefault<FxHasher>>,
279}
280
281impl ExpressionCache {
282 #[must_use]
284 pub fn new() -> Self {
285 Self {
286 cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
287 }
288 }
289
290 pub fn get_or_insert(&self, expr: Expression) -> Arc<Expression> {
292 let hash = compute_hash(&expr);
293 self.cache
294 .entry(hash)
295 .or_insert_with(|| Arc::new(expr))
296 .clone()
297 }
298
299 pub fn clear(&self) {
301 self.cache.clear();
302 }
303}
304
305impl Default for ExpressionCache {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311pub struct ComplexEvalCache {
320 cache: DashMap<(u64, u64), CachedValue<Complex64>, std::hash::BuildHasherDefault<FxHasher>>,
321 max_size: usize,
322 access_counter: AtomicU64,
323 hits: AtomicUsize,
324 misses: AtomicUsize,
325}
326
327impl ComplexEvalCache {
328 #[must_use]
330 pub fn new() -> Self {
331 Self::with_capacity(DEFAULT_MAX_CACHE_SIZE)
332 }
333
334 #[must_use]
336 pub fn with_capacity(max_size: usize) -> Self {
337 Self {
338 cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
339 max_size,
340 access_counter: AtomicU64::new(0),
341 hits: AtomicUsize::new(0),
342 misses: AtomicUsize::new(0),
343 }
344 }
345
346 pub fn get_or_compute<F>(&self, expr_hash: u64, params_hash: u64, compute: F) -> Complex64
348 where
349 F: FnOnce() -> Complex64,
350 {
351 let key = (expr_hash, params_hash);
352 let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
353
354 if let Some(mut entry) = self.cache.get_mut(&key) {
355 self.hits.fetch_add(1, Ordering::Relaxed);
356 entry.last_access = access_time;
357 return entry.value;
358 }
359
360 self.misses.fetch_add(1, Ordering::Relaxed);
361 let result = compute();
362
363 if self.cache.len() >= self.max_size {
364 self.evict_lru();
365 }
366
367 self.cache.insert(
368 key,
369 CachedValue {
370 value: result,
371 last_access: access_time,
372 },
373 );
374 result
375 }
376
377 pub fn get_or_try_compute<F, E>(
379 &self,
380 expr_hash: u64,
381 params_hash: u64,
382 compute: F,
383 ) -> Result<Complex64, E>
384 where
385 F: FnOnce() -> Result<Complex64, E>,
386 {
387 let key = (expr_hash, params_hash);
388 let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
389
390 if let Some(mut entry) = self.cache.get_mut(&key) {
391 self.hits.fetch_add(1, Ordering::Relaxed);
392 entry.last_access = access_time;
393 return Ok(entry.value);
394 }
395
396 self.misses.fetch_add(1, Ordering::Relaxed);
397 let result = compute()?;
398
399 if self.cache.len() >= self.max_size {
400 self.evict_lru();
401 }
402
403 self.cache.insert(
404 key,
405 CachedValue {
406 value: result,
407 last_access: access_time,
408 },
409 );
410 Ok(result)
411 }
412
413 fn evict_lru(&self) {
415 let evict_count = self.max_size / 10;
416 if evict_count == 0 {
417 return;
418 }
419
420 let mut entries: Vec<_> = self
421 .cache
422 .iter()
423 .map(|e| (*e.key(), e.value().last_access))
424 .collect();
425 entries.sort_by_key(|(_, access)| *access);
426
427 for (key, _) in entries.into_iter().take(evict_count) {
428 self.cache.remove(&key);
429 }
430 }
431
432 pub fn clear(&self) {
434 self.cache.clear();
435 self.hits.store(0, Ordering::Relaxed);
436 self.misses.store(0, Ordering::Relaxed);
437 }
438
439 #[must_use]
441 pub fn len(&self) -> usize {
442 self.cache.len()
443 }
444
445 #[must_use]
447 pub fn is_empty(&self) -> bool {
448 self.cache.is_empty()
449 }
450
451 #[must_use]
453 pub fn stats(&self) -> CacheStats {
454 let hits = self.hits.load(Ordering::Relaxed);
455 let misses = self.misses.load(Ordering::Relaxed);
456 CacheStats {
457 size: self.cache.len(),
458 max_size: self.max_size,
459 hits,
460 misses,
461 hit_rate: if hits + misses > 0 {
462 hits as f64 / (hits + misses) as f64
463 } else {
464 0.0
465 },
466 }
467 }
468}
469
470impl Default for ComplexEvalCache {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476pub struct SimplificationCache {
485 cache: DashMap<u64, CachedValue<Expression>, std::hash::BuildHasherDefault<FxHasher>>,
486 max_size: usize,
487 access_counter: AtomicU64,
488 hits: AtomicUsize,
489 misses: AtomicUsize,
490}
491
492impl SimplificationCache {
493 #[must_use]
495 pub fn new() -> Self {
496 Self::with_capacity(DEFAULT_MAX_CACHE_SIZE)
497 }
498
499 #[must_use]
501 pub fn with_capacity(max_size: usize) -> Self {
502 Self {
503 cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
504 max_size,
505 access_counter: AtomicU64::new(0),
506 hits: AtomicUsize::new(0),
507 misses: AtomicUsize::new(0),
508 }
509 }
510
511 pub fn get_or_simplify<F>(&self, expr: &Expression, simplify: F) -> Expression
513 where
514 F: FnOnce() -> Expression,
515 {
516 let expr_hash = compute_hash(expr);
517 let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
518
519 if let Some(mut entry) = self.cache.get_mut(&expr_hash) {
520 self.hits.fetch_add(1, Ordering::Relaxed);
521 entry.last_access = access_time;
522 return entry.value.clone();
523 }
524
525 self.misses.fetch_add(1, Ordering::Relaxed);
526 let result = simplify();
527
528 if self.cache.len() >= self.max_size {
529 self.evict_lru();
530 }
531
532 self.cache.insert(
533 expr_hash,
534 CachedValue {
535 value: result.clone(),
536 last_access: access_time,
537 },
538 );
539 result
540 }
541
542 fn evict_lru(&self) {
544 let evict_count = self.max_size / 10;
545 if evict_count == 0 {
546 return;
547 }
548
549 let mut entries: Vec<_> = self
550 .cache
551 .iter()
552 .map(|e| (*e.key(), e.value().last_access))
553 .collect();
554 entries.sort_by_key(|(_, access)| *access);
555
556 for (key, _) in entries.into_iter().take(evict_count) {
557 self.cache.remove(&key);
558 }
559 }
560
561 pub fn clear(&self) {
563 self.cache.clear();
564 self.hits.store(0, Ordering::Relaxed);
565 self.misses.store(0, Ordering::Relaxed);
566 }
567
568 #[must_use]
570 pub fn len(&self) -> usize {
571 self.cache.len()
572 }
573
574 #[must_use]
576 pub fn is_empty(&self) -> bool {
577 self.cache.is_empty()
578 }
579
580 #[must_use]
582 pub fn stats(&self) -> CacheStats {
583 let hits = self.hits.load(Ordering::Relaxed);
584 let misses = self.misses.load(Ordering::Relaxed);
585 CacheStats {
586 size: self.cache.len(),
587 max_size: self.max_size,
588 hits,
589 misses,
590 hit_rate: if hits + misses > 0 {
591 hits as f64 / (hits + misses) as f64
592 } else {
593 0.0
594 },
595 }
596 }
597}
598
599impl Default for SimplificationCache {
600 fn default() -> Self {
601 Self::new()
602 }
603}
604
605pub struct BatchEvalCache {
614 cache: DashMap<
616 u64,
617 DashMap<u64, f64, std::hash::BuildHasherDefault<FxHasher>>,
618 std::hash::BuildHasherDefault<FxHasher>,
619 >,
620 max_expressions: usize,
621 max_params_per_expr: usize,
622}
623
624impl BatchEvalCache {
625 #[must_use]
627 pub fn new() -> Self {
628 Self::with_capacity(1000, 1000)
629 }
630
631 #[must_use]
633 pub fn with_capacity(max_expressions: usize, max_params_per_expr: usize) -> Self {
634 Self {
635 cache: DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default()),
636 max_expressions,
637 max_params_per_expr,
638 }
639 }
640
641 pub fn get_or_compute_batch<F>(
643 &self,
644 expr_hash: u64,
645 param_hashes: &[u64],
646 compute: F,
647 ) -> Vec<f64>
648 where
649 F: FnOnce(&[usize]) -> Vec<f64>,
650 {
651 let expr_cache = self.cache.entry(expr_hash).or_insert_with(|| {
653 DashMap::with_hasher(std::hash::BuildHasherDefault::<FxHasher>::default())
654 });
655
656 let mut results = vec![0.0; param_hashes.len()];
657 let mut missing_indices = Vec::new();
658
659 for (i, &ph) in param_hashes.iter().enumerate() {
660 if let Some(val) = expr_cache.get(&ph) {
661 results[i] = *val;
662 } else {
663 missing_indices.push(i);
664 }
665 }
666
667 if !missing_indices.is_empty() {
669 let computed = compute(&missing_indices);
670
671 for (j, &i) in missing_indices.iter().enumerate() {
672 results[i] = computed[j];
673 let ph = param_hashes[i];
674
675 if expr_cache.len() >= self.max_params_per_expr {
677 let first_key = expr_cache.iter().next().map(|e| *e.key());
680 if let Some(key) = first_key {
681 expr_cache.remove(&key);
682 }
683 }
684
685 expr_cache.insert(ph, computed[j]);
686 }
687 }
688
689 results
690 }
691
692 pub fn clear(&self) {
694 self.cache.clear();
695 }
696
697 #[must_use]
699 pub fn len(&self) -> usize {
700 self.cache.len()
701 }
702
703 #[must_use]
705 pub fn is_empty(&self) -> bool {
706 self.cache.is_empty()
707 }
708
709 #[must_use]
711 pub fn total_params_cached(&self) -> usize {
712 self.cache.iter().map(|e| e.value().len()).sum()
713 }
714}
715
716impl Default for BatchEvalCache {
717 fn default() -> Self {
718 Self::new()
719 }
720}
721
722#[allow(clippy::struct_field_names)]
731pub struct CachedEvaluator {
732 eval_cache: EvalCache,
733 complex_cache: ComplexEvalCache,
734 simplification_cache: SimplificationCache,
735}
736
737impl CachedEvaluator {
738 #[must_use]
740 pub fn new() -> Self {
741 Self {
742 eval_cache: EvalCache::new(),
743 complex_cache: ComplexEvalCache::new(),
744 simplification_cache: SimplificationCache::new(),
745 }
746 }
747
748 #[must_use]
750 pub fn with_capacity(eval_size: usize, complex_size: usize, simplify_size: usize) -> Self {
751 Self {
752 eval_cache: EvalCache::with_capacity(eval_size),
753 complex_cache: ComplexEvalCache::with_capacity(complex_size),
754 simplification_cache: SimplificationCache::with_capacity(simplify_size),
755 }
756 }
757
758 pub fn eval(&self, expr: &Expression, values: &HashMap<String, f64>) -> SymEngineResult<f64> {
760 let expr_hash = compute_hash(expr);
761 let params_hash = hash_params(values);
762
763 self.eval_cache
765 .get_or_try_compute(expr_hash, params_hash, || expr.eval(values))
766 }
767
768 pub fn eval_complex(
770 &self,
771 expr: &Expression,
772 values: &HashMap<String, f64>,
773 ) -> SymEngineResult<Complex64> {
774 let expr_hash = compute_hash(expr);
775 let params_hash = hash_params(values);
776
777 self.complex_cache
778 .get_or_try_compute(expr_hash, params_hash, || expr.eval_complex(values))
779 }
780
781 pub fn simplify(&self, expr: &Expression) -> Expression {
783 self.simplification_cache
784 .get_or_simplify(expr, || expr.simplify())
785 }
786
787 pub fn clear(&self) {
789 self.eval_cache.clear();
790 self.complex_cache.clear();
791 self.simplification_cache.clear();
792 }
793
794 #[must_use]
796 pub fn stats(&self) -> CombinedCacheStats {
797 CombinedCacheStats {
798 eval: self.eval_cache.stats(),
799 complex: self.complex_cache.stats(),
800 simplification: self.simplification_cache.stats(),
801 }
802 }
803}
804
805impl Default for CachedEvaluator {
806 fn default() -> Self {
807 Self::new()
808 }
809}
810
811#[derive(Debug, Clone)]
813pub struct CombinedCacheStats {
814 pub eval: CacheStats,
816 pub complex: CacheStats,
818 pub simplification: CacheStats,
820}
821
822impl CombinedCacheStats {
823 #[must_use]
825 pub const fn total_size(&self) -> usize {
826 self.eval.size + self.complex.size + self.simplification.size
827 }
828
829 #[must_use]
831 pub const fn total_hits(&self) -> usize {
832 self.eval.hits + self.complex.hits + self.simplification.hits
833 }
834
835 #[must_use]
837 pub const fn total_misses(&self) -> usize {
838 self.eval.misses + self.complex.misses + self.simplification.misses
839 }
840
841 #[must_use]
843 pub fn overall_hit_rate(&self) -> f64 {
844 let total = self.total_hits() + self.total_misses();
845 if total > 0 {
846 self.total_hits() as f64 / total as f64
847 } else {
848 0.0
849 }
850 }
851}
852
853pub fn compute_hash(expr: &Expression) -> u64 {
859 use std::hash::{Hash, Hasher};
860 let mut hasher = FxHasher::default();
861 expr.to_string().hash(&mut hasher);
862 hasher.finish()
863}
864
865pub fn hash_params(params: &HashMap<String, f64>) -> u64 {
867 use std::hash::{Hash, Hasher};
868 let mut hasher = FxHasher::default();
869
870 let mut keys: Vec<_> = params.keys().collect();
872 keys.sort();
873
874 for key in keys {
875 key.hash(&mut hasher);
876 if let Some(value) = params.get(key) {
877 value.to_bits().hash(&mut hasher);
878 }
879 }
880
881 hasher.finish()
882}
883
884pub fn hash_complex_params(params: &HashMap<String, Complex64>) -> u64 {
886 use std::hash::{Hash, Hasher};
887 let mut hasher = FxHasher::default();
888
889 let mut keys: Vec<_> = params.keys().collect();
890 keys.sort();
891
892 for key in keys {
893 key.hash(&mut hasher);
894 if let Some(value) = params.get(key) {
895 value.re.to_bits().hash(&mut hasher);
896 value.im.to_bits().hash(&mut hasher);
897 }
898 }
899
900 hasher.finish()
901}
902
903pub fn hash_param_array(params: &[f64]) -> u64 {
905 use std::hash::{Hash, Hasher};
906 let mut hasher = FxHasher::default();
907
908 for value in params {
909 value.to_bits().hash(&mut hasher);
910 }
911
912 hasher.finish()
913}
914
915#[cfg(test)]
916#[allow(clippy::approx_constant)]
917mod tests {
918 use super::*;
919
920 #[test]
921 fn test_eval_cache() {
922 let cache = EvalCache::new();
923
924 let result1 = cache.get_or_compute(1, 1, || 42.0);
925 assert!((result1 - 42.0).abs() < 1e-10);
926
927 let result2 = cache.get_or_compute(1, 1, || 100.0);
929 assert!((result2 - 42.0).abs() < 1e-10);
930
931 assert_eq!(cache.len(), 1);
932 }
933
934 #[test]
935 fn test_eval_cache_stats() {
936 let cache = EvalCache::new();
937
938 cache.get_or_compute(1, 1, || 42.0);
940 cache.get_or_compute(1, 1, || 42.0);
941
942 let stats = cache.stats();
943 assert_eq!(stats.hits, 1);
944 assert_eq!(stats.misses, 1);
945 assert!((stats.hit_rate - 0.5).abs() < 1e-10);
946 }
947
948 #[test]
949 fn test_eval_cache_lru_eviction() {
950 let cache = EvalCache::with_capacity(10);
951
952 for i in 0..15u64 {
954 cache.get_or_compute(i, 0, || i as f64);
955 }
956
957 assert!(cache.len() <= 10);
959 }
960
961 #[test]
962 fn test_complex_eval_cache() {
963 let cache = ComplexEvalCache::new();
964
965 let result1 = cache.get_or_compute(1, 1, || Complex64::new(3.0, 4.0));
966 assert!((result1.re - 3.0).abs() < 1e-10);
967 assert!((result1.im - 4.0).abs() < 1e-10);
968
969 let result2 = cache.get_or_compute(1, 1, || Complex64::new(100.0, 200.0));
971 assert!((result2.re - 3.0).abs() < 1e-10);
972 assert!((result2.im - 4.0).abs() < 1e-10);
973 }
974
975 #[test]
976 fn test_complex_eval_cache_try_compute() {
977 let cache = ComplexEvalCache::new();
978
979 let result: Result<_, &str> =
980 cache.get_or_try_compute(1, 1, || Ok(Complex64::new(1.0, 2.0)));
981 assert!(result.is_ok());
982
983 let stats = cache.stats();
984 assert_eq!(stats.misses, 1);
985
986 let result2: Result<_, &str> =
988 cache.get_or_try_compute(1, 1, || Err("should not be called"));
989 assert!(result2.is_ok());
990
991 let stats = cache.stats();
992 assert_eq!(stats.hits, 1);
993 }
994
995 #[test]
996 fn test_simplification_cache() {
997 let cache = SimplificationCache::new();
998
999 let expr = Expression::symbol("x") + Expression::symbol("x");
1000 let simplified = cache.get_or_simplify(&expr, || {
1001 Expression::int(2) * Expression::symbol("x")
1003 });
1004
1005 assert_eq!(cache.len(), 1);
1007
1008 let simplified2 = cache.get_or_simplify(&expr, || {
1010 Expression::symbol("should_not_appear")
1012 });
1013
1014 assert_eq!(simplified.to_string(), simplified2.to_string());
1016
1017 let stats = cache.stats();
1018 assert_eq!(stats.hits, 1);
1019 assert_eq!(stats.misses, 1);
1020 }
1021
1022 #[test]
1023 fn test_batch_eval_cache() {
1024 let cache = BatchEvalCache::new();
1025
1026 let expr_hash = 12345u64;
1027 let param_hashes = vec![1, 2, 3, 4, 5];
1028
1029 let mut compute_count = 0;
1030 let results = cache.get_or_compute_batch(expr_hash, ¶m_hashes, |missing| {
1031 compute_count = missing.len();
1032 missing.iter().map(|&i| i as f64 * 10.0).collect()
1033 });
1034
1035 assert_eq!(compute_count, 5); assert!((results[0] - 0.0).abs() < 1e-10);
1037 assert!((results[1] - 10.0).abs() < 1e-10);
1038
1039 let mut compute_count2 = 0;
1041 let results2 = cache.get_or_compute_batch(expr_hash, ¶m_hashes, |missing| {
1042 compute_count2 = missing.len();
1043 missing.iter().map(|&i| i as f64 * 100.0).collect()
1044 });
1045
1046 assert_eq!(compute_count2, 0); assert!((results2[0] - 0.0).abs() < 1e-10);
1048 assert!((results2[1] - 10.0).abs() < 1e-10);
1049 }
1050
1051 #[test]
1052 fn test_batch_eval_cache_partial_hit() {
1053 let cache = BatchEvalCache::new();
1054
1055 let expr_hash = 12345u64;
1056
1057 cache.get_or_compute_batch(expr_hash, &[1, 2, 3], |missing| {
1059 missing.iter().map(|&i| i as f64).collect()
1060 });
1061
1062 let mut computed_indices = Vec::new();
1064 cache.get_or_compute_batch(expr_hash, &[2, 3, 4, 5], |missing| {
1065 computed_indices = missing.to_vec();
1066 missing.iter().map(|&i| i as f64).collect()
1067 });
1068
1069 assert_eq!(computed_indices, vec![2, 3]);
1071 }
1072
1073 #[test]
1074 fn test_cached_evaluator() {
1075 let evaluator = CachedEvaluator::new();
1076
1077 let expr = Expression::symbol("x");
1078 let mut values = HashMap::new();
1079 values.insert("x".to_string(), 5.0);
1080
1081 let result1 = evaluator.eval(&expr, &values).expect("should eval");
1082 assert!((result1 - 5.0).abs() < 1e-10);
1083
1084 let result2 = evaluator.eval(&expr, &values).expect("should eval");
1086 assert!((result2 - 5.0).abs() < 1e-10);
1087
1088 let stats = evaluator.stats();
1089 assert_eq!(stats.eval.misses, 1);
1090 assert_eq!(stats.eval.hits, 1);
1091 }
1092
1093 #[test]
1094 fn test_cached_evaluator_complex() {
1095 let evaluator = CachedEvaluator::new();
1096
1097 let expr = Expression::int(1) + Expression::symbol("I");
1099 let values = HashMap::new();
1100
1101 let result = evaluator.eval_complex(&expr, &values).expect("should eval");
1102 assert!((result.re - 1.0).abs() < 1e-10);
1103 assert!((result.im - 1.0).abs() < 1e-10);
1104
1105 let stats = evaluator.stats();
1106 assert_eq!(stats.complex.misses, 1);
1107 }
1108
1109 #[test]
1110 fn test_cached_evaluator_simplify() {
1111 let evaluator = CachedEvaluator::new();
1112
1113 let expr = Expression::symbol("x") + Expression::int(0);
1114 let simplified = evaluator.simplify(&expr);
1115
1116 assert!(simplified.is_symbol() || simplified.to_string().contains('x'));
1118
1119 let simplified2 = evaluator.simplify(&expr);
1121 assert_eq!(simplified.to_string(), simplified2.to_string());
1122
1123 let stats = evaluator.stats();
1124 assert_eq!(stats.simplification.misses, 1);
1125 assert_eq!(stats.simplification.hits, 1);
1126 }
1127
1128 #[test]
1129 fn test_combined_cache_stats() {
1130 let evaluator = CachedEvaluator::new();
1131
1132 let expr = Expression::symbol("x");
1134 let mut values = HashMap::new();
1135 values.insert("x".to_string(), 1.0);
1136
1137 for _ in 0..3 {
1139 let _ = evaluator.eval(&expr, &values);
1140 }
1141
1142 let stats = evaluator.stats();
1143 assert_eq!(stats.total_size(), 1);
1144 assert_eq!(stats.total_hits(), 2);
1145 assert_eq!(stats.total_misses(), 1);
1146 assert!((stats.overall_hit_rate() - 2.0 / 3.0).abs() < 1e-10);
1147 }
1148
1149 #[test]
1150 fn test_hash_params() {
1151 let mut params1 = HashMap::new();
1152 params1.insert("x".to_string(), 1.0);
1153 params1.insert("y".to_string(), 2.0);
1154
1155 let mut params2 = HashMap::new();
1156 params2.insert("y".to_string(), 2.0);
1157 params2.insert("x".to_string(), 1.0);
1158
1159 assert_eq!(hash_params(¶ms1), hash_params(¶ms2));
1161 }
1162
1163 #[test]
1164 fn test_hash_complex_params() {
1165 let mut params1 = HashMap::new();
1166 params1.insert("a".to_string(), Complex64::new(1.0, 2.0));
1167 params1.insert("b".to_string(), Complex64::new(3.0, 4.0));
1168
1169 let mut params2 = HashMap::new();
1170 params2.insert("b".to_string(), Complex64::new(3.0, 4.0));
1171 params2.insert("a".to_string(), Complex64::new(1.0, 2.0));
1172
1173 assert_eq!(hash_complex_params(¶ms1), hash_complex_params(¶ms2));
1175 }
1176
1177 #[test]
1178 fn test_hash_param_array() {
1179 let params1 = [1.0, 2.0, 3.0];
1180 let params2 = [1.0, 2.0, 3.0];
1181 let params3 = [1.0, 2.0, 4.0];
1182
1183 assert_eq!(hash_param_array(¶ms1), hash_param_array(¶ms2));
1184 assert_ne!(hash_param_array(¶ms1), hash_param_array(¶ms3));
1185 }
1186
1187 #[test]
1188 fn test_expression_cache() {
1189 let cache = ExpressionCache::new();
1190
1191 let expr1 = Expression::symbol("x");
1192 let arc1 = cache.get_or_insert(expr1.clone());
1193 let arc2 = cache.get_or_insert(expr1);
1194
1195 assert!(Arc::ptr_eq(&arc1, &arc2));
1197 }
1198
1199 #[test]
1200 fn test_cache_clear() {
1201 let cache = EvalCache::new();
1202 cache.get_or_compute(1, 1, || 42.0);
1203 assert_eq!(cache.len(), 1);
1204
1205 cache.clear();
1206 assert!(cache.is_empty());
1207
1208 let stats = cache.stats();
1209 assert_eq!(stats.hits, 0);
1210 assert_eq!(stats.misses, 0);
1211 }
1212}