burn_cubecl/kernel/matmul/
base.rs1use super::init_matmul_output;
2use crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor};
3use burn_backend::{DType, QTensorPrimitive};
4use burn_std::QuantLevel;
5use cubek::matmul::{
6 definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError},
7 launch::{MatmulInputHandleRef, Strategy},
8};
9
10#[cfg(feature = "autotune")]
11use super::matmul_autotune;
12
13pub enum MatmulStrategy {
15 #[cfg(feature = "autotune")]
16 Autotune,
18 Cube,
20}
21
22impl Default for MatmulStrategy {
23 fn default() -> Self {
24 #[cfg(feature = "autotune")]
26 return MatmulStrategy::Autotune;
27
28 #[cfg(not(feature = "autotune"))]
29 MatmulStrategy::Cube
30 }
31}
32
33pub fn matmul<R: CubeRuntime>(
35 lhs: CubeTensor<R>,
36 rhs: CubeTensor<R>,
37 out: Option<CubeTensor<R>>,
38 strategy: MatmulStrategy,
39 out_dtype: DType,
40) -> Result<CubeTensor<R>, MatmulSetupError> {
41 match strategy {
42 MatmulStrategy::Cube => {
43 let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
44 launch_matmul(&Default::default(), lhs, rhs, out.clone())?;
45 Ok(out)
46 }
47 #[cfg(feature = "autotune")]
48 MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)),
49 }
50}
51
52pub(crate) fn launch_matmul_naive<R: CubeRuntime>(
53 strategy: &Strategy,
54 mut lhs: CubeTensor<R>,
55 mut rhs: CubeTensor<R>,
56 out: CubeTensor<R>,
57) -> Result<(), MatmulSetupError> {
58 if lhs.qparams.is_some() || rhs.qparams.is_some() {
61 match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) {
62 Err(_) => {
63 if lhs.qparams.is_some() {
64 lhs = dequantize(lhs, out.dtype);
65 }
66 if rhs.qparams.is_some() {
67 rhs = dequantize(rhs, out.dtype);
68 }
69 launch_matmul(strategy, lhs, rhs, out)
70 }
71 Ok(_) => Ok(()),
72 }
73 } else {
74 launch_matmul(strategy, lhs, rhs, out)
75 }
76}
77
78pub(crate) fn launch_matmul<R: CubeRuntime>(
79 strategy: &Strategy,
80 lhs: CubeTensor<R>,
81 mut rhs: CubeTensor<R>,
82 out: CubeTensor<R>,
83) -> Result<(), MatmulSetupError> {
84 let client = &lhs.client;
85
86 let lhs_quant_handles = lhs.quantized_handles();
87 let out_dtype: DType = out.dtype;
88
89 let (lhs_dtype, lhs_handle) = match &lhs_quant_handles {
90 None => (
91 lhs.dtype,
92 MatmulInputHandleRef::new(lhs.as_handle_ref(), lhs.dtype.into()),
93 ),
94 Some((data, scale)) => (
95 out_dtype,
96 MatmulInputHandleRef::quantized(
97 data.as_handle_ref(),
98 scale.as_handle_ref(),
99 &lhs.shape.dims,
100 lhs.scheme(),
101 data.dtype.into(),
102 scale.dtype.into(),
103 ),
104 ),
105 };
106
107 let rhs_quant_handles = rhs.quantized_handles();
108
109 let (rhs_dtype, rhs_handle) = match &rhs_quant_handles {
110 None => (
111 lhs.dtype,
112 MatmulInputHandleRef::new(rhs.as_handle_ref(), lhs.dtype.into()),
113 ),
114 Some((data, scale)) => {
115 if matches!(strategy, Strategy::Naive)
117 && matches!(rhs.scheme().level, QuantLevel::Block(_))
118 {
119 rhs = dequantize(rhs.clone(), lhs.dtype);
120 (
121 lhs.dtype,
122 MatmulInputHandleRef::new(rhs.as_handle_ref(), rhs.dtype.into()),
123 )
124 } else {
125 (
126 out_dtype,
127 MatmulInputHandleRef::quantized(
128 data.as_handle_ref(),
129 scale.as_handle_ref(),
130 &rhs.shape.dims,
131 rhs.scheme(),
132 data.dtype.into(),
133 scale.dtype.into(),
134 ),
135 )
136 }
137 }
138 };
139
140 let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
141 lhs: lhs_dtype.into(),
142 rhs: rhs_dtype.into(),
143 out: out_dtype.into(),
144 });
145 cubek::matmul::launch::launch_ref(
146 strategy,
147 client,
148 &lhs_handle,
149 &rhs_handle,
150 &out.as_handle_ref(),
151 &mut dtypes,
152 )?;
153
154 Ok(())
155}