use crate::dataset::Records;
use crate::traits::PredictInplace;
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
use std::iter::FromIterator;
pub struct MultiTargetModel<R: Records, L> {
models: Vec<Box<dyn PredictInplace<R, Array1<L>>>>,
}
impl<R: Records, L> MultiTargetModel<R, L> {
pub fn new(models: Vec<Box<dyn PredictInplace<R, Array1<L>>>>) -> Self {
MultiTargetModel { models }
}
}
impl<L: Default, F, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<L>>
for MultiTargetModel<ArrayBase<D, Ix2>, L>
{
fn predict_inplace(&self, arr: &ArrayBase<D, Ix2>, targets: &mut Array2<L>) {
assert_eq!(
targets.shape(),
&[arr.nrows(), self.models.len()],
"The number of data points must match the number of output targets."
);
assert!(
targets.is_standard_layout(),
"targets not in row-major layout"
);
*targets = self
.models
.iter()
.flat_map(|model| {
let mut targets = Array1::default(arr.nrows());
model.predict_inplace(arr, &mut targets);
let (v, _) = targets.into_raw_vec_and_offset();
v
})
.collect::<Array1<L>>()
.into_shape_with_order((self.models.len(), arr.len_of(Axis(0))))
.unwrap()
.reversed_axes();
}
fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<L> {
Array2::default((x.nrows(), self.models.len()))
}
}
impl<F, D: Data<Elem = F>, L, P: PredictInplace<ArrayBase<D, Ix2>, Array1<L>> + 'static>
FromIterator<P> for MultiTargetModel<ArrayBase<D, Ix2>, L>
{
fn from_iter<I: IntoIterator<Item = P>>(iter: I) -> Self {
let models = iter
.into_iter()
.map(|x| Box::new(x) as Box<dyn PredictInplace<ArrayBase<D, Ix2>, Array1<L>>>)
.collect();
MultiTargetModel { models }
}
}
#[cfg(test)]
mod tests {
use crate::{
traits::{Predict, PredictInplace},
MultiTargetModel,
};
use approx::assert_abs_diff_eq;
use ndarray::{array, Array1, Array2, Axis};
struct DummyModel {
val: f32,
}
impl PredictInplace<Array2<f32>, Array1<f32>> for DummyModel {
fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<f32>) {
assert_eq!(
arr.nrows(),
targets.len(),
"The number of data points must match the number of output targets."
);
targets.fill(self.val);
}
fn default_target(&self, x: &Array2<f32>) -> Array1<f32> {
Array1::zeros(x.nrows())
}
}
struct DummyModel2 {
val: f32,
}
impl PredictInplace<Array2<f32>, Array1<f32>> for DummyModel2 {
fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<f32>) {
assert_eq!(
arr.nrows(),
targets.len(),
"The number of data points must match the number of output targets."
);
*targets = Array1::linspace(
self.val,
self.val + arr.len_of(Axis(0)) as f32 - 1.0,
arr.len_of(Axis(0)),
);
}
fn default_target(&self, x: &Array2<f32>) -> Array1<f32> {
Array1::zeros(x.nrows())
}
}
#[test]
fn dummy_constant() {
let model = (0..4)
.map(|val| val as f32)
.map(|val| DummyModel { val })
.collect::<MultiTargetModel<_, _>>();
let targets = model.predict(&Array2::zeros((5, 2)));
assert_abs_diff_eq!(
targets,
array![
[0., 1., 2., 3.],
[0., 1., 2., 3.],
[0., 1., 2., 3.],
[0., 1., 2., 3.],
[0., 1., 2., 3.],
]
);
}
#[test]
fn different_dummys() {
let model_a = DummyModel { val: 42.0 };
let model_b = DummyModel2 { val: 42.0 };
let model = MultiTargetModel::new(vec![Box::new(model_a), Box::new(model_b)]);
let targets = model.predict(&Array2::zeros((5, 2)));
assert_abs_diff_eq!(
targets,
array![[42., 42.], [42., 43.], [42., 44.], [42., 45.], [42., 46.]]
);
}
}