use cubecl::prelude::*;
use cubecl_core::{self as cubecl};
use cubecl_matmul::components::global::memory::GlobalMemoryConfig;
use cubecl_std::{
FastDivmod, FastDivmodArgs,
tensor::layout::{Coords3d, Layout, LayoutExpand},
};
use crate::{
components::{
ConvolutionProblem,
global::{
layout::{NhwcCoords, cast_seq},
read::im2col_tma::div_mod_seq,
},
},
kernels::layered::selector::RuntimeArgs,
};
#[derive(CubeType, CubeLaunch, Clone)]
pub struct OutLayout {
pub shape_out: Sequence<FastDivmod>,
pub shape_m: u32,
pub shape_n: u32,
#[cube(comptime)]
pub config: GlobalMemoryConfig,
}
#[cube]
impl OutLayout {
pub fn new(args: &RuntimeArgs, #[comptime] config: GlobalMemoryConfig) -> OutLayout {
OutLayout {
shape_out: args.shape_out.clone(),
shape_m: args.shape_m,
shape_n: args.shape_n,
config,
}
}
}
#[cube]
impl Layout for OutLayout {
type Coordinates = Coords3d;
type SourceCoordinates = NhwcCoords;
fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
let (_, view_m, view_n) = coords;
let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
NhwcCoords {
batch,
spatial: cast_seq(spatial),
channel: view_n,
}
}
fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
(self.to_source_pos(coords), self.is_in_bounds(coords))
}
fn shape(&self) -> Self::Coordinates {
(1, self.shape_m, self.shape_n)
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
let (_, m, n) = pos;
let check_m = comptime![self.config.check_row_bounds];
let check_n = comptime![self.config.check_col_bounds];
(!check_m || m < self.shape_m) && (!check_n || n < self.shape_n)
}
}
impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
pub fn from_args(
client: &ComputeClient<R::Server>,
problem: &ConvolutionProblem,
config: GlobalMemoryConfig,
) -> Self {
let shape_out = problem
.out_shape
.iter()
.map(|s| FastDivmodArgs::new(client, *s as u32))
.collect();
let shape_m = ScalarArg::new(problem.m as u32);
let shape_n = ScalarArg::new(problem.n as u32);
Self::new(shape_out, shape_m, shape_n, config)
}
}