use crate::drift::adwin::Adwin;
use crate::drift::{DriftDetector, DriftSignal};
use crate::learner::StreamingLearner;
use irithyll_core::rng::xorshift64;
fn poisson(lambda: f64, rng: &mut u64) -> u64 {
let l = (-lambda).exp();
let mut k = 0u64;
let mut p = 1.0f64;
loop {
k += 1;
let u = xorshift64(rng) as f64 / u64::MAX as f64;
p *= u;
if p <= l {
return k - 1;
}
}
}
#[derive(Debug, Clone)]
pub struct ARFConfig {
pub n_trees: usize,
pub lambda: f64,
pub feature_fraction: f64,
pub drift_delta: f64,
pub warning_delta: f64,
pub seed: u64,
}
#[derive(Debug, Clone)]
pub struct ARFConfigBuilder {
n_trees: usize,
lambda: f64,
feature_fraction: f64,
drift_delta: f64,
warning_delta: f64,
seed: u64,
}
impl ARFConfig {
pub fn builder(n_trees: usize) -> ARFConfigBuilder {
ARFConfigBuilder {
n_trees,
lambda: 6.0,
feature_fraction: 0.0,
drift_delta: 1e-3,
warning_delta: 1e-2,
seed: 42,
}
}
}
impl ARFConfigBuilder {
pub fn lambda(mut self, lambda: f64) -> Self {
self.lambda = lambda;
self
}
pub fn feature_fraction(mut self, f: f64) -> Self {
self.feature_fraction = f;
self
}
pub fn drift_delta(mut self, d: f64) -> Self {
self.drift_delta = d;
self
}
pub fn warning_delta(mut self, d: f64) -> Self {
self.warning_delta = d;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = s;
self
}
pub fn build(self) -> Result<ARFConfig, irithyll_core::error::ConfigError> {
use irithyll_core::error::ConfigError;
if self.n_trees == 0 {
return Err(ConfigError::out_of_range(
"n_trees",
"must be >= 1",
self.n_trees,
));
}
if self.lambda <= 0.0 || !self.lambda.is_finite() {
return Err(ConfigError::invalid(
"lambda",
"must be positive and finite",
));
}
if self.feature_fraction < 0.0 || self.feature_fraction > 1.0 {
return Err(ConfigError::out_of_range(
"feature_fraction",
"must be in [0.0, 1.0]",
self.feature_fraction,
));
}
if self.drift_delta <= 0.0 || self.drift_delta >= 1.0 {
return Err(ConfigError::out_of_range(
"drift_delta",
"must be in (0, 1)",
self.drift_delta,
));
}
if self.warning_delta <= 0.0 || self.warning_delta >= 1.0 {
return Err(ConfigError::out_of_range(
"warning_delta",
"must be in (0, 1)",
self.warning_delta,
));
}
Ok(ARFConfig {
n_trees: self.n_trees,
lambda: self.lambda,
feature_fraction: self.feature_fraction,
drift_delta: self.drift_delta,
warning_delta: self.warning_delta,
seed: self.seed,
})
}
}
struct ARFMember {
learner: Box<dyn StreamingLearner>,
drift_detector: Adwin,
warning_detector: Adwin,
feature_mask: Vec<usize>,
n_correct: u64,
n_evaluated: u64,
}
pub struct AdaptiveRandomForest {
config: ARFConfig,
trees: Vec<ARFMember>,
n_features: usize,
n_samples: u64,
n_drifts: usize,
rng_state: u64,
factory: Box<dyn Fn() -> Box<dyn StreamingLearner> + Send + Sync>,
}
impl AdaptiveRandomForest {
pub fn new<F>(config: ARFConfig, factory: F) -> Self
where
F: Fn() -> Box<dyn StreamingLearner> + Send + Sync + 'static,
{
let mut rng = config.seed;
let trees: Vec<ARFMember> = (0..config.n_trees)
.map(|_| {
let _ = xorshift64(&mut rng);
ARFMember {
learner: factory(),
drift_detector: Adwin::with_delta(config.drift_delta),
warning_detector: Adwin::with_delta(config.warning_delta),
feature_mask: Vec::new(),
n_correct: 0,
n_evaluated: 0,
}
})
.collect();
Self {
config,
trees,
n_features: 0,
n_samples: 0,
n_drifts: 0,
rng_state: rng,
factory: Box::new(factory),
}
}
fn init_feature_masks(&mut self) {
let d = self.n_features;
let fraction = if self.config.feature_fraction == 0.0 {
(d as f64).sqrt() / d as f64
} else {
self.config.feature_fraction
};
let k = ((fraction * d as f64).ceil() as usize).max(1).min(d);
for member in &mut self.trees {
let mut indices: Vec<usize> = (0..d).collect();
for i in 0..k {
let j = i + (xorshift64(&mut self.rng_state) as usize % (d - i));
indices.swap(i, j);
}
indices.truncate(k);
indices.sort_unstable();
member.feature_mask = indices;
}
}
fn mask_features(&self, features: &[f64], mask: &[usize]) -> Vec<f64> {
if mask.is_empty() {
features.to_vec()
} else {
mask.iter().map(|&i| features[i]).collect()
}
}
pub fn train_one(&mut self, features: &[f64], target: f64) {
if self.n_features == 0 {
self.n_features = features.len();
self.init_feature_masks();
}
self.n_samples += 1;
for i in 0..self.trees.len() {
let k = poisson(self.config.lambda, &mut self.rng_state);
let masked = self.mask_features(features, &self.trees[i].feature_mask);
let pred = self.trees[i].learner.predict(&masked);
let correct = (pred.round() - target).abs() < 0.5;
self.trees[i].n_evaluated += 1;
if correct {
self.trees[i].n_correct += 1;
}
for _ in 0..k {
self.trees[i].learner.train(&masked, target);
}
let error_val = if correct { 0.0 } else { 1.0 };
let drift_signal = self.trees[i].drift_detector.update(error_val);
let _warning_signal = self.trees[i].warning_detector.update(error_val);
if matches!(drift_signal, DriftSignal::Drift) {
self.trees[i].learner = (self.factory)();
self.trees[i].drift_detector = Adwin::with_delta(self.config.drift_delta);
self.trees[i].warning_detector = Adwin::with_delta(self.config.warning_delta);
self.trees[i].n_correct = 0;
self.trees[i].n_evaluated = 0;
self.n_drifts += 1;
let d = self.n_features;
let fraction = if self.config.feature_fraction == 0.0 {
(d as f64).sqrt() / d as f64
} else {
self.config.feature_fraction
};
let k_features = ((fraction * d as f64).ceil() as usize).max(1).min(d);
let mut indices: Vec<usize> = (0..d).collect();
for j in 0..k_features {
let swap = j + (xorshift64(&mut self.rng_state) as usize % (d - j));
indices.swap(j, swap);
}
indices.truncate(k_features);
indices.sort_unstable();
self.trees[i].feature_mask = indices;
}
}
}
pub fn predict(&self, features: &[f64]) -> f64 {
let votes = self.predict_votes(features);
votes
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(class, _)| class)
.unwrap_or(0.0)
}
pub fn predict_votes(&self, features: &[f64]) -> Vec<(f64, u64)> {
let mut vote_map: Vec<(f64, u64)> = Vec::new();
for member in &self.trees {
let masked = self.mask_features(features, &member.feature_mask);
let pred = member.learner.predict(&masked).round();
if let Some(entry) = vote_map.iter_mut().find(|(c, _)| (*c - pred).abs() < 0.5) {
entry.1 += 1;
} else {
vote_map.push((pred, 1));
}
}
vote_map
}
pub fn n_trees(&self) -> usize {
self.config.n_trees
}
pub fn n_samples_seen(&self) -> u64 {
self.n_samples
}
pub fn tree_accuracies(&self) -> Vec<f64> {
self.trees
.iter()
.map(|m| {
if m.n_evaluated == 0 {
0.0
} else {
m.n_correct as f64 / m.n_evaluated as f64
}
})
.collect()
}
pub fn n_drifts_detected(&self) -> usize {
self.n_drifts
}
}
impl StreamingLearner for AdaptiveRandomForest {
fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
self.train_one(features, target);
}
fn predict(&self, features: &[f64]) -> f64 {
self.predict(features)
}
fn n_samples_seen(&self) -> u64 {
self.n_samples
}
fn reset(&mut self) {
self.n_samples = 0;
self.n_drifts = 0;
for member in &mut self.trees {
member.n_correct = 0;
member.n_evaluated = 0;
}
}
#[allow(deprecated)]
fn replacement_count(&self) -> u64 {
<Self as crate::learner::Structural>::replacement_count(self)
}
}
impl crate::learner::Structural for AdaptiveRandomForest {
fn apply_structural_change(&mut self, _depth_delta: i32, _steps_delta: i32) {
}
fn replacement_count(&self) -> u64 {
self.n_drifts as u64
}
}
impl crate::automl::DiagnosticSource for AdaptiveRandomForest {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockClassifier {
prediction: f64,
n: u64,
}
impl MockClassifier {
fn new(prediction: f64) -> Self {
Self { prediction, n: 0 }
}
}
impl StreamingLearner for MockClassifier {
fn train_one(&mut self, _features: &[f64], _target: f64, _weight: f64) {
self.n += 1;
}
fn predict(&self, _features: &[f64]) -> f64 {
self.prediction
}
fn n_samples_seen(&self) -> u64 {
self.n
}
fn reset(&mut self) {
self.n = 0;
}
}
#[test]
fn arf_trains_and_predicts() {
let config = ARFConfig::builder(3).seed(42).build().unwrap();
let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
arf.train_one(&[1.0, 2.0], 1.0);
let pred = arf.predict(&[1.0, 2.0]);
assert_eq!(pred, 1.0);
assert_eq!(arf.n_samples_seen(), 1);
}
#[test]
fn arf_majority_vote() {
let config = ARFConfig::builder(5).seed(42).build().unwrap();
let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
arf.n_features = 2;
arf.init_feature_masks();
let votes = arf.predict_votes(&[1.0, 2.0]);
assert_eq!(votes.len(), 1, "all trees should agree");
assert_eq!(votes[0], (0.0, 5), "5 votes for class 0");
assert_eq!(arf.predict(&[1.0, 2.0]), 0.0);
}
#[test]
fn arf_poisson_valid() {
let mut rng = 12345u64;
let mut total = 0u64;
let n = 1000;
for _ in 0..n {
total += poisson(6.0, &mut rng);
}
let mean = total as f64 / n as f64;
assert!(
(mean - 6.0).abs() < 1.0,
"Poisson mean should be ~6.0, got {}",
mean
);
}
#[test]
fn arf_feature_subspace() {
let config = ARFConfig::builder(3)
.feature_fraction(0.5)
.seed(42)
.build()
.unwrap();
let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
arf.train_one(&[1.0, 2.0, 3.0, 4.0], 0.0);
for member in &arf.trees {
assert_eq!(
member.feature_mask.len(),
2,
"expected 2 features, got {}",
member.feature_mask.len()
);
}
}
#[test]
fn arf_streaming_learner_trait() {
let config = ARFConfig::builder(3).seed(42).build().unwrap();
let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
let learner: &mut dyn StreamingLearner = &mut arf;
learner.train(&[1.0, 2.0], 0.0);
assert_eq!(learner.n_samples_seen(), 1);
let pred = learner.predict(&[1.0, 2.0]);
assert_eq!(pred, 0.0);
}
#[test]
fn arf_config_validates() {
assert!(ARFConfig::builder(0).build().is_err());
assert!(ARFConfig::builder(3).lambda(0.0).build().is_err());
assert!(ARFConfig::builder(3).lambda(-1.0).build().is_err());
assert!(ARFConfig::builder(3)
.feature_fraction(-0.1)
.build()
.is_err());
assert!(ARFConfig::builder(3).feature_fraction(1.1).build().is_err());
assert!(ARFConfig::builder(3).drift_delta(0.0).build().is_err());
assert!(ARFConfig::builder(3).drift_delta(1.0).build().is_err());
assert!(ARFConfig::builder(3).build().is_ok());
}
#[test]
fn arf_tree_accuracies() {
let config = ARFConfig::builder(3).seed(42).build().unwrap();
let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
for _ in 0..10 {
arf.train_one(&[1.0, 2.0], 1.0);
}
let accs = arf.tree_accuracies();
assert_eq!(accs.len(), 3);
for &acc in &accs {
assert!(
acc > 0.9,
"accuracy should be ~1.0 for correct mock, got {}",
acc
);
}
}
}