Skip to main content

burn_cubecl/kernel/matmul/
base.rs

1use super::init_matmul_output;
2use crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor};
3use burn_backend::{DType, QTensorPrimitive};
4use burn_std::QuantLevel;
5use cubek::matmul::{
6    definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError},
7    launch::{MatmulInputHandleRef, Strategy},
8};
9
10#[cfg(feature = "autotune")]
11use super::matmul_autotune;
12
13/// The strategy to be used when launching a matmul kernel.
14pub enum MatmulStrategy {
15    #[cfg(feature = "autotune")]
16    /// Using autotune to choose the best kernel based on runtime information.
17    Autotune,
18    /// Cube implementation of matmul.
19    Cube,
20}
21
22impl Default for MatmulStrategy {
23    fn default() -> Self {
24        // if autotune is enabled, default to autotune
25        #[cfg(feature = "autotune")]
26        return MatmulStrategy::Autotune;
27
28        #[cfg(not(feature = "autotune"))]
29        MatmulStrategy::Cube
30    }
31}
32
33/// Launch a matmul kernel using the given strategy.
34pub fn matmul<R: CubeRuntime>(
35    lhs: CubeTensor<R>,
36    rhs: CubeTensor<R>,
37    out: Option<CubeTensor<R>>,
38    strategy: MatmulStrategy,
39    out_dtype: DType,
40) -> Result<CubeTensor<R>, MatmulSetupError> {
41    match strategy {
42        MatmulStrategy::Cube => {
43            let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
44            launch_matmul(&Default::default(), lhs, rhs, out.clone())?;
45            Ok(out)
46        }
47        #[cfg(feature = "autotune")]
48        MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)),
49    }
50}
51
52pub(crate) fn launch_matmul_naive<R: CubeRuntime>(
53    strategy: &Strategy,
54    mut lhs: CubeTensor<R>,
55    mut rhs: CubeTensor<R>,
56    out: CubeTensor<R>,
57) -> Result<(), MatmulSetupError> {
58    // Naive has very specific layout requirements for block scaled tensors, so we need to manually
59    // dequantize if it fails to launch normally. This is because naive is assumed to always work.
60    if lhs.qparams.is_some() || rhs.qparams.is_some() {
61        match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) {
62            Err(_) => {
63                if lhs.qparams.is_some() {
64                    lhs = dequantize(lhs, out.dtype);
65                }
66                if rhs.qparams.is_some() {
67                    rhs = dequantize(rhs, out.dtype);
68                }
69                launch_matmul(strategy, lhs, rhs, out)
70            }
71            Ok(_) => Ok(()),
72        }
73    } else {
74        launch_matmul(strategy, lhs, rhs, out)
75    }
76}
77
78pub(crate) fn launch_matmul<R: CubeRuntime>(
79    strategy: &Strategy,
80    lhs: CubeTensor<R>,
81    mut rhs: CubeTensor<R>,
82    out: CubeTensor<R>,
83) -> Result<(), MatmulSetupError> {
84    let client = &lhs.client;
85
86    let lhs_quant_handles = lhs.quantized_handles();
87    let out_dtype: DType = out.dtype;
88
89    let (lhs_dtype, lhs_handle) = match &lhs_quant_handles {
90        None => (
91            lhs.dtype,
92            MatmulInputHandleRef::new(lhs.as_handle_ref(), lhs.dtype.into()),
93        ),
94        Some((data, scale)) => (
95            out_dtype,
96            MatmulInputHandleRef::quantized(
97                data.as_handle_ref(),
98                scale.as_handle_ref(),
99                &lhs.shape.dims,
100                lhs.scheme(),
101                data.dtype.into(),
102                scale.dtype.into(),
103            ),
104        ),
105    };
106
107    let rhs_quant_handles = rhs.quantized_handles();
108
109    let (rhs_dtype, rhs_handle) = match &rhs_quant_handles {
110        None => (
111            lhs.dtype,
112            MatmulInputHandleRef::new(rhs.as_handle_ref(), lhs.dtype.into()),
113        ),
114        Some((data, scale)) => {
115            // Extremely hacky fix to ensure naive can run in every case
116            if matches!(strategy, Strategy::Naive)
117                && matches!(rhs.scheme().level, QuantLevel::Block(_))
118            {
119                rhs = dequantize(rhs.clone(), lhs.dtype);
120                (
121                    lhs.dtype,
122                    MatmulInputHandleRef::new(rhs.as_handle_ref(), rhs.dtype.into()),
123                )
124            } else {
125                (
126                    out_dtype,
127                    MatmulInputHandleRef::quantized(
128                        data.as_handle_ref(),
129                        scale.as_handle_ref(),
130                        &rhs.shape.dims,
131                        rhs.scheme(),
132                        data.dtype.into(),
133                        scale.dtype.into(),
134                    ),
135                )
136            }
137        }
138    };
139
140    let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
141        lhs: lhs_dtype.into(),
142        rhs: rhs_dtype.into(),
143        out: out_dtype.into(),
144    });
145    cubek::matmul::launch::launch_ref(
146        strategy,
147        client,
148        &lhs_handle,
149        &rhs_handle,
150        &out.as_handle_ref(),
151        &mut dtypes,
152    )?;
153
154    Ok(())
155}