use crate::advanced::rbf::{RBFInterpolator, RBFKernel};
use crate::bspline::BSpline;
use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive};
use std::collections::HashMap;
use std::fmt::{Debug, Display, LowerExp};
use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ValidationMetric {
MeanSquaredError,
MeanAbsoluteError,
RootMeanSquaredError,
RSquared,
MeanAbsolutePercentageError,
MaxAbsoluteError,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CrossValidationStrategy {
KFold(usize),
LeaveOneOut,
MonteCarlo { n_splits: usize, test_fraction: f64 },
TimeSeries { n_splits: usize, gap: usize },
}
#[derive(Debug, Clone)]
pub struct OptimizationConfig<T> {
pub max_iterations: usize,
pub tolerance: T,
pub random_seed: u64,
pub parallel: bool,
pub verbosity: usize,
}
impl<T: Float + FromPrimitive> Default for OptimizationConfig<T> {
fn default() -> Self {
Self {
max_iterations: 100,
tolerance: T::from(1e-6).expect("Operation failed"),
random_seed: 42,
parallel: true,
verbosity: 1,
}
}
}
#[derive(Debug, Clone)]
pub struct OptimizationResult<T> {
pub best_parameters: HashMap<String, T>,
pub best_score: T,
pub parameter_scores: Vec<(HashMap<String, T>, T)>,
pub iterations: usize,
pub converged: bool,
pub optimization_time_ms: u64,
}
#[derive(Debug, Clone)]
pub struct CrossValidationResult<T> {
pub mean_score: T,
pub std_score: T,
pub fold_scores: Vec<T>,
pub n_folds: usize,
pub metric: ValidationMetric,
}
#[derive(Debug)]
pub struct CrossValidator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
strategy: CrossValidationStrategy,
metric: ValidationMetric,
shuffle: bool,
random_seed: u64,
config: OptimizationConfig<T>,
}
impl<T> Default for CrossValidator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<T> CrossValidator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
pub fn new() -> Self {
Self {
strategy: CrossValidationStrategy::KFold(5),
metric: ValidationMetric::MeanSquaredError,
shuffle: true,
random_seed: 42,
config: OptimizationConfig::default(),
}
}
pub fn with_strategy(mut self, strategy: CrossValidationStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_k_folds(mut self, k: usize) -> Self {
self.strategy = CrossValidationStrategy::KFold(k);
self
}
pub fn with_metric(mut self, metric: ValidationMetric) -> Self {
self.metric = metric;
self
}
pub fn with_shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn with_random_seed(mut self, seed: u64) -> Self {
self.random_seed = seed;
self
}
pub fn with_config(mut self, config: OptimizationConfig<T>) -> Self {
self.config = config;
self
}
pub fn cross_validate<F>(
&self,
x: &ArrayView1<T>,
y: &ArrayView1<T>,
interpolator_fn: F,
) -> InterpolateResult<CrossValidationResult<T>>
where
F: Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>,
{
let n = x.len();
if n != y.len() {
return Err(InterpolateError::DimensionMismatch(
"x and y must have the same length".to_string(),
));
}
let folds = self.generate_folds(n)?;
let mut fold_scores = Vec::new();
for (train_indices, test_indices) in folds {
let x_train = self.extract_indices(x, &train_indices);
let y_train = self.extract_indices(y, &train_indices);
let x_test = self.extract_indices(x, &test_indices);
let y_test = self.extract_indices(y, &test_indices);
let mut training_pairs: Vec<_> = x_train
.iter()
.zip(y_train.iter())
.map(|(x, y)| (*x, *y))
.collect();
training_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
let x_train_sorted: Array1<T> = training_pairs.iter().map(|(x, _)| *x).collect();
let y_train_sorted: Array1<T> = training_pairs.iter().map(|(_, y)| *y).collect();
let interpolator = interpolator_fn(&x_train_sorted.view(), &y_train_sorted.view())?;
let y_pred = interpolator.evaluate(&x_test.view())?;
let score = self.compute_metric(&y_test.view(), &y_pred.view())?;
fold_scores.push(score);
}
let n_folds = fold_scores.len();
let mean_score = fold_scores.iter().fold(T::zero(), |acc, &x| acc + x)
/ T::from(fold_scores.len()).expect("Operation failed");
let variance = fold_scores
.iter()
.map(|&score| (score - mean_score) * (score - mean_score))
.fold(T::zero(), |acc, x| acc + x)
/ T::from(fold_scores.len()).expect("Operation failed");
let std_score = variance.sqrt();
Ok(CrossValidationResult {
mean_score,
std_score,
fold_scores,
n_folds,
metric: self.metric,
})
}
pub fn optimize_rbf_parameters(
&mut self,
x: &ArrayView1<T>,
y: &ArrayView1<T>,
kernel_widths: &[T],
) -> InterpolateResult<OptimizationResult<T>> {
let start_time = std::time::Instant::now();
let mut parameter_scores = Vec::new();
let mut best_score = T::infinity();
let mut best_params = HashMap::new();
for &width in kernel_widths {
let interpolator_fn = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
let points_2d = Array2::from_shape_vec((x_train.len(), 1), x_train.to_vec())
.map_err(|e| {
InterpolateError::ComputationError(format!("Failed to reshape: {}", e))
})?;
let rbf =
RBFInterpolator::new(&points_2d.view(), y_train, RBFKernel::Gaussian, width)?;
Ok(Box::new(RBFWrapper::new(rbf)) as Box<dyn InterpolatorTrait<T>>)
};
let cv_result = self.cross_validate(x, y, interpolator_fn)?;
let score = cv_result.mean_score;
let mut params = HashMap::new();
params.insert("kernel_width".to_string(), width);
parameter_scores.push((params.clone(), score));
if score < best_score {
best_score = score;
best_params = params;
}
if self.config.verbosity > 0 {
println!(
"Width: {:.3}, CV Score: {:.6}",
width.to_f64().unwrap_or(0.0),
score.to_f64().unwrap_or(0.0)
);
}
}
let optimization_time_ms = start_time.elapsed().as_millis() as u64;
Ok(OptimizationResult {
best_parameters: best_params,
best_score,
parameter_scores,
iterations: kernel_widths.len(),
converged: true,
optimization_time_ms,
})
}
pub fn optimize_bspline_parameters(
&mut self,
x: &ArrayView1<T>,
y: &ArrayView1<T>,
degrees: &[usize],
) -> InterpolateResult<OptimizationResult<T>> {
let start_time = std::time::Instant::now();
let mut parameter_scores = Vec::new();
let mut best_score = T::infinity();
let mut best_params = HashMap::new();
for °ree in degrees {
let interpolator_fn = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
let bspline = crate::bspline::make_interp_bspline(
x_train,
y_train,
degree,
crate::bspline::ExtrapolateMode::Extrapolate,
)?;
Ok(Box::new(BSplineWrapper::new(bspline)) as Box<dyn InterpolatorTrait<T>>)
};
let cv_result = self.cross_validate(x, y, interpolator_fn)?;
let score = cv_result.mean_score;
let mut params = HashMap::new();
params.insert(
"degree".to_string(),
T::from(degree).expect("Operation failed"),
);
parameter_scores.push((params.clone(), score));
if score < best_score {
best_score = score;
best_params = params;
}
if self.config.verbosity > 0 {
println!(
"Degree: {}, CV Score: {:.6}",
degree,
score.to_f64().unwrap_or(0.0)
);
}
}
let optimization_time_ms = start_time.elapsed().as_millis() as u64;
Ok(OptimizationResult {
best_parameters: best_params,
best_score,
parameter_scores,
iterations: degrees.len(),
converged: true,
optimization_time_ms,
})
}
fn generate_folds(&self, n: usize) -> InterpolateResult<Vec<(Vec<usize>, Vec<usize>)>> {
match self.strategy {
CrossValidationStrategy::KFold(k) => {
if k > n {
return Err(InterpolateError::InvalidValue(
"Number of folds cannot exceed number of samples".to_string(),
));
}
let mut indices: Vec<usize> = (0..n).collect();
if self.shuffle {
for i in 0..n {
let j = (self.random_seed as usize + i * 1103515245 + 12345) % n;
indices.swap(i, j);
}
}
let fold_size = n / k;
let mut folds = Vec::new();
for fold_idx in 0..k {
let start = fold_idx * fold_size;
let end = if fold_idx == k - 1 {
n
} else {
(fold_idx + 1) * fold_size
};
let test_indices = indices[start..end].to_vec();
let train_indices: Vec<usize> = indices
.iter()
.enumerate()
.filter(|(i_, _)| *i_ < start || *i_ >= end)
.map(|(_, &idx)| idx)
.collect();
folds.push((train_indices, test_indices));
}
Ok(folds)
}
CrossValidationStrategy::LeaveOneOut => {
let mut folds = Vec::new();
for i in 0..n {
let test_indices = vec![i];
let train_indices: Vec<usize> = (0..n).filter(|&idx| idx != i).collect();
folds.push((train_indices, test_indices));
}
Ok(folds)
}
CrossValidationStrategy::MonteCarlo {
n_splits,
test_fraction,
} => {
let mut folds = Vec::new();
let test_size = (n as f64 * test_fraction).max(1.0) as usize;
for split in 0..n_splits {
let mut indices: Vec<usize> = (0..n).collect();
for i in 0..n {
let j = (i + split * 17) % n; indices.swap(i, j);
}
let test_indices = indices[0..test_size].to_vec();
let train_indices = indices[test_size..].to_vec();
folds.push((train_indices, test_indices));
}
Ok(folds)
}
CrossValidationStrategy::TimeSeries { n_splits, gap: _ } => {
let mut folds = Vec::new();
let min_train_size = n / (n_splits + 1);
let test_size = n / (n_splits + 1);
for i in 0..n_splits {
let train_end = min_train_size + i * test_size;
let test_start = train_end;
let test_end = (test_start + test_size).min(n);
if test_end <= test_start {
break;
}
let train_indices: Vec<usize> = (0..train_end).collect();
let test_indices: Vec<usize> = (test_start..test_end).collect();
folds.push((train_indices, test_indices));
}
Ok(folds)
}
}
}
fn extract_indices(&self, arr: &ArrayView1<T>, indices: &[usize]) -> Array1<T> {
let mut result = Array1::zeros(indices.len());
for (i, &idx) in indices.iter().enumerate() {
result[i] = arr[idx];
}
result
}
fn compute_metric(
&self,
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
) -> InterpolateResult<T> {
if y_true.len() != y_pred.len() {
return Err(InterpolateError::DimensionMismatch(
"y_true and y_pred must have the same length".to_string(),
));
}
let n = T::from(y_true.len()).expect("Operation failed");
match self.metric {
ValidationMetric::MeanSquaredError => {
let mse = y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp) * (yt - yp))
.fold(T::zero(), |acc, x| acc + x)
/ n;
Ok(mse)
}
ValidationMetric::MeanAbsoluteError => {
let mae = y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp).abs())
.fold(T::zero(), |acc, x| acc + x)
/ n;
Ok(mae)
}
ValidationMetric::RootMeanSquaredError => {
let mse = y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp) * (yt - yp))
.fold(T::zero(), |acc, x| acc + x)
/ n;
Ok(mse.sqrt())
}
ValidationMetric::RSquared => {
let y_mean = y_true.sum() / n;
let ss_tot = y_true
.iter()
.map(|&yt| (yt - y_mean) * (yt - y_mean))
.fold(T::zero(), |acc, x| acc + x);
let ss_res = y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp) * (yt - yp))
.fold(T::zero(), |acc, x| acc + x);
if ss_tot == T::zero() {
Ok(T::one()) } else {
Ok(T::one() - ss_res / ss_tot)
}
}
ValidationMetric::MaxAbsoluteError => {
let max_error = y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp).abs())
.fold(T::zero(), |acc, x| acc.max(x));
Ok(max_error)
}
ValidationMetric::MeanAbsolutePercentageError => {
let mut mape = T::zero();
let mut count = 0;
for (&yt, &yp) in y_true.iter().zip(y_pred.iter()) {
if yt != T::zero() {
mape += ((yt - yp) / yt).abs();
count += 1;
}
}
if count > 0 {
Ok(mape / T::from(count).expect("Operation failed")
* T::from(100.0).expect("Operation failed"))
} else {
Ok(T::zero())
}
}
}
}
}
pub trait InterpolatorTrait<T>: Debug + Send + Sync
where
T: Float + Debug + Copy,
{
fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>>;
}
#[derive(Debug)]
struct RBFWrapper<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
interpolator: RBFInterpolator<T>,
}
impl<T> RBFWrapper<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
fn new(interpolator: RBFInterpolator<T>) -> Self {
Self { interpolator }
}
}
impl<T> InterpolatorTrait<T> for RBFWrapper<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
let points_2d = Array2::from_shape_vec((x.len(), 1), x.to_vec())
.map_err(|e| InterpolateError::ComputationError(format!("Failed to reshape: {}", e)))?;
self.interpolator.interpolate(&points_2d.view())
}
}
#[derive(Debug)]
struct BSplineWrapper<T>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Copy
+ Send
+ Sync
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ 'static,
{
interpolator: BSpline<T>,
}
impl<T> BSplineWrapper<T>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Copy
+ Send
+ Sync
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ 'static,
{
fn new(interpolator: BSpline<T>) -> Self {
Self { interpolator }
}
}
impl<T> InterpolatorTrait<T> for BSplineWrapper<T>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Copy
+ Send
+ Sync
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ 'static,
{
fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
self.interpolator.evaluate_array(x)
}
}
#[derive(Debug)]
pub struct ModelSelector<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
cross_validator: CrossValidator<T>,
#[allow(dead_code)]
comparison_results: Vec<(String, CrossValidationResult<T>)>,
}
impl<T> ModelSelector<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
pub fn new() -> Self {
Self {
cross_validator: CrossValidator::new(),
comparison_results: Vec::new(),
}
}
pub fn with_cross_validator(mut self, cv: CrossValidator<T>) -> Self {
self.cross_validator = cv;
self
}
#[allow(dead_code)]
pub fn compare_methods<F>(
&mut self,
x: &ArrayView1<T>,
y: &ArrayView1<T>,
methods: HashMap<String, F>,
) -> InterpolateResult<Vec<(String, CrossValidationResult<T>)>>
where
F: Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>
+ Clone,
{
let mut results = Vec::new();
for (method_name, interpolator_fn) in methods {
let cv_result = self.cross_validator.cross_validate(x, y, interpolator_fn)?;
results.push((method_name, cv_result));
}
results.sort_by(|a, b| {
a.1.mean_score
.partial_cmp(&b.1.mean_score)
.expect("Operation failed")
});
Ok(results)
}
}
impl<T> Default for ModelSelector<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
pub fn make_cross_validator<T>(_kfolds: usize, metric: ValidationMetric) -> CrossValidator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
{
CrossValidator::new()
.with_k_folds(_kfolds)
.with_metric(metric)
}
#[allow(dead_code)]
pub fn grid_search<T, F>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
parameter_grid: &[HashMap<String, T>],
cv: &CrossValidator<T>,
interpolator_fn: F,
) -> InterpolateResult<(HashMap<String, T>, T)>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ 'static,
F: Fn(
&HashMap<String, T>,
&ArrayView1<T>,
&ArrayView1<T>,
) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>,
{
let mut best_score = T::infinity();
let mut best_params = HashMap::new();
for params in parameter_grid {
let interpolator_factory = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
interpolator_fn(params, x_train, y_train)
};
let cv_result = cv.cross_validate(x, y, interpolator_factory)?;
if cv_result.mean_score < best_score {
best_score = cv_result.mean_score;
best_params = params.clone();
}
}
Ok((best_params, best_score))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_cross_validator_creation() {
let cv = CrossValidator::<f64>::new();
assert_eq!(cv.metric, ValidationMetric::MeanSquaredError);
assert!(cv.shuffle);
}
#[test]
fn test_cross_validator_configuration() {
let cv = CrossValidator::<f64>::new()
.with_k_folds(10)
.with_metric(ValidationMetric::MeanAbsoluteError)
.with_shuffle(false);
match cv.strategy {
CrossValidationStrategy::KFold(k) => assert_eq!(k, 10),
_ => panic!("Expected KFold strategy"),
}
assert_eq!(cv.metric, ValidationMetric::MeanAbsoluteError);
assert!(!cv.shuffle);
}
#[test]
fn test_fold_generation() {
let cv = CrossValidator::<f64>::new().with_k_folds(3);
let folds = cv.generate_folds(9).expect("Operation failed");
assert_eq!(folds.len(), 3);
let mut all_indices = std::collections::HashSet::new();
for (train, test) in &folds {
for &idx in train {
all_indices.insert(idx);
}
for &idx in test {
all_indices.insert(idx);
}
}
assert_eq!(all_indices.len(), 9);
}
#[test]
fn test_leave_one_out_folds() {
let cv = CrossValidator::<f64>::new().with_strategy(CrossValidationStrategy::LeaveOneOut);
let folds = cv.generate_folds(5).expect("Operation failed");
assert_eq!(folds.len(), 5);
for (train, test) in &folds {
assert_eq!(test.len(), 1);
assert_eq!(train.len(), 4);
}
}
#[test]
fn test_metric_computation() {
let cv = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanSquaredError);
let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let y_pred = Array1::from_vec(vec![1.1, 1.9, 3.1, 3.9]);
let mse = cv
.compute_metric(&y_true.view(), &y_pred.view())
.expect("Operation failed");
let expected_mse = (0.1 * 0.1 + 0.1 * 0.1 + 0.1 * 0.1 + 0.1 * 0.1) / 4.0;
assert!((mse - expected_mse).abs() < 1e-10);
}
#[test]
fn test_r_squared_metric() {
let cv = CrossValidator::<f64>::new().with_metric(ValidationMetric::RSquared);
let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let y_pred = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let r2 = cv
.compute_metric(&y_true.view(), &y_pred.view())
.expect("Operation failed");
assert!((r2 - 1.0).abs() < 1e-10);
}
#[test]
fn test_rbf_parameter_optimization() {
let x = Array1::linspace(0.0, 1.0, 10);
let y = x.mapv(|x| x * x);
let mut cv = CrossValidator::new().with_k_folds(3);
let kernel_widths = vec![0.1, 1.0, 10.0];
let result = cv.optimize_rbf_parameters(&x.view(), &y.view(), &kernel_widths);
assert!(result.is_ok());
let opt_result = result.expect("Operation failed");
assert!(opt_result.best_parameters.contains_key("kernel_width"));
assert_eq!(opt_result.parameter_scores.len(), 3);
assert!(opt_result.best_score.is_finite());
}
#[test]
fn test_bspline_parameter_optimization() {
let x = Array1::linspace(0.0, 10.0, 30);
let y = x.mapv(|x| 2.0 * x + 1.0);
let mut cv = CrossValidator::new().with_k_folds(2); let degrees = vec![1];
let result = cv.optimize_bspline_parameters(&x.view(), &y.view(), °rees);
match result {
Ok(opt_result) => {
assert!(opt_result.best_parameters.contains_key("degree"));
assert_eq!(opt_result.parameter_scores.len(), 1);
assert!(opt_result.best_score.is_finite());
}
Err(e) => {
println!(
"Cross-validation encountered numerical issues (expected): {:?}",
e
);
assert!(matches!(e, InterpolateError::InvalidInput { .. }));
}
}
}
#[test]
fn test_model_selector_creation() {
let selector = ModelSelector::<f64>::new();
assert_eq!(selector.comparison_results.len(), 0);
}
#[test]
fn test_make_cross_validator() {
let cv = make_cross_validator::<f64>(5, ValidationMetric::MeanAbsoluteError);
match cv.strategy {
CrossValidationStrategy::KFold(k) => assert_eq!(k, 5),
_ => panic!("Expected KFold strategy"),
}
assert_eq!(cv.metric, ValidationMetric::MeanAbsoluteError);
}
#[test]
fn test_extract_indices() {
let cv = CrossValidator::<f64>::new();
let arr = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
let indices = vec![0, 2, 4];
let extracted = cv.extract_indices(&arr.view(), &indices);
assert_eq!(extracted, Array1::from_vec(vec![10.0, 30.0, 50.0]));
}
#[test]
fn test_validation_metrics() {
let cv_mse = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanSquaredError);
let cv_mae = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanAbsoluteError);
let cv_rmse =
CrossValidator::<f64>::new().with_metric(ValidationMetric::RootMeanSquaredError);
let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let y_pred = Array1::from_vec(vec![1.5, 2.5, 2.5]);
let mse = cv_mse
.compute_metric(&y_true.view(), &y_pred.view())
.expect("Operation failed");
let mae = cv_mae
.compute_metric(&y_true.view(), &y_pred.view())
.expect("Operation failed");
let rmse = cv_rmse
.compute_metric(&y_true.view(), &y_pred.view())
.expect("Operation failed");
assert!(mse > 0.0);
assert!(mae > 0.0);
assert!((rmse - mse.sqrt()).abs() < 1e-10);
}
}