1pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62pub struct SGBT<L: Loss = SquaredLoss> {
90 config: SGBTConfig,
92 steps: Vec<BoostingStep>,
94 loss: L,
96 base_prediction: f64,
98 base_initialized: bool,
100 initial_targets: Vec<f64>,
102 initial_target_count: usize,
104 samples_seen: u64,
106 rng_state: u64,
108 contribution_ewma: Vec<f64>,
111 low_contrib_count: Vec<u64>,
114 rolling_mean_error: f64,
117 auto_bandwidths: Vec<f64>,
120 last_replacement_sum: u64,
123}
124
125impl<L: Loss + Clone> Clone for SGBT<L> {
126 fn clone(&self) -> Self {
127 Self {
128 config: self.config.clone(),
129 steps: self.steps.clone(),
130 loss: self.loss.clone(),
131 base_prediction: self.base_prediction,
132 base_initialized: self.base_initialized,
133 initial_targets: self.initial_targets.clone(),
134 initial_target_count: self.initial_target_count,
135 samples_seen: self.samples_seen,
136 rng_state: self.rng_state,
137 contribution_ewma: self.contribution_ewma.clone(),
138 low_contrib_count: self.low_contrib_count.clone(),
139 rolling_mean_error: self.rolling_mean_error,
140 auto_bandwidths: self.auto_bandwidths.clone(),
141 last_replacement_sum: self.last_replacement_sum,
142 }
143 }
144}
145
146impl<L: Loss> fmt::Debug for SGBT<L> {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 f.debug_struct("SGBT")
149 .field("n_steps", &self.steps.len())
150 .field("samples_seen", &self.samples_seen)
151 .field("base_prediction", &self.base_prediction)
152 .field("base_initialized", &self.base_initialized)
153 .finish()
154 }
155}
156
157impl SGBT<SquaredLoss> {
162 pub fn new(config: SGBTConfig) -> Self {
167 Self::with_loss(config, SquaredLoss)
168 }
169}
170
171impl<L: Loss> SGBT<L> {
176 pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
189 let leaf_decay_alpha = config
190 .leaf_half_life
191 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
192
193 let tree_config = TreeConfig::new()
194 .max_depth(config.max_depth)
195 .n_bins(config.n_bins)
196 .lambda(config.lambda)
197 .gamma(config.gamma)
198 .grace_period(config.grace_period)
199 .delta(config.delta)
200 .feature_subsample_rate(config.feature_subsample_rate)
201 .leaf_decay_alpha_opt(leaf_decay_alpha)
202 .split_reeval_interval_opt(config.split_reeval_interval)
203 .feature_types_opt(config.feature_types.clone())
204 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
205 .monotone_constraints_opt(config.monotone_constraints.clone())
206 .max_leaf_output_opt(config.max_leaf_output)
207 .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
208 .adaptive_depth_opt(config.adaptive_depth)
209 .min_hessian_sum_opt(config.min_hessian_sum)
210 .leaf_model_type(config.leaf_model_type.clone());
211
212 let max_tree_samples = config.max_tree_samples;
213
214 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
215 let steps: Vec<BoostingStep> = (0..config.n_steps)
216 .map(|i| {
217 let mut tc = tree_config.clone();
218 tc.seed = config.seed ^ (i as u64);
219 let detector = config.drift_detector.create();
220 if shadow_warmup > 0 {
221 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
222 } else {
223 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
224 }
225 })
226 .collect();
227
228 let seed = config.seed;
229 let initial_target_count = config.initial_target_count;
230 let n = config.n_steps;
231 let has_pruning = config.quality_prune_alpha.is_some();
232 Self {
233 config,
234 steps,
235 loss,
236 base_prediction: 0.0,
237 base_initialized: false,
238 initial_targets: Vec::new(),
239 initial_target_count,
240 samples_seen: 0,
241 rng_state: seed,
242 contribution_ewma: if has_pruning {
243 vec![0.0; n]
244 } else {
245 Vec::new()
246 },
247 low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
248 rolling_mean_error: 0.0,
249 auto_bandwidths: Vec::new(),
250 last_replacement_sum: 0,
251 }
252 }
253
254 pub fn train_one(&mut self, sample: &impl Observation) {
260 self.samples_seen += 1;
261 let target = sample.target();
262 let features = sample.features();
263
264 if !self.base_initialized {
266 self.initial_targets.push(target);
267 if self.initial_targets.len() >= self.initial_target_count {
268 self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
269 self.base_initialized = true;
270 self.initial_targets.clear();
271 self.initial_targets.shrink_to_fit();
272 }
273 }
274
275 let mut current_pred = self.base_prediction;
277
278 let prune_alpha = self.config.quality_prune_alpha;
279 let prune_threshold = self.config.quality_prune_threshold;
280 let prune_patience = self.config.quality_prune_patience;
281
282 let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
284 let abs_error = crate::math::abs(target - current_pred);
285 if self.rolling_mean_error > 1e-15 {
286 let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
287 self.rolling_mean_error =
288 ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
289 w
290 } else {
291 self.rolling_mean_error = abs_error.max(1e-15);
292 1.0 }
294 } else {
295 1.0
296 };
297
298 for s in 0..self.steps.len() {
300 let gradient = self.loss.gradient(target, current_pred) * error_weight;
301 let hessian = self.loss.hessian(target, current_pred) * error_weight;
302 let train_count = self
303 .config
304 .variant
305 .train_count(hessian, &mut self.rng_state);
306
307 let step_pred =
308 self.steps[s].train_and_predict(features, gradient, hessian, train_count);
309
310 current_pred += self.config.learning_rate * step_pred;
311
312 if let Some(alpha) = prune_alpha {
314 let contribution = crate::math::abs(self.config.learning_rate * step_pred);
315 self.contribution_ewma[s] =
316 alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
317
318 if self.contribution_ewma[s] < prune_threshold {
319 self.low_contrib_count[s] += 1;
320 if self.low_contrib_count[s] >= prune_patience {
321 self.steps[s].reset();
322 self.contribution_ewma[s] = 0.0;
323 self.low_contrib_count[s] = 0;
324 }
325 } else {
326 self.low_contrib_count[s] = 0;
327 }
328 }
329 }
330
331 self.refresh_bandwidths();
333 }
334
335 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
337 for sample in samples {
338 self.train_one(sample);
339 }
340 }
341
342 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
363 &mut self,
364 samples: &[O],
365 interval: usize,
366 mut callback: F,
367 ) {
368 let interval = interval.max(1); for (i, sample) in samples.iter().enumerate() {
370 self.train_one(sample);
371 if (i + 1) % interval == 0 {
372 callback(i + 1);
373 }
374 }
375 let total = samples.len();
377 if total % interval != 0 {
378 callback(total);
379 }
380 }
381
382 pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
395 if max_samples >= samples.len() {
396 self.train_batch(samples);
397 return;
398 }
399
400 let mut reservoir: Vec<usize> = (0..max_samples).collect();
402 let mut rng = self.rng_state;
403
404 for i in max_samples..samples.len() {
405 rng ^= rng << 13;
407 rng ^= rng >> 7;
408 rng ^= rng << 17;
409 let j = (rng % (i as u64 + 1)) as usize;
410 if j < max_samples {
411 reservoir[j] = i;
412 }
413 }
414
415 self.rng_state = rng;
416
417 reservoir.sort_unstable();
419
420 for &idx in &reservoir {
422 self.train_one(&samples[idx]);
423 }
424 }
425
426 pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
432 &mut self,
433 samples: &[O],
434 max_samples: usize,
435 interval: usize,
436 mut callback: F,
437 ) {
438 if max_samples >= samples.len() {
439 self.train_batch_with_callback(samples, interval, callback);
440 return;
441 }
442
443 let mut reservoir: Vec<usize> = (0..max_samples).collect();
445 let mut rng = self.rng_state;
446
447 for i in max_samples..samples.len() {
448 rng ^= rng << 13;
449 rng ^= rng >> 7;
450 rng ^= rng << 17;
451 let j = (rng % (i as u64 + 1)) as usize;
452 if j < max_samples {
453 reservoir[j] = i;
454 }
455 }
456
457 self.rng_state = rng;
458 reservoir.sort_unstable();
459
460 let interval = interval.max(1);
461 for (i, &idx) in reservoir.iter().enumerate() {
462 self.train_one(&samples[idx]);
463 if (i + 1) % interval == 0 {
464 callback(i + 1);
465 }
466 }
467 let total = reservoir.len();
468 if total % interval != 0 {
469 callback(total);
470 }
471 }
472
473 pub fn predict(&self, features: &[f64]) -> f64 {
479 let mut pred = self.base_prediction;
480 if self.auto_bandwidths.is_empty() {
481 for step in &self.steps {
483 pred += self.config.learning_rate * step.predict(features);
484 }
485 } else {
486 for step in &self.steps {
487 pred += self.config.learning_rate
488 * step.predict_smooth_auto(features, &self.auto_bandwidths);
489 }
490 }
491 pred
492 }
493
494 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
499 let mut pred = self.base_prediction;
500 for step in &self.steps {
501 pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
502 }
503 pred
504 }
505
506 pub fn auto_bandwidths(&self) -> &[f64] {
512 &self.auto_bandwidths
513 }
514
515 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
520 let mut pred = self.base_prediction;
521 for step in &self.steps {
522 pred += self.config.learning_rate * step.predict_interpolated(features);
523 }
524 pred
525 }
526
527 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
535 let mut pred = self.base_prediction;
536 for step in &self.steps {
537 pred += self.config.learning_rate
538 * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
539 }
540 pred
541 }
542
543 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
549 let mut pred = self.base_prediction;
550 for step in &self.steps {
551 pred += self.config.learning_rate * step.predict_graduated(features);
552 }
553 pred
554 }
555
556 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
562 let mut pred = self.base_prediction;
563 for step in &self.steps {
564 pred += self.config.learning_rate
565 * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
566 }
567 pred
568 }
569
570 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
572 self.loss.predict_transform(self.predict(features))
573 }
574
575 pub fn predict_proba(&self, features: &[f64]) -> f64 {
577 self.predict_transformed(features)
578 }
579
580 pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
595 let mut pred = self.base_prediction;
596 let mut total_variance = 0.0;
597 let lr2 = self.config.learning_rate * self.config.learning_rate;
598
599 for step in &self.steps {
600 let (value, variance) = step.predict_with_variance(features);
601 pred += self.config.learning_rate * value;
602 total_variance += lr2 * variance;
603 }
604
605 let confidence = if total_variance > 0.0 && total_variance.is_finite() {
606 1.0 / crate::math::sqrt(total_variance)
607 } else {
608 0.0
609 };
610
611 (pred, confidence)
612 }
613
614 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
616 feature_matrix.iter().map(|f| self.predict(f)).collect()
617 }
618
619 pub fn n_steps(&self) -> usize {
621 self.steps.len()
622 }
623
624 pub fn n_trees(&self) -> usize {
626 self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
627 }
628
629 pub fn total_leaves(&self) -> usize {
631 self.steps.iter().map(|s| s.n_leaves()).sum()
632 }
633
634 pub fn n_samples_seen(&self) -> u64 {
636 self.samples_seen
637 }
638
639 pub fn base_prediction(&self) -> f64 {
641 self.base_prediction
642 }
643
644 pub fn is_initialized(&self) -> bool {
646 self.base_initialized
647 }
648
649 pub fn config(&self) -> &SGBTConfig {
651 &self.config
652 }
653
654 #[inline]
663 pub fn set_learning_rate(&mut self, lr: f64) {
664 self.config.learning_rate = lr;
665 }
666
667 pub fn steps(&self) -> &[BoostingStep] {
671 &self.steps
672 }
673
674 pub fn loss(&self) -> &L {
676 &self.loss
677 }
678
679 pub fn feature_importances(&self) -> Vec<f64> {
684 let mut totals: Vec<f64> = Vec::new();
686 for step in &self.steps {
687 let gains = step.slot().split_gains();
688 if totals.is_empty() && !gains.is_empty() {
689 totals.resize(gains.len(), 0.0);
690 }
691 for (i, &g) in gains.iter().enumerate() {
692 if i < totals.len() {
693 totals[i] += g;
694 }
695 }
696 }
697
698 let sum: f64 = totals.iter().sum();
700 if sum > 0.0 {
701 totals.iter_mut().for_each(|v| *v /= sum);
702 }
703 totals
704 }
705
706 pub fn feature_names(&self) -> Option<&[String]> {
708 self.config.feature_names.as_deref()
709 }
710
711 pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
716 let names = self.config.feature_names.as_ref()?;
717 let importances = self.feature_importances();
718 let mut pairs: Vec<(String, f64)> = names
719 .iter()
720 .zip(importances.iter().chain(core::iter::repeat(&0.0)))
721 .map(|(n, &v)| (n.clone(), v))
722 .collect();
723 pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
724 Some(pairs)
725 }
726
727 #[cfg(feature = "std")]
736 pub fn train_one_named(
737 &mut self,
738 features: &std::collections::HashMap<alloc::string::String, f64>,
739 target: f64,
740 ) {
741 let names = self
742 .config
743 .feature_names
744 .as_ref()
745 .expect("train_one_named requires feature_names to be configured");
746 let vec: Vec<f64> = names
747 .iter()
748 .map(|name| features.get(name).copied().unwrap_or(0.0))
749 .collect();
750 self.train_one(&(&vec[..], target));
751 }
752
753 #[cfg(feature = "std")]
761 pub fn predict_named(
762 &self,
763 features: &std::collections::HashMap<alloc::string::String, f64>,
764 ) -> f64 {
765 let names = self
766 .config
767 .feature_names
768 .as_ref()
769 .expect("predict_named requires feature_names to be configured");
770 let vec: Vec<f64> = names
771 .iter()
772 .map(|name| features.get(name).copied().unwrap_or(0.0))
773 .collect();
774 self.predict(&vec)
775 }
776
777 fn refresh_bandwidths(&mut self) {
783 let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
784 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
785 self.auto_bandwidths = self.compute_auto_bandwidths();
786 self.last_replacement_sum = current_sum;
787 }
788 }
789
790 fn compute_auto_bandwidths(&self) -> Vec<f64> {
800 const K: f64 = 2.0;
801
802 let n_features = self
804 .steps
805 .iter()
806 .filter_map(|s| s.slot().active_tree().n_features())
807 .max()
808 .unwrap_or(0);
809
810 if n_features == 0 {
811 return Vec::new();
812 }
813
814 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
816
817 for step in &self.steps {
818 let tree_thresholds = step
819 .slot()
820 .active_tree()
821 .collect_split_thresholds_per_feature();
822 for (i, ts) in tree_thresholds.into_iter().enumerate() {
823 if i < n_features {
824 all_thresholds[i].extend(ts);
825 }
826 }
827 }
828
829 let n_bins = self.config.n_bins as f64;
830
831 all_thresholds
833 .iter()
834 .map(|ts| {
835 if ts.is_empty() {
836 return f64::INFINITY; }
838
839 let mut sorted = ts.clone();
840 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
841 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
842
843 if sorted.len() < 2 {
844 return f64::INFINITY; }
846
847 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
849
850 if sorted.len() < 3 {
851 let range = sorted.last().unwrap() - sorted.first().unwrap();
853 if range < 1e-15 {
854 return f64::INFINITY;
855 }
856 return (range / n_bins) * K;
857 }
858
859 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
861 let median_gap = if gaps.len() % 2 == 0 {
862 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
863 } else {
864 gaps[gaps.len() / 2]
865 };
866
867 if median_gap < 1e-15 {
868 f64::INFINITY
869 } else {
870 median_gap * K
871 }
872 })
873 .collect()
874 }
875
876 pub fn reset(&mut self) {
878 for step in &mut self.steps {
879 step.reset();
880 }
881 self.base_prediction = 0.0;
882 self.base_initialized = false;
883 self.initial_targets.clear();
884 self.samples_seen = 0;
885 self.rng_state = self.config.seed;
886 self.auto_bandwidths.clear();
887 self.last_replacement_sum = 0;
888 }
889
890 #[cfg(feature = "_serde_support")]
901 pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
902 let loss_type = self.loss.loss_type().ok_or_else(|| {
903 crate::error::IrithyllError::Serialization(
904 "cannot auto-detect loss type for serialization: \
905 implement Loss::loss_type() or use to_model_state_with()"
906 .into(),
907 )
908 })?;
909 Ok(self.to_model_state_with(loss_type))
910 }
911
912 #[cfg(feature = "_serde_support")]
916 pub fn to_model_state_with(
917 &self,
918 loss_type: crate::loss::LossType,
919 ) -> crate::serde_support::ModelState {
920 use crate::serde_support::{ModelState, StepSnapshot};
921
922 let steps = self
923 .steps
924 .iter()
925 .map(|step| {
926 let slot = step.slot();
927 let tree_snap = snapshot_tree(slot.active_tree());
928 let alt_snap = slot.alternate_tree().map(snapshot_tree);
929 let drift_state = slot.detector().serialize_state();
930 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
931 StepSnapshot {
932 tree: tree_snap,
933 alternate_tree: alt_snap,
934 drift_state,
935 alt_drift_state,
936 }
937 })
938 .collect();
939
940 ModelState {
941 config: self.config.clone(),
942 loss_type,
943 base_prediction: self.base_prediction,
944 base_initialized: self.base_initialized,
945 initial_targets: self.initial_targets.clone(),
946 initial_target_count: self.initial_target_count,
947 samples_seen: self.samples_seen,
948 rng_state: self.rng_state,
949 steps,
950 rolling_mean_error: self.rolling_mean_error,
951 contribution_ewma: self.contribution_ewma.clone(),
952 low_contrib_count: self.low_contrib_count.clone(),
953 }
954 }
955}
956
957#[cfg(feature = "_serde_support")]
962impl SGBT<Box<dyn Loss>> {
963 pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
973 use crate::ensemble::replacement::TreeSlot;
974
975 let loss = state.loss_type.into_loss();
976
977 let leaf_decay_alpha = state
978 .config
979 .leaf_half_life
980 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
981 let max_tree_samples = state.config.max_tree_samples;
982
983 let steps: Vec<BoostingStep> = state
984 .steps
985 .iter()
986 .enumerate()
987 .map(|(i, step_snap)| {
988 let tree_config = TreeConfig::new()
989 .max_depth(state.config.max_depth)
990 .n_bins(state.config.n_bins)
991 .lambda(state.config.lambda)
992 .gamma(state.config.gamma)
993 .grace_period(state.config.grace_period)
994 .delta(state.config.delta)
995 .feature_subsample_rate(state.config.feature_subsample_rate)
996 .leaf_decay_alpha_opt(leaf_decay_alpha)
997 .split_reeval_interval_opt(state.config.split_reeval_interval)
998 .feature_types_opt(state.config.feature_types.clone())
999 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1000 .monotone_constraints_opt(state.config.monotone_constraints.clone())
1001 .adaptive_depth_opt(state.config.adaptive_depth)
1002 .leaf_model_type(state.config.leaf_model_type.clone())
1003 .seed(state.config.seed ^ (i as u64));
1004
1005 let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1006 let alternate = step_snap
1007 .alternate_tree
1008 .as_ref()
1009 .map(|snap| rebuild_tree(snap, tree_config.clone()));
1010
1011 let mut detector = state.config.drift_detector.create();
1012 if let Some(ref ds) = step_snap.drift_state {
1013 detector.restore_state(ds);
1014 }
1015 let mut slot = TreeSlot::from_trees(
1016 active,
1017 alternate,
1018 tree_config,
1019 detector,
1020 max_tree_samples,
1021 );
1022 if let Some(ref ads) = step_snap.alt_drift_state {
1023 if let Some(alt_det) = slot.alt_detector_mut() {
1024 alt_det.restore_state(ads);
1025 }
1026 }
1027 BoostingStep::from_slot(slot)
1028 })
1029 .collect();
1030
1031 let n = steps.len();
1032 let has_pruning = state.config.quality_prune_alpha.is_some();
1033
1034 let contribution_ewma = if !state.contribution_ewma.is_empty() {
1036 state.contribution_ewma
1037 } else if has_pruning {
1038 vec![0.0; n]
1039 } else {
1040 Vec::new()
1041 };
1042 let low_contrib_count = if !state.low_contrib_count.is_empty() {
1043 state.low_contrib_count
1044 } else if has_pruning {
1045 vec![0; n]
1046 } else {
1047 Vec::new()
1048 };
1049
1050 Self {
1051 config: state.config,
1052 steps,
1053 loss,
1054 base_prediction: state.base_prediction,
1055 base_initialized: state.base_initialized,
1056 initial_targets: state.initial_targets,
1057 initial_target_count: state.initial_target_count,
1058 samples_seen: state.samples_seen,
1059 rng_state: state.rng_state,
1060 contribution_ewma,
1061 low_contrib_count,
1062 rolling_mean_error: state.rolling_mean_error,
1063 auto_bandwidths: Vec::new(),
1064 last_replacement_sum: 0,
1065 }
1066 }
1067}
1068
1069#[cfg(feature = "_serde_support")]
1075pub(crate) fn snapshot_tree(
1076 tree: &crate::tree::hoeffding::HoeffdingTree,
1077) -> crate::serde_support::TreeSnapshot {
1078 use crate::serde_support::TreeSnapshot;
1079 use crate::tree::StreamingTree;
1080 let arena = tree.arena();
1081 TreeSnapshot {
1082 feature_idx: arena.feature_idx.clone(),
1083 threshold: arena.threshold.clone(),
1084 left: arena.left.iter().map(|id| id.0).collect(),
1085 right: arena.right.iter().map(|id| id.0).collect(),
1086 leaf_value: arena.leaf_value.clone(),
1087 is_leaf: arena.is_leaf.clone(),
1088 depth: arena.depth.clone(),
1089 sample_count: arena.sample_count.clone(),
1090 n_features: tree.n_features(),
1091 samples_seen: tree.n_samples_seen(),
1092 rng_state: tree.rng_state(),
1093 categorical_mask: arena.categorical_mask.clone(),
1094 }
1095}
1096
1097#[cfg(feature = "_serde_support")]
1099pub(crate) fn rebuild_tree(
1100 snapshot: &crate::serde_support::TreeSnapshot,
1101 tree_config: TreeConfig,
1102) -> crate::tree::hoeffding::HoeffdingTree {
1103 use crate::tree::hoeffding::HoeffdingTree;
1104 use crate::tree::node::{NodeId, TreeArena};
1105
1106 let mut arena = TreeArena::new();
1107 let n = snapshot.feature_idx.len();
1108
1109 for i in 0..n {
1110 arena.feature_idx.push(snapshot.feature_idx[i]);
1111 arena.threshold.push(snapshot.threshold[i]);
1112 arena.left.push(NodeId(snapshot.left[i]));
1113 arena.right.push(NodeId(snapshot.right[i]));
1114 arena.leaf_value.push(snapshot.leaf_value[i]);
1115 arena.is_leaf.push(snapshot.is_leaf[i]);
1116 arena.depth.push(snapshot.depth[i]);
1117 arena.sample_count.push(snapshot.sample_count[i]);
1118 let mask = snapshot.categorical_mask.get(i).copied().flatten();
1119 arena.categorical_mask.push(mask);
1120 }
1121
1122 HoeffdingTree::from_arena(
1123 tree_config,
1124 arena,
1125 snapshot.n_features,
1126 snapshot.samples_seen,
1127 snapshot.rng_state,
1128 )
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133 use super::*;
1134 use alloc::boxed::Box;
1135 use alloc::vec;
1136 use alloc::vec::Vec;
1137
1138 fn default_config() -> SGBTConfig {
1139 SGBTConfig::builder()
1140 .n_steps(10)
1141 .learning_rate(0.1)
1142 .grace_period(20)
1143 .max_depth(4)
1144 .n_bins(16)
1145 .build()
1146 .unwrap()
1147 }
1148
1149 #[test]
1150 fn new_model_predicts_zero() {
1151 let model = SGBT::new(default_config());
1152 let pred = model.predict(&[1.0, 2.0, 3.0]);
1153 assert!(pred.abs() < 1e-12);
1154 }
1155
1156 #[test]
1157 fn train_one_does_not_panic() {
1158 let mut model = SGBT::new(default_config());
1159 model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1160 assert_eq!(model.n_samples_seen(), 1);
1161 }
1162
1163 #[test]
1164 fn prediction_changes_after_training() {
1165 let mut model = SGBT::new(default_config());
1166 let features = vec![1.0, 2.0, 3.0];
1167 for i in 0..100 {
1168 model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1169 }
1170 let pred = model.predict(&features);
1171 assert!(pred.is_finite());
1172 }
1173
1174 #[test]
1175 fn linear_signal_rmse_improves() {
1176 let config = SGBTConfig::builder()
1177 .n_steps(20)
1178 .learning_rate(0.1)
1179 .grace_period(10)
1180 .max_depth(3)
1181 .n_bins(16)
1182 .build()
1183 .unwrap();
1184 let mut model = SGBT::new(config);
1185
1186 let mut rng: u64 = 12345;
1187 let mut early_errors = Vec::new();
1188 let mut late_errors = Vec::new();
1189
1190 for i in 0..500 {
1191 rng ^= rng << 13;
1192 rng ^= rng >> 7;
1193 rng ^= rng << 17;
1194 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1195 rng ^= rng << 13;
1196 rng ^= rng >> 7;
1197 rng ^= rng << 17;
1198 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1199 let target = 2.0 * x1 + 3.0 * x2;
1200
1201 let pred = model.predict(&[x1, x2]);
1202 let error = (pred - target).powi(2);
1203
1204 if (50..150).contains(&i) {
1205 early_errors.push(error);
1206 }
1207 if i >= 400 {
1208 late_errors.push(error);
1209 }
1210
1211 model.train_one(&Sample::new(vec![x1, x2], target));
1212 }
1213
1214 let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1215 let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1216
1217 assert!(
1218 late_rmse < early_rmse,
1219 "RMSE should decrease: early={:.4}, late={:.4}",
1220 early_rmse,
1221 late_rmse
1222 );
1223 }
1224
1225 #[test]
1226 fn train_batch_equivalent_to_sequential() {
1227 let config = default_config();
1228 let mut model_seq = SGBT::new(config.clone());
1229 let mut model_batch = SGBT::new(config);
1230
1231 let samples: Vec<Sample> = (0..20)
1232 .map(|i| {
1233 let x = i as f64 * 0.5;
1234 Sample::new(vec![x, x * 2.0], x * 3.0)
1235 })
1236 .collect();
1237
1238 for s in &samples {
1239 model_seq.train_one(s);
1240 }
1241 model_batch.train_batch(&samples);
1242
1243 let pred_seq = model_seq.predict(&[1.0, 2.0]);
1244 let pred_batch = model_batch.predict(&[1.0, 2.0]);
1245
1246 assert!(
1247 (pred_seq - pred_batch).abs() < 1e-10,
1248 "seq={}, batch={}",
1249 pred_seq,
1250 pred_batch
1251 );
1252 }
1253
1254 #[test]
1255 fn reset_returns_to_initial() {
1256 let mut model = SGBT::new(default_config());
1257 for i in 0..100 {
1258 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1259 }
1260 model.reset();
1261 assert_eq!(model.n_samples_seen(), 0);
1262 assert!(!model.is_initialized());
1263 assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1264 }
1265
1266 #[test]
1267 fn base_prediction_initializes() {
1268 let mut model = SGBT::new(default_config());
1269 for i in 0..50 {
1270 model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1271 }
1272 assert!(model.is_initialized());
1273 let expected = (100.0 + 149.0) / 2.0;
1274 assert!((model.base_prediction() - expected).abs() < 1.0);
1275 }
1276
1277 #[test]
1278 fn with_loss_uses_custom_loss() {
1279 use crate::loss::logistic::LogisticLoss;
1280 let model = SGBT::with_loss(default_config(), LogisticLoss);
1281 let pred = model.predict_transformed(&[1.0, 2.0]);
1282 assert!(
1283 (pred - 0.5).abs() < 1e-6,
1284 "sigmoid(0) should be 0.5, got {}",
1285 pred
1286 );
1287 }
1288
1289 #[test]
1290 fn ewma_config_propagates_and_trains() {
1291 let config = SGBTConfig::builder()
1292 .n_steps(5)
1293 .learning_rate(0.1)
1294 .grace_period(10)
1295 .max_depth(3)
1296 .n_bins(16)
1297 .leaf_half_life(50)
1298 .build()
1299 .unwrap();
1300 let mut model = SGBT::new(config);
1301
1302 for i in 0..200 {
1303 let x = (i as f64) * 0.1;
1304 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1305 }
1306
1307 let pred = model.predict(&[1.0, 2.0]);
1308 assert!(
1309 pred.is_finite(),
1310 "EWMA-enabled model should produce finite predictions, got {}",
1311 pred
1312 );
1313 }
1314
1315 #[test]
1316 fn max_tree_samples_config_propagates() {
1317 let config = SGBTConfig::builder()
1318 .n_steps(5)
1319 .learning_rate(0.1)
1320 .grace_period(10)
1321 .max_depth(3)
1322 .n_bins(16)
1323 .max_tree_samples(200)
1324 .build()
1325 .unwrap();
1326 let mut model = SGBT::new(config);
1327
1328 for i in 0..500 {
1329 let x = (i as f64) * 0.1;
1330 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1331 }
1332
1333 let pred = model.predict(&[1.0, 2.0]);
1334 assert!(
1335 pred.is_finite(),
1336 "max_tree_samples model should produce finite predictions, got {}",
1337 pred
1338 );
1339 }
1340
1341 #[test]
1342 fn split_reeval_config_propagates() {
1343 let config = SGBTConfig::builder()
1344 .n_steps(5)
1345 .learning_rate(0.1)
1346 .grace_period(10)
1347 .max_depth(2)
1348 .n_bins(16)
1349 .split_reeval_interval(50)
1350 .build()
1351 .unwrap();
1352 let mut model = SGBT::new(config);
1353
1354 let mut rng: u64 = 12345;
1355 for _ in 0..1000 {
1356 rng ^= rng << 13;
1357 rng ^= rng >> 7;
1358 rng ^= rng << 17;
1359 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1360 rng ^= rng << 13;
1361 rng ^= rng >> 7;
1362 rng ^= rng << 17;
1363 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1364 let target = 2.0 * x1 + 3.0 * x2;
1365 model.train_one(&Sample::new(vec![x1, x2], target));
1366 }
1367
1368 let pred = model.predict(&[1.0, 2.0]);
1369 assert!(
1370 pred.is_finite(),
1371 "split re-eval model should produce finite predictions, got {}",
1372 pred
1373 );
1374 }
1375
1376 #[test]
1377 fn loss_accessor_works() {
1378 use crate::loss::logistic::LogisticLoss;
1379 let model = SGBT::with_loss(default_config(), LogisticLoss);
1380 let _loss: &LogisticLoss = model.loss();
1382 assert_eq!(_loss.n_outputs(), 1);
1383 }
1384
1385 #[test]
1386 fn clone_produces_independent_copy() {
1387 let config = default_config();
1388 let mut model = SGBT::new(config);
1389
1390 let mut rng: u64 = 99999;
1392 for _ in 0..200 {
1393 rng ^= rng << 13;
1394 rng ^= rng >> 7;
1395 rng ^= rng << 17;
1396 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1397 let target = 2.0 * x + 1.0;
1398 model.train_one(&Sample::new(vec![x], target));
1399 }
1400
1401 let mut cloned = model.clone();
1403
1404 let test_features = [3.0];
1406 let pred_original = model.predict(&test_features);
1407 let pred_cloned = cloned.predict(&test_features);
1408 assert!(
1409 (pred_original - pred_cloned).abs() < 1e-12,
1410 "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1411 );
1412
1413 for _ in 0..200 {
1415 rng ^= rng << 13;
1416 rng ^= rng >> 7;
1417 rng ^= rng << 17;
1418 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1419 let target = -3.0 * x + 5.0; cloned.train_one(&Sample::new(vec![x], target));
1421 }
1422
1423 let pred_original_after = model.predict(&test_features);
1424 let pred_cloned_after = cloned.predict(&test_features);
1425
1426 assert!(
1428 (pred_original - pred_original_after).abs() < 1e-12,
1429 "original should be unchanged after training clone"
1430 );
1431
1432 assert!(
1434 (pred_original_after - pred_cloned_after).abs() > 1e-6,
1435 "clone should diverge after independent training"
1436 );
1437 }
1438
1439 #[test]
1443 fn predict_with_confidence_finite() {
1444 let config = SGBTConfig::builder()
1445 .n_steps(5)
1446 .grace_period(10)
1447 .build()
1448 .unwrap();
1449 let mut model = SGBT::new(config);
1450
1451 for i in 0..100 {
1453 let x = i as f64 * 0.1;
1454 model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1455 }
1456
1457 let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1458 assert!(pred.is_finite(), "prediction should be finite");
1459 assert!(confidence.is_finite(), "confidence should be finite");
1460 assert!(
1461 confidence > 0.0,
1462 "confidence should be positive after training"
1463 );
1464 }
1465
1466 #[test]
1470 fn predict_with_confidence_positive_after_training() {
1471 let config = SGBTConfig::builder()
1472 .n_steps(5)
1473 .grace_period(10)
1474 .build()
1475 .unwrap();
1476 let mut model = SGBT::new(config);
1477
1478 for i in 0..200 {
1480 let x = i as f64 * 0.05;
1481 model.train_one(&(&[x][..], x * 2.0));
1482 }
1483
1484 let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1485
1486 assert!(pred.is_finite(), "prediction should be finite");
1487 assert!(
1488 confidence > 0.0 && confidence.is_finite(),
1489 "confidence should be finite and positive, got {}",
1490 confidence,
1491 );
1492
1493 let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1495 assert!(
1496 (pred - pred2).abs() < 1e-12,
1497 "same input should give same prediction"
1498 );
1499 assert!(
1500 (confidence - conf2).abs() < 1e-12,
1501 "same input should give same confidence"
1502 );
1503 }
1504
1505 #[test]
1509 fn predict_with_confidence_matches_predict() {
1510 let config = SGBTConfig::builder()
1511 .n_steps(10)
1512 .grace_period(10)
1513 .build()
1514 .unwrap();
1515 let mut model = SGBT::new(config);
1516
1517 for i in 0..200 {
1518 let x = (i as f64 - 100.0) * 0.01;
1519 model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1520 }
1521
1522 let pred = model.predict(&[0.5, 0.25]);
1523 let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1524
1525 assert!(
1526 (pred - conf_pred).abs() < 1e-10,
1527 "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1528 pred,
1529 conf_pred,
1530 );
1531 }
1532
1533 #[test]
1537 fn gradient_clip_config_builder() {
1538 let config = SGBTConfig::builder()
1539 .n_steps(10)
1540 .gradient_clip_sigma(3.0)
1541 .build()
1542 .unwrap();
1543
1544 assert_eq!(config.gradient_clip_sigma, Some(3.0));
1545 }
1546
1547 #[test]
1551 fn monotone_constraints_config_builder() {
1552 let config = SGBTConfig::builder()
1553 .n_steps(10)
1554 .monotone_constraints(vec![1, -1, 0])
1555 .build()
1556 .unwrap();
1557
1558 assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1559 }
1560
1561 #[test]
1565 fn monotone_constraints_invalid_value_rejected() {
1566 let result = SGBTConfig::builder()
1567 .n_steps(10)
1568 .monotone_constraints(vec![1, 2, 0])
1569 .build();
1570
1571 assert!(result.is_err(), "constraint value 2 should be rejected");
1572 }
1573
1574 #[test]
1578 fn gradient_clip_sigma_negative_rejected() {
1579 let result = SGBTConfig::builder()
1580 .n_steps(10)
1581 .gradient_clip_sigma(-1.0)
1582 .build();
1583
1584 assert!(result.is_err(), "negative sigma should be rejected");
1585 }
1586
1587 #[test]
1591 fn gradient_clipping_reduces_outlier_impact() {
1592 let config_no_clip = SGBTConfig::builder()
1594 .n_steps(5)
1595 .grace_period(10)
1596 .build()
1597 .unwrap();
1598 let mut model_no_clip = SGBT::new(config_no_clip);
1599
1600 let config_clip = SGBTConfig::builder()
1602 .n_steps(5)
1603 .grace_period(10)
1604 .gradient_clip_sigma(3.0)
1605 .build()
1606 .unwrap();
1607 let mut model_clip = SGBT::new(config_clip);
1608
1609 for i in 0..100 {
1611 let x = (i as f64) * 0.01;
1612 let sample = (&[x][..], x * 2.0);
1613 model_no_clip.train_one(&sample);
1614 model_clip.train_one(&sample);
1615 }
1616
1617 let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1618 let pred_clip_before = model_clip.predict(&[0.5]);
1619
1620 let outlier = (&[0.5_f64][..], 10000.0);
1622 model_no_clip.train_one(&outlier);
1623 model_clip.train_one(&outlier);
1624
1625 let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1626 let pred_clip_after = model_clip.predict(&[0.5]);
1627
1628 let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1629 let delta_clip = (pred_clip_after - pred_clip_before).abs();
1630
1631 assert!(
1633 delta_clip <= delta_no_clip + 1e-10,
1634 "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1635 delta_clip,
1636 delta_no_clip,
1637 );
1638 }
1639
1640 #[test]
1644 fn train_batch_with_callback_fires() {
1645 let config = SGBTConfig::builder()
1646 .n_steps(3)
1647 .grace_period(5)
1648 .build()
1649 .unwrap();
1650 let mut model = SGBT::new(config);
1651
1652 let data: Vec<(Vec<f64>, f64)> = (0..25)
1653 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1654 .collect();
1655
1656 let mut callbacks = Vec::new();
1657 model.train_batch_with_callback(&data, 10, |n| {
1658 callbacks.push(n);
1659 });
1660
1661 assert_eq!(callbacks, vec![10, 20, 25]);
1663 }
1664
1665 #[test]
1669 fn train_batch_subsampled_trains_subset() {
1670 let config = SGBTConfig::builder()
1671 .n_steps(3)
1672 .grace_period(5)
1673 .build()
1674 .unwrap();
1675 let mut model = SGBT::new(config);
1676
1677 let data: Vec<(Vec<f64>, f64)> = (0..100)
1678 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1679 .collect();
1680
1681 model.train_batch_subsampled(&data, 20);
1683
1684 assert!(
1686 model.n_samples_seen() > 0,
1687 "model should have trained on subset"
1688 );
1689 assert!(
1690 model.n_samples_seen() <= 20,
1691 "model should have trained at most 20 samples, got {}",
1692 model.n_samples_seen(),
1693 );
1694 }
1695
1696 #[test]
1700 fn train_batch_subsampled_full_equals_batch() {
1701 let config1 = SGBTConfig::builder()
1702 .n_steps(3)
1703 .grace_period(5)
1704 .build()
1705 .unwrap();
1706 let config2 = config1.clone();
1707
1708 let mut model1 = SGBT::new(config1);
1709 let mut model2 = SGBT::new(config2);
1710
1711 let data: Vec<(Vec<f64>, f64)> = (0..50)
1712 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1713 .collect();
1714
1715 model1.train_batch(&data);
1716 model2.train_batch_subsampled(&data, 1000); assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1720 let pred1 = model1.predict(&[2.5]);
1721 let pred2 = model2.predict(&[2.5]);
1722 assert!(
1723 (pred1 - pred2).abs() < 1e-12,
1724 "full subsample should equal batch: {} vs {}",
1725 pred1,
1726 pred2,
1727 );
1728 }
1729
1730 #[test]
1734 fn train_batch_subsampled_with_callback_works() {
1735 let config = SGBTConfig::builder()
1736 .n_steps(3)
1737 .grace_period(5)
1738 .build()
1739 .unwrap();
1740 let mut model = SGBT::new(config);
1741
1742 let data: Vec<(Vec<f64>, f64)> = (0..200)
1743 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1744 .collect();
1745
1746 let mut callbacks = Vec::new();
1747 model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1748 callbacks.push(n);
1749 });
1750
1751 assert!(!callbacks.is_empty(), "should have received callbacks");
1753 assert_eq!(
1754 *callbacks.last().unwrap(),
1755 50,
1756 "final callback should be total samples"
1757 );
1758 }
1759
1760 fn xorshift64(state: &mut u64) -> u64 {
1766 let mut s = *state;
1767 s ^= s << 13;
1768 s ^= s >> 7;
1769 s ^= s << 17;
1770 *state = s;
1771 s
1772 }
1773
1774 fn rand_f64(state: &mut u64) -> f64 {
1775 xorshift64(state) as f64 / u64::MAX as f64
1776 }
1777
1778 fn linear_leaves_config() -> SGBTConfig {
1779 SGBTConfig::builder()
1780 .n_steps(10)
1781 .learning_rate(0.1)
1782 .grace_period(20)
1783 .max_depth(2) .n_bins(16)
1785 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1786 learning_rate: 0.1,
1787 decay: None,
1788 use_adagrad: false,
1789 })
1790 .build()
1791 .unwrap()
1792 }
1793
1794 #[test]
1795 fn linear_leaves_trains_without_panic() {
1796 let mut model = SGBT::new(linear_leaves_config());
1797 let mut rng = 42u64;
1798 for _ in 0..200 {
1799 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1800 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1801 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1802 model.train_one(&Sample::new(vec![x1, x2], y));
1803 }
1804 assert_eq!(model.n_samples_seen(), 200);
1805 }
1806
1807 #[test]
1808 fn linear_leaves_prediction_finite() {
1809 let mut model = SGBT::new(linear_leaves_config());
1810 let mut rng = 42u64;
1811 for _ in 0..200 {
1812 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1813 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1814 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1815 model.train_one(&Sample::new(vec![x1, x2], y));
1816 }
1817 let pred = model.predict(&[0.5, -0.3]);
1818 assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1819 }
1820
1821 #[test]
1822 fn linear_leaves_learns_linear_target() {
1823 let mut model = SGBT::new(linear_leaves_config());
1824 let mut rng = 42u64;
1825 for _ in 0..500 {
1826 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1827 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1828 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1829 model.train_one(&Sample::new(vec![x1, x2], y));
1830 }
1831
1832 let mut total_error = 0.0;
1834 for _ in 0..50 {
1835 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1836 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1837 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1838 let pred = model.predict(&[x1, x2]);
1839 total_error += (pred - y).powi(2);
1840 }
1841 let mse = total_error / 50.0;
1842 assert!(
1843 mse < 5.0,
1844 "linear leaves MSE on linear target should be < 5.0, got {mse}"
1845 );
1846 }
1847
1848 #[test]
1849 fn linear_leaves_better_than_constant_at_low_depth() {
1850 let constant_config = SGBTConfig::builder()
1853 .n_steps(10)
1854 .learning_rate(0.1)
1855 .grace_period(20)
1856 .max_depth(2)
1857 .n_bins(16)
1858 .seed(0xDEAD)
1859 .build()
1860 .unwrap();
1861 let linear_config = SGBTConfig::builder()
1862 .n_steps(10)
1863 .learning_rate(0.1)
1864 .grace_period(20)
1865 .max_depth(2)
1866 .n_bins(16)
1867 .seed(0xDEAD)
1868 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1869 learning_rate: 0.1,
1870 decay: None,
1871 use_adagrad: false,
1872 })
1873 .build()
1874 .unwrap();
1875
1876 let mut constant_model = SGBT::new(constant_config);
1877 let mut linear_model = SGBT::new(linear_config);
1878 let mut rng = 42u64;
1879
1880 for _ in 0..500 {
1881 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1882 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1883 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1884 let sample = Sample::new(vec![x1, x2], y);
1885 constant_model.train_one(&sample);
1886 linear_model.train_one(&sample);
1887 }
1888
1889 let mut constant_mse = 0.0;
1891 let mut linear_mse = 0.0;
1892 for _ in 0..100 {
1893 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1894 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1895 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1896 constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1897 linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1898 }
1899 constant_mse /= 100.0;
1900 linear_mse /= 100.0;
1901
1902 assert!(
1904 linear_mse < constant_mse,
1905 "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1906 );
1907 }
1908
1909 #[test]
1910 fn adaptive_leaves_trains_without_panic() {
1911 let config = SGBTConfig::builder()
1912 .n_steps(10)
1913 .learning_rate(0.1)
1914 .grace_period(20)
1915 .max_depth(3)
1916 .n_bins(16)
1917 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
1918 promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
1919 learning_rate: 0.1,
1920 decay: None,
1921 use_adagrad: false,
1922 }),
1923 })
1924 .build()
1925 .unwrap();
1926
1927 let mut model = SGBT::new(config);
1928 let mut rng = 42u64;
1929 for _ in 0..500 {
1930 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1931 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1932 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1933 model.train_one(&Sample::new(vec![x1, x2], y));
1934 }
1935 let pred = model.predict(&[0.5, -0.3]);
1936 assert!(
1937 pred.is_finite(),
1938 "adaptive leaf prediction should be finite, got {pred}"
1939 );
1940 }
1941
1942 #[test]
1943 fn linear_leaves_with_decay_trains_without_panic() {
1944 let config = SGBTConfig::builder()
1945 .n_steps(10)
1946 .learning_rate(0.1)
1947 .grace_period(20)
1948 .max_depth(3)
1949 .n_bins(16)
1950 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1951 learning_rate: 0.1,
1952 decay: Some(0.995),
1953 use_adagrad: false,
1954 })
1955 .build()
1956 .unwrap();
1957
1958 let mut model = SGBT::new(config);
1959 let mut rng = 42u64;
1960 for _ in 0..500 {
1961 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1962 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1963 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1964 model.train_one(&Sample::new(vec![x1, x2], y));
1965 }
1966 let pred = model.predict(&[0.5, -0.3]);
1967 assert!(
1968 pred.is_finite(),
1969 "decay leaf prediction should be finite, got {pred}"
1970 );
1971 }
1972
1973 #[test]
1977 fn predict_smooth_returns_finite() {
1978 let config = SGBTConfig::builder()
1979 .n_steps(5)
1980 .learning_rate(0.1)
1981 .grace_period(10)
1982 .build()
1983 .unwrap();
1984 let mut model = SGBT::new(config);
1985
1986 for i in 0..200 {
1987 let x = (i as f64) * 0.1;
1988 model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
1989 }
1990
1991 let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
1992 let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
1993
1994 assert!(pred_hard.is_finite(), "hard prediction should be finite");
1995 assert!(
1996 pred_smooth.is_finite(),
1997 "smooth prediction should be finite"
1998 );
1999 }
2000
2001 #[test]
2005 fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2006 let config = SGBTConfig::builder()
2007 .n_steps(5)
2008 .learning_rate(0.1)
2009 .grace_period(10)
2010 .build()
2011 .unwrap();
2012 let mut model = SGBT::new(config);
2013
2014 for i in 0..300 {
2015 let x = (i as f64) * 0.1;
2016 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2017 }
2018
2019 let features = [5.0, 2.5];
2020 let hard = model.predict(&features);
2021 let smooth = model.predict_smooth(&features, 0.001);
2022
2023 assert!(
2024 (hard - smooth).abs() < 0.5,
2025 "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2026 hard,
2027 smooth,
2028 );
2029 }
2030
2031 #[test]
2032 fn auto_bandwidth_computed_after_training() {
2033 let config = SGBTConfig::builder()
2034 .n_steps(5)
2035 .learning_rate(0.1)
2036 .grace_period(10)
2037 .build()
2038 .unwrap();
2039 let mut model = SGBT::new(config);
2040
2041 assert!(model.auto_bandwidths().is_empty());
2043
2044 for i in 0..200 {
2045 let x = (i as f64) * 0.1;
2046 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2047 }
2048
2049 let bws = model.auto_bandwidths();
2051 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2052
2053 let pred = model.predict(&[5.0, 2.5]);
2055 assert!(
2056 pred.is_finite(),
2057 "auto-bandwidth predict should be finite: {}",
2058 pred
2059 );
2060 }
2061
2062 #[test]
2063 fn predict_interpolated_returns_finite() {
2064 let config = SGBTConfig::builder()
2065 .n_steps(5)
2066 .learning_rate(0.01)
2067 .build()
2068 .unwrap();
2069 let mut model = SGBT::new(config);
2070
2071 for i in 0..200 {
2072 let x = (i as f64) * 0.1;
2073 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2074 }
2075
2076 let pred = model.predict_interpolated(&[1.0, 0.5]);
2077 assert!(
2078 pred.is_finite(),
2079 "interpolated prediction should be finite: {}",
2080 pred
2081 );
2082 }
2083
2084 #[test]
2085 fn predict_sibling_interpolated_varies_with_features() {
2086 let config = SGBTConfig::builder()
2087 .n_steps(10)
2088 .learning_rate(0.1)
2089 .grace_period(10)
2090 .max_depth(6)
2091 .delta(0.1)
2092 .build()
2093 .unwrap();
2094 let mut model = SGBT::new(config);
2095
2096 for i in 0..2000 {
2097 let x = (i as f64) * 0.01;
2098 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2099 model.train_one(&Sample::new(vec![x, x * 0.3], y));
2100 }
2101
2102 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2104 assert!(pred.is_finite(), "sibling interpolated should be finite");
2105
2106 let bws = model.auto_bandwidths();
2109 if bws.iter().any(|&b| b.is_finite()) {
2110 let hard: Vec<f64> = (0..200)
2111 .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2112 .collect();
2113 let sib: Vec<f64> = (0..200)
2114 .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2115 .collect();
2116 let hc = hard
2117 .windows(2)
2118 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2119 .count();
2120 let sc = sib
2121 .windows(2)
2122 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2123 .count();
2124 assert!(
2125 sc >= hc,
2126 "sibling should produce >= hard changes: sib={}, hard={}",
2127 sc,
2128 hc
2129 );
2130 }
2131 }
2132
2133 #[test]
2134 fn predict_graduated_returns_finite() {
2135 let config = SGBTConfig::builder()
2136 .n_steps(5)
2137 .learning_rate(0.01)
2138 .max_tree_samples(200)
2139 .shadow_warmup(50)
2140 .build()
2141 .unwrap();
2142 let mut model = SGBT::new(config);
2143
2144 for i in 0..300 {
2145 let x = (i as f64) * 0.1;
2146 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2147 }
2148
2149 let pred = model.predict_graduated(&[1.0, 0.5]);
2150 assert!(
2151 pred.is_finite(),
2152 "graduated prediction should be finite: {}",
2153 pred
2154 );
2155
2156 let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2157 assert!(
2158 pred2.is_finite(),
2159 "graduated+sibling prediction should be finite: {}",
2160 pred2
2161 );
2162 }
2163
2164 #[test]
2165 fn shadow_warmup_validation() {
2166 let result = SGBTConfig::builder()
2167 .n_steps(5)
2168 .learning_rate(0.01)
2169 .shadow_warmup(0)
2170 .build();
2171 assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2172 }
2173}