cubecl_matmul/components/global/read/strategy/
async_full_cyclic.rs

1use std::marker::PhantomData;
2
3use crate::components::{
4    InvalidConfigError, MatmulElems, MatrixLayout,
5    global::{
6        GlobalReaderConfig, RoleRule,
7        memory::{GlobalIterator, load_window_in_tile},
8        multi_stage::LoadMaxRoundPlaneCount,
9        read::{
10            FullLoadingStrategy, LoadingJob, async_barrier::AsyncBarrier, validate_async_barrier,
11            validate_noswizzle,
12        },
13    },
14    stage::{
15        ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder,
16        TilingValidation,
17    },
18};
19use cubecl_core::prelude::{barrier::Barrier, *};
20use cubecl_core::{self as cubecl};
21
22use super::LoadingValidation;
23
24#[derive(CubeType, Clone, Copy)]
25/// Loads the content of all tiles in the stage memory using all planes,
26/// iterating with steps determined by the plane's dimension.
27pub struct AsyncFullCyclicLoading<T: TilingOrder> {
28    #[cube(comptime)]
29    _phantom: PhantomData<T>,
30}
31
32impl<T: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<T> {
33    fn check<R: Runtime>(
34        client: &ComputeClient<R::Server>,
35        config: &GlobalReaderConfig,
36        _dtypes: &MatmulElems,
37    ) -> Result<(), InvalidConfigError> {
38        let total_units = config.loading_planes_count() * config.plane_dim;
39        let num_slices =
40            config.smem_config.elements_per_tile_along_row * config.smem_config.tiles_per_stage();
41
42        if num_slices >= total_units && !num_slices.is_multiple_of(total_units) {
43            return Err(Box::new(format!(
44                "Number of units ({total_units:?}) must divide number of slices ({num_slices:?}). Would require units doing different numbers of slices"
45            )));
46        }
47
48        ContiguousTilingLayout::<T>::check(config.smem_config)?;
49        validate_async_barrier::<R>(client)?;
50        validate_noswizzle(config.smem_config)?;
51
52        Ok(())
53    }
54}
55
56impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
57    fn max_round_plane_count(
58        _elements_per_tile: u32,
59        _tiles_per_stage: u32,
60        _line_size: u8,
61        _plane_dim: u32,
62    ) -> u32 {
63        // Not sure what's ideal here, the current specialization isn't great anyways so can deal
64        // with it later
65        4
66    }
67}
68
69#[cube]
70impl<TO: TilingOrder> FullLoadingStrategy for AsyncFullCyclicLoading<TO> {
71    type TilingLayout = ContiguousTilingLayout<TO>;
72    type SyncStrategy = AsyncBarrier;
73    type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
74
75    const SHOULD_CLEAR: bool = true;
76
77    fn new_job<EG: Numeric, ES: Numeric>(
78        #[comptime] line_size: u32,
79        #[comptime] config: GlobalReaderConfig,
80    ) -> AsyncFullCyclicJob {
81        let total_units = config.loading_units_count();
82
83        let (num_slices_per_tile, slice_length_in_lines) = match config.gmem_config.matrix_layout {
84            MatrixLayout::RowMajor => (
85                config.smem_config.elements_per_tile_along_row,
86                config.smem_config.elements_per_tile_along_col / line_size,
87            ),
88            MatrixLayout::ColMajor => (
89                config.smem_config.elements_per_tile_along_col,
90                config.smem_config.elements_per_tile_along_row / line_size,
91            ),
92        };
93
94        let num_slices = comptime!(num_slices_per_tile * config.smem_config.tiles_per_stage());
95        let num_tasks_per_unit = num_slices.div_ceil(total_units);
96
97        let unit_id = RoleRule::new(config.plane_role_config.rule)
98            .load_index(config.specialization_tensor_config)
99            * config.plane_dim
100            + UNIT_POS_X;
101
102        AsyncFullCyclicJob {
103            unit_id,
104            num_tasks_per_unit,
105            total_units,
106            num_slices,
107            num_slices_per_tile,
108            slice_length_in_lines,
109            line_size,
110        }
111    }
112}
113
114#[derive(CubeType, Clone, Copy)]
115pub struct AsyncFullCyclicJob {
116    unit_id: u32,
117
118    #[cube(comptime)]
119    num_tasks_per_unit: u32,
120    #[cube(comptime)]
121    total_units: u32,
122    #[cube(comptime)]
123    num_slices: u32,
124    #[cube(comptime)]
125    num_slices_per_tile: u32,
126    #[cube(comptime)]
127    slice_length_in_lines: u32,
128    #[cube(comptime)]
129    line_size: u32,
130}
131
132#[cube]
133impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
134    LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncBarrier> for AsyncFullCyclicJob
135{
136    type Stage = StridedStageFamily;
137
138    fn execute_task(
139        this: &mut Self,
140        #[comptime] task_id: u32,
141        global_iter: &GlobalIterator<Line<EG>>,
142        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
143        barrier: &mut Barrier,
144        #[comptime] config: GlobalReaderConfig,
145    ) {
146        let slice_index = this.unit_id + this.total_units * task_id;
147
148        let nth_tile = slice_index / this.num_slices_per_tile;
149        let (tile_x, tile_y) =
150            ContiguousTilingLayout::<TO>::to_x_y(nth_tile, comptime!(config.smem_config));
151        let nth_slice = slice_index % this.num_slices_per_tile;
152
153        // TODO make branching comptime conditional (using Reader Mode)
154        if slice_index < this.num_slices {
155            let window = load_window_in_tile(
156                &global_iter.view(),
157                (tile_x, tile_y),
158                nth_slice,
159                config.smem_config,
160                config.gmem_config,
161            );
162
163            // Where this unit writes source in the stage
164            let slice_destination_offset =
165                (nth_tile * this.num_slices_per_tile + nth_slice) * this.slice_length_in_lines;
166
167            // Make destination start at offset
168            let mut destination = stage.as_slice_mut(this.line_size).slice_mut(
169                slice_destination_offset,
170                slice_destination_offset + this.slice_length_in_lines,
171            );
172
173            barrier.memcpy_async(&window.try_cast_unchecked(), &mut destination);
174        }
175    }
176
177    fn task_count(this: &Self) -> comptime_type!(u32) {
178        this.num_tasks_per_unit
179    }
180}