cubecl_core/
lib.rs

1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6/// Cube Frontend Types.
7pub mod frontend;
8/// Input Output utilities.
9pub mod io;
10
11pub mod post_processing;
12
13/// Some future utilities that work across environments.
14pub use cubecl_common::future;
15
16use cubecl_runtime::client::ComputeClient;
17pub use cubecl_runtime::memory_management::MemoryConfiguration;
18use cubecl_runtime::server::CubeCountSelection;
19pub use frontend::cmma;
20
21/// Cube Language Internal Representation.
22pub use cubecl_ir as ir;
23
24pub mod codegen;
25pub mod compute;
26pub mod prelude;
27
28mod pod;
29
30pub use codegen::*;
31pub use cubecl_runtime::runtime::*;
32pub use pod::*;
33
34pub use cubecl_macros::*;
35pub use cubecl_runtime::benchmark;
36pub use cubecl_runtime::client;
37pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
38pub use cubecl_runtime::memory_management::MemoryUsage;
39pub use cubecl_runtime::server;
40pub use cubecl_runtime::tune;
41
42use frontend::LaunchArg;
43
44pub use cubecl_common::*;
45
46pub use prelude::CubeCount;
47pub use prelude::{CubeDim, ExecutionMode};
48
49mod id;
50pub use id::*;
51
52/// Calculate the number of cubes required to execute an operation where one cube unit is
53/// assigned to one element.
54pub fn calculate_cube_count_elemwise<R: Runtime>(
55    client: &ComputeClient<R>,
56    num_elems: usize,
57    cube_dim: CubeDim,
58) -> CubeCount {
59    let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
60    CubeCountSelection::new(client, num_cubes as u32).cube_count()
61}
62
63pub fn tensor_vectorization_factor(
64    factors: &[u8],
65    shape: &[usize],
66    strides: &[usize],
67    dim: usize,
68) -> u8 {
69    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
70}
71pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
72    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
73}
74
75#[derive(Debug, Clone)]
76pub enum LineSizeError {
77    AxisOutOfBounds,
78    StrideMismatch,
79    NoValidLineSize,
80}
81
82/// Find the maximum line size usable for parallel vectorization along the given axis
83/// from the supported line sizes or return 1 if vectorization is impossible.
84///
85/// This function is designed to never return a line size above 1 by error,
86/// but doesn't guarantee to always return the actual maximum possible line size.
87/// That is, it may be overly strict.
88///
89/// Currently, this checks that the stride of the axis is 1, that it's shape is
90/// divisible by a candidate line size and that the smallest stride that is not 1
91/// is divisible by the vectorization.
92/// The last condition ensure that the current axis is contiguous within the next stride.
93pub fn tensor_line_size_parallel(
94    supported_line_sizes: impl Iterator<Item = u8>,
95    shape: &[usize],
96    strides: &[usize],
97    axis: usize,
98) -> u8 {
99    try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1)
100}
101
102/// Like `try_tensor_line_size_parallel` but does not assume 1 is supported
103pub fn try_tensor_line_size_parallel(
104    supported_line_sizes: impl Iterator<Item = u8>,
105    shape: &[usize],
106    strides: &[usize],
107    axis: usize,
108) -> Result<u8, LineSizeError> {
109    let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
110    if *stride != 1 {
111        return Err(LineSizeError::StrideMismatch);
112    }
113
114    let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
115
116    let next_stride = *strides
117        .iter()
118        .filter(|&&stride| stride > 1)
119        .min()
120        .unwrap_or(&0);
121
122    supported_line_sizes
123        .filter(|&line_size| {
124            axis_shape % line_size as usize == 0 && next_stride % line_size as usize == 0
125        })
126        .max()
127        .ok_or(LineSizeError::NoValidLineSize)
128}
129
130/// Find the maximum line size usable for perpendicular vectorization along the given axis
131/// from the supported line sizes or return 1 if vectorization is impossible.
132///
133/// This function is designed to never return a line size above 1 by error,
134/// but doesn't guarantee to always return the actual maximum possible line size.
135/// That is, it may be overly strict.
136///
137/// Currently, this checks that the stride of the axis is divisible by a candidate line size
138/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
139/// The second condition ensure that elements within the stride are contiguous.
140pub fn tensor_line_size_perpendicular(
141    supported_line_sizes: impl Iterator<Item = u8>,
142    shape: &[usize],
143    strides: &[usize],
144    axis: usize,
145) -> u8 {
146    try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
147}
148
149/// Like `tensor_line_size_perpendicular` but does not assume 1 is supported
150pub fn try_tensor_line_size_perpendicular(
151    supported_line_sizes: impl Iterator<Item = u8>,
152    shape: &[usize],
153    strides: &[usize],
154    axis: usize,
155) -> Result<u8, LineSizeError> {
156    let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
157
158    let prod_shape_axes_smaller_strides = strides
159        .iter()
160        .zip(shape.iter())
161        .filter(|(stride, _)| **stride < *axis_stride)
162        .map(|(_, shape)| shape)
163        .product::<usize>();
164
165    if *axis_stride != prod_shape_axes_smaller_strides {
166        return Err(LineSizeError::StrideMismatch);
167    }
168
169    supported_line_sizes
170        .filter(|&line_size| *axis_stride % line_size as usize == 0)
171        .max()
172        .ok_or(LineSizeError::NoValidLineSize)
173}
174
175/// Runtime arguments to launch a kernel.
176pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
177
178#[cfg(feature = "export_tests")]
179/// Tests only useful for runtimes.
180pub mod runtime_tests;