cubecl_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::{
4 MatmulIdent,
5 global::{GlobalConfig, memory::GlobalMemoryConfig},
6};
7use cubecl_std::{
8 FastDivmod, FastDivmodArgs,
9 tensor::layout::{Coords3d, Layout, LayoutExpand},
10};
11
12use crate::{
13 components::{
14 ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
15 global::{layout::NhwcCoords, read::im2col_tma::div_mod_seq},
16 },
17 kernels::layered::selector::RuntimeArgs,
18};
19
20#[derive(CubeType, CubeLaunch, Clone)]
24pub struct Im2colLayout {
25 pub shape_out: Sequence<FastDivmod>,
27 pub shape_channel: FastDivmod,
29
30 pub shape_m: u32,
32 pub shape_k: u32,
34
35 #[cube(comptime)]
37 pub params: ConvolutionParams,
38 #[cube(comptime)]
40 pub config: GlobalMemoryConfig,
41}
42
43#[cube]
44impl Im2colLayout {
45 pub fn new<G: GlobalConfig>(
46 args: &RuntimeArgs,
47 #[comptime] config: ConvolutionConfig<G>,
48 ) -> Im2colLayout {
49 let shape_out = args.shape_out.clone();
50
51 Im2colLayout {
52 shape_out,
53 shape_channel: args.shape_channel,
54 shape_m: args.shape_m,
55 shape_k: args.shape_k,
56 params: config.convolution_params(),
57 config: config.global_memory_config(MatmulIdent::Lhs),
58 }
59 }
60}
61
62#[cube]
63impl Layout for Im2colLayout {
64 type Coordinates = Coords3d;
65 type SourceCoordinates = NhwcCoords;
66
67 fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
68 let params = comptime![self.params];
69 let (_, view_m, view_k) = pos;
70
71 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
72
73 let (mut rem, channel) = self.shape_channel.div_mod(view_k);
74
75 let spatial_dims = comptime![self.shape_out.len()];
76 let mut in_pos = Sequence::<i32>::new();
77
78 #[unroll]
79 for i in 0..spatial_dims {
80 let dim = comptime![spatial_dims - i - 1];
81 let ksize = comptime![params.kernel_size[dim as usize]];
82 let k_pos = rem % ksize;
83 rem /= ksize;
84
85 let out_pos = *out_offs.index(dim);
86 let stride = comptime![params.stride[dim as usize]];
87 let dilate = comptime![params.dilation[dim as usize]];
88 let pad = comptime![params.padding[dim as usize]];
89
90 let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
91 in_pos.push(pos);
92 }
93
94 let in_pos = in_pos.rev();
95
96 NhwcCoords {
97 batch,
98 spatial: in_pos,
99 channel,
100 }
101 }
102
103 fn shape(&self) -> Self::Coordinates {
104 (1, self.shape_m, self.shape_k)
105 }
106
107 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
108 (self.to_source_pos(pos), self.is_in_bounds(pos))
109 }
110
111 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
112 let (_, view_m, view_k) = pos;
113 let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
115 let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
116 m_in_bounds && k_in_bounds
117 }
118}
119
120impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
121 pub fn from_args(
122 client: &ComputeClient<R::Server>,
123 problem: &ConvolutionProblem,
124 params: ConvolutionParams,
125 config: GlobalMemoryConfig,
126 ) -> Self {
127 let shape_out = problem
128 .out_shape
129 .iter()
130 .map(|s| FastDivmodArgs::new(client, *s as u32))
131 .collect();
132 let shape_channel = FastDivmodArgs::new(client, problem.channels as u32);
133
134 let shape_m = ScalarArg::new(problem.m as u32);
135 let shape_k = ScalarArg::new(problem.k as u32);
136
137 Im2colLayoutLaunch::new(shape_out, shape_channel, shape_m, shape_k, params, config)
138 }
139}