use crate::error::{RcfError, RcfResult};
use crate::forest::RandomCutForest;
use crate::thresholded::ThresholdedForest;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BootstrapReport {
pub points_ingested: u64,
pub points_skipped: u64,
pub final_observations: u64,
pub final_threshold: f64,
}
impl BootstrapReport {
#[must_use]
pub fn empty() -> Self {
Self {
points_ingested: 0,
points_skipped: 0,
final_observations: 0,
final_threshold: 0.0,
}
}
#[must_use]
pub fn is_hot(&self) -> bool {
self.points_ingested > 0
}
}
impl Default for BootstrapReport {
fn default() -> Self {
Self::empty()
}
}
fn is_finite_point<const D: usize>(p: &[f64; D]) -> bool {
p.iter().all(|x| x.is_finite())
}
impl<const D: usize> RandomCutForest<D> {
pub fn bootstrap<I>(&mut self, points: I) -> RcfResult<BootstrapReport>
where
I: IntoIterator<Item = [f64; D]>,
{
let mut ingested: u64 = 0;
let mut skipped: u64 = 0;
for p in points {
if !is_finite_point(&p) {
skipped = skipped.saturating_add(1);
continue;
}
match self.update(p) {
Ok(()) => ingested = ingested.saturating_add(1),
Err(RcfError::NaNValue) => skipped = skipped.saturating_add(1),
Err(other) => return Err(other),
}
}
#[cfg(feature = "std")]
{
use crate::metrics::names;
let sink = self.metrics_sink();
sink.inc_counter(names::BOOTSTRAP_POINTS_TOTAL, ingested);
sink.inc_counter(names::BOOTSTRAP_SKIPPED_TOTAL, skipped);
}
Ok(BootstrapReport {
points_ingested: ingested,
points_skipped: skipped,
final_observations: self.updates_seen(),
final_threshold: 0.0,
})
}
}
impl<const D: usize> ThresholdedForest<D> {
pub fn bootstrap<I>(&mut self, points: I) -> RcfResult<BootstrapReport>
where
I: IntoIterator<Item = [f64; D]>,
{
let mut ingested: u64 = 0;
let mut skipped: u64 = 0;
for p in points {
if !is_finite_point(&p) {
skipped = skipped.saturating_add(1);
continue;
}
match self.process(p) {
Ok(_) => ingested = ingested.saturating_add(1),
Err(RcfError::NaNValue) => skipped = skipped.saturating_add(1),
Err(other) => return Err(other),
}
}
#[cfg(feature = "std")]
{
use crate::metrics::names;
let sink = self.metrics_sink();
sink.inc_counter(names::BOOTSTRAP_POINTS_TOTAL, ingested);
sink.inc_counter(names::BOOTSTRAP_SKIPPED_TOTAL, skipped);
}
Ok(BootstrapReport {
points_ingested: ingested,
points_skipped: skipped,
final_observations: self.stats().observations(),
final_threshold: self.current_threshold(),
})
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
use crate::{ForestBuilder, ThresholdedForestBuilder};
#[test]
fn bootstrap_report_empty_defaults() {
let r = BootstrapReport::empty();
assert_eq!(r.points_ingested, 0);
assert_eq!(r.points_skipped, 0);
assert_eq!(r.final_observations, 0);
assert_eq!(r.final_threshold, 0.0);
assert!(!r.is_hot());
assert_eq!(r, BootstrapReport::default());
}
#[test]
fn forest_bootstrap_from_empty_iter_is_noop() {
let mut f = ForestBuilder::<2>::new().seed(1).build().unwrap();
let r = f.bootstrap(std::iter::empty::<[f64; 2]>()).unwrap();
assert_eq!(r.points_ingested, 0);
assert!(!r.is_hot());
assert_eq!(f.updates_seen(), 0);
}
#[test]
fn forest_bootstrap_counts_ingested_and_skipped() {
let mut f = ForestBuilder::<2>::new().seed(1).build().unwrap();
let pts: Vec<[f64; 2]> = vec![
[0.0, 0.0],
[1.0, 1.0],
[f64::NAN, 0.0], [2.0, 2.0],
[0.0, f64::INFINITY], ];
let r = f.bootstrap(pts).unwrap();
assert_eq!(r.points_ingested, 3);
assert_eq!(r.points_skipped, 2);
assert_eq!(f.updates_seen(), 3);
assert_eq!(r.final_observations, 3);
}
#[test]
fn thresholded_bootstrap_makes_detector_ready() {
use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
let mut d = ThresholdedForestBuilder::<4>::new()
.num_trees(50)
.sample_size(64)
.min_observations(32)
.min_threshold(0.1)
.seed(42)
.build()
.unwrap();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let history: Vec<[f64; 4]> = (0..512)
.map(|_| {
[
rng.random::<f64>() * 0.1,
rng.random::<f64>() * 0.1,
rng.random::<f64>() * 0.1,
rng.random::<f64>() * 0.1,
]
})
.collect();
let r = d.bootstrap(history).unwrap();
assert_eq!(r.points_ingested, 512);
assert!(r.is_hot());
assert!(r.final_observations >= 32, "should be past warmup");
assert!(r.final_threshold > 0.1, "threshold should be adaptive");
let verdict = d.score_only(&[0.05, 0.05, 0.05, 0.05]).unwrap();
assert!(verdict.ready(), "detector must be hot after bootstrap");
}
#[test]
fn thresholded_bootstrap_detects_outlier_immediately() {
use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
let mut d = ThresholdedForestBuilder::<4>::new()
.num_trees(50)
.sample_size(64)
.min_observations(32)
.min_threshold(0.1)
.seed(3)
.build()
.unwrap();
let mut rng = ChaCha8Rng::seed_from_u64(3);
let history: Vec<[f64; 4]> = (0..512)
.map(|_| {
[
rng.random::<f64>() * 0.1,
rng.random::<f64>() * 0.1,
rng.random::<f64>() * 0.1,
rng.random::<f64>() * 0.1,
]
})
.collect();
d.bootstrap(history).unwrap();
let outlier = d.process([50.0, 50.0, 50.0, 50.0]).unwrap();
assert!(outlier.ready());
assert!(outlier.is_anomaly());
assert!(outlier.grade() > 0.0);
}
#[test]
fn thresholded_bootstrap_skips_non_finite() {
let mut d = ThresholdedForestBuilder::<2>::new()
.num_trees(50)
.sample_size(16)
.min_observations(4)
.seed(1)
.build()
.unwrap();
let pts: Vec<[f64; 2]> = vec![
[0.0, 0.0],
[f64::NAN, 0.0],
[0.5, 0.5],
[f64::NEG_INFINITY, 1.0],
];
let r = d.bootstrap(pts).unwrap();
assert_eq!(r.points_ingested, 2);
assert_eq!(r.points_skipped, 2);
}
}