cubecl_convolution/reader/
im2col.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, intrinsic};
3use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
4
5use crate::{ConvGemmConfig, loader::im2col_tma::div_mod_seq};
6use cubecl_matmul::components::Ident;
7
8#[derive(CubeType)]
9/// A view of a feature map tensor that starts reading data from a specified offset.
10/// Ensures safe access by preventing out-of-bounds errors.
11/// Includes pre-fetched shapes and strides for optimized performance.
12pub struct Im2colReader<E: Numeric> {
13    pub tensor: VirtualTensor<E>,
14    pub m_offset: u32,
15    pub k_offset: u32,
16
17    pub stride_batch: u32,
18    pub strides_spatial: Sequence<u32>,
19    pub stride_channel: u32,
20
21    pub shapes_spatial: Sequence<u32>,
22    pub shape_channel: u32,
23
24    pub shape_out: Sequence<FastDivmod>,
25
26    pub shape_m: u32,
27    pub shape_k: u32,
28}
29
30#[cube]
31impl<E: Numeric> Im2colReader<E> {
32    #[allow(clippy::too_many_arguments)]
33    pub fn new(
34        tensor: VirtualTensor<E>,
35        shape_out: Sequence<FastDivmod>,
36        x_offset: u32,
37        y_offset: u32,
38        shape_k: u32,
39        shape_m: u32,
40    ) -> Im2colReader<E> {
41        let spatial_dims = comptime![shape_out.len()];
42        let mut strides_spatial = Sequence::new();
43        let mut shapes_spatial = Sequence::new();
44
45        #[unroll]
46        for i in 0..spatial_dims {
47            strides_spatial.push(tensor.stride(i + 1));
48            shapes_spatial.push(tensor.shape(i + 1));
49        }
50
51        let stride_batch = tensor.stride(0);
52        let stride_channel = tensor.stride(spatial_dims + 1);
53
54        let shape_channel = tensor.shape(spatial_dims + 1);
55
56        Im2colReader::<E> {
57            tensor,
58            m_offset: x_offset,
59            k_offset: y_offset,
60            stride_batch,
61            strides_spatial,
62            stride_channel,
63            shapes_spatial,
64            shape_channel,
65            shape_out,
66            shape_m,
67            shape_k,
68        }
69    }
70}
71
72unsafe impl<E: Numeric> Sync for Im2colReader<E> {}
73unsafe impl<E: Numeric> Send for Im2colReader<E> {}
74
75#[cube]
76impl<E: Numeric> Im2colReader<E> {
77    /// Advance the view along the k dimension by a specified offset, `k_offset`.
78    pub fn update_view(&mut self, k_offset: u32) {
79        self.k_offset += k_offset;
80    }
81
82    /// Reads data from the tensor view at the specified tile coordinates (tile_x, tile_y) using
83    /// the `im2col` algorithm to translate them to input coordinates.
84    ///
85    /// Each unit loads one line in a coalesced manner for improved efficiency.
86    /// For row-major tensors, subsequent units read lines horizontally within the tile,
87    /// while for column-major tensors, they read lines vertically.
88    ///
89    /// # Note
90    ///
91    /// Out-of-bounds reads will be translated to zeros.
92    pub fn load_simple<G: ConvGemmConfig>(
93        &self,
94        tile_x: u32,
95        tile_y: u32,
96        unit_id: u32,
97        #[comptime] ident: Ident,
98        #[comptime] config: G,
99    ) -> Line<E> {
100        let line_size = config.global_line_size(ident);
101        let tile_size_x = config.tiling_scheme().elements_in_tile_row(ident);
102        let tile_size_y = config.tiling_scheme().elements_in_tile_col(ident);
103
104        let view_tile_m = tile_x * tile_size_x + self.m_offset;
105        let view_tile_k = tile_y * tile_size_y + self.k_offset;
106
107        let load_m = unit_id / tile_size_y;
108        let load_k = unit_id % tile_size_y;
109
110        let view_m = view_tile_m + load_m;
111        let view_k = view_tile_k + load_k;
112
113        let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
114
115        let channel = view_k % self.shape_channel;
116        let mut rem = view_k / self.shape_channel;
117
118        let spatial_dims = comptime![self.shapes_spatial.len()];
119        let mut in_pos = Sequence::<i32>::new();
120
121        #[unroll]
122        for i in 0..spatial_dims {
123            let i = unwrap(i);
124            let dim = comptime![spatial_dims - i - 1];
125            let ksize = comptime![config.kernel_size(dim)];
126            let k_pos = rem % ksize;
127            rem /= ksize;
128
129            let out_pos = *out_offs.index(dim);
130            let stride = comptime![config.stride(dim)];
131            let dilate = comptime![config.dilation(dim)];
132            let pad = comptime![config.padding(dim)];
133
134            let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
135            in_pos.push(pos);
136        }
137
138        let in_pos = in_pos.rev();
139
140        let has_padding = comptime! {
141            let mut has_padding = false;
142            for i in 0..spatial_dims {
143                has_padding |= config.padding(i) != 0;
144            }
145            has_padding
146        };
147
148        let m_in_bounds = comptime!(!config.check_row_bounds(Ident::Lhs)) || view_m < self.shape_m;
149        let k_in_bounds = comptime!(!config.check_col_bounds(Ident::Lhs)) || view_k < self.shape_k;
150        let mut spatial_in_bounds = true;
151
152        if has_padding {
153            #[unroll]
154            for i in 0..spatial_dims {
155                let i = unwrap(i);
156                let pos = *in_pos.index(i);
157                spatial_in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
158            }
159        }
160
161        let in_bounds = m_in_bounds && k_in_bounds && spatial_in_bounds;
162
163        let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
164
165        #[unroll]
166        for i in 0..spatial_dims {
167            let i = unwrap(i);
168            read_pos += *in_pos.index(i) as u32 * *self.strides_spatial.index(i);
169        }
170
171        let read_pos = read_pos / line_size;
172
173        let mut res = Line::empty(line_size).fill(E::from_int(0));
174        if in_bounds {
175            res = self.read(read_pos);
176        }
177
178        res
179    }
180
181    fn read(&self, position: u32) -> Line<E> {
182        self.tensor.read(position)
183    }
184}
185
186#[allow(unused_variables)]
187#[cube]
188fn unwrap(v: u32) -> comptime_type!(u32) {
189    intrinsic!(|_| v.constant().expect("Must be constant").as_u32())
190}