use crate::error::MetricsError;
use scirs2_core::ndarray::{Array1, ArrayBase, Data, Dimension};
pub type LearningCurveResult = (Array1<usize>, Array1<f64>, Array1<f64>);
#[allow(dead_code)]
pub fn learning_curve<S1, S2, S3, D1, D2, D3>(
train_sizes: &ArrayBase<S1, D1>,
train_scores: &ArrayBase<S2, D2>,
test_scores: &ArrayBase<S3, D3>,
total_examples: usize,
) -> Result<LearningCurveResult, MetricsError>
where
S1: Data,
S2: Data,
S3: Data,
S1::Elem: PartialOrd + Clone + Into<f64>,
S2::Elem: Clone + Into<f64>,
S3::Elem: Clone + Into<f64>,
D1: Dimension,
D2: Dimension,
D3: Dimension,
{
if train_sizes.len() != train_scores.len() || train_sizes.len() != test_scores.len() {
return Err(MetricsError::ShapeMismatch {
shape1: format!("{:?}", train_sizes.shape()),
shape2: format!("{:?}", train_scores.shape()),
});
}
let train_sizes_abs: Vec<usize> = train_sizes
.iter()
.map(|x| {
let val: f64 = x.clone().into();
if val <= 0.0 {
return Err(MetricsError::InvalidArgument(
"Training sizes must be positive".to_string(),
));
}
if val <= 1.0 {
Ok((val * total_examples as f64).round() as usize)
} else {
Ok(val as usize)
}
})
.collect::<Result<Vec<usize>, MetricsError>>()?;
let train_scores_f64: Vec<f64> = train_scores
.iter()
.map(|x| {
let val: f64 = x.clone().into();
val
})
.collect();
let test_scores_f64: Vec<f64> = test_scores
.iter()
.map(|x| {
let val: f64 = x.clone().into();
val
})
.collect();
Ok((
Array1::from_vec(train_sizes_abs),
Array1::from_vec(train_scores_f64),
Array1::from_vec(test_scores_f64),
))
}