burn_dragon_kernel 0.5.0

Fused GPU kernel crate for burn_dragon execution paths
Documentation
#[cfg(not(feature = "cuda"))]
fn main() {
    eprintln!("recurrent_cuda_bench requires --features cuda");
    std::process::exit(1);
}

#[cfg(feature = "cuda")]
mod app {
    use std::fs;
    use std::path::{Path, PathBuf};
    use std::time::Instant;

    use burn::tensor::backend::Backend as BackendTrait;
    use burn::tensor::{Tensor, TensorData};
    use burn_cubecl::cubecl::Runtime;
    use burn_cubecl::cubecl::cuda::CudaRuntime;
    use burn_cuda::Cuda;
    use burn_dragon_kernel::api::recurrent::{
        supports_recurrent_backend, try_fused_recurrent_attention_wgpu,
    };
    use serde::Serialize;

    type Backend = Cuda<f32, i32>;
    type Device = <Backend as BackendTrait>::Device;

    #[derive(Clone, Copy, Serialize)]
    struct BenchCase {
        name: &'static str,
        batch: usize,
        heads: usize,
        value_heads: usize,
        time: usize,
        latent: usize,
        embd: usize,
    }

    #[derive(Clone, Copy, Serialize)]
    struct MemorySnapshot {
        reserved: u64,
        in_use: u64,
    }

    #[derive(Clone, Copy, Serialize)]
    struct ErrorMetrics {
        max_abs: f64,
        mean_abs: f64,
        rmse: f64,
    }

    #[derive(Clone, Serialize)]
    struct CaseResult {
        case: BenchCase,
        warmup: usize,
        repetitions: usize,
        reference_ms: f64,
        fused_ms: f64,
        speedup_x: f64,
        reference_tokens_per_s: f64,
        fused_tokens_per_s: f64,
        context_error: ErrorMetrics,
        rho_error: ErrorMetrics,
        memory_before: MemorySnapshot,
        memory_after: MemorySnapshot,
    }

    #[derive(Clone, Serialize)]
    struct Report {
        benchmark: &'static str,
        backend: &'static str,
        warmup: usize,
        repetitions: usize,
        cases: Vec<CaseResult>,
    }

    const CASES: &[BenchCase] = &[
        BenchCase {
            name: "cuda_b1_h1_t128_l32768_e256",
            batch: 1,
            heads: 1,
            value_heads: 1,
            time: 128,
            latent: 32768,
            embd: 256,
        },
        BenchCase {
            name: "cuda_b1_h1_t128_l65536_e256",
            batch: 1,
            heads: 1,
            value_heads: 1,
            time: 128,
            latent: 65536,
            embd: 256,
        },
    ];

    pub fn main() {
        let output_dir = std::env::args().skip(1).find_map(|arg| {
            arg.strip_prefix("--output-dir=")
                .map(PathBuf::from)
                .or_else(|| {
                    (arg == "--output-dir").then_some(PathBuf::from(
                        "artifacts/language/recurrent_cuda_bench/latest",
                    ))
                })
        });

        if !supports_recurrent_backend::<Backend>() {
            eprintln!(
                "cuda runtime support is currently disabled for the recurrent fused kernel because the kernel source is WGSL-only; port it to a portable CubeCL kernel or add a CUDA-specific source first"
            );
            std::process::exit(2);
        }
        let device = Device::default();
        <Backend as BackendTrait>::seed(&device, 20260319);

        let warmup = 1usize;
        let repetitions = 3usize;
        let cases = CASES
            .iter()
            .map(|case| run_case(case, &device, warmup, repetitions))
            .collect::<Vec<_>>();

        let report = Report {
            benchmark: "burn_dragon_kernel recurrent cuda bench",
            backend: "cuda",
            warmup,
            repetitions,
            cases,
        };

        let json = serde_json::to_string_pretty(&report).expect("serialize report");
        println!("{json}");

        if let Some(root) = output_dir.as_ref() {
            write_artifacts(root, &report, &json);
        }
    }

