use crate::learner::StreamingLearner;
use crate::metrics::regression::RegressionMetrics;
use crate::sample::Observation;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum HoldoutStrategy {
None,
Periodic {
period: usize,
},
Random {
holdout_fraction: f64,
seed: u64,
},
}
use irithyll_core::rng::xorshift64;
#[derive(Debug, Clone)]
pub struct ProgressiveValidator {
holdout: HoldoutStrategy,
regression: RegressionMetrics,
samples_seen: u64,
samples_evaluated: u64,
samples_trained: u64,
rng_state: u64,
}
impl ProgressiveValidator {
pub fn new(holdout: HoldoutStrategy) -> Self {
let rng_state = match &holdout {
HoldoutStrategy::None => 1, HoldoutStrategy::Periodic { period } => {
assert!(*period > 0, "holdout period must be > 0");
1 }
HoldoutStrategy::Random {
holdout_fraction,
seed,
} => {
assert!(
*holdout_fraction >= 0.0 && *holdout_fraction <= 1.0,
"holdout_fraction must be in [0.0, 1.0], got {}",
holdout_fraction
);
assert!(*seed != 0, "xorshift64 seed must be non-zero");
*seed
}
};
Self {
holdout,
regression: RegressionMetrics::new(),
samples_seen: 0,
samples_evaluated: 0,
samples_trained: 0,
rng_state,
}
}
pub fn step(&mut self, model: &mut dyn StreamingLearner, features: &[f64], target: f64) -> f64 {
let prediction = model.predict(features);
self.regression.update(target, prediction);
self.samples_evaluated += 1;
let should_train = match &self.holdout {
HoldoutStrategy::None => true,
HoldoutStrategy::Periodic { period } => {
self.samples_seen % *period as u64 != 0
}
HoldoutStrategy::Random {
holdout_fraction, ..
} => {
let rand_val = xorshift64(&mut self.rng_state) as f64 / u64::MAX as f64;
rand_val >= *holdout_fraction
}
};
if should_train {
model.train(features, target);
self.samples_trained += 1;
}
self.samples_seen += 1;
prediction
}
pub fn evaluate<O: Observation>(&mut self, model: &mut dyn StreamingLearner, data: &[O]) {
for obs in data {
self.step(model, obs.features(), obs.target());
}
}
pub fn regression_metrics(&self) -> &RegressionMetrics {
&self.regression
}
pub fn samples_seen(&self) -> u64 {
self.samples_seen
}
pub fn samples_evaluated(&self) -> u64 {
self.samples_evaluated
}
pub fn samples_trained(&self) -> u64 {
self.samples_trained
}
pub fn holdout_fraction_actual(&self) -> f64 {
if self.samples_seen == 0 {
return 0.0;
}
1.0 - (self.samples_trained as f64 / self.samples_seen as f64)
}
pub fn reset(&mut self) {
self.regression.reset();
self.samples_seen = 0;
self.samples_evaluated = 0;
self.samples_trained = 0;
if let HoldoutStrategy::Random { seed, .. } = &self.holdout {
self.rng_state = *seed;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learner::StreamingLearner;
use crate::sample::Sample;
struct MockLearner {
prediction: f64,
n: u64,
}
impl MockLearner {
fn new(prediction: f64) -> Self {
Self { prediction, n: 0 }
}
}
impl StreamingLearner for MockLearner {
fn train_one(&mut self, _features: &[f64], _target: f64, _weight: f64) {
self.n += 1;
}
fn predict(&self, _features: &[f64]) -> f64 {
self.prediction
}
fn n_samples_seen(&self) -> u64 {
self.n
}
fn reset(&mut self) {
self.n = 0;
}
}
#[test]
fn holdout_none_trains_on_everything() {
let mut model = MockLearner::new(0.0);
let mut validator = ProgressiveValidator::new(HoldoutStrategy::None);
for i in 0..100 {
validator.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(validator.samples_seen(), 100);
assert_eq!(validator.samples_evaluated(), 100);
assert_eq!(validator.samples_trained(), 100);
assert!(
(validator.holdout_fraction_actual() - 0.0).abs() < 1e-12,
"holdout fraction should be 0.0 for None strategy"
);
assert_eq!(model.n_samples_seen(), 100);
}
#[test]
fn holdout_periodic_holds_out_every_nth() {
let mut model = MockLearner::new(0.0);
let holdout = HoldoutStrategy::Periodic { period: 5 };
let mut validator = ProgressiveValidator::new(holdout);
for i in 0..100 {
validator.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(validator.samples_seen(), 100);
assert_eq!(validator.samples_evaluated(), 100);
assert_eq!(validator.samples_trained(), 80);
assert_eq!(model.n_samples_seen(), 80);
let actual = validator.holdout_fraction_actual();
assert!(
(actual - 0.2).abs() < 0.01,
"expected holdout ~0.2, got {}",
actual
);
}
#[test]
fn holdout_random_approximates_target_fraction() {
let mut model = MockLearner::new(0.0);
let holdout = HoldoutStrategy::Random {
holdout_fraction: 0.3,
seed: 42,
};
let mut validator = ProgressiveValidator::new(holdout);
for i in 0..10_000 {
validator.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(validator.samples_seen(), 10_000);
assert_eq!(validator.samples_evaluated(), 10_000);
let actual = validator.holdout_fraction_actual();
assert!(
(actual - 0.3).abs() < 0.05,
"expected holdout ~0.3, got {} (tolerance 0.05)",
actual
);
}
#[test]
fn holdout_random_deterministic_with_same_seed() {
let run = |seed: u64| -> Vec<bool> {
let mut model = MockLearner::new(0.0);
let holdout = HoldoutStrategy::Random {
holdout_fraction: 0.5,
seed,
};
let mut validator = ProgressiveValidator::new(holdout);
let mut pattern = Vec::new();
for i in 0..50 {
let prev_trained = validator.samples_trained();
validator.step(&mut model, &[i as f64], i as f64);
pattern.push(validator.samples_trained() > prev_trained);
}
pattern
};
let run1 = run(12345);
let run2 = run(12345);
assert_eq!(
run1, run2,
"same seed should produce identical holdout patterns"
);
let run3 = run(99999);
assert_ne!(
run1, run3,
"different seeds should produce different holdout patterns"
);
}
#[test]
fn step_returns_prediction() {
let mut model = MockLearner::new(7.5);
let mut validator = ProgressiveValidator::new(HoldoutStrategy::None);
let pred = validator.step(&mut model, &[1.0], 5.0);
assert!(
(pred - 7.5).abs() < 1e-12,
"step should return the model's prediction, got {}",
pred
);
}
#[test]
fn evaluate_over_observations() {
let mut model = MockLearner::new(0.0);
let mut validator = ProgressiveValidator::new(HoldoutStrategy::None);
let data: Vec<Sample> = (0..30)
.map(|i| Sample::new(vec![i as f64], i as f64))
.collect();
validator.evaluate(&mut model, &data);
assert_eq!(validator.samples_seen(), 30);
assert_eq!(validator.samples_evaluated(), 30);
assert_eq!(validator.samples_trained(), 30);
}
#[test]
fn reset_clears_all_state() {
let mut model = MockLearner::new(0.0);
let holdout = HoldoutStrategy::Random {
holdout_fraction: 0.2,
seed: 42,
};
let mut validator = ProgressiveValidator::new(holdout);
for i in 0..50 {
validator.step(&mut model, &[i as f64], i as f64);
}
assert!(validator.samples_seen() > 0);
validator.reset();
assert_eq!(validator.samples_seen(), 0);
assert_eq!(validator.samples_evaluated(), 0);
assert_eq!(validator.samples_trained(), 0);
assert_eq!(validator.regression_metrics().n_samples(), 0);
assert!(
(validator.holdout_fraction_actual() - 0.0).abs() < 1e-12,
"holdout fraction should be 0.0 after reset"
);
}
#[test]
fn reset_restores_prng_seed() {
let holdout = HoldoutStrategy::Random {
holdout_fraction: 0.5,
seed: 42,
};
let mut model1 = MockLearner::new(0.0);
let mut validator = ProgressiveValidator::new(holdout);
for i in 0..20 {
validator.step(&mut model1, &[i as f64], i as f64);
}
let trained_first_run = validator.samples_trained();
validator.reset();
let mut model2 = MockLearner::new(0.0);
for i in 0..20 {
validator.step(&mut model2, &[i as f64], i as f64);
}
let trained_second_run = validator.samples_trained();
assert_eq!(
trained_first_run, trained_second_run,
"reset should restore PRNG state, producing identical holdout patterns"
);
}
#[test]
fn metrics_always_recorded_even_for_holdout() {
let mut model = MockLearner::new(0.0);
let holdout = HoldoutStrategy::Periodic { period: 2 };
let mut validator = ProgressiveValidator::new(holdout);
for i in 0..10 {
validator.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(validator.samples_evaluated(), 10);
assert_eq!(validator.regression_metrics().n_samples(), 10);
assert_eq!(validator.samples_trained(), 5);
}
#[test]
#[should_panic(expected = "holdout period must be > 0")]
fn periodic_zero_period_panics() {
ProgressiveValidator::new(HoldoutStrategy::Periodic { period: 0 });
}
#[test]
#[should_panic(expected = "holdout_fraction must be in [0.0, 1.0]")]
fn random_invalid_fraction_panics() {
ProgressiveValidator::new(HoldoutStrategy::Random {
holdout_fraction: 1.5,
seed: 42,
});
}
#[test]
#[should_panic(expected = "xorshift64 seed must be non-zero")]
fn random_zero_seed_panics() {
ProgressiveValidator::new(HoldoutStrategy::Random {
holdout_fraction: 0.5,
seed: 0,
});
}
#[test]
fn holdout_fraction_zero_means_all_train() {
let mut model = MockLearner::new(0.0);
let holdout = HoldoutStrategy::Random {
holdout_fraction: 0.0,
seed: 42,
};
let mut validator = ProgressiveValidator::new(holdout);
for i in 0..100 {
validator.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(validator.samples_trained(), 100);
}
#[test]
fn holdout_fraction_one_means_none_train() {
let mut model = MockLearner::new(0.0);
let holdout = HoldoutStrategy::Random {
holdout_fraction: 1.0,
seed: 42,
};
let mut validator = ProgressiveValidator::new(holdout);
for i in 0..100 {
validator.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(validator.samples_trained(), 0);
assert_eq!(model.n_samples_seen(), 0);
}
}