Skip to main content

irithyll_core/ensemble/
config.rs

1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34///   `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1) compute.
35///   Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38///   log(σ) from features (NGBoost-style dual chain).  Gives *feature-conditional*
39///   uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43    /// EWMA of squared prediction errors — always calibrated, O(1).
44    #[default]
45    Empirical,
46    /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47    TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61    /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62    /// (detection threshold).
63    PageHinkley {
64        /// Magnitude tolerance. Default 0.005.
65        delta: f64,
66        /// Detection threshold. Default 50.0.
67        lambda: f64,
68    },
69
70    /// ADWIN with custom confidence parameter.
71    Adwin {
72        /// Confidence (smaller = fewer false positives). Default 0.002.
73        delta: f64,
74    },
75
76    /// DDM with custom warning/drift levels and minimum warmup instances.
77    Ddm {
78        /// Warning threshold multiplier. Default 2.0.
79        warning_level: f64,
80        /// Drift threshold multiplier. Default 3.0.
81        drift_level: f64,
82        /// Minimum observations before detection activates. Default 30.
83        min_instances: u64,
84    },
85}
86
87impl Default for DriftDetectorType {
88    fn default() -> Self {
89        DriftDetectorType::PageHinkley {
90            delta: 0.005,
91            lambda: 50.0,
92        }
93    }
94}
95
96impl DriftDetectorType {
97    /// Create a new, fresh drift detector from this configuration.
98    pub fn create(&self) -> Box<dyn DriftDetector> {
99        match self {
100            Self::PageHinkley { delta, lambda } => {
101                Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102            }
103            Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104            Self::Ddm {
105                warning_level,
106                drift_level,
107                min_instances,
108            } => Box::new(Ddm::with_params(
109                *warning_level,
110                *drift_level,
111                *min_instances,
112            )),
113        }
114    }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter                | Default              |
128/// |--------------------------|----------------------|
129/// | `n_steps`                | 100                  |
130/// | `learning_rate`          | 0.0125               |
131/// | `feature_subsample_rate` | 0.75                 |
132/// | `max_depth`              | 6                    |
133/// | `n_bins`                 | 64                   |
134/// | `lambda`                 | 1.0                  |
135/// | `gamma`                  | 0.0                  |
136/// | `grace_period`           | 200                  |
137/// | `delta`                  | 1e-7                 |
138/// | `drift_detector`         | PageHinkley(0.005, 50.0) |
139/// | `variant`                | Standard             |
140/// | `seed`                   | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count`   | 50                   |
142/// | `leaf_half_life`         | None (disabled)      |
143/// | `max_tree_samples`       | None (disabled)      |
144/// | `split_reeval_interval`  | None (disabled)      |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148    /// Number of boosting steps (trees in the ensemble). Default 100.
149    pub n_steps: usize,
150    /// Learning rate (shrinkage). Default 0.0125.
151    pub learning_rate: f64,
152    /// Fraction of features to subsample per tree. Default 0.75.
153    pub feature_subsample_rate: f64,
154    /// Maximum tree depth. Default 6.
155    pub max_depth: usize,
156    /// Number of histogram bins. Default 64.
157    pub n_bins: usize,
158    /// L2 regularization parameter (lambda). Default 1.0.
159    pub lambda: f64,
160    /// Minimum split gain (gamma). Default 0.0.
161    pub gamma: f64,
162    /// Grace period: min samples before evaluating splits. Default 200.
163    pub grace_period: usize,
164    /// Hoeffding bound confidence (delta). Default 1e-7.
165    pub delta: f64,
166    /// Drift detector type for tree replacement. Default: PageHinkley.
167    pub drift_detector: DriftDetectorType,
168    /// SGBT computational variant. Default: Standard.
169    pub variant: SGBTVariant,
170    /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171    ///
172    /// Controls feature subsampling and variant skip/MI stochastic decisions.
173    /// Two models with the same seed and same data will produce identical results.
174    pub seed: u64,
175    /// Number of initial targets to collect before computing the base prediction.
176    /// Default: 50.
177    pub initial_target_count: usize,
178
179    /// Half-life for exponential leaf decay (in samples per leaf).
180    ///
181    /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182    /// statistics have half the weight of the most recent sample. This causes
183    /// the model to continuously adapt to changing data distributions rather
184    /// than freezing on early observations.
185    ///
186    /// `None` (default) disables decay -- traditional monotonic accumulation.
187    #[cfg_attr(feature = "serde", serde(default))]
188    pub leaf_half_life: Option<usize>,
189
190    /// Maximum samples a single tree processes before proactive replacement.
191    ///
192    /// After this many samples, the tree is replaced with a fresh one regardless
193    /// of drift detector state. Prevents stale tree structure from persisting
194    /// when the drift detector is not sensitive enough.
195    ///
196    /// `None` (default) disables time-based replacement.
197    #[cfg_attr(feature = "serde", serde(default))]
198    pub max_tree_samples: Option<u64>,
199
200    /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
201    /// whether a split would improve them.
202    ///
203    /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
204    /// `split_reeval_interval` samples since its last evaluation and has reached
205    /// max depth, it re-evaluates whether a split should be performed.
206    ///
207    /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
208    #[cfg_attr(feature = "serde", serde(default))]
209    pub split_reeval_interval: Option<usize>,
210
211    /// Optional human-readable feature names.
212    ///
213    /// When set, enables `named_feature_importances` and
214    /// `train_one_named` for production-friendly named access.
215    /// Length must match the number of features in training data.
216    #[cfg_attr(feature = "serde", serde(default))]
217    pub feature_names: Option<Vec<String>>,
218
219    /// Optional per-feature type declarations.
220    ///
221    /// When set, declares which features are categorical vs continuous.
222    /// Categorical features use one-bin-per-category binning and Fisher
223    /// optimal binary partitioning for split evaluation.
224    /// Length must match the number of features in training data.
225    ///
226    /// `None` (default) treats all features as continuous.
227    #[cfg_attr(feature = "serde", serde(default))]
228    pub feature_types: Option<Vec<FeatureType>>,
229
230    /// Gradient clipping threshold in standard deviations per leaf.
231    ///
232    /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
233    /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
234    /// clamped to the boundary. This prevents outlier samples from corrupting
235    /// leaf statistics, which is critical in streaming settings where sudden
236    /// label floods can destabilize the model.
237    ///
238    /// Typical value: 3.0 (3-sigma clipping).
239    /// `None` (default) disables gradient clipping.
240    #[cfg_attr(feature = "serde", serde(default))]
241    pub gradient_clip_sigma: Option<f64>,
242
243    /// Per-feature monotonic constraints.
244    ///
245    /// Each element specifies the monotonic relationship between a feature and
246    /// the prediction:
247    /// - `+1`: prediction must be non-decreasing as feature value increases.
248    /// - `-1`: prediction must be non-increasing as feature value increases.
249    /// - `0`: no constraint (unconstrained).
250    ///
251    /// During split evaluation, candidate splits that would violate monotonicity
252    /// (left child value > right child value for +1 constraints, or vice versa)
253    /// are rejected.
254    ///
255    /// Length must match the number of features in training data.
256    /// `None` (default) means no monotonic constraints.
257    #[cfg_attr(feature = "serde", serde(default))]
258    pub monotone_constraints: Option<Vec<i8>>,
259
260    /// EWMA smoothing factor for quality-based tree pruning.
261    ///
262    /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
263    /// moving average of its marginal contribution to the ensemble. Trees whose
264    /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
265    /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
266    /// samples are replaced with a fresh tree that can learn the current regime.
267    ///
268    /// This prevents "dead wood" -- trees from a past regime that no longer
269    /// contribute meaningfully to ensemble accuracy.
270    ///
271    /// `None` (default) disables quality-based pruning.
272    /// Suggested value: 0.01.
273    #[cfg_attr(feature = "serde", serde(default))]
274    pub quality_prune_alpha: Option<f64>,
275
276    /// Minimum contribution threshold for quality-based pruning.
277    ///
278    /// A tree's EWMA contribution must stay above this value to avoid being
279    /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
280    ///
281    /// Default: 1e-6.
282    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
283    pub quality_prune_threshold: f64,
284
285    /// Consecutive low-contribution samples before a tree is replaced.
286    ///
287    /// After this many consecutive samples where a tree's EWMA contribution
288    /// is below `quality_prune_threshold`, the tree is reset. Only used when
289    /// `quality_prune_alpha` is `Some`.
290    ///
291    /// Default: 500.
292    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
293    pub quality_prune_patience: u64,
294
295    /// EWMA smoothing factor for error-weighted sample importance.
296    ///
297    /// When `Some(alpha)`, samples the model predicted poorly get higher
298    /// effective weight during histogram accumulation. The weight is:
299    /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
300    ///
301    /// This is a streaming version of AdaBoost's reweighting applied at the
302    /// gradient level -- learning capacity focuses on hard/novel patterns,
303    /// enabling faster adaptation to regime changes.
304    ///
305    /// `None` (default) disables error weighting.
306    /// Suggested value: 0.01.
307    #[cfg_attr(feature = "serde", serde(default))]
308    pub error_weight_alpha: Option<f64>,
309
310    /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
311    ///
312    /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
313    /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
314    /// is an EWMA of the model's predicted σ (alpha = 0.001).
315    ///
316    /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
317    /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
318    /// trains at the unmodulated base rate to prevent positive feedback loops.
319    ///
320    /// Default: `false`.
321    #[cfg_attr(feature = "serde", serde(default))]
322    pub uncertainty_modulated_lr: bool,
323
324    /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
325    ///
326    /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
327    ///   errors.  `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1).
328    /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
329    ///   separate tree ensemble predicting log(σ) from features.
330    ///
331    /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
332    /// is strongly recommended — scale tree gradients are inherently weak and
333    /// the trees often fail to split.
334    #[cfg_attr(feature = "serde", serde(default))]
335    pub scale_mode: ScaleMode,
336
337    /// EWMA smoothing factor for empirical σ estimation.
338    ///
339    /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
340    /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
341    /// Higher values react faster to regime changes but are noisier.
342    ///
343    /// Default: `0.01` (~100-sample effective window).
344    #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
345    pub empirical_sigma_alpha: f64,
346
347    /// Maximum absolute leaf output value.
348    ///
349    /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
350    /// Prevents runaway leaf weights from causing prediction explosions
351    /// in feedback loops. `None` (default) means no clamping.
352    #[cfg_attr(feature = "serde", serde(default))]
353    pub max_leaf_output: Option<f64>,
354
355    /// Per-leaf adaptive output bound (sigma multiplier).
356    ///
357    /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
358    /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
359    /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
360    ///
361    /// This is strictly superior to `max_leaf_output` for streaming — the bound
362    /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
363    /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
364    ///
365    /// Typical value: 3.0 (3-sigma bound).
366    /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
367    #[cfg_attr(feature = "serde", serde(default))]
368    pub adaptive_leaf_bound: Option<f64>,
369
370    /// Per-split information criterion (Lunde-Kleppe-Skaug 2020).
371    ///
372    /// When `Some(cir_factor)`, replaces `max_depth` with a per-split
373    /// generalization test. Each candidate split must have
374    /// `gain > cir_factor * sigma^2_g / n * n_features`.
375    /// `max_depth * 2` becomes a hard safety ceiling.
376    ///
377    /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
378    /// `None` (default) uses static `max_depth` only.
379    #[cfg_attr(feature = "serde", serde(default))]
380    pub adaptive_depth: Option<f64>,
381
382    /// Minimum hessian sum before a leaf produces non-zero output.
383    ///
384    /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
385    /// Prevents post-replacement spikes from fresh leaves with insufficient
386    /// samples. `None` (default) means all leaves contribute immediately.
387    #[cfg_attr(feature = "serde", serde(default))]
388    pub min_hessian_sum: Option<f64>,
389
390    /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
391    ///
392    /// When `Some(k)`, the distributional location gradient uses Huber loss
393    /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
394    /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
395    /// `None` (default) uses squared loss.
396    #[cfg_attr(feature = "serde", serde(default))]
397    pub huber_k: Option<f64>,
398
399    /// Shadow warmup for graduated tree handoff.
400    ///
401    /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
402    /// alongside every active tree. The shadow trains on the same gradient stream
403    /// but does not contribute to predictions until it has seen `n` samples.
404    ///
405    /// As the active tree ages past 80% of `max_tree_samples`, its prediction
406    /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
407    /// over `n` samples after warmup. When the active weight reaches 0, the shadow
408    /// is promoted and a new shadow is spawned — no cold-start prediction dip.
409    ///
410    /// Requires `max_tree_samples` to be set for time-based graduated handoff.
411    /// Drift-based replacement still uses hard swap (shadow is already warm).
412    ///
413    /// `None` (default) disables graduated handoff — uses traditional hard swap.
414    #[cfg_attr(feature = "serde", serde(default))]
415    pub shadow_warmup: Option<usize>,
416
417    /// Leaf prediction model type.
418    ///
419    /// Controls how each leaf computes its prediction:
420    /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
421    /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
422    ///   AdaGrad optimization. Optional `decay` for concept drift. Recommended for
423    ///   low-depth trees (depth 2--4).
424    /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
425    ///   Optional `decay` for concept drift.
426    /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
427    ///   when the Hoeffding bound confirms a more complex model is better.
428    ///
429    /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
430    #[cfg_attr(feature = "serde", serde(default))]
431    pub leaf_model_type: LeafModelType,
432
433    /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
434    ///
435    /// When non-zero, the distributional model maintains a packed f32 cache of
436    /// its location ensemble that is re-exported every `packed_refresh_interval`
437    /// training samples. Predictions use the cache for O(1)-per-tree inference
438    /// via contiguous memory traversal, falling back to full tree traversal when
439    /// the cache is absent or produces non-finite results.
440    ///
441    /// `0` (default) disables the packed cache.
442    #[cfg_attr(feature = "serde", serde(default))]
443    pub packed_refresh_interval: u64,
444}
445
446#[cfg(feature = "serde")]
447fn default_empirical_sigma_alpha() -> f64 {
448    0.01
449}
450
451#[cfg(feature = "serde")]
452fn default_quality_prune_threshold() -> f64 {
453    1e-6
454}
455
456#[cfg(feature = "serde")]
457fn default_quality_prune_patience() -> u64 {
458    500
459}
460
461impl Default for SGBTConfig {
462    fn default() -> Self {
463        Self {
464            n_steps: 100,
465            learning_rate: 0.0125,
466            feature_subsample_rate: 0.75,
467            max_depth: 6,
468            n_bins: 64,
469            lambda: 1.0,
470            gamma: 0.0,
471            grace_period: 200,
472            delta: 1e-7,
473            drift_detector: DriftDetectorType::default(),
474            variant: SGBTVariant::default(),
475            seed: 0xDEAD_BEEF_CAFE_4242,
476            initial_target_count: 50,
477            leaf_half_life: None,
478            max_tree_samples: None,
479            split_reeval_interval: None,
480            feature_names: None,
481            feature_types: None,
482            gradient_clip_sigma: None,
483            monotone_constraints: None,
484            quality_prune_alpha: None,
485            quality_prune_threshold: 1e-6,
486            quality_prune_patience: 500,
487            error_weight_alpha: None,
488            uncertainty_modulated_lr: false,
489            scale_mode: ScaleMode::default(),
490            empirical_sigma_alpha: 0.01,
491            max_leaf_output: None,
492            adaptive_leaf_bound: None,
493            adaptive_depth: None,
494            min_hessian_sum: None,
495            huber_k: None,
496            shadow_warmup: None,
497            leaf_model_type: LeafModelType::default(),
498            packed_refresh_interval: 0,
499        }
500    }
501}
502
503impl SGBTConfig {
504    /// Start building a configuration via the builder pattern.
505    pub fn builder() -> SGBTConfigBuilder {
506        SGBTConfigBuilder::default()
507    }
508}
509
510// ---------------------------------------------------------------------------
511// SGBTConfigBuilder
512// ---------------------------------------------------------------------------
513
514/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
515///
516/// # Example
517///
518/// ```ignore
519/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
520/// use irithyll::ensemble::variants::SGBTVariant;
521///
522/// let config = SGBTConfig::builder()
523///     .n_steps(200)
524///     .learning_rate(0.05)
525///     .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
526///     .variant(SGBTVariant::Skip { k: 10 })
527///     .build()
528///     .expect("valid config");
529/// ```
530#[derive(Debug, Clone, Default)]
531pub struct SGBTConfigBuilder {
532    config: SGBTConfig,
533}
534
535impl SGBTConfigBuilder {
536    /// Set the number of boosting steps (trees in the ensemble).
537    pub fn n_steps(mut self, n: usize) -> Self {
538        self.config.n_steps = n;
539        self
540    }
541
542    /// Set the learning rate (shrinkage factor).
543    pub fn learning_rate(mut self, lr: f64) -> Self {
544        self.config.learning_rate = lr;
545        self
546    }
547
548    /// Set the fraction of features to subsample per tree.
549    pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
550        self.config.feature_subsample_rate = rate;
551        self
552    }
553
554    /// Set the maximum tree depth.
555    pub fn max_depth(mut self, depth: usize) -> Self {
556        self.config.max_depth = depth;
557        self
558    }
559
560    /// Set the number of histogram bins per feature.
561    pub fn n_bins(mut self, bins: usize) -> Self {
562        self.config.n_bins = bins;
563        self
564    }
565
566    /// Set the L2 regularization parameter (lambda).
567    pub fn lambda(mut self, l: f64) -> Self {
568        self.config.lambda = l;
569        self
570    }
571
572    /// Set the minimum split gain (gamma).
573    pub fn gamma(mut self, g: f64) -> Self {
574        self.config.gamma = g;
575        self
576    }
577
578    /// Set the grace period (minimum samples before evaluating splits).
579    pub fn grace_period(mut self, gp: usize) -> Self {
580        self.config.grace_period = gp;
581        self
582    }
583
584    /// Set the Hoeffding bound confidence parameter (delta).
585    pub fn delta(mut self, d: f64) -> Self {
586        self.config.delta = d;
587        self
588    }
589
590    /// Set the drift detector type for tree replacement.
591    pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
592        self.config.drift_detector = dt;
593        self
594    }
595
596    /// Set the SGBT computational variant.
597    pub fn variant(mut self, v: SGBTVariant) -> Self {
598        self.config.variant = v;
599        self
600    }
601
602    /// Set the random seed for deterministic reproducibility.
603    ///
604    /// Controls feature subsampling and variant skip/MI stochastic decisions.
605    /// Two models with the same seed and data sequence will produce identical results.
606    pub fn seed(mut self, seed: u64) -> Self {
607        self.config.seed = seed;
608        self
609    }
610
611    /// Set the number of initial targets to collect before computing the base prediction.
612    ///
613    /// The model collects this many target values before initializing the base
614    /// prediction (via `loss.initial_prediction`). Default: 50.
615    pub fn initial_target_count(mut self, count: usize) -> Self {
616        self.config.initial_target_count = count;
617        self
618    }
619
620    /// Set the half-life for exponential leaf decay (in samples per leaf).
621    ///
622    /// After `n` samples, a leaf's accumulated statistics have half the weight
623    /// of the most recent sample. Enables continuous adaptation to concept drift.
624    pub fn leaf_half_life(mut self, n: usize) -> Self {
625        self.config.leaf_half_life = Some(n);
626        self
627    }
628
629    /// Set the maximum samples a single tree processes before proactive replacement.
630    ///
631    /// After `n` samples, the tree is replaced regardless of drift detector state.
632    pub fn max_tree_samples(mut self, n: u64) -> Self {
633        self.config.max_tree_samples = Some(n);
634        self
635    }
636
637    /// Set the split re-evaluation interval for max-depth leaves.
638    ///
639    /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
640    /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
641    pub fn split_reeval_interval(mut self, n: usize) -> Self {
642        self.config.split_reeval_interval = Some(n);
643        self
644    }
645
646    /// Set human-readable feature names.
647    ///
648    /// Enables named feature importances and named training input.
649    /// Names must be unique; validated at [`build()`](Self::build).
650    pub fn feature_names(mut self, names: Vec<String>) -> Self {
651        self.config.feature_names = Some(names);
652        self
653    }
654
655    /// Set per-feature type declarations.
656    ///
657    /// Declares which features are categorical vs continuous. Categorical features
658    /// use one-bin-per-category binning and Fisher optimal binary partitioning.
659    /// Supports up to 64 distinct category values per categorical feature.
660    pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
661        self.config.feature_types = Some(types);
662        self
663    }
664
665    /// Set per-leaf gradient clipping threshold (in standard deviations).
666    ///
667    /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
668    /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
669    /// from corrupting streaming model stability.
670    ///
671    /// Typical value: 3.0 (3-sigma clipping).
672    pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
673        self.config.gradient_clip_sigma = Some(sigma);
674        self
675    }
676
677    /// Set per-feature monotonic constraints.
678    ///
679    /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
680    /// Candidate splits violating monotonicity are rejected during tree growth.
681    pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
682        self.config.monotone_constraints = Some(constraints);
683        self
684    }
685
686    /// Enable quality-based tree pruning with the given EWMA smoothing factor.
687    ///
688    /// Trees whose marginal contribution drops below the threshold for
689    /// `patience` consecutive samples are replaced with fresh trees.
690    /// Suggested alpha: 0.01.
691    pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
692        self.config.quality_prune_alpha = Some(alpha);
693        self
694    }
695
696    /// Set the minimum contribution threshold for quality-based pruning.
697    ///
698    /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
699    pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
700        self.config.quality_prune_threshold = threshold;
701        self
702    }
703
704    /// Set the patience (consecutive low-contribution samples) before pruning.
705    ///
706    /// Default: 500. Only relevant when `quality_prune_alpha` is set.
707    pub fn quality_prune_patience(mut self, patience: u64) -> Self {
708        self.config.quality_prune_patience = patience;
709        self
710    }
711
712    /// Enable error-weighted sample importance with the given EWMA smoothing factor.
713    ///
714    /// Samples the model predicted poorly get higher effective weight.
715    /// Suggested alpha: 0.01.
716    pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
717        self.config.error_weight_alpha = Some(alpha);
718        self
719    }
720
721    /// Enable σ-modulated learning rate for distributional models.
722    ///
723    /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
724    /// so the model adapts faster during high-uncertainty regimes and conserves
725    /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
726    ///
727    /// By default uses empirical σ (EWMA of squared errors).  Set
728    /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
729    pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
730        self.config.uncertainty_modulated_lr = enabled;
731        self
732    }
733
734    /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
735    ///
736    /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
737    /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
738    pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
739        self.config.scale_mode = mode;
740        self
741    }
742
743    /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
744    ///
745    /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
746    pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
747        self.config.empirical_sigma_alpha = alpha;
748        self
749    }
750
751    /// Set the maximum absolute leaf output value.
752    ///
753    /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
754    /// that cause prediction explosions.
755    pub fn max_leaf_output(mut self, max: f64) -> Self {
756        self.config.max_leaf_output = Some(max);
757        self
758    }
759
760    /// Set per-leaf adaptive output bound (sigma multiplier).
761    ///
762    /// Each leaf tracks EWMA of its own output weight and clamps to
763    /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
764    /// Recommended: use with `leaf_half_life` for streaming scenarios.
765    pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
766        self.config.adaptive_leaf_bound = Some(k);
767        self
768    }
769
770    /// Set the per-split information criterion factor (Lunde-Kleppe-Skaug 2020).
771    ///
772    /// Replaces static `max_depth` with a per-split generalization test.
773    /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
774    pub fn adaptive_depth(mut self, factor: f64) -> Self {
775        self.config.adaptive_depth = Some(factor);
776        self
777    }
778
779    /// Set the minimum hessian sum for leaf output.
780    ///
781    /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
782    /// post-replacement spikes.
783    pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
784        self.config.min_hessian_sum = Some(min_h);
785        self
786    }
787
788    /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
789    ///
790    /// When set, location gradients use Huber loss with adaptive
791    /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
792    pub fn huber_k(mut self, k: f64) -> Self {
793        self.config.huber_k = Some(k);
794        self
795    }
796
797    /// Enable graduated tree handoff with the given shadow warmup samples.
798    ///
799    /// Spawns an always-on shadow tree that trains alongside the active tree.
800    /// After `warmup` samples, the shadow begins contributing to predictions
801    /// via graduated blending. Eliminates prediction dips during tree replacement.
802    pub fn shadow_warmup(mut self, warmup: usize) -> Self {
803        self.config.shadow_warmup = Some(warmup);
804        self
805    }
806
807    /// Set the leaf prediction model type.
808    ///
809    /// [`LeafModelType::Linear`] is recommended for low-depth configurations
810    /// (depth 2--4) where per-leaf linear models reduce approximation error.
811    ///
812    /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
813    /// a trainable model per leaf, using the Hoeffding bound for promotion.
814    pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
815        self.config.leaf_model_type = lmt;
816        self
817    }
818
819    /// Set the packed cache refresh interval for distributional models.
820    ///
821    /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
822    /// maintains a packed f32 cache refreshed every `interval` training samples.
823    /// `0` (default) disables the cache.
824    pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
825        self.config.packed_refresh_interval = interval;
826        self
827    }
828
829    /// Validate and build the configuration.
830    ///
831    /// # Errors
832    ///
833    /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
834    /// [`ConfigError`] if any parameter is out of its valid range.
835    pub fn build(self) -> Result<SGBTConfig> {
836        let c = &self.config;
837
838        // -- Ensemble-level parameters --
839        if c.n_steps == 0 {
840            return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
841        }
842        if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
843            return Err(ConfigError::out_of_range(
844                "learning_rate",
845                "must be in (0, 1]",
846                c.learning_rate,
847            )
848            .into());
849        }
850        if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
851            return Err(ConfigError::out_of_range(
852                "feature_subsample_rate",
853                "must be in (0, 1]",
854                c.feature_subsample_rate,
855            )
856            .into());
857        }
858
859        // -- Tree-level parameters --
860        if c.max_depth == 0 {
861            return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
862        }
863        if c.n_bins < 2 {
864            return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
865        }
866        if c.lambda < 0.0 {
867            return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
868        }
869        if c.gamma < 0.0 {
870            return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
871        }
872        if c.grace_period == 0 {
873            return Err(
874                ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
875            );
876        }
877        if c.delta <= 0.0 || c.delta >= 1.0 {
878            return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
879        }
880
881        if c.initial_target_count == 0 {
882            return Err(ConfigError::out_of_range(
883                "initial_target_count",
884                "must be > 0",
885                c.initial_target_count,
886            )
887            .into());
888        }
889
890        // -- Streaming adaptation parameters --
891        if let Some(hl) = c.leaf_half_life {
892            if hl == 0 {
893                return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
894            }
895        }
896        if let Some(max) = c.max_tree_samples {
897            if max < 100 {
898                return Err(
899                    ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
900                );
901            }
902        }
903        if let Some(interval) = c.split_reeval_interval {
904            if interval < c.grace_period {
905                return Err(ConfigError::invalid(
906                    "split_reeval_interval",
907                    format!(
908                        "must be >= grace_period ({}), got {}",
909                        c.grace_period, interval
910                    ),
911                )
912                .into());
913            }
914        }
915
916        // -- Feature names --
917        if let Some(ref names) = c.feature_names {
918            // O(n^2) duplicate check — names list is small so no HashSet needed
919            for (i, name) in names.iter().enumerate() {
920                for prev in &names[..i] {
921                    if name == prev {
922                        return Err(ConfigError::invalid(
923                            "feature_names",
924                            format!("duplicate feature name: '{}'", name),
925                        )
926                        .into());
927                    }
928                }
929            }
930        }
931
932        // -- Feature types --
933        if let Some(ref types) = c.feature_types {
934            if let Some(ref names) = c.feature_names {
935                if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
936                    return Err(ConfigError::invalid(
937                        "feature_types",
938                        format!(
939                            "length ({}) must match feature_names length ({})",
940                            types.len(),
941                            names.len()
942                        ),
943                    )
944                    .into());
945                }
946            }
947        }
948
949        // -- Gradient clipping --
950        if let Some(sigma) = c.gradient_clip_sigma {
951            if sigma <= 0.0 {
952                return Err(
953                    ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
954                );
955            }
956        }
957
958        // -- Monotonic constraints --
959        if let Some(ref mc) = c.monotone_constraints {
960            for (i, &v) in mc.iter().enumerate() {
961                if v != -1 && v != 0 && v != 1 {
962                    return Err(ConfigError::invalid(
963                        "monotone_constraints",
964                        format!("feature {}: must be -1, 0, or +1, got {}", i, v),
965                    )
966                    .into());
967                }
968            }
969        }
970
971        // -- Leaf output clamping --
972        if let Some(max) = c.max_leaf_output {
973            if max <= 0.0 {
974                return Err(
975                    ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
976                );
977            }
978        }
979
980        // -- Per-leaf adaptive output bound --
981        if let Some(k) = c.adaptive_leaf_bound {
982            if k <= 0.0 {
983                return Err(
984                    ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
985                );
986            }
987        }
988
989        // -- Adaptive depth (per-split information criterion) --
990        if let Some(factor) = c.adaptive_depth {
991            if factor <= 0.0 {
992                return Err(
993                    ConfigError::out_of_range("adaptive_depth", "must be > 0", factor).into(),
994                );
995            }
996        }
997
998        // -- Minimum hessian sum --
999        if let Some(min_h) = c.min_hessian_sum {
1000            if min_h <= 0.0 {
1001                return Err(
1002                    ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
1003                );
1004            }
1005        }
1006
1007        // -- Huber loss multiplier --
1008        if let Some(k) = c.huber_k {
1009            if k <= 0.0 {
1010                return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
1011            }
1012        }
1013
1014        // -- Shadow warmup --
1015        if let Some(warmup) = c.shadow_warmup {
1016            if warmup == 0 {
1017                return Err(ConfigError::out_of_range(
1018                    "shadow_warmup",
1019                    "must be > 0",
1020                    warmup as f64,
1021                )
1022                .into());
1023            }
1024        }
1025
1026        // -- Quality-based pruning parameters --
1027        if let Some(alpha) = c.quality_prune_alpha {
1028            if alpha <= 0.0 || alpha >= 1.0 {
1029                return Err(ConfigError::out_of_range(
1030                    "quality_prune_alpha",
1031                    "must be in (0, 1)",
1032                    alpha,
1033                )
1034                .into());
1035            }
1036        }
1037        if c.quality_prune_threshold <= 0.0 {
1038            return Err(ConfigError::out_of_range(
1039                "quality_prune_threshold",
1040                "must be > 0",
1041                c.quality_prune_threshold,
1042            )
1043            .into());
1044        }
1045        if c.quality_prune_patience == 0 {
1046            return Err(ConfigError::out_of_range(
1047                "quality_prune_patience",
1048                "must be > 0",
1049                c.quality_prune_patience,
1050            )
1051            .into());
1052        }
1053
1054        // -- Error-weighted sample importance --
1055        if let Some(alpha) = c.error_weight_alpha {
1056            if alpha <= 0.0 || alpha >= 1.0 {
1057                return Err(ConfigError::out_of_range(
1058                    "error_weight_alpha",
1059                    "must be in (0, 1)",
1060                    alpha,
1061                )
1062                .into());
1063            }
1064        }
1065
1066        // -- Drift detector parameters --
1067        match &c.drift_detector {
1068            DriftDetectorType::PageHinkley { delta, lambda } => {
1069                if *delta <= 0.0 {
1070                    return Err(ConfigError::out_of_range(
1071                        "drift_detector.PageHinkley.delta",
1072                        "must be > 0",
1073                        delta,
1074                    )
1075                    .into());
1076                }
1077                if *lambda <= 0.0 {
1078                    return Err(ConfigError::out_of_range(
1079                        "drift_detector.PageHinkley.lambda",
1080                        "must be > 0",
1081                        lambda,
1082                    )
1083                    .into());
1084                }
1085            }
1086            DriftDetectorType::Adwin { delta } => {
1087                if *delta <= 0.0 || *delta >= 1.0 {
1088                    return Err(ConfigError::out_of_range(
1089                        "drift_detector.Adwin.delta",
1090                        "must be in (0, 1)",
1091                        delta,
1092                    )
1093                    .into());
1094                }
1095            }
1096            DriftDetectorType::Ddm {
1097                warning_level,
1098                drift_level,
1099                min_instances,
1100            } => {
1101                if *warning_level <= 0.0 {
1102                    return Err(ConfigError::out_of_range(
1103                        "drift_detector.Ddm.warning_level",
1104                        "must be > 0",
1105                        warning_level,
1106                    )
1107                    .into());
1108                }
1109                if *drift_level <= 0.0 {
1110                    return Err(ConfigError::out_of_range(
1111                        "drift_detector.Ddm.drift_level",
1112                        "must be > 0",
1113                        drift_level,
1114                    )
1115                    .into());
1116                }
1117                if *drift_level <= *warning_level {
1118                    return Err(ConfigError::invalid(
1119                        "drift_detector.Ddm.drift_level",
1120                        format!(
1121                            "must be > warning_level ({}), got {}",
1122                            warning_level, drift_level
1123                        ),
1124                    )
1125                    .into());
1126                }
1127                if *min_instances == 0 {
1128                    return Err(ConfigError::out_of_range(
1129                        "drift_detector.Ddm.min_instances",
1130                        "must be > 0",
1131                        min_instances,
1132                    )
1133                    .into());
1134                }
1135            }
1136        }
1137
1138        // -- Variant parameters --
1139        match &c.variant {
1140            SGBTVariant::Standard => {} // no extra validation
1141            SGBTVariant::Skip { k } => {
1142                if *k == 0 {
1143                    return Err(
1144                        ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1145                    );
1146                }
1147            }
1148            SGBTVariant::MultipleIterations { multiplier } => {
1149                if *multiplier <= 0.0 {
1150                    return Err(ConfigError::out_of_range(
1151                        "variant.MultipleIterations.multiplier",
1152                        "must be > 0",
1153                        multiplier,
1154                    )
1155                    .into());
1156                }
1157            }
1158        }
1159
1160        Ok(self.config)
1161    }
1162}
1163
1164// ---------------------------------------------------------------------------
1165// Display impls
1166// ---------------------------------------------------------------------------
1167
1168impl core::fmt::Display for DriftDetectorType {
1169    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1170        match self {
1171            Self::PageHinkley { delta, lambda } => {
1172                write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1173            }
1174            Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1175            Self::Ddm {
1176                warning_level,
1177                drift_level,
1178                min_instances,
1179            } => write!(
1180                f,
1181                "Ddm(warning={}, drift={}, min_instances={})",
1182                warning_level, drift_level, min_instances
1183            ),
1184        }
1185    }
1186}
1187
1188impl core::fmt::Display for SGBTConfig {
1189    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1190        write!(
1191            f,
1192            "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1193            self.n_steps,
1194            self.learning_rate,
1195            self.max_depth,
1196            self.n_bins,
1197            self.variant,
1198            self.drift_detector,
1199        )
1200    }
1201}
1202
1203// ---------------------------------------------------------------------------
1204// Tests
1205// ---------------------------------------------------------------------------
1206
1207#[cfg(test)]
1208mod tests {
1209    use super::*;
1210    use alloc::format;
1211    use alloc::vec;
1212
1213    // ------------------------------------------------------------------
1214    // 1. Default config values are correct
1215    // ------------------------------------------------------------------
1216    #[test]
1217    fn default_config_values() {
1218        let cfg = SGBTConfig::default();
1219        assert_eq!(cfg.n_steps, 100);
1220        assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1221        assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1222        assert_eq!(cfg.max_depth, 6);
1223        assert_eq!(cfg.n_bins, 64);
1224        assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1225        assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1226        assert_eq!(cfg.grace_period, 200);
1227        assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1228        assert_eq!(cfg.variant, SGBTVariant::Standard);
1229    }
1230
1231    // ------------------------------------------------------------------
1232    // 2. Builder chain works
1233    // ------------------------------------------------------------------
1234    #[test]
1235    fn builder_chain() {
1236        let cfg = SGBTConfig::builder()
1237            .n_steps(50)
1238            .learning_rate(0.1)
1239            .feature_subsample_rate(0.5)
1240            .max_depth(10)
1241            .n_bins(128)
1242            .lambda(0.5)
1243            .gamma(0.1)
1244            .grace_period(500)
1245            .delta(1e-3)
1246            .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1247            .variant(SGBTVariant::Skip { k: 5 })
1248            .build()
1249            .expect("valid config");
1250
1251        assert_eq!(cfg.n_steps, 50);
1252        assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1253        assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1254        assert_eq!(cfg.max_depth, 10);
1255        assert_eq!(cfg.n_bins, 128);
1256        assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1257        assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1258        assert_eq!(cfg.grace_period, 500);
1259        assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1260        assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1261
1262        // Verify drift detector type is Adwin.
1263        match &cfg.drift_detector {
1264            DriftDetectorType::Adwin { delta } => {
1265                assert!((delta - 0.01).abs() < f64::EPSILON);
1266            }
1267            _ => panic!("expected Adwin drift detector"),
1268        }
1269    }
1270
1271    // ------------------------------------------------------------------
1272    // 3. Validation rejects invalid values
1273    // ------------------------------------------------------------------
1274    #[test]
1275    fn validation_rejects_n_steps_zero() {
1276        let result = SGBTConfig::builder().n_steps(0).build();
1277        assert!(result.is_err());
1278        let msg = format!("{}", result.unwrap_err());
1279        assert!(msg.contains("n_steps"));
1280    }
1281
1282    #[test]
1283    fn validation_rejects_learning_rate_zero() {
1284        let result = SGBTConfig::builder().learning_rate(0.0).build();
1285        assert!(result.is_err());
1286        let msg = format!("{}", result.unwrap_err());
1287        assert!(msg.contains("learning_rate"));
1288    }
1289
1290    #[test]
1291    fn validation_rejects_learning_rate_above_one() {
1292        let result = SGBTConfig::builder().learning_rate(1.5).build();
1293        assert!(result.is_err());
1294    }
1295
1296    #[test]
1297    fn validation_accepts_learning_rate_one() {
1298        let result = SGBTConfig::builder().learning_rate(1.0).build();
1299        assert!(result.is_ok());
1300    }
1301
1302    #[test]
1303    fn validation_rejects_negative_learning_rate() {
1304        let result = SGBTConfig::builder().learning_rate(-0.1).build();
1305        assert!(result.is_err());
1306    }
1307
1308    #[test]
1309    fn validation_rejects_feature_subsample_zero() {
1310        let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1311        assert!(result.is_err());
1312    }
1313
1314    #[test]
1315    fn validation_rejects_feature_subsample_above_one() {
1316        let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1317        assert!(result.is_err());
1318    }
1319
1320    #[test]
1321    fn validation_rejects_max_depth_zero() {
1322        let result = SGBTConfig::builder().max_depth(0).build();
1323        assert!(result.is_err());
1324    }
1325
1326    #[test]
1327    fn validation_rejects_n_bins_one() {
1328        let result = SGBTConfig::builder().n_bins(1).build();
1329        assert!(result.is_err());
1330    }
1331
1332    #[test]
1333    fn validation_rejects_negative_lambda() {
1334        let result = SGBTConfig::builder().lambda(-0.1).build();
1335        assert!(result.is_err());
1336    }
1337
1338    #[test]
1339    fn validation_accepts_zero_lambda() {
1340        let result = SGBTConfig::builder().lambda(0.0).build();
1341        assert!(result.is_ok());
1342    }
1343
1344    #[test]
1345    fn validation_rejects_negative_gamma() {
1346        let result = SGBTConfig::builder().gamma(-0.1).build();
1347        assert!(result.is_err());
1348    }
1349
1350    #[test]
1351    fn validation_rejects_grace_period_zero() {
1352        let result = SGBTConfig::builder().grace_period(0).build();
1353        assert!(result.is_err());
1354    }
1355
1356    #[test]
1357    fn validation_rejects_delta_zero() {
1358        let result = SGBTConfig::builder().delta(0.0).build();
1359        assert!(result.is_err());
1360    }
1361
1362    #[test]
1363    fn validation_rejects_delta_one() {
1364        let result = SGBTConfig::builder().delta(1.0).build();
1365        assert!(result.is_err());
1366    }
1367
1368    // ------------------------------------------------------------------
1369    // 3b. Drift detector parameter validation
1370    // ------------------------------------------------------------------
1371    #[test]
1372    fn validation_rejects_pht_negative_delta() {
1373        let result = SGBTConfig::builder()
1374            .drift_detector(DriftDetectorType::PageHinkley {
1375                delta: -1.0,
1376                lambda: 50.0,
1377            })
1378            .build();
1379        assert!(result.is_err());
1380        let msg = format!("{}", result.unwrap_err());
1381        assert!(msg.contains("PageHinkley"));
1382    }
1383
1384    #[test]
1385    fn validation_rejects_pht_zero_lambda() {
1386        let result = SGBTConfig::builder()
1387            .drift_detector(DriftDetectorType::PageHinkley {
1388                delta: 0.005,
1389                lambda: 0.0,
1390            })
1391            .build();
1392        assert!(result.is_err());
1393    }
1394
1395    #[test]
1396    fn validation_rejects_adwin_delta_out_of_range() {
1397        let result = SGBTConfig::builder()
1398            .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1399            .build();
1400        assert!(result.is_err());
1401
1402        let result = SGBTConfig::builder()
1403            .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1404            .build();
1405        assert!(result.is_err());
1406    }
1407
1408    #[test]
1409    fn validation_rejects_ddm_warning_above_drift() {
1410        let result = SGBTConfig::builder()
1411            .drift_detector(DriftDetectorType::Ddm {
1412                warning_level: 3.0,
1413                drift_level: 2.0,
1414                min_instances: 30,
1415            })
1416            .build();
1417        assert!(result.is_err());
1418        let msg = format!("{}", result.unwrap_err());
1419        assert!(msg.contains("drift_level"));
1420        assert!(msg.contains("must be > warning_level"));
1421    }
1422
1423    #[test]
1424    fn validation_rejects_ddm_equal_levels() {
1425        let result = SGBTConfig::builder()
1426            .drift_detector(DriftDetectorType::Ddm {
1427                warning_level: 2.0,
1428                drift_level: 2.0,
1429                min_instances: 30,
1430            })
1431            .build();
1432        assert!(result.is_err());
1433    }
1434
1435    #[test]
1436    fn validation_rejects_ddm_zero_min_instances() {
1437        let result = SGBTConfig::builder()
1438            .drift_detector(DriftDetectorType::Ddm {
1439                warning_level: 2.0,
1440                drift_level: 3.0,
1441                min_instances: 0,
1442            })
1443            .build();
1444        assert!(result.is_err());
1445    }
1446
1447    #[test]
1448    fn validation_rejects_ddm_zero_warning_level() {
1449        let result = SGBTConfig::builder()
1450            .drift_detector(DriftDetectorType::Ddm {
1451                warning_level: 0.0,
1452                drift_level: 3.0,
1453                min_instances: 30,
1454            })
1455            .build();
1456        assert!(result.is_err());
1457    }
1458
1459    // ------------------------------------------------------------------
1460    // 3c. Variant parameter validation
1461    // ------------------------------------------------------------------
1462    #[test]
1463    fn validation_rejects_skip_k_zero() {
1464        let result = SGBTConfig::builder()
1465            .variant(SGBTVariant::Skip { k: 0 })
1466            .build();
1467        assert!(result.is_err());
1468        let msg = format!("{}", result.unwrap_err());
1469        assert!(msg.contains("Skip"));
1470    }
1471
1472    #[test]
1473    fn validation_rejects_mi_zero_multiplier() {
1474        let result = SGBTConfig::builder()
1475            .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1476            .build();
1477        assert!(result.is_err());
1478    }
1479
1480    #[test]
1481    fn validation_rejects_mi_negative_multiplier() {
1482        let result = SGBTConfig::builder()
1483            .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1484            .build();
1485        assert!(result.is_err());
1486    }
1487
1488    #[test]
1489    fn validation_accepts_standard_variant() {
1490        let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1491        assert!(result.is_ok());
1492    }
1493
1494    // ------------------------------------------------------------------
1495    // 4. DriftDetectorType creates correct detector types
1496    // ------------------------------------------------------------------
1497    #[test]
1498    fn drift_detector_type_creates_page_hinkley() {
1499        let dt = DriftDetectorType::PageHinkley {
1500            delta: 0.01,
1501            lambda: 100.0,
1502        };
1503        let mut detector = dt.create();
1504
1505        // Should start with zero mean.
1506        assert_eq!(detector.estimated_mean(), 0.0);
1507
1508        // Feed a stable value -- should not drift.
1509        let signal = detector.update(1.0);
1510        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1511    }
1512
1513    #[test]
1514    fn drift_detector_type_creates_adwin() {
1515        let dt = DriftDetectorType::Adwin { delta: 0.05 };
1516        let mut detector = dt.create();
1517
1518        assert_eq!(detector.estimated_mean(), 0.0);
1519        let signal = detector.update(1.0);
1520        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1521    }
1522
1523    #[test]
1524    fn drift_detector_type_creates_ddm() {
1525        let dt = DriftDetectorType::Ddm {
1526            warning_level: 2.0,
1527            drift_level: 3.0,
1528            min_instances: 30,
1529        };
1530        let mut detector = dt.create();
1531
1532        assert_eq!(detector.estimated_mean(), 0.0);
1533        let signal = detector.update(0.1);
1534        assert_eq!(signal, crate::drift::DriftSignal::Stable);
1535    }
1536
1537    // ------------------------------------------------------------------
1538    // 5. Default drift detector is PageHinkley
1539    // ------------------------------------------------------------------
1540    #[test]
1541    fn default_drift_detector_is_page_hinkley() {
1542        let dt = DriftDetectorType::default();
1543        match dt {
1544            DriftDetectorType::PageHinkley { delta, lambda } => {
1545                assert!((delta - 0.005).abs() < f64::EPSILON);
1546                assert!((lambda - 50.0).abs() < f64::EPSILON);
1547            }
1548            _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1549        }
1550    }
1551
1552    // ------------------------------------------------------------------
1553    // 6. Default config builds without error
1554    // ------------------------------------------------------------------
1555    #[test]
1556    fn default_config_builds() {
1557        let result = SGBTConfig::builder().build();
1558        assert!(result.is_ok());
1559    }
1560
1561    // ------------------------------------------------------------------
1562    // 7. Config clone preserves all fields
1563    // ------------------------------------------------------------------
1564    #[test]
1565    fn config_clone_preserves_fields() {
1566        let original = SGBTConfig::builder()
1567            .n_steps(50)
1568            .learning_rate(0.05)
1569            .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1570            .build()
1571            .unwrap();
1572
1573        let cloned = original.clone();
1574
1575        assert_eq!(cloned.n_steps, 50);
1576        assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1577        assert_eq!(
1578            cloned.variant,
1579            SGBTVariant::MultipleIterations { multiplier: 5.0 }
1580        );
1581    }
1582
1583    // ------------------------------------------------------------------
1584    // 8. All three DDM valid configs accepted
1585    // ------------------------------------------------------------------
1586    #[test]
1587    fn valid_ddm_config_accepted() {
1588        let result = SGBTConfig::builder()
1589            .drift_detector(DriftDetectorType::Ddm {
1590                warning_level: 1.5,
1591                drift_level: 2.5,
1592                min_instances: 10,
1593            })
1594            .build();
1595        assert!(result.is_ok());
1596    }
1597
1598    // ------------------------------------------------------------------
1599    // 9. Created detectors are functional (round-trip test)
1600    // ------------------------------------------------------------------
1601    #[test]
1602    fn created_detectors_are_functional() {
1603        // Create a detector, feed it data, verify it can detect drift.
1604        let dt = DriftDetectorType::PageHinkley {
1605            delta: 0.005,
1606            lambda: 50.0,
1607        };
1608        let mut detector = dt.create();
1609
1610        // Feed 500 stable values.
1611        for _ in 0..500 {
1612            detector.update(1.0);
1613        }
1614
1615        // Feed shifted values -- should eventually drift.
1616        let mut drifted = false;
1617        for _ in 0..500 {
1618            if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1619                drifted = true;
1620                break;
1621            }
1622        }
1623        assert!(
1624            drifted,
1625            "detector created from DriftDetectorType should be functional"
1626        );
1627    }
1628
1629    // ------------------------------------------------------------------
1630    // 10. Boundary acceptance: n_bins=2 is the minimum valid
1631    // ------------------------------------------------------------------
1632    #[test]
1633    fn boundary_n_bins_two_accepted() {
1634        let result = SGBTConfig::builder().n_bins(2).build();
1635        assert!(result.is_ok());
1636    }
1637
1638    // ------------------------------------------------------------------
1639    // 11. Boundary acceptance: grace_period=1 is valid
1640    // ------------------------------------------------------------------
1641    #[test]
1642    fn boundary_grace_period_one_accepted() {
1643        let result = SGBTConfig::builder().grace_period(1).build();
1644        assert!(result.is_ok());
1645    }
1646
1647    // ------------------------------------------------------------------
1648    // 12. Feature names -- valid config with names
1649    // ------------------------------------------------------------------
1650    #[test]
1651    fn feature_names_accepted() {
1652        let cfg = SGBTConfig::builder()
1653            .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1654            .build()
1655            .unwrap();
1656        assert_eq!(
1657            cfg.feature_names.as_ref().unwrap(),
1658            &["price", "volume", "spread"]
1659        );
1660    }
1661
1662    // ------------------------------------------------------------------
1663    // 13. Feature names -- duplicate names rejected
1664    // ------------------------------------------------------------------
1665    #[test]
1666    fn feature_names_rejects_duplicates() {
1667        let result = SGBTConfig::builder()
1668            .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1669            .build();
1670        assert!(result.is_err());
1671        let msg = format!("{}", result.unwrap_err());
1672        assert!(msg.contains("duplicate"));
1673    }
1674
1675    // ------------------------------------------------------------------
1676    // 14. Feature names -- serde backward compat (missing field)
1677    //     Requires serde_json which is only in the full irithyll crate.
1678    // ------------------------------------------------------------------
1679
1680    // ------------------------------------------------------------------
1681    // 15. Feature names -- empty vec is valid
1682    // ------------------------------------------------------------------
1683    #[test]
1684    fn feature_names_empty_vec_accepted() {
1685        let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1686        assert!(cfg.feature_names.unwrap().is_empty());
1687    }
1688
1689    // ------------------------------------------------------------------
1690    // 16. Adaptive leaf bound -- builder
1691    // ------------------------------------------------------------------
1692    #[test]
1693    fn builder_adaptive_leaf_bound() {
1694        let cfg = SGBTConfig::builder()
1695            .adaptive_leaf_bound(3.0)
1696            .build()
1697            .unwrap();
1698        assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1699    }
1700
1701    // ------------------------------------------------------------------
1702    // 17. Adaptive leaf bound -- validation rejects zero
1703    // ------------------------------------------------------------------
1704    #[test]
1705    fn validation_rejects_zero_adaptive_leaf_bound() {
1706        let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1707        assert!(result.is_err());
1708        let msg = format!("{}", result.unwrap_err());
1709        assert!(
1710            msg.contains("adaptive_leaf_bound"),
1711            "error should mention adaptive_leaf_bound: {}",
1712            msg,
1713        );
1714    }
1715
1716    // ------------------------------------------------------------------
1717    // 18. Adaptive leaf bound -- serde backward compat
1718    //     Requires serde_json which is only in the full irithyll crate.
1719    // ------------------------------------------------------------------
1720}