baracuda-kernels 0.0.1-alpha.68

Unified ML op facade for the baracuda CUDA ecosystem. Exposes every primitive an ML framework would expect (union of PyTorch torch.* + nn.functional and JAX lax.* / numpy ops) through a single Plan-based Rust surface, internally dispatching to baracuda-cutlass, the baracuda-* NVIDIA-library wrappers, or bespoke baracuda-kernels-sys kernels.
Documentation
//! Real-GPU smoke test for `BinaryBackwardPlan<T, N> + BinaryKind::Maximum`.
//!
//! Dual-save backward (first BW where the saves are pure comparison
//! references, not multipliers). Tie convention: PyTorch parity — split
//! `dy` evenly when `a == b`, propagate `dy` to both on NaN inputs.
//!
//! Per dtype, the input pattern hits all four branches deliberately:
//! `a > b`, `a < b`, `a == b`, and (for FP only) `NaN`.
//!
//! `#[ignore]` by default; run with
//! `cargo test -p baracuda-kernels --release --features sm89 \
//!   --test binary_maximum_backward_smoke -- --ignored`.

use baracuda_driver::{init, Context, Device, DeviceBuffer, Stream};
use baracuda_kernels::{
    contiguous_stride, BinaryBackwardArgs, BinaryBackwardDescriptor, BinaryBackwardPlan,
    BinaryKind, ElementKind, PlanPreference, TensorMut, TensorRef, Workspace,
};
use half::{bf16, f16};

const F16_EPS: f32 = 9.77e-4;
const BF16_EPS: f32 = 7.81e-3;

fn setup() -> (Context, Stream) {
    init().expect("driver init");
    let device = Device::get(0).expect("device 0");
    let ctx = Context::new(&device).expect("context");
    let stream = Stream::new(&ctx).expect("stream");
    (ctx, stream)
}

/// Builds an (a, b) pair covering every branch:
///   - a > b   (most cells)
///   - a < b   (a stripe)
///   - a == b  (a stripe — tie path)
///   - NaN     (a couple of explicit cells)
fn build_inputs_f32(numel: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
    let host_dy: Vec<f32> = (0..numel).map(|i| (i as f32) * 0.5 - 17.25).collect();
    let mut host_a: Vec<f32> = Vec::with_capacity(numel);
    let mut host_b: Vec<f32> = Vec::with_capacity(numel);
    for i in 0..numel {
        let m = i % 8;
        let base = (i as f32) * 0.125 - 1.0;
        match m {
            0 | 1 | 2 | 3 => {
                // a > b
                host_a.push(base + 1.0);
                host_b.push(base);
            }
            4 | 5 => {
                // a < b
                host_a.push(base);
                host_b.push(base + 1.0);
            }
            6 => {
                // tie
                host_a.push(base);
                host_b.push(base);
            }
            _ => {
                // NaN on `a` (PyTorch convention: both grads = dy)
                host_a.push(f32::NAN);
                host_b.push(base);
            }
        }
    }
    (host_dy, host_a, host_b)
}

fn ref_maximum_backward(dy: f32, a: f32, b: f32) -> (f32, f32) {
    // PyTorch derivatives.yaml `maximum`:
    //   da = where(a == b, dy/2, dy).masked_fill(a <  b, 0)
    //   db = where(a == b, dy/2, dy).masked_fill(b <  a, 0)
    // NaN: all comparisons false → both get `dy`.
    if a == b {
        let half = dy * 0.5;
        (half, half)
    } else {
        let da = if a < b { 0.0 } else { dy };
        let db = if b < a { 0.0 } else { dy };
        (da, db)
    }
}

