Function dfdx::data::one_hot_encode
source · [−]Expand description
One hot encodes an array of class labels into a Tensor2D of probability vectors. This can be used in tandem with cross_entropy_with_logits_loss().
Const Generic Arguments:
B
- the batch sizeN
- the number of classes
Arguments:
class_labels
- an array of sizeB
where each element is the class label
Outputs: Tensor2D with shape (B, N)
Examples:
let class_labels = [0, 1, 2, 1, 1];
// NOTE: 5 is the batch size, 3 is the number of classes
let probs = one_hot_encode::<5, 3>(&class_labels);
assert_eq!(probs.data(), &[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
]);