Skip to main content

irithyll_core/ensemble/
replacement.rs

1//! TreeSlot: warning/danger/swap lifecycle for tree replacement.
2//!
3//! Implements the drift-triggered tree replacement strategy from
4//! Gunasekara et al. (2024). Each boosting step owns a "slot" that manages:
5//!
6//! - An **active** tree serving predictions and receiving training samples.
7//! - An optional **alternate** tree that begins training when a Warning signal
8//!   is emitted by the drift detector.
9//! - A **drift detector** monitoring prediction error magnitude.
10//!
11//! # Lifecycle
12//!
13//! ```text
14//!   Stable  --> keep training active tree
15//!   Warning --> spawn alternate tree (if not already training)
16//!   Drift   --> replace active with alternate (or fresh tree), reset detector
17//! ```
18
19use alloc::boxed::Box;
20use core::fmt;
21
22use crate::drift::{DriftDetector, DriftSignal};
23use crate::tree::builder::TreeConfig;
24use crate::tree::hoeffding::HoeffdingTree;
25use crate::tree::StreamingTree;
26
27/// Manages the lifecycle of a single tree in the ensemble.
28///
29/// When the drift detector signals [`DriftSignal::Warning`], an alternate tree
30/// begins training alongside the active tree. When [`DriftSignal::Drift`] is
31/// confirmed, the alternate replaces the active tree (or a fresh tree is
32/// created if no alternate exists). The drift detector is then reset via
33/// [`clone_fresh`](DriftDetector::clone_fresh) to monitor the new tree.
34pub struct TreeSlot {
35    /// The currently active tree serving predictions.
36    active: HoeffdingTree,
37    /// Optional alternate tree being trained during a warning period.
38    alternate: Option<HoeffdingTree>,
39    /// Drift detector monitoring this slot's error stream.
40    detector: Box<dyn DriftDetector>,
41    /// Configuration for creating new trees (shared across replacements).
42    tree_config: TreeConfig,
43    /// Maximum samples before proactive replacement. `None` = disabled.
44    max_tree_samples: Option<u64>,
45    /// Total number of tree replacements (drift or time-based).
46    replacements: u64,
47    /// Welford online count for prediction statistics.
48    pred_count: u64,
49    /// Welford online mean of predictions.
50    pred_mean: f64,
51    /// Welford online M2 accumulator for prediction variance.
52    pred_m2: f64,
53    /// Shadow warmup samples (0 = disabled). When > 0, an always-on shadow
54    /// tree is spawned and trained alongside the active tree.
55    shadow_warmup: usize,
56    /// Sample count of the active tree when it was activated (promoted from shadow).
57    /// Used to compute samples-since-activation for time-based replacement,
58    /// preventing cascading swaps when a shadow is promoted with a high sample count.
59    samples_at_activation: u64,
60}
61
62impl fmt::Debug for TreeSlot {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        f.debug_struct("TreeSlot")
65            .field("active_leaves", &self.active.n_leaves())
66            .field("active_samples", &self.active.n_samples_seen())
67            .field("has_alternate", &self.alternate.is_some())
68            .field("tree_config", &self.tree_config)
69            .finish()
70    }
71}
72
73impl Clone for TreeSlot {
74    fn clone(&self) -> Self {
75        Self {
76            active: self.active.clone(),
77            alternate: self.alternate.clone(),
78            detector: self.detector.clone_boxed(),
79            tree_config: self.tree_config.clone(),
80            max_tree_samples: self.max_tree_samples,
81            replacements: self.replacements,
82            pred_count: self.pred_count,
83            pred_mean: self.pred_mean,
84            pred_m2: self.pred_m2,
85            shadow_warmup: self.shadow_warmup,
86            samples_at_activation: self.samples_at_activation,
87        }
88    }
89}
90
91impl TreeSlot {
92    /// Create a new `TreeSlot` with a fresh tree and drift detector.
93    ///
94    /// The active tree starts as a single-leaf tree (prediction = 0.0).
95    /// No alternate tree is created until a Warning signal is received.
96    pub fn new(
97        tree_config: TreeConfig,
98        detector: Box<dyn DriftDetector>,
99        max_tree_samples: Option<u64>,
100    ) -> Self {
101        Self::with_shadow_warmup(tree_config, detector, max_tree_samples, 0)
102    }
103
104    /// Create a new `TreeSlot` with graduated tree handoff enabled.
105    ///
106    /// When `shadow_warmup > 0`, an always-on shadow tree is spawned immediately
107    /// and trained alongside the active tree. After `shadow_warmup` samples, the
108    /// shadow begins contributing to `predict_graduated()` predictions.
109    pub fn with_shadow_warmup(
110        tree_config: TreeConfig,
111        detector: Box<dyn DriftDetector>,
112        max_tree_samples: Option<u64>,
113        shadow_warmup: usize,
114    ) -> Self {
115        let alternate = if shadow_warmup > 0 {
116            Some(HoeffdingTree::new(tree_config.clone()))
117        } else {
118            None
119        };
120        Self {
121            active: HoeffdingTree::new(tree_config.clone()),
122            alternate,
123            detector,
124            tree_config,
125            max_tree_samples,
126            replacements: 0,
127            pred_count: 0,
128            pred_mean: 0.0,
129            pred_m2: 0.0,
130            shadow_warmup,
131            samples_at_activation: 0,
132        }
133    }
134
135    /// Reconstruct a `TreeSlot` from pre-built trees and a fresh drift detector.
136    ///
137    /// Used during model deserialization to restore tree state without replaying
138    /// the training stream.
139    pub fn from_trees(
140        active: HoeffdingTree,
141        alternate: Option<HoeffdingTree>,
142        tree_config: TreeConfig,
143        detector: Box<dyn DriftDetector>,
144        max_tree_samples: Option<u64>,
145    ) -> Self {
146        Self {
147            active,
148            alternate,
149            detector,
150            tree_config,
151            max_tree_samples,
152            replacements: 0,
153            pred_count: 0,
154            pred_mean: 0.0,
155            pred_m2: 0.0,
156            shadow_warmup: 0,
157            samples_at_activation: 0,
158        }
159    }
160
161    /// Train the active tree (and alternate if it exists) on a single sample.
162    ///
163    /// The absolute value of the gradient is fed to the drift detector as an
164    /// error proxy (gradient = derivative of loss = prediction error signal).
165    ///
166    /// # Returns
167    ///
168    /// The prediction from the **active** tree **before** training on this sample.
169    /// This ensures the prediction reflects only previously seen data, which is
170    /// critical for unbiased gradient computation in the boosting loop.
171    ///
172    /// # Drift handling
173    ///
174    /// - [`DriftSignal::Stable`]: no action.
175    /// - [`DriftSignal::Warning`]: spawn an alternate tree if one is not already
176    ///   being trained. The alternate receives the same training sample.
177    /// - [`DriftSignal::Drift`]: replace the active tree with the alternate
178    ///   (or a fresh tree if no alternate exists). The drift detector is reset
179    ///   via [`clone_fresh`](DriftDetector::clone_fresh) so it monitors the
180    ///   new tree from a clean state.
181    pub fn train_and_predict(&mut self, features: &[f64], gradient: f64, hessian: f64) -> f64 {
182        // 1. Predict from active tree BEFORE training.
183        let prediction = self.active.predict(features);
184
185        // 1b. Update Welford running prediction statistics.
186        self.pred_count += 1;
187        let delta = prediction - self.pred_mean;
188        self.pred_mean += delta / self.pred_count as f64;
189        let delta2 = prediction - self.pred_mean;
190        self.pred_m2 += delta * delta2;
191
192        // 2. Train the active tree.
193        self.active.train_one(features, gradient, hessian);
194
195        // 3. Train the alternate tree if it exists.
196        if let Some(ref mut alt) = self.alternate {
197            alt.train_one(features, gradient, hessian);
198        }
199
200        // 4. Feed error magnitude to the drift detector.
201        //    |gradient| is a proxy for prediction error: for squared loss,
202        //    gradient = (prediction - target), so |gradient| = |error|.
203        let error = crate::math::abs(gradient);
204        let signal = self.detector.update(error);
205
206        // 5. React to the drift signal.
207        match signal {
208            DriftSignal::Stable => {}
209            DriftSignal::Warning => {
210                // Start training an alternate tree if not already doing so.
211                if self.alternate.is_none() {
212                    self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
213                }
214            }
215            DriftSignal::Drift => {
216                // Replace active tree: prefer the alternate (which has been
217                // training on recent data), fall back to a fresh tree.
218                self.active = self
219                    .alternate
220                    .take()
221                    .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
222                // Record activation point to prevent cascading swaps.
223                self.samples_at_activation = self.active.n_samples_seen();
224                // Reset the drift detector to monitor the new tree cleanly.
225                self.detector = self.detector.clone_fresh();
226                // Track replacement and reset prediction stats for the new tree.
227                self.replacements += 1;
228                self.pred_count = 0;
229                self.pred_mean = 0.0;
230                self.pred_m2 = 0.0;
231                // In graduated mode, immediately spawn a new shadow.
232                if self.shadow_warmup > 0 {
233                    self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
234                }
235            }
236        }
237
238        // 6. Proactive time-based replacement.
239        //    Compare samples-since-activation (not total lifetime) to prevent
240        //    cascading swaps when a shadow with high sample count is promoted.
241        if let Some(max_samples) = self.max_tree_samples {
242            let active_age = self
243                .active
244                .n_samples_seen()
245                .saturating_sub(self.samples_at_activation);
246
247            // In graduated mode, wait until 120% of max_samples for soft replacement.
248            let threshold = if self.shadow_warmup > 0 {
249                (max_samples as f64 * 1.2) as u64
250            } else {
251                max_samples
252            };
253
254            if active_age >= threshold {
255                self.active = self
256                    .alternate
257                    .take()
258                    .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
259                // Record activation point to prevent cascading swaps.
260                self.samples_at_activation = self.active.n_samples_seen();
261                self.detector = self.detector.clone_fresh();
262                self.replacements += 1;
263                self.pred_count = 0;
264                self.pred_mean = 0.0;
265                self.pred_m2 = 0.0;
266                // In graduated mode, immediately spawn a new shadow.
267                if self.shadow_warmup > 0 {
268                    self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
269                }
270            }
271        }
272
273        // 7. In graduated mode, ensure shadow always exists.
274        if self.shadow_warmup > 0 && self.alternate.is_none() {
275            self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
276        }
277
278        prediction
279    }
280
281    /// Predict without training.
282    ///
283    /// Routes the feature vector through the active tree and returns the
284    /// leaf value. Does not update any state.
285    #[inline]
286    pub fn predict(&self, features: &[f64]) -> f64 {
287        self.active.predict(features)
288    }
289
290    /// Predict with variance for confidence estimation.
291    ///
292    /// Returns `(leaf_value, variance)` where variance = 1 / (H_sum + lambda).
293    #[inline]
294    pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
295        self.active.predict_with_variance(features)
296    }
297
298    /// Predict using sigmoid-blended soft routing for smooth interpolation.
299    ///
300    /// See [`crate::tree::hoeffding::HoeffdingTree::predict_smooth`] for details.
301    #[inline]
302    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
303        self.active.predict_smooth(features, bandwidth)
304    }
305
306    /// Predict using per-feature auto-calibrated bandwidths.
307    #[inline]
308    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
309        self.active.predict_smooth_auto(features, bandwidths)
310    }
311
312    /// Predict with parent-leaf linear interpolation.
313    #[inline]
314    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
315        self.active.predict_interpolated(features)
316    }
317
318    /// Predict with sibling-based interpolation for feature-continuous predictions.
319    #[inline]
320    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
321        self.active
322            .predict_sibling_interpolated(features, bandwidths)
323    }
324
325    /// Predict using per-node auto-bandwidth soft routing.
326    #[inline]
327    pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
328        self.active.predict_soft_routed(features)
329    }
330
331    /// Predict with graduated active-shadow blending.
332    ///
333    /// When `shadow_warmup > 0`, blends the active tree's prediction with the
334    /// shadow's prediction based on relative maturity:
335    /// - Active weight decays from 1.0 to 0.0 as it ages from 80% to 120% of `max_tree_samples`
336    /// - Shadow weight ramps from 0.0 to 1.0 over `shadow_warmup` samples after warmup
337    ///
338    /// When `shadow_warmup == 0` or no shadow exists, returns the active prediction.
339    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
340        let active_pred = self.active.predict(features);
341
342        if self.shadow_warmup == 0 {
343            return active_pred;
344        }
345
346        let Some(ref shadow) = self.alternate else {
347            return active_pred;
348        };
349
350        let shadow_samples = shadow.n_samples_seen();
351        if shadow_samples < self.shadow_warmup as u64 {
352            return active_pred;
353        }
354
355        let shadow_pred = shadow.predict(features);
356        self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
357    }
358
359    /// Predict with graduated blending + sibling interpolation (premium path).
360    ///
361    /// Combines graduated active-shadow handoff with feature-continuous sibling
362    /// interpolation for the smoothest possible prediction surface.
363    pub fn predict_graduated_sibling_interpolated(
364        &self,
365        features: &[f64],
366        bandwidths: &[f64],
367    ) -> f64 {
368        let active_pred = self
369            .active
370            .predict_sibling_interpolated(features, bandwidths);
371
372        if self.shadow_warmup == 0 {
373            return active_pred;
374        }
375
376        let Some(ref shadow) = self.alternate else {
377            return active_pred;
378        };
379
380        let shadow_samples = shadow.n_samples_seen();
381        if shadow_samples < self.shadow_warmup as u64 {
382            return active_pred;
383        }
384
385        let shadow_pred = shadow.predict_sibling_interpolated(features, bandwidths);
386        self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
387    }
388
389    /// Compute the graduated blend of active and shadow predictions.
390    #[inline]
391    fn blend_active_shadow(&self, active_pred: f64, shadow_pred: f64, shadow_samples: u64) -> f64 {
392        let active_age = self
393            .active
394            .n_samples_seen()
395            .saturating_sub(self.samples_at_activation);
396        let mts = self.max_tree_samples.unwrap_or(u64::MAX) as f64;
397
398        // Active weight: 1.0 until 80% of mts, then linear decay to 0.0 at 120%
399        let active_w = if (active_age as f64) < mts * 0.8 {
400            1.0
401        } else {
402            let progress = (active_age as f64 - mts * 0.8) / (mts * 0.4);
403            (1.0 - progress).clamp(0.0, 1.0)
404        };
405
406        // Shadow weight: 0.0 until shadow_warmup, then ramp to 1.0 over shadow_warmup samples
407        let shadow_w = ((shadow_samples as f64 - self.shadow_warmup as f64)
408            / self.shadow_warmup as f64)
409            .clamp(0.0, 1.0);
410
411        // Normalize
412        let total = active_w + shadow_w;
413        if total < 1e-10 {
414            return shadow_pred;
415        }
416
417        (active_w * active_pred + shadow_w * shadow_pred) / total
418    }
419
420    /// Dynamically update the max_tree_samples threshold.
421    ///
422    /// Used by adaptive_mts to modulate tree lifetime based on contribution variance.
423    #[inline]
424    pub fn set_max_tree_samples(&mut self, max: Option<u64>) {
425        self.max_tree_samples = max;
426    }
427
428    /// Shadow warmup configuration (0 = disabled).
429    #[inline]
430    pub fn shadow_warmup(&self) -> usize {
431        self.shadow_warmup
432    }
433
434    /// Total number of tree replacements (drift or time-based).
435    #[inline]
436    pub fn replacements(&self) -> u64 {
437        self.replacements
438    }
439
440    /// Running mean of predictions from the active tree.
441    #[inline]
442    pub fn prediction_mean(&self) -> f64 {
443        self.pred_mean
444    }
445
446    /// Running standard deviation of predictions from the active tree.
447    #[inline]
448    pub fn prediction_std(&self) -> f64 {
449        if self.pred_count < 2 {
450            0.0
451        } else {
452            crate::math::sqrt(self.pred_m2 / (self.pred_count - 1) as f64)
453        }
454    }
455
456    /// Number of leaves in the active tree.
457    #[inline]
458    pub fn n_leaves(&self) -> usize {
459        self.active.n_leaves()
460    }
461
462    /// Total samples the active tree has seen.
463    #[inline]
464    pub fn n_samples_seen(&self) -> u64 {
465        self.active.n_samples_seen()
466    }
467
468    /// Whether an alternate tree is currently being trained.
469    #[inline]
470    pub fn has_alternate(&self) -> bool {
471        self.alternate.is_some()
472    }
473
474    /// Accumulated split gains per feature from the active tree.
475    #[inline]
476    pub fn split_gains(&self) -> &[f64] {
477        self.active.split_gains()
478    }
479
480    /// Immutable access to the active tree.
481    #[inline]
482    pub fn active_tree(&self) -> &HoeffdingTree {
483        &self.active
484    }
485
486    /// Immutable access to the alternate tree (if one is being trained).
487    #[inline]
488    pub fn alternate_tree(&self) -> Option<&HoeffdingTree> {
489        self.alternate.as_ref()
490    }
491
492    /// Immutable access to the tree configuration.
493    #[inline]
494    pub fn tree_config(&self) -> &TreeConfig {
495        &self.tree_config
496    }
497
498    /// Immutable access to the drift detector.
499    #[inline]
500    pub fn detector(&self) -> &dyn DriftDetector {
501        &*self.detector
502    }
503
504    /// Mutable access to the drift detector.
505    #[inline]
506    pub fn detector_mut(&mut self) -> &mut dyn DriftDetector {
507        &mut *self.detector
508    }
509
510    /// Immutable access to the alternate drift detector (always `None` in
511    /// the current architecture -- the alternate tree shares the main detector).
512    /// Reserved for future use.
513    #[inline]
514    pub fn alt_detector(&self) -> Option<&dyn DriftDetector> {
515        // Currently there's no separate alt detector.
516        None
517    }
518
519    /// Mutable access to the alternate drift detector.
520    #[inline]
521    pub fn alt_detector_mut(&mut self) -> Option<&mut dyn DriftDetector> {
522        None
523    }
524
525    /// Reset to a completely fresh state: new tree, no alternate, reset detector.
526    pub fn reset(&mut self) {
527        self.active = HoeffdingTree::new(self.tree_config.clone());
528        self.alternate = if self.shadow_warmup > 0 {
529            Some(HoeffdingTree::new(self.tree_config.clone()))
530        } else {
531            None
532        };
533        self.detector = self.detector.clone_fresh();
534        self.replacements = 0;
535        self.pred_count = 0;
536        self.pred_mean = 0.0;
537        self.pred_m2 = 0.0;
538        self.samples_at_activation = 0;
539    }
540}
541
542// ---------------------------------------------------------------------------
543// Tests
544// ---------------------------------------------------------------------------
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use crate::drift::pht::PageHinkleyTest;
550    use alloc::boxed::Box;
551    use alloc::format;
552
553    /// Create a default TreeConfig for tests (small grace period for fast splits).
554    fn test_tree_config() -> TreeConfig {
555        TreeConfig::new()
556            .grace_period(20)
557            .max_depth(4)
558            .n_bins(16)
559            .lambda(1.0)
560    }
561
562    /// Create a default drift detector for tests.
563    fn test_detector() -> Box<dyn DriftDetector> {
564        Box::new(PageHinkleyTest::new())
565    }
566
567    // -------------------------------------------------------------------
568    // Test 1: TreeSlot::new creates a functional slot; predict returns 0.0.
569    // -------------------------------------------------------------------
570    #[test]
571    fn new_slot_predicts_zero() {
572        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
573
574        // A fresh tree with no training data should predict 0.0.
575        let pred = slot.predict(&[1.0, 2.0, 3.0]);
576        assert!(
577            pred.abs() < 1e-12,
578            "fresh slot should predict ~0.0, got {}",
579            pred,
580        );
581    }
582
583    // -------------------------------------------------------------------
584    // Test 2: train_and_predict returns a prediction and does not panic.
585    // -------------------------------------------------------------------
586    #[test]
587    fn train_and_predict_returns_prediction() {
588        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
589
590        let features = [1.0, 2.0, 3.0];
591        let pred = slot.train_and_predict(&features, -0.5, 1.0);
592
593        // First prediction should be 0.0 (tree was empty before training).
594        assert!(
595            pred.abs() < 1e-12,
596            "first prediction should be ~0.0, got {}",
597            pred,
598        );
599
600        // After training, the tree should have updated, so a second predict
601        // should be non-zero (gradient=-0.5 pushes leaf weight positive).
602        let pred2 = slot.predict(&features);
603        assert!(
604            pred2.is_finite(),
605            "prediction after training should be finite"
606        );
607    }
608
609    // -------------------------------------------------------------------
610    // Test 3: After many stable samples, no alternate tree is spawned.
611    // -------------------------------------------------------------------
612    #[test]
613    fn stable_stream_no_alternate() {
614        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
615        let features = [1.0, 2.0, 3.0];
616
617        // Feed many stable samples with small, consistent gradients.
618        // With a constant error of 0.1, the PHT running mean settles and
619        // no warning/drift should trigger.
620        for _ in 0..500 {
621            slot.train_and_predict(&features, -0.1, 1.0);
622        }
623
624        assert!(
625            !slot.has_alternate(),
626            "stable error stream should not spawn an alternate tree",
627        );
628    }
629
630    // -------------------------------------------------------------------
631    // Test 4: Reset returns to fresh state.
632    // -------------------------------------------------------------------
633    #[test]
634    fn reset_returns_to_fresh_state() {
635        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
636        let features = [1.0, 2.0, 3.0];
637
638        // Train several samples.
639        for _ in 0..100 {
640            slot.train_and_predict(&features, -0.5, 1.0);
641        }
642
643        assert!(slot.n_samples_seen() > 0, "should have trained samples");
644
645        slot.reset();
646
647        assert_eq!(
648            slot.n_leaves(),
649            1,
650            "after reset, should have exactly 1 leaf"
651        );
652        assert_eq!(
653            slot.n_samples_seen(),
654            0,
655            "after reset, samples_seen should be 0"
656        );
657        assert!(
658            !slot.has_alternate(),
659            "after reset, no alternate should exist"
660        );
661
662        // Predict should return 0.0 again.
663        let pred = slot.predict(&features);
664        assert!(
665            pred.abs() < 1e-12,
666            "prediction after reset should be ~0.0, got {}",
667            pred,
668        );
669    }
670
671    // -------------------------------------------------------------------
672    // Test 5: Predict without training works.
673    // -------------------------------------------------------------------
674    #[test]
675    fn predict_without_training() {
676        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
677
678        // Multiple predict calls on a fresh slot should all return 0.0.
679        for i in 0..10 {
680            let x = (i as f64) * 0.5;
681            let pred = slot.predict(&[x, x + 1.0]);
682            assert!(
683                pred.abs() < 1e-12,
684                "untrained slot should predict ~0.0 for any input, got {} at i={}",
685                pred,
686                i,
687            );
688        }
689    }
690
691    // -------------------------------------------------------------------
692    // Test 6: Drift replaces the active tree.
693    // -------------------------------------------------------------------
694    #[test]
695    fn drift_replaces_active_tree() {
696        // Use a very sensitive detector: small lambda triggers drift quickly.
697        let sensitive_detector = Box::new(PageHinkleyTest::with_params(0.005, 5.0));
698        let mut slot = TreeSlot::new(test_tree_config(), sensitive_detector, None);
699        let features = [1.0, 2.0, 3.0];
700
701        // Phase 1: stable training with small gradients.
702        for _ in 0..200 {
703            slot.train_and_predict(&features, -0.01, 1.0);
704        }
705        let samples_before_drift = slot.n_samples_seen();
706
707        // Phase 2: abrupt shift in gradient magnitude to trigger drift.
708        let mut drift_occurred = false;
709        for _ in 0..500 {
710            slot.train_and_predict(&features, -50.0, 1.0);
711            // If drift occurred, the tree was replaced and samples_seen resets
712            // (new tree starts from 0).
713            if slot.n_samples_seen() < samples_before_drift {
714                drift_occurred = true;
715                break;
716            }
717        }
718
719        assert!(
720            drift_occurred,
721            "abrupt gradient shift should trigger drift and replace the active tree",
722        );
723    }
724
725    // -------------------------------------------------------------------
726    // Test 7: n_leaves reflects the active tree.
727    // -------------------------------------------------------------------
728    #[test]
729    fn n_leaves_reflects_active_tree() {
730        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
731        assert_eq!(slot.n_leaves(), 1, "fresh slot should have exactly 1 leaf",);
732    }
733
734    // -------------------------------------------------------------------
735    // Test 8: Debug formatting works.
736    // -------------------------------------------------------------------
737    #[test]
738    fn debug_format_does_not_panic() {
739        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
740        let debug_str = format!("{:?}", slot);
741        assert!(
742            debug_str.contains("TreeSlot"),
743            "debug output should contain 'TreeSlot'",
744        );
745    }
746
747    // -------------------------------------------------------------------
748    // Test 9: Time-based replacement triggers after max_tree_samples.
749    // -------------------------------------------------------------------
750    #[test]
751    fn time_based_replacement_triggers() {
752        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), Some(200));
753        let features = [1.0, 2.0, 3.0];
754
755        // Train up to the limit.
756        for _ in 0..200 {
757            slot.train_and_predict(&features, -0.1, 1.0);
758        }
759
760        // At exactly 200 samples, the tree should have been replaced.
761        // The new tree has 0 samples seen (or the most recently trained sample).
762        assert!(
763            slot.n_samples_seen() < 200,
764            "after 200 samples with max_tree_samples=200, tree should be replaced (got {} samples)",
765            slot.n_samples_seen(),
766        );
767    }
768
769    // -------------------------------------------------------------------
770    // Test 10: Time-based replacement disabled (None) never triggers.
771    // -------------------------------------------------------------------
772    #[test]
773    fn time_based_replacement_disabled() {
774        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
775        let features = [1.0, 2.0, 3.0];
776
777        for _ in 0..500 {
778            slot.train_and_predict(&features, -0.1, 1.0);
779        }
780
781        assert_eq!(
782            slot.n_samples_seen(),
783            500,
784            "without max_tree_samples, tree should never be proactively replaced",
785        );
786    }
787
788    // -------------------------------------------------------------------
789    // Graduated handoff tests
790    // -------------------------------------------------------------------
791
792    #[test]
793    fn graduated_shadow_spawns_immediately() {
794        let slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
795
796        assert!(
797            slot.has_alternate(),
798            "graduated mode should spawn shadow immediately"
799        );
800        assert_eq!(slot.shadow_warmup(), 50);
801    }
802
803    #[test]
804    fn graduated_predict_returns_finite() {
805        let mut slot =
806            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
807        let features = [1.0, 2.0, 3.0];
808
809        for _ in 0..100 {
810            slot.train_and_predict(&features, -0.1, 1.0);
811        }
812
813        let pred = slot.predict_graduated(&features);
814        assert!(
815            pred.is_finite(),
816            "graduated prediction should be finite: {}",
817            pred
818        );
819    }
820
821    #[test]
822    fn graduated_shadow_always_respawns() {
823        let mut slot =
824            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(100), 30);
825        let features = [1.0, 2.0, 3.0];
826
827        // Train past the 120% soft replacement threshold (120 samples)
828        for _ in 0..130 {
829            slot.train_and_predict(&features, -0.1, 1.0);
830        }
831
832        // After soft replacement, shadow should still exist
833        assert!(
834            slot.has_alternate(),
835            "shadow should be respawned after soft replacement"
836        );
837    }
838
839    #[test]
840    fn graduated_blending_produces_intermediate_values() {
841        let mut slot =
842            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
843        let features = [1.0, 2.0, 3.0];
844
845        // Train enough samples for shadow to be warm and blending to be active
846        // (past 80% of max_tree_samples = 160 samples, shadow needs 50 warmup)
847        for _ in 0..180 {
848            slot.train_and_predict(&features, -0.1, 1.0);
849        }
850
851        let active_pred = slot.predict(&features);
852        let graduated_pred = slot.predict_graduated(&features);
853
854        // Both should be finite
855        assert!(active_pred.is_finite());
856        assert!(graduated_pred.is_finite());
857    }
858
859    #[test]
860    fn graduated_reset_preserves_shadow() {
861        let mut slot =
862            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
863
864        slot.reset();
865
866        assert!(
867            slot.has_alternate(),
868            "reset in graduated mode should preserve shadow spawning"
869        );
870    }
871
872    #[test]
873    fn graduated_no_cascading_swap() {
874        let mut slot =
875            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
876
877        // Train past the 120% soft replacement threshold (240 samples).
878        // Use varying features so trees can actually split.
879        for i in 0..250 {
880            let x = (i as f64) * 0.1;
881            let features = [x, x.sin(), x.cos()];
882            let gradient = -0.1 * (1.0 + x.sin());
883            slot.train_and_predict(&features, gradient, 1.0);
884        }
885
886        let replacements_after_first_swap = slot.replacements();
887        assert!(
888            replacements_after_first_swap >= 1,
889            "should have swapped at least once after 250 samples with mts=200"
890        );
891
892        // Train 50 more samples — should NOT trigger another swap
893        for i in 250..300 {
894            let x = (i as f64) * 0.1;
895            let features = [x, x.sin(), x.cos()];
896            slot.train_and_predict(&features, -0.1, 1.0);
897        }
898
899        assert_eq!(
900            slot.replacements(),
901            replacements_after_first_swap,
902            "should not cascade-swap immediately after promotion"
903        );
904    }
905
906    #[test]
907    fn graduated_without_max_tree_samples_still_works() {
908        // shadow_warmup enabled but no max_tree_samples — active never decays
909        let mut slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), None, 50);
910        let features = [1.0, 2.0, 3.0];
911
912        for _ in 0..100 {
913            slot.train_and_predict(&features, -0.1, 1.0);
914        }
915
916        // predict_graduated should work (active_w stays 1.0 since mts = MAX)
917        let pred = slot.predict_graduated(&features);
918        assert!(pred.is_finite(), "graduated without mts should be finite");
919    }
920}