1use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
2use cubecl_core::{
3 Runtime,
4 client::ComputeClient,
5 prelude::{CubePrimitive, Numeric, TensorHandleRef},
6};
7
8use cubecl_std::tensor::{TensorHandle, into_contiguous_packed, into_contiguous_pitched};
9
10use crate::{
11 components::{
12 AccG, LhsG, MatmulSetupError, RhsG,
13 tile::{accelerated::AcceleratedMatmul, io::Filled},
14 },
15 kernels::layered::{
16 Selection,
17 double_buffering::DoubleBufferingArgs,
18 double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
19 ordered_double_buffering::OrderedSelectionArgs,
20 simple::SimpleArgs,
21 simple_unit::SimpleUnitSelectionArgs,
22 vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
23 },
24};
25
26use super::{
27 components::{
28 MatmulPrecision,
29 global::read::{
30 async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
31 async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
32 },
33 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
34 },
35 kernels::{
36 layered::{
37 self,
38 double_buffering::{
39 CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
40 TilewiseDoubleBufferingAlgorithm,
41 },
42 ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
43 simple::SimpleAlgorithm,
44 simple_barrier::SimpleBarrierAlgorithm,
45 simple_tma::SimpleTmaAlgorithm,
46 simple_unit::SimpleUnitAlgorithm,
47 },
48 naive,
49 },
50};
51
52#[derive(Debug, Clone, Default)]
53pub enum Strategy {
58 Simple(SyncReadingStrategy, Selection<SimpleArgs>),
59 SimpleBarrier(AsyncReadingStrategy),
60 DoubleBuffering(SyncPartialReadingStrategy, Selection<DoubleBufferingArgs>),
61 SimpleUnit(Selection<SimpleUnitSelectionArgs>),
62 DoubleUnit(Selection<DoubleUnitSelectionArgs>),
63 SimpleVecMat(Selection<()>),
64 DoubleVecMat(Selection<()>),
65 OrderedDoubleBuffering(Selection<OrderedSelectionArgs>),
66 Naive,
67 #[default]
68 Auto,
70}
71
72#[derive(Debug, Clone)]
73pub enum SyncReadingStrategy {
75 Cyclic,
76 Strided,
77 Tilewise,
78}
79
80#[derive(Debug, Clone)]
81pub enum SyncPartialReadingStrategy {
83 Cyclic,
84 Tilewise,
85 Hybrid,
86}
87
88#[derive(Debug, Clone)]
89pub enum AsyncReadingStrategy {
91 Cooperative,
92 Cyclic,
93 MaximizeSliceLength,
94 MaximizeUnitCount,
95 Tma,
96}
97
98pub enum MatmulInputHandle<R: Runtime, E: CubePrimitive, S: CubePrimitive = f32> {
99 Normal(TensorHandle<R, E>),
100 Quantized {
101 data: TensorHandle<R, E>,
102 scale: TensorHandle<R, S>,
103 shape: Vec<usize>,
104 scheme: QuantScheme,
105 },
106}
107
108impl<R: Runtime, E: Numeric> MatmulInputHandle<R, E> {
109 pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> {
110 match self {
111 MatmulInputHandle::Normal(handle) => MatmulInputHandleRef::Normal(handle.as_ref()),
112 MatmulInputHandle::Quantized {
113 data,
114 scale,
115 shape,
116 scheme,
117 } => MatmulInputHandleRef::Quantized {
118 data: data.as_ref(),
119 scale: scale.as_ref(),
120 shape,
121 scheme,
122 },
123 }
124 }
125
126 pub fn from_ref(handle: &MatmulInputHandleRef<'_, R>) -> Self {
127 match handle {
128 MatmulInputHandleRef::Normal(handle) => {
129 MatmulInputHandle::Normal(TensorHandle::from_ref(handle))
130 }
131 MatmulInputHandleRef::Quantized {
132 data,
133 scale,
134 shape,
135 scheme,
136 } => MatmulInputHandle::Quantized {
137 data: TensorHandle::from_ref(data),
138 scale: TensorHandle::from_ref(scale),
139 shape: shape.to_vec(),
140 scheme: **scheme,
141 },
142 }
143 }
144
145 pub fn data(&self) -> &TensorHandle<R, E> {
146 match self {
147 MatmulInputHandle::Normal(handle) => handle,
148 MatmulInputHandle::Quantized { data, .. } => data,
149 }
150 }
151
152 pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
153 match self {
154 MatmulInputHandle::Normal(handle) => {
155 handle.shape.swap(dim0, dim1);
156 handle.strides.swap(dim0, dim1);
157 }
158 MatmulInputHandle::Quantized {
159 data, scale, shape, ..
160 } => {
161 data.shape.swap(dim0, dim1);
162 data.strides.swap(dim0, dim1);
163 if scale.shape.len() == data.shape.len() {
164 scale.shape.swap(dim0, dim1);
165 scale.strides.swap(dim0, dim1);
166 }
167 shape.swap(dim0, dim1);
168 }
169 }
170 }
171}
172
173impl<R: Runtime, E: CubePrimitive> Clone for MatmulInputHandle<R, E> {
174 fn clone(&self) -> Self {
175 match self {
176 Self::Normal(handle) => Self::Normal(handle.clone()),
177 Self::Quantized {
178 data,
179 scale,
180 shape,
181 scheme,
182 } => Self::Quantized {
183 data: data.clone(),
184 scale: scale.clone(),
185 shape: shape.clone(),
186 scheme: *scheme,
187 },
188 }
189 }
190}
191
192#[derive(Debug)]
193pub enum MatmulInputHandleRef<'a, R: Runtime> {
194 Normal(TensorHandleRef<'a, R>),
195 Quantized {
196 data: TensorHandleRef<'a, R>,
197 scale: TensorHandleRef<'a, R>,
198 shape: &'a [usize],
200 scheme: &'a QuantScheme,
201 },
202}
203
204impl<'a, R: Runtime> Clone for MatmulInputHandleRef<'a, R> {
205 fn clone(&self) -> Self {
206 *self
207 }
208}
209
210impl<'a, R: Runtime> Copy for MatmulInputHandleRef<'a, R> {}
211
212impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
213 pub fn new(data: TensorHandleRef<'a, R>) -> Self {
214 Self::Normal(data)
215 }
216
217 pub fn quantized(
218 data: TensorHandleRef<'a, R>,
219 scale: TensorHandleRef<'a, R>,
220 shape: &'a [usize],
221 scheme: &'a QuantScheme,
222 ) -> Self {
223 Self::Quantized {
224 data,
225 scale,
226 shape,
227 scheme,
228 }
229 }
230
231 pub fn data(&self) -> &TensorHandleRef<'a, R> {
232 match self {
233 MatmulInputHandleRef::Normal(handle) => handle,
234 MatmulInputHandleRef::Quantized { data, .. } => data,
235 }
236 }
237
238 pub fn data_mut(&mut self) -> &mut TensorHandleRef<'a, R> {
239 match self {
240 MatmulInputHandleRef::Normal(handle) => handle,
241 MatmulInputHandleRef::Quantized { data, .. } => data,
242 }
243 }
244
245 pub fn scale(&self) -> Option<&TensorHandleRef<'a, R>> {
246 match self {
247 MatmulInputHandleRef::Normal(_) => None,
248 MatmulInputHandleRef::Quantized { scale, .. } => Some(scale),
249 }
250 }
251
252 pub fn scheme(&self) -> Option<&QuantScheme> {
253 match self {
254 MatmulInputHandleRef::Normal(_) => None,
255 MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme),
256 }
257 }
258
259 pub fn shape(&self) -> &[usize] {
260 match self {
261 MatmulInputHandleRef::Normal(handle) => handle.shape,
262 MatmulInputHandleRef::Quantized { shape, .. } => shape,
263 }
264 }
265
266 pub fn into_contiguous<E: Numeric>(
267 &self,
268 client: &ComputeClient<R::Server>,
269 ) -> MatmulInputHandle<R, E> {
270 match self {
271 MatmulInputHandleRef::Normal(data) => {
272 MatmulInputHandle::Normal(into_contiguous_pitched::<R, E>(client, data))
273 }
274 MatmulInputHandleRef::Quantized {
275 data,
276 scale,
277 shape,
278 scheme,
279 } => {
280 let data = match scheme.store {
281 QuantStore::Native if scheme.value == QuantValue::E2M1 => {
283 let data = into_contiguous_packed::<R, u8>(client, data, shape, 2);
284 TensorHandle::from_ref(&data.as_ref())
286 }
287 QuantStore::U32 => {
288 let data = into_contiguous_packed::<R, u32>(
289 client,
290 data,
291 shape,
292 scheme.num_quants() as u32,
293 );
294 TensorHandle::from_ref(&data.as_ref())
296 }
297 _ => into_contiguous_pitched::<R, E>(client, data),
298 };
299 MatmulInputHandle::Quantized {
300 data,
301 scale: TensorHandle::from_ref(scale),
302 shape: shape.to_vec(),
303 scheme: **scheme,
304 }
305 }
306 }
307 }
308}
309
310#[allow(clippy::result_large_err)]
311pub fn launch<R: Runtime, MP: MatmulPrecision>(
312 strategy: &Strategy,
313 client: &ComputeClient<R::Server>,
314 lhs: MatmulInputHandle<R, LhsG<MP>>,
315 rhs: MatmulInputHandle<R, RhsG<MP>>,
316 out: TensorHandle<R, AccG<MP>>,
317) -> Result<(), MatmulSetupError> {
318 launch_ref::<R, MP>(
319 strategy,
320 client,
321 &lhs.as_ref(),
322 &rhs.as_ref(),
323 &out.as_ref(),
324 )
325}
326
327#[allow(clippy::result_large_err)]
328pub fn launch_ref<R: Runtime, MP: MatmulPrecision>(
329 strategy: &Strategy,
330 client: &ComputeClient<R::Server>,
331 lhs: &MatmulInputHandleRef<R>,
332 rhs: &MatmulInputHandleRef<R>,
333 out: &TensorHandleRef<R>,
334) -> Result<(), MatmulSetupError> {
335 type Accelerated = AcceleratedMatmul<Filled>;
336
337 match strategy {
338 Strategy::Simple(loading_strategy, selection) => match loading_strategy {
339 SyncReadingStrategy::Cyclic => {
340 layered::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(
341 client, lhs, rhs, out, selection,
342 )
343 }
344 SyncReadingStrategy::Strided => layered::launch_ref::<
345 R,
346 MP,
347 SimpleAlgorithm<
348 Accelerated,
349 sync_full_strided::SyncFullStridedLoading,
350 sync_full_strided::SyncFullStridedLoading,
351 >,
352 >(client, lhs, rhs, out, selection),
353 SyncReadingStrategy::Tilewise => {
354 layered::launch_ref::<
355 R,
356 MP,
357 SimpleAlgorithm<
358 Accelerated,
359 sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
360 sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
361 >,
362 >(client, lhs, rhs, out, &Default::default())
363 }
364 },
365 Strategy::SimpleBarrier(loading_strategy) => match loading_strategy {
366 AsyncReadingStrategy::Cooperative => {
367 layered::launch_ref::<
368 R,
369 MP,
370 SimpleBarrierAlgorithm<
371 Accelerated,
372 async_full_cooperative::AsyncFullCooperativeLoading,
373 >,
374 >(client, lhs, rhs, out, &Default::default())
375 }
376 AsyncReadingStrategy::Cyclic => {
377 layered::launch_ref::<
378 R,
379 MP,
380 SimpleBarrierAlgorithm<
381 Accelerated,
382 async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
383 >,
384 >(client, lhs, rhs, out, &Default::default())
385 }
386 AsyncReadingStrategy::MaximizeSliceLength => {
387 layered::launch_ref::<
388 R,
389 MP,
390 SimpleBarrierAlgorithm<
391 Accelerated,
392 async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
393 >,
394 >(client, lhs, rhs, out, &Default::default())
395 }
396 AsyncReadingStrategy::MaximizeUnitCount => {
397 layered::launch_ref::<
398 R,
399 MP,
400 SimpleBarrierAlgorithm<
401 Accelerated,
402 async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
403 >,
404 >(client, lhs, rhs, out, &Default::default())
405 }
406 AsyncReadingStrategy::Tma => {
407 layered::matmul_cmma_tma_ref_no_check::<R, MP, SimpleTmaAlgorithm<Accelerated>>(
408 client,
409 lhs,
410 rhs,
411 out,
412 (false, false),
413 &Default::default(),
414 )
415 }
416 },
417 Strategy::DoubleBuffering(loading_strategy, selection) => match loading_strategy {
418 SyncPartialReadingStrategy::Cyclic => {
419 layered::launch_ref::<R, MP, CyclicDoubleBufferingAlgorithm<Accelerated>>(
420 client, lhs, rhs, out, selection,
421 )
422 }
423 SyncPartialReadingStrategy::Tilewise => {
424 layered::launch_ref::<R, MP, TilewiseDoubleBufferingAlgorithm<Accelerated>>(
425 client, lhs, rhs, out, selection,
426 )
427 }
428 SyncPartialReadingStrategy::Hybrid => {
429 layered::launch_ref::<R, MP, HybridDoubleBufferingAlgorithm<Accelerated>>(
430 client, lhs, rhs, out, selection,
431 )
432 }
433 },
434 Strategy::OrderedDoubleBuffering(selection) => {
435 layered::launch_ref::<R, MP, OrderedDoubleBufferingAlgorithm<Accelerated>>(
436 client, lhs, rhs, out, selection,
437 )
438 }
439 Strategy::SimpleUnit(selection) => {
440 layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(client, lhs, rhs, out, selection)
441 }
442 Strategy::DoubleUnit(selection) => {
443 layered::launch_ref::<R, MP, DoubleUnitAlgorithm>(client, lhs, rhs, out, selection)
444 }
445 Strategy::Naive => {
446 naive::launch_ref::<R, LhsG<MP>, AccG<MP>>(client, lhs, rhs, out)?;
447 Ok(())
448 }
449 Strategy::Auto => {
450 if let Err(err) = layered::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(
451 client,
452 lhs,
453 rhs,
454 out,
455 &Default::default(),
456 ) {
457 match err {
458 MatmulSetupError::Unavailable(_) => {
459 layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(
460 client,
461 lhs,
462 rhs,
463 out,
464 &Default::default(),
465 )
466 .unwrap();
467 }
468 _ => panic!("{err:?}"),
469 }
470 }
471
472 Ok(())
473 }
474 Strategy::SimpleVecMat(selection) => {
475 layered::launch_ref::<R, MP, SimpleVecMatAlgorithm>(client, lhs, rhs, out, selection)
476 }
477 Strategy::DoubleVecMat(selection) => {
478 layered::launch_ref::<R, MP, DoubleVecMatAlgorithm>(client, lhs, rhs, out, selection)
479 }
480 }
481}