cubek_convolution/components/global/read/strategy/
async_full_strided.rs1use cubecl::{
2 prelude::*,
3 std::tensor::layout::{Layout, LayoutExpand},
4 {ir::DeviceProperties, prelude::barrier::Barrier},
5};
6use cubek_matmul::components::{
7 global::{
8 GlobalReaderConfig, PlaneFlowPartition,
9 memory::GlobalIterator,
10 multi_stage::LoadMaxRoundPlaneCount,
11 read::{
12 FullLoadingStrategy, LoadingJob, LoadingValidation, async_barrier::AsyncCopy,
13 async_full_strided::AsyncFullStridedLoading as MatmulStridedLoading,
14 stage::FullStageLayout,
15 },
16 },
17 stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout},
18};
19use cubek_matmul::definition::{MatmulElems, MatmulProblem, StageIdent};
20use cubek_std::{InvalidConfigError, tile::Strided};
21
22use crate::components::global::{
23 args::RuntimeArgs,
24 read::strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25};
26
27#[derive(CubeType, Clone, Copy)]
28pub struct AsyncFullStridedLoading {}
31
32impl LoadingValidation for AsyncFullStridedLoading {
33 fn validate_with_config(
34 device_props: &DeviceProperties,
35 config: &GlobalReaderConfig,
36 ) -> Result<(), InvalidConfigError> {
37 MatmulStridedLoading::validate_with_config(device_props, config)
38 }
39
40 fn validate_with_problem(
41 problem: &MatmulProblem,
42 dtypes: &MatmulElems,
43 ident: StageIdent,
44 ) -> Result<(), InvalidConfigError> {
45 MatmulStridedLoading::validate_with_problem(problem, dtypes, ident)
46 }
47}
48
49impl LoadMaxRoundPlaneCount for AsyncFullStridedLoading {
50 fn max_round_plane_count(
51 elements_per_tile: u32,
52 tiles_per_stage: u32,
53 vector_size: VectorSize,
54 plane_dim: u32,
55 dtype: StorageType,
56 ) -> u32 {
57 MatmulStridedLoading::max_round_plane_count(
58 elements_per_tile,
59 tiles_per_stage,
60 vector_size,
61 plane_dim,
62 dtype,
63 )
64 }
65}
66
67#[cube]
68impl FullLoadingStrategy<RuntimeArgs> for AsyncFullStridedLoading {
69 type TilingLayout = StridedTilingLayout;
70 type SyncStrategy = AsyncCopy;
71 type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size> = AsyncFullStridedJob;
72 type Stage = StridedStageFamily;
73 type TileKind = Strided;
74
75 fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
76 runtime_args: RuntimeArgs,
77 #[comptime] config: GlobalReaderConfig,
78 ) -> Self::Job<EG, NG, ES, NS> {
79 let type_size = ES::type_size_bits().comptime();
80 let vector_size = ASYNC_COPY_WIDTH / type_size as u32;
81 let num_stage_vectors = config.smem_config.elements_per_stage() / vector_size;
82 let unit_count = config.loading_planes_count() * config.plane_dim;
83 let num_tasks_per_unit = num_stage_vectors / unit_count;
84
85 let unit_position_base = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
86 .load_index(config.input_load_flow)
87 * config.plane_dim
88 + UNIT_POS_X;
89
90 AsyncFullStridedJob {
91 unit_position_base,
92 runtime_args,
93 num_tasks_per_unit,
94 unit_count,
95 copy_vector_size: vector_size,
96 }
97 }
98}
99
100#[derive(CubeType, Clone)]
101pub struct AsyncFullStridedJob {
102 unit_position_base: u32,
103 runtime_args: RuntimeArgs,
104
105 #[cube(comptime)]
106 num_tasks_per_unit: u32,
107 #[cube(comptime)]
108 unit_count: u32,
109 #[cube(comptime)]
110 copy_vector_size: u32,
111}
112
113#[cube]
114impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size>
115 LoadingJob<EG, NG, ES, NS, StridedTilingLayout, AsyncCopy> for AsyncFullStridedJob
116{
117 type Stage = StridedStageFamily;
118
119 fn execute_task(
120 this: &mut Self,
121 #[comptime] task_id: u32,
122 global_iter: &GlobalIterator<Vector<EG, NG>>,
123 stage: &mut StridedStageMemory<ES, NS, StridedTilingLayout>,
124 _barrier: &mut Shared<Barrier>,
125 #[comptime] config: GlobalReaderConfig,
126 ) {
127 let unit_position = this.unit_position_base + task_id * this.unit_count;
128 let unit_position_abs = unit_position * this.copy_vector_size;
129
130 let layout = FullStageLayout::new(config.smem_config);
131 let view = global_iter.view();
132
133 let pos = layout.to_source_pos(unit_position_abs);
134 let stage_offset = unit_position_abs / stage.smem.vector_size() as u32;
135
136 async_copy_from(
137 view,
138 pos,
139 stage,
140 stage_offset,
141 &this.runtime_args,
142 global_iter.offset(),
143 config,
144 this.copy_vector_size,
145 );
146 }
147
148 fn task_count(this: &Self) -> comptime_type!(u32) {
149 this.num_tasks_per_unit
150 }
151}