cubecl_convolution/components/global/read/reader/strategy/
async_full_strided.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3use cubecl_matmul::components::{
4    InvalidConfigError, MatmulElems, MatmulProblem,
5    global::{
6        GlobalReaderConfig, RoleRule,
7        memory::GlobalIterator,
8        multi_stage::LoadMaxRoundPlaneCount,
9        read::{
10            LoadingJob, LoadingValidation, async_barrier::AsyncCopy,
11            async_full_strided::AsyncFullStridedLoading as MatmulStridedLoading,
12            stage::FullStageLayout,
13        },
14    },
15    stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout},
16};
17use cubecl_std::tensor::layout::{Layout, LayoutExpand};
18
19use crate::components::global::{
20    args::RuntimeArgs,
21    read::{
22        full_reader::FullLoadingStrategy,
23        strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
24    },
25};
26
27#[derive(CubeType, Clone, Copy)]
28/// Loads the content of all the stage using all planes,
29/// keeping the original layout, making each tile strided
30pub struct AsyncFullStridedLoading {}
31
32impl LoadingValidation for AsyncFullStridedLoading {
33    fn check<R: Runtime>(
34        client: &ComputeClient<R>,
35        problem: &MatmulProblem,
36        config: &GlobalReaderConfig,
37        dtypes: &MatmulElems,
38    ) -> Result<(), InvalidConfigError> {
39        MatmulStridedLoading::check(client, problem, config, dtypes)
40    }
41}
42
43impl LoadMaxRoundPlaneCount for AsyncFullStridedLoading {
44    fn max_round_plane_count(
45        elements_per_tile: u32,
46        tiles_per_stage: u32,
47        line_size: u8,
48        plane_dim: u32,
49        dtype: StorageType,
50    ) -> u32 {
51        MatmulStridedLoading::max_round_plane_count(
52            elements_per_tile,
53            tiles_per_stage,
54            line_size,
55            plane_dim,
56            dtype,
57        )
58    }
59}
60
61#[cube]
62impl FullLoadingStrategy for AsyncFullStridedLoading {
63    type TilingLayout = StridedTilingLayout;
64    type SyncStrategy = AsyncCopy;
65    type Job<EG: Numeric, ES: Numeric> = AsyncFullStridedJob;
66
67    fn new_job<EG: Numeric, ES: Numeric>(
68        runtime_args: RuntimeArgs,
69        #[comptime] _line_size: u32,
70        #[comptime] config: GlobalReaderConfig,
71    ) -> Self::Job<EG, ES> {
72        let type_size = ES::type_size_bits();
73        let line_size = comptime![ASYNC_COPY_WIDTH / type_size];
74        let num_stage_lines = config.smem_config.elements_per_stage() / line_size;
75        let unit_count = config.loading_planes_count() * config.plane_dim;
76        let num_tasks_per_unit = comptime!(num_stage_lines / unit_count);
77
78        let unit_position_base = RoleRule::new(config.plane_role_config.rule)
79            .load_index(config.specialization_tensor_config)
80            * config.plane_dim
81            + UNIT_POS_X;
82
83        AsyncFullStridedJob {
84            unit_position_base,
85            runtime_args,
86            num_tasks_per_unit,
87            unit_count,
88            copy_line_size: line_size,
89        }
90    }
91}
92
93#[derive(CubeType, Clone)]
94pub struct AsyncFullStridedJob {
95    unit_position_base: u32,
96    runtime_args: RuntimeArgs,
97
98    #[cube(comptime)]
99    num_tasks_per_unit: u32,
100    #[cube(comptime)]
101    unit_count: u32,
102    #[cube(comptime)]
103    copy_line_size: u32,
104}
105
106#[cube]
107impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, StridedTilingLayout, AsyncCopy>
108    for AsyncFullStridedJob
109{
110    type Stage = StridedStageFamily;
111
112    fn execute_task(
113        this: &mut Self,
114        #[comptime] task_id: u32,
115        global_iter: &GlobalIterator<Line<EG>>,
116        stage: &mut StridedStageMemory<ES, StridedTilingLayout>,
117        _barrier: &mut Shared<Barrier>,
118        #[comptime] config: GlobalReaderConfig,
119    ) {
120        let unit_position = this.unit_position_base + task_id * this.unit_count;
121        let unit_position_abs = unit_position * this.copy_line_size;
122
123        let layout = FullStageLayout::new(comptime![config.smem_config]);
124        let view = global_iter.view();
125
126        let pos = layout.to_source_pos(unit_position_abs);
127        let stage_offset = unit_position_abs / stage.smem.line_size();
128
129        async_copy_from(
130            view,
131            pos,
132            stage,
133            stage_offset,
134            &this.runtime_args,
135            global_iter.offset(),
136            config,
137            this.copy_line_size,
138        );
139    }
140
141    fn task_count(this: &Self) -> comptime_type!(u32) {
142        this.num_tasks_per_unit
143    }
144}