cubecl_matmul/components/batch/partitioned_matmul/partition/
matmul.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::{
5    AccG, LhsG, MatmulPrecision, RhsG,
6    batch::SliceIndex,
7    global::{self, GlobalConfig},
8};
9use cubecl_std::{
10    CubeOption, CubeOptionExpand,
11    tensor::{View, layout::Coords3d},
12};
13
14#[derive(CubeType)]
15/// Area of a tensor a cube is responsible of performing matmul
16pub struct PartitionRanges {
17    row: PartitionRangeDim,
18    col: PartitionRangeDim,
19    batch: PartitionRangeDim,
20}
21
22#[derive(CubeType)]
23pub struct PartitionRangeDim {
24    start: u32,
25    #[cube(comptime)]
26    step: u32,
27    #[cube(comptime)]
28    num_steps: u32,
29}
30
31#[cube]
32/// Iterates on several global matmul across a global partition
33pub trait GlobalPartitionMatmul: 'static + Send + Sync {
34    fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
35        a: View<Line<LhsG<MP>>, Coords3d>,
36        b: View<Line<RhsG<MP>>, Coords3d>,
37        c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
38        out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
39        partition_ranges: PartitionRanges,
40        acc: GMM::Accumulators,
41        k_range: (u32, u32),
42        #[comptime] config: GMM::Config,
43    );
44}
45
46#[derive(CubeType)]
47/// Iterates on global matmuls in a row major fashion
48pub struct RowMajorGlobalPartitionMatmul {}
49
50#[derive(CubeType)]
51/// Iterates on global matmuls in a col major fashion
52pub struct ColMajorGlobalPartitionMatmul {}
53
54#[cube]
55impl PartitionRanges {
56    /// Create a new [PartitionRanges]
57    pub fn new(
58        row: PartitionRangeDim,
59        col: PartitionRangeDim,
60        batch: PartitionRangeDim,
61    ) -> PartitionRanges {
62        PartitionRanges { row, col, batch }
63    }
64}
65
66#[cube]
67impl PartitionRangeDim {
68    /// Create a new [PartitionRangeDim]
69    pub fn new(
70        cube_pos: u32,
71        #[comptime] stage_dim: u32,
72        #[comptime] global_partition_dim: u32,
73    ) -> PartitionRangeDim {
74        let start = cube_pos * global_partition_dim;
75        PartitionRangeDim {
76            start,
77            step: stage_dim,
78            num_steps: global_partition_dim.div_ceil(stage_dim),
79        }
80    }
81}
82
83#[cube]
84impl GlobalPartitionMatmul for RowMajorGlobalPartitionMatmul {
85    fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
86        a: View<Line<LhsG<MP>>, Coords3d>,
87        b: View<Line<RhsG<MP>>, Coords3d>,
88        c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
89        out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
90        ranges: PartitionRanges,
91        mut acc: GMM::Accumulators,
92        k_range: (u32, u32),
93        #[comptime] config: GMM::Config,
94    ) {
95        // Needed for the unroll macro to work.
96        let num_steps_batch = comptime!(ranges.batch.num_steps);
97        let num_steps_row = comptime!(ranges.row.num_steps);
98        let num_steps_col = comptime!(ranges.col.num_steps);
99
100        #[unroll(num_steps_batch == 1)]
101        for batch in 0..num_steps_batch {
102            let batch_iter = ranges.batch.start + batch * ranges.batch.step;
103
104            #[unroll(num_steps_row == 1)]
105            for row in 0..num_steps_row {
106                let row_offset = ranges.row.start + row * ranges.row.step;
107
108                #[unroll(num_steps_col == 1)]
109                for col in 0..num_steps_col {
110                    let col_offset = ranges.col.start + col * ranges.col.step;
111
112                    execute_global_matmul::<MP, GMM>(
113                        a, b, c, out, batch_iter, row_offset, col_offset, &mut acc, k_range, config,
114                    );
115                }
116            }
117        }
118    }
119}
120
121#[cube]
122impl GlobalPartitionMatmul for ColMajorGlobalPartitionMatmul {
123    fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
124        a: View<Line<LhsG<MP>>, Coords3d>,
125        b: View<Line<RhsG<MP>>, Coords3d>,
126        c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
127        out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
128        ranges: PartitionRanges,
129        mut acc: GMM::Accumulators,
130        k_range: (u32, u32),
131        #[comptime] config: GMM::Config,
132    ) {
133        // Needed for the unroll macro to work.
134        let num_steps_batch = comptime!(ranges.batch.num_steps);
135        let num_steps_row = comptime!(ranges.row.num_steps);
136        let num_steps_col = comptime!(ranges.col.num_steps);
137
138        #[unroll(num_steps_batch == 1)]
139        for batch in 0..num_steps_batch {
140            let batch_iter = ranges.batch.start + batch * ranges.batch.step;
141
142            #[unroll(num_steps_col == 1)]
143            for col in 0..num_steps_col {
144                let col_offset = ranges.col.start + col * ranges.col.step;
145
146                #[unroll(num_steps_row == 1)]
147                for row in 0..num_steps_row {
148                    let row_offset = ranges.row.start + row * ranges.row.step;
149
150                    execute_global_matmul::<MP, GMM>(
151                        a, b, c, out, batch_iter, row_offset, col_offset, &mut acc, k_range, config,
152                    );
153                }
154            }
155        }
156    }
157}
158
159#[cube]
160/// Execute global matmul on lhs, rhs, writing in out.
161/// m and n offsets are absolute rows and columns
162pub(crate) fn execute_global_matmul<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
163    a: View<Line<LhsG<MP>>, Coords3d>,
164    b: View<Line<RhsG<MP>>, Coords3d>,
165    c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
166    out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
167    nth_batch: u32,
168    m_offset: u32,
169    n_offset: u32,
170    acc: &mut GMM::Accumulators,
171    k_range: (u32, u32),
172    #[comptime] config: GMM::Config,
173) {
174    let tiling = config.tiling_scheme();
175    let stage_m = tiling.elements_in_stage_m().runtime();
176    let stage_n = tiling.elements_in_stage_n().runtime();
177    let k_size = k_range.1 - k_range.0;
178
179    let a = a.view(SliceIndex::new(nth_batch, a.shape()));
180    let b = b.view(SliceIndex::new(nth_batch, b.shape()));
181    let c = match c {
182        CubeOption::Some(c) => {
183            let c = c.view(SliceIndex::new(nth_batch, c.shape()));
184            CubeOption::new_Some(c.slice_unchecked((m_offset, n_offset), (stage_m, stage_n)))
185        }
186        CubeOption::None => CubeOption::new_None(),
187    };
188    let out = out.view_mut(SliceIndex::new(nth_batch, out.shape()));
189
190    GMM::execute(
191        GMM::init_lhs_global_reader(
192            a.slice_unchecked((m_offset, k_range.0), (stage_m, k_size)),
193            config,
194        ),
195        GMM::init_rhs_global_reader(
196            b.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
197            config,
198        ),
199        GMM::init_acc_global_reader(c, config),
200        GMM::init_global_writer(
201            out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
202            config,
203        ),
204        acc,
205        k_range,
206        config,
207    );
208}