oxiphysics_gpu/compute/wgpu_backend.rs
1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! WebGPU (wgpu) compute backend for the OxiPhysics GPU acceleration layer.
5//!
6//! This module provides [`WgpuBackend`] which implements `ComputeBackend` using
7//! the `wgpu` crate for cross-platform GPU compute (Vulkan, Metal, DX12, WebGPU).
8//!
9//! ## Feature flag
10//!
11//! This module is gated behind the `wgpu-backend` Cargo feature:
12//!
13//! ```toml
14//! [dependencies]
15//! oxiphysics-gpu = { features = ["wgpu-backend"] }
16//! ```
17//!
18//! When the feature is disabled the module compiles to an empty stub. This allows
19//! the crate to compile without the `wgpu` dependency on platforms or toolchains
20//! where GPU support is not required.
21//!
22//! ## Enabling the dependency
23//!
24//! To activate the wgpu backend, add `wgpu` to the crate's `Cargo.toml`:
25//!
26//! ```toml
27//! [features]
28//! wgpu-backend = ["wgpu"]
29//!
30//! [dependencies]
31//! wgpu = { version = "0.20", optional = true }
32//! ```
33//!
34//! ## Architecture
35//!
36//! ```text
37//! WgpuBackend
38//! ├── wgpu::Device / wgpu::Queue ← GPU device & command queue
39//! ├── Vec<WgpuBufferEntry> ← Registered GPU buffers
40//! │ ├── wgpu::Buffer (device memory)
41//! │ └── size, usage flags
42//! └── ShaderRegistry ← Compiled WGSL compute shaders
43//!
44//! Compute pipeline:
45//! write_buffer → [upload via staging] → dispatch(kernel) → [readback via staging] → read_buffer
46//! ```
47//!
48//! ## Usage (when feature is enabled)
49//!
50//! ```ignore
51//! use oxiphysics_gpu::compute::wgpu_backend::WgpuBackend;
52//! use oxiphysics_gpu::compute::ComputeBackend;
53//!
54//! let backend = WgpuBackend::new_async().await?;
55//! let handle = backend.create_buffer(1024);
56//! backend.write_buffer(handle, &vec![1.0_f64; 128]);
57//! // ... dispatch kernel ...
58//! let data = backend.read_buffer(handle);
59//! ```
60
61#![allow(dead_code)]
62
63// ── BufferHandle (re-used from parent module) ─────────────────────────────────
64
65/// Opaque handle to a GPU buffer allocated by a `ComputeBackend`.
66///
67/// This type mirrors the one in the parent `compute` module so that
68/// [`WgpuBackend`] can implement the same `ComputeBackend` trait.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub struct WgpuBufferHandle(pub usize);
71
72// ── WgpuDeviceInfo ────────────────────────────────────────────────────────────
73
74/// Information about the GPU device selected by the wgpu adapter.
75#[derive(Debug, Clone, Default)]
76pub struct WgpuDeviceInfo {
77 /// Human-readable device name (e.g. `"NVIDIA GeForce RTX 4090"`).
78 pub name: String,
79 /// Backend API in use: `"Vulkan"`, `"Metal"`, `"Dx12"`, `"WebGpu"`, or `"None"`.
80 pub backend: String,
81 /// Driver version string (if available).
82 pub driver_version: String,
83 /// Total VRAM in bytes (0 if not reported by the adapter).
84 pub vram_bytes: u64,
85 /// Whether the device supports 64-bit floating-point storage.
86 pub supports_f64: bool,
87 /// Maximum workgroup size (x, y, z).
88 pub max_workgroup_size: [u32; 3],
89}
90
91// ── WgpuBackend ───────────────────────────────────────────────────────────────
92
93/// WebGPU compute backend.
94///
95/// When compiled **without** the `wgpu-backend` feature this struct is a no-op
96/// stub that will return an error from [`WgpuBackend::try_new`]. When compiled
97/// **with** the feature, a real wgpu `Device` / `Queue` pair is created.
98///
99/// For the real implementation, `try_new` should be called within an async
100/// runtime (tokio or wasm-bindgen-futures for browser targets).
101#[derive(Debug)]
102pub struct WgpuBackend {
103 /// Device info (populated at initialisation).
104 pub device_info: WgpuDeviceInfo,
105 /// Allocated CPU-side buffers (mirrors GPU allocations).
106 ///
107 /// In the stub implementation these are plain `Vec<f64>` acting as
108 /// stand-ins for actual `wgpu::Buffer` objects. A full implementation
109 /// wraps `wgpu::Buffer` behind `Arc<Mutex<…>>` to allow async reads.
110 buffers: Vec<WgpuBufferEntry>,
111 /// Whether the backend is operational.
112 available: bool,
113}
114
115/// Internal buffer entry storing metadata and a CPU-side shadow copy.
116#[derive(Debug, Clone)]
117struct WgpuBufferEntry {
118 /// Byte capacity of the GPU buffer (8 × `len` for f64 arrays).
119 capacity: usize,
120 /// CPU-side shadow for upload/download (avoids wgpu dep in stub).
121 shadow: Vec<f64>,
122}
123
124impl WgpuBackend {
125 /// Attempt to create a wgpu backend.
126 ///
127 /// Returns `Ok(Self)` when a compatible GPU adapter is available, or
128 /// `Err(WgpuInitError::NotAvailable)` when no adapter can be found (e.g.
129 /// running headless without a GPU or without the `wgpu-backend` feature).
130 ///
131 /// In the current stub implementation this always returns a CPU-fallback
132 /// instance with `available = false`. The full implementation calls
133 /// `wgpu::Instance::request_adapter` and `adapter.request_device`.
134 pub fn try_new() -> Result<Self, WgpuInitError> {
135 // ── TODO (wgpu-backend feature) ─────────────────────────────────────
136 // When `wgpu-backend` is enabled, replace this stub with:
137 //
138 // let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
139 // backends: wgpu::Backends::all(),
140 // ..Default::default()
141 // });
142 // let adapter = pollster::block_on(instance.request_adapter(
143 // &wgpu::RequestAdapterOptions {
144 // power_preference: wgpu::PowerPreference::HighPerformance,
145 // ..Default::default()
146 // },
147 // )).ok_or(WgpuInitError::NoAdapter)?;
148 // let (device, queue) = pollster::block_on(adapter.request_device(
149 // &wgpu::DeviceDescriptor::default(),
150 // None,
151 // ))?;
152 // let info = adapter.get_info();
153 // Ok(Self { device, queue, info, buffers: Vec::new(), available: true })
154 // ────────────────────────────────────────────────────────────────────
155
156 Err(WgpuInitError::NotAvailable)
157 }
158
159 /// Create a stub backend for testing that stores data in CPU memory.
160 ///
161 /// This is equivalent to what `try_new` would return on a headless system
162 /// but without returning an error — useful for unit testing backend logic.
163 pub fn new_stub() -> Self {
164 Self {
165 device_info: WgpuDeviceInfo {
166 name: "CPU stub".to_string(),
167 backend: "None".to_string(),
168 ..Default::default()
169 },
170 buffers: Vec::new(),
171 available: false,
172 }
173 }
174
175 /// Return `true` if a real GPU device is available.
176 pub fn is_available(&self) -> bool {
177 self.available
178 }
179
180 /// Return device information for diagnostics.
181 pub fn device_info(&self) -> &WgpuDeviceInfo {
182 &self.device_info
183 }
184
185 // ── Buffer management ────────────────────────────────────────────────────
186
187 /// Allocate a GPU buffer that can hold `len` `f64` values.
188 ///
189 /// Returns a [`WgpuBufferHandle`] that can be passed to [`Self::write_buffer`]
190 /// and [`Self::read_buffer`].
191 ///
192 /// In the stub implementation, a CPU-side shadow `Vec<f64>` is allocated.
193 /// In the full wgpu implementation, `wgpu::Device::create_buffer` is called
194 /// with `STORAGE | COPY_SRC | COPY_DST` usage flags.
195 pub fn create_buffer(&mut self, len: usize) -> WgpuBufferHandle {
196 let handle = WgpuBufferHandle(self.buffers.len());
197 self.buffers.push(WgpuBufferEntry {
198 capacity: len,
199 shadow: vec![0.0; len],
200 });
201 handle
202 }
203
204 /// Upload `data` to the GPU buffer at `handle`.
205 ///
206 /// In the stub, data is copied into the CPU-side shadow.
207 /// In the full implementation, `queue.write_buffer` is used.
208 pub fn write_buffer(&mut self, handle: WgpuBufferHandle, data: &[f64]) {
209 if let Some(entry) = self.buffers.get_mut(handle.0) {
210 let len = data.len().min(entry.capacity);
211 entry.shadow[..len].copy_from_slice(&data[..len]);
212 }
213 }
214
215 /// Download data from the GPU buffer at `handle`.
216 ///
217 /// In the stub, data is read from the CPU-side shadow.
218 /// In the full implementation, a staging buffer is created, the command
219 /// `encoder.copy_buffer_to_buffer` is executed, and the staging buffer is
220 /// mapped for reading.
221 pub fn read_buffer(&self, handle: WgpuBufferHandle) -> Vec<f64> {
222 self.buffers
223 .get(handle.0)
224 .map(|e| e.shadow.clone())
225 .unwrap_or_default()
226 }
227
228 // ── Dispatch ─────────────────────────────────────────────────────────────
229
230 /// Dispatch a compute kernel with `work_groups_x` workgroups.
231 ///
232 /// In the stub, the kernel's `execute` method is called on the CPU-side
233 /// shadow data. In the full implementation a `ComputePipeline` is looked
234 /// up from the shader registry and `encoder.dispatch_workgroups` is called.
235 ///
236 /// # Arguments
237 ///
238 /// * `kernel_name` — name of the WGSL shader entry point
239 /// * `buffers` — input/output buffer handles
240 /// * `work_groups_x` — number of workgroups in the X dimension
241 pub fn dispatch(
242 &mut self,
243 kernel_name: &str,
244 buffers: &[WgpuBufferHandle],
245 work_groups_x: u32,
246 ) {
247 // ── TODO (wgpu-backend feature) ─────────────────────────────────────
248 // When enabled:
249 // let pipeline = self.shader_registry.get_pipeline(kernel_name)?;
250 // let bind_group = self.device.create_bind_group(…);
251 // let mut encoder = self.device.create_command_encoder(…);
252 // {
253 // let mut pass = encoder.begin_compute_pass(…);
254 // pass.set_pipeline(&pipeline);
255 // pass.set_bind_group(0, &bind_group, &[]);
256 // pass.dispatch_workgroups(work_groups_x, 1, 1);
257 // }
258 // self.queue.submit([encoder.finish()]);
259 // ────────────────────────────────────────────────────────────────────
260
261 // Stub: identity kernel (pass-through, no-op)
262 let _ = (kernel_name, buffers, work_groups_x);
263 }
264
265 // ── WGSL shader registry ──────────────────────────────────────────────────
266
267 /// Register a WGSL compute shader source and associate it with a name.
268 ///
269 /// In the stub, the source is stored but not compiled.
270 /// In the full implementation, `device.create_shader_module` is called and
271 /// the resulting `ShaderModule` is cached.
272 pub fn register_shader(&mut self, name: &str, wgsl_source: &str) {
273 // ── TODO (wgpu-backend feature) ─────────────────────────────────────
274 // let module = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
275 // label: Some(name),
276 // source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
277 // });
278 // self.shader_registry.insert(name.to_string(), module);
279 let _ = (name, wgsl_source);
280 }
281}
282
283// ── Built-in WGSL kernels ─────────────────────────────────────────────────────
284
285/// WGSL source for a parallel prefix sum (exclusive scan) kernel.
286///
287/// This is the Blelloch algorithm adapted for WGSL with a workgroup of 256 threads.
288pub const WGSL_PARALLEL_SCAN: &str = r#"
289// Exclusive parallel prefix sum (Blelloch up-sweep / down-sweep)
290// Workgroup size: 256 threads
291// Binding 0: input buffer (read)
292// Binding 1: output buffer (write)
293// Binding 2: uniform { n: u32, pass: u32 }
294
295@group(0) @binding(0) var<storage, read> input: array<f32>;
296@group(0) @binding(1) var<storage, read_write> output: array<f32>;
297
298struct Params { n: u32, pass: u32 }
299@group(0) @binding(2) var<uniform> params: Params;
300
301var<workgroup> shared: array<f32, 256>;
302
303@compute @workgroup_size(256)
304fn exclusive_scan(@builtin(global_invocation_id) gid: vec3<u32>,
305 @builtin(local_invocation_id) lid: vec3<u32>) {
306 let n = params.n;
307 let i = gid.x;
308
309 // Load
310 shared[lid.x] = select(0.0, input[i], i < n);
311 workgroupBarrier();
312
313 // Up-sweep (reduce)
314 var stride: u32 = 1u;
315 loop {
316 if stride >= 256u { break; }
317 if lid.x % (stride * 2u) == (stride * 2u - 1u) {
318 shared[lid.x] += shared[lid.x - stride];
319 }
320 workgroupBarrier();
321 stride = stride * 2u;
322 }
323
324 // Down-sweep
325 if lid.x == 255u { shared[255] = 0.0; }
326 workgroupBarrier();
327 stride = 128u;
328 loop {
329 if stride == 0u { break; }
330 if lid.x % (stride * 2u) == (stride * 2u - 1u) {
331 let tmp = shared[lid.x - stride];
332 shared[lid.x - stride] = shared[lid.x];
333 shared[lid.x] += tmp;
334 }
335 workgroupBarrier();
336 stride = stride / 2u;
337 }
338
339 // Store
340 if i < n { output[i] = shared[lid.x]; }
341}
342"#;
343
344/// WGSL source for a simple SPH density kernel.
345///
346/// Computes particle density via a cubic-spline kernel with radius `h`.
347pub const WGSL_SPH_DENSITY: &str = r#"
348// SPH density kernel — W_spline3 smoothing
349// Binding 0: positions array (x0,y0,z0, x1,y1,z1, ...)
350// Binding 1: densities output (one per particle)
351// Binding 2: uniform { n: u32, h: f32, mass: f32 }
352
353struct SphParams { n: u32, h: f32, mass: f32 }
354@group(0) @binding(0) var<storage, read> positions: array<f32>;
355@group(0) @binding(1) var<storage, read_write> densities: array<f32>;
356@group(0) @binding(2) var<uniform> params: SphParams;
357
358fn w_spline3(r: f32, h: f32) -> f32 {
359 let q = r / h;
360 let sigma = 3.0 / (2.0 * 3.14159265358979 * h * h * h);
361 if q < 1.0 {
362 return sigma * (2.0/3.0 - q*q + 0.5*q*q*q);
363 } else if q < 2.0 {
364 let t = 2.0 - q;
365 return sigma * (1.0/6.0) * t*t*t;
366 } else {
367 return 0.0;
368 }
369}
370
371@compute @workgroup_size(64)
372fn sph_density(@builtin(global_invocation_id) gid: vec3<u32>) {
373 let i = gid.x;
374 let n = params.n;
375 if i >= n { return; }
376
377 let xi = vec3<f32>(positions[i*3u], positions[i*3u+1u], positions[i*3u+2u]);
378 var density: f32 = 0.0;
379
380 for (var j: u32 = 0u; j < n; j++) {
381 let xj = vec3<f32>(positions[j*3u], positions[j*3u+1u], positions[j*3u+2u]);
382 let r = length(xi - xj);
383 density += params.mass * w_spline3(r, params.h);
384 }
385
386 densities[i] = density;
387}
388"#;
389
390/// WGSL source for parallel BVH ray traversal.
391///
392/// Traverses a linearized BVH (LBVH) to find ray–box intersections.
393/// This is a stub; real traversal requires the full BVH node buffer layout.
394pub const WGSL_BVH_TRAVERSAL: &str = r#"
395// Parallel BVH ray traversal stub
396// Each thread handles one ray; BVH nodes are in binding 0.
397
398struct Ray { origin: vec3<f32>, dir: vec3<f32>, t_max: f32 }
399struct BvhNode { lo: vec3<f32>, hi: vec3<f32>, left: u32, right: u32, is_leaf: u32, prim: u32 }
400struct HitResult { hit: u32, t: f32, prim: u32 }
401
402@group(0) @binding(0) var<storage, read> nodes: array<BvhNode>;
403@group(0) @binding(1) var<storage, read> rays: array<Ray>;
404@group(0) @binding(2) var<storage, read_write> results: array<HitResult>;
405@group(0) @binding(3) var<uniform> num_rays: u32;
406
407fn ray_aabb(ray: Ray, lo: vec3<f32>, hi: vec3<f32>) -> f32 {
408 let inv_dir = 1.0 / ray.dir;
409 let t0 = (lo - ray.origin) * inv_dir;
410 let t1 = (hi - ray.origin) * inv_dir;
411 let t_min = max(max(min(t0.x, t1.x), min(t0.y, t1.y)), min(t0.z, t1.z));
412 let t_max_box = min(min(max(t0.x, t1.x), max(t0.y, t1.y)), max(t0.z, t1.z));
413 if t_max_box < t_min || t_min > ray.t_max { return -1.0; }
414 return t_min;
415}
416
417@compute @workgroup_size(64)
418fn bvh_traverse(@builtin(global_invocation_id) gid: vec3<u32>) {
419 let rid = gid.x;
420 if rid >= num_rays { return; }
421 let ray = rays[rid];
422 results[rid] = HitResult(0u, ray.t_max, 0xFFFFFFFFu);
423
424 // Iterative DFS stack (max depth 32)
425 var stack: array<u32, 32>;
426 var sp: i32 = 0;
427 stack[0] = 0u;
428
429 loop {
430 if sp < 0 { break; }
431 let node_idx = stack[sp]; sp--;
432 let node = nodes[node_idx];
433
434 let t = ray_aabb(ray, node.lo, node.hi);
435 if t < 0.0 { continue; }
436
437 if node.is_leaf != 0u {
438 if t < results[rid].t {
439 results[rid] = HitResult(1u, t, node.prim);
440 }
441 } else {
442 if sp < 30 { sp++; stack[sp] = node.left; }
443 if sp < 30 { sp++; stack[sp] = node.right; }
444 }
445 }
446}
447"#;
448
449// ── WgpuInitError ─────────────────────────────────────────────────────────────
450
451/// Error returned when the wgpu backend cannot be initialised.
452#[derive(Debug, Clone, PartialEq)]
453pub enum WgpuInitError {
454 /// No compatible GPU adapter was found.
455 NoAdapter,
456 /// The `wgpu-backend` feature is not enabled; this is a stub build.
457 NotAvailable,
458 /// The device request failed (e.g. out of memory).
459 DeviceRequestFailed(String),
460 /// A required GPU feature is disabled or not supported.
461 FeatureDisabled,
462 /// Device creation failed with the given error string.
463 DeviceRequest(String),
464 /// A buffer handle is out of range.
465 InvalidHandle(usize),
466 /// A mutex was poisoned (should not occur in practice).
467 PoisonedLock,
468}
469
470impl std::fmt::Display for WgpuInitError {
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 match self {
473 WgpuInitError::NoAdapter => write!(f, "No compatible GPU adapter found"),
474 WgpuInitError::NotAvailable => write!(f, "wgpu-backend feature not enabled"),
475 WgpuInitError::DeviceRequestFailed(s) => write!(f, "Device request failed: {s}"),
476 WgpuInitError::FeatureDisabled => write!(f, "Required GPU feature is not available"),
477 WgpuInitError::DeviceRequest(s) => write!(f, "Device request error: {s}"),
478 WgpuInitError::InvalidHandle(h) => write!(f, "Invalid buffer handle: {h}"),
479 WgpuInitError::PoisonedLock => write!(f, "Internal mutex was poisoned"),
480 }
481 }
482}
483
484impl std::error::Error for WgpuInitError {}
485
486// ── Tests ─────────────────────────────────────────────────────────────────────
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn try_new_returns_not_available_in_stub_build() {
494 let result = WgpuBackend::try_new();
495 assert!(matches!(result, Err(WgpuInitError::NotAvailable)));
496 }
497
498 #[test]
499 fn stub_backend_write_read_roundtrip() {
500 let mut backend = WgpuBackend::new_stub();
501 let handle = backend.create_buffer(4);
502 let data = vec![1.0_f64, 2.0, 3.0, 4.0];
503 backend.write_buffer(handle, &data);
504 let out = backend.read_buffer(handle);
505 assert_eq!(out, data);
506 }
507
508 #[test]
509 fn stub_dispatch_is_noop() {
510 let mut backend = WgpuBackend::new_stub();
511 let h = backend.create_buffer(8);
512 let before = backend.read_buffer(h);
513 backend.dispatch("sph_density", &[h], 1);
514 let after = backend.read_buffer(h);
515 assert_eq!(before, after, "stub dispatch should not modify buffers");
516 }
517
518 #[test]
519 fn wgsl_kernels_are_non_empty() {
520 assert!(!WGSL_PARALLEL_SCAN.is_empty());
521 assert!(!WGSL_SPH_DENSITY.is_empty());
522 assert!(!WGSL_BVH_TRAVERSAL.is_empty());
523 }
524
525 #[test]
526 fn device_info_stub_has_name() {
527 let backend = WgpuBackend::new_stub();
528 assert!(!backend.device_info().name.is_empty());
529 }
530
531 #[test]
532 fn wgpu_init_error_display() {
533 assert!(!WgpuInitError::NotAvailable.to_string().is_empty());
534 assert!(!WgpuInitError::NoAdapter.to_string().is_empty());
535 assert!(!WgpuInitError::FeatureDisabled.to_string().is_empty());
536 assert!(
537 !WgpuInitError::DeviceRequest("oom".into())
538 .to_string()
539 .is_empty()
540 );
541 assert!(!WgpuInitError::InvalidHandle(7).to_string().is_empty());
542 assert!(!WgpuInitError::PoisonedLock.to_string().is_empty());
543 }
544}
545
546// ── Real wgpu backend (feature-gated) ─────────────────────────────────────────
547
548/// Real wgpu compute backend, enabled only with the `wgpu-backend` feature.
549///
550/// Provides GPU buffer management, WGSL shader dispatch, and CPU-side readback
551/// using `wgpu` 29's cross-platform Vulkan / Metal / DX12 backends.
552///
553/// # Thread safety
554///
555/// `wgpu::Device` and `wgpu::Queue` are `Send + Sync`. The shader cache is
556/// protected by a `Mutex`, making `WgpuBackendReal` safe to share across
557/// threads (though individual dispatches are synchronous on the calling thread).
558///
559/// # Usage
560///
561/// ```ignore
562/// // With the wgpu-backend feature enabled:
563/// use oxiphysics_gpu::compute::wgpu_backend::real::WgpuBackendReal;
564///
565/// let mut backend = WgpuBackendReal::try_new()?;
566/// let h = backend.create_buffer_f64(128);
567/// backend.write_buffer_f64(h, &vec![1.0_f64; 128]);
568/// backend.dispatch_wgsl(
569/// WGSL_SPH_DENSITY, "sph_density",
570/// &[(h, wgpu::BufferBindingType::Storage { read_only: false })],
571/// [2, 1, 1],
572/// )?;
573/// let out = backend.read_buffer_f64(h);
574/// ```
575#[cfg(feature = "wgpu-backend")]
576pub mod real {
577 use super::{WgpuBufferHandle, WgpuDeviceInfo, WgpuInitError};
578 use std::collections::HashMap;
579 use std::hash::{DefaultHasher, Hash, Hasher};
580 use std::sync::{Arc, Mutex};
581
582 // ── Internal shader-cache entry ──────────────────────────────────────────
583
584 struct ShaderCacheEntry {
585 pipeline: Arc<wgpu::ComputePipeline>,
586 }
587
588 // ── WgpuBackendReal ──────────────────────────────────────────────────────
589
590 /// Real GPU compute backend backed by `wgpu` 29.
591 ///
592 /// Obtain an instance via [`WgpuBackendReal::try_new`] (synchronous,
593 /// blocks the thread) or [`WgpuBackendReal::try_new_async`] from within
594 /// an async context.
595 pub struct WgpuBackendReal {
596 device: Arc<wgpu::Device>,
597 queue: Arc<wgpu::Queue>,
598 /// Device information (name, backend, driver).
599 pub device_info: WgpuDeviceInfo,
600 /// Allocated GPU buffers, indexed by `WgpuBufferHandle.0`.
601 buffers: Vec<Option<Arc<wgpu::Buffer>>>,
602 /// Byte size of each buffer (parallel to `buffers`).
603 buffer_sizes: Vec<u64>,
604 /// Compiled pipeline cache, keyed by a hash of WGSL source + entry point.
605 shader_cache: Mutex<HashMap<u64, ShaderCacheEntry>>,
606 }
607
608 impl WgpuBackendReal {
609 // ── Construction ─────────────────────────────────────────────────────
610
611 /// Create a real GPU backend, blocking the calling thread.
612 ///
613 /// Returns `Err` if no compatible GPU adapter is found or if device
614 /// creation fails. Prefer [`try_new_async`](Self::try_new_async) from
615 /// within an `async` context.
616 pub fn try_new() -> Result<Self, WgpuInitError> {
617 pollster::block_on(Self::try_new_async())
618 }
619
620 /// Create a real GPU backend asynchronously.
621 ///
622 /// This is the preferred entry point from `async` contexts (tokio,
623 /// wasm-bindgen-futures, etc.).
624 pub async fn try_new_async() -> Result<Self, WgpuInitError> {
625 let instance =
626 wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
627
628 let adapter = instance
629 .request_adapter(&wgpu::RequestAdapterOptions {
630 power_preference: wgpu::PowerPreference::HighPerformance,
631 compatible_surface: None,
632 force_fallback_adapter: false,
633 })
634 .await
635 .map_err(|_| WgpuInitError::NoAdapter)?;
636
637 let info = adapter.get_info();
638
639 let desc = wgpu::DeviceDescriptor {
640 label: Some("oxiphysics-wgpu"),
641 required_features: wgpu::Features::empty(),
642 required_limits: adapter.limits(),
643 ..Default::default()
644 };
645
646 let (device, queue) = adapter
647 .request_device(&desc)
648 .await
649 .map_err(|e| WgpuInitError::DeviceRequest(e.to_string()))?;
650
651 let device_info = WgpuDeviceInfo {
652 name: info.name.clone(),
653 backend: format!("{:?}", info.backend),
654 driver_version: info.driver_info.clone(),
655 // VRAM is not exposed by wgpu's AdapterInfo; use 0 as sentinel.
656 vram_bytes: 0,
657 // GPU-native f64 requires a device extension not in the base profile.
658 supports_f64: false,
659 // Conservative defaults matching most desktop GPU limits.
660 max_workgroup_size: [256, 256, 64],
661 };
662
663 Ok(Self {
664 device: Arc::new(device),
665 queue: Arc::new(queue),
666 device_info,
667 buffers: Vec::new(),
668 buffer_sizes: Vec::new(),
669 shader_cache: Mutex::new(HashMap::new()),
670 })
671 }
672
673 /// Return `true` — this struct always wraps a real GPU device.
674 pub fn is_available(&self) -> bool {
675 true
676 }
677
678 // ── Buffer management ─────────────────────────────────────────────────
679
680 /// Allocate a GPU storage buffer of `size_bytes` bytes.
681 ///
682 /// The buffer is created with `STORAGE | COPY_SRC | COPY_DST` usage
683 /// flags so that it can be used as a shader binding and for staged
684 /// CPU read/write.
685 pub fn create_buffer_storage(&mut self, size_bytes: u64) -> WgpuBufferHandle {
686 let handle = WgpuBufferHandle(self.buffers.len());
687 let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
688 label: None,
689 size: size_bytes,
690 usage: wgpu::BufferUsages::STORAGE
691 | wgpu::BufferUsages::COPY_SRC
692 | wgpu::BufferUsages::COPY_DST,
693 mapped_at_creation: false,
694 });
695 self.buffers.push(Some(Arc::new(buf)));
696 self.buffer_sizes.push(size_bytes);
697 handle
698 }
699
700 /// Allocate a GPU buffer sized for `len` `f64` values.
701 ///
702 /// Internally the data is stored as `f32` on the GPU (8 bytes per
703 /// element to maintain the same stride).
704 pub fn create_buffer_f64(&mut self, len: usize) -> WgpuBufferHandle {
705 // We store f64 values packed as two f32s to preserve stride; or
706 // simply allocate 8 bytes per element and use the f32 path with
707 // two floats per logical element. For simplicity, the current
708 // implementation casts f64→f32 on write and f32→f64 on read, so
709 // we only need 4 bytes per element on the GPU.
710 self.create_buffer_storage((len * 4) as u64)
711 }
712
713 /// Upload `data` to the GPU buffer at `handle`, casting `f64` → `f32`.
714 ///
715 /// # Panics
716 ///
717 /// Does nothing (silently returns) if `handle` is out of range.
718 pub fn write_buffer_f64(&self, handle: WgpuBufferHandle, data: &[f64]) {
719 if let Some(Some(buf)) = self.buffers.get(handle.0) {
720 let f32_data: Vec<f32> = data.iter().map(|&v| v as f32).collect();
721 self.queue
722 .write_buffer(buf, 0, bytemuck::cast_slice(&f32_data));
723 }
724 }
725
726 /// Download data from the GPU buffer at `handle`, casting `f32` → `f64`.
727 ///
728 /// This blocks the calling thread until the GPU has finished all
729 /// outstanding work and the readback mapping is complete.
730 ///
731 /// Returns an empty `Vec` if the handle is invalid or the readback fails.
732 pub fn read_buffer_f64(&self, handle: WgpuBufferHandle) -> Vec<f64> {
733 let buf = match self.buffers.get(handle.0).and_then(|b| b.as_ref()) {
734 Some(b) => b.clone(),
735 None => return Vec::new(),
736 };
737 let size = self.buffer_sizes[handle.0];
738
739 // Create a CPU-visible staging buffer for the readback.
740 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
741 label: Some("oxiphysics_staging_readback"),
742 size,
743 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
744 mapped_at_creation: false,
745 });
746
747 // Record and submit the copy command.
748 let mut encoder = self
749 .device
750 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
751 encoder.copy_buffer_to_buffer(&buf, 0, &staging, 0, size);
752 self.queue.submit(std::iter::once(encoder.finish()));
753
754 // Map the staging buffer for reading.
755 let slice = staging.slice(..);
756 let (tx, rx) = std::sync::mpsc::channel();
757 slice.map_async(wgpu::MapMode::Read, move |result| {
758 let _ = tx.send(result);
759 });
760
761 // Block until the GPU has completed and the mapping is ready.
762 if let Err(_e) = self.device.poll(wgpu::PollType::Wait {
763 submission_index: None,
764 timeout: None,
765 }) {
766 return Vec::new();
767 }
768
769 // Check that the mapping succeeded.
770 if rx.recv().ok().and_then(|r| r.ok()).is_none() {
771 return Vec::new();
772 }
773
774 let mapped = slice.get_mapped_range();
775 let f32_data: &[f32] = bytemuck::cast_slice(&mapped);
776 let result: Vec<f64> = f32_data.iter().map(|&v| v as f64).collect();
777 drop(mapped);
778 staging.unmap();
779 result
780 }
781
782 // ── Dispatch ──────────────────────────────────────────────────────────
783
784 /// Upload raw bytes to the GPU buffer at `handle`.
785 ///
786 /// The byte slice must fit within the buffer's allocated size.
787 /// Does nothing (silently returns) if `handle` is out of range.
788 pub fn queue_write_buffer_raw(&self, handle: &WgpuBufferHandle, data: &[u8]) {
789 if let Some(Some(buf)) = self.buffers.get(handle.0) {
790 self.queue.write_buffer(buf, 0, data);
791 }
792 }
793
794 /// Upload `f32` data directly to the GPU buffer at `handle` (no f64→f32 cast).
795 ///
796 /// Does nothing (silently returns) if `handle` is out of range.
797 pub fn queue_write_buffer_f32(&self, handle: &WgpuBufferHandle, data: &[f32]) {
798 if let Some(Some(buf)) = self.buffers.get(handle.0) {
799 self.queue.write_buffer(buf, 0, bytemuck::cast_slice(data));
800 }
801 }
802
803 /// Download raw `f32` values from the GPU buffer at `handle`.
804 ///
805 /// Returns an empty `Vec` if the handle is invalid or the readback fails.
806 pub fn read_buffer_f32(&self, handle: WgpuBufferHandle) -> Vec<f32> {
807 let buf = match self.buffers.get(handle.0).and_then(|b| b.as_ref()) {
808 Some(b) => b.clone(),
809 None => return Vec::new(),
810 };
811 let size = self.buffer_sizes[handle.0];
812
813 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
814 label: Some("oxiphysics_staging_readback_f32"),
815 size,
816 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
817 mapped_at_creation: false,
818 });
819
820 let mut encoder = self
821 .device
822 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
823 encoder.copy_buffer_to_buffer(&buf, 0, &staging, 0, size);
824 self.queue.submit(std::iter::once(encoder.finish()));
825
826 let slice = staging.slice(..);
827 let (tx, rx) = std::sync::mpsc::channel();
828 slice.map_async(wgpu::MapMode::Read, move |result| {
829 let _ = tx.send(result);
830 });
831
832 if let Err(_e) = self.device.poll(wgpu::PollType::Wait {
833 submission_index: None,
834 timeout: None,
835 }) {
836 return Vec::new();
837 }
838
839 if rx.recv().ok().and_then(|r| r.ok()).is_none() {
840 return Vec::new();
841 }
842
843 let mapped = slice.get_mapped_range();
844 let result: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&mapped).to_vec();
845 drop(mapped);
846 staging.unmap();
847 result
848 }
849
850 // ── Dispatch ──────────────────────────────────────────────────────────
851
852 /// Compute the 3-D workgroup dispatch counts for `n_items` elements.
853 ///
854 /// Returns `[0, 1, 1]` for `n_items == 0` (no-op dispatch).
855 pub fn dispatch_count_for(n_items: usize, workgroup_size: u32) -> [u32; 3] {
856 crate::compute::timestamp::dispatch_count_for(n_items, workgroup_size)
857 }
858
859 /// Compile and dispatch a WGSL compute shader.
860 ///
861 /// The pipeline is compiled lazily and cached by a hash of
862 /// `(wgsl_src, entry_point)`, so repeated calls with the same shader
863 /// do not recompile.
864 ///
865 /// # Parameters
866 ///
867 /// * `wgsl_src` — WGSL shader source code.
868 /// * `entry_point` — Name of the `@compute` entry point function.
869 /// * `buffers` — Ordered list of `(handle, binding_type)` pairs.
870 /// Binding index in the WGSL shader corresponds to the position in
871 /// this slice (binding 0 = `buffers[0]`, etc.).
872 /// * `workgroups` — `[x, y, z]` dispatch counts.
873 ///
874 /// # Errors
875 ///
876 /// Returns `Err(WgpuInitError::InvalidHandle)` if any buffer handle is
877 /// out of range. Returns `Err(WgpuInitError::PoisonedLock)` if the
878 /// shader-cache mutex is poisoned (should not occur in practice).
879 pub fn dispatch_wgsl(
880 &self,
881 wgsl_src: &str,
882 entry_point: &str,
883 buffers: &[(WgpuBufferHandle, wgpu::BufferBindingType)],
884 workgroups: [u32; 3],
885 ) -> Result<(), WgpuInitError> {
886 // Hash the shader source + entry point to key the pipeline cache.
887 let mut hasher = DefaultHasher::new();
888 wgsl_src.hash(&mut hasher);
889 entry_point.hash(&mut hasher);
890 let key = hasher.finish();
891
892 // Obtain or compile the pipeline.
893 let pipeline: Arc<wgpu::ComputePipeline> = {
894 let mut cache = self.shader_cache.lock().unwrap_or_else(|e| e.into_inner());
895
896 if let Some(entry) = cache.get(&key) {
897 entry.pipeline.clone()
898 } else {
899 let module = self
900 .device
901 .create_shader_module(wgpu::ShaderModuleDescriptor {
902 label: Some(entry_point),
903 source: wgpu::ShaderSource::Wgsl(wgsl_src.into()),
904 });
905 let pipeline = Arc::new(self.device.create_compute_pipeline(
906 &wgpu::ComputePipelineDescriptor {
907 label: Some(entry_point),
908 layout: None,
909 module: &module,
910 entry_point: Some(entry_point),
911 compilation_options: wgpu::PipelineCompilationOptions::default(),
912 cache: None,
913 },
914 ));
915 cache.insert(
916 key,
917 ShaderCacheEntry {
918 pipeline: pipeline.clone(),
919 },
920 );
921 pipeline
922 }
923 };
924
925 // Derive the bind-group layout from the compiled pipeline.
926 let bg_layout = pipeline.get_bind_group_layout(0);
927
928 // Build the bind-group entries.
929 let mut entries: Vec<wgpu::BindGroupEntry> = Vec::with_capacity(buffers.len());
930 for (i, (handle, _binding_type)) in buffers.iter().enumerate() {
931 let buf = self
932 .buffers
933 .get(handle.0)
934 .and_then(|b| b.as_ref())
935 .ok_or(WgpuInitError::InvalidHandle(handle.0))?;
936 entries.push(wgpu::BindGroupEntry {
937 binding: i as u32,
938 resource: buf.as_entire_binding(),
939 });
940 }
941
942 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
943 label: None,
944 layout: &bg_layout,
945 entries: &entries,
946 });
947
948 // Record and submit the compute pass.
949 let mut encoder = self
950 .device
951 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
952 {
953 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
954 label: None,
955 timestamp_writes: None,
956 });
957 pass.set_pipeline(&pipeline);
958 pass.set_bind_group(0, &bind_group, &[]);
959 pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
960 }
961 self.queue.submit(std::iter::once(encoder.finish()));
962
963 // Block until the GPU has finished (synchronous dispatch).
964 self.device
965 .poll(wgpu::PollType::Wait {
966 submission_index: None,
967 timeout: None,
968 })
969 .map_err(|_| WgpuInitError::DeviceRequest("poll failed".into()))?;
970
971 Ok(())
972 }
973 }
974
975 // ── Feature-gated tests ───────────────────────────────────────────────────
976
977 #[cfg(test)]
978 mod tests {
979 use super::*;
980
981 /// Helper: attempt to create a real backend, returning `None` if no GPU
982 /// is available (e.g. in headless CI).
983 fn try_backend() -> Option<WgpuBackendReal> {
984 WgpuBackendReal::try_new().ok()
985 }
986
987 #[test]
988 fn real_backend_try_new_succeeds_or_gracefully_fails() {
989 // This test always passes: it either succeeds (GPU present) or
990 // returns None (headless / CI environment).
991 match WgpuBackendReal::try_new() {
992 Ok(b) => {
993 assert!(b.is_available());
994 assert!(!b.device_info.backend.is_empty());
995 }
996 Err(e) => {
997 // NoAdapter is the expected error in headless CI.
998 eprintln!("No GPU adapter available: {e}");
999 }
1000 }
1001 }
1002
1003 #[test]
1004 fn real_backend_create_and_write_buffer() {
1005 let Some(mut backend) = try_backend() else {
1006 return;
1007 };
1008 let data = vec![1.0_f64, 2.0, 3.0, 4.0];
1009 let handle = backend.create_buffer_f64(data.len());
1010 backend.write_buffer_f64(handle, &data);
1011 // write_buffer_f64 is fire-and-forget; we just verify no panic.
1012 assert!(handle.0 < backend.buffers.len());
1013 }
1014
1015 #[test]
1016 fn real_backend_buffer_roundtrip() {
1017 let Some(mut backend) = try_backend() else {
1018 return;
1019 };
1020 let data = vec![1.0_f64, 2.0, 3.0, 4.0];
1021 let handle = backend.create_buffer_f64(data.len());
1022 backend.write_buffer_f64(handle, &data);
1023 let out = backend.read_buffer_f64(handle);
1024 // f64→f32→f64 loses precision; check within f32 rounding.
1025 assert_eq!(out.len(), data.len());
1026 for (&expected, &got) in data.iter().zip(out.iter()) {
1027 assert!(
1028 (expected as f32 - got as f32).abs() < 1e-5,
1029 "roundtrip mismatch: expected {expected}, got {got}"
1030 );
1031 }
1032 }
1033
1034 #[test]
1035 fn real_backend_dispatch_scale_shader() {
1036 let Some(mut backend) = try_backend() else {
1037 return;
1038 };
1039 use super::super::WgpuBackend;
1040
1041 // A simple WGSL shader that multiplies each f32 element by 2.
1042 const SCALE_BY_TWO: &str = r#"
1043@group(0) @binding(0) var<storage, read> input_buf: array<f32>;
1044@group(0) @binding(1) var<storage, read_write> output_buf: array<f32>;
1045
1046@compute @workgroup_size(64)
1047fn scale_by_two(@builtin(global_invocation_id) gid: vec3<u32>) {
1048 let i = gid.x;
1049 if i < arrayLength(&input_buf) {
1050 output_buf[i] = input_buf[i] * 2.0;
1051 }
1052}
1053"#;
1054 let n: usize = 4;
1055 let input_data: Vec<f32> = (1..=n as u32).map(|x| x as f32).collect();
1056 let in_handle = backend.create_buffer_storage((n * 4) as u64);
1057 let out_handle = backend.create_buffer_storage((n * 4) as u64);
1058
1059 backend.queue.write_buffer(
1060 backend.buffers[in_handle.0].as_ref().unwrap(),
1061 0,
1062 bytemuck::cast_slice(&input_data),
1063 );
1064
1065 // Dispatch: 1 workgroup of 64 threads covers n=4 elements.
1066 let result = backend.dispatch_wgsl(
1067 SCALE_BY_TWO,
1068 "scale_by_two",
1069 &[
1070 (
1071 in_handle,
1072 wgpu::BufferBindingType::Storage { read_only: true },
1073 ),
1074 (
1075 out_handle,
1076 wgpu::BufferBindingType::Storage { read_only: false },
1077 ),
1078 ],
1079 [1, 1, 1],
1080 );
1081 assert!(result.is_ok(), "dispatch_wgsl failed: {:?}", result.err());
1082
1083 // Readback via staging and verify.
1084 let out = backend.read_buffer_f64(out_handle);
1085 assert_eq!(out.len(), n);
1086 for (i, &v) in out.iter().enumerate() {
1087 let expected = (i + 1) as f64 * 2.0;
1088 assert!(
1089 (v - expected).abs() < 0.01,
1090 "element {i}: expected {expected}, got {v}"
1091 );
1092 }
1093
1094 // Regression guard: stub backend still works.
1095 let mut stub = WgpuBackend::new_stub();
1096 let h = stub.create_buffer(4);
1097 let _ = stub.read_buffer(h);
1098 }
1099
1100 #[test]
1101 fn dispatch_count_for_zero_items() {
1102 assert_eq!(WgpuBackendReal::dispatch_count_for(0, 64), [0, 1, 1]);
1103 }
1104
1105 #[test]
1106 fn dispatch_count_for_65_items() {
1107 assert_eq!(WgpuBackendReal::dispatch_count_for(65, 64), [2, 1, 1]);
1108 }
1109
1110 #[test]
1111 fn dispatch_count_for_exact_workgroup() {
1112 assert_eq!(WgpuBackendReal::dispatch_count_for(256, 64), [4, 1, 1]);
1113 }
1114 }
1115}