1use super::{
7 BridgeConfig, BufferHandle, CompiledKernel, DeviceInfo, GpuBackendTrait, GpuDevice,
8 NeuralIntegrationError, NeuralOperation, NeuralResult, Precision, BindingType,
9};
10use crate::backend::backend_trait::BackendTrait;
11use crate::runtime::Runtime;
12use crate::transpiler::Transpiler;
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex, RwLock};
15
16pub struct WebGpuBackend {
18 device: Option<wgpu::Device>,
19 queue: Option<wgpu::Queue>,
20 adapter_info: Option<wgpu::AdapterInfo>,
21 runtime: Arc<Runtime>,
22 kernel_cache: Arc<RwLock<HashMap<String, CompiledKernel>>>,
23 buffer_pool: Arc<Mutex<BufferPool>>,
24 config: BridgeConfig,
25}
26
27struct BufferPool {
29 buffers: HashMap<BufferHandle, wgpu::Buffer>,
30 free_buffers: Vec<(usize, BufferHandle)>, next_handle: u64,
32}
33
34impl BufferPool {
35 fn new() -> Self {
36 Self {
37 buffers: HashMap::new(),
38 free_buffers: Vec::new(),
39 next_handle: 1,
40 }
41 }
42
43 fn get_or_create(&mut self, device: &wgpu::Device, size: usize, usage: wgpu::BufferUsages) -> BufferHandle {
44 if let Some(pos) = self.free_buffers.iter().position(|(s, _)| *s >= size) {
46 let (_, handle) = self.free_buffers.remove(pos);
47 return handle;
48 }
49
50 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
52 label: Some("Neural operation buffer"),
53 size: size as u64,
54 usage,
55 mapped_at_creation: false,
56 });
57
58 let handle = BufferHandle(self.next_handle);
59 self.next_handle += 1;
60
61 self.buffers.insert(handle, buffer);
62 handle
63 }
64
65 fn return_buffer(&mut self, handle: BufferHandle, size: usize) {
66 self.free_buffers.push((size, handle));
67 }
68
69 fn get_buffer(&self, handle: BufferHandle) -> Option<&wgpu::Buffer> {
70 self.buffers.get(&handle)
71 }
72}
73
74impl WebGpuBackend {
75 pub fn new(config: &BridgeConfig) -> NeuralResult<Self> {
77 let runtime = Arc::new(Runtime::new().map_err(|e| {
78 NeuralIntegrationError::GpuInitError(format!("Failed to create runtime: {e}"))
79 })?);
80
81 let mut backend = Self {
82 device: None,
83 queue: None,
84 adapter_info: None,
85 runtime,
86 kernel_cache: Arc::new(RwLock::new(HashMap::new())),
87 buffer_pool: Arc::new(Mutex::new(BufferPool::new())),
88 config: config.clone(),
89 };
90
91 if let Err(e) = backend.init_webgpu() {
93 log::warn!("WebGPU initialization failed: {e}");
94 if !config.auto_fallback {
95 return Err(e);
96 }
97 }
98
99 Ok(backend)
100 }
101
102 #[cfg(not(target_arch = "wasm32"))]
104 fn init_webgpu(&mut self) -> NeuralResult<()> {
105 use pollster::FutureExt;
106
107 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
108 backends: wgpu::Backends::all(),
109 dx12_shader_compiler: Default::default(),
110 flags: wgpu::InstanceFlags::default(),
111 gles_minor_version: wgpu::Gles3MinorVersion::default(),
112 });
113
114 let adapter = instance.request_adapter(&wgpu::RequestAdapterOptions {
115 power_preference: match self.config.gpu_device {
116 GpuDevice::HighPerformance => wgpu::PowerPreference::HighPerformance,
117 GpuDevice::LowPower => wgpu::PowerPreference::LowPower,
118 _ => wgpu::PowerPreference::default(),
119 },
120 compatible_surface: None,
121 force_fallback_adapter: false,
122 }).block_on().ok_or_else(|| {
123 NeuralIntegrationError::GpuInitError("No suitable GPU adapter found".to_string())
124 })?;
125
126 self.adapter_info = Some(adapter.get_info());
127
128 let (device, queue) = adapter.request_device(
129 &wgpu::DeviceDescriptor {
130 required_features: wgpu::Features::empty(),
131 required_limits: wgpu::Limits::default(),
132 label: Some("Neural Bridge Device"),
133 },
134 None,
135 ).block_on().map_err(|e| {
136 NeuralIntegrationError::GpuInitError(format!("Failed to create device: {e}"))
137 })?;
138
139 self.device = Some(device);
140 self.queue = Some(queue);
141
142 log::info!("WebGPU initialized successfully");
143 Ok(())
144 }
145
146 #[cfg(target_arch = "wasm32")]
147 fn init_webgpu(&mut self) -> NeuralResult<()> {
148 log::info!("WASM WebGPU initialization deferred to runtime");
150 Ok(())
151 }
152
153 fn compile_kernel(&self, cuda_source: &str, name: &str) -> NeuralResult<CompiledKernel> {
155 if let Ok(cache) = self.kernel_cache.read() {
157 if let Some(kernel) = cache.get(name) {
158 return Ok(kernel.clone());
159 }
160 }
161
162 let wgsl_source = self.transpile_cuda_to_wgsl(cuda_source)?;
164
165 let kernel = CompiledKernel {
166 name: name.to_string(),
167 wgsl_source,
168 entry_point: "main".to_string(),
169 workgroup_size: [64, 1, 1], bind_group_layout: vec![
171 BindingType::Buffer { read_only: true }, BindingType::Buffer { read_only: false }, ],
174 };
175
176 if let Ok(mut cache) = self.kernel_cache.write() {
178 cache.insert(name.to_string(), kernel.clone());
179 }
180
181 Ok(kernel)
182 }
183
184 fn transpile_cuda_to_wgsl(&self, cuda_source: &str) -> NeuralResult<String> {
186 let transpiler = Transpiler::new();
188
189 let ast = crate::parser::CudaParser::new()
191 .parse(cuda_source)
192 .map_err(|e| NeuralIntegrationError::TranspilationError(e.to_string()))?;
193
194 let wgsl = transpiler
196 .to_wgsl(ast)
197 .map_err(|e| NeuralIntegrationError::TranspilationError(e.to_string()))?;
198
199 Ok(wgsl)
200 }
201}
202
203impl GpuBackendTrait for WebGpuBackend {
204 fn initialize(&self) -> NeuralResult<()> {
205 if self.device.is_some() && self.queue.is_some() {
206 Ok(())
207 } else {
208 Err(NeuralIntegrationError::GpuInitError("Device not initialized".to_string()))
209 }
210 }
211
212 fn is_available(&self) -> bool {
213 self.device.is_some() && self.queue.is_some()
214 }
215
216 fn get_device_info(&self) -> DeviceInfo {
217 if let Some(ref info) = self.adapter_info {
218 DeviceInfo {
219 name: info.name.clone(),
220 vendor: format!("{:?}", info.vendor),
221 device_type: format!("{:?}", info.device_type),
222 memory_size: 0, compute_units: 0, max_workgroup_size: 256, supports_f16: false, supports_f64: false, }
228 } else {
229 DeviceInfo {
230 name: "Unknown".to_string(),
231 vendor: "Unknown".to_string(),
232 device_type: "Unknown".to_string(),
233 memory_size: 0,
234 compute_units: 0,
235 max_workgroup_size: 64,
236 supports_f16: false,
237 supports_f64: false,
238 }
239 }
240 }
241
242 fn create_buffer(&self, size: usize) -> NeuralResult<BufferHandle> {
243 let device = self.device.as_ref().ok_or_else(|| {
244 NeuralIntegrationError::GpuInitError("Device not initialized".to_string())
245 })?;
246
247 let mut pool = self.buffer_pool.lock().unwrap();
248 let handle = pool.get_or_create(
249 device,
250 size,
251 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
252 );
253
254 Ok(handle)
255 }
256
257 fn execute_kernel(&self, kernel: &CompiledKernel, inputs: &[BufferHandle]) -> NeuralResult<BufferHandle> {
258 let device = self.device.as_ref().ok_or_else(|| {
259 NeuralIntegrationError::GpuInitError("Device not initialized".to_string())
260 })?;
261
262 let queue = self.queue.as_ref().ok_or_else(|| {
263 NeuralIntegrationError::GpuInitError("Queue not initialized".to_string())
264 })?;
265
266 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
268 label: Some(&format!("{} shader", kernel.name)),
269 source: wgpu::ShaderSource::Wgsl(kernel.wgsl_source.as_str().into()),
270 });
271
272 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
274 label: Some(&format!("{} bind group layout", kernel.name)),
275 entries: &kernel.bind_group_layout.iter().enumerate().map(|(i, binding_type)| {
276 wgpu::BindGroupLayoutEntry {
277 binding: i as u32,
278 visibility: wgpu::ShaderStages::COMPUTE,
279 ty: match binding_type {
280 BindingType::Buffer { read_only } => wgpu::BindingType::Buffer {
281 ty: wgpu::BufferBindingType::Storage { read_only: *read_only },
282 has_dynamic_offset: false,
283 min_binding_size: None,
284 },
285 BindingType::UniformBuffer => wgpu::BindingType::Buffer {
286 ty: wgpu::BufferBindingType::Uniform,
287 has_dynamic_offset: false,
288 min_binding_size: None,
289 },
290 BindingType::StorageTexture => wgpu::BindingType::StorageTexture {
291 access: wgpu::StorageTextureAccess::WriteOnly,
292 format: wgpu::TextureFormat::Rgba8Unorm,
293 view_dimension: wgpu::TextureViewDimension::D2,
294 },
295 },
296 count: None,
297 }
298 }).collect::<Vec<_>>(),
299 });
300
301 let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
303 label: Some(&format!("{} pipeline layout", kernel.name)),
304 bind_group_layouts: &[&bind_group_layout],
305 push_constant_ranges: &[],
306 });
307
308 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
309 label: Some(&format!("{} pipeline", kernel.name)),
310 layout: Some(&compute_pipeline_layout),
311 module: &shader_module,
312 entry_point: &kernel.entry_point,
313 });
314
315 let pool = self.buffer_pool.lock().unwrap();
317 let input_buffers: Vec<&wgpu::Buffer> = inputs.iter()
318 .map(|handle| pool.get_buffer(*handle))
319 .collect::<Option<Vec<_>>>()
320 .ok_or_else(|| NeuralIntegrationError::OperationError("Invalid buffer handle".to_string()))?;
321
322 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
324 label: Some("Output buffer"),
325 size: input_buffers[0].size(),
326 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
327 mapped_at_creation: false,
328 });
329
330 let mut bind_group_entries = Vec::new();
332 for (i, buffer) in input_buffers.iter().enumerate() {
333 bind_group_entries.push(wgpu::BindGroupEntry {
334 binding: i as u32,
335 resource: buffer.as_entire_binding(),
336 });
337 }
338 bind_group_entries.push(wgpu::BindGroupEntry {
339 binding: input_buffers.len() as u32,
340 resource: output_buffer.as_entire_binding(),
341 });
342
343 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
344 label: Some(&format!("{} bind group", kernel.name)),
345 layout: &bind_group_layout,
346 entries: &bind_group_entries,
347 });
348
349 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
351 label: Some(&format!("{} encoder", kernel.name)),
352 });
353
354 {
355 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
356 label: Some(&format!("{} pass", kernel.name)),
357 timestamp_writes: None,
358 });
359
360 compute_pass.set_pipeline(&compute_pipeline);
361 compute_pass.set_bind_group(0, &bind_group, &[]);
362
363 let workgroup_count = (input_buffers[0].size() as u32 / 4) / kernel.workgroup_size[0] + 1;
365 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
366 }
367
368 queue.submit(std::iter::once(encoder.finish()));
369
370 drop(pool);
374 let mut pool = self.buffer_pool.lock().unwrap();
375 let handle = BufferHandle(pool.next_handle);
376 pool.next_handle += 1;
377 pool.buffers.insert(handle, output_buffer);
378
379 Ok(handle)
380 }
381}
382
383pub fn extract_wgsl_from_rust(rust_code: &str) -> NeuralResult<CompiledKernel> {
385 let wgsl_source = generate_basic_wgsl(rust_code)?;
390
391 Ok(CompiledKernel {
392 name: "extracted_kernel".to_string(),
393 wgsl_source,
394 entry_point: "main".to_string(),
395 workgroup_size: [64, 1, 1],
396 bind_group_layout: vec![
397 BindingType::Buffer { read_only: true },
398 BindingType::Buffer { read_only: false },
399 ],
400 })
401}
402
403fn generate_basic_wgsl(rust_code: &str) -> NeuralResult<String> {
405 if rust_code.contains("matrix_multiply") || rust_code.contains("matmul") {
407 Ok(include_str!("../webgpu/shaders/matrix_vector_multiply.wgsl").to_string())
408 } else if rust_code.contains("vector_add") || rust_code.contains("add") {
409 Ok(r#"
410@group(0) @binding(0) var<storage, read> input_a: array<f32>;
411@group(0) @binding(1) var<storage, read> input_b: array<f32>;
412@group(0) @binding(2) var<storage, read_write> output: array<f32>;
413
414@compute @workgroup_size(64)
415fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
416 let index = global_id.x;
417 if (index >= arrayLength(&input_a)) {
418 return;
419 }
420 output[index] = input_a[index] + input_b[index];
421}
422"#.to_string())
423 } else if rust_code.contains("sigmoid") {
424 Ok(r#"
425@group(0) @binding(0) var<storage, read> input: array<f32>;
426@group(0) @binding(1) var<storage, read_write> output: array<f32>;
427
428@compute @workgroup_size(64)
429fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
430 let index = global_id.x;
431 if (index >= arrayLength(&input)) {
432 return;
433 }
434 output[index] = 1.0 / (1.0 + exp(-input[index]));
435}
436"#.to_string())
437 } else {
438 Ok(r#"
440@group(0) @binding(0) var<storage, read> input: array<f32>;
441@group(0) @binding(1) var<storage, read_write> output: array<f32>;
442
443@compute @workgroup_size(64)
444fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
445 let index = global_id.x;
446 if (index >= arrayLength(&input)) {
447 return;
448 }
449 output[index] = input[index];
450}
451"#.to_string())
452 }
453}
454
455pub fn execute_cpu_fallback<T>(operation: NeuralOperation<T>, inputs: &[T]) -> NeuralResult<Vec<T>>
457where
458 T: Clone + Send + Sync + 'static + num_traits::Float,
459{
460 match operation {
461 NeuralOperation::VectorAdd { size, _phantom } => {
462 if inputs.len() < size * 2 {
463 return Err(NeuralIntegrationError::OperationError("Insufficient input data".to_string()));
464 }
465
466 let mut result = Vec::with_capacity(size);
467 for i in 0..size {
468 result.push(inputs[i] + inputs[i + size]);
469 }
470 Ok(result)
471 }
472
473 NeuralOperation::ActivationFunction { function, size, _phantom } => {
474 if inputs.len() < size {
475 return Err(NeuralIntegrationError::OperationError("Insufficient input data".to_string()));
476 }
477
478 let mut result = Vec::with_capacity(size);
479 for i in 0..size {
480 let value = match function {
481 super::ActivationFunction::Sigmoid => {
482 T::one() / (T::one() + (-inputs[i]).exp())
483 }
484 super::ActivationFunction::ReLU => {
485 if inputs[i] > T::zero() { inputs[i] } else { T::zero() }
486 }
487 super::ActivationFunction::Tanh => inputs[i].tanh(),
488 super::ActivationFunction::LeakyReLU => {
489 if inputs[i] > T::zero() {
490 inputs[i]
491 } else {
492 inputs[i] * T::from(0.01).unwrap_or(T::zero())
493 }
494 }
495 super::ActivationFunction::Swish => {
496 inputs[i] * (T::one() / (T::one() + (-inputs[i]).exp()))
497 }
498 super::ActivationFunction::GELU => {
499 let sqrt_2_pi = T::from(0.7978845608).unwrap_or(T::one());
501 let x = inputs[i];
502 x * T::from(0.5).unwrap_or(T::one()) *
503 (T::one() + (sqrt_2_pi * (x + T::from(0.044715).unwrap_or(T::zero()) * x * x * x)).tanh())
504 }
505 };
506 result.push(value);
507 }
508 Ok(result)
509 }
510
511 NeuralOperation::MatrixMultiply { a_rows, a_cols, b_cols, _phantom } => {
512 if inputs.len() < a_rows * a_cols + a_cols * b_cols {
513 return Err(NeuralIntegrationError::OperationError("Insufficient input data for matrix multiplication".to_string()));
514 }
515
516 let mut result = Vec::with_capacity(a_rows * b_cols);
517 let matrix_a = &inputs[0..a_rows * a_cols];
518 let matrix_b = &inputs[a_rows * a_cols..];
519
520 for i in 0..a_rows {
521 for j in 0..b_cols {
522 let mut sum = T::zero();
523 for k in 0..a_cols {
524 sum = sum + matrix_a[i * a_cols + k] * matrix_b[k * b_cols + j];
525 }
526 result.push(sum);
527 }
528 }
529 Ok(result)
530 }
531
532 _ => {
533 Err(NeuralIntegrationError::OperationError(format!("CPU fallback not implemented for operation: {}", operation.name())))
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_cpu_vector_add() {
544 let operation = NeuralOperation::VectorAdd { size: 3, _phantom: std::marker::PhantomData };
545 let inputs = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
546 let result = execute_cpu_fallback(operation, &inputs).unwrap();
547 assert_eq!(result, vec![5.0, 7.0, 9.0]);
548 }
549
550 #[test]
551 fn test_cpu_sigmoid() {
552 let operation = NeuralOperation::ActivationFunction {
553 function: super::super::ActivationFunction::Sigmoid,
554 size: 3,
555 _phantom: std::marker::PhantomData
556 };
557 let inputs = vec![0.0f32, 1.0, -1.0];
558 let result = execute_cpu_fallback(operation, &inputs).unwrap();
559
560 assert!((result[0] - 0.5).abs() < 1e-6);
562 assert!(result[1] > 0.5);
564 assert!(result[2] < 0.5);
566 }
567
568 #[test]
569 fn test_wgsl_generation() {
570 let rust_code = "fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> { ... }";
571 let wgsl = generate_basic_wgsl(rust_code).unwrap();
572 assert!(wgsl.contains("vector_add") || wgsl.contains("input_a"));
573 assert!(wgsl.contains("@compute"));
574 }
575}