Skip to main content

cubek_convolution/components/global/read/strategy/
async_full_strided.rs

1use cubecl::prelude::*;
2use cubecl::std::tensor::layout::{Layout, LayoutExpand};
3use cubecl::{ir::DeviceProperties, prelude::barrier::Barrier};
4use cubek_matmul::components::{
5    global::{
6        GlobalReaderConfig, PlaneFlowPartition,
7        memory::GlobalIterator,
8        multi_stage::LoadMaxRoundPlaneCount,
9        read::{
10            FullLoadingStrategy, LoadingJob, LoadingValidation, async_barrier::AsyncCopy,
11            async_full_strided::AsyncFullStridedLoading as MatmulStridedLoading,
12            stage::FullStageLayout,
13        },
14    },
15    stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout},
16    tile::io::Strided,
17};
18use cubek_matmul::definition::{InvalidConfigError, MatmulElems, MatmulProblem, StageIdent};
19
20use crate::components::global::{
21    args::RuntimeArgs,
22    read::strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
23};
24
25#[derive(CubeType, Clone, Copy)]
26/// Loads the content of all the stage using all planes,
27/// keeping the original layout, making each tile strided
28pub struct AsyncFullStridedLoading {}
29
30impl LoadingValidation for AsyncFullStridedLoading {
31    fn validate_with_config(
32        device_props: &DeviceProperties,
33        config: &GlobalReaderConfig,
34    ) -> Result<(), InvalidConfigError> {
35        MatmulStridedLoading::validate_with_config(device_props, config)
36    }
37
38    fn validate_with_problem(
39        problem: &MatmulProblem,
40        dtypes: &MatmulElems,
41        ident: StageIdent,
42    ) -> Result<(), InvalidConfigError> {
43        MatmulStridedLoading::validate_with_problem(problem, dtypes, ident)
44    }
45}
46
47impl LoadMaxRoundPlaneCount for AsyncFullStridedLoading {
48    fn max_round_plane_count(
49        elements_per_tile: u32,
50        tiles_per_stage: u32,
51        line_size: LineSize,
52        plane_dim: u32,
53        dtype: StorageType,
54    ) -> u32 {
55        MatmulStridedLoading::max_round_plane_count(
56            elements_per_tile,
57            tiles_per_stage,
58            line_size,
59            plane_dim,
60            dtype,
61        )
62    }
63}
64
65#[cube]
66impl FullLoadingStrategy<RuntimeArgs> for AsyncFullStridedLoading {
67    type TilingLayout = StridedTilingLayout;
68    type SyncStrategy = AsyncCopy;
69    type Job<EG: Numeric, ES: Numeric> = AsyncFullStridedJob;
70    type Stage = StridedStageFamily;
71    type TileKind = Strided;
72
73    fn new_job<EG: Numeric, ES: Numeric>(
74        runtime_args: RuntimeArgs,
75        #[comptime] _line_size: LineSize,
76        #[comptime] config: GlobalReaderConfig,
77    ) -> Self::Job<EG, ES> {
78        let type_size = ES::type_size_bits().comptime();
79        let line_size = ASYNC_COPY_WIDTH / type_size as u32;
80        let num_stage_lines = config.smem_config.elements_per_stage() / line_size;
81        let unit_count = config.loading_planes_count() * config.plane_dim;
82        let num_tasks_per_unit = num_stage_lines / unit_count;
83
84        let unit_position_base = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
85            .load_index(config.input_load_flow)
86            * config.plane_dim
87            + UNIT_POS_X;
88
89        AsyncFullStridedJob {
90            unit_position_base,
91            runtime_args,
92            num_tasks_per_unit,
93            unit_count,
94            copy_line_size: line_size,
95        }
96    }
97}
98
99#[derive(CubeType, Clone)]
100pub struct AsyncFullStridedJob {
101    unit_position_base: u32,
102    runtime_args: RuntimeArgs,
103
104    #[cube(comptime)]
105    num_tasks_per_unit: u32,
106    #[cube(comptime)]
107    unit_count: u32,
108    #[cube(comptime)]
109    copy_line_size: u32,
110}
111
112#[cube]
113impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, StridedTilingLayout, AsyncCopy>
114    for AsyncFullStridedJob
115{
116    type Stage = StridedStageFamily;
117
118    fn execute_task(
119        this: &mut Self,
120        #[comptime] task_id: u32,
121        global_iter: &GlobalIterator<Line<EG>>,
122        stage: &mut StridedStageMemory<ES, StridedTilingLayout>,
123        _barrier: &mut Shared<Barrier>,
124        #[comptime] config: GlobalReaderConfig,
125    ) {
126        let unit_position = this.unit_position_base + task_id * this.unit_count;
127        let unit_position_abs = unit_position * this.copy_line_size;
128
129        let layout = FullStageLayout::new(config.smem_config);
130        let view = global_iter.view();
131
132        let pos = layout.to_source_pos(unit_position_abs);
133        let stage_offset = unit_position_abs / stage.smem.line_size() as u32;
134
135        async_copy_from(
136            view,
137            pos,
138            stage,
139            stage_offset,
140            &this.runtime_args,
141            global_iter.offset(),
142            config,
143            this.copy_line_size,
144        );
145    }
146
147    fn task_count(this: &Self) -> comptime_type!(u32) {
148        this.num_tasks_per_unit
149    }
150}