cubecl_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::global::{GlobalConfig, memory::GlobalMemoryConfig};
4use cubecl_std::{
5 FastDivmod, FastDivmodArgs,
6 tensor::layout::{Coords3d, Layout, LayoutExpand},
7};
8
9use crate::{
10 components::{
11 ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
12 global::{layout::NhwcCoords, read::im2col_tma::div_mod_seq},
13 },
14 kernels::layered::selector::RuntimeArgs,
15};
16
17#[derive(CubeType, CubeLaunch, Clone)]
21pub struct Im2colLayout {
22 pub shape_out: Sequence<FastDivmod>,
24 pub shape_channel: FastDivmod,
26
27 pub shape_m: u32,
29 pub shape_k: u32,
31
32 #[cube(comptime)]
34 pub params: ConvolutionParams,
35 #[cube(comptime)]
37 pub config: GlobalMemoryConfig,
38}
39
40#[cube]
41impl Im2colLayout {
42 pub fn new<G: GlobalConfig>(
43 args: &RuntimeArgs,
44 #[comptime] config: ConvolutionConfig<G>,
45 ) -> Im2colLayout {
46 let shape_out = args.shape_out.clone();
47
48 Im2colLayout {
49 shape_out,
50 shape_channel: args.shape_channel,
51 shape_m: args.shape_m,
52 shape_k: args.shape_k,
53 params: config.convolution_params,
54 config: config.lhs_global_memory_config(),
55 }
56 }
57}
58
59#[cube]
60impl Layout for Im2colLayout {
61 type Coordinates = Coords3d;
62 type SourceCoordinates = NhwcCoords;
63
64 fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
65 let params = comptime![self.params];
66 let (_, view_m, view_k) = pos;
67
68 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
69
70 let (mut rem, channel) = self.shape_channel.div_mod(view_k);
71
72 let spatial_dims = comptime![self.shape_out.len()];
73 let mut in_pos = Sequence::<i32>::new();
74
75 #[unroll]
76 for i in 0..spatial_dims {
77 let dim = comptime![spatial_dims - i - 1];
78 let ksize = comptime![params.kernel_size[dim as usize]];
79 let k_pos = rem % ksize;
80 rem /= ksize;
81
82 let out_pos = *out_offs.index(dim);
83 let stride = comptime![params.stride[dim as usize]];
84 let dilate = comptime![params.dilation[dim as usize]];
85 let pad = comptime![params.padding[dim as usize]];
86
87 let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
88 in_pos.push(pos);
89 }
90
91 let in_pos = in_pos.rev();
92
93 NhwcCoords {
94 batch,
95 spatial: in_pos,
96 channel,
97 }
98 }
99
100 fn shape(&self) -> Self::Coordinates {
101 (1, self.shape_m, self.shape_k)
102 }
103
104 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
105 (self.to_source_pos(pos), self.is_in_bounds(pos))
106 }
107
108 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
109 let (_, view_m, view_k) = pos;
110 let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
112 let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
113 m_in_bounds && k_in_bounds
114 }
115}
116
117impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
118 pub fn from_args(
119 client: &ComputeClient<R::Server>,
120 problem: &ConvolutionProblem,
121 params: ConvolutionParams,
122 config: GlobalMemoryConfig,
123 ) -> Self {
124 let shape_out = problem
125 .out_shape
126 .iter()
127 .map(|s| FastDivmodArgs::new(client, *s as u32))
128 .collect();
129 let shape_channel = FastDivmodArgs::new(client, problem.channels as u32);
130
131 let shape_m = ScalarArg::new(problem.m as u32);
132 let shape_k = ScalarArg::new(problem.k as u32);
133
134 Im2colLayoutLaunch::new(shape_out, shape_channel, shape_m, shape_k, params, config)
135 }
136}