cubecl_linalg/matmul/kernels/
naive.rs

1//! Naive matmul kernel implementation
2//!
3//! Each local unit will compute a single element of the output matrix.
4use cubecl::prelude::*;
5use cubecl_core as cubecl;
6
7use crate::tensor::{MatrixBatchLayout, TensorHandle, into_contiguous, matrix_batch_layout};
8
9use super::MatmulLaunchError;
10
11#[cube(launch_unchecked)]
12fn matmul_kernel<N: Numeric>(
13    lhs: &Tensor<Line<N>>,
14    rhs: &Tensor<Line<N>>,
15    out: &mut Tensor<N>,
16    // number of dimensions not involved in the matmul
17    #[comptime] num_batches: Option<u32>,
18) {
19    let rank = out.rank();
20    let end = num_batches.unwrap_or_else(|| rank - 2);
21    let unroll = num_batches.is_some();
22
23    let n_rows = lhs.shape(rank - 2);
24    let n_cols = rhs.shape(rank - 1);
25    let mut k = rhs.shape(rank - 2);
26
27    let batch_pos = ABSOLUTE_POS_Z;
28    let row = CUBE_DIM_X * CUBE_POS_X + UNIT_POS_X;
29    let col = CUBE_DIM_Y * CUBE_POS_Y + UNIT_POS_Y;
30
31    if row >= n_rows || col >= n_cols {
32        terminate!();
33    }
34
35    let line_size = lhs.line_size();
36
37    let mut offset_lhs = 0;
38    let mut offset_rhs = 0;
39    let offset_out = batch_pos * out.stride(rank - 2) * out.shape(rank - 2);
40
41    #[unroll(unroll)]
42    for i in 0..end {
43        let ogwl = offset_out / out.stride(i);
44
45        offset_lhs += ogwl % lhs.shape(i) * lhs.stride(i);
46        offset_rhs += ogwl % rhs.shape(i) * rhs.stride(i);
47    }
48
49    offset_lhs /= line_size.runtime();
50    offset_rhs /= line_size.runtime();
51
52    let mut sum = Line::empty(line_size).fill(N::from_int(0));
53
54    k /= line_size.runtime();
55
56    for i in 0..k {
57        let lhs_index = row * lhs.stride(rank - 2) / line_size + i + offset_lhs;
58        let rhs_index = col * rhs.stride(rank - 1) / line_size + i + offset_rhs;
59
60        sum += lhs[lhs_index] * rhs[rhs_index];
61    }
62
63    let mut out_index = row * out.stride(rank - 2) + col;
64    out_index += offset_out;
65
66    let unroll_sum = line_size != 1;
67    if unroll_sum {
68        let mut accum = N::from_int(0);
69        // we unroll the loop to sum `vectorization_factor` elements at once, which lets us
70        // use SIMD instructions to speed up the computation
71        #[unroll]
72        for v in 0..line_size {
73            accum += sum[v];
74        }
75
76        out[out_index] = accum;
77    } else {
78        out[out_index] = sum[0];
79    }
80}
81
82/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions
83#[allow(clippy::result_large_err)]
84pub fn launch_ref<R: Runtime, E: Numeric>(
85    client: &ComputeClient<R::Server, R::Channel>,
86    lhs: &TensorHandleRef<'_, R>,
87    rhs: &TensorHandleRef<'_, R>,
88    out: &TensorHandleRef<'_, R>,
89) -> Result<(), MatmulLaunchError> {
90    let lhs = TensorHandle::<R, E>::from_ref(lhs);
91    let rhs = TensorHandle::<R, E>::from_ref(rhs);
92
93    launch(client, lhs, rhs, out)
94}
95
96#[allow(clippy::result_large_err)]
97pub fn launch<R: Runtime, E: Numeric>(
98    client: &ComputeClient<R::Server, R::Channel>,
99    lhs: TensorHandle<R, E>,
100    rhs: TensorHandle<R, E>,
101    out: &TensorHandleRef<'_, R>,
102) -> Result<(), MatmulLaunchError> {
103    let (cube_dim_x, cube_dim_y) = (32, 8);
104    let ndims = lhs.shape.len();
105    let dim1 = ndims - 1;
106    let dim2 = ndims - 2;
107
108    let lhs_layout = matrix_batch_layout(&lhs.strides);
109    let rhs_layout = matrix_batch_layout(&rhs.strides);
110
111    let lhs = if !matches!(lhs_layout, MatrixBatchLayout::Contiguous) {
112        into_contiguous::<R, E>(client, &lhs.as_ref())
113    } else {
114        lhs
115    };
116
117    // we swap the dimensions to achieve memory-coalescing:
118    // consecutive elements of a column in the original rhs tensor will now be stored
119    // consecutively in memory, which allows to fetch them with fewer memory instructions
120    let correct_rhs_layout = |mut rhs: TensorHandle<R, E>| {
121        let rhs_original_shape = rhs.shape.to_vec();
122        rhs.strides.swap(dim1, dim2);
123        rhs.shape.swap(dim1, dim2);
124
125        let mut rhs = into_contiguous::<R, E>(client, &rhs.as_ref());
126
127        rhs.strides.swap(dim1, dim2);
128        rhs.shape.swap(dim1, dim2);
129
130        (rhs_original_shape, rhs)
131    };
132
133    let (rhs_original_shape, rhs) = match rhs_layout {
134        MatrixBatchLayout::Contiguous => correct_rhs_layout(rhs),
135        MatrixBatchLayout::MildlyPermuted {
136            transposed,
137            batch_swap,
138        } => {
139            if transposed && !batch_swap {
140                let rhs_original_shape = rhs.shape.to_vec();
141                (rhs_original_shape, rhs)
142            } else {
143                correct_rhs_layout(rhs)
144            }
145        }
146        MatrixBatchLayout::HighlyPermuted => correct_rhs_layout(rhs),
147    };
148
149    let cube_count = simple_cube_count(
150        &lhs.shape,
151        &rhs_original_shape,
152        out.shape,
153        cube_dim_x,
154        cube_dim_y,
155    )?;
156
157    let vectorization_factor = match lhs.shape[ndims - 1] % 4 == 0 {
158        true => 4,
159        false => 1,
160    };
161
162    unsafe {
163        matmul_kernel::launch_unchecked::<E, R>(
164            client,
165            cube_count,
166            CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),
167            lhs.as_arg(vectorization_factor),
168            rhs.as_arg(vectorization_factor),
169            out.as_tensor_arg(1),
170            Some(ndims as u32 - 2),
171        );
172    };
173
174    Ok(())
175}
176
177#[allow(clippy::result_large_err)]
178fn simple_cube_count(
179    lhs_shape: &[usize],
180    rhs_shape: &[usize],
181    output_shape: &[usize],
182    cube_dim_x: usize,
183    cube_dim_y: usize,
184) -> Result<CubeCount, MatmulLaunchError> {
185    let ndims = lhs_shape.len();
186    let num_rows = lhs_shape[ndims - 2];
187    let num_cols = rhs_shape[ndims - 1];
188
189    let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
190    let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
191    let mut num_iter = 1u32;
192
193    #[allow(clippy::needless_range_loop)]
194    for i in 0..ndims - 2 {
195        num_iter *= output_shape[i] as u32;
196    }
197
198    let result = CubeCount::Static(cubes_x, cubes_y, num_iter);
199    let max_cube_count = u16::MAX as u32;
200
201    if cubes_x > max_cube_count || cubes_y > max_cube_count || num_iter > max_cube_count {
202        return Err(MatmulLaunchError::Unavailable(
203            super::MatmulAvailabilityError::CubeCountTooBig(result),
204        ));
205    }
206
207    Ok(result)
208}