burn_cubecl/kernel/matmul/
base.rs1use super::init_matmul_output;
2use crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor};
3use burn_backend::{DType, QTensorPrimitive};
4use burn_std::QuantLevel;
5use cubek::{
6 matmul::{
7 definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError},
8 launch::Strategy,
9 },
10 std::InputBinding,
11};
12
13#[cfg(feature = "autotune")]
14use super::matmul_autotune;
15
16pub enum MatmulStrategy {
18 #[cfg(feature = "autotune")]
19 Autotune,
21 Cube,
23}
24
25impl Default for MatmulStrategy {
26 fn default() -> Self {
27 #[cfg(feature = "autotune")]
29 return MatmulStrategy::Autotune;
30
31 #[cfg(not(feature = "autotune"))]
32 MatmulStrategy::Cube
33 }
34}
35
36pub fn matmul<R: CubeRuntime>(
38 lhs: CubeTensor<R>,
39 rhs: CubeTensor<R>,
40 out: Option<CubeTensor<R>>,
41 strategy: MatmulStrategy,
42 out_dtype: DType,
43) -> Result<CubeTensor<R>, MatmulSetupError> {
44 match strategy {
45 MatmulStrategy::Cube => {
46 let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
47 launch_matmul(&Default::default(), lhs, rhs, out.clone())?;
48 Ok(out)
49 }
50 #[cfg(feature = "autotune")]
51 MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)),
52 }
53}
54
55pub(crate) fn launch_matmul_naive<R: CubeRuntime>(
56 strategy: &Strategy,
57 mut lhs: CubeTensor<R>,
58 mut rhs: CubeTensor<R>,
59 out: CubeTensor<R>,
60) -> Result<(), MatmulSetupError> {
61 if lhs.qparams.is_some() || rhs.qparams.is_some() {
64 match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) {
65 Err(_) => {
66 if lhs.qparams.is_some() {
67 lhs = dequantize(lhs, out.dtype);
68 }
69 if rhs.qparams.is_some() {
70 rhs = dequantize(rhs, out.dtype);
71 }
72 launch_matmul(strategy, lhs, rhs, out)
73 }
74 Ok(_) => Ok(()),
75 }
76 } else {
77 launch_matmul(strategy, lhs, rhs, out)
78 }
79}
80
81pub(crate) fn launch_matmul<R: CubeRuntime>(
82 strategy: &Strategy,
83 lhs: CubeTensor<R>,
84 mut rhs: CubeTensor<R>,
85 out: CubeTensor<R>,
86) -> Result<(), MatmulSetupError> {
87 let client = &out.client;
88
89 let lhs_quant_handles = lhs.quantized_handles();
90 let out_dtype: DType = out.dtype;
91
92 let (lhs_dtype, lhs_handle) = match lhs_quant_handles {
93 None => {
94 let lhs_dtype = lhs.dtype;
95 (
96 lhs_dtype,
97 InputBinding::new(lhs.binding(), lhs_dtype.into()),
98 )
99 }
100 Some((data, scale)) => {
101 let scheme = *lhs.scheme();
102 let data_dtype = data.dtype;
103 let scale_dtype = scale.dtype;
104 (
105 out_dtype,
106 InputBinding::quantized(
107 data.binding(),
108 scale.binding(),
109 lhs.meta.shape().clone(),
110 scheme,
111 data_dtype.into(),
112 scale_dtype.into(),
113 ),
114 )
115 }
116 };
117
118 let rhs_quant_handles = rhs.quantized_handles();
119
120 let (rhs_dtype, rhs_handle) = match rhs_quant_handles {
121 None => (
122 lhs_dtype,
123 InputBinding::new(rhs.binding(), lhs_dtype.into()),
124 ),
125 Some((data, scale)) => {
126 if matches!(strategy, Strategy::Naive)
128 && matches!(rhs.scheme().level, QuantLevel::Block(_))
129 {
130 rhs = dequantize(rhs.clone(), lhs_dtype);
131 let rhs_dtype = rhs.dtype;
132 (
133 lhs_dtype,
134 InputBinding::new(rhs.binding(), rhs_dtype.into()),
135 )
136 } else {
137 let scheme = *rhs.scheme();
138 let data_dtype = data.dtype;
139 let scale_dtype = scale.dtype;
140 (
141 out_dtype,
142 InputBinding::quantized(
143 data.binding(),
144 scale.binding(),
145 rhs.meta.shape().clone(),
146 scheme,
147 data_dtype.into(),
148 scale_dtype.into(),
149 ),
150 )
151 }
152 }
153 };
154
155 let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
156 lhs: lhs_dtype.into(),
157 rhs: rhs_dtype.into(),
158 out: out_dtype.into(),
159 });
160
161 cubek::matmul::launch::launch_ref(
162 strategy,
163 client,
164 lhs_handle,
165 rhs_handle,
166 out.clone().binding(),
167 &mut dtypes,
168 )?;
169
170 Ok(())
171}