cuda_rust_wasm/backend/
wasm_runtime.rs1use super::backend_trait::{BackendTrait, BackendCapabilities, MemcpyKind};
4use crate::{Result, runtime_error};
5use std::sync::Arc;
6use async_trait::async_trait;
7
8pub struct WasmRuntime {
10 capabilities: BackendCapabilities,
11}
12
13impl Default for WasmRuntime {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl WasmRuntime {
20 pub fn new() -> Self {
22 Self {
23 capabilities: BackendCapabilities {
24 name: "WASM Runtime".to_string(),
25 supports_cuda: false,
26 supports_opencl: false,
27 supports_vulkan: false,
28 supports_webgpu: false,
29 max_threads: 1,
30 max_threads_per_block: 1,
31 max_blocks_per_grid: 1,
32 max_shared_memory: 0,
33 supports_dynamic_parallelism: false,
34 supports_unified_memory: false,
35 max_grid_dim: [1, 1, 1],
36 max_block_dim: [1, 1, 1],
37 warp_size: 1,
38 },
39 }
40 }
41}
42
43#[async_trait]
44impl BackendTrait for WasmRuntime {
45 fn name(&self) -> &str {
46 &self.capabilities.name
47 }
48 fn capabilities(&self) -> &BackendCapabilities {
49 &self.capabilities
50 }
51
52 async fn initialize(&mut self) -> Result<()> {
53 Ok(())
55 }
56
57 async fn compile_kernel(&self, _source: &str) -> Result<Vec<u8>> {
58 Err(runtime_error!("Kernel compilation not supported on WASM runtime backend"))
60 }
61
62 async fn launch_kernel(
63 &self,
64 _kernel: &[u8],
65 _grid: (u32, u32, u32),
66 _block: (u32, u32, u32),
67 _args: &[*const u8],
68 ) -> Result<()> {
69 Err(runtime_error!("Kernel launch not supported on WASM runtime backend"))
70 }
71
72 fn allocate_memory(&self, size: usize) -> Result<*mut u8> {
73 let layout = std::alloc::Layout::from_size_align(size, 8)
75 .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
76
77 let ptr = unsafe { std::alloc::alloc(layout) };
78
79 if ptr.is_null() {
80 return Err(runtime_error!("Failed to allocate {} bytes", size));
81 }
82
83 Ok(ptr)
84 }
85
86 fn free_memory(&self, ptr: *mut u8) -> Result<()> {
87 Ok(())
91 }
92
93 fn copy_memory(
94 &self,
95 dst: *mut u8,
96 src: *const u8,
97 size: usize,
98 _kind: MemcpyKind,
99 ) -> Result<()> {
100 unsafe {
103 std::ptr::copy_nonoverlapping(src, dst, size);
104 }
105 Ok(())
106 }
107
108 fn synchronize(&self) -> Result<()> {
109 Ok(())
111 }
112
113}