1use alloc::boxed::Box;
26use alloc::string::String;
27use alloc::vec::Vec;
28
29use crate::drift::adwin::Adwin;
30use crate::drift::{DriftDetector, DriftSignal};
31use crate::learner::StreamingLearner;
32
33fn xorshift64(state: &mut u64) -> u64 {
38 let mut x = *state;
39 x ^= x << 13;
40 x ^= x >> 7;
41 x ^= x << 17;
42 *state = x;
43 x
44}
45
46fn poisson(lambda: f64, rng: &mut u64) -> u64 {
47 let l = crate::math::exp(-lambda);
48 let mut k = 0u64;
49 let mut p = 1.0f64;
50 loop {
51 k += 1;
52 let u = xorshift64(rng) as f64 / u64::MAX as f64;
53 p *= u;
54 if p <= l {
55 return k - 1;
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
66pub struct ARFConfig {
67 pub n_trees: usize,
69 pub lambda: f64,
71 pub feature_fraction: f64,
73 pub drift_delta: f64,
75 pub warning_delta: f64,
77 pub seed: u64,
79}
80
81#[derive(Debug, Clone)]
83pub struct ARFConfigBuilder {
84 n_trees: usize,
85 lambda: f64,
86 feature_fraction: f64,
87 drift_delta: f64,
88 warning_delta: f64,
89 seed: u64,
90}
91
92impl ARFConfig {
93 pub fn builder(n_trees: usize) -> ARFConfigBuilder {
95 ARFConfigBuilder {
96 n_trees,
97 lambda: 6.0,
98 feature_fraction: 0.0,
99 drift_delta: 1e-3,
100 warning_delta: 1e-2,
101 seed: 42,
102 }
103 }
104}
105
106impl ARFConfigBuilder {
107 pub fn lambda(mut self, lambda: f64) -> Self {
109 self.lambda = lambda;
110 self
111 }
112
113 pub fn feature_fraction(mut self, f: f64) -> Self {
115 self.feature_fraction = f;
116 self
117 }
118
119 pub fn drift_delta(mut self, d: f64) -> Self {
121 self.drift_delta = d;
122 self
123 }
124
125 pub fn warning_delta(mut self, d: f64) -> Self {
127 self.warning_delta = d;
128 self
129 }
130
131 pub fn seed(mut self, s: u64) -> Self {
133 self.seed = s;
134 self
135 }
136
137 pub fn build(self) -> Result<ARFConfig, String> {
139 if self.n_trees == 0 {
140 return Err("n_trees must be >= 1".into());
141 }
142 if self.lambda <= 0.0 || !self.lambda.is_finite() {
143 return Err("lambda must be positive and finite".into());
144 }
145 if self.feature_fraction < 0.0 || self.feature_fraction > 1.0 {
146 return Err("feature_fraction must be in [0.0, 1.0]".into());
147 }
148 if self.drift_delta <= 0.0 || self.drift_delta >= 1.0 {
149 return Err("drift_delta must be in (0, 1)".into());
150 }
151 if self.warning_delta <= 0.0 || self.warning_delta >= 1.0 {
152 return Err("warning_delta must be in (0, 1)".into());
153 }
154 Ok(ARFConfig {
155 n_trees: self.n_trees,
156 lambda: self.lambda,
157 feature_fraction: self.feature_fraction,
158 drift_delta: self.drift_delta,
159 warning_delta: self.warning_delta,
160 seed: self.seed,
161 })
162 }
163}
164
165struct ARFMember {
170 learner: Box<dyn StreamingLearner>,
171 drift_detector: Adwin,
172 warning_detector: Adwin,
173 feature_mask: Vec<usize>,
174 n_correct: u64,
175 n_evaluated: u64,
176}
177
178pub struct AdaptiveRandomForest {
205 config: ARFConfig,
206 trees: Vec<ARFMember>,
207 n_features: usize,
208 n_samples: u64,
209 n_drifts: usize,
210 rng_state: u64,
211 factory: Box<dyn Fn() -> Box<dyn StreamingLearner> + Send + Sync>,
213}
214
215impl AdaptiveRandomForest {
216 pub fn new<F>(config: ARFConfig, factory: F) -> Self
221 where
222 F: Fn() -> Box<dyn StreamingLearner> + Send + Sync + 'static,
223 {
224 let mut rng = config.seed;
225 let trees: Vec<ARFMember> = (0..config.n_trees)
226 .map(|_| {
227 let _ = xorshift64(&mut rng);
229 ARFMember {
230 learner: factory(),
231 drift_detector: Adwin::with_delta(config.drift_delta),
232 warning_detector: Adwin::with_delta(config.warning_delta),
233 feature_mask: Vec::new(),
234 n_correct: 0,
235 n_evaluated: 0,
236 }
237 })
238 .collect();
239
240 Self {
241 config,
242 trees,
243 n_features: 0,
244 n_samples: 0,
245 n_drifts: 0,
246 rng_state: rng,
247 factory: Box::new(factory),
248 }
249 }
250
251 fn init_feature_masks(&mut self) {
253 let d = self.n_features;
254 let fraction = if self.config.feature_fraction == 0.0 {
255 crate::math::sqrt(d as f64) / d as f64
256 } else {
257 self.config.feature_fraction
258 };
259 let k = (crate::math::ceil(fraction * d as f64) as usize)
260 .max(1)
261 .min(d);
262
263 for member in &mut self.trees {
264 let mut indices: Vec<usize> = (0..d).collect();
266 for i in 0..k {
267 let j = i + (xorshift64(&mut self.rng_state) as usize % (d - i));
268 indices.swap(i, j);
269 }
270 indices.truncate(k);
271 indices.sort_unstable();
272 member.feature_mask = indices;
273 }
274 }
275
276 fn mask_features(&self, features: &[f64], mask: &[usize]) -> Vec<f64> {
278 if mask.is_empty() {
279 features.to_vec()
280 } else {
281 mask.iter().map(|&i| features[i]).collect()
282 }
283 }
284
285 pub fn train_one(&mut self, features: &[f64], target: f64) {
287 if self.n_features == 0 {
288 self.n_features = features.len();
289 self.init_feature_masks();
290 }
291 self.n_samples += 1;
292
293 for i in 0..self.trees.len() {
294 let k = poisson(self.config.lambda, &mut self.rng_state);
295 let masked = self.mask_features(features, &self.trees[i].feature_mask);
296
297 let pred = self.trees[i].learner.predict(&masked);
299 let correct = crate::math::abs(crate::math::round(pred) - target) < 0.5;
300 self.trees[i].n_evaluated += 1;
301 if correct {
302 self.trees[i].n_correct += 1;
303 }
304
305 for _ in 0..k {
307 self.trees[i].learner.train(&masked, target);
308 }
309
310 let error_val = if correct { 0.0 } else { 1.0 };
312 let drift_signal = self.trees[i].drift_detector.update(error_val);
313 let _warning_signal = self.trees[i].warning_detector.update(error_val);
314
315 if matches!(drift_signal, DriftSignal::Drift) {
317 self.trees[i].learner = (self.factory)();
318 self.trees[i].drift_detector = Adwin::with_delta(self.config.drift_delta);
319 self.trees[i].warning_detector = Adwin::with_delta(self.config.warning_delta);
320 self.trees[i].n_correct = 0;
321 self.trees[i].n_evaluated = 0;
322 self.n_drifts += 1;
323
324 let d = self.n_features;
326 let fraction = if self.config.feature_fraction == 0.0 {
327 crate::math::sqrt(d as f64) / d as f64
328 } else {
329 self.config.feature_fraction
330 };
331 let k_features = (crate::math::ceil(fraction * d as f64) as usize)
332 .max(1)
333 .min(d);
334 let mut indices: Vec<usize> = (0..d).collect();
335 for j in 0..k_features {
336 let swap = j + (xorshift64(&mut self.rng_state) as usize % (d - j));
337 indices.swap(j, swap);
338 }
339 indices.truncate(k_features);
340 indices.sort_unstable();
341 self.trees[i].feature_mask = indices;
342 }
343 }
344 }
345
346 pub fn predict(&self, features: &[f64]) -> f64 {
351 let votes = self.predict_votes(features);
352 votes
353 .into_iter()
354 .max_by_key(|&(_, count)| count)
355 .map(|(class, _)| class)
356 .unwrap_or(0.0)
357 }
358
359 pub fn predict_votes(&self, features: &[f64]) -> Vec<(f64, u64)> {
361 let mut vote_map: Vec<(f64, u64)> = Vec::new();
362 for member in &self.trees {
363 let masked = self.mask_features(features, &member.feature_mask);
364 let pred = crate::math::round(member.learner.predict(&masked));
365 if let Some(entry) = vote_map
366 .iter_mut()
367 .find(|(c, _)| crate::math::abs(*c - pred) < 0.5)
368 {
369 entry.1 += 1;
370 } else {
371 vote_map.push((pred, 1));
372 }
373 }
374 vote_map
375 }
376
377 pub fn n_trees(&self) -> usize {
379 self.config.n_trees
380 }
381
382 pub fn n_samples_seen(&self) -> u64 {
384 self.n_samples
385 }
386
387 pub fn tree_accuracies(&self) -> Vec<f64> {
389 self.trees
390 .iter()
391 .map(|m| {
392 if m.n_evaluated == 0 {
393 0.0
394 } else {
395 m.n_correct as f64 / m.n_evaluated as f64
396 }
397 })
398 .collect()
399 }
400
401 pub fn n_drifts_detected(&self) -> usize {
403 self.n_drifts
404 }
405}
406
407impl StreamingLearner for AdaptiveRandomForest {
408 fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
409 self.train_one(features, target);
410 }
411
412 fn predict(&self, features: &[f64]) -> f64 {
413 self.predict(features)
414 }
415
416 fn n_samples_seen(&self) -> u64 {
417 self.n_samples
418 }
419
420 fn reset(&mut self) {
421 self.n_samples = 0;
422 self.n_drifts = 0;
423 for member in &mut self.trees {
424 member.n_correct = 0;
425 member.n_evaluated = 0;
426 }
427 }
428}
429
430#[cfg(test)]
435mod tests {
436 use super::*;
437 use alloc::boxed::Box;
438
439 struct MockClassifier {
440 prediction: f64,
441 n: u64,
442 }
443
444 impl MockClassifier {
445 fn new(prediction: f64) -> Self {
446 Self { prediction, n: 0 }
447 }
448 }
449
450 impl StreamingLearner for MockClassifier {
451 fn train_one(&mut self, _features: &[f64], _target: f64, _weight: f64) {
452 self.n += 1;
453 }
454 fn predict(&self, _features: &[f64]) -> f64 {
455 self.prediction
456 }
457 fn n_samples_seen(&self) -> u64 {
458 self.n
459 }
460 fn reset(&mut self) {
461 self.n = 0;
462 }
463 }
464
465 #[test]
466 fn arf_trains_and_predicts() {
467 let config = ARFConfig::builder(3).seed(42).build().unwrap();
468 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
469
470 arf.train_one(&[1.0, 2.0], 1.0);
471 let pred = arf.predict(&[1.0, 2.0]);
472 assert_eq!(pred, 1.0);
473 assert_eq!(arf.n_samples_seen(), 1);
474 }
475
476 #[test]
477 fn arf_majority_vote() {
478 let config = ARFConfig::builder(5).seed(42).build().unwrap();
479 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
481 arf.n_features = 2;
483 arf.init_feature_masks();
484
485 let votes = arf.predict_votes(&[1.0, 2.0]);
486 assert_eq!(votes.len(), 1, "all trees should agree");
487 assert_eq!(votes[0], (0.0, 5), "5 votes for class 0");
488 assert_eq!(arf.predict(&[1.0, 2.0]), 0.0);
489 }
490
491 #[test]
492 fn arf_poisson_valid() {
493 let mut rng = 12345u64;
494 let mut total = 0u64;
495 let n = 1000;
496 for _ in 0..n {
497 total += poisson(6.0, &mut rng);
498 }
499 let mean = total as f64 / n as f64;
500 assert!(
502 (mean - 6.0).abs() < 1.0,
503 "Poisson mean should be ~6.0, got {}",
504 mean
505 );
506 }
507
508 #[test]
509 fn arf_feature_subspace() {
510 let config = ARFConfig::builder(3)
511 .feature_fraction(0.5)
512 .seed(42)
513 .build()
514 .unwrap();
515 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
516
517 arf.train_one(&[1.0, 2.0, 3.0, 4.0], 0.0);
518
519 for member in &arf.trees {
521 assert_eq!(
522 member.feature_mask.len(),
523 2,
524 "expected 2 features, got {}",
525 member.feature_mask.len()
526 );
527 }
528 }
529
530 #[test]
531 fn arf_streaming_learner_trait() {
532 let config = ARFConfig::builder(3).seed(42).build().unwrap();
533 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(0.0)));
534
535 let learner: &mut dyn StreamingLearner = &mut arf;
536 learner.train(&[1.0, 2.0], 0.0);
537 assert_eq!(learner.n_samples_seen(), 1);
538 let pred = learner.predict(&[1.0, 2.0]);
539 assert_eq!(pred, 0.0);
540 }
541
542 #[test]
543 fn arf_config_validates() {
544 assert!(ARFConfig::builder(0).build().is_err());
545 assert!(ARFConfig::builder(3).lambda(0.0).build().is_err());
546 assert!(ARFConfig::builder(3).lambda(-1.0).build().is_err());
547 assert!(ARFConfig::builder(3)
548 .feature_fraction(-0.1)
549 .build()
550 .is_err());
551 assert!(ARFConfig::builder(3).feature_fraction(1.1).build().is_err());
552 assert!(ARFConfig::builder(3).drift_delta(0.0).build().is_err());
553 assert!(ARFConfig::builder(3).drift_delta(1.0).build().is_err());
554 assert!(ARFConfig::builder(3).build().is_ok());
555 }
556
557 #[test]
558 fn arf_tree_accuracies() {
559 let config = ARFConfig::builder(3).seed(42).build().unwrap();
560 let mut arf = AdaptiveRandomForest::new(config, || Box::new(MockClassifier::new(1.0)));
561
562 for _ in 0..10 {
564 arf.train_one(&[1.0, 2.0], 1.0);
565 }
566
567 let accs = arf.tree_accuracies();
568 assert_eq!(accs.len(), 3);
569 for &acc in &accs {
570 assert!(
571 acc > 0.9,
572 "accuracy should be ~1.0 for correct mock, got {}",
573 acc
574 );
575 }
576 }
577}