use super::backend_trait::{BackendCapabilities, BackendTrait, MemcpyKind};
use crate::{runtime_error, Result};
use async_trait::async_trait;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
pub struct WasmRuntime {
capabilities: BackendCapabilities,
allocations: Mutex<HashMap<usize, usize>>,
compiled_modules: Mutex<Vec<Vec<u8>>>,
}
impl Default for WasmRuntime {
fn default() -> Self {
Self::new()
}
}
const ALLOC_ALIGN: usize = 8;
impl WasmRuntime {
pub fn new() -> Self {
let num_cpus = std::thread::available_parallelism()
.map(|n| n.get() as u32)
.unwrap_or(1);
Self {
capabilities: BackendCapabilities {
name: "WASM Runtime".to_string(),
supports_cuda: false,
supports_opencl: false,
supports_vulkan: false,
supports_webgpu: false,
max_threads: num_cpus,
max_threads_per_block: num_cpus,
max_blocks_per_grid: 1024,
max_shared_memory: 64 * 1024, supports_dynamic_parallelism: false,
supports_unified_memory: true,
max_grid_dim: [1024, 1024, 1],
max_block_dim: [num_cpus, 1, 1],
warp_size: 1,
},
allocations: Mutex::new(HashMap::new()),
compiled_modules: Mutex::new(Vec::new()),
}
}
#[cfg(test)]
fn allocation_count(&self) -> usize {
self.allocations.lock().len()
}
#[cfg(test)]
fn module_count(&self) -> usize {
self.compiled_modules.lock().len()
}
fn detect_wasm_runtime() -> Option<&'static str> {
if std::process::Command::new("wasmtime")
.arg("--version")
.output()
.is_ok()
{
return Some("wasmtime");
}
if std::process::Command::new("wasmer")
.arg("--version")
.output()
.is_ok()
{
return Some("wasmer");
}
None
}
}
#[async_trait(?Send)]
impl BackendTrait for WasmRuntime {
fn name(&self) -> &str {
&self.capabilities.name
}
fn capabilities(&self) -> &BackendCapabilities {
&self.capabilities
}
async fn initialize(&mut self) -> Result<()> {
Ok(())
}
async fn compile_kernel(&self, source: &str) -> Result<Vec<u8>> {
let bytes = source.as_bytes();
if bytes.len() >= 4 && &bytes[0..4] == b"\0asm" {
let mut modules = self.compiled_modules.lock();
let index = modules.len();
modules.push(bytes.to_vec());
Ok((index as u32).to_le_bytes().to_vec())
} else if source.trim_start().starts_with("(module") {
let mut modules = self.compiled_modules.lock();
let index = modules.len();
modules.push(bytes.to_vec());
Ok((index as u32).to_le_bytes().to_vec())
} else {
Err(runtime_error!(
"Invalid WASM module: expected WASM binary (\\0asm magic) or WAT text ((module prefix)"
))
}
}
async fn launch_kernel(
&self,
kernel: &[u8],
_grid: (u32, u32, u32),
_block: (u32, u32, u32),
_args: &[*const u8],
) -> Result<()> {
if kernel.len() < 4 {
return Err(runtime_error!(
"Invalid kernel handle: expected 4-byte module index"
));
}
let index = u32::from_le_bytes([kernel[0], kernel[1], kernel[2], kernel[3]]) as usize;
let module_bytes = {
let modules = self.compiled_modules.lock();
modules
.get(index)
.cloned()
.ok_or_else(|| runtime_error!("Module index {} not found", index))?
};
let runtime_name = Self::detect_wasm_runtime().ok_or_else(|| {
runtime_error!("No WASM runtime found (install wasmtime or wasmer)")
})?;
let tmp_dir = std::env::temp_dir();
let tmp_path = tmp_dir.join(format!(
"cuda_wasm_module_{}_{}.wasm",
std::process::id(),
index
));
std::fs::write(&tmp_path, &module_bytes)
.map_err(|e| runtime_error!("Failed to write temp WASM file: {}", e))?;
let output = std::process::Command::new(runtime_name)
.arg(tmp_path.to_str().unwrap_or("module.wasm"))
.output()
.map_err(|e| runtime_error!("Failed to execute {}: {}", runtime_name, e))?;
let _ = std::fs::remove_file(&tmp_path);
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(runtime_error!(
"{} execution failed (exit {}): {}",
runtime_name,
output.status.code().unwrap_or(-1),
stderr.trim()
));
}
Ok(())
}
fn allocate_memory(&self, size: usize) -> Result<*mut u8> {
if size == 0 {
return Err(runtime_error!("Cannot allocate 0 bytes"));
}
let layout = std::alloc::Layout::from_size_align(size, ALLOC_ALIGN)
.map_err(|e| runtime_error!("Invalid layout: {}", e))?;
let ptr = unsafe { std::alloc::alloc(layout) };
if ptr.is_null() {
return Err(runtime_error!("Failed to allocate {} bytes", size));
}
self.allocations.lock().insert(ptr as usize, size);
Ok(ptr)
}
fn free_memory(&self, ptr: *mut u8) -> Result<()> {
if ptr.is_null() {
return Err(runtime_error!("Cannot free null pointer"));
}
let size = self
.allocations
.lock()
.remove(&(ptr as usize))
.ok_or_else(|| {
runtime_error!(
"Pointer {:p} was not allocated by this backend",
ptr
)
})?;
let layout = std::alloc::Layout::from_size_align(size, ALLOC_ALIGN)
.map_err(|e| runtime_error!("Invalid layout during free: {}", e))?;
unsafe {
std::alloc::dealloc(ptr, layout);
}
Ok(())
}
fn copy_memory(
&self,
dst: *mut u8,
src: *const u8,
size: usize,
_kind: MemcpyKind,
) -> Result<()> {
if dst.is_null() {
return Err(runtime_error!("Destination pointer is null"));
}
if src.is_null() {
return Err(runtime_error!("Source pointer is null"));
}
if size == 0 {
return Err(runtime_error!("Copy size must be greater than 0"));
}
unsafe {
std::ptr::copy_nonoverlapping(src, dst, size);
}
Ok(())
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
unsafe impl Send for WasmRuntime {}
unsafe impl Sync for WasmRuntime {}
#[cfg(test)]
mod tests {
use super::*;
fn make_runtime() -> WasmRuntime {
WasmRuntime::new()
}
#[test]
fn test_capabilities_reflect_cpu_count() {
let rt = make_runtime();
let caps = rt.capabilities();
let expected = std::thread::available_parallelism()
.map(|n| n.get() as u32)
.unwrap_or(1);
assert_eq!(caps.max_threads, expected);
assert!(!caps.supports_cuda);
assert!(caps.supports_unified_memory);
}
#[test]
fn test_allocate_and_free() {
let rt = make_runtime();
let ptr = rt.allocate_memory(1024).expect("allocation should succeed");
assert!(!ptr.is_null());
assert_eq!(rt.allocation_count(), 1);
rt.free_memory(ptr).expect("free should succeed");
assert_eq!(rt.allocation_count(), 0);
}
#[test]
fn test_allocate_zero_bytes_fails() {
let rt = make_runtime();
assert!(rt.allocate_memory(0).is_err());
}
#[test]
fn test_free_null_pointer_fails() {
let rt = make_runtime();
assert!(rt.free_memory(std::ptr::null_mut()).is_err());
}
#[test]
fn test_free_unknown_pointer_fails() {
let rt = make_runtime();
let fake: *mut u8 = 0xDEAD_BEEF as *mut u8;
assert!(rt.free_memory(fake).is_err());
}
#[test]
fn test_double_free_fails() {
let rt = make_runtime();
let ptr = rt.allocate_memory(64).unwrap();
rt.free_memory(ptr).unwrap();
assert!(rt.free_memory(ptr).is_err());
}
#[test]
fn test_copy_memory_roundtrip() {
let rt = make_runtime();
let src = rt.allocate_memory(4).unwrap();
let dst = rt.allocate_memory(4).unwrap();
unsafe {
std::ptr::write_bytes(src, 0xAB, 4);
}
rt.copy_memory(dst, src, 4, MemcpyKind::HostToHost)
.expect("copy should succeed");
unsafe {
for i in 0..4 {
assert_eq!(*dst.add(i), 0xAB);
}
}
rt.free_memory(src).unwrap();
rt.free_memory(dst).unwrap();
}
#[test]
fn test_copy_memory_null_dst_fails() {
let rt = make_runtime();
let src = rt.allocate_memory(4).unwrap();
assert!(rt
.copy_memory(std::ptr::null_mut(), src, 4, MemcpyKind::HostToHost)
.is_err());
rt.free_memory(src).unwrap();
}
#[test]
fn test_copy_memory_null_src_fails() {
let rt = make_runtime();
let dst = rt.allocate_memory(4).unwrap();
assert!(rt
.copy_memory(dst, std::ptr::null(), 4, MemcpyKind::HostToHost)
.is_err());
rt.free_memory(dst).unwrap();
}
#[test]
fn test_copy_memory_zero_size_fails() {
let rt = make_runtime();
let src = rt.allocate_memory(4).unwrap();
let dst = rt.allocate_memory(4).unwrap();
assert!(rt
.copy_memory(dst, src, 0, MemcpyKind::HostToHost)
.is_err());
rt.free_memory(src).unwrap();
rt.free_memory(dst).unwrap();
}
#[tokio::test]
async fn test_compile_wasm_binary() {
let rt = make_runtime();
let wasm_source = "\0asm\x01\x00\x00\x00";
let handle = rt.compile_kernel(wasm_source).await.unwrap();
assert_eq!(handle.len(), 4);
assert_eq!(rt.module_count(), 1);
let index = u32::from_le_bytes([handle[0], handle[1], handle[2], handle[3]]);
assert_eq!(index, 0);
}
#[tokio::test]
async fn test_compile_wat_text() {
let rt = make_runtime();
let wat = "(module)";
let handle = rt.compile_kernel(wat).await.unwrap();
assert_eq!(handle.len(), 4);
assert_eq!(rt.module_count(), 1);
}
#[tokio::test]
async fn test_compile_invalid_source_fails() {
let rt = make_runtime();
assert!(rt.compile_kernel("not wasm at all").await.is_err());
}
#[tokio::test]
async fn test_compile_multiple_modules() {
let rt = make_runtime();
let h1 = rt.compile_kernel("(module)").await.unwrap();
let h2 = rt.compile_kernel("(module (func))").await.unwrap();
let i1 = u32::from_le_bytes([h1[0], h1[1], h1[2], h1[3]]);
let i2 = u32::from_le_bytes([h2[0], h2[1], h2[2], h2[3]]);
assert_eq!(i1, 0);
assert_eq!(i2, 1);
assert_eq!(rt.module_count(), 2);
}
#[tokio::test]
async fn test_launch_kernel_invalid_handle() {
let rt = make_runtime();
let result = rt
.launch_kernel(&[0u8, 1], (1, 1, 1), (1, 1, 1), &[])
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_launch_kernel_missing_module() {
let rt = make_runtime();
let handle = 99u32.to_le_bytes();
let result = rt
.launch_kernel(&handle, (1, 1, 1), (1, 1, 1), &[])
.await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("not found"));
}
#[test]
fn test_synchronize() {
let rt = make_runtime();
assert!(rt.synchronize().is_ok());
}
#[test]
fn test_default() {
let rt = WasmRuntime::default();
assert_eq!(rt.name(), "WASM Runtime");
}
}