1#![no_std]
2
3#[cfg(feature = "std")]
4extern crate std;
5
6extern crate alloc;
7
8#[macro_use]
9extern crate derive_new;
10
11pub use cubecl_zspace as zspace;
12
13pub mod frontend;
15pub mod io;
17
18pub mod post_processing;
19
20pub use cubecl_common::future;
22
23use cubecl_ir::LineSize;
24use cubecl_runtime::client::ComputeClient;
25pub use cubecl_runtime::memory_management::MemoryConfiguration;
26use cubecl_runtime::server::CubeCountSelection;
27pub use frontend::cmma;
28
29pub use cubecl_ir as ir;
31
32pub mod codegen;
33pub mod compute;
34pub mod prelude;
35
36mod pod;
37
38pub use codegen::*;
39pub use cubecl_runtime::runtime::*;
40pub use pod::*;
41
42pub use cubecl_macros::*;
43pub use cubecl_runtime::benchmark;
44pub use cubecl_runtime::client;
45pub use cubecl_runtime::compiler::{CompilationError, Compiler, CubeTask};
46pub use cubecl_runtime::memory_management::MemoryUsage;
47pub use cubecl_runtime::server;
48pub use cubecl_runtime::tune;
49
50use frontend::LaunchArg;
51
52pub use cubecl_common::*;
53
54pub use prelude::CubeCount;
55pub use prelude::{CubeDim, ExecutionMode};
56
57pub use num_traits;
58
59mod id;
60pub use id::*;
61
62#[doc(hidden)]
64pub mod __private {
65 pub use alloc::{format, vec};
66}
67
68pub fn calculate_cube_count_elemwise<R: Runtime>(
71 client: &ComputeClient<R>,
72 num_elems: usize,
73 cube_dim: CubeDim,
74) -> CubeCount {
75 let num_cubes = num_elems.div_ceil(cube_dim.num_elems() as usize);
76 CubeCountSelection::new(client, num_cubes as u32).cube_count()
77}
78
79pub fn tensor_vectorization_factor(
80 factors: &[LineSize],
81 shape: &[usize],
82 strides: &[usize],
83 dim: usize,
84) -> LineSize {
85 tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
86}
87pub fn tensor_line_size(
88 factors: &[LineSize],
89 shape: &[usize],
90 strides: &[usize],
91 dim: usize,
92) -> LineSize {
93 tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
94}
95
96#[derive(Debug, Clone)]
97pub enum LineSizeError {
98 AxisOutOfBounds,
99 StrideMismatch,
100 NoValidLineSize,
101}
102
103pub fn tensor_line_size_parallel(
115 optimized_line_sizes: impl Iterator<Item = LineSize>,
116 shape: &[usize],
117 strides: &[usize],
118 axis: usize,
119) -> LineSize {
120 try_tensor_line_size_parallel(optimized_line_sizes, shape, strides, axis).unwrap_or(1)
121}
122
123pub fn try_tensor_line_size_parallel(
125 supported_line_sizes: impl Iterator<Item = LineSize>,
126 shape: &[usize],
127 strides: &[usize],
128 axis: usize,
129) -> Result<LineSize, LineSizeError> {
130 let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
131 if *stride != 1 {
132 return Err(LineSizeError::StrideMismatch);
133 }
134
135 let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
136
137 let next_stride = *strides
138 .iter()
139 .filter(|&&stride| stride > 1)
140 .min()
141 .unwrap_or(&0);
142
143 supported_line_sizes
144 .filter(|&line_size| axis_shape % line_size == 0 && next_stride % line_size == 0)
145 .max()
146 .ok_or(LineSizeError::NoValidLineSize)
147}
148
149pub fn tensor_line_size_perpendicular(
160 supported_line_sizes: impl Iterator<Item = LineSize>,
161 shape: &[usize],
162 strides: &[usize],
163 axis: usize,
164) -> LineSize {
165 try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
166}
167
168pub fn try_tensor_line_size_perpendicular(
170 supported_line_sizes: impl Iterator<Item = LineSize>,
171 shape: &[usize],
172 strides: &[usize],
173 axis: usize,
174) -> Result<LineSize, LineSizeError> {
175 let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
176
177 let prod_shape_axes_smaller_strides = strides
178 .iter()
179 .zip(shape.iter())
180 .filter(|(stride, _)| **stride < *axis_stride)
181 .map(|(_, shape)| shape)
182 .product::<usize>();
183
184 if *axis_stride != prod_shape_axes_smaller_strides {
185 return Err(LineSizeError::StrideMismatch);
186 }
187
188 supported_line_sizes
189 .filter(|&line_size| *axis_stride % line_size == 0)
190 .max()
191 .ok_or(LineSizeError::NoValidLineSize)
192}
193
194pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
196pub type ExpandType<T> = <T as crate::prelude::CubeType>::ExpandType;
197
198#[cfg(feature = "export_tests")]
199pub mod runtime_tests;