Skip to main content

cubek_test_utils/test_tensor/
base.rs

1use cubecl::{
2    TestRuntime,
3    client::ComputeClient,
4    ir::{ElemType, StorageType},
5    prelude::CubePrimitive,
6    std::tensor::TensorHandle,
7    zspace::{Shape, Strides},
8};
9use cubecl_common::quant::scheme::QuantScheme;
10use cubek_quant::scheme::QuantStore;
11
12use crate::test_tensor::{
13    arange::build_arange,
14    custom::build_custom,
15    eye::build_eye,
16    host_data::{HostData, HostDataType},
17    quant::apply_quantization,
18    random::build_random,
19    strides::StrideSpec,
20    zeros::build_zeros,
21};
22
23#[derive(Clone)]
24/// Information about a quantized tensor in tests.
25/// This allows marking a tensor as quantized for the kernel dispatcher
26/// while keeping the original unquantized data on the host for reference.
27pub struct QuantizationInfo {
28    /// The scale tensor on the device.
29    pub scale: TensorHandle<TestRuntime>,
30    /// The quantization scheme (e.g., Symmetric, Tensor-wise, etc.)
31    pub scheme: QuantScheme,
32    /// The original unquantized shape of the tensor.
33    pub shape: Shape,
34}
35
36#[derive(Clone)]
37/// A test tensor which might be marked as quantized.
38///
39/// This structure couples the device handle, the host reference data,
40/// and optional quantization metadata. If `quantization` is `Some`,
41/// the handle on the device is expected to contain quantized data
42/// (unless it's a dummy quantization for testing purposes).
43pub struct TestTensor {
44    /// The device handle.
45    pub handle: TensorHandle<TestRuntime>,
46    /// The host data, usually stored in f32 for easy reference comparison.
47    pub host: HostData,
48    /// Optional quantization info.
49    pub quantization: Option<QuantizationInfo>,
50}
51
52#[derive(Clone, Debug)]
53pub enum InputDataType {
54    Standard(StorageType),
55    Quantized(QuantScheme),
56}
57
58impl From<StorageType> for InputDataType {
59    fn from(dtype: StorageType) -> Self {
60        InputDataType::Standard(dtype)
61    }
62}
63
64impl From<cubecl::ir::ElemType> for InputDataType {
65    fn from(elem: cubecl::ir::ElemType) -> Self {
66        InputDataType::Standard(StorageType::Scalar(elem))
67    }
68}
69
70impl InputDataType {
71    pub fn storage_type(&self) -> StorageType {
72        match self {
73            InputDataType::Standard(dtype) => *dtype,
74            InputDataType::Quantized(scheme) => {
75                let elem = ElemType::from_quant_value(scheme.value);
76
77                match scheme.store {
78                    QuantStore::Native => StorageType::Scalar(elem),
79                    QuantStore::PackedNative(_) => {
80                        // Uses the format's inherent packing factor (e.g., E2M1x2)
81                        StorageType::Packed(elem, scheme.native_packing())
82                    }
83                    QuantStore::PackedU32(_) => {
84                        // Usually represents multiple small quants in a 32-bit register
85                        // factor would be 4 for 8-bit, 8 for 4-bit, etc.
86                        let factor = scheme.num_quants();
87                        StorageType::Packed(elem, factor)
88                    }
89                }
90            }
91        }
92    }
93
94    pub fn is_quantized(&self) -> bool {
95        matches!(self, InputDataType::Quantized(_))
96    }
97
98    pub fn scheme(&self) -> Option<QuantScheme> {
99        match self {
100            InputDataType::Quantized(scheme) => Some(*scheme),
101            _ => None,
102        }
103    }
104}
105
106pub struct TestInput {
107    base_spec: BaseInputSpec,
108    data_kind: DataKind,
109    input_dtype: InputDataType,
110}
111
112pub enum DataKind {
113    Arange {
114        scale: Option<f32>,
115    },
116    Eye,
117    Zeros,
118    Random {
119        seed: u64,
120        distribution: Distribution,
121    },
122    Custom {
123        data: Vec<f32>,
124    },
125}
126
127impl TestInput {
128    /// Start a fluent builder for a test input.
129    ///
130    /// Defaults: `dtype = f32`, `stride = RowMajor`. Call `.dtype(_)` /
131    /// `.stride(_)` to override, then a finalizer such as `.arange()`,
132    /// `.eye()`, `.zeros()`, `.uniform(seed, lo, hi)`, `.bernoulli(seed, p)`,
133    /// or `.custom(data)` to produce a [`TestInput`] ready to generate.
134    pub fn builder(
135        client: ComputeClient<TestRuntime>,
136        shape: impl Into<Shape>,
137    ) -> TestInputBuilder {
138        TestInputBuilder::new(client, shape.into())
139    }
140
141    pub fn new(
142        client: ComputeClient<TestRuntime>,
143        shape: impl Into<Shape>,
144        dtype: impl Into<InputDataType>,
145        stride_spec: StrideSpec,
146        data_kind: DataKind,
147    ) -> Self {
148        let dtype = dtype.into();
149        let storage_type = match &dtype {
150            InputDataType::Standard(dtype) => *dtype,
151            InputDataType::Quantized(_scheme) => {
152                // For quantized input, the initial data is generated as f32 (Standard)
153                // then it will be quantized in generate_test_tensor.
154                f32::as_type_native_unchecked().storage_type()
155            }
156        };
157
158        let base_spec = BaseInputSpec {
159            client,
160            shape: shape.into(),
161            dtype: storage_type,
162            stride_spec,
163        };
164
165        Self {
166            base_spec,
167            data_kind,
168            input_dtype: dtype,
169        }
170    }
171
172    pub fn generate_with_f32_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
173        self.generate_host_data(HostDataType::F32)
174    }
175
176    pub fn generate_with_bool_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
177        self.generate_host_data(HostDataType::Bool)
178    }
179
180    pub fn generate_test_tensor(self) -> TestTensor {
181        let input_dtype = self.input_dtype.clone();
182        let client = self.base_spec.client.clone();
183        let (handle, host) = self.generate_with_f32_host_data();
184
185        let mut tensor = TestTensor {
186            handle,
187            host,
188            quantization: None,
189        };
190
191        if let InputDataType::Quantized(scheme) = input_dtype {
192            apply_quantization(&client, &mut tensor, scheme);
193        }
194
195        tensor
196    }
197
198    pub fn f32_host_data(self) -> HostData {
199        self.generate_host_data(HostDataType::F32).1
200    }
201
202    pub fn bool_host_data(self) -> HostData {
203        self.generate_host_data(HostDataType::Bool).1
204    }
205
206    // Public API returning only TensorHandle
207    pub fn generate_without_host_data(self) -> TensorHandle<TestRuntime> {
208        self.generate()
209    }
210
211    pub fn generate(self) -> TensorHandle<TestRuntime> {
212        let (shape, strides, dtype) = (
213            self.base_spec.shape.clone(),
214            self.base_spec.strides(),
215            self.base_spec.dtype,
216        );
217
218        let mut handle = match self.data_kind {
219            DataKind::Arange { scale } => build_arange(self.base_spec, scale),
220            DataKind::Eye => build_eye(self.base_spec),
221            DataKind::Random { seed, distribution } => {
222                build_random(self.base_spec, seed, distribution)
223            }
224            DataKind::Zeros => build_zeros(self.base_spec),
225            DataKind::Custom { data } => build_custom(self.base_spec, data),
226        };
227        handle.metadata.shape = shape;
228        handle.metadata.strides = strides;
229        handle.dtype = dtype;
230
231        handle
232    }
233
234    fn generate_host_data(
235        self,
236        host_data_type: HostDataType,
237    ) -> (TensorHandle<TestRuntime>, HostData) {
238        let client = self.base_spec.client.clone();
239
240        let tensor_handle = self.generate();
241        let host_data =
242            HostData::from_tensor_handle(&client, tensor_handle.clone(), host_data_type);
243
244        (tensor_handle, host_data)
245    }
246}
247
248pub struct BaseInputSpec {
249    pub client: ComputeClient<TestRuntime>,
250    pub shape: Shape,
251    pub dtype: StorageType,
252    pub stride_spec: StrideSpec,
253}
254
255impl BaseInputSpec {
256    pub(crate) fn strides(&self) -> Strides {
257        self.stride_spec.compute_strides(&self.shape)
258    }
259}
260
261pub struct RandomInputSpec {
262    pub seed: u64,
263    pub distribution: Distribution,
264}
265
266#[derive(Copy, Clone)]
267pub enum Distribution {
268    /// Uniform random over `[lower, upper]`.
269    Uniform(f32, f32),
270    /// Bernoulli random with probability `prob` of `1`.
271    Bernoulli(f32),
272    /// Normal (Gaussian) random with the given `mean` and `std`.
273    Normal { mean: f32, std: f32 },
274}
275
276/// Fluent builder for [`TestInput`].
277///
278/// Use [`TestInput::builder`] to start one. The builder holds the shape,
279/// dtype, and stride spec. Call a finalizer (`arange`, `eye`, `zeros`,
280/// `uniform`, `bernoulli`, `random`, `custom`) to produce a [`TestInput`]
281/// ready to generate a tensor handle, host data, or test tensor.
282///
283/// # Example
284///
285/// ```ignore
286/// use cubek_test_utils::{TestInput, StrideSpec, Distribution};
287///
288/// let (handle, host) = TestInput::builder(client, [4, 4])
289///     .stride(StrideSpec::ColMajor)
290///     .uniform( 0, -1.0, 1.0)
291///     .generate_with_f32_host_data();
292/// ```
293pub struct TestInputBuilder {
294    client: ComputeClient<TestRuntime>,
295    shape: Shape,
296    dtype: Option<InputDataType>,
297    stride_spec: StrideSpec,
298}
299
300impl TestInputBuilder {
301    fn new(client: ComputeClient<TestRuntime>, shape: Shape) -> Self {
302        Self {
303            client,
304            shape,
305            dtype: None,
306            stride_spec: StrideSpec::RowMajor,
307        }
308    }
309
310    /// Override the dtype. Defaults to f32.
311    pub fn dtype(mut self, dtype: impl Into<InputDataType>) -> Self {
312        self.dtype = Some(dtype.into());
313        self
314    }
315
316    /// Override the stride layout. Defaults to [`StrideSpec::RowMajor`].
317    pub fn stride(mut self, stride_spec: StrideSpec) -> Self {
318        self.stride_spec = stride_spec;
319        self
320    }
321
322    fn finalize(self, data_kind: DataKind) -> TestInput {
323        let dtype = self.dtype.unwrap_or_else(|| {
324            InputDataType::Standard(f32::as_type_native_unchecked().storage_type())
325        });
326        TestInput::new(self.client, self.shape, dtype, self.stride_spec, data_kind)
327    }
328
329    /// `0, 1, 2, …` in row-major order.
330    pub fn arange(self) -> TestInput {
331        self.finalize(DataKind::Arange { scale: None })
332    }
333
334    /// `arange` with each value multiplied by `scale`.
335    pub fn arange_scaled(self, scale: f32) -> TestInput {
336        self.finalize(DataKind::Arange { scale: Some(scale) })
337    }
338
339    /// Identity matrix (1 on the diagonal, 0 elsewhere).
340    pub fn eye(self) -> TestInput {
341        self.finalize(DataKind::Eye)
342    }
343
344    /// All-zeros tensor.
345    pub fn zeros(self) -> TestInput {
346        self.finalize(DataKind::Zeros)
347    }
348
349    /// Random tensor with a custom [`Distribution`].
350    pub fn random(self, seed: u64, distribution: Distribution) -> TestInput {
351        self.finalize(DataKind::Random { seed, distribution })
352    }
353
354    /// Uniform random in `[lo, hi]`.
355    pub fn uniform(self, seed: u64, lo: f32, hi: f32) -> TestInput {
356        self.random(seed, Distribution::Uniform(lo, hi))
357    }
358
359    /// Bernoulli random with probability `p` of 1.
360    pub fn bernoulli(self, seed: u64, p: f32) -> TestInput {
361        self.random(seed, Distribution::Bernoulli(p))
362    }
363
364    /// Normal (Gaussian) random with the given `mean` and `std`.
365    pub fn normal(self, seed: u64, mean: f32, std: f32) -> TestInput {
366        self.random(seed, Distribution::Normal { mean, std })
367    }
368
369    /// Tensor populated from an explicit row-major `Vec<f32>`.
370    pub fn custom(self, data: Vec<f32>) -> TestInput {
371        self.finalize(DataKind::Custom { data })
372    }
373
374    /// Evenly-spaced values from `start` to `end` inclusive, populated in
375    /// row-major order. The number of points equals the tensor's element count.
376    ///
377    /// Equivalent to NumPy's `np.linspace(start, end, num=shape.numel()).reshape(shape)`.
378    pub fn linspace(self, start: f32, end: f32) -> TestInput {
379        let num_elems: usize = self.shape.iter().product();
380        let data = if num_elems == 0 {
381            Vec::new()
382        } else if num_elems == 1 {
383            vec![start]
384        } else {
385            let step = (end - start) / (num_elems - 1) as f32;
386            (0..num_elems).map(|i| start + step * i as f32).collect()
387        };
388        self.finalize(DataKind::Custom { data })
389    }
390}