cubecl_linalg/convolution/loader/
im2col.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use cubecl_std::tensor::r#virtual::VirtualTensor;
5use std::marker::PhantomData;
6
7use crate::{
8    convolution::base::RuntimeArgs,
9    matmul::components::{
10        Ident, InputIdent, MatmulPrecision,
11        stage::{ContiguousTilingLayout, FullReader, RowMajorTilingOrder},
12    },
13};
14use crate::{
15    convolution::{ConvGemmConfig, reader::im2col::Im2colReader},
16    matmul::components::stage::Stage,
17};
18
19/// Loader that translates matrix coordinates to input coordinates using the `im2col` algorithm
20#[derive(CubeType)]
21pub struct SimpleIm2colLoader<MP: MatmulPrecision, G: ConvGemmConfig> {
22    pub tensor_view: Im2colReader<MP::EI>,
23    pub stage: Stage<MP::ES, ContiguousTilingLayout<RowMajorTilingOrder>>,
24    #[cube(comptime)]
25    _config: PhantomData<G>,
26}
27
28#[cube]
29impl<MP: MatmulPrecision, G: ConvGemmConfig> SimpleIm2colLoader<MP, G> {
30    pub fn new(
31        tensor: VirtualTensor<MP::EI>,
32        x_offset: u32,
33        y_offset: u32,
34        runtime_args: &RuntimeArgs,
35        #[comptime] config: G,
36    ) -> Self {
37        let stage = Stage::new::<G::SmmConfig>(Ident::Lhs, config.to_smm_config());
38        let shape_channel = tensor.shape(3);
39
40        let shape_m = runtime_args.size_m;
41        let shape_k = runtime_args.size_k;
42
43        let tensor_view = Im2colReader::<MP::EI>::new(
44            tensor,
45            runtime_args.out_h,
46            runtime_args.out_w,
47            x_offset,
48            y_offset,
49            shape_k,
50            shape_channel,
51            shape_m,
52        );
53
54        SimpleIm2colLoader::<MP, G> {
55            tensor_view,
56            stage,
57            _config: PhantomData::<G>,
58        }
59    }
60
61    pub fn advance_view(this: &mut Self, k_offset: u32) {
62        this.tensor_view.update_view(k_offset);
63    }
64
65    pub fn reader(this: &Self) -> FullReader<MP::ES, ContiguousTilingLayout<RowMajorTilingOrder>> {
66        FullReader::new(this.stage, InputIdent::Lhs)
67    }
68
69    pub fn fill_stage(this: &mut Self, #[comptime] config: G) {
70        let line_size = config.global_line_size(Ident::Lhs);
71        SimpleIm2col::load_to_slice::<MP, G>(
72            &this.tensor_view,
73            &mut this.stage.as_slice_mut(line_size),
74            Ident::Lhs,
75            config,
76        );
77    }
78}
79
80#[derive(CubeType, Clone, Copy)]
81/// Loads the content of all tiles in the tensor view using all planes,
82/// iterating with steps determined by the plane's dimension.
83pub struct SimpleIm2col;
84
85#[cube]
86impl SimpleIm2col {
87    pub fn load_to_slice<MP: MatmulPrecision, G: ConvGemmConfig>(
88        tensor_reader: &Im2colReader<MP::EI>,
89        slice: &mut SliceMut<Line<MP::ES>>,
90        #[comptime] ident: Ident,
91        #[comptime] config: G,
92    ) {
93        let stage_tiling = config.tiling_dimensions(ident);
94        let line_size = config.global_line_size(ident);
95
96        let num_stage_elements = stage_tiling.total_size();
97        let total_units = comptime!(config.num_planes() * config.plane_dim());
98        let jump_length = comptime!(total_units * line_size);
99        let num_loads_per_unit = num_stage_elements / jump_length;
100
101        #[allow(clippy::all)]
102        let _ = comptime!(check_jump_divides_well(num_stage_elements, jump_length));
103
104        let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
105        let unit_position_base = unit_id * line_size;
106
107        for i in 0..num_loads_per_unit {
108            let unit_position = unit_position_base + i * jump_length;
109
110            let tile_num_elements = stage_tiling.tile_size();
111            let nth_tile = unit_position / tile_num_elements;
112            let pos_within_tile = unit_position % tile_num_elements;
113
114            let (tile_x, tile_y) = ContiguousTilingLayout::<RowMajorTilingOrder>::to_x_y::<
115                G::SmmConfig,
116            >(nth_tile, ident, config.to_smm_config());
117
118            let line_read =
119                tensor_reader.load_simple::<G>(tile_x, tile_y, pos_within_tile, ident, config);
120
121            slice[unit_position / line_size] = Line::cast_from(line_read);
122        }
123    }
124}
125
126pub fn check_jump_divides_well(num_stage_elements: u32, jump_length: u32) {
127    assert!(
128        num_stage_elements % jump_length == 0,
129        "Too many data will be loaded, resulting in out of bounds. 
130    Try setting line size and number of planes so that jump_length divides num_stage_elements."
131    );
132}