cubecl_linalg/matmul/components/batch/
shared.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5    MatmulPrecision,
6    global::{self, Quantization},
7};
8use cubecl_std::CubeOption;
9use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
10
11#[cube]
12/// Execute global matmul on lhs, rhs, writing in out.
13/// x and y offsets are absolute rows and columns
14pub(crate) fn gmm_execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
15    lhs: VirtualTensor<MP::EI>,
16    rhs: VirtualTensor<MP::EI>,
17    out: VirtualTensor<MP::EO, ReadWrite>,
18    x_offset: u32,
19    y_offset: u32,
20    nth_batch: u32,
21    acc: &mut GMM::Accumulator,
22    k_range: (u32, u32),
23    quantization: CubeOption<Quantization<MP>>,
24    #[comptime] config: GMM::Config,
25) {
26    let rank = out.rank();
27
28    let batch_out = nth_batch * out.stride(rank - 2) * out.shape(rank - 2);
29    let mut batch_lhs = 0u32.runtime();
30    let mut batch_rhs = 0u32.runtime();
31    for axis in 0..rank - 2 {
32        let tmp = batch_out / out.stride(axis);
33        batch_lhs += tmp % lhs.shape(axis) * lhs.stride(axis);
34        batch_rhs += tmp % rhs.shape(axis) * rhs.stride(axis);
35    }
36
37    GMM::execute(
38        GMM::init_lhs_loader(
39            lhs,
40            x_offset,
41            k_range.0,
42            nth_batch,
43            batch_lhs,
44            quantization,
45            config,
46        ),
47        GMM::init_rhs_loader(
48            rhs,
49            k_range.0,
50            y_offset,
51            nth_batch,
52            batch_rhs,
53            quantization,
54            config,
55        ),
56        GMM::init_unloader(out, x_offset, y_offset, nth_batch, batch_out),
57        acc,
58        k_range,
59        config,
60    );
61}
62
63#[cube]
64pub fn swizzle(nth: u32, height: u32, #[comptime] swizzle_width: u32) -> (u32, u32) {
65    let num_elem_per_swizzle_col = height * swizzle_width;
66
67    let swizzle_id = nth % num_elem_per_swizzle_col;
68    let swizzle_col = nth / num_elem_per_swizzle_col;
69
70    let col_within_swizzle = swizzle_id / height;
71    let col = swizzle_col * swizzle_width + col_within_swizzle;
72
73    let topdown_row = swizzle_id % height;
74    let is_bottom_up = swizzle_col % 2;
75
76    let row = topdown_row + is_bottom_up * (height - 2 * topdown_row - 1);
77
78    (row, col)
79}