burn_dragon_kernel 0.5.0

Fused GPU kernel crate for burn_dragon execution paths
Documentation
use super::*;

pub(super) fn try_fusion_path_runtime<B, BT, R>(
    query: &BurnTensor<B, 4>,
    value: &BurnTensor<B, 4>,
    decay: &BurnTensor<B, 1>,
    meta: &BurnTensor<B, 1>,
) -> Option<BurnTensor<B, 4>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    BT: BoolElement + 'static,
    R: CubeRuntime + 'static,
{
    if !matches_type::<B::FloatTensorPrimitive, FusionTensor<FusionCubeRuntime<R>>>() {
        return None;
    }

    let prim_query = query.clone().into_primitive().tensor();
    let fusion_query: FusionTensor<FusionCubeRuntime<R>> = try_cast_primitive::<B, _>(prim_query)?;
    let fusion_client = fusion_query.client.clone();
    let query = fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_query);
    if query.dtype != DType::F32 {
        return None;
    }

    let value = resolve_fusion_tensor_runtime::<B, BT, R, 4>(value)?;
    let decay = resolve_fusion_tensor_runtime::<B, BT, R, 1>(decay)?;
    let meta = resolve_fusion_tensor_runtime::<B, BT, R, 1>(meta)?;

    let output = dense_causal_attention_runtime::<R>(query, value, decay, meta);
    let output_fusion = register_fusion_float_tensor(&fusion_client, output);
    let output_prim = try_cast_backend::<B, _>(output_fusion)?;
    Some(BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(
        output_prim,
    )))
}

pub(super) fn try_direct_path_runtime<B, R>(
    query: &BurnTensor<B, 4>,
    value: &BurnTensor<B, 4>,
    decay: &BurnTensor<B, 1>,
    meta: &BurnTensor<B, 1>,
) -> Option<BurnTensor<B, 4>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    R: CubeRuntime + 'static,
{
    let prim_query = query.clone().into_primitive().tensor();
    let query: CubeTensor<R> = try_cast_primitive::<B, _>(prim_query)?;
    if query.dtype != DType::F32 {
        return None;
    }

    let prim_value = value.clone().into_primitive().tensor();
    let value: CubeTensor<R> = try_cast_primitive::<B, _>(prim_value)?;
    if value.dtype != DType::F32 {
        return None;
    }

    let prim_decay = decay.clone().into_primitive().tensor();
    let decay: CubeTensor<R> = try_cast_primitive::<B, _>(prim_decay)?;
    if decay.dtype != DType::F32 {
        return None;
    }

    let prim_meta = meta.clone().into_primitive().tensor();
    let meta: CubeTensor<R> = try_cast_primitive::<B, _>(prim_meta)?;
    if meta.dtype != DType::F32 {
        return None;
    }

    let output = dense_causal_attention_runtime::<R>(query, value, decay, meta);
    let output_prim = try_cast_backend::<B, _>(output)?;
    Some(BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(
        output_prim,
    )))
}

pub(super) fn try_direct_path_autodiff_cube_runtime<B, R>(
    query: &BurnTensor<B, 4>,
    value: &BurnTensor<B, 4>,
    decay: &BurnTensor<B, 1>,
    meta: &BurnTensor<B, 1>,
) -> Option<BurnTensor<B, 4>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    R: CubeRuntime + 'static,
{
    super::dense_causal_attention_autodiff_custom::<B, R>(query, value, decay, meta)
}

