#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
fn create_executor() -> Option<CudaExecutor> {
CudaExecutor::new(0).ok()
}
#[test]
fn test_gemm_identity_matrix() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0, 0.0, 0.0, 1.0];
let b = vec![3.0, 4.0, 5.0, 6.0];
let mut c = vec![0.0; 4];
exec.gemm(&a, &b, &mut c, 2, 2, 2).unwrap();
assert!((c[0] - 3.0).abs() < 0.1);
assert!((c[3] - 6.0).abs() < 0.1);
}
#[test]
fn test_gemm_small_matrix() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
exec.gemm(&a, &b, &mut c, 2, 2, 2).unwrap();
assert!((c[0] - 19.0).abs() < 1.0);
assert!((c[3] - 50.0).abs() < 1.0);
}
#[test]
fn test_gemm_m1_uses_gemv() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0, 2.0, 3.0, 4.0]; let b = vec![1.0, 0.5, 0.25, 0.125, 2.0, 1.0, 0.5, 0.25]; let mut c = vec![0.0; 2]; exec.gemm(&a, &b, &mut c, 1, 2, 4).unwrap();
assert!(c[0] > 0.0);
}
#[test]
fn test_gemm_size_mismatch_error() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0, 2.0]; let b = vec![1.0, 2.0, 3.0, 4.0];
let mut c = vec![0.0; 4];
let result = exec.gemm(&a, &b, &mut c, 2, 2, 2);
assert!(result.is_err());
}
#[test]
fn test_gemm_optimized_tile32() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0; 64]; let b = vec![1.0; 64]; let mut c = vec![0.0; 64]; exec.gemm_optimized(&a, &b, &mut c, 8, 8, 8, 32).unwrap();
assert!(c[0] > 0.0);
}
#[test]
fn test_gemm_fused_no_activation() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0; 4];
let b = vec![1.0; 4];
let mut c = vec![0.0; 4];
exec.gemm_fused(&a, &b, None, &mut c, 2, 2, 2, 0).unwrap();
assert!(c[0] > 0.0);
}
#[test]
fn test_gemm_fused_with_relu() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0, -1.0, 1.0, -1.0]; let b = vec![1.0; 4];
let mut c = vec![0.0; 4];
exec.gemm_fused(&a, &b, None, &mut c, 2, 2, 2, 1).unwrap();
for val in &c {
assert!(*val >= 0.0);
}
}
#[test]
fn test_gemm_fused_with_bias() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0; 4];
let b = vec![1.0; 4];
let bias = vec![10.0, 20.0];
let mut c = vec![0.0; 4];
exec.gemm_fused(&a, &b, Some(&bias), &mut c, 2, 2, 2, 0)
.unwrap();
assert!(c[0] > 10.0);
}
#[test]
fn test_gemm_fused_bias_size_mismatch() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0; 4];
let b = vec![1.0; 4];
let bias = vec![10.0]; let mut c = vec![0.0; 4];
let result = exec.gemm_fused(&a, &b, Some(&bias), &mut c, 2, 2, 2, 0);
assert!(result.is_err());
}
#[test]
fn test_softmax_uniform() {
let Some(mut exec) = create_executor() else {
return;
};
let mut data = vec![1.0, 1.0, 1.0, 1.0];
exec.softmax(&mut data).unwrap();
let expected = 0.25;
for val in &data {
assert!((*val - expected).abs() < 0.01);
}
}
#[test]
fn test_softmax_single_max() {
let Some(mut exec) = create_executor() else {
return;
};
let mut data = vec![10.0, 0.0, 0.0, 0.0];
exec.softmax(&mut data).unwrap();
assert!(data[0] > 0.9);
let sum: f32 = data.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
}
#[test]
fn test_softmax_sum_to_one() {
let Some(mut exec) = create_executor() else {
return;
};
let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
exec.softmax(&mut data).unwrap();
let sum: f32 = data.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
}
#[test]
fn test_softmax_large_values() {
let Some(mut exec) = create_executor() else {
return;
};
let mut data = vec![100.0, 101.0, 102.0, 103.0];
exec.softmax(&mut data).unwrap();
let sum: f32 = data.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
}
#[test]
fn test_allocate_buffer_basic() {
let Some(exec) = create_executor() else {
return;
};
let buf = exec.allocate_buffer(1024).unwrap();
assert_eq!(buf.len(), 1024);
}
#[test]
fn test_allocate_buffer_small() {
let Some(exec) = create_executor() else {
return;
};
let buf = exec.allocate_buffer(1).unwrap();
assert_eq!(buf.len(), 1);
}
#[test]
fn test_allocate_buffer_large() {
let Some(exec) = create_executor() else {
return;
};
let buf = exec.allocate_buffer(1024 * 1024).unwrap();
assert_eq!(buf.len(), 1024 * 1024);
}
#[test]
fn test_synchronize_compute() {
let Some(exec) = create_executor() else {
return;
};
assert!(exec.synchronize_compute().is_ok());
}
#[test]
fn test_synchronize_transfer() {
let Some(exec) = create_executor() else {
return;
};
assert!(exec.synchronize_transfer().is_ok());
}
#[test]
fn test_synchronize_all() {
let Some(exec) = create_executor() else {
return;
};
assert!(exec.synchronize_all().is_ok());
}
#[test]
fn test_gemm_cached_weight_not_found() {
let Some(mut exec) = create_executor() else {
return;
};
let b = vec![1.0; 4];
let mut c = vec![0.0; 4];
let result = exec.gemm_cached("nonexistent_weight", &b, &mut c, 2, 2, 2);
assert!(result.is_err());
}
#[test]
fn test_gemm_b_cached_weight_not_found() {
let Some(mut exec) = create_executor() else {
return;
};
let a = vec![1.0; 4];
let mut c = vec![0.0; 4];
let result = exec.gemm_b_cached("nonexistent_weight", &a, &mut c, 2, 2, 2);
assert!(result.is_err());
}
#[test]
fn test_gemv_cached_weight_not_found() {
let Some(mut exec) = create_executor() else {
return;
};
let x = vec![1.0; 4];
let mut y = vec![0.0; 4];
let result = exec.gemv_cached("nonexistent_weight", &x, &mut y, 4, 4);
assert!(result.is_err());
}
#[test]
fn test_gemv_cached_input_size_mismatch() {
let Some(mut exec) = create_executor() else {
return;
};
let x = vec![1.0; 2]; let mut y = vec![0.0; 4];
let result = exec.gemv_cached("test", &x, &mut y, 4, 4);
assert!(result.is_err());
}
#[test]
fn test_gemv_cached_output_size_mismatch() {
let Some(mut exec) = create_executor() else {
return;
};
let x = vec![1.0; 4];
let mut y = vec![0.0; 2]; let result = exec.gemv_cached("test", &x, &mut y, 4, 4);
assert!(result.is_err());
}
#[test]
fn test_async_copy_roundtrip() {
let Some(exec) = create_executor() else {
return;
};
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let mut buf = exec.allocate_buffer(4).unwrap();
unsafe {
exec.copy_to_gpu_async(&mut buf, &data).unwrap();
}
exec.synchronize_transfer().unwrap();
let mut result = vec![0.0f32; 4];
unsafe {
exec.copy_from_gpu_async(&buf, &mut result).unwrap();
}
exec.synchronize_transfer().unwrap();
assert_eq!(result, data);
}
}