Skip to main content

irithyll_core/ensemble/
adaptive_forest.rs

1//! Adaptive Random Forest (ARF) for streaming classification.
2//!
3//! An ensemble of streaming learners with ADWIN-based drift detection and
4//! automatic tree replacement. Each member trains on a Poisson(lambda)-weighted
5//! bootstrap of the stream with a random feature subspace.
6//!
7//! # Algorithm
8//!
9//! For each new sample *(x, y)* and each tree *t*:
10//!
11//! 1. Draw *k ~ Poisson(lambda)* -- the bootstrap weight.
12//! 2. Predict *y_hat = tree_t(x\[mask_t\])* before training (for drift detection).
13//! 3. Train *k* times on the masked feature vector.
14//! 4. Feed correctness (0.0 = correct, 1.0 = incorrect) to the ADWIN detector.
15//! 5. On drift: reset the tree and detector, allowing a fresh learner to adapt
16//!    to the new distribution.
17//!
18//! Final prediction is majority vote across all trees.
19//!
20//! # Reference
21//!
22//! Gomes, H. M., et al. (2017). "Adaptive random forests for evolving data
23//! stream classification." *Machine Learning*, 106(9-10), 1469-1495.
24
25use alloc::boxed::Box;
26use alloc::string::String;
27use alloc::vec::Vec;
28
29use crate::drift::adwin::Adwin;
30use crate::drift::{DriftDetector, DriftSignal};
31use crate::learner::StreamingLearner;
32
33// ---------------------------------------------------------------------------
34// Utilities
35// ---------------------------------------------------------------------------
36
37fn xorshift64(state: &mut u64) -> u64 {
38    let mut x = *state;
39    x ^= x << 13;
40    x ^= x >> 7;
41    x ^= x << 17;
42    *state = x;
43    x
44}
45
46fn poisson(lambda: f64, rng: &mut u64) -> u64 {
47    let l = crate::math::exp(-lambda);
48    let mut k = 0u64;
49    let mut p = 1.0f64;
50    loop {
51        k += 1;
52        let u = xorshift64(rng) as f64 / u64::MAX as f64;
53        p *= u;
54        if p <= l {
55            return k - 1;
56        }
57    }
58}
59
60// ---------------------------------------------------------------------------
61// ARFConfig
62// ---------------------------------------------------------------------------
63
64/// Configuration for [`AdaptiveRandomForest`].
65#[derive(Debug, Clone)]
66pub struct ARFConfig {
67    /// Number of trees in the forest.
68    pub n_trees: usize,
69    /// Poisson lambda for bootstrap resampling (default 6.0).
70    pub lambda: f64,
71    /// Fraction of features per tree. 0.0 = auto (sqrt(d)/d).
72    pub feature_fraction: f64,
73    /// ADWIN delta for drift detection (default 1e-3).
74    pub drift_delta: f64,
75    /// ADWIN delta for warning detection (default 1e-2).
76    pub warning_delta: f64,
77    /// Random seed.
78    pub seed: u64,
79}
80
81/// Builder for [`ARFConfig`].
82#[derive(Debug, Clone)]
83pub struct ARFConfigBuilder {
84    n_trees: usize,
85    lambda: f64,
86    feature_fraction: f64,
87    drift_delta: f64,
88    warning_delta: f64,
89    seed: u64,
90}
91
92impl ARFConfig {
93    /// Start building a configuration with the given number of trees.
94    pub fn builder(n_trees: usize) -> ARFConfigBuilder {
95        ARFConfigBuilder {
96            n_trees,
97            lambda: 6.0,
98            feature_fraction: 0.0,
99            drift_delta: 1e-3,
100            warning_delta: 1e-2,
101            seed: 42,
102        }
103    }
104}
105
106impl ARFConfigBuilder {
107    /// Poisson lambda for bootstrap resampling.
108    pub fn lambda(mut self, lambda: f64) -> Self {
109        self.lambda = lambda;
110        self
111    }
112
113    /// Fraction of features per tree. 0.0 = auto (sqrt).
114    pub fn feature_fraction(mut self, f: f64) -> Self {
115        self.feature_fraction = f;
116        self
117    }
118
119    /// ADWIN delta for drift detection.
120    pub fn drift_delta(mut self, d: f64) -> Self {
121        self.drift_delta = d;
122        self
123    }
124
125    /// ADWIN delta for warning detection.
126    pub fn warning_delta(mut self, d: f64) -> Self {
127        self.warning_delta = d;
128        self
129    }
130
131    /// Random seed.
132    pub fn seed(mut self, s: u64) -> Self {
133        self.seed = s;
134        self
135    }
136
137    /// Build the configuration, validating all parameters.
138    pub fn build(self) -> Result<ARFConfig, String> {
139        if self.n_trees == 0 {
140            return Err("n_trees must be >= 1".into());
141        }
142        if self.lambda <= 0.0 || !self.lambda.is_finite() {
143            return Err("lambda must be positive and finite".into());
144        }
145        if self.feature_fraction < 0.0 || self.feature_fraction > 1.0 {
146            return Err("feature_fraction must be in [0.0, 1.0]".into());
147        }
148        if self.drift_delta <= 0.0 || self.drift_delta >= 1.0 {
149            return Err("drift_delta must be in (0, 1)".into());
150        }
151        if self.warning_delta <= 0.0 || self.warning_delta >= 1.0 {
152            return Err("warning_delta must be in (0, 1)".into());
153        }
154        Ok(ARFConfig {
155            n_trees: self.n_trees,
156            lambda: self.lambda,
157            feature_fraction: self.feature_fraction,
158            drift_delta: self.drift_delta,
159            warning_delta: self.warning_delta,
160            seed: self.seed,
161        })
162    }
163}
164
165// ---------------------------------------------------------------------------
166// ARFMember
167// ---------------------------------------------------------------------------
168
169struct ARFMember {
170    learner: Box<dyn StreamingLearner>,
171    drift_detector: Adwin,
172    warning_detector: Adwin,
173    feature_mask: Vec<usize>,
174    n_correct: u64,
175    n_evaluated: u64,
176}
177
178// ---------------------------------------------------------------------------
179// AdaptiveRandomForest
180// ---------------------------------------------------------------------------
181
182/// Adaptive Random Forest for streaming classification.
183///
184/// An ensemble of `n_trees` streaming learners, each trained on a
185/// Poisson-weighted bootstrap with a random feature subspace. ADWIN drift
186/// detection automatically resets individual trees when their error rate
187/// changes significantly.
188///
189/// # Example
190///
191/// ```
192/// use irithyll::{AdaptiveRandomForest, StreamingLearner};
193/// use irithyll::ensemble::adaptive_forest::ARFConfig;
194/// use irithyll::learners::linear::StreamingLinearModel;
195///
196/// let config = ARFConfig::builder(5).lambda(6.0).build().unwrap();
197/// let mut arf = AdaptiveRandomForest::new(config, || {
198///     Box::new(StreamingLinearModel::new(0.01))
199/// });
200///
201/// arf.train_one(&[1.0, 0.0], 0.0);
202/// let pred = arf.predict(&[1.0, 0.0]);
203/// ```
204pub struct AdaptiveRandomForest {
205    config: ARFConfig,
206    trees: Vec<ARFMember>,
207    n_features: usize,
208    n_samples: u64,
209    n_drifts: usize,
210    rng_state: u64,
211    /// Stored factory for creating replacement learners on drift.
212    factory: Box<dyn Fn() -> Box<dyn StreamingLearner> + Send + Sync>,
213}
214
215impl AdaptiveRandomForest {
216    /// Create a new ARF with the given config and learner factory.
217    ///
218    /// The `factory` closure is called `n_trees` times to create the initial
219    /// ensemble, and again whenever a tree is replaced after drift detection.
220    pub fn new<F>(config: ARFConfig, factory: F) -> Self
221    where
222        F: Fn() -> Box<dyn StreamingLearner> + Send + Sync + 'static,
223    {
224        let mut rng = config.seed;
225        let trees: Vec<ARFMember> = (0..config.n_trees)
226            .map(|_| {
227                // Feature masks are initialized lazily on first train_one
228                let _ = xorshift64(&mut rng);
229                ARFMember {
230                    learner: factory(),
231                    drift_detector: Adwin::with_delta(config.drift_delta),
232                    warning_detector: Adwin::with_delta(config.warning_delta),
233                    feature_mask: Vec::new(),
234                    n_correct: 0,
235                    n_evaluated: 0,
236                }
237            })
238            .collect();
239
240        Self {
241            config,
242            trees,
243            n_features: 0,
244            n_samples: 0,
245            n_drifts: 0,
246            rng_state: rng,
247            factory: Box::new(factory),
248        }
249    }
250
251    /// Initialize feature masks for all trees (called lazily on first sample).
252    fn init_feature_masks(&mut self) {
253        let d = self.n_features;
254        let fraction = if self.config.feature_fraction == 0.0 {
255            crate::math::sqrt(d as f64) / d as f64
256        } else {
257            self.config.feature_fraction
258        };
259        let k = (crate::math::ceil(fraction * d as f64) as usize)
260            .max(1)
261            .min(d);
262
263        for member in &mut self.trees {
264            // Fisher-Yates partial shuffle to select k unique indices
265            let mut indices: Vec<usize> = (0..d).collect();
266            for i in 0..k {
267                let j = i + (xorshift64(&mut self.rng_state) as usize % (d - i));
268                indices.swap(i, j);
269            }
270            indices.truncate(k);
271            indices.sort_unstable();
272            member.feature_mask = indices;
273        }
274    }
275
276    /// Extract masked features for a given tree member.
277    fn mask_features(&self, features: &[f64], mask: &[usize]) -> Vec<f64> {
278        if mask.is_empty() {
279            features.to_vec()
280        } else {
281            mask.iter().map(|&i| features[i]).collect()
282        }
283    }
284
285    /// Train on a single sample.
286    pub fn train_one(&mut self, features: &[f64], target: f64) {
287        if self.n_features == 0 {
288            self.n_features = features.len();
289            self.init_feature_masks();
290        }
291        self.n_samples += 1;
292
293        for i in 0..self.trees.len() {
294            let k = poisson(self.config.lambda, &mut self.rng_state);
295            let masked = self.mask_features(features, &self.trees[i].feature_mask);
296
297            // Predict before training (for drift detection).
298            let pred = self.trees[i].learner.predict(&masked);
299            let correct = crate::math::abs(crate::math::round(pred) - target) < 0.5;
300            self.trees[i].n_evaluated += 1;
301            if correct {
302                self.trees[i].n_correct += 1;
303            }
304
305            // Train k times (Poisson-weighted bootstrap).
306            for _ in 0..k {
307                self.trees[i].learner.train(&masked, target);
308            }
309
310            // Feed error signal to drift detectors.
311            let error_val = if correct { 0.0 } else { 1.0 };
312            let drift_signal = self.trees[i].drift_detector.update(error_val);
313            let _warning_signal = self.trees[i].warning_detector.update(error_val);
314
315            // On drift: replace the tree.
316            if matches!(drift_signal, DriftSignal::Drift) {
317                self.trees[i].learner = (self.factory)();
318                self.trees[i].drift_detector = Adwin::with_delta(self.config.drift_delta);
319                self.trees[i].warning_detector = Adwin::with_delta(self.config.warning_delta);
320                self.trees[i].n_correct = 0;
321                self.trees[i].n_evaluated = 0;
322                self.n_drifts += 1;
323
324                // Re-init feature mask for the new tree.
325                let d = self.n_features;
326                let fraction = if self.config.feature_fraction == 0.0 {
327                    crate::math::sqrt(d as f64) / d as f64
328                } else {
329                    self.config.feature_fraction
330                };
331                let k_features = (crate::math::ceil(fraction * d as f64) as usize)
332                    .max(1)
333                    .min(d);
334                let mut indices: Vec<usize> = (0..d).collect();
335                for j in 0..k_features {
336                    let swap = j + (xorshift64(&mut self.rng_state) as usize % (d - j));
337                    indices.swap(j, swap);
338                }
339                indices.truncate(k_features);
340                indices.sort_unstable();
341                self.trees[i].feature_mask = indices;
342            }
343        }
344    }
345
346    /// Predict by majority vote across all trees.
347    ///
348    /// Each tree casts a vote for a class (prediction rounded to nearest
349    /// integer). The class with the most votes wins.
350    pub fn predict(&self, features: &[f64]) -> f64 {
351        let votes = self.predict_votes(features);
352        votes
353            .into_iter()
354            .max_by_key(|&(_, count)| count)
355            .map(|(class, _)| class)
356            .unwrap_or(0.0)
357    }
358
359    /// Vote counts per predicted class.
360    pub fn predict_votes(&self, features: &[f64]) -> Vec<(f64, u64)> {
361        let mut vote_map: Vec<(f64, u64)> = Vec::new();
362        for member in &self.trees {
363            let masked = self.mask_features(features, &member.feature_mask);
364            let pred = crate::math::round(member.learner.predict(&masked));
365            if let Some(entry) = vote_map
366                .iter_mut()
367                .find(|(c, _)| crate::math::abs(*c - pred) < 0.5)
368            {
369                entry.1 += 1;
370            } else {
371                vote_map.push((pred, 1));
372            }
373        }
374        vote_map
375    }
376
377    /// Number of trees in the ensemble.
378    pub fn n_trees(&self) -> usize {
379        self.config.n_trees
380    }
381
382    /// Total samples processed.
383    pub fn n_samples_seen(&self) -> u64 {
384        self.n_samples
385    }
386
387    /// Per-tree accuracy (correct / evaluated).
388    pub fn tree_accuracies(&self) -> Vec<f64> {
389        self.trees
390            .iter()
391            .map(|m| {
392                if m.n_evaluated == 0 {
393                    0.0
394                } else {
395                    m.n_correct as f64 / m.n_evaluated as f64
396                }
397            })
398            .collect()
399    }
400
401    /// Total number of drift-triggered tree replacements.
402    pub fn n_drifts_detected(&self) -> usize {
403        self.n_drifts
404    }
405}
406
407impl StreamingLearner for AdaptiveRandomForest {
408    fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
409        self.train_one(features, target);
410    }
411
412    fn predict(&self, features: &[f64]) -> f64 {
413        self.predict(features)
414    }
415
416    fn n_samples_seen(&self) -> u64 {
417        self.n_samples
418    }
419
420    fn reset(&mut self) {
421        self.n_samples = 0;
422        self.n_drifts = 0;
423        for member in &mut self.trees {
424            member.n_correct = 0;
425            member.n_evaluated = 0;
426        }
427    }
428}
429
430// ---------------------------------------------------------------------------
431// Tests
432// ---------------------------------------------------------------------------
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use alloc::boxed::Box;
438
439    struct MockClassifier {
440        prediction: f64,
441        n: u64,
442    }
443
444    impl MockClassifier {
445        fn new(prediction: f64) -> Self {
446            Self { prediction, n: 0 }
447        }
448    }
449
450    impl StreamingLearner for MockClassifier {
451        fn train_one(&mut self, _features: &[f64], _target: f64, _weight: f64) {
452            self.n += 1;
453        }
454        fn predict(&self, _features: &[f64]) -> f64 {
455            self.prediction
456        }
457        fn n_samples_seen(&self) -> u64 {
458            self.n
459        }
460        fn reset(&mut self) {
461            self.n = 0;
462        }
463    }
464
465    #[test]
466    fn arf_trains_and_predicts() {
467        let config = ARFConfig::builder(3).seed(42).build().unwrap();
468        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
469
470        arf.train_one(&[1.0, 2.0], 1.0);
471        let pred = arf.predict(&[1.0, 2.0]);
472        assert_eq!(pred, 1.0);
473        assert_eq!(arf.n_samples_seen(), 1);
474    }
475
476    #[test]
477    fn arf_majority_vote() {
478        let config = ARFConfig::builder(5).seed(42).build().unwrap();
479        // All trees predict 0.0 → unanimous vote
480        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
481        // Need to init feature masks
482        arf.n_features = 2;
483        arf.init_feature_masks();
484
485        let votes = arf.predict_votes(&[1.0, 2.0]);
486        assert_eq!(votes.len(), 1, "all trees should agree");
487        assert_eq!(votes[0], (0.0, 5), "5 votes for class 0");
488        assert_eq!(arf.predict(&[1.0, 2.0]), 0.0);
489    }
490
491    #[test]
492    fn arf_poisson_valid() {
493        let mut rng = 12345u64;
494        let mut total = 0u64;
495        let n = 1000;
496        for _ in 0..n {
497            total += poisson(6.0, &mut rng);
498        }
499        let mean = total as f64 / n as f64;
500        // Poisson(6) should have mean ~6
501        assert!(
502            (mean - 6.0).abs() < 1.0,
503            "Poisson mean should be ~6.0, got {}",
504            mean
505        );
506    }
507
508    #[test]
509    fn arf_feature_subspace() {
510        let config = ARFConfig::builder(3)
511            .feature_fraction(0.5)
512            .seed(42)
513            .build()
514            .unwrap();
515        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
516
517        arf.train_one(&[1.0, 2.0, 3.0, 4.0], 0.0);
518
519        // Each tree should have ceil(0.5 * 4) = 2 features
520        for member in &arf.trees {
521            assert_eq!(
522                member.feature_mask.len(),
523                2,
524                "expected 2 features, got {}",
525                member.feature_mask.len()
526            );
527        }
528    }
529
530    #[test]
531    fn arf_streaming_learner_trait() {
532        let config = ARFConfig::builder(3).seed(42).build().unwrap();
533        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
534
535        let learner: &mut dyn StreamingLearner = &mut arf;
536        learner.train(&[1.0, 2.0], 0.0);
537        assert_eq!(learner.n_samples_seen(), 1);
538        let pred = learner.predict(&[1.0, 2.0]);
539        assert_eq!(pred, 0.0);
540    }
541
542    #[test]
543    fn arf_config_validates() {
544        assert!(ARFConfig::builder(0).build().is_err());
545        assert!(ARFConfig::builder(3).lambda(0.0).build().is_err());
546        assert!(ARFConfig::builder(3).lambda(-1.0).build().is_err());
547        assert!(ARFConfig::builder(3)
548            .feature_fraction(-0.1)
549            .build()
550            .is_err());
551        assert!(ARFConfig::builder(3).feature_fraction(1.1).build().is_err());
552        assert!(ARFConfig::builder(3).drift_delta(0.0).build().is_err());
553        assert!(ARFConfig::builder(3).drift_delta(1.0).build().is_err());
554        assert!(ARFConfig::builder(3).build().is_ok());
555    }
556
557    #[test]
558    fn arf_tree_accuracies() {
559        let config = ARFConfig::builder(3).seed(42).build().unwrap();
560        let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
561
562        // Train with target=1.0, mock predicts 1.0 → all correct
563        for _ in 0..10 {
564            arf.train_one(&[1.0, 2.0], 1.0);
565        }
566
567        let accs = arf.tree_accuracies();
568        assert_eq!(accs.len(), 3);
569        for &acc in &accs {
570            assert!(
571                acc > 0.9,
572                "accuracy should be ~1.0 for correct mock, got {}",
573                acc
574            );
575        }
576    }
577}