burn_dragon_kernel 0.5.0

Fused GPU kernel crate for burn_dragon execution paths
Documentation
use super::*;
use burn::tensor::{Distribution, Tensor, TensorData};
use burn_autodiff::Autodiff;
#[cfg(feature = "cuda")]
use burn_cuda::Cuda;
use burn_wgpu::{CubeBackend, RuntimeOptions, graphics};

type Backend = CubeBackend<WgpuRuntime, f32, i32, u32>;
type AutodiffBackendImpl = Autodiff<Backend>;

fn init_runtime(device: &<Backend as BackendTrait>::Device) {
    static INIT: std::sync::Once = std::sync::Once::new();
    INIT.call_once(|| {
        burn_wgpu::init_setup::<graphics::AutoGraphicsApi>(device, RuntimeOptions::default());
    });
}

fn assert_close_backend<B: BackendTrait, const D: usize>(
    lhs: Tensor<B, D>,
    rhs: Tensor<B, D>,
    atol: f32,
    rtol: f32,
) {
    let lhs_data = lhs
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("lhs vec");
    let rhs_data = rhs
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("rhs vec");
    let mut max_diff = 0.0_f32;
    let mut max_tol = 0.0_f32;
    let mut max_lhs = 0.0_f32;
    let mut max_rhs = 0.0_f32;
    for (a, b) in lhs_data.iter().zip(rhs_data.iter()) {
        let diff = (a - b).abs();
        let tol = atol + rtol * b.abs();
        if diff > max_diff {
            max_diff = diff;
            max_tol = tol;
            max_lhs = *a;
            max_rhs = *b;
        }
    }
    assert!(
        max_diff <= max_tol,
        "max difference {max_diff} exceeds tolerance {max_tol} (lhs={max_lhs}, rhs={max_rhs})"
    );
}

fn assert_close(lhs: Tensor<Backend, 4>, rhs: Tensor<Backend, 4>, atol: f32, rtol: f32) {
    assert_close_backend(lhs, rhs, atol, rtol);
}

fn reference_attention(
    query: Tensor<Backend, 4>,
    value: Tensor<Backend, 4>,
    decay: Tensor<Backend, 1>,
) -> Tensor<Backend, 4> {
    dense_causal_attention_reference(query, value, decay)
}

#[test]
fn dense_causal_attention_matches_reference_on_wgpu() {
    let device = <Backend as BackendTrait>::Device::default();
    init_runtime(&device);
    <Backend as BackendTrait>::seed(&device, 17);

    let query =
        Tensor::<Backend, 4>::random([2, 4, 16, 32], Distribution::Normal(0.0, 1.0), &device);
    let value =
        Tensor::<Backend, 4>::random([2, 1, 16, 24], Distribution::Normal(0.0, 1.0), &device);
    let decay = Tensor::<Backend, 1>::from_floats([0.97, 0.93, 0.89, 0.85], &device);

    let fused = try_fused_dense_causal_attention_wgpu::<Backend>(&query, &value, &decay)
        .expect("wgpu dense causal attention");
    let expected = reference_attention(query, value, decay);
    assert_close(fused, expected, 2e-4, 2e-4);
}

#[test]
fn dense_causal_attention_matches_reference_gradients_on_wgpu_autodiff() {
    let device = <AutodiffBackendImpl as BackendTrait>::Device::default();
    init_runtime(&device);

    let query = Tensor::<AutodiffBackendImpl, 4>::from_data(
        TensorData::new(
            (0..24).map(|i| (i as f32) * 0.03 - 0.2).collect(),
            [1, 2, 3, 4],
        ),
        &device,
    )
    .require_grad();
    let value = Tensor::<AutodiffBackendImpl, 4>::from_data(
        TensorData::new(
            (0..18).map(|i| (i as f32) * 0.05 - 0.15).collect(),
            [1, 1, 3, 6],
        ),
        &device,
    )
    .require_grad();
    let decay =
        Tensor::<AutodiffBackendImpl, 1>::from_data(TensorData::new(vec![0.95, 0.9], [2]), &device)
            .require_grad();
    let weights = Tensor::<AutodiffBackendImpl, 4>::from_data(
        TensorData::new(
            (0..36).map(|i| (i as f32) * 0.02 - 0.1).collect(),
            [1, 2, 3, 6],
        ),
        &device,
    );

    let fused =
        try_fused_dense_causal_attention_wgpu::<AutodiffBackendImpl>(&query, &value, &decay)
            .expect("wgpu dense causal attention autodiff");
    let reference = dense_causal_attention_reference(query.clone(), value.clone(), decay.clone());

    let fused_grads = (fused * weights.clone()).sum().backward();
    let reference_grads = (reference * weights).sum().backward();

    assert_close_backend(
        query.grad(&fused_grads).expect("fused query grad"),
        query.grad(&reference_grads).expect("reference query grad"),
        5e-3,
        5e-3,
    );
    assert_close_backend(
        value.grad(&fused_grads).expect("fused value grad"),
        value.grad(&reference_grads).expect("reference value grad"),
        5e-3,
        5e-3,
    );
    assert_close_backend(
        decay.grad(&fused_grads).expect("fused decay grad"),
        decay.grad(&reference_grads).expect("reference decay grad"),
        5e-3,
        5e-3,
    );
}

