cubecl_core/
lib.rs

1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6/// Cube Frontend Types.
7pub mod frontend;
8
9/// Some future utilities that work across environments.
10pub use cubecl_common::future;
11
12pub use cubecl_runtime::memory_management::MemoryConfiguration;
13pub use frontend::cmma;
14
15/// Cube Language Internal Representation.
16pub mod ir;
17
18pub mod codegen;
19pub mod compute;
20pub mod prelude;
21
22mod pod;
23mod runtime;
24
25pub use codegen::*;
26pub use pod::*;
27pub use runtime::*;
28
29pub use cubecl_macros::*;
30pub use cubecl_runtime::benchmark;
31
32/// An approximation of the plane dimension.
33pub const PLANE_DIM_APPROX: usize = 16;
34
35use crate::ir::KernelDefinition;
36use frontend::LaunchArg;
37
38pub use prelude::CubeCount;
39pub use prelude::CubeDim;
40pub use prelude::{flex32, tf32};
41
42mod id;
43pub use id::*;
44
45/// Implement this trait to create a [kernel definition](KernelDefinition).
46pub trait Kernel: Send + Sync + 'static + Sized {
47    /// Convert to a kernel definition.
48    fn define(&self) -> KernelDefinition;
49    /// Identifier for the kernel, used for caching kernel compilation.
50    fn id(&self) -> KernelId {
51        KernelId::new::<Self>()
52    }
53}
54
55/// Calculate the number of cubes required to execute an operation where one cube unit is
56/// assigned to one element.
57pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: CubeDim) -> CubeCount {
58    let num_elems_per_cube = cube_dim.num_elems();
59    let cube_counts = f32::max(1.0, f32::ceil(num_elems as f32 / num_elems_per_cube as f32));
60    let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
61    let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
62
63    CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
64}
65
66pub fn tensor_vectorization_factor(
67    factors: &[u8],
68    shape: &[usize],
69    strides: &[usize],
70    dim: usize,
71) -> u8 {
72    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
73}
74pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
75    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
76}
77
78/// Find the maximum line size usable for parallel vectorization along the given axis
79/// from the supported line sizes or return 1 if vectorization is impossible.
80///
81/// This function is designed to never return a line size above 1 by error,
82/// but doesn't guarantee to always return the actual maximum possible line size.
83/// That is, it may be overly strict.
84///
85/// Currently, this checks that the stride of the axis is 1, that it's shape is
86/// divisible by a candidate line size and that the smallest stride that is not 1
87/// is equal to the shape of the axis.
88/// The last condition ensure that the current axis is contiguous within the next stride.
89pub fn tensor_line_size_parallel(
90    supported_line_sizes: impl Iterator<Item = u8>,
91    shape: &[usize],
92    strides: &[usize],
93    axis: usize,
94) -> u8 {
95    match strides.get(axis) {
96        Some(val) => {
97            if *val != 1 {
98                return 1;
99            }
100        }
101        None => return 1,
102    }
103
104    let axis_shape = match shape.get(axis) {
105        Some(val) => val,
106        None => return 1,
107    };
108
109    let next_stride = strides.iter().filter(|stride| **stride > 1).min();
110
111    if let Some(next_stride) = next_stride {
112        if next_stride != axis_shape {
113            return 1;
114        }
115    }
116
117    supported_line_sizes
118        .filter(|line_size| axis_shape % *line_size as usize == 0)
119        .max()
120        .unwrap_or(1)
121}
122
123/// Find the maximum line size usable for perpendicular vectorization along the given axis
124/// from the supported line sizes or return 1 if vectorization is impossible.
125///
126/// This function is designed to never return a line size above 1 by error,
127/// but doesn't guarantee to always return the actual maximum possible line size.
128/// That is, it may be overly strict.
129///
130/// Currently, this checks that the stride of the axis is divisible by a candidate line size
131/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
132/// The second condition ensure that elements within the stride are contiguous.
133pub fn tensor_line_size_perpendicular(
134    supported_line_sizes: impl Iterator<Item = u8>,
135    shape: &[usize],
136    strides: &[usize],
137    axis: usize,
138) -> u8 {
139    let axis_stride = match strides.get(axis) {
140        Some(stride) => *stride,
141        None => return 1,
142    };
143
144    let prod_shape_axes_smaller_strides = strides
145        .iter()
146        .zip(shape.iter())
147        .filter(|(stride, _)| **stride < axis_stride)
148        .map(|(_, shape)| shape)
149        .product::<usize>();
150
151    if axis_stride != prod_shape_axes_smaller_strides {
152        return 1;
153    }
154
155    supported_line_sizes
156        .filter(|line_size| axis_stride % *line_size as usize == 0)
157        .max()
158        .unwrap_or(1)
159}
160
161/// Runtime arguments to launch a kernel.
162pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
163
164#[cfg(feature = "export_tests")]
165/// Tests only useful for runtimes.
166pub mod runtime_tests;