cubek-convolution 0.2.0-pre.5

CubeK: Convolution Kernels
Documentation
use cubecl::{
    prelude::barrier::copy_async_checked,
    prelude::*,
    std::tensor::{View, layout::Coords2d},
};
use cubek_matmul::components::{
    global::GlobalReaderConfig,
    stage::{StridedStageMemory, TilingLayout},
};
use cubek_std::{MatrixLayout, StageIdent};

use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};

/// The instruction has a max width of 128 bits, even on Blackwell which supports 256-bit loads
pub(crate) const ASYNC_COPY_WIDTH: u32 = 128;

/// Custom version of async copy to clamp slice on channels, not `k` as a whole.
#[cube]
#[expect(clippy::overly_complex_bool_expr, reason = "override")]
pub(crate) fn async_copy_from<EG: Scalar, EGS: Size, ES: Numeric, ESS: Size, T: TilingLayout>(
    view: View<Vector<EG, EGS>, Coords2d>,
    pos: Coords2d,
    stage: &mut StridedStageMemory<ES, ESS, T>,
    stage_offset: u32,
    runtime_args: &RuntimeArgs,
    k_offset: u32,
    #[comptime] config: GlobalReaderConfig,
    #[comptime] copy_vector_size: u32,
) {
    let operation = runtime_args.operation.comptime();
    let channels = runtime_args.channels;

    let mut stage_slice = stage.as_slice_mut::<ESS>();
    let slice_size = match config.smem_config.matrix_layout {
        MatrixLayout::RowMajor => (1u32, copy_vector_size),
        MatrixLayout::ColMajor => (copy_vector_size, 1u32),
    }
    .runtime();

    let mut slice_len_global = copy_vector_size.runtime();
    let slice_len_stage = copy_vector_size / stage_slice.vector_size() as u32;

    match (config.stage_ident, operation) {
        (StageIdent::Lhs, ConvolutionOperation::Forward)
        | (StageIdent::Lhs, ConvolutionOperation::ForwardTransposed)
        | (StageIdent::Lhs, ConvolutionOperation::BackwardData)
        | (StageIdent::Rhs, ConvolutionOperation::BackwardWeight) => {
            // im2col can give negative spatial indices so need to do a full bounds check on Lhs
            slice_len_global *= u32::cast_from(view.is_in_bounds(pos));

            // Remove check override later, currently checks are always false because matmul can't
            // understand channel padding and treats padded `k` as actual shape
            if config.gmem_config.check_col_bounds || true {
                let in_c = runtime_args.padded_channels.modulo(k_offset + pos.1);
                slice_len_global = channels.saturating_sub(in_c).min(slice_len_global);
            }
        }
        (StageIdent::Rhs, ConvolutionOperation::Forward)
        | (StageIdent::Out, ConvolutionOperation::BackwardWeight) => {
            // Remove check override later, currently checks are always false because matmul can't
            // understand channel padding and treats padded `k` as actual shape
            if config.gmem_config.check_row_bounds || true {
                let in_c = runtime_args.padded_channels.modulo(k_offset + pos.0);
                slice_len_global = channels.saturating_sub(in_c).min(slice_len_global);
            }
            if config.gmem_config.check_col_bounds {
                slice_len_global *= u32::cast_from(pos.1 < view.shape().1);
            }
        }
        (StageIdent::Rhs, ConvolutionOperation::ForwardTransposed)
        | (StageIdent::Rhs, ConvolutionOperation::BackwardData) => {
            // Remove check override later, currently checks are always false because matmul can't
            // understand channel padding and treats padded `k` as actual shape
            if config.gmem_config.check_row_bounds || true {
                let out_c = runtime_args.padded_channels.modulo(k_offset + pos.0);
                slice_len_global *=
                    u32::cast_from(out_c < runtime_args.channels && pos.0 < view.shape().0);
            }
            if config.gmem_config.check_col_bounds {
                let pos = pos.1;
                let shape = view.shape().1;
                slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
            }
        }
        _ => {
            if config.gmem_config.check_row_bounds {
                let pos = pos.0;
                let shape = view.shape().0;
                match config.gmem_config.matrix_layout {
                    MatrixLayout::RowMajor => {
                        slice_len_global *= u32::cast_from(pos < shape);
                    }
                    MatrixLayout::ColMajor => {
                        slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
                    }
                }
            }

            if config.gmem_config.check_col_bounds {
                let pos = pos.1;
                let shape = view.shape().1;
                match config.gmem_config.matrix_layout {
                    MatrixLayout::RowMajor => {
                        slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
                    }
                    MatrixLayout::ColMajor => {
                        slice_len_global *= u32::cast_from(pos < shape);
                    }
                }
            }
        }
    }

    slice_len_global /= view.vector_size() as u32;

    let global_slice = view.slice_unchecked(pos, slice_size).to_linear_slice();

    let type_size = Vector::<ES, ESS>::type_size();
    let offset = stage.swizzle.apply(stage_offset, type_size);

    let stage_slice = stage_slice.slice_mut(offset as usize, (offset + slice_len_stage) as usize);

    copy_async_checked(
        &global_slice.slice(0, slice_len_global as usize),
        &mut stage_slice.downcast(),
        copy_vector_size,
    );
}