1use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::ensemble::config::{SGBTConfig, ScaleMode};
34use crate::ensemble::step::BoostingStep;
35use crate::sample::{Observation, SampleRef};
36use crate::tree::builder::TreeConfig;
37
38struct PackedInferenceCache {
43 bytes: Vec<u8>,
44 base: f64,
45 n_features: usize,
46}
47
48impl Clone for PackedInferenceCache {
49 fn clone(&self) -> Self {
50 Self {
51 bytes: self.bytes.clone(),
52 base: self.base,
53 n_features: self.n_features,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy)]
60pub struct GaussianPrediction {
61 pub mu: f64,
63 pub sigma: f64,
65 pub log_sigma: f64,
67}
68
69impl GaussianPrediction {
70 #[inline]
74 pub fn lower(&self, z: f64) -> f64 {
75 self.mu - z * self.sigma
76 }
77
78 #[inline]
80 pub fn upper(&self, z: f64) -> f64 {
81 self.mu + z * self.sigma
82 }
83}
84
85#[derive(Debug, Clone)]
91pub struct TreeDiagnostic {
92 pub n_leaves: usize,
94 pub max_depth_reached: usize,
96 pub samples_seen: u64,
98 pub leaf_weight_stats: (f64, f64, f64, f64),
100 pub split_features: Vec<usize>,
102 pub leaf_sample_counts: Vec<u64>,
104 pub prediction_mean: f64,
106 pub prediction_std: f64,
108}
109
110#[derive(Debug, Clone)]
115pub struct ModelDiagnostics {
116 pub trees: Vec<TreeDiagnostic>,
118 pub location_trees: Vec<TreeDiagnostic>,
120 pub scale_trees: Vec<TreeDiagnostic>,
122 pub feature_split_counts: Vec<usize>,
124 pub location_base: f64,
126 pub scale_base: f64,
128 pub empirical_sigma: f64,
130 pub scale_mode: ScaleMode,
132 pub scale_trees_active: usize,
134 pub auto_bandwidths: Vec<f64>,
137 pub ensemble_grad_mean: f64,
139 pub ensemble_grad_std: f64,
141}
142
143#[derive(Debug, Clone)]
145pub struct DecomposedPrediction {
146 pub location_base: f64,
148 pub scale_base: f64,
150 pub location_contributions: Vec<f64>,
153 pub scale_contributions: Vec<f64>,
156}
157
158impl DecomposedPrediction {
159 pub fn mu(&self) -> f64 {
161 self.location_base + self.location_contributions.iter().sum::<f64>()
162 }
163
164 pub fn log_sigma(&self) -> f64 {
166 self.scale_base + self.scale_contributions.iter().sum::<f64>()
167 }
168
169 pub fn sigma(&self) -> f64 {
171 crate::math::exp(self.log_sigma()).max(1e-8)
172 }
173}
174
175pub struct DistributionalSGBT {
201 config: SGBTConfig,
203 location_steps: Vec<BoostingStep>,
205 scale_steps: Vec<BoostingStep>,
207 location_base: f64,
209 scale_base: f64,
211 base_initialized: bool,
213 initial_targets: Vec<f64>,
215 initial_target_count: usize,
217 samples_seen: u64,
219 rng_state: u64,
221 uncertainty_modulated_lr: bool,
223 rolling_sigma_mean: f64,
228 scale_mode: ScaleMode,
230 ewma_sq_err: f64,
235 empirical_sigma_alpha: f64,
237 prev_sigma: f64,
239 sigma_velocity: f64,
242 auto_bandwidths: Vec<f64>,
244 last_replacement_sum: u64,
246 ensemble_grad_mean: f64,
248 ensemble_grad_m2: f64,
250 ensemble_grad_count: u64,
252 packed_cache: Option<PackedInferenceCache>,
254 samples_since_refresh: u64,
256 packed_refresh_interval: u64,
258}
259
260impl Clone for DistributionalSGBT {
261 fn clone(&self) -> Self {
262 Self {
263 config: self.config.clone(),
264 location_steps: self.location_steps.clone(),
265 scale_steps: self.scale_steps.clone(),
266 location_base: self.location_base,
267 scale_base: self.scale_base,
268 base_initialized: self.base_initialized,
269 initial_targets: self.initial_targets.clone(),
270 initial_target_count: self.initial_target_count,
271 samples_seen: self.samples_seen,
272 rng_state: self.rng_state,
273 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
274 rolling_sigma_mean: self.rolling_sigma_mean,
275 scale_mode: self.scale_mode,
276 ewma_sq_err: self.ewma_sq_err,
277 empirical_sigma_alpha: self.empirical_sigma_alpha,
278 prev_sigma: self.prev_sigma,
279 sigma_velocity: self.sigma_velocity,
280 auto_bandwidths: self.auto_bandwidths.clone(),
281 last_replacement_sum: self.last_replacement_sum,
282 ensemble_grad_mean: self.ensemble_grad_mean,
283 ensemble_grad_m2: self.ensemble_grad_m2,
284 ensemble_grad_count: self.ensemble_grad_count,
285 packed_cache: self.packed_cache.clone(),
286 samples_since_refresh: self.samples_since_refresh,
287 packed_refresh_interval: self.packed_refresh_interval,
288 }
289 }
290}
291
292impl core::fmt::Debug for DistributionalSGBT {
293 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
294 let mut s = f.debug_struct("DistributionalSGBT");
295 s.field("n_steps", &self.location_steps.len())
296 .field("samples_seen", &self.samples_seen)
297 .field("location_base", &self.location_base)
298 .field("scale_mode", &self.scale_mode)
299 .field("base_initialized", &self.base_initialized);
300 match self.scale_mode {
301 ScaleMode::Empirical => {
302 s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
303 }
304 ScaleMode::TreeChain => {
305 s.field("scale_base", &self.scale_base);
306 }
307 }
308 if self.uncertainty_modulated_lr {
309 s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
310 }
311 s.finish()
312 }
313}
314
315impl DistributionalSGBT {
316 pub fn new(config: SGBTConfig) -> Self {
322 let leaf_decay_alpha = config
323 .leaf_half_life
324 .map(|hl| crate::math::exp(-(crate::math::ln(2.0_f64)) / hl as f64));
325
326 let tree_config = TreeConfig::new()
327 .max_depth(config.max_depth)
328 .n_bins(config.n_bins)
329 .lambda(config.lambda)
330 .gamma(config.gamma)
331 .grace_period(config.grace_period)
332 .delta(config.delta)
333 .feature_subsample_rate(config.feature_subsample_rate)
334 .leaf_decay_alpha_opt(leaf_decay_alpha)
335 .split_reeval_interval_opt(config.split_reeval_interval)
336 .feature_types_opt(config.feature_types.clone())
337 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
338 .monotone_constraints_opt(config.monotone_constraints.clone())
339 .max_leaf_output_opt(config.max_leaf_output)
340 .adaptive_depth_opt(config.adaptive_depth)
341 .min_hessian_sum_opt(config.min_hessian_sum)
342 .leaf_model_type(config.leaf_model_type.clone());
343
344 let max_tree_samples = config.max_tree_samples;
345
346 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
348 let location_steps: Vec<BoostingStep> = (0..config.n_steps)
349 .map(|i| {
350 let mut tc = tree_config.clone();
351 tc.seed = config.seed ^ (i as u64);
352 let detector = config.drift_detector.create();
353 if shadow_warmup > 0 {
354 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
355 } else {
356 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
357 }
358 })
359 .collect();
360
361 let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
363 .map(|i| {
364 let mut tc = tree_config.clone();
365 tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
366 let detector = config.drift_detector.create();
367 if shadow_warmup > 0 {
368 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
369 } else {
370 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
371 }
372 })
373 .collect();
374
375 let seed = config.seed;
376 let initial_target_count = config.initial_target_count;
377 let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
378 let scale_mode = config.scale_mode;
379 let empirical_sigma_alpha = config.empirical_sigma_alpha;
380 let packed_refresh_interval = config.packed_refresh_interval;
381 Self {
382 config,
383 location_steps,
384 scale_steps,
385 location_base: 0.0,
386 scale_base: 0.0,
387 base_initialized: false,
388 initial_targets: Vec::new(),
389 initial_target_count,
390 samples_seen: 0,
391 rng_state: seed,
392 uncertainty_modulated_lr,
393 rolling_sigma_mean: 1.0, scale_mode,
395 ewma_sq_err: 1.0, empirical_sigma_alpha,
397 prev_sigma: 0.0,
398 sigma_velocity: 0.0,
399 auto_bandwidths: Vec::new(),
400 last_replacement_sum: 0,
401 ensemble_grad_mean: 0.0,
402 ensemble_grad_m2: 0.0,
403 ensemble_grad_count: 0,
404 packed_cache: None,
405 samples_since_refresh: 0,
406 packed_refresh_interval,
407 }
408 }
409
410 pub fn train_one(&mut self, sample: &impl Observation) {
412 self.samples_seen += 1;
413 let target = sample.target();
414 let features = sample.features();
415
416 if !self.base_initialized {
418 self.initial_targets.push(target);
419 if self.initial_targets.len() >= self.initial_target_count {
420 let sum: f64 = self.initial_targets.iter().sum();
422 let mean = sum / self.initial_targets.len() as f64;
423 self.location_base = mean;
424
425 let var: f64 = self
427 .initial_targets
428 .iter()
429 .map(|&y| (y - mean) * (y - mean))
430 .sum::<f64>()
431 / self.initial_targets.len() as f64;
432 let initial_std = crate::math::sqrt(var).max(1e-6);
433 self.scale_base = crate::math::ln(initial_std);
434
435 self.rolling_sigma_mean = initial_std;
437 self.ewma_sq_err = var.max(1e-12);
438
439 self.prev_sigma = initial_std;
441 self.sigma_velocity = 0.0;
442
443 self.base_initialized = true;
444 self.initial_targets.clear();
445 self.initial_targets.shrink_to_fit();
446 }
447 return;
448 }
449
450 match self.scale_mode {
451 ScaleMode::Empirical => self.train_one_empirical(target, features),
452 ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
453 }
454
455 self.refresh_bandwidths();
457 }
458
459 fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
461 let mut mu = self.location_base;
463 for s in 0..self.location_steps.len() {
464 mu += self.config.learning_rate * self.location_steps[s].predict(features);
465 }
466
467 let err = target - mu;
469 let alpha = self.empirical_sigma_alpha;
470 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
471 let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
472
473 let sigma_ratio = if self.uncertainty_modulated_lr {
475 let d_sigma = empirical_sigma - self.prev_sigma;
477 self.prev_sigma = empirical_sigma;
478
479 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
481
482 let k_d = if self.rolling_sigma_mean > 1e-12 {
484 crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
485 } else {
486 0.0
487 };
488
489 let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
491 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
492
493 const SIGMA_EWMA_ALPHA: f64 = 0.001;
495 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
496 + SIGMA_EWMA_ALPHA * empirical_sigma;
497
498 ratio
499 } else {
500 1.0
501 };
502
503 let base_lr = self.config.learning_rate;
504
505 let mut mu_accum = self.location_base;
507 for s in 0..self.location_steps.len() {
508 let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
509 self.update_ensemble_grad_stats(g_mu);
511 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
512 let loc_pred =
513 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
514 mu_accum += (base_lr * sigma_ratio) * loc_pred;
515 }
516
517 self.maybe_refresh_packed_cache();
519 }
520
521 fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
523 let mut mu = self.location_base;
524 let mut log_sigma = self.scale_base;
525
526 let sigma_ratio = if self.uncertainty_modulated_lr {
528 let current_sigma = crate::math::exp(log_sigma).max(1e-8);
529
530 let d_sigma = current_sigma - self.prev_sigma;
532 self.prev_sigma = current_sigma;
533
534 let alpha = self.empirical_sigma_alpha;
536 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
537
538 let k_d = if self.rolling_sigma_mean > 1e-12 {
540 crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
541 } else {
542 0.0
543 };
544
545 let pd_sigma = current_sigma + k_d * self.sigma_velocity;
547 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
548
549 const SIGMA_EWMA_ALPHA: f64 = 0.001;
550 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
551 + SIGMA_EWMA_ALPHA * current_sigma;
552
553 ratio
554 } else {
555 1.0
556 };
557
558 let base_lr = self.config.learning_rate;
559
560 for s in 0..self.location_steps.len() {
562 let sigma = crate::math::exp(log_sigma).max(1e-8);
563 let z = (target - mu) / sigma;
564
565 let (g_mu, h_mu) = self.location_gradient(mu, target);
567 self.update_ensemble_grad_stats(g_mu);
569
570 let g_sigma = 1.0 - z * z;
572 let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
573
574 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
575
576 let loc_pred =
578 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
579 mu += (base_lr * sigma_ratio) * loc_pred;
580
581 let scale_pred =
583 self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
584 log_sigma += base_lr * scale_pred;
585 }
586
587 let err = target - mu;
589 let alpha = self.empirical_sigma_alpha;
590 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
591
592 self.maybe_refresh_packed_cache();
594 }
595
596 pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
605 let mu = if let Some(ref cache) = self.packed_cache {
607 let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
608 match crate::EnsembleView::from_bytes(&cache.bytes) {
609 Ok(view) => {
610 let packed_mu = cache.base + view.predict(&features_f32) as f64;
611 if packed_mu.is_finite() {
612 packed_mu
613 } else {
614 self.predict_full_trees(features)
615 }
616 }
617 Err(_) => self.predict_full_trees(features),
618 }
619 } else {
620 self.predict_full_trees(features)
621 };
622
623 let (sigma, log_sigma) = match self.scale_mode {
624 ScaleMode::Empirical => {
625 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
626 (s, crate::math::ln(s))
627 }
628 ScaleMode::TreeChain => {
629 let mut ls = self.scale_base;
630 if self.auto_bandwidths.is_empty() {
631 for s in 0..self.scale_steps.len() {
632 ls += self.config.learning_rate * self.scale_steps[s].predict(features);
633 }
634 } else {
635 for s in 0..self.scale_steps.len() {
636 ls += self.config.learning_rate
637 * self.scale_steps[s]
638 .predict_smooth_auto(features, &self.auto_bandwidths);
639 }
640 }
641 (crate::math::exp(ls).max(1e-8), ls)
642 }
643 };
644
645 GaussianPrediction {
646 mu,
647 sigma,
648 log_sigma,
649 }
650 }
651
652 fn predict_full_trees(&self, features: &[f64]) -> f64 {
654 let mut mu = self.location_base;
655 if self.auto_bandwidths.is_empty() {
656 for s in 0..self.location_steps.len() {
657 mu += self.config.learning_rate * self.location_steps[s].predict(features);
658 }
659 } else {
660 for s in 0..self.location_steps.len() {
661 mu += self.config.learning_rate
662 * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
663 }
664 }
665 mu
666 }
667
668 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
678 let mut mu = self.location_base;
679 for s in 0..self.location_steps.len() {
680 mu += self.config.learning_rate
681 * self.location_steps[s].predict_smooth(features, bandwidth);
682 }
683
684 let (sigma, log_sigma) = match self.scale_mode {
685 ScaleMode::Empirical => {
686 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
687 (s, crate::math::ln(s))
688 }
689 ScaleMode::TreeChain => {
690 let mut ls = self.scale_base;
691 for s in 0..self.scale_steps.len() {
692 ls += self.config.learning_rate
693 * self.scale_steps[s].predict_smooth(features, bandwidth);
694 }
695 (crate::math::exp(ls).max(1e-8), ls)
696 }
697 };
698
699 GaussianPrediction {
700 mu,
701 sigma,
702 log_sigma,
703 }
704 }
705
706 pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
711 let mut mu = self.location_base;
712 for s in 0..self.location_steps.len() {
713 mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
714 }
715
716 let (sigma, log_sigma) = match self.scale_mode {
717 ScaleMode::Empirical => {
718 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
719 (s, crate::math::ln(s))
720 }
721 ScaleMode::TreeChain => {
722 let mut ls = self.scale_base;
723 for s in 0..self.scale_steps.len() {
724 ls += self.config.learning_rate
725 * self.scale_steps[s].predict_interpolated(features);
726 }
727 (crate::math::exp(ls).max(1e-8), ls)
728 }
729 };
730
731 GaussianPrediction {
732 mu,
733 sigma,
734 log_sigma,
735 }
736 }
737
738 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
744 let mut mu = self.location_base;
745 for s in 0..self.location_steps.len() {
746 mu += self.config.learning_rate
747 * self.location_steps[s]
748 .predict_sibling_interpolated(features, &self.auto_bandwidths);
749 }
750
751 let (sigma, log_sigma) = match self.scale_mode {
752 ScaleMode::Empirical => {
753 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
754 (s, crate::math::ln(s))
755 }
756 ScaleMode::TreeChain => {
757 let mut ls = self.scale_base;
758 for s in 0..self.scale_steps.len() {
759 ls += self.config.learning_rate
760 * self.scale_steps[s]
761 .predict_sibling_interpolated(features, &self.auto_bandwidths);
762 }
763 (crate::math::exp(ls).max(1e-8), ls)
764 }
765 };
766
767 GaussianPrediction {
768 mu,
769 sigma,
770 log_sigma,
771 }
772 }
773
774 pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
779 let mut mu = self.location_base;
780 for s in 0..self.location_steps.len() {
781 mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
782 }
783
784 let (sigma, log_sigma) = match self.scale_mode {
785 ScaleMode::Empirical => {
786 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
787 (s, crate::math::ln(s))
788 }
789 ScaleMode::TreeChain => {
790 let mut ls = self.scale_base;
791 for s in 0..self.scale_steps.len() {
792 ls +=
793 self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
794 }
795 (crate::math::exp(ls).max(1e-8), ls)
796 }
797 };
798
799 GaussianPrediction {
800 mu,
801 sigma,
802 log_sigma,
803 }
804 }
805
806 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
808 let mut mu = self.location_base;
809 for s in 0..self.location_steps.len() {
810 mu += self.config.learning_rate
811 * self.location_steps[s]
812 .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
813 }
814
815 let (sigma, log_sigma) = match self.scale_mode {
816 ScaleMode::Empirical => {
817 let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
818 (s, crate::math::ln(s))
819 }
820 ScaleMode::TreeChain => {
821 let mut ls = self.scale_base;
822 for s in 0..self.scale_steps.len() {
823 ls += self.config.learning_rate
824 * self.scale_steps[s].predict_graduated_sibling_interpolated(
825 features,
826 &self.auto_bandwidths,
827 );
828 }
829 (crate::math::exp(ls).max(1e-8), ls)
830 }
831 };
832
833 GaussianPrediction {
834 mu,
835 sigma,
836 log_sigma,
837 }
838 }
839
840 pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
849 let pred = self.predict(features);
850 let sigma_ratio = if self.uncertainty_modulated_lr {
851 (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
852 } else {
853 1.0
854 };
855 (pred.mu, pred.sigma, sigma_ratio)
856 }
857
858 #[inline]
862 pub fn empirical_sigma(&self) -> f64 {
863 crate::math::sqrt(self.ewma_sq_err)
864 }
865
866 #[inline]
868 pub fn scale_mode(&self) -> ScaleMode {
869 self.scale_mode
870 }
871
872 #[inline]
878 pub fn sigma_velocity(&self) -> f64 {
879 self.sigma_velocity
880 }
881
882 #[inline]
884 pub fn predict_mu(&self, features: &[f64]) -> f64 {
885 self.predict(features).mu
886 }
887
888 #[inline]
890 pub fn predict_sigma(&self, features: &[f64]) -> f64 {
891 self.predict(features).sigma
892 }
893
894 pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
901 let pred = self.predict(features);
902 (pred.lower(confidence), pred.upper(confidence))
903 }
904
905 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
907 feature_matrix.iter().map(|f| self.predict(f)).collect()
908 }
909
910 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
912 for sample in samples {
913 self.train_one(sample);
914 }
915 }
916
917 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
919 &mut self,
920 samples: &[O],
921 interval: usize,
922 mut callback: F,
923 ) {
924 let interval = interval.max(1);
925 for (i, sample) in samples.iter().enumerate() {
926 self.train_one(sample);
927 if (i + 1) % interval == 0 {
928 callback(i + 1);
929 }
930 }
931 let total = samples.len();
932 if total % interval != 0 {
933 callback(total);
934 }
935 }
936
937 #[inline]
942 fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
943 if let Some(k) = self.config.huber_k {
944 let delta = k * crate::math::sqrt(self.ewma_sq_err).max(1e-8);
945 let residual = mu - target;
946 if crate::math::abs(residual) <= delta {
947 (residual, 1.0)
948 } else {
949 (delta * residual.signum(), 1e-6)
950 }
951 } else {
952 (mu - target, 1.0)
953 }
954 }
955
956 #[inline]
958 fn update_ensemble_grad_stats(&mut self, gradient: f64) {
959 self.ensemble_grad_count += 1;
960 let delta = gradient - self.ensemble_grad_mean;
961 self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
962 let delta2 = gradient - self.ensemble_grad_mean;
963 self.ensemble_grad_m2 += delta * delta2;
964 }
965
966 pub fn ensemble_grad_std(&self) -> f64 {
968 if self.ensemble_grad_count < 2 {
969 return 0.0;
970 }
971 crate::math::fmax(
972 crate::math::sqrt(self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64),
973 0.0,
974 )
975 }
976
977 pub fn ensemble_grad_mean(&self) -> f64 {
979 self.ensemble_grad_mean
980 }
981
982 fn maybe_refresh_packed_cache(&mut self) {
984 if self.packed_refresh_interval > 0 {
985 self.samples_since_refresh += 1;
986 if self.samples_since_refresh >= self.packed_refresh_interval {
987 self.refresh_packed_cache();
988 self.samples_since_refresh = 0;
989 }
990 }
991 }
992
993 fn refresh_packed_cache(&mut self) {
999 }
1004
1005 pub fn enable_packed_cache(&mut self, interval: u64) {
1010 self.packed_refresh_interval = interval;
1011 self.samples_since_refresh = 0;
1012 if interval > 0 && self.base_initialized {
1013 self.refresh_packed_cache();
1014 } else if interval == 0 {
1015 self.packed_cache = None;
1016 }
1017 }
1018
1019 #[inline]
1021 pub fn has_packed_cache(&self) -> bool {
1022 self.packed_cache.is_some()
1023 }
1024
1025 fn refresh_bandwidths(&mut self) {
1027 let current_sum: u64 = self
1028 .location_steps
1029 .iter()
1030 .chain(self.scale_steps.iter())
1031 .map(|s| s.slot().replacements())
1032 .sum();
1033 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1034 self.auto_bandwidths = self.compute_auto_bandwidths();
1035 self.last_replacement_sum = current_sum;
1036 }
1037 }
1038
1039 fn compute_auto_bandwidths(&self) -> Vec<f64> {
1041 const K: f64 = 2.0;
1042
1043 let n_features = self
1044 .location_steps
1045 .iter()
1046 .chain(self.scale_steps.iter())
1047 .filter_map(|s| s.slot().active_tree().n_features())
1048 .max()
1049 .unwrap_or(0);
1050
1051 if n_features == 0 {
1052 return Vec::new();
1053 }
1054
1055 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1056
1057 for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1058 let tree_thresholds = step
1059 .slot()
1060 .active_tree()
1061 .collect_split_thresholds_per_feature();
1062 for (i, ts) in tree_thresholds.into_iter().enumerate() {
1063 if i < n_features {
1064 all_thresholds[i].extend(ts);
1065 }
1066 }
1067 }
1068
1069 let n_bins = self.config.n_bins as f64;
1070
1071 all_thresholds
1072 .iter()
1073 .map(|ts| {
1074 if ts.is_empty() {
1075 return f64::INFINITY;
1076 }
1077
1078 let mut sorted = ts.clone();
1079 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1080 sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
1081
1082 if sorted.len() < 2 {
1083 return f64::INFINITY;
1084 }
1085
1086 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1087
1088 if sorted.len() < 3 {
1089 let range = sorted.last().unwrap() - sorted.first().unwrap();
1090 if range < 1e-15 {
1091 return f64::INFINITY;
1092 }
1093 return (range / n_bins) * K;
1094 }
1095
1096 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1097 let median_gap = if gaps.len() % 2 == 0 {
1098 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1099 } else {
1100 gaps[gaps.len() / 2]
1101 };
1102
1103 if median_gap < 1e-15 {
1104 f64::INFINITY
1105 } else {
1106 median_gap * K
1107 }
1108 })
1109 .collect()
1110 }
1111
1112 pub fn auto_bandwidths(&self) -> &[f64] {
1114 &self.auto_bandwidths
1115 }
1116
1117 pub fn reset(&mut self) {
1119 for step in &mut self.location_steps {
1120 step.reset();
1121 }
1122 for step in &mut self.scale_steps {
1123 step.reset();
1124 }
1125 self.location_base = 0.0;
1126 self.scale_base = 0.0;
1127 self.base_initialized = false;
1128 self.initial_targets.clear();
1129 self.samples_seen = 0;
1130 self.rng_state = self.config.seed;
1131 self.rolling_sigma_mean = 1.0;
1132 self.ewma_sq_err = 1.0;
1133 self.prev_sigma = 0.0;
1134 self.sigma_velocity = 0.0;
1135 self.auto_bandwidths.clear();
1136 self.last_replacement_sum = 0;
1137 self.ensemble_grad_mean = 0.0;
1138 self.ensemble_grad_m2 = 0.0;
1139 self.ensemble_grad_count = 0;
1140 self.packed_cache = None;
1141 self.samples_since_refresh = 0;
1142 }
1143
1144 #[inline]
1146 pub fn n_samples_seen(&self) -> u64 {
1147 self.samples_seen
1148 }
1149
1150 #[inline]
1152 pub fn n_steps(&self) -> usize {
1153 self.location_steps.len()
1154 }
1155
1156 pub fn n_trees(&self) -> usize {
1158 let loc = self.location_steps.len()
1159 + self
1160 .location_steps
1161 .iter()
1162 .filter(|s| s.has_alternate())
1163 .count();
1164 let scale = self.scale_steps.len()
1165 + self
1166 .scale_steps
1167 .iter()
1168 .filter(|s| s.has_alternate())
1169 .count();
1170 loc + scale
1171 }
1172
1173 pub fn total_leaves(&self) -> usize {
1175 let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1176 let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1177 loc + scale
1178 }
1179
1180 #[inline]
1182 pub fn is_initialized(&self) -> bool {
1183 self.base_initialized
1184 }
1185
1186 #[inline]
1188 pub fn config(&self) -> &SGBTConfig {
1189 &self.config
1190 }
1191
1192 pub fn location_steps(&self) -> &[BoostingStep] {
1194 &self.location_steps
1195 }
1196
1197 #[inline]
1199 pub fn location_base(&self) -> f64 {
1200 self.location_base
1201 }
1202
1203 #[inline]
1205 pub fn learning_rate(&self) -> f64 {
1206 self.config.learning_rate
1207 }
1208
1209 #[inline]
1213 pub fn rolling_sigma_mean(&self) -> f64 {
1214 self.rolling_sigma_mean
1215 }
1216
1217 #[inline]
1219 pub fn is_uncertainty_modulated(&self) -> bool {
1220 self.uncertainty_modulated_lr
1221 }
1222
1223 pub fn diagnostics(&self) -> ModelDiagnostics {
1235 let n = self.location_steps.len();
1236 let mut trees = Vec::with_capacity(2 * n);
1237 let mut feature_split_counts: Vec<usize> = Vec::new();
1238
1239 fn collect_tree_diags(
1240 steps: &[BoostingStep],
1241 trees: &mut Vec<TreeDiagnostic>,
1242 feature_split_counts: &mut Vec<usize>,
1243 ) {
1244 for step in steps {
1245 let slot = step.slot();
1246 let tree = slot.active_tree();
1247 let arena = tree.arena();
1248
1249 let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1250 .filter(|&i| arena.is_leaf[i])
1251 .map(|i| arena.leaf_value[i])
1252 .collect();
1253
1254 let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1255 .filter(|&i| arena.is_leaf[i])
1256 .map(|i| arena.sample_count[i])
1257 .collect();
1258
1259 let max_depth_reached = (0..arena.is_leaf.len())
1260 .filter(|&i| arena.is_leaf[i])
1261 .map(|i| arena.depth[i] as usize)
1262 .max()
1263 .unwrap_or(0);
1264
1265 let leaf_weight_stats = if leaf_values.is_empty() {
1266 (0.0, 0.0, 0.0, 0.0)
1267 } else {
1268 let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1269 let max = leaf_values
1270 .iter()
1271 .cloned()
1272 .fold(f64::NEG_INFINITY, f64::max);
1273 let sum: f64 = leaf_values.iter().sum();
1274 let mean = sum / leaf_values.len() as f64;
1275 let var: f64 = leaf_values
1276 .iter()
1277 .map(|v| crate::math::powi(v - mean, 2))
1278 .sum::<f64>()
1279 / leaf_values.len() as f64;
1280 (min, max, mean, crate::math::sqrt(var))
1281 };
1282
1283 let gains = slot.split_gains();
1284 let split_features: Vec<usize> = gains
1285 .iter()
1286 .enumerate()
1287 .filter(|(_, &g)| g > 0.0)
1288 .map(|(i, _)| i)
1289 .collect();
1290
1291 if !gains.is_empty() {
1292 if feature_split_counts.is_empty() {
1293 feature_split_counts.resize(gains.len(), 0);
1294 }
1295 for &fi in &split_features {
1296 if fi < feature_split_counts.len() {
1297 feature_split_counts[fi] += 1;
1298 }
1299 }
1300 }
1301
1302 trees.push(TreeDiagnostic {
1303 n_leaves: leaf_values.len(),
1304 max_depth_reached,
1305 samples_seen: step.n_samples_seen(),
1306 leaf_weight_stats,
1307 split_features,
1308 leaf_sample_counts,
1309 prediction_mean: slot.prediction_mean(),
1310 prediction_std: slot.prediction_std(),
1311 });
1312 }
1313 }
1314
1315 collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1316 collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1317
1318 let location_trees = trees[..n].to_vec();
1319 let scale_trees = trees[n..].to_vec();
1320 let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1321
1322 ModelDiagnostics {
1323 trees,
1324 location_trees,
1325 scale_trees,
1326 feature_split_counts,
1327 location_base: self.location_base,
1328 scale_base: self.scale_base,
1329 empirical_sigma: crate::math::sqrt(self.ewma_sq_err),
1330 scale_mode: self.scale_mode,
1331 scale_trees_active,
1332 auto_bandwidths: self.auto_bandwidths.clone(),
1333 ensemble_grad_mean: self.ensemble_grad_mean,
1334 ensemble_grad_std: self.ensemble_grad_std(),
1335 }
1336 }
1337
1338 pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1350 let lr = self.config.learning_rate;
1351 let location: Vec<f64> = self
1352 .location_steps
1353 .iter()
1354 .map(|s| lr * s.predict(features))
1355 .collect();
1356
1357 let (sb, scale) = match self.scale_mode {
1358 ScaleMode::Empirical => {
1359 let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1360 (
1361 crate::math::ln(empirical_sigma),
1362 vec![0.0; self.location_steps.len()],
1363 )
1364 }
1365 ScaleMode::TreeChain => {
1366 let s: Vec<f64> = self
1367 .scale_steps
1368 .iter()
1369 .map(|s| lr * s.predict(features))
1370 .collect();
1371 (self.scale_base, s)
1372 }
1373 };
1374
1375 DecomposedPrediction {
1376 location_base: self.location_base,
1377 scale_base: sb,
1378 location_contributions: location,
1379 scale_contributions: scale,
1380 }
1381 }
1382
1383 pub fn feature_importances(&self) -> Vec<f64> {
1389 let mut totals: Vec<f64> = Vec::new();
1390 for steps in [&self.location_steps, &self.scale_steps] {
1391 for step in steps {
1392 let gains = step.slot().split_gains();
1393 if totals.is_empty() && !gains.is_empty() {
1394 totals.resize(gains.len(), 0.0);
1395 }
1396 for (i, &g) in gains.iter().enumerate() {
1397 if i < totals.len() {
1398 totals[i] += g;
1399 }
1400 }
1401 }
1402 }
1403 let sum: f64 = totals.iter().sum();
1404 if sum > 0.0 {
1405 totals.iter_mut().for_each(|v| *v /= sum);
1406 }
1407 totals
1408 }
1409
1410 pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1415 fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1416 let mut totals: Vec<f64> = Vec::new();
1417 for step in steps {
1418 let gains = step.slot().split_gains();
1419 if totals.is_empty() && !gains.is_empty() {
1420 totals.resize(gains.len(), 0.0);
1421 }
1422 for (i, &g) in gains.iter().enumerate() {
1423 if i < totals.len() {
1424 totals[i] += g;
1425 }
1426 }
1427 }
1428 let sum: f64 = totals.iter().sum();
1429 if sum > 0.0 {
1430 totals.iter_mut().for_each(|v| *v /= sum);
1431 }
1432 totals
1433 }
1434 (
1435 aggregate(&self.location_steps),
1436 aggregate(&self.scale_steps),
1437 )
1438 }
1439
1440 #[cfg(feature = "_serde_support")]
1446 pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1447 use super::snapshot_tree;
1448 use crate::serde_support::{DistributionalModelState, StepSnapshot};
1449
1450 fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1451 let slot = step.slot();
1452 let tree_snap = snapshot_tree(slot.active_tree());
1453 let alt_snap = slot.alternate_tree().map(snapshot_tree);
1454 let drift_state = slot.detector().serialize_state();
1455 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1456 StepSnapshot {
1457 tree: tree_snap,
1458 alternate_tree: alt_snap,
1459 drift_state,
1460 alt_drift_state,
1461 }
1462 }
1463
1464 DistributionalModelState {
1465 config: self.config.clone(),
1466 location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1467 scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1468 location_base: self.location_base,
1469 scale_base: self.scale_base,
1470 base_initialized: self.base_initialized,
1471 initial_targets: self.initial_targets.clone(),
1472 initial_target_count: self.initial_target_count,
1473 samples_seen: self.samples_seen,
1474 rng_state: self.rng_state,
1475 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1476 rolling_sigma_mean: self.rolling_sigma_mean,
1477 ewma_sq_err: self.ewma_sq_err,
1478 }
1479 }
1480
1481 #[cfg(feature = "_serde_support")]
1487 pub fn from_distributional_state(
1488 state: crate::serde_support::DistributionalModelState,
1489 ) -> Self {
1490 use super::rebuild_tree;
1491 use crate::ensemble::replacement::TreeSlot;
1492 use crate::serde_support::StepSnapshot;
1493
1494 let leaf_decay_alpha = state
1495 .config
1496 .leaf_half_life
1497 .map(|hl| crate::math::exp((-(crate::math::ln(2.0_f64)) / hl as f64)));
1498 let max_tree_samples = state.config.max_tree_samples;
1499
1500 let base_tree_config = TreeConfig::new()
1501 .max_depth(state.config.max_depth)
1502 .n_bins(state.config.n_bins)
1503 .lambda(state.config.lambda)
1504 .gamma(state.config.gamma)
1505 .grace_period(state.config.grace_period)
1506 .delta(state.config.delta)
1507 .feature_subsample_rate(state.config.feature_subsample_rate)
1508 .leaf_decay_alpha_opt(leaf_decay_alpha)
1509 .split_reeval_interval_opt(state.config.split_reeval_interval)
1510 .feature_types_opt(state.config.feature_types.clone())
1511 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1512 .monotone_constraints_opt(state.config.monotone_constraints.clone())
1513 .adaptive_depth_opt(state.config.adaptive_depth)
1514 .leaf_model_type(state.config.leaf_model_type.clone());
1515
1516 let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
1518 snaps
1519 .iter()
1520 .enumerate()
1521 .map(|(i, snap)| {
1522 let tc = base_tree_config
1523 .clone()
1524 .seed(state.config.seed ^ (i as u64) ^ seed_xor);
1525
1526 let active = rebuild_tree(&snap.tree, tc.clone());
1527 let alternate = snap
1528 .alternate_tree
1529 .as_ref()
1530 .map(|s| rebuild_tree(s, tc.clone()));
1531
1532 let mut detector = state.config.drift_detector.create();
1533 if let Some(ref ds) = snap.drift_state {
1534 detector.restore_state(ds);
1535 }
1536 let mut slot =
1537 TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
1538 if let Some(ref ads) = snap.alt_drift_state {
1539 if let Some(alt_det) = slot.alt_detector_mut() {
1540 alt_det.restore_state(ads);
1541 }
1542 }
1543 BoostingStep::from_slot(slot)
1544 })
1545 .collect()
1546 };
1547
1548 let location_steps = rebuild_steps(&state.location_steps, 0);
1550 let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
1551
1552 let scale_mode = state.config.scale_mode;
1553 let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
1554 let packed_refresh_interval = state.config.packed_refresh_interval;
1555 Self {
1556 config: state.config,
1557 location_steps,
1558 scale_steps,
1559 location_base: state.location_base,
1560 scale_base: state.scale_base,
1561 base_initialized: state.base_initialized,
1562 initial_targets: state.initial_targets,
1563 initial_target_count: state.initial_target_count,
1564 samples_seen: state.samples_seen,
1565 rng_state: state.rng_state,
1566 uncertainty_modulated_lr: state.uncertainty_modulated_lr,
1567 rolling_sigma_mean: state.rolling_sigma_mean,
1568 scale_mode,
1569 ewma_sq_err: state.ewma_sq_err,
1570 empirical_sigma_alpha,
1571 prev_sigma: 0.0,
1572 sigma_velocity: 0.0,
1573 auto_bandwidths: Vec::new(),
1574 last_replacement_sum: 0,
1575 ensemble_grad_mean: 0.0,
1576 ensemble_grad_m2: 0.0,
1577 ensemble_grad_count: 0,
1578 packed_cache: None,
1579 samples_since_refresh: 0,
1580 packed_refresh_interval,
1581 }
1582 }
1583}
1584
1585use crate::learner::StreamingLearner;
1590
1591impl StreamingLearner for DistributionalSGBT {
1592 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
1593 let sample = SampleRef::weighted(features, target, weight);
1594 DistributionalSGBT::train_one(self, &sample);
1596 }
1597
1598 fn predict(&self, features: &[f64]) -> f64 {
1600 DistributionalSGBT::predict(self, features).mu
1601 }
1602
1603 fn n_samples_seen(&self) -> u64 {
1604 self.samples_seen
1605 }
1606
1607 fn reset(&mut self) {
1608 DistributionalSGBT::reset(self);
1609 }
1610}
1611
1612#[cfg(test)]
1617mod tests {
1618 use super::*;
1619 use alloc::format;
1620 use alloc::vec;
1621 use alloc::vec::Vec;
1622
1623 fn test_config() -> SGBTConfig {
1624 SGBTConfig::builder()
1625 .n_steps(10)
1626 .learning_rate(0.1)
1627 .grace_period(20)
1628 .max_depth(4)
1629 .n_bins(16)
1630 .initial_target_count(10)
1631 .build()
1632 .unwrap()
1633 }
1634
1635 #[test]
1636 fn fresh_model_predicts_zero() {
1637 let model = DistributionalSGBT::new(test_config());
1638 let pred = model.predict(&[1.0, 2.0, 3.0]);
1639 assert!(pred.mu.abs() < 1e-12);
1640 assert!(pred.sigma > 0.0);
1641 }
1642
1643 #[test]
1644 fn sigma_always_positive() {
1645 let mut model = DistributionalSGBT::new(test_config());
1646
1647 for i in 0..200 {
1649 let x = i as f64 * 0.1;
1650 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1651 }
1652
1653 for i in 0..20 {
1655 let x = i as f64 * 0.5;
1656 let pred = model.predict(&[x, x * 0.5]);
1657 assert!(
1658 pred.sigma > 0.0,
1659 "sigma must be positive, got {}",
1660 pred.sigma
1661 );
1662 assert!(pred.sigma.is_finite(), "sigma must be finite");
1663 }
1664 }
1665
1666 #[test]
1667 fn constant_target_has_small_sigma() {
1668 let mut model = DistributionalSGBT::new(test_config());
1669
1670 for i in 0..200 {
1672 let x = i as f64 * 0.1;
1673 model.train_one(&(vec![x, x * 2.0], 5.0));
1674 }
1675
1676 let pred = model.predict(&[1.0, 2.0]);
1677 assert!(pred.mu.is_finite());
1678 assert!(pred.sigma.is_finite());
1679 assert!(pred.sigma > 0.0);
1680 }
1683
1684 #[test]
1685 fn noisy_target_has_finite_predictions() {
1686 let mut model = DistributionalSGBT::new(test_config());
1687
1688 let mut rng: u64 = 42;
1690 for i in 0..200 {
1691 rng ^= rng << 13;
1692 rng ^= rng >> 7;
1693 rng ^= rng << 17;
1694 let noise = (rng % 1000) as f64 / 500.0 - 1.0; let x = i as f64 * 0.1;
1696 model.train_one(&(vec![x], x * 2.0 + noise));
1697 }
1698
1699 let pred = model.predict(&[5.0]);
1700 assert!(pred.mu.is_finite());
1701 assert!(pred.sigma.is_finite());
1702 assert!(pred.sigma > 0.0);
1703 }
1704
1705 #[test]
1706 fn predict_interval_bounds_correct() {
1707 let mut model = DistributionalSGBT::new(test_config());
1708
1709 for i in 0..200 {
1710 let x = i as f64 * 0.1;
1711 model.train_one(&(vec![x], x * 2.0));
1712 }
1713
1714 let (lo, hi) = model.predict_interval(&[5.0], 1.96);
1715 let pred = model.predict(&[5.0]);
1716
1717 assert!(lo < pred.mu, "lower bound should be < mu");
1718 assert!(hi > pred.mu, "upper bound should be > mu");
1719 assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
1720 }
1721
1722 #[test]
1723 fn batch_prediction_matches_individual() {
1724 let mut model = DistributionalSGBT::new(test_config());
1725
1726 for i in 0..100 {
1727 let x = i as f64 * 0.1;
1728 model.train_one(&(vec![x, x * 2.0], x));
1729 }
1730
1731 let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
1732 let batch = model.predict_batch(&features);
1733
1734 for (feat, batch_pred) in features.iter().zip(batch.iter()) {
1735 let individual = model.predict(feat);
1736 assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
1737 assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
1738 }
1739 }
1740
1741 #[test]
1742 fn reset_clears_state() {
1743 let mut model = DistributionalSGBT::new(test_config());
1744
1745 for i in 0..200 {
1746 let x = i as f64 * 0.1;
1747 model.train_one(&(vec![x], x * 2.0));
1748 }
1749
1750 assert!(model.n_samples_seen() > 0);
1751 model.reset();
1752
1753 assert_eq!(model.n_samples_seen(), 0);
1754 assert!(!model.is_initialized());
1755 }
1756
1757 #[test]
1758 fn gaussian_prediction_lower_upper() {
1759 let pred = GaussianPrediction {
1760 mu: 10.0,
1761 sigma: 2.0,
1762 log_sigma: 2.0_f64.ln(),
1763 };
1764
1765 assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
1766 assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
1767 }
1768
1769 #[test]
1770 fn train_batch_works() {
1771 let mut model = DistributionalSGBT::new(test_config());
1772 let samples: Vec<(Vec<f64>, f64)> = (0..100)
1773 .map(|i| {
1774 let x = i as f64 * 0.1;
1775 (vec![x], x * 2.0)
1776 })
1777 .collect();
1778
1779 model.train_batch(&samples);
1780 assert_eq!(model.n_samples_seen(), 100);
1781 }
1782
1783 #[test]
1784 fn debug_format_works() {
1785 let model = DistributionalSGBT::new(test_config());
1786 let debug = format!("{:?}", model);
1787 assert!(debug.contains("DistributionalSGBT"));
1788 }
1789
1790 #[test]
1791 fn n_trees_counts_both_ensembles() {
1792 let model = DistributionalSGBT::new(test_config());
1793 assert!(model.n_trees() >= 20);
1795 }
1796
1797 fn modulated_config() -> SGBTConfig {
1800 SGBTConfig::builder()
1801 .n_steps(10)
1802 .learning_rate(0.1)
1803 .grace_period(20)
1804 .max_depth(4)
1805 .n_bins(16)
1806 .initial_target_count(10)
1807 .uncertainty_modulated_lr(true)
1808 .build()
1809 .unwrap()
1810 }
1811
1812 #[test]
1813 fn sigma_modulated_initializes_rolling_mean() {
1814 let mut model = DistributionalSGBT::new(modulated_config());
1815 assert!(model.is_uncertainty_modulated());
1816
1817 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1819
1820 for i in 0..200 {
1822 let x = i as f64 * 0.1;
1823 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1824 }
1825
1826 assert!(model.rolling_sigma_mean() > 0.0);
1828 assert!(model.rolling_sigma_mean().is_finite());
1829 }
1830
1831 #[test]
1832 fn predict_distributional_returns_sigma_ratio() {
1833 let mut model = DistributionalSGBT::new(modulated_config());
1834
1835 for i in 0..200 {
1836 let x = i as f64 * 0.1;
1837 model.train_one(&(vec![x], x * 2.0 + 1.0));
1838 }
1839
1840 let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1841 assert!(mu.is_finite());
1842 assert!(sigma > 0.0);
1843 assert!(
1844 (0.1..=10.0).contains(&sigma_ratio),
1845 "sigma_ratio={}",
1846 sigma_ratio
1847 );
1848 }
1849
1850 #[test]
1851 fn predict_distributional_without_modulation_returns_one() {
1852 let mut model = DistributionalSGBT::new(test_config());
1853 assert!(!model.is_uncertainty_modulated());
1854
1855 for i in 0..200 {
1856 let x = i as f64 * 0.1;
1857 model.train_one(&(vec![x], x * 2.0));
1858 }
1859
1860 let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1861 assert!(
1862 (sigma_ratio - 1.0).abs() < 1e-12,
1863 "should be 1.0 when disabled"
1864 );
1865 }
1866
1867 #[test]
1868 fn modulated_model_sigma_finite_under_varying_noise() {
1869 let mut model = DistributionalSGBT::new(modulated_config());
1870
1871 let mut rng: u64 = 123;
1872 for i in 0..500 {
1873 rng ^= rng << 13;
1874 rng ^= rng >> 7;
1875 rng ^= rng << 17;
1876 let noise = (rng % 1000) as f64 / 100.0 - 5.0; let x = i as f64 * 0.1;
1878 let scale = if i < 250 { 1.0 } else { 5.0 };
1880 model.train_one(&(vec![x], x * 2.0 + noise * scale));
1881 }
1882
1883 let pred = model.predict(&[10.0]);
1884 assert!(pred.mu.is_finite());
1885 assert!(pred.sigma.is_finite());
1886 assert!(pred.sigma > 0.0);
1887 assert!(model.rolling_sigma_mean().is_finite());
1888 }
1889
1890 #[test]
1891 fn reset_clears_rolling_sigma_mean() {
1892 let mut model = DistributionalSGBT::new(modulated_config());
1893
1894 for i in 0..200 {
1895 let x = i as f64 * 0.1;
1896 model.train_one(&(vec![x], x * 2.0));
1897 }
1898
1899 let sigma_before = model.rolling_sigma_mean();
1900 assert!(sigma_before > 0.0);
1901
1902 model.reset();
1903 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1904 }
1905
1906 #[test]
1907 fn streaming_learner_returns_mu() {
1908 let mut model = DistributionalSGBT::new(test_config());
1909 for i in 0..200 {
1910 let x = i as f64 * 0.1;
1911 StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
1912 }
1913 let pred = StreamingLearner::predict(&model, &[5.0]);
1914 let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
1915 assert!(
1916 (pred - gaussian.mu).abs() < 1e-12,
1917 "StreamingLearner::predict should return mu"
1918 );
1919 }
1920
1921 fn trained_model() -> DistributionalSGBT {
1924 let config = SGBTConfig::builder()
1925 .n_steps(10)
1926 .learning_rate(0.1)
1927 .grace_period(10) .max_depth(4)
1929 .n_bins(16)
1930 .initial_target_count(10)
1931 .build()
1932 .unwrap();
1933 let mut model = DistributionalSGBT::new(config);
1934 for i in 0..500 {
1935 let x = i as f64 * 0.1;
1936 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
1937 }
1938 model
1939 }
1940
1941 #[test]
1942 fn diagnostics_returns_correct_tree_count() {
1943 let model = trained_model();
1944 let diag = model.diagnostics();
1945 assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
1947 }
1948
1949 #[test]
1950 fn diagnostics_trees_have_leaves() {
1951 let model = trained_model();
1952 let diag = model.diagnostics();
1953 for (i, tree) in diag.trees.iter().enumerate() {
1954 assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
1955 }
1956 let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
1958 assert!(
1959 total_samples > 0,
1960 "at least some trees should have seen samples"
1961 );
1962 }
1963
1964 #[test]
1965 fn diagnostics_leaf_weight_stats_finite() {
1966 let model = trained_model();
1967 let diag = model.diagnostics();
1968 for (i, tree) in diag.trees.iter().enumerate() {
1969 let (min, max, mean, std) = tree.leaf_weight_stats;
1970 assert!(min.is_finite(), "tree {i} min not finite");
1971 assert!(max.is_finite(), "tree {i} max not finite");
1972 assert!(mean.is_finite(), "tree {i} mean not finite");
1973 assert!(std.is_finite(), "tree {i} std not finite");
1974 assert!(min <= max, "tree {i} min > max");
1975 }
1976 }
1977
1978 #[test]
1979 fn diagnostics_base_predictions_match() {
1980 let model = trained_model();
1981 let diag = model.diagnostics();
1982 assert!(
1983 (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
1984 "location_base should be plausible"
1985 );
1986 }
1987
1988 #[test]
1989 fn predict_decomposed_reconstructs_prediction() {
1990 let model = trained_model();
1991 let features = [5.0, 2.5, 1.0];
1992 let pred = model.predict(&features);
1993 let decomp = model.predict_decomposed(&features);
1994
1995 assert!(
1996 (decomp.mu() - pred.mu).abs() < 1e-10,
1997 "decomposed mu ({}) != predict mu ({})",
1998 decomp.mu(),
1999 pred.mu
2000 );
2001 assert!(
2002 (decomp.sigma() - pred.sigma).abs() < 1e-10,
2003 "decomposed sigma ({}) != predict sigma ({})",
2004 decomp.sigma(),
2005 pred.sigma
2006 );
2007 }
2008
2009 #[test]
2010 fn predict_decomposed_correct_lengths() {
2011 let model = trained_model();
2012 let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2013 assert_eq!(
2014 decomp.location_contributions.len(),
2015 model.n_steps(),
2016 "location contributions should have n_steps entries"
2017 );
2018 assert_eq!(
2019 decomp.scale_contributions.len(),
2020 model.n_steps(),
2021 "scale contributions should have n_steps entries"
2022 );
2023 }
2024
2025 #[test]
2026 fn feature_importances_work() {
2027 let model = trained_model();
2028 let imp = model.feature_importances();
2029 for (i, &v) in imp.iter().enumerate() {
2032 assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2033 assert!(v.is_finite(), "importance {i} should be finite");
2034 }
2035 let sum: f64 = imp.iter().sum();
2036 if sum > 0.0 {
2037 assert!(
2038 (sum - 1.0).abs() < 1e-10,
2039 "non-zero importances should sum to 1.0, got {sum}"
2040 );
2041 }
2042 }
2043
2044 #[test]
2045 fn feature_importances_split_works() {
2046 let model = trained_model();
2047 let (loc_imp, scale_imp) = model.feature_importances_split();
2048 for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2049 let sum: f64 = imp.iter().sum();
2050 if sum > 0.0 {
2051 assert!(
2052 (sum - 1.0).abs() < 1e-10,
2053 "{name} importances should sum to 1.0, got {sum}"
2054 );
2055 }
2056 for &v in imp.iter() {
2057 assert!(v >= 0.0 && v.is_finite());
2058 }
2059 }
2060 }
2061
2062 #[test]
2065 fn empirical_sigma_default_mode() {
2066 use crate::ensemble::config::ScaleMode;
2067 let config = test_config();
2068 let model = DistributionalSGBT::new(config);
2069 assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2070 }
2071
2072 #[test]
2073 fn empirical_sigma_tracks_errors() {
2074 let mut model = DistributionalSGBT::new(test_config());
2075
2076 for i in 0..200 {
2078 let x = i as f64 * 0.1;
2079 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2080 }
2081
2082 let sigma_clean = model.empirical_sigma();
2083 assert!(sigma_clean > 0.0, "sigma should be positive");
2084 assert!(sigma_clean.is_finite(), "sigma should be finite");
2085
2086 let mut rng: u64 = 42;
2088 for i in 200..400 {
2089 rng ^= rng << 13;
2090 rng ^= rng >> 7;
2091 rng ^= rng << 17;
2092 let noise = (rng % 10000) as f64 / 100.0 - 50.0; let x = i as f64 * 0.1;
2094 model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2095 }
2096
2097 let sigma_noisy = model.empirical_sigma();
2098 assert!(
2099 sigma_noisy > sigma_clean,
2100 "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2101 );
2102 }
2103
2104 #[test]
2105 fn empirical_sigma_modulated_lr_adapts() {
2106 let config = SGBTConfig::builder()
2107 .n_steps(10)
2108 .learning_rate(0.1)
2109 .grace_period(20)
2110 .max_depth(4)
2111 .n_bins(16)
2112 .initial_target_count(10)
2113 .uncertainty_modulated_lr(true)
2114 .build()
2115 .unwrap();
2116 let mut model = DistributionalSGBT::new(config);
2117
2118 for i in 0..300 {
2120 let x = i as f64 * 0.1;
2121 model.train_one(&(vec![x], x * 2.0 + 1.0));
2122 }
2123
2124 let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2125 assert!(sigma_ratio.is_finite());
2126 assert!(
2127 (0.1..=10.0).contains(&sigma_ratio),
2128 "sigma_ratio={sigma_ratio}"
2129 );
2130 }
2131
2132 #[test]
2133 fn tree_chain_mode_trains_scale_trees() {
2134 use crate::ensemble::config::ScaleMode;
2135 let config = SGBTConfig::builder()
2136 .n_steps(10)
2137 .learning_rate(0.1)
2138 .grace_period(10)
2139 .max_depth(4)
2140 .n_bins(16)
2141 .initial_target_count(10)
2142 .scale_mode(ScaleMode::TreeChain)
2143 .build()
2144 .unwrap();
2145 let mut model = DistributionalSGBT::new(config);
2146 assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2147
2148 for i in 0..500 {
2149 let x = i as f64 * 0.1;
2150 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2151 }
2152
2153 let pred = model.predict(&[5.0, 2.5, 1.0]);
2154 assert!(pred.mu.is_finite());
2155 assert!(pred.sigma > 0.0);
2156 assert!(pred.sigma.is_finite());
2157 }
2158
2159 #[test]
2160 fn diagnostics_shows_empirical_sigma() {
2161 let model = trained_model();
2162 let diag = model.diagnostics();
2163 assert!(
2164 diag.empirical_sigma > 0.0,
2165 "empirical_sigma should be positive"
2166 );
2167 assert!(
2168 diag.empirical_sigma.is_finite(),
2169 "empirical_sigma should be finite"
2170 );
2171 }
2172
2173 #[test]
2174 fn diagnostics_scale_trees_split_fields() {
2175 let model = trained_model();
2176 let diag = model.diagnostics();
2177 assert_eq!(diag.location_trees.len(), model.n_steps());
2178 assert_eq!(diag.scale_trees.len(), model.n_steps());
2179 }
2182
2183 #[test]
2184 fn reset_clears_empirical_sigma() {
2185 let mut model = DistributionalSGBT::new(test_config());
2186 for i in 0..200 {
2187 let x = i as f64 * 0.1;
2188 model.train_one(&(vec![x], x * 2.0));
2189 }
2190 model.reset();
2191 assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2193 }
2194
2195 #[test]
2196 fn predict_smooth_returns_finite() {
2197 let config = SGBTConfig::builder()
2198 .n_steps(3)
2199 .learning_rate(0.1)
2200 .grace_period(20)
2201 .max_depth(4)
2202 .n_bins(16)
2203 .initial_target_count(10)
2204 .build()
2205 .unwrap();
2206 let mut model = DistributionalSGBT::new(config);
2207
2208 for i in 0..200 {
2209 let x = (i as f64) * 0.1;
2210 let features = vec![x, x.sin()];
2211 let target = 2.0 * x + 1.0;
2212 model.train_one(&(features, target));
2213 }
2214
2215 let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2216 assert!(pred.mu.is_finite(), "smooth mu should be finite");
2217 assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2218 assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2219 }
2220
2221 #[test]
2224 fn sigma_velocity_responds_to_error_spike() {
2225 let config = SGBTConfig::builder()
2226 .n_steps(3)
2227 .learning_rate(0.1)
2228 .grace_period(20)
2229 .max_depth(4)
2230 .n_bins(16)
2231 .initial_target_count(10)
2232 .uncertainty_modulated_lr(true)
2233 .build()
2234 .unwrap();
2235 let mut model = DistributionalSGBT::new(config);
2236
2237 for i in 0..200 {
2239 let x = (i as f64) * 0.1;
2240 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2241 }
2242
2243 let velocity_before = model.sigma_velocity();
2244
2245 for i in 0..50 {
2247 let x = (i as f64) * 0.1;
2248 model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2249 }
2250
2251 let velocity_after = model.sigma_velocity();
2252
2253 assert!(
2255 velocity_after > velocity_before,
2256 "sigma velocity should increase after error spike: before={}, after={}",
2257 velocity_before,
2258 velocity_after,
2259 );
2260 }
2261
2262 #[test]
2263 fn sigma_velocity_getter_works() {
2264 let config = SGBTConfig::builder()
2265 .n_steps(2)
2266 .learning_rate(0.1)
2267 .grace_period(20)
2268 .max_depth(4)
2269 .n_bins(16)
2270 .initial_target_count(10)
2271 .build()
2272 .unwrap();
2273 let model = DistributionalSGBT::new(config);
2274 assert_eq!(model.sigma_velocity(), 0.0);
2276 }
2277
2278 #[test]
2279 fn diagnostics_leaf_sample_counts_populated() {
2280 let config = SGBTConfig::builder()
2281 .n_steps(3)
2282 .learning_rate(0.1)
2283 .grace_period(10)
2284 .max_depth(4)
2285 .n_bins(16)
2286 .initial_target_count(10)
2287 .build()
2288 .unwrap();
2289 let mut model = DistributionalSGBT::new(config);
2290
2291 for i in 0..200 {
2292 let x = (i as f64) * 0.1;
2293 let features = vec![x, x.sin()];
2294 let target = 2.0 * x + 1.0;
2295 model.train_one(&(features, target));
2296 }
2297
2298 let diags = model.diagnostics();
2299 for (ti, tree) in diags.trees.iter().enumerate() {
2300 assert_eq!(
2301 tree.leaf_sample_counts.len(),
2302 tree.n_leaves,
2303 "tree {} should have sample count per leaf",
2304 ti,
2305 );
2306 if tree.samples_seen > 0 {
2308 let total: u64 = tree.leaf_sample_counts.iter().sum();
2309 assert!(
2310 total > 0,
2311 "tree {} has {} samples_seen but leaf counts sum to 0",
2312 ti,
2313 tree.samples_seen,
2314 );
2315 }
2316 }
2317 }
2318
2319 #[test]
2324 fn packed_cache_disabled_by_default() {
2325 let model = DistributionalSGBT::new(test_config());
2326 assert!(!model.has_packed_cache());
2327 assert_eq!(model.config().packed_refresh_interval, 0);
2328 }
2329
2330 #[test]
2331 #[cfg(feature = "_packed_cache_tests_disabled")]
2332 fn packed_cache_refreshes_after_interval() {
2333 let config = SGBTConfig::builder()
2334 .n_steps(5)
2335 .learning_rate(0.1)
2336 .grace_period(5)
2337 .max_depth(3)
2338 .n_bins(8)
2339 .initial_target_count(10)
2340 .packed_refresh_interval(20)
2341 .build()
2342 .unwrap();
2343
2344 let mut model = DistributionalSGBT::new(config);
2345
2346 for i in 0..40 {
2348 let x = i as f64 * 0.1;
2349 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2350 }
2351
2352 assert!(
2354 model.has_packed_cache(),
2355 "packed cache should exist after training past refresh interval"
2356 );
2357
2358 let pred = model.predict(&[2.0, 4.0, 1.0]);
2360 assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2361 }
2362
2363 #[test]
2364 #[cfg(feature = "_packed_cache_tests_disabled")]
2365 fn packed_cache_matches_full_tree() {
2366 let config = SGBTConfig::builder()
2367 .n_steps(5)
2368 .learning_rate(0.1)
2369 .grace_period(5)
2370 .max_depth(3)
2371 .n_bins(8)
2372 .initial_target_count(10)
2373 .build()
2374 .unwrap();
2375
2376 let mut model = DistributionalSGBT::new(config);
2377
2378 for i in 0..80 {
2380 let x = i as f64 * 0.1;
2381 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2382 }
2383
2384 assert!(!model.has_packed_cache());
2386 let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2387
2388 model.enable_packed_cache(10);
2390 assert!(model.has_packed_cache());
2391 let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2392
2393 let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2395 assert!(
2396 mu_diff < 0.1,
2397 "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2398 cached_pred.mu,
2399 full_pred.mu,
2400 mu_diff
2401 );
2402
2403 assert!(
2405 (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2406 "sigma should be identical: full={}, cached={}",
2407 full_pred.sigma,
2408 cached_pred.sigma
2409 );
2410 }
2411}
2412
2413#[cfg(test)]
2414#[cfg(feature = "_serde_support")]
2415mod serde_tests {
2416 use super::*;
2417 use crate::SGBTConfig;
2418
2419 fn make_trained_distributional() -> DistributionalSGBT {
2420 let config = SGBTConfig::builder()
2421 .n_steps(5)
2422 .learning_rate(0.1)
2423 .max_depth(3)
2424 .grace_period(2)
2425 .initial_target_count(10)
2426 .build()
2427 .unwrap();
2428 let mut model = DistributionalSGBT::new(config);
2429 for i in 0..50 {
2430 let x = i as f64 * 0.1;
2431 model.train_one(&(vec![x], x.sin()));
2432 }
2433 model
2434 }
2435
2436 #[test]
2437 fn json_round_trip_preserves_predictions() {
2438 let model = make_trained_distributional();
2439 let state = model.to_distributional_state();
2440 let json = crate::serde_support::save_distributional_model(&state).unwrap();
2441 let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
2442 let restored = DistributionalSGBT::from_distributional_state(loaded_state);
2443
2444 let test_points = [0.5, 1.0, 2.0, 3.0];
2445 for &x in &test_points {
2446 let orig = model.predict(&[x]);
2447 let rest = restored.predict(&[x]);
2448 assert!(
2449 (orig.mu - rest.mu).abs() < 1e-10,
2450 "JSON round-trip mu mismatch at x={}: {} vs {}",
2451 x,
2452 orig.mu,
2453 rest.mu
2454 );
2455 assert!(
2456 (orig.sigma - rest.sigma).abs() < 1e-10,
2457 "JSON round-trip sigma mismatch at x={}: {} vs {}",
2458 x,
2459 orig.sigma,
2460 rest.sigma
2461 );
2462 }
2463 }
2464
2465 #[test]
2466 fn state_preserves_rolling_sigma_mean() {
2467 let config = SGBTConfig::builder()
2468 .n_steps(5)
2469 .learning_rate(0.1)
2470 .max_depth(3)
2471 .grace_period(2)
2472 .initial_target_count(10)
2473 .uncertainty_modulated_lr(true)
2474 .build()
2475 .unwrap();
2476 let mut model = DistributionalSGBT::new(config);
2477 for i in 0..50 {
2478 let x = i as f64 * 0.1;
2479 model.train_one(&(vec![x], x.sin()));
2480 }
2481 let state = model.to_distributional_state();
2482 assert!(state.uncertainty_modulated_lr);
2483 assert!(state.rolling_sigma_mean >= 0.0);
2484
2485 let restored = DistributionalSGBT::from_distributional_state(state);
2486 assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
2487 }
2488
2489 #[test]
2490 fn auto_bandwidth_computed_distributional() {
2491 let config = SGBTConfig::builder()
2492 .n_steps(3)
2493 .learning_rate(0.1)
2494 .grace_period(10)
2495 .initial_target_count(10)
2496 .build()
2497 .unwrap();
2498 let mut model = DistributionalSGBT::new(config);
2499
2500 for i in 0..200 {
2501 let x = (i as f64) * 0.1;
2502 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2503 }
2504
2505 let bws = model.auto_bandwidths();
2507 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2508
2509 let diag = model.diagnostics();
2511 assert_eq!(diag.auto_bandwidths.len(), 2);
2512
2513 assert!(diag.location_trees[0].prediction_mean.is_finite());
2515 assert!(diag.location_trees[0].prediction_std.is_finite());
2516
2517 let pred = model.predict(&[1.0, 1.0_f64.sin()]);
2518 assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
2519 assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
2520 }
2521
2522 #[test]
2523 fn max_leaf_output_clamps_predictions() {
2524 let config = SGBTConfig::builder()
2525 .n_steps(5)
2526 .learning_rate(1.0) .max_leaf_output(0.5)
2528 .build()
2529 .unwrap();
2530 let mut model = DistributionalSGBT::new(config);
2531
2532 for i in 0..200 {
2534 let target = if i % 2 == 0 { 100.0 } else { -100.0 };
2535 let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
2536 model.train_one(&sample);
2537 }
2538
2539 let pred = model.predict(&[2.0, 0.5]);
2541 assert!(
2542 pred.mu.is_finite(),
2543 "prediction should be finite with clamping"
2544 );
2545 }
2546
2547 #[test]
2548 fn min_hessian_sum_suppresses_fresh_leaves() {
2549 let config = SGBTConfig::builder()
2550 .n_steps(3)
2551 .learning_rate(0.01)
2552 .min_hessian_sum(50.0)
2553 .build()
2554 .unwrap();
2555 let mut model = DistributionalSGBT::new(config);
2556
2557 for i in 0..60 {
2559 let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
2560 model.train_one(&sample);
2561 }
2562
2563 let pred = model.predict(&[30.0, 0.5]);
2564 assert!(
2565 pred.mu.is_finite(),
2566 "prediction should be finite with min_hessian_sum"
2567 );
2568 }
2569
2570 #[test]
2571 fn predict_interpolated_returns_finite() {
2572 let config = SGBTConfig::builder()
2573 .n_steps(10)
2574 .learning_rate(0.1)
2575 .grace_period(20)
2576 .max_depth(4)
2577 .n_bins(16)
2578 .initial_target_count(10)
2579 .build()
2580 .unwrap();
2581 let mut model = DistributionalSGBT::new(config);
2582 for i in 0..200 {
2583 let x = i as f64 * 0.1;
2584 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2585 model.train_one(&sample);
2586 }
2587
2588 let pred = model.predict_interpolated(&[1.0, 0.5]);
2589 assert!(pred.mu.is_finite(), "interpolated mu should be finite");
2590 assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
2591 }
2592
2593 #[test]
2594 fn huber_k_bounds_gradients() {
2595 let config = SGBTConfig::builder()
2596 .n_steps(5)
2597 .learning_rate(0.01)
2598 .huber_k(1.345)
2599 .build()
2600 .unwrap();
2601 let mut model = DistributionalSGBT::new(config);
2602
2603 for i in 0..300 {
2605 let target = if i % 50 == 0 {
2606 1000.0
2607 } else {
2608 (i as f64 * 0.1).sin()
2609 };
2610 let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
2611 model.train_one(&sample);
2612 }
2613
2614 let pred = model.predict(&[5.0, 0.3]);
2615 assert!(
2616 pred.mu.is_finite(),
2617 "Huber-loss mu should be finite despite outliers"
2618 );
2619 assert!(pred.sigma > 0.0, "sigma should be positive");
2620 }
2621
2622 #[test]
2623 fn ensemble_gradient_stats_populated() {
2624 let config = SGBTConfig::builder()
2625 .n_steps(10)
2626 .learning_rate(0.1)
2627 .grace_period(20)
2628 .max_depth(4)
2629 .n_bins(16)
2630 .initial_target_count(10)
2631 .build()
2632 .unwrap();
2633 let mut model = DistributionalSGBT::new(config);
2634 for i in 0..200 {
2635 let x = i as f64 * 0.1;
2636 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2637 model.train_one(&sample);
2638 }
2639
2640 let diag = model.diagnostics();
2641 assert!(
2642 diag.ensemble_grad_mean.is_finite(),
2643 "ensemble grad mean should be finite"
2644 );
2645 assert!(
2646 diag.ensemble_grad_std >= 0.0,
2647 "ensemble grad std should be non-negative"
2648 );
2649 assert!(
2650 diag.ensemble_grad_std.is_finite(),
2651 "ensemble grad std should be finite"
2652 );
2653 }
2654
2655 #[test]
2656 fn huber_k_validation() {
2657 let result = SGBTConfig::builder()
2658 .n_steps(5)
2659 .learning_rate(0.01)
2660 .huber_k(-1.0)
2661 .build();
2662 assert!(result.is_err(), "negative huber_k should fail validation");
2663 }
2664
2665 #[test]
2666 fn max_leaf_output_validation() {
2667 let result = SGBTConfig::builder()
2668 .n_steps(5)
2669 .learning_rate(0.01)
2670 .max_leaf_output(-1.0)
2671 .build();
2672 assert!(
2673 result.is_err(),
2674 "negative max_leaf_output should fail validation"
2675 );
2676 }
2677
2678 #[test]
2679 fn predict_sibling_interpolated_varies_with_features() {
2680 let config = SGBTConfig::builder()
2681 .n_steps(10)
2682 .learning_rate(0.1)
2683 .grace_period(10)
2684 .max_depth(6)
2685 .delta(0.1)
2686 .initial_target_count(10)
2687 .build()
2688 .unwrap();
2689 let mut model = DistributionalSGBT::new(config);
2690
2691 for i in 0..2000 {
2692 let x = (i as f64) * 0.01;
2693 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2694 let sample = crate::Sample::new(vec![x, x * 0.3], y);
2695 model.train_one(&sample);
2696 }
2697
2698 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2700 assert!(
2701 pred.mu.is_finite(),
2702 "sibling interpolated mu should be finite"
2703 );
2704 assert!(
2705 pred.sigma > 0.0,
2706 "sibling interpolated sigma should be positive"
2707 );
2708
2709 let bws = model.auto_bandwidths();
2711 if bws.iter().any(|&b| b.is_finite()) {
2712 let hard_preds: Vec<f64> = (0..200)
2713 .map(|i| {
2714 let x = i as f64 * 0.1;
2715 model.predict(&[x, x * 0.3]).mu
2716 })
2717 .collect();
2718 let hard_changes = hard_preds
2719 .windows(2)
2720 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2721 .count();
2722
2723 let preds: Vec<f64> = (0..200)
2724 .map(|i| {
2725 let x = i as f64 * 0.1;
2726 model.predict_sibling_interpolated(&[x, x * 0.3]).mu
2727 })
2728 .collect();
2729
2730 let sibling_changes = preds
2731 .windows(2)
2732 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2733 .count();
2734 assert!(
2735 sibling_changes >= hard_changes,
2736 "sibling should produce >= hard changes: sibling={}, hard={}",
2737 sibling_changes,
2738 hard_changes
2739 );
2740 }
2741 }
2742}