vyre-driver-cuda 0.6.1

CUDA/PTX backend for vyre through the CUDA driver API.
Documentation
use super::*;

#[test]
fn cuda_large_storage_atomic_sum_crosses_workgroup_boundary() {
    let count = 4096u32;
    let program = Program::wrapped(
        vec![
            BufferDecl::storage("sum", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1),
            BufferDecl::read("values", 1, DataType::U32).with_count(count),
        ],
        [256, 1, 1],
        vec![
            Node::let_bind("idx", Expr::gid_x()),
            Node::if_then(
                Expr::lt(Expr::var("idx"), Expr::u32(count)),
                vec![Node::let_bind(
                    "old_sum",
                    Expr::atomic_add("sum", Expr::u32(0), Expr::load("values", Expr::var("idx"))),
                )],
            ),
        ],
    );
    let backend =
        CudaBackend::acquire().expect("Fix: CUDA backend must acquire on the GPU-required host.");
    let outputs = backend
        .dispatch(
            &program,
            &[u32_bytes(&[0]), u32_bytes(&vec![1; count as usize])],
            &DispatchConfig::default(),
        )
        .expect("Fix: CUDA storage atomic reduction must launch every inferred workgroup.");

    assert_eq!(
        bytes_u32(&outputs[0]),
        vec![count],
        "Fix: CUDA storage atomics must produce a full multi-workgroup sum, not only the first block."
    );
}

#[test]
fn cuda_compiled_pipeline_matches_direct_dispatch_for_multi_block_atomics() {
    let count = 4096u32;
    let program = Program::wrapped(
        vec![
            BufferDecl::storage("sum", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1),
            BufferDecl::read("values", 1, DataType::U32).with_count(count),
        ],
        [256, 1, 1],
        vec![
            Node::let_bind("idx", Expr::gid_x()),
            Node::if_then(
                Expr::lt(Expr::var("idx"), Expr::u32(count)),
                vec![Node::let_bind(
                    "old_sum",
                    Expr::atomic_add("sum", Expr::u32(0), Expr::load("values", Expr::var("idx"))),
                )],
            ),
        ],
    );
    let backend: Arc<dyn VyreBackend> = cuda_factory()
        .expect("Fix: CUDA factory must acquire the live GPU.")
        .into();
    let compiled = pipeline::compile(Arc::clone(&backend), &program, &DispatchConfig::default())
        .expect("Fix: CUDA native pipeline compile must accept storage atomic programs.");
    let values = u32_bytes(&vec![1; count as usize]);
    let initial_sum = u32_bytes(&[0]);
    let grid = vyre_driver::program_walks::infer_dispatch_grid(
        &program,
        &[initial_sum.clone(), values.clone()],
        &DispatchConfig::default(),
    )
    .expect("Fix: shared dispatch-grid inference must handle CUDA storage atomic programs.");
    let mut config = DispatchConfig::default();
    config.grid_override = Some(grid);
    let outputs = compiled
        .dispatch_borrowed(&[initial_sum.as_slice(), values.as_slice()], &config)
        .expect("Fix: CUDA compiled pipeline dispatch must match direct backend dispatch.");

    assert_eq!(
        bytes_u32(&outputs[0]),
        vec![count],
        "Fix: CUDA compiled pipeline must honor caller launch config across all workgroups."
    );
}

#[test]
fn cuda_subgroup_add_reports_full_warp_sum() {
    let program = Program::wrapped(
        vec![BufferDecl::output("out", 0, DataType::U32).with_count(32)],
        [32, 1, 1],
        vec![Node::store(
            "out",
            Expr::gid_x(),
            Expr::subgroup_add(Expr::u32(1)),
        )],
    );
    let backend =
        CudaBackend::acquire().expect("Fix: CUDA backend must acquire on the GPU-required host.");
    let outputs = backend
        .dispatch(&program, &[], &DispatchConfig::default())
        .expect("Fix: CUDA subgroup add must execute through warp-sync PTX.");

    assert_eq!(
        bytes_u32(&outputs[0]),
        vec![32; 32],
        "Fix: CUDA subgroup add must reduce across every lane in the warp."
    );
}

