Skip to main content

irithyll_core/ensemble/
adaptive_forest.rs

1//! Adaptive Random Forest (ARF) for streaming classification.
2//!
3//! An ensemble of streaming learners with ADWIN-based drift detection and
4//! automatic tree replacement. Each member trains on a Poisson(lambda)-weighted
5//! bootstrap of the stream with a random feature subspace.
6//!
7//! # Algorithm
8//!
9//! For each new sample *(x, y)* and each tree *t*:
10//!
11//! 1. Draw *k ~ Poisson(lambda)* -- the bootstrap weight.
12//! 2. Predict *y_hat = tree_t(x\[mask_t\])* before training (for drift detection).
13//! 3. Train *k* times on the masked feature vector.
14//! 4. Feed correctness (0.0 = correct, 1.0 = incorrect) to the ADWIN detector.
15//! 5. On drift: reset the tree and detector, allowing a fresh learner to adapt
16//!    to the new distribution.
17//!
18//! Final prediction is majority vote across all trees.
19//!
20//! # Reference
21//!
22//! Gomes, H. M., et al. (2017). "Adaptive random forests for evolving data
23//! stream classification." *Machine Learning*, 106(9-10), 1469-1495.
24
25use alloc::boxed::Box;
26use alloc::string::String;
27use alloc::vec::Vec;
28
29use crate::drift::adwin::Adwin;
30use crate::drift::{DriftDetector, DriftSignal};
31use crate::learner::StreamingLearner;
32
33use crate::rng::xorshift64;
34
35// ---------------------------------------------------------------------------
36// Utilities
37// ---------------------------------------------------------------------------
38
39fn poisson(lambda: f64, rng: &mut u64) -> u64 {
40    let l = crate::math::exp(-lambda);
41    let mut k = 0u64;
42    let mut p = 1.0f64;
43    loop {
44        k += 1;
45        let u = xorshift64(rng) as f64 / u64::MAX as f64;
46        p *= u;
47        if p <= l {
48            return k - 1;
49        }
50    }
51}
52
53// ---------------------------------------------------------------------------
54// ARFConfig
55// ---------------------------------------------------------------------------
56
57/// Configuration for [`AdaptiveRandomForest`].
58#[derive(Debug, Clone)]
59pub struct ARFConfig {
60    /// Number of trees in the forest.
61    pub n_trees: usize,
62    /// Poisson lambda for bootstrap resampling (default 6.0).
63    pub lambda: f64,
64    /// Fraction of features per tree. 0.0 = auto (sqrt(d)/d).
65    pub feature_fraction: f64,
66    /// ADWIN delta for drift detection (default 1e-3).
67    pub drift_delta: f64,
68    /// ADWIN delta for warning detection (default 1e-2).
69    pub warning_delta: f64,
70    /// Random seed.
71    pub seed: u64,
72}
73
74/// Builder for [`ARFConfig`].
75#[derive(Debug, Clone)]
76pub struct ARFConfigBuilder {
77    n_trees: usize,
78    lambda: f64,
79    feature_fraction: f64,
80    drift_delta: f64,
81    warning_delta: f64,
82    seed: u64,
83}
84
85impl ARFConfig {
86    /// Start building a configuration with the given number of trees.
87    pub fn builder(n_trees: usize) -> ARFConfigBuilder {
88        ARFConfigBuilder {
89            n_trees,
90            lambda: 6.0,
91            feature_fraction: 0.0,
92            drift_delta: 1e-3,
93            warning_delta: 1e-2,
94            seed: 42,
95        }
96    }
97}
98
99impl ARFConfigBuilder {
100    /// Poisson lambda for bootstrap resampling.
101    pub fn lambda(mut self, lambda: f64) -> Self {
102        self.lambda = lambda;
103        self
104    }
105
106    /// Fraction of features per tree. 0.0 = auto (sqrt).
107    pub fn feature_fraction(mut self, f: f64) -> Self {
108        self.feature_fraction = f;
109        self
110    }
111
112    /// ADWIN delta for drift detection.
113    pub fn drift_delta(mut self, d: f64) -> Self {
114        self.drift_delta = d;
115        self
116    }
117
118    /// ADWIN delta for warning detection.
119    pub fn warning_delta(mut self, d: f64) -> Self {
120        self.warning_delta = d;
121        self
122    }
123
124    /// Random seed.
125    pub fn seed(mut self, s: u64) -> Self {
126        self.seed = s;
127        self
128    }
129
130    /// Build the configuration, validating all parameters.
131    pub fn build(self) -> Result<ARFConfig, String> {
132        if self.n_trees == 0 {
133            return Err("n_trees must be >= 1".into());
134        }
135        if self.lambda <= 0.0 || !self.lambda.is_finite() {
136            return Err("lambda must be positive and finite".into());
137        }
138        if self.feature_fraction < 0.0 || self.feature_fraction > 1.0 {
139            return Err("feature_fraction must be in [0.0, 1.0]".into());
140        }
141        if self.drift_delta <= 0.0 || self.drift_delta >= 1.0 {
142            return Err("drift_delta must be in (0, 1)".into());
143        }
144        if self.warning_delta <= 0.0 || self.warning_delta >= 1.0 {
145            return Err("warning_delta must be in (0, 1)".into());
146        }
147        Ok(ARFConfig {
148            n_trees: self.n_trees,
149            lambda: self.lambda,
150            feature_fraction: self.feature_fraction,
151            drift_delta: self.drift_delta,
152            warning_delta: self.warning_delta,
153            seed: self.seed,
154        })
155    }
156}
157
158// ---------------------------------------------------------------------------
159// ARFMember
160// ---------------------------------------------------------------------------
161
162struct ARFMember {
163    learner: Box<dyn StreamingLearner>,
164    drift_detector: Adwin,
165    warning_detector: Adwin,
166    feature_mask: Vec<usize>,
167    n_correct: u64,
168    n_evaluated: u64,
169}
170
171// ---------------------------------------------------------------------------
172// AdaptiveRandomForest
173// ---------------------------------------------------------------------------
174
175/// Adaptive Random Forest for streaming classification.
176///
177/// An ensemble of `n_trees` streaming learners, each trained on a
178/// Poisson-weighted bootstrap with a random feature subspace. ADWIN drift
179/// detection automatically resets individual trees when their error rate
180/// changes significantly.
181///
182/// # Example
183///
184/// ```text
185/// use irithyll::{AdaptiveRandomForest, StreamingLearner};
186/// use irithyll::ensemble::adaptive_forest::ARFConfig;
187/// use irithyll::learners::linear::StreamingLinearModel;
188///
189/// let config = ARFConfig::builder(5).lambda(6.0).build().unwrap();
190/// let mut arf = AdaptiveRandomForest::new(config, || {
191///     Box::new(StreamingLinearModel::new(0.01))
192/// });
193///
194/// arf.train_one(&[1.0, 0.0], 0.0);
195/// let pred = arf.predict(&[1.0, 0.0]);
196/// ```
197pub struct AdaptiveRandomForest {
198    config: ARFConfig,
199    trees: Vec<ARFMember>,
200    n_features: usize,
201    n_samples: u64,
202    n_drifts: usize,
203    rng_state: u64,
204    /// Stored factory for creating replacement learners on drift.
205    factory: Box<dyn Fn() -> Box<dyn StreamingLearner> + Send + Sync>,
206}
207
208impl AdaptiveRandomForest {
209    /// Create a new ARF with the given config and learner factory.
210    ///
211    /// The `factory` closure is called `n_trees` times to create the initial
212    /// ensemble, and again whenever a tree is replaced after drift detection.
213    pub fn new<F>(config: ARFConfig, factory: F) -> Self
214    where
215        F: Fn() -> Box<dyn StreamingLearner> + Send + Sync + 'static,
216    {
217        let mut rng = config.seed;
218        let trees: Vec<ARFMember> = (0..config.n_trees)
219            .map(|_| {
220                // Feature masks are initialized lazily on first train_one
221                let _ = xorshift64(&mut rng);
222                ARFMember {
223                    learner: factory(),
224                    drift_detector: Adwin::with_delta(config.drift_delta),
225                    warning_detector: Adwin::with_delta(config.warning_delta),
226                    feature_mask: Vec::new(),
227                    n_correct: 0,
228                    n_evaluated: 0,
229                }
230            })
231            .collect();
232
233        Self {
234            config,
235            trees,
236            n_features: 0,
237            n_samples: 0,
238            n_drifts: 0,
239            rng_state: rng,
240            factory: Box::new(factory),
241        }
242    }
243
244    /// Initialize feature masks for all trees (called lazily on first sample).
245    fn init_feature_masks(&mut self) {
246        let d = self.n_features;
247        let fraction = if self.config.feature_fraction == 0.0 {
248            crate::math::sqrt(d as f64) / d as f64
249        } else {
250            self.config.feature_fraction
251        };
252        let k = (crate::math::ceil(fraction * d as f64) as usize)
253            .max(1)
254            .min(d);
255
256        for member in &mut self.trees {
257            // Fisher-Yates partial shuffle to select k unique indices
258            let mut indices: Vec<usize> = (0..d).collect();
259            for i in 0..k {
260                let j = i + (xorshift64(&mut self.rng_state) as usize % (d - i));
261                indices.swap(i, j);
262            }
263            indices.truncate(k);
264            indices.sort_unstable();
265            member.feature_mask = indices;
266        }
267    }
268
269    /// Extract masked features for a given tree member.
270    fn mask_features(&self, features: &[f64], mask: &[usize]) -> Vec<f64> {
271        if mask.is_empty() {
272            features.to_vec()
273        } else {
274            mask.iter().map(|&i| features[i]).collect()
275        }
276    }
277
278    /// Train on a single sample.
279    pub fn train_one(&mut self, features: &[f64], target: f64) {
280        if self.n_features == 0 {
281            self.n_features = features.len();
282            self.init_feature_masks();
283        }
284        self.n_samples += 1;
285
286        for i in 0..self.trees.len() {
287            let k = poisson(self.config.lambda, &mut self.rng_state);
288            let masked = self.mask_features(features, &self.trees[i].feature_mask);
289
290            // Predict before training (for drift detection).
291            let pred = self.trees[i].learner.predict(&masked);
292            let correct = crate::math::abs(crate::math::round(pred) - target) < 0.5;
293            self.trees[i].n_evaluated += 1;
294            if correct {
295                self.trees[i].n_correct += 1;
296            }
297
298            // Train k times (Poisson-weighted bootstrap).
299            for _ in 0..k {
300                self.trees[i].learner.train(&masked, target);
301            }
302
303            // Feed error signal to drift detectors.
304            let error_val = if correct { 0.0 } else { 1.0 };
305            let drift_signal = self.trees[i].drift_detector.update(error_val);
306            let _warning_signal = self.trees[i].warning_detector.update(error_val);
307
308            // On drift: replace the tree.
309            if matches!(drift_signal, DriftSignal::Drift) {
310                self.trees[i].learner = (self.factory)();
311                self.trees[i].drift_detector = Adwin::with_delta(self.config.drift_delta);
312                self.trees[i].warning_detector = Adwin::with_delta(self.config.warning_delta);
313                self.trees[i].n_correct = 0;
314                self.trees[i].n_evaluated = 0;
315                self.n_drifts += 1;
316
317                // Re-init feature mask for the new tree.
318                let d = self.n_features;
319                let fraction = if self.config.feature_fraction == 0.0 {
320                    crate::math::sqrt(d as f64) / d as f64
321                } else {
322                    self.config.feature_fraction
323                };
324                let k_features = (crate::math::ceil(fraction * d as f64) as usize)
325                    .max(1)
326                    .min(d);
327                let mut indices: Vec<usize> = (0..d).collect();
328                for j in 0..k_features {
329                    let swap = j + (xorshift64(&mut self.rng_state) as usize % (d - j));
330                    indices.swap(j, swap);
331                }
332                indices.truncate(k_features);
333                indices.sort_unstable();
334                self.trees[i].feature_mask = indices;
335            }
336        }
337    }
338
339    /// Predict by majority vote across all trees.
340    ///
341    /// Each tree casts a vote for a class (prediction rounded to nearest
342    /// integer). The class with the most votes wins.
343    pub fn predict(&self, features: &[f64]) -> f64 {
344        let votes = self.predict_votes(features);
345        votes
346            .into_iter()
347            .max_by_key(|&(_, count)| count)
348            .map(|(class, _)| class)
349            .unwrap_or(0.0)
350    }
351
352    /// Vote counts per predicted class.
353    pub fn predict_votes(&self, features: &[f64]) -> Vec<(f64, u64)> {
354        let mut vote_map: Vec<(f64, u64)> = Vec::new();
355        for member in &self.trees {
356            let masked = self.mask_features(features, &member.feature_mask);
357            let pred = crate::math::round(member.learner.predict(&masked));
358            if let Some(entry) = vote_map
359                .iter_mut()
360                .find(|(c, _)| crate::math::abs(*c - pred) < 0.5)
361            {
362                entry.1 += 1;
363            } else {
364                vote_map.push((pred, 1));
365            }
366        }
367        vote_map
368    }
369
370    /// Number of trees in the ensemble.
371    pub fn n_trees(&self) -> usize {
372        self.config.n_trees
373    }
374
375    /// Total samples processed.
376    pub fn n_samples_seen(&self) -> u64 {
377        self.n_samples
378    }
379
380    /// Per-tree accuracy (correct / evaluated).
381    pub fn tree_accuracies(&self) -> Vec<f64> {
382        self.trees
383            .iter()
384            .map(|m| {
385                if m.n_evaluated == 0 {
386                    0.0
387                } else {
388                    m.n_correct as f64 / m.n_evaluated as f64
389                }
390            })
391            .collect()
392    }
393
394    /// Total number of drift-triggered tree replacements.
395    pub fn n_drifts_detected(&self) -> usize {
396        self.n_drifts
397    }
398}
399
400impl StreamingLearner for AdaptiveRandomForest {
401    fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
402        self.train_one(features, target);
403    }
404
405    fn predict(&self, features: &[f64]) -> f64 {
406        self.predict(features)
407    }
408
409    fn n_samples_seen(&self) -> u64 {
410        self.n_samples
411    }
412
413    fn reset(&mut self) {
414        self.n_samples = 0;
415        self.n_drifts = 0;
416        for member in &mut self.trees {
417            member.n_correct = 0;
418            member.n_evaluated = 0;
419        }
420    }
421}
422
423// ---------------------------------------------------------------------------
424// Tests
425// ---------------------------------------------------------------------------
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use alloc::boxed::Box;
431
432    struct MockClassifier {
433        prediction: f64,
434        n: u64,
435    }
436
437    impl MockClassifier {
438        fn new(prediction: f64) -> Self {
439            Self { prediction, n: 0 }
440        }
441    }
442
443    impl StreamingLearner for MockClassifier {
444        fn train_one(&mut self, _features: &[f64], _target: f64, _weight: f64) {
445            self.n += 1;
446        }
447        fn predict(&self, _features: &[f64]) -> f64 {
448            self.prediction
449        }
450        fn n_samples_seen(&self) -> u64 {
451            self.n
452        }
453        fn reset(&mut self) {
454            self.n = 0;
455        }
456    }
457
458    #[test]
459    fn arf_trains_and_predicts() {
460        let config = ARFConfig::builder(3).seed(42).build().unwrap();
461        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
462
463        arf.train_one(&[1.0, 2.0], 1.0);
464        let pred = arf.predict(&[1.0, 2.0]);
465        assert_eq!(pred, 1.0);
466        assert_eq!(arf.n_samples_seen(), 1);
467    }
468
469    #[test]
470    fn arf_majority_vote() {
471        let config = ARFConfig::builder(5).seed(42).build().unwrap();
472        // All trees predict 0.0 → unanimous vote
473        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
474        // Need to init feature masks
475        arf.n_features = 2;
476        arf.init_feature_masks();
477
478        let votes = arf.predict_votes(&[1.0, 2.0]);
479        assert_eq!(votes.len(), 1, "all trees should agree");
480        assert_eq!(votes[0], (0.0, 5), "5 votes for class 0");
481        assert_eq!(arf.predict(&[1.0, 2.0]), 0.0);
482    }
483
484    #[test]
485    fn arf_poisson_valid() {
486        let mut rng = 12345u64;
487        let mut total = 0u64;
488        let n = 1000;
489        for _ in 0..n {
490            total += poisson(6.0, &mut rng);
491        }
492        let mean = total as f64 / n as f64;
493        // Poisson(6) should have mean ~6
494        assert!(
495            (mean - 6.0).abs() < 1.0,
496            "Poisson mean should be ~6.0, got {}",
497            mean
498        );
499    }
500
501    #[test]
502    fn arf_feature_subspace() {
503        let config = ARFConfig::builder(3)
504            .feature_fraction(0.5)
505            .seed(42)
506            .build()
507            .unwrap();
508        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
509
510        arf.train_one(&[1.0, 2.0, 3.0, 4.0], 0.0);
511
512        // Each tree should have ceil(0.5 * 4) = 2 features
513        for member in &arf.trees {
514            assert_eq!(
515                member.feature_mask.len(),
516                2,
517                "expected 2 features, got {}",
518                member.feature_mask.len()
519            );
520        }
521    }
522
523    #[test]
524    fn arf_streaming_learner_trait() {
525        let config = ARFConfig::builder(3).seed(42).build().unwrap();
526        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
527
528        let learner: &mut dyn StreamingLearner = &mut arf;
529        learner.train(&[1.0, 2.0], 0.0);
530        assert_eq!(learner.n_samples_seen(), 1);
531        let pred = learner.predict(&[1.0, 2.0]);
532        assert_eq!(pred, 0.0);
533    }
534
535    #[test]
536    fn arf_config_validates() {
537        assert!(ARFConfig::builder(0).build().is_err());
538        assert!(ARFConfig::builder(3).lambda(0.0).build().is_err());
539        assert!(ARFConfig::builder(3).lambda(-1.0).build().is_err());
540        assert!(ARFConfig::builder(3)
541            .feature_fraction(-0.1)
542            .build()
543            .is_err());
544        assert!(ARFConfig::builder(3).feature_fraction(1.1).build().is_err());
545        assert!(ARFConfig::builder(3).drift_delta(0.0).build().is_err());
546        assert!(ARFConfig::builder(3).drift_delta(1.0).build().is_err());
547        assert!(ARFConfig::builder(3).build().is_ok());
548    }
549
550    #[test]
551    fn arf_tree_accuracies() {
552        let config = ARFConfig::builder(3).seed(42).build().unwrap();
553        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
554
555        // Train with target=1.0, mock predicts 1.0 → all correct
556        for _ in 0..10 {
557            arf.train_one(&[1.0, 2.0], 1.0);
558        }
559
560        let accs = arf.tree_accuracies();
561        assert_eq!(accs.len(), 3);
562        for &acc in &accs {
563            assert!(
564                acc > 0.9,
565                "accuracy should be ~1.0 for correct mock, got {}",
566                acc
567            );
568        }
569    }
570}