irithyll_core/ensemble/config.rs
1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34/// `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1) compute.
35/// Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38/// log(σ) from features (NGBoost-style dual chain). Gives *feature-conditional*
39/// uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43 /// EWMA of squared prediction errors — always calibrated, O(1).
44 #[default]
45 Empirical,
46 /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47 TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61 /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62 /// (detection threshold).
63 PageHinkley {
64 /// Magnitude tolerance. Default 0.005.
65 delta: f64,
66 /// Detection threshold. Default 50.0.
67 lambda: f64,
68 },
69
70 /// ADWIN with custom confidence parameter.
71 Adwin {
72 /// Confidence (smaller = fewer false positives). Default 0.002.
73 delta: f64,
74 },
75
76 /// DDM with custom warning/drift levels and minimum warmup instances.
77 Ddm {
78 /// Warning threshold multiplier. Default 2.0.
79 warning_level: f64,
80 /// Drift threshold multiplier. Default 3.0.
81 drift_level: f64,
82 /// Minimum observations before detection activates. Default 30.
83 min_instances: u64,
84 },
85}
86
87impl Default for DriftDetectorType {
88 fn default() -> Self {
89 DriftDetectorType::PageHinkley {
90 delta: 0.005,
91 lambda: 50.0,
92 }
93 }
94}
95
96impl DriftDetectorType {
97 /// Create a new, fresh drift detector from this configuration.
98 pub fn create(&self) -> Box<dyn DriftDetector> {
99 match self {
100 Self::PageHinkley { delta, lambda } => {
101 Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102 }
103 Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104 Self::Ddm {
105 warning_level,
106 drift_level,
107 min_instances,
108 } => Box::new(Ddm::with_params(
109 *warning_level,
110 *drift_level,
111 *min_instances,
112 )),
113 }
114 }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter | Default |
128/// |--------------------------|----------------------|
129/// | `n_steps` | 100 |
130/// | `learning_rate` | 0.0125 |
131/// | `feature_subsample_rate` | 0.75 |
132/// | `max_depth` | 6 |
133/// | `n_bins` | 64 |
134/// | `lambda` | 1.0 |
135/// | `gamma` | 0.0 |
136/// | `grace_period` | 200 |
137/// | `delta` | 1e-7 |
138/// | `drift_detector` | PageHinkley(0.005, 50.0) |
139/// | `variant` | Standard |
140/// | `seed` | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count` | 50 |
142/// | `leaf_half_life` | None (disabled) |
143/// | `max_tree_samples` | None (disabled) |
144/// | `split_reeval_interval` | None (disabled) |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148 /// Number of boosting steps (trees in the ensemble). Default 100.
149 pub n_steps: usize,
150 /// Learning rate (shrinkage). Default 0.0125.
151 pub learning_rate: f64,
152 /// Fraction of features to subsample per tree. Default 0.75.
153 pub feature_subsample_rate: f64,
154 /// Maximum tree depth. Default 6.
155 pub max_depth: usize,
156 /// Number of histogram bins. Default 64.
157 pub n_bins: usize,
158 /// L2 regularization parameter (lambda). Default 1.0.
159 pub lambda: f64,
160 /// Minimum split gain (gamma). Default 0.0.
161 pub gamma: f64,
162 /// Grace period: min samples before evaluating splits. Default 200.
163 pub grace_period: usize,
164 /// Hoeffding bound confidence (delta). Default 1e-7.
165 pub delta: f64,
166 /// Drift detector type for tree replacement. Default: PageHinkley.
167 pub drift_detector: DriftDetectorType,
168 /// SGBT computational variant. Default: Standard.
169 pub variant: SGBTVariant,
170 /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171 ///
172 /// Controls feature subsampling and variant skip/MI stochastic decisions.
173 /// Two models with the same seed and same data will produce identical results.
174 pub seed: u64,
175 /// Number of initial targets to collect before computing the base prediction.
176 /// Default: 50.
177 pub initial_target_count: usize,
178
179 /// Half-life for exponential leaf decay (in samples per leaf).
180 ///
181 /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182 /// statistics have half the weight of the most recent sample. This causes
183 /// the model to continuously adapt to changing data distributions rather
184 /// than freezing on early observations.
185 ///
186 /// `None` (default) disables decay -- traditional monotonic accumulation.
187 #[cfg_attr(feature = "serde", serde(default))]
188 pub leaf_half_life: Option<usize>,
189
190 /// Maximum samples a single tree processes before proactive replacement.
191 ///
192 /// After this many samples, the tree is replaced with a fresh one regardless
193 /// of drift detector state. Prevents stale tree structure from persisting
194 /// when the drift detector is not sensitive enough.
195 ///
196 /// `None` (default) disables time-based replacement.
197 #[cfg_attr(feature = "serde", serde(default))]
198 pub max_tree_samples: Option<u64>,
199
200 /// Sigma-modulated tree replacement speed.
201 ///
202 /// When `Some((base_mts, k))`, the effective max_tree_samples becomes:
203 /// `effective_mts = base_mts / (1.0 + k * normalized_sigma)`
204 /// where `normalized_sigma = contribution_sigma / rolling_contribution_sigma_mean`.
205 ///
206 /// High uncertainty (contribution variance) shortens tree lifetime for faster
207 /// adaptation. Low uncertainty lengthens lifetime for more knowledge accumulation.
208 ///
209 /// Overrides `max_tree_samples` when set. Default: `None`.
210 #[cfg_attr(feature = "serde", serde(default))]
211 pub adaptive_mts: Option<(u64, f64)>,
212
213 /// Proactive tree pruning interval.
214 ///
215 /// Every `interval` training samples, the tree with the lowest EWMA
216 /// marginal contribution is replaced with a fresh tree.
217 /// Fresh trees (fewer than `interval / 2` samples) are excluded.
218 ///
219 /// When set, contribution tracking is automatically enabled even if
220 /// `quality_prune_alpha` is `None` (uses alpha=0.01).
221 ///
222 /// Default: `None` (disabled).
223 #[cfg_attr(feature = "serde", serde(default))]
224 pub proactive_prune_interval: Option<u64>,
225
226 /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
227 /// whether a split would improve them.
228 ///
229 /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
230 /// `split_reeval_interval` samples since its last evaluation and has reached
231 /// max depth, it re-evaluates whether a split should be performed.
232 ///
233 /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
234 #[cfg_attr(feature = "serde", serde(default))]
235 pub split_reeval_interval: Option<usize>,
236
237 /// Optional human-readable feature names.
238 ///
239 /// When set, enables `named_feature_importances` and
240 /// `train_one_named` for production-friendly named access.
241 /// Length must match the number of features in training data.
242 #[cfg_attr(feature = "serde", serde(default))]
243 pub feature_names: Option<Vec<String>>,
244
245 /// Optional per-feature type declarations.
246 ///
247 /// When set, declares which features are categorical vs continuous.
248 /// Categorical features use one-bin-per-category binning and Fisher
249 /// optimal binary partitioning for split evaluation.
250 /// Length must match the number of features in training data.
251 ///
252 /// `None` (default) treats all features as continuous.
253 #[cfg_attr(feature = "serde", serde(default))]
254 pub feature_types: Option<Vec<FeatureType>>,
255
256 /// Gradient clipping threshold in standard deviations per leaf.
257 ///
258 /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
259 /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
260 /// clamped to the boundary. This prevents outlier samples from corrupting
261 /// leaf statistics, which is critical in streaming settings where sudden
262 /// label floods can destabilize the model.
263 ///
264 /// Typical value: 3.0 (3-sigma clipping).
265 /// `None` (default) disables gradient clipping.
266 #[cfg_attr(feature = "serde", serde(default))]
267 pub gradient_clip_sigma: Option<f64>,
268
269 /// Per-feature monotonic constraints.
270 ///
271 /// Each element specifies the monotonic relationship between a feature and
272 /// the prediction:
273 /// - `+1`: prediction must be non-decreasing as feature value increases.
274 /// - `-1`: prediction must be non-increasing as feature value increases.
275 /// - `0`: no constraint (unconstrained).
276 ///
277 /// During split evaluation, candidate splits that would violate monotonicity
278 /// (left child value > right child value for +1 constraints, or vice versa)
279 /// are rejected.
280 ///
281 /// Length must match the number of features in training data.
282 /// `None` (default) means no monotonic constraints.
283 #[cfg_attr(feature = "serde", serde(default))]
284 pub monotone_constraints: Option<Vec<i8>>,
285
286 /// EWMA smoothing factor for quality-based tree pruning.
287 ///
288 /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
289 /// moving average of its marginal contribution to the ensemble. Trees whose
290 /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
291 /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
292 /// samples are replaced with a fresh tree that can learn the current regime.
293 ///
294 /// This prevents "dead wood" -- trees from a past regime that no longer
295 /// contribute meaningfully to ensemble accuracy.
296 ///
297 /// `None` (default) disables quality-based pruning.
298 /// Suggested value: 0.01.
299 #[cfg_attr(feature = "serde", serde(default))]
300 pub quality_prune_alpha: Option<f64>,
301
302 /// Minimum contribution threshold for quality-based pruning.
303 ///
304 /// A tree's EWMA contribution must stay above this value to avoid being
305 /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
306 ///
307 /// Default: 1e-6.
308 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
309 pub quality_prune_threshold: f64,
310
311 /// Consecutive low-contribution samples before a tree is replaced.
312 ///
313 /// After this many consecutive samples where a tree's EWMA contribution
314 /// is below `quality_prune_threshold`, the tree is reset. Only used when
315 /// `quality_prune_alpha` is `Some`.
316 ///
317 /// Default: 500.
318 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
319 pub quality_prune_patience: u64,
320
321 /// EWMA smoothing factor for error-weighted sample importance.
322 ///
323 /// When `Some(alpha)`, samples the model predicted poorly get higher
324 /// effective weight during histogram accumulation. The weight is:
325 /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
326 ///
327 /// This is a streaming version of AdaBoost's reweighting applied at the
328 /// gradient level -- learning capacity focuses on hard/novel patterns,
329 /// enabling faster adaptation to regime changes.
330 ///
331 /// `None` (default) disables error weighting.
332 /// Suggested value: 0.01.
333 #[cfg_attr(feature = "serde", serde(default))]
334 pub error_weight_alpha: Option<f64>,
335
336 /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
337 ///
338 /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
339 /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
340 /// is an EWMA of the model's predicted σ (alpha = 0.001).
341 ///
342 /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
343 /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
344 /// trains at the unmodulated base rate to prevent positive feedback loops.
345 ///
346 /// Default: `false`.
347 #[cfg_attr(feature = "serde", serde(default))]
348 pub uncertainty_modulated_lr: bool,
349
350 /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
351 ///
352 /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
353 /// errors. `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1).
354 /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
355 /// separate tree ensemble predicting log(σ) from features.
356 ///
357 /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
358 /// is strongly recommended — scale tree gradients are inherently weak and
359 /// the trees often fail to split.
360 #[cfg_attr(feature = "serde", serde(default))]
361 pub scale_mode: ScaleMode,
362
363 /// EWMA smoothing factor for empirical σ estimation.
364 ///
365 /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
366 /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
367 /// Higher values react faster to regime changes but are noisier.
368 ///
369 /// Default: `0.01` (~100-sample effective window).
370 #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
371 pub empirical_sigma_alpha: f64,
372
373 /// Maximum absolute leaf output value.
374 ///
375 /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
376 /// Prevents runaway leaf weights from causing prediction explosions
377 /// in feedback loops. `None` (default) means no clamping.
378 #[cfg_attr(feature = "serde", serde(default))]
379 pub max_leaf_output: Option<f64>,
380
381 /// Per-leaf adaptive output bound (sigma multiplier).
382 ///
383 /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
384 /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
385 /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
386 ///
387 /// This is strictly superior to `max_leaf_output` for streaming — the bound
388 /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
389 /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
390 ///
391 /// Typical value: 3.0 (3-sigma bound).
392 /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
393 #[cfg_attr(feature = "serde", serde(default))]
394 pub adaptive_leaf_bound: Option<f64>,
395
396 /// Per-split information criterion (Lunde-Kleppe-Skaug 2020).
397 ///
398 /// When `Some(cir_factor)`, replaces `max_depth` with a per-split
399 /// generalization test. Each candidate split must have
400 /// `gain > cir_factor * sigma^2_g / n * n_features`.
401 /// `max_depth * 2` becomes a hard safety ceiling.
402 ///
403 /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
404 /// `None` (default) uses static `max_depth` only.
405 #[cfg_attr(feature = "serde", serde(default))]
406 pub adaptive_depth: Option<f64>,
407
408 /// Minimum hessian sum before a leaf produces non-zero output.
409 ///
410 /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
411 /// Prevents post-replacement spikes from fresh leaves with insufficient
412 /// samples. `None` (default) means all leaves contribute immediately.
413 #[cfg_attr(feature = "serde", serde(default))]
414 pub min_hessian_sum: Option<f64>,
415
416 /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
417 ///
418 /// When `Some(k)`, the distributional location gradient uses Huber loss
419 /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
420 /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
421 /// `None` (default) uses squared loss.
422 #[cfg_attr(feature = "serde", serde(default))]
423 pub huber_k: Option<f64>,
424
425 /// Shadow warmup for graduated tree handoff.
426 ///
427 /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
428 /// alongside every active tree. The shadow trains on the same gradient stream
429 /// but does not contribute to predictions until it has seen `n` samples.
430 ///
431 /// As the active tree ages past 80% of `max_tree_samples`, its prediction
432 /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
433 /// over `n` samples after warmup. When the active weight reaches 0, the shadow
434 /// is promoted and a new shadow is spawned — no cold-start prediction dip.
435 ///
436 /// Requires `max_tree_samples` to be set for time-based graduated handoff.
437 /// Drift-based replacement still uses hard swap (shadow is already warm).
438 ///
439 /// `None` (default) disables graduated handoff — uses traditional hard swap.
440 #[cfg_attr(feature = "serde", serde(default))]
441 pub shadow_warmup: Option<usize>,
442
443 /// Leaf prediction model type.
444 ///
445 /// Controls how each leaf computes its prediction:
446 /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
447 /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
448 /// AdaGrad optimization. Optional `decay` for concept drift. Recommended for
449 /// low-depth trees (depth 2--4).
450 /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
451 /// Optional `decay` for concept drift.
452 /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
453 /// when the Hoeffding bound confirms a more complex model is better.
454 ///
455 /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
456 #[cfg_attr(feature = "serde", serde(default))]
457 pub leaf_model_type: LeafModelType,
458
459 /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
460 ///
461 /// When non-zero, the distributional model maintains a packed f32 cache of
462 /// its location ensemble that is re-exported every `packed_refresh_interval`
463 /// training samples. Predictions use the cache for O(1)-per-tree inference
464 /// via contiguous memory traversal, falling back to full tree traversal when
465 /// the cache is absent or produces non-finite results.
466 ///
467 /// `0` (default) disables the packed cache.
468 #[cfg_attr(feature = "serde", serde(default))]
469 pub packed_refresh_interval: u64,
470
471 /// Enable per-node auto-bandwidth soft routing at prediction time.
472 ///
473 /// When `true`, predictions are continuous weighted blends instead of
474 /// piecewise-constant step functions. No training changes.
475 ///
476 /// Default: `false`.
477 #[cfg_attr(feature = "serde", serde(default))]
478 pub soft_routing: bool,
479}
480
481#[cfg(feature = "serde")]
482fn default_empirical_sigma_alpha() -> f64 {
483 0.01
484}
485
486#[cfg(feature = "serde")]
487fn default_quality_prune_threshold() -> f64 {
488 1e-6
489}
490
491#[cfg(feature = "serde")]
492fn default_quality_prune_patience() -> u64 {
493 500
494}
495
496impl Default for SGBTConfig {
497 fn default() -> Self {
498 Self {
499 n_steps: 100,
500 learning_rate: 0.0125,
501 feature_subsample_rate: 0.75,
502 max_depth: 6,
503 n_bins: 64,
504 lambda: 1.0,
505 gamma: 0.0,
506 grace_period: 200,
507 delta: 1e-7,
508 drift_detector: DriftDetectorType::default(),
509 variant: SGBTVariant::default(),
510 seed: 0xDEAD_BEEF_CAFE_4242,
511 initial_target_count: 50,
512 leaf_half_life: None,
513 max_tree_samples: None,
514 adaptive_mts: None,
515 proactive_prune_interval: None,
516 split_reeval_interval: None,
517 feature_names: None,
518 feature_types: None,
519 gradient_clip_sigma: None,
520 monotone_constraints: None,
521 quality_prune_alpha: None,
522 quality_prune_threshold: 1e-6,
523 quality_prune_patience: 500,
524 error_weight_alpha: None,
525 uncertainty_modulated_lr: false,
526 scale_mode: ScaleMode::default(),
527 empirical_sigma_alpha: 0.01,
528 max_leaf_output: None,
529 adaptive_leaf_bound: None,
530 adaptive_depth: None,
531 min_hessian_sum: None,
532 huber_k: None,
533 shadow_warmup: None,
534 leaf_model_type: LeafModelType::default(),
535 packed_refresh_interval: 0,
536 soft_routing: false,
537 }
538 }
539}
540
541impl SGBTConfig {
542 /// Start building a configuration via the builder pattern.
543 pub fn builder() -> SGBTConfigBuilder {
544 SGBTConfigBuilder::default()
545 }
546}
547
548// ---------------------------------------------------------------------------
549// SGBTConfigBuilder
550// ---------------------------------------------------------------------------
551
552/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
553///
554/// # Example
555///
556/// ```ignore
557/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
558/// use irithyll::ensemble::variants::SGBTVariant;
559///
560/// let config = SGBTConfig::builder()
561/// .n_steps(200)
562/// .learning_rate(0.05)
563/// .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
564/// .variant(SGBTVariant::Skip { k: 10 })
565/// .build()
566/// .expect("valid config");
567/// ```
568#[derive(Debug, Clone, Default)]
569pub struct SGBTConfigBuilder {
570 config: SGBTConfig,
571}
572
573impl SGBTConfigBuilder {
574 /// Set the number of boosting steps (trees in the ensemble).
575 pub fn n_steps(mut self, n: usize) -> Self {
576 self.config.n_steps = n;
577 self
578 }
579
580 /// Set the learning rate (shrinkage factor).
581 pub fn learning_rate(mut self, lr: f64) -> Self {
582 self.config.learning_rate = lr;
583 self
584 }
585
586 /// Set the fraction of features to subsample per tree.
587 pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
588 self.config.feature_subsample_rate = rate;
589 self
590 }
591
592 /// Set the maximum tree depth.
593 pub fn max_depth(mut self, depth: usize) -> Self {
594 self.config.max_depth = depth;
595 self
596 }
597
598 /// Set the number of histogram bins per feature.
599 pub fn n_bins(mut self, bins: usize) -> Self {
600 self.config.n_bins = bins;
601 self
602 }
603
604 /// Set the L2 regularization parameter (lambda).
605 pub fn lambda(mut self, l: f64) -> Self {
606 self.config.lambda = l;
607 self
608 }
609
610 /// Set the minimum split gain (gamma).
611 pub fn gamma(mut self, g: f64) -> Self {
612 self.config.gamma = g;
613 self
614 }
615
616 /// Set the grace period (minimum samples before evaluating splits).
617 pub fn grace_period(mut self, gp: usize) -> Self {
618 self.config.grace_period = gp;
619 self
620 }
621
622 /// Set the Hoeffding bound confidence parameter (delta).
623 pub fn delta(mut self, d: f64) -> Self {
624 self.config.delta = d;
625 self
626 }
627
628 /// Set the drift detector type for tree replacement.
629 pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
630 self.config.drift_detector = dt;
631 self
632 }
633
634 /// Set the SGBT computational variant.
635 pub fn variant(mut self, v: SGBTVariant) -> Self {
636 self.config.variant = v;
637 self
638 }
639
640 /// Set the random seed for deterministic reproducibility.
641 ///
642 /// Controls feature subsampling and variant skip/MI stochastic decisions.
643 /// Two models with the same seed and data sequence will produce identical results.
644 pub fn seed(mut self, seed: u64) -> Self {
645 self.config.seed = seed;
646 self
647 }
648
649 /// Set the number of initial targets to collect before computing the base prediction.
650 ///
651 /// The model collects this many target values before initializing the base
652 /// prediction (via `loss.initial_prediction`). Default: 50.
653 pub fn initial_target_count(mut self, count: usize) -> Self {
654 self.config.initial_target_count = count;
655 self
656 }
657
658 /// Set the half-life for exponential leaf decay (in samples per leaf).
659 ///
660 /// After `n` samples, a leaf's accumulated statistics have half the weight
661 /// of the most recent sample. Enables continuous adaptation to concept drift.
662 pub fn leaf_half_life(mut self, n: usize) -> Self {
663 self.config.leaf_half_life = Some(n);
664 self
665 }
666
667 /// Set the maximum samples a single tree processes before proactive replacement.
668 ///
669 /// After `n` samples, the tree is replaced regardless of drift detector state.
670 pub fn max_tree_samples(mut self, n: u64) -> Self {
671 self.config.max_tree_samples = Some(n);
672 self
673 }
674
675 /// Enable sigma-modulated tree replacement speed.
676 ///
677 /// The effective max_tree_samples becomes `base_mts / (1.0 + k * normalized_sigma)`,
678 /// where `normalized_sigma` tracks contribution variance across ensemble trees.
679 /// Overrides `max_tree_samples` when set.
680 pub fn adaptive_mts(mut self, base_mts: u64, k: f64) -> Self {
681 self.config.adaptive_mts = Some((base_mts, k));
682 self
683 }
684
685 /// Set the proactive tree pruning interval.
686 ///
687 /// Every `interval` training samples, the tree with the lowest EWMA
688 /// marginal contribution is replaced. Automatically enables contribution
689 /// tracking with alpha=0.01 when `quality_prune_alpha` is not set.
690 pub fn proactive_prune_interval(mut self, interval: u64) -> Self {
691 self.config.proactive_prune_interval = Some(interval);
692 self
693 }
694
695 /// Set the split re-evaluation interval for max-depth leaves.
696 ///
697 /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
698 /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
699 pub fn split_reeval_interval(mut self, n: usize) -> Self {
700 self.config.split_reeval_interval = Some(n);
701 self
702 }
703
704 /// Set human-readable feature names.
705 ///
706 /// Enables named feature importances and named training input.
707 /// Names must be unique; validated at [`build()`](Self::build).
708 pub fn feature_names(mut self, names: Vec<String>) -> Self {
709 self.config.feature_names = Some(names);
710 self
711 }
712
713 /// Set per-feature type declarations.
714 ///
715 /// Declares which features are categorical vs continuous. Categorical features
716 /// use one-bin-per-category binning and Fisher optimal binary partitioning.
717 /// Supports up to 64 distinct category values per categorical feature.
718 pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
719 self.config.feature_types = Some(types);
720 self
721 }
722
723 /// Set per-leaf gradient clipping threshold (in standard deviations).
724 ///
725 /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
726 /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
727 /// from corrupting streaming model stability.
728 ///
729 /// Typical value: 3.0 (3-sigma clipping).
730 pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
731 self.config.gradient_clip_sigma = Some(sigma);
732 self
733 }
734
735 /// Set per-feature monotonic constraints.
736 ///
737 /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
738 /// Candidate splits violating monotonicity are rejected during tree growth.
739 pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
740 self.config.monotone_constraints = Some(constraints);
741 self
742 }
743
744 /// Enable quality-based tree pruning with the given EWMA smoothing factor.
745 ///
746 /// Trees whose marginal contribution drops below the threshold for
747 /// `patience` consecutive samples are replaced with fresh trees.
748 /// Suggested alpha: 0.01.
749 pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
750 self.config.quality_prune_alpha = Some(alpha);
751 self
752 }
753
754 /// Set the minimum contribution threshold for quality-based pruning.
755 ///
756 /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
757 pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
758 self.config.quality_prune_threshold = threshold;
759 self
760 }
761
762 /// Set the patience (consecutive low-contribution samples) before pruning.
763 ///
764 /// Default: 500. Only relevant when `quality_prune_alpha` is set.
765 pub fn quality_prune_patience(mut self, patience: u64) -> Self {
766 self.config.quality_prune_patience = patience;
767 self
768 }
769
770 /// Enable error-weighted sample importance with the given EWMA smoothing factor.
771 ///
772 /// Samples the model predicted poorly get higher effective weight.
773 /// Suggested alpha: 0.01.
774 pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
775 self.config.error_weight_alpha = Some(alpha);
776 self
777 }
778
779 /// Enable σ-modulated learning rate for distributional models.
780 ///
781 /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
782 /// so the model adapts faster during high-uncertainty regimes and conserves
783 /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
784 ///
785 /// By default uses empirical σ (EWMA of squared errors). Set
786 /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
787 pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
788 self.config.uncertainty_modulated_lr = enabled;
789 self
790 }
791
792 /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
793 ///
794 /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
795 /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
796 pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
797 self.config.scale_mode = mode;
798 self
799 }
800
801 /// Enable per-node auto-bandwidth soft routing.
802 pub fn soft_routing(mut self, enabled: bool) -> Self {
803 self.config.soft_routing = enabled;
804 self
805 }
806
807 /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
808 ///
809 /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
810 pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
811 self.config.empirical_sigma_alpha = alpha;
812 self
813 }
814
815 /// Set the maximum absolute leaf output value.
816 ///
817 /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
818 /// that cause prediction explosions.
819 pub fn max_leaf_output(mut self, max: f64) -> Self {
820 self.config.max_leaf_output = Some(max);
821 self
822 }
823
824 /// Set per-leaf adaptive output bound (sigma multiplier).
825 ///
826 /// Each leaf tracks EWMA of its own output weight and clamps to
827 /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
828 /// Recommended: use with `leaf_half_life` for streaming scenarios.
829 pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
830 self.config.adaptive_leaf_bound = Some(k);
831 self
832 }
833
834 /// Set the per-split information criterion factor (Lunde-Kleppe-Skaug 2020).
835 ///
836 /// Replaces static `max_depth` with a per-split generalization test.
837 /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
838 pub fn adaptive_depth(mut self, factor: f64) -> Self {
839 self.config.adaptive_depth = Some(factor);
840 self
841 }
842
843 /// Set the minimum hessian sum for leaf output.
844 ///
845 /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
846 /// post-replacement spikes.
847 pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
848 self.config.min_hessian_sum = Some(min_h);
849 self
850 }
851
852 /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
853 ///
854 /// When set, location gradients use Huber loss with adaptive
855 /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
856 pub fn huber_k(mut self, k: f64) -> Self {
857 self.config.huber_k = Some(k);
858 self
859 }
860
861 /// Enable graduated tree handoff with the given shadow warmup samples.
862 ///
863 /// Spawns an always-on shadow tree that trains alongside the active tree.
864 /// After `warmup` samples, the shadow begins contributing to predictions
865 /// via graduated blending. Eliminates prediction dips during tree replacement.
866 pub fn shadow_warmup(mut self, warmup: usize) -> Self {
867 self.config.shadow_warmup = Some(warmup);
868 self
869 }
870
871 /// Set the leaf prediction model type.
872 ///
873 /// [`LeafModelType::Linear`] is recommended for low-depth configurations
874 /// (depth 2--4) where per-leaf linear models reduce approximation error.
875 ///
876 /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
877 /// a trainable model per leaf, using the Hoeffding bound for promotion.
878 pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
879 self.config.leaf_model_type = lmt;
880 self
881 }
882
883 /// Set the packed cache refresh interval for distributional models.
884 ///
885 /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
886 /// maintains a packed f32 cache refreshed every `interval` training samples.
887 /// `0` (default) disables the cache.
888 pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
889 self.config.packed_refresh_interval = interval;
890 self
891 }
892
893 /// Validate and build the configuration.
894 ///
895 /// # Errors
896 ///
897 /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
898 /// [`ConfigError`] if any parameter is out of its valid range.
899 pub fn build(self) -> Result<SGBTConfig> {
900 let c = &self.config;
901
902 // -- Ensemble-level parameters --
903 if c.n_steps == 0 {
904 return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
905 }
906 if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
907 return Err(ConfigError::out_of_range(
908 "learning_rate",
909 "must be in (0, 1]",
910 c.learning_rate,
911 )
912 .into());
913 }
914 if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
915 return Err(ConfigError::out_of_range(
916 "feature_subsample_rate",
917 "must be in (0, 1]",
918 c.feature_subsample_rate,
919 )
920 .into());
921 }
922
923 // -- Tree-level parameters --
924 if c.max_depth == 0 {
925 return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
926 }
927 if c.n_bins < 2 {
928 return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
929 }
930 if c.lambda < 0.0 {
931 return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
932 }
933 if c.gamma < 0.0 {
934 return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
935 }
936 if c.grace_period == 0 {
937 return Err(
938 ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
939 );
940 }
941 if c.delta <= 0.0 || c.delta >= 1.0 {
942 return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
943 }
944
945 if c.initial_target_count == 0 {
946 return Err(ConfigError::out_of_range(
947 "initial_target_count",
948 "must be > 0",
949 c.initial_target_count,
950 )
951 .into());
952 }
953
954 // -- Streaming adaptation parameters --
955 if let Some(hl) = c.leaf_half_life {
956 if hl == 0 {
957 return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
958 }
959 }
960 if let Some(max) = c.max_tree_samples {
961 if max < 100 {
962 return Err(
963 ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
964 );
965 }
966 }
967 if let Some((base_mts, k)) = c.adaptive_mts {
968 if base_mts < 100 {
969 return Err(ConfigError::out_of_range(
970 "adaptive_mts.base_mts",
971 "must be >= 100",
972 base_mts,
973 )
974 .into());
975 }
976 if k <= 0.0 {
977 return Err(ConfigError::out_of_range("adaptive_mts.k", "must be > 0", k).into());
978 }
979 }
980 if let Some(interval) = c.proactive_prune_interval {
981 if interval < 100 {
982 return Err(ConfigError::out_of_range(
983 "proactive_prune_interval",
984 "must be >= 100",
985 interval,
986 )
987 .into());
988 }
989 }
990 if let Some(interval) = c.split_reeval_interval {
991 if interval < c.grace_period {
992 return Err(ConfigError::invalid(
993 "split_reeval_interval",
994 format!(
995 "must be >= grace_period ({}), got {}",
996 c.grace_period, interval
997 ),
998 )
999 .into());
1000 }
1001 }
1002
1003 // -- Feature names --
1004 if let Some(ref names) = c.feature_names {
1005 // O(n^2) duplicate check — names list is small so no HashSet needed
1006 for (i, name) in names.iter().enumerate() {
1007 for prev in &names[..i] {
1008 if name == prev {
1009 return Err(ConfigError::invalid(
1010 "feature_names",
1011 format!("duplicate feature name: '{}'", name),
1012 )
1013 .into());
1014 }
1015 }
1016 }
1017 }
1018
1019 // -- Feature types --
1020 if let Some(ref types) = c.feature_types {
1021 if let Some(ref names) = c.feature_names {
1022 if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
1023 return Err(ConfigError::invalid(
1024 "feature_types",
1025 format!(
1026 "length ({}) must match feature_names length ({})",
1027 types.len(),
1028 names.len()
1029 ),
1030 )
1031 .into());
1032 }
1033 }
1034 }
1035
1036 // -- Gradient clipping --
1037 if let Some(sigma) = c.gradient_clip_sigma {
1038 if sigma <= 0.0 {
1039 return Err(
1040 ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
1041 );
1042 }
1043 }
1044
1045 // -- Monotonic constraints --
1046 if let Some(ref mc) = c.monotone_constraints {
1047 for (i, &v) in mc.iter().enumerate() {
1048 if v != -1 && v != 0 && v != 1 {
1049 return Err(ConfigError::invalid(
1050 "monotone_constraints",
1051 format!("feature {}: must be -1, 0, or +1, got {}", i, v),
1052 )
1053 .into());
1054 }
1055 }
1056 }
1057
1058 // -- Leaf output clamping --
1059 if let Some(max) = c.max_leaf_output {
1060 if max <= 0.0 {
1061 return Err(
1062 ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
1063 );
1064 }
1065 }
1066
1067 // -- Per-leaf adaptive output bound --
1068 if let Some(k) = c.adaptive_leaf_bound {
1069 if k <= 0.0 {
1070 return Err(
1071 ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
1072 );
1073 }
1074 }
1075
1076 // -- Adaptive depth (per-split information criterion) --
1077 if let Some(factor) = c.adaptive_depth {
1078 if factor <= 0.0 {
1079 return Err(
1080 ConfigError::out_of_range("adaptive_depth", "must be > 0", factor).into(),
1081 );
1082 }
1083 }
1084
1085 // -- Minimum hessian sum --
1086 if let Some(min_h) = c.min_hessian_sum {
1087 if min_h <= 0.0 {
1088 return Err(
1089 ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
1090 );
1091 }
1092 }
1093
1094 // -- Huber loss multiplier --
1095 if let Some(k) = c.huber_k {
1096 if k <= 0.0 {
1097 return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
1098 }
1099 }
1100
1101 // -- Shadow warmup --
1102 if let Some(warmup) = c.shadow_warmup {
1103 if warmup == 0 {
1104 return Err(ConfigError::out_of_range(
1105 "shadow_warmup",
1106 "must be > 0",
1107 warmup as f64,
1108 )
1109 .into());
1110 }
1111 }
1112
1113 // -- Quality-based pruning parameters --
1114 if let Some(alpha) = c.quality_prune_alpha {
1115 if alpha <= 0.0 || alpha >= 1.0 {
1116 return Err(ConfigError::out_of_range(
1117 "quality_prune_alpha",
1118 "must be in (0, 1)",
1119 alpha,
1120 )
1121 .into());
1122 }
1123 }
1124 if c.quality_prune_threshold <= 0.0 {
1125 return Err(ConfigError::out_of_range(
1126 "quality_prune_threshold",
1127 "must be > 0",
1128 c.quality_prune_threshold,
1129 )
1130 .into());
1131 }
1132 if c.quality_prune_patience == 0 {
1133 return Err(ConfigError::out_of_range(
1134 "quality_prune_patience",
1135 "must be > 0",
1136 c.quality_prune_patience,
1137 )
1138 .into());
1139 }
1140
1141 // -- Error-weighted sample importance --
1142 if let Some(alpha) = c.error_weight_alpha {
1143 if alpha <= 0.0 || alpha >= 1.0 {
1144 return Err(ConfigError::out_of_range(
1145 "error_weight_alpha",
1146 "must be in (0, 1)",
1147 alpha,
1148 )
1149 .into());
1150 }
1151 }
1152
1153 // -- Drift detector parameters --
1154 match &c.drift_detector {
1155 DriftDetectorType::PageHinkley { delta, lambda } => {
1156 if *delta <= 0.0 {
1157 return Err(ConfigError::out_of_range(
1158 "drift_detector.PageHinkley.delta",
1159 "must be > 0",
1160 delta,
1161 )
1162 .into());
1163 }
1164 if *lambda <= 0.0 {
1165 return Err(ConfigError::out_of_range(
1166 "drift_detector.PageHinkley.lambda",
1167 "must be > 0",
1168 lambda,
1169 )
1170 .into());
1171 }
1172 }
1173 DriftDetectorType::Adwin { delta } => {
1174 if *delta <= 0.0 || *delta >= 1.0 {
1175 return Err(ConfigError::out_of_range(
1176 "drift_detector.Adwin.delta",
1177 "must be in (0, 1)",
1178 delta,
1179 )
1180 .into());
1181 }
1182 }
1183 DriftDetectorType::Ddm {
1184 warning_level,
1185 drift_level,
1186 min_instances,
1187 } => {
1188 if *warning_level <= 0.0 {
1189 return Err(ConfigError::out_of_range(
1190 "drift_detector.Ddm.warning_level",
1191 "must be > 0",
1192 warning_level,
1193 )
1194 .into());
1195 }
1196 if *drift_level <= 0.0 {
1197 return Err(ConfigError::out_of_range(
1198 "drift_detector.Ddm.drift_level",
1199 "must be > 0",
1200 drift_level,
1201 )
1202 .into());
1203 }
1204 if *drift_level <= *warning_level {
1205 return Err(ConfigError::invalid(
1206 "drift_detector.Ddm.drift_level",
1207 format!(
1208 "must be > warning_level ({}), got {}",
1209 warning_level, drift_level
1210 ),
1211 )
1212 .into());
1213 }
1214 if *min_instances == 0 {
1215 return Err(ConfigError::out_of_range(
1216 "drift_detector.Ddm.min_instances",
1217 "must be > 0",
1218 min_instances,
1219 )
1220 .into());
1221 }
1222 }
1223 }
1224
1225 // -- Variant parameters --
1226 match &c.variant {
1227 SGBTVariant::Standard => {} // no extra validation
1228 SGBTVariant::Skip { k } => {
1229 if *k == 0 {
1230 return Err(
1231 ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1232 );
1233 }
1234 }
1235 SGBTVariant::MultipleIterations { multiplier } => {
1236 if *multiplier <= 0.0 {
1237 return Err(ConfigError::out_of_range(
1238 "variant.MultipleIterations.multiplier",
1239 "must be > 0",
1240 multiplier,
1241 )
1242 .into());
1243 }
1244 }
1245 }
1246
1247 Ok(self.config)
1248 }
1249}
1250
1251// ---------------------------------------------------------------------------
1252// Display impls
1253// ---------------------------------------------------------------------------
1254
1255impl core::fmt::Display for DriftDetectorType {
1256 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1257 match self {
1258 Self::PageHinkley { delta, lambda } => {
1259 write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1260 }
1261 Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1262 Self::Ddm {
1263 warning_level,
1264 drift_level,
1265 min_instances,
1266 } => write!(
1267 f,
1268 "Ddm(warning={}, drift={}, min_instances={})",
1269 warning_level, drift_level, min_instances
1270 ),
1271 }
1272 }
1273}
1274
1275impl core::fmt::Display for SGBTConfig {
1276 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1277 write!(
1278 f,
1279 "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1280 self.n_steps,
1281 self.learning_rate,
1282 self.max_depth,
1283 self.n_bins,
1284 self.variant,
1285 self.drift_detector,
1286 )
1287 }
1288}
1289
1290// ---------------------------------------------------------------------------
1291// Tests
1292// ---------------------------------------------------------------------------
1293
1294#[cfg(test)]
1295mod tests {
1296 use super::*;
1297 use alloc::format;
1298 use alloc::vec;
1299
1300 // ------------------------------------------------------------------
1301 // 1. Default config values are correct
1302 // ------------------------------------------------------------------
1303 #[test]
1304 fn default_config_values() {
1305 let cfg = SGBTConfig::default();
1306 assert_eq!(cfg.n_steps, 100);
1307 assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1308 assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1309 assert_eq!(cfg.max_depth, 6);
1310 assert_eq!(cfg.n_bins, 64);
1311 assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1312 assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1313 assert_eq!(cfg.grace_period, 200);
1314 assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1315 assert_eq!(cfg.variant, SGBTVariant::Standard);
1316 }
1317
1318 // ------------------------------------------------------------------
1319 // 2. Builder chain works
1320 // ------------------------------------------------------------------
1321 #[test]
1322 fn builder_chain() {
1323 let cfg = SGBTConfig::builder()
1324 .n_steps(50)
1325 .learning_rate(0.1)
1326 .feature_subsample_rate(0.5)
1327 .max_depth(10)
1328 .n_bins(128)
1329 .lambda(0.5)
1330 .gamma(0.1)
1331 .grace_period(500)
1332 .delta(1e-3)
1333 .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1334 .variant(SGBTVariant::Skip { k: 5 })
1335 .build()
1336 .expect("valid config");
1337
1338 assert_eq!(cfg.n_steps, 50);
1339 assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1340 assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1341 assert_eq!(cfg.max_depth, 10);
1342 assert_eq!(cfg.n_bins, 128);
1343 assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1344 assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1345 assert_eq!(cfg.grace_period, 500);
1346 assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1347 assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1348
1349 // Verify drift detector type is Adwin.
1350 match &cfg.drift_detector {
1351 DriftDetectorType::Adwin { delta } => {
1352 assert!((delta - 0.01).abs() < f64::EPSILON);
1353 }
1354 _ => panic!("expected Adwin drift detector"),
1355 }
1356 }
1357
1358 // ------------------------------------------------------------------
1359 // 3. Validation rejects invalid values
1360 // ------------------------------------------------------------------
1361 #[test]
1362 fn validation_rejects_n_steps_zero() {
1363 let result = SGBTConfig::builder().n_steps(0).build();
1364 assert!(result.is_err());
1365 let msg = format!("{}", result.unwrap_err());
1366 assert!(msg.contains("n_steps"));
1367 }
1368
1369 #[test]
1370 fn validation_rejects_learning_rate_zero() {
1371 let result = SGBTConfig::builder().learning_rate(0.0).build();
1372 assert!(result.is_err());
1373 let msg = format!("{}", result.unwrap_err());
1374 assert!(msg.contains("learning_rate"));
1375 }
1376
1377 #[test]
1378 fn validation_rejects_learning_rate_above_one() {
1379 let result = SGBTConfig::builder().learning_rate(1.5).build();
1380 assert!(result.is_err());
1381 }
1382
1383 #[test]
1384 fn validation_accepts_learning_rate_one() {
1385 let result = SGBTConfig::builder().learning_rate(1.0).build();
1386 assert!(result.is_ok());
1387 }
1388
1389 #[test]
1390 fn validation_rejects_negative_learning_rate() {
1391 let result = SGBTConfig::builder().learning_rate(-0.1).build();
1392 assert!(result.is_err());
1393 }
1394
1395 #[test]
1396 fn validation_rejects_feature_subsample_zero() {
1397 let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1398 assert!(result.is_err());
1399 }
1400
1401 #[test]
1402 fn validation_rejects_feature_subsample_above_one() {
1403 let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1404 assert!(result.is_err());
1405 }
1406
1407 #[test]
1408 fn validation_rejects_max_depth_zero() {
1409 let result = SGBTConfig::builder().max_depth(0).build();
1410 assert!(result.is_err());
1411 }
1412
1413 #[test]
1414 fn validation_rejects_n_bins_one() {
1415 let result = SGBTConfig::builder().n_bins(1).build();
1416 assert!(result.is_err());
1417 }
1418
1419 #[test]
1420 fn validation_rejects_negative_lambda() {
1421 let result = SGBTConfig::builder().lambda(-0.1).build();
1422 assert!(result.is_err());
1423 }
1424
1425 #[test]
1426 fn validation_accepts_zero_lambda() {
1427 let result = SGBTConfig::builder().lambda(0.0).build();
1428 assert!(result.is_ok());
1429 }
1430
1431 #[test]
1432 fn validation_rejects_negative_gamma() {
1433 let result = SGBTConfig::builder().gamma(-0.1).build();
1434 assert!(result.is_err());
1435 }
1436
1437 #[test]
1438 fn validation_rejects_grace_period_zero() {
1439 let result = SGBTConfig::builder().grace_period(0).build();
1440 assert!(result.is_err());
1441 }
1442
1443 #[test]
1444 fn validation_rejects_delta_zero() {
1445 let result = SGBTConfig::builder().delta(0.0).build();
1446 assert!(result.is_err());
1447 }
1448
1449 #[test]
1450 fn validation_rejects_delta_one() {
1451 let result = SGBTConfig::builder().delta(1.0).build();
1452 assert!(result.is_err());
1453 }
1454
1455 // ------------------------------------------------------------------
1456 // 3b. Drift detector parameter validation
1457 // ------------------------------------------------------------------
1458 #[test]
1459 fn validation_rejects_pht_negative_delta() {
1460 let result = SGBTConfig::builder()
1461 .drift_detector(DriftDetectorType::PageHinkley {
1462 delta: -1.0,
1463 lambda: 50.0,
1464 })
1465 .build();
1466 assert!(result.is_err());
1467 let msg = format!("{}", result.unwrap_err());
1468 assert!(msg.contains("PageHinkley"));
1469 }
1470
1471 #[test]
1472 fn validation_rejects_pht_zero_lambda() {
1473 let result = SGBTConfig::builder()
1474 .drift_detector(DriftDetectorType::PageHinkley {
1475 delta: 0.005,
1476 lambda: 0.0,
1477 })
1478 .build();
1479 assert!(result.is_err());
1480 }
1481
1482 #[test]
1483 fn validation_rejects_adwin_delta_out_of_range() {
1484 let result = SGBTConfig::builder()
1485 .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1486 .build();
1487 assert!(result.is_err());
1488
1489 let result = SGBTConfig::builder()
1490 .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1491 .build();
1492 assert!(result.is_err());
1493 }
1494
1495 #[test]
1496 fn validation_rejects_ddm_warning_above_drift() {
1497 let result = SGBTConfig::builder()
1498 .drift_detector(DriftDetectorType::Ddm {
1499 warning_level: 3.0,
1500 drift_level: 2.0,
1501 min_instances: 30,
1502 })
1503 .build();
1504 assert!(result.is_err());
1505 let msg = format!("{}", result.unwrap_err());
1506 assert!(msg.contains("drift_level"));
1507 assert!(msg.contains("must be > warning_level"));
1508 }
1509
1510 #[test]
1511 fn validation_rejects_ddm_equal_levels() {
1512 let result = SGBTConfig::builder()
1513 .drift_detector(DriftDetectorType::Ddm {
1514 warning_level: 2.0,
1515 drift_level: 2.0,
1516 min_instances: 30,
1517 })
1518 .build();
1519 assert!(result.is_err());
1520 }
1521
1522 #[test]
1523 fn validation_rejects_ddm_zero_min_instances() {
1524 let result = SGBTConfig::builder()
1525 .drift_detector(DriftDetectorType::Ddm {
1526 warning_level: 2.0,
1527 drift_level: 3.0,
1528 min_instances: 0,
1529 })
1530 .build();
1531 assert!(result.is_err());
1532 }
1533
1534 #[test]
1535 fn validation_rejects_ddm_zero_warning_level() {
1536 let result = SGBTConfig::builder()
1537 .drift_detector(DriftDetectorType::Ddm {
1538 warning_level: 0.0,
1539 drift_level: 3.0,
1540 min_instances: 30,
1541 })
1542 .build();
1543 assert!(result.is_err());
1544 }
1545
1546 // ------------------------------------------------------------------
1547 // 3c. Variant parameter validation
1548 // ------------------------------------------------------------------
1549 #[test]
1550 fn validation_rejects_skip_k_zero() {
1551 let result = SGBTConfig::builder()
1552 .variant(SGBTVariant::Skip { k: 0 })
1553 .build();
1554 assert!(result.is_err());
1555 let msg = format!("{}", result.unwrap_err());
1556 assert!(msg.contains("Skip"));
1557 }
1558
1559 #[test]
1560 fn validation_rejects_mi_zero_multiplier() {
1561 let result = SGBTConfig::builder()
1562 .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1563 .build();
1564 assert!(result.is_err());
1565 }
1566
1567 #[test]
1568 fn validation_rejects_mi_negative_multiplier() {
1569 let result = SGBTConfig::builder()
1570 .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1571 .build();
1572 assert!(result.is_err());
1573 }
1574
1575 #[test]
1576 fn validation_accepts_standard_variant() {
1577 let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1578 assert!(result.is_ok());
1579 }
1580
1581 // ------------------------------------------------------------------
1582 // 4. DriftDetectorType creates correct detector types
1583 // ------------------------------------------------------------------
1584 #[test]
1585 fn drift_detector_type_creates_page_hinkley() {
1586 let dt = DriftDetectorType::PageHinkley {
1587 delta: 0.01,
1588 lambda: 100.0,
1589 };
1590 let mut detector = dt.create();
1591
1592 // Should start with zero mean.
1593 assert_eq!(detector.estimated_mean(), 0.0);
1594
1595 // Feed a stable value -- should not drift.
1596 let signal = detector.update(1.0);
1597 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1598 }
1599
1600 #[test]
1601 fn drift_detector_type_creates_adwin() {
1602 let dt = DriftDetectorType::Adwin { delta: 0.05 };
1603 let mut detector = dt.create();
1604
1605 assert_eq!(detector.estimated_mean(), 0.0);
1606 let signal = detector.update(1.0);
1607 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1608 }
1609
1610 #[test]
1611 fn drift_detector_type_creates_ddm() {
1612 let dt = DriftDetectorType::Ddm {
1613 warning_level: 2.0,
1614 drift_level: 3.0,
1615 min_instances: 30,
1616 };
1617 let mut detector = dt.create();
1618
1619 assert_eq!(detector.estimated_mean(), 0.0);
1620 let signal = detector.update(0.1);
1621 assert_eq!(signal, crate::drift::DriftSignal::Stable);
1622 }
1623
1624 // ------------------------------------------------------------------
1625 // 5. Default drift detector is PageHinkley
1626 // ------------------------------------------------------------------
1627 #[test]
1628 fn default_drift_detector_is_page_hinkley() {
1629 let dt = DriftDetectorType::default();
1630 match dt {
1631 DriftDetectorType::PageHinkley { delta, lambda } => {
1632 assert!((delta - 0.005).abs() < f64::EPSILON);
1633 assert!((lambda - 50.0).abs() < f64::EPSILON);
1634 }
1635 _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1636 }
1637 }
1638
1639 // ------------------------------------------------------------------
1640 // 6. Default config builds without error
1641 // ------------------------------------------------------------------
1642 #[test]
1643 fn default_config_builds() {
1644 let result = SGBTConfig::builder().build();
1645 assert!(result.is_ok());
1646 }
1647
1648 // ------------------------------------------------------------------
1649 // 7. Config clone preserves all fields
1650 // ------------------------------------------------------------------
1651 #[test]
1652 fn config_clone_preserves_fields() {
1653 let original = SGBTConfig::builder()
1654 .n_steps(50)
1655 .learning_rate(0.05)
1656 .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1657 .build()
1658 .unwrap();
1659
1660 let cloned = original.clone();
1661
1662 assert_eq!(cloned.n_steps, 50);
1663 assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1664 assert_eq!(
1665 cloned.variant,
1666 SGBTVariant::MultipleIterations { multiplier: 5.0 }
1667 );
1668 }
1669
1670 // ------------------------------------------------------------------
1671 // 8. All three DDM valid configs accepted
1672 // ------------------------------------------------------------------
1673 #[test]
1674 fn valid_ddm_config_accepted() {
1675 let result = SGBTConfig::builder()
1676 .drift_detector(DriftDetectorType::Ddm {
1677 warning_level: 1.5,
1678 drift_level: 2.5,
1679 min_instances: 10,
1680 })
1681 .build();
1682 assert!(result.is_ok());
1683 }
1684
1685 // ------------------------------------------------------------------
1686 // 9. Created detectors are functional (round-trip test)
1687 // ------------------------------------------------------------------
1688 #[test]
1689 fn created_detectors_are_functional() {
1690 // Create a detector, feed it data, verify it can detect drift.
1691 let dt = DriftDetectorType::PageHinkley {
1692 delta: 0.005,
1693 lambda: 50.0,
1694 };
1695 let mut detector = dt.create();
1696
1697 // Feed 500 stable values.
1698 for _ in 0..500 {
1699 detector.update(1.0);
1700 }
1701
1702 // Feed shifted values -- should eventually drift.
1703 let mut drifted = false;
1704 for _ in 0..500 {
1705 if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1706 drifted = true;
1707 break;
1708 }
1709 }
1710 assert!(
1711 drifted,
1712 "detector created from DriftDetectorType should be functional"
1713 );
1714 }
1715
1716 // ------------------------------------------------------------------
1717 // 10. Boundary acceptance: n_bins=2 is the minimum valid
1718 // ------------------------------------------------------------------
1719 #[test]
1720 fn boundary_n_bins_two_accepted() {
1721 let result = SGBTConfig::builder().n_bins(2).build();
1722 assert!(result.is_ok());
1723 }
1724
1725 // ------------------------------------------------------------------
1726 // 11. Boundary acceptance: grace_period=1 is valid
1727 // ------------------------------------------------------------------
1728 #[test]
1729 fn boundary_grace_period_one_accepted() {
1730 let result = SGBTConfig::builder().grace_period(1).build();
1731 assert!(result.is_ok());
1732 }
1733
1734 // ------------------------------------------------------------------
1735 // 12. Feature names -- valid config with names
1736 // ------------------------------------------------------------------
1737 #[test]
1738 fn feature_names_accepted() {
1739 let cfg = SGBTConfig::builder()
1740 .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1741 .build()
1742 .unwrap();
1743 assert_eq!(
1744 cfg.feature_names.as_ref().unwrap(),
1745 &["price", "volume", "spread"]
1746 );
1747 }
1748
1749 // ------------------------------------------------------------------
1750 // 13. Feature names -- duplicate names rejected
1751 // ------------------------------------------------------------------
1752 #[test]
1753 fn feature_names_rejects_duplicates() {
1754 let result = SGBTConfig::builder()
1755 .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1756 .build();
1757 assert!(result.is_err());
1758 let msg = format!("{}", result.unwrap_err());
1759 assert!(msg.contains("duplicate"));
1760 }
1761
1762 // ------------------------------------------------------------------
1763 // 14. Feature names -- serde backward compat (missing field)
1764 // Requires serde_json which is only in the full irithyll crate.
1765 // ------------------------------------------------------------------
1766
1767 // ------------------------------------------------------------------
1768 // 15. Feature names -- empty vec is valid
1769 // ------------------------------------------------------------------
1770 #[test]
1771 fn feature_names_empty_vec_accepted() {
1772 let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1773 assert!(cfg.feature_names.unwrap().is_empty());
1774 }
1775
1776 // ------------------------------------------------------------------
1777 // 16. Adaptive leaf bound -- builder
1778 // ------------------------------------------------------------------
1779 #[test]
1780 fn builder_adaptive_leaf_bound() {
1781 let cfg = SGBTConfig::builder()
1782 .adaptive_leaf_bound(3.0)
1783 .build()
1784 .unwrap();
1785 assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1786 }
1787
1788 // ------------------------------------------------------------------
1789 // 17. Adaptive leaf bound -- validation rejects zero
1790 // ------------------------------------------------------------------
1791 #[test]
1792 fn validation_rejects_zero_adaptive_leaf_bound() {
1793 let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1794 assert!(result.is_err());
1795 let msg = format!("{}", result.unwrap_err());
1796 assert!(
1797 msg.contains("adaptive_leaf_bound"),
1798 "error should mention adaptive_leaf_bound: {}",
1799 msg,
1800 );
1801 }
1802
1803 // ------------------------------------------------------------------
1804 // 18. Adaptive leaf bound -- serde backward compat
1805 // Requires serde_json which is only in the full irithyll crate.
1806 // ------------------------------------------------------------------
1807}