cubecl_core/
lib.rs

1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6/// Cube Frontend Types.
7pub mod frontend;
8/// Input Output utilities.
9pub mod io;
10
11pub mod post_processing;
12
13/// Some future utilities that work across environments.
14pub use cubecl_common::{PLANE_DIM_APPROX, future};
15
16pub use cubecl_runtime::memory_management::MemoryConfiguration;
17pub use frontend::cmma;
18
19/// Cube Language Internal Representation.
20pub use cubecl_ir as ir;
21
22pub mod codegen;
23pub mod compute;
24pub mod prelude;
25
26mod pod;
27mod runtime;
28
29pub use codegen::*;
30pub use pod::*;
31pub use runtime::*;
32
33pub use cubecl_macros::*;
34pub use cubecl_runtime::benchmark;
35pub use cubecl_runtime::memory_management::MemoryUsage;
36
37use frontend::LaunchArg;
38
39pub use cubecl_common::ExecutionMode;
40pub use cubecl_common::{flex32, tf32};
41
42pub use prelude::CubeCount;
43pub use prelude::CubeDim;
44
45mod id;
46pub use id::*;
47
48/// Calculate the number of cubes required to execute an operation where one cube unit is
49/// assigned to one element.
50pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: CubeDim) -> CubeCount {
51    let num_elems_per_cube = cube_dim.num_elems();
52    let cube_counts = f32::max(1.0, f32::ceil(num_elems as f32 / num_elems_per_cube as f32));
53    let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
54    let cube_count_y = f32::max(
55        1.0,
56        f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32)),
57    );
58
59    CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
60}
61
62pub fn tensor_vectorization_factor(
63    factors: &[u8],
64    shape: &[usize],
65    strides: &[usize],
66    dim: usize,
67) -> u8 {
68    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
69}
70pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
71    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
72}
73
74#[derive(Debug, Clone)]
75pub enum LineSizeError {
76    AxisOutOfBounds,
77    StrideMismatch,
78    NoValidLineSize,
79}
80
81/// Find the maximum line size usable for parallel vectorization along the given axis
82/// from the supported line sizes or return 1 if vectorization is impossible.
83///
84/// This function is designed to never return a line size above 1 by error,
85/// but doesn't guarantee to always return the actual maximum possible line size.
86/// That is, it may be overly strict.
87///
88/// Currently, this checks that the stride of the axis is 1, that it's shape is
89/// divisible by a candidate line size and that the smallest stride that is not 1
90/// is divisible by the vectorization.
91/// The last condition ensure that the current axis is contiguous within the next stride.
92pub fn tensor_line_size_parallel(
93    supported_line_sizes: impl Iterator<Item = u8>,
94    shape: &[usize],
95    strides: &[usize],
96    axis: usize,
97) -> u8 {
98    try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1)
99}
100
101/// Like `try_tensor_line_size_parallel` but does not assume 1 is supported
102pub fn try_tensor_line_size_parallel(
103    supported_line_sizes: impl Iterator<Item = u8>,
104    shape: &[usize],
105    strides: &[usize],
106    axis: usize,
107) -> Result<u8, LineSizeError> {
108    let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
109    if *stride != 1 {
110        return Err(LineSizeError::StrideMismatch);
111    }
112
113    let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
114
115    let next_stride = *strides
116        .iter()
117        .filter(|&&stride| stride > 1)
118        .min()
119        .unwrap_or(&0);
120
121    supported_line_sizes
122        .filter(|&line_size| {
123            axis_shape % line_size as usize == 0 && next_stride % line_size as usize == 0
124        })
125        .max()
126        .ok_or(LineSizeError::NoValidLineSize)
127}
128
129/// Find the maximum line size usable for perpendicular vectorization along the given axis
130/// from the supported line sizes or return 1 if vectorization is impossible.
131///
132/// This function is designed to never return a line size above 1 by error,
133/// but doesn't guarantee to always return the actual maximum possible line size.
134/// That is, it may be overly strict.
135///
136/// Currently, this checks that the stride of the axis is divisible by a candidate line size
137/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
138/// The second condition ensure that elements within the stride are contiguous.
139pub fn tensor_line_size_perpendicular(
140    supported_line_sizes: impl Iterator<Item = u8>,
141    shape: &[usize],
142    strides: &[usize],
143    axis: usize,
144) -> u8 {
145    try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
146}
147
148/// Like `tensor_line_size_perpendicular` but does not assume 1 is supported
149pub fn try_tensor_line_size_perpendicular(
150    supported_line_sizes: impl Iterator<Item = u8>,
151    shape: &[usize],
152    strides: &[usize],
153    axis: usize,
154) -> Result<u8, LineSizeError> {
155    let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
156
157    let prod_shape_axes_smaller_strides = strides
158        .iter()
159        .zip(shape.iter())
160        .filter(|(stride, _)| **stride < *axis_stride)
161        .map(|(_, shape)| shape)
162        .product::<usize>();
163
164    if *axis_stride != prod_shape_axes_smaller_strides {
165        return Err(LineSizeError::StrideMismatch);
166    }
167
168    supported_line_sizes
169        .filter(|&line_size| *axis_stride % line_size as usize == 0)
170        .max()
171        .ok_or(LineSizeError::NoValidLineSize)
172}
173
174/// Runtime arguments to launch a kernel.
175pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
176
177#[cfg(feature = "export_tests")]
178/// Tests only useful for runtimes.
179pub mod runtime_tests;