cubecl_linalg/matmul/components/batch/
shared.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5 MatmulPrecision,
6 global::{self, Quantization},
7};
8use cubecl_std::CubeOption;
9use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
10
11#[cube]
12pub(crate) fn gmm_execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
15 lhs: VirtualTensor<MP::EI>,
16 rhs: VirtualTensor<MP::EI>,
17 out: VirtualTensor<MP::EO, ReadWrite>,
18 x_offset: u32,
19 y_offset: u32,
20 nth_batch: u32,
21 acc: &mut GMM::Accumulator,
22 k_range: (u32, u32),
23 quantization: CubeOption<Quantization<MP>>,
24 #[comptime] config: GMM::Config,
25) {
26 let rank = out.rank();
27
28 let batch_out = nth_batch * out.stride(rank - 2) * out.shape(rank - 2);
29 let mut batch_lhs = 0u32.runtime();
30 let mut batch_rhs = 0u32.runtime();
31 for axis in 0..rank - 2 {
32 let tmp = batch_out / out.stride(axis);
33 batch_lhs += tmp % lhs.shape(axis) * lhs.stride(axis);
34 batch_rhs += tmp % rhs.shape(axis) * rhs.stride(axis);
35 }
36
37 GMM::execute(
38 GMM::init_lhs_loader(
39 lhs,
40 x_offset,
41 k_range.0,
42 nth_batch,
43 batch_lhs,
44 quantization,
45 config,
46 ),
47 GMM::init_rhs_loader(
48 rhs,
49 k_range.0,
50 y_offset,
51 nth_batch,
52 batch_rhs,
53 quantization,
54 config,
55 ),
56 GMM::init_unloader(out, x_offset, y_offset, nth_batch, batch_out),
57 acc,
58 k_range,
59 config,
60 );
61}
62
63#[cube]
64pub fn swizzle(nth: u32, height: u32, #[comptime] swizzle_width: u32) -> (u32, u32) {
65 let num_elem_per_swizzle_col = height * swizzle_width;
66
67 let swizzle_id = nth % num_elem_per_swizzle_col;
68 let swizzle_col = nth / num_elem_per_swizzle_col;
69
70 let col_within_swizzle = swizzle_id / height;
71 let col = swizzle_col * swizzle_width + col_within_swizzle;
72
73 let topdown_row = swizzle_id % height;
74 let is_bottom_up = swizzle_col % 2;
75
76 let row = topdown_row + is_bottom_up * (height - 2 * topdown_row - 1);
77
78 (row, col)
79}