cubecl-std 0.7.1

CubeCL Standard Library.
Documentation
use core::fmt::Display;

use cubecl_core::{
    CubeElement,
    prelude::{Numeric, Runtime},
};

use super::test_utils::identity_cpu;
use crate::tensor::{self, TensorHandle};

pub fn test_identity<R: Runtime, C: Numeric + CubeElement + Display>(
    device: &R::Device,
    dim: usize,
) {
    let client = R::client(device);

    let expected = identity_cpu::<C>(dim);

    let identity = TensorHandle::<R, C>::empty(&client, [dim, dim].to_vec());
    tensor::identity::launch(&client, &identity);

    let actual = client.read_one_tensor(identity.handle.clone().copy_descriptor(
        &identity.shape,
        &identity.strides,
        size_of::<C>(),
    ));
    let actual = C::from_bytes(&actual);

    assert_eq!(&expected[..], actual, "identity matrices are not equal.");
}