cubek_convolution/components/global/read/reader/strategy/
async_full_cyclic.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::tensor::layout::{Layout, LayoutExpand};
5use cubecl::{ir::DeviceProperties, prelude::barrier::Barrier};
6use cubek_matmul::components::{
7    global::{
8        GlobalReaderConfig, PlaneFlowPartition,
9        memory::GlobalIterator,
10        multi_stage::LoadMaxRoundPlaneCount,
11        read::{
12            LoadingJob, LoadingValidation, ReaderMode, async_barrier::AsyncCopy,
13            async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
14        },
15    },
16    stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
17};
18use cubek_matmul::definition::{InvalidConfigError, MatmulElems, MatmulProblem, StageIdent};
19
20use crate::components::global::{
21    args::RuntimeArgs,
22    read::{
23        full_reader::FullLoadingStrategy,
24        strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25    },
26};
27
28#[derive(CubeType, Clone, Copy)]
29/// Loads the content of all tiles in the stage using all planes.
30/// Unit with pos X loads lines with indices X, X + NUM_UNITS, X + 2 * NUM_UNITS, ...
31pub struct AsyncFullCyclicLoading<T: TilingOrder> {
32    #[cube(comptime)]
33    _t: PhantomData<T>,
34}
35
36impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
37    fn validate_with_config(
38        device_props: &DeviceProperties,
39        config: &GlobalReaderConfig,
40    ) -> Result<(), InvalidConfigError> {
41        MatmulCyclicLoading::<TO>::validate_with_config(device_props, config)
42    }
43
44    fn validate_with_problem(
45        _problem: &MatmulProblem,
46        _dtypes: &MatmulElems,
47        _ident: StageIdent,
48    ) -> Result<(), InvalidConfigError> {
49        Ok(())
50    }
51}
52
53impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
54    fn max_round_plane_count(
55        elements_per_tile: u32,
56        tiles_per_stage: u32,
57        line_size: LineSize,
58        plane_dim: u32,
59        dtype: StorageType,
60    ) -> u32 {
61        MatmulCyclicLoading::<TO>::max_round_plane_count(
62            elements_per_tile,
63            tiles_per_stage,
64            line_size,
65            plane_dim,
66            dtype,
67        )
68    }
69}
70
71#[cube]
72impl<TO: TilingOrder> FullLoadingStrategy for AsyncFullCyclicLoading<TO> {
73    type TilingLayout = ContiguousTilingLayout<TO>;
74    type SyncStrategy = AsyncCopy;
75    type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
76
77    fn new_job<EG: Numeric, ES: Numeric>(
78        runtime_args: RuntimeArgs,
79        #[comptime] _line_size: LineSize,
80        #[comptime] config: GlobalReaderConfig,
81    ) -> Self::Job<EG, ES> {
82        let type_size = ES::type_size_bits().comptime();
83        let line_size = ASYNC_COPY_WIDTH / type_size as u32;
84        let tile_num_elements = config.smem_config.elements_per_tile();
85        let num_stage_elements = config.smem_config.elements_per_stage();
86
87        let num_stage_lines = num_stage_elements.div_ceil(line_size);
88        let total_units = config.loading_units_count();
89        let num_tasks_per_unit = num_stage_lines.div_ceil(total_units);
90        let balanced_workload = num_stage_lines.is_multiple_of(total_units);
91        let jump_length = total_units * line_size;
92
93        let unit_id = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
94            .load_index(config.input_load_flow)
95            * config.plane_dim
96            + UNIT_POS_X;
97        let unit_position_base = unit_id * line_size;
98
99        AsyncFullCyclicJob {
100            unit_position_base,
101            runtime_args,
102            num_tasks_per_unit,
103            tile_num_elements,
104            jump_length,
105            copy_line_size: line_size,
106            balanced_workload,
107            num_stage_elements,
108            reader_mode: config.reader_mode,
109        }
110    }
111}
112
113#[derive(CubeType, Clone)]
114pub struct AsyncFullCyclicJob {
115    unit_position_base: u32,
116    runtime_args: RuntimeArgs,
117
118    #[cube(comptime)]
119    num_tasks_per_unit: u32,
120    #[cube(comptime)]
121    tile_num_elements: u32,
122    #[cube(comptime)]
123    jump_length: u32,
124    #[cube(comptime)]
125    copy_line_size: u32,
126    #[cube(comptime)]
127    balanced_workload: bool,
128    #[cube(comptime)]
129    num_stage_elements: u32,
130    #[cube(comptime)]
131    reader_mode: ReaderMode,
132}
133
134#[cube]
135impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
136    LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
137{
138    type Stage = StridedStageFamily;
139
140    fn execute_task(
141        this: &mut Self,
142        #[comptime] task_id: u32,
143        global_iter: &GlobalIterator<Line<EG>>,
144        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
145        _barrier: &mut Shared<Barrier>,
146        #[comptime] config: GlobalReaderConfig,
147    ) {
148        let unit_position = this.unit_position_base + task_id * this.jump_length;
149
150        #[allow(clippy::collapsible_else_if)]
151        if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
152            copy_line::<EG, ES, TO>(
153                this,
154                unit_position,
155                global_iter,
156                stage,
157                &this.runtime_args,
158                config,
159            );
160        } else {
161            if unit_position < this.num_stage_elements {
162                copy_line::<EG, ES, TO>(
163                    this,
164                    unit_position,
165                    global_iter,
166                    stage,
167                    &this.runtime_args,
168                    config,
169                );
170            }
171        }
172    }
173
174    fn task_count(this: &Self) -> comptime_type!(u32) {
175        this.num_tasks_per_unit
176    }
177}
178
179#[cube]
180pub(crate) fn copy_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
181    job: &AsyncFullCyclicJob,
182    unit_position: u32,
183    global_iter: &GlobalIterator<Line<EG>>,
184    stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
185    runtime_args: &RuntimeArgs,
186    #[comptime] config: GlobalReaderConfig,
187) {
188    let nth_tile = unit_position / job.tile_num_elements;
189    let pos_within_tile = unit_position % job.tile_num_elements;
190
191    let layout = TiledLayout::new(config.stage_ident, config.smem_config);
192    let view = global_iter.view();
193
194    let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
195
196    let pos = layout.to_source_pos((tile, pos_within_tile));
197    let stage_offset = unit_position / stage.smem.line_size() as u32;
198
199    async_copy_from(
200        view,
201        pos,
202        stage,
203        stage_offset,
204        runtime_args,
205        global_iter.offset(),
206        config,
207        job.copy_line_size,
208    );
209}