1use alloc::boxed::Box;
20use core::fmt;
21
22use crate::drift::{DriftDetector, DriftSignal};
23use crate::tree::builder::TreeConfig;
24use crate::tree::hoeffding::HoeffdingTree;
25use crate::tree::StreamingTree;
26
27pub struct TreeSlot {
35 active: HoeffdingTree,
37 alternate: Option<HoeffdingTree>,
39 detector: Box<dyn DriftDetector>,
41 tree_config: TreeConfig,
43 max_tree_samples: Option<u64>,
45 replacements: u64,
47 pred_count: u64,
49 pred_mean: f64,
51 pred_m2: f64,
53 shadow_warmup: usize,
56 samples_at_activation: u64,
60}
61
62impl fmt::Debug for TreeSlot {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 f.debug_struct("TreeSlot")
65 .field("active_leaves", &self.active.n_leaves())
66 .field("active_samples", &self.active.n_samples_seen())
67 .field("has_alternate", &self.alternate.is_some())
68 .field("tree_config", &self.tree_config)
69 .finish()
70 }
71}
72
73impl Clone for TreeSlot {
74 fn clone(&self) -> Self {
75 Self {
76 active: self.active.clone(),
77 alternate: self.alternate.clone(),
78 detector: self.detector.clone_boxed(),
79 tree_config: self.tree_config.clone(),
80 max_tree_samples: self.max_tree_samples,
81 replacements: self.replacements,
82 pred_count: self.pred_count,
83 pred_mean: self.pred_mean,
84 pred_m2: self.pred_m2,
85 shadow_warmup: self.shadow_warmup,
86 samples_at_activation: self.samples_at_activation,
87 }
88 }
89}
90
91impl TreeSlot {
92 pub fn new(
97 tree_config: TreeConfig,
98 detector: Box<dyn DriftDetector>,
99 max_tree_samples: Option<u64>,
100 ) -> Self {
101 Self::with_shadow_warmup(tree_config, detector, max_tree_samples, 0)
102 }
103
104 pub fn with_shadow_warmup(
110 tree_config: TreeConfig,
111 detector: Box<dyn DriftDetector>,
112 max_tree_samples: Option<u64>,
113 shadow_warmup: usize,
114 ) -> Self {
115 let alternate = if shadow_warmup > 0 {
116 Some(HoeffdingTree::new(tree_config.clone()))
117 } else {
118 None
119 };
120 Self {
121 active: HoeffdingTree::new(tree_config.clone()),
122 alternate,
123 detector,
124 tree_config,
125 max_tree_samples,
126 replacements: 0,
127 pred_count: 0,
128 pred_mean: 0.0,
129 pred_m2: 0.0,
130 shadow_warmup,
131 samples_at_activation: 0,
132 }
133 }
134
135 pub fn from_trees(
140 active: HoeffdingTree,
141 alternate: Option<HoeffdingTree>,
142 tree_config: TreeConfig,
143 detector: Box<dyn DriftDetector>,
144 max_tree_samples: Option<u64>,
145 ) -> Self {
146 Self {
147 active,
148 alternate,
149 detector,
150 tree_config,
151 max_tree_samples,
152 replacements: 0,
153 pred_count: 0,
154 pred_mean: 0.0,
155 pred_m2: 0.0,
156 shadow_warmup: 0,
157 samples_at_activation: 0,
158 }
159 }
160
161 pub fn train_and_predict(&mut self, features: &[f64], gradient: f64, hessian: f64) -> f64 {
182 let prediction = self.active.predict(features);
184
185 self.pred_count += 1;
187 let delta = prediction - self.pred_mean;
188 self.pred_mean += delta / self.pred_count as f64;
189 let delta2 = prediction - self.pred_mean;
190 self.pred_m2 += delta * delta2;
191
192 self.active.train_one(features, gradient, hessian);
194
195 if let Some(ref mut alt) = self.alternate {
197 alt.train_one(features, gradient, hessian);
198 }
199
200 let error = crate::math::abs(gradient);
204 let signal = self.detector.update(error);
205
206 match signal {
208 DriftSignal::Stable => {}
209 DriftSignal::Warning => {
210 if self.alternate.is_none() {
212 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
213 }
214 }
215 DriftSignal::Drift => {
216 self.active = self
219 .alternate
220 .take()
221 .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
222 self.samples_at_activation = self.active.n_samples_seen();
224 self.detector = self.detector.clone_fresh();
226 self.replacements += 1;
228 self.pred_count = 0;
229 self.pred_mean = 0.0;
230 self.pred_m2 = 0.0;
231 if self.shadow_warmup > 0 {
233 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
234 }
235 }
236 }
237
238 if let Some(max_samples) = self.max_tree_samples {
242 let active_age = self
243 .active
244 .n_samples_seen()
245 .saturating_sub(self.samples_at_activation);
246
247 let threshold = if self.shadow_warmup > 0 {
249 (max_samples as f64 * 1.2) as u64
250 } else {
251 max_samples
252 };
253
254 if active_age >= threshold {
255 self.active = self
256 .alternate
257 .take()
258 .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
259 self.samples_at_activation = self.active.n_samples_seen();
261 self.detector = self.detector.clone_fresh();
262 self.replacements += 1;
263 self.pred_count = 0;
264 self.pred_mean = 0.0;
265 self.pred_m2 = 0.0;
266 if self.shadow_warmup > 0 {
268 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
269 }
270 }
271 }
272
273 if self.shadow_warmup > 0 && self.alternate.is_none() {
275 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
276 }
277
278 prediction
279 }
280
281 #[inline]
286 pub fn predict(&self, features: &[f64]) -> f64 {
287 self.active.predict(features)
288 }
289
290 #[inline]
294 pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
295 self.active.predict_with_variance(features)
296 }
297
298 #[inline]
302 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
303 self.active.predict_smooth(features, bandwidth)
304 }
305
306 #[inline]
308 pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
309 self.active.predict_smooth_auto(features, bandwidths)
310 }
311
312 #[inline]
314 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
315 self.active.predict_interpolated(features)
316 }
317
318 #[inline]
320 pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
321 self.active
322 .predict_sibling_interpolated(features, bandwidths)
323 }
324
325 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
334 let active_pred = self.active.predict(features);
335
336 if self.shadow_warmup == 0 {
337 return active_pred;
338 }
339
340 let Some(ref shadow) = self.alternate else {
341 return active_pred;
342 };
343
344 let shadow_samples = shadow.n_samples_seen();
345 if shadow_samples < self.shadow_warmup as u64 {
346 return active_pred;
347 }
348
349 let shadow_pred = shadow.predict(features);
350 self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
351 }
352
353 pub fn predict_graduated_sibling_interpolated(
358 &self,
359 features: &[f64],
360 bandwidths: &[f64],
361 ) -> f64 {
362 let active_pred = self
363 .active
364 .predict_sibling_interpolated(features, bandwidths);
365
366 if self.shadow_warmup == 0 {
367 return active_pred;
368 }
369
370 let Some(ref shadow) = self.alternate else {
371 return active_pred;
372 };
373
374 let shadow_samples = shadow.n_samples_seen();
375 if shadow_samples < self.shadow_warmup as u64 {
376 return active_pred;
377 }
378
379 let shadow_pred = shadow.predict_sibling_interpolated(features, bandwidths);
380 self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
381 }
382
383 #[inline]
385 fn blend_active_shadow(&self, active_pred: f64, shadow_pred: f64, shadow_samples: u64) -> f64 {
386 let active_age = self
387 .active
388 .n_samples_seen()
389 .saturating_sub(self.samples_at_activation);
390 let mts = self.max_tree_samples.unwrap_or(u64::MAX) as f64;
391
392 let active_w = if (active_age as f64) < mts * 0.8 {
394 1.0
395 } else {
396 let progress = (active_age as f64 - mts * 0.8) / (mts * 0.4);
397 (1.0 - progress).clamp(0.0, 1.0)
398 };
399
400 let shadow_w = ((shadow_samples as f64 - self.shadow_warmup as f64)
402 / self.shadow_warmup as f64)
403 .clamp(0.0, 1.0);
404
405 let total = active_w + shadow_w;
407 if total < 1e-10 {
408 return shadow_pred;
409 }
410
411 (active_w * active_pred + shadow_w * shadow_pred) / total
412 }
413
414 #[inline]
416 pub fn shadow_warmup(&self) -> usize {
417 self.shadow_warmup
418 }
419
420 #[inline]
422 pub fn replacements(&self) -> u64 {
423 self.replacements
424 }
425
426 #[inline]
428 pub fn prediction_mean(&self) -> f64 {
429 self.pred_mean
430 }
431
432 #[inline]
434 pub fn prediction_std(&self) -> f64 {
435 if self.pred_count < 2 {
436 0.0
437 } else {
438 crate::math::sqrt(self.pred_m2 / (self.pred_count - 1) as f64)
439 }
440 }
441
442 #[inline]
444 pub fn n_leaves(&self) -> usize {
445 self.active.n_leaves()
446 }
447
448 #[inline]
450 pub fn n_samples_seen(&self) -> u64 {
451 self.active.n_samples_seen()
452 }
453
454 #[inline]
456 pub fn has_alternate(&self) -> bool {
457 self.alternate.is_some()
458 }
459
460 #[inline]
462 pub fn split_gains(&self) -> &[f64] {
463 self.active.split_gains()
464 }
465
466 #[inline]
468 pub fn active_tree(&self) -> &HoeffdingTree {
469 &self.active
470 }
471
472 #[inline]
474 pub fn alternate_tree(&self) -> Option<&HoeffdingTree> {
475 self.alternate.as_ref()
476 }
477
478 #[inline]
480 pub fn tree_config(&self) -> &TreeConfig {
481 &self.tree_config
482 }
483
484 #[inline]
486 pub fn detector(&self) -> &dyn DriftDetector {
487 &*self.detector
488 }
489
490 #[inline]
492 pub fn detector_mut(&mut self) -> &mut dyn DriftDetector {
493 &mut *self.detector
494 }
495
496 #[inline]
500 pub fn alt_detector(&self) -> Option<&dyn DriftDetector> {
501 None
503 }
504
505 #[inline]
507 pub fn alt_detector_mut(&mut self) -> Option<&mut dyn DriftDetector> {
508 None
509 }
510
511 pub fn reset(&mut self) {
513 self.active = HoeffdingTree::new(self.tree_config.clone());
514 self.alternate = if self.shadow_warmup > 0 {
515 Some(HoeffdingTree::new(self.tree_config.clone()))
516 } else {
517 None
518 };
519 self.detector = self.detector.clone_fresh();
520 self.replacements = 0;
521 self.pred_count = 0;
522 self.pred_mean = 0.0;
523 self.pred_m2 = 0.0;
524 self.samples_at_activation = 0;
525 }
526}
527
528#[cfg(test)]
533mod tests {
534 use super::*;
535 use crate::drift::pht::PageHinkleyTest;
536 use alloc::boxed::Box;
537 use alloc::format;
538
539 fn test_tree_config() -> TreeConfig {
541 TreeConfig::new()
542 .grace_period(20)
543 .max_depth(4)
544 .n_bins(16)
545 .lambda(1.0)
546 }
547
548 fn test_detector() -> Box<dyn DriftDetector> {
550 Box::new(PageHinkleyTest::new())
551 }
552
553 #[test]
557 fn new_slot_predicts_zero() {
558 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
559
560 let pred = slot.predict(&[1.0, 2.0, 3.0]);
562 assert!(
563 pred.abs() < 1e-12,
564 "fresh slot should predict ~0.0, got {}",
565 pred,
566 );
567 }
568
569 #[test]
573 fn train_and_predict_returns_prediction() {
574 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
575
576 let features = [1.0, 2.0, 3.0];
577 let pred = slot.train_and_predict(&features, -0.5, 1.0);
578
579 assert!(
581 pred.abs() < 1e-12,
582 "first prediction should be ~0.0, got {}",
583 pred,
584 );
585
586 let pred2 = slot.predict(&features);
589 assert!(
590 pred2.is_finite(),
591 "prediction after training should be finite"
592 );
593 }
594
595 #[test]
599 fn stable_stream_no_alternate() {
600 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
601 let features = [1.0, 2.0, 3.0];
602
603 for _ in 0..500 {
607 slot.train_and_predict(&features, -0.1, 1.0);
608 }
609
610 assert!(
611 !slot.has_alternate(),
612 "stable error stream should not spawn an alternate tree",
613 );
614 }
615
616 #[test]
620 fn reset_returns_to_fresh_state() {
621 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
622 let features = [1.0, 2.0, 3.0];
623
624 for _ in 0..100 {
626 slot.train_and_predict(&features, -0.5, 1.0);
627 }
628
629 assert!(slot.n_samples_seen() > 0, "should have trained samples");
630
631 slot.reset();
632
633 assert_eq!(
634 slot.n_leaves(),
635 1,
636 "after reset, should have exactly 1 leaf"
637 );
638 assert_eq!(
639 slot.n_samples_seen(),
640 0,
641 "after reset, samples_seen should be 0"
642 );
643 assert!(
644 !slot.has_alternate(),
645 "after reset, no alternate should exist"
646 );
647
648 let pred = slot.predict(&features);
650 assert!(
651 pred.abs() < 1e-12,
652 "prediction after reset should be ~0.0, got {}",
653 pred,
654 );
655 }
656
657 #[test]
661 fn predict_without_training() {
662 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
663
664 for i in 0..10 {
666 let x = (i as f64) * 0.5;
667 let pred = slot.predict(&[x, x + 1.0]);
668 assert!(
669 pred.abs() < 1e-12,
670 "untrained slot should predict ~0.0 for any input, got {} at i={}",
671 pred,
672 i,
673 );
674 }
675 }
676
677 #[test]
681 fn drift_replaces_active_tree() {
682 let sensitive_detector = Box::new(PageHinkleyTest::with_params(0.005, 5.0));
684 let mut slot = TreeSlot::new(test_tree_config(), sensitive_detector, None);
685 let features = [1.0, 2.0, 3.0];
686
687 for _ in 0..200 {
689 slot.train_and_predict(&features, -0.01, 1.0);
690 }
691 let samples_before_drift = slot.n_samples_seen();
692
693 let mut drift_occurred = false;
695 for _ in 0..500 {
696 slot.train_and_predict(&features, -50.0, 1.0);
697 if slot.n_samples_seen() < samples_before_drift {
700 drift_occurred = true;
701 break;
702 }
703 }
704
705 assert!(
706 drift_occurred,
707 "abrupt gradient shift should trigger drift and replace the active tree",
708 );
709 }
710
711 #[test]
715 fn n_leaves_reflects_active_tree() {
716 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
717 assert_eq!(slot.n_leaves(), 1, "fresh slot should have exactly 1 leaf",);
718 }
719
720 #[test]
724 fn debug_format_does_not_panic() {
725 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
726 let debug_str = format!("{:?}", slot);
727 assert!(
728 debug_str.contains("TreeSlot"),
729 "debug output should contain 'TreeSlot'",
730 );
731 }
732
733 #[test]
737 fn time_based_replacement_triggers() {
738 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), Some(200));
739 let features = [1.0, 2.0, 3.0];
740
741 for _ in 0..200 {
743 slot.train_and_predict(&features, -0.1, 1.0);
744 }
745
746 assert!(
749 slot.n_samples_seen() < 200,
750 "after 200 samples with max_tree_samples=200, tree should be replaced (got {} samples)",
751 slot.n_samples_seen(),
752 );
753 }
754
755 #[test]
759 fn time_based_replacement_disabled() {
760 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
761 let features = [1.0, 2.0, 3.0];
762
763 for _ in 0..500 {
764 slot.train_and_predict(&features, -0.1, 1.0);
765 }
766
767 assert_eq!(
768 slot.n_samples_seen(),
769 500,
770 "without max_tree_samples, tree should never be proactively replaced",
771 );
772 }
773
774 #[test]
779 fn graduated_shadow_spawns_immediately() {
780 let slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
781
782 assert!(
783 slot.has_alternate(),
784 "graduated mode should spawn shadow immediately"
785 );
786 assert_eq!(slot.shadow_warmup(), 50);
787 }
788
789 #[test]
790 fn graduated_predict_returns_finite() {
791 let mut slot =
792 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
793 let features = [1.0, 2.0, 3.0];
794
795 for _ in 0..100 {
796 slot.train_and_predict(&features, -0.1, 1.0);
797 }
798
799 let pred = slot.predict_graduated(&features);
800 assert!(
801 pred.is_finite(),
802 "graduated prediction should be finite: {}",
803 pred
804 );
805 }
806
807 #[test]
808 fn graduated_shadow_always_respawns() {
809 let mut slot =
810 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(100), 30);
811 let features = [1.0, 2.0, 3.0];
812
813 for _ in 0..130 {
815 slot.train_and_predict(&features, -0.1, 1.0);
816 }
817
818 assert!(
820 slot.has_alternate(),
821 "shadow should be respawned after soft replacement"
822 );
823 }
824
825 #[test]
826 fn graduated_blending_produces_intermediate_values() {
827 let mut slot =
828 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
829 let features = [1.0, 2.0, 3.0];
830
831 for _ in 0..180 {
834 slot.train_and_predict(&features, -0.1, 1.0);
835 }
836
837 let active_pred = slot.predict(&features);
838 let graduated_pred = slot.predict_graduated(&features);
839
840 assert!(active_pred.is_finite());
842 assert!(graduated_pred.is_finite());
843 }
844
845 #[test]
846 fn graduated_reset_preserves_shadow() {
847 let mut slot =
848 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
849
850 slot.reset();
851
852 assert!(
853 slot.has_alternate(),
854 "reset in graduated mode should preserve shadow spawning"
855 );
856 }
857
858 #[test]
859 fn graduated_no_cascading_swap() {
860 let mut slot =
861 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
862
863 for i in 0..250 {
866 let x = (i as f64) * 0.1;
867 let features = [x, x.sin(), x.cos()];
868 let gradient = -0.1 * (1.0 + x.sin());
869 slot.train_and_predict(&features, gradient, 1.0);
870 }
871
872 let replacements_after_first_swap = slot.replacements();
873 assert!(
874 replacements_after_first_swap >= 1,
875 "should have swapped at least once after 250 samples with mts=200"
876 );
877
878 for i in 250..300 {
880 let x = (i as f64) * 0.1;
881 let features = [x, x.sin(), x.cos()];
882 slot.train_and_predict(&features, -0.1, 1.0);
883 }
884
885 assert_eq!(
886 slot.replacements(),
887 replacements_after_first_swap,
888 "should not cascade-swap immediately after promotion"
889 );
890 }
891
892 #[test]
893 fn graduated_without_max_tree_samples_still_works() {
894 let mut slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), None, 50);
896 let features = [1.0, 2.0, 3.0];
897
898 for _ in 0..100 {
899 slot.train_and_predict(&features, -0.1, 1.0);
900 }
901
902 let pred = slot.predict_graduated(&features);
904 assert!(pred.is_finite(), "graduated without mts should be finite");
905 }
906}