#[test]
#[ignore]
fn maximum_backward_f32_3d() {
    let (ctx, stream) = setup();
    let shape = [8i32, 128, 128];
    let numel: usize = shape.iter().map(|&d| d as usize).product();
    let (host_dy, host_a, host_b) = build_inputs_f32(numel);
    let dev_dy = DeviceBuffer::from_slice(&ctx, &host_dy).expect("upload dy");
    let dev_a = DeviceBuffer::from_slice(&ctx, &host_a).expect("upload a");
    let dev_b = DeviceBuffer::from_slice(&ctx, &host_b).expect("upload b");
    let mut dev_da: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, numel).expect("alloc da");
    let mut dev_db: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, numel).expect("alloc db");
    let stride = contiguous_stride(shape);
    let desc = BinaryBackwardDescriptor {
        kind: BinaryKind::Maximum,
        shape,
        element: ElementKind::F32,
    };
    let plan = BinaryBackwardPlan::<f32, 3>::select(&stream, &desc, PlanPreference::default())
        .expect("select");
    let args = BinaryBackwardArgs::<f32, 3> {
        dy: TensorRef { data: dev_dy.as_slice(), shape, stride },
        a: Some(TensorRef { data: dev_a.as_slice(), shape, stride }),
        b: Some(TensorRef { data: dev_b.as_slice(), shape, stride }),
        da: TensorMut { data: dev_da.as_slice_mut(), shape, stride },
        db: TensorMut { data: dev_db.as_slice_mut(), shape, stride },
    };
    plan.run(&stream, Workspace::None, args).expect("run");
    stream.synchronize().expect("sync");
    let mut got_da = vec![0f32; numel];
    let mut got_db = vec![0f32; numel];
    dev_da.copy_to_host(&mut got_da).expect("download da");
    dev_db.copy_to_host(&mut got_db).expect("download db");
    for i in 0..numel {
        let (exp_da, exp_db) = ref_maximum_backward(host_dy[i], host_a[i], host_b[i]);
        assert_eq!(got_da[i].to_bits(), exp_da.to_bits(),
            "maximum bw f32 da @ {i}: a={}, b={}, dy={}", host_a[i], host_b[i], host_dy[i]);
        assert_eq!(got_db[i].to_bits(), exp_db.to_bits(),
            "maximum bw f32 db @ {i}: a={}, b={}, dy={}", host_a[i], host_b[i], host_dy[i]);
    }
}

#[test]
#[ignore]
fn maximum_backward_f64_3d() {
    let (ctx, stream) = setup();
    let shape = [8i32, 128, 128];
    let numel: usize = shape.iter().map(|&d| d as usize).product();
    let (host_dy_f32, host_a_f32, host_b_f32) = build_inputs_f32(numel);
    let host_dy: Vec<f64> = host_dy_f32.iter().map(|&x| x as f64).collect();
    let host_a: Vec<f64> = host_a_f32.iter().map(|&x| x as f64).collect();
    let host_b: Vec<f64> = host_b_f32.iter().map(|&x| x as f64).collect();
    let dev_dy = DeviceBuffer::from_slice(&ctx, &host_dy).expect("upload dy");
    let dev_a = DeviceBuffer::from_slice(&ctx, &host_a).expect("upload a");
    let dev_b = DeviceBuffer::from_slice(&ctx, &host_b).expect("upload b");
    let mut dev_da: DeviceBuffer<f64> = DeviceBuffer::zeros(&ctx, numel).expect("alloc da");
    let mut dev_db: DeviceBuffer<f64> = DeviceBuffer::zeros(&ctx, numel).expect("alloc db");
    let stride = contiguous_stride(shape);
    let desc = BinaryBackwardDescriptor {
        kind: BinaryKind::Maximum,
        shape,
        element: ElementKind::F64,
    };
    let plan = BinaryBackwardPlan::<f64, 3>::select(&stream, &desc, PlanPreference::default())
        .expect("select");
    let args = BinaryBackwardArgs::<f64, 3> {
        dy: TensorRef { data: dev_dy.as_slice(), shape, stride },
        a: Some(TensorRef { data: dev_a.as_slice(), shape, stride }),
        b: Some(TensorRef { data: dev_b.as_slice(), shape, stride }),
        da: TensorMut { data: dev_da.as_slice_mut(), shape, stride },
        db: TensorMut { data: dev_db.as_slice_mut(), shape, stride },
    };
    plan.run(&stream, Workspace::None, args).expect("run");
    stream.synchronize().expect("sync");
    let mut got_da = vec![0f64; numel];
    let mut got_db = vec![0f64; numel];
    dev_da.copy_to_host(&mut got_da).expect("download da");
    dev_db.copy_to_host(&mut got_db).expect("download db");
    for i in 0..numel {
        let (exp_da_f32, exp_db_f32) =
            ref_maximum_backward(host_dy_f32[i], host_a_f32[i], host_b_f32[i]);
        let exp_da = exp_da_f32 as f64;
        let exp_db = exp_db_f32 as f64;
        assert_eq!(got_da[i].to_bits(), exp_da.to_bits(), "maximum bw f64 da @ {i}");
        assert_eq!(got_db[i].to_bits(), exp_db.to_bits(), "maximum bw f64 db @ {i}");
    }
}

