use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, calculate_matmul_output};
pub fn init_matmul_output<R: CubeRuntime>(
lhs: &CubeTensor<R>,
rhs: &CubeTensor<R>,
dtype: DType,
) -> CubeTensor<R> {
empty_device_dtype(
lhs.client.clone(),
lhs.device.clone(),
calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(),
dtype,
)
}