irithyll_core/ensemble/step.rs
1//! Single boosting step: owns one tree + drift detector + optional alternate.
2//!
3//! [`BoostingStep`] is a thin wrapper around [`TreeSlot`]
4//! that adds SGBT variant logic. The three variants from Gunasekara et al. (2024) are:
5//!
6//! - **Standard** (`train_count = 1`): each sample trains the tree exactly once.
7//! - **Skip** (`train_count = 0`): the sample is skipped (only prediction returned).
8//! - **Multiple Iterations** (`train_count > 1`): the sample trains the tree
9//! multiple times, weighted by the hessian.
10//!
11//! The variant logic is computed externally (by `SGBTVariant::train_count()` in the
12//! ensemble orchestrator) and passed in as `train_count`. This keeps `BoostingStep`
13//! focused on execution rather than policy.
14
15use alloc::boxed::Box;
16use core::fmt;
17
18use crate::drift::DriftDetector;
19use crate::ensemble::replacement::TreeSlot;
20use crate::tree::builder::TreeConfig;
21
22/// A single step in the SGBT boosting sequence.
23///
24/// Owns a [`TreeSlot`] and applies variant-aware training repetition. The number
25/// of training iterations per sample is determined by the caller (the ensemble
26/// orchestrator computes `train_count` from the configured SGBT variant).
27///
28/// # Prediction semantics
29///
30/// Both [`train_and_predict`](BoostingStep::train_and_predict) and
31/// [`predict`](BoostingStep::predict) return the active tree's prediction
32/// **before** any training on the current sample. This ensures unbiased
33/// gradient computation in the boosting loop.
34#[derive(Clone)]
35pub struct BoostingStep {
36 /// The tree slot managing the active tree, alternate, and drift detector.
37 slot: TreeSlot,
38}
39
40impl fmt::Debug for BoostingStep {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.debug_struct("BoostingStep")
43 .field("slot", &self.slot)
44 .finish()
45 }
46}
47
48impl BoostingStep {
49 /// Create a new boosting step with a fresh tree and drift detector.
50 pub fn new(tree_config: TreeConfig, detector: Box<dyn DriftDetector>) -> Self {
51 Self {
52 slot: TreeSlot::new(tree_config, detector, None),
53 }
54 }
55
56 /// Create a new boosting step with optional time-based tree replacement.
57 pub fn new_with_max_samples(
58 tree_config: TreeConfig,
59 detector: Box<dyn DriftDetector>,
60 max_tree_samples: Option<u64>,
61 ) -> Self {
62 Self {
63 slot: TreeSlot::new(tree_config, detector, max_tree_samples),
64 }
65 }
66
67 /// Create a new boosting step with graduated tree handoff.
68 pub fn new_with_graduated(
69 tree_config: TreeConfig,
70 detector: Box<dyn DriftDetector>,
71 max_tree_samples: Option<u64>,
72 shadow_warmup: usize,
73 ) -> Self {
74 Self {
75 slot: TreeSlot::with_shadow_warmup(
76 tree_config,
77 detector,
78 max_tree_samples,
79 shadow_warmup,
80 ),
81 }
82 }
83
84 /// Reconstruct a boosting step from a pre-built tree slot.
85 ///
86 /// Used during model deserialization.
87 pub fn from_slot(slot: TreeSlot) -> Self {
88 Self { slot }
89 }
90
91 /// Train on a single sample with variant-aware repetition.
92 ///
93 /// # Arguments
94 ///
95 /// * `features` - Input feature vector.
96 /// * `gradient` - Negative gradient of the loss at this sample.
97 /// * `hessian` - Second derivative (curvature) of the loss at this sample.
98 /// * `train_count` - Number of training iterations for this sample:
99 /// - `0`: skip training entirely (SK variant or stochastic skip).
100 /// - `1`: standard single-pass training.
101 /// - `>1`: multiple iterations (MI variant).
102 ///
103 /// # Returns
104 ///
105 /// The prediction from the active tree **before** training.
106 pub fn train_and_predict(
107 &mut self,
108 features: &[f64],
109 gradient: f64,
110 hessian: f64,
111 train_count: usize,
112 ) -> f64 {
113 if train_count == 0 {
114 // Skip variant: no training, just predict.
115 return self.slot.predict(features);
116 }
117
118 // First iteration: train and get the pre-training prediction.
119 let pred = self.slot.train_and_predict(features, gradient, hessian);
120
121 // Additional iterations for MI variant.
122 // Each subsequent call still feeds the same gradient/hessian to the
123 // tree and drift detector, effectively weighting this sample more heavily.
124 for _ in 1..train_count {
125 self.slot.train_and_predict(features, gradient, hessian);
126 }
127
128 pred
129 }
130
131 /// Predict without training.
132 ///
133 /// Routes the feature vector through the active tree and returns the
134 /// leaf value. Does not update any state.
135 #[inline]
136 pub fn predict(&self, features: &[f64]) -> f64 {
137 self.slot.predict(features)
138 }
139
140 /// Predict with variance for confidence estimation.
141 ///
142 /// Returns `(leaf_value, variance)` where variance = 1 / (H_sum + lambda).
143 #[inline]
144 pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
145 self.slot.predict_with_variance(features)
146 }
147
148 /// Predict using sigmoid-blended soft routing for smooth interpolation.
149 ///
150 /// See [`crate::tree::hoeffding::HoeffdingTree::predict_smooth`] for details.
151 #[inline]
152 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
153 self.slot.predict_smooth(features, bandwidth)
154 }
155
156 /// Predict using per-feature auto-calibrated bandwidths.
157 #[inline]
158 pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
159 self.slot.predict_smooth_auto(features, bandwidths)
160 }
161
162 /// Predict with parent-leaf linear interpolation.
163 #[inline]
164 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
165 self.slot.predict_interpolated(features)
166 }
167
168 /// Predict with sibling-based interpolation for feature-continuous predictions.
169 #[inline]
170 pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
171 self.slot.predict_sibling_interpolated(features, bandwidths)
172 }
173
174 /// Predict with graduated active-shadow blending.
175 #[inline]
176 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
177 self.slot.predict_graduated(features)
178 }
179
180 /// Predict with graduated blending + sibling interpolation.
181 #[inline]
182 pub fn predict_graduated_sibling_interpolated(
183 &self,
184 features: &[f64],
185 bandwidths: &[f64],
186 ) -> f64 {
187 self.slot
188 .predict_graduated_sibling_interpolated(features, bandwidths)
189 }
190
191 /// Number of leaves in the active tree.
192 #[inline]
193 pub fn n_leaves(&self) -> usize {
194 self.slot.n_leaves()
195 }
196
197 /// Total samples the active tree has seen.
198 #[inline]
199 pub fn n_samples_seen(&self) -> u64 {
200 self.slot.n_samples_seen()
201 }
202
203 /// Whether the slot has an alternate tree being trained.
204 #[inline]
205 pub fn has_alternate(&self) -> bool {
206 self.slot.has_alternate()
207 }
208
209 /// Reset to a completely fresh state: new tree, no alternate, reset detector.
210 pub fn reset(&mut self) {
211 self.slot.reset();
212 }
213
214 /// Immutable access to the underlying [`TreeSlot`].
215 #[inline]
216 pub fn slot(&self) -> &TreeSlot {
217 &self.slot
218 }
219
220 /// Mutable access to the underlying [`TreeSlot`].
221 #[inline]
222 pub fn slot_mut(&mut self) -> &mut TreeSlot {
223 &mut self.slot
224 }
225}
226
227// ---------------------------------------------------------------------------
228// Tests
229// ---------------------------------------------------------------------------
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::drift::pht::PageHinkleyTest;
235 use alloc::boxed::Box;
236 use alloc::format;
237
238 /// Create a default TreeConfig for tests.
239 fn test_tree_config() -> TreeConfig {
240 TreeConfig::new()
241 .grace_period(20)
242 .max_depth(4)
243 .n_bins(16)
244 .lambda(1.0)
245 }
246
247 /// Create a default drift detector for tests.
248 fn test_detector() -> Box<dyn DriftDetector> {
249 Box::new(PageHinkleyTest::new())
250 }
251
252 // -------------------------------------------------------------------
253 // Test 1: train_count=0 skips training (just predicts).
254 // -------------------------------------------------------------------
255 #[test]
256 fn train_count_zero_skips_training() {
257 let mut step = BoostingStep::new(test_tree_config(), test_detector());
258 let features = [1.0, 2.0, 3.0];
259
260 // Train with count=0 should not actually train.
261 let pred = step.train_and_predict(&features, -0.5, 1.0, 0);
262 assert!(
263 pred.abs() < 1e-12,
264 "train_count=0 should return fresh prediction (~0.0), got {}",
265 pred,
266 );
267
268 // Verify no samples were actually trained.
269 assert_eq!(
270 step.n_samples_seen(),
271 0,
272 "train_count=0 should not increment samples_seen",
273 );
274 }
275
276 // -------------------------------------------------------------------
277 // Test 2: train_count=1 trains once.
278 // -------------------------------------------------------------------
279 #[test]
280 fn train_count_one_trains_once() {
281 let mut step = BoostingStep::new(test_tree_config(), test_detector());
282 let features = [1.0, 2.0, 3.0];
283
284 let pred = step.train_and_predict(&features, -0.5, 1.0, 1);
285 assert!(
286 pred.abs() < 1e-12,
287 "first prediction should be ~0.0, got {}",
288 pred,
289 );
290
291 // After one training call, the tree should have seen 1 sample.
292 assert_eq!(
293 step.n_samples_seen(),
294 1,
295 "train_count=1 should train exactly once",
296 );
297
298 // Second call should return non-zero (tree has been trained).
299 let pred2 = step.predict(&features);
300 assert!(
301 pred2.is_finite(),
302 "prediction after training should be finite",
303 );
304 }
305
306 // -------------------------------------------------------------------
307 // Test 3: train_count=3 trains multiple times.
308 // -------------------------------------------------------------------
309 #[test]
310 fn train_count_three_trains_multiple_times() {
311 let mut step = BoostingStep::new(test_tree_config(), test_detector());
312 let features = [1.0, 2.0, 3.0];
313
314 let pred = step.train_and_predict(&features, -0.5, 1.0, 3);
315 assert!(
316 pred.abs() < 1e-12,
317 "first prediction should be ~0.0, got {}",
318 pred,
319 );
320
321 // After train_count=3, the tree should have seen 3 samples.
322 assert_eq!(
323 step.n_samples_seen(),
324 3,
325 "train_count=3 should train exactly 3 times",
326 );
327 }
328
329 // -------------------------------------------------------------------
330 // Test 4: Reset works.
331 // -------------------------------------------------------------------
332 #[test]
333 fn reset_clears_state() {
334 let mut step = BoostingStep::new(test_tree_config(), test_detector());
335 let features = [1.0, 2.0, 3.0];
336
337 // Train several samples.
338 for _ in 0..50 {
339 step.train_and_predict(&features, -0.5, 1.0, 1);
340 }
341
342 assert!(step.n_samples_seen() > 0, "should have trained samples");
343
344 step.reset();
345
346 assert_eq!(step.n_leaves(), 1, "after reset, should have 1 leaf");
347 assert_eq!(
348 step.n_samples_seen(),
349 0,
350 "after reset, samples_seen should be 0"
351 );
352 assert!(
353 !step.has_alternate(),
354 "after reset, no alternate should exist"
355 );
356
357 let pred = step.predict(&features);
358 assert!(
359 pred.abs() < 1e-12,
360 "prediction after reset should be ~0.0, got {}",
361 pred,
362 );
363 }
364
365 // -------------------------------------------------------------------
366 // Test 5: Predict-only (no training) works on fresh step.
367 // -------------------------------------------------------------------
368 #[test]
369 fn predict_only_on_fresh_step() {
370 let step = BoostingStep::new(test_tree_config(), test_detector());
371
372 for i in 0..10 {
373 let x = (i as f64) * 0.5;
374 let pred = step.predict(&[x, x + 1.0, x + 2.0]);
375 assert!(
376 pred.abs() < 1e-12,
377 "untrained step should predict ~0.0, got {} at i={}",
378 pred,
379 i,
380 );
381 }
382 }
383
384 // -------------------------------------------------------------------
385 // Test 6: Multiple calls with different train_counts produce expected
386 // cumulative sample counts.
387 // -------------------------------------------------------------------
388 #[test]
389 fn mixed_train_counts_accumulate_correctly() {
390 let mut step = BoostingStep::new(test_tree_config(), test_detector());
391 let features = [1.0, 2.0, 3.0];
392
393 // count=2 -> 2 samples
394 step.train_and_predict(&features, -0.1, 1.0, 2);
395 assert_eq!(step.n_samples_seen(), 2);
396
397 // count=0 -> still 2 samples (skipped)
398 step.train_and_predict(&features, -0.1, 1.0, 0);
399 assert_eq!(step.n_samples_seen(), 2);
400
401 // count=1 -> 3 samples
402 step.train_and_predict(&features, -0.1, 1.0, 1);
403 assert_eq!(step.n_samples_seen(), 3);
404
405 // count=5 -> 8 samples
406 step.train_and_predict(&features, -0.1, 1.0, 5);
407 assert_eq!(step.n_samples_seen(), 8);
408 }
409
410 // -------------------------------------------------------------------
411 // Test 7: n_leaves and has_alternate passthrough to slot.
412 // -------------------------------------------------------------------
413 #[test]
414 fn accessors_match_slot() {
415 let step = BoostingStep::new(test_tree_config(), test_detector());
416
417 assert_eq!(step.n_leaves(), step.slot().n_leaves());
418 assert_eq!(step.has_alternate(), step.slot().has_alternate());
419 assert_eq!(step.n_samples_seen(), step.slot().n_samples_seen());
420 }
421
422 // -------------------------------------------------------------------
423 // Test 8: Debug formatting works.
424 // -------------------------------------------------------------------
425 #[test]
426 fn debug_format_does_not_panic() {
427 let step = BoostingStep::new(test_tree_config(), test_detector());
428 let debug_str = format!("{:?}", step);
429 assert!(
430 debug_str.contains("BoostingStep"),
431 "debug output should contain 'BoostingStep'",
432 );
433 }
434}