use cubecl::prelude::*;
use cubecl_core::{self as cubecl};
use cubecl_matmul::components::{
MatmulIdent,
global::{GlobalConfig, memory::GlobalMemoryConfig},
};
use cubecl_std::{
FastDivmod, FastDivmodArgs,
tensor::layout::{Coords3d, Layout, LayoutExpand},
};
use crate::{
components::{
ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
global::{layout::NhwcCoords, read::im2col_tma::div_mod_seq},
},
kernels::layered::selector::RuntimeArgs,
};
#[derive(CubeType, CubeLaunch, Clone)]
pub struct Im2colLayout {
pub shape_out: Sequence<FastDivmod>,
pub shape_channel: FastDivmod,
pub shape_m: u32,
pub shape_k: u32,
#[cube(comptime)]
pub params: ConvolutionParams,
#[cube(comptime)]
pub config: GlobalMemoryConfig,
}
#[cube]
impl Im2colLayout {
pub fn new<G: GlobalConfig>(
args: &RuntimeArgs,
#[comptime] config: ConvolutionConfig<G>,
) -> Im2colLayout {
let shape_out = args.shape_out.clone();
Im2colLayout {
shape_out,
shape_channel: args.shape_channel,
shape_m: args.shape_m,
shape_k: args.shape_k,
params: config.convolution_params(),
config: config.global_memory_config(MatmulIdent::Lhs),
}
}
}
#[cube]
impl Layout for Im2colLayout {
type Coordinates = Coords3d;
type SourceCoordinates = NhwcCoords;
fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
let params = comptime![self.params];
let (_, view_m, view_k) = pos;
let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
let (mut rem, channel) = self.shape_channel.div_mod(view_k);
let spatial_dims = comptime![self.shape_out.len()];
let mut in_pos = Sequence::<i32>::new();
#[unroll]
for i in 0..spatial_dims {
let dim = comptime![spatial_dims - i - 1];
let ksize = comptime![params.kernel_size[dim as usize]];
let k_pos = rem % ksize;
rem /= ksize;
let out_pos = *out_offs.index(dim);
let stride = comptime![params.stride[dim as usize]];
let dilate = comptime![params.dilation[dim as usize]];
let pad = comptime![params.padding[dim as usize]];
let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
in_pos.push(pos);
}
let in_pos = in_pos.rev();
NhwcCoords {
batch,
spatial: in_pos,
channel,
}
}
fn shape(&self) -> Self::Coordinates {
(1, self.shape_m, self.shape_k)
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
let (_, view_m, view_k) = pos;
let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
m_in_bounds && k_in_bounds
}
}
impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
pub fn from_args(
client: &ComputeClient<R::Server>,
problem: &ConvolutionProblem,
params: ConvolutionParams,
config: GlobalMemoryConfig,
) -> Self {
let shape_out = problem
.out_shape
.iter()
.map(|s| FastDivmodArgs::new(client, *s as u32))
.collect();
let shape_channel = FastDivmodArgs::new(client, problem.channels as u32);
let shape_m = ScalarArg::new(problem.m as u32);
let shape_k = ScalarArg::new(problem.k as u32);
Im2colLayoutLaunch::new(shape_out, shape_channel, shape_m, shape_k, params, config)
}
}