extern crate alloc;
#[macro_use]
extern crate derive_new;
pub mod frontend;
pub mod io;
pub mod post_processing;
pub use cubecl_common::{PLANE_DIM_APPROX, future};
pub use cubecl_runtime::memory_management::MemoryConfiguration;
pub use frontend::cmma;
pub use cubecl_ir as ir;
pub mod codegen;
pub mod compute;
pub mod prelude;
mod pod;
mod runtime;
pub use codegen::*;
pub use pod::*;
pub use runtime::*;
pub use cubecl_macros::*;
pub use cubecl_runtime::benchmark;
pub use cubecl_runtime::memory_management::MemoryUsage;
use frontend::LaunchArg;
pub use cubecl_common::ExecutionMode;
pub use cubecl_common::{flex32, tf32};
pub use prelude::CubeCount;
pub use prelude::CubeDim;
mod id;
pub use id::*;
pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: CubeDim) -> CubeCount {
let num_elems_per_cube = cube_dim.num_elems();
let cube_counts = f32::max(1.0, f32::ceil(num_elems as f32 / num_elems_per_cube as f32));
let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
let cube_count_y = f32::max(
1.0,
f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32)),
);
CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
}
pub fn tensor_vectorization_factor(
factors: &[u8],
shape: &[usize],
strides: &[usize],
dim: usize,
) -> u8 {
tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
}
pub fn tensor_line_size(factors: &[u8], shape: &[usize], strides: &[usize], dim: usize) -> u8 {
tensor_line_size_parallel(factors.iter().cloned(), shape, strides, dim)
}
#[derive(Debug, Clone)]
pub enum LineSizeError {
AxisOutOfBounds,
StrideMismatch,
NoValidLineSize,
}
pub fn tensor_line_size_parallel(
supported_line_sizes: impl Iterator<Item = u8>,
shape: &[usize],
strides: &[usize],
axis: usize,
) -> u8 {
try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1)
}
pub fn try_tensor_line_size_parallel(
supported_line_sizes: impl Iterator<Item = u8>,
shape: &[usize],
strides: &[usize],
axis: usize,
) -> Result<u8, LineSizeError> {
let stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
if *stride != 1 {
return Err(LineSizeError::StrideMismatch);
}
let axis_shape = shape.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
let next_stride = *strides
.iter()
.filter(|&&stride| stride > 1)
.min()
.unwrap_or(&0);
supported_line_sizes
.filter(|&line_size| {
axis_shape % line_size as usize == 0 && next_stride % line_size as usize == 0
})
.max()
.ok_or(LineSizeError::NoValidLineSize)
}
pub fn tensor_line_size_perpendicular(
supported_line_sizes: impl Iterator<Item = u8>,
shape: &[usize],
strides: &[usize],
axis: usize,
) -> u8 {
try_tensor_line_size_perpendicular(supported_line_sizes, shape, strides, axis).unwrap_or(1)
}
pub fn try_tensor_line_size_perpendicular(
supported_line_sizes: impl Iterator<Item = u8>,
shape: &[usize],
strides: &[usize],
axis: usize,
) -> Result<u8, LineSizeError> {
let axis_stride = strides.get(axis).ok_or(LineSizeError::AxisOutOfBounds)?;
let prod_shape_axes_smaller_strides = strides
.iter()
.zip(shape.iter())
.filter(|(stride, _)| **stride < *axis_stride)
.map(|(_, shape)| shape)
.product::<usize>();
if *axis_stride != prod_shape_axes_smaller_strides {
return Err(LineSizeError::StrideMismatch);
}
supported_line_sizes
.filter(|&line_size| *axis_stride % line_size as usize == 0)
.max()
.ok_or(LineSizeError::NoValidLineSize)
}
pub type RuntimeArg<'a, T, R> = <T as LaunchArg>::RuntimeArg<'a, R>;
#[cfg(feature = "export_tests")]
pub mod runtime_tests;