#[cfg(not(feature = "std"))]
use alloc::collections::BTreeMap;
#[cfg(all(not(feature = "std"), feature = "serde"))]
use alloc::string::{String, ToString};
#[cfg(not(feature = "std"))]
use alloc::{format, vec, vec::Vec};
#[cfg(feature = "std")]
use std::collections::BTreeMap;
use rand::prelude::*;
use rand::rngs::Xoshiro256PlusPlus;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
config::RcfConfig,
error::{RcfError, Result},
point_store::PointStore,
sampler::{Sampler, reservoir_weight},
score::{Attribution, ScoreMode},
tree::RcfTree,
};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Forest {
pub(crate) config: RcfConfig,
trees: Vec<RcfTree>,
samplers: Vec<Sampler>,
pub(crate) point_store: PointStore,
entries_seen: u64,
rng: Xoshiro256PlusPlus,
}
#[derive(Clone, Debug, PartialEq)]
pub struct NeighborCandidate {
pub score: f64,
pub point_idx: usize,
pub distance: f64,
}
impl From<(f64, usize, f64)> for NeighborCandidate {
fn from(value: (f64, usize, f64)) -> Self {
Self {
score: value.0,
point_idx: value.1,
distance: value.2,
}
}
}
impl From<NeighborCandidate> for (f64, usize, f64) {
fn from(value: NeighborCandidate) -> Self {
(value.score, value.point_idx, value.distance)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct NeighborResult {
pub score: f64,
pub point: Vec<f32>,
pub distance: f64,
}
impl From<(f64, Vec<f32>, f64)> for NeighborResult {
fn from(value: (f64, Vec<f32>, f64)) -> Self {
Self {
score: value.0,
point: value.1,
distance: value.2,
}
}
}
impl From<NeighborResult> for (f64, Vec<f32>, f64) {
fn from(value: NeighborResult) -> Self {
(value.score, value.point, value.distance)
}
}
fn make_missing_flags(missing: &[usize], dim: usize) -> Result<Vec<bool>> {
let mut missing_flags = vec![false; dim];
for &i in missing {
if i >= dim {
return Err(RcfError::IndexOutOfBounds(i));
}
missing_flags[i] = true;
}
Ok(missing_flags)
}
fn median_in_place(vals: &mut [f32]) -> f32 {
debug_assert!(!vals.is_empty(), "median_in_place requires non-empty input");
let n = vals.len();
let mid = n / 2;
vals.select_nth_unstable_by(mid, |a, b| a.total_cmp(b));
if n % 2 == 1 {
vals[mid]
} else {
let lo = vals[..mid]
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
(lo + vals[mid]) / 2.0
}
}
impl Forest {
fn new_internal(config: RcfConfig, seed: u64) -> Result<Self> {
if config.input_dim == 0 {
return Err(RcfError::InvalidArgument("input_dim must be > 0".into()));
}
if config.shingle_size == 0 {
return Err(RcfError::InvalidArgument("shingle_size must be > 0".into()));
}
if config.capacity == 0 {
return Err(RcfError::InvalidArgument("capacity must be > 0".into()));
}
if config.num_trees == 0 {
return Err(RcfError::InvalidArgument("num_trees must be > 0".into()));
}
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
let dim = config.dim();
let capacity = config.capacity;
let num_trees = config.num_trees;
let store_capacity = (capacity * num_trees + 1).max(2 * capacity);
let trees: Vec<RcfTree> = (0..num_trees)
.map(|_| RcfTree::new(dim, capacity, rng.next_u64()))
.collect();
let samplers: Vec<Sampler> = (0..num_trees).map(|_| Sampler::new(capacity)).collect();
let point_store = PointStore::new(
config.input_dim,
config.shingle_size,
store_capacity,
config.internal_shingling,
);
Ok(Forest {
config,
trees,
samplers,
point_store,
entries_seen: 0,
rng: Xoshiro256PlusPlus::seed_from_u64(rng.next_u64()),
})
}
pub fn from_config(config: &RcfConfig) -> Result<Self> {
let mut seed_rng: Xoshiro256PlusPlus = rand::make_rng();
Self::new_internal(config.clone(), seed_rng.next_u64())
}
pub fn from_config_seeded(config: &RcfConfig, seed: u64) -> Result<Self> {
Self::new_internal(config.clone(), seed)
}
pub fn builder(input_dim: usize) -> ForestBuilder {
ForestBuilder::new(input_dim)
}
pub fn update(&mut self, base: &[f32]) -> Result<()> {
let shingled = self.point_store.shingled_point(base)?;
self.entries_seen += 1;
let shingle_lag = if self.config.internal_shingling {
self.config.shingle_size.saturating_sub(1)
} else {
0
};
if self.entries_seen as usize <= shingle_lag {
return Ok(());
}
let point_idx = self.point_store.add(&shingled)?;
let time_decay = self.config.effective_time_decay();
let initial_frac = self.config.initial_accept_fraction;
let mut any_accepted = false;
for t in 0..self.trees.len() {
let u: f64 = self.rng.random::<f64>();
let weight = reservoir_weight(u, time_decay, self.entries_seen);
let fill = self.samplers[t].fill_fraction();
let is_initial = if self.samplers[t].is_full() {
false
} else {
let prob = if fill < initial_frac {
1.0
} else if initial_frac >= 1.0 {
0.0
} else {
1.0 - (fill - initial_frac) / (1.0 - initial_frac)
};
self.rng.random::<f64>() < prob
};
let result = self.samplers[t].accept(is_initial, weight, point_idx);
if result.accepted {
any_accepted = true;
if let Some(evicted_idx) = result.evicted {
self.trees[t].delete(evicted_idx, &self.point_store)?;
self.point_store.dec_ref(evicted_idx);
}
self.trees[t].insert(point_idx, &self.point_store)?;
self.point_store.inc_ref(point_idx);
self.samplers[t].add_point(point_idx);
}
}
if !any_accepted {
self.point_store.dec_ref(point_idx);
}
Ok(())
}
pub fn is_ready(&self) -> bool {
let needed = self.config.effective_output_after()
+ if self.config.internal_shingling {
self.config.shingle_size.saturating_sub(1)
} else {
0
};
self.entries_seen as usize > needed
}
pub fn score(&self, query: &[f32]) -> Result<f64> {
let q = self.prepare_query(query)?;
Ok(self.forest_score(&q, &ScoreMode::standard()))
}
pub fn displacement_score(&self, query: &[f32]) -> Result<f64> {
let q = self.prepare_query(query)?;
Ok(self.forest_score(&q, &ScoreMode::displacement()))
}
pub fn attribution(&self, query: &[f32]) -> Result<Vec<Attribution>> {
self.attribution_sequential(query)
}
fn attribution_sequential(&self, query: &[f32]) -> Result<Vec<Attribution>> {
let q = self.prepare_query(query)?;
let dim = self.config.dim();
let mode = ScoreMode::standard();
let n = self.trees.len() as f64;
let total_attr = self
.trees
.iter()
.map(|tree| tree.attribution(&q, &mode))
.fold(vec![Attribution::default(); dim], |mut acc, tree_attr| {
for i in 0..dim {
acc[i] += tree_attr[i];
}
acc
});
Ok(total_attr.into_iter().map(|a| a.scale(1.0 / n)).collect())
}
pub fn density(&self, query: &[f32]) -> Result<f64> {
self.density_sequential(query)
}
fn density_sequential(&self, query: &[f32]) -> Result<f64> {
let q = self.prepare_query(query)?;
let raw: f64 = self
.trees
.iter()
.map(|t| t.density(&q, &self.point_store))
.sum::<f64>()
/ self.trees.len() as f64;
Ok(raw)
}
pub fn near_neighbors(
&self,
query: &[f32],
top_k: usize,
percentile: usize,
) -> Result<Vec<NeighborResult>> {
let q = self.prepare_query(query)?;
let mode = ScoreMode::standard();
let candidates = self.collect_neighbor_candidates(&q, &mode, percentile);
Ok(self.aggregate_neighbor_candidates(candidates, top_k))
}
pub fn impute(&self, query: &[f32], missing: &[usize], centrality: f64) -> Result<Vec<f32>> {
if missing.is_empty() {
return Err(RcfError::InvalidArgument("missing list is empty".into()));
}
let dim = self.config.dim();
if query.len() != dim {
return Err(RcfError::DimensionMismatch {
expected: dim,
got: query.len(),
});
}
let missing_flags = make_missing_flags(missing, dim)?;
let mut seed_rng = self.rng.clone();
let candidate_idxs = self.collect_conditional_candidate_indices(
query,
&missing_flags,
centrality,
seed_rng.next_u64(),
);
if candidate_idxs.is_empty() {
return Err(RcfError::NotReady);
}
let mut result = query.to_vec();
self.impute_dimensions_from_candidates(&mut result, missing, &candidate_idxs);
Ok(result)
}
pub fn extrapolate(&self, look_ahead: usize) -> Result<Vec<f32>> {
if !self.config.internal_shingling {
return Err(RcfError::InvalidArgument(
"extrapolation requires internal_shingling = true".into(),
));
}
if self.config.shingle_size <= 1 {
return Err(RcfError::InvalidArgument(
"extrapolation requires shingle_size > 1".into(),
));
}
if look_ahead == 0 {
return Ok(Vec::new());
}
let shingle_size = self.config.shingle_size;
if look_ahead > shingle_size {
return Err(RcfError::InvalidArgument(format!(
"extrapolation requires look_ahead <= shingle_size (got {look_ahead}, shingle_size={})",
shingle_size
)));
}
let input_dim = self.config.input_dim;
let dim = self.config.dim();
let mut fictitious = self.point_store.current_shingled().to_vec();
let mut result = Vec::with_capacity(look_ahead * input_dim);
let mut rng = self.rng.clone();
let _ = rng.next_u64();
for step in 0..look_ahead {
let missing_indices = self.point_store.next_indices(step);
let missing_flags = make_missing_flags(&missing_indices, dim)?;
let seed = rng.next_u64();
let candidate_idxs =
self.collect_conditional_candidate_indices(&fictitious, &missing_flags, 1.0, seed);
if candidate_idxs.is_empty() {
return Err(RcfError::NotReady);
}
for &mi in &missing_indices {
let median = self.median_for_dimension(&candidate_idxs, mi);
fictitious[mi] = median;
result.push(median);
}
}
Ok(result)
}
#[cfg(feature = "serde")]
pub fn to_json(&self) -> Result<String> {
serde_json::to_string(self).map_err(|e| RcfError::Io(e.to_string()))
}
#[cfg(feature = "serde")]
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json).map_err(|e| RcfError::Io(e.to_string()))
}
#[cfg(all(feature = "serde", feature = "std"))]
pub fn save_json(&self, path: impl Into<std::path::PathBuf>) -> Result<()> {
let json = self.to_json()?;
std::fs::write(path.into(), json).map_err(|e| RcfError::Io(e.to_string()))
}
#[cfg(all(feature = "serde", feature = "std"))]
pub fn load_json(path: impl Into<std::path::PathBuf>) -> Result<Self> {
let data = std::fs::read_to_string(path.into()).map_err(|e| RcfError::Io(e.to_string()))?;
Self::from_json(&data)
}
pub fn entries_seen(&self) -> u64 {
self.entries_seen
}
pub fn num_trees(&self) -> usize {
self.trees.len()
}
pub fn config(&self) -> &RcfConfig {
&self.config
}
fn prepare_query(&self, query: &[f32]) -> Result<Vec<f32>> {
let base_dim = self.config.input_dim;
let full_dim = self.config.dim();
if query.len() == base_dim && self.config.internal_shingling {
let mut buf = self.point_store.current_shingled().to_vec();
let start = full_dim - base_dim;
buf[start..].copy_from_slice(query);
Ok(buf)
} else if query.len() == full_dim {
Ok(query.to_vec())
} else {
Err(RcfError::DimensionMismatch {
expected: full_dim,
got: query.len(),
})
}
}
fn forest_score(&self, query: &[f32], mode: &ScoreMode) -> f64 {
self.forest_score_sequential(query, mode)
}
fn forest_score_sequential(&self, query: &[f32], mode: &ScoreMode) -> f64 {
if self.trees.is_empty() {
return 0.0;
}
let sum: f64 = self
.trees
.iter()
.map(|t| t.raw_score(query, &self.point_store, mode))
.sum();
sum / self.trees.len() as f64
}
fn collect_neighbor_candidates(
&self,
query: &[f32],
mode: &ScoreMode,
percentile: usize,
) -> Vec<NeighborCandidate> {
self.collect_neighbor_candidates_sequential(query, mode, percentile)
}
fn collect_neighbor_candidates_sequential(
&self,
query: &[f32],
mode: &ScoreMode,
percentile: usize,
) -> Vec<NeighborCandidate> {
let mut candidates = Vec::with_capacity(self.trees.len() * 2);
for tree in &self.trees {
tree.near_neighbors_into(query, &self.point_store, mode, percentile, &mut candidates);
}
candidates
}
fn aggregate_neighbor_candidates(
&self,
candidates: Vec<NeighborCandidate>,
top_k: usize,
) -> Vec<NeighborResult> {
if candidates.is_empty() || top_k == 0 {
return Vec::new();
}
let n = self.trees.len() as f64;
let mut merged: BTreeMap<usize, (f64, f64)> = BTreeMap::new();
for item in candidates {
let entry = merged.entry(item.point_idx).or_insert((0.0, f64::MAX));
entry.0 += item.score;
entry.1 = entry.1.min(item.distance);
}
let mut aggregated: Vec<NeighborCandidate> = merged
.into_iter()
.map(|(point_idx, (score_sum, dist_min))| NeighborCandidate {
score: score_sum / n,
point_idx,
distance: dist_min,
})
.collect();
let limit = top_k.min(aggregated.len());
if limit < aggregated.len() {
aggregated
.select_nth_unstable_by(limit - 1, |a, b| cmp_distance(a.distance, b.distance));
aggregated.truncate(limit);
}
aggregated.sort_unstable_by(|a, b| cmp_distance(a.distance, b.distance));
aggregated
.into_iter()
.map(|item| NeighborResult {
score: item.score,
point: self.point_store.copy_point(item.point_idx),
distance: item.distance,
})
.collect()
}
fn collect_conditional_candidate_indices(
&self,
query: &[f32],
missing_flags: &[bool],
centrality: f64,
seed: u64,
) -> Vec<usize> {
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
self.trees
.iter()
.filter_map(|tree| {
let tree_seed = rng.next_u64();
tree.conditional_field(
query,
missing_flags,
&self.point_store,
centrality,
tree_seed,
)
.map(|c| c.point_idx)
})
.collect()
}
fn impute_dimensions_from_candidates(
&self,
result: &mut [f32],
missing: &[usize],
candidate_idxs: &[usize],
) {
for &mi in missing {
result[mi] = self.median_for_dimension(candidate_idxs, mi);
}
}
fn median_for_dimension(&self, candidate_idxs: &[usize], dim_idx: usize) -> f32 {
let mut vals: Vec<f32> = candidate_idxs
.iter()
.map(|&i| self.point_store.get(i)[dim_idx])
.collect();
median_in_place(&mut vals)
}
}
fn cmp_distance(a: f64, b: f64) -> core::cmp::Ordering {
match (a.is_nan(), b.is_nan()) {
(true, true) => core::cmp::Ordering::Equal,
(true, false) => core::cmp::Ordering::Greater,
(false, true) => core::cmp::Ordering::Less,
(false, false) => a.total_cmp(&b),
}
}
pub struct ForestBuilder {
config: RcfConfig,
seed: Option<u64>,
}
impl ForestBuilder {
pub fn new(input_dim: usize) -> Self {
let config = RcfConfig::new(input_dim);
ForestBuilder { config, seed: None }
}
pub fn shingle_size(mut self, n: usize) -> Self {
self.config = self.config.with_shingle_size(n);
self
}
pub fn num_trees(mut self, n: usize) -> Self {
self.config = self.config.with_num_trees(n);
self
}
pub fn capacity(mut self, c: usize) -> Self {
self.config = self.config.with_capacity(c);
self
}
pub fn time_decay(mut self, d: f64) -> Self {
self.config = self.config.with_time_decay(d);
self
}
pub fn output_after(mut self, n: usize) -> Self {
self.config = self.config.with_output_after(n);
self
}
pub fn internal_shingling(mut self, v: bool) -> Self {
self.config = self.config.with_internal_shingling(v);
self
}
pub fn initial_accept_fraction(mut self, f: f64) -> Self {
self.config = self.config.with_initial_accept_fraction(f);
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = Some(s);
self
}
pub fn build(self) -> Result<Forest> {
match self.seed {
Some(s) => Forest::from_config_seeded(&self.config, s),
None => Forest::from_config(&self.config),
}
}
}
#[cfg(test)]
mod tests {
use approx::assert_abs_diff_eq;
use rstest::*;
use super::*;
use crate::score::attribution_total;
fn make_forest() -> Forest {
Forest::builder(2)
.shingle_size(1)
.num_trees(10)
.capacity(64)
.output_after(10)
.seed(42)
.build()
.unwrap()
}
#[test]
fn builder_uses_default_shingle_size() {
let f = Forest::builder(2).build().unwrap();
assert_eq!(f.config().shingle_size, 1);
}
#[test]
fn builder_applies_explicit_shingle_size() {
let f = Forest::builder(2).shingle_size(4).build().unwrap();
assert_eq!(f.config().shingle_size, 4);
}
#[test]
fn forest_not_ready_initially() {
let f = make_forest();
assert!(!f.is_ready());
}
#[test]
fn forest_ready_after_enough_updates() {
let mut f = make_forest();
for i in 0..100 {
f.update(&[i as f32, 0.5]).unwrap();
}
assert!(f.is_ready());
}
#[rstest]
#[case([100.0f32, 100.0])]
#[case([-50.0f32, -50.0])]
#[case([0.0f32, 500.0])]
fn outlier_scores_higher_than_inlier(#[case] outlier: [f32; 2]) {
let mut f = make_forest();
for _ in 0..200 {
f.update(&[0.5f32, 0.5]).unwrap();
}
let inlier = f.score(&[0.5f32, 0.5]).unwrap();
let out = f.score(&outlier).unwrap();
assert!(
out > inlier,
"outlier={out:.4} should be > inlier={inlier:.4}"
);
}
#[test]
fn score_zero_before_ready() {
let f = make_forest();
let s = f.score(&[1.0f32, 2.0]).unwrap();
assert_abs_diff_eq!(s, 0.0, epsilon = 1e-12);
}
#[test]
fn attribution_sums_close_to_score() {
let mut f = make_forest();
for i in 0..200 {
f.update(&[(i % 5) as f32 * 0.1, 0.5]).unwrap();
}
let query = &[5.0f32, 0.5];
let score = f.score(query).unwrap();
let attr = f.attribution(query).unwrap();
let attr_total: f64 = attribution_total(&attr);
let ratio = attr_total / score;
assert!(
(0.05..=1.01).contains(&ratio),
"attr_total={attr_total:.4} score={score:.4} ratio={ratio:.4}"
);
}
#[test]
#[cfg(all(feature = "serde", feature = "std"))]
fn save_load_roundtrip() {
let mut f = make_forest();
for i in 0..200 {
f.update(&[i as f32 * 0.01, 0.5]).unwrap();
}
let query = &[0.5f32, 0.5];
let score_before = f.score(query).unwrap();
let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().join("forest.json");
f.save_json(&path).unwrap();
let f2 = Forest::load_json(&path).unwrap();
let score_after = f2.score(query).unwrap();
assert_abs_diff_eq!(score_before, score_after, epsilon = 1e-10);
}
#[test]
fn shingling_forest_update_and_score() {
let mut f = Forest::builder(1)
.shingle_size(4)
.num_trees(10)
.capacity(64)
.output_after(10)
.internal_shingling(true)
.seed(7)
.build()
.unwrap();
for i in 0..200 {
let v = (i as f32 * 0.1).sin();
f.update(&[v]).unwrap();
}
assert!(f.is_ready());
let s = f.score(&[0.0f32]).unwrap();
assert!(s >= 0.0);
}
#[test]
fn extrapolate_returns_expected_length() {
let mut f = Forest::builder(1)
.shingle_size(4)
.num_trees(10)
.capacity(64)
.output_after(10)
.internal_shingling(true)
.seed(17)
.build()
.unwrap();
for i in 0..200 {
let v = (i as f32 * 0.1).sin();
f.update(&[v]).unwrap();
}
let look_ahead = 3;
let out = f.extrapolate(look_ahead).unwrap();
assert_eq!(out.len(), look_ahead * f.config().input_dim);
assert!(out.iter().all(|x| x.is_finite()));
}
#[test]
fn extrapolate_requires_internal_shingling() {
let mut f = Forest::builder(1)
.shingle_size(4)
.num_trees(10)
.capacity(64)
.output_after(10)
.internal_shingling(false)
.seed(19)
.build()
.unwrap();
for i in 0..200 {
let v = (i as f32 * 0.1).sin();
f.update(&[v, v, v, v]).unwrap();
}
let err = f.extrapolate(1).unwrap_err();
assert!(
matches!(err, RcfError::InvalidArgument(msg) if msg.contains("internal_shingling"))
);
}
#[test]
fn extrapolate_rejects_look_ahead_beyond_shingle_size() {
let mut f = Forest::builder(1)
.shingle_size(4)
.num_trees(10)
.capacity(64)
.output_after(10)
.internal_shingling(true)
.seed(23)
.build()
.unwrap();
for i in 0..200 {
let v = (i as f32 * 0.1).sin();
f.update(&[v]).unwrap();
}
let err = f.extrapolate(5).unwrap_err();
assert!(matches!(err, RcfError::InvalidArgument(msg) if msg.contains("look_ahead")));
}
#[rstest]
#[case::top1(1)]
#[case::top5(5)]
#[case::top7(7)]
#[case::top15(15)]
fn near_neighbors_sorted_and_bounded(#[case] top_k: usize) {
let mut f = make_forest();
for i in 0..300 {
let x = (i as f32 * 0.07).sin();
let y = (i as f32 * 0.11).cos();
f.update(&[x, y]).unwrap();
}
let neighbors = f.near_neighbors(&[0.1, -0.2], top_k, 0).unwrap();
assert!(neighbors.len() <= top_k);
for w in neighbors.windows(2) {
assert!(
w[0].distance <= w[1].distance,
"neighbors are not sorted by distance"
);
}
}
#[rstest]
#[case::odd_3(vec![7.0f32, 1.0, 5.0], 5.0f32)]
#[case::even_4(vec![8.0f32, 2.0, 6.0, 4.0], 5.0f32)]
#[case::single(vec![3.0f32], 3.0f32)]
#[case::two(vec![2.0f32, 8.0], 5.0f32)]
fn median_in_place_handles_odd_and_even_lengths(
#[case] mut data: Vec<f32>,
#[case] expected: f32,
) {
let m = median_in_place(&mut data);
assert_abs_diff_eq!(m, expected, epsilon = f32::EPSILON);
}
fn make_anomaly_forest() -> Forest {
Forest::builder(2)
.shingle_size(4)
.num_trees(50)
.capacity(512)
.output_after(50)
.internal_shingling(true)
.seed(1234)
.build()
.unwrap()
}
fn normal_cluster_points(n: usize, seed: u64) -> Vec<[f32; 2]> {
let mut rng = SmallRng::seed_from_u64(seed);
(0..n)
.map(|_| {
let dx: f32 = rng.random_range(-0.15f32..0.15);
let dy: f32 = rng.random_range(-0.15f32..0.15);
[0.5 + dx, 0.5 + dy]
})
.collect()
}
fn warm_up_forest(f: &mut Forest) -> f64 {
for pt in normal_cluster_points(250, 42) {
f.update(&pt).unwrap();
}
assert!(f.is_ready(), "forest must be ready after warm-up");
f.score(&[0.5f32, 0.5]).unwrap()
}
#[rstest]
#[case::far_positive([10.0f32, 10.0], 0)] #[case::far_negative([-8.0f32, -8.0], 1)] #[case::axis_spike([0.5f32, 15.0], 0)] fn anomaly_detection_simulation(
#[case] anomaly: [f32; 2],
#[case] dominant_direction: usize,
) {
let mut f = make_anomaly_forest();
let normal_score = warm_up_forest(&mut f);
let normal_attr = f.attribution(&[0.5f32, 0.5]).unwrap();
let normal_attr_total: f64 = attribution_total(&normal_attr);
let normal_ratio = if normal_score > 0.0 {
normal_attr_total / normal_score
} else {
1.0
};
assert!(
normal_ratio <= 1.01,
"attribution total {normal_attr_total:.4} exceeds score {normal_score:.4}"
);
let anomaly_score = f.score(&anomaly).unwrap();
assert!(
anomaly_score > normal_score * 2.0,
"anomaly score {anomaly_score:.4} not > 2× normal {normal_score:.4} \
for point {anomaly:?}"
);
let disp = f.displacement_score(&anomaly).unwrap();
assert!(
disp > 0.0,
"displacement score {disp:.4} should be positive for {anomaly:?}"
);
let attr = f.attribution(&anomaly).unwrap();
let input_dim = f.config().input_dim;
let current_slot = &attr[attr.len() - input_dim..];
let total_dir0: f64 = current_slot.iter().map(|a| a.below).sum();
let total_dir1: f64 = current_slot.iter().map(|a| a.above).sum();
let direction_margin = 1.01;
if dominant_direction == 0 {
assert!(
total_dir0 > total_dir1 * direction_margin,
"expected 'below' direction to dominate for {anomaly:?}: \
dir0={total_dir0:.4} dir1={total_dir1:.4}. attr={attr:?}"
);
} else {
assert!(
total_dir1 > total_dir0 * direction_margin,
"expected 'above' direction to dominate for {anomaly:?}: \
dir1={total_dir1:.4} dir0={total_dir0:.4}. attr={attr:?}"
);
}
let anomaly_neighbors = f.near_neighbors(&anomaly, 5, 0).unwrap();
assert!(
!anomaly_neighbors.is_empty(),
"near_neighbors must return at least 1 result for {anomaly:?}"
);
for w in anomaly_neighbors.windows(2) {
assert!(
w[0].distance <= w[1].distance,
"neighbors not sorted by distance for {anomaly:?}"
);
}
let normal_neighbors = f.near_neighbors(&[0.5f32, 0.5], 5, 0).unwrap();
let normal_nn_dist = normal_neighbors.first().map(|r| r.distance).unwrap_or(0.0);
let anomaly_nn_dist = anomaly_neighbors.first().map(|r| r.distance).unwrap_or(0.0);
assert!(
anomaly_nn_dist > normal_nn_dist,
"anomaly nn-distance {anomaly_nn_dist:.4} should exceed \
normal nn-distance {normal_nn_dist:.4} for {anomaly:?}"
);
}
}