use cubecl::prelude::*;
use cubecl::std::{
FastDivmod,
tensor::layout::{Layout, LayoutExpand},
};
use cubek_matmul::{
components::global::{GlobalConfig, memory::GlobalLayoutConfig},
launch::BatchedCoords,
};
use crate::components::{
ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
global::layout::{NhwcCoords, div_mod_seq},
};
#[derive(CubeType, CubeLaunch, Clone)]
pub struct Im2colLayout {
pub shape_out: Sequence<FastDivmod<u32>>,
pub padded_channels: FastDivmod<u32>,
pub rows: u32,
pub cols: u32,
#[cube(comptime)]
pub params: ConvolutionParams,
#[cube(comptime)]
pub config: GlobalLayoutConfig,
}
#[cube]
impl Im2colLayout {
pub fn new<G: GlobalConfig>(
rows: u32,
cols: u32,
padded_channels: FastDivmod<u32>,
shape_out: Sequence<FastDivmod<u32>>,
#[comptime] config: GlobalLayoutConfig,
#[comptime] params: ConvolutionParams,
) -> Im2colLayout {
Im2colLayout {
shape_out,
padded_channels,
rows,
cols,
params,
config,
}
}
}
#[cube]
impl Layout for Im2colLayout {
type Coordinates = BatchedCoords;
type SourceCoordinates = NhwcCoords;
fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
let params = self.params.comptime();
let (_, view_m, view_k) = pos;
let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
let (mut rem, channel) = self.padded_channels.div_mod(view_k);
let spatial_dims = params.dimensionality.num_dims();
let mut in_pos = Sequence::<i32>::new();
#[unroll]
for i in 0..spatial_dims {
let dim = spatial_dims - i - 1;
let ksize = params.kernel_size[dim];
let k_pos = (rem % ksize) as i32;
rem /= ksize;
let out_pos = out_offs[dim];
let stride = params.stride[dim] as i32;
let dilate = params.dilation[dim] as i32;
let pad = params.padding[dim];
let pos = match params.operation {
ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => {
(out_pos as i32 * stride + k_pos * dilate) - pad
}
ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
(out_pos as i32 + pad - k_pos * dilate) / stride
}
};
in_pos.push(pos);
}
let in_pos = in_pos.rev();
NhwcCoords {
batch,
spatial: in_pos,
channel,
}
}
fn shape(&self) -> Self::Coordinates {
(1, self.rows, self.cols)
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
let (_, view_m, view_k) = pos;
let m_in_bounds = !self.config.check_row_bounds || view_m < self.rows;
let k_in_bounds = !self.config.check_col_bounds || view_k < self.cols;
m_in_bounds && k_in_bounds
}
}
impl<R: Runtime> Im2colLayoutLaunch<R> {
pub fn from_args(
problem: &ConvolutionProblem,
params: ConvolutionParams,
config: GlobalLayoutConfig,
) -> Self {
match problem.operation {
ConvolutionOperation::Forward => Self::from_args_fprop(problem, params, config),
ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
Self::from_args_dgrad(problem, params, config)
}
ConvolutionOperation::BackwardWeight => Self::from_args_wgrad(problem, params, config),
}
}
fn from_args_fprop(
problem: &ConvolutionProblem,
params: ConvolutionParams,
config: GlobalLayoutConfig,
) -> Self {
let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
let padded_channels = problem.padded_channels as u32;
let shape_m = problem.m as u32;
let shape_k = problem.k as u32;
Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
}
fn from_args_dgrad(
problem: &ConvolutionProblem,
params: ConvolutionParams,
config: GlobalLayoutConfig,
) -> Self {
let shape = problem.in_shape.iter().map(|s| *s as u32).collect();
let padded_channels = problem.padded_channels as u32;
let shape_m = problem.m as u32;
let shape_k = problem.k as u32;
Im2colLayoutLaunch::new(shape, padded_channels, shape_m, shape_k, params, config)
}
fn from_args_wgrad(
problem: &ConvolutionProblem,
params: ConvolutionParams,
config: GlobalLayoutConfig,
) -> Self {
let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
let padded_channels = problem.padded_channels as u32;
let shape_k = problem.k as u32;
let shape_n = problem.n as u32;
Im2colLayoutLaunch::new(shape_out, padded_channels, shape_k, shape_n, params, config)
}
}