cubecl_matmul/components/global/read/strategy/
async_full_maximize_slice_length.rs1use crate::components::{
2 InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3 global::{
4 CopyMechanism, GlobalConfig,
5 memory::{GlobalIterator, load_window_in_stage},
6 read::AsyncFullLoadingStrategy,
7 },
8 stage::{StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16pub struct AsyncFullMaximizeSliceLengthLoading {}
19
20impl LoadingValidation for AsyncFullMaximizeSliceLengthLoading {
21 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
22 StridedTilingLayout::check(config.global_memory_config(ident))?;
23
24 Ok(())
25 }
26}
27
28#[cube]
29impl AsyncFullLoadingStrategy for AsyncFullMaximizeSliceLengthLoading {
30 type TilingLayout = StridedTilingLayout;
31 type Job<IP: MatrixPrecision> = AsynFullMaximizeSliceLengthJob;
32
33 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
34 #[comptime] ident: MatmulIdent,
35 #[comptime] config: G,
36 ) -> AsynFullMaximizeSliceLengthJob {
37 let matrix_layout = config.matrix_layout(ident);
38
39 let num_slices = match matrix_layout {
40 MatrixLayout::RowMajor => config.tiling_scheme().elements_in_stage_row(ident),
41 MatrixLayout::ColMajor => config.tiling_scheme().elements_in_stage_col(ident),
42 };
43 let unit_count = config.plane_dim() * config.num_loading_planes(ident);
44
45 let num_tasks_per_unit = comptime!(div_ceil(num_slices, unit_count));
46
47 AsynFullMaximizeSliceLengthJob {
48 num_tasks_per_unit,
49 unit_count,
50 num_slices,
51 ident,
52 }
53 }
54
55 fn barrier_level() -> BarrierLevel {
56 BarrierLevel::cube_manual(0u32)
57 }
58}
59
60#[derive(CubeType, Clone, Copy)]
61pub struct AsynFullMaximizeSliceLengthJob {
62 #[cube(comptime)]
63 num_tasks_per_unit: u32,
64 #[cube(comptime)]
65 unit_count: u32,
66 #[cube(comptime)]
67 num_slices: u32,
68 #[cube(comptime)]
69 ident: MatmulIdent,
70}
71
72#[cube]
73impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout>
74 for AsynFullMaximizeSliceLengthJob
75{
76 fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
77 this: &mut Self,
78 task_id: u32,
79 tensor_reader: &GlobalIterator<Line<IP::Global>>,
80 stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
81 mechanism: &CM,
82 #[comptime] config: G,
83 ) {
84 let nth_slice = this.unit_count * task_id + UNIT_POS;
85
86 #[allow(clippy::collapsible_else_if)]
87 if comptime!(this.num_slices.is_multiple_of(this.unit_count)) {
88 load_nth_slice::<IP::Global, IP::Stage, CM, G>(
89 nth_slice,
90 tensor_reader,
91 stage,
92 mechanism,
93 this.ident,
94 config,
95 );
96 } else {
97 if nth_slice < this.num_slices {
98 load_nth_slice::<IP::Global, IP::Stage, CM, G>(
99 nth_slice,
100 tensor_reader,
101 stage,
102 mechanism,
103 this.ident,
104 config,
105 );
106 }
107 };
108 }
109
110 fn task_count(this: &Self) -> comptime_type!(u32) {
111 this.num_tasks_per_unit
112 }
113}
114
115#[cube]
116fn load_nth_slice<EG: Numeric, ES: Numeric, CM: CopyMechanism, G: GlobalConfig>(
117 nth_slice: u32,
118 global_iter: &GlobalIterator<Line<EG>>,
119 stage: &mut StridedStage<ES, StridedTilingLayout>,
120 mechanism: &CM,
121 #[comptime] ident: MatmulIdent,
122 #[comptime] config: G,
123) {
124 let window = load_window_in_stage(
125 &global_iter.view(),
126 nth_slice,
127 comptime!(config.global_memory_config(ident)),
128 );
129 let mut destination: SliceMut<Line<ES>> = StridedTilingLayout::nth_slice::<ES>(
130 stage,
131 nth_slice,
132 comptime!(config.stage_memory_config(ident)),
133 );
134
135 CM::memcpy_async(mechanism, &window.try_cast_unchecked(), &mut destination);
136}