cubecl_matmul/kernels/
naive.rs

1//! Naive matmul kernel implementation
2//!
3//! Each local unit will compute a single element of the output matrix.
4use cubecl::prelude::*;
5use cubecl_core::{
6    self as cubecl,
7    ir::{ElemType, IntKind, UIntKind},
8    tensor_line_size_parallel,
9};
10
11use cubecl_std::tensor::{
12    MatrixBatchLayout, View, launch::ViewArg, layout::Coords3d, matrix_batch_layout,
13};
14
15use crate::{
16    MatmulInputHandle, MatmulInputHandleRef,
17    components::{
18        MatmulAvailabilityError, MatmulProblem, MatmulSetupError, MatrixLayout,
19        global::memory::{
20            BatchedGlobalLayout, BatchedGlobalLayoutLaunch, BatchedGlobalScaleLayout,
21            GlobalLayoutConfig,
22        },
23    },
24};
25
26#[cube]
27fn load_unrolled<I: Numeric>(
28    view: &View<Line<I>, Coords3d>,
29    pos: Coords3d,
30    #[comptime] layout: MatrixLayout,
31    #[comptime] line_size: u32,
32) -> Line<I> {
33    comptime![assert!(line_size <= view.line_size())];
34    let view_line_size = view.line_size();
35    if comptime![view.line_size() == line_size] {
36        view[pos]
37    } else {
38        let (b, row, col) = pos;
39        let mut out = Line::empty(line_size);
40        #[unroll]
41        for i in range_stepped(0, line_size, view_line_size) {
42            let pos = match layout {
43                MatrixLayout::RowMajor => (b, row, col + i),
44                MatrixLayout::ColMajor => (b, row + i, col),
45            };
46            let value = view[pos];
47            #[unroll]
48            for n in 0..view_line_size {
49                out[i + n] = value[n];
50            }
51        }
52        out
53    }
54}
55
56#[cube(launch_unchecked)]
57fn matmul_kernel<I: Numeric, M: Numeric, O: Numeric>(
58    lhs: &View<Line<I>, Coords3d>,
59    rhs: &View<Line<I>, Coords3d>,
60    out: &mut Tensor<O>,
61) {
62    let rank = out.rank();
63
64    let (_, _, k) = lhs.shape();
65    let size_m = out.shape(rank - 2);
66    let size_n = out.shape(rank - 1);
67
68    let batch = ABSOLUTE_POS_Z;
69    let m = ABSOLUTE_POS_X;
70    let n = ABSOLUTE_POS_Y;
71
72    if m >= size_m || n >= size_n {
73        terminate!();
74    }
75
76    let offset_out = batch * out.stride(rank - 2) * out.shape(rank - 2);
77
78    let line_size = comptime![Ord::max(lhs.line_size(), rhs.line_size())];
79    let mut sum = Line::empty(line_size).fill(O::from_int(0));
80
81    for k in range_stepped(0, k, line_size) {
82        let lhs = load_unrolled(lhs, (batch, m, k), MatrixLayout::RowMajor, line_size);
83        let rhs = load_unrolled(rhs, (batch, k, n), MatrixLayout::ColMajor, line_size);
84
85        sum += Line::cast_from(Line::<M>::cast_from(lhs) * Line::<M>::cast_from(rhs));
86    }
87
88    let mut out_index = m * out.stride(rank - 2) + n;
89    out_index += offset_out;
90
91    let unroll_sum = line_size != 1;
92    if unroll_sum {
93        let mut accum = O::from_int(0);
94        // we unroll the loop to sum `vectorization_factor` elements at once, which lets us
95        // use SIMD instructions to speed up the computation
96        #[unroll]
97        for v in 0..line_size {
98            accum += sum[v];
99        }
100
101        out[out_index] = accum;
102    } else {
103        out[out_index] = sum[0];
104    }
105}
106
107/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions
108#[allow(clippy::result_large_err)]
109pub fn launch<R: Runtime, EI: Numeric, EO: Numeric>(
110    client: &ComputeClient<R::Server>,
111    lhs: MatmulInputHandle<R, EI>,
112    rhs: MatmulInputHandle<R, EI>,
113    out: &TensorHandleRef<'_, R>,
114) -> Result<(), MatmulSetupError> {
115    launch_ref::<R, EI, EO>(client, &lhs.as_ref(), &rhs.as_ref(), out)
116}
117
118#[allow(clippy::result_large_err)]
119pub fn launch_ref<R: Runtime, EI: Numeric, EO: Numeric>(
120    client: &ComputeClient<R::Server>,
121    lhs: &MatmulInputHandleRef<'_, R>,
122    rhs: &MatmulInputHandleRef<'_, R>,
123    out: &TensorHandleRef<'_, R>,
124) -> Result<(), MatmulSetupError> {
125    let (cube_dim_x, cube_dim_y) = (32, 8);
126    let rank = lhs.shape().len();
127    let dim1 = rank - 1;
128    let dim2 = rank - 2;
129
130    let lhs_layout = matrix_batch_layout(lhs.data().strides);
131    let rhs_layout = matrix_batch_layout(rhs.data().strides);
132
133    let lhs = if !matches!(lhs_layout, MatrixBatchLayout::Contiguous) {
134        lhs.into_contiguous::<EI>(client)
135    } else {
136        MatmulInputHandle::from_ref(lhs)
137    };
138    let lhs = lhs.as_ref();
139    let rhs = MatmulInputHandle::from_ref(rhs);
140
141    // we swap the dimensions to achieve memory-coalescing:
142    // consecutive elements of a column in the original rhs tensor will now be stored
143    // consecutively in memory, which allows to fetch them with fewer memory instructions
144    let correct_rhs_layout = |mut rhs: MatmulInputHandle<R, EI>| {
145        rhs.swap_dims(dim1, dim2);
146
147        let mut rhs = rhs.as_ref().into_contiguous::<EI>(client);
148
149        rhs.swap_dims(dim1, dim2);
150        rhs
151    };
152
153    let rhs = match rhs_layout {
154        MatrixBatchLayout::Contiguous => correct_rhs_layout(rhs),
155        MatrixBatchLayout::MildlyPermuted {
156            transposed,
157            batch_swap,
158        } => {
159            if transposed && !batch_swap {
160                rhs
161            } else {
162                correct_rhs_layout(rhs)
163            }
164        }
165        MatrixBatchLayout::HighlyPermuted => correct_rhs_layout(rhs),
166    };
167    let rhs = rhs.as_ref();
168
169    let lhs_shape = lhs.shape();
170    let rhs_shape = rhs.shape();
171    let out_shape = out.shape;
172
173    let cube_count = simple_cube_count(lhs_shape, rhs_shape, out_shape, cube_dim_x, cube_dim_y)?;
174
175    let elem = EI::as_type_native_unchecked();
176    let lhs_line_size = tensor_line_size_parallel(
177        R::io_optimized_line_sizes(&elem),
178        lhs.data().shape,
179        lhs.data().strides,
180        rank - 1,
181    );
182    let rhs_line_size = tensor_line_size_parallel(
183        R::io_optimized_line_sizes(&elem),
184        rhs.data().shape,
185        rhs.data().strides,
186        rank - 2,
187    );
188
189    let problem = MatmulProblem {
190        m: out_shape[rank - 2],
191        n: out_shape[rank - 1],
192        k: lhs_shape[rank - 1],
193        lhs_batches: lhs_shape[..rank - 2].to_vec(),
194        rhs_batches: rhs_shape[..rank - 2].to_vec(),
195        out_batches: out_shape[..rank - 2].to_vec(),
196        lhs_layout: MatrixLayout::RowMajor,
197        rhs_layout: MatrixLayout::ColMajor,
198    };
199
200    let launch = match EI::as_type_native_unchecked().elem_type() {
201        ElemType::Int(IntKind::I8) => matmul_kernel::launch_unchecked::<EI, i16, EO, R>,
202        ElemType::Int(IntKind::I16) | ElemType::UInt(UIntKind::U16) => {
203            matmul_kernel::launch_unchecked::<EI, i32, EO, R>
204        }
205        ElemType::UInt(UIntKind::U8) => matmul_kernel::launch_unchecked::<EI, u16, EO, R>,
206        _ => matmul_kernel::launch_unchecked::<EI, EI, EO, R>,
207    };
208
209    fn view<'a, R: Runtime>(
210        client: &ComputeClient<R::Server>,
211        handle: &'a MatmulInputHandleRef<'a, R>,
212        layout: MatrixLayout,
213        line_size: u8,
214        problem: &MatmulProblem,
215    ) -> ViewArg<'a, Coords3d, R> {
216        // Checks off, other properties are unused
217        let config = GlobalLayoutConfig {
218            matrix_layout: layout,
219            ..Default::default()
220        };
221        match handle {
222            MatmulInputHandleRef::Normal(handle) => {
223                let layout = BatchedGlobalLayoutLaunch::from_handle(
224                    client, handle, problem, line_size, config,
225                );
226                ViewArg::new::<BatchedGlobalLayout>(handle.as_array_arg(line_size), layout)
227            }
228            MatmulInputHandleRef::Quantized {
229                data,
230                scale,
231                shape,
232                scheme,
233            } => {
234                let (data_layout, scales_layout) = BatchedGlobalLayoutLaunch::from_quantized_handle(
235                    client, data, scale, shape, problem, **scheme, line_size, config,
236                );
237                let data_view =
238                    ViewArg::new::<BatchedGlobalLayout>(data.as_array_arg(line_size), data_layout);
239                let scales_view =
240                    ViewArg::new::<BatchedGlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
241                ViewArg::new_quantized(data_view, scales_view, **scheme)
242            }
243        }
244    }
245
246    let lhs_view = view(
247        client,
248        &lhs,
249        MatrixLayout::RowMajor,
250        lhs_line_size,
251        &problem,
252    );
253    let rhs_view = view(
254        client,
255        &rhs,
256        MatrixLayout::ColMajor,
257        rhs_line_size,
258        &problem,
259    );
260
261    unsafe {
262        launch(
263            client,
264            cube_count,
265            CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),
266            lhs_view,
267            rhs_view,
268            out.as_tensor_arg(1),
269        );
270    };
271
272    Ok(())
273}
274
275#[allow(clippy::result_large_err)]
276fn simple_cube_count(
277    lhs_shape: &[usize],
278    rhs_shape: &[usize],
279    output_shape: &[usize],
280    cube_dim_x: usize,
281    cube_dim_y: usize,
282) -> Result<CubeCount, MatmulSetupError> {
283    let ndims = lhs_shape.len();
284    let num_rows = lhs_shape[ndims - 2];
285    let num_cols = rhs_shape[ndims - 1];
286
287    let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
288    let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
289    let mut num_iter = 1u32;
290
291    #[allow(clippy::needless_range_loop)]
292    for i in 0..ndims - 2 {
293        num_iter *= output_shape[i] as u32;
294    }
295
296    let result = CubeCount::Static(cubes_x, cubes_y, num_iter);
297    let max_cube_count = u16::MAX as u32;
298
299    if cubes_x > max_cube_count || cubes_y > max_cube_count || num_iter > max_cube_count {
300        return Err(MatmulSetupError::Unavailable(
301            MatmulAvailabilityError::CubeCountTooBig(result),
302        ));
303    }
304
305    Ok(result)
306}