cubecl_matmul/components/global/read/strategy/
async_full_maximize_unit_count.rs1use crate::components::{
2 InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3 global::{
4 CopyMechanism, GlobalConfig,
5 memory::{GlobalIterator, load_window_in_stage},
6 read::AsyncFullLoadingStrategy,
7 },
8 stage::{StageConfig, StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16pub struct AsyncFullMaximizeUnitCountLoading {}
19
20impl LoadingValidation for AsyncFullMaximizeUnitCountLoading {
21 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
22 let matrix_layout = config.matrix_layout(ident);
23 let line_size = config.global_line_size(ident);
24
25 let (num_slices, slice_length) = match matrix_layout {
26 MatrixLayout::RowMajor => (
27 config.tiling_scheme().elements_in_stage_row(ident),
28 config.tiling_scheme().elements_in_stage_col(ident) / line_size,
29 ),
30 MatrixLayout::ColMajor => (
31 config.tiling_scheme().elements_in_stage_col(ident),
32 config.tiling_scheme().elements_in_stage_row(ident) / line_size,
33 ),
34 };
35 let unit_count = config.plane_dim() * config.num_loading_planes(ident);
36
37 if !unit_count.is_multiple_of(num_slices) {
38 return Err(Box::new(
39 "Number of slices must divide number of units evenly",
40 ));
41 }
42 if slice_length % (unit_count / num_slices) != 0 {
43 return Err(Box::new(
44 "Number of units per slice must divide slice length evenly",
45 ));
46 }
47
48 StridedTilingLayout::check(config.global_memory_config(ident))?;
49
50 Ok(())
51 }
52}
53
54#[cube]
55impl AsyncFullLoadingStrategy for AsyncFullMaximizeUnitCountLoading {
56 type TilingLayout = StridedTilingLayout;
57 type Job<IP: MatrixPrecision> = AsyncFullMaximizeUnitCountJob;
58
59 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
60 #[comptime] ident: MatmulIdent,
61 #[comptime] config: G,
62 ) -> AsyncFullMaximizeUnitCountJob {
63 let matrix_layout = config.matrix_layout(ident);
64 let line_size = config
65 .stage_config()
66 .stage_line_size(comptime!(ident.into_stage()));
67
68 let (num_slices, slice_length) = match matrix_layout {
69 MatrixLayout::RowMajor => (
70 config.tiling_scheme().elements_in_stage_row(ident),
71 config.tiling_scheme().elements_in_stage_col(ident) / line_size,
72 ),
73 MatrixLayout::ColMajor => (
74 config.tiling_scheme().elements_in_stage_col(ident),
75 config.tiling_scheme().elements_in_stage_row(ident) / line_size,
76 ),
77 };
78
79 let unit_count = config.plane_dim() * config.num_loading_planes(ident);
80
81 let units_per_slice = comptime!(unit_count / num_slices);
82 let nth_slice = UNIT_POS / units_per_slice;
83
84 let segment_length = comptime!(slice_length / units_per_slice);
85 let nth_segment = UNIT_POS % units_per_slice;
86
87 AsyncFullMaximizeUnitCountJob {
88 nth_slice,
89 nth_segment,
90 segment_length,
91 ident,
92 }
93 }
94
95 fn barrier_level() -> BarrierLevel {
96 BarrierLevel::cube_manual(0u32)
97 }
98}
99
100#[derive(CubeType, Clone, Copy)]
101pub struct AsyncFullMaximizeUnitCountJob {
102 nth_slice: u32,
103 nth_segment: u32,
104 #[cube(comptime)]
105 segment_length: u32,
106 #[cube(comptime)]
107 ident: MatmulIdent,
108}
109
110#[cube]
111impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout>
112 for AsyncFullMaximizeUnitCountJob
113{
114 fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
115 this: &mut Self,
116 _task_id: u32,
117 global_iter: &GlobalIterator<Line<IP::Global>>,
118 stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
119 mechanism: &CM,
120 #[comptime] config: G,
121 ) {
122 let mut destination: SliceMut<Line<IP::Stage>> = StridedTilingLayout::nth_slice::<IP::Stage>(
123 stage,
124 this.nth_slice,
125 comptime!(config.stage_memory_config(this.ident)),
126 );
127
128 let window = load_window_in_stage(
129 &global_iter.view(),
130 this.nth_slice,
131 comptime!(config.global_memory_config(this.ident)),
132 );
133 let seg_start = Min::min(this.nth_segment * this.segment_length, window.len());
134 let seg_end = Min::min((this.nth_segment + 1) * this.segment_length, window.len());
135
136 let src_segment = window.slice(seg_start, seg_end);
137 let mut dest_segment = destination.slice_mut(seg_start, seg_end);
138
139 CM::memcpy_async(
140 mechanism,
141 &src_segment.try_cast_unchecked(),
142 &mut dest_segment,
143 );
144 }
145
146 fn task_count(_this: &Self) -> comptime_type!(u32) {
147 1
148 }
149}