burn_dragon_kernel 0.21.0-pre.13

Fused GPU kernel crate for burn_dragon execution paths
Documentation
use super::*;
use burn::tensor::backend::AutodiffBackend;

fn dense_causal_backward_impl<B: BackendTrait>(
    grad_output: BurnTensor<B, 4>,
    query: BurnTensor<B, 4>,
    value: BurnTensor<B, 4>,
    decay: BurnTensor<B, 1>,
) -> (BurnTensor<B, 4>, BurnTensor<B, 4>, BurnTensor<B, 1>)
where
    B::FloatTensorPrimitive: 'static,
{
    let [batch, heads, time, latent] = query.shape().dims::<4>();
    let value_heads = value.shape().dims::<4>()[1];
    let value_dim = value.shape().dims::<4>()[3];

    let value_per_head = if value_heads == heads {
        value.clone()
    } else {
        value.clone().repeat_dim(1, heads)
    };

    let pos_row = BurnTensor::<B, 1, Int>::arange(0..time as i64, &query.device())
        .float()
        .reshape([1, 1, time, 1]);
    let pos_col = BurnTensor::<B, 1, Int>::arange(0..time as i64, &query.device())
        .float()
        .reshape([1, 1, 1, time]);
    let gap = (pos_row - pos_col).tril(-1);
    let decay_matrix = decay
        .clone()
        .reshape([1, heads, 1, 1])
        .repeat_dim(2, time)
        .repeat_dim(3, time)
        .powf(gap.clone());

    let raw_scores = query.clone().matmul(query.clone().swap_dims(2, 3)).tril(-1);
    let scores = raw_scores * decay_matrix.clone();

    let batch_heads = batch * heads;
    let query_flat = query.clone().reshape([batch_heads, time, latent]);
    let value_flat = value_per_head
        .clone()
        .reshape([batch_heads, time, value_dim]);
    let grad_output_flat = grad_output.clone().reshape([batch_heads, time, value_dim]);

    let grad_value_heads = scores
        .clone()
        .swap_dims(2, 3)
        .reshape([batch_heads, time, time])
        .matmul(grad_output_flat.clone())
        .reshape([batch, heads, time, value_dim]);
    let grad_value = if value_heads == heads {
        grad_value_heads
    } else {
        grad_value_heads
            .sum_dim(1)
            .reshape([batch, 1, time, value_dim])
    };

    let grad_scores = grad_output_flat
        .matmul(value_flat.swap_dims(1, 2))
        .reshape([batch, heads, time, time])
        .tril(-1);
    let grad_raw_scores = grad_scores.clone() * decay_matrix;
    let grad_query = (grad_raw_scores.clone() + grad_raw_scores.swap_dims(2, 3))
        .reshape([batch_heads, time, time])
        .matmul(query_flat)
        .reshape([batch, heads, time, latent]);

    let safe_decay = decay.clone().add_scalar(1.0e-12).reshape([1, heads, 1, 1]);
    let grad_decay = ((grad_scores * gap) * scores)
        .div(safe_decay)
        .sum_dim(0)
        .sum_dim(2)
        .sum_dim(3)
        .reshape([heads]);

    (grad_query, grad_value, grad_decay)
}

fn dense_causal_attention_backward_impl<B: BackendTrait>(
    ops: Ops<DenseCausalAttentionBackwardState<B::FloatTensorPrimitive>, 3>,
    grads: &mut Gradients,
) where
    B::FloatTensorPrimitive: 'static,
{
    let grad_output =
        BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(grads.consume::<B>(&ops.node)));
    let DenseCausalAttentionBackwardState {
        query,
        value,
        decay,
    } = ops.state;
    let query = BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(query));
    let value = BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(value));
    let decay = BurnTensor::<B, 1>::from_primitive(TensorPrimitive::Float(decay));

    let (grad_query, grad_value, grad_decay) =
        dense_causal_backward_impl(grad_output, query, value, decay);

    if let Some(parent) = &ops.parents[0] {
        grads.register::<B>(parent.id, grad_query.into_primitive().tensor());
    }
    if let Some(parent) = &ops.parents[1] {
        grads.register::<B>(parent.id, grad_value.into_primitive().tensor());
    }
    if let Some(parent) = &ops.parents[2] {
        grads.register::<B>(parent.id, grad_decay.into_primitive().tensor());
    }
}

impl Backward<WgpuCubeBackend, 3> for FusedDenseCausalAttentionBackward<WgpuCubeBackend> {
    type State = DenseCausalAttentionBackwardState<CubeTensor<WgpuRuntime>>;

    fn backward(
        self,
        ops: Ops<Self::State, 3>,
        grads: &mut Gradients,
        _checkpointer: &mut Checkpointer,
    ) {
        dense_causal_attention_backward_impl::<WgpuCubeBackend>(ops, grads);
    }
}

#[cfg(feature = "cuda")]
impl Backward<CudaCubeBackend, 3> for FusedDenseCausalAttentionBackward<CudaCubeBackend> {
    type State = DenseCausalAttentionBackwardState<CubeTensor<CudaRuntime>>;

    fn backward(
        self,
        ops: Ops<Self::State, 3>,
        grads: &mut Gradients,
        _checkpointer: &mut Checkpointer,
    ) {
        dense_causal_attention_backward_impl::<CudaCubeBackend>(ops, grads);
    }
}

