burn-cubecl 0.21.0-pre.2

Generic backend that can be compiled just-in-time to any shader language target
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, calculate_matmul_output};

/// Creates an empty output tensor with matmul output shape
pub fn init_matmul_output<R: CubeRuntime>(
    lhs: &CubeTensor<R>,
    rhs: &CubeTensor<R>,
    dtype: DType,
) -> CubeTensor<R> {
    empty_device_dtype(
        lhs.client.clone(),
        lhs.device.clone(),
        calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(),
        dtype,
    )
}