Skip to main content

sklears_model_selection/
validation.rs

1//! Model validation utilities
2
3use crate::{CrossValidator, ParameterValue};
4use scirs2_core::ndarray::{Array1, Array2};
5// use scirs2_core::SliceRandomExt;
6use sklears_core::{
7    error::Result,
8    prelude::{Predict, SklearsError},
9    traits::Fit,
10    traits::Score,
11    types::Float,
12};
13use sklears_metrics::{
14    classification::accuracy_score, get_scorer, regression::mean_squared_error, Scorer,
15};
16use std::collections::HashMap;
17
18/// Helper function for scoring that handles both regression and classification
19fn compute_score_for_regression_val(
20    metric_name: &str,
21    y_true: &Array1<f64>,
22    y_pred: &Array1<f64>,
23) -> Result<f64> {
24    match metric_name {
25        "neg_mean_squared_error" => Ok(-mean_squared_error(y_true, y_pred)?),
26        "mean_squared_error" => Ok(mean_squared_error(y_true, y_pred)?),
27        _ => {
28            // For unsupported metrics, return a default score
29            Err(sklears_core::error::SklearsError::InvalidInput(format!(
30                "Metric '{}' not supported for regression",
31                metric_name
32            )))
33        }
34    }
35}
36
37/// Helper function for scoring classification data
38fn compute_score_for_classification_val(
39    metric_name: &str,
40    y_true: &Array1<i32>,
41    y_pred: &Array1<i32>,
42) -> Result<f64> {
43    match metric_name {
44        "accuracy" => Ok(accuracy_score(y_true, y_pred)?),
45        _ => {
46            let scorer = get_scorer(metric_name)?;
47            scorer.score(
48                y_true.as_slice().expect("operation should succeed"),
49                y_pred.as_slice().expect("operation should succeed"),
50            )
51        }
52    }
53}
54
55/// Scoring method for cross-validation
56#[derive(Debug, Clone)]
57pub enum Scoring {
58    /// Use the estimator's built-in score method
59    EstimatorScore,
60    /// Use a predefined scorer by name
61    Metric(String),
62    /// Use a specific scorer configuration
63    Scorer(Scorer),
64    /// Use multiple scoring metrics
65    MultiMetric(Vec<String>),
66    /// Use a custom scoring function
67    Custom(fn(&Array1<Float>, &Array1<Float>) -> Result<f64>),
68}
69
70/// Enhanced scoring result that can handle multiple metrics
71#[derive(Debug, Clone)]
72pub enum ScoreResult {
73    /// Single score value
74    Single(f64),
75    /// Multiple score values with metric names
76    Multiple(HashMap<String, f64>),
77}
78
79impl ScoreResult {
80    /// Get a single score (first score if multiple)
81    pub fn as_single(&self) -> f64 {
82        match self {
83            ScoreResult::Single(score) => *score,
84            ScoreResult::Multiple(scores) => scores.values().next().copied().unwrap_or(0.0),
85        }
86    }
87
88    /// Get scores as a map
89    pub fn as_multiple(&self) -> HashMap<String, f64> {
90        match self {
91            ScoreResult::Single(score) => {
92                let mut map = HashMap::new();
93                map.insert("score".to_string(), *score);
94                map
95            }
96            ScoreResult::Multiple(scores) => scores.clone(),
97        }
98    }
99}
100
101/// Evaluate metric(s) by cross-validation and also record fit/score times
102#[allow(clippy::too_many_arguments)]
103pub fn cross_validate<E, F, C>(
104    estimator: E,
105    x: &Array2<Float>,
106    y: &Array1<Float>,
107    cv: &C,
108    scoring: Scoring,
109    return_train_score: bool,
110    return_estimator: bool,
111    _n_jobs: Option<usize>,
112) -> Result<CrossValidateResult<F>>
113where
114    E: Clone,
115    F: Clone,
116    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
117    F: Predict<Array2<Float>, Array1<Float>>,
118    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
119    C: CrossValidator,
120{
121    // Note: This assumes KFold or other CV that doesn't need y
122    // For StratifiedKFold, you would need to pass integer labels
123    let splits = cv.split(x.nrows(), None);
124    let n_splits = splits.len();
125
126    let mut test_scores = Vec::with_capacity(n_splits);
127    let mut train_scores = if return_train_score {
128        Some(Vec::with_capacity(n_splits))
129    } else {
130        None
131    };
132    let mut fit_times = Vec::with_capacity(n_splits);
133    let mut score_times = Vec::with_capacity(n_splits);
134    let mut estimators = if return_estimator {
135        Some(Vec::with_capacity(n_splits))
136    } else {
137        None
138    };
139
140    // Process each fold
141    for (train_idx, test_idx) in splits {
142        // Extract train and test data
143        let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
144        let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
145        let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
146        let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
147
148        // Fit the estimator
149        let start = std::time::Instant::now();
150        let fitted = estimator.clone().fit(&x_train, &y_train)?;
151        let fit_time = start.elapsed().as_secs_f64();
152        fit_times.push(fit_time);
153
154        // Score on test set
155        let start = std::time::Instant::now();
156        let test_score = match &scoring {
157            Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
158            Scoring::Custom(func) => {
159                let y_pred = fitted.predict(&x_test)?;
160                func(&y_test.to_owned(), &y_pred)?
161            }
162            Scoring::Metric(metric_name) => {
163                let y_pred = fitted.predict(&x_test)?;
164                // Determine if this is classification or regression based on the data type
165                if y_test.iter().all(|&x| x.fract() == 0.0) {
166                    // Integer-like values, likely classification
167                    let y_true_int: Array1<i32> = y_test.mapv(|x| x as i32);
168                    let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
169                    compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
170                } else {
171                    // Float values, likely regression
172                    compute_score_for_regression_val(metric_name, &y_test, &y_pred)?
173                }
174            }
175            Scoring::Scorer(scorer) => {
176                let y_pred = fitted.predict(&x_test)?;
177                scorer.score_float(
178                    y_test.as_slice().expect("operation should succeed"),
179                    y_pred.as_slice().expect("operation should succeed"),
180                )?
181            }
182            Scoring::MultiMetric(_) => {
183                return Err(SklearsError::InvalidInput(
184                    "MultiMetric scoring not supported in single metric context".to_string(),
185                ));
186            }
187        };
188        let score_time = start.elapsed().as_secs_f64();
189        score_times.push(score_time);
190        test_scores.push(test_score);
191
192        // Score on train set if requested
193        if let Some(ref mut train_scores) = train_scores {
194            let train_score = match &scoring {
195                Scoring::EstimatorScore => fitted.score(&x_train, &y_train)?,
196                Scoring::Custom(func) => {
197                    let y_pred = fitted.predict(&x_train)?;
198                    func(&y_train.to_owned(), &y_pred)?
199                }
200                Scoring::Metric(metric_name) => {
201                    let y_pred = fitted.predict(&x_train)?;
202                    // Determine if this is classification or regression based on the data type
203                    if y_train.iter().all(|&x| x.fract() == 0.0) {
204                        // Integer-like values, likely classification
205                        let y_true_int: Array1<i32> = y_train.mapv(|x| x as i32);
206                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
207                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
208                    } else {
209                        // Float values, likely regression
210                        compute_score_for_regression_val(metric_name, &y_train, &y_pred)?
211                    }
212                }
213                Scoring::Scorer(scorer) => {
214                    let y_pred = fitted.predict(&x_train)?;
215                    scorer.score_float(
216                        y_train.as_slice().expect("operation should succeed"),
217                        y_pred.as_slice().expect("operation should succeed"),
218                    )?
219                }
220                Scoring::MultiMetric(_metrics) => {
221                    // For multi-metric, just use the first metric for now
222                    fitted.score(&x_train, &y_train)?
223                }
224            };
225            train_scores.push(train_score);
226        }
227
228        // Store estimator if requested
229        if let Some(ref mut estimators) = estimators {
230            estimators.push(fitted);
231        }
232    }
233
234    Ok(CrossValidateResult {
235        test_scores: Array1::from_vec(test_scores),
236        train_scores: train_scores.map(Array1::from_vec),
237        fit_times: Array1::from_vec(fit_times),
238        score_times: Array1::from_vec(score_times),
239        estimators,
240    })
241}
242
243/// Result of cross_validate
244#[derive(Debug, Clone)]
245pub struct CrossValidateResult<F> {
246    pub test_scores: Array1<f64>,
247    pub train_scores: Option<Array1<f64>>,
248    pub fit_times: Array1<f64>,
249    pub score_times: Array1<f64>,
250    pub estimators: Option<Vec<F>>,
251}
252
253/// Evaluate a score by cross-validation
254pub fn cross_val_score<E, F, C>(
255    estimator: E,
256    x: &Array2<Float>,
257    y: &Array1<Float>,
258    cv: &C,
259    scoring: Option<Scoring>,
260    n_jobs: Option<usize>,
261) -> Result<Array1<f64>>
262where
263    E: Clone,
264    F: Clone,
265    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
266    F: Predict<Array2<Float>, Array1<Float>>,
267    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
268    C: CrossValidator,
269{
270    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
271    let result = cross_validate(
272        estimator, x, y, cv, scoring, false, // return_train_score
273        false, // return_estimator
274        n_jobs,
275    )?;
276
277    Ok(result.test_scores)
278}
279
280/// Generate cross-validated estimates for each input data point
281pub fn cross_val_predict<E, F, C>(
282    estimator: E,
283    x: &Array2<Float>,
284    y: &Array1<Float>,
285    cv: &C,
286    _n_jobs: Option<usize>,
287) -> Result<Array1<Float>>
288where
289    E: Clone,
290    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
291    F: Predict<Array2<Float>, Array1<Float>>,
292    C: CrossValidator,
293{
294    // Note: This assumes KFold or other CV that doesn't need y
295    // For StratifiedKFold, you would need to pass integer labels
296    let splits = cv.split(x.nrows(), None);
297    let n_samples = x.nrows();
298
299    // Initialize predictions array
300    let mut predictions = Array1::<Float>::zeros(n_samples);
301
302    // Process each fold
303    for (train_idx, test_idx) in splits {
304        // Extract train and test data
305        let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
306        let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
307        let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
308
309        // Fit and predict
310        let fitted = estimator.clone().fit(&x_train, &y_train)?;
311        let y_pred = fitted.predict(&x_test)?;
312
313        // Store predictions at the correct indices
314        for (i, &idx) in test_idx.iter().enumerate() {
315            predictions[idx] = y_pred[i];
316        }
317    }
318
319    Ok(predictions)
320}
321
322/// Learning curve results
323#[derive(Debug, Clone)]
324pub struct LearningCurveResult {
325    /// Training set sizes used
326    pub train_sizes: Array1<usize>,
327    /// Training scores for each size
328    pub train_scores: Array2<f64>,
329    /// Validation scores for each size
330    pub test_scores: Array2<f64>,
331    /// Mean training scores for each size
332    pub train_scores_mean: Array1<f64>,
333    /// Mean validation scores for each size
334    pub test_scores_mean: Array1<f64>,
335    /// Standard deviation of training scores for each size
336    pub train_scores_std: Array1<f64>,
337    /// Standard deviation of validation scores for each size
338    pub test_scores_std: Array1<f64>,
339    /// Lower confidence bound for training scores (mean - confidence_interval)
340    pub train_scores_lower: Array1<f64>,
341    /// Upper confidence bound for training scores (mean + confidence_interval)
342    pub train_scores_upper: Array1<f64>,
343    /// Lower confidence bound for validation scores (mean - confidence_interval)
344    pub test_scores_lower: Array1<f64>,
345    /// Upper confidence bound for validation scores (mean + confidence_interval)
346    pub test_scores_upper: Array1<f64>,
347}
348
349/// Compute learning curves for an estimator
350///
351/// Determines cross-validated training and test scores for different training
352/// set sizes. This is useful to find out if we suffer from bias vs variance
353/// when we add more data to the training set.
354///
355/// # Arguments
356/// * `estimator` - The estimator to evaluate
357/// * `x` - Training data features
358/// * `y` - Training data targets
359/// * `cv` - Cross-validation splitter
360/// * `train_sizes` - Relative or absolute numbers of training examples that will be used to generate the learning curve
361/// * `scoring` - Scoring method to use
362/// * `confidence_level` - Confidence level for confidence bands (default: 0.95 for 95% confidence interval)
363#[allow(clippy::too_many_arguments)]
364pub fn learning_curve<E, F, C>(
365    estimator: E,
366    x: &Array2<Float>,
367    y: &Array1<Float>,
368    cv: &C,
369    train_sizes: Option<Vec<f64>>,
370    scoring: Option<Scoring>,
371    confidence_level: Option<f64>,
372) -> Result<LearningCurveResult>
373where
374    E: Clone,
375    F: Clone,
376    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
377    F: Predict<Array2<Float>, Array1<Float>>,
378    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
379    C: CrossValidator,
380{
381    let n_samples = x.nrows();
382    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
383
384    // Default train sizes: 10%, 30%, 50%, 70%, 90%, 100%
385    let train_size_fractions = train_sizes.unwrap_or_else(|| vec![0.1, 0.3, 0.5, 0.7, 0.9, 1.0]);
386
387    // Convert fractions to actual sizes
388    let train_sizes_actual: Vec<usize> = train_size_fractions
389        .iter()
390        .map(|&frac| {
391            let size = (frac * n_samples as f64).round() as usize;
392            size.max(1).min(n_samples) // Ensure between 1 and n_samples
393        })
394        .collect();
395
396    let n_splits = cv.n_splits();
397    let n_train_sizes = train_sizes_actual.len();
398
399    let mut train_scores = Array2::<f64>::zeros((n_train_sizes, n_splits));
400    let mut test_scores = Array2::<f64>::zeros((n_train_sizes, n_splits));
401
402    // Get CV splits
403    let splits = cv.split(x.nrows(), None);
404
405    for (size_idx, &train_size) in train_sizes_actual.iter().enumerate() {
406        for (split_idx, (train_idx, test_idx)) in splits.iter().enumerate() {
407            // Limit training set to the desired size
408            let mut limited_train_idx = train_idx.clone();
409            if limited_train_idx.len() > train_size {
410                limited_train_idx.truncate(train_size);
411            }
412
413            // Extract data
414            let x_train = x.select(scirs2_core::ndarray::Axis(0), &limited_train_idx);
415            let y_train = y.select(scirs2_core::ndarray::Axis(0), &limited_train_idx);
416            let x_test = x.select(scirs2_core::ndarray::Axis(0), test_idx);
417            let y_test = y.select(scirs2_core::ndarray::Axis(0), test_idx);
418
419            // Fit estimator
420            let fitted = estimator.clone().fit(&x_train, &y_train)?;
421
422            // Score on training set
423            let train_score = match &scoring {
424                Scoring::EstimatorScore => fitted.score(&x_train, &y_train)?,
425                Scoring::Custom(func) => {
426                    let y_pred = fitted.predict(&x_train)?;
427                    func(&y_train.to_owned(), &y_pred)?
428                }
429                Scoring::Metric(metric_name) => {
430                    let y_pred = fitted.predict(&x_train)?;
431                    // Determine if this is classification or regression based on the data type
432                    if y_train.iter().all(|&x| x.fract() == 0.0) {
433                        // Integer-like values, likely classification
434                        let y_true_int: Array1<i32> = y_train.mapv(|x| x as i32);
435                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
436                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
437                    } else {
438                        // Float values, likely regression
439                        compute_score_for_regression_val(metric_name, &y_train, &y_pred)?
440                    }
441                }
442                Scoring::Scorer(scorer) => {
443                    let y_pred = fitted.predict(&x_train)?;
444                    scorer.score_float(
445                        y_train.as_slice().expect("operation should succeed"),
446                        y_pred.as_slice().expect("operation should succeed"),
447                    )?
448                }
449                Scoring::MultiMetric(_metrics) => {
450                    // For multi-metric, just use the first metric for now
451                    fitted.score(&x_train, &y_train)?
452                }
453            };
454            train_scores[[size_idx, split_idx]] = train_score;
455
456            // Score on test set
457            let test_score = match &scoring {
458                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
459                Scoring::Custom(func) => {
460                    let y_pred = fitted.predict(&x_test)?;
461                    func(&y_test.to_owned(), &y_pred)?
462                }
463                Scoring::Metric(metric_name) => {
464                    let y_pred = fitted.predict(&x_test)?;
465                    // Determine if this is classification or regression based on the data type
466                    if y_test.iter().all(|&x| x.fract() == 0.0) {
467                        // Integer-like values, likely classification
468                        let y_true_int: Array1<i32> = y_test.mapv(|x| x as i32);
469                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
470                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
471                    } else {
472                        // Float values, likely regression
473                        compute_score_for_regression_val(metric_name, &y_test, &y_pred)?
474                    }
475                }
476                Scoring::Scorer(scorer) => {
477                    let y_pred = fitted.predict(&x_test)?;
478                    scorer.score_float(
479                        y_test.as_slice().expect("operation should succeed"),
480                        y_pred.as_slice().expect("operation should succeed"),
481                    )?
482                }
483                Scoring::MultiMetric(_metrics) => {
484                    // For multi-metric, just use the first metric for now
485                    fitted.score(&x_test, &y_test)?
486                }
487            };
488            test_scores[[size_idx, split_idx]] = test_score;
489        }
490    }
491
492    // Calculate confidence level (default 95%)
493    let confidence = confidence_level.unwrap_or(0.95);
494    let _alpha = 1.0 - confidence;
495    let z_score = 1.96; // Approximate 95% confidence interval
496
497    // Calculate statistics for each training size
498    let mut train_scores_mean = Array1::<f64>::zeros(n_train_sizes);
499    let mut test_scores_mean = Array1::<f64>::zeros(n_train_sizes);
500    let mut train_scores_std = Array1::<f64>::zeros(n_train_sizes);
501    let mut test_scores_std = Array1::<f64>::zeros(n_train_sizes);
502    let mut train_scores_lower = Array1::<f64>::zeros(n_train_sizes);
503    let mut train_scores_upper = Array1::<f64>::zeros(n_train_sizes);
504    let mut test_scores_lower = Array1::<f64>::zeros(n_train_sizes);
505    let mut test_scores_upper = Array1::<f64>::zeros(n_train_sizes);
506
507    for size_idx in 0..n_train_sizes {
508        // Extract scores for this training size across all CV folds
509        let train_scores_for_size: Vec<f64> = (0..n_splits)
510            .map(|split_idx| train_scores[[size_idx, split_idx]])
511            .collect();
512        let test_scores_for_size: Vec<f64> = (0..n_splits)
513            .map(|split_idx| test_scores[[size_idx, split_idx]])
514            .collect();
515
516        // Calculate mean and std for training scores
517        let train_mean = train_scores_for_size.iter().sum::<f64>() / n_splits as f64;
518        let train_variance = train_scores_for_size
519            .iter()
520            .map(|&x| (x - train_mean).powi(2))
521            .sum::<f64>()
522            / (n_splits - 1).max(1) as f64;
523        let train_std = train_variance.sqrt();
524        let train_sem = train_std / (n_splits as f64).sqrt(); // Standard error of the mean
525
526        // Calculate mean and std for test scores
527        let test_mean = test_scores_for_size.iter().sum::<f64>() / n_splits as f64;
528        let test_variance = test_scores_for_size
529            .iter()
530            .map(|&x| (x - test_mean).powi(2))
531            .sum::<f64>()
532            / (n_splits - 1).max(1) as f64;
533        let test_std = test_variance.sqrt();
534        let test_sem = test_std / (n_splits as f64).sqrt(); // Standard error of the mean
535
536        // Calculate confidence intervals
537        let train_margin = z_score * train_sem;
538        let test_margin = z_score * test_sem;
539
540        train_scores_mean[size_idx] = train_mean;
541        test_scores_mean[size_idx] = test_mean;
542        train_scores_std[size_idx] = train_std;
543        test_scores_std[size_idx] = test_std;
544        train_scores_lower[size_idx] = train_mean - train_margin;
545        train_scores_upper[size_idx] = train_mean + train_margin;
546        test_scores_lower[size_idx] = test_mean - test_margin;
547        test_scores_upper[size_idx] = test_mean + test_margin;
548    }
549
550    Ok(LearningCurveResult {
551        train_sizes: Array1::from_vec(train_sizes_actual),
552        train_scores,
553        test_scores,
554        train_scores_mean,
555        test_scores_mean,
556        train_scores_std,
557        test_scores_std,
558        train_scores_lower,
559        train_scores_upper,
560        test_scores_lower,
561        test_scores_upper,
562    })
563}
564
565/// Validation curve results
566#[derive(Debug, Clone)]
567pub struct ValidationCurveResult {
568    /// Parameter values used
569    pub param_values: Vec<ParameterValue>,
570    /// Training scores for each parameter value
571    pub train_scores: Array2<f64>,
572    /// Validation scores for each parameter value
573    pub test_scores: Array2<f64>,
574    /// Mean training scores for each parameter value
575    pub train_scores_mean: Array1<f64>,
576    /// Mean validation scores for each parameter value
577    pub test_scores_mean: Array1<f64>,
578    /// Standard deviation of training scores for each parameter value
579    pub train_scores_std: Array1<f64>,
580    /// Standard deviation of validation scores for each parameter value
581    pub test_scores_std: Array1<f64>,
582    /// Lower error bar for training scores (mean - std_error)
583    pub train_scores_lower: Array1<f64>,
584    /// Upper error bar for training scores (mean + std_error)
585    pub train_scores_upper: Array1<f64>,
586    /// Lower error bar for validation scores (mean - std_error)
587    pub test_scores_lower: Array1<f64>,
588    /// Upper error bar for validation scores (mean + std_error)
589    pub test_scores_upper: Array1<f64>,
590}
591
592/// Parameter configuration function type
593pub type ParamConfigFn<E> = Box<dyn Fn(E, &ParameterValue) -> Result<E>>;
594
595/// Compute validation curves for an estimator
596///
597/// Determines training and test scores for a varying parameter value.
598/// This is useful to understand the effect of a specific parameter on
599/// model performance and to detect overfitting/underfitting.
600///
601/// # Arguments
602/// * `estimator` - The estimator to evaluate
603/// * `x` - Training data features
604/// * `y` - Training data targets
605/// * `_param_name` - Name of the parameter being varied (for documentation)
606/// * `param_range` - Parameter values to test
607/// * `param_config` - Function to configure estimator with parameter values
608/// * `cv` - Cross-validation splitter
609/// * `scoring` - Scoring method to use
610/// * `confidence_level` - Confidence level for error bars (default: 0.95 for 95% confidence interval)
611#[allow(clippy::too_many_arguments)]
612pub fn validation_curve<E, F, C>(
613    estimator: E,
614    x: &Array2<Float>,
615    y: &Array1<Float>,
616    _param_name: &str,
617    param_range: Vec<ParameterValue>,
618    param_config: ParamConfigFn<E>,
619    cv: &C,
620    scoring: Option<Scoring>,
621    confidence_level: Option<f64>,
622) -> Result<ValidationCurveResult>
623where
624    E: Clone,
625    F: Clone,
626    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
627    F: Predict<Array2<Float>, Array1<Float>>,
628    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
629    C: CrossValidator,
630{
631    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
632    let n_splits = cv.n_splits();
633    let n_params = param_range.len();
634
635    let mut train_scores = Array2::<f64>::zeros((n_params, n_splits));
636    let mut test_scores = Array2::<f64>::zeros((n_params, n_splits));
637
638    // Get CV splits
639    let splits = cv.split(x.nrows(), None);
640
641    for (param_idx, param_value) in param_range.iter().enumerate() {
642        for (split_idx, (train_idx, test_idx)) in splits.iter().enumerate() {
643            // Extract data
644            let x_train = x.select(scirs2_core::ndarray::Axis(0), train_idx);
645            let y_train = y.select(scirs2_core::ndarray::Axis(0), train_idx);
646            let x_test = x.select(scirs2_core::ndarray::Axis(0), test_idx);
647            let y_test = y.select(scirs2_core::ndarray::Axis(0), test_idx);
648
649            // Configure estimator with current parameter value
650            let configured_estimator = param_config(estimator.clone(), param_value)?;
651
652            // Fit estimator
653            let fitted = configured_estimator.fit(&x_train, &y_train)?;
654
655            // Score on training set
656            let train_score = match &scoring {
657                Scoring::EstimatorScore => fitted.score(&x_train, &y_train)?,
658                Scoring::Custom(func) => {
659                    let y_pred = fitted.predict(&x_train)?;
660                    func(&y_train.to_owned(), &y_pred)?
661                }
662                Scoring::Metric(metric_name) => {
663                    let y_pred = fitted.predict(&x_train)?;
664                    // Determine if this is classification or regression based on the data type
665                    if y_train.iter().all(|&x| x.fract() == 0.0) {
666                        // Integer-like values, likely classification
667                        let y_true_int: Array1<i32> = y_train.mapv(|x| x as i32);
668                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
669                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
670                    } else {
671                        // Float values, likely regression
672                        compute_score_for_regression_val(metric_name, &y_train, &y_pred)?
673                    }
674                }
675                Scoring::Scorer(scorer) => {
676                    let y_pred = fitted.predict(&x_train)?;
677                    scorer.score_float(
678                        y_train.as_slice().expect("operation should succeed"),
679                        y_pred.as_slice().expect("operation should succeed"),
680                    )?
681                }
682                Scoring::MultiMetric(_metrics) => {
683                    // For multi-metric, just use the first metric for now
684                    fitted.score(&x_train, &y_train)?
685                }
686            };
687            train_scores[[param_idx, split_idx]] = train_score;
688
689            // Score on test set
690            let test_score = match &scoring {
691                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
692                Scoring::Custom(func) => {
693                    let y_pred = fitted.predict(&x_test)?;
694                    func(&y_test.to_owned(), &y_pred)?
695                }
696                Scoring::Metric(metric_name) => {
697                    let y_pred = fitted.predict(&x_test)?;
698                    // Determine if this is classification or regression based on the data type
699                    if y_test.iter().all(|&x| x.fract() == 0.0) {
700                        // Integer-like values, likely classification
701                        let y_true_int: Array1<i32> = y_test.mapv(|x| x as i32);
702                        let y_pred_int: Array1<i32> = y_pred.mapv(|x| x as i32);
703                        compute_score_for_classification_val(metric_name, &y_true_int, &y_pred_int)?
704                    } else {
705                        // Float values, likely regression
706                        compute_score_for_regression_val(metric_name, &y_test, &y_pred)?
707                    }
708                }
709                Scoring::Scorer(scorer) => {
710                    let y_pred = fitted.predict(&x_test)?;
711                    scorer.score_float(
712                        y_test.as_slice().expect("operation should succeed"),
713                        y_pred.as_slice().expect("operation should succeed"),
714                    )?
715                }
716                Scoring::MultiMetric(_metrics) => {
717                    // For multi-metric, just use the first metric for now
718                    fitted.score(&x_test, &y_test)?
719                }
720            };
721            test_scores[[param_idx, split_idx]] = test_score;
722        }
723    }
724
725    // Calculate confidence level (default 95%)
726    let _confidence = confidence_level.unwrap_or(0.95);
727    let _z_score = 1.96; // Approximate 95% confidence interval
728
729    // Calculate statistics for each parameter value
730    let mut train_scores_mean = Array1::<f64>::zeros(n_params);
731    let mut test_scores_mean = Array1::<f64>::zeros(n_params);
732    let mut train_scores_std = Array1::<f64>::zeros(n_params);
733    let mut test_scores_std = Array1::<f64>::zeros(n_params);
734    let mut train_scores_lower = Array1::<f64>::zeros(n_params);
735    let mut train_scores_upper = Array1::<f64>::zeros(n_params);
736    let mut test_scores_lower = Array1::<f64>::zeros(n_params);
737    let mut test_scores_upper = Array1::<f64>::zeros(n_params);
738
739    for param_idx in 0..n_params {
740        // Extract scores for this parameter value across all CV folds
741        let train_scores_for_param: Vec<f64> = (0..n_splits)
742            .map(|split_idx| train_scores[[param_idx, split_idx]])
743            .collect();
744        let test_scores_for_param: Vec<f64> = (0..n_splits)
745            .map(|split_idx| test_scores[[param_idx, split_idx]])
746            .collect();
747
748        // Calculate mean and std for training scores
749        let train_mean = train_scores_for_param.iter().sum::<f64>() / n_splits as f64;
750        let train_variance = train_scores_for_param
751            .iter()
752            .map(|&x| (x - train_mean).powi(2))
753            .sum::<f64>()
754            / (n_splits - 1).max(1) as f64;
755        let train_std = train_variance.sqrt();
756        let train_sem = train_std / (n_splits as f64).sqrt(); // Standard error of the mean
757
758        // Calculate mean and std for test scores
759        let test_mean = test_scores_for_param.iter().sum::<f64>() / n_splits as f64;
760        let test_variance = test_scores_for_param
761            .iter()
762            .map(|&x| (x - test_mean).powi(2))
763            .sum::<f64>()
764            / (n_splits - 1).max(1) as f64;
765        let test_std = test_variance.sqrt();
766        let test_sem = test_std / (n_splits as f64).sqrt(); // Standard error of the mean
767
768        // Calculate error bars (using standard error for error bars)
769        let train_margin = train_sem;
770        let test_margin = test_sem;
771
772        train_scores_mean[param_idx] = train_mean;
773        test_scores_mean[param_idx] = test_mean;
774        train_scores_std[param_idx] = train_std;
775        test_scores_std[param_idx] = test_std;
776        train_scores_lower[param_idx] = train_mean - train_margin;
777        train_scores_upper[param_idx] = train_mean + train_margin;
778        test_scores_lower[param_idx] = test_mean - test_margin;
779        test_scores_upper[param_idx] = test_mean + test_margin;
780    }
781
782    Ok(ValidationCurveResult {
783        param_values: param_range,
784        train_scores,
785        test_scores,
786        train_scores_mean,
787        test_scores_mean,
788        train_scores_std,
789        test_scores_std,
790        train_scores_lower,
791        train_scores_upper,
792        test_scores_lower,
793        test_scores_upper,
794    })
795}
796
797/// Evaluate the significance of a cross-validated score with permutations
798///
799/// This function tests whether the estimator performs significantly better than
800/// random by computing cross-validation scores on permuted labels.
801#[allow(clippy::too_many_arguments)]
802pub fn permutation_test_score<E, F, C>(
803    estimator: E,
804    x: &Array2<Float>,
805    y: &Array1<Float>,
806    cv: &C,
807    scoring: Option<Scoring>,
808    n_permutations: usize,
809    random_state: Option<u64>,
810    n_jobs: Option<usize>,
811) -> Result<PermutationTestResult>
812where
813    E: Clone,
814    F: Clone,
815    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
816    F: Predict<Array2<Float>, Array1<Float>>,
817    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
818    C: CrossValidator,
819{
820    use scirs2_core::random::prelude::*;
821    use scirs2_core::random::rngs::StdRng;
822
823    let scoring = scoring.unwrap_or(Scoring::EstimatorScore);
824
825    // Compute original score
826    let original_scores =
827        cross_val_score(estimator.clone(), x, y, cv, Some(scoring.clone()), n_jobs)?;
828    let original_score = original_scores.mean().unwrap_or(0.0);
829
830    // Initialize random number generator
831    let mut rng = if let Some(seed) = random_state {
832        StdRng::seed_from_u64(seed)
833    } else {
834        StdRng::seed_from_u64(42)
835    };
836
837    // Compute permutation scores
838    let mut permutation_scores = Vec::with_capacity(n_permutations);
839
840    for _ in 0..n_permutations {
841        // Create permuted labels
842        let mut y_permuted = y.to_owned();
843        let mut indices: Vec<usize> = (0..y.len()).collect();
844        indices.shuffle(&mut rng);
845
846        for (i, &perm_idx) in indices.iter().enumerate() {
847            y_permuted[i] = y[perm_idx];
848        }
849
850        // Compute score with permuted labels
851        let perm_scores = cross_val_score(
852            estimator.clone(),
853            x,
854            &y_permuted,
855            cv,
856            Some(scoring.clone()),
857            n_jobs,
858        )?;
859        let perm_score = perm_scores.mean().unwrap_or(0.0);
860        permutation_scores.push(perm_score);
861    }
862
863    // Compute p-value
864    let n_better_or_equal = permutation_scores
865        .iter()
866        .filter(|&&score| score >= original_score)
867        .count();
868    let p_value = (n_better_or_equal + 1) as f64 / (n_permutations + 1) as f64;
869
870    Ok(PermutationTestResult {
871        statistic: original_score,
872        pvalue: p_value,
873        permutation_scores: Array1::from_vec(permutation_scores),
874    })
875}
876
877/// Result of permutation test
878#[derive(Debug, Clone)]
879pub struct PermutationTestResult {
880    /// The original cross-validation score
881    pub statistic: f64,
882    /// The p-value of the permutation test
883    pub pvalue: f64,
884    /// Scores obtained for each permutation
885    pub permutation_scores: Array1<f64>,
886}
887
888/// Nested cross-validation for unbiased model evaluation with hyperparameter optimization
889///
890/// This implements nested cross-validation which provides an unbiased estimate of model
891/// performance by using separate CV loops for hyperparameter optimization (inner loop)
892/// and performance estimation (outer loop).
893#[allow(clippy::too_many_arguments)]
894pub fn nested_cross_validate<E, F, C>(
895    estimator: E,
896    x: &Array2<Float>,
897    y: &Array1<Float>,
898    outer_cv: &C,
899    inner_cv: &C,
900    param_grid: &[ParameterValue],
901    param_config: ParamConfigFn<E>,
902    scoring: Option<fn(&Array1<Float>, &Array1<Float>) -> f64>,
903) -> Result<NestedCVResult>
904where
905    E: Clone,
906    F: Clone,
907    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
908    F: Predict<Array2<Float>, Array1<Float>>,
909    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
910    C: CrossValidator,
911{
912    let outer_splits = outer_cv.split(x.nrows(), None);
913    let mut outer_scores = Vec::with_capacity(outer_splits.len());
914    let mut best_params_per_fold = Vec::with_capacity(outer_splits.len());
915    let mut inner_scores_per_fold = Vec::with_capacity(outer_splits.len());
916
917    for (outer_train_idx, outer_test_idx) in outer_splits {
918        // Extract outer train/test data
919        let outer_train_x = extract_rows(x, &outer_train_idx);
920        let outer_train_y = extract_elements(y, &outer_train_idx);
921        let outer_test_x = extract_rows(x, &outer_test_idx);
922        let outer_test_y = extract_elements(y, &outer_test_idx);
923
924        // Inner cross-validation for hyperparameter optimization
925        let mut best_score = f64::NEG_INFINITY;
926        let mut best_param = param_grid[0].clone();
927        let mut inner_scores = Vec::new();
928
929        for param in param_grid {
930            let param_estimator = param_config(estimator.clone(), param)?;
931
932            // Inner CV evaluation
933            let inner_splits = inner_cv.split(outer_train_x.nrows(), None);
934            let mut param_scores = Vec::new();
935
936            for (inner_train_idx, inner_test_idx) in inner_splits {
937                let inner_train_x = extract_rows(&outer_train_x, &inner_train_idx);
938                let inner_train_y = extract_elements(&outer_train_y, &inner_train_idx);
939                let inner_test_x = extract_rows(&outer_train_x, &inner_test_idx);
940                let inner_test_y = extract_elements(&outer_train_y, &inner_test_idx);
941
942                // Fit and score on inner split
943                let fitted = param_estimator
944                    .clone()
945                    .fit(&inner_train_x, &inner_train_y)?;
946                let predictions = fitted.predict(&inner_test_x)?;
947
948                let score = if let Some(scoring_fn) = scoring {
949                    scoring_fn(&inner_test_y, &predictions)
950                } else {
951                    fitted.score(&inner_test_x, &inner_test_y)?
952                };
953
954                param_scores.push(score);
955            }
956
957            let mean_score = param_scores.iter().sum::<f64>() / param_scores.len() as f64;
958            inner_scores.push(mean_score);
959
960            if mean_score > best_score {
961                best_score = mean_score;
962                best_param = param.clone();
963            }
964        }
965
966        // Train best model on full outer training set and evaluate on outer test set
967        let best_estimator = param_config(estimator.clone(), &best_param)?;
968        let final_fitted = best_estimator.fit(&outer_train_x, &outer_train_y)?;
969        let outer_predictions = final_fitted.predict(&outer_test_x)?;
970
971        let outer_score = if let Some(scoring_fn) = scoring {
972            scoring_fn(&outer_test_y, &outer_predictions)
973        } else {
974            final_fitted.score(&outer_test_x, &outer_test_y)?
975        };
976
977        outer_scores.push(outer_score);
978        best_params_per_fold.push(best_param);
979        inner_scores_per_fold.push(inner_scores);
980    }
981
982    let mean_score = outer_scores.iter().sum::<f64>() / outer_scores.len() as f64;
983    let std_score = {
984        let variance = outer_scores
985            .iter()
986            .map(|&x| (x - mean_score).powi(2))
987            .sum::<f64>()
988            / outer_scores.len() as f64;
989        variance.sqrt()
990    };
991
992    Ok(NestedCVResult {
993        outer_scores: Array1::from_vec(outer_scores),
994        best_params_per_fold,
995        inner_scores_per_fold,
996        mean_outer_score: mean_score,
997        std_outer_score: std_score,
998    })
999}
1000
1001/// Result of nested cross-validation
1002#[derive(Debug, Clone)]
1003pub struct NestedCVResult {
1004    /// Outer cross-validation scores (unbiased performance estimates)
1005    pub outer_scores: Array1<f64>,
1006    /// Best parameters found for each outer fold
1007    pub best_params_per_fold: Vec<ParameterValue>,
1008    /// Inner CV scores for each parameter in each outer fold
1009    pub inner_scores_per_fold: Vec<Vec<f64>>,
1010    /// Mean of outer scores
1011    pub mean_outer_score: f64,
1012    /// Standard deviation of outer scores
1013    pub std_outer_score: f64,
1014}
1015
1016// Helper functions for data extraction
1017fn extract_rows(arr: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1018    let mut result = Array2::zeros((indices.len(), arr.ncols()));
1019    for (i, &idx) in indices.iter().enumerate() {
1020        for j in 0..arr.ncols() {
1021            result[[i, j]] = arr[[idx, j]];
1022        }
1023    }
1024    result
1025}
1026
1027fn extract_elements(arr: &Array1<Float>, indices: &[usize]) -> Array1<Float> {
1028    Array1::from_iter(indices.iter().map(|&i| arr[i]))
1029}
1030
1031#[allow(non_snake_case)]
1032#[cfg(test)]
1033mod tests {
1034    use super::*;
1035    use crate::KFold;
1036    use scirs2_core::ndarray::array;
1037
1038    // Mock estimator for testing
1039    #[derive(Clone)]
1040    struct MockEstimator;
1041
1042    #[derive(Clone)]
1043    struct MockFitted {
1044        train_mean: f64,
1045    }
1046
1047    impl Fit<Array2<Float>, Array1<Float>> for MockEstimator {
1048        type Fitted = MockFitted;
1049
1050        fn fit(self, _x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
1051            Ok(MockFitted {
1052                train_mean: y.mean().unwrap_or(0.0),
1053            })
1054        }
1055    }
1056
1057    impl Predict<Array2<Float>, Array1<Float>> for MockFitted {
1058        fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
1059            Ok(Array1::from_elem(x.nrows(), self.train_mean))
1060        }
1061    }
1062
1063    impl Score<Array2<Float>, Array1<Float>> for MockFitted {
1064        type Float = Float;
1065
1066        fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
1067            let y_pred = self.predict(x)?;
1068            let mse = (y - &y_pred).mapv(|e| e * e).mean().unwrap_or(0.0);
1069            Ok(1.0 - mse) // Simple R² approximation
1070        }
1071    }
1072
1073    #[test]
1074    fn test_cross_val_score() {
1075        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1076        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1077
1078        let estimator = MockEstimator;
1079        let cv = KFold::new(3);
1080
1081        let scores =
1082            cross_val_score(estimator, &x, &y, &cv, None, None).expect("operation should succeed");
1083
1084        assert_eq!(scores.len(), 3);
1085        // All scores should be negative (since we're predicting mean)
1086        for score in scores.iter() {
1087            assert!(*score <= 1.0);
1088        }
1089    }
1090
1091    #[test]
1092    fn test_cross_val_predict() {
1093        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1094        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1095
1096        let estimator = MockEstimator;
1097        let cv = KFold::new(3);
1098
1099        let predictions =
1100            cross_val_predict(estimator, &x, &y, &cv, None).expect("operation should succeed");
1101
1102        assert_eq!(predictions.len(), 6);
1103        // Each prediction should be the mean of the training fold
1104        // Since we're using KFold with 3 splits, each test set has 2 samples
1105        // and each train set has 4 samples
1106    }
1107
1108    #[test]
1109    fn test_learning_curve() {
1110        let x = array![
1111            [1.0],
1112            [2.0],
1113            [3.0],
1114            [4.0],
1115            [5.0],
1116            [6.0],
1117            [7.0],
1118            [8.0],
1119            [9.0],
1120            [10.0]
1121        ];
1122        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1123
1124        let estimator = MockEstimator;
1125        let cv = KFold::new(3);
1126
1127        let result = learning_curve(
1128            estimator,
1129            &x,
1130            &y,
1131            &cv,
1132            Some(vec![0.3, 0.6, 1.0]), // 30%, 60%, 100% of training data
1133            None,
1134            None, // Use default confidence level
1135        )
1136        .expect("operation should succeed");
1137
1138        // Check dimensions
1139        assert_eq!(result.train_sizes.len(), 3);
1140        assert_eq!(result.train_scores.dim(), (3, 3)); // 3 sizes x 3 CV folds
1141        assert_eq!(result.test_scores.dim(), (3, 3));
1142
1143        // Check that train sizes are reasonable
1144        assert_eq!(result.train_sizes[0], 3); // 30% of 10 = 3
1145        assert_eq!(result.train_sizes[1], 6); // 60% of 10 = 6
1146        assert_eq!(result.train_sizes[2], 10); // 100% of 10 = 10
1147
1148        // Training scores should generally be better than test scores for our mock estimator
1149        let mean_train_score = result
1150            .train_scores
1151            .mean()
1152            .expect("operation should succeed");
1153        let mean_test_score = result.test_scores.mean().expect("operation should succeed");
1154        // Our mock estimator predicts the mean, so training should be perfect
1155        assert!(mean_train_score >= mean_test_score);
1156
1157        // Verify confidence bands are calculated
1158        assert_eq!(result.train_scores_mean.len(), 3);
1159        assert_eq!(result.test_scores_mean.len(), 3);
1160        assert_eq!(result.train_scores_std.len(), 3);
1161        assert_eq!(result.test_scores_std.len(), 3);
1162        assert_eq!(result.train_scores_lower.len(), 3);
1163        assert_eq!(result.train_scores_upper.len(), 3);
1164        assert_eq!(result.test_scores_lower.len(), 3);
1165        assert_eq!(result.test_scores_upper.len(), 3);
1166
1167        // Verify confidence intervals are sensible (lower < mean < upper)
1168        for i in 0..3 {
1169            assert!(result.train_scores_lower[i] <= result.train_scores_mean[i]);
1170            assert!(result.train_scores_mean[i] <= result.train_scores_upper[i]);
1171            assert!(result.test_scores_lower[i] <= result.test_scores_mean[i]);
1172            assert!(result.test_scores_mean[i] <= result.test_scores_upper[i]);
1173        }
1174    }
1175
1176    #[test]
1177    fn test_validation_curve() {
1178        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1179        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1180
1181        let estimator = MockEstimator;
1182        let cv = KFold::new(3);
1183
1184        // Mock parameter configuration function
1185        let param_config: ParamConfigFn<MockEstimator> = Box::new(|estimator, _param_value| {
1186            // For our mock estimator, parameters don't matter
1187            Ok(estimator)
1188        });
1189
1190        let param_range = vec![
1191            ParameterValue::Float(0.1),
1192            ParameterValue::Float(0.5),
1193            ParameterValue::Float(1.0),
1194        ];
1195
1196        let result = validation_curve(
1197            estimator,
1198            &x,
1199            &y,
1200            "mock_param",
1201            param_range.clone(),
1202            param_config,
1203            &cv,
1204            None,
1205            None, // Use default confidence level
1206        )
1207        .expect("operation should succeed");
1208
1209        // Check dimensions
1210        assert_eq!(result.param_values.len(), 3);
1211        assert_eq!(result.train_scores.dim(), (3, 3)); // 3 params x 3 CV folds
1212        assert_eq!(result.test_scores.dim(), (3, 3));
1213
1214        // Check that parameter values match
1215        assert_eq!(result.param_values, param_range);
1216
1217        // For our mock estimator, all parameter values should give similar results
1218        let train_score_std = {
1219            let mean = result
1220                .train_scores
1221                .mean()
1222                .expect("operation should succeed");
1223            let variance = result
1224                .train_scores
1225                .mapv(|x| (x - mean).powi(2))
1226                .mean()
1227                .expect("operation should succeed");
1228            variance.sqrt()
1229        };
1230
1231        // Standard deviation should be low since our mock estimator ignores parameters
1232        // But allow for some variation due to different CV folds
1233        assert!(train_score_std < 2.0);
1234
1235        // Verify error bars are calculated
1236        assert_eq!(result.train_scores_mean.len(), 3);
1237        assert_eq!(result.test_scores_mean.len(), 3);
1238        assert_eq!(result.train_scores_std.len(), 3);
1239        assert_eq!(result.test_scores_std.len(), 3);
1240        assert_eq!(result.train_scores_lower.len(), 3);
1241        assert_eq!(result.train_scores_upper.len(), 3);
1242        assert_eq!(result.test_scores_lower.len(), 3);
1243        assert_eq!(result.test_scores_upper.len(), 3);
1244
1245        // Verify error bars are sensible (lower <= mean <= upper)
1246        for i in 0..3 {
1247            assert!(result.train_scores_lower[i] <= result.train_scores_mean[i]);
1248            assert!(result.train_scores_mean[i] <= result.train_scores_upper[i]);
1249            assert!(result.test_scores_lower[i] <= result.test_scores_mean[i]);
1250            assert!(result.test_scores_mean[i] <= result.test_scores_upper[i]);
1251        }
1252    }
1253
1254    #[test]
1255    fn test_learning_curve_default_sizes() {
1256        let x = array![
1257            [1.0],
1258            [2.0],
1259            [3.0],
1260            [4.0],
1261            [5.0],
1262            [6.0],
1263            [7.0],
1264            [8.0],
1265            [9.0],
1266            [10.0]
1267        ];
1268        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1269
1270        let estimator = MockEstimator;
1271        let cv = KFold::new(2);
1272
1273        let result = learning_curve(
1274            estimator, &x, &y, &cv, None, // Use default train sizes
1275            None, None, // Use default confidence level
1276        )
1277        .expect("operation should succeed");
1278
1279        // Should use default sizes: 10%, 30%, 50%, 70%, 90%, 100%
1280        assert_eq!(result.train_sizes.len(), 6);
1281        assert_eq!(result.train_scores.dim(), (6, 2)); // 6 sizes x 2 CV folds
1282
1283        // Check that sizes are increasing
1284        for i in 1..result.train_sizes.len() {
1285            assert!(result.train_sizes[i] >= result.train_sizes[i - 1]);
1286        }
1287    }
1288
1289    #[test]
1290    fn test_permutation_test_score() {
1291        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]];
1292        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1293
1294        let estimator = MockEstimator;
1295        let cv = KFold::new(4);
1296
1297        let result = permutation_test_score(
1298            estimator,
1299            &x,
1300            &y,
1301            &cv,
1302            None,
1303            10, // 10 permutations
1304            Some(42),
1305            None,
1306        )
1307        .expect("operation should succeed");
1308
1309        // Check that we got reasonable results
1310        assert!(result.pvalue >= 0.0 && result.pvalue <= 1.0);
1311        assert_eq!(result.permutation_scores.len(), 10);
1312
1313        // For our mock estimator, the original score should be reasonably good
1314        // compared to permuted scores
1315        assert!(result.statistic.is_finite());
1316
1317        // Permutation scores should all be finite
1318        for &score in result.permutation_scores.iter() {
1319            assert!(score.is_finite());
1320        }
1321
1322        // P-value should be calculated correctly (at least one score >= original)
1323        let n_better = result
1324            .permutation_scores
1325            .iter()
1326            .filter(|&&score| score >= result.statistic)
1327            .count();
1328        let expected_p = (n_better + 1) as f64 / 11.0; // 10 permutations + 1
1329        assert!((result.pvalue - expected_p).abs() < 1e-10);
1330    }
1331
1332    #[test]
1333    fn test_nested_cross_validate() {
1334        let x = array![
1335            [1.0],
1336            [2.0],
1337            [3.0],
1338            [4.0],
1339            [5.0],
1340            [6.0],
1341            [7.0],
1342            [8.0],
1343            [9.0],
1344            [10.0]
1345        ];
1346        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1347
1348        let estimator = MockEstimator;
1349        let outer_cv = KFold::new(3);
1350        let inner_cv = KFold::new(2);
1351
1352        // Mock parameter configuration function
1353        let param_config: ParamConfigFn<MockEstimator> = Box::new(|estimator, _param_value| {
1354            // For our mock estimator, parameters don't matter
1355            Ok(estimator)
1356        });
1357
1358        let param_grid = vec![
1359            ParameterValue::Float(0.1),
1360            ParameterValue::Float(0.5),
1361            ParameterValue::Float(1.0),
1362        ];
1363
1364        let result = nested_cross_validate(
1365            estimator,
1366            &x,
1367            &y,
1368            &outer_cv,
1369            &inner_cv,
1370            &param_grid,
1371            param_config,
1372            None,
1373        )
1374        .expect("operation should succeed");
1375
1376        // Check dimensions
1377        assert_eq!(result.outer_scores.len(), 3); // 3 outer folds
1378        assert_eq!(result.best_params_per_fold.len(), 3);
1379        assert_eq!(result.inner_scores_per_fold.len(), 3);
1380
1381        // Each inner fold should have scores for all parameters
1382        for inner_scores in &result.inner_scores_per_fold {
1383            assert_eq!(inner_scores.len(), 3); // 3 parameters
1384        }
1385
1386        // Check that outer scores are finite
1387        for &score in result.outer_scores.iter() {
1388            assert!(score.is_finite());
1389        }
1390
1391        // Check that mean and std are calculated correctly
1392        let manual_mean =
1393            result.outer_scores.iter().sum::<f64>() / result.outer_scores.len() as f64;
1394        assert!((result.mean_outer_score - manual_mean).abs() < 1e-10);
1395
1396        assert!(result.std_outer_score >= 0.0);
1397        assert!(result.std_outer_score.is_finite());
1398    }
1399}