#[test]
#[ignore]
fn maximum_backward_f16_3d() {
    let (ctx, stream) = setup();
    let shape = [8i32, 128, 128];
    let numel: usize = shape.iter().map(|&d| d as usize).product();
    // f16 has limited NaN sentinel handling in the host-side `half` crate's
    // comparisons; we still exercise the > / < / == branches and skip the
    // explicit NaN cells by reusing the f32 inputs and clamping NaN → 0
    // for the device-side check too (the kernel's NaN handling is covered
    // by the f32 test).
    let (host_dy_f32, host_a_f32, host_b_f32) = build_inputs_f32(numel);
    let host_dy: Vec<f16> = host_dy_f32.iter().map(|&x| f16::from_f32(x.clamp(-32.0, 32.0))).collect();
    let host_a: Vec<f16> = host_a_f32.iter().map(|&x| {
        if x.is_nan() { f16::from_f32(0.0) } else { f16::from_f32(x.clamp(-32.0, 32.0)) }
    }).collect();
    let host_b: Vec<f16> = host_b_f32.iter().map(|&x| f16::from_f32(x.clamp(-32.0, 32.0))).collect();
    let dev_dy = DeviceBuffer::from_slice(&ctx, &host_dy).expect("upload dy");
    let dev_a = DeviceBuffer::from_slice(&ctx, &host_a).expect("upload a");
    let dev_b = DeviceBuffer::from_slice(&ctx, &host_b).expect("upload b");
    let mut dev_da: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, numel).expect("alloc da");
    let mut dev_db: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, numel).expect("alloc db");
    let stride = contiguous_stride(shape);
    let desc = BinaryBackwardDescriptor {
        kind: BinaryKind::Maximum,
        shape,
        element: ElementKind::F16,
    };
    let plan = BinaryBackwardPlan::<f16, 3>::select(&stream, &desc, PlanPreference::default())
        .expect("select");
    let args = BinaryBackwardArgs::<f16, 3> {
        dy: TensorRef { data: dev_dy.as_slice(), shape, stride },
        a: Some(TensorRef { data: dev_a.as_slice(), shape, stride }),
        b: Some(TensorRef { data: dev_b.as_slice(), shape, stride }),
        da: TensorMut { data: dev_da.as_slice_mut(), shape, stride },
        db: TensorMut { data: dev_db.as_slice_mut(), shape, stride },
    };
    plan.run(&stream, Workspace::None, args).expect("run");
    stream.synchronize().expect("sync");
    let mut got_da = vec![f16::from_f32(0.0); numel];
    let mut got_db = vec![f16::from_f32(0.0); numel];
    dev_da.copy_to_host(&mut got_da).expect("download da");
    dev_db.copy_to_host(&mut got_db).expect("download db");
    for i in 0..numel {
        let dy = host_dy[i].to_f32();
        let a = host_a[i].to_f32();
        let b = host_b[i].to_f32();
        let (exp_da, exp_db) = ref_maximum_backward(dy, a, b);
        let gd = got_da[i].to_f32();
        let tol = exp_da.abs().max(1.0) * F16_EPS;
        assert!((gd - exp_da).abs() <= tol, "maximum bw f16 da @ {i}: got {gd}, exp {exp_da}");
        let gdb = got_db[i].to_f32();
        let tol = exp_db.abs().max(1.0) * F16_EPS;
        assert!((gdb - exp_db).abs() <= tol, "maximum bw f16 db @ {i}: got {gdb}, exp {exp_db}");
    }
}

