use alloc::format;
use crate::config::ForestBuilder;
use crate::error::{RcfError, RcfResult};
use crate::thresholded::detector::ThresholdedForest;
pub const DEFAULT_Z_FACTOR: f64 = 3.0;
pub const DEFAULT_QUANTILE: f64 = 0.99;
pub const DEFAULT_SCORE_DECAY: f64 = 0.01;
pub const DEFAULT_MIN_OBSERVATIONS: u64 = 32;
pub const DEFAULT_MIN_THRESHOLD: f64 = 1.0;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ThresholdMode {
ZSigma {
z_factor: f64,
},
Quantile {
p: f64,
},
}
impl Default for ThresholdMode {
fn default() -> Self {
Self::ZSigma {
z_factor: DEFAULT_Z_FACTOR,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ThresholdedConfig {
pub z_factor: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub threshold_mode: ThresholdMode,
pub score_decay: f64,
pub min_observations: u64,
pub min_threshold: f64,
}
impl Default for ThresholdedConfig {
fn default() -> Self {
Self {
z_factor: DEFAULT_Z_FACTOR,
threshold_mode: ThresholdMode::default(),
score_decay: DEFAULT_SCORE_DECAY,
min_observations: DEFAULT_MIN_OBSERVATIONS,
min_threshold: DEFAULT_MIN_THRESHOLD,
}
}
}
impl ThresholdedConfig {
pub fn validate(&self) -> RcfResult<()> {
match self.threshold_mode {
ThresholdMode::ZSigma { z_factor } => {
if !z_factor.is_finite() || z_factor <= 0.0 {
return Err(RcfError::InvalidConfig(
format!("z_factor must be finite and > 0, got {z_factor}").into(),
));
}
}
ThresholdMode::Quantile { p } => {
if !p.is_finite() || !(0.0..1.0).contains(&p) || p <= 0.0 {
return Err(RcfError::InvalidConfig(
format!("Quantile p must be in (0.0, 1.0), got {p}").into(),
));
}
}
}
if !self.z_factor.is_finite() || self.z_factor <= 0.0 {
return Err(RcfError::InvalidConfig(
format!("z_factor must be finite and > 0, got {}", self.z_factor).into(),
));
}
if !self.score_decay.is_finite() || self.score_decay <= 0.0 || self.score_decay > 1.0 {
return Err(RcfError::InvalidConfig(
format!(
"score_decay must be in (0.0, 1.0], got {}",
self.score_decay
)
.into(),
));
}
if !self.min_threshold.is_finite() || self.min_threshold < 0.0 {
return Err(RcfError::InvalidConfig(
format!(
"min_threshold must be finite and >= 0, got {}",
self.min_threshold
)
.into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ThresholdedForestBuilder<const D: usize> {
forest: ForestBuilder<D>,
thresholded: ThresholdedConfig,
}
impl<const D: usize> Default for ThresholdedForestBuilder<D> {
fn default() -> Self {
Self::new()
}
}
impl<const D: usize> ThresholdedForestBuilder<D> {
#[must_use]
pub fn new() -> Self {
Self {
forest: ForestBuilder::<D>::new(),
thresholded: ThresholdedConfig::default(),
}
}
#[must_use]
pub fn num_trees(mut self, n: usize) -> Self {
self.forest = self.forest.num_trees(n);
self
}
#[must_use]
pub fn sample_size(mut self, s: usize) -> Self {
self.forest = self.forest.sample_size(s);
self
}
#[must_use]
pub fn time_decay(mut self, d: f64) -> Self {
self.forest = self.forest.time_decay(d);
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.forest = self.forest.seed(seed);
self
}
#[must_use]
pub fn num_threads(mut self, n: usize) -> Self {
self.forest = self.forest.num_threads(n);
self
}
#[must_use]
pub fn initial_accept_fraction(mut self, f: f64) -> Self {
self.forest = self.forest.initial_accept_fraction(f);
self
}
#[must_use]
pub fn feature_scales(mut self, scales: [f64; D]) -> Self {
self.forest = self.forest.feature_scales(scales);
self
}
#[must_use]
pub fn z_factor(mut self, z: f64) -> Self {
self.thresholded.z_factor = z;
self.thresholded.threshold_mode = ThresholdMode::ZSigma { z_factor: z };
self
}
#[must_use]
pub fn quantile_threshold(mut self, p: f64) -> Self {
self.thresholded.threshold_mode = ThresholdMode::Quantile { p };
self
}
#[must_use]
pub fn score_decay(mut self, d: f64) -> Self {
self.thresholded.score_decay = d;
self
}
#[must_use]
pub fn min_observations(mut self, n: u64) -> Self {
self.thresholded.min_observations = n;
self
}
#[must_use]
pub fn min_threshold(mut self, t: f64) -> Self {
self.thresholded.min_threshold = t;
self
}
#[must_use]
pub fn forest_builder(&self) -> &ForestBuilder<D> {
&self.forest
}
#[must_use]
pub fn thresholded_config(&self) -> &ThresholdedConfig {
&self.thresholded
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
pub fn build(self) -> RcfResult<ThresholdedForest<D>> {
self.thresholded.validate()?;
let forest = self.forest.build()?;
ThresholdedForest::<D>::from_parts(forest, self.thresholded)
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
#[test]
fn default_config_validates() {
ThresholdedConfig::default().validate().unwrap();
}
#[test]
fn default_config_fields_match_constants() {
let c = ThresholdedConfig::default();
assert_eq!(c.z_factor, DEFAULT_Z_FACTOR);
assert_eq!(c.score_decay, DEFAULT_SCORE_DECAY);
assert_eq!(c.min_observations, DEFAULT_MIN_OBSERVATIONS);
assert_eq!(c.min_threshold, DEFAULT_MIN_THRESHOLD);
}
fn cfg(z: f64, decay: f64, min_obs: u64, min_thr: f64) -> ThresholdedConfig {
ThresholdedConfig {
z_factor: z,
threshold_mode: ThresholdMode::ZSigma { z_factor: z },
score_decay: decay,
min_observations: min_obs,
min_threshold: min_thr,
}
}
#[test]
fn validate_rejects_non_finite_z_factor() {
assert!(
cfg(f64::NAN, DEFAULT_SCORE_DECAY, 1, 0.0)
.validate()
.is_err()
);
assert!(
cfg(f64::INFINITY, DEFAULT_SCORE_DECAY, 1, 0.0)
.validate()
.is_err()
);
}
#[test]
fn validate_rejects_non_positive_z_factor() {
assert!(cfg(0.0, DEFAULT_SCORE_DECAY, 1, 0.0).validate().is_err());
assert!(cfg(-1.0, DEFAULT_SCORE_DECAY, 1, 0.0).validate().is_err());
}
#[test]
fn validate_rejects_score_decay_outside_range() {
assert!(cfg(DEFAULT_Z_FACTOR, 0.0, 1, 0.0).validate().is_err());
assert!(cfg(DEFAULT_Z_FACTOR, 1.5, 1, 0.0).validate().is_err());
assert!(cfg(DEFAULT_Z_FACTOR, f64::NAN, 1, 0.0).validate().is_err());
}
#[test]
fn validate_rejects_negative_min_threshold() {
assert!(
cfg(DEFAULT_Z_FACTOR, DEFAULT_SCORE_DECAY, 1, -0.001)
.validate()
.is_err()
);
}
#[test]
fn builder_defaults_pass_validation() {
let b = ThresholdedForestBuilder::<4>::new();
b.thresholded_config().validate().unwrap();
b.forest_builder().config().validate().unwrap();
}
#[test]
fn builder_overrides_apply_to_both_layers() {
let b = ThresholdedForestBuilder::<4>::new()
.num_trees(150)
.sample_size(128)
.z_factor(2.5)
.score_decay(0.05)
.min_observations(10)
.min_threshold(0.5)
.initial_accept_fraction(0.125)
.seed(7);
assert_eq!(b.forest_builder().config().num_trees, 150);
assert_eq!(b.forest_builder().config().sample_size, 128);
assert_eq!(b.forest_builder().config().seed, Some(7));
assert!((b.forest_builder().config().initial_accept_fraction - 0.125).abs() < f64::EPSILON);
assert_eq!(b.thresholded_config().z_factor, 2.5);
assert_eq!(b.thresholded_config().score_decay, 0.05);
assert_eq!(b.thresholded_config().min_observations, 10);
assert_eq!(b.thresholded_config().min_threshold, 0.5);
}
#[test]
fn builder_build_validates_forest_layer() {
let err = ThresholdedForestBuilder::<4>::new()
.num_trees(10)
.build()
.unwrap_err();
assert!(matches!(err, RcfError::InvalidConfig(_)));
}
#[test]
fn builder_build_validates_threshold_layer() {
let err = ThresholdedForestBuilder::<4>::new()
.z_factor(-1.0)
.build()
.unwrap_err();
assert!(matches!(err, RcfError::InvalidConfig(_)));
}
}