cubecl_convolution/components/global/read/reader/strategy/
async_full_cyclic.rs

1use std::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_matmul::components::{
6    InvalidConfigError, MatmulElems, MatmulProblem,
7    global::{
8        GlobalReaderConfig, RoleRule,
9        memory::GlobalIterator,
10        multi_stage::LoadMaxRoundPlaneCount,
11        read::{
12            LoadingJob, LoadingValidation, ReaderMode, async_barrier::AsyncCopy,
13            async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
14        },
15    },
16    stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
17};
18use cubecl_std::tensor::layout::{Layout, LayoutExpand};
19
20use crate::components::global::{
21    args::RuntimeArgs,
22    read::{
23        full_reader::FullLoadingStrategy,
24        strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25    },
26};
27
28#[derive(CubeType, Clone, Copy)]
29/// Loads the content of all tiles in the stage using all planes.
30/// Unit with pos X loads lines with indices X, X + NUM_UNITS, X + 2 * NUM_UNITS, ...
31pub struct AsyncFullCyclicLoading<T: TilingOrder> {
32    #[cube(comptime)]
33    _t: PhantomData<T>,
34}
35
36impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
37    fn check<R: Runtime>(
38        client: &ComputeClient<R>,
39        problem: &MatmulProblem,
40        config: &GlobalReaderConfig,
41        dtypes: &MatmulElems,
42    ) -> Result<(), InvalidConfigError> {
43        MatmulCyclicLoading::<TO>::check(client, problem, config, dtypes)
44    }
45}
46
47impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
48    fn max_round_plane_count(
49        elements_per_tile: u32,
50        tiles_per_stage: u32,
51        line_size: u8,
52        plane_dim: u32,
53        dtype: StorageType,
54    ) -> u32 {
55        MatmulCyclicLoading::<TO>::max_round_plane_count(
56            elements_per_tile,
57            tiles_per_stage,
58            line_size,
59            plane_dim,
60            dtype,
61        )
62    }
63}
64
65#[cube]
66impl<TO: TilingOrder> FullLoadingStrategy for AsyncFullCyclicLoading<TO> {
67    type TilingLayout = ContiguousTilingLayout<TO>;
68    type SyncStrategy = AsyncCopy;
69    type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
70
71    fn new_job<EG: Numeric, ES: Numeric>(
72        runtime_args: RuntimeArgs,
73        #[comptime] _line_size: u32,
74        #[comptime] config: GlobalReaderConfig,
75    ) -> Self::Job<EG, ES> {
76        let type_size = ES::type_size_bits();
77        let line_size = comptime![ASYNC_COPY_WIDTH / type_size];
78        let tile_num_elements = config.smem_config.elements_per_tile();
79        let num_stage_elements = config.smem_config.elements_per_stage();
80
81        let num_stage_lines = num_stage_elements.div_ceil(line_size);
82        let total_units = config.loading_units_count();
83        let num_tasks_per_unit = comptime!(num_stage_lines.div_ceil(total_units));
84        let balanced_workload = comptime!(num_stage_lines.is_multiple_of(total_units));
85        let jump_length = comptime!(total_units * line_size);
86
87        let unit_id = RoleRule::new(config.plane_role_config.rule)
88            .load_index(config.specialization_tensor_config)
89            * config.plane_dim
90            + UNIT_POS_X;
91        let unit_position_base = unit_id * line_size;
92
93        AsyncFullCyclicJob {
94            unit_position_base,
95            runtime_args,
96            num_tasks_per_unit,
97            tile_num_elements,
98            jump_length,
99            copy_line_size: line_size,
100            balanced_workload,
101            num_stage_elements,
102            reader_mode: config.reader_mode,
103        }
104    }
105}
106
107#[derive(CubeType, Clone)]
108pub struct AsyncFullCyclicJob {
109    unit_position_base: u32,
110    runtime_args: RuntimeArgs,
111
112    #[cube(comptime)]
113    num_tasks_per_unit: u32,
114    #[cube(comptime)]
115    tile_num_elements: u32,
116    #[cube(comptime)]
117    jump_length: u32,
118    #[cube(comptime)]
119    copy_line_size: u32,
120    #[cube(comptime)]
121    balanced_workload: bool,
122    #[cube(comptime)]
123    num_stage_elements: u32,
124    #[cube(comptime)]
125    reader_mode: ReaderMode,
126}
127
128#[cube]
129impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
130    LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
131{
132    type Stage = StridedStageFamily;
133
134    fn execute_task(
135        this: &mut Self,
136        #[comptime] task_id: u32,
137        global_iter: &GlobalIterator<Line<EG>>,
138        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
139        _barrier: &mut Shared<Barrier>,
140        #[comptime] config: GlobalReaderConfig,
141    ) {
142        let unit_position = this.unit_position_base + task_id * this.jump_length;
143
144        #[allow(clippy::collapsible_else_if)]
145        if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
146            copy_line::<EG, ES, TO>(
147                this,
148                unit_position,
149                global_iter,
150                stage,
151                &this.runtime_args,
152                config,
153            );
154        } else {
155            if unit_position < this.num_stage_elements {
156                copy_line::<EG, ES, TO>(
157                    this,
158                    unit_position,
159                    global_iter,
160                    stage,
161                    &this.runtime_args,
162                    config,
163                );
164            }
165        }
166    }
167
168    fn task_count(this: &Self) -> comptime_type!(u32) {
169        this.num_tasks_per_unit
170    }
171}
172
173#[cube]
174pub(crate) fn copy_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
175    job: &AsyncFullCyclicJob,
176    unit_position: u32,
177    global_iter: &GlobalIterator<Line<EG>>,
178    stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
179    runtime_args: &RuntimeArgs,
180    #[comptime] config: GlobalReaderConfig,
181) {
182    let nth_tile = unit_position / job.tile_num_elements;
183    let pos_within_tile = unit_position % job.tile_num_elements;
184
185    let layout = TiledLayout::new(config.stage_ident, config.smem_config);
186    let view = global_iter.view();
187
188    let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
189
190    let pos = layout.to_source_pos((tile, pos_within_tile));
191    let stage_offset = unit_position / stage.smem.line_size();
192
193    async_copy_from(
194        view,
195        pos,
196        stage,
197        stage_offset,
198        runtime_args,
199        global_iter.offset(),
200        config,
201        job.copy_line_size,
202    );
203}