#[test]
#[ignore]
fn maximum_backward_bf16_3d() {
    let (ctx, stream) = setup();
    let shape = [8i32, 128, 128];
    let numel: usize = shape.iter().map(|&d| d as usize).product();
    let (host_dy_f32, host_a_f32, host_b_f32) = build_inputs_f32(numel);
    let host_dy: Vec<bf16> = host_dy_f32.iter().map(|&x| bf16::from_f32(x)).collect();
    let host_a: Vec<bf16> = host_a_f32.iter().map(|&x| {
        if x.is_nan() { bf16::from_f32(0.0) } else { bf16::from_f32(x) }
    }).collect();
    let host_b: Vec<bf16> = host_b_f32.iter().map(|&x| bf16::from_f32(x)).collect();
    let dev_dy = DeviceBuffer::from_slice(&ctx, &host_dy).expect("upload dy");
    let dev_a = DeviceBuffer::from_slice(&ctx, &host_a).expect("upload a");
    let dev_b = DeviceBuffer::from_slice(&ctx, &host_b).expect("upload b");
    let mut dev_da: DeviceBuffer<bf16> = DeviceBuffer::zeros(&ctx, numel).expect("alloc da");
    let mut dev_db: DeviceBuffer<bf16> = DeviceBuffer::zeros(&ctx, numel).expect("alloc db");
    let stride = contiguous_stride(shape);
    let desc = BinaryBackwardDescriptor {
        kind: BinaryKind::Maximum,
        shape,
        element: ElementKind::Bf16,
    };
    let plan = BinaryBackwardPlan::<bf16, 3>::select(&stream, &desc, PlanPreference::default())
        .expect("select");
    let args = BinaryBackwardArgs::<bf16, 3> {
        dy: TensorRef { data: dev_dy.as_slice(), shape, stride },
        a: Some(TensorRef { data: dev_a.as_slice(), shape, stride }),
        b: Some(TensorRef { data: dev_b.as_slice(), shape, stride }),
        da: TensorMut { data: dev_da.as_slice_mut(), shape, stride },
        db: TensorMut { data: dev_db.as_slice_mut(), shape, stride },
    };
    plan.run(&stream, Workspace::None, args).expect("run");
    stream.synchronize().expect("sync");
    let mut got_da = vec![bf16::from_f32(0.0); numel];
    let mut got_db = vec![bf16::from_f32(0.0); numel];
    dev_da.copy_to_host(&mut got_da).expect("download da");
    dev_db.copy_to_host(&mut got_db).expect("download db");
    for i in 0..numel {
        let dy = host_dy[i].to_f32();
        let a = host_a[i].to_f32();
        let b = host_b[i].to_f32();
        let (exp_da, exp_db) = ref_maximum_backward(dy, a, b);
        let gd = got_da[i].to_f32();
        let tol = exp_da.abs().max(1.0) * BF16_EPS;
        assert!((gd - exp_da).abs() <= tol, "maximum bw bf16 da @ {i}: got {gd}, exp {exp_da}");
        let gdb = got_db[i].to_f32();
        let tol = exp_db.abs().max(1.0) * BF16_EPS;
        assert!((gdb - exp_db).abs() <= tol, "maximum bw bf16 db @ {i}: got {gdb}, exp {exp_db}");
    }
}