use cuda_rust_wasm::{
runtime::{Runtime, Device, BackendType, Grid, Block, Dim3},
kernel::{launch_kernel, LaunchConfig, KernelFunction, ThreadContext},
memory::{MemoryPool, PoolConfig, PoolStats, DeviceBuffer},
error::CudaRustError,
};
use std::sync::Arc;
use std::thread;
use std::time::Instant;
#[cfg(test)]
mod runtime_tests {
use super::*;
#[test]
fn test_runtime_initialization() {
let runtime = Runtime::new();
assert!(runtime.is_ok(), "Runtime should initialize successfully");
}
#[test]
fn test_runtime_device_access() {
let runtime = Runtime::new().unwrap();
let device = runtime.device();
let props = device.properties();
assert!(!props.name.is_empty(), "Device should have a name");
assert!(props.max_threads_per_block > 0);
}
#[test]
fn test_runtime_synchronize() {
let runtime = Runtime::new().unwrap();
let result = runtime.synchronize();
assert!(result.is_ok(), "Synchronize should succeed");
}
#[test]
fn test_runtime_create_stream() {
let runtime = Runtime::new().unwrap();
let stream = runtime.create_stream();
assert!(stream.is_ok(), "Stream creation should succeed");
}
#[test]
fn test_device_default() {
let device = Device::get_default();
assert!(device.is_ok(), "Default device should be available");
let device = device.unwrap();
assert_eq!(device.id(), 0, "Default device should have id 0");
}
#[test]
fn test_device_count() {
let count = Device::count();
assert!(count.is_ok());
assert!(count.unwrap() >= 1, "Should have at least one device");
}
#[test]
fn test_device_backend_type() {
let device = Device::get_default().unwrap();
let backend = device.backend();
match backend {
BackendType::CPU | BackendType::Native | BackendType::WebGPU => {
}
}
}
#[test]
fn test_device_properties() {
let device = Device::get_default().unwrap();
let props = device.properties();
assert!(!props.name.is_empty());
assert!(props.max_threads_per_block > 0);
assert!(props.max_blocks_per_grid > 0);
}
#[test]
fn test_device_buffer_allocation() {
let device = Device::get_default().unwrap();
let buffer = DeviceBuffer::<f32>::new(1024, device);
assert!(buffer.is_ok(), "Buffer allocation should succeed");
let buffer = buffer.unwrap();
assert_eq!(buffer.len(), 1024);
assert!(!buffer.is_empty());
}
#[test]
fn test_device_buffer_zero_length() {
let device = Device::get_default().unwrap();
let buffer = DeviceBuffer::<f32>::new(0, device);
assert!(buffer.is_err(), "Zero-length buffer should fail");
}
#[test]
fn test_device_buffer_copy_roundtrip() {
let device = Device::get_default().unwrap();
let n = 256;
let host_data: Vec<f32> = (0..n).map(|i| i as f32 * 1.5).collect();
let mut buffer = DeviceBuffer::<f32>::new(n, device).unwrap();
buffer.copy_from_host(&host_data).unwrap();
let mut result = vec![0.0f32; n];
buffer.copy_to_host(&mut result).unwrap();
assert_eq!(host_data, result, "Data should survive round-trip copy");
}
#[test]
fn test_device_buffer_copy_length_mismatch() {
let device = Device::get_default().unwrap();
let mut buffer = DeviceBuffer::<f32>::new(100, device).unwrap();
let wrong_size_data = vec![0.0f32; 50];
let result = buffer.copy_from_host(&wrong_size_data);
assert!(result.is_err(), "Mismatched copy should fail");
}
#[test]
fn test_device_buffer_fill() {
let device = Device::get_default().unwrap();
let n = 100;
let mut buffer = DeviceBuffer::<f32>::new(n, device).unwrap();
buffer.fill(42.0f32).unwrap();
let mut result = vec![0.0f32; n];
buffer.copy_to_host(&mut result).unwrap();
for val in &result {
assert_eq!(*val, 42.0f32);
}
}
#[test]
fn test_concurrent_device_access() {
let num_threads = 4;
let barrier = Arc::new(std::sync::Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
for _ in 0..5 {
let device = Device::get_default().unwrap();
let props = device.properties();
assert!(!props.name.is_empty());
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_memory_pool_basic() {
let pool = MemoryPool::new();
let buf = pool.allocate(4096);
assert_eq!(buf.len(), 4096);
pool.deallocate(buf);
let stats = pool.stats();
assert!(stats.total_allocations >= 1);
}
#[test]
fn test_memory_pool_reuse() {
let pool = MemoryPool::new();
let buf = pool.allocate(2048);
pool.deallocate(buf);
let buf2 = pool.allocate(2048);
assert_eq!(buf2.len(), 2048);
pool.deallocate(buf2);
assert!(pool.hit_ratio() > 0.0, "Should have cache hits after reuse");
}
#[test]
fn test_memory_pool_stats() {
let pool = MemoryPool::new();
for _ in 0..10 {
let buf = pool.allocate(1024);
pool.deallocate(buf);
}
let stats = pool.stats();
assert_eq!(stats.total_allocations, 10);
assert!(stats.cache_hits > 0, "Should have cache hits");
assert!(stats.total_bytes_allocated >= 10 * 1024);
}
#[test]
fn test_memory_pool_clear() {
let pool = MemoryPool::new();
let buf = pool.allocate(4096);
pool.deallocate(buf);
assert!(pool.stats().total_allocations > 0);
pool.clear();
let stats = pool.stats();
assert_eq!(stats.total_allocations, 0);
assert_eq!(stats.cache_hits, 0);
}
#[test]
fn test_memory_pool_custom_config() {
let config = PoolConfig {
max_pool_size: 4 * 1024 * 1024,
min_pooled_size: 256,
max_pooled_size: 1 * 1024 * 1024,
prealloc_count: 4,
};
let pool = MemoryPool::with_config(config);
let buf = pool.allocate(512);
assert_eq!(buf.len(), 512);
pool.deallocate(buf);
}
#[test]
fn test_kernel_launch_simple() {
struct DoubleKernel;
impl KernelFunction<Arc<std::sync::Mutex<Vec<f32>>>> for DoubleKernel {
fn execute(
&self,
args: Arc<std::sync::Mutex<Vec<f32>>>,
ctx: ThreadContext,
) {
let idx = ctx.global_thread_id();
let mut data = args.lock().unwrap();
if idx < data.len() {
data[idx] *= 2.0;
}
}
fn name(&self) -> &str {
"double_kernel"
}
}
let n = 32;
let data = Arc::new(std::sync::Mutex::new(
(0..n).map(|i| i as f32).collect::<Vec<f32>>(),
));
let config = LaunchConfig::new(Grid::new(1u32), Block::new(n as u32));
let result = launch_kernel(DoubleKernel, config, Arc::clone(&data));
assert!(result.is_ok(), "Kernel launch should succeed");
let result_data = data.lock().unwrap();
for i in 0..n {
assert_eq!(result_data[i], (i as f32) * 2.0, "Element {} should be doubled", i);
}
}
#[test]
fn test_launch_config_with_shared_memory() {
let config = LaunchConfig::new(Grid::new(4u32), Block::new(128u32))
.with_shared_memory(1024);
assert_eq!(config.shared_memory_bytes, 1024);
assert_eq!(config.grid.dim.x, 4);
assert_eq!(config.block.dim.x, 128);
}
#[test]
fn test_launch_config_2d() {
let config = LaunchConfig::new(
Grid::new((4u32, 4u32)),
Block::new((16u32, 16u32)),
);
assert_eq!(config.grid.dim.x, 4);
assert_eq!(config.grid.dim.y, 4);
assert_eq!(config.block.dim.x, 16);
assert_eq!(config.block.dim.y, 16);
}
#[test]
fn test_block_validation_succeeds() {
let block = Block::new(256u32);
assert!(block.validate().is_ok());
}
#[test]
fn test_block_validation_exceeds_limit() {
let block = Block::new(2048u32);
assert!(block.validate().is_err(), "Block size 2048 should exceed max threads per block");
}
#[test]
fn test_performance_memory_pool_allocation() {
let pool = MemoryPool::new();
let start = Instant::now();
let iterations = 1000;
for _ in 0..iterations {
let buf = pool.allocate(4096);
pool.deallocate(buf);
}
let duration = start.elapsed();
println!(
"Memory pool: {} allocations/deallocations in {:?} ({:.0} ops/sec)",
iterations,
duration,
iterations as f64 / duration.as_secs_f64()
);
assert!(
duration.as_secs() < 10,
"Memory pool operations should be fast"
);
}
#[test]
fn test_error_types() {
let err = CudaRustError::RuntimeError("test error".to_string());
let msg = err.to_string();
assert!(msg.contains("test error"), "Error message should contain the original text");
let err = CudaRustError::MemoryError("out of memory".to_string());
let msg = err.to_string();
assert!(msg.contains("out of memory"));
}
#[test]
fn test_device_buffer_multiple_types() {
let device = Device::get_default().unwrap();
let mut buf_f32 = DeviceBuffer::<f32>::new(10, device.clone()).unwrap();
let data_f32: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
buf_f32.copy_from_host(&data_f32).unwrap();
let mut result_f32 = vec![0.0f32; 10];
buf_f32.copy_to_host(&mut result_f32).unwrap();
assert_eq!(data_f32, result_f32);
let mut buf_u32 = DeviceBuffer::<u32>::new(5, device.clone()).unwrap();
let data_u32: Vec<u32> = vec![100, 200, 300, 400, 500];
buf_u32.copy_from_host(&data_u32).unwrap();
let mut result_u32 = vec![0u32; 5];
buf_u32.copy_to_host(&mut result_u32).unwrap();
assert_eq!(data_u32, result_u32);
let mut buf_i64 = DeviceBuffer::<i64>::new(3, device).unwrap();
let data_i64: Vec<i64> = vec![-1, 0, 1];
buf_i64.copy_from_host(&data_i64).unwrap();
let mut result_i64 = vec![0i64; 3];
buf_i64.copy_to_host(&mut result_i64).unwrap();
assert_eq!(data_i64, result_i64);
}
#[test]
fn test_global_memory_pool() {
let buf = cuda_rust_wasm::memory::allocate(2048);
assert_eq!(buf.len(), 2048);
cuda_rust_wasm::memory::deallocate(buf);
let pool = cuda_rust_wasm::memory::global_pool();
let stats = pool.stats();
assert!(stats.total_allocations > 0);
}
}