1use crate::ensemble::config::{SGBTConfig, ScaleMode};
31use crate::ensemble::step::BoostingStep;
32use crate::export_embedded::export_distributional_packed;
33use crate::sample::{Observation, SampleRef};
34use crate::tree::builder::TreeConfig;
35
36struct PackedInferenceCache {
41 bytes: Vec<u8>,
42 base: f64,
43 n_features: usize,
44}
45
46impl Clone for PackedInferenceCache {
47 fn clone(&self) -> Self {
48 Self {
49 bytes: self.bytes.clone(),
50 base: self.base,
51 n_features: self.n_features,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy)]
58pub struct GaussianPrediction {
59 pub mu: f64,
61 pub sigma: f64,
63 pub log_sigma: f64,
65 pub honest_sigma: f64,
74}
75
76impl GaussianPrediction {
77 #[inline]
81 pub fn lower(&self, z: f64) -> f64 {
82 self.mu - z * self.sigma
83 }
84
85 #[inline]
87 pub fn upper(&self, z: f64) -> f64 {
88 self.mu + z * self.sigma
89 }
90}
91
92#[derive(Debug, Clone)]
98pub struct TreeDiagnostic {
99 pub n_leaves: usize,
101 pub max_depth_reached: usize,
103 pub samples_seen: u64,
105 pub leaf_weight_stats: (f64, f64, f64, f64),
107 pub split_features: Vec<usize>,
109 pub leaf_sample_counts: Vec<u64>,
111 pub prediction_mean: f64,
113 pub prediction_std: f64,
115}
116
117#[derive(Debug, Clone)]
122pub struct ModelDiagnostics {
123 pub trees: Vec<TreeDiagnostic>,
125 pub location_trees: Vec<TreeDiagnostic>,
127 pub scale_trees: Vec<TreeDiagnostic>,
129 pub feature_split_counts: Vec<usize>,
131 pub location_base: f64,
133 pub scale_base: f64,
135 pub empirical_sigma: f64,
137 pub scale_mode: ScaleMode,
139 pub scale_trees_active: usize,
141 pub auto_bandwidths: Vec<f64>,
144 pub ensemble_grad_mean: f64,
146 pub ensemble_grad_std: f64,
148}
149
150#[derive(Debug, Clone)]
152pub struct DecomposedPrediction {
153 pub location_base: f64,
155 pub scale_base: f64,
157 pub location_contributions: Vec<f64>,
160 pub scale_contributions: Vec<f64>,
163}
164
165impl DecomposedPrediction {
166 pub fn mu(&self) -> f64 {
168 self.location_base + self.location_contributions.iter().sum::<f64>()
169 }
170
171 pub fn log_sigma(&self) -> f64 {
173 self.scale_base + self.scale_contributions.iter().sum::<f64>()
174 }
175
176 pub fn sigma(&self) -> f64 {
178 self.log_sigma().exp().max(1e-8)
179 }
180}
181
182pub struct DistributionalSGBT {
208 config: SGBTConfig,
210 location_steps: Vec<BoostingStep>,
212 scale_steps: Vec<BoostingStep>,
214 location_base: f64,
216 scale_base: f64,
218 base_initialized: bool,
220 initial_targets: Vec<f64>,
222 initial_target_count: usize,
224 samples_seen: u64,
226 rng_state: u64,
228 uncertainty_modulated_lr: bool,
230 rolling_sigma_mean: f64,
235 scale_mode: ScaleMode,
237 ewma_sq_err: f64,
242 empirical_sigma_alpha: f64,
244 prev_sigma: f64,
246 sigma_velocity: f64,
249 auto_bandwidths: Vec<f64>,
251 last_replacement_sum: u64,
253 ensemble_grad_mean: f64,
255 ensemble_grad_m2: f64,
257 ensemble_grad_count: u64,
259 rolling_honest_sigma_mean: f64,
265 packed_cache: Option<PackedInferenceCache>,
267 samples_since_refresh: u64,
269 packed_refresh_interval: u64,
271 prev_contributions: Vec<f64>,
276 prev_prev_contributions: Vec<f64>,
278 cached_residual_alignment: f64,
280 cached_reg_sensitivity: f64,
282 cached_depth_sufficiency: f64,
284 cached_effective_dof: f64,
286 contribution_accuracy: Vec<f64>,
288 prune_alpha: f64,
290}
291
292impl Clone for DistributionalSGBT {
293 fn clone(&self) -> Self {
294 Self {
295 config: self.config.clone(),
296 location_steps: self.location_steps.clone(),
297 scale_steps: self.scale_steps.clone(),
298 location_base: self.location_base,
299 scale_base: self.scale_base,
300 base_initialized: self.base_initialized,
301 initial_targets: self.initial_targets.clone(),
302 initial_target_count: self.initial_target_count,
303 samples_seen: self.samples_seen,
304 rng_state: self.rng_state,
305 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
306 rolling_sigma_mean: self.rolling_sigma_mean,
307 scale_mode: self.scale_mode,
308 ewma_sq_err: self.ewma_sq_err,
309 empirical_sigma_alpha: self.empirical_sigma_alpha,
310 prev_sigma: self.prev_sigma,
311 sigma_velocity: self.sigma_velocity,
312 auto_bandwidths: self.auto_bandwidths.clone(),
313 last_replacement_sum: self.last_replacement_sum,
314 ensemble_grad_mean: self.ensemble_grad_mean,
315 ensemble_grad_m2: self.ensemble_grad_m2,
316 ensemble_grad_count: self.ensemble_grad_count,
317 rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
318 packed_cache: self.packed_cache.clone(),
319 samples_since_refresh: self.samples_since_refresh,
320 packed_refresh_interval: self.packed_refresh_interval,
321 prev_contributions: self.prev_contributions.clone(),
322 prev_prev_contributions: self.prev_prev_contributions.clone(),
323 cached_residual_alignment: self.cached_residual_alignment,
324 cached_reg_sensitivity: self.cached_reg_sensitivity,
325 cached_depth_sufficiency: self.cached_depth_sufficiency,
326 cached_effective_dof: self.cached_effective_dof,
327 contribution_accuracy: self.contribution_accuracy.clone(),
328 prune_alpha: self.prune_alpha,
329 }
330 }
331}
332
333impl std::fmt::Debug for DistributionalSGBT {
334 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335 let mut s = f.debug_struct("DistributionalSGBT");
336 s.field("n_steps", &self.location_steps.len())
337 .field("samples_seen", &self.samples_seen)
338 .field("location_base", &self.location_base)
339 .field("scale_mode", &self.scale_mode)
340 .field("base_initialized", &self.base_initialized);
341 match self.scale_mode {
342 ScaleMode::Empirical => {
343 s.field("empirical_sigma", &self.ewma_sq_err.sqrt());
344 }
345 ScaleMode::TreeChain => {
346 s.field("scale_base", &self.scale_base);
347 }
348 }
349 if self.uncertainty_modulated_lr {
350 s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
351 }
352 s.finish()
353 }
354}
355
356impl DistributionalSGBT {
357 pub fn new(config: SGBTConfig) -> Self {
363 let leaf_decay_alpha = config
364 .leaf_half_life
365 .map(|hl| (-(2.0_f64.ln()) / hl as f64).exp());
366
367 let tree_config = TreeConfig::new()
368 .max_depth(config.max_depth)
369 .n_bins(config.n_bins)
370 .lambda(config.lambda)
371 .gamma(config.gamma)
372 .grace_period(config.grace_period)
373 .delta(config.delta)
374 .feature_subsample_rate(config.feature_subsample_rate)
375 .leaf_decay_alpha_opt(leaf_decay_alpha)
376 .split_reeval_interval_opt(config.split_reeval_interval)
377 .feature_types_opt(config.feature_types.clone())
378 .gradient_clip_sigma_opt(config.gradient_clip_sigma)
379 .monotone_constraints_opt(config.monotone_constraints.clone())
380 .max_leaf_output_opt(config.max_leaf_output)
381 .adaptive_depth_opt(config.adaptive_depth)
382 .min_hessian_sum_opt(config.min_hessian_sum)
383 .leaf_model_type(config.leaf_model_type.clone());
384
385 let max_tree_samples = config.max_tree_samples;
386
387 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
389 let location_steps: Vec<BoostingStep> = (0..config.n_steps)
390 .map(|i| {
391 let mut tc = tree_config.clone();
392 tc.seed = config.seed ^ (i as u64);
393 let detector = config.drift_detector.create();
394 if shadow_warmup > 0 {
395 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
396 } else {
397 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
398 }
399 })
400 .collect();
401
402 let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
404 .map(|i| {
405 let mut tc = tree_config.clone();
406 tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
407 let detector = config.drift_detector.create();
408 if shadow_warmup > 0 {
409 BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
410 } else {
411 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
412 }
413 })
414 .collect();
415
416 let seed = config.seed;
417 let initial_target_count = config.initial_target_count;
418 let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
419 let scale_mode = config.scale_mode;
420 let empirical_sigma_alpha = config.empirical_sigma_alpha;
421 let packed_refresh_interval = config.packed_refresh_interval;
422 let n_steps = config.n_steps;
423 let prune_alpha = if config.proactive_prune_interval.is_some() {
424 let gp = config.grace_period.max(1) as f64;
425 1.0 - (-2.0 / gp).exp()
426 } else {
427 0.01
428 };
429 Self {
430 config,
431 location_steps,
432 scale_steps,
433 location_base: 0.0,
434 scale_base: 0.0,
435 base_initialized: false,
436 initial_targets: Vec::new(),
437 initial_target_count,
438 samples_seen: 0,
439 rng_state: seed,
440 uncertainty_modulated_lr,
441 rolling_sigma_mean: 1.0, scale_mode,
443 ewma_sq_err: 1.0, empirical_sigma_alpha,
445 prev_sigma: 0.0,
446 sigma_velocity: 0.0,
447 auto_bandwidths: Vec::new(),
448 last_replacement_sum: 0,
449 ensemble_grad_mean: 0.0,
450 ensemble_grad_m2: 0.0,
451 ensemble_grad_count: 0,
452 rolling_honest_sigma_mean: 0.0,
453 packed_cache: None,
454 samples_since_refresh: 0,
455 packed_refresh_interval,
456 prev_contributions: Vec::new(),
457 prev_prev_contributions: Vec::new(),
458 cached_residual_alignment: 0.0,
459 cached_reg_sensitivity: 0.0,
460 cached_depth_sufficiency: 0.0,
461 cached_effective_dof: 0.0,
462 contribution_accuracy: vec![0.0; n_steps],
463 prune_alpha,
464 }
465 }
466
467 fn compute_honest_sigma(&self, features: &[f64]) -> f64 {
473 let n = self.location_steps.len();
474 if n <= 1 {
475 return 0.0;
476 }
477 let lr = self.config.learning_rate;
478 let mut sum = 0.0_f64;
479 let mut sq_sum = 0.0_f64;
480 for step in &self.location_steps {
481 let c = lr * step.predict(features);
482 sum += c;
483 sq_sum += c * c;
484 }
485 let nf = n as f64;
486 let mean_c = sum / nf;
487 let var = (sq_sum / nf) - (mean_c * mean_c);
488 let var_corrected = var * nf / (nf - 1.0);
489 var_corrected.max(0.0).sqrt()
490 }
491
492 pub fn train_one(&mut self, sample: &impl Observation) {
494 self.samples_seen += 1;
495 let target = sample.target();
496 let features = sample.features();
497
498 if !self.base_initialized {
500 self.initial_targets.push(target);
501 if self.initial_targets.len() >= self.initial_target_count {
502 let sum: f64 = self.initial_targets.iter().sum();
504 let mean = sum / self.initial_targets.len() as f64;
505 self.location_base = mean;
506
507 let var: f64 = self
509 .initial_targets
510 .iter()
511 .map(|&y| (y - mean) * (y - mean))
512 .sum::<f64>()
513 / self.initial_targets.len() as f64;
514 let initial_std = var.sqrt().max(1e-6);
515 self.scale_base = initial_std.ln();
516
517 self.rolling_sigma_mean = initial_std;
519 self.ewma_sq_err = var.max(1e-12);
520
521 self.prev_sigma = initial_std;
523 self.sigma_velocity = 0.0;
524
525 self.base_initialized = true;
526 self.initial_targets.clear();
527 self.initial_targets.shrink_to_fit();
528 }
529 return;
530 }
531
532 if let Some((base_mts, k)) = self.config.adaptive_mts {
534 let sigma_ratio = if self.rolling_honest_sigma_mean > 1e-12 {
535 let honest_sigma = self.compute_honest_sigma(features);
536 honest_sigma / self.rolling_honest_sigma_mean
537 } else {
538 1.0
539 };
540 let effective_mts = (base_mts as f64 / (1.0 + k * sigma_ratio)).max(100.0) as u64;
541 for step in &mut self.location_steps {
542 step.slot_mut().set_max_tree_samples(Some(effective_mts));
543 }
544 for step in &mut self.scale_steps {
545 step.slot_mut().set_max_tree_samples(Some(effective_mts));
546 }
547 }
548
549 match self.scale_mode {
550 ScaleMode::Empirical => self.train_one_empirical(target, features),
551 ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
552 }
553
554 if let Some(interval) = self.config.proactive_prune_interval {
556 if self.config.accuracy_based_pruning {
558 let mut location_pred = self.location_base;
559 for step in self.location_steps.iter() {
560 location_pred += self.config.learning_rate * step.predict(features);
561 }
562 let residual = target - location_pred;
563 let sign = residual.signum();
564 for (i, step) in self.location_steps.iter().enumerate() {
565 let contribution = self.config.learning_rate * step.predict(features);
566 let alignment = contribution * sign;
567 self.contribution_accuracy[i] = self.prune_alpha * alignment
568 + (1.0 - self.prune_alpha) * self.contribution_accuracy[i];
569 }
570 }
571
572 if interval > 0 && self.samples_seen % interval == 0 {
574 self.check_proactive_prune();
575 }
576 }
577
578 self.update_diagnostic_cache(features);
580
581 self.refresh_bandwidths();
583 }
584
585 fn update_diagnostic_cache(&mut self, features: &[f64]) {
593 use crate::tree::node::NodeId;
594
595 let lambda = self.config.lambda;
596 let lr = self.config.learning_rate;
597 let n_steps = self.location_steps.len();
598
599 let mut contributions = Vec::with_capacity(n_steps);
601 for step in &self.location_steps {
602 contributions.push(lr * step.predict(features));
603 }
604
605 if !self.prev_contributions.is_empty()
606 && self.prev_contributions.len() == contributions.len()
607 && !self.prev_prev_contributions.is_empty()
608 && self.prev_prev_contributions.len() == contributions.len()
609 {
610 let delta_curr: Vec<f64> = contributions
614 .iter()
615 .zip(&self.prev_contributions)
616 .map(|(a, b)| a - b)
617 .collect();
618 let delta_prev: Vec<f64> = self
619 .prev_contributions
620 .iter()
621 .zip(&self.prev_prev_contributions)
622 .map(|(a, b)| a - b)
623 .collect();
624
625 let dot: f64 = delta_curr.iter().zip(&delta_prev).map(|(a, b)| a * b).sum();
626 let norm_curr: f64 = delta_curr.iter().map(|x| x * x).sum::<f64>().sqrt();
627 let norm_prev: f64 = delta_prev.iter().map(|x| x * x).sum::<f64>().sqrt();
628 self.cached_residual_alignment = if norm_curr > 1e-15 && norm_prev > 1e-15 {
629 dot / (norm_curr * norm_prev)
630 } else {
631 0.0
632 };
633 }
634 self.prev_prev_contributions =
635 core::mem::replace(&mut self.prev_contributions, contributions);
636
637 let mut total_sensitivity = 0.0;
639 let mut total_dof = 0.0;
640 let mut leaf_weights: Vec<f64> = Vec::new();
641 let mut leaf_within_vars: Vec<f64> = Vec::new();
642 let mut n_leaves_total: u64 = 0;
643
644 for step in &self.location_steps {
645 let tree = step.slot().active_tree();
646 let arena = tree.arena();
647
648 for node_idx in 0..arena.n_nodes() {
649 let nid = NodeId(node_idx as u32);
650 if arena.is_leaf(nid) {
651 if let Some((g, h)) = tree.leaf_grad_hess(nid) {
652 let denom = h + lambda;
653 if denom.abs() > 1e-15 {
654 total_sensitivity += g.abs() / (denom * denom);
656 total_dof += h / denom;
658 leaf_weights.push(-g / denom);
660 leaf_within_vars.push(1.0 / denom);
662 n_leaves_total += 1;
663 }
664 }
665 }
666 }
667 }
668
669 if n_leaves_total > 0 {
670 let n = n_leaves_total as f64;
671 self.cached_reg_sensitivity = total_sensitivity / n;
672 self.cached_effective_dof = total_dof;
673
674 let mean_weight = leaf_weights.iter().sum::<f64>() / n;
676 let between_var = leaf_weights
677 .iter()
678 .map(|w| (w - mean_weight).powi(2))
679 .sum::<f64>()
680 / (n - 1.0).max(1.0);
681 let within_var = leaf_within_vars.iter().sum::<f64>() / n;
682 self.cached_depth_sufficiency = between_var / within_var.max(1e-15);
683 }
684 }
685
686 fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
688 let mut mu = self.location_base;
690 for s in 0..self.location_steps.len() {
691 mu += self.config.learning_rate * self.location_steps[s].predict(features);
692 }
693
694 let honest_sigma = self.compute_honest_sigma(features);
696 const HONEST_SIGMA_ALPHA: f64 = 0.001;
697 self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
698 * self.rolling_honest_sigma_mean
699 + HONEST_SIGMA_ALPHA * honest_sigma;
700
701 let err = target - mu;
703 let alpha = self.empirical_sigma_alpha;
704 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
705 let empirical_sigma = self.ewma_sq_err.sqrt().max(1e-8);
706
707 let sigma_ratio = if self.uncertainty_modulated_lr {
709 let d_sigma = empirical_sigma - self.prev_sigma;
711 self.prev_sigma = empirical_sigma;
712
713 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
715
716 let k_d = if self.rolling_sigma_mean > 1e-12 {
718 self.sigma_velocity.abs() / self.rolling_sigma_mean
719 } else {
720 0.0
721 };
722
723 let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
725 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
726
727 const SIGMA_EWMA_ALPHA: f64 = 0.001;
729 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
730 + SIGMA_EWMA_ALPHA * empirical_sigma;
731
732 ratio
733 } else {
734 1.0
735 };
736
737 let base_lr = self.config.learning_rate;
738
739 let mut mu_accum = self.location_base;
741 for s in 0..self.location_steps.len() {
742 let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
743 self.update_ensemble_grad_stats(g_mu);
745 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
746 let loc_pred =
747 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
748 mu_accum += (base_lr * sigma_ratio) * loc_pred;
749 }
750
751 self.maybe_refresh_packed_cache();
753 }
754
755 fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
757 let mut mu = self.location_base;
758 let mut log_sigma = self.scale_base;
759
760 let honest_sigma = self.compute_honest_sigma(features);
762 const HONEST_SIGMA_ALPHA: f64 = 0.001;
763 self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
764 * self.rolling_honest_sigma_mean
765 + HONEST_SIGMA_ALPHA * honest_sigma;
766
767 let sigma_ratio = if self.uncertainty_modulated_lr {
769 let current_sigma = log_sigma.exp().max(1e-8);
770
771 let d_sigma = current_sigma - self.prev_sigma;
773 self.prev_sigma = current_sigma;
774
775 let alpha = self.empirical_sigma_alpha;
777 self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
778
779 let k_d = if self.rolling_sigma_mean > 1e-12 {
781 self.sigma_velocity.abs() / self.rolling_sigma_mean
782 } else {
783 0.0
784 };
785
786 let pd_sigma = current_sigma + k_d * self.sigma_velocity;
788 let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
789
790 const SIGMA_EWMA_ALPHA: f64 = 0.001;
791 self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
792 + SIGMA_EWMA_ALPHA * current_sigma;
793
794 ratio
795 } else {
796 1.0
797 };
798
799 let base_lr = self.config.learning_rate;
800
801 for s in 0..self.location_steps.len() {
803 let sigma = log_sigma.exp().max(1e-8);
804 let z = (target - mu) / sigma;
805
806 let (g_mu, h_mu) = self.location_gradient(mu, target);
808 self.update_ensemble_grad_stats(g_mu);
810
811 let g_sigma = 1.0 - z * z;
813 let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
814
815 let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
816
817 let loc_pred =
819 self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
820 mu += (base_lr * sigma_ratio) * loc_pred;
821
822 let scale_pred =
824 self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
825 log_sigma += base_lr * scale_pred;
826 }
827
828 let err = target - mu;
830 let alpha = self.empirical_sigma_alpha;
831 self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
832
833 self.maybe_refresh_packed_cache();
835 }
836
837 pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
846 let mu = if let Some(ref cache) = self.packed_cache {
848 let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
849 match irithyll_core::EnsembleView::from_bytes(&cache.bytes) {
850 Ok(view) => {
851 let packed_mu = cache.base + view.predict(&features_f32) as f64;
852 if packed_mu.is_finite() {
853 packed_mu
854 } else {
855 self.predict_full_trees(features)
856 }
857 }
858 Err(_) => self.predict_full_trees(features),
859 }
860 } else {
861 self.predict_full_trees(features)
862 };
863
864 let (sigma, log_sigma) = match self.scale_mode {
865 ScaleMode::Empirical => {
866 let s = self.ewma_sq_err.sqrt().max(1e-8);
867 (s, s.ln())
868 }
869 ScaleMode::TreeChain => {
870 let mut ls = self.scale_base;
871 if self.auto_bandwidths.is_empty() {
872 for s in 0..self.scale_steps.len() {
873 ls += self.config.learning_rate * self.scale_steps[s].predict(features);
874 }
875 } else {
876 for s in 0..self.scale_steps.len() {
877 ls += self.config.learning_rate
878 * self.scale_steps[s]
879 .predict_smooth_auto(features, &self.auto_bandwidths);
880 }
881 }
882 (ls.exp().max(1e-8), ls)
883 }
884 };
885
886 let honest_sigma = self.compute_honest_sigma(features);
887
888 GaussianPrediction {
889 mu,
890 sigma,
891 log_sigma,
892 honest_sigma,
893 }
894 }
895
896 fn predict_full_trees(&self, features: &[f64]) -> f64 {
898 let mut mu = self.location_base;
899 if self.auto_bandwidths.is_empty() {
900 for s in 0..self.location_steps.len() {
901 mu += self.config.learning_rate * self.location_steps[s].predict(features);
902 }
903 } else {
904 for s in 0..self.location_steps.len() {
905 mu += self.config.learning_rate
906 * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
907 }
908 }
909 mu
910 }
911
912 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
922 let mut mu = self.location_base;
923 for s in 0..self.location_steps.len() {
924 mu += self.config.learning_rate
925 * self.location_steps[s].predict_smooth(features, bandwidth);
926 }
927
928 let (sigma, log_sigma) = match self.scale_mode {
929 ScaleMode::Empirical => {
930 let s = self.ewma_sq_err.sqrt().max(1e-8);
931 (s, s.ln())
932 }
933 ScaleMode::TreeChain => {
934 let mut ls = self.scale_base;
935 for s in 0..self.scale_steps.len() {
936 ls += self.config.learning_rate
937 * self.scale_steps[s].predict_smooth(features, bandwidth);
938 }
939 (ls.exp().max(1e-8), ls)
940 }
941 };
942
943 let honest_sigma = self.compute_honest_sigma(features);
944
945 GaussianPrediction {
946 mu,
947 sigma,
948 log_sigma,
949 honest_sigma,
950 }
951 }
952
953 pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
958 let mut mu = self.location_base;
959 for s in 0..self.location_steps.len() {
960 mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
961 }
962
963 let (sigma, log_sigma) = match self.scale_mode {
964 ScaleMode::Empirical => {
965 let s = self.ewma_sq_err.sqrt().max(1e-8);
966 (s, s.ln())
967 }
968 ScaleMode::TreeChain => {
969 let mut ls = self.scale_base;
970 for s in 0..self.scale_steps.len() {
971 ls += self.config.learning_rate
972 * self.scale_steps[s].predict_interpolated(features);
973 }
974 (ls.exp().max(1e-8), ls)
975 }
976 };
977
978 let honest_sigma = self.compute_honest_sigma(features);
979
980 GaussianPrediction {
981 mu,
982 sigma,
983 log_sigma,
984 honest_sigma,
985 }
986 }
987
988 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
994 let mut mu = self.location_base;
995 for s in 0..self.location_steps.len() {
996 mu += self.config.learning_rate
997 * self.location_steps[s]
998 .predict_sibling_interpolated(features, &self.auto_bandwidths);
999 }
1000
1001 let (sigma, log_sigma) = match self.scale_mode {
1002 ScaleMode::Empirical => {
1003 let s = self.ewma_sq_err.sqrt().max(1e-8);
1004 (s, s.ln())
1005 }
1006 ScaleMode::TreeChain => {
1007 let mut ls = self.scale_base;
1008 for s in 0..self.scale_steps.len() {
1009 ls += self.config.learning_rate
1010 * self.scale_steps[s]
1011 .predict_sibling_interpolated(features, &self.auto_bandwidths);
1012 }
1013 (ls.exp().max(1e-8), ls)
1014 }
1015 };
1016
1017 let honest_sigma = self.compute_honest_sigma(features);
1018
1019 GaussianPrediction {
1020 mu,
1021 sigma,
1022 log_sigma,
1023 honest_sigma,
1024 }
1025 }
1026
1027 pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
1032 let mut mu = self.location_base;
1033 for s in 0..self.location_steps.len() {
1034 mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
1035 }
1036
1037 let (sigma, log_sigma) = match self.scale_mode {
1038 ScaleMode::Empirical => {
1039 let s = self.ewma_sq_err.sqrt().max(1e-8);
1040 (s, s.ln())
1041 }
1042 ScaleMode::TreeChain => {
1043 let mut ls = self.scale_base;
1044 for s in 0..self.scale_steps.len() {
1045 ls +=
1046 self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
1047 }
1048 (ls.exp().max(1e-8), ls)
1049 }
1050 };
1051
1052 let honest_sigma = self.compute_honest_sigma(features);
1053
1054 GaussianPrediction {
1055 mu,
1056 sigma,
1057 log_sigma,
1058 honest_sigma,
1059 }
1060 }
1061
1062 pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
1064 let mut mu = self.location_base;
1065 for s in 0..self.location_steps.len() {
1066 mu += self.config.learning_rate
1067 * self.location_steps[s]
1068 .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
1069 }
1070
1071 let (sigma, log_sigma) = match self.scale_mode {
1072 ScaleMode::Empirical => {
1073 let s = self.ewma_sq_err.sqrt().max(1e-8);
1074 (s, s.ln())
1075 }
1076 ScaleMode::TreeChain => {
1077 let mut ls = self.scale_base;
1078 for s in 0..self.scale_steps.len() {
1079 ls += self.config.learning_rate
1080 * self.scale_steps[s].predict_graduated_sibling_interpolated(
1081 features,
1082 &self.auto_bandwidths,
1083 );
1084 }
1085 (ls.exp().max(1e-8), ls)
1086 }
1087 };
1088
1089 let honest_sigma = self.compute_honest_sigma(features);
1090
1091 GaussianPrediction {
1092 mu,
1093 sigma,
1094 log_sigma,
1095 honest_sigma,
1096 }
1097 }
1098
1099 pub fn predict_soft_routed(&self, features: &[f64]) -> GaussianPrediction {
1104 let mut mu = self.location_base;
1106 for step in &self.location_steps {
1107 mu += self.config.learning_rate * step.predict_soft_routed(features);
1108 }
1109
1110 let (sigma, log_sigma) = match self.scale_mode {
1112 ScaleMode::Empirical => {
1113 let s = self.ewma_sq_err.sqrt().max(1e-8);
1114 (s, s.ln())
1115 }
1116 ScaleMode::TreeChain => {
1117 let mut ls = self.scale_base;
1118 for step in &self.scale_steps {
1119 ls += self.config.learning_rate * step.predict_soft_routed(features);
1120 }
1121 (ls.exp().max(1e-8), ls)
1122 }
1123 };
1124
1125 let honest_sigma = self.compute_honest_sigma(features);
1126
1127 GaussianPrediction {
1128 mu,
1129 sigma,
1130 log_sigma,
1131 honest_sigma,
1132 }
1133 }
1134
1135 pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
1144 let pred = self.predict(features);
1145 let sigma_ratio = if self.uncertainty_modulated_lr {
1146 (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
1147 } else {
1148 1.0
1149 };
1150 (pred.mu, pred.sigma, sigma_ratio)
1151 }
1152
1153 #[inline]
1157 pub fn empirical_sigma(&self) -> f64 {
1158 self.ewma_sq_err.sqrt()
1159 }
1160
1161 #[inline]
1163 pub fn scale_mode(&self) -> ScaleMode {
1164 self.scale_mode
1165 }
1166
1167 #[inline]
1173 pub fn sigma_velocity(&self) -> f64 {
1174 self.sigma_velocity
1175 }
1176
1177 #[inline]
1179 pub fn predict_mu(&self, features: &[f64]) -> f64 {
1180 self.predict(features).mu
1181 }
1182
1183 #[inline]
1185 pub fn predict_sigma(&self, features: &[f64]) -> f64 {
1186 self.predict(features).sigma
1187 }
1188
1189 pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
1196 let pred = self.predict(features);
1197 (pred.lower(confidence), pred.upper(confidence))
1198 }
1199
1200 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
1202 feature_matrix.iter().map(|f| self.predict(f)).collect()
1203 }
1204
1205 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
1207 for sample in samples {
1208 self.train_one(sample);
1209 }
1210 }
1211
1212 pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
1214 &mut self,
1215 samples: &[O],
1216 interval: usize,
1217 mut callback: F,
1218 ) {
1219 let interval = interval.max(1);
1220 for (i, sample) in samples.iter().enumerate() {
1221 self.train_one(sample);
1222 if (i + 1) % interval == 0 {
1223 callback(i + 1);
1224 }
1225 }
1226 let total = samples.len();
1227 if total % interval != 0 {
1228 callback(total);
1229 }
1230 }
1231
1232 #[inline]
1237 fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
1238 if let Some(k) = self.config.huber_k {
1239 let delta = k * self.ewma_sq_err.sqrt().max(1e-8);
1240 let residual = mu - target;
1241 if residual.abs() <= delta {
1242 (residual, 1.0)
1243 } else {
1244 (delta * residual.signum(), 1e-6)
1245 }
1246 } else {
1247 (mu - target, 1.0)
1248 }
1249 }
1250
1251 #[inline]
1253 fn update_ensemble_grad_stats(&mut self, gradient: f64) {
1254 self.ensemble_grad_count += 1;
1255 let delta = gradient - self.ensemble_grad_mean;
1256 self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
1257 let delta2 = gradient - self.ensemble_grad_mean;
1258 self.ensemble_grad_m2 += delta * delta2;
1259 }
1260
1261 pub fn ensemble_grad_std(&self) -> f64 {
1263 if self.ensemble_grad_count < 2 {
1264 return 0.0;
1265 }
1266 (self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64)
1267 .sqrt()
1268 .max(0.0)
1269 }
1270
1271 pub fn ensemble_grad_mean(&self) -> f64 {
1273 self.ensemble_grad_mean
1274 }
1275
1276 fn maybe_refresh_packed_cache(&mut self) {
1278 if self.packed_refresh_interval > 0 {
1279 self.samples_since_refresh += 1;
1280 if self.samples_since_refresh >= self.packed_refresh_interval {
1281 self.refresh_packed_cache();
1282 self.samples_since_refresh = 0;
1283 }
1284 }
1285 }
1286
1287 fn refresh_packed_cache(&mut self) {
1289 let n_features = self
1291 .location_steps
1292 .iter()
1293 .filter_map(|s| s.slot().active_tree().n_features())
1294 .max()
1295 .unwrap_or(0);
1296
1297 if n_features == 0 {
1298 return;
1299 }
1300
1301 let (bytes, base) = export_distributional_packed(self, n_features);
1302 self.packed_cache = Some(PackedInferenceCache {
1303 bytes,
1304 base,
1305 n_features,
1306 });
1307 }
1308
1309 pub fn enable_packed_cache(&mut self, interval: u64) {
1314 self.packed_refresh_interval = interval;
1315 self.samples_since_refresh = 0;
1316 if interval > 0 && self.base_initialized {
1317 self.refresh_packed_cache();
1318 } else if interval == 0 {
1319 self.packed_cache = None;
1320 }
1321 }
1322
1323 #[inline]
1325 pub fn has_packed_cache(&self) -> bool {
1326 self.packed_cache.is_some()
1327 }
1328
1329 fn refresh_bandwidths(&mut self) {
1331 let current_sum: u64 = self
1332 .location_steps
1333 .iter()
1334 .chain(self.scale_steps.iter())
1335 .map(|s| s.slot().replacements())
1336 .sum();
1337 if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1338 self.auto_bandwidths = self.compute_auto_bandwidths();
1339 self.last_replacement_sum = current_sum;
1340 }
1341 }
1342
1343 fn compute_auto_bandwidths(&self) -> Vec<f64> {
1345 const K: f64 = 2.0;
1346
1347 let n_features = self
1348 .location_steps
1349 .iter()
1350 .chain(self.scale_steps.iter())
1351 .filter_map(|s| s.slot().active_tree().n_features())
1352 .max()
1353 .unwrap_or(0);
1354
1355 if n_features == 0 {
1356 return Vec::new();
1357 }
1358
1359 let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1360
1361 for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1362 let tree_thresholds = step
1363 .slot()
1364 .active_tree()
1365 .collect_split_thresholds_per_feature();
1366 for (i, ts) in tree_thresholds.into_iter().enumerate() {
1367 if i < n_features {
1368 all_thresholds[i].extend(ts);
1369 }
1370 }
1371 }
1372
1373 let n_bins = self.config.n_bins as f64;
1374
1375 all_thresholds
1376 .iter()
1377 .map(|ts| {
1378 if ts.is_empty() {
1379 return f64::INFINITY;
1380 }
1381
1382 let mut sorted = ts.clone();
1383 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1384 sorted.dedup_by(|a, b| (*a - *b).abs() < 1e-15);
1385
1386 if sorted.len() < 2 {
1387 return f64::INFINITY;
1388 }
1389
1390 let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1391
1392 if sorted.len() < 3 {
1393 let range = sorted.last().unwrap() - sorted.first().unwrap();
1394 if range < 1e-15 {
1395 return f64::INFINITY;
1396 }
1397 return (range / n_bins) * K;
1398 }
1399
1400 gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1401 let median_gap = if gaps.len() % 2 == 0 {
1402 (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1403 } else {
1404 gaps[gaps.len() / 2]
1405 };
1406
1407 if median_gap < 1e-15 {
1408 f64::INFINITY
1409 } else {
1410 median_gap * K
1411 }
1412 })
1413 .collect()
1414 }
1415
1416 pub fn auto_bandwidths(&self) -> &[f64] {
1418 &self.auto_bandwidths
1419 }
1420
1421 pub fn reset(&mut self) {
1423 for step in &mut self.location_steps {
1424 step.reset();
1425 }
1426 for step in &mut self.scale_steps {
1427 step.reset();
1428 }
1429 self.location_base = 0.0;
1430 self.scale_base = 0.0;
1431 self.base_initialized = false;
1432 self.initial_targets.clear();
1433 self.samples_seen = 0;
1434 self.rng_state = self.config.seed;
1435 self.rolling_sigma_mean = 1.0;
1436 self.ewma_sq_err = 1.0;
1437 self.prev_sigma = 0.0;
1438 self.sigma_velocity = 0.0;
1439 self.auto_bandwidths.clear();
1440 self.last_replacement_sum = 0;
1441 self.ensemble_grad_mean = 0.0;
1442 self.ensemble_grad_m2 = 0.0;
1443 self.ensemble_grad_count = 0;
1444 self.rolling_honest_sigma_mean = 0.0;
1445 self.packed_cache = None;
1446 self.samples_since_refresh = 0;
1447 self.prev_contributions.clear();
1448 self.prev_prev_contributions.clear();
1449 self.cached_residual_alignment = 0.0;
1450 self.cached_reg_sensitivity = 0.0;
1451 self.cached_depth_sufficiency = 0.0;
1452 self.cached_effective_dof = 0.0;
1453 self.contribution_accuracy = vec![0.0; self.location_steps.len()];
1454 }
1455
1456 #[inline]
1458 pub fn n_samples_seen(&self) -> u64 {
1459 self.samples_seen
1460 }
1461
1462 #[inline]
1464 pub fn n_steps(&self) -> usize {
1465 self.location_steps.len()
1466 }
1467
1468 pub fn n_trees(&self) -> usize {
1470 let loc = self.location_steps.len()
1471 + self
1472 .location_steps
1473 .iter()
1474 .filter(|s| s.has_alternate())
1475 .count();
1476 let scale = self.scale_steps.len()
1477 + self
1478 .scale_steps
1479 .iter()
1480 .filter(|s| s.has_alternate())
1481 .count();
1482 loc + scale
1483 }
1484
1485 pub fn total_leaves(&self) -> usize {
1487 let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1488 let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1489 loc + scale
1490 }
1491
1492 #[inline]
1494 pub fn is_initialized(&self) -> bool {
1495 self.base_initialized
1496 }
1497
1498 #[inline]
1500 pub fn config(&self) -> &SGBTConfig {
1501 &self.config
1502 }
1503
1504 #[inline]
1506 pub fn set_learning_rate(&mut self, lr: f64) {
1507 self.config.learning_rate = lr;
1508 }
1509
1510 #[inline]
1512 pub fn set_lambda(&mut self, lambda: f64) {
1513 self.config.lambda = lambda.max(0.0);
1514 }
1515
1516 #[inline]
1518 pub fn set_max_depth(&mut self, depth: usize) {
1519 self.config.max_depth = depth.clamp(1, 20);
1520 }
1521
1522 pub fn set_n_steps(&mut self, n: usize) {
1527 let n = n.clamp(3, 1000);
1528 let current = self.location_steps.len();
1529 if n > current {
1530 let leaf_decay_alpha = self
1531 .config
1532 .leaf_half_life
1533 .map(|hl| (-(2.0_f64.ln()) / hl as f64).exp());
1534 let tree_config = TreeConfig::new()
1535 .max_depth(self.config.max_depth)
1536 .n_bins(self.config.n_bins)
1537 .lambda(self.config.lambda)
1538 .gamma(self.config.gamma)
1539 .grace_period(self.config.grace_period)
1540 .delta(self.config.delta)
1541 .feature_subsample_rate(self.config.feature_subsample_rate)
1542 .leaf_decay_alpha_opt(leaf_decay_alpha)
1543 .split_reeval_interval_opt(self.config.split_reeval_interval)
1544 .feature_types_opt(self.config.feature_types.clone())
1545 .gradient_clip_sigma_opt(self.config.gradient_clip_sigma)
1546 .monotone_constraints_opt(self.config.monotone_constraints.clone())
1547 .max_leaf_output_opt(self.config.max_leaf_output)
1548 .adaptive_depth_opt(self.config.adaptive_depth)
1549 .min_hessian_sum_opt(self.config.min_hessian_sum)
1550 .leaf_model_type(self.config.leaf_model_type.clone());
1551 let mts = self.config.max_tree_samples;
1552 let shadow_warmup = self.config.shadow_warmup.unwrap_or(0);
1553 for i in current..n {
1554 let mut tc = tree_config.clone();
1556 tc.seed = self.config.seed ^ (i as u64);
1557 let detector = self.config.drift_detector.create();
1558 let step = if shadow_warmup > 0 {
1559 BoostingStep::new_with_graduated(tc, detector, mts, shadow_warmup)
1560 } else {
1561 BoostingStep::new_with_max_samples(tc, detector, mts)
1562 };
1563 self.location_steps.push(step);
1564
1565 let mut tc = tree_config.clone();
1567 tc.seed = self.config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
1568 let detector = self.config.drift_detector.create();
1569 let step = if shadow_warmup > 0 {
1570 BoostingStep::new_with_graduated(tc, detector, mts, shadow_warmup)
1571 } else {
1572 BoostingStep::new_with_max_samples(tc, detector, mts)
1573 };
1574 self.scale_steps.push(step);
1575 }
1576 } else if n < current {
1577 self.location_steps.truncate(n);
1578 self.scale_steps.truncate(n);
1579 }
1580 self.contribution_accuracy.resize(n, 0.0);
1581 self.config.n_steps = n;
1582 }
1583
1584 pub fn total_replacements(&self) -> u64 {
1586 self.location_steps
1587 .iter()
1588 .chain(self.scale_steps.iter())
1589 .map(|s| s.slot().replacements())
1590 .sum()
1591 }
1592
1593 pub fn check_proactive_prune(&mut self) -> bool {
1599 if self.location_steps.len() <= 1 {
1600 return false;
1601 }
1602 if self.config.accuracy_based_pruning {
1603 let grace_period = self.config.grace_period as u64;
1604 let worst = self
1605 .location_steps
1606 .iter()
1607 .enumerate()
1608 .zip(self.contribution_accuracy.iter())
1609 .filter(|((_, step), _)| step.slot().n_samples_seen() >= grace_period)
1610 .min_by(|((_, _), a), ((_, _), b)| {
1611 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
1612 });
1613 if let Some(((worst_idx, _), &worst_acc)) = worst {
1614 if worst_acc < 0.0 {
1615 self.location_steps[worst_idx].slot_mut().replace_active();
1616 self.contribution_accuracy[worst_idx] = 0.0;
1617 return true;
1618 }
1619 }
1620 false
1621 } else {
1622 let worst_idx = self
1623 .location_steps
1624 .iter()
1625 .enumerate()
1626 .min_by(|(_, a), (_, b)| {
1627 let a_std = a.slot().prediction_std();
1628 let b_std = b.slot().prediction_std();
1629 a_std
1630 .partial_cmp(&b_std)
1631 .unwrap_or(std::cmp::Ordering::Equal)
1632 })
1633 .map(|(i, _)| i)
1634 .unwrap_or(0);
1635 self.location_steps[worst_idx].slot_mut().replace_active();
1636 true
1637 }
1638 }
1639
1640 pub fn location_steps(&self) -> &[BoostingStep] {
1642 &self.location_steps
1643 }
1644
1645 #[inline]
1647 pub fn location_base(&self) -> f64 {
1648 self.location_base
1649 }
1650
1651 #[inline]
1653 pub fn learning_rate(&self) -> f64 {
1654 self.config.learning_rate
1655 }
1656
1657 #[inline]
1661 pub fn rolling_sigma_mean(&self) -> f64 {
1662 self.rolling_sigma_mean
1663 }
1664
1665 #[inline]
1667 pub fn is_uncertainty_modulated(&self) -> bool {
1668 self.uncertainty_modulated_lr
1669 }
1670
1671 #[inline]
1675 pub fn rolling_honest_sigma_mean(&self) -> f64 {
1676 self.rolling_honest_sigma_mean
1677 }
1678
1679 pub fn diagnostics(&self) -> ModelDiagnostics {
1691 let n = self.location_steps.len();
1692 let mut trees = Vec::with_capacity(2 * n);
1693 let mut feature_split_counts: Vec<usize> = Vec::new();
1694
1695 fn collect_tree_diags(
1696 steps: &[BoostingStep],
1697 trees: &mut Vec<TreeDiagnostic>,
1698 feature_split_counts: &mut Vec<usize>,
1699 ) {
1700 for step in steps {
1701 let slot = step.slot();
1702 let tree = slot.active_tree();
1703 let arena = tree.arena();
1704
1705 let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1706 .filter(|&i| arena.is_leaf[i])
1707 .map(|i| arena.leaf_value[i])
1708 .collect();
1709
1710 let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1711 .filter(|&i| arena.is_leaf[i])
1712 .map(|i| arena.sample_count[i])
1713 .collect();
1714
1715 let max_depth_reached = (0..arena.is_leaf.len())
1716 .filter(|&i| arena.is_leaf[i])
1717 .map(|i| arena.depth[i] as usize)
1718 .max()
1719 .unwrap_or(0);
1720
1721 let leaf_weight_stats = if leaf_values.is_empty() {
1722 (0.0, 0.0, 0.0, 0.0)
1723 } else {
1724 let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1725 let max = leaf_values
1726 .iter()
1727 .cloned()
1728 .fold(f64::NEG_INFINITY, f64::max);
1729 let sum: f64 = leaf_values.iter().sum();
1730 let mean = sum / leaf_values.len() as f64;
1731 let var: f64 = leaf_values.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
1732 / leaf_values.len() as f64;
1733 (min, max, mean, var.sqrt())
1734 };
1735
1736 let gains = slot.split_gains();
1737 let split_features: Vec<usize> = gains
1738 .iter()
1739 .enumerate()
1740 .filter(|(_, &g)| g > 0.0)
1741 .map(|(i, _)| i)
1742 .collect();
1743
1744 if !gains.is_empty() {
1745 if feature_split_counts.is_empty() {
1746 feature_split_counts.resize(gains.len(), 0);
1747 }
1748 for &fi in &split_features {
1749 if fi < feature_split_counts.len() {
1750 feature_split_counts[fi] += 1;
1751 }
1752 }
1753 }
1754
1755 trees.push(TreeDiagnostic {
1756 n_leaves: leaf_values.len(),
1757 max_depth_reached,
1758 samples_seen: step.n_samples_seen(),
1759 leaf_weight_stats,
1760 split_features,
1761 leaf_sample_counts,
1762 prediction_mean: slot.prediction_mean(),
1763 prediction_std: slot.prediction_std(),
1764 });
1765 }
1766 }
1767
1768 collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1769 collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1770
1771 let location_trees = trees[..n].to_vec();
1772 let scale_trees = trees[n..].to_vec();
1773 let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1774
1775 ModelDiagnostics {
1776 trees,
1777 location_trees,
1778 scale_trees,
1779 feature_split_counts,
1780 location_base: self.location_base,
1781 scale_base: self.scale_base,
1782 empirical_sigma: self.ewma_sq_err.sqrt(),
1783 scale_mode: self.scale_mode,
1784 scale_trees_active,
1785 auto_bandwidths: self.auto_bandwidths.clone(),
1786 ensemble_grad_mean: self.ensemble_grad_mean,
1787 ensemble_grad_std: self.ensemble_grad_std(),
1788 }
1789 }
1790
1791 pub fn ensemble_diagnostics(
1799 &self,
1800 features: &[f64],
1801 ) -> crate::ensemble::diagnostics::DistributionalDiagnostics {
1802 use crate::ensemble::diagnostics::build_ensemble_diagnostics;
1803
1804 let location = build_ensemble_diagnostics(
1805 &self.location_steps,
1806 self.location_base,
1807 self.config.learning_rate,
1808 self.samples_seen,
1809 Some(features),
1810 );
1811
1812 let scale = match self.scale_mode {
1813 ScaleMode::TreeChain => Some(build_ensemble_diagnostics(
1814 &self.scale_steps,
1815 self.scale_base,
1816 self.config.learning_rate,
1817 self.samples_seen,
1818 Some(features),
1819 )),
1820 ScaleMode::Empirical => None,
1821 };
1822
1823 let honest_sigma = self.compute_honest_sigma(features);
1824
1825 let effective_mts = self.config.adaptive_mts.map(|(base_mts, k)| {
1827 let sigma_ratio = if self.rolling_honest_sigma_mean > 1e-12 {
1828 honest_sigma / self.rolling_honest_sigma_mean
1829 } else {
1830 1.0
1831 };
1832 (base_mts as f64 / (1.0 + k * sigma_ratio)).max(100.0) as u64
1833 });
1834
1835 crate::ensemble::diagnostics::DistributionalDiagnostics {
1836 location,
1837 scale,
1838 honest_sigma,
1839 rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
1840 effective_mts,
1841 }
1842 }
1843
1844 pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1856 let lr = self.config.learning_rate;
1857 let location: Vec<f64> = self
1858 .location_steps
1859 .iter()
1860 .map(|s| lr * s.predict(features))
1861 .collect();
1862
1863 let (sb, scale) = match self.scale_mode {
1864 ScaleMode::Empirical => {
1865 let empirical_sigma = self.ewma_sq_err.sqrt().max(1e-8);
1866 (empirical_sigma.ln(), vec![0.0; self.location_steps.len()])
1867 }
1868 ScaleMode::TreeChain => {
1869 let s: Vec<f64> = self
1870 .scale_steps
1871 .iter()
1872 .map(|s| lr * s.predict(features))
1873 .collect();
1874 (self.scale_base, s)
1875 }
1876 };
1877
1878 DecomposedPrediction {
1879 location_base: self.location_base,
1880 scale_base: sb,
1881 location_contributions: location,
1882 scale_contributions: scale,
1883 }
1884 }
1885
1886 pub fn feature_importances(&self) -> Vec<f64> {
1892 let mut totals: Vec<f64> = Vec::new();
1893 for steps in [&self.location_steps, &self.scale_steps] {
1894 for step in steps {
1895 let gains = step.slot().split_gains();
1896 if totals.is_empty() && !gains.is_empty() {
1897 totals.resize(gains.len(), 0.0);
1898 }
1899 for (i, &g) in gains.iter().enumerate() {
1900 if i < totals.len() {
1901 totals[i] += g;
1902 }
1903 }
1904 }
1905 }
1906 let sum: f64 = totals.iter().sum();
1907 if sum > 0.0 {
1908 totals.iter_mut().for_each(|v| *v /= sum);
1909 }
1910 totals
1911 }
1912
1913 pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1918 fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1919 let mut totals: Vec<f64> = Vec::new();
1920 for step in steps {
1921 let gains = step.slot().split_gains();
1922 if totals.is_empty() && !gains.is_empty() {
1923 totals.resize(gains.len(), 0.0);
1924 }
1925 for (i, &g) in gains.iter().enumerate() {
1926 if i < totals.len() {
1927 totals[i] += g;
1928 }
1929 }
1930 }
1931 let sum: f64 = totals.iter().sum();
1932 if sum > 0.0 {
1933 totals.iter_mut().for_each(|v| *v /= sum);
1934 }
1935 totals
1936 }
1937 (
1938 aggregate(&self.location_steps),
1939 aggregate(&self.scale_steps),
1940 )
1941 }
1942
1943 #[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
1949 pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1950 use super::snapshot_tree;
1951 use crate::serde_support::{DistributionalModelState, StepSnapshot};
1952
1953 fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1954 let slot = step.slot();
1955 let tree_snap = snapshot_tree(slot.active_tree());
1956 let alt_snap = slot.alternate_tree().map(snapshot_tree);
1957 let drift_state = slot.detector().serialize_state();
1958 let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1959 StepSnapshot {
1960 tree: tree_snap,
1961 alternate_tree: alt_snap,
1962 drift_state,
1963 alt_drift_state,
1964 }
1965 }
1966
1967 DistributionalModelState {
1968 config: self.config.clone(),
1969 location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1970 scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1971 location_base: self.location_base,
1972 scale_base: self.scale_base,
1973 base_initialized: self.base_initialized,
1974 initial_targets: self.initial_targets.clone(),
1975 initial_target_count: self.initial_target_count,
1976 samples_seen: self.samples_seen,
1977 rng_state: self.rng_state,
1978 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1979 rolling_sigma_mean: self.rolling_sigma_mean,
1980 ewma_sq_err: self.ewma_sq_err,
1981 rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
1982 }
1983 }
1984
1985 #[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
1991 pub fn from_distributional_state(
1992 state: crate::serde_support::DistributionalModelState,
1993 ) -> Self {
1994 use super::rebuild_tree;
1995 use crate::ensemble::replacement::TreeSlot;
1996 use crate::serde_support::StepSnapshot;
1997
1998 let leaf_decay_alpha = state
1999 .config
2000 .leaf_half_life
2001 .map(|hl| (-(2.0_f64.ln()) / hl as f64).exp());
2002 let max_tree_samples = state.config.max_tree_samples;
2003
2004 let base_tree_config = TreeConfig::new()
2005 .max_depth(state.config.max_depth)
2006 .n_bins(state.config.n_bins)
2007 .lambda(state.config.lambda)
2008 .gamma(state.config.gamma)
2009 .grace_period(state.config.grace_period)
2010 .delta(state.config.delta)
2011 .feature_subsample_rate(state.config.feature_subsample_rate)
2012 .leaf_decay_alpha_opt(leaf_decay_alpha)
2013 .split_reeval_interval_opt(state.config.split_reeval_interval)
2014 .feature_types_opt(state.config.feature_types.clone())
2015 .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
2016 .monotone_constraints_opt(state.config.monotone_constraints.clone())
2017 .adaptive_depth_opt(state.config.adaptive_depth)
2018 .leaf_model_type(state.config.leaf_model_type.clone());
2019
2020 let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
2022 snaps
2023 .iter()
2024 .enumerate()
2025 .map(|(i, snap)| {
2026 let tc = base_tree_config
2027 .clone()
2028 .seed(state.config.seed ^ (i as u64) ^ seed_xor);
2029
2030 let active = rebuild_tree(&snap.tree, tc.clone());
2031 let alternate = snap
2032 .alternate_tree
2033 .as_ref()
2034 .map(|s| rebuild_tree(s, tc.clone()));
2035
2036 let mut detector = state.config.drift_detector.create();
2037 if let Some(ref ds) = snap.drift_state {
2038 detector.restore_state(ds);
2039 }
2040 let mut slot =
2041 TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
2042 if let Some(ref ads) = snap.alt_drift_state {
2043 if let Some(alt_det) = slot.alt_detector_mut() {
2044 alt_det.restore_state(ads);
2045 }
2046 }
2047 BoostingStep::from_slot(slot)
2048 })
2049 .collect()
2050 };
2051
2052 let location_steps = rebuild_steps(&state.location_steps, 0);
2054 let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
2055
2056 let scale_mode = state.config.scale_mode;
2057 let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
2058 let packed_refresh_interval = state.config.packed_refresh_interval;
2059 let n_location_steps = location_steps.len();
2060 let prune_alpha = if state.config.proactive_prune_interval.is_some() {
2061 let gp = state.config.grace_period.max(1) as f64;
2062 1.0 - (-2.0 / gp).exp()
2063 } else {
2064 0.01
2065 };
2066 Self {
2067 config: state.config,
2068 location_steps,
2069 scale_steps,
2070 location_base: state.location_base,
2071 scale_base: state.scale_base,
2072 base_initialized: state.base_initialized,
2073 initial_targets: state.initial_targets,
2074 initial_target_count: state.initial_target_count,
2075 samples_seen: state.samples_seen,
2076 rng_state: state.rng_state,
2077 uncertainty_modulated_lr: state.uncertainty_modulated_lr,
2078 rolling_sigma_mean: state.rolling_sigma_mean,
2079 scale_mode,
2080 ewma_sq_err: state.ewma_sq_err,
2081 empirical_sigma_alpha,
2082 prev_sigma: 0.0,
2083 sigma_velocity: 0.0,
2084 auto_bandwidths: Vec::new(),
2085 last_replacement_sum: 0,
2086 ensemble_grad_mean: 0.0,
2087 ensemble_grad_m2: 0.0,
2088 ensemble_grad_count: 0,
2089 rolling_honest_sigma_mean: state.rolling_honest_sigma_mean,
2090 packed_cache: None,
2091 samples_since_refresh: 0,
2092 packed_refresh_interval,
2093 prev_contributions: Vec::new(),
2094 prev_prev_contributions: Vec::new(),
2095 cached_residual_alignment: 0.0,
2096 cached_reg_sensitivity: 0.0,
2097 cached_depth_sufficiency: 0.0,
2098 cached_effective_dof: 0.0,
2099 contribution_accuracy: vec![0.0; n_location_steps],
2100 prune_alpha,
2101 }
2102 }
2103}
2104
2105use crate::learner::StreamingLearner;
2110
2111impl StreamingLearner for DistributionalSGBT {
2112 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
2113 let sample = SampleRef::weighted(features, target, weight);
2114 DistributionalSGBT::train_one(self, &sample);
2116 }
2117
2118 fn predict(&self, features: &[f64]) -> f64 {
2120 DistributionalSGBT::predict(self, features).mu
2121 }
2122
2123 fn n_samples_seen(&self) -> u64 {
2124 self.samples_seen
2125 }
2126
2127 fn reset(&mut self) {
2128 DistributionalSGBT::reset(self);
2129 }
2130
2131 fn diagnostics_array(&self) -> [f64; 5] {
2132 [
2133 self.cached_residual_alignment,
2134 self.cached_reg_sensitivity,
2135 self.cached_depth_sufficiency,
2136 self.cached_effective_dof,
2137 self.rolling_honest_sigma_mean(),
2138 ]
2139 }
2140
2141 fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
2142 self.set_learning_rate(self.config.learning_rate * lr_multiplier);
2143 self.set_lambda(self.config.lambda + lambda_delta);
2144 }
2145
2146 fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
2147 if depth_delta != 0 {
2148 let current = self.config.max_depth as i32;
2149 self.set_max_depth((current + depth_delta).max(1) as usize);
2150 }
2151 if steps_delta != 0 {
2152 let current = self.config.n_steps as i32;
2153 self.set_n_steps((current + steps_delta).max(3) as usize);
2154 }
2155 }
2156
2157 fn replacement_count(&self) -> u64 {
2158 self.total_replacements()
2159 }
2160}
2161
2162impl crate::automl::DiagnosticSource for DistributionalSGBT {
2167 fn config_diagnostics(&self) -> Option<crate::automl::auto_builder::ConfigDiagnostics> {
2168 Some(crate::automl::auto_builder::ConfigDiagnostics {
2169 residual_alignment: self.cached_residual_alignment,
2170 regularization_sensitivity: self.cached_reg_sensitivity,
2171 depth_sufficiency: self.cached_depth_sufficiency,
2172 effective_dof: self.cached_effective_dof,
2173 uncertainty: self.rolling_honest_sigma_mean(),
2174 })
2175 }
2176}
2177
2178#[cfg(test)]
2183mod tests {
2184 use super::*;
2185
2186 fn test_config() -> SGBTConfig {
2187 SGBTConfig::builder()
2188 .n_steps(10)
2189 .learning_rate(0.1)
2190 .grace_period(20)
2191 .max_depth(4)
2192 .n_bins(16)
2193 .initial_target_count(10)
2194 .build()
2195 .unwrap()
2196 }
2197
2198 #[test]
2199 fn fresh_model_predicts_zero() {
2200 let model = DistributionalSGBT::new(test_config());
2201 let pred = model.predict(&[1.0, 2.0, 3.0]);
2202 assert!(pred.mu.abs() < 1e-12);
2203 assert!(pred.sigma > 0.0);
2204 }
2205
2206 #[test]
2207 fn sigma_always_positive() {
2208 let mut model = DistributionalSGBT::new(test_config());
2209
2210 for i in 0..200 {
2212 let x = i as f64 * 0.1;
2213 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2214 }
2215
2216 for i in 0..20 {
2218 let x = i as f64 * 0.5;
2219 let pred = model.predict(&[x, x * 0.5]);
2220 assert!(
2221 pred.sigma > 0.0,
2222 "sigma must be positive, got {}",
2223 pred.sigma
2224 );
2225 assert!(pred.sigma.is_finite(), "sigma must be finite");
2226 }
2227 }
2228
2229 #[test]
2230 fn constant_target_has_small_sigma() {
2231 let mut model = DistributionalSGBT::new(test_config());
2232
2233 for i in 0..200 {
2235 let x = i as f64 * 0.1;
2236 model.train_one(&(vec![x, x * 2.0], 5.0));
2237 }
2238
2239 let pred = model.predict(&[1.0, 2.0]);
2240 assert!(pred.mu.is_finite());
2241 assert!(pred.sigma.is_finite());
2242 assert!(pred.sigma > 0.0);
2243 }
2246
2247 #[test]
2248 fn noisy_target_has_finite_predictions() {
2249 let mut model = DistributionalSGBT::new(test_config());
2250
2251 let mut rng: u64 = 42;
2253 for i in 0..200 {
2254 rng ^= rng << 13;
2255 rng ^= rng >> 7;
2256 rng ^= rng << 17;
2257 let noise = (rng % 1000) as f64 / 500.0 - 1.0; let x = i as f64 * 0.1;
2259 model.train_one(&(vec![x], x * 2.0 + noise));
2260 }
2261
2262 let pred = model.predict(&[5.0]);
2263 assert!(pred.mu.is_finite());
2264 assert!(pred.sigma.is_finite());
2265 assert!(pred.sigma > 0.0);
2266 }
2267
2268 #[test]
2269 fn predict_interval_bounds_correct() {
2270 let mut model = DistributionalSGBT::new(test_config());
2271
2272 for i in 0..200 {
2273 let x = i as f64 * 0.1;
2274 model.train_one(&(vec![x], x * 2.0));
2275 }
2276
2277 let (lo, hi) = model.predict_interval(&[5.0], 1.96);
2278 let pred = model.predict(&[5.0]);
2279
2280 assert!(lo < pred.mu, "lower bound should be < mu");
2281 assert!(hi > pred.mu, "upper bound should be > mu");
2282 assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
2283 }
2284
2285 #[test]
2286 fn batch_prediction_matches_individual() {
2287 let mut model = DistributionalSGBT::new(test_config());
2288
2289 for i in 0..100 {
2290 let x = i as f64 * 0.1;
2291 model.train_one(&(vec![x, x * 2.0], x));
2292 }
2293
2294 let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
2295 let batch = model.predict_batch(&features);
2296
2297 for (feat, batch_pred) in features.iter().zip(batch.iter()) {
2298 let individual = model.predict(feat);
2299 assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
2300 assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
2301 }
2302 }
2303
2304 #[test]
2305 fn reset_clears_state() {
2306 let mut model = DistributionalSGBT::new(test_config());
2307
2308 for i in 0..200 {
2309 let x = i as f64 * 0.1;
2310 model.train_one(&(vec![x], x * 2.0));
2311 }
2312
2313 assert!(model.n_samples_seen() > 0);
2314 model.reset();
2315
2316 assert_eq!(model.n_samples_seen(), 0);
2317 assert!(!model.is_initialized());
2318 }
2319
2320 #[test]
2321 fn gaussian_prediction_lower_upper() {
2322 let pred = GaussianPrediction {
2323 mu: 10.0,
2324 sigma: 2.0,
2325 log_sigma: 2.0_f64.ln(),
2326 honest_sigma: 0.0,
2327 };
2328
2329 assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
2330 assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
2331 }
2332
2333 #[test]
2334 fn train_batch_works() {
2335 let mut model = DistributionalSGBT::new(test_config());
2336 let samples: Vec<(Vec<f64>, f64)> = (0..100)
2337 .map(|i| {
2338 let x = i as f64 * 0.1;
2339 (vec![x], x * 2.0)
2340 })
2341 .collect();
2342
2343 model.train_batch(&samples);
2344 assert_eq!(model.n_samples_seen(), 100);
2345 }
2346
2347 #[test]
2348 fn debug_format_works() {
2349 let model = DistributionalSGBT::new(test_config());
2350 let debug = format!("{:?}", model);
2351 assert!(debug.contains("DistributionalSGBT"));
2352 }
2353
2354 #[test]
2355 fn n_trees_counts_both_ensembles() {
2356 let model = DistributionalSGBT::new(test_config());
2357 assert!(model.n_trees() >= 20);
2359 }
2360
2361 fn modulated_config() -> SGBTConfig {
2364 SGBTConfig::builder()
2365 .n_steps(10)
2366 .learning_rate(0.1)
2367 .grace_period(20)
2368 .max_depth(4)
2369 .n_bins(16)
2370 .initial_target_count(10)
2371 .uncertainty_modulated_lr(true)
2372 .build()
2373 .unwrap()
2374 }
2375
2376 #[test]
2377 fn sigma_modulated_initializes_rolling_mean() {
2378 let mut model = DistributionalSGBT::new(modulated_config());
2379 assert!(model.is_uncertainty_modulated());
2380
2381 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
2383
2384 for i in 0..200 {
2386 let x = i as f64 * 0.1;
2387 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2388 }
2389
2390 assert!(model.rolling_sigma_mean() > 0.0);
2392 assert!(model.rolling_sigma_mean().is_finite());
2393 }
2394
2395 #[test]
2396 fn predict_distributional_returns_sigma_ratio() {
2397 let mut model = DistributionalSGBT::new(modulated_config());
2398
2399 for i in 0..200 {
2400 let x = i as f64 * 0.1;
2401 model.train_one(&(vec![x], x * 2.0 + 1.0));
2402 }
2403
2404 let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
2405 assert!(mu.is_finite());
2406 assert!(sigma > 0.0);
2407 assert!(
2408 (0.1..=10.0).contains(&sigma_ratio),
2409 "sigma_ratio={}",
2410 sigma_ratio
2411 );
2412 }
2413
2414 #[test]
2415 fn predict_distributional_without_modulation_returns_one() {
2416 let mut model = DistributionalSGBT::new(test_config());
2417 assert!(!model.is_uncertainty_modulated());
2418
2419 for i in 0..200 {
2420 let x = i as f64 * 0.1;
2421 model.train_one(&(vec![x], x * 2.0));
2422 }
2423
2424 let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
2425 assert!(
2426 (sigma_ratio - 1.0).abs() < 1e-12,
2427 "should be 1.0 when disabled"
2428 );
2429 }
2430
2431 #[test]
2432 fn modulated_model_sigma_finite_under_varying_noise() {
2433 let mut model = DistributionalSGBT::new(modulated_config());
2434
2435 let mut rng: u64 = 123;
2436 for i in 0..500 {
2437 rng ^= rng << 13;
2438 rng ^= rng >> 7;
2439 rng ^= rng << 17;
2440 let noise = (rng % 1000) as f64 / 100.0 - 5.0; let x = i as f64 * 0.1;
2442 let scale = if i < 250 { 1.0 } else { 5.0 };
2444 model.train_one(&(vec![x], x * 2.0 + noise * scale));
2445 }
2446
2447 let pred = model.predict(&[10.0]);
2448 assert!(pred.mu.is_finite());
2449 assert!(pred.sigma.is_finite());
2450 assert!(pred.sigma > 0.0);
2451 assert!(model.rolling_sigma_mean().is_finite());
2452 }
2453
2454 #[test]
2455 fn reset_clears_rolling_sigma_mean() {
2456 let mut model = DistributionalSGBT::new(modulated_config());
2457
2458 for i in 0..200 {
2459 let x = i as f64 * 0.1;
2460 model.train_one(&(vec![x], x * 2.0));
2461 }
2462
2463 let sigma_before = model.rolling_sigma_mean();
2464 assert!(sigma_before > 0.0);
2465
2466 model.reset();
2467 assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
2468 }
2469
2470 #[test]
2471 fn streaming_learner_returns_mu() {
2472 let mut model = DistributionalSGBT::new(test_config());
2473 for i in 0..200 {
2474 let x = i as f64 * 0.1;
2475 StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
2476 }
2477 let pred = StreamingLearner::predict(&model, &[5.0]);
2478 let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
2479 assert!(
2480 (pred - gaussian.mu).abs() < 1e-12,
2481 "StreamingLearner::predict should return mu"
2482 );
2483 }
2484
2485 fn trained_model() -> DistributionalSGBT {
2488 let config = SGBTConfig::builder()
2489 .n_steps(10)
2490 .learning_rate(0.1)
2491 .grace_period(10) .max_depth(4)
2493 .n_bins(16)
2494 .initial_target_count(10)
2495 .build()
2496 .unwrap();
2497 let mut model = DistributionalSGBT::new(config);
2498 for i in 0..500 {
2499 let x = i as f64 * 0.1;
2500 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2501 }
2502 model
2503 }
2504
2505 #[test]
2506 fn diagnostics_returns_correct_tree_count() {
2507 let model = trained_model();
2508 let diag = model.diagnostics();
2509 assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
2511 }
2512
2513 #[test]
2514 fn diagnostics_trees_have_leaves() {
2515 let model = trained_model();
2516 let diag = model.diagnostics();
2517 for (i, tree) in diag.trees.iter().enumerate() {
2518 assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
2519 }
2520 let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
2522 assert!(
2523 total_samples > 0,
2524 "at least some trees should have seen samples"
2525 );
2526 }
2527
2528 #[test]
2529 fn diagnostics_leaf_weight_stats_finite() {
2530 let model = trained_model();
2531 let diag = model.diagnostics();
2532 for (i, tree) in diag.trees.iter().enumerate() {
2533 let (min, max, mean, std) = tree.leaf_weight_stats;
2534 assert!(min.is_finite(), "tree {i} min not finite");
2535 assert!(max.is_finite(), "tree {i} max not finite");
2536 assert!(mean.is_finite(), "tree {i} mean not finite");
2537 assert!(std.is_finite(), "tree {i} std not finite");
2538 assert!(min <= max, "tree {i} min > max");
2539 }
2540 }
2541
2542 #[test]
2543 fn diagnostics_base_predictions_match() {
2544 let model = trained_model();
2545 let diag = model.diagnostics();
2546 assert!(
2547 (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
2548 "location_base should be plausible"
2549 );
2550 }
2551
2552 #[test]
2553 fn predict_decomposed_reconstructs_prediction() {
2554 let model = trained_model();
2555 let features = [5.0, 2.5, 1.0];
2556 let pred = model.predict(&features);
2557 let decomp = model.predict_decomposed(&features);
2558
2559 assert!(
2560 (decomp.mu() - pred.mu).abs() < 1e-10,
2561 "decomposed mu ({}) != predict mu ({})",
2562 decomp.mu(),
2563 pred.mu
2564 );
2565 assert!(
2566 (decomp.sigma() - pred.sigma).abs() < 1e-10,
2567 "decomposed sigma ({}) != predict sigma ({})",
2568 decomp.sigma(),
2569 pred.sigma
2570 );
2571 }
2572
2573 #[test]
2574 fn predict_decomposed_correct_lengths() {
2575 let model = trained_model();
2576 let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2577 assert_eq!(
2578 decomp.location_contributions.len(),
2579 model.n_steps(),
2580 "location contributions should have n_steps entries"
2581 );
2582 assert_eq!(
2583 decomp.scale_contributions.len(),
2584 model.n_steps(),
2585 "scale contributions should have n_steps entries"
2586 );
2587 }
2588
2589 #[test]
2590 fn feature_importances_work() {
2591 let model = trained_model();
2592 let imp = model.feature_importances();
2593 for (i, &v) in imp.iter().enumerate() {
2596 assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2597 assert!(v.is_finite(), "importance {i} should be finite");
2598 }
2599 let sum: f64 = imp.iter().sum();
2600 if sum > 0.0 {
2601 assert!(
2602 (sum - 1.0).abs() < 1e-10,
2603 "non-zero importances should sum to 1.0, got {sum}"
2604 );
2605 }
2606 }
2607
2608 #[test]
2609 fn feature_importances_split_works() {
2610 let model = trained_model();
2611 let (loc_imp, scale_imp) = model.feature_importances_split();
2612 for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2613 let sum: f64 = imp.iter().sum();
2614 if sum > 0.0 {
2615 assert!(
2616 (sum - 1.0).abs() < 1e-10,
2617 "{name} importances should sum to 1.0, got {sum}"
2618 );
2619 }
2620 for &v in imp.iter() {
2621 assert!(v >= 0.0 && v.is_finite());
2622 }
2623 }
2624 }
2625
2626 #[test]
2629 fn empirical_sigma_default_mode() {
2630 use crate::ensemble::config::ScaleMode;
2631 let config = test_config();
2632 let model = DistributionalSGBT::new(config);
2633 assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2634 }
2635
2636 #[test]
2637 fn empirical_sigma_tracks_errors() {
2638 let mut model = DistributionalSGBT::new(test_config());
2639
2640 for i in 0..200 {
2642 let x = i as f64 * 0.1;
2643 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2644 }
2645
2646 let sigma_clean = model.empirical_sigma();
2647 assert!(sigma_clean > 0.0, "sigma should be positive");
2648 assert!(sigma_clean.is_finite(), "sigma should be finite");
2649
2650 let mut rng: u64 = 42;
2652 for i in 200..400 {
2653 rng ^= rng << 13;
2654 rng ^= rng >> 7;
2655 rng ^= rng << 17;
2656 let noise = (rng % 10000) as f64 / 100.0 - 50.0; let x = i as f64 * 0.1;
2658 model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2659 }
2660
2661 let sigma_noisy = model.empirical_sigma();
2662 assert!(
2663 sigma_noisy > sigma_clean,
2664 "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2665 );
2666 }
2667
2668 #[test]
2669 fn empirical_sigma_modulated_lr_adapts() {
2670 let config = SGBTConfig::builder()
2671 .n_steps(10)
2672 .learning_rate(0.1)
2673 .grace_period(20)
2674 .max_depth(4)
2675 .n_bins(16)
2676 .initial_target_count(10)
2677 .uncertainty_modulated_lr(true)
2678 .build()
2679 .unwrap();
2680 let mut model = DistributionalSGBT::new(config);
2681
2682 for i in 0..300 {
2684 let x = i as f64 * 0.1;
2685 model.train_one(&(vec![x], x * 2.0 + 1.0));
2686 }
2687
2688 let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2689 assert!(sigma_ratio.is_finite());
2690 assert!(
2691 (0.1..=10.0).contains(&sigma_ratio),
2692 "sigma_ratio={sigma_ratio}"
2693 );
2694 }
2695
2696 #[test]
2697 fn tree_chain_mode_trains_scale_trees() {
2698 use crate::ensemble::config::ScaleMode;
2699 let config = SGBTConfig::builder()
2700 .n_steps(10)
2701 .learning_rate(0.1)
2702 .grace_period(10)
2703 .max_depth(4)
2704 .n_bins(16)
2705 .initial_target_count(10)
2706 .scale_mode(ScaleMode::TreeChain)
2707 .build()
2708 .unwrap();
2709 let mut model = DistributionalSGBT::new(config);
2710 assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2711
2712 for i in 0..500 {
2713 let x = i as f64 * 0.1;
2714 model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2715 }
2716
2717 let pred = model.predict(&[5.0, 2.5, 1.0]);
2718 assert!(pred.mu.is_finite());
2719 assert!(pred.sigma > 0.0);
2720 assert!(pred.sigma.is_finite());
2721 }
2722
2723 #[test]
2724 fn diagnostics_shows_empirical_sigma() {
2725 let model = trained_model();
2726 let diag = model.diagnostics();
2727 assert!(
2728 diag.empirical_sigma > 0.0,
2729 "empirical_sigma should be positive"
2730 );
2731 assert!(
2732 diag.empirical_sigma.is_finite(),
2733 "empirical_sigma should be finite"
2734 );
2735 }
2736
2737 #[test]
2738 fn diagnostics_scale_trees_split_fields() {
2739 let model = trained_model();
2740 let diag = model.diagnostics();
2741 assert_eq!(diag.location_trees.len(), model.n_steps());
2742 assert_eq!(diag.scale_trees.len(), model.n_steps());
2743 }
2746
2747 #[test]
2748 fn reset_clears_empirical_sigma() {
2749 let mut model = DistributionalSGBT::new(test_config());
2750 for i in 0..200 {
2751 let x = i as f64 * 0.1;
2752 model.train_one(&(vec![x], x * 2.0));
2753 }
2754 model.reset();
2755 assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2757 }
2758
2759 #[test]
2760 fn predict_smooth_returns_finite() {
2761 let config = SGBTConfig::builder()
2762 .n_steps(3)
2763 .learning_rate(0.1)
2764 .grace_period(20)
2765 .max_depth(4)
2766 .n_bins(16)
2767 .initial_target_count(10)
2768 .build()
2769 .unwrap();
2770 let mut model = DistributionalSGBT::new(config);
2771
2772 for i in 0..200 {
2773 let x = (i as f64) * 0.1;
2774 let features = vec![x, x.sin()];
2775 let target = 2.0 * x + 1.0;
2776 model.train_one(&(features, target));
2777 }
2778
2779 let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2780 assert!(pred.mu.is_finite(), "smooth mu should be finite");
2781 assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2782 assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2783 }
2784
2785 #[test]
2788 fn sigma_velocity_responds_to_error_spike() {
2789 let config = SGBTConfig::builder()
2790 .n_steps(3)
2791 .learning_rate(0.1)
2792 .grace_period(20)
2793 .max_depth(4)
2794 .n_bins(16)
2795 .initial_target_count(10)
2796 .uncertainty_modulated_lr(true)
2797 .build()
2798 .unwrap();
2799 let mut model = DistributionalSGBT::new(config);
2800
2801 for i in 0..200 {
2803 let x = (i as f64) * 0.1;
2804 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2805 }
2806
2807 let velocity_before = model.sigma_velocity();
2808
2809 for i in 0..50 {
2811 let x = (i as f64) * 0.1;
2812 model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2813 }
2814
2815 let velocity_after = model.sigma_velocity();
2816
2817 assert!(
2819 velocity_after > velocity_before,
2820 "sigma velocity should increase after error spike: before={}, after={}",
2821 velocity_before,
2822 velocity_after,
2823 );
2824 }
2825
2826 #[test]
2827 fn sigma_velocity_getter_works() {
2828 let config = SGBTConfig::builder()
2829 .n_steps(2)
2830 .learning_rate(0.1)
2831 .grace_period(20)
2832 .max_depth(4)
2833 .n_bins(16)
2834 .initial_target_count(10)
2835 .build()
2836 .unwrap();
2837 let model = DistributionalSGBT::new(config);
2838 assert_eq!(model.sigma_velocity(), 0.0);
2840 }
2841
2842 #[test]
2843 fn diagnostics_leaf_sample_counts_populated() {
2844 let config = SGBTConfig::builder()
2845 .n_steps(3)
2846 .learning_rate(0.1)
2847 .grace_period(10)
2848 .max_depth(4)
2849 .n_bins(16)
2850 .initial_target_count(10)
2851 .build()
2852 .unwrap();
2853 let mut model = DistributionalSGBT::new(config);
2854
2855 for i in 0..200 {
2856 let x = (i as f64) * 0.1;
2857 let features = vec![x, x.sin()];
2858 let target = 2.0 * x + 1.0;
2859 model.train_one(&(features, target));
2860 }
2861
2862 let diags = model.diagnostics();
2863 for (ti, tree) in diags.trees.iter().enumerate() {
2864 assert_eq!(
2865 tree.leaf_sample_counts.len(),
2866 tree.n_leaves,
2867 "tree {} should have sample count per leaf",
2868 ti,
2869 );
2870 if tree.samples_seen > 0 {
2872 let total: u64 = tree.leaf_sample_counts.iter().sum();
2873 assert!(
2874 total > 0,
2875 "tree {} has {} samples_seen but leaf counts sum to 0",
2876 ti,
2877 tree.samples_seen,
2878 );
2879 }
2880 }
2881 }
2882
2883 #[test]
2888 fn packed_cache_disabled_by_default() {
2889 let model = DistributionalSGBT::new(test_config());
2890 assert!(!model.has_packed_cache());
2891 assert_eq!(model.config().packed_refresh_interval, 0);
2892 }
2893
2894 #[test]
2895 fn packed_cache_refreshes_after_interval() {
2896 let config = SGBTConfig::builder()
2897 .n_steps(5)
2898 .learning_rate(0.1)
2899 .grace_period(5)
2900 .max_depth(3)
2901 .n_bins(8)
2902 .initial_target_count(10)
2903 .packed_refresh_interval(20)
2904 .build()
2905 .unwrap();
2906
2907 let mut model = DistributionalSGBT::new(config);
2908
2909 for i in 0..40 {
2911 let x = i as f64 * 0.1;
2912 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2913 }
2914
2915 assert!(
2917 model.has_packed_cache(),
2918 "packed cache should exist after training past refresh interval"
2919 );
2920
2921 let pred = model.predict(&[2.0, 4.0, 1.0]);
2923 assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2924 }
2925
2926 #[test]
2927 fn packed_cache_matches_full_tree() {
2928 let config = SGBTConfig::builder()
2929 .n_steps(5)
2930 .learning_rate(0.1)
2931 .grace_period(5)
2932 .max_depth(3)
2933 .n_bins(8)
2934 .initial_target_count(10)
2935 .build()
2936 .unwrap();
2937
2938 let mut model = DistributionalSGBT::new(config);
2939
2940 for i in 0..80 {
2942 let x = i as f64 * 0.1;
2943 model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2944 }
2945
2946 assert!(!model.has_packed_cache());
2948 let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2949
2950 model.enable_packed_cache(10);
2952 assert!(model.has_packed_cache());
2953 let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2954
2955 let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2957 assert!(
2958 mu_diff < 0.1,
2959 "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2960 cached_pred.mu,
2961 full_pred.mu,
2962 mu_diff
2963 );
2964
2965 assert!(
2967 (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2968 "sigma should be identical: full={}, cached={}",
2969 full_pred.sigma,
2970 cached_pred.sigma
2971 );
2972 }
2973
2974 #[test]
2975 fn honest_sigma_in_gaussian_prediction() {
2976 let config = SGBTConfig::builder()
2977 .n_steps(5)
2978 .learning_rate(0.1)
2979 .max_depth(3)
2980 .grace_period(2)
2981 .initial_target_count(10)
2982 .build()
2983 .unwrap();
2984 let mut model = DistributionalSGBT::new(config);
2985 for i in 0..100 {
2986 let x = i as f64 * 0.1;
2987 model.train_one(&(vec![x], x * 2.0));
2988 }
2989
2990 let pred = model.predict(&[5.0]);
2991 assert!(
2993 pred.honest_sigma.is_finite(),
2994 "honest_sigma should be finite, got {}",
2995 pred.honest_sigma
2996 );
2997 assert!(
2998 pred.honest_sigma >= 0.0,
2999 "honest_sigma should be >= 0, got {}",
3000 pred.honest_sigma
3001 );
3002 }
3003
3004 #[test]
3005 fn honest_sigma_increases_with_divergence() {
3006 let config = SGBTConfig::builder()
3007 .n_steps(8)
3008 .learning_rate(0.1)
3009 .max_depth(4)
3010 .grace_period(2)
3011 .initial_target_count(10)
3012 .build()
3013 .unwrap();
3014 let mut model = DistributionalSGBT::new(config);
3015
3016 for i in 0..200 {
3018 let x = i as f64 * 0.05;
3019 model.train_one(&(vec![x], x * 2.0));
3020 }
3021
3022 let in_dist = model.predict(&[5.0]);
3024 let out_dist = model.predict(&[500.0]);
3026
3027 assert!(in_dist.honest_sigma.is_finite());
3029 assert!(out_dist.honest_sigma.is_finite());
3030 assert!(in_dist.honest_sigma >= 0.0);
3031 assert!(out_dist.honest_sigma >= 0.0);
3032
3033 }
3038
3039 #[test]
3040 fn honest_sigma_zero_for_single_step() {
3041 let config = SGBTConfig::builder()
3043 .n_steps(1)
3044 .learning_rate(0.1)
3045 .max_depth(3)
3046 .grace_period(2)
3047 .initial_target_count(10)
3048 .build()
3049 .unwrap();
3050 let mut model = DistributionalSGBT::new(config);
3051 for i in 0..100 {
3052 let x = i as f64 * 0.1;
3053 model.train_one(&(vec![x], x * 2.0));
3054 }
3055 let pred = model.predict(&[5.0]);
3056 assert!(
3057 pred.honest_sigma.abs() < 1e-15,
3058 "honest_sigma should be 0 with 1 step, got {}",
3059 pred.honest_sigma
3060 );
3061 }
3062
3063 #[test]
3064 fn adaptive_mts_with_distributional() {
3065 let config = SGBTConfig::builder()
3066 .n_steps(10)
3067 .learning_rate(0.1)
3068 .max_depth(3)
3069 .grace_period(5)
3070 .initial_target_count(10)
3071 .adaptive_mts(500, 1.0)
3072 .build()
3073 .unwrap();
3074 let mut model = DistributionalSGBT::new(config);
3075
3076 for i in 0..500 {
3078 let x = (i as f64) * 0.02;
3079 model.train_one(&(vec![x, x * 0.5], x.sin()));
3080 }
3081 let pred = model.predict(&[1.0, 0.5]);
3082 assert!(
3083 pred.mu.is_finite(),
3084 "adaptive_mts distributional: mu should be finite, got {}",
3085 pred.mu
3086 );
3087 assert!(
3088 pred.sigma.is_finite() && pred.sigma > 0.0,
3089 "adaptive_mts distributional: sigma should be finite and positive, got {}",
3090 pred.sigma
3091 );
3092 }
3093
3094 #[test]
3095 fn proactive_prune_with_distributional() {
3096 let config = SGBTConfig::builder()
3097 .n_steps(10)
3098 .learning_rate(0.1)
3099 .max_depth(3)
3100 .grace_period(5)
3101 .initial_target_count(10)
3102 .proactive_prune_interval(100)
3103 .build()
3104 .unwrap();
3105 let mut model = DistributionalSGBT::new(config);
3106
3107 for i in 0..300 {
3109 let x = (i as f64) * 0.03;
3110 model.train_one(&(vec![x, x * 0.7], x.cos()));
3111 }
3112 let pred = model.predict(&[2.0, 1.4]);
3113 assert!(
3114 pred.mu.is_finite(),
3115 "proactive_prune distributional: mu should be finite, got {}",
3116 pred.mu
3117 );
3118 assert!(
3119 pred.sigma.is_finite() && pred.sigma > 0.0,
3120 "proactive_prune distributional: sigma should be finite and positive, got {}",
3121 pred.sigma
3122 );
3123 }
3124
3125 #[test]
3127 fn diagnostic_source_impl() {
3128 use crate::automl::DiagnosticSource;
3129
3130 let mut model = DistributionalSGBT::new(test_config());
3131
3132 for i in 0..200 {
3133 let x = (i as f64) * 0.1;
3134 model.train_one(&(vec![x, x * 0.3], x.sin()));
3135 }
3136
3137 let diag = model.config_diagnostics();
3138 assert!(
3139 diag.is_some(),
3140 "config_diagnostics should return Some after training"
3141 );
3142 let diag = diag.unwrap();
3143
3144 assert!(
3145 diag.effective_dof > 0.0,
3146 "effective_dof should be > 0, got {}",
3147 diag.effective_dof
3148 );
3149 assert!(
3150 diag.uncertainty.is_finite(),
3151 "uncertainty should be finite, got {}",
3152 diag.uncertainty
3153 );
3154 }
3155
3156 #[test]
3161 fn diagnostic_signals_populated() {
3162 use crate::automl::DiagnosticSource;
3163 let mut model = DistributionalSGBT::new(test_config());
3164
3165 for i in 0..500 {
3167 let x = i as f64 * 0.1;
3168 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3169 }
3170
3171 let diag = model
3172 .config_diagnostics()
3173 .expect("should return Some diagnostics");
3174
3175 assert!(
3177 diag.residual_alignment != 0.0,
3178 "residual_alignment should be non-zero, got {}",
3179 diag.residual_alignment
3180 );
3181 assert!(
3182 diag.regularization_sensitivity != 0.0,
3183 "regularization_sensitivity should be non-zero, got {}",
3184 diag.regularization_sensitivity
3185 );
3186 assert!(
3187 diag.depth_sufficiency != 0.0,
3188 "depth_sufficiency should be non-zero, got {}",
3189 diag.depth_sufficiency
3190 );
3191 assert!(
3192 diag.effective_dof != 0.0,
3193 "effective_dof should be non-zero, got {}",
3194 diag.effective_dof
3195 );
3196 assert!(
3197 diag.uncertainty.is_finite(),
3198 "uncertainty should be finite, got {}",
3199 diag.uncertainty
3200 );
3201 }
3202
3203 #[test]
3204 fn residual_alignment_range() {
3205 use crate::automl::DiagnosticSource;
3206 let mut model = DistributionalSGBT::new(test_config());
3207
3208 for i in 0..500 {
3209 let x = i as f64 * 0.1;
3210 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3211 }
3212
3213 let diag = model
3214 .config_diagnostics()
3215 .expect("should return Some diagnostics");
3216
3217 assert!(
3218 diag.residual_alignment >= -1.0 && diag.residual_alignment <= 1.0,
3219 "residual_alignment should be in [-1, 1], got {}",
3220 diag.residual_alignment
3221 );
3222 }
3223
3224 #[test]
3225 fn effective_dof_positive() {
3226 use crate::automl::DiagnosticSource;
3227 let mut model = DistributionalSGBT::new(test_config());
3228
3229 for i in 0..500 {
3230 let x = i as f64 * 0.1;
3231 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3232 }
3233
3234 let diag = model
3235 .config_diagnostics()
3236 .expect("should return Some diagnostics");
3237
3238 assert!(
3239 diag.effective_dof > 0.0,
3240 "effective_dof should be positive after training, got {}",
3241 diag.effective_dof
3242 );
3243 }
3244
3245 #[test]
3246 fn depth_sufficiency_positive() {
3247 use crate::automl::DiagnosticSource;
3248 let mut model = DistributionalSGBT::new(test_config());
3249
3250 for i in 0..500 {
3251 let x = i as f64 * 0.1;
3252 model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3253 }
3254
3255 let diag = model
3256 .config_diagnostics()
3257 .expect("should return Some diagnostics");
3258
3259 assert!(
3260 diag.depth_sufficiency > 0.0,
3261 "depth_sufficiency should be positive after training, got {}",
3262 diag.depth_sufficiency
3263 );
3264 }
3265}
3266
3267#[cfg(test)]
3268#[cfg(feature = "serde-json")]
3269mod serde_tests {
3270 use super::*;
3271 use crate::SGBTConfig;
3272
3273 fn make_trained_distributional() -> DistributionalSGBT {
3274 let config = SGBTConfig::builder()
3275 .n_steps(5)
3276 .learning_rate(0.1)
3277 .max_depth(3)
3278 .grace_period(2)
3279 .initial_target_count(10)
3280 .build()
3281 .unwrap();
3282 let mut model = DistributionalSGBT::new(config);
3283 for i in 0..50 {
3284 let x = i as f64 * 0.1;
3285 model.train_one(&(vec![x], x.sin()));
3286 }
3287 model
3288 }
3289
3290 #[test]
3291 fn json_round_trip_preserves_predictions() {
3292 let model = make_trained_distributional();
3293 let state = model.to_distributional_state();
3294 let json = crate::serde_support::save_distributional_model(&state).unwrap();
3295 let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
3296 let restored = DistributionalSGBT::from_distributional_state(loaded_state);
3297
3298 let test_points = [0.5, 1.0, 2.0, 3.0];
3299 for &x in &test_points {
3300 let orig = model.predict(&[x]);
3301 let rest = restored.predict(&[x]);
3302 assert!(
3303 (orig.mu - rest.mu).abs() < 1e-10,
3304 "JSON round-trip mu mismatch at x={}: {} vs {}",
3305 x,
3306 orig.mu,
3307 rest.mu
3308 );
3309 assert!(
3310 (orig.sigma - rest.sigma).abs() < 1e-10,
3311 "JSON round-trip sigma mismatch at x={}: {} vs {}",
3312 x,
3313 orig.sigma,
3314 rest.sigma
3315 );
3316 }
3317 }
3318
3319 #[test]
3320 fn state_preserves_rolling_sigma_mean() {
3321 let config = SGBTConfig::builder()
3322 .n_steps(5)
3323 .learning_rate(0.1)
3324 .max_depth(3)
3325 .grace_period(2)
3326 .initial_target_count(10)
3327 .uncertainty_modulated_lr(true)
3328 .build()
3329 .unwrap();
3330 let mut model = DistributionalSGBT::new(config);
3331 for i in 0..50 {
3332 let x = i as f64 * 0.1;
3333 model.train_one(&(vec![x], x.sin()));
3334 }
3335 let state = model.to_distributional_state();
3336 assert!(state.uncertainty_modulated_lr);
3337 assert!(state.rolling_sigma_mean >= 0.0);
3338
3339 let restored = DistributionalSGBT::from_distributional_state(state);
3340 assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
3341 }
3342
3343 #[test]
3344 fn auto_bandwidth_computed_distributional() {
3345 let config = SGBTConfig::builder()
3346 .n_steps(3)
3347 .learning_rate(0.1)
3348 .grace_period(10)
3349 .initial_target_count(10)
3350 .build()
3351 .unwrap();
3352 let mut model = DistributionalSGBT::new(config);
3353
3354 for i in 0..200 {
3355 let x = (i as f64) * 0.1;
3356 model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
3357 }
3358
3359 let bws = model.auto_bandwidths();
3361 assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
3362
3363 let diag = model.diagnostics();
3365 assert_eq!(diag.auto_bandwidths.len(), 2);
3366
3367 assert!(diag.location_trees[0].prediction_mean.is_finite());
3369 assert!(diag.location_trees[0].prediction_std.is_finite());
3370
3371 let pred = model.predict(&[1.0, 1.0_f64.sin()]);
3372 assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
3373 assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
3374 }
3375
3376 #[test]
3377 fn max_leaf_output_clamps_predictions() {
3378 let config = SGBTConfig::builder()
3379 .n_steps(5)
3380 .learning_rate(1.0) .max_leaf_output(0.5)
3382 .build()
3383 .unwrap();
3384 let mut model = DistributionalSGBT::new(config);
3385
3386 for i in 0..200 {
3388 let target = if i % 2 == 0 { 100.0 } else { -100.0 };
3389 let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
3390 model.train_one(&sample);
3391 }
3392
3393 let pred = model.predict(&[2.0, 0.5]);
3395 assert!(
3396 pred.mu.is_finite(),
3397 "prediction should be finite with clamping"
3398 );
3399 }
3400
3401 #[test]
3402 fn min_hessian_sum_suppresses_fresh_leaves() {
3403 let config = SGBTConfig::builder()
3404 .n_steps(3)
3405 .learning_rate(0.01)
3406 .min_hessian_sum(50.0)
3407 .build()
3408 .unwrap();
3409 let mut model = DistributionalSGBT::new(config);
3410
3411 for i in 0..60 {
3413 let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
3414 model.train_one(&sample);
3415 }
3416
3417 let pred = model.predict(&[30.0, 0.5]);
3418 assert!(
3419 pred.mu.is_finite(),
3420 "prediction should be finite with min_hessian_sum"
3421 );
3422 }
3423
3424 #[test]
3425 fn predict_interpolated_returns_finite() {
3426 let config = SGBTConfig::builder()
3427 .n_steps(10)
3428 .learning_rate(0.1)
3429 .grace_period(20)
3430 .max_depth(4)
3431 .n_bins(16)
3432 .initial_target_count(10)
3433 .build()
3434 .unwrap();
3435 let mut model = DistributionalSGBT::new(config);
3436 for i in 0..200 {
3437 let x = i as f64 * 0.1;
3438 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
3439 model.train_one(&sample);
3440 }
3441
3442 let pred = model.predict_interpolated(&[1.0, 0.5]);
3443 assert!(pred.mu.is_finite(), "interpolated mu should be finite");
3444 assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
3445 }
3446
3447 #[test]
3448 fn huber_k_bounds_gradients() {
3449 let config = SGBTConfig::builder()
3450 .n_steps(5)
3451 .learning_rate(0.01)
3452 .huber_k(1.345)
3453 .build()
3454 .unwrap();
3455 let mut model = DistributionalSGBT::new(config);
3456
3457 for i in 0..300 {
3459 let target = if i % 50 == 0 {
3460 1000.0
3461 } else {
3462 (i as f64 * 0.1).sin()
3463 };
3464 let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
3465 model.train_one(&sample);
3466 }
3467
3468 let pred = model.predict(&[5.0, 0.3]);
3469 assert!(
3470 pred.mu.is_finite(),
3471 "Huber-loss mu should be finite despite outliers"
3472 );
3473 assert!(pred.sigma > 0.0, "sigma should be positive");
3474 }
3475
3476 #[test]
3477 fn ensemble_gradient_stats_populated() {
3478 let config = SGBTConfig::builder()
3479 .n_steps(10)
3480 .learning_rate(0.1)
3481 .grace_period(20)
3482 .max_depth(4)
3483 .n_bins(16)
3484 .initial_target_count(10)
3485 .build()
3486 .unwrap();
3487 let mut model = DistributionalSGBT::new(config);
3488 for i in 0..200 {
3489 let x = i as f64 * 0.1;
3490 let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
3491 model.train_one(&sample);
3492 }
3493
3494 let diag = model.diagnostics();
3495 assert!(
3496 diag.ensemble_grad_mean.is_finite(),
3497 "ensemble grad mean should be finite"
3498 );
3499 assert!(
3500 diag.ensemble_grad_std >= 0.0,
3501 "ensemble grad std should be non-negative"
3502 );
3503 assert!(
3504 diag.ensemble_grad_std.is_finite(),
3505 "ensemble grad std should be finite"
3506 );
3507 }
3508
3509 #[test]
3510 fn huber_k_validation() {
3511 let result = SGBTConfig::builder()
3512 .n_steps(5)
3513 .learning_rate(0.01)
3514 .huber_k(-1.0)
3515 .build();
3516 assert!(result.is_err(), "negative huber_k should fail validation");
3517 }
3518
3519 #[test]
3520 fn max_leaf_output_validation() {
3521 let result = SGBTConfig::builder()
3522 .n_steps(5)
3523 .learning_rate(0.01)
3524 .max_leaf_output(-1.0)
3525 .build();
3526 assert!(
3527 result.is_err(),
3528 "negative max_leaf_output should fail validation"
3529 );
3530 }
3531
3532 #[test]
3533 fn predict_sibling_interpolated_varies_with_features() {
3534 let config = SGBTConfig::builder()
3535 .n_steps(10)
3536 .learning_rate(0.1)
3537 .grace_period(10)
3538 .max_depth(6)
3539 .delta(0.1)
3540 .initial_target_count(10)
3541 .build()
3542 .unwrap();
3543 let mut model = DistributionalSGBT::new(config);
3544
3545 for i in 0..2000 {
3546 let x = (i as f64) * 0.01;
3547 let y = x.sin() * x + 0.5 * (x * 2.0).cos();
3548 let sample = crate::Sample::new(vec![x, x * 0.3], y);
3549 model.train_one(&sample);
3550 }
3551
3552 let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
3554 assert!(
3555 pred.mu.is_finite(),
3556 "sibling interpolated mu should be finite"
3557 );
3558 assert!(
3559 pred.sigma > 0.0,
3560 "sibling interpolated sigma should be positive"
3561 );
3562
3563 let bws = model.auto_bandwidths();
3565 if bws.iter().any(|&b| b.is_finite()) {
3566 let hard_preds: Vec<f64> = (0..200)
3567 .map(|i| {
3568 let x = i as f64 * 0.1;
3569 model.predict(&[x, x * 0.3]).mu
3570 })
3571 .collect();
3572 let hard_changes = hard_preds
3573 .windows(2)
3574 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
3575 .count();
3576
3577 let preds: Vec<f64> = (0..200)
3578 .map(|i| {
3579 let x = i as f64 * 0.1;
3580 model.predict_sibling_interpolated(&[x, x * 0.3]).mu
3581 })
3582 .collect();
3583
3584 let sibling_changes = preds
3585 .windows(2)
3586 .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
3587 .count();
3588 assert!(
3589 sibling_changes >= hard_changes,
3590 "sibling should produce >= hard changes: sibling={}, hard={}",
3591 sibling_changes,
3592 hard_changes
3593 );
3594 }
3595 }
3596}