    fn run_case(
        case: &BenchCase,
        device: &Device,
        warmup: usize,
        repetitions: usize,
    ) -> CaseResult {
        let query = sample_query(case, device);
        let value = sample_value(case, device);
        let rho =
            Tensor::<Backend, 4>::zeros([case.batch, case.heads, case.latent, case.embd], device);
        let decay = sample_decay(case.heads, device);

        let memory_before = memory_snapshot(device);

        let (reference_ms, reference_context, reference_rho) = timed_run(
            warmup,
            repetitions,
            || reference_recurrent(query.clone(), value.clone(), rho.clone(), decay.clone()),
            device,
        );
        let (fused_ms, fused_context, fused_rho) = timed_run(
            warmup,
            repetitions,
            || {
                let output = try_fused_recurrent_attention_wgpu::<Backend>(
                    &query,
                    &value,
                    Some(&rho),
                    Some(&decay),
                )
                .expect("fused recurrent output");
                (output.context, output.rho)
            },
            device,
        );

        let memory_after = memory_snapshot(device);
        let tokens = (case.batch * case.time) as f64;

        CaseResult {
            case: *case,
            warmup,
            repetitions,
            reference_ms,
            fused_ms,
            speedup_x: reference_ms / fused_ms,
            reference_tokens_per_s: tokens / (reference_ms / 1_000.0),
            fused_tokens_per_s: tokens / (fused_ms / 1_000.0),
            context_error: compare_tensors(reference_context, fused_context),
            rho_error: compare_tensors(reference_rho, fused_rho),
            memory_before,
            memory_after,
        }
    }

    fn sample_query(case: &BenchCase, device: &Device) -> Tensor<Backend, 4> {
        let total = case.batch * case.heads * case.time * case.latent;
        let values = (0..total)
            .map(|idx| (((idx % 97) as f32) / 97.0) + 0.01)
            .collect::<Vec<_>>();
        Tensor::<Backend, 4>::from_data(
            TensorData::new(values, [case.batch, case.heads, case.time, case.latent]),
            device,
        )
    }

    fn sample_value(case: &BenchCase, device: &Device) -> Tensor<Backend, 4> {
        let total = case.batch * case.value_heads * case.time * case.embd;
        let values = (0..total)
            .map(|idx| (((idx % 53) as f32) / 53.0) + 0.02)
            .collect::<Vec<_>>();
        Tensor::<Backend, 4>::from_data(
            TensorData::new(values, [case.batch, case.value_heads, case.time, case.embd]),
            device,
        )
    }

    fn sample_decay(heads: usize, device: &Device) -> Tensor<Backend, 1> {
        let values = (0..heads)
            .map(|idx| 0.99f32 - (idx as f32 * 0.01))
            .collect::<Vec<_>>();
        Tensor::<Backend, 1>::from_data(TensorData::new(values, [heads]), device)
    }

    fn reference_recurrent(
        query: Tensor<Backend, 4>,
        value: Tensor<Backend, 4>,
        rho: Tensor<Backend, 4>,
        decay: Tensor<Backend, 1>,
    ) -> (Tensor<Backend, 4>, Tensor<Backend, 4>) {
        let [batch, heads, time, _latent] = query.shape().dims::<4>();
        let value_heads = value.shape().dims::<4>()[1];
        let embd = value.shape().dims::<4>()[3];

        let decay = decay.reshape([1, heads, 1, 1]);
        let value = if value_heads == 1 {
            value.repeat_dim(1, heads)
        } else {
            value
        };

        let mut state = rho;
        let mut outputs = Vec::with_capacity(time);
        for t in 0..time {
            let q_t = query.clone().slice_dim(2, t..t + 1);
            let v_t = value.clone().slice_dim(2, t..t + 1);
            let q_latent = q_t.swap_dims(2, 3);
            let context = (state.clone() * q_latent.clone())
                .sum_dim(2)
                .reshape([batch, heads, 1, embd]);
            outputs.push(context);
            state = (state + q_latent * v_t) * decay.clone();
        }
        (Tensor::cat(outputs, 2), state)
    }

