1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Distribution;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::error::Result;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
17pub struct KernelApproximationValidator {
19 config: ValidationConfig,
20 theoretical_bounds: HashMap<String, TheoreticalBound>,
21}
22
23#[derive(Debug, Clone)]
25pub struct ValidationConfig {
27 pub confidence_level: f64,
29 pub max_approximation_error: f64,
31 pub convergence_tolerance: f64,
33 pub stability_tolerance: f64,
35 pub sample_sizes: Vec<usize>,
37 pub approximation_dimensions: Vec<usize>,
39 pub repetitions: usize,
41 pub random_state: Option<u64>,
43}
44
45impl Default for ValidationConfig {
46 fn default() -> Self {
47 Self {
48 confidence_level: 0.95,
49 max_approximation_error: 0.1,
50 convergence_tolerance: 1e-6,
51 stability_tolerance: 1e-4,
52 sample_sizes: vec![100, 500, 1000, 2000],
53 approximation_dimensions: vec![50, 100, 200, 500],
54 repetitions: 10,
55 random_state: Some(42),
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct TheoreticalBound {
64 pub method_name: String,
66 pub bound_type: BoundType,
68 pub bound_function: BoundFunction,
70 pub constants: HashMap<String, f64>,
72}
73
74#[derive(Debug, Clone)]
76pub enum BoundType {
78 Probabilistic { confidence: f64 },
80 Deterministic,
82 Expected,
84 Concentration { deviation_parameter: f64 },
86}
87
88#[derive(Debug, Clone)]
90pub enum BoundFunction {
92 RandomFourierFeatures,
94 Nystroem,
96 StructuredRandomFeatures,
98 Fastfood,
100 Custom { formula: String },
102}
103
104#[derive(Debug, Clone)]
106pub struct ValidationResult {
108 pub method_name: String,
110 pub empirical_errors: Vec<f64>,
112 pub theoretical_bounds: Vec<f64>,
114 pub bound_violations: usize,
116 pub bound_tightness: f64,
118 pub convergence_rate: Option<f64>,
120 pub stability_analysis: StabilityAnalysis,
122 pub sample_complexity: SampleComplexityAnalysis,
124 pub dimension_dependency: DimensionDependencyAnalysis,
126}
127
128#[derive(Debug, Clone)]
130pub struct StabilityAnalysis {
132 pub perturbation_sensitivity: f64,
134 pub numerical_stability: f64,
136 pub condition_numbers: Vec<f64>,
138 pub eigenvalue_stability: f64,
140}
141
142#[derive(Debug, Clone)]
144pub struct SampleComplexityAnalysis {
146 pub minimum_samples: usize,
148 pub convergence_rate: f64,
150 pub sample_efficiency: f64,
152 pub dimension_scaling: f64,
154}
155
156#[derive(Debug, Clone)]
158pub struct DimensionDependencyAnalysis {
160 pub approximation_quality_vs_dimension: Vec<(usize, f64)>,
162 pub computational_cost_vs_dimension: Vec<(usize, f64)>,
164 pub optimal_dimension: usize,
166 pub dimension_efficiency: f64,
168}
169
170#[derive(Debug, Clone)]
172pub struct CrossValidationResult {
174 pub method_name: String,
176 pub cv_scores: Vec<f64>,
178 pub mean_score: f64,
180 pub std_score: f64,
182 pub best_parameters: HashMap<String, f64>,
184 pub parameter_sensitivity: HashMap<String, f64>,
186}
187
188impl KernelApproximationValidator {
189 pub fn new(config: ValidationConfig) -> Self {
191 let mut validator = Self {
192 config,
193 theoretical_bounds: HashMap::new(),
194 };
195
196 validator.add_default_bounds();
198 validator
199 }
200
201 pub fn add_theoretical_bound(&mut self, bound: TheoreticalBound) {
203 self.theoretical_bounds
204 .insert(bound.method_name.clone(), bound);
205 }
206
207 fn add_default_bounds(&mut self) {
208 self.add_theoretical_bound(TheoreticalBound {
210 method_name: "RBF".to_string(),
211 bound_type: BoundType::Probabilistic { confidence: 0.95 },
212 bound_function: BoundFunction::RandomFourierFeatures,
213 constants: [
214 ("kernel_bound".to_string(), 1.0),
215 ("lipschitz_constant".to_string(), 1.0),
216 ]
217 .iter()
218 .cloned()
219 .collect(),
220 });
221
222 self.add_theoretical_bound(TheoreticalBound {
224 method_name: "Nystroem".to_string(),
225 bound_type: BoundType::Expected,
226 bound_function: BoundFunction::Nystroem,
227 constants: [
228 ("trace_bound".to_string(), 1.0),
229 ("effective_rank".to_string(), 100.0),
230 ]
231 .iter()
232 .cloned()
233 .collect(),
234 });
235
236 self.add_theoretical_bound(TheoreticalBound {
238 method_name: "Fastfood".to_string(),
239 bound_type: BoundType::Probabilistic { confidence: 0.95 },
240 bound_function: BoundFunction::Fastfood,
241 constants: [
242 ("dimension_factor".to_string(), 1.0),
243 ("log_factor".to_string(), 2.0),
244 ]
245 .iter()
246 .cloned()
247 .collect(),
248 });
249 }
250
251 pub fn validate_method<T: ValidatableKernelMethod>(
253 &self,
254 method: &T,
255 data: &Array2<f64>,
256 true_kernel: Option<&Array2<f64>>,
257 ) -> Result<ValidationResult> {
258 let method_name = method.method_name();
259 let mut empirical_errors = Vec::new();
260 let mut theoretical_bounds = Vec::new();
261 let mut condition_numbers = Vec::new();
262
263 for &n_components in &self.config.approximation_dimensions {
265 let mut dimension_errors = Vec::new();
266
267 for _ in 0..self.config.repetitions {
268 let fitted = method.fit_with_dimension(data, n_components)?;
270 let approximation = fitted.get_kernel_approximation(data)?;
271
272 let empirical_error = if let Some(true_k) = true_kernel {
274 self.compute_approximation_error(&approximation, true_k)?
275 } else {
276 let rbf_kernel = self.compute_rbf_kernel(data, 1.0)?;
278 self.compute_approximation_error(&approximation, &rbf_kernel)?
279 };
280
281 dimension_errors.push(empirical_error);
282
283 if let Some(cond_num) = fitted.compute_condition_number()? {
285 condition_numbers.push(cond_num);
286 }
287 }
288
289 let mean_error = dimension_errors.iter().sum::<f64>() / dimension_errors.len() as f64;
290 empirical_errors.push(mean_error);
291
292 if let Some(bound) = self.theoretical_bounds.get(&method_name) {
294 let theoretical_bound = self.compute_theoretical_bound(
295 bound,
296 data.nrows(),
297 data.ncols(),
298 n_components,
299 )?;
300 theoretical_bounds.push(theoretical_bound);
301 } else {
302 theoretical_bounds.push(f64::INFINITY);
303 }
304 }
305
306 let bound_violations = empirical_errors
308 .iter()
309 .zip(theoretical_bounds.iter())
310 .filter(|(&emp, &theo)| emp > theo)
311 .count();
312
313 let bound_tightness = empirical_errors
315 .iter()
316 .zip(theoretical_bounds.iter())
317 .filter(|(_, &theo)| theo.is_finite())
318 .map(|(&emp, &theo)| emp / theo)
319 .sum::<f64>()
320 / empirical_errors.len() as f64;
321
322 let convergence_rate = self.estimate_convergence_rate(&empirical_errors);
324
325 let stability_analysis = self.analyze_stability(method, data, &condition_numbers)?;
327
328 let sample_complexity = self.analyze_sample_complexity(method, data)?;
330
331 let dimension_dependency =
333 self.analyze_dimension_dependency(method, data, &empirical_errors)?;
334
335 Ok(ValidationResult {
336 method_name,
337 empirical_errors,
338 theoretical_bounds,
339 bound_violations,
340 bound_tightness,
341 convergence_rate,
342 stability_analysis,
343 sample_complexity,
344 dimension_dependency,
345 })
346 }
347
348 pub fn cross_validate<T: ValidatableKernelMethod>(
350 &self,
351 method: &T,
352 data: &Array2<f64>,
353 targets: Option<&Array1<f64>>,
354 parameter_grid: HashMap<String, Vec<f64>>,
355 ) -> Result<CrossValidationResult> {
356 let mut best_score = f64::NEG_INFINITY;
357 let mut best_parameters = HashMap::new();
358 let mut all_scores = Vec::new();
359 let mut parameter_sensitivity = HashMap::new();
360
361 let param_combinations = self.generate_parameter_combinations(¶meter_grid);
363
364 for params in param_combinations {
365 let cv_scores = self.k_fold_cross_validation(method, data, targets, ¶ms, 5)?;
366 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
367
368 all_scores.push(mean_score);
369
370 if mean_score > best_score {
371 best_score = mean_score;
372 best_parameters = params.clone();
373 }
374 }
375
376 for (param_name, param_values) in ¶meter_grid {
378 let mut sensitivities = Vec::new();
379
380 for (i, ¶m_value) in param_values.iter().enumerate() {
381 let mut single_param = best_parameters.clone();
382 single_param.insert(param_name.clone(), param_value);
383
384 let cv_scores =
385 self.k_fold_cross_validation(method, data, targets, &single_param, 3)?;
386 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
387 sensitivities.push((best_score - mean_score).abs());
388 }
389
390 let sensitivity = sensitivities.iter().sum::<f64>() / sensitivities.len() as f64;
391 parameter_sensitivity.insert(param_name.clone(), sensitivity);
392 }
393
394 let mean_score = all_scores.iter().sum::<f64>() / all_scores.len() as f64;
395 let variance = all_scores
396 .iter()
397 .map(|&x| (x - mean_score).powi(2))
398 .sum::<f64>()
399 / all_scores.len() as f64;
400 let std_score = variance.sqrt();
401
402 Ok(CrossValidationResult {
403 method_name: method.method_name(),
404 cv_scores: all_scores,
405 mean_score,
406 std_score,
407 best_parameters,
408 parameter_sensitivity,
409 })
410 }
411
412 fn compute_approximation_error(
413 &self,
414 approx_kernel: &Array2<f64>,
415 true_kernel: &Array2<f64>,
416 ) -> Result<f64> {
417 let diff = approx_kernel - true_kernel;
419 let frobenius_error = diff.mapv(|x| x * x).sum().sqrt();
420
421 let true_norm = true_kernel.mapv(|x| x * x).sum().sqrt();
423 Ok(frobenius_error / true_norm.max(1e-8))
424 }
425
426 fn compute_rbf_kernel(&self, data: &Array2<f64>, gamma: f64) -> Result<Array2<f64>> {
427 let n_samples = data.nrows();
428 let mut kernel = Array2::zeros((n_samples, n_samples));
429
430 for i in 0..n_samples {
431 for j in i..n_samples {
432 let diff = &data.row(i) - &data.row(j);
433 let dist_sq = diff.mapv(|x| x * x).sum();
434 let similarity = (-gamma * dist_sq).exp();
435 kernel[[i, j]] = similarity;
436 kernel[[j, i]] = similarity;
437 }
438 }
439
440 Ok(kernel)
441 }
442
443 fn compute_theoretical_bound(
444 &self,
445 bound: &TheoreticalBound,
446 n_samples: usize,
447 n_features: usize,
448 n_components: usize,
449 ) -> Result<f64> {
450 let bound_value = match &bound.bound_function {
451 BoundFunction::RandomFourierFeatures => {
452 let kernel_bound = bound.constants.get("kernel_bound").unwrap_or(&1.0);
453 let lipschitz = bound.constants.get("lipschitz_constant").unwrap_or(&1.0);
454
455 let log_factor = (n_features as f64).ln();
457 kernel_bound * lipschitz * (log_factor / n_components as f64).sqrt()
458 }
459 BoundFunction::Nystroem => {
460 let trace_bound = bound.constants.get("trace_bound").unwrap_or(&1.0);
461 let effective_rank = bound.constants.get("effective_rank").unwrap_or(&100.0);
462
463 trace_bound * (effective_rank / n_components as f64).sqrt()
465 }
466 BoundFunction::StructuredRandomFeatures => {
467 let log_factor = (n_features as f64).ln();
468 (n_features as f64 * log_factor / n_components as f64).sqrt()
469 }
470 BoundFunction::Fastfood => {
471 let log_factor = bound.constants.get("log_factor").unwrap_or(&2.0);
472 let dim_factor = bound.constants.get("dimension_factor").unwrap_or(&1.0);
473
474 let log_d = (n_features as f64).ln();
475 dim_factor
476 * (n_features as f64 * log_d.powf(*log_factor) / n_components as f64).sqrt()
477 }
478 BoundFunction::Custom { formula: _ } => {
479 1.0 / (n_components as f64).sqrt()
481 }
482 };
483
484 let final_bound = match &bound.bound_type {
486 BoundType::Probabilistic { confidence } => {
487 let z_score = self.inverse_normal_cdf(*confidence);
489 bound_value * (1.0 + z_score / (n_samples as f64).sqrt())
490 }
491 BoundType::Deterministic => bound_value,
492 BoundType::Expected => bound_value * 0.8, BoundType::Concentration {
494 deviation_parameter,
495 } => bound_value * (1.0 + deviation_parameter / (n_samples as f64).sqrt()),
496 };
497
498 Ok(final_bound)
499 }
500
501 fn inverse_normal_cdf(&self, p: f64) -> f64 {
502 if p <= 0.5 {
504 -self.inverse_normal_cdf(1.0 - p)
505 } else {
506 let t = (-2.0 * (1.0 - p).ln()).sqrt();
507 let c0 = 2.515517;
508 let c1 = 0.802853;
509 let c2 = 0.010328;
510 let d1 = 1.432788;
511 let d2 = 0.189269;
512 let d3 = 0.001308;
513
514 t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
515 }
516 }
517
518 fn estimate_convergence_rate(&self, errors: &[f64]) -> Option<f64> {
519 if errors.len() < 3 {
520 return None;
521 }
522
523 let dimensions: Vec<f64> = self
525 .config
526 .approximation_dimensions
527 .iter()
528 .take(errors.len())
529 .map(|&x| (x as f64).ln())
530 .collect();
531
532 let log_errors: Vec<f64> = errors.iter().map(|&x| x.ln()).collect();
533
534 let n = dimensions.len() as f64;
536 let sum_x = dimensions.iter().sum::<f64>();
537 let sum_y = log_errors.iter().sum::<f64>();
538 let sum_xy = dimensions
539 .iter()
540 .zip(log_errors.iter())
541 .map(|(&x, &y)| x * y)
542 .sum::<f64>();
543 let sum_x2 = dimensions.iter().map(|&x| x * x).sum::<f64>();
544
545 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
546 Some(-slope) }
548
549 fn analyze_stability<T: ValidatableKernelMethod>(
550 &self,
551 method: &T,
552 data: &Array2<f64>,
553 condition_numbers: &[f64],
554 ) -> Result<StabilityAnalysis> {
555 let mut rng = RealStdRng::seed_from_u64(self.config.random_state.unwrap_or(42));
556 let normal = RandNormal::new(0.0, self.config.stability_tolerance).unwrap();
557
558 let mut perturbation_errors = Vec::new();
560
561 for _ in 0..5 {
562 let mut perturbed_data = data.clone();
563 for elem in perturbed_data.iter_mut() {
564 *elem += rng.sample(normal);
565 }
566
567 let original_fitted = method.fit_with_dimension(data, 100)?;
568 let perturbed_fitted = method.fit_with_dimension(&perturbed_data, 100)?;
569
570 let original_approx = original_fitted.get_kernel_approximation(data)?;
571 let perturbed_approx = perturbed_fitted.get_kernel_approximation(data)?;
572
573 let error = self.compute_approximation_error(&perturbed_approx, &original_approx)?;
574 perturbation_errors.push(error);
575 }
576
577 let perturbation_sensitivity =
578 perturbation_errors.iter().sum::<f64>() / perturbation_errors.len() as f64;
579
580 let numerical_stability = if condition_numbers.is_empty() {
582 1.0
583 } else {
584 let mean_condition =
585 condition_numbers.iter().sum::<f64>() / condition_numbers.len() as f64;
586 1.0 / mean_condition.ln().max(1.0)
587 };
588
589 let eigenvalue_stability = 1.0 - perturbation_sensitivity;
591
592 Ok(StabilityAnalysis {
593 perturbation_sensitivity,
594 numerical_stability,
595 condition_numbers: condition_numbers.to_vec(),
596 eigenvalue_stability,
597 })
598 }
599
600 fn analyze_sample_complexity<T: ValidatableKernelMethod>(
601 &self,
602 method: &T,
603 data: &Array2<f64>,
604 ) -> Result<SampleComplexityAnalysis> {
605 let mut sample_errors = Vec::new();
606
607 for &n_samples in &self.config.sample_sizes {
609 if n_samples > data.nrows() {
610 continue;
611 }
612
613 let subset_data = data
614 .slice(scirs2_core::ndarray::s![..n_samples, ..])
615 .to_owned();
616 let fitted = method.fit_with_dimension(&subset_data, 100)?;
617 let approx = fitted.get_kernel_approximation(&subset_data)?;
618
619 let rbf_kernel = self.compute_rbf_kernel(&subset_data, 1.0)?;
620 let error = self.compute_approximation_error(&approx, &rbf_kernel)?;
621 sample_errors.push(error);
622 }
623
624 let target_error = self.config.max_approximation_error;
626 let minimum_samples = self
627 .config
628 .sample_sizes
629 .iter()
630 .zip(sample_errors.iter())
631 .find(|(_, &error)| error <= target_error)
632 .map(|(&samples, _)| samples)
633 .unwrap_or(*self.config.sample_sizes.last().unwrap());
634
635 let convergence_rate = if sample_errors.len() >= 2 {
637 let log_samples: Vec<f64> = self
638 .config
639 .sample_sizes
640 .iter()
641 .take(sample_errors.len())
642 .map(|&x| (x as f64).ln())
643 .collect();
644 let log_errors: Vec<f64> = sample_errors.iter().map(|&x| x.ln()).collect();
645
646 let n = log_samples.len() as f64;
648 let sum_x = log_samples.iter().sum::<f64>();
649 let sum_y = log_errors.iter().sum::<f64>();
650 let sum_xy = log_samples
651 .iter()
652 .zip(log_errors.iter())
653 .map(|(&x, &y)| x * y)
654 .sum::<f64>();
655 let sum_x2 = log_samples.iter().map(|&x| x * x).sum::<f64>();
656
657 -(n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
658 } else {
659 0.5 };
661
662 let sample_efficiency = 1.0 / minimum_samples as f64;
663 let dimension_scaling = data.ncols() as f64 / minimum_samples as f64;
664
665 Ok(SampleComplexityAnalysis {
666 minimum_samples,
667 convergence_rate,
668 sample_efficiency,
669 dimension_scaling,
670 })
671 }
672
673 fn analyze_dimension_dependency<T: ValidatableKernelMethod>(
674 &self,
675 method: &T,
676 data: &Array2<f64>,
677 errors: &[f64],
678 ) -> Result<DimensionDependencyAnalysis> {
679 let approximation_quality_vs_dimension: Vec<(usize, f64)> = self
680 .config
681 .approximation_dimensions
682 .iter()
683 .take(errors.len())
684 .zip(errors.iter())
685 .map(|(&dim, &error)| (dim, 1.0 - error)) .collect();
687
688 let computational_cost_vs_dimension: Vec<(usize, f64)> = self
690 .config
691 .approximation_dimensions
692 .iter()
693 .map(|&dim| (dim, dim as f64 * data.nrows() as f64))
694 .collect();
695
696 let optimal_dimension = approximation_quality_vs_dimension
698 .iter()
699 .zip(computational_cost_vs_dimension.iter())
700 .map(|((dim, quality), (_, cost))| (*dim, quality / cost))
701 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
702 .map(|(dim, _)| dim)
703 .unwrap_or(100);
704
705 let dimension_efficiency = approximation_quality_vs_dimension
706 .iter()
707 .map(|(_, quality)| quality)
708 .sum::<f64>()
709 / approximation_quality_vs_dimension.len() as f64;
710
711 Ok(DimensionDependencyAnalysis {
712 approximation_quality_vs_dimension,
713 computational_cost_vs_dimension,
714 optimal_dimension,
715 dimension_efficiency,
716 })
717 }
718
719 fn generate_parameter_combinations(
720 &self,
721 parameter_grid: &HashMap<String, Vec<f64>>,
722 ) -> Vec<HashMap<String, f64>> {
723 let mut combinations = vec![HashMap::new()];
724
725 for (param_name, param_values) in parameter_grid {
726 let mut new_combinations = Vec::new();
727
728 for combination in &combinations {
729 for ¶m_value in param_values {
730 let mut new_combination = combination.clone();
731 new_combination.insert(param_name.clone(), param_value);
732 new_combinations.push(new_combination);
733 }
734 }
735
736 combinations = new_combinations;
737 }
738
739 combinations
740 }
741
742 fn k_fold_cross_validation<T: ValidatableKernelMethod>(
743 &self,
744 method: &T,
745 data: &Array2<f64>,
746 _targets: Option<&Array1<f64>>,
747 parameters: &HashMap<String, f64>,
748 k: usize,
749 ) -> Result<Vec<f64>> {
750 let n_samples = data.nrows();
751 let fold_size = n_samples / k;
752 let mut scores = Vec::new();
753
754 for fold in 0..k {
755 let start_idx = fold * fold_size;
756 let end_idx = if fold == k - 1 {
757 n_samples
758 } else {
759 (fold + 1) * fold_size
760 };
761
762 let train_indices: Vec<usize> = (0..n_samples)
764 .filter(|&i| i < start_idx || i >= end_idx)
765 .collect();
766 let val_indices: Vec<usize> = (start_idx..end_idx).collect();
767
768 let train_data = data.select(Axis(0), &train_indices);
769 let val_data = data.select(Axis(0), &val_indices);
770
771 let fitted = method.fit_with_parameters(&train_data, parameters)?;
773 let train_approx = fitted.get_kernel_approximation(&train_data)?;
774 let val_approx = fitted.get_kernel_approximation(&val_data)?;
775
776 let train_kernel = self.compute_rbf_kernel(&train_data, 1.0)?;
778 let val_kernel = self.compute_rbf_kernel(&val_data, 1.0)?;
779
780 let train_error = self.compute_approximation_error(&train_approx, &train_kernel)?;
781 let val_error = self.compute_approximation_error(&val_approx, &val_kernel)?;
782
783 let score = -(train_error + val_error) / 2.0;
785 scores.push(score);
786 }
787
788 Ok(scores)
789 }
790}
791
792pub trait ValidatableKernelMethod {
794 fn method_name(&self) -> String;
796
797 fn fit_with_dimension(
799 &self,
800 data: &Array2<f64>,
801 n_components: usize,
802 ) -> Result<Box<dyn ValidatedFittedMethod>>;
803
804 fn fit_with_parameters(
806 &self,
807 data: &Array2<f64>,
808 parameters: &HashMap<String, f64>,
809 ) -> Result<Box<dyn ValidatedFittedMethod>>;
810}
811
812pub trait ValidatedFittedMethod {
814 fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
816
817 fn compute_condition_number(&self) -> Result<Option<f64>>;
819
820 fn approximation_dimension(&self) -> usize;
822}
823
824#[allow(non_snake_case)]
825#[cfg(test)]
826mod tests {
827 use super::*;
828 struct MockValidatableRBF {
830 gamma: f64,
831 }
832
833 impl ValidatableKernelMethod for MockValidatableRBF {
834 fn method_name(&self) -> String {
835 "MockRBF".to_string()
836 }
837
838 fn fit_with_dimension(
839 &self,
840 data: &Array2<f64>,
841 n_components: usize,
842 ) -> Result<Box<dyn ValidatedFittedMethod>> {
843 Ok(Box::new(MockValidatedFitted { n_components }))
844 }
845
846 fn fit_with_parameters(
847 &self,
848 data: &Array2<f64>,
849 parameters: &HashMap<String, f64>,
850 ) -> Result<Box<dyn ValidatedFittedMethod>> {
851 let n_components = parameters.get("n_components").copied().unwrap_or(100.0) as usize;
852 Ok(Box::new(MockValidatedFitted { n_components }))
853 }
854 }
855
856 struct MockValidatedFitted {
857 n_components: usize,
858 }
859
860 impl ValidatedFittedMethod for MockValidatedFitted {
861 fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
862 let n_samples = data.nrows();
863 let mut kernel = Array2::zeros((n_samples, n_samples));
864
865 for i in 0..n_samples {
867 kernel[[i, i]] = 1.0;
868 for j in i + 1..n_samples {
869 let similarity = 0.5; kernel[[i, j]] = similarity;
871 kernel[[j, i]] = similarity;
872 }
873 }
874
875 Ok(kernel)
876 }
877
878 fn compute_condition_number(&self) -> Result<Option<f64>> {
879 Ok(Some(10.0))
881 }
882
883 fn approximation_dimension(&self) -> usize {
884 self.n_components
885 }
886 }
887
888 #[test]
889 fn test_validator_creation() {
890 let config = ValidationConfig::default();
891 let validator = KernelApproximationValidator::new(config);
892
893 assert!(!validator.theoretical_bounds.is_empty());
894 assert!(validator.theoretical_bounds.contains_key("RBF"));
895 }
896
897 #[test]
898 fn test_method_validation() {
899 let config = ValidationConfig {
900 approximation_dimensions: vec![10, 20],
901 repetitions: 2,
902 ..Default::default()
903 };
904 let validator = KernelApproximationValidator::new(config);
905
906 let data = Array2::from_shape_fn((50, 5), |(i, j)| (i + j) as f64 * 0.1);
907 let method = MockValidatableRBF { gamma: 1.0 };
908
909 let result = validator.validate_method(&method, &data, None).unwrap();
910
911 assert_eq!(result.method_name, "MockRBF");
912 assert_eq!(result.empirical_errors.len(), 2);
913 assert_eq!(result.theoretical_bounds.len(), 2);
914 if let Some(rate) = result.convergence_rate {
917 assert!(rate.is_finite());
918 }
919 }
920
921 #[test]
922 fn test_cross_validation() {
923 let config = ValidationConfig::default();
924 let validator = KernelApproximationValidator::new(config);
925
926 let data = Array2::from_shape_fn((30, 4), |(i, j)| (i + j) as f64 * 0.1);
927 let method = MockValidatableRBF { gamma: 1.0 };
928
929 let mut parameter_grid = HashMap::new();
930 parameter_grid.insert("gamma".to_string(), vec![0.5, 1.0, 2.0]);
931 parameter_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
932
933 let result = validator
934 .cross_validate(&method, &data, None, parameter_grid)
935 .unwrap();
936
937 assert_eq!(result.method_name, "MockRBF");
938 assert!(!result.cv_scores.is_empty());
939 assert!(!result.best_parameters.is_empty());
940 }
941
942 #[test]
943 fn test_theoretical_bounds() {
944 let config = ValidationConfig::default();
945 let validator = KernelApproximationValidator::new(config);
946
947 let bound = validator.theoretical_bounds.get("RBF").unwrap();
948 let theoretical_bound = validator
949 .compute_theoretical_bound(bound, 100, 10, 50)
950 .unwrap();
951
952 assert!(theoretical_bound > 0.0);
953 assert!(theoretical_bound.is_finite());
954 }
955}