Skip to main content

cubecl_core/
lib.rs

1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6pub use cubecl_zspace as zspace;
7
8/// Cube Frontend Types.
9pub mod frontend;
10/// Input Output utilities.
11pub mod io;
12
13pub mod post_processing;
14
15/// Some future utilities that work across environments.
16pub use cubecl_common::future;
17
18use cubecl_ir::LineSize;
19use cubecl_runtime::client::ComputeClient;
20pub use cubecl_runtime::memory_management::MemoryConfiguration;
21use cubecl_runtime::server::CubeCountSelection;
22pub use frontend::cmma;
23
24/// Cube Language Internal Representation.
25pub use cubecl_ir as ir;
26
27pub mod codegen;
28pub mod compute;
29pub mod prelude;
30
31mod pod;
32
33pub use codegen::*;
34pub use cubecl_runtime::runtime::*;
35pub use pod::*;
36
37pub use cubecl_macros::*;
38pub use cubecl_runtime::benchmark;
39pub use cubecl_runtime::client;
40pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
41pub use cubecl_runtime::memory_management::MemoryUsage;
42pub use cubecl_runtime::server;
43pub use cubecl_runtime::tune;
44
45use frontend::LaunchArg;
46
47pub use cubecl_common::*;
48
49pub use prelude::CubeCount;
50pub use prelude::{CubeDim, ExecutionMode};
51
52pub use num_traits;
53
54mod id;
55pub use id::*;
56
57/// Calculate the number of cubes required to execute an operation where one cube unit is
58/// assigned to one element.
59pub fn calculate_cube_count_elemwise<R: Runtime>(
60    client: &ComputeClient<R>,
61    num_elems: usize,
62    cube_dim: CubeDim,
63) -> CubeCount {
64    let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
65    CubeCountSelection::new(client, num_cubes as u32).cube_count()
66}
67
68pub fn tensor_vectorization_factor(
69    factors: &[LineSize],
70    shape: &[usize],
71    strides: &[usize],
72    dim: usize,
73) -> LineSize {
74    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
75}
76pub fn tensor_line_size(
77    factors: &[LineSize],
78    shape: &[usize],
79    strides: &[usize],
80    dim: usize,
81) -> LineSize {
82    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
83}
84
85#[derive(Debug, Clone)]
86pub enum LineSizeError {
87    AxisOutOfBounds,
88    StrideMismatch,
89    NoValidLineSize,
90}
91
92/// Find the maximum line size usable for parallel vectorization along the given axis
93/// from the supported line sizes or return 1 if vectorization is impossible.
94///
95/// This function is designed to never return a line size above 1 by error,
96/// but doesn't guarantee to always return the actual maximum possible line size.
97/// That is, it may be overly strict.
98///
99/// Currently, this checks that the stride of the axis is 1, that it's shape is
100/// divisible by a candidate line size and that the smallest stride that is not 1
101/// is divisible by the vectorization.
102/// The last condition ensure that the current axis is contiguous within the next stride.
103pub fn tensor_line_size_parallel(
104    supported_line_sizes: impl Iterator<Item = LineSize>,
105    shape: &[usize],
106    strides: &[usize],
107    axis: usize,
108) -> LineSize {
109    try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1)
110}
111
112/// Like `try_tensor_line_size_parallel` but does not assume 1 is supported
113pub fn try_tensor_line_size_parallel(
114    supported_line_sizes: impl Iterator<Item = LineSize>,
115    shape: &[usize],
116    strides: &[usize],
117    axis: usize,
118) -> Result<LineSize, LineSizeError> {
119    let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
120    if *stride != 1 {
121        return Err(LineSizeError::StrideMismatch);
122    }
123
124    let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
125
126    let next_stride = *strides
127        .iter()
128        .filter(|&&stride| stride > 1)
129        .min()
130        .unwrap_or(&0);
131
132    supported_line_sizes
133        .filter(|&line_size| axis_shape % line_size == 0 && next_stride % line_size == 0)
134        .max()
135        .ok_or(LineSizeError::NoValidLineSize)
136}
137
138/// Find the maximum line size usable for perpendicular vectorization along the given axis
139/// from the supported line sizes or return 1 if vectorization is impossible.
140///
141/// This function is designed to never return a line size above 1 by error,
142/// but doesn't guarantee to always return the actual maximum possible line size.
143/// That is, it may be overly strict.
144///
145/// Currently, this checks that the stride of the axis is divisible by a candidate line size
146/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
147/// The second condition ensure that elements within the stride are contiguous.
148pub fn tensor_line_size_perpendicular(
149    supported_line_sizes: impl Iterator<Item = LineSize>,
150    shape: &[usize],
151    strides: &[usize],
152    axis: usize,
153) -> LineSize {
154    try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
155}
156
157/// Like `tensor_line_size_perpendicular` but does not assume 1 is supported
158pub fn try_tensor_line_size_perpendicular(
159    supported_line_sizes: impl Iterator<Item = LineSize>,
160    shape: &[usize],
161    strides: &[usize],
162    axis: usize,
163) -> Result<LineSize, LineSizeError> {
164    let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
165
166    let prod_shape_axes_smaller_strides = strides
167        .iter()
168        .zip(shape.iter())
169        .filter(|(stride, _)| **stride < *axis_stride)
170        .map(|(_, shape)| shape)
171        .product::<usize>();
172
173    if *axis_stride != prod_shape_axes_smaller_strides {
174        return Err(LineSizeError::StrideMismatch);
175    }
176
177    supported_line_sizes
178        .filter(|&line_size| *axis_stride % line_size == 0)
179        .max()
180        .ok_or(LineSizeError::NoValidLineSize)
181}
182
183/// Runtime arguments to launch a kernel.
184pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
185
186#[cfg(feature = "export_tests")]
187/// Tests only useful for runtimes.
188pub mod runtime_tests;