mod diagnostics;
mod inference;
mod training;
#[cfg(test)]
mod tests;
pub use diagnostics::{DecomposedPrediction, DistributionalTreeDiagnostic, ModelDiagnostics};
use alloc::vec::Vec;
use crate::ensemble::config::{SGBTConfig, ScaleMode};
use crate::ensemble::step::BoostingStep;
use crate::sample::{Observation, SampleRef};
struct PackedInferenceCache {
bytes: Vec<u8>,
base: f64,
n_features: usize,
}
impl Clone for PackedInferenceCache {
fn clone(&self) -> Self {
Self {
bytes: self.bytes.clone(),
base: self.base,
n_features: self.n_features,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct GaussianPrediction {
pub mu: f64,
pub sigma: f64,
pub log_sigma: f64,
pub honest_sigma: f64,
}
impl GaussianPrediction {
#[inline]
pub fn lower(&self, z: f64) -> f64 {
self.mu - z * self.sigma
}
#[inline]
pub fn upper(&self, z: f64) -> f64 {
self.mu + z * self.sigma
}
}
pub struct DistributionalSGBT {
config: SGBTConfig,
location_steps: Vec<BoostingStep>,
scale_steps: Vec<BoostingStep>,
location_base: f64,
scale_base: f64,
base_initialized: bool,
initial_targets: Vec<f64>,
initial_target_count: usize,
samples_seen: u64,
rng_state: u64,
uncertainty_modulated_lr: bool,
rolling_sigma_mean: f64,
scale_mode: ScaleMode,
ewma_sq_err: f64,
empirical_sigma_alpha: f64,
prev_sigma: f64,
sigma_velocity: f64,
auto_bandwidths: Vec<f64>,
last_replacement_sum: u64,
ensemble_grad_mean: f64,
ensemble_grad_m2: f64,
ensemble_grad_count: u64,
rolling_honest_sigma_mean: f64,
packed_cache: Option<PackedInferenceCache>,
samples_since_refresh: u64,
packed_refresh_interval: u64,
}
impl Clone for DistributionalSGBT {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
location_steps: self.location_steps.clone(),
scale_steps: self.scale_steps.clone(),
location_base: self.location_base,
scale_base: self.scale_base,
base_initialized: self.base_initialized,
initial_targets: self.initial_targets.clone(),
initial_target_count: self.initial_target_count,
samples_seen: self.samples_seen,
rng_state: self.rng_state,
uncertainty_modulated_lr: self.uncertainty_modulated_lr,
rolling_sigma_mean: self.rolling_sigma_mean,
scale_mode: self.scale_mode,
ewma_sq_err: self.ewma_sq_err,
empirical_sigma_alpha: self.empirical_sigma_alpha,
prev_sigma: self.prev_sigma,
sigma_velocity: self.sigma_velocity,
auto_bandwidths: self.auto_bandwidths.clone(),
last_replacement_sum: self.last_replacement_sum,
ensemble_grad_mean: self.ensemble_grad_mean,
ensemble_grad_m2: self.ensemble_grad_m2,
ensemble_grad_count: self.ensemble_grad_count,
rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
packed_cache: self.packed_cache.clone(),
samples_since_refresh: self.samples_since_refresh,
packed_refresh_interval: self.packed_refresh_interval,
}
}
}
impl core::fmt::Debug for DistributionalSGBT {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut s = f.debug_struct("DistributionalSGBT");
s.field("n_steps", &self.location_steps.len())
.field("samples_seen", &self.samples_seen)
.field("location_base", &self.location_base)
.field("scale_mode", &self.scale_mode)
.field("base_initialized", &self.base_initialized);
match self.scale_mode {
ScaleMode::Empirical => {
s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
}
ScaleMode::TreeChain => {
s.field("scale_base", &self.scale_base);
}
}
if self.uncertainty_modulated_lr {
s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
}
s.finish()
}
}
impl DistributionalSGBT {
pub fn new(config: SGBTConfig) -> Self {
let n_steps = config.n_steps;
let initial_target_count = config.initial_target_count;
let seed = config.seed;
let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
let scale_mode = config.scale_mode;
let leaf_decay_alpha = config
.leaf_half_life
.map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
let tree_config = crate::ensemble::config::build_tree_config(&config)
.leaf_decay_alpha_opt(leaf_decay_alpha);
let max_tree_samples = config.max_tree_samples;
let shadow_warmup = config.shadow_warmup.unwrap_or(0);
let build_steps = |salt: u64| -> Vec<BoostingStep> {
(0..n_steps)
.map(|i| {
let mut tc = tree_config.clone();
tc.seed = seed ^ salt ^ (i as u64);
let detector = config.drift_detector.create();
if shadow_warmup > 0 {
BoostingStep::new_with_graduated(
tc,
detector,
max_tree_samples,
shadow_warmup,
)
} else {
BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
}
})
.collect()
};
let location_steps = build_steps(0);
let scale_steps = build_steps(0xD15C_A1E5_5CA1_E000);
Self {
config,
location_steps,
scale_steps,
location_base: 0.0,
scale_base: 0.0,
base_initialized: false,
initial_targets: Vec::with_capacity(initial_target_count),
initial_target_count,
samples_seen: 0,
rng_state: 1u64.wrapping_add(seed),
uncertainty_modulated_lr,
rolling_sigma_mean: 1.0,
scale_mode,
ewma_sq_err: 0.0,
empirical_sigma_alpha: 0.05,
prev_sigma: 0.0,
sigma_velocity: 0.0,
auto_bandwidths: Vec::new(),
last_replacement_sum: 0,
ensemble_grad_mean: 0.0,
ensemble_grad_m2: 0.0,
ensemble_grad_count: 0,
rolling_honest_sigma_mean: 1.0,
packed_cache: None,
samples_since_refresh: 0,
packed_refresh_interval: 1000,
}
}
pub fn config(&self) -> &SGBTConfig {
&self.config
}
pub fn train_one(&mut self, obs: &impl Observation) {
training::train_distributional_one(self, obs);
}
pub fn train_batch(&mut self, samples: &[(Vec<f64>, f64)]) {
for (features, target) in samples {
self.train_one(&(features.clone(), *target));
}
}
pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
inference::predict_distributional(self, features)
}
pub fn predict_batch(&self, batch: &[Vec<f64>]) -> Vec<GaussianPrediction> {
batch.iter().map(|f| self.predict(f)).collect()
}
pub fn predict_interval(&self, features: &[f64], z: f64) -> (f64, f64) {
let pred = self.predict(features);
(pred.lower(z), pred.upper(z))
}
pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
let pred = self.predict(features);
let ratio = if self.uncertainty_modulated_lr {
(pred.honest_sigma / self.rolling_honest_sigma_mean).clamp(0.1, 10.0)
} else {
1.0
};
(pred.mu, pred.sigma, ratio)
}
pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
inference::predict_smooth(self, features, bandwidth)
}
pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
inference::predict_interpolated(self, features)
}
pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
inference::predict_sibling_interpolated(self, features)
}
pub fn is_initialized(&self) -> bool {
self.base_initialized
}
pub fn n_location_trees(&self) -> usize {
self.location_steps.len()
}
pub fn n_scale_trees(&self) -> usize {
self.scale_steps.len()
}
pub fn n_trees(&self) -> usize {
self.location_steps.len() + self.scale_steps.len()
}
pub fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
pub fn is_uncertainty_modulated(&self) -> bool {
self.uncertainty_modulated_lr
}
pub fn rolling_sigma_mean(&self) -> f64 {
self.rolling_sigma_mean
}
pub fn reset(&mut self) {
self.location_steps.clear();
self.scale_steps.clear();
self.location_base = 0.0;
self.scale_base = 0.0;
self.base_initialized = false;
self.initial_targets.clear();
self.samples_seen = 0;
self.rng_state = 1u64.wrapping_add(self.config.seed);
self.rolling_sigma_mean = 1.0;
self.ewma_sq_err = 0.0;
self.prev_sigma = 0.0;
self.sigma_velocity = 0.0;
self.auto_bandwidths.clear();
self.ensemble_grad_mean = 0.0;
self.ensemble_grad_m2 = 0.0;
self.ensemble_grad_count = 0;
self.rolling_honest_sigma_mean = 1.0;
self.packed_cache = None;
}
pub fn diagnostics(&self) -> ModelDiagnostics {
diagnostics::compute_diagnostics(self)
}
pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
diagnostics::decompose_prediction(self, features)
}
pub fn feature_importances(&self) -> Vec<f64> {
diagnostics::compute_feature_importances(self, false)
}
pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
let location = diagnostics::compute_feature_importances(self, true);
let scale = diagnostics::compute_feature_importances_scale(self);
(location, scale)
}
#[allow(dead_code)]
fn compute_honest_sigma(&self, features: &[f64]) -> f64 {
if self.location_steps.len() < 2 {
return 0.0;
}
let preds: Vec<f64> = self
.location_steps
.iter()
.map(|s| s.predict(features))
.collect();
let n = preds.len() as f64;
let mean = preds.iter().sum::<f64>() / n;
let var = preds
.iter()
.map(|p| {
let d = p - mean;
d * d
})
.sum::<f64>()
/ (n - 1.0).max(1.0);
crate::math::sqrt(var)
}
}
impl crate::learner::StreamingLearner for DistributionalSGBT {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let sample = SampleRef::weighted(features, target, weight);
DistributionalSGBT::train_one(self, &sample);
}
fn predict(&self, features: &[f64]) -> f64 {
DistributionalSGBT::predict(self, features).mu
}
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
DistributionalSGBT::reset(self);
}
}