1use crate::learner::StreamingLearner;
29use irithyll_core::drift::{DriftDetector, DriftSignal};
30
31use std::fmt;
32
33pub struct ContinualLearner {
64 inner: Box<dyn StreamingLearner>,
66 drift_detector: Option<Box<dyn DriftDetector>>,
68 reset_on_drift: bool,
70 n_samples: u64,
72 drift_count: u64,
74 last_drift_signal: DriftSignal,
76}
77
78impl ContinualLearner {
79 pub fn new(learner: impl StreamingLearner + 'static) -> Self {
85 Self {
86 inner: Box::new(learner),
87 drift_detector: None,
88 reset_on_drift: true,
89 n_samples: 0,
90 drift_count: 0,
91 last_drift_signal: DriftSignal::Stable,
92 }
93 }
94
95 pub fn from_boxed(learner: Box<dyn StreamingLearner>) -> Self {
100 Self {
101 inner: learner,
102 drift_detector: None,
103 reset_on_drift: true,
104 n_samples: 0,
105 drift_count: 0,
106 last_drift_signal: DriftSignal::Stable,
107 }
108 }
109
110 pub fn with_drift_detector(mut self, detector: impl DriftDetector + 'static) -> Self {
130 self.drift_detector = Some(Box::new(detector));
131 self
132 }
133
134 pub fn with_drift_detector_boxed(mut self, detector: Box<dyn DriftDetector>) -> Self {
136 self.drift_detector = Some(detector);
137 self
138 }
139
140 pub fn with_reset_on_drift(mut self, reset: bool) -> Self {
145 self.reset_on_drift = reset;
146 self
147 }
148
149 #[inline]
155 pub fn drift_count(&self) -> u64 {
156 self.drift_count
157 }
158
159 #[inline]
164 pub fn last_signal(&self) -> DriftSignal {
165 self.last_drift_signal
166 }
167
168 #[inline]
170 pub fn reset_on_drift(&self) -> bool {
171 self.reset_on_drift
172 }
173
174 #[inline]
176 pub fn inner(&self) -> &dyn StreamingLearner {
177 &*self.inner
178 }
179
180 #[inline]
182 pub fn inner_mut(&mut self) -> &mut dyn StreamingLearner {
183 &mut *self.inner
184 }
185
186 #[inline]
188 pub fn has_drift_detector(&self) -> bool {
189 self.drift_detector.is_some()
190 }
191}
192
193impl StreamingLearner for ContinualLearner {
198 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
199 let pred = self.inner.predict(features);
201
202 if let Some(ref mut detector) = self.drift_detector {
204 let error = (pred - target).abs();
205 let signal = detector.update(error);
206 self.last_drift_signal = signal;
207
208 if signal == DriftSignal::Drift {
210 self.drift_count += 1;
211
212 if self.reset_on_drift {
213 self.inner.reset();
214 }
215 }
216 }
217
218 self.inner.train_one(features, target, weight);
220
221 self.n_samples += 1;
223 }
224
225 #[inline]
226 fn predict(&self, features: &[f64]) -> f64 {
227 self.inner.predict(features)
228 }
229
230 #[inline]
231 fn n_samples_seen(&self) -> u64 {
232 self.n_samples
233 }
234
235 fn reset(&mut self) {
236 self.inner.reset();
237 if let Some(ref mut detector) = self.drift_detector {
238 detector.reset();
239 }
240 self.n_samples = 0;
241 self.drift_count = 0;
242 self.last_drift_signal = DriftSignal::Stable;
243 }
244
245 fn diagnostics_array(&self) -> [f64; 5] {
246 self.inner.diagnostics_array()
247 }
248
249 fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
250 self.inner.adjust_config(lr_multiplier, lambda_delta);
251 }
252
253 fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
254 self.inner.apply_structural_change(depth_delta, steps_delta);
255 }
256
257 fn replacement_count(&self) -> u64 {
258 self.inner.replacement_count()
259 }
260}
261
262impl fmt::Debug for ContinualLearner {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 f.debug_struct("ContinualLearner")
269 .field("n_samples", &self.n_samples)
270 .field("drift_count", &self.drift_count)
271 .field("last_signal", &self.last_drift_signal)
272 .field("reset_on_drift", &self.reset_on_drift)
273 .field("has_detector", &self.drift_detector.is_some())
274 .finish()
275 }
276}
277
278pub fn continual(learner: impl StreamingLearner + 'static) -> ContinualLearner {
304 ContinualLearner::new(learner)
305}
306
307impl crate::automl::DiagnosticSource for ContinualLearner {
312 fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
313 None
315 }
316}
317
318#[cfg(test)]
323mod tests {
324 use super::*;
325 use irithyll_core::drift::pht::PageHinkleyTest;
326
327 struct MeanLearner {
329 sum: f64,
330 count: u64,
331 }
332
333 impl MeanLearner {
334 fn new() -> Self {
335 Self { sum: 0.0, count: 0 }
336 }
337 }
338
339 impl StreamingLearner for MeanLearner {
340 fn train_one(&mut self, _features: &[f64], target: f64, _weight: f64) {
341 self.sum += target;
342 self.count += 1;
343 }
344
345 fn predict(&self, _features: &[f64]) -> f64 {
346 if self.count == 0 {
347 return 0.0;
348 }
349 self.sum / self.count as f64
350 }
351
352 fn n_samples_seen(&self) -> u64 {
353 self.count
354 }
355
356 fn reset(&mut self) {
357 self.sum = 0.0;
358 self.count = 0;
359 }
360 }
361
362 unsafe impl Send for MeanLearner {}
364 unsafe impl Sync for MeanLearner {}
365
366 #[test]
367 fn wraps_learner_transparently() {
368 let mut cl = ContinualLearner::new(MeanLearner::new());
369
370 cl.train(&[1.0], 10.0);
372 cl.train(&[2.0], 20.0);
373
374 assert_eq!(cl.n_samples_seen(), 2);
375
376 let pred = cl.predict(&[0.0]);
378 assert!(
379 (pred - 15.0).abs() < 1e-6,
380 "expected mean ~15.0, got {}",
381 pred
382 );
383 }
384
385 #[test]
386 fn drift_detection_triggers_on_error_spike() {
387 let pht = PageHinkleyTest::with_params(0.001, 5.0);
389 let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
390
391 for _ in 0..200 {
394 cl.train(&[0.0], 1.0);
395 }
396 let drifts_before = cl.drift_count();
397
398 let mut detected = false;
401 for _ in 0..200 {
402 cl.train(&[0.0], 1000.0);
403 if cl.drift_count() > drifts_before {
404 detected = true;
405 break;
406 }
407 }
408
409 assert!(detected, "drift should be detected on sudden error spike");
410 }
411
412 #[test]
413 fn drift_count_increments() {
414 let pht = PageHinkleyTest::with_params(0.001, 5.0);
415 let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
416
417 assert_eq!(cl.drift_count(), 0);
418
419 for _ in 0..200 {
421 cl.train(&[0.0], 1.0);
422 }
423
424 for _ in 0..200 {
426 cl.train(&[0.0], 1000.0);
427 }
428
429 assert!(
430 cl.drift_count() >= 1,
431 "drift_count should be >= 1 after regime shift, got {}",
432 cl.drift_count()
433 );
434 }
435
436 #[test]
437 fn reset_on_drift_resets_inner_model() {
438 let pht = PageHinkleyTest::with_params(0.001, 5.0);
439 let mut cl = ContinualLearner::new(MeanLearner::new())
440 .with_drift_detector(pht)
441 .with_reset_on_drift(true);
442
443 for _ in 0..200 {
445 cl.train(&[0.0], 1.0);
446 }
447
448 assert!(
450 cl.inner().n_samples_seen() > 0,
451 "inner should have samples before drift"
452 );
453
454 for _ in 0..200 {
456 cl.train(&[0.0], 1000.0);
457 }
458
459 assert!(
462 cl.inner().n_samples_seen() < cl.n_samples_seen(),
463 "inner model samples ({}) should be less than total ({}) after reset",
464 cl.inner().n_samples_seen(),
465 cl.n_samples_seen()
466 );
467 }
468
469 #[test]
470 fn no_drift_detector_works_fine() {
471 let mut cl = ContinualLearner::new(MeanLearner::new());
473
474 cl.train(&[0.0], 5.0);
475 cl.train(&[0.0], 15.0);
476 assert_eq!(cl.n_samples_seen(), 2);
477
478 let pred = cl.predict(&[0.0]);
479 assert!(
480 (pred - 10.0).abs() < 1e-6,
481 "pass-through should work without detector: got {}",
482 pred
483 );
484
485 assert_eq!(cl.drift_count(), 0);
486 assert_eq!(cl.last_signal(), DriftSignal::Stable);
487 }
488
489 #[test]
490 fn predict_is_side_effect_free() {
491 let pht = PageHinkleyTest::new();
492 let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
493
494 cl.train(&[0.0], 10.0);
495 let n_before = cl.n_samples_seen();
496 let drift_before = cl.drift_count();
497 let signal_before = cl.last_signal();
498
499 let _ = cl.predict(&[0.0]);
501 let _ = cl.predict(&[0.0]);
502 let _ = cl.predict(&[0.0]);
503
504 assert_eq!(
505 cl.n_samples_seen(),
506 n_before,
507 "predict should not change n_samples"
508 );
509 assert_eq!(
510 cl.drift_count(),
511 drift_before,
512 "predict should not change drift_count"
513 );
514 assert_eq!(
515 cl.last_signal(),
516 signal_before,
517 "predict should not change last_signal"
518 );
519 }
520
521 #[test]
522 fn n_samples_tracks_correctly() {
523 let mut cl = ContinualLearner::new(MeanLearner::new());
524
525 assert_eq!(cl.n_samples_seen(), 0);
526
527 for i in 1..=50 {
528 cl.train(&[0.0], i as f64);
529 assert_eq!(
530 cl.n_samples_seen(),
531 i,
532 "n_samples should be {} after {} trains",
533 i,
534 i
535 );
536 }
537 }
538
539 #[test]
540 fn inner_access_works() {
541 let mut cl = ContinualLearner::new(MeanLearner::new());
542
543 cl.train(&[0.0], 10.0);
544 cl.train(&[0.0], 20.0);
545
546 assert_eq!(cl.inner().n_samples_seen(), 2);
548
549 cl.inner_mut().reset();
551 assert_eq!(cl.inner().n_samples_seen(), 0);
552 }
553
554 #[test]
555 fn reset_clears_everything() {
556 let pht = PageHinkleyTest::with_params(0.001, 5.0);
557 let mut cl = ContinualLearner::new(MeanLearner::new()).with_drift_detector(pht);
558
559 for _ in 0..200 {
561 cl.train(&[0.0], 1.0);
562 }
563 for _ in 0..200 {
564 cl.train(&[0.0], 1000.0);
565 }
566
567 assert!(cl.n_samples_seen() > 0);
569
570 cl.reset();
572
573 assert_eq!(
574 cl.n_samples_seen(),
575 0,
576 "n_samples should be zero after reset"
577 );
578 assert_eq!(
579 cl.drift_count(),
580 0,
581 "drift_count should be zero after reset"
582 );
583 assert_eq!(
584 cl.last_signal(),
585 DriftSignal::Stable,
586 "last_signal should be Stable after reset"
587 );
588 assert_eq!(
589 cl.inner().n_samples_seen(),
590 0,
591 "inner model should be reset"
592 );
593 }
594
595 #[test]
596 fn pipeline_composition_works() {
597 use crate::pipeline::Pipeline;
598
599 let cl = continual(MeanLearner::new());
600 let mut pipeline = Pipeline::builder().learner(cl);
601
602 pipeline.train(&[1.0, 2.0], 10.0);
603 pipeline.train(&[3.0, 4.0], 20.0);
604
605 assert_eq!(pipeline.n_samples_seen(), 2);
606
607 let pred = pipeline.predict(&[5.0, 6.0]);
608 assert!(pred.is_finite(), "pipeline prediction should be finite");
609 }
610
611 #[test]
612 fn factory_function_creates_wrapper() {
613 let mut cl = continual(MeanLearner::new());
614
615 cl.train(&[0.0], 42.0);
616 assert_eq!(cl.n_samples_seen(), 1);
617
618 let pred = cl.predict(&[0.0]);
619 assert!(
620 (pred - 42.0).abs() < 1e-6,
621 "factory-created wrapper should work: got {}",
622 pred
623 );
624 }
625
626 #[test]
627 fn with_reset_on_drift_false_does_not_reset() {
628 let pht = PageHinkleyTest::with_params(0.001, 5.0);
629 let mut cl = ContinualLearner::new(MeanLearner::new())
630 .with_drift_detector(pht)
631 .with_reset_on_drift(false);
632
633 for _ in 0..200 {
635 cl.train(&[0.0], 1.0);
636 }
637 let inner_count_before_shift = cl.inner().n_samples_seen();
638
639 for _ in 0..200 {
641 cl.train(&[0.0], 1000.0);
642 }
643
644 assert!(
647 cl.drift_count() >= 1,
648 "drift should still be detected even with reset_on_drift=false"
649 );
650 assert_eq!(
651 cl.inner().n_samples_seen(),
652 cl.n_samples_seen(),
653 "inner model should NOT have been reset (reset_on_drift=false): inner={}, total={}",
654 cl.inner().n_samples_seen(),
655 cl.n_samples_seen()
656 );
657 assert!(
658 cl.inner().n_samples_seen() > inner_count_before_shift,
659 "inner should have continued accumulating samples"
660 );
661 }
662
663 #[test]
664 fn as_trait_object() {
665 let cl = ContinualLearner::new(MeanLearner::new());
667 let mut boxed: Box<dyn StreamingLearner> = Box::new(cl);
668
669 boxed.train(&[0.0], 7.0);
670 assert_eq!(boxed.n_samples_seen(), 1);
671
672 let pred = boxed.predict(&[0.0]);
673 assert!(
674 (pred - 7.0).abs() < 1e-6,
675 "trait object predict should work: got {}",
676 pred
677 );
678 }
679
680 #[test]
681 fn debug_format_is_informative() {
682 let cl =
683 ContinualLearner::new(MeanLearner::new()).with_drift_detector(PageHinkleyTest::new());
684
685 let debug = format!("{:?}", cl);
686 assert!(
687 debug.contains("ContinualLearner"),
688 "debug output should contain struct name"
689 );
690 assert!(
691 debug.contains("drift_count"),
692 "debug output should contain drift_count field"
693 );
694 assert!(
695 debug.contains("has_detector"),
696 "debug output should contain has_detector field"
697 );
698 }
699}