cubecl_matmul/components/global/read/strategy/
async_partial_maximize_slice_length.rs

1use crate::components::{
2    InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3    global::{
4        CopyMechanism, GlobalConfig,
5        memory::{GlobalIterator, load_window_in_stage},
6        read::AsyncPartialLoadingStrategy,
7    },
8    stage::{StageConfig, StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16/// Executes one `memcpy_async` call per contiguous slice.
17/// The goal is to reduce the total number of `memcpy_async` calls, though it may result in idle threads.
18pub struct AsyncPartialMaximizeSliceLengthLoading {}
19
20impl LoadingValidation for AsyncPartialMaximizeSliceLengthLoading {
21    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
22        StridedTilingLayout::check(config.global_memory_config(ident))?;
23
24        Ok(())
25    }
26}
27
28#[cube]
29impl AsyncPartialLoadingStrategy for AsyncPartialMaximizeSliceLengthLoading {
30    type TilingLayout = StridedTilingLayout;
31    type Job<IP: MatrixPrecision> = AsyncPartialMaximizeSliceLengthJob;
32
33    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
34        #[comptime] stage_index: u32,
35        #[comptime] ident: MatmulIdent,
36        #[comptime] config: G,
37    ) -> AsyncPartialMaximizeSliceLengthJob {
38        let matrix_layout = config.matrix_layout(ident);
39        let line_size = config
40            .stage_config()
41            .stage_line_size(comptime!(ident.into_stage()));
42        let num_stages = 2;
43
44        let total_row = config.tiling_scheme().elements_in_stage_row(ident);
45        let total_col = config.tiling_scheme().elements_in_stage_col(ident);
46
47        // If stage is parallel to slices, slices are as long as in full stage memory, but there are less.
48        // Otherwise, slices are shorter but there are as many as in full stage memory
49        let (num_slices, num_slices_stage_offset, slice_length, slice_stage_offset) = comptime! {
50            match (ident, matrix_layout) {
51                (MatmulIdent::Lhs, MatrixLayout::RowMajor) => {
52                    let slice_length = total_col / (num_stages * line_size);
53
54                    (total_row, 0, slice_length, stage_index * slice_length)
55                },
56                (MatmulIdent::Lhs, MatrixLayout::ColMajor) => {
57                    let num_slices = total_col / num_stages;
58
59                    (num_slices, stage_index * num_slices, total_row / line_size, 0)
60                },
61                (MatmulIdent::Rhs, MatrixLayout::RowMajor) => {
62                    let num_slices = total_row / num_stages;
63
64                    (num_slices, stage_index * num_slices, total_col / line_size, 0)
65                },
66                (MatmulIdent::Rhs, MatrixLayout::ColMajor) => {
67                    let slice_length = total_row / (num_stages * line_size);
68
69                    (total_col, 0, slice_length, stage_index * slice_length)
70                },
71                (MatmulIdent::Out, _) => unreachable!()
72            }
73        };
74
75        let unit_count = config.plane_dim() * config.num_loading_planes(ident);
76        let num_tasks_per_unit = comptime!(num_slices.div_ceil(unit_count));
77
78        AsyncPartialMaximizeSliceLengthJob {
79            num_tasks_per_unit,
80            unit_count,
81            num_slices_stage_offset,
82            ident,
83            slice_stage_offset,
84            slice_length,
85            num_slices,
86        }
87    }
88
89    fn barrier_level() -> BarrierLevel {
90        BarrierLevel::cube_manual(0u32)
91    }
92}
93
94#[derive(CubeType, Clone, Copy)]
95pub struct AsyncPartialMaximizeSliceLengthJob {
96    #[cube(comptime)]
97    num_tasks_per_unit: u32,
98    #[cube(comptime)]
99    unit_count: u32,
100    #[cube(comptime)]
101    num_slices_stage_offset: u32,
102    #[cube(comptime)]
103    ident: MatmulIdent,
104    #[cube(comptime)]
105    slice_stage_offset: u32,
106    #[cube(comptime)]
107    slice_length: u32,
108    #[cube(comptime)]
109    num_slices: u32,
110}
111
112#[cube]
113impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout>
114    for AsyncPartialMaximizeSliceLengthJob
115{
116    fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
117        this: &mut Self,
118        task_id: u32,
119        global_iter: &GlobalIterator<Line<IP::Global>>,
120        stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
121        mechanism: &CM,
122        #[comptime] config: G,
123    ) {
124        let nth_slice_in_stage = this.unit_count * task_id + UNIT_POS;
125
126        let nth_slice = nth_slice_in_stage + this.num_slices_stage_offset;
127
128        let window = load_window_in_stage(
129            &global_iter.view(),
130            nth_slice,
131            comptime!(config.global_memory_config(this.ident)),
132        );
133        let mut destination: SliceMut<Line<IP::Stage>> = StridedTilingLayout::nth_slice::<IP::Stage>(
134            stage,
135            nth_slice,
136            comptime!(config.stage_memory_config(this.ident)),
137        );
138
139        let start = this.slice_stage_offset;
140        let limit = select(
141            this.slice_stage_offset < window.len(),
142            this.slice_stage_offset,
143            window.len(),
144        );
145        let end = start + Min::min(window.len() - limit, this.slice_length);
146
147        let src = window.slice(start, end);
148        let mut dest = destination.slice_mut(start, end);
149
150        #[allow(clippy::collapsible_else_if)]
151        if comptime!(this.num_slices.is_multiple_of(this.unit_count)) {
152            CM::memcpy_async(mechanism, &src.try_cast_unchecked(), &mut dest);
153        } else {
154            if nth_slice_in_stage < this.num_slices {
155                CM::memcpy_async(mechanism, &src.try_cast_unchecked(), &mut dest);
156            }
157        };
158    }
159
160    fn task_count(this: &Self) -> comptime_type!(u32) {
161        this.num_tasks_per_unit
162    }
163}