cubecl_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::{
4    MatmulIdent,
5    global::{GlobalConfig, memory::GlobalMemoryConfig},
6};
7use cubecl_std::{
8    FastDivmod, FastDivmodArgs,
9    tensor::layout::{Coords3d, Layout, LayoutExpand},
10};
11
12use crate::{
13    components::{
14        ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
15        global::{layout::NhwcCoords, read::im2col_tma::div_mod_seq},
16    },
17    kernels::layered::selector::RuntimeArgs,
18};
19
20#[derive(CubeType, CubeLaunch, Clone)]
24pub struct Im2colLayout {
25    pub shape_out: Sequence<FastDivmod>,
27    pub shape_channel: FastDivmod,
29
30    pub shape_m: u32,
32    pub shape_k: u32,
34
35    #[cube(comptime)]
37    pub params: ConvolutionParams,
38    #[cube(comptime)]
40    pub config: GlobalMemoryConfig,
41}
42
43#[cube]
44impl Im2colLayout {
45    pub fn new<G: GlobalConfig>(
46        args: &RuntimeArgs,
47        #[comptime] config: ConvolutionConfig<G>,
48    ) -> Im2colLayout {
49        let shape_out = args.shape_out.clone();
50
51        Im2colLayout {
52            shape_out,
53            shape_channel: args.shape_channel,
54            shape_m: args.shape_m,
55            shape_k: args.shape_k,
56            params: config.convolution_params(),
57            config: config.global_memory_config(MatmulIdent::Lhs),
58        }
59    }
60}
61
62#[cube]
63impl Layout for Im2colLayout {
64    type Coordinates = Coords3d;
65    type SourceCoordinates = NhwcCoords;
66
67    fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
68        let params = comptime![self.params];
69        let (_, view_m, view_k) = pos;
70
71        let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
72
73        let (mut rem, channel) = self.shape_channel.div_mod(view_k);
74
75        let spatial_dims = comptime![self.shape_out.len()];
76        let mut in_pos = Sequence::<i32>::new();
77
78        #[unroll]
79        for i in 0..spatial_dims {
80            let dim = comptime![spatial_dims - i - 1];
81            let ksize = comptime![params.kernel_size[dim as usize]];
82            let k_pos = rem % ksize;
83            rem /= ksize;
84
85            let out_pos = *out_offs.index(dim);
86            let stride = comptime![params.stride[dim as usize]];
87            let dilate = comptime![params.dilation[dim as usize]];
88            let pad = comptime![params.padding[dim as usize]];
89
90            let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
91            in_pos.push(pos);
92        }
93
94        let in_pos = in_pos.rev();
95
96        NhwcCoords {
97            batch,
98            spatial: in_pos,
99            channel,
100        }
101    }
102
103    fn shape(&self) -> Self::Coordinates {
104        (1, self.shape_m, self.shape_k)
105    }
106
107    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
108        (self.to_source_pos(pos), self.is_in_bounds(pos))
109    }
110
111    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
112        let (_, view_m, view_k) = pos;
113        let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
115        let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
116        m_in_bounds && k_in_bounds
117    }
118}
119
120impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
121    pub fn from_args(
122        client: &ComputeClient<R::Server>,
123        problem: &ConvolutionProblem,
124        params: ConvolutionParams,
125        config: GlobalMemoryConfig,
126    ) -> Self {
127        let shape_out = problem
128            .out_shape
129            .iter()
130            .map(|s| FastDivmodArgs::new(client, *s as u32))
131            .collect();
132        let shape_channel = FastDivmodArgs::new(client, problem.channels as u32);
133
134        let shape_m = ScalarArg::new(problem.m as u32);
135        let shape_k = ScalarArg::new(problem.k as u32);
136
137        Im2colLayoutLaunch::new(shape_out, shape_channel, shape_m, shape_k, params, config)
138    }
139}