cubek-matmul 0.2.0

CubeK: Matrix Multiplication Kernels
Documentation
use cubecl::{
    prelude::barrier::{copy_async, copy_async_checked},
    prelude::*,
    std::tensor::{View, layout::Coords2d},
};
use cubek_std::MatrixLayout;

use crate::components::{
    global::GlobalReaderConfig,
    stage::{StridedStageMemory, TilingLayout},
};

/// The instruction has a max width of 128 bits, even on Blackwell which supports 256-bit loads
pub(crate) const ASYNC_COPY_WIDTH: u32 = 128;

#[cube]
pub(crate) fn async_copy_from<EG: Scalar, NG: Size, ES: Numeric, NS: Size, T: TilingLayout>(
    view: View<Vector<EG, NG>, Coords2d>,
    pos: Coords2d,
    stage: &mut StridedStageMemory<ES, NS, T>,
    stage_offset: u32,
    #[comptime] config: GlobalReaderConfig,
    #[comptime] copy_vector_size: u32,
) {
    let mut stage_slice = stage.as_slice_mut::<NS>();
    let slice_size = match config.smem_config.matrix_layout {
        MatrixLayout::RowMajor => (1u32, copy_vector_size),
        MatrixLayout::ColMajor => (copy_vector_size, 1u32),
    }
    .runtime();

    let mut slice_len_global = copy_vector_size.runtime();
    let slice_len_stage = copy_vector_size / stage_slice.vector_size() as u32;

    if config.gmem_config.check_row_bounds {
        let pos = pos.0;
        let shape = view.shape().0;
        match config.gmem_config.matrix_layout {
            MatrixLayout::RowMajor => {
                slice_len_global *= u32::cast_from(pos < shape);
            }
            MatrixLayout::ColMajor => {
                slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
            }
        }
    }

    if config.gmem_config.check_col_bounds {
        let pos = pos.1;
        let shape = view.shape().1;
        match config.gmem_config.matrix_layout {
            MatrixLayout::RowMajor => {
                slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
            }
            MatrixLayout::ColMajor => {
                slice_len_global *= u32::cast_from(pos < shape);
            }
        }
    }

    slice_len_global /= view.vector_size() as u32;

    let global_slice = view.slice_unchecked(pos, slice_size).to_linear_slice();

    let type_size = Vector::<ES, NS>::type_size();
    let offset = stage.swizzle.apply(stage_offset, type_size);

    let stage_slice = stage_slice.slice_mut(offset as usize, (offset + slice_len_stage) as usize);

    if config.gmem_config.check_row_bounds || config.gmem_config.check_col_bounds {
        copy_async_checked(
            &global_slice.slice(0, slice_len_global as usize),
            &mut stage_slice.downcast(),
            copy_vector_size,
        );
    } else {
        copy_async(&global_slice, &mut stage_slice.downcast(), copy_vector_size);
    }
}