rten/gemm/
im2col.rs

1use std::mem::MaybeUninit;
2use std::ops::Range;
3
4use rten_simd::{SimdInt, SimdMask};
5use rten_tensor::{NdTensorView, Storage};
6
7use super::packing::int8::shift_cast_i8_u8;
8use crate::slice_cast::cast_pod_mut_slice;
9
10/// Maps rows of an [`Im2Col`] matrix to locations in the source image.
11///
12/// For efficiency when packing the image, the locations are premultiplied by
13/// the corresponding stride.
14pub struct RowOffsets {
15    /// Map of row index to `channel * channel_stride`.
16    pub chan: Vec<i32>,
17
18    /// Map of row index to `row * row_stride`.
19    pub y: Vec<i32>,
20
21    /// Map of row index to `col * col_stride`.
22    pub x: Vec<i32>,
23}
24
25/// Maps columns of an [`Im2Col`] matrix to locations in the source image.
26///
27/// For efficiency when packing the image, the locations are premultiplied by
28/// the corresponding stride.
29pub struct ColOffsets {
30    /// Map of column index to `row * row_stride`.
31    pub y: Vec<i32>,
32
33    /// Map of column index to `col * col_stride`.
34    pub x: Vec<i32>,
35}
36
37/// A matrix formed by unrolling patches of an image into columns.
38///
39/// The input image has shape [C,H,W] and is transformed into a matrix with
40/// shape [C * Kh * kW, Oh * Ow] where Kh/Kw are convolution kernel sizes and
41/// Oh/Ow are the number of patches in the Y and X directions.
42///
43/// The matrix is _virtual_ as it is not materialized fully in memory. Instead
44/// blocks of the matrix are materialized during computation.
45pub struct Im2Col<'a, T> {
46    pub image: NdTensorView<'a, T, 3>,
47
48    /// Map of im2col row index to input image coordinate, premultiplied with
49    /// the corresponding stride.
50    ///
51    /// The arrays may be padded to a multiple of a step size specified by the
52    /// GEMM kernel. `n_rows` contains the actual number of rows in the virtual
53    /// matrix.
54    pub row_offsets: RowOffsets,
55
56    /// Map of im2col column index to input image coordinate, premultiplied with
57    /// the corresponding stride.
58    ///
59    /// The arrays may be padded to a multiple of a step size specified by the
60    /// GEMM kernel. `n_cols` contains the actual number of columns in the
61    /// virtual matrix.
62    pub col_offsets: ColOffsets,
63
64    /// Number of columns in the im2col matrix.
65    pub n_cols: usize,
66
67    /// Number of rows in the im2col matrix.
68    pub n_rows: usize,
69
70    /// Maximum valid sum of `row_offsets.y + col_offsets.y`. Values above this
71    /// correspond to the padding region.
72    pub max_y_offset: i32,
73
74    /// Maximum valid sum of `row_offsets.x + col_offsets.x`. Values above this
75    /// correspond to the padding region.
76    pub max_x_offset: i32,
77}
78
79impl<T: Copy + Default> Im2Col<'_, T> {
80    /// Return the number of rows in the im2col matrix.
81    pub fn rows(&self) -> usize {
82        self.n_rows
83    }
84
85    /// Return the number of columns in the im2col matrix.
86    pub fn cols(&self) -> usize {
87        self.n_cols
88    }
89
90    /// Pack part of an image into a packing buffer.
91    ///
92    /// This method is for use by kernels using the "standard" packing buffer
93    /// layout for the B / RHS input.
94    ///
95    /// `NR_REGS` specifies the width of each column panel as a multiple of
96    /// `S::LEN` elements. In other words, `panel_width` must exactly equal
97    /// `NR_REGS * S::LEN`.
98    ///
99    /// # Safety
100    ///
101    /// Caller must ensure SIMD type is supported.
102    #[inline(always)]
103    pub(super) unsafe fn pack_block<S: SimdInt, const NR_REGS: usize>(
104        &self,
105        out: &mut [MaybeUninit<T>],
106        panel_width: usize,
107        rows: Range<usize>,
108        cols: Range<usize>,
109    ) {
110        assert_eq!(panel_width, S::LEN * NR_REGS);
111
112        let col_range = cols.start..cols.end.next_multiple_of(panel_width);
113        let used_size = rows.len() * col_range.len();
114        assert_eq!(out.len(), used_size);
115
116        let col_y_offsets = &self.col_offsets.y[col_range.clone()];
117        let col_x_offsets = &self.col_offsets.x[col_range.clone()];
118        let row_chan_offsets = &self.row_offsets.chan[rows.clone()];
119        let row_y_offsets = &self.row_offsets.y[rows.clone()];
120        let row_x_offsets = &self.row_offsets.x[rows.clone()];
121
122        let img_ptr = self.image.storage().as_ptr();
123
124        // Compute max valid image buffer offset. Used to clamp generated offsets
125        // as a form of bounds check.
126        let img_len = self.image.storage().len();
127        assert!(img_len > 0 && img_len <= i32::MAX as usize);
128        let max_img_offset = S::splat(img_len as i32 - 1);
129
130        // Loop over column panels, then rows, then `S::LEN`-wide column groups
131        // within each panel.
132        let out_ptr = out.as_mut_ptr();
133        let mut out_offset = 0;
134
135        for start_col in (0..col_y_offsets.len()).step_by(S::LEN * NR_REGS) {
136            let col_y_offset: [S; NR_REGS] = std::array::from_fn(|i| {
137                S::load(col_y_offsets.as_ptr().add(start_col + S::LEN * i))
138            });
139            let col_x_offset: [S; NR_REGS] = std::array::from_fn(|i| {
140                S::load(col_x_offsets.as_ptr().add(start_col + S::LEN * i))
141            });
142            let max_x_offset = S::splat(self.max_x_offset);
143            let max_y_offset = S::splat(self.max_y_offset);
144
145            for ((&row_chan_offset, &row_y_offset), &row_x_offset) in row_chan_offsets
146                .iter()
147                .zip(row_y_offsets.iter())
148                .zip(row_x_offsets.iter())
149            {
150                let row_chan_offset = S::splat(row_chan_offset);
151                let row_y_offset = S::splat(row_y_offset);
152                let row_x_offset = S::splat(row_x_offset);
153
154                for i in 0..NR_REGS {
155                    let y_offset = col_y_offset[i].add(row_y_offset);
156                    let x_offset = col_x_offset[i].add(row_x_offset);
157                    let offsets = row_chan_offset
158                        .add(y_offset)
159                        .add(x_offset)
160                        // Ensure offsets cannot be out of bounds even if row /
161                        // column offsets were calculated incorrectly.
162                        .max(S::zero())
163                        .min(max_img_offset);
164
165                    // Create mask to specify offsets which are valid. Others
166                    // correspond to the padding region.
167                    let zero = S::zero();
168                    let pad_mask = y_offset
169                        .ge(zero)
170                        .and(y_offset.le(max_y_offset))
171                        .and(x_offset.ge(zero))
172                        .and(x_offset.le(max_x_offset));
173
174                    // Set offsets to zero for padding elements. We require
175                    // this offset is always valid.
176                    let offsets_array = zero.blend(offsets, pad_mask).to_array();
177                    let pad_mask_array = pad_mask.to_array();
178
179                    // Gather elements and store in packing buffer.
180                    for idx in 0..S::LEN {
181                        let out_ptr: *mut T = std::mem::transmute(out_ptr.add(out_offset + idx));
182
183                        // Safety: Offsets are clamped so they must be in-bounds.
184                        let src_elem = *img_ptr.add(offsets_array[idx] as usize);
185
186                        // This should be compiled to a conditional move.
187                        let elem = if pad_mask_array[idx] {
188                            src_elem
189                        } else {
190                            T::default()
191                        };
192
193                        out_ptr.write(elem);
194                    }
195
196                    out_offset += S::LEN;
197                }
198            }
199        }
200
201        // Check we initialized as many elements as used.
202        assert_eq!(out_offset, used_size);
203    }
204}
205
206impl Im2Col<'_, i8> {
207    /// Pack part of an image into a packing buffer.
208    ///
209    /// This method is for use by kernels using int8 dot product instructions
210    /// to compute `S::LEN x i32` dot products from two input vectors each
211    /// containing `S::LEN x 4 x i8` (or u8) inputs.
212    #[inline(always)]
213    #[allow(unused)] // Some architectures only
214    pub(super) unsafe fn pack_block_i8_dot<S: SimdInt, const NR_REGS: usize>(
215        &self,
216        out: &mut [MaybeUninit<i8>],
217        rows: Range<usize>,
218        cols: Range<usize>,
219    ) {
220        self.pack_block_int8::<S, NR_REGS, false>(out, rows, cols);
221    }
222
223    /// Variant of [`pack_block_i8_dot`](Self::pack_block_i8_dot) which shifts
224    /// i8 values to u8 by adding 128.
225    #[inline(always)]
226    #[allow(unused)] // Some architectures only
227    pub(super) unsafe fn pack_block_i8_dot_cast_u8<S: SimdInt, const NR_REGS: usize>(
228        &self,
229        out: &mut [MaybeUninit<u8>],
230        rows: Range<usize>,
231        cols: Range<usize>,
232    ) {
233        let out = cast_pod_mut_slice(out).unwrap();
234        self.pack_block_int8::<S, NR_REGS, true>(out, rows, cols);
235    }
236
237    #[inline(always)]
238    unsafe fn pack_block_int8<S: SimdInt, const NR_REGS: usize, const CAST_B_U8: bool>(
239        &self,
240        out: &mut [MaybeUninit<i8>],
241        rows: Range<usize>,
242        cols: Range<usize>,
243    ) {
244        const K_TILE: usize = size_of::<i32>() / size_of::<i8>();
245
246        debug_assert!(rows.end <= self.rows());
247        debug_assert!(cols.end <= self.cols());
248
249        let max_x_offset = S::splat(self.max_x_offset);
250        let max_y_offset = S::splat(self.max_y_offset);
251
252        let col_x_offsets = &self.col_offsets.x;
253        debug_assert_eq!(col_x_offsets.len() % S::LEN, 0);
254
255        let col_y_offsets = &self.col_offsets.y;
256        debug_assert_eq!(col_y_offsets.len() % S::LEN, 0);
257
258        let row_x_offsets = &self.row_offsets.x;
259        debug_assert_eq!(row_x_offsets.len() % K_TILE, 0);
260
261        let row_y_offsets = &self.row_offsets.y;
262        debug_assert_eq!(row_y_offsets.len() % K_TILE, 0);
263
264        let row_chan_offsets = &self.row_offsets.chan;
265        debug_assert_eq!(row_chan_offsets.len() % K_TILE, 0);
266
267        let img_ptr = self.image.storage().as_ptr();
268        let out_ptr = out.as_mut_ptr();
269
270        let mut out_offset = 0;
271
272        for start_col in cols.step_by(S::LEN * NR_REGS) {
273            let col_y_offset: [S; NR_REGS] = std::array::from_fn(|i| {
274                S::load(col_y_offsets.get_unchecked(start_col + i * S::LEN))
275            });
276            let col_x_offset: [S; NR_REGS] = std::array::from_fn(|i| {
277                S::load(col_x_offsets.get_unchecked(start_col + i * S::LEN))
278            });
279            let zero = S::zero();
280
281            let mut col_sums = [S::zero().to_array(); NR_REGS];
282
283            for start_row in rows.clone().step_by(4) {
284                for i in 0..K_TILE {
285                    let k = start_row + i;
286                    let row_x_offset = S::splat(*row_x_offsets.get_unchecked(k));
287                    let row_y_offset = S::splat(*row_y_offsets.get_unchecked(k));
288                    let row_chan_offset = S::splat(*row_chan_offsets.get_unchecked(k));
289
290                    for c_block in 0..NR_REGS {
291                        let x_offsets = row_x_offset.add(col_x_offset[c_block]);
292                        let y_offsets = row_y_offset.add(col_y_offset[c_block]);
293                        let offsets = x_offsets.add(y_offsets).add(row_chan_offset);
294
295                        let pad_mask = y_offsets
296                            .ge(zero)
297                            .and(y_offsets.le(max_y_offset))
298                            .and(x_offsets.ge(zero))
299                            .and(x_offsets.le(max_x_offset));
300                        let pad_mask_array = pad_mask.to_array();
301
302                        // Set offsets to zero for padding elements. We require
303                        // this offset is always valid.
304                        let offsets_array = zero.blend(offsets, pad_mask).to_array();
305
306                        for idx in 0..S::LEN {
307                            let out_ptr =
308                                out_ptr.add(out_offset + (c_block * S::LEN + idx) * K_TILE + i);
309                            let src_elem = *img_ptr.add(offsets_array[idx] as usize);
310
311                            if CAST_B_U8 {
312                                let src_elem = shift_cast_i8_u8(src_elem);
313                                let elem = if pad_mask_array[idx] { src_elem } else { 0 };
314                                col_sums[c_block][idx] += elem as i32;
315                                out_ptr.write(MaybeUninit::new(elem as i8));
316                            } else {
317                                let elem = if pad_mask_array[idx] { src_elem } else { 0 };
318                                col_sums[c_block][idx] += elem as i32;
319                                out_ptr.write(MaybeUninit::new(elem));
320                            }
321                        }
322                    }
323                }
324                out_offset += S::LEN * NR_REGS * K_TILE;
325            }
326
327            // Store column sums at end of each panel.
328            for c_block in 0..NR_REGS {
329                let col_sum_ptr = out_ptr.add(out_offset) as *mut i32;
330                for i in 0..S::LEN {
331                    *col_sum_ptr.add(i) = col_sums[c_block][i];
332                }
333                out_offset += S::LEN * K_TILE;
334            }
335        }
336
337        // Sanity check
338        assert_eq!(out_offset, out.len());
339    }
340}