use cubecl::{
Runtime,
client::ComputeClient,
prelude::*,
std::tensor::{
launch::ViewArg,
layout::{
VirtualLayoutLaunch,
chain::{Chain, ChainLaunch},
},
},
zspace::{shape, strides},
};
use cubek_matmul::{
components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
definition::{Blueprint, MatmulElems, TilingBlueprint},
launch::*,
routines::Routine,
};
use cubek_std::launch::tma::remap_storage_for_tma;
use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
use enumset::EnumSet;
use crate::components::{
ConvolutionParams, ConvolutionProblem,
global::{
args::{RuntimeArgs, RuntimeArgsLaunch},
layout::{
BiasLayout, Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch,
OutLayout, OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
WeightLayoutLaunch,
},
},
};
pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
MatmulArgs<
Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
Config = RuntimeArgs,
>
{
fn adjust_problem<R: Runtime>(
client: &ComputeClient<R>,
problem: ConvolutionProblem,
blueprint: &A::Blueprint,
dtypes: &MatmulElems,
) -> ConvolutionProblem;
}
impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
fn adjust_problem<R: Runtime>(
client: &ComputeClient<R>,
mut problem: ConvolutionProblem,
_blueprint: &A::Blueprint,
dtypes: &MatmulElems,
) -> ConvolutionProblem {
let load_width = client.properties().hardware.load_width;
let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
let padded_channels = problem.channels.next_multiple_of(channel_align);
let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
problem.k = shape_k;
problem.padded_channels = padded_channels;
problem
}
}
impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
for TensorMapArgs<RuntimeArgs>
{
fn adjust_problem<R: Runtime>(
_client: &ComputeClient<R>,
mut problem: ConvolutionProblem,
blueprint: &TilingBlueprint,
_dtypes: &MatmulElems,
) -> ConvolutionProblem {
let channel_align = match blueprint.swizzle_modes.lhs {
SwizzleMode::None => blueprint.tiling_scheme.tile_size.k() as usize,
_ => blueprint.tiling_scheme.elements_per_stage_along_k() as usize,
};
let padded_channels = problem.channels.next_multiple_of(channel_align);
let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
problem.k = shape_k;
problem.padded_channels = padded_channels;
problem
}
}
pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
#[allow(clippy::too_many_arguments)]
fn create<R: Runtime>(
lhs: InputBinding<R>,
rhs: InputBinding<R>,
bias: Option<InputBinding<R>>,
blueprint: &A::Blueprint,
problem: &ConvolutionProblem,
dtypes: &MatmulElems,
) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
}
pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
fn create<R: Runtime>(
out: TensorBinding<R>,
blueprint: &A::Blueprint,
problem: &ConvolutionProblem,
dtypes: &MatmulElems,
) -> Self::RuntimeArg<R>;
}
impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
{
fn create<R: Runtime>(
lhs: InputBinding<R>,
rhs: InputBinding<R>,
bias: Option<InputBinding<R>>,
blueprint: &A::Blueprint,
problem: &ConvolutionProblem,
_dtypes: &MatmulElems,
) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
type RhsLayout = Chain<NhwcLayout, WeightLayout>;
let padded_channels = problem.padded_channels as u32;
let conv_params = ConvolutionParams::from_problem(problem);
let layout_lhs = Im2colLayoutLaunch::from_args(
problem,
conv_params,
blueprint.lhs_global_layout_config(),
);
let layout_rhs =
WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
let layout_lhs = {
let mut checks = EnumSet::empty();
if problem.should_check_spatial_bounds() {
checks.insert(NhwcCheck::Spatial);
}
if problem.should_check_channel() {
checks.insert(NhwcCheck::Channel);
}
let global = NhwcLayoutLaunch::checked(checks);
ChainLaunch::new(global, layout_lhs)
};
let layout_rhs = {
let mut checks = EnumSet::empty();
if problem.should_check_channel() {
checks.insert(NhwcCheck::Channel);
}
let global = NhwcLayoutLaunch::checked(checks);
ChainLaunch::new(global, layout_rhs)
};
let inputs = TensorInputsLaunch::new(
VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
ViewArg::new_tensor::<LhsLayout>(lhs.into_data().into_tensor_arg(), layout_lhs),
VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
ViewArg::new_tensor::<RhsLayout>(rhs.into_data().into_tensor_arg(), layout_rhs),
bias.as_ref()
.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
.into(),
bias.map(|bias| {
ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ())
})
.into(),
);
let runtime_args = RuntimeArgsLaunch::new(
problem.k as u32,
problem.channels as u32,
padded_channels,
conv_params.operation,
);
(inputs, runtime_args)
}
}
impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
fn create<R: Runtime>(
out: TensorBinding<R>,
blueprint: &A::Blueprint,
problem: &ConvolutionProblem,
_dtypes: &MatmulElems,
) -> Self::RuntimeArg<R> {
type Layout = Chain<NhwcLayout, OutLayout>;
let global = NhwcLayoutLaunch::unchecked();
let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
let layout = ChainLaunch::new(global, layout);
let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
TensorOutputLaunch::new(view, batch)
}
}
impl<
Lhs: CubePrimitive,
Rhs: CubePrimitive,
EO: CubePrimitive,
A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
{
fn create<R: Runtime>(
lhs: InputBinding<R>,
rhs: InputBinding<R>,
bias: Option<InputBinding<R>>,
blueprint: &TilingBlueprint,
problem: &ConvolutionProblem,
dtypes: &MatmulElems,
) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
let tiling_scheme = blueprint.tiling_scheme;
let stage_m = tiling_scheme.elements_per_stage_along_m();
let stage_n = tiling_scheme.elements_per_stage_along_n();
let tile_size_k = match blueprint.swizzle_modes.lhs {
SwizzleMode::None => tiling_scheme.tile_size.k,
_ => tiling_scheme.elements_per_stage_along_k(),
};
let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
stage_size_rhs.insert(0, stage_n as usize);
stage_size_rhs.push(tile_size_k as usize);
let lhs_elem = remap_storage_for_tma(dtypes.lhs_stage);
let mut elem_stride = strides![1; 2 + problem.stride.len()];
for (i, stride) in problem.stride.iter().enumerate() {
elem_stride[i + 1] = *stride as usize;
}
let lhs = TensorMapArg::new(
Im2colArgs {
pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
pixel_box_upper_corner: calculate_upper_corner(
&problem.padding,
&problem.kernel_size,
&problem.dilation,
),
channels_per_pixel: tile_size_k,
pixels_per_column: stage_m,
},
lhs.clone().into_data().into_tensor_arg(),
lhs_elem,
)
.with_elem_stride(elem_stride)
.with_swizzle(blueprint.swizzle_modes.lhs.into());
let rhs = TensorMapArg::new(
TiledArgs {
tile_size: stage_size_rhs,
},
rhs.clone().into_data().into_tensor_arg(),
dtypes.rhs_global,
)
.with_swizzle(blueprint.swizzle_modes.rhs.into());
let padded_channels = problem.padded_channels as u32;
let shape_k = problem.k as u32;
let stages_lhs = A::num_stages().lhs;
let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
let check_kernel = !shape_k.is_multiple_of(stages_size_k);
let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
let rhs_layout = WeightLayoutLaunch::from_args(
problem,
GlobalLayoutConfig {
check_row_bounds: false,
check_col_bounds: false,
matrix_layout: MatrixLayout::ColMajor,
},
);
let bias = bias
.map(|bias| ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ()));
let inputs = TensorMapInputsLaunch::new(
ViewArg::new_tensor_map_im2col::<TmaIm2colLayout, _, _>(lhs, lhs_layout),
ViewArg::new_tensor_map_tiled::<WeightLayout>(rhs, rhs_layout),
bias.into(),
ComptimeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
NoopLayoutLaunch::new(),
)),
);
let runtime_args = RuntimeArgsLaunch::new(
shape_k,
problem.channels as u32,
padded_channels,
problem.operation,
);
(inputs, runtime_args)
}
}
fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
padding.iter().map(|padding| -*padding).collect()
}
fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
padding
.iter()
.zip(kernel_size)
.zip(dilation)
.map(|((padding, kernel_size), dilation)| {
*padding - (*kernel_size - 1) as i32 * *dilation as i32
})
.collect()
}