cubek_test_utils/test_tensor/
host_data.rs1use cubecl::{
2 CubeElement, TestRuntime, client::ComputeClient, prelude::CubePrimitive,
3 std::tensor::TensorHandle,
4};
5
6use crate::test_tensor::cast::copy_casted;
7
8#[derive(Debug)]
9pub struct HostData {
10 pub data: HostDataVec,
11 pub shape: Vec<usize>,
12 pub strides: Vec<usize>,
13}
14
15#[derive(Eq, PartialEq, PartialOrd)]
16pub enum HostDataType {
17 F32,
18 Bool,
19}
20
21#[derive(Clone, Debug)]
22pub enum HostDataVec {
23 F32(Vec<f32>),
24 Bool(Vec<bool>),
25}
26
27impl HostDataVec {
28 pub fn get_f32(&self, i: usize) -> f32 {
29 match self {
30 HostDataVec::F32(items) => items[i],
31 HostDataVec::Bool(_) => panic!("Can't get bool as f32"),
32 }
33 }
34
35 pub fn get_bool(&self, i: usize) -> bool {
36 match self {
37 HostDataVec::F32(_) => panic!("Can't get bool as f32"),
38 HostDataVec::Bool(items) => items[i],
39 }
40 }
41}
42
43impl HostData {
44 pub fn from_tensor_handle(
45 client: &ComputeClient<TestRuntime>,
46 tensor_handle: &TensorHandle<TestRuntime>,
47 host_data_type: HostDataType,
48 ) -> Self {
49 let shape = tensor_handle.shape.clone();
50 let strides = tensor_handle.strides.clone();
51
52 let data = match host_data_type {
53 HostDataType::F32 => {
54 let handle = copy_casted(client, tensor_handle, f32::as_type_native_unchecked());
55 let data = f32::from_bytes(&client.read_one_tensor(handle.as_copy_descriptor()))
56 .to_owned();
57
58 HostDataVec::F32(data)
59 }
60 HostDataType::Bool => {
61 let handle = copy_casted(client, tensor_handle, u8::as_type_native_unchecked());
62 let data =
63 u8::from_bytes(&client.read_one_tensor(handle.as_copy_descriptor())).to_owned();
64
65 HostDataVec::Bool(data.iter().map(|&x| x > 0).collect())
66 }
67 };
68
69 Self {
70 data,
71 shape,
72 strides,
73 }
74 }
75
76 pub fn get_f32(&self, index: &[usize]) -> f32 {
77 self.data.get_f32(self.strided_index(index))
78 }
79
80 pub fn get_bool(&self, index: &[usize]) -> bool {
81 self.data.get_bool(self.strided_index(index))
82 }
83
84 fn strided_index(&self, index: &[usize]) -> usize {
85 let mut i = 0usize;
86 for (d, idx) in index.iter().enumerate() {
87 i += idx * self.strides[d];
88 }
89 i
90 }
91}