cubecl_linalg/matmul/components/global/load/strategy/
async_full_cyclic.rs1use std::marker::PhantomData;
2
3use crate::matmul::components::{
4 Ident, InputIdent, InvalidConfigError, MatmulPrecision, MatrixLayout,
5 global::{
6 CopyMechanism, GlobalConfig, LoadingValidation, load::AsyncFullLoadingStrategy,
7 tensor_view::TensorReader,
8 },
9 stage::{ContiguousTilingLayout, Stage, TilingOrder},
10};
11use cubecl_core::prelude::*;
12use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
13
14use super::AsyncLoadingJob;
15
16#[derive(CubeType, Clone, Copy)]
17pub struct LoadingStrategy<T: TilingOrder> {
20 #[cube(comptime)]
21 _phantom: PhantomData<T>,
22}
23
24impl<T: TilingOrder> LoadingValidation for LoadingStrategy<T> {
25 fn check<C: GlobalConfig>(config: &C, ident: Ident) -> Result<(), InvalidConfigError> {
26 let tiling = config.tiling_dimensions(ident);
27 let total_units = config.num_planes() * config.plane_dim();
28
29 let num_slices = tiling.tile_shape_row() * tiling.tile_count();
30 if num_slices >= total_units && num_slices % total_units != 0 {
31 return Err(Box::new(format!(
32 "Number of units ({total_units:?}) must divide number of slices ({num_slices:?}). Would require units doing different numbers of slices"
33 )));
34 }
35
36 Ok(())
37 }
38}
39
40#[cube]
41impl<TO: TilingOrder> AsyncFullLoadingStrategy for LoadingStrategy<TO> {
42 type TilingLayout = ContiguousTilingLayout<TO>;
43 type Job<MP: MatmulPrecision> = Job;
44
45 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
46 #[comptime] input_ident: InputIdent,
47 #[comptime] config: G,
48 ) -> Job {
49 let stage_dim = config.tiling_dimensions(input_ident);
50 let total_units = config.plane_dim() * config.num_planes();
51 let line_size = config.global_line_size(input_ident);
52
53 let (num_slices_per_tile, slice_length_in_lines) = match config.matrix_layout(input_ident) {
54 MatrixLayout::RowMajor => (
55 stage_dim.tile_shape_row(),
56 stage_dim.tile_shape_col() / line_size,
57 ),
58 MatrixLayout::ColMajor => (
59 stage_dim.tile_shape_col(),
60 stage_dim.tile_shape_row() / line_size,
61 ),
62 };
63
64 let num_slices = comptime!(num_slices_per_tile * stage_dim.tile_count());
65 let num_tasks_per_unit = num_slices.div_ceil(total_units);
66
67 let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
68
69 Job {
70 unit_id,
71 num_tasks_per_unit,
72 total_units,
73 num_slices,
74 input_ident,
75 num_slices_per_tile,
76 slice_length_in_lines,
77 line_size,
78 }
79 }
80
81 fn barrier_level() -> BarrierLevel {
82 BarrierLevel::cube_manual(0u32)
83 }
84}
85
86#[derive(CubeType, Clone, Copy)]
87pub struct Job {
88 unit_id: u32,
89
90 #[cube(comptime)]
91 num_tasks_per_unit: u32,
92 #[cube(comptime)]
93 total_units: u32,
94 #[cube(comptime)]
95 num_slices: u32,
96 #[cube(comptime)]
97 input_ident: InputIdent,
98 #[cube(comptime)]
99 num_slices_per_tile: u32,
100 #[cube(comptime)]
101 slice_length_in_lines: u32,
102 #[cube(comptime)]
103 line_size: u32,
104}
105
106#[cube]
107impl<MP: MatmulPrecision, TO: TilingOrder> AsyncLoadingJob<MP, ContiguousTilingLayout<TO>> for Job {
108 fn execute_task<CM: CopyMechanism<MP::ES>, G: GlobalConfig>(
109 this: &mut Self,
110 task_id: u32,
111 tensor_reader: &TensorReader<MP::EI>,
112 stage: &mut Stage<MP::ES, ContiguousTilingLayout<TO>>,
113 mechanism: &CM,
114 #[comptime] config: G,
115 ) {
116 let slice_index = this.unit_id + this.total_units * task_id;
117
118 let nth_tile = slice_index / this.num_slices_per_tile;
119 let (tile_x, tile_y) = ContiguousTilingLayout::<TO>::to_x_y::<G::SmmConfig>(
120 nth_tile,
121 comptime!(this.input_ident.as_ident()),
122 config.to_smm_config(),
123 );
124 let nth_slice = slice_index % this.num_slices_per_tile;
125
126 if slice_index < this.num_slices {
128 let window = tensor_reader.load_window_in_tile::<G>(
129 (tile_x, tile_y),
130 nth_slice,
131 this.input_ident,
132 config,
133 );
134
135 let slice_destination_offset =
137 (nth_tile * this.num_slices_per_tile + nth_slice) * this.slice_length_in_lines;
138
139 let mut destination = stage.as_slice_mut(this.line_size).slice_mut(
141 slice_destination_offset,
142 slice_destination_offset + this.slice_length_in_lines,
143 );
144
145 CM::memcpy_async(
146 mechanism,
147 &window.slice.try_cast_unchecked(),
148 &mut destination,
149 );
150 }
151 }
152
153 fn task_count(this: &Self) -> comptime_type!(u32) {
154 this.num_tasks_per_unit
155 }
156}