irithyll_core/ensemble/distributional/
mod.rs1mod diagnostics;
31mod inference;
32mod training;
33
34#[cfg(test)]
35mod tests;
36
37pub use diagnostics::{DecomposedPrediction, DistributionalTreeDiagnostic, ModelDiagnostics};
38
39use alloc::vec::Vec;
40
41use crate::ensemble::config::{SGBTConfig, ScaleMode};
42use crate::ensemble::step::BoostingStep;
43use crate::sample::{Observation, SampleRef};
44
45struct PackedInferenceCache {
50 bytes: Vec<u8>,
51 base: f64,
52 n_features: usize,
53}
54
55impl Clone for PackedInferenceCache {
56 fn clone(&self) -> Self {
57 Self {
58 bytes: self.bytes.clone(),
59 base: self.base,
60 n_features: self.n_features,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy)]
67pub struct GaussianPrediction {
68 pub mu: f64,
70 pub sigma: f64,
72 pub log_sigma: f64,
74 pub honest_sigma: f64,
83}
84
85impl GaussianPrediction {
86 #[inline]
90 pub fn lower(&self, z: f64) -> f64 {
91 self.mu - z * self.sigma
92 }
93
94 #[inline]
96 pub fn upper(&self, z: f64) -> f64 {
97 self.mu + z * self.sigma
98 }
99}
100
101pub struct DistributionalSGBT {
123 config: SGBTConfig,
124 location_steps: Vec<BoostingStep>,
125 scale_steps: Vec<BoostingStep>,
126 location_base: f64,
127 scale_base: f64,
128 base_initialized: bool,
129 initial_targets: Vec<f64>,
130 initial_target_count: usize,
131 samples_seen: u64,
132 rng_state: u64,
133 uncertainty_modulated_lr: bool,
134 rolling_sigma_mean: f64,
135 scale_mode: ScaleMode,
136 ewma_sq_err: f64,
137 empirical_sigma_alpha: f64,
138 prev_sigma: f64,
139 sigma_velocity: f64,
140 auto_bandwidths: Vec<f64>,
141 last_replacement_sum: u64,
142 ensemble_grad_mean: f64,
143 ensemble_grad_m2: f64,
144 ensemble_grad_count: u64,
145 rolling_honest_sigma_mean: f64,
146 packed_cache: Option<PackedInferenceCache>,
147 samples_since_refresh: u64,
148 packed_refresh_interval: u64,
149}
150
151impl Clone for DistributionalSGBT {
152 fn clone(&self) -> Self {
153 Self {
154 config: self.config.clone(),
155 location_steps: self.location_steps.clone(),
156 scale_steps: self.scale_steps.clone(),
157 location_base: self.location_base,
158 scale_base: self.scale_base,
159 base_initialized: self.base_initialized,
160 initial_targets: self.initial_targets.clone(),
161 initial_target_count: self.initial_target_count,
162 samples_seen: self.samples_seen,
163 rng_state: self.rng_state,
164 uncertainty_modulated_lr: self.uncertainty_modulated_lr,
165 rolling_sigma_mean: self.rolling_sigma_mean,
166 scale_mode: self.scale_mode,
167 ewma_sq_err: self.ewma_sq_err,
168 empirical_sigma_alpha: self.empirical_sigma_alpha,
169 prev_sigma: self.prev_sigma,
170 sigma_velocity: self.sigma_velocity,
171 auto_bandwidths: self.auto_bandwidths.clone(),
172 last_replacement_sum: self.last_replacement_sum,
173 ensemble_grad_mean: self.ensemble_grad_mean,
174 ensemble_grad_m2: self.ensemble_grad_m2,
175 ensemble_grad_count: self.ensemble_grad_count,
176 rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
177 packed_cache: self.packed_cache.clone(),
178 samples_since_refresh: self.samples_since_refresh,
179 packed_refresh_interval: self.packed_refresh_interval,
180 }
181 }
182}
183
184impl core::fmt::Debug for DistributionalSGBT {
185 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
186 let mut s = f.debug_struct("DistributionalSGBT");
187 s.field("n_steps", &self.location_steps.len())
188 .field("samples_seen", &self.samples_seen)
189 .field("location_base", &self.location_base)
190 .field("scale_mode", &self.scale_mode)
191 .field("base_initialized", &self.base_initialized);
192 match self.scale_mode {
193 ScaleMode::Empirical => {
194 s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
195 }
196 ScaleMode::TreeChain => {
197 s.field("scale_base", &self.scale_base);
198 }
199 }
200 if self.uncertainty_modulated_lr {
201 s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
202 }
203 s.finish()
204 }
205}
206
207impl DistributionalSGBT {
208 pub fn new(config: SGBTConfig) -> Self {
210 let n_steps = config.n_steps;
211 let initial_target_count = config.initial_target_count;
212 let seed = config.seed;
213 let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
214 let scale_mode = config.scale_mode;
215
216 let leaf_decay_alpha = config
217 .leaf_half_life
218 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
219 let tree_config = crate::ensemble::config::build_tree_config(&config)
220 .leaf_decay_alpha_opt(leaf_decay_alpha);
221 let max_tree_samples = config.max_tree_samples;
222 let shadow_warmup = config.shadow_warmup.unwrap_or(0);
223
224 let build_steps = |salt: u64| -> Vec<BoostingStep> {
225 (0..n_steps)
226 .map(|i| {
227 let mut tc = tree_config.clone();
228 tc.seed = seed ^ salt ^ (i as u64);
229 let detector = config.drift_detector.create();
230 if shadow_warmup > 0 {
231 BoostingStep::new_with_graduated(
232 tc,
233 detector,
234 max_tree_samples,
235 shadow_warmup,
236 )
237 } else {
238 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
239 }
240 })
241 .collect()
242 };
243
244 let location_steps = build_steps(0);
245 let scale_steps = build_steps(0xD15C_A1E5_5CA1_E000);
246
247 Self {
248 config,
249 location_steps,
250 scale_steps,
251 location_base: 0.0,
252 scale_base: 0.0,
253 base_initialized: false,
254 initial_targets: Vec::with_capacity(initial_target_count),
255 initial_target_count,
256 samples_seen: 0,
257 rng_state: 1u64.wrapping_add(seed),
258 uncertainty_modulated_lr,
259 rolling_sigma_mean: 1.0,
260 scale_mode,
261 ewma_sq_err: 0.0,
262 empirical_sigma_alpha: 0.05,
263 prev_sigma: 0.0,
264 sigma_velocity: 0.0,
265 auto_bandwidths: Vec::new(),
266 last_replacement_sum: 0,
267 ensemble_grad_mean: 0.0,
268 ensemble_grad_m2: 0.0,
269 ensemble_grad_count: 0,
270 rolling_honest_sigma_mean: 1.0,
271 packed_cache: None,
272 samples_since_refresh: 0,
273 packed_refresh_interval: 1000,
274 }
275 }
276
277 pub fn config(&self) -> &SGBTConfig {
279 &self.config
280 }
281
282 pub fn train_one(&mut self, obs: &impl Observation) {
284 training::train_distributional_one(self, obs);
285 }
286
287 pub fn train_batch(&mut self, samples: &[(Vec<f64>, f64)]) {
289 for (features, target) in samples {
290 self.train_one(&(features.clone(), *target));
291 }
292 }
293
294 pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
296 inference::predict_distributional(self, features)
297 }
298
299 pub fn predict_batch(&self, batch: &[Vec<f64>]) -> Vec<GaussianPrediction> {
301 batch.iter().map(|f| self.predict(f)).collect()
302 }
303
304 pub fn predict_interval(&self, features: &[f64], z: f64) -> (f64, f64) {
308 let pred = self.predict(features);
309 (pred.lower(z), pred.upper(z))
310 }
311
312 pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
315 let pred = self.predict(features);
316 let ratio = if self.uncertainty_modulated_lr {
317 (pred.honest_sigma / self.rolling_honest_sigma_mean).clamp(0.1, 10.0)
318 } else {
319 1.0
320 };
321 (pred.mu, pred.sigma, ratio)
322 }
323
324 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
334 inference::predict_smooth(self, features, bandwidth)
335 }
336
337 pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
342 inference::predict_interpolated(self, features)
343 }
344
345 pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
351 inference::predict_sibling_interpolated(self, features)
352 }
353
354 pub fn is_initialized(&self) -> bool {
356 self.base_initialized
357 }
358
359 pub fn n_location_trees(&self) -> usize {
361 self.location_steps.len()
362 }
363
364 pub fn n_scale_trees(&self) -> usize {
366 self.scale_steps.len()
367 }
368
369 pub fn n_trees(&self) -> usize {
371 self.location_steps.len() + self.scale_steps.len()
372 }
373
374 pub fn n_samples_seen(&self) -> u64 {
376 self.samples_seen
377 }
378
379 pub fn is_uncertainty_modulated(&self) -> bool {
381 self.uncertainty_modulated_lr
382 }
383
384 pub fn rolling_sigma_mean(&self) -> f64 {
386 self.rolling_sigma_mean
387 }
388
389 pub fn reset(&mut self) {
391 self.location_steps.clear();
392 self.scale_steps.clear();
393 self.location_base = 0.0;
394 self.scale_base = 0.0;
395 self.base_initialized = false;
396 self.initial_targets.clear();
397 self.samples_seen = 0;
398 self.rng_state = 1u64.wrapping_add(self.config.seed);
399 self.rolling_sigma_mean = 1.0;
400 self.ewma_sq_err = 0.0;
401 self.prev_sigma = 0.0;
402 self.sigma_velocity = 0.0;
403 self.auto_bandwidths.clear();
404 self.ensemble_grad_mean = 0.0;
405 self.ensemble_grad_m2 = 0.0;
406 self.ensemble_grad_count = 0;
407 self.rolling_honest_sigma_mean = 1.0;
408 self.packed_cache = None;
409 }
410
411 pub fn diagnostics(&self) -> ModelDiagnostics {
413 diagnostics::compute_diagnostics(self)
414 }
415
416 pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
418 diagnostics::decompose_prediction(self, features)
419 }
420
421 pub fn feature_importances(&self) -> Vec<f64> {
423 diagnostics::compute_feature_importances(self, false)
424 }
425
426 pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
428 let location = diagnostics::compute_feature_importances(self, true);
429 let scale = diagnostics::compute_feature_importances_scale(self);
430 (location, scale)
431 }
432
433 #[allow(dead_code)]
435 fn compute_honest_sigma(&self, features: &[f64]) -> f64 {
436 if self.location_steps.len() < 2 {
437 return 0.0;
438 }
439
440 let preds: Vec<f64> = self
441 .location_steps
442 .iter()
443 .map(|s| s.predict(features))
444 .collect();
445
446 let n = preds.len() as f64;
447 let mean = preds.iter().sum::<f64>() / n;
448 let var = preds
449 .iter()
450 .map(|p| {
451 let d = p - mean;
452 d * d
453 })
454 .sum::<f64>()
455 / (n - 1.0).max(1.0);
456 crate::math::sqrt(var)
457 }
458}
459
460impl crate::learner::StreamingLearner for DistributionalSGBT {
461 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
462 let sample = SampleRef::weighted(features, target, weight);
463 DistributionalSGBT::train_one(self, &sample);
464 }
465
466 fn predict(&self, features: &[f64]) -> f64 {
467 DistributionalSGBT::predict(self, features).mu
468 }
469
470 fn n_samples_seen(&self) -> u64 {
471 self.samples_seen
472 }
473
474 fn reset(&mut self) {
475 DistributionalSGBT::reset(self);
476 }
477}