burn_dragon_kernel 0.5.0

Fused GPU kernel crate for burn_dragon execution paths
Documentation
#[cfg(feature = "cuda")]
use burn::tensor::Tensor as BurnTensor;
#[cfg(feature = "cuda")]
use burn::tensor::{Shape, TensorData};
#[cfg(feature = "cuda")]
use burn_cubecl::cubecl::cuda::CudaRuntime;
#[cfg(feature = "cuda")]
use burn_cubecl::cubecl::{self, prelude::*};
#[cfg(feature = "cuda")]
use burn_cubecl::kernel::into_contiguous;
#[cfg(feature = "cuda")]
use burn_cubecl::ops::numeric::empty_device;
#[cfg(feature = "cuda")]
use burn_cubecl::tensor::CubeTensor;
#[cfg(feature = "cuda")]
use burn_wgpu::CubeBackend;

#[cfg(feature = "cuda")]
type CudaCubeBackend = CubeBackend<CudaRuntime, f32, i32, u8>;

#[cfg(feature = "cuda")]
const PARAMS_LEN: usize = 5;
#[cfg(feature = "cuda")]
const MAMBA_CONV_CUDA_WORKGROUP_X: u32 = 128;

#[cfg(feature = "cuda")]
pub(crate) struct MambaDepthwiseConvCudaForwardOutput {
    pub(crate) preact: CubeTensor<CudaRuntime>,
    pub(crate) activated: CubeTensor<CudaRuntime>,
    pub(crate) next_state: CubeTensor<CudaRuntime>,
}

#[cfg(feature = "cuda")]
pub(crate) struct MambaDepthwiseConvCudaBackwardOutput {
    pub(crate) grad_x: CubeTensor<CudaRuntime>,
    pub(crate) grad_weight: CubeTensor<CudaRuntime>,
    pub(crate) grad_bias: CubeTensor<CudaRuntime>,
}

#[cfg(feature = "cuda")]
pub(crate) fn fused_mamba_depthwise_conv_forward_cuda(
    x: CubeTensor<CudaRuntime>,
    conv_weight: CubeTensor<CudaRuntime>,
    conv_bias: CubeTensor<CudaRuntime>,
    state: CubeTensor<CudaRuntime>,
) -> MambaDepthwiseConvCudaForwardOutput {
    let x = into_contiguous(x);
    let conv_weight = into_contiguous(conv_weight);
    let conv_bias = into_contiguous(conv_bias);
    let state = into_contiguous(state);

    let [batch, views, channels, time] = x.meta.shape.dims::<4>();
    let d_conv = conv_weight.meta.shape.dims::<2>()[1];
    let client = x.client.clone();
    let device = x.device.clone();
    let max_pos = time.max(d_conv);

    let preact = empty_device::<CudaRuntime, f32>(
        client.clone(),
        device.clone(),
        Shape::new([batch, views, channels, time]),
    );
    let activated = empty_device::<CudaRuntime, f32>(
        client.clone(),
        device.clone(),
        Shape::new([batch, views, channels, time]),
    );
    let next_state = empty_device::<CudaRuntime, f32>(
        client.clone(),
        device.clone(),
        Shape::new([batch, views, channels, d_conv]),
    );
    let params = params_tensor(
        &device,
        [
            batch as f32,
            views as f32,
            channels as f32,
            time as f32,
            d_conv as f32,
        ],
    )
    .into_primitive()
    .tensor();

    let cube_dim = CubeDim::new_1d(MAMBA_CONV_CUDA_WORKGROUP_X);
    let cube_count = CubeCount::Static(
        div_ceil_u32(max_pos as u32, MAMBA_CONV_CUDA_WORKGROUP_X),
        channels as u32,
        (batch * views) as u32,
    );

    unsafe {
        let _ = mamba_depthwise_conv_forward_cuda_kernel::launch_unchecked::<CudaRuntime>(
            &client,
            cube_count,
            cube_dim,
            x.clone().into_tensor_arg(),
            conv_weight.clone().into_tensor_arg(),
            conv_bias.clone().into_tensor_arg(),
            state.clone().into_tensor_arg(),
            preact.clone().into_tensor_arg(),
            activated.clone().into_tensor_arg(),
            next_state.clone().into_tensor_arg(),
            params.clone().into_tensor_arg(),
        );
    }

    MambaDepthwiseConvCudaForwardOutput {
        preact,
        activated,
        next_state,
    }
}

