cubek_test_utils/test_tensor/
base.rs1use cubecl::{TestRuntime, client::ComputeClient, ir::StorageType, std::tensor::TensorHandle};
2
3use crate::test_tensor::{
4 arange::build_arange,
5 custom::build_custom,
6 eye::build_eye,
7 host_data::{HostData, HostDataType},
8 random::build_random,
9 strides::StrideSpec,
10 zeros::build_zeros,
11};
12
13pub struct TestInput {
14 client: ComputeClient<TestRuntime>,
15 spec: TestInputSpec,
16}
17
18pub enum TestInputSpec {
19 Arange(SimpleInputSpec),
20 Eye(SimpleInputSpec),
21 Random(RandomInputSpec),
22 Zeros(SimpleInputSpec),
23 Custom(CustomInputSpec),
24}
25
26impl TestInput {
27 pub fn random(
28 client: ComputeClient<TestRuntime>,
29 shape: Vec<usize>,
30 dtype: StorageType,
31 seed: u64,
32 distribution: Distribution,
33 stride_spec: StrideSpec,
34 ) -> Self {
35 let inner = SimpleInputSpec {
36 client: client.clone(),
37 shape,
38 dtype,
39 stride_spec,
40 };
41
42 let spec = RandomInputSpec {
43 inner,
44 seed,
45 distribution,
46 };
47
48 TestInput {
49 client,
50 spec: TestInputSpec::Random(spec),
51 }
52 }
53
54 pub fn zeros(
55 client: ComputeClient<TestRuntime>,
56 shape: Vec<usize>,
57 dtype: StorageType,
58 stride_spec: StrideSpec,
59 ) -> Self {
60 TestInput {
61 client: client.clone(),
62 spec: TestInputSpec::Zeros(SimpleInputSpec {
63 client,
64 shape,
65 dtype,
66 stride_spec,
67 }),
68 }
69 }
70
71 pub fn eye(
72 client: ComputeClient<TestRuntime>,
73 shape: Vec<usize>,
74 dtype: StorageType,
75 stride_spec: StrideSpec,
76 ) -> Self {
77 TestInput {
78 client: client.clone(),
79 spec: TestInputSpec::Eye(SimpleInputSpec {
80 client,
81 shape,
82 dtype,
83 stride_spec,
84 }),
85 }
86 }
87
88 pub fn arange(
89 client: ComputeClient<TestRuntime>,
90 shape: Vec<usize>,
91 dtype: StorageType,
92 stride_spec: StrideSpec,
93 ) -> Self {
94 let spec = SimpleInputSpec {
95 client: client.clone(),
96 shape,
97 dtype,
98 stride_spec,
99 };
100
101 TestInput {
102 client,
103 spec: TestInputSpec::Arange(spec),
104 }
105 }
106
107 pub fn custom(
108 client: ComputeClient<TestRuntime>,
109 shape: Vec<usize>,
110 dtype: StorageType,
111 stride_spec: StrideSpec,
112 data: Vec<f32>,
113 ) -> Self {
114 let inner = SimpleInputSpec {
115 client: client.clone(),
116 shape,
117 dtype,
118 stride_spec,
119 };
120
121 let spec = CustomInputSpec { inner, data };
122
123 TestInput {
124 client,
125 spec: TestInputSpec::Custom(spec),
126 }
127 }
128 pub fn generate_with_f32_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
129 self.generate_host_data(HostDataType::F32)
130 }
131
132 pub fn generate_with_bool_host_data(self) -> (TensorHandle<TestRuntime>, HostData) {
133 self.generate_host_data(HostDataType::Bool)
134 }
135
136 pub fn f32_host_data(self) -> HostData {
137 self.generate_host_data(HostDataType::F32).1
138 }
139
140 pub fn bool_host_data(self) -> HostData {
141 self.generate_host_data(HostDataType::Bool).1
142 }
143
144 pub fn generate_without_host_data(self) -> TensorHandle<TestRuntime> {
146 self.generate()
147 }
148
149 pub fn generate(self) -> TensorHandle<TestRuntime> {
150 match self.spec {
151 TestInputSpec::Arange(spec) => build_arange(spec),
152 TestInputSpec::Eye(spec) => build_eye(spec),
153 TestInputSpec::Random(spec) => build_random(spec),
154 TestInputSpec::Zeros(spec) => build_zeros(spec),
155 TestInputSpec::Custom(spec) => build_custom(spec),
156 }
157 }
158
159 fn generate_host_data(
160 self,
161 host_data_type: HostDataType,
162 ) -> (TensorHandle<TestRuntime>, HostData) {
163 let client = self.client.clone();
164 let tensor_handle = self.generate();
165 let host_data = HostData::from_tensor_handle(&client, &tensor_handle, host_data_type);
166 (tensor_handle, host_data)
167 }
168}
169
170pub struct SimpleInputSpec {
171 pub(crate) client: ComputeClient<TestRuntime>,
172 pub(crate) shape: Vec<usize>,
173 pub(crate) dtype: StorageType,
174 pub(crate) stride_spec: StrideSpec,
175}
176
177impl SimpleInputSpec {
178 pub(crate) fn strides(&self) -> Vec<usize> {
179 self.stride_spec.compute_strides(&self.shape)
180 }
181}
182
183pub struct RandomInputSpec {
184 pub(crate) inner: SimpleInputSpec,
185 pub(crate) seed: u64,
186 pub(crate) distribution: Distribution,
187}
188
189pub struct CustomInputSpec {
190 pub(crate) inner: SimpleInputSpec,
191 pub(crate) data: Vec<f32>,
192}
193
194#[derive(Copy, Clone)]
195pub enum Distribution {
196 Uniform(f32, f32),
198 Bernoulli(f32),
200}