#[cfg(feature = "cuda")]
#[test]
fn dense_causal_attention_supports_cuda_backend_types() {
    type CudaBackend = Cuda<f32, i32>;
    type CudaAutodiffBackend = Autodiff<CudaBackend>;

    assert!(supports_dense_causal_attention_backend::<CudaBackend>());
    assert!(supports_dense_causal_attention_backend::<CudaAutodiffBackend>());
}

#[cfg(feature = "cuda")]
#[test]
fn dense_causal_attention_matches_reference_on_cuda() {
    type CudaBackend = Cuda<f32, i32>;
    let device = <CudaBackend as BackendTrait>::Device::default();
    <CudaBackend as BackendTrait>::seed(&device, 17);

    let query =
        Tensor::<CudaBackend, 4>::random([1, 2, 8, 16], Distribution::Normal(0.0, 1.0), &device);
    let value =
        Tensor::<CudaBackend, 4>::random([1, 1, 8, 12], Distribution::Normal(0.0, 1.0), &device);
    let decay = Tensor::<CudaBackend, 1>::from_floats([0.97, 0.93], &device);

    let fused = try_fused_dense_causal_attention_wgpu::<CudaBackend>(&query, &value, &decay)
        .expect("cuda dense causal attention");
    let expected = dense_causal_attention_reference(query, value, decay);
    let _ = <CudaBackend as BackendTrait>::sync(&device);
    assert_close_backend(fused, expected, 2e-2, 2e-2);
}

#[cfg(feature = "cuda")]
#[test]
fn dense_causal_attention_matches_reference_gradients_on_cuda_autodiff() {
    type CudaBackend = Cuda<f32, i32>;
    type CudaAutodiffBackend = Autodiff<CudaBackend>;

    let device = <CudaAutodiffBackend as BackendTrait>::Device::default();

    let query = Tensor::<CudaAutodiffBackend, 4>::from_data(
        TensorData::new(
            (0..24).map(|i| (i as f32) * 0.03 - 0.2).collect(),
            [1, 2, 3, 4],
        ),
        &device,
    )
    .require_grad();
    let value = Tensor::<CudaAutodiffBackend, 4>::from_data(
        TensorData::new(
            (0..18).map(|i| (i as f32) * 0.05 - 0.15).collect(),
            [1, 1, 3, 6],
        ),
        &device,
    )
    .require_grad();
    let decay =
        Tensor::<CudaAutodiffBackend, 1>::from_data(TensorData::new(vec![0.95, 0.9], [2]), &device)
            .require_grad();
    let weights = Tensor::<CudaAutodiffBackend, 4>::from_data(
        TensorData::new(
            (0..36).map(|i| (i as f32) * 0.02 - 0.1).collect(),
            [1, 2, 3, 6],
        ),
        &device,
    );

    let fused =
        try_fused_dense_causal_attention_wgpu::<CudaAutodiffBackend>(&query, &value, &decay)
            .expect("cuda dense causal attention autodiff");
    let reference = dense_causal_attention_reference(query.clone(), value.clone(), decay.clone());

    let fused_grads = (fused * weights.clone()).sum().backward();
    let reference_grads = (reference * weights).sum().backward();

    assert_close_backend(
        query.grad(&fused_grads).expect("fused query grad"),
        query.grad(&reference_grads).expect("reference query grad"),
        3e-2,
        3e-2,
    );
    assert_close_backend(
        value.grad(&fused_grads).expect("fused value grad"),
        value.grad(&reference_grads).expect("reference value grad"),
        3e-2,
        3e-2,
    );
    assert_close_backend(
        decay.grad(&fused_grads).expect("fused decay grad"),
        decay.grad(&reference_grads).expect("reference decay grad"),
        3e-2,
        3e-2,
    );
}