cubecl_linalg/matmul/kernels/tiling2d/
launch.rs1use std::cmp::max;
2
3use cubecl_core::prelude::*;
4
5use crate::{
6 matmul::kernels::tiling2d::{
7 base::tiling2d_cube_kernel,
8 config::{CubeTiling2dConfig, tiling2d_cube_count, tiling2d_cube_dim},
9 },
10 tensor::{MatrixBatchLayout, TensorHandle, into_contiguous, matrix_batch_layout},
11};
12
13use super::config::Tiling2dConfig;
14
15pub fn matmul_tiling_2d<R: Runtime, F: Float>(
17 client: &ComputeClient<R::Server, R::Channel>,
18 lhs: TensorHandle<R, F>,
19 rhs: TensorHandle<R, F>,
20 out: TensorHandle<R, F>,
21 config: Tiling2dConfig,
22) -> TensorHandle<R, F> {
23 matmul_tiling_2d_ref::<R, F>(client, &lhs.as_ref(), &rhs.as_ref(), &out.as_ref(), config);
24
25 out
26}
27
28pub fn matmul_tiling_2d_ref<R: Runtime, EG: Numeric>(
30 client: &ComputeClient<R::Server, R::Channel>,
31 lhs: &TensorHandleRef<'_, R>,
32 rhs: &TensorHandleRef<'_, R>,
33 out: &TensorHandleRef<'_, R>,
34 config: Tiling2dConfig,
35) {
36 assert!(
37 EG::size().unwrap() * config.block_size_k * max(config.block_size_m, config.block_size_n)
38 <= client
39 .properties()
40 .hardware_properties()
41 .max_shared_memory_size,
42 "Shared memory limit will be busted. "
43 );
44 let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_batch_layout(tensor.strides) {
45 MatrixBatchLayout::Contiguous => true,
46 MatrixBatchLayout::MildlyPermuted {
47 transposed: _,
48 batch_swap: _,
49 } => true,
50 MatrixBatchLayout::HighlyPermuted => false,
51 };
52 let lhs_correct_layout = check_layout(lhs);
53 let rhs_correct_layout = check_layout(rhs);
54
55 match (lhs_correct_layout, rhs_correct_layout) {
56 (true, true) => matmul_tiling_2d_ref_no_check::<R, EG>(client, lhs, rhs, out, config),
57 (true, false) => matmul_tiling_2d_ref_no_check::<R, EG>(
58 client,
59 lhs,
60 &into_contiguous::<R, EG>(client, rhs).as_ref(),
61 out,
62 config,
63 ),
64 (false, true) => matmul_tiling_2d_ref_no_check::<R, EG>(
65 client,
66 &into_contiguous::<R, EG>(client, lhs).as_ref(),
67 rhs,
68 out,
69 config,
70 ),
71 (false, false) => matmul_tiling_2d_ref_no_check::<R, EG>(
72 client,
73 &into_contiguous::<R, EG>(client, lhs).as_ref(),
74 &into_contiguous::<R, EG>(client, rhs).as_ref(),
75 out,
76 config,
77 ),
78 }
79}
80
81fn matmul_tiling_2d_ref_no_check<R: Runtime, N: Numeric>(
83 client: &ComputeClient<R::Server, R::Channel>,
84 lhs: &TensorHandleRef<'_, R>,
85 rhs: &TensorHandleRef<'_, R>,
86 out: &TensorHandleRef<'_, R>,
87 config: Tiling2dConfig,
88) {
89 let rank = lhs.strides.len();
90
91 let m = lhs.shape[rank - 2];
92 let k = lhs.shape[rank - 1];
93 let n = rhs.shape[rank - 1];
94
95 let check_layout = |strides: &[usize]| match matrix_batch_layout(strides) {
96 MatrixBatchLayout::Contiguous => false,
97 MatrixBatchLayout::MildlyPermuted {
98 transposed,
99 batch_swap: _,
100 } => transposed,
101 MatrixBatchLayout::HighlyPermuted => {
102 panic!("Can't run on highly permuted tensor")
103 }
104 };
105 let lhs_transposed = check_layout(lhs.strides);
106 let rhs_transposed = check_layout(rhs.strides);
107
108 let vectorization = |shape: usize| {
109 [4, 2]
110 .into_iter()
111 .filter(|v| shape % v == 0)
112 .map(|v| v as u8)
113 .next()
114 .unwrap_or(1)
115 };
116
117 let lhs_vectorization = match lhs_transposed {
118 true => vectorization(m),
119 false => 1,
120 };
121 let rhs_vectorization = match rhs_transposed {
122 true => 1,
123 false => vectorization(n),
124 };
125 let out_vectorization = vectorization(n);
126
127 let cube_count = tiling2d_cube_count(out.shape, &config);
128 let cube_dim = tiling2d_cube_dim(&config);
129 let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed);
130
131 unsafe {
132 tiling2d_cube_kernel::launch_unchecked::<N, R>(
133 client,
134 cube_count,
135 cube_dim,
136 TensorArg::from_raw_parts::<N>(lhs.handle, lhs.strides, lhs.shape, lhs_vectorization),
137 TensorArg::from_raw_parts::<N>(rhs.handle, rhs.strides, rhs.shape, rhs_vectorization),
138 TensorArg::from_raw_parts::<N>(out.handle, out.strides, out.shape, out_vectorization),
139 cube_config,
140 );
141 }
142}