use alloc::format;
use alloc::vec::Vec;
use crate::error::{RcfError, RcfResult};
use crate::forest::random_cut_forest::RandomCutForest;
pub const MIN_DIMENSION: usize = 1;
pub const MAX_DIMENSION: usize = 10_000;
pub const MIN_NUM_TREES: usize = 50;
pub const MAX_NUM_TREES: usize = 1_000;
pub const DEFAULT_NUM_TREES: usize = 100;
pub const MIN_SAMPLE_SIZE: usize = 1;
pub const MAX_SAMPLE_SIZE: usize = 2_048;
pub const DEFAULT_SAMPLE_SIZE: usize = 256;
pub const TIME_DECAY_NUMERATOR: f64 = 0.1;
#[allow(clippy::cast_precision_loss)]
pub const DEFAULT_TIME_DECAY: f64 = TIME_DECAY_NUMERATOR / DEFAULT_SAMPLE_SIZE as f64;
#[must_use]
pub fn default_time_decay_for(sample_size: usize) -> f64 {
if sample_size == 0 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
{
TIME_DECAY_NUMERATOR / sample_size as f64
}
}
pub const DEFAULT_INITIAL_ACCEPT_FRACTION: f64 = 1.0;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(try_from = "RcfConfigShadow"))]
#[non_exhaustive]
pub struct RcfConfig {
pub num_trees: usize,
pub sample_size: usize,
pub time_decay: f64,
pub seed: Option<u64>,
pub num_threads: Option<usize>,
#[cfg_attr(feature = "serde", serde(default = "default_initial_accept_fraction"))]
pub initial_accept_fraction: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub feature_scales: Option<Vec<f64>>,
}
#[cfg(feature = "serde")]
#[must_use]
fn default_initial_accept_fraction() -> f64 {
DEFAULT_INITIAL_ACCEPT_FRACTION
}
#[cfg(feature = "serde")]
#[derive(serde::Serialize, serde::Deserialize)]
#[allow(clippy::missing_docs_in_private_items)]
struct RcfConfigShadow {
num_trees: usize,
sample_size: usize,
time_decay: f64,
seed: Option<u64>,
num_threads: Option<usize>,
#[serde(default = "default_initial_accept_fraction")]
initial_accept_fraction: f64,
#[serde(default)]
feature_scales: Option<Vec<f64>>,
}
#[cfg(feature = "serde")]
impl TryFrom<RcfConfigShadow> for RcfConfig {
type Error = RcfError;
fn try_from(raw: RcfConfigShadow) -> Result<Self, Self::Error> {
let cfg = Self {
num_trees: raw.num_trees,
sample_size: raw.sample_size,
time_decay: raw.time_decay,
seed: raw.seed,
num_threads: raw.num_threads,
initial_accept_fraction: raw.initial_accept_fraction,
feature_scales: raw.feature_scales,
};
cfg.validate()?;
Ok(cfg)
}
}
impl RcfConfig {
pub fn validate(&self) -> RcfResult<()> {
if !(MIN_NUM_TREES..=MAX_NUM_TREES).contains(&self.num_trees) {
return Err(RcfError::InvalidConfig(
format!(
"num_trees {} out of [{}, {}]",
self.num_trees, MIN_NUM_TREES, MAX_NUM_TREES
)
.into(),
));
}
if !(MIN_SAMPLE_SIZE..=MAX_SAMPLE_SIZE).contains(&self.sample_size) {
return Err(RcfError::InvalidConfig(
format!(
"sample_size {} out of [{}, {}]",
self.sample_size, MIN_SAMPLE_SIZE, MAX_SAMPLE_SIZE
)
.into(),
));
}
if !self.time_decay.is_finite() || !(0.0..=1.0).contains(&self.time_decay) {
return Err(RcfError::InvalidConfig(
format!("time_decay {} out of [0.0, 1.0]", self.time_decay).into(),
));
}
if let Some(n) = self.num_threads
&& n == 0
{
return Err(RcfError::InvalidConfig(
"num_threads must be > 0 when set; use None to fall back to rayon's global pool"
.into(),
));
}
if !self.initial_accept_fraction.is_finite()
|| self.initial_accept_fraction <= 0.0
|| self.initial_accept_fraction > 1.0
{
return Err(RcfError::InvalidConfig(
format!(
"initial_accept_fraction {} out of (0.0, 1.0]",
self.initial_accept_fraction
)
.into(),
));
}
if let Some(scales) = &self.feature_scales {
for (i, s) in scales.iter().enumerate() {
if !s.is_finite() || *s <= 0.0 {
return Err(RcfError::InvalidConfig(
format!("feature_scales[{i}] must be finite and > 0, got {s}").into(),
));
}
}
}
Ok(())
}
pub fn validate_feature_scales_dimension(&self, d: usize) -> RcfResult<()> {
if let Some(scales) = &self.feature_scales
&& scales.len() != d
{
return Err(RcfError::DimensionMismatch {
expected: d,
got: scales.len(),
});
}
Ok(())
}
pub fn validate_dimension(dimension: usize) -> RcfResult<()> {
if !(MIN_DIMENSION..=MAX_DIMENSION).contains(&dimension) {
return Err(RcfError::InvalidConfig(
format!("dimension {dimension} out of [{MIN_DIMENSION}, {MAX_DIMENSION}]").into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ForestBuilder<const D: usize> {
config: RcfConfig,
time_decay_explicit: bool,
}
impl<const D: usize> Default for ForestBuilder<D> {
fn default() -> Self {
Self::new()
}
}
impl<const D: usize> ForestBuilder<D> {
#[must_use]
pub fn new() -> Self {
Self {
config: RcfConfig {
num_trees: DEFAULT_NUM_TREES,
sample_size: DEFAULT_SAMPLE_SIZE,
time_decay: default_time_decay_for(DEFAULT_SAMPLE_SIZE),
seed: None,
num_threads: None,
initial_accept_fraction: DEFAULT_INITIAL_ACCEPT_FRACTION,
feature_scales: None,
},
time_decay_explicit: false,
}
}
#[must_use]
pub fn num_trees(mut self, n: usize) -> Self {
self.config.num_trees = n;
self
}
#[must_use]
pub fn sample_size(mut self, s: usize) -> Self {
self.config.sample_size = s;
if !self.time_decay_explicit {
self.config.time_decay = default_time_decay_for(s);
}
self
}
#[must_use]
pub fn time_decay(mut self, d: f64) -> Self {
self.config.time_decay = d;
self.time_decay_explicit = true;
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = Some(seed);
self
}
#[must_use]
pub fn num_threads(mut self, n: usize) -> Self {
self.config.num_threads = Some(n);
self
}
#[must_use]
pub fn initial_accept_fraction(mut self, f: f64) -> Self {
self.config.initial_accept_fraction = f;
self
}
#[must_use]
pub fn feature_scales(mut self, scales: [f64; D]) -> Self {
self.config.feature_scales = Some(scales.to_vec());
self
}
#[must_use]
pub fn clear_feature_scales(mut self) -> Self {
self.config.feature_scales = None;
self
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
pub fn config(&self) -> &RcfConfig {
&self.config
}
#[must_use]
pub const fn dimension(&self) -> usize {
D
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
pub fn build(self) -> RcfResult<RandomCutForest<D>> {
RcfConfig::validate_dimension(D)?;
self.config.validate()?;
self.config.validate_feature_scales_dimension(D)?;
RandomCutForest::<D>::from_config(self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(n: usize, s: usize, td: f64) -> RcfConfig {
RcfConfig {
num_trees: n,
sample_size: s,
time_decay: td,
seed: None,
num_threads: None,
initial_accept_fraction: DEFAULT_INITIAL_ACCEPT_FRACTION,
feature_scales: None,
}
}
#[test]
fn validate_default_passes() {
let c = cfg(DEFAULT_NUM_TREES, DEFAULT_SAMPLE_SIZE, DEFAULT_TIME_DECAY);
c.validate().unwrap();
}
#[test]
fn validate_dimension_rejects_zero() {
assert!(matches!(
RcfConfig::validate_dimension(0).unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn validate_dimension_rejects_above_max() {
assert!(RcfConfig::validate_dimension(10_001).is_err());
}
#[test]
fn validate_dimension_accepts_at_max() {
RcfConfig::validate_dimension(10_000).unwrap();
}
#[test]
fn validate_rejects_num_trees_below_min() {
assert!(cfg(49, 256, 0.0).validate().is_err());
}
#[test]
fn validate_accepts_num_trees_at_bounds() {
cfg(50, 256, 0.0).validate().unwrap();
cfg(1000, 256, 0.0).validate().unwrap();
}
#[test]
fn validate_rejects_num_trees_above_max() {
assert!(cfg(1001, 256, 0.0).validate().is_err());
}
#[test]
fn validate_rejects_sample_size_zero() {
assert!(cfg(100, 0, 0.0).validate().is_err());
}
#[test]
fn validate_accepts_sample_size_at_bounds() {
cfg(100, 1, 0.0).validate().unwrap();
cfg(100, 2048, 0.0).validate().unwrap();
}
#[test]
fn validate_rejects_sample_size_above_max() {
assert!(cfg(100, 2049, 0.0).validate().is_err());
}
#[test]
fn validate_rejects_negative_time_decay() {
assert!(cfg(100, 256, -0.01).validate().is_err());
}
#[test]
fn validate_rejects_time_decay_above_one() {
assert!(cfg(100, 256, 1.01).validate().is_err());
}
#[test]
fn validate_rejects_non_finite_time_decay() {
assert!(cfg(100, 256, f64::NAN).validate().is_err());
assert!(cfg(100, 256, f64::INFINITY).validate().is_err());
}
#[test]
fn validate_rejects_zero_num_threads() {
let mut c = cfg(100, 256, 0.0);
c.num_threads = Some(0);
assert!(matches!(
c.validate().unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn validate_accepts_some_num_threads() {
let mut c = cfg(100, 256, 0.0);
c.num_threads = Some(4);
c.validate().unwrap();
}
#[test]
fn validate_accepts_default_num_threads_none() {
let c = cfg(100, 256, 0.0);
assert_eq!(c.num_threads, None);
c.validate().unwrap();
}
#[test]
fn builder_num_threads_sets_field() {
let b = ForestBuilder::<4>::new().num_threads(8);
assert_eq!(b.config().num_threads, Some(8));
}
#[test]
fn validate_accepts_initial_accept_fraction_at_bounds() {
let mut c = cfg(100, 256, 0.0);
c.initial_accept_fraction = 0.001;
c.validate().unwrap();
c.initial_accept_fraction = 1.0;
c.validate().unwrap();
}
#[test]
fn validate_rejects_initial_accept_fraction_out_of_range() {
let mut c = cfg(100, 256, 0.0);
c.initial_accept_fraction = 0.0;
assert!(c.validate().is_err());
c.initial_accept_fraction = -0.1;
assert!(c.validate().is_err());
c.initial_accept_fraction = 1.01;
assert!(c.validate().is_err());
}
#[test]
fn validate_rejects_non_finite_initial_accept_fraction() {
let mut c = cfg(100, 256, 0.0);
c.initial_accept_fraction = f64::NAN;
assert!(c.validate().is_err());
c.initial_accept_fraction = f64::INFINITY;
assert!(c.validate().is_err());
}
#[test]
fn builder_initial_accept_fraction_sets_field() {
let b = ForestBuilder::<4>::new().initial_accept_fraction(0.125);
assert!((b.config().initial_accept_fraction - 0.125).abs() < f64::EPSILON);
}
#[test]
fn builder_defaults_initial_accept_fraction_to_one() {
let b = ForestBuilder::<4>::new();
assert!((b.config().initial_accept_fraction - 1.0).abs() < f64::EPSILON);
}
#[test]
fn builder_defaults_match_aws() {
let b = ForestBuilder::<8>::new();
assert_eq!(b.dimension(), 8);
assert_eq!(b.config().num_trees, 100);
assert_eq!(b.config().sample_size, 256);
assert!(
(b.config().time_decay - TIME_DECAY_NUMERATOR / 256.0).abs() < f64::EPSILON,
"default time_decay should resolve to 0.1 / sample_size, got {}",
b.config().time_decay
);
assert_eq!(b.config().seed, None);
}
#[test]
fn builder_sample_size_override_rescales_default_time_decay() {
let b = ForestBuilder::<4>::new().sample_size(128);
assert!(
(b.config().time_decay - TIME_DECAY_NUMERATOR / 128.0).abs() < f64::EPSILON,
"sample_size(128) should rescale default to 0.1 / 128, got {}",
b.config().time_decay,
);
}
#[test]
fn builder_explicit_time_decay_sticks_across_sample_size_override() {
let b = ForestBuilder::<4>::new().time_decay(0.05).sample_size(128);
assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
}
#[test]
fn builder_sample_size_override_before_time_decay() {
let b = ForestBuilder::<4>::new().sample_size(128).time_decay(0.05);
assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
}
#[test]
fn builder_time_decay_zero_still_accepted() {
let b = ForestBuilder::<4>::new().time_decay(0.0);
assert!(b.config().time_decay.abs() < f64::EPSILON);
b.build().expect("time_decay=0 must still build");
}
#[test]
fn default_time_decay_for_zero_sample_size_is_zero() {
assert!(default_time_decay_for(0).abs() < f64::EPSILON);
}
#[test]
fn default_time_decay_for_default_sample_size_matches_constant() {
assert!(
(default_time_decay_for(DEFAULT_SAMPLE_SIZE) - DEFAULT_TIME_DECAY).abs() < f64::EPSILON,
);
}
#[test]
fn builder_overrides_apply() {
let b = ForestBuilder::<4>::new()
.num_trees(50)
.sample_size(64)
.time_decay(0.05)
.seed(42);
assert_eq!(b.config().num_trees, 50);
assert_eq!(b.config().sample_size, 64);
assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
assert_eq!(b.config().seed, Some(42));
}
#[test]
fn builder_build_validates() {
let err = ForestBuilder::<4>::new().num_trees(10).build().unwrap_err();
assert!(matches!(err, RcfError::InvalidConfig(_)));
}
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn deserialize_rejects_out_of_range_num_trees() {
let bad = RcfConfigShadow {
num_trees: MAX_NUM_TREES + 1,
sample_size: 256,
time_decay: 0.0,
seed: None,
num_threads: None,
initial_accept_fraction: 1.0,
feature_scales: None,
};
let bytes = postcard::to_allocvec(&bad).unwrap();
let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
assert!(back.is_err());
}
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn deserialize_rejects_nan_time_decay() {
let bad = RcfConfigShadow {
num_trees: 100,
sample_size: 256,
time_decay: f64::NAN,
seed: None,
num_threads: None,
initial_accept_fraction: 1.0,
feature_scales: None,
};
let bytes = postcard::to_allocvec(&bad).unwrap();
let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
assert!(back.is_err());
}
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn deserialize_rejects_negative_feature_scale() {
let bad = RcfConfigShadow {
num_trees: 100,
sample_size: 256,
time_decay: 0.0,
seed: None,
num_threads: None,
initial_accept_fraction: 1.0,
feature_scales: Some(alloc::vec![1.0, -0.5]),
};
let bytes = postcard::to_allocvec(&bad).unwrap();
let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
assert!(back.is_err());
}
}