Skip to main content

cubecl_core/
lib.rs

1#![no_std]
2
3#[cfg(feature = "std")]
4extern crate std;
5
6extern crate alloc;
7
8#[macro_use]
9extern crate derive_new;
10
11pub use cubecl_zspace as zspace;
12use cubecl_zspace::Shape;
13use cubecl_zspace::Strides;
14
15/// Cube Frontend Types.
16pub mod frontend;
17/// Input Output utilities.
18pub mod io;
19
20pub mod post_processing;
21
22/// Some future utilities that work across environments.
23pub use cubecl_common::future;
24
25use cubecl_ir::VectorSize;
26use cubecl_runtime::client::ComputeClient;
27pub use cubecl_runtime::memory_management::MemoryConfiguration;
28use cubecl_runtime::server::CubeCountSelection;
29pub use frontend::cmma;
30
31/// Cube Language Internal Representation.
32pub use cubecl_ir as ir;
33
34pub mod codegen;
35pub mod compute;
36pub mod prelude;
37
38mod pod;
39
40pub use codegen::*;
41pub use cubecl_runtime::runtime::*;
42pub use pod::*;
43
44pub use cubecl_macros::*;
45pub use cubecl_runtime::benchmark;
46pub use cubecl_runtime::client;
47pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
48pub use cubecl_runtime::memory_management::MemoryUsage;
49pub use cubecl_runtime::server;
50pub use cubecl_runtime::tune;
51
52use frontend::LaunchArg;
53
54pub use cubecl_common::*;
55
56pub use prelude::CubeCount;
57pub use prelude::{CubeDim, ExecutionMode};
58
59pub use num_traits;
60
61mod id;
62pub use id::*;
63
64// Private utils for macros
65#[doc(hidden)]
66pub mod __private {
67    pub use alloc::{format, vec};
68    pub use paste::paste;
69}
70
71pub use prelude::{Assign, IntoRuntime};
72
73/// Calculate the number of cubes required to execute an operation where one cube unit is
74/// assigned to one element.
75pub fn calculate_cube_count_elemwise<R: Runtime>(
76    client: &ComputeClient<R>,
77    num_elems: usize,
78    cube_dim: CubeDim,
79) -> CubeCount {
80    if num_elems == 0 {
81        return CubeCount::Static(0, 0, 0);
82    }
83    let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
84    CubeCountSelection::new(client, num_cubes as u32).cube_count()
85}
86
87pub fn tensor_vectorization_factor(
88    factors: &[VectorSize],
89    shape: &Shape,
90    strides: &Strides,
91    dim: usize,
92) -> VectorSize {
93    tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
94}
95pub fn tensor_vectorization(
96    factors: &[VectorSize],
97    shape: &Shape,
98    strides: &Strides,
99    dim: usize,
100) -> VectorSize {
101    tensor_vector_size_parallel(factors.iter().cloned(), shape, strides, dim)
102}
103
104#[derive(Debug, Clone)]
105pub enum VectorizationError {
106    AxisOutOfBounds,
107    StrideMismatch,
108    NoValidVectorization,
109}
110
111/// Find the maximum vector size usable for parallel vectorization along the given axis
112/// from the supported vector sizes or return 1 if vectorization is impossible.
113///
114/// This function is designed to never return a vector size above 1 by error,
115/// but doesn't guarantee to always return the actual maximum possible vector size.
116/// That is, it may be overly strict.
117///
118/// Currently, this checks that the stride of the axis is 1, that it's shape is
119/// divisible by a candidate vector size and that the smallest stride that is not 1
120/// is divisible by the vector size.
121/// The last condition ensure that the current axis is contiguous within the next stride.
122pub fn tensor_vector_size_parallel(
123    optimized_vector_sizes: impl Iterator<Item = VectorSize>,
124    shape: &Shape,
125    strides: &Strides,
126    axis: usize,
127) -> VectorSize {
128    try_tensor_vector_size_parallel(optimized_vector_sizes, shape, strides, axis).unwrap_or(1)
129}
130
131/// Like `try_tensor_vector_size_parallel` but does not assume 1 is supported
132pub fn try_tensor_vector_size_parallel(
133    supported_vector_sizes: impl Iterator<Item = VectorSize>,
134    shape: &Shape,
135    strides: &Strides,
136    axis: usize,
137) -> Result<VectorSize, VectorizationError> {
138    let stride = strides
139        .get(axis)
140        .ok_or(VectorizationError::AxisOutOfBounds)?;
141    if *stride != 1 {
142        return Err(VectorizationError::StrideMismatch);
143    }
144
145    let axis_shape = shape.get(axis).ok_or(VectorizationError::AxisOutOfBounds)?;
146
147    let next_stride = *strides
148        .iter()
149        .filter(|&&stride| stride > 1)
150        .min()
151        .unwrap_or(&0);
152
153    supported_vector_sizes
154        .filter(|&vector_size| axis_shape % vector_size == 0 && next_stride % vector_size == 0)
155        .max()
156        .ok_or(VectorizationError::NoValidVectorization)
157}
158
159/// Find the maximum vector size usable for perpendicular vectorization along the given axis
160/// from the supported vector sizes or return 1 if vectorization is impossible.
161///
162/// This function is designed to never return a vector size above 1 by error,
163/// but doesn't guarantee to always return the actual maximum possible vector size.
164/// That is, it may be overly strict.
165///
166/// Currently, this checks that the stride of the axis is divisible by a candidate vector size
167/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
168/// The second condition ensure that elements within the stride are contiguous.
169pub fn tensor_vector_size_perpendicular(
170    supported_vector_sizes: impl Iterator<Item = VectorSize>,
171    shape: &[usize],
172    strides: &[usize],
173    axis: usize,
174) -> VectorSize {
175    try_tensor_vector_sizes_perpendicular(supported_vector_sizes, shape, strides, axis).unwrap_or(1)
176}
177
178/// Like `tensor_vector_sizes_perpendicular` but does not assume 1 is supported
179pub fn try_tensor_vector_sizes_perpendicular(
180    supported_vector_sizes: impl Iterator<Item = VectorSize>,
181    shape: &[usize],
182    strides: &[usize],
183    axis: usize,
184) -> Result<VectorSize, VectorizationError> {
185    let axis_stride = strides
186        .get(axis)
187        .ok_or(VectorizationError::AxisOutOfBounds)?;
188
189    let prod_shape_axes_smaller_strides = strides
190        .iter()
191        .zip(shape.iter())
192        .filter(|(stride, _)| **stride < *axis_stride)
193        .map(|(_, shape)| shape)
194        .product::<usize>();
195
196    if *axis_stride != prod_shape_axes_smaller_strides {
197        return Err(VectorizationError::StrideMismatch);
198    }
199
200    supported_vector_sizes
201        .filter(|&vector_size| *axis_stride % vector_size == 0)
202        .max()
203        .ok_or(VectorizationError::NoValidVectorization)
204}
205
206/// Runtime arguments to launch a kernel.
207pub type RuntimeArg<T, R> = <T as LaunchArg>::RuntimeArg<R>;
208pub type ExpandType<T> = <T as crate::prelude::CubeType>::ExpandType;
209
210#[cfg(feature = "export_tests")]
211/// Tests only useful for runtimes.
212pub mod runtime_tests;