irithyll_core/ensemble/config.rs
1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34/// `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1) compute.
35/// Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38/// log(σ) from features (NGBoost-style dual chain). Gives *feature-conditional*
39/// uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43 /// EWMA of squared prediction errors — always calibrated, O(1).
44 #[default]
45 Empirical,
46 /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47 TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61 /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62 /// (detection threshold).
63 PageHinkley {
64 /// Magnitude tolerance. Default 0.005.
65 delta: f64,
66 /// Detection threshold. Default 50.0.
67 lambda: f64,
68 },
69
70 /// ADWIN with custom confidence parameter.
71 Adwin {
72 /// Confidence (smaller = fewer false positives). Default 0.002.
73 delta: f64,
74 },
75
76 /// DDM with custom warning/drift levels and minimum warmup instances.
77 Ddm {
78 /// Warning threshold multiplier. Default 2.0.
79 warning_level: f64,
80 /// Drift threshold multiplier. Default 3.0.
81 drift_level: f64,
82 /// Minimum observations before detection activates. Default 30.
83 min_instances: u64,
84 },
85}
86
87impl Default for DriftDetectorType {
88 fn default() -> Self {
89 DriftDetectorType::PageHinkley {
90 delta: 0.005,
91 lambda: 50.0,
92 }
93 }
94}
95
96impl DriftDetectorType {
97 /// Create a new, fresh drift detector from this configuration.
98 pub fn create(&self) -> Box<dyn DriftDetector> {
99 match self {
100 Self::PageHinkley { delta, lambda } => {
101 Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102 }
103 Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104 Self::Ddm {
105 warning_level,
106 drift_level,
107 min_instances,
108 } => Box::new(Ddm::with_params(
109 *warning_level,
110 *drift_level,
111 *min_instances,
112 )),
113 }
114 }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter | Default |
128/// |--------------------------|----------------------|
129/// | `n_steps` | 100 |
130/// | `learning_rate` | 0.0125 |
131/// | `feature_subsample_rate` | 0.75 |
132/// | `max_depth` | 6 |
133/// | `n_bins` | 64 |
134/// | `lambda` | 1.0 |
135/// | `gamma` | 0.0 |
136/// | `grace_period` | 200 |
137/// | `delta` | 1e-7 |
138/// | `drift_detector` | PageHinkley(0.005, 50.0) |
139/// | `variant` | Standard |
140/// | `seed` | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count` | 50 |
142/// | `leaf_half_life` | None (disabled) |
143/// | `max_tree_samples` | None (disabled) |
144/// | `split_reeval_interval` | None (disabled) |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148 /// Number of boosting steps (trees in the ensemble). Default 100.
149 pub n_steps: usize,
150 /// Learning rate (shrinkage). Default 0.0125.
151 pub learning_rate: f64,
152 /// Fraction of features to subsample per tree. Default 0.75.
153 pub feature_subsample_rate: f64,
154 /// Maximum tree depth. Default 6.
155 pub max_depth: usize,
156 /// Number of histogram bins. Default 64.
157 pub n_bins: usize,
158 /// L2 regularization parameter (lambda). Default 1.0.
159 pub lambda: f64,
160 /// Minimum split gain (gamma). Default 0.0.
161 pub gamma: f64,
162 /// Grace period: min samples before evaluating splits. Default 200.
163 pub grace_period: usize,
164 /// Hoeffding bound confidence (delta). Default 1e-7.
165 pub delta: f64,
166 /// Drift detector type for tree replacement. Default: PageHinkley.
167 pub drift_detector: DriftDetectorType,
168 /// SGBT computational variant. Default: Standard.
169 pub variant: SGBTVariant,
170 /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171 ///
172 /// Controls feature subsampling and variant skip/MI stochastic decisions.
173 /// Two models with the same seed and same data will produce identical results.
174 pub seed: u64,
175 /// Number of initial targets to collect before computing the base prediction.
176 /// Default: 50.
177 pub initial_target_count: usize,
178
179 /// Half-life for exponential leaf decay (in samples per leaf).
180 ///
181 /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182 /// statistics have half the weight of the most recent sample. This causes
183 /// the model to continuously adapt to changing data distributions rather
184 /// than freezing on early observations.
185 ///
186 /// `None` (default) disables decay -- traditional monotonic accumulation.
187 #[cfg_attr(feature = "serde", serde(default))]
188 pub leaf_half_life: Option<usize>,
189
190 /// Maximum samples a single tree processes before proactive replacement.
191 ///
192 /// After this many samples, the tree is replaced with a fresh one regardless
193 /// of drift detector state. Prevents stale tree structure from persisting
194 /// when the drift detector is not sensitive enough.
195 ///
196 /// `None` (default) disables time-based replacement.
197 #[cfg_attr(feature = "serde", serde(default))]
198 pub max_tree_samples: Option<u64>,
199
200 /// Sigma-modulated tree replacement speed.
201 ///
202 /// When `Some((base_mts, k))`, the effective max_tree_samples becomes:
203 /// `effective_mts = base_mts / (1.0 + k * normalized_sigma)`
204 /// where `normalized_sigma = contribution_sigma / rolling_contribution_sigma_mean`.
205 ///
206 /// High uncertainty (contribution variance) shortens tree lifetime for faster
207 /// adaptation. Low uncertainty lengthens lifetime for more knowledge accumulation.
208 ///
209 /// Overrides `max_tree_samples` when set. Default: `None`.
210 #[cfg_attr(feature = "serde", serde(default))]
211 pub adaptive_mts: Option<(u64, f64)>,
212
213 /// Minimum effective MTS as a fraction of `base_mts`.
214 ///
215 /// When adaptive MTS is active, the effective tree lifetime can shrink
216 /// aggressively under high uncertainty. This floor prevents runaway
217 /// replacement by clamping `effective_mts >= base_mts * fraction`.
218 ///
219 /// The hard floor (grace_period * 2) still applies beneath this.
220 ///
221 /// Default: `0.0` (only the hard floor applies).
222 #[cfg_attr(feature = "serde", serde(default))]
223 pub adaptive_mts_floor: f64,
224
225 /// Proactive tree pruning interval.
226 ///
227 /// Every `interval` training samples, the tree with the lowest EWMA
228 /// marginal contribution is replaced with a fresh tree.
229 /// Fresh trees (fewer than `interval / 2` samples) are excluded.
230 ///
231 /// When set, contribution tracking is automatically enabled even if
232 /// `quality_prune_alpha` is `None` (uses alpha=0.01).
233 ///
234 /// Default: `None` (disabled).
235 #[cfg_attr(feature = "serde", serde(default))]
236 pub proactive_prune_interval: Option<u64>,
237
238 /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
239 /// whether a split would improve them.
240 ///
241 /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
242 /// `split_reeval_interval` samples since its last evaluation and has reached
243 /// max depth, it re-evaluates whether a split should be performed.
244 ///
245 /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
246 #[cfg_attr(feature = "serde", serde(default))]
247 pub split_reeval_interval: Option<usize>,
248
249 /// Optional human-readable feature names.
250 ///
251 /// When set, enables `named_feature_importances` and
252 /// `train_one_named` for production-friendly named access.
253 /// Length must match the number of features in training data.
254 #[cfg_attr(feature = "serde", serde(default))]
255 pub feature_names: Option<Vec<String>>,
256
257 /// Optional per-feature type declarations.
258 ///
259 /// When set, declares which features are categorical vs continuous.
260 /// Categorical features use one-bin-per-category binning and Fisher
261 /// optimal binary partitioning for split evaluation.
262 /// Length must match the number of features in training data.
263 ///
264 /// `None` (default) treats all features as continuous.
265 #[cfg_attr(feature = "serde", serde(default))]
266 pub feature_types: Option<Vec<FeatureType>>,
267
268 /// Gradient clipping threshold in standard deviations per leaf.
269 ///
270 /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
271 /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
272 /// clamped to the boundary. This prevents outlier samples from corrupting
273 /// leaf statistics, which is critical in streaming settings where sudden
274 /// label floods can destabilize the model.
275 ///
276 /// Typical value: 3.0 (3-sigma clipping).
277 /// `None` (default) disables gradient clipping.
278 #[cfg_attr(feature = "serde", serde(default))]
279 pub gradient_clip_sigma: Option<f64>,
280
281 /// Per-feature monotonic constraints.
282 ///
283 /// Each element specifies the monotonic relationship between a feature and
284 /// the prediction:
285 /// - `+1`: prediction must be non-decreasing as feature value increases.
286 /// - `-1`: prediction must be non-increasing as feature value increases.
287 /// - `0`: no constraint (unconstrained).
288 ///
289 /// During split evaluation, candidate splits that would violate monotonicity
290 /// (left child value > right child value for +1 constraints, or vice versa)
291 /// are rejected.
292 ///
293 /// Length must match the number of features in training data.
294 /// `None` (default) means no monotonic constraints.
295 #[cfg_attr(feature = "serde", serde(default))]
296 pub monotone_constraints: Option<Vec<i8>>,
297
298 /// EWMA smoothing factor for quality-based tree pruning.
299 ///
300 /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
301 /// moving average of its marginal contribution to the ensemble. Trees whose
302 /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
303 /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
304 /// samples are replaced with a fresh tree that can learn the current regime.
305 ///
306 /// This prevents "dead wood" -- trees from a past regime that no longer
307 /// contribute meaningfully to ensemble accuracy.
308 ///
309 /// `None` (default) disables quality-based pruning.
310 /// Suggested value: 0.01.
311 #[cfg_attr(feature = "serde", serde(default))]
312 pub quality_prune_alpha: Option<f64>,
313
314 /// Minimum contribution threshold for quality-based pruning.
315 ///
316 /// A tree's EWMA contribution must stay above this value to avoid being
317 /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
318 ///
319 /// Default: 1e-6.
320 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
321 pub quality_prune_threshold: f64,
322
323 /// Consecutive low-contribution samples before a tree is replaced.
324 ///
325 /// After this many consecutive samples where a tree's EWMA contribution
326 /// is below `quality_prune_threshold`, the tree is reset. Only used when
327 /// `quality_prune_alpha` is `Some`.
328 ///
329 /// Default: 500.
330 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
331 pub quality_prune_patience: u64,
332
333 /// EWMA smoothing factor for error-weighted sample importance.
334 ///
335 /// When `Some(alpha)`, samples the model predicted poorly get higher
336 /// effective weight during histogram accumulation. The weight is:
337 /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
338 ///
339 /// This is a streaming version of AdaBoost's reweighting applied at the
340 /// gradient level -- learning capacity focuses on hard/novel patterns,
341 /// enabling faster adaptation to regime changes.
342 ///
343 /// `None` (default) disables error weighting.
344 /// Suggested value: 0.01.
345 #[cfg_attr(feature = "serde", serde(default))]
346 pub error_weight_alpha: Option<f64>,
347
348 /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
349 ///
350 /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
351 /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
352 /// is an EWMA of the model's predicted σ (alpha = 0.001).
353 ///
354 /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
355 /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
356 /// trains at the unmodulated base rate to prevent positive feedback loops.
357 ///
358 /// Default: `false`.
359 #[cfg_attr(feature = "serde", serde(default))]
360 pub uncertainty_modulated_lr: bool,
361
362 /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
363 ///
364 /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
365 /// errors. `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1).
366 /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
367 /// separate tree ensemble predicting log(σ) from features.
368 ///
369 /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
370 /// is strongly recommended — scale tree gradients are inherently weak and
371 /// the trees often fail to split.
372 #[cfg_attr(feature = "serde", serde(default))]
373 pub scale_mode: ScaleMode,
374
375 /// EWMA smoothing factor for empirical σ estimation.
376 ///
377 /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
378 /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
379 /// Higher values react faster to regime changes but are noisier.
380 ///
381 /// Default: `0.01` (~100-sample effective window).
382 #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
383 pub empirical_sigma_alpha: f64,
384
385 /// Maximum absolute leaf output value.
386 ///
387 /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
388 /// Prevents runaway leaf weights from causing prediction explosions
389 /// in feedback loops. `None` (default) means no clamping.
390 #[cfg_attr(feature = "serde", serde(default))]
391 pub max_leaf_output: Option<f64>,
392
393 /// Per-leaf adaptive output bound (sigma multiplier).
394 ///
395 /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
396 /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
397 /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
398 ///
399 /// This is strictly superior to `max_leaf_output` for streaming — the bound
400 /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
401 /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
402 ///
403 /// Typical value: 3.0 (3-sigma bound).
404 /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
405 #[cfg_attr(feature = "serde", serde(default))]
406 pub adaptive_leaf_bound: Option<f64>,
407
408 /// Per-split information criterion (Lunde-Kleppe-Skaug 2020).
409 ///
410 /// When `Some(cir_factor)`, replaces `max_depth` with a per-split
411 /// generalization test. Each candidate split must have
412 /// `gain > cir_factor * sigma^2_g / n * n_features`.
413 /// `max_depth * 2` becomes a hard safety ceiling.
414 ///
415 /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
416 /// `None` (default) uses static `max_depth` only.
417 #[cfg_attr(feature = "serde", serde(default))]
418 pub adaptive_depth: Option<f64>,
419
420 /// Minimum hessian sum before a leaf produces non-zero output.
421 ///
422 /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
423 /// Prevents post-replacement spikes from fresh leaves with insufficient
424 /// samples. `None` (default) means all leaves contribute immediately.
425 #[cfg_attr(feature = "serde", serde(default))]
426 pub min_hessian_sum: Option<f64>,
427
428 /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
429 ///
430 /// When `Some(k)`, the distributional location gradient uses Huber loss
431 /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
432 /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
433 /// `None` (default) uses squared loss.
434 #[cfg_attr(feature = "serde", serde(default))]
435 pub huber_k: Option<f64>,
436
437 /// Shadow warmup for graduated tree handoff.
438 ///
439 /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
440 /// alongside every active tree. The shadow trains on the same gradient stream
441 /// but does not contribute to predictions until it has seen `n` samples.
442 ///
443 /// As the active tree ages past 80% of `max_tree_samples`, its prediction
444 /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
445 /// over `n` samples after warmup. When the active weight reaches 0, the shadow
446 /// is promoted and a new shadow is spawned — no cold-start prediction dip.
447 ///
448 /// Requires `max_tree_samples` to be set for time-based graduated handoff.
449 /// Drift-based replacement still uses hard swap (shadow is already warm).
450 ///
451 /// `None` (default) disables graduated handoff — uses traditional hard swap.
452 #[cfg_attr(feature = "serde", serde(default))]
453 pub shadow_warmup: Option<usize>,
454
455 /// Leaf prediction model type.
456 ///
457 /// Controls how each leaf computes its prediction:
458 /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
459 /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
460 /// AdaGrad optimization. Optional `decay` for concept drift. Recommended for
461 /// low-depth trees (depth 2--4).
462 /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
463 /// Optional `decay` for concept drift.
464 /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
465 /// when the Hoeffding bound confirms a more complex model is better.
466 ///
467 /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
468 #[cfg_attr(feature = "serde", serde(default))]
469 pub leaf_model_type: LeafModelType,
470
471 /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
472 ///
473 /// When non-zero, the distributional model maintains a packed f32 cache of
474 /// its location ensemble that is re-exported every `packed_refresh_interval`
475 /// training samples. Predictions use the cache for O(1)-per-tree inference
476 /// via contiguous memory traversal, falling back to full tree traversal when
477 /// the cache is absent or produces non-finite results.
478 ///
479 /// `0` (default) disables the packed cache.
480 #[cfg_attr(feature = "serde", serde(default))]
481 pub packed_refresh_interval: u64,
482
483 /// Enable per-node auto-bandwidth soft routing at prediction time.
484 ///
485 /// When `true`, predictions are continuous weighted blends instead of
486 /// piecewise-constant step functions. No training changes.
487 ///
488 /// Default: `false`.
489 #[cfg_attr(feature = "serde", serde(default))]
490 pub soft_routing: bool,
491}
492
493#[cfg(feature = "serde")]
494fn default_empirical_sigma_alpha() -> f64 {
495 0.01
496}
497
498#[cfg(feature = "serde")]
499fn default_quality_prune_threshold() -> f64 {
500 1e-6
501}
502
503#[cfg(feature = "serde")]
504fn default_quality_prune_patience() -> u64 {
505 500
506}
507
508impl Default for SGBTConfig {
509 fn default() -> Self {
510 Self {
511 n_steps: 100,
512 learning_rate: 0.0125,
513 feature_subsample_rate: 0.75,
514 max_depth: 6,
515 n_bins: 64,
516 lambda: 1.0,
517 gamma: 0.0,
518 grace_period: 200,
519 delta: 1e-7,
520 drift_detector: DriftDetectorType::default(),
521 variant: SGBTVariant::default(),
522 seed: 0xDEAD_BEEF_CAFE_4242,
523 initial_target_count: 50,
524 leaf_half_life: None,
525 max_tree_samples: None,
526 adaptive_mts: None,
527 adaptive_mts_floor: 0.0,
528 proactive_prune_interval: None,
529 split_reeval_interval: None,
530 feature_names: None,
531 feature_types: None,
532 gradient_clip_sigma: None,
533 monotone_constraints: None,
534 quality_prune_alpha: None,
535 quality_prune_threshold: 1e-6,
536 quality_prune_patience: 500,
537 error_weight_alpha: None,
538 uncertainty_modulated_lr: false,
539 scale_mode: ScaleMode::default(),
540 empirical_sigma_alpha: 0.01,
541 max_leaf_output: None,
542 adaptive_leaf_bound: None,
543 adaptive_depth: None,
544 min_hessian_sum: None,
545 huber_k: None,
546 shadow_warmup: None,
547 leaf_model_type: LeafModelType::default(),
548 packed_refresh_interval: 0,
549 soft_routing: false,
550 }
551 }
552}
553
554impl SGBTConfig {
555 /// Start building a configuration via the builder pattern.
556 pub fn builder() -> SGBTConfigBuilder {
557 SGBTConfigBuilder::default()
558 }
559}
560
561// ---------------------------------------------------------------------------
562// SGBTConfigBuilder
563// ---------------------------------------------------------------------------
564
565/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
566///
567/// # Example
568///
569/// ```ignore
570/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
571/// use irithyll::ensemble::variants::SGBTVariant;
572///
573/// let config = SGBTConfig::builder()
574/// .n_steps(200)
575/// .learning_rate(0.05)
576/// .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
577/// .variant(SGBTVariant::Skip { k: 10 })
578/// .build()
579/// .expect("valid config");
580/// ```
581#[derive(Debug, Clone, Default)]
582pub struct SGBTConfigBuilder {
583 config: SGBTConfig,
584}
585
586impl SGBTConfigBuilder {
587 /// Set the number of boosting steps (trees in the ensemble).
588 pub fn n_steps(mut self, n: usize) -> Self {
589 self.config.n_steps = n;
590 self
591 }
592
593 /// Set the learning rate (shrinkage factor).
594 pub fn learning_rate(mut self, lr: f64) -> Self {
595 self.config.learning_rate = lr;
596 self
597 }
598
599 /// Set the fraction of features to subsample per tree.
600 pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
601 self.config.feature_subsample_rate = rate;
602 self
603 }
604
605 /// Set the maximum tree depth.
606 pub fn max_depth(mut self, depth: usize) -> Self {
607 self.config.max_depth = depth;
608 self
609 }
610
611 /// Set the number of histogram bins per feature.
612 pub fn n_bins(mut self, bins: usize) -> Self {
613 self.config.n_bins = bins;
614 self
615 }
616
617 /// Set the L2 regularization parameter (lambda).
618 pub fn lambda(mut self, l: f64) -> Self {
619 self.config.lambda = l;
620 self
621 }
622
623 /// Set the minimum split gain (gamma).
624 pub fn gamma(mut self, g: f64) -> Self {
625 self.config.gamma = g;
626 self
627 }
628
629 /// Set the grace period (minimum samples before evaluating splits).
630 pub fn grace_period(mut self, gp: usize) -> Self {
631 self.config.grace_period = gp;
632 self
633 }
634
635 /// Set the Hoeffding bound confidence parameter (delta).
636 pub fn delta(mut self, d: f64) -> Self {
637 self.config.delta = d;
638 self
639 }
640
641 /// Set the drift detector type for tree replacement.
642 pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
643 self.config.drift_detector = dt;
644 self
645 }
646
647 /// Set the SGBT computational variant.
648 pub fn variant(mut self, v: SGBTVariant) -> Self {
649 self.config.variant = v;
650 self
651 }
652
653 /// Set the random seed for deterministic reproducibility.
654 ///
655 /// Controls feature subsampling and variant skip/MI stochastic decisions.
656 /// Two models with the same seed and data sequence will produce identical results.
657 pub fn seed(mut self, seed: u64) -> Self {
658 self.config.seed = seed;
659 self
660 }
661
662 /// Set the number of initial targets to collect before computing the base prediction.
663 ///
664 /// The model collects this many target values before initializing the base
665 /// prediction (via `loss.initial_prediction`). Default: 50.
666 pub fn initial_target_count(mut self, count: usize) -> Self {
667 self.config.initial_target_count = count;
668 self
669 }
670
671 /// Set the half-life for exponential leaf decay (in samples per leaf).
672 ///
673 /// After `n` samples, a leaf's accumulated statistics have half the weight
674 /// of the most recent sample. Enables continuous adaptation to concept drift.
675 pub fn leaf_half_life(mut self, n: usize) -> Self {
676 self.config.leaf_half_life = Some(n);
677 self
678 }
679
680 /// Set the maximum samples a single tree processes before proactive replacement.
681 ///
682 /// After `n` samples, the tree is replaced regardless of drift detector state.
683 pub fn max_tree_samples(mut self, n: u64) -> Self {
684 self.config.max_tree_samples = Some(n);
685 self
686 }
687
688 /// Enable sigma-modulated tree replacement speed.
689 ///
690 /// The effective max_tree_samples becomes `base_mts / (1.0 + k * normalized_sigma)`,
691 /// where `normalized_sigma` tracks contribution variance across ensemble trees.
692 /// Overrides `max_tree_samples` when set.
693 pub fn adaptive_mts(mut self, base_mts: u64, k: f64) -> Self {
694 self.config.adaptive_mts = Some((base_mts, k));
695 self
696 }
697
698 /// Set a minimum effective MTS as a fraction of `base_mts`.
699 ///
700 /// Prevents adaptive MTS from shrinking tree lifetime below
701 /// `base_mts * fraction`. For example, `0.25` ensures effective MTS
702 /// never drops below 25% of `base_mts`.
703 ///
704 /// Only meaningful when `adaptive_mts` is also set.
705 pub fn adaptive_mts_floor(mut self, fraction: f64) -> Self {
706 self.config.adaptive_mts_floor = fraction;
707 self
708 }
709
710 /// Set the proactive tree pruning interval.
711 ///
712 /// Every `interval` training samples, the tree with the lowest EWMA
713 /// marginal contribution is replaced. Automatically enables contribution
714 /// tracking with alpha=0.01 when `quality_prune_alpha` is not set.
715 pub fn proactive_prune_interval(mut self, interval: u64) -> Self {
716 self.config.proactive_prune_interval = Some(interval);
717 self
718 }
719
720 /// Set the split re-evaluation interval for max-depth leaves.
721 ///
722 /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
723 /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
724 pub fn split_reeval_interval(mut self, n: usize) -> Self {
725 self.config.split_reeval_interval = Some(n);
726 self
727 }
728
729 /// Set human-readable feature names.
730 ///
731 /// Enables named feature importances and named training input.
732 /// Names must be unique; validated at [`build()`](Self::build).
733 pub fn feature_names(mut self, names: Vec<String>) -> Self {
734 self.config.feature_names = Some(names);
735 self
736 }
737
738 /// Set per-feature type declarations.
739 ///
740 /// Declares which features are categorical vs continuous. Categorical features
741 /// use one-bin-per-category binning and Fisher optimal binary partitioning.
742 /// Supports up to 64 distinct category values per categorical feature.
743 pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
744 self.config.feature_types = Some(types);
745 self
746 }
747
748 /// Set per-leaf gradient clipping threshold (in standard deviations).
749 ///
750 /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
751 /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
752 /// from corrupting streaming model stability.
753 ///
754 /// Typical value: 3.0 (3-sigma clipping).
755 pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
756 self.config.gradient_clip_sigma = Some(sigma);
757 self
758 }
759
760 /// Set per-feature monotonic constraints.
761 ///
762 /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
763 /// Candidate splits violating monotonicity are rejected during tree growth.
764 pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
765 self.config.monotone_constraints = Some(constraints);
766 self
767 }
768
769 /// Enable quality-based tree pruning with the given EWMA smoothing factor.
770 ///
771 /// Trees whose marginal contribution drops below the threshold for
772 /// `patience` consecutive samples are replaced with fresh trees.
773 /// Suggested alpha: 0.01.
774 pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
775 self.config.quality_prune_alpha = Some(alpha);
776 self
777 }
778
779 /// Set the minimum contribution threshold for quality-based pruning.
780 ///
781 /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
782 pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
783 self.config.quality_prune_threshold = threshold;
784 self
785 }
786
787 /// Set the patience (consecutive low-contribution samples) before pruning.
788 ///
789 /// Default: 500. Only relevant when `quality_prune_alpha` is set.
790 pub fn quality_prune_patience(mut self, patience: u64) -> Self {
791 self.config.quality_prune_patience = patience;
792 self
793 }
794
795 /// Enable error-weighted sample importance with the given EWMA smoothing factor.
796 ///
797 /// Samples the model predicted poorly get higher effective weight.
798 /// Suggested alpha: 0.01.
799 pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
800 self.config.error_weight_alpha = Some(alpha);
801 self
802 }
803
804 /// Enable σ-modulated learning rate for distributional models.
805 ///
806 /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
807 /// so the model adapts faster during high-uncertainty regimes and conserves
808 /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
809 ///
810 /// By default uses empirical σ (EWMA of squared errors). Set
811 /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
812 pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
813 self.config.uncertainty_modulated_lr = enabled;
814 self
815 }
816
817 /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
818 ///
819 /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
820 /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
821 pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
822 self.config.scale_mode = mode;
823 self
824 }
825
826 /// Enable per-node auto-bandwidth soft routing.
827 pub fn soft_routing(mut self, enabled: bool) -> Self {
828 self.config.soft_routing = enabled;
829 self
830 }
831
832 /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
833 ///
834 /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
835 pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
836 self.config.empirical_sigma_alpha = alpha;
837 self
838 }
839
840 /// Set the maximum absolute leaf output value.
841 ///
842 /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
843 /// that cause prediction explosions.
844 pub fn max_leaf_output(mut self, max: f64) -> Self {
845 self.config.max_leaf_output = Some(max);
846 self
847 }
848
849 /// Set per-leaf adaptive output bound (sigma multiplier).
850 ///
851 /// Each leaf tracks EWMA of its own output weight and clamps to
852 /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
853 /// Recommended: use with `leaf_half_life` for streaming scenarios.
854 pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
855 self.config.adaptive_leaf_bound = Some(k);
856 self
857 }
858
859 /// Set the per-split information criterion factor (Lunde-Kleppe-Skaug 2020).
860 ///
861 /// Replaces static `max_depth` with a per-split generalization test.
862 /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
863 pub fn adaptive_depth(mut self, factor: f64) -> Self {
864 self.config.adaptive_depth = Some(factor);
865 self
866 }
867
868 /// Set the minimum hessian sum for leaf output.
869 ///
870 /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
871 /// post-replacement spikes.
872 pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
873 self.config.min_hessian_sum = Some(min_h);
874 self
875 }
876
877 /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
878 ///
879 /// When set, location gradients use Huber loss with adaptive
880 /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
881 pub fn huber_k(mut self, k: f64) -> Self {
882 self.config.huber_k = Some(k);
883 self
884 }
885
886 /// Enable graduated tree handoff with the given shadow warmup samples.
887 ///
888 /// Spawns an always-on shadow tree that trains alongside the active tree.
889 /// After `warmup` samples, the shadow begins contributing to predictions
890 /// via graduated blending. Eliminates prediction dips during tree replacement.
891 pub fn shadow_warmup(mut self, warmup: usize) -> Self {
892 self.config.shadow_warmup = Some(warmup);
893 self
894 }
895
896 /// Set the leaf prediction model type.
897 ///
898 /// [`LeafModelType::Linear`] is recommended for low-depth configurations
899 /// (depth 2--4) where per-leaf linear models reduce approximation error.
900 ///
901 /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
902 /// a trainable model per leaf, using the Hoeffding bound for promotion.
903 pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
904 self.config.leaf_model_type = lmt;
905 self
906 }
907
908 /// Set the packed cache refresh interval for distributional models.
909 ///
910 /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
911 /// maintains a packed f32 cache refreshed every `interval` training samples.
912 /// `0` (default) disables the cache.
913 pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
914 self.config.packed_refresh_interval = interval;
915 self
916 }
917
918 /// Validate and build the configuration.
919 ///
920 /// # Errors
921 ///
922 /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
923 /// [`ConfigError`] if any parameter is out of its valid range.
924 pub fn build(self) -> Result<SGBTConfig> {
925 let c = &self.config;
926
927 // -- Ensemble-level parameters --
928 if c.n_steps == 0 {
929 return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
930 }
931 if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
932 return Err(ConfigError::out_of_range(
933 "learning_rate",
934 "must be in (0, 1]",
935 c.learning_rate,
936 )
937 .into());
938 }
939 if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
940 return Err(ConfigError::out_of_range(
941 "feature_subsample_rate",
942 "must be in (0, 1]",
943 c.feature_subsample_rate,
944 )
945 .into());
946 }
947
948 // -- Tree-level parameters --
949 if c.max_depth == 0 {
950 return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
951 }
952 if c.n_bins < 2 {
953 return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
954 }
955 if c.lambda < 0.0 {
956 return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
957 }
958 if c.gamma < 0.0 {
959 return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
960 }
961 if c.grace_period == 0 {
962 return Err(
963 ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
964 );
965 }
966 if c.delta <= 0.0 || c.delta >= 1.0 {
967 return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
968 }
969
970 if c.initial_target_count == 0 {
971 return Err(ConfigError::out_of_range(
972 "initial_target_count",
973 "must be > 0",
974 c.initial_target_count,
975 )
976 .into());
977 }
978
979 // -- Streaming adaptation parameters --
980 if let Some(hl) = c.leaf_half_life {
981 if hl == 0 {
982 return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
983 }
984 }
985 if let Some(max) = c.max_tree_samples {
986 if max < 100 {
987 return Err(
988 ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
989 );
990 }
991 }
992 if let Some((base_mts, k)) = c.adaptive_mts {
993 if base_mts < 100 {
994 return Err(ConfigError::out_of_range(
995 "adaptive_mts.base_mts",
996 "must be >= 100",
997 base_mts,
998 )
999 .into());
1000 }
1001 if k <= 0.0 {
1002 return Err(ConfigError::out_of_range("adaptive_mts.k", "must be > 0", k).into());
1003 }
1004 }
1005 if !(0.0..=1.0).contains(&c.adaptive_mts_floor) {
1006 return Err(ConfigError::out_of_range(
1007 "adaptive_mts_floor",
1008 "must be in [0.0, 1.0]",
1009 c.adaptive_mts_floor,
1010 )
1011 .into());
1012 }
1013 if let Some(interval) = c.proactive_prune_interval {
1014 if interval < 100 {
1015 return Err(ConfigError::out_of_range(
1016 "proactive_prune_interval",
1017 "must be >= 100",
1018 interval,
1019 )
1020 .into());
1021 }
1022 }
1023 if let Some(interval) = c.split_reeval_interval {
1024 if interval < c.grace_period {
1025 return Err(ConfigError::invalid(
1026 "split_reeval_interval",
1027 format!(
1028 "must be >= grace_period ({}), got {}",
1029 c.grace_period, interval
1030 ),
1031 )
1032 .into());
1033 }
1034 }
1035
1036 // -- Feature names --
1037 if let Some(ref names) = c.feature_names {
1038 // O(n^2) duplicate check — names list is small so no HashSet needed
1039 for (i, name) in names.iter().enumerate() {
1040 for prev in &names[..i] {
1041 if name == prev {
1042 return Err(ConfigError::invalid(
1043 "feature_names",
1044 format!("duplicate feature name: '{}'", name),
1045 )
1046 .into());
1047 }
1048 }
1049 }
1050 }
1051
1052 // -- Feature types --
1053 if let Some(ref types) = c.feature_types {
1054 if let Some(ref names) = c.feature_names {
1055 if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
1056 return Err(ConfigError::invalid(
1057 "feature_types",
1058 format!(
1059 "length ({}) must match feature_names length ({})",
1060 types.len(),
1061 names.len()
1062 ),
1063 )
1064 .into());
1065 }
1066 }
1067 }
1068
1069 // -- Gradient clipping --
1070 if let Some(sigma) = c.gradient_clip_sigma {
1071 if sigma <= 0.0 {
1072 return Err(
1073 ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
1074 );
1075 }
1076 }
1077
1078 // -- Monotonic constraints --
1079 if let Some(ref mc) = c.monotone_constraints {
1080 for (i, &v) in mc.iter().enumerate() {
1081 if v != -1 && v != 0 && v != 1 {
1082 return Err(ConfigError::invalid(
1083 "monotone_constraints",
1084 format!("feature {}: must be -1, 0, or +1, got {}", i, v),
1085 )
1086 .into());
1087 }
1088 }
1089 }
1090
1091 // -- Leaf output clamping --
1092 if let Some(max) = c.max_leaf_output {
1093 if max <= 0.0 {
1094 return Err(
1095 ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
1096 );
1097 }
1098 }
1099
1100 // -- Per-leaf adaptive output bound --
1101 if let Some(k) = c.adaptive_leaf_bound {
1102 if k <= 0.0 {
1103 return Err(
1104 ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
1105 );
1106 }
1107 }
1108
1109 // -- Adaptive depth (per-split information criterion) --
1110 if let Some(factor) = c.adaptive_depth {
1111 if factor <= 0.0 {
1112 return Err(
1113 ConfigError::out_of_range("adaptive_depth", "must be > 0", factor).into(),
1114 );
1115 }
1116 }
1117
1118 // -- Minimum hessian sum --
1119 if let Some(min_h) = c.min_hessian_sum {
1120 if min_h <= 0.0 {
1121 return Err(
1122 ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
1123 );
1124 }
1125 }
1126
1127 // -- Huber loss multiplier --
1128 if let Some(k) = c.huber_k {
1129 if k <= 0.0 {
1130 return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
1131 }
1132 }
1133
1134 // -- Shadow warmup --
1135 if let Some(warmup) = c.shadow_warmup {
1136 if warmup == 0 {
1137 return Err(ConfigError::out_of_range(
1138 "shadow_warmup",
1139 "must be > 0",
1140 warmup as f64,
1141 )
1142 .into());
1143 }
1144 }
1145
1146 // -- Quality-based pruning parameters --
1147 if let Some(alpha) = c.quality_prune_alpha {
1148 if alpha <= 0.0 || alpha >= 1.0 {
1149 return Err(ConfigError::out_of_range(
1150 "quality_prune_alpha",
1151 "must be in (0, 1)",
1152 alpha,
1153 )
1154 .into());
1155 }
1156 }
1157 if c.quality_prune_threshold <= 0.0 {
1158 return Err(ConfigError::out_of_range(
1159 "quality_prune_threshold",
1160 "must be > 0",
1161 c.quality_prune_threshold,
1162 )
1163 .into());
1164 }
1165 if c.quality_prune_patience == 0 {
1166 return Err(ConfigError::out_of_range(
1167 "quality_prune_patience",
1168 "must be > 0",
1169 c.quality_prune_patience,
1170 )
1171 .into());
1172 }
1173
1174 // -- Error-weighted sample importance --
1175 if let Some(alpha) = c.error_weight_alpha {
1176 if alpha <= 0.0 || alpha >= 1.0 {
1177 return Err(ConfigError::out_of_range(
1178 "error_weight_alpha",
1179 "must be in (0, 1)",
1180 alpha,
1181 )
1182 .into());
1183 }
1184 }
1185
1186 // -- Drift detector parameters --
1187 match &c.drift_detector {
1188 DriftDetectorType::PageHinkley { delta, lambda } => {
1189 if *delta <= 0.0 {
1190 return Err(ConfigError::out_of_range(
1191 "drift_detector.PageHinkley.delta",
1192 "must be > 0",
1193 delta,
1194 )
1195 .into());
1196 }
1197 if *lambda <= 0.0 {
1198 return Err(ConfigError::out_of_range(
1199 "drift_detector.PageHinkley.lambda",
1200 "must be > 0",
1201 lambda,
1202 )
1203 .into());
1204 }
1205 }
1206 DriftDetectorType::Adwin { delta } => {
1207 if *delta <= 0.0 || *delta >= 1.0 {
1208 return Err(ConfigError::out_of_range(
1209 "drift_detector.Adwin.delta",
1210 "must be in (0, 1)",
1211 delta,
1212 )
1213 .into());
1214 }
1215 }
1216 DriftDetectorType::Ddm {
1217 warning_level,
1218 drift_level,
1219 min_instances,
1220 } => {
1221 if *warning_level <= 0.0 {
1222 return Err(ConfigError::out_of_range(
1223 "drift_detector.Ddm.warning_level",
1224 "must be > 0",
1225 warning_level,
1226 )
1227 .into());
1228 }
1229 if *drift_level <= 0.0 {
1230 return Err(ConfigError::out_of_range(
1231 "drift_detector.Ddm.drift_level",
1232 "must be > 0",
1233 drift_level,
1234 )
1235 .into());
1236 }
1237 if *drift_level <= *warning_level {
1238 return Err(ConfigError::invalid(
1239 "drift_detector.Ddm.drift_level",
1240 format!(
1241 "must be > warning_level ({}), got {}",
1242 warning_level, drift_level
1243 ),
1244 )
1245 .into());
1246 }
1247 if *min_instances == 0 {
1248 return Err(ConfigError::out_of_range(
1249 "drift_detector.Ddm.min_instances",
1250 "must be > 0",
1251 min_instances,
1252 )
1253 .into());
1254 }
1255 }
1256 }
1257
1258 // -- Variant parameters --
1259 match &c.variant {
1260 SGBTVariant::Standard => {} // no extra validation
1261 SGBTVariant::Skip { k } => {
1262 if *k == 0 {
1263 return Err(
1264 ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1265 );
1266 }
1267 }
1268 SGBTVariant::MultipleIterations { multiplier } => {
1269 if *multiplier <= 0.0 {
1270 return Err(ConfigError::out_of_range(
1271 "variant.MultipleIterations.multiplier",
1272 "must be > 0",
1273 multiplier,
1274 )
1275 .into());
1276 }
1277 }
1278 }
1279
1280 Ok(self.config)
1281 }
1282}
1283
1284// ---------------------------------------------------------------------------
1285// Display impls
1286// ---------------------------------------------------------------------------
1287
1288impl core::fmt::Display for DriftDetectorType {
1289 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1290 match self {
1291 Self::PageHinkley { delta, lambda } => {
1292 write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1293 }
1294 Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1295 Self::Ddm {
1296 warning_level,
1297 drift_level,
1298 min_instances,
1299 } => write!(
1300 f,
1301 "Ddm(warning={}, drift={}, min_instances={})",
1302 warning_level, drift_level, min_instances
1303 ),
1304 }
1305 }
1306}
1307
1308impl core::fmt::Display for SGBTConfig {
1309 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1310 write!(
1311 f,
1312 "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1313 self.n_steps,
1314 self.learning_rate,
1315 self.max_depth,
1316 self.n_bins,
1317 self.variant,
1318 self.drift_detector,
1319 )
1320 }
1321}
1322
1323// ---------------------------------------------------------------------------
1324// Tests
1325// ---------------------------------------------------------------------------
1326
1327#[cfg(test)]
1328mod tests {
1329 use super::*;
1330 use alloc::format;
1331 use alloc::vec;
1332
1333 // ------------------------------------------------------------------
1334 // 1. Default config values are correct
1335 // ------------------------------------------------------------------
1336 #[test]
1337 fn default_config_values() {
1338 let cfg = SGBTConfig::default();
1339 assert_eq!(cfg.n_steps, 100);
1340 assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1341 assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1342 assert_eq!(cfg.max_depth, 6);
1343 assert_eq!(cfg.n_bins, 64);
1344 assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1345 assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1346 assert_eq!(cfg.grace_period, 200);
1347 assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1348 assert_eq!(cfg.variant, SGBTVariant::Standard);
1349 }
1350
1351 // ------------------------------------------------------------------
1352 // 2. Builder chain works
1353 // ------------------------------------------------------------------
1354 #[test]
1355 fn builder_chain() {
1356 let cfg = SGBTConfig::builder()
1357 .n_steps(50)
1358 .learning_rate(0.1)
1359 .feature_subsample_rate(0.5)
1360 .max_depth(10)
1361 .n_bins(128)
1362 .lambda(0.5)
1363 .gamma(0.1)
1364 .grace_period(500)
1365 .delta(1e-3)
1366 .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1367 .variant(SGBTVariant::Skip { k: 5 })
1368 .build()
1369 .expect("valid config");
1370
1371 assert_eq!(cfg.n_steps, 50);
1372 assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1373 assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1374 assert_eq!(cfg.max_depth, 10);
1375 assert_eq!(cfg.n_bins, 128);
1376 assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1377 assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1378 assert_eq!(cfg.grace_period, 500);
1379 assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1380 assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1381
1382 // Verify drift detector type is Adwin.
1383 match &cfg.drift_detector {
1384 DriftDetectorType::Adwin { delta } => {
1385 assert!((delta - 0.01).abs() < f64::EPSILON);
1386 }
1387 _ => panic!("expected Adwin drift detector"),
1388 }
1389 }
1390
1391 // ------------------------------------------------------------------
1392 // 3. Validation rejects invalid values
1393 // ------------------------------------------------------------------
1394 #[test]
1395 fn validation_rejects_n_steps_zero() {
1396 let result = SGBTConfig::builder().n_steps(0).build();
1397 assert!(result.is_err());
1398 let msg = format!("{}", result.unwrap_err());
1399 assert!(msg.contains("n_steps"));
1400 }
1401
1402 #[test]
1403 fn validation_rejects_learning_rate_zero() {
1404 let result = SGBTConfig::builder().learning_rate(0.0).build();
1405 assert!(result.is_err());
1406 let msg = format!("{}", result.unwrap_err());
1407 assert!(msg.contains("learning_rate"));
1408 }
1409
1410 #[test]
1411 fn validation_rejects_learning_rate_above_one() {
1412 let result = SGBTConfig::builder().learning_rate(1.5).build();
1413 assert!(result.is_err());
1414 }
1415
1416 #[test]
1417 fn validation_accepts_learning_rate_one() {
1418 let result = SGBTConfig::builder().learning_rate(1.0).build();
1419 assert!(result.is_ok());
1420 }
1421
1422 #[test]
1423 fn validation_rejects_negative_learning_rate() {
1424 let result = SGBTConfig::builder().learning_rate(-0.1).build();
1425 assert!(result.is_err());
1426 }
1427
1428 #[test]
1429 fn validation_rejects_feature_subsample_zero() {
1430 let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1431 assert!(result.is_err());
1432 }
1433
1434 #[test]
1435 fn validation_rejects_feature_subsample_above_one() {
1436 let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1437 assert!(result.is_err());
1438 }
1439
1440 #[test]
1441 fn validation_rejects_max_depth_zero() {
1442 let result = SGBTConfig::builder().max_depth(0).build();
1443 assert!(result.is_err());
1444 }
1445
1446 #[test]
1447 fn validation_rejects_n_bins_one() {
1448 let result = SGBTConfig::builder().n_bins(1).build();
1449 assert!(result.is_err());
1450 }
1451
1452 #[test]
1453 fn validation_rejects_negative_lambda() {
1454 let result = SGBTConfig::builder().lambda(-0.1).build();
1455 assert!(result.is_err());
1456 }
1457
1458 #[test]
1459 fn validation_accepts_zero_lambda() {
1460 let result = SGBTConfig::builder().lambda(0.0).build();
1461 assert!(result.is_ok());
1462 }
1463
1464 #[test]
1465 fn validation_rejects_negative_gamma() {
1466 let result = SGBTConfig::builder().gamma(-0.1).build();
1467 assert!(result.is_err());
1468 }
1469
1470 #[test]
1471 fn validation_rejects_grace_period_zero() {
1472 let result = SGBTConfig::builder().grace_period(0).build();
1473 assert!(result.is_err());
1474 }
1475
1476 #[test]
1477 fn validation_rejects_delta_zero() {
1478 let result = SGBTConfig::builder().delta(0.0).build();
1479 assert!(result.is_err());
1480 }
1481
1482 #[test]
1483 fn validation_rejects_delta_one() {
1484 let result = SGBTConfig::builder().delta(1.0).build();
1485 assert!(result.is_err());
1486 }
1487
1488 // ------------------------------------------------------------------
1489 // 3b. Drift detector parameter validation
1490 // ------------------------------------------------------------------
1491 #[test]
1492 fn validation_rejects_pht_negative_delta() {
1493 let result = SGBTConfig::builder()
1494 .drift_detector(DriftDetectorType::PageHinkley {
1495 delta: -1.0,
1496 lambda: 50.0,
1497 })
1498 .build();
1499 assert!(result.is_err());
1500 let msg = format!("{}", result.unwrap_err());
1501 assert!(msg.contains("PageHinkley"));
1502 }
1503
1504 #[test]
1505 fn validation_rejects_pht_zero_lambda() {
1506 let result = SGBTConfig::builder()
1507 .drift_detector(DriftDetectorType::PageHinkley {
1508 delta: 0.005,
1509 lambda: 0.0,
1510 })
1511 .build();
1512 assert!(result.is_err());
1513 }
1514
1515 #[test]
1516 fn validation_rejects_adwin_delta_out_of_range() {
1517 let result = SGBTConfig::builder()
1518 .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1519 .build();
1520 assert!(result.is_err());
1521
1522 let result = SGBTConfig::builder()
1523 .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1524 .build();
1525 assert!(result.is_err());
1526 }
1527
1528 #[test]
1529 fn validation_rejects_ddm_warning_above_drift() {
1530 let result = SGBTConfig::builder()
1531 .drift_detector(DriftDetectorType::Ddm {
1532 warning_level: 3.0,
1533 drift_level: 2.0,
1534 min_instances: 30,
1535 })
1536 .build();
1537 assert!(result.is_err());
1538 let msg = format!("{}", result.unwrap_err());
1539 assert!(msg.contains("drift_level"));
1540 assert!(msg.contains("must be > warning_level"));
1541 }
1542
1543 #[test]
1544 fn validation_rejects_ddm_equal_levels() {
1545 let result = SGBTConfig::builder()
1546 .drift_detector(DriftDetectorType::Ddm {
1547 warning_level: 2.0,
1548 drift_level: 2.0,
1549 min_instances: 30,
1550 })
1551 .build();
1552 assert!(result.is_err());
1553 }
1554
1555 #[test]
1556 fn validation_rejects_ddm_zero_min_instances() {
1557 let result = SGBTConfig::builder()
1558 .drift_detector(DriftDetectorType::Ddm {
1559 warning_level: 2.0,
1560 drift_level: 3.0,
1561 min_instances: 0,
1562 })
1563 .build();
1564 assert!(result.is_err());
1565 }
1566
1567 #[test]
1568 fn validation_rejects_ddm_zero_warning_level() {
1569 let result = SGBTConfig::builder()
1570 .drift_detector(DriftDetectorType::Ddm {
1571 warning_level: 0.0,
1572 drift_level: 3.0,
1573 min_instances: 30,
1574 })
1575 .build();
1576 assert!(result.is_err());
1577 }
1578
1579 // ------------------------------------------------------------------
1580 // 3c. Variant parameter validation
1581 // ------------------------------------------------------------------
1582 #[test]
1583 fn validation_rejects_skip_k_zero() {
1584 let result = SGBTConfig::builder()
1585 .variant(SGBTVariant::Skip { k: 0 })
1586 .build();
1587 assert!(result.is_err());
1588 let msg = format!("{}", result.unwrap_err());
1589 assert!(msg.contains("Skip"));
1590 }
1591
1592 #[test]
1593 fn validation_rejects_mi_zero_multiplier() {
1594 let result = SGBTConfig::builder()
1595 .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1596 .build();
1597 assert!(result.is_err());
1598 }
1599
1600 #[test]
1601 fn validation_rejects_mi_negative_multiplier() {
1602 let result = SGBTConfig::builder()
1603 .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1604 .build();
1605 assert!(result.is_err());
1606 }
1607
1608 #[test]
1609 fn validation_accepts_standard_variant() {
1610 let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1611 assert!(result.is_ok());
1612 }
1613
1614 // ------------------------------------------------------------------
1615 // 4. DriftDetectorType creates correct detector types
1616 // ------------------------------------------------------------------
1617 #[test]
1618 fn drift_detector_type_creates_page_hinkley() {
1619 let dt = DriftDetectorType::PageHinkley {
1620 delta: 0.01,
1621 lambda: 100.0,
1622 };
1623 let mut detector = dt.create();
1624
1625 // Should start with zero mean.
1626 assert_eq!(detector.estimated_mean(), 0.0);
1627
1628 // Feed a stable value -- should not drift.
1629 let signal = detector.update(1.0);
1630 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1631 }
1632
1633 #[test]
1634 fn drift_detector_type_creates_adwin() {
1635 let dt = DriftDetectorType::Adwin { delta: 0.05 };
1636 let mut detector = dt.create();
1637
1638 assert_eq!(detector.estimated_mean(), 0.0);
1639 let signal = detector.update(1.0);
1640 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1641 }
1642
1643 #[test]
1644 fn drift_detector_type_creates_ddm() {
1645 let dt = DriftDetectorType::Ddm {
1646 warning_level: 2.0,
1647 drift_level: 3.0,
1648 min_instances: 30,
1649 };
1650 let mut detector = dt.create();
1651
1652 assert_eq!(detector.estimated_mean(), 0.0);
1653 let signal = detector.update(0.1);
1654 assert_eq!(signal, crate::drift::DriftSignal::Stable);
1655 }
1656
1657 // ------------------------------------------------------------------
1658 // 5. Default drift detector is PageHinkley
1659 // ------------------------------------------------------------------
1660 #[test]
1661 fn default_drift_detector_is_page_hinkley() {
1662 let dt = DriftDetectorType::default();
1663 match dt {
1664 DriftDetectorType::PageHinkley { delta, lambda } => {
1665 assert!((delta - 0.005).abs() < f64::EPSILON);
1666 assert!((lambda - 50.0).abs() < f64::EPSILON);
1667 }
1668 _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1669 }
1670 }
1671
1672 // ------------------------------------------------------------------
1673 // 6. Default config builds without error
1674 // ------------------------------------------------------------------
1675 #[test]
1676 fn default_config_builds() {
1677 let result = SGBTConfig::builder().build();
1678 assert!(result.is_ok());
1679 }
1680
1681 // ------------------------------------------------------------------
1682 // 7. Config clone preserves all fields
1683 // ------------------------------------------------------------------
1684 #[test]
1685 fn config_clone_preserves_fields() {
1686 let original = SGBTConfig::builder()
1687 .n_steps(50)
1688 .learning_rate(0.05)
1689 .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1690 .build()
1691 .unwrap();
1692
1693 let cloned = original.clone();
1694
1695 assert_eq!(cloned.n_steps, 50);
1696 assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1697 assert_eq!(
1698 cloned.variant,
1699 SGBTVariant::MultipleIterations { multiplier: 5.0 }
1700 );
1701 }
1702
1703 // ------------------------------------------------------------------
1704 // 8. All three DDM valid configs accepted
1705 // ------------------------------------------------------------------
1706 #[test]
1707 fn valid_ddm_config_accepted() {
1708 let result = SGBTConfig::builder()
1709 .drift_detector(DriftDetectorType::Ddm {
1710 warning_level: 1.5,
1711 drift_level: 2.5,
1712 min_instances: 10,
1713 })
1714 .build();
1715 assert!(result.is_ok());
1716 }
1717
1718 // ------------------------------------------------------------------
1719 // 9. Created detectors are functional (round-trip test)
1720 // ------------------------------------------------------------------
1721 #[test]
1722 fn created_detectors_are_functional() {
1723 // Create a detector, feed it data, verify it can detect drift.
1724 let dt = DriftDetectorType::PageHinkley {
1725 delta: 0.005,
1726 lambda: 50.0,
1727 };
1728 let mut detector = dt.create();
1729
1730 // Feed 500 stable values.
1731 for _ in 0..500 {
1732 detector.update(1.0);
1733 }
1734
1735 // Feed shifted values -- should eventually drift.
1736 let mut drifted = false;
1737 for _ in 0..500 {
1738 if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1739 drifted = true;
1740 break;
1741 }
1742 }
1743 assert!(
1744 drifted,
1745 "detector created from DriftDetectorType should be functional"
1746 );
1747 }
1748
1749 // ------------------------------------------------------------------
1750 // 10. Boundary acceptance: n_bins=2 is the minimum valid
1751 // ------------------------------------------------------------------
1752 #[test]
1753 fn boundary_n_bins_two_accepted() {
1754 let result = SGBTConfig::builder().n_bins(2).build();
1755 assert!(result.is_ok());
1756 }
1757
1758 // ------------------------------------------------------------------
1759 // 11. Boundary acceptance: grace_period=1 is valid
1760 // ------------------------------------------------------------------
1761 #[test]
1762 fn boundary_grace_period_one_accepted() {
1763 let result = SGBTConfig::builder().grace_period(1).build();
1764 assert!(result.is_ok());
1765 }
1766
1767 // ------------------------------------------------------------------
1768 // 12. Feature names -- valid config with names
1769 // ------------------------------------------------------------------
1770 #[test]
1771 fn feature_names_accepted() {
1772 let cfg = SGBTConfig::builder()
1773 .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1774 .build()
1775 .unwrap();
1776 assert_eq!(
1777 cfg.feature_names.as_ref().unwrap(),
1778 &["price", "volume", "spread"]
1779 );
1780 }
1781
1782 // ------------------------------------------------------------------
1783 // 13. Feature names -- duplicate names rejected
1784 // ------------------------------------------------------------------
1785 #[test]
1786 fn feature_names_rejects_duplicates() {
1787 let result = SGBTConfig::builder()
1788 .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1789 .build();
1790 assert!(result.is_err());
1791 let msg = format!("{}", result.unwrap_err());
1792 assert!(msg.contains("duplicate"));
1793 }
1794
1795 // ------------------------------------------------------------------
1796 // 14. Feature names -- serde backward compat (missing field)
1797 // Requires serde_json which is only in the full irithyll crate.
1798 // ------------------------------------------------------------------
1799
1800 // ------------------------------------------------------------------
1801 // 15. Feature names -- empty vec is valid
1802 // ------------------------------------------------------------------
1803 #[test]
1804 fn feature_names_empty_vec_accepted() {
1805 let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1806 assert!(cfg.feature_names.unwrap().is_empty());
1807 }
1808
1809 // ------------------------------------------------------------------
1810 // 16. Adaptive leaf bound -- builder
1811 // ------------------------------------------------------------------
1812 #[test]
1813 fn builder_adaptive_leaf_bound() {
1814 let cfg = SGBTConfig::builder()
1815 .adaptive_leaf_bound(3.0)
1816 .build()
1817 .unwrap();
1818 assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1819 }
1820
1821 // ------------------------------------------------------------------
1822 // 17. Adaptive leaf bound -- validation rejects zero
1823 // ------------------------------------------------------------------
1824 #[test]
1825 fn validation_rejects_zero_adaptive_leaf_bound() {
1826 let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1827 assert!(result.is_err());
1828 let msg = format!("{}", result.unwrap_err());
1829 assert!(
1830 msg.contains("adaptive_leaf_bound"),
1831 "error should mention adaptive_leaf_bound: {}",
1832 msg,
1833 );
1834 }
1835
1836 // ------------------------------------------------------------------
1837 // 18. Adaptive leaf bound -- serde backward compat
1838 // Requires serde_json which is only in the full irithyll crate.
1839 // ------------------------------------------------------------------
1840}