cubecl_convolution/reader/
im2col.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, intrinsic};
3use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
4
5use crate::{ConvGemmConfig, loader::im2col_tma::div_mod_seq};
6use cubecl_matmul::components::Ident;
7
8#[derive(CubeType)]
9pub struct Im2colReader<E: Numeric> {
13 pub tensor: VirtualTensor<E>,
14 pub m_offset: u32,
15 pub k_offset: u32,
16
17 pub stride_batch: u32,
18 pub strides_spatial: Sequence<u32>,
19 pub stride_channel: u32,
20
21 pub shapes_spatial: Sequence<u32>,
22 pub shape_channel: u32,
23
24 pub shape_out: Sequence<FastDivmod>,
25
26 pub shape_m: u32,
27 pub shape_k: u32,
28}
29
30#[cube]
31impl<E: Numeric> Im2colReader<E> {
32 #[allow(clippy::too_many_arguments)]
33 pub fn new(
34 tensor: VirtualTensor<E>,
35 shape_out: Sequence<FastDivmod>,
36 x_offset: u32,
37 y_offset: u32,
38 shape_k: u32,
39 shape_m: u32,
40 ) -> Im2colReader<E> {
41 let spatial_dims = comptime![shape_out.len()];
42 let mut strides_spatial = Sequence::new();
43 let mut shapes_spatial = Sequence::new();
44
45 #[unroll]
46 for i in 0..spatial_dims {
47 strides_spatial.push(tensor.stride(i + 1));
48 shapes_spatial.push(tensor.shape(i + 1));
49 }
50
51 let stride_batch = tensor.stride(0);
52 let stride_channel = tensor.stride(spatial_dims + 1);
53
54 let shape_channel = tensor.shape(spatial_dims + 1);
55
56 Im2colReader::<E> {
57 tensor,
58 m_offset: x_offset,
59 k_offset: y_offset,
60 stride_batch,
61 strides_spatial,
62 stride_channel,
63 shapes_spatial,
64 shape_channel,
65 shape_out,
66 shape_m,
67 shape_k,
68 }
69 }
70}
71
72unsafe impl<E: Numeric> Sync for Im2colReader<E> {}
73unsafe impl<E: Numeric> Send for Im2colReader<E> {}
74
75#[cube]
76impl<E: Numeric> Im2colReader<E> {
77 pub fn update_view(&mut self, k_offset: u32) {
79 self.k_offset += k_offset;
80 }
81
82 pub fn load_simple<G: ConvGemmConfig>(
93 &self,
94 tile_x: u32,
95 tile_y: u32,
96 unit_id: u32,
97 #[comptime] ident: Ident,
98 #[comptime] config: G,
99 ) -> Line<E> {
100 let line_size = config.global_line_size(ident);
101 let tile_size_x = config.tiling_scheme().elements_in_tile_row(ident);
102 let tile_size_y = config.tiling_scheme().elements_in_tile_col(ident);
103
104 let view_tile_m = tile_x * tile_size_x + self.m_offset;
105 let view_tile_k = tile_y * tile_size_y + self.k_offset;
106
107 let load_m = unit_id / tile_size_y;
108 let load_k = unit_id % tile_size_y;
109
110 let view_m = view_tile_m + load_m;
111 let view_k = view_tile_k + load_k;
112
113 let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
114
115 let channel = view_k % self.shape_channel;
116 let mut rem = view_k / self.shape_channel;
117
118 let spatial_dims = comptime![self.shapes_spatial.len()];
119 let mut in_pos = Sequence::<i32>::new();
120
121 #[unroll]
122 for i in 0..spatial_dims {
123 let i = unwrap(i);
124 let dim = comptime![spatial_dims - i - 1];
125 let ksize = comptime![config.kernel_size(dim)];
126 let k_pos = rem % ksize;
127 rem /= ksize;
128
129 let out_pos = *out_offs.index(dim);
130 let stride = comptime![config.stride(dim)];
131 let dilate = comptime![config.dilation(dim)];
132 let pad = comptime![config.padding(dim)];
133
134 let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
135 in_pos.push(pos);
136 }
137
138 let in_pos = in_pos.rev();
139
140 let has_padding = comptime! {
141 let mut has_padding = false;
142 for i in 0..spatial_dims {
143 has_padding |= config.padding(i) != 0;
144 }
145 has_padding
146 };
147
148 let m_in_bounds = comptime!(!config.check_row_bounds(Ident::Lhs)) || view_m < self.shape_m;
149 let k_in_bounds = comptime!(!config.check_col_bounds(Ident::Lhs)) || view_k < self.shape_k;
150 let mut spatial_in_bounds = true;
151
152 if has_padding {
153 #[unroll]
154 for i in 0..spatial_dims {
155 let i = unwrap(i);
156 let pos = *in_pos.index(i);
157 spatial_in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
158 }
159 }
160
161 let in_bounds = m_in_bounds && k_in_bounds && spatial_in_bounds;
162
163 let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
164
165 #[unroll]
166 for i in 0..spatial_dims {
167 let i = unwrap(i);
168 read_pos += *in_pos.index(i) as u32 * *self.strides_spatial.index(i);
169 }
170
171 let read_pos = read_pos / line_size;
172
173 let mut res = Line::empty(line_size).fill(E::from_int(0));
174 if in_bounds {
175 res = self.read(read_pos);
176 }
177
178 res
179 }
180
181 fn read(&self, position: u32) -> Line<E> {
182 self.tensor.read(position)
183 }
184}
185
186#[allow(unused_variables)]
187#[cube]
188fn unwrap(v: u32) -> comptime_type!(u32) {
189 intrinsic!(|_| v.constant().expect("Must be constant").as_u32())
190}