Skip to main content

cubek_convolution/components/global/read/strategy/
async_full_strided.rs

1use cubecl::{
2    prelude::*,
3    std::tensor::layout::{Layout, LayoutExpand},
4    {ir::DeviceProperties, prelude::barrier::Barrier},
5};
6use cubek_matmul::components::{
7    global::{
8        GlobalReaderConfig, PlaneFlowPartition,
9        memory::GlobalIterator,
10        multi_stage::LoadMaxRoundPlaneCount,
11        read::{
12            FullLoadingStrategy, LoadingJob, LoadingValidation, async_barrier::AsyncCopy,
13            async_full_strided::AsyncFullStridedLoading as MatmulStridedLoading,
14            stage::FullStageLayout,
15        },
16    },
17    stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout},
18};
19use cubek_matmul::definition::{MatmulElems, MatmulProblem};
20use cubek_std::{InvalidConfigError, StageIdent, tile::Strided};
21
22use crate::components::global::{
23    args::RuntimeArgs,
24    read::strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25};
26
27#[derive(CubeType, Clone, Copy)]
28/// Loads the content of all the stage using all planes,
29/// keeping the original layout, making each tile strided
30pub struct AsyncFullStridedLoading {}
31
32impl LoadingValidation for AsyncFullStridedLoading {
33    fn validate_with_config(
34        device_props: &DeviceProperties,
35        config: &GlobalReaderConfig,
36    ) -> Result<(), InvalidConfigError> {
37        MatmulStridedLoading::validate_with_config(device_props, config)
38    }
39
40    fn validate_with_problem(
41        problem: &MatmulProblem,
42        dtypes: &MatmulElems,
43        ident: StageIdent,
44    ) -> Result<(), InvalidConfigError> {
45        MatmulStridedLoading::validate_with_problem(problem, dtypes, ident)
46    }
47}
48
49impl LoadMaxRoundPlaneCount for AsyncFullStridedLoading {
50    fn max_round_plane_count(
51        elements_per_tile: u32,
52        tiles_per_stage: u32,
53        vector_size: VectorSize,
54        plane_dim: u32,
55        dtype: StorageType,
56    ) -> u32 {
57        MatmulStridedLoading::max_round_plane_count(
58            elements_per_tile,
59            tiles_per_stage,
60            vector_size,
61            plane_dim,
62            dtype,
63        )
64    }
65}
66
67#[cube]
68impl FullLoadingStrategy<RuntimeArgs> for AsyncFullStridedLoading {
69    type TilingLayout = StridedTilingLayout;
70    type SyncStrategy = AsyncCopy;
71    type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size> = AsyncFullStridedJob;
72    type Stage = StridedStageFamily;
73    type TileKind = Strided;
74
75    fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
76        runtime_args: RuntimeArgs,
77        #[comptime] config: GlobalReaderConfig,
78    ) -> Self::Job<EG, NG, ES, NS> {
79        let type_size = ES::type_size_bits().comptime();
80        let vector_size = ASYNC_COPY_WIDTH / type_size as u32;
81        let num_stage_vectors = config.smem_config.elements_per_stage() / vector_size;
82        let unit_count = config.loading_planes_count() * config.plane_dim;
83        let num_tasks_per_unit = num_stage_vectors / unit_count;
84
85        let unit_position_base = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
86            .load_index(config.input_load_flow)
87            * config.plane_dim
88            + UNIT_POS_X;
89
90        AsyncFullStridedJob {
91            unit_position_base,
92            runtime_args,
93            num_tasks_per_unit,
94            unit_count,
95            copy_vector_size: vector_size,
96        }
97    }
98}
99
100#[derive(CubeType, Clone)]
101pub struct AsyncFullStridedJob {
102    unit_position_base: u32,
103    runtime_args: RuntimeArgs,
104
105    #[cube(comptime)]
106    num_tasks_per_unit: u32,
107    #[cube(comptime)]
108    unit_count: u32,
109    #[cube(comptime)]
110    copy_vector_size: u32,
111}
112
113#[cube]
114impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size>
115    LoadingJob<EG, NG, ES, NS, StridedTilingLayout, AsyncCopy> for AsyncFullStridedJob
116{
117    type Stage = StridedStageFamily;
118
119    fn execute_task(
120        this: &mut Self,
121        #[comptime] task_id: u32,
122        global_iter: &GlobalIterator<Vector<EG, NG>>,
123        stage: &mut StridedStageMemory<ES, NS, StridedTilingLayout>,
124        _barrier: &mut Shared<Barrier>,
125        #[comptime] config: GlobalReaderConfig,
126    ) {
127        let unit_position = this.unit_position_base + task_id * this.unit_count;
128        let unit_position_abs = unit_position * this.copy_vector_size;
129
130        let layout = FullStageLayout::new(config.smem_config);
131        let view = global_iter.view();
132
133        let pos = layout.to_source_pos(unit_position_abs);
134        let stage_offset = unit_position_abs / stage.smem.vector_size() as u32;
135
136        async_copy_from(
137            view,
138            pos,
139            stage,
140            stage_offset,
141            &this.runtime_args,
142            global_iter.offset(),
143            config,
144            this.copy_vector_size,
145        );
146    }
147
148    fn task_count(this: &Self) -> comptime_type!(u32) {
149        this.num_tasks_per_unit
150    }
151}