1#![no_std]
2
3#[cfg(feature = "std")]
4extern crate std;
5
6extern crate alloc;
7
8#[macro_use]
9extern crate derive_new;
10
11pub use cubecl_zspace as zspace;
12use cubecl_zspace::Shape;
13use cubecl_zspace::Strides;
14
15pub mod frontend;
17pub mod io;
19
20pub mod post_processing;
21
22pub use cubecl_common::future;
24
25use cubecl_ir::VectorSize;
26use cubecl_runtime::client::ComputeClient;
27pub use cubecl_runtime::memory_management::MemoryConfiguration;
28use cubecl_runtime::server::CubeCountSelection;
29pub use frontend::cmma;
30
31pub use cubecl_ir as ir;
33
34pub mod codegen;
35pub mod compute;
36pub mod prelude;
37
38mod pod;
39
40pub use codegen::*;
41pub use cubecl_runtime::runtime::*;
42pub use pod::*;
43
44pub use cubecl_macros::*;
45pub use cubecl_runtime::benchmark;
46pub use cubecl_runtime::client;
47pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
48pub use cubecl_runtime::memory_management::MemoryUsage;
49pub use cubecl_runtime::server;
50pub use cubecl_runtime::tune;
51
52use frontend::LaunchArg;
53
54pub use cubecl_common::*;
55
56pub use prelude::CubeCount;
57pub use prelude::{CubeDim, ExecutionMode};
58
59pub use num_traits;
60
61mod id;
62pub use id::*;
63
64#[doc(hidden)]
66pub mod __private {
67 pub use alloc::{format, vec};
68 pub use paste::paste;
69}
70
71pub use prelude::{Assign, IntoRuntime};
72
73pub fn calculate_cube_count_elemwise<R: Runtime>(
76 client: &ComputeClient<R>,
77 num_elems: usize,
78 cube_dim: CubeDim,
79) -> CubeCount {
80 if num_elems == 0 {
81 return CubeCount::Static(0, 0, 0);
82 }
83 let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
84 CubeCountSelection::new(client, num_cubes as u32).cube_count()
85}
86
87pub fn tensor_vectorization_factor(
88 factors: &[VectorSize],
89 shape: &Shape,
90 strides: &Strides,
91 dim: usize,
92) -> VectorSize {
93 tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
94}
95pub fn tensor_vectorization(
96 factors: &[VectorSize],
97 shape: &Shape,
98 strides: &Strides,
99 dim: usize,
100) -> VectorSize {
101 tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
102}
103
104#[derive(Debug, Clone)]
105pub enum VectorizationError {
106 AxisOutOfBounds,
107 StrideMismatch,
108 NoValidVectorization,
109}
110
111pub fn tensor_vector_size_parallel(
123 optimized_vector_sizes: impl Iterator<Item = VectorSize>,
124 shape: &Shape,
125 strides: &Strides,
126 axis: usize,
127) -> VectorSize {
128 try_tensor_vector_size_parallel(optimized_vector_sizes, shape, strides, axis).unwrap_or(1)
129}
130
131pub fn try_tensor_vector_size_parallel(
133 supported_vector_sizes: impl Iterator<Item = VectorSize>,
134 shape: &Shape,
135 strides: &Strides,
136 axis: usize,
137) -> Result<VectorSize, VectorizationError> {
138 let stride = strides
139 .get(axis)
140 .ok_or(VectorizationError::AxisOutOfBounds)?;
141 if *stride != 1 {
142 return Err(VectorizationError::StrideMismatch);
143 }
144
145 let axis_shape = shape.get(axis).ok_or(VectorizationError::AxisOutOfBounds)?;
146
147 let next_stride = *strides
148 .iter()
149 .filter(|&&stride| stride > 1)
150 .min()
151 .unwrap_or(&0);
152
153 supported_vector_sizes
154 .filter(|&vector_size| axis_shape % vector_size == 0 && next_stride % vector_size == 0)
155 .max()
156 .ok_or(VectorizationError::NoValidVectorization)
157}
158
159pub fn tensor_vector_size_perpendicular(
170 supported_vector_sizes: impl Iterator<Item = VectorSize>,
171 shape: &[usize],
172 strides: &[usize],
173 axis: usize,
174) -> VectorSize {
175 try_tensor_vector_sizes_perpendicular(supported_vector_sizes, shape, strides, axis).unwrap_or(1)
176}
177
178pub fn try_tensor_vector_sizes_perpendicular(
180 supported_vector_sizes: impl Iterator<Item = VectorSize>,
181 shape: &[usize],
182 strides: &[usize],
183 axis: usize,
184) -> Result<VectorSize, VectorizationError> {
185 let axis_stride = strides
186 .get(axis)
187 .ok_or(VectorizationError::AxisOutOfBounds)?;
188
189 let prod_shape_axes_smaller_strides = strides
190 .iter()
191 .zip(shape.iter())
192 .filter(|(stride, _)| **stride < *axis_stride)
193 .map(|(_, shape)| shape)
194 .product::<usize>();
195
196 if *axis_stride != prod_shape_axes_smaller_strides {
197 return Err(VectorizationError::StrideMismatch);
198 }
199
200 supported_vector_sizes
201 .filter(|&vector_size| *axis_stride % vector_size == 0)
202 .max()
203 .ok_or(VectorizationError::NoValidVectorization)
204}
205
206pub type RuntimeArg<T, R> = <T as LaunchArg>::RuntimeArg<R>;
208pub type ExpandType<T> = <T as crate::prelude::CubeType>::ExpandType;
209
210#[cfg(feature = "export_tests")]
211pub mod runtime_tests;