irithyll_core/ensemble/config.rs
1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34/// `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1) compute.
35/// Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38/// log(σ) from features (NGBoost-style dual chain). Gives *feature-conditional*
39/// uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43 /// EWMA of squared prediction errors — always calibrated, O(1).
44 #[default]
45 Empirical,
46 /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47 TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61 /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62 /// (detection threshold).
63 PageHinkley {
64 /// Magnitude tolerance. Default 0.005.
65 delta: f64,
66 /// Detection threshold. Default 50.0.
67 lambda: f64,
68 },
69
70 /// ADWIN with custom confidence parameter.
71 Adwin {
72 /// Confidence (smaller = fewer false positives). Default 0.002.
73 delta: f64,
74 },
75
76 /// DDM with custom warning/drift levels and minimum warmup instances.
77 Ddm {
78 /// Warning threshold multiplier. Default 2.0.
79 warning_level: f64,
80 /// Drift threshold multiplier. Default 3.0.
81 drift_level: f64,
82 /// Minimum observations before detection activates. Default 30.
83 min_instances: u64,
84 },
85}
86
87impl Default for DriftDetectorType {
88 fn default() -> Self {
89 DriftDetectorType::PageHinkley {
90 delta: 0.005,
91 lambda: 50.0,
92 }
93 }
94}
95
96impl DriftDetectorType {
97 /// Create a new, fresh drift detector from this configuration.
98 pub fn create(&self) -> Box<dyn DriftDetector> {
99 match self {
100 Self::PageHinkley { delta, lambda } => {
101 Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102 }
103 Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104 Self::Ddm {
105 warning_level,
106 drift_level,
107 min_instances,
108 } => Box::new(Ddm::with_params(
109 *warning_level,
110 *drift_level,
111 *min_instances,
112 )),
113 }
114 }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter | Default |
128/// |--------------------------|----------------------|
129/// | `n_steps` | 100 |
130/// | `learning_rate` | 0.0125 |
131/// | `feature_subsample_rate` | 0.75 |
132/// | `max_depth` | 6 |
133/// | `n_bins` | 64 |
134/// | `lambda` | 1.0 |
135/// | `gamma` | 0.0 |
136/// | `grace_period` | 200 |
137/// | `delta` | 1e-7 |
138/// | `drift_detector` | PageHinkley(0.005, 50.0) |
139/// | `variant` | Standard |
140/// | `seed` | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count` | 50 |
142/// | `leaf_half_life` | None (disabled) |
143/// | `max_tree_samples` | None (disabled) |
144/// | `split_reeval_interval` | None (disabled) |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148 /// Number of boosting steps (trees in the ensemble). Default 100.
149 pub n_steps: usize,
150 /// Learning rate (shrinkage). Default 0.0125.
151 pub learning_rate: f64,
152 /// Fraction of features to subsample per tree. Default 0.75.
153 pub feature_subsample_rate: f64,
154 /// Maximum tree depth. Default 6.
155 pub max_depth: usize,
156 /// Number of histogram bins. Default 64.
157 pub n_bins: usize,
158 /// L2 regularization parameter (lambda). Default 1.0.
159 pub lambda: f64,
160 /// Minimum split gain (gamma). Default 0.0.
161 pub gamma: f64,
162 /// Grace period: min samples before evaluating splits. Default 200.
163 pub grace_period: usize,
164 /// Hoeffding bound confidence (delta). Default 1e-7.
165 pub delta: f64,
166 /// Drift detector type for tree replacement. Default: PageHinkley.
167 pub drift_detector: DriftDetectorType,
168 /// SGBT computational variant. Default: Standard.
169 pub variant: SGBTVariant,
170 /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171 ///
172 /// Controls feature subsampling and variant skip/MI stochastic decisions.
173 /// Two models with the same seed and same data will produce identical results.
174 pub seed: u64,
175 /// Number of initial targets to collect before computing the base prediction.
176 /// Default: 50.
177 pub initial_target_count: usize,
178
179 /// Half-life for exponential leaf decay (in samples per leaf).
180 ///
181 /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182 /// statistics have half the weight of the most recent sample. This causes
183 /// the model to continuously adapt to changing data distributions rather
184 /// than freezing on early observations.
185 ///
186 /// `None` (default) disables decay -- traditional monotonic accumulation.
187 #[cfg_attr(feature = "serde", serde(default))]
188 pub leaf_half_life: Option<usize>,
189
190 /// Maximum samples a single tree processes before proactive replacement.
191 ///
192 /// After this many samples, the tree is replaced with a fresh one regardless
193 /// of drift detector state. Prevents stale tree structure from persisting
194 /// when the drift detector is not sensitive enough.
195 ///
196 /// `None` (default) disables time-based replacement.
197 #[cfg_attr(feature = "serde", serde(default))]
198 pub max_tree_samples: Option<u64>,
199
200 /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
201 /// whether a split would improve them.
202 ///
203 /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
204 /// `split_reeval_interval` samples since its last evaluation and has reached
205 /// max depth, it re-evaluates whether a split should be performed.
206 ///
207 /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
208 #[cfg_attr(feature = "serde", serde(default))]
209 pub split_reeval_interval: Option<usize>,
210
211 /// Optional human-readable feature names.
212 ///
213 /// When set, enables [`named_feature_importances`](super::SGBT::named_feature_importances) and
214 /// [`train_one_named`](super::SGBT::train_one_named) for production-friendly named access.
215 /// Length must match the number of features in training data.
216 #[cfg_attr(feature = "serde", serde(default))]
217 pub feature_names: Option<Vec<String>>,
218
219 /// Optional per-feature type declarations.
220 ///
221 /// When set, declares which features are categorical vs continuous.
222 /// Categorical features use one-bin-per-category binning and Fisher
223 /// optimal binary partitioning for split evaluation.
224 /// Length must match the number of features in training data.
225 ///
226 /// `None` (default) treats all features as continuous.
227 #[cfg_attr(feature = "serde", serde(default))]
228 pub feature_types: Option<Vec<FeatureType>>,
229
230 /// Gradient clipping threshold in standard deviations per leaf.
231 ///
232 /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
233 /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
234 /// clamped to the boundary. This prevents outlier samples from corrupting
235 /// leaf statistics, which is critical in streaming settings where sudden
236 /// label floods can destabilize the model.
237 ///
238 /// Typical value: 3.0 (3-sigma clipping).
239 /// `None` (default) disables gradient clipping.
240 #[cfg_attr(feature = "serde", serde(default))]
241 pub gradient_clip_sigma: Option<f64>,
242
243 /// Per-feature monotonic constraints.
244 ///
245 /// Each element specifies the monotonic relationship between a feature and
246 /// the prediction:
247 /// - `+1`: prediction must be non-decreasing as feature value increases.
248 /// - `-1`: prediction must be non-increasing as feature value increases.
249 /// - `0`: no constraint (unconstrained).
250 ///
251 /// During split evaluation, candidate splits that would violate monotonicity
252 /// (left child value > right child value for +1 constraints, or vice versa)
253 /// are rejected.
254 ///
255 /// Length must match the number of features in training data.
256 /// `None` (default) means no monotonic constraints.
257 #[cfg_attr(feature = "serde", serde(default))]
258 pub monotone_constraints: Option<Vec<i8>>,
259
260 /// EWMA smoothing factor for quality-based tree pruning.
261 ///
262 /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
263 /// moving average of its marginal contribution to the ensemble. Trees whose
264 /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
265 /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
266 /// samples are replaced with a fresh tree that can learn the current regime.
267 ///
268 /// This prevents "dead wood" -- trees from a past regime that no longer
269 /// contribute meaningfully to ensemble accuracy.
270 ///
271 /// `None` (default) disables quality-based pruning.
272 /// Suggested value: 0.01.
273 #[cfg_attr(feature = "serde", serde(default))]
274 pub quality_prune_alpha: Option<f64>,
275
276 /// Minimum contribution threshold for quality-based pruning.
277 ///
278 /// A tree's EWMA contribution must stay above this value to avoid being
279 /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
280 ///
281 /// Default: 1e-6.
282 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
283 pub quality_prune_threshold: f64,
284
285 /// Consecutive low-contribution samples before a tree is replaced.
286 ///
287 /// After this many consecutive samples where a tree's EWMA contribution
288 /// is below `quality_prune_threshold`, the tree is reset. Only used when
289 /// `quality_prune_alpha` is `Some`.
290 ///
291 /// Default: 500.
292 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
293 pub quality_prune_patience: u64,
294
295 /// EWMA smoothing factor for error-weighted sample importance.
296 ///
297 /// When `Some(alpha)`, samples the model predicted poorly get higher
298 /// effective weight during histogram accumulation. The weight is:
299 /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
300 ///
301 /// This is a streaming version of AdaBoost's reweighting applied at the
302 /// gradient level -- learning capacity focuses on hard/novel patterns,
303 /// enabling faster adaptation to regime changes.
304 ///
305 /// `None` (default) disables error weighting.
306 /// Suggested value: 0.01.
307 #[cfg_attr(feature = "serde", serde(default))]
308 pub error_weight_alpha: Option<f64>,
309
310 /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
311 ///
312 /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
313 /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
314 /// is an EWMA of the model's predicted σ (alpha = 0.001).
315 ///
316 /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
317 /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
318 /// trains at the unmodulated base rate to prevent positive feedback loops.
319 ///
320 /// Default: `false`.
321 #[cfg_attr(feature = "serde", serde(default))]
322 pub uncertainty_modulated_lr: bool,
323
324 /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
325 ///
326 /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
327 /// errors. `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1).
328 /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
329 /// separate tree ensemble predicting log(σ) from features.
330 ///
331 /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
332 /// is strongly recommended — scale tree gradients are inherently weak and
333 /// the trees often fail to split.
334 #[cfg_attr(feature = "serde", serde(default))]
335 pub scale_mode: ScaleMode,
336
337 /// EWMA smoothing factor for empirical σ estimation.
338 ///
339 /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
340 /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
341 /// Higher values react faster to regime changes but are noisier.
342 ///
343 /// Default: `0.01` (~100-sample effective window).
344 #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
345 pub empirical_sigma_alpha: f64,
346
347 /// Maximum absolute leaf output value.
348 ///
349 /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
350 /// Prevents runaway leaf weights from causing prediction explosions
351 /// in feedback loops. `None` (default) means no clamping.
352 #[cfg_attr(feature = "serde", serde(default))]
353 pub max_leaf_output: Option<f64>,
354
355 /// Per-leaf adaptive output bound (sigma multiplier).
356 ///
357 /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
358 /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
359 /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
360 ///
361 /// This is strictly superior to `max_leaf_output` for streaming — the bound
362 /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
363 /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
364 ///
365 /// Typical value: 3.0 (3-sigma bound).
366 /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
367 #[cfg_attr(feature = "serde", serde(default))]
368 pub adaptive_leaf_bound: Option<f64>,
369
370 /// Minimum hessian sum before a leaf produces non-zero output.
371 ///
372 /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
373 /// Prevents post-replacement spikes from fresh leaves with insufficient
374 /// samples. `None` (default) means all leaves contribute immediately.
375 #[cfg_attr(feature = "serde", serde(default))]
376 pub min_hessian_sum: Option<f64>,
377
378 /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
379 ///
380 /// When `Some(k)`, the distributional location gradient uses Huber loss
381 /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
382 /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
383 /// `None` (default) uses squared loss.
384 #[cfg_attr(feature = "serde", serde(default))]
385 pub huber_k: Option<f64>,
386
387 /// Shadow warmup for graduated tree handoff.
388 ///
389 /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
390 /// alongside every active tree. The shadow trains on the same gradient stream
391 /// but does not contribute to predictions until it has seen `n` samples.
392 ///
393 /// As the active tree ages past 80% of `max_tree_samples`, its prediction
394 /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
395 /// over `n` samples after warmup. When the active weight reaches 0, the shadow
396 /// is promoted and a new shadow is spawned — no cold-start prediction dip.
397 ///
398 /// Requires `max_tree_samples` to be set for time-based graduated handoff.
399 /// Drift-based replacement still uses hard swap (shadow is already warm).
400 ///
401 /// `None` (default) disables graduated handoff — uses traditional hard swap.
402 #[cfg_attr(feature = "serde", serde(default))]
403 pub shadow_warmup: Option<usize>,
404
405 /// Leaf prediction model type.
406 ///
407 /// Controls how each leaf computes its prediction:
408 /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
409 /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
410 /// AdaGrad optimization. Optional `decay` for concept drift. Recommended for
411 /// low-depth trees (depth 2--4).
412 /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
413 /// Optional `decay` for concept drift.
414 /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
415 /// when the Hoeffding bound confirms a more complex model is better.
416 ///
417 /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
418 #[cfg_attr(feature = "serde", serde(default))]
419 pub leaf_model_type: LeafModelType,
420
421 /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
422 ///
423 /// When non-zero, the distributional model maintains a packed f32 cache of
424 /// its location ensemble that is re-exported every `packed_refresh_interval`
425 /// training samples. Predictions use the cache for O(1)-per-tree inference
426 /// via contiguous memory traversal, falling back to full tree traversal when
427 /// the cache is absent or produces non-finite results.
428 ///
429 /// `0` (default) disables the packed cache.
430 #[cfg_attr(feature = "serde", serde(default))]
431 pub packed_refresh_interval: u64,
432}
433
434fn default_empirical_sigma_alpha() -> f64 {
435 0.01
436}
437
438fn default_quality_prune_threshold() -> f64 {
439 1e-6
440}
441
442fn default_quality_prune_patience() -> u64 {
443 500
444}
445
446impl Default for SGBTConfig {
447 fn default() -> Self {
448 Self {
449 n_steps: 100,
450 learning_rate: 0.0125,
451 feature_subsample_rate: 0.75,
452 max_depth: 6,
453 n_bins: 64,
454 lambda: 1.0,
455 gamma: 0.0,
456 grace_period: 200,
457 delta: 1e-7,
458 drift_detector: DriftDetectorType::default(),
459 variant: SGBTVariant::default(),
460 seed: 0xDEAD_BEEF_CAFE_4242,
461 initial_target_count: 50,
462 leaf_half_life: None,
463 max_tree_samples: None,
464 split_reeval_interval: None,
465 feature_names: None,
466 feature_types: None,
467 gradient_clip_sigma: None,
468 monotone_constraints: None,
469 quality_prune_alpha: None,
470 quality_prune_threshold: 1e-6,
471 quality_prune_patience: 500,
472 error_weight_alpha: None,
473 uncertainty_modulated_lr: false,
474 scale_mode: ScaleMode::default(),
475 empirical_sigma_alpha: 0.01,
476 max_leaf_output: None,
477 adaptive_leaf_bound: None,
478 min_hessian_sum: None,
479 huber_k: None,
480 shadow_warmup: None,
481 leaf_model_type: LeafModelType::default(),
482 packed_refresh_interval: 0,
483 }
484 }
485}
486
487impl SGBTConfig {
488 /// Start building a configuration via the builder pattern.
489 pub fn builder() -> SGBTConfigBuilder {
490 SGBTConfigBuilder::default()
491 }
492}
493
494// ---------------------------------------------------------------------------
495// SGBTConfigBuilder
496// ---------------------------------------------------------------------------
497
498/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
499///
500/// # Example
501///
502/// ```ignore
503/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
504/// use irithyll::ensemble::variants::SGBTVariant;
505///
506/// let config = SGBTConfig::builder()
507/// .n_steps(200)
508/// .learning_rate(0.05)
509/// .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
510/// .variant(SGBTVariant::Skip { k: 10 })
511/// .build()
512/// .expect("valid config");
513/// ```
514#[derive(Debug, Clone, Default)]
515pub struct SGBTConfigBuilder {
516 config: SGBTConfig,
517}
518
519impl SGBTConfigBuilder {
520 /// Set the number of boosting steps (trees in the ensemble).
521 pub fn n_steps(mut self, n: usize) -> Self {
522 self.config.n_steps = n;
523 self
524 }
525
526 /// Set the learning rate (shrinkage factor).
527 pub fn learning_rate(mut self, lr: f64) -> Self {
528 self.config.learning_rate = lr;
529 self
530 }
531
532 /// Set the fraction of features to subsample per tree.
533 pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
534 self.config.feature_subsample_rate = rate;
535 self
536 }
537
538 /// Set the maximum tree depth.
539 pub fn max_depth(mut self, depth: usize) -> Self {
540 self.config.max_depth = depth;
541 self
542 }
543
544 /// Set the number of histogram bins per feature.
545 pub fn n_bins(mut self, bins: usize) -> Self {
546 self.config.n_bins = bins;
547 self
548 }
549
550 /// Set the L2 regularization parameter (lambda).
551 pub fn lambda(mut self, l: f64) -> Self {
552 self.config.lambda = l;
553 self
554 }
555
556 /// Set the minimum split gain (gamma).
557 pub fn gamma(mut self, g: f64) -> Self {
558 self.config.gamma = g;
559 self
560 }
561
562 /// Set the grace period (minimum samples before evaluating splits).
563 pub fn grace_period(mut self, gp: usize) -> Self {
564 self.config.grace_period = gp;
565 self
566 }
567
568 /// Set the Hoeffding bound confidence parameter (delta).
569 pub fn delta(mut self, d: f64) -> Self {
570 self.config.delta = d;
571 self
572 }
573
574 /// Set the drift detector type for tree replacement.
575 pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
576 self.config.drift_detector = dt;
577 self
578 }
579
580 /// Set the SGBT computational variant.
581 pub fn variant(mut self, v: SGBTVariant) -> Self {
582 self.config.variant = v;
583 self
584 }
585
586 /// Set the random seed for deterministic reproducibility.
587 ///
588 /// Controls feature subsampling and variant skip/MI stochastic decisions.
589 /// Two models with the same seed and data sequence will produce identical results.
590 pub fn seed(mut self, seed: u64) -> Self {
591 self.config.seed = seed;
592 self
593 }
594
595 /// Set the number of initial targets to collect before computing the base prediction.
596 ///
597 /// The model collects this many target values before initializing the base
598 /// prediction (via `loss.initial_prediction`). Default: 50.
599 pub fn initial_target_count(mut self, count: usize) -> Self {
600 self.config.initial_target_count = count;
601 self
602 }
603
604 /// Set the half-life for exponential leaf decay (in samples per leaf).
605 ///
606 /// After `n` samples, a leaf's accumulated statistics have half the weight
607 /// of the most recent sample. Enables continuous adaptation to concept drift.
608 pub fn leaf_half_life(mut self, n: usize) -> Self {
609 self.config.leaf_half_life = Some(n);
610 self
611 }
612
613 /// Set the maximum samples a single tree processes before proactive replacement.
614 ///
615 /// After `n` samples, the tree is replaced regardless of drift detector state.
616 pub fn max_tree_samples(mut self, n: u64) -> Self {
617 self.config.max_tree_samples = Some(n);
618 self
619 }
620
621 /// Set the split re-evaluation interval for max-depth leaves.
622 ///
623 /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
624 /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
625 pub fn split_reeval_interval(mut self, n: usize) -> Self {
626 self.config.split_reeval_interval = Some(n);
627 self
628 }
629
630 /// Set human-readable feature names.
631 ///
632 /// Enables named feature importances and named training input.
633 /// Names must be unique; validated at [`build()`](Self::build).
634 pub fn feature_names(mut self, names: Vec<String>) -> Self {
635 self.config.feature_names = Some(names);
636 self
637 }
638
639 /// Set per-feature type declarations.
640 ///
641 /// Declares which features are categorical vs continuous. Categorical features
642 /// use one-bin-per-category binning and Fisher optimal binary partitioning.
643 /// Supports up to 64 distinct category values per categorical feature.
644 pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
645 self.config.feature_types = Some(types);
646 self
647 }
648
649 /// Set per-leaf gradient clipping threshold (in standard deviations).
650 ///
651 /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
652 /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
653 /// from corrupting streaming model stability.
654 ///
655 /// Typical value: 3.0 (3-sigma clipping).
656 pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
657 self.config.gradient_clip_sigma = Some(sigma);
658 self
659 }
660
661 /// Set per-feature monotonic constraints.
662 ///
663 /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
664 /// Candidate splits violating monotonicity are rejected during tree growth.
665 pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
666 self.config.monotone_constraints = Some(constraints);
667 self
668 }
669
670 /// Enable quality-based tree pruning with the given EWMA smoothing factor.
671 ///
672 /// Trees whose marginal contribution drops below the threshold for
673 /// `patience` consecutive samples are replaced with fresh trees.
674 /// Suggested alpha: 0.01.
675 pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
676 self.config.quality_prune_alpha = Some(alpha);
677 self
678 }
679
680 /// Set the minimum contribution threshold for quality-based pruning.
681 ///
682 /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
683 pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
684 self.config.quality_prune_threshold = threshold;
685 self
686 }
687
688 /// Set the patience (consecutive low-contribution samples) before pruning.
689 ///
690 /// Default: 500. Only relevant when `quality_prune_alpha` is set.
691 pub fn quality_prune_patience(mut self, patience: u64) -> Self {
692 self.config.quality_prune_patience = patience;
693 self
694 }
695
696 /// Enable error-weighted sample importance with the given EWMA smoothing factor.
697 ///
698 /// Samples the model predicted poorly get higher effective weight.
699 /// Suggested alpha: 0.01.
700 pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
701 self.config.error_weight_alpha = Some(alpha);
702 self
703 }
704
705 /// Enable σ-modulated learning rate for distributional models.
706 ///
707 /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
708 /// so the model adapts faster during high-uncertainty regimes and conserves
709 /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
710 ///
711 /// By default uses empirical σ (EWMA of squared errors). Set
712 /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
713 pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
714 self.config.uncertainty_modulated_lr = enabled;
715 self
716 }
717
718 /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
719 ///
720 /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
721 /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
722 pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
723 self.config.scale_mode = mode;
724 self
725 }
726
727 /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
728 ///
729 /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
730 pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
731 self.config.empirical_sigma_alpha = alpha;
732 self
733 }
734
735 /// Set the maximum absolute leaf output value.
736 ///
737 /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
738 /// that cause prediction explosions.
739 pub fn max_leaf_output(mut self, max: f64) -> Self {
740 self.config.max_leaf_output = Some(max);
741 self
742 }
743
744 /// Set per-leaf adaptive output bound (sigma multiplier).
745 ///
746 /// Each leaf tracks EWMA of its own output weight and clamps to
747 /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
748 /// Recommended: use with `leaf_half_life` for streaming scenarios.
749 pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
750 self.config.adaptive_leaf_bound = Some(k);
751 self
752 }
753
754 /// Set the minimum hessian sum for leaf output.
755 ///
756 /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
757 /// post-replacement spikes.
758 pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
759 self.config.min_hessian_sum = Some(min_h);
760 self
761 }
762
763 /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
764 ///
765 /// When set, location gradients use Huber loss with adaptive
766 /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
767 pub fn huber_k(mut self, k: f64) -> Self {
768 self.config.huber_k = Some(k);
769 self
770 }
771
772 /// Enable graduated tree handoff with the given shadow warmup samples.
773 ///
774 /// Spawns an always-on shadow tree that trains alongside the active tree.
775 /// After `warmup` samples, the shadow begins contributing to predictions
776 /// via graduated blending. Eliminates prediction dips during tree replacement.
777 pub fn shadow_warmup(mut self, warmup: usize) -> Self {
778 self.config.shadow_warmup = Some(warmup);
779 self
780 }
781
782 /// Set the leaf prediction model type.
783 ///
784 /// [`LeafModelType::Linear`] is recommended for low-depth configurations
785 /// (depth 2--4) where per-leaf linear models reduce approximation error.
786 ///
787 /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
788 /// a trainable model per leaf, using the Hoeffding bound for promotion.
789 pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
790 self.config.leaf_model_type = lmt;
791 self
792 }
793
794 /// Set the packed cache refresh interval for distributional models.
795 ///
796 /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
797 /// maintains a packed f32 cache refreshed every `interval` training samples.
798 /// `0` (default) disables the cache.
799 pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
800 self.config.packed_refresh_interval = interval;
801 self
802 }
803
804 /// Validate and build the configuration.
805 ///
806 /// # Errors
807 ///
808 /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
809 /// [`ConfigError`] if any parameter is out of its valid range.
810 pub fn build(self) -> Result<SGBTConfig> {
811 let c = &self.config;
812
813 // -- Ensemble-level parameters --
814 if c.n_steps == 0 {
815 return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
816 }
817 if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
818 return Err(ConfigError::out_of_range(
819 "learning_rate",
820 "must be in (0, 1]",
821 c.learning_rate,
822 )
823 .into());
824 }
825 if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
826 return Err(ConfigError::out_of_range(
827 "feature_subsample_rate",
828 "must be in (0, 1]",
829 c.feature_subsample_rate,
830 )
831 .into());
832 }
833
834 // -- Tree-level parameters --
835 if c.max_depth == 0 {
836 return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
837 }
838 if c.n_bins < 2 {
839 return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
840 }
841 if c.lambda < 0.0 {
842 return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
843 }
844 if c.gamma < 0.0 {
845 return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
846 }
847 if c.grace_period == 0 {
848 return Err(
849 ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
850 );
851 }
852 if c.delta <= 0.0 || c.delta >= 1.0 {
853 return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
854 }
855
856 if c.initial_target_count == 0 {
857 return Err(ConfigError::out_of_range(
858 "initial_target_count",
859 "must be > 0",
860 c.initial_target_count,
861 )
862 .into());
863 }
864
865 // -- Streaming adaptation parameters --
866 if let Some(hl) = c.leaf_half_life {
867 if hl == 0 {
868 return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
869 }
870 }
871 if let Some(max) = c.max_tree_samples {
872 if max < 100 {
873 return Err(
874 ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
875 );
876 }
877 }
878 if let Some(interval) = c.split_reeval_interval {
879 if interval < c.grace_period {
880 return Err(ConfigError::invalid(
881 "split_reeval_interval",
882 format!(
883 "must be >= grace_period ({}), got {}",
884 c.grace_period, interval
885 ),
886 )
887 .into());
888 }
889 }
890
891 // -- Feature names --
892 if let Some(ref names) = c.feature_names {
893 // O(n^2) duplicate check — names list is small so no HashSet needed
894 for (i, name) in names.iter().enumerate() {
895 for prev in &names[..i] {
896 if name == prev {
897 return Err(ConfigError::invalid(
898 "feature_names",
899 format!("duplicate feature name: '{}'", name),
900 )
901 .into());
902 }
903 }
904 }
905 }
906
907 // -- Feature types --
908 if let Some(ref types) = c.feature_types {
909 if let Some(ref names) = c.feature_names {
910 if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
911 return Err(ConfigError::invalid(
912 "feature_types",
913 format!(
914 "length ({}) must match feature_names length ({})",
915 types.len(),
916 names.len()
917 ),
918 )
919 .into());
920 }
921 }
922 }
923
924 // -- Gradient clipping --
925 if let Some(sigma) = c.gradient_clip_sigma {
926 if sigma <= 0.0 {
927 return Err(
928 ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
929 );
930 }
931 }
932
933 // -- Monotonic constraints --
934 if let Some(ref mc) = c.monotone_constraints {
935 for (i, &v) in mc.iter().enumerate() {
936 if v != -1 && v != 0 && v != 1 {
937 return Err(ConfigError::invalid(
938 "monotone_constraints",
939 format!("feature {}: must be -1, 0, or +1, got {}", i, v),
940 )
941 .into());
942 }
943 }
944 }
945
946 // -- Leaf output clamping --
947 if let Some(max) = c.max_leaf_output {
948 if max <= 0.0 {
949 return Err(
950 ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
951 );
952 }
953 }
954
955 // -- Per-leaf adaptive output bound --
956 if let Some(k) = c.adaptive_leaf_bound {
957 if k <= 0.0 {
958 return Err(
959 ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
960 );
961 }
962 }
963
964 // -- Minimum hessian sum --
965 if let Some(min_h) = c.min_hessian_sum {
966 if min_h <= 0.0 {
967 return Err(
968 ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
969 );
970 }
971 }
972
973 // -- Huber loss multiplier --
974 if let Some(k) = c.huber_k {
975 if k <= 0.0 {
976 return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
977 }
978 }
979
980 // -- Shadow warmup --
981 if let Some(warmup) = c.shadow_warmup {
982 if warmup == 0 {
983 return Err(ConfigError::out_of_range(
984 "shadow_warmup",
985 "must be > 0",
986 warmup as f64,
987 )
988 .into());
989 }
990 }
991
992 // -- Quality-based pruning parameters --
993 if let Some(alpha) = c.quality_prune_alpha {
994 if alpha <= 0.0 || alpha >= 1.0 {
995 return Err(ConfigError::out_of_range(
996 "quality_prune_alpha",
997 "must be in (0, 1)",
998 alpha,
999 )
1000 .into());
1001 }
1002 }
1003 if c.quality_prune_threshold <= 0.0 {
1004 return Err(ConfigError::out_of_range(
1005 "quality_prune_threshold",
1006 "must be > 0",
1007 c.quality_prune_threshold,
1008 )
1009 .into());
1010 }
1011 if c.quality_prune_patience == 0 {
1012 return Err(ConfigError::out_of_range(
1013 "quality_prune_patience",
1014 "must be > 0",
1015 c.quality_prune_patience,
1016 )
1017 .into());
1018 }
1019
1020 // -- Error-weighted sample importance --
1021 if let Some(alpha) = c.error_weight_alpha {
1022 if alpha <= 0.0 || alpha >= 1.0 {
1023 return Err(ConfigError::out_of_range(
1024 "error_weight_alpha",
1025 "must be in (0, 1)",
1026 alpha,
1027 )
1028 .into());
1029 }
1030 }
1031
1032 // -- Drift detector parameters --
1033 match &c.drift_detector {
1034 DriftDetectorType::PageHinkley { delta, lambda } => {
1035 if *delta <= 0.0 {
1036 return Err(ConfigError::out_of_range(
1037 "drift_detector.PageHinkley.delta",
1038 "must be > 0",
1039 delta,
1040 )
1041 .into());
1042 }
1043 if *lambda <= 0.0 {
1044 return Err(ConfigError::out_of_range(
1045 "drift_detector.PageHinkley.lambda",
1046 "must be > 0",
1047 lambda,
1048 )
1049 .into());
1050 }
1051 }
1052 DriftDetectorType::Adwin { delta } => {
1053 if *delta <= 0.0 || *delta >= 1.0 {
1054 return Err(ConfigError::out_of_range(
1055 "drift_detector.Adwin.delta",
1056 "must be in (0, 1)",
1057 delta,
1058 )
1059 .into());
1060 }
1061 }
1062 DriftDetectorType::Ddm {
1063 warning_level,
1064 drift_level,
1065 min_instances,
1066 } => {
1067 if *warning_level <= 0.0 {
1068 return Err(ConfigError::out_of_range(
1069 "drift_detector.Ddm.warning_level",
1070 "must be > 0",
1071 warning_level,
1072 )
1073 .into());
1074 }
1075 if *drift_level <= 0.0 {
1076 return Err(ConfigError::out_of_range(
1077 "drift_detector.Ddm.drift_level",
1078 "must be > 0",
1079 drift_level,
1080 )
1081 .into());
1082 }
1083 if *drift_level <= *warning_level {
1084 return Err(ConfigError::invalid(
1085 "drift_detector.Ddm.drift_level",
1086 format!(
1087 "must be > warning_level ({}), got {}",
1088 warning_level, drift_level
1089 ),
1090 )
1091 .into());
1092 }
1093 if *min_instances == 0 {
1094 return Err(ConfigError::out_of_range(
1095 "drift_detector.Ddm.min_instances",
1096 "must be > 0",
1097 min_instances,
1098 )
1099 .into());
1100 }
1101 }
1102 }
1103
1104 // -- Variant parameters --
1105 match &c.variant {
1106 SGBTVariant::Standard => {} // no extra validation
1107 SGBTVariant::Skip { k } => {
1108 if *k == 0 {
1109 return Err(
1110 ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1111 );
1112 }
1113 }
1114 SGBTVariant::MultipleIterations { multiplier } => {
1115 if *multiplier <= 0.0 {
1116 return Err(ConfigError::out_of_range(
1117 "variant.MultipleIterations.multiplier",
1118 "must be > 0",
1119 multiplier,
1120 )
1121 .into());
1122 }
1123 }
1124 }
1125
1126 Ok(self.config)
1127 }
1128}
1129
1130// ---------------------------------------------------------------------------
1131// Display impls
1132// ---------------------------------------------------------------------------
1133
1134impl core::fmt::Display for DriftDetectorType {
1135 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1136 match self {
1137 Self::PageHinkley { delta, lambda } => {
1138 write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1139 }
1140 Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1141 Self::Ddm {
1142 warning_level,
1143 drift_level,
1144 min_instances,
1145 } => write!(
1146 f,
1147 "Ddm(warning={}, drift={}, min_instances={})",
1148 warning_level, drift_level, min_instances
1149 ),
1150 }
1151 }
1152}
1153
1154impl core::fmt::Display for SGBTConfig {
1155 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1156 write!(
1157 f,
1158 "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1159 self.n_steps,
1160 self.learning_rate,
1161 self.max_depth,
1162 self.n_bins,
1163 self.variant,
1164 self.drift_detector,
1165 )
1166 }
1167}
1168
1169// ---------------------------------------------------------------------------
1170// Tests
1171// ---------------------------------------------------------------------------
1172
1173#[cfg(test)]
1174mod tests {
1175 use super::*;
1176 use alloc::format;
1177 use alloc::vec;
1178
1179 // ------------------------------------------------------------------
1180 // 1. Default config values are correct
1181 // ------------------------------------------------------------------
1182 #[test]
1183 fn default_config_values() {
1184 let cfg = SGBTConfig::default();
1185 assert_eq!(cfg.n_steps, 100);
1186 assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1187 assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1188 assert_eq!(cfg.max_depth, 6);
1189 assert_eq!(cfg.n_bins, 64);
1190 assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1191 assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1192 assert_eq!(cfg.grace_period, 200);
1193 assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1194 assert_eq!(cfg.variant, SGBTVariant::Standard);
1195 }
1196
1197 // ------------------------------------------------------------------
1198 // 2. Builder chain works
1199 // ------------------------------------------------------------------
1200 #[test]
1201 fn builder_chain() {
1202 let cfg = SGBTConfig::builder()
1203 .n_steps(50)
1204 .learning_rate(0.1)
1205 .feature_subsample_rate(0.5)
1206 .max_depth(10)
1207 .n_bins(128)
1208 .lambda(0.5)
1209 .gamma(0.1)
1210 .grace_period(500)
1211 .delta(1e-3)
1212 .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1213 .variant(SGBTVariant::Skip { k: 5 })
1214 .build()
1215 .expect("valid config");
1216
1217 assert_eq!(cfg.n_steps, 50);
1218 assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1219 assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1220 assert_eq!(cfg.max_depth, 10);
1221 assert_eq!(cfg.n_bins, 128);
1222 assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1223 assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1224 assert_eq!(cfg.grace_period, 500);
1225 assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1226 assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1227
1228 // Verify drift detector type is Adwin.
1229 match &cfg.drift_detector {
1230 DriftDetectorType::Adwin { delta } => {
1231 assert!((delta - 0.01).abs() < f64::EPSILON);
1232 }
1233 _ => panic!("expected Adwin drift detector"),
1234 }
1235 }
1236
1237 // ------------------------------------------------------------------
1238 // 3. Validation rejects invalid values
1239 // ------------------------------------------------------------------
1240 #[test]
1241 fn validation_rejects_n_steps_zero() {
1242 let result = SGBTConfig::builder().n_steps(0).build();
1243 assert!(result.is_err());
1244 let msg = format!("{}", result.unwrap_err());
1245 assert!(msg.contains("n_steps"));
1246 }
1247
1248 #[test]
1249 fn validation_rejects_learning_rate_zero() {
1250 let result = SGBTConfig::builder().learning_rate(0.0).build();
1251 assert!(result.is_err());
1252 let msg = format!("{}", result.unwrap_err());
1253 assert!(msg.contains("learning_rate"));
1254 }
1255
1256 #[test]
1257 fn validation_rejects_learning_rate_above_one() {
1258 let result = SGBTConfig::builder().learning_rate(1.5).build();
1259 assert!(result.is_err());
1260 }
1261
1262 #[test]
1263 fn validation_accepts_learning_rate_one() {
1264 let result = SGBTConfig::builder().learning_rate(1.0).build();
1265 assert!(result.is_ok());
1266 }
1267
1268 #[test]
1269 fn validation_rejects_negative_learning_rate() {
1270 let result = SGBTConfig::builder().learning_rate(-0.1).build();
1271 assert!(result.is_err());
1272 }
1273
1274 #[test]
1275 fn validation_rejects_feature_subsample_zero() {
1276 let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1277 assert!(result.is_err());
1278 }
1279
1280 #[test]
1281 fn validation_rejects_feature_subsample_above_one() {
1282 let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1283 assert!(result.is_err());
1284 }
1285
1286 #[test]
1287 fn validation_rejects_max_depth_zero() {
1288 let result = SGBTConfig::builder().max_depth(0).build();
1289 assert!(result.is_err());
1290 }
1291
1292 #[test]
1293 fn validation_rejects_n_bins_one() {
1294 let result = SGBTConfig::builder().n_bins(1).build();
1295 assert!(result.is_err());
1296 }
1297
1298 #[test]
1299 fn validation_rejects_negative_lambda() {
1300 let result = SGBTConfig::builder().lambda(-0.1).build();
1301 assert!(result.is_err());
1302 }
1303
1304 #[test]
1305 fn validation_accepts_zero_lambda() {
1306 let result = SGBTConfig::builder().lambda(0.0).build();
1307 assert!(result.is_ok());
1308 }
1309
1310 #[test]
1311 fn validation_rejects_negative_gamma() {
1312 let result = SGBTConfig::builder().gamma(-0.1).build();
1313 assert!(result.is_err());
1314 }
1315
1316 #[test]
1317 fn validation_rejects_grace_period_zero() {
1318 let result = SGBTConfig::builder().grace_period(0).build();
1319 assert!(result.is_err());
1320 }
1321
1322 #[test]
1323 fn validation_rejects_delta_zero() {
1324 let result = SGBTConfig::builder().delta(0.0).build();
1325 assert!(result.is_err());
1326 }
1327
1328 #[test]
1329 fn validation_rejects_delta_one() {
1330 let result = SGBTConfig::builder().delta(1.0).build();
1331 assert!(result.is_err());
1332 }
1333
1334 // ------------------------------------------------------------------
1335 // 3b. Drift detector parameter validation
1336 // ------------------------------------------------------------------
1337 #[test]
1338 fn validation_rejects_pht_negative_delta() {
1339 let result = SGBTConfig::builder()
1340 .drift_detector(DriftDetectorType::PageHinkley {
1341 delta: -1.0,
1342 lambda: 50.0,
1343 })
1344 .build();
1345 assert!(result.is_err());
1346 let msg = format!("{}", result.unwrap_err());
1347 assert!(msg.contains("PageHinkley"));
1348 }
1349
1350 #[test]
1351 fn validation_rejects_pht_zero_lambda() {
1352 let result = SGBTConfig::builder()
1353 .drift_detector(DriftDetectorType::PageHinkley {
1354 delta: 0.005,
1355 lambda: 0.0,
1356 })
1357 .build();
1358 assert!(result.is_err());
1359 }
1360
1361 #[test]
1362 fn validation_rejects_adwin_delta_out_of_range() {
1363 let result = SGBTConfig::builder()
1364 .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1365 .build();
1366 assert!(result.is_err());
1367
1368 let result = SGBTConfig::builder()
1369 .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1370 .build();
1371 assert!(result.is_err());
1372 }
1373
1374 #[test]
1375 fn validation_rejects_ddm_warning_above_drift() {
1376 let result = SGBTConfig::builder()
1377 .drift_detector(DriftDetectorType::Ddm {
1378 warning_level: 3.0,
1379 drift_level: 2.0,
1380 min_instances: 30,
1381 })
1382 .build();
1383 assert!(result.is_err());
1384 let msg = format!("{}", result.unwrap_err());
1385 assert!(msg.contains("drift_level"));
1386 assert!(msg.contains("must be > warning_level"));
1387 }
1388
1389 #[test]
1390 fn validation_rejects_ddm_equal_levels() {
1391 let result = SGBTConfig::builder()
1392 .drift_detector(DriftDetectorType::Ddm {
1393 warning_level: 2.0,
1394 drift_level: 2.0,
1395 min_instances: 30,
1396 })
1397 .build();
1398 assert!(result.is_err());
1399 }
1400
1401 #[test]
1402 fn validation_rejects_ddm_zero_min_instances() {
1403 let result = SGBTConfig::builder()
1404 .drift_detector(DriftDetectorType::Ddm {
1405 warning_level: 2.0,
1406 drift_level: 3.0,
1407 min_instances: 0,
1408 })
1409 .build();
1410 assert!(result.is_err());
1411 }
1412
1413 #[test]
1414 fn validation_rejects_ddm_zero_warning_level() {
1415 let result = SGBTConfig::builder()
1416 .drift_detector(DriftDetectorType::Ddm {
1417 warning_level: 0.0,
1418 drift_level: 3.0,
1419 min_instances: 30,
1420 })
1421 .build();
1422 assert!(result.is_err());
1423 }
1424
1425 // ------------------------------------------------------------------
1426 // 3c. Variant parameter validation
1427 // ------------------------------------------------------------------
1428 #[test]
1429 fn validation_rejects_skip_k_zero() {
1430 let result = SGBTConfig::builder()
1431 .variant(SGBTVariant::Skip { k: 0 })
1432 .build();
1433 assert!(result.is_err());
1434 let msg = format!("{}", result.unwrap_err());
1435 assert!(msg.contains("Skip"));
1436 }
1437
1438 #[test]
1439 fn validation_rejects_mi_zero_multiplier() {
1440 let result = SGBTConfig::builder()
1441 .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1442 .build();
1443 assert!(result.is_err());
1444 }
1445
1446 #[test]
1447 fn validation_rejects_mi_negative_multiplier() {
1448 let result = SGBTConfig::builder()
1449 .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1450 .build();
1451 assert!(result.is_err());
1452 }
1453
1454 #[test]
1455 fn validation_accepts_standard_variant() {
1456 let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1457 assert!(result.is_ok());
1458 }
1459
1460 // ------------------------------------------------------------------
1461 // 4. DriftDetectorType creates correct detector types
1462 // ------------------------------------------------------------------
1463 #[test]
1464 fn drift_detector_type_creates_page_hinkley() {
1465 let dt = DriftDetectorType::PageHinkley {
1466 delta: 0.01,
1467 lambda: 100.0,
1468 };
1469 let mut detector = dt.create();
1470
1471 // Should start with zero mean.
1472 assert_eq!(detector.estimated_mean(), 0.0);
1473
1474 // Feed a stable value -- should not drift.
1475 let signal = detector.update(1.0);
1476 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1477 }
1478
1479 #[test]
1480 fn drift_detector_type_creates_adwin() {
1481 let dt = DriftDetectorType::Adwin { delta: 0.05 };
1482 let mut detector = dt.create();
1483
1484 assert_eq!(detector.estimated_mean(), 0.0);
1485 let signal = detector.update(1.0);
1486 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1487 }
1488
1489 #[test]
1490 fn drift_detector_type_creates_ddm() {
1491 let dt = DriftDetectorType::Ddm {
1492 warning_level: 2.0,
1493 drift_level: 3.0,
1494 min_instances: 30,
1495 };
1496 let mut detector = dt.create();
1497
1498 assert_eq!(detector.estimated_mean(), 0.0);
1499 let signal = detector.update(0.1);
1500 assert_eq!(signal, crate::drift::DriftSignal::Stable);
1501 }
1502
1503 // ------------------------------------------------------------------
1504 // 5. Default drift detector is PageHinkley
1505 // ------------------------------------------------------------------
1506 #[test]
1507 fn default_drift_detector_is_page_hinkley() {
1508 let dt = DriftDetectorType::default();
1509 match dt {
1510 DriftDetectorType::PageHinkley { delta, lambda } => {
1511 assert!((delta - 0.005).abs() < f64::EPSILON);
1512 assert!((lambda - 50.0).abs() < f64::EPSILON);
1513 }
1514 _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1515 }
1516 }
1517
1518 // ------------------------------------------------------------------
1519 // 6. Default config builds without error
1520 // ------------------------------------------------------------------
1521 #[test]
1522 fn default_config_builds() {
1523 let result = SGBTConfig::builder().build();
1524 assert!(result.is_ok());
1525 }
1526
1527 // ------------------------------------------------------------------
1528 // 7. Config clone preserves all fields
1529 // ------------------------------------------------------------------
1530 #[test]
1531 fn config_clone_preserves_fields() {
1532 let original = SGBTConfig::builder()
1533 .n_steps(50)
1534 .learning_rate(0.05)
1535 .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1536 .build()
1537 .unwrap();
1538
1539 let cloned = original.clone();
1540
1541 assert_eq!(cloned.n_steps, 50);
1542 assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1543 assert_eq!(
1544 cloned.variant,
1545 SGBTVariant::MultipleIterations { multiplier: 5.0 }
1546 );
1547 }
1548
1549 // ------------------------------------------------------------------
1550 // 8. All three DDM valid configs accepted
1551 // ------------------------------------------------------------------
1552 #[test]
1553 fn valid_ddm_config_accepted() {
1554 let result = SGBTConfig::builder()
1555 .drift_detector(DriftDetectorType::Ddm {
1556 warning_level: 1.5,
1557 drift_level: 2.5,
1558 min_instances: 10,
1559 })
1560 .build();
1561 assert!(result.is_ok());
1562 }
1563
1564 // ------------------------------------------------------------------
1565 // 9. Created detectors are functional (round-trip test)
1566 // ------------------------------------------------------------------
1567 #[test]
1568 fn created_detectors_are_functional() {
1569 // Create a detector, feed it data, verify it can detect drift.
1570 let dt = DriftDetectorType::PageHinkley {
1571 delta: 0.005,
1572 lambda: 50.0,
1573 };
1574 let mut detector = dt.create();
1575
1576 // Feed 500 stable values.
1577 for _ in 0..500 {
1578 detector.update(1.0);
1579 }
1580
1581 // Feed shifted values -- should eventually drift.
1582 let mut drifted = false;
1583 for _ in 0..500 {
1584 if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1585 drifted = true;
1586 break;
1587 }
1588 }
1589 assert!(
1590 drifted,
1591 "detector created from DriftDetectorType should be functional"
1592 );
1593 }
1594
1595 // ------------------------------------------------------------------
1596 // 10. Boundary acceptance: n_bins=2 is the minimum valid
1597 // ------------------------------------------------------------------
1598 #[test]
1599 fn boundary_n_bins_two_accepted() {
1600 let result = SGBTConfig::builder().n_bins(2).build();
1601 assert!(result.is_ok());
1602 }
1603
1604 // ------------------------------------------------------------------
1605 // 11. Boundary acceptance: grace_period=1 is valid
1606 // ------------------------------------------------------------------
1607 #[test]
1608 fn boundary_grace_period_one_accepted() {
1609 let result = SGBTConfig::builder().grace_period(1).build();
1610 assert!(result.is_ok());
1611 }
1612
1613 // ------------------------------------------------------------------
1614 // 12. Feature names -- valid config with names
1615 // ------------------------------------------------------------------
1616 #[test]
1617 fn feature_names_accepted() {
1618 let cfg = SGBTConfig::builder()
1619 .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1620 .build()
1621 .unwrap();
1622 assert_eq!(
1623 cfg.feature_names.as_ref().unwrap(),
1624 &["price", "volume", "spread"]
1625 );
1626 }
1627
1628 // ------------------------------------------------------------------
1629 // 13. Feature names -- duplicate names rejected
1630 // ------------------------------------------------------------------
1631 #[test]
1632 fn feature_names_rejects_duplicates() {
1633 let result = SGBTConfig::builder()
1634 .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1635 .build();
1636 assert!(result.is_err());
1637 let msg = format!("{}", result.unwrap_err());
1638 assert!(msg.contains("duplicate"));
1639 }
1640
1641 // ------------------------------------------------------------------
1642 // 14. Feature names -- serde backward compat (missing field)
1643 // Requires serde_json which is only in the full irithyll crate.
1644 // ------------------------------------------------------------------
1645
1646 // ------------------------------------------------------------------
1647 // 15. Feature names -- empty vec is valid
1648 // ------------------------------------------------------------------
1649 #[test]
1650 fn feature_names_empty_vec_accepted() {
1651 let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1652 assert!(cfg.feature_names.unwrap().is_empty());
1653 }
1654
1655 // ------------------------------------------------------------------
1656 // 16. Adaptive leaf bound -- builder
1657 // ------------------------------------------------------------------
1658 #[test]
1659 fn builder_adaptive_leaf_bound() {
1660 let cfg = SGBTConfig::builder()
1661 .adaptive_leaf_bound(3.0)
1662 .build()
1663 .unwrap();
1664 assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1665 }
1666
1667 // ------------------------------------------------------------------
1668 // 17. Adaptive leaf bound -- validation rejects zero
1669 // ------------------------------------------------------------------
1670 #[test]
1671 fn validation_rejects_zero_adaptive_leaf_bound() {
1672 let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1673 assert!(result.is_err());
1674 let msg = format!("{}", result.unwrap_err());
1675 assert!(
1676 msg.contains("adaptive_leaf_bound"),
1677 "error should mention adaptive_leaf_bound: {}",
1678 msg,
1679 );
1680 }
1681
1682 // ------------------------------------------------------------------
1683 // 18. Adaptive leaf bound -- serde backward compat
1684 // Requires serde_json which is only in the full irithyll crate.
1685 // ------------------------------------------------------------------
1686}