cubecl_linalg/convolution/reader/
im2col.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor};
4
5use crate::{convolution::ConvGemmConfig, matmul::components::Ident};
6
7#[derive(CubeType)]
8/// A view of a feature map tensor that starts reading data from a specified offset.
9/// Ensures safe access by preventing out-of-bounds errors.
10/// Includes pre-fetched shapes and strides for optimized performance.
11pub struct Im2colReader<E: Numeric> {
12    pub tensor: VirtualTensor<E>,
13    pub m_offset: u32,
14    pub k_offset: u32,
15
16    pub stride_batch: u32,
17    pub stride_y: u32,
18    pub stride_x: u32,
19    pub stride_channel: u32,
20
21    pub shape_y: u32,
22    pub shape_x: u32,
23    pub shape_channel: u32,
24
25    pub shape_out_y: FastDivmod,
26    pub shape_out_x: FastDivmod,
27
28    pub shape_m: u32,
29    pub shape_k: u32,
30}
31
32#[cube]
33impl<E: Numeric> Im2colReader<E> {
34    #[allow(clippy::too_many_arguments)]
35    pub fn new(
36        tensor: VirtualTensor<E>,
37        shape_out_y: FastDivmod,
38        shape_out_x: FastDivmod,
39        x_offset: u32,
40        y_offset: u32,
41        shape_k: u32,
42        shape_channel: u32,
43        shape_m: u32,
44    ) -> Im2colReader<E> {
45        let stride_batch = tensor.stride(0);
46        let stride_y = tensor.stride(1);
47        let stride_x = tensor.stride(2);
48        let stride_channel = tensor.stride(3);
49        let shape_y = tensor.shape(1);
50        let shape_x = tensor.shape(2);
51
52        Im2colReader::<E> {
53            tensor,
54            m_offset: x_offset,
55            k_offset: y_offset,
56            stride_batch,
57            stride_y,
58            stride_x,
59            stride_channel,
60            shape_y,
61            shape_x,
62            shape_channel,
63            shape_out_y,
64            shape_out_x,
65            shape_m,
66            shape_k,
67        }
68    }
69}
70
71unsafe impl<E: Numeric> Sync for Im2colReader<E> {}
72unsafe impl<E: Numeric> Send for Im2colReader<E> {}
73
74#[cube]
75impl<E: Numeric> Im2colReader<E> {
76    /// Advance the view along the k dimension by a specified offset, `k_offset`.
77    pub fn update_view(&mut self, k_offset: u32) {
78        self.k_offset += k_offset;
79    }
80
81    /// Reads data from the tensor view at the specified tile coordinates (tile_x, tile_y) using
82    /// the `im2col` algorithm to translate them to input coordinates.
83    ///
84    /// Each unit loads one line in a coalesced manner for improved efficiency.
85    /// For row-major tensors, subsequent units read lines horizontally within the tile,
86    /// while for column-major tensors, they read lines vertically.
87    ///
88    /// # Note
89    ///
90    /// Out-of-bounds reads will be translated to zeros.
91    pub fn load_simple<G: ConvGemmConfig>(
92        &self,
93        tile_x: u32,
94        tile_y: u32,
95        unit_id: u32,
96        #[comptime] ident: Ident,
97        #[comptime] config: G,
98    ) -> Line<E> {
99        let line_size = config.global_line_size(ident);
100        let tile_size_x = config.tiling_dimensions(ident).tile_shape_row();
101        let tile_size_y = config.tiling_dimensions(ident).tile_shape_col();
102
103        let view_tile_m = tile_x * tile_size_x + self.m_offset;
104        let view_tile_k = tile_y * tile_size_y + self.k_offset;
105
106        let load_m = unit_id / tile_size_y;
107        let load_k = unit_id % tile_size_y;
108
109        let view_m = view_tile_m + load_m;
110        let view_k = view_tile_k + load_k;
111
112        let (out_nh, out_x) = self.shape_out_x.div_mod(view_m);
113        let (batch, out_y) = self.shape_out_y.div_mod(out_nh);
114
115        let kernel_w = config.kernel_size(1);
116
117        let channel = view_k % self.shape_channel;
118        let rem = view_k / self.shape_channel;
119        let kernel_x = rem % kernel_w;
120        let kernel_y = rem / kernel_w;
121
122        let y =
123            (out_y * config.stride(0) + kernel_y * config.dilation(0)) as i32 - config.padding(0);
124        let x =
125            (out_x * config.stride(1) + kernel_x * config.dilation(1)) as i32 - config.padding(1);
126
127        let m_in_bounds = comptime!(!config.check_row_bounds(Ident::Lhs)) || view_m < self.shape_m;
128        let k_in_bounds = comptime!(!config.check_col_bounds(Ident::Lhs)) || view_k < self.shape_k;
129        let no_padding = comptime!(config.padding(0) == 0 && config.padding(1) == 0);
130        let hw_in_bounds = no_padding
131            || (y >= 0 && (y as u32) < self.shape_y && x >= 0 && (x as u32) < self.shape_x);
132        let in_bounds = m_in_bounds && k_in_bounds && hw_in_bounds;
133        let read_pos = batch * self.stride_batch
134            + y as u32 * self.stride_y
135            + x as u32 * self.stride_x
136            + channel * self.stride_channel;
137
138        let read_pos = read_pos / line_size;
139
140        let mut res = Line::empty(line_size).fill(E::from_int(0));
141        if in_bounds {
142            res = self.read(read_pos);
143        }
144
145        res
146    }
147
148    fn read(&self, position: u32) -> Line<E> {
149        self.tensor.read(position)
150    }
151}