cubecl_linalg/matmul/kernels/
naive.rs1use cubecl::prelude::*;
5use cubecl_core as cubecl;
6
7use crate::tensor::{MatrixBatchLayout, TensorHandle, into_contiguous, matrix_batch_layout};
8
9use super::MatmulLaunchError;
10
11#[cube(launch_unchecked)]
12fn matmul_kernel<N: Numeric>(
13 lhs: &Tensor<Line<N>>,
14 rhs: &Tensor<Line<N>>,
15 out: &mut Tensor<N>,
16 #[comptime] num_batches: Option<u32>,
18) {
19 let rank = out.rank();
20 let end = num_batches.unwrap_or_else(|| rank - 2);
21 let unroll = num_batches.is_some();
22
23 let n_rows = lhs.shape(rank - 2);
24 let n_cols = rhs.shape(rank - 1);
25 let mut k = rhs.shape(rank - 2);
26
27 let batch_pos = ABSOLUTE_POS_Z;
28 let row = CUBE_DIM_X * CUBE_POS_X + UNIT_POS_X;
29 let col = CUBE_DIM_Y * CUBE_POS_Y + UNIT_POS_Y;
30
31 if row >= n_rows || col >= n_cols {
32 terminate!();
33 }
34
35 let line_size = lhs.line_size();
36
37 let mut offset_lhs = 0;
38 let mut offset_rhs = 0;
39 let offset_out = batch_pos * out.stride(rank - 2) * out.shape(rank - 2);
40
41 #[unroll(unroll)]
42 for i in 0..end {
43 let ogwl = offset_out / out.stride(i);
44
45 offset_lhs += ogwl % lhs.shape(i) * lhs.stride(i);
46 offset_rhs += ogwl % rhs.shape(i) * rhs.stride(i);
47 }
48
49 offset_lhs /= line_size.runtime();
50 offset_rhs /= line_size.runtime();
51
52 let mut sum = Line::empty(line_size).fill(N::from_int(0));
53
54 k /= line_size.runtime();
55
56 for i in 0..k {
57 let lhs_index = row * lhs.stride(rank - 2) / line_size + i + offset_lhs;
58 let rhs_index = col * rhs.stride(rank - 1) / line_size + i + offset_rhs;
59
60 sum += lhs[lhs_index] * rhs[rhs_index];
61 }
62
63 let mut out_index = row * out.stride(rank - 2) + col;
64 out_index += offset_out;
65
66 let unroll_sum = line_size != 1;
67 if unroll_sum {
68 let mut accum = N::from_int(0);
69 #[unroll]
72 for v in 0..line_size {
73 accum += sum[v];
74 }
75
76 out[out_index] = accum;
77 } else {
78 out[out_index] = sum[0];
79 }
80}
81
82#[allow(clippy::result_large_err)]
84pub fn launch_ref<R: Runtime, E: Numeric>(
85 client: &ComputeClient<R::Server, R::Channel>,
86 lhs: &TensorHandleRef<'_, R>,
87 rhs: &TensorHandleRef<'_, R>,
88 out: &TensorHandleRef<'_, R>,
89) -> Result<(), MatmulLaunchError> {
90 let lhs = TensorHandle::<R, E>::from_ref(lhs);
91 let rhs = TensorHandle::<R, E>::from_ref(rhs);
92
93 launch(client, lhs, rhs, out)
94}
95
96#[allow(clippy::result_large_err)]
97pub fn launch<R: Runtime, E: Numeric>(
98 client: &ComputeClient<R::Server, R::Channel>,
99 lhs: TensorHandle<R, E>,
100 rhs: TensorHandle<R, E>,
101 out: &TensorHandleRef<'_, R>,
102) -> Result<(), MatmulLaunchError> {
103 let (cube_dim_x, cube_dim_y) = (32, 8);
104 let ndims = lhs.shape.len();
105 let dim1 = ndims - 1;
106 let dim2 = ndims - 2;
107
108 let lhs_layout = matrix_batch_layout(&lhs.strides);
109 let rhs_layout = matrix_batch_layout(&rhs.strides);
110
111 let lhs = if !matches!(lhs_layout, MatrixBatchLayout::Contiguous) {
112 into_contiguous::<R, E>(client, &lhs.as_ref())
113 } else {
114 lhs
115 };
116
117 let correct_rhs_layout = |mut rhs: TensorHandle<R, E>| {
121 let rhs_original_shape = rhs.shape.to_vec();
122 rhs.strides.swap(dim1, dim2);
123 rhs.shape.swap(dim1, dim2);
124
125 let mut rhs = into_contiguous::<R, E>(client, &rhs.as_ref());
126
127 rhs.strides.swap(dim1, dim2);
128 rhs.shape.swap(dim1, dim2);
129
130 (rhs_original_shape, rhs)
131 };
132
133 let (rhs_original_shape, rhs) = match rhs_layout {
134 MatrixBatchLayout::Contiguous => correct_rhs_layout(rhs),
135 MatrixBatchLayout::MildlyPermuted {
136 transposed,
137 batch_swap,
138 } => {
139 if transposed && !batch_swap {
140 let rhs_original_shape = rhs.shape.to_vec();
141 (rhs_original_shape, rhs)
142 } else {
143 correct_rhs_layout(rhs)
144 }
145 }
146 MatrixBatchLayout::HighlyPermuted => correct_rhs_layout(rhs),
147 };
148
149 let cube_count = simple_cube_count(
150 &lhs.shape,
151 &rhs_original_shape,
152 out.shape,
153 cube_dim_x,
154 cube_dim_y,
155 )?;
156
157 let vectorization_factor = match lhs.shape[ndims - 1] % 4 == 0 {
158 true => 4,
159 false => 1,
160 };
161
162 unsafe {
163 matmul_kernel::launch_unchecked::<E, R>(
164 client,
165 cube_count,
166 CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),
167 lhs.as_arg(vectorization_factor),
168 rhs.as_arg(vectorization_factor),
169 out.as_tensor_arg(1),
170 Some(ndims as u32 - 2),
171 );
172 };
173
174 Ok(())
175}
176
177#[allow(clippy::result_large_err)]
178fn simple_cube_count(
179 lhs_shape: &[usize],
180 rhs_shape: &[usize],
181 output_shape: &[usize],
182 cube_dim_x: usize,
183 cube_dim_y: usize,
184) -> Result<CubeCount, MatmulLaunchError> {
185 let ndims = lhs_shape.len();
186 let num_rows = lhs_shape[ndims - 2];
187 let num_cols = rhs_shape[ndims - 1];
188
189 let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
190 let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
191 let mut num_iter = 1u32;
192
193 #[allow(clippy::needless_range_loop)]
194 for i in 0..ndims - 2 {
195 num_iter *= output_shape[i] as u32;
196 }
197
198 let result = CubeCount::Static(cubes_x, cubes_y, num_iter);
199 let max_cube_count = u16::MAX as u32;
200
201 if cubes_x > max_cube_count || cubes_y > max_cube_count || num_iter > max_cube_count {
202 return Err(MatmulLaunchError::Unavailable(
203 super::MatmulAvailabilityError::CubeCountTooBig(result),
204 ));
205 }
206
207 Ok(result)
208}