#![cfg(feature = "std")]
use std::sync::Arc;
use crate::config::ForestBuilder;
use crate::domain::{AnomalyScore, DiVector};
use crate::error::RcfResult;
use crate::forest::RandomCutForest;
use crate::metrics::{MetricsSink, default_sink, names};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DriftRecoveryConfig {
pub shadow_warmup: u64,
pub min_primary_age: u64,
}
impl Default for DriftRecoveryConfig {
fn default() -> Self {
Self {
shadow_warmup: 1_024,
min_primary_age: 512,
}
}
}
#[derive(Debug)]
struct ShadowState<const D: usize> {
forest: RandomCutForest<D>,
seen: u64,
}
#[derive(Debug)]
pub struct DriftAwareForest<const D: usize> {
primary: RandomCutForest<D>,
shadow: Option<ShadowState<D>>,
primary_age: u64,
builder: ForestBuilder<D>,
config: DriftRecoveryConfig,
swaps: u64,
metrics: Arc<dyn MetricsSink>,
}
impl<const D: usize> DriftAwareForest<D> {
pub fn new(builder: ForestBuilder<D>, config: DriftRecoveryConfig) -> RcfResult<Self> {
let primary = builder.clone().build()?;
Ok(Self {
primary,
shadow: None,
primary_age: 0,
builder,
config,
swaps: 0,
metrics: default_sink(),
})
}
#[must_use]
pub fn with_metrics_sink(mut self, sink: Arc<dyn MetricsSink>) -> Self {
self.metrics = sink;
self
}
#[must_use]
pub fn metrics_sink(&self) -> &Arc<dyn MetricsSink> {
&self.metrics
}
#[must_use]
pub fn forest(&self) -> &RandomCutForest<D> {
&self.primary
}
#[must_use]
pub fn is_recovering(&self) -> bool {
self.shadow.is_some()
}
#[must_use]
pub fn shadow_progress(&self) -> u64 {
self.shadow.as_ref().map_or(0, |s| s.seen)
}
#[must_use]
pub fn primary_age(&self) -> u64 {
self.primary_age
}
#[must_use]
pub fn swaps_total(&self) -> u64 {
self.swaps
}
#[must_use]
pub fn config(&self) -> DriftRecoveryConfig {
self.config
}
pub fn update(&mut self, point: [f64; D]) -> RcfResult<()> {
self.primary.update(point)?;
self.primary_age = self.primary_age.saturating_add(1);
if let Some(shadow) = self.shadow.as_mut() {
match shadow.forest.update(point) {
Ok(()) => {
shadow.seen = shadow.seen.saturating_add(1);
}
Err(e) => {
self.shadow = None;
self.metrics
.set_gauge(names::DRIFT_AWARE_SHADOW_ACTIVE, 0.0);
return Err(e);
}
}
if self
.shadow
.as_ref()
.is_some_and(|s| s.seen >= self.config.shadow_warmup)
{
self.swap_shadow_into_primary();
}
}
Ok(())
}
pub fn score(&self, point: &[f64; D]) -> RcfResult<AnomalyScore> {
self.primary.score(point)
}
pub fn attribution(&self, point: &[f64; D]) -> RcfResult<DiVector> {
self.primary.attribution(point)
}
pub fn on_drift(&mut self) -> RcfResult<bool> {
if self.shadow.is_some() {
return Ok(false);
}
if self.primary_age < self.config.min_primary_age {
return Ok(false);
}
let fresh = self.builder.clone().build()?;
self.shadow = Some(ShadowState {
forest: fresh,
seen: 0,
});
self.metrics
.inc_counter(names::DRIFT_AWARE_ON_DRIFT_TOTAL, 1);
self.metrics
.set_gauge(names::DRIFT_AWARE_SHADOW_ACTIVE, 1.0);
Ok(true)
}
pub fn abort_shadow(&mut self) {
self.shadow = None;
self.metrics
.set_gauge(names::DRIFT_AWARE_SHADOW_ACTIVE, 0.0);
}
fn swap_shadow_into_primary(&mut self) {
if let Some(shadow) = self.shadow.take() {
self.primary = shadow.forest;
self.primary_age = shadow.seen;
self.swaps = self.swaps.saturating_add(1);
self.metrics.inc_counter(names::DRIFT_AWARE_SWAPS_TOTAL, 1);
self.metrics
.set_gauge(names::DRIFT_AWARE_SHADOW_ACTIVE, 0.0);
}
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::panic,
clippy::float_cmp,
clippy::cast_precision_loss
)]
mod tests {
use super::*;
fn small_builder() -> ForestBuilder<2> {
ForestBuilder::<2>::new()
.num_trees(50)
.sample_size(64)
.seed(2026)
}
#[test]
fn fresh_wrapper_has_no_shadow() {
let d = DriftAwareForest::new(small_builder(), DriftRecoveryConfig::default()).unwrap();
assert!(!d.is_recovering());
assert_eq!(d.shadow_progress(), 0);
assert_eq!(d.swaps_total(), 0);
}
#[test]
fn on_drift_requires_min_primary_age() {
let mut d = DriftAwareForest::new(
small_builder(),
DriftRecoveryConfig {
shadow_warmup: 10,
min_primary_age: 50,
},
)
.unwrap();
for _ in 0..10 {
d.update([0.1, 0.2]).unwrap();
}
assert!(!d.on_drift().unwrap());
assert!(!d.is_recovering());
}
#[test]
fn on_drift_spawns_shadow_when_primary_mature() {
let mut d = DriftAwareForest::new(
small_builder(),
DriftRecoveryConfig {
shadow_warmup: 100,
min_primary_age: 50,
},
)
.unwrap();
for _ in 0..60 {
d.update([0.1, 0.2]).unwrap();
}
assert!(d.on_drift().unwrap());
assert!(d.is_recovering());
assert_eq!(d.shadow_progress(), 0);
assert!(!d.on_drift().unwrap());
}
#[test]
fn shadow_promotes_after_warmup() {
let mut d = DriftAwareForest::new(
small_builder(),
DriftRecoveryConfig {
shadow_warmup: 30,
min_primary_age: 10,
},
)
.unwrap();
for _ in 0..20 {
d.update([0.1, 0.2]).unwrap();
}
d.on_drift().unwrap();
for i in 0..30 {
let v = f64::from(i) * 0.01;
d.update([v, v + 0.5]).unwrap();
}
assert!(!d.is_recovering());
assert_eq!(d.swaps_total(), 1);
assert_eq!(d.primary_age(), 30);
}
#[test]
fn abort_shadow_discards_recovery() {
let mut d = DriftAwareForest::new(
small_builder(),
DriftRecoveryConfig {
shadow_warmup: 100,
min_primary_age: 10,
},
)
.unwrap();
for _ in 0..20 {
d.update([0.1, 0.2]).unwrap();
}
d.on_drift().unwrap();
assert!(d.is_recovering());
d.abort_shadow();
assert!(!d.is_recovering());
assert_eq!(d.swaps_total(), 0);
}
#[test]
fn score_uses_primary_forest_always() {
let mut d = DriftAwareForest::new(small_builder(), DriftRecoveryConfig::default()).unwrap();
for i in 0..100 {
let v = f64::from(i) * 0.01;
d.update([v, v + 0.5]).unwrap();
}
let s_before: f64 = d.score(&[0.5, 1.0]).unwrap().into();
d.on_drift().unwrap();
let s_during: f64 = d.score(&[0.5, 1.0]).unwrap().into();
assert_eq!(s_before, s_during);
}
}