1#![no_std]
2
3#[cfg(feature = "std")]
4extern crate std;
5
6extern crate alloc;
7
8#[macro_use]
9extern crate derive_new;
10
11pub use cubecl_zspace as zspace;
12use cubecl_zspace::Shape;
13use cubecl_zspace::Strides;
14
15pub mod frontend;
17pub mod io;
19
20pub mod post_processing;
21
22pub use cubecl_common::future;
24
25use cubecl_ir::VectorSize;
26use cubecl_runtime::client::ComputeClient;
27pub use cubecl_runtime::memory_management::MemoryConfiguration;
28use cubecl_runtime::server::CubeCountSelection;
29pub use frontend::cmma;
30
31pub use cubecl_ir as ir;
33
34pub mod codegen;
35pub mod compute;
36pub mod prelude;
37
38mod pod;
39
40pub use codegen::*;
41pub use cubecl_runtime::runtime::*;
42pub use pod::*;
43
44pub use cubecl_macros::*;
45pub use cubecl_runtime::benchmark;
46pub use cubecl_runtime::client;
47pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
48pub use cubecl_runtime::memory_management::MemoryUsage;
49pub use cubecl_runtime::server;
50pub use cubecl_runtime::tune;
51
52use frontend::LaunchArg;
53
54pub use cubecl_common::*;
55
56pub use prelude::CubeCount;
57pub use prelude::{CubeDim, ExecutionMode};
58
59pub use num_traits;
60
61mod id;
62pub use id::*;
63
64#[doc(hidden)]
66pub mod __private {
67 pub use alloc::{format, vec};
68 pub use paste::paste;
69}
70
71pub use prelude::{Assign, IntoRuntime};
72
73pub fn calculate_cube_count_elemwise<R: Runtime>(
76 client: &ComputeClient<R>,
77 num_elems: usize,
78 cube_dim: CubeDim,
79) -> CubeCount {
80 if num_elems == 0 {
81 return CubeCount::Static(0, 0, 0);
82 }
83 let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
84 CubeCountSelection::new(client, num_cubes as u32).cube_count()
85}
86
87pub fn tensor_vectorization_factor(
88 factors: &[VectorSize],
89 shape: &Shape,
90 strides: &Strides,
91 dim: usize,
92) -> VectorSize {
93 tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
94}
95pub fn tensor_vectorization(
96 factors: &[VectorSize],
97 shape: &Shape,
98 strides: &Strides,
99 dim: usize,
100) -> VectorSize {
101 tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
102}
103
104#[derive(Debug, Clone)]
105pub enum VectorizationError {
106 AxisOutOfBounds,
107 StrideMismatch,
108 NoValidVectorization,
109}
110
111pub fn tensor_vector_size_parallel(
124 optimized_vector_sizes: impl Iterator<Item = VectorSize>,
125 shape: &Shape,
126 strides: &Strides,
127 axis: usize,
128) -> VectorSize {
129 try_tensor_vector_size_parallel(optimized_vector_sizes, shape, strides, axis).unwrap_or(1)
130}
131
132pub fn try_tensor_vector_size_parallel(
134 supported_vector_sizes: impl Iterator<Item = VectorSize>,
135 shape: &Shape,
136 strides: &Strides,
137 axis: usize,
138) -> Result<VectorSize, VectorizationError> {
139 let stride = strides
140 .get(axis)
141 .ok_or(VectorizationError::AxisOutOfBounds)?;
142 if *stride != 1 {
143 return Err(VectorizationError::StrideMismatch);
144 }
145
146 let axis_shape = shape.get(axis).ok_or(VectorizationError::AxisOutOfBounds)?;
147
148 let next_stride = strides
155 .iter()
156 .enumerate()
157 .filter_map(|(i, &s)| (i != axis && s != 0).then_some(s))
158 .min()
159 .unwrap_or(0);
160
161 supported_vector_sizes
162 .filter(|&vector_size| axis_shape % vector_size == 0 && next_stride % vector_size == 0)
163 .max()
164 .ok_or(VectorizationError::NoValidVectorization)
165}
166
167pub fn tensor_vector_size_perpendicular(
178 supported_vector_sizes: impl Iterator<Item = VectorSize>,
179 shape: &[usize],
180 strides: &[usize],
181 axis: usize,
182) -> VectorSize {
183 try_tensor_vector_sizes_perpendicular(supported_vector_sizes, shape, strides, axis).unwrap_or(1)
184}
185
186pub fn try_tensor_vector_sizes_perpendicular(
188 supported_vector_sizes: impl Iterator<Item = VectorSize>,
189 shape: &[usize],
190 strides: &[usize],
191 axis: usize,
192) -> Result<VectorSize, VectorizationError> {
193 let axis_stride = strides
194 .get(axis)
195 .ok_or(VectorizationError::AxisOutOfBounds)?;
196
197 let prod_shape_axes_smaller_strides = strides
198 .iter()
199 .zip(shape.iter())
200 .filter(|(stride, _)| **stride < *axis_stride)
201 .map(|(_, shape)| shape)
202 .product::<usize>();
203
204 if *axis_stride != prod_shape_axes_smaller_strides {
205 return Err(VectorizationError::StrideMismatch);
206 }
207
208 supported_vector_sizes
209 .filter(|&vector_size| *axis_stride % vector_size == 0)
210 .max()
211 .ok_or(VectorizationError::NoValidVectorization)
212}
213
214pub type RuntimeArg<T, R> = <T as LaunchArg>::RuntimeArg<R>;
216pub type ExpandType<T> = <T as crate::prelude::CubeType>::ExpandType;
217
218#[cfg(feature = "export_tests")]
219pub mod runtime_tests;
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 fn try_parallel(
227 sizes: &[VectorSize],
228 shape: &[usize],
229 strides: &[usize],
230 axis: usize,
231 ) -> Result<VectorSize, VectorizationError> {
232 try_tensor_vector_size_parallel(
233 sizes.iter().copied(),
234 &Shape::from(shape.iter().copied()),
235 &Strides::new(strides),
236 axis,
237 )
238 }
239
240 #[test]
241 fn parallel_contiguous_picks_max_vector_size() {
242 let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[36, 4, 1], 2).unwrap();
245 assert_eq!(v, 4);
246 }
247
248 #[test]
249 fn parallel_unfold_step_one_rejects_vectorization() {
250 let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[12, 1, 1], 2).unwrap();
256 assert_eq!(v, 1);
257 }
258
259 #[test]
260 fn parallel_unfold_step_two_allows_vectorization() {
261 let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[12, 2, 1], 2).unwrap();
265 assert_eq!(v, 2);
266 }
267
268 #[test]
269 fn parallel_broadcast_dim_ignored() {
270 let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[0, 4, 1], 2).unwrap();
273 assert_eq!(v, 4);
274 }
275
276 #[test]
277 fn parallel_axis_stride_not_one_is_error() {
278 let err = try_parallel(&[1, 2, 4], &[1, 9, 4], &[36, 1, 4], 2).unwrap_err();
279 assert!(matches!(err, VectorizationError::StrideMismatch));
280 }
281}