cubecl_matmul/components/global/read/strategy/
async_full_maximize_unit_count.rs

1use crate::components::{
2    InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3    global::{
4        CopyMechanism, GlobalConfig,
5        memory::{GlobalIterator, load_window_in_stage},
6        read::AsyncFullLoadingStrategy,
7    },
8    stage::{StageConfig, StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16/// Executes one memcpy_async call per unit.
17/// The objective is to reduce branching, prioritizing this over maximizing memory slice length.
18pub struct AsyncFullMaximizeUnitCountLoading {}
19
20impl LoadingValidation for AsyncFullMaximizeUnitCountLoading {
21    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
22        let matrix_layout = config.matrix_layout(ident);
23        let line_size = config.global_line_size(ident);
24
25        let (num_slices, slice_length) = match matrix_layout {
26            MatrixLayout::RowMajor => (
27                config.tiling_scheme().elements_in_stage_row(ident),
28                config.tiling_scheme().elements_in_stage_col(ident) / line_size,
29            ),
30            MatrixLayout::ColMajor => (
31                config.tiling_scheme().elements_in_stage_col(ident),
32                config.tiling_scheme().elements_in_stage_row(ident) / line_size,
33            ),
34        };
35        let unit_count = config.plane_dim() * config.num_loading_planes(ident);
36
37        if !unit_count.is_multiple_of(num_slices) {
38            return Err(Box::new(
39                "Number of slices must divide number of units evenly",
40            ));
41        }
42        if slice_length % (unit_count / num_slices) != 0 {
43            return Err(Box::new(
44                "Number of units per slice must divide slice length evenly",
45            ));
46        }
47
48        StridedTilingLayout::check(config.global_memory_config(ident))?;
49
50        Ok(())
51    }
52}
53
54#[cube]
55impl AsyncFullLoadingStrategy for AsyncFullMaximizeUnitCountLoading {
56    type TilingLayout = StridedTilingLayout;
57    type Job<IP: MatrixPrecision> = AsyncFullMaximizeUnitCountJob;
58
59    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
60        #[comptime] ident: MatmulIdent,
61        #[comptime] config: G,
62    ) -> AsyncFullMaximizeUnitCountJob {
63        let matrix_layout = config.matrix_layout(ident);
64        let line_size = config
65            .stage_config()
66            .stage_line_size(comptime!(ident.into_stage()));
67
68        let (num_slices, slice_length) = match matrix_layout {
69            MatrixLayout::RowMajor => (
70                config.tiling_scheme().elements_in_stage_row(ident),
71                config.tiling_scheme().elements_in_stage_col(ident) / line_size,
72            ),
73            MatrixLayout::ColMajor => (
74                config.tiling_scheme().elements_in_stage_col(ident),
75                config.tiling_scheme().elements_in_stage_row(ident) / line_size,
76            ),
77        };
78
79        let unit_count = config.plane_dim() * config.num_loading_planes(ident);
80
81        let units_per_slice = comptime!(unit_count / num_slices);
82        let nth_slice = UNIT_POS / units_per_slice;
83
84        let segment_length = comptime!(slice_length / units_per_slice);
85        let nth_segment = UNIT_POS % units_per_slice;
86
87        AsyncFullMaximizeUnitCountJob {
88            nth_slice,
89            nth_segment,
90            segment_length,
91            ident,
92        }
93    }
94
95    fn barrier_level() -> BarrierLevel {
96        BarrierLevel::cube_manual(0u32)
97    }
98}
99
100#[derive(CubeType, Clone, Copy)]
101pub struct AsyncFullMaximizeUnitCountJob {
102    nth_slice: u32,
103    nth_segment: u32,
104    #[cube(comptime)]
105    segment_length: u32,
106    #[cube(comptime)]
107    ident: MatmulIdent,
108}
109
110#[cube]
111impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout>
112    for AsyncFullMaximizeUnitCountJob
113{
114    fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
115        this: &mut Self,
116        _task_id: u32,
117        global_iter: &GlobalIterator<Line<IP::Global>>,
118        stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
119        mechanism: &CM,
120        #[comptime] config: G,
121    ) {
122        let mut destination: SliceMut<Line<IP::Stage>> = StridedTilingLayout::nth_slice::<IP::Stage>(
123            stage,
124            this.nth_slice,
125            comptime!(config.stage_memory_config(this.ident)),
126        );
127
128        let window = load_window_in_stage(
129            &global_iter.view(),
130            this.nth_slice,
131            comptime!(config.global_memory_config(this.ident)),
132        );
133        let seg_start = Min::min(this.nth_segment * this.segment_length, window.len());
134        let seg_end = Min::min((this.nth_segment + 1) * this.segment_length, window.len());
135
136        let src_segment = window.slice(seg_start, seg_end);
137        let mut dest_segment = destination.slice_mut(seg_start, seg_end);
138
139        CM::memcpy_async(
140            mechanism,
141            &src_segment.try_cast_unchecked(),
142            &mut dest_segment,
143        );
144    }
145
146    fn task_count(_this: &Self) -> comptime_type!(u32) {
147        1
148    }
149}