use crate::learner::StreamingLearner;
use crate::metrics::ewma::EwmaRegressionMetrics;
use crate::metrics::regression::RegressionMetrics;
use crate::metrics::rolling::RollingRegressionMetrics;
use crate::sample::Observation;
#[derive(Debug, Clone)]
pub struct PrequentialConfig {
pub warmup: usize,
pub step_interval: usize,
}
impl Default for PrequentialConfig {
fn default() -> Self {
Self {
warmup: 0,
step_interval: 1,
}
}
}
#[derive(Debug, Clone)]
pub struct PrequentialEvaluator {
config: PrequentialConfig,
regression: RegressionMetrics,
rolling: Option<RollingRegressionMetrics>,
ewma: Option<EwmaRegressionMetrics>,
samples_seen: u64,
samples_evaluated: u64,
}
impl PrequentialEvaluator {
pub fn new() -> Self {
Self::with_config(PrequentialConfig::default())
}
pub fn with_config(config: PrequentialConfig) -> Self {
assert!(config.step_interval > 0, "step_interval must be > 0");
Self {
config,
regression: RegressionMetrics::new(),
rolling: None,
ewma: None,
samples_seen: 0,
samples_evaluated: 0,
}
}
pub fn with_rolling_window(mut self, window_size: usize) -> Self {
self.rolling = Some(RollingRegressionMetrics::new(window_size));
self
}
pub fn with_ewma(mut self, span: usize) -> Self {
self.ewma = Some(EwmaRegressionMetrics::new(span));
self
}
pub fn step(&mut self, model: &mut dyn StreamingLearner, features: &[f64], target: f64) -> f64 {
let prediction = model.predict(features);
let past_warmup = self.samples_seen >= self.config.warmup as u64;
let on_step = self.config.step_interval == 1
|| (self.samples_seen % self.config.step_interval as u64 == 0);
if past_warmup && on_step {
self.regression.update(target, prediction);
if let Some(ref mut rolling) = self.rolling {
rolling.update(target, prediction);
}
if let Some(ref mut ewma) = self.ewma {
ewma.update(target, prediction);
}
self.samples_evaluated += 1;
}
model.train(features, target);
self.samples_seen += 1;
prediction
}
pub fn evaluate<O: Observation>(&mut self, model: &mut dyn StreamingLearner, data: &[O]) {
for obs in data {
self.step(model, obs.features(), obs.target());
}
}
pub fn regression_metrics(&self) -> &RegressionMetrics {
&self.regression
}
pub fn rolling_metrics(&self) -> Option<&RollingRegressionMetrics> {
self.rolling.as_ref()
}
pub fn ewma_metrics(&self) -> Option<&EwmaRegressionMetrics> {
self.ewma.as_ref()
}
pub fn samples_seen(&self) -> u64 {
self.samples_seen
}
pub fn samples_evaluated(&self) -> u64 {
self.samples_evaluated
}
pub fn reset(&mut self) {
self.regression.reset();
if let Some(ref mut rolling) = self.rolling {
rolling.reset();
}
if let Some(ref mut ewma) = self.ewma {
ewma.reset();
}
self.samples_seen = 0;
self.samples_evaluated = 0;
}
}
impl Default for PrequentialEvaluator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learner::StreamingLearner;
use crate::sample::Sample;
struct MockLearner {
prediction: f64,
n: u64,
}
impl MockLearner {
fn new(prediction: f64) -> Self {
Self { prediction, n: 0 }
}
}
impl StreamingLearner for MockLearner {
fn train_one(&mut self, _features: &[f64], _target: f64, _weight: f64) {
self.n += 1;
}
fn predict(&self, _features: &[f64]) -> f64 {
self.prediction
}
fn n_samples_seen(&self) -> u64 {
self.n
}
fn reset(&mut self) {
self.n = 0;
}
}
#[test]
fn basic_prequential_evaluation() {
let mut model = MockLearner::new(0.0);
let mut eval = PrequentialEvaluator::new();
for i in 0..100 {
eval.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(eval.samples_seen(), 100);
assert_eq!(eval.samples_evaluated(), 100);
assert_eq!(eval.regression_metrics().n_samples(), 100);
let mae = eval.regression_metrics().mae();
assert!((mae - 49.5).abs() < 0.01, "expected MAE ~49.5, got {}", mae);
}
#[test]
fn warmup_skips_early_evaluation() {
let mut model = MockLearner::new(0.0);
let config = PrequentialConfig {
warmup: 10,
step_interval: 1,
};
let mut eval = PrequentialEvaluator::with_config(config);
for i in 0..50 {
eval.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(eval.samples_seen(), 50);
assert_eq!(eval.samples_evaluated(), 40);
assert_eq!(eval.regression_metrics().n_samples(), 40);
assert_eq!(model.n_samples_seen(), 50);
}
#[test]
fn step_interval_evaluates_periodically() {
let mut model = MockLearner::new(0.0);
let config = PrequentialConfig {
warmup: 0,
step_interval: 3,
};
let mut eval = PrequentialEvaluator::with_config(config);
for i in 0..30 {
eval.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(eval.samples_seen(), 30);
assert_eq!(eval.samples_evaluated(), 10);
assert_eq!(model.n_samples_seen(), 30);
}
#[test]
fn rolling_window_tracks_metrics() {
let mut model = MockLearner::new(0.0);
let mut eval = PrequentialEvaluator::new().with_rolling_window(10);
for i in 0..100 {
eval.step(&mut model, &[i as f64], i as f64);
}
let rolling = eval.rolling_metrics().expect("rolling should be enabled");
assert_eq!(rolling.n_samples(), 10);
assert!(rolling.is_full());
let rolling_mae = rolling.mae();
assert!(
(rolling_mae - 94.5).abs() < 0.01,
"expected rolling MAE ~94.5, got {}",
rolling_mae
);
}
#[test]
fn ewma_tracks_metrics() {
let mut model = MockLearner::new(0.0);
let mut eval = PrequentialEvaluator::new().with_ewma(20);
for i in 0..100 {
eval.step(&mut model, &[i as f64], i as f64);
}
let ewma = eval.ewma_metrics().expect("EWMA should be enabled");
assert_eq!(ewma.n_samples(), 100);
assert!(
ewma.mae() > 50.0,
"EWMA MAE should be > 50, got {}",
ewma.mae()
);
}
#[test]
fn no_rolling_or_ewma_by_default() {
let eval = PrequentialEvaluator::new();
assert!(eval.rolling_metrics().is_none());
assert!(eval.ewma_metrics().is_none());
}
#[test]
fn step_returns_prediction() {
let mut model = MockLearner::new(42.0);
let mut eval = PrequentialEvaluator::new();
let pred = eval.step(&mut model, &[1.0], 5.0);
assert!(
(pred - 42.0).abs() < 1e-12,
"step should return the model's prediction, got {}",
pred
);
}
#[test]
fn evaluate_over_observations() {
let mut model = MockLearner::new(0.0);
let mut eval = PrequentialEvaluator::new();
let data: Vec<Sample> = (0..50)
.map(|i| Sample::new(vec![i as f64], i as f64))
.collect();
eval.evaluate(&mut model, &data);
assert_eq!(eval.samples_seen(), 50);
assert_eq!(eval.samples_evaluated(), 50);
assert_eq!(model.n_samples_seen(), 50);
}
#[test]
fn reset_clears_all_state() {
let mut model = MockLearner::new(0.0);
let mut eval = PrequentialEvaluator::new()
.with_rolling_window(10)
.with_ewma(20);
for i in 0..50 {
eval.step(&mut model, &[i as f64], i as f64);
}
assert!(eval.samples_seen() > 0);
assert!(eval.samples_evaluated() > 0);
eval.reset();
assert_eq!(eval.samples_seen(), 0);
assert_eq!(eval.samples_evaluated(), 0);
assert_eq!(eval.regression_metrics().n_samples(), 0);
assert_eq!(eval.rolling_metrics().unwrap().n_samples(), 0,);
assert_eq!(eval.ewma_metrics().unwrap().n_samples(), 0,);
}
#[test]
fn warmup_and_step_interval_combined() {
let mut model = MockLearner::new(0.0);
let config = PrequentialConfig {
warmup: 5,
step_interval: 2,
};
let mut eval = PrequentialEvaluator::with_config(config);
for i in 0..20 {
eval.step(&mut model, &[i as f64], i as f64);
}
assert_eq!(eval.samples_seen(), 20);
assert_eq!(eval.samples_evaluated(), 7);
assert_eq!(model.n_samples_seen(), 20);
}
#[test]
#[should_panic(expected = "step_interval must be > 0")]
fn zero_step_interval_panics() {
let config = PrequentialConfig {
warmup: 0,
step_interval: 0,
};
PrequentialEvaluator::with_config(config);
}
#[test]
fn default_evaluator_is_new() {
let eval = PrequentialEvaluator::default();
assert_eq!(eval.samples_seen(), 0);
assert_eq!(eval.samples_evaluated(), 0);
assert!(eval.rolling_metrics().is_none());
assert!(eval.ewma_metrics().is_none());
}
}