1use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::ensemble::config::{SGBTConfig, ScaleMode};
34use crate::ensemble::step::BoostingStep;
35use crate::sample::{Observation, SampleRef};
36use crate::tree::builder::TreeConfig;
37
38struct PackedInferenceCache {
43 bytes: Vec<u8>,
44 base: f64,
45 n_features: usize,
46}
47
48impl Clone for PackedInferenceCache {
49 fn clone(&self) -> Self {
50 Self {
51 bytes: self.bytes.clone(),
52 base: self.base,
53 n_features: self.n_features,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy)]
60pub struct GaussianPrediction {
61 pub mu: f64,
63 pub sigma: f64,
65 pub log_sigma: f64,
67}
68
69impl GaussianPrediction {
70 #[inline]
74 pub fn lower(&self, z: f64) -> f64 {
75 self.mu - z * self.sigma
76 }
77
78 #[inline]
80 pub fn upper(&self, z: f64) -> f64 {
81 self.mu + z * self.sigma
82 }
83}
84
85#[derive(Debug, Clone)]
91pub struct TreeDiagnostic {
92 pub n_leaves: usize,
94 pub max_depth_reached: usize,
96 pub samples_seen: u64,
98 pub leaf_weight_stats: (f64, f64, f64, f64),
100 pub split_features: Vec<usize>,
102 pub leaf_sample_counts: Vec<u64>,
104 pub prediction_mean: f64,
106 pub prediction_std: f64,
108}
109
110#[derive(Debug, Clone)]
115pub struct ModelDiagnostics {
116 pub trees: Vec<TreeDiagnostic>,
118 pub location_trees: Vec<TreeDiagnostic>,
120 pub scale_trees: Vec<TreeDiagnostic>,
122 pub feature_split_counts: Vec<usize>,
124 pub location_base: f64,
126 pub scale_base: f64,
128 pub empirical_sigma: f64,
130 pub scale_mode: ScaleMode,
132 pub scale_trees_active: usize,
134 pub auto_bandwidths: Vec<f64>,
137 pub ensemble_grad_mean: f64,
139 pub ensemble_grad_std: f64,
141}
142
143#[derive(Debug, Clone)]
145pub struct DecomposedPrediction {
146 pub location_base: f64,
148 pub scale_base: f64,
150 pub location_contributions: Vec<f64>,
153 pub scale_contributions: Vec<f64>,
156}
157
158impl DecomposedPrediction {
159 pub fn mu(&self) -> f64 {
161 self.location_base + self.location_contributions.iter().sum::<f64>()
162 }
163
164 pub fn log_sigma(&self) -> f64 {
166 self.scale_base + self.scale_contributions.iter().sum::<f64>()
167 }
168
169 pub fn sigma(&self) -> f64 {
171 crate::math::exp(self.log_sigma()).max(1e-8)
172 }
173}
174
175pub struct DistributionalSGBT {
201 config: SGBTConfig,
203 location_steps: Vec<BoostingStep>,
205 scale_steps: Vec<BoostingStep>,
207 location_base: f64,
209 scale_base: f64,
211 base_initialized: bool,
213 initial_targets: Vec<f64>,
215 initial_target_count: usize,
217 samples_seen: u64,
219 rng_state: u64,
221 uncertainty_modulated_lr: bool,
223 rolling_sigma_mean: f64,
228 scale_mode: ScaleMode,
230 ewma_sq_err: f64,
235 empirical_sigma_alpha: f64,
237 prev_sigma: f64,
239 sigma_velocity: f64,
242 auto_bandwidths: Vec<f64>,
244 last_replacement_sum: u64,
246 ensemble_grad_mean: f64,
248 ensemble_grad_m2: f64,
250 ensemble_grad_count: u64,
252 packed_cache: Option<PackedInferenceCache>,
254 samples_since_refresh: u64,
256 packed_refresh_interval: u64,
258}
259
260impl Clone for DistributionalSGBT {
261 fn clone(&self) -> Self {
262 Self {
263 config: self.config.clone(),
264 location_steps: self.location_steps.clone(),
265 scale_steps: self.scale_steps.clone(),
266 location_base: self.location_base,
267 scale_base: self.scale_base,
268 base_initialized: self.base_initialized,
269 initial_targets: self.initial_targets.clone(),
270 initial_target_count: self.initial_target_count,
271 samples_seen: self.samples_seen,
272 rng_state: self.rng_state,
273 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
274 rolling_sigma_mean: self.rolling_sigma_mean,
275 scale_mode: self.scale_mode,
276 ewma_sq_err: self.ewma_sq_err,
277 empirical_sigma_alpha: self.empirical_sigma_alpha,
278 prev_sigma: self.prev_sigma,
279 sigma_velocity: self.sigma_velocity,
280 auto_bandwidths: self.auto_bandwidths.clone(),
281 last_replacement_sum: self.last_replacement_sum,
282 ensemble_grad_mean: self.ensemble_grad_mean,
283 ensemble_grad_m2: self.ensemble_grad_m2,
284 ensemble_grad_count: self.ensemble_grad_count,
285 packed_cache: self.packed_cache.clone(),
286 samples_since_refresh: self.samples_since_refresh,
287 packed_refresh_interval: self.packed_refresh_interval,
288 }
289 }
290}
291
292impl core::fmt::Debug for DistributionalSGBT {
293 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
294 let mut s = f.debug_struct("DistributionalSGBT");
295 s.field("n_steps", &self.location_steps.len())
296 .field("samples_seen", &self.samples_seen)
297 .field("location_base", &self.location_base)
298 .field("scale_mode", &self.scale_mode)
299 .field("base_initialized", &self.base_initialized);
300 match self.scale_mode {
301 ScaleMode::Empirical => {
302 s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
303 }
304 ScaleMode::TreeChain => {
305 s.field("scale_base", &self.scale_base);
306 }
307 }
308 if self.uncertainty_modulated_lr {
309 s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
310 }
311 s.finish()
312 }
313}
314
315impl DistributionalSGBT {
316 pub fn new(config: SGBTConfig) -> Self {
322 let leaf_decay_alpha = config
323 .leaf_half_life
324 .map(|hl| crate::math::exp(-(crate::math::ln(2.0_f64)) / hl as f64));
325
326 let tree_config = TreeConfig::new()
327 .max_depth(config.max_depth)
328 .n_bins(config.n_bins)
329 .lambda(config.lambda)
330 .gamma(config.gamma)
331 .grace_period(config.grace_period)
332 .delta(config.delta)
333 .feature_subsample_rate(config.feature_subsample_rate)
334 .leaf_decay_alpha_opt(leaf_decay_alpha)
335 .split_reeval_interval_opt(config.split_reeval_interval)
336 .feature_types_opt(config.feature_types.clone())
337 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
338 .monotone_constraints_opt(config.monotone_constraints.clone())
339 .max_leaf_output_opt(config.max_leaf_output)
340 .min_hessian_sum_opt(config.min_hessian_sum)
341 .leaf_model_type(config.leaf_model_type.clone());
342
343 let max_tree_samples = config.max_tree_samples;
344
345 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
347 let location_steps: Vec<BoostingStep> = (0..config.n_steps)
348 .map(|i| {
349 let mut tc = tree_config.clone();
350 tc.seed = config.seed ^ (i as u64);
351 let detector = config.drift_detector.create();
352 if shadow_warmup > 0 {
353 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
354 } else {
355 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
356 }
357 })
358 .collect();
359
360 let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
362 .map(|i| {
363 let mut tc = tree_config.clone();
364 tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
365 let detector = config.drift_detector.create();
366 if shadow_warmup > 0 {
367 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
368 } else {
369 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
370 }
371 })
372 .collect();
373
374 let seed = config.seed;
375 let initial_target_count = config.initial_target_count;
376 let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
377 let scale_mode = config.scale_mode;
378 let empirical_sigma_alpha = config.empirical_sigma_alpha;
379 let packed_refresh_interval = config.packed_refresh_interval;
380 Self {
381 config,
382 location_steps,
383 scale_steps,
384 location_base: 0.0,
385 scale_base: 0.0,
386 base_initialized: false,
387 initial_targets: Vec::new(),
388 initial_target_count,
389 samples_seen: 0,
390 rng_state: seed,
391 uncertainty_modulated_lr,
392 rolling_sigma_mean: 1.0, scale_mode,
394 ewma_sq_err: 1.0, empirical_sigma_alpha,
396 prev_sigma: 0.0,
397 sigma_velocity: 0.0,
398 auto_bandwidths: Vec::new(),
399 last_replacement_sum: 0,
400 ensemble_grad_mean: 0.0,
401 ensemble_grad_m2: 0.0,
402 ensemble_grad_count: 0,
403 packed_cache: None,
404 samples_since_refresh: 0,
405 packed_refresh_interval,
406 }
407 }
408
409 pub fn train_one(&mut self, sample: &impl Observation) {
411 self.samples_seen += 1;
412 let target = sample.target();
413 let features = sample.features();
414
415 if !self.base_initialized {
417 self.initial_targets.push(target);
418 if self.initial_targets.len() >= self.initial_target_count {
419 let sum: f64 = self.initial_targets.iter().sum();
421 let mean = sum / self.initial_targets.len() as f64;
422 self.location_base = mean;
423
424 let var: f64 = self
426 .initial_targets
427 .iter()
428 .map(|&y| (y - mean) * (y - mean))
429 .sum::<f64>()
430 / self.initial_targets.len() as f64;
431 let initial_std = crate::math::sqrt(var).max(1e-6);
432 self.scale_base = crate::math::ln(initial_std);
433
434 self.rolling_sigma_mean = initial_std;
436 self.ewma_sq_err = var.max(1e-12);
437
438 self.prev_sigma = initial_std;
440 self.sigma_velocity = 0.0;
441
442 self.base_initialized = true;
443 self.initial_targets.clear();
444 self.initial_targets.shrink_to_fit();
445 }
446 return;
447 }
448
449 match self.scale_mode {
450 ScaleMode::Empirical => self.train_one_empirical(target, features),
451 ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
452 }
453
454 self.refresh_bandwidths();
456 }
457
458 fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
460 let mut mu = self.location_base;
462 for s in 0..self.location_steps.len() {
463 mu += self.config.learning_rate * self.location_steps[s].predict(features);
464 }
465
466 let err = target - mu;
468 let alpha = self.empirical_sigma_alpha;
469 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
470 let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
471
472 let sigma_ratio = if self.uncertainty_modulated_lr {
474 let d_sigma = empirical_sigma - self.prev_sigma;
476 self.prev_sigma = empirical_sigma;
477
478 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
480
481 let k_d = if self.rolling_sigma_mean > 1e-12 {
483 crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
484 } else {
485 0.0
486 };
487
488 let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
490 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
491
492 const SIGMA_EWMA_ALPHA: f64 = 0.001;
494 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
495 + SIGMA_EWMA_ALPHA * empirical_sigma;
496
497 ratio
498 } else {
499 1.0
500 };
501
502 let base_lr = self.config.learning_rate;
503
504 let mut mu_accum = self.location_base;
506 for s in 0..self.location_steps.len() {
507 let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
508 self.update_ensemble_grad_stats(g_mu);
510 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
511 let loc_pred =
512 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
513 mu_accum += (base_lr * sigma_ratio) * loc_pred;
514 }
515
516 self.maybe_refresh_packed_cache();
518 }
519
520 fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
522 let mut mu = self.location_base;
523 let mut log_sigma = self.scale_base;
524
525 let sigma_ratio = if self.uncertainty_modulated_lr {
527 let current_sigma = crate::math::exp(log_sigma).max(1e-8);
528
529 let d_sigma = current_sigma - self.prev_sigma;
531 self.prev_sigma = current_sigma;
532
533 let alpha = self.empirical_sigma_alpha;
535 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
536
537 let k_d = if self.rolling_sigma_mean > 1e-12 {
539 crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
540 } else {
541 0.0
542 };
543
544 let pd_sigma = current_sigma + k_d * self.sigma_velocity;
546 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
547
548 const SIGMA_EWMA_ALPHA: f64 = 0.001;
549 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
550 + SIGMA_EWMA_ALPHA * current_sigma;
551
552 ratio
553 } else {
554 1.0
555 };
556
557 let base_lr = self.config.learning_rate;
558
559 for s in 0..self.location_steps.len() {
561 let sigma = crate::math::exp(log_sigma).max(1e-8);
562 let z = (target - mu) / sigma;
563
564 let (g_mu, h_mu) = self.location_gradient(mu, target);
566 self.update_ensemble_grad_stats(g_mu);
568
569 let g_sigma = 1.0 - z * z;
571 let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
572
573 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
574
575 let loc_pred =
577 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
578 mu += (base_lr * sigma_ratio) * loc_pred;
579
580 let scale_pred =
582 self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
583 log_sigma += base_lr * scale_pred;
584 }
585
586 let err = target - mu;
588 let alpha = self.empirical_sigma_alpha;
589 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
590
591 self.maybe_refresh_packed_cache();
593 }
594
595 pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
604 let mu = if let Some(ref cache) = self.packed_cache {
606 let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
607 match crate::EnsembleView::from_bytes(&cache.bytes) {
608 Ok(view) => {
609 let packed_mu = cache.base + view.predict(&features_f32) as f64;
610 if packed_mu.is_finite() {
611 packed_mu
612 } else {
613 self.predict_full_trees(features)
614 }
615 }
616 Err(_) => self.predict_full_trees(features),
617 }
618 } else {
619 self.predict_full_trees(features)
620 };
621
622 let (sigma, log_sigma) = match self.scale_mode {
623 ScaleMode::Empirical => {
624 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
625 (s, crate::math::ln(s))
626 }
627 ScaleMode::TreeChain => {
628 let mut ls = self.scale_base;
629 if self.auto_bandwidths.is_empty() {
630 for s in 0..self.scale_steps.len() {
631 ls += self.config.learning_rate * self.scale_steps[s].predict(features);
632 }
633 } else {
634 for s in 0..self.scale_steps.len() {
635 ls += self.config.learning_rate
636 * self.scale_steps[s]
637 .predict_smooth_auto(features, &self.auto_bandwidths);
638 }
639 }
640 (crate::math::exp(ls).max(1e-8), ls)
641 }
642 };
643
644 GaussianPrediction {
645 mu,
646 sigma,
647 log_sigma,
648 }
649 }
650
651 fn predict_full_trees(&self, features: &[f64]) -> f64 {
653 let mut mu = self.location_base;
654 if self.auto_bandwidths.is_empty() {
655 for s in 0..self.location_steps.len() {
656 mu += self.config.learning_rate * self.location_steps[s].predict(features);
657 }
658 } else {
659 for s in 0..self.location_steps.len() {
660 mu += self.config.learning_rate
661 * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
662 }
663 }
664 mu
665 }
666
667 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
677 let mut mu = self.location_base;
678 for s in 0..self.location_steps.len() {
679 mu += self.config.learning_rate
680 * self.location_steps[s].predict_smooth(features, bandwidth);
681 }
682
683 let (sigma, log_sigma) = match self.scale_mode {
684 ScaleMode::Empirical => {
685 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
686 (s, crate::math::ln(s))
687 }
688 ScaleMode::TreeChain => {
689 let mut ls = self.scale_base;
690 for s in 0..self.scale_steps.len() {
691 ls += self.config.learning_rate
692 * self.scale_steps[s].predict_smooth(features, bandwidth);
693 }
694 (crate::math::exp(ls).max(1e-8), ls)
695 }
696 };
697
698 GaussianPrediction {
699 mu,
700 sigma,
701 log_sigma,
702 }
703 }
704
705 pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
710 let mut mu = self.location_base;
711 for s in 0..self.location_steps.len() {
712 mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
713 }
714
715 let (sigma, log_sigma) = match self.scale_mode {
716 ScaleMode::Empirical => {
717 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
718 (s, crate::math::ln(s))
719 }
720 ScaleMode::TreeChain => {
721 let mut ls = self.scale_base;
722 for s in 0..self.scale_steps.len() {
723 ls += self.config.learning_rate
724 * self.scale_steps[s].predict_interpolated(features);
725 }
726 (crate::math::exp(ls).max(1e-8), ls)
727 }
728 };
729
730 GaussianPrediction {
731 mu,
732 sigma,
733 log_sigma,
734 }
735 }
736
737 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
743 let mut mu = self.location_base;
744 for s in 0..self.location_steps.len() {
745 mu += self.config.learning_rate
746 * self.location_steps[s]
747 .predict_sibling_interpolated(features, &self.auto_bandwidths);
748 }
749
750 let (sigma, log_sigma) = match self.scale_mode {
751 ScaleMode::Empirical => {
752 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
753 (s, crate::math::ln(s))
754 }
755 ScaleMode::TreeChain => {
756 let mut ls = self.scale_base;
757 for s in 0..self.scale_steps.len() {
758 ls += self.config.learning_rate
759 * self.scale_steps[s]
760 .predict_sibling_interpolated(features, &self.auto_bandwidths);
761 }
762 (crate::math::exp(ls).max(1e-8), ls)
763 }
764 };
765
766 GaussianPrediction {
767 mu,
768 sigma,
769 log_sigma,
770 }
771 }
772
773 pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
778 let mut mu = self.location_base;
779 for s in 0..self.location_steps.len() {
780 mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
781 }
782
783 let (sigma, log_sigma) = match self.scale_mode {
784 ScaleMode::Empirical => {
785 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
786 (s, crate::math::ln(s))
787 }
788 ScaleMode::TreeChain => {
789 let mut ls = self.scale_base;
790 for s in 0..self.scale_steps.len() {
791 ls +=
792 self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
793 }
794 (crate::math::exp(ls).max(1e-8), ls)
795 }
796 };
797
798 GaussianPrediction {
799 mu,
800 sigma,
801 log_sigma,
802 }
803 }
804
805 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
807 let mut mu = self.location_base;
808 for s in 0..self.location_steps.len() {
809 mu += self.config.learning_rate
810 * self.location_steps[s]
811 .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
812 }
813
814 let (sigma, log_sigma) = match self.scale_mode {
815 ScaleMode::Empirical => {
816 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
817 (s, crate::math::ln(s))
818 }
819 ScaleMode::TreeChain => {
820 let mut ls = self.scale_base;
821 for s in 0..self.scale_steps.len() {
822 ls += self.config.learning_rate
823 * self.scale_steps[s].predict_graduated_sibling_interpolated(
824 features,
825 &self.auto_bandwidths,
826 );
827 }
828 (crate::math::exp(ls).max(1e-8), ls)
829 }
830 };
831
832 GaussianPrediction {
833 mu,
834 sigma,
835 log_sigma,
836 }
837 }
838
839 pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
848 let pred = self.predict(features);
849 let sigma_ratio = if self.uncertainty_modulated_lr {
850 (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
851 } else {
852 1.0
853 };
854 (pred.mu, pred.sigma, sigma_ratio)
855 }
856
857 #[inline]
861 pub fn empirical_sigma(&self) -> f64 {
862 crate::math::sqrt(self.ewma_sq_err)
863 }
864
865 #[inline]
867 pub fn scale_mode(&self) -> ScaleMode {
868 self.scale_mode
869 }
870
871 #[inline]
877 pub fn sigma_velocity(&self) -> f64 {
878 self.sigma_velocity
879 }
880
881 #[inline]
883 pub fn predict_mu(&self, features: &[f64]) -> f64 {
884 self.predict(features).mu
885 }
886
887 #[inline]
889 pub fn predict_sigma(&self, features: &[f64]) -> f64 {
890 self.predict(features).sigma
891 }
892
893 pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
900 let pred = self.predict(features);
901 (pred.lower(confidence), pred.upper(confidence))
902 }
903
904 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
906 feature_matrix.iter().map(|f| self.predict(f)).collect()
907 }
908
909 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
911 for sample in samples {
912 self.train_one(sample);
913 }
914 }
915
916 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
918 &mut self,
919 samples: &[O],
920 interval: usize,
921 mut callback: F,
922 ) {
923 let interval = interval.max(1);
924 for (i, sample) in samples.iter().enumerate() {
925 self.train_one(sample);
926 if (i + 1) % interval == 0 {
927 callback(i + 1);
928 }
929 }
930 let total = samples.len();
931 if total % interval != 0 {
932 callback(total);
933 }
934 }
935
936 #[inline]
941 fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
942 if let Some(k) = self.config.huber_k {
943 let delta = k * crate::math::sqrt(self.ewma_sq_err).max(1e-8);
944 let residual = mu - target;
945 if crate::math::abs(residual) <= delta {
946 (residual, 1.0)
947 } else {
948 (delta * residual.signum(), 1e-6)
949 }
950 } else {
951 (mu - target, 1.0)
952 }
953 }
954
955 #[inline]
957 fn update_ensemble_grad_stats(&mut self, gradient: f64) {
958 self.ensemble_grad_count += 1;
959 let delta = gradient - self.ensemble_grad_mean;
960 self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
961 let delta2 = gradient - self.ensemble_grad_mean;
962 self.ensemble_grad_m2 += delta * delta2;
963 }
964
965 pub fn ensemble_grad_std(&self) -> f64 {
967 if self.ensemble_grad_count < 2 {
968 return 0.0;
969 }
970 crate::math::fmax(
971 crate::math::sqrt(self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64),
972 0.0,
973 )
974 }
975
976 pub fn ensemble_grad_mean(&self) -> f64 {
978 self.ensemble_grad_mean
979 }
980
981 fn maybe_refresh_packed_cache(&mut self) {
983 if self.packed_refresh_interval > 0 {
984 self.samples_since_refresh += 1;
985 if self.samples_since_refresh >= self.packed_refresh_interval {
986 self.refresh_packed_cache();
987 self.samples_since_refresh = 0;
988 }
989 }
990 }
991
992 fn refresh_packed_cache(&mut self) {
998 }
1003
1004 pub fn enable_packed_cache(&mut self, interval: u64) {
1009 self.packed_refresh_interval = interval;
1010 self.samples_since_refresh = 0;
1011 if interval > 0 && self.base_initialized {
1012 self.refresh_packed_cache();
1013 } else if interval == 0 {
1014 self.packed_cache = None;
1015 }
1016 }
1017
1018 #[inline]
1020 pub fn has_packed_cache(&self) -> bool {
1021 self.packed_cache.is_some()
1022 }
1023
1024 fn refresh_bandwidths(&mut self) {
1026 let current_sum: u64 = self
1027 .location_steps
1028 .iter()
1029 .chain(self.scale_steps.iter())
1030 .map(|s| s.slot().replacements())
1031 .sum();
1032 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1033 self.auto_bandwidths = self.compute_auto_bandwidths();
1034 self.last_replacement_sum = current_sum;
1035 }
1036 }
1037
1038 fn compute_auto_bandwidths(&self) -> Vec<f64> {
1040 const K: f64 = 2.0;
1041
1042 let n_features = self
1043 .location_steps
1044 .iter()
1045 .chain(self.scale_steps.iter())
1046 .filter_map(|s| s.slot().active_tree().n_features())
1047 .max()
1048 .unwrap_or(0);
1049
1050 if n_features == 0 {
1051 return Vec::new();
1052 }
1053
1054 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1055
1056 for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1057 let tree_thresholds = step
1058 .slot()
1059 .active_tree()
1060 .collect_split_thresholds_per_feature();
1061 for (i, ts) in tree_thresholds.into_iter().enumerate() {
1062 if i < n_features {
1063 all_thresholds[i].extend(ts);
1064 }
1065 }
1066 }
1067
1068 let n_bins = self.config.n_bins as f64;
1069
1070 all_thresholds
1071 .iter()
1072 .map(|ts| {
1073 if ts.is_empty() {
1074 return f64::INFINITY;
1075 }
1076
1077 let mut sorted = ts.clone();
1078 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1079 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
1080
1081 if sorted.len() < 2 {
1082 return f64::INFINITY;
1083 }
1084
1085 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1086
1087 if sorted.len() < 3 {
1088 let range = sorted.last().unwrap() - sorted.first().unwrap();
1089 if range < 1e-15 {
1090 return f64::INFINITY;
1091 }
1092 return (range / n_bins) * K;
1093 }
1094
1095 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1096 let median_gap = if gaps.len() % 2 == 0 {
1097 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1098 } else {
1099 gaps[gaps.len() / 2]
1100 };
1101
1102 if median_gap < 1e-15 {
1103 f64::INFINITY
1104 } else {
1105 median_gap * K
1106 }
1107 })
1108 .collect()
1109 }
1110
1111 pub fn auto_bandwidths(&self) -> &[f64] {
1113 &self.auto_bandwidths
1114 }
1115
1116 pub fn reset(&mut self) {
1118 for step in &mut self.location_steps {
1119 step.reset();
1120 }
1121 for step in &mut self.scale_steps {
1122 step.reset();
1123 }
1124 self.location_base = 0.0;
1125 self.scale_base = 0.0;
1126 self.base_initialized = false;
1127 self.initial_targets.clear();
1128 self.samples_seen = 0;
1129 self.rng_state = self.config.seed;
1130 self.rolling_sigma_mean = 1.0;
1131 self.ewma_sq_err = 1.0;
1132 self.prev_sigma = 0.0;
1133 self.sigma_velocity = 0.0;
1134 self.auto_bandwidths.clear();
1135 self.last_replacement_sum = 0;
1136 self.ensemble_grad_mean = 0.0;
1137 self.ensemble_grad_m2 = 0.0;
1138 self.ensemble_grad_count = 0;
1139 self.packed_cache = None;
1140 self.samples_since_refresh = 0;
1141 }
1142
1143 #[inline]
1145 pub fn n_samples_seen(&self) -> u64 {
1146 self.samples_seen
1147 }
1148
1149 #[inline]
1151 pub fn n_steps(&self) -> usize {
1152 self.location_steps.len()
1153 }
1154
1155 pub fn n_trees(&self) -> usize {
1157 let loc = self.location_steps.len()
1158 + self
1159 .location_steps
1160 .iter()
1161 .filter(|s| s.has_alternate())
1162 .count();
1163 let scale = self.scale_steps.len()
1164 + self
1165 .scale_steps
1166 .iter()
1167 .filter(|s| s.has_alternate())
1168 .count();
1169 loc + scale
1170 }
1171
1172 pub fn total_leaves(&self) -> usize {
1174 let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1175 let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1176 loc + scale
1177 }
1178
1179 #[inline]
1181 pub fn is_initialized(&self) -> bool {
1182 self.base_initialized
1183 }
1184
1185 #[inline]
1187 pub fn config(&self) -> &SGBTConfig {
1188 &self.config
1189 }
1190
1191 pub fn location_steps(&self) -> &[BoostingStep] {
1193 &self.location_steps
1194 }
1195
1196 #[inline]
1198 pub fn location_base(&self) -> f64 {
1199 self.location_base
1200 }
1201
1202 #[inline]
1204 pub fn learning_rate(&self) -> f64 {
1205 self.config.learning_rate
1206 }
1207
1208 #[inline]
1212 pub fn rolling_sigma_mean(&self) -> f64 {
1213 self.rolling_sigma_mean
1214 }
1215
1216 #[inline]
1218 pub fn is_uncertainty_modulated(&self) -> bool {
1219 self.uncertainty_modulated_lr
1220 }
1221
1222 pub fn diagnostics(&self) -> ModelDiagnostics {
1234 let n = self.location_steps.len();
1235 let mut trees = Vec::with_capacity(2 * n);
1236 let mut feature_split_counts: Vec<usize> = Vec::new();
1237
1238 fn collect_tree_diags(
1239 steps: &[BoostingStep],
1240 trees: &mut Vec<TreeDiagnostic>,
1241 feature_split_counts: &mut Vec<usize>,
1242 ) {
1243 for step in steps {
1244 let slot = step.slot();
1245 let tree = slot.active_tree();
1246 let arena = tree.arena();
1247
1248 let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1249 .filter(|&i| arena.is_leaf[i])
1250 .map(|i| arena.leaf_value[i])
1251 .collect();
1252
1253 let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1254 .filter(|&i| arena.is_leaf[i])
1255 .map(|i| arena.sample_count[i])
1256 .collect();
1257
1258 let max_depth_reached = (0..arena.is_leaf.len())
1259 .filter(|&i| arena.is_leaf[i])
1260 .map(|i| arena.depth[i] as usize)
1261 .max()
1262 .unwrap_or(0);
1263
1264 let leaf_weight_stats = if leaf_values.is_empty() {
1265 (0.0, 0.0, 0.0, 0.0)
1266 } else {
1267 let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1268 let max = leaf_values
1269 .iter()
1270 .cloned()
1271 .fold(f64::NEG_INFINITY, f64::max);
1272 let sum: f64 = leaf_values.iter().sum();
1273 let mean = sum / leaf_values.len() as f64;
1274 let var: f64 = leaf_values
1275 .iter()
1276 .map(|v| crate::math::powi(v - mean, 2))
1277 .sum::<f64>()
1278 / leaf_values.len() as f64;
1279 (min, max, mean, crate::math::sqrt(var))
1280 };
1281
1282 let gains = slot.split_gains();
1283 let split_features: Vec<usize> = gains
1284 .iter()
1285 .enumerate()
1286 .filter(|(_, &g)| g > 0.0)
1287 .map(|(i, _)| i)
1288 .collect();
1289
1290 if !gains.is_empty() {
1291 if feature_split_counts.is_empty() {
1292 feature_split_counts.resize(gains.len(), 0);
1293 }
1294 for &fi in &split_features {
1295 if fi < feature_split_counts.len() {
1296 feature_split_counts[fi] += 1;
1297 }
1298 }
1299 }
1300
1301 trees.push(TreeDiagnostic {
1302 n_leaves: leaf_values.len(),
1303 max_depth_reached,
1304 samples_seen: step.n_samples_seen(),
1305 leaf_weight_stats,
1306 split_features,
1307 leaf_sample_counts,
1308 prediction_mean: slot.prediction_mean(),
1309 prediction_std: slot.prediction_std(),
1310 });
1311 }
1312 }
1313
1314 collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1315 collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1316
1317 let location_trees = trees[..n].to_vec();
1318 let scale_trees = trees[n..].to_vec();
1319 let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1320
1321 ModelDiagnostics {
1322 trees,
1323 location_trees,
1324 scale_trees,
1325 feature_split_counts,
1326 location_base: self.location_base,
1327 scale_base: self.scale_base,
1328 empirical_sigma: crate::math::sqrt(self.ewma_sq_err),
1329 scale_mode: self.scale_mode,
1330 scale_trees_active,
1331 auto_bandwidths: self.auto_bandwidths.clone(),
1332 ensemble_grad_mean: self.ensemble_grad_mean,
1333 ensemble_grad_std: self.ensemble_grad_std(),
1334 }
1335 }
1336
1337 pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1349 let lr = self.config.learning_rate;
1350 let location: Vec<f64> = self
1351 .location_steps
1352 .iter()
1353 .map(|s| lr * s.predict(features))
1354 .collect();
1355
1356 let (sb, scale) = match self.scale_mode {
1357 ScaleMode::Empirical => {
1358 let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1359 (
1360 crate::math::ln(empirical_sigma),
1361 vec![0.0; self.location_steps.len()],
1362 )
1363 }
1364 ScaleMode::TreeChain => {
1365 let s: Vec<f64> = self
1366 .scale_steps
1367 .iter()
1368 .map(|s| lr * s.predict(features))
1369 .collect();
1370 (self.scale_base, s)
1371 }
1372 };
1373
1374 DecomposedPrediction {
1375 location_base: self.location_base,
1376 scale_base: sb,
1377 location_contributions: location,
1378 scale_contributions: scale,
1379 }
1380 }
1381
1382 pub fn feature_importances(&self) -> Vec<f64> {
1388 let mut totals: Vec<f64> = Vec::new();
1389 for steps in [&self.location_steps, &self.scale_steps] {
1390 for step in steps {
1391 let gains = step.slot().split_gains();
1392 if totals.is_empty() && !gains.is_empty() {
1393 totals.resize(gains.len(), 0.0);
1394 }
1395 for (i, &g) in gains.iter().enumerate() {
1396 if i < totals.len() {
1397 totals[i] += g;
1398 }
1399 }
1400 }
1401 }
1402 let sum: f64 = totals.iter().sum();
1403 if sum > 0.0 {
1404 totals.iter_mut().for_each(|v| *v /= sum);
1405 }
1406 totals
1407 }
1408
1409 pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1414 fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1415 let mut totals: Vec<f64> = Vec::new();
1416 for step in steps {
1417 let gains = step.slot().split_gains();
1418 if totals.is_empty() && !gains.is_empty() {
1419 totals.resize(gains.len(), 0.0);
1420 }
1421 for (i, &g) in gains.iter().enumerate() {
1422 if i < totals.len() {
1423 totals[i] += g;
1424 }
1425 }
1426 }
1427 let sum: f64 = totals.iter().sum();
1428 if sum > 0.0 {
1429 totals.iter_mut().for_each(|v| *v /= sum);
1430 }
1431 totals
1432 }
1433 (
1434 aggregate(&self.location_steps),
1435 aggregate(&self.scale_steps),
1436 )
1437 }
1438
1439 #[cfg(feature = "_serde_support")]
1445 pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1446 use super::snapshot_tree;
1447 use crate::serde_support::{DistributionalModelState, StepSnapshot};
1448
1449 fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1450 let slot = step.slot();
1451 let tree_snap = snapshot_tree(slot.active_tree());
1452 let alt_snap = slot.alternate_tree().map(snapshot_tree);
1453 let drift_state = slot.detector().serialize_state();
1454 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1455 StepSnapshot {
1456 tree: tree_snap,
1457 alternate_tree: alt_snap,
1458 drift_state,
1459 alt_drift_state,
1460 }
1461 }
1462
1463 DistributionalModelState {
1464 config: self.config.clone(),
1465 location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1466 scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1467 location_base: self.location_base,
1468 scale_base: self.scale_base,
1469 base_initialized: self.base_initialized,
1470 initial_targets: self.initial_targets.clone(),
1471 initial_target_count: self.initial_target_count,
1472 samples_seen: self.samples_seen,
1473 rng_state: self.rng_state,
1474 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1475 rolling_sigma_mean: self.rolling_sigma_mean,
1476 ewma_sq_err: self.ewma_sq_err,
1477 }
1478 }
1479
1480 #[cfg(feature = "_serde_support")]
1486 pub fn from_distributional_state(
1487 state: crate::serde_support::DistributionalModelState,
1488 ) -> Self {
1489 use super::rebuild_tree;
1490 use crate::ensemble::replacement::TreeSlot;
1491 use crate::serde_support::StepSnapshot;
1492
1493 let leaf_decay_alpha = state
1494 .config
1495 .leaf_half_life
1496 .map(|hl| crate::math::exp((-(crate::math::ln(2.0_f64)) / hl as f64)));
1497 let max_tree_samples = state.config.max_tree_samples;
1498
1499 let base_tree_config = TreeConfig::new()
1500 .max_depth(state.config.max_depth)
1501 .n_bins(state.config.n_bins)
1502 .lambda(state.config.lambda)
1503 .gamma(state.config.gamma)
1504 .grace_period(state.config.grace_period)
1505 .delta(state.config.delta)
1506 .feature_subsample_rate(state.config.feature_subsample_rate)
1507 .leaf_decay_alpha_opt(leaf_decay_alpha)
1508 .split_reeval_interval_opt(state.config.split_reeval_interval)
1509 .feature_types_opt(state.config.feature_types.clone())
1510 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1511 .monotone_constraints_opt(state.config.monotone_constraints.clone())
1512 .leaf_model_type(state.config.leaf_model_type.clone());
1513
1514 let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
1516 snaps
1517 .iter()
1518 .enumerate()
1519 .map(|(i, snap)| {
1520 let tc = base_tree_config
1521 .clone()
1522 .seed(state.config.seed ^ (i as u64) ^ seed_xor);
1523
1524 let active = rebuild_tree(&snap.tree, tc.clone());
1525 let alternate = snap
1526 .alternate_tree
1527 .as_ref()
1528 .map(|s| rebuild_tree(s, tc.clone()));
1529
1530 let mut detector = state.config.drift_detector.create();
1531 if let Some(ref ds) = snap.drift_state {
1532 detector.restore_state(ds);
1533 }
1534 let mut slot =
1535 TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
1536 if let Some(ref ads) = snap.alt_drift_state {
1537 if let Some(alt_det) = slot.alt_detector_mut() {
1538 alt_det.restore_state(ads);
1539 }
1540 }
1541 BoostingStep::from_slot(slot)
1542 })
1543 .collect()
1544 };
1545
1546 let location_steps = rebuild_steps(&state.location_steps, 0);
1548 let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
1549
1550 let scale_mode = state.config.scale_mode;
1551 let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
1552 let packed_refresh_interval = state.config.packed_refresh_interval;
1553 Self {
1554 config: state.config,
1555 location_steps,
1556 scale_steps,
1557 location_base: state.location_base,
1558 scale_base: state.scale_base,
1559 base_initialized: state.base_initialized,
1560 initial_targets: state.initial_targets,
1561 initial_target_count: state.initial_target_count,
1562 samples_seen: state.samples_seen,
1563 rng_state: state.rng_state,
1564 uncertainty_modulated_lr: state.uncertainty_modulated_lr,
1565 rolling_sigma_mean: state.rolling_sigma_mean,
1566 scale_mode,
1567 ewma_sq_err: state.ewma_sq_err,
1568 empirical_sigma_alpha,
1569 prev_sigma: 0.0,
1570 sigma_velocity: 0.0,
1571 auto_bandwidths: Vec::new(),
1572 last_replacement_sum: 0,
1573 ensemble_grad_mean: 0.0,
1574 ensemble_grad_m2: 0.0,
1575 ensemble_grad_count: 0,
1576 packed_cache: None,
1577 samples_since_refresh: 0,
1578 packed_refresh_interval,
1579 }
1580 }
1581}
1582
1583use crate::learner::StreamingLearner;
1588
1589impl StreamingLearner for DistributionalSGBT {
1590 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
1591 let sample = SampleRef::weighted(features, target, weight);
1592 DistributionalSGBT::train_one(self, &sample);
1594 }
1595
1596 fn predict(&self, features: &[f64]) -> f64 {
1598 DistributionalSGBT::predict(self, features).mu
1599 }
1600
1601 fn n_samples_seen(&self) -> u64 {
1602 self.samples_seen
1603 }
1604
1605 fn reset(&mut self) {
1606 DistributionalSGBT::reset(self);
1607 }
1608}
1609
1610#[cfg(test)]
1615mod tests {
1616 use super::*;
1617 use alloc::format;
1618 use alloc::vec;
1619 use alloc::vec::Vec;
1620
1621 fn test_config() -> SGBTConfig {
1622 SGBTConfig::builder()
1623 .n_steps(10)
1624 .learning_rate(0.1)
1625 .grace_period(20)
1626 .max_depth(4)
1627 .n_bins(16)
1628 .initial_target_count(10)
1629 .build()
1630 .unwrap()
1631 }
1632
1633 #[test]
1634 fn fresh_model_predicts_zero() {
1635 let model = DistributionalSGBT::new(test_config());
1636 let pred = model.predict(&[1.0, 2.0, 3.0]);
1637 assert!(pred.mu.abs() < 1e-12);
1638 assert!(pred.sigma > 0.0);
1639 }
1640
1641 #[test]
1642 fn sigma_always_positive() {
1643 let mut model = DistributionalSGBT::new(test_config());
1644
1645 for i in 0..200 {
1647 let x = i as f64 * 0.1;
1648 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1649 }
1650
1651 for i in 0..20 {
1653 let x = i as f64 * 0.5;
1654 let pred = model.predict(&[x, x * 0.5]);
1655 assert!(
1656 pred.sigma > 0.0,
1657 "sigma must be positive, got {}",
1658 pred.sigma
1659 );
1660 assert!(pred.sigma.is_finite(), "sigma must be finite");
1661 }
1662 }
1663
1664 #[test]
1665 fn constant_target_has_small_sigma() {
1666 let mut model = DistributionalSGBT::new(test_config());
1667
1668 for i in 0..200 {
1670 let x = i as f64 * 0.1;
1671 model.train_one(&(vec![x, x * 2.0], 5.0));
1672 }
1673
1674 let pred = model.predict(&[1.0, 2.0]);
1675 assert!(pred.mu.is_finite());
1676 assert!(pred.sigma.is_finite());
1677 assert!(pred.sigma > 0.0);
1678 }
1681
1682 #[test]
1683 fn noisy_target_has_finite_predictions() {
1684 let mut model = DistributionalSGBT::new(test_config());
1685
1686 let mut rng: u64 = 42;
1688 for i in 0..200 {
1689 rng ^= rng << 13;
1690 rng ^= rng >> 7;
1691 rng ^= rng << 17;
1692 let noise = (rng % 1000) as f64 / 500.0 - 1.0; let x = i as f64 * 0.1;
1694 model.train_one(&(vec![x], x * 2.0 + noise));
1695 }
1696
1697 let pred = model.predict(&[5.0]);
1698 assert!(pred.mu.is_finite());
1699 assert!(pred.sigma.is_finite());
1700 assert!(pred.sigma > 0.0);
1701 }
1702
1703 #[test]
1704 fn predict_interval_bounds_correct() {
1705 let mut model = DistributionalSGBT::new(test_config());
1706
1707 for i in 0..200 {
1708 let x = i as f64 * 0.1;
1709 model.train_one(&(vec![x], x * 2.0));
1710 }
1711
1712 let (lo, hi) = model.predict_interval(&[5.0], 1.96);
1713 let pred = model.predict(&[5.0]);
1714
1715 assert!(lo < pred.mu, "lower bound should be < mu");
1716 assert!(hi > pred.mu, "upper bound should be > mu");
1717 assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
1718 }
1719
1720 #[test]
1721 fn batch_prediction_matches_individual() {
1722 let mut model = DistributionalSGBT::new(test_config());
1723
1724 for i in 0..100 {
1725 let x = i as f64 * 0.1;
1726 model.train_one(&(vec![x, x * 2.0], x));
1727 }
1728
1729 let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
1730 let batch = model.predict_batch(&features);
1731
1732 for (feat, batch_pred) in features.iter().zip(batch.iter()) {
1733 let individual = model.predict(feat);
1734 assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
1735 assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
1736 }
1737 }
1738
1739 #[test]
1740 fn reset_clears_state() {
1741 let mut model = DistributionalSGBT::new(test_config());
1742
1743 for i in 0..200 {
1744 let x = i as f64 * 0.1;
1745 model.train_one(&(vec![x], x * 2.0));
1746 }
1747
1748 assert!(model.n_samples_seen() > 0);
1749 model.reset();
1750
1751 assert_eq!(model.n_samples_seen(), 0);
1752 assert!(!model.is_initialized());
1753 }
1754
1755 #[test]
1756 fn gaussian_prediction_lower_upper() {
1757 let pred = GaussianPrediction {
1758 mu: 10.0,
1759 sigma: 2.0,
1760 log_sigma: 2.0_f64.ln(),
1761 };
1762
1763 assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
1764 assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
1765 }
1766
1767 #[test]
1768 fn train_batch_works() {
1769 let mut model = DistributionalSGBT::new(test_config());
1770 let samples: Vec<(Vec<f64>, f64)> = (0..100)
1771 .map(|i| {
1772 let x = i as f64 * 0.1;
1773 (vec![x], x * 2.0)
1774 })
1775 .collect();
1776
1777 model.train_batch(&samples);
1778 assert_eq!(model.n_samples_seen(), 100);
1779 }
1780
1781 #[test]
1782 fn debug_format_works() {
1783 let model = DistributionalSGBT::new(test_config());
1784 let debug = format!("{:?}", model);
1785 assert!(debug.contains("DistributionalSGBT"));
1786 }
1787
1788 #[test]
1789 fn n_trees_counts_both_ensembles() {
1790 let model = DistributionalSGBT::new(test_config());
1791 assert!(model.n_trees() >= 20);
1793 }
1794
1795 fn modulated_config() -> SGBTConfig {
1798 SGBTConfig::builder()
1799 .n_steps(10)
1800 .learning_rate(0.1)
1801 .grace_period(20)
1802 .max_depth(4)
1803 .n_bins(16)
1804 .initial_target_count(10)
1805 .uncertainty_modulated_lr(true)
1806 .build()
1807 .unwrap()
1808 }
1809
1810 #[test]
1811 fn sigma_modulated_initializes_rolling_mean() {
1812 let mut model = DistributionalSGBT::new(modulated_config());
1813 assert!(model.is_uncertainty_modulated());
1814
1815 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1817
1818 for i in 0..200 {
1820 let x = i as f64 * 0.1;
1821 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1822 }
1823
1824 assert!(model.rolling_sigma_mean() > 0.0);
1826 assert!(model.rolling_sigma_mean().is_finite());
1827 }
1828
1829 #[test]
1830 fn predict_distributional_returns_sigma_ratio() {
1831 let mut model = DistributionalSGBT::new(modulated_config());
1832
1833 for i in 0..200 {
1834 let x = i as f64 * 0.1;
1835 model.train_one(&(vec![x], x * 2.0 + 1.0));
1836 }
1837
1838 let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1839 assert!(mu.is_finite());
1840 assert!(sigma > 0.0);
1841 assert!(
1842 (0.1..=10.0).contains(&sigma_ratio),
1843 "sigma_ratio={}",
1844 sigma_ratio
1845 );
1846 }
1847
1848 #[test]
1849 fn predict_distributional_without_modulation_returns_one() {
1850 let mut model = DistributionalSGBT::new(test_config());
1851 assert!(!model.is_uncertainty_modulated());
1852
1853 for i in 0..200 {
1854 let x = i as f64 * 0.1;
1855 model.train_one(&(vec![x], x * 2.0));
1856 }
1857
1858 let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1859 assert!(
1860 (sigma_ratio - 1.0).abs() < 1e-12,
1861 "should be 1.0 when disabled"
1862 );
1863 }
1864
1865 #[test]
1866 fn modulated_model_sigma_finite_under_varying_noise() {
1867 let mut model = DistributionalSGBT::new(modulated_config());
1868
1869 let mut rng: u64 = 123;
1870 for i in 0..500 {
1871 rng ^= rng << 13;
1872 rng ^= rng >> 7;
1873 rng ^= rng << 17;
1874 let noise = (rng % 1000) as f64 / 100.0 - 5.0; let x = i as f64 * 0.1;
1876 let scale = if i < 250 { 1.0 } else { 5.0 };
1878 model.train_one(&(vec![x], x * 2.0 + noise * scale));
1879 }
1880
1881 let pred = model.predict(&[10.0]);
1882 assert!(pred.mu.is_finite());
1883 assert!(pred.sigma.is_finite());
1884 assert!(pred.sigma > 0.0);
1885 assert!(model.rolling_sigma_mean().is_finite());
1886 }
1887
1888 #[test]
1889 fn reset_clears_rolling_sigma_mean() {
1890 let mut model = DistributionalSGBT::new(modulated_config());
1891
1892 for i in 0..200 {
1893 let x = i as f64 * 0.1;
1894 model.train_one(&(vec![x], x * 2.0));
1895 }
1896
1897 let sigma_before = model.rolling_sigma_mean();
1898 assert!(sigma_before > 0.0);
1899
1900 model.reset();
1901 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1902 }
1903
1904 #[test]
1905 fn streaming_learner_returns_mu() {
1906 let mut model = DistributionalSGBT::new(test_config());
1907 for i in 0..200 {
1908 let x = i as f64 * 0.1;
1909 StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
1910 }
1911 let pred = StreamingLearner::predict(&model, &[5.0]);
1912 let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
1913 assert!(
1914 (pred - gaussian.mu).abs() < 1e-12,
1915 "StreamingLearner::predict should return mu"
1916 );
1917 }
1918
1919 fn trained_model() -> DistributionalSGBT {
1922 let config = SGBTConfig::builder()
1923 .n_steps(10)
1924 .learning_rate(0.1)
1925 .grace_period(10) .max_depth(4)
1927 .n_bins(16)
1928 .initial_target_count(10)
1929 .build()
1930 .unwrap();
1931 let mut model = DistributionalSGBT::new(config);
1932 for i in 0..500 {
1933 let x = i as f64 * 0.1;
1934 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
1935 }
1936 model
1937 }
1938
1939 #[test]
1940 fn diagnostics_returns_correct_tree_count() {
1941 let model = trained_model();
1942 let diag = model.diagnostics();
1943 assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
1945 }
1946
1947 #[test]
1948 fn diagnostics_trees_have_leaves() {
1949 let model = trained_model();
1950 let diag = model.diagnostics();
1951 for (i, tree) in diag.trees.iter().enumerate() {
1952 assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
1953 }
1954 let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
1956 assert!(
1957 total_samples > 0,
1958 "at least some trees should have seen samples"
1959 );
1960 }
1961
1962 #[test]
1963 fn diagnostics_leaf_weight_stats_finite() {
1964 let model = trained_model();
1965 let diag = model.diagnostics();
1966 for (i, tree) in diag.trees.iter().enumerate() {
1967 let (min, max, mean, std) = tree.leaf_weight_stats;
1968 assert!(min.is_finite(), "tree {i} min not finite");
1969 assert!(max.is_finite(), "tree {i} max not finite");
1970 assert!(mean.is_finite(), "tree {i} mean not finite");
1971 assert!(std.is_finite(), "tree {i} std not finite");
1972 assert!(min <= max, "tree {i} min > max");
1973 }
1974 }
1975
1976 #[test]
1977 fn diagnostics_base_predictions_match() {
1978 let model = trained_model();
1979 let diag = model.diagnostics();
1980 assert!(
1981 (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
1982 "location_base should be plausible"
1983 );
1984 }
1985
1986 #[test]
1987 fn predict_decomposed_reconstructs_prediction() {
1988 let model = trained_model();
1989 let features = [5.0, 2.5, 1.0];
1990 let pred = model.predict(&features);
1991 let decomp = model.predict_decomposed(&features);
1992
1993 assert!(
1994 (decomp.mu() - pred.mu).abs() < 1e-10,
1995 "decomposed mu ({}) != predict mu ({})",
1996 decomp.mu(),
1997 pred.mu
1998 );
1999 assert!(
2000 (decomp.sigma() - pred.sigma).abs() < 1e-10,
2001 "decomposed sigma ({}) != predict sigma ({})",
2002 decomp.sigma(),
2003 pred.sigma
2004 );
2005 }
2006
2007 #[test]
2008 fn predict_decomposed_correct_lengths() {
2009 let model = trained_model();
2010 let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2011 assert_eq!(
2012 decomp.location_contributions.len(),
2013 model.n_steps(),
2014 "location contributions should have n_steps entries"
2015 );
2016 assert_eq!(
2017 decomp.scale_contributions.len(),
2018 model.n_steps(),
2019 "scale contributions should have n_steps entries"
2020 );
2021 }
2022
2023 #[test]
2024 fn feature_importances_work() {
2025 let model = trained_model();
2026 let imp = model.feature_importances();
2027 for (i, &v) in imp.iter().enumerate() {
2030 assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2031 assert!(v.is_finite(), "importance {i} should be finite");
2032 }
2033 let sum: f64 = imp.iter().sum();
2034 if sum > 0.0 {
2035 assert!(
2036 (sum - 1.0).abs() < 1e-10,
2037 "non-zero importances should sum to 1.0, got {sum}"
2038 );
2039 }
2040 }
2041
2042 #[test]
2043 fn feature_importances_split_works() {
2044 let model = trained_model();
2045 let (loc_imp, scale_imp) = model.feature_importances_split();
2046 for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2047 let sum: f64 = imp.iter().sum();
2048 if sum > 0.0 {
2049 assert!(
2050 (sum - 1.0).abs() < 1e-10,
2051 "{name} importances should sum to 1.0, got {sum}"
2052 );
2053 }
2054 for &v in imp.iter() {
2055 assert!(v >= 0.0 && v.is_finite());
2056 }
2057 }
2058 }
2059
2060 #[test]
2063 fn empirical_sigma_default_mode() {
2064 use crate::ensemble::config::ScaleMode;
2065 let config = test_config();
2066 let model = DistributionalSGBT::new(config);
2067 assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2068 }
2069
2070 #[test]
2071 fn empirical_sigma_tracks_errors() {
2072 let mut model = DistributionalSGBT::new(test_config());
2073
2074 for i in 0..200 {
2076 let x = i as f64 * 0.1;
2077 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2078 }
2079
2080 let sigma_clean = model.empirical_sigma();
2081 assert!(sigma_clean > 0.0, "sigma should be positive");
2082 assert!(sigma_clean.is_finite(), "sigma should be finite");
2083
2084 let mut rng: u64 = 42;
2086 for i in 200..400 {
2087 rng ^= rng << 13;
2088 rng ^= rng >> 7;
2089 rng ^= rng << 17;
2090 let noise = (rng % 10000) as f64 / 100.0 - 50.0; let x = i as f64 * 0.1;
2092 model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2093 }
2094
2095 let sigma_noisy = model.empirical_sigma();
2096 assert!(
2097 sigma_noisy > sigma_clean,
2098 "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2099 );
2100 }
2101
2102 #[test]
2103 fn empirical_sigma_modulated_lr_adapts() {
2104 let config = SGBTConfig::builder()
2105 .n_steps(10)
2106 .learning_rate(0.1)
2107 .grace_period(20)
2108 .max_depth(4)
2109 .n_bins(16)
2110 .initial_target_count(10)
2111 .uncertainty_modulated_lr(true)
2112 .build()
2113 .unwrap();
2114 let mut model = DistributionalSGBT::new(config);
2115
2116 for i in 0..300 {
2118 let x = i as f64 * 0.1;
2119 model.train_one(&(vec![x], x * 2.0 + 1.0));
2120 }
2121
2122 let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2123 assert!(sigma_ratio.is_finite());
2124 assert!(
2125 (0.1..=10.0).contains(&sigma_ratio),
2126 "sigma_ratio={sigma_ratio}"
2127 );
2128 }
2129
2130 #[test]
2131 fn tree_chain_mode_trains_scale_trees() {
2132 use crate::ensemble::config::ScaleMode;
2133 let config = SGBTConfig::builder()
2134 .n_steps(10)
2135 .learning_rate(0.1)
2136 .grace_period(10)
2137 .max_depth(4)
2138 .n_bins(16)
2139 .initial_target_count(10)
2140 .scale_mode(ScaleMode::TreeChain)
2141 .build()
2142 .unwrap();
2143 let mut model = DistributionalSGBT::new(config);
2144 assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2145
2146 for i in 0..500 {
2147 let x = i as f64 * 0.1;
2148 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2149 }
2150
2151 let pred = model.predict(&[5.0, 2.5, 1.0]);
2152 assert!(pred.mu.is_finite());
2153 assert!(pred.sigma > 0.0);
2154 assert!(pred.sigma.is_finite());
2155 }
2156
2157 #[test]
2158 fn diagnostics_shows_empirical_sigma() {
2159 let model = trained_model();
2160 let diag = model.diagnostics();
2161 assert!(
2162 diag.empirical_sigma > 0.0,
2163 "empirical_sigma should be positive"
2164 );
2165 assert!(
2166 diag.empirical_sigma.is_finite(),
2167 "empirical_sigma should be finite"
2168 );
2169 }
2170
2171 #[test]
2172 fn diagnostics_scale_trees_split_fields() {
2173 let model = trained_model();
2174 let diag = model.diagnostics();
2175 assert_eq!(diag.location_trees.len(), model.n_steps());
2176 assert_eq!(diag.scale_trees.len(), model.n_steps());
2177 }
2180
2181 #[test]
2182 fn reset_clears_empirical_sigma() {
2183 let mut model = DistributionalSGBT::new(test_config());
2184 for i in 0..200 {
2185 let x = i as f64 * 0.1;
2186 model.train_one(&(vec![x], x * 2.0));
2187 }
2188 model.reset();
2189 assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2191 }
2192
2193 #[test]
2194 fn predict_smooth_returns_finite() {
2195 let config = SGBTConfig::builder()
2196 .n_steps(3)
2197 .learning_rate(0.1)
2198 .grace_period(20)
2199 .max_depth(4)
2200 .n_bins(16)
2201 .initial_target_count(10)
2202 .build()
2203 .unwrap();
2204 let mut model = DistributionalSGBT::new(config);
2205
2206 for i in 0..200 {
2207 let x = (i as f64) * 0.1;
2208 let features = vec![x, x.sin()];
2209 let target = 2.0 * x + 1.0;
2210 model.train_one(&(features, target));
2211 }
2212
2213 let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2214 assert!(pred.mu.is_finite(), "smooth mu should be finite");
2215 assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2216 assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2217 }
2218
2219 #[test]
2222 fn sigma_velocity_responds_to_error_spike() {
2223 let config = SGBTConfig::builder()
2224 .n_steps(3)
2225 .learning_rate(0.1)
2226 .grace_period(20)
2227 .max_depth(4)
2228 .n_bins(16)
2229 .initial_target_count(10)
2230 .uncertainty_modulated_lr(true)
2231 .build()
2232 .unwrap();
2233 let mut model = DistributionalSGBT::new(config);
2234
2235 for i in 0..200 {
2237 let x = (i as f64) * 0.1;
2238 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2239 }
2240
2241 let velocity_before = model.sigma_velocity();
2242
2243 for i in 0..50 {
2245 let x = (i as f64) * 0.1;
2246 model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2247 }
2248
2249 let velocity_after = model.sigma_velocity();
2250
2251 assert!(
2253 velocity_after > velocity_before,
2254 "sigma velocity should increase after error spike: before={}, after={}",
2255 velocity_before,
2256 velocity_after,
2257 );
2258 }
2259
2260 #[test]
2261 fn sigma_velocity_getter_works() {
2262 let config = SGBTConfig::builder()
2263 .n_steps(2)
2264 .learning_rate(0.1)
2265 .grace_period(20)
2266 .max_depth(4)
2267 .n_bins(16)
2268 .initial_target_count(10)
2269 .build()
2270 .unwrap();
2271 let model = DistributionalSGBT::new(config);
2272 assert_eq!(model.sigma_velocity(), 0.0);
2274 }
2275
2276 #[test]
2277 fn diagnostics_leaf_sample_counts_populated() {
2278 let config = SGBTConfig::builder()
2279 .n_steps(3)
2280 .learning_rate(0.1)
2281 .grace_period(10)
2282 .max_depth(4)
2283 .n_bins(16)
2284 .initial_target_count(10)
2285 .build()
2286 .unwrap();
2287 let mut model = DistributionalSGBT::new(config);
2288
2289 for i in 0..200 {
2290 let x = (i as f64) * 0.1;
2291 let features = vec![x, x.sin()];
2292 let target = 2.0 * x + 1.0;
2293 model.train_one(&(features, target));
2294 }
2295
2296 let diags = model.diagnostics();
2297 for (ti, tree) in diags.trees.iter().enumerate() {
2298 assert_eq!(
2299 tree.leaf_sample_counts.len(),
2300 tree.n_leaves,
2301 "tree {} should have sample count per leaf",
2302 ti,
2303 );
2304 if tree.samples_seen > 0 {
2306 let total: u64 = tree.leaf_sample_counts.iter().sum();
2307 assert!(
2308 total > 0,
2309 "tree {} has {} samples_seen but leaf counts sum to 0",
2310 ti,
2311 tree.samples_seen,
2312 );
2313 }
2314 }
2315 }
2316
2317 #[test]
2322 fn packed_cache_disabled_by_default() {
2323 let model = DistributionalSGBT::new(test_config());
2324 assert!(!model.has_packed_cache());
2325 assert_eq!(model.config().packed_refresh_interval, 0);
2326 }
2327
2328 #[test]
2329 #[cfg(feature = "_packed_cache_tests_disabled")]
2330 fn packed_cache_refreshes_after_interval() {
2331 let config = SGBTConfig::builder()
2332 .n_steps(5)
2333 .learning_rate(0.1)
2334 .grace_period(5)
2335 .max_depth(3)
2336 .n_bins(8)
2337 .initial_target_count(10)
2338 .packed_refresh_interval(20)
2339 .build()
2340 .unwrap();
2341
2342 let mut model = DistributionalSGBT::new(config);
2343
2344 for i in 0..40 {
2346 let x = i as f64 * 0.1;
2347 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2348 }
2349
2350 assert!(
2352 model.has_packed_cache(),
2353 "packed cache should exist after training past refresh interval"
2354 );
2355
2356 let pred = model.predict(&[2.0, 4.0, 1.0]);
2358 assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2359 }
2360
2361 #[test]
2362 #[cfg(feature = "_packed_cache_tests_disabled")]
2363 fn packed_cache_matches_full_tree() {
2364 let config = SGBTConfig::builder()
2365 .n_steps(5)
2366 .learning_rate(0.1)
2367 .grace_period(5)
2368 .max_depth(3)
2369 .n_bins(8)
2370 .initial_target_count(10)
2371 .build()
2372 .unwrap();
2373
2374 let mut model = DistributionalSGBT::new(config);
2375
2376 for i in 0..80 {
2378 let x = i as f64 * 0.1;
2379 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2380 }
2381
2382 assert!(!model.has_packed_cache());
2384 let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2385
2386 model.enable_packed_cache(10);
2388 assert!(model.has_packed_cache());
2389 let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2390
2391 let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2393 assert!(
2394 mu_diff < 0.1,
2395 "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2396 cached_pred.mu,
2397 full_pred.mu,
2398 mu_diff
2399 );
2400
2401 assert!(
2403 (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2404 "sigma should be identical: full={}, cached={}",
2405 full_pred.sigma,
2406 cached_pred.sigma
2407 );
2408 }
2409}
2410
2411#[cfg(test)]
2412#[cfg(feature = "_serde_support")]
2413mod serde_tests {
2414 use super::*;
2415 use crate::SGBTConfig;
2416
2417 fn make_trained_distributional() -> DistributionalSGBT {
2418 let config = SGBTConfig::builder()
2419 .n_steps(5)
2420 .learning_rate(0.1)
2421 .max_depth(3)
2422 .grace_period(2)
2423 .initial_target_count(10)
2424 .build()
2425 .unwrap();
2426 let mut model = DistributionalSGBT::new(config);
2427 for i in 0..50 {
2428 let x = i as f64 * 0.1;
2429 model.train_one(&(vec![x], x.sin()));
2430 }
2431 model
2432 }
2433
2434 #[test]
2435 fn json_round_trip_preserves_predictions() {
2436 let model = make_trained_distributional();
2437 let state = model.to_distributional_state();
2438 let json = crate::serde_support::save_distributional_model(&state).unwrap();
2439 let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
2440 let restored = DistributionalSGBT::from_distributional_state(loaded_state);
2441
2442 let test_points = [0.5, 1.0, 2.0, 3.0];
2443 for &x in &test_points {
2444 let orig = model.predict(&[x]);
2445 let rest = restored.predict(&[x]);
2446 assert!(
2447 (orig.mu - rest.mu).abs() < 1e-10,
2448 "JSON round-trip mu mismatch at x={}: {} vs {}",
2449 x,
2450 orig.mu,
2451 rest.mu
2452 );
2453 assert!(
2454 (orig.sigma - rest.sigma).abs() < 1e-10,
2455 "JSON round-trip sigma mismatch at x={}: {} vs {}",
2456 x,
2457 orig.sigma,
2458 rest.sigma
2459 );
2460 }
2461 }
2462
2463 #[test]
2464 fn state_preserves_rolling_sigma_mean() {
2465 let config = SGBTConfig::builder()
2466 .n_steps(5)
2467 .learning_rate(0.1)
2468 .max_depth(3)
2469 .grace_period(2)
2470 .initial_target_count(10)
2471 .uncertainty_modulated_lr(true)
2472 .build()
2473 .unwrap();
2474 let mut model = DistributionalSGBT::new(config);
2475 for i in 0..50 {
2476 let x = i as f64 * 0.1;
2477 model.train_one(&(vec![x], x.sin()));
2478 }
2479 let state = model.to_distributional_state();
2480 assert!(state.uncertainty_modulated_lr);
2481 assert!(state.rolling_sigma_mean >= 0.0);
2482
2483 let restored = DistributionalSGBT::from_distributional_state(state);
2484 assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
2485 }
2486
2487 #[test]
2488 fn auto_bandwidth_computed_distributional() {
2489 let config = SGBTConfig::builder()
2490 .n_steps(3)
2491 .learning_rate(0.1)
2492 .grace_period(10)
2493 .initial_target_count(10)
2494 .build()
2495 .unwrap();
2496 let mut model = DistributionalSGBT::new(config);
2497
2498 for i in 0..200 {
2499 let x = (i as f64) * 0.1;
2500 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2501 }
2502
2503 let bws = model.auto_bandwidths();
2505 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2506
2507 let diag = model.diagnostics();
2509 assert_eq!(diag.auto_bandwidths.len(), 2);
2510
2511 assert!(diag.location_trees[0].prediction_mean.is_finite());
2513 assert!(diag.location_trees[0].prediction_std.is_finite());
2514
2515 let pred = model.predict(&[1.0, 1.0_f64.sin()]);
2516 assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
2517 assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
2518 }
2519
2520 #[test]
2521 fn max_leaf_output_clamps_predictions() {
2522 let config = SGBTConfig::builder()
2523 .n_steps(5)
2524 .learning_rate(1.0) .max_leaf_output(0.5)
2526 .build()
2527 .unwrap();
2528 let mut model = DistributionalSGBT::new(config);
2529
2530 for i in 0..200 {
2532 let target = if i % 2 == 0 { 100.0 } else { -100.0 };
2533 let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
2534 model.train_one(&sample);
2535 }
2536
2537 let pred = model.predict(&[2.0, 0.5]);
2539 assert!(
2540 pred.mu.is_finite(),
2541 "prediction should be finite with clamping"
2542 );
2543 }
2544
2545 #[test]
2546 fn min_hessian_sum_suppresses_fresh_leaves() {
2547 let config = SGBTConfig::builder()
2548 .n_steps(3)
2549 .learning_rate(0.01)
2550 .min_hessian_sum(50.0)
2551 .build()
2552 .unwrap();
2553 let mut model = DistributionalSGBT::new(config);
2554
2555 for i in 0..60 {
2557 let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
2558 model.train_one(&sample);
2559 }
2560
2561 let pred = model.predict(&[30.0, 0.5]);
2562 assert!(
2563 pred.mu.is_finite(),
2564 "prediction should be finite with min_hessian_sum"
2565 );
2566 }
2567
2568 #[test]
2569 fn predict_interpolated_returns_finite() {
2570 let config = SGBTConfig::builder()
2571 .n_steps(10)
2572 .learning_rate(0.1)
2573 .grace_period(20)
2574 .max_depth(4)
2575 .n_bins(16)
2576 .initial_target_count(10)
2577 .build()
2578 .unwrap();
2579 let mut model = DistributionalSGBT::new(config);
2580 for i in 0..200 {
2581 let x = i as f64 * 0.1;
2582 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2583 model.train_one(&sample);
2584 }
2585
2586 let pred = model.predict_interpolated(&[1.0, 0.5]);
2587 assert!(pred.mu.is_finite(), "interpolated mu should be finite");
2588 assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
2589 }
2590
2591 #[test]
2592 fn huber_k_bounds_gradients() {
2593 let config = SGBTConfig::builder()
2594 .n_steps(5)
2595 .learning_rate(0.01)
2596 .huber_k(1.345)
2597 .build()
2598 .unwrap();
2599 let mut model = DistributionalSGBT::new(config);
2600
2601 for i in 0..300 {
2603 let target = if i % 50 == 0 {
2604 1000.0
2605 } else {
2606 (i as f64 * 0.1).sin()
2607 };
2608 let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
2609 model.train_one(&sample);
2610 }
2611
2612 let pred = model.predict(&[5.0, 0.3]);
2613 assert!(
2614 pred.mu.is_finite(),
2615 "Huber-loss mu should be finite despite outliers"
2616 );
2617 assert!(pred.sigma > 0.0, "sigma should be positive");
2618 }
2619
2620 #[test]
2621 fn ensemble_gradient_stats_populated() {
2622 let config = SGBTConfig::builder()
2623 .n_steps(10)
2624 .learning_rate(0.1)
2625 .grace_period(20)
2626 .max_depth(4)
2627 .n_bins(16)
2628 .initial_target_count(10)
2629 .build()
2630 .unwrap();
2631 let mut model = DistributionalSGBT::new(config);
2632 for i in 0..200 {
2633 let x = i as f64 * 0.1;
2634 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2635 model.train_one(&sample);
2636 }
2637
2638 let diag = model.diagnostics();
2639 assert!(
2640 diag.ensemble_grad_mean.is_finite(),
2641 "ensemble grad mean should be finite"
2642 );
2643 assert!(
2644 diag.ensemble_grad_std >= 0.0,
2645 "ensemble grad std should be non-negative"
2646 );
2647 assert!(
2648 diag.ensemble_grad_std.is_finite(),
2649 "ensemble grad std should be finite"
2650 );
2651 }
2652
2653 #[test]
2654 fn huber_k_validation() {
2655 let result = SGBTConfig::builder()
2656 .n_steps(5)
2657 .learning_rate(0.01)
2658 .huber_k(-1.0)
2659 .build();
2660 assert!(result.is_err(), "negative huber_k should fail validation");
2661 }
2662
2663 #[test]
2664 fn max_leaf_output_validation() {
2665 let result = SGBTConfig::builder()
2666 .n_steps(5)
2667 .learning_rate(0.01)
2668 .max_leaf_output(-1.0)
2669 .build();
2670 assert!(
2671 result.is_err(),
2672 "negative max_leaf_output should fail validation"
2673 );
2674 }
2675
2676 #[test]
2677 fn predict_sibling_interpolated_varies_with_features() {
2678 let config = SGBTConfig::builder()
2679 .n_steps(10)
2680 .learning_rate(0.1)
2681 .grace_period(10)
2682 .max_depth(6)
2683 .delta(0.1)
2684 .initial_target_count(10)
2685 .build()
2686 .unwrap();
2687 let mut model = DistributionalSGBT::new(config);
2688
2689 for i in 0..2000 {
2690 let x = (i as f64) * 0.01;
2691 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2692 let sample = crate::Sample::new(vec![x, x * 0.3], y);
2693 model.train_one(&sample);
2694 }
2695
2696 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2698 assert!(
2699 pred.mu.is_finite(),
2700 "sibling interpolated mu should be finite"
2701 );
2702 assert!(
2703 pred.sigma > 0.0,
2704 "sibling interpolated sigma should be positive"
2705 );
2706
2707 let bws = model.auto_bandwidths();
2709 if bws.iter().any(|&b| b.is_finite()) {
2710 let hard_preds: Vec<f64> = (0..200)
2711 .map(|i| {
2712 let x = i as f64 * 0.1;
2713 model.predict(&[x, x * 0.3]).mu
2714 })
2715 .collect();
2716 let hard_changes = hard_preds
2717 .windows(2)
2718 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2719 .count();
2720
2721 let preds: Vec<f64> = (0..200)
2722 .map(|i| {
2723 let x = i as f64 * 0.1;
2724 model.predict_sibling_interpolated(&[x, x * 0.3]).mu
2725 })
2726 .collect();
2727
2728 let sibling_changes = preds
2729 .windows(2)
2730 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2731 .count();
2732 assert!(
2733 sibling_changes >= hard_changes,
2734 "sibling should produce >= hard changes: sibling={}, hard={}",
2735 sibling_changes,
2736 hard_changes
2737 );
2738 }
2739 }
2740}