Skip to main content

irithyll_core/ensemble/
config.rs

1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34///   `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1) compute.
35///   Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38///   log(σ) from features (NGBoost-style dual chain).  Gives *feature-conditional*
39///   uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43    /// EWMA of squared prediction errors — always calibrated, O(1).
44    #[default]
45    Empirical,
46    /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47    TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61    /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62    /// (detection threshold).
63    PageHinkley {
64        /// Magnitude tolerance. Default 0.005.
65        delta: f64,
66        /// Detection threshold. Default 50.0.
67        lambda: f64,
68    },
69
70    /// ADWIN with custom confidence parameter.
71    Adwin {
72        /// Confidence (smaller = fewer false positives). Default 0.002.
73        delta: f64,
74    },
75
76    /// DDM with custom warning/drift levels and minimum warmup instances.
77    Ddm {
78        /// Warning threshold multiplier. Default 2.0.
79        warning_level: f64,
80        /// Drift threshold multiplier. Default 3.0.
81        drift_level: f64,
82        /// Minimum observations before detection activates. Default 30.
83        min_instances: u64,
84    },
85}
86
87impl Default for DriftDetectorType {
88    fn default() -> Self {
89        DriftDetectorType::PageHinkley {
90            delta: 0.005,
91            lambda: 50.0,
92        }
93    }
94}
95
96impl DriftDetectorType {
97    /// Create a new, fresh drift detector from this configuration.
98    pub fn create(&self) -> Box<dyn DriftDetector> {
99        match self {
100            Self::PageHinkley { delta, lambda } => {
101                Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102            }
103            Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104            Self::Ddm {
105                warning_level,
106                drift_level,
107                min_instances,
108            } => Box::new(Ddm::with_params(
109                *warning_level,
110                *drift_level,
111                *min_instances,
112            )),
113        }
114    }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter                | Default              |
128/// |--------------------------|----------------------|
129/// | `n_steps`                | 100                  |
130/// | `learning_rate`          | 0.0125               |
131/// | `feature_subsample_rate` | 0.75                 |
132/// | `max_depth`              | 6                    |
133/// | `n_bins`                 | 64                   |
134/// | `lambda`                 | 1.0                  |
135/// | `gamma`                  | 0.0                  |
136/// | `grace_period`           | 200                  |
137/// | `delta`                  | 1e-7                 |
138/// | `drift_detector`         | PageHinkley(0.005, 50.0) |
139/// | `variant`                | Standard             |
140/// | `seed`                   | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count`   | 50                   |
142/// | `leaf_half_life`         | None (disabled)      |
143/// | `max_tree_samples`       | None (disabled)      |
144/// | `split_reeval_interval`  | None (disabled)      |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148    /// Number of boosting steps (trees in the ensemble). Default 100.
149    pub n_steps: usize,
150    /// Learning rate (shrinkage). Default 0.0125.
151    pub learning_rate: f64,
152    /// Fraction of features to subsample per tree. Default 0.75.
153    pub feature_subsample_rate: f64,
154    /// Maximum tree depth. Default 6.
155    pub max_depth: usize,
156    /// Number of histogram bins. Default 64.
157    pub n_bins: usize,
158    /// L2 regularization parameter (lambda). Default 1.0.
159    pub lambda: f64,
160    /// Minimum split gain (gamma). Default 0.0.
161    pub gamma: f64,
162    /// Grace period: min samples before evaluating splits. Default 200.
163    pub grace_period: usize,
164    /// Hoeffding bound confidence (delta). Default 1e-7.
165    pub delta: f64,
166    /// Drift detector type for tree replacement. Default: PageHinkley.
167    pub drift_detector: DriftDetectorType,
168    /// SGBT computational variant. Default: Standard.
169    pub variant: SGBTVariant,
170    /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171    ///
172    /// Controls feature subsampling and variant skip/MI stochastic decisions.
173    /// Two models with the same seed and same data will produce identical results.
174    pub seed: u64,
175    /// Number of initial targets to collect before computing the base prediction.
176    /// Default: 50.
177    pub initial_target_count: usize,
178
179    /// Half-life for exponential leaf decay (in samples per leaf).
180    ///
181    /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182    /// statistics have half the weight of the most recent sample. This causes
183    /// the model to continuously adapt to changing data distributions rather
184    /// than freezing on early observations.
185    ///
186    /// `None` (default) disables decay -- traditional monotonic accumulation.
187    #[cfg_attr(feature = "serde", serde(default))]
188    pub leaf_half_life: Option<usize>,
189
190    /// Maximum samples a single tree processes before proactive replacement.
191    ///
192    /// After this many samples, the tree is replaced with a fresh one regardless
193    /// of drift detector state. Prevents stale tree structure from persisting
194    /// when the drift detector is not sensitive enough.
195    ///
196    /// `None` (default) disables time-based replacement.
197    #[cfg_attr(feature = "serde", serde(default))]
198    pub max_tree_samples: Option<u64>,
199
200    /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
201    /// whether a split would improve them.
202    ///
203    /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
204    /// `split_reeval_interval` samples since its last evaluation and has reached
205    /// max depth, it re-evaluates whether a split should be performed.
206    ///
207    /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
208    #[cfg_attr(feature = "serde", serde(default))]
209    pub split_reeval_interval: Option<usize>,
210
211    /// Optional human-readable feature names.
212    ///
213    /// When set, enables [`named_feature_importances`](super::SGBT::named_feature_importances) and
214    /// [`train_one_named`](super::SGBT::train_one_named) for production-friendly named access.
215    /// Length must match the number of features in training data.
216    #[cfg_attr(feature = "serde", serde(default))]
217    pub feature_names: Option<Vec<String>>,
218
219    /// Optional per-feature type declarations.
220    ///
221    /// When set, declares which features are categorical vs continuous.
222    /// Categorical features use one-bin-per-category binning and Fisher
223    /// optimal binary partitioning for split evaluation.
224    /// Length must match the number of features in training data.
225    ///
226    /// `None` (default) treats all features as continuous.
227    #[cfg_attr(feature = "serde", serde(default))]
228    pub feature_types: Option<Vec<FeatureType>>,
229
230    /// Gradient clipping threshold in standard deviations per leaf.
231    ///
232    /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
233    /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
234    /// clamped to the boundary. This prevents outlier samples from corrupting
235    /// leaf statistics, which is critical in streaming settings where sudden
236    /// label floods can destabilize the model.
237    ///
238    /// Typical value: 3.0 (3-sigma clipping).
239    /// `None` (default) disables gradient clipping.
240    #[cfg_attr(feature = "serde", serde(default))]
241    pub gradient_clip_sigma: Option<f64>,
242
243    /// Per-feature monotonic constraints.
244    ///
245    /// Each element specifies the monotonic relationship between a feature and
246    /// the prediction:
247    /// - `+1`: prediction must be non-decreasing as feature value increases.
248    /// - `-1`: prediction must be non-increasing as feature value increases.
249    /// - `0`: no constraint (unconstrained).
250    ///
251    /// During split evaluation, candidate splits that would violate monotonicity
252    /// (left child value > right child value for +1 constraints, or vice versa)
253    /// are rejected.
254    ///
255    /// Length must match the number of features in training data.
256    /// `None` (default) means no monotonic constraints.
257    #[cfg_attr(feature = "serde", serde(default))]
258    pub monotone_constraints: Option<Vec<i8>>,
259
260    /// EWMA smoothing factor for quality-based tree pruning.
261    ///
262    /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
263    /// moving average of its marginal contribution to the ensemble. Trees whose
264    /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
265    /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
266    /// samples are replaced with a fresh tree that can learn the current regime.
267    ///
268    /// This prevents "dead wood" -- trees from a past regime that no longer
269    /// contribute meaningfully to ensemble accuracy.
270    ///
271    /// `None` (default) disables quality-based pruning.
272    /// Suggested value: 0.01.
273    #[cfg_attr(feature = "serde", serde(default))]
274    pub quality_prune_alpha: Option<f64>,
275
276    /// Minimum contribution threshold for quality-based pruning.
277    ///
278    /// A tree's EWMA contribution must stay above this value to avoid being
279    /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
280    ///
281    /// Default: 1e-6.
282    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
283    pub quality_prune_threshold: f64,
284
285    /// Consecutive low-contribution samples before a tree is replaced.
286    ///
287    /// After this many consecutive samples where a tree's EWMA contribution
288    /// is below `quality_prune_threshold`, the tree is reset. Only used when
289    /// `quality_prune_alpha` is `Some`.
290    ///
291    /// Default: 500.
292    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
293    pub quality_prune_patience: u64,
294
295    /// EWMA smoothing factor for error-weighted sample importance.
296    ///
297    /// When `Some(alpha)`, samples the model predicted poorly get higher
298    /// effective weight during histogram accumulation. The weight is:
299    /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
300    ///
301    /// This is a streaming version of AdaBoost's reweighting applied at the
302    /// gradient level -- learning capacity focuses on hard/novel patterns,
303    /// enabling faster adaptation to regime changes.
304    ///
305    /// `None` (default) disables error weighting.
306    /// Suggested value: 0.01.
307    #[cfg_attr(feature = "serde", serde(default))]
308    pub error_weight_alpha: Option<f64>,
309
310    /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
311    ///
312    /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
313    /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
314    /// is an EWMA of the model's predicted σ (alpha = 0.001).
315    ///
316    /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
317    /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
318    /// trains at the unmodulated base rate to prevent positive feedback loops.
319    ///
320    /// Default: `false`.
321    #[cfg_attr(feature = "serde", serde(default))]
322    pub uncertainty_modulated_lr: bool,
323
324    /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
325    ///
326    /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
327    ///   errors.  `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1).
328    /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
329    ///   separate tree ensemble predicting log(σ) from features.
330    ///
331    /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
332    /// is strongly recommended — scale tree gradients are inherently weak and
333    /// the trees often fail to split.
334    #[cfg_attr(feature = "serde", serde(default))]
335    pub scale_mode: ScaleMode,
336
337    /// EWMA smoothing factor for empirical σ estimation.
338    ///
339    /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
340    /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
341    /// Higher values react faster to regime changes but are noisier.
342    ///
343    /// Default: `0.01` (~100-sample effective window).
344    #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
345    pub empirical_sigma_alpha: f64,
346
347    /// Maximum absolute leaf output value.
348    ///
349    /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
350    /// Prevents runaway leaf weights from causing prediction explosions
351    /// in feedback loops. `None` (default) means no clamping.
352    #[cfg_attr(feature = "serde", serde(default))]
353    pub max_leaf_output: Option<f64>,
354
355    /// Per-leaf adaptive output bound (sigma multiplier).
356    ///
357    /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
358    /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
359    /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
360    ///
361    /// This is strictly superior to `max_leaf_output` for streaming — the bound
362    /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
363    /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
364    ///
365    /// Typical value: 3.0 (3-sigma bound).
366    /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
367    #[cfg_attr(feature = "serde", serde(default))]
368    pub adaptive_leaf_bound: Option<f64>,
369
370    /// Minimum hessian sum before a leaf produces non-zero output.
371    ///
372    /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
373    /// Prevents post-replacement spikes from fresh leaves with insufficient
374    /// samples. `None` (default) means all leaves contribute immediately.
375    #[cfg_attr(feature = "serde", serde(default))]
376    pub min_hessian_sum: Option<f64>,
377
378    /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
379    ///
380    /// When `Some(k)`, the distributional location gradient uses Huber loss
381    /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
382    /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
383    /// `None` (default) uses squared loss.
384    #[cfg_attr(feature = "serde", serde(default))]
385    pub huber_k: Option<f64>,
386
387    /// Shadow warmup for graduated tree handoff.
388    ///
389    /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
390    /// alongside every active tree. The shadow trains on the same gradient stream
391    /// but does not contribute to predictions until it has seen `n` samples.
392    ///
393    /// As the active tree ages past 80% of `max_tree_samples`, its prediction
394    /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
395    /// over `n` samples after warmup. When the active weight reaches 0, the shadow
396    /// is promoted and a new shadow is spawned — no cold-start prediction dip.
397    ///
398    /// Requires `max_tree_samples` to be set for time-based graduated handoff.
399    /// Drift-based replacement still uses hard swap (shadow is already warm).
400    ///
401    /// `None` (default) disables graduated handoff — uses traditional hard swap.
402    #[cfg_attr(feature = "serde", serde(default))]
403    pub shadow_warmup: Option<usize>,
404
405    /// Leaf prediction model type.
406    ///
407    /// Controls how each leaf computes its prediction:
408    /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
409    /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
410    ///   AdaGrad optimization. Optional `decay` for concept drift. Recommended for
411    ///   low-depth trees (depth 2--4).
412    /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
413    ///   Optional `decay` for concept drift.
414    /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
415    ///   when the Hoeffding bound confirms a more complex model is better.
416    ///
417    /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
418    #[cfg_attr(feature = "serde", serde(default))]
419    pub leaf_model_type: LeafModelType,
420
421    /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
422    ///
423    /// When non-zero, the distributional model maintains a packed f32 cache of
424    /// its location ensemble that is re-exported every `packed_refresh_interval`
425    /// training samples. Predictions use the cache for O(1)-per-tree inference
426    /// via contiguous memory traversal, falling back to full tree traversal when
427    /// the cache is absent or produces non-finite results.
428    ///
429    /// `0` (default) disables the packed cache.
430    #[cfg_attr(feature = "serde", serde(default))]
431    pub packed_refresh_interval: u64,
432}
433
434fn default_empirical_sigma_alpha() -> f64 {
435    0.01
436}
437
438fn default_quality_prune_threshold() -> f64 {
439    1e-6
440}
441
442fn default_quality_prune_patience() -> u64 {
443    500
444}
445
446impl Default for SGBTConfig {
447    fn default() -> Self {
448        Self {
449            n_steps: 100,
450            learning_rate: 0.0125,
451            feature_subsample_rate: 0.75,
452            max_depth: 6,
453            n_bins: 64,
454            lambda: 1.0,
455            gamma: 0.0,
456            grace_period: 200,
457            delta: 1e-7,
458            drift_detector: DriftDetectorType::default(),
459            variant: SGBTVariant::default(),
460            seed: 0xDEAD_BEEF_CAFE_4242,
461            initial_target_count: 50,
462            leaf_half_life: None,
463            max_tree_samples: None,
464            split_reeval_interval: None,
465            feature_names: None,
466            feature_types: None,
467            gradient_clip_sigma: None,
468            monotone_constraints: None,
469            quality_prune_alpha: None,
470            quality_prune_threshold: 1e-6,
471            quality_prune_patience: 500,
472            error_weight_alpha: None,
473            uncertainty_modulated_lr: false,
474            scale_mode: ScaleMode::default(),
475            empirical_sigma_alpha: 0.01,
476            max_leaf_output: None,
477            adaptive_leaf_bound: None,
478            min_hessian_sum: None,
479            huber_k: None,
480            shadow_warmup: None,
481            leaf_model_type: LeafModelType::default(),
482            packed_refresh_interval: 0,
483        }
484    }
485}
486
487impl SGBTConfig {
488    /// Start building a configuration via the builder pattern.
489    pub fn builder() -> SGBTConfigBuilder {
490        SGBTConfigBuilder::default()
491    }
492}
493
494// ---------------------------------------------------------------------------
495// SGBTConfigBuilder
496// ---------------------------------------------------------------------------
497
498/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
499///
500/// # Example
501///
502/// ```ignore
503/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
504/// use irithyll::ensemble::variants::SGBTVariant;
505///
506/// let config = SGBTConfig::builder()
507///     .n_steps(200)
508///     .learning_rate(0.05)
509///     .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
510///     .variant(SGBTVariant::Skip { k: 10 })
511///     .build()
512///     .expect("valid config");
513/// ```
514#[derive(Debug, Clone, Default)]
515pub struct SGBTConfigBuilder {
516    config: SGBTConfig,
517}
518
519impl SGBTConfigBuilder {
520    /// Set the number of boosting steps (trees in the ensemble).
521    pub fn n_steps(mut self, n: usize) -> Self {
522        self.config.n_steps = n;
523        self
524    }
525
526    /// Set the learning rate (shrinkage factor).
527    pub fn learning_rate(mut self, lr: f64) -> Self {
528        self.config.learning_rate = lr;
529        self
530    }
531
532    /// Set the fraction of features to subsample per tree.
533    pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
534        self.config.feature_subsample_rate = rate;
535        self
536    }
537
538    /// Set the maximum tree depth.
539    pub fn max_depth(mut self, depth: usize) -> Self {
540        self.config.max_depth = depth;
541        self
542    }
543
544    /// Set the number of histogram bins per feature.
545    pub fn n_bins(mut self, bins: usize) -> Self {
546        self.config.n_bins = bins;
547        self
548    }
549
550    /// Set the L2 regularization parameter (lambda).
551    pub fn lambda(mut self, l: f64) -> Self {
552        self.config.lambda = l;
553        self
554    }
555
556    /// Set the minimum split gain (gamma).
557    pub fn gamma(mut self, g: f64) -> Self {
558        self.config.gamma = g;
559        self
560    }
561
562    /// Set the grace period (minimum samples before evaluating splits).
563    pub fn grace_period(mut self, gp: usize) -> Self {
564        self.config.grace_period = gp;
565        self
566    }
567
568    /// Set the Hoeffding bound confidence parameter (delta).
569    pub fn delta(mut self, d: f64) -> Self {
570        self.config.delta = d;
571        self
572    }
573
574    /// Set the drift detector type for tree replacement.
575    pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
576        self.config.drift_detector = dt;
577        self
578    }
579
580    /// Set the SGBT computational variant.
581    pub fn variant(mut self, v: SGBTVariant) -> Self {
582        self.config.variant = v;
583        self
584    }
585
586    /// Set the random seed for deterministic reproducibility.
587    ///
588    /// Controls feature subsampling and variant skip/MI stochastic decisions.
589    /// Two models with the same seed and data sequence will produce identical results.
590    pub fn seed(mut self, seed: u64) -> Self {
591        self.config.seed = seed;
592        self
593    }
594
595    /// Set the number of initial targets to collect before computing the base prediction.
596    ///
597    /// The model collects this many target values before initializing the base
598    /// prediction (via `loss.initial_prediction`). Default: 50.
599    pub fn initial_target_count(mut self, count: usize) -> Self {
600        self.config.initial_target_count = count;
601        self
602    }
603
604    /// Set the half-life for exponential leaf decay (in samples per leaf).
605    ///
606    /// After `n` samples, a leaf's accumulated statistics have half the weight
607    /// of the most recent sample. Enables continuous adaptation to concept drift.
608    pub fn leaf_half_life(mut self, n: usize) -> Self {
609        self.config.leaf_half_life = Some(n);
610        self
611    }
612
613    /// Set the maximum samples a single tree processes before proactive replacement.
614    ///
615    /// After `n` samples, the tree is replaced regardless of drift detector state.
616    pub fn max_tree_samples(mut self, n: u64) -> Self {
617        self.config.max_tree_samples = Some(n);
618        self
619    }
620
621    /// Set the split re-evaluation interval for max-depth leaves.
622    ///
623    /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
624    /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
625    pub fn split_reeval_interval(mut self, n: usize) -> Self {
626        self.config.split_reeval_interval = Some(n);
627        self
628    }
629
630    /// Set human-readable feature names.
631    ///
632    /// Enables named feature importances and named training input.
633    /// Names must be unique; validated at [`build()`](Self::build).
634    pub fn feature_names(mut self, names: Vec<String>) -> Self {
635        self.config.feature_names = Some(names);
636        self
637    }
638
639    /// Set per-feature type declarations.
640    ///
641    /// Declares which features are categorical vs continuous. Categorical features
642    /// use one-bin-per-category binning and Fisher optimal binary partitioning.
643    /// Supports up to 64 distinct category values per categorical feature.
644    pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
645        self.config.feature_types = Some(types);
646        self
647    }
648
649    /// Set per-leaf gradient clipping threshold (in standard deviations).
650    ///
651    /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
652    /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
653    /// from corrupting streaming model stability.
654    ///
655    /// Typical value: 3.0 (3-sigma clipping).
656    pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
657        self.config.gradient_clip_sigma = Some(sigma);
658        self
659    }
660
661    /// Set per-feature monotonic constraints.
662    ///
663    /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
664    /// Candidate splits violating monotonicity are rejected during tree growth.
665    pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
666        self.config.monotone_constraints = Some(constraints);
667        self
668    }
669
670    /// Enable quality-based tree pruning with the given EWMA smoothing factor.
671    ///
672    /// Trees whose marginal contribution drops below the threshold for
673    /// `patience` consecutive samples are replaced with fresh trees.
674    /// Suggested alpha: 0.01.
675    pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
676        self.config.quality_prune_alpha = Some(alpha);
677        self
678    }
679
680    /// Set the minimum contribution threshold for quality-based pruning.
681    ///
682    /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
683    pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
684        self.config.quality_prune_threshold = threshold;
685        self
686    }
687
688    /// Set the patience (consecutive low-contribution samples) before pruning.
689    ///
690    /// Default: 500. Only relevant when `quality_prune_alpha` is set.
691    pub fn quality_prune_patience(mut self, patience: u64) -> Self {
692        self.config.quality_prune_patience = patience;
693        self
694    }
695
696    /// Enable error-weighted sample importance with the given EWMA smoothing factor.
697    ///
698    /// Samples the model predicted poorly get higher effective weight.
699    /// Suggested alpha: 0.01.
700    pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
701        self.config.error_weight_alpha = Some(alpha);
702        self
703    }
704
705    /// Enable σ-modulated learning rate for distributional models.
706    ///
707    /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
708    /// so the model adapts faster during high-uncertainty regimes and conserves
709    /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
710    ///
711    /// By default uses empirical σ (EWMA of squared errors).  Set
712    /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
713    pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
714        self.config.uncertainty_modulated_lr = enabled;
715        self
716    }
717
718    /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
719    ///
720    /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
721    /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
722    pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
723        self.config.scale_mode = mode;
724        self
725    }
726
727    /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
728    ///
729    /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
730    pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
731        self.config.empirical_sigma_alpha = alpha;
732        self
733    }
734
735    /// Set the maximum absolute leaf output value.
736    ///
737    /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
738    /// that cause prediction explosions.
739    pub fn max_leaf_output(mut self, max: f64) -> Self {
740        self.config.max_leaf_output = Some(max);
741        self
742    }
743
744    /// Set per-leaf adaptive output bound (sigma multiplier).
745    ///
746    /// Each leaf tracks EWMA of its own output weight and clamps to
747    /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
748    /// Recommended: use with `leaf_half_life` for streaming scenarios.
749    pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
750        self.config.adaptive_leaf_bound = Some(k);
751        self
752    }
753
754    /// Set the minimum hessian sum for leaf output.
755    ///
756    /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
757    /// post-replacement spikes.
758    pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
759        self.config.min_hessian_sum = Some(min_h);
760        self
761    }
762
763    /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
764    ///
765    /// When set, location gradients use Huber loss with adaptive
766    /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
767    pub fn huber_k(mut self, k: f64) -> Self {
768        self.config.huber_k = Some(k);
769        self
770    }
771
772    /// Enable graduated tree handoff with the given shadow warmup samples.
773    ///
774    /// Spawns an always-on shadow tree that trains alongside the active tree.
775    /// After `warmup` samples, the shadow begins contributing to predictions
776    /// via graduated blending. Eliminates prediction dips during tree replacement.
777    pub fn shadow_warmup(mut self, warmup: usize) -> Self {
778        self.config.shadow_warmup = Some(warmup);
779        self
780    }
781
782    /// Set the leaf prediction model type.
783    ///
784    /// [`LeafModelType::Linear`] is recommended for low-depth configurations
785    /// (depth 2--4) where per-leaf linear models reduce approximation error.
786    ///
787    /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
788    /// a trainable model per leaf, using the Hoeffding bound for promotion.
789    pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
790        self.config.leaf_model_type = lmt;
791        self
792    }
793
794    /// Set the packed cache refresh interval for distributional models.
795    ///
796    /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
797    /// maintains a packed f32 cache refreshed every `interval` training samples.
798    /// `0` (default) disables the cache.
799    pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
800        self.config.packed_refresh_interval = interval;
801        self
802    }
803
804    /// Validate and build the configuration.
805    ///
806    /// # Errors
807    ///
808    /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
809    /// [`ConfigError`] if any parameter is out of its valid range.
810    pub fn build(self) -> Result<SGBTConfig> {
811        let c = &self.config;
812
813        // -- Ensemble-level parameters --
814        if c.n_steps == 0 {
815            return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
816        }
817        if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
818            return Err(ConfigError::out_of_range(
819                "learning_rate",
820                "must be in (0, 1]",
821                c.learning_rate,
822            )
823            .into());
824        }
825        if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
826            return Err(ConfigError::out_of_range(
827                "feature_subsample_rate",
828                "must be in (0, 1]",
829                c.feature_subsample_rate,
830            )
831            .into());
832        }
833
834        // -- Tree-level parameters --
835        if c.max_depth == 0 {
836            return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
837        }
838        if c.n_bins < 2 {
839            return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
840        }
841        if c.lambda < 0.0 {
842            return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
843        }
844        if c.gamma < 0.0 {
845            return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
846        }
847        if c.grace_period == 0 {
848            return Err(
849                ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
850            );
851        }
852        if c.delta <= 0.0 || c.delta >= 1.0 {
853            return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
854        }
855
856        if c.initial_target_count == 0 {
857            return Err(ConfigError::out_of_range(
858                "initial_target_count",
859                "must be > 0",
860                c.initial_target_count,
861            )
862            .into());
863        }
864
865        // -- Streaming adaptation parameters --
866        if let Some(hl) = c.leaf_half_life {
867            if hl == 0 {
868                return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
869            }
870        }
871        if let Some(max) = c.max_tree_samples {
872            if max < 100 {
873                return Err(
874                    ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
875                );
876            }
877        }
878        if let Some(interval) = c.split_reeval_interval {
879            if interval < c.grace_period {
880                return Err(ConfigError::invalid(
881                    "split_reeval_interval",
882                    format!(
883                        "must be >= grace_period ({}), got {}",
884                        c.grace_period, interval
885                    ),
886                )
887                .into());
888            }
889        }
890
891        // -- Feature names --
892        if let Some(ref names) = c.feature_names {
893            // O(n^2) duplicate check — names list is small so no HashSet needed
894            for (i, name) in names.iter().enumerate() {
895                for prev in &names[..i] {
896                    if name == prev {
897                        return Err(ConfigError::invalid(
898                            "feature_names",
899                            format!("duplicate feature name: '{}'", name),
900                        )
901                        .into());
902                    }
903                }
904            }
905        }
906
907        // -- Feature types --
908        if let Some(ref types) = c.feature_types {
909            if let Some(ref names) = c.feature_names {
910                if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
911                    return Err(ConfigError::invalid(
912                        "feature_types",
913                        format!(
914                            "length ({}) must match feature_names length ({})",
915                            types.len(),
916                            names.len()
917                        ),
918                    )
919                    .into());
920                }
921            }
922        }
923
924        // -- Gradient clipping --
925        if let Some(sigma) = c.gradient_clip_sigma {
926            if sigma <= 0.0 {
927                return Err(
928                    ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
929                );
930            }
931        }
932
933        // -- Monotonic constraints --
934        if let Some(ref mc) = c.monotone_constraints {
935            for (i, &v) in mc.iter().enumerate() {
936                if v != -1 && v != 0 && v != 1 {
937                    return Err(ConfigError::invalid(
938                        "monotone_constraints",
939                        format!("feature {}: must be -1, 0, or +1, got {}", i, v),
940                    )
941                    .into());
942                }
943            }
944        }
945
946        // -- Leaf output clamping --
947        if let Some(max) = c.max_leaf_output {
948            if max <= 0.0 {
949                return Err(
950                    ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
951                );
952            }
953        }
954
955        // -- Per-leaf adaptive output bound --
956        if let Some(k) = c.adaptive_leaf_bound {
957            if k <= 0.0 {
958                return Err(
959                    ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
960                );
961            }
962        }
963
964        // -- Minimum hessian sum --
965        if let Some(min_h) = c.min_hessian_sum {
966            if min_h <= 0.0 {
967                return Err(
968                    ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
969                );
970            }
971        }
972
973        // -- Huber loss multiplier --
974        if let Some(k) = c.huber_k {
975            if k <= 0.0 {
976                return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
977            }
978        }
979
980        // -- Shadow warmup --
981        if let Some(warmup) = c.shadow_warmup {
982            if warmup == 0 {
983                return Err(ConfigError::out_of_range(
984                    "shadow_warmup",
985                    "must be > 0",
986                    warmup as f64,
987                )
988                .into());
989            }
990        }
991
992        // -- Quality-based pruning parameters --
993        if let Some(alpha) = c.quality_prune_alpha {
994            if alpha <= 0.0 || alpha >= 1.0 {
995                return Err(ConfigError::out_of_range(
996                    "quality_prune_alpha",
997                    "must be in (0, 1)",
998                    alpha,
999                )
1000                .into());
1001            }
1002        }
1003        if c.quality_prune_threshold <= 0.0 {
1004            return Err(ConfigError::out_of_range(
1005                "quality_prune_threshold",
1006                "must be > 0",
1007                c.quality_prune_threshold,
1008            )
1009            .into());
1010        }
1011        if c.quality_prune_patience == 0 {
1012            return Err(ConfigError::out_of_range(
1013                "quality_prune_patience",
1014                "must be > 0",
1015                c.quality_prune_patience,
1016            )
1017            .into());
1018        }
1019
1020        // -- Error-weighted sample importance --
1021        if let Some(alpha) = c.error_weight_alpha {
1022            if alpha <= 0.0 || alpha >= 1.0 {
1023                return Err(ConfigError::out_of_range(
1024                    "error_weight_alpha",
1025                    "must be in (0, 1)",
1026                    alpha,
1027                )
1028                .into());
1029            }
1030        }
1031
1032        // -- Drift detector parameters --
1033        match &c.drift_detector {
1034            DriftDetectorType::PageHinkley { delta, lambda } => {
1035                if *delta <= 0.0 {
1036                    return Err(ConfigError::out_of_range(
1037                        "drift_detector.PageHinkley.delta",
1038                        "must be > 0",
1039                        delta,
1040                    )
1041                    .into());
1042                }
1043                if *lambda <= 0.0 {
1044                    return Err(ConfigError::out_of_range(
1045                        "drift_detector.PageHinkley.lambda",
1046                        "must be > 0",
1047                        lambda,
1048                    )
1049                    .into());
1050                }
1051            }
1052            DriftDetectorType::Adwin { delta } => {
1053                if *delta <= 0.0 || *delta >= 1.0 {
1054                    return Err(ConfigError::out_of_range(
1055                        "drift_detector.Adwin.delta",
1056                        "must be in (0, 1)",
1057                        delta,
1058                    )
1059                    .into());
1060                }
1061            }
1062            DriftDetectorType::Ddm {
1063                warning_level,
1064                drift_level,
1065                min_instances,
1066            } => {
1067                if *warning_level <= 0.0 {
1068                    return Err(ConfigError::out_of_range(
1069                        "drift_detector.Ddm.warning_level",
1070                        "must be > 0",
1071                        warning_level,
1072                    )
1073                    .into());
1074                }
1075                if *drift_level <= 0.0 {
1076                    return Err(ConfigError::out_of_range(
1077                        "drift_detector.Ddm.drift_level",
1078                        "must be > 0",
1079                        drift_level,
1080                    )
1081                    .into());
1082                }
1083                if *drift_level <= *warning_level {
1084                    return Err(ConfigError::invalid(
1085                        "drift_detector.Ddm.drift_level",
1086                        format!(
1087                            "must be > warning_level ({}), got {}",
1088                            warning_level, drift_level
1089                        ),
1090                    )
1091                    .into());
1092                }
1093                if *min_instances == 0 {
1094                    return Err(ConfigError::out_of_range(
1095                        "drift_detector.Ddm.min_instances",
1096                        "must be > 0",
1097                        min_instances,
1098                    )
1099                    .into());
1100                }
1101            }
1102        }
1103
1104        // -- Variant parameters --
1105        match &c.variant {
1106            SGBTVariant::Standard => {} // no extra validation
1107            SGBTVariant::Skip { k } => {
1108                if *k == 0 {
1109                    return Err(
1110                        ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1111                    );
1112                }
1113            }
1114            SGBTVariant::MultipleIterations { multiplier } => {
1115                if *multiplier <= 0.0 {
1116                    return Err(ConfigError::out_of_range(
1117                        "variant.MultipleIterations.multiplier",
1118                        "must be > 0",
1119                        multiplier,
1120                    )
1121                    .into());
1122                }
1123            }
1124        }
1125
1126        Ok(self.config)
1127    }
1128}
1129
1130// ---------------------------------------------------------------------------
1131// Display impls
1132// ---------------------------------------------------------------------------
1133
1134impl core::fmt::Display for DriftDetectorType {
1135    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1136        match self {
1137            Self::PageHinkley { delta, lambda } => {
1138                write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1139            }
1140            Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1141            Self::Ddm {
1142                warning_level,
1143                drift_level,
1144                min_instances,
1145            } => write!(
1146                f,
1147                "Ddm(warning={}, drift={}, min_instances={})",
1148                warning_level, drift_level, min_instances
1149            ),
1150        }
1151    }
1152}
1153
1154impl core::fmt::Display for SGBTConfig {
1155    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1156        write!(
1157            f,
1158            "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1159            self.n_steps,
1160            self.learning_rate,
1161            self.max_depth,
1162            self.n_bins,
1163            self.variant,
1164            self.drift_detector,
1165        )
1166    }
1167}
1168
1169// ---------------------------------------------------------------------------
1170// Tests
1171// ---------------------------------------------------------------------------
1172
1173#[cfg(test)]
1174mod tests {
1175    use super::*;
1176    use alloc::format;
1177    use alloc::vec;
1178
1179    // ------------------------------------------------------------------
1180    // 1. Default config values are correct
1181    // ------------------------------------------------------------------
1182    #[test]
1183    fn default_config_values() {
1184        let cfg = SGBTConfig::default();
1185        assert_eq!(cfg.n_steps, 100);
1186        assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1187        assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1188        assert_eq!(cfg.max_depth, 6);
1189        assert_eq!(cfg.n_bins, 64);
1190        assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1191        assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1192        assert_eq!(cfg.grace_period, 200);
1193        assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1194        assert_eq!(cfg.variant, SGBTVariant::Standard);
1195    }
1196
1197    // ------------------------------------------------------------------
1198    // 2. Builder chain works
1199    // ------------------------------------------------------------------
1200    #[test]
1201    fn builder_chain() {
1202        let cfg = SGBTConfig::builder()
1203            .n_steps(50)
1204            .learning_rate(0.1)
1205            .feature_subsample_rate(0.5)
1206            .max_depth(10)
1207            .n_bins(128)
1208            .lambda(0.5)
1209            .gamma(0.1)
1210            .grace_period(500)
1211            .delta(1e-3)
1212            .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1213            .variant(SGBTVariant::Skip { k: 5 })
1214            .build()
1215            .expect("valid config");
1216
1217        assert_eq!(cfg.n_steps, 50);
1218        assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1219        assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1220        assert_eq!(cfg.max_depth, 10);
1221        assert_eq!(cfg.n_bins, 128);
1222        assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1223        assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1224        assert_eq!(cfg.grace_period, 500);
1225        assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1226        assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1227
1228        // Verify drift detector type is Adwin.
1229        match &cfg.drift_detector {
1230            DriftDetectorType::Adwin { delta } => {
1231                assert!((delta - 0.01).abs() < f64::EPSILON);
1232            }
1233            _ => panic!("expected Adwin drift detector"),
1234        }
1235    }
1236
1237    // ------------------------------------------------------------------
1238    // 3. Validation rejects invalid values
1239    // ------------------------------------------------------------------
1240    #[test]
1241    fn validation_rejects_n_steps_zero() {
1242        let result = SGBTConfig::builder().n_steps(0).build();
1243        assert!(result.is_err());
1244        let msg = format!("{}", result.unwrap_err());
1245        assert!(msg.contains("n_steps"));
1246    }
1247
1248    #[test]
1249    fn validation_rejects_learning_rate_zero() {
1250        let result = SGBTConfig::builder().learning_rate(0.0).build();
1251        assert!(result.is_err());
1252        let msg = format!("{}", result.unwrap_err());
1253        assert!(msg.contains("learning_rate"));
1254    }
1255
1256    #[test]
1257    fn validation_rejects_learning_rate_above_one() {
1258        let result = SGBTConfig::builder().learning_rate(1.5).build();
1259        assert!(result.is_err());
1260    }
1261
1262    #[test]
1263    fn validation_accepts_learning_rate_one() {
1264        let result = SGBTConfig::builder().learning_rate(1.0).build();
1265        assert!(result.is_ok());
1266    }
1267
1268    #[test]
1269    fn validation_rejects_negative_learning_rate() {
1270        let result = SGBTConfig::builder().learning_rate(-0.1).build();
1271        assert!(result.is_err());
1272    }
1273
1274    #[test]
1275    fn validation_rejects_feature_subsample_zero() {
1276        let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1277        assert!(result.is_err());
1278    }
1279
1280    #[test]
1281    fn validation_rejects_feature_subsample_above_one() {
1282        let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1283        assert!(result.is_err());
1284    }
1285
1286    #[test]
1287    fn validation_rejects_max_depth_zero() {
1288        let result = SGBTConfig::builder().max_depth(0).build();
1289        assert!(result.is_err());
1290    }
1291
1292    #[test]
1293    fn validation_rejects_n_bins_one() {
1294        let result = SGBTConfig::builder().n_bins(1).build();
1295        assert!(result.is_err());
1296    }
1297
1298    #[test]
1299    fn validation_rejects_negative_lambda() {
1300        let result = SGBTConfig::builder().lambda(-0.1).build();
1301        assert!(result.is_err());
1302    }
1303
1304    #[test]
1305    fn validation_accepts_zero_lambda() {
1306        let result = SGBTConfig::builder().lambda(0.0).build();
1307        assert!(result.is_ok());
1308    }
1309
1310    #[test]
1311    fn validation_rejects_negative_gamma() {
1312        let result = SGBTConfig::builder().gamma(-0.1).build();
1313        assert!(result.is_err());
1314    }
1315
1316    #[test]
1317    fn validation_rejects_grace_period_zero() {
1318        let result = SGBTConfig::builder().grace_period(0).build();
1319        assert!(result.is_err());
1320    }
1321
1322    #[test]
1323    fn validation_rejects_delta_zero() {
1324        let result = SGBTConfig::builder().delta(0.0).build();
1325        assert!(result.is_err());
1326    }
1327
1328    #[test]
1329    fn validation_rejects_delta_one() {
1330        let result = SGBTConfig::builder().delta(1.0).build();
1331        assert!(result.is_err());
1332    }
1333
1334    // ------------------------------------------------------------------
1335    // 3b. Drift detector parameter validation
1336    // ------------------------------------------------------------------
1337    #[test]
1338    fn validation_rejects_pht_negative_delta() {
1339        let result = SGBTConfig::builder()
1340            .drift_detector(DriftDetectorType::PageHinkley {
1341                delta: -1.0,
1342                lambda: 50.0,
1343            })
1344            .build();
1345        assert!(result.is_err());
1346        let msg = format!("{}", result.unwrap_err());
1347        assert!(msg.contains("PageHinkley"));
1348    }
1349
1350    #[test]
1351    fn validation_rejects_pht_zero_lambda() {
1352        let result = SGBTConfig::builder()
1353            .drift_detector(DriftDetectorType::PageHinkley {
1354                delta: 0.005,
1355                lambda: 0.0,
1356            })
1357            .build();
1358        assert!(result.is_err());
1359    }
1360
1361    #[test]
1362    fn validation_rejects_adwin_delta_out_of_range() {
1363        let result = SGBTConfig::builder()
1364            .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1365            .build();
1366        assert!(result.is_err());
1367
1368        let result = SGBTConfig::builder()
1369            .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1370            .build();
1371        assert!(result.is_err());
1372    }
1373
1374    #[test]
1375    fn validation_rejects_ddm_warning_above_drift() {
1376        let result = SGBTConfig::builder()
1377            .drift_detector(DriftDetectorType::Ddm {
1378                warning_level: 3.0,
1379                drift_level: 2.0,
1380                min_instances: 30,
1381            })
1382            .build();
1383        assert!(result.is_err());
1384        let msg = format!("{}", result.unwrap_err());
1385        assert!(msg.contains("drift_level"));
1386        assert!(msg.contains("must be > warning_level"));
1387    }
1388
1389    #[test]
1390    fn validation_rejects_ddm_equal_levels() {
1391        let result = SGBTConfig::builder()
1392            .drift_detector(DriftDetectorType::Ddm {
1393                warning_level: 2.0,
1394                drift_level: 2.0,
1395                min_instances: 30,
1396            })
1397            .build();
1398        assert!(result.is_err());
1399    }
1400
1401    #[test]
1402    fn validation_rejects_ddm_zero_min_instances() {
1403        let result = SGBTConfig::builder()
1404            .drift_detector(DriftDetectorType::Ddm {
1405                warning_level: 2.0,
1406                drift_level: 3.0,
1407                min_instances: 0,
1408            })
1409            .build();
1410        assert!(result.is_err());
1411    }
1412
1413    #[test]
1414    fn validation_rejects_ddm_zero_warning_level() {
1415        let result = SGBTConfig::builder()
1416            .drift_detector(DriftDetectorType::Ddm {
1417                warning_level: 0.0,
1418                drift_level: 3.0,
1419                min_instances: 30,
1420            })
1421            .build();
1422        assert!(result.is_err());
1423    }
1424
1425    // ------------------------------------------------------------------
1426    // 3c. Variant parameter validation
1427    // ------------------------------------------------------------------
1428    #[test]
1429    fn validation_rejects_skip_k_zero() {
1430        let result = SGBTConfig::builder()
1431            .variant(SGBTVariant::Skip { k: 0 })
1432            .build();
1433        assert!(result.is_err());
1434        let msg = format!("{}", result.unwrap_err());
1435        assert!(msg.contains("Skip"));
1436    }
1437
1438    #[test]
1439    fn validation_rejects_mi_zero_multiplier() {
1440        let result = SGBTConfig::builder()
1441            .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1442            .build();
1443        assert!(result.is_err());
1444    }
1445
1446    #[test]
1447    fn validation_rejects_mi_negative_multiplier() {
1448        let result = SGBTConfig::builder()
1449            .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1450            .build();
1451        assert!(result.is_err());
1452    }
1453
1454    #[test]
1455    fn validation_accepts_standard_variant() {
1456        let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1457        assert!(result.is_ok());
1458    }
1459
1460    // ------------------------------------------------------------------
1461    // 4. DriftDetectorType creates correct detector types
1462    // ------------------------------------------------------------------
1463    #[test]
1464    fn drift_detector_type_creates_page_hinkley() {
1465        let dt = DriftDetectorType::PageHinkley {
1466            delta: 0.01,
1467            lambda: 100.0,
1468        };
1469        let mut detector = dt.create();
1470
1471        // Should start with zero mean.
1472        assert_eq!(detector.estimated_mean(), 0.0);
1473
1474        // Feed a stable value -- should not drift.
1475        let signal = detector.update(1.0);
1476        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1477    }
1478
1479    #[test]
1480    fn drift_detector_type_creates_adwin() {
1481        let dt = DriftDetectorType::Adwin { delta: 0.05 };
1482        let mut detector = dt.create();
1483
1484        assert_eq!(detector.estimated_mean(), 0.0);
1485        let signal = detector.update(1.0);
1486        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1487    }
1488
1489    #[test]
1490    fn drift_detector_type_creates_ddm() {
1491        let dt = DriftDetectorType::Ddm {
1492            warning_level: 2.0,
1493            drift_level: 3.0,
1494            min_instances: 30,
1495        };
1496        let mut detector = dt.create();
1497
1498        assert_eq!(detector.estimated_mean(), 0.0);
1499        let signal = detector.update(0.1);
1500        assert_eq!(signal, crate::drift::DriftSignal::Stable);
1501    }
1502
1503    // ------------------------------------------------------------------
1504    // 5. Default drift detector is PageHinkley
1505    // ------------------------------------------------------------------
1506    #[test]
1507    fn default_drift_detector_is_page_hinkley() {
1508        let dt = DriftDetectorType::default();
1509        match dt {
1510            DriftDetectorType::PageHinkley { delta, lambda } => {
1511                assert!((delta - 0.005).abs() < f64::EPSILON);
1512                assert!((lambda - 50.0).abs() < f64::EPSILON);
1513            }
1514            _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1515        }
1516    }
1517
1518    // ------------------------------------------------------------------
1519    // 6. Default config builds without error
1520    // ------------------------------------------------------------------
1521    #[test]
1522    fn default_config_builds() {
1523        let result = SGBTConfig::builder().build();
1524        assert!(result.is_ok());
1525    }
1526
1527    // ------------------------------------------------------------------
1528    // 7. Config clone preserves all fields
1529    // ------------------------------------------------------------------
1530    #[test]
1531    fn config_clone_preserves_fields() {
1532        let original = SGBTConfig::builder()
1533            .n_steps(50)
1534            .learning_rate(0.05)
1535            .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1536            .build()
1537            .unwrap();
1538
1539        let cloned = original.clone();
1540
1541        assert_eq!(cloned.n_steps, 50);
1542        assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1543        assert_eq!(
1544            cloned.variant,
1545            SGBTVariant::MultipleIterations { multiplier: 5.0 }
1546        );
1547    }
1548
1549    // ------------------------------------------------------------------
1550    // 8. All three DDM valid configs accepted
1551    // ------------------------------------------------------------------
1552    #[test]
1553    fn valid_ddm_config_accepted() {
1554        let result = SGBTConfig::builder()
1555            .drift_detector(DriftDetectorType::Ddm {
1556                warning_level: 1.5,
1557                drift_level: 2.5,
1558                min_instances: 10,
1559            })
1560            .build();
1561        assert!(result.is_ok());
1562    }
1563
1564    // ------------------------------------------------------------------
1565    // 9. Created detectors are functional (round-trip test)
1566    // ------------------------------------------------------------------
1567    #[test]
1568    fn created_detectors_are_functional() {
1569        // Create a detector, feed it data, verify it can detect drift.
1570        let dt = DriftDetectorType::PageHinkley {
1571            delta: 0.005,
1572            lambda: 50.0,
1573        };
1574        let mut detector = dt.create();
1575
1576        // Feed 500 stable values.
1577        for _ in 0..500 {
1578            detector.update(1.0);
1579        }
1580
1581        // Feed shifted values -- should eventually drift.
1582        let mut drifted = false;
1583        for _ in 0..500 {
1584            if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1585                drifted = true;
1586                break;
1587            }
1588        }
1589        assert!(
1590            drifted,
1591            "detector created from DriftDetectorType should be functional"
1592        );
1593    }
1594
1595    // ------------------------------------------------------------------
1596    // 10. Boundary acceptance: n_bins=2 is the minimum valid
1597    // ------------------------------------------------------------------
1598    #[test]
1599    fn boundary_n_bins_two_accepted() {
1600        let result = SGBTConfig::builder().n_bins(2).build();
1601        assert!(result.is_ok());
1602    }
1603
1604    // ------------------------------------------------------------------
1605    // 11. Boundary acceptance: grace_period=1 is valid
1606    // ------------------------------------------------------------------
1607    #[test]
1608    fn boundary_grace_period_one_accepted() {
1609        let result = SGBTConfig::builder().grace_period(1).build();
1610        assert!(result.is_ok());
1611    }
1612
1613    // ------------------------------------------------------------------
1614    // 12. Feature names -- valid config with names
1615    // ------------------------------------------------------------------
1616    #[test]
1617    fn feature_names_accepted() {
1618        let cfg = SGBTConfig::builder()
1619            .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1620            .build()
1621            .unwrap();
1622        assert_eq!(
1623            cfg.feature_names.as_ref().unwrap(),
1624            &["price", "volume", "spread"]
1625        );
1626    }
1627
1628    // ------------------------------------------------------------------
1629    // 13. Feature names -- duplicate names rejected
1630    // ------------------------------------------------------------------
1631    #[test]
1632    fn feature_names_rejects_duplicates() {
1633        let result = SGBTConfig::builder()
1634            .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1635            .build();
1636        assert!(result.is_err());
1637        let msg = format!("{}", result.unwrap_err());
1638        assert!(msg.contains("duplicate"));
1639    }
1640
1641    // ------------------------------------------------------------------
1642    // 14. Feature names -- serde backward compat (missing field)
1643    //     Requires serde_json which is only in the full irithyll crate.
1644    // ------------------------------------------------------------------
1645
1646    // ------------------------------------------------------------------
1647    // 15. Feature names -- empty vec is valid
1648    // ------------------------------------------------------------------
1649    #[test]
1650    fn feature_names_empty_vec_accepted() {
1651        let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1652        assert!(cfg.feature_names.unwrap().is_empty());
1653    }
1654
1655    // ------------------------------------------------------------------
1656    // 16. Adaptive leaf bound -- builder
1657    // ------------------------------------------------------------------
1658    #[test]
1659    fn builder_adaptive_leaf_bound() {
1660        let cfg = SGBTConfig::builder()
1661            .adaptive_leaf_bound(3.0)
1662            .build()
1663            .unwrap();
1664        assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1665    }
1666
1667    // ------------------------------------------------------------------
1668    // 17. Adaptive leaf bound -- validation rejects zero
1669    // ------------------------------------------------------------------
1670    #[test]
1671    fn validation_rejects_zero_adaptive_leaf_bound() {
1672        let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1673        assert!(result.is_err());
1674        let msg = format!("{}", result.unwrap_err());
1675        assert!(
1676            msg.contains("adaptive_leaf_bound"),
1677            "error should mention adaptive_leaf_bound: {}",
1678            msg,
1679        );
1680    }
1681
1682    // ------------------------------------------------------------------
1683    // 18. Adaptive leaf bound -- serde backward compat
1684    //     Requires serde_json which is only in the full irithyll crate.
1685    // ------------------------------------------------------------------
1686}