use oxibonsai_kernels::gpu_backend::{
gpu_gemv_1bit, gpu_matmul, select_backend, CpuBackend, DeviceBuffer, GpuBackendTrait,
LaunchConfig,
};
fn backend() -> CpuBackend {
CpuBackend::new()
}
#[test]
fn cpu_backend_name() {
let b = backend();
assert_eq!(b.name(), "cpu");
}
#[test]
fn cpu_backend_not_accelerated() {
let b = backend();
assert!(!b.is_accelerated());
}
#[test]
fn cpu_backend_device_count() {
let b = backend();
assert!(b.device_count() >= 1);
}
#[test]
fn cpu_backend_alloc() {
let b = backend();
let buf = b.alloc(100, 0).expect("alloc should succeed");
assert_eq!(buf.size(), 100);
assert_eq!(buf.device_id(), 0);
assert!(buf.data.iter().all(|&v| v == 0.0_f32));
}
#[test]
fn cpu_backend_host_to_device() {
let b = backend();
let src: Vec<f32> = (0..8).map(|i| i as f32).collect();
let buf = b
.host_to_device(&src, 0)
.expect("host_to_device should succeed");
assert_eq!(buf.size(), 8);
assert_eq!(buf.to_vec(), src);
}
#[test]
fn cpu_backend_device_to_host() {
let b = backend();
let src = vec![1.0_f32, 2.0, 3.0, 4.0];
let buf = b.host_to_device(&src, 0).expect("host_to_device");
let out = b.device_to_host(&buf).expect("device_to_host");
assert_eq!(out, src);
}
#[test]
fn cpu_backend_matvec_2x2() {
let b = backend();
let a = b.host_to_device(&[1.0, 0.0, 0.0, 1.0], 0).expect("h2d a");
let x = b.host_to_device(&[3.0, 4.0], 0).expect("h2d x");
let y = b.matvec(&a, &x, 2, 2, 0).expect("matvec");
let result = b.device_to_host(&y).expect("d2h");
assert!((result[0] - 3.0).abs() < 1e-6, "y[0] = {}", result[0]);
assert!((result[1] - 4.0).abs() < 1e-6, "y[1] = {}", result[1]);
}
#[test]
fn cpu_backend_matvec_identity() {
let b = backend();
let n = 4_usize;
let mut identity = vec![0.0_f32; n * n];
for i in 0..n {
identity[i * n + i] = 1.0;
}
let x_data: Vec<f32> = (1..=n as u32).map(|v| v as f32).collect();
let a_buf = b.host_to_device(&identity, 0).expect("h2d a");
let x_buf = b.host_to_device(&x_data, 0).expect("h2d x");
let y_buf = b.matvec(&a_buf, &x_buf, n, n, 0).expect("matvec");
let y = b.device_to_host(&y_buf).expect("d2h");
for i in 0..n {
assert!(
(y[i] - x_data[i]).abs() < 1e-5,
"y[{}] = {} != {}",
i,
y[i],
x_data[i]
);
}
}
#[test]
fn cpu_backend_relu_positive() {
let b = backend();
let input = vec![1.0_f32, -1.0, 2.0];
let buf = b.host_to_device(&input, 0).expect("h2d");
let out_buf = b.relu(&buf, 0).expect("relu");
let out = b.device_to_host(&out_buf).expect("d2h");
assert!((out[0] - 1.0).abs() < 1e-6);
assert!((out[1] - 0.0).abs() < 1e-6);
assert!((out[2] - 2.0).abs() < 1e-6);
}
#[test]
fn cpu_backend_relu_all_negative() {
let b = backend();
let input = vec![-1.0_f32, -2.0];
let buf = b.host_to_device(&input, 0).expect("h2d");
let out_buf = b.relu(&buf, 0).expect("relu");
let out = b.device_to_host(&out_buf).expect("d2h");
for &v in &out {
assert!((v - 0.0).abs() < 1e-6, "expected 0, got {v}");
}
}
#[test]
fn cpu_backend_softmax_sums_to_one() {
let b = backend();
let input = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let size = input.len();
let buf = b.host_to_device(&input, 0).expect("h2d");
let out_buf = b.softmax(&buf, size, 0).expect("softmax");
let out = b.device_to_host(&out_buf).expect("d2h");
let sum: f32 = out.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax sum = {sum}, expected ~1.0"
);
}
#[test]
fn cpu_backend_softmax_uniform() {
let b = backend();
let input = vec![0.0_f32; 3];
let size = input.len();
let buf = b.host_to_device(&input, 0).expect("h2d");
let out_buf = b.softmax(&buf, size, 0).expect("softmax");
let out = b.device_to_host(&out_buf).expect("d2h");
let expected = 1.0 / 3.0_f32;
for &v in &out {
assert!((v - expected).abs() < 1e-5, "expected ~{expected}, got {v}");
}
}
#[test]
fn cpu_backend_synchronize_ok() {
let b = backend();
b.synchronize(0).expect("synchronize should succeed");
}
#[test]
fn cpu_backend_memory_info() {
let b = backend();
let (free, total) = b.memory_info(0).expect("memory_info should succeed");
assert!(free <= total, "free ({free}) must be <= total ({total})");
assert!(total > 0, "total memory must be > 0");
}
#[test]
fn select_backend_returns_usable() {
let b = select_backend();
assert!(!b.name().is_empty());
#[cfg(not(feature = "gpu"))]
assert!(!b.is_accelerated());
}
#[test]
fn gpu_matmul_identity() {
let b = backend();
let n = 3_usize;
let mut identity = vec![0.0_f32; n * n];
for i in 0..n {
identity[i * n + i] = 1.0;
}
let b_mat = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let result = gpu_matmul(&b, &identity, &b_mat, n, n, 2, 0).expect("gpu_matmul");
for (r, expected) in result.iter().zip(b_mat.iter()) {
assert!((r - expected).abs() < 1e-5, "got {r}, expected {expected}");
}
}
#[test]
fn launch_config_for_n() {
let cfg = LaunchConfig::for_n_elements(1024);
assert_eq!(cfg.block_dim.0, 256);
assert_eq!(cfg.grid_dim.0, 4);
assert_eq!(cfg.grid_dim.1, 1);
assert_eq!(cfg.grid_dim.2, 1);
assert_eq!(cfg.block_dim.1, 1);
assert_eq!(cfg.block_dim.2, 1);
}
#[test]
fn device_buffer_size() {
let sizes = [0_usize, 1, 64, 1024, 65536];
for &s in &sizes {
let buf = DeviceBuffer::new(s, 0);
assert_eq!(buf.size(), s, "DeviceBuffer::size() mismatch for size={s}");
}
}
fn make_q1_block(scale: f32, bits: [u8; 16]) -> Vec<u8> {
let d = half::f16::from_f32(scale);
let mut block = Vec::with_capacity(18);
block.extend_from_slice(&d.to_bits().to_le_bytes());
block.extend_from_slice(&bits);
block
}
#[test]
fn gemv_1bit_all_ones_scale_one() {
let block = make_q1_block(1.0, [0xFF; 16]);
let input: Vec<f32> = (0..128).map(|i| i as f32).collect();
let expected: f32 = input.iter().sum(); let result = gpu_gemv_1bit(&block, &input, 1, 128).expect("gemv_1bit");
assert!(
(result[0] - expected).abs() < 1.0,
"got {} expected {}",
result[0],
expected,
);
}
#[test]
fn gemv_1bit_all_zeros_neg_sum() {
let block = make_q1_block(1.0, [0x00; 16]);
let input = vec![1.0_f32; 128];
let result = gpu_gemv_1bit(&block, &input, 1, 128).expect("gemv_1bit");
assert!(
(result[0] - (-128.0)).abs() < 1.0,
"got {} expected -128",
result[0],
);
}
#[test]
fn gemv_1bit_multi_row() {
let row0 = make_q1_block(2.0, [0xFF; 16]); let row1 = make_q1_block(0.5, [0x00; 16]); let mut blocks = row0;
blocks.extend_from_slice(&row1);
let input = vec![1.0_f32; 128];
let result = gpu_gemv_1bit(&blocks, &input, 2, 128).expect("gemv_1bit 2 rows");
assert!((result[0] - 256.0).abs() < 1.0, "row0: got {}", result[0]);
assert!((result[1] - (-64.0)).abs() < 1.0, "row1: got {}", result[1]);
}
#[test]
fn gemv_1bit_two_blocks_per_row() {
let b0 = make_q1_block(1.0, [0xFF; 16]); let b1 = make_q1_block(1.0, [0x00; 16]); let mut blocks = b0;
blocks.extend_from_slice(&b1);
let input = vec![1.0_f32; 256];
let result = gpu_gemv_1bit(&blocks, &input, 1, 256).expect("gemv_1bit k=256");
assert!(result[0].abs() < 1.0, "expected ~0, got {}", result[0],);
}
#[test]
fn gemv_1bit_bad_k_not_multiple_128() {
let result = gpu_gemv_1bit(&[], &[], 0, 64);
assert!(result.is_err());
}
#[test]
fn gemv_1bit_input_size_mismatch() {
let block = make_q1_block(1.0, [0xFF; 16]);
let input = vec![1.0_f32; 64]; let result = gpu_gemv_1bit(&block, &input, 1, 128);
assert!(result.is_err());
}