1use alloc::boxed::Box;
20use core::fmt;
21
22use crate::drift::{DriftDetector, DriftSignal};
23use crate::tree::builder::TreeConfig;
24use crate::tree::hoeffding::HoeffdingTree;
25use crate::tree::StreamingTree;
26
27pub struct TreeSlot {
35 active: HoeffdingTree,
37 alternate: Option<HoeffdingTree>,
39 detector: Box<dyn DriftDetector>,
41 tree_config: TreeConfig,
43 max_tree_samples: Option<u64>,
45 replacements: u64,
47 pred_count: u64,
49 pred_mean: f64,
51 pred_m2: f64,
53 shadow_warmup: usize,
56 samples_at_activation: u64,
60}
61
62impl fmt::Debug for TreeSlot {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 f.debug_struct("TreeSlot")
65 .field("active_leaves", &self.active.n_leaves())
66 .field("active_samples", &self.active.n_samples_seen())
67 .field("has_alternate", &self.alternate.is_some())
68 .field("tree_config", &self.tree_config)
69 .finish()
70 }
71}
72
73impl Clone for TreeSlot {
74 fn clone(&self) -> Self {
75 Self {
76 active: self.active.clone(),
77 alternate: self.alternate.clone(),
78 detector: self.detector.clone_boxed(),
79 tree_config: self.tree_config.clone(),
80 max_tree_samples: self.max_tree_samples,
81 replacements: self.replacements,
82 pred_count: self.pred_count,
83 pred_mean: self.pred_mean,
84 pred_m2: self.pred_m2,
85 shadow_warmup: self.shadow_warmup,
86 samples_at_activation: self.samples_at_activation,
87 }
88 }
89}
90
91impl TreeSlot {
92 pub fn new(
97 tree_config: TreeConfig,
98 detector: Box<dyn DriftDetector>,
99 max_tree_samples: Option<u64>,
100 ) -> Self {
101 Self::with_shadow_warmup(tree_config, detector, max_tree_samples, 0)
102 }
103
104 pub fn with_shadow_warmup(
110 tree_config: TreeConfig,
111 detector: Box<dyn DriftDetector>,
112 max_tree_samples: Option<u64>,
113 shadow_warmup: usize,
114 ) -> Self {
115 let alternate = if shadow_warmup > 0 {
116 Some(HoeffdingTree::new(tree_config.clone()))
117 } else {
118 None
119 };
120 Self {
121 active: HoeffdingTree::new(tree_config.clone()),
122 alternate,
123 detector,
124 tree_config,
125 max_tree_samples,
126 replacements: 0,
127 pred_count: 0,
128 pred_mean: 0.0,
129 pred_m2: 0.0,
130 shadow_warmup,
131 samples_at_activation: 0,
132 }
133 }
134
135 pub fn from_trees(
140 active: HoeffdingTree,
141 alternate: Option<HoeffdingTree>,
142 tree_config: TreeConfig,
143 detector: Box<dyn DriftDetector>,
144 max_tree_samples: Option<u64>,
145 ) -> Self {
146 Self {
147 active,
148 alternate,
149 detector,
150 tree_config,
151 max_tree_samples,
152 replacements: 0,
153 pred_count: 0,
154 pred_mean: 0.0,
155 pred_m2: 0.0,
156 shadow_warmup: 0,
157 samples_at_activation: 0,
158 }
159 }
160
161 pub fn train_and_predict(&mut self, features: &[f64], gradient: f64, hessian: f64) -> f64 {
182 let prediction = self.active.predict(features);
184
185 self.pred_count += 1;
187 let delta = prediction - self.pred_mean;
188 self.pred_mean += delta / self.pred_count as f64;
189 let delta2 = prediction - self.pred_mean;
190 self.pred_m2 += delta * delta2;
191
192 self.active.train_one(features, gradient, hessian);
194
195 if let Some(ref mut alt) = self.alternate {
197 alt.train_one(features, gradient, hessian);
198 }
199
200 let error = crate::math::abs(gradient);
204 let signal = self.detector.update(error);
205
206 match signal {
208 DriftSignal::Stable => {}
209 DriftSignal::Warning => {
210 if self.alternate.is_none() {
212 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
213 }
214 }
215 DriftSignal::Drift => {
216 self.active = self
219 .alternate
220 .take()
221 .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
222 self.samples_at_activation = self.active.n_samples_seen();
224 self.detector = self.detector.clone_fresh();
226 self.replacements += 1;
228 self.pred_count = 0;
229 self.pred_mean = 0.0;
230 self.pred_m2 = 0.0;
231 if self.shadow_warmup > 0 {
233 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
234 }
235 }
236 }
237
238 if let Some(max_samples) = self.max_tree_samples {
242 let active_age = self
243 .active
244 .n_samples_seen()
245 .saturating_sub(self.samples_at_activation);
246
247 let threshold = if self.shadow_warmup > 0 {
249 (max_samples as f64 * 1.2) as u64
250 } else {
251 max_samples
252 };
253
254 if active_age >= threshold {
255 self.active = self
256 .alternate
257 .take()
258 .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
259 self.samples_at_activation = self.active.n_samples_seen();
261 self.detector = self.detector.clone_fresh();
262 self.replacements += 1;
263 self.pred_count = 0;
264 self.pred_mean = 0.0;
265 self.pred_m2 = 0.0;
266 if self.shadow_warmup > 0 {
268 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
269 }
270 }
271 }
272
273 if self.shadow_warmup > 0 && self.alternate.is_none() {
275 self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
276 }
277
278 prediction
279 }
280
281 #[inline]
286 pub fn predict(&self, features: &[f64]) -> f64 {
287 self.active.predict(features)
288 }
289
290 #[inline]
294 pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
295 self.active.predict_with_variance(features)
296 }
297
298 #[inline]
302 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
303 self.active.predict_smooth(features, bandwidth)
304 }
305
306 #[inline]
308 pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
309 self.active.predict_smooth_auto(features, bandwidths)
310 }
311
312 #[inline]
314 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
315 self.active.predict_interpolated(features)
316 }
317
318 #[inline]
320 pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
321 self.active
322 .predict_sibling_interpolated(features, bandwidths)
323 }
324
325 #[inline]
327 pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
328 self.active.predict_soft_routed(features)
329 }
330
331 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
340 let active_pred = self.active.predict(features);
341
342 if self.shadow_warmup == 0 {
343 return active_pred;
344 }
345
346 let Some(ref shadow) = self.alternate else {
347 return active_pred;
348 };
349
350 let shadow_samples = shadow.n_samples_seen();
351 if shadow_samples < self.shadow_warmup as u64 {
352 return active_pred;
353 }
354
355 let shadow_pred = shadow.predict(features);
356 self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
357 }
358
359 pub fn predict_graduated_sibling_interpolated(
364 &self,
365 features: &[f64],
366 bandwidths: &[f64],
367 ) -> f64 {
368 let active_pred = self
369 .active
370 .predict_sibling_interpolated(features, bandwidths);
371
372 if self.shadow_warmup == 0 {
373 return active_pred;
374 }
375
376 let Some(ref shadow) = self.alternate else {
377 return active_pred;
378 };
379
380 let shadow_samples = shadow.n_samples_seen();
381 if shadow_samples < self.shadow_warmup as u64 {
382 return active_pred;
383 }
384
385 let shadow_pred = shadow.predict_sibling_interpolated(features, bandwidths);
386 self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
387 }
388
389 #[inline]
391 fn blend_active_shadow(&self, active_pred: f64, shadow_pred: f64, shadow_samples: u64) -> f64 {
392 let active_age = self
393 .active
394 .n_samples_seen()
395 .saturating_sub(self.samples_at_activation);
396 let mts = self.max_tree_samples.unwrap_or(u64::MAX) as f64;
397
398 let active_w = if (active_age as f64) < mts * 0.8 {
400 1.0
401 } else {
402 let progress = (active_age as f64 - mts * 0.8) / (mts * 0.4);
403 (1.0 - progress).clamp(0.0, 1.0)
404 };
405
406 let shadow_w = ((shadow_samples as f64 - self.shadow_warmup as f64)
408 / self.shadow_warmup as f64)
409 .clamp(0.0, 1.0);
410
411 let total = active_w + shadow_w;
413 if total < 1e-10 {
414 return shadow_pred;
415 }
416
417 (active_w * active_pred + shadow_w * shadow_pred) / total
418 }
419
420 #[inline]
424 pub fn set_max_tree_samples(&mut self, max: Option<u64>) {
425 self.max_tree_samples = max;
426 }
427
428 #[inline]
430 pub fn shadow_warmup(&self) -> usize {
431 self.shadow_warmup
432 }
433
434 #[inline]
436 pub fn replacements(&self) -> u64 {
437 self.replacements
438 }
439
440 #[inline]
442 pub fn prediction_mean(&self) -> f64 {
443 self.pred_mean
444 }
445
446 #[inline]
448 pub fn prediction_std(&self) -> f64 {
449 if self.pred_count < 2 {
450 0.0
451 } else {
452 crate::math::sqrt(self.pred_m2 / (self.pred_count - 1) as f64)
453 }
454 }
455
456 #[inline]
458 pub fn n_leaves(&self) -> usize {
459 self.active.n_leaves()
460 }
461
462 #[inline]
464 pub fn n_samples_seen(&self) -> u64 {
465 self.active.n_samples_seen()
466 }
467
468 #[inline]
470 pub fn has_alternate(&self) -> bool {
471 self.alternate.is_some()
472 }
473
474 #[inline]
476 pub fn split_gains(&self) -> &[f64] {
477 self.active.split_gains()
478 }
479
480 #[inline]
482 pub fn active_tree(&self) -> &HoeffdingTree {
483 &self.active
484 }
485
486 #[inline]
488 pub fn alternate_tree(&self) -> Option<&HoeffdingTree> {
489 self.alternate.as_ref()
490 }
491
492 #[inline]
494 pub fn tree_config(&self) -> &TreeConfig {
495 &self.tree_config
496 }
497
498 #[inline]
500 pub fn detector(&self) -> &dyn DriftDetector {
501 &*self.detector
502 }
503
504 #[inline]
506 pub fn detector_mut(&mut self) -> &mut dyn DriftDetector {
507 &mut *self.detector
508 }
509
510 #[inline]
514 pub fn alt_detector(&self) -> Option<&dyn DriftDetector> {
515 None
517 }
518
519 #[inline]
521 pub fn alt_detector_mut(&mut self) -> Option<&mut dyn DriftDetector> {
522 None
523 }
524
525 pub fn reset(&mut self) {
527 self.active = HoeffdingTree::new(self.tree_config.clone());
528 self.alternate = if self.shadow_warmup > 0 {
529 Some(HoeffdingTree::new(self.tree_config.clone()))
530 } else {
531 None
532 };
533 self.detector = self.detector.clone_fresh();
534 self.replacements = 0;
535 self.pred_count = 0;
536 self.pred_mean = 0.0;
537 self.pred_m2 = 0.0;
538 self.samples_at_activation = 0;
539 }
540}
541
542#[cfg(test)]
547mod tests {
548 use super::*;
549 use crate::drift::pht::PageHinkleyTest;
550 use alloc::boxed::Box;
551 use alloc::format;
552
553 fn test_tree_config() -> TreeConfig {
555 TreeConfig::new()
556 .grace_period(20)
557 .max_depth(4)
558 .n_bins(16)
559 .lambda(1.0)
560 }
561
562 fn test_detector() -> Box<dyn DriftDetector> {
564 Box::new(PageHinkleyTest::new())
565 }
566
567 #[test]
571 fn new_slot_predicts_zero() {
572 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
573
574 let pred = slot.predict(&[1.0, 2.0, 3.0]);
576 assert!(
577 pred.abs() < 1e-12,
578 "fresh slot should predict ~0.0, got {}",
579 pred,
580 );
581 }
582
583 #[test]
587 fn train_and_predict_returns_prediction() {
588 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
589
590 let features = [1.0, 2.0, 3.0];
591 let pred = slot.train_and_predict(&features, -0.5, 1.0);
592
593 assert!(
595 pred.abs() < 1e-12,
596 "first prediction should be ~0.0, got {}",
597 pred,
598 );
599
600 let pred2 = slot.predict(&features);
603 assert!(
604 pred2.is_finite(),
605 "prediction after training should be finite"
606 );
607 }
608
609 #[test]
613 fn stable_stream_no_alternate() {
614 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
615 let features = [1.0, 2.0, 3.0];
616
617 for _ in 0..500 {
621 slot.train_and_predict(&features, -0.1, 1.0);
622 }
623
624 assert!(
625 !slot.has_alternate(),
626 "stable error stream should not spawn an alternate tree",
627 );
628 }
629
630 #[test]
634 fn reset_returns_to_fresh_state() {
635 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
636 let features = [1.0, 2.0, 3.0];
637
638 for _ in 0..100 {
640 slot.train_and_predict(&features, -0.5, 1.0);
641 }
642
643 assert!(slot.n_samples_seen() > 0, "should have trained samples");
644
645 slot.reset();
646
647 assert_eq!(
648 slot.n_leaves(),
649 1,
650 "after reset, should have exactly 1 leaf"
651 );
652 assert_eq!(
653 slot.n_samples_seen(),
654 0,
655 "after reset, samples_seen should be 0"
656 );
657 assert!(
658 !slot.has_alternate(),
659 "after reset, no alternate should exist"
660 );
661
662 let pred = slot.predict(&features);
664 assert!(
665 pred.abs() < 1e-12,
666 "prediction after reset should be ~0.0, got {}",
667 pred,
668 );
669 }
670
671 #[test]
675 fn predict_without_training() {
676 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
677
678 for i in 0..10 {
680 let x = (i as f64) * 0.5;
681 let pred = slot.predict(&[x, x + 1.0]);
682 assert!(
683 pred.abs() < 1e-12,
684 "untrained slot should predict ~0.0 for any input, got {} at i={}",
685 pred,
686 i,
687 );
688 }
689 }
690
691 #[test]
695 fn drift_replaces_active_tree() {
696 let sensitive_detector = Box::new(PageHinkleyTest::with_params(0.005, 5.0));
698 let mut slot = TreeSlot::new(test_tree_config(), sensitive_detector, None);
699 let features = [1.0, 2.0, 3.0];
700
701 for _ in 0..200 {
703 slot.train_and_predict(&features, -0.01, 1.0);
704 }
705 let samples_before_drift = slot.n_samples_seen();
706
707 let mut drift_occurred = false;
709 for _ in 0..500 {
710 slot.train_and_predict(&features, -50.0, 1.0);
711 if slot.n_samples_seen() < samples_before_drift {
714 drift_occurred = true;
715 break;
716 }
717 }
718
719 assert!(
720 drift_occurred,
721 "abrupt gradient shift should trigger drift and replace the active tree",
722 );
723 }
724
725 #[test]
729 fn n_leaves_reflects_active_tree() {
730 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
731 assert_eq!(slot.n_leaves(), 1, "fresh slot should have exactly 1 leaf",);
732 }
733
734 #[test]
738 fn debug_format_does_not_panic() {
739 let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
740 let debug_str = format!("{:?}", slot);
741 assert!(
742 debug_str.contains("TreeSlot"),
743 "debug output should contain 'TreeSlot'",
744 );
745 }
746
747 #[test]
751 fn time_based_replacement_triggers() {
752 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), Some(200));
753 let features = [1.0, 2.0, 3.0];
754
755 for _ in 0..200 {
757 slot.train_and_predict(&features, -0.1, 1.0);
758 }
759
760 assert!(
763 slot.n_samples_seen() < 200,
764 "after 200 samples with max_tree_samples=200, tree should be replaced (got {} samples)",
765 slot.n_samples_seen(),
766 );
767 }
768
769 #[test]
773 fn time_based_replacement_disabled() {
774 let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
775 let features = [1.0, 2.0, 3.0];
776
777 for _ in 0..500 {
778 slot.train_and_predict(&features, -0.1, 1.0);
779 }
780
781 assert_eq!(
782 slot.n_samples_seen(),
783 500,
784 "without max_tree_samples, tree should never be proactively replaced",
785 );
786 }
787
788 #[test]
793 fn graduated_shadow_spawns_immediately() {
794 let slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
795
796 assert!(
797 slot.has_alternate(),
798 "graduated mode should spawn shadow immediately"
799 );
800 assert_eq!(slot.shadow_warmup(), 50);
801 }
802
803 #[test]
804 fn graduated_predict_returns_finite() {
805 let mut slot =
806 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
807 let features = [1.0, 2.0, 3.0];
808
809 for _ in 0..100 {
810 slot.train_and_predict(&features, -0.1, 1.0);
811 }
812
813 let pred = slot.predict_graduated(&features);
814 assert!(
815 pred.is_finite(),
816 "graduated prediction should be finite: {}",
817 pred
818 );
819 }
820
821 #[test]
822 fn graduated_shadow_always_respawns() {
823 let mut slot =
824 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(100), 30);
825 let features = [1.0, 2.0, 3.0];
826
827 for _ in 0..130 {
829 slot.train_and_predict(&features, -0.1, 1.0);
830 }
831
832 assert!(
834 slot.has_alternate(),
835 "shadow should be respawned after soft replacement"
836 );
837 }
838
839 #[test]
840 fn graduated_blending_produces_intermediate_values() {
841 let mut slot =
842 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
843 let features = [1.0, 2.0, 3.0];
844
845 for _ in 0..180 {
848 slot.train_and_predict(&features, -0.1, 1.0);
849 }
850
851 let active_pred = slot.predict(&features);
852 let graduated_pred = slot.predict_graduated(&features);
853
854 assert!(active_pred.is_finite());
856 assert!(graduated_pred.is_finite());
857 }
858
859 #[test]
860 fn graduated_reset_preserves_shadow() {
861 let mut slot =
862 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
863
864 slot.reset();
865
866 assert!(
867 slot.has_alternate(),
868 "reset in graduated mode should preserve shadow spawning"
869 );
870 }
871
872 #[test]
873 fn graduated_no_cascading_swap() {
874 let mut slot =
875 TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
876
877 for i in 0..250 {
880 let x = (i as f64) * 0.1;
881 let features = [x, x.sin(), x.cos()];
882 let gradient = -0.1 * (1.0 + x.sin());
883 slot.train_and_predict(&features, gradient, 1.0);
884 }
885
886 let replacements_after_first_swap = slot.replacements();
887 assert!(
888 replacements_after_first_swap >= 1,
889 "should have swapped at least once after 250 samples with mts=200"
890 );
891
892 for i in 250..300 {
894 let x = (i as f64) * 0.1;
895 let features = [x, x.sin(), x.cos()];
896 slot.train_and_predict(&features, -0.1, 1.0);
897 }
898
899 assert_eq!(
900 slot.replacements(),
901 replacements_after_first_swap,
902 "should not cascade-swap immediately after promotion"
903 );
904 }
905
906 #[test]
907 fn graduated_without_max_tree_samples_still_works() {
908 let mut slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), None, 50);
910 let features = [1.0, 2.0, 3.0];
911
912 for _ in 0..100 {
913 slot.train_and_predict(&features, -0.1, 1.0);
914 }
915
916 let pred = slot.predict_graduated(&features);
918 assert!(pred.is_finite(), "graduated without mts should be finite");
919 }
920}