cubecl_linalg/matmul/components/global/load/strategy/
sync_buffer_cyclic.rs1use std::marker::PhantomData;
2
3use crate::matmul::components::global::load::SyncBufferLoadingStrategy;
4use crate::matmul::components::global::tensor_view::TensorReader;
5use crate::matmul::components::global::{GlobalConfig, LoadingValidation, Quantization};
6use crate::matmul::components::stage::{ContiguousTilingLayout, Stage, TilingOrder};
7use crate::matmul::components::{Ident, InputIdent, InvalidConfigError, MatmulPrecision};
8use cubecl_core as cubecl;
9use cubecl_core::prelude::*;
10use cubecl_std::{CubeOption, CubeOptionExpand};
11
12use super::LoadingJob;
13
14#[derive(CubeType, Clone, Copy)]
15pub struct LoadingStrategy<T: TilingOrder> {
18 #[cube(comptime)]
19 tiling_order: PhantomData<T>,
20}
21
22impl<TO: TilingOrder> LoadingValidation for LoadingStrategy<TO> {
23 fn check<C: GlobalConfig>(config: &C, ident: Ident) -> Result<(), InvalidConfigError> {
24 let tiling_dimensions = config.tiling_dimensions(ident);
25 let line_size = config.global_line_size(ident);
26 let tile_size = tiling_dimensions.tile_size();
27 let tile_count_row = tiling_dimensions.tile_count_row();
28 let tile_count_col = tiling_dimensions.tile_count_col();
29 let num_lines_per_tile = tile_size / line_size;
30
31 let total_units = config.plane_dim() * config.num_planes();
32 let jump_length = total_units * line_size;
33 let num_tiles_in_buffer = comptime! {match ident.as_input_ident() {
34 InputIdent::Lhs => tile_count_row,
35 InputIdent::Rhs => tile_count_col,
36 }};
37 let total_num_lines = num_tiles_in_buffer * num_lines_per_tile;
38 let num_lines_per_unit = (total_num_lines + total_units - 1) / total_units;
39
40 let total_num_lines = num_tiles_in_buffer * num_lines_per_tile;
41 let out_of_bounds_pos = total_num_lines * line_size;
42
43 let max_id = total_units - 1;
44 let max_iter = num_lines_per_unit - 1;
45 let max_position_base = max_id * line_size;
46 let max_position = max_position_base + max_iter * jump_length;
47
48 if max_position > out_of_bounds_pos {
49 return Err(Box::new(
50 "Too many data will be loaded, resulting in out-of-bounds",
51 ));
52 }
53
54 Ok(())
55 }
56}
57
58#[cube]
59impl<TO: TilingOrder> SyncBufferLoadingStrategy for LoadingStrategy<TO> {
60 type TilingLayout = ContiguousTilingLayout<TO>;
61 type Job<MP: MatmulPrecision> = Job;
62
63 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
64 #[comptime] buffer_index: u32,
65 #[comptime] input_ident: InputIdent,
66 #[comptime] config: G,
67 ) -> Job {
68 let tiling_dimensions = config.tiling_dimensions(input_ident);
69 let line_size = config.global_line_size(input_ident);
70 let tile_size = tiling_dimensions.tile_size();
71 let tile_count_row = tiling_dimensions.tile_count_row();
72 let tile_count_col = tiling_dimensions.tile_count_col();
73
74 let num_lines_per_tile = tile_size / line_size;
75 let total_units = config.plane_dim() * config.num_planes();
76 let jump_length = total_units * line_size;
77
78 let num_tiles_in_buffer = comptime! {match input_ident {
79 InputIdent::Lhs => tile_count_row,
80 InputIdent::Rhs => tile_count_col,
81 }};
82 let total_num_lines = num_tiles_in_buffer * num_lines_per_tile;
83 let num_tasks = (total_num_lines + total_units - 1) / total_units;
84
85 let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
86 let unit_position_base = unit_id * line_size;
87
88 Job {
89 unit_position_base,
90 num_tasks,
91 buffer_index,
92 jump_length,
93 num_lines_per_tile,
94 input_ident,
95 }
96 }
97}
98
99#[derive(CubeType, Clone, Copy)]
100pub struct Job {
101 unit_position_base: u32,
102
103 #[cube(comptime)]
104 num_tasks: u32,
105 #[cube(comptime)]
106 buffer_index: u32,
107 #[cube(comptime)]
108 jump_length: u32,
109 #[cube(comptime)]
110 num_lines_per_tile: u32,
111 #[cube(comptime)]
112 input_ident: InputIdent,
113}
114
115#[cube]
116impl<MP: MatmulPrecision, TO: TilingOrder> LoadingJob<MP, ContiguousTilingLayout<TO>> for Job {
117 fn execute_task<G: GlobalConfig>(
118 this: &mut Self,
119 task_id: u32,
120 tensor_reader: &TensorReader<MP::EI>,
121 stage: &mut Stage<MP::ES, ContiguousTilingLayout<TO>>,
122 quantization: &CubeOption<Quantization<MP>>,
123 #[comptime] config: G,
124 ) {
125 let (line_size, tile_size, tile_count_row, tile_count_col) = comptime! {
126 let tiling_dimensions = config.tiling_dimensions(this.input_ident);
127 (
128 config.global_line_size(this.input_ident),
129 tiling_dimensions.tile_size(),
130 tiling_dimensions.tile_count_row(),
131 tiling_dimensions.tile_count_col()
132 )
133 };
134
135 let unit_position = this.unit_position_base + task_id * this.jump_length;
136
137 let unit_pos_in_buffer = unit_position / tile_size;
141 let pos_within_tile = unit_position % tile_size;
142
143 let (tile_x, tile_y) = match comptime!(this.input_ident) {
144 InputIdent::Lhs => (unit_pos_in_buffer, this.buffer_index.runtime()),
145 InputIdent::Rhs => (this.buffer_index.runtime(), unit_pos_in_buffer),
146 };
147
148 let nth_tile = TO::to_nth_tile::<G::SmmConfig>(
149 tile_x,
150 tile_y,
151 tile_count_row,
152 tile_count_col,
153 comptime!(this.input_ident.as_ident()),
154 config.to_smm_config(),
155 );
156
157 let line_read = tensor_reader.load_coalesced_in_tile::<G>(
158 tile_x,
159 tile_y,
160 pos_within_tile,
161 this.input_ident,
162 config,
163 );
164
165 let tile_start = nth_tile * this.num_lines_per_tile;
166 let tile_end = tile_start + this.num_lines_per_tile;
167 let mut tile_slice = stage
168 .as_slice_mut(line_size)
169 .slice_mut(tile_start, tile_end);
170
171 tile_slice[pos_within_tile / line_size] = match quantization {
172 CubeOption::Some(quantization) => quantization.dequantize(line_read, this.input_ident),
173 CubeOption::None => Line::cast_from(line_read),
174 }
175 }
176
177 fn task_count(this: &Self) -> comptime_type!(u32) {
178 this.num_tasks
179 }
180}