1pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62pub struct SGBT<L: Loss = SquaredLoss> {
90 config: SGBTConfig,
92 steps: Vec<BoostingStep>,
94 loss: L,
96 base_prediction: f64,
98 base_initialized: bool,
100 initial_targets: Vec<f64>,
102 initial_target_count: usize,
104 samples_seen: u64,
106 rng_state: u64,
108 contribution_ewma: Vec<f64>,
111 low_contrib_count: Vec<u64>,
114 rolling_mean_error: f64,
117 auto_bandwidths: Vec<f64>,
120 last_replacement_sum: u64,
123}
124
125impl<L: Loss + Clone> Clone for SGBT<L> {
126 fn clone(&self) -> Self {
127 Self {
128 config: self.config.clone(),
129 steps: self.steps.clone(),
130 loss: self.loss.clone(),
131 base_prediction: self.base_prediction,
132 base_initialized: self.base_initialized,
133 initial_targets: self.initial_targets.clone(),
134 initial_target_count: self.initial_target_count,
135 samples_seen: self.samples_seen,
136 rng_state: self.rng_state,
137 contribution_ewma: self.contribution_ewma.clone(),
138 low_contrib_count: self.low_contrib_count.clone(),
139 rolling_mean_error: self.rolling_mean_error,
140 auto_bandwidths: self.auto_bandwidths.clone(),
141 last_replacement_sum: self.last_replacement_sum,
142 }
143 }
144}
145
146impl<L: Loss> fmt::Debug for SGBT<L> {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 f.debug_struct("SGBT")
149 .field("n_steps", &self.steps.len())
150 .field("samples_seen", &self.samples_seen)
151 .field("base_prediction", &self.base_prediction)
152 .field("base_initialized", &self.base_initialized)
153 .finish()
154 }
155}
156
157impl SGBT<SquaredLoss> {
162 pub fn new(config: SGBTConfig) -> Self {
167 Self::with_loss(config, SquaredLoss)
168 }
169}
170
171impl<L: Loss> SGBT<L> {
176 pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
189 let leaf_decay_alpha = config
190 .leaf_half_life
191 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
192
193 let tree_config = TreeConfig::new()
194 .max_depth(config.max_depth)
195 .n_bins(config.n_bins)
196 .lambda(config.lambda)
197 .gamma(config.gamma)
198 .grace_period(config.grace_period)
199 .delta(config.delta)
200 .feature_subsample_rate(config.feature_subsample_rate)
201 .leaf_decay_alpha_opt(leaf_decay_alpha)
202 .split_reeval_interval_opt(config.split_reeval_interval)
203 .feature_types_opt(config.feature_types.clone())
204 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
205 .monotone_constraints_opt(config.monotone_constraints.clone())
206 .max_leaf_output_opt(config.max_leaf_output)
207 .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
208 .min_hessian_sum_opt(config.min_hessian_sum)
209 .leaf_model_type(config.leaf_model_type.clone());
210
211 let max_tree_samples = config.max_tree_samples;
212
213 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
214 let steps: Vec<BoostingStep> = (0..config.n_steps)
215 .map(|i| {
216 let mut tc = tree_config.clone();
217 tc.seed = config.seed ^ (i as u64);
218 let detector = config.drift_detector.create();
219 if shadow_warmup > 0 {
220 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
221 } else {
222 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
223 }
224 })
225 .collect();
226
227 let seed = config.seed;
228 let initial_target_count = config.initial_target_count;
229 let n = config.n_steps;
230 let has_pruning = config.quality_prune_alpha.is_some();
231 Self {
232 config,
233 steps,
234 loss,
235 base_prediction: 0.0,
236 base_initialized: false,
237 initial_targets: Vec::new(),
238 initial_target_count,
239 samples_seen: 0,
240 rng_state: seed,
241 contribution_ewma: if has_pruning {
242 vec![0.0; n]
243 } else {
244 Vec::new()
245 },
246 low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
247 rolling_mean_error: 0.0,
248 auto_bandwidths: Vec::new(),
249 last_replacement_sum: 0,
250 }
251 }
252
253 pub fn train_one(&mut self, sample: &impl Observation) {
259 self.samples_seen += 1;
260 let target = sample.target();
261 let features = sample.features();
262
263 if !self.base_initialized {
265 self.initial_targets.push(target);
266 if self.initial_targets.len() >= self.initial_target_count {
267 self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
268 self.base_initialized = true;
269 self.initial_targets.clear();
270 self.initial_targets.shrink_to_fit();
271 }
272 }
273
274 let mut current_pred = self.base_prediction;
276
277 let prune_alpha = self.config.quality_prune_alpha;
278 let prune_threshold = self.config.quality_prune_threshold;
279 let prune_patience = self.config.quality_prune_patience;
280
281 let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
283 let abs_error = crate::math::abs(target - current_pred);
284 if self.rolling_mean_error > 1e-15 {
285 let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
286 self.rolling_mean_error =
287 ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
288 w
289 } else {
290 self.rolling_mean_error = abs_error.max(1e-15);
291 1.0 }
293 } else {
294 1.0
295 };
296
297 for s in 0..self.steps.len() {
299 let gradient = self.loss.gradient(target, current_pred) * error_weight;
300 let hessian = self.loss.hessian(target, current_pred) * error_weight;
301 let train_count = self
302 .config
303 .variant
304 .train_count(hessian, &mut self.rng_state);
305
306 let step_pred =
307 self.steps[s].train_and_predict(features, gradient, hessian, train_count);
308
309 current_pred += self.config.learning_rate * step_pred;
310
311 if let Some(alpha) = prune_alpha {
313 let contribution = crate::math::abs(self.config.learning_rate * step_pred);
314 self.contribution_ewma[s] =
315 alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
316
317 if self.contribution_ewma[s] < prune_threshold {
318 self.low_contrib_count[s] += 1;
319 if self.low_contrib_count[s] >= prune_patience {
320 self.steps[s].reset();
321 self.contribution_ewma[s] = 0.0;
322 self.low_contrib_count[s] = 0;
323 }
324 } else {
325 self.low_contrib_count[s] = 0;
326 }
327 }
328 }
329
330 self.refresh_bandwidths();
332 }
333
334 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
336 for sample in samples {
337 self.train_one(sample);
338 }
339 }
340
341 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
362 &mut self,
363 samples: &[O],
364 interval: usize,
365 mut callback: F,
366 ) {
367 let interval = interval.max(1); for (i, sample) in samples.iter().enumerate() {
369 self.train_one(sample);
370 if (i + 1) % interval == 0 {
371 callback(i + 1);
372 }
373 }
374 let total = samples.len();
376 if total % interval != 0 {
377 callback(total);
378 }
379 }
380
381 pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
394 if max_samples >= samples.len() {
395 self.train_batch(samples);
396 return;
397 }
398
399 let mut reservoir: Vec<usize> = (0..max_samples).collect();
401 let mut rng = self.rng_state;
402
403 for i in max_samples..samples.len() {
404 rng ^= rng << 13;
406 rng ^= rng >> 7;
407 rng ^= rng << 17;
408 let j = (rng % (i as u64 + 1)) as usize;
409 if j < max_samples {
410 reservoir[j] = i;
411 }
412 }
413
414 self.rng_state = rng;
415
416 reservoir.sort_unstable();
418
419 for &idx in &reservoir {
421 self.train_one(&samples[idx]);
422 }
423 }
424
425 pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
431 &mut self,
432 samples: &[O],
433 max_samples: usize,
434 interval: usize,
435 mut callback: F,
436 ) {
437 if max_samples >= samples.len() {
438 self.train_batch_with_callback(samples, interval, callback);
439 return;
440 }
441
442 let mut reservoir: Vec<usize> = (0..max_samples).collect();
444 let mut rng = self.rng_state;
445
446 for i in max_samples..samples.len() {
447 rng ^= rng << 13;
448 rng ^= rng >> 7;
449 rng ^= rng << 17;
450 let j = (rng % (i as u64 + 1)) as usize;
451 if j < max_samples {
452 reservoir[j] = i;
453 }
454 }
455
456 self.rng_state = rng;
457 reservoir.sort_unstable();
458
459 let interval = interval.max(1);
460 for (i, &idx) in reservoir.iter().enumerate() {
461 self.train_one(&samples[idx]);
462 if (i + 1) % interval == 0 {
463 callback(i + 1);
464 }
465 }
466 let total = reservoir.len();
467 if total % interval != 0 {
468 callback(total);
469 }
470 }
471
472 pub fn predict(&self, features: &[f64]) -> f64 {
478 let mut pred = self.base_prediction;
479 if self.auto_bandwidths.is_empty() {
480 for step in &self.steps {
482 pred += self.config.learning_rate * step.predict(features);
483 }
484 } else {
485 for step in &self.steps {
486 pred += self.config.learning_rate
487 * step.predict_smooth_auto(features, &self.auto_bandwidths);
488 }
489 }
490 pred
491 }
492
493 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
498 let mut pred = self.base_prediction;
499 for step in &self.steps {
500 pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
501 }
502 pred
503 }
504
505 pub fn auto_bandwidths(&self) -> &[f64] {
511 &self.auto_bandwidths
512 }
513
514 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
519 let mut pred = self.base_prediction;
520 for step in &self.steps {
521 pred += self.config.learning_rate * step.predict_interpolated(features);
522 }
523 pred
524 }
525
526 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
534 let mut pred = self.base_prediction;
535 for step in &self.steps {
536 pred += self.config.learning_rate
537 * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
538 }
539 pred
540 }
541
542 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
548 let mut pred = self.base_prediction;
549 for step in &self.steps {
550 pred += self.config.learning_rate * step.predict_graduated(features);
551 }
552 pred
553 }
554
555 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
561 let mut pred = self.base_prediction;
562 for step in &self.steps {
563 pred += self.config.learning_rate
564 * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
565 }
566 pred
567 }
568
569 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
571 self.loss.predict_transform(self.predict(features))
572 }
573
574 pub fn predict_proba(&self, features: &[f64]) -> f64 {
576 self.predict_transformed(features)
577 }
578
579 pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
594 let mut pred = self.base_prediction;
595 let mut total_variance = 0.0;
596 let lr2 = self.config.learning_rate * self.config.learning_rate;
597
598 for step in &self.steps {
599 let (value, variance) = step.predict_with_variance(features);
600 pred += self.config.learning_rate * value;
601 total_variance += lr2 * variance;
602 }
603
604 let confidence = if total_variance > 0.0 && total_variance.is_finite() {
605 1.0 / crate::math::sqrt(total_variance)
606 } else {
607 0.0
608 };
609
610 (pred, confidence)
611 }
612
613 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
615 feature_matrix.iter().map(|f| self.predict(f)).collect()
616 }
617
618 pub fn n_steps(&self) -> usize {
620 self.steps.len()
621 }
622
623 pub fn n_trees(&self) -> usize {
625 self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
626 }
627
628 pub fn total_leaves(&self) -> usize {
630 self.steps.iter().map(|s| s.n_leaves()).sum()
631 }
632
633 pub fn n_samples_seen(&self) -> u64 {
635 self.samples_seen
636 }
637
638 pub fn base_prediction(&self) -> f64 {
640 self.base_prediction
641 }
642
643 pub fn is_initialized(&self) -> bool {
645 self.base_initialized
646 }
647
648 pub fn config(&self) -> &SGBTConfig {
650 &self.config
651 }
652
653 #[inline]
662 pub fn set_learning_rate(&mut self, lr: f64) {
663 self.config.learning_rate = lr;
664 }
665
666 pub fn steps(&self) -> &[BoostingStep] {
670 &self.steps
671 }
672
673 pub fn loss(&self) -> &L {
675 &self.loss
676 }
677
678 pub fn feature_importances(&self) -> Vec<f64> {
683 let mut totals: Vec<f64> = Vec::new();
685 for step in &self.steps {
686 let gains = step.slot().split_gains();
687 if totals.is_empty() && !gains.is_empty() {
688 totals.resize(gains.len(), 0.0);
689 }
690 for (i, &g) in gains.iter().enumerate() {
691 if i < totals.len() {
692 totals[i] += g;
693 }
694 }
695 }
696
697 let sum: f64 = totals.iter().sum();
699 if sum > 0.0 {
700 totals.iter_mut().for_each(|v| *v /= sum);
701 }
702 totals
703 }
704
705 pub fn feature_names(&self) -> Option<&[String]> {
707 self.config.feature_names.as_deref()
708 }
709
710 pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
715 let names = self.config.feature_names.as_ref()?;
716 let importances = self.feature_importances();
717 let mut pairs: Vec<(String, f64)> = names
718 .iter()
719 .zip(importances.iter().chain(core::iter::repeat(&0.0)))
720 .map(|(n, &v)| (n.clone(), v))
721 .collect();
722 pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
723 Some(pairs)
724 }
725
726 #[cfg(feature = "std")]
735 pub fn train_one_named(
736 &mut self,
737 features: &std::collections::HashMap<alloc::string::String, f64>,
738 target: f64,
739 ) {
740 let names = self
741 .config
742 .feature_names
743 .as_ref()
744 .expect("train_one_named requires feature_names to be configured");
745 let vec: Vec<f64> = names
746 .iter()
747 .map(|name| features.get(name).copied().unwrap_or(0.0))
748 .collect();
749 self.train_one(&(&vec[..], target));
750 }
751
752 #[cfg(feature = "std")]
760 pub fn predict_named(
761 &self,
762 features: &std::collections::HashMap<alloc::string::String, f64>,
763 ) -> f64 {
764 let names = self
765 .config
766 .feature_names
767 .as_ref()
768 .expect("predict_named requires feature_names to be configured");
769 let vec: Vec<f64> = names
770 .iter()
771 .map(|name| features.get(name).copied().unwrap_or(0.0))
772 .collect();
773 self.predict(&vec)
774 }
775
776 fn refresh_bandwidths(&mut self) {
782 let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
783 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
784 self.auto_bandwidths = self.compute_auto_bandwidths();
785 self.last_replacement_sum = current_sum;
786 }
787 }
788
789 fn compute_auto_bandwidths(&self) -> Vec<f64> {
799 const K: f64 = 2.0;
800
801 let n_features = self
803 .steps
804 .iter()
805 .filter_map(|s| s.slot().active_tree().n_features())
806 .max()
807 .unwrap_or(0);
808
809 if n_features == 0 {
810 return Vec::new();
811 }
812
813 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
815
816 for step in &self.steps {
817 let tree_thresholds = step
818 .slot()
819 .active_tree()
820 .collect_split_thresholds_per_feature();
821 for (i, ts) in tree_thresholds.into_iter().enumerate() {
822 if i < n_features {
823 all_thresholds[i].extend(ts);
824 }
825 }
826 }
827
828 let n_bins = self.config.n_bins as f64;
829
830 all_thresholds
832 .iter()
833 .map(|ts| {
834 if ts.is_empty() {
835 return f64::INFINITY; }
837
838 let mut sorted = ts.clone();
839 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
840 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
841
842 if sorted.len() < 2 {
843 return f64::INFINITY; }
845
846 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
848
849 if sorted.len() < 3 {
850 let range = sorted.last().unwrap() - sorted.first().unwrap();
852 if range < 1e-15 {
853 return f64::INFINITY;
854 }
855 return (range / n_bins) * K;
856 }
857
858 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
860 let median_gap = if gaps.len() % 2 == 0 {
861 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
862 } else {
863 gaps[gaps.len() / 2]
864 };
865
866 if median_gap < 1e-15 {
867 f64::INFINITY
868 } else {
869 median_gap * K
870 }
871 })
872 .collect()
873 }
874
875 pub fn reset(&mut self) {
877 for step in &mut self.steps {
878 step.reset();
879 }
880 self.base_prediction = 0.0;
881 self.base_initialized = false;
882 self.initial_targets.clear();
883 self.samples_seen = 0;
884 self.rng_state = self.config.seed;
885 self.auto_bandwidths.clear();
886 self.last_replacement_sum = 0;
887 }
888
889 #[cfg(feature = "_serde_support")]
900 pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
901 let loss_type = self.loss.loss_type().ok_or_else(|| {
902 crate::error::IrithyllError::Serialization(
903 "cannot auto-detect loss type for serialization: \
904 implement Loss::loss_type() or use to_model_state_with()"
905 .into(),
906 )
907 })?;
908 Ok(self.to_model_state_with(loss_type))
909 }
910
911 #[cfg(feature = "_serde_support")]
915 pub fn to_model_state_with(
916 &self,
917 loss_type: crate::loss::LossType,
918 ) -> crate::serde_support::ModelState {
919 use crate::serde_support::{ModelState, StepSnapshot};
920
921 let steps = self
922 .steps
923 .iter()
924 .map(|step| {
925 let slot = step.slot();
926 let tree_snap = snapshot_tree(slot.active_tree());
927 let alt_snap = slot.alternate_tree().map(snapshot_tree);
928 let drift_state = slot.detector().serialize_state();
929 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
930 StepSnapshot {
931 tree: tree_snap,
932 alternate_tree: alt_snap,
933 drift_state,
934 alt_drift_state,
935 }
936 })
937 .collect();
938
939 ModelState {
940 config: self.config.clone(),
941 loss_type,
942 base_prediction: self.base_prediction,
943 base_initialized: self.base_initialized,
944 initial_targets: self.initial_targets.clone(),
945 initial_target_count: self.initial_target_count,
946 samples_seen: self.samples_seen,
947 rng_state: self.rng_state,
948 steps,
949 rolling_mean_error: self.rolling_mean_error,
950 contribution_ewma: self.contribution_ewma.clone(),
951 low_contrib_count: self.low_contrib_count.clone(),
952 }
953 }
954}
955
956#[cfg(feature = "_serde_support")]
961impl SGBT<Box<dyn Loss>> {
962 pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
972 use crate::ensemble::replacement::TreeSlot;
973
974 let loss = state.loss_type.into_loss();
975
976 let leaf_decay_alpha = state
977 .config
978 .leaf_half_life
979 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
980 let max_tree_samples = state.config.max_tree_samples;
981
982 let steps: Vec<BoostingStep> = state
983 .steps
984 .iter()
985 .enumerate()
986 .map(|(i, step_snap)| {
987 let tree_config = TreeConfig::new()
988 .max_depth(state.config.max_depth)
989 .n_bins(state.config.n_bins)
990 .lambda(state.config.lambda)
991 .gamma(state.config.gamma)
992 .grace_period(state.config.grace_period)
993 .delta(state.config.delta)
994 .feature_subsample_rate(state.config.feature_subsample_rate)
995 .leaf_decay_alpha_opt(leaf_decay_alpha)
996 .split_reeval_interval_opt(state.config.split_reeval_interval)
997 .feature_types_opt(state.config.feature_types.clone())
998 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
999 .monotone_constraints_opt(state.config.monotone_constraints.clone())
1000 .leaf_model_type(state.config.leaf_model_type.clone())
1001 .seed(state.config.seed ^ (i as u64));
1002
1003 let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1004 let alternate = step_snap
1005 .alternate_tree
1006 .as_ref()
1007 .map(|snap| rebuild_tree(snap, tree_config.clone()));
1008
1009 let mut detector = state.config.drift_detector.create();
1010 if let Some(ref ds) = step_snap.drift_state {
1011 detector.restore_state(ds);
1012 }
1013 let mut slot = TreeSlot::from_trees(
1014 active,
1015 alternate,
1016 tree_config,
1017 detector,
1018 max_tree_samples,
1019 );
1020 if let Some(ref ads) = step_snap.alt_drift_state {
1021 if let Some(alt_det) = slot.alt_detector_mut() {
1022 alt_det.restore_state(ads);
1023 }
1024 }
1025 BoostingStep::from_slot(slot)
1026 })
1027 .collect();
1028
1029 let n = steps.len();
1030 let has_pruning = state.config.quality_prune_alpha.is_some();
1031
1032 let contribution_ewma = if !state.contribution_ewma.is_empty() {
1034 state.contribution_ewma
1035 } else if has_pruning {
1036 vec![0.0; n]
1037 } else {
1038 Vec::new()
1039 };
1040 let low_contrib_count = if !state.low_contrib_count.is_empty() {
1041 state.low_contrib_count
1042 } else if has_pruning {
1043 vec![0; n]
1044 } else {
1045 Vec::new()
1046 };
1047
1048 Self {
1049 config: state.config,
1050 steps,
1051 loss,
1052 base_prediction: state.base_prediction,
1053 base_initialized: state.base_initialized,
1054 initial_targets: state.initial_targets,
1055 initial_target_count: state.initial_target_count,
1056 samples_seen: state.samples_seen,
1057 rng_state: state.rng_state,
1058 contribution_ewma,
1059 low_contrib_count,
1060 rolling_mean_error: state.rolling_mean_error,
1061 auto_bandwidths: Vec::new(),
1062 last_replacement_sum: 0,
1063 }
1064 }
1065}
1066
1067#[cfg(feature = "_serde_support")]
1073pub(crate) fn snapshot_tree(
1074 tree: &crate::tree::hoeffding::HoeffdingTree,
1075) -> crate::serde_support::TreeSnapshot {
1076 use crate::serde_support::TreeSnapshot;
1077 use crate::tree::StreamingTree;
1078 let arena = tree.arena();
1079 TreeSnapshot {
1080 feature_idx: arena.feature_idx.clone(),
1081 threshold: arena.threshold.clone(),
1082 left: arena.left.iter().map(|id| id.0).collect(),
1083 right: arena.right.iter().map(|id| id.0).collect(),
1084 leaf_value: arena.leaf_value.clone(),
1085 is_leaf: arena.is_leaf.clone(),
1086 depth: arena.depth.clone(),
1087 sample_count: arena.sample_count.clone(),
1088 n_features: tree.n_features(),
1089 samples_seen: tree.n_samples_seen(),
1090 rng_state: tree.rng_state(),
1091 categorical_mask: arena.categorical_mask.clone(),
1092 }
1093}
1094
1095#[cfg(feature = "_serde_support")]
1097pub(crate) fn rebuild_tree(
1098 snapshot: &crate::serde_support::TreeSnapshot,
1099 tree_config: TreeConfig,
1100) -> crate::tree::hoeffding::HoeffdingTree {
1101 use crate::tree::hoeffding::HoeffdingTree;
1102 use crate::tree::node::{NodeId, TreeArena};
1103
1104 let mut arena = TreeArena::new();
1105 let n = snapshot.feature_idx.len();
1106
1107 for i in 0..n {
1108 arena.feature_idx.push(snapshot.feature_idx[i]);
1109 arena.threshold.push(snapshot.threshold[i]);
1110 arena.left.push(NodeId(snapshot.left[i]));
1111 arena.right.push(NodeId(snapshot.right[i]));
1112 arena.leaf_value.push(snapshot.leaf_value[i]);
1113 arena.is_leaf.push(snapshot.is_leaf[i]);
1114 arena.depth.push(snapshot.depth[i]);
1115 arena.sample_count.push(snapshot.sample_count[i]);
1116 let mask = snapshot.categorical_mask.get(i).copied().flatten();
1117 arena.categorical_mask.push(mask);
1118 }
1119
1120 HoeffdingTree::from_arena(
1121 tree_config,
1122 arena,
1123 snapshot.n_features,
1124 snapshot.samples_seen,
1125 snapshot.rng_state,
1126 )
1127}
1128
1129#[cfg(test)]
1130mod tests {
1131 use super::*;
1132 use alloc::boxed::Box;
1133 use alloc::vec;
1134 use alloc::vec::Vec;
1135
1136 fn default_config() -> SGBTConfig {
1137 SGBTConfig::builder()
1138 .n_steps(10)
1139 .learning_rate(0.1)
1140 .grace_period(20)
1141 .max_depth(4)
1142 .n_bins(16)
1143 .build()
1144 .unwrap()
1145 }
1146
1147 #[test]
1148 fn new_model_predicts_zero() {
1149 let model = SGBT::new(default_config());
1150 let pred = model.predict(&[1.0, 2.0, 3.0]);
1151 assert!(pred.abs() < 1e-12);
1152 }
1153
1154 #[test]
1155 fn train_one_does_not_panic() {
1156 let mut model = SGBT::new(default_config());
1157 model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1158 assert_eq!(model.n_samples_seen(), 1);
1159 }
1160
1161 #[test]
1162 fn prediction_changes_after_training() {
1163 let mut model = SGBT::new(default_config());
1164 let features = vec![1.0, 2.0, 3.0];
1165 for i in 0..100 {
1166 model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1167 }
1168 let pred = model.predict(&features);
1169 assert!(pred.is_finite());
1170 }
1171
1172 #[test]
1173 fn linear_signal_rmse_improves() {
1174 let config = SGBTConfig::builder()
1175 .n_steps(20)
1176 .learning_rate(0.1)
1177 .grace_period(10)
1178 .max_depth(3)
1179 .n_bins(16)
1180 .build()
1181 .unwrap();
1182 let mut model = SGBT::new(config);
1183
1184 let mut rng: u64 = 12345;
1185 let mut early_errors = Vec::new();
1186 let mut late_errors = Vec::new();
1187
1188 for i in 0..500 {
1189 rng ^= rng << 13;
1190 rng ^= rng >> 7;
1191 rng ^= rng << 17;
1192 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1193 rng ^= rng << 13;
1194 rng ^= rng >> 7;
1195 rng ^= rng << 17;
1196 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1197 let target = 2.0 * x1 + 3.0 * x2;
1198
1199 let pred = model.predict(&[x1, x2]);
1200 let error = (pred - target).powi(2);
1201
1202 if (50..150).contains(&i) {
1203 early_errors.push(error);
1204 }
1205 if i >= 400 {
1206 late_errors.push(error);
1207 }
1208
1209 model.train_one(&Sample::new(vec![x1, x2], target));
1210 }
1211
1212 let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1213 let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1214
1215 assert!(
1216 late_rmse < early_rmse,
1217 "RMSE should decrease: early={:.4}, late={:.4}",
1218 early_rmse,
1219 late_rmse
1220 );
1221 }
1222
1223 #[test]
1224 fn train_batch_equivalent_to_sequential() {
1225 let config = default_config();
1226 let mut model_seq = SGBT::new(config.clone());
1227 let mut model_batch = SGBT::new(config);
1228
1229 let samples: Vec<Sample> = (0..20)
1230 .map(|i| {
1231 let x = i as f64 * 0.5;
1232 Sample::new(vec![x, x * 2.0], x * 3.0)
1233 })
1234 .collect();
1235
1236 for s in &samples {
1237 model_seq.train_one(s);
1238 }
1239 model_batch.train_batch(&samples);
1240
1241 let pred_seq = model_seq.predict(&[1.0, 2.0]);
1242 let pred_batch = model_batch.predict(&[1.0, 2.0]);
1243
1244 assert!(
1245 (pred_seq - pred_batch).abs() < 1e-10,
1246 "seq={}, batch={}",
1247 pred_seq,
1248 pred_batch
1249 );
1250 }
1251
1252 #[test]
1253 fn reset_returns_to_initial() {
1254 let mut model = SGBT::new(default_config());
1255 for i in 0..100 {
1256 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1257 }
1258 model.reset();
1259 assert_eq!(model.n_samples_seen(), 0);
1260 assert!(!model.is_initialized());
1261 assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1262 }
1263
1264 #[test]
1265 fn base_prediction_initializes() {
1266 let mut model = SGBT::new(default_config());
1267 for i in 0..50 {
1268 model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1269 }
1270 assert!(model.is_initialized());
1271 let expected = (100.0 + 149.0) / 2.0;
1272 assert!((model.base_prediction() - expected).abs() < 1.0);
1273 }
1274
1275 #[test]
1276 fn with_loss_uses_custom_loss() {
1277 use crate::loss::logistic::LogisticLoss;
1278 let model = SGBT::with_loss(default_config(), LogisticLoss);
1279 let pred = model.predict_transformed(&[1.0, 2.0]);
1280 assert!(
1281 (pred - 0.5).abs() < 1e-6,
1282 "sigmoid(0) should be 0.5, got {}",
1283 pred
1284 );
1285 }
1286
1287 #[test]
1288 fn ewma_config_propagates_and_trains() {
1289 let config = SGBTConfig::builder()
1290 .n_steps(5)
1291 .learning_rate(0.1)
1292 .grace_period(10)
1293 .max_depth(3)
1294 .n_bins(16)
1295 .leaf_half_life(50)
1296 .build()
1297 .unwrap();
1298 let mut model = SGBT::new(config);
1299
1300 for i in 0..200 {
1301 let x = (i as f64) * 0.1;
1302 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1303 }
1304
1305 let pred = model.predict(&[1.0, 2.0]);
1306 assert!(
1307 pred.is_finite(),
1308 "EWMA-enabled model should produce finite predictions, got {}",
1309 pred
1310 );
1311 }
1312
1313 #[test]
1314 fn max_tree_samples_config_propagates() {
1315 let config = SGBTConfig::builder()
1316 .n_steps(5)
1317 .learning_rate(0.1)
1318 .grace_period(10)
1319 .max_depth(3)
1320 .n_bins(16)
1321 .max_tree_samples(200)
1322 .build()
1323 .unwrap();
1324 let mut model = SGBT::new(config);
1325
1326 for i in 0..500 {
1327 let x = (i as f64) * 0.1;
1328 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1329 }
1330
1331 let pred = model.predict(&[1.0, 2.0]);
1332 assert!(
1333 pred.is_finite(),
1334 "max_tree_samples model should produce finite predictions, got {}",
1335 pred
1336 );
1337 }
1338
1339 #[test]
1340 fn split_reeval_config_propagates() {
1341 let config = SGBTConfig::builder()
1342 .n_steps(5)
1343 .learning_rate(0.1)
1344 .grace_period(10)
1345 .max_depth(2)
1346 .n_bins(16)
1347 .split_reeval_interval(50)
1348 .build()
1349 .unwrap();
1350 let mut model = SGBT::new(config);
1351
1352 let mut rng: u64 = 12345;
1353 for _ in 0..1000 {
1354 rng ^= rng << 13;
1355 rng ^= rng >> 7;
1356 rng ^= rng << 17;
1357 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1358 rng ^= rng << 13;
1359 rng ^= rng >> 7;
1360 rng ^= rng << 17;
1361 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1362 let target = 2.0 * x1 + 3.0 * x2;
1363 model.train_one(&Sample::new(vec![x1, x2], target));
1364 }
1365
1366 let pred = model.predict(&[1.0, 2.0]);
1367 assert!(
1368 pred.is_finite(),
1369 "split re-eval model should produce finite predictions, got {}",
1370 pred
1371 );
1372 }
1373
1374 #[test]
1375 fn loss_accessor_works() {
1376 use crate::loss::logistic::LogisticLoss;
1377 let model = SGBT::with_loss(default_config(), LogisticLoss);
1378 let _loss: &LogisticLoss = model.loss();
1380 assert_eq!(_loss.n_outputs(), 1);
1381 }
1382
1383 #[test]
1384 fn clone_produces_independent_copy() {
1385 let config = default_config();
1386 let mut model = SGBT::new(config);
1387
1388 let mut rng: u64 = 99999;
1390 for _ in 0..200 {
1391 rng ^= rng << 13;
1392 rng ^= rng >> 7;
1393 rng ^= rng << 17;
1394 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1395 let target = 2.0 * x + 1.0;
1396 model.train_one(&Sample::new(vec![x], target));
1397 }
1398
1399 let mut cloned = model.clone();
1401
1402 let test_features = [3.0];
1404 let pred_original = model.predict(&test_features);
1405 let pred_cloned = cloned.predict(&test_features);
1406 assert!(
1407 (pred_original - pred_cloned).abs() < 1e-12,
1408 "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1409 );
1410
1411 for _ in 0..200 {
1413 rng ^= rng << 13;
1414 rng ^= rng >> 7;
1415 rng ^= rng << 17;
1416 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1417 let target = -3.0 * x + 5.0; cloned.train_one(&Sample::new(vec![x], target));
1419 }
1420
1421 let pred_original_after = model.predict(&test_features);
1422 let pred_cloned_after = cloned.predict(&test_features);
1423
1424 assert!(
1426 (pred_original - pred_original_after).abs() < 1e-12,
1427 "original should be unchanged after training clone"
1428 );
1429
1430 assert!(
1432 (pred_original_after - pred_cloned_after).abs() > 1e-6,
1433 "clone should diverge after independent training"
1434 );
1435 }
1436
1437 #[test]
1441 fn predict_with_confidence_finite() {
1442 let config = SGBTConfig::builder()
1443 .n_steps(5)
1444 .grace_period(10)
1445 .build()
1446 .unwrap();
1447 let mut model = SGBT::new(config);
1448
1449 for i in 0..100 {
1451 let x = i as f64 * 0.1;
1452 model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1453 }
1454
1455 let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1456 assert!(pred.is_finite(), "prediction should be finite");
1457 assert!(confidence.is_finite(), "confidence should be finite");
1458 assert!(
1459 confidence > 0.0,
1460 "confidence should be positive after training"
1461 );
1462 }
1463
1464 #[test]
1468 fn predict_with_confidence_positive_after_training() {
1469 let config = SGBTConfig::builder()
1470 .n_steps(5)
1471 .grace_period(10)
1472 .build()
1473 .unwrap();
1474 let mut model = SGBT::new(config);
1475
1476 for i in 0..200 {
1478 let x = i as f64 * 0.05;
1479 model.train_one(&(&[x][..], x * 2.0));
1480 }
1481
1482 let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1483
1484 assert!(pred.is_finite(), "prediction should be finite");
1485 assert!(
1486 confidence > 0.0 && confidence.is_finite(),
1487 "confidence should be finite and positive, got {}",
1488 confidence,
1489 );
1490
1491 let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1493 assert!(
1494 (pred - pred2).abs() < 1e-12,
1495 "same input should give same prediction"
1496 );
1497 assert!(
1498 (confidence - conf2).abs() < 1e-12,
1499 "same input should give same confidence"
1500 );
1501 }
1502
1503 #[test]
1507 fn predict_with_confidence_matches_predict() {
1508 let config = SGBTConfig::builder()
1509 .n_steps(10)
1510 .grace_period(10)
1511 .build()
1512 .unwrap();
1513 let mut model = SGBT::new(config);
1514
1515 for i in 0..200 {
1516 let x = (i as f64 - 100.0) * 0.01;
1517 model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1518 }
1519
1520 let pred = model.predict(&[0.5, 0.25]);
1521 let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1522
1523 assert!(
1524 (pred - conf_pred).abs() < 1e-10,
1525 "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1526 pred,
1527 conf_pred,
1528 );
1529 }
1530
1531 #[test]
1535 fn gradient_clip_config_builder() {
1536 let config = SGBTConfig::builder()
1537 .n_steps(10)
1538 .gradient_clip_sigma(3.0)
1539 .build()
1540 .unwrap();
1541
1542 assert_eq!(config.gradient_clip_sigma, Some(3.0));
1543 }
1544
1545 #[test]
1549 fn monotone_constraints_config_builder() {
1550 let config = SGBTConfig::builder()
1551 .n_steps(10)
1552 .monotone_constraints(vec![1, -1, 0])
1553 .build()
1554 .unwrap();
1555
1556 assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1557 }
1558
1559 #[test]
1563 fn monotone_constraints_invalid_value_rejected() {
1564 let result = SGBTConfig::builder()
1565 .n_steps(10)
1566 .monotone_constraints(vec![1, 2, 0])
1567 .build();
1568
1569 assert!(result.is_err(), "constraint value 2 should be rejected");
1570 }
1571
1572 #[test]
1576 fn gradient_clip_sigma_negative_rejected() {
1577 let result = SGBTConfig::builder()
1578 .n_steps(10)
1579 .gradient_clip_sigma(-1.0)
1580 .build();
1581
1582 assert!(result.is_err(), "negative sigma should be rejected");
1583 }
1584
1585 #[test]
1589 fn gradient_clipping_reduces_outlier_impact() {
1590 let config_no_clip = SGBTConfig::builder()
1592 .n_steps(5)
1593 .grace_period(10)
1594 .build()
1595 .unwrap();
1596 let mut model_no_clip = SGBT::new(config_no_clip);
1597
1598 let config_clip = SGBTConfig::builder()
1600 .n_steps(5)
1601 .grace_period(10)
1602 .gradient_clip_sigma(3.0)
1603 .build()
1604 .unwrap();
1605 let mut model_clip = SGBT::new(config_clip);
1606
1607 for i in 0..100 {
1609 let x = (i as f64) * 0.01;
1610 let sample = (&[x][..], x * 2.0);
1611 model_no_clip.train_one(&sample);
1612 model_clip.train_one(&sample);
1613 }
1614
1615 let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1616 let pred_clip_before = model_clip.predict(&[0.5]);
1617
1618 let outlier = (&[0.5_f64][..], 10000.0);
1620 model_no_clip.train_one(&outlier);
1621 model_clip.train_one(&outlier);
1622
1623 let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1624 let pred_clip_after = model_clip.predict(&[0.5]);
1625
1626 let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1627 let delta_clip = (pred_clip_after - pred_clip_before).abs();
1628
1629 assert!(
1631 delta_clip <= delta_no_clip + 1e-10,
1632 "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1633 delta_clip,
1634 delta_no_clip,
1635 );
1636 }
1637
1638 #[test]
1642 fn train_batch_with_callback_fires() {
1643 let config = SGBTConfig::builder()
1644 .n_steps(3)
1645 .grace_period(5)
1646 .build()
1647 .unwrap();
1648 let mut model = SGBT::new(config);
1649
1650 let data: Vec<(Vec<f64>, f64)> = (0..25)
1651 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1652 .collect();
1653
1654 let mut callbacks = Vec::new();
1655 model.train_batch_with_callback(&data, 10, |n| {
1656 callbacks.push(n);
1657 });
1658
1659 assert_eq!(callbacks, vec![10, 20, 25]);
1661 }
1662
1663 #[test]
1667 fn train_batch_subsampled_trains_subset() {
1668 let config = SGBTConfig::builder()
1669 .n_steps(3)
1670 .grace_period(5)
1671 .build()
1672 .unwrap();
1673 let mut model = SGBT::new(config);
1674
1675 let data: Vec<(Vec<f64>, f64)> = (0..100)
1676 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1677 .collect();
1678
1679 model.train_batch_subsampled(&data, 20);
1681
1682 assert!(
1684 model.n_samples_seen() > 0,
1685 "model should have trained on subset"
1686 );
1687 assert!(
1688 model.n_samples_seen() <= 20,
1689 "model should have trained at most 20 samples, got {}",
1690 model.n_samples_seen(),
1691 );
1692 }
1693
1694 #[test]
1698 fn train_batch_subsampled_full_equals_batch() {
1699 let config1 = SGBTConfig::builder()
1700 .n_steps(3)
1701 .grace_period(5)
1702 .build()
1703 .unwrap();
1704 let config2 = config1.clone();
1705
1706 let mut model1 = SGBT::new(config1);
1707 let mut model2 = SGBT::new(config2);
1708
1709 let data: Vec<(Vec<f64>, f64)> = (0..50)
1710 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1711 .collect();
1712
1713 model1.train_batch(&data);
1714 model2.train_batch_subsampled(&data, 1000); assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1718 let pred1 = model1.predict(&[2.5]);
1719 let pred2 = model2.predict(&[2.5]);
1720 assert!(
1721 (pred1 - pred2).abs() < 1e-12,
1722 "full subsample should equal batch: {} vs {}",
1723 pred1,
1724 pred2,
1725 );
1726 }
1727
1728 #[test]
1732 fn train_batch_subsampled_with_callback_works() {
1733 let config = SGBTConfig::builder()
1734 .n_steps(3)
1735 .grace_period(5)
1736 .build()
1737 .unwrap();
1738 let mut model = SGBT::new(config);
1739
1740 let data: Vec<(Vec<f64>, f64)> = (0..200)
1741 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1742 .collect();
1743
1744 let mut callbacks = Vec::new();
1745 model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1746 callbacks.push(n);
1747 });
1748
1749 assert!(!callbacks.is_empty(), "should have received callbacks");
1751 assert_eq!(
1752 *callbacks.last().unwrap(),
1753 50,
1754 "final callback should be total samples"
1755 );
1756 }
1757
1758 fn xorshift64(state: &mut u64) -> u64 {
1764 let mut s = *state;
1765 s ^= s << 13;
1766 s ^= s >> 7;
1767 s ^= s << 17;
1768 *state = s;
1769 s
1770 }
1771
1772 fn rand_f64(state: &mut u64) -> f64 {
1773 xorshift64(state) as f64 / u64::MAX as f64
1774 }
1775
1776 fn linear_leaves_config() -> SGBTConfig {
1777 SGBTConfig::builder()
1778 .n_steps(10)
1779 .learning_rate(0.1)
1780 .grace_period(20)
1781 .max_depth(2) .n_bins(16)
1783 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1784 learning_rate: 0.1,
1785 decay: None,
1786 use_adagrad: false,
1787 })
1788 .build()
1789 .unwrap()
1790 }
1791
1792 #[test]
1793 fn linear_leaves_trains_without_panic() {
1794 let mut model = SGBT::new(linear_leaves_config());
1795 let mut rng = 42u64;
1796 for _ in 0..200 {
1797 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1798 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1799 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1800 model.train_one(&Sample::new(vec![x1, x2], y));
1801 }
1802 assert_eq!(model.n_samples_seen(), 200);
1803 }
1804
1805 #[test]
1806 fn linear_leaves_prediction_finite() {
1807 let mut model = SGBT::new(linear_leaves_config());
1808 let mut rng = 42u64;
1809 for _ in 0..200 {
1810 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1811 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1812 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1813 model.train_one(&Sample::new(vec![x1, x2], y));
1814 }
1815 let pred = model.predict(&[0.5, -0.3]);
1816 assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1817 }
1818
1819 #[test]
1820 fn linear_leaves_learns_linear_target() {
1821 let mut model = SGBT::new(linear_leaves_config());
1822 let mut rng = 42u64;
1823 for _ in 0..500 {
1824 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1825 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1826 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1827 model.train_one(&Sample::new(vec![x1, x2], y));
1828 }
1829
1830 let mut total_error = 0.0;
1832 for _ in 0..50 {
1833 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1834 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1835 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1836 let pred = model.predict(&[x1, x2]);
1837 total_error += (pred - y).powi(2);
1838 }
1839 let mse = total_error / 50.0;
1840 assert!(
1841 mse < 5.0,
1842 "linear leaves MSE on linear target should be < 5.0, got {mse}"
1843 );
1844 }
1845
1846 #[test]
1847 fn linear_leaves_better_than_constant_at_low_depth() {
1848 let constant_config = SGBTConfig::builder()
1851 .n_steps(10)
1852 .learning_rate(0.1)
1853 .grace_period(20)
1854 .max_depth(2)
1855 .n_bins(16)
1856 .seed(0xDEAD)
1857 .build()
1858 .unwrap();
1859 let linear_config = SGBTConfig::builder()
1860 .n_steps(10)
1861 .learning_rate(0.1)
1862 .grace_period(20)
1863 .max_depth(2)
1864 .n_bins(16)
1865 .seed(0xDEAD)
1866 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1867 learning_rate: 0.1,
1868 decay: None,
1869 use_adagrad: false,
1870 })
1871 .build()
1872 .unwrap();
1873
1874 let mut constant_model = SGBT::new(constant_config);
1875 let mut linear_model = SGBT::new(linear_config);
1876 let mut rng = 42u64;
1877
1878 for _ in 0..500 {
1879 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1880 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1881 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1882 let sample = Sample::new(vec![x1, x2], y);
1883 constant_model.train_one(&sample);
1884 linear_model.train_one(&sample);
1885 }
1886
1887 let mut constant_mse = 0.0;
1889 let mut linear_mse = 0.0;
1890 for _ in 0..100 {
1891 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1892 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1893 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1894 constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1895 linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1896 }
1897 constant_mse /= 100.0;
1898 linear_mse /= 100.0;
1899
1900 assert!(
1902 linear_mse < constant_mse,
1903 "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1904 );
1905 }
1906
1907 #[test]
1908 fn adaptive_leaves_trains_without_panic() {
1909 let config = SGBTConfig::builder()
1910 .n_steps(10)
1911 .learning_rate(0.1)
1912 .grace_period(20)
1913 .max_depth(3)
1914 .n_bins(16)
1915 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
1916 promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
1917 learning_rate: 0.1,
1918 decay: None,
1919 use_adagrad: false,
1920 }),
1921 })
1922 .build()
1923 .unwrap();
1924
1925 let mut model = SGBT::new(config);
1926 let mut rng = 42u64;
1927 for _ in 0..500 {
1928 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1929 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1930 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1931 model.train_one(&Sample::new(vec![x1, x2], y));
1932 }
1933 let pred = model.predict(&[0.5, -0.3]);
1934 assert!(
1935 pred.is_finite(),
1936 "adaptive leaf prediction should be finite, got {pred}"
1937 );
1938 }
1939
1940 #[test]
1941 fn linear_leaves_with_decay_trains_without_panic() {
1942 let config = SGBTConfig::builder()
1943 .n_steps(10)
1944 .learning_rate(0.1)
1945 .grace_period(20)
1946 .max_depth(3)
1947 .n_bins(16)
1948 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1949 learning_rate: 0.1,
1950 decay: Some(0.995),
1951 use_adagrad: false,
1952 })
1953 .build()
1954 .unwrap();
1955
1956 let mut model = SGBT::new(config);
1957 let mut rng = 42u64;
1958 for _ in 0..500 {
1959 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1960 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1961 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1962 model.train_one(&Sample::new(vec![x1, x2], y));
1963 }
1964 let pred = model.predict(&[0.5, -0.3]);
1965 assert!(
1966 pred.is_finite(),
1967 "decay leaf prediction should be finite, got {pred}"
1968 );
1969 }
1970
1971 #[test]
1975 fn predict_smooth_returns_finite() {
1976 let config = SGBTConfig::builder()
1977 .n_steps(5)
1978 .learning_rate(0.1)
1979 .grace_period(10)
1980 .build()
1981 .unwrap();
1982 let mut model = SGBT::new(config);
1983
1984 for i in 0..200 {
1985 let x = (i as f64) * 0.1;
1986 model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
1987 }
1988
1989 let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
1990 let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
1991
1992 assert!(pred_hard.is_finite(), "hard prediction should be finite");
1993 assert!(
1994 pred_smooth.is_finite(),
1995 "smooth prediction should be finite"
1996 );
1997 }
1998
1999 #[test]
2003 fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2004 let config = SGBTConfig::builder()
2005 .n_steps(5)
2006 .learning_rate(0.1)
2007 .grace_period(10)
2008 .build()
2009 .unwrap();
2010 let mut model = SGBT::new(config);
2011
2012 for i in 0..300 {
2013 let x = (i as f64) * 0.1;
2014 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2015 }
2016
2017 let features = [5.0, 2.5];
2018 let hard = model.predict(&features);
2019 let smooth = model.predict_smooth(&features, 0.001);
2020
2021 assert!(
2022 (hard - smooth).abs() < 0.5,
2023 "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2024 hard,
2025 smooth,
2026 );
2027 }
2028
2029 #[test]
2030 fn auto_bandwidth_computed_after_training() {
2031 let config = SGBTConfig::builder()
2032 .n_steps(5)
2033 .learning_rate(0.1)
2034 .grace_period(10)
2035 .build()
2036 .unwrap();
2037 let mut model = SGBT::new(config);
2038
2039 assert!(model.auto_bandwidths().is_empty());
2041
2042 for i in 0..200 {
2043 let x = (i as f64) * 0.1;
2044 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2045 }
2046
2047 let bws = model.auto_bandwidths();
2049 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2050
2051 let pred = model.predict(&[5.0, 2.5]);
2053 assert!(
2054 pred.is_finite(),
2055 "auto-bandwidth predict should be finite: {}",
2056 pred
2057 );
2058 }
2059
2060 #[test]
2061 fn predict_interpolated_returns_finite() {
2062 let config = SGBTConfig::builder()
2063 .n_steps(5)
2064 .learning_rate(0.01)
2065 .build()
2066 .unwrap();
2067 let mut model = SGBT::new(config);
2068
2069 for i in 0..200 {
2070 let x = (i as f64) * 0.1;
2071 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2072 }
2073
2074 let pred = model.predict_interpolated(&[1.0, 0.5]);
2075 assert!(
2076 pred.is_finite(),
2077 "interpolated prediction should be finite: {}",
2078 pred
2079 );
2080 }
2081
2082 #[test]
2083 fn predict_sibling_interpolated_varies_with_features() {
2084 let config = SGBTConfig::builder()
2085 .n_steps(10)
2086 .learning_rate(0.1)
2087 .grace_period(10)
2088 .max_depth(6)
2089 .delta(0.1)
2090 .build()
2091 .unwrap();
2092 let mut model = SGBT::new(config);
2093
2094 for i in 0..2000 {
2095 let x = (i as f64) * 0.01;
2096 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2097 model.train_one(&Sample::new(vec![x, x * 0.3], y));
2098 }
2099
2100 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2102 assert!(pred.is_finite(), "sibling interpolated should be finite");
2103
2104 let bws = model.auto_bandwidths();
2107 if bws.iter().any(|&b| b.is_finite()) {
2108 let hard: Vec<f64> = (0..200)
2109 .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2110 .collect();
2111 let sib: Vec<f64> = (0..200)
2112 .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2113 .collect();
2114 let hc = hard
2115 .windows(2)
2116 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2117 .count();
2118 let sc = sib
2119 .windows(2)
2120 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2121 .count();
2122 assert!(
2123 sc >= hc,
2124 "sibling should produce >= hard changes: sib={}, hard={}",
2125 sc,
2126 hc
2127 );
2128 }
2129 }
2130
2131 #[test]
2132 fn predict_graduated_returns_finite() {
2133 let config = SGBTConfig::builder()
2134 .n_steps(5)
2135 .learning_rate(0.01)
2136 .max_tree_samples(200)
2137 .shadow_warmup(50)
2138 .build()
2139 .unwrap();
2140 let mut model = SGBT::new(config);
2141
2142 for i in 0..300 {
2143 let x = (i as f64) * 0.1;
2144 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2145 }
2146
2147 let pred = model.predict_graduated(&[1.0, 0.5]);
2148 assert!(
2149 pred.is_finite(),
2150 "graduated prediction should be finite: {}",
2151 pred
2152 );
2153
2154 let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2155 assert!(
2156 pred2.is_finite(),
2157 "graduated+sibling prediction should be finite: {}",
2158 pred2
2159 );
2160 }
2161
2162 #[test]
2163 fn shadow_warmup_validation() {
2164 let result = SGBTConfig::builder()
2165 .n_steps(5)
2166 .learning_rate(0.01)
2167 .shadow_warmup(0)
2168 .build();
2169 assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2170 }
2171}