use super::*;
use std::sync::OnceLock;
use crate::backends::gpu::GpuDevice;
static SHARED_DEVICE: OnceLock<Option<GpuDevice>> = OnceLock::new();
fn get_shared_device() -> Option<GpuDevice> {
SHARED_DEVICE
.get_or_init(|| if GpuDevice::is_available() { GpuDevice::new().ok() } else { None })
.clone()
}
#[test]
fn test_buffer_allocation() {
let Some(device) = get_shared_device() else {
eprintln!("GPU not available, skipping");
return;
};
let mut batch = GpuCommandBatch::new(device);
let buf1 = batch.upload(&[1.0, 2.0, 3.0]);
let buf2 = batch.upload(&[4.0, 5.0, 6.0]);
assert_eq!(batch.num_buffers(), 2);
assert_ne!(buf1, buf2);
}
#[test]
fn test_operation_queuing() {
let Some(device) = get_shared_device() else {
eprintln!("GPU not available, skipping");
return;
};
let mut batch = GpuCommandBatch::new(device);
let input = batch.upload(&[1.0, 2.0, -3.0, 4.0]);
let relu_out = batch.relu(input);
let scaled = batch.scale(relu_out, 2.0);
let other = batch.upload(&[0.5, 0.5, 0.5, 0.5]);
let _final_out = batch.add(scaled, other);
assert_eq!(batch.num_operations(), 3); assert_eq!(batch.num_buffers(), 5); }
#[test]
#[should_panic(expected = "Buffer size mismatch")]
fn test_size_mismatch_add() {
let Some(device) = get_shared_device() else {
panic!("Buffer size mismatch"); };
let mut batch = GpuCommandBatch::new(device);
let a = batch.upload(&[1.0, 2.0]);
let b = batch.upload(&[1.0, 2.0, 3.0]);
batch.add(a, b); }
#[test]
#[should_panic(expected = "Buffer size mismatch")]
fn test_size_mismatch_mul() {
let Some(device) = get_shared_device() else {
panic!("Buffer size mismatch"); };
let mut batch = GpuCommandBatch::new(device);
let a = batch.upload(&[1.0, 2.0]);
let b = batch.upload(&[1.0, 2.0, 3.0]);
batch.mul(a, b); }
#[test]
#[should_panic(expected = "Buffer size mismatch")]
fn test_size_mismatch_dot() {
let Some(device) = get_shared_device() else {
panic!("Buffer size mismatch"); };
let mut batch = GpuCommandBatch::new(device);
let a = batch.upload(&[1.0, 2.0]);
let b = batch.upload(&[1.0, 2.0, 3.0]);
batch.dot(a, b); }
#[tokio::test]
async fn test_all_batch_operations() {
let Some(device) = get_shared_device() else {
eprintln!("GPU not available, skipping");
return;
};
let mut batch = GpuCommandBatch::new(device);
let input1 = batch.upload(&[1.0, 2.0, -3.0, 4.0]);
let relu_out = batch.relu(input1);
let scaled = batch.scale(relu_out, 2.0);
let other = batch.upload(&[0.5, 0.5, 0.5, 0.5]);
let add_result = batch.add(scaled, other);
let mul_a = batch.upload(&[1.0, 2.0, 3.0, 4.0]);
let mul_b = batch.upload(&[2.0, 3.0, 4.0, 5.0]);
let mul_result = batch.mul(mul_a, mul_b);
let dot_a = batch.upload(&[1.0, 2.0, 3.0, 4.0]);
let dot_b = batch.upload(&[2.0, 3.0, 4.0, 5.0]);
let dot_result = batch.dot(dot_a, dot_b);
let sig_input = batch.upload(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let sig_result = batch.sigmoid(sig_input);
let tanh_input = batch.upload(&[-1.0, 0.0, 1.0]);
let tanh_result = batch.tanh(tanh_input);
let swish_input = batch.upload(&[0.0, 1.0, 2.0]);
let swish_result = batch.swish(swish_input);
let gelu_input = batch.upload(&[-1.0, 0.0, 1.0]);
let gelu_result = batch.gelu(gelu_input);
let sub_a = batch.upload(&[5.0, 10.0, 15.0, 20.0]);
let sub_b = batch.upload(&[1.0, 2.0, 3.0, 4.0]);
let sub_result = batch.sub(sub_a, sub_b);
let chain_input = batch.upload(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let chain_relu = batch.relu(chain_input);
let chain_sigmoid = batch.sigmoid(chain_relu);
let chain_result = batch.tanh(chain_sigmoid);
batch.execute().await.unwrap();
let result1 = batch.read(add_result).await.unwrap();
assert_eq!(result1.len(), 4);
assert!((result1[0] - 2.5).abs() < 1e-5);
assert!((result1[1] - 4.5).abs() < 1e-5);
assert!((result1[2] - 0.5).abs() < 1e-5);
assert!((result1[3] - 8.5).abs() < 1e-5);
let result2 = batch.read(mul_result).await.unwrap();
assert_eq!(result2, vec![2.0, 6.0, 12.0, 20.0]);
let result3 = batch.read(dot_result).await.unwrap();
assert!(!result3.is_empty());
let result4 = batch.read(sig_result).await.unwrap();
assert_eq!(result4.len(), 5);
assert!((result4[0] - 0.119).abs() < 0.01); assert!((result4[2] - 0.5).abs() < 0.01); assert!((result4[4] - 0.881).abs() < 0.01);
let result5 = batch.read(tanh_result).await.unwrap();
assert_eq!(result5.len(), 3);
assert!((result5[0] - (-0.762)).abs() < 0.01);
assert!(result5[1].abs() < 0.01);
assert!((result5[2] - 0.762).abs() < 0.01);
let result6 = batch.read(swish_result).await.unwrap();
assert_eq!(result6.len(), 3);
assert!(result6[0].abs() < 0.01);
assert!((result6[1] - 0.731).abs() < 0.01);
let result7 = batch.read(gelu_result).await.unwrap();
assert_eq!(result7.len(), 3);
assert!(result7[1].abs() < 0.01);
assert!((result7[2] - 0.841).abs() < 0.05);
let result8 = batch.read(sub_result).await.unwrap();
assert_eq!(result8, vec![4.0, 8.0, 12.0, 16.0]);
let result9 = batch.read(chain_result).await.unwrap();
assert_eq!(result9.len(), 5);
for &val in &result9 {
assert!((-1.0..=1.0).contains(&val), "Value {} out of range", val);
}
}