Skip to main content

burn_cubecl/kernel/matmul/
base.rs

1use super::init_matmul_output;
2use crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor};
3use burn_backend::{DType, QTensorPrimitive};
4use burn_std::QuantLevel;
5use cubek::{
6    matmul::{
7        definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError},
8        launch::Strategy,
9    },
10    std::InputBinding,
11};
12
13#[cfg(feature = "autotune")]
14use super::matmul_autotune;
15
16/// The strategy to be used when launching a matmul kernel.
17pub enum MatmulStrategy {
18    #[cfg(feature = "autotune")]
19    /// Using autotune to choose the best kernel based on runtime information.
20    Autotune,
21    /// Cube implementation of matmul.
22    Cube,
23}
24
25impl Default for MatmulStrategy {
26    fn default() -> Self {
27        // if autotune is enabled, default to autotune
28        #[cfg(feature = "autotune")]
29        return MatmulStrategy::Autotune;
30
31        #[cfg(not(feature = "autotune"))]
32        MatmulStrategy::Cube
33    }
34}
35
36/// Launch a matmul kernel using the given strategy.
37pub fn matmul<R: CubeRuntime>(
38    lhs: CubeTensor<R>,
39    rhs: CubeTensor<R>,
40    out: Option<CubeTensor<R>>,
41    strategy: MatmulStrategy,
42    out_dtype: DType,
43) -> Result<CubeTensor<R>, MatmulSetupError> {
44    match strategy {
45        MatmulStrategy::Cube => {
46            let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
47            launch_matmul(&Default::default(), lhs, rhs, out.clone())?;
48            Ok(out)
49        }
50        #[cfg(feature = "autotune")]
51        MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)),
52    }
53}
54
55pub(crate) fn launch_matmul_naive<R: CubeRuntime>(
56    strategy: &Strategy,
57    mut lhs: CubeTensor<R>,
58    mut rhs: CubeTensor<R>,
59    out: CubeTensor<R>,
60) -> Result<(), MatmulSetupError> {
61    // Naive has very specific layout requirements for block scaled tensors, so we need to manually
62    // dequantize if it fails to launch normally. This is because naive is assumed to always work.
63    if lhs.qparams.is_some() || rhs.qparams.is_some() {
64        match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) {
65            Err(_) => {
66                if lhs.qparams.is_some() {
67                    lhs = dequantize(lhs, out.dtype);
68                }
69                if rhs.qparams.is_some() {
70                    rhs = dequantize(rhs, out.dtype);
71                }
72                launch_matmul(strategy, lhs, rhs, out)
73            }
74            Ok(_) => Ok(()),
75        }
76    } else {
77        launch_matmul(strategy, lhs, rhs, out)
78    }
79}
80
81pub(crate) fn launch_matmul<R: CubeRuntime>(
82    strategy: &Strategy,
83    lhs: CubeTensor<R>,
84    mut rhs: CubeTensor<R>,
85    out: CubeTensor<R>,
86) -> Result<(), MatmulSetupError> {
87    let client = &out.client;
88
89    let lhs_quant_handles = lhs.quantized_handles();
90    let out_dtype: DType = out.dtype;
91
92    let (lhs_dtype, lhs_handle) = match lhs_quant_handles {
93        None => {
94            let lhs_dtype = lhs.dtype;
95            (
96                lhs_dtype,
97                InputBinding::new(lhs.binding(), lhs_dtype.into()),
98            )
99        }
100        Some((data, scale)) => {
101            let scheme = *lhs.scheme();
102            let data_dtype = data.dtype;
103            let scale_dtype = scale.dtype;
104            (
105                out_dtype,
106                InputBinding::quantized(
107                    data.binding(),
108                    scale.binding(),
109                    lhs.meta.shape().clone(),
110                    scheme,
111                    data_dtype.into(),
112                    scale_dtype.into(),
113                ),
114            )
115        }
116    };
117
118    let rhs_quant_handles = rhs.quantized_handles();
119
120    let (rhs_dtype, rhs_handle) = match rhs_quant_handles {
121        None => (
122            lhs_dtype,
123            InputBinding::new(rhs.binding(), lhs_dtype.into()),
124        ),
125        Some((data, scale)) => {
126            // Extremely hacky fix to ensure naive can run in every case
127            if matches!(strategy, Strategy::Naive)
128                && matches!(rhs.scheme().level, QuantLevel::Block(_))
129            {
130                rhs = dequantize(rhs.clone(), lhs_dtype);
131                let rhs_dtype = rhs.dtype;
132                (
133                    lhs_dtype,
134                    InputBinding::new(rhs.binding(), rhs_dtype.into()),
135                )
136            } else {
137                let scheme = *rhs.scheme();
138                let data_dtype = data.dtype;
139                let scale_dtype = scale.dtype;
140                (
141                    out_dtype,
142                    InputBinding::quantized(
143                        data.binding(),
144                        scale.binding(),
145                        rhs.meta.shape().clone(),
146                        scheme,
147                        data_dtype.into(),
148                        scale_dtype.into(),
149                    ),
150                )
151            }
152        }
153    };
154
155    let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
156        lhs: lhs_dtype.into(),
157        rhs: rhs_dtype.into(),
158        out: out_dtype.into(),
159    });
160
161    cubek::matmul::launch::launch_ref(
162        strategy,
163        client,
164        lhs_handle,
165        rhs_handle,
166        out.clone().binding(),
167        &mut dtypes,
168    )?;
169
170    Ok(())
171}