#[cfg(feature = "cuda")]
pub(crate) fn fused_mamba_depthwise_conv_backward_cuda(
    x: CubeTensor<CudaRuntime>,
    conv_weight: CubeTensor<CudaRuntime>,
    state: CubeTensor<CudaRuntime>,
    grad_preact: CubeTensor<CudaRuntime>,
) -> MambaDepthwiseConvCudaBackwardOutput {
    let x = into_contiguous(x);
    let conv_weight = into_contiguous(conv_weight);
    let state = into_contiguous(state);
    let grad_preact = into_contiguous(grad_preact);

    let [batch, views, channels, time] = x.meta.shape.dims::<4>();
    let d_conv = conv_weight.meta.shape.dims::<2>()[1];
    let client = x.client.clone();
    let device = x.device.clone();
    let max_pos = time.max(d_conv);

    let grad_x = empty_device::<CudaRuntime, f32>(
        client.clone(),
        device.clone(),
        Shape::new([batch, views, channels, time]),
    );
    let grad_weight = BurnTensor::<CudaCubeBackend, 2>::zeros([channels, d_conv], &device)
        .into_primitive()
        .tensor();
    let grad_bias = BurnTensor::<CudaCubeBackend, 1>::zeros([channels], &device)
        .into_primitive()
        .tensor();
    let params = params_tensor(
        &device,
        [
            batch as f32,
            views as f32,
            channels as f32,
            time as f32,
            d_conv as f32,
        ],
    )
    .into_primitive()
    .tensor();

    let cube_dim = CubeDim::new_1d(MAMBA_CONV_CUDA_WORKGROUP_X);
    let cube_count = CubeCount::Static(
        div_ceil_u32(max_pos as u32, MAMBA_CONV_CUDA_WORKGROUP_X),
        channels as u32,
        (batch * views) as u32,
    );

    unsafe {
        let _ = mamba_depthwise_conv_backward_cuda_kernel::launch_unchecked::<CudaRuntime>(
            &client,
            cube_count,
            cube_dim,
            x.clone().into_tensor_arg(),
            conv_weight.clone().into_tensor_arg(),
            state.clone().into_tensor_arg(),
            grad_preact.clone().into_tensor_arg(),
            grad_x.clone().into_tensor_arg(),
            grad_weight.clone().into_tensor_arg(),
            grad_bias.clone().into_tensor_arg(),
            params.clone().into_tensor_arg(),
        );
    }

    MambaDepthwiseConvCudaBackwardOutput {
        grad_x,
        grad_weight,
        grad_bias,
    }
}

#[cfg(feature = "cuda")]
fn params_tensor(
    device: &<CudaCubeBackend as burn::tensor::backend::Backend>::Device,
    values: [f32; PARAMS_LEN],
) -> BurnTensor<CudaCubeBackend, 1> {
    BurnTensor::<CudaCubeBackend, 1>::from_data(
        TensorData::new(values.to_vec(), [PARAMS_LEN]),
        device,
    )
}

#[cfg(feature = "cuda")]
#[cube(launch_unchecked)]
fn mamba_depthwise_conv_forward_cuda_kernel(
    x: &Tensor<f32>,
    conv_weight: &Tensor<f32>,
    conv_bias: &Tensor<f32>,
    state: &Tensor<f32>,
    preact: &mut Tensor<f32>,
    activated: &mut Tensor<f32>,
    next_state: &mut Tensor<f32>,
    params: &Tensor<f32>,
) {
    let batch = params[0] as usize;
    let views = params[1] as usize;
    let channels = params[2] as usize;
    let time = params[3] as usize;
    let d_conv = params[4] as usize;

    let bv = CUBE_POS_Z as usize;
    let channel = CUBE_POS_Y as usize;
    let pos = (CUBE_POS_X * CUBE_DIM_X + UNIT_POS_X) as usize;
    if bv >= batch * views || channel >= channels {
        terminate!();
    }

    let batch_idx = bv / views;
    let view_idx = bv % views;

    if pos < time {
        let mut acc = conv_bias[channel * conv_bias.stride(0)];
        let mut tap = 0usize;
        while tap < d_conv {
            let hist_idx = pos + 1usize + tap;
            let value = if hist_idx < d_conv {
                let idx = batch_idx * state.stride(0)
                    + view_idx * state.stride(1)
                    + channel * state.stride(2)
                    + hist_idx * state.stride(3);
                state[idx]
            } else {
                let x_idx = hist_idx - d_conv;
                let idx = batch_idx * x.stride(0)
                    + view_idx * x.stride(1)
                    + channel * x.stride(2)
                    + x_idx * x.stride(3);
                x[idx]
            };
            let weight_idx = channel * conv_weight.stride(0) + tap * conv_weight.stride(1);
            acc += value * conv_weight[weight_idx];
            tap += 1usize;
        }
        let out_idx = batch_idx * preact.stride(0)
            + view_idx * preact.stride(1)
            + channel * preact.stride(2)
            + pos * preact.stride(3);
        preact[out_idx] = acc;
        let sigmoid = 1.0 / (1.0 + f32::exp(-acc));
        activated[batch_idx * activated.stride(0)
            + view_idx * activated.stride(1)
            + channel * activated.stride(2)
            + pos * activated.stride(3)] = acc * sigmoid;
    }

    if pos < d_conv {
        let hist_idx = time + pos;
        let value = if hist_idx < d_conv {
            let idx = batch_idx * state.stride(0)
                + view_idx * state.stride(1)
                + channel * state.stride(2)
                + hist_idx * state.stride(3);
            state[idx]
        } else {
            let x_idx = hist_idx - d_conv;
            let idx = batch_idx * x.stride(0)
                + view_idx * x.stride(1)
                + channel * x.stride(2)
                + x_idx * x.stride(3);
            x[idx]
        };
        let out_idx = batch_idx * next_state.stride(0)
            + view_idx * next_state.stride(1)
            + channel * next_state.stride(2)
            + pos * next_state.stride(3);
        next_state[out_idx] = value;
    }
}

