cubecl_matmul/kernels/layered/
base.rs

1use crate::components::global::args::TensorArgs;
2use crate::components::{
3    AccG, AccS,
4    batch::{BatchMatmulFamily, CubeCountInputArgs},
5};
6use crate::components::{
7    AvailableLineSizes, InputRuntimeArg, LhsG, LhsS, MatmulAvailabilityError, MatmulLineSizes,
8    MatmulPrecision, MatmulProblem, MatmulSelection, MatmulSetupError, MatmulSpec, MatrixLayout,
9    OutputRuntimeArg, RhsG, RhsS,
10};
11use crate::components::{global::args::TensorMapArgs, tile::TileMatmulFamily};
12use crate::kernels::layered::selector::launch_kernel_concrete;
13use crate::{MatmulInputHandle, MatmulInputHandleRef};
14use core::any::TypeId;
15use cubecl_core::{Runtime, client::ComputeClient, frontend::TensorHandleRef};
16use cubecl_core::{prelude::*, try_tensor_line_size_parallel};
17use cubecl_runtime::TypeUsage;
18use cubecl_std::tensor::{MatrixBatchLayout, TensorHandle, matrix_batch_layout};
19
20use super::Algorithm;
21
22#[derive(Debug, Clone)]
23pub enum Selection<S> {
24    /// Use a predefined MatmulSelection
25    Forced(MatmulSelection),
26    /// Allows to give limited MatmulSelection information, and the rest is inferred from it
27    Inferred(S),
28}
29
30impl<S: Default + Clone> Selection<S> {
31    pub fn maybe_forced_default(s: &Option<MatmulSelection>) -> Self {
32        s.as_ref()
33            .map(|s| Self::Forced(s.clone()))
34            .unwrap_or_default()
35    }
36    pub fn maybe_forced_or(s: &Option<MatmulSelection>, args: &S) -> Self {
37        s.as_ref()
38            .map(|s| Self::Forced(s.clone()))
39            .unwrap_or_else(|| Self::Inferred(args.clone()))
40    }
41}
42
43impl<S: Default> Default for Selection<S> {
44    fn default() -> Self {
45        Self::Inferred(Default::default())
46    }
47}
48
49/// Launch a matrix multiplication kernel.
50///
51/// Cmma will be used if enabled
52/// Will fail if unavailable
53#[allow(clippy::result_large_err)]
54pub fn launch<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
55    client: &ComputeClient<R::Server>,
56    lhs: MatmulInputHandle<R, LhsG<MP>>,
57    rhs: MatmulInputHandle<R, RhsG<MP>>,
58    out: TensorHandle<R, AccG<MP>>,
59    selection: &Selection<A::SelectionArgs>,
60) -> Result<TensorHandle<R, AccG<MP>>, MatmulSetupError> {
61    let result = launch_ref::<R, MP, A>(
62        client,
63        &lhs.as_ref(),
64        &rhs.as_ref(),
65        &out.as_ref(),
66        selection,
67    );
68
69    match result {
70        Ok(_) => Ok(out),
71        Err(e) => Err(e),
72    }
73}
74
75/// Launch a matrix multiplication kernel.
76///
77/// Cmma will be used if available and enabled,
78/// otherwise it will fall back on a non-cmma implementation
79#[allow(clippy::result_large_err)]
80pub fn launch_ref<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
81    client: &ComputeClient<R::Server>,
82    lhs: &MatmulInputHandleRef<'_, R>,
83    rhs: &MatmulInputHandleRef<'_, R>,
84    out: &TensorHandleRef<'_, R>,
85    selection: &Selection<A::SelectionArgs>,
86) -> Result<(), MatmulSetupError> {
87    let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_batch_layout(tensor.strides) {
88        MatrixBatchLayout::Contiguous => (false, false),
89        MatrixBatchLayout::MildlyPermuted {
90            transposed,
91            batch_swap: _,
92        } => (false, transposed),
93        MatrixBatchLayout::HighlyPermuted => (true, false),
94    };
95
96    let (lhs_make_contiguous, lhs_transposed) = check_layout(lhs.data());
97    let (rhs_make_contiguous, rhs_transposed) = check_layout(rhs.data());
98
99    let lhs_owned;
100    let rhs_owned;
101    let lhs = if lhs_make_contiguous {
102        lhs_owned = lhs.into_contiguous::<LhsG<MP>>(client);
103        &lhs_owned.as_ref()
104    } else {
105        lhs
106    };
107    let rhs = if rhs_make_contiguous {
108        rhs_owned = rhs.into_contiguous::<RhsG<MP>>(client);
109        &rhs_owned.as_ref()
110    } else {
111        rhs
112    };
113
114    launch_inner_ref::<R, MP, A>(
115        client,
116        lhs,
117        rhs,
118        out,
119        (lhs_transposed, rhs_transposed),
120        selection,
121    )
122}
123
124#[allow(clippy::result_large_err, clippy::too_many_arguments)]
125fn launch_inner_ref<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
126    client: &ComputeClient<R::Server>,
127    lhs_handle: &MatmulInputHandleRef<'_, R>,
128    rhs_handle: &MatmulInputHandleRef<'_, R>,
129    out: &TensorHandleRef<'_, R>,
130    transposed: (bool, bool),
131    selection: &Selection<A::SelectionArgs>,
132) -> Result<(), MatmulSetupError> {
133    let lhs_shape = lhs_handle.shape();
134    let rhs_shape = rhs_handle.shape();
135
136    let rank = lhs_shape.len();
137    let lhs_elem = LhsG::<MP>::as_type_native().expect("To be a native type");
138    let rhs_elem = RhsG::<MP>::as_type_native().expect("To be a native type");
139    let acc_elem = AccG::<MP>::as_type_native().expect("To be a native type");
140
141    if !LhsG::<MP>::supported_uses(client).contains(TypeUsage::Conversion)
142        || !RhsG::<MP>::supported_uses(client).contains(TypeUsage::Conversion)
143        || !AccG::<MP>::supported_uses(client).contains(TypeUsage::Conversion)
144    {
145        return Err(MatmulSetupError::Unavailable(
146            MatmulAvailabilityError::TypesUnavailable {
147                lhs: lhs_elem,
148                rhs: rhs_elem,
149                output: acc_elem,
150            },
151        ));
152    }
153
154    let m = lhs_shape[rank - 2] as u32;
155    let k = lhs_shape[rank - 1] as u32;
156    let n = rhs_shape[rank - 1] as u32;
157
158    let lhs_layout = match transposed.0 {
159        true => MatrixLayout::ColMajor,
160        false => MatrixLayout::RowMajor,
161    };
162
163    let rhs_layout = match transposed.1 {
164        true => MatrixLayout::ColMajor,
165        false => MatrixLayout::RowMajor,
166    };
167
168    let problem = MatmulProblem {
169        m: m as usize,
170        n: n as usize,
171        k: k as usize,
172        lhs_batches: lhs_shape[..lhs_shape.len() - 2].to_vec(),
173        rhs_batches: rhs_shape[..rhs_shape.len() - 2].to_vec(),
174        out_batches: out.shape[..out.shape.len() - 2].to_vec(),
175        lhs_layout,
176        rhs_layout,
177    };
178
179    let lhs = lhs_handle.data();
180    let rhs = rhs_handle.data();
181
182    let line_sizes =
183        AvailableLineSizes::from_type_sizes::<R>(lhs.elem_size, rhs.elem_size, out.elem_size);
184    let line_sizes = A::filter_line_sizes(line_sizes);
185    let mut line_sizes = line_sizes
186        .filter_lhs_with_tensor(lhs.strides, lhs.shape, problem.lhs_layout)
187        .filter_rhs_with_tensor(rhs.strides, rhs.shape, problem.rhs_layout)
188        .filter_out_with_tensor(out.strides, out.shape)
189        .pick_max()?;
190
191    // The large line size resulting from dequantizing ends up slower due to restrictions on
192    // algorithms. Use this as a quick and dirty fix.
193    if lhs_handle.scale().is_some() {
194        line_sizes.lhs = 1;
195    }
196    if rhs_handle.scale().is_some() {
197        line_sizes.rhs = 1;
198    }
199
200    let fix_plane_dim = |plane_dim: u32| {
201        // Sometimes the GPU doesn't support plane instructions and doesn't report the
202        // plane size, but we can still execute algorithms that don't use plane instructions.
203        //
204        // In this case, we set a plane size for the selector to work, defaulting to 32 as it
205        // is a common plane size.
206        if plane_dim == 0 { 32 } else { plane_dim }
207    };
208
209    let plane_dim = fix_plane_dim(A::select_plane_dim::<R>(client));
210
211    launch_inner_ref_fix_dtype::<R, MP, A>(
212        client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
213    )
214}
215
216#[allow(clippy::result_large_err, clippy::too_many_arguments)]
217fn launch_inner_ref_fix_dtype<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
218    client: &ComputeClient<R::Server>,
219    lhs: &MatmulInputHandleRef<'_, R>,
220    rhs: &MatmulInputHandleRef<'_, R>,
221    out: &TensorHandleRef<'_, R>,
222    problem: MatmulProblem,
223    line_sizes: MatmulLineSizes,
224    plane_dim: u32,
225    selection: &Selection<A::SelectionArgs>,
226) -> Result<(), MatmulSetupError> {
227    if <A::TileMatmul as TileMatmulFamily>::requires_accelerator()
228        && tf32::supported_uses(client).contains(TypeUsage::Conversion)
229    {
230        match (
231            TypeId::of::<LhsG<MP>>() == TypeId::of::<f32>(),
232            TypeId::of::<RhsG<MP>>() == TypeId::of::<f32>(),
233        ) {
234            (true, true) => launch_kernel_concrete::<
235                ((f32, f32, AccG<MP>, tf32, tf32, AccS<MP>), TensorArgs),
236                R,
237                A,
238            >(
239                client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
240            ),
241            (true, false) => launch_kernel_concrete::<
242                (
243                    (f32, RhsG<MP>, AccG<MP>, tf32, RhsS<MP>, AccS<MP>),
244                    TensorArgs,
245                ),
246                R,
247                A,
248            >(
249                client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
250            ),
251            (false, true) => launch_kernel_concrete::<
252                (
253                    (LhsG<MP>, f32, AccG<MP>, LhsS<MP>, tf32, AccS<MP>),
254                    TensorArgs,
255                ),
256                R,
257                A,
258            >(
259                client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
260            ),
261            (false, false) => launch_kernel_concrete::<(MP, TensorArgs), R, A>(
262                client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
263            ),
264        }
265    } else {
266        launch_kernel_concrete::<(MP, TensorArgs), R, A>(
267            client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
268        )
269    }
270}
271
272#[allow(clippy::result_large_err, clippy::too_many_arguments)]
273pub fn matmul_cmma_tma_ref_no_check<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
274    client: &ComputeClient<R::Server>,
275    lhs_handle: &MatmulInputHandleRef<'_, R>,
276    rhs_handle: &MatmulInputHandleRef<'_, R>,
277    out: &TensorHandleRef<'_, R>,
278    transposed: (bool, bool),
279    selection: &Selection<A::SelectionArgs>,
280) -> Result<(), MatmulSetupError> {
281    let lhs = lhs_handle.data();
282    let rhs = rhs_handle.data();
283
284    let rank = lhs.strides.len();
285
286    let m = lhs.shape[rank - 2] as u32;
287    let k = lhs.shape[rank - 1] as u32;
288    let n = rhs.shape[rank - 1] as u32;
289
290    let lhs_layout = match transposed.0 {
291        true => MatrixLayout::ColMajor,
292        false => MatrixLayout::RowMajor,
293    };
294    let rhs_layout = match transposed.1 {
295        true => MatrixLayout::ColMajor,
296        false => MatrixLayout::RowMajor,
297    };
298
299    let line_sizes = MatmulLineSizes {
300        lhs: 1,
301        rhs: 1,
302        out: try_tensor_line_size_parallel(
303            R::io_optimized_line_sizes_unchecked(out.elem_size),
304            out.shape,
305            out.strides,
306            rank - 1,
307        )?,
308    };
309
310    let batch_lhs: usize = lhs.shape[..lhs.shape.len() - 2].iter().product();
311    let batch_rhs: usize = rhs.shape[..rhs.shape.len() - 2].iter().product();
312    let batch_out: usize = out.shape[..out.shape.len() - 2].iter().product();
313
314    let problem = MatmulProblem {
315        m: m as usize,
316        n: n as usize,
317        k: k as usize,
318        lhs_batches: [batch_lhs].to_vec(),
319        rhs_batches: [batch_rhs].to_vec(),
320        out_batches: [batch_out].to_vec(),
321        lhs_layout,
322        rhs_layout,
323    };
324
325    let plane_size = client.properties().hardware.plane_size_max;
326
327    let plane_dim = match plane_size {
328        32 | 64 => plane_size,
329        _ => {
330            return Err(MatmulSetupError::Unavailable(
331                MatmulAvailabilityError::PlaneDimUnsupported {
332                    plane_dim: plane_size,
333                },
334            ));
335        }
336    };
337
338    if tf32::supported_uses(client).contains(TypeUsage::Conversion) {
339        match (
340            TypeId::of::<LhsG<MP>>() == TypeId::of::<f32>(),
341            TypeId::of::<RhsG<MP>>() == TypeId::of::<f32>(),
342        ) {
343            (true, true) => launch_kernel_concrete::<
344                ((f32, f32, AccG<MP>, tf32, tf32, AccS<MP>), TensorMapArgs),
345                R,
346                A,
347            >(
348                client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
349            ),
350            (true, false) => launch_kernel_concrete::<
351                (
352                    (f32, RhsG<MP>, AccG<MP>, tf32, RhsS<MP>, AccS<MP>),
353                    TensorMapArgs,
354                ),
355                R,
356                A,
357            >(
358                client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
359            ),
360            (false, true) => launch_kernel_concrete::<
361                (
362                    (LhsG<MP>, f32, AccG<MP>, LhsS<MP>, tf32, AccS<MP>),
363                    TensorMapArgs,
364                ),
365                R,
366                A,
367            >(
368                client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
369            ),
370            (false, false) => launch_kernel_concrete::<(MP, TensorMapArgs), R, A>(
371                client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
372            ),
373        }
374    } else {
375        launch_kernel_concrete::<(MP, TensorMapArgs), R, A>(
376            client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
377        )
378    }
379}
380
381#[allow(clippy::too_many_arguments, clippy::result_large_err)]
382pub fn launch_with_config<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>(
383    client: &ComputeClient<R::Server>,
384    cube_dim: CubeDim,
385    cube_count: CubeCount,
386    input: InputRuntimeArg<'a, MS, R>,
387    output: OutputRuntimeArg<'a, MS, R>,
388    cube_count_input: CubeCountInputArgs<'a, R>,
389    config: <A::BatchMatmul as BatchMatmulFamily>::Config,
390) -> Result<(), MatmulSetupError> {
391    unsafe {
392        A::BatchMatmul::launch_unchecked::<MS, R>(
393            client,
394            cube_dim,
395            cube_count,
396            input,
397            output,
398            cube_count_input,
399            config,
400        );
401    };
402
403    Ok(())
404}