fn dense_causal_attention_autodiff_custom_wgpu<B: BackendTrait>(
    query: &BurnTensor<B, 4>,
    value: &BurnTensor<B, 4>,
    decay: &BurnTensor<B, 1>,
    meta: &BurnTensor<B, 1>,
) -> Option<BurnTensor<B, 4>>
where
    B::FloatTensorPrimitive: 'static,
{
    let query_ad: WgpuCubeAutodiffTensor =
        try_cast_primitive::<B, _>(query.clone().into_primitive().tensor())?;
    let value_ad: WgpuCubeAutodiffTensor =
        try_cast_primitive::<B, _>(value.clone().into_primitive().tensor())?;
    let decay_ad: WgpuCubeAutodiffTensor =
        try_cast_primitive::<B, _>(decay.clone().into_primitive().tensor())?;
    let meta_ad: WgpuCubeAutodiffTensor =
        try_cast_primitive::<B, _>(meta.clone().into_primitive().tensor())?;

    let query_inner = <WgpuCubeAutodiffBackend as AutodiffBackend>::inner(query_ad.clone());
    let value_inner = <WgpuCubeAutodiffBackend as AutodiffBackend>::inner(value_ad.clone());
    let decay_inner = <WgpuCubeAutodiffBackend as AutodiffBackend>::inner(decay_ad.clone());
    let meta_inner = <WgpuCubeAutodiffBackend as AutodiffBackend>::inner(meta_ad);

    let context = forward_runtime::dense_causal_attention_runtime::<WgpuRuntime>(
        query_inner.clone(),
        value_inner.clone(),
        decay_inner.clone(),
        meta_inner,
    );

    let output = match FusedDenseCausalAttentionBackward::<WgpuCubeBackend>(PhantomData)
        .prepare::<NoCheckpointing>([
            query_ad.node.clone(),
            value_ad.node.clone(),
            decay_ad.node.clone(),
        ])
        .compute_bound()
        .stateful()
    {
        OpsKind::Tracked(prep) => prep.finish(
            DenseCausalAttentionBackwardState {
                query: query_inner,
                value: value_inner,
                decay: decay_inner,
            },
            context,
        ),
        OpsKind::UnTracked(prep) => prep.finish(context),
    };

    Some(BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(
        try_cast_backend::<B, _>(output)?,
    )))
}

#[cfg(feature = "cuda")]
fn dense_causal_attention_autodiff_custom_cuda<B: BackendTrait>(
    query: &BurnTensor<B, 4>,
    value: &BurnTensor<B, 4>,
    decay: &BurnTensor<B, 1>,
    meta: &BurnTensor<B, 1>,
) -> Option<BurnTensor<B, 4>>
where
    B::FloatTensorPrimitive: 'static,
{
    let query_ad: CudaCubeAutodiffTensor =
        try_cast_primitive::<B, _>(query.clone().into_primitive().tensor())?;
    let value_ad: CudaCubeAutodiffTensor =
        try_cast_primitive::<B, _>(value.clone().into_primitive().tensor())?;
    let decay_ad: CudaCubeAutodiffTensor =
        try_cast_primitive::<B, _>(decay.clone().into_primitive().tensor())?;
    let meta_ad: CudaCubeAutodiffTensor =
        try_cast_primitive::<B, _>(meta.clone().into_primitive().tensor())?;

    let query_inner = <CudaCubeAutodiffBackend as AutodiffBackend>::inner(query_ad.clone());
    let value_inner = <CudaCubeAutodiffBackend as AutodiffBackend>::inner(value_ad.clone());
    let decay_inner = <CudaCubeAutodiffBackend as AutodiffBackend>::inner(decay_ad.clone());
    let meta_inner = <CudaCubeAutodiffBackend as AutodiffBackend>::inner(meta_ad);

    let context = forward_runtime::dense_causal_attention_runtime::<CudaRuntime>(
        query_inner.clone(),
        value_inner.clone(),
        decay_inner.clone(),
        meta_inner,
    );

    let output = match FusedDenseCausalAttentionBackward::<CudaCubeBackend>(PhantomData)
        .prepare::<NoCheckpointing>([
            query_ad.node.clone(),
            value_ad.node.clone(),
            decay_ad.node.clone(),
        ])
        .compute_bound()
        .stateful()
    {
        OpsKind::Tracked(prep) => prep.finish(
            DenseCausalAttentionBackwardState {
                query: query_inner,
                value: value_inner,
                decay: decay_inner,
            },
            context,
        ),
        OpsKind::UnTracked(prep) => prep.finish(context),
    };

    Some(BurnTensor::<B, 4>::from_primitive(TensorPrimitive::Float(
        try_cast_backend::<B, _>(output)?,
    )))
}

pub(super) fn dense_causal_attention_autodiff_custom<B: BackendTrait, R: CubeRuntime + 'static>(
    query: &BurnTensor<B, 4>,
    value: &BurnTensor<B, 4>,
    decay: &BurnTensor<B, 1>,
    meta: &BurnTensor<B, 1>,
) -> Option<BurnTensor<B, 4>>
where
    B::FloatTensorPrimitive: 'static,
{
    if TypeId::of::<R>() == TypeId::of::<WgpuRuntime>() {
        return dense_causal_attention_autodiff_custom_wgpu(query, value, decay, meta);
    }
    #[cfg(feature = "cuda")]
    if TypeId::of::<R>() == TypeId::of::<CudaRuntime>() {
        return dense_causal_attention_autodiff_custom_cuda(query, value, decay, meta);
    }
    None
}