rafor 0.3.0

Fast Random Forest library.
Documentation
use std::sync::{
    atomic::{AtomicUsize, Ordering},
    Arc,
};
use std::thread;

pub trait Predictor {
    fn predict(&self, dataset: &[f32]) -> Vec<f32>;
}

pub fn predict<P: Predictor + Sync + Send>(
    predictors: &Vec<P>,
    dataset: &[f32],
    num_threads: usize,
) -> Vec<f32> {
    let mut result: Vec<f32> = Vec::new();
    if num_threads == 1 {
        for p in predictors.iter() {
            result.aggregate(&p.predict(&dataset));
        }
    } else {
        let task_id = Arc::new(AtomicUsize::new(0));
        thread::scope(|s| {
            let mut handles = Vec::new();
            for _ in 0..num_threads {
                let handle = s.spawn(|| {
                    let mut thread_result: Vec<f32> = Vec::new();
                    loop {
                        let id = task_id.fetch_add(1, Ordering::Relaxed);
                        if id < predictors.len() {
                            thread_result.aggregate(&predictors[id].predict(&dataset));
                        } else {
                            break;
                        }
                    }
                    thread_result
                });

                handles.push(handle);
            }
            for handle in handles {
                result.aggregate(&handle.join().unwrap());
            }
        });
    }

    for x in result.iter_mut() {
        *x /= predictors.len() as f32;
    }

    result
}

pub trait Aggregate {
    fn aggregate(&mut self, other: &Vec<f32>);
}

impl Aggregate for Vec<f32> {
    fn aggregate(&mut self, other: &Vec<f32>) {
        if self.is_empty() {
            *self = other.clone();
        } else {
            assert!(other.is_empty() || self.len() == other.len());
            for (s, x) in self.iter_mut().zip(other.iter()) {
                *s += *x;
            }
        }
    }
}