cubecl_linalg/matmul/components/global/load/strategy/
sync_buffer_cyclic.rs

1use std::marker::PhantomData;
2
3use crate::matmul::components::global::load::SyncBufferLoadingStrategy;
4use crate::matmul::components::global::tensor_view::TensorReader;
5use crate::matmul::components::global::{GlobalConfig, LoadingValidation, Quantization};
6use crate::matmul::components::stage::{ContiguousTilingLayout, Stage, TilingOrder};
7use crate::matmul::components::{Ident, InputIdent, InvalidConfigError, MatmulPrecision};
8use cubecl_core as cubecl;
9use cubecl_core::prelude::*;
10use cubecl_std::{CubeOption, CubeOptionExpand};
11
12use super::LoadingJob;
13
14#[derive(CubeType, Clone, Copy)]
15/// Loads the content of all tiles in the tensor view using all planes,
16/// iterating with steps determined by the plane's dimension.
17pub struct LoadingStrategy<T: TilingOrder> {
18    #[cube(comptime)]
19    tiling_order: PhantomData<T>,
20}
21
22impl<TO: TilingOrder> LoadingValidation for LoadingStrategy<TO> {
23    fn check<C: GlobalConfig>(config: &C, ident: Ident) -> Result<(), InvalidConfigError> {
24        let tiling_dimensions = config.tiling_dimensions(ident);
25        let line_size = config.global_line_size(ident);
26        let tile_size = tiling_dimensions.tile_size();
27        let tile_count_row = tiling_dimensions.tile_count_row();
28        let tile_count_col = tiling_dimensions.tile_count_col();
29        let num_lines_per_tile = tile_size / line_size;
30
31        let total_units = config.plane_dim() * config.num_planes();
32        let jump_length = total_units * line_size;
33        let num_tiles_in_buffer = comptime! {match ident.as_input_ident() {
34            InputIdent::Lhs => tile_count_row,
35            InputIdent::Rhs => tile_count_col,
36        }};
37        let total_num_lines = num_tiles_in_buffer * num_lines_per_tile;
38        let num_lines_per_unit = (total_num_lines + total_units - 1) / total_units;
39
40        let total_num_lines = num_tiles_in_buffer * num_lines_per_tile;
41        let out_of_bounds_pos = total_num_lines * line_size;
42
43        let max_id = total_units - 1;
44        let max_iter = num_lines_per_unit - 1;
45        let max_position_base = max_id * line_size;
46        let max_position = max_position_base + max_iter * jump_length;
47
48        if max_position > out_of_bounds_pos {
49            return Err(Box::new(
50                "Too many data will be loaded, resulting in out-of-bounds",
51            ));
52        }
53
54        Ok(())
55    }
56}
57
58#[cube]
59impl<TO: TilingOrder> SyncBufferLoadingStrategy for LoadingStrategy<TO> {
60    type TilingLayout = ContiguousTilingLayout<TO>;
61    type Job<MP: MatmulPrecision> = Job;
62
63    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
64        #[comptime] buffer_index: u32,
65        #[comptime] input_ident: InputIdent,
66        #[comptime] config: G,
67    ) -> Job {
68        let tiling_dimensions = config.tiling_dimensions(input_ident);
69        let line_size = config.global_line_size(input_ident);
70        let tile_size = tiling_dimensions.tile_size();
71        let tile_count_row = tiling_dimensions.tile_count_row();
72        let tile_count_col = tiling_dimensions.tile_count_col();
73
74        let num_lines_per_tile = tile_size / line_size;
75        let total_units = config.plane_dim() * config.num_planes();
76        let jump_length = total_units * line_size;
77
78        let num_tiles_in_buffer = comptime! {match input_ident {
79            InputIdent::Lhs => tile_count_row,
80            InputIdent::Rhs => tile_count_col,
81        }};
82        let total_num_lines = num_tiles_in_buffer * num_lines_per_tile;
83        let num_tasks = (total_num_lines + total_units - 1) / total_units;
84
85        let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
86        let unit_position_base = unit_id * line_size;
87
88        Job {
89            unit_position_base,
90            num_tasks,
91            buffer_index,
92            jump_length,
93            num_lines_per_tile,
94            input_ident,
95        }
96    }
97}
98
99#[derive(CubeType, Clone, Copy)]
100pub struct Job {
101    unit_position_base: u32,
102
103    #[cube(comptime)]
104    num_tasks: u32,
105    #[cube(comptime)]
106    buffer_index: u32,
107    #[cube(comptime)]
108    jump_length: u32,
109    #[cube(comptime)]
110    num_lines_per_tile: u32,
111    #[cube(comptime)]
112    input_ident: InputIdent,
113}
114
115#[cube]
116impl<MP: MatmulPrecision, TO: TilingOrder> LoadingJob<MP, ContiguousTilingLayout<TO>> for Job {
117    fn execute_task<G: GlobalConfig>(
118        this: &mut Self,
119        task_id: u32,
120        tensor_reader: &TensorReader<MP::EI>,
121        stage: &mut Stage<MP::ES, ContiguousTilingLayout<TO>>,
122        quantization: &CubeOption<Quantization<MP>>,
123        #[comptime] config: G,
124    ) {
125        let (line_size, tile_size, tile_count_row, tile_count_col) = comptime! {
126            let tiling_dimensions = config.tiling_dimensions(this.input_ident);
127            (
128                config.global_line_size(this.input_ident),
129                tiling_dimensions.tile_size(),
130                tiling_dimensions.tile_count_row(),
131                tiling_dimensions.tile_count_col()
132            )
133        };
134
135        let unit_position = this.unit_position_base + task_id * this.jump_length;
136
137        // We assume unit_position < total_num_lines * line_size;
138        // This is caught by the loading validation
139
140        let unit_pos_in_buffer = unit_position / tile_size;
141        let pos_within_tile = unit_position % tile_size;
142
143        let (tile_x, tile_y) = match comptime!(this.input_ident) {
144            InputIdent::Lhs => (unit_pos_in_buffer, this.buffer_index.runtime()),
145            InputIdent::Rhs => (this.buffer_index.runtime(), unit_pos_in_buffer),
146        };
147
148        let nth_tile = TO::to_nth_tile::<G::SmmConfig>(
149            tile_x,
150            tile_y,
151            tile_count_row,
152            tile_count_col,
153            comptime!(this.input_ident.as_ident()),
154            config.to_smm_config(),
155        );
156
157        let line_read = tensor_reader.load_coalesced_in_tile::<G>(
158            tile_x,
159            tile_y,
160            pos_within_tile,
161            this.input_ident,
162            config,
163        );
164
165        let tile_start = nth_tile * this.num_lines_per_tile;
166        let tile_end = tile_start + this.num_lines_per_tile;
167        let mut tile_slice = stage
168            .as_slice_mut(line_size)
169            .slice_mut(tile_start, tile_end);
170
171        tile_slice[pos_within_tile / line_size] = match quantization {
172            CubeOption::Some(quantization) => quantization.dequantize(line_read, this.input_ident),
173            CubeOption::None => Line::cast_from(line_read),
174        }
175    }
176
177    fn task_count(this: &Self) -> comptime_type!(u32) {
178        this.num_tasks
179    }
180}