use super::result::ValidationResult;
use super::strategy::ValidationStrategy;
pub struct Validator<T> {
split_ratio: f64,
strategy: Box<dyn ValidationStrategy>,
_marker: std::marker::PhantomData<T>,
}
impl<T> Validator<T> {
pub fn new(split_ratio: f64, strategy: Box<dyn ValidationStrategy>) -> Self {
assert!(
(0.0..=1.0).contains(&split_ratio),
"split_ratio must be between 0.0 and 1.0"
);
Self {
split_ratio,
strategy,
_marker: std::marker::PhantomData,
}
}
pub fn with_80_20_split(strategy: Box<dyn ValidationStrategy>) -> Self {
Self::new(0.8, strategy)
}
pub fn with_70_30_split(strategy: Box<dyn ValidationStrategy>) -> Self {
Self::new(0.7, strategy)
}
pub fn validate<F, G>(&self, data: &[T], baseline_fn: F, evaluate_fn: G) -> ValidationResult
where
F: FnOnce(&[T]) -> f64,
G: FnOnce(&[T]) -> f64,
{
let (train, test) = self.split(data);
let baseline = baseline_fn(train);
let current = evaluate_fn(test);
self.strategy.evaluate(baseline, current, test.len())
}
pub fn validate_with_baseline<F>(
&self,
data: &[T],
baseline: f64,
evaluate_fn: F,
) -> ValidationResult
where
F: FnOnce(&[T]) -> f64,
{
let (_, test) = self.split(data);
let current = evaluate_fn(test);
self.strategy.evaluate(baseline, current, test.len())
}
fn split<'a>(&self, data: &'a [T]) -> (&'a [T], &'a [T]) {
let split_idx = (data.len() as f64 * self.split_ratio) as usize;
let split_idx = split_idx.min(data.len());
(&data[..split_idx], &data[split_idx..])
}
pub fn strategy_name(&self) -> &str {
self.strategy.name()
}
pub fn split_ratio(&self) -> f64 {
self.split_ratio
}
}
#[cfg(test)]
mod tests {
use super::super::strategy::{Absolute, Improvement, NoRegression};
use super::*;
#[test]
fn test_validator_split() {
let data: Vec<i32> = (0..100).collect();
let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));
let (train, test) = validator.split(&data);
assert_eq!(train.len(), 80);
assert_eq!(test.len(), 20);
}
#[test]
fn test_validator_validate() {
let mut data: Vec<f64> = Vec::with_capacity(100);
for i in 0..80 {
data.push(if i < 64 { 1.0 } else { 0.0 });
}
for i in 0..20 {
data.push(if i < 18 { 1.0 } else { 0.0 });
}
let validator = Validator::with_80_20_split(Box::new(NoRegression::new()));
let result = validator.validate(
&data,
|train| train.iter().sum::<f64>() / train.len() as f64,
|test| test.iter().sum::<f64>() / test.len() as f64,
);
assert!(result.passed);
assert!((result.baseline - 0.8).abs() < 0.01);
assert!((result.current - 0.9).abs() < 0.01);
}
#[test]
fn test_validator_with_baseline() {
let data: Vec<f64> = (0..100).map(|i| if i < 85 { 1.0 } else { 0.0 }).collect();
let validator = Validator::with_80_20_split(Box::new(Improvement::ten_percent()));
let result = validator.validate_with_baseline(&data, 0.7, |test| {
test.iter().sum::<f64>() / test.len() as f64
});
assert!(!result.passed);
}
#[test]
fn test_validator_absolute_strategy() {
let data: Vec<f64> = (0..100).map(|i| if i < 90 { 1.0 } else { 0.0 }).collect();
let validator = Validator::with_80_20_split(Box::new(Absolute::eighty_percent()));
let result = validator.validate_with_baseline(&data, 0.5, |test| {
test.iter().sum::<f64>() / test.len() as f64
});
assert!(!result.passed);
}
#[test]
fn test_validator_empty_data() {
let data: Vec<i32> = vec![];
let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));
let (train, test) = validator.split(&data);
assert!(train.is_empty());
assert!(test.is_empty());
}
#[test]
#[should_panic(expected = "split_ratio must be between")]
fn test_validator_invalid_ratio() {
let _: Validator<i32> = Validator::new(1.5, Box::new(NoRegression::new()));
}
}