1use cubecl::prelude::*;
5use cubecl_core::{
6 self as cubecl,
7 ir::{ElemType, IntKind, UIntKind},
8 tensor_line_size_parallel,
9};
10
11use cubecl_std::tensor::{
12 MatrixBatchLayout, View, launch::ViewArg, layout::Coords3d, matrix_batch_layout,
13};
14
15use crate::{
16 MatmulInputHandle, MatmulInputHandleRef,
17 components::{
18 MatmulAvailabilityError, MatmulProblem, MatmulSetupError, MatrixLayout,
19 global::memory::{
20 BatchedGlobalLayout, BatchedGlobalLayoutLaunch, BatchedGlobalScaleLayout,
21 GlobalLayoutConfig,
22 },
23 },
24};
25
26#[cube]
27fn load_unrolled<I: Numeric>(
28 view: &View<Line<I>, Coords3d>,
29 pos: Coords3d,
30 #[comptime] layout: MatrixLayout,
31 #[comptime] line_size: u32,
32) -> Line<I> {
33 comptime![assert!(line_size <= view.line_size())];
34 let view_line_size = view.line_size();
35 if comptime![view.line_size() == line_size] {
36 view[pos]
37 } else {
38 let (b, row, col) = pos;
39 let mut out = Line::empty(line_size);
40 #[unroll]
41 for i in range_stepped(0, line_size, view_line_size) {
42 let pos = match layout {
43 MatrixLayout::RowMajor => (b, row, col + i),
44 MatrixLayout::ColMajor => (b, row + i, col),
45 };
46 let value = view[pos];
47 #[unroll]
48 for n in 0..view_line_size {
49 out[i + n] = value[n];
50 }
51 }
52 out
53 }
54}
55
56#[cube(launch_unchecked)]
57fn matmul_kernel<I: Numeric, M: Numeric, O: Numeric>(
58 lhs: &View<Line<I>, Coords3d>,
59 rhs: &View<Line<I>, Coords3d>,
60 out: &mut Tensor<O>,
61) {
62 let rank = out.rank();
63
64 let (_, _, k) = lhs.shape();
65 let size_m = out.shape(rank - 2);
66 let size_n = out.shape(rank - 1);
67
68 let batch = ABSOLUTE_POS_Z;
69 let m = ABSOLUTE_POS_X;
70 let n = ABSOLUTE_POS_Y;
71
72 if m >= size_m || n >= size_n {
73 terminate!();
74 }
75
76 let offset_out = batch * out.stride(rank - 2) * out.shape(rank - 2);
77
78 let line_size = comptime![Ord::max(lhs.line_size(), rhs.line_size())];
79 let mut sum = Line::empty(line_size).fill(O::from_int(0));
80
81 for k in range_stepped(0, k, line_size) {
82 let lhs = load_unrolled(lhs, (batch, m, k), MatrixLayout::RowMajor, line_size);
83 let rhs = load_unrolled(rhs, (batch, k, n), MatrixLayout::ColMajor, line_size);
84
85 sum += Line::cast_from(Line::<M>::cast_from(lhs) * Line::<M>::cast_from(rhs));
86 }
87
88 let mut out_index = m * out.stride(rank - 2) + n;
89 out_index += offset_out;
90
91 let unroll_sum = line_size != 1;
92 if unroll_sum {
93 let mut accum = O::from_int(0);
94 #[unroll]
97 for v in 0..line_size {
98 accum += sum[v];
99 }
100
101 out[out_index] = accum;
102 } else {
103 out[out_index] = sum[0];
104 }
105}
106
107#[allow(clippy::result_large_err)]
109pub fn launch<R: Runtime, EI: Numeric, EO: Numeric>(
110 client: &ComputeClient<R::Server>,
111 lhs: MatmulInputHandle<R, EI>,
112 rhs: MatmulInputHandle<R, EI>,
113 out: &TensorHandleRef<'_, R>,
114) -> Result<(), MatmulSetupError> {
115 launch_ref::<R, EI, EO>(client, &lhs.as_ref(), &rhs.as_ref(), out)
116}
117
118#[allow(clippy::result_large_err)]
119pub fn launch_ref<R: Runtime, EI: Numeric, EO: Numeric>(
120 client: &ComputeClient<R::Server>,
121 lhs: &MatmulInputHandleRef<'_, R>,
122 rhs: &MatmulInputHandleRef<'_, R>,
123 out: &TensorHandleRef<'_, R>,
124) -> Result<(), MatmulSetupError> {
125 let (cube_dim_x, cube_dim_y) = (32, 8);
126 let rank = lhs.shape().len();
127 let dim1 = rank - 1;
128 let dim2 = rank - 2;
129
130 let lhs_layout = matrix_batch_layout(lhs.data().strides);
131 let rhs_layout = matrix_batch_layout(rhs.data().strides);
132
133 let lhs = if !matches!(lhs_layout, MatrixBatchLayout::Contiguous) {
134 lhs.into_contiguous::<EI>(client)
135 } else {
136 MatmulInputHandle::from_ref(lhs)
137 };
138 let lhs = lhs.as_ref();
139 let rhs = MatmulInputHandle::from_ref(rhs);
140
141 let correct_rhs_layout = |mut rhs: MatmulInputHandle<R, EI>| {
145 rhs.swap_dims(dim1, dim2);
146
147 let mut rhs = rhs.as_ref().into_contiguous::<EI>(client);
148
149 rhs.swap_dims(dim1, dim2);
150 rhs
151 };
152
153 let rhs = match rhs_layout {
154 MatrixBatchLayout::Contiguous => correct_rhs_layout(rhs),
155 MatrixBatchLayout::MildlyPermuted {
156 transposed,
157 batch_swap,
158 } => {
159 if transposed && !batch_swap {
160 rhs
161 } else {
162 correct_rhs_layout(rhs)
163 }
164 }
165 MatrixBatchLayout::HighlyPermuted => correct_rhs_layout(rhs),
166 };
167 let rhs = rhs.as_ref();
168
169 let lhs_shape = lhs.shape();
170 let rhs_shape = rhs.shape();
171 let out_shape = out.shape;
172
173 let cube_count = simple_cube_count(lhs_shape, rhs_shape, out_shape, cube_dim_x, cube_dim_y)?;
174
175 let elem = EI::as_type_native_unchecked();
176 let lhs_line_size = tensor_line_size_parallel(
177 R::io_optimized_line_sizes(&elem),
178 lhs.data().shape,
179 lhs.data().strides,
180 rank - 1,
181 );
182 let rhs_line_size = tensor_line_size_parallel(
183 R::io_optimized_line_sizes(&elem),
184 rhs.data().shape,
185 rhs.data().strides,
186 rank - 2,
187 );
188
189 let problem = MatmulProblem {
190 m: out_shape[rank - 2],
191 n: out_shape[rank - 1],
192 k: lhs_shape[rank - 1],
193 lhs_batches: lhs_shape[..rank - 2].to_vec(),
194 rhs_batches: rhs_shape[..rank - 2].to_vec(),
195 out_batches: out_shape[..rank - 2].to_vec(),
196 lhs_layout: MatrixLayout::RowMajor,
197 rhs_layout: MatrixLayout::ColMajor,
198 };
199
200 let launch = match EI::as_type_native_unchecked().elem_type() {
201 ElemType::Int(IntKind::I8) => matmul_kernel::launch_unchecked::<EI, i16, EO, R>,
202 ElemType::Int(IntKind::I16) | ElemType::UInt(UIntKind::U16) => {
203 matmul_kernel::launch_unchecked::<EI, i32, EO, R>
204 }
205 ElemType::UInt(UIntKind::U8) => matmul_kernel::launch_unchecked::<EI, u16, EO, R>,
206 _ => matmul_kernel::launch_unchecked::<EI, EI, EO, R>,
207 };
208
209 fn view<'a, R: Runtime>(
210 client: &ComputeClient<R::Server>,
211 handle: &'a MatmulInputHandleRef<'a, R>,
212 layout: MatrixLayout,
213 line_size: u8,
214 problem: &MatmulProblem,
215 ) -> ViewArg<'a, Coords3d, R> {
216 let config = GlobalLayoutConfig {
218 matrix_layout: layout,
219 ..Default::default()
220 };
221 match handle {
222 MatmulInputHandleRef::Normal(handle) => {
223 let layout = BatchedGlobalLayoutLaunch::from_handle(
224 client, handle, problem, line_size, config,
225 );
226 ViewArg::new::<BatchedGlobalLayout>(handle.as_array_arg(line_size), layout)
227 }
228 MatmulInputHandleRef::Quantized {
229 data,
230 scale,
231 shape,
232 scheme,
233 } => {
234 let (data_layout, scales_layout) = BatchedGlobalLayoutLaunch::from_quantized_handle(
235 client, data, scale, shape, problem, **scheme, line_size, config,
236 );
237 let data_view =
238 ViewArg::new::<BatchedGlobalLayout>(data.as_array_arg(line_size), data_layout);
239 let scales_view =
240 ViewArg::new::<BatchedGlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
241 ViewArg::new_quantized(data_view, scales_view, **scheme)
242 }
243 }
244 }
245
246 let lhs_view = view(
247 client,
248 &lhs,
249 MatrixLayout::RowMajor,
250 lhs_line_size,
251 &problem,
252 );
253 let rhs_view = view(
254 client,
255 &rhs,
256 MatrixLayout::ColMajor,
257 rhs_line_size,
258 &problem,
259 );
260
261 unsafe {
262 launch(
263 client,
264 cube_count,
265 CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),
266 lhs_view,
267 rhs_view,
268 out.as_tensor_arg(1),
269 );
270 };
271
272 Ok(())
273}
274
275#[allow(clippy::result_large_err)]
276fn simple_cube_count(
277 lhs_shape: &[usize],
278 rhs_shape: &[usize],
279 output_shape: &[usize],
280 cube_dim_x: usize,
281 cube_dim_y: usize,
282) -> Result<CubeCount, MatmulSetupError> {
283 let ndims = lhs_shape.len();
284 let num_rows = lhs_shape[ndims - 2];
285 let num_cols = rhs_shape[ndims - 1];
286
287 let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
288 let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
289 let mut num_iter = 1u32;
290
291 #[allow(clippy::needless_range_loop)]
292 for i in 0..ndims - 2 {
293 num_iter *= output_shape[i] as u32;
294 }
295
296 let result = CubeCount::Static(cubes_x, cubes_y, num_iter);
297 let max_cube_count = u16::MAX as u32;
298
299 if cubes_x > max_cube_count || cubes_y > max_cube_count || num_iter > max_cube_count {
300 return Err(MatmulSetupError::Unavailable(
301 MatmulAvailabilityError::CubeCountTooBig(result),
302 ));
303 }
304
305 Ok(result)
306}