cubecl_core/
lib.rs

1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6/// Cube Frontend Types.
7pub mod frontend;
8/// Input Output utilities.
9pub mod io;
10
11pub mod post_processing;
12
13/// Some future utilities that work across environments.
14pub use cubecl_common::{PLANE_DIM_APPROX, future};
15
16pub use cubecl_runtime::memory_management::MemoryConfiguration;
17pub use frontend::cmma;
18
19/// Cube Language Internal Representation.
20pub use cubecl_ir as ir;
21
22pub mod codegen;
23pub mod compute;
24pub mod prelude;
25
26mod pod;
27
28pub use codegen::*;
29pub use cubecl_runtime::runtime::*;
30pub use pod::*;
31
32pub use cubecl_macros::*;
33pub use cubecl_runtime::benchmark;
34pub use cubecl_runtime::client;
35pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
36pub use cubecl_runtime::memory_management::MemoryUsage;
37pub use cubecl_runtime::server;
38pub use cubecl_runtime::tune;
39
40use frontend::LaunchArg;
41
42pub use cubecl_common::ExecutionMode;
43pub use cubecl_common::{flex32, tf32};
44
45pub use prelude::CubeCount;
46pub use prelude::CubeDim;
47
48mod id;
49pub use id::*;
50
51/// Calculate the number of cubes required to execute an operation where one cube unit is
52/// assigned to one element.
53pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: CubeDim) -> CubeCount {
54    let num_elems_per_cube = cube_dim.num_elems();
55    let cube_counts = f32::max(1.0, f32::ceil(num_elems as f32 / num_elems_per_cube as f32));
56    let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
57    let cube_count_y = f32::max(
58        1.0,
59        f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32)),
60    );
61
62    CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
63}
64
65pub fn tensor_vectorization_factor(
66    factors: &[u8],
67    shape: &[usize],
68    strides: &[usize],
69    dim: usize,
70) -> u8 {
71    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
72}
73pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
74    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
75}
76
77#[derive(Debug, Clone)]
78pub enum LineSizeError {
79    AxisOutOfBounds,
80    StrideMismatch,
81    NoValidLineSize,
82}
83
84/// Find the maximum line size usable for parallel vectorization along the given axis
85/// from the supported line sizes or return 1 if vectorization is impossible.
86///
87/// This function is designed to never return a line size above 1 by error,
88/// but doesn't guarantee to always return the actual maximum possible line size.
89/// That is, it may be overly strict.
90///
91/// Currently, this checks that the stride of the axis is 1, that it's shape is
92/// divisible by a candidate line size and that the smallest stride that is not 1
93/// is divisible by the vectorization.
94/// The last condition ensure that the current axis is contiguous within the next stride.
95pub fn tensor_line_size_parallel(
96    supported_line_sizes: impl Iterator<Item = u8>,
97    shape: &[usize],
98    strides: &[usize],
99    axis: usize,
100) -> u8 {
101    try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1)
102}
103
104/// Like `try_tensor_line_size_parallel` but does not assume 1 is supported
105pub fn try_tensor_line_size_parallel(
106    supported_line_sizes: impl Iterator<Item = u8>,
107    shape: &[usize],
108    strides: &[usize],
109    axis: usize,
110) -> Result<u8, LineSizeError> {
111    let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
112    if *stride != 1 {
113        return Err(LineSizeError::StrideMismatch);
114    }
115
116    let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
117
118    let next_stride = *strides
119        .iter()
120        .filter(|&&stride| stride > 1)
121        .min()
122        .unwrap_or(&0);
123
124    supported_line_sizes
125        .filter(|&line_size| {
126            axis_shape % line_size as usize == 0 && next_stride % line_size as usize == 0
127        })
128        .max()
129        .ok_or(LineSizeError::NoValidLineSize)
130}
131
132/// Find the maximum line size usable for perpendicular vectorization along the given axis
133/// from the supported line sizes or return 1 if vectorization is impossible.
134///
135/// This function is designed to never return a line size above 1 by error,
136/// but doesn't guarantee to always return the actual maximum possible line size.
137/// That is, it may be overly strict.
138///
139/// Currently, this checks that the stride of the axis is divisible by a candidate line size
140/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
141/// The second condition ensure that elements within the stride are contiguous.
142pub fn tensor_line_size_perpendicular(
143    supported_line_sizes: impl Iterator<Item = u8>,
144    shape: &[usize],
145    strides: &[usize],
146    axis: usize,
147) -> u8 {
148    try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
149}
150
151/// Like `tensor_line_size_perpendicular` but does not assume 1 is supported
152pub fn try_tensor_line_size_perpendicular(
153    supported_line_sizes: impl Iterator<Item = u8>,
154    shape: &[usize],
155    strides: &[usize],
156    axis: usize,
157) -> Result<u8, LineSizeError> {
158    let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
159
160    let prod_shape_axes_smaller_strides = strides
161        .iter()
162        .zip(shape.iter())
163        .filter(|(stride, _)| **stride < *axis_stride)
164        .map(|(_, shape)| shape)
165        .product::<usize>();
166
167    if *axis_stride != prod_shape_axes_smaller_strides {
168        return Err(LineSizeError::StrideMismatch);
169    }
170
171    supported_line_sizes
172        .filter(|&line_size| *axis_stride % line_size as usize == 0)
173        .max()
174        .ok_or(LineSizeError::NoValidLineSize)
175}
176
177/// Runtime arguments to launch a kernel.
178pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
179
180#[cfg(feature = "export_tests")]
181/// Tests only useful for runtimes.
182pub mod runtime_tests;