cubecl_linalg/matmul/components/batch/
span.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5    MatmulPrecision,
6    batch::shared::swizzle,
7    global::{self, Quantization},
8};
9use cubecl_std::{
10    CubeOption,
11    tensor::r#virtual::{ReadWrite, VirtualTensor},
12};
13
14use super::shared::gmm_execute;
15
16#[derive(CubeType)]
17/// Area of a tensor a cube is responsible of performing matmul
18/// Similar to the concept of tensor slice, but specialized for matmul constraints
19pub struct Span {
20    row: SpanDim,
21    col: SpanDim,
22    batch: SpanDim,
23}
24
25#[derive(CubeType)]
26/// Span information in one dimension
27pub struct SpanDim {
28    start: u32,
29    end: u32,
30    step: u32,
31}
32
33#[cube]
34/// Iterates on several global matmul across a span
35pub trait SpanMatmul: 'static + Send + Sync {
36    fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
37        lhs: VirtualTensor<MP::EI>,
38        rhs: VirtualTensor<MP::EI>,
39        out: VirtualTensor<MP::EO, ReadWrite>,
40        span: Span,
41        acc: GMM::Accumulator,
42        k_range: (u32, u32),
43        quantization: CubeOption<Quantization<MP>>,
44        #[comptime] config: GMM::Config,
45    );
46}
47
48#[derive(CubeType)]
49/// Iterates on global matmuls in a row major fashion
50pub struct RowMajorSpanMatmul {}
51
52#[derive(CubeType)]
53/// Iterates on global matmuls in a col major fashion
54pub struct ColMajorSpanMatmul {}
55
56#[derive(CubeType)]
57/// Iterates on global matmuls following the swizzle algorithm
58///
59/// The swizzle algorithm processes  W elements per row in a top-down pass,
60/// then shifts to the next W columns in a bottom-up pass.
61/// This zigzag (top-down, bottom-up) repeats, covering the matrix span by span.
62pub struct SwizzleSpanMatmul<const W: u32> {}
63
64#[cube]
65impl Span {
66    pub fn new(row: SpanDim, col: SpanDim, batch: SpanDim) -> Span {
67        Span { row, col, batch }
68    }
69}
70
71#[cube]
72impl SpanDim {
73    pub fn new(shape: u32, stage: u32, cube_pos: u32, num_cubes: u32) -> SpanDim {
74        let num_stages = (shape + stage - 1) / stage;
75        let num = (num_stages + num_cubes - 1) / num_cubes;
76        let span = num * stage;
77        let start = cube_pos * span;
78        let end = Min::min(start + span, shape);
79        SpanDim {
80            start,
81            end,
82            step: stage,
83        }
84    }
85
86    pub fn num_iterations(&self) -> u32 {
87        let range = self.end - self.start;
88        (range + self.step - 1) / self.step
89    }
90}
91
92#[cube]
93impl SpanMatmul for RowMajorSpanMatmul {
94    fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
95        lhs: VirtualTensor<MP::EI>,
96        rhs: VirtualTensor<MP::EI>,
97        out: VirtualTensor<MP::EO, ReadWrite>,
98        span: Span,
99        mut acc: GMM::Accumulator,
100        k_range: (u32, u32),
101        quantization: CubeOption<Quantization<MP>>,
102        #[comptime] config: GMM::Config,
103    ) {
104        for batch_iter in range_stepped(span.batch.start, span.batch.end, span.batch.step) {
105            for row_iter in range_stepped(span.row.start, span.row.end, span.row.step) {
106                for col_iter in range_stepped(span.col.start, span.col.end, span.col.step) {
107                    GMM::zero_accumulator(&mut acc, config);
108                    gmm_execute::<MP, GMM>(
109                        lhs,
110                        rhs,
111                        out,
112                        row_iter,
113                        col_iter,
114                        batch_iter,
115                        &mut acc,
116                        k_range,
117                        quantization,
118                        config,
119                    );
120                }
121            }
122        }
123    }
124}
125
126#[cube]
127impl SpanMatmul for ColMajorSpanMatmul {
128    fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
129        lhs: VirtualTensor<MP::EI>,
130        rhs: VirtualTensor<MP::EI>,
131        out: VirtualTensor<MP::EO, ReadWrite>,
132        span: Span,
133        mut acc: GMM::Accumulator,
134        k_range: (u32, u32),
135        quantization: CubeOption<Quantization<MP>>,
136        #[comptime] config: GMM::Config,
137    ) {
138        for batch_iter in range_stepped(span.batch.start, span.batch.end, span.batch.step) {
139            for col_iter in range_stepped(span.col.start, span.col.end, span.col.step) {
140                for row_iter in range_stepped(span.row.start, span.row.end, span.row.step) {
141                    GMM::zero_accumulator(&mut acc, config);
142                    gmm_execute::<MP, GMM>(
143                        lhs,
144                        rhs,
145                        out,
146                        row_iter,
147                        col_iter,
148                        batch_iter,
149                        &mut acc,
150                        k_range,
151                        quantization,
152                        config,
153                    );
154                }
155            }
156        }
157    }
158}
159
160#[cube]
161impl<const W: u32> SpanMatmul for SwizzleSpanMatmul<W> {
162    fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
163        lhs: VirtualTensor<MP::EI>,
164        rhs: VirtualTensor<MP::EI>,
165        out: VirtualTensor<MP::EO, ReadWrite>,
166        span: Span,
167        mut acc: GMM::Accumulator,
168        k_range: (u32, u32),
169        quantization: CubeOption<Quantization<MP>>,
170        #[comptime] config: GMM::Config,
171    ) {
172        let num_swizzle = span.row.num_iterations() * span.col.num_iterations();
173
174        for batch_iter in range_stepped(span.batch.start, span.batch.end, span.batch.step) {
175            for n in 0..num_swizzle {
176                GMM::zero_accumulator(&mut acc, config);
177                let (row, col) = swizzle(n, span.row.num_iterations(), W);
178
179                let row_iter = span.row.start + row * span.row.step;
180                let col_iter = span.col.start + col * span.col.step;
181                gmm_execute::<MP, GMM>(
182                    lhs,
183                    rhs,
184                    out,
185                    row_iter,
186                    col_iter,
187                    batch_iter,
188                    &mut acc,
189                    k_range,
190                    quantization,
191                    config,
192                );
193            }
194        }
195    }
196}