cubek_test_utils/test_tensor/
base.rs

1use cubecl::{TestRuntime, client::ComputeClient, ir::StorageType, std::tensor::TensorHandle};
2
3use crate::test_tensor::{
4    arange::build_arange,
5    custom::build_custom,
6    eye::build_eye,
7    host_data::{HostData, HostDataType},
8    random::build_random,
9    strides::StrideSpec,
10    zeros::build_zeros,
11};
12
13pub struct TestInput {
14    client: ComputeClient<TestRuntime>,
15    spec: TestInputSpec,
16}
17
18pub enum TestInputSpec {
19    Arange(SimpleInputSpec),
20    Eye(SimpleInputSpec),
21    Random(RandomInputSpec),
22    Zeros(SimpleInputSpec),
23    Custom(CustomInputSpec),
24}
25
26impl TestInput {
27    pub fn random(
28        client: ComputeClient<TestRuntime>,
29        shape: Vec<usize>,
30        dtype: StorageType,
31        seed: u64,
32        distribution: Distribution,
33        stride_spec: StrideSpec,
34    ) -> Self {
35        let inner = SimpleInputSpec {
36            client: client.clone(),
37            shape,
38            dtype,
39            stride_spec,
40        };
41
42        let spec = RandomInputSpec {
43            inner,
44            seed,
45            distribution,
46        };
47
48        TestInput {
49            client,
50            spec: TestInputSpec::Random(spec),
51        }
52    }
53
54    pub fn zeros(
55        client: ComputeClient<TestRuntime>,
56        shape: Vec<usize>,
57        dtype: StorageType,
58        stride_spec: StrideSpec,
59    ) -> Self {
60        TestInput {
61            client: client.clone(),
62            spec: TestInputSpec::Zeros(SimpleInputSpec {
63                client,
64                shape,
65                dtype,
66                stride_spec,
67            }),
68        }
69    }
70
71    pub fn eye(
72        client: ComputeClient<TestRuntime>,
73        shape: Vec<usize>,
74        dtype: StorageType,
75        stride_spec: StrideSpec,
76    ) -> Self {
77        TestInput {
78            client: client.clone(),
79            spec: TestInputSpec::Eye(SimpleInputSpec {
80                client,
81                shape,
82                dtype,
83                stride_spec,
84            }),
85        }
86    }
87
88    pub fn arange(
89        client: ComputeClient<TestRuntime>,
90        shape: Vec<usize>,
91        dtype: StorageType,
92        stride_spec: StrideSpec,
93    ) -> Self {
94        let spec = SimpleInputSpec {
95            client: client.clone(),
96            shape,
97            dtype,
98            stride_spec,
99        };
100
101        TestInput {
102            client,
103            spec: TestInputSpec::Arange(spec),
104        }
105    }
106
107    pub fn custom(
108        client: ComputeClient<TestRuntime>,
109        shape: Vec<usize>,
110        dtype: StorageType,
111        stride_spec: StrideSpec,
112        data: Vec<f32>,
113    ) -> Self {
114        let inner = SimpleInputSpec {
115            client: client.clone(),
116            shape,
117            dtype,
118            stride_spec,
119        };
120
121        let spec = CustomInputSpec { inner, data };
122
123        TestInput {
124            client,
125            spec: TestInputSpec::Custom(spec),
126        }
127    }
128    pub fn generate_with_f32_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
129        self.generate_host_data(HostDataType::F32)
130    }
131
132    pub fn generate_with_bool_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
133        self.generate_host_data(HostDataType::Bool)
134    }
135
136    pub fn f32_host_data(self) -> HostData {
137        self.generate_host_data(HostDataType::F32).1
138    }
139
140    pub fn bool_host_data(self) -> HostData {
141        self.generate_host_data(HostDataType::Bool).1
142    }
143
144    // Public API returning only TensorHandle
145    pub fn generate_without_host_data(self) -> TensorHandle<TestRuntime> {
146        self.generate()
147    }
148
149    pub fn generate(self) -> TensorHandle<TestRuntime> {
150        match self.spec {
151            TestInputSpec::Arange(spec) => build_arange(spec),
152            TestInputSpec::Eye(spec) => build_eye(spec),
153            TestInputSpec::Random(spec) => build_random(spec),
154            TestInputSpec::Zeros(spec) => build_zeros(spec),
155            TestInputSpec::Custom(spec) => build_custom(spec),
156        }
157    }
158
159    fn generate_host_data(
160        self,
161        host_data_type: HostDataType,
162    ) -> (TensorHandle<TestRuntime>, HostData) {
163        let client = self.client.clone();
164        let tensor_handle = self.generate();
165        let host_data = HostData::from_tensor_handle(&client, &tensor_handle, host_data_type);
166        (tensor_handle, host_data)
167    }
168}
169
170pub struct SimpleInputSpec {
171    pub(crate) client: ComputeClient<TestRuntime>,
172    pub(crate) shape: Vec<usize>,
173    pub(crate) dtype: StorageType,
174    pub(crate) stride_spec: StrideSpec,
175}
176
177impl SimpleInputSpec {
178    pub(crate) fn strides(&self) -> Vec<usize> {
179        self.stride_spec.compute_strides(&self.shape)
180    }
181}
182
183pub struct RandomInputSpec {
184    pub(crate) inner: SimpleInputSpec,
185    pub(crate) seed: u64,
186    pub(crate) distribution: Distribution,
187}
188
189pub struct CustomInputSpec {
190    pub(crate) inner: SimpleInputSpec,
191    pub(crate) data: Vec<f32>,
192}
193
194#[derive(Copy, Clone)]
195pub enum Distribution {
196    // lower, upper bounds
197    Uniform(f32, f32),
198    // prob
199    Bernoulli(f32),
200}