#[test]
fn cuda_shared_memory_round_trips_with_barrier() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read("input", 0, DataType::U32).with_count(32),
            BufferDecl::workgroup("scratch", 32, DataType::U32),
            BufferDecl::output("out", 1, DataType::U32).with_count(32),
        ],
        [32, 1, 1],
        vec![
            Node::let_bind("idx", Expr::gid_x()),
            Node::let_bind("lane", Expr::local_x()),
            Node::store(
                "scratch",
                Expr::var("lane"),
                Expr::load("input", Expr::var("idx")),
            ),
            Node::barrier(),
            Node::store(
                "out",
                Expr::var("idx"),
                Expr::add(Expr::load("scratch", Expr::var("lane")), Expr::u32(7)),
            ),
        ],
    );
    let input = (0..32).collect::<Vec<u32>>();
    let expected = input.iter().map(|value| value + 7).collect::<Vec<_>>();
    let backend =
        CudaBackend::acquire().expect("Fix: CUDA backend must acquire on the GPU-required host.");
    let outputs = backend
        .dispatch(&program, &[u32_bytes(&input)], &DispatchConfig::default())
        .expect("Fix: CUDA shared-memory program must lower and execute.");

    assert_eq!(
        bytes_u32(&outputs[0]),
        expected,
        "Fix: CUDA shared memory must use shared address space and respect workgroup barriers."
    );
}

#[test]
fn cuda_dispatch_rejects_per_axis_block_overflow_before_launch() {
    let backend =
        CudaBackend::acquire().expect("Fix: CUDA backend must acquire on the GPU-required host.");
    let max_block_dim = backend.max_block_dim();
    let program = Program::wrapped(
        vec![BufferDecl::output("dst", 0, DataType::U32).with_count(1)],
        [1, 1, max_block_dim[2].saturating_add(1)],
        vec![Node::store("dst", Expr::u32(0), Expr::u32(1))],
    );

    let err = backend
        .dispatch(&program, &[], &DispatchConfig::default())
        .expect_err("Fix: CUDA must reject per-axis block overflow before cuLaunchKernel.");
    let msg = err.to_string();
    assert!(
        msg.contains("axis 2") && msg.contains("Fix:"),
        "Fix: CUDA block-axis validation must produce an actionable axis-specific error; got: {msg}"
    );
}

#[test]
fn cuda_tanh_matches_reference_softcap_fixture() {
    let i = Expr::var("i");
    let value = Expr::mul(
        Expr::UnOp {
            op: UnOp::Tanh,
            operand: Box::new(Expr::div(Expr::load("input", i.clone()), Expr::f32(30.0))),
        },
        Expr::f32(30.0),
    );
    let program = Program::wrapped(
        vec![
            BufferDecl::read("input", 0, DataType::F32).with_count(4),
            BufferDecl::output("output", 1, DataType::F32).with_count(4),
        ],
        [64, 1, 1],
        vec![
            Node::let_bind("i", Expr::gid_x()),
            Node::if_then(
                Expr::lt(i.clone(), Expr::buf_len("input")),
                vec![Node::store("output", i, value)],
            ),
        ],
    );
    let input = [0.0_f32, 15.0, -60.0, 100.0];
    let expected = input
        .iter()
        .map(|x| (x / 30.0).tanh() * 30.0)
        .collect::<Vec<_>>();
    let backend =
        CudaBackend::acquire().expect("Fix: CUDA backend must acquire on the GPU-required host.");
    let mut config = DispatchConfig::default();
    config.ulp_budget = Some(128);
    let outputs = backend
        .dispatch(
            &program,
            &[input.iter().flat_map(|x| x.to_le_bytes()).collect()],
            &config,
        )
        .expect("Fix: CUDA tanh dispatch must complete.");

    let got = outputs[0]
        .chunks_exact(4)
        .map(|bytes| f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
        .collect::<Vec<_>>();
    for (lane, (actual, expected)) in got.iter().zip(expected.iter()).enumerate() {
        let diff = ordered_f32_bits(*actual).abs_diff(ordered_f32_bits(*expected));
        assert!(
            diff <= 128,
            "Fix: CUDA tanh lane {lane} exceeded the 128-ULP native transcendental budget: actual={actual:e} expected={expected:e} diff={diff}."
        );
    }
}