cubek_convolution/components/global/read/reader/strategy/
async_full_strided.rs1use cubecl::prelude::*;
2use cubecl::std::tensor::layout::{Layout, LayoutExpand};
3use cubecl::{ir::DeviceProperties, prelude::barrier::Barrier};
4use cubek_matmul::components::{
5 global::{
6 GlobalReaderConfig, PlaneFlowPartition,
7 memory::GlobalIterator,
8 multi_stage::LoadMaxRoundPlaneCount,
9 read::{
10 LoadingJob, LoadingValidation, async_barrier::AsyncCopy,
11 async_full_strided::AsyncFullStridedLoading as MatmulStridedLoading,
12 stage::FullStageLayout,
13 },
14 },
15 stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout},
16};
17use cubek_matmul::definition::{InvalidConfigError, MatmulElems, MatmulProblem, StageIdent};
18
19use crate::components::global::{
20 args::RuntimeArgs,
21 read::{
22 full_reader::FullLoadingStrategy,
23 strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
24 },
25};
26
27#[derive(CubeType, Clone, Copy)]
28pub struct AsyncFullStridedLoading {}
31
32impl LoadingValidation for AsyncFullStridedLoading {
33 fn validate_with_config(
34 device_props: &DeviceProperties,
35 config: &GlobalReaderConfig,
36 ) -> Result<(), InvalidConfigError> {
37 MatmulStridedLoading::validate_with_config(device_props, config)
38 }
39
40 fn validate_with_problem(
41 _problem: &MatmulProblem,
42 _dtypes: &MatmulElems,
43 _ident: StageIdent,
44 ) -> Result<(), InvalidConfigError> {
45 Ok(())
46 }
47}
48
49impl LoadMaxRoundPlaneCount for AsyncFullStridedLoading {
50 fn max_round_plane_count(
51 elements_per_tile: u32,
52 tiles_per_stage: u32,
53 line_size: LineSize,
54 plane_dim: u32,
55 dtype: StorageType,
56 ) -> u32 {
57 MatmulStridedLoading::max_round_plane_count(
58 elements_per_tile,
59 tiles_per_stage,
60 line_size,
61 plane_dim,
62 dtype,
63 )
64 }
65}
66
67#[cube]
68impl FullLoadingStrategy for AsyncFullStridedLoading {
69 type TilingLayout = StridedTilingLayout;
70 type SyncStrategy = AsyncCopy;
71 type Job<EG: Numeric, ES: Numeric> = AsyncFullStridedJob;
72
73 fn new_job<EG: Numeric, ES: Numeric>(
74 runtime_args: RuntimeArgs,
75 #[comptime] _line_size: LineSize,
76 #[comptime] config: GlobalReaderConfig,
77 ) -> Self::Job<EG, ES> {
78 let type_size = ES::type_size_bits().comptime();
79 let line_size = ASYNC_COPY_WIDTH / type_size as u32;
80 let num_stage_lines = config.smem_config.elements_per_stage() / line_size;
81 let unit_count = config.loading_planes_count() * config.plane_dim;
82 let num_tasks_per_unit = num_stage_lines / unit_count;
83
84 let unit_position_base = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
85 .load_index(config.input_load_flow)
86 * config.plane_dim
87 + UNIT_POS_X;
88
89 AsyncFullStridedJob {
90 unit_position_base,
91 runtime_args,
92 num_tasks_per_unit,
93 unit_count,
94 copy_line_size: line_size,
95 }
96 }
97}
98
99#[derive(CubeType, Clone)]
100pub struct AsyncFullStridedJob {
101 unit_position_base: u32,
102 runtime_args: RuntimeArgs,
103
104 #[cube(comptime)]
105 num_tasks_per_unit: u32,
106 #[cube(comptime)]
107 unit_count: u32,
108 #[cube(comptime)]
109 copy_line_size: u32,
110}
111
112#[cube]
113impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, StridedTilingLayout, AsyncCopy>
114 for AsyncFullStridedJob
115{
116 type Stage = StridedStageFamily;
117
118 fn execute_task(
119 this: &mut Self,
120 #[comptime] task_id: u32,
121 global_iter: &GlobalIterator<Line<EG>>,
122 stage: &mut StridedStageMemory<ES, StridedTilingLayout>,
123 _barrier: &mut Shared<Barrier>,
124 #[comptime] config: GlobalReaderConfig,
125 ) {
126 let unit_position = this.unit_position_base + task_id * this.unit_count;
127 let unit_position_abs = unit_position * this.copy_line_size;
128
129 let layout = FullStageLayout::new(config.smem_config);
130 let view = global_iter.view();
131
132 let pos = layout.to_source_pos(unit_position_abs);
133 let stage_offset = unit_position_abs / stage.smem.line_size() as u32;
134
135 async_copy_from(
136 view,
137 pos,
138 stage,
139 stage_offset,
140 &this.runtime_args,
141 global_iter.offset(),
142 config,
143 this.copy_line_size,
144 );
145 }
146
147 fn task_count(this: &Self) -> comptime_type!(u32) {
148 this.num_tasks_per_unit
149 }
150}