1use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use wgpu::*;
12use crate::error::{CudaRustError, Result};
13use crate::memory::{MemoryPool, allocate, deallocate};
14use crate::profiling::{CounterType, time_operation};
15
16#[derive(Debug, Clone)]
18pub struct WebGPUConfig {
19 pub enable_kernel_cache: bool,
21 pub enable_auto_tuning: bool,
23 pub enable_memory_pooling: bool,
25 pub max_cache_size: usize,
27 pub power_preference: PowerPreference,
29 pub max_buffer_size: u64,
31 pub max_workgroups_per_dimension: u32,
33}
34
35impl Default for WebGPUConfig {
36 fn default() -> Self {
37 Self {
38 enable_kernel_cache: true,
39 enable_auto_tuning: true,
40 enable_memory_pooling: true,
41 max_cache_size: 100,
42 power_preference: PowerPreference::HighPerformance,
43 max_buffer_size: 256 * 1024 * 1024, max_workgroups_per_dimension: 65535,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct CachedKernel {
52 pub pipeline: Arc<ComputePipeline>,
54 pub bind_group_layout: Arc<BindGroupLayout>,
56 pub optimal_workgroup_size: [u32; 3],
58 pub avg_execution_time: f64,
60 pub usage_count: u64,
62 pub total_data_processed: u64,
64}
65
66pub struct OptimizedWebGPUBackend {
68 device: Arc<Device>,
70 queue: Arc<Queue>,
72 config: WebGPUConfig,
74 kernel_cache: Arc<Mutex<HashMap<String, CachedKernel>>>,
76 memory_pool: Arc<MemoryPool>,
78 buffer_cache: Arc<Mutex<HashMap<u64, Vec<Buffer>>>>,
80 stats: Arc<Mutex<BackendStats>>,
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct BackendStats {
87 pub kernels_executed: u64,
89 pub cache_hits: u64,
91 pub cache_misses: u64,
93 pub total_execution_time: f64,
95 pub total_data_transferred: u64,
97 pub memory_allocations: u64,
99 pub buffer_reuse_count: u64,
101}
102
103#[derive(Debug, Clone)]
105pub struct AutoTuneResult {
106 pub workgroup_size: [u32; 3],
108 pub performance: f64,
110 pub memory_bandwidth: f64,
112 pub compute_utilization: f64,
114}
115
116impl OptimizedWebGPUBackend {
117 pub async fn new() -> Result<Self> {
119 Self::with_config(WebGPUConfig::default()).await
120 }
121
122 pub async fn with_config(config: WebGPUConfig) -> Result<Self> {
124 let _timer = time_operation(CounterType::Custom("webgpu_init".to_string()));
125
126 let instance = Instance::new(InstanceDescriptor {
128 backends: Backends::BROWSER_WEBGPU | Backends::GL,
129 flags: InstanceFlags::default(),
130 dx12_shader_compiler: Dx12Compiler::default(),
131 gles_minor_version: Gles3MinorVersion::default(),
132 });
133
134 let adapter = instance
135 .request_adapter(&RequestAdapterOptions {
136 power_preference: config.power_preference,
137 compatible_surface: None,
138 force_fallback_adapter: false,
139 })
140 .await
141 .ok_or_else(|| CudaRustError::Backend("Failed to find suitable WebGPU adapter".to_string()))?;
142
143 let (device, queue) = adapter
145 .request_device(
146 &DeviceDescriptor {
147 label: Some("CUDA-Rust Optimized Device"),
148 required_features: Features::TIMESTAMP_QUERY
149 | Features::TIMESTAMP_QUERY_INSIDE_PASSES
150 | Features::PIPELINE_STATISTICS_QUERY,
151 required_limits: Limits {
152 max_buffer_size: config.max_buffer_size,
153 max_compute_workgroup_storage_size: 32768,
154 max_compute_invocations_per_workgroup: 1024,
155 max_compute_workgroup_size_x: 1024,
156 max_compute_workgroup_size_y: 1024,
157 max_compute_workgroup_size_z: 64,
158 max_compute_workgroups_per_dimension: config.max_workgroups_per_dimension,
159 ..Default::default()
160 },
161 },
162 None,
163 )
164 .await
165 .map_err(|e| CudaRustError::Backend(format!("Failed to create WebGPU device: {e}")))?;
166
167 Ok(Self {
168 device: Arc::new(device),
169 queue: Arc::new(queue),
170 config,
171 kernel_cache: Arc::new(Mutex::new(HashMap::new())),
172 memory_pool: Arc::new(MemoryPool::new()),
173 buffer_cache: Arc::new(Mutex::new(HashMap::new())),
174 stats: Arc::new(Mutex::new(BackendStats::default())),
175 })
176 }
177
178 pub fn compile_kernel(&self, shader_source: &str, entry_point: &str) -> Result<String> {
180 let _timer = time_operation(CounterType::Compilation)
181 .with_size(shader_source.len());
182
183 let cache_key = format!("{}:{}", shader_source.len(), entry_point);
184
185 {
187 let cache = self.kernel_cache.lock().unwrap();
188 if let Some(cached) = cache.get(&cache_key) {
189 let mut stats = self.stats.lock().unwrap();
190 stats.cache_hits += 1;
191 return Ok(cache_key);
192 }
193 }
194
195 let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
197 label: Some("CUDA Kernel"),
198 source: ShaderSource::Wgsl(shader_source.into()),
199 });
200
201 let bind_group_layout = self.device.create_bind_group_layout(&BindGroupLayoutDescriptor {
202 label: Some("Kernel Bind Group Layout"),
203 entries: &[
204 BindGroupLayoutEntry {
205 binding: 0,
206 visibility: ShaderStages::COMPUTE,
207 ty: BindingType::Buffer {
208 ty: BufferBindingType::Storage { read_only: false },
209 has_dynamic_offset: false,
210 min_binding_size: None,
211 },
212 count: None,
213 },
214 ],
215 });
216
217 let pipeline_layout = self.device.create_pipeline_layout(&PipelineLayoutDescriptor {
218 label: Some("Kernel Pipeline Layout"),
219 bind_group_layouts: &[&bind_group_layout],
220 push_constant_ranges: &[],
221 });
222
223 let pipeline = self.device.create_compute_pipeline(&ComputePipelineDescriptor {
224 label: Some("CUDA Kernel Pipeline"),
225 layout: Some(&pipeline_layout),
226 module: &shader_module,
227 entry_point,
228 });
229
230 let optimal_workgroup_size = if self.config.enable_auto_tuning {
232 self.auto_tune_workgroup_size(&pipeline, &bind_group_layout)?
233 } else {
234 [64, 1, 1] };
236
237 let cached_kernel = CachedKernel {
239 pipeline: Arc::new(pipeline),
240 bind_group_layout: Arc::new(bind_group_layout),
241 optimal_workgroup_size,
242 avg_execution_time: 0.0,
243 usage_count: 0,
244 total_data_processed: 0,
245 };
246
247 {
248 let mut cache = self.kernel_cache.lock().unwrap();
249
250 if cache.len() >= self.config.max_cache_size {
252 self.evict_least_used_kernel(&mut cache);
253 }
254
255 cache.insert(cache_key.clone(), cached_kernel);
256 }
257
258 {
259 let mut stats = self.stats.lock().unwrap();
260 stats.cache_misses += 1;
261 }
262
263 Ok(cache_key)
264 }
265
266 pub async fn execute_kernel(
268 &self,
269 cache_key: &str,
270 buffers: &[&Buffer],
271 workgroup_count: [u32; 3]
272 ) -> Result<f64> {
273 let _timer = time_operation(CounterType::KernelExecution);
274
275 let (pipeline, bind_group_layout, optimal_workgroup_size) = {
276 let mut cache = self.kernel_cache.lock().unwrap();
277 let cached = cache.get_mut(cache_key)
278 .ok_or_else(|| CudaRustError::Backend("Kernel not found in cache".to_string()))?;
279
280 cached.usage_count += 1;
281 (
282 cached.pipeline.clone(),
283 cached.bind_group_layout.clone(),
284 cached.optimal_workgroup_size
285 )
286 };
287
288 let entries: Vec<BindGroupEntry> = buffers.iter().enumerate()
290 .map(|(i, buffer)| BindGroupEntry {
291 binding: i as u32,
292 resource: buffer.as_entire_binding(),
293 })
294 .collect();
295
296 let bind_group = self.device.create_bind_group(&BindGroupDescriptor {
297 label: Some("Kernel Bind Group"),
298 layout: &bind_group_layout,
299 entries: &entries,
300 });
301
302 let mut encoder = self.device.create_command_encoder(&CommandEncoderDescriptor {
304 label: Some("Kernel Execution"),
305 });
306
307 {
309 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
310 label: Some("CUDA Kernel Pass"),
311 timestamp_writes: None,
312 });
313
314 compute_pass.set_pipeline(&pipeline);
315 compute_pass.set_bind_group(0, &bind_group, &[]);
316
317 compute_pass.dispatch_workgroups(
319 workgroup_count[0],
320 workgroup_count[1],
321 workgroup_count[2]
322 );
323 }
324
325 #[cfg(target_arch = "wasm32")]
327 let start_time = web_sys::window()
328 .and_then(|w| w.performance())
329 .map(|p| p.now())
330 .unwrap_or(0.0);
331 #[cfg(not(target_arch = "wasm32"))]
332 let start_instant = std::time::Instant::now();
333
334 self.queue.submit(std::iter::once(encoder.finish()));
335
336 self.device.poll(Maintain::Wait);
338
339 #[cfg(target_arch = "wasm32")]
340 let end_time = web_sys::window()
341 .and_then(|w| w.performance())
342 .map(|p| p.now())
343 .unwrap_or(0.0);
344
345 #[cfg(target_arch = "wasm32")]
346 let execution_time = end_time - start_time;
347 #[cfg(not(target_arch = "wasm32"))]
348 let execution_time = start_instant.elapsed().as_secs_f64() * 1000.0;
349
350 {
352 let mut stats = self.stats.lock().unwrap();
353 stats.kernels_executed += 1;
354 stats.total_execution_time += execution_time;
355 }
356
357 {
359 let mut cache = self.kernel_cache.lock().unwrap();
360 if let Some(cached) = cache.get_mut(cache_key) {
361 let alpha = 0.1; cached.avg_execution_time =
363 alpha * execution_time + (1.0 - alpha) * cached.avg_execution_time;
364 }
365 }
366
367 Ok(execution_time)
368 }
369
370 fn auto_tune_workgroup_size(
372 &self,
373 _pipeline: &ComputePipeline,
374 _bind_group_layout: &BindGroupLayout
375 ) -> Result<[u32; 3]> {
376 let candidate_sizes = [
381 [32, 1, 1], [64, 1, 1], [128, 1, 1], [256, 1, 1], [16, 16, 1], [8, 8, 8], ];
388
389 Ok([64, 1, 1])
392 }
393
394 fn evict_least_used_kernel(&self, cache: &mut HashMap<String, CachedKernel>) {
396 if let Some((key_to_remove, _)) = cache.iter()
397 .min_by_key(|(_, cached)| cached.usage_count) {
398 let key_to_remove = key_to_remove.clone();
399 cache.remove(&key_to_remove);
400 }
401 }
402
403 pub fn create_buffer(&self, size: u64, usage: BufferUsages) -> Result<Buffer> {
405 let _timer = time_operation(CounterType::MemoryAllocation)
406 .with_size(size as usize);
407
408 if self.config.enable_memory_pooling {
410 let mut buffer_cache = self.buffer_cache.lock().unwrap();
411 if let Some(buffers) = buffer_cache.get_mut(&size) {
412 if let Some(buffer) = buffers.pop() {
413 let mut stats = self.stats.lock().unwrap();
414 stats.buffer_reuse_count += 1;
415 return Ok(buffer);
416 }
417 }
418 }
419
420 let buffer = self.device.create_buffer(&BufferDescriptor {
422 label: Some("CUDA Buffer"),
423 size,
424 usage,
425 mapped_at_creation: false,
426 });
427
428 {
429 let mut stats = self.stats.lock().unwrap();
430 stats.memory_allocations += 1;
431 }
432
433 Ok(buffer)
434 }
435
436 pub fn return_buffer(&self, buffer: Buffer) {
438 if !self.config.enable_memory_pooling {
439 return;
440 }
441
442 let size = buffer.size();
443 let mut buffer_cache = self.buffer_cache.lock().unwrap();
444
445 let buffers = buffer_cache.entry(size).or_default();
446
447 if buffers.len() < 10 {
449 buffers.push(buffer);
450 }
451 }
452
453 pub fn get_stats(&self) -> BackendStats {
455 self.stats.lock().unwrap().clone()
456 }
457
458 pub fn cache_hit_ratio(&self) -> f64 {
460 let stats = self.stats.lock().unwrap();
461 let total = stats.cache_hits + stats.cache_misses;
462 if total == 0 {
463 0.0
464 } else {
465 stats.cache_hits as f64 / total as f64
466 }
467 }
468
469 pub fn clear_caches(&self) {
471 self.kernel_cache.lock().unwrap().clear();
472 self.buffer_cache.lock().unwrap().clear();
473 *self.stats.lock().unwrap() = BackendStats::default();
474 }
475
476 pub fn performance_report(&self) -> String {
478 let stats = self.get_stats();
479 let cache_ratio = self.cache_hit_ratio();
480 let kernel_cache_size = self.kernel_cache.lock().unwrap().len();
481 let buffer_cache_size: usize = self.buffer_cache.lock().unwrap()
482 .values()
483 .map(|v| v.len())
484 .sum();
485
486 format!(
487 "=== WebGPU Backend Performance Report ===\n\
488 Kernels Executed: {}\n\
489 Cache Hit Ratio: {:.1}%\n\
490 Avg Execution Time: {:.2}ms\n\
491 Total Data Transferred: {:.2}MB\n\
492 Memory Allocations: {}\n\
493 Buffer Reuse Count: {}\n\
494 Kernel Cache Size: {}\n\
495 Buffer Cache Size: {}\n\
496 Memory Pool Stats: {:?}",
497 stats.kernels_executed,
498 cache_ratio * 100.0,
499 if stats.kernels_executed > 0 {
500 stats.total_execution_time / stats.kernels_executed as f64
501 } else {
502 0.0
503 },
504 stats.total_data_transferred as f64 / 1_000_000.0,
505 stats.memory_allocations,
506 stats.buffer_reuse_count,
507 kernel_cache_size,
508 buffer_cache_size,
509 self.memory_pool.stats()
510 )
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[tokio::test]
519 async fn test_webgpu_backend_creation() {
520 if let Ok(backend) = OptimizedWebGPUBackend::new().await {
522 assert!(backend.cache_hit_ratio() == 0.0); }
524 }
525
526 #[test]
527 fn test_auto_tune_result() {
528 let result = AutoTuneResult {
529 workgroup_size: [64, 1, 1],
530 performance: 1000.0,
531 memory_bandwidth: 0.8,
532 compute_utilization: 0.9,
533 };
534
535 assert_eq!(result.workgroup_size, [64, 1, 1]);
536 assert_eq!(result.performance, 1000.0);
537 }
538
539 #[test]
540 fn test_backend_stats() {
541 let stats = BackendStats {
542 kernels_executed: 100,
543 cache_hits: 80,
544 cache_misses: 20,
545 total_execution_time: 1000.0,
546 ..Default::default()
547 };
548
549 assert_eq!(stats.kernels_executed, 100);
550 assert_eq!(stats.cache_hits, 80);
551 assert_eq!(stats.cache_misses, 20);
552 }
553}