cubecl_matmul/components/global/read/strategy/
async_partial_maximize_slice_length.rs1use crate::components::{
2 InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3 global::{
4 CopyMechanism, GlobalConfig,
5 memory::{GlobalIterator, load_window_in_stage},
6 read::AsyncPartialLoadingStrategy,
7 },
8 stage::{StageConfig, StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16pub struct AsyncPartialMaximizeSliceLengthLoading {}
19
20impl LoadingValidation for AsyncPartialMaximizeSliceLengthLoading {
21 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
22 StridedTilingLayout::check(config.global_memory_config(ident))?;
23
24 Ok(())
25 }
26}
27
28#[cube]
29impl AsyncPartialLoadingStrategy for AsyncPartialMaximizeSliceLengthLoading {
30 type TilingLayout = StridedTilingLayout;
31 type Job<IP: MatrixPrecision> = AsyncPartialMaximizeSliceLengthJob;
32
33 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
34 #[comptime] stage_index: u32,
35 #[comptime] ident: MatmulIdent,
36 #[comptime] config: G,
37 ) -> AsyncPartialMaximizeSliceLengthJob {
38 let matrix_layout = config.matrix_layout(ident);
39 let line_size = config
40 .stage_config()
41 .stage_line_size(comptime!(ident.into_stage()));
42 let num_stages = 2;
43
44 let total_row = config.tiling_scheme().elements_in_stage_row(ident);
45 let total_col = config.tiling_scheme().elements_in_stage_col(ident);
46
47 let (num_slices, num_slices_stage_offset, slice_length, slice_stage_offset) = comptime! {
50 match (ident, matrix_layout) {
51 (MatmulIdent::Lhs, MatrixLayout::RowMajor) => {
52 let slice_length = total_col / (num_stages * line_size);
53
54 (total_row, 0, slice_length, stage_index * slice_length)
55 },
56 (MatmulIdent::Lhs, MatrixLayout::ColMajor) => {
57 let num_slices = total_col / num_stages;
58
59 (num_slices, stage_index * num_slices, total_row / line_size, 0)
60 },
61 (MatmulIdent::Rhs, MatrixLayout::RowMajor) => {
62 let num_slices = total_row / num_stages;
63
64 (num_slices, stage_index * num_slices, total_col / line_size, 0)
65 },
66 (MatmulIdent::Rhs, MatrixLayout::ColMajor) => {
67 let slice_length = total_row / (num_stages * line_size);
68
69 (total_col, 0, slice_length, stage_index * slice_length)
70 },
71 (MatmulIdent::Out, _) => unreachable!()
72 }
73 };
74
75 let unit_count = config.plane_dim() * config.num_loading_planes(ident);
76 let num_tasks_per_unit = comptime!(num_slices.div_ceil(unit_count));
77
78 AsyncPartialMaximizeSliceLengthJob {
79 num_tasks_per_unit,
80 unit_count,
81 num_slices_stage_offset,
82 ident,
83 slice_stage_offset,
84 slice_length,
85 num_slices,
86 }
87 }
88
89 fn barrier_level() -> BarrierLevel {
90 BarrierLevel::cube_manual(0u32)
91 }
92}
93
94#[derive(CubeType, Clone, Copy)]
95pub struct AsyncPartialMaximizeSliceLengthJob {
96 #[cube(comptime)]
97 num_tasks_per_unit: u32,
98 #[cube(comptime)]
99 unit_count: u32,
100 #[cube(comptime)]
101 num_slices_stage_offset: u32,
102 #[cube(comptime)]
103 ident: MatmulIdent,
104 #[cube(comptime)]
105 slice_stage_offset: u32,
106 #[cube(comptime)]
107 slice_length: u32,
108 #[cube(comptime)]
109 num_slices: u32,
110}
111
112#[cube]
113impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout>
114 for AsyncPartialMaximizeSliceLengthJob
115{
116 fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
117 this: &mut Self,
118 task_id: u32,
119 global_iter: &GlobalIterator<Line<IP::Global>>,
120 stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
121 mechanism: &CM,
122 #[comptime] config: G,
123 ) {
124 let nth_slice_in_stage = this.unit_count * task_id + UNIT_POS;
125
126 let nth_slice = nth_slice_in_stage + this.num_slices_stage_offset;
127
128 let window = load_window_in_stage(
129 &global_iter.view(),
130 nth_slice,
131 comptime!(config.global_memory_config(this.ident)),
132 );
133 let mut destination: SliceMut<Line<IP::Stage>> = StridedTilingLayout::nth_slice::<IP::Stage>(
134 stage,
135 nth_slice,
136 comptime!(config.stage_memory_config(this.ident)),
137 );
138
139 let start = this.slice_stage_offset;
140 let limit = select(
141 this.slice_stage_offset < window.len(),
142 this.slice_stage_offset,
143 window.len(),
144 );
145 let end = start + Min::min(window.len() - limit, this.slice_length);
146
147 let src = window.slice(start, end);
148 let mut dest = destination.slice_mut(start, end);
149
150 #[allow(clippy::collapsible_else_if)]
151 if comptime!(this.num_slices.is_multiple_of(this.unit_count)) {
152 CM::memcpy_async(mechanism, &src.try_cast_unchecked(), &mut dest);
153 } else {
154 if nth_slice_in_stage < this.num_slices {
155 CM::memcpy_async(mechanism, &src.try_cast_unchecked(), &mut dest);
156 }
157 };
158 }
159
160 fn task_count(this: &Self) -> comptime_type!(u32) {
161 this.num_tasks_per_unit
162 }
163}