imbalanced_ensemble/
balanced_random_forest.rs1use imbalanced_core::traits::*;
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
4use std::sync::Arc;
5
6pub struct BalancedRandomForest {
8 _n_estimators: usize,
10 max_depth: Option<usize>,
12 _resampler: Arc<dyn ResamplingStrategy<Input = (), Output = (Array2<f64>, Array1<i32>), Config = ()>>,
14}
15
16impl BalancedRandomForest {
17 pub fn new(n_estimators: usize) -> Self {
19 Self {
20 _n_estimators: n_estimators,
21 max_depth: None,
22 _resampler: Arc::new(DummyResampler),
23 }
24 }
25
26 pub fn with_max_depth(mut self, max_depth: usize) -> Self {
28 self.max_depth = Some(max_depth);
29 self
30 }
31}
32
33struct DummyResampler;
35
36impl ResamplingStrategy for DummyResampler {
37 type Input = ();
38 type Output = (Array2<f64>, Array1<i32>);
39 type Config = ();
40
41 fn resample(
42 &self,
43 x: ArrayView2<f64>,
44 y: ArrayView1<i32>,
45 _config: &Self::Config,
46 ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
47 Ok((x.to_owned(), y.to_owned()))
49 }
50}