sklears_kernel_approximation/
validation.rs

1//! Advanced Validation Framework for Kernel Approximation Methods
2//!
3//! This module provides comprehensive validation tools including theoretical error bound
4//! validation, convergence analysis, and approximation quality assessment.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::error::Result;
12use std::collections::HashMap;
13
14/// Comprehensive validation framework for kernel approximation methods
15#[derive(Debug, Clone)]
16/// KernelApproximationValidator
17pub struct KernelApproximationValidator {
18    config: ValidationConfig,
19    theoretical_bounds: HashMap<String, TheoreticalBound>,
20}
21
22/// Configuration for validation
23#[derive(Debug, Clone)]
24/// ValidationConfig
25pub struct ValidationConfig {
26    /// confidence_level
27    pub confidence_level: f64,
28    /// max_approximation_error
29    pub max_approximation_error: f64,
30    /// convergence_tolerance
31    pub convergence_tolerance: f64,
32    /// stability_tolerance
33    pub stability_tolerance: f64,
34    /// sample_sizes
35    pub sample_sizes: Vec<usize>,
36    /// approximation_dimensions
37    pub approximation_dimensions: Vec<usize>,
38    /// repetitions
39    pub repetitions: usize,
40    /// random_state
41    pub random_state: Option<u64>,
42}
43
44impl Default for ValidationConfig {
45    fn default() -> Self {
46        Self {
47            confidence_level: 0.95,
48            max_approximation_error: 0.1,
49            convergence_tolerance: 1e-6,
50            stability_tolerance: 1e-4,
51            sample_sizes: vec![100, 500, 1000, 2000],
52            approximation_dimensions: vec![50, 100, 200, 500],
53            repetitions: 10,
54            random_state: Some(42),
55        }
56    }
57}
58
59/// Theoretical error bounds for different approximation methods
60#[derive(Debug, Clone)]
61/// TheoreticalBound
62pub struct TheoreticalBound {
63    /// method_name
64    pub method_name: String,
65    /// bound_type
66    pub bound_type: BoundType,
67    /// bound_function
68    pub bound_function: BoundFunction,
69    /// constants
70    pub constants: HashMap<String, f64>,
71}
72
73/// Types of theoretical bounds
74#[derive(Debug, Clone)]
75/// BoundType
76pub enum BoundType {
77    /// Probabilistic bound with confidence level
78    Probabilistic { confidence: f64 },
79    /// Deterministic worst-case bound
80    Deterministic,
81    /// Expected error bound
82    Expected,
83    /// Concentration inequality bound
84    Concentration { deviation_parameter: f64 },
85}
86
87/// Functions for computing theoretical bounds
88#[derive(Debug, Clone)]
89/// BoundFunction
90pub enum BoundFunction {
91    /// RFF approximation error: O(sqrt(log(d)/m))
92    RandomFourierFeatures,
93    /// Nyström approximation error: depends on eigenvalue decay
94    Nystroem,
95    /// Structured random features: O(sqrt(d*log(d)/m))
96    StructuredRandomFeatures,
97    /// Fastfood approximation: O(sqrt(d*log^2(d)/m))
98    Fastfood,
99    /// Custom bound function
100    Custom { formula: String },
101}
102
103/// Result of validation analysis
104#[derive(Debug, Clone)]
105/// ValidationResult
106pub struct ValidationResult {
107    /// method_name
108    pub method_name: String,
109    /// empirical_errors
110    pub empirical_errors: Vec<f64>,
111    /// theoretical_bounds
112    pub theoretical_bounds: Vec<f64>,
113    /// bound_violations
114    pub bound_violations: usize,
115    /// bound_tightness
116    pub bound_tightness: f64,
117    /// convergence_rate
118    pub convergence_rate: Option<f64>,
119    /// stability_analysis
120    pub stability_analysis: StabilityAnalysis,
121    /// sample_complexity
122    pub sample_complexity: SampleComplexityAnalysis,
123    /// dimension_dependency
124    pub dimension_dependency: DimensionDependencyAnalysis,
125}
126
127/// Stability analysis results
128#[derive(Debug, Clone)]
129/// StabilityAnalysis
130pub struct StabilityAnalysis {
131    /// perturbation_sensitivity
132    pub perturbation_sensitivity: f64,
133    /// numerical_stability
134    pub numerical_stability: f64,
135    /// condition_numbers
136    pub condition_numbers: Vec<f64>,
137    /// eigenvalue_stability
138    pub eigenvalue_stability: f64,
139}
140
141/// Sample complexity analysis
142#[derive(Debug, Clone)]
143/// SampleComplexityAnalysis
144pub struct SampleComplexityAnalysis {
145    /// minimum_samples
146    pub minimum_samples: usize,
147    /// convergence_rate
148    pub convergence_rate: f64,
149    /// sample_efficiency
150    pub sample_efficiency: f64,
151    /// dimension_scaling
152    pub dimension_scaling: f64,
153}
154
155/// Dimension dependency analysis
156#[derive(Debug, Clone)]
157/// DimensionDependencyAnalysis
158pub struct DimensionDependencyAnalysis {
159    /// approximation_quality_vs_dimension
160    pub approximation_quality_vs_dimension: Vec<(usize, f64)>,
161    /// computational_cost_vs_dimension
162    pub computational_cost_vs_dimension: Vec<(usize, f64)>,
163    /// optimal_dimension
164    pub optimal_dimension: usize,
165    /// dimension_efficiency
166    pub dimension_efficiency: f64,
167}
168
169/// Cross-validation result for kernel approximation
170#[derive(Debug, Clone)]
171/// CrossValidationResult
172pub struct CrossValidationResult {
173    /// method_name
174    pub method_name: String,
175    /// cv_scores
176    pub cv_scores: Vec<f64>,
177    /// mean_score
178    pub mean_score: f64,
179    /// std_score
180    pub std_score: f64,
181    /// best_parameters
182    pub best_parameters: HashMap<String, f64>,
183    /// parameter_sensitivity
184    pub parameter_sensitivity: HashMap<String, f64>,
185}
186
187impl KernelApproximationValidator {
188    /// Create a new validator with configuration
189    pub fn new(config: ValidationConfig) -> Self {
190        let mut validator = Self {
191            config,
192            theoretical_bounds: HashMap::new(),
193        };
194
195        // Add default theoretical bounds
196        validator.add_default_bounds();
197        validator
198    }
199
200    /// Add theoretical bounds for a specific method
201    pub fn add_theoretical_bound(&mut self, bound: TheoreticalBound) {
202        self.theoretical_bounds
203            .insert(bound.method_name.clone(), bound);
204    }
205
206    fn add_default_bounds(&mut self) {
207        // RFF bounds
208        self.add_theoretical_bound(TheoreticalBound {
209            method_name: "RBF".to_string(),
210            bound_type: BoundType::Probabilistic { confidence: 0.95 },
211            bound_function: BoundFunction::RandomFourierFeatures,
212            constants: [
213                ("kernel_bound".to_string(), 1.0),
214                ("lipschitz_constant".to_string(), 1.0),
215            ]
216            .iter()
217            .cloned()
218            .collect(),
219        });
220
221        // Nyström bounds
222        self.add_theoretical_bound(TheoreticalBound {
223            method_name: "Nystroem".to_string(),
224            bound_type: BoundType::Expected,
225            bound_function: BoundFunction::Nystroem,
226            constants: [
227                ("trace_bound".to_string(), 1.0),
228                ("effective_rank".to_string(), 100.0),
229            ]
230            .iter()
231            .cloned()
232            .collect(),
233        });
234
235        // Fastfood bounds
236        self.add_theoretical_bound(TheoreticalBound {
237            method_name: "Fastfood".to_string(),
238            bound_type: BoundType::Probabilistic { confidence: 0.95 },
239            bound_function: BoundFunction::Fastfood,
240            constants: [
241                ("dimension_factor".to_string(), 1.0),
242                ("log_factor".to_string(), 2.0),
243            ]
244            .iter()
245            .cloned()
246            .collect(),
247        });
248    }
249
250    /// Validate a kernel approximation method
251    pub fn validate_method<T: ValidatableKernelMethod>(
252        &self,
253        method: &T,
254        data: &Array2<f64>,
255        true_kernel: Option<&Array2<f64>>,
256    ) -> Result<ValidationResult> {
257        let method_name = method.method_name();
258        let mut empirical_errors = Vec::new();
259        let mut theoretical_bounds = Vec::new();
260        let mut condition_numbers = Vec::new();
261
262        // Test different approximation dimensions
263        for &n_components in &self.config.approximation_dimensions {
264            let mut dimension_errors = Vec::new();
265
266            for _ in 0..self.config.repetitions {
267                // Fit and evaluate the method
268                let fitted = method.fit_with_dimension(data, n_components)?;
269                let approximation = fitted.get_kernel_approximation(data)?;
270
271                // Compute empirical error
272                let empirical_error = if let Some(true_k) = true_kernel {
273                    self.compute_approximation_error(&approximation, true_k)?
274                } else {
275                    // Use RBF kernel as reference
276                    let rbf_kernel = self.compute_rbf_kernel(data, 1.0)?;
277                    self.compute_approximation_error(&approximation, &rbf_kernel)?
278                };
279
280                dimension_errors.push(empirical_error);
281
282                // Compute condition number for stability analysis
283                if let Some(cond_num) = fitted.compute_condition_number()? {
284                    condition_numbers.push(cond_num);
285                }
286            }
287
288            let mean_error = dimension_errors.iter().sum::<f64>() / dimension_errors.len() as f64;
289            empirical_errors.push(mean_error);
290
291            // Compute theoretical bound
292            if let Some(bound) = self.theoretical_bounds.get(&method_name) {
293                let theoretical_bound = self.compute_theoretical_bound(
294                    bound,
295                    data.nrows(),
296                    data.ncols(),
297                    n_components,
298                )?;
299                theoretical_bounds.push(theoretical_bound);
300            } else {
301                theoretical_bounds.push(f64::INFINITY);
302            }
303        }
304
305        // Count bound violations
306        let bound_violations = empirical_errors
307            .iter()
308            .zip(theoretical_bounds.iter())
309            .filter(|(&emp, &theo)| emp > theo)
310            .count();
311
312        // Compute bound tightness (average ratio of empirical to theoretical)
313        let bound_tightness = empirical_errors
314            .iter()
315            .zip(theoretical_bounds.iter())
316            .filter(|(_, &theo)| theo.is_finite())
317            .map(|(&emp, &theo)| emp / theo)
318            .sum::<f64>()
319            / empirical_errors.len() as f64;
320
321        // Analyze convergence rate
322        let convergence_rate = self.estimate_convergence_rate(&empirical_errors);
323
324        // Perform stability analysis
325        let stability_analysis = self.analyze_stability(method, data, &condition_numbers)?;
326
327        // Analyze sample complexity
328        let sample_complexity = self.analyze_sample_complexity(method, data)?;
329
330        // Analyze dimension dependency
331        let dimension_dependency =
332            self.analyze_dimension_dependency(method, data, &empirical_errors)?;
333
334        Ok(ValidationResult {
335            method_name,
336            empirical_errors,
337            theoretical_bounds,
338            bound_violations,
339            bound_tightness,
340            convergence_rate,
341            stability_analysis,
342            sample_complexity,
343            dimension_dependency,
344        })
345    }
346
347    /// Perform cross-validation for parameter selection
348    pub fn cross_validate<T: ValidatableKernelMethod>(
349        &self,
350        method: &T,
351        data: &Array2<f64>,
352        targets: Option<&Array1<f64>>,
353        parameter_grid: HashMap<String, Vec<f64>>,
354    ) -> Result<CrossValidationResult> {
355        let mut best_score = f64::NEG_INFINITY;
356        let mut best_parameters = HashMap::new();
357        let mut all_scores = Vec::new();
358        let mut parameter_sensitivity = HashMap::new();
359
360        // Generate parameter combinations
361        let param_combinations = self.generate_parameter_combinations(&parameter_grid);
362
363        for params in param_combinations {
364            let cv_scores = self.k_fold_cross_validation(method, data, targets, &params, 5)?;
365            let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
366
367            all_scores.push(mean_score);
368
369            if mean_score > best_score {
370                best_score = mean_score;
371                best_parameters = params.clone();
372            }
373        }
374
375        // Analyze parameter sensitivity
376        for (param_name, param_values) in &parameter_grid {
377            let mut sensitivities = Vec::new();
378
379            for &param_value in param_values.iter() {
380                let mut single_param = best_parameters.clone();
381                single_param.insert(param_name.clone(), param_value);
382
383                let cv_scores =
384                    self.k_fold_cross_validation(method, data, targets, &single_param, 3)?;
385                let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
386                sensitivities.push((best_score - mean_score).abs());
387            }
388
389            let sensitivity = sensitivities.iter().sum::<f64>() / sensitivities.len() as f64;
390            parameter_sensitivity.insert(param_name.clone(), sensitivity);
391        }
392
393        let mean_score = all_scores.iter().sum::<f64>() / all_scores.len() as f64;
394        let variance = all_scores
395            .iter()
396            .map(|&x| (x - mean_score).powi(2))
397            .sum::<f64>()
398            / all_scores.len() as f64;
399        let std_score = variance.sqrt();
400
401        Ok(CrossValidationResult {
402            method_name: method.method_name(),
403            cv_scores: all_scores,
404            mean_score,
405            std_score,
406            best_parameters,
407            parameter_sensitivity,
408        })
409    }
410
411    fn compute_approximation_error(
412        &self,
413        approx_kernel: &Array2<f64>,
414        true_kernel: &Array2<f64>,
415    ) -> Result<f64> {
416        // Compute Frobenius norm error
417        let diff = approx_kernel - true_kernel;
418        let frobenius_error = diff.mapv(|x| x * x).sum().sqrt();
419
420        // Normalize by true kernel norm
421        let true_norm = true_kernel.mapv(|x| x * x).sum().sqrt();
422        Ok(frobenius_error / true_norm.max(1e-8))
423    }
424
425    fn compute_rbf_kernel(&self, data: &Array2<f64>, gamma: f64) -> Result<Array2<f64>> {
426        let n_samples = data.nrows();
427        let mut kernel = Array2::zeros((n_samples, n_samples));
428
429        for i in 0..n_samples {
430            for j in i..n_samples {
431                let diff = &data.row(i) - &data.row(j);
432                let dist_sq = diff.mapv(|x| x * x).sum();
433                let similarity = (-gamma * dist_sq).exp();
434                kernel[[i, j]] = similarity;
435                kernel[[j, i]] = similarity;
436            }
437        }
438
439        Ok(kernel)
440    }
441
442    fn compute_theoretical_bound(
443        &self,
444        bound: &TheoreticalBound,
445        n_samples: usize,
446        n_features: usize,
447        n_components: usize,
448    ) -> Result<f64> {
449        let bound_value = match &bound.bound_function {
450            BoundFunction::RandomFourierFeatures => {
451                let kernel_bound = bound.constants.get("kernel_bound").unwrap_or(&1.0);
452                let lipschitz = bound.constants.get("lipschitz_constant").unwrap_or(&1.0);
453
454                // O(sqrt(log(d)/m)) bound for RFF
455                let log_factor = (n_features as f64).ln();
456                kernel_bound * lipschitz * (log_factor / n_components as f64).sqrt()
457            }
458            BoundFunction::Nystroem => {
459                let trace_bound = bound.constants.get("trace_bound").unwrap_or(&1.0);
460                let effective_rank = bound.constants.get("effective_rank").unwrap_or(&100.0);
461
462                // Approximation depends on eigenvalue decay
463                trace_bound * (effective_rank / n_components as f64).sqrt()
464            }
465            BoundFunction::StructuredRandomFeatures => {
466                let log_factor = (n_features as f64).ln();
467                (n_features as f64 * log_factor / n_components as f64).sqrt()
468            }
469            BoundFunction::Fastfood => {
470                let log_factor = bound.constants.get("log_factor").unwrap_or(&2.0);
471                let dim_factor = bound.constants.get("dimension_factor").unwrap_or(&1.0);
472
473                let log_d = (n_features as f64).ln();
474                dim_factor
475                    * (n_features as f64 * log_d.powf(*log_factor) / n_components as f64).sqrt()
476            }
477            BoundFunction::Custom { formula: _ } => {
478                // Placeholder for custom formulas
479                1.0 / (n_components as f64).sqrt()
480            }
481        };
482
483        // Apply bound type modifications
484        let final_bound = match &bound.bound_type {
485            BoundType::Probabilistic { confidence } => {
486                // Add confidence-dependent factor
487                let z_score = self.inverse_normal_cdf(*confidence);
488                bound_value * (1.0 + z_score / (n_samples as f64).sqrt())
489            }
490            BoundType::Deterministic => bound_value,
491            BoundType::Expected => bound_value * 0.8, // Expected is typically tighter
492            BoundType::Concentration {
493                deviation_parameter,
494            } => bound_value * (1.0 + deviation_parameter / (n_samples as f64).sqrt()),
495        };
496
497        Ok(final_bound)
498    }
499
500    fn inverse_normal_cdf(&self, p: f64) -> f64 {
501        // Approximation of inverse normal CDF for confidence intervals
502        if p <= 0.5 {
503            -self.inverse_normal_cdf(1.0 - p)
504        } else {
505            let t = (-2.0 * (1.0 - p).ln()).sqrt();
506            let c0 = 2.515517;
507            let c1 = 0.802853;
508            let c2 = 0.010328;
509            let d1 = 1.432788;
510            let d2 = 0.189269;
511            let d3 = 0.001308;
512
513            t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
514        }
515    }
516
517    fn estimate_convergence_rate(&self, errors: &[f64]) -> Option<f64> {
518        if errors.len() < 3 {
519            return None;
520        }
521
522        // Fit log(error) = a + b * log(dimension) to estimate convergence rate
523        let dimensions: Vec<f64> = self
524            .config
525            .approximation_dimensions
526            .iter()
527            .take(errors.len())
528            .map(|&x| (x as f64).ln())
529            .collect();
530
531        let log_errors: Vec<f64> = errors.iter().map(|&x| x.ln()).collect();
532
533        // Simple linear regression
534        let n = dimensions.len() as f64;
535        let sum_x = dimensions.iter().sum::<f64>();
536        let sum_y = log_errors.iter().sum::<f64>();
537        let sum_xy = dimensions
538            .iter()
539            .zip(log_errors.iter())
540            .map(|(&x, &y)| x * y)
541            .sum::<f64>();
542        let sum_x2 = dimensions.iter().map(|&x| x * x).sum::<f64>();
543
544        let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
545        Some(-slope) // Negative because we expect decreasing error
546    }
547
548    fn analyze_stability<T: ValidatableKernelMethod>(
549        &self,
550        method: &T,
551        data: &Array2<f64>,
552        condition_numbers: &[f64],
553    ) -> Result<StabilityAnalysis> {
554        let mut rng = RealStdRng::seed_from_u64(self.config.random_state.unwrap_or(42));
555        let normal = RandNormal::new(0.0, self.config.stability_tolerance).unwrap();
556
557        // Test perturbation sensitivity
558        let mut perturbation_errors = Vec::new();
559
560        for _ in 0..5 {
561            let mut perturbed_data = data.clone();
562            for elem in perturbed_data.iter_mut() {
563                *elem += rng.sample(normal);
564            }
565
566            let original_fitted = method.fit_with_dimension(data, 100)?;
567            let perturbed_fitted = method.fit_with_dimension(&perturbed_data, 100)?;
568
569            let original_approx = original_fitted.get_kernel_approximation(data)?;
570            let perturbed_approx = perturbed_fitted.get_kernel_approximation(data)?;
571
572            let error = self.compute_approximation_error(&perturbed_approx, &original_approx)?;
573            perturbation_errors.push(error);
574        }
575
576        let perturbation_sensitivity =
577            perturbation_errors.iter().sum::<f64>() / perturbation_errors.len() as f64;
578
579        // Numerical stability from condition numbers
580        let numerical_stability = if condition_numbers.is_empty() {
581            1.0
582        } else {
583            let mean_condition =
584                condition_numbers.iter().sum::<f64>() / condition_numbers.len() as f64;
585            1.0 / mean_condition.ln().max(1.0)
586        };
587
588        // Eigenvalue stability (placeholder)
589        let eigenvalue_stability = 1.0 - perturbation_sensitivity;
590
591        Ok(StabilityAnalysis {
592            perturbation_sensitivity,
593            numerical_stability,
594            condition_numbers: condition_numbers.to_vec(),
595            eigenvalue_stability,
596        })
597    }
598
599    fn analyze_sample_complexity<T: ValidatableKernelMethod>(
600        &self,
601        method: &T,
602        data: &Array2<f64>,
603    ) -> Result<SampleComplexityAnalysis> {
604        let mut sample_errors = Vec::new();
605
606        // Test different sample sizes
607        for &n_samples in &self.config.sample_sizes {
608            if n_samples > data.nrows() {
609                continue;
610            }
611
612            let subset_data = data
613                .slice(scirs2_core::ndarray::s![..n_samples, ..])
614                .to_owned();
615            let fitted = method.fit_with_dimension(&subset_data, 100)?;
616            let approx = fitted.get_kernel_approximation(&subset_data)?;
617
618            let rbf_kernel = self.compute_rbf_kernel(&subset_data, 1.0)?;
619            let error = self.compute_approximation_error(&approx, &rbf_kernel)?;
620            sample_errors.push(error);
621        }
622
623        // Estimate minimum required samples
624        let target_error = self.config.max_approximation_error;
625        let minimum_samples = self
626            .config
627            .sample_sizes
628            .iter()
629            .zip(sample_errors.iter())
630            .find(|(_, &error)| error <= target_error)
631            .map(|(&samples, _)| samples)
632            .unwrap_or(*self.config.sample_sizes.last().unwrap());
633
634        // Estimate convergence rate with respect to sample size
635        let convergence_rate = if sample_errors.len() >= 2 {
636            let log_samples: Vec<f64> = self
637                .config
638                .sample_sizes
639                .iter()
640                .take(sample_errors.len())
641                .map(|&x| (x as f64).ln())
642                .collect();
643            let log_errors: Vec<f64> = sample_errors.iter().map(|&x| x.ln()).collect();
644
645            // Linear regression for convergence rate
646            let n = log_samples.len() as f64;
647            let sum_x = log_samples.iter().sum::<f64>();
648            let sum_y = log_errors.iter().sum::<f64>();
649            let sum_xy = log_samples
650                .iter()
651                .zip(log_errors.iter())
652                .map(|(&x, &y)| x * y)
653                .sum::<f64>();
654            let sum_x2 = log_samples.iter().map(|&x| x * x).sum::<f64>();
655
656            -(n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
657        } else {
658            0.5 // Default assumption
659        };
660
661        let sample_efficiency = 1.0 / minimum_samples as f64;
662        let dimension_scaling = data.ncols() as f64 / minimum_samples as f64;
663
664        Ok(SampleComplexityAnalysis {
665            minimum_samples,
666            convergence_rate,
667            sample_efficiency,
668            dimension_scaling,
669        })
670    }
671
672    fn analyze_dimension_dependency<T: ValidatableKernelMethod>(
673        &self,
674        _method: &T,
675        data: &Array2<f64>,
676        errors: &[f64],
677    ) -> Result<DimensionDependencyAnalysis> {
678        let approximation_quality_vs_dimension: Vec<(usize, f64)> = self
679            .config
680            .approximation_dimensions
681            .iter()
682            .take(errors.len())
683            .zip(errors.iter())
684            .map(|(&dim, &error)| (dim, 1.0 - error)) // Convert error to quality
685            .collect();
686
687        // Estimate computational cost (simplified)
688        let computational_cost_vs_dimension: Vec<(usize, f64)> = self
689            .config
690            .approximation_dimensions
691            .iter()
692            .map(|&dim| (dim, dim as f64 * data.nrows() as f64))
693            .collect();
694
695        // Find optimal dimension (best quality-to-cost ratio)
696        let optimal_dimension = approximation_quality_vs_dimension
697            .iter()
698            .zip(computational_cost_vs_dimension.iter())
699            .map(|((dim, quality), (_, cost))| (*dim, quality / cost))
700            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
701            .map(|(dim, _)| dim)
702            .unwrap_or(100);
703
704        let dimension_efficiency = approximation_quality_vs_dimension
705            .iter()
706            .map(|(_, quality)| quality)
707            .sum::<f64>()
708            / approximation_quality_vs_dimension.len() as f64;
709
710        Ok(DimensionDependencyAnalysis {
711            approximation_quality_vs_dimension,
712            computational_cost_vs_dimension,
713            optimal_dimension,
714            dimension_efficiency,
715        })
716    }
717
718    fn generate_parameter_combinations(
719        &self,
720        parameter_grid: &HashMap<String, Vec<f64>>,
721    ) -> Vec<HashMap<String, f64>> {
722        let mut combinations = vec![HashMap::new()];
723
724        for (param_name, param_values) in parameter_grid {
725            let mut new_combinations = Vec::new();
726
727            for combination in &combinations {
728                for &param_value in param_values {
729                    let mut new_combination = combination.clone();
730                    new_combination.insert(param_name.clone(), param_value);
731                    new_combinations.push(new_combination);
732                }
733            }
734
735            combinations = new_combinations;
736        }
737
738        combinations
739    }
740
741    fn k_fold_cross_validation<T: ValidatableKernelMethod>(
742        &self,
743        method: &T,
744        data: &Array2<f64>,
745        _targets: Option<&Array1<f64>>,
746        parameters: &HashMap<String, f64>,
747        k: usize,
748    ) -> Result<Vec<f64>> {
749        let n_samples = data.nrows();
750        let fold_size = n_samples / k;
751        let mut scores = Vec::new();
752
753        for fold in 0..k {
754            let start_idx = fold * fold_size;
755            let end_idx = if fold == k - 1 {
756                n_samples
757            } else {
758                (fold + 1) * fold_size
759            };
760
761            // Create train and validation sets
762            let train_indices: Vec<usize> = (0..n_samples)
763                .filter(|&i| i < start_idx || i >= end_idx)
764                .collect();
765            let val_indices: Vec<usize> = (start_idx..end_idx).collect();
766
767            let train_data = data.select(Axis(0), &train_indices);
768            let val_data = data.select(Axis(0), &val_indices);
769
770            // Fit with parameters
771            let fitted = method.fit_with_parameters(&train_data, parameters)?;
772            let train_approx = fitted.get_kernel_approximation(&train_data)?;
773            let val_approx = fitted.get_kernel_approximation(&val_data)?;
774
775            // Compute validation score (kernel alignment as proxy)
776            let train_kernel = self.compute_rbf_kernel(&train_data, 1.0)?;
777            let val_kernel = self.compute_rbf_kernel(&val_data, 1.0)?;
778
779            let train_error = self.compute_approximation_error(&train_approx, &train_kernel)?;
780            let val_error = self.compute_approximation_error(&val_approx, &val_kernel)?;
781
782            // Score is negative error (higher is better)
783            let score = -(train_error + val_error) / 2.0;
784            scores.push(score);
785        }
786
787        Ok(scores)
788    }
789}
790
791/// Trait for kernel methods that can be validated
792pub trait ValidatableKernelMethod {
793    /// Get method name
794    fn method_name(&self) -> String;
795
796    /// Fit with specific approximation dimension
797    fn fit_with_dimension(
798        &self,
799        data: &Array2<f64>,
800        n_components: usize,
801    ) -> Result<Box<dyn ValidatedFittedMethod>>;
802
803    /// Fit with specific parameters
804    fn fit_with_parameters(
805        &self,
806        data: &Array2<f64>,
807        parameters: &HashMap<String, f64>,
808    ) -> Result<Box<dyn ValidatedFittedMethod>>;
809}
810
811/// Trait for fitted methods that can be validated
812pub trait ValidatedFittedMethod {
813    /// Get kernel approximation matrix
814    fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
815
816    /// Compute condition number if applicable
817    fn compute_condition_number(&self) -> Result<Option<f64>>;
818
819    /// Get approximation dimension
820    fn approximation_dimension(&self) -> usize;
821}
822
823#[allow(non_snake_case)]
824#[cfg(test)]
825mod tests {
826    use super::*;
827    // Mock implementation for testing
828    struct MockValidatableRBF {
829        gamma: f64,
830    }
831
832    impl ValidatableKernelMethod for MockValidatableRBF {
833        fn method_name(&self) -> String {
834            "MockRBF".to_string()
835        }
836
837        fn fit_with_dimension(
838            &self,
839            _data: &Array2<f64>,
840            n_components: usize,
841        ) -> Result<Box<dyn ValidatedFittedMethod>> {
842            Ok(Box::new(MockValidatedFitted { n_components }))
843        }
844
845        fn fit_with_parameters(
846            &self,
847            _data: &Array2<f64>,
848            parameters: &HashMap<String, f64>,
849        ) -> Result<Box<dyn ValidatedFittedMethod>> {
850            let n_components = parameters.get("n_components").copied().unwrap_or(100.0) as usize;
851            Ok(Box::new(MockValidatedFitted { n_components }))
852        }
853    }
854
855    struct MockValidatedFitted {
856        n_components: usize,
857    }
858
859    impl ValidatedFittedMethod for MockValidatedFitted {
860        fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
861            let n_samples = data.nrows();
862            let mut kernel = Array2::zeros((n_samples, n_samples));
863
864            // Simple mock kernel matrix (identity-like)
865            for i in 0..n_samples {
866                kernel[[i, i]] = 1.0;
867                for j in i + 1..n_samples {
868                    let similarity = 0.5; // Simple mock similarity
869                    kernel[[i, j]] = similarity;
870                    kernel[[j, i]] = similarity;
871                }
872            }
873
874            Ok(kernel)
875        }
876
877        fn compute_condition_number(&self) -> Result<Option<f64>> {
878            // Simplified condition number estimation
879            Ok(Some(10.0))
880        }
881
882        fn approximation_dimension(&self) -> usize {
883            self.n_components
884        }
885    }
886
887    #[test]
888    fn test_validator_creation() {
889        let config = ValidationConfig::default();
890        let validator = KernelApproximationValidator::new(config);
891
892        assert!(!validator.theoretical_bounds.is_empty());
893        assert!(validator.theoretical_bounds.contains_key("RBF"));
894    }
895
896    #[test]
897    fn test_method_validation() {
898        let config = ValidationConfig {
899            approximation_dimensions: vec![10, 20],
900            repetitions: 2,
901            ..Default::default()
902        };
903        let validator = KernelApproximationValidator::new(config);
904
905        let data = Array2::from_shape_fn((50, 5), |(i, j)| (i + j) as f64 * 0.1);
906        let method = MockValidatableRBF { gamma: 1.0 };
907
908        let result = validator.validate_method(&method, &data, None).unwrap();
909
910        assert_eq!(result.method_name, "MockRBF");
911        assert_eq!(result.empirical_errors.len(), 2);
912        assert_eq!(result.theoretical_bounds.len(), 2);
913        // Convergence rate may be None if insufficient data points or poor fit
914        // This is acceptable as long as other validation results are present
915        if let Some(rate) = result.convergence_rate {
916            assert!(rate.is_finite());
917        }
918    }
919
920    #[test]
921    fn test_cross_validation() {
922        let config = ValidationConfig::default();
923        let validator = KernelApproximationValidator::new(config);
924
925        let data = Array2::from_shape_fn((30, 4), |(i, j)| (i + j) as f64 * 0.1);
926        let method = MockValidatableRBF { gamma: 1.0 };
927
928        let mut parameter_grid = HashMap::new();
929        parameter_grid.insert("gamma".to_string(), vec![0.5, 1.0, 2.0]);
930        parameter_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
931
932        let result = validator
933            .cross_validate(&method, &data, None, parameter_grid)
934            .unwrap();
935
936        assert_eq!(result.method_name, "MockRBF");
937        assert!(!result.cv_scores.is_empty());
938        assert!(!result.best_parameters.is_empty());
939    }
940
941    #[test]
942    fn test_theoretical_bounds() {
943        let config = ValidationConfig::default();
944        let validator = KernelApproximationValidator::new(config);
945
946        let bound = validator.theoretical_bounds.get("RBF").unwrap();
947        let theoretical_bound = validator
948            .compute_theoretical_bound(bound, 100, 10, 50)
949            .unwrap();
950
951        assert!(theoretical_bound > 0.0);
952        assert!(theoretical_bound.is_finite());
953    }
954}