cubecl_linalg/matmul/components/global/load/strategy/
async_full_cyclic.rs

1use std::marker::PhantomData;
2
3use crate::matmul::components::{
4    Ident, InputIdent, InvalidConfigError, MatmulPrecision, MatrixLayout,
5    global::{
6        CopyMechanism, GlobalConfig, LoadingValidation, load::AsyncFullLoadingStrategy,
7        tensor_view::TensorReader,
8    },
9    stage::{ContiguousTilingLayout, Stage, TilingOrder},
10};
11use cubecl_core::prelude::*;
12use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
13
14use super::AsyncLoadingJob;
15
16#[derive(CubeType, Clone, Copy)]
17/// Loads the content of all tiles in the tensor view using all planes,
18/// iterating with steps determined by the plane's dimension.
19pub struct LoadingStrategy<T: TilingOrder> {
20    #[cube(comptime)]
21    _phantom: PhantomData<T>,
22}
23
24impl<T: TilingOrder> LoadingValidation for LoadingStrategy<T> {
25    fn check<C: GlobalConfig>(config: &C, ident: Ident) -> Result<(), InvalidConfigError> {
26        let tiling = config.tiling_dimensions(ident);
27        let total_units = config.num_planes() * config.plane_dim();
28
29        let num_slices = tiling.tile_shape_row() * tiling.tile_count();
30        if num_slices >= total_units && num_slices % total_units != 0 {
31            return Err(Box::new(format!(
32                "Number of units ({total_units:?}) must divide number of slices ({num_slices:?}). Would require units doing different numbers of slices"
33            )));
34        }
35
36        Ok(())
37    }
38}
39
40#[cube]
41impl<TO: TilingOrder> AsyncFullLoadingStrategy for LoadingStrategy<TO> {
42    type TilingLayout = ContiguousTilingLayout<TO>;
43    type Job<MP: MatmulPrecision> = Job;
44
45    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
46        #[comptime] input_ident: InputIdent,
47        #[comptime] config: G,
48    ) -> Job {
49        let stage_dim = config.tiling_dimensions(input_ident);
50        let total_units = config.plane_dim() * config.num_planes();
51        let line_size = config.global_line_size(input_ident);
52
53        let (num_slices_per_tile, slice_length_in_lines) = match config.matrix_layout(input_ident) {
54            MatrixLayout::RowMajor => (
55                stage_dim.tile_shape_row(),
56                stage_dim.tile_shape_col() / line_size,
57            ),
58            MatrixLayout::ColMajor => (
59                stage_dim.tile_shape_col(),
60                stage_dim.tile_shape_row() / line_size,
61            ),
62        };
63
64        let num_slices = comptime!(num_slices_per_tile * stage_dim.tile_count());
65        let num_tasks_per_unit = num_slices.div_ceil(total_units);
66
67        let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;
68
69        Job {
70            unit_id,
71            num_tasks_per_unit,
72            total_units,
73            num_slices,
74            input_ident,
75            num_slices_per_tile,
76            slice_length_in_lines,
77            line_size,
78        }
79    }
80
81    fn barrier_level() -> BarrierLevel {
82        BarrierLevel::cube_manual(0u32)
83    }
84}
85
86#[derive(CubeType, Clone, Copy)]
87pub struct Job {
88    unit_id: u32,
89
90    #[cube(comptime)]
91    num_tasks_per_unit: u32,
92    #[cube(comptime)]
93    total_units: u32,
94    #[cube(comptime)]
95    num_slices: u32,
96    #[cube(comptime)]
97    input_ident: InputIdent,
98    #[cube(comptime)]
99    num_slices_per_tile: u32,
100    #[cube(comptime)]
101    slice_length_in_lines: u32,
102    #[cube(comptime)]
103    line_size: u32,
104}
105
106#[cube]
107impl<MP: MatmulPrecision, TO: TilingOrder> AsyncLoadingJob<MP, ContiguousTilingLayout<TO>> for Job {
108    fn execute_task<CM: CopyMechanism<MP::ES>, G: GlobalConfig>(
109        this: &mut Self,
110        task_id: u32,
111        tensor_reader: &TensorReader<MP::EI>,
112        stage: &mut Stage<MP::ES, ContiguousTilingLayout<TO>>,
113        mechanism: &CM,
114        #[comptime] config: G,
115    ) {
116        let slice_index = this.unit_id + this.total_units * task_id;
117
118        let nth_tile = slice_index / this.num_slices_per_tile;
119        let (tile_x, tile_y) = ContiguousTilingLayout::<TO>::to_x_y::<G::SmmConfig>(
120            nth_tile,
121            comptime!(this.input_ident.as_ident()),
122            config.to_smm_config(),
123        );
124        let nth_slice = slice_index % this.num_slices_per_tile;
125
126        // TODO make branching comptime conditional (using balanced_workload)
127        if slice_index < this.num_slices {
128            let window = tensor_reader.load_window_in_tile::<G>(
129                (tile_x, tile_y),
130                nth_slice,
131                this.input_ident,
132                config,
133            );
134
135            // Where this unit writes source in the stage
136            let slice_destination_offset =
137                (nth_tile * this.num_slices_per_tile + nth_slice) * this.slice_length_in_lines;
138
139            // Make destination start at offset
140            let mut destination = stage.as_slice_mut(this.line_size).slice_mut(
141                slice_destination_offset,
142                slice_destination_offset + this.slice_length_in_lines,
143            );
144
145            CM::memcpy_async(
146                mechanism,
147                &window.slice.try_cast_unchecked(),
148                &mut destination,
149            );
150        }
151    }
152
153    fn task_count(this: &Self) -> comptime_type!(u32) {
154        this.num_tasks_per_unit
155    }
156}