burn_jit/kernel/matmul/
base.rs

1use cubecl::linalg::matmul::kernels::MatmulLaunchError;
2
3use super::init_matmul_output;
4use crate::{tensor::JitTensor, FloatElement, JitRuntime};
5
6#[cfg(feature = "autotune")]
7use super::matmul_autotune;
8
9/// The strategy to be used when launching a matmul kernel.
10pub enum MatmulStrategy {
11    #[cfg(feature = "autotune")]
12    /// Using autotune to choose the best kernel based on runtime information.
13    Autotune,
14    /// Cube implementation of matmul.
15    Cube,
16}
17
18impl Default for MatmulStrategy {
19    fn default() -> Self {
20        // if autotune is enabled, default to autotune
21        #[cfg(feature = "autotune")]
22        return MatmulStrategy::Autotune;
23
24        #[cfg(not(feature = "autotune"))]
25        MatmulStrategy::Cube
26    }
27}
28
29/// Launch a matmul kernel using the given strategy.
30pub fn matmul<R: JitRuntime, E: FloatElement>(
31    lhs: JitTensor<R>,
32    rhs: JitTensor<R>,
33    out: Option<JitTensor<R>>,
34    strategy: MatmulStrategy,
35) -> Result<JitTensor<R>, MatmulLaunchError> {
36    match strategy {
37        MatmulStrategy::Cube => {
38            let out = out.unwrap_or_else(|| init_matmul_output::<R, E>(&lhs, &rhs));
39
40            let client = &lhs.client;
41
42            cubecl::linalg::matmul::launch_ref::<R, E>(
43                &Default::default(),
44                client,
45                &lhs.as_handle_ref(),
46                &rhs.as_handle_ref(),
47                &out.as_handle_ref(),
48            )?;
49
50            Ok(out)
51        }
52        #[cfg(feature = "autotune")]
53        MatmulStrategy::Autotune => Ok(matmul_autotune::<R, E>(lhs, rhs, out)),
54    }
55}