Skip to main content

irithyll_core/ensemble/
step.rs

1//! Single boosting step: owns one tree + drift detector + optional alternate.
2//!
3//! [`BoostingStep`] is a thin wrapper around [`TreeSlot`]
4//! that adds SGBT variant logic. The three variants from Gunasekara et al. (2024) are:
5//!
6//! - **Standard** (`train_count = 1`): each sample trains the tree exactly once.
7//! - **Skip** (`train_count = 0`): the sample is skipped (only prediction returned).
8//! - **Multiple Iterations** (`train_count > 1`): the sample trains the tree
9//!   multiple times, weighted by the hessian.
10//!
11//! The variant logic is computed externally (by `SGBTVariant::train_count()` in the
12//! ensemble orchestrator) and passed in as `train_count`. This keeps `BoostingStep`
13//! focused on execution rather than policy.
14
15use alloc::boxed::Box;
16use core::fmt;
17
18use crate::drift::DriftDetector;
19use crate::ensemble::replacement::TreeSlot;
20use crate::tree::builder::TreeConfig;
21
22/// A single step in the SGBT boosting sequence.
23///
24/// Owns a [`TreeSlot`] and applies variant-aware training repetition. The number
25/// of training iterations per sample is determined by the caller (the ensemble
26/// orchestrator computes `train_count` from the configured SGBT variant).
27///
28/// # Prediction semantics
29///
30/// Both [`train_and_predict`](BoostingStep::train_and_predict) and
31/// [`predict`](BoostingStep::predict) return the active tree's prediction
32/// **before** any training on the current sample. This ensures unbiased
33/// gradient computation in the boosting loop.
34#[derive(Clone)]
35pub struct BoostingStep {
36    /// The tree slot managing the active tree, alternate, and drift detector.
37    slot: TreeSlot,
38}
39
40impl fmt::Debug for BoostingStep {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("BoostingStep")
43            .field("slot", &self.slot)
44            .finish()
45    }
46}
47
48impl BoostingStep {
49    /// Create a new boosting step with a fresh tree and drift detector.
50    pub fn new(tree_config: TreeConfig, detector: Box<dyn DriftDetector>) -> Self {
51        Self {
52            slot: TreeSlot::new(tree_config, detector, None),
53        }
54    }
55
56    /// Create a new boosting step with optional time-based tree replacement.
57    pub fn new_with_max_samples(
58        tree_config: TreeConfig,
59        detector: Box<dyn DriftDetector>,
60        max_tree_samples: Option<u64>,
61    ) -> Self {
62        Self {
63            slot: TreeSlot::new(tree_config, detector, max_tree_samples),
64        }
65    }
66
67    /// Create a new boosting step with graduated tree handoff.
68    pub fn new_with_graduated(
69        tree_config: TreeConfig,
70        detector: Box<dyn DriftDetector>,
71        max_tree_samples: Option<u64>,
72        shadow_warmup: usize,
73    ) -> Self {
74        Self {
75            slot: TreeSlot::with_shadow_warmup(
76                tree_config,
77                detector,
78                max_tree_samples,
79                shadow_warmup,
80            ),
81        }
82    }
83
84    /// Reconstruct a boosting step from a pre-built tree slot.
85    ///
86    /// Used during model deserialization.
87    pub fn from_slot(slot: TreeSlot) -> Self {
88        Self { slot }
89    }
90
91    /// Train on a single sample with variant-aware repetition.
92    ///
93    /// # Arguments
94    ///
95    /// * `features` - Input feature vector.
96    /// * `gradient` - Negative gradient of the loss at this sample.
97    /// * `hessian` - Second derivative (curvature) of the loss at this sample.
98    /// * `train_count` - Number of training iterations for this sample:
99    ///   - `0`: skip training entirely (SK variant or stochastic skip).
100    ///   - `1`: standard single-pass training.
101    ///   - `>1`: multiple iterations (MI variant).
102    ///
103    /// # Returns
104    ///
105    /// The prediction from the active tree **before** training.
106    pub fn train_and_predict(
107        &mut self,
108        features: &[f64],
109        gradient: f64,
110        hessian: f64,
111        train_count: usize,
112    ) -> f64 {
113        if train_count == 0 {
114            // Skip variant: no training, just predict.
115            return self.slot.predict(features);
116        }
117
118        // First iteration: train and get the pre-training prediction.
119        let pred = self.slot.train_and_predict(features, gradient, hessian);
120
121        // Additional iterations for MI variant.
122        // Each subsequent call still feeds the same gradient/hessian to the
123        // tree and drift detector, effectively weighting this sample more heavily.
124        for _ in 1..train_count {
125            self.slot.train_and_predict(features, gradient, hessian);
126        }
127
128        pred
129    }
130
131    /// Predict without training.
132    ///
133    /// Routes the feature vector through the active tree and returns the
134    /// leaf value. Does not update any state.
135    #[inline]
136    pub fn predict(&self, features: &[f64]) -> f64 {
137        self.slot.predict(features)
138    }
139
140    /// Predict with variance for confidence estimation.
141    ///
142    /// Returns `(leaf_value, variance)` where variance = 1 / (H_sum + lambda).
143    #[inline]
144    pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
145        self.slot.predict_with_variance(features)
146    }
147
148    /// Predict using sigmoid-blended soft routing for smooth interpolation.
149    ///
150    /// See [`crate::tree::hoeffding::HoeffdingTree::predict_smooth`] for details.
151    #[inline]
152    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
153        self.slot.predict_smooth(features, bandwidth)
154    }
155
156    /// Predict using per-feature auto-calibrated bandwidths.
157    #[inline]
158    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
159        self.slot.predict_smooth_auto(features, bandwidths)
160    }
161
162    /// Predict with parent-leaf linear interpolation.
163    #[inline]
164    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
165        self.slot.predict_interpolated(features)
166    }
167
168    /// Predict with sibling-based interpolation for feature-continuous predictions.
169    #[inline]
170    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
171        self.slot.predict_sibling_interpolated(features, bandwidths)
172    }
173
174    /// Predict with graduated active-shadow blending.
175    #[inline]
176    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
177        self.slot.predict_graduated(features)
178    }
179
180    /// Predict with graduated blending + sibling interpolation.
181    #[inline]
182    pub fn predict_graduated_sibling_interpolated(
183        &self,
184        features: &[f64],
185        bandwidths: &[f64],
186    ) -> f64 {
187        self.slot
188            .predict_graduated_sibling_interpolated(features, bandwidths)
189    }
190
191    /// Number of leaves in the active tree.
192    #[inline]
193    pub fn n_leaves(&self) -> usize {
194        self.slot.n_leaves()
195    }
196
197    /// Total samples the active tree has seen.
198    #[inline]
199    pub fn n_samples_seen(&self) -> u64 {
200        self.slot.n_samples_seen()
201    }
202
203    /// Whether the slot has an alternate tree being trained.
204    #[inline]
205    pub fn has_alternate(&self) -> bool {
206        self.slot.has_alternate()
207    }
208
209    /// Reset to a completely fresh state: new tree, no alternate, reset detector.
210    pub fn reset(&mut self) {
211        self.slot.reset();
212    }
213
214    /// Immutable access to the underlying [`TreeSlot`].
215    #[inline]
216    pub fn slot(&self) -> &TreeSlot {
217        &self.slot
218    }
219
220    /// Mutable access to the underlying [`TreeSlot`].
221    #[inline]
222    pub fn slot_mut(&mut self) -> &mut TreeSlot {
223        &mut self.slot
224    }
225}
226
227// ---------------------------------------------------------------------------
228// Tests
229// ---------------------------------------------------------------------------
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::drift::pht::PageHinkleyTest;
235    use alloc::boxed::Box;
236    use alloc::format;
237
238    /// Create a default TreeConfig for tests.
239    fn test_tree_config() -> TreeConfig {
240        TreeConfig::new()
241            .grace_period(20)
242            .max_depth(4)
243            .n_bins(16)
244            .lambda(1.0)
245    }
246
247    /// Create a default drift detector for tests.
248    fn test_detector() -> Box<dyn DriftDetector> {
249        Box::new(PageHinkleyTest::new())
250    }
251
252    // -------------------------------------------------------------------
253    // Test 1: train_count=0 skips training (just predicts).
254    // -------------------------------------------------------------------
255    #[test]
256    fn train_count_zero_skips_training() {
257        let mut step = BoostingStep::new(test_tree_config(), test_detector());
258        let features = [1.0, 2.0, 3.0];
259
260        // Train with count=0 should not actually train.
261        let pred = step.train_and_predict(&features, -0.5, 1.0, 0);
262        assert!(
263            pred.abs() < 1e-12,
264            "train_count=0 should return fresh prediction (~0.0), got {}",
265            pred,
266        );
267
268        // Verify no samples were actually trained.
269        assert_eq!(
270            step.n_samples_seen(),
271            0,
272            "train_count=0 should not increment samples_seen",
273        );
274    }
275
276    // -------------------------------------------------------------------
277    // Test 2: train_count=1 trains once.
278    // -------------------------------------------------------------------
279    #[test]
280    fn train_count_one_trains_once() {
281        let mut step = BoostingStep::new(test_tree_config(), test_detector());
282        let features = [1.0, 2.0, 3.0];
283
284        let pred = step.train_and_predict(&features, -0.5, 1.0, 1);
285        assert!(
286            pred.abs() < 1e-12,
287            "first prediction should be ~0.0, got {}",
288            pred,
289        );
290
291        // After one training call, the tree should have seen 1 sample.
292        assert_eq!(
293            step.n_samples_seen(),
294            1,
295            "train_count=1 should train exactly once",
296        );
297
298        // Second call should return non-zero (tree has been trained).
299        let pred2 = step.predict(&features);
300        assert!(
301            pred2.is_finite(),
302            "prediction after training should be finite",
303        );
304    }
305
306    // -------------------------------------------------------------------
307    // Test 3: train_count=3 trains multiple times.
308    // -------------------------------------------------------------------
309    #[test]
310    fn train_count_three_trains_multiple_times() {
311        let mut step = BoostingStep::new(test_tree_config(), test_detector());
312        let features = [1.0, 2.0, 3.0];
313
314        let pred = step.train_and_predict(&features, -0.5, 1.0, 3);
315        assert!(
316            pred.abs() < 1e-12,
317            "first prediction should be ~0.0, got {}",
318            pred,
319        );
320
321        // After train_count=3, the tree should have seen 3 samples.
322        assert_eq!(
323            step.n_samples_seen(),
324            3,
325            "train_count=3 should train exactly 3 times",
326        );
327    }
328
329    // -------------------------------------------------------------------
330    // Test 4: Reset works.
331    // -------------------------------------------------------------------
332    #[test]
333    fn reset_clears_state() {
334        let mut step = BoostingStep::new(test_tree_config(), test_detector());
335        let features = [1.0, 2.0, 3.0];
336
337        // Train several samples.
338        for _ in 0..50 {
339            step.train_and_predict(&features, -0.5, 1.0, 1);
340        }
341
342        assert!(step.n_samples_seen() > 0, "should have trained samples");
343
344        step.reset();
345
346        assert_eq!(step.n_leaves(), 1, "after reset, should have 1 leaf");
347        assert_eq!(
348            step.n_samples_seen(),
349            0,
350            "after reset, samples_seen should be 0"
351        );
352        assert!(
353            !step.has_alternate(),
354            "after reset, no alternate should exist"
355        );
356
357        let pred = step.predict(&features);
358        assert!(
359            pred.abs() < 1e-12,
360            "prediction after reset should be ~0.0, got {}",
361            pred,
362        );
363    }
364
365    // -------------------------------------------------------------------
366    // Test 5: Predict-only (no training) works on fresh step.
367    // -------------------------------------------------------------------
368    #[test]
369    fn predict_only_on_fresh_step() {
370        let step = BoostingStep::new(test_tree_config(), test_detector());
371
372        for i in 0..10 {
373            let x = (i as f64) * 0.5;
374            let pred = step.predict(&[x, x + 1.0, x + 2.0]);
375            assert!(
376                pred.abs() < 1e-12,
377                "untrained step should predict ~0.0, got {} at i={}",
378                pred,
379                i,
380            );
381        }
382    }
383
384    // -------------------------------------------------------------------
385    // Test 6: Multiple calls with different train_counts produce expected
386    //         cumulative sample counts.
387    // -------------------------------------------------------------------
388    #[test]
389    fn mixed_train_counts_accumulate_correctly() {
390        let mut step = BoostingStep::new(test_tree_config(), test_detector());
391        let features = [1.0, 2.0, 3.0];
392
393        // count=2 -> 2 samples
394        step.train_and_predict(&features, -0.1, 1.0, 2);
395        assert_eq!(step.n_samples_seen(), 2);
396
397        // count=0 -> still 2 samples (skipped)
398        step.train_and_predict(&features, -0.1, 1.0, 0);
399        assert_eq!(step.n_samples_seen(), 2);
400
401        // count=1 -> 3 samples
402        step.train_and_predict(&features, -0.1, 1.0, 1);
403        assert_eq!(step.n_samples_seen(), 3);
404
405        // count=5 -> 8 samples
406        step.train_and_predict(&features, -0.1, 1.0, 5);
407        assert_eq!(step.n_samples_seen(), 8);
408    }
409
410    // -------------------------------------------------------------------
411    // Test 7: n_leaves and has_alternate passthrough to slot.
412    // -------------------------------------------------------------------
413    #[test]
414    fn accessors_match_slot() {
415        let step = BoostingStep::new(test_tree_config(), test_detector());
416
417        assert_eq!(step.n_leaves(), step.slot().n_leaves());
418        assert_eq!(step.has_alternate(), step.slot().has_alternate());
419        assert_eq!(step.n_samples_seen(), step.slot().n_samples_seen());
420    }
421
422    // -------------------------------------------------------------------
423    // Test 8: Debug formatting works.
424    // -------------------------------------------------------------------
425    #[test]
426    fn debug_format_does_not_panic() {
427        let step = BoostingStep::new(test_tree_config(), test_detector());
428        let debug_str = format!("{:?}", step);
429        assert!(
430            debug_str.contains("BoostingStep"),
431            "debug output should contain 'BoostingStep'",
432        );
433    }
434}