pub(super) fn try_fusion_path_autodiff_runtime<B, BT, R>(
    query: &BurnTensor<B, 4>,
    value: &BurnTensor<B, 4>,
    decay: &BurnTensor<B, 1>,
    meta: &BurnTensor<B, 1>,
) -> Option<BurnTensor<B, 4>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    BT: BoolElement + 'static,
    R: CubeRuntime + 'static,
{
    if TypeId::of::<R>() == TypeId::of::<WgpuRuntime>() {
        let prim_query = query.clone().into_primitive().tensor();
        let query_ad: WgpuFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_query)?;
        let fusion_query: FusionTensor<FusionCubeRuntime<WgpuRuntime>> =
            <WgpuFusionAutodiffBackend<BT> as AutodiffBackend>::inner(query_ad.clone());
        let fusion_client = fusion_query.client.clone();
        let query = fusion_client
            .resolve_tensor_float::<CubeBackend<WgpuRuntime, f32, i32, BT>>(fusion_query);
        if query.dtype != DType::F32 {
            return None;
        }

        let prim_value = value.clone().into_primitive().tensor();
        let value_ad: WgpuFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_value)?;
        let fusion_value: FusionTensor<FusionCubeRuntime<WgpuRuntime>> =
            <WgpuFusionAutodiffBackend<BT> as AutodiffBackend>::inner(value_ad.clone());
        let value = fusion_client
            .resolve_tensor_float::<CubeBackend<WgpuRuntime, f32, i32, BT>>(fusion_value);
        if value.dtype != DType::F32 {
            return None;
        }

        let prim_decay = decay.clone().into_primitive().tensor();
        let decay_ad: WgpuFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_decay)?;
        let fusion_decay: FusionTensor<FusionCubeRuntime<WgpuRuntime>> =
            <WgpuFusionAutodiffBackend<BT> as AutodiffBackend>::inner(decay_ad);
        let decay = fusion_client
            .resolve_tensor_float::<CubeBackend<WgpuRuntime, f32, i32, BT>>(fusion_decay);
        if decay.dtype != DType::F32 {
            return None;
        }

        let prim_meta = meta.clone().into_primitive().tensor();
        let meta_ad: WgpuFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_meta)?;
        let fusion_meta: FusionTensor<FusionCubeRuntime<WgpuRuntime>> =
            <WgpuFusionAutodiffBackend<BT> as AutodiffBackend>::inner(meta_ad);
        let meta = fusion_client
            .resolve_tensor_float::<CubeBackend<WgpuRuntime, f32, i32, BT>>(fusion_meta);
        if meta.dtype != DType::F32 {
            return None;
        }

        let output = dense_causal_attention_runtime::<WgpuRuntime>(query, value, decay, meta);
        let output_fusion = register_fusion_float_tensor(&fusion_client, output);
        let output_ad = wrap_fusion_autodiff_inner::<B, BT, WgpuRuntime>(output_fusion)?;
        let output_prim = try_cast_backend::<B, _>(output_ad)?;
        return Some(BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(
            output_prim,
        )));
    }

    #[cfg(feature = "cuda")]
    if TypeId::of::<R>() == TypeId::of::<CudaRuntime>() {
        let prim_query = query.clone().into_primitive().tensor();
        let query_ad: CudaFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_query)?;
        let fusion_query: FusionTensor<FusionCubeRuntime<CudaRuntime>> =
            <CudaFusionAutodiffBackend<BT> as AutodiffBackend>::inner(query_ad.clone());
        let fusion_client = fusion_query.client.clone();
        let query = fusion_client
            .resolve_tensor_float::<CubeBackend<CudaRuntime, f32, i32, BT>>(fusion_query);
        if query.dtype != DType::F32 {
            return None;
        }

        let prim_value = value.clone().into_primitive().tensor();
        let value_ad: CudaFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_value)?;
        let fusion_value: FusionTensor<FusionCubeRuntime<CudaRuntime>> =
            <CudaFusionAutodiffBackend<BT> as AutodiffBackend>::inner(value_ad.clone());
        let value = fusion_client
            .resolve_tensor_float::<CubeBackend<CudaRuntime, f32, i32, BT>>(fusion_value);
        if value.dtype != DType::F32 {
            return None;
        }

        let prim_decay = decay.clone().into_primitive().tensor();
        let decay_ad: CudaFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_decay)?;
        let fusion_decay: FusionTensor<FusionCubeRuntime<CudaRuntime>> =
            <CudaFusionAutodiffBackend<BT> as AutodiffBackend>::inner(decay_ad);
        let decay = fusion_client
            .resolve_tensor_float::<CubeBackend<CudaRuntime, f32, i32, BT>>(fusion_decay);
        if decay.dtype != DType::F32 {
            return None;
        }

        let prim_meta = meta.clone().into_primitive().tensor();
        let meta_ad: CudaFusionAutodiffTensor<BT> = try_cast_primitive::<B, _>(prim_meta)?;
        let fusion_meta: FusionTensor<FusionCubeRuntime<CudaRuntime>> =
            <CudaFusionAutodiffBackend<BT> as AutodiffBackend>::inner(meta_ad);
        let meta = fusion_client
            .resolve_tensor_float::<CubeBackend<CudaRuntime, f32, i32, BT>>(fusion_meta);
        if meta.dtype != DType::F32 {
            return None;
        }

        let output = dense_causal_attention_runtime::<CudaRuntime>(query, value, decay, meta);
        let output_fusion = register_fusion_float_tensor(&fusion_client, output);
        let output_ad = wrap_fusion_autodiff_inner::<B, BT, CudaRuntime>(output_fusion)?;
        let output_prim = try_cast_backend::<B, _>(output_ad)?;
        return Some(BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(
            output_prim,
        )));
    }

    if !matches_autodiff_fusion_type::<B, BT, R>() {
        return None;
    }

    let prim_query = query.clone().into_primitive().tensor();
    let query_ad: B::FloatTensorPrimitive = try_cast_primitive::<B, _>(prim_query)?;
    let fusion_query: FusionTensor<FusionCubeRuntime<R>> =
        extract_fusion_autodiff_inner::<B, BT, R>(query_ad.clone())?;
    let fusion_client = fusion_query.client.clone();
    let query = fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_query);
    if query.dtype != DType::F32 {
        return None;
    }

    let prim_value = value.clone().into_primitive().tensor();
    let value_ad: B::FloatTensorPrimitive = try_cast_primitive::<B, _>(prim_value)?;
    let fusion_value: FusionTensor<FusionCubeRuntime<R>> =
        extract_fusion_autodiff_inner::<B, BT, R>(value_ad.clone())?;
    let value = fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_value);
    if value.dtype != DType::F32 {
        return None;
    }

    let prim_decay = decay.clone().into_primitive().tensor();
    let decay_ad: B::FloatTensorPrimitive = try_cast_primitive::<B, _>(prim_decay)?;
    let fusion_decay: FusionTensor<FusionCubeRuntime<R>> =
        extract_fusion_autodiff_inner::<B, BT, R>(decay_ad)?;
    let decay = fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_decay);
    if decay.dtype != DType::F32 {
        return None;
    }

    let prim_meta = meta.clone().into_primitive().tensor();
    let meta_ad: B::FloatTensorPrimitive = try_cast_primitive::<B, _>(prim_meta)?;
    let fusion_meta: FusionTensor<FusionCubeRuntime<R>> =
        extract_fusion_autodiff_inner::<B, BT, R>(meta_ad)?;
    let meta = fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_meta);
    if meta.dtype != DType::F32 {
        return None;
    }

    let output = dense_causal_attention_runtime::<R>(query, value, decay, meta);
    let output_fusion = register_fusion_float_tensor(&fusion_client, output);
    let output_ad = wrap_fusion_autodiff_inner::<B, BT, R>(output_fusion)?;
    let output_prim = try_cast_backend::<B, _>(output_ad)?;
    Some(BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(
        output_prim,
    )))
}

