use std::vec::Vec;
use crate::{
shapes::*,
tensor::{Storage, Tensor, TensorFromVec, ZerosTensor},
};
pub trait OneHotEncode<E: Dtype>: Storage<E> + ZerosTensor<E> + TensorFromVec<E> {
fn one_hot_encode<Lbls: Array<usize>, N: Dim>(
&self,
n: N,
labels: Lbls,
) -> Tensor<(Lbls::Dim, N), E, Self> {
let l = labels.dim();
let mut data = Vec::with_capacity(l.size() * n.size());
for l in labels.into_iter() {
for i in 0..n.size() {
data.push(if i == l {
E::from_usize(1).unwrap()
} else {
E::from_usize(0).unwrap()
});
}
}
self.tensor_from_vec(data, (l, n))
}
}
impl<E: Dtype, D: Storage<E> + ZerosTensor<E> + TensorFromVec<E>> OneHotEncode<E> for D {}