cubecl_core/
lib.rs

1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6/// Cube Frontend Types.
7pub mod frontend;
8/// Input Output utilities.
9pub mod io;
10
11/// Some future utilities that work across environments.
12pub use cubecl_common::{PLANE_DIM_APPROX, future};
13
14pub use cubecl_runtime::memory_management::MemoryConfiguration;
15pub use frontend::cmma;
16
17/// Cube Language Internal Representation.
18pub use cubecl_ir as ir;
19
20pub mod codegen;
21pub mod compute;
22pub mod prelude;
23
24mod pod;
25mod runtime;
26
27pub use codegen::*;
28pub use pod::*;
29pub use runtime::*;
30
31pub use cubecl_macros::*;
32pub use cubecl_runtime::benchmark;
33pub use cubecl_runtime::memory_management::MemoryUsage;
34
35use crate::compute::KernelDefinition;
36use frontend::LaunchArg;
37
38pub use cubecl_common::ExecutionMode;
39pub use cubecl_common::{flex32, tf32};
40
41pub use prelude::CubeCount;
42pub use prelude::CubeDim;
43
44mod id;
45pub use id::*;
46
47/// Implement this trait to create a [kernel definition](KernelDefinition).
48pub trait Kernel: Send + Sync + 'static + Sized {
49    /// Convert to a kernel definition.
50    fn define(&self) -> KernelDefinition;
51    /// Identifier for the kernel, used for caching kernel compilation.
52    fn id(&self) -> KernelId {
53        KernelId::new::<Self>()
54    }
55}
56
57/// Calculate the number of cubes required to execute an operation where one cube unit is
58/// assigned to one element.
59pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: CubeDim) -> CubeCount {
60    let num_elems_per_cube = cube_dim.num_elems();
61    let cube_counts = f32::max(1.0, f32::ceil(num_elems as f32 / num_elems_per_cube as f32));
62    let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
63    let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
64
65    CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
66}
67
68pub fn tensor_vectorization_factor(
69    factors: &[u8],
70    shape: &[usize],
71    strides: &[usize],
72    dim: usize,
73) -> u8 {
74    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
75}
76pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
77    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
78}
79
80/// Find the maximum line size usable for parallel vectorization along the given axis
81/// from the supported line sizes or return 1 if vectorization is impossible.
82///
83/// This function is designed to never return a line size above 1 by error,
84/// but doesn't guarantee to always return the actual maximum possible line size.
85/// That is, it may be overly strict.
86///
87/// Currently, this checks that the stride of the axis is 1, that it's shape is
88/// divisible by a candidate line size and that the smallest stride that is not 1
89/// is divisible by the vectorization.
90/// The last condition ensure that the current axis is contiguous within the next stride.
91pub fn tensor_line_size_parallel(
92    supported_line_sizes: impl Iterator<Item = u8>,
93    shape: &[usize],
94    strides: &[usize],
95    axis: usize,
96) -> u8 {
97    match strides.get(axis) {
98        Some(val) => {
99            if *val != 1 {
100                return 1;
101            }
102        }
103        None => return 1,
104    }
105
106    let axis_shape = match shape.get(axis) {
107        Some(val) => val,
108        None => return 1,
109    };
110
111    let next_stride = *strides
112        .iter()
113        .filter(|stride| **stride > 1)
114        .min()
115        .unwrap_or(&0);
116
117    supported_line_sizes
118        .filter(|line_size| {
119            axis_shape % *line_size as usize == 0 && next_stride % *line_size as usize == 0
120        })
121        .max()
122        .unwrap_or(1)
123}
124
125/// Find the maximum line size usable for perpendicular vectorization along the given axis
126/// from the supported line sizes or return 1 if vectorization is impossible.
127///
128/// This function is designed to never return a line size above 1 by error,
129/// but doesn't guarantee to always return the actual maximum possible line size.
130/// That is, it may be overly strict.
131///
132/// Currently, this checks that the stride of the axis is divisible by a candidate line size
133/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
134/// The second condition ensure that elements within the stride are contiguous.
135pub fn tensor_line_size_perpendicular(
136    supported_line_sizes: impl Iterator<Item = u8>,
137    shape: &[usize],
138    strides: &[usize],
139    axis: usize,
140) -> u8 {
141    let axis_stride = match strides.get(axis) {
142        Some(stride) => *stride,
143        None => return 1,
144    };
145
146    let prod_shape_axes_smaller_strides = strides
147        .iter()
148        .zip(shape.iter())
149        .filter(|(stride, _)| **stride < axis_stride)
150        .map(|(_, shape)| shape)
151        .product::<usize>();
152
153    if axis_stride != prod_shape_axes_smaller_strides {
154        return 1;
155    }
156
157    supported_line_sizes
158        .filter(|line_size| axis_stride % *line_size as usize == 0)
159        .max()
160        .unwrap_or(1)
161}
162
163/// Runtime arguments to launch a kernel.
164pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
165
166#[cfg(feature = "export_tests")]
167/// Tests only useful for runtimes.
168pub mod runtime_tests;