cubecl_linalg/convolution/reader/
im2col.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
4
5use crate::{convolution::ConvGemmConfig, matmul::components::Ident};
6
7#[derive(CubeType)]
8pub struct Im2colReader<E: Numeric> {
12 pub tensor: VirtualTensor<E>,
13 pub m_offset: u32,
14 pub k_offset: u32,
15
16 pub stride_batch: u32,
17 pub stride_y: u32,
18 pub stride_x: u32,
19 pub stride_channel: u32,
20
21 pub shape_y: u32,
22 pub shape_x: u32,
23 pub shape_channel: u32,
24
25 pub shape_out_y: FastDivmod,
26 pub shape_out_x: FastDivmod,
27
28 pub shape_m: u32,
29 pub shape_k: u32,
30}
31
32#[cube]
33impl<E: Numeric> Im2colReader<E> {
34 #[allow(clippy::too_many_arguments)]
35 pub fn new(
36 tensor: VirtualTensor<E>,
37 shape_out_y: FastDivmod,
38 shape_out_x: FastDivmod,
39 x_offset: u32,
40 y_offset: u32,
41 shape_k: u32,
42 shape_channel: u32,
43 shape_m: u32,
44 ) -> Im2colReader<E> {
45 let stride_batch = tensor.stride(0);
46 let stride_y = tensor.stride(1);
47 let stride_x = tensor.stride(2);
48 let stride_channel = tensor.stride(3);
49 let shape_y = tensor.shape(1);
50 let shape_x = tensor.shape(2);
51
52 Im2colReader::<E> {
53 tensor,
54 m_offset: x_offset,
55 k_offset: y_offset,
56 stride_batch,
57 stride_y,
58 stride_x,
59 stride_channel,
60 shape_y,
61 shape_x,
62 shape_channel,
63 shape_out_y,
64 shape_out_x,
65 shape_m,
66 shape_k,
67 }
68 }
69}
70
71unsafe impl<E: Numeric> Sync for Im2colReader<E> {}
72unsafe impl<E: Numeric> Send for Im2colReader<E> {}
73
74#[cube]
75impl<E: Numeric> Im2colReader<E> {
76 pub fn update_view(&mut self, k_offset: u32) {
78 self.k_offset += k_offset;
79 }
80
81 pub fn load_simple<G: ConvGemmConfig>(
92 &self,
93 tile_x: u32,
94 tile_y: u32,
95 unit_id: u32,
96 #[comptime] ident: Ident,
97 #[comptime] config: G,
98 ) -> Line<E> {
99 let line_size = config.global_line_size(ident);
100 let tile_size_x = config.tiling_dimensions(ident).tile_shape_row();
101 let tile_size_y = config.tiling_dimensions(ident).tile_shape_col();
102
103 let view_tile_m = tile_x * tile_size_x + self.m_offset;
104 let view_tile_k = tile_y * tile_size_y + self.k_offset;
105
106 let load_m = unit_id / tile_size_y;
107 let load_k = unit_id % tile_size_y;
108
109 let view_m = view_tile_m + load_m;
110 let view_k = view_tile_k + load_k;
111
112 let (out_nh, out_x) = self.shape_out_x.div_mod(view_m);
113 let (batch, out_y) = self.shape_out_y.div_mod(out_nh);
114
115 let kernel_w = config.kernel_size(1);
116
117 let channel = view_k % self.shape_channel;
118 let rem = view_k / self.shape_channel;
119 let kernel_x = rem % kernel_w;
120 let kernel_y = rem / kernel_w;
121
122 let y =
123 (out_y * config.stride(0) + kernel_y * config.dilation(0)) as i32 - config.padding(0);
124 let x =
125 (out_x * config.stride(1) + kernel_x * config.dilation(1)) as i32 - config.padding(1);
126
127 let m_in_bounds = comptime!(!config.check_row_bounds(Ident::Lhs)) || view_m < self.shape_m;
128 let k_in_bounds = comptime!(!config.check_col_bounds(Ident::Lhs)) || view_k < self.shape_k;
129 let no_padding = comptime!(config.padding(0) == 0 && config.padding(1) == 0);
130 let hw_in_bounds = no_padding
131 || (y >= 0 && (y as u32) < self.shape_y && x >= 0 && (x as u32) < self.shape_x);
132 let in_bounds = m_in_bounds && k_in_bounds && hw_in_bounds;
133 let read_pos = batch * self.stride_batch
134 + y as u32 * self.stride_y
135 + x as u32 * self.stride_x
136 + channel * self.stride_channel;
137
138 let read_pos = read_pos / line_size;
139
140 let mut res = Line::empty(line_size).fill(E::from_int(0));
141 if in_bounds {
142 res = self.read(read_pos);
143 }
144
145 res
146 }
147
148 fn read(&self, position: u32) -> Line<E> {
149 self.tensor.read(position)
150 }
151}