1use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::ensemble::config::{SGBTConfig, ScaleMode};
34use crate::ensemble::step::BoostingStep;
35use crate::sample::{Observation, SampleRef};
36use crate::tree::builder::TreeConfig;
37
38struct PackedInferenceCache {
43 bytes: Vec<u8>,
44 base: f64,
45 n_features: usize,
46}
47
48impl Clone for PackedInferenceCache {
49 fn clone(&self) -> Self {
50 Self {
51 bytes: self.bytes.clone(),
52 base: self.base,
53 n_features: self.n_features,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy)]
60pub struct GaussianPrediction {
61 pub mu: f64,
63 pub sigma: f64,
65 pub log_sigma: f64,
67 pub honest_sigma: f64,
76}
77
78impl GaussianPrediction {
79 #[inline]
83 pub fn lower(&self, z: f64) -> f64 {
84 self.mu - z * self.sigma
85 }
86
87 #[inline]
89 pub fn upper(&self, z: f64) -> f64 {
90 self.mu + z * self.sigma
91 }
92}
93
94#[derive(Debug, Clone)]
100pub struct TreeDiagnostic {
101 pub n_leaves: usize,
103 pub max_depth_reached: usize,
105 pub samples_seen: u64,
107 pub leaf_weight_stats: (f64, f64, f64, f64),
109 pub split_features: Vec<usize>,
111 pub leaf_sample_counts: Vec<u64>,
113 pub prediction_mean: f64,
115 pub prediction_std: f64,
117}
118
119#[derive(Debug, Clone)]
124pub struct ModelDiagnostics {
125 pub trees: Vec<TreeDiagnostic>,
127 pub location_trees: Vec<TreeDiagnostic>,
129 pub scale_trees: Vec<TreeDiagnostic>,
131 pub feature_split_counts: Vec<usize>,
133 pub location_base: f64,
135 pub scale_base: f64,
137 pub empirical_sigma: f64,
139 pub scale_mode: ScaleMode,
141 pub scale_trees_active: usize,
143 pub auto_bandwidths: Vec<f64>,
146 pub ensemble_grad_mean: f64,
148 pub ensemble_grad_std: f64,
150}
151
152#[derive(Debug, Clone)]
154pub struct DecomposedPrediction {
155 pub location_base: f64,
157 pub scale_base: f64,
159 pub location_contributions: Vec<f64>,
162 pub scale_contributions: Vec<f64>,
165}
166
167impl DecomposedPrediction {
168 pub fn mu(&self) -> f64 {
170 self.location_base + self.location_contributions.iter().sum::<f64>()
171 }
172
173 pub fn log_sigma(&self) -> f64 {
175 self.scale_base + self.scale_contributions.iter().sum::<f64>()
176 }
177
178 pub fn sigma(&self) -> f64 {
180 crate::math::exp(self.log_sigma()).max(1e-8)
181 }
182}
183
184pub struct DistributionalSGBT {
210 config: SGBTConfig,
212 location_steps: Vec<BoostingStep>,
214 scale_steps: Vec<BoostingStep>,
216 location_base: f64,
218 scale_base: f64,
220 base_initialized: bool,
222 initial_targets: Vec<f64>,
224 initial_target_count: usize,
226 samples_seen: u64,
228 rng_state: u64,
230 uncertainty_modulated_lr: bool,
232 rolling_sigma_mean: f64,
237 scale_mode: ScaleMode,
239 ewma_sq_err: f64,
244 empirical_sigma_alpha: f64,
246 prev_sigma: f64,
248 sigma_velocity: f64,
251 auto_bandwidths: Vec<f64>,
253 last_replacement_sum: u64,
255 ensemble_grad_mean: f64,
257 ensemble_grad_m2: f64,
259 ensemble_grad_count: u64,
261 rolling_honest_sigma_mean: f64,
267 packed_cache: Option<PackedInferenceCache>,
269 samples_since_refresh: u64,
271 packed_refresh_interval: u64,
273}
274
275impl Clone for DistributionalSGBT {
276 fn clone(&self) -> Self {
277 Self {
278 config: self.config.clone(),
279 location_steps: self.location_steps.clone(),
280 scale_steps: self.scale_steps.clone(),
281 location_base: self.location_base,
282 scale_base: self.scale_base,
283 base_initialized: self.base_initialized,
284 initial_targets: self.initial_targets.clone(),
285 initial_target_count: self.initial_target_count,
286 samples_seen: self.samples_seen,
287 rng_state: self.rng_state,
288 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
289 rolling_sigma_mean: self.rolling_sigma_mean,
290 scale_mode: self.scale_mode,
291 ewma_sq_err: self.ewma_sq_err,
292 empirical_sigma_alpha: self.empirical_sigma_alpha,
293 prev_sigma: self.prev_sigma,
294 sigma_velocity: self.sigma_velocity,
295 auto_bandwidths: self.auto_bandwidths.clone(),
296 last_replacement_sum: self.last_replacement_sum,
297 ensemble_grad_mean: self.ensemble_grad_mean,
298 ensemble_grad_m2: self.ensemble_grad_m2,
299 ensemble_grad_count: self.ensemble_grad_count,
300 rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
301 packed_cache: self.packed_cache.clone(),
302 samples_since_refresh: self.samples_since_refresh,
303 packed_refresh_interval: self.packed_refresh_interval,
304 }
305 }
306}
307
308impl core::fmt::Debug for DistributionalSGBT {
309 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
310 let mut s = f.debug_struct("DistributionalSGBT");
311 s.field("n_steps", &self.location_steps.len())
312 .field("samples_seen", &self.samples_seen)
313 .field("location_base", &self.location_base)
314 .field("scale_mode", &self.scale_mode)
315 .field("base_initialized", &self.base_initialized);
316 match self.scale_mode {
317 ScaleMode::Empirical => {
318 s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
319 }
320 ScaleMode::TreeChain => {
321 s.field("scale_base", &self.scale_base);
322 }
323 }
324 if self.uncertainty_modulated_lr {
325 s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
326 }
327 s.finish()
328 }
329}
330
331impl DistributionalSGBT {
332 pub fn new(config: SGBTConfig) -> Self {
338 let leaf_decay_alpha = config
339 .leaf_half_life
340 .map(|hl| crate::math::exp(-(crate::math::ln(2.0_f64)) / hl as f64));
341
342 let tree_config = TreeConfig::new()
343 .max_depth(config.max_depth)
344 .n_bins(config.n_bins)
345 .lambda(config.lambda)
346 .gamma(config.gamma)
347 .grace_period(config.grace_period)
348 .delta(config.delta)
349 .feature_subsample_rate(config.feature_subsample_rate)
350 .leaf_decay_alpha_opt(leaf_decay_alpha)
351 .split_reeval_interval_opt(config.split_reeval_interval)
352 .feature_types_opt(config.feature_types.clone())
353 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
354 .monotone_constraints_opt(config.monotone_constraints.clone())
355 .max_leaf_output_opt(config.max_leaf_output)
356 .adaptive_depth_opt(config.adaptive_depth)
357 .min_hessian_sum_opt(config.min_hessian_sum)
358 .leaf_model_type(config.leaf_model_type.clone());
359
360 let max_tree_samples = config.max_tree_samples;
361
362 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
364 let location_steps: Vec<BoostingStep> = (0..config.n_steps)
365 .map(|i| {
366 let mut tc = tree_config.clone();
367 tc.seed = config.seed ^ (i as u64);
368 let detector = config.drift_detector.create();
369 if shadow_warmup > 0 {
370 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
371 } else {
372 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
373 }
374 })
375 .collect();
376
377 let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
379 .map(|i| {
380 let mut tc = tree_config.clone();
381 tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
382 let detector = config.drift_detector.create();
383 if shadow_warmup > 0 {
384 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
385 } else {
386 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
387 }
388 })
389 .collect();
390
391 let seed = config.seed;
392 let initial_target_count = config.initial_target_count;
393 let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
394 let scale_mode = config.scale_mode;
395 let empirical_sigma_alpha = config.empirical_sigma_alpha;
396 let packed_refresh_interval = config.packed_refresh_interval;
397 Self {
398 config,
399 location_steps,
400 scale_steps,
401 location_base: 0.0,
402 scale_base: 0.0,
403 base_initialized: false,
404 initial_targets: Vec::new(),
405 initial_target_count,
406 samples_seen: 0,
407 rng_state: seed,
408 uncertainty_modulated_lr,
409 rolling_sigma_mean: 1.0, scale_mode,
411 ewma_sq_err: 1.0, empirical_sigma_alpha,
413 prev_sigma: 0.0,
414 sigma_velocity: 0.0,
415 auto_bandwidths: Vec::new(),
416 last_replacement_sum: 0,
417 ensemble_grad_mean: 0.0,
418 ensemble_grad_m2: 0.0,
419 ensemble_grad_count: 0,
420 rolling_honest_sigma_mean: 0.0,
421 packed_cache: None,
422 samples_since_refresh: 0,
423 packed_refresh_interval,
424 }
425 }
426
427 fn compute_honest_sigma(&self, features: &[f64]) -> f64 {
433 let n = self.location_steps.len();
434 if n <= 1 {
435 return 0.0;
436 }
437 let lr = self.config.learning_rate;
438 let mut sum = 0.0_f64;
439 let mut sq_sum = 0.0_f64;
440 for step in &self.location_steps {
441 let c = lr * step.predict(features);
442 sum += c;
443 sq_sum += c * c;
444 }
445 let nf = n as f64;
446 let mean_c = sum / nf;
447 let var = (sq_sum / nf) - (mean_c * mean_c);
448 let var_corrected = var * nf / (nf - 1.0);
449 crate::math::sqrt(var_corrected.max(0.0))
450 }
451
452 pub fn train_one(&mut self, sample: &impl Observation) {
454 self.samples_seen += 1;
455 let target = sample.target();
456 let features = sample.features();
457
458 if !self.base_initialized {
460 self.initial_targets.push(target);
461 if self.initial_targets.len() >= self.initial_target_count {
462 let sum: f64 = self.initial_targets.iter().sum();
464 let mean = sum / self.initial_targets.len() as f64;
465 self.location_base = mean;
466
467 let var: f64 = self
469 .initial_targets
470 .iter()
471 .map(|&y| (y - mean) * (y - mean))
472 .sum::<f64>()
473 / self.initial_targets.len() as f64;
474 let initial_std = crate::math::sqrt(var).max(1e-6);
475 self.scale_base = crate::math::ln(initial_std);
476
477 self.rolling_sigma_mean = initial_std;
479 self.ewma_sq_err = var.max(1e-12);
480
481 self.prev_sigma = initial_std;
483 self.sigma_velocity = 0.0;
484
485 self.base_initialized = true;
486 self.initial_targets.clear();
487 self.initial_targets.shrink_to_fit();
488 }
489 return;
490 }
491
492 match self.scale_mode {
493 ScaleMode::Empirical => self.train_one_empirical(target, features),
494 ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
495 }
496
497 self.refresh_bandwidths();
499 }
500
501 fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
503 let mut mu = self.location_base;
505 for s in 0..self.location_steps.len() {
506 mu += self.config.learning_rate * self.location_steps[s].predict(features);
507 }
508
509 let honest_sigma = self.compute_honest_sigma(features);
511 const HONEST_SIGMA_ALPHA: f64 = 0.001;
512 self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
513 * self.rolling_honest_sigma_mean
514 + HONEST_SIGMA_ALPHA * honest_sigma;
515
516 let err = target - mu;
518 let alpha = self.empirical_sigma_alpha;
519 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
520 let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
521
522 let sigma_ratio = if self.uncertainty_modulated_lr {
524 let d_sigma = empirical_sigma - self.prev_sigma;
526 self.prev_sigma = empirical_sigma;
527
528 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
530
531 let k_d = if self.rolling_sigma_mean > 1e-12 {
533 crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
534 } else {
535 0.0
536 };
537
538 let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
540 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
541
542 const SIGMA_EWMA_ALPHA: f64 = 0.001;
544 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
545 + SIGMA_EWMA_ALPHA * empirical_sigma;
546
547 ratio
548 } else {
549 1.0
550 };
551
552 let base_lr = self.config.learning_rate;
553
554 let mut mu_accum = self.location_base;
556 for s in 0..self.location_steps.len() {
557 let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
558 self.update_ensemble_grad_stats(g_mu);
560 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
561 let loc_pred =
562 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
563 mu_accum += (base_lr * sigma_ratio) * loc_pred;
564 }
565
566 self.maybe_refresh_packed_cache();
568 }
569
570 fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
572 let mut mu = self.location_base;
573 let mut log_sigma = self.scale_base;
574
575 let honest_sigma = self.compute_honest_sigma(features);
577 const HONEST_SIGMA_ALPHA: f64 = 0.001;
578 self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
579 * self.rolling_honest_sigma_mean
580 + HONEST_SIGMA_ALPHA * honest_sigma;
581
582 let sigma_ratio = if self.uncertainty_modulated_lr {
584 let current_sigma = crate::math::exp(log_sigma).max(1e-8);
585
586 let d_sigma = current_sigma - self.prev_sigma;
588 self.prev_sigma = current_sigma;
589
590 let alpha = self.empirical_sigma_alpha;
592 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
593
594 let k_d = if self.rolling_sigma_mean > 1e-12 {
596 crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
597 } else {
598 0.0
599 };
600
601 let pd_sigma = current_sigma + k_d * self.sigma_velocity;
603 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
604
605 const SIGMA_EWMA_ALPHA: f64 = 0.001;
606 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
607 + SIGMA_EWMA_ALPHA * current_sigma;
608
609 ratio
610 } else {
611 1.0
612 };
613
614 let base_lr = self.config.learning_rate;
615
616 for s in 0..self.location_steps.len() {
618 let sigma = crate::math::exp(log_sigma).max(1e-8);
619 let z = (target - mu) / sigma;
620
621 let (g_mu, h_mu) = self.location_gradient(mu, target);
623 self.update_ensemble_grad_stats(g_mu);
625
626 let g_sigma = 1.0 - z * z;
628 let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
629
630 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
631
632 let loc_pred =
634 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
635 mu += (base_lr * sigma_ratio) * loc_pred;
636
637 let scale_pred =
639 self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
640 log_sigma += base_lr * scale_pred;
641 }
642
643 let err = target - mu;
645 let alpha = self.empirical_sigma_alpha;
646 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
647
648 self.maybe_refresh_packed_cache();
650 }
651
652 pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
661 let mu = if let Some(ref cache) = self.packed_cache {
663 let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
664 match crate::EnsembleView::from_bytes(&cache.bytes) {
665 Ok(view) => {
666 let packed_mu = cache.base + view.predict(&features_f32) as f64;
667 if packed_mu.is_finite() {
668 packed_mu
669 } else {
670 self.predict_full_trees(features)
671 }
672 }
673 Err(_) => self.predict_full_trees(features),
674 }
675 } else {
676 self.predict_full_trees(features)
677 };
678
679 let (sigma, log_sigma) = match self.scale_mode {
680 ScaleMode::Empirical => {
681 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
682 (s, crate::math::ln(s))
683 }
684 ScaleMode::TreeChain => {
685 let mut ls = self.scale_base;
686 if self.auto_bandwidths.is_empty() {
687 for s in 0..self.scale_steps.len() {
688 ls += self.config.learning_rate * self.scale_steps[s].predict(features);
689 }
690 } else {
691 for s in 0..self.scale_steps.len() {
692 ls += self.config.learning_rate
693 * self.scale_steps[s]
694 .predict_smooth_auto(features, &self.auto_bandwidths);
695 }
696 }
697 (crate::math::exp(ls).max(1e-8), ls)
698 }
699 };
700
701 let honest_sigma = self.compute_honest_sigma(features);
702
703 GaussianPrediction {
704 mu,
705 sigma,
706 log_sigma,
707 honest_sigma,
708 }
709 }
710
711 fn predict_full_trees(&self, features: &[f64]) -> f64 {
713 let mut mu = self.location_base;
714 if self.auto_bandwidths.is_empty() {
715 for s in 0..self.location_steps.len() {
716 mu += self.config.learning_rate * self.location_steps[s].predict(features);
717 }
718 } else {
719 for s in 0..self.location_steps.len() {
720 mu += self.config.learning_rate
721 * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
722 }
723 }
724 mu
725 }
726
727 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
737 let mut mu = self.location_base;
738 for s in 0..self.location_steps.len() {
739 mu += self.config.learning_rate
740 * self.location_steps[s].predict_smooth(features, bandwidth);
741 }
742
743 let (sigma, log_sigma) = match self.scale_mode {
744 ScaleMode::Empirical => {
745 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
746 (s, crate::math::ln(s))
747 }
748 ScaleMode::TreeChain => {
749 let mut ls = self.scale_base;
750 for s in 0..self.scale_steps.len() {
751 ls += self.config.learning_rate
752 * self.scale_steps[s].predict_smooth(features, bandwidth);
753 }
754 (crate::math::exp(ls).max(1e-8), ls)
755 }
756 };
757
758 let honest_sigma = self.compute_honest_sigma(features);
759
760 GaussianPrediction {
761 mu,
762 sigma,
763 log_sigma,
764 honest_sigma,
765 }
766 }
767
768 pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
773 let mut mu = self.location_base;
774 for s in 0..self.location_steps.len() {
775 mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
776 }
777
778 let (sigma, log_sigma) = match self.scale_mode {
779 ScaleMode::Empirical => {
780 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
781 (s, crate::math::ln(s))
782 }
783 ScaleMode::TreeChain => {
784 let mut ls = self.scale_base;
785 for s in 0..self.scale_steps.len() {
786 ls += self.config.learning_rate
787 * self.scale_steps[s].predict_interpolated(features);
788 }
789 (crate::math::exp(ls).max(1e-8), ls)
790 }
791 };
792
793 let honest_sigma = self.compute_honest_sigma(features);
794
795 GaussianPrediction {
796 mu,
797 sigma,
798 log_sigma,
799 honest_sigma,
800 }
801 }
802
803 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
809 let mut mu = self.location_base;
810 for s in 0..self.location_steps.len() {
811 mu += self.config.learning_rate
812 * self.location_steps[s]
813 .predict_sibling_interpolated(features, &self.auto_bandwidths);
814 }
815
816 let (sigma, log_sigma) = match self.scale_mode {
817 ScaleMode::Empirical => {
818 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
819 (s, crate::math::ln(s))
820 }
821 ScaleMode::TreeChain => {
822 let mut ls = self.scale_base;
823 for s in 0..self.scale_steps.len() {
824 ls += self.config.learning_rate
825 * self.scale_steps[s]
826 .predict_sibling_interpolated(features, &self.auto_bandwidths);
827 }
828 (crate::math::exp(ls).max(1e-8), ls)
829 }
830 };
831
832 let honest_sigma = self.compute_honest_sigma(features);
833
834 GaussianPrediction {
835 mu,
836 sigma,
837 log_sigma,
838 honest_sigma,
839 }
840 }
841
842 pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
847 let mut mu = self.location_base;
848 for s in 0..self.location_steps.len() {
849 mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
850 }
851
852 let (sigma, log_sigma) = match self.scale_mode {
853 ScaleMode::Empirical => {
854 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
855 (s, crate::math::ln(s))
856 }
857 ScaleMode::TreeChain => {
858 let mut ls = self.scale_base;
859 for s in 0..self.scale_steps.len() {
860 ls +=
861 self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
862 }
863 (crate::math::exp(ls).max(1e-8), ls)
864 }
865 };
866
867 let honest_sigma = self.compute_honest_sigma(features);
868
869 GaussianPrediction {
870 mu,
871 sigma,
872 log_sigma,
873 honest_sigma,
874 }
875 }
876
877 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
879 let mut mu = self.location_base;
880 for s in 0..self.location_steps.len() {
881 mu += self.config.learning_rate
882 * self.location_steps[s]
883 .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
884 }
885
886 let (sigma, log_sigma) = match self.scale_mode {
887 ScaleMode::Empirical => {
888 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
889 (s, crate::math::ln(s))
890 }
891 ScaleMode::TreeChain => {
892 let mut ls = self.scale_base;
893 for s in 0..self.scale_steps.len() {
894 ls += self.config.learning_rate
895 * self.scale_steps[s].predict_graduated_sibling_interpolated(
896 features,
897 &self.auto_bandwidths,
898 );
899 }
900 (crate::math::exp(ls).max(1e-8), ls)
901 }
902 };
903
904 let honest_sigma = self.compute_honest_sigma(features);
905
906 GaussianPrediction {
907 mu,
908 sigma,
909 log_sigma,
910 honest_sigma,
911 }
912 }
913
914 pub fn predict_soft_routed(&self, features: &[f64]) -> GaussianPrediction {
919 let mut mu = self.location_base;
921 for step in &self.location_steps {
922 mu += self.config.learning_rate * step.predict_soft_routed(features);
923 }
924
925 let (sigma, log_sigma) = match self.scale_mode {
927 ScaleMode::Empirical => {
928 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
929 (s, crate::math::ln(s))
930 }
931 ScaleMode::TreeChain => {
932 let mut ls = self.scale_base;
933 for step in &self.scale_steps {
934 ls += self.config.learning_rate * step.predict_soft_routed(features);
935 }
936 (crate::math::exp(ls).max(1e-8), ls)
937 }
938 };
939
940 let honest_sigma = self.compute_honest_sigma(features);
941
942 GaussianPrediction {
943 mu,
944 sigma,
945 log_sigma,
946 honest_sigma,
947 }
948 }
949
950 pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
959 let pred = self.predict(features);
960 let sigma_ratio = if self.uncertainty_modulated_lr {
961 (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
962 } else {
963 1.0
964 };
965 (pred.mu, pred.sigma, sigma_ratio)
966 }
967
968 #[inline]
972 pub fn empirical_sigma(&self) -> f64 {
973 crate::math::sqrt(self.ewma_sq_err)
974 }
975
976 #[inline]
978 pub fn scale_mode(&self) -> ScaleMode {
979 self.scale_mode
980 }
981
982 #[inline]
988 pub fn sigma_velocity(&self) -> f64 {
989 self.sigma_velocity
990 }
991
992 #[inline]
994 pub fn predict_mu(&self, features: &[f64]) -> f64 {
995 self.predict(features).mu
996 }
997
998 #[inline]
1000 pub fn predict_sigma(&self, features: &[f64]) -> f64 {
1001 self.predict(features).sigma
1002 }
1003
1004 pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
1011 let pred = self.predict(features);
1012 (pred.lower(confidence), pred.upper(confidence))
1013 }
1014
1015 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
1017 feature_matrix.iter().map(|f| self.predict(f)).collect()
1018 }
1019
1020 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
1022 for sample in samples {
1023 self.train_one(sample);
1024 }
1025 }
1026
1027 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
1029 &mut self,
1030 samples: &[O],
1031 interval: usize,
1032 mut callback: F,
1033 ) {
1034 let interval = interval.max(1);
1035 for (i, sample) in samples.iter().enumerate() {
1036 self.train_one(sample);
1037 if (i + 1) % interval == 0 {
1038 callback(i + 1);
1039 }
1040 }
1041 let total = samples.len();
1042 if total % interval != 0 {
1043 callback(total);
1044 }
1045 }
1046
1047 #[inline]
1052 fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
1053 if let Some(k) = self.config.huber_k {
1054 let delta = k * crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1055 let residual = mu - target;
1056 if crate::math::abs(residual) <= delta {
1057 (residual, 1.0)
1058 } else {
1059 (delta * residual.signum(), 1e-6)
1060 }
1061 } else {
1062 (mu - target, 1.0)
1063 }
1064 }
1065
1066 #[inline]
1068 fn update_ensemble_grad_stats(&mut self, gradient: f64) {
1069 self.ensemble_grad_count += 1;
1070 let delta = gradient - self.ensemble_grad_mean;
1071 self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
1072 let delta2 = gradient - self.ensemble_grad_mean;
1073 self.ensemble_grad_m2 += delta * delta2;
1074 }
1075
1076 pub fn ensemble_grad_std(&self) -> f64 {
1078 if self.ensemble_grad_count < 2 {
1079 return 0.0;
1080 }
1081 crate::math::fmax(
1082 crate::math::sqrt(self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64),
1083 0.0,
1084 )
1085 }
1086
1087 pub fn ensemble_grad_mean(&self) -> f64 {
1089 self.ensemble_grad_mean
1090 }
1091
1092 fn maybe_refresh_packed_cache(&mut self) {
1094 if self.packed_refresh_interval > 0 {
1095 self.samples_since_refresh += 1;
1096 if self.samples_since_refresh >= self.packed_refresh_interval {
1097 self.refresh_packed_cache();
1098 self.samples_since_refresh = 0;
1099 }
1100 }
1101 }
1102
1103 fn refresh_packed_cache(&mut self) {
1109 }
1114
1115 pub fn enable_packed_cache(&mut self, interval: u64) {
1120 self.packed_refresh_interval = interval;
1121 self.samples_since_refresh = 0;
1122 if interval > 0 && self.base_initialized {
1123 self.refresh_packed_cache();
1124 } else if interval == 0 {
1125 self.packed_cache = None;
1126 }
1127 }
1128
1129 #[inline]
1131 pub fn has_packed_cache(&self) -> bool {
1132 self.packed_cache.is_some()
1133 }
1134
1135 fn refresh_bandwidths(&mut self) {
1137 let current_sum: u64 = self
1138 .location_steps
1139 .iter()
1140 .chain(self.scale_steps.iter())
1141 .map(|s| s.slot().replacements())
1142 .sum();
1143 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1144 self.auto_bandwidths = self.compute_auto_bandwidths();
1145 self.last_replacement_sum = current_sum;
1146 }
1147 }
1148
1149 fn compute_auto_bandwidths(&self) -> Vec<f64> {
1151 const K: f64 = 2.0;
1152
1153 let n_features = self
1154 .location_steps
1155 .iter()
1156 .chain(self.scale_steps.iter())
1157 .filter_map(|s| s.slot().active_tree().n_features())
1158 .max()
1159 .unwrap_or(0);
1160
1161 if n_features == 0 {
1162 return Vec::new();
1163 }
1164
1165 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1166
1167 for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1168 let tree_thresholds = step
1169 .slot()
1170 .active_tree()
1171 .collect_split_thresholds_per_feature();
1172 for (i, ts) in tree_thresholds.into_iter().enumerate() {
1173 if i < n_features {
1174 all_thresholds[i].extend(ts);
1175 }
1176 }
1177 }
1178
1179 let n_bins = self.config.n_bins as f64;
1180
1181 all_thresholds
1182 .iter()
1183 .map(|ts| {
1184 if ts.is_empty() {
1185 return f64::INFINITY;
1186 }
1187
1188 let mut sorted = ts.clone();
1189 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1190 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
1191
1192 if sorted.len() < 2 {
1193 return f64::INFINITY;
1194 }
1195
1196 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1197
1198 if sorted.len() < 3 {
1199 let range = sorted.last().unwrap() - sorted.first().unwrap();
1200 if range < 1e-15 {
1201 return f64::INFINITY;
1202 }
1203 return (range / n_bins) * K;
1204 }
1205
1206 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1207 let median_gap = if gaps.len() % 2 == 0 {
1208 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1209 } else {
1210 gaps[gaps.len() / 2]
1211 };
1212
1213 if median_gap < 1e-15 {
1214 f64::INFINITY
1215 } else {
1216 median_gap * K
1217 }
1218 })
1219 .collect()
1220 }
1221
1222 pub fn auto_bandwidths(&self) -> &[f64] {
1224 &self.auto_bandwidths
1225 }
1226
1227 pub fn reset(&mut self) {
1229 for step in &mut self.location_steps {
1230 step.reset();
1231 }
1232 for step in &mut self.scale_steps {
1233 step.reset();
1234 }
1235 self.location_base = 0.0;
1236 self.scale_base = 0.0;
1237 self.base_initialized = false;
1238 self.initial_targets.clear();
1239 self.samples_seen = 0;
1240 self.rng_state = self.config.seed;
1241 self.rolling_sigma_mean = 1.0;
1242 self.ewma_sq_err = 1.0;
1243 self.prev_sigma = 0.0;
1244 self.sigma_velocity = 0.0;
1245 self.auto_bandwidths.clear();
1246 self.last_replacement_sum = 0;
1247 self.ensemble_grad_mean = 0.0;
1248 self.ensemble_grad_m2 = 0.0;
1249 self.ensemble_grad_count = 0;
1250 self.rolling_honest_sigma_mean = 0.0;
1251 self.packed_cache = None;
1252 self.samples_since_refresh = 0;
1253 }
1254
1255 #[inline]
1257 pub fn n_samples_seen(&self) -> u64 {
1258 self.samples_seen
1259 }
1260
1261 #[inline]
1263 pub fn n_steps(&self) -> usize {
1264 self.location_steps.len()
1265 }
1266
1267 pub fn n_trees(&self) -> usize {
1269 let loc = self.location_steps.len()
1270 + self
1271 .location_steps
1272 .iter()
1273 .filter(|s| s.has_alternate())
1274 .count();
1275 let scale = self.scale_steps.len()
1276 + self
1277 .scale_steps
1278 .iter()
1279 .filter(|s| s.has_alternate())
1280 .count();
1281 loc + scale
1282 }
1283
1284 pub fn total_leaves(&self) -> usize {
1286 let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1287 let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1288 loc + scale
1289 }
1290
1291 #[inline]
1293 pub fn is_initialized(&self) -> bool {
1294 self.base_initialized
1295 }
1296
1297 #[inline]
1299 pub fn config(&self) -> &SGBTConfig {
1300 &self.config
1301 }
1302
1303 pub fn location_steps(&self) -> &[BoostingStep] {
1305 &self.location_steps
1306 }
1307
1308 #[inline]
1310 pub fn location_base(&self) -> f64 {
1311 self.location_base
1312 }
1313
1314 #[inline]
1316 pub fn learning_rate(&self) -> f64 {
1317 self.config.learning_rate
1318 }
1319
1320 #[inline]
1324 pub fn rolling_sigma_mean(&self) -> f64 {
1325 self.rolling_sigma_mean
1326 }
1327
1328 #[inline]
1330 pub fn is_uncertainty_modulated(&self) -> bool {
1331 self.uncertainty_modulated_lr
1332 }
1333
1334 #[inline]
1338 pub fn rolling_honest_sigma_mean(&self) -> f64 {
1339 self.rolling_honest_sigma_mean
1340 }
1341
1342 pub fn diagnostics(&self) -> ModelDiagnostics {
1354 let n = self.location_steps.len();
1355 let mut trees = Vec::with_capacity(2 * n);
1356 let mut feature_split_counts: Vec<usize> = Vec::new();
1357
1358 fn collect_tree_diags(
1359 steps: &[BoostingStep],
1360 trees: &mut Vec<TreeDiagnostic>,
1361 feature_split_counts: &mut Vec<usize>,
1362 ) {
1363 for step in steps {
1364 let slot = step.slot();
1365 let tree = slot.active_tree();
1366 let arena = tree.arena();
1367
1368 let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1369 .filter(|&i| arena.is_leaf[i])
1370 .map(|i| arena.leaf_value[i])
1371 .collect();
1372
1373 let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1374 .filter(|&i| arena.is_leaf[i])
1375 .map(|i| arena.sample_count[i])
1376 .collect();
1377
1378 let max_depth_reached = (0..arena.is_leaf.len())
1379 .filter(|&i| arena.is_leaf[i])
1380 .map(|i| arena.depth[i] as usize)
1381 .max()
1382 .unwrap_or(0);
1383
1384 let leaf_weight_stats = if leaf_values.is_empty() {
1385 (0.0, 0.0, 0.0, 0.0)
1386 } else {
1387 let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1388 let max = leaf_values
1389 .iter()
1390 .cloned()
1391 .fold(f64::NEG_INFINITY, f64::max);
1392 let sum: f64 = leaf_values.iter().sum();
1393 let mean = sum / leaf_values.len() as f64;
1394 let var: f64 = leaf_values
1395 .iter()
1396 .map(|v| crate::math::powi(v - mean, 2))
1397 .sum::<f64>()
1398 / leaf_values.len() as f64;
1399 (min, max, mean, crate::math::sqrt(var))
1400 };
1401
1402 let gains = slot.split_gains();
1403 let split_features: Vec<usize> = gains
1404 .iter()
1405 .enumerate()
1406 .filter(|(_, &g)| g > 0.0)
1407 .map(|(i, _)| i)
1408 .collect();
1409
1410 if !gains.is_empty() {
1411 if feature_split_counts.is_empty() {
1412 feature_split_counts.resize(gains.len(), 0);
1413 }
1414 for &fi in &split_features {
1415 if fi < feature_split_counts.len() {
1416 feature_split_counts[fi] += 1;
1417 }
1418 }
1419 }
1420
1421 trees.push(TreeDiagnostic {
1422 n_leaves: leaf_values.len(),
1423 max_depth_reached,
1424 samples_seen: step.n_samples_seen(),
1425 leaf_weight_stats,
1426 split_features,
1427 leaf_sample_counts,
1428 prediction_mean: slot.prediction_mean(),
1429 prediction_std: slot.prediction_std(),
1430 });
1431 }
1432 }
1433
1434 collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1435 collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1436
1437 let location_trees = trees[..n].to_vec();
1438 let scale_trees = trees[n..].to_vec();
1439 let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1440
1441 ModelDiagnostics {
1442 trees,
1443 location_trees,
1444 scale_trees,
1445 feature_split_counts,
1446 location_base: self.location_base,
1447 scale_base: self.scale_base,
1448 empirical_sigma: crate::math::sqrt(self.ewma_sq_err),
1449 scale_mode: self.scale_mode,
1450 scale_trees_active,
1451 auto_bandwidths: self.auto_bandwidths.clone(),
1452 ensemble_grad_mean: self.ensemble_grad_mean,
1453 ensemble_grad_std: self.ensemble_grad_std(),
1454 }
1455 }
1456
1457 pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1469 let lr = self.config.learning_rate;
1470 let location: Vec<f64> = self
1471 .location_steps
1472 .iter()
1473 .map(|s| lr * s.predict(features))
1474 .collect();
1475
1476 let (sb, scale) = match self.scale_mode {
1477 ScaleMode::Empirical => {
1478 let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1479 (
1480 crate::math::ln(empirical_sigma),
1481 vec![0.0; self.location_steps.len()],
1482 )
1483 }
1484 ScaleMode::TreeChain => {
1485 let s: Vec<f64> = self
1486 .scale_steps
1487 .iter()
1488 .map(|s| lr * s.predict(features))
1489 .collect();
1490 (self.scale_base, s)
1491 }
1492 };
1493
1494 DecomposedPrediction {
1495 location_base: self.location_base,
1496 scale_base: sb,
1497 location_contributions: location,
1498 scale_contributions: scale,
1499 }
1500 }
1501
1502 pub fn feature_importances(&self) -> Vec<f64> {
1508 let mut totals: Vec<f64> = Vec::new();
1509 for steps in [&self.location_steps, &self.scale_steps] {
1510 for step in steps {
1511 let gains = step.slot().split_gains();
1512 if totals.is_empty() && !gains.is_empty() {
1513 totals.resize(gains.len(), 0.0);
1514 }
1515 for (i, &g) in gains.iter().enumerate() {
1516 if i < totals.len() {
1517 totals[i] += g;
1518 }
1519 }
1520 }
1521 }
1522 let sum: f64 = totals.iter().sum();
1523 if sum > 0.0 {
1524 totals.iter_mut().for_each(|v| *v /= sum);
1525 }
1526 totals
1527 }
1528
1529 pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1534 fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1535 let mut totals: Vec<f64> = Vec::new();
1536 for step in steps {
1537 let gains = step.slot().split_gains();
1538 if totals.is_empty() && !gains.is_empty() {
1539 totals.resize(gains.len(), 0.0);
1540 }
1541 for (i, &g) in gains.iter().enumerate() {
1542 if i < totals.len() {
1543 totals[i] += g;
1544 }
1545 }
1546 }
1547 let sum: f64 = totals.iter().sum();
1548 if sum > 0.0 {
1549 totals.iter_mut().for_each(|v| *v /= sum);
1550 }
1551 totals
1552 }
1553 (
1554 aggregate(&self.location_steps),
1555 aggregate(&self.scale_steps),
1556 )
1557 }
1558
1559 #[cfg(feature = "_serde_support")]
1565 pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1566 use super::snapshot_tree;
1567 use crate::serde_support::{DistributionalModelState, StepSnapshot};
1568
1569 fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1570 let slot = step.slot();
1571 let tree_snap = snapshot_tree(slot.active_tree());
1572 let alt_snap = slot.alternate_tree().map(snapshot_tree);
1573 let drift_state = slot.detector().serialize_state();
1574 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1575 StepSnapshot {
1576 tree: tree_snap,
1577 alternate_tree: alt_snap,
1578 drift_state,
1579 alt_drift_state,
1580 }
1581 }
1582
1583 DistributionalModelState {
1584 config: self.config.clone(),
1585 location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1586 scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1587 location_base: self.location_base,
1588 scale_base: self.scale_base,
1589 base_initialized: self.base_initialized,
1590 initial_targets: self.initial_targets.clone(),
1591 initial_target_count: self.initial_target_count,
1592 samples_seen: self.samples_seen,
1593 rng_state: self.rng_state,
1594 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1595 rolling_sigma_mean: self.rolling_sigma_mean,
1596 ewma_sq_err: self.ewma_sq_err,
1597 rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
1598 }
1599 }
1600
1601 #[cfg(feature = "_serde_support")]
1607 pub fn from_distributional_state(
1608 state: crate::serde_support::DistributionalModelState,
1609 ) -> Self {
1610 use super::rebuild_tree;
1611 use crate::ensemble::replacement::TreeSlot;
1612 use crate::serde_support::StepSnapshot;
1613
1614 let leaf_decay_alpha = state
1615 .config
1616 .leaf_half_life
1617 .map(|hl| crate::math::exp((-(crate::math::ln(2.0_f64)) / hl as f64)));
1618 let max_tree_samples = state.config.max_tree_samples;
1619
1620 let base_tree_config = TreeConfig::new()
1621 .max_depth(state.config.max_depth)
1622 .n_bins(state.config.n_bins)
1623 .lambda(state.config.lambda)
1624 .gamma(state.config.gamma)
1625 .grace_period(state.config.grace_period)
1626 .delta(state.config.delta)
1627 .feature_subsample_rate(state.config.feature_subsample_rate)
1628 .leaf_decay_alpha_opt(leaf_decay_alpha)
1629 .split_reeval_interval_opt(state.config.split_reeval_interval)
1630 .feature_types_opt(state.config.feature_types.clone())
1631 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1632 .monotone_constraints_opt(state.config.monotone_constraints.clone())
1633 .adaptive_depth_opt(state.config.adaptive_depth)
1634 .leaf_model_type(state.config.leaf_model_type.clone());
1635
1636 let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
1638 snaps
1639 .iter()
1640 .enumerate()
1641 .map(|(i, snap)| {
1642 let tc = base_tree_config
1643 .clone()
1644 .seed(state.config.seed ^ (i as u64) ^ seed_xor);
1645
1646 let active = rebuild_tree(&snap.tree, tc.clone());
1647 let alternate = snap
1648 .alternate_tree
1649 .as_ref()
1650 .map(|s| rebuild_tree(s, tc.clone()));
1651
1652 let mut detector = state.config.drift_detector.create();
1653 if let Some(ref ds) = snap.drift_state {
1654 detector.restore_state(ds);
1655 }
1656 let mut slot =
1657 TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
1658 if let Some(ref ads) = snap.alt_drift_state {
1659 if let Some(alt_det) = slot.alt_detector_mut() {
1660 alt_det.restore_state(ads);
1661 }
1662 }
1663 BoostingStep::from_slot(slot)
1664 })
1665 .collect()
1666 };
1667
1668 let location_steps = rebuild_steps(&state.location_steps, 0);
1670 let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
1671
1672 let scale_mode = state.config.scale_mode;
1673 let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
1674 let packed_refresh_interval = state.config.packed_refresh_interval;
1675 Self {
1676 config: state.config,
1677 location_steps,
1678 scale_steps,
1679 location_base: state.location_base,
1680 scale_base: state.scale_base,
1681 base_initialized: state.base_initialized,
1682 initial_targets: state.initial_targets,
1683 initial_target_count: state.initial_target_count,
1684 samples_seen: state.samples_seen,
1685 rng_state: state.rng_state,
1686 uncertainty_modulated_lr: state.uncertainty_modulated_lr,
1687 rolling_sigma_mean: state.rolling_sigma_mean,
1688 scale_mode,
1689 ewma_sq_err: state.ewma_sq_err,
1690 empirical_sigma_alpha,
1691 prev_sigma: 0.0,
1692 sigma_velocity: 0.0,
1693 auto_bandwidths: Vec::new(),
1694 last_replacement_sum: 0,
1695 ensemble_grad_mean: 0.0,
1696 ensemble_grad_m2: 0.0,
1697 ensemble_grad_count: 0,
1698 rolling_honest_sigma_mean: state.rolling_honest_sigma_mean,
1699 packed_cache: None,
1700 samples_since_refresh: 0,
1701 packed_refresh_interval,
1702 }
1703 }
1704}
1705
1706use crate::learner::StreamingLearner;
1711
1712impl StreamingLearner for DistributionalSGBT {
1713 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
1714 let sample = SampleRef::weighted(features, target, weight);
1715 DistributionalSGBT::train_one(self, &sample);
1717 }
1718
1719 fn predict(&self, features: &[f64]) -> f64 {
1721 DistributionalSGBT::predict(self, features).mu
1722 }
1723
1724 fn n_samples_seen(&self) -> u64 {
1725 self.samples_seen
1726 }
1727
1728 fn reset(&mut self) {
1729 DistributionalSGBT::reset(self);
1730 }
1731}
1732
1733#[cfg(test)]
1738mod tests {
1739 use super::*;
1740 use alloc::format;
1741 use alloc::vec;
1742 use alloc::vec::Vec;
1743
1744 fn test_config() -> SGBTConfig {
1745 SGBTConfig::builder()
1746 .n_steps(10)
1747 .learning_rate(0.1)
1748 .grace_period(20)
1749 .max_depth(4)
1750 .n_bins(16)
1751 .initial_target_count(10)
1752 .build()
1753 .unwrap()
1754 }
1755
1756 #[test]
1757 fn fresh_model_predicts_zero() {
1758 let model = DistributionalSGBT::new(test_config());
1759 let pred = model.predict(&[1.0, 2.0, 3.0]);
1760 assert!(pred.mu.abs() < 1e-12);
1761 assert!(pred.sigma > 0.0);
1762 }
1763
1764 #[test]
1765 fn sigma_always_positive() {
1766 let mut model = DistributionalSGBT::new(test_config());
1767
1768 for i in 0..200 {
1770 let x = i as f64 * 0.1;
1771 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1772 }
1773
1774 for i in 0..20 {
1776 let x = i as f64 * 0.5;
1777 let pred = model.predict(&[x, x * 0.5]);
1778 assert!(
1779 pred.sigma > 0.0,
1780 "sigma must be positive, got {}",
1781 pred.sigma
1782 );
1783 assert!(pred.sigma.is_finite(), "sigma must be finite");
1784 }
1785 }
1786
1787 #[test]
1788 fn constant_target_has_small_sigma() {
1789 let mut model = DistributionalSGBT::new(test_config());
1790
1791 for i in 0..200 {
1793 let x = i as f64 * 0.1;
1794 model.train_one(&(vec![x, x * 2.0], 5.0));
1795 }
1796
1797 let pred = model.predict(&[1.0, 2.0]);
1798 assert!(pred.mu.is_finite());
1799 assert!(pred.sigma.is_finite());
1800 assert!(pred.sigma > 0.0);
1801 }
1804
1805 #[test]
1806 fn noisy_target_has_finite_predictions() {
1807 let mut model = DistributionalSGBT::new(test_config());
1808
1809 let mut rng: u64 = 42;
1811 for i in 0..200 {
1812 rng ^= rng << 13;
1813 rng ^= rng >> 7;
1814 rng ^= rng << 17;
1815 let noise = (rng % 1000) as f64 / 500.0 - 1.0; let x = i as f64 * 0.1;
1817 model.train_one(&(vec![x], x * 2.0 + noise));
1818 }
1819
1820 let pred = model.predict(&[5.0]);
1821 assert!(pred.mu.is_finite());
1822 assert!(pred.sigma.is_finite());
1823 assert!(pred.sigma > 0.0);
1824 }
1825
1826 #[test]
1827 fn predict_interval_bounds_correct() {
1828 let mut model = DistributionalSGBT::new(test_config());
1829
1830 for i in 0..200 {
1831 let x = i as f64 * 0.1;
1832 model.train_one(&(vec![x], x * 2.0));
1833 }
1834
1835 let (lo, hi) = model.predict_interval(&[5.0], 1.96);
1836 let pred = model.predict(&[5.0]);
1837
1838 assert!(lo < pred.mu, "lower bound should be < mu");
1839 assert!(hi > pred.mu, "upper bound should be > mu");
1840 assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
1841 }
1842
1843 #[test]
1844 fn batch_prediction_matches_individual() {
1845 let mut model = DistributionalSGBT::new(test_config());
1846
1847 for i in 0..100 {
1848 let x = i as f64 * 0.1;
1849 model.train_one(&(vec![x, x * 2.0], x));
1850 }
1851
1852 let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
1853 let batch = model.predict_batch(&features);
1854
1855 for (feat, batch_pred) in features.iter().zip(batch.iter()) {
1856 let individual = model.predict(feat);
1857 assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
1858 assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
1859 }
1860 }
1861
1862 #[test]
1863 fn reset_clears_state() {
1864 let mut model = DistributionalSGBT::new(test_config());
1865
1866 for i in 0..200 {
1867 let x = i as f64 * 0.1;
1868 model.train_one(&(vec![x], x * 2.0));
1869 }
1870
1871 assert!(model.n_samples_seen() > 0);
1872 model.reset();
1873
1874 assert_eq!(model.n_samples_seen(), 0);
1875 assert!(!model.is_initialized());
1876 }
1877
1878 #[test]
1879 fn gaussian_prediction_lower_upper() {
1880 let pred = GaussianPrediction {
1881 mu: 10.0,
1882 sigma: 2.0,
1883 log_sigma: 2.0_f64.ln(),
1884 honest_sigma: 0.0,
1885 };
1886
1887 assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
1888 assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
1889 }
1890
1891 #[test]
1892 fn train_batch_works() {
1893 let mut model = DistributionalSGBT::new(test_config());
1894 let samples: Vec<(Vec<f64>, f64)> = (0..100)
1895 .map(|i| {
1896 let x = i as f64 * 0.1;
1897 (vec![x], x * 2.0)
1898 })
1899 .collect();
1900
1901 model.train_batch(&samples);
1902 assert_eq!(model.n_samples_seen(), 100);
1903 }
1904
1905 #[test]
1906 fn debug_format_works() {
1907 let model = DistributionalSGBT::new(test_config());
1908 let debug = format!("{:?}", model);
1909 assert!(debug.contains("DistributionalSGBT"));
1910 }
1911
1912 #[test]
1913 fn n_trees_counts_both_ensembles() {
1914 let model = DistributionalSGBT::new(test_config());
1915 assert!(model.n_trees() >= 20);
1917 }
1918
1919 fn modulated_config() -> SGBTConfig {
1922 SGBTConfig::builder()
1923 .n_steps(10)
1924 .learning_rate(0.1)
1925 .grace_period(20)
1926 .max_depth(4)
1927 .n_bins(16)
1928 .initial_target_count(10)
1929 .uncertainty_modulated_lr(true)
1930 .build()
1931 .unwrap()
1932 }
1933
1934 #[test]
1935 fn sigma_modulated_initializes_rolling_mean() {
1936 let mut model = DistributionalSGBT::new(modulated_config());
1937 assert!(model.is_uncertainty_modulated());
1938
1939 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1941
1942 for i in 0..200 {
1944 let x = i as f64 * 0.1;
1945 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1946 }
1947
1948 assert!(model.rolling_sigma_mean() > 0.0);
1950 assert!(model.rolling_sigma_mean().is_finite());
1951 }
1952
1953 #[test]
1954 fn predict_distributional_returns_sigma_ratio() {
1955 let mut model = DistributionalSGBT::new(modulated_config());
1956
1957 for i in 0..200 {
1958 let x = i as f64 * 0.1;
1959 model.train_one(&(vec![x], x * 2.0 + 1.0));
1960 }
1961
1962 let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1963 assert!(mu.is_finite());
1964 assert!(sigma > 0.0);
1965 assert!(
1966 (0.1..=10.0).contains(&sigma_ratio),
1967 "sigma_ratio={}",
1968 sigma_ratio
1969 );
1970 }
1971
1972 #[test]
1973 fn predict_distributional_without_modulation_returns_one() {
1974 let mut model = DistributionalSGBT::new(test_config());
1975 assert!(!model.is_uncertainty_modulated());
1976
1977 for i in 0..200 {
1978 let x = i as f64 * 0.1;
1979 model.train_one(&(vec![x], x * 2.0));
1980 }
1981
1982 let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1983 assert!(
1984 (sigma_ratio - 1.0).abs() < 1e-12,
1985 "should be 1.0 when disabled"
1986 );
1987 }
1988
1989 #[test]
1990 fn modulated_model_sigma_finite_under_varying_noise() {
1991 let mut model = DistributionalSGBT::new(modulated_config());
1992
1993 let mut rng: u64 = 123;
1994 for i in 0..500 {
1995 rng ^= rng << 13;
1996 rng ^= rng >> 7;
1997 rng ^= rng << 17;
1998 let noise = (rng % 1000) as f64 / 100.0 - 5.0; let x = i as f64 * 0.1;
2000 let scale = if i < 250 { 1.0 } else { 5.0 };
2002 model.train_one(&(vec![x], x * 2.0 + noise * scale));
2003 }
2004
2005 let pred = model.predict(&[10.0]);
2006 assert!(pred.mu.is_finite());
2007 assert!(pred.sigma.is_finite());
2008 assert!(pred.sigma > 0.0);
2009 assert!(model.rolling_sigma_mean().is_finite());
2010 }
2011
2012 #[test]
2013 fn reset_clears_rolling_sigma_mean() {
2014 let mut model = DistributionalSGBT::new(modulated_config());
2015
2016 for i in 0..200 {
2017 let x = i as f64 * 0.1;
2018 model.train_one(&(vec![x], x * 2.0));
2019 }
2020
2021 let sigma_before = model.rolling_sigma_mean();
2022 assert!(sigma_before > 0.0);
2023
2024 model.reset();
2025 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
2026 }
2027
2028 #[test]
2029 fn streaming_learner_returns_mu() {
2030 let mut model = DistributionalSGBT::new(test_config());
2031 for i in 0..200 {
2032 let x = i as f64 * 0.1;
2033 StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
2034 }
2035 let pred = StreamingLearner::predict(&model, &[5.0]);
2036 let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
2037 assert!(
2038 (pred - gaussian.mu).abs() < 1e-12,
2039 "StreamingLearner::predict should return mu"
2040 );
2041 }
2042
2043 fn trained_model() -> DistributionalSGBT {
2046 let config = SGBTConfig::builder()
2047 .n_steps(10)
2048 .learning_rate(0.1)
2049 .grace_period(10) .max_depth(4)
2051 .n_bins(16)
2052 .initial_target_count(10)
2053 .build()
2054 .unwrap();
2055 let mut model = DistributionalSGBT::new(config);
2056 for i in 0..500 {
2057 let x = i as f64 * 0.1;
2058 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2059 }
2060 model
2061 }
2062
2063 #[test]
2064 fn diagnostics_returns_correct_tree_count() {
2065 let model = trained_model();
2066 let diag = model.diagnostics();
2067 assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
2069 }
2070
2071 #[test]
2072 fn diagnostics_trees_have_leaves() {
2073 let model = trained_model();
2074 let diag = model.diagnostics();
2075 for (i, tree) in diag.trees.iter().enumerate() {
2076 assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
2077 }
2078 let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
2080 assert!(
2081 total_samples > 0,
2082 "at least some trees should have seen samples"
2083 );
2084 }
2085
2086 #[test]
2087 fn diagnostics_leaf_weight_stats_finite() {
2088 let model = trained_model();
2089 let diag = model.diagnostics();
2090 for (i, tree) in diag.trees.iter().enumerate() {
2091 let (min, max, mean, std) = tree.leaf_weight_stats;
2092 assert!(min.is_finite(), "tree {i} min not finite");
2093 assert!(max.is_finite(), "tree {i} max not finite");
2094 assert!(mean.is_finite(), "tree {i} mean not finite");
2095 assert!(std.is_finite(), "tree {i} std not finite");
2096 assert!(min <= max, "tree {i} min > max");
2097 }
2098 }
2099
2100 #[test]
2101 fn diagnostics_base_predictions_match() {
2102 let model = trained_model();
2103 let diag = model.diagnostics();
2104 assert!(
2105 (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
2106 "location_base should be plausible"
2107 );
2108 }
2109
2110 #[test]
2111 fn predict_decomposed_reconstructs_prediction() {
2112 let model = trained_model();
2113 let features = [5.0, 2.5, 1.0];
2114 let pred = model.predict(&features);
2115 let decomp = model.predict_decomposed(&features);
2116
2117 assert!(
2118 (decomp.mu() - pred.mu).abs() < 1e-10,
2119 "decomposed mu ({}) != predict mu ({})",
2120 decomp.mu(),
2121 pred.mu
2122 );
2123 assert!(
2124 (decomp.sigma() - pred.sigma).abs() < 1e-10,
2125 "decomposed sigma ({}) != predict sigma ({})",
2126 decomp.sigma(),
2127 pred.sigma
2128 );
2129 }
2130
2131 #[test]
2132 fn predict_decomposed_correct_lengths() {
2133 let model = trained_model();
2134 let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2135 assert_eq!(
2136 decomp.location_contributions.len(),
2137 model.n_steps(),
2138 "location contributions should have n_steps entries"
2139 );
2140 assert_eq!(
2141 decomp.scale_contributions.len(),
2142 model.n_steps(),
2143 "scale contributions should have n_steps entries"
2144 );
2145 }
2146
2147 #[test]
2148 fn feature_importances_work() {
2149 let model = trained_model();
2150 let imp = model.feature_importances();
2151 for (i, &v) in imp.iter().enumerate() {
2154 assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2155 assert!(v.is_finite(), "importance {i} should be finite");
2156 }
2157 let sum: f64 = imp.iter().sum();
2158 if sum > 0.0 {
2159 assert!(
2160 (sum - 1.0).abs() < 1e-10,
2161 "non-zero importances should sum to 1.0, got {sum}"
2162 );
2163 }
2164 }
2165
2166 #[test]
2167 fn feature_importances_split_works() {
2168 let model = trained_model();
2169 let (loc_imp, scale_imp) = model.feature_importances_split();
2170 for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2171 let sum: f64 = imp.iter().sum();
2172 if sum > 0.0 {
2173 assert!(
2174 (sum - 1.0).abs() < 1e-10,
2175 "{name} importances should sum to 1.0, got {sum}"
2176 );
2177 }
2178 for &v in imp.iter() {
2179 assert!(v >= 0.0 && v.is_finite());
2180 }
2181 }
2182 }
2183
2184 #[test]
2187 fn empirical_sigma_default_mode() {
2188 use crate::ensemble::config::ScaleMode;
2189 let config = test_config();
2190 let model = DistributionalSGBT::new(config);
2191 assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2192 }
2193
2194 #[test]
2195 fn empirical_sigma_tracks_errors() {
2196 let mut model = DistributionalSGBT::new(test_config());
2197
2198 for i in 0..200 {
2200 let x = i as f64 * 0.1;
2201 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2202 }
2203
2204 let sigma_clean = model.empirical_sigma();
2205 assert!(sigma_clean > 0.0, "sigma should be positive");
2206 assert!(sigma_clean.is_finite(), "sigma should be finite");
2207
2208 let mut rng: u64 = 42;
2210 for i in 200..400 {
2211 rng ^= rng << 13;
2212 rng ^= rng >> 7;
2213 rng ^= rng << 17;
2214 let noise = (rng % 10000) as f64 / 100.0 - 50.0; let x = i as f64 * 0.1;
2216 model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2217 }
2218
2219 let sigma_noisy = model.empirical_sigma();
2220 assert!(
2221 sigma_noisy > sigma_clean,
2222 "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2223 );
2224 }
2225
2226 #[test]
2227 fn empirical_sigma_modulated_lr_adapts() {
2228 let config = SGBTConfig::builder()
2229 .n_steps(10)
2230 .learning_rate(0.1)
2231 .grace_period(20)
2232 .max_depth(4)
2233 .n_bins(16)
2234 .initial_target_count(10)
2235 .uncertainty_modulated_lr(true)
2236 .build()
2237 .unwrap();
2238 let mut model = DistributionalSGBT::new(config);
2239
2240 for i in 0..300 {
2242 let x = i as f64 * 0.1;
2243 model.train_one(&(vec![x], x * 2.0 + 1.0));
2244 }
2245
2246 let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2247 assert!(sigma_ratio.is_finite());
2248 assert!(
2249 (0.1..=10.0).contains(&sigma_ratio),
2250 "sigma_ratio={sigma_ratio}"
2251 );
2252 }
2253
2254 #[test]
2255 fn tree_chain_mode_trains_scale_trees() {
2256 use crate::ensemble::config::ScaleMode;
2257 let config = SGBTConfig::builder()
2258 .n_steps(10)
2259 .learning_rate(0.1)
2260 .grace_period(10)
2261 .max_depth(4)
2262 .n_bins(16)
2263 .initial_target_count(10)
2264 .scale_mode(ScaleMode::TreeChain)
2265 .build()
2266 .unwrap();
2267 let mut model = DistributionalSGBT::new(config);
2268 assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2269
2270 for i in 0..500 {
2271 let x = i as f64 * 0.1;
2272 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2273 }
2274
2275 let pred = model.predict(&[5.0, 2.5, 1.0]);
2276 assert!(pred.mu.is_finite());
2277 assert!(pred.sigma > 0.0);
2278 assert!(pred.sigma.is_finite());
2279 }
2280
2281 #[test]
2282 fn diagnostics_shows_empirical_sigma() {
2283 let model = trained_model();
2284 let diag = model.diagnostics();
2285 assert!(
2286 diag.empirical_sigma > 0.0,
2287 "empirical_sigma should be positive"
2288 );
2289 assert!(
2290 diag.empirical_sigma.is_finite(),
2291 "empirical_sigma should be finite"
2292 );
2293 }
2294
2295 #[test]
2296 fn diagnostics_scale_trees_split_fields() {
2297 let model = trained_model();
2298 let diag = model.diagnostics();
2299 assert_eq!(diag.location_trees.len(), model.n_steps());
2300 assert_eq!(diag.scale_trees.len(), model.n_steps());
2301 }
2304
2305 #[test]
2306 fn reset_clears_empirical_sigma() {
2307 let mut model = DistributionalSGBT::new(test_config());
2308 for i in 0..200 {
2309 let x = i as f64 * 0.1;
2310 model.train_one(&(vec![x], x * 2.0));
2311 }
2312 model.reset();
2313 assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2315 }
2316
2317 #[test]
2318 fn predict_smooth_returns_finite() {
2319 let config = SGBTConfig::builder()
2320 .n_steps(3)
2321 .learning_rate(0.1)
2322 .grace_period(20)
2323 .max_depth(4)
2324 .n_bins(16)
2325 .initial_target_count(10)
2326 .build()
2327 .unwrap();
2328 let mut model = DistributionalSGBT::new(config);
2329
2330 for i in 0..200 {
2331 let x = (i as f64) * 0.1;
2332 let features = vec![x, x.sin()];
2333 let target = 2.0 * x + 1.0;
2334 model.train_one(&(features, target));
2335 }
2336
2337 let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2338 assert!(pred.mu.is_finite(), "smooth mu should be finite");
2339 assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2340 assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2341 }
2342
2343 #[test]
2346 fn sigma_velocity_responds_to_error_spike() {
2347 let config = SGBTConfig::builder()
2348 .n_steps(3)
2349 .learning_rate(0.1)
2350 .grace_period(20)
2351 .max_depth(4)
2352 .n_bins(16)
2353 .initial_target_count(10)
2354 .uncertainty_modulated_lr(true)
2355 .build()
2356 .unwrap();
2357 let mut model = DistributionalSGBT::new(config);
2358
2359 for i in 0..200 {
2361 let x = (i as f64) * 0.1;
2362 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2363 }
2364
2365 let velocity_before = model.sigma_velocity();
2366
2367 for i in 0..50 {
2369 let x = (i as f64) * 0.1;
2370 model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2371 }
2372
2373 let velocity_after = model.sigma_velocity();
2374
2375 assert!(
2377 velocity_after > velocity_before,
2378 "sigma velocity should increase after error spike: before={}, after={}",
2379 velocity_before,
2380 velocity_after,
2381 );
2382 }
2383
2384 #[test]
2385 fn sigma_velocity_getter_works() {
2386 let config = SGBTConfig::builder()
2387 .n_steps(2)
2388 .learning_rate(0.1)
2389 .grace_period(20)
2390 .max_depth(4)
2391 .n_bins(16)
2392 .initial_target_count(10)
2393 .build()
2394 .unwrap();
2395 let model = DistributionalSGBT::new(config);
2396 assert_eq!(model.sigma_velocity(), 0.0);
2398 }
2399
2400 #[test]
2401 fn diagnostics_leaf_sample_counts_populated() {
2402 let config = SGBTConfig::builder()
2403 .n_steps(3)
2404 .learning_rate(0.1)
2405 .grace_period(10)
2406 .max_depth(4)
2407 .n_bins(16)
2408 .initial_target_count(10)
2409 .build()
2410 .unwrap();
2411 let mut model = DistributionalSGBT::new(config);
2412
2413 for i in 0..200 {
2414 let x = (i as f64) * 0.1;
2415 let features = vec![x, x.sin()];
2416 let target = 2.0 * x + 1.0;
2417 model.train_one(&(features, target));
2418 }
2419
2420 let diags = model.diagnostics();
2421 for (ti, tree) in diags.trees.iter().enumerate() {
2422 assert_eq!(
2423 tree.leaf_sample_counts.len(),
2424 tree.n_leaves,
2425 "tree {} should have sample count per leaf",
2426 ti,
2427 );
2428 if tree.samples_seen > 0 {
2430 let total: u64 = tree.leaf_sample_counts.iter().sum();
2431 assert!(
2432 total > 0,
2433 "tree {} has {} samples_seen but leaf counts sum to 0",
2434 ti,
2435 tree.samples_seen,
2436 );
2437 }
2438 }
2439 }
2440
2441 #[test]
2446 fn packed_cache_disabled_by_default() {
2447 let model = DistributionalSGBT::new(test_config());
2448 assert!(!model.has_packed_cache());
2449 assert_eq!(model.config().packed_refresh_interval, 0);
2450 }
2451
2452 #[test]
2453 #[cfg(feature = "_packed_cache_tests_disabled")]
2454 fn packed_cache_refreshes_after_interval() {
2455 let config = SGBTConfig::builder()
2456 .n_steps(5)
2457 .learning_rate(0.1)
2458 .grace_period(5)
2459 .max_depth(3)
2460 .n_bins(8)
2461 .initial_target_count(10)
2462 .packed_refresh_interval(20)
2463 .build()
2464 .unwrap();
2465
2466 let mut model = DistributionalSGBT::new(config);
2467
2468 for i in 0..40 {
2470 let x = i as f64 * 0.1;
2471 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2472 }
2473
2474 assert!(
2476 model.has_packed_cache(),
2477 "packed cache should exist after training past refresh interval"
2478 );
2479
2480 let pred = model.predict(&[2.0, 4.0, 1.0]);
2482 assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2483 }
2484
2485 #[test]
2486 #[cfg(feature = "_packed_cache_tests_disabled")]
2487 fn packed_cache_matches_full_tree() {
2488 let config = SGBTConfig::builder()
2489 .n_steps(5)
2490 .learning_rate(0.1)
2491 .grace_period(5)
2492 .max_depth(3)
2493 .n_bins(8)
2494 .initial_target_count(10)
2495 .build()
2496 .unwrap();
2497
2498 let mut model = DistributionalSGBT::new(config);
2499
2500 for i in 0..80 {
2502 let x = i as f64 * 0.1;
2503 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2504 }
2505
2506 assert!(!model.has_packed_cache());
2508 let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2509
2510 model.enable_packed_cache(10);
2512 assert!(model.has_packed_cache());
2513 let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2514
2515 let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2517 assert!(
2518 mu_diff < 0.1,
2519 "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2520 cached_pred.mu,
2521 full_pred.mu,
2522 mu_diff
2523 );
2524
2525 assert!(
2527 (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2528 "sigma should be identical: full={}, cached={}",
2529 full_pred.sigma,
2530 cached_pred.sigma
2531 );
2532 }
2533
2534 #[test]
2535 fn honest_sigma_in_gaussian_prediction() {
2536 let config = SGBTConfig::builder()
2537 .n_steps(5)
2538 .learning_rate(0.1)
2539 .max_depth(3)
2540 .grace_period(2)
2541 .initial_target_count(10)
2542 .build()
2543 .unwrap();
2544 let mut model = DistributionalSGBT::new(config);
2545 for i in 0..100 {
2546 let x = i as f64 * 0.1;
2547 model.train_one(&(vec![x], x * 2.0));
2548 }
2549
2550 let pred = model.predict(&[5.0]);
2551 assert!(
2553 pred.honest_sigma.is_finite(),
2554 "honest_sigma should be finite, got {}",
2555 pred.honest_sigma
2556 );
2557 assert!(
2558 pred.honest_sigma >= 0.0,
2559 "honest_sigma should be >= 0, got {}",
2560 pred.honest_sigma
2561 );
2562 }
2563
2564 #[test]
2565 fn honest_sigma_increases_with_divergence() {
2566 let config = SGBTConfig::builder()
2567 .n_steps(8)
2568 .learning_rate(0.1)
2569 .max_depth(4)
2570 .grace_period(2)
2571 .initial_target_count(10)
2572 .build()
2573 .unwrap();
2574 let mut model = DistributionalSGBT::new(config);
2575
2576 for i in 0..200 {
2578 let x = i as f64 * 0.05;
2579 model.train_one(&(vec![x], x * 2.0));
2580 }
2581
2582 let in_dist = model.predict(&[5.0]);
2584 let out_dist = model.predict(&[500.0]);
2586
2587 assert!(in_dist.honest_sigma.is_finite());
2589 assert!(out_dist.honest_sigma.is_finite());
2590 assert!(in_dist.honest_sigma >= 0.0);
2591 assert!(out_dist.honest_sigma >= 0.0);
2592
2593 }
2598
2599 #[test]
2600 fn honest_sigma_zero_for_single_step() {
2601 let config = SGBTConfig::builder()
2603 .n_steps(1)
2604 .learning_rate(0.1)
2605 .max_depth(3)
2606 .grace_period(2)
2607 .initial_target_count(10)
2608 .build()
2609 .unwrap();
2610 let mut model = DistributionalSGBT::new(config);
2611 for i in 0..100 {
2612 let x = i as f64 * 0.1;
2613 model.train_one(&(vec![x], x * 2.0));
2614 }
2615 let pred = model.predict(&[5.0]);
2616 assert!(
2617 pred.honest_sigma.abs() < 1e-15,
2618 "honest_sigma should be 0 with 1 step, got {}",
2619 pred.honest_sigma
2620 );
2621 }
2622
2623 #[test]
2624 fn distributional_soft_routed_prediction() {
2625 let config = SGBTConfig::builder()
2626 .n_steps(10)
2627 .learning_rate(0.1)
2628 .grace_period(20)
2629 .max_depth(4)
2630 .n_bins(16)
2631 .initial_target_count(10)
2632 .build()
2633 .unwrap();
2634 let mut model = DistributionalSGBT::new(config);
2635
2636 for i in 0..200 {
2638 let x = i as f64 * 0.05;
2639 let y = x * 2.0 + 1.0;
2640 model.train_one(&(vec![x, x * 0.3], y));
2641 }
2642
2643 let pred = model.predict_soft_routed(&[1.0, 0.3]);
2645 assert!(
2646 pred.mu.is_finite(),
2647 "soft-routed mu should be finite, got {}",
2648 pred.mu
2649 );
2650 assert!(
2651 pred.sigma.is_finite() && pred.sigma > 0.0,
2652 "soft-routed sigma should be finite and positive, got {}",
2653 pred.sigma
2654 );
2655 assert!(
2656 pred.log_sigma.is_finite(),
2657 "soft-routed log_sigma should be finite, got {}",
2658 pred.log_sigma
2659 );
2660 assert!(
2661 pred.honest_sigma.is_finite(),
2662 "soft-routed honest_sigma should be finite, got {}",
2663 pred.honest_sigma
2664 );
2665 }
2666}
2667
2668#[cfg(test)]
2669#[cfg(feature = "_serde_support")]
2670mod serde_tests {
2671 use super::*;
2672 use crate::SGBTConfig;
2673
2674 fn make_trained_distributional() -> DistributionalSGBT {
2675 let config = SGBTConfig::builder()
2676 .n_steps(5)
2677 .learning_rate(0.1)
2678 .max_depth(3)
2679 .grace_period(2)
2680 .initial_target_count(10)
2681 .build()
2682 .unwrap();
2683 let mut model = DistributionalSGBT::new(config);
2684 for i in 0..50 {
2685 let x = i as f64 * 0.1;
2686 model.train_one(&(vec![x], x.sin()));
2687 }
2688 model
2689 }
2690
2691 #[test]
2692 fn json_round_trip_preserves_predictions() {
2693 let model = make_trained_distributional();
2694 let state = model.to_distributional_state();
2695 let json = crate::serde_support::save_distributional_model(&state).unwrap();
2696 let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
2697 let restored = DistributionalSGBT::from_distributional_state(loaded_state);
2698
2699 let test_points = [0.5, 1.0, 2.0, 3.0];
2700 for &x in &test_points {
2701 let orig = model.predict(&[x]);
2702 let rest = restored.predict(&[x]);
2703 assert!(
2704 (orig.mu - rest.mu).abs() < 1e-10,
2705 "JSON round-trip mu mismatch at x={}: {} vs {}",
2706 x,
2707 orig.mu,
2708 rest.mu
2709 );
2710 assert!(
2711 (orig.sigma - rest.sigma).abs() < 1e-10,
2712 "JSON round-trip sigma mismatch at x={}: {} vs {}",
2713 x,
2714 orig.sigma,
2715 rest.sigma
2716 );
2717 }
2718 }
2719
2720 #[test]
2721 fn state_preserves_rolling_sigma_mean() {
2722 let config = SGBTConfig::builder()
2723 .n_steps(5)
2724 .learning_rate(0.1)
2725 .max_depth(3)
2726 .grace_period(2)
2727 .initial_target_count(10)
2728 .uncertainty_modulated_lr(true)
2729 .build()
2730 .unwrap();
2731 let mut model = DistributionalSGBT::new(config);
2732 for i in 0..50 {
2733 let x = i as f64 * 0.1;
2734 model.train_one(&(vec![x], x.sin()));
2735 }
2736 let state = model.to_distributional_state();
2737 assert!(state.uncertainty_modulated_lr);
2738 assert!(state.rolling_sigma_mean >= 0.0);
2739
2740 let restored = DistributionalSGBT::from_distributional_state(state);
2741 assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
2742 }
2743
2744 #[test]
2745 fn auto_bandwidth_computed_distributional() {
2746 let config = SGBTConfig::builder()
2747 .n_steps(3)
2748 .learning_rate(0.1)
2749 .grace_period(10)
2750 .initial_target_count(10)
2751 .build()
2752 .unwrap();
2753 let mut model = DistributionalSGBT::new(config);
2754
2755 for i in 0..200 {
2756 let x = (i as f64) * 0.1;
2757 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2758 }
2759
2760 let bws = model.auto_bandwidths();
2762 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2763
2764 let diag = model.diagnostics();
2766 assert_eq!(diag.auto_bandwidths.len(), 2);
2767
2768 assert!(diag.location_trees[0].prediction_mean.is_finite());
2770 assert!(diag.location_trees[0].prediction_std.is_finite());
2771
2772 let pred = model.predict(&[1.0, 1.0_f64.sin()]);
2773 assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
2774 assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
2775 }
2776
2777 #[test]
2778 fn max_leaf_output_clamps_predictions() {
2779 let config = SGBTConfig::builder()
2780 .n_steps(5)
2781 .learning_rate(1.0) .max_leaf_output(0.5)
2783 .build()
2784 .unwrap();
2785 let mut model = DistributionalSGBT::new(config);
2786
2787 for i in 0..200 {
2789 let target = if i % 2 == 0 { 100.0 } else { -100.0 };
2790 let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
2791 model.train_one(&sample);
2792 }
2793
2794 let pred = model.predict(&[2.0, 0.5]);
2796 assert!(
2797 pred.mu.is_finite(),
2798 "prediction should be finite with clamping"
2799 );
2800 }
2801
2802 #[test]
2803 fn min_hessian_sum_suppresses_fresh_leaves() {
2804 let config = SGBTConfig::builder()
2805 .n_steps(3)
2806 .learning_rate(0.01)
2807 .min_hessian_sum(50.0)
2808 .build()
2809 .unwrap();
2810 let mut model = DistributionalSGBT::new(config);
2811
2812 for i in 0..60 {
2814 let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
2815 model.train_one(&sample);
2816 }
2817
2818 let pred = model.predict(&[30.0, 0.5]);
2819 assert!(
2820 pred.mu.is_finite(),
2821 "prediction should be finite with min_hessian_sum"
2822 );
2823 }
2824
2825 #[test]
2826 fn predict_interpolated_returns_finite() {
2827 let config = SGBTConfig::builder()
2828 .n_steps(10)
2829 .learning_rate(0.1)
2830 .grace_period(20)
2831 .max_depth(4)
2832 .n_bins(16)
2833 .initial_target_count(10)
2834 .build()
2835 .unwrap();
2836 let mut model = DistributionalSGBT::new(config);
2837 for i in 0..200 {
2838 let x = i as f64 * 0.1;
2839 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2840 model.train_one(&sample);
2841 }
2842
2843 let pred = model.predict_interpolated(&[1.0, 0.5]);
2844 assert!(pred.mu.is_finite(), "interpolated mu should be finite");
2845 assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
2846 }
2847
2848 #[test]
2849 fn huber_k_bounds_gradients() {
2850 let config = SGBTConfig::builder()
2851 .n_steps(5)
2852 .learning_rate(0.01)
2853 .huber_k(1.345)
2854 .build()
2855 .unwrap();
2856 let mut model = DistributionalSGBT::new(config);
2857
2858 for i in 0..300 {
2860 let target = if i % 50 == 0 {
2861 1000.0
2862 } else {
2863 (i as f64 * 0.1).sin()
2864 };
2865 let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
2866 model.train_one(&sample);
2867 }
2868
2869 let pred = model.predict(&[5.0, 0.3]);
2870 assert!(
2871 pred.mu.is_finite(),
2872 "Huber-loss mu should be finite despite outliers"
2873 );
2874 assert!(pred.sigma > 0.0, "sigma should be positive");
2875 }
2876
2877 #[test]
2878 fn ensemble_gradient_stats_populated() {
2879 let config = SGBTConfig::builder()
2880 .n_steps(10)
2881 .learning_rate(0.1)
2882 .grace_period(20)
2883 .max_depth(4)
2884 .n_bins(16)
2885 .initial_target_count(10)
2886 .build()
2887 .unwrap();
2888 let mut model = DistributionalSGBT::new(config);
2889 for i in 0..200 {
2890 let x = i as f64 * 0.1;
2891 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2892 model.train_one(&sample);
2893 }
2894
2895 let diag = model.diagnostics();
2896 assert!(
2897 diag.ensemble_grad_mean.is_finite(),
2898 "ensemble grad mean should be finite"
2899 );
2900 assert!(
2901 diag.ensemble_grad_std >= 0.0,
2902 "ensemble grad std should be non-negative"
2903 );
2904 assert!(
2905 diag.ensemble_grad_std.is_finite(),
2906 "ensemble grad std should be finite"
2907 );
2908 }
2909
2910 #[test]
2911 fn huber_k_validation() {
2912 let result = SGBTConfig::builder()
2913 .n_steps(5)
2914 .learning_rate(0.01)
2915 .huber_k(-1.0)
2916 .build();
2917 assert!(result.is_err(), "negative huber_k should fail validation");
2918 }
2919
2920 #[test]
2921 fn max_leaf_output_validation() {
2922 let result = SGBTConfig::builder()
2923 .n_steps(5)
2924 .learning_rate(0.01)
2925 .max_leaf_output(-1.0)
2926 .build();
2927 assert!(
2928 result.is_err(),
2929 "negative max_leaf_output should fail validation"
2930 );
2931 }
2932
2933 #[test]
2934 fn predict_sibling_interpolated_varies_with_features() {
2935 let config = SGBTConfig::builder()
2936 .n_steps(10)
2937 .learning_rate(0.1)
2938 .grace_period(10)
2939 .max_depth(6)
2940 .delta(0.1)
2941 .initial_target_count(10)
2942 .build()
2943 .unwrap();
2944 let mut model = DistributionalSGBT::new(config);
2945
2946 for i in 0..2000 {
2947 let x = (i as f64) * 0.01;
2948 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2949 let sample = crate::Sample::new(vec![x, x * 0.3], y);
2950 model.train_one(&sample);
2951 }
2952
2953 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2955 assert!(
2956 pred.mu.is_finite(),
2957 "sibling interpolated mu should be finite"
2958 );
2959 assert!(
2960 pred.sigma > 0.0,
2961 "sibling interpolated sigma should be positive"
2962 );
2963
2964 let bws = model.auto_bandwidths();
2966 if bws.iter().any(|&b| b.is_finite()) {
2967 let hard_preds: Vec<f64> = (0..200)
2968 .map(|i| {
2969 let x = i as f64 * 0.1;
2970 model.predict(&[x, x * 0.3]).mu
2971 })
2972 .collect();
2973 let hard_changes = hard_preds
2974 .windows(2)
2975 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2976 .count();
2977
2978 let preds: Vec<f64> = (0..200)
2979 .map(|i| {
2980 let x = i as f64 * 0.1;
2981 model.predict_sibling_interpolated(&[x, x * 0.3]).mu
2982 })
2983 .collect();
2984
2985 let sibling_changes = preds
2986 .windows(2)
2987 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2988 .count();
2989 assert!(
2990 sibling_changes >= hard_changes,
2991 "sibling should produce >= hard changes: sibling={}, hard={}",
2992 sibling_changes,
2993 hard_changes
2994 );
2995 }
2996 }
2997}