cubecl_convolution/loader/
im2col.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use cubecl_matmul::components::global::load::LoaderMode;
5use cubecl_std::div_ceil;
6use cubecl_std::tensor::r#virtual::VirtualTensor;
7use std::marker::PhantomData;
8
9use crate::base::RuntimeArgs;
10use crate::{ConvGemmConfig, reader::im2col::Im2colReader};
11use cubecl_matmul::components::{
12    Ident, InputIdent, MatmulPrecision,
13    stage::{ContiguousTilingLayout, FullStageToTileReader, RowMajorTilingOrder, StageMemory},
14};
15
16/// Loader that translates matrix coordinates to input coordinates using the `im2col` algorithm
17#[derive(CubeType)]
18pub struct SimpleIm2colLoader<MP: MatmulPrecision, G: ConvGemmConfig> {
19    pub tensor_view: Im2colReader<MP::EI>,
20    pub stage: StageMemory<MP::ES, ContiguousTilingLayout<RowMajorTilingOrder>>,
21    #[cube(comptime)]
22    _config: PhantomData<G>,
23}
24
25#[cube]
26impl<MP: MatmulPrecision, G: ConvGemmConfig> SimpleIm2colLoader<MP, G> {
27    pub fn new(
28        tensor: VirtualTensor<MP::EI>,
29        x_offset: u32,
30        y_offset: u32,
31        runtime_args: &RuntimeArgs,
32        #[comptime] config: G,
33    ) -> Self {
34        let stage = StageMemory::new::<G::StageConfig>(1u32, Ident::Lhs, config.stage_config());
35
36        let shape_m = runtime_args.size_m;
37        let shape_k = runtime_args.size_k;
38
39        let tensor_view = Im2colReader::<MP::EI>::new(
40            tensor,
41            comptime![runtime_args.out_shape.clone()],
42            x_offset,
43            y_offset,
44            shape_k,
45            shape_m,
46        );
47
48        SimpleIm2colLoader::<MP, G> {
49            tensor_view,
50            stage,
51            _config: PhantomData::<G>,
52        }
53    }
54
55    pub fn advance_view(this: &mut Self, k_offset: u32) {
56        this.tensor_view.update_view(k_offset);
57    }
58
59    pub fn reader(
60        this: &Self,
61    ) -> FullStageToTileReader<MP::ES, ContiguousTilingLayout<RowMajorTilingOrder>> {
62        FullStageToTileReader::new(this.stage, InputIdent::Lhs)
63    }
64
65    pub fn fill_stage(this: &mut Self, #[comptime] config: G) {
66        let line_size = config.global_line_size(Ident::Lhs);
67        SimpleIm2col::load_to_slice::<MP, G>(
68            &this.tensor_view,
69            &mut this.stage.as_slice_mut(line_size),
70            Ident::Lhs,
71            config,
72        );
73    }
74}
75
76#[derive(CubeType, Clone, Copy)]
77/// Loads the content of all tiles in the tensor view using all planes,
78/// iterating with steps determined by the plane's dimension.
79pub struct SimpleIm2col;
80
81#[cube]
82impl SimpleIm2col {
83    pub fn load_to_slice<MP: MatmulPrecision, G: ConvGemmConfig>(
84        tensor_reader: &Im2colReader<MP::EI>,
85        slice: &mut SliceMut<Line<MP::ES>>,
86        #[comptime] ident: Ident,
87        #[comptime] config: G,
88    ) {
89        let line_size = config.global_line_size(ident);
90
91        let num_stage_elements = config.tiling_scheme().elements_in_stage(ident);
92        let total_units = comptime!(config.num_loading_planes(ident) * config.plane_dim());
93
94        let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
95        let unit_position_base = unit_id * line_size;
96
97        if let LoaderMode::Strict = config.loader_mode() {
98            let jump_length = comptime!(total_units * line_size);
99
100            comptime! {
101                            assert!(
102                num_stage_elements % jump_length == 0,
103                "Too many data will be loaded, resulting in out of bounds.
104            Try setting line size and number of planes so that jump_length divides num_stage_elements."
105            );
106                    }
107
108            let num_loads_per_unit = num_stage_elements / jump_length;
109
110            for i in 0..num_loads_per_unit {
111                let unit_position = unit_position_base + i * jump_length;
112
113                load_at_position::<MP, G>(tensor_reader, slice, unit_position, ident, config);
114            }
115        } else {
116            let jump_length = comptime!(total_units * line_size);
117            let num_loads_per_unit = div_ceil(num_stage_elements, jump_length);
118
119            for i in 0..num_loads_per_unit {
120                let unit_position = unit_position_base + i * jump_length;
121
122                if unit_position < num_stage_elements {
123                    load_at_position::<MP, G>(tensor_reader, slice, unit_position, ident, config);
124                }
125            }
126        }
127    }
128}
129
130#[cube]
131fn load_at_position<MP: MatmulPrecision, G: ConvGemmConfig>(
132    tensor_reader: &Im2colReader<MP::EI>,
133    slice: &mut SliceMut<Line<MP::ES>>,
134    unit_position: u32,
135    #[comptime] ident: Ident,
136    #[comptime] config: G,
137) {
138    let line_size = config.global_line_size(ident);
139    let tile_num_elements = config.tiling_scheme().elements_in_tile(ident);
140    let nth_tile = unit_position / tile_num_elements;
141    let pos_within_tile = unit_position % tile_num_elements;
142
143    let (tile_x, tile_y) = ContiguousTilingLayout::<RowMajorTilingOrder>::to_x_y::<G::StageConfig>(
144        nth_tile,
145        ident,
146        config.stage_config(),
147    );
148
149    let line_read = tensor_reader.load_simple::<G>(tile_x, tile_y, pos_within_tile, ident, config);
150
151    slice[unit_position / line_size] = Line::cast_from(line_read);
152}