Skip to main content

irithyll_core/ensemble/
config.rs

1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34///   `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1) compute.
35///   Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38///   log(σ) from features (NGBoost-style dual chain).  Gives *feature-conditional*
39///   uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43    /// EWMA of squared prediction errors — always calibrated, O(1).
44    #[default]
45    Empirical,
46    /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47    TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61    /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62    /// (detection threshold).
63    PageHinkley {
64        /// Magnitude tolerance. Default 0.005.
65        delta: f64,
66        /// Detection threshold. Default 50.0.
67        lambda: f64,
68    },
69
70    /// ADWIN with custom confidence parameter.
71    Adwin {
72        /// Confidence (smaller = fewer false positives). Default 0.002.
73        delta: f64,
74    },
75
76    /// DDM with custom warning/drift levels and minimum warmup instances.
77    Ddm {
78        /// Warning threshold multiplier. Default 2.0.
79        warning_level: f64,
80        /// Drift threshold multiplier. Default 3.0.
81        drift_level: f64,
82        /// Minimum observations before detection activates. Default 30.
83        min_instances: u64,
84    },
85}
86
87impl Default for DriftDetectorType {
88    fn default() -> Self {
89        DriftDetectorType::PageHinkley {
90            delta: 0.005,
91            lambda: 50.0,
92        }
93    }
94}
95
96impl DriftDetectorType {
97    /// Create a new, fresh drift detector from this configuration.
98    pub fn create(&self) -> Box<dyn DriftDetector> {
99        match self {
100            Self::PageHinkley { delta, lambda } => {
101                Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102            }
103            Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104            Self::Ddm {
105                warning_level,
106                drift_level,
107                min_instances,
108            } => Box::new(Ddm::with_params(
109                *warning_level,
110                *drift_level,
111                *min_instances,
112            )),
113        }
114    }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter                | Default              |
128/// |--------------------------|----------------------|
129/// | `n_steps`                | 100                  |
130/// | `learning_rate`          | 0.0125               |
131/// | `feature_subsample_rate` | 0.75                 |
132/// | `max_depth`              | 6                    |
133/// | `n_bins`                 | 64                   |
134/// | `lambda`                 | 1.0                  |
135/// | `gamma`                  | 0.0                  |
136/// | `grace_period`           | 200                  |
137/// | `delta`                  | 1e-7                 |
138/// | `drift_detector`         | PageHinkley(0.005, 50.0) |
139/// | `variant`                | Standard             |
140/// | `seed`                   | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count`   | 50                   |
142/// | `leaf_half_life`         | None (disabled)      |
143/// | `max_tree_samples`       | None (disabled)      |
144/// | `split_reeval_interval`  | None (disabled)      |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148    /// Number of boosting steps (trees in the ensemble). Default 100.
149    pub n_steps: usize,
150    /// Learning rate (shrinkage). Default 0.0125.
151    pub learning_rate: f64,
152    /// Fraction of features to subsample per tree. Default 0.75.
153    pub feature_subsample_rate: f64,
154    /// Maximum tree depth. Default 6.
155    pub max_depth: usize,
156    /// Number of histogram bins. Default 64.
157    pub n_bins: usize,
158    /// L2 regularization parameter (lambda). Default 1.0.
159    pub lambda: f64,
160    /// Minimum split gain (gamma). Default 0.0.
161    pub gamma: f64,
162    /// Grace period: min samples before evaluating splits. Default 200.
163    pub grace_period: usize,
164    /// Hoeffding bound confidence (delta). Default 1e-7.
165    pub delta: f64,
166    /// Drift detector type for tree replacement. Default: PageHinkley.
167    pub drift_detector: DriftDetectorType,
168    /// SGBT computational variant. Default: Standard.
169    pub variant: SGBTVariant,
170    /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171    ///
172    /// Controls feature subsampling and variant skip/MI stochastic decisions.
173    /// Two models with the same seed and same data will produce identical results.
174    pub seed: u64,
175    /// Number of initial targets to collect before computing the base prediction.
176    /// Default: 50.
177    pub initial_target_count: usize,
178
179    /// Half-life for exponential leaf decay (in samples per leaf).
180    ///
181    /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182    /// statistics have half the weight of the most recent sample. This causes
183    /// the model to continuously adapt to changing data distributions rather
184    /// than freezing on early observations.
185    ///
186    /// `None` (default) disables decay -- traditional monotonic accumulation.
187    #[cfg_attr(feature = "serde", serde(default))]
188    pub leaf_half_life: Option<usize>,
189
190    /// Maximum samples a single tree processes before proactive replacement.
191    ///
192    /// After this many samples, the tree is replaced with a fresh one regardless
193    /// of drift detector state. Prevents stale tree structure from persisting
194    /// when the drift detector is not sensitive enough.
195    ///
196    /// `None` (default) disables time-based replacement.
197    #[cfg_attr(feature = "serde", serde(default))]
198    pub max_tree_samples: Option<u64>,
199
200    /// Sigma-modulated tree replacement speed.
201    ///
202    /// When `Some((base_mts, k))`, the effective max_tree_samples becomes:
203    /// `effective_mts = base_mts / (1.0 + k * normalized_sigma)`
204    /// where `normalized_sigma = contribution_sigma / rolling_contribution_sigma_mean`.
205    ///
206    /// High uncertainty (contribution variance) shortens tree lifetime for faster
207    /// adaptation. Low uncertainty lengthens lifetime for more knowledge accumulation.
208    ///
209    /// Overrides `max_tree_samples` when set. Default: `None`.
210    #[cfg_attr(feature = "serde", serde(default))]
211    pub adaptive_mts: Option<(u64, f64)>,
212
213    /// Minimum effective MTS as a fraction of `base_mts`.
214    ///
215    /// When adaptive MTS is active, the effective tree lifetime can shrink
216    /// aggressively under high uncertainty. This floor prevents runaway
217    /// replacement by clamping `effective_mts >= base_mts * fraction`.
218    ///
219    /// The hard floor (grace_period * 2) still applies beneath this.
220    ///
221    /// Default: `0.0` (only the hard floor applies).
222    #[cfg_attr(feature = "serde", serde(default))]
223    pub adaptive_mts_floor: f64,
224
225    /// Proactive tree pruning interval.
226    ///
227    /// Every `interval` training samples, the tree with the lowest EWMA
228    /// marginal contribution is replaced with a fresh tree.
229    /// Fresh trees (fewer than `interval / 2` samples) are excluded.
230    ///
231    /// When set, contribution tracking is automatically enabled even if
232    /// `quality_prune_alpha` is `None` (uses alpha=0.01).
233    ///
234    /// Default: `None` (disabled).
235    #[cfg_attr(feature = "serde", serde(default))]
236    pub proactive_prune_interval: Option<u64>,
237
238    /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
239    /// whether a split would improve them.
240    ///
241    /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
242    /// `split_reeval_interval` samples since its last evaluation and has reached
243    /// max depth, it re-evaluates whether a split should be performed.
244    ///
245    /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
246    #[cfg_attr(feature = "serde", serde(default))]
247    pub split_reeval_interval: Option<usize>,
248
249    /// Optional human-readable feature names.
250    ///
251    /// When set, enables `named_feature_importances` and
252    /// `train_one_named` for production-friendly named access.
253    /// Length must match the number of features in training data.
254    #[cfg_attr(feature = "serde", serde(default))]
255    pub feature_names: Option<Vec<String>>,
256
257    /// Optional per-feature type declarations.
258    ///
259    /// When set, declares which features are categorical vs continuous.
260    /// Categorical features use one-bin-per-category binning and Fisher
261    /// optimal binary partitioning for split evaluation.
262    /// Length must match the number of features in training data.
263    ///
264    /// `None` (default) treats all features as continuous.
265    #[cfg_attr(feature = "serde", serde(default))]
266    pub feature_types: Option<Vec<FeatureType>>,
267
268    /// Gradient clipping threshold in standard deviations per leaf.
269    ///
270    /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
271    /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
272    /// clamped to the boundary. This prevents outlier samples from corrupting
273    /// leaf statistics, which is critical in streaming settings where sudden
274    /// label floods can destabilize the model.
275    ///
276    /// Typical value: 3.0 (3-sigma clipping).
277    /// `None` (default) disables gradient clipping.
278    #[cfg_attr(feature = "serde", serde(default))]
279    pub gradient_clip_sigma: Option<f64>,
280
281    /// Per-feature monotonic constraints.
282    ///
283    /// Each element specifies the monotonic relationship between a feature and
284    /// the prediction:
285    /// - `+1`: prediction must be non-decreasing as feature value increases.
286    /// - `-1`: prediction must be non-increasing as feature value increases.
287    /// - `0`: no constraint (unconstrained).
288    ///
289    /// During split evaluation, candidate splits that would violate monotonicity
290    /// (left child value > right child value for +1 constraints, or vice versa)
291    /// are rejected.
292    ///
293    /// Length must match the number of features in training data.
294    /// `None` (default) means no monotonic constraints.
295    #[cfg_attr(feature = "serde", serde(default))]
296    pub monotone_constraints: Option<Vec<i8>>,
297
298    /// EWMA smoothing factor for quality-based tree pruning.
299    ///
300    /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
301    /// moving average of its marginal contribution to the ensemble. Trees whose
302    /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
303    /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
304    /// samples are replaced with a fresh tree that can learn the current regime.
305    ///
306    /// This prevents "dead wood" -- trees from a past regime that no longer
307    /// contribute meaningfully to ensemble accuracy.
308    ///
309    /// `None` (default) disables quality-based pruning.
310    /// Suggested value: 0.01.
311    #[cfg_attr(feature = "serde", serde(default))]
312    pub quality_prune_alpha: Option<f64>,
313
314    /// Minimum contribution threshold for quality-based pruning.
315    ///
316    /// A tree's EWMA contribution must stay above this value to avoid being
317    /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
318    ///
319    /// Default: 1e-6.
320    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
321    pub quality_prune_threshold: f64,
322
323    /// Consecutive low-contribution samples before a tree is replaced.
324    ///
325    /// After this many consecutive samples where a tree's EWMA contribution
326    /// is below `quality_prune_threshold`, the tree is reset. Only used when
327    /// `quality_prune_alpha` is `Some`.
328    ///
329    /// Default: 500.
330    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
331    pub quality_prune_patience: u64,
332
333    /// EWMA smoothing factor for error-weighted sample importance.
334    ///
335    /// When `Some(alpha)`, samples the model predicted poorly get higher
336    /// effective weight during histogram accumulation. The weight is:
337    /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
338    ///
339    /// This is a streaming version of AdaBoost's reweighting applied at the
340    /// gradient level -- learning capacity focuses on hard/novel patterns,
341    /// enabling faster adaptation to regime changes.
342    ///
343    /// `None` (default) disables error weighting.
344    /// Suggested value: 0.01.
345    #[cfg_attr(feature = "serde", serde(default))]
346    pub error_weight_alpha: Option<f64>,
347
348    /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
349    ///
350    /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
351    /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
352    /// is an EWMA of the model's predicted σ (alpha = 0.001).
353    ///
354    /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
355    /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
356    /// trains at the unmodulated base rate to prevent positive feedback loops.
357    ///
358    /// Default: `false`.
359    #[cfg_attr(feature = "serde", serde(default))]
360    pub uncertainty_modulated_lr: bool,
361
362    /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
363    ///
364    /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
365    ///   errors.  `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1).
366    /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
367    ///   separate tree ensemble predicting log(σ) from features.
368    ///
369    /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
370    /// is strongly recommended — scale tree gradients are inherently weak and
371    /// the trees often fail to split.
372    #[cfg_attr(feature = "serde", serde(default))]
373    pub scale_mode: ScaleMode,
374
375    /// EWMA smoothing factor for empirical σ estimation.
376    ///
377    /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
378    /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
379    /// Higher values react faster to regime changes but are noisier.
380    ///
381    /// Default: `0.01` (~100-sample effective window).
382    #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
383    pub empirical_sigma_alpha: f64,
384
385    /// Maximum absolute leaf output value.
386    ///
387    /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
388    /// Prevents runaway leaf weights from causing prediction explosions
389    /// in feedback loops. `None` (default) means no clamping.
390    #[cfg_attr(feature = "serde", serde(default))]
391    pub max_leaf_output: Option<f64>,
392
393    /// Per-leaf adaptive output bound (sigma multiplier).
394    ///
395    /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
396    /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
397    /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
398    ///
399    /// This is strictly superior to `max_leaf_output` for streaming — the bound
400    /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
401    /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
402    ///
403    /// Typical value: 3.0 (3-sigma bound).
404    /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
405    #[cfg_attr(feature = "serde", serde(default))]
406    pub adaptive_leaf_bound: Option<f64>,
407
408    /// Per-split information criterion (Lunde-Kleppe-Skaug 2020).
409    ///
410    /// When `Some(cir_factor)`, replaces `max_depth` with a per-split
411    /// generalization test. Each candidate split must have
412    /// `gain > cir_factor * sigma^2_g / n * n_features`.
413    /// `max_depth * 2` becomes a hard safety ceiling.
414    ///
415    /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
416    /// `None` (default) uses static `max_depth` only.
417    #[cfg_attr(feature = "serde", serde(default))]
418    pub adaptive_depth: Option<f64>,
419
420    /// Minimum hessian sum before a leaf produces non-zero output.
421    ///
422    /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
423    /// Prevents post-replacement spikes from fresh leaves with insufficient
424    /// samples. `None` (default) means all leaves contribute immediately.
425    #[cfg_attr(feature = "serde", serde(default))]
426    pub min_hessian_sum: Option<f64>,
427
428    /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
429    ///
430    /// When `Some(k)`, the distributional location gradient uses Huber loss
431    /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
432    /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
433    /// `None` (default) uses squared loss.
434    #[cfg_attr(feature = "serde", serde(default))]
435    pub huber_k: Option<f64>,
436
437    /// Shadow warmup for graduated tree handoff.
438    ///
439    /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
440    /// alongside every active tree. The shadow trains on the same gradient stream
441    /// but does not contribute to predictions until it has seen `n` samples.
442    ///
443    /// As the active tree ages past 80% of `max_tree_samples`, its prediction
444    /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
445    /// over `n` samples after warmup. When the active weight reaches 0, the shadow
446    /// is promoted and a new shadow is spawned — no cold-start prediction dip.
447    ///
448    /// Requires `max_tree_samples` to be set for time-based graduated handoff.
449    /// Drift-based replacement still uses hard swap (shadow is already warm).
450    ///
451    /// `None` (default) disables graduated handoff — uses traditional hard swap.
452    #[cfg_attr(feature = "serde", serde(default))]
453    pub shadow_warmup: Option<usize>,
454
455    /// Leaf prediction model type.
456    ///
457    /// Controls how each leaf computes its prediction:
458    /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
459    /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
460    ///   AdaGrad optimization. Optional `decay` for concept drift. Recommended for
461    ///   low-depth trees (depth 2--4).
462    /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
463    ///   Optional `decay` for concept drift.
464    /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
465    ///   when the Hoeffding bound confirms a more complex model is better.
466    ///
467    /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
468    #[cfg_attr(feature = "serde", serde(default))]
469    pub leaf_model_type: LeafModelType,
470
471    /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
472    ///
473    /// When non-zero, the distributional model maintains a packed f32 cache of
474    /// its location ensemble that is re-exported every `packed_refresh_interval`
475    /// training samples. Predictions use the cache for O(1)-per-tree inference
476    /// via contiguous memory traversal, falling back to full tree traversal when
477    /// the cache is absent or produces non-finite results.
478    ///
479    /// `0` (default) disables the packed cache.
480    #[cfg_attr(feature = "serde", serde(default))]
481    pub packed_refresh_interval: u64,
482
483    /// Enable per-node auto-bandwidth soft routing at prediction time.
484    ///
485    /// When `true`, predictions are continuous weighted blends instead of
486    /// piecewise-constant step functions. No training changes.
487    ///
488    /// Default: `false`.
489    #[cfg_attr(feature = "serde", serde(default))]
490    pub soft_routing: bool,
491}
492
493#[cfg(feature = "serde")]
494fn default_empirical_sigma_alpha() -> f64 {
495    0.01
496}
497
498#[cfg(feature = "serde")]
499fn default_quality_prune_threshold() -> f64 {
500    1e-6
501}
502
503#[cfg(feature = "serde")]
504fn default_quality_prune_patience() -> u64 {
505    500
506}
507
508impl Default for SGBTConfig {
509    fn default() -> Self {
510        Self {
511            n_steps: 100,
512            learning_rate: 0.0125,
513            feature_subsample_rate: 0.75,
514            max_depth: 6,
515            n_bins: 64,
516            lambda: 1.0,
517            gamma: 0.0,
518            grace_period: 200,
519            delta: 1e-7,
520            drift_detector: DriftDetectorType::default(),
521            variant: SGBTVariant::default(),
522            seed: 0xDEAD_BEEF_CAFE_4242,
523            initial_target_count: 50,
524            leaf_half_life: None,
525            max_tree_samples: None,
526            adaptive_mts: None,
527            adaptive_mts_floor: 0.0,
528            proactive_prune_interval: None,
529            split_reeval_interval: None,
530            feature_names: None,
531            feature_types: None,
532            gradient_clip_sigma: None,
533            monotone_constraints: None,
534            quality_prune_alpha: None,
535            quality_prune_threshold: 1e-6,
536            quality_prune_patience: 500,
537            error_weight_alpha: None,
538            uncertainty_modulated_lr: false,
539            scale_mode: ScaleMode::default(),
540            empirical_sigma_alpha: 0.01,
541            max_leaf_output: None,
542            adaptive_leaf_bound: None,
543            adaptive_depth: None,
544            min_hessian_sum: None,
545            huber_k: None,
546            shadow_warmup: None,
547            leaf_model_type: LeafModelType::default(),
548            packed_refresh_interval: 0,
549            soft_routing: false,
550        }
551    }
552}
553
554impl SGBTConfig {
555    /// Start building a configuration via the builder pattern.
556    pub fn builder() -> SGBTConfigBuilder {
557        SGBTConfigBuilder::default()
558    }
559}
560
561// ---------------------------------------------------------------------------
562// SGBTConfigBuilder
563// ---------------------------------------------------------------------------
564
565/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
566///
567/// # Example
568///
569/// ```ignore
570/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
571/// use irithyll::ensemble::variants::SGBTVariant;
572///
573/// let config = SGBTConfig::builder()
574///     .n_steps(200)
575///     .learning_rate(0.05)
576///     .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
577///     .variant(SGBTVariant::Skip { k: 10 })
578///     .build()
579///     .expect("valid config");
580/// ```
581#[derive(Debug, Clone, Default)]
582pub struct SGBTConfigBuilder {
583    config: SGBTConfig,
584}
585
586impl SGBTConfigBuilder {
587    /// Set the number of boosting steps (trees in the ensemble).
588    pub fn n_steps(mut self, n: usize) -> Self {
589        self.config.n_steps = n;
590        self
591    }
592
593    /// Set the learning rate (shrinkage factor).
594    pub fn learning_rate(mut self, lr: f64) -> Self {
595        self.config.learning_rate = lr;
596        self
597    }
598
599    /// Set the fraction of features to subsample per tree.
600    pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
601        self.config.feature_subsample_rate = rate;
602        self
603    }
604
605    /// Set the maximum tree depth.
606    pub fn max_depth(mut self, depth: usize) -> Self {
607        self.config.max_depth = depth;
608        self
609    }
610
611    /// Set the number of histogram bins per feature.
612    pub fn n_bins(mut self, bins: usize) -> Self {
613        self.config.n_bins = bins;
614        self
615    }
616
617    /// Set the L2 regularization parameter (lambda).
618    pub fn lambda(mut self, l: f64) -> Self {
619        self.config.lambda = l;
620        self
621    }
622
623    /// Set the minimum split gain (gamma).
624    pub fn gamma(mut self, g: f64) -> Self {
625        self.config.gamma = g;
626        self
627    }
628
629    /// Set the grace period (minimum samples before evaluating splits).
630    pub fn grace_period(mut self, gp: usize) -> Self {
631        self.config.grace_period = gp;
632        self
633    }
634
635    /// Set the Hoeffding bound confidence parameter (delta).
636    pub fn delta(mut self, d: f64) -> Self {
637        self.config.delta = d;
638        self
639    }
640
641    /// Set the drift detector type for tree replacement.
642    pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
643        self.config.drift_detector = dt;
644        self
645    }
646
647    /// Set the SGBT computational variant.
648    pub fn variant(mut self, v: SGBTVariant) -> Self {
649        self.config.variant = v;
650        self
651    }
652
653    /// Set the random seed for deterministic reproducibility.
654    ///
655    /// Controls feature subsampling and variant skip/MI stochastic decisions.
656    /// Two models with the same seed and data sequence will produce identical results.
657    pub fn seed(mut self, seed: u64) -> Self {
658        self.config.seed = seed;
659        self
660    }
661
662    /// Set the number of initial targets to collect before computing the base prediction.
663    ///
664    /// The model collects this many target values before initializing the base
665    /// prediction (via `loss.initial_prediction`). Default: 50.
666    pub fn initial_target_count(mut self, count: usize) -> Self {
667        self.config.initial_target_count = count;
668        self
669    }
670
671    /// Set the half-life for exponential leaf decay (in samples per leaf).
672    ///
673    /// After `n` samples, a leaf's accumulated statistics have half the weight
674    /// of the most recent sample. Enables continuous adaptation to concept drift.
675    pub fn leaf_half_life(mut self, n: usize) -> Self {
676        self.config.leaf_half_life = Some(n);
677        self
678    }
679
680    /// Set the maximum samples a single tree processes before proactive replacement.
681    ///
682    /// After `n` samples, the tree is replaced regardless of drift detector state.
683    pub fn max_tree_samples(mut self, n: u64) -> Self {
684        self.config.max_tree_samples = Some(n);
685        self
686    }
687
688    /// Enable sigma-modulated tree replacement speed.
689    ///
690    /// The effective max_tree_samples becomes `base_mts / (1.0 + k * normalized_sigma)`,
691    /// where `normalized_sigma` tracks contribution variance across ensemble trees.
692    /// Overrides `max_tree_samples` when set.
693    pub fn adaptive_mts(mut self, base_mts: u64, k: f64) -> Self {
694        self.config.adaptive_mts = Some((base_mts, k));
695        self
696    }
697
698    /// Set a minimum effective MTS as a fraction of `base_mts`.
699    ///
700    /// Prevents adaptive MTS from shrinking tree lifetime below
701    /// `base_mts * fraction`. For example, `0.25` ensures effective MTS
702    /// never drops below 25% of `base_mts`.
703    ///
704    /// Only meaningful when `adaptive_mts` is also set.
705    pub fn adaptive_mts_floor(mut self, fraction: f64) -> Self {
706        self.config.adaptive_mts_floor = fraction;
707        self
708    }
709
710    /// Set the proactive tree pruning interval.
711    ///
712    /// Every `interval` training samples, the tree with the lowest EWMA
713    /// marginal contribution is replaced. Automatically enables contribution
714    /// tracking with alpha=0.01 when `quality_prune_alpha` is not set.
715    pub fn proactive_prune_interval(mut self, interval: u64) -> Self {
716        self.config.proactive_prune_interval = Some(interval);
717        self
718    }
719
720    /// Set the split re-evaluation interval for max-depth leaves.
721    ///
722    /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
723    /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
724    pub fn split_reeval_interval(mut self, n: usize) -> Self {
725        self.config.split_reeval_interval = Some(n);
726        self
727    }
728
729    /// Set human-readable feature names.
730    ///
731    /// Enables named feature importances and named training input.
732    /// Names must be unique; validated at [`build()`](Self::build).
733    pub fn feature_names(mut self, names: Vec<String>) -> Self {
734        self.config.feature_names = Some(names);
735        self
736    }
737
738    /// Set per-feature type declarations.
739    ///
740    /// Declares which features are categorical vs continuous. Categorical features
741    /// use one-bin-per-category binning and Fisher optimal binary partitioning.
742    /// Supports up to 64 distinct category values per categorical feature.
743    pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
744        self.config.feature_types = Some(types);
745        self
746    }
747
748    /// Set per-leaf gradient clipping threshold (in standard deviations).
749    ///
750    /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
751    /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
752    /// from corrupting streaming model stability.
753    ///
754    /// Typical value: 3.0 (3-sigma clipping).
755    pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
756        self.config.gradient_clip_sigma = Some(sigma);
757        self
758    }
759
760    /// Set per-feature monotonic constraints.
761    ///
762    /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
763    /// Candidate splits violating monotonicity are rejected during tree growth.
764    pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
765        self.config.monotone_constraints = Some(constraints);
766        self
767    }
768
769    /// Enable quality-based tree pruning with the given EWMA smoothing factor.
770    ///
771    /// Trees whose marginal contribution drops below the threshold for
772    /// `patience` consecutive samples are replaced with fresh trees.
773    /// Suggested alpha: 0.01.
774    pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
775        self.config.quality_prune_alpha = Some(alpha);
776        self
777    }
778
779    /// Set the minimum contribution threshold for quality-based pruning.
780    ///
781    /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
782    pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
783        self.config.quality_prune_threshold = threshold;
784        self
785    }
786
787    /// Set the patience (consecutive low-contribution samples) before pruning.
788    ///
789    /// Default: 500. Only relevant when `quality_prune_alpha` is set.
790    pub fn quality_prune_patience(mut self, patience: u64) -> Self {
791        self.config.quality_prune_patience = patience;
792        self
793    }
794
795    /// Enable error-weighted sample importance with the given EWMA smoothing factor.
796    ///
797    /// Samples the model predicted poorly get higher effective weight.
798    /// Suggested alpha: 0.01.
799    pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
800        self.config.error_weight_alpha = Some(alpha);
801        self
802    }
803
804    /// Enable σ-modulated learning rate for distributional models.
805    ///
806    /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
807    /// so the model adapts faster during high-uncertainty regimes and conserves
808    /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
809    ///
810    /// By default uses empirical σ (EWMA of squared errors).  Set
811    /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
812    pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
813        self.config.uncertainty_modulated_lr = enabled;
814        self
815    }
816
817    /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
818    ///
819    /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
820    /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
821    pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
822        self.config.scale_mode = mode;
823        self
824    }
825
826    /// Enable per-node auto-bandwidth soft routing.
827    pub fn soft_routing(mut self, enabled: bool) -> Self {
828        self.config.soft_routing = enabled;
829        self
830    }
831
832    /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
833    ///
834    /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
835    pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
836        self.config.empirical_sigma_alpha = alpha;
837        self
838    }
839
840    /// Set the maximum absolute leaf output value.
841    ///
842    /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
843    /// that cause prediction explosions.
844    pub fn max_leaf_output(mut self, max: f64) -> Self {
845        self.config.max_leaf_output = Some(max);
846        self
847    }
848
849    /// Set per-leaf adaptive output bound (sigma multiplier).
850    ///
851    /// Each leaf tracks EWMA of its own output weight and clamps to
852    /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
853    /// Recommended: use with `leaf_half_life` for streaming scenarios.
854    pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
855        self.config.adaptive_leaf_bound = Some(k);
856        self
857    }
858
859    /// Set the per-split information criterion factor (Lunde-Kleppe-Skaug 2020).
860    ///
861    /// Replaces static `max_depth` with a per-split generalization test.
862    /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
863    pub fn adaptive_depth(mut self, factor: f64) -> Self {
864        self.config.adaptive_depth = Some(factor);
865        self
866    }
867
868    /// Set the minimum hessian sum for leaf output.
869    ///
870    /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
871    /// post-replacement spikes.
872    pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
873        self.config.min_hessian_sum = Some(min_h);
874        self
875    }
876
877    /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
878    ///
879    /// When set, location gradients use Huber loss with adaptive
880    /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
881    pub fn huber_k(mut self, k: f64) -> Self {
882        self.config.huber_k = Some(k);
883        self
884    }
885
886    /// Enable graduated tree handoff with the given shadow warmup samples.
887    ///
888    /// Spawns an always-on shadow tree that trains alongside the active tree.
889    /// After `warmup` samples, the shadow begins contributing to predictions
890    /// via graduated blending. Eliminates prediction dips during tree replacement.
891    pub fn shadow_warmup(mut self, warmup: usize) -> Self {
892        self.config.shadow_warmup = Some(warmup);
893        self
894    }
895
896    /// Set the leaf prediction model type.
897    ///
898    /// [`LeafModelType::Linear`] is recommended for low-depth configurations
899    /// (depth 2--4) where per-leaf linear models reduce approximation error.
900    ///
901    /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
902    /// a trainable model per leaf, using the Hoeffding bound for promotion.
903    pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
904        self.config.leaf_model_type = lmt;
905        self
906    }
907
908    /// Set the packed cache refresh interval for distributional models.
909    ///
910    /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
911    /// maintains a packed f32 cache refreshed every `interval` training samples.
912    /// `0` (default) disables the cache.
913    pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
914        self.config.packed_refresh_interval = interval;
915        self
916    }
917
918    /// Validate and build the configuration.
919    ///
920    /// # Errors
921    ///
922    /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
923    /// [`ConfigError`] if any parameter is out of its valid range.
924    pub fn build(self) -> Result<SGBTConfig> {
925        let c = &self.config;
926
927        // -- Ensemble-level parameters --
928        if c.n_steps == 0 {
929            return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
930        }
931        if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
932            return Err(ConfigError::out_of_range(
933                "learning_rate",
934                "must be in (0, 1]",
935                c.learning_rate,
936            )
937            .into());
938        }
939        if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
940            return Err(ConfigError::out_of_range(
941                "feature_subsample_rate",
942                "must be in (0, 1]",
943                c.feature_subsample_rate,
944            )
945            .into());
946        }
947
948        // -- Tree-level parameters --
949        if c.max_depth == 0 {
950            return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
951        }
952        if c.n_bins < 2 {
953            return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
954        }
955        if c.lambda < 0.0 {
956            return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
957        }
958        if c.gamma < 0.0 {
959            return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
960        }
961        if c.grace_period == 0 {
962            return Err(
963                ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
964            );
965        }
966        if c.delta <= 0.0 || c.delta >= 1.0 {
967            return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
968        }
969
970        if c.initial_target_count == 0 {
971            return Err(ConfigError::out_of_range(
972                "initial_target_count",
973                "must be > 0",
974                c.initial_target_count,
975            )
976            .into());
977        }
978
979        // -- Streaming adaptation parameters --
980        if let Some(hl) = c.leaf_half_life {
981            if hl == 0 {
982                return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
983            }
984        }
985        if let Some(max) = c.max_tree_samples {
986            if max < 100 {
987                return Err(
988                    ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
989                );
990            }
991        }
992        if let Some((base_mts, k)) = c.adaptive_mts {
993            if base_mts < 100 {
994                return Err(ConfigError::out_of_range(
995                    "adaptive_mts.base_mts",
996                    "must be >= 100",
997                    base_mts,
998                )
999                .into());
1000            }
1001            if k <= 0.0 {
1002                return Err(ConfigError::out_of_range("adaptive_mts.k", "must be > 0", k).into());
1003            }
1004        }
1005        if !(0.0..=1.0).contains(&c.adaptive_mts_floor) {
1006            return Err(ConfigError::out_of_range(
1007                "adaptive_mts_floor",
1008                "must be in [0.0, 1.0]",
1009                c.adaptive_mts_floor,
1010            )
1011            .into());
1012        }
1013        if let Some(interval) = c.proactive_prune_interval {
1014            if interval < 100 {
1015                return Err(ConfigError::out_of_range(
1016                    "proactive_prune_interval",
1017                    "must be >= 100",
1018                    interval,
1019                )
1020                .into());
1021            }
1022        }
1023        if let Some(interval) = c.split_reeval_interval {
1024            if interval < c.grace_period {
1025                return Err(ConfigError::invalid(
1026                    "split_reeval_interval",
1027                    format!(
1028                        "must be >= grace_period ({}), got {}",
1029                        c.grace_period, interval
1030                    ),
1031                )
1032                .into());
1033            }
1034        }
1035
1036        // -- Feature names --
1037        if let Some(ref names) = c.feature_names {
1038            // O(n^2) duplicate check — names list is small so no HashSet needed
1039            for (i, name) in names.iter().enumerate() {
1040                for prev in &names[..i] {
1041                    if name == prev {
1042                        return Err(ConfigError::invalid(
1043                            "feature_names",
1044                            format!("duplicate feature name: '{}'", name),
1045                        )
1046                        .into());
1047                    }
1048                }
1049            }
1050        }
1051
1052        // -- Feature types --
1053        if let Some(ref types) = c.feature_types {
1054            if let Some(ref names) = c.feature_names {
1055                if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
1056                    return Err(ConfigError::invalid(
1057                        "feature_types",
1058                        format!(
1059                            "length ({}) must match feature_names length ({})",
1060                            types.len(),
1061                            names.len()
1062                        ),
1063                    )
1064                    .into());
1065                }
1066            }
1067        }
1068
1069        // -- Gradient clipping --
1070        if let Some(sigma) = c.gradient_clip_sigma {
1071            if sigma <= 0.0 {
1072                return Err(
1073                    ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
1074                );
1075            }
1076        }
1077
1078        // -- Monotonic constraints --
1079        if let Some(ref mc) = c.monotone_constraints {
1080            for (i, &v) in mc.iter().enumerate() {
1081                if v != -1 && v != 0 && v != 1 {
1082                    return Err(ConfigError::invalid(
1083                        "monotone_constraints",
1084                        format!("feature {}: must be -1, 0, or +1, got {}", i, v),
1085                    )
1086                    .into());
1087                }
1088            }
1089        }
1090
1091        // -- Leaf output clamping --
1092        if let Some(max) = c.max_leaf_output {
1093            if max <= 0.0 {
1094                return Err(
1095                    ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
1096                );
1097            }
1098        }
1099
1100        // -- Per-leaf adaptive output bound --
1101        if let Some(k) = c.adaptive_leaf_bound {
1102            if k <= 0.0 {
1103                return Err(
1104                    ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
1105                );
1106            }
1107        }
1108
1109        // -- Adaptive depth (per-split information criterion) --
1110        if let Some(factor) = c.adaptive_depth {
1111            if factor <= 0.0 {
1112                return Err(
1113                    ConfigError::out_of_range("adaptive_depth", "must be > 0", factor).into(),
1114                );
1115            }
1116        }
1117
1118        // -- Minimum hessian sum --
1119        if let Some(min_h) = c.min_hessian_sum {
1120            if min_h <= 0.0 {
1121                return Err(
1122                    ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
1123                );
1124            }
1125        }
1126
1127        // -- Huber loss multiplier --
1128        if let Some(k) = c.huber_k {
1129            if k <= 0.0 {
1130                return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
1131            }
1132        }
1133
1134        // -- Shadow warmup --
1135        if let Some(warmup) = c.shadow_warmup {
1136            if warmup == 0 {
1137                return Err(ConfigError::out_of_range(
1138                    "shadow_warmup",
1139                    "must be > 0",
1140                    warmup as f64,
1141                )
1142                .into());
1143            }
1144        }
1145
1146        // -- Quality-based pruning parameters --
1147        if let Some(alpha) = c.quality_prune_alpha {
1148            if alpha <= 0.0 || alpha >= 1.0 {
1149                return Err(ConfigError::out_of_range(
1150                    "quality_prune_alpha",
1151                    "must be in (0, 1)",
1152                    alpha,
1153                )
1154                .into());
1155            }
1156        }
1157        if c.quality_prune_threshold <= 0.0 {
1158            return Err(ConfigError::out_of_range(
1159                "quality_prune_threshold",
1160                "must be > 0",
1161                c.quality_prune_threshold,
1162            )
1163            .into());
1164        }
1165        if c.quality_prune_patience == 0 {
1166            return Err(ConfigError::out_of_range(
1167                "quality_prune_patience",
1168                "must be > 0",
1169                c.quality_prune_patience,
1170            )
1171            .into());
1172        }
1173
1174        // -- Error-weighted sample importance --
1175        if let Some(alpha) = c.error_weight_alpha {
1176            if alpha <= 0.0 || alpha >= 1.0 {
1177                return Err(ConfigError::out_of_range(
1178                    "error_weight_alpha",
1179                    "must be in (0, 1)",
1180                    alpha,
1181                )
1182                .into());
1183            }
1184        }
1185
1186        // -- Drift detector parameters --
1187        match &c.drift_detector {
1188            DriftDetectorType::PageHinkley { delta, lambda } => {
1189                if *delta <= 0.0 {
1190                    return Err(ConfigError::out_of_range(
1191                        "drift_detector.PageHinkley.delta",
1192                        "must be > 0",
1193                        delta,
1194                    )
1195                    .into());
1196                }
1197                if *lambda <= 0.0 {
1198                    return Err(ConfigError::out_of_range(
1199                        "drift_detector.PageHinkley.lambda",
1200                        "must be > 0",
1201                        lambda,
1202                    )
1203                    .into());
1204                }
1205            }
1206            DriftDetectorType::Adwin { delta } => {
1207                if *delta <= 0.0 || *delta >= 1.0 {
1208                    return Err(ConfigError::out_of_range(
1209                        "drift_detector.Adwin.delta",
1210                        "must be in (0, 1)",
1211                        delta,
1212                    )
1213                    .into());
1214                }
1215            }
1216            DriftDetectorType::Ddm {
1217                warning_level,
1218                drift_level,
1219                min_instances,
1220            } => {
1221                if *warning_level <= 0.0 {
1222                    return Err(ConfigError::out_of_range(
1223                        "drift_detector.Ddm.warning_level",
1224                        "must be > 0",
1225                        warning_level,
1226                    )
1227                    .into());
1228                }
1229                if *drift_level <= 0.0 {
1230                    return Err(ConfigError::out_of_range(
1231                        "drift_detector.Ddm.drift_level",
1232                        "must be > 0",
1233                        drift_level,
1234                    )
1235                    .into());
1236                }
1237                if *drift_level <= *warning_level {
1238                    return Err(ConfigError::invalid(
1239                        "drift_detector.Ddm.drift_level",
1240                        format!(
1241                            "must be > warning_level ({}), got {}",
1242                            warning_level, drift_level
1243                        ),
1244                    )
1245                    .into());
1246                }
1247                if *min_instances == 0 {
1248                    return Err(ConfigError::out_of_range(
1249                        "drift_detector.Ddm.min_instances",
1250                        "must be > 0",
1251                        min_instances,
1252                    )
1253                    .into());
1254                }
1255            }
1256        }
1257
1258        // -- Variant parameters --
1259        match &c.variant {
1260            SGBTVariant::Standard => {} // no extra validation
1261            SGBTVariant::Skip { k } => {
1262                if *k == 0 {
1263                    return Err(
1264                        ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1265                    );
1266                }
1267            }
1268            SGBTVariant::MultipleIterations { multiplier } => {
1269                if *multiplier <= 0.0 {
1270                    return Err(ConfigError::out_of_range(
1271                        "variant.MultipleIterations.multiplier",
1272                        "must be > 0",
1273                        multiplier,
1274                    )
1275                    .into());
1276                }
1277            }
1278        }
1279
1280        Ok(self.config)
1281    }
1282}
1283
1284// ---------------------------------------------------------------------------
1285// Display impls
1286// ---------------------------------------------------------------------------
1287
1288impl core::fmt::Display for DriftDetectorType {
1289    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1290        match self {
1291            Self::PageHinkley { delta, lambda } => {
1292                write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1293            }
1294            Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1295            Self::Ddm {
1296                warning_level,
1297                drift_level,
1298                min_instances,
1299            } => write!(
1300                f,
1301                "Ddm(warning={}, drift={}, min_instances={})",
1302                warning_level, drift_level, min_instances
1303            ),
1304        }
1305    }
1306}
1307
1308impl core::fmt::Display for SGBTConfig {
1309    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1310        write!(
1311            f,
1312            "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1313            self.n_steps,
1314            self.learning_rate,
1315            self.max_depth,
1316            self.n_bins,
1317            self.variant,
1318            self.drift_detector,
1319        )
1320    }
1321}
1322
1323// ---------------------------------------------------------------------------
1324// Tests
1325// ---------------------------------------------------------------------------
1326
1327#[cfg(test)]
1328mod tests {
1329    use super::*;
1330    use alloc::format;
1331    use alloc::vec;
1332
1333    // ------------------------------------------------------------------
1334    // 1. Default config values are correct
1335    // ------------------------------------------------------------------
1336    #[test]
1337    fn default_config_values() {
1338        let cfg = SGBTConfig::default();
1339        assert_eq!(cfg.n_steps, 100);
1340        assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1341        assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1342        assert_eq!(cfg.max_depth, 6);
1343        assert_eq!(cfg.n_bins, 64);
1344        assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1345        assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1346        assert_eq!(cfg.grace_period, 200);
1347        assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1348        assert_eq!(cfg.variant, SGBTVariant::Standard);
1349    }
1350
1351    // ------------------------------------------------------------------
1352    // 2. Builder chain works
1353    // ------------------------------------------------------------------
1354    #[test]
1355    fn builder_chain() {
1356        let cfg = SGBTConfig::builder()
1357            .n_steps(50)
1358            .learning_rate(0.1)
1359            .feature_subsample_rate(0.5)
1360            .max_depth(10)
1361            .n_bins(128)
1362            .lambda(0.5)
1363            .gamma(0.1)
1364            .grace_period(500)
1365            .delta(1e-3)
1366            .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1367            .variant(SGBTVariant::Skip { k: 5 })
1368            .build()
1369            .expect("valid config");
1370
1371        assert_eq!(cfg.n_steps, 50);
1372        assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1373        assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1374        assert_eq!(cfg.max_depth, 10);
1375        assert_eq!(cfg.n_bins, 128);
1376        assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1377        assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1378        assert_eq!(cfg.grace_period, 500);
1379        assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1380        assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1381
1382        // Verify drift detector type is Adwin.
1383        match &cfg.drift_detector {
1384            DriftDetectorType::Adwin { delta } => {
1385                assert!((delta - 0.01).abs() < f64::EPSILON);
1386            }
1387            _ => panic!("expected Adwin drift detector"),
1388        }
1389    }
1390
1391    // ------------------------------------------------------------------
1392    // 3. Validation rejects invalid values
1393    // ------------------------------------------------------------------
1394    #[test]
1395    fn validation_rejects_n_steps_zero() {
1396        let result = SGBTConfig::builder().n_steps(0).build();
1397        assert!(result.is_err());
1398        let msg = format!("{}", result.unwrap_err());
1399        assert!(msg.contains("n_steps"));
1400    }
1401
1402    #[test]
1403    fn validation_rejects_learning_rate_zero() {
1404        let result = SGBTConfig::builder().learning_rate(0.0).build();
1405        assert!(result.is_err());
1406        let msg = format!("{}", result.unwrap_err());
1407        assert!(msg.contains("learning_rate"));
1408    }
1409
1410    #[test]
1411    fn validation_rejects_learning_rate_above_one() {
1412        let result = SGBTConfig::builder().learning_rate(1.5).build();
1413        assert!(result.is_err());
1414    }
1415
1416    #[test]
1417    fn validation_accepts_learning_rate_one() {
1418        let result = SGBTConfig::builder().learning_rate(1.0).build();
1419        assert!(result.is_ok());
1420    }
1421
1422    #[test]
1423    fn validation_rejects_negative_learning_rate() {
1424        let result = SGBTConfig::builder().learning_rate(-0.1).build();
1425        assert!(result.is_err());
1426    }
1427
1428    #[test]
1429    fn validation_rejects_feature_subsample_zero() {
1430        let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1431        assert!(result.is_err());
1432    }
1433
1434    #[test]
1435    fn validation_rejects_feature_subsample_above_one() {
1436        let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1437        assert!(result.is_err());
1438    }
1439
1440    #[test]
1441    fn validation_rejects_max_depth_zero() {
1442        let result = SGBTConfig::builder().max_depth(0).build();
1443        assert!(result.is_err());
1444    }
1445
1446    #[test]
1447    fn validation_rejects_n_bins_one() {
1448        let result = SGBTConfig::builder().n_bins(1).build();
1449        assert!(result.is_err());
1450    }
1451
1452    #[test]
1453    fn validation_rejects_negative_lambda() {
1454        let result = SGBTConfig::builder().lambda(-0.1).build();
1455        assert!(result.is_err());
1456    }
1457
1458    #[test]
1459    fn validation_accepts_zero_lambda() {
1460        let result = SGBTConfig::builder().lambda(0.0).build();
1461        assert!(result.is_ok());
1462    }
1463
1464    #[test]
1465    fn validation_rejects_negative_gamma() {
1466        let result = SGBTConfig::builder().gamma(-0.1).build();
1467        assert!(result.is_err());
1468    }
1469
1470    #[test]
1471    fn validation_rejects_grace_period_zero() {
1472        let result = SGBTConfig::builder().grace_period(0).build();
1473        assert!(result.is_err());
1474    }
1475
1476    #[test]
1477    fn validation_rejects_delta_zero() {
1478        let result = SGBTConfig::builder().delta(0.0).build();
1479        assert!(result.is_err());
1480    }
1481
1482    #[test]
1483    fn validation_rejects_delta_one() {
1484        let result = SGBTConfig::builder().delta(1.0).build();
1485        assert!(result.is_err());
1486    }
1487
1488    // ------------------------------------------------------------------
1489    // 3b. Drift detector parameter validation
1490    // ------------------------------------------------------------------
1491    #[test]
1492    fn validation_rejects_pht_negative_delta() {
1493        let result = SGBTConfig::builder()
1494            .drift_detector(DriftDetectorType::PageHinkley {
1495                delta: -1.0,
1496                lambda: 50.0,
1497            })
1498            .build();
1499        assert!(result.is_err());
1500        let msg = format!("{}", result.unwrap_err());
1501        assert!(msg.contains("PageHinkley"));
1502    }
1503
1504    #[test]
1505    fn validation_rejects_pht_zero_lambda() {
1506        let result = SGBTConfig::builder()
1507            .drift_detector(DriftDetectorType::PageHinkley {
1508                delta: 0.005,
1509                lambda: 0.0,
1510            })
1511            .build();
1512        assert!(result.is_err());
1513    }
1514
1515    #[test]
1516    fn validation_rejects_adwin_delta_out_of_range() {
1517        let result = SGBTConfig::builder()
1518            .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1519            .build();
1520        assert!(result.is_err());
1521
1522        let result = SGBTConfig::builder()
1523            .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1524            .build();
1525        assert!(result.is_err());
1526    }
1527
1528    #[test]
1529    fn validation_rejects_ddm_warning_above_drift() {
1530        let result = SGBTConfig::builder()
1531            .drift_detector(DriftDetectorType::Ddm {
1532                warning_level: 3.0,
1533                drift_level: 2.0,
1534                min_instances: 30,
1535            })
1536            .build();
1537        assert!(result.is_err());
1538        let msg = format!("{}", result.unwrap_err());
1539        assert!(msg.contains("drift_level"));
1540        assert!(msg.contains("must be > warning_level"));
1541    }
1542
1543    #[test]
1544    fn validation_rejects_ddm_equal_levels() {
1545        let result = SGBTConfig::builder()
1546            .drift_detector(DriftDetectorType::Ddm {
1547                warning_level: 2.0,
1548                drift_level: 2.0,
1549                min_instances: 30,
1550            })
1551            .build();
1552        assert!(result.is_err());
1553    }
1554
1555    #[test]
1556    fn validation_rejects_ddm_zero_min_instances() {
1557        let result = SGBTConfig::builder()
1558            .drift_detector(DriftDetectorType::Ddm {
1559                warning_level: 2.0,
1560                drift_level: 3.0,
1561                min_instances: 0,
1562            })
1563            .build();
1564        assert!(result.is_err());
1565    }
1566
1567    #[test]
1568    fn validation_rejects_ddm_zero_warning_level() {
1569        let result = SGBTConfig::builder()
1570            .drift_detector(DriftDetectorType::Ddm {
1571                warning_level: 0.0,
1572                drift_level: 3.0,
1573                min_instances: 30,
1574            })
1575            .build();
1576        assert!(result.is_err());
1577    }
1578
1579    // ------------------------------------------------------------------
1580    // 3c. Variant parameter validation
1581    // ------------------------------------------------------------------
1582    #[test]
1583    fn validation_rejects_skip_k_zero() {
1584        let result = SGBTConfig::builder()
1585            .variant(SGBTVariant::Skip { k: 0 })
1586            .build();
1587        assert!(result.is_err());
1588        let msg = format!("{}", result.unwrap_err());
1589        assert!(msg.contains("Skip"));
1590    }
1591
1592    #[test]
1593    fn validation_rejects_mi_zero_multiplier() {
1594        let result = SGBTConfig::builder()
1595            .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1596            .build();
1597        assert!(result.is_err());
1598    }
1599
1600    #[test]
1601    fn validation_rejects_mi_negative_multiplier() {
1602        let result = SGBTConfig::builder()
1603            .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1604            .build();
1605        assert!(result.is_err());
1606    }
1607
1608    #[test]
1609    fn validation_accepts_standard_variant() {
1610        let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1611        assert!(result.is_ok());
1612    }
1613
1614    // ------------------------------------------------------------------
1615    // 4. DriftDetectorType creates correct detector types
1616    // ------------------------------------------------------------------
1617    #[test]
1618    fn drift_detector_type_creates_page_hinkley() {
1619        let dt = DriftDetectorType::PageHinkley {
1620            delta: 0.01,
1621            lambda: 100.0,
1622        };
1623        let mut detector = dt.create();
1624
1625        // Should start with zero mean.
1626        assert_eq!(detector.estimated_mean(), 0.0);
1627
1628        // Feed a stable value -- should not drift.
1629        let signal = detector.update(1.0);
1630        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1631    }
1632
1633    #[test]
1634    fn drift_detector_type_creates_adwin() {
1635        let dt = DriftDetectorType::Adwin { delta: 0.05 };
1636        let mut detector = dt.create();
1637
1638        assert_eq!(detector.estimated_mean(), 0.0);
1639        let signal = detector.update(1.0);
1640        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1641    }
1642
1643    #[test]
1644    fn drift_detector_type_creates_ddm() {
1645        let dt = DriftDetectorType::Ddm {
1646            warning_level: 2.0,
1647            drift_level: 3.0,
1648            min_instances: 30,
1649        };
1650        let mut detector = dt.create();
1651
1652        assert_eq!(detector.estimated_mean(), 0.0);
1653        let signal = detector.update(0.1);
1654        assert_eq!(signal, crate::drift::DriftSignal::Stable);
1655    }
1656
1657    // ------------------------------------------------------------------
1658    // 5. Default drift detector is PageHinkley
1659    // ------------------------------------------------------------------
1660    #[test]
1661    fn default_drift_detector_is_page_hinkley() {
1662        let dt = DriftDetectorType::default();
1663        match dt {
1664            DriftDetectorType::PageHinkley { delta, lambda } => {
1665                assert!((delta - 0.005).abs() < f64::EPSILON);
1666                assert!((lambda - 50.0).abs() < f64::EPSILON);
1667            }
1668            _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1669        }
1670    }
1671
1672    // ------------------------------------------------------------------
1673    // 6. Default config builds without error
1674    // ------------------------------------------------------------------
1675    #[test]
1676    fn default_config_builds() {
1677        let result = SGBTConfig::builder().build();
1678        assert!(result.is_ok());
1679    }
1680
1681    // ------------------------------------------------------------------
1682    // 7. Config clone preserves all fields
1683    // ------------------------------------------------------------------
1684    #[test]
1685    fn config_clone_preserves_fields() {
1686        let original = SGBTConfig::builder()
1687            .n_steps(50)
1688            .learning_rate(0.05)
1689            .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1690            .build()
1691            .unwrap();
1692
1693        let cloned = original.clone();
1694
1695        assert_eq!(cloned.n_steps, 50);
1696        assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1697        assert_eq!(
1698            cloned.variant,
1699            SGBTVariant::MultipleIterations { multiplier: 5.0 }
1700        );
1701    }
1702
1703    // ------------------------------------------------------------------
1704    // 8. All three DDM valid configs accepted
1705    // ------------------------------------------------------------------
1706    #[test]
1707    fn valid_ddm_config_accepted() {
1708        let result = SGBTConfig::builder()
1709            .drift_detector(DriftDetectorType::Ddm {
1710                warning_level: 1.5,
1711                drift_level: 2.5,
1712                min_instances: 10,
1713            })
1714            .build();
1715        assert!(result.is_ok());
1716    }
1717
1718    // ------------------------------------------------------------------
1719    // 9. Created detectors are functional (round-trip test)
1720    // ------------------------------------------------------------------
1721    #[test]
1722    fn created_detectors_are_functional() {
1723        // Create a detector, feed it data, verify it can detect drift.
1724        let dt = DriftDetectorType::PageHinkley {
1725            delta: 0.005,
1726            lambda: 50.0,
1727        };
1728        let mut detector = dt.create();
1729
1730        // Feed 500 stable values.
1731        for _ in 0..500 {
1732            detector.update(1.0);
1733        }
1734
1735        // Feed shifted values -- should eventually drift.
1736        let mut drifted = false;
1737        for _ in 0..500 {
1738            if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1739                drifted = true;
1740                break;
1741            }
1742        }
1743        assert!(
1744            drifted,
1745            "detector created from DriftDetectorType should be functional"
1746        );
1747    }
1748
1749    // ------------------------------------------------------------------
1750    // 10. Boundary acceptance: n_bins=2 is the minimum valid
1751    // ------------------------------------------------------------------
1752    #[test]
1753    fn boundary_n_bins_two_accepted() {
1754        let result = SGBTConfig::builder().n_bins(2).build();
1755        assert!(result.is_ok());
1756    }
1757
1758    // ------------------------------------------------------------------
1759    // 11. Boundary acceptance: grace_period=1 is valid
1760    // ------------------------------------------------------------------
1761    #[test]
1762    fn boundary_grace_period_one_accepted() {
1763        let result = SGBTConfig::builder().grace_period(1).build();
1764        assert!(result.is_ok());
1765    }
1766
1767    // ------------------------------------------------------------------
1768    // 12. Feature names -- valid config with names
1769    // ------------------------------------------------------------------
1770    #[test]
1771    fn feature_names_accepted() {
1772        let cfg = SGBTConfig::builder()
1773            .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1774            .build()
1775            .unwrap();
1776        assert_eq!(
1777            cfg.feature_names.as_ref().unwrap(),
1778            &["price", "volume", "spread"]
1779        );
1780    }
1781
1782    // ------------------------------------------------------------------
1783    // 13. Feature names -- duplicate names rejected
1784    // ------------------------------------------------------------------
1785    #[test]
1786    fn feature_names_rejects_duplicates() {
1787        let result = SGBTConfig::builder()
1788            .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1789            .build();
1790        assert!(result.is_err());
1791        let msg = format!("{}", result.unwrap_err());
1792        assert!(msg.contains("duplicate"));
1793    }
1794
1795    // ------------------------------------------------------------------
1796    // 14. Feature names -- serde backward compat (missing field)
1797    //     Requires serde_json which is only in the full irithyll crate.
1798    // ------------------------------------------------------------------
1799
1800    // ------------------------------------------------------------------
1801    // 15. Feature names -- empty vec is valid
1802    // ------------------------------------------------------------------
1803    #[test]
1804    fn feature_names_empty_vec_accepted() {
1805        let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1806        assert!(cfg.feature_names.unwrap().is_empty());
1807    }
1808
1809    // ------------------------------------------------------------------
1810    // 16. Adaptive leaf bound -- builder
1811    // ------------------------------------------------------------------
1812    #[test]
1813    fn builder_adaptive_leaf_bound() {
1814        let cfg = SGBTConfig::builder()
1815            .adaptive_leaf_bound(3.0)
1816            .build()
1817            .unwrap();
1818        assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1819    }
1820
1821    // ------------------------------------------------------------------
1822    // 17. Adaptive leaf bound -- validation rejects zero
1823    // ------------------------------------------------------------------
1824    #[test]
1825    fn validation_rejects_zero_adaptive_leaf_bound() {
1826        let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1827        assert!(result.is_err());
1828        let msg = format!("{}", result.unwrap_err());
1829        assert!(
1830            msg.contains("adaptive_leaf_bound"),
1831            "error should mention adaptive_leaf_bound: {}",
1832            msg,
1833        );
1834    }
1835
1836    // ------------------------------------------------------------------
1837    // 18. Adaptive leaf bound -- serde backward compat
1838    //     Requires serde_json which is only in the full irithyll crate.
1839    // ------------------------------------------------------------------
1840}