burn_dragon_kernel 0.5.0

Fused GPU kernel crate for burn_dragon execution paths
Documentation
#![cfg_attr(not(feature = "cuda"), allow(unused))]

#[cfg(feature = "cuda")]
use std::time::Instant;

#[cfg(feature = "cuda")]
use burn::tensor::backend::Backend as BackendTrait;
#[cfg(feature = "cuda")]
use burn::tensor::{Distribution, Tensor};
#[cfg(feature = "cuda")]
use burn_autodiff::Autodiff;
#[cfg(feature = "cuda")]
use burn_cubecl::cubecl::cuda::CudaRuntime;
#[cfg(feature = "cuda")]
use burn_dragon_kernel::kernels::sequence::rwkv8::forward::{
    tensorized_rwkv8_forward, tensorized_rwkv8_forward_direct_graph,
};
#[cfg(feature = "cuda")]
use burn_wgpu::CubeBackend;
#[cfg(feature = "cuda")]
use serde::Serialize;

#[cfg(feature = "cuda")]
type Backend = CubeBackend<CudaRuntime, f32, i32, u8>;
#[cfg(feature = "cuda")]
type AutodiffBackend = Autodiff<Backend>;
#[cfg(feature = "cuda")]
type Device = <Backend as BackendTrait>::Device;

#[cfg(feature = "cuda")]
struct EnvVarGuard {
    key: &'static str,
    previous: Option<String>,
}

#[cfg(feature = "cuda")]
impl EnvVarGuard {
    fn set(key: &'static str, value: &str) -> Self {
        let previous = std::env::var(key).ok();
        unsafe { std::env::set_var(key, value) };
        Self { key, previous }
    }
}

#[cfg(feature = "cuda")]
impl Drop for EnvVarGuard {
    fn drop(&mut self) {
        match &self.previous {
            Some(previous) => unsafe { std::env::set_var(self.key, previous) },
            None => unsafe { std::env::remove_var(self.key) },
        }
    }
}

#[cfg(feature = "cuda")]
#[derive(Clone, Copy, Serialize)]
struct BenchCase {
    name: &'static str,
    batch: usize,
    heads: usize,
    time: usize,
    latent: usize,
    embd: usize,
}

#[cfg(feature = "cuda")]
#[derive(Serialize)]
struct BenchResult {
    case: BenchCase,
    warmup: usize,
    repetitions: usize,
    direct_graph_forward_ms: f64,
    analytic_wrapper_forward_ms: f64,
    forward_speedup_x: f64,
    direct_graph_backward_ms: f64,
    analytic_wrapper_backward_ms: f64,
    backward_speedup_x: f64,
    context_max_abs: f64,
    rho_max_abs: f64,
    rho_norm_max_abs: f64,
}

#[cfg(feature = "cuda")]
#[derive(Serialize)]
struct BenchReport {
    benchmark: &'static str,
    profile: &'static str,
    results: Vec<BenchResult>,
}

#[cfg(feature = "cuda")]
const COMPACT_CASES: &[BenchCase] = &[BenchCase {
    name: "cuda_b1_h2_t16_l16_e32",
    batch: 1,
    heads: 2,
    time: 16,
    latent: 16,
    embd: 32,
}];

#[cfg(feature = "cuda")]
const FULL_CASES: &[BenchCase] = &[
    BenchCase {
        name: "cuda_b1_h4_t64_l32_e64",
        batch: 1,
        heads: 4,
        time: 64,
        latent: 32,
        embd: 64,
    },
    BenchCase {
        name: "cuda_b1_h4_t128_l64_e128",
        batch: 1,
        heads: 4,
        time: 128,
        latent: 64,
        embd: 128,
    },
];

