Skip to main content

burn_cubecl/kernel/matmul/
utils.rs

1use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
2use burn_backend::{DType, calculate_matmul_output};
3
4/// Creates an empty output tensor with matmul output shape
5pub fn init_matmul_output<R: CubeRuntime>(
6    lhs: &CubeTensor<R>,
7    rhs: &CubeTensor<R>,
8    dtype: DType,
9) -> CubeTensor<R> {
10    empty_device_dtype(
11        lhs.client.clone(),
12        lhs.device.clone(),
13        calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(),
14        dtype,
15    )
16}