Skip to main content

anomstream_core/
config.rs

1//! Forest configuration and the [`ForestBuilder`] entry point.
2//!
3//! [`RcfConfig`] enforces the AWS `SageMaker` Random Cut Forest
4//! hyperparameter bounds at validation time. Callers should construct
5//! a forest through [`ForestBuilder`] rather than instantiating
6//! [`RcfConfig`] directly so the builder picks AWS-conformant defaults.
7//!
8//! Per-point dimensionality is encoded at the type level as the
9//! `D` const-generic on [`ForestBuilder`] / [`crate::RandomCutForest`]
10//! so the bounding-box and per-tree node storage live on the stack
11//! and the compiler can vectorise the hot tree-traversal loops.
12
13use alloc::format;
14use alloc::vec::Vec;
15
16use crate::error::{RcfError, RcfResult};
17use crate::forest::random_cut_forest::RandomCutForest;
18
19/// AWS lower bound for `feature_dim`.
20pub const MIN_DIMENSION: usize = 1;
21/// AWS upper bound for `feature_dim`.
22pub const MAX_DIMENSION: usize = 10_000;
23/// AWS lower bound for `num_trees`.
24pub const MIN_NUM_TREES: usize = 50;
25/// AWS upper bound for `num_trees`.
26pub const MAX_NUM_TREES: usize = 1_000;
27/// AWS default for `num_trees`.
28pub const DEFAULT_NUM_TREES: usize = 100;
29/// AWS lower bound for `num_samples_per_tree`.
30pub const MIN_SAMPLE_SIZE: usize = 1;
31/// AWS upper bound for `num_samples_per_tree`.
32pub const MAX_SAMPLE_SIZE: usize = 2_048;
33/// AWS default for `num_samples_per_tree`.
34pub const DEFAULT_SAMPLE_SIZE: usize = 256;
35/// Scaling numerator used to derive the default time-decay factor
36/// from `sample_size`: `default_time_decay = TIME_DECAY_NUMERATOR /
37/// sample_size`. `0.1` matches the AWS Java `CompactSampler` default
38/// and gives an effective reservoir "half-life" of a handful of
39/// reservoirs-worth of input — enough recency bias to track baseline
40/// drift on a streaming agent over hours / days without losing the
41/// uniform-sampling character on each individual window.
42pub const TIME_DECAY_NUMERATOR: f64 = 0.1;
43/// Default time-decay resolved against [`DEFAULT_SAMPLE_SIZE`] —
44/// `0.1 / 256 ≈ 3.9 × 10⁻⁴`. Prefer [`default_time_decay_for`] when
45/// the sample size differs from the default.
46// Cast is precision-safe: 256 fits in f64's mantissa exactly.
47#[allow(clippy::cast_precision_loss)]
48pub const DEFAULT_TIME_DECAY: f64 = TIME_DECAY_NUMERATOR / DEFAULT_SAMPLE_SIZE as f64;
49
50/// Compute the default time-decay for a given `sample_size`:
51/// `0.1 / sample_size`, clamped to `0.0` for `sample_size == 0`
52/// (which is caught separately by [`RcfConfig::validate`]).
53#[must_use]
54pub fn default_time_decay_for(sample_size: usize) -> f64 {
55    if sample_size == 0 {
56        return 0.0;
57    }
58    #[allow(clippy::cast_precision_loss)]
59    {
60        TIME_DECAY_NUMERATOR / sample_size as f64
61    }
62}
63/// Default warmup admission fraction — `1.0` disables the gate and
64/// matches the classic reservoir behaviour. Set below `1.0` (AWS
65/// uses `0.125`) to ramp admission during the cold-start period.
66pub const DEFAULT_INITIAL_ACCEPT_FRACTION: f64 = 1.0;
67
68/// Validated forest hyperparameters (dimension is encoded separately
69/// at the type level).
70///
71/// # Examples
72///
73/// ```
74/// use anomstream_core::{ForestBuilder, RcfConfig};
75///
76/// let builder = ForestBuilder::<4>::new()
77///     .num_trees(50)
78///     .sample_size(64);
79/// let cfg: &RcfConfig = builder.config();
80/// assert_eq!(cfg.num_trees, 50);
81/// cfg.validate().unwrap();
82/// ```
83#[derive(Debug, Clone, PartialEq)]
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[cfg_attr(feature = "serde", serde(try_from = "RcfConfigShadow"))]
86#[non_exhaustive]
87pub struct RcfConfig {
88    /// Number of trees in the forest (`num_trees`).
89    pub num_trees: usize,
90    /// Maximum reservoir size per tree (`num_samples_per_tree`).
91    pub sample_size: usize,
92    /// Time-decay factor applied to reservoir sampling weights. A
93    /// value of `0.0` restores strict uniform sampling; positive
94    /// values bias the reservoir toward recent points. Default
95    /// resolved by [`ForestBuilder`] is `0.1 / sample_size`, matching
96    /// the AWS Java `CompactSampler` reference.
97    pub time_decay: f64,
98    /// Optional deterministic seed; `None` falls back to entropy.
99    pub seed: Option<u64>,
100    /// Optional dedicated rayon thread pool size for the `parallel`
101    /// cargo feature. `None` means "use rayon's global pool"
102    /// (configurable via the `RAYON_NUM_THREADS` env var). `Some(n)`
103    /// builds a per-forest [`rayon::ThreadPool`] of `n` workers so
104    /// callers can isolate this forest from the rest of the
105    /// application's rayon workload. Ignored without `parallel`.
106    pub num_threads: Option<usize>,
107    /// Warmup admission fraction forwarded to every per-tree
108    /// [`crate::ReservoirSampler`]. See that type's module-level docs
109    /// for semantics. `1.0` disables the gate; smaller values ramp
110    /// admission during the cold-start period so the reservoir is
111    /// less dominated by the first few stream entries.
112    #[cfg_attr(feature = "serde", serde(default = "default_initial_accept_fraction"))]
113    pub initial_accept_fraction: f64,
114    /// Optional per-dimension multiplicative weights applied to every
115    /// point before it reaches the forest's hot paths (`update`,
116    /// `score`, `attribution`, `bootstrap`, `delete_by_value`). Length
117    /// must match the forest's compile-time dimension `D`.
118    ///
119    /// Intended for per-feature scale normalisation: when different
120    /// input dimensions have wildly different dynamic ranges
121    /// (packet-rate in `[10², 10⁶]`, protocol-mix ratios in `[0, 1]`,
122    /// entropy in `[0, 8]` bits), a naive random cut weights each
123    /// dimension by its raw range. Pre-scaling with `1 / stddev[d]`
124    /// recovers a unit-variance input space where every dim pulls
125    /// its weight. For full z-score normalisation the caller should
126    /// still mean-centre upstream — `feature_scales` is a weight, not
127    /// a full affine transform.
128    ///
129    /// `None` keeps the classic "forest sees the raw caller point"
130    /// behaviour. The field is `#[serde(default)]` so old snapshots
131    /// deserialise without migration.
132    #[cfg_attr(feature = "serde", serde(default))]
133    pub feature_scales: Option<Vec<f64>>,
134}
135
136/// Serde default for [`RcfConfig::initial_accept_fraction`] so payloads
137/// persisted before the warmup knob existed deserialise with the
138/// gate disabled.
139#[cfg(feature = "serde")]
140#[must_use]
141fn default_initial_accept_fraction() -> f64 {
142    DEFAULT_INITIAL_ACCEPT_FRACTION
143}
144
145/// Over-the-wire [`RcfConfig`] layout. Deserialization lands here
146/// first so [`TryFrom`] can re-run [`RcfConfig::validate`] —
147/// enforcing the AWS `SageMaker` hyperparameter bounds
148/// (`num_trees`, `sample_size`, `time_decay`) and finite /
149/// positive `feature_scales` — before a live config is handed out.
150#[cfg(feature = "serde")]
151#[derive(serde::Serialize, serde::Deserialize)]
152#[allow(clippy::missing_docs_in_private_items)]
153struct RcfConfigShadow {
154    num_trees: usize,
155    sample_size: usize,
156    time_decay: f64,
157    seed: Option<u64>,
158    num_threads: Option<usize>,
159    #[serde(default = "default_initial_accept_fraction")]
160    initial_accept_fraction: f64,
161    #[serde(default)]
162    feature_scales: Option<Vec<f64>>,
163}
164
165#[cfg(feature = "serde")]
166impl TryFrom<RcfConfigShadow> for RcfConfig {
167    type Error = RcfError;
168
169    fn try_from(raw: RcfConfigShadow) -> Result<Self, Self::Error> {
170        let cfg = Self {
171            num_trees: raw.num_trees,
172            sample_size: raw.sample_size,
173            time_decay: raw.time_decay,
174            seed: raw.seed,
175            num_threads: raw.num_threads,
176            initial_accept_fraction: raw.initial_accept_fraction,
177            feature_scales: raw.feature_scales,
178        };
179        cfg.validate()?;
180        Ok(cfg)
181    }
182}
183
184impl RcfConfig {
185    /// Validate the configuration against the AWS hyperparameter
186    /// bounds. The forest's compile-time dimension `D` is checked
187    /// separately via [`Self::validate_dimension`] so non-const
188    /// callers can apply the AWS bounds without instantiating a
189    /// generic.
190    ///
191    /// # Errors
192    ///
193    /// Returns [`RcfError::InvalidConfig`] with the offending
194    /// parameter when any bound is violated.
195    pub fn validate(&self) -> RcfResult<()> {
196        if !(MIN_NUM_TREES..=MAX_NUM_TREES).contains(&self.num_trees) {
197            return Err(RcfError::InvalidConfig(
198                format!(
199                    "num_trees {} out of [{}, {}]",
200                    self.num_trees, MIN_NUM_TREES, MAX_NUM_TREES
201                )
202                .into(),
203            ));
204        }
205        if !(MIN_SAMPLE_SIZE..=MAX_SAMPLE_SIZE).contains(&self.sample_size) {
206            return Err(RcfError::InvalidConfig(
207                format!(
208                    "sample_size {} out of [{}, {}]",
209                    self.sample_size, MIN_SAMPLE_SIZE, MAX_SAMPLE_SIZE
210                )
211                .into(),
212            ));
213        }
214        if !self.time_decay.is_finite() || !(0.0..=1.0).contains(&self.time_decay) {
215            return Err(RcfError::InvalidConfig(
216                format!("time_decay {} out of [0.0, 1.0]", self.time_decay).into(),
217            ));
218        }
219        if let Some(n) = self.num_threads
220            && n == 0
221        {
222            return Err(RcfError::InvalidConfig(
223                "num_threads must be > 0 when set; use None to fall back to rayon's global pool"
224                    .into(),
225            ));
226        }
227        if !self.initial_accept_fraction.is_finite()
228            || self.initial_accept_fraction <= 0.0
229            || self.initial_accept_fraction > 1.0
230        {
231            return Err(RcfError::InvalidConfig(
232                format!(
233                    "initial_accept_fraction {} out of (0.0, 1.0]",
234                    self.initial_accept_fraction
235                )
236                .into(),
237            ));
238        }
239        if let Some(scales) = &self.feature_scales {
240            for (i, s) in scales.iter().enumerate() {
241                if !s.is_finite() || *s <= 0.0 {
242                    return Err(RcfError::InvalidConfig(
243                        format!("feature_scales[{i}] must be finite and > 0, got {s}").into(),
244                    ));
245                }
246            }
247        }
248        Ok(())
249    }
250
251    /// Validate the declared [`RcfConfig::feature_scales`] against a
252    /// target per-point dimension `d`. When `feature_scales` is
253    /// `None`, the check is a no-op.
254    ///
255    /// # Errors
256    ///
257    /// Returns [`RcfError::DimensionMismatch`] when the scales vector
258    /// length does not equal `d`.
259    pub fn validate_feature_scales_dimension(&self, d: usize) -> RcfResult<()> {
260        if let Some(scales) = &self.feature_scales
261            && scales.len() != d
262        {
263            return Err(RcfError::DimensionMismatch {
264                expected: d,
265                got: scales.len(),
266            });
267        }
268        Ok(())
269    }
270
271    /// Validate the compile-time dimension `D` against the AWS
272    /// `feature_dim` bounds. Called by [`ForestBuilder::build`] so
273    /// every user-facing entry point gates on the AWS limits.
274    ///
275    /// # Errors
276    ///
277    /// Returns [`RcfError::InvalidConfig`] when `D` is outside
278    /// `[MIN_DIMENSION, MAX_DIMENSION]`.
279    pub fn validate_dimension(dimension: usize) -> RcfResult<()> {
280        if !(MIN_DIMENSION..=MAX_DIMENSION).contains(&dimension) {
281            return Err(RcfError::InvalidConfig(
282                format!("dimension {dimension} out of [{MIN_DIMENSION}, {MAX_DIMENSION}]").into(),
283            ));
284        }
285        Ok(())
286    }
287}
288
289/// Fluent builder for [`RandomCutForest`].
290///
291/// Defaults: `num_trees = 100`, `sample_size = 256`,
292/// `time_decay = 0.1 / sample_size` (matches AWS Java
293/// `CompactSampler`; call [`Self::time_decay`] with `0.0` to recover
294/// strict uniform sampling), RNG seeded from entropy.
295///
296/// `D` is the per-point dimensionality. Callers pin it at construction
297/// via turbofish: `ForestBuilder::<4>::new()`.
298///
299/// # Examples
300///
301/// ```
302/// use anomstream_core::ForestBuilder;
303///
304/// let mut forest = ForestBuilder::<4>::new()
305///     .num_trees(50)
306///     .sample_size(64)
307///     .seed(42)
308///     .build()
309///     .expect("AWS-conformant config");
310/// forest.update([0.0, 0.0, 0.0, 0.0]).expect("dim matches");
311/// ```
312#[derive(Debug, Clone)]
313pub struct ForestBuilder<const D: usize> {
314    /// Working configuration mutated by the fluent builder methods
315    /// and validated when [`ForestBuilder::build`] runs.
316    config: RcfConfig,
317    /// Whether the caller has explicitly overridden `time_decay` via
318    /// [`Self::time_decay`]. When `false`, [`Self::sample_size`] and
319    /// [`Self::build`] resolve `time_decay` from the current
320    /// `sample_size` so the AWS `0.1 / sample_size` default tracks
321    /// any reservoir-size override the caller applies.
322    time_decay_explicit: bool,
323}
324
325impl<const D: usize> Default for ForestBuilder<D> {
326    fn default() -> Self {
327        Self::new()
328    }
329}
330
331impl<const D: usize> ForestBuilder<D> {
332    /// Start a new builder for `D`-dimensional points.
333    #[must_use]
334    pub fn new() -> Self {
335        Self {
336            config: RcfConfig {
337                num_trees: DEFAULT_NUM_TREES,
338                sample_size: DEFAULT_SAMPLE_SIZE,
339                time_decay: default_time_decay_for(DEFAULT_SAMPLE_SIZE),
340                seed: None,
341                num_threads: None,
342                initial_accept_fraction: DEFAULT_INITIAL_ACCEPT_FRACTION,
343                feature_scales: None,
344            },
345            time_decay_explicit: false,
346        }
347    }
348
349    /// Override the number of trees.
350    #[must_use]
351    pub fn num_trees(mut self, n: usize) -> Self {
352        self.config.num_trees = n;
353        self
354    }
355
356    /// Override the per-tree reservoir size. When `time_decay` has
357    /// not been explicitly set via [`Self::time_decay`], the default
358    /// `0.1 / sample_size` is re-resolved against the new value so
359    /// the effective recency bias stays consistent with AWS's
360    /// `CompactSampler` formula.
361    #[must_use]
362    pub fn sample_size(mut self, s: usize) -> Self {
363        self.config.sample_size = s;
364        if !self.time_decay_explicit {
365            self.config.time_decay = default_time_decay_for(s);
366        }
367        self
368    }
369
370    /// Override the sampler time-decay factor. Pass `0.0` to disable
371    /// recency bias and recover strict uniform reservoir sampling.
372    /// Once called, the builder stops auto-resolving `time_decay`
373    /// from subsequent [`Self::sample_size`] changes — the caller's
374    /// choice wins.
375    #[must_use]
376    pub fn time_decay(mut self, d: f64) -> Self {
377        self.config.time_decay = d;
378        self.time_decay_explicit = true;
379        self
380    }
381
382    /// Pin the RNG seed for reproducible runs.
383    #[must_use]
384    pub fn seed(mut self, seed: u64) -> Self {
385        self.config.seed = Some(seed);
386        self
387    }
388
389    /// Build a dedicated rayon thread pool of size `n` for this
390    /// forest's parallel score / attribution / update paths.
391    /// Requires the `parallel` cargo feature. When unset (default)
392    /// rayon's global pool is used.
393    #[must_use]
394    pub fn num_threads(mut self, n: usize) -> Self {
395        self.config.num_threads = Some(n);
396        self
397    }
398
399    /// Override the warmup admission fraction forwarded to each
400    /// per-tree reservoir. See
401    /// [`crate::ReservoirSampler`] module-level docs for semantics.
402    /// `1.0` (the default) disables the gate; AWS's `CompactSampler`
403    /// uses `0.125`.
404    #[must_use]
405    pub fn initial_accept_fraction(mut self, f: f64) -> Self {
406        self.config.initial_accept_fraction = f;
407        self
408    }
409
410    /// Set per-dimension multiplicative weights applied to every
411    /// point before it reaches the forest's hot paths. See
412    /// [`RcfConfig::feature_scales`] for semantics. Pass the exact
413    /// `[f64; D]` array — length is checked against the builder's
414    /// compile-time `D` at [`Self::build`] time, contents are
415    /// checked for finiteness and positivity by
416    /// [`RcfConfig::validate`].
417    #[must_use]
418    pub fn feature_scales(mut self, scales: [f64; D]) -> Self {
419        self.config.feature_scales = Some(scales.to_vec());
420        self
421    }
422
423    /// Drop any previously-set `feature_scales` — returns the builder
424    /// to the unweighted state. Useful for tests that want to clear a
425    /// shared template's scale vector before building a specialised
426    /// forest.
427    #[must_use]
428    pub fn clear_feature_scales(mut self) -> Self {
429        self.config.feature_scales = None;
430        self
431    }
432
433    /// Read-only access to the config under construction.
434    #[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
435    pub fn config(&self) -> &RcfConfig {
436        &self.config
437    }
438
439    /// Per-point dimensionality (compile-time `D`).
440    #[must_use]
441    pub const fn dimension(&self) -> usize {
442        D
443    }
444
445    /// Validate the config and instantiate the forest.
446    ///
447    /// # Errors
448    ///
449    /// Forwards [`RcfConfig::validate`] errors and propagates any
450    /// failure from the underlying [`RandomCutForest`] constructor.
451    #[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
452    pub fn build(self) -> RcfResult<RandomCutForest<D>> {
453        RcfConfig::validate_dimension(D)?;
454        self.config.validate()?;
455        self.config.validate_feature_scales_dimension(D)?;
456        RandomCutForest::<D>::from_config(self.config)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    fn cfg(n: usize, s: usize, td: f64) -> RcfConfig {
465        RcfConfig {
466            num_trees: n,
467            sample_size: s,
468            time_decay: td,
469            seed: None,
470            num_threads: None,
471            initial_accept_fraction: DEFAULT_INITIAL_ACCEPT_FRACTION,
472            feature_scales: None,
473        }
474    }
475
476    #[test]
477    fn validate_default_passes() {
478        let c = cfg(DEFAULT_NUM_TREES, DEFAULT_SAMPLE_SIZE, DEFAULT_TIME_DECAY);
479        c.validate().unwrap();
480    }
481
482    #[test]
483    fn validate_dimension_rejects_zero() {
484        assert!(matches!(
485            RcfConfig::validate_dimension(0).unwrap_err(),
486            RcfError::InvalidConfig(_)
487        ));
488    }
489
490    #[test]
491    fn validate_dimension_rejects_above_max() {
492        assert!(RcfConfig::validate_dimension(10_001).is_err());
493    }
494
495    #[test]
496    fn validate_dimension_accepts_at_max() {
497        RcfConfig::validate_dimension(10_000).unwrap();
498    }
499
500    #[test]
501    fn validate_rejects_num_trees_below_min() {
502        assert!(cfg(49, 256, 0.0).validate().is_err());
503    }
504
505    #[test]
506    fn validate_accepts_num_trees_at_bounds() {
507        cfg(50, 256, 0.0).validate().unwrap();
508        cfg(1000, 256, 0.0).validate().unwrap();
509    }
510
511    #[test]
512    fn validate_rejects_num_trees_above_max() {
513        assert!(cfg(1001, 256, 0.0).validate().is_err());
514    }
515
516    #[test]
517    fn validate_rejects_sample_size_zero() {
518        assert!(cfg(100, 0, 0.0).validate().is_err());
519    }
520
521    #[test]
522    fn validate_accepts_sample_size_at_bounds() {
523        cfg(100, 1, 0.0).validate().unwrap();
524        cfg(100, 2048, 0.0).validate().unwrap();
525    }
526
527    #[test]
528    fn validate_rejects_sample_size_above_max() {
529        assert!(cfg(100, 2049, 0.0).validate().is_err());
530    }
531
532    #[test]
533    fn validate_rejects_negative_time_decay() {
534        assert!(cfg(100, 256, -0.01).validate().is_err());
535    }
536
537    #[test]
538    fn validate_rejects_time_decay_above_one() {
539        assert!(cfg(100, 256, 1.01).validate().is_err());
540    }
541
542    #[test]
543    fn validate_rejects_non_finite_time_decay() {
544        assert!(cfg(100, 256, f64::NAN).validate().is_err());
545        assert!(cfg(100, 256, f64::INFINITY).validate().is_err());
546    }
547
548    #[test]
549    fn validate_rejects_zero_num_threads() {
550        let mut c = cfg(100, 256, 0.0);
551        c.num_threads = Some(0);
552        assert!(matches!(
553            c.validate().unwrap_err(),
554            RcfError::InvalidConfig(_)
555        ));
556    }
557
558    #[test]
559    fn validate_accepts_some_num_threads() {
560        let mut c = cfg(100, 256, 0.0);
561        c.num_threads = Some(4);
562        c.validate().unwrap();
563    }
564
565    #[test]
566    fn validate_accepts_default_num_threads_none() {
567        let c = cfg(100, 256, 0.0);
568        assert_eq!(c.num_threads, None);
569        c.validate().unwrap();
570    }
571
572    #[test]
573    fn builder_num_threads_sets_field() {
574        let b = ForestBuilder::<4>::new().num_threads(8);
575        assert_eq!(b.config().num_threads, Some(8));
576    }
577
578    #[test]
579    fn validate_accepts_initial_accept_fraction_at_bounds() {
580        let mut c = cfg(100, 256, 0.0);
581        c.initial_accept_fraction = 0.001;
582        c.validate().unwrap();
583        c.initial_accept_fraction = 1.0;
584        c.validate().unwrap();
585    }
586
587    #[test]
588    fn validate_rejects_initial_accept_fraction_out_of_range() {
589        let mut c = cfg(100, 256, 0.0);
590        c.initial_accept_fraction = 0.0;
591        assert!(c.validate().is_err());
592        c.initial_accept_fraction = -0.1;
593        assert!(c.validate().is_err());
594        c.initial_accept_fraction = 1.01;
595        assert!(c.validate().is_err());
596    }
597
598    #[test]
599    fn validate_rejects_non_finite_initial_accept_fraction() {
600        let mut c = cfg(100, 256, 0.0);
601        c.initial_accept_fraction = f64::NAN;
602        assert!(c.validate().is_err());
603        c.initial_accept_fraction = f64::INFINITY;
604        assert!(c.validate().is_err());
605    }
606
607    #[test]
608    fn builder_initial_accept_fraction_sets_field() {
609        let b = ForestBuilder::<4>::new().initial_accept_fraction(0.125);
610        assert!((b.config().initial_accept_fraction - 0.125).abs() < f64::EPSILON);
611    }
612
613    #[test]
614    fn builder_defaults_initial_accept_fraction_to_one() {
615        let b = ForestBuilder::<4>::new();
616        assert!((b.config().initial_accept_fraction - 1.0).abs() < f64::EPSILON);
617    }
618
619    #[test]
620    fn builder_defaults_match_aws() {
621        let b = ForestBuilder::<8>::new();
622        assert_eq!(b.dimension(), 8);
623        assert_eq!(b.config().num_trees, 100);
624        assert_eq!(b.config().sample_size, 256);
625        assert!(
626            (b.config().time_decay - TIME_DECAY_NUMERATOR / 256.0).abs() < f64::EPSILON,
627            "default time_decay should resolve to 0.1 / sample_size, got {}",
628            b.config().time_decay
629        );
630        assert_eq!(b.config().seed, None);
631    }
632
633    #[test]
634    fn builder_sample_size_override_rescales_default_time_decay() {
635        let b = ForestBuilder::<4>::new().sample_size(128);
636        // No explicit time_decay — should auto-resolve to 0.1/128.
637        assert!(
638            (b.config().time_decay - TIME_DECAY_NUMERATOR / 128.0).abs() < f64::EPSILON,
639            "sample_size(128) should rescale default to 0.1 / 128, got {}",
640            b.config().time_decay,
641        );
642    }
643
644    #[test]
645    fn builder_explicit_time_decay_sticks_across_sample_size_override() {
646        let b = ForestBuilder::<4>::new().time_decay(0.05).sample_size(128);
647        // Explicit override must not be clobbered by later sample_size.
648        assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
649    }
650
651    #[test]
652    fn builder_sample_size_override_before_time_decay() {
653        let b = ForestBuilder::<4>::new().sample_size(128).time_decay(0.05);
654        // Explicit override applied after sample_size wins too.
655        assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
656    }
657
658    #[test]
659    fn builder_time_decay_zero_still_accepted() {
660        let b = ForestBuilder::<4>::new().time_decay(0.0);
661        assert!(b.config().time_decay.abs() < f64::EPSILON);
662        b.build().expect("time_decay=0 must still build");
663    }
664
665    #[test]
666    fn default_time_decay_for_zero_sample_size_is_zero() {
667        assert!(default_time_decay_for(0).abs() < f64::EPSILON);
668    }
669
670    #[test]
671    fn default_time_decay_for_default_sample_size_matches_constant() {
672        assert!(
673            (default_time_decay_for(DEFAULT_SAMPLE_SIZE) - DEFAULT_TIME_DECAY).abs() < f64::EPSILON,
674        );
675    }
676
677    #[test]
678    fn builder_overrides_apply() {
679        let b = ForestBuilder::<4>::new()
680            .num_trees(50)
681            .sample_size(64)
682            .time_decay(0.05)
683            .seed(42);
684        assert_eq!(b.config().num_trees, 50);
685        assert_eq!(b.config().sample_size, 64);
686        assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
687        assert_eq!(b.config().seed, Some(42));
688    }
689
690    #[test]
691    fn builder_build_validates() {
692        let err = ForestBuilder::<4>::new().num_trees(10).build().unwrap_err();
693        assert!(matches!(err, RcfError::InvalidConfig(_)));
694    }
695
696    #[cfg(all(feature = "serde", feature = "postcard"))]
697    #[test]
698    fn deserialize_rejects_out_of_range_num_trees() {
699        let bad = RcfConfigShadow {
700            num_trees: MAX_NUM_TREES + 1,
701            sample_size: 256,
702            time_decay: 0.0,
703            seed: None,
704            num_threads: None,
705            initial_accept_fraction: 1.0,
706            feature_scales: None,
707        };
708        let bytes = postcard::to_allocvec(&bad).unwrap();
709        let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
710        assert!(back.is_err());
711    }
712
713    #[cfg(all(feature = "serde", feature = "postcard"))]
714    #[test]
715    fn deserialize_rejects_nan_time_decay() {
716        let bad = RcfConfigShadow {
717            num_trees: 100,
718            sample_size: 256,
719            time_decay: f64::NAN,
720            seed: None,
721            num_threads: None,
722            initial_accept_fraction: 1.0,
723            feature_scales: None,
724        };
725        let bytes = postcard::to_allocvec(&bad).unwrap();
726        let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
727        assert!(back.is_err());
728    }
729
730    #[cfg(all(feature = "serde", feature = "postcard"))]
731    #[test]
732    fn deserialize_rejects_negative_feature_scale() {
733        let bad = RcfConfigShadow {
734            num_trees: 100,
735            sample_size: 256,
736            time_decay: 0.0,
737            seed: None,
738            num_threads: None,
739            initial_accept_fraction: 1.0,
740            feature_scales: Some(alloc::vec![1.0, -0.5]),
741        };
742        let bytes = postcard::to_allocvec(&bad).unwrap();
743        let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
744        assert!(back.is_err());
745    }
746}