#[cfg(feature = "cuda")]
#[cube(launch_unchecked)]
fn mamba_depthwise_conv_backward_cuda_kernel(
    x: &Tensor<f32>,
    conv_weight: &Tensor<f32>,
    state: &Tensor<f32>,
    grad_preact: &Tensor<f32>,
    grad_x: &mut Tensor<f32>,
    grad_weight: &mut Tensor<Atomic<f32>>,
    grad_bias: &mut Tensor<Atomic<f32>>,
    params: &Tensor<f32>,
) {
    let batch = params[0] as usize;
    let views = params[1] as usize;
    let channels = params[2] as usize;
    let time = params[3] as usize;
    let d_conv = params[4] as usize;

    let bv = CUBE_POS_Z as usize;
    let channel = CUBE_POS_Y as usize;
    let pos = (CUBE_POS_X * CUBE_DIM_X + UNIT_POS_X) as usize;
    if bv >= batch * views || channel >= channels {
        terminate!();
    }

    let batch_idx = bv / views;
    let view_idx = bv % views;

    if pos < time {
        let grad_idx = batch_idx * grad_preact.stride(0)
            + view_idx * grad_preact.stride(1)
            + channel * grad_preact.stride(2)
            + pos * grad_preact.stride(3);
        let grad_val = grad_preact[grad_idx];
        grad_bias[channel * grad_bias.stride(0)].fetch_add(grad_val);

        let mut tap = 0usize;
        while tap < d_conv {
            let hist_idx = pos + 1usize + tap;
            let value = if hist_idx < d_conv {
                let idx = batch_idx * state.stride(0)
                    + view_idx * state.stride(1)
                    + channel * state.stride(2)
                    + hist_idx * state.stride(3);
                state[idx]
            } else {
                let x_idx = hist_idx - d_conv;
                let idx = batch_idx * x.stride(0)
                    + view_idx * x.stride(1)
                    + channel * x.stride(2)
                    + x_idx * x.stride(3);
                x[idx]
            };
            let weight_idx = channel * grad_weight.stride(0) + tap * grad_weight.stride(1);
            grad_weight[weight_idx].fetch_add(grad_val * value);
            tap += 1usize;
        }
    }

    if pos < time {
        let mut acc = 0.0;
        let mut tap = 0usize;
        while tap < d_conv {
            let out_t = pos + d_conv - 1usize;
            if out_t >= tap {
                let target_t = out_t - tap;
                if target_t < time {
                    let grad_idx = batch_idx * grad_preact.stride(0)
                        + view_idx * grad_preact.stride(1)
                        + channel * grad_preact.stride(2)
                        + target_t * grad_preact.stride(3);
                    let weight_idx = channel * conv_weight.stride(0) + tap * conv_weight.stride(1);
                    acc += grad_preact[grad_idx] * conv_weight[weight_idx];
                }
            }
            tap += 1usize;
        }
        let out_idx = batch_idx * grad_x.stride(0)
            + view_idx * grad_x.stride(1)
            + channel * grad_x.stride(2)
            + pos * grad_x.stride(3);
        grad_x[out_idx] = acc;
    }
}

#[cfg(feature = "cuda")]
fn div_ceil_u32(value: u32, divisor: u32) -> u32 {
    value.div_ceil(divisor)
}