1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6pub mod frontend;
8
9pub use cubecl_common::future;
11
12pub use cubecl_runtime::memory_management::MemoryConfiguration;
13pub use frontend::cmma;
14
15pub mod ir;
17
18pub mod codegen;
19pub mod compute;
20pub mod prelude;
21
22mod pod;
23mod runtime;
24
25pub use codegen::*;
26pub use pod::*;
27pub use runtime::*;
28
29pub use cubecl_macros::*;
30pub use cubecl_runtime::benchmark;
31
32pub const PLANE_DIM_APPROX: usize = 16;
34
35use crate::ir::KernelDefinition;
36use frontend::LaunchArg;
37
38pub use prelude::CubeCount;
39pub use prelude::CubeDim;
40pub use prelude::{flex32, tf32};
41
42mod id;
43pub use id::*;
44
45pub trait Kernel: Send + Sync + 'static + Sized {
47 fn define(&self) -> KernelDefinition;
49 fn id(&self) -> KernelId {
51 KernelId::new::<Self>()
52 }
53}
54
55pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: CubeDim) -> CubeCount {
58 let num_elems_per_cube = cube_dim.num_elems();
59 let cube_counts = f32::max(1.0, f32::ceil(num_elems as f32 / num_elems_per_cube as f32));
60 let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
61 let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
62
63 CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
64}
65
66pub fn tensor_vectorization_factor(
67 factors: &[u8],
68 shape: &[usize],
69 strides: &[usize],
70 dim: usize,
71) -> u8 {
72 tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
73}
74pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
75 tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
76}
77
78pub fn tensor_line_size_parallel(
90 supported_line_sizes: impl Iterator<Item = u8>,
91 shape: &[usize],
92 strides: &[usize],
93 axis: usize,
94) -> u8 {
95 match strides.get(axis) {
96 Some(val) => {
97 if *val != 1 {
98 return 1;
99 }
100 }
101 None => return 1,
102 }
103
104 let axis_shape = match shape.get(axis) {
105 Some(val) => val,
106 None => return 1,
107 };
108
109 let next_stride = strides.iter().filter(|stride| **stride > 1).min();
110
111 if let Some(next_stride) = next_stride {
112 if next_stride != axis_shape {
113 return 1;
114 }
115 }
116
117 supported_line_sizes
118 .filter(|line_size| axis_shape % *line_size as usize == 0)
119 .max()
120 .unwrap_or(1)
121}
122
123pub fn tensor_line_size_perpendicular(
134 supported_line_sizes: impl Iterator<Item = u8>,
135 shape: &[usize],
136 strides: &[usize],
137 axis: usize,
138) -> u8 {
139 let axis_stride = match strides.get(axis) {
140 Some(stride) => *stride,
141 None => return 1,
142 };
143
144 let prod_shape_axes_smaller_strides = strides
145 .iter()
146 .zip(shape.iter())
147 .filter(|(stride, _)| **stride < axis_stride)
148 .map(|(_, shape)| shape)
149 .product::<usize>();
150
151 if axis_stride != prod_shape_axes_smaller_strides {
152 return 1;
153 }
154
155 supported_line_sizes
156 .filter(|line_size| axis_stride % *line_size as usize == 0)
157 .max()
158 .unwrap_or(1)
159}
160
161pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
163
164#[cfg(feature = "export_tests")]
165pub mod runtime_tests;