1use alloc::format;
14use alloc::vec::Vec;
15
16use crate::error::{RcfError, RcfResult};
17use crate::forest::random_cut_forest::RandomCutForest;
18
19pub const MIN_DIMENSION: usize = 1;
21pub const MAX_DIMENSION: usize = 10_000;
23pub const MIN_NUM_TREES: usize = 50;
25pub const MAX_NUM_TREES: usize = 1_000;
27pub const DEFAULT_NUM_TREES: usize = 100;
29pub const MIN_SAMPLE_SIZE: usize = 1;
31pub const MAX_SAMPLE_SIZE: usize = 2_048;
33pub const DEFAULT_SAMPLE_SIZE: usize = 256;
35pub const TIME_DECAY_NUMERATOR: f64 = 0.1;
43#[allow(clippy::cast_precision_loss)]
48pub const DEFAULT_TIME_DECAY: f64 = TIME_DECAY_NUMERATOR / DEFAULT_SAMPLE_SIZE as f64;
49
50#[must_use]
54pub fn default_time_decay_for(sample_size: usize) -> f64 {
55 if sample_size == 0 {
56 return 0.0;
57 }
58 #[allow(clippy::cast_precision_loss)]
59 {
60 TIME_DECAY_NUMERATOR / sample_size as f64
61 }
62}
63pub const DEFAULT_INITIAL_ACCEPT_FRACTION: f64 = 1.0;
67
68#[derive(Debug, Clone, PartialEq)]
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[cfg_attr(feature = "serde", serde(try_from = "RcfConfigShadow"))]
86#[non_exhaustive]
87pub struct RcfConfig {
88 pub num_trees: usize,
90 pub sample_size: usize,
92 pub time_decay: f64,
98 pub seed: Option<u64>,
100 pub num_threads: Option<usize>,
107 #[cfg_attr(feature = "serde", serde(default = "default_initial_accept_fraction"))]
113 pub initial_accept_fraction: f64,
114 #[cfg_attr(feature = "serde", serde(default))]
133 pub feature_scales: Option<Vec<f64>>,
134}
135
136#[cfg(feature = "serde")]
140#[must_use]
141fn default_initial_accept_fraction() -> f64 {
142 DEFAULT_INITIAL_ACCEPT_FRACTION
143}
144
145#[cfg(feature = "serde")]
151#[derive(serde::Serialize, serde::Deserialize)]
152#[allow(clippy::missing_docs_in_private_items)]
153struct RcfConfigShadow {
154 num_trees: usize,
155 sample_size: usize,
156 time_decay: f64,
157 seed: Option<u64>,
158 num_threads: Option<usize>,
159 #[serde(default = "default_initial_accept_fraction")]
160 initial_accept_fraction: f64,
161 #[serde(default)]
162 feature_scales: Option<Vec<f64>>,
163}
164
165#[cfg(feature = "serde")]
166impl TryFrom<RcfConfigShadow> for RcfConfig {
167 type Error = RcfError;
168
169 fn try_from(raw: RcfConfigShadow) -> Result<Self, Self::Error> {
170 let cfg = Self {
171 num_trees: raw.num_trees,
172 sample_size: raw.sample_size,
173 time_decay: raw.time_decay,
174 seed: raw.seed,
175 num_threads: raw.num_threads,
176 initial_accept_fraction: raw.initial_accept_fraction,
177 feature_scales: raw.feature_scales,
178 };
179 cfg.validate()?;
180 Ok(cfg)
181 }
182}
183
184impl RcfConfig {
185 pub fn validate(&self) -> RcfResult<()> {
196 if !(MIN_NUM_TREES..=MAX_NUM_TREES).contains(&self.num_trees) {
197 return Err(RcfError::InvalidConfig(
198 format!(
199 "num_trees {} out of [{}, {}]",
200 self.num_trees, MIN_NUM_TREES, MAX_NUM_TREES
201 )
202 .into(),
203 ));
204 }
205 if !(MIN_SAMPLE_SIZE..=MAX_SAMPLE_SIZE).contains(&self.sample_size) {
206 return Err(RcfError::InvalidConfig(
207 format!(
208 "sample_size {} out of [{}, {}]",
209 self.sample_size, MIN_SAMPLE_SIZE, MAX_SAMPLE_SIZE
210 )
211 .into(),
212 ));
213 }
214 if !self.time_decay.is_finite() || !(0.0..=1.0).contains(&self.time_decay) {
215 return Err(RcfError::InvalidConfig(
216 format!("time_decay {} out of [0.0, 1.0]", self.time_decay).into(),
217 ));
218 }
219 if let Some(n) = self.num_threads
220 && n == 0
221 {
222 return Err(RcfError::InvalidConfig(
223 "num_threads must be > 0 when set; use None to fall back to rayon's global pool"
224 .into(),
225 ));
226 }
227 if !self.initial_accept_fraction.is_finite()
228 || self.initial_accept_fraction <= 0.0
229 || self.initial_accept_fraction > 1.0
230 {
231 return Err(RcfError::InvalidConfig(
232 format!(
233 "initial_accept_fraction {} out of (0.0, 1.0]",
234 self.initial_accept_fraction
235 )
236 .into(),
237 ));
238 }
239 if let Some(scales) = &self.feature_scales {
240 for (i, s) in scales.iter().enumerate() {
241 if !s.is_finite() || *s <= 0.0 {
242 return Err(RcfError::InvalidConfig(
243 format!("feature_scales[{i}] must be finite and > 0, got {s}").into(),
244 ));
245 }
246 }
247 }
248 Ok(())
249 }
250
251 pub fn validate_feature_scales_dimension(&self, d: usize) -> RcfResult<()> {
260 if let Some(scales) = &self.feature_scales
261 && scales.len() != d
262 {
263 return Err(RcfError::DimensionMismatch {
264 expected: d,
265 got: scales.len(),
266 });
267 }
268 Ok(())
269 }
270
271 pub fn validate_dimension(dimension: usize) -> RcfResult<()> {
280 if !(MIN_DIMENSION..=MAX_DIMENSION).contains(&dimension) {
281 return Err(RcfError::InvalidConfig(
282 format!("dimension {dimension} out of [{MIN_DIMENSION}, {MAX_DIMENSION}]").into(),
283 ));
284 }
285 Ok(())
286 }
287}
288
289#[derive(Debug, Clone)]
313pub struct ForestBuilder<const D: usize> {
314 config: RcfConfig,
317 time_decay_explicit: bool,
323}
324
325impl<const D: usize> Default for ForestBuilder<D> {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331impl<const D: usize> ForestBuilder<D> {
332 #[must_use]
334 pub fn new() -> Self {
335 Self {
336 config: RcfConfig {
337 num_trees: DEFAULT_NUM_TREES,
338 sample_size: DEFAULT_SAMPLE_SIZE,
339 time_decay: default_time_decay_for(DEFAULT_SAMPLE_SIZE),
340 seed: None,
341 num_threads: None,
342 initial_accept_fraction: DEFAULT_INITIAL_ACCEPT_FRACTION,
343 feature_scales: None,
344 },
345 time_decay_explicit: false,
346 }
347 }
348
349 #[must_use]
351 pub fn num_trees(mut self, n: usize) -> Self {
352 self.config.num_trees = n;
353 self
354 }
355
356 #[must_use]
362 pub fn sample_size(mut self, s: usize) -> Self {
363 self.config.sample_size = s;
364 if !self.time_decay_explicit {
365 self.config.time_decay = default_time_decay_for(s);
366 }
367 self
368 }
369
370 #[must_use]
376 pub fn time_decay(mut self, d: f64) -> Self {
377 self.config.time_decay = d;
378 self.time_decay_explicit = true;
379 self
380 }
381
382 #[must_use]
384 pub fn seed(mut self, seed: u64) -> Self {
385 self.config.seed = Some(seed);
386 self
387 }
388
389 #[must_use]
394 pub fn num_threads(mut self, n: usize) -> Self {
395 self.config.num_threads = Some(n);
396 self
397 }
398
399 #[must_use]
405 pub fn initial_accept_fraction(mut self, f: f64) -> Self {
406 self.config.initial_accept_fraction = f;
407 self
408 }
409
410 #[must_use]
418 pub fn feature_scales(mut self, scales: [f64; D]) -> Self {
419 self.config.feature_scales = Some(scales.to_vec());
420 self
421 }
422
423 #[must_use]
428 pub fn clear_feature_scales(mut self) -> Self {
429 self.config.feature_scales = None;
430 self
431 }
432
433 #[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
435 pub fn config(&self) -> &RcfConfig {
436 &self.config
437 }
438
439 #[must_use]
441 pub const fn dimension(&self) -> usize {
442 D
443 }
444
445 #[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
452 pub fn build(self) -> RcfResult<RandomCutForest<D>> {
453 RcfConfig::validate_dimension(D)?;
454 self.config.validate()?;
455 self.config.validate_feature_scales_dimension(D)?;
456 RandomCutForest::<D>::from_config(self.config)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 fn cfg(n: usize, s: usize, td: f64) -> RcfConfig {
465 RcfConfig {
466 num_trees: n,
467 sample_size: s,
468 time_decay: td,
469 seed: None,
470 num_threads: None,
471 initial_accept_fraction: DEFAULT_INITIAL_ACCEPT_FRACTION,
472 feature_scales: None,
473 }
474 }
475
476 #[test]
477 fn validate_default_passes() {
478 let c = cfg(DEFAULT_NUM_TREES, DEFAULT_SAMPLE_SIZE, DEFAULT_TIME_DECAY);
479 c.validate().unwrap();
480 }
481
482 #[test]
483 fn validate_dimension_rejects_zero() {
484 assert!(matches!(
485 RcfConfig::validate_dimension(0).unwrap_err(),
486 RcfError::InvalidConfig(_)
487 ));
488 }
489
490 #[test]
491 fn validate_dimension_rejects_above_max() {
492 assert!(RcfConfig::validate_dimension(10_001).is_err());
493 }
494
495 #[test]
496 fn validate_dimension_accepts_at_max() {
497 RcfConfig::validate_dimension(10_000).unwrap();
498 }
499
500 #[test]
501 fn validate_rejects_num_trees_below_min() {
502 assert!(cfg(49, 256, 0.0).validate().is_err());
503 }
504
505 #[test]
506 fn validate_accepts_num_trees_at_bounds() {
507 cfg(50, 256, 0.0).validate().unwrap();
508 cfg(1000, 256, 0.0).validate().unwrap();
509 }
510
511 #[test]
512 fn validate_rejects_num_trees_above_max() {
513 assert!(cfg(1001, 256, 0.0).validate().is_err());
514 }
515
516 #[test]
517 fn validate_rejects_sample_size_zero() {
518 assert!(cfg(100, 0, 0.0).validate().is_err());
519 }
520
521 #[test]
522 fn validate_accepts_sample_size_at_bounds() {
523 cfg(100, 1, 0.0).validate().unwrap();
524 cfg(100, 2048, 0.0).validate().unwrap();
525 }
526
527 #[test]
528 fn validate_rejects_sample_size_above_max() {
529 assert!(cfg(100, 2049, 0.0).validate().is_err());
530 }
531
532 #[test]
533 fn validate_rejects_negative_time_decay() {
534 assert!(cfg(100, 256, -0.01).validate().is_err());
535 }
536
537 #[test]
538 fn validate_rejects_time_decay_above_one() {
539 assert!(cfg(100, 256, 1.01).validate().is_err());
540 }
541
542 #[test]
543 fn validate_rejects_non_finite_time_decay() {
544 assert!(cfg(100, 256, f64::NAN).validate().is_err());
545 assert!(cfg(100, 256, f64::INFINITY).validate().is_err());
546 }
547
548 #[test]
549 fn validate_rejects_zero_num_threads() {
550 let mut c = cfg(100, 256, 0.0);
551 c.num_threads = Some(0);
552 assert!(matches!(
553 c.validate().unwrap_err(),
554 RcfError::InvalidConfig(_)
555 ));
556 }
557
558 #[test]
559 fn validate_accepts_some_num_threads() {
560 let mut c = cfg(100, 256, 0.0);
561 c.num_threads = Some(4);
562 c.validate().unwrap();
563 }
564
565 #[test]
566 fn validate_accepts_default_num_threads_none() {
567 let c = cfg(100, 256, 0.0);
568 assert_eq!(c.num_threads, None);
569 c.validate().unwrap();
570 }
571
572 #[test]
573 fn builder_num_threads_sets_field() {
574 let b = ForestBuilder::<4>::new().num_threads(8);
575 assert_eq!(b.config().num_threads, Some(8));
576 }
577
578 #[test]
579 fn validate_accepts_initial_accept_fraction_at_bounds() {
580 let mut c = cfg(100, 256, 0.0);
581 c.initial_accept_fraction = 0.001;
582 c.validate().unwrap();
583 c.initial_accept_fraction = 1.0;
584 c.validate().unwrap();
585 }
586
587 #[test]
588 fn validate_rejects_initial_accept_fraction_out_of_range() {
589 let mut c = cfg(100, 256, 0.0);
590 c.initial_accept_fraction = 0.0;
591 assert!(c.validate().is_err());
592 c.initial_accept_fraction = -0.1;
593 assert!(c.validate().is_err());
594 c.initial_accept_fraction = 1.01;
595 assert!(c.validate().is_err());
596 }
597
598 #[test]
599 fn validate_rejects_non_finite_initial_accept_fraction() {
600 let mut c = cfg(100, 256, 0.0);
601 c.initial_accept_fraction = f64::NAN;
602 assert!(c.validate().is_err());
603 c.initial_accept_fraction = f64::INFINITY;
604 assert!(c.validate().is_err());
605 }
606
607 #[test]
608 fn builder_initial_accept_fraction_sets_field() {
609 let b = ForestBuilder::<4>::new().initial_accept_fraction(0.125);
610 assert!((b.config().initial_accept_fraction - 0.125).abs() < f64::EPSILON);
611 }
612
613 #[test]
614 fn builder_defaults_initial_accept_fraction_to_one() {
615 let b = ForestBuilder::<4>::new();
616 assert!((b.config().initial_accept_fraction - 1.0).abs() < f64::EPSILON);
617 }
618
619 #[test]
620 fn builder_defaults_match_aws() {
621 let b = ForestBuilder::<8>::new();
622 assert_eq!(b.dimension(), 8);
623 assert_eq!(b.config().num_trees, 100);
624 assert_eq!(b.config().sample_size, 256);
625 assert!(
626 (b.config().time_decay - TIME_DECAY_NUMERATOR / 256.0).abs() < f64::EPSILON,
627 "default time_decay should resolve to 0.1 / sample_size, got {}",
628 b.config().time_decay
629 );
630 assert_eq!(b.config().seed, None);
631 }
632
633 #[test]
634 fn builder_sample_size_override_rescales_default_time_decay() {
635 let b = ForestBuilder::<4>::new().sample_size(128);
636 assert!(
638 (b.config().time_decay - TIME_DECAY_NUMERATOR / 128.0).abs() < f64::EPSILON,
639 "sample_size(128) should rescale default to 0.1 / 128, got {}",
640 b.config().time_decay,
641 );
642 }
643
644 #[test]
645 fn builder_explicit_time_decay_sticks_across_sample_size_override() {
646 let b = ForestBuilder::<4>::new().time_decay(0.05).sample_size(128);
647 assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
649 }
650
651 #[test]
652 fn builder_sample_size_override_before_time_decay() {
653 let b = ForestBuilder::<4>::new().sample_size(128).time_decay(0.05);
654 assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
656 }
657
658 #[test]
659 fn builder_time_decay_zero_still_accepted() {
660 let b = ForestBuilder::<4>::new().time_decay(0.0);
661 assert!(b.config().time_decay.abs() < f64::EPSILON);
662 b.build().expect("time_decay=0 must still build");
663 }
664
665 #[test]
666 fn default_time_decay_for_zero_sample_size_is_zero() {
667 assert!(default_time_decay_for(0).abs() < f64::EPSILON);
668 }
669
670 #[test]
671 fn default_time_decay_for_default_sample_size_matches_constant() {
672 assert!(
673 (default_time_decay_for(DEFAULT_SAMPLE_SIZE) - DEFAULT_TIME_DECAY).abs() < f64::EPSILON,
674 );
675 }
676
677 #[test]
678 fn builder_overrides_apply() {
679 let b = ForestBuilder::<4>::new()
680 .num_trees(50)
681 .sample_size(64)
682 .time_decay(0.05)
683 .seed(42);
684 assert_eq!(b.config().num_trees, 50);
685 assert_eq!(b.config().sample_size, 64);
686 assert!((b.config().time_decay - 0.05).abs() < f64::EPSILON);
687 assert_eq!(b.config().seed, Some(42));
688 }
689
690 #[test]
691 fn builder_build_validates() {
692 let err = ForestBuilder::<4>::new().num_trees(10).build().unwrap_err();
693 assert!(matches!(err, RcfError::InvalidConfig(_)));
694 }
695
696 #[cfg(all(feature = "serde", feature = "postcard"))]
697 #[test]
698 fn deserialize_rejects_out_of_range_num_trees() {
699 let bad = RcfConfigShadow {
700 num_trees: MAX_NUM_TREES + 1,
701 sample_size: 256,
702 time_decay: 0.0,
703 seed: None,
704 num_threads: None,
705 initial_accept_fraction: 1.0,
706 feature_scales: None,
707 };
708 let bytes = postcard::to_allocvec(&bad).unwrap();
709 let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
710 assert!(back.is_err());
711 }
712
713 #[cfg(all(feature = "serde", feature = "postcard"))]
714 #[test]
715 fn deserialize_rejects_nan_time_decay() {
716 let bad = RcfConfigShadow {
717 num_trees: 100,
718 sample_size: 256,
719 time_decay: f64::NAN,
720 seed: None,
721 num_threads: None,
722 initial_accept_fraction: 1.0,
723 feature_scales: None,
724 };
725 let bytes = postcard::to_allocvec(&bad).unwrap();
726 let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
727 assert!(back.is_err());
728 }
729
730 #[cfg(all(feature = "serde", feature = "postcard"))]
731 #[test]
732 fn deserialize_rejects_negative_feature_scale() {
733 let bad = RcfConfigShadow {
734 num_trees: 100,
735 sample_size: 256,
736 time_decay: 0.0,
737 seed: None,
738 num_threads: None,
739 initial_accept_fraction: 1.0,
740 feature_scales: Some(alloc::vec![1.0, -0.5]),
741 };
742 let bytes = postcard::to_allocvec(&bad).unwrap();
743 let back: Result<RcfConfig, _> = postcard::from_bytes(&bytes);
744 assert!(back.is_err());
745 }
746}