cubek_test_utils/test_tensor/
host_data.rs

1use cubecl::{
2    CubeElement, TestRuntime, client::ComputeClient, prelude::CubePrimitive,
3    std::tensor::TensorHandle,
4};
5
6use crate::test_tensor::cast::copy_casted;
7
8#[derive(Debug)]
9pub struct HostData {
10    pub data: HostDataVec,
11    pub shape: Vec<usize>,
12    pub strides: Vec<usize>,
13}
14
15#[derive(Eq, PartialEq, PartialOrd)]
16pub enum HostDataType {
17    F32,
18    Bool,
19}
20
21#[derive(Clone, Debug)]
22pub enum HostDataVec {
23    F32(Vec<f32>),
24    Bool(Vec<bool>),
25}
26
27impl HostDataVec {
28    pub fn get_f32(&self, i: usize) -> f32 {
29        match self {
30            HostDataVec::F32(items) => items[i],
31            HostDataVec::Bool(_) => panic!("Can't get bool as f32"),
32        }
33    }
34
35    pub fn get_bool(&self, i: usize) -> bool {
36        match self {
37            HostDataVec::F32(_) => panic!("Can't get bool as f32"),
38            HostDataVec::Bool(items) => items[i],
39        }
40    }
41}
42
43impl HostData {
44    pub fn from_tensor_handle(
45        client: &ComputeClient<TestRuntime>,
46        tensor_handle: &TensorHandle<TestRuntime>,
47        host_data_type: HostDataType,
48    ) -> Self {
49        let shape = tensor_handle.shape.clone();
50        let strides = tensor_handle.strides.clone();
51
52        let data = match host_data_type {
53            HostDataType::F32 => {
54                let handle = copy_casted(client, tensor_handle, f32::as_type_native_unchecked());
55                let data = f32::from_bytes(&client.read_one_tensor(handle.as_copy_descriptor()))
56                    .to_owned();
57
58                HostDataVec::F32(data)
59            }
60            HostDataType::Bool => {
61                let handle = copy_casted(client, tensor_handle, u8::as_type_native_unchecked());
62                let data =
63                    u8::from_bytes(&client.read_one_tensor(handle.as_copy_descriptor())).to_owned();
64
65                HostDataVec::Bool(data.iter().map(|&x| x > 0).collect())
66            }
67        };
68
69        Self {
70            data,
71            shape,
72            strides,
73        }
74    }
75
76    pub fn get_f32(&self, index: &[usize]) -> f32 {
77        self.data.get_f32(self.strided_index(index))
78    }
79
80    pub fn get_bool(&self, index: &[usize]) -> bool {
81        self.data.get_bool(self.strided_index(index))
82    }
83
84    fn strided_index(&self, index: &[usize]) -> usize {
85        let mut i = 0usize;
86        for (d, idx) in index.iter().enumerate() {
87            i += idx * self.strides[d];
88        }
89        i
90    }
91}