cubecl_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::{
4 MatmulElems,
5 global::{GlobalConfig, memory::GlobalMemoryConfig},
6};
7use cubecl_std::{
8 FastDivmod, FastDivmodArgs,
9 tensor::layout::{Coords3d, Layout, LayoutExpand},
10};
11
12use crate::components::{
13 ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
14 global::{args::RuntimeArgs, layout::NhwcCoords, read::im2col_tma::div_mod_seq},
15};
16
17#[derive(CubeType, CubeLaunch, Clone)]
21pub struct Im2colLayout {
22 pub shape_out: Sequence<FastDivmod>,
24 pub padded_channels: FastDivmod,
26
27 pub shape_m: u32,
29 pub shape_k: u32,
31
32 #[cube(comptime)]
34 pub params: ConvolutionParams,
35 #[cube(comptime)]
37 pub config: GlobalMemoryConfig,
38}
39
40#[cube]
41impl Im2colLayout {
42 pub fn new<G: GlobalConfig>(
43 args: &RuntimeArgs,
44 shape_out: Sequence<FastDivmod>,
45 #[comptime] config: ConvolutionConfig<G>,
46 ) -> Im2colLayout {
47 Im2colLayout {
48 shape_out,
49 padded_channels: args.padded_channels,
50 shape_m: args.shape_m,
51 shape_k: args.shape_k,
52 params: config.convolution_params,
53 config: config.lhs_global_memory_config(),
54 }
55 }
56}
57
58#[cube]
59impl Layout for Im2colLayout {
60 type Coordinates = Coords3d;
61 type SourceCoordinates = NhwcCoords;
62
63 fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
64 let params = comptime![self.params];
65 let (_, view_m, view_k) = pos;
66
67 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
68
69 let (mut rem, channel) = self.padded_channels.div_mod(view_k);
70
71 let spatial_dims = comptime![self.shape_out.len()];
72 let mut in_pos = Sequence::<i32>::new();
73
74 #[unroll]
75 for i in 0..spatial_dims {
76 let dim = comptime![spatial_dims - i - 1];
77 let ksize = comptime![params.kernel_size[dim as usize]];
78 let k_pos = rem % ksize;
79 rem /= ksize;
80
81 let out_pos = *out_offs.index(dim);
82 let stride = comptime![params.stride[dim as usize]];
83 let dilate = comptime![params.dilation[dim as usize]];
84 let pad = comptime![params.padding[dim as usize]];
85
86 let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
87 in_pos.push(pos);
88 }
89
90 let in_pos = in_pos.rev();
91
92 NhwcCoords {
93 batch,
94 spatial: in_pos,
95 channel,
96 }
97 }
98
99 fn shape(&self) -> Self::Coordinates {
100 (1, self.shape_m, self.shape_k)
101 }
102
103 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
104 (self.to_source_pos(pos), self.is_in_bounds(pos))
105 }
106
107 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
108 let (_, view_m, view_k) = pos;
109 let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
111 let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
112 m_in_bounds && k_in_bounds
113 }
114}
115
116impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
117 pub fn from_args(
118 client: &ComputeClient<R>,
119 problem: &ConvolutionProblem,
120 params: ConvolutionParams,
121 config: GlobalMemoryConfig,
122 dtypes: &MatmulElems,
123 ) -> Self {
124 let shape_out = problem
125 .out_shape
126 .iter()
127 .map(|s| FastDivmodArgs::new(client, *s as u32))
128 .collect();
129
130 let load_width = client.properties().hardware.load_width;
131 let channel_align = load_width / dtypes.lhs_global.size_bits() as u32;
132 let padded_channels = (problem.channels as u32).next_multiple_of(channel_align);
133
134 let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
135 let padded_channels = FastDivmodArgs::new(client, padded_channels);
136
137 let shape_m = ScalarArg::new(problem.m as u32);
138 let shape_k = ScalarArg::new(size_k);
139
140 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
141 }
142}