1extern crate alloc;
2
3#[macro_use]
4extern crate derive_new;
5
6pub mod frontend;
8pub mod io;
10
11pub use cubecl_common::{PLANE_DIM_APPROX, future};
13
14pub use cubecl_runtime::memory_management::MemoryConfiguration;
15pub use frontend::cmma;
16
17pub use cubecl_ir as ir;
19
20pub mod codegen;
21pub mod compute;
22pub mod prelude;
23
24mod pod;
25mod runtime;
26
27pub use codegen::*;
28pub use pod::*;
29pub use runtime::*;
30
31pub use cubecl_macros::*;
32pub use cubecl_runtime::benchmark;
33pub use cubecl_runtime::memory_management::MemoryUsage;
34
35use crate::compute::KernelDefinition;
36use frontend::LaunchArg;
37
38pub use cubecl_common::ExecutionMode;
39pub use cubecl_common::{flex32, tf32};
40
41pub use prelude::CubeCount;
42pub use prelude::CubeDim;
43
44mod id;
45pub use id::*;
46
47pub trait Kernel: Send + Sync + 'static + Sized {
49 fn define(&self) -> KernelDefinition;
51 fn id(&self) -> KernelId {
53 KernelId::new::<Self>()
54 }
55}
56
57pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: CubeDim) -> CubeCount {
60 let num_elems_per_cube = cube_dim.num_elems();
61 let cube_counts = f32::max(1.0, f32::ceil(num_elems as f32 / num_elems_per_cube as f32));
62 let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
63 let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
64
65 CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
66}
67
68pub fn tensor_vectorization_factor(
69 factors: &[u8],
70 shape: &[usize],
71 strides: &[usize],
72 dim: usize,
73) -> u8 {
74 tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
75}
76pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
77 tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
78}
79
80pub fn tensor_line_size_parallel(
92 supported_line_sizes: impl Iterator<Item = u8>,
93 shape: &[usize],
94 strides: &[usize],
95 axis: usize,
96) -> u8 {
97 match strides.get(axis) {
98 Some(val) => {
99 if *val != 1 {
100 return 1;
101 }
102 }
103 None => return 1,
104 }
105
106 let axis_shape = match shape.get(axis) {
107 Some(val) => val,
108 None => return 1,
109 };
110
111 let next_stride = *strides
112 .iter()
113 .filter(|stride| **stride > 1)
114 .min()
115 .unwrap_or(&0);
116
117 supported_line_sizes
118 .filter(|line_size| {
119 axis_shape % *line_size as usize == 0 && next_stride % *line_size as usize == 0
120 })
121 .max()
122 .unwrap_or(1)
123}
124
125pub fn tensor_line_size_perpendicular(
136 supported_line_sizes: impl Iterator<Item = u8>,
137 shape: &[usize],
138 strides: &[usize],
139 axis: usize,
140) -> u8 {
141 let axis_stride = match strides.get(axis) {
142 Some(stride) => *stride,
143 None => return 1,
144 };
145
146 let prod_shape_axes_smaller_strides = strides
147 .iter()
148 .zip(shape.iter())
149 .filter(|(stride, _)| **stride < axis_stride)
150 .map(|(_, shape)| shape)
151 .product::<usize>();
152
153 if axis_stride != prod_shape_axes_smaller_strides {
154 return 1;
155 }
156
157 supported_line_sizes
158 .filter(|line_size| axis_stride % *line_size as usize == 0)
159 .max()
160 .unwrap_or(1)
161}
162
163pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
165
166#[cfg(feature = "export_tests")]
167pub mod runtime_tests;