Skip to main content

irithyll_core/ensemble/
config.rs

1//! SGBT configuration with builder pattern and full validation.
2//!
3//! [`SGBTConfig`] holds all hyperparameters for the Streaming Gradient Boosted
4//! Trees ensemble. Use [`SGBTConfig::builder`] for ergonomic construction with
5//! validation on [`build()`](SGBTConfigBuilder::build).
6
7use alloc::boxed::Box;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11
12use crate::drift::adwin::Adwin;
13use crate::drift::ddm::Ddm;
14use crate::drift::pht::PageHinkleyTest;
15use crate::drift::DriftDetector;
16use crate::ensemble::variants::SGBTVariant;
17use crate::error::{ConfigError, Result};
18use crate::tree::leaf_model::LeafModelType;
19
20// ---------------------------------------------------------------------------
21// FeatureType -- re-exported from irithyll-core
22// ---------------------------------------------------------------------------
23
24pub use crate::feature::FeatureType;
25
26// ---------------------------------------------------------------------------
27// ScaleMode
28// ---------------------------------------------------------------------------
29
30/// How [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
31/// estimates uncertainty (σ).
32///
33/// - **`Empirical`** (default): tracks an EWMA of squared prediction errors.
34///   `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1) compute.
35///   Use this when σ drives learning-rate modulation (σ high → learn faster).
36///
37/// - **`TreeChain`**: trains a full second ensemble of Hoeffding trees to predict
38///   log(σ) from features (NGBoost-style dual chain).  Gives *feature-conditional*
39///   uncertainty but requires strong signal in the scale gradients.
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum ScaleMode {
43    /// EWMA of squared prediction errors — always calibrated, O(1).
44    #[default]
45    Empirical,
46    /// Full Hoeffding-tree ensemble for feature-conditional log(σ) prediction.
47    TreeChain,
48}
49
50// ---------------------------------------------------------------------------
51// DriftDetectorType
52// ---------------------------------------------------------------------------
53
54/// Which drift detector to instantiate for each boosting step.
55///
56/// Each variant stores the detector's configuration parameters so that fresh
57/// instances can be created on demand (e.g. when replacing a drifted tree).
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub enum DriftDetectorType {
61    /// Page-Hinkley Test with custom delta (magnitude tolerance) and lambda
62    /// (detection threshold).
63    PageHinkley {
64        /// Magnitude tolerance. Default 0.005.
65        delta: f64,
66        /// Detection threshold. Default 50.0.
67        lambda: f64,
68    },
69
70    /// ADWIN with custom confidence parameter.
71    Adwin {
72        /// Confidence (smaller = fewer false positives). Default 0.002.
73        delta: f64,
74    },
75
76    /// DDM with custom warning/drift levels and minimum warmup instances.
77    Ddm {
78        /// Warning threshold multiplier. Default 2.0.
79        warning_level: f64,
80        /// Drift threshold multiplier. Default 3.0.
81        drift_level: f64,
82        /// Minimum observations before detection activates. Default 30.
83        min_instances: u64,
84    },
85}
86
87impl Default for DriftDetectorType {
88    fn default() -> Self {
89        DriftDetectorType::PageHinkley {
90            delta: 0.005,
91            lambda: 50.0,
92        }
93    }
94}
95
96impl DriftDetectorType {
97    /// Create a new, fresh drift detector from this configuration.
98    pub fn create(&self) -> Box<dyn DriftDetector> {
99        match self {
100            Self::PageHinkley { delta, lambda } => {
101                Box::new(PageHinkleyTest::with_params(*delta, *lambda))
102            }
103            Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
104            Self::Ddm {
105                warning_level,
106                drift_level,
107                min_instances,
108            } => Box::new(Ddm::with_params(
109                *warning_level,
110                *drift_level,
111                *min_instances,
112            )),
113        }
114    }
115}
116
117// ---------------------------------------------------------------------------
118// SGBTConfig
119// ---------------------------------------------------------------------------
120
121/// Configuration for the SGBT ensemble.
122///
123/// All numeric parameters are validated at build time via [`SGBTConfigBuilder`].
124///
125/// # Defaults
126///
127/// | Parameter                | Default              |
128/// |--------------------------|----------------------|
129/// | `n_steps`                | 100                  |
130/// | `learning_rate`          | 0.0125               |
131/// | `feature_subsample_rate` | 0.75                 |
132/// | `max_depth`              | 6                    |
133/// | `n_bins`                 | 64                   |
134/// | `lambda`                 | 1.0                  |
135/// | `gamma`                  | 0.0                  |
136/// | `grace_period`           | 200                  |
137/// | `delta`                  | 1e-7                 |
138/// | `drift_detector`         | PageHinkley(0.005, 50.0) |
139/// | `variant`                | Standard             |
140/// | `seed`                   | 0xDEAD_BEEF_CAFE_4242 |
141/// | `initial_target_count`   | 50                   |
142/// | `leaf_half_life`         | None (disabled)      |
143/// | `max_tree_samples`       | None (disabled)      |
144/// | `split_reeval_interval`  | None (disabled)      |
145#[derive(Debug, Clone, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct SGBTConfig {
148    /// Number of boosting steps (trees in the ensemble). Default 100.
149    pub n_steps: usize,
150    /// Learning rate (shrinkage). Default 0.0125.
151    pub learning_rate: f64,
152    /// Fraction of features to subsample per tree. Default 0.75.
153    pub feature_subsample_rate: f64,
154    /// Maximum tree depth. Default 6.
155    pub max_depth: usize,
156    /// Number of histogram bins. Default 64.
157    pub n_bins: usize,
158    /// L2 regularization parameter (lambda). Default 1.0.
159    pub lambda: f64,
160    /// Minimum split gain (gamma). Default 0.0.
161    pub gamma: f64,
162    /// Grace period: min samples before evaluating splits. Default 200.
163    pub grace_period: usize,
164    /// Hoeffding bound confidence (delta). Default 1e-7.
165    pub delta: f64,
166    /// Drift detector type for tree replacement. Default: PageHinkley.
167    pub drift_detector: DriftDetectorType,
168    /// SGBT computational variant. Default: Standard.
169    pub variant: SGBTVariant,
170    /// Random seed for deterministic reproducibility. Default: 0xDEAD_BEEF_CAFE_4242.
171    ///
172    /// Controls feature subsampling and variant skip/MI stochastic decisions.
173    /// Two models with the same seed and same data will produce identical results.
174    pub seed: u64,
175    /// Number of initial targets to collect before computing the base prediction.
176    /// Default: 50.
177    pub initial_target_count: usize,
178
179    /// Half-life for exponential leaf decay (in samples per leaf).
180    ///
181    /// After `leaf_half_life` samples, a leaf's accumulated gradient/hessian
182    /// statistics have half the weight of the most recent sample. This causes
183    /// the model to continuously adapt to changing data distributions rather
184    /// than freezing on early observations.
185    ///
186    /// `None` (default) disables decay -- traditional monotonic accumulation.
187    #[cfg_attr(feature = "serde", serde(default))]
188    pub leaf_half_life: Option<usize>,
189
190    /// Maximum samples a single tree processes before proactive replacement.
191    ///
192    /// After this many samples, the tree is replaced with a fresh one regardless
193    /// of drift detector state. Prevents stale tree structure from persisting
194    /// when the drift detector is not sensitive enough.
195    ///
196    /// `None` (default) disables time-based replacement.
197    #[cfg_attr(feature = "serde", serde(default))]
198    pub max_tree_samples: Option<u64>,
199
200    /// Sigma-modulated tree replacement speed.
201    ///
202    /// When `Some((base_mts, k))`, the effective max_tree_samples becomes:
203    /// `effective_mts = base_mts / (1.0 + k * normalized_sigma)`
204    /// where `normalized_sigma = contribution_sigma / rolling_contribution_sigma_mean`.
205    ///
206    /// High uncertainty (contribution variance) shortens tree lifetime for faster
207    /// adaptation. Low uncertainty lengthens lifetime for more knowledge accumulation.
208    ///
209    /// Overrides `max_tree_samples` when set. Default: `None`.
210    #[cfg_attr(feature = "serde", serde(default))]
211    pub adaptive_mts: Option<(u64, f64)>,
212
213    /// Proactive tree pruning interval.
214    ///
215    /// Every `interval` training samples, the tree with the lowest EWMA
216    /// marginal contribution is replaced with a fresh tree.
217    /// Fresh trees (fewer than `interval / 2` samples) are excluded.
218    ///
219    /// When set, contribution tracking is automatically enabled even if
220    /// `quality_prune_alpha` is `None` (uses alpha=0.01).
221    ///
222    /// Default: `None` (disabled).
223    #[cfg_attr(feature = "serde", serde(default))]
224    pub proactive_prune_interval: Option<u64>,
225
226    /// Interval (in samples per leaf) at which max-depth leaves re-evaluate
227    /// whether a split would improve them.
228    ///
229    /// Inspired by EFDT (Manapragada et al. 2018). When a leaf has accumulated
230    /// `split_reeval_interval` samples since its last evaluation and has reached
231    /// max depth, it re-evaluates whether a split should be performed.
232    ///
233    /// `None` (default) disables re-evaluation -- max-depth leaves are permanent.
234    #[cfg_attr(feature = "serde", serde(default))]
235    pub split_reeval_interval: Option<usize>,
236
237    /// Optional human-readable feature names.
238    ///
239    /// When set, enables `named_feature_importances` and
240    /// `train_one_named` for production-friendly named access.
241    /// Length must match the number of features in training data.
242    #[cfg_attr(feature = "serde", serde(default))]
243    pub feature_names: Option<Vec<String>>,
244
245    /// Optional per-feature type declarations.
246    ///
247    /// When set, declares which features are categorical vs continuous.
248    /// Categorical features use one-bin-per-category binning and Fisher
249    /// optimal binary partitioning for split evaluation.
250    /// Length must match the number of features in training data.
251    ///
252    /// `None` (default) treats all features as continuous.
253    #[cfg_attr(feature = "serde", serde(default))]
254    pub feature_types: Option<Vec<FeatureType>>,
255
256    /// Gradient clipping threshold in standard deviations per leaf.
257    ///
258    /// When enabled, each leaf tracks an EWMA of gradient mean and variance.
259    /// Incoming gradients that exceed `mean ± sigma * gradient_clip_sigma` are
260    /// clamped to the boundary. This prevents outlier samples from corrupting
261    /// leaf statistics, which is critical in streaming settings where sudden
262    /// label floods can destabilize the model.
263    ///
264    /// Typical value: 3.0 (3-sigma clipping).
265    /// `None` (default) disables gradient clipping.
266    #[cfg_attr(feature = "serde", serde(default))]
267    pub gradient_clip_sigma: Option<f64>,
268
269    /// Per-feature monotonic constraints.
270    ///
271    /// Each element specifies the monotonic relationship between a feature and
272    /// the prediction:
273    /// - `+1`: prediction must be non-decreasing as feature value increases.
274    /// - `-1`: prediction must be non-increasing as feature value increases.
275    /// - `0`: no constraint (unconstrained).
276    ///
277    /// During split evaluation, candidate splits that would violate monotonicity
278    /// (left child value > right child value for +1 constraints, or vice versa)
279    /// are rejected.
280    ///
281    /// Length must match the number of features in training data.
282    /// `None` (default) means no monotonic constraints.
283    #[cfg_attr(feature = "serde", serde(default))]
284    pub monotone_constraints: Option<Vec<i8>>,
285
286    /// EWMA smoothing factor for quality-based tree pruning.
287    ///
288    /// When `Some(alpha)`, each boosting step tracks an exponentially weighted
289    /// moving average of its marginal contribution to the ensemble. Trees whose
290    /// contribution drops below [`quality_prune_threshold`](Self::quality_prune_threshold)
291    /// for [`quality_prune_patience`](Self::quality_prune_patience) consecutive
292    /// samples are replaced with a fresh tree that can learn the current regime.
293    ///
294    /// This prevents "dead wood" -- trees from a past regime that no longer
295    /// contribute meaningfully to ensemble accuracy.
296    ///
297    /// `None` (default) disables quality-based pruning.
298    /// Suggested value: 0.01.
299    #[cfg_attr(feature = "serde", serde(default))]
300    pub quality_prune_alpha: Option<f64>,
301
302    /// Minimum contribution threshold for quality-based pruning.
303    ///
304    /// A tree's EWMA contribution must stay above this value to avoid being
305    /// flagged as dead wood. Only used when `quality_prune_alpha` is `Some`.
306    ///
307    /// Default: 1e-6.
308    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_threshold"))]
309    pub quality_prune_threshold: f64,
310
311    /// Consecutive low-contribution samples before a tree is replaced.
312    ///
313    /// After this many consecutive samples where a tree's EWMA contribution
314    /// is below `quality_prune_threshold`, the tree is reset. Only used when
315    /// `quality_prune_alpha` is `Some`.
316    ///
317    /// Default: 500.
318    #[cfg_attr(feature = "serde", serde(default = "default_quality_prune_patience"))]
319    pub quality_prune_patience: u64,
320
321    /// EWMA smoothing factor for error-weighted sample importance.
322    ///
323    /// When `Some(alpha)`, samples the model predicted poorly get higher
324    /// effective weight during histogram accumulation. The weight is:
325    /// `1.0 + |error| / (rolling_mean_error + epsilon)`, capped at 10x.
326    ///
327    /// This is a streaming version of AdaBoost's reweighting applied at the
328    /// gradient level -- learning capacity focuses on hard/novel patterns,
329    /// enabling faster adaptation to regime changes.
330    ///
331    /// `None` (default) disables error weighting.
332    /// Suggested value: 0.01.
333    #[cfg_attr(feature = "serde", serde(default))]
334    pub error_weight_alpha: Option<f64>,
335
336    /// Enable σ-modulated learning rate for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
337    ///
338    /// When `true`, the **location** (μ) ensemble's learning rate is scaled by
339    /// `sigma_ratio = current_sigma / rolling_sigma_mean`, where `rolling_sigma_mean`
340    /// is an EWMA of the model's predicted σ (alpha = 0.001).
341    ///
342    /// This means the model learns μ **faster** when σ is elevated (high uncertainty)
343    /// and **slower** when σ is low (confident regime). The scale (σ) ensemble always
344    /// trains at the unmodulated base rate to prevent positive feedback loops.
345    ///
346    /// Default: `false`.
347    #[cfg_attr(feature = "serde", serde(default))]
348    pub uncertainty_modulated_lr: bool,
349
350    /// How the scale (σ) is estimated in [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
351    ///
352    /// - [`Empirical`](ScaleMode::Empirical) (default): EWMA of squared prediction
353    ///   errors.  `σ = sqrt(ewma_sq_err)`.  Always calibrated, zero tuning, O(1).
354    /// - [`TreeChain`](ScaleMode::TreeChain): full dual-chain NGBoost with a
355    ///   separate tree ensemble predicting log(σ) from features.
356    ///
357    /// For σ-modulated learning (`uncertainty_modulated_lr = true`), `Empirical`
358    /// is strongly recommended — scale tree gradients are inherently weak and
359    /// the trees often fail to split.
360    #[cfg_attr(feature = "serde", serde(default))]
361    pub scale_mode: ScaleMode,
362
363    /// EWMA smoothing factor for empirical σ estimation.
364    ///
365    /// Controls the adaptation speed of `σ = sqrt(ewma_sq_err)` when
366    /// [`scale_mode`](Self::scale_mode) is [`Empirical`](ScaleMode::Empirical).
367    /// Higher values react faster to regime changes but are noisier.
368    ///
369    /// Default: `0.01` (~100-sample effective window).
370    #[cfg_attr(feature = "serde", serde(default = "default_empirical_sigma_alpha"))]
371    pub empirical_sigma_alpha: f64,
372
373    /// Maximum absolute leaf output value.
374    ///
375    /// When `Some(max)`, leaf predictions are clamped to `[-max, max]`.
376    /// Prevents runaway leaf weights from causing prediction explosions
377    /// in feedback loops. `None` (default) means no clamping.
378    #[cfg_attr(feature = "serde", serde(default))]
379    pub max_leaf_output: Option<f64>,
380
381    /// Per-leaf adaptive output bound (sigma multiplier).
382    ///
383    /// When `Some(k)`, each leaf tracks an EWMA of its own output weight and
384    /// clamps predictions to `|output_mean| + k * output_std`. The EWMA uses
385    /// `leaf_decay_alpha` when `leaf_half_life` is set, otherwise Welford online.
386    ///
387    /// This is strictly superior to `max_leaf_output` for streaming — the bound
388    /// is per-leaf, self-calibrating, and regime-synchronized. A leaf that usually
389    /// outputs 0.3 can't suddenly output 2.9 just because it fits in the global clamp.
390    ///
391    /// Typical value: 3.0 (3-sigma bound).
392    /// `None` (default) disables adaptive bounds (falls back to `max_leaf_output`).
393    #[cfg_attr(feature = "serde", serde(default))]
394    pub adaptive_leaf_bound: Option<f64>,
395
396    /// Per-split information criterion (Lunde-Kleppe-Skaug 2020).
397    ///
398    /// When `Some(cir_factor)`, replaces `max_depth` with a per-split
399    /// generalization test. Each candidate split must have
400    /// `gain > cir_factor * sigma^2_g / n * n_features`.
401    /// `max_depth * 2` becomes a hard safety ceiling.
402    ///
403    /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
404    /// `None` (default) uses static `max_depth` only.
405    #[cfg_attr(feature = "serde", serde(default))]
406    pub adaptive_depth: Option<f64>,
407
408    /// Minimum hessian sum before a leaf produces non-zero output.
409    ///
410    /// When `Some(min_h)`, leaves with `hess_sum < min_h` return 0.0.
411    /// Prevents post-replacement spikes from fresh leaves with insufficient
412    /// samples. `None` (default) means all leaves contribute immediately.
413    #[cfg_attr(feature = "serde", serde(default))]
414    pub min_hessian_sum: Option<f64>,
415
416    /// Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
417    ///
418    /// When `Some(k)`, the distributional location gradient uses Huber loss
419    /// with adaptive `delta = k * empirical_sigma`. This bounds gradients by
420    /// construction. Standard value: `1.345` (95% efficiency at Gaussian).
421    /// `None` (default) uses squared loss.
422    #[cfg_attr(feature = "serde", serde(default))]
423    pub huber_k: Option<f64>,
424
425    /// Shadow warmup for graduated tree handoff.
426    ///
427    /// When `Some(n)`, an always-on shadow (alternate) tree is spawned immediately
428    /// alongside every active tree. The shadow trains on the same gradient stream
429    /// but does not contribute to predictions until it has seen `n` samples.
430    ///
431    /// As the active tree ages past 80% of `max_tree_samples`, its prediction
432    /// weight linearly decays to 0 at 120%. The shadow's weight ramps from 0 to 1
433    /// over `n` samples after warmup. When the active weight reaches 0, the shadow
434    /// is promoted and a new shadow is spawned — no cold-start prediction dip.
435    ///
436    /// Requires `max_tree_samples` to be set for time-based graduated handoff.
437    /// Drift-based replacement still uses hard swap (shadow is already warm).
438    ///
439    /// `None` (default) disables graduated handoff — uses traditional hard swap.
440    #[cfg_attr(feature = "serde", serde(default))]
441    pub shadow_warmup: Option<usize>,
442
443    /// Leaf prediction model type.
444    ///
445    /// Controls how each leaf computes its prediction:
446    /// - [`ClosedForm`](LeafModelType::ClosedForm) (default): constant leaf weight.
447    /// - [`Linear`](LeafModelType::Linear): per-leaf online ridge regression with
448    ///   AdaGrad optimization. Optional `decay` for concept drift. Recommended for
449    ///   low-depth trees (depth 2--4).
450    /// - [`MLP`](LeafModelType::MLP): per-leaf single-hidden-layer neural network.
451    ///   Optional `decay` for concept drift.
452    /// - [`Adaptive`](LeafModelType::Adaptive): starts as closed-form, auto-promotes
453    ///   when the Hoeffding bound confirms a more complex model is better.
454    ///
455    /// Default: [`ClosedForm`](LeafModelType::ClosedForm).
456    #[cfg_attr(feature = "serde", serde(default))]
457    pub leaf_model_type: LeafModelType,
458
459    /// Packed cache refresh interval for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
460    ///
461    /// When non-zero, the distributional model maintains a packed f32 cache of
462    /// its location ensemble that is re-exported every `packed_refresh_interval`
463    /// training samples. Predictions use the cache for O(1)-per-tree inference
464    /// via contiguous memory traversal, falling back to full tree traversal when
465    /// the cache is absent or produces non-finite results.
466    ///
467    /// `0` (default) disables the packed cache.
468    #[cfg_attr(feature = "serde", serde(default))]
469    pub packed_refresh_interval: u64,
470
471    /// Enable per-node auto-bandwidth soft routing at prediction time.
472    ///
473    /// When `true`, predictions are continuous weighted blends instead of
474    /// piecewise-constant step functions. No training changes.
475    ///
476    /// Default: `false`.
477    #[cfg_attr(feature = "serde", serde(default))]
478    pub soft_routing: bool,
479}
480
481#[cfg(feature = "serde")]
482fn default_empirical_sigma_alpha() -> f64 {
483    0.01
484}
485
486#[cfg(feature = "serde")]
487fn default_quality_prune_threshold() -> f64 {
488    1e-6
489}
490
491#[cfg(feature = "serde")]
492fn default_quality_prune_patience() -> u64 {
493    500
494}
495
496impl Default for SGBTConfig {
497    fn default() -> Self {
498        Self {
499            n_steps: 100,
500            learning_rate: 0.0125,
501            feature_subsample_rate: 0.75,
502            max_depth: 6,
503            n_bins: 64,
504            lambda: 1.0,
505            gamma: 0.0,
506            grace_period: 200,
507            delta: 1e-7,
508            drift_detector: DriftDetectorType::default(),
509            variant: SGBTVariant::default(),
510            seed: 0xDEAD_BEEF_CAFE_4242,
511            initial_target_count: 50,
512            leaf_half_life: None,
513            max_tree_samples: None,
514            adaptive_mts: None,
515            proactive_prune_interval: None,
516            split_reeval_interval: None,
517            feature_names: None,
518            feature_types: None,
519            gradient_clip_sigma: None,
520            monotone_constraints: None,
521            quality_prune_alpha: None,
522            quality_prune_threshold: 1e-6,
523            quality_prune_patience: 500,
524            error_weight_alpha: None,
525            uncertainty_modulated_lr: false,
526            scale_mode: ScaleMode::default(),
527            empirical_sigma_alpha: 0.01,
528            max_leaf_output: None,
529            adaptive_leaf_bound: None,
530            adaptive_depth: None,
531            min_hessian_sum: None,
532            huber_k: None,
533            shadow_warmup: None,
534            leaf_model_type: LeafModelType::default(),
535            packed_refresh_interval: 0,
536            soft_routing: false,
537        }
538    }
539}
540
541impl SGBTConfig {
542    /// Start building a configuration via the builder pattern.
543    pub fn builder() -> SGBTConfigBuilder {
544        SGBTConfigBuilder::default()
545    }
546}
547
548// ---------------------------------------------------------------------------
549// SGBTConfigBuilder
550// ---------------------------------------------------------------------------
551
552/// Builder for [`SGBTConfig`] with validation on [`build()`](Self::build).
553///
554/// # Example
555///
556/// ```ignore
557/// use irithyll::ensemble::config::{SGBTConfig, DriftDetectorType};
558/// use irithyll::ensemble::variants::SGBTVariant;
559///
560/// let config = SGBTConfig::builder()
561///     .n_steps(200)
562///     .learning_rate(0.05)
563///     .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
564///     .variant(SGBTVariant::Skip { k: 10 })
565///     .build()
566///     .expect("valid config");
567/// ```
568#[derive(Debug, Clone, Default)]
569pub struct SGBTConfigBuilder {
570    config: SGBTConfig,
571}
572
573impl SGBTConfigBuilder {
574    /// Set the number of boosting steps (trees in the ensemble).
575    pub fn n_steps(mut self, n: usize) -> Self {
576        self.config.n_steps = n;
577        self
578    }
579
580    /// Set the learning rate (shrinkage factor).
581    pub fn learning_rate(mut self, lr: f64) -> Self {
582        self.config.learning_rate = lr;
583        self
584    }
585
586    /// Set the fraction of features to subsample per tree.
587    pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
588        self.config.feature_subsample_rate = rate;
589        self
590    }
591
592    /// Set the maximum tree depth.
593    pub fn max_depth(mut self, depth: usize) -> Self {
594        self.config.max_depth = depth;
595        self
596    }
597
598    /// Set the number of histogram bins per feature.
599    pub fn n_bins(mut self, bins: usize) -> Self {
600        self.config.n_bins = bins;
601        self
602    }
603
604    /// Set the L2 regularization parameter (lambda).
605    pub fn lambda(mut self, l: f64) -> Self {
606        self.config.lambda = l;
607        self
608    }
609
610    /// Set the minimum split gain (gamma).
611    pub fn gamma(mut self, g: f64) -> Self {
612        self.config.gamma = g;
613        self
614    }
615
616    /// Set the grace period (minimum samples before evaluating splits).
617    pub fn grace_period(mut self, gp: usize) -> Self {
618        self.config.grace_period = gp;
619        self
620    }
621
622    /// Set the Hoeffding bound confidence parameter (delta).
623    pub fn delta(mut self, d: f64) -> Self {
624        self.config.delta = d;
625        self
626    }
627
628    /// Set the drift detector type for tree replacement.
629    pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
630        self.config.drift_detector = dt;
631        self
632    }
633
634    /// Set the SGBT computational variant.
635    pub fn variant(mut self, v: SGBTVariant) -> Self {
636        self.config.variant = v;
637        self
638    }
639
640    /// Set the random seed for deterministic reproducibility.
641    ///
642    /// Controls feature subsampling and variant skip/MI stochastic decisions.
643    /// Two models with the same seed and data sequence will produce identical results.
644    pub fn seed(mut self, seed: u64) -> Self {
645        self.config.seed = seed;
646        self
647    }
648
649    /// Set the number of initial targets to collect before computing the base prediction.
650    ///
651    /// The model collects this many target values before initializing the base
652    /// prediction (via `loss.initial_prediction`). Default: 50.
653    pub fn initial_target_count(mut self, count: usize) -> Self {
654        self.config.initial_target_count = count;
655        self
656    }
657
658    /// Set the half-life for exponential leaf decay (in samples per leaf).
659    ///
660    /// After `n` samples, a leaf's accumulated statistics have half the weight
661    /// of the most recent sample. Enables continuous adaptation to concept drift.
662    pub fn leaf_half_life(mut self, n: usize) -> Self {
663        self.config.leaf_half_life = Some(n);
664        self
665    }
666
667    /// Set the maximum samples a single tree processes before proactive replacement.
668    ///
669    /// After `n` samples, the tree is replaced regardless of drift detector state.
670    pub fn max_tree_samples(mut self, n: u64) -> Self {
671        self.config.max_tree_samples = Some(n);
672        self
673    }
674
675    /// Enable sigma-modulated tree replacement speed.
676    ///
677    /// The effective max_tree_samples becomes `base_mts / (1.0 + k * normalized_sigma)`,
678    /// where `normalized_sigma` tracks contribution variance across ensemble trees.
679    /// Overrides `max_tree_samples` when set.
680    pub fn adaptive_mts(mut self, base_mts: u64, k: f64) -> Self {
681        self.config.adaptive_mts = Some((base_mts, k));
682        self
683    }
684
685    /// Set the proactive tree pruning interval.
686    ///
687    /// Every `interval` training samples, the tree with the lowest EWMA
688    /// marginal contribution is replaced. Automatically enables contribution
689    /// tracking with alpha=0.01 when `quality_prune_alpha` is not set.
690    pub fn proactive_prune_interval(mut self, interval: u64) -> Self {
691        self.config.proactive_prune_interval = Some(interval);
692        self
693    }
694
695    /// Set the split re-evaluation interval for max-depth leaves.
696    ///
697    /// Every `n` samples per leaf, max-depth leaves re-evaluate whether a split
698    /// would improve them. Inspired by EFDT (Manapragada et al. 2018).
699    pub fn split_reeval_interval(mut self, n: usize) -> Self {
700        self.config.split_reeval_interval = Some(n);
701        self
702    }
703
704    /// Set human-readable feature names.
705    ///
706    /// Enables named feature importances and named training input.
707    /// Names must be unique; validated at [`build()`](Self::build).
708    pub fn feature_names(mut self, names: Vec<String>) -> Self {
709        self.config.feature_names = Some(names);
710        self
711    }
712
713    /// Set per-feature type declarations.
714    ///
715    /// Declares which features are categorical vs continuous. Categorical features
716    /// use one-bin-per-category binning and Fisher optimal binary partitioning.
717    /// Supports up to 64 distinct category values per categorical feature.
718    pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
719        self.config.feature_types = Some(types);
720        self
721    }
722
723    /// Set per-leaf gradient clipping threshold (in standard deviations).
724    ///
725    /// Each leaf tracks an EWMA of gradient mean and variance. Gradients
726    /// exceeding `mean ± sigma * n` are clamped. Prevents outlier labels
727    /// from corrupting streaming model stability.
728    ///
729    /// Typical value: 3.0 (3-sigma clipping).
730    pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
731        self.config.gradient_clip_sigma = Some(sigma);
732        self
733    }
734
735    /// Set per-feature monotonic constraints.
736    ///
737    /// `+1` = non-decreasing, `-1` = non-increasing, `0` = unconstrained.
738    /// Candidate splits violating monotonicity are rejected during tree growth.
739    pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
740        self.config.monotone_constraints = Some(constraints);
741        self
742    }
743
744    /// Enable quality-based tree pruning with the given EWMA smoothing factor.
745    ///
746    /// Trees whose marginal contribution drops below the threshold for
747    /// `patience` consecutive samples are replaced with fresh trees.
748    /// Suggested alpha: 0.01.
749    pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
750        self.config.quality_prune_alpha = Some(alpha);
751        self
752    }
753
754    /// Set the minimum contribution threshold for quality-based pruning.
755    ///
756    /// Default: 1e-6. Only relevant when `quality_prune_alpha` is set.
757    pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
758        self.config.quality_prune_threshold = threshold;
759        self
760    }
761
762    /// Set the patience (consecutive low-contribution samples) before pruning.
763    ///
764    /// Default: 500. Only relevant when `quality_prune_alpha` is set.
765    pub fn quality_prune_patience(mut self, patience: u64) -> Self {
766        self.config.quality_prune_patience = patience;
767        self
768    }
769
770    /// Enable error-weighted sample importance with the given EWMA smoothing factor.
771    ///
772    /// Samples the model predicted poorly get higher effective weight.
773    /// Suggested alpha: 0.01.
774    pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
775        self.config.error_weight_alpha = Some(alpha);
776        self
777    }
778
779    /// Enable σ-modulated learning rate for distributional models.
780    ///
781    /// Scales the location (μ) learning rate by `current_sigma / rolling_sigma_mean`,
782    /// so the model adapts faster during high-uncertainty regimes and conserves
783    /// during stable periods. Only affects [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
784    ///
785    /// By default uses empirical σ (EWMA of squared errors).  Set
786    /// [`scale_mode(ScaleMode::TreeChain)`](Self::scale_mode) for feature-conditional σ.
787    pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
788        self.config.uncertainty_modulated_lr = enabled;
789        self
790    }
791
792    /// Set the scale estimation mode for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
793    ///
794    /// - [`Empirical`](ScaleMode::Empirical): EWMA of squared prediction errors (default, recommended).
795    /// - [`TreeChain`](ScaleMode::TreeChain): dual-chain NGBoost with scale tree ensemble.
796    pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
797        self.config.scale_mode = mode;
798        self
799    }
800
801    /// Enable per-node auto-bandwidth soft routing.
802    pub fn soft_routing(mut self, enabled: bool) -> Self {
803        self.config.soft_routing = enabled;
804        self
805    }
806
807    /// EWMA alpha for empirical σ. Controls adaptation speed. Default `0.01`.
808    ///
809    /// Only used when `scale_mode` is [`Empirical`](ScaleMode::Empirical).
810    pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
811        self.config.empirical_sigma_alpha = alpha;
812        self
813    }
814
815    /// Set the maximum absolute leaf output value.
816    ///
817    /// Clamps leaf predictions to `[-max, max]`, breaking feedback loops
818    /// that cause prediction explosions.
819    pub fn max_leaf_output(mut self, max: f64) -> Self {
820        self.config.max_leaf_output = Some(max);
821        self
822    }
823
824    /// Set per-leaf adaptive output bound (sigma multiplier).
825    ///
826    /// Each leaf tracks EWMA of its own output weight and clamps to
827    /// `|output_mean| + k * output_std`. Self-calibrating per-leaf.
828    /// Recommended: use with `leaf_half_life` for streaming scenarios.
829    pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
830        self.config.adaptive_leaf_bound = Some(k);
831        self
832    }
833
834    /// Set the per-split information criterion factor (Lunde-Kleppe-Skaug 2020).
835    ///
836    /// Replaces static `max_depth` with a per-split generalization test.
837    /// Typical: 7.5 (<=10 features), 9.0 (<=50), 11.0 (<=200).
838    pub fn adaptive_depth(mut self, factor: f64) -> Self {
839        self.config.adaptive_depth = Some(factor);
840        self
841    }
842
843    /// Set the minimum hessian sum for leaf output.
844    ///
845    /// Fresh leaves with `hess_sum < min_h` return 0.0, preventing
846    /// post-replacement spikes.
847    pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
848        self.config.min_hessian_sum = Some(min_h);
849        self
850    }
851
852    /// Set the Huber loss delta multiplier for [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
853    ///
854    /// When set, location gradients use Huber loss with adaptive
855    /// `delta = k * empirical_sigma`. Standard value: `1.345` (95% Gaussian efficiency).
856    pub fn huber_k(mut self, k: f64) -> Self {
857        self.config.huber_k = Some(k);
858        self
859    }
860
861    /// Enable graduated tree handoff with the given shadow warmup samples.
862    ///
863    /// Spawns an always-on shadow tree that trains alongside the active tree.
864    /// After `warmup` samples, the shadow begins contributing to predictions
865    /// via graduated blending. Eliminates prediction dips during tree replacement.
866    pub fn shadow_warmup(mut self, warmup: usize) -> Self {
867        self.config.shadow_warmup = Some(warmup);
868        self
869    }
870
871    /// Set the leaf prediction model type.
872    ///
873    /// [`LeafModelType::Linear`] is recommended for low-depth configurations
874    /// (depth 2--4) where per-leaf linear models reduce approximation error.
875    ///
876    /// [`LeafModelType::Adaptive`] automatically selects between closed-form and
877    /// a trainable model per leaf, using the Hoeffding bound for promotion.
878    pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
879        self.config.leaf_model_type = lmt;
880        self
881    }
882
883    /// Set the packed cache refresh interval for distributional models.
884    ///
885    /// When non-zero, [`DistributionalSGBT`](super::distributional::DistributionalSGBT)
886    /// maintains a packed f32 cache refreshed every `interval` training samples.
887    /// `0` (default) disables the cache.
888    pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
889        self.config.packed_refresh_interval = interval;
890        self
891    }
892
893    /// Validate and build the configuration.
894    ///
895    /// # Errors
896    ///
897    /// Returns [`InvalidConfig`](crate::IrithyllError::InvalidConfig) with a structured
898    /// [`ConfigError`] if any parameter is out of its valid range.
899    pub fn build(self) -> Result<SGBTConfig> {
900        let c = &self.config;
901
902        // -- Ensemble-level parameters --
903        if c.n_steps == 0 {
904            return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
905        }
906        if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
907            return Err(ConfigError::out_of_range(
908                "learning_rate",
909                "must be in (0, 1]",
910                c.learning_rate,
911            )
912            .into());
913        }
914        if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
915            return Err(ConfigError::out_of_range(
916                "feature_subsample_rate",
917                "must be in (0, 1]",
918                c.feature_subsample_rate,
919            )
920            .into());
921        }
922
923        // -- Tree-level parameters --
924        if c.max_depth == 0 {
925            return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
926        }
927        if c.n_bins < 2 {
928            return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
929        }
930        if c.lambda < 0.0 {
931            return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
932        }
933        if c.gamma < 0.0 {
934            return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
935        }
936        if c.grace_period == 0 {
937            return Err(
938                ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
939            );
940        }
941        if c.delta <= 0.0 || c.delta >= 1.0 {
942            return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
943        }
944
945        if c.initial_target_count == 0 {
946            return Err(ConfigError::out_of_range(
947                "initial_target_count",
948                "must be > 0",
949                c.initial_target_count,
950            )
951            .into());
952        }
953
954        // -- Streaming adaptation parameters --
955        if let Some(hl) = c.leaf_half_life {
956            if hl == 0 {
957                return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
958            }
959        }
960        if let Some(max) = c.max_tree_samples {
961            if max < 100 {
962                return Err(
963                    ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
964                );
965            }
966        }
967        if let Some((base_mts, k)) = c.adaptive_mts {
968            if base_mts < 100 {
969                return Err(ConfigError::out_of_range(
970                    "adaptive_mts.base_mts",
971                    "must be >= 100",
972                    base_mts,
973                )
974                .into());
975            }
976            if k <= 0.0 {
977                return Err(ConfigError::out_of_range("adaptive_mts.k", "must be > 0", k).into());
978            }
979        }
980        if let Some(interval) = c.proactive_prune_interval {
981            if interval < 100 {
982                return Err(ConfigError::out_of_range(
983                    "proactive_prune_interval",
984                    "must be >= 100",
985                    interval,
986                )
987                .into());
988            }
989        }
990        if let Some(interval) = c.split_reeval_interval {
991            if interval < c.grace_period {
992                return Err(ConfigError::invalid(
993                    "split_reeval_interval",
994                    format!(
995                        "must be >= grace_period ({}), got {}",
996                        c.grace_period, interval
997                    ),
998                )
999                .into());
1000            }
1001        }
1002
1003        // -- Feature names --
1004        if let Some(ref names) = c.feature_names {
1005            // O(n^2) duplicate check — names list is small so no HashSet needed
1006            for (i, name) in names.iter().enumerate() {
1007                for prev in &names[..i] {
1008                    if name == prev {
1009                        return Err(ConfigError::invalid(
1010                            "feature_names",
1011                            format!("duplicate feature name: '{}'", name),
1012                        )
1013                        .into());
1014                    }
1015                }
1016            }
1017        }
1018
1019        // -- Feature types --
1020        if let Some(ref types) = c.feature_types {
1021            if let Some(ref names) = c.feature_names {
1022                if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
1023                    return Err(ConfigError::invalid(
1024                        "feature_types",
1025                        format!(
1026                            "length ({}) must match feature_names length ({})",
1027                            types.len(),
1028                            names.len()
1029                        ),
1030                    )
1031                    .into());
1032                }
1033            }
1034        }
1035
1036        // -- Gradient clipping --
1037        if let Some(sigma) = c.gradient_clip_sigma {
1038            if sigma <= 0.0 {
1039                return Err(
1040                    ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
1041                );
1042            }
1043        }
1044
1045        // -- Monotonic constraints --
1046        if let Some(ref mc) = c.monotone_constraints {
1047            for (i, &v) in mc.iter().enumerate() {
1048                if v != -1 && v != 0 && v != 1 {
1049                    return Err(ConfigError::invalid(
1050                        "monotone_constraints",
1051                        format!("feature {}: must be -1, 0, or +1, got {}", i, v),
1052                    )
1053                    .into());
1054                }
1055            }
1056        }
1057
1058        // -- Leaf output clamping --
1059        if let Some(max) = c.max_leaf_output {
1060            if max <= 0.0 {
1061                return Err(
1062                    ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into(),
1063                );
1064            }
1065        }
1066
1067        // -- Per-leaf adaptive output bound --
1068        if let Some(k) = c.adaptive_leaf_bound {
1069            if k <= 0.0 {
1070                return Err(
1071                    ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into(),
1072                );
1073            }
1074        }
1075
1076        // -- Adaptive depth (per-split information criterion) --
1077        if let Some(factor) = c.adaptive_depth {
1078            if factor <= 0.0 {
1079                return Err(
1080                    ConfigError::out_of_range("adaptive_depth", "must be > 0", factor).into(),
1081                );
1082            }
1083        }
1084
1085        // -- Minimum hessian sum --
1086        if let Some(min_h) = c.min_hessian_sum {
1087            if min_h <= 0.0 {
1088                return Err(
1089                    ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into(),
1090                );
1091            }
1092        }
1093
1094        // -- Huber loss multiplier --
1095        if let Some(k) = c.huber_k {
1096            if k <= 0.0 {
1097                return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
1098            }
1099        }
1100
1101        // -- Shadow warmup --
1102        if let Some(warmup) = c.shadow_warmup {
1103            if warmup == 0 {
1104                return Err(ConfigError::out_of_range(
1105                    "shadow_warmup",
1106                    "must be > 0",
1107                    warmup as f64,
1108                )
1109                .into());
1110            }
1111        }
1112
1113        // -- Quality-based pruning parameters --
1114        if let Some(alpha) = c.quality_prune_alpha {
1115            if alpha <= 0.0 || alpha >= 1.0 {
1116                return Err(ConfigError::out_of_range(
1117                    "quality_prune_alpha",
1118                    "must be in (0, 1)",
1119                    alpha,
1120                )
1121                .into());
1122            }
1123        }
1124        if c.quality_prune_threshold <= 0.0 {
1125            return Err(ConfigError::out_of_range(
1126                "quality_prune_threshold",
1127                "must be > 0",
1128                c.quality_prune_threshold,
1129            )
1130            .into());
1131        }
1132        if c.quality_prune_patience == 0 {
1133            return Err(ConfigError::out_of_range(
1134                "quality_prune_patience",
1135                "must be > 0",
1136                c.quality_prune_patience,
1137            )
1138            .into());
1139        }
1140
1141        // -- Error-weighted sample importance --
1142        if let Some(alpha) = c.error_weight_alpha {
1143            if alpha <= 0.0 || alpha >= 1.0 {
1144                return Err(ConfigError::out_of_range(
1145                    "error_weight_alpha",
1146                    "must be in (0, 1)",
1147                    alpha,
1148                )
1149                .into());
1150            }
1151        }
1152
1153        // -- Drift detector parameters --
1154        match &c.drift_detector {
1155            DriftDetectorType::PageHinkley { delta, lambda } => {
1156                if *delta <= 0.0 {
1157                    return Err(ConfigError::out_of_range(
1158                        "drift_detector.PageHinkley.delta",
1159                        "must be > 0",
1160                        delta,
1161                    )
1162                    .into());
1163                }
1164                if *lambda <= 0.0 {
1165                    return Err(ConfigError::out_of_range(
1166                        "drift_detector.PageHinkley.lambda",
1167                        "must be > 0",
1168                        lambda,
1169                    )
1170                    .into());
1171                }
1172            }
1173            DriftDetectorType::Adwin { delta } => {
1174                if *delta <= 0.0 || *delta >= 1.0 {
1175                    return Err(ConfigError::out_of_range(
1176                        "drift_detector.Adwin.delta",
1177                        "must be in (0, 1)",
1178                        delta,
1179                    )
1180                    .into());
1181                }
1182            }
1183            DriftDetectorType::Ddm {
1184                warning_level,
1185                drift_level,
1186                min_instances,
1187            } => {
1188                if *warning_level <= 0.0 {
1189                    return Err(ConfigError::out_of_range(
1190                        "drift_detector.Ddm.warning_level",
1191                        "must be > 0",
1192                        warning_level,
1193                    )
1194                    .into());
1195                }
1196                if *drift_level <= 0.0 {
1197                    return Err(ConfigError::out_of_range(
1198                        "drift_detector.Ddm.drift_level",
1199                        "must be > 0",
1200                        drift_level,
1201                    )
1202                    .into());
1203                }
1204                if *drift_level <= *warning_level {
1205                    return Err(ConfigError::invalid(
1206                        "drift_detector.Ddm.drift_level",
1207                        format!(
1208                            "must be > warning_level ({}), got {}",
1209                            warning_level, drift_level
1210                        ),
1211                    )
1212                    .into());
1213                }
1214                if *min_instances == 0 {
1215                    return Err(ConfigError::out_of_range(
1216                        "drift_detector.Ddm.min_instances",
1217                        "must be > 0",
1218                        min_instances,
1219                    )
1220                    .into());
1221                }
1222            }
1223        }
1224
1225        // -- Variant parameters --
1226        match &c.variant {
1227            SGBTVariant::Standard => {} // no extra validation
1228            SGBTVariant::Skip { k } => {
1229                if *k == 0 {
1230                    return Err(
1231                        ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into(),
1232                    );
1233                }
1234            }
1235            SGBTVariant::MultipleIterations { multiplier } => {
1236                if *multiplier <= 0.0 {
1237                    return Err(ConfigError::out_of_range(
1238                        "variant.MultipleIterations.multiplier",
1239                        "must be > 0",
1240                        multiplier,
1241                    )
1242                    .into());
1243                }
1244            }
1245        }
1246
1247        Ok(self.config)
1248    }
1249}
1250
1251// ---------------------------------------------------------------------------
1252// Display impls
1253// ---------------------------------------------------------------------------
1254
1255impl core::fmt::Display for DriftDetectorType {
1256    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1257        match self {
1258            Self::PageHinkley { delta, lambda } => {
1259                write!(f, "PageHinkley(delta={}, lambda={})", delta, lambda)
1260            }
1261            Self::Adwin { delta } => write!(f, "Adwin(delta={})", delta),
1262            Self::Ddm {
1263                warning_level,
1264                drift_level,
1265                min_instances,
1266            } => write!(
1267                f,
1268                "Ddm(warning={}, drift={}, min_instances={})",
1269                warning_level, drift_level, min_instances
1270            ),
1271        }
1272    }
1273}
1274
1275impl core::fmt::Display for SGBTConfig {
1276    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1277        write!(
1278            f,
1279            "SGBTConfig {{ steps={}, lr={}, depth={}, bins={}, variant={}, drift={} }}",
1280            self.n_steps,
1281            self.learning_rate,
1282            self.max_depth,
1283            self.n_bins,
1284            self.variant,
1285            self.drift_detector,
1286        )
1287    }
1288}
1289
1290// ---------------------------------------------------------------------------
1291// Tests
1292// ---------------------------------------------------------------------------
1293
1294#[cfg(test)]
1295mod tests {
1296    use super::*;
1297    use alloc::format;
1298    use alloc::vec;
1299
1300    // ------------------------------------------------------------------
1301    // 1. Default config values are correct
1302    // ------------------------------------------------------------------
1303    #[test]
1304    fn default_config_values() {
1305        let cfg = SGBTConfig::default();
1306        assert_eq!(cfg.n_steps, 100);
1307        assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
1308        assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
1309        assert_eq!(cfg.max_depth, 6);
1310        assert_eq!(cfg.n_bins, 64);
1311        assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
1312        assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
1313        assert_eq!(cfg.grace_period, 200);
1314        assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
1315        assert_eq!(cfg.variant, SGBTVariant::Standard);
1316    }
1317
1318    // ------------------------------------------------------------------
1319    // 2. Builder chain works
1320    // ------------------------------------------------------------------
1321    #[test]
1322    fn builder_chain() {
1323        let cfg = SGBTConfig::builder()
1324            .n_steps(50)
1325            .learning_rate(0.1)
1326            .feature_subsample_rate(0.5)
1327            .max_depth(10)
1328            .n_bins(128)
1329            .lambda(0.5)
1330            .gamma(0.1)
1331            .grace_period(500)
1332            .delta(1e-3)
1333            .drift_detector(DriftDetectorType::Adwin { delta: 0.01 })
1334            .variant(SGBTVariant::Skip { k: 5 })
1335            .build()
1336            .expect("valid config");
1337
1338        assert_eq!(cfg.n_steps, 50);
1339        assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
1340        assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
1341        assert_eq!(cfg.max_depth, 10);
1342        assert_eq!(cfg.n_bins, 128);
1343        assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
1344        assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
1345        assert_eq!(cfg.grace_period, 500);
1346        assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
1347        assert_eq!(cfg.variant, SGBTVariant::Skip { k: 5 });
1348
1349        // Verify drift detector type is Adwin.
1350        match &cfg.drift_detector {
1351            DriftDetectorType::Adwin { delta } => {
1352                assert!((delta - 0.01).abs() < f64::EPSILON);
1353            }
1354            _ => panic!("expected Adwin drift detector"),
1355        }
1356    }
1357
1358    // ------------------------------------------------------------------
1359    // 3. Validation rejects invalid values
1360    // ------------------------------------------------------------------
1361    #[test]
1362    fn validation_rejects_n_steps_zero() {
1363        let result = SGBTConfig::builder().n_steps(0).build();
1364        assert!(result.is_err());
1365        let msg = format!("{}", result.unwrap_err());
1366        assert!(msg.contains("n_steps"));
1367    }
1368
1369    #[test]
1370    fn validation_rejects_learning_rate_zero() {
1371        let result = SGBTConfig::builder().learning_rate(0.0).build();
1372        assert!(result.is_err());
1373        let msg = format!("{}", result.unwrap_err());
1374        assert!(msg.contains("learning_rate"));
1375    }
1376
1377    #[test]
1378    fn validation_rejects_learning_rate_above_one() {
1379        let result = SGBTConfig::builder().learning_rate(1.5).build();
1380        assert!(result.is_err());
1381    }
1382
1383    #[test]
1384    fn validation_accepts_learning_rate_one() {
1385        let result = SGBTConfig::builder().learning_rate(1.0).build();
1386        assert!(result.is_ok());
1387    }
1388
1389    #[test]
1390    fn validation_rejects_negative_learning_rate() {
1391        let result = SGBTConfig::builder().learning_rate(-0.1).build();
1392        assert!(result.is_err());
1393    }
1394
1395    #[test]
1396    fn validation_rejects_feature_subsample_zero() {
1397        let result = SGBTConfig::builder().feature_subsample_rate(0.0).build();
1398        assert!(result.is_err());
1399    }
1400
1401    #[test]
1402    fn validation_rejects_feature_subsample_above_one() {
1403        let result = SGBTConfig::builder().feature_subsample_rate(1.01).build();
1404        assert!(result.is_err());
1405    }
1406
1407    #[test]
1408    fn validation_rejects_max_depth_zero() {
1409        let result = SGBTConfig::builder().max_depth(0).build();
1410        assert!(result.is_err());
1411    }
1412
1413    #[test]
1414    fn validation_rejects_n_bins_one() {
1415        let result = SGBTConfig::builder().n_bins(1).build();
1416        assert!(result.is_err());
1417    }
1418
1419    #[test]
1420    fn validation_rejects_negative_lambda() {
1421        let result = SGBTConfig::builder().lambda(-0.1).build();
1422        assert!(result.is_err());
1423    }
1424
1425    #[test]
1426    fn validation_accepts_zero_lambda() {
1427        let result = SGBTConfig::builder().lambda(0.0).build();
1428        assert!(result.is_ok());
1429    }
1430
1431    #[test]
1432    fn validation_rejects_negative_gamma() {
1433        let result = SGBTConfig::builder().gamma(-0.1).build();
1434        assert!(result.is_err());
1435    }
1436
1437    #[test]
1438    fn validation_rejects_grace_period_zero() {
1439        let result = SGBTConfig::builder().grace_period(0).build();
1440        assert!(result.is_err());
1441    }
1442
1443    #[test]
1444    fn validation_rejects_delta_zero() {
1445        let result = SGBTConfig::builder().delta(0.0).build();
1446        assert!(result.is_err());
1447    }
1448
1449    #[test]
1450    fn validation_rejects_delta_one() {
1451        let result = SGBTConfig::builder().delta(1.0).build();
1452        assert!(result.is_err());
1453    }
1454
1455    // ------------------------------------------------------------------
1456    // 3b. Drift detector parameter validation
1457    // ------------------------------------------------------------------
1458    #[test]
1459    fn validation_rejects_pht_negative_delta() {
1460        let result = SGBTConfig::builder()
1461            .drift_detector(DriftDetectorType::PageHinkley {
1462                delta: -1.0,
1463                lambda: 50.0,
1464            })
1465            .build();
1466        assert!(result.is_err());
1467        let msg = format!("{}", result.unwrap_err());
1468        assert!(msg.contains("PageHinkley"));
1469    }
1470
1471    #[test]
1472    fn validation_rejects_pht_zero_lambda() {
1473        let result = SGBTConfig::builder()
1474            .drift_detector(DriftDetectorType::PageHinkley {
1475                delta: 0.005,
1476                lambda: 0.0,
1477            })
1478            .build();
1479        assert!(result.is_err());
1480    }
1481
1482    #[test]
1483    fn validation_rejects_adwin_delta_out_of_range() {
1484        let result = SGBTConfig::builder()
1485            .drift_detector(DriftDetectorType::Adwin { delta: 0.0 })
1486            .build();
1487        assert!(result.is_err());
1488
1489        let result = SGBTConfig::builder()
1490            .drift_detector(DriftDetectorType::Adwin { delta: 1.0 })
1491            .build();
1492        assert!(result.is_err());
1493    }
1494
1495    #[test]
1496    fn validation_rejects_ddm_warning_above_drift() {
1497        let result = SGBTConfig::builder()
1498            .drift_detector(DriftDetectorType::Ddm {
1499                warning_level: 3.0,
1500                drift_level: 2.0,
1501                min_instances: 30,
1502            })
1503            .build();
1504        assert!(result.is_err());
1505        let msg = format!("{}", result.unwrap_err());
1506        assert!(msg.contains("drift_level"));
1507        assert!(msg.contains("must be > warning_level"));
1508    }
1509
1510    #[test]
1511    fn validation_rejects_ddm_equal_levels() {
1512        let result = SGBTConfig::builder()
1513            .drift_detector(DriftDetectorType::Ddm {
1514                warning_level: 2.0,
1515                drift_level: 2.0,
1516                min_instances: 30,
1517            })
1518            .build();
1519        assert!(result.is_err());
1520    }
1521
1522    #[test]
1523    fn validation_rejects_ddm_zero_min_instances() {
1524        let result = SGBTConfig::builder()
1525            .drift_detector(DriftDetectorType::Ddm {
1526                warning_level: 2.0,
1527                drift_level: 3.0,
1528                min_instances: 0,
1529            })
1530            .build();
1531        assert!(result.is_err());
1532    }
1533
1534    #[test]
1535    fn validation_rejects_ddm_zero_warning_level() {
1536        let result = SGBTConfig::builder()
1537            .drift_detector(DriftDetectorType::Ddm {
1538                warning_level: 0.0,
1539                drift_level: 3.0,
1540                min_instances: 30,
1541            })
1542            .build();
1543        assert!(result.is_err());
1544    }
1545
1546    // ------------------------------------------------------------------
1547    // 3c. Variant parameter validation
1548    // ------------------------------------------------------------------
1549    #[test]
1550    fn validation_rejects_skip_k_zero() {
1551        let result = SGBTConfig::builder()
1552            .variant(SGBTVariant::Skip { k: 0 })
1553            .build();
1554        assert!(result.is_err());
1555        let msg = format!("{}", result.unwrap_err());
1556        assert!(msg.contains("Skip"));
1557    }
1558
1559    #[test]
1560    fn validation_rejects_mi_zero_multiplier() {
1561        let result = SGBTConfig::builder()
1562            .variant(SGBTVariant::MultipleIterations { multiplier: 0.0 })
1563            .build();
1564        assert!(result.is_err());
1565    }
1566
1567    #[test]
1568    fn validation_rejects_mi_negative_multiplier() {
1569        let result = SGBTConfig::builder()
1570            .variant(SGBTVariant::MultipleIterations { multiplier: -1.0 })
1571            .build();
1572        assert!(result.is_err());
1573    }
1574
1575    #[test]
1576    fn validation_accepts_standard_variant() {
1577        let result = SGBTConfig::builder().variant(SGBTVariant::Standard).build();
1578        assert!(result.is_ok());
1579    }
1580
1581    // ------------------------------------------------------------------
1582    // 4. DriftDetectorType creates correct detector types
1583    // ------------------------------------------------------------------
1584    #[test]
1585    fn drift_detector_type_creates_page_hinkley() {
1586        let dt = DriftDetectorType::PageHinkley {
1587            delta: 0.01,
1588            lambda: 100.0,
1589        };
1590        let mut detector = dt.create();
1591
1592        // Should start with zero mean.
1593        assert_eq!(detector.estimated_mean(), 0.0);
1594
1595        // Feed a stable value -- should not drift.
1596        let signal = detector.update(1.0);
1597        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1598    }
1599
1600    #[test]
1601    fn drift_detector_type_creates_adwin() {
1602        let dt = DriftDetectorType::Adwin { delta: 0.05 };
1603        let mut detector = dt.create();
1604
1605        assert_eq!(detector.estimated_mean(), 0.0);
1606        let signal = detector.update(1.0);
1607        assert_ne!(signal, crate::drift::DriftSignal::Drift);
1608    }
1609
1610    #[test]
1611    fn drift_detector_type_creates_ddm() {
1612        let dt = DriftDetectorType::Ddm {
1613            warning_level: 2.0,
1614            drift_level: 3.0,
1615            min_instances: 30,
1616        };
1617        let mut detector = dt.create();
1618
1619        assert_eq!(detector.estimated_mean(), 0.0);
1620        let signal = detector.update(0.1);
1621        assert_eq!(signal, crate::drift::DriftSignal::Stable);
1622    }
1623
1624    // ------------------------------------------------------------------
1625    // 5. Default drift detector is PageHinkley
1626    // ------------------------------------------------------------------
1627    #[test]
1628    fn default_drift_detector_is_page_hinkley() {
1629        let dt = DriftDetectorType::default();
1630        match dt {
1631            DriftDetectorType::PageHinkley { delta, lambda } => {
1632                assert!((delta - 0.005).abs() < f64::EPSILON);
1633                assert!((lambda - 50.0).abs() < f64::EPSILON);
1634            }
1635            _ => panic!("expected default DriftDetectorType to be PageHinkley"),
1636        }
1637    }
1638
1639    // ------------------------------------------------------------------
1640    // 6. Default config builds without error
1641    // ------------------------------------------------------------------
1642    #[test]
1643    fn default_config_builds() {
1644        let result = SGBTConfig::builder().build();
1645        assert!(result.is_ok());
1646    }
1647
1648    // ------------------------------------------------------------------
1649    // 7. Config clone preserves all fields
1650    // ------------------------------------------------------------------
1651    #[test]
1652    fn config_clone_preserves_fields() {
1653        let original = SGBTConfig::builder()
1654            .n_steps(50)
1655            .learning_rate(0.05)
1656            .variant(SGBTVariant::MultipleIterations { multiplier: 5.0 })
1657            .build()
1658            .unwrap();
1659
1660        let cloned = original.clone();
1661
1662        assert_eq!(cloned.n_steps, 50);
1663        assert!((cloned.learning_rate - 0.05).abs() < f64::EPSILON);
1664        assert_eq!(
1665            cloned.variant,
1666            SGBTVariant::MultipleIterations { multiplier: 5.0 }
1667        );
1668    }
1669
1670    // ------------------------------------------------------------------
1671    // 8. All three DDM valid configs accepted
1672    // ------------------------------------------------------------------
1673    #[test]
1674    fn valid_ddm_config_accepted() {
1675        let result = SGBTConfig::builder()
1676            .drift_detector(DriftDetectorType::Ddm {
1677                warning_level: 1.5,
1678                drift_level: 2.5,
1679                min_instances: 10,
1680            })
1681            .build();
1682        assert!(result.is_ok());
1683    }
1684
1685    // ------------------------------------------------------------------
1686    // 9. Created detectors are functional (round-trip test)
1687    // ------------------------------------------------------------------
1688    #[test]
1689    fn created_detectors_are_functional() {
1690        // Create a detector, feed it data, verify it can detect drift.
1691        let dt = DriftDetectorType::PageHinkley {
1692            delta: 0.005,
1693            lambda: 50.0,
1694        };
1695        let mut detector = dt.create();
1696
1697        // Feed 500 stable values.
1698        for _ in 0..500 {
1699            detector.update(1.0);
1700        }
1701
1702        // Feed shifted values -- should eventually drift.
1703        let mut drifted = false;
1704        for _ in 0..500 {
1705            if detector.update(10.0) == crate::drift::DriftSignal::Drift {
1706                drifted = true;
1707                break;
1708            }
1709        }
1710        assert!(
1711            drifted,
1712            "detector created from DriftDetectorType should be functional"
1713        );
1714    }
1715
1716    // ------------------------------------------------------------------
1717    // 10. Boundary acceptance: n_bins=2 is the minimum valid
1718    // ------------------------------------------------------------------
1719    #[test]
1720    fn boundary_n_bins_two_accepted() {
1721        let result = SGBTConfig::builder().n_bins(2).build();
1722        assert!(result.is_ok());
1723    }
1724
1725    // ------------------------------------------------------------------
1726    // 11. Boundary acceptance: grace_period=1 is valid
1727    // ------------------------------------------------------------------
1728    #[test]
1729    fn boundary_grace_period_one_accepted() {
1730        let result = SGBTConfig::builder().grace_period(1).build();
1731        assert!(result.is_ok());
1732    }
1733
1734    // ------------------------------------------------------------------
1735    // 12. Feature names -- valid config with names
1736    // ------------------------------------------------------------------
1737    #[test]
1738    fn feature_names_accepted() {
1739        let cfg = SGBTConfig::builder()
1740            .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
1741            .build()
1742            .unwrap();
1743        assert_eq!(
1744            cfg.feature_names.as_ref().unwrap(),
1745            &["price", "volume", "spread"]
1746        );
1747    }
1748
1749    // ------------------------------------------------------------------
1750    // 13. Feature names -- duplicate names rejected
1751    // ------------------------------------------------------------------
1752    #[test]
1753    fn feature_names_rejects_duplicates() {
1754        let result = SGBTConfig::builder()
1755            .feature_names(vec!["price".into(), "volume".into(), "price".into()])
1756            .build();
1757        assert!(result.is_err());
1758        let msg = format!("{}", result.unwrap_err());
1759        assert!(msg.contains("duplicate"));
1760    }
1761
1762    // ------------------------------------------------------------------
1763    // 14. Feature names -- serde backward compat (missing field)
1764    //     Requires serde_json which is only in the full irithyll crate.
1765    // ------------------------------------------------------------------
1766
1767    // ------------------------------------------------------------------
1768    // 15. Feature names -- empty vec is valid
1769    // ------------------------------------------------------------------
1770    #[test]
1771    fn feature_names_empty_vec_accepted() {
1772        let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
1773        assert!(cfg.feature_names.unwrap().is_empty());
1774    }
1775
1776    // ------------------------------------------------------------------
1777    // 16. Adaptive leaf bound -- builder
1778    // ------------------------------------------------------------------
1779    #[test]
1780    fn builder_adaptive_leaf_bound() {
1781        let cfg = SGBTConfig::builder()
1782            .adaptive_leaf_bound(3.0)
1783            .build()
1784            .unwrap();
1785        assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
1786    }
1787
1788    // ------------------------------------------------------------------
1789    // 17. Adaptive leaf bound -- validation rejects zero
1790    // ------------------------------------------------------------------
1791    #[test]
1792    fn validation_rejects_zero_adaptive_leaf_bound() {
1793        let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
1794        assert!(result.is_err());
1795        let msg = format!("{}", result.unwrap_err());
1796        assert!(
1797            msg.contains("adaptive_leaf_bound"),
1798            "error should mention adaptive_leaf_bound: {}",
1799            msg,
1800        );
1801    }
1802
1803    // ------------------------------------------------------------------
1804    // 18. Adaptive leaf bound -- serde backward compat
1805    //     Requires serde_json which is only in the full irithyll crate.
1806    // ------------------------------------------------------------------
1807}