use crate::learner::StreamingLearner;
use irithyll_core::drift::{DriftDetector, DriftSignal};
use std::fmt;
pub struct ContinualLearner {
inner: Box<dyn StreamingLearner>,
drift_detector: Option<Box<dyn DriftDetector>>,
reset_on_drift: bool,
n_samples: u64,
drift_count: u64,
last_drift_signal: DriftSignal,
}
impl ContinualLearner {
pub fn new(learner: impl StreamingLearner + 'static) -> Self {
Self {
inner: Box::new(learner),
drift_detector: None,
reset_on_drift: true,
n_samples: 0,
drift_count: 0,
last_drift_signal: DriftSignal::Stable,
}
}
pub fn from_boxed(learner: Box<dyn StreamingLearner>) -> Self {
Self {
inner: learner,
drift_detector: None,
reset_on_drift: true,
n_samples: 0,
drift_count: 0,
last_drift_signal: DriftSignal::Stable,
}
}
pub fn with_drift_detector(mut self, detector: impl DriftDetector + 'static) -> Self {
self.drift_detector = Some(Box::new(detector));
self
}
pub fn with_drift_detector_boxed(mut self, detector: Box<dyn DriftDetector>) -> Self {
self.drift_detector = Some(detector);
self
}
pub fn with_reset_on_drift(mut self, reset: bool) -> Self {
self.reset_on_drift = reset;
self
}
#[inline]
pub fn drift_count(&self) -> u64 {
self.drift_count
}
#[inline]
pub fn last_signal(&self) -> DriftSignal {
self.last_drift_signal
}
#[inline]
pub fn reset_on_drift(&self) -> bool {
self.reset_on_drift
}
#[inline]
pub fn inner(&self) -> &dyn StreamingLearner {
&*self.inner
}
#[inline]
pub fn inner_mut(&mut self) -> &mut dyn StreamingLearner {
&mut *self.inner
}
#[inline]
pub fn has_drift_detector(&self) -> bool {
self.drift_detector.is_some()
}
}
impl StreamingLearner for ContinualLearner {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let pred = self.inner.predict(features);
if let Some(ref mut detector) = self.drift_detector {
let error = (pred - target).abs();
let signal = detector.update(error);
self.last_drift_signal = signal;
if signal == DriftSignal::Drift {
self.drift_count += 1;
if self.reset_on_drift {
self.inner.reset();
}
}
}
self.inner.train_one(features, target, weight);
self.n_samples += 1;
}
#[inline]
fn predict(&self, features: &[f64]) -> f64 {
self.inner.predict(features)
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.n_samples
}
fn reset(&mut self) {
self.inner.reset();
if let Some(ref mut detector) = self.drift_detector {
detector.reset();
}
self.n_samples = 0;
self.drift_count = 0;
self.last_drift_signal = DriftSignal::Stable;
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
self.inner.diagnostics_array()
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
self.inner.adjust_config(lr_multiplier, lambda_delta);
}
#[allow(deprecated)]
fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
self.inner.apply_structural_change(depth_delta, steps_delta);
}
#[allow(deprecated)]
fn replacement_count(&self) -> u64 {
self.inner.replacement_count()
}
}
impl fmt::Debug for ContinualLearner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ContinualLearner")
.field("n_samples", &self.n_samples)
.field("drift_count", &self.drift_count)
.field("last_signal", &self.last_drift_signal)
.field("reset_on_drift", &self.reset_on_drift)
.field("has_detector", &self.drift_detector.is_some())
.finish()
}
}
pub fn continual(learner: impl StreamingLearner + 'static) -> ContinualLearner {
ContinualLearner::new(learner)
}
impl crate::automl::DiagnosticSource for ContinualLearner {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use irithyll_core::drift::pht::PageHinkleyTest;
struct MeanLearner {
sum: f64,
count: u64,
}
impl MeanLearner {
fn new() -> Self {
Self { sum: 0.0, count: 0 }
}
}
impl StreamingLearner for MeanLearner {
fn train_one(&mut self, _features: &[f64], target: f64, _weight: f64) {
self.sum += target;
self.count += 1;
}
fn predict(&self, _features: &[f64]) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum / self.count as f64
}
fn n_samples_seen(&self) -> u64 {
self.count
}
fn reset(&mut self) {
self.sum = 0.0;
self.count = 0;
}
}
#[test]
fn wraps_learner_transparently() {
let mut cl = ContinualLearner::new(MeanLearner::new());
cl.train(&[1.0], 10.0);
cl.train(&[2.0], 20.0);
assert_eq!(cl.n_samples_seen(), 2);
let pred = cl.predict(&[0.0]);
assert!(
(pred - 15.0).abs() < 1e-6,
"expected mean ~15.0, got {}",
pred
);
}
#[test]
fn drift_detection_triggers_on_error_spike() {
let pht = PageHinkleyTest::with_params(0.001, 5.0);
let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
for _ in 0..200 {
cl.train(&[0.0], 1.0);
}
let drifts_before = cl.drift_count();
let mut detected = false;
for _ in 0..200 {
cl.train(&[0.0], 1000.0);
if cl.drift_count() > drifts_before {
detected = true;
break;
}
}
assert!(detected, "drift should be detected on sudden error spike");
}
#[test]
fn drift_count_increments() {
let pht = PageHinkleyTest::with_params(0.001, 5.0);
let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
assert_eq!(cl.drift_count(), 0);
for _ in 0..200 {
cl.train(&[0.0], 1.0);
}
for _ in 0..200 {
cl.train(&[0.0], 1000.0);
}
assert!(
cl.drift_count() >= 1,
"drift_count should be >= 1 after regime shift, got {}",
cl.drift_count()
);
}
#[test]
fn reset_on_drift_resets_inner_model() {
let pht = PageHinkleyTest::with_params(0.001, 5.0);
let mut cl = ContinualLearner::new(MeanLearner::new())
.with_drift_detector(pht)
.with_reset_on_drift(true);
for _ in 0..200 {
cl.train(&[0.0], 1.0);
}
assert!(
cl.inner().n_samples_seen() > 0,
"inner should have samples before drift"
);
for _ in 0..200 {
cl.train(&[0.0], 1000.0);
}
assert!(
cl.inner().n_samples_seen() < cl.n_samples_seen(),
"inner model samples ({}) should be less than total ({}) after reset",
cl.inner().n_samples_seen(),
cl.n_samples_seen()
);
}
#[test]
fn no_drift_detector_works_fine() {
let mut cl = ContinualLearner::new(MeanLearner::new());
cl.train(&[0.0], 5.0);
cl.train(&[0.0], 15.0);
assert_eq!(cl.n_samples_seen(), 2);
let pred = cl.predict(&[0.0]);
assert!(
(pred - 10.0).abs() < 1e-6,
"pass-through should work without detector: got {}",
pred
);
assert_eq!(cl.drift_count(), 0);
assert_eq!(cl.last_signal(), DriftSignal::Stable);
}
#[test]
fn predict_is_side_effect_free() {
let pht = PageHinkleyTest::new();
let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
cl.train(&[0.0], 10.0);
let n_before = cl.n_samples_seen();
let drift_before = cl.drift_count();
let signal_before = cl.last_signal();
let _ = cl.predict(&[0.0]);
let _ = cl.predict(&[0.0]);
let _ = cl.predict(&[0.0]);
assert_eq!(
cl.n_samples_seen(),
n_before,
"predict should not change n_samples"
);
assert_eq!(
cl.drift_count(),
drift_before,
"predict should not change drift_count"
);
assert_eq!(
cl.last_signal(),
signal_before,
"predict should not change last_signal"
);
}
#[test]
fn n_samples_tracks_correctly() {
let mut cl = ContinualLearner::new(MeanLearner::new());
assert_eq!(cl.n_samples_seen(), 0);
for i in 1..=50 {
cl.train(&[0.0], i as f64);
assert_eq!(
cl.n_samples_seen(),
i,
"n_samples should be {} after {} trains",
i,
i
);
}
}
#[test]
fn inner_access_works() {
let mut cl = ContinualLearner::new(MeanLearner::new());
cl.train(&[0.0], 10.0);
cl.train(&[0.0], 20.0);
assert_eq!(cl.inner().n_samples_seen(), 2);
cl.inner_mut().reset();
assert_eq!(cl.inner().n_samples_seen(), 0);
}
#[test]
fn reset_clears_everything() {
let pht = PageHinkleyTest::with_params(0.001, 5.0);
let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
for _ in 0..200 {
cl.train(&[0.0], 1.0);
}
for _ in 0..200 {
cl.train(&[0.0], 1000.0);
}
assert!(cl.n_samples_seen() > 0);
cl.reset();
assert_eq!(
cl.n_samples_seen(),
0,
"n_samples should be zero after reset"
);
assert_eq!(
cl.drift_count(),
0,
"drift_count should be zero after reset"
);
assert_eq!(
cl.last_signal(),
DriftSignal::Stable,
"last_signal should be Stable after reset"
);
assert_eq!(
cl.inner().n_samples_seen(),
0,
"inner model should be reset"
);
}
#[test]
fn pipeline_composition_works() {
use crate::pipeline::Pipeline;
let cl = continual(MeanLearner::new());
let mut pipeline = Pipeline::builder().learner(cl);
pipeline.train(&[1.0, 2.0], 10.0);
pipeline.train(&[3.0, 4.0], 20.0);
assert_eq!(pipeline.n_samples_seen(), 2);
let pred = pipeline.predict(&[5.0, 6.0]);
assert!(pred.is_finite(), "pipeline prediction should be finite");
}
#[test]
fn factory_function_creates_wrapper() {
let mut cl = continual(MeanLearner::new());
cl.train(&[0.0], 42.0);
assert_eq!(cl.n_samples_seen(), 1);
let pred = cl.predict(&[0.0]);
assert!(
(pred - 42.0).abs() < 1e-6,
"factory-created wrapper should work: got {}",
pred
);
}
#[test]
fn with_reset_on_drift_false_does_not_reset() {
let pht = PageHinkleyTest::with_params(0.001, 5.0);
let mut cl = ContinualLearner::new(MeanLearner::new())
.with_drift_detector(pht)
.with_reset_on_drift(false);
for _ in 0..200 {
cl.train(&[0.0], 1.0);
}
let inner_count_before_shift = cl.inner().n_samples_seen();
for _ in 0..200 {
cl.train(&[0.0], 1000.0);
}
assert!(
cl.drift_count() >= 1,
"drift should still be detected even with reset_on_drift=false"
);
assert_eq!(
cl.inner().n_samples_seen(),
cl.n_samples_seen(),
"inner model should NOT have been reset (reset_on_drift=false): inner={}, total={}",
cl.inner().n_samples_seen(),
cl.n_samples_seen()
);
assert!(
cl.inner().n_samples_seen() > inner_count_before_shift,
"inner should have continued accumulating samples"
);
}
#[test]
fn as_trait_object() {
let cl = ContinualLearner::new(MeanLearner::new());
let mut boxed: Box<dyn StreamingLearner> = Box::new(cl);
boxed.train(&[0.0], 7.0);
assert_eq!(boxed.n_samples_seen(), 1);
let pred = boxed.predict(&[0.0]);
assert!(
(pred - 7.0).abs() < 1e-6,
"trait object predict should work: got {}",
pred
);
}
#[test]
fn debug_format_is_informative() {
let cl =
ContinualLearner::new(MeanLearner::new()).with_drift_detector(PageHinkleyTest::new());
let debug = format!("{:?}", cl);
assert!(
debug.contains("ContinualLearner"),
"debug output should contain struct name"
);
assert!(
debug.contains("drift_count"),
"debug output should contain drift_count field"
);
assert!(
debug.contains("has_detector"),
"debug output should contain has_detector field"
);
}
}