cubecl_core/
lib.rs

1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6/// Cube Frontend Types.
7pub mod frontend;
8/// Input Output utilities.
9pub mod io;
10
11pub mod post_processing;
12
13/// Some future utilities that work across environments.
14pub use cubecl_common::future;
15
16use cubecl_ir::LineSize;
17use cubecl_runtime::client::ComputeClient;
18pub use cubecl_runtime::memory_management::MemoryConfiguration;
19use cubecl_runtime::server::CubeCountSelection;
20pub use frontend::cmma;
21
22/// Cube Language Internal Representation.
23pub use cubecl_ir as ir;
24
25pub mod codegen;
26pub mod compute;
27pub mod prelude;
28
29mod pod;
30
31pub use codegen::*;
32pub use cubecl_runtime::runtime::*;
33pub use pod::*;
34
35pub use cubecl_macros::*;
36pub use cubecl_runtime::benchmark;
37pub use cubecl_runtime::client;
38pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
39pub use cubecl_runtime::memory_management::MemoryUsage;
40pub use cubecl_runtime::server;
41pub use cubecl_runtime::tune;
42
43use frontend::LaunchArg;
44
45pub use cubecl_common::*;
46
47pub use prelude::CubeCount;
48pub use prelude::{CubeDim, ExecutionMode};
49
50pub use num_traits;
51
52mod id;
53pub use id::*;
54
55/// Calculate the number of cubes required to execute an operation where one cube unit is
56/// assigned to one element.
57pub fn calculate_cube_count_elemwise<R: Runtime>(
58    client: &ComputeClient<R>,
59    num_elems: usize,
60    cube_dim: CubeDim,
61) -> CubeCount {
62    let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
63    CubeCountSelection::new(client, num_cubes as u32).cube_count()
64}
65
66pub fn tensor_vectorization_factor(
67    factors: &[LineSize],
68    shape: &[usize],
69    strides: &[usize],
70    dim: usize,
71) -> LineSize {
72    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
73}
74pub fn tensor_line_size(
75    factors: &[LineSize],
76    shape: &[usize],
77    strides: &[usize],
78    dim: usize,
79) -> LineSize {
80    tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
81}
82
83#[derive(Debug, Clone)]
84pub enum LineSizeError {
85    AxisOutOfBounds,
86    StrideMismatch,
87    NoValidLineSize,
88}
89
90/// Find the maximum line size usable for parallel vectorization along the given axis
91/// from the supported line sizes or return 1 if vectorization is impossible.
92///
93/// This function is designed to never return a line size above 1 by error,
94/// but doesn't guarantee to always return the actual maximum possible line size.
95/// That is, it may be overly strict.
96///
97/// Currently, this checks that the stride of the axis is 1, that it's shape is
98/// divisible by a candidate line size and that the smallest stride that is not 1
99/// is divisible by the vectorization.
100/// The last condition ensure that the current axis is contiguous within the next stride.
101pub fn tensor_line_size_parallel(
102    supported_line_sizes: impl Iterator<Item = LineSize>,
103    shape: &[usize],
104    strides: &[usize],
105    axis: usize,
106) -> LineSize {
107    try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1)
108}
109
110/// Like `try_tensor_line_size_parallel` but does not assume 1 is supported
111pub fn try_tensor_line_size_parallel(
112    supported_line_sizes: impl Iterator<Item = LineSize>,
113    shape: &[usize],
114    strides: &[usize],
115    axis: usize,
116) -> Result<LineSize, LineSizeError> {
117    let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
118    if *stride != 1 {
119        return Err(LineSizeError::StrideMismatch);
120    }
121
122    let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
123
124    let next_stride = *strides
125        .iter()
126        .filter(|&&stride| stride > 1)
127        .min()
128        .unwrap_or(&0);
129
130    supported_line_sizes
131        .filter(|&line_size| axis_shape % line_size == 0 && next_stride % line_size == 0)
132        .max()
133        .ok_or(LineSizeError::NoValidLineSize)
134}
135
136/// Find the maximum line size usable for perpendicular vectorization along the given axis
137/// from the supported line sizes or return 1 if vectorization is impossible.
138///
139/// This function is designed to never return a line size above 1 by error,
140/// but doesn't guarantee to always return the actual maximum possible line size.
141/// That is, it may be overly strict.
142///
143/// Currently, this checks that the stride of the axis is divisible by a candidate line size
144/// and that the product of all shapes of axes with smaller strides is equal to the stride of the axis.
145/// The second condition ensure that elements within the stride are contiguous.
146pub fn tensor_line_size_perpendicular(
147    supported_line_sizes: impl Iterator<Item = LineSize>,
148    shape: &[usize],
149    strides: &[usize],
150    axis: usize,
151) -> LineSize {
152    try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
153}
154
155/// Like `tensor_line_size_perpendicular` but does not assume 1 is supported
156pub fn try_tensor_line_size_perpendicular(
157    supported_line_sizes: impl Iterator<Item = LineSize>,
158    shape: &[usize],
159    strides: &[usize],
160    axis: usize,
161) -> Result<LineSize, LineSizeError> {
162    let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
163
164    let prod_shape_axes_smaller_strides = strides
165        .iter()
166        .zip(shape.iter())
167        .filter(|(stride, _)| **stride < *axis_stride)
168        .map(|(_, shape)| shape)
169        .product::<usize>();
170
171    if *axis_stride != prod_shape_axes_smaller_strides {
172        return Err(LineSizeError::StrideMismatch);
173    }
174
175    supported_line_sizes
176        .filter(|&line_size| *axis_stride % line_size == 0)
177        .max()
178        .ok_or(LineSizeError::NoValidLineSize)
179}
180
181/// Runtime arguments to launch a kernel.
182pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
183
184#[cfg(feature = "export_tests")]
185/// Tests only useful for runtimes.
186pub mod runtime_tests;