Skip to main content

cubek_convolution/components/global/read/strategy/
async_full_cyclic.rs

1use std::marker::PhantomData;
2
3use cubecl::{
4    prelude::*,
5    std::tensor::layout::{Layout, LayoutExpand},
6    {ir::DeviceProperties, prelude::barrier::Barrier},
7};
8use cubek_matmul::components::{
9    global::{
10        GlobalReaderConfig, PlaneFlowPartition,
11        memory::GlobalIterator,
12        multi_stage::LoadMaxRoundPlaneCount,
13        read::{
14            FullLoadingStrategy, LoadingJob, LoadingValidation, ReaderMode,
15            async_barrier::AsyncCopy,
16            async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
17        },
18    },
19    stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
20};
21use cubek_matmul::definition::{MatmulElems, MatmulProblem, StageIdent};
22use cubek_std::{InvalidConfigError, tile::Strided};
23
24use crate::components::global::{
25    args::RuntimeArgs,
26    read::strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
27};
28
29#[derive(CubeType, Clone, Copy)]
30/// Loads the content of all tiles in the stage using all planes.
31/// Unit with pos X loads vectors with indices X, X + NUM_UNITS, X + 2 * NUM_UNITS, ...
32pub struct AsyncFullCyclicLoading<T: TilingOrder> {
33    #[cube(comptime)]
34    _t: PhantomData<T>,
35}
36
37impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
38    fn validate_with_config(
39        device_props: &DeviceProperties,
40        config: &GlobalReaderConfig,
41    ) -> Result<(), InvalidConfigError> {
42        MatmulCyclicLoading::<TO>::validate_with_config(device_props, config)
43    }
44
45    fn validate_with_problem(
46        problem: &MatmulProblem,
47        dtypes: &MatmulElems,
48        ident: StageIdent,
49    ) -> Result<(), InvalidConfigError> {
50        MatmulCyclicLoading::<TO>::validate_with_problem(problem, dtypes, ident)
51    }
52}
53
54impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
55    fn max_round_plane_count(
56        elements_per_tile: u32,
57        tiles_per_stage: u32,
58        vector_size: VectorSize,
59        plane_dim: u32,
60        dtype: StorageType,
61    ) -> u32 {
62        MatmulCyclicLoading::<TO>::max_round_plane_count(
63            elements_per_tile,
64            tiles_per_stage,
65            vector_size,
66            plane_dim,
67            dtype,
68        )
69    }
70}
71
72#[cube]
73impl<TO: TilingOrder> FullLoadingStrategy<RuntimeArgs> for AsyncFullCyclicLoading<TO> {
74    type TilingLayout = ContiguousTilingLayout<TO>;
75    type SyncStrategy = AsyncCopy;
76    type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size> = AsyncFullCyclicJob;
77    type Stage = StridedStageFamily;
78    type TileKind = Strided;
79
80    fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
81        runtime_args: RuntimeArgs,
82        #[comptime] config: GlobalReaderConfig,
83    ) -> Self::Job<EG, NG, ES, NS> {
84        let type_size = ES::type_size_bits().comptime();
85        let vector_size = ASYNC_COPY_WIDTH / type_size as u32;
86        let tile_num_elements = config.smem_config.elements_per_tile();
87        let num_stage_elements = config.smem_config.elements_per_stage();
88
89        let num_stage_vectors = num_stage_elements.div_ceil(vector_size);
90        let total_units = config.loading_units_count();
91        let num_tasks_per_unit = num_stage_vectors.div_ceil(total_units);
92        let balanced_workload = num_stage_vectors.is_multiple_of(total_units);
93        let jump_length = total_units * vector_size;
94
95        let unit_id = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
96            .load_index(config.input_load_flow)
97            * config.plane_dim
98            + UNIT_POS_X;
99        let unit_position_base = unit_id * vector_size;
100
101        AsyncFullCyclicJob {
102            unit_position_base,
103            runtime_args,
104            num_tasks_per_unit,
105            tile_num_elements,
106            jump_length,
107            copy_vector_size: vector_size,
108            balanced_workload,
109            num_stage_elements,
110            reader_mode: config.reader_mode,
111        }
112    }
113}
114
115#[derive(CubeType, Clone)]
116pub struct AsyncFullCyclicJob {
117    unit_position_base: u32,
118    runtime_args: RuntimeArgs,
119
120    #[cube(comptime)]
121    num_tasks_per_unit: u32,
122    #[cube(comptime)]
123    tile_num_elements: u32,
124    #[cube(comptime)]
125    jump_length: u32,
126    #[cube(comptime)]
127    copy_vector_size: u32,
128    #[cube(comptime)]
129    balanced_workload: bool,
130    #[cube(comptime)]
131    num_stage_elements: u32,
132    #[cube(comptime)]
133    reader_mode: ReaderMode,
134}
135
136#[cube]
137impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>
138    LoadingJob<EG, NG, ES, NS, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
139{
140    type Stage = StridedStageFamily;
141
142    fn execute_task(
143        this: &mut Self,
144        #[comptime] task_id: u32,
145        global_iter: &GlobalIterator<Vector<EG, NG>>,
146        stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
147        _barrier: &mut Shared<Barrier>,
148        #[comptime] config: GlobalReaderConfig,
149    ) {
150        let unit_position = this.unit_position_base + task_id * this.jump_length;
151
152        #[allow(clippy::collapsible_else_if)]
153        if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
154            copy_vector::<EG, NG, ES, NS, TO>(
155                this,
156                unit_position,
157                global_iter,
158                stage,
159                &this.runtime_args,
160                config,
161            );
162        } else {
163            if unit_position < this.num_stage_elements {
164                copy_vector::<EG, NG, ES, NS, TO>(
165                    this,
166                    unit_position,
167                    global_iter,
168                    stage,
169                    &this.runtime_args,
170                    config,
171                );
172            }
173        }
174    }
175
176    fn task_count(this: &Self) -> comptime_type!(u32) {
177        this.num_tasks_per_unit
178    }
179}
180
181#[cube]
182pub(crate) fn copy_vector<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>(
183    job: &AsyncFullCyclicJob,
184    unit_position: u32,
185    global_iter: &GlobalIterator<Vector<EG, NG>>,
186    stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
187    runtime_args: &RuntimeArgs,
188    #[comptime] config: GlobalReaderConfig,
189) {
190    let nth_tile = unit_position / job.tile_num_elements;
191    let pos_within_tile = unit_position % job.tile_num_elements;
192
193    let layout = TiledLayout::new(config.stage_ident, config.smem_config);
194    let view = global_iter.view();
195
196    let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
197
198    let pos = layout.to_source_pos((tile, pos_within_tile));
199    let stage_offset = unit_position / stage.smem.vector_size() as u32;
200
201    async_copy_from(
202        view,
203        pos,
204        stage,
205        stage_offset,
206        runtime_args,
207        global_iter.offset(),
208        config,
209        job.copy_vector_size,
210    );
211}