cubecl_matmul/components/global/read/strategy/
async_full_cyclic.rs

1use std::marker::PhantomData;
2
3use crate::components::{
4    InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
5    global::{
6        CopyMechanism, GlobalConfig, RoleRule,
7        memory::{GlobalIterator, load_window_in_tile},
8        read::AsyncFullLoadingStrategy,
9    },
10    stage::{ContiguousTilingLayout, StridedStage, TilingOrder, TilingValidation},
11};
12use cubecl_core::prelude::*;
13use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
14
15use super::{AsyncLoadingJob, LoadingValidation};
16
17#[derive(CubeType, Clone, Copy)]
18/// Loads the content of all tiles in the stage memory using all planes,
19/// iterating with steps determined by the plane's dimension.
20pub struct AsyncFullCyclicLoading<T: TilingOrder> {
21    #[cube(comptime)]
22    _phantom: PhantomData<T>,
23}
24
25impl<T: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<T> {
26    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
27        let total_units = config.num_loading_planes(ident) * config.plane_dim();
28        let num_slices = config.tiling_scheme().elements_in_tile_row(ident)
29            * config.tiling_scheme().tiles_in_stage(ident);
30
31        if num_slices >= total_units && !num_slices.is_multiple_of(total_units) {
32            return Err(Box::new(format!(
33                "Number of units ({total_units:?}) must divide number of slices ({num_slices:?}). Would require units doing different numbers of slices"
34            )));
35        }
36
37        ContiguousTilingLayout::<T>::check(config.global_memory_config(ident))?;
38
39        Ok(())
40    }
41}
42
43#[cube]
44impl<TO: TilingOrder> AsyncFullLoadingStrategy for AsyncFullCyclicLoading<TO> {
45    type TilingLayout = ContiguousTilingLayout<TO>;
46    type Job<IP: MatrixPrecision> = AsyncFullCyclicJob;
47
48    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
49        #[comptime] ident: MatmulIdent,
50        #[comptime] config: G,
51    ) -> AsyncFullCyclicJob {
52        let total_units = config.plane_dim() * config.num_loading_planes(ident);
53        let line_size = config.global_line_size(ident);
54
55        let (num_slices_per_tile, slice_length_in_lines) = match config.matrix_layout(ident) {
56            MatrixLayout::RowMajor => (
57                config.tiling_scheme().elements_in_tile_row(ident),
58                config.tiling_scheme().elements_in_tile_col(ident) / line_size,
59            ),
60            MatrixLayout::ColMajor => (
61                config.tiling_scheme().elements_in_tile_col(ident),
62                config.tiling_scheme().elements_in_tile_row(ident) / line_size,
63            ),
64        };
65
66        let num_slices =
67            comptime!(num_slices_per_tile * config.tiling_scheme().tiles_in_stage(ident));
68        let num_tasks_per_unit = num_slices.div_ceil(total_units);
69
70        let unit_id = RoleRule::new(config.role_rule_config())
71            .load_index(ident, config.specialized_loading_sides())
72            * config.plane_dim()
73            + UNIT_POS_X;
74
75        AsyncFullCyclicJob {
76            unit_id,
77            num_tasks_per_unit,
78            total_units,
79            num_slices,
80            ident,
81            num_slices_per_tile,
82            slice_length_in_lines,
83            line_size,
84        }
85    }
86
87    fn barrier_level() -> BarrierLevel {
88        BarrierLevel::cube_manual(0u32)
89    }
90}
91
92#[derive(CubeType, Clone, Copy)]
93pub struct AsyncFullCyclicJob {
94    unit_id: u32,
95
96    #[cube(comptime)]
97    num_tasks_per_unit: u32,
98    #[cube(comptime)]
99    total_units: u32,
100    #[cube(comptime)]
101    num_slices: u32,
102    #[cube(comptime)]
103    ident: MatmulIdent,
104    #[cube(comptime)]
105    num_slices_per_tile: u32,
106    #[cube(comptime)]
107    slice_length_in_lines: u32,
108    #[cube(comptime)]
109    line_size: u32,
110}
111
112#[cube]
113impl<IP: MatrixPrecision, TO: TilingOrder> AsyncLoadingJob<IP, ContiguousTilingLayout<TO>>
114    for AsyncFullCyclicJob
115{
116    fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
117        this: &mut Self,
118        task_id: u32,
119        global_iter: &GlobalIterator<Line<IP::Global>>,
120        stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
121        mechanism: &CM,
122        #[comptime] config: G,
123    ) {
124        let slice_index = this.unit_id + this.total_units * task_id;
125
126        let nth_tile = slice_index / this.num_slices_per_tile;
127        let (tile_x, tile_y) = ContiguousTilingLayout::<TO>::to_x_y(
128            nth_tile,
129            comptime!(config.stage_memory_config(this.ident)),
130        );
131        let nth_slice = slice_index % this.num_slices_per_tile;
132
133        // TODO make branching comptime conditional (using Reader Mode)
134        if slice_index < this.num_slices {
135            let window = load_window_in_tile(
136                &global_iter.view(),
137                (tile_x, tile_y),
138                nth_slice,
139                comptime!(config.global_memory_config(this.ident)),
140            );
141
142            // Where this unit writes source in the stage
143            let slice_destination_offset =
144                (nth_tile * this.num_slices_per_tile + nth_slice) * this.slice_length_in_lines;
145
146            // Make destination start at offset
147            let mut destination = stage.as_slice_mut(this.line_size).slice_mut(
148                slice_destination_offset,
149                slice_destination_offset + this.slice_length_in_lines,
150            );
151
152            CM::memcpy_async(mechanism, &window.try_cast_unchecked(), &mut destination);
153        }
154    }
155
156    fn task_count(this: &Self) -> comptime_type!(u32) {
157        this.num_tasks_per_unit
158    }
159}