Skip to main content

cubecl_core/
lib.rs

1#![no_std]
2
3#[cfg(feature = "std")]
4extern crate std;
5
6extern crate alloc;
7
8#[macro_use]
9extern crate derive_new;
10
11pub use cubecl_zspace as zspace;
12
13/// Cube Frontend Types.
14pub mod frontend;
15/// Input Output utilities.
16pub mod io;
17
18pub mod post_processing;
19
20/// Some future utilities that work across environments.
21pub use cubecl_common::future;
22
23use cubecl_ir::LineSize;
24use cubecl_runtime::client::ComputeClient;
25pub use cubecl_runtime::memory_management::MemoryConfiguration;
26use cubecl_runtime::server::CubeCountSelection;
27pub use frontend::cmma;
28
29/// Cube Language Internal Representation.
30pub use cubecl_ir as ir;
31
32pub mod codegen;
33pub mod compute;
34pub mod prelude;
35
36mod pod;
37
38pub use codegen::*;
39pub use cubecl_runtime::runtime::*;
40pub use pod::*;
41
42pub use cubecl_macros::*;
43pub use cubecl_runtime::benchmark;
44pub use cubecl_runtime::client;
45pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
46pub use cubecl_runtime::memory_management::MemoryUsage;
47pub use cubecl_runtime::server;
48pub use cubecl_runtime::tune;
49
50use frontend::LaunchArg;
51
52pub use cubecl_common::*;
53
54pub use prelude::CubeCount;
55pub use prelude::{CubeDim, ExecutionMode};
56
57pub use num_traits;
58
59mod id;
60pub use id::*;
61
62// Private utils for macros
63#[doc(hidden)]
64pub mod __private {
65    pub use alloc::{format, vec};
66}
67
68/// Calculate the number of cubes required to execute an operation where one cube unit is
69/// assigned to one element.
70pub fn calculate_cube_count_elemwise<R: Runtime>(
71    client: &ComputeClient<R>,
72    num_elems: usize,
73    cube_dim: CubeDim,
74) -> CubeCount {
75    let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
76    CubeCountSelection::new(client, num_cubes as u32).cube_count()
77}
78
79pub fn tensor_vectorization_factor(
80    factors: &[LineSize],
81    shape: &[usize],
82    strides: &[usize],
83    dim: usize,
84) -> LineSize {
85    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
86}
87pub fn tensor_line_size(
88    factors: &[LineSize],
89    shape: &[usize],
90    strides: &[usize],
91    dim: usize,
92) -> LineSize {
93    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
94}
95
96#[derive(Debug, Clone)]
97pub enum LineSizeError {
98    AxisOutOfBounds,
99    StrideMismatch,
100    NoValidLineSize,
101}
102
103/// Find the maximum line size usable for parallel vectorization along the given axis
104/// from the supported line sizes or return 1 if vectorization is impossible.
105///
106/// This function is designed to never return a line size above 1 by error,
107/// but doesn't guarantee to always return the actual maximum possible line size.
108/// That is, it may be overly strict.
109///
110/// Currently, this checks that the stride of the axis is 1, that it's shape is
111/// divisible by a candidate line size and that the smallest stride that is not 1
112/// is divisible by the vectorization.
113/// The last condition ensure that the current axis is contiguous within the next stride.
114pub fn tensor_line_size_parallel(
115    optimized_line_sizes: impl Iterator<Item = LineSize>,
116    shape: &[usize],
117    strides: &[usize],
118    axis: usize,
119) -> LineSize {
120    try_tensor_line_size_parallel(optimized_line_sizes, shape, strides, axis).unwrap_or(1)
121}
122
123/// Like `try_tensor_line_size_parallel` but does not assume 1 is supported
124pub fn try_tensor_line_size_parallel(
125    supported_line_sizes: impl Iterator<Item = LineSize>,
126    shape: &[usize],
127    strides: &[usize],
128    axis: usize,
129) -> Result<LineSize, LineSizeError> {
130    let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
131    if *stride != 1 {
132        return Err(LineSizeError::StrideMismatch);
133    }
134
135    let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
136
137    let next_stride = *strides
138        .iter()
139        .filter(|&&stride| stride > 1)
140        .min()
141        .unwrap_or(&0);
142
143    supported_line_sizes
144        .filter(|&line_size| axis_shape % line_size == 0 && next_stride % line_size == 0)
145        .max()
146        .ok_or(LineSizeError::NoValidLineSize)
147}
148
149/// Find the maximum line size usable for perpendicular vectorization along the given axis
150/// from the supported line sizes or return 1 if vectorization is impossible.
151///
152/// This function is designed to never return a line size above 1 by error,
153/// but doesn't guarantee to always return the actual maximum possible line size.
154/// That is, it may be overly strict.
155///
156/// Currently, this checks that the stride of the axis is divisible by a candidate line size
157/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
158/// The second condition ensure that elements within the stride are contiguous.
159pub fn tensor_line_size_perpendicular(
160    supported_line_sizes: impl Iterator<Item = LineSize>,
161    shape: &[usize],
162    strides: &[usize],
163    axis: usize,
164) -> LineSize {
165    try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
166}
167
168/// Like `tensor_line_size_perpendicular` but does not assume 1 is supported
169pub fn try_tensor_line_size_perpendicular(
170    supported_line_sizes: impl Iterator<Item = LineSize>,
171    shape: &[usize],
172    strides: &[usize],
173    axis: usize,
174) -> Result<LineSize, LineSizeError> {
175    let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
176
177    let prod_shape_axes_smaller_strides = strides
178        .iter()
179        .zip(shape.iter())
180        .filter(|(stride, _)| **stride < *axis_stride)
181        .map(|(_, shape)| shape)
182        .product::<usize>();
183
184    if *axis_stride != prod_shape_axes_smaller_strides {
185        return Err(LineSizeError::StrideMismatch);
186    }
187
188    supported_line_sizes
189        .filter(|&line_size| *axis_stride % line_size == 0)
190        .max()
191        .ok_or(LineSizeError::NoValidLineSize)
192}
193
194/// Runtime arguments to launch a kernel.
195pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
196pub type ExpandType<T> = <T as crate::prelude::CubeType>::ExpandType;
197
198#[cfg(feature = "export_tests")]
199/// Tests only useful for runtimes.
200pub mod runtime_tests;