burn_cubecl/kernel/matmul/utils.rs
1use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
2use burn_backend::{DType, calculate_matmul_output};
3
4/// Creates an empty output tensor with matmul output shape
5pub fn init_matmul_output<R: CubeRuntime>(
6 lhs: &CubeTensor<R>,
7 rhs: &CubeTensor<R>,
8 dtype: DType,
9) -> CubeTensor<R> {
10 empty_device_dtype(
11 lhs.client.clone(),
12 lhs.device.clone(),
13 calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(),
14 dtype,
15 )
16}