use cubecl::prelude::*;
use cubecl::std::{
FastDivmod,
tensor::layout::{
Coords1d, Coords2d, Layout, LayoutExpand, VirtualLayout, VirtualLayoutLaunch,
},
};
use cubecl::zspace::Shape;
use cubecl_common::quant::scheme::{QuantLevel, QuantScheme};
use cubek_std::MatrixLayout;
use crate::{
definition::MatmulProblem,
{components::global::memory::GlobalMemoryConfig, launch::BatchedCoords},
};
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct SimpleTmaGlobalLayout {
#[cube(comptime)]
transposed: bool,
shape: BatchedCoords,
}
#[cube]
impl SimpleTmaGlobalLayout {
pub fn new(shape: BatchedCoords, #[comptime] layout: MatrixLayout) -> Self {
let transposed = comptime![matches!(layout, MatrixLayout::ColMajor)];
SimpleTmaGlobalLayout { shape, transposed }
}
}
#[cube]
impl Layout for SimpleTmaGlobalLayout {
type Coordinates = BatchedCoords;
type SourceCoordinates = BatchedCoords;
fn to_source_pos(&self, coords: Self::Coordinates) -> BatchedCoords {
let (batch, row, col) = coords;
let batch = select(self.shape.0 == 1, 0, batch);
if self.transposed.comptime() {
(batch, col, row)
} else {
(batch, row, col)
}
}
fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (BatchedCoords, bool) {
(self.to_source_pos(coords), self.is_in_bounds(coords))
}
fn shape(&self) -> Self::Coordinates {
self.shape
}
fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
true.runtime()
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)]
pub struct GlobalLayoutConfig {
pub matrix_layout: MatrixLayout,
pub check_row_bounds: bool,
pub check_col_bounds: bool,
}
impl From<GlobalMemoryConfig> for GlobalLayoutConfig {
fn from(gmem_config: GlobalMemoryConfig) -> Self {
gmem_config.as_global_layout_config()
}
}
#[derive(CubeType, CubeLaunch, Clone)]
pub struct GlobalLayout {
batch_layout: VirtualLayout<Coords1d, Coords1d>,
rows: u32,
cols: u32,
stride_row: usize,
stride_col: usize,
#[cube(comptime)]
vector_size: VectorSize,
#[cube(comptime)]
packing: u32,
#[cube(comptime)]
config: GlobalLayoutConfig,
}
#[cube]
impl GlobalLayout {
#[allow(clippy::too_many_arguments)]
pub fn new(
batch_layout: VirtualLayout<Coords1d, Coords1d>,
shape_row: u32,
shape_col: u32,
stride_row: usize,
stride_col: usize,
#[comptime] vector_size: VectorSize,
#[comptime] packing: u32,
#[comptime] config: GlobalLayoutConfig,
) -> Self {
GlobalLayout {
batch_layout,
rows: shape_row,
cols: shape_col,
stride_row,
stride_col,
vector_size,
packing,
config,
}
}
}
#[cube]
impl Layout for GlobalLayout {
type Coordinates = BatchedCoords;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, coords: Self::Coordinates) -> usize {
let (batch, row, col) = coords;
let batch_offs = self.batch_layout.to_source_pos(batch);
let (row, col) = match self.config.matrix_layout.comptime() {
MatrixLayout::RowMajor => (row, col / self.packing),
MatrixLayout::ColMajor => (row / self.packing, col),
};
let idx = batch_offs + row as usize * self.stride_row + col as usize * self.stride_col;
idx / self.vector_size
}
fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (usize, bool) {
(self.to_source_pos(coords), self.is_in_bounds(coords))
}
fn shape(&self) -> Self::Coordinates {
(u32::MAX.runtime() as usize, self.rows, self.cols)
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
let config = self.config.comptime();
let (_, row, col) = pos;
match (config.check_row_bounds, config.check_col_bounds) {
(true, true) => row < self.rows && col < self.cols,
(true, false) => row < self.rows,
(false, true) => col < self.cols,
(false, false) => true,
}
}
}
impl<R: Runtime> GlobalLayoutLaunch<R> {
pub fn from_handle(
handle: &TensorBinding<R>,
vector_size: VectorSize,
config: GlobalLayoutConfig,
) -> Self {
let rank = handle.shape.len();
let rows = handle.shape[rank - 2];
let cols = handle.shape[rank - 1];
let stride_row = handle.strides[rank - 2];
let stride_col = handle.strides[rank - 1];
GlobalLayoutLaunch::new(
VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
rows as u32,
cols as u32,
stride_row,
stride_col,
vector_size,
1,
config,
)
}
pub fn from_handle_batched(
handle: &TensorBinding<R>,
problem: &MatmulProblem,
vector_size: VectorSize,
config: GlobalLayoutConfig,
) -> Self {
let rank = handle.shape.len();
let rows = handle.shape[rank - 2];
let cols = handle.shape[rank - 1];
let stride_row = handle.strides[rank - 2];
let stride_col = handle.strides[rank - 1];
let batch_layout = BatchLayoutLaunch::from_handle(handle, problem);
GlobalLayoutLaunch::new(
VirtualLayoutLaunch::new::<BatchLayout>(batch_layout),
rows as u32,
cols as u32,
stride_row,
stride_col,
vector_size,
1,
config,
)
}
#[allow(clippy::too_many_arguments)]
pub fn from_quantized_handle(
values: &TensorBinding<R>,
scales: &TensorBinding<R>,
shape: &Shape,
problem: &MatmulProblem,
scheme: QuantScheme,
vector_size: VectorSize,
config: GlobalLayoutConfig,
) -> (GlobalLayoutLaunch<R>, GlobalScaleLayoutArgs<R>) {
let rank = values.shape.len();
let (rows, cols) = (shape[rank - 2], shape[rank - 1]);
let values_layout = {
let (stride_row, stride_col) = (values.strides[rank - 2], values.strides[rank - 1]);
let batch_layout = BatchLayoutLaunch::from_handle(values, problem);
GlobalLayoutLaunch::new(
VirtualLayoutLaunch::new::<BatchLayout>(batch_layout),
rows as u32,
cols as u32,
stride_row,
stride_col,
vector_size / scheme.num_quants(),
scheme.num_quants() as u32,
config,
)
};
let scales_layout = {
let shape = (rows as u32, cols as u32);
match scheme.level {
QuantLevel::Tensor => GlobalScaleLayoutArgs::PerTensor { shape },
QuantLevel::Block(block_size) => {
let [block_row, block_col] = block_size.as_dim();
let scales_layout =
GlobalLayoutLaunch::from_handle_batched(scales, problem, 1, config);
GlobalScaleLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
shape,
scales_layout,
(block_row as u32, block_col as u32),
))
}
}
};
(values_layout, scales_layout)
}
}
#[derive(CubeType, CubeLaunch)]
pub struct BatchLayout {
batch_shape: Sequence<FastDivmod<u32>>,
batch_strides: Sequence<usize>,
}
#[cube]
impl BatchLayout {
pub fn new(batch_strides: Sequence<usize>, batch_shape: Sequence<FastDivmod<u32>>) -> Self {
BatchLayout {
batch_shape,
batch_strides,
}
}
}
#[cube]
impl Layout for BatchLayout {
type Coordinates = Coords1d;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
let mut batch = pos as u32;
let mut batch_offs = 0;
let batch_shape = self.batch_shape.rev();
let batch_strides = self.batch_strides.rev();
#[unroll]
for i in 0..batch_shape.len() {
let (rem, local_pos) = batch_shape[i].div_mod(batch);
batch = rem;
batch_offs += local_pos as usize * batch_strides[i];
}
batch_offs
}
#[allow(clippy::legacy_numeric_constants)]
fn shape(&self) -> Self::Coordinates {
usize::max_value()
}
fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
true.runtime()
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
}
#[derive(CubeType, CubeLaunch)]
pub struct NoopLayout {}
#[cube]
impl NoopLayout {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
NoopLayout {}
}
}
#[cube]
impl Layout for NoopLayout {
type Coordinates = Coords1d;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
pos
}
#[allow(clippy::legacy_numeric_constants)]
fn shape(&self) -> Self::Coordinates {
usize::max_value()
}
fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
true.runtime()
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
}
impl<R: Runtime> BatchLayoutLaunch<R> {
pub fn from_handle(handle: &TensorBinding<R>, problem: &MatmulProblem) -> Self {
let rank = handle.shape.len();
let batch_shape = problem
.out_batches
.iter()
.map(|shape| *shape as u32)
.collect();
let batch_strides = handle.strides[..rank - 2]
.iter()
.zip(&handle.shape[..rank - 2])
.map(|(stride, shape)| if *shape == 1 { 0 } else { *stride })
.collect();
BatchLayoutLaunch::new(batch_shape, batch_strides)
}
}
#[derive(CubeType, CubeLaunch)]
pub enum GlobalScaleLayout {
PerTensor { shape: Coords2d },
BlockScaled(BlockScaledLayout),
}
#[derive(CubeType, CubeLaunch)]
pub struct BlockScaledLayout {
shape: Coords2d,
scales_layout: GlobalLayout,
#[cube(comptime)]
block_size: Coords2d,
}
#[cube]
impl BlockScaledLayout {
pub fn new(
shape: Coords2d,
scales_layout: GlobalLayout,
#[comptime] block_size: Coords2d,
) -> Self {
BlockScaledLayout {
shape,
scales_layout,
block_size,
}
}
}
#[cube]
impl Layout for GlobalScaleLayout {
type Coordinates = BatchedCoords;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, coords: Self::Coordinates) -> usize {
match self {
GlobalScaleLayout::PerTensor { .. } => 0usize.runtime(),
GlobalScaleLayout::BlockScaled(layout) => {
let BlockScaledLayout {
scales_layout,
block_size,
..
} = layout;
let (batch, row, col) = coords;
let (block_row, block_col) = block_size;
let (row, col) = (row / block_row, col / block_col);
scales_layout.to_source_pos((batch, row, col))
}
}
}
fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (usize, bool) {
(self.to_source_pos(coords), self.is_in_bounds(coords))
}
fn shape(&self) -> Self::Coordinates {
match self {
GlobalScaleLayout::PerTensor { shape } => {
(u32::MAX.runtime() as usize, shape.0, shape.1)
}
GlobalScaleLayout::BlockScaled(layout) => {
let (row, col) = layout.shape;
(u32::MAX.runtime() as usize, row, col)
}
}
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
match self {
GlobalScaleLayout::PerTensor { .. } => true.runtime(),
GlobalScaleLayout::BlockScaled(layout) => {
let (_, row, col) = pos;
let config = &layout.scales_layout.config.comptime();
let (rows, cols) = layout.shape;
match (config.check_row_bounds, config.check_col_bounds) {
(true, true) => row < rows && col < cols,
(true, false) => row < rows,
(false, true) => col < cols,
(false, false) => true,
}
}
}
}
}
#[derive(CubeType, CubeLaunch)]
pub struct Transpose<Inner: Layout + LaunchArg> {
inner: Inner,
}
#[cube]
impl<Inner: Layout + LaunchArg> Transpose<Inner> {
pub fn new(inner: Inner) -> Self {
Transpose::<Inner> { inner }
}
}
#[cube]
impl<Inner: Layout<Coordinates = BatchedCoords> + LaunchArg> Layout for Transpose<Inner> {
type Coordinates = BatchedCoords;
type SourceCoordinates = Inner::SourceCoordinates;
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
let (batch, row, col) = pos;
self.inner.to_source_pos((batch, col, row))
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
let (batch, row, col) = pos;
self.inner.is_in_bounds((batch, col, row))
}
fn shape(&self) -> Self::Coordinates {
let (batches, rows, cols) = self.inner.shape();
(batches, cols, rows)
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
}