#[cfg(feature = "cuda")]
fn main() {
    let _guard = EnvVarGuard::set("BURN_DRAGON_RWKV8_TENSORIZED_TRAIN_WRAPPER", "1");
    let device = Device::default();
    <Backend as BackendTrait>::seed(&device, 2026);

    let profile = bench_profile();
    let warmup = 1;
    let repetitions = if profile == "full" { 3 } else { 2 };
    let results = bench_cases(profile)
        .iter()
        .copied()
        .map(|case| run_case(case, &device, warmup, repetitions))
        .collect::<Vec<_>>();

    println!(
        "{}",
        serde_json::to_string_pretty(&BenchReport {
            benchmark: "burn_dragon_kernel rwkv8 cuda direct-graph vs analytical wrapper",
            profile,
            results,
        })
        .expect("serialize rwkv8 cuda bench report")
    );
}

#[cfg(not(feature = "cuda"))]
fn main() {
    panic!("rwkv8_cuda_bench requires --features cuda");
}

#[cfg(feature = "cuda")]
fn bench_profile() -> &'static str {
    match std::env::var("BURN_DRAGON_BENCH_PROFILE")
        .unwrap_or_else(|_| "compact".to_string())
        .as_str()
    {
        "full" => "full",
        _ => "compact",
    }
}

