irithyll_core/ensemble/config.rs
1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34/// `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1) compute.
35/// Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38/// log(σ) from features (NGBoost-style dual chain). Gives *feature-conditional*
39/// uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43 /// EWMA of squared prediction errors — always calibrated, O(1).
44 #[default]
45 Empirical,
46 /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47 TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61 /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62 /// (detection threshold).
63 PageHinkley {
64 /// Magnitude tolerance. Default 0.005.
65 delta: f64,
66 /// Detection threshold. Default 50.0.
67 lambda: f64,
68 },
69
70 /// ADWIN with custom confidence parameter.
71 Adwin {
72 /// Confidence (smaller = fewer false positives). Default 0.002.
73 delta: f64,
74 },
75
76 /// DDM with custom warning/drift levels and minimum warmup instances.
77 Ddm {
78 /// Warning threshold multiplier. Default 2.0.
79 warning_level: f64,
80 /// Drift threshold multiplier. Default 3.0.
81 drift_level: f64,
82 /// Minimum observations before detection activates. Default 30.
83 min_instances: u64,
84 },
85}
86
87impl Default for DriftDetectorType {
88 fn default() -> Self {
89 DriftDetectorType::PageHinkley {
90 delta: 0.005,
91 lambda: 50.0,
92 }
93 }
94}
95
96impl DriftDetectorType {
97 /// Create a new, fresh drift detector from this configuration.
98 pub fn create(&self) -> Box<dyn DriftDetector> {
99 match self {
100 Self::PageHinkley { delta, lambda } => {
101 Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102 }
103 Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104 Self::Ddm {
105 warning_level,
106 drift_level,
107 min_instances,
108 } => Box::new(Ddm::with_params(
109 *warning_level,
110 *drift_level,
111 *min_instances,
112 )),
113 }
114 }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter | Default |
128/// |--------------------------|----------------------|
129/// | `n_steps` | 100 |
130/// | `learning_rate` | 0.0125 |
131/// | `feature_subsample_rate` | 0.75 |
132/// | `max_depth` | 6 |
133/// | `n_bins` | 64 |
134/// | `lambda` | 1.0 |
135/// | `gamma` | 0.0 |
136/// | `grace_period` | 200 |
137/// | `delta` | 1e-7 |
138/// | `drift_detector` | PageHinkley(0.005, 50.0) |
139/// | `variant` | Standard |
140/// | `seed` | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count` | 50 |
142/// | `leaf_half_life` | None (disabled) |
143/// | `max_tree_samples` | None (disabled) |
144/// | `split_reeval_interval` | None (disabled) |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148 /// Number of boosting steps (trees in the ensemble). Default 100.
149 pub n_steps: usize,
150 /// Learning rate (shrinkage). Default 0.0125.
151 pub learning_rate: f64,
152 /// Fraction of features to subsample per tree. Default 0.75.
153 pub feature_subsample_rate: f64,
154 /// Maximum tree depth. Default 6.
155 pub max_depth: usize,
156 /// Number of histogram bins. Default 64.
157 pub n_bins: usize,
158 /// L2 regularization parameter (lambda). Default 1.0.
159 pub lambda: f64,
160 /// Minimum split gain (gamma). Default 0.0.
161 pub gamma: f64,
162 /// Grace period: min samples before evaluating splits. Default 200.
163 pub grace_period: usize,
164 /// Hoeffding bound confidence (delta). Default 1e-7.
165 pub delta: f64,
166 /// Drift detector type for tree replacement. Default: PageHinkley.
167 pub drift_detector: DriftDetectorType,
168 /// SGBT computational variant. Default: Standard.
169 pub variant: SGBTVariant,
170 /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171 ///
172 /// Controls feature subsampling and variant skip/MI stochastic decisions.
173 /// Two models with the same seed and same data will produce identical results.
174 pub seed: u64,
175 /// Number of initial targets to collect before computing the base prediction.
176 /// Default: 50.
177 pub initial_target_count: usize,
178
179 /// Half-life for exponential leaf decay (in samples per leaf).
180 ///
181 /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182 /// statistics have half the weight of the most recent sample. This causes
183 /// the model to continuously adapt to changing data distributions rather
184 /// than freezing on early observations.
185 ///
186 /// `None` (default) disables decay -- traditional monotonic accumulation.
187 #[cfg_attr(feature = "serde", serde(default))]
188 pub leaf_half_life: Option<usize>,
189
190 /// Maximum samples a single tree processes before proactive replacement.
191 ///
192 /// After this many samples, the tree is replaced with a fresh one regardless
193 /// of drift detector state. Prevents stale tree structure from persisting
194 /// when the drift detector is not sensitive enough.
195 ///
196 /// `None` (default) disables time-based replacement.
197 #[cfg_attr(feature = "serde", serde(default))]
198 pub max_tree_samples: Option<u64>,
199
200 /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
201 /// whether a split would improve them.
202 ///
203 /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
204 /// `split_reeval_interval` samples since its last evaluation and has reached
205 /// max depth, it re-evaluates whether a split should be performed.
206 ///
207 /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
208 #[cfg_attr(feature = "serde", serde(default))]
209 pub split_reeval_interval: Option<usize>,
210
211 /// Optional human-readable feature names.
212 ///
213 /// When set, enables `named_feature_importances` and
214 /// `train_one_named` for production-friendly named access.
215 /// Length must match the number of features in training data.
216 #[cfg_attr(feature = "serde", serde(default))]
217 pub feature_names: Option<Vec<String>>,
218
219 /// Optional per-feature type declarations.
220 ///
221 /// When set, declares which features are categorical vs continuous.
222 /// Categorical features use one-bin-per-category binning and Fisher
223 /// optimal binary partitioning for split evaluation.
224 /// Length must match the number of features in training data.
225 ///
226 /// `None` (default) treats all features as continuous.
227 #[cfg_attr(feature = "serde", serde(default))]
228 pub feature_types: Option<Vec<FeatureType>>,
229
230 /// Gradient clipping threshold in standard deviations per leaf.
231 ///
232 /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
233 /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
234 /// clamped to the boundary. This prevents outlier samples from corrupting
235 /// leaf statistics, which is critical in streaming settings where sudden
236 /// label floods can destabilize the model.
237 ///
238 /// Typical value: 3.0 (3-sigma clipping).
239 /// `None` (default) disables gradient clipping.
240 #[cfg_attr(feature = "serde", serde(default))]
241 pub gradient_clip_sigma: Option<f64>,
242
243 /// Per-feature monotonic constraints.
244 ///
245 /// Each element specifies the monotonic relationship between a feature and
246 /// the prediction:
247 /// - `+1`: prediction must be non-decreasing as feature value increases.
248 /// - `-1`: prediction must be non-increasing as feature value increases.
249 /// - `0`: no constraint (unconstrained).
250 ///
251 /// During split evaluation, candidate splits that would violate monotonicity
252 /// (left child value > right child value for +1 constraints, or vice versa)
253 /// are rejected.
254 ///
255 /// Length must match the number of features in training data.
256 /// `None` (default) means no monotonic constraints.
257 #[cfg_attr(feature = "serde", serde(default))]
258 pub monotone_constraints: Option<Vec<i8>>,
259
260 /// EWMA smoothing factor for quality-based tree pruning.
261 ///
262 /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
263 /// moving average of its marginal contribution to the ensemble. Trees whose
264 /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
265 /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
266 /// samples are replaced with a fresh tree that can learn the current regime.
267 ///
268 /// This prevents "dead wood" -- trees from a past regime that no longer
269 /// contribute meaningfully to ensemble accuracy.
270 ///
271 /// `None` (default) disables quality-based pruning.
272 /// Suggested value: 0.01.
273 #[cfg_attr(feature = "serde", serde(default))]
274 pub quality_prune_alpha: Option<f64>,
275
276 /// Minimum contribution threshold for quality-based pruning.
277 ///
278 /// A tree's EWMA contribution must stay above this value to avoid being
279 /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
280 ///
281 /// Default: 1e-6.
282 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
283 pub quality_prune_threshold: f64,
284
285 /// Consecutive low-contribution samples before a tree is replaced.
286 ///
287 /// After this many consecutive samples where a tree's EWMA contribution
288 /// is below `quality_prune_threshold`, the tree is reset. Only used when
289 /// `quality_prune_alpha` is `Some`.
290 ///
291 /// Default: 500.
292 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
293 pub quality_prune_patience: u64,
294
295 /// EWMA smoothing factor for error-weighted sample importance.
296 ///
297 /// When `Some(alpha)`, samples the model predicted poorly get higher
298 /// effective weight during histogram accumulation. The weight is:
299 /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
300 ///
301 /// This is a streaming version of AdaBoost's reweighting applied at the
302 /// gradient level -- learning capacity focuses on hard/novel patterns,
303 /// enabling faster adaptation to regime changes.
304 ///
305 /// `None` (default) disables error weighting.
306 /// Suggested value: 0.01.
307 #[cfg_attr(feature = "serde", serde(default))]
308 pub error_weight_alpha: Option<f64>,
309
310 /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
311 ///
312 /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
313 /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
314 /// is an EWMA of the model's predicted σ (alpha = 0.001).
315 ///
316 /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
317 /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
318 /// trains at the unmodulated base rate to prevent positive feedback loops.
319 ///
320 /// Default: `false`.
321 #[cfg_attr(feature = "serde", serde(default))]
322 pub uncertainty_modulated_lr: bool,
323
324 /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
325 ///
326 /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
327 /// errors. `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1).
328 /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
329 /// separate tree ensemble predicting log(σ) from features.
330 ///
331 /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
332 /// is strongly recommended — scale tree gradients are inherently weak and
333 /// the trees often fail to split.
334 #[cfg_attr(feature = "serde", serde(default))]
335 pub scale_mode: ScaleMode,
336
337 /// EWMA smoothing factor for empirical σ estimation.
338 ///
339 /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
340 /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
341 /// Higher values react faster to regime changes but are noisier.
342 ///
343 /// Default: `0.01` (~100-sample effective window).
344 #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
345 pub empirical_sigma_alpha: f64,
346
347 /// Maximum absolute leaf output value.
348 ///
349 /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
350 /// Prevents runaway leaf weights from causing prediction explosions
351 /// in feedback loops. `None` (default) means no clamping.
352 #[cfg_attr(feature = "serde", serde(default))]
353 pub max_leaf_output: Option<f64>,
354
355 /// Per-leaf adaptive output bound (sigma multiplier).
356 ///
357 /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
358 /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
359 /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
360 ///
361 /// This is strictly superior to `max_leaf_output` for streaming — the bound
362 /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
363 /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
364 ///
365 /// Typical value: 3.0 (3-sigma bound).
366 /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
367 #[cfg_attr(feature = "serde", serde(default))]
368 pub adaptive_leaf_bound: Option<f64>,
369
370 /// Minimum hessian sum before a leaf produces non-zero output.
371 ///
372 /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
373 /// Prevents post-replacement spikes from fresh leaves with insufficient
374 /// samples. `None` (default) means all leaves contribute immediately.
375 #[cfg_attr(feature = "serde", serde(default))]
376 pub min_hessian_sum: Option<f64>,
377
378 /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
379 ///
380 /// When `Some(k)`, the distributional location gradient uses Huber loss
381 /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
382 /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
383 /// `None` (default) uses squared loss.
384 #[cfg_attr(feature = "serde", serde(default))]
385 pub huber_k: Option<f64>,
386
387 /// Shadow warmup for graduated tree handoff.
388 ///
389 /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
390 /// alongside every active tree. The shadow trains on the same gradient stream
391 /// but does not contribute to predictions until it has seen `n` samples.
392 ///
393 /// As the active tree ages past 80% of `max_tree_samples`, its prediction
394 /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
395 /// over `n` samples after warmup. When the active weight reaches 0, the shadow
396 /// is promoted and a new shadow is spawned — no cold-start prediction dip.
397 ///
398 /// Requires `max_tree_samples` to be set for time-based graduated handoff.
399 /// Drift-based replacement still uses hard swap (shadow is already warm).
400 ///
401 /// `None` (default) disables graduated handoff — uses traditional hard swap.
402 #[cfg_attr(feature = "serde", serde(default))]
403 pub shadow_warmup: Option<usize>,
404
405 /// Leaf prediction model type.
406 ///
407 /// Controls how each leaf computes its prediction:
408 /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
409 /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
410 /// AdaGrad optimization. Optional `decay` for concept drift. Recommended for
411 /// low-depth trees (depth 2--4).
412 /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
413 /// Optional `decay` for concept drift.
414 /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
415 /// when the Hoeffding bound confirms a more complex model is better.
416 ///
417 /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
418 #[cfg_attr(feature = "serde", serde(default))]
419 pub leaf_model_type: LeafModelType,
420
421 /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
422 ///
423 /// When non-zero, the distributional model maintains a packed f32 cache of
424 /// its location ensemble that is re-exported every `packed_refresh_interval`
425 /// training samples. Predictions use the cache for O(1)-per-tree inference
426 /// via contiguous memory traversal, falling back to full tree traversal when
427 /// the cache is absent or produces non-finite results.
428 ///
429 /// `0` (default) disables the packed cache.
430 #[cfg_attr(feature = "serde", serde(default))]
431 pub packed_refresh_interval: u64,
432}
433
434#[cfg(feature = "serde")]
435fn default_empirical_sigma_alpha() -> f64 {
436 0.01
437}
438
439#[cfg(feature = "serde")]
440fn default_quality_prune_threshold() -> f64 {
441 1e-6
442}
443
444#[cfg(feature = "serde")]
445fn default_quality_prune_patience() -> u64 {
446 500
447}
448
449impl Default for SGBTConfig {
450 fn default() -> Self {
451 Self {
452 n_steps: 100,
453 learning_rate: 0.0125,
454 feature_subsample_rate: 0.75,
455 max_depth: 6,
456 n_bins: 64,
457 lambda: 1.0,
458 gamma: 0.0,
459 grace_period: 200,
460 delta: 1e-7,
461 drift_detector: DriftDetectorType::default(),
462 variant: SGBTVariant::default(),
463 seed: 0xDEAD_BEEF_CAFE_4242,
464 initial_target_count: 50,
465 leaf_half_life: None,
466 max_tree_samples: None,
467 split_reeval_interval: None,
468 feature_names: None,
469 feature_types: None,
470 gradient_clip_sigma: None,
471 monotone_constraints: None,
472 quality_prune_alpha: None,
473 quality_prune_threshold: 1e-6,
474 quality_prune_patience: 500,
475 error_weight_alpha: None,
476 uncertainty_modulated_lr: false,
477 scale_mode: ScaleMode::default(),
478 empirical_sigma_alpha: 0.01,
479 max_leaf_output: None,
480 adaptive_leaf_bound: None,
481 min_hessian_sum: None,
482 huber_k: None,
483 shadow_warmup: None,
484 leaf_model_type: LeafModelType::default(),
485 packed_refresh_interval: 0,
486 }
487 }
488}
489
490impl SGBTConfig {
491 /// Start building a configuration via the builder pattern.
492 pub fn builder() -> SGBTConfigBuilder {
493 SGBTConfigBuilder::default()
494 }
495}
496
497// ---------------------------------------------------------------------------
498// SGBTConfigBuilder
499// ---------------------------------------------------------------------------
500
501/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
502///
503/// # Example
504///
505/// ```ignore
506/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
507/// use irithyll::ensemble::variants::SGBTVariant;
508///
509/// let config = SGBTConfig::builder()
510/// .n_steps(200)
511/// .learning_rate(0.05)
512/// .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
513/// .variant(SGBTVariant::Skip { k: 10 })
514/// .build()
515/// .expect("valid config");
516/// ```
517#[derive(Debug, Clone, Default)]
518pub struct SGBTConfigBuilder {
519 config: SGBTConfig,
520}
521
522impl SGBTConfigBuilder {
523 /// Set the number of boosting steps (trees in the ensemble).
524 pub fn n_steps(mut self, n: usize) -> Self {
525 self.config.n_steps = n;
526 self
527 }
528
529 /// Set the learning rate (shrinkage factor).
530 pub fn learning_rate(mut self, lr: f64) -> Self {
531 self.config.learning_rate = lr;
532 self
533 }
534
535 /// Set the fraction of features to subsample per tree.
536 pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
537 self.config.feature_subsample_rate = rate;
538 self
539 }
540
541 /// Set the maximum tree depth.
542 pub fn max_depth(mut self, depth: usize) -> Self {
543 self.config.max_depth = depth;
544 self
545 }
546
547 /// Set the number of histogram bins per feature.
548 pub fn n_bins(mut self, bins: usize) -> Self {
549 self.config.n_bins = bins;
550 self
551 }
552
553 /// Set the L2 regularization parameter (lambda).
554 pub fn lambda(mut self, l: f64) -> Self {
555 self.config.lambda = l;
556 self
557 }
558
559 /// Set the minimum split gain (gamma).
560 pub fn gamma(mut self, g: f64) -> Self {
561 self.config.gamma = g;
562 self
563 }
564
565 /// Set the grace period (minimum samples before evaluating splits).
566 pub fn grace_period(mut self, gp: usize) -> Self {
567 self.config.grace_period = gp;
568 self
569 }
570
571 /// Set the Hoeffding bound confidence parameter (delta).
572 pub fn delta(mut self, d: f64) -> Self {
573 self.config.delta = d;
574 self
575 }
576
577 /// Set the drift detector type for tree replacement.
578 pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
579 self.config.drift_detector = dt;
580 self
581 }
582
583 /// Set the SGBT computational variant.
584 pub fn variant(mut self, v: SGBTVariant) -> Self {
585 self.config.variant = v;
586 self
587 }
588
589 /// Set the random seed for deterministic reproducibility.
590 ///
591 /// Controls feature subsampling and variant skip/MI stochastic decisions.
592 /// Two models with the same seed and data sequence will produce identical results.
593 pub fn seed(mut self, seed: u64) -> Self {
594 self.config.seed = seed;
595 self
596 }
597
598 /// Set the number of initial targets to collect before computing the base prediction.
599 ///
600 /// The model collects this many target values before initializing the base
601 /// prediction (via `loss.initial_prediction`). Default: 50.
602 pub fn initial_target_count(mut self, count: usize) -> Self {
603 self.config.initial_target_count = count;
604 self
605 }
606
607 /// Set the half-life for exponential leaf decay (in samples per leaf).
608 ///
609 /// After `n` samples, a leaf's accumulated statistics have half the weight
610 /// of the most recent sample. Enables continuous adaptation to concept drift.
611 pub fn leaf_half_life(mut self, n: usize) -> Self {
612 self.config.leaf_half_life = Some(n);
613 self
614 }
615
616 /// Set the maximum samples a single tree processes before proactive replacement.
617 ///
618 /// After `n` samples, the tree is replaced regardless of drift detector state.
619 pub fn max_tree_samples(mut self, n: u64) -> Self {
620 self.config.max_tree_samples = Some(n);
621 self
622 }
623
624 /// Set the split re-evaluation interval for max-depth leaves.
625 ///
626 /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
627 /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
628 pub fn split_reeval_interval(mut self, n: usize) -> Self {
629 self.config.split_reeval_interval = Some(n);
630 self
631 }
632
633 /// Set human-readable feature names.
634 ///
635 /// Enables named feature importances and named training input.
636 /// Names must be unique; validated at [`build()`](Self::build).
637 pub fn feature_names(mut self, names: Vec<String>) -> Self {
638 self.config.feature_names = Some(names);
639 self
640 }
641
642 /// Set per-feature type declarations.
643 ///
644 /// Declares which features are categorical vs continuous. Categorical features
645 /// use one-bin-per-category binning and Fisher optimal binary partitioning.
646 /// Supports up to 64 distinct category values per categorical feature.
647 pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
648 self.config.feature_types = Some(types);
649 self
650 }
651
652 /// Set per-leaf gradient clipping threshold (in standard deviations).
653 ///
654 /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
655 /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
656 /// from corrupting streaming model stability.
657 ///
658 /// Typical value: 3.0 (3-sigma clipping).
659 pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
660 self.config.gradient_clip_sigma = Some(sigma);
661 self
662 }
663
664 /// Set per-feature monotonic constraints.
665 ///
666 /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
667 /// Candidate splits violating monotonicity are rejected during tree growth.
668 pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
669 self.config.monotone_constraints = Some(constraints);
670 self
671 }
672
673 /// Enable quality-based tree pruning with the given EWMA smoothing factor.
674 ///
675 /// Trees whose marginal contribution drops below the threshold for
676 /// `patience` consecutive samples are replaced with fresh trees.
677 /// Suggested alpha: 0.01.
678 pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
679 self.config.quality_prune_alpha = Some(alpha);
680 self
681 }
682
683 /// Set the minimum contribution threshold for quality-based pruning.
684 ///
685 /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
686 pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
687 self.config.quality_prune_threshold = threshold;
688 self
689 }
690
691 /// Set the patience (consecutive low-contribution samples) before pruning.
692 ///
693 /// Default: 500. Only relevant when `quality_prune_alpha` is set.
694 pub fn quality_prune_patience(mut self, patience: u64) -> Self {
695 self.config.quality_prune_patience = patience;
696 self
697 }
698
699 /// Enable error-weighted sample importance with the given EWMA smoothing factor.
700 ///
701 /// Samples the model predicted poorly get higher effective weight.
702 /// Suggested alpha: 0.01.
703 pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
704 self.config.error_weight_alpha = Some(alpha);
705 self
706 }
707
708 /// Enable σ-modulated learning rate for distributional models.
709 ///
710 /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
711 /// so the model adapts faster during high-uncertainty regimes and conserves
712 /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
713 ///
714 /// By default uses empirical σ (EWMA of squared errors). Set
715 /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
716 pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
717 self.config.uncertainty_modulated_lr = enabled;
718 self
719 }
720
721 /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
722 ///
723 /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
724 /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
725 pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
726 self.config.scale_mode = mode;
727 self
728 }
729
730 /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
731 ///
732 /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
733 pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
734 self.config.empirical_sigma_alpha = alpha;
735 self
736 }
737
738 /// Set the maximum absolute leaf output value.
739 ///
740 /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
741 /// that cause prediction explosions.
742 pub fn max_leaf_output(mut self, max: f64) -> Self {
743 self.config.max_leaf_output = Some(max);
744 self
745 }
746
747 /// Set per-leaf adaptive output bound (sigma multiplier).
748 ///
749 /// Each leaf tracks EWMA of its own output weight and clamps to
750 /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
751 /// Recommended: use with `leaf_half_life` for streaming scenarios.
752 pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
753 self.config.adaptive_leaf_bound = Some(k);
754 self
755 }
756
757 /// Set the minimum hessian sum for leaf output.
758 ///
759 /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
760 /// post-replacement spikes.
761 pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
762 self.config.min_hessian_sum = Some(min_h);
763 self
764 }
765
766 /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
767 ///
768 /// When set, location gradients use Huber loss with adaptive
769 /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
770 pub fn huber_k(mut self, k: f64) -> Self {
771 self.config.huber_k = Some(k);
772 self
773 }
774
775 /// Enable graduated tree handoff with the given shadow warmup samples.
776 ///
777 /// Spawns an always-on shadow tree that trains alongside the active tree.
778 /// After `warmup` samples, the shadow begins contributing to predictions
779 /// via graduated blending. Eliminates prediction dips during tree replacement.
780 pub fn shadow_warmup(mut self, warmup: usize) -> Self {
781 self.config.shadow_warmup = Some(warmup);
782 self
783 }
784
785 /// Set the leaf prediction model type.
786 ///
787 /// [`LeafModelType::Linear`] is recommended for low-depth configurations
788 /// (depth 2--4) where per-leaf linear models reduce approximation error.
789 ///
790 /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
791 /// a trainable model per leaf, using the Hoeffding bound for promotion.
792 pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
793 self.config.leaf_model_type = lmt;
794 self
795 }
796
797 /// Set the packed cache refresh interval for distributional models.
798 ///
799 /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
800 /// maintains a packed f32 cache refreshed every `interval` training samples.
801 /// `0` (default) disables the cache.
802 pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
803 self.config.packed_refresh_interval = interval;
804 self
805 }
806
807 /// Validate and build the configuration.
808 ///
809 /// # Errors
810 ///
811 /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
812 /// [`ConfigError`] if any parameter is out of its valid range.
813 pub fn build(self) -> Result<SGBTConfig> {
814 let c = &self.config;
815
816 // -- Ensemble-level parameters --
817 if c.n_steps == 0 {
818 return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
819 }
820 if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
821 return Err(ConfigError::out_of_range(
822 "learning_rate",
823 "must be in (0, 1]",
824 c.learning_rate,
825 )
826 .into());
827 }
828 if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
829 return Err(ConfigError::out_of_range(
830 "feature_subsample_rate",
831 "must be in (0, 1]",
832 c.feature_subsample_rate,
833 )
834 .into());
835 }
836
837 // -- Tree-level parameters --
838 if c.max_depth == 0 {
839 return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
840 }
841 if c.n_bins < 2 {
842 return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
843 }
844 if c.lambda < 0.0 {
845 return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
846 }
847 if c.gamma < 0.0 {
848 return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
849 }
850 if c.grace_period == 0 {
851 return Err(
852 ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
853 );
854 }
855 if c.delta <= 0.0 || c.delta >= 1.0 {
856 return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
857 }
858
859 if c.initial_target_count == 0 {
860 return Err(ConfigError::out_of_range(
861 "initial_target_count",
862 "must be > 0",
863 c.initial_target_count,
864 )
865 .into());
866 }
867
868 // -- Streaming adaptation parameters --
869 if let Some(hl) = c.leaf_half_life {
870 if hl == 0 {
871 return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
872 }
873 }
874 if let Some(max) = c.max_tree_samples {
875 if max < 100 {
876 return Err(
877 ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
878 );
879 }
880 }
881 if let Some(interval) = c.split_reeval_interval {
882 if interval < c.grace_period {
883 return Err(ConfigError::invalid(
884 "split_reeval_interval",
885 format!(
886 "must be >= grace_period ({}), got {}",
887 c.grace_period, interval
888 ),
889 )
890 .into());
891 }
892 }
893
894 // -- Feature names --
895 if let Some(ref names) = c.feature_names {
896 // O(n^2) duplicate check — names list is small so no HashSet needed
897 for (i, name) in names.iter().enumerate() {
898 for prev in &names[..i] {
899 if name == prev {
900 return Err(ConfigError::invalid(
901 "feature_names",
902 format!("duplicate feature name: '{}'", name),
903 )
904 .into());
905 }
906 }
907 }
908 }
909
910 // -- Feature types --
911 if let Some(ref types) = c.feature_types {
912 if let Some(ref names) = c.feature_names {
913 if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
914 return Err(ConfigError::invalid(
915 "feature_types",
916 format!(
917 "length ({}) must match feature_names length ({})",
918 types.len(),
919 names.len()
920 ),
921 )
922 .into());
923 }
924 }
925 }
926
927 // -- Gradient clipping --
928 if let Some(sigma) = c.gradient_clip_sigma {
929 if sigma <= 0.0 {
930 return Err(
931 ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
932 );
933 }
934 }
935
936 // -- Monotonic constraints --
937 if let Some(ref mc) = c.monotone_constraints {
938 for (i, &v) in mc.iter().enumerate() {
939 if v != -1 && v != 0 && v != 1 {
940 return Err(ConfigError::invalid(
941 "monotone_constraints",
942 format!("feature {}: must be -1, 0, or +1, got {}", i, v),
943 )
944 .into());
945 }
946 }
947 }
948
949 // -- Leaf output clamping --
950 if let Some(max) = c.max_leaf_output {
951 if max <= 0.0 {
952 return Err(
953 ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
954 );
955 }
956 }
957
958 // -- Per-leaf adaptive output bound --
959 if let Some(k) = c.adaptive_leaf_bound {
960 if k <= 0.0 {
961 return Err(
962 ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
963 );
964 }
965 }
966
967 // -- Minimum hessian sum --
968 if let Some(min_h) = c.min_hessian_sum {
969 if min_h <= 0.0 {
970 return Err(
971 ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
972 );
973 }
974 }
975
976 // -- Huber loss multiplier --
977 if let Some(k) = c.huber_k {
978 if k <= 0.0 {
979 return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
980 }
981 }
982
983 // -- Shadow warmup --
984 if let Some(warmup) = c.shadow_warmup {
985 if warmup == 0 {
986 return Err(ConfigError::out_of_range(
987 "shadow_warmup",
988 "must be > 0",
989 warmup as f64,
990 )
991 .into());
992 }
993 }
994
995 // -- Quality-based pruning parameters --
996 if let Some(alpha) = c.quality_prune_alpha {
997 if alpha <= 0.0 || alpha >= 1.0 {
998 return Err(ConfigError::out_of_range(
999 "quality_prune_alpha",
1000 "must be in (0, 1)",
1001 alpha,
1002 )
1003 .into());
1004 }
1005 }
1006 if c.quality_prune_threshold <= 0.0 {
1007 return Err(ConfigError::out_of_range(
1008 "quality_prune_threshold",
1009 "must be > 0",
1010 c.quality_prune_threshold,
1011 )
1012 .into());
1013 }
1014 if c.quality_prune_patience == 0 {
1015 return Err(ConfigError::out_of_range(
1016 "quality_prune_patience",
1017 "must be > 0",
1018 c.quality_prune_patience,
1019 )
1020 .into());
1021 }
1022
1023 // -- Error-weighted sample importance --
1024 if let Some(alpha) = c.error_weight_alpha {
1025 if alpha <= 0.0 || alpha >= 1.0 {
1026 return Err(ConfigError::out_of_range(
1027 "error_weight_alpha",
1028 "must be in (0, 1)",
1029 alpha,
1030 )
1031 .into());
1032 }
1033 }
1034
1035 // -- Drift detector parameters --
1036 match &c.drift_detector {
1037 DriftDetectorType::PageHinkley { delta, lambda } => {
1038 if *delta <= 0.0 {
1039 return Err(ConfigError::out_of_range(
1040 "drift_detector.PageHinkley.delta",
1041 "must be > 0",
1042 delta,
1043 )
1044 .into());
1045 }
1046 if *lambda <= 0.0 {
1047 return Err(ConfigError::out_of_range(
1048 "drift_detector.PageHinkley.lambda",
1049 "must be > 0",
1050 lambda,
1051 )
1052 .into());
1053 }
1054 }
1055 DriftDetectorType::Adwin { delta } => {
1056 if *delta <= 0.0 || *delta >= 1.0 {
1057 return Err(ConfigError::out_of_range(
1058 "drift_detector.Adwin.delta",
1059 "must be in (0, 1)",
1060 delta,
1061 )
1062 .into());
1063 }
1064 }
1065 DriftDetectorType::Ddm {
1066 warning_level,
1067 drift_level,
1068 min_instances,
1069 } => {
1070 if *warning_level <= 0.0 {
1071 return Err(ConfigError::out_of_range(
1072 "drift_detector.Ddm.warning_level",
1073 "must be > 0",
1074 warning_level,
1075 )
1076 .into());
1077 }
1078 if *drift_level <= 0.0 {
1079 return Err(ConfigError::out_of_range(
1080 "drift_detector.Ddm.drift_level",
1081 "must be > 0",
1082 drift_level,
1083 )
1084 .into());
1085 }
1086 if *drift_level <= *warning_level {
1087 return Err(ConfigError::invalid(
1088 "drift_detector.Ddm.drift_level",
1089 format!(
1090 "must be > warning_level ({}), got {}",
1091 warning_level, drift_level
1092 ),
1093 )
1094 .into());
1095 }
1096 if *min_instances == 0 {
1097 return Err(ConfigError::out_of_range(
1098 "drift_detector.Ddm.min_instances",
1099 "must be > 0",
1100 min_instances,
1101 )
1102 .into());
1103 }
1104 }
1105 }
1106
1107 // -- Variant parameters --
1108 match &c.variant {
1109 SGBTVariant::Standard => {} // no extra validation
1110 SGBTVariant::Skip { k } => {
1111 if *k == 0 {
1112 return Err(
1113 ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1114 );
1115 }
1116 }
1117 SGBTVariant::MultipleIterations { multiplier } => {
1118 if *multiplier <= 0.0 {
1119 return Err(ConfigError::out_of_range(
1120 "variant.MultipleIterations.multiplier",
1121 "must be > 0",
1122 multiplier,
1123 )
1124 .into());
1125 }
1126 }
1127 }
1128
1129 Ok(self.config)
1130 }
1131}
1132
1133// ---------------------------------------------------------------------------
1134// Display impls
1135// ---------------------------------------------------------------------------
1136
1137impl core::fmt::Display for DriftDetectorType {
1138 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1139 match self {
1140 Self::PageHinkley { delta, lambda } => {
1141 write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1142 }
1143 Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1144 Self::Ddm {
1145 warning_level,
1146 drift_level,
1147 min_instances,
1148 } => write!(
1149 f,
1150 "Ddm(warning={}, drift={}, min_instances={})",
1151 warning_level, drift_level, min_instances
1152 ),
1153 }
1154 }
1155}
1156
1157impl core::fmt::Display for SGBTConfig {
1158 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1159 write!(
1160 f,
1161 "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1162 self.n_steps,
1163 self.learning_rate,
1164 self.max_depth,
1165 self.n_bins,
1166 self.variant,
1167 self.drift_detector,
1168 )
1169 }
1170}
1171
1172// ---------------------------------------------------------------------------
1173// Tests
1174// ---------------------------------------------------------------------------
1175
1176#[cfg(test)]
1177mod tests {
1178 use super::*;
1179 use alloc::format;
1180 use alloc::vec;
1181
1182 // ------------------------------------------------------------------
1183 // 1. Default config values are correct
1184 // ------------------------------------------------------------------
1185 #[test]
1186 fn default_config_values() {
1187 let cfg = SGBTConfig::default();
1188 assert_eq!(cfg.n_steps, 100);
1189 assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1190 assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1191 assert_eq!(cfg.max_depth, 6);
1192 assert_eq!(cfg.n_bins, 64);
1193 assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1194 assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1195 assert_eq!(cfg.grace_period, 200);
1196 assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1197 assert_eq!(cfg.variant, SGBTVariant::Standard);
1198 }
1199
1200 // ------------------------------------------------------------------
1201 // 2. Builder chain works
1202 // ------------------------------------------------------------------
1203 #[test]
1204 fn builder_chain() {
1205 let cfg = SGBTConfig::builder()
1206 .n_steps(50)
1207 .learning_rate(0.1)
1208 .feature_subsample_rate(0.5)
1209 .max_depth(10)
1210 .n_bins(128)
1211 .lambda(0.5)
1212 .gamma(0.1)
1213 .grace_period(500)
1214 .delta(1e-3)
1215 .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1216 .variant(SGBTVariant::Skip { k: 5 })
1217 .build()
1218 .expect("valid config");
1219
1220 assert_eq!(cfg.n_steps, 50);
1221 assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1222 assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1223 assert_eq!(cfg.max_depth, 10);
1224 assert_eq!(cfg.n_bins, 128);
1225 assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1226 assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1227 assert_eq!(cfg.grace_period, 500);
1228 assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1229 assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1230
1231 // Verify drift detector type is Adwin.
1232 match &cfg.drift_detector {
1233 DriftDetectorType::Adwin { delta } => {
1234 assert!((delta - 0.01).abs() < f64::EPSILON);
1235 }
1236 _ => panic!("expected Adwin drift detector"),
1237 }
1238 }
1239
1240 // ------------------------------------------------------------------
1241 // 3. Validation rejects invalid values
1242 // ------------------------------------------------------------------
1243 #[test]
1244 fn validation_rejects_n_steps_zero() {
1245 let result = SGBTConfig::builder().n_steps(0).build();
1246 assert!(result.is_err());
1247 let msg = format!("{}", result.unwrap_err());
1248 assert!(msg.contains("n_steps"));
1249 }
1250
1251 #[test]
1252 fn validation_rejects_learning_rate_zero() {
1253 let result = SGBTConfig::builder().learning_rate(0.0).build();
1254 assert!(result.is_err());
1255 let msg = format!("{}", result.unwrap_err());
1256 assert!(msg.contains("learning_rate"));
1257 }
1258
1259 #[test]
1260 fn validation_rejects_learning_rate_above_one() {
1261 let result = SGBTConfig::builder().learning_rate(1.5).build();
1262 assert!(result.is_err());
1263 }
1264
1265 #[test]
1266 fn validation_accepts_learning_rate_one() {
1267 let result = SGBTConfig::builder().learning_rate(1.0).build();
1268 assert!(result.is_ok());
1269 }
1270
1271 #[test]
1272 fn validation_rejects_negative_learning_rate() {
1273 let result = SGBTConfig::builder().learning_rate(-0.1).build();
1274 assert!(result.is_err());
1275 }
1276
1277 #[test]
1278 fn validation_rejects_feature_subsample_zero() {
1279 let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1280 assert!(result.is_err());
1281 }
1282
1283 #[test]
1284 fn validation_rejects_feature_subsample_above_one() {
1285 let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1286 assert!(result.is_err());
1287 }
1288
1289 #[test]
1290 fn validation_rejects_max_depth_zero() {
1291 let result = SGBTConfig::builder().max_depth(0).build();
1292 assert!(result.is_err());
1293 }
1294
1295 #[test]
1296 fn validation_rejects_n_bins_one() {
1297 let result = SGBTConfig::builder().n_bins(1).build();
1298 assert!(result.is_err());
1299 }
1300
1301 #[test]
1302 fn validation_rejects_negative_lambda() {
1303 let result = SGBTConfig::builder().lambda(-0.1).build();
1304 assert!(result.is_err());
1305 }
1306
1307 #[test]
1308 fn validation_accepts_zero_lambda() {
1309 let result = SGBTConfig::builder().lambda(0.0).build();
1310 assert!(result.is_ok());
1311 }
1312
1313 #[test]
1314 fn validation_rejects_negative_gamma() {
1315 let result = SGBTConfig::builder().gamma(-0.1).build();
1316 assert!(result.is_err());
1317 }
1318
1319 #[test]
1320 fn validation_rejects_grace_period_zero() {
1321 let result = SGBTConfig::builder().grace_period(0).build();
1322 assert!(result.is_err());
1323 }
1324
1325 #[test]
1326 fn validation_rejects_delta_zero() {
1327 let result = SGBTConfig::builder().delta(0.0).build();
1328 assert!(result.is_err());
1329 }
1330
1331 #[test]
1332 fn validation_rejects_delta_one() {
1333 let result = SGBTConfig::builder().delta(1.0).build();
1334 assert!(result.is_err());
1335 }
1336
1337 // ------------------------------------------------------------------
1338 // 3b. Drift detector parameter validation
1339 // ------------------------------------------------------------------
1340 #[test]
1341 fn validation_rejects_pht_negative_delta() {
1342 let result = SGBTConfig::builder()
1343 .drift_detector(DriftDetectorType::PageHinkley {
1344 delta: -1.0,
1345 lambda: 50.0,
1346 })
1347 .build();
1348 assert!(result.is_err());
1349 let msg = format!("{}", result.unwrap_err());
1350 assert!(msg.contains("PageHinkley"));
1351 }
1352
1353 #[test]
1354 fn validation_rejects_pht_zero_lambda() {
1355 let result = SGBTConfig::builder()
1356 .drift_detector(DriftDetectorType::PageHinkley {
1357 delta: 0.005,
1358 lambda: 0.0,
1359 })
1360 .build();
1361 assert!(result.is_err());
1362 }
1363
1364 #[test]
1365 fn validation_rejects_adwin_delta_out_of_range() {
1366 let result = SGBTConfig::builder()
1367 .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1368 .build();
1369 assert!(result.is_err());
1370
1371 let result = SGBTConfig::builder()
1372 .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1373 .build();
1374 assert!(result.is_err());
1375 }
1376
1377 #[test]
1378 fn validation_rejects_ddm_warning_above_drift() {
1379 let result = SGBTConfig::builder()
1380 .drift_detector(DriftDetectorType::Ddm {
1381 warning_level: 3.0,
1382 drift_level: 2.0,
1383 min_instances: 30,
1384 })
1385 .build();
1386 assert!(result.is_err());
1387 let msg = format!("{}", result.unwrap_err());
1388 assert!(msg.contains("drift_level"));
1389 assert!(msg.contains("must be > warning_level"));
1390 }
1391
1392 #[test]
1393 fn validation_rejects_ddm_equal_levels() {
1394 let result = SGBTConfig::builder()
1395 .drift_detector(DriftDetectorType::Ddm {
1396 warning_level: 2.0,
1397 drift_level: 2.0,
1398 min_instances: 30,
1399 })
1400 .build();
1401 assert!(result.is_err());
1402 }
1403
1404 #[test]
1405 fn validation_rejects_ddm_zero_min_instances() {
1406 let result = SGBTConfig::builder()
1407 .drift_detector(DriftDetectorType::Ddm {
1408 warning_level: 2.0,
1409 drift_level: 3.0,
1410 min_instances: 0,
1411 })
1412 .build();
1413 assert!(result.is_err());
1414 }
1415
1416 #[test]
1417 fn validation_rejects_ddm_zero_warning_level() {
1418 let result = SGBTConfig::builder()
1419 .drift_detector(DriftDetectorType::Ddm {
1420 warning_level: 0.0,
1421 drift_level: 3.0,
1422 min_instances: 30,
1423 })
1424 .build();
1425 assert!(result.is_err());
1426 }
1427
1428 // ------------------------------------------------------------------
1429 // 3c. Variant parameter validation
1430 // ------------------------------------------------------------------
1431 #[test]
1432 fn validation_rejects_skip_k_zero() {
1433 let result = SGBTConfig::builder()
1434 .variant(SGBTVariant::Skip { k: 0 })
1435 .build();
1436 assert!(result.is_err());
1437 let msg = format!("{}", result.unwrap_err());
1438 assert!(msg.contains("Skip"));
1439 }
1440
1441 #[test]
1442 fn validation_rejects_mi_zero_multiplier() {
1443 let result = SGBTConfig::builder()
1444 .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1445 .build();
1446 assert!(result.is_err());
1447 }
1448
1449 #[test]
1450 fn validation_rejects_mi_negative_multiplier() {
1451 let result = SGBTConfig::builder()
1452 .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1453 .build();
1454 assert!(result.is_err());
1455 }
1456
1457 #[test]
1458 fn validation_accepts_standard_variant() {
1459 let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1460 assert!(result.is_ok());
1461 }
1462
1463 // ------------------------------------------------------------------
1464 // 4. DriftDetectorType creates correct detector types
1465 // ------------------------------------------------------------------
1466 #[test]
1467 fn drift_detector_type_creates_page_hinkley() {
1468 let dt = DriftDetectorType::PageHinkley {
1469 delta: 0.01,
1470 lambda: 100.0,
1471 };
1472 let mut detector = dt.create();
1473
1474 // Should start with zero mean.
1475 assert_eq!(detector.estimated_mean(), 0.0);
1476
1477 // Feed a stable value -- should not drift.
1478 let signal = detector.update(1.0);
1479 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1480 }
1481
1482 #[test]
1483 fn drift_detector_type_creates_adwin() {
1484 let dt = DriftDetectorType::Adwin { delta: 0.05 };
1485 let mut detector = dt.create();
1486
1487 assert_eq!(detector.estimated_mean(), 0.0);
1488 let signal = detector.update(1.0);
1489 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1490 }
1491
1492 #[test]
1493 fn drift_detector_type_creates_ddm() {
1494 let dt = DriftDetectorType::Ddm {
1495 warning_level: 2.0,
1496 drift_level: 3.0,
1497 min_instances: 30,
1498 };
1499 let mut detector = dt.create();
1500
1501 assert_eq!(detector.estimated_mean(), 0.0);
1502 let signal = detector.update(0.1);
1503 assert_eq!(signal, crate::drift::DriftSignal::Stable);
1504 }
1505
1506 // ------------------------------------------------------------------
1507 // 5. Default drift detector is PageHinkley
1508 // ------------------------------------------------------------------
1509 #[test]
1510 fn default_drift_detector_is_page_hinkley() {
1511 let dt = DriftDetectorType::default();
1512 match dt {
1513 DriftDetectorType::PageHinkley { delta, lambda } => {
1514 assert!((delta - 0.005).abs() < f64::EPSILON);
1515 assert!((lambda - 50.0).abs() < f64::EPSILON);
1516 }
1517 _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1518 }
1519 }
1520
1521 // ------------------------------------------------------------------
1522 // 6. Default config builds without error
1523 // ------------------------------------------------------------------
1524 #[test]
1525 fn default_config_builds() {
1526 let result = SGBTConfig::builder().build();
1527 assert!(result.is_ok());
1528 }
1529
1530 // ------------------------------------------------------------------
1531 // 7. Config clone preserves all fields
1532 // ------------------------------------------------------------------
1533 #[test]
1534 fn config_clone_preserves_fields() {
1535 let original = SGBTConfig::builder()
1536 .n_steps(50)
1537 .learning_rate(0.05)
1538 .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1539 .build()
1540 .unwrap();
1541
1542 let cloned = original.clone();
1543
1544 assert_eq!(cloned.n_steps, 50);
1545 assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1546 assert_eq!(
1547 cloned.variant,
1548 SGBTVariant::MultipleIterations { multiplier: 5.0 }
1549 );
1550 }
1551
1552 // ------------------------------------------------------------------
1553 // 8. All three DDM valid configs accepted
1554 // ------------------------------------------------------------------
1555 #[test]
1556 fn valid_ddm_config_accepted() {
1557 let result = SGBTConfig::builder()
1558 .drift_detector(DriftDetectorType::Ddm {
1559 warning_level: 1.5,
1560 drift_level: 2.5,
1561 min_instances: 10,
1562 })
1563 .build();
1564 assert!(result.is_ok());
1565 }
1566
1567 // ------------------------------------------------------------------
1568 // 9. Created detectors are functional (round-trip test)
1569 // ------------------------------------------------------------------
1570 #[test]
1571 fn created_detectors_are_functional() {
1572 // Create a detector, feed it data, verify it can detect drift.
1573 let dt = DriftDetectorType::PageHinkley {
1574 delta: 0.005,
1575 lambda: 50.0,
1576 };
1577 let mut detector = dt.create();
1578
1579 // Feed 500 stable values.
1580 for _ in 0..500 {
1581 detector.update(1.0);
1582 }
1583
1584 // Feed shifted values -- should eventually drift.
1585 let mut drifted = false;
1586 for _ in 0..500 {
1587 if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1588 drifted = true;
1589 break;
1590 }
1591 }
1592 assert!(
1593 drifted,
1594 "detector created from DriftDetectorType should be functional"
1595 );
1596 }
1597
1598 // ------------------------------------------------------------------
1599 // 10. Boundary acceptance: n_bins=2 is the minimum valid
1600 // ------------------------------------------------------------------
1601 #[test]
1602 fn boundary_n_bins_two_accepted() {
1603 let result = SGBTConfig::builder().n_bins(2).build();
1604 assert!(result.is_ok());
1605 }
1606
1607 // ------------------------------------------------------------------
1608 // 11. Boundary acceptance: grace_period=1 is valid
1609 // ------------------------------------------------------------------
1610 #[test]
1611 fn boundary_grace_period_one_accepted() {
1612 let result = SGBTConfig::builder().grace_period(1).build();
1613 assert!(result.is_ok());
1614 }
1615
1616 // ------------------------------------------------------------------
1617 // 12. Feature names -- valid config with names
1618 // ------------------------------------------------------------------
1619 #[test]
1620 fn feature_names_accepted() {
1621 let cfg = SGBTConfig::builder()
1622 .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1623 .build()
1624 .unwrap();
1625 assert_eq!(
1626 cfg.feature_names.as_ref().unwrap(),
1627 &["price", "volume", "spread"]
1628 );
1629 }
1630
1631 // ------------------------------------------------------------------
1632 // 13. Feature names -- duplicate names rejected
1633 // ------------------------------------------------------------------
1634 #[test]
1635 fn feature_names_rejects_duplicates() {
1636 let result = SGBTConfig::builder()
1637 .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1638 .build();
1639 assert!(result.is_err());
1640 let msg = format!("{}", result.unwrap_err());
1641 assert!(msg.contains("duplicate"));
1642 }
1643
1644 // ------------------------------------------------------------------
1645 // 14. Feature names -- serde backward compat (missing field)
1646 // Requires serde_json which is only in the full irithyll crate.
1647 // ------------------------------------------------------------------
1648
1649 // ------------------------------------------------------------------
1650 // 15. Feature names -- empty vec is valid
1651 // ------------------------------------------------------------------
1652 #[test]
1653 fn feature_names_empty_vec_accepted() {
1654 let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1655 assert!(cfg.feature_names.unwrap().is_empty());
1656 }
1657
1658 // ------------------------------------------------------------------
1659 // 16. Adaptive leaf bound -- builder
1660 // ------------------------------------------------------------------
1661 #[test]
1662 fn builder_adaptive_leaf_bound() {
1663 let cfg = SGBTConfig::builder()
1664 .adaptive_leaf_bound(3.0)
1665 .build()
1666 .unwrap();
1667 assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1668 }
1669
1670 // ------------------------------------------------------------------
1671 // 17. Adaptive leaf bound -- validation rejects zero
1672 // ------------------------------------------------------------------
1673 #[test]
1674 fn validation_rejects_zero_adaptive_leaf_bound() {
1675 let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1676 assert!(result.is_err());
1677 let msg = format!("{}", result.unwrap_err());
1678 assert!(
1679 msg.contains("adaptive_leaf_bound"),
1680 "error should mention adaptive_leaf_bound: {}",
1681 msg,
1682 );
1683 }
1684
1685 // ------------------------------------------------------------------
1686 // 18. Adaptive leaf bound -- serde backward compat
1687 // Requires serde_json which is only in the full irithyll crate.
1688 // ------------------------------------------------------------------
1689}