cubecl_linalg/matmul/kernels/tiling2d/
launch.rs

1use std::cmp::max;
2
3use cubecl_core::prelude::*;
4
5use crate::{
6    matmul::kernels::tiling2d::{
7        base::tiling2d_cube_kernel,
8        config::{CubeTiling2dConfig, tiling2d_cube_count, tiling2d_cube_dim},
9    },
10    tensor::{MatrixBatchLayout, TensorHandle, into_contiguous, matrix_batch_layout},
11};
12
13use super::config::Tiling2dConfig;
14
15/// Matrix multiplication using tiling 2d algorithm.
16pub fn matmul_tiling_2d<R: Runtime, F: Float>(
17    client: &ComputeClient<R::Server, R::Channel>,
18    lhs: TensorHandle<R, F>,
19    rhs: TensorHandle<R, F>,
20    out: TensorHandle<R, F>,
21    config: Tiling2dConfig,
22) -> TensorHandle<R, F> {
23    matmul_tiling_2d_ref::<R, F>(client, &lhs.as_ref(), &rhs.as_ref(), &out.as_ref(), config);
24
25    out
26}
27
28/// Matrix multiplication using tiling 2d algorithm.
29pub fn matmul_tiling_2d_ref<R: Runtime, EG: Numeric>(
30    client: &ComputeClient<R::Server, R::Channel>,
31    lhs: &TensorHandleRef<'_, R>,
32    rhs: &TensorHandleRef<'_, R>,
33    out: &TensorHandleRef<'_, R>,
34    config: Tiling2dConfig,
35) {
36    assert!(
37        EG::size().unwrap() * config.block_size_k * max(config.block_size_m, config.block_size_n)
38            <= client
39                .properties()
40                .hardware_properties()
41                .max_shared_memory_size,
42        "Shared memory limit will be busted. "
43    );
44    let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_batch_layout(tensor.strides) {
45        MatrixBatchLayout::Contiguous => true,
46        MatrixBatchLayout::MildlyPermuted {
47            transposed: _,
48            batch_swap: _,
49        } => true,
50        MatrixBatchLayout::HighlyPermuted => false,
51    };
52    let lhs_correct_layout = check_layout(lhs);
53    let rhs_correct_layout = check_layout(rhs);
54
55    match (lhs_correct_layout, rhs_correct_layout) {
56        (true, true) => matmul_tiling_2d_ref_no_check::<R, EG>(client, lhs, rhs, out, config),
57        (true, false) => matmul_tiling_2d_ref_no_check::<R, EG>(
58            client,
59            lhs,
60            &into_contiguous::<R, EG>(client, rhs).as_ref(),
61            out,
62            config,
63        ),
64        (false, true) => matmul_tiling_2d_ref_no_check::<R, EG>(
65            client,
66            &into_contiguous::<R, EG>(client, lhs).as_ref(),
67            rhs,
68            out,
69            config,
70        ),
71        (false, false) => matmul_tiling_2d_ref_no_check::<R, EG>(
72            client,
73            &into_contiguous::<R, EG>(client, lhs).as_ref(),
74            &into_contiguous::<R, EG>(client, rhs).as_ref(),
75            out,
76            config,
77        ),
78    }
79}
80
81/// Matrix multiplication using tiling 2d algorithm.
82fn matmul_tiling_2d_ref_no_check<R: Runtime, N: Numeric>(
83    client: &ComputeClient<R::Server, R::Channel>,
84    lhs: &TensorHandleRef<'_, R>,
85    rhs: &TensorHandleRef<'_, R>,
86    out: &TensorHandleRef<'_, R>,
87    config: Tiling2dConfig,
88) {
89    let rank = lhs.strides.len();
90
91    let m = lhs.shape[rank - 2];
92    let k = lhs.shape[rank - 1];
93    let n = rhs.shape[rank - 1];
94
95    let check_layout = |strides: &[usize]| match matrix_batch_layout(strides) {
96        MatrixBatchLayout::Contiguous => false,
97        MatrixBatchLayout::MildlyPermuted {
98            transposed,
99            batch_swap: _,
100        } => transposed,
101        MatrixBatchLayout::HighlyPermuted => {
102            panic!("Can't run on highly permuted tensor")
103        }
104    };
105    let lhs_transposed = check_layout(lhs.strides);
106    let rhs_transposed = check_layout(rhs.strides);
107
108    let vectorization = |shape: usize| {
109        [4, 2]
110            .into_iter()
111            .filter(|v| shape % v == 0)
112            .map(|v| v as u8)
113            .next()
114            .unwrap_or(1)
115    };
116
117    let lhs_vectorization = match lhs_transposed {
118        true => vectorization(m),
119        false => 1,
120    };
121    let rhs_vectorization = match rhs_transposed {
122        true => 1,
123        false => vectorization(n),
124    };
125    let out_vectorization = vectorization(n);
126
127    let cube_count = tiling2d_cube_count(out.shape, &config);
128    let cube_dim = tiling2d_cube_dim(&config);
129    let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed);
130
131    unsafe {
132        tiling2d_cube_kernel::launch_unchecked::<N, R>(
133            client,
134            cube_count,
135            cube_dim,
136            TensorArg::from_raw_parts::<N>(lhs.handle, lhs.strides, lhs.shape, lhs_vectorization),
137            TensorArg::from_raw_parts::<N>(rhs.handle, rhs.strides, rhs.shape, rhs_vectorization),
138            TensorArg::from_raw_parts::<N>(out.handle, out.strides, out.shape, out_vectorization),
139            cube_config,
140        );
141    }
142}