#[cfg(feature = "cuda")]
fn bench_cases(profile: &'static str) -> &'static [BenchCase] {
    if profile == "full" {
        FULL_CASES
    } else {
        COMPACT_CASES
    }
}

#[cfg(feature = "cuda")]
fn run_case(case: BenchCase, device: &Device, warmup: usize, repetitions: usize) -> BenchResult {
    let query = Tensor::<Backend, 4>::random(
        [case.batch, case.heads, case.time, case.latent],
        Distribution::Uniform(0.0, 1.0),
        device,
    );
    let value = Tensor::<Backend, 4>::random(
        [case.batch, 1, case.time, case.embd],
        Distribution::Uniform(0.0, 1.0),
        device,
    );
    let decay = Tensor::<Backend, 3>::random(
        [1, case.heads, case.latent],
        Distribution::Uniform(0.75, 0.99),
        device,
    );

    for _ in 0..warmup {
        let _ = tensorized_rwkv8_forward_direct_graph(
            query.clone(),
            value.clone(),
            None,
            None,
            decay.clone(),
        );
        let _ = tensorized_rwkv8_forward(query.clone(), value.clone(), None, None, decay.clone());
        let _ = Backend::sync(device);
    }

    let mut direct_graph_forward_ms = 0.0;
    let mut analytic_wrapper_forward_ms = 0.0;
    for _ in 0..repetitions {
        let _ = Backend::sync(device);
        let started = Instant::now();
        let _ = tensorized_rwkv8_forward_direct_graph(
            query.clone(),
            value.clone(),
            None,
            None,
            decay.clone(),
        );
        let _ = Backend::sync(device);
        direct_graph_forward_ms += started.elapsed().as_secs_f64() * 1_000.0;

        let _ = Backend::sync(device);
        let started = Instant::now();
        let _ = tensorized_rwkv8_forward(query.clone(), value.clone(), None, None, decay.clone());
        let _ = Backend::sync(device);
        analytic_wrapper_forward_ms += started.elapsed().as_secs_f64() * 1_000.0;
    }

    let query_ad = Tensor::<AutodiffBackend, 4>::random(
        [case.batch, case.heads, case.time, case.latent],
        Distribution::Uniform(0.0, 1.0),
        device,
    )
    .require_grad();
    let value_ad = Tensor::<AutodiffBackend, 4>::random(
        [case.batch, 1, case.time, case.embd],
        Distribution::Uniform(0.0, 1.0),
        device,
    )
    .require_grad();
    let decay_ad = Tensor::<AutodiffBackend, 3>::random(
        [1, case.heads, case.latent],
        Distribution::Uniform(0.75, 0.99),
        device,
    )
    .require_grad();
    let weights = Tensor::<AutodiffBackend, 4>::random(
        [case.batch, case.heads, case.time, case.embd],
        Distribution::Uniform(-0.25, 0.25),
        device,
    );

    for _ in 0..warmup {
        let direct_graph_loss = (tensorized_rwkv8_forward_direct_graph(
            query_ad.clone(),
            value_ad.clone(),
            None,
            None,
            decay_ad.clone(),
        )
        .context
            * weights.clone())
        .sum();
        let _ = direct_graph_loss.backward();
        let _ = AutodiffBackend::sync(device);

        let analytic_wrapper_loss = (tensorized_rwkv8_forward(
            query_ad.clone(),
            value_ad.clone(),
            None,
            None,
            decay_ad.clone(),
        )
        .context
            * weights.clone())
        .sum();
        let _ = analytic_wrapper_loss.backward();
        let _ = AutodiffBackend::sync(device);
    }

    let mut direct_graph_backward_ms = 0.0;
    let mut analytic_wrapper_backward_ms = 0.0;
    for _ in 0..repetitions {
        let _ = AutodiffBackend::sync(device);
        let started = Instant::now();
        let direct_graph_loss = (tensorized_rwkv8_forward_direct_graph(
            query_ad.clone(),
            value_ad.clone(),
            None,
            None,
            decay_ad.clone(),
        )
        .context
            * weights.clone())
        .sum();
        let _ = direct_graph_loss.backward();
        let _ = AutodiffBackend::sync(device);
        direct_graph_backward_ms += started.elapsed().as_secs_f64() * 1_000.0;

        let _ = AutodiffBackend::sync(device);
        let started = Instant::now();
        let analytic_wrapper_loss = (tensorized_rwkv8_forward(
            query_ad.clone(),
            value_ad.clone(),
            None,
            None,
            decay_ad.clone(),
        )
        .context
            * weights.clone())
        .sum();
        let _ = analytic_wrapper_loss.backward();
        let _ = AutodiffBackend::sync(device);
        analytic_wrapper_backward_ms += started.elapsed().as_secs_f64() * 1_000.0;
    }

    let direct_graph_output = tensorized_rwkv8_forward_direct_graph(
        query.clone(),
        value.clone(),
        None,
        None,
        decay.clone(),
    );
    let analytic_wrapper_output = tensorized_rwkv8_forward(query, value, None, None, decay);

    let direct_graph_forward_ms = direct_graph_forward_ms / repetitions.max(1) as f64;
    let analytic_wrapper_forward_ms = analytic_wrapper_forward_ms / repetitions.max(1) as f64;
    let direct_graph_backward_ms = direct_graph_backward_ms / repetitions.max(1) as f64;
    let analytic_wrapper_backward_ms = analytic_wrapper_backward_ms / repetitions.max(1) as f64;

    BenchResult {
        case,
        warmup,
        repetitions,
        direct_graph_forward_ms,
        analytic_wrapper_forward_ms,
        forward_speedup_x: direct_graph_forward_ms / analytic_wrapper_forward_ms.max(f64::EPSILON),
        direct_graph_backward_ms,
        analytic_wrapper_backward_ms,
        backward_speedup_x: direct_graph_backward_ms
            / analytic_wrapper_backward_ms.max(f64::EPSILON),
        context_max_abs: max_abs_4(direct_graph_output.context, analytic_wrapper_output.context),
        rho_max_abs: max_abs_4(direct_graph_output.rho, analytic_wrapper_output.rho),
        rho_norm_max_abs: max_abs_3(
            direct_graph_output.rho_norm,
            analytic_wrapper_output.rho_norm,
        ),
    }
}

#[cfg(feature = "cuda")]
fn max_abs_4(lhs: Tensor<Backend, 4>, rhs: Tensor<Backend, 4>) -> f64 {
    lhs.sub(rhs).abs().max().into_scalar() as f64
}

#[cfg(feature = "cuda")]
fn max_abs_3(lhs: Tensor<Backend, 3>, rhs: Tensor<Backend, 3>) -> f64 {
    lhs.sub(rhs).abs().max().into_scalar() as f64
}