burn_dragon_kernel 0.21.0-pre.13

Fused GPU kernel crate for burn_dragon execution paths
Documentation
use std::any::TypeId;

#[cfg(not(feature = "cuda"))]
use burn_cubecl::cubecl::wgpu::WgpuRuntime;
use burn_cubecl::fusion::FusionCubeRuntime;
use burn_cubecl::tensor::CubeTensor;
use burn_cubecl::{BoolElement, CubeRuntime};
#[cfg(feature = "cuda")]
use burn_cubecl::{cubecl::cuda::CudaRuntime, cubecl::wgpu::WgpuRuntime};
use burn_fusion::{Client, FusionTensor, NoOp, stream::OperationStreams};
use burn_ir::{InitOperationIr, OperationIr, OperationOutput};
use burn_wgpu::CubeBackend;

fn register_fusion_float_tensor_with_bool<R: CubeRuntime, BT: BoolElement + 'static>(
    client: &Client<FusionCubeRuntime<R>>,
    tensor: CubeTensor<R>,
) -> FusionTensor<FusionCubeRuntime<R>> {
    let shape = tensor.meta.shape().clone();
    let dtype = tensor.dtype;
    let handle = tensor.into();
    let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));

    client
        .register(
            OperationStreams::default(),
            OperationIr::Init(desc),
            NoOp::<CubeBackend<R, f32, i32, BT>>::new(),
        )
        .output()
}

pub(crate) fn register_fusion_float_tensor<R: CubeRuntime + 'static>(
    client: &Client<FusionCubeRuntime<R>>,
    tensor: CubeTensor<R>,
) -> FusionTensor<FusionCubeRuntime<R>> {
    if TypeId::of::<R>() == TypeId::of::<WgpuRuntime>() {
        return register_fusion_float_tensor_with_bool::<R, u32>(client, tensor);
    }
    #[cfg(feature = "cuda")]
    if TypeId::of::<R>() == TypeId::of::<CudaRuntime>() {
        return register_fusion_float_tensor_with_bool::<R, u8>(client, tensor);
    }
    panic!(
        "unsupported fusion runtime for float tensor registration: {}",
        std::any::type_name::<R>()
    );
}