Skip to main content

irithyll_core/ensemble/
config.rs

1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34///   `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1) compute.
35///   Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38///   log(σ) from features (NGBoost-style dual chain).  Gives *feature-conditional*
39///   uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43    /// EWMA of squared prediction errors — always calibrated, O(1).
44    #[default]
45    Empirical,
46    /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47    TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61    /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62    /// (detection threshold).
63    PageHinkley {
64        /// Magnitude tolerance. Default 0.005.
65        delta: f64,
66        /// Detection threshold. Default 50.0.
67        lambda: f64,
68    },
69
70    /// ADWIN with custom confidence parameter.
71    Adwin {
72        /// Confidence (smaller = fewer false positives). Default 0.002.
73        delta: f64,
74    },
75
76    /// DDM with custom warning/drift levels and minimum warmup instances.
77    Ddm {
78        /// Warning threshold multiplier. Default 2.0.
79        warning_level: f64,
80        /// Drift threshold multiplier. Default 3.0.
81        drift_level: f64,
82        /// Minimum observations before detection activates. Default 30.
83        min_instances: u64,
84    },
85}
86
87impl Default for DriftDetectorType {
88    fn default() -> Self {
89        DriftDetectorType::PageHinkley {
90            delta: 0.005,
91            lambda: 50.0,
92        }
93    }
94}
95
96impl DriftDetectorType {
97    /// Create a new, fresh drift detector from this configuration.
98    pub fn create(&self) -> Box<dyn DriftDetector> {
99        match self {
100            Self::PageHinkley { delta, lambda } => {
101                Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102            }
103            Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104            Self::Ddm {
105                warning_level,
106                drift_level,
107                min_instances,
108            } => Box::new(Ddm::with_params(
109                *warning_level,
110                *drift_level,
111                *min_instances,
112            )),
113        }
114    }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter                | Default              |
128/// |--------------------------|----------------------|
129/// | `n_steps`                | 100                  |
130/// | `learning_rate`          | 0.0125               |
131/// | `feature_subsample_rate` | 0.75                 |
132/// | `max_depth`              | 6                    |
133/// | `n_bins`                 | 64                   |
134/// | `lambda`                 | 1.0                  |
135/// | `gamma`                  | 0.0                  |
136/// | `grace_period`           | 200                  |
137/// | `delta`                  | 1e-7                 |
138/// | `drift_detector`         | PageHinkley(0.005, 50.0) |
139/// | `variant`                | Standard             |
140/// | `seed`                   | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count`   | 50                   |
142/// | `leaf_half_life`         | None (disabled)      |
143/// | `max_tree_samples`       | None (disabled)      |
144/// | `split_reeval_interval`  | None (disabled)      |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148    /// Number of boosting steps (trees in the ensemble). Default 100.
149    pub n_steps: usize,
150    /// Learning rate (shrinkage). Default 0.0125.
151    pub learning_rate: f64,
152    /// Fraction of features to subsample per tree. Default 0.75.
153    pub feature_subsample_rate: f64,
154    /// Maximum tree depth. Default 6.
155    pub max_depth: usize,
156    /// Number of histogram bins. Default 64.
157    pub n_bins: usize,
158    /// L2 regularization parameter (lambda). Default 1.0.
159    pub lambda: f64,
160    /// Minimum split gain (gamma). Default 0.0.
161    pub gamma: f64,
162    /// Grace period: min samples before evaluating splits. Default 200.
163    pub grace_period: usize,
164    /// Hoeffding bound confidence (delta). Default 1e-7.
165    pub delta: f64,
166    /// Drift detector type for tree replacement. Default: PageHinkley.
167    pub drift_detector: DriftDetectorType,
168    /// SGBT computational variant. Default: Standard.
169    pub variant: SGBTVariant,
170    /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171    ///
172    /// Controls feature subsampling and variant skip/MI stochastic decisions.
173    /// Two models with the same seed and same data will produce identical results.
174    pub seed: u64,
175    /// Number of initial targets to collect before computing the base prediction.
176    /// Default: 50.
177    pub initial_target_count: usize,
178
179    /// Half-life for exponential leaf decay (in samples per leaf).
180    ///
181    /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182    /// statistics have half the weight of the most recent sample. This causes
183    /// the model to continuously adapt to changing data distributions rather
184    /// than freezing on early observations.
185    ///
186    /// `None` (default) disables decay -- traditional monotonic accumulation.
187    #[cfg_attr(feature = "serde", serde(default))]
188    pub leaf_half_life: Option<usize>,
189
190    /// Maximum samples a single tree processes before proactive replacement.
191    ///
192    /// After this many samples, the tree is replaced with a fresh one regardless
193    /// of drift detector state. Prevents stale tree structure from persisting
194    /// when the drift detector is not sensitive enough.
195    ///
196    /// `None` (default) disables time-based replacement.
197    #[cfg_attr(feature = "serde", serde(default))]
198    pub max_tree_samples: Option<u64>,
199
200    /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
201    /// whether a split would improve them.
202    ///
203    /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
204    /// `split_reeval_interval` samples since its last evaluation and has reached
205    /// max depth, it re-evaluates whether a split should be performed.
206    ///
207    /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
208    #[cfg_attr(feature = "serde", serde(default))]
209    pub split_reeval_interval: Option<usize>,
210
211    /// Optional human-readable feature names.
212    ///
213    /// When set, enables `named_feature_importances` and
214    /// `train_one_named` for production-friendly named access.
215    /// Length must match the number of features in training data.
216    #[cfg_attr(feature = "serde", serde(default))]
217    pub feature_names: Option<Vec<String>>,
218
219    /// Optional per-feature type declarations.
220    ///
221    /// When set, declares which features are categorical vs continuous.
222    /// Categorical features use one-bin-per-category binning and Fisher
223    /// optimal binary partitioning for split evaluation.
224    /// Length must match the number of features in training data.
225    ///
226    /// `None` (default) treats all features as continuous.
227    #[cfg_attr(feature = "serde", serde(default))]
228    pub feature_types: Option<Vec<FeatureType>>,
229
230    /// Gradient clipping threshold in standard deviations per leaf.
231    ///
232    /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
233    /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
234    /// clamped to the boundary. This prevents outlier samples from corrupting
235    /// leaf statistics, which is critical in streaming settings where sudden
236    /// label floods can destabilize the model.
237    ///
238    /// Typical value: 3.0 (3-sigma clipping).
239    /// `None` (default) disables gradient clipping.
240    #[cfg_attr(feature = "serde", serde(default))]
241    pub gradient_clip_sigma: Option<f64>,
242
243    /// Per-feature monotonic constraints.
244    ///
245    /// Each element specifies the monotonic relationship between a feature and
246    /// the prediction:
247    /// - `+1`: prediction must be non-decreasing as feature value increases.
248    /// - `-1`: prediction must be non-increasing as feature value increases.
249    /// - `0`: no constraint (unconstrained).
250    ///
251    /// During split evaluation, candidate splits that would violate monotonicity
252    /// (left child value > right child value for +1 constraints, or vice versa)
253    /// are rejected.
254    ///
255    /// Length must match the number of features in training data.
256    /// `None` (default) means no monotonic constraints.
257    #[cfg_attr(feature = "serde", serde(default))]
258    pub monotone_constraints: Option<Vec<i8>>,
259
260    /// EWMA smoothing factor for quality-based tree pruning.
261    ///
262    /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
263    /// moving average of its marginal contribution to the ensemble. Trees whose
264    /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
265    /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
266    /// samples are replaced with a fresh tree that can learn the current regime.
267    ///
268    /// This prevents "dead wood" -- trees from a past regime that no longer
269    /// contribute meaningfully to ensemble accuracy.
270    ///
271    /// `None` (default) disables quality-based pruning.
272    /// Suggested value: 0.01.
273    #[cfg_attr(feature = "serde", serde(default))]
274    pub quality_prune_alpha: Option<f64>,
275
276    /// Minimum contribution threshold for quality-based pruning.
277    ///
278    /// A tree's EWMA contribution must stay above this value to avoid being
279    /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
280    ///
281    /// Default: 1e-6.
282    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
283    pub quality_prune_threshold: f64,
284
285    /// Consecutive low-contribution samples before a tree is replaced.
286    ///
287    /// After this many consecutive samples where a tree's EWMA contribution
288    /// is below `quality_prune_threshold`, the tree is reset. Only used when
289    /// `quality_prune_alpha` is `Some`.
290    ///
291    /// Default: 500.
292    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
293    pub quality_prune_patience: u64,
294
295    /// EWMA smoothing factor for error-weighted sample importance.
296    ///
297    /// When `Some(alpha)`, samples the model predicted poorly get higher
298    /// effective weight during histogram accumulation. The weight is:
299    /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
300    ///
301    /// This is a streaming version of AdaBoost's reweighting applied at the
302    /// gradient level -- learning capacity focuses on hard/novel patterns,
303    /// enabling faster adaptation to regime changes.
304    ///
305    /// `None` (default) disables error weighting.
306    /// Suggested value: 0.01.
307    #[cfg_attr(feature = "serde", serde(default))]
308    pub error_weight_alpha: Option<f64>,
309
310    /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
311    ///
312    /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
313    /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
314    /// is an EWMA of the model's predicted σ (alpha = 0.001).
315    ///
316    /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
317    /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
318    /// trains at the unmodulated base rate to prevent positive feedback loops.
319    ///
320    /// Default: `false`.
321    #[cfg_attr(feature = "serde", serde(default))]
322    pub uncertainty_modulated_lr: bool,
323
324    /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
325    ///
326    /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
327    ///   errors.  `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1).
328    /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
329    ///   separate tree ensemble predicting log(σ) from features.
330    ///
331    /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
332    /// is strongly recommended — scale tree gradients are inherently weak and
333    /// the trees often fail to split.
334    #[cfg_attr(feature = "serde", serde(default))]
335    pub scale_mode: ScaleMode,
336
337    /// EWMA smoothing factor for empirical σ estimation.
338    ///
339    /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
340    /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
341    /// Higher values react faster to regime changes but are noisier.
342    ///
343    /// Default: `0.01` (~100-sample effective window).
344    #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
345    pub empirical_sigma_alpha: f64,
346
347    /// Maximum absolute leaf output value.
348    ///
349    /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
350    /// Prevents runaway leaf weights from causing prediction explosions
351    /// in feedback loops. `None` (default) means no clamping.
352    #[cfg_attr(feature = "serde", serde(default))]
353    pub max_leaf_output: Option<f64>,
354
355    /// Per-leaf adaptive output bound (sigma multiplier).
356    ///
357    /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
358    /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
359    /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
360    ///
361    /// This is strictly superior to `max_leaf_output` for streaming — the bound
362    /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
363    /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
364    ///
365    /// Typical value: 3.0 (3-sigma bound).
366    /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
367    #[cfg_attr(feature = "serde", serde(default))]
368    pub adaptive_leaf_bound: Option<f64>,
369
370    /// Minimum hessian sum before a leaf produces non-zero output.
371    ///
372    /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
373    /// Prevents post-replacement spikes from fresh leaves with insufficient
374    /// samples. `None` (default) means all leaves contribute immediately.
375    #[cfg_attr(feature = "serde", serde(default))]
376    pub min_hessian_sum: Option<f64>,
377
378    /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
379    ///
380    /// When `Some(k)`, the distributional location gradient uses Huber loss
381    /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
382    /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
383    /// `None` (default) uses squared loss.
384    #[cfg_attr(feature = "serde", serde(default))]
385    pub huber_k: Option<f64>,
386
387    /// Shadow warmup for graduated tree handoff.
388    ///
389    /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
390    /// alongside every active tree. The shadow trains on the same gradient stream
391    /// but does not contribute to predictions until it has seen `n` samples.
392    ///
393    /// As the active tree ages past 80% of `max_tree_samples`, its prediction
394    /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
395    /// over `n` samples after warmup. When the active weight reaches 0, the shadow
396    /// is promoted and a new shadow is spawned — no cold-start prediction dip.
397    ///
398    /// Requires `max_tree_samples` to be set for time-based graduated handoff.
399    /// Drift-based replacement still uses hard swap (shadow is already warm).
400    ///
401    /// `None` (default) disables graduated handoff — uses traditional hard swap.
402    #[cfg_attr(feature = "serde", serde(default))]
403    pub shadow_warmup: Option<usize>,
404
405    /// Leaf prediction model type.
406    ///
407    /// Controls how each leaf computes its prediction:
408    /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
409    /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
410    ///   AdaGrad optimization. Optional `decay` for concept drift. Recommended for
411    ///   low-depth trees (depth 2--4).
412    /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
413    ///   Optional `decay` for concept drift.
414    /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
415    ///   when the Hoeffding bound confirms a more complex model is better.
416    ///
417    /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
418    #[cfg_attr(feature = "serde", serde(default))]
419    pub leaf_model_type: LeafModelType,
420
421    /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
422    ///
423    /// When non-zero, the distributional model maintains a packed f32 cache of
424    /// its location ensemble that is re-exported every `packed_refresh_interval`
425    /// training samples. Predictions use the cache for O(1)-per-tree inference
426    /// via contiguous memory traversal, falling back to full tree traversal when
427    /// the cache is absent or produces non-finite results.
428    ///
429    /// `0` (default) disables the packed cache.
430    #[cfg_attr(feature = "serde", serde(default))]
431    pub packed_refresh_interval: u64,
432}
433
434#[cfg(feature = "serde")]
435fn default_empirical_sigma_alpha() -> f64 {
436    0.01
437}
438
439#[cfg(feature = "serde")]
440fn default_quality_prune_threshold() -> f64 {
441    1e-6
442}
443
444#[cfg(feature = "serde")]
445fn default_quality_prune_patience() -> u64 {
446    500
447}
448
449impl Default for SGBTConfig {
450    fn default() -> Self {
451        Self {
452            n_steps: 100,
453            learning_rate: 0.0125,
454            feature_subsample_rate: 0.75,
455            max_depth: 6,
456            n_bins: 64,
457            lambda: 1.0,
458            gamma: 0.0,
459            grace_period: 200,
460            delta: 1e-7,
461            drift_detector: DriftDetectorType::default(),
462            variant: SGBTVariant::default(),
463            seed: 0xDEAD_BEEF_CAFE_4242,
464            initial_target_count: 50,
465            leaf_half_life: None,
466            max_tree_samples: None,
467            split_reeval_interval: None,
468            feature_names: None,
469            feature_types: None,
470            gradient_clip_sigma: None,
471            monotone_constraints: None,
472            quality_prune_alpha: None,
473            quality_prune_threshold: 1e-6,
474            quality_prune_patience: 500,
475            error_weight_alpha: None,
476            uncertainty_modulated_lr: false,
477            scale_mode: ScaleMode::default(),
478            empirical_sigma_alpha: 0.01,
479            max_leaf_output: None,
480            adaptive_leaf_bound: None,
481            min_hessian_sum: None,
482            huber_k: None,
483            shadow_warmup: None,
484            leaf_model_type: LeafModelType::default(),
485            packed_refresh_interval: 0,
486        }
487    }
488}
489
490impl SGBTConfig {
491    /// Start building a configuration via the builder pattern.
492    pub fn builder() -> SGBTConfigBuilder {
493        SGBTConfigBuilder::default()
494    }
495}
496
497// ---------------------------------------------------------------------------
498// SGBTConfigBuilder
499// ---------------------------------------------------------------------------
500
501/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
502///
503/// # Example
504///
505/// ```ignore
506/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
507/// use irithyll::ensemble::variants::SGBTVariant;
508///
509/// let config = SGBTConfig::builder()
510///     .n_steps(200)
511///     .learning_rate(0.05)
512///     .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
513///     .variant(SGBTVariant::Skip { k: 10 })
514///     .build()
515///     .expect("valid config");
516/// ```
517#[derive(Debug, Clone, Default)]
518pub struct SGBTConfigBuilder {
519    config: SGBTConfig,
520}
521
522impl SGBTConfigBuilder {
523    /// Set the number of boosting steps (trees in the ensemble).
524    pub fn n_steps(mut self, n: usize) -> Self {
525        self.config.n_steps = n;
526        self
527    }
528
529    /// Set the learning rate (shrinkage factor).
530    pub fn learning_rate(mut self, lr: f64) -> Self {
531        self.config.learning_rate = lr;
532        self
533    }
534
535    /// Set the fraction of features to subsample per tree.
536    pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
537        self.config.feature_subsample_rate = rate;
538        self
539    }
540
541    /// Set the maximum tree depth.
542    pub fn max_depth(mut self, depth: usize) -> Self {
543        self.config.max_depth = depth;
544        self
545    }
546
547    /// Set the number of histogram bins per feature.
548    pub fn n_bins(mut self, bins: usize) -> Self {
549        self.config.n_bins = bins;
550        self
551    }
552
553    /// Set the L2 regularization parameter (lambda).
554    pub fn lambda(mut self, l: f64) -> Self {
555        self.config.lambda = l;
556        self
557    }
558
559    /// Set the minimum split gain (gamma).
560    pub fn gamma(mut self, g: f64) -> Self {
561        self.config.gamma = g;
562        self
563    }
564
565    /// Set the grace period (minimum samples before evaluating splits).
566    pub fn grace_period(mut self, gp: usize) -> Self {
567        self.config.grace_period = gp;
568        self
569    }
570
571    /// Set the Hoeffding bound confidence parameter (delta).
572    pub fn delta(mut self, d: f64) -> Self {
573        self.config.delta = d;
574        self
575    }
576
577    /// Set the drift detector type for tree replacement.
578    pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
579        self.config.drift_detector = dt;
580        self
581    }
582
583    /// Set the SGBT computational variant.
584    pub fn variant(mut self, v: SGBTVariant) -> Self {
585        self.config.variant = v;
586        self
587    }
588
589    /// Set the random seed for deterministic reproducibility.
590    ///
591    /// Controls feature subsampling and variant skip/MI stochastic decisions.
592    /// Two models with the same seed and data sequence will produce identical results.
593    pub fn seed(mut self, seed: u64) -> Self {
594        self.config.seed = seed;
595        self
596    }
597
598    /// Set the number of initial targets to collect before computing the base prediction.
599    ///
600    /// The model collects this many target values before initializing the base
601    /// prediction (via `loss.initial_prediction`). Default: 50.
602    pub fn initial_target_count(mut self, count: usize) -> Self {
603        self.config.initial_target_count = count;
604        self
605    }
606
607    /// Set the half-life for exponential leaf decay (in samples per leaf).
608    ///
609    /// After `n` samples, a leaf's accumulated statistics have half the weight
610    /// of the most recent sample. Enables continuous adaptation to concept drift.
611    pub fn leaf_half_life(mut self, n: usize) -> Self {
612        self.config.leaf_half_life = Some(n);
613        self
614    }
615
616    /// Set the maximum samples a single tree processes before proactive replacement.
617    ///
618    /// After `n` samples, the tree is replaced regardless of drift detector state.
619    pub fn max_tree_samples(mut self, n: u64) -> Self {
620        self.config.max_tree_samples = Some(n);
621        self
622    }
623
624    /// Set the split re-evaluation interval for max-depth leaves.
625    ///
626    /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
627    /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
628    pub fn split_reeval_interval(mut self, n: usize) -> Self {
629        self.config.split_reeval_interval = Some(n);
630        self
631    }
632
633    /// Set human-readable feature names.
634    ///
635    /// Enables named feature importances and named training input.
636    /// Names must be unique; validated at [`build()`](Self::build).
637    pub fn feature_names(mut self, names: Vec<String>) -> Self {
638        self.config.feature_names = Some(names);
639        self
640    }
641
642    /// Set per-feature type declarations.
643    ///
644    /// Declares which features are categorical vs continuous. Categorical features
645    /// use one-bin-per-category binning and Fisher optimal binary partitioning.
646    /// Supports up to 64 distinct category values per categorical feature.
647    pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
648        self.config.feature_types = Some(types);
649        self
650    }
651
652    /// Set per-leaf gradient clipping threshold (in standard deviations).
653    ///
654    /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
655    /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
656    /// from corrupting streaming model stability.
657    ///
658    /// Typical value: 3.0 (3-sigma clipping).
659    pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
660        self.config.gradient_clip_sigma = Some(sigma);
661        self
662    }
663
664    /// Set per-feature monotonic constraints.
665    ///
666    /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
667    /// Candidate splits violating monotonicity are rejected during tree growth.
668    pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
669        self.config.monotone_constraints = Some(constraints);
670        self
671    }
672
673    /// Enable quality-based tree pruning with the given EWMA smoothing factor.
674    ///
675    /// Trees whose marginal contribution drops below the threshold for
676    /// `patience` consecutive samples are replaced with fresh trees.
677    /// Suggested alpha: 0.01.
678    pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
679        self.config.quality_prune_alpha = Some(alpha);
680        self
681    }
682
683    /// Set the minimum contribution threshold for quality-based pruning.
684    ///
685    /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
686    pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
687        self.config.quality_prune_threshold = threshold;
688        self
689    }
690
691    /// Set the patience (consecutive low-contribution samples) before pruning.
692    ///
693    /// Default: 500. Only relevant when `quality_prune_alpha` is set.
694    pub fn quality_prune_patience(mut self, patience: u64) -> Self {
695        self.config.quality_prune_patience = patience;
696        self
697    }
698
699    /// Enable error-weighted sample importance with the given EWMA smoothing factor.
700    ///
701    /// Samples the model predicted poorly get higher effective weight.
702    /// Suggested alpha: 0.01.
703    pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
704        self.config.error_weight_alpha = Some(alpha);
705        self
706    }
707
708    /// Enable σ-modulated learning rate for distributional models.
709    ///
710    /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
711    /// so the model adapts faster during high-uncertainty regimes and conserves
712    /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
713    ///
714    /// By default uses empirical σ (EWMA of squared errors).  Set
715    /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
716    pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
717        self.config.uncertainty_modulated_lr = enabled;
718        self
719    }
720
721    /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
722    ///
723    /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
724    /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
725    pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
726        self.config.scale_mode = mode;
727        self
728    }
729
730    /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
731    ///
732    /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
733    pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
734        self.config.empirical_sigma_alpha = alpha;
735        self
736    }
737
738    /// Set the maximum absolute leaf output value.
739    ///
740    /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
741    /// that cause prediction explosions.
742    pub fn max_leaf_output(mut self, max: f64) -> Self {
743        self.config.max_leaf_output = Some(max);
744        self
745    }
746
747    /// Set per-leaf adaptive output bound (sigma multiplier).
748    ///
749    /// Each leaf tracks EWMA of its own output weight and clamps to
750    /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
751    /// Recommended: use with `leaf_half_life` for streaming scenarios.
752    pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
753        self.config.adaptive_leaf_bound = Some(k);
754        self
755    }
756
757    /// Set the minimum hessian sum for leaf output.
758    ///
759    /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
760    /// post-replacement spikes.
761    pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
762        self.config.min_hessian_sum = Some(min_h);
763        self
764    }
765
766    /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
767    ///
768    /// When set, location gradients use Huber loss with adaptive
769    /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
770    pub fn huber_k(mut self, k: f64) -> Self {
771        self.config.huber_k = Some(k);
772        self
773    }
774
775    /// Enable graduated tree handoff with the given shadow warmup samples.
776    ///
777    /// Spawns an always-on shadow tree that trains alongside the active tree.
778    /// After `warmup` samples, the shadow begins contributing to predictions
779    /// via graduated blending. Eliminates prediction dips during tree replacement.
780    pub fn shadow_warmup(mut self, warmup: usize) -> Self {
781        self.config.shadow_warmup = Some(warmup);
782        self
783    }
784
785    /// Set the leaf prediction model type.
786    ///
787    /// [`LeafModelType::Linear`] is recommended for low-depth configurations
788    /// (depth 2--4) where per-leaf linear models reduce approximation error.
789    ///
790    /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
791    /// a trainable model per leaf, using the Hoeffding bound for promotion.
792    pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
793        self.config.leaf_model_type = lmt;
794        self
795    }
796
797    /// Set the packed cache refresh interval for distributional models.
798    ///
799    /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
800    /// maintains a packed f32 cache refreshed every `interval` training samples.
801    /// `0` (default) disables the cache.
802    pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
803        self.config.packed_refresh_interval = interval;
804        self
805    }
806
807    /// Validate and build the configuration.
808    ///
809    /// # Errors
810    ///
811    /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
812    /// [`ConfigError`] if any parameter is out of its valid range.
813    pub fn build(self) -> Result<SGBTConfig> {
814        let c = &self.config;
815
816        // -- Ensemble-level parameters --
817        if c.n_steps == 0 {
818            return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
819        }
820        if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
821            return Err(ConfigError::out_of_range(
822                "learning_rate",
823                "must be in (0, 1]",
824                c.learning_rate,
825            )
826            .into());
827        }
828        if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
829            return Err(ConfigError::out_of_range(
830                "feature_subsample_rate",
831                "must be in (0, 1]",
832                c.feature_subsample_rate,
833            )
834            .into());
835        }
836
837        // -- Tree-level parameters --
838        if c.max_depth == 0 {
839            return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
840        }
841        if c.n_bins < 2 {
842            return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
843        }
844        if c.lambda < 0.0 {
845            return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
846        }
847        if c.gamma < 0.0 {
848            return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
849        }
850        if c.grace_period == 0 {
851            return Err(
852                ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
853            );
854        }
855        if c.delta <= 0.0 || c.delta >= 1.0 {
856            return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
857        }
858
859        if c.initial_target_count == 0 {
860            return Err(ConfigError::out_of_range(
861                "initial_target_count",
862                "must be > 0",
863                c.initial_target_count,
864            )
865            .into());
866        }
867
868        // -- Streaming adaptation parameters --
869        if let Some(hl) = c.leaf_half_life {
870            if hl == 0 {
871                return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
872            }
873        }
874        if let Some(max) = c.max_tree_samples {
875            if max < 100 {
876                return Err(
877                    ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
878                );
879            }
880        }
881        if let Some(interval) = c.split_reeval_interval {
882            if interval < c.grace_period {
883                return Err(ConfigError::invalid(
884                    "split_reeval_interval",
885                    format!(
886                        "must be >= grace_period ({}), got {}",
887                        c.grace_period, interval
888                    ),
889                )
890                .into());
891            }
892        }
893
894        // -- Feature names --
895        if let Some(ref names) = c.feature_names {
896            // O(n^2) duplicate check — names list is small so no HashSet needed
897            for (i, name) in names.iter().enumerate() {
898                for prev in &names[..i] {
899                    if name == prev {
900                        return Err(ConfigError::invalid(
901                            "feature_names",
902                            format!("duplicate feature name: '{}'", name),
903                        )
904                        .into());
905                    }
906                }
907            }
908        }
909
910        // -- Feature types --
911        if let Some(ref types) = c.feature_types {
912            if let Some(ref names) = c.feature_names {
913                if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
914                    return Err(ConfigError::invalid(
915                        "feature_types",
916                        format!(
917                            "length ({}) must match feature_names length ({})",
918                            types.len(),
919                            names.len()
920                        ),
921                    )
922                    .into());
923                }
924            }
925        }
926
927        // -- Gradient clipping --
928        if let Some(sigma) = c.gradient_clip_sigma {
929            if sigma <= 0.0 {
930                return Err(
931                    ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
932                );
933            }
934        }
935
936        // -- Monotonic constraints --
937        if let Some(ref mc) = c.monotone_constraints {
938            for (i, &v) in mc.iter().enumerate() {
939                if v != -1 && v != 0 && v != 1 {
940                    return Err(ConfigError::invalid(
941                        "monotone_constraints",
942                        format!("feature {}: must be -1, 0, or +1, got {}", i, v),
943                    )
944                    .into());
945                }
946            }
947        }
948
949        // -- Leaf output clamping --
950        if let Some(max) = c.max_leaf_output {
951            if max <= 0.0 {
952                return Err(
953                    ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
954                );
955            }
956        }
957
958        // -- Per-leaf adaptive output bound --
959        if let Some(k) = c.adaptive_leaf_bound {
960            if k <= 0.0 {
961                return Err(
962                    ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
963                );
964            }
965        }
966
967        // -- Minimum hessian sum --
968        if let Some(min_h) = c.min_hessian_sum {
969            if min_h <= 0.0 {
970                return Err(
971                    ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
972                );
973            }
974        }
975
976        // -- Huber loss multiplier --
977        if let Some(k) = c.huber_k {
978            if k <= 0.0 {
979                return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
980            }
981        }
982
983        // -- Shadow warmup --
984        if let Some(warmup) = c.shadow_warmup {
985            if warmup == 0 {
986                return Err(ConfigError::out_of_range(
987                    "shadow_warmup",
988                    "must be > 0",
989                    warmup as f64,
990                )
991                .into());
992            }
993        }
994
995        // -- Quality-based pruning parameters --
996        if let Some(alpha) = c.quality_prune_alpha {
997            if alpha <= 0.0 || alpha >= 1.0 {
998                return Err(ConfigError::out_of_range(
999                    "quality_prune_alpha",
1000                    "must be in (0, 1)",
1001                    alpha,
1002                )
1003                .into());
1004            }
1005        }
1006        if c.quality_prune_threshold <= 0.0 {
1007            return Err(ConfigError::out_of_range(
1008                "quality_prune_threshold",
1009                "must be > 0",
1010                c.quality_prune_threshold,
1011            )
1012            .into());
1013        }
1014        if c.quality_prune_patience == 0 {
1015            return Err(ConfigError::out_of_range(
1016                "quality_prune_patience",
1017                "must be > 0",
1018                c.quality_prune_patience,
1019            )
1020            .into());
1021        }
1022
1023        // -- Error-weighted sample importance --
1024        if let Some(alpha) = c.error_weight_alpha {
1025            if alpha <= 0.0 || alpha >= 1.0 {
1026                return Err(ConfigError::out_of_range(
1027                    "error_weight_alpha",
1028                    "must be in (0, 1)",
1029                    alpha,
1030                )
1031                .into());
1032            }
1033        }
1034
1035        // -- Drift detector parameters --
1036        match &c.drift_detector {
1037            DriftDetectorType::PageHinkley { delta, lambda } => {
1038                if *delta <= 0.0 {
1039                    return Err(ConfigError::out_of_range(
1040                        "drift_detector.PageHinkley.delta",
1041                        "must be > 0",
1042                        delta,
1043                    )
1044                    .into());
1045                }
1046                if *lambda <= 0.0 {
1047                    return Err(ConfigError::out_of_range(
1048                        "drift_detector.PageHinkley.lambda",
1049                        "must be > 0",
1050                        lambda,
1051                    )
1052                    .into());
1053                }
1054            }
1055            DriftDetectorType::Adwin { delta } => {
1056                if *delta <= 0.0 || *delta >= 1.0 {
1057                    return Err(ConfigError::out_of_range(
1058                        "drift_detector.Adwin.delta",
1059                        "must be in (0, 1)",
1060                        delta,
1061                    )
1062                    .into());
1063                }
1064            }
1065            DriftDetectorType::Ddm {
1066                warning_level,
1067                drift_level,
1068                min_instances,
1069            } => {
1070                if *warning_level <= 0.0 {
1071                    return Err(ConfigError::out_of_range(
1072                        "drift_detector.Ddm.warning_level",
1073                        "must be > 0",
1074                        warning_level,
1075                    )
1076                    .into());
1077                }
1078                if *drift_level <= 0.0 {
1079                    return Err(ConfigError::out_of_range(
1080                        "drift_detector.Ddm.drift_level",
1081                        "must be > 0",
1082                        drift_level,
1083                    )
1084                    .into());
1085                }
1086                if *drift_level <= *warning_level {
1087                    return Err(ConfigError::invalid(
1088                        "drift_detector.Ddm.drift_level",
1089                        format!(
1090                            "must be > warning_level ({}), got {}",
1091                            warning_level, drift_level
1092                        ),
1093                    )
1094                    .into());
1095                }
1096                if *min_instances == 0 {
1097                    return Err(ConfigError::out_of_range(
1098                        "drift_detector.Ddm.min_instances",
1099                        "must be > 0",
1100                        min_instances,
1101                    )
1102                    .into());
1103                }
1104            }
1105        }
1106
1107        // -- Variant parameters --
1108        match &c.variant {
1109            SGBTVariant::Standard => {} // no extra validation
1110            SGBTVariant::Skip { k } => {
1111                if *k == 0 {
1112                    return Err(
1113                        ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1114                    );
1115                }
1116            }
1117            SGBTVariant::MultipleIterations { multiplier } => {
1118                if *multiplier <= 0.0 {
1119                    return Err(ConfigError::out_of_range(
1120                        "variant.MultipleIterations.multiplier",
1121                        "must be > 0",
1122                        multiplier,
1123                    )
1124                    .into());
1125                }
1126            }
1127        }
1128
1129        Ok(self.config)
1130    }
1131}
1132
1133// ---------------------------------------------------------------------------
1134// Display impls
1135// ---------------------------------------------------------------------------
1136
1137impl core::fmt::Display for DriftDetectorType {
1138    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1139        match self {
1140            Self::PageHinkley { delta, lambda } => {
1141                write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1142            }
1143            Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1144            Self::Ddm {
1145                warning_level,
1146                drift_level,
1147                min_instances,
1148            } => write!(
1149                f,
1150                "Ddm(warning={}, drift={}, min_instances={})",
1151                warning_level, drift_level, min_instances
1152            ),
1153        }
1154    }
1155}
1156
1157impl core::fmt::Display for SGBTConfig {
1158    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1159        write!(
1160            f,
1161            "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1162            self.n_steps,
1163            self.learning_rate,
1164            self.max_depth,
1165            self.n_bins,
1166            self.variant,
1167            self.drift_detector,
1168        )
1169    }
1170}
1171
1172// ---------------------------------------------------------------------------
1173// Tests
1174// ---------------------------------------------------------------------------
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::*;
1179    use alloc::format;
1180    use alloc::vec;
1181
1182    // ------------------------------------------------------------------
1183    // 1. Default config values are correct
1184    // ------------------------------------------------------------------
1185    #[test]
1186    fn default_config_values() {
1187        let cfg = SGBTConfig::default();
1188        assert_eq!(cfg.n_steps, 100);
1189        assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1190        assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1191        assert_eq!(cfg.max_depth, 6);
1192        assert_eq!(cfg.n_bins, 64);
1193        assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1194        assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1195        assert_eq!(cfg.grace_period, 200);
1196        assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1197        assert_eq!(cfg.variant, SGBTVariant::Standard);
1198    }
1199
1200    // ------------------------------------------------------------------
1201    // 2. Builder chain works
1202    // ------------------------------------------------------------------
1203    #[test]
1204    fn builder_chain() {
1205        let cfg = SGBTConfig::builder()
1206            .n_steps(50)
1207            .learning_rate(0.1)
1208            .feature_subsample_rate(0.5)
1209            .max_depth(10)
1210            .n_bins(128)
1211            .lambda(0.5)
1212            .gamma(0.1)
1213            .grace_period(500)
1214            .delta(1e-3)
1215            .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1216            .variant(SGBTVariant::Skip { k: 5 })
1217            .build()
1218            .expect("valid config");
1219
1220        assert_eq!(cfg.n_steps, 50);
1221        assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1222        assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1223        assert_eq!(cfg.max_depth, 10);
1224        assert_eq!(cfg.n_bins, 128);
1225        assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1226        assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1227        assert_eq!(cfg.grace_period, 500);
1228        assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1229        assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1230
1231        // Verify drift detector type is Adwin.
1232        match &cfg.drift_detector {
1233            DriftDetectorType::Adwin { delta } => {
1234                assert!((delta - 0.01).abs() < f64::EPSILON);
1235            }
1236            _ => panic!("expected Adwin drift detector"),
1237        }
1238    }
1239
1240    // ------------------------------------------------------------------
1241    // 3. Validation rejects invalid values
1242    // ------------------------------------------------------------------
1243    #[test]
1244    fn validation_rejects_n_steps_zero() {
1245        let result = SGBTConfig::builder().n_steps(0).build();
1246        assert!(result.is_err());
1247        let msg = format!("{}", result.unwrap_err());
1248        assert!(msg.contains("n_steps"));
1249    }
1250
1251    #[test]
1252    fn validation_rejects_learning_rate_zero() {
1253        let result = SGBTConfig::builder().learning_rate(0.0).build();
1254        assert!(result.is_err());
1255        let msg = format!("{}", result.unwrap_err());
1256        assert!(msg.contains("learning_rate"));
1257    }
1258
1259    #[test]
1260    fn validation_rejects_learning_rate_above_one() {
1261        let result = SGBTConfig::builder().learning_rate(1.5).build();
1262        assert!(result.is_err());
1263    }
1264
1265    #[test]
1266    fn validation_accepts_learning_rate_one() {
1267        let result = SGBTConfig::builder().learning_rate(1.0).build();
1268        assert!(result.is_ok());
1269    }
1270
1271    #[test]
1272    fn validation_rejects_negative_learning_rate() {
1273        let result = SGBTConfig::builder().learning_rate(-0.1).build();
1274        assert!(result.is_err());
1275    }
1276
1277    #[test]
1278    fn validation_rejects_feature_subsample_zero() {
1279        let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1280        assert!(result.is_err());
1281    }
1282
1283    #[test]
1284    fn validation_rejects_feature_subsample_above_one() {
1285        let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1286        assert!(result.is_err());
1287    }
1288
1289    #[test]
1290    fn validation_rejects_max_depth_zero() {
1291        let result = SGBTConfig::builder().max_depth(0).build();
1292        assert!(result.is_err());
1293    }
1294
1295    #[test]
1296    fn validation_rejects_n_bins_one() {
1297        let result = SGBTConfig::builder().n_bins(1).build();
1298        assert!(result.is_err());
1299    }
1300
1301    #[test]
1302    fn validation_rejects_negative_lambda() {
1303        let result = SGBTConfig::builder().lambda(-0.1).build();
1304        assert!(result.is_err());
1305    }
1306
1307    #[test]
1308    fn validation_accepts_zero_lambda() {
1309        let result = SGBTConfig::builder().lambda(0.0).build();
1310        assert!(result.is_ok());
1311    }
1312
1313    #[test]
1314    fn validation_rejects_negative_gamma() {
1315        let result = SGBTConfig::builder().gamma(-0.1).build();
1316        assert!(result.is_err());
1317    }
1318
1319    #[test]
1320    fn validation_rejects_grace_period_zero() {
1321        let result = SGBTConfig::builder().grace_period(0).build();
1322        assert!(result.is_err());
1323    }
1324
1325    #[test]
1326    fn validation_rejects_delta_zero() {
1327        let result = SGBTConfig::builder().delta(0.0).build();
1328        assert!(result.is_err());
1329    }
1330
1331    #[test]
1332    fn validation_rejects_delta_one() {
1333        let result = SGBTConfig::builder().delta(1.0).build();
1334        assert!(result.is_err());
1335    }
1336
1337    // ------------------------------------------------------------------
1338    // 3b. Drift detector parameter validation
1339    // ------------------------------------------------------------------
1340    #[test]
1341    fn validation_rejects_pht_negative_delta() {
1342        let result = SGBTConfig::builder()
1343            .drift_detector(DriftDetectorType::PageHinkley {
1344                delta: -1.0,
1345                lambda: 50.0,
1346            })
1347            .build();
1348        assert!(result.is_err());
1349        let msg = format!("{}", result.unwrap_err());
1350        assert!(msg.contains("PageHinkley"));
1351    }
1352
1353    #[test]
1354    fn validation_rejects_pht_zero_lambda() {
1355        let result = SGBTConfig::builder()
1356            .drift_detector(DriftDetectorType::PageHinkley {
1357                delta: 0.005,
1358                lambda: 0.0,
1359            })
1360            .build();
1361        assert!(result.is_err());
1362    }
1363
1364    #[test]
1365    fn validation_rejects_adwin_delta_out_of_range() {
1366        let result = SGBTConfig::builder()
1367            .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1368            .build();
1369        assert!(result.is_err());
1370
1371        let result = SGBTConfig::builder()
1372            .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1373            .build();
1374        assert!(result.is_err());
1375    }
1376
1377    #[test]
1378    fn validation_rejects_ddm_warning_above_drift() {
1379        let result = SGBTConfig::builder()
1380            .drift_detector(DriftDetectorType::Ddm {
1381                warning_level: 3.0,
1382                drift_level: 2.0,
1383                min_instances: 30,
1384            })
1385            .build();
1386        assert!(result.is_err());
1387        let msg = format!("{}", result.unwrap_err());
1388        assert!(msg.contains("drift_level"));
1389        assert!(msg.contains("must be > warning_level"));
1390    }
1391
1392    #[test]
1393    fn validation_rejects_ddm_equal_levels() {
1394        let result = SGBTConfig::builder()
1395            .drift_detector(DriftDetectorType::Ddm {
1396                warning_level: 2.0,
1397                drift_level: 2.0,
1398                min_instances: 30,
1399            })
1400            .build();
1401        assert!(result.is_err());
1402    }
1403
1404    #[test]
1405    fn validation_rejects_ddm_zero_min_instances() {
1406        let result = SGBTConfig::builder()
1407            .drift_detector(DriftDetectorType::Ddm {
1408                warning_level: 2.0,
1409                drift_level: 3.0,
1410                min_instances: 0,
1411            })
1412            .build();
1413        assert!(result.is_err());
1414    }
1415
1416    #[test]
1417    fn validation_rejects_ddm_zero_warning_level() {
1418        let result = SGBTConfig::builder()
1419            .drift_detector(DriftDetectorType::Ddm {
1420                warning_level: 0.0,
1421                drift_level: 3.0,
1422                min_instances: 30,
1423            })
1424            .build();
1425        assert!(result.is_err());
1426    }
1427
1428    // ------------------------------------------------------------------
1429    // 3c. Variant parameter validation
1430    // ------------------------------------------------------------------
1431    #[test]
1432    fn validation_rejects_skip_k_zero() {
1433        let result = SGBTConfig::builder()
1434            .variant(SGBTVariant::Skip { k: 0 })
1435            .build();
1436        assert!(result.is_err());
1437        let msg = format!("{}", result.unwrap_err());
1438        assert!(msg.contains("Skip"));
1439    }
1440
1441    #[test]
1442    fn validation_rejects_mi_zero_multiplier() {
1443        let result = SGBTConfig::builder()
1444            .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1445            .build();
1446        assert!(result.is_err());
1447    }
1448
1449    #[test]
1450    fn validation_rejects_mi_negative_multiplier() {
1451        let result = SGBTConfig::builder()
1452            .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1453            .build();
1454        assert!(result.is_err());
1455    }
1456
1457    #[test]
1458    fn validation_accepts_standard_variant() {
1459        let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1460        assert!(result.is_ok());
1461    }
1462
1463    // ------------------------------------------------------------------
1464    // 4. DriftDetectorType creates correct detector types
1465    // ------------------------------------------------------------------
1466    #[test]
1467    fn drift_detector_type_creates_page_hinkley() {
1468        let dt = DriftDetectorType::PageHinkley {
1469            delta: 0.01,
1470            lambda: 100.0,
1471        };
1472        let mut detector = dt.create();
1473
1474        // Should start with zero mean.
1475        assert_eq!(detector.estimated_mean(), 0.0);
1476
1477        // Feed a stable value -- should not drift.
1478        let signal = detector.update(1.0);
1479        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1480    }
1481
1482    #[test]
1483    fn drift_detector_type_creates_adwin() {
1484        let dt = DriftDetectorType::Adwin { delta: 0.05 };
1485        let mut detector = dt.create();
1486
1487        assert_eq!(detector.estimated_mean(), 0.0);
1488        let signal = detector.update(1.0);
1489        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1490    }
1491
1492    #[test]
1493    fn drift_detector_type_creates_ddm() {
1494        let dt = DriftDetectorType::Ddm {
1495            warning_level: 2.0,
1496            drift_level: 3.0,
1497            min_instances: 30,
1498        };
1499        let mut detector = dt.create();
1500
1501        assert_eq!(detector.estimated_mean(), 0.0);
1502        let signal = detector.update(0.1);
1503        assert_eq!(signal, crate::drift::DriftSignal::Stable);
1504    }
1505
1506    // ------------------------------------------------------------------
1507    // 5. Default drift detector is PageHinkley
1508    // ------------------------------------------------------------------
1509    #[test]
1510    fn default_drift_detector_is_page_hinkley() {
1511        let dt = DriftDetectorType::default();
1512        match dt {
1513            DriftDetectorType::PageHinkley { delta, lambda } => {
1514                assert!((delta - 0.005).abs() < f64::EPSILON);
1515                assert!((lambda - 50.0).abs() < f64::EPSILON);
1516            }
1517            _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1518        }
1519    }
1520
1521    // ------------------------------------------------------------------
1522    // 6. Default config builds without error
1523    // ------------------------------------------------------------------
1524    #[test]
1525    fn default_config_builds() {
1526        let result = SGBTConfig::builder().build();
1527        assert!(result.is_ok());
1528    }
1529
1530    // ------------------------------------------------------------------
1531    // 7. Config clone preserves all fields
1532    // ------------------------------------------------------------------
1533    #[test]
1534    fn config_clone_preserves_fields() {
1535        let original = SGBTConfig::builder()
1536            .n_steps(50)
1537            .learning_rate(0.05)
1538            .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1539            .build()
1540            .unwrap();
1541
1542        let cloned = original.clone();
1543
1544        assert_eq!(cloned.n_steps, 50);
1545        assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1546        assert_eq!(
1547            cloned.variant,
1548            SGBTVariant::MultipleIterations { multiplier: 5.0 }
1549        );
1550    }
1551
1552    // ------------------------------------------------------------------
1553    // 8. All three DDM valid configs accepted
1554    // ------------------------------------------------------------------
1555    #[test]
1556    fn valid_ddm_config_accepted() {
1557        let result = SGBTConfig::builder()
1558            .drift_detector(DriftDetectorType::Ddm {
1559                warning_level: 1.5,
1560                drift_level: 2.5,
1561                min_instances: 10,
1562            })
1563            .build();
1564        assert!(result.is_ok());
1565    }
1566
1567    // ------------------------------------------------------------------
1568    // 9. Created detectors are functional (round-trip test)
1569    // ------------------------------------------------------------------
1570    #[test]
1571    fn created_detectors_are_functional() {
1572        // Create a detector, feed it data, verify it can detect drift.
1573        let dt = DriftDetectorType::PageHinkley {
1574            delta: 0.005,
1575            lambda: 50.0,
1576        };
1577        let mut detector = dt.create();
1578
1579        // Feed 500 stable values.
1580        for _ in 0..500 {
1581            detector.update(1.0);
1582        }
1583
1584        // Feed shifted values -- should eventually drift.
1585        let mut drifted = false;
1586        for _ in 0..500 {
1587            if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1588                drifted = true;
1589                break;
1590            }
1591        }
1592        assert!(
1593            drifted,
1594            "detector created from DriftDetectorType should be functional"
1595        );
1596    }
1597
1598    // ------------------------------------------------------------------
1599    // 10. Boundary acceptance: n_bins=2 is the minimum valid
1600    // ------------------------------------------------------------------
1601    #[test]
1602    fn boundary_n_bins_two_accepted() {
1603        let result = SGBTConfig::builder().n_bins(2).build();
1604        assert!(result.is_ok());
1605    }
1606
1607    // ------------------------------------------------------------------
1608    // 11. Boundary acceptance: grace_period=1 is valid
1609    // ------------------------------------------------------------------
1610    #[test]
1611    fn boundary_grace_period_one_accepted() {
1612        let result = SGBTConfig::builder().grace_period(1).build();
1613        assert!(result.is_ok());
1614    }
1615
1616    // ------------------------------------------------------------------
1617    // 12. Feature names -- valid config with names
1618    // ------------------------------------------------------------------
1619    #[test]
1620    fn feature_names_accepted() {
1621        let cfg = SGBTConfig::builder()
1622            .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1623            .build()
1624            .unwrap();
1625        assert_eq!(
1626            cfg.feature_names.as_ref().unwrap(),
1627            &["price", "volume", "spread"]
1628        );
1629    }
1630
1631    // ------------------------------------------------------------------
1632    // 13. Feature names -- duplicate names rejected
1633    // ------------------------------------------------------------------
1634    #[test]
1635    fn feature_names_rejects_duplicates() {
1636        let result = SGBTConfig::builder()
1637            .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1638            .build();
1639        assert!(result.is_err());
1640        let msg = format!("{}", result.unwrap_err());
1641        assert!(msg.contains("duplicate"));
1642    }
1643
1644    // ------------------------------------------------------------------
1645    // 14. Feature names -- serde backward compat (missing field)
1646    //     Requires serde_json which is only in the full irithyll crate.
1647    // ------------------------------------------------------------------
1648
1649    // ------------------------------------------------------------------
1650    // 15. Feature names -- empty vec is valid
1651    // ------------------------------------------------------------------
1652    #[test]
1653    fn feature_names_empty_vec_accepted() {
1654        let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1655        assert!(cfg.feature_names.unwrap().is_empty());
1656    }
1657
1658    // ------------------------------------------------------------------
1659    // 16. Adaptive leaf bound -- builder
1660    // ------------------------------------------------------------------
1661    #[test]
1662    fn builder_adaptive_leaf_bound() {
1663        let cfg = SGBTConfig::builder()
1664            .adaptive_leaf_bound(3.0)
1665            .build()
1666            .unwrap();
1667        assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1668    }
1669
1670    // ------------------------------------------------------------------
1671    // 17. Adaptive leaf bound -- validation rejects zero
1672    // ------------------------------------------------------------------
1673    #[test]
1674    fn validation_rejects_zero_adaptive_leaf_bound() {
1675        let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1676        assert!(result.is_err());
1677        let msg = format!("{}", result.unwrap_err());
1678        assert!(
1679            msg.contains("adaptive_leaf_bound"),
1680            "error should mention adaptive_leaf_bound: {}",
1681            msg,
1682        );
1683    }
1684
1685    // ------------------------------------------------------------------
1686    // 18. Adaptive leaf bound -- serde backward compat
1687    //     Requires serde_json which is only in the full irithyll crate.
1688    // ------------------------------------------------------------------
1689}