cubecl_linalg/matmul/components/batch/
span.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5 MatmulPrecision,
6 batch::shared::swizzle,
7 global::{self, Quantization},
8};
9use cubecl_std::{
10 CubeOption,
11 tensor::r#virtual::{ReadWrite, VirtualTensor},
12};
13
14use super::shared::gmm_execute;
15
16#[derive(CubeType)]
17pub struct Span {
20 row: SpanDim,
21 col: SpanDim,
22 batch: SpanDim,
23}
24
25#[derive(CubeType)]
26pub struct SpanDim {
28 start: u32,
29 end: u32,
30 step: u32,
31}
32
33#[cube]
34pub trait SpanMatmul: 'static + Send + Sync {
36 fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
37 lhs: VirtualTensor<MP::EI>,
38 rhs: VirtualTensor<MP::EI>,
39 out: VirtualTensor<MP::EO, ReadWrite>,
40 span: Span,
41 acc: GMM::Accumulator,
42 k_range: (u32, u32),
43 quantization: CubeOption<Quantization<MP>>,
44 #[comptime] config: GMM::Config,
45 );
46}
47
48#[derive(CubeType)]
49pub struct RowMajorSpanMatmul {}
51
52#[derive(CubeType)]
53pub struct ColMajorSpanMatmul {}
55
56#[derive(CubeType)]
57pub struct SwizzleSpanMatmul<const W: u32> {}
63
64#[cube]
65impl Span {
66 pub fn new(row: SpanDim, col: SpanDim, batch: SpanDim) -> Span {
67 Span { row, col, batch }
68 }
69}
70
71#[cube]
72impl SpanDim {
73 pub fn new(shape: u32, stage: u32, cube_pos: u32, num_cubes: u32) -> SpanDim {
74 let num_stages = (shape + stage - 1) / stage;
75 let num = (num_stages + num_cubes - 1) / num_cubes;
76 let span = num * stage;
77 let start = cube_pos * span;
78 let end = Min::min(start + span, shape);
79 SpanDim {
80 start,
81 end,
82 step: stage,
83 }
84 }
85
86 pub fn num_iterations(&self) -> u32 {
87 let range = self.end - self.start;
88 (range + self.step - 1) / self.step
89 }
90}
91
92#[cube]
93impl SpanMatmul for RowMajorSpanMatmul {
94 fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
95 lhs: VirtualTensor<MP::EI>,
96 rhs: VirtualTensor<MP::EI>,
97 out: VirtualTensor<MP::EO, ReadWrite>,
98 span: Span,
99 mut acc: GMM::Accumulator,
100 k_range: (u32, u32),
101 quantization: CubeOption<Quantization<MP>>,
102 #[comptime] config: GMM::Config,
103 ) {
104 for batch_iter in range_stepped(span.batch.start, span.batch.end, span.batch.step) {
105 for row_iter in range_stepped(span.row.start, span.row.end, span.row.step) {
106 for col_iter in range_stepped(span.col.start, span.col.end, span.col.step) {
107 GMM::zero_accumulator(&mut acc, config);
108 gmm_execute::<MP, GMM>(
109 lhs,
110 rhs,
111 out,
112 row_iter,
113 col_iter,
114 batch_iter,
115 &mut acc,
116 k_range,
117 quantization,
118 config,
119 );
120 }
121 }
122 }
123 }
124}
125
126#[cube]
127impl SpanMatmul for ColMajorSpanMatmul {
128 fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
129 lhs: VirtualTensor<MP::EI>,
130 rhs: VirtualTensor<MP::EI>,
131 out: VirtualTensor<MP::EO, ReadWrite>,
132 span: Span,
133 mut acc: GMM::Accumulator,
134 k_range: (u32, u32),
135 quantization: CubeOption<Quantization<MP>>,
136 #[comptime] config: GMM::Config,
137 ) {
138 for batch_iter in range_stepped(span.batch.start, span.batch.end, span.batch.step) {
139 for col_iter in range_stepped(span.col.start, span.col.end, span.col.step) {
140 for row_iter in range_stepped(span.row.start, span.row.end, span.row.step) {
141 GMM::zero_accumulator(&mut acc, config);
142 gmm_execute::<MP, GMM>(
143 lhs,
144 rhs,
145 out,
146 row_iter,
147 col_iter,
148 batch_iter,
149 &mut acc,
150 k_range,
151 quantization,
152 config,
153 );
154 }
155 }
156 }
157 }
158}
159
160#[cube]
161impl<const W: u32> SpanMatmul for SwizzleSpanMatmul<W> {
162 fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
163 lhs: VirtualTensor<MP::EI>,
164 rhs: VirtualTensor<MP::EI>,
165 out: VirtualTensor<MP::EO, ReadWrite>,
166 span: Span,
167 mut acc: GMM::Accumulator,
168 k_range: (u32, u32),
169 quantization: CubeOption<Quantization<MP>>,
170 #[comptime] config: GMM::Config,
171 ) {
172 let num_swizzle = span.row.num_iterations() * span.col.num_iterations();
173
174 for batch_iter in range_stepped(span.batch.start, span.batch.end, span.batch.step) {
175 for n in 0..num_swizzle {
176 GMM::zero_accumulator(&mut acc, config);
177 let (row, col) = swizzle(n, span.row.num_iterations(), W);
178
179 let row_iter = span.row.start + row * span.row.step;
180 let col_iter = span.col.start + col * span.col.step;
181 gmm_execute::<MP, GMM>(
182 lhs,
183 rhs,
184 out,
185 row_iter,
186 col_iter,
187 batch_iter,
188 &mut acc,
189 k_range,
190 quantization,
191 config,
192 );
193 }
194 }
195 }
196}