Skip to main content

cubek_convolution/components/global/read/strategy/
async_full_cyclic.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::tensor::layout::{Layout, LayoutExpand};
5use cubecl::{ir::DeviceProperties, prelude::barrier::Barrier};
6use cubek_matmul::components::{
7    global::{
8        GlobalReaderConfig, PlaneFlowPartition,
9        memory::GlobalIterator,
10        multi_stage::LoadMaxRoundPlaneCount,
11        read::{
12            FullLoadingStrategy, LoadingJob, LoadingValidation, ReaderMode,
13            async_barrier::AsyncCopy,
14            async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
15        },
16    },
17    stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
18    tile::io::Strided,
19};
20use cubek_matmul::definition::{InvalidConfigError, MatmulElems, MatmulProblem, StageIdent};
21
22use crate::components::global::{
23    args::RuntimeArgs,
24    read::strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25};
26
27#[derive(CubeType, Clone, Copy)]
28/// Loads the content of all tiles in the stage using all planes.
29/// Unit with pos X loads lines with indices X, X + NUM_UNITS, X + 2 * NUM_UNITS, ...
30pub struct AsyncFullCyclicLoading<T: TilingOrder> {
31    #[cube(comptime)]
32    _t: PhantomData<T>,
33}
34
35impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
36    fn validate_with_config(
37        device_props: &DeviceProperties,
38        config: &GlobalReaderConfig,
39    ) -> Result<(), InvalidConfigError> {
40        MatmulCyclicLoading::<TO>::validate_with_config(device_props, config)
41    }
42
43    fn validate_with_problem(
44        problem: &MatmulProblem,
45        dtypes: &MatmulElems,
46        ident: StageIdent,
47    ) -> Result<(), InvalidConfigError> {
48        MatmulCyclicLoading::<TO>::validate_with_problem(problem, dtypes, ident)
49    }
50}
51
52impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
53    fn max_round_plane_count(
54        elements_per_tile: u32,
55        tiles_per_stage: u32,
56        line_size: LineSize,
57        plane_dim: u32,
58        dtype: StorageType,
59    ) -> u32 {
60        MatmulCyclicLoading::<TO>::max_round_plane_count(
61            elements_per_tile,
62            tiles_per_stage,
63            line_size,
64            plane_dim,
65            dtype,
66        )
67    }
68}
69
70#[cube]
71impl<TO: TilingOrder> FullLoadingStrategy<RuntimeArgs> for AsyncFullCyclicLoading<TO> {
72    type TilingLayout = ContiguousTilingLayout<TO>;
73    type SyncStrategy = AsyncCopy;
74    type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
75    type Stage = StridedStageFamily;
76    type TileKind = Strided;
77
78    fn new_job<EG: Numeric, ES: Numeric>(
79        runtime_args: RuntimeArgs,
80        #[comptime] _line_size: LineSize,
81        #[comptime] config: GlobalReaderConfig,
82    ) -> Self::Job<EG, ES> {
83        let type_size = ES::type_size_bits().comptime();
84        let line_size = ASYNC_COPY_WIDTH / type_size as u32;
85        let tile_num_elements = config.smem_config.elements_per_tile();
86        let num_stage_elements = config.smem_config.elements_per_stage();
87
88        let num_stage_lines = num_stage_elements.div_ceil(line_size);
89        let total_units = config.loading_units_count();
90        let num_tasks_per_unit = num_stage_lines.div_ceil(total_units);
91        let balanced_workload = num_stage_lines.is_multiple_of(total_units);
92        let jump_length = total_units * line_size;
93
94        let unit_id = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
95            .load_index(config.input_load_flow)
96            * config.plane_dim
97            + UNIT_POS_X;
98        let unit_position_base = unit_id * line_size;
99
100        AsyncFullCyclicJob {
101            unit_position_base,
102            runtime_args,
103            num_tasks_per_unit,
104            tile_num_elements,
105            jump_length,
106            copy_line_size: line_size,
107            balanced_workload,
108            num_stage_elements,
109            reader_mode: config.reader_mode,
110        }
111    }
112}
113
114#[derive(CubeType, Clone)]
115pub struct AsyncFullCyclicJob {
116    unit_position_base: u32,
117    runtime_args: RuntimeArgs,
118
119    #[cube(comptime)]
120    num_tasks_per_unit: u32,
121    #[cube(comptime)]
122    tile_num_elements: u32,
123    #[cube(comptime)]
124    jump_length: u32,
125    #[cube(comptime)]
126    copy_line_size: u32,
127    #[cube(comptime)]
128    balanced_workload: bool,
129    #[cube(comptime)]
130    num_stage_elements: u32,
131    #[cube(comptime)]
132    reader_mode: ReaderMode,
133}
134
135#[cube]
136impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
137    LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
138{
139    type Stage = StridedStageFamily;
140
141    fn execute_task(
142        this: &mut Self,
143        #[comptime] task_id: u32,
144        global_iter: &GlobalIterator<Line<EG>>,
145        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
146        _barrier: &mut Shared<Barrier>,
147        #[comptime] config: GlobalReaderConfig,
148    ) {
149        let unit_position = this.unit_position_base + task_id * this.jump_length;
150
151        #[allow(clippy::collapsible_else_if)]
152        if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
153            copy_line::<EG, ES, TO>(
154                this,
155                unit_position,
156                global_iter,
157                stage,
158                &this.runtime_args,
159                config,
160            );
161        } else {
162            if unit_position < this.num_stage_elements {
163                copy_line::<EG, ES, TO>(
164                    this,
165                    unit_position,
166                    global_iter,
167                    stage,
168                    &this.runtime_args,
169                    config,
170                );
171            }
172        }
173    }
174
175    fn task_count(this: &Self) -> comptime_type!(u32) {
176        this.num_tasks_per_unit
177    }
178}
179
180#[cube]
181pub(crate) fn copy_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
182    job: &AsyncFullCyclicJob,
183    unit_position: u32,
184    global_iter: &GlobalIterator<Line<EG>>,
185    stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
186    runtime_args: &RuntimeArgs,
187    #[comptime] config: GlobalReaderConfig,
188) {
189    let nth_tile = unit_position / job.tile_num_elements;
190    let pos_within_tile = unit_position % job.tile_num_elements;
191
192    let layout = TiledLayout::new(config.stage_ident, config.smem_config);
193    let view = global_iter.view();
194
195    let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
196
197    let pos = layout.to_source_pos((tile, pos_within_tile));
198    let stage_offset = unit_position / stage.smem.line_size() as u32;
199
200    async_copy_from(
201        view,
202        pos,
203        stage,
204        stage_offset,
205        runtime_args,
206        global_iter.offset(),
207        config,
208        job.copy_line_size,
209    );
210}