1pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62pub struct SGBT<L: Loss = SquaredLoss> {
90 config: SGBTConfig,
92 steps: Vec<BoostingStep>,
94 loss: L,
96 base_prediction: f64,
98 base_initialized: bool,
100 initial_targets: Vec<f64>,
102 initial_target_count: usize,
104 samples_seen: u64,
106 rng_state: u64,
108 contribution_ewma: Vec<f64>,
111 low_contrib_count: Vec<u64>,
114 rolling_mean_error: f64,
117 rolling_contribution_sigma: f64,
120 auto_bandwidths: Vec<f64>,
123 last_replacement_sum: u64,
126}
127
128impl<L: Loss + Clone> Clone for SGBT<L> {
129 fn clone(&self) -> Self {
130 Self {
131 config: self.config.clone(),
132 steps: self.steps.clone(),
133 loss: self.loss.clone(),
134 base_prediction: self.base_prediction,
135 base_initialized: self.base_initialized,
136 initial_targets: self.initial_targets.clone(),
137 initial_target_count: self.initial_target_count,
138 samples_seen: self.samples_seen,
139 rng_state: self.rng_state,
140 contribution_ewma: self.contribution_ewma.clone(),
141 low_contrib_count: self.low_contrib_count.clone(),
142 rolling_mean_error: self.rolling_mean_error,
143 rolling_contribution_sigma: self.rolling_contribution_sigma,
144 auto_bandwidths: self.auto_bandwidths.clone(),
145 last_replacement_sum: self.last_replacement_sum,
146 }
147 }
148}
149
150impl<L: Loss> fmt::Debug for SGBT<L> {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 f.debug_struct("SGBT")
153 .field("n_steps", &self.steps.len())
154 .field("samples_seen", &self.samples_seen)
155 .field("base_prediction", &self.base_prediction)
156 .field("base_initialized", &self.base_initialized)
157 .finish()
158 }
159}
160
161impl SGBT<SquaredLoss> {
166 pub fn new(config: SGBTConfig) -> Self {
171 Self::with_loss(config, SquaredLoss)
172 }
173}
174
175impl<L: Loss> SGBT<L> {
180 pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
193 let leaf_decay_alpha = config
194 .leaf_half_life
195 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
196
197 let tree_config = TreeConfig::new()
198 .max_depth(config.max_depth)
199 .n_bins(config.n_bins)
200 .lambda(config.lambda)
201 .gamma(config.gamma)
202 .grace_period(config.grace_period)
203 .delta(config.delta)
204 .feature_subsample_rate(config.feature_subsample_rate)
205 .leaf_decay_alpha_opt(leaf_decay_alpha)
206 .split_reeval_interval_opt(config.split_reeval_interval)
207 .feature_types_opt(config.feature_types.clone())
208 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
209 .monotone_constraints_opt(config.monotone_constraints.clone())
210 .max_leaf_output_opt(config.max_leaf_output)
211 .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
212 .adaptive_depth_opt(config.adaptive_depth)
213 .min_hessian_sum_opt(config.min_hessian_sum)
214 .leaf_model_type(config.leaf_model_type.clone());
215
216 let max_tree_samples = config.max_tree_samples;
217
218 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
219 let steps: Vec<BoostingStep> = (0..config.n_steps)
220 .map(|i| {
221 let mut tc = tree_config.clone();
222 tc.seed = config.seed ^ (i as u64);
223 let detector = config.drift_detector.create();
224 if shadow_warmup > 0 {
225 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
226 } else {
227 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
228 }
229 })
230 .collect();
231
232 let seed = config.seed;
233 let initial_target_count = config.initial_target_count;
234 let n = config.n_steps;
235 let has_pruning =
236 config.quality_prune_alpha.is_some() || config.proactive_prune_interval.is_some();
237 Self {
238 config,
239 steps,
240 loss,
241 base_prediction: 0.0,
242 base_initialized: false,
243 initial_targets: Vec::new(),
244 initial_target_count,
245 samples_seen: 0,
246 rng_state: seed,
247 contribution_ewma: if has_pruning {
248 vec![0.0; n]
249 } else {
250 Vec::new()
251 },
252 low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
253 rolling_mean_error: 0.0,
254 rolling_contribution_sigma: 0.0,
255 auto_bandwidths: Vec::new(),
256 last_replacement_sum: 0,
257 }
258 }
259
260 pub fn train_one(&mut self, sample: &impl Observation) {
266 self.samples_seen += 1;
267 let target = sample.target();
268 let features = sample.features();
269
270 if !self.base_initialized {
272 self.initial_targets.push(target);
273 if self.initial_targets.len() >= self.initial_target_count {
274 self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
275 self.base_initialized = true;
276 self.initial_targets.clear();
277 self.initial_targets.shrink_to_fit();
278 }
279 }
280
281 let mut current_pred = self.base_prediction;
283
284 if let Some((base_mts, k)) = self.config.adaptive_mts {
286 let sigma = self.contribution_variance(features);
287 self.rolling_contribution_sigma =
288 0.999 * self.rolling_contribution_sigma + 0.001 * sigma;
289
290 let normalized = if self.rolling_contribution_sigma > 1e-10 {
291 sigma / self.rolling_contribution_sigma
292 } else {
293 1.0
294 };
295 let factor = 1.0 / (1.0 + k * normalized);
296 let floor = (base_mts as f64 * self.config.adaptive_mts_floor)
297 .max(self.config.grace_period as f64 * 2.0);
298 let effective_mts = ((base_mts as f64) * factor).max(floor) as u64;
299 for step in &mut self.steps {
300 step.slot_mut().set_max_tree_samples(Some(effective_mts));
301 }
302 }
303
304 let prune_alpha = self
305 .config
306 .quality_prune_alpha
307 .or_else(|| self.config.proactive_prune_interval.map(|_| 0.01));
308 let prune_threshold = self.config.quality_prune_threshold;
309 let prune_patience = self.config.quality_prune_patience;
310
311 let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
313 let abs_error = crate::math::abs(target - current_pred);
314 if self.rolling_mean_error > 1e-15 {
315 let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
316 self.rolling_mean_error =
317 ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
318 w
319 } else {
320 self.rolling_mean_error = abs_error.max(1e-15);
321 1.0 }
323 } else {
324 1.0
325 };
326
327 for s in 0..self.steps.len() {
329 let gradient = self.loss.gradient(target, current_pred) * error_weight;
330 let hessian = self.loss.hessian(target, current_pred) * error_weight;
331 let train_count = self
332 .config
333 .variant
334 .train_count(hessian, &mut self.rng_state);
335
336 let step_pred =
337 self.steps[s].train_and_predict(features, gradient, hessian, train_count);
338
339 current_pred += self.config.learning_rate * step_pred;
340
341 if let Some(alpha) = prune_alpha {
343 let contribution = crate::math::abs(self.config.learning_rate * step_pred);
344 self.contribution_ewma[s] =
345 alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
346
347 if self.contribution_ewma[s] < prune_threshold {
348 self.low_contrib_count[s] += 1;
349 if self.low_contrib_count[s] >= prune_patience {
350 self.steps[s].reset();
351 self.contribution_ewma[s] = 0.0;
352 self.low_contrib_count[s] = 0;
353 }
354 } else {
355 self.low_contrib_count[s] = 0;
356 }
357 }
358 }
359
360 if let Some(interval) = self.config.proactive_prune_interval {
362 if self.samples_seen % interval == 0
363 && self.samples_seen > 0
364 && !self.contribution_ewma.is_empty()
365 {
366 let min_age = interval / 2;
367 let worst_idx = self
368 .steps
369 .iter()
370 .enumerate()
371 .zip(self.contribution_ewma.iter())
372 .filter(|((_, step), _)| step.n_samples_seen() >= min_age)
373 .min_by(|((_, _), a_ewma), ((_, _), b_ewma)| {
374 a_ewma
375 .partial_cmp(b_ewma)
376 .unwrap_or(core::cmp::Ordering::Equal)
377 })
378 .map(|((i, _), _)| i);
379
380 if let Some(idx) = worst_idx {
381 self.steps[idx].reset();
382 self.contribution_ewma[idx] = 0.0;
383 self.low_contrib_count[idx] = 0;
384 }
385 }
386 }
387
388 self.refresh_bandwidths();
390 }
391
392 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
394 for sample in samples {
395 self.train_one(sample);
396 }
397 }
398
399 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
420 &mut self,
421 samples: &[O],
422 interval: usize,
423 mut callback: F,
424 ) {
425 let interval = interval.max(1); for (i, sample) in samples.iter().enumerate() {
427 self.train_one(sample);
428 if (i + 1) % interval == 0 {
429 callback(i + 1);
430 }
431 }
432 let total = samples.len();
434 if total % interval != 0 {
435 callback(total);
436 }
437 }
438
439 pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
452 if max_samples >= samples.len() {
453 self.train_batch(samples);
454 return;
455 }
456
457 let mut reservoir: Vec<usize> = (0..max_samples).collect();
459 let mut rng = self.rng_state;
460
461 for i in max_samples..samples.len() {
462 rng ^= rng << 13;
464 rng ^= rng >> 7;
465 rng ^= rng << 17;
466 let j = (rng % (i as u64 + 1)) as usize;
467 if j < max_samples {
468 reservoir[j] = i;
469 }
470 }
471
472 self.rng_state = rng;
473
474 reservoir.sort_unstable();
476
477 for &idx in &reservoir {
479 self.train_one(&samples[idx]);
480 }
481 }
482
483 pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
489 &mut self,
490 samples: &[O],
491 max_samples: usize,
492 interval: usize,
493 mut callback: F,
494 ) {
495 if max_samples >= samples.len() {
496 self.train_batch_with_callback(samples, interval, callback);
497 return;
498 }
499
500 let mut reservoir: Vec<usize> = (0..max_samples).collect();
502 let mut rng = self.rng_state;
503
504 for i in max_samples..samples.len() {
505 rng ^= rng << 13;
506 rng ^= rng >> 7;
507 rng ^= rng << 17;
508 let j = (rng % (i as u64 + 1)) as usize;
509 if j < max_samples {
510 reservoir[j] = i;
511 }
512 }
513
514 self.rng_state = rng;
515 reservoir.sort_unstable();
516
517 let interval = interval.max(1);
518 for (i, &idx) in reservoir.iter().enumerate() {
519 self.train_one(&samples[idx]);
520 if (i + 1) % interval == 0 {
521 callback(i + 1);
522 }
523 }
524 let total = reservoir.len();
525 if total % interval != 0 {
526 callback(total);
527 }
528 }
529
530 pub fn predict(&self, features: &[f64]) -> f64 {
536 let mut pred = self.base_prediction;
537 if self.auto_bandwidths.is_empty() {
538 for step in &self.steps {
540 pred += self.config.learning_rate * step.predict(features);
541 }
542 } else {
543 for step in &self.steps {
544 pred += self.config.learning_rate
545 * step.predict_smooth_auto(features, &self.auto_bandwidths);
546 }
547 }
548 pred
549 }
550
551 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
556 let mut pred = self.base_prediction;
557 for step in &self.steps {
558 pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
559 }
560 pred
561 }
562
563 pub fn auto_bandwidths(&self) -> &[f64] {
569 &self.auto_bandwidths
570 }
571
572 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
577 let mut pred = self.base_prediction;
578 for step in &self.steps {
579 pred += self.config.learning_rate * step.predict_interpolated(features);
580 }
581 pred
582 }
583
584 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
592 let mut pred = self.base_prediction;
593 for step in &self.steps {
594 pred += self.config.learning_rate
595 * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
596 }
597 pred
598 }
599
600 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
606 let mut pred = self.base_prediction;
607 for step in &self.steps {
608 pred += self.config.learning_rate * step.predict_graduated(features);
609 }
610 pred
611 }
612
613 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
619 let mut pred = self.base_prediction;
620 for step in &self.steps {
621 pred += self.config.learning_rate
622 * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
623 }
624 pred
625 }
626
627 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
629 self.loss.predict_transform(self.predict(features))
630 }
631
632 pub fn predict_proba(&self, features: &[f64]) -> f64 {
634 self.predict_transformed(features)
635 }
636
637 pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
652 let mut pred = self.base_prediction;
653 let mut total_variance = 0.0;
654 let lr2 = self.config.learning_rate * self.config.learning_rate;
655
656 for step in &self.steps {
657 let (value, variance) = step.predict_with_variance(features);
658 pred += self.config.learning_rate * value;
659 total_variance += lr2 * variance;
660 }
661
662 let confidence = if total_variance > 0.0 && total_variance.is_finite() {
663 1.0 / crate::math::sqrt(total_variance)
664 } else {
665 0.0
666 };
667
668 (pred, confidence)
669 }
670
671 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
673 feature_matrix.iter().map(|f| self.predict(f)).collect()
674 }
675
676 pub fn n_steps(&self) -> usize {
678 self.steps.len()
679 }
680
681 pub fn n_trees(&self) -> usize {
683 self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
684 }
685
686 pub fn total_leaves(&self) -> usize {
688 self.steps.iter().map(|s| s.n_leaves()).sum()
689 }
690
691 pub fn n_samples_seen(&self) -> u64 {
693 self.samples_seen
694 }
695
696 pub fn base_prediction(&self) -> f64 {
698 self.base_prediction
699 }
700
701 pub fn is_initialized(&self) -> bool {
703 self.base_initialized
704 }
705
706 pub fn config(&self) -> &SGBTConfig {
708 &self.config
709 }
710
711 #[inline]
720 pub fn set_learning_rate(&mut self, lr: f64) {
721 self.config.learning_rate = lr;
722 }
723
724 pub fn steps(&self) -> &[BoostingStep] {
728 &self.steps
729 }
730
731 pub fn loss(&self) -> &L {
733 &self.loss
734 }
735
736 pub fn feature_importances(&self) -> Vec<f64> {
741 let mut totals: Vec<f64> = Vec::new();
743 for step in &self.steps {
744 let gains = step.slot().split_gains();
745 if totals.is_empty() && !gains.is_empty() {
746 totals.resize(gains.len(), 0.0);
747 }
748 for (i, &g) in gains.iter().enumerate() {
749 if i < totals.len() {
750 totals[i] += g;
751 }
752 }
753 }
754
755 let sum: f64 = totals.iter().sum();
757 if sum > 0.0 {
758 totals.iter_mut().for_each(|v| *v /= sum);
759 }
760 totals
761 }
762
763 pub fn feature_names(&self) -> Option<&[String]> {
765 self.config.feature_names.as_deref()
766 }
767
768 pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
773 let names = self.config.feature_names.as_ref()?;
774 let importances = self.feature_importances();
775 let mut pairs: Vec<(String, f64)> = names
776 .iter()
777 .zip(importances.iter().chain(core::iter::repeat(&0.0)))
778 .map(|(n, &v)| (n.clone(), v))
779 .collect();
780 pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
781 Some(pairs)
782 }
783
784 #[cfg(feature = "std")]
793 pub fn train_one_named(
794 &mut self,
795 features: &std::collections::HashMap<alloc::string::String, f64>,
796 target: f64,
797 ) {
798 let names = self
799 .config
800 .feature_names
801 .as_ref()
802 .expect("train_one_named requires feature_names to be configured");
803 let vec: Vec<f64> = names
804 .iter()
805 .map(|name| features.get(name).copied().unwrap_or(0.0))
806 .collect();
807 self.train_one(&(&vec[..], target));
808 }
809
810 #[cfg(feature = "std")]
818 pub fn predict_named(
819 &self,
820 features: &std::collections::HashMap<alloc::string::String, f64>,
821 ) -> f64 {
822 let names = self
823 .config
824 .feature_names
825 .as_ref()
826 .expect("predict_named requires feature_names to be configured");
827 let vec: Vec<f64> = names
828 .iter()
829 .map(|name| features.get(name).copied().unwrap_or(0.0))
830 .collect();
831 self.predict(&vec)
832 }
833
834 fn contribution_variance(&self, features: &[f64]) -> f64 {
844 let n = self.steps.len();
845 if n <= 1 {
846 return 0.0;
847 }
848
849 let lr = self.config.learning_rate;
850 let mut sum = 0.0;
851 let mut sq_sum = 0.0;
852 for step in &self.steps {
853 let c = lr * step.predict(features);
854 sum += c;
855 sq_sum += c * c;
856 }
857 let n_f = n as f64;
858 let mean = sum / n_f;
859 let var = (sq_sum / n_f) - (mean * mean);
860 crate::math::sqrt((var.abs() * n_f / (n_f - 1.0)).max(0.0))
862 }
863
864 fn refresh_bandwidths(&mut self) {
866 let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
867 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
868 self.auto_bandwidths = self.compute_auto_bandwidths();
869 self.last_replacement_sum = current_sum;
870 }
871 }
872
873 fn compute_auto_bandwidths(&self) -> Vec<f64> {
883 const K: f64 = 2.0;
884
885 let n_features = self
887 .steps
888 .iter()
889 .filter_map(|s| s.slot().active_tree().n_features())
890 .max()
891 .unwrap_or(0);
892
893 if n_features == 0 {
894 return Vec::new();
895 }
896
897 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
899
900 for step in &self.steps {
901 let tree_thresholds = step
902 .slot()
903 .active_tree()
904 .collect_split_thresholds_per_feature();
905 for (i, ts) in tree_thresholds.into_iter().enumerate() {
906 if i < n_features {
907 all_thresholds[i].extend(ts);
908 }
909 }
910 }
911
912 let n_bins = self.config.n_bins as f64;
913
914 all_thresholds
916 .iter()
917 .map(|ts| {
918 if ts.is_empty() {
919 return f64::INFINITY; }
921
922 let mut sorted = ts.clone();
923 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
924 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
925
926 if sorted.len() < 2 {
927 return f64::INFINITY; }
929
930 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
932
933 if sorted.len() < 3 {
934 let range = sorted.last().unwrap() - sorted.first().unwrap();
936 if range < 1e-15 {
937 return f64::INFINITY;
938 }
939 return (range / n_bins) * K;
940 }
941
942 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
944 let median_gap = if gaps.len() % 2 == 0 {
945 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
946 } else {
947 gaps[gaps.len() / 2]
948 };
949
950 if median_gap < 1e-15 {
951 f64::INFINITY
952 } else {
953 median_gap * K
954 }
955 })
956 .collect()
957 }
958
959 pub fn reset(&mut self) {
961 for step in &mut self.steps {
962 step.reset();
963 }
964 self.base_prediction = 0.0;
965 self.base_initialized = false;
966 self.initial_targets.clear();
967 self.samples_seen = 0;
968 self.rng_state = self.config.seed;
969 self.rolling_contribution_sigma = 0.0;
970 self.auto_bandwidths.clear();
971 self.last_replacement_sum = 0;
972 }
973
974 #[cfg(feature = "_serde_support")]
985 pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
986 let loss_type = self.loss.loss_type().ok_or_else(|| {
987 crate::error::IrithyllError::Serialization(
988 "cannot auto-detect loss type for serialization: \
989 implement Loss::loss_type() or use to_model_state_with()"
990 .into(),
991 )
992 })?;
993 Ok(self.to_model_state_with(loss_type))
994 }
995
996 #[cfg(feature = "_serde_support")]
1000 pub fn to_model_state_with(
1001 &self,
1002 loss_type: crate::loss::LossType,
1003 ) -> crate::serde_support::ModelState {
1004 use crate::serde_support::{ModelState, StepSnapshot};
1005
1006 let steps = self
1007 .steps
1008 .iter()
1009 .map(|step| {
1010 let slot = step.slot();
1011 let tree_snap = snapshot_tree(slot.active_tree());
1012 let alt_snap = slot.alternate_tree().map(snapshot_tree);
1013 let drift_state = slot.detector().serialize_state();
1014 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1015 StepSnapshot {
1016 tree: tree_snap,
1017 alternate_tree: alt_snap,
1018 drift_state,
1019 alt_drift_state,
1020 }
1021 })
1022 .collect();
1023
1024 ModelState {
1025 config: self.config.clone(),
1026 loss_type,
1027 base_prediction: self.base_prediction,
1028 base_initialized: self.base_initialized,
1029 initial_targets: self.initial_targets.clone(),
1030 initial_target_count: self.initial_target_count,
1031 samples_seen: self.samples_seen,
1032 rng_state: self.rng_state,
1033 steps,
1034 rolling_mean_error: self.rolling_mean_error,
1035 contribution_ewma: self.contribution_ewma.clone(),
1036 low_contrib_count: self.low_contrib_count.clone(),
1037 }
1038 }
1039}
1040
1041#[cfg(feature = "_serde_support")]
1046impl SGBT<Box<dyn Loss>> {
1047 pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
1057 use crate::ensemble::replacement::TreeSlot;
1058
1059 let loss = state.loss_type.into_loss();
1060
1061 let leaf_decay_alpha = state
1062 .config
1063 .leaf_half_life
1064 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
1065 let max_tree_samples = state.config.max_tree_samples;
1066
1067 let steps: Vec<BoostingStep> = state
1068 .steps
1069 .iter()
1070 .enumerate()
1071 .map(|(i, step_snap)| {
1072 let tree_config = TreeConfig::new()
1073 .max_depth(state.config.max_depth)
1074 .n_bins(state.config.n_bins)
1075 .lambda(state.config.lambda)
1076 .gamma(state.config.gamma)
1077 .grace_period(state.config.grace_period)
1078 .delta(state.config.delta)
1079 .feature_subsample_rate(state.config.feature_subsample_rate)
1080 .leaf_decay_alpha_opt(leaf_decay_alpha)
1081 .split_reeval_interval_opt(state.config.split_reeval_interval)
1082 .feature_types_opt(state.config.feature_types.clone())
1083 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1084 .monotone_constraints_opt(state.config.monotone_constraints.clone())
1085 .adaptive_depth_opt(state.config.adaptive_depth)
1086 .leaf_model_type(state.config.leaf_model_type.clone())
1087 .seed(state.config.seed ^ (i as u64));
1088
1089 let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1090 let alternate = step_snap
1091 .alternate_tree
1092 .as_ref()
1093 .map(|snap| rebuild_tree(snap, tree_config.clone()));
1094
1095 let mut detector = state.config.drift_detector.create();
1096 if let Some(ref ds) = step_snap.drift_state {
1097 detector.restore_state(ds);
1098 }
1099 let mut slot = TreeSlot::from_trees(
1100 active,
1101 alternate,
1102 tree_config,
1103 detector,
1104 max_tree_samples,
1105 );
1106 if let Some(ref ads) = step_snap.alt_drift_state {
1107 if let Some(alt_det) = slot.alt_detector_mut() {
1108 alt_det.restore_state(ads);
1109 }
1110 }
1111 BoostingStep::from_slot(slot)
1112 })
1113 .collect();
1114
1115 let n = steps.len();
1116 let has_pruning = state.config.quality_prune_alpha.is_some()
1117 || state.config.proactive_prune_interval.is_some();
1118
1119 let contribution_ewma = if !state.contribution_ewma.is_empty() {
1121 state.contribution_ewma
1122 } else if has_pruning {
1123 vec![0.0; n]
1124 } else {
1125 Vec::new()
1126 };
1127 let low_contrib_count = if !state.low_contrib_count.is_empty() {
1128 state.low_contrib_count
1129 } else if has_pruning {
1130 vec![0; n]
1131 } else {
1132 Vec::new()
1133 };
1134
1135 Self {
1136 config: state.config,
1137 steps,
1138 loss,
1139 base_prediction: state.base_prediction,
1140 base_initialized: state.base_initialized,
1141 initial_targets: state.initial_targets,
1142 initial_target_count: state.initial_target_count,
1143 samples_seen: state.samples_seen,
1144 rng_state: state.rng_state,
1145 contribution_ewma,
1146 low_contrib_count,
1147 rolling_mean_error: state.rolling_mean_error,
1148 rolling_contribution_sigma: 0.0,
1149 auto_bandwidths: Vec::new(),
1150 last_replacement_sum: 0,
1151 }
1152 }
1153}
1154
1155#[cfg(feature = "_serde_support")]
1161pub(crate) fn snapshot_tree(
1162 tree: &crate::tree::hoeffding::HoeffdingTree,
1163) -> crate::serde_support::TreeSnapshot {
1164 use crate::serde_support::TreeSnapshot;
1165 use crate::tree::StreamingTree;
1166 let arena = tree.arena();
1167 TreeSnapshot {
1168 feature_idx: arena.feature_idx.clone(),
1169 threshold: arena.threshold.clone(),
1170 left: arena.left.iter().map(|id| id.0).collect(),
1171 right: arena.right.iter().map(|id| id.0).collect(),
1172 leaf_value: arena.leaf_value.clone(),
1173 is_leaf: arena.is_leaf.clone(),
1174 depth: arena.depth.clone(),
1175 sample_count: arena.sample_count.clone(),
1176 n_features: tree.n_features(),
1177 samples_seen: tree.n_samples_seen(),
1178 rng_state: tree.rng_state(),
1179 categorical_mask: arena.categorical_mask.clone(),
1180 }
1181}
1182
1183#[cfg(feature = "_serde_support")]
1185pub(crate) fn rebuild_tree(
1186 snapshot: &crate::serde_support::TreeSnapshot,
1187 tree_config: TreeConfig,
1188) -> crate::tree::hoeffding::HoeffdingTree {
1189 use crate::tree::hoeffding::HoeffdingTree;
1190 use crate::tree::node::{NodeId, TreeArena};
1191
1192 let mut arena = TreeArena::new();
1193 let n = snapshot.feature_idx.len();
1194
1195 for i in 0..n {
1196 arena.feature_idx.push(snapshot.feature_idx[i]);
1197 arena.threshold.push(snapshot.threshold[i]);
1198 arena.left.push(NodeId(snapshot.left[i]));
1199 arena.right.push(NodeId(snapshot.right[i]));
1200 arena.leaf_value.push(snapshot.leaf_value[i]);
1201 arena.is_leaf.push(snapshot.is_leaf[i]);
1202 arena.depth.push(snapshot.depth[i]);
1203 arena.sample_count.push(snapshot.sample_count[i]);
1204 let mask = snapshot.categorical_mask.get(i).copied().flatten();
1205 arena.categorical_mask.push(mask);
1206 }
1207
1208 HoeffdingTree::from_arena(
1209 tree_config,
1210 arena,
1211 snapshot.n_features,
1212 snapshot.samples_seen,
1213 snapshot.rng_state,
1214 )
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219 use super::*;
1220 use alloc::boxed::Box;
1221 use alloc::vec;
1222 use alloc::vec::Vec;
1223
1224 fn default_config() -> SGBTConfig {
1225 SGBTConfig::builder()
1226 .n_steps(10)
1227 .learning_rate(0.1)
1228 .grace_period(20)
1229 .max_depth(4)
1230 .n_bins(16)
1231 .build()
1232 .unwrap()
1233 }
1234
1235 #[test]
1236 fn new_model_predicts_zero() {
1237 let model = SGBT::new(default_config());
1238 let pred = model.predict(&[1.0, 2.0, 3.0]);
1239 assert!(pred.abs() < 1e-12);
1240 }
1241
1242 #[test]
1243 fn train_one_does_not_panic() {
1244 let mut model = SGBT::new(default_config());
1245 model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1246 assert_eq!(model.n_samples_seen(), 1);
1247 }
1248
1249 #[test]
1250 fn prediction_changes_after_training() {
1251 let mut model = SGBT::new(default_config());
1252 let features = vec![1.0, 2.0, 3.0];
1253 for i in 0..100 {
1254 model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1255 }
1256 let pred = model.predict(&features);
1257 assert!(pred.is_finite());
1258 }
1259
1260 #[test]
1261 fn linear_signal_rmse_improves() {
1262 let config = SGBTConfig::builder()
1263 .n_steps(20)
1264 .learning_rate(0.1)
1265 .grace_period(10)
1266 .max_depth(3)
1267 .n_bins(16)
1268 .build()
1269 .unwrap();
1270 let mut model = SGBT::new(config);
1271
1272 let mut rng: u64 = 12345;
1273 let mut early_errors = Vec::new();
1274 let mut late_errors = Vec::new();
1275
1276 for i in 0..500 {
1277 rng ^= rng << 13;
1278 rng ^= rng >> 7;
1279 rng ^= rng << 17;
1280 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1281 rng ^= rng << 13;
1282 rng ^= rng >> 7;
1283 rng ^= rng << 17;
1284 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1285 let target = 2.0 * x1 + 3.0 * x2;
1286
1287 let pred = model.predict(&[x1, x2]);
1288 let error = (pred - target).powi(2);
1289
1290 if (50..150).contains(&i) {
1291 early_errors.push(error);
1292 }
1293 if i >= 400 {
1294 late_errors.push(error);
1295 }
1296
1297 model.train_one(&Sample::new(vec![x1, x2], target));
1298 }
1299
1300 let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1301 let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1302
1303 assert!(
1304 late_rmse < early_rmse,
1305 "RMSE should decrease: early={:.4}, late={:.4}",
1306 early_rmse,
1307 late_rmse
1308 );
1309 }
1310
1311 #[test]
1312 fn train_batch_equivalent_to_sequential() {
1313 let config = default_config();
1314 let mut model_seq = SGBT::new(config.clone());
1315 let mut model_batch = SGBT::new(config);
1316
1317 let samples: Vec<Sample> = (0..20)
1318 .map(|i| {
1319 let x = i as f64 * 0.5;
1320 Sample::new(vec![x, x * 2.0], x * 3.0)
1321 })
1322 .collect();
1323
1324 for s in &samples {
1325 model_seq.train_one(s);
1326 }
1327 model_batch.train_batch(&samples);
1328
1329 let pred_seq = model_seq.predict(&[1.0, 2.0]);
1330 let pred_batch = model_batch.predict(&[1.0, 2.0]);
1331
1332 assert!(
1333 (pred_seq - pred_batch).abs() < 1e-10,
1334 "seq={}, batch={}",
1335 pred_seq,
1336 pred_batch
1337 );
1338 }
1339
1340 #[test]
1341 fn reset_returns_to_initial() {
1342 let mut model = SGBT::new(default_config());
1343 for i in 0..100 {
1344 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1345 }
1346 model.reset();
1347 assert_eq!(model.n_samples_seen(), 0);
1348 assert!(!model.is_initialized());
1349 assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1350 }
1351
1352 #[test]
1353 fn base_prediction_initializes() {
1354 let mut model = SGBT::new(default_config());
1355 for i in 0..50 {
1356 model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1357 }
1358 assert!(model.is_initialized());
1359 let expected = (100.0 + 149.0) / 2.0;
1360 assert!((model.base_prediction() - expected).abs() < 1.0);
1361 }
1362
1363 #[test]
1364 fn with_loss_uses_custom_loss() {
1365 use crate::loss::logistic::LogisticLoss;
1366 let model = SGBT::with_loss(default_config(), LogisticLoss);
1367 let pred = model.predict_transformed(&[1.0, 2.0]);
1368 assert!(
1369 (pred - 0.5).abs() < 1e-6,
1370 "sigmoid(0) should be 0.5, got {}",
1371 pred
1372 );
1373 }
1374
1375 #[test]
1376 fn ewma_config_propagates_and_trains() {
1377 let config = SGBTConfig::builder()
1378 .n_steps(5)
1379 .learning_rate(0.1)
1380 .grace_period(10)
1381 .max_depth(3)
1382 .n_bins(16)
1383 .leaf_half_life(50)
1384 .build()
1385 .unwrap();
1386 let mut model = SGBT::new(config);
1387
1388 for i in 0..200 {
1389 let x = (i as f64) * 0.1;
1390 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1391 }
1392
1393 let pred = model.predict(&[1.0, 2.0]);
1394 assert!(
1395 pred.is_finite(),
1396 "EWMA-enabled model should produce finite predictions, got {}",
1397 pred
1398 );
1399 }
1400
1401 #[test]
1402 fn max_tree_samples_config_propagates() {
1403 let config = SGBTConfig::builder()
1404 .n_steps(5)
1405 .learning_rate(0.1)
1406 .grace_period(10)
1407 .max_depth(3)
1408 .n_bins(16)
1409 .max_tree_samples(200)
1410 .build()
1411 .unwrap();
1412 let mut model = SGBT::new(config);
1413
1414 for i in 0..500 {
1415 let x = (i as f64) * 0.1;
1416 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1417 }
1418
1419 let pred = model.predict(&[1.0, 2.0]);
1420 assert!(
1421 pred.is_finite(),
1422 "max_tree_samples model should produce finite predictions, got {}",
1423 pred
1424 );
1425 }
1426
1427 #[test]
1428 fn split_reeval_config_propagates() {
1429 let config = SGBTConfig::builder()
1430 .n_steps(5)
1431 .learning_rate(0.1)
1432 .grace_period(10)
1433 .max_depth(2)
1434 .n_bins(16)
1435 .split_reeval_interval(50)
1436 .build()
1437 .unwrap();
1438 let mut model = SGBT::new(config);
1439
1440 let mut rng: u64 = 12345;
1441 for _ in 0..1000 {
1442 rng ^= rng << 13;
1443 rng ^= rng >> 7;
1444 rng ^= rng << 17;
1445 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1446 rng ^= rng << 13;
1447 rng ^= rng >> 7;
1448 rng ^= rng << 17;
1449 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1450 let target = 2.0 * x1 + 3.0 * x2;
1451 model.train_one(&Sample::new(vec![x1, x2], target));
1452 }
1453
1454 let pred = model.predict(&[1.0, 2.0]);
1455 assert!(
1456 pred.is_finite(),
1457 "split re-eval model should produce finite predictions, got {}",
1458 pred
1459 );
1460 }
1461
1462 #[test]
1463 fn loss_accessor_works() {
1464 use crate::loss::logistic::LogisticLoss;
1465 let model = SGBT::with_loss(default_config(), LogisticLoss);
1466 let _loss: &LogisticLoss = model.loss();
1468 assert_eq!(_loss.n_outputs(), 1);
1469 }
1470
1471 #[test]
1472 fn clone_produces_independent_copy() {
1473 let config = default_config();
1474 let mut model = SGBT::new(config);
1475
1476 let mut rng: u64 = 99999;
1478 for _ in 0..200 {
1479 rng ^= rng << 13;
1480 rng ^= rng >> 7;
1481 rng ^= rng << 17;
1482 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1483 let target = 2.0 * x + 1.0;
1484 model.train_one(&Sample::new(vec![x], target));
1485 }
1486
1487 let mut cloned = model.clone();
1489
1490 let test_features = [3.0];
1492 let pred_original = model.predict(&test_features);
1493 let pred_cloned = cloned.predict(&test_features);
1494 assert!(
1495 (pred_original - pred_cloned).abs() < 1e-12,
1496 "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1497 );
1498
1499 for _ in 0..200 {
1501 rng ^= rng << 13;
1502 rng ^= rng >> 7;
1503 rng ^= rng << 17;
1504 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1505 let target = -3.0 * x + 5.0; cloned.train_one(&Sample::new(vec![x], target));
1507 }
1508
1509 let pred_original_after = model.predict(&test_features);
1510 let pred_cloned_after = cloned.predict(&test_features);
1511
1512 assert!(
1514 (pred_original - pred_original_after).abs() < 1e-12,
1515 "original should be unchanged after training clone"
1516 );
1517
1518 assert!(
1520 (pred_original_after - pred_cloned_after).abs() > 1e-6,
1521 "clone should diverge after independent training"
1522 );
1523 }
1524
1525 #[test]
1529 fn predict_with_confidence_finite() {
1530 let config = SGBTConfig::builder()
1531 .n_steps(5)
1532 .grace_period(10)
1533 .build()
1534 .unwrap();
1535 let mut model = SGBT::new(config);
1536
1537 for i in 0..100 {
1539 let x = i as f64 * 0.1;
1540 model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1541 }
1542
1543 let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1544 assert!(pred.is_finite(), "prediction should be finite");
1545 assert!(confidence.is_finite(), "confidence should be finite");
1546 assert!(
1547 confidence > 0.0,
1548 "confidence should be positive after training"
1549 );
1550 }
1551
1552 #[test]
1556 fn predict_with_confidence_positive_after_training() {
1557 let config = SGBTConfig::builder()
1558 .n_steps(5)
1559 .grace_period(10)
1560 .build()
1561 .unwrap();
1562 let mut model = SGBT::new(config);
1563
1564 for i in 0..200 {
1566 let x = i as f64 * 0.05;
1567 model.train_one(&(&[x][..], x * 2.0));
1568 }
1569
1570 let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1571
1572 assert!(pred.is_finite(), "prediction should be finite");
1573 assert!(
1574 confidence > 0.0 && confidence.is_finite(),
1575 "confidence should be finite and positive, got {}",
1576 confidence,
1577 );
1578
1579 let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1581 assert!(
1582 (pred - pred2).abs() < 1e-12,
1583 "same input should give same prediction"
1584 );
1585 assert!(
1586 (confidence - conf2).abs() < 1e-12,
1587 "same input should give same confidence"
1588 );
1589 }
1590
1591 #[test]
1595 fn predict_with_confidence_matches_predict() {
1596 let config = SGBTConfig::builder()
1597 .n_steps(10)
1598 .grace_period(10)
1599 .build()
1600 .unwrap();
1601 let mut model = SGBT::new(config);
1602
1603 for i in 0..200 {
1604 let x = (i as f64 - 100.0) * 0.01;
1605 model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1606 }
1607
1608 let pred = model.predict(&[0.5, 0.25]);
1609 let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1610
1611 assert!(
1612 (pred - conf_pred).abs() < 1e-10,
1613 "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1614 pred,
1615 conf_pred,
1616 );
1617 }
1618
1619 #[test]
1623 fn gradient_clip_config_builder() {
1624 let config = SGBTConfig::builder()
1625 .n_steps(10)
1626 .gradient_clip_sigma(3.0)
1627 .build()
1628 .unwrap();
1629
1630 assert_eq!(config.gradient_clip_sigma, Some(3.0));
1631 }
1632
1633 #[test]
1637 fn monotone_constraints_config_builder() {
1638 let config = SGBTConfig::builder()
1639 .n_steps(10)
1640 .monotone_constraints(vec![1, -1, 0])
1641 .build()
1642 .unwrap();
1643
1644 assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1645 }
1646
1647 #[test]
1651 fn monotone_constraints_invalid_value_rejected() {
1652 let result = SGBTConfig::builder()
1653 .n_steps(10)
1654 .monotone_constraints(vec![1, 2, 0])
1655 .build();
1656
1657 assert!(result.is_err(), "constraint value 2 should be rejected");
1658 }
1659
1660 #[test]
1664 fn gradient_clip_sigma_negative_rejected() {
1665 let result = SGBTConfig::builder()
1666 .n_steps(10)
1667 .gradient_clip_sigma(-1.0)
1668 .build();
1669
1670 assert!(result.is_err(), "negative sigma should be rejected");
1671 }
1672
1673 #[test]
1677 fn gradient_clipping_reduces_outlier_impact() {
1678 let config_no_clip = SGBTConfig::builder()
1680 .n_steps(5)
1681 .grace_period(10)
1682 .build()
1683 .unwrap();
1684 let mut model_no_clip = SGBT::new(config_no_clip);
1685
1686 let config_clip = SGBTConfig::builder()
1688 .n_steps(5)
1689 .grace_period(10)
1690 .gradient_clip_sigma(3.0)
1691 .build()
1692 .unwrap();
1693 let mut model_clip = SGBT::new(config_clip);
1694
1695 for i in 0..100 {
1697 let x = (i as f64) * 0.01;
1698 let sample = (&[x][..], x * 2.0);
1699 model_no_clip.train_one(&sample);
1700 model_clip.train_one(&sample);
1701 }
1702
1703 let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1704 let pred_clip_before = model_clip.predict(&[0.5]);
1705
1706 let outlier = (&[0.5_f64][..], 10000.0);
1708 model_no_clip.train_one(&outlier);
1709 model_clip.train_one(&outlier);
1710
1711 let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1712 let pred_clip_after = model_clip.predict(&[0.5]);
1713
1714 let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1715 let delta_clip = (pred_clip_after - pred_clip_before).abs();
1716
1717 assert!(
1719 delta_clip <= delta_no_clip + 1e-10,
1720 "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1721 delta_clip,
1722 delta_no_clip,
1723 );
1724 }
1725
1726 #[test]
1730 fn train_batch_with_callback_fires() {
1731 let config = SGBTConfig::builder()
1732 .n_steps(3)
1733 .grace_period(5)
1734 .build()
1735 .unwrap();
1736 let mut model = SGBT::new(config);
1737
1738 let data: Vec<(Vec<f64>, f64)> = (0..25)
1739 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1740 .collect();
1741
1742 let mut callbacks = Vec::new();
1743 model.train_batch_with_callback(&data, 10, |n| {
1744 callbacks.push(n);
1745 });
1746
1747 assert_eq!(callbacks, vec![10, 20, 25]);
1749 }
1750
1751 #[test]
1755 fn train_batch_subsampled_trains_subset() {
1756 let config = SGBTConfig::builder()
1757 .n_steps(3)
1758 .grace_period(5)
1759 .build()
1760 .unwrap();
1761 let mut model = SGBT::new(config);
1762
1763 let data: Vec<(Vec<f64>, f64)> = (0..100)
1764 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1765 .collect();
1766
1767 model.train_batch_subsampled(&data, 20);
1769
1770 assert!(
1772 model.n_samples_seen() > 0,
1773 "model should have trained on subset"
1774 );
1775 assert!(
1776 model.n_samples_seen() <= 20,
1777 "model should have trained at most 20 samples, got {}",
1778 model.n_samples_seen(),
1779 );
1780 }
1781
1782 #[test]
1786 fn train_batch_subsampled_full_equals_batch() {
1787 let config1 = SGBTConfig::builder()
1788 .n_steps(3)
1789 .grace_period(5)
1790 .build()
1791 .unwrap();
1792 let config2 = config1.clone();
1793
1794 let mut model1 = SGBT::new(config1);
1795 let mut model2 = SGBT::new(config2);
1796
1797 let data: Vec<(Vec<f64>, f64)> = (0..50)
1798 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1799 .collect();
1800
1801 model1.train_batch(&data);
1802 model2.train_batch_subsampled(&data, 1000); assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1806 let pred1 = model1.predict(&[2.5]);
1807 let pred2 = model2.predict(&[2.5]);
1808 assert!(
1809 (pred1 - pred2).abs() < 1e-12,
1810 "full subsample should equal batch: {} vs {}",
1811 pred1,
1812 pred2,
1813 );
1814 }
1815
1816 #[test]
1820 fn train_batch_subsampled_with_callback_works() {
1821 let config = SGBTConfig::builder()
1822 .n_steps(3)
1823 .grace_period(5)
1824 .build()
1825 .unwrap();
1826 let mut model = SGBT::new(config);
1827
1828 let data: Vec<(Vec<f64>, f64)> = (0..200)
1829 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1830 .collect();
1831
1832 let mut callbacks = Vec::new();
1833 model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1834 callbacks.push(n);
1835 });
1836
1837 assert!(!callbacks.is_empty(), "should have received callbacks");
1839 assert_eq!(
1840 *callbacks.last().unwrap(),
1841 50,
1842 "final callback should be total samples"
1843 );
1844 }
1845
1846 fn xorshift64(state: &mut u64) -> u64 {
1852 let mut s = *state;
1853 s ^= s << 13;
1854 s ^= s >> 7;
1855 s ^= s << 17;
1856 *state = s;
1857 s
1858 }
1859
1860 fn rand_f64(state: &mut u64) -> f64 {
1861 xorshift64(state) as f64 / u64::MAX as f64
1862 }
1863
1864 fn linear_leaves_config() -> SGBTConfig {
1865 SGBTConfig::builder()
1866 .n_steps(10)
1867 .learning_rate(0.1)
1868 .grace_period(20)
1869 .max_depth(2) .n_bins(16)
1871 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1872 learning_rate: 0.1,
1873 decay: None,
1874 use_adagrad: false,
1875 })
1876 .build()
1877 .unwrap()
1878 }
1879
1880 #[test]
1881 fn linear_leaves_trains_without_panic() {
1882 let mut model = SGBT::new(linear_leaves_config());
1883 let mut rng = 42u64;
1884 for _ in 0..200 {
1885 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1886 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1887 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1888 model.train_one(&Sample::new(vec![x1, x2], y));
1889 }
1890 assert_eq!(model.n_samples_seen(), 200);
1891 }
1892
1893 #[test]
1894 fn linear_leaves_prediction_finite() {
1895 let mut model = SGBT::new(linear_leaves_config());
1896 let mut rng = 42u64;
1897 for _ in 0..200 {
1898 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1899 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1900 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1901 model.train_one(&Sample::new(vec![x1, x2], y));
1902 }
1903 let pred = model.predict(&[0.5, -0.3]);
1904 assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1905 }
1906
1907 #[test]
1908 fn linear_leaves_learns_linear_target() {
1909 let mut model = SGBT::new(linear_leaves_config());
1910 let mut rng = 42u64;
1911 for _ in 0..500 {
1912 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1913 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1914 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1915 model.train_one(&Sample::new(vec![x1, x2], y));
1916 }
1917
1918 let mut total_error = 0.0;
1920 for _ in 0..50 {
1921 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1922 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1923 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1924 let pred = model.predict(&[x1, x2]);
1925 total_error += (pred - y).powi(2);
1926 }
1927 let mse = total_error / 50.0;
1928 assert!(
1929 mse < 5.0,
1930 "linear leaves MSE on linear target should be < 5.0, got {mse}"
1931 );
1932 }
1933
1934 #[test]
1935 fn linear_leaves_better_than_constant_at_low_depth() {
1936 let constant_config = SGBTConfig::builder()
1939 .n_steps(10)
1940 .learning_rate(0.1)
1941 .grace_period(20)
1942 .max_depth(2)
1943 .n_bins(16)
1944 .seed(0xDEAD)
1945 .build()
1946 .unwrap();
1947 let linear_config = SGBTConfig::builder()
1948 .n_steps(10)
1949 .learning_rate(0.1)
1950 .grace_period(20)
1951 .max_depth(2)
1952 .n_bins(16)
1953 .seed(0xDEAD)
1954 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1955 learning_rate: 0.1,
1956 decay: None,
1957 use_adagrad: false,
1958 })
1959 .build()
1960 .unwrap();
1961
1962 let mut constant_model = SGBT::new(constant_config);
1963 let mut linear_model = SGBT::new(linear_config);
1964 let mut rng = 42u64;
1965
1966 for _ in 0..500 {
1967 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1968 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1969 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1970 let sample = Sample::new(vec![x1, x2], y);
1971 constant_model.train_one(&sample);
1972 linear_model.train_one(&sample);
1973 }
1974
1975 let mut constant_mse = 0.0;
1977 let mut linear_mse = 0.0;
1978 for _ in 0..100 {
1979 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1980 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1981 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1982 constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1983 linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1984 }
1985 constant_mse /= 100.0;
1986 linear_mse /= 100.0;
1987
1988 assert!(
1990 linear_mse < constant_mse,
1991 "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1992 );
1993 }
1994
1995 #[test]
1996 fn adaptive_leaves_trains_without_panic() {
1997 let config = SGBTConfig::builder()
1998 .n_steps(10)
1999 .learning_rate(0.1)
2000 .grace_period(20)
2001 .max_depth(3)
2002 .n_bins(16)
2003 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
2004 promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
2005 learning_rate: 0.1,
2006 decay: None,
2007 use_adagrad: false,
2008 }),
2009 })
2010 .build()
2011 .unwrap();
2012
2013 let mut model = SGBT::new(config);
2014 let mut rng = 42u64;
2015 for _ in 0..500 {
2016 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2017 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2018 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2019 model.train_one(&Sample::new(vec![x1, x2], y));
2020 }
2021 let pred = model.predict(&[0.5, -0.3]);
2022 assert!(
2023 pred.is_finite(),
2024 "adaptive leaf prediction should be finite, got {pred}"
2025 );
2026 }
2027
2028 #[test]
2029 fn linear_leaves_with_decay_trains_without_panic() {
2030 let config = SGBTConfig::builder()
2031 .n_steps(10)
2032 .learning_rate(0.1)
2033 .grace_period(20)
2034 .max_depth(3)
2035 .n_bins(16)
2036 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
2037 learning_rate: 0.1,
2038 decay: Some(0.995),
2039 use_adagrad: false,
2040 })
2041 .build()
2042 .unwrap();
2043
2044 let mut model = SGBT::new(config);
2045 let mut rng = 42u64;
2046 for _ in 0..500 {
2047 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2048 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2049 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2050 model.train_one(&Sample::new(vec![x1, x2], y));
2051 }
2052 let pred = model.predict(&[0.5, -0.3]);
2053 assert!(
2054 pred.is_finite(),
2055 "decay leaf prediction should be finite, got {pred}"
2056 );
2057 }
2058
2059 #[test]
2063 fn predict_smooth_returns_finite() {
2064 let config = SGBTConfig::builder()
2065 .n_steps(5)
2066 .learning_rate(0.1)
2067 .grace_period(10)
2068 .build()
2069 .unwrap();
2070 let mut model = SGBT::new(config);
2071
2072 for i in 0..200 {
2073 let x = (i as f64) * 0.1;
2074 model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
2075 }
2076
2077 let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
2078 let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2079
2080 assert!(pred_hard.is_finite(), "hard prediction should be finite");
2081 assert!(
2082 pred_smooth.is_finite(),
2083 "smooth prediction should be finite"
2084 );
2085 }
2086
2087 #[test]
2091 fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2092 let config = SGBTConfig::builder()
2093 .n_steps(5)
2094 .learning_rate(0.1)
2095 .grace_period(10)
2096 .build()
2097 .unwrap();
2098 let mut model = SGBT::new(config);
2099
2100 for i in 0..300 {
2101 let x = (i as f64) * 0.1;
2102 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2103 }
2104
2105 let features = [5.0, 2.5];
2106 let hard = model.predict(&features);
2107 let smooth = model.predict_smooth(&features, 0.001);
2108
2109 assert!(
2110 (hard - smooth).abs() < 0.5,
2111 "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2112 hard,
2113 smooth,
2114 );
2115 }
2116
2117 #[test]
2118 fn auto_bandwidth_computed_after_training() {
2119 let config = SGBTConfig::builder()
2120 .n_steps(5)
2121 .learning_rate(0.1)
2122 .grace_period(10)
2123 .build()
2124 .unwrap();
2125 let mut model = SGBT::new(config);
2126
2127 assert!(model.auto_bandwidths().is_empty());
2129
2130 for i in 0..200 {
2131 let x = (i as f64) * 0.1;
2132 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2133 }
2134
2135 let bws = model.auto_bandwidths();
2137 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2138
2139 let pred = model.predict(&[5.0, 2.5]);
2141 assert!(
2142 pred.is_finite(),
2143 "auto-bandwidth predict should be finite: {}",
2144 pred
2145 );
2146 }
2147
2148 #[test]
2149 fn predict_interpolated_returns_finite() {
2150 let config = SGBTConfig::builder()
2151 .n_steps(5)
2152 .learning_rate(0.01)
2153 .build()
2154 .unwrap();
2155 let mut model = SGBT::new(config);
2156
2157 for i in 0..200 {
2158 let x = (i as f64) * 0.1;
2159 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2160 }
2161
2162 let pred = model.predict_interpolated(&[1.0, 0.5]);
2163 assert!(
2164 pred.is_finite(),
2165 "interpolated prediction should be finite: {}",
2166 pred
2167 );
2168 }
2169
2170 #[test]
2171 fn predict_sibling_interpolated_varies_with_features() {
2172 let config = SGBTConfig::builder()
2173 .n_steps(10)
2174 .learning_rate(0.1)
2175 .grace_period(10)
2176 .max_depth(6)
2177 .delta(0.1)
2178 .build()
2179 .unwrap();
2180 let mut model = SGBT::new(config);
2181
2182 for i in 0..2000 {
2183 let x = (i as f64) * 0.01;
2184 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2185 model.train_one(&Sample::new(vec![x, x * 0.3], y));
2186 }
2187
2188 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2190 assert!(pred.is_finite(), "sibling interpolated should be finite");
2191
2192 let bws = model.auto_bandwidths();
2195 if bws.iter().any(|&b| b.is_finite()) {
2196 let hard: Vec<f64> = (0..200)
2197 .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2198 .collect();
2199 let sib: Vec<f64> = (0..200)
2200 .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2201 .collect();
2202 let hc = hard
2203 .windows(2)
2204 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2205 .count();
2206 let sc = sib
2207 .windows(2)
2208 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2209 .count();
2210 assert!(
2211 sc >= hc,
2212 "sibling should produce >= hard changes: sib={}, hard={}",
2213 sc,
2214 hc
2215 );
2216 }
2217 }
2218
2219 #[test]
2220 fn predict_graduated_returns_finite() {
2221 let config = SGBTConfig::builder()
2222 .n_steps(5)
2223 .learning_rate(0.01)
2224 .max_tree_samples(200)
2225 .shadow_warmup(50)
2226 .build()
2227 .unwrap();
2228 let mut model = SGBT::new(config);
2229
2230 for i in 0..300 {
2231 let x = (i as f64) * 0.1;
2232 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2233 }
2234
2235 let pred = model.predict_graduated(&[1.0, 0.5]);
2236 assert!(
2237 pred.is_finite(),
2238 "graduated prediction should be finite: {}",
2239 pred
2240 );
2241
2242 let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2243 assert!(
2244 pred2.is_finite(),
2245 "graduated+sibling prediction should be finite: {}",
2246 pred2
2247 );
2248 }
2249
2250 #[test]
2251 fn shadow_warmup_validation() {
2252 let result = SGBTConfig::builder()
2253 .n_steps(5)
2254 .learning_rate(0.01)
2255 .shadow_warmup(0)
2256 .build();
2257 assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2258 }
2259
2260 #[test]
2265 fn adaptive_mts_defaults_to_none() {
2266 let cfg = SGBTConfig::default();
2267 assert!(
2268 cfg.adaptive_mts.is_none(),
2269 "adaptive_mts should default to None"
2270 );
2271 }
2272
2273 #[test]
2274 fn adaptive_mts_config_builder() {
2275 let cfg = SGBTConfig::builder()
2276 .n_steps(10)
2277 .adaptive_mts(500, 2.0)
2278 .build()
2279 .unwrap();
2280 assert_eq!(
2281 cfg.adaptive_mts,
2282 Some((500, 2.0)),
2283 "adaptive_mts should store (base_mts, k)"
2284 );
2285 }
2286
2287 #[test]
2288 fn adaptive_mts_validation_rejects_low_base() {
2289 let result = SGBTConfig::builder()
2290 .n_steps(5)
2291 .adaptive_mts(50, 1.0)
2292 .build();
2293 assert!(
2294 result.is_err(),
2295 "adaptive_mts with base_mts < 100 should fail"
2296 );
2297 }
2298
2299 #[test]
2300 fn adaptive_mts_validation_rejects_zero_k() {
2301 let result = SGBTConfig::builder()
2302 .n_steps(5)
2303 .adaptive_mts(500, 0.0)
2304 .build();
2305 assert!(result.is_err(), "adaptive_mts with k=0 should fail");
2306 }
2307
2308 #[test]
2309 fn adaptive_mts_trains_without_panic() {
2310 let config = SGBTConfig::builder()
2311 .n_steps(5)
2312 .learning_rate(0.1)
2313 .grace_period(10)
2314 .max_depth(3)
2315 .n_bins(16)
2316 .adaptive_mts(200, 1.0)
2317 .build()
2318 .unwrap();
2319 let mut model = SGBT::new(config);
2320
2321 for i in 0..500 {
2322 let x = (i as f64) * 0.1;
2323 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2324 }
2325
2326 let pred = model.predict(&[1.0, 2.0]);
2327 assert!(
2328 pred.is_finite(),
2329 "adaptive_mts model should produce finite predictions, got {}",
2330 pred
2331 );
2332 }
2333
2334 #[test]
2339 fn proactive_prune_defaults_to_none() {
2340 let cfg = SGBTConfig::default();
2341 assert!(
2342 cfg.proactive_prune_interval.is_none(),
2343 "proactive_prune_interval should default to None"
2344 );
2345 }
2346
2347 #[test]
2348 fn proactive_prune_config_builder() {
2349 let cfg = SGBTConfig::builder()
2350 .n_steps(10)
2351 .proactive_prune_interval(500)
2352 .build()
2353 .unwrap();
2354 assert_eq!(
2355 cfg.proactive_prune_interval,
2356 Some(500),
2357 "proactive_prune_interval should be set"
2358 );
2359 }
2360
2361 #[test]
2362 fn proactive_prune_validation_rejects_low_interval() {
2363 let result = SGBTConfig::builder()
2364 .n_steps(5)
2365 .proactive_prune_interval(50)
2366 .build();
2367 assert!(
2368 result.is_err(),
2369 "proactive_prune_interval < 100 should fail"
2370 );
2371 }
2372
2373 #[test]
2374 fn proactive_prune_enables_contribution_tracking() {
2375 let config = SGBTConfig::builder()
2376 .n_steps(5)
2377 .learning_rate(0.1)
2378 .grace_period(10)
2379 .max_depth(3)
2380 .n_bins(16)
2381 .proactive_prune_interval(200)
2382 .build()
2383 .unwrap();
2384
2385 assert!(config.quality_prune_alpha.is_none());
2387
2388 let mut model = SGBT::new(config);
2389
2390 for i in 0..100 {
2392 let x = (i as f64) * 0.1;
2393 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2394 }
2395
2396 let pred = model.predict(&[1.0, 2.0]);
2400 assert!(
2401 pred.is_finite(),
2402 "proactive_prune model should produce finite predictions, got {}",
2403 pred
2404 );
2405 }
2406
2407 #[test]
2408 fn proactive_prune_trains_without_panic() {
2409 let config = SGBTConfig::builder()
2410 .n_steps(5)
2411 .learning_rate(0.1)
2412 .grace_period(10)
2413 .max_depth(3)
2414 .n_bins(16)
2415 .proactive_prune_interval(200)
2416 .build()
2417 .unwrap();
2418 let mut model = SGBT::new(config);
2419
2420 for i in 0..500 {
2422 let x = (i as f64) * 0.1;
2423 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2424 }
2425
2426 let pred = model.predict(&[1.0, 2.0]);
2427 assert!(
2428 pred.is_finite(),
2429 "proactive_prune model should produce finite predictions after pruning, got {}",
2430 pred
2431 );
2432 }
2433}