sklears_kernel_approximation/
progressive.rs

1//! Progressive kernel approximation methods
2//!
3//! This module provides progressive approximation strategies that start with
4//! coarse approximations and progressively refine them based on quality criteria.
5
6use crate::{Nystroem, RBFSampler};
7use scirs2_core::ndarray::ndarray_linalg::{Norm, SVD};
8use scirs2_core::ndarray::Array2;
9use sklears_core::traits::Fit;
10use sklears_core::{
11    error::{Result, SklearsError},
12    traits::Transform,
13};
14use std::time::Instant;
15
16/// Progressive refinement strategies
17#[derive(Debug, Clone)]
18/// ProgressiveStrategy
19pub enum ProgressiveStrategy {
20    /// Double the number of components at each step
21    Doubling,
22    /// Add a fixed number of components at each step
23    FixedIncrement { increment: usize },
24    /// Adaptive increment based on quality improvement
25    AdaptiveIncrement {
26        min_increment: usize,
27
28        max_increment: usize,
29
30        improvement_threshold: f64,
31    },
32    /// Exponential growth with custom base
33    Exponential { base: f64 },
34    /// Fibonacci-based growth
35    Fibonacci,
36}
37
38/// Stopping criteria for progressive approximation
39#[derive(Debug, Clone)]
40/// StoppingCriterion
41pub enum StoppingCriterion {
42    /// Stop when target quality is reached
43    TargetQuality { quality: f64 },
44    /// Stop when improvement falls below threshold
45    ImprovementThreshold { threshold: f64 },
46    /// Stop after maximum number of iterations
47    MaxIterations { max_iter: usize },
48    /// Stop when maximum components is reached
49    MaxComponents { max_components: usize },
50    /// Combined criteria (all must be satisfied)
51    Combined {
52        quality: Option<f64>,
53        improvement_threshold: Option<f64>,
54        max_iter: Option<usize>,
55        max_components: Option<usize>,
56    },
57}
58
59/// Quality metrics for progressive approximation
60#[derive(Debug, Clone)]
61/// ProgressiveQualityMetric
62pub enum ProgressiveQualityMetric {
63    /// Kernel alignment between exact and approximate kernels
64    KernelAlignment,
65    /// Frobenius norm of approximation error
66    FrobeniusError,
67    /// Spectral norm of approximation error
68    SpectralError,
69    /// Effective rank of the approximation
70    EffectiveRank,
71    /// Relative improvement over previous iteration
72    RelativeImprovement,
73    /// Custom quality function
74    Custom,
75}
76
77/// Configuration for progressive approximation
78#[derive(Debug, Clone)]
79/// ProgressiveConfig
80pub struct ProgressiveConfig {
81    /// Initial number of components
82    pub initial_components: usize,
83    /// Progressive strategy
84    pub strategy: ProgressiveStrategy,
85    /// Stopping criterion
86    pub stopping_criterion: StoppingCriterion,
87    /// Quality metric to optimize
88    pub quality_metric: ProgressiveQualityMetric,
89    /// Number of trials per iteration for stability
90    pub n_trials: usize,
91    /// Random seed for reproducibility
92    pub random_seed: Option<u64>,
93    /// Validation fraction for quality assessment
94    pub validation_fraction: f64,
95    /// Whether to store intermediate results
96    pub store_intermediate: bool,
97}
98
99impl Default for ProgressiveConfig {
100    fn default() -> Self {
101        Self {
102            initial_components: 10,
103            strategy: ProgressiveStrategy::Doubling,
104            stopping_criterion: StoppingCriterion::Combined {
105                quality: Some(0.95),
106                improvement_threshold: Some(0.01),
107                max_iter: Some(10),
108                max_components: Some(1000),
109            },
110            quality_metric: ProgressiveQualityMetric::KernelAlignment,
111            n_trials: 3,
112            random_seed: None,
113            validation_fraction: 0.2,
114            store_intermediate: true,
115        }
116    }
117}
118
119/// Results from a single progressive step
120#[derive(Debug, Clone)]
121/// ProgressiveStep
122pub struct ProgressiveStep {
123    /// Number of components in this step
124    pub n_components: usize,
125    /// Quality score achieved
126    pub quality_score: f64,
127    /// Improvement over previous step
128    pub improvement: f64,
129    /// Time taken for this step
130    pub time_taken: f64,
131    /// Iteration number
132    pub iteration: usize,
133}
134
135/// Results from progressive approximation
136#[derive(Debug, Clone)]
137/// ProgressiveResult
138pub struct ProgressiveResult {
139    /// Final number of components
140    pub final_components: usize,
141    /// Final quality score
142    pub final_quality: f64,
143    /// All progressive steps
144    pub steps: Vec<ProgressiveStep>,
145    /// Whether convergence was achieved
146    pub converged: bool,
147    /// Stopping reason
148    pub stopping_reason: String,
149    /// Total time taken
150    pub total_time: f64,
151}
152
153/// Progressive RBF sampler
154#[derive(Debug, Clone)]
155/// ProgressiveRBFSampler
156pub struct ProgressiveRBFSampler {
157    gamma: f64,
158    config: ProgressiveConfig,
159}
160
161impl ProgressiveRBFSampler {
162    /// Create a new progressive RBF sampler
163    pub fn new() -> Self {
164        Self {
165            gamma: 1.0,
166            config: ProgressiveConfig::default(),
167        }
168    }
169
170    /// Set gamma parameter
171    pub fn gamma(mut self, gamma: f64) -> Self {
172        self.gamma = gamma;
173        self
174    }
175
176    /// Set configuration
177    pub fn config(mut self, config: ProgressiveConfig) -> Self {
178        self.config = config;
179        self
180    }
181
182    /// Set initial components
183    pub fn initial_components(mut self, components: usize) -> Self {
184        self.config.initial_components = components;
185        self
186    }
187
188    /// Set progressive strategy
189    pub fn strategy(mut self, strategy: ProgressiveStrategy) -> Self {
190        self.config.strategy = strategy;
191        self
192    }
193
194    /// Set stopping criterion
195    pub fn stopping_criterion(mut self, criterion: StoppingCriterion) -> Self {
196        self.config.stopping_criterion = criterion;
197        self
198    }
199
200    /// Run progressive approximation
201    pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
202        let start_time = Instant::now();
203        let n_samples = x.nrows();
204
205        // Split data for validation
206        let split_idx = (n_samples as f64 * (1.0 - self.config.validation_fraction)) as usize;
207        let x_train = x
208            .slice(scirs2_core::ndarray::s![..split_idx, ..])
209            .to_owned();
210        let x_val = x
211            .slice(scirs2_core::ndarray::s![split_idx.., ..])
212            .to_owned();
213
214        // Compute exact kernel matrix for validation (small subset)
215        let k_exact = self.compute_exact_kernel_matrix(&x_val)?;
216
217        let mut steps = Vec::new();
218        let mut current_components = self.config.initial_components;
219        let mut previous_quality = 0.0;
220        let mut iteration = 0;
221        let mut converged = false;
222        let mut stopping_reason = String::from("Max iterations reached");
223
224        // Fibonacci sequence state (for Fibonacci strategy)
225        let mut fib_prev = 1;
226        let mut fib_curr = 1;
227
228        loop {
229            let step_start = Instant::now();
230
231            // Compute quality for current number of components
232            let quality = self.compute_quality_for_components(
233                current_components,
234                &x_train,
235                &x_val,
236                &k_exact,
237            )?;
238
239            let improvement = if iteration == 0 {
240                quality
241            } else {
242                quality - previous_quality
243            };
244
245            let step_time = step_start.elapsed().as_secs_f64();
246
247            // Store step result
248            let step = ProgressiveStep {
249                n_components: current_components,
250                quality_score: quality,
251                improvement,
252                time_taken: step_time,
253                iteration,
254            };
255            steps.push(step);
256
257            // Check stopping criteria
258            if let Some((converged_flag, reason)) =
259                self.check_stopping_criteria(quality, improvement, iteration, current_components)
260            {
261                converged = converged_flag;
262                stopping_reason = reason;
263                break;
264            }
265
266            // Update for next iteration
267            previous_quality = quality;
268            iteration += 1;
269
270            // Determine next number of components
271            current_components = match &self.config.strategy {
272                ProgressiveStrategy::Doubling => current_components * 2,
273                ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
274                ProgressiveStrategy::AdaptiveIncrement {
275                    min_increment,
276                    max_increment,
277                    improvement_threshold,
278                } => {
279                    let increment = if improvement > *improvement_threshold {
280                        *min_increment
281                    } else {
282                        (*min_increment + (*max_increment - *min_increment) / 2).max(*min_increment)
283                    };
284                    current_components + increment
285                }
286                ProgressiveStrategy::Exponential { base } => {
287                    ((current_components as f64) * base) as usize
288                }
289                ProgressiveStrategy::Fibonacci => {
290                    let next_fib = fib_prev + fib_curr;
291                    fib_prev = fib_curr;
292                    fib_curr = next_fib;
293                    self.config.initial_components + fib_curr
294                }
295            };
296        }
297
298        let total_time = start_time.elapsed().as_secs_f64();
299
300        Ok(ProgressiveResult {
301            final_components: steps
302                .last()
303                .map(|s| s.n_components)
304                .unwrap_or(current_components),
305            final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
306            steps,
307            converged,
308            stopping_reason,
309            total_time,
310        })
311    }
312
313    /// Compute exact kernel matrix for validation
314    fn compute_exact_kernel_matrix(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
315        let n_samples = x.nrows().min(100); // Limit for computational efficiency
316        let x_subset = x.slice(scirs2_core::ndarray::s![..n_samples, ..]);
317
318        let mut k_exact = Array2::zeros((n_samples, n_samples));
319
320        for i in 0..n_samples {
321            for j in 0..n_samples {
322                let diff = &x_subset.row(i) - &x_subset.row(j);
323                let squared_norm = diff.dot(&diff);
324                k_exact[[i, j]] = (-self.gamma * squared_norm).exp();
325            }
326        }
327
328        Ok(k_exact)
329    }
330
331    /// Compute quality for a given number of components
332    fn compute_quality_for_components(
333        &self,
334        n_components: usize,
335        x_train: &Array2<f64>,
336        x_val: &Array2<f64>,
337        k_exact: &Array2<f64>,
338    ) -> Result<f64> {
339        let mut trial_qualities = Vec::new();
340
341        // Run multiple trials for stability
342        for trial in 0..self.config.n_trials {
343            let seed = self.config.random_seed.map(|s| s + trial as u64);
344            let sampler = if let Some(s) = seed {
345                RBFSampler::new(n_components)
346                    .gamma(self.gamma)
347                    .random_state(s)
348            } else {
349                RBFSampler::new(n_components).gamma(self.gamma)
350            };
351
352            let fitted = sampler.fit(x_train, &())?;
353            let x_val_transformed = fitted.transform(x_val)?;
354
355            let quality = self.compute_quality_metric(x_val, &x_val_transformed, k_exact)?;
356            trial_qualities.push(quality);
357        }
358
359        // Return average quality across trials
360        Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
361    }
362
363    /// Compute quality metric
364    fn compute_quality_metric(
365        &self,
366        x: &Array2<f64>,
367        x_transformed: &Array2<f64>,
368        k_exact: &Array2<f64>,
369    ) -> Result<f64> {
370        match &self.config.quality_metric {
371            ProgressiveQualityMetric::KernelAlignment => {
372                self.compute_kernel_alignment(x_transformed, k_exact)
373            }
374            ProgressiveQualityMetric::FrobeniusError => {
375                self.compute_frobenius_error(x_transformed, k_exact)
376            }
377            ProgressiveQualityMetric::SpectralError => {
378                self.compute_spectral_error(x_transformed, k_exact)
379            }
380            ProgressiveQualityMetric::EffectiveRank => self.compute_effective_rank(x_transformed),
381            ProgressiveQualityMetric::RelativeImprovement => {
382                // This is handled at a higher level
383                Ok(1.0)
384            }
385            ProgressiveQualityMetric::Custom => {
386                // Placeholder for custom quality function
387                self.compute_kernel_alignment(x_transformed, k_exact)
388            }
389        }
390    }
391
392    /// Compute kernel alignment
393    fn compute_kernel_alignment(
394        &self,
395        x_transformed: &Array2<f64>,
396        k_exact: &Array2<f64>,
397    ) -> Result<f64> {
398        let n_samples = k_exact.nrows().min(x_transformed.nrows());
399        let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
400
401        // Compute approximate kernel matrix
402        let k_approx = x_subset.dot(&x_subset.t());
403
404        // Compute alignment
405        let k_exact_norm = k_exact.norm_l2();
406        let k_approx_norm = k_approx.norm_l2();
407
408        if k_exact_norm > 1e-12 && k_approx_norm > 1e-12 {
409            let alignment = (k_exact * &k_approx).sum() / (k_exact_norm * k_approx_norm);
410            Ok(alignment)
411        } else {
412            Ok(0.0)
413        }
414    }
415
416    /// Compute Frobenius error (as quality score, so higher is better)
417    fn compute_frobenius_error(
418        &self,
419        x_transformed: &Array2<f64>,
420        k_exact: &Array2<f64>,
421    ) -> Result<f64> {
422        let n_samples = k_exact.nrows().min(x_transformed.nrows());
423        let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
424
425        // Compute approximate kernel matrix
426        let k_approx = x_subset.dot(&x_subset.t());
427
428        // Compute error and convert to quality (higher is better)
429        let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
430        let error = diff.norm_l2();
431        let quality = 1.0 / (1.0 + error); // Convert error to quality score
432
433        Ok(quality)
434    }
435
436    /// Compute spectral error (as quality score)
437    fn compute_spectral_error(
438        &self,
439        x_transformed: &Array2<f64>,
440        k_exact: &Array2<f64>,
441    ) -> Result<f64> {
442        let n_samples = k_exact.nrows().min(x_transformed.nrows());
443        let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
444
445        // Compute approximate kernel matrix
446        let k_approx = x_subset.dot(&x_subset.t());
447
448        // Compute spectral norm (largest singular value) of the error
449        let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
450        let (_, s, _) = diff
451            .svd(false, false)
452            .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
453
454        let spectral_error = s.iter().fold(0.0f64, |acc, &x| acc.max(x));
455        let quality = 1.0 / (1.0 + spectral_error);
456
457        Ok(quality)
458    }
459
460    /// Compute effective rank
461    fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
462        // Compute SVD of transformed data
463        let (_, s, _) = x_transformed
464            .svd(true, true)
465            .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
466
467        // Compute effective rank using entropy
468        let s_sum = s.sum();
469        if s_sum == 0.0 {
470            return Ok(0.0);
471        }
472
473        let s_normalized = &s / s_sum;
474        let entropy = -s_normalized
475            .iter()
476            .filter(|&&x| x > 1e-12)
477            .map(|&x| x * x.ln())
478            .sum::<f64>();
479
480        let effective_rank = entropy.exp();
481        Ok(effective_rank / x_transformed.ncols() as f64) // Normalize by max possible rank
482    }
483
484    /// Check stopping criteria
485    fn check_stopping_criteria(
486        &self,
487        quality: f64,
488        improvement: f64,
489        iteration: usize,
490        components: usize,
491    ) -> Option<(bool, String)> {
492        match &self.config.stopping_criterion {
493            StoppingCriterion::TargetQuality { quality: target } => {
494                if quality >= *target {
495                    Some((true, format!("Target quality {} reached", target)))
496                } else {
497                    None
498                }
499            }
500            StoppingCriterion::ImprovementThreshold { threshold } => {
501                if iteration > 0 && improvement < *threshold {
502                    Some((
503                        true,
504                        format!("Improvement {} below threshold {}", improvement, threshold),
505                    ))
506                } else {
507                    None
508                }
509            }
510            StoppingCriterion::MaxIterations { max_iter } => {
511                if iteration + 1 >= *max_iter {
512                    Some((false, format!("Maximum iterations {} reached", max_iter)))
513                } else {
514                    None
515                }
516            }
517            StoppingCriterion::MaxComponents { max_components } => {
518                if components >= *max_components {
519                    Some((
520                        false,
521                        format!("Maximum components {} reached", max_components),
522                    ))
523                } else {
524                    None
525                }
526            }
527            StoppingCriterion::Combined {
528                quality: target_quality,
529                improvement_threshold,
530                max_iter,
531                max_components,
532            } => {
533                // Check target quality
534                if let Some(target) = target_quality {
535                    if quality >= *target {
536                        return Some((true, format!("Target quality {} reached", target)));
537                    }
538                }
539
540                // Check improvement threshold
541                if let Some(threshold) = improvement_threshold {
542                    if iteration > 0 && improvement < *threshold {
543                        return Some((
544                            true,
545                            format!("Improvement {} below threshold {}", improvement, threshold),
546                        ));
547                    }
548                }
549
550                // Check max iterations
551                if let Some(max) = max_iter {
552                    if iteration >= *max {
553                        return Some((false, format!("Maximum iterations {} reached", max)));
554                    }
555                }
556
557                // Check max components
558                if let Some(max) = max_components {
559                    if components >= *max {
560                        return Some((false, format!("Maximum components {} reached", max)));
561                    }
562                }
563
564                None
565            }
566        }
567    }
568}
569
570/// Fitted progressive RBF sampler
571pub struct FittedProgressiveRBFSampler {
572    fitted_rbf: crate::rbf_sampler::RBFSampler<sklears_core::traits::Trained>,
573    progressive_result: ProgressiveResult,
574}
575
576impl Fit<Array2<f64>, ()> for ProgressiveRBFSampler {
577    type Fitted = FittedProgressiveRBFSampler;
578
579    fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
580        // Run progressive approximation
581        let progressive_result = self.run_progressive_approximation(x)?;
582
583        // Fit RBF sampler with final configuration
584        let rbf_sampler = RBFSampler::new(progressive_result.final_components).gamma(self.gamma);
585        let fitted_rbf = rbf_sampler.fit(x, &())?;
586
587        Ok(FittedProgressiveRBFSampler {
588            fitted_rbf,
589            progressive_result,
590        })
591    }
592}
593
594impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveRBFSampler {
595    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
596        self.fitted_rbf.transform(x)
597    }
598}
599
600impl FittedProgressiveRBFSampler {
601    /// Get the progressive result
602    pub fn progressive_result(&self) -> &ProgressiveResult {
603        &self.progressive_result
604    }
605
606    /// Get the final number of components
607    pub fn final_components(&self) -> usize {
608        self.progressive_result.final_components
609    }
610
611    /// Get the final quality score
612    pub fn final_quality(&self) -> f64 {
613        self.progressive_result.final_quality
614    }
615
616    /// Check if progressive approximation converged
617    pub fn converged(&self) -> bool {
618        self.progressive_result.converged
619    }
620
621    /// Get all progressive steps
622    pub fn steps(&self) -> &[ProgressiveStep] {
623        &self.progressive_result.steps
624    }
625
626    /// Get the stopping reason
627    pub fn stopping_reason(&self) -> &str {
628        &self.progressive_result.stopping_reason
629    }
630}
631
632/// Progressive Nyström method
633#[derive(Debug, Clone)]
634/// ProgressiveNystroem
635pub struct ProgressiveNystroem {
636    kernel: crate::nystroem::Kernel,
637    config: ProgressiveConfig,
638}
639
640impl ProgressiveNystroem {
641    /// Create a new progressive Nyström method
642    pub fn new() -> Self {
643        Self {
644            kernel: crate::nystroem::Kernel::Rbf { gamma: 1.0 },
645            config: ProgressiveConfig::default(),
646        }
647    }
648
649    /// Set gamma parameter (for RBF kernel)
650    pub fn gamma(mut self, gamma: f64) -> Self {
651        self.kernel = crate::nystroem::Kernel::Rbf { gamma };
652        self
653    }
654
655    /// Set kernel type
656    pub fn kernel(mut self, kernel: crate::nystroem::Kernel) -> Self {
657        self.kernel = kernel;
658        self
659    }
660
661    /// Set configuration
662    pub fn config(mut self, config: ProgressiveConfig) -> Self {
663        self.config = config;
664        self
665    }
666
667    /// Run progressive approximation for Nyström method
668    pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
669        let start_time = Instant::now();
670
671        let mut steps = Vec::new();
672        let mut current_components = self.config.initial_components;
673        let mut previous_quality = 0.0;
674        let mut iteration = 0;
675        let mut converged = false;
676        let mut stopping_reason = String::from("Max iterations reached");
677
678        loop {
679            let step_start = Instant::now();
680
681            // Compute quality for current number of components
682            let quality = self.compute_nystroem_quality(current_components, x)?;
683
684            let improvement = if iteration == 0 {
685                quality
686            } else {
687                quality - previous_quality
688            };
689
690            let step_time = step_start.elapsed().as_secs_f64();
691
692            // Store step result
693            let step = ProgressiveStep {
694                n_components: current_components,
695                quality_score: quality,
696                improvement,
697                time_taken: step_time,
698                iteration,
699            };
700            steps.push(step);
701
702            // Check stopping criteria (using same logic as RBF sampler)
703            if let Some((converged_flag, reason)) =
704                self.check_stopping_criteria(quality, improvement, iteration, current_components)
705            {
706                converged = converged_flag;
707                stopping_reason = reason;
708                break;
709            }
710
711            // Update for next iteration
712            previous_quality = quality;
713            iteration += 1;
714
715            // Determine next number of components (same logic as RBF sampler)
716            current_components = match &self.config.strategy {
717                ProgressiveStrategy::Doubling => current_components * 2,
718                ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
719                _ => current_components * 2, // Simplified for Nyström
720            };
721        }
722
723        let total_time = start_time.elapsed().as_secs_f64();
724
725        Ok(ProgressiveResult {
726            final_components: steps
727                .last()
728                .map(|s| s.n_components)
729                .unwrap_or(current_components),
730            final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
731            steps,
732            converged,
733            stopping_reason,
734            total_time,
735        })
736    }
737
738    /// Compute quality for Nyström with given components
739    fn compute_nystroem_quality(&self, n_components: usize, x: &Array2<f64>) -> Result<f64> {
740        let mut trial_qualities = Vec::new();
741
742        // Run multiple trials for stability
743        for trial in 0..self.config.n_trials {
744            let seed = self.config.random_seed.map(|s| s + trial as u64);
745            let nystroem = if let Some(s) = seed {
746                Nystroem::new(self.kernel.clone(), n_components).random_state(s)
747            } else {
748                Nystroem::new(self.kernel.clone(), n_components)
749            };
750
751            let fitted = nystroem.fit(x, &())?;
752            let x_transformed = fitted.transform(x)?;
753
754            // Use effective rank as quality measure
755            let quality = self.compute_effective_rank(&x_transformed)?;
756            trial_qualities.push(quality);
757        }
758
759        Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
760    }
761
762    /// Compute effective rank (same as RBF sampler)
763    fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
764        let (_, s, _) = x_transformed
765            .svd(true, true)
766            .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
767
768        let s_sum = s.sum();
769        if s_sum == 0.0 {
770            return Ok(0.0);
771        }
772
773        let s_normalized = &s / s_sum;
774        let entropy = -s_normalized
775            .iter()
776            .filter(|&&x| x > 1e-12)
777            .map(|&x| x * x.ln())
778            .sum::<f64>();
779
780        let effective_rank = entropy.exp();
781        Ok(effective_rank / x_transformed.ncols() as f64)
782    }
783
784    /// Check stopping criteria (same as RBF sampler)
785    fn check_stopping_criteria(
786        &self,
787        quality: f64,
788        improvement: f64,
789        iteration: usize,
790        components: usize,
791    ) -> Option<(bool, String)> {
792        match &self.config.stopping_criterion {
793            StoppingCriterion::TargetQuality { quality: target } => {
794                if quality >= *target {
795                    Some((true, format!("Target quality {} reached", target)))
796                } else {
797                    None
798                }
799            }
800            StoppingCriterion::MaxIterations { max_iter } => {
801                if iteration + 1 >= *max_iter {
802                    Some((false, format!("Maximum iterations {} reached", max_iter)))
803                } else {
804                    None
805                }
806            }
807            _ => None, // Simplified for Nyström
808        }
809    }
810}
811
812/// Fitted progressive Nyström method
813pub struct FittedProgressiveNystroem {
814    fitted_nystroem: crate::nystroem::Nystroem<sklears_core::traits::Trained>,
815    progressive_result: ProgressiveResult,
816}
817
818impl Fit<Array2<f64>, ()> for ProgressiveNystroem {
819    type Fitted = FittedProgressiveNystroem;
820
821    fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
822        // Run progressive approximation
823        let progressive_result = self.run_progressive_approximation(x)?;
824
825        // Fit Nyström method with final configuration
826        let nystroem = Nystroem::new(self.kernel, progressive_result.final_components);
827        let fitted_nystroem = nystroem.fit(x, &())?;
828
829        Ok(FittedProgressiveNystroem {
830            fitted_nystroem,
831            progressive_result,
832        })
833    }
834}
835
836impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveNystroem {
837    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
838        self.fitted_nystroem.transform(x)
839    }
840}
841
842impl FittedProgressiveNystroem {
843    /// Get the progressive result
844    pub fn progressive_result(&self) -> &ProgressiveResult {
845        &self.progressive_result
846    }
847
848    /// Get the final number of components
849    pub fn final_components(&self) -> usize {
850        self.progressive_result.final_components
851    }
852
853    /// Get the final quality score
854    pub fn final_quality(&self) -> f64 {
855        self.progressive_result.final_quality
856    }
857
858    /// Check if progressive approximation converged
859    pub fn converged(&self) -> bool {
860        self.progressive_result.converged
861    }
862}
863
864#[allow(non_snake_case)]
865#[cfg(test)]
866mod tests {
867    use super::*;
868    use approx::assert_abs_diff_eq;
869
870    #[test]
871    fn test_progressive_rbf_sampler() {
872        let x = Array2::from_shape_vec((100, 4), (0..400).map(|i| (i as f64) * 0.01).collect())
873            .unwrap();
874
875        let config = ProgressiveConfig {
876            initial_components: 5,
877            strategy: ProgressiveStrategy::Doubling,
878            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
879            quality_metric: ProgressiveQualityMetric::KernelAlignment,
880            n_trials: 2,
881            validation_fraction: 0.3,
882            ..Default::default()
883        };
884
885        let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
886
887        let fitted = sampler.fit(&x, &()).unwrap();
888        let transformed = fitted.transform(&x).unwrap();
889
890        assert_eq!(transformed.nrows(), 100);
891        assert!(fitted.final_components() >= 5);
892        assert!(fitted.final_quality() >= 0.0);
893        assert_eq!(fitted.steps().len(), 3); // 3 iterations max
894    }
895
896    #[test]
897    fn test_progressive_nystroem() {
898        let x =
899            Array2::from_shape_vec((80, 3), (0..240).map(|i| (i as f64) * 0.02).collect()).unwrap();
900
901        let config = ProgressiveConfig {
902            initial_components: 10,
903            strategy: ProgressiveStrategy::FixedIncrement { increment: 5 },
904            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
905            n_trials: 2,
906            ..Default::default()
907        };
908
909        let nystroem = ProgressiveNystroem::new().gamma(1.0).config(config);
910
911        let fitted = nystroem.fit(&x, &()).unwrap();
912        let transformed = fitted.transform(&x).unwrap();
913
914        assert_eq!(transformed.nrows(), 80);
915        assert!(fitted.final_components() >= 10);
916        assert!(fitted.final_quality() >= 0.0);
917    }
918
919    #[test]
920    fn test_progressive_strategies() {
921        let x =
922            Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.05).collect()).unwrap();
923
924        let strategies = vec![
925            ProgressiveStrategy::Doubling,
926            ProgressiveStrategy::FixedIncrement { increment: 3 },
927            ProgressiveStrategy::Exponential { base: 1.5 },
928            ProgressiveStrategy::Fibonacci,
929        ];
930
931        for strategy in strategies {
932            let config = ProgressiveConfig {
933                initial_components: 5,
934                strategy,
935                stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
936                n_trials: 1,
937                ..Default::default()
938            };
939
940            let sampler = ProgressiveRBFSampler::new().gamma(0.8).config(config);
941
942            let result = sampler.run_progressive_approximation(&x).unwrap();
943
944            assert!(result.final_components >= 5);
945            assert!(result.final_quality >= 0.0);
946            assert_eq!(result.steps.len(), 3);
947        }
948    }
949
950    #[test]
951    fn test_stopping_criteria() {
952        let x =
953            Array2::from_shape_vec((60, 3), (0..180).map(|i| (i as f64) * 0.03).collect()).unwrap();
954
955        let criteria = vec![
956            StoppingCriterion::TargetQuality { quality: 0.8 },
957            StoppingCriterion::ImprovementThreshold { threshold: 0.01 },
958            StoppingCriterion::MaxIterations { max_iter: 5 },
959            StoppingCriterion::MaxComponents { max_components: 50 },
960        ];
961
962        for criterion in criteria {
963            let config = ProgressiveConfig {
964                initial_components: 10,
965                strategy: ProgressiveStrategy::Doubling,
966                stopping_criterion: criterion,
967                n_trials: 1,
968                ..Default::default()
969            };
970
971            let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
972
973            let result = sampler.run_progressive_approximation(&x).unwrap();
974
975            assert!(result.final_components >= 10);
976            assert!(result.final_quality >= 0.0);
977            assert!(!result.stopping_reason.is_empty());
978        }
979    }
980
981    #[test]
982    fn test_quality_metrics() {
983        let x =
984            Array2::from_shape_vec((40, 2), (0..80).map(|i| (i as f64) * 0.05).collect()).unwrap();
985
986        let metrics = vec![
987            ProgressiveQualityMetric::KernelAlignment,
988            ProgressiveQualityMetric::FrobeniusError,
989            ProgressiveQualityMetric::SpectralError,
990            ProgressiveQualityMetric::EffectiveRank,
991        ];
992
993        for metric in metrics {
994            let config = ProgressiveConfig {
995                initial_components: 5,
996                strategy: ProgressiveStrategy::Doubling,
997                stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
998                quality_metric: metric,
999                n_trials: 1,
1000                ..Default::default()
1001            };
1002
1003            let sampler = ProgressiveRBFSampler::new().gamma(0.3).config(config);
1004
1005            let result = sampler.run_progressive_approximation(&x).unwrap();
1006
1007            assert!(result.final_components >= 5);
1008            assert!(result.final_quality >= 0.0);
1009
1010            // All steps should have valid quality scores
1011            for step in &result.steps {
1012                assert!(step.quality_score >= 0.0);
1013                assert!(step.time_taken >= 0.0);
1014            }
1015        }
1016    }
1017
1018    #[test]
1019    fn test_progressive_improvement() {
1020        let x =
1021            Array2::from_shape_vec((70, 3), (0..210).map(|i| (i as f64) * 0.02).collect()).unwrap();
1022
1023        let config = ProgressiveConfig {
1024            initial_components: 10,
1025            strategy: ProgressiveStrategy::Doubling,
1026            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
1027            quality_metric: ProgressiveQualityMetric::KernelAlignment,
1028            n_trials: 2,
1029            ..Default::default()
1030        };
1031
1032        let sampler = ProgressiveRBFSampler::new().gamma(0.7).config(config);
1033
1034        let result = sampler.run_progressive_approximation(&x).unwrap();
1035
1036        // Quality should generally improve or stay stable
1037        for i in 1..result.steps.len() {
1038            let current_quality = result.steps[i].quality_score;
1039            let previous_quality = result.steps[i - 1].quality_score;
1040
1041            // Allow for small numerical differences
1042            assert!(
1043                current_quality >= previous_quality - 0.1,
1044                "Quality should not decrease significantly: {} -> {}",
1045                previous_quality,
1046                current_quality
1047            );
1048        }
1049    }
1050
1051    #[test]
1052    fn test_progressive_reproducibility() {
1053        let x =
1054            Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.04).collect()).unwrap();
1055
1056        let config = ProgressiveConfig {
1057            initial_components: 5,
1058            strategy: ProgressiveStrategy::Doubling,
1059            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
1060            n_trials: 2,
1061            random_seed: Some(42),
1062            ..Default::default()
1063        };
1064
1065        let sampler1 = ProgressiveRBFSampler::new()
1066            .gamma(0.6)
1067            .config(config.clone());
1068
1069        let sampler2 = ProgressiveRBFSampler::new().gamma(0.6).config(config);
1070
1071        let result1 = sampler1.run_progressive_approximation(&x).unwrap();
1072        let result2 = sampler2.run_progressive_approximation(&x).unwrap();
1073
1074        assert_eq!(result1.final_components, result2.final_components);
1075        assert_abs_diff_eq!(
1076            result1.final_quality,
1077            result2.final_quality,
1078            epsilon = 1e-10
1079        );
1080        assert_eq!(result1.steps.len(), result2.steps.len());
1081    }
1082}