cubek_test_utils/test_tensor/
base.rs1use cubecl::{
2 TestRuntime,
3 client::ComputeClient,
4 ir::{ElemType, StorageType},
5 prelude::CubePrimitive,
6 std::tensor::TensorHandle,
7 zspace::{Shape, Strides},
8};
9use cubecl_common::quant::scheme::QuantScheme;
10use cubek_quant::scheme::QuantStore;
11
12use crate::test_tensor::{
13 arange::build_arange,
14 custom::build_custom,
15 eye::build_eye,
16 host_data::{HostData, HostDataType},
17 quant::apply_quantization,
18 random::build_random,
19 strides::StrideSpec,
20 zeros::build_zeros,
21};
22
23#[derive(Clone)]
24pub struct QuantizationInfo {
28 pub scale: TensorHandle<TestRuntime>,
30 pub scheme: QuantScheme,
32 pub shape: Shape,
34}
35
36#[derive(Clone)]
37pub struct TestTensor {
44 pub handle: TensorHandle<TestRuntime>,
46 pub host: HostData,
48 pub quantization: Option<QuantizationInfo>,
50}
51
52#[derive(Clone, Debug)]
53pub enum InputDataType {
54 Standard(StorageType),
55 Quantized(QuantScheme),
56}
57
58impl From<StorageType> for InputDataType {
59 fn from(dtype: StorageType) -> Self {
60 InputDataType::Standard(dtype)
61 }
62}
63
64impl From<cubecl::ir::ElemType> for InputDataType {
65 fn from(elem: cubecl::ir::ElemType) -> Self {
66 InputDataType::Standard(StorageType::Scalar(elem))
67 }
68}
69
70impl InputDataType {
71 pub fn storage_type(&self) -> StorageType {
72 match self {
73 InputDataType::Standard(dtype) => *dtype,
74 InputDataType::Quantized(scheme) => {
75 let elem = ElemType::from_quant_value(scheme.value);
76
77 match scheme.store {
78 QuantStore::Native => StorageType::Scalar(elem),
79 QuantStore::PackedNative(_) => {
80 StorageType::Packed(elem, scheme.native_packing())
82 }
83 QuantStore::PackedU32(_) => {
84 let factor = scheme.num_quants();
87 StorageType::Packed(elem, factor)
88 }
89 }
90 }
91 }
92 }
93
94 pub fn is_quantized(&self) -> bool {
95 matches!(self, InputDataType::Quantized(_))
96 }
97
98 pub fn scheme(&self) -> Option<QuantScheme> {
99 match self {
100 InputDataType::Quantized(scheme) => Some(*scheme),
101 _ => None,
102 }
103 }
104}
105
106pub struct TestInput {
107 base_spec: BaseInputSpec,
108 data_kind: DataKind,
109 input_dtype: InputDataType,
110}
111
112pub enum DataKind {
113 Arange {
114 scale: Option<f32>,
115 },
116 Eye,
117 Zeros,
118 Random {
119 seed: u64,
120 distribution: Distribution,
121 },
122 Custom {
123 data: Vec<f32>,
124 },
125}
126
127impl TestInput {
128 pub fn builder(
135 client: ComputeClient<TestRuntime>,
136 shape: impl Into<Shape>,
137 ) -> TestInputBuilder {
138 TestInputBuilder::new(client, shape.into())
139 }
140
141 pub fn new(
142 client: ComputeClient<TestRuntime>,
143 shape: impl Into<Shape>,
144 dtype: impl Into<InputDataType>,
145 stride_spec: StrideSpec,
146 data_kind: DataKind,
147 ) -> Self {
148 let dtype = dtype.into();
149 let storage_type = match &dtype {
150 InputDataType::Standard(dtype) => *dtype,
151 InputDataType::Quantized(_scheme) => {
152 f32::as_type_native_unchecked().storage_type()
155 }
156 };
157
158 let base_spec = BaseInputSpec {
159 client,
160 shape: shape.into(),
161 dtype: storage_type,
162 stride_spec,
163 };
164
165 Self {
166 base_spec,
167 data_kind,
168 input_dtype: dtype,
169 }
170 }
171
172 pub fn generate_with_f32_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
173 self.generate_host_data(HostDataType::F32)
174 }
175
176 pub fn generate_with_bool_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
177 self.generate_host_data(HostDataType::Bool)
178 }
179
180 pub fn generate_test_tensor(self) -> TestTensor {
181 let input_dtype = self.input_dtype.clone();
182 let client = self.base_spec.client.clone();
183 let (handle, host) = self.generate_with_f32_host_data();
184
185 let mut tensor = TestTensor {
186 handle,
187 host,
188 quantization: None,
189 };
190
191 if let InputDataType::Quantized(scheme) = input_dtype {
192 apply_quantization(&client, &mut tensor, scheme);
193 }
194
195 tensor
196 }
197
198 pub fn f32_host_data(self) -> HostData {
199 self.generate_host_data(HostDataType::F32).1
200 }
201
202 pub fn bool_host_data(self) -> HostData {
203 self.generate_host_data(HostDataType::Bool).1
204 }
205
206 pub fn generate_without_host_data(self) -> TensorHandle<TestRuntime> {
208 self.generate()
209 }
210
211 pub fn generate(self) -> TensorHandle<TestRuntime> {
212 let (shape, strides, dtype) = (
213 self.base_spec.shape.clone(),
214 self.base_spec.strides(),
215 self.base_spec.dtype,
216 );
217
218 let mut handle = match self.data_kind {
219 DataKind::Arange { scale } => build_arange(self.base_spec, scale),
220 DataKind::Eye => build_eye(self.base_spec),
221 DataKind::Random { seed, distribution } => {
222 build_random(self.base_spec, seed, distribution)
223 }
224 DataKind::Zeros => build_zeros(self.base_spec),
225 DataKind::Custom { data } => build_custom(self.base_spec, data),
226 };
227 handle.metadata.shape = shape;
228 handle.metadata.strides = strides;
229 handle.dtype = dtype;
230
231 handle
232 }
233
234 fn generate_host_data(
235 self,
236 host_data_type: HostDataType,
237 ) -> (TensorHandle<TestRuntime>, HostData) {
238 let client = self.base_spec.client.clone();
239
240 let tensor_handle = self.generate();
241 let host_data =
242 HostData::from_tensor_handle(&client, tensor_handle.clone(), host_data_type);
243
244 (tensor_handle, host_data)
245 }
246}
247
248pub struct BaseInputSpec {
249 pub client: ComputeClient<TestRuntime>,
250 pub shape: Shape,
251 pub dtype: StorageType,
252 pub stride_spec: StrideSpec,
253}
254
255impl BaseInputSpec {
256 pub(crate) fn strides(&self) -> Strides {
257 self.stride_spec.compute_strides(&self.shape)
258 }
259}
260
261pub struct RandomInputSpec {
262 pub seed: u64,
263 pub distribution: Distribution,
264}
265
266#[derive(Copy, Clone)]
267pub enum Distribution {
268 Uniform(f32, f32),
270 Bernoulli(f32),
272 Normal { mean: f32, std: f32 },
274}
275
276pub struct TestInputBuilder {
294 client: ComputeClient<TestRuntime>,
295 shape: Shape,
296 dtype: Option<InputDataType>,
297 stride_spec: StrideSpec,
298}
299
300impl TestInputBuilder {
301 fn new(client: ComputeClient<TestRuntime>, shape: Shape) -> Self {
302 Self {
303 client,
304 shape,
305 dtype: None,
306 stride_spec: StrideSpec::RowMajor,
307 }
308 }
309
310 pub fn dtype(mut self, dtype: impl Into<InputDataType>) -> Self {
312 self.dtype = Some(dtype.into());
313 self
314 }
315
316 pub fn stride(mut self, stride_spec: StrideSpec) -> Self {
318 self.stride_spec = stride_spec;
319 self
320 }
321
322 fn finalize(self, data_kind: DataKind) -> TestInput {
323 let dtype = self.dtype.unwrap_or_else(|| {
324 InputDataType::Standard(f32::as_type_native_unchecked().storage_type())
325 });
326 TestInput::new(self.client, self.shape, dtype, self.stride_spec, data_kind)
327 }
328
329 pub fn arange(self) -> TestInput {
331 self.finalize(DataKind::Arange { scale: None })
332 }
333
334 pub fn arange_scaled(self, scale: f32) -> TestInput {
336 self.finalize(DataKind::Arange { scale: Some(scale) })
337 }
338
339 pub fn eye(self) -> TestInput {
341 self.finalize(DataKind::Eye)
342 }
343
344 pub fn zeros(self) -> TestInput {
346 self.finalize(DataKind::Zeros)
347 }
348
349 pub fn random(self, seed: u64, distribution: Distribution) -> TestInput {
351 self.finalize(DataKind::Random { seed, distribution })
352 }
353
354 pub fn uniform(self, seed: u64, lo: f32, hi: f32) -> TestInput {
356 self.random(seed, Distribution::Uniform(lo, hi))
357 }
358
359 pub fn bernoulli(self, seed: u64, p: f32) -> TestInput {
361 self.random(seed, Distribution::Bernoulli(p))
362 }
363
364 pub fn normal(self, seed: u64, mean: f32, std: f32) -> TestInput {
366 self.random(seed, Distribution::Normal { mean, std })
367 }
368
369 pub fn custom(self, data: Vec<f32>) -> TestInput {
371 self.finalize(DataKind::Custom { data })
372 }
373
374 pub fn linspace(self, start: f32, end: f32) -> TestInput {
379 let num_elems: usize = self.shape.iter().product();
380 let data = if num_elems == 0 {
381 Vec::new()
382 } else if num_elems == 1 {
383 vec![start]
384 } else {
385 let step = (end - start) / (num_elems - 1) as f32;
386 (0..num_elems).map(|i| start + step * i as f32).collect()
387 };
388 self.finalize(DataKind::Custom { data })
389 }
390}