irithyll_core/ensemble/config.rs
1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34/// `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1) compute.
35/// Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38/// log(σ) from features (NGBoost-style dual chain). Gives *feature-conditional*
39/// uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43 /// EWMA of squared prediction errors — always calibrated, O(1).
44 #[default]
45 Empirical,
46 /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47 TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61 /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62 /// (detection threshold).
63 PageHinkley {
64 /// Magnitude tolerance. Default 0.005.
65 delta: f64,
66 /// Detection threshold. Default 50.0.
67 lambda: f64,
68 },
69
70 /// ADWIN with custom confidence parameter.
71 Adwin {
72 /// Confidence (smaller = fewer false positives). Default 0.002.
73 delta: f64,
74 },
75
76 /// DDM with custom warning/drift levels and minimum warmup instances.
77 Ddm {
78 /// Warning threshold multiplier. Default 2.0.
79 warning_level: f64,
80 /// Drift threshold multiplier. Default 3.0.
81 drift_level: f64,
82 /// Minimum observations before detection activates. Default 30.
83 min_instances: u64,
84 },
85}
86
87impl Default for DriftDetectorType {
88 fn default() -> Self {
89 DriftDetectorType::PageHinkley {
90 delta: 0.005,
91 lambda: 50.0,
92 }
93 }
94}
95
96impl DriftDetectorType {
97 /// Create a new, fresh drift detector from this configuration.
98 pub fn create(&self) -> Box<dyn DriftDetector> {
99 match self {
100 Self::PageHinkley { delta, lambda } => {
101 Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102 }
103 Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104 Self::Ddm {
105 warning_level,
106 drift_level,
107 min_instances,
108 } => Box::new(Ddm::with_params(
109 *warning_level,
110 *drift_level,
111 *min_instances,
112 )),
113 }
114 }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter | Default |
128/// |--------------------------|----------------------|
129/// | `n_steps` | 100 |
130/// | `learning_rate` | 0.0125 |
131/// | `feature_subsample_rate` | 0.75 |
132/// | `max_depth` | 6 |
133/// | `n_bins` | 64 |
134/// | `lambda` | 1.0 |
135/// | `gamma` | 0.0 |
136/// | `grace_period` | 200 |
137/// | `delta` | 1e-7 |
138/// | `drift_detector` | PageHinkley(0.005, 50.0) |
139/// | `variant` | Standard |
140/// | `seed` | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count` | 50 |
142/// | `leaf_half_life` | None (disabled) |
143/// | `max_tree_samples` | None (disabled) |
144/// | `split_reeval_interval` | None (disabled) |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148 /// Number of boosting steps (trees in the ensemble). Default 100.
149 pub n_steps: usize,
150 /// Learning rate (shrinkage). Default 0.0125.
151 pub learning_rate: f64,
152 /// Fraction of features to subsample per tree. Default 0.75.
153 pub feature_subsample_rate: f64,
154 /// Maximum tree depth. Default 6.
155 pub max_depth: usize,
156 /// Number of histogram bins. Default 64.
157 pub n_bins: usize,
158 /// L2 regularization parameter (lambda). Default 1.0.
159 pub lambda: f64,
160 /// Minimum split gain (gamma). Default 0.0.
161 pub gamma: f64,
162 /// Grace period: min samples before evaluating splits. Default 200.
163 pub grace_period: usize,
164 /// Hoeffding bound confidence (delta). Default 1e-7.
165 pub delta: f64,
166 /// Drift detector type for tree replacement. Default: PageHinkley.
167 pub drift_detector: DriftDetectorType,
168 /// SGBT computational variant. Default: Standard.
169 pub variant: SGBTVariant,
170 /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171 ///
172 /// Controls feature subsampling and variant skip/MI stochastic decisions.
173 /// Two models with the same seed and same data will produce identical results.
174 pub seed: u64,
175 /// Number of initial targets to collect before computing the base prediction.
176 /// Default: 50.
177 pub initial_target_count: usize,
178
179 /// Half-life for exponential leaf decay (in samples per leaf).
180 ///
181 /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182 /// statistics have half the weight of the most recent sample. This causes
183 /// the model to continuously adapt to changing data distributions rather
184 /// than freezing on early observations.
185 ///
186 /// `None` (default) disables decay -- traditional monotonic accumulation.
187 #[cfg_attr(feature = "serde", serde(default))]
188 pub leaf_half_life: Option<usize>,
189
190 /// Maximum samples a single tree processes before proactive replacement.
191 ///
192 /// After this many samples, the tree is replaced with a fresh one regardless
193 /// of drift detector state. Prevents stale tree structure from persisting
194 /// when the drift detector is not sensitive enough.
195 ///
196 /// `None` (default) disables time-based replacement.
197 #[cfg_attr(feature = "serde", serde(default))]
198 pub max_tree_samples: Option<u64>,
199
200 /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
201 /// whether a split would improve them.
202 ///
203 /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
204 /// `split_reeval_interval` samples since its last evaluation and has reached
205 /// max depth, it re-evaluates whether a split should be performed.
206 ///
207 /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
208 #[cfg_attr(feature = "serde", serde(default))]
209 pub split_reeval_interval: Option<usize>,
210
211 /// Optional human-readable feature names.
212 ///
213 /// When set, enables `named_feature_importances` and
214 /// `train_one_named` for production-friendly named access.
215 /// Length must match the number of features in training data.
216 #[cfg_attr(feature = "serde", serde(default))]
217 pub feature_names: Option<Vec<String>>,
218
219 /// Optional per-feature type declarations.
220 ///
221 /// When set, declares which features are categorical vs continuous.
222 /// Categorical features use one-bin-per-category binning and Fisher
223 /// optimal binary partitioning for split evaluation.
224 /// Length must match the number of features in training data.
225 ///
226 /// `None` (default) treats all features as continuous.
227 #[cfg_attr(feature = "serde", serde(default))]
228 pub feature_types: Option<Vec<FeatureType>>,
229
230 /// Gradient clipping threshold in standard deviations per leaf.
231 ///
232 /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
233 /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
234 /// clamped to the boundary. This prevents outlier samples from corrupting
235 /// leaf statistics, which is critical in streaming settings where sudden
236 /// label floods can destabilize the model.
237 ///
238 /// Typical value: 3.0 (3-sigma clipping).
239 /// `None` (default) disables gradient clipping.
240 #[cfg_attr(feature = "serde", serde(default))]
241 pub gradient_clip_sigma: Option<f64>,
242
243 /// Per-feature monotonic constraints.
244 ///
245 /// Each element specifies the monotonic relationship between a feature and
246 /// the prediction:
247 /// - `+1`: prediction must be non-decreasing as feature value increases.
248 /// - `-1`: prediction must be non-increasing as feature value increases.
249 /// - `0`: no constraint (unconstrained).
250 ///
251 /// During split evaluation, candidate splits that would violate monotonicity
252 /// (left child value > right child value for +1 constraints, or vice versa)
253 /// are rejected.
254 ///
255 /// Length must match the number of features in training data.
256 /// `None` (default) means no monotonic constraints.
257 #[cfg_attr(feature = "serde", serde(default))]
258 pub monotone_constraints: Option<Vec<i8>>,
259
260 /// EWMA smoothing factor for quality-based tree pruning.
261 ///
262 /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
263 /// moving average of its marginal contribution to the ensemble. Trees whose
264 /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
265 /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
266 /// samples are replaced with a fresh tree that can learn the current regime.
267 ///
268 /// This prevents "dead wood" -- trees from a past regime that no longer
269 /// contribute meaningfully to ensemble accuracy.
270 ///
271 /// `None` (default) disables quality-based pruning.
272 /// Suggested value: 0.01.
273 #[cfg_attr(feature = "serde", serde(default))]
274 pub quality_prune_alpha: Option<f64>,
275
276 /// Minimum contribution threshold for quality-based pruning.
277 ///
278 /// A tree's EWMA contribution must stay above this value to avoid being
279 /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
280 ///
281 /// Default: 1e-6.
282 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
283 pub quality_prune_threshold: f64,
284
285 /// Consecutive low-contribution samples before a tree is replaced.
286 ///
287 /// After this many consecutive samples where a tree's EWMA contribution
288 /// is below `quality_prune_threshold`, the tree is reset. Only used when
289 /// `quality_prune_alpha` is `Some`.
290 ///
291 /// Default: 500.
292 #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
293 pub quality_prune_patience: u64,
294
295 /// EWMA smoothing factor for error-weighted sample importance.
296 ///
297 /// When `Some(alpha)`, samples the model predicted poorly get higher
298 /// effective weight during histogram accumulation. The weight is:
299 /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
300 ///
301 /// This is a streaming version of AdaBoost's reweighting applied at the
302 /// gradient level -- learning capacity focuses on hard/novel patterns,
303 /// enabling faster adaptation to regime changes.
304 ///
305 /// `None` (default) disables error weighting.
306 /// Suggested value: 0.01.
307 #[cfg_attr(feature = "serde", serde(default))]
308 pub error_weight_alpha: Option<f64>,
309
310 /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
311 ///
312 /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
313 /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
314 /// is an EWMA of the model's predicted σ (alpha = 0.001).
315 ///
316 /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
317 /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
318 /// trains at the unmodulated base rate to prevent positive feedback loops.
319 ///
320 /// Default: `false`.
321 #[cfg_attr(feature = "serde", serde(default))]
322 pub uncertainty_modulated_lr: bool,
323
324 /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
325 ///
326 /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
327 /// errors. `σ = sqrt(ewma_sq_err)`. Always calibrated, zero tuning, O(1).
328 /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
329 /// separate tree ensemble predicting log(σ) from features.
330 ///
331 /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
332 /// is strongly recommended — scale tree gradients are inherently weak and
333 /// the trees often fail to split.
334 #[cfg_attr(feature = "serde", serde(default))]
335 pub scale_mode: ScaleMode,
336
337 /// EWMA smoothing factor for empirical σ estimation.
338 ///
339 /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
340 /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
341 /// Higher values react faster to regime changes but are noisier.
342 ///
343 /// Default: `0.01` (~100-sample effective window).
344 #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
345 pub empirical_sigma_alpha: f64,
346
347 /// Maximum absolute leaf output value.
348 ///
349 /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
350 /// Prevents runaway leaf weights from causing prediction explosions
351 /// in feedback loops. `None` (default) means no clamping.
352 #[cfg_attr(feature = "serde", serde(default))]
353 pub max_leaf_output: Option<f64>,
354
355 /// Per-leaf adaptive output bound (sigma multiplier).
356 ///
357 /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
358 /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
359 /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
360 ///
361 /// This is strictly superior to `max_leaf_output` for streaming — the bound
362 /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
363 /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
364 ///
365 /// Typical value: 3.0 (3-sigma bound).
366 /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
367 #[cfg_attr(feature = "serde", serde(default))]
368 pub adaptive_leaf_bound: Option<f64>,
369
370 /// Per-split information criterion (Lunde-Kleppe-Skaug 2020).
371 ///
372 /// When `Some(cir_factor)`, replaces `max_depth` with a per-split
373 /// generalization test. Each candidate split must have
374 /// `gain > cir_factor * sigma^2_g / n * n_features`.
375 /// `max_depth * 2` becomes a hard safety ceiling.
376 ///
377 /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
378 /// `None` (default) uses static `max_depth` only.
379 #[cfg_attr(feature = "serde", serde(default))]
380 pub adaptive_depth: Option<f64>,
381
382 /// Minimum hessian sum before a leaf produces non-zero output.
383 ///
384 /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
385 /// Prevents post-replacement spikes from fresh leaves with insufficient
386 /// samples. `None` (default) means all leaves contribute immediately.
387 #[cfg_attr(feature = "serde", serde(default))]
388 pub min_hessian_sum: Option<f64>,
389
390 /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
391 ///
392 /// When `Some(k)`, the distributional location gradient uses Huber loss
393 /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
394 /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
395 /// `None` (default) uses squared loss.
396 #[cfg_attr(feature = "serde", serde(default))]
397 pub huber_k: Option<f64>,
398
399 /// Shadow warmup for graduated tree handoff.
400 ///
401 /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
402 /// alongside every active tree. The shadow trains on the same gradient stream
403 /// but does not contribute to predictions until it has seen `n` samples.
404 ///
405 /// As the active tree ages past 80% of `max_tree_samples`, its prediction
406 /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
407 /// over `n` samples after warmup. When the active weight reaches 0, the shadow
408 /// is promoted and a new shadow is spawned — no cold-start prediction dip.
409 ///
410 /// Requires `max_tree_samples` to be set for time-based graduated handoff.
411 /// Drift-based replacement still uses hard swap (shadow is already warm).
412 ///
413 /// `None` (default) disables graduated handoff — uses traditional hard swap.
414 #[cfg_attr(feature = "serde", serde(default))]
415 pub shadow_warmup: Option<usize>,
416
417 /// Leaf prediction model type.
418 ///
419 /// Controls how each leaf computes its prediction:
420 /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
421 /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
422 /// AdaGrad optimization. Optional `decay` for concept drift. Recommended for
423 /// low-depth trees (depth 2--4).
424 /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
425 /// Optional `decay` for concept drift.
426 /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
427 /// when the Hoeffding bound confirms a more complex model is better.
428 ///
429 /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
430 #[cfg_attr(feature = "serde", serde(default))]
431 pub leaf_model_type: LeafModelType,
432
433 /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
434 ///
435 /// When non-zero, the distributional model maintains a packed f32 cache of
436 /// its location ensemble that is re-exported every `packed_refresh_interval`
437 /// training samples. Predictions use the cache for O(1)-per-tree inference
438 /// via contiguous memory traversal, falling back to full tree traversal when
439 /// the cache is absent or produces non-finite results.
440 ///
441 /// `0` (default) disables the packed cache.
442 #[cfg_attr(feature = "serde", serde(default))]
443 pub packed_refresh_interval: u64,
444}
445
446#[cfg(feature = "serde")]
447fn default_empirical_sigma_alpha() -> f64 {
448 0.01
449}
450
451#[cfg(feature = "serde")]
452fn default_quality_prune_threshold() -> f64 {
453 1e-6
454}
455
456#[cfg(feature = "serde")]
457fn default_quality_prune_patience() -> u64 {
458 500
459}
460
461impl Default for SGBTConfig {
462 fn default() -> Self {
463 Self {
464 n_steps: 100,
465 learning_rate: 0.0125,
466 feature_subsample_rate: 0.75,
467 max_depth: 6,
468 n_bins: 64,
469 lambda: 1.0,
470 gamma: 0.0,
471 grace_period: 200,
472 delta: 1e-7,
473 drift_detector: DriftDetectorType::default(),
474 variant: SGBTVariant::default(),
475 seed: 0xDEAD_BEEF_CAFE_4242,
476 initial_target_count: 50,
477 leaf_half_life: None,
478 max_tree_samples: None,
479 split_reeval_interval: None,
480 feature_names: None,
481 feature_types: None,
482 gradient_clip_sigma: None,
483 monotone_constraints: None,
484 quality_prune_alpha: None,
485 quality_prune_threshold: 1e-6,
486 quality_prune_patience: 500,
487 error_weight_alpha: None,
488 uncertainty_modulated_lr: false,
489 scale_mode: ScaleMode::default(),
490 empirical_sigma_alpha: 0.01,
491 max_leaf_output: None,
492 adaptive_leaf_bound: None,
493 adaptive_depth: None,
494 min_hessian_sum: None,
495 huber_k: None,
496 shadow_warmup: None,
497 leaf_model_type: LeafModelType::default(),
498 packed_refresh_interval: 0,
499 }
500 }
501}
502
503impl SGBTConfig {
504 /// Start building a configuration via the builder pattern.
505 pub fn builder() -> SGBTConfigBuilder {
506 SGBTConfigBuilder::default()
507 }
508}
509
510// ---------------------------------------------------------------------------
511// SGBTConfigBuilder
512// ---------------------------------------------------------------------------
513
514/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
515///
516/// # Example
517///
518/// ```ignore
519/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
520/// use irithyll::ensemble::variants::SGBTVariant;
521///
522/// let config = SGBTConfig::builder()
523/// .n_steps(200)
524/// .learning_rate(0.05)
525/// .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
526/// .variant(SGBTVariant::Skip { k: 10 })
527/// .build()
528/// .expect("valid config");
529/// ```
530#[derive(Debug, Clone, Default)]
531pub struct SGBTConfigBuilder {
532 config: SGBTConfig,
533}
534
535impl SGBTConfigBuilder {
536 /// Set the number of boosting steps (trees in the ensemble).
537 pub fn n_steps(mut self, n: usize) -> Self {
538 self.config.n_steps = n;
539 self
540 }
541
542 /// Set the learning rate (shrinkage factor).
543 pub fn learning_rate(mut self, lr: f64) -> Self {
544 self.config.learning_rate = lr;
545 self
546 }
547
548 /// Set the fraction of features to subsample per tree.
549 pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
550 self.config.feature_subsample_rate = rate;
551 self
552 }
553
554 /// Set the maximum tree depth.
555 pub fn max_depth(mut self, depth: usize) -> Self {
556 self.config.max_depth = depth;
557 self
558 }
559
560 /// Set the number of histogram bins per feature.
561 pub fn n_bins(mut self, bins: usize) -> Self {
562 self.config.n_bins = bins;
563 self
564 }
565
566 /// Set the L2 regularization parameter (lambda).
567 pub fn lambda(mut self, l: f64) -> Self {
568 self.config.lambda = l;
569 self
570 }
571
572 /// Set the minimum split gain (gamma).
573 pub fn gamma(mut self, g: f64) -> Self {
574 self.config.gamma = g;
575 self
576 }
577
578 /// Set the grace period (minimum samples before evaluating splits).
579 pub fn grace_period(mut self, gp: usize) -> Self {
580 self.config.grace_period = gp;
581 self
582 }
583
584 /// Set the Hoeffding bound confidence parameter (delta).
585 pub fn delta(mut self, d: f64) -> Self {
586 self.config.delta = d;
587 self
588 }
589
590 /// Set the drift detector type for tree replacement.
591 pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
592 self.config.drift_detector = dt;
593 self
594 }
595
596 /// Set the SGBT computational variant.
597 pub fn variant(mut self, v: SGBTVariant) -> Self {
598 self.config.variant = v;
599 self
600 }
601
602 /// Set the random seed for deterministic reproducibility.
603 ///
604 /// Controls feature subsampling and variant skip/MI stochastic decisions.
605 /// Two models with the same seed and data sequence will produce identical results.
606 pub fn seed(mut self, seed: u64) -> Self {
607 self.config.seed = seed;
608 self
609 }
610
611 /// Set the number of initial targets to collect before computing the base prediction.
612 ///
613 /// The model collects this many target values before initializing the base
614 /// prediction (via `loss.initial_prediction`). Default: 50.
615 pub fn initial_target_count(mut self, count: usize) -> Self {
616 self.config.initial_target_count = count;
617 self
618 }
619
620 /// Set the half-life for exponential leaf decay (in samples per leaf).
621 ///
622 /// After `n` samples, a leaf's accumulated statistics have half the weight
623 /// of the most recent sample. Enables continuous adaptation to concept drift.
624 pub fn leaf_half_life(mut self, n: usize) -> Self {
625 self.config.leaf_half_life = Some(n);
626 self
627 }
628
629 /// Set the maximum samples a single tree processes before proactive replacement.
630 ///
631 /// After `n` samples, the tree is replaced regardless of drift detector state.
632 pub fn max_tree_samples(mut self, n: u64) -> Self {
633 self.config.max_tree_samples = Some(n);
634 self
635 }
636
637 /// Set the split re-evaluation interval for max-depth leaves.
638 ///
639 /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
640 /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
641 pub fn split_reeval_interval(mut self, n: usize) -> Self {
642 self.config.split_reeval_interval = Some(n);
643 self
644 }
645
646 /// Set human-readable feature names.
647 ///
648 /// Enables named feature importances and named training input.
649 /// Names must be unique; validated at [`build()`](Self::build).
650 pub fn feature_names(mut self, names: Vec<String>) -> Self {
651 self.config.feature_names = Some(names);
652 self
653 }
654
655 /// Set per-feature type declarations.
656 ///
657 /// Declares which features are categorical vs continuous. Categorical features
658 /// use one-bin-per-category binning and Fisher optimal binary partitioning.
659 /// Supports up to 64 distinct category values per categorical feature.
660 pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
661 self.config.feature_types = Some(types);
662 self
663 }
664
665 /// Set per-leaf gradient clipping threshold (in standard deviations).
666 ///
667 /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
668 /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
669 /// from corrupting streaming model stability.
670 ///
671 /// Typical value: 3.0 (3-sigma clipping).
672 pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
673 self.config.gradient_clip_sigma = Some(sigma);
674 self
675 }
676
677 /// Set per-feature monotonic constraints.
678 ///
679 /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
680 /// Candidate splits violating monotonicity are rejected during tree growth.
681 pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
682 self.config.monotone_constraints = Some(constraints);
683 self
684 }
685
686 /// Enable quality-based tree pruning with the given EWMA smoothing factor.
687 ///
688 /// Trees whose marginal contribution drops below the threshold for
689 /// `patience` consecutive samples are replaced with fresh trees.
690 /// Suggested alpha: 0.01.
691 pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
692 self.config.quality_prune_alpha = Some(alpha);
693 self
694 }
695
696 /// Set the minimum contribution threshold for quality-based pruning.
697 ///
698 /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
699 pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
700 self.config.quality_prune_threshold = threshold;
701 self
702 }
703
704 /// Set the patience (consecutive low-contribution samples) before pruning.
705 ///
706 /// Default: 500. Only relevant when `quality_prune_alpha` is set.
707 pub fn quality_prune_patience(mut self, patience: u64) -> Self {
708 self.config.quality_prune_patience = patience;
709 self
710 }
711
712 /// Enable error-weighted sample importance with the given EWMA smoothing factor.
713 ///
714 /// Samples the model predicted poorly get higher effective weight.
715 /// Suggested alpha: 0.01.
716 pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
717 self.config.error_weight_alpha = Some(alpha);
718 self
719 }
720
721 /// Enable σ-modulated learning rate for distributional models.
722 ///
723 /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
724 /// so the model adapts faster during high-uncertainty regimes and conserves
725 /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
726 ///
727 /// By default uses empirical σ (EWMA of squared errors). Set
728 /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
729 pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
730 self.config.uncertainty_modulated_lr = enabled;
731 self
732 }
733
734 /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
735 ///
736 /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
737 /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
738 pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
739 self.config.scale_mode = mode;
740 self
741 }
742
743 /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
744 ///
745 /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
746 pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
747 self.config.empirical_sigma_alpha = alpha;
748 self
749 }
750
751 /// Set the maximum absolute leaf output value.
752 ///
753 /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
754 /// that cause prediction explosions.
755 pub fn max_leaf_output(mut self, max: f64) -> Self {
756 self.config.max_leaf_output = Some(max);
757 self
758 }
759
760 /// Set per-leaf adaptive output bound (sigma multiplier).
761 ///
762 /// Each leaf tracks EWMA of its own output weight and clamps to
763 /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
764 /// Recommended: use with `leaf_half_life` for streaming scenarios.
765 pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
766 self.config.adaptive_leaf_bound = Some(k);
767 self
768 }
769
770 /// Set the per-split information criterion factor (Lunde-Kleppe-Skaug 2020).
771 ///
772 /// Replaces static `max_depth` with a per-split generalization test.
773 /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
774 pub fn adaptive_depth(mut self, factor: f64) -> Self {
775 self.config.adaptive_depth = Some(factor);
776 self
777 }
778
779 /// Set the minimum hessian sum for leaf output.
780 ///
781 /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
782 /// post-replacement spikes.
783 pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
784 self.config.min_hessian_sum = Some(min_h);
785 self
786 }
787
788 /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
789 ///
790 /// When set, location gradients use Huber loss with adaptive
791 /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
792 pub fn huber_k(mut self, k: f64) -> Self {
793 self.config.huber_k = Some(k);
794 self
795 }
796
797 /// Enable graduated tree handoff with the given shadow warmup samples.
798 ///
799 /// Spawns an always-on shadow tree that trains alongside the active tree.
800 /// After `warmup` samples, the shadow begins contributing to predictions
801 /// via graduated blending. Eliminates prediction dips during tree replacement.
802 pub fn shadow_warmup(mut self, warmup: usize) -> Self {
803 self.config.shadow_warmup = Some(warmup);
804 self
805 }
806
807 /// Set the leaf prediction model type.
808 ///
809 /// [`LeafModelType::Linear`] is recommended for low-depth configurations
810 /// (depth 2--4) where per-leaf linear models reduce approximation error.
811 ///
812 /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
813 /// a trainable model per leaf, using the Hoeffding bound for promotion.
814 pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
815 self.config.leaf_model_type = lmt;
816 self
817 }
818
819 /// Set the packed cache refresh interval for distributional models.
820 ///
821 /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
822 /// maintains a packed f32 cache refreshed every `interval` training samples.
823 /// `0` (default) disables the cache.
824 pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
825 self.config.packed_refresh_interval = interval;
826 self
827 }
828
829 /// Validate and build the configuration.
830 ///
831 /// # Errors
832 ///
833 /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
834 /// [`ConfigError`] if any parameter is out of its valid range.
835 pub fn build(self) -> Result<SGBTConfig> {
836 let c = &self.config;
837
838 // -- Ensemble-level parameters --
839 if c.n_steps == 0 {
840 return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
841 }
842 if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
843 return Err(ConfigError::out_of_range(
844 "learning_rate",
845 "must be in (0, 1]",
846 c.learning_rate,
847 )
848 .into());
849 }
850 if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
851 return Err(ConfigError::out_of_range(
852 "feature_subsample_rate",
853 "must be in (0, 1]",
854 c.feature_subsample_rate,
855 )
856 .into());
857 }
858
859 // -- Tree-level parameters --
860 if c.max_depth == 0 {
861 return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
862 }
863 if c.n_bins < 2 {
864 return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
865 }
866 if c.lambda < 0.0 {
867 return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
868 }
869 if c.gamma < 0.0 {
870 return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
871 }
872 if c.grace_period == 0 {
873 return Err(
874 ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
875 );
876 }
877 if c.delta <= 0.0 || c.delta >= 1.0 {
878 return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
879 }
880
881 if c.initial_target_count == 0 {
882 return Err(ConfigError::out_of_range(
883 "initial_target_count",
884 "must be > 0",
885 c.initial_target_count,
886 )
887 .into());
888 }
889
890 // -- Streaming adaptation parameters --
891 if let Some(hl) = c.leaf_half_life {
892 if hl == 0 {
893 return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
894 }
895 }
896 if let Some(max) = c.max_tree_samples {
897 if max < 100 {
898 return Err(
899 ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
900 );
901 }
902 }
903 if let Some(interval) = c.split_reeval_interval {
904 if interval < c.grace_period {
905 return Err(ConfigError::invalid(
906 "split_reeval_interval",
907 format!(
908 "must be >= grace_period ({}), got {}",
909 c.grace_period, interval
910 ),
911 )
912 .into());
913 }
914 }
915
916 // -- Feature names --
917 if let Some(ref names) = c.feature_names {
918 // O(n^2) duplicate check — names list is small so no HashSet needed
919 for (i, name) in names.iter().enumerate() {
920 for prev in &names[..i] {
921 if name == prev {
922 return Err(ConfigError::invalid(
923 "feature_names",
924 format!("duplicate feature name: '{}'", name),
925 )
926 .into());
927 }
928 }
929 }
930 }
931
932 // -- Feature types --
933 if let Some(ref types) = c.feature_types {
934 if let Some(ref names) = c.feature_names {
935 if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
936 return Err(ConfigError::invalid(
937 "feature_types",
938 format!(
939 "length ({}) must match feature_names length ({})",
940 types.len(),
941 names.len()
942 ),
943 )
944 .into());
945 }
946 }
947 }
948
949 // -- Gradient clipping --
950 if let Some(sigma) = c.gradient_clip_sigma {
951 if sigma <= 0.0 {
952 return Err(
953 ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
954 );
955 }
956 }
957
958 // -- Monotonic constraints --
959 if let Some(ref mc) = c.monotone_constraints {
960 for (i, &v) in mc.iter().enumerate() {
961 if v != -1 && v != 0 && v != 1 {
962 return Err(ConfigError::invalid(
963 "monotone_constraints",
964 format!("feature {}: must be -1, 0, or +1, got {}", i, v),
965 )
966 .into());
967 }
968 }
969 }
970
971 // -- Leaf output clamping --
972 if let Some(max) = c.max_leaf_output {
973 if max <= 0.0 {
974 return Err(
975 ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
976 );
977 }
978 }
979
980 // -- Per-leaf adaptive output bound --
981 if let Some(k) = c.adaptive_leaf_bound {
982 if k <= 0.0 {
983 return Err(
984 ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
985 );
986 }
987 }
988
989 // -- Adaptive depth (per-split information criterion) --
990 if let Some(factor) = c.adaptive_depth {
991 if factor <= 0.0 {
992 return Err(
993 ConfigError::out_of_range("adaptive_depth", "must be > 0", factor).into(),
994 );
995 }
996 }
997
998 // -- Minimum hessian sum --
999 if let Some(min_h) = c.min_hessian_sum {
1000 if min_h <= 0.0 {
1001 return Err(
1002 ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
1003 );
1004 }
1005 }
1006
1007 // -- Huber loss multiplier --
1008 if let Some(k) = c.huber_k {
1009 if k <= 0.0 {
1010 return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
1011 }
1012 }
1013
1014 // -- Shadow warmup --
1015 if let Some(warmup) = c.shadow_warmup {
1016 if warmup == 0 {
1017 return Err(ConfigError::out_of_range(
1018 "shadow_warmup",
1019 "must be > 0",
1020 warmup as f64,
1021 )
1022 .into());
1023 }
1024 }
1025
1026 // -- Quality-based pruning parameters --
1027 if let Some(alpha) = c.quality_prune_alpha {
1028 if alpha <= 0.0 || alpha >= 1.0 {
1029 return Err(ConfigError::out_of_range(
1030 "quality_prune_alpha",
1031 "must be in (0, 1)",
1032 alpha,
1033 )
1034 .into());
1035 }
1036 }
1037 if c.quality_prune_threshold <= 0.0 {
1038 return Err(ConfigError::out_of_range(
1039 "quality_prune_threshold",
1040 "must be > 0",
1041 c.quality_prune_threshold,
1042 )
1043 .into());
1044 }
1045 if c.quality_prune_patience == 0 {
1046 return Err(ConfigError::out_of_range(
1047 "quality_prune_patience",
1048 "must be > 0",
1049 c.quality_prune_patience,
1050 )
1051 .into());
1052 }
1053
1054 // -- Error-weighted sample importance --
1055 if let Some(alpha) = c.error_weight_alpha {
1056 if alpha <= 0.0 || alpha >= 1.0 {
1057 return Err(ConfigError::out_of_range(
1058 "error_weight_alpha",
1059 "must be in (0, 1)",
1060 alpha,
1061 )
1062 .into());
1063 }
1064 }
1065
1066 // -- Drift detector parameters --
1067 match &c.drift_detector {
1068 DriftDetectorType::PageHinkley { delta, lambda } => {
1069 if *delta <= 0.0 {
1070 return Err(ConfigError::out_of_range(
1071 "drift_detector.PageHinkley.delta",
1072 "must be > 0",
1073 delta,
1074 )
1075 .into());
1076 }
1077 if *lambda <= 0.0 {
1078 return Err(ConfigError::out_of_range(
1079 "drift_detector.PageHinkley.lambda",
1080 "must be > 0",
1081 lambda,
1082 )
1083 .into());
1084 }
1085 }
1086 DriftDetectorType::Adwin { delta } => {
1087 if *delta <= 0.0 || *delta >= 1.0 {
1088 return Err(ConfigError::out_of_range(
1089 "drift_detector.Adwin.delta",
1090 "must be in (0, 1)",
1091 delta,
1092 )
1093 .into());
1094 }
1095 }
1096 DriftDetectorType::Ddm {
1097 warning_level,
1098 drift_level,
1099 min_instances,
1100 } => {
1101 if *warning_level <= 0.0 {
1102 return Err(ConfigError::out_of_range(
1103 "drift_detector.Ddm.warning_level",
1104 "must be > 0",
1105 warning_level,
1106 )
1107 .into());
1108 }
1109 if *drift_level <= 0.0 {
1110 return Err(ConfigError::out_of_range(
1111 "drift_detector.Ddm.drift_level",
1112 "must be > 0",
1113 drift_level,
1114 )
1115 .into());
1116 }
1117 if *drift_level <= *warning_level {
1118 return Err(ConfigError::invalid(
1119 "drift_detector.Ddm.drift_level",
1120 format!(
1121 "must be > warning_level ({}), got {}",
1122 warning_level, drift_level
1123 ),
1124 )
1125 .into());
1126 }
1127 if *min_instances == 0 {
1128 return Err(ConfigError::out_of_range(
1129 "drift_detector.Ddm.min_instances",
1130 "must be > 0",
1131 min_instances,
1132 )
1133 .into());
1134 }
1135 }
1136 }
1137
1138 // -- Variant parameters --
1139 match &c.variant {
1140 SGBTVariant::Standard => {} // no extra validation
1141 SGBTVariant::Skip { k } => {
1142 if *k == 0 {
1143 return Err(
1144 ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1145 );
1146 }
1147 }
1148 SGBTVariant::MultipleIterations { multiplier } => {
1149 if *multiplier <= 0.0 {
1150 return Err(ConfigError::out_of_range(
1151 "variant.MultipleIterations.multiplier",
1152 "must be > 0",
1153 multiplier,
1154 )
1155 .into());
1156 }
1157 }
1158 }
1159
1160 Ok(self.config)
1161 }
1162}
1163
1164// ---------------------------------------------------------------------------
1165// Display impls
1166// ---------------------------------------------------------------------------
1167
1168impl core::fmt::Display for DriftDetectorType {
1169 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1170 match self {
1171 Self::PageHinkley { delta, lambda } => {
1172 write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1173 }
1174 Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1175 Self::Ddm {
1176 warning_level,
1177 drift_level,
1178 min_instances,
1179 } => write!(
1180 f,
1181 "Ddm(warning={}, drift={}, min_instances={})",
1182 warning_level, drift_level, min_instances
1183 ),
1184 }
1185 }
1186}
1187
1188impl core::fmt::Display for SGBTConfig {
1189 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1190 write!(
1191 f,
1192 "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1193 self.n_steps,
1194 self.learning_rate,
1195 self.max_depth,
1196 self.n_bins,
1197 self.variant,
1198 self.drift_detector,
1199 )
1200 }
1201}
1202
1203// ---------------------------------------------------------------------------
1204// Tests
1205// ---------------------------------------------------------------------------
1206
1207#[cfg(test)]
1208mod tests {
1209 use super::*;
1210 use alloc::format;
1211 use alloc::vec;
1212
1213 // ------------------------------------------------------------------
1214 // 1. Default config values are correct
1215 // ------------------------------------------------------------------
1216 #[test]
1217 fn default_config_values() {
1218 let cfg = SGBTConfig::default();
1219 assert_eq!(cfg.n_steps, 100);
1220 assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1221 assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1222 assert_eq!(cfg.max_depth, 6);
1223 assert_eq!(cfg.n_bins, 64);
1224 assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1225 assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1226 assert_eq!(cfg.grace_period, 200);
1227 assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1228 assert_eq!(cfg.variant, SGBTVariant::Standard);
1229 }
1230
1231 // ------------------------------------------------------------------
1232 // 2. Builder chain works
1233 // ------------------------------------------------------------------
1234 #[test]
1235 fn builder_chain() {
1236 let cfg = SGBTConfig::builder()
1237 .n_steps(50)
1238 .learning_rate(0.1)
1239 .feature_subsample_rate(0.5)
1240 .max_depth(10)
1241 .n_bins(128)
1242 .lambda(0.5)
1243 .gamma(0.1)
1244 .grace_period(500)
1245 .delta(1e-3)
1246 .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1247 .variant(SGBTVariant::Skip { k: 5 })
1248 .build()
1249 .expect("valid config");
1250
1251 assert_eq!(cfg.n_steps, 50);
1252 assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1253 assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1254 assert_eq!(cfg.max_depth, 10);
1255 assert_eq!(cfg.n_bins, 128);
1256 assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1257 assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1258 assert_eq!(cfg.grace_period, 500);
1259 assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1260 assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1261
1262 // Verify drift detector type is Adwin.
1263 match &cfg.drift_detector {
1264 DriftDetectorType::Adwin { delta } => {
1265 assert!((delta - 0.01).abs() < f64::EPSILON);
1266 }
1267 _ => panic!("expected Adwin drift detector"),
1268 }
1269 }
1270
1271 // ------------------------------------------------------------------
1272 // 3. Validation rejects invalid values
1273 // ------------------------------------------------------------------
1274 #[test]
1275 fn validation_rejects_n_steps_zero() {
1276 let result = SGBTConfig::builder().n_steps(0).build();
1277 assert!(result.is_err());
1278 let msg = format!("{}", result.unwrap_err());
1279 assert!(msg.contains("n_steps"));
1280 }
1281
1282 #[test]
1283 fn validation_rejects_learning_rate_zero() {
1284 let result = SGBTConfig::builder().learning_rate(0.0).build();
1285 assert!(result.is_err());
1286 let msg = format!("{}", result.unwrap_err());
1287 assert!(msg.contains("learning_rate"));
1288 }
1289
1290 #[test]
1291 fn validation_rejects_learning_rate_above_one() {
1292 let result = SGBTConfig::builder().learning_rate(1.5).build();
1293 assert!(result.is_err());
1294 }
1295
1296 #[test]
1297 fn validation_accepts_learning_rate_one() {
1298 let result = SGBTConfig::builder().learning_rate(1.0).build();
1299 assert!(result.is_ok());
1300 }
1301
1302 #[test]
1303 fn validation_rejects_negative_learning_rate() {
1304 let result = SGBTConfig::builder().learning_rate(-0.1).build();
1305 assert!(result.is_err());
1306 }
1307
1308 #[test]
1309 fn validation_rejects_feature_subsample_zero() {
1310 let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1311 assert!(result.is_err());
1312 }
1313
1314 #[test]
1315 fn validation_rejects_feature_subsample_above_one() {
1316 let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1317 assert!(result.is_err());
1318 }
1319
1320 #[test]
1321 fn validation_rejects_max_depth_zero() {
1322 let result = SGBTConfig::builder().max_depth(0).build();
1323 assert!(result.is_err());
1324 }
1325
1326 #[test]
1327 fn validation_rejects_n_bins_one() {
1328 let result = SGBTConfig::builder().n_bins(1).build();
1329 assert!(result.is_err());
1330 }
1331
1332 #[test]
1333 fn validation_rejects_negative_lambda() {
1334 let result = SGBTConfig::builder().lambda(-0.1).build();
1335 assert!(result.is_err());
1336 }
1337
1338 #[test]
1339 fn validation_accepts_zero_lambda() {
1340 let result = SGBTConfig::builder().lambda(0.0).build();
1341 assert!(result.is_ok());
1342 }
1343
1344 #[test]
1345 fn validation_rejects_negative_gamma() {
1346 let result = SGBTConfig::builder().gamma(-0.1).build();
1347 assert!(result.is_err());
1348 }
1349
1350 #[test]
1351 fn validation_rejects_grace_period_zero() {
1352 let result = SGBTConfig::builder().grace_period(0).build();
1353 assert!(result.is_err());
1354 }
1355
1356 #[test]
1357 fn validation_rejects_delta_zero() {
1358 let result = SGBTConfig::builder().delta(0.0).build();
1359 assert!(result.is_err());
1360 }
1361
1362 #[test]
1363 fn validation_rejects_delta_one() {
1364 let result = SGBTConfig::builder().delta(1.0).build();
1365 assert!(result.is_err());
1366 }
1367
1368 // ------------------------------------------------------------------
1369 // 3b. Drift detector parameter validation
1370 // ------------------------------------------------------------------
1371 #[test]
1372 fn validation_rejects_pht_negative_delta() {
1373 let result = SGBTConfig::builder()
1374 .drift_detector(DriftDetectorType::PageHinkley {
1375 delta: -1.0,
1376 lambda: 50.0,
1377 })
1378 .build();
1379 assert!(result.is_err());
1380 let msg = format!("{}", result.unwrap_err());
1381 assert!(msg.contains("PageHinkley"));
1382 }
1383
1384 #[test]
1385 fn validation_rejects_pht_zero_lambda() {
1386 let result = SGBTConfig::builder()
1387 .drift_detector(DriftDetectorType::PageHinkley {
1388 delta: 0.005,
1389 lambda: 0.0,
1390 })
1391 .build();
1392 assert!(result.is_err());
1393 }
1394
1395 #[test]
1396 fn validation_rejects_adwin_delta_out_of_range() {
1397 let result = SGBTConfig::builder()
1398 .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1399 .build();
1400 assert!(result.is_err());
1401
1402 let result = SGBTConfig::builder()
1403 .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1404 .build();
1405 assert!(result.is_err());
1406 }
1407
1408 #[test]
1409 fn validation_rejects_ddm_warning_above_drift() {
1410 let result = SGBTConfig::builder()
1411 .drift_detector(DriftDetectorType::Ddm {
1412 warning_level: 3.0,
1413 drift_level: 2.0,
1414 min_instances: 30,
1415 })
1416 .build();
1417 assert!(result.is_err());
1418 let msg = format!("{}", result.unwrap_err());
1419 assert!(msg.contains("drift_level"));
1420 assert!(msg.contains("must be > warning_level"));
1421 }
1422
1423 #[test]
1424 fn validation_rejects_ddm_equal_levels() {
1425 let result = SGBTConfig::builder()
1426 .drift_detector(DriftDetectorType::Ddm {
1427 warning_level: 2.0,
1428 drift_level: 2.0,
1429 min_instances: 30,
1430 })
1431 .build();
1432 assert!(result.is_err());
1433 }
1434
1435 #[test]
1436 fn validation_rejects_ddm_zero_min_instances() {
1437 let result = SGBTConfig::builder()
1438 .drift_detector(DriftDetectorType::Ddm {
1439 warning_level: 2.0,
1440 drift_level: 3.0,
1441 min_instances: 0,
1442 })
1443 .build();
1444 assert!(result.is_err());
1445 }
1446
1447 #[test]
1448 fn validation_rejects_ddm_zero_warning_level() {
1449 let result = SGBTConfig::builder()
1450 .drift_detector(DriftDetectorType::Ddm {
1451 warning_level: 0.0,
1452 drift_level: 3.0,
1453 min_instances: 30,
1454 })
1455 .build();
1456 assert!(result.is_err());
1457 }
1458
1459 // ------------------------------------------------------------------
1460 // 3c. Variant parameter validation
1461 // ------------------------------------------------------------------
1462 #[test]
1463 fn validation_rejects_skip_k_zero() {
1464 let result = SGBTConfig::builder()
1465 .variant(SGBTVariant::Skip { k: 0 })
1466 .build();
1467 assert!(result.is_err());
1468 let msg = format!("{}", result.unwrap_err());
1469 assert!(msg.contains("Skip"));
1470 }
1471
1472 #[test]
1473 fn validation_rejects_mi_zero_multiplier() {
1474 let result = SGBTConfig::builder()
1475 .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1476 .build();
1477 assert!(result.is_err());
1478 }
1479
1480 #[test]
1481 fn validation_rejects_mi_negative_multiplier() {
1482 let result = SGBTConfig::builder()
1483 .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1484 .build();
1485 assert!(result.is_err());
1486 }
1487
1488 #[test]
1489 fn validation_accepts_standard_variant() {
1490 let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1491 assert!(result.is_ok());
1492 }
1493
1494 // ------------------------------------------------------------------
1495 // 4. DriftDetectorType creates correct detector types
1496 // ------------------------------------------------------------------
1497 #[test]
1498 fn drift_detector_type_creates_page_hinkley() {
1499 let dt = DriftDetectorType::PageHinkley {
1500 delta: 0.01,
1501 lambda: 100.0,
1502 };
1503 let mut detector = dt.create();
1504
1505 // Should start with zero mean.
1506 assert_eq!(detector.estimated_mean(), 0.0);
1507
1508 // Feed a stable value -- should not drift.
1509 let signal = detector.update(1.0);
1510 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1511 }
1512
1513 #[test]
1514 fn drift_detector_type_creates_adwin() {
1515 let dt = DriftDetectorType::Adwin { delta: 0.05 };
1516 let mut detector = dt.create();
1517
1518 assert_eq!(detector.estimated_mean(), 0.0);
1519 let signal = detector.update(1.0);
1520 assert_ne!(signal, crate::drift::DriftSignal::Drift);
1521 }
1522
1523 #[test]
1524 fn drift_detector_type_creates_ddm() {
1525 let dt = DriftDetectorType::Ddm {
1526 warning_level: 2.0,
1527 drift_level: 3.0,
1528 min_instances: 30,
1529 };
1530 let mut detector = dt.create();
1531
1532 assert_eq!(detector.estimated_mean(), 0.0);
1533 let signal = detector.update(0.1);
1534 assert_eq!(signal, crate::drift::DriftSignal::Stable);
1535 }
1536
1537 // ------------------------------------------------------------------
1538 // 5. Default drift detector is PageHinkley
1539 // ------------------------------------------------------------------
1540 #[test]
1541 fn default_drift_detector_is_page_hinkley() {
1542 let dt = DriftDetectorType::default();
1543 match dt {
1544 DriftDetectorType::PageHinkley { delta, lambda } => {
1545 assert!((delta - 0.005).abs() < f64::EPSILON);
1546 assert!((lambda - 50.0).abs() < f64::EPSILON);
1547 }
1548 _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1549 }
1550 }
1551
1552 // ------------------------------------------------------------------
1553 // 6. Default config builds without error
1554 // ------------------------------------------------------------------
1555 #[test]
1556 fn default_config_builds() {
1557 let result = SGBTConfig::builder().build();
1558 assert!(result.is_ok());
1559 }
1560
1561 // ------------------------------------------------------------------
1562 // 7. Config clone preserves all fields
1563 // ------------------------------------------------------------------
1564 #[test]
1565 fn config_clone_preserves_fields() {
1566 let original = SGBTConfig::builder()
1567 .n_steps(50)
1568 .learning_rate(0.05)
1569 .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1570 .build()
1571 .unwrap();
1572
1573 let cloned = original.clone();
1574
1575 assert_eq!(cloned.n_steps, 50);
1576 assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1577 assert_eq!(
1578 cloned.variant,
1579 SGBTVariant::MultipleIterations { multiplier: 5.0 }
1580 );
1581 }
1582
1583 // ------------------------------------------------------------------
1584 // 8. All three DDM valid configs accepted
1585 // ------------------------------------------------------------------
1586 #[test]
1587 fn valid_ddm_config_accepted() {
1588 let result = SGBTConfig::builder()
1589 .drift_detector(DriftDetectorType::Ddm {
1590 warning_level: 1.5,
1591 drift_level: 2.5,
1592 min_instances: 10,
1593 })
1594 .build();
1595 assert!(result.is_ok());
1596 }
1597
1598 // ------------------------------------------------------------------
1599 // 9. Created detectors are functional (round-trip test)
1600 // ------------------------------------------------------------------
1601 #[test]
1602 fn created_detectors_are_functional() {
1603 // Create a detector, feed it data, verify it can detect drift.
1604 let dt = DriftDetectorType::PageHinkley {
1605 delta: 0.005,
1606 lambda: 50.0,
1607 };
1608 let mut detector = dt.create();
1609
1610 // Feed 500 stable values.
1611 for _ in 0..500 {
1612 detector.update(1.0);
1613 }
1614
1615 // Feed shifted values -- should eventually drift.
1616 let mut drifted = false;
1617 for _ in 0..500 {
1618 if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1619 drifted = true;
1620 break;
1621 }
1622 }
1623 assert!(
1624 drifted,
1625 "detector created from DriftDetectorType should be functional"
1626 );
1627 }
1628
1629 // ------------------------------------------------------------------
1630 // 10. Boundary acceptance: n_bins=2 is the minimum valid
1631 // ------------------------------------------------------------------
1632 #[test]
1633 fn boundary_n_bins_two_accepted() {
1634 let result = SGBTConfig::builder().n_bins(2).build();
1635 assert!(result.is_ok());
1636 }
1637
1638 // ------------------------------------------------------------------
1639 // 11. Boundary acceptance: grace_period=1 is valid
1640 // ------------------------------------------------------------------
1641 #[test]
1642 fn boundary_grace_period_one_accepted() {
1643 let result = SGBTConfig::builder().grace_period(1).build();
1644 assert!(result.is_ok());
1645 }
1646
1647 // ------------------------------------------------------------------
1648 // 12. Feature names -- valid config with names
1649 // ------------------------------------------------------------------
1650 #[test]
1651 fn feature_names_accepted() {
1652 let cfg = SGBTConfig::builder()
1653 .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1654 .build()
1655 .unwrap();
1656 assert_eq!(
1657 cfg.feature_names.as_ref().unwrap(),
1658 &["price", "volume", "spread"]
1659 );
1660 }
1661
1662 // ------------------------------------------------------------------
1663 // 13. Feature names -- duplicate names rejected
1664 // ------------------------------------------------------------------
1665 #[test]
1666 fn feature_names_rejects_duplicates() {
1667 let result = SGBTConfig::builder()
1668 .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1669 .build();
1670 assert!(result.is_err());
1671 let msg = format!("{}", result.unwrap_err());
1672 assert!(msg.contains("duplicate"));
1673 }
1674
1675 // ------------------------------------------------------------------
1676 // 14. Feature names -- serde backward compat (missing field)
1677 // Requires serde_json which is only in the full irithyll crate.
1678 // ------------------------------------------------------------------
1679
1680 // ------------------------------------------------------------------
1681 // 15. Feature names -- empty vec is valid
1682 // ------------------------------------------------------------------
1683 #[test]
1684 fn feature_names_empty_vec_accepted() {
1685 let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1686 assert!(cfg.feature_names.unwrap().is_empty());
1687 }
1688
1689 // ------------------------------------------------------------------
1690 // 16. Adaptive leaf bound -- builder
1691 // ------------------------------------------------------------------
1692 #[test]
1693 fn builder_adaptive_leaf_bound() {
1694 let cfg = SGBTConfig::builder()
1695 .adaptive_leaf_bound(3.0)
1696 .build()
1697 .unwrap();
1698 assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1699 }
1700
1701 // ------------------------------------------------------------------
1702 // 17. Adaptive leaf bound -- validation rejects zero
1703 // ------------------------------------------------------------------
1704 #[test]
1705 fn validation_rejects_zero_adaptive_leaf_bound() {
1706 let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1707 assert!(result.is_err());
1708 let msg = format!("{}", result.unwrap_err());
1709 assert!(
1710 msg.contains("adaptive_leaf_bound"),
1711 "error should mention adaptive_leaf_bound: {}",
1712 msg,
1713 );
1714 }
1715
1716 // ------------------------------------------------------------------
1717 // 18. Adaptive leaf bound -- serde backward compat
1718 // Requires serde_json which is only in the full irithyll crate.
1719 // ------------------------------------------------------------------
1720}