use cubecl::{
TestRuntime,
client::ComputeClient,
ir::{ElemType, StorageType},
prelude::CubePrimitive,
std::tensor::TensorHandle,
zspace::{Shape, Strides},
};
use cubecl_common::quant::scheme::QuantScheme;
use cubek_quant::scheme::QuantStore;
use crate::test_tensor::{
arange::build_arange,
custom::build_custom,
eye::build_eye,
host_data::{HostData, HostDataType},
quant::apply_quantization,
random::build_random,
strides::StrideSpec,
zeros::build_zeros,
};
#[derive(Clone)]
pub struct QuantizationInfo {
pub scale: TensorHandle<TestRuntime>,
pub scheme: QuantScheme,
pub shape: Shape,
}
#[derive(Clone)]
pub struct TestTensor {
pub handle: TensorHandle<TestRuntime>,
pub host: HostData,
pub quantization: Option<QuantizationInfo>,
}
#[derive(Clone, Debug)]
pub enum InputDataType {
Standard(StorageType),
Quantized(QuantScheme),
}
impl From<StorageType> for InputDataType {
fn from(dtype: StorageType) -> Self {
InputDataType::Standard(dtype)
}
}
impl From<cubecl::ir::ElemType> for InputDataType {
fn from(elem: cubecl::ir::ElemType) -> Self {
InputDataType::Standard(StorageType::Scalar(elem))
}
}
impl InputDataType {
pub fn storage_type(&self) -> StorageType {
match self {
InputDataType::Standard(dtype) => *dtype,
InputDataType::Quantized(scheme) => {
let elem = ElemType::from_quant_value(scheme.value);
match scheme.store {
QuantStore::Native => StorageType::Scalar(elem),
QuantStore::PackedNative(_) => {
StorageType::Packed(elem, scheme.native_packing())
}
QuantStore::PackedU32(_) => {
let factor = scheme.num_quants();
StorageType::Packed(elem, factor)
}
}
}
}
}
pub fn is_quantized(&self) -> bool {
matches!(self, InputDataType::Quantized(_))
}
pub fn scheme(&self) -> Option<QuantScheme> {
match self {
InputDataType::Quantized(scheme) => Some(*scheme),
_ => None,
}
}
}
pub struct TestInput {
base_spec: BaseInputSpec,
data_kind: DataKind,
input_dtype: InputDataType,
}
pub enum DataKind {
Arange {
scale: Option<f32>,
},
Eye,
Zeros,
Random {
seed: u64,
distribution: Distribution,
},
Custom {
data: Vec<f32>,
},
}
impl TestInput {
pub fn builder(
client: ComputeClient<TestRuntime>,
shape: impl Into<Shape>,
) -> TestInputBuilder {
TestInputBuilder::new(client, shape.into())
}
pub fn new(
client: ComputeClient<TestRuntime>,
shape: impl Into<Shape>,
dtype: impl Into<InputDataType>,
stride_spec: StrideSpec,
data_kind: DataKind,
) -> Self {
let dtype = dtype.into();
let storage_type = match &dtype {
InputDataType::Standard(dtype) => *dtype,
InputDataType::Quantized(_scheme) => {
f32::as_type_native_unchecked().storage_type()
}
};
let base_spec = BaseInputSpec {
client,
shape: shape.into(),
dtype: storage_type,
stride_spec,
};
Self {
base_spec,
data_kind,
input_dtype: dtype,
}
}
pub fn generate_with_f32_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
self.generate_host_data(HostDataType::F32)
}
pub fn generate_with_bool_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
self.generate_host_data(HostDataType::Bool)
}
pub fn generate_test_tensor(self) -> TestTensor {
let input_dtype = self.input_dtype.clone();
let client = self.base_spec.client.clone();
let (handle, host) = self.generate_with_f32_host_data();
let mut tensor = TestTensor {
handle,
host,
quantization: None,
};
if let InputDataType::Quantized(scheme) = input_dtype {
apply_quantization(&client, &mut tensor, scheme);
}
tensor
}
pub fn f32_host_data(self) -> HostData {
self.generate_host_data(HostDataType::F32).1
}
pub fn bool_host_data(self) -> HostData {
self.generate_host_data(HostDataType::Bool).1
}
pub fn generate_without_host_data(self) -> TensorHandle<TestRuntime> {
self.generate()
}
pub fn generate(self) -> TensorHandle<TestRuntime> {
let (shape, strides, dtype) = (
self.base_spec.shape.clone(),
self.base_spec.strides(),
self.base_spec.dtype,
);
let mut handle = match self.data_kind {
DataKind::Arange { scale } => build_arange(self.base_spec, scale),
DataKind::Eye => build_eye(self.base_spec),
DataKind::Random { seed, distribution } => {
build_random(self.base_spec, seed, distribution)
}
DataKind::Zeros => build_zeros(self.base_spec),
DataKind::Custom { data } => build_custom(self.base_spec, data),
};
handle.metadata.shape = shape;
handle.metadata.strides = strides;
handle.dtype = dtype;
handle
}
fn generate_host_data(
self,
host_data_type: HostDataType,
) -> (TensorHandle<TestRuntime>, HostData) {
let client = self.base_spec.client.clone();
let tensor_handle = self.generate();
let host_data =
HostData::from_tensor_handle(&client, tensor_handle.clone(), host_data_type);
(tensor_handle, host_data)
}
}
pub struct BaseInputSpec {
pub client: ComputeClient<TestRuntime>,
pub shape: Shape,
pub dtype: StorageType,
pub stride_spec: StrideSpec,
}
impl BaseInputSpec {
pub(crate) fn strides(&self) -> Strides {
self.stride_spec.compute_strides(&self.shape)
}
}
pub struct RandomInputSpec {
pub seed: u64,
pub distribution: Distribution,
}
#[derive(Copy, Clone)]
pub enum Distribution {
Uniform(f32, f32),
Bernoulli(f32),
Normal { mean: f32, std: f32 },
}
pub struct TestInputBuilder {
client: ComputeClient<TestRuntime>,
shape: Shape,
dtype: Option<InputDataType>,
stride_spec: StrideSpec,
}
impl TestInputBuilder {
fn new(client: ComputeClient<TestRuntime>, shape: Shape) -> Self {
Self {
client,
shape,
dtype: None,
stride_spec: StrideSpec::RowMajor,
}
}
pub fn dtype(mut self, dtype: impl Into<InputDataType>) -> Self {
self.dtype = Some(dtype.into());
self
}
pub fn stride(mut self, stride_spec: StrideSpec) -> Self {
self.stride_spec = stride_spec;
self
}
fn finalize(self, data_kind: DataKind) -> TestInput {
let dtype = self.dtype.unwrap_or_else(|| {
InputDataType::Standard(f32::as_type_native_unchecked().storage_type())
});
TestInput::new(self.client, self.shape, dtype, self.stride_spec, data_kind)
}
pub fn arange(self) -> TestInput {
self.finalize(DataKind::Arange { scale: None })
}
pub fn arange_scaled(self, scale: f32) -> TestInput {
self.finalize(DataKind::Arange { scale: Some(scale) })
}
pub fn eye(self) -> TestInput {
self.finalize(DataKind::Eye)
}
pub fn zeros(self) -> TestInput {
self.finalize(DataKind::Zeros)
}
pub fn random(self, seed: u64, distribution: Distribution) -> TestInput {
self.finalize(DataKind::Random { seed, distribution })
}
pub fn uniform(self, seed: u64, lo: f32, hi: f32) -> TestInput {
self.random(seed, Distribution::Uniform(lo, hi))
}
pub fn bernoulli(self, seed: u64, p: f32) -> TestInput {
self.random(seed, Distribution::Bernoulli(p))
}
pub fn normal(self, seed: u64, mean: f32, std: f32) -> TestInput {
self.random(seed, Distribution::Normal { mean, std })
}
pub fn custom(self, data: Vec<f32>) -> TestInput {
self.finalize(DataKind::Custom { data })
}
pub fn linspace(self, start: f32, end: f32) -> TestInput {
let num_elems: usize = self.shape.iter().product();
let data = if num_elems == 0 {
Vec::new()
} else if num_elems == 1 {
vec![start]
} else {
let step = (end - start) / (num_elems - 1) as f32;
(0..num_elems).map(|i| start + step * i as f32).collect()
};
self.finalize(DataKind::Custom { data })
}
}