use cubecl::{
prelude::barrier::copy_async_checked,
prelude::*,
std::tensor::{View, layout::Coords2d},
};
use cubek_matmul::components::{
global::GlobalReaderConfig,
stage::{StridedStageMemory, TilingLayout},
};
use cubek_std::{MatrixLayout, StageIdent};
use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};
pub(crate) const ASYNC_COPY_WIDTH: u32 = 128;
#[cube]
#[expect(clippy::overly_complex_bool_expr, reason = "override")]
pub(crate) fn async_copy_from<EG: Scalar, EGS: Size, ES: Numeric, ESS: Size, T: TilingLayout>(
view: View<Vector<EG, EGS>, Coords2d>,
pos: Coords2d,
stage: &mut StridedStageMemory<ES, ESS, T>,
stage_offset: u32,
runtime_args: &RuntimeArgs,
k_offset: u32,
#[comptime] config: GlobalReaderConfig,
#[comptime] copy_vector_size: u32,
) {
let operation = runtime_args.operation.comptime();
let channels = runtime_args.channels;
let mut stage_slice = stage.as_slice_mut::<ESS>();
let slice_size = match config.smem_config.matrix_layout {
MatrixLayout::RowMajor => (1u32, copy_vector_size),
MatrixLayout::ColMajor => (copy_vector_size, 1u32),
}
.runtime();
let mut slice_len_global = copy_vector_size.runtime();
let slice_len_stage = copy_vector_size / stage_slice.vector_size() as u32;
match (config.stage_ident, operation) {
(StageIdent::Lhs, ConvolutionOperation::Forward)
| (StageIdent::Lhs, ConvolutionOperation::ForwardTransposed)
| (StageIdent::Lhs, ConvolutionOperation::BackwardData)
| (StageIdent::Rhs, ConvolutionOperation::BackwardWeight) => {
slice_len_global *= u32::cast_from(view.is_in_bounds(pos));
if config.gmem_config.check_col_bounds || true {
let in_c = runtime_args.padded_channels.modulo(k_offset + pos.1);
slice_len_global = channels.saturating_sub(in_c).min(slice_len_global);
}
}
(StageIdent::Rhs, ConvolutionOperation::Forward)
| (StageIdent::Out, ConvolutionOperation::BackwardWeight) => {
if config.gmem_config.check_row_bounds || true {
let in_c = runtime_args.padded_channels.modulo(k_offset + pos.0);
slice_len_global = channels.saturating_sub(in_c).min(slice_len_global);
}
if config.gmem_config.check_col_bounds {
slice_len_global *= u32::cast_from(pos.1 < view.shape().1);
}
}
(StageIdent::Rhs, ConvolutionOperation::ForwardTransposed)
| (StageIdent::Rhs, ConvolutionOperation::BackwardData) => {
if config.gmem_config.check_row_bounds || true {
let out_c = runtime_args.padded_channels.modulo(k_offset + pos.0);
slice_len_global *=
u32::cast_from(out_c < runtime_args.channels && pos.0 < view.shape().0);
}
if config.gmem_config.check_col_bounds {
let pos = pos.1;
let shape = view.shape().1;
slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
}
}
_ => {
if config.gmem_config.check_row_bounds {
let pos = pos.0;
let shape = view.shape().0;
match config.gmem_config.matrix_layout {
MatrixLayout::RowMajor => {
slice_len_global *= u32::cast_from(pos < shape);
}
MatrixLayout::ColMajor => {
slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
}
}
}
if config.gmem_config.check_col_bounds {
let pos = pos.1;
let shape = view.shape().1;
match config.gmem_config.matrix_layout {
MatrixLayout::RowMajor => {
slice_len_global = shape.saturating_sub(pos).min(slice_len_global);
}
MatrixLayout::ColMajor => {
slice_len_global *= u32::cast_from(pos < shape);
}
}
}
}
}
slice_len_global /= view.vector_size() as u32;
let global_slice = view.slice_unchecked(pos, slice_size).to_linear_slice();
let type_size = Vector::<ES, ESS>::type_size();
let offset = stage.swizzle.apply(stage_offset, type_size);
let stage_slice = stage_slice.slice_mut(offset as usize, (offset + slice_len_stage) as usize);
copy_async_checked(
&global_slice.slice(0, slice_len_global as usize),
&mut stage_slice.downcast(),
copy_vector_size,
);
}