use crate::dataset::{Pr, Records};
use crate::traits::PredictInplace;
use ndarray::{Array1, ArrayBase, Data, Ix2};
use std::iter::FromIterator;
type MultiClassVec<R, L> = Vec<(L, Box<dyn PredictInplace<R, Array1<Pr>>>)>;
pub struct MultiClassModel<R: Records, L> {
models: MultiClassVec<R, L>,
}
impl<R: Records, L> MultiClassModel<R, L> {
pub fn new(models: MultiClassVec<R, L>) -> Self {
MultiClassModel { models }
}
}
impl<L: Clone + Default, F, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<L>>
for MultiClassModel<ArrayBase<D, Ix2>, L>
{
fn predict_inplace(&self, arr: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
assert_eq!(
arr.nrows(),
y.len(),
"The number of data points must match the number of output targets."
);
let mut res = Vec::new();
for pairs in self.models.iter().map(|(elm, model)| {
let mut targets = Array1::default(arr.nrows());
model.predict_inplace(arr, &mut targets);
targets.into_iter().map(|x| (elm.clone(), *x)).collect()
}) {
if res.is_empty() {
res = pairs;
continue;
}
res = res
.into_iter()
.zip(pairs.into_iter())
.map(|(c, d)| if d.1 > c.1 { d } else { c })
.collect();
}
for (r, target) in res.into_iter().zip(y.iter_mut()) {
*target = r.0;
}
}
fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
Array1::default(x.nrows())
}
}
impl<F, D: Data<Elem = F>, L, P: PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>> + 'static>
FromIterator<(L, P)> for MultiClassModel<ArrayBase<D, Ix2>, L>
{
fn from_iter<I: IntoIterator<Item = (L, P)>>(iter: I) -> Self {
let models = iter
.into_iter()
.map(|(l, x)| {
(
l,
Box::new(x) as Box<dyn PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>>>,
)
})
.collect();
MultiClassModel { models }
}
}
#[cfg(test)]
mod tests {
use crate::{
dataset::Pr,
traits::{Predict, PredictInplace},
MultiClassModel,
};
use ndarray::{array, Array1, Array2};
struct DummyModel {
on_even: bool,
}
impl PredictInplace<Array2<f32>, Array1<Pr>> for DummyModel {
fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<Pr>) {
assert_eq!(
arr.nrows(),
targets.len(),
"The number of data points must match the number of output targets."
);
if !self.on_even {
*targets = Array1::from_shape_fn(arr.nrows(), |x| {
if x % 2 == 1 {
Pr::new(1.0)
} else {
Pr::new(0.0)
}
});
} else {
*targets = Array1::from_shape_fn(arr.nrows(), |x| {
if x % 2 == 1 {
Pr::new(0.0)
} else {
Pr::new(1.0)
}
});
}
}
fn default_target(&self, x: &Array2<f32>) -> Array1<Pr> {
Array1::default(x.nrows())
}
}
#[test]
fn correct_dummies() {
let model1 = DummyModel { on_even: false };
let model2 = DummyModel { on_even: true };
let data = Array2::zeros((4, 2));
assert_eq!(
model1.predict(&data),
array![0.0, 1.0, 0.0, 1.0].mapv(Pr::new)
);
assert_eq!(
model2.predict(&data),
array![1.0, 0.0, 1.0, 0.0].mapv(Pr::new)
);
}
#[test]
fn choose_correct() {
let model = vec![
(0, DummyModel { on_even: false }),
(1, DummyModel { on_even: true }),
]
.into_iter()
.collect::<MultiClassModel<_, usize>>();
let data = Array2::zeros((4, 2));
assert_eq!(model.predict(&data), array![1, 0, 1, 0]);
}
}