1use alloc::boxed::Box;
26use alloc::string::String;
27use alloc::vec::Vec;
28
29use crate::drift::adwin::Adwin;
30use crate::drift::{DriftDetector, DriftSignal};
31use crate::learner::StreamingLearner;
32
33use crate::rng::xorshift64;
34
35fn poisson(lambda: f64, rng: &mut u64) -> u64 {
40 let l = crate::math::exp(-lambda);
41 let mut k = 0u64;
42 let mut p = 1.0f64;
43 loop {
44 k += 1;
45 let u = xorshift64(rng) as f64 / u64::MAX as f64;
46 p *= u;
47 if p <= l {
48 return k - 1;
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
59pub struct ARFConfig {
60 pub n_trees: usize,
62 pub lambda: f64,
64 pub feature_fraction: f64,
66 pub drift_delta: f64,
68 pub warning_delta: f64,
70 pub seed: u64,
72}
73
74#[derive(Debug, Clone)]
76pub struct ARFConfigBuilder {
77 n_trees: usize,
78 lambda: f64,
79 feature_fraction: f64,
80 drift_delta: f64,
81 warning_delta: f64,
82 seed: u64,
83}
84
85impl ARFConfig {
86 pub fn builder(n_trees: usize) -> ARFConfigBuilder {
88 ARFConfigBuilder {
89 n_trees,
90 lambda: 6.0,
91 feature_fraction: 0.0,
92 drift_delta: 1e-3,
93 warning_delta: 1e-2,
94 seed: 42,
95 }
96 }
97}
98
99impl ARFConfigBuilder {
100 pub fn lambda(mut self, lambda: f64) -> Self {
102 self.lambda = lambda;
103 self
104 }
105
106 pub fn feature_fraction(mut self, f: f64) -> Self {
108 self.feature_fraction = f;
109 self
110 }
111
112 pub fn drift_delta(mut self, d: f64) -> Self {
114 self.drift_delta = d;
115 self
116 }
117
118 pub fn warning_delta(mut self, d: f64) -> Self {
120 self.warning_delta = d;
121 self
122 }
123
124 pub fn seed(mut self, s: u64) -> Self {
126 self.seed = s;
127 self
128 }
129
130 pub fn build(self) -> Result<ARFConfig, String> {
132 if self.n_trees == 0 {
133 return Err("n_trees must be >= 1".into());
134 }
135 if self.lambda <= 0.0 || !self.lambda.is_finite() {
136 return Err("lambda must be positive and finite".into());
137 }
138 if self.feature_fraction < 0.0 || self.feature_fraction > 1.0 {
139 return Err("feature_fraction must be in [0.0, 1.0]".into());
140 }
141 if self.drift_delta <= 0.0 || self.drift_delta >= 1.0 {
142 return Err("drift_delta must be in (0, 1)".into());
143 }
144 if self.warning_delta <= 0.0 || self.warning_delta >= 1.0 {
145 return Err("warning_delta must be in (0, 1)".into());
146 }
147 Ok(ARFConfig {
148 n_trees: self.n_trees,
149 lambda: self.lambda,
150 feature_fraction: self.feature_fraction,
151 drift_delta: self.drift_delta,
152 warning_delta: self.warning_delta,
153 seed: self.seed,
154 })
155 }
156}
157
158struct ARFMember {
163 learner: Box<dyn StreamingLearner>,
164 drift_detector: Adwin,
165 warning_detector: Adwin,
166 feature_mask: Vec<usize>,
167 n_correct: u64,
168 n_evaluated: u64,
169}
170
171pub struct AdaptiveRandomForest {
198 config: ARFConfig,
199 trees: Vec<ARFMember>,
200 n_features: usize,
201 n_samples: u64,
202 n_drifts: usize,
203 rng_state: u64,
204 factory: Box<dyn Fn() -> Box<dyn StreamingLearner> + Send + Sync>,
206}
207
208impl AdaptiveRandomForest {
209 pub fn new<F>(config: ARFConfig, factory: F) -> Self
214 where
215 F: Fn() -> Box<dyn StreamingLearner> + Send + Sync + 'static,
216 {
217 let mut rng = config.seed;
218 let trees: Vec<ARFMember> = (0..config.n_trees)
219 .map(|_| {
220 let _ = xorshift64(&mut rng);
222 ARFMember {
223 learner: factory(),
224 drift_detector: Adwin::with_delta(config.drift_delta),
225 warning_detector: Adwin::with_delta(config.warning_delta),
226 feature_mask: Vec::new(),
227 n_correct: 0,
228 n_evaluated: 0,
229 }
230 })
231 .collect();
232
233 Self {
234 config,
235 trees,
236 n_features: 0,
237 n_samples: 0,
238 n_drifts: 0,
239 rng_state: rng,
240 factory: Box::new(factory),
241 }
242 }
243
244 fn init_feature_masks(&mut self) {
246 let d = self.n_features;
247 let fraction = if self.config.feature_fraction == 0.0 {
248 crate::math::sqrt(d as f64) / d as f64
249 } else {
250 self.config.feature_fraction
251 };
252 let k = (crate::math::ceil(fraction * d as f64) as usize)
253 .max(1)
254 .min(d);
255
256 for member in &mut self.trees {
257 let mut indices: Vec<usize> = (0..d).collect();
259 for i in 0..k {
260 let j = i + (xorshift64(&mut self.rng_state) as usize % (d - i));
261 indices.swap(i, j);
262 }
263 indices.truncate(k);
264 indices.sort_unstable();
265 member.feature_mask = indices;
266 }
267 }
268
269 fn mask_features(&self, features: &[f64], mask: &[usize]) -> Vec<f64> {
271 if mask.is_empty() {
272 features.to_vec()
273 } else {
274 mask.iter().map(|&i| features[i]).collect()
275 }
276 }
277
278 pub fn train_one(&mut self, features: &[f64], target: f64) {
280 if self.n_features == 0 {
281 self.n_features = features.len();
282 self.init_feature_masks();
283 }
284 self.n_samples += 1;
285
286 for i in 0..self.trees.len() {
287 let k = poisson(self.config.lambda, &mut self.rng_state);
288 let masked = self.mask_features(features, &self.trees[i].feature_mask);
289
290 let pred = self.trees[i].learner.predict(&masked);
292 let correct = crate::math::abs(crate::math::round(pred) - target) < 0.5;
293 self.trees[i].n_evaluated += 1;
294 if correct {
295 self.trees[i].n_correct += 1;
296 }
297
298 for _ in 0..k {
300 self.trees[i].learner.train(&masked, target);
301 }
302
303 let error_val = if correct { 0.0 } else { 1.0 };
305 let drift_signal = self.trees[i].drift_detector.update(error_val);
306 let _warning_signal = self.trees[i].warning_detector.update(error_val);
307
308 if matches!(drift_signal, DriftSignal::Drift) {
310 self.trees[i].learner = (self.factory)();
311 self.trees[i].drift_detector = Adwin::with_delta(self.config.drift_delta);
312 self.trees[i].warning_detector = Adwin::with_delta(self.config.warning_delta);
313 self.trees[i].n_correct = 0;
314 self.trees[i].n_evaluated = 0;
315 self.n_drifts += 1;
316
317 let d = self.n_features;
319 let fraction = if self.config.feature_fraction == 0.0 {
320 crate::math::sqrt(d as f64) / d as f64
321 } else {
322 self.config.feature_fraction
323 };
324 let k_features = (crate::math::ceil(fraction * d as f64) as usize)
325 .max(1)
326 .min(d);
327 let mut indices: Vec<usize> = (0..d).collect();
328 for j in 0..k_features {
329 let swap = j + (xorshift64(&mut self.rng_state) as usize % (d - j));
330 indices.swap(j, swap);
331 }
332 indices.truncate(k_features);
333 indices.sort_unstable();
334 self.trees[i].feature_mask = indices;
335 }
336 }
337 }
338
339 pub fn predict(&self, features: &[f64]) -> f64 {
344 let votes = self.predict_votes(features);
345 votes
346 .into_iter()
347 .max_by_key(|&(_, count)| count)
348 .map(|(class, _)| class)
349 .unwrap_or(0.0)
350 }
351
352 pub fn predict_votes(&self, features: &[f64]) -> Vec<(f64, u64)> {
354 let mut vote_map: Vec<(f64, u64)> = Vec::new();
355 for member in &self.trees {
356 let masked = self.mask_features(features, &member.feature_mask);
357 let pred = crate::math::round(member.learner.predict(&masked));
358 if let Some(entry) = vote_map
359 .iter_mut()
360 .find(|(c, _)| crate::math::abs(*c - pred) < 0.5)
361 {
362 entry.1 += 1;
363 } else {
364 vote_map.push((pred, 1));
365 }
366 }
367 vote_map
368 }
369
370 pub fn n_trees(&self) -> usize {
372 self.config.n_trees
373 }
374
375 pub fn n_samples_seen(&self) -> u64 {
377 self.n_samples
378 }
379
380 pub fn tree_accuracies(&self) -> Vec<f64> {
382 self.trees
383 .iter()
384 .map(|m| {
385 if m.n_evaluated == 0 {
386 0.0
387 } else {
388 m.n_correct as f64 / m.n_evaluated as f64
389 }
390 })
391 .collect()
392 }
393
394 pub fn n_drifts_detected(&self) -> usize {
396 self.n_drifts
397 }
398}
399
400impl StreamingLearner for AdaptiveRandomForest {
401 fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
402 self.train_one(features, target);
403 }
404
405 fn predict(&self, features: &[f64]) -> f64 {
406 self.predict(features)
407 }
408
409 fn n_samples_seen(&self) -> u64 {
410 self.n_samples
411 }
412
413 fn reset(&mut self) {
414 self.n_samples = 0;
415 self.n_drifts = 0;
416 for member in &mut self.trees {
417 member.n_correct = 0;
418 member.n_evaluated = 0;
419 }
420 }
421}
422
423#[cfg(test)]
428mod tests {
429 use super::*;
430 use alloc::boxed::Box;
431
432 struct MockClassifier {
433 prediction: f64,
434 n: u64,
435 }
436
437 impl MockClassifier {
438 fn new(prediction: f64) -> Self {
439 Self { prediction, n: 0 }
440 }
441 }
442
443 impl StreamingLearner for MockClassifier {
444 fn train_one(&mut self, _features: &[f64], _target: f64, _weight: f64) {
445 self.n += 1;
446 }
447 fn predict(&self, _features: &[f64]) -> f64 {
448 self.prediction
449 }
450 fn n_samples_seen(&self) -> u64 {
451 self.n
452 }
453 fn reset(&mut self) {
454 self.n = 0;
455 }
456 }
457
458 #[test]
459 fn arf_trains_and_predicts() {
460 let config = ARFConfig::builder(3).seed(42).build().unwrap();
461 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
462
463 arf.train_one(&[1.0, 2.0], 1.0);
464 let pred = arf.predict(&[1.0, 2.0]);
465 assert_eq!(pred, 1.0);
466 assert_eq!(arf.n_samples_seen(), 1);
467 }
468
469 #[test]
470 fn arf_majority_vote() {
471 let config = ARFConfig::builder(5).seed(42).build().unwrap();
472 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
474 arf.n_features = 2;
476 arf.init_feature_masks();
477
478 let votes = arf.predict_votes(&[1.0, 2.0]);
479 assert_eq!(votes.len(), 1, "all trees should agree");
480 assert_eq!(votes[0], (0.0, 5), "5 votes for class 0");
481 assert_eq!(arf.predict(&[1.0, 2.0]), 0.0);
482 }
483
484 #[test]
485 fn arf_poisson_valid() {
486 let mut rng = 12345u64;
487 let mut total = 0u64;
488 let n = 1000;
489 for _ in 0..n {
490 total += poisson(6.0, &mut rng);
491 }
492 let mean = total as f64 / n as f64;
493 assert!(
495 (mean - 6.0).abs() < 1.0,
496 "Poisson mean should be ~6.0, got {}",
497 mean
498 );
499 }
500
501 #[test]
502 fn arf_feature_subspace() {
503 let config = ARFConfig::builder(3)
504 .feature_fraction(0.5)
505 .seed(42)
506 .build()
507 .unwrap();
508 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
509
510 arf.train_one(&[1.0, 2.0, 3.0, 4.0], 0.0);
511
512 for member in &arf.trees {
514 assert_eq!(
515 member.feature_mask.len(),
516 2,
517 "expected 2 features, got {}",
518 member.feature_mask.len()
519 );
520 }
521 }
522
523 #[test]
524 fn arf_streaming_learner_trait() {
525 let config = ARFConfig::builder(3).seed(42).build().unwrap();
526 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
527
528 let learner: &mut dyn StreamingLearner = &mut arf;
529 learner.train(&[1.0, 2.0], 0.0);
530 assert_eq!(learner.n_samples_seen(), 1);
531 let pred = learner.predict(&[1.0, 2.0]);
532 assert_eq!(pred, 0.0);
533 }
534
535 #[test]
536 fn arf_config_validates() {
537 assert!(ARFConfig::builder(0).build().is_err());
538 assert!(ARFConfig::builder(3).lambda(0.0).build().is_err());
539 assert!(ARFConfig::builder(3).lambda(-1.0).build().is_err());
540 assert!(ARFConfig::builder(3)
541 .feature_fraction(-0.1)
542 .build()
543 .is_err());
544 assert!(ARFConfig::builder(3).feature_fraction(1.1).build().is_err());
545 assert!(ARFConfig::builder(3).drift_delta(0.0).build().is_err());
546 assert!(ARFConfig::builder(3).drift_delta(1.0).build().is_err());
547 assert!(ARFConfig::builder(3).build().is_ok());
548 }
549
550 #[test]
551 fn arf_tree_accuracies() {
552 let config = ARFConfig::builder(3).seed(42).build().unwrap();
553 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
554
555 for _ in 0..10 {
557 arf.train_one(&[1.0, 2.0], 1.0);
558 }
559
560 let accs = arf.tree_accuracies();
561 assert_eq!(accs.len(), 3);
562 for &acc in &accs {
563 assert!(
564 acc > 0.9,
565 "accuracy should be ~1.0 for correct mock, got {}",
566 acc
567 );
568 }
569 }
570}