burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(one_hot)]
mod tests {
    use super::*;
    use burn_tensor::{Int, TensorData};

    #[test]
    fn float_should_support_one_hot() {
        let device = Default::default();

        let tensor = TestTensor::<1>::one_hot(0, 5, &device);
        let expected = TensorData::from([1., 0., 0., 0., 0.]);
        tensor.into_data().assert_eq(&expected, false);

        let tensor = TestTensor::<1>::one_hot(1, 5, &device);
        let expected = TensorData::from([0., 1., 0., 0., 0.]);
        tensor.into_data().assert_eq(&expected, false);

        let tensor = TestTensor::<1>::one_hot(4, 5, &device);
        let expected = TensorData::from([0., 0., 0., 0., 1.]);
        tensor.into_data().assert_eq(&expected, false);

        let tensor = TestTensor::<1>::one_hot(1, 2, &device);
        let expected = TensorData::from([0., 1.]);
        tensor.into_data().assert_eq(&expected, false);
    }

    #[test]
    #[should_panic]
    fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() {
        let device = Default::default();
        let tensor = TestTensor::<1>::one_hot(1, 1, &device);
    }

    #[test]
    #[should_panic]
    fn float_one_hot_should_panic_when_number_of_classes_is_zero() {
        let device = Default::default();
        let tensor = TestTensor::<1>::one_hot(0, 0, &device);
    }

    #[test]
    fn int_should_support_one_hot() {
        let device = Default::default();

        let index_tensor = TestTensorInt::<1>::arange(0..5, &device);
        let one_hot_tensor = index_tensor.one_hot(5);
        let expected = TestTensorInt::eye(5, &device).into_data();
        one_hot_tensor.into_data().assert_eq(&expected, false);
    }

    #[test]
    #[should_panic]
    fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() {
        let device = Default::default();
        let index_tensor = TestTensorInt::<1>::arange(0..6, &device);
        let one_hot_tensor = index_tensor.one_hot(5);
    }

    #[test]
    #[should_panic]
    fn int_one_hot_should_panic_when_number_of_classes_is_zero() {
        let device = Default::default();
        let index_tensor = TestTensorInt::<1>::arange(0..3, &device);
        let one_hot_tensor = index_tensor.one_hot(0);
    }

    #[test]
    #[should_panic]
    fn int_one_hot_should_panic_when_number_of_classes_is_1() {
        let device = Default::default();
        let index_tensor = TestTensorInt::<1>::arange(0..3, &device);
        let one_hot_tensor = index_tensor.one_hot(1);
    }
}