use alloc::boxed::Box;
use core::fmt;
use crate::drift::DriftDetector;
use crate::ensemble::replacement::TreeSlot;
use crate::tree::builder::TreeConfig;
#[derive(Clone)]
pub struct BoostingStep {
slot: TreeSlot,
}
impl fmt::Debug for BoostingStep {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoostingStep")
.field("slot", &self.slot)
.finish()
}
}
impl BoostingStep {
pub fn new(tree_config: TreeConfig, detector: Box<dyn DriftDetector>) -> Self {
Self {
slot: TreeSlot::new(tree_config, detector, None),
}
}
pub fn new_with_max_samples(
tree_config: TreeConfig,
detector: Box<dyn DriftDetector>,
max_tree_samples: Option<u64>,
) -> Self {
Self {
slot: TreeSlot::new(tree_config, detector, max_tree_samples),
}
}
pub fn new_with_graduated(
tree_config: TreeConfig,
detector: Box<dyn DriftDetector>,
max_tree_samples: Option<u64>,
shadow_warmup: usize,
) -> Self {
Self {
slot: TreeSlot::with_shadow_warmup(
tree_config,
detector,
max_tree_samples,
shadow_warmup,
),
}
}
pub fn from_slot(slot: TreeSlot) -> Self {
Self { slot }
}
pub fn train_and_predict(
&mut self,
features: &[f64],
gradient: f64,
hessian: f64,
train_count: usize,
) -> f64 {
if train_count == 0 {
return self.slot.predict(features);
}
let pred = self.slot.train_and_predict(features, gradient, hessian);
for _ in 1..train_count {
self.slot.train_and_predict(features, gradient, hessian);
}
pred
}
#[inline]
pub fn predict(&self, features: &[f64]) -> f64 {
self.slot.predict(features)
}
#[inline]
pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
self.slot.predict_with_variance(features)
}
#[inline]
pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
self.slot.predict_smooth(features, bandwidth)
}
#[inline]
pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
self.slot.predict_smooth_auto(features, bandwidths)
}
#[inline]
pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
self.slot.predict_interpolated(features)
}
#[inline]
pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
self.slot.predict_sibling_interpolated(features, bandwidths)
}
#[inline]
pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
self.slot.predict_soft_routed(features)
}
#[inline]
pub fn predict_graduated(&self, features: &[f64]) -> f64 {
self.slot.predict_graduated(features)
}
#[inline]
pub fn predict_graduated_sibling_interpolated(
&self,
features: &[f64],
bandwidths: &[f64],
) -> f64 {
self.slot
.predict_graduated_sibling_interpolated(features, bandwidths)
}
#[inline]
pub fn n_leaves(&self) -> usize {
self.slot.n_leaves()
}
#[inline]
pub fn n_samples_seen(&self) -> u64 {
self.slot.n_samples_seen()
}
#[inline]
pub fn has_alternate(&self) -> bool {
self.slot.has_alternate()
}
pub fn reset(&mut self) {
self.slot.reset();
}
#[inline]
pub fn slot(&self) -> &TreeSlot {
&self.slot
}
#[inline]
pub fn slot_mut(&mut self) -> &mut TreeSlot {
&mut self.slot
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::drift::pht::PageHinkleyTest;
use alloc::boxed::Box;
use alloc::format;
fn test_tree_config() -> TreeConfig {
TreeConfig::new()
.grace_period(20)
.max_depth(4)
.n_bins(16)
.lambda(1.0)
}
fn test_detector() -> Box<dyn DriftDetector> {
Box::new(PageHinkleyTest::new())
}
#[test]
fn train_count_zero_skips_training() {
let mut step = BoostingStep::new(test_tree_config(), test_detector());
let features = [1.0, 2.0, 3.0];
let pred = step.train_and_predict(&features, -0.5, 1.0, 0);
assert!(
pred.abs() < 1e-12,
"train_count=0 should return fresh prediction (~0.0), got {}",
pred,
);
assert_eq!(
step.n_samples_seen(),
0,
"train_count=0 should not increment samples_seen",
);
}
#[test]
fn train_count_one_trains_once() {
let mut step = BoostingStep::new(test_tree_config(), test_detector());
let features = [1.0, 2.0, 3.0];
let pred = step.train_and_predict(&features, -0.5, 1.0, 1);
assert!(
pred.abs() < 1e-12,
"first prediction should be ~0.0, got {}",
pred,
);
assert_eq!(
step.n_samples_seen(),
1,
"train_count=1 should train exactly once",
);
let pred2 = step.predict(&features);
assert!(
pred2.is_finite(),
"prediction after training should be finite",
);
}
#[test]
fn train_count_three_trains_multiple_times() {
let mut step = BoostingStep::new(test_tree_config(), test_detector());
let features = [1.0, 2.0, 3.0];
let pred = step.train_and_predict(&features, -0.5, 1.0, 3);
assert!(
pred.abs() < 1e-12,
"first prediction should be ~0.0, got {}",
pred,
);
assert_eq!(
step.n_samples_seen(),
3,
"train_count=3 should train exactly 3 times",
);
}
#[test]
fn reset_clears_state() {
let mut step = BoostingStep::new(test_tree_config(), test_detector());
let features = [1.0, 2.0, 3.0];
for _ in 0..50 {
step.train_and_predict(&features, -0.5, 1.0, 1);
}
assert!(step.n_samples_seen() > 0, "should have trained samples");
step.reset();
assert_eq!(step.n_leaves(), 1, "after reset, should have 1 leaf");
assert_eq!(
step.n_samples_seen(),
0,
"after reset, samples_seen should be 0"
);
assert!(
!step.has_alternate(),
"after reset, no alternate should exist"
);
let pred = step.predict(&features);
assert!(
pred.abs() < 1e-12,
"prediction after reset should be ~0.0, got {}",
pred,
);
}
#[test]
fn predict_only_on_fresh_step() {
let step = BoostingStep::new(test_tree_config(), test_detector());
for i in 0..10 {
let x = (i as f64) * 0.5;
let pred = step.predict(&[x, x + 1.0, x + 2.0]);
assert!(
pred.abs() < 1e-12,
"untrained step should predict ~0.0, got {} at i={}",
pred,
i,
);
}
}
#[test]
fn mixed_train_counts_accumulate_correctly() {
let mut step = BoostingStep::new(test_tree_config(), test_detector());
let features = [1.0, 2.0, 3.0];
step.train_and_predict(&features, -0.1, 1.0, 2);
assert_eq!(step.n_samples_seen(), 2);
step.train_and_predict(&features, -0.1, 1.0, 0);
assert_eq!(step.n_samples_seen(), 2);
step.train_and_predict(&features, -0.1, 1.0, 1);
assert_eq!(step.n_samples_seen(), 3);
step.train_and_predict(&features, -0.1, 1.0, 5);
assert_eq!(step.n_samples_seen(), 8);
}
#[test]
fn accessors_match_slot() {
let step = BoostingStep::new(test_tree_config(), test_detector());
assert_eq!(step.n_leaves(), step.slot().n_leaves());
assert_eq!(step.has_alternate(), step.slot().has_alternate());
assert_eq!(step.n_samples_seen(), step.slot().n_samples_seen());
}
#[test]
fn debug_format_does_not_panic() {
let step = BoostingStep::new(test_tree_config(), test_detector());
let debug_str = format!("{:?}", step);
assert!(
debug_str.contains("BoostingStep"),
"debug output should contain 'BoostingStep'",
);
}
}