burn_jit/kernel/matmul/
base.rs1use cubecl::linalg::matmul::kernels::MatmulLaunchError;
2
3use super::init_matmul_output;
4use crate::{tensor::JitTensor, FloatElement, JitRuntime};
5
6#[cfg(feature = "autotune")]
7use super::matmul_autotune;
8
9pub enum MatmulStrategy {
11 #[cfg(feature = "autotune")]
12 Autotune,
14 Cube,
16}
17
18impl Default for MatmulStrategy {
19 fn default() -> Self {
20 #[cfg(feature = "autotune")]
22 return MatmulStrategy::Autotune;
23
24 #[cfg(not(feature = "autotune"))]
25 MatmulStrategy::Cube
26 }
27}
28
29pub fn matmul<R: JitRuntime, E: FloatElement>(
31 lhs: JitTensor<R>,
32 rhs: JitTensor<R>,
33 out: Option<JitTensor<R>>,
34 strategy: MatmulStrategy,
35) -> Result<JitTensor<R>, MatmulLaunchError> {
36 match strategy {
37 MatmulStrategy::Cube => {
38 let out = out.unwrap_or_else(|| init_matmul_output::<R, E>(&lhs, &rhs));
39
40 let client = &lhs.client;
41
42 cubecl::linalg::matmul::launch_ref::<R, E>(
43 &Default::default(),
44 client,
45 &lhs.as_handle_ref(),
46 &rhs.as_handle_ref(),
47 &out.as_handle_ref(),
48 )?;
49
50 Ok(out)
51 }
52 #[cfg(feature = "autotune")]
53 MatmulStrategy::Autotune => Ok(matmul_autotune::<R, E>(lhs, rhs, out)),
54 }
55}