1pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62pub struct SGBT<L: Loss = SquaredLoss> {
90 config: SGBTConfig,
92 steps: Vec<BoostingStep>,
94 loss: L,
96 base_prediction: f64,
98 base_initialized: bool,
100 initial_targets: Vec<f64>,
102 initial_target_count: usize,
104 samples_seen: u64,
106 rng_state: u64,
108 contribution_ewma: Vec<f64>,
111 low_contrib_count: Vec<u64>,
114 rolling_mean_error: f64,
117 rolling_contribution_sigma: f64,
120 auto_bandwidths: Vec<f64>,
123 last_replacement_sum: u64,
126}
127
128impl<L: Loss + Clone> Clone for SGBT<L> {
129 fn clone(&self) -> Self {
130 Self {
131 config: self.config.clone(),
132 steps: self.steps.clone(),
133 loss: self.loss.clone(),
134 base_prediction: self.base_prediction,
135 base_initialized: self.base_initialized,
136 initial_targets: self.initial_targets.clone(),
137 initial_target_count: self.initial_target_count,
138 samples_seen: self.samples_seen,
139 rng_state: self.rng_state,
140 contribution_ewma: self.contribution_ewma.clone(),
141 low_contrib_count: self.low_contrib_count.clone(),
142 rolling_mean_error: self.rolling_mean_error,
143 rolling_contribution_sigma: self.rolling_contribution_sigma,
144 auto_bandwidths: self.auto_bandwidths.clone(),
145 last_replacement_sum: self.last_replacement_sum,
146 }
147 }
148}
149
150impl<L: Loss> fmt::Debug for SGBT<L> {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 f.debug_struct("SGBT")
153 .field("n_steps", &self.steps.len())
154 .field("samples_seen", &self.samples_seen)
155 .field("base_prediction", &self.base_prediction)
156 .field("base_initialized", &self.base_initialized)
157 .finish()
158 }
159}
160
161impl SGBT<SquaredLoss> {
166 pub fn new(config: SGBTConfig) -> Self {
171 Self::with_loss(config, SquaredLoss)
172 }
173}
174
175impl<L: Loss> SGBT<L> {
180 pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
193 let leaf_decay_alpha = config
194 .leaf_half_life
195 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
196
197 let tree_config = TreeConfig::new()
198 .max_depth(config.max_depth)
199 .n_bins(config.n_bins)
200 .lambda(config.lambda)
201 .gamma(config.gamma)
202 .grace_period(config.grace_period)
203 .delta(config.delta)
204 .feature_subsample_rate(config.feature_subsample_rate)
205 .leaf_decay_alpha_opt(leaf_decay_alpha)
206 .split_reeval_interval_opt(config.split_reeval_interval)
207 .feature_types_opt(config.feature_types.clone())
208 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
209 .monotone_constraints_opt(config.monotone_constraints.clone())
210 .max_leaf_output_opt(config.max_leaf_output)
211 .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
212 .adaptive_depth_opt(config.adaptive_depth)
213 .min_hessian_sum_opt(config.min_hessian_sum)
214 .leaf_model_type(config.leaf_model_type.clone());
215
216 let max_tree_samples = config.max_tree_samples;
217
218 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
219 let steps: Vec<BoostingStep> = (0..config.n_steps)
220 .map(|i| {
221 let mut tc = tree_config.clone();
222 tc.seed = config.seed ^ (i as u64);
223 let detector = config.drift_detector.create();
224 if shadow_warmup > 0 {
225 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
226 } else {
227 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
228 }
229 })
230 .collect();
231
232 let seed = config.seed;
233 let initial_target_count = config.initial_target_count;
234 let n = config.n_steps;
235 let has_pruning =
236 config.quality_prune_alpha.is_some() || config.proactive_prune_interval.is_some();
237 Self {
238 config,
239 steps,
240 loss,
241 base_prediction: 0.0,
242 base_initialized: false,
243 initial_targets: Vec::new(),
244 initial_target_count,
245 samples_seen: 0,
246 rng_state: seed,
247 contribution_ewma: if has_pruning {
248 vec![0.0; n]
249 } else {
250 Vec::new()
251 },
252 low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
253 rolling_mean_error: 0.0,
254 rolling_contribution_sigma: 0.0,
255 auto_bandwidths: Vec::new(),
256 last_replacement_sum: 0,
257 }
258 }
259
260 pub fn train_one(&mut self, sample: &impl Observation) {
266 self.samples_seen += 1;
267 let target = sample.target();
268 let features = sample.features();
269
270 if !self.base_initialized {
272 self.initial_targets.push(target);
273 if self.initial_targets.len() >= self.initial_target_count {
274 self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
275 self.base_initialized = true;
276 self.initial_targets.clear();
277 self.initial_targets.shrink_to_fit();
278 }
279 }
280
281 let mut current_pred = self.base_prediction;
283
284 if let Some((base_mts, k)) = self.config.adaptive_mts {
286 let sigma = self.contribution_variance(features);
287 self.rolling_contribution_sigma =
288 0.999 * self.rolling_contribution_sigma + 0.001 * sigma;
289
290 let normalized = if self.rolling_contribution_sigma > 1e-10 {
291 sigma / self.rolling_contribution_sigma
292 } else {
293 1.0
294 };
295 let factor = 1.0 / (1.0 + k * normalized);
296 let effective_mts =
297 ((base_mts as f64) * factor).max(self.config.grace_period as f64 * 2.0) as u64;
298 for step in &mut self.steps {
299 step.slot_mut().set_max_tree_samples(Some(effective_mts));
300 }
301 }
302
303 let prune_alpha = self
304 .config
305 .quality_prune_alpha
306 .or_else(|| self.config.proactive_prune_interval.map(|_| 0.01));
307 let prune_threshold = self.config.quality_prune_threshold;
308 let prune_patience = self.config.quality_prune_patience;
309
310 let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
312 let abs_error = crate::math::abs(target - current_pred);
313 if self.rolling_mean_error > 1e-15 {
314 let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
315 self.rolling_mean_error =
316 ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
317 w
318 } else {
319 self.rolling_mean_error = abs_error.max(1e-15);
320 1.0 }
322 } else {
323 1.0
324 };
325
326 for s in 0..self.steps.len() {
328 let gradient = self.loss.gradient(target, current_pred) * error_weight;
329 let hessian = self.loss.hessian(target, current_pred) * error_weight;
330 let train_count = self
331 .config
332 .variant
333 .train_count(hessian, &mut self.rng_state);
334
335 let step_pred =
336 self.steps[s].train_and_predict(features, gradient, hessian, train_count);
337
338 current_pred += self.config.learning_rate * step_pred;
339
340 if let Some(alpha) = prune_alpha {
342 let contribution = crate::math::abs(self.config.learning_rate * step_pred);
343 self.contribution_ewma[s] =
344 alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
345
346 if self.contribution_ewma[s] < prune_threshold {
347 self.low_contrib_count[s] += 1;
348 if self.low_contrib_count[s] >= prune_patience {
349 self.steps[s].reset();
350 self.contribution_ewma[s] = 0.0;
351 self.low_contrib_count[s] = 0;
352 }
353 } else {
354 self.low_contrib_count[s] = 0;
355 }
356 }
357 }
358
359 if let Some(interval) = self.config.proactive_prune_interval {
361 if self.samples_seen % interval == 0
362 && self.samples_seen > 0
363 && !self.contribution_ewma.is_empty()
364 {
365 let min_age = interval / 2;
366 let worst_idx = self
367 .steps
368 .iter()
369 .enumerate()
370 .zip(self.contribution_ewma.iter())
371 .filter(|((_, step), _)| step.n_samples_seen() >= min_age)
372 .min_by(|((_, _), a_ewma), ((_, _), b_ewma)| {
373 a_ewma
374 .partial_cmp(b_ewma)
375 .unwrap_or(core::cmp::Ordering::Equal)
376 })
377 .map(|((i, _), _)| i);
378
379 if let Some(idx) = worst_idx {
380 self.steps[idx].reset();
381 self.contribution_ewma[idx] = 0.0;
382 self.low_contrib_count[idx] = 0;
383 }
384 }
385 }
386
387 self.refresh_bandwidths();
389 }
390
391 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
393 for sample in samples {
394 self.train_one(sample);
395 }
396 }
397
398 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
419 &mut self,
420 samples: &[O],
421 interval: usize,
422 mut callback: F,
423 ) {
424 let interval = interval.max(1); for (i, sample) in samples.iter().enumerate() {
426 self.train_one(sample);
427 if (i + 1) % interval == 0 {
428 callback(i + 1);
429 }
430 }
431 let total = samples.len();
433 if total % interval != 0 {
434 callback(total);
435 }
436 }
437
438 pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
451 if max_samples >= samples.len() {
452 self.train_batch(samples);
453 return;
454 }
455
456 let mut reservoir: Vec<usize> = (0..max_samples).collect();
458 let mut rng = self.rng_state;
459
460 for i in max_samples..samples.len() {
461 rng ^= rng << 13;
463 rng ^= rng >> 7;
464 rng ^= rng << 17;
465 let j = (rng % (i as u64 + 1)) as usize;
466 if j < max_samples {
467 reservoir[j] = i;
468 }
469 }
470
471 self.rng_state = rng;
472
473 reservoir.sort_unstable();
475
476 for &idx in &reservoir {
478 self.train_one(&samples[idx]);
479 }
480 }
481
482 pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
488 &mut self,
489 samples: &[O],
490 max_samples: usize,
491 interval: usize,
492 mut callback: F,
493 ) {
494 if max_samples >= samples.len() {
495 self.train_batch_with_callback(samples, interval, callback);
496 return;
497 }
498
499 let mut reservoir: Vec<usize> = (0..max_samples).collect();
501 let mut rng = self.rng_state;
502
503 for i in max_samples..samples.len() {
504 rng ^= rng << 13;
505 rng ^= rng >> 7;
506 rng ^= rng << 17;
507 let j = (rng % (i as u64 + 1)) as usize;
508 if j < max_samples {
509 reservoir[j] = i;
510 }
511 }
512
513 self.rng_state = rng;
514 reservoir.sort_unstable();
515
516 let interval = interval.max(1);
517 for (i, &idx) in reservoir.iter().enumerate() {
518 self.train_one(&samples[idx]);
519 if (i + 1) % interval == 0 {
520 callback(i + 1);
521 }
522 }
523 let total = reservoir.len();
524 if total % interval != 0 {
525 callback(total);
526 }
527 }
528
529 pub fn predict(&self, features: &[f64]) -> f64 {
535 let mut pred = self.base_prediction;
536 if self.auto_bandwidths.is_empty() {
537 for step in &self.steps {
539 pred += self.config.learning_rate * step.predict(features);
540 }
541 } else {
542 for step in &self.steps {
543 pred += self.config.learning_rate
544 * step.predict_smooth_auto(features, &self.auto_bandwidths);
545 }
546 }
547 pred
548 }
549
550 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
555 let mut pred = self.base_prediction;
556 for step in &self.steps {
557 pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
558 }
559 pred
560 }
561
562 pub fn auto_bandwidths(&self) -> &[f64] {
568 &self.auto_bandwidths
569 }
570
571 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
576 let mut pred = self.base_prediction;
577 for step in &self.steps {
578 pred += self.config.learning_rate * step.predict_interpolated(features);
579 }
580 pred
581 }
582
583 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
591 let mut pred = self.base_prediction;
592 for step in &self.steps {
593 pred += self.config.learning_rate
594 * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
595 }
596 pred
597 }
598
599 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
605 let mut pred = self.base_prediction;
606 for step in &self.steps {
607 pred += self.config.learning_rate * step.predict_graduated(features);
608 }
609 pred
610 }
611
612 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
618 let mut pred = self.base_prediction;
619 for step in &self.steps {
620 pred += self.config.learning_rate
621 * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
622 }
623 pred
624 }
625
626 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
628 self.loss.predict_transform(self.predict(features))
629 }
630
631 pub fn predict_proba(&self, features: &[f64]) -> f64 {
633 self.predict_transformed(features)
634 }
635
636 pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
651 let mut pred = self.base_prediction;
652 let mut total_variance = 0.0;
653 let lr2 = self.config.learning_rate * self.config.learning_rate;
654
655 for step in &self.steps {
656 let (value, variance) = step.predict_with_variance(features);
657 pred += self.config.learning_rate * value;
658 total_variance += lr2 * variance;
659 }
660
661 let confidence = if total_variance > 0.0 && total_variance.is_finite() {
662 1.0 / crate::math::sqrt(total_variance)
663 } else {
664 0.0
665 };
666
667 (pred, confidence)
668 }
669
670 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
672 feature_matrix.iter().map(|f| self.predict(f)).collect()
673 }
674
675 pub fn n_steps(&self) -> usize {
677 self.steps.len()
678 }
679
680 pub fn n_trees(&self) -> usize {
682 self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
683 }
684
685 pub fn total_leaves(&self) -> usize {
687 self.steps.iter().map(|s| s.n_leaves()).sum()
688 }
689
690 pub fn n_samples_seen(&self) -> u64 {
692 self.samples_seen
693 }
694
695 pub fn base_prediction(&self) -> f64 {
697 self.base_prediction
698 }
699
700 pub fn is_initialized(&self) -> bool {
702 self.base_initialized
703 }
704
705 pub fn config(&self) -> &SGBTConfig {
707 &self.config
708 }
709
710 #[inline]
719 pub fn set_learning_rate(&mut self, lr: f64) {
720 self.config.learning_rate = lr;
721 }
722
723 pub fn steps(&self) -> &[BoostingStep] {
727 &self.steps
728 }
729
730 pub fn loss(&self) -> &L {
732 &self.loss
733 }
734
735 pub fn feature_importances(&self) -> Vec<f64> {
740 let mut totals: Vec<f64> = Vec::new();
742 for step in &self.steps {
743 let gains = step.slot().split_gains();
744 if totals.is_empty() && !gains.is_empty() {
745 totals.resize(gains.len(), 0.0);
746 }
747 for (i, &g) in gains.iter().enumerate() {
748 if i < totals.len() {
749 totals[i] += g;
750 }
751 }
752 }
753
754 let sum: f64 = totals.iter().sum();
756 if sum > 0.0 {
757 totals.iter_mut().for_each(|v| *v /= sum);
758 }
759 totals
760 }
761
762 pub fn feature_names(&self) -> Option<&[String]> {
764 self.config.feature_names.as_deref()
765 }
766
767 pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
772 let names = self.config.feature_names.as_ref()?;
773 let importances = self.feature_importances();
774 let mut pairs: Vec<(String, f64)> = names
775 .iter()
776 .zip(importances.iter().chain(core::iter::repeat(&0.0)))
777 .map(|(n, &v)| (n.clone(), v))
778 .collect();
779 pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
780 Some(pairs)
781 }
782
783 #[cfg(feature = "std")]
792 pub fn train_one_named(
793 &mut self,
794 features: &std::collections::HashMap<alloc::string::String, f64>,
795 target: f64,
796 ) {
797 let names = self
798 .config
799 .feature_names
800 .as_ref()
801 .expect("train_one_named requires feature_names to be configured");
802 let vec: Vec<f64> = names
803 .iter()
804 .map(|name| features.get(name).copied().unwrap_or(0.0))
805 .collect();
806 self.train_one(&(&vec[..], target));
807 }
808
809 #[cfg(feature = "std")]
817 pub fn predict_named(
818 &self,
819 features: &std::collections::HashMap<alloc::string::String, f64>,
820 ) -> f64 {
821 let names = self
822 .config
823 .feature_names
824 .as_ref()
825 .expect("predict_named requires feature_names to be configured");
826 let vec: Vec<f64> = names
827 .iter()
828 .map(|name| features.get(name).copied().unwrap_or(0.0))
829 .collect();
830 self.predict(&vec)
831 }
832
833 fn contribution_variance(&self, features: &[f64]) -> f64 {
843 let n = self.steps.len();
844 if n <= 1 {
845 return 0.0;
846 }
847
848 let lr = self.config.learning_rate;
849 let mut sum = 0.0;
850 let mut sq_sum = 0.0;
851 for step in &self.steps {
852 let c = lr * step.predict(features);
853 sum += c;
854 sq_sum += c * c;
855 }
856 let n_f = n as f64;
857 let mean = sum / n_f;
858 let var = (sq_sum / n_f) - (mean * mean);
859 crate::math::sqrt((var.abs() * n_f / (n_f - 1.0)).max(0.0))
861 }
862
863 fn refresh_bandwidths(&mut self) {
865 let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
866 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
867 self.auto_bandwidths = self.compute_auto_bandwidths();
868 self.last_replacement_sum = current_sum;
869 }
870 }
871
872 fn compute_auto_bandwidths(&self) -> Vec<f64> {
882 const K: f64 = 2.0;
883
884 let n_features = self
886 .steps
887 .iter()
888 .filter_map(|s| s.slot().active_tree().n_features())
889 .max()
890 .unwrap_or(0);
891
892 if n_features == 0 {
893 return Vec::new();
894 }
895
896 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
898
899 for step in &self.steps {
900 let tree_thresholds = step
901 .slot()
902 .active_tree()
903 .collect_split_thresholds_per_feature();
904 for (i, ts) in tree_thresholds.into_iter().enumerate() {
905 if i < n_features {
906 all_thresholds[i].extend(ts);
907 }
908 }
909 }
910
911 let n_bins = self.config.n_bins as f64;
912
913 all_thresholds
915 .iter()
916 .map(|ts| {
917 if ts.is_empty() {
918 return f64::INFINITY; }
920
921 let mut sorted = ts.clone();
922 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
923 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
924
925 if sorted.len() < 2 {
926 return f64::INFINITY; }
928
929 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
931
932 if sorted.len() < 3 {
933 let range = sorted.last().unwrap() - sorted.first().unwrap();
935 if range < 1e-15 {
936 return f64::INFINITY;
937 }
938 return (range / n_bins) * K;
939 }
940
941 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
943 let median_gap = if gaps.len() % 2 == 0 {
944 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
945 } else {
946 gaps[gaps.len() / 2]
947 };
948
949 if median_gap < 1e-15 {
950 f64::INFINITY
951 } else {
952 median_gap * K
953 }
954 })
955 .collect()
956 }
957
958 pub fn reset(&mut self) {
960 for step in &mut self.steps {
961 step.reset();
962 }
963 self.base_prediction = 0.0;
964 self.base_initialized = false;
965 self.initial_targets.clear();
966 self.samples_seen = 0;
967 self.rng_state = self.config.seed;
968 self.rolling_contribution_sigma = 0.0;
969 self.auto_bandwidths.clear();
970 self.last_replacement_sum = 0;
971 }
972
973 #[cfg(feature = "_serde_support")]
984 pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
985 let loss_type = self.loss.loss_type().ok_or_else(|| {
986 crate::error::IrithyllError::Serialization(
987 "cannot auto-detect loss type for serialization: \
988 implement Loss::loss_type() or use to_model_state_with()"
989 .into(),
990 )
991 })?;
992 Ok(self.to_model_state_with(loss_type))
993 }
994
995 #[cfg(feature = "_serde_support")]
999 pub fn to_model_state_with(
1000 &self,
1001 loss_type: crate::loss::LossType,
1002 ) -> crate::serde_support::ModelState {
1003 use crate::serde_support::{ModelState, StepSnapshot};
1004
1005 let steps = self
1006 .steps
1007 .iter()
1008 .map(|step| {
1009 let slot = step.slot();
1010 let tree_snap = snapshot_tree(slot.active_tree());
1011 let alt_snap = slot.alternate_tree().map(snapshot_tree);
1012 let drift_state = slot.detector().serialize_state();
1013 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1014 StepSnapshot {
1015 tree: tree_snap,
1016 alternate_tree: alt_snap,
1017 drift_state,
1018 alt_drift_state,
1019 }
1020 })
1021 .collect();
1022
1023 ModelState {
1024 config: self.config.clone(),
1025 loss_type,
1026 base_prediction: self.base_prediction,
1027 base_initialized: self.base_initialized,
1028 initial_targets: self.initial_targets.clone(),
1029 initial_target_count: self.initial_target_count,
1030 samples_seen: self.samples_seen,
1031 rng_state: self.rng_state,
1032 steps,
1033 rolling_mean_error: self.rolling_mean_error,
1034 contribution_ewma: self.contribution_ewma.clone(),
1035 low_contrib_count: self.low_contrib_count.clone(),
1036 }
1037 }
1038}
1039
1040#[cfg(feature = "_serde_support")]
1045impl SGBT<Box<dyn Loss>> {
1046 pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
1056 use crate::ensemble::replacement::TreeSlot;
1057
1058 let loss = state.loss_type.into_loss();
1059
1060 let leaf_decay_alpha = state
1061 .config
1062 .leaf_half_life
1063 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
1064 let max_tree_samples = state.config.max_tree_samples;
1065
1066 let steps: Vec<BoostingStep> = state
1067 .steps
1068 .iter()
1069 .enumerate()
1070 .map(|(i, step_snap)| {
1071 let tree_config = TreeConfig::new()
1072 .max_depth(state.config.max_depth)
1073 .n_bins(state.config.n_bins)
1074 .lambda(state.config.lambda)
1075 .gamma(state.config.gamma)
1076 .grace_period(state.config.grace_period)
1077 .delta(state.config.delta)
1078 .feature_subsample_rate(state.config.feature_subsample_rate)
1079 .leaf_decay_alpha_opt(leaf_decay_alpha)
1080 .split_reeval_interval_opt(state.config.split_reeval_interval)
1081 .feature_types_opt(state.config.feature_types.clone())
1082 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1083 .monotone_constraints_opt(state.config.monotone_constraints.clone())
1084 .adaptive_depth_opt(state.config.adaptive_depth)
1085 .leaf_model_type(state.config.leaf_model_type.clone())
1086 .seed(state.config.seed ^ (i as u64));
1087
1088 let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1089 let alternate = step_snap
1090 .alternate_tree
1091 .as_ref()
1092 .map(|snap| rebuild_tree(snap, tree_config.clone()));
1093
1094 let mut detector = state.config.drift_detector.create();
1095 if let Some(ref ds) = step_snap.drift_state {
1096 detector.restore_state(ds);
1097 }
1098 let mut slot = TreeSlot::from_trees(
1099 active,
1100 alternate,
1101 tree_config,
1102 detector,
1103 max_tree_samples,
1104 );
1105 if let Some(ref ads) = step_snap.alt_drift_state {
1106 if let Some(alt_det) = slot.alt_detector_mut() {
1107 alt_det.restore_state(ads);
1108 }
1109 }
1110 BoostingStep::from_slot(slot)
1111 })
1112 .collect();
1113
1114 let n = steps.len();
1115 let has_pruning = state.config.quality_prune_alpha.is_some()
1116 || state.config.proactive_prune_interval.is_some();
1117
1118 let contribution_ewma = if !state.contribution_ewma.is_empty() {
1120 state.contribution_ewma
1121 } else if has_pruning {
1122 vec![0.0; n]
1123 } else {
1124 Vec::new()
1125 };
1126 let low_contrib_count = if !state.low_contrib_count.is_empty() {
1127 state.low_contrib_count
1128 } else if has_pruning {
1129 vec![0; n]
1130 } else {
1131 Vec::new()
1132 };
1133
1134 Self {
1135 config: state.config,
1136 steps,
1137 loss,
1138 base_prediction: state.base_prediction,
1139 base_initialized: state.base_initialized,
1140 initial_targets: state.initial_targets,
1141 initial_target_count: state.initial_target_count,
1142 samples_seen: state.samples_seen,
1143 rng_state: state.rng_state,
1144 contribution_ewma,
1145 low_contrib_count,
1146 rolling_mean_error: state.rolling_mean_error,
1147 rolling_contribution_sigma: 0.0,
1148 auto_bandwidths: Vec::new(),
1149 last_replacement_sum: 0,
1150 }
1151 }
1152}
1153
1154#[cfg(feature = "_serde_support")]
1160pub(crate) fn snapshot_tree(
1161 tree: &crate::tree::hoeffding::HoeffdingTree,
1162) -> crate::serde_support::TreeSnapshot {
1163 use crate::serde_support::TreeSnapshot;
1164 use crate::tree::StreamingTree;
1165 let arena = tree.arena();
1166 TreeSnapshot {
1167 feature_idx: arena.feature_idx.clone(),
1168 threshold: arena.threshold.clone(),
1169 left: arena.left.iter().map(|id| id.0).collect(),
1170 right: arena.right.iter().map(|id| id.0).collect(),
1171 leaf_value: arena.leaf_value.clone(),
1172 is_leaf: arena.is_leaf.clone(),
1173 depth: arena.depth.clone(),
1174 sample_count: arena.sample_count.clone(),
1175 n_features: tree.n_features(),
1176 samples_seen: tree.n_samples_seen(),
1177 rng_state: tree.rng_state(),
1178 categorical_mask: arena.categorical_mask.clone(),
1179 }
1180}
1181
1182#[cfg(feature = "_serde_support")]
1184pub(crate) fn rebuild_tree(
1185 snapshot: &crate::serde_support::TreeSnapshot,
1186 tree_config: TreeConfig,
1187) -> crate::tree::hoeffding::HoeffdingTree {
1188 use crate::tree::hoeffding::HoeffdingTree;
1189 use crate::tree::node::{NodeId, TreeArena};
1190
1191 let mut arena = TreeArena::new();
1192 let n = snapshot.feature_idx.len();
1193
1194 for i in 0..n {
1195 arena.feature_idx.push(snapshot.feature_idx[i]);
1196 arena.threshold.push(snapshot.threshold[i]);
1197 arena.left.push(NodeId(snapshot.left[i]));
1198 arena.right.push(NodeId(snapshot.right[i]));
1199 arena.leaf_value.push(snapshot.leaf_value[i]);
1200 arena.is_leaf.push(snapshot.is_leaf[i]);
1201 arena.depth.push(snapshot.depth[i]);
1202 arena.sample_count.push(snapshot.sample_count[i]);
1203 let mask = snapshot.categorical_mask.get(i).copied().flatten();
1204 arena.categorical_mask.push(mask);
1205 }
1206
1207 HoeffdingTree::from_arena(
1208 tree_config,
1209 arena,
1210 snapshot.n_features,
1211 snapshot.samples_seen,
1212 snapshot.rng_state,
1213 )
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218 use super::*;
1219 use alloc::boxed::Box;
1220 use alloc::vec;
1221 use alloc::vec::Vec;
1222
1223 fn default_config() -> SGBTConfig {
1224 SGBTConfig::builder()
1225 .n_steps(10)
1226 .learning_rate(0.1)
1227 .grace_period(20)
1228 .max_depth(4)
1229 .n_bins(16)
1230 .build()
1231 .unwrap()
1232 }
1233
1234 #[test]
1235 fn new_model_predicts_zero() {
1236 let model = SGBT::new(default_config());
1237 let pred = model.predict(&[1.0, 2.0, 3.0]);
1238 assert!(pred.abs() < 1e-12);
1239 }
1240
1241 #[test]
1242 fn train_one_does_not_panic() {
1243 let mut model = SGBT::new(default_config());
1244 model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1245 assert_eq!(model.n_samples_seen(), 1);
1246 }
1247
1248 #[test]
1249 fn prediction_changes_after_training() {
1250 let mut model = SGBT::new(default_config());
1251 let features = vec![1.0, 2.0, 3.0];
1252 for i in 0..100 {
1253 model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1254 }
1255 let pred = model.predict(&features);
1256 assert!(pred.is_finite());
1257 }
1258
1259 #[test]
1260 fn linear_signal_rmse_improves() {
1261 let config = SGBTConfig::builder()
1262 .n_steps(20)
1263 .learning_rate(0.1)
1264 .grace_period(10)
1265 .max_depth(3)
1266 .n_bins(16)
1267 .build()
1268 .unwrap();
1269 let mut model = SGBT::new(config);
1270
1271 let mut rng: u64 = 12345;
1272 let mut early_errors = Vec::new();
1273 let mut late_errors = Vec::new();
1274
1275 for i in 0..500 {
1276 rng ^= rng << 13;
1277 rng ^= rng >> 7;
1278 rng ^= rng << 17;
1279 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1280 rng ^= rng << 13;
1281 rng ^= rng >> 7;
1282 rng ^= rng << 17;
1283 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1284 let target = 2.0 * x1 + 3.0 * x2;
1285
1286 let pred = model.predict(&[x1, x2]);
1287 let error = (pred - target).powi(2);
1288
1289 if (50..150).contains(&i) {
1290 early_errors.push(error);
1291 }
1292 if i >= 400 {
1293 late_errors.push(error);
1294 }
1295
1296 model.train_one(&Sample::new(vec![x1, x2], target));
1297 }
1298
1299 let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1300 let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1301
1302 assert!(
1303 late_rmse < early_rmse,
1304 "RMSE should decrease: early={:.4}, late={:.4}",
1305 early_rmse,
1306 late_rmse
1307 );
1308 }
1309
1310 #[test]
1311 fn train_batch_equivalent_to_sequential() {
1312 let config = default_config();
1313 let mut model_seq = SGBT::new(config.clone());
1314 let mut model_batch = SGBT::new(config);
1315
1316 let samples: Vec<Sample> = (0..20)
1317 .map(|i| {
1318 let x = i as f64 * 0.5;
1319 Sample::new(vec![x, x * 2.0], x * 3.0)
1320 })
1321 .collect();
1322
1323 for s in &samples {
1324 model_seq.train_one(s);
1325 }
1326 model_batch.train_batch(&samples);
1327
1328 let pred_seq = model_seq.predict(&[1.0, 2.0]);
1329 let pred_batch = model_batch.predict(&[1.0, 2.0]);
1330
1331 assert!(
1332 (pred_seq - pred_batch).abs() < 1e-10,
1333 "seq={}, batch={}",
1334 pred_seq,
1335 pred_batch
1336 );
1337 }
1338
1339 #[test]
1340 fn reset_returns_to_initial() {
1341 let mut model = SGBT::new(default_config());
1342 for i in 0..100 {
1343 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1344 }
1345 model.reset();
1346 assert_eq!(model.n_samples_seen(), 0);
1347 assert!(!model.is_initialized());
1348 assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1349 }
1350
1351 #[test]
1352 fn base_prediction_initializes() {
1353 let mut model = SGBT::new(default_config());
1354 for i in 0..50 {
1355 model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1356 }
1357 assert!(model.is_initialized());
1358 let expected = (100.0 + 149.0) / 2.0;
1359 assert!((model.base_prediction() - expected).abs() < 1.0);
1360 }
1361
1362 #[test]
1363 fn with_loss_uses_custom_loss() {
1364 use crate::loss::logistic::LogisticLoss;
1365 let model = SGBT::with_loss(default_config(), LogisticLoss);
1366 let pred = model.predict_transformed(&[1.0, 2.0]);
1367 assert!(
1368 (pred - 0.5).abs() < 1e-6,
1369 "sigmoid(0) should be 0.5, got {}",
1370 pred
1371 );
1372 }
1373
1374 #[test]
1375 fn ewma_config_propagates_and_trains() {
1376 let config = SGBTConfig::builder()
1377 .n_steps(5)
1378 .learning_rate(0.1)
1379 .grace_period(10)
1380 .max_depth(3)
1381 .n_bins(16)
1382 .leaf_half_life(50)
1383 .build()
1384 .unwrap();
1385 let mut model = SGBT::new(config);
1386
1387 for i in 0..200 {
1388 let x = (i as f64) * 0.1;
1389 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1390 }
1391
1392 let pred = model.predict(&[1.0, 2.0]);
1393 assert!(
1394 pred.is_finite(),
1395 "EWMA-enabled model should produce finite predictions, got {}",
1396 pred
1397 );
1398 }
1399
1400 #[test]
1401 fn max_tree_samples_config_propagates() {
1402 let config = SGBTConfig::builder()
1403 .n_steps(5)
1404 .learning_rate(0.1)
1405 .grace_period(10)
1406 .max_depth(3)
1407 .n_bins(16)
1408 .max_tree_samples(200)
1409 .build()
1410 .unwrap();
1411 let mut model = SGBT::new(config);
1412
1413 for i in 0..500 {
1414 let x = (i as f64) * 0.1;
1415 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1416 }
1417
1418 let pred = model.predict(&[1.0, 2.0]);
1419 assert!(
1420 pred.is_finite(),
1421 "max_tree_samples model should produce finite predictions, got {}",
1422 pred
1423 );
1424 }
1425
1426 #[test]
1427 fn split_reeval_config_propagates() {
1428 let config = SGBTConfig::builder()
1429 .n_steps(5)
1430 .learning_rate(0.1)
1431 .grace_period(10)
1432 .max_depth(2)
1433 .n_bins(16)
1434 .split_reeval_interval(50)
1435 .build()
1436 .unwrap();
1437 let mut model = SGBT::new(config);
1438
1439 let mut rng: u64 = 12345;
1440 for _ in 0..1000 {
1441 rng ^= rng << 13;
1442 rng ^= rng >> 7;
1443 rng ^= rng << 17;
1444 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1445 rng ^= rng << 13;
1446 rng ^= rng >> 7;
1447 rng ^= rng << 17;
1448 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1449 let target = 2.0 * x1 + 3.0 * x2;
1450 model.train_one(&Sample::new(vec![x1, x2], target));
1451 }
1452
1453 let pred = model.predict(&[1.0, 2.0]);
1454 assert!(
1455 pred.is_finite(),
1456 "split re-eval model should produce finite predictions, got {}",
1457 pred
1458 );
1459 }
1460
1461 #[test]
1462 fn loss_accessor_works() {
1463 use crate::loss::logistic::LogisticLoss;
1464 let model = SGBT::with_loss(default_config(), LogisticLoss);
1465 let _loss: &LogisticLoss = model.loss();
1467 assert_eq!(_loss.n_outputs(), 1);
1468 }
1469
1470 #[test]
1471 fn clone_produces_independent_copy() {
1472 let config = default_config();
1473 let mut model = SGBT::new(config);
1474
1475 let mut rng: u64 = 99999;
1477 for _ in 0..200 {
1478 rng ^= rng << 13;
1479 rng ^= rng >> 7;
1480 rng ^= rng << 17;
1481 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1482 let target = 2.0 * x + 1.0;
1483 model.train_one(&Sample::new(vec![x], target));
1484 }
1485
1486 let mut cloned = model.clone();
1488
1489 let test_features = [3.0];
1491 let pred_original = model.predict(&test_features);
1492 let pred_cloned = cloned.predict(&test_features);
1493 assert!(
1494 (pred_original - pred_cloned).abs() < 1e-12,
1495 "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1496 );
1497
1498 for _ in 0..200 {
1500 rng ^= rng << 13;
1501 rng ^= rng >> 7;
1502 rng ^= rng << 17;
1503 let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1504 let target = -3.0 * x + 5.0; cloned.train_one(&Sample::new(vec![x], target));
1506 }
1507
1508 let pred_original_after = model.predict(&test_features);
1509 let pred_cloned_after = cloned.predict(&test_features);
1510
1511 assert!(
1513 (pred_original - pred_original_after).abs() < 1e-12,
1514 "original should be unchanged after training clone"
1515 );
1516
1517 assert!(
1519 (pred_original_after - pred_cloned_after).abs() > 1e-6,
1520 "clone should diverge after independent training"
1521 );
1522 }
1523
1524 #[test]
1528 fn predict_with_confidence_finite() {
1529 let config = SGBTConfig::builder()
1530 .n_steps(5)
1531 .grace_period(10)
1532 .build()
1533 .unwrap();
1534 let mut model = SGBT::new(config);
1535
1536 for i in 0..100 {
1538 let x = i as f64 * 0.1;
1539 model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1540 }
1541
1542 let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1543 assert!(pred.is_finite(), "prediction should be finite");
1544 assert!(confidence.is_finite(), "confidence should be finite");
1545 assert!(
1546 confidence > 0.0,
1547 "confidence should be positive after training"
1548 );
1549 }
1550
1551 #[test]
1555 fn predict_with_confidence_positive_after_training() {
1556 let config = SGBTConfig::builder()
1557 .n_steps(5)
1558 .grace_period(10)
1559 .build()
1560 .unwrap();
1561 let mut model = SGBT::new(config);
1562
1563 for i in 0..200 {
1565 let x = i as f64 * 0.05;
1566 model.train_one(&(&[x][..], x * 2.0));
1567 }
1568
1569 let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1570
1571 assert!(pred.is_finite(), "prediction should be finite");
1572 assert!(
1573 confidence > 0.0 && confidence.is_finite(),
1574 "confidence should be finite and positive, got {}",
1575 confidence,
1576 );
1577
1578 let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1580 assert!(
1581 (pred - pred2).abs() < 1e-12,
1582 "same input should give same prediction"
1583 );
1584 assert!(
1585 (confidence - conf2).abs() < 1e-12,
1586 "same input should give same confidence"
1587 );
1588 }
1589
1590 #[test]
1594 fn predict_with_confidence_matches_predict() {
1595 let config = SGBTConfig::builder()
1596 .n_steps(10)
1597 .grace_period(10)
1598 .build()
1599 .unwrap();
1600 let mut model = SGBT::new(config);
1601
1602 for i in 0..200 {
1603 let x = (i as f64 - 100.0) * 0.01;
1604 model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1605 }
1606
1607 let pred = model.predict(&[0.5, 0.25]);
1608 let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1609
1610 assert!(
1611 (pred - conf_pred).abs() < 1e-10,
1612 "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1613 pred,
1614 conf_pred,
1615 );
1616 }
1617
1618 #[test]
1622 fn gradient_clip_config_builder() {
1623 let config = SGBTConfig::builder()
1624 .n_steps(10)
1625 .gradient_clip_sigma(3.0)
1626 .build()
1627 .unwrap();
1628
1629 assert_eq!(config.gradient_clip_sigma, Some(3.0));
1630 }
1631
1632 #[test]
1636 fn monotone_constraints_config_builder() {
1637 let config = SGBTConfig::builder()
1638 .n_steps(10)
1639 .monotone_constraints(vec![1, -1, 0])
1640 .build()
1641 .unwrap();
1642
1643 assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1644 }
1645
1646 #[test]
1650 fn monotone_constraints_invalid_value_rejected() {
1651 let result = SGBTConfig::builder()
1652 .n_steps(10)
1653 .monotone_constraints(vec![1, 2, 0])
1654 .build();
1655
1656 assert!(result.is_err(), "constraint value 2 should be rejected");
1657 }
1658
1659 #[test]
1663 fn gradient_clip_sigma_negative_rejected() {
1664 let result = SGBTConfig::builder()
1665 .n_steps(10)
1666 .gradient_clip_sigma(-1.0)
1667 .build();
1668
1669 assert!(result.is_err(), "negative sigma should be rejected");
1670 }
1671
1672 #[test]
1676 fn gradient_clipping_reduces_outlier_impact() {
1677 let config_no_clip = SGBTConfig::builder()
1679 .n_steps(5)
1680 .grace_period(10)
1681 .build()
1682 .unwrap();
1683 let mut model_no_clip = SGBT::new(config_no_clip);
1684
1685 let config_clip = SGBTConfig::builder()
1687 .n_steps(5)
1688 .grace_period(10)
1689 .gradient_clip_sigma(3.0)
1690 .build()
1691 .unwrap();
1692 let mut model_clip = SGBT::new(config_clip);
1693
1694 for i in 0..100 {
1696 let x = (i as f64) * 0.01;
1697 let sample = (&[x][..], x * 2.0);
1698 model_no_clip.train_one(&sample);
1699 model_clip.train_one(&sample);
1700 }
1701
1702 let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1703 let pred_clip_before = model_clip.predict(&[0.5]);
1704
1705 let outlier = (&[0.5_f64][..], 10000.0);
1707 model_no_clip.train_one(&outlier);
1708 model_clip.train_one(&outlier);
1709
1710 let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1711 let pred_clip_after = model_clip.predict(&[0.5]);
1712
1713 let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1714 let delta_clip = (pred_clip_after - pred_clip_before).abs();
1715
1716 assert!(
1718 delta_clip <= delta_no_clip + 1e-10,
1719 "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1720 delta_clip,
1721 delta_no_clip,
1722 );
1723 }
1724
1725 #[test]
1729 fn train_batch_with_callback_fires() {
1730 let config = SGBTConfig::builder()
1731 .n_steps(3)
1732 .grace_period(5)
1733 .build()
1734 .unwrap();
1735 let mut model = SGBT::new(config);
1736
1737 let data: Vec<(Vec<f64>, f64)> = (0..25)
1738 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1739 .collect();
1740
1741 let mut callbacks = Vec::new();
1742 model.train_batch_with_callback(&data, 10, |n| {
1743 callbacks.push(n);
1744 });
1745
1746 assert_eq!(callbacks, vec![10, 20, 25]);
1748 }
1749
1750 #[test]
1754 fn train_batch_subsampled_trains_subset() {
1755 let config = SGBTConfig::builder()
1756 .n_steps(3)
1757 .grace_period(5)
1758 .build()
1759 .unwrap();
1760 let mut model = SGBT::new(config);
1761
1762 let data: Vec<(Vec<f64>, f64)> = (0..100)
1763 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1764 .collect();
1765
1766 model.train_batch_subsampled(&data, 20);
1768
1769 assert!(
1771 model.n_samples_seen() > 0,
1772 "model should have trained on subset"
1773 );
1774 assert!(
1775 model.n_samples_seen() <= 20,
1776 "model should have trained at most 20 samples, got {}",
1777 model.n_samples_seen(),
1778 );
1779 }
1780
1781 #[test]
1785 fn train_batch_subsampled_full_equals_batch() {
1786 let config1 = SGBTConfig::builder()
1787 .n_steps(3)
1788 .grace_period(5)
1789 .build()
1790 .unwrap();
1791 let config2 = config1.clone();
1792
1793 let mut model1 = SGBT::new(config1);
1794 let mut model2 = SGBT::new(config2);
1795
1796 let data: Vec<(Vec<f64>, f64)> = (0..50)
1797 .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1798 .collect();
1799
1800 model1.train_batch(&data);
1801 model2.train_batch_subsampled(&data, 1000); assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1805 let pred1 = model1.predict(&[2.5]);
1806 let pred2 = model2.predict(&[2.5]);
1807 assert!(
1808 (pred1 - pred2).abs() < 1e-12,
1809 "full subsample should equal batch: {} vs {}",
1810 pred1,
1811 pred2,
1812 );
1813 }
1814
1815 #[test]
1819 fn train_batch_subsampled_with_callback_works() {
1820 let config = SGBTConfig::builder()
1821 .n_steps(3)
1822 .grace_period(5)
1823 .build()
1824 .unwrap();
1825 let mut model = SGBT::new(config);
1826
1827 let data: Vec<(Vec<f64>, f64)> = (0..200)
1828 .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1829 .collect();
1830
1831 let mut callbacks = Vec::new();
1832 model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1833 callbacks.push(n);
1834 });
1835
1836 assert!(!callbacks.is_empty(), "should have received callbacks");
1838 assert_eq!(
1839 *callbacks.last().unwrap(),
1840 50,
1841 "final callback should be total samples"
1842 );
1843 }
1844
1845 fn xorshift64(state: &mut u64) -> u64 {
1851 let mut s = *state;
1852 s ^= s << 13;
1853 s ^= s >> 7;
1854 s ^= s << 17;
1855 *state = s;
1856 s
1857 }
1858
1859 fn rand_f64(state: &mut u64) -> f64 {
1860 xorshift64(state) as f64 / u64::MAX as f64
1861 }
1862
1863 fn linear_leaves_config() -> SGBTConfig {
1864 SGBTConfig::builder()
1865 .n_steps(10)
1866 .learning_rate(0.1)
1867 .grace_period(20)
1868 .max_depth(2) .n_bins(16)
1870 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1871 learning_rate: 0.1,
1872 decay: None,
1873 use_adagrad: false,
1874 })
1875 .build()
1876 .unwrap()
1877 }
1878
1879 #[test]
1880 fn linear_leaves_trains_without_panic() {
1881 let mut model = SGBT::new(linear_leaves_config());
1882 let mut rng = 42u64;
1883 for _ in 0..200 {
1884 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1885 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1886 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1887 model.train_one(&Sample::new(vec![x1, x2], y));
1888 }
1889 assert_eq!(model.n_samples_seen(), 200);
1890 }
1891
1892 #[test]
1893 fn linear_leaves_prediction_finite() {
1894 let mut model = SGBT::new(linear_leaves_config());
1895 let mut rng = 42u64;
1896 for _ in 0..200 {
1897 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1898 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1899 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1900 model.train_one(&Sample::new(vec![x1, x2], y));
1901 }
1902 let pred = model.predict(&[0.5, -0.3]);
1903 assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1904 }
1905
1906 #[test]
1907 fn linear_leaves_learns_linear_target() {
1908 let mut model = SGBT::new(linear_leaves_config());
1909 let mut rng = 42u64;
1910 for _ in 0..500 {
1911 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1912 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1913 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1914 model.train_one(&Sample::new(vec![x1, x2], y));
1915 }
1916
1917 let mut total_error = 0.0;
1919 for _ in 0..50 {
1920 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1921 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1922 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1923 let pred = model.predict(&[x1, x2]);
1924 total_error += (pred - y).powi(2);
1925 }
1926 let mse = total_error / 50.0;
1927 assert!(
1928 mse < 5.0,
1929 "linear leaves MSE on linear target should be < 5.0, got {mse}"
1930 );
1931 }
1932
1933 #[test]
1934 fn linear_leaves_better_than_constant_at_low_depth() {
1935 let constant_config = SGBTConfig::builder()
1938 .n_steps(10)
1939 .learning_rate(0.1)
1940 .grace_period(20)
1941 .max_depth(2)
1942 .n_bins(16)
1943 .seed(0xDEAD)
1944 .build()
1945 .unwrap();
1946 let linear_config = SGBTConfig::builder()
1947 .n_steps(10)
1948 .learning_rate(0.1)
1949 .grace_period(20)
1950 .max_depth(2)
1951 .n_bins(16)
1952 .seed(0xDEAD)
1953 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1954 learning_rate: 0.1,
1955 decay: None,
1956 use_adagrad: false,
1957 })
1958 .build()
1959 .unwrap();
1960
1961 let mut constant_model = SGBT::new(constant_config);
1962 let mut linear_model = SGBT::new(linear_config);
1963 let mut rng = 42u64;
1964
1965 for _ in 0..500 {
1966 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1967 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1968 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1969 let sample = Sample::new(vec![x1, x2], y);
1970 constant_model.train_one(&sample);
1971 linear_model.train_one(&sample);
1972 }
1973
1974 let mut constant_mse = 0.0;
1976 let mut linear_mse = 0.0;
1977 for _ in 0..100 {
1978 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1979 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1980 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1981 constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1982 linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1983 }
1984 constant_mse /= 100.0;
1985 linear_mse /= 100.0;
1986
1987 assert!(
1989 linear_mse < constant_mse,
1990 "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1991 );
1992 }
1993
1994 #[test]
1995 fn adaptive_leaves_trains_without_panic() {
1996 let config = SGBTConfig::builder()
1997 .n_steps(10)
1998 .learning_rate(0.1)
1999 .grace_period(20)
2000 .max_depth(3)
2001 .n_bins(16)
2002 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
2003 promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
2004 learning_rate: 0.1,
2005 decay: None,
2006 use_adagrad: false,
2007 }),
2008 })
2009 .build()
2010 .unwrap();
2011
2012 let mut model = SGBT::new(config);
2013 let mut rng = 42u64;
2014 for _ in 0..500 {
2015 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2016 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2017 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2018 model.train_one(&Sample::new(vec![x1, x2], y));
2019 }
2020 let pred = model.predict(&[0.5, -0.3]);
2021 assert!(
2022 pred.is_finite(),
2023 "adaptive leaf prediction should be finite, got {pred}"
2024 );
2025 }
2026
2027 #[test]
2028 fn linear_leaves_with_decay_trains_without_panic() {
2029 let config = SGBTConfig::builder()
2030 .n_steps(10)
2031 .learning_rate(0.1)
2032 .grace_period(20)
2033 .max_depth(3)
2034 .n_bins(16)
2035 .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
2036 learning_rate: 0.1,
2037 decay: Some(0.995),
2038 use_adagrad: false,
2039 })
2040 .build()
2041 .unwrap();
2042
2043 let mut model = SGBT::new(config);
2044 let mut rng = 42u64;
2045 for _ in 0..500 {
2046 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2047 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2048 let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2049 model.train_one(&Sample::new(vec![x1, x2], y));
2050 }
2051 let pred = model.predict(&[0.5, -0.3]);
2052 assert!(
2053 pred.is_finite(),
2054 "decay leaf prediction should be finite, got {pred}"
2055 );
2056 }
2057
2058 #[test]
2062 fn predict_smooth_returns_finite() {
2063 let config = SGBTConfig::builder()
2064 .n_steps(5)
2065 .learning_rate(0.1)
2066 .grace_period(10)
2067 .build()
2068 .unwrap();
2069 let mut model = SGBT::new(config);
2070
2071 for i in 0..200 {
2072 let x = (i as f64) * 0.1;
2073 model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
2074 }
2075
2076 let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
2077 let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2078
2079 assert!(pred_hard.is_finite(), "hard prediction should be finite");
2080 assert!(
2081 pred_smooth.is_finite(),
2082 "smooth prediction should be finite"
2083 );
2084 }
2085
2086 #[test]
2090 fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2091 let config = SGBTConfig::builder()
2092 .n_steps(5)
2093 .learning_rate(0.1)
2094 .grace_period(10)
2095 .build()
2096 .unwrap();
2097 let mut model = SGBT::new(config);
2098
2099 for i in 0..300 {
2100 let x = (i as f64) * 0.1;
2101 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2102 }
2103
2104 let features = [5.0, 2.5];
2105 let hard = model.predict(&features);
2106 let smooth = model.predict_smooth(&features, 0.001);
2107
2108 assert!(
2109 (hard - smooth).abs() < 0.5,
2110 "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2111 hard,
2112 smooth,
2113 );
2114 }
2115
2116 #[test]
2117 fn auto_bandwidth_computed_after_training() {
2118 let config = SGBTConfig::builder()
2119 .n_steps(5)
2120 .learning_rate(0.1)
2121 .grace_period(10)
2122 .build()
2123 .unwrap();
2124 let mut model = SGBT::new(config);
2125
2126 assert!(model.auto_bandwidths().is_empty());
2128
2129 for i in 0..200 {
2130 let x = (i as f64) * 0.1;
2131 model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2132 }
2133
2134 let bws = model.auto_bandwidths();
2136 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2137
2138 let pred = model.predict(&[5.0, 2.5]);
2140 assert!(
2141 pred.is_finite(),
2142 "auto-bandwidth predict should be finite: {}",
2143 pred
2144 );
2145 }
2146
2147 #[test]
2148 fn predict_interpolated_returns_finite() {
2149 let config = SGBTConfig::builder()
2150 .n_steps(5)
2151 .learning_rate(0.01)
2152 .build()
2153 .unwrap();
2154 let mut model = SGBT::new(config);
2155
2156 for i in 0..200 {
2157 let x = (i as f64) * 0.1;
2158 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2159 }
2160
2161 let pred = model.predict_interpolated(&[1.0, 0.5]);
2162 assert!(
2163 pred.is_finite(),
2164 "interpolated prediction should be finite: {}",
2165 pred
2166 );
2167 }
2168
2169 #[test]
2170 fn predict_sibling_interpolated_varies_with_features() {
2171 let config = SGBTConfig::builder()
2172 .n_steps(10)
2173 .learning_rate(0.1)
2174 .grace_period(10)
2175 .max_depth(6)
2176 .delta(0.1)
2177 .build()
2178 .unwrap();
2179 let mut model = SGBT::new(config);
2180
2181 for i in 0..2000 {
2182 let x = (i as f64) * 0.01;
2183 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2184 model.train_one(&Sample::new(vec![x, x * 0.3], y));
2185 }
2186
2187 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2189 assert!(pred.is_finite(), "sibling interpolated should be finite");
2190
2191 let bws = model.auto_bandwidths();
2194 if bws.iter().any(|&b| b.is_finite()) {
2195 let hard: Vec<f64> = (0..200)
2196 .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2197 .collect();
2198 let sib: Vec<f64> = (0..200)
2199 .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2200 .collect();
2201 let hc = hard
2202 .windows(2)
2203 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2204 .count();
2205 let sc = sib
2206 .windows(2)
2207 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2208 .count();
2209 assert!(
2210 sc >= hc,
2211 "sibling should produce >= hard changes: sib={}, hard={}",
2212 sc,
2213 hc
2214 );
2215 }
2216 }
2217
2218 #[test]
2219 fn predict_graduated_returns_finite() {
2220 let config = SGBTConfig::builder()
2221 .n_steps(5)
2222 .learning_rate(0.01)
2223 .max_tree_samples(200)
2224 .shadow_warmup(50)
2225 .build()
2226 .unwrap();
2227 let mut model = SGBT::new(config);
2228
2229 for i in 0..300 {
2230 let x = (i as f64) * 0.1;
2231 model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2232 }
2233
2234 let pred = model.predict_graduated(&[1.0, 0.5]);
2235 assert!(
2236 pred.is_finite(),
2237 "graduated prediction should be finite: {}",
2238 pred
2239 );
2240
2241 let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2242 assert!(
2243 pred2.is_finite(),
2244 "graduated+sibling prediction should be finite: {}",
2245 pred2
2246 );
2247 }
2248
2249 #[test]
2250 fn shadow_warmup_validation() {
2251 let result = SGBTConfig::builder()
2252 .n_steps(5)
2253 .learning_rate(0.01)
2254 .shadow_warmup(0)
2255 .build();
2256 assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2257 }
2258
2259 #[test]
2264 fn adaptive_mts_defaults_to_none() {
2265 let cfg = SGBTConfig::default();
2266 assert!(
2267 cfg.adaptive_mts.is_none(),
2268 "adaptive_mts should default to None"
2269 );
2270 }
2271
2272 #[test]
2273 fn adaptive_mts_config_builder() {
2274 let cfg = SGBTConfig::builder()
2275 .n_steps(10)
2276 .adaptive_mts(500, 2.0)
2277 .build()
2278 .unwrap();
2279 assert_eq!(
2280 cfg.adaptive_mts,
2281 Some((500, 2.0)),
2282 "adaptive_mts should store (base_mts, k)"
2283 );
2284 }
2285
2286 #[test]
2287 fn adaptive_mts_validation_rejects_low_base() {
2288 let result = SGBTConfig::builder()
2289 .n_steps(5)
2290 .adaptive_mts(50, 1.0)
2291 .build();
2292 assert!(
2293 result.is_err(),
2294 "adaptive_mts with base_mts < 100 should fail"
2295 );
2296 }
2297
2298 #[test]
2299 fn adaptive_mts_validation_rejects_zero_k() {
2300 let result = SGBTConfig::builder()
2301 .n_steps(5)
2302 .adaptive_mts(500, 0.0)
2303 .build();
2304 assert!(result.is_err(), "adaptive_mts with k=0 should fail");
2305 }
2306
2307 #[test]
2308 fn adaptive_mts_trains_without_panic() {
2309 let config = SGBTConfig::builder()
2310 .n_steps(5)
2311 .learning_rate(0.1)
2312 .grace_period(10)
2313 .max_depth(3)
2314 .n_bins(16)
2315 .adaptive_mts(200, 1.0)
2316 .build()
2317 .unwrap();
2318 let mut model = SGBT::new(config);
2319
2320 for i in 0..500 {
2321 let x = (i as f64) * 0.1;
2322 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2323 }
2324
2325 let pred = model.predict(&[1.0, 2.0]);
2326 assert!(
2327 pred.is_finite(),
2328 "adaptive_mts model should produce finite predictions, got {}",
2329 pred
2330 );
2331 }
2332
2333 #[test]
2338 fn proactive_prune_defaults_to_none() {
2339 let cfg = SGBTConfig::default();
2340 assert!(
2341 cfg.proactive_prune_interval.is_none(),
2342 "proactive_prune_interval should default to None"
2343 );
2344 }
2345
2346 #[test]
2347 fn proactive_prune_config_builder() {
2348 let cfg = SGBTConfig::builder()
2349 .n_steps(10)
2350 .proactive_prune_interval(500)
2351 .build()
2352 .unwrap();
2353 assert_eq!(
2354 cfg.proactive_prune_interval,
2355 Some(500),
2356 "proactive_prune_interval should be set"
2357 );
2358 }
2359
2360 #[test]
2361 fn proactive_prune_validation_rejects_low_interval() {
2362 let result = SGBTConfig::builder()
2363 .n_steps(5)
2364 .proactive_prune_interval(50)
2365 .build();
2366 assert!(
2367 result.is_err(),
2368 "proactive_prune_interval < 100 should fail"
2369 );
2370 }
2371
2372 #[test]
2373 fn proactive_prune_enables_contribution_tracking() {
2374 let config = SGBTConfig::builder()
2375 .n_steps(5)
2376 .learning_rate(0.1)
2377 .grace_period(10)
2378 .max_depth(3)
2379 .n_bins(16)
2380 .proactive_prune_interval(200)
2381 .build()
2382 .unwrap();
2383
2384 assert!(config.quality_prune_alpha.is_none());
2386
2387 let mut model = SGBT::new(config);
2388
2389 for i in 0..100 {
2391 let x = (i as f64) * 0.1;
2392 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2393 }
2394
2395 let pred = model.predict(&[1.0, 2.0]);
2399 assert!(
2400 pred.is_finite(),
2401 "proactive_prune model should produce finite predictions, got {}",
2402 pred
2403 );
2404 }
2405
2406 #[test]
2407 fn proactive_prune_trains_without_panic() {
2408 let config = SGBTConfig::builder()
2409 .n_steps(5)
2410 .learning_rate(0.1)
2411 .grace_period(10)
2412 .max_depth(3)
2413 .n_bins(16)
2414 .proactive_prune_interval(200)
2415 .build()
2416 .unwrap();
2417 let mut model = SGBT::new(config);
2418
2419 for i in 0..500 {
2421 let x = (i as f64) * 0.1;
2422 model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2423 }
2424
2425 let pred = model.predict(&[1.0, 2.0]);
2426 assert!(
2427 pred.is_finite(),
2428 "proactive_prune model should produce finite predictions after pruning, got {}",
2429 pred
2430 );
2431 }
2432}