use imbalanced_core::traits::*;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use std::sync::Arc;
pub struct BalancedRandomForest {
_n_estimators: usize,
max_depth: Option<usize>,
_resampler: Arc<dyn ResamplingStrategy<Input = (), Output = (Array2<f64>, Array1<i32>), Config = ()>>,
}
impl BalancedRandomForest {
pub fn new(n_estimators: usize) -> Self {
Self {
_n_estimators: n_estimators,
max_depth: None,
_resampler: Arc::new(DummyResampler),
}
}
pub fn with_max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = Some(max_depth);
self
}
}
struct DummyResampler;
impl ResamplingStrategy for DummyResampler {
type Input = ();
type Output = (Array2<f64>, Array1<i32>);
type Config = ();
fn resample(
&self,
x: ArrayView2<f64>,
y: ArrayView1<i32>,
_config: &Self::Config,
) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
Ok((x.to_owned(), y.to_owned()))
}
}