Skip to main content

irithyll/continual/
continual_wrapper.rs

1//! Drift-aware continual learning wrapper for any [`StreamingLearner`].
2//!
3//! [`ContinualLearner`] wraps an opaque streaming model and monitors
4//! prediction error via a pluggable [`DriftDetector`]. When drift is
5//! detected the inner model is reset (or partially reset), allowing it
6//! to adapt to the new data regime without accumulating stale knowledge.
7//!
8//! Because [`StreamingLearner`] is intentionally opaque -- no access to
9//! raw parameters or gradients -- the wrapper uses **prequential error**
10//! (predict-then-train) as the drift signal source. Models that expose
11//! parameters can compose with `ContinualStrategy` directly; this
12//! wrapper provides the outer orchestration layer.
13//!
14//! # Prequential Protocol
15//!
16//! On every `train_one` call the wrapper:
17//!
18//! 1. **Predicts** first (before the model has seen this sample).
19//! 2. Computes absolute prediction error `|pred - target|`.
20//! 3. Feeds the error to the drift detector.
21//! 4. If the detector signals `Drift` and `reset_on_drift` is enabled,
22//!    resets the inner model so it can re-learn from scratch.
23//! 5. Trains the inner model on the sample (whether or not a reset occurred).
24//!
25//! This is the standard **prequential evaluation** protocol used in
26//! streaming ML literature (Gama et al., 2013).
27
28use crate::learner::StreamingLearner;
29use irithyll_core::drift::{DriftDetector, DriftSignal};
30
31use std::fmt;
32
33// ---------------------------------------------------------------------------
34// ContinualLearner
35// ---------------------------------------------------------------------------
36
37/// Wraps any [`StreamingLearner`] with drift-detected continual adaptation.
38///
39/// Since `StreamingLearner` is opaque (no access to raw parameters or
40/// gradients), `ContinualLearner` uses prediction error to drive drift
41/// detection, which triggers model reset on the underlying learner.
42///
43/// For models that **do** expose parameters (neural models), the
44/// `ContinualStrategy` trait can be applied
45/// directly. This wrapper provides the higher-level orchestration layer.
46///
47/// # Example
48///
49/// ```
50/// use irithyll::continual::ContinualLearner;
51/// use irithyll::{linear, StreamingLearner};
52/// use irithyll_core::drift::pht::PageHinkleyTest;
53///
54/// let mut cl = ContinualLearner::new(linear(0.01))
55///     .with_drift_detector(PageHinkleyTest::new());
56///
57/// for i in 0..100 {
58///     cl.train(&[i as f64], i as f64 * 2.0);
59/// }
60/// let pred = cl.predict(&[50.0]);
61/// assert!(pred.is_finite());
62/// ```
63pub struct ContinualLearner {
64    /// The wrapped streaming model.
65    inner: Box<dyn StreamingLearner>,
66    /// Optional drift detector fed with prediction errors.
67    drift_detector: Option<Box<dyn DriftDetector>>,
68    /// Whether to reset the inner model on drift (default: true).
69    reset_on_drift: bool,
70    /// Total training samples seen (including across resets).
71    n_samples: u64,
72    /// Number of drift events detected.
73    drift_count: u64,
74    /// Most recent drift signal from the detector.
75    last_drift_signal: DriftSignal,
76}
77
78impl ContinualLearner {
79    /// Wrap a streaming learner with continual learning capabilities.
80    ///
81    /// The returned wrapper has no drift detector attached by default --
82    /// call [`with_drift_detector`](Self::with_drift_detector) to enable
83    /// drift-aware behaviour.
84    pub fn new(learner: impl StreamingLearner + 'static) -> Self {
85        Self {
86            inner: Box::new(learner),
87            drift_detector: None,
88            reset_on_drift: true,
89            n_samples: 0,
90            drift_count: 0,
91            last_drift_signal: DriftSignal::Stable,
92        }
93    }
94
95    /// Wrap a boxed streaming learner.
96    ///
97    /// Use this when the learner is already behind a
98    /// `Box<dyn StreamingLearner>`.
99    pub fn from_boxed(learner: Box<dyn StreamingLearner>) -> Self {
100        Self {
101            inner: learner,
102            drift_detector: None,
103            reset_on_drift: true,
104            n_samples: 0,
105            drift_count: 0,
106            last_drift_signal: DriftSignal::Stable,
107        }
108    }
109
110    // -----------------------------------------------------------------------
111    // Builder methods
112    // -----------------------------------------------------------------------
113
114    /// Attach a drift detector that monitors prediction error.
115    ///
116    /// The detector receives `|prediction - target|` on every training
117    /// sample (prequential protocol).
118    ///
119    /// # Example
120    ///
121    /// ```
122    /// use irithyll::continual::ContinualLearner;
123    /// use irithyll::linear;
124    /// use irithyll_core::drift::pht::PageHinkleyTest;
125    ///
126    /// let cl = ContinualLearner::new(linear(0.01))
127    ///     .with_drift_detector(PageHinkleyTest::new());
128    /// ```
129    pub fn with_drift_detector(mut self, detector: impl DriftDetector + 'static) -> Self {
130        self.drift_detector = Some(Box::new(detector));
131        self
132    }
133
134    /// Attach a boxed drift detector.
135    pub fn with_drift_detector_boxed(mut self, detector: Box<dyn DriftDetector>) -> Self {
136        self.drift_detector = Some(detector);
137        self
138    }
139
140    /// Set whether the inner model is reset when drift is detected.
141    ///
142    /// Default: `true`. When set to `false`, the wrapper still tracks
143    /// drift events and signals but does not reset the model.
144    pub fn with_reset_on_drift(mut self, reset: bool) -> Self {
145        self.reset_on_drift = reset;
146        self
147    }
148
149    // -----------------------------------------------------------------------
150    // Accessors
151    // -----------------------------------------------------------------------
152
153    /// Number of drift events detected since creation (or last reset).
154    #[inline]
155    pub fn drift_count(&self) -> u64 {
156        self.drift_count
157    }
158
159    /// Most recent drift signal from the detector.
160    ///
161    /// Returns [`DriftSignal::Stable`] if no detector is attached or no
162    /// samples have been processed.
163    #[inline]
164    pub fn last_signal(&self) -> DriftSignal {
165        self.last_drift_signal
166    }
167
168    /// Whether the wrapper is configured to reset on drift.
169    #[inline]
170    pub fn reset_on_drift(&self) -> bool {
171        self.reset_on_drift
172    }
173
174    /// Immutable reference to the wrapped streaming learner.
175    #[inline]
176    pub fn inner(&self) -> &dyn StreamingLearner {
177        &*self.inner
178    }
179
180    /// Mutable reference to the wrapped streaming learner.
181    #[inline]
182    pub fn inner_mut(&mut self) -> &mut dyn StreamingLearner {
183        &mut *self.inner
184    }
185
186    /// Whether a drift detector is attached.
187    #[inline]
188    pub fn has_drift_detector(&self) -> bool {
189        self.drift_detector.is_some()
190    }
191}
192
193// ---------------------------------------------------------------------------
194// StreamingLearner impl
195// ---------------------------------------------------------------------------
196
197impl StreamingLearner for ContinualLearner {
198    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
199        // Step 1: Prequential prediction (before this sample updates the model).
200        let pred = self.inner.predict(features);
201
202        // Step 2: Feed absolute prediction error to drift detector.
203        if let Some(ref mut detector) = self.drift_detector {
204            let error = (pred - target).abs();
205            let signal = detector.update(error);
206            self.last_drift_signal = signal;
207
208            // Step 3: Handle drift.
209            if signal == DriftSignal::Drift {
210                self.drift_count += 1;
211
212                if self.reset_on_drift {
213                    self.inner.reset();
214                }
215            }
216        }
217
218        // Step 4: Train the inner model (always, even after reset).
219        self.inner.train_one(features, target, weight);
220
221        // Step 5: Increment our own sample counter.
222        self.n_samples += 1;
223    }
224
225    #[inline]
226    fn predict(&self, features: &[f64]) -> f64 {
227        self.inner.predict(features)
228    }
229
230    #[inline]
231    fn n_samples_seen(&self) -> u64 {
232        self.n_samples
233    }
234
235    fn reset(&mut self) {
236        self.inner.reset();
237        if let Some(ref mut detector) = self.drift_detector {
238            detector.reset();
239        }
240        self.n_samples = 0;
241        self.drift_count = 0;
242        self.last_drift_signal = DriftSignal::Stable;
243    }
244
245    fn diagnostics_array(&self) -> [f64; 5] {
246        self.inner.diagnostics_array()
247    }
248
249    fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
250        self.inner.adjust_config(lr_multiplier, lambda_delta);
251    }
252
253    fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
254        self.inner.apply_structural_change(depth_delta, steps_delta);
255    }
256
257    fn replacement_count(&self) -> u64 {
258        self.inner.replacement_count()
259    }
260}
261
262// ---------------------------------------------------------------------------
263// Debug impl
264// ---------------------------------------------------------------------------
265
266impl fmt::Debug for ContinualLearner {
267    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268        f.debug_struct("ContinualLearner")
269            .field("n_samples", &self.n_samples)
270            .field("drift_count", &self.drift_count)
271            .field("last_signal", &self.last_drift_signal)
272            .field("reset_on_drift", &self.reset_on_drift)
273            .field("has_detector", &self.drift_detector.is_some())
274            .finish()
275    }
276}
277
278// ---------------------------------------------------------------------------
279// Factory function
280// ---------------------------------------------------------------------------
281
282/// Wrap any streaming learner with drift-detected continual adaptation.
283///
284/// Returns a [`ContinualLearner`] with no drift detector attached.
285/// Chain [`ContinualLearner::with_drift_detector`] to enable detection.
286///
287/// # Example
288///
289/// ```
290/// use irithyll::continual::continual;
291/// use irithyll::{esn, StreamingLearner};
292/// use irithyll_core::drift::pht::PageHinkleyTest;
293///
294/// let mut cl = continual(esn(50, 0.9))
295///     .with_drift_detector(PageHinkleyTest::new());
296///
297/// for i in 0..60 {
298///     cl.train(&[i as f64 * 0.1], 0.0);
299/// }
300/// let pred = cl.predict(&[1.0]);
301/// assert!(pred.is_finite());
302/// ```
303pub fn continual(learner: impl StreamingLearner + 'static) -> ContinualLearner {
304    ContinualLearner::new(learner)
305}
306
307// ---------------------------------------------------------------------------
308// DiagnosticSource impl
309// ---------------------------------------------------------------------------
310
311impl crate::automl::DiagnosticSource for ContinualLearner {
312    fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
313        // Cannot access inner learner diagnostics through Box<dyn StreamingLearner>.
314        None
315    }
316}
317
318// ---------------------------------------------------------------------------
319// Tests
320// ---------------------------------------------------------------------------
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use irithyll_core::drift::pht::PageHinkleyTest;
326
327    // A trivial learner for testing: tracks a running mean of targets.
328    struct MeanLearner {
329        sum: f64,
330        count: u64,
331    }
332
333    impl MeanLearner {
334        fn new() -> Self {
335            Self { sum: 0.0, count: 0 }
336        }
337    }
338
339    impl StreamingLearner for MeanLearner {
340        fn train_one(&mut self, _features: &[f64], target: f64, _weight: f64) {
341            self.sum += target;
342            self.count += 1;
343        }
344
345        fn predict(&self, _features: &[f64]) -> f64 {
346            if self.count == 0 {
347                return 0.0;
348            }
349            self.sum / self.count as f64
350        }
351
352        fn n_samples_seen(&self) -> u64 {
353            self.count
354        }
355
356        fn reset(&mut self) {
357            self.sum = 0.0;
358            self.count = 0;
359        }
360    }
361
362    // MeanLearner is trivially thread-safe.
363    unsafe impl Send for MeanLearner {}
364    unsafe impl Sync for MeanLearner {}
365
366    #[test]
367    fn wraps_learner_transparently() {
368        let mut cl = ContinualLearner::new(MeanLearner::new());
369
370        // Train with known values.
371        cl.train(&[1.0], 10.0);
372        cl.train(&[2.0], 20.0);
373
374        assert_eq!(cl.n_samples_seen(), 2);
375
376        // Predict should return the mean of targets (from inner MeanLearner).
377        let pred = cl.predict(&[0.0]);
378        assert!(
379            (pred - 15.0).abs() < 1e-6,
380            "expected mean ~15.0, got {}",
381            pred
382        );
383    }
384
385    #[test]
386    fn drift_detection_triggers_on_error_spike() {
387        // Use a very sensitive PHT to trigger quickly.
388        let pht = PageHinkleyTest::with_params(0.001, 5.0);
389        let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
390
391        // Phase 1: Train on stable data (target ~ 1.0).
392        // MeanLearner will converge toward 1.0, so prediction error stays small.
393        for _ in 0..200 {
394            cl.train(&[0.0], 1.0);
395        }
396        let drifts_before = cl.drift_count();
397
398        // Phase 2: Sudden regime shift (target -> 1000.0).
399        // Prediction is ~1.0 but target is 1000.0 => error ~999 => triggers drift.
400        let mut detected = false;
401        for _ in 0..200 {
402            cl.train(&[0.0], 1000.0);
403            if cl.drift_count() > drifts_before {
404                detected = true;
405                break;
406            }
407        }
408
409        assert!(detected, "drift should be detected on sudden error spike");
410    }
411
412    #[test]
413    fn drift_count_increments() {
414        let pht = PageHinkleyTest::with_params(0.001, 5.0);
415        let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
416
417        assert_eq!(cl.drift_count(), 0);
418
419        // Phase 1: stable.
420        for _ in 0..200 {
421            cl.train(&[0.0], 1.0);
422        }
423
424        // Phase 2: shift to trigger drift.
425        for _ in 0..200 {
426            cl.train(&[0.0], 1000.0);
427        }
428
429        assert!(
430            cl.drift_count() >= 1,
431            "drift_count should be >= 1 after regime shift, got {}",
432            cl.drift_count()
433        );
434    }
435
436    #[test]
437    fn reset_on_drift_resets_inner_model() {
438        let pht = PageHinkleyTest::with_params(0.001, 5.0);
439        let mut cl = ContinualLearner::new(MeanLearner::new())
440            .with_drift_detector(pht)
441            .with_reset_on_drift(true);
442
443        // Phase 1: stable training.
444        for _ in 0..200 {
445            cl.train(&[0.0], 1.0);
446        }
447
448        // Inner model has accumulated samples.
449        assert!(
450            cl.inner().n_samples_seen() > 0,
451            "inner should have samples before drift"
452        );
453
454        // Phase 2: trigger drift.
455        for _ in 0..200 {
456            cl.train(&[0.0], 1000.0);
457        }
458
459        // After drift + reset, the inner model was reset and then re-trained
460        // on the post-drift samples. Its count should be less than the total.
461        assert!(
462            cl.inner().n_samples_seen() < cl.n_samples_seen(),
463            "inner model samples ({}) should be less than total ({}) after reset",
464            cl.inner().n_samples_seen(),
465            cl.n_samples_seen()
466        );
467    }
468
469    #[test]
470    fn no_drift_detector_works_fine() {
471        // No detector attached -- pure pass-through.
472        let mut cl = ContinualLearner::new(MeanLearner::new());
473
474        cl.train(&[0.0], 5.0);
475        cl.train(&[0.0], 15.0);
476        assert_eq!(cl.n_samples_seen(), 2);
477
478        let pred = cl.predict(&[0.0]);
479        assert!(
480            (pred - 10.0).abs() < 1e-6,
481            "pass-through should work without detector: got {}",
482            pred
483        );
484
485        assert_eq!(cl.drift_count(), 0);
486        assert_eq!(cl.last_signal(), DriftSignal::Stable);
487    }
488
489    #[test]
490    fn predict_is_side_effect_free() {
491        let pht = PageHinkleyTest::new();
492        let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
493
494        cl.train(&[0.0], 10.0);
495        let n_before = cl.n_samples_seen();
496        let drift_before = cl.drift_count();
497        let signal_before = cl.last_signal();
498
499        // Multiple predictions should not change any state.
500        let _ = cl.predict(&[0.0]);
501        let _ = cl.predict(&[0.0]);
502        let _ = cl.predict(&[0.0]);
503
504        assert_eq!(
505            cl.n_samples_seen(),
506            n_before,
507            "predict should not change n_samples"
508        );
509        assert_eq!(
510            cl.drift_count(),
511            drift_before,
512            "predict should not change drift_count"
513        );
514        assert_eq!(
515            cl.last_signal(),
516            signal_before,
517            "predict should not change last_signal"
518        );
519    }
520
521    #[test]
522    fn n_samples_tracks_correctly() {
523        let mut cl = ContinualLearner::new(MeanLearner::new());
524
525        assert_eq!(cl.n_samples_seen(), 0);
526
527        for i in 1..=50 {
528            cl.train(&[0.0], i as f64);
529            assert_eq!(
530                cl.n_samples_seen(),
531                i,
532                "n_samples should be {} after {} trains",
533                i,
534                i
535            );
536        }
537    }
538
539    #[test]
540    fn inner_access_works() {
541        let mut cl = ContinualLearner::new(MeanLearner::new());
542
543        cl.train(&[0.0], 10.0);
544        cl.train(&[0.0], 20.0);
545
546        // inner() should reflect the model's state.
547        assert_eq!(cl.inner().n_samples_seen(), 2);
548
549        // inner_mut() should allow modification.
550        cl.inner_mut().reset();
551        assert_eq!(cl.inner().n_samples_seen(), 0);
552    }
553
554    #[test]
555    fn reset_clears_everything() {
556        let pht = PageHinkleyTest::with_params(0.001, 5.0);
557        let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
558
559        // Train and trigger drift.
560        for _ in 0..200 {
561            cl.train(&[0.0], 1.0);
562        }
563        for _ in 0..200 {
564            cl.train(&[0.0], 1000.0);
565        }
566
567        // Some state should have accumulated.
568        assert!(cl.n_samples_seen() > 0);
569
570        // Full reset.
571        cl.reset();
572
573        assert_eq!(
574            cl.n_samples_seen(),
575            0,
576            "n_samples should be zero after reset"
577        );
578        assert_eq!(
579            cl.drift_count(),
580            0,
581            "drift_count should be zero after reset"
582        );
583        assert_eq!(
584            cl.last_signal(),
585            DriftSignal::Stable,
586            "last_signal should be Stable after reset"
587        );
588        assert_eq!(
589            cl.inner().n_samples_seen(),
590            0,
591            "inner model should be reset"
592        );
593    }
594
595    #[test]
596    fn pipeline_composition_works() {
597        use crate::pipeline::Pipeline;
598
599        let cl = continual(MeanLearner::new());
600        let mut pipeline = Pipeline::builder().learner(cl);
601
602        pipeline.train(&[1.0, 2.0], 10.0);
603        pipeline.train(&[3.0, 4.0], 20.0);
604
605        assert_eq!(pipeline.n_samples_seen(), 2);
606
607        let pred = pipeline.predict(&[5.0, 6.0]);
608        assert!(pred.is_finite(), "pipeline prediction should be finite");
609    }
610
611    #[test]
612    fn factory_function_creates_wrapper() {
613        let mut cl = continual(MeanLearner::new());
614
615        cl.train(&[0.0], 42.0);
616        assert_eq!(cl.n_samples_seen(), 1);
617
618        let pred = cl.predict(&[0.0]);
619        assert!(
620            (pred - 42.0).abs() < 1e-6,
621            "factory-created wrapper should work: got {}",
622            pred
623        );
624    }
625
626    #[test]
627    fn with_reset_on_drift_false_does_not_reset() {
628        let pht = PageHinkleyTest::with_params(0.001, 5.0);
629        let mut cl = ContinualLearner::new(MeanLearner::new())
630            .with_drift_detector(pht)
631            .with_reset_on_drift(false);
632
633        // Phase 1: stable.
634        for _ in 0..200 {
635            cl.train(&[0.0], 1.0);
636        }
637        let inner_count_before_shift = cl.inner().n_samples_seen();
638
639        // Phase 2: trigger drift (but reset is disabled).
640        for _ in 0..200 {
641            cl.train(&[0.0], 1000.0);
642        }
643
644        // Drift should be detected but inner model NOT reset -- so inner
645        // count should equal total wrapper count (all samples accumulated).
646        assert!(
647            cl.drift_count() >= 1,
648            "drift should still be detected even with reset_on_drift=false"
649        );
650        assert_eq!(
651            cl.inner().n_samples_seen(),
652            cl.n_samples_seen(),
653            "inner model should NOT have been reset (reset_on_drift=false): inner={}, total={}",
654            cl.inner().n_samples_seen(),
655            cl.n_samples_seen()
656        );
657        assert!(
658            cl.inner().n_samples_seen() > inner_count_before_shift,
659            "inner should have continued accumulating samples"
660        );
661    }
662
663    #[test]
664    fn as_trait_object() {
665        // ContinualLearner should work behind Box<dyn StreamingLearner>.
666        let cl = ContinualLearner::new(MeanLearner::new());
667        let mut boxed: Box<dyn StreamingLearner> = Box::new(cl);
668
669        boxed.train(&[0.0], 7.0);
670        assert_eq!(boxed.n_samples_seen(), 1);
671
672        let pred = boxed.predict(&[0.0]);
673        assert!(
674            (pred - 7.0).abs() < 1e-6,
675            "trait object predict should work: got {}",
676            pred
677        );
678    }
679
680    #[test]
681    fn debug_format_is_informative() {
682        let cl =
683            ContinualLearner::new(MeanLearner::new()).with_drift_detector(PageHinkleyTest::new());
684
685        let debug = format!("{:?}", cl);
686        assert!(
687            debug.contains("ContinualLearner"),
688            "debug output should contain struct name"
689        );
690        assert!(
691            debug.contains("drift_count"),
692            "debug output should contain drift_count field"
693        );
694        assert!(
695            debug.contains("has_detector"),
696            "debug output should contain has_detector field"
697        );
698    }
699}