    fn timed_run<F>(
        warmup: usize,
        repetitions: usize,
        mut run: F,
        device: &Device,
    ) -> (f64, Tensor<Backend, 4>, Tensor<Backend, 4>)
    where
        F: FnMut() -> (Tensor<Backend, 4>, Tensor<Backend, 4>),
    {
        let mut last = run();
        let _ = <Backend as BackendTrait>::sync(device);
        for _ in 0..warmup {
            last = run();
            let _ = <Backend as BackendTrait>::sync(device);
        }
        let start = Instant::now();
        for _ in 0..repetitions {
            last = run();
            let _ = <Backend as BackendTrait>::sync(device);
        }
        let elapsed_ms = start.elapsed().as_secs_f64() * 1_000.0 / repetitions as f64;
        (elapsed_ms, last.0, last.1)
    }

    fn compare_tensors(lhs: Tensor<Backend, 4>, rhs: Tensor<Backend, 4>) -> ErrorMetrics {
        let lhs = lhs
            .to_data()
            .convert::<f32>()
            .into_vec::<f32>()
            .expect("lhs vec");
        let rhs = rhs
            .to_data()
            .convert::<f32>()
            .into_vec::<f32>()
            .expect("rhs vec");

        let mut max_abs = 0.0f64;
        let mut sum_abs = 0.0f64;
        let mut sum_sq = 0.0f64;
        for (a, b) in lhs.iter().zip(rhs.iter()) {
            let diff = (*a as f64 - *b as f64).abs();
            max_abs = max_abs.max(diff);
            sum_abs += diff;
            sum_sq += diff * diff;
        }
        let len = lhs.len().max(1) as f64;
        ErrorMetrics {
            max_abs,
            mean_abs: sum_abs / len,
            rmse: (sum_sq / len).sqrt(),
        }
    }

    fn memory_snapshot(device: &Device) -> MemorySnapshot {
        let usage = <CudaRuntime as Runtime>::client(device)
            .memory_usage()
            .expect("cuda memory usage");
        MemorySnapshot {
            reserved: usage.bytes_reserved,
            in_use: usage.bytes_in_use,
        }
    }

    fn write_artifacts(root: &Path, report: &Report, json: &str) {
        fs::create_dir_all(root).expect("create output dir");
        fs::write(root.join("report.json"), json).expect("write report json");
        fs::write(root.join("summary.md"), format_markdown(report)).expect("write summary");
        fs::write(root.join("summary.csv"), format_csv(report)).expect("write csv");
    }

    fn format_markdown(report: &Report) -> String {
        let mut out = String::new();
        out.push_str("# CUDA Recurrent Bench\n\n");
        out.push_str(&format!(
            "- backend: `{}`\n- warmup: `{}`\n- repetitions: `{}`\n\n",
            report.backend, report.warmup, report.repetitions
        ));
        out.push_str("| case | reference ms | fused ms | speedup | ref tok/s | fused tok/s | ctx max abs | rho max abs |\n");
        out.push_str("| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |\n");
        for case in &report.cases {
            out.push_str(&format!(
                "| {} | {:.3} | {:.3} | {:.2}x | {:.2} | {:.2} | {:.6} | {:.6} |\n",
                case.case.name,
                case.reference_ms,
                case.fused_ms,
                case.speedup_x,
                case.reference_tokens_per_s,
                case.fused_tokens_per_s,
                case.context_error.max_abs,
                case.rho_error.max_abs,
            ));
        }
        out
    }

    fn format_csv(report: &Report) -> String {
        let mut out = String::from(
            "case,reference_ms,fused_ms,speedup_x,reference_tokens_per_s,fused_tokens_per_s,context_max_abs,rho_max_abs,reserved_before,reserved_after,in_use_before,in_use_after\n",
        );
        for case in &report.cases {
            out.push_str(&format!(
                "{},{:.6},{:.6},{:.6},{:.6},{:.6},{:.9},{:.9},{},{},{},{}\n",
                case.case.name,
                case.reference_ms,
                case.fused_ms,
                case.speedup_x,
                case.reference_tokens_per_s,
                case.fused_tokens_per_s,
                case.context_error.max_abs,
                case.rho_error.max_abs,
                case.memory_before.reserved,
                case.memory_after.reserved,
                case.memory_before.in_use,
                case.memory_after.in_use,
            ));
        }
        out
    }
}

#[cfg(feature = "cuda")]
fn main() {
    app::main();
}