cubek_convolution/components/global/layout/
im2col.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod, FastDivmodArgs,
4 tensor::layout::{Coords3d, Layout, LayoutExpand},
5};
6use cubek_matmul::components::{
7 MatmulElems,
8 global::{GlobalConfig, memory::GlobalMemoryConfig},
9};
10
11use crate::components::{
12 ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
13 global::{args::RuntimeArgs, layout::NhwcCoords, read::im2col_tma::div_mod_seq},
14};
15
16#[derive(CubeType, CubeLaunch, Clone)]
20pub struct Im2colLayout {
21 pub shape_out: Sequence<FastDivmod>,
23 pub padded_channels: FastDivmod,
25
26 pub shape_m: u32,
28 pub shape_k: u32,
30
31 #[cube(comptime)]
33 pub params: ConvolutionParams,
34 #[cube(comptime)]
36 pub config: GlobalMemoryConfig,
37}
38
39#[cube]
40impl Im2colLayout {
41 pub fn new<G: GlobalConfig>(
42 args: &RuntimeArgs,
43 shape_out: Sequence<FastDivmod>,
44 #[comptime] config: ConvolutionConfig<G>,
45 ) -> Im2colLayout {
46 Im2colLayout {
47 shape_out,
48 padded_channels: args.padded_channels,
49 shape_m: args.shape_m,
50 shape_k: args.shape_k,
51 params: config.convolution_params,
52 config: config.lhs_global_memory_config(),
53 }
54 }
55}
56
57#[cube]
58impl Layout for Im2colLayout {
59 type Coordinates = Coords3d;
60 type SourceCoordinates = NhwcCoords;
61
62 fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
63 let params = comptime![self.params];
64 let (_, view_m, view_k) = pos;
65
66 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
67
68 let (mut rem, channel) = self.padded_channels.div_mod(view_k);
69
70 let spatial_dims = comptime![self.shape_out.len()];
71 let mut in_pos = Sequence::<i32>::new();
72
73 #[unroll]
74 for i in 0..spatial_dims {
75 let dim = comptime![spatial_dims - i - 1];
76 let ksize = comptime![params.kernel_size[dim as usize]];
77 let k_pos = rem % ksize;
78 rem /= ksize;
79
80 let out_pos = *out_offs.index(dim);
81 let stride = comptime![params.stride[dim as usize]];
82 let dilate = comptime![params.dilation[dim as usize]];
83 let pad = comptime![params.padding[dim as usize]];
84
85 let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
86 in_pos.push(pos);
87 }
88
89 let in_pos = in_pos.rev();
90
91 NhwcCoords {
92 batch,
93 spatial: in_pos,
94 channel,
95 }
96 }
97
98 fn shape(&self) -> Self::Coordinates {
99 (1, self.shape_m, self.shape_k)
100 }
101
102 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
103 (self.to_source_pos(pos), self.is_in_bounds(pos))
104 }
105
106 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
107 let (_, view_m, view_k) = pos;
108 let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
110 let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
111 m_in_bounds && k_in_bounds
112 }
113}
114
115impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> {
116 pub fn from_args(
117 client: &ComputeClient<R>,
118 problem: &ConvolutionProblem,
119 params: ConvolutionParams,
120 config: GlobalMemoryConfig,
121 dtypes: &MatmulElems,
122 ) -> Self {
123 let shape_out = problem
124 .out_shape
125 .iter()
126 .map(|s| FastDivmodArgs::new(client, *s as u32))
127 .collect();
128
129 let load_width = client.properties().hardware.load_width;
130 let channel_align = load_width / dtypes.lhs_global.size_bits() as u32;
131 let padded_channels = (problem.channels as u32).next_multiple_of(channel_align);
132
133 let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
134 let padded_channels = FastDivmodArgs::new(client, padded_channels);
135
136 let shape_m = ScalarArg::new(problem.m as u32);
137 let shape_k = ScalarArg::new(size_k);
138
139 Im2colLayoutLaunch::new(shape_out, padded_channels, shape_m, shape_k, params, config)
140 }
141}