cubecl_matmul/components/global/read/strategy/
async_full_maximize_slice_length.rs

1use crate::components::{
2    InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3    global::{
4        CopyMechanism, GlobalConfig,
5        memory::{GlobalIterator, load_window_in_stage},
6        read::AsyncFullLoadingStrategy,
7    },
8    stage::{StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16/// Executes one memcpy_async call per contiguous slice.
17/// The goal is to reduce the total number of memcpy_async calls, though it may result in idle threads.
18pub struct AsyncFullMaximizeSliceLengthLoading {}
19
20impl LoadingValidation for AsyncFullMaximizeSliceLengthLoading {
21    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
22        StridedTilingLayout::check(config.global_memory_config(ident))?;
23
24        Ok(())
25    }
26}
27
28#[cube]
29impl AsyncFullLoadingStrategy for AsyncFullMaximizeSliceLengthLoading {
30    type TilingLayout = StridedTilingLayout;
31    type Job<IP: MatrixPrecision> = AsynFullMaximizeSliceLengthJob;
32
33    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
34        #[comptime] ident: MatmulIdent,
35        #[comptime] config: G,
36    ) -> AsynFullMaximizeSliceLengthJob {
37        let matrix_layout = config.matrix_layout(ident);
38
39        let num_slices = match matrix_layout {
40            MatrixLayout::RowMajor => config.tiling_scheme().elements_in_stage_row(ident),
41            MatrixLayout::ColMajor => config.tiling_scheme().elements_in_stage_col(ident),
42        };
43        let unit_count = config.plane_dim() * config.num_loading_planes(ident);
44
45        let num_tasks_per_unit = comptime!(div_ceil(num_slices, unit_count));
46
47        AsynFullMaximizeSliceLengthJob {
48            num_tasks_per_unit,
49            unit_count,
50            num_slices,
51            ident,
52        }
53    }
54
55    fn barrier_level() -> BarrierLevel {
56        BarrierLevel::cube_manual(0u32)
57    }
58}
59
60#[derive(CubeType, Clone, Copy)]
61pub struct AsynFullMaximizeSliceLengthJob {
62    #[cube(comptime)]
63    num_tasks_per_unit: u32,
64    #[cube(comptime)]
65    unit_count: u32,
66    #[cube(comptime)]
67    num_slices: u32,
68    #[cube(comptime)]
69    ident: MatmulIdent,
70}
71
72#[cube]
73impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout>
74    for AsynFullMaximizeSliceLengthJob
75{
76    fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
77        this: &mut Self,
78        task_id: u32,
79        tensor_reader: &GlobalIterator<Line<IP::Global>>,
80        stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
81        mechanism: &CM,
82        #[comptime] config: G,
83    ) {
84        let nth_slice = this.unit_count * task_id + UNIT_POS;
85
86        #[allow(clippy::collapsible_else_if)]
87        if comptime!(this.num_slices.is_multiple_of(this.unit_count)) {
88            load_nth_slice::<IP::Global, IP::Stage, CM, G>(
89                nth_slice,
90                tensor_reader,
91                stage,
92                mechanism,
93                this.ident,
94                config,
95            );
96        } else {
97            if nth_slice < this.num_slices {
98                load_nth_slice::<IP::Global, IP::Stage, CM, G>(
99                    nth_slice,
100                    tensor_reader,
101                    stage,
102                    mechanism,
103                    this.ident,
104                    config,
105                );
106            }
107        };
108    }
109
110    fn task_count(this: &Self) -> comptime_type!(u32) {
111        this.num_tasks_per_unit
112    }
113}
114
115#[cube]
116fn load_nth_slice<EG: Numeric, ES: Numeric, CM: CopyMechanism, G: GlobalConfig>(
117    nth_slice: u32,
118    global_iter: &GlobalIterator<Line<EG>>,
119    stage: &mut StridedStage<ES, StridedTilingLayout>,
120    mechanism: &CM,
121    #[comptime] ident: MatmulIdent,
122    #[comptime] config: G,
123) {
124    let window = load_window_in_stage(
125        &global_iter.view(),
126        nth_slice,
127        comptime!(config.global_memory_config(ident)),
128    );
129    let mut destination: SliceMut<Line<ES>> = StridedTilingLayout::nth_slice::<ES>(
130        stage,
131        nth_slice,
132        comptime!(config.stage_memory_config(ident)),
133    );
134
135    CM::memcpy_async(mechanism, &window.try_cast_unchecked(), &mut destination);
136}