1use crate::bspline::{BSpline, ExtrapolateMode};
43use crate::error::InterpolateResult;
44use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
45use scirs2_core::numeric::{Float, FromPrimitive, Zero};
46use std::collections::HashMap;
47use std::fmt::{Debug, Display};
48use std::hash::{Hash, Hasher};
49use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, RemAssign, Sub, SubAssign};
50
51#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum EvictionPolicy {
56 LRU,
58 LFU,
60 FIFO,
62 Random,
64 Adaptive,
66}
67
68#[derive(Debug, Clone)]
70pub struct CacheConfig {
71 pub max_basis_cache_size: usize,
73 pub max_matrix_cache_size: usize,
75 pub max_distance_cache_size: usize,
77 pub tolerance: f64,
79 pub track_stats: bool,
81 pub eviction_policy: EvictionPolicy,
83 pub memory_limit_mb: usize,
85 pub adaptive_sizing: bool,
87}
88
89impl Default for CacheConfig {
90 fn default() -> Self {
91 Self {
92 max_basis_cache_size: 1024,
93 max_matrix_cache_size: 64,
94 max_distance_cache_size: 256,
95 tolerance: 1e-12,
96 track_stats: false,
97 eviction_policy: EvictionPolicy::LRU,
98 memory_limit_mb: 0, adaptive_sizing: true,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct CacheStats {
107 pub hits: usize,
109 pub misses: usize,
111 pub evictions: usize,
113 pub memory_usage_bytes: usize,
115 pub avg_access_frequency: f64,
117 pub resize_count: usize,
119 pub peak_memory_bytes: usize,
121 pub last_cleanup_time: std::time::Instant,
123}
124
125impl Default for CacheStats {
126 fn default() -> Self {
127 Self {
128 hits: 0,
129 misses: 0,
130 evictions: 0,
131 memory_usage_bytes: 0,
132 avg_access_frequency: 0.0,
133 resize_count: 0,
134 peak_memory_bytes: 0,
135 last_cleanup_time: std::time::Instant::now(),
136 }
137 }
138}
139
140impl CacheStats {
141 pub fn hit_ratio(&self) -> f64 {
143 if self.hits + self.misses == 0 {
144 0.0
145 } else {
146 self.hits as f64 / (self.hits + self.misses) as f64
147 }
148 }
149
150 pub fn efficiency_score(&self) -> f64 {
152 let hit_ratio = self.hit_ratio();
153 let memory_factor = if self.peak_memory_bytes > 0 {
154 1.0 - (self.memory_usage_bytes as f64 / self.peak_memory_bytes as f64)
155 } else {
156 1.0
157 };
158 (hit_ratio + memory_factor) / 2.0
159 }
160
161 pub fn memory_usage_mb(&self) -> f64 {
163 self.memory_usage_bytes as f64 / (1024.0 * 1024.0)
164 }
165
166 pub fn peak_memory_mb(&self) -> f64 {
168 self.peak_memory_bytes as f64 / (1024.0 * 1024.0)
169 }
170
171 pub fn update_memory_usage(&mut self, currentbytes: usize) {
173 self.memory_usage_bytes = currentbytes;
174 if currentbytes > self.peak_memory_bytes {
175 self.peak_memory_bytes = currentbytes;
176 }
177 }
178
179 pub fn update_access_frequency(&mut self, new_accesscount: usize) {
181 let alpha = 0.1; let new_freq = new_accesscount as f64;
183 self.avg_access_frequency = alpha * new_freq + (1.0 - alpha) * self.avg_access_frequency;
184 }
185
186 pub fn needs_cleanup(&self, thresholdsecs: u64) -> bool {
188 self.last_cleanup_time.elapsed().as_secs() >= thresholdsecs
189 }
190
191 pub fn reset(&mut self) {
193 self.hits = 0;
194 self.misses = 0;
195 self.evictions = 0;
196 self.memory_usage_bytes = 0;
197 self.avg_access_frequency = 0.0;
198 self.resize_count = 0;
199 self.peak_memory_bytes = 0;
200 self.last_cleanup_time = std::time::Instant::now();
201 }
202}
203
204#[derive(Debug, Clone)]
206struct FloatKey<F: Float> {
207 value: F,
208 tolerance: F,
209}
210
211impl<F: Float> FloatKey<F> {
212 #[allow(dead_code)]
213 fn new(value: F, tolerance: F) -> Self {
214 Self { value, tolerance }
215 }
216}
217
218impl<F: crate::traits::InterpolationFloat> PartialEq for FloatKey<F> {
219 fn eq(&self, other: &Self) -> bool {
220 (self.value - other.value).abs() <= self.tolerance
221 }
222}
223
224impl<F: crate::traits::InterpolationFloat> Eq for FloatKey<F> {}
225
226impl<F: crate::traits::InterpolationFloat> Hash for FloatKey<F> {
227 fn hash<H: Hasher>(&self, state: &mut H) {
228 let quantized = (self.value / self.tolerance).round() * self.tolerance;
230 let bits = quantized.to_f64().unwrap_or(0.0).to_bits();
232 bits.hash(state);
233 }
234}
235
236#[derive(Debug)]
238pub struct BSplineCache<F: Float> {
239 basis_cache: HashMap<(FloatKey<F>, usize, usize), CacheEntry<F>>,
241 span_cache: HashMap<FloatKey<F>, CacheEntry<usize>>,
243 config: CacheConfig,
245 stats: CacheStats,
247 #[allow(dead_code)]
249 access_counter: usize,
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Hash)]
254pub struct BasisCacheKey {
255 pub x_quantized: u64,
257 pub index: usize,
259 pub degree: usize,
261}
262
263#[derive(Debug, Clone)]
265struct CacheEntry<T> {
266 #[allow(dead_code)]
268 value: T,
269 #[allow(dead_code)]
271 last_access: usize,
272 #[allow(dead_code)]
274 access_count: usize,
275 #[allow(dead_code)]
277 insertion_time: usize,
278 #[allow(dead_code)]
280 memory_size: usize,
281}
282
283#[allow(dead_code)]
284impl<T> CacheEntry<T> {
285 fn new(_value: T, insertiontime: usize) -> Self {
287 let memory_size = std::mem::size_of::<T>() + std::mem::size_of::<Self>();
288 Self {
289 value: _value,
290 last_access: insertiontime,
291 access_count: 1,
292 insertion_time: insertiontime,
293 memory_size,
294 }
295 }
296
297 fn update_access(&mut self, currenttime: usize) {
299 self.last_access = currenttime;
300 self.access_count += 1;
301 }
302
303 fn eviction_priority(&self, policy: EvictionPolicy, currenttime: usize) -> f64 {
305 match policy {
306 EvictionPolicy::LRU => -(self.last_access as f64),
307 EvictionPolicy::LFU => -(self.access_count as f64),
308 EvictionPolicy::FIFO => -(self.insertion_time as f64),
309 EvictionPolicy::Random => {
310 let x = (self.insertion_time * 1103515245 + 12345) & 0x7fffffff;
312 x as f64 / 0x7fffffff as f64
313 }
314 EvictionPolicy::Adaptive => {
315 let recency = (currenttime - self.last_access) as f64;
317 let frequency = self.access_count as f64;
318 let memory_factor = self.memory_size as f64;
319
320 -(frequency / (1.0 + recency + memory_factor / 1000.0))
322 }
323 }
324 }
325}
326
327impl<F: crate::traits::InterpolationFloat> Default for BSplineCache<F> {
328 fn default() -> Self {
329 Self::new(CacheConfig::default())
330 }
331}
332
333#[allow(dead_code)]
334impl<F: crate::traits::InterpolationFloat> BSplineCache<F> {
335 pub fn new(config: CacheConfig) -> Self {
337 Self {
338 basis_cache: HashMap::new(),
339 span_cache: HashMap::new(),
340 config,
341 stats: CacheStats::default(),
342 access_counter: 0,
343 }
344 }
345
346 fn get_or_compute_basis<T>(
348 &mut self,
349 x: F,
350 i: usize,
351 k: usize,
352 knots: &[T],
353 computer: impl FnOnce() -> T,
354 ) -> T
355 where
356 T: Float + Copy,
357 {
358 self.access_counter += 1;
359 let tolerance = F::from_f64(self.config.tolerance).expect("Operation failed");
360 let key = (FloatKey::new(x, tolerance), i, k);
361
362 if let Some(cache_entry) = self.basis_cache.get_mut(&key) {
363 if self.config.track_stats {
364 self.stats.hits += 1;
365 }
366
367 cache_entry.update_access(self.access_counter);
369
370 unsafe { std::mem::transmute_copy(&cache_entry.value) }
372 } else {
373 if self.config.track_stats {
374 self.stats.misses += 1;
375 }
376 let computed = computer();
377
378 let cached: F = unsafe { std::mem::transmute_copy(&computed) };
380
381 if self.basis_cache.len() >= self.config.max_basis_cache_size {
383 self.evict_basis_cache();
384 }
385
386 let cache_entry = CacheEntry::new(cached, self.access_counter);
388 self.basis_cache.insert(key, cache_entry);
389
390 if self.config.track_stats {
392 self.update_memory_usage();
393 }
394
395 computed
396 }
397 }
398
399 pub fn get_or_compute_basis_with_key(
401 &mut self,
402 key: BasisCacheKey,
403 computer: impl FnOnce() -> F,
404 ) -> F {
405 self.access_counter += 1;
406
407 let tolerance = F::from_f64(self.config.tolerance).expect("Operation failed");
409 let x = F::from_f64(f64::from_bits(key.x_quantized)).unwrap_or_else(F::zero);
410 let internal_key = (FloatKey::new(x, tolerance), key.index, key.degree);
411
412 if let Some(cache_entry) = self.basis_cache.get_mut(&internal_key) {
413 if self.config.track_stats {
414 self.stats.hits += 1;
415 }
416
417 cache_entry.update_access(self.access_counter);
419 cache_entry.value
420 } else {
421 if self.config.track_stats {
422 self.stats.misses += 1;
423 }
424 let computed = computer();
425
426 if self.basis_cache.len() >= self.config.max_basis_cache_size {
428 self.evict_basis_cache();
429 }
430
431 let cache_entry = CacheEntry::new(computed, self.access_counter);
433 self.basis_cache.insert(internal_key, cache_entry);
434
435 if self.config.track_stats {
437 self.update_memory_usage();
438 }
439
440 computed
441 }
442 }
443
444 fn get_or_compute_span(&mut self, x: F, computer: impl FnOnce() -> usize) -> usize {
446 self.access_counter += 1;
447 let tolerance = F::from_f64(self.config.tolerance).expect("Operation failed");
448 let key = FloatKey::new(x, tolerance);
449
450 if let Some(cache_entry) = self.span_cache.get_mut(&key) {
451 if self.config.track_stats {
452 self.stats.hits += 1;
453 }
454
455 cache_entry.update_access(self.access_counter);
457 cache_entry.value
458 } else {
459 if self.config.track_stats {
460 self.stats.misses += 1;
461 }
462 let computed = computer();
463
464 let cache_entry = CacheEntry::new(computed, self.access_counter);
466 self.span_cache.insert(key, cache_entry);
467
468 if self.config.track_stats {
470 self.update_memory_usage();
471 }
472
473 computed
474 }
475 }
476
477 fn update_memory_usage(&mut self) {
479 if !self.config.track_stats {
480 return;
481 }
482
483 let basis_memory: usize = self
484 .basis_cache
485 .values()
486 .map(|entry| entry.memory_size)
487 .sum();
488 let span_memory: usize = self
489 .span_cache
490 .values()
491 .map(|entry| entry.memory_size)
492 .sum();
493
494 let total_memory = basis_memory + span_memory;
495 self.stats.update_memory_usage(total_memory);
496
497 if self.config.memory_limit_mb > 0 {
499 let limit_bytes = self.config.memory_limit_mb * 1024 * 1024;
500 if total_memory > limit_bytes {
501 self.evict_basis_cache_by_memory();
502 }
503 }
504 }
505
506 fn evict_basis_cache_by_memory(&mut self) {
508 let target_size = self.config.memory_limit_mb * 1024 * 1024 * 3 / 4; let mut current_memory = self.stats.memory_usage_bytes;
510
511 if current_memory <= target_size {
512 return;
513 }
514
515 let mut entries: Vec<_> = self
517 .basis_cache
518 .iter()
519 .map(|(key, entry)| {
520 let priority =
521 entry.eviction_priority(self.config.eviction_policy, self.access_counter);
522 (key.clone(), priority, entry.memory_size)
523 })
524 .collect();
525
526 entries.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
528
529 for (key_, _, memory_size) in entries {
531 if current_memory <= target_size {
532 break;
533 }
534
535 self.basis_cache.remove(&key_);
536 current_memory = current_memory.saturating_sub(memory_size);
537
538 if self.config.track_stats {
539 self.stats.evictions += 1;
540 }
541 }
542
543 self.update_memory_usage();
545 }
546
547 fn evict_basis_cache(&mut self) {
550 let total_entries = self.basis_cache.len();
551 let remove_count = if self.config.adaptive_sizing {
552 let hit_ratio = self.stats.hit_ratio();
554 if hit_ratio > 0.8 {
555 total_entries / 8 } else if hit_ratio > 0.5 {
557 total_entries / 4 } else {
559 total_entries / 2 }
561 } else {
562 total_entries / 4 };
564
565 let mut entries: Vec<_> = self
567 .basis_cache
568 .iter()
569 .map(|(key, entry)| {
570 let priority =
571 entry.eviction_priority(self.config.eviction_policy, self.access_counter);
572 (key.clone(), priority)
573 })
574 .collect();
575
576 entries.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
578
579 for (key, _) in entries.into_iter().take(remove_count) {
581 self.basis_cache.remove(&key);
582 if self.config.track_stats {
583 self.stats.evictions += 1;
584 }
585 }
586
587 if self.config.track_stats {
589 self.stats.resize_count += 1;
590 self.update_memory_usage();
591 }
592 }
593
594 pub fn clear(&mut self) {
596 self.basis_cache.clear();
597 self.span_cache.clear();
598 }
599
600 pub fn stats(&self) -> &CacheStats {
602 &self.stats
603 }
604
605 pub fn reset_stats(&mut self) {
607 self.stats.reset();
608 }
609}
610
611#[derive(Debug)]
613pub struct CachedBSpline<T>
614where
615 T: Float
616 + FromPrimitive
617 + Debug
618 + Display
619 + Add<Output = T>
620 + Sub<Output = T>
621 + Mul<Output = T>
622 + Div<Output = T>
623 + AddAssign
624 + SubAssign
625 + MulAssign
626 + DivAssign
627 + RemAssign
628 + Zero
629 + Copy,
630{
631 spline: BSpline<T>,
633 cache: BSplineCache<T>,
635}
636
637#[allow(dead_code)]
638impl<T> CachedBSpline<T>
639where
640 T: crate::traits::InterpolationFloat,
641{
642 pub fn new(
678 knots: &ArrayView1<T>,
679 coeffs: &ArrayView1<T>,
680 degree: usize,
681 extrapolate: ExtrapolateMode,
682 cache: BSplineCache<T>,
683 ) -> InterpolateResult<Self> {
684 let spline = BSpline::new(knots, coeffs, degree, extrapolate)?;
685
686 Ok(Self { spline, cache })
687 }
688
689 pub fn evaluate_cached(&mut self, x: T) -> InterpolateResult<T> {
732 self.evaluate_with_cache_optimization(x)
736 }
737
738 fn evaluate_with_cache_optimization(&mut self, x: T) -> InterpolateResult<T> {
740 let knots = self.spline.knot_vector();
742 let coeffs = self.spline.coefficients();
743 let degree = self.spline.degree();
744
745 let knots_clone = knots.to_owned();
748 let span = self
749 .cache
750 .get_or_compute_span(x, || Self::find_knot_span(x, &knots_clone, degree));
751
752 let mut result = T::zero();
754 for i in 0..=degree {
755 let basis_index = span.saturating_sub(degree) + i;
756 if basis_index < coeffs.len() {
757 let basis_key = BasisCacheKey {
759 x_quantized: self.quantize_x(x),
760 index: basis_index,
761 degree,
762 };
763
764 let basis_value = self.cache.get_or_compute_basis_with_key(basis_key, || {
765 if basis_index < knots_clone.len() - degree - 1 {
766 Self::compute_basis_function(x, basis_index, degree, &knots_clone)
767 } else {
768 T::zero()
769 }
770 });
771
772 result += coeffs[basis_index] * basis_value;
773 }
774 }
775
776 Ok(result)
777 }
778
779 fn quantize_x(&self, x: T) -> u64 {
781 let x_f64 = x.to_f64().unwrap_or(0.0);
782 let tolerance = self.cache.config.tolerance;
783 let quantized = (x_f64 / tolerance).round() * tolerance;
784 quantized.to_bits()
785 }
786
787 fn compute_basis_functions_at_x(&self, x: T, knots: &Array1<T>, degree: usize) -> Array1<T> {
789 let span = Self::find_knot_span(x, knots, degree);
790 let mut basis = Array1::zeros(degree + 1);
791
792 basis[0] = T::one();
794
795 for j in 1..=degree {
796 let mut saved = T::zero();
797 for r in 0..j {
798 let left_knot_idx = span + 1 - j + r;
799 let right_knot_idx = span + r;
800
801 if left_knot_idx < knots.len() && right_knot_idx + 1 < knots.len() {
802 let alpha = if knots[right_knot_idx + 1] != knots[left_knot_idx] {
803 (x - knots[left_knot_idx])
804 / (knots[right_knot_idx + 1] - knots[left_knot_idx])
805 } else {
806 T::zero()
807 };
808
809 let temp = basis[r];
810 basis[r] = saved + (T::one() - alpha) * temp;
811 saved = alpha * temp;
812 }
813 }
814 basis[j] = saved;
815 }
816
817 basis
818 }
819
820 pub fn evaluate_batch_optimized(
822 &mut self,
823 x_vals: &ArrayView1<T>,
824 ) -> InterpolateResult<Array1<T>> {
825 let mut results = Array1::zeros(x_vals.len());
826
827 for &x in x_vals.iter().take(x_vals.len().min(10)) {
829 let _ = self.evaluate_cached(x)?;
830 }
831
832 for (i, &x) in x_vals.iter().enumerate() {
834 results[i] = self.evaluate_cached(x)?;
835 }
836
837 Ok(results)
838 }
839
840 pub fn refresh_cache(&mut self) {
842 self.cache.clear();
843 }
844
845 pub fn optimize_cache_settings(&mut self) {
847 let (hit_ratio, total_requests) = {
848 let stats = self.cache.stats();
849 (stats.hit_ratio(), stats.hits + stats.misses)
850 };
851
852 if hit_ratio < 0.3 && self.cache.config.max_basis_cache_size > 64 {
854 self.cache.config.max_basis_cache_size /= 2;
856 } else if hit_ratio > 0.8 && self.cache.config.max_basis_cache_size < 4096 {
857 self.cache.config.max_basis_cache_size *= 2;
859 }
860
861 if total_requests > 1000 {
863 self.cache.config.adaptive_sizing = true;
864 }
865 }
866
867 fn find_knot_span(x: T, knots: &Array1<T>, degree: usize) -> usize {
869 let n = knots.len() - degree - 1;
870
871 if x >= knots[n] {
872 return n - 1;
873 }
874 if x <= knots[degree] {
875 return degree;
876 }
877
878 let mut low = degree;
880 let mut high = n;
881 let mut mid = (low + high) / 2;
882
883 while x < knots[mid] || x >= knots[mid + 1] {
884 if x < knots[mid] {
885 high = mid;
886 } else {
887 low = mid;
888 }
889 mid = (low + high) / 2;
890 }
891
892 mid
893 }
894
895 #[allow(clippy::only_used_in_recursion)]
897 fn compute_basis_function(x: T, i: usize, degree: usize, knots: &Array1<T>) -> T {
898 if degree == 0 {
900 if i < knots.len() - 1 && x >= knots[i] && x < knots[i + 1] {
901 T::one()
902 } else {
903 T::zero()
904 }
905 } else {
906 let mut left = T::zero();
907 let mut right = T::zero();
908
909 if i < knots.len() - degree - 1 && knots[i + degree] != knots[i] {
911 let basis_left = Self::compute_basis_function(x, i, degree - 1, knots);
912 left = (x - knots[i]) / (knots[i + degree] - knots[i]) * basis_left;
913 }
914
915 if i + 1 < knots.len() - degree - 1 && knots[i + degree + 1] != knots[i + 1] {
917 let basis_right = Self::compute_basis_function(x, i + 1, degree - 1, knots);
918 right = (knots[i + degree + 1] - x) / (knots[i + degree + 1] - knots[i + 1])
919 * basis_right;
920 }
921
922 left + right
923 }
924 }
925
926 pub fn evaluate_standard(&self, x: T) -> InterpolateResult<T> {
928 self.spline.evaluate(x)
929 }
930
931 pub fn evaluate_array_cached(
968 &mut self,
969 x_vals: &ArrayView1<T>,
970 ) -> InterpolateResult<Array1<T>> {
971 let mut results = Array1::zeros(x_vals.len());
972 for (i, &x) in x_vals.iter().enumerate() {
973 results[i] = self.evaluate_cached(x)?;
974 }
975 Ok(results)
976 }
977
978 pub fn cache_stats(&self) -> &CacheStats {
980 self.cache.stats()
981 }
982
983 pub fn reset_cache_stats(&mut self) {
985 self.cache.reset_stats();
986 }
987
988 pub fn clear_cache(&mut self) {
990 self.cache.clear();
991 }
992
993 pub fn spline(&self) -> &BSpline<T> {
995 &self.spline
996 }
997}
998
999#[derive(Debug)]
1001pub struct DistanceMatrixCache<F: Float> {
1002 matrix_cache: HashMap<u64, Array2<F>>,
1004 config: CacheConfig,
1006 stats: CacheStats,
1008}
1009
1010impl<F: crate::traits::InterpolationFloat> DistanceMatrixCache<F> {
1011 pub fn new(config: CacheConfig) -> Self {
1013 Self {
1014 matrix_cache: HashMap::new(),
1015 config,
1016 stats: CacheStats::default(),
1017 }
1018 }
1019
1020 pub fn get_or_compute_distance_matrix<T>(
1022 &mut self,
1023 points: &Array2<T>,
1024 computer: impl FnOnce(&Array2<T>) -> Array2<F>,
1025 ) -> Array2<F>
1026 where
1027 T: Float,
1028 {
1029 let key = self.hash_points_safe(points);
1031
1032 if let Some(cached_matrix) = self.matrix_cache.get(&key) {
1033 if self.config.track_stats {
1034 self.stats.hits += 1;
1035 }
1036 cached_matrix.clone()
1037 } else {
1038 if self.config.track_stats {
1039 self.stats.misses += 1;
1040 }
1041
1042 let computed = computer(points);
1043
1044 if self.matrix_cache.len() >= self.config.max_distance_cache_size {
1046 self.evict_matrix_cache();
1047 }
1048
1049 self.matrix_cache.insert(key, computed.clone());
1050 computed
1051 }
1052 }
1053
1054 fn hash_points_safe<T: Float>(&self, points: &Array2<T>) -> u64 {
1056 use std::collections::hash_map::DefaultHasher;
1057 let mut hasher = DefaultHasher::new();
1058
1059 points.shape()[0].hash(&mut hasher);
1061 points.shape()[1].hash(&mut hasher);
1062
1063 let hash_stride = if points.len() > 1000 {
1065 points.len() / 100
1066 } else {
1067 1
1068 };
1069
1070 for (i, &val) in points.iter().enumerate() {
1071 if i % hash_stride == 0 {
1072 let val_f64 = val.to_f64().unwrap_or(0.0);
1074 let quantized = (val_f64 / self.config.tolerance).round() * self.config.tolerance;
1076 let bits = quantized.to_bits();
1077 bits.hash(&mut hasher);
1078 }
1079 }
1080
1081 hasher.finish()
1082 }
1083
1084 fn evict_matrix_cache(&mut self) {
1086 let remove_count = self.matrix_cache.len() / 4; let keys_to_remove: Vec<_> = self
1088 .matrix_cache
1089 .keys()
1090 .take(remove_count)
1091 .cloned()
1092 .collect();
1093
1094 for key in keys_to_remove {
1095 self.matrix_cache.remove(&key);
1096 if self.config.track_stats {
1097 self.stats.evictions += 1;
1098 }
1099 }
1100 }
1101
1102 pub fn clear(&mut self) {
1104 self.matrix_cache.clear();
1105 }
1106
1107 pub fn stats(&self) -> &CacheStats {
1109 &self.stats
1110 }
1111
1112 pub fn reset_stats(&mut self) {
1114 self.stats.reset();
1115 }
1116}
1117
1118#[allow(dead_code)]
1120pub fn make_cached_bspline<T>(
1121 knots: &ArrayView1<T>,
1122 coeffs: &ArrayView1<T>,
1123 degree: usize,
1124 extrapolate: ExtrapolateMode,
1125) -> InterpolateResult<CachedBSpline<T>>
1126where
1127 T: crate::traits::InterpolationFloat,
1128{
1129 let cache = BSplineCache::default();
1130 CachedBSpline::new(knots, coeffs, degree, extrapolate, cache)
1131}
1132
1133#[allow(dead_code)]
1135pub fn make_cached_bspline_with_config<T>(
1136 knots: &ArrayView1<T>,
1137 coeffs: &ArrayView1<T>,
1138 degree: usize,
1139 extrapolate: ExtrapolateMode,
1140 cache_config: CacheConfig,
1141) -> InterpolateResult<CachedBSpline<T>>
1142where
1143 T: crate::traits::InterpolationFloat,
1144{
1145 let cache = BSplineCache::new(cache_config);
1146 CachedBSpline::new(knots, coeffs, degree, extrapolate, cache)
1147}
1148
1149#[cfg(test)]
1150mod tests {
1151 use super::*;
1152 use approx::assert_relative_eq;
1153 use scirs2_core::ndarray::array;
1154
1155 #[test]
1156 fn test_cached_bspline_evaluation() {
1157 let knots = array![0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0];
1159 let coeffs = array![1.0, 2.0, 3.0, 4.0, 5.0];
1160
1161 let mut cached_spline = make_cached_bspline(
1162 &knots.view(),
1163 &coeffs.view(),
1164 2, ExtrapolateMode::Extrapolate,
1166 )
1167 .expect("Operation failed");
1168
1169 let test_points = array![0.5, 1.0, 1.5, 2.0, 2.5];
1171
1172 for &x in test_points.iter() {
1173 let cached_result = cached_spline.evaluate_cached(x).expect("Operation failed");
1174 let standard_result = cached_spline
1175 .evaluate_standard(x)
1176 .expect("Operation failed");
1177
1178 assert_relative_eq!(cached_result, standard_result, epsilon = 1e-10);
1180 }
1181 }
1182
1183 #[test]
1184 fn test_cache_statistics() {
1185 let knots = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1186 let coeffs = array![1.0, 2.0, 3.0, 4.0];
1187
1188 let mut cached_spline = make_cached_bspline_with_config(
1189 &knots.view(),
1190 &coeffs.view(),
1191 2,
1192 ExtrapolateMode::Extrapolate,
1193 CacheConfig {
1194 track_stats: true,
1195 ..Default::default()
1196 },
1197 )
1198 .expect("Operation failed");
1199
1200 let _ = cached_spline
1202 .evaluate_cached(1.5)
1203 .expect("Operation failed");
1204 let stats_after_first = cached_spline.cache_stats();
1205 assert!(stats_after_first.misses > 0);
1206
1207 let _ = cached_spline
1209 .evaluate_cached(1.5)
1210 .expect("Operation failed");
1211 let stats_after_second = cached_spline.cache_stats();
1212 assert!(stats_after_second.hits > 0);
1213 }
1214
1215 #[test]
1216 fn test_distance_matrix_cache() {
1217 let config = CacheConfig {
1218 track_stats: true,
1219 max_distance_cache_size: 10,
1220 ..Default::default()
1221 };
1222 let mut cache = DistanceMatrixCache::<f64>::new(config);
1223
1224 let points = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0])
1226 .expect("Operation failed");
1227
1228 let result1 = cache.get_or_compute_distance_matrix(&points, |pts| {
1230 let n = pts.nrows();
1231 let mut distances = Array2::zeros((n, n));
1232 for i in 0..n {
1233 for j in 0..n {
1234 let diff = &pts.slice(scirs2_core::ndarray::s![i, ..])
1235 - &pts.slice(scirs2_core::ndarray::s![j, ..]);
1236 distances[[i, j]] = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
1237 }
1238 }
1239 distances
1240 });
1241
1242 assert_eq!(result1.shape(), &[3, 3]);
1243 assert_eq!(cache.stats().misses, 1);
1244 assert_eq!(cache.stats().hits, 0);
1245
1246 let result2 = cache.get_or_compute_distance_matrix(&points, |_| {
1248 panic!("Should not be called on cache hit");
1249 });
1250
1251 assert_eq!(result1, result2);
1252 assert_eq!(cache.stats().misses, 1);
1253 assert_eq!(cache.stats().hits, 1);
1254
1255 let different_points =
1257 Array2::from_shape_vec((2, 2), vec![2.0, 2.0, 3.0, 3.0]).expect("Operation failed");
1258 let _result3 = cache.get_or_compute_distance_matrix(&different_points, |pts| {
1259 let n = pts.nrows();
1260 Array2::zeros((n, n))
1261 });
1262
1263 assert_eq!(cache.stats().misses, 2);
1264 assert_eq!(cache.stats().hits, 1);
1265 }
1266}