Skip to main content

cubecl_core/
lib.rs

1#![no_std]
2
3#[cfg(feature = "std")]
4extern crate std;
5
6extern crate alloc;
7
8#[macro_use]
9extern crate derive_new;
10
11pub use cubecl_zspace as zspace;
12use cubecl_zspace::Shape;
13use cubecl_zspace::Strides;
14
15/// Cube Frontend Types.
16pub mod frontend;
17/// Input Output utilities.
18pub mod io;
19
20pub mod post_processing;
21
22/// Some future utilities that work across environments.
23pub use cubecl_common::future;
24
25use cubecl_ir::VectorSize;
26use cubecl_runtime::client::ComputeClient;
27pub use cubecl_runtime::memory_management::MemoryConfiguration;
28use cubecl_runtime::server::CubeCountSelection;
29pub use frontend::cmma;
30
31/// Cube Language Internal Representation.
32pub use cubecl_ir as ir;
33
34pub mod codegen;
35pub mod compute;
36pub mod prelude;
37
38mod pod;
39
40pub use codegen::*;
41pub use cubecl_runtime::runtime::*;
42pub use pod::*;
43
44pub use cubecl_macros::*;
45pub use cubecl_runtime::benchmark;
46pub use cubecl_runtime::client;
47pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
48pub use cubecl_runtime::memory_management::MemoryUsage;
49pub use cubecl_runtime::server;
50pub use cubecl_runtime::tune;
51
52use frontend::LaunchArg;
53
54pub use cubecl_common::*;
55
56pub use prelude::CubeCount;
57pub use prelude::{CubeDim, ExecutionMode};
58
59pub use num_traits;
60
61mod id;
62pub use id::*;
63
64// Private utils for macros
65#[doc(hidden)]
66pub mod __private {
67    pub use alloc::{format, vec};
68    pub use paste::paste;
69}
70
71pub use prelude::{Assign, IntoRuntime};
72
73/// Calculate the number of cubes required to execute an operation where one cube unit is
74/// assigned to one element.
75pub fn calculate_cube_count_elemwise<R: Runtime>(
76    client: &ComputeClient<R>,
77    num_elems: usize,
78    cube_dim: CubeDim,
79) -> CubeCount {
80    if num_elems == 0 {
81        return CubeCount::Static(0, 0, 0);
82    }
83    let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
84    CubeCountSelection::new(client, num_cubes as u32).cube_count()
85}
86
87pub fn tensor_vectorization_factor(
88    factors: &[VectorSize],
89    shape: &Shape,
90    strides: &Strides,
91    dim: usize,
92) -> VectorSize {
93    tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
94}
95pub fn tensor_vectorization(
96    factors: &[VectorSize],
97    shape: &Shape,
98    strides: &Strides,
99    dim: usize,
100) -> VectorSize {
101    tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
102}
103
104#[derive(Debug, Clone)]
105pub enum VectorizationError {
106    AxisOutOfBounds,
107    StrideMismatch,
108    NoValidVectorization,
109}
110
111/// Find the maximum vector size usable for parallel vectorization along the given axis
112/// from the supported vector sizes or return 1 if vectorization is impossible.
113///
114/// This function is designed to never return a vector size above 1 by error,
115/// but doesn't guarantee to always return the actual maximum possible vector size.
116/// That is, it may be overly strict.
117///
118/// Currently, this checks that the stride of the axis is 1, that its shape is
119/// divisible by a candidate vector size and that every non-broadcast stride outside
120/// the axis is divisible by the vector size.
121/// The last condition ensures a vectorized read on `axis` stays contiguous in the
122/// source buffer as coordinates in other dimensions change.
123pub fn tensor_vector_size_parallel(
124    optimized_vector_sizes: impl Iterator<Item = VectorSize>,
125    shape: &Shape,
126    strides: &Strides,
127    axis: usize,
128) -> VectorSize {
129    try_tensor_vector_size_parallel(optimized_vector_sizes, shape, strides, axis).unwrap_or(1)
130}
131
132/// Like `try_tensor_vector_size_parallel` but does not assume 1 is supported
133pub fn try_tensor_vector_size_parallel(
134    supported_vector_sizes: impl Iterator<Item = VectorSize>,
135    shape: &Shape,
136    strides: &Strides,
137    axis: usize,
138) -> Result<VectorSize, VectorizationError> {
139    let stride = strides
140        .get(axis)
141        .ok_or(VectorizationError::AxisOutOfBounds)?;
142    if *stride != 1 {
143        return Err(VectorizationError::StrideMismatch);
144    }
145
146    let axis_shape = shape.get(axis).ok_or(VectorizationError::AxisOutOfBounds)?;
147
148    // Smallest non-zero stride among non-axis dims. Stride 0 is a broadcast and
149    // never contributes to the source offset, so it can be ignored. Every other
150    // dim can shift the source offset when its coord changes, so its stride must
151    // be a multiple of the vector size for vectorized reads to stay aligned.
152    // Unit-size dims are included for simplicity; they only cause false negatives
153    // (vectorization disabled) rather than incorrect output.
154    let next_stride = strides
155        .iter()
156        .enumerate()
157        .filter_map(|(i, &s)| (i != axis && s != 0).then_some(s))
158        .min()
159        .unwrap_or(0);
160
161    supported_vector_sizes
162        .filter(|&vector_size| axis_shape % vector_size == 0 && next_stride % vector_size == 0)
163        .max()
164        .ok_or(VectorizationError::NoValidVectorization)
165}
166
167/// Find the maximum vector size usable for perpendicular vectorization along the given axis
168/// from the supported vector sizes or return 1 if vectorization is impossible.
169///
170/// This function is designed to never return a vector size above 1 by error,
171/// but doesn't guarantee to always return the actual maximum possible vector size.
172/// That is, it may be overly strict.
173///
174/// Currently, this checks that the stride of the axis is divisible by a candidate vector size
175/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
176/// The second condition ensure that elements within the stride are contiguous.
177pub fn tensor_vector_size_perpendicular(
178    supported_vector_sizes: impl Iterator<Item = VectorSize>,
179    shape: &[usize],
180    strides: &[usize],
181    axis: usize,
182) -> VectorSize {
183    try_tensor_vector_sizes_perpendicular(supported_vector_sizes, shape, strides, axis).unwrap_or(1)
184}
185
186/// Like `tensor_vector_sizes_perpendicular` but does not assume 1 is supported
187pub fn try_tensor_vector_sizes_perpendicular(
188    supported_vector_sizes: impl Iterator<Item = VectorSize>,
189    shape: &[usize],
190    strides: &[usize],
191    axis: usize,
192) -> Result<VectorSize, VectorizationError> {
193    let axis_stride = strides
194        .get(axis)
195        .ok_or(VectorizationError::AxisOutOfBounds)?;
196
197    let prod_shape_axes_smaller_strides = strides
198        .iter()
199        .zip(shape.iter())
200        .filter(|(stride, _)| **stride < *axis_stride)
201        .map(|(_, shape)| shape)
202        .product::<usize>();
203
204    if *axis_stride != prod_shape_axes_smaller_strides {
205        return Err(VectorizationError::StrideMismatch);
206    }
207
208    supported_vector_sizes
209        .filter(|&vector_size| *axis_stride % vector_size == 0)
210        .max()
211        .ok_or(VectorizationError::NoValidVectorization)
212}
213
214/// Runtime arguments to launch a kernel.
215pub type RuntimeArg<T, R> = <T as LaunchArg>::RuntimeArg<R>;
216pub type ExpandType<T> = <T as crate::prelude::CubeType>::ExpandType;
217
218#[cfg(feature = "export_tests")]
219/// Tests only useful for runtimes.
220pub mod runtime_tests;
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    fn try_parallel(
227        sizes: &[VectorSize],
228        shape: &[usize],
229        strides: &[usize],
230        axis: usize,
231    ) -> Result<VectorSize, VectorizationError> {
232        try_tensor_vector_size_parallel(
233            sizes.iter().copied(),
234            &Shape::from(shape.iter().copied()),
235            &Strides::new(strides),
236            axis,
237        )
238    }
239
240    #[test]
241    fn parallel_contiguous_picks_max_vector_size() {
242        // Contiguous [1, 9, 4], vectorize along last dim (stride 1).
243        // Outer stride 4 is a multiple of 4, so vec_size = 4 is safe.
244        let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[36, 4, 1], 2).unwrap();
245        assert_eq!(v, 4);
246    }
247
248    #[test]
249    fn parallel_unfold_step_one_rejects_vectorization() {
250        // Unfold view produced by `unfold(1, 4, 1)` on a [1, 12] contiguous tensor:
251        // shape [1, 9, 4], strides [12, 1, 1]. The frame dim has stride 1, so each
252        // step in the frame coord shifts the source offset by 1 - not a multiple
253        // of any vec_size > 1, so vectorized reads would be unaligned and return
254        // the wrong data. Must fall back to vec_size = 1.
255        let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[12, 1, 1], 2).unwrap();
256        assert_eq!(v, 1);
257    }
258
259    #[test]
260    fn parallel_unfold_step_two_allows_vectorization() {
261        // Same unfold pattern but with step=2: strides [12, 2, 1]. Frame coord
262        // shifts source by 2 (still not a multiple of 4), so vec_size = 4 must
263        // be rejected - but vec_size = 2 is fine.
264        let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[12, 2, 1], 2).unwrap();
265        assert_eq!(v, 2);
266    }
267
268    #[test]
269    fn parallel_broadcast_dim_ignored() {
270        // Broadcast dim has stride 0; it never shifts the source offset, so
271        // it should not disqualify vectorization.
272        let v = try_parallel(&[1, 2, 4], &[1, 9, 4], &[0, 4, 1], 2).unwrap();
273        assert_eq!(v, 4);
274    }
275
276    #[test]
277    fn parallel_axis_stride_not_one_is_error() {
278        let err = try_parallel(&[1, 2, 4], &[1, 9, 4], &[36, 1, 4], 2).unwrap_err();
279        assert!(matches!(err, VectorizationError::StrideMismatch));
280    }
281}