cubecl_matmul/components/global/read/strategy/
async_full_cyclic.rs1use std::marker::PhantomData;
2
3use crate::components::{
4 InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
5 global::{
6 CopyMechanism, GlobalConfig, RoleRule,
7 memory::{GlobalIterator, load_window_in_tile},
8 read::AsyncFullLoadingStrategy,
9 },
10 stage::{ContiguousTilingLayout, StridedStage, TilingOrder, TilingValidation},
11};
12use cubecl_core::prelude::*;
13use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
14
15use super::{AsyncLoadingJob, LoadingValidation};
16
17#[derive(CubeType, Clone, Copy)]
18pub struct AsyncFullCyclicLoading<T: TilingOrder> {
21 #[cube(comptime)]
22 _phantom: PhantomData<T>,
23}
24
25impl<T: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<T> {
26 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
27 let total_units = config.num_loading_planes(ident) * config.plane_dim();
28 let num_slices = config.tiling_scheme().elements_in_tile_row(ident)
29 * config.tiling_scheme().tiles_in_stage(ident);
30
31 if num_slices >= total_units && !num_slices.is_multiple_of(total_units) {
32 return Err(Box::new(format!(
33 "Number of units ({total_units:?}) must divide number of slices ({num_slices:?}). Would require units doing different numbers of slices"
34 )));
35 }
36
37 ContiguousTilingLayout::<T>::check(config.global_memory_config(ident))?;
38
39 Ok(())
40 }
41}
42
43#[cube]
44impl<TO: TilingOrder> AsyncFullLoadingStrategy for AsyncFullCyclicLoading<TO> {
45 type TilingLayout = ContiguousTilingLayout<TO>;
46 type Job<IP: MatrixPrecision> = AsyncFullCyclicJob;
47
48 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
49 #[comptime] ident: MatmulIdent,
50 #[comptime] config: G,
51 ) -> AsyncFullCyclicJob {
52 let total_units = config.plane_dim() * config.num_loading_planes(ident);
53 let line_size = config.global_line_size(ident);
54
55 let (num_slices_per_tile, slice_length_in_lines) = match config.matrix_layout(ident) {
56 MatrixLayout::RowMajor => (
57 config.tiling_scheme().elements_in_tile_row(ident),
58 config.tiling_scheme().elements_in_tile_col(ident) / line_size,
59 ),
60 MatrixLayout::ColMajor => (
61 config.tiling_scheme().elements_in_tile_col(ident),
62 config.tiling_scheme().elements_in_tile_row(ident) / line_size,
63 ),
64 };
65
66 let num_slices =
67 comptime!(num_slices_per_tile * config.tiling_scheme().tiles_in_stage(ident));
68 let num_tasks_per_unit = num_slices.div_ceil(total_units);
69
70 let unit_id = RoleRule::new(config.role_rule_config())
71 .load_index(ident, config.specialized_loading_sides())
72 * config.plane_dim()
73 + UNIT_POS_X;
74
75 AsyncFullCyclicJob {
76 unit_id,
77 num_tasks_per_unit,
78 total_units,
79 num_slices,
80 ident,
81 num_slices_per_tile,
82 slice_length_in_lines,
83 line_size,
84 }
85 }
86
87 fn barrier_level() -> BarrierLevel {
88 BarrierLevel::cube_manual(0u32)
89 }
90}
91
92#[derive(CubeType, Clone, Copy)]
93pub struct AsyncFullCyclicJob {
94 unit_id: u32,
95
96 #[cube(comptime)]
97 num_tasks_per_unit: u32,
98 #[cube(comptime)]
99 total_units: u32,
100 #[cube(comptime)]
101 num_slices: u32,
102 #[cube(comptime)]
103 ident: MatmulIdent,
104 #[cube(comptime)]
105 num_slices_per_tile: u32,
106 #[cube(comptime)]
107 slice_length_in_lines: u32,
108 #[cube(comptime)]
109 line_size: u32,
110}
111
112#[cube]
113impl<IP: MatrixPrecision, TO: TilingOrder> AsyncLoadingJob<IP, ContiguousTilingLayout<TO>>
114 for AsyncFullCyclicJob
115{
116 fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
117 this: &mut Self,
118 task_id: u32,
119 global_iter: &GlobalIterator<Line<IP::Global>>,
120 stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
121 mechanism: &CM,
122 #[comptime] config: G,
123 ) {
124 let slice_index = this.unit_id + this.total_units * task_id;
125
126 let nth_tile = slice_index / this.num_slices_per_tile;
127 let (tile_x, tile_y) = ContiguousTilingLayout::<TO>::to_x_y(
128 nth_tile,
129 comptime!(config.stage_memory_config(this.ident)),
130 );
131 let nth_slice = slice_index % this.num_slices_per_tile;
132
133 if slice_index < this.num_slices {
135 let window = load_window_in_tile(
136 &global_iter.view(),
137 (tile_x, tile_y),
138 nth_slice,
139 comptime!(config.global_memory_config(this.ident)),
140 );
141
142 let slice_destination_offset =
144 (nth_tile * this.num_slices_per_tile + nth_slice) * this.slice_length_in_lines;
145
146 let mut destination = stage.as_slice_mut(this.line_size).slice_mut(
148 slice_destination_offset,
149 slice_destination_offset + this.slice_length_in_lines,
150 );
151
152 CM::memcpy_async(mechanism, &window.try_cast_unchecked(), &mut destination);
153 }
154 }
155
156 fn task_count(this: &Self) -> comptime_type!(u32) {
157 this.num_tasks_per_unit
158 }
159}