Skip to main content

irithyll_core/ensemble/
step.rs

1//! Single boosting step: owns one tree + drift detector + optional alternate.
2//!
3//! [`BoostingStep`] is a thin wrapper around [`TreeSlot`]
4//! that adds SGBT variant logic. The three variants from Gunasekara et al. (2024) are:
5//!
6//! - **Standard** (`train_count = 1`): each sample trains the tree exactly once.
7//! - **Skip** (`train_count = 0`): the sample is skipped (only prediction returned).
8//! - **Multiple Iterations** (`train_count > 1`): the sample trains the tree
9//!   multiple times, weighted by the hessian.
10//!
11//! The variant logic is computed externally (by `SGBTVariant::train_count()` in the
12//! ensemble orchestrator) and passed in as `train_count`. This keeps `BoostingStep`
13//! focused on execution rather than policy.
14
15use alloc::boxed::Box;
16use core::fmt;
17
18use crate::drift::DriftDetector;
19use crate::ensemble::replacement::TreeSlot;
20use crate::tree::builder::TreeConfig;
21
22/// A single step in the SGBT boosting sequence.
23///
24/// Owns a [`TreeSlot`] and applies variant-aware training repetition. The number
25/// of training iterations per sample is determined by the caller (the ensemble
26/// orchestrator computes `train_count` from the configured SGBT variant).
27///
28/// # Prediction semantics
29///
30/// Both [`train_and_predict`](BoostingStep::train_and_predict) and
31/// [`predict`](BoostingStep::predict) return the active tree's prediction
32/// **before** any training on the current sample. This ensures unbiased
33/// gradient computation in the boosting loop.
34#[derive(Clone)]
35pub struct BoostingStep {
36    /// The tree slot managing the active tree, alternate, and drift detector.
37    slot: TreeSlot,
38}
39
40impl fmt::Debug for BoostingStep {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("BoostingStep")
43            .field("slot", &self.slot)
44            .finish()
45    }
46}
47
48impl BoostingStep {
49    /// Create a new boosting step with a fresh tree and drift detector.
50    pub fn new(tree_config: TreeConfig, detector: Box<dyn DriftDetector>) -> Self {
51        Self {
52            slot: TreeSlot::new(tree_config, detector, None),
53        }
54    }
55
56    /// Create a new boosting step with optional time-based tree replacement.
57    pub fn new_with_max_samples(
58        tree_config: TreeConfig,
59        detector: Box<dyn DriftDetector>,
60        max_tree_samples: Option<u64>,
61    ) -> Self {
62        Self {
63            slot: TreeSlot::new(tree_config, detector, max_tree_samples),
64        }
65    }
66
67    /// Create a new boosting step with graduated tree handoff.
68    pub fn new_with_graduated(
69        tree_config: TreeConfig,
70        detector: Box<dyn DriftDetector>,
71        max_tree_samples: Option<u64>,
72        shadow_warmup: usize,
73    ) -> Self {
74        Self {
75            slot: TreeSlot::with_shadow_warmup(
76                tree_config,
77                detector,
78                max_tree_samples,
79                shadow_warmup,
80            ),
81        }
82    }
83
84    /// Reconstruct a boosting step from a pre-built tree slot.
85    ///
86    /// Used during model deserialization.
87    pub fn from_slot(slot: TreeSlot) -> Self {
88        Self { slot }
89    }
90
91    /// Train on a single sample with variant-aware repetition.
92    ///
93    /// # Arguments
94    ///
95    /// * `features` - Input feature vector.
96    /// * `gradient` - Negative gradient of the loss at this sample.
97    /// * `hessian` - Second derivative (curvature) of the loss at this sample.
98    /// * `train_count` - Number of training iterations for this sample:
99    ///   - `0`: skip training entirely (SK variant or stochastic skip).
100    ///   - `1`: standard single-pass training.
101    ///   - `>1`: multiple iterations (MI variant).
102    ///
103    /// # Returns
104    ///
105    /// The prediction from the active tree **before** training.
106    pub fn train_and_predict(
107        &mut self,
108        features: &[f64],
109        gradient: f64,
110        hessian: f64,
111        train_count: usize,
112    ) -> f64 {
113        if train_count == 0 {
114            // Skip variant: no training, just predict.
115            return self.slot.predict(features);
116        }
117
118        // First iteration: train and get the pre-training prediction.
119        let pred = self.slot.train_and_predict(features, gradient, hessian);
120
121        // Additional iterations for MI variant.
122        // Each subsequent call still feeds the same gradient/hessian to the
123        // tree and drift detector, effectively weighting this sample more heavily.
124        for _ in 1..train_count {
125            self.slot.train_and_predict(features, gradient, hessian);
126        }
127
128        pred
129    }
130
131    /// Predict without training.
132    ///
133    /// Routes the feature vector through the active tree and returns the
134    /// leaf value. Does not update any state.
135    #[inline]
136    pub fn predict(&self, features: &[f64]) -> f64 {
137        self.slot.predict(features)
138    }
139
140    /// Predict with variance for confidence estimation.
141    ///
142    /// Returns `(leaf_value, variance)` where variance = 1 / (H_sum + lambda).
143    #[inline]
144    pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
145        self.slot.predict_with_variance(features)
146    }
147
148    /// Predict using sigmoid-blended soft routing for smooth interpolation.
149    ///
150    /// See [`crate::tree::hoeffding::HoeffdingTree::predict_smooth`] for details.
151    #[inline]
152    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
153        self.slot.predict_smooth(features, bandwidth)
154    }
155
156    /// Predict using per-feature auto-calibrated bandwidths.
157    #[inline]
158    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
159        self.slot.predict_smooth_auto(features, bandwidths)
160    }
161
162    /// Predict with parent-leaf linear interpolation.
163    #[inline]
164    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
165        self.slot.predict_interpolated(features)
166    }
167
168    /// Predict with sibling-based interpolation for feature-continuous predictions.
169    #[inline]
170    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
171        self.slot.predict_sibling_interpolated(features, bandwidths)
172    }
173
174    /// Predict using per-node auto-bandwidth soft routing.
175    #[inline]
176    pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
177        self.slot.predict_soft_routed(features)
178    }
179
180    /// Predict with graduated active-shadow blending.
181    #[inline]
182    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
183        self.slot.predict_graduated(features)
184    }
185
186    /// Predict with graduated blending + sibling interpolation.
187    #[inline]
188    pub fn predict_graduated_sibling_interpolated(
189        &self,
190        features: &[f64],
191        bandwidths: &[f64],
192    ) -> f64 {
193        self.slot
194            .predict_graduated_sibling_interpolated(features, bandwidths)
195    }
196
197    /// Number of leaves in the active tree.
198    #[inline]
199    pub fn n_leaves(&self) -> usize {
200        self.slot.n_leaves()
201    }
202
203    /// Total samples the active tree has seen.
204    #[inline]
205    pub fn n_samples_seen(&self) -> u64 {
206        self.slot.n_samples_seen()
207    }
208
209    /// Whether the slot has an alternate tree being trained.
210    #[inline]
211    pub fn has_alternate(&self) -> bool {
212        self.slot.has_alternate()
213    }
214
215    /// Reset to a completely fresh state: new tree, no alternate, reset detector.
216    pub fn reset(&mut self) {
217        self.slot.reset();
218    }
219
220    /// Immutable access to the underlying [`TreeSlot`].
221    #[inline]
222    pub fn slot(&self) -> &TreeSlot {
223        &self.slot
224    }
225
226    /// Mutable access to the underlying [`TreeSlot`].
227    #[inline]
228    pub fn slot_mut(&mut self) -> &mut TreeSlot {
229        &mut self.slot
230    }
231}
232
233// ---------------------------------------------------------------------------
234// Tests
235// ---------------------------------------------------------------------------
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::drift::pht::PageHinkleyTest;
241    use alloc::boxed::Box;
242    use alloc::format;
243
244    /// Create a default TreeConfig for tests.
245    fn test_tree_config() -> TreeConfig {
246        TreeConfig::new()
247            .grace_period(20)
248            .max_depth(4)
249            .n_bins(16)
250            .lambda(1.0)
251    }
252
253    /// Create a default drift detector for tests.
254    fn test_detector() -> Box<dyn DriftDetector> {
255        Box::new(PageHinkleyTest::new())
256    }
257
258    // -------------------------------------------------------------------
259    // Test 1: train_count=0 skips training (just predicts).
260    // -------------------------------------------------------------------
261    #[test]
262    fn train_count_zero_skips_training() {
263        let mut step = BoostingStep::new(test_tree_config(), test_detector());
264        let features = [1.0, 2.0, 3.0];
265
266        // Train with count=0 should not actually train.
267        let pred = step.train_and_predict(&features, -0.5, 1.0, 0);
268        assert!(
269            pred.abs() < 1e-12,
270            "train_count=0 should return fresh prediction (~0.0), got {}",
271            pred,
272        );
273
274        // Verify no samples were actually trained.
275        assert_eq!(
276            step.n_samples_seen(),
277            0,
278            "train_count=0 should not increment samples_seen",
279        );
280    }
281
282    // -------------------------------------------------------------------
283    // Test 2: train_count=1 trains once.
284    // -------------------------------------------------------------------
285    #[test]
286    fn train_count_one_trains_once() {
287        let mut step = BoostingStep::new(test_tree_config(), test_detector());
288        let features = [1.0, 2.0, 3.0];
289
290        let pred = step.train_and_predict(&features, -0.5, 1.0, 1);
291        assert!(
292            pred.abs() < 1e-12,
293            "first prediction should be ~0.0, got {}",
294            pred,
295        );
296
297        // After one training call, the tree should have seen 1 sample.
298        assert_eq!(
299            step.n_samples_seen(),
300            1,
301            "train_count=1 should train exactly once",
302        );
303
304        // Second call should return non-zero (tree has been trained).
305        let pred2 = step.predict(&features);
306        assert!(
307            pred2.is_finite(),
308            "prediction after training should be finite",
309        );
310    }
311
312    // -------------------------------------------------------------------
313    // Test 3: train_count=3 trains multiple times.
314    // -------------------------------------------------------------------
315    #[test]
316    fn train_count_three_trains_multiple_times() {
317        let mut step = BoostingStep::new(test_tree_config(), test_detector());
318        let features = [1.0, 2.0, 3.0];
319
320        let pred = step.train_and_predict(&features, -0.5, 1.0, 3);
321        assert!(
322            pred.abs() < 1e-12,
323            "first prediction should be ~0.0, got {}",
324            pred,
325        );
326
327        // After train_count=3, the tree should have seen 3 samples.
328        assert_eq!(
329            step.n_samples_seen(),
330            3,
331            "train_count=3 should train exactly 3 times",
332        );
333    }
334
335    // -------------------------------------------------------------------
336    // Test 4: Reset works.
337    // -------------------------------------------------------------------
338    #[test]
339    fn reset_clears_state() {
340        let mut step = BoostingStep::new(test_tree_config(), test_detector());
341        let features = [1.0, 2.0, 3.0];
342
343        // Train several samples.
344        for _ in 0..50 {
345            step.train_and_predict(&features, -0.5, 1.0, 1);
346        }
347
348        assert!(step.n_samples_seen() > 0, "should have trained samples");
349
350        step.reset();
351
352        assert_eq!(step.n_leaves(), 1, "after reset, should have 1 leaf");
353        assert_eq!(
354            step.n_samples_seen(),
355            0,
356            "after reset, samples_seen should be 0"
357        );
358        assert!(
359            !step.has_alternate(),
360            "after reset, no alternate should exist"
361        );
362
363        let pred = step.predict(&features);
364        assert!(
365            pred.abs() < 1e-12,
366            "prediction after reset should be ~0.0, got {}",
367            pred,
368        );
369    }
370
371    // -------------------------------------------------------------------
372    // Test 5: Predict-only (no training) works on fresh step.
373    // -------------------------------------------------------------------
374    #[test]
375    fn predict_only_on_fresh_step() {
376        let step = BoostingStep::new(test_tree_config(), test_detector());
377
378        for i in 0..10 {
379            let x = (i as f64) * 0.5;
380            let pred = step.predict(&[x, x + 1.0, x + 2.0]);
381            assert!(
382                pred.abs() < 1e-12,
383                "untrained step should predict ~0.0, got {} at i={}",
384                pred,
385                i,
386            );
387        }
388    }
389
390    // -------------------------------------------------------------------
391    // Test 6: Multiple calls with different train_counts produce expected
392    //         cumulative sample counts.
393    // -------------------------------------------------------------------
394    #[test]
395    fn mixed_train_counts_accumulate_correctly() {
396        let mut step = BoostingStep::new(test_tree_config(), test_detector());
397        let features = [1.0, 2.0, 3.0];
398
399        // count=2 -> 2 samples
400        step.train_and_predict(&features, -0.1, 1.0, 2);
401        assert_eq!(step.n_samples_seen(), 2);
402
403        // count=0 -> still 2 samples (skipped)
404        step.train_and_predict(&features, -0.1, 1.0, 0);
405        assert_eq!(step.n_samples_seen(), 2);
406
407        // count=1 -> 3 samples
408        step.train_and_predict(&features, -0.1, 1.0, 1);
409        assert_eq!(step.n_samples_seen(), 3);
410
411        // count=5 -> 8 samples
412        step.train_and_predict(&features, -0.1, 1.0, 5);
413        assert_eq!(step.n_samples_seen(), 8);
414    }
415
416    // -------------------------------------------------------------------
417    // Test 7: n_leaves and has_alternate passthrough to slot.
418    // -------------------------------------------------------------------
419    #[test]
420    fn accessors_match_slot() {
421        let step = BoostingStep::new(test_tree_config(), test_detector());
422
423        assert_eq!(step.n_leaves(), step.slot().n_leaves());
424        assert_eq!(step.has_alternate(), step.slot().has_alternate());
425        assert_eq!(step.n_samples_seen(), step.slot().n_samples_seen());
426    }
427
428    // -------------------------------------------------------------------
429    // Test 8: Debug formatting works.
430    // -------------------------------------------------------------------
431    #[test]
432    fn debug_format_does_not_panic() {
433        let step = BoostingStep::new(test_tree_config(), test_detector());
434        let debug_str = format!("{:?}", step);
435        assert!(
436            debug_str.contains("BoostingStep"),
437            "debug output should contain 'BoostingStep'",
438        );
439    }
440}