1use alloc::collections::VecDeque;
7use alloc::string::String;
8use alloc::vec;
9use alloc::vec::Vec;
10
11use core::fmt;
12
13use crate::ensemble::config::SGBTConfig;
14use crate::ensemble::step::BoostingStep;
15use crate::loss::squared::SquaredLoss;
16use crate::loss::Loss;
17use crate::sample::Observation;
18#[allow(unused_imports)] use crate::sample::Sample;
20
21#[derive(Debug, Clone, Default)]
24#[allow(dead_code)]
25pub(crate) struct DiagnosticCache {
26 pub(crate) prev_contributions: Vec<f64>,
28 pub(crate) prev_prev_contributions: Vec<f64>,
30 pub(crate) cached_residual_alignment: f64,
32 pub(crate) cached_reg_sensitivity: f64,
34 pub(crate) cached_depth_sufficiency: f64,
36 pub(crate) cached_effective_dof: f64,
38 pub(crate) contribution_accuracy: Vec<f64>,
40 pub(crate) prune_alpha: f64,
42}
43
44pub struct SGBT<L: Loss = SquaredLoss> {
72 pub(crate) config: SGBTConfig,
74 pub(crate) steps: Vec<BoostingStep>,
76 pub(crate) loss: L,
78 pub(crate) base_prediction: f64,
80 pub(crate) base_initialized: bool,
82 pub(crate) initial_targets: Vec<f64>,
84 pub(crate) initial_target_count: usize,
86 pub(crate) samples_seen: u64,
88 pub(crate) rng_state: u64,
90 pub(crate) contribution_ewma: Vec<f64>,
93 pub(crate) low_contrib_count: Vec<u64>,
96 pub(crate) rolling_mean_error: f64,
99 pub(crate) auto_bandwidths: Vec<f64>,
102 pub(crate) last_replacement_sum: u64,
105 pub(crate) rolling_contribution_sigma: f64,
108 pub(crate) sigma_ring: VecDeque<f64>,
111 pub(crate) mts_replacement_sum: u64,
113 pub(crate) diag: DiagnosticCache,
118}
119
120impl<L: Loss + Clone> Clone for SGBT<L> {
121 fn clone(&self) -> Self {
122 Self {
123 config: self.config.clone(),
124 steps: self.steps.clone(),
125 loss: self.loss.clone(),
126 base_prediction: self.base_prediction,
127 base_initialized: self.base_initialized,
128 initial_targets: self.initial_targets.clone(),
129 initial_target_count: self.initial_target_count,
130 samples_seen: self.samples_seen,
131 rng_state: self.rng_state,
132 contribution_ewma: self.contribution_ewma.clone(),
133 low_contrib_count: self.low_contrib_count.clone(),
134 rolling_mean_error: self.rolling_mean_error,
135 auto_bandwidths: self.auto_bandwidths.clone(),
136 last_replacement_sum: self.last_replacement_sum,
137 rolling_contribution_sigma: self.rolling_contribution_sigma,
138 sigma_ring: self.sigma_ring.clone(),
139 mts_replacement_sum: self.mts_replacement_sum,
140 diag: self.diag.clone(),
141 }
142 }
143}
144
145impl<L: Loss> fmt::Debug for SGBT<L> {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 f.debug_struct("SGBT")
148 .field("n_steps", &self.steps.len())
149 .field("samples_seen", &self.samples_seen)
150 .field("base_prediction", &self.base_prediction)
151 .field("base_initialized", &self.base_initialized)
152 .finish()
153 }
154}
155
156impl SGBT<SquaredLoss> {
161 pub fn new(config: SGBTConfig) -> Self {
166 Self::with_loss(config, SquaredLoss)
167 }
168}
169
170impl<L: Loss> SGBT<L> {
175 pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
188 let leaf_decay_alpha = config
189 .leaf_half_life
190 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
191
192 let tree_config = crate::ensemble::config::build_tree_config(&config)
193 .leaf_decay_alpha_opt(leaf_decay_alpha);
194
195 let max_tree_samples = config.max_tree_samples;
196
197 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
198 let steps: Vec<BoostingStep> = (0..config.n_steps)
199 .map(|i| {
200 let mut tc = tree_config.clone();
201 tc.seed = config.seed ^ (i as u64);
202 let detector = config.drift_detector.create();
203 if shadow_warmup > 0 {
204 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
205 } else {
206 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
207 }
208 })
209 .collect();
210
211 let seed = config.seed;
212 let initial_target_count = config.initial_target_count;
213 let n = config.n_steps;
214 let has_pruning =
215 config.quality_prune_alpha.is_some() || config.proactive_prune_interval.is_some();
216 let grace_period = config.grace_period;
217 Self {
218 config,
219 steps,
220 loss,
221 base_prediction: 0.0,
222 base_initialized: false,
223 initial_targets: Vec::new(),
224 initial_target_count,
225 samples_seen: 0,
226 rng_state: seed,
227 contribution_ewma: if has_pruning {
228 vec![0.0; n]
229 } else {
230 Vec::new()
231 },
232 low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
233 rolling_mean_error: 0.0,
234 rolling_contribution_sigma: 0.0,
235 auto_bandwidths: Vec::new(),
236 last_replacement_sum: 0,
237 sigma_ring: VecDeque::with_capacity(grace_period),
238 mts_replacement_sum: 0,
239 diag: DiagnosticCache {
240 contribution_accuracy: vec![0.0; n],
241 ..Default::default()
242 },
243 }
244 }
245
246 pub fn train_one(&mut self, sample: &impl Observation) {
255 self.samples_seen += 1;
256 let target = sample.target();
257 let features = sample.features();
258
259 if !target.is_finite() || !features.iter().all(|f| f.is_finite()) {
261 return;
262 }
263
264 if !self.base_initialized {
266 self.initial_targets.push(target);
267 if self.initial_targets.len() >= self.initial_target_count {
268 self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
269 self.base_initialized = true;
270 self.initial_targets.clear();
271 self.initial_targets.shrink_to_fit();
272 }
273 }
274
275 let mut current_pred = self.base_prediction;
277
278 if let Some((base_mts, k)) = self.config.adaptive_mts {
280 let sigma = self.contribution_variance(features);
281 self.rolling_contribution_sigma =
282 0.999 * self.rolling_contribution_sigma + 0.001 * sigma;
283
284 let normalized = if self.rolling_contribution_sigma > 1e-10 {
285 sigma / self.rolling_contribution_sigma
286 } else {
287 1.0
288 };
289 let factor = 1.0 / (1.0 + k * normalized);
290 let floor = (base_mts as f64 * self.config.adaptive_mts_floor)
291 .max(self.config.grace_period as f64 * 2.0);
292 let effective_mts = ((base_mts as f64) * factor).max(floor) as u64;
293 for step in &mut self.steps {
294 step.slot_mut().set_max_tree_samples(Some(effective_mts));
295 }
296 }
297
298 let prune_alpha = self
299 .config
300 .quality_prune_alpha
301 .or_else(|| self.config.proactive_prune_interval.map(|_| 0.01));
302 let prune_threshold = self.config.quality_prune_threshold;
303 let prune_patience = self.config.quality_prune_patience;
304
305 let mut replaced_this_step = vec![false; self.steps.len()];
307
308 let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
310 let abs_error = crate::math::abs(target - current_pred);
311 if self.rolling_mean_error > 1e-15 {
312 let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
313 self.rolling_mean_error =
314 ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
315 w
316 } else {
317 self.rolling_mean_error = abs_error.max(1e-15);
318 1.0 }
320 } else {
321 1.0
322 };
323
324 #[allow(clippy::needless_range_loop)]
326 for s in 0..self.steps.len() {
327 let gradient = self.loss.gradient(target, current_pred) * error_weight;
328 let hessian = self.loss.hessian(target, current_pred) * error_weight;
329 let train_count = self
330 .config
331 .variant
332 .train_count(hessian, &mut self.rng_state);
333
334 let step_pred =
335 self.steps[s].train_and_predict(features, gradient, hessian, train_count);
336
337 current_pred += self.config.learning_rate * step_pred;
338
339 if let Some(alpha) = prune_alpha {
341 let contribution = crate::math::abs(self.config.learning_rate * step_pred);
342 self.contribution_ewma[s] =
343 alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
344
345 if self.contribution_ewma[s] < prune_threshold {
346 self.low_contrib_count[s] += 1;
347 if self.low_contrib_count[s] >= prune_patience {
348 self.steps[s].reset();
349 self.contribution_ewma[s] = 0.0;
350 self.low_contrib_count[s] = 0;
351 replaced_this_step[s] = true;
352 }
353 } else {
354 self.low_contrib_count[s] = 0;
355 }
356 }
357 }
358
359 if let Some(interval) = self.config.proactive_prune_interval {
361 if self.samples_seen % interval == 0
362 && self.samples_seen > 0
363 && !self.contribution_ewma.is_empty()
364 {
365 let min_age = interval / 2;
366
367 let mature: Vec<(usize, f64)> = self
369 .steps
370 .iter()
371 .enumerate()
372 .zip(self.contribution_ewma.iter())
373 .filter(|((i, step), _)| {
374 step.n_samples_seen() >= min_age && !replaced_this_step[*i]
375 })
376 .map(|((i, _), &ewma)| (i, ewma))
377 .collect();
378
379 if !mature.is_empty() {
380 let mut sorted_ewma: Vec<f64> = mature.iter().map(|(_, e)| *e).collect();
382 sorted_ewma
383 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
384 let p25_idx = (sorted_ewma.len().saturating_sub(1)) / 4;
385 let p25 = sorted_ewma[p25_idx];
386
387 let worst = mature.iter().min_by(|(_, a), (_, b)| {
389 a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal)
390 });
391
392 if let Some(&(worst_idx, worst_ewma)) = worst {
393 if worst_ewma < p25 {
394 self.steps[worst_idx].reset();
395 self.contribution_ewma[worst_idx] = 0.0;
396 self.low_contrib_count[worst_idx] = 0;
397 }
398 }
399 }
400 }
401 }
402
403 self.refresh_bandwidths();
405 }
406
407 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
409 for sample in samples {
410 self.train_one(sample);
411 }
412 }
413
414 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
416 &mut self,
417 samples: &[O],
418 interval: usize,
419 mut callback: F,
420 ) {
421 let interval = interval.max(1);
422 for (i, sample) in samples.iter().enumerate() {
423 self.train_one(sample);
424 if (i + 1) % interval == 0 {
425 callback(i + 1);
426 }
427 }
428 let total = samples.len();
429 if total % interval != 0 {
430 callback(total);
431 }
432 }
433
434 pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
436 if max_samples >= samples.len() {
437 self.train_batch(samples);
438 return;
439 }
440 let mut reservoir: Vec<usize> = (0..max_samples).collect();
441 let mut rng = self.rng_state;
442 for i in max_samples..samples.len() {
443 rng ^= rng << 13;
444 rng ^= rng >> 7;
445 rng ^= rng << 17;
446 let j = (rng % (i as u64 + 1)) as usize;
447 if j < max_samples {
448 reservoir[j] = i;
449 }
450 }
451 self.rng_state = rng;
452 reservoir.sort_unstable();
453 for &idx in &reservoir {
454 self.train_one(&samples[idx]);
455 }
456 }
457
458 pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
460 &mut self,
461 samples: &[O],
462 max_samples: usize,
463 interval: usize,
464 mut callback: F,
465 ) {
466 if max_samples >= samples.len() {
467 self.train_batch_with_callback(samples, interval, callback);
468 return;
469 }
470 let mut reservoir: Vec<usize> = (0..max_samples).collect();
471 let mut rng = self.rng_state;
472 for i in max_samples..samples.len() {
473 rng ^= rng << 13;
474 rng ^= rng >> 7;
475 rng ^= rng << 17;
476 let j = (rng % (i as u64 + 1)) as usize;
477 if j < max_samples {
478 reservoir[j] = i;
479 }
480 }
481 self.rng_state = rng;
482 reservoir.sort_unstable();
483 let interval = interval.max(1);
484 for (i, &idx) in reservoir.iter().enumerate() {
485 self.train_one(&samples[idx]);
486 if (i + 1) % interval == 0 {
487 callback(i + 1);
488 }
489 }
490 let total = reservoir.len();
491 if total % interval != 0 {
492 callback(total);
493 }
494 }
495
496 pub fn predict(&self, features: &[f64]) -> f64 {
505 let mut pred = self.base_prediction;
506 if self.auto_bandwidths.is_empty() {
507 for step in &self.steps {
508 pred += self.config.learning_rate * step.predict(features);
509 }
510 } else {
511 for step in &self.steps {
512 pred += self.config.learning_rate
513 * step.predict_smooth_auto(features, &self.auto_bandwidths);
514 }
515 }
516 pred
517 }
518
519 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
521 let mut pred = self.base_prediction;
522 for step in &self.steps {
523 pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
524 }
525 pred
526 }
527
528 pub fn auto_bandwidths(&self) -> &[f64] {
530 &self.auto_bandwidths
531 }
532
533 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
535 let mut pred = self.base_prediction;
536 for step in &self.steps {
537 pred += self.config.learning_rate * step.predict_interpolated(features);
538 }
539 pred
540 }
541
542 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
544 let mut pred = self.base_prediction;
545 for step in &self.steps {
546 pred += self.config.learning_rate
547 * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
548 }
549 pred
550 }
551
552 pub fn predict_graduated(&self, features: &[f64]) -> f64 {
554 let mut pred = self.base_prediction;
555 for step in &self.steps {
556 pred += self.config.learning_rate * step.predict_graduated(features);
557 }
558 pred
559 }
560
561 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
563 let mut pred = self.base_prediction;
564 for step in &self.steps {
565 pred += self.config.learning_rate
566 * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
567 }
568 pred
569 }
570
571 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
573 self.loss.predict_transform(self.predict(features))
574 }
575
576 pub fn predict_proba(&self, features: &[f64]) -> f64 {
578 self.predict_transformed(features)
579 }
580
581 pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
585 let mut pred = self.base_prediction;
586 let mut total_variance = 0.0;
587 let lr2 = self.config.learning_rate * self.config.learning_rate;
588 for step in &self.steps {
589 let (value, variance) = step.predict_with_variance(features);
590 pred += self.config.learning_rate * value;
591 total_variance += lr2 * variance;
592 }
593 let confidence = if total_variance > 0.0 && total_variance.is_finite() {
594 1.0 / crate::math::sqrt(total_variance)
595 } else {
596 0.0
597 };
598 (pred, confidence)
599 }
600
601 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
603 feature_matrix.iter().map(|f| self.predict(f)).collect()
604 }
605
606 pub fn n_steps(&self) -> usize {
612 self.steps.len()
613 }
614
615 pub fn n_trees(&self) -> usize {
617 self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
618 }
619
620 pub fn total_leaves(&self) -> usize {
622 self.steps.iter().map(|s| s.n_leaves()).sum()
623 }
624
625 pub fn n_samples_seen(&self) -> u64 {
627 self.samples_seen
628 }
629
630 pub fn base_prediction(&self) -> f64 {
632 self.base_prediction
633 }
634
635 pub fn is_initialized(&self) -> bool {
637 self.base_initialized
638 }
639
640 pub fn config(&self) -> &SGBTConfig {
642 &self.config
643 }
644
645 #[inline]
647 pub fn set_learning_rate(&mut self, lr: f64) {
648 self.config.learning_rate = lr;
649 }
650
651 pub fn steps(&self) -> &[BoostingStep] {
653 &self.steps
654 }
655
656 pub fn loss(&self) -> &L {
658 &self.loss
659 }
660
661 pub fn feature_importances(&self) -> Vec<f64> {
665 let mut totals: Vec<f64> = Vec::new();
666 for step in &self.steps {
667 let gains = step.slot().split_gains();
668 if totals.is_empty() && !gains.is_empty() {
669 totals.resize(gains.len(), 0.0);
670 }
671 for (i, &g) in gains.iter().enumerate() {
672 if i < totals.len() {
673 totals[i] += g;
674 }
675 }
676 }
677 let sum: f64 = totals.iter().sum();
678 if sum > 0.0 {
679 totals.iter_mut().for_each(|v| *v /= sum);
680 }
681 totals
682 }
683
684 pub fn feature_names(&self) -> Option<&[String]> {
686 self.config.feature_names.as_deref()
687 }
688
689 pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
694 let names = self.config.feature_names.as_ref()?;
695 let importances = self.feature_importances();
696 let mut pairs: Vec<(String, f64)> = names
697 .iter()
698 .zip(importances.iter().chain(core::iter::repeat(&0.0)))
699 .map(|(n, &v)| (n.clone(), v))
700 .collect();
701 pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
702 Some(pairs)
703 }
704
705 #[cfg(feature = "std")]
707 pub fn train_one_named(
708 &mut self,
709 features: &std::collections::HashMap<alloc::string::String, f64>,
710 target: f64,
711 ) {
712 let names = self
713 .config
714 .feature_names
715 .as_ref()
716 .expect("train_one_named requires feature_names to be configured");
717 let vec: Vec<f64> = names
718 .iter()
719 .map(|name| features.get(name).copied().unwrap_or(0.0))
720 .collect();
721 self.train_one(&(&vec[..], target));
722 }
723
724 #[cfg(feature = "std")]
726 pub fn predict_named(
727 &self,
728 features: &std::collections::HashMap<alloc::string::String, f64>,
729 ) -> f64 {
730 let names = self
731 .config
732 .feature_names
733 .as_ref()
734 .expect("predict_named requires feature_names to be configured");
735 let vec: Vec<f64> = names
736 .iter()
737 .map(|name| features.get(name).copied().unwrap_or(0.0))
738 .collect();
739 self.predict(&vec)
740 }
741
742 pub fn reset(&mut self) {
748 for step in &mut self.steps {
749 step.reset();
750 }
751 self.base_prediction = 0.0;
752 self.base_initialized = false;
753 self.initial_targets.clear();
754 self.samples_seen = 0;
755 self.rng_state = self.config.seed;
756 self.rolling_mean_error = 0.0;
757 self.rolling_contribution_sigma = 0.0;
758 self.auto_bandwidths.clear();
759 self.last_replacement_sum = 0;
760 self.sigma_ring.clear();
761 self.mts_replacement_sum = 0;
762 self.diag = DiagnosticCache {
763 contribution_accuracy: vec![0.0; self.steps.len()],
764 ..Default::default()
765 };
766 if !self.contribution_ewma.is_empty() {
767 self.contribution_ewma.iter_mut().for_each(|v| *v = 0.0);
768 }
769 if !self.low_contrib_count.is_empty() {
770 self.low_contrib_count.iter_mut().for_each(|v| *v = 0);
771 }
772 }
773
774 fn contribution_variance(&self, features: &[f64]) -> f64 {
780 let n = self.steps.len();
781 if n <= 1 {
782 return 0.0;
783 }
784 let lr = self.config.learning_rate;
785 let mut sum = 0.0;
786 let mut sq_sum = 0.0;
787 for step in &self.steps {
788 let c = lr * step.predict(features);
789 sum += c;
790 sq_sum += c * c;
791 }
792 let n_f = n as f64;
793 let mean = sum / n_f;
794 let var = (sq_sum / n_f) - (mean * mean);
795 crate::math::sqrt((var.abs() * n_f / (n_f - 1.0)).max(0.0))
796 }
797
798 fn refresh_bandwidths(&mut self) {
800 let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
801 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
802 self.auto_bandwidths = self.compute_auto_bandwidths();
803 self.last_replacement_sum = current_sum;
804 }
805 }
806
807 fn compute_auto_bandwidths(&self) -> Vec<f64> {
809 const K: f64 = 2.0;
810 let n_features = self
811 .steps
812 .iter()
813 .filter_map(|s| s.slot().active_tree().n_features())
814 .max()
815 .unwrap_or(0);
816
817 if n_features == 0 {
818 return Vec::new();
819 }
820
821 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
822 for step in &self.steps {
823 let tree_thresholds = step
824 .slot()
825 .active_tree()
826 .collect_split_thresholds_per_feature();
827 for (i, ts) in tree_thresholds.into_iter().enumerate() {
828 if i < n_features {
829 all_thresholds[i].extend(ts);
830 }
831 }
832 }
833
834 let n_bins = self.config.n_bins as f64;
835
836 all_thresholds
837 .iter()
838 .map(|ts| {
839 if ts.is_empty() {
840 return f64::INFINITY;
841 }
842 let mut sorted = ts.clone();
843 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
844 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
845
846 if sorted.len() < 2 {
847 return f64::INFINITY;
848 }
849
850 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
851
852 if sorted.len() < 3 {
853 let range = sorted.last().unwrap() - sorted.first().unwrap();
854 if range < 1e-15 {
855 return f64::INFINITY;
856 }
857 return (range / n_bins) * K;
858 }
859
860 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
861 let median_gap = if gaps.len() % 2 == 0 {
862 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
863 } else {
864 gaps[gaps.len() / 2]
865 };
866
867 if median_gap < 1e-15 {
868 f64::INFINITY
869 } else {
870 median_gap * K
871 }
872 })
873 .collect()
874 }
875}