#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{RcfError, Result};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RcfConfig {
input_dim: usize,
#[cfg_attr(feature = "serde", serde(default = "default_shingle_size"))]
shingle_size: usize,
#[cfg_attr(feature = "serde", serde(default = "default_capacity"))]
capacity: usize,
#[cfg_attr(feature = "serde", serde(default = "default_num_trees"))]
num_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
time_decay: f64,
#[cfg_attr(feature = "serde", serde(default))]
output_after: usize,
#[cfg_attr(feature = "serde", serde(default = "default_internal_shingling"))]
internal_shingling: bool,
#[cfg_attr(feature = "serde", serde(default = "default_initial_accept_fraction"))]
initial_accept_fraction: f64,
}
fn default_shingle_size() -> usize {
1
}
fn default_capacity() -> usize {
256
}
fn default_num_trees() -> usize {
50
}
fn default_internal_shingling() -> bool {
true
}
fn default_initial_accept_fraction() -> f64 {
0.125
}
pub(in crate::rcf) fn checked_tree_arena_capacity(capacity: usize) -> Option<usize> {
capacity.checked_mul(2).and_then(|v| v.checked_add(4))
}
pub(in crate::rcf) fn checked_point_store_capacity(
capacity: usize,
num_trees: usize,
) -> Option<usize> {
let shared_capacity = capacity
.checked_mul(num_trees)
.and_then(|v| v.checked_add(1))?;
let minimum_capacity = capacity.checked_mul(2)?;
Some(shared_capacity.max(minimum_capacity))
}
impl RcfConfig {
pub fn new(input_dim: usize) -> Self {
Self {
input_dim,
shingle_size: default_shingle_size(),
capacity: default_capacity(),
num_trees: default_num_trees(),
time_decay: 0.0,
output_after: 0,
internal_shingling: default_internal_shingling(),
initial_accept_fraction: default_initial_accept_fraction(),
}
}
pub fn with_shingle_size(mut self, v: usize) -> Self {
self.shingle_size = v;
self
}
pub fn with_capacity(mut self, v: usize) -> Self {
self.capacity = v;
self
}
pub fn with_num_trees(mut self, v: usize) -> Self {
self.num_trees = v;
self
}
pub fn with_time_decay(mut self, v: f64) -> Self {
self.time_decay = v;
self
}
pub fn with_output_after(mut self, v: usize) -> Self {
self.output_after = v;
self
}
pub fn with_internal_shingling(mut self, v: bool) -> Self {
self.internal_shingling = v;
self
}
pub fn with_initial_accept_fraction(mut self, v: f64) -> Self {
self.initial_accept_fraction = v;
self
}
pub fn input_dim(&self) -> usize {
self.input_dim
}
pub fn shingle_size(&self) -> usize {
self.shingle_size
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn num_trees(&self) -> usize {
self.num_trees
}
pub fn time_decay(&self) -> f64 {
self.time_decay
}
pub fn output_after(&self) -> usize {
self.output_after
}
pub fn internal_shingling(&self) -> bool {
self.internal_shingling
}
pub fn initial_accept_fraction(&self) -> f64 {
self.initial_accept_fraction
}
pub fn effective_time_decay(&self) -> f64 {
if self.time_decay == 0.0 {
0.1 / self.capacity as f64
} else {
self.time_decay
}
}
pub fn effective_output_after(&self) -> usize {
if self.output_after == 0 {
1 + self.capacity / 4
} else {
self.output_after
}
}
pub fn dim(&self) -> usize {
self.input_dim * self.shingle_size
}
pub(in crate::rcf) fn point_store_capacity(&self) -> usize {
checked_point_store_capacity(self.capacity, self.num_trees)
.expect("validated config must have a valid point-store capacity")
}
pub(in crate::rcf) fn validate(&self) -> Result<()> {
if self.input_dim == 0 {
return Err(RcfError::InvalidArgument("input_dim must be > 0".into()));
}
if self.shingle_size == 0 {
return Err(RcfError::InvalidArgument("shingle_size must be > 0".into()));
}
if self.input_dim.checked_mul(self.shingle_size).is_none() {
return Err(RcfError::InvalidArgument(
"input_dim * shingle_size overflows usize".into(),
));
}
if self.capacity == 0 {
return Err(RcfError::InvalidArgument("capacity must be > 0".into()));
}
if self.num_trees == 0 {
return Err(RcfError::InvalidArgument("num_trees must be > 0".into()));
}
if checked_tree_arena_capacity(self.capacity).is_none() {
return Err(RcfError::InvalidArgument(
"2 * capacity + 4 overflows usize".into(),
));
}
if checked_point_store_capacity(self.capacity, self.num_trees).is_none() {
return Err(RcfError::InvalidArgument(
"capacity * num_trees + 1 overflows usize".into(),
));
}
if !self.time_decay.is_finite() || self.time_decay < 0.0 {
return Err(RcfError::InvalidArgument(
"time_decay must be finite and >= 0.0".into(),
));
}
if !self.initial_accept_fraction.is_finite()
|| !(0.0..=1.0).contains(&self.initial_accept_fraction)
{
return Err(RcfError::InvalidArgument(
"initial_accept_fraction must be finite and in [0.0, 1.0]".into(),
));
}
Ok(())
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::error::RcfError;
use proptest::prelude::*;
use rstest::rstest;
proptest! {
#[test]
fn dim_equals_input_times_shingle(
input_dim in 1usize..=32,
shingle_size in 1usize..=16,
) {
let cfg = RcfConfig::new(input_dim).with_shingle_size(shingle_size);
prop_assert_eq!(cfg.dim(), input_dim * shingle_size);
}
#[test]
fn effective_time_decay_positive(capacity in 1usize..=1000) {
let cfg = RcfConfig::new(1).with_capacity(capacity);
prop_assert!(cfg.effective_time_decay() > 0.0);
}
#[test]
fn effective_output_after_positive(capacity in 1usize..=1000) {
let cfg = RcfConfig::new(1).with_capacity(capacity).with_output_after(0);
prop_assert!(cfg.effective_output_after() >= 1);
}
#[test]
fn setters_reflect_values(n in 1usize..=100) {
let cfg = RcfConfig::new(1).with_num_trees(n);
prop_assert_eq!(cfg.num_trees(), n);
}
}
#[test]
fn validate_accepts_default_config() {
RcfConfig::new(1).validate().unwrap();
}
#[rstest]
#[case::tree_arena_3(checked_tree_arena_capacity(3), Some(10))]
#[case::tree_arena_max(checked_tree_arena_capacity(usize::MAX), None)]
#[case::point_store_3_4(checked_point_store_capacity(3, 4), Some(13))]
#[case::point_store_3_1(checked_point_store_capacity(3, 1), Some(6))]
#[case::point_store_max(checked_point_store_capacity(9, usize::MAX), None)]
fn internal_capacity_helpers_match_construction_formulas(
#[case] actual: Option<usize>,
#[case] expected: Option<usize>,
) {
assert_eq!(actual, expected);
}
#[test]
fn getters_reflect_all_config_fields() {
let config = RcfConfig::new(3)
.with_shingle_size(4)
.with_capacity(128)
.with_num_trees(17)
.with_time_decay(0.25)
.with_output_after(9)
.with_internal_shingling(false)
.with_initial_accept_fraction(0.5);
assert_eq!(config.input_dim(), 3);
assert_eq!(config.shingle_size(), 4);
assert_eq!(config.capacity(), 128);
assert_eq!(config.num_trees(), 17);
assert_eq!(config.time_decay(), 0.25);
assert_eq!(config.output_after(), 9);
assert!(!config.internal_shingling());
assert_eq!(config.initial_accept_fraction(), 0.5);
assert_eq!(config.dim(), 12);
}
#[cfg(feature = "serde")]
#[test]
fn serde_preserves_private_field_wire_shape_and_defaults() {
let config = RcfConfig::new(3)
.with_shingle_size(4)
.with_capacity(128)
.with_num_trees(17)
.with_time_decay(0.25)
.with_output_after(9)
.with_internal_shingling(false)
.with_initial_accept_fraction(0.5);
let value = serde_json::to_value(&config).unwrap();
assert_eq!(value["input_dim"], 3);
assert_eq!(value["shingle_size"], 4);
assert_eq!(value["capacity"], 128);
assert_eq!(value["num_trees"], 17);
assert_eq!(value["time_decay"], 0.25);
assert_eq!(value["output_after"], 9);
assert_eq!(value["internal_shingling"], false);
assert_eq!(value["initial_accept_fraction"], 0.5);
let minimal: RcfConfig = serde_json::from_str(r#"{"input_dim":2}"#).unwrap();
assert_eq!(minimal.input_dim(), 2);
assert_eq!(minimal.shingle_size(), 1);
assert_eq!(minimal.capacity(), 256);
assert_eq!(minimal.num_trees(), 50);
assert_eq!(minimal.time_decay(), 0.0);
assert_eq!(minimal.output_after(), 0);
assert!(minimal.internal_shingling());
assert_eq!(minimal.initial_accept_fraction(), 0.125);
}
#[rstest]
#[case::zero_input_dim(RcfConfig::new(0), "input_dim")]
#[case::zero_shingle_size(RcfConfig::new(1).with_shingle_size(0), "shingle_size")]
#[case::zero_capacity(RcfConfig::new(1).with_capacity(0), "capacity")]
#[case::zero_num_trees(RcfConfig::new(1).with_num_trees(0), "num_trees")]
#[case::tree_arena_capacity_overflow(
RcfConfig::new(1).with_capacity(usize::MAX / 2),
"2 * capacity + 4"
)]
#[case::point_store_capacity_overflow(
RcfConfig::new(1)
.with_capacity(2)
.with_num_trees(usize::MAX / 2 + 1),
"capacity * num_trees + 1"
)]
#[case::negative_time_decay(RcfConfig::new(1).with_time_decay(-0.1), "time_decay")]
#[case::nan_time_decay(RcfConfig::new(1).with_time_decay(f64::NAN), "time_decay")]
#[case::infinite_time_decay(
RcfConfig::new(1).with_time_decay(f64::INFINITY),
"time_decay"
)]
#[case::negative_initial_accept_fraction(
RcfConfig::new(1).with_initial_accept_fraction(-0.1),
"initial_accept_fraction"
)]
#[case::nan_initial_accept_fraction(
RcfConfig::new(1).with_initial_accept_fraction(f64::NAN),
"initial_accept_fraction"
)]
#[case::infinite_initial_accept_fraction(
RcfConfig::new(1).with_initial_accept_fraction(f64::INFINITY),
"initial_accept_fraction"
)]
#[case::too_large_initial_accept_fraction(
RcfConfig::new(1).with_initial_accept_fraction(1.1),
"initial_accept_fraction"
)]
fn validate_rejects_invalid_core_fields(
#[case] config: RcfConfig,
#[case] expected_message: &str,
) {
let err = config.validate().unwrap_err();
assert!(
matches!(err, RcfError::InvalidArgument(ref msg) if msg.contains(expected_message)),
"unexpected error variant: {err:?}"
);
}
#[rstest]
#[case::zero_time_decay_and_zero_initial_fraction(
RcfConfig::new(1)
.with_time_decay(0.0)
.with_initial_accept_fraction(0.0)
)]
#[case::positive_time_decay_and_full_initial_fraction(
RcfConfig::new(1)
.with_time_decay(0.1)
.with_initial_accept_fraction(1.0)
)]
fn validate_accepts_float_boundaries(#[case] config: RcfConfig) {
config.validate().unwrap();
}
}