Trait dfdx::data::OneHotEncode

source ·
pub trait OneHotEncode<E: Dtype>: Storage<E> + ZerosTensor<E> + TensorFromVec<E> {
    // Provided method
    fn one_hot_encode<Lbls: Array<usize>, N: Dim>(
        &self,
        n: N,
        labels: Lbls
    ) -> Tensor<(Lbls::Dim, N), E, Self> { ... }
}
Expand description

One hot encodes an array of class labels into a 2d tensor of probability vectors. This can be used in tandem with crate::losses::cross_entropy_with_logits_loss().

Provided Methods§

source

fn one_hot_encode<Lbls: Array<usize>, N: Dim>( &self, n: N, labels: Lbls ) -> Tensor<(Lbls::Dim, N), E, Self>

One hot encodes an array or vec into a tensor.

Arguments:

  • n - the numnber of classes to use to encode, can be Const or usize
  • class_labels - either an array [usize; N], or Vec<usize>

Const class labels and const n:

let class_labels = [0, 1, 2, 1, 1];
let probs: Tensor<Rank2<5, 3>, f32, _> = dev.one_hot_encode(Const::<3>, class_labels);
assert_eq!(probs.array(), [
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
    [0.0, 1.0, 0.0],
    [0.0, 1.0, 0.0],
]);

Runtime class labels and const n:

let class_labels = [0, 1, 2, 1, 1];
let probs: Tensor<(Const<5>, usize), f32, _> = dev.one_hot_encode(3, class_labels);
assert_eq!(&probs.as_vec(), &[
    1.0, 0.0, 0.0,
    0.0, 1.0, 0.0,
    0.0, 0.0, 1.0,
    0.0, 1.0, 0.0,
    0.0, 1.0, 0.0,
]);

Const class labels and runtime n:

let class_labels = std::vec![0, 1, 2, 1, 1];
let probs: Tensor<(usize, Const<3>), f32, _> = dev.one_hot_encode(Const, class_labels);
assert_eq!(&probs.as_vec(), &[
    1.0, 0.0, 0.0,
    0.0, 1.0, 0.0,
    0.0, 0.0, 1.0,
    0.0, 1.0, 0.0,
    0.0, 1.0, 0.0,
]);

Runtime both:

let class_labels = std::vec![0, 1, 2, 1, 1];
let probs: Tensor<(usize, usize), f32, _> = dev.one_hot_encode(3, class_labels);
assert_eq!(&probs.as_vec(), &[
    1.0, 0.0, 0.0,
    0.0, 1.0, 0.0,
    0.0, 0.0, 1.0,
    0.0, 1.0, 0.0,
    0.0, 1.0, 0.0,
]);

Implementors§

source§

impl<E: Dtype, D: Storage<E> + ZerosTensor<E> + TensorFromVec<E>> OneHotEncode<E> for D