Skip to main content

irithyll_core/ensemble/
replacement.rs

1//! TreeSlot: warning/danger/swap lifecycle for tree replacement.
2//!
3//! Implements the drift-triggered tree replacement strategy from
4//! Gunasekara et al. (2024). Each boosting step owns a "slot" that manages:
5//!
6//! - An **active** tree serving predictions and receiving training samples.
7//! - An optional **alternate** tree that begins training when a Warning signal
8//!   is emitted by the drift detector.
9//! - A **drift detector** monitoring prediction error magnitude.
10//!
11//! # Lifecycle
12//!
13//! ```text
14//!   Stable  --> keep training active tree
15//!   Warning --> spawn alternate tree (if not already training)
16//!   Drift   --> replace active with alternate (or fresh tree), reset detector
17//! ```
18
19use alloc::boxed::Box;
20use core::fmt;
21
22use crate::drift::{DriftDetector, DriftSignal};
23use crate::tree::builder::TreeConfig;
24use crate::tree::hoeffding::HoeffdingTree;
25use crate::tree::StreamingTree;
26
27/// Manages the lifecycle of a single tree in the ensemble.
28///
29/// When the drift detector signals [`DriftSignal::Warning`], an alternate tree
30/// begins training alongside the active tree. When [`DriftSignal::Drift`] is
31/// confirmed, the alternate replaces the active tree (or a fresh tree is
32/// created if no alternate exists). The drift detector is then reset via
33/// [`clone_fresh`](DriftDetector::clone_fresh) to monitor the new tree.
34pub struct TreeSlot {
35    /// The currently active tree serving predictions.
36    active: HoeffdingTree,
37    /// Optional alternate tree being trained during a warning period.
38    alternate: Option<HoeffdingTree>,
39    /// Drift detector monitoring this slot's error stream.
40    detector: Box<dyn DriftDetector>,
41    /// Configuration for creating new trees (shared across replacements).
42    tree_config: TreeConfig,
43    /// Maximum samples before proactive replacement. `None` = disabled.
44    max_tree_samples: Option<u64>,
45    /// Total number of tree replacements (drift or time-based).
46    replacements: u64,
47    /// Welford online count for prediction statistics.
48    pred_count: u64,
49    /// Welford online mean of predictions.
50    pred_mean: f64,
51    /// Welford online M2 accumulator for prediction variance.
52    pred_m2: f64,
53    /// Shadow warmup samples (0 = disabled). When > 0, an always-on shadow
54    /// tree is spawned and trained alongside the active tree.
55    shadow_warmup: usize,
56    /// Sample count of the active tree when it was activated (promoted from shadow).
57    /// Used to compute samples-since-activation for time-based replacement,
58    /// preventing cascading swaps when a shadow is promoted with a high sample count.
59    samples_at_activation: u64,
60}
61
62impl fmt::Debug for TreeSlot {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        f.debug_struct("TreeSlot")
65            .field("active_leaves", &self.active.n_leaves())
66            .field("active_samples", &self.active.n_samples_seen())
67            .field("has_alternate", &self.alternate.is_some())
68            .field("tree_config", &self.tree_config)
69            .finish()
70    }
71}
72
73impl Clone for TreeSlot {
74    fn clone(&self) -> Self {
75        Self {
76            active: self.active.clone(),
77            alternate: self.alternate.clone(),
78            detector: self.detector.clone_boxed(),
79            tree_config: self.tree_config.clone(),
80            max_tree_samples: self.max_tree_samples,
81            replacements: self.replacements,
82            pred_count: self.pred_count,
83            pred_mean: self.pred_mean,
84            pred_m2: self.pred_m2,
85            shadow_warmup: self.shadow_warmup,
86            samples_at_activation: self.samples_at_activation,
87        }
88    }
89}
90
91impl TreeSlot {
92    /// Create a new `TreeSlot` with a fresh tree and drift detector.
93    ///
94    /// The active tree starts as a single-leaf tree (prediction = 0.0).
95    /// No alternate tree is created until a Warning signal is received.
96    pub fn new(
97        tree_config: TreeConfig,
98        detector: Box<dyn DriftDetector>,
99        max_tree_samples: Option<u64>,
100    ) -> Self {
101        Self::with_shadow_warmup(tree_config, detector, max_tree_samples, 0)
102    }
103
104    /// Create a new `TreeSlot` with graduated tree handoff enabled.
105    ///
106    /// When `shadow_warmup > 0`, an always-on shadow tree is spawned immediately
107    /// and trained alongside the active tree. After `shadow_warmup` samples, the
108    /// shadow begins contributing to `predict_graduated()` predictions.
109    pub fn with_shadow_warmup(
110        tree_config: TreeConfig,
111        detector: Box<dyn DriftDetector>,
112        max_tree_samples: Option<u64>,
113        shadow_warmup: usize,
114    ) -> Self {
115        let alternate = if shadow_warmup > 0 {
116            Some(HoeffdingTree::new(tree_config.clone()))
117        } else {
118            None
119        };
120        Self {
121            active: HoeffdingTree::new(tree_config.clone()),
122            alternate,
123            detector,
124            tree_config,
125            max_tree_samples,
126            replacements: 0,
127            pred_count: 0,
128            pred_mean: 0.0,
129            pred_m2: 0.0,
130            shadow_warmup,
131            samples_at_activation: 0,
132        }
133    }
134
135    /// Reconstruct a `TreeSlot` from pre-built trees and a fresh drift detector.
136    ///
137    /// Used during model deserialization to restore tree state without replaying
138    /// the training stream.
139    pub fn from_trees(
140        active: HoeffdingTree,
141        alternate: Option<HoeffdingTree>,
142        tree_config: TreeConfig,
143        detector: Box<dyn DriftDetector>,
144        max_tree_samples: Option<u64>,
145    ) -> Self {
146        Self {
147            active,
148            alternate,
149            detector,
150            tree_config,
151            max_tree_samples,
152            replacements: 0,
153            pred_count: 0,
154            pred_mean: 0.0,
155            pred_m2: 0.0,
156            shadow_warmup: 0,
157            samples_at_activation: 0,
158        }
159    }
160
161    /// Train the active tree (and alternate if it exists) on a single sample.
162    ///
163    /// The absolute value of the gradient is fed to the drift detector as an
164    /// error proxy (gradient = derivative of loss = prediction error signal).
165    ///
166    /// # Returns
167    ///
168    /// The prediction from the **active** tree **before** training on this sample.
169    /// This ensures the prediction reflects only previously seen data, which is
170    /// critical for unbiased gradient computation in the boosting loop.
171    ///
172    /// # Drift handling
173    ///
174    /// - [`DriftSignal::Stable`]: no action.
175    /// - [`DriftSignal::Warning`]: spawn an alternate tree if one is not already
176    ///   being trained. The alternate receives the same training sample.
177    /// - [`DriftSignal::Drift`]: replace the active tree with the alternate
178    ///   (or a fresh tree if no alternate exists). The drift detector is reset
179    ///   via [`clone_fresh`](DriftDetector::clone_fresh) so it monitors the
180    ///   new tree from a clean state.
181    pub fn train_and_predict(&mut self, features: &[f64], gradient: f64, hessian: f64) -> f64 {
182        // 1. Predict from active tree BEFORE training.
183        let prediction = self.active.predict(features);
184
185        // 1b. Update Welford running prediction statistics.
186        self.pred_count += 1;
187        let delta = prediction - self.pred_mean;
188        self.pred_mean += delta / self.pred_count as f64;
189        let delta2 = prediction - self.pred_mean;
190        self.pred_m2 += delta * delta2;
191
192        // 2. Train the active tree.
193        self.active.train_one(features, gradient, hessian);
194
195        // 3. Train the alternate tree if it exists.
196        if let Some(ref mut alt) = self.alternate {
197            alt.train_one(features, gradient, hessian);
198        }
199
200        // 4. Feed error magnitude to the drift detector.
201        //    |gradient| is a proxy for prediction error: for squared loss,
202        //    gradient = (prediction - target), so |gradient| = |error|.
203        let error = crate::math::abs(gradient);
204        let signal = self.detector.update(error);
205
206        // 5. React to the drift signal.
207        match signal {
208            DriftSignal::Stable => {}
209            DriftSignal::Warning => {
210                // Start training an alternate tree if not already doing so.
211                if self.alternate.is_none() {
212                    self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
213                }
214            }
215            DriftSignal::Drift => {
216                // Replace active tree: prefer the alternate (which has been
217                // training on recent data), fall back to a fresh tree.
218                self.active = self
219                    .alternate
220                    .take()
221                    .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
222                // Record activation point to prevent cascading swaps.
223                self.samples_at_activation = self.active.n_samples_seen();
224                // Reset the drift detector to monitor the new tree cleanly.
225                self.detector = self.detector.clone_fresh();
226                // Track replacement and reset prediction stats for the new tree.
227                self.replacements += 1;
228                self.pred_count = 0;
229                self.pred_mean = 0.0;
230                self.pred_m2 = 0.0;
231                // In graduated mode, immediately spawn a new shadow.
232                if self.shadow_warmup > 0 {
233                    self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
234                }
235            }
236        }
237
238        // 6. Proactive time-based replacement.
239        //    Compare samples-since-activation (not total lifetime) to prevent
240        //    cascading swaps when a shadow with high sample count is promoted.
241        if let Some(max_samples) = self.max_tree_samples {
242            let active_age = self
243                .active
244                .n_samples_seen()
245                .saturating_sub(self.samples_at_activation);
246
247            // In graduated mode, wait until 120% of max_samples for soft replacement.
248            let threshold = if self.shadow_warmup > 0 {
249                (max_samples as f64 * 1.2) as u64
250            } else {
251                max_samples
252            };
253
254            if active_age >= threshold {
255                self.active = self
256                    .alternate
257                    .take()
258                    .unwrap_or_else(|| HoeffdingTree::new(self.tree_config.clone()));
259                // Record activation point to prevent cascading swaps.
260                self.samples_at_activation = self.active.n_samples_seen();
261                self.detector = self.detector.clone_fresh();
262                self.replacements += 1;
263                self.pred_count = 0;
264                self.pred_mean = 0.0;
265                self.pred_m2 = 0.0;
266                // In graduated mode, immediately spawn a new shadow.
267                if self.shadow_warmup > 0 {
268                    self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
269                }
270            }
271        }
272
273        // 7. In graduated mode, ensure shadow always exists.
274        if self.shadow_warmup > 0 && self.alternate.is_none() {
275            self.alternate = Some(HoeffdingTree::new(self.tree_config.clone()));
276        }
277
278        prediction
279    }
280
281    /// Predict without training.
282    ///
283    /// Routes the feature vector through the active tree and returns the
284    /// leaf value. Does not update any state.
285    #[inline]
286    pub fn predict(&self, features: &[f64]) -> f64 {
287        self.active.predict(features)
288    }
289
290    /// Predict with variance for confidence estimation.
291    ///
292    /// Returns `(leaf_value, variance)` where variance = 1 / (H_sum + lambda).
293    #[inline]
294    pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
295        self.active.predict_with_variance(features)
296    }
297
298    /// Predict using sigmoid-blended soft routing for smooth interpolation.
299    ///
300    /// See [`crate::tree::hoeffding::HoeffdingTree::predict_smooth`] for details.
301    #[inline]
302    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
303        self.active.predict_smooth(features, bandwidth)
304    }
305
306    /// Predict using per-feature auto-calibrated bandwidths.
307    #[inline]
308    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
309        self.active.predict_smooth_auto(features, bandwidths)
310    }
311
312    /// Predict with parent-leaf linear interpolation.
313    #[inline]
314    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
315        self.active.predict_interpolated(features)
316    }
317
318    /// Predict with sibling-based interpolation for feature-continuous predictions.
319    #[inline]
320    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
321        self.active
322            .predict_sibling_interpolated(features, bandwidths)
323    }
324
325    /// Predict with graduated active-shadow blending.
326    ///
327    /// When `shadow_warmup > 0`, blends the active tree's prediction with the
328    /// shadow's prediction based on relative maturity:
329    /// - Active weight decays from 1.0 to 0.0 as it ages from 80% to 120% of `max_tree_samples`
330    /// - Shadow weight ramps from 0.0 to 1.0 over `shadow_warmup` samples after warmup
331    ///
332    /// When `shadow_warmup == 0` or no shadow exists, returns the active prediction.
333    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
334        let active_pred = self.active.predict(features);
335
336        if self.shadow_warmup == 0 {
337            return active_pred;
338        }
339
340        let Some(ref shadow) = self.alternate else {
341            return active_pred;
342        };
343
344        let shadow_samples = shadow.n_samples_seen();
345        if shadow_samples < self.shadow_warmup as u64 {
346            return active_pred;
347        }
348
349        let shadow_pred = shadow.predict(features);
350        self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
351    }
352
353    /// Predict with graduated blending + sibling interpolation (premium path).
354    ///
355    /// Combines graduated active-shadow handoff with feature-continuous sibling
356    /// interpolation for the smoothest possible prediction surface.
357    pub fn predict_graduated_sibling_interpolated(
358        &self,
359        features: &[f64],
360        bandwidths: &[f64],
361    ) -> f64 {
362        let active_pred = self
363            .active
364            .predict_sibling_interpolated(features, bandwidths);
365
366        if self.shadow_warmup == 0 {
367            return active_pred;
368        }
369
370        let Some(ref shadow) = self.alternate else {
371            return active_pred;
372        };
373
374        let shadow_samples = shadow.n_samples_seen();
375        if shadow_samples < self.shadow_warmup as u64 {
376            return active_pred;
377        }
378
379        let shadow_pred = shadow.predict_sibling_interpolated(features, bandwidths);
380        self.blend_active_shadow(active_pred, shadow_pred, shadow_samples)
381    }
382
383    /// Compute the graduated blend of active and shadow predictions.
384    #[inline]
385    fn blend_active_shadow(&self, active_pred: f64, shadow_pred: f64, shadow_samples: u64) -> f64 {
386        let active_age = self
387            .active
388            .n_samples_seen()
389            .saturating_sub(self.samples_at_activation);
390        let mts = self.max_tree_samples.unwrap_or(u64::MAX) as f64;
391
392        // Active weight: 1.0 until 80% of mts, then linear decay to 0.0 at 120%
393        let active_w = if (active_age as f64) < mts * 0.8 {
394            1.0
395        } else {
396            let progress = (active_age as f64 - mts * 0.8) / (mts * 0.4);
397            (1.0 - progress).clamp(0.0, 1.0)
398        };
399
400        // Shadow weight: 0.0 until shadow_warmup, then ramp to 1.0 over shadow_warmup samples
401        let shadow_w = ((shadow_samples as f64 - self.shadow_warmup as f64)
402            / self.shadow_warmup as f64)
403            .clamp(0.0, 1.0);
404
405        // Normalize
406        let total = active_w + shadow_w;
407        if total < 1e-10 {
408            return shadow_pred;
409        }
410
411        (active_w * active_pred + shadow_w * shadow_pred) / total
412    }
413
414    /// Shadow warmup configuration (0 = disabled).
415    #[inline]
416    pub fn shadow_warmup(&self) -> usize {
417        self.shadow_warmup
418    }
419
420    /// Total number of tree replacements (drift or time-based).
421    #[inline]
422    pub fn replacements(&self) -> u64 {
423        self.replacements
424    }
425
426    /// Running mean of predictions from the active tree.
427    #[inline]
428    pub fn prediction_mean(&self) -> f64 {
429        self.pred_mean
430    }
431
432    /// Running standard deviation of predictions from the active tree.
433    #[inline]
434    pub fn prediction_std(&self) -> f64 {
435        if self.pred_count < 2 {
436            0.0
437        } else {
438            crate::math::sqrt(self.pred_m2 / (self.pred_count - 1) as f64)
439        }
440    }
441
442    /// Number of leaves in the active tree.
443    #[inline]
444    pub fn n_leaves(&self) -> usize {
445        self.active.n_leaves()
446    }
447
448    /// Total samples the active tree has seen.
449    #[inline]
450    pub fn n_samples_seen(&self) -> u64 {
451        self.active.n_samples_seen()
452    }
453
454    /// Whether an alternate tree is currently being trained.
455    #[inline]
456    pub fn has_alternate(&self) -> bool {
457        self.alternate.is_some()
458    }
459
460    /// Accumulated split gains per feature from the active tree.
461    #[inline]
462    pub fn split_gains(&self) -> &[f64] {
463        self.active.split_gains()
464    }
465
466    /// Immutable access to the active tree.
467    #[inline]
468    pub fn active_tree(&self) -> &HoeffdingTree {
469        &self.active
470    }
471
472    /// Immutable access to the alternate tree (if one is being trained).
473    #[inline]
474    pub fn alternate_tree(&self) -> Option<&HoeffdingTree> {
475        self.alternate.as_ref()
476    }
477
478    /// Immutable access to the tree configuration.
479    #[inline]
480    pub fn tree_config(&self) -> &TreeConfig {
481        &self.tree_config
482    }
483
484    /// Immutable access to the drift detector.
485    #[inline]
486    pub fn detector(&self) -> &dyn DriftDetector {
487        &*self.detector
488    }
489
490    /// Mutable access to the drift detector.
491    #[inline]
492    pub fn detector_mut(&mut self) -> &mut dyn DriftDetector {
493        &mut *self.detector
494    }
495
496    /// Immutable access to the alternate drift detector (always `None` in
497    /// the current architecture -- the alternate tree shares the main detector).
498    /// Reserved for future use.
499    #[inline]
500    pub fn alt_detector(&self) -> Option<&dyn DriftDetector> {
501        // Currently there's no separate alt detector.
502        None
503    }
504
505    /// Mutable access to the alternate drift detector.
506    #[inline]
507    pub fn alt_detector_mut(&mut self) -> Option<&mut dyn DriftDetector> {
508        None
509    }
510
511    /// Reset to a completely fresh state: new tree, no alternate, reset detector.
512    pub fn reset(&mut self) {
513        self.active = HoeffdingTree::new(self.tree_config.clone());
514        self.alternate = if self.shadow_warmup > 0 {
515            Some(HoeffdingTree::new(self.tree_config.clone()))
516        } else {
517            None
518        };
519        self.detector = self.detector.clone_fresh();
520        self.replacements = 0;
521        self.pred_count = 0;
522        self.pred_mean = 0.0;
523        self.pred_m2 = 0.0;
524        self.samples_at_activation = 0;
525    }
526}
527
528// ---------------------------------------------------------------------------
529// Tests
530// ---------------------------------------------------------------------------
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use crate::drift::pht::PageHinkleyTest;
536    use alloc::boxed::Box;
537    use alloc::format;
538
539    /// Create a default TreeConfig for tests (small grace period for fast splits).
540    fn test_tree_config() -> TreeConfig {
541        TreeConfig::new()
542            .grace_period(20)
543            .max_depth(4)
544            .n_bins(16)
545            .lambda(1.0)
546    }
547
548    /// Create a default drift detector for tests.
549    fn test_detector() -> Box<dyn DriftDetector> {
550        Box::new(PageHinkleyTest::new())
551    }
552
553    // -------------------------------------------------------------------
554    // Test 1: TreeSlot::new creates a functional slot; predict returns 0.0.
555    // -------------------------------------------------------------------
556    #[test]
557    fn new_slot_predicts_zero() {
558        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
559
560        // A fresh tree with no training data should predict 0.0.
561        let pred = slot.predict(&[1.0, 2.0, 3.0]);
562        assert!(
563            pred.abs() < 1e-12,
564            "fresh slot should predict ~0.0, got {}",
565            pred,
566        );
567    }
568
569    // -------------------------------------------------------------------
570    // Test 2: train_and_predict returns a prediction and does not panic.
571    // -------------------------------------------------------------------
572    #[test]
573    fn train_and_predict_returns_prediction() {
574        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
575
576        let features = [1.0, 2.0, 3.0];
577        let pred = slot.train_and_predict(&features, -0.5, 1.0);
578
579        // First prediction should be 0.0 (tree was empty before training).
580        assert!(
581            pred.abs() < 1e-12,
582            "first prediction should be ~0.0, got {}",
583            pred,
584        );
585
586        // After training, the tree should have updated, so a second predict
587        // should be non-zero (gradient=-0.5 pushes leaf weight positive).
588        let pred2 = slot.predict(&features);
589        assert!(
590            pred2.is_finite(),
591            "prediction after training should be finite"
592        );
593    }
594
595    // -------------------------------------------------------------------
596    // Test 3: After many stable samples, no alternate tree is spawned.
597    // -------------------------------------------------------------------
598    #[test]
599    fn stable_stream_no_alternate() {
600        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
601        let features = [1.0, 2.0, 3.0];
602
603        // Feed many stable samples with small, consistent gradients.
604        // With a constant error of 0.1, the PHT running mean settles and
605        // no warning/drift should trigger.
606        for _ in 0..500 {
607            slot.train_and_predict(&features, -0.1, 1.0);
608        }
609
610        assert!(
611            !slot.has_alternate(),
612            "stable error stream should not spawn an alternate tree",
613        );
614    }
615
616    // -------------------------------------------------------------------
617    // Test 4: Reset returns to fresh state.
618    // -------------------------------------------------------------------
619    #[test]
620    fn reset_returns_to_fresh_state() {
621        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
622        let features = [1.0, 2.0, 3.0];
623
624        // Train several samples.
625        for _ in 0..100 {
626            slot.train_and_predict(&features, -0.5, 1.0);
627        }
628
629        assert!(slot.n_samples_seen() > 0, "should have trained samples");
630
631        slot.reset();
632
633        assert_eq!(
634            slot.n_leaves(),
635            1,
636            "after reset, should have exactly 1 leaf"
637        );
638        assert_eq!(
639            slot.n_samples_seen(),
640            0,
641            "after reset, samples_seen should be 0"
642        );
643        assert!(
644            !slot.has_alternate(),
645            "after reset, no alternate should exist"
646        );
647
648        // Predict should return 0.0 again.
649        let pred = slot.predict(&features);
650        assert!(
651            pred.abs() < 1e-12,
652            "prediction after reset should be ~0.0, got {}",
653            pred,
654        );
655    }
656
657    // -------------------------------------------------------------------
658    // Test 5: Predict without training works.
659    // -------------------------------------------------------------------
660    #[test]
661    fn predict_without_training() {
662        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
663
664        // Multiple predict calls on a fresh slot should all return 0.0.
665        for i in 0..10 {
666            let x = (i as f64) * 0.5;
667            let pred = slot.predict(&[x, x + 1.0]);
668            assert!(
669                pred.abs() < 1e-12,
670                "untrained slot should predict ~0.0 for any input, got {} at i={}",
671                pred,
672                i,
673            );
674        }
675    }
676
677    // -------------------------------------------------------------------
678    // Test 6: Drift replaces the active tree.
679    // -------------------------------------------------------------------
680    #[test]
681    fn drift_replaces_active_tree() {
682        // Use a very sensitive detector: small lambda triggers drift quickly.
683        let sensitive_detector = Box::new(PageHinkleyTest::with_params(0.005, 5.0));
684        let mut slot = TreeSlot::new(test_tree_config(), sensitive_detector, None);
685        let features = [1.0, 2.0, 3.0];
686
687        // Phase 1: stable training with small gradients.
688        for _ in 0..200 {
689            slot.train_and_predict(&features, -0.01, 1.0);
690        }
691        let samples_before_drift = slot.n_samples_seen();
692
693        // Phase 2: abrupt shift in gradient magnitude to trigger drift.
694        let mut drift_occurred = false;
695        for _ in 0..500 {
696            slot.train_and_predict(&features, -50.0, 1.0);
697            // If drift occurred, the tree was replaced and samples_seen resets
698            // (new tree starts from 0).
699            if slot.n_samples_seen() < samples_before_drift {
700                drift_occurred = true;
701                break;
702            }
703        }
704
705        assert!(
706            drift_occurred,
707            "abrupt gradient shift should trigger drift and replace the active tree",
708        );
709    }
710
711    // -------------------------------------------------------------------
712    // Test 7: n_leaves reflects the active tree.
713    // -------------------------------------------------------------------
714    #[test]
715    fn n_leaves_reflects_active_tree() {
716        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
717        assert_eq!(slot.n_leaves(), 1, "fresh slot should have exactly 1 leaf",);
718    }
719
720    // -------------------------------------------------------------------
721    // Test 8: Debug formatting works.
722    // -------------------------------------------------------------------
723    #[test]
724    fn debug_format_does_not_panic() {
725        let slot = TreeSlot::new(test_tree_config(), test_detector(), None);
726        let debug_str = format!("{:?}", slot);
727        assert!(
728            debug_str.contains("TreeSlot"),
729            "debug output should contain 'TreeSlot'",
730        );
731    }
732
733    // -------------------------------------------------------------------
734    // Test 9: Time-based replacement triggers after max_tree_samples.
735    // -------------------------------------------------------------------
736    #[test]
737    fn time_based_replacement_triggers() {
738        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), Some(200));
739        let features = [1.0, 2.0, 3.0];
740
741        // Train up to the limit.
742        for _ in 0..200 {
743            slot.train_and_predict(&features, -0.1, 1.0);
744        }
745
746        // At exactly 200 samples, the tree should have been replaced.
747        // The new tree has 0 samples seen (or the most recently trained sample).
748        assert!(
749            slot.n_samples_seen() < 200,
750            "after 200 samples with max_tree_samples=200, tree should be replaced (got {} samples)",
751            slot.n_samples_seen(),
752        );
753    }
754
755    // -------------------------------------------------------------------
756    // Test 10: Time-based replacement disabled (None) never triggers.
757    // -------------------------------------------------------------------
758    #[test]
759    fn time_based_replacement_disabled() {
760        let mut slot = TreeSlot::new(test_tree_config(), test_detector(), None);
761        let features = [1.0, 2.0, 3.0];
762
763        for _ in 0..500 {
764            slot.train_and_predict(&features, -0.1, 1.0);
765        }
766
767        assert_eq!(
768            slot.n_samples_seen(),
769            500,
770            "without max_tree_samples, tree should never be proactively replaced",
771        );
772    }
773
774    // -------------------------------------------------------------------
775    // Graduated handoff tests
776    // -------------------------------------------------------------------
777
778    #[test]
779    fn graduated_shadow_spawns_immediately() {
780        let slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
781
782        assert!(
783            slot.has_alternate(),
784            "graduated mode should spawn shadow immediately"
785        );
786        assert_eq!(slot.shadow_warmup(), 50);
787    }
788
789    #[test]
790    fn graduated_predict_returns_finite() {
791        let mut slot =
792            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
793        let features = [1.0, 2.0, 3.0];
794
795        for _ in 0..100 {
796            slot.train_and_predict(&features, -0.1, 1.0);
797        }
798
799        let pred = slot.predict_graduated(&features);
800        assert!(
801            pred.is_finite(),
802            "graduated prediction should be finite: {}",
803            pred
804        );
805    }
806
807    #[test]
808    fn graduated_shadow_always_respawns() {
809        let mut slot =
810            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(100), 30);
811        let features = [1.0, 2.0, 3.0];
812
813        // Train past the 120% soft replacement threshold (120 samples)
814        for _ in 0..130 {
815            slot.train_and_predict(&features, -0.1, 1.0);
816        }
817
818        // After soft replacement, shadow should still exist
819        assert!(
820            slot.has_alternate(),
821            "shadow should be respawned after soft replacement"
822        );
823    }
824
825    #[test]
826    fn graduated_blending_produces_intermediate_values() {
827        let mut slot =
828            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
829        let features = [1.0, 2.0, 3.0];
830
831        // Train enough samples for shadow to be warm and blending to be active
832        // (past 80% of max_tree_samples = 160 samples, shadow needs 50 warmup)
833        for _ in 0..180 {
834            slot.train_and_predict(&features, -0.1, 1.0);
835        }
836
837        let active_pred = slot.predict(&features);
838        let graduated_pred = slot.predict_graduated(&features);
839
840        // Both should be finite
841        assert!(active_pred.is_finite());
842        assert!(graduated_pred.is_finite());
843    }
844
845    #[test]
846    fn graduated_reset_preserves_shadow() {
847        let mut slot =
848            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
849
850        slot.reset();
851
852        assert!(
853            slot.has_alternate(),
854            "reset in graduated mode should preserve shadow spawning"
855        );
856    }
857
858    #[test]
859    fn graduated_no_cascading_swap() {
860        let mut slot =
861            TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), Some(200), 50);
862
863        // Train past the 120% soft replacement threshold (240 samples).
864        // Use varying features so trees can actually split.
865        for i in 0..250 {
866            let x = (i as f64) * 0.1;
867            let features = [x, x.sin(), x.cos()];
868            let gradient = -0.1 * (1.0 + x.sin());
869            slot.train_and_predict(&features, gradient, 1.0);
870        }
871
872        let replacements_after_first_swap = slot.replacements();
873        assert!(
874            replacements_after_first_swap >= 1,
875            "should have swapped at least once after 250 samples with mts=200"
876        );
877
878        // Train 50 more samples — should NOT trigger another swap
879        for i in 250..300 {
880            let x = (i as f64) * 0.1;
881            let features = [x, x.sin(), x.cos()];
882            slot.train_and_predict(&features, -0.1, 1.0);
883        }
884
885        assert_eq!(
886            slot.replacements(),
887            replacements_after_first_swap,
888            "should not cascade-swap immediately after promotion"
889        );
890    }
891
892    #[test]
893    fn graduated_without_max_tree_samples_still_works() {
894        // shadow_warmup enabled but no max_tree_samples — active never decays
895        let mut slot = TreeSlot::with_shadow_warmup(test_tree_config(), test_detector(), None, 50);
896        let features = [1.0, 2.0, 3.0];
897
898        for _ in 0..100 {
899            slot.train_and_predict(&features, -0.1, 1.0);
900        }
901
902        // predict_graduated should work (active_w stays 1.0 since mts = MAX)
903        let pred = slot.predict_graduated(&features);
904        assert!(pred.is_finite(), "graduated without mts should be finite");
905    }
906}