pub(crate) fn dense_causal_attention_runtime<R: CubeRuntime + 'static>(
    query: CubeTensor<R>,
    value: CubeTensor<R>,
    decay: CubeTensor<R>,
    meta: CubeTensor<R>,
) -> CubeTensor<R> {
    #[cfg(feature = "cuda")]
    {
        if TypeId::of::<R>() == TypeId::of::<CudaRuntime>() {
            return dense_causal_attention_cube_runtime::<R>(query, value, decay, meta);
        }
    }
    dense_causal_attention_wgsl_runtime::<R>(query, value, decay, meta)
}

pub(super) fn dense_causal_attention_wgsl_runtime<R: CubeRuntime>(
    query: CubeTensor<R>,
    value: CubeTensor<R>,
    decay: CubeTensor<R>,
    meta: CubeTensor<R>,
) -> CubeTensor<R> {
    let query = into_contiguous(query);
    let value = into_contiguous(value);
    let decay = into_contiguous(decay);
    let meta = into_contiguous(meta);

    let [batch, heads, time, _latent] = query.meta.shape.dims::<4>();
    let value_dim = value.meta.shape.dims::<4>()[3];

    let client = query.client.clone();
    let device = query.device.clone();
    let output = empty_device::<R, f32>(
        client.clone(),
        device,
        Shape::new([batch, heads, time, value_dim]),
    );

    let workgroups_x = div_ceil_u32(value_dim as u32, WORKGROUP_SIZE_X);
    let count = CubeCount::Static(workgroups_x, time as u32, (batch * heads) as u32);
    let kernel = SourceKernel::new(
        DenseCausalAttentionKernel,
        CubeDim::new_3d(WORKGROUP_SIZE_X, 1, 1),
    );
    let bindings = KernelArguments::new().with_buffers(vec![
        query.handle.clone().binding(),
        value.handle.clone().binding(),
        decay.handle.clone().binding(),
        output.handle.clone().binding(),
        meta.handle.clone().binding(),
    ]);
    client.launch(Box::new(kernel), count, bindings);

    output
}

#[cfg(feature = "cuda")]
fn dense_causal_attention_cube_runtime<R: CubeRuntime>(
    query: CubeTensor<R>,
    value: CubeTensor<R>,
    decay: CubeTensor<R>,
    meta: CubeTensor<R>,
) -> CubeTensor<R> {
    let query = into_contiguous(query);
    let value = into_contiguous(value);
    let decay = into_contiguous(decay);
    let meta = into_contiguous(meta);

    let [batch, heads, time, _latent] = query.meta.shape.dims::<4>();
    let value_dim = value.meta.shape.dims::<4>()[3];

    let client = query.client.clone();
    let device = query.device.clone();
    let output = empty_device::<R, f32>(
        client.clone(),
        device,
        Shape::new([batch, heads, time, value_dim]),
    );

    let cube_dim = CubeDim::new_1d(WORKGROUP_SIZE_X);
    let cube_count = CubeCount::Static(
        div_ceil_u32(value_dim as u32, WORKGROUP_SIZE_X),
        time as u32,
        (batch * heads) as u32,
    );

    let _ = dense_causal_attention_cube_kernel::launch::<R>(
        &client,
        cube_count,
        cube_dim,
        query.clone().into_tensor_arg(),
        value.clone().into_tensor_arg(),
        decay.clone().into_tensor_arg(),
        output.clone().into_tensor_arg(),
        meta.clone().into_tensor_arg(),
        MAX_FUSED_TIME,
    );

    output
}