1use crate::{Nystroem, RBFSampler};
7use scirs2_core::ndarray::Array2;
8use scirs2_linalg::compat::{Norm, SVD};
9use sklears_core::traits::Fit;
10use sklears_core::{
11 error::{Result, SklearsError},
12 traits::Transform,
13};
14use std::time::Instant;
15
16#[derive(Debug, Clone)]
18pub enum ProgressiveStrategy {
20 Doubling,
22 FixedIncrement { increment: usize },
24 AdaptiveIncrement {
26 min_increment: usize,
27
28 max_increment: usize,
29
30 improvement_threshold: f64,
31 },
32 Exponential { base: f64 },
34 Fibonacci,
36}
37
38#[derive(Debug, Clone)]
40pub enum StoppingCriterion {
42 TargetQuality { quality: f64 },
44 ImprovementThreshold { threshold: f64 },
46 MaxIterations { max_iter: usize },
48 MaxComponents { max_components: usize },
50 Combined {
52 quality: Option<f64>,
53 improvement_threshold: Option<f64>,
54 max_iter: Option<usize>,
55 max_components: Option<usize>,
56 },
57}
58
59#[derive(Debug, Clone)]
61pub enum ProgressiveQualityMetric {
63 KernelAlignment,
65 FrobeniusError,
67 SpectralError,
69 EffectiveRank,
71 RelativeImprovement,
73 Custom,
75}
76
77#[derive(Debug, Clone)]
79pub struct ProgressiveConfig {
81 pub initial_components: usize,
83 pub strategy: ProgressiveStrategy,
85 pub stopping_criterion: StoppingCriterion,
87 pub quality_metric: ProgressiveQualityMetric,
89 pub n_trials: usize,
91 pub random_seed: Option<u64>,
93 pub validation_fraction: f64,
95 pub store_intermediate: bool,
97}
98
99impl Default for ProgressiveConfig {
100 fn default() -> Self {
101 Self {
102 initial_components: 10,
103 strategy: ProgressiveStrategy::Doubling,
104 stopping_criterion: StoppingCriterion::Combined {
105 quality: Some(0.95),
106 improvement_threshold: Some(0.01),
107 max_iter: Some(10),
108 max_components: Some(1000),
109 },
110 quality_metric: ProgressiveQualityMetric::KernelAlignment,
111 n_trials: 3,
112 random_seed: None,
113 validation_fraction: 0.2,
114 store_intermediate: true,
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct ProgressiveStep {
123 pub n_components: usize,
125 pub quality_score: f64,
127 pub improvement: f64,
129 pub time_taken: f64,
131 pub iteration: usize,
133}
134
135#[derive(Debug, Clone)]
137pub struct ProgressiveResult {
139 pub final_components: usize,
141 pub final_quality: f64,
143 pub steps: Vec<ProgressiveStep>,
145 pub converged: bool,
147 pub stopping_reason: String,
149 pub total_time: f64,
151}
152
153#[derive(Debug, Clone)]
155pub struct ProgressiveRBFSampler {
157 gamma: f64,
158 config: ProgressiveConfig,
159}
160
161impl Default for ProgressiveRBFSampler {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167impl ProgressiveRBFSampler {
168 pub fn new() -> Self {
170 Self {
171 gamma: 1.0,
172 config: ProgressiveConfig::default(),
173 }
174 }
175
176 pub fn gamma(mut self, gamma: f64) -> Self {
178 self.gamma = gamma;
179 self
180 }
181
182 pub fn config(mut self, config: ProgressiveConfig) -> Self {
184 self.config = config;
185 self
186 }
187
188 pub fn initial_components(mut self, components: usize) -> Self {
190 self.config.initial_components = components;
191 self
192 }
193
194 pub fn strategy(mut self, strategy: ProgressiveStrategy) -> Self {
196 self.config.strategy = strategy;
197 self
198 }
199
200 pub fn stopping_criterion(mut self, criterion: StoppingCriterion) -> Self {
202 self.config.stopping_criterion = criterion;
203 self
204 }
205
206 pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
208 let start_time = Instant::now();
209 let n_samples = x.nrows();
210
211 let split_idx = (n_samples as f64 * (1.0 - self.config.validation_fraction)) as usize;
213 let x_train = x
214 .slice(scirs2_core::ndarray::s![..split_idx, ..])
215 .to_owned();
216 let x_val = x
217 .slice(scirs2_core::ndarray::s![split_idx.., ..])
218 .to_owned();
219
220 let k_exact = self.compute_exact_kernel_matrix(&x_val)?;
222
223 let mut steps = Vec::new();
224 let mut current_components = self.config.initial_components;
225 let mut previous_quality = 0.0;
226 let mut iteration = 0;
227 let result;
228
229 let mut fib_prev = 1;
231 let mut fib_curr = 1;
232
233 loop {
234 let step_start = Instant::now();
235
236 let quality = self.compute_quality_for_components(
238 current_components,
239 &x_train,
240 &x_val,
241 &k_exact,
242 )?;
243
244 let improvement = if iteration == 0 {
245 quality
246 } else {
247 quality - previous_quality
248 };
249
250 let step_time = step_start.elapsed().as_secs_f64();
251
252 let step = ProgressiveStep {
254 n_components: current_components,
255 quality_score: quality,
256 improvement,
257 time_taken: step_time,
258 iteration,
259 };
260 steps.push(step);
261
262 if let Some(stop_result) =
264 self.check_stopping_criteria(quality, improvement, iteration, current_components)
265 {
266 result = Some(stop_result);
267 break;
268 }
269
270 previous_quality = quality;
272 iteration += 1;
273
274 current_components = match &self.config.strategy {
276 ProgressiveStrategy::Doubling => current_components * 2,
277 ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
278 ProgressiveStrategy::AdaptiveIncrement {
279 min_increment,
280 max_increment,
281 improvement_threshold,
282 } => {
283 let increment = if improvement > *improvement_threshold {
284 *min_increment
285 } else {
286 (*min_increment + (*max_increment - *min_increment) / 2).max(*min_increment)
287 };
288 current_components + increment
289 }
290 ProgressiveStrategy::Exponential { base } => {
291 ((current_components as f64) * base) as usize
292 }
293 ProgressiveStrategy::Fibonacci => {
294 let next_fib = fib_prev + fib_curr;
295 fib_prev = fib_curr;
296 fib_curr = next_fib;
297 self.config.initial_components + fib_curr
298 }
299 };
300 }
301
302 let total_time = start_time.elapsed().as_secs_f64();
303 let (converged, stopping_reason) =
304 result.unwrap_or((false, "Max iterations reached".to_string()));
305
306 Ok(ProgressiveResult {
307 final_components: steps
308 .last()
309 .map(|s| s.n_components)
310 .unwrap_or(current_components),
311 final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
312 steps,
313 converged,
314 stopping_reason,
315 total_time,
316 })
317 }
318
319 fn compute_exact_kernel_matrix(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
321 let n_samples = x.nrows().min(100); let x_subset = x.slice(scirs2_core::ndarray::s![..n_samples, ..]);
323
324 let mut k_exact = Array2::zeros((n_samples, n_samples));
325
326 for i in 0..n_samples {
327 for j in 0..n_samples {
328 let diff = &x_subset.row(i) - &x_subset.row(j);
329 let squared_norm = diff.dot(&diff);
330 k_exact[[i, j]] = (-self.gamma * squared_norm).exp();
331 }
332 }
333
334 Ok(k_exact)
335 }
336
337 fn compute_quality_for_components(
339 &self,
340 n_components: usize,
341 x_train: &Array2<f64>,
342 x_val: &Array2<f64>,
343 k_exact: &Array2<f64>,
344 ) -> Result<f64> {
345 let mut trial_qualities = Vec::new();
346
347 for trial in 0..self.config.n_trials {
349 let seed = self.config.random_seed.map(|s| s + trial as u64);
350 let sampler = if let Some(s) = seed {
351 RBFSampler::new(n_components)
352 .gamma(self.gamma)
353 .random_state(s)
354 } else {
355 RBFSampler::new(n_components).gamma(self.gamma)
356 };
357
358 let fitted = sampler.fit(x_train, &())?;
359 let x_val_transformed = fitted.transform(x_val)?;
360
361 let quality = self.compute_quality_metric(x_val, &x_val_transformed, k_exact)?;
362 trial_qualities.push(quality);
363 }
364
365 Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
367 }
368
369 fn compute_quality_metric(
371 &self,
372 _x: &Array2<f64>,
373 x_transformed: &Array2<f64>,
374 k_exact: &Array2<f64>,
375 ) -> Result<f64> {
376 match &self.config.quality_metric {
377 ProgressiveQualityMetric::KernelAlignment => {
378 self.compute_kernel_alignment(x_transformed, k_exact)
379 }
380 ProgressiveQualityMetric::FrobeniusError => {
381 self.compute_frobenius_error(x_transformed, k_exact)
382 }
383 ProgressiveQualityMetric::SpectralError => {
384 self.compute_spectral_error(x_transformed, k_exact)
385 }
386 ProgressiveQualityMetric::EffectiveRank => self.compute_effective_rank(x_transformed),
387 ProgressiveQualityMetric::RelativeImprovement => {
388 Ok(1.0)
390 }
391 ProgressiveQualityMetric::Custom => {
392 self.compute_kernel_alignment(x_transformed, k_exact)
394 }
395 }
396 }
397
398 fn compute_kernel_alignment(
400 &self,
401 x_transformed: &Array2<f64>,
402 k_exact: &Array2<f64>,
403 ) -> Result<f64> {
404 let n_samples = k_exact.nrows().min(x_transformed.nrows());
405 let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
406
407 let k_approx = x_subset.dot(&x_subset.t());
409
410 let k_exact_norm = k_exact.norm_l2();
412 let k_approx_norm = k_approx.norm_l2();
413
414 if k_exact_norm > 1e-12 && k_approx_norm > 1e-12 {
415 let alignment = (k_exact * &k_approx).sum() / (k_exact_norm * k_approx_norm);
416 Ok(alignment)
417 } else {
418 Ok(0.0)
419 }
420 }
421
422 fn compute_frobenius_error(
424 &self,
425 x_transformed: &Array2<f64>,
426 k_exact: &Array2<f64>,
427 ) -> Result<f64> {
428 let n_samples = k_exact.nrows().min(x_transformed.nrows());
429 let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
430
431 let k_approx = x_subset.dot(&x_subset.t());
433
434 let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
436 let error = diff.norm_l2();
437 let quality = 1.0 / (1.0 + error); Ok(quality)
440 }
441
442 fn compute_spectral_error(
444 &self,
445 x_transformed: &Array2<f64>,
446 k_exact: &Array2<f64>,
447 ) -> Result<f64> {
448 let n_samples = k_exact.nrows().min(x_transformed.nrows());
449 let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
450
451 let k_approx = x_subset.dot(&x_subset.t());
453
454 let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
456 let (_, s, _) = diff
457 .svd(false)
458 .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
459
460 let spectral_error = s.iter().fold(0.0f64, |acc, &x| acc.max(x));
461 let quality = 1.0 / (1.0 + spectral_error);
462
463 Ok(quality)
464 }
465
466 fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
468 let (_, s, _) = x_transformed
470 .svd(true)
471 .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
472
473 let s_sum = s.sum();
475 if s_sum == 0.0 {
476 return Ok(0.0);
477 }
478
479 let s_normalized = &s / s_sum;
480 let entropy = -s_normalized
481 .iter()
482 .filter(|&&x| x > 1e-12)
483 .map(|&x| x * x.ln())
484 .sum::<f64>();
485
486 let effective_rank = entropy.exp();
487 Ok(effective_rank / x_transformed.ncols() as f64) }
489
490 fn check_stopping_criteria(
492 &self,
493 quality: f64,
494 improvement: f64,
495 iteration: usize,
496 components: usize,
497 ) -> Option<(bool, String)> {
498 match &self.config.stopping_criterion {
499 StoppingCriterion::TargetQuality { quality: target } => {
500 if quality >= *target {
501 Some((true, format!("Target quality {} reached", target)))
502 } else {
503 None
504 }
505 }
506 StoppingCriterion::ImprovementThreshold { threshold } => {
507 if iteration > 0 && improvement < *threshold {
508 Some((
509 true,
510 format!("Improvement {} below threshold {}", improvement, threshold),
511 ))
512 } else {
513 None
514 }
515 }
516 StoppingCriterion::MaxIterations { max_iter } => {
517 if iteration + 1 >= *max_iter {
518 Some((false, format!("Maximum iterations {} reached", max_iter)))
519 } else {
520 None
521 }
522 }
523 StoppingCriterion::MaxComponents { max_components } => {
524 if components >= *max_components {
525 Some((
526 false,
527 format!("Maximum components {} reached", max_components),
528 ))
529 } else {
530 None
531 }
532 }
533 StoppingCriterion::Combined {
534 quality: target_quality,
535 improvement_threshold,
536 max_iter,
537 max_components,
538 } => {
539 if let Some(target) = target_quality {
541 if quality >= *target {
542 return Some((true, format!("Target quality {} reached", target)));
543 }
544 }
545
546 if let Some(threshold) = improvement_threshold {
548 if iteration > 0 && improvement < *threshold {
549 return Some((
550 true,
551 format!("Improvement {} below threshold {}", improvement, threshold),
552 ));
553 }
554 }
555
556 if let Some(max) = max_iter {
558 if iteration >= *max {
559 return Some((false, format!("Maximum iterations {} reached", max)));
560 }
561 }
562
563 if let Some(max) = max_components {
565 if components >= *max {
566 return Some((false, format!("Maximum components {} reached", max)));
567 }
568 }
569
570 None
571 }
572 }
573 }
574}
575
576pub struct FittedProgressiveRBFSampler {
578 fitted_rbf: crate::rbf_sampler::RBFSampler<sklears_core::traits::Trained>,
579 progressive_result: ProgressiveResult,
580}
581
582impl Fit<Array2<f64>, ()> for ProgressiveRBFSampler {
583 type Fitted = FittedProgressiveRBFSampler;
584
585 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
586 let progressive_result = self.run_progressive_approximation(x)?;
588
589 let rbf_sampler = RBFSampler::new(progressive_result.final_components).gamma(self.gamma);
591 let fitted_rbf = rbf_sampler.fit(x, &())?;
592
593 Ok(FittedProgressiveRBFSampler {
594 fitted_rbf,
595 progressive_result,
596 })
597 }
598}
599
600impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveRBFSampler {
601 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
602 self.fitted_rbf.transform(x)
603 }
604}
605
606impl FittedProgressiveRBFSampler {
607 pub fn progressive_result(&self) -> &ProgressiveResult {
609 &self.progressive_result
610 }
611
612 pub fn final_components(&self) -> usize {
614 self.progressive_result.final_components
615 }
616
617 pub fn final_quality(&self) -> f64 {
619 self.progressive_result.final_quality
620 }
621
622 pub fn converged(&self) -> bool {
624 self.progressive_result.converged
625 }
626
627 pub fn steps(&self) -> &[ProgressiveStep] {
629 &self.progressive_result.steps
630 }
631
632 pub fn stopping_reason(&self) -> &str {
634 &self.progressive_result.stopping_reason
635 }
636}
637
638#[derive(Debug, Clone)]
640pub struct ProgressiveNystroem {
642 kernel: crate::nystroem::Kernel,
643 config: ProgressiveConfig,
644}
645
646impl Default for ProgressiveNystroem {
647 fn default() -> Self {
648 Self::new()
649 }
650}
651
652impl ProgressiveNystroem {
653 pub fn new() -> Self {
655 Self {
656 kernel: crate::nystroem::Kernel::Rbf { gamma: 1.0 },
657 config: ProgressiveConfig::default(),
658 }
659 }
660
661 pub fn gamma(mut self, gamma: f64) -> Self {
663 self.kernel = crate::nystroem::Kernel::Rbf { gamma };
664 self
665 }
666
667 pub fn kernel(mut self, kernel: crate::nystroem::Kernel) -> Self {
669 self.kernel = kernel;
670 self
671 }
672
673 pub fn config(mut self, config: ProgressiveConfig) -> Self {
675 self.config = config;
676 self
677 }
678
679 pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
681 let start_time = Instant::now();
682
683 let mut steps = Vec::new();
684 let mut current_components = self.config.initial_components;
685 let mut previous_quality = 0.0;
686 let mut iteration = 0;
687 let result;
688
689 loop {
690 let step_start = Instant::now();
691
692 let quality = self.compute_nystroem_quality(current_components, x)?;
694
695 let improvement = if iteration == 0 {
696 quality
697 } else {
698 quality - previous_quality
699 };
700
701 let step_time = step_start.elapsed().as_secs_f64();
702
703 let step = ProgressiveStep {
705 n_components: current_components,
706 quality_score: quality,
707 improvement,
708 time_taken: step_time,
709 iteration,
710 };
711 steps.push(step);
712
713 if let Some(stop_result) =
715 self.check_stopping_criteria(quality, improvement, iteration, current_components)
716 {
717 result = Some(stop_result);
718 break;
719 }
720
721 previous_quality = quality;
723 iteration += 1;
724
725 current_components = match &self.config.strategy {
727 ProgressiveStrategy::Doubling => current_components * 2,
728 ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
729 _ => current_components * 2, };
731 }
732
733 let total_time = start_time.elapsed().as_secs_f64();
734 let (converged, stopping_reason) =
735 result.unwrap_or((false, "Max iterations reached".to_string()));
736
737 Ok(ProgressiveResult {
738 final_components: steps
739 .last()
740 .map(|s| s.n_components)
741 .unwrap_or(current_components),
742 final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
743 steps,
744 converged,
745 stopping_reason,
746 total_time,
747 })
748 }
749
750 fn compute_nystroem_quality(&self, n_components: usize, x: &Array2<f64>) -> Result<f64> {
752 let mut trial_qualities = Vec::new();
753
754 for trial in 0..self.config.n_trials {
756 let seed = self.config.random_seed.map(|s| s + trial as u64);
757 let nystroem = if let Some(s) = seed {
758 Nystroem::new(self.kernel.clone(), n_components).random_state(s)
759 } else {
760 Nystroem::new(self.kernel.clone(), n_components)
761 };
762
763 let fitted = nystroem.fit(x, &())?;
764 let x_transformed = fitted.transform(x)?;
765
766 let quality = self.compute_effective_rank(&x_transformed)?;
768 trial_qualities.push(quality);
769 }
770
771 Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
772 }
773
774 fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
776 let (_, s, _) = x_transformed
777 .svd(true)
778 .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
779
780 let s_sum = s.sum();
781 if s_sum == 0.0 {
782 return Ok(0.0);
783 }
784
785 let s_normalized = &s / s_sum;
786 let entropy = -s_normalized
787 .iter()
788 .filter(|&&x| x > 1e-12)
789 .map(|&x| x * x.ln())
790 .sum::<f64>();
791
792 let effective_rank = entropy.exp();
793 Ok(effective_rank / x_transformed.ncols() as f64)
794 }
795
796 fn check_stopping_criteria(
798 &self,
799 quality: f64,
800 _improvement: f64,
801 iteration: usize,
802 _components: usize,
803 ) -> Option<(bool, String)> {
804 match &self.config.stopping_criterion {
805 StoppingCriterion::TargetQuality { quality: target } => {
806 if quality >= *target {
807 Some((true, format!("Target quality {} reached", target)))
808 } else {
809 None
810 }
811 }
812 StoppingCriterion::MaxIterations { max_iter } => {
813 if iteration + 1 >= *max_iter {
814 Some((false, format!("Maximum iterations {} reached", max_iter)))
815 } else {
816 None
817 }
818 }
819 _ => None, }
821 }
822}
823
824pub struct FittedProgressiveNystroem {
826 fitted_nystroem: crate::nystroem::Nystroem<sklears_core::traits::Trained>,
827 progressive_result: ProgressiveResult,
828}
829
830impl Fit<Array2<f64>, ()> for ProgressiveNystroem {
831 type Fitted = FittedProgressiveNystroem;
832
833 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
834 let progressive_result = self.run_progressive_approximation(x)?;
836
837 let nystroem = Nystroem::new(self.kernel, progressive_result.final_components);
839 let fitted_nystroem = nystroem.fit(x, &())?;
840
841 Ok(FittedProgressiveNystroem {
842 fitted_nystroem,
843 progressive_result,
844 })
845 }
846}
847
848impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveNystroem {
849 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
850 self.fitted_nystroem.transform(x)
851 }
852}
853
854impl FittedProgressiveNystroem {
855 pub fn progressive_result(&self) -> &ProgressiveResult {
857 &self.progressive_result
858 }
859
860 pub fn final_components(&self) -> usize {
862 self.progressive_result.final_components
863 }
864
865 pub fn final_quality(&self) -> f64 {
867 self.progressive_result.final_quality
868 }
869
870 pub fn converged(&self) -> bool {
872 self.progressive_result.converged
873 }
874}
875
876#[allow(non_snake_case)]
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use approx::assert_abs_diff_eq;
881
882 #[test]
883 fn test_progressive_rbf_sampler() {
884 let x = Array2::from_shape_vec((100, 4), (0..400).map(|i| (i as f64) * 0.01).collect())
885 .unwrap();
886
887 let config = ProgressiveConfig {
888 initial_components: 5,
889 strategy: ProgressiveStrategy::Doubling,
890 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
891 quality_metric: ProgressiveQualityMetric::KernelAlignment,
892 n_trials: 2,
893 validation_fraction: 0.3,
894 ..Default::default()
895 };
896
897 let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
898
899 let fitted = sampler.fit(&x, &()).unwrap();
900 let transformed = fitted.transform(&x).unwrap();
901
902 assert_eq!(transformed.nrows(), 100);
903 assert!(fitted.final_components() >= 5);
904 assert!(fitted.final_quality() >= 0.0);
905 assert_eq!(fitted.steps().len(), 3); }
907
908 #[test]
909 fn test_progressive_nystroem() {
910 let x =
911 Array2::from_shape_vec((80, 3), (0..240).map(|i| (i as f64) * 0.02).collect()).unwrap();
912
913 let config = ProgressiveConfig {
914 initial_components: 10,
915 strategy: ProgressiveStrategy::FixedIncrement { increment: 5 },
916 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
917 n_trials: 2,
918 ..Default::default()
919 };
920
921 let nystroem = ProgressiveNystroem::new().gamma(1.0).config(config);
922
923 let fitted = nystroem.fit(&x, &()).unwrap();
924 let transformed = fitted.transform(&x).unwrap();
925
926 assert_eq!(transformed.nrows(), 80);
927 assert!(fitted.final_components() >= 10);
928 assert!(fitted.final_quality() >= 0.0);
929 }
930
931 #[test]
932 fn test_progressive_strategies() {
933 let x =
934 Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.05).collect()).unwrap();
935
936 let strategies = vec![
937 ProgressiveStrategy::Doubling,
938 ProgressiveStrategy::FixedIncrement { increment: 3 },
939 ProgressiveStrategy::Exponential { base: 1.5 },
940 ProgressiveStrategy::Fibonacci,
941 ];
942
943 for strategy in strategies {
944 let config = ProgressiveConfig {
945 initial_components: 5,
946 strategy,
947 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
948 n_trials: 1,
949 ..Default::default()
950 };
951
952 let sampler = ProgressiveRBFSampler::new().gamma(0.8).config(config);
953
954 let result = sampler.run_progressive_approximation(&x).unwrap();
955
956 assert!(result.final_components >= 5);
957 assert!(result.final_quality >= 0.0);
958 assert_eq!(result.steps.len(), 3);
959 }
960 }
961
962 #[test]
963 fn test_stopping_criteria() {
964 let x =
965 Array2::from_shape_vec((60, 3), (0..180).map(|i| (i as f64) * 0.03).collect()).unwrap();
966
967 let criteria = vec![
968 StoppingCriterion::TargetQuality { quality: 0.8 },
969 StoppingCriterion::ImprovementThreshold { threshold: 0.01 },
970 StoppingCriterion::MaxIterations { max_iter: 5 },
971 StoppingCriterion::MaxComponents { max_components: 50 },
972 ];
973
974 for criterion in criteria {
975 let config = ProgressiveConfig {
976 initial_components: 10,
977 strategy: ProgressiveStrategy::Doubling,
978 stopping_criterion: criterion,
979 n_trials: 1,
980 ..Default::default()
981 };
982
983 let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
984
985 let result = sampler.run_progressive_approximation(&x).unwrap();
986
987 assert!(result.final_components >= 10);
988 assert!(result.final_quality >= 0.0);
989 assert!(!result.stopping_reason.is_empty());
990 }
991 }
992
993 #[test]
994 fn test_quality_metrics() {
995 let x =
996 Array2::from_shape_vec((40, 2), (0..80).map(|i| (i as f64) * 0.05).collect()).unwrap();
997
998 let metrics = vec![
999 ProgressiveQualityMetric::KernelAlignment,
1000 ProgressiveQualityMetric::FrobeniusError,
1001 ProgressiveQualityMetric::SpectralError,
1002 ProgressiveQualityMetric::EffectiveRank,
1003 ];
1004
1005 for metric in metrics {
1006 let config = ProgressiveConfig {
1007 initial_components: 5,
1008 strategy: ProgressiveStrategy::Doubling,
1009 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
1010 quality_metric: metric,
1011 n_trials: 1,
1012 ..Default::default()
1013 };
1014
1015 let sampler = ProgressiveRBFSampler::new().gamma(0.3).config(config);
1016
1017 let result = sampler.run_progressive_approximation(&x).unwrap();
1018
1019 assert!(result.final_components >= 5);
1020 assert!(result.final_quality >= 0.0);
1021
1022 for step in &result.steps {
1024 assert!(step.quality_score >= 0.0);
1025 assert!(step.time_taken >= 0.0);
1026 }
1027 }
1028 }
1029
1030 #[test]
1031 fn test_progressive_improvement() {
1032 let x =
1033 Array2::from_shape_vec((70, 3), (0..210).map(|i| (i as f64) * 0.02).collect()).unwrap();
1034
1035 let config = ProgressiveConfig {
1036 initial_components: 10,
1037 strategy: ProgressiveStrategy::Doubling,
1038 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
1039 quality_metric: ProgressiveQualityMetric::KernelAlignment,
1040 n_trials: 2,
1041 ..Default::default()
1042 };
1043
1044 let sampler = ProgressiveRBFSampler::new().gamma(0.7).config(config);
1045
1046 let result = sampler.run_progressive_approximation(&x).unwrap();
1047
1048 for i in 1..result.steps.len() {
1050 let current_quality = result.steps[i].quality_score;
1051 let previous_quality = result.steps[i - 1].quality_score;
1052
1053 assert!(
1055 current_quality >= previous_quality - 0.1,
1056 "Quality should not decrease significantly: {} -> {}",
1057 previous_quality,
1058 current_quality
1059 );
1060 }
1061 }
1062
1063 #[test]
1064 fn test_progressive_reproducibility() {
1065 let x =
1066 Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.04).collect()).unwrap();
1067
1068 let config = ProgressiveConfig {
1069 initial_components: 5,
1070 strategy: ProgressiveStrategy::Doubling,
1071 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
1072 n_trials: 2,
1073 random_seed: Some(42),
1074 ..Default::default()
1075 };
1076
1077 let sampler1 = ProgressiveRBFSampler::new()
1078 .gamma(0.6)
1079 .config(config.clone());
1080
1081 let sampler2 = ProgressiveRBFSampler::new().gamma(0.6).config(config);
1082
1083 let result1 = sampler1.run_progressive_approximation(&x).unwrap();
1084 let result2 = sampler2.run_progressive_approximation(&x).unwrap();
1085
1086 assert_eq!(result1.final_components, result2.final_components);
1087 assert_abs_diff_eq!(
1088 result1.final_quality,
1089 result2.final_quality,
1090 epsilon = 1e-10
1091 );
1092 assert_eq!(result1.steps.len(), result2.steps.len());
1093 }
1094}