Skip to main content

oxiphysics_gpu/compute/
wgpu_backend.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! WebGPU (wgpu) compute backend for the OxiPhysics GPU acceleration layer.
5//!
6//! This module provides [`WgpuBackend`] which implements `ComputeBackend` using
7//! the `wgpu` crate for cross-platform GPU compute (Vulkan, Metal, DX12, WebGPU).
8//!
9//! ## Feature flag
10//!
11//! This module is gated behind the `wgpu-backend` Cargo feature:
12//!
13//! ```toml
14//! [dependencies]
15//! oxiphysics-gpu = { features = ["wgpu-backend"] }
16//! ```
17//!
18//! When the feature is disabled the module compiles to an empty stub.  This allows
19//! the crate to compile without the `wgpu` dependency on platforms or toolchains
20//! where GPU support is not required.
21//!
22//! ## Enabling the dependency
23//!
24//! To activate the wgpu backend, add `wgpu` to the crate's `Cargo.toml`:
25//!
26//! ```toml
27//! [features]
28//! wgpu-backend = ["wgpu"]
29//!
30//! [dependencies]
31//! wgpu = { version = "0.20", optional = true }
32//! ```
33//!
34//! ## Architecture
35//!
36//! ```text
37//!  WgpuBackend
38//!   ├── wgpu::Device / wgpu::Queue          ← GPU device & command queue
39//!   ├── Vec<WgpuBufferEntry>                 ← Registered GPU buffers
40//!   │     ├── wgpu::Buffer (device memory)
41//!   │     └── size, usage flags
42//!   └── ShaderRegistry                       ← Compiled WGSL compute shaders
43//!
44//!  Compute pipeline:
45//!    write_buffer → [upload via staging] → dispatch(kernel) → [readback via staging] → read_buffer
46//! ```
47//!
48//! ## Usage (when feature is enabled)
49//!
50//! ```ignore
51//! use oxiphysics_gpu::compute::wgpu_backend::WgpuBackend;
52//! use oxiphysics_gpu::compute::ComputeBackend;
53//!
54//! let backend = WgpuBackend::new_async().await?;
55//! let handle = backend.create_buffer(1024);
56//! backend.write_buffer(handle, &vec![1.0_f64; 128]);
57//! // ... dispatch kernel ...
58//! let data = backend.read_buffer(handle);
59//! ```
60
61#![allow(dead_code)]
62
63// ── BufferHandle (re-used from parent module) ─────────────────────────────────
64
65/// Opaque handle to a GPU buffer allocated by a `ComputeBackend`.
66///
67/// This type mirrors the one in the parent `compute` module so that
68/// [`WgpuBackend`] can implement the same `ComputeBackend` trait.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub struct WgpuBufferHandle(pub usize);
71
72// ── WgpuDeviceInfo ────────────────────────────────────────────────────────────
73
74/// Information about the GPU device selected by the wgpu adapter.
75#[derive(Debug, Clone, Default)]
76pub struct WgpuDeviceInfo {
77    /// Human-readable device name (e.g. `"NVIDIA GeForce RTX 4090"`).
78    pub name: String,
79    /// Backend API in use: `"Vulkan"`, `"Metal"`, `"Dx12"`, `"WebGpu"`, or `"None"`.
80    pub backend: String,
81    /// Driver version string (if available).
82    pub driver_version: String,
83    /// Total VRAM in bytes (0 if not reported by the adapter).
84    pub vram_bytes: u64,
85    /// Whether the device supports 64-bit floating-point storage.
86    pub supports_f64: bool,
87    /// Maximum workgroup size (x, y, z).
88    pub max_workgroup_size: [u32; 3],
89}
90
91// ── WgpuBackend ───────────────────────────────────────────────────────────────
92
93/// WebGPU compute backend.
94///
95/// When compiled **without** the `wgpu-backend` feature this struct is a no-op
96/// stub that will return an error from [`WgpuBackend::try_new`].  When compiled
97/// **with** the feature, a real wgpu `Device` / `Queue` pair is created.
98///
99/// For the real implementation, `try_new` should be called within an async
100/// runtime (tokio or wasm-bindgen-futures for browser targets).
101#[derive(Debug)]
102pub struct WgpuBackend {
103    /// Device info (populated at initialisation).
104    pub device_info: WgpuDeviceInfo,
105    /// Allocated CPU-side buffers (mirrors GPU allocations).
106    ///
107    /// In the stub implementation these are plain `Vec<f64>` acting as
108    /// stand-ins for actual `wgpu::Buffer` objects.  A full implementation
109    /// wraps `wgpu::Buffer` behind `Arc<Mutex<…>>` to allow async reads.
110    buffers: Vec<WgpuBufferEntry>,
111    /// Whether the backend is operational.
112    available: bool,
113}
114
115/// Internal buffer entry storing metadata and a CPU-side shadow copy.
116#[derive(Debug, Clone)]
117struct WgpuBufferEntry {
118    /// Byte capacity of the GPU buffer (8 × `len` for f64 arrays).
119    capacity: usize,
120    /// CPU-side shadow for upload/download (avoids wgpu dep in stub).
121    shadow: Vec<f64>,
122}
123
124impl WgpuBackend {
125    /// Attempt to create a wgpu backend.
126    ///
127    /// Returns `Ok(Self)` when a compatible GPU adapter is available, or
128    /// `Err(WgpuInitError::NotAvailable)` when no adapter can be found (e.g.
129    /// running headless without a GPU or without the `wgpu-backend` feature).
130    ///
131    /// In the current stub implementation this always returns a CPU-fallback
132    /// instance with `available = false`.  The full implementation calls
133    /// `wgpu::Instance::request_adapter` and `adapter.request_device`.
134    pub fn try_new() -> Result<Self, WgpuInitError> {
135        // ── TODO (wgpu-backend feature) ─────────────────────────────────────
136        // When `wgpu-backend` is enabled, replace this stub with:
137        //
138        //   let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
139        //       backends: wgpu::Backends::all(),
140        //       ..Default::default()
141        //   });
142        //   let adapter = pollster::block_on(instance.request_adapter(
143        //       &wgpu::RequestAdapterOptions {
144        //           power_preference: wgpu::PowerPreference::HighPerformance,
145        //           ..Default::default()
146        //       },
147        //   )).ok_or(WgpuInitError::NoAdapter)?;
148        //   let (device, queue) = pollster::block_on(adapter.request_device(
149        //       &wgpu::DeviceDescriptor::default(),
150        //       None,
151        //   ))?;
152        //   let info = adapter.get_info();
153        //   Ok(Self { device, queue, info, buffers: Vec::new(), available: true })
154        // ────────────────────────────────────────────────────────────────────
155
156        Err(WgpuInitError::NotAvailable)
157    }
158
159    /// Create a stub backend for testing that stores data in CPU memory.
160    ///
161    /// This is equivalent to what `try_new` would return on a headless system
162    /// but without returning an error — useful for unit testing backend logic.
163    pub fn new_stub() -> Self {
164        Self {
165            device_info: WgpuDeviceInfo {
166                name: "CPU stub".to_string(),
167                backend: "None".to_string(),
168                ..Default::default()
169            },
170            buffers: Vec::new(),
171            available: false,
172        }
173    }
174
175    /// Return `true` if a real GPU device is available.
176    pub fn is_available(&self) -> bool {
177        self.available
178    }
179
180    /// Return device information for diagnostics.
181    pub fn device_info(&self) -> &WgpuDeviceInfo {
182        &self.device_info
183    }
184
185    // ── Buffer management ────────────────────────────────────────────────────
186
187    /// Allocate a GPU buffer that can hold `len` `f64` values.
188    ///
189    /// Returns a [`WgpuBufferHandle`] that can be passed to [`Self::write_buffer`]
190    /// and [`Self::read_buffer`].
191    ///
192    /// In the stub implementation, a CPU-side shadow `Vec<f64>` is allocated.
193    /// In the full wgpu implementation, `wgpu::Device::create_buffer` is called
194    /// with `STORAGE | COPY_SRC | COPY_DST` usage flags.
195    pub fn create_buffer(&mut self, len: usize) -> WgpuBufferHandle {
196        let handle = WgpuBufferHandle(self.buffers.len());
197        self.buffers.push(WgpuBufferEntry {
198            capacity: len,
199            shadow: vec![0.0; len],
200        });
201        handle
202    }
203
204    /// Upload `data` to the GPU buffer at `handle`.
205    ///
206    /// In the stub, data is copied into the CPU-side shadow.
207    /// In the full implementation, `queue.write_buffer` is used.
208    pub fn write_buffer(&mut self, handle: WgpuBufferHandle, data: &[f64]) {
209        if let Some(entry) = self.buffers.get_mut(handle.0) {
210            let len = data.len().min(entry.capacity);
211            entry.shadow[..len].copy_from_slice(&data[..len]);
212        }
213    }
214
215    /// Download data from the GPU buffer at `handle`.
216    ///
217    /// In the stub, data is read from the CPU-side shadow.
218    /// In the full implementation, a staging buffer is created, the command
219    /// `encoder.copy_buffer_to_buffer` is executed, and the staging buffer is
220    /// mapped for reading.
221    pub fn read_buffer(&self, handle: WgpuBufferHandle) -> Vec<f64> {
222        self.buffers
223            .get(handle.0)
224            .map(|e| e.shadow.clone())
225            .unwrap_or_default()
226    }
227
228    // ── Dispatch ─────────────────────────────────────────────────────────────
229
230    /// Dispatch a compute kernel with `work_groups_x` workgroups.
231    ///
232    /// In the stub, the kernel's `execute` method is called on the CPU-side
233    /// shadow data.  In the full implementation a `ComputePipeline` is looked
234    /// up from the shader registry and `encoder.dispatch_workgroups` is called.
235    ///
236    /// # Arguments
237    ///
238    /// * `kernel_name` — name of the WGSL shader entry point
239    /// * `buffers`     — input/output buffer handles
240    /// * `work_groups_x` — number of workgroups in the X dimension
241    pub fn dispatch(
242        &mut self,
243        kernel_name: &str,
244        buffers: &[WgpuBufferHandle],
245        work_groups_x: u32,
246    ) {
247        // ── TODO (wgpu-backend feature) ─────────────────────────────────────
248        // When enabled:
249        //   let pipeline = self.shader_registry.get_pipeline(kernel_name)?;
250        //   let bind_group = self.device.create_bind_group(…);
251        //   let mut encoder = self.device.create_command_encoder(…);
252        //   {
253        //       let mut pass = encoder.begin_compute_pass(…);
254        //       pass.set_pipeline(&pipeline);
255        //       pass.set_bind_group(0, &bind_group, &[]);
256        //       pass.dispatch_workgroups(work_groups_x, 1, 1);
257        //   }
258        //   self.queue.submit([encoder.finish()]);
259        // ────────────────────────────────────────────────────────────────────
260
261        // Stub: identity kernel (pass-through, no-op)
262        let _ = (kernel_name, buffers, work_groups_x);
263    }
264
265    // ── WGSL shader registry ──────────────────────────────────────────────────
266
267    /// Register a WGSL compute shader source and associate it with a name.
268    ///
269    /// In the stub, the source is stored but not compiled.
270    /// In the full implementation, `device.create_shader_module` is called and
271    /// the resulting `ShaderModule` is cached.
272    pub fn register_shader(&mut self, name: &str, wgsl_source: &str) {
273        // ── TODO (wgpu-backend feature) ─────────────────────────────────────
274        // let module = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
275        //     label: Some(name),
276        //     source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
277        // });
278        // self.shader_registry.insert(name.to_string(), module);
279        let _ = (name, wgsl_source);
280    }
281}
282
283// ── Built-in WGSL kernels ─────────────────────────────────────────────────────
284
285/// WGSL source for a parallel prefix sum (exclusive scan) kernel.
286///
287/// This is the Blelloch algorithm adapted for WGSL with a workgroup of 256 threads.
288pub const WGSL_PARALLEL_SCAN: &str = r#"
289// Exclusive parallel prefix sum (Blelloch up-sweep / down-sweep)
290// Workgroup size: 256 threads
291// Binding 0: input buffer (read)
292// Binding 1: output buffer (write)
293// Binding 2: uniform { n: u32, pass: u32 }
294
295@group(0) @binding(0) var<storage, read> input:  array<f32>;
296@group(0) @binding(1) var<storage, read_write> output: array<f32>;
297
298struct Params { n: u32, pass: u32 }
299@group(0) @binding(2) var<uniform> params: Params;
300
301var<workgroup> shared: array<f32, 256>;
302
303@compute @workgroup_size(256)
304fn exclusive_scan(@builtin(global_invocation_id) gid: vec3<u32>,
305                  @builtin(local_invocation_id) lid: vec3<u32>) {
306    let n = params.n;
307    let i = gid.x;
308
309    // Load
310    shared[lid.x] = select(0.0, input[i], i < n);
311    workgroupBarrier();
312
313    // Up-sweep (reduce)
314    var stride: u32 = 1u;
315    loop {
316        if stride >= 256u { break; }
317        if lid.x % (stride * 2u) == (stride * 2u - 1u) {
318            shared[lid.x] += shared[lid.x - stride];
319        }
320        workgroupBarrier();
321        stride = stride * 2u;
322    }
323
324    // Down-sweep
325    if lid.x == 255u { shared[255] = 0.0; }
326    workgroupBarrier();
327    stride = 128u;
328    loop {
329        if stride == 0u { break; }
330        if lid.x % (stride * 2u) == (stride * 2u - 1u) {
331            let tmp = shared[lid.x - stride];
332            shared[lid.x - stride] = shared[lid.x];
333            shared[lid.x] += tmp;
334        }
335        workgroupBarrier();
336        stride = stride / 2u;
337    }
338
339    // Store
340    if i < n { output[i] = shared[lid.x]; }
341}
342"#;
343
344/// WGSL source for a simple SPH density kernel.
345///
346/// Computes particle density via a cubic-spline kernel with radius `h`.
347pub const WGSL_SPH_DENSITY: &str = r#"
348// SPH density kernel — W_spline3 smoothing
349// Binding 0: positions array (x0,y0,z0, x1,y1,z1, ...)
350// Binding 1: densities output (one per particle)
351// Binding 2: uniform { n: u32, h: f32, mass: f32 }
352
353struct SphParams { n: u32, h: f32, mass: f32 }
354@group(0) @binding(0) var<storage, read>       positions: array<f32>;
355@group(0) @binding(1) var<storage, read_write> densities: array<f32>;
356@group(0) @binding(2) var<uniform>             params:    SphParams;
357
358fn w_spline3(r: f32, h: f32) -> f32 {
359    let q = r / h;
360    let sigma = 3.0 / (2.0 * 3.14159265358979 * h * h * h);
361    if q < 1.0 {
362        return sigma * (2.0/3.0 - q*q + 0.5*q*q*q);
363    } else if q < 2.0 {
364        let t = 2.0 - q;
365        return sigma * (1.0/6.0) * t*t*t;
366    } else {
367        return 0.0;
368    }
369}
370
371@compute @workgroup_size(64)
372fn sph_density(@builtin(global_invocation_id) gid: vec3<u32>) {
373    let i = gid.x;
374    let n = params.n;
375    if i >= n { return; }
376
377    let xi = vec3<f32>(positions[i*3u], positions[i*3u+1u], positions[i*3u+2u]);
378    var density: f32 = 0.0;
379
380    for (var j: u32 = 0u; j < n; j++) {
381        let xj = vec3<f32>(positions[j*3u], positions[j*3u+1u], positions[j*3u+2u]);
382        let r = length(xi - xj);
383        density += params.mass * w_spline3(r, params.h);
384    }
385
386    densities[i] = density;
387}
388"#;
389
390/// WGSL source for parallel BVH ray traversal.
391///
392/// Traverses a linearized BVH (LBVH) to find ray–box intersections.
393/// This is a stub; real traversal requires the full BVH node buffer layout.
394pub const WGSL_BVH_TRAVERSAL: &str = r#"
395// Parallel BVH ray traversal stub
396// Each thread handles one ray; BVH nodes are in binding 0.
397
398struct Ray { origin: vec3<f32>, dir: vec3<f32>, t_max: f32 }
399struct BvhNode { lo: vec3<f32>, hi: vec3<f32>, left: u32, right: u32, is_leaf: u32, prim: u32 }
400struct HitResult { hit: u32, t: f32, prim: u32 }
401
402@group(0) @binding(0) var<storage, read>       nodes:   array<BvhNode>;
403@group(0) @binding(1) var<storage, read>        rays:    array<Ray>;
404@group(0) @binding(2) var<storage, read_write> results: array<HitResult>;
405@group(0) @binding(3) var<uniform>             num_rays: u32;
406
407fn ray_aabb(ray: Ray, lo: vec3<f32>, hi: vec3<f32>) -> f32 {
408    let inv_dir = 1.0 / ray.dir;
409    let t0 = (lo - ray.origin) * inv_dir;
410    let t1 = (hi - ray.origin) * inv_dir;
411    let t_min = max(max(min(t0.x, t1.x), min(t0.y, t1.y)), min(t0.z, t1.z));
412    let t_max_box = min(min(max(t0.x, t1.x), max(t0.y, t1.y)), max(t0.z, t1.z));
413    if t_max_box < t_min || t_min > ray.t_max { return -1.0; }
414    return t_min;
415}
416
417@compute @workgroup_size(64)
418fn bvh_traverse(@builtin(global_invocation_id) gid: vec3<u32>) {
419    let rid = gid.x;
420    if rid >= num_rays { return; }
421    let ray = rays[rid];
422    results[rid] = HitResult(0u, ray.t_max, 0xFFFFFFFFu);
423
424    // Iterative DFS stack (max depth 32)
425    var stack: array<u32, 32>;
426    var sp: i32 = 0;
427    stack[0] = 0u;
428
429    loop {
430        if sp < 0 { break; }
431        let node_idx = stack[sp]; sp--;
432        let node = nodes[node_idx];
433
434        let t = ray_aabb(ray, node.lo, node.hi);
435        if t < 0.0 { continue; }
436
437        if node.is_leaf != 0u {
438            if t < results[rid].t {
439                results[rid] = HitResult(1u, t, node.prim);
440            }
441        } else {
442            if sp < 30 { sp++; stack[sp] = node.left; }
443            if sp < 30 { sp++; stack[sp] = node.right; }
444        }
445    }
446}
447"#;
448
449// ── WgpuInitError ─────────────────────────────────────────────────────────────
450
451/// Error returned when the wgpu backend cannot be initialised.
452#[derive(Debug, Clone, PartialEq)]
453pub enum WgpuInitError {
454    /// No compatible GPU adapter was found.
455    NoAdapter,
456    /// The `wgpu-backend` feature is not enabled; this is a stub build.
457    NotAvailable,
458    /// The device request failed (e.g. out of memory).
459    DeviceRequestFailed(String),
460    /// A required GPU feature is disabled or not supported.
461    FeatureDisabled,
462    /// Device creation failed with the given error string.
463    DeviceRequest(String),
464    /// A buffer handle is out of range.
465    InvalidHandle(usize),
466    /// A mutex was poisoned (should not occur in practice).
467    PoisonedLock,
468}
469
470impl std::fmt::Display for WgpuInitError {
471    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        match self {
473            WgpuInitError::NoAdapter => write!(f, "No compatible GPU adapter found"),
474            WgpuInitError::NotAvailable => write!(f, "wgpu-backend feature not enabled"),
475            WgpuInitError::DeviceRequestFailed(s) => write!(f, "Device request failed: {s}"),
476            WgpuInitError::FeatureDisabled => write!(f, "Required GPU feature is not available"),
477            WgpuInitError::DeviceRequest(s) => write!(f, "Device request error: {s}"),
478            WgpuInitError::InvalidHandle(h) => write!(f, "Invalid buffer handle: {h}"),
479            WgpuInitError::PoisonedLock => write!(f, "Internal mutex was poisoned"),
480        }
481    }
482}
483
484impl std::error::Error for WgpuInitError {}
485
486// ── Tests ─────────────────────────────────────────────────────────────────────
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn try_new_returns_not_available_in_stub_build() {
494        let result = WgpuBackend::try_new();
495        assert!(matches!(result, Err(WgpuInitError::NotAvailable)));
496    }
497
498    #[test]
499    fn stub_backend_write_read_roundtrip() {
500        let mut backend = WgpuBackend::new_stub();
501        let handle = backend.create_buffer(4);
502        let data = vec![1.0_f64, 2.0, 3.0, 4.0];
503        backend.write_buffer(handle, &data);
504        let out = backend.read_buffer(handle);
505        assert_eq!(out, data);
506    }
507
508    #[test]
509    fn stub_dispatch_is_noop() {
510        let mut backend = WgpuBackend::new_stub();
511        let h = backend.create_buffer(8);
512        let before = backend.read_buffer(h);
513        backend.dispatch("sph_density", &[h], 1);
514        let after = backend.read_buffer(h);
515        assert_eq!(before, after, "stub dispatch should not modify buffers");
516    }
517
518    #[test]
519    fn wgsl_kernels_are_non_empty() {
520        assert!(!WGSL_PARALLEL_SCAN.is_empty());
521        assert!(!WGSL_SPH_DENSITY.is_empty());
522        assert!(!WGSL_BVH_TRAVERSAL.is_empty());
523    }
524
525    #[test]
526    fn device_info_stub_has_name() {
527        let backend = WgpuBackend::new_stub();
528        assert!(!backend.device_info().name.is_empty());
529    }
530
531    #[test]
532    fn wgpu_init_error_display() {
533        assert!(!WgpuInitError::NotAvailable.to_string().is_empty());
534        assert!(!WgpuInitError::NoAdapter.to_string().is_empty());
535        assert!(!WgpuInitError::FeatureDisabled.to_string().is_empty());
536        assert!(
537            !WgpuInitError::DeviceRequest("oom".into())
538                .to_string()
539                .is_empty()
540        );
541        assert!(!WgpuInitError::InvalidHandle(7).to_string().is_empty());
542        assert!(!WgpuInitError::PoisonedLock.to_string().is_empty());
543    }
544}
545
546// ── Real wgpu backend (feature-gated) ─────────────────────────────────────────
547
548/// Real wgpu compute backend, enabled only with the `wgpu-backend` feature.
549///
550/// Provides GPU buffer management, WGSL shader dispatch, and CPU-side readback
551/// using `wgpu` 29's cross-platform Vulkan / Metal / DX12 backends.
552///
553/// # Thread safety
554///
555/// `wgpu::Device` and `wgpu::Queue` are `Send + Sync`.  The shader cache is
556/// protected by a `Mutex`, making `WgpuBackendReal` safe to share across
557/// threads (though individual dispatches are synchronous on the calling thread).
558///
559/// # Usage
560///
561/// ```ignore
562/// // With the wgpu-backend feature enabled:
563/// use oxiphysics_gpu::compute::wgpu_backend::real::WgpuBackendReal;
564///
565/// let mut backend = WgpuBackendReal::try_new()?;
566/// let h = backend.create_buffer_f64(128);
567/// backend.write_buffer_f64(h, &vec![1.0_f64; 128]);
568/// backend.dispatch_wgsl(
569///     WGSL_SPH_DENSITY, "sph_density",
570///     &[(h, wgpu::BufferBindingType::Storage { read_only: false })],
571///     [2, 1, 1],
572/// )?;
573/// let out = backend.read_buffer_f64(h);
574/// ```
575#[cfg(feature = "wgpu-backend")]
576pub mod real {
577    use super::{WgpuBufferHandle, WgpuDeviceInfo, WgpuInitError};
578    use std::collections::HashMap;
579    use std::hash::{DefaultHasher, Hash, Hasher};
580    use std::sync::{Arc, Mutex};
581
582    // ── Internal shader-cache entry ──────────────────────────────────────────
583
584    struct ShaderCacheEntry {
585        pipeline: Arc<wgpu::ComputePipeline>,
586    }
587
588    // ── WgpuBackendReal ──────────────────────────────────────────────────────
589
590    /// Real GPU compute backend backed by `wgpu` 29.
591    ///
592    /// Obtain an instance via [`WgpuBackendReal::try_new`] (synchronous,
593    /// blocks the thread) or [`WgpuBackendReal::try_new_async`] from within
594    /// an async context.
595    pub struct WgpuBackendReal {
596        device: Arc<wgpu::Device>,
597        queue: Arc<wgpu::Queue>,
598        /// Device information (name, backend, driver).
599        pub device_info: WgpuDeviceInfo,
600        /// Allocated GPU buffers, indexed by `WgpuBufferHandle.0`.
601        buffers: Vec<Option<Arc<wgpu::Buffer>>>,
602        /// Byte size of each buffer (parallel to `buffers`).
603        buffer_sizes: Vec<u64>,
604        /// Compiled pipeline cache, keyed by a hash of WGSL source + entry point.
605        shader_cache: Mutex<HashMap<u64, ShaderCacheEntry>>,
606    }
607
608    impl WgpuBackendReal {
609        // ── Construction ─────────────────────────────────────────────────────
610
611        /// Create a real GPU backend, blocking the calling thread.
612        ///
613        /// Returns `Err` if no compatible GPU adapter is found or if device
614        /// creation fails.  Prefer [`try_new_async`](Self::try_new_async) from
615        /// within an `async` context.
616        pub fn try_new() -> Result<Self, WgpuInitError> {
617            pollster::block_on(Self::try_new_async())
618        }
619
620        /// Create a real GPU backend asynchronously.
621        ///
622        /// This is the preferred entry point from `async` contexts (tokio,
623        /// wasm-bindgen-futures, etc.).
624        pub async fn try_new_async() -> Result<Self, WgpuInitError> {
625            let instance =
626                wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
627
628            let adapter = instance
629                .request_adapter(&wgpu::RequestAdapterOptions {
630                    power_preference: wgpu::PowerPreference::HighPerformance,
631                    compatible_surface: None,
632                    force_fallback_adapter: false,
633                })
634                .await
635                .map_err(|_| WgpuInitError::NoAdapter)?;
636
637            let info = adapter.get_info();
638
639            let desc = wgpu::DeviceDescriptor {
640                label: Some("oxiphysics-wgpu"),
641                required_features: wgpu::Features::empty(),
642                required_limits: adapter.limits(),
643                ..Default::default()
644            };
645
646            let (device, queue) = adapter
647                .request_device(&desc)
648                .await
649                .map_err(|e| WgpuInitError::DeviceRequest(e.to_string()))?;
650
651            let device_info = WgpuDeviceInfo {
652                name: info.name.clone(),
653                backend: format!("{:?}", info.backend),
654                driver_version: info.driver_info.clone(),
655                // VRAM is not exposed by wgpu's AdapterInfo; use 0 as sentinel.
656                vram_bytes: 0,
657                // GPU-native f64 requires a device extension not in the base profile.
658                supports_f64: false,
659                // Conservative defaults matching most desktop GPU limits.
660                max_workgroup_size: [256, 256, 64],
661            };
662
663            Ok(Self {
664                device: Arc::new(device),
665                queue: Arc::new(queue),
666                device_info,
667                buffers: Vec::new(),
668                buffer_sizes: Vec::new(),
669                shader_cache: Mutex::new(HashMap::new()),
670            })
671        }
672
673        /// Return `true` — this struct always wraps a real GPU device.
674        pub fn is_available(&self) -> bool {
675            true
676        }
677
678        // ── Buffer management ─────────────────────────────────────────────────
679
680        /// Allocate a GPU storage buffer of `size_bytes` bytes.
681        ///
682        /// The buffer is created with `STORAGE | COPY_SRC | COPY_DST` usage
683        /// flags so that it can be used as a shader binding and for staged
684        /// CPU read/write.
685        pub fn create_buffer_storage(&mut self, size_bytes: u64) -> WgpuBufferHandle {
686            let handle = WgpuBufferHandle(self.buffers.len());
687            let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
688                label: None,
689                size: size_bytes,
690                usage: wgpu::BufferUsages::STORAGE
691                    | wgpu::BufferUsages::COPY_SRC
692                    | wgpu::BufferUsages::COPY_DST,
693                mapped_at_creation: false,
694            });
695            self.buffers.push(Some(Arc::new(buf)));
696            self.buffer_sizes.push(size_bytes);
697            handle
698        }
699
700        /// Allocate a GPU buffer sized for `len` `f64` values.
701        ///
702        /// Internally the data is stored as `f32` on the GPU (8 bytes per
703        /// element to maintain the same stride).
704        pub fn create_buffer_f64(&mut self, len: usize) -> WgpuBufferHandle {
705            // We store f64 values packed as two f32s to preserve stride; or
706            // simply allocate 8 bytes per element and use the f32 path with
707            // two floats per logical element. For simplicity, the current
708            // implementation casts f64→f32 on write and f32→f64 on read, so
709            // we only need 4 bytes per element on the GPU.
710            self.create_buffer_storage((len * 4) as u64)
711        }
712
713        /// Upload `data` to the GPU buffer at `handle`, casting `f64` → `f32`.
714        ///
715        /// # Panics
716        ///
717        /// Does nothing (silently returns) if `handle` is out of range.
718        pub fn write_buffer_f64(&self, handle: WgpuBufferHandle, data: &[f64]) {
719            if let Some(Some(buf)) = self.buffers.get(handle.0) {
720                let f32_data: Vec<f32> = data.iter().map(|&v| v as f32).collect();
721                self.queue
722                    .write_buffer(buf, 0, bytemuck::cast_slice(&f32_data));
723            }
724        }
725
726        /// Download data from the GPU buffer at `handle`, casting `f32` → `f64`.
727        ///
728        /// This blocks the calling thread until the GPU has finished all
729        /// outstanding work and the readback mapping is complete.
730        ///
731        /// Returns an empty `Vec` if the handle is invalid or the readback fails.
732        pub fn read_buffer_f64(&self, handle: WgpuBufferHandle) -> Vec<f64> {
733            let buf = match self.buffers.get(handle.0).and_then(|b| b.as_ref()) {
734                Some(b) => b.clone(),
735                None => return Vec::new(),
736            };
737            let size = self.buffer_sizes[handle.0];
738
739            // Create a CPU-visible staging buffer for the readback.
740            let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
741                label: Some("oxiphysics_staging_readback"),
742                size,
743                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
744                mapped_at_creation: false,
745            });
746
747            // Record and submit the copy command.
748            let mut encoder = self
749                .device
750                .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
751            encoder.copy_buffer_to_buffer(&buf, 0, &staging, 0, size);
752            self.queue.submit(std::iter::once(encoder.finish()));
753
754            // Map the staging buffer for reading.
755            let slice = staging.slice(..);
756            let (tx, rx) = std::sync::mpsc::channel();
757            slice.map_async(wgpu::MapMode::Read, move |result| {
758                let _ = tx.send(result);
759            });
760
761            // Block until the GPU has completed and the mapping is ready.
762            if let Err(_e) = self.device.poll(wgpu::PollType::Wait {
763                submission_index: None,
764                timeout: None,
765            }) {
766                return Vec::new();
767            }
768
769            // Check that the mapping succeeded.
770            if rx.recv().ok().and_then(|r| r.ok()).is_none() {
771                return Vec::new();
772            }
773
774            let mapped = slice.get_mapped_range();
775            let f32_data: &[f32] = bytemuck::cast_slice(&mapped);
776            let result: Vec<f64> = f32_data.iter().map(|&v| v as f64).collect();
777            drop(mapped);
778            staging.unmap();
779            result
780        }
781
782        // ── Dispatch ──────────────────────────────────────────────────────────
783
784        /// Upload raw bytes to the GPU buffer at `handle`.
785        ///
786        /// The byte slice must fit within the buffer's allocated size.
787        /// Does nothing (silently returns) if `handle` is out of range.
788        pub fn queue_write_buffer_raw(&self, handle: &WgpuBufferHandle, data: &[u8]) {
789            if let Some(Some(buf)) = self.buffers.get(handle.0) {
790                self.queue.write_buffer(buf, 0, data);
791            }
792        }
793
794        /// Upload `f32` data directly to the GPU buffer at `handle` (no f64→f32 cast).
795        ///
796        /// Does nothing (silently returns) if `handle` is out of range.
797        pub fn queue_write_buffer_f32(&self, handle: &WgpuBufferHandle, data: &[f32]) {
798            if let Some(Some(buf)) = self.buffers.get(handle.0) {
799                self.queue.write_buffer(buf, 0, bytemuck::cast_slice(data));
800            }
801        }
802
803        /// Download raw `f32` values from the GPU buffer at `handle`.
804        ///
805        /// Returns an empty `Vec` if the handle is invalid or the readback fails.
806        pub fn read_buffer_f32(&self, handle: WgpuBufferHandle) -> Vec<f32> {
807            let buf = match self.buffers.get(handle.0).and_then(|b| b.as_ref()) {
808                Some(b) => b.clone(),
809                None => return Vec::new(),
810            };
811            let size = self.buffer_sizes[handle.0];
812
813            let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
814                label: Some("oxiphysics_staging_readback_f32"),
815                size,
816                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
817                mapped_at_creation: false,
818            });
819
820            let mut encoder = self
821                .device
822                .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
823            encoder.copy_buffer_to_buffer(&buf, 0, &staging, 0, size);
824            self.queue.submit(std::iter::once(encoder.finish()));
825
826            let slice = staging.slice(..);
827            let (tx, rx) = std::sync::mpsc::channel();
828            slice.map_async(wgpu::MapMode::Read, move |result| {
829                let _ = tx.send(result);
830            });
831
832            if let Err(_e) = self.device.poll(wgpu::PollType::Wait {
833                submission_index: None,
834                timeout: None,
835            }) {
836                return Vec::new();
837            }
838
839            if rx.recv().ok().and_then(|r| r.ok()).is_none() {
840                return Vec::new();
841            }
842
843            let mapped = slice.get_mapped_range();
844            let result: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&mapped).to_vec();
845            drop(mapped);
846            staging.unmap();
847            result
848        }
849
850        // ── Dispatch ──────────────────────────────────────────────────────────
851
852        /// Compute the 3-D workgroup dispatch counts for `n_items` elements.
853        ///
854        /// Returns `[0, 1, 1]` for `n_items == 0` (no-op dispatch).
855        pub fn dispatch_count_for(n_items: usize, workgroup_size: u32) -> [u32; 3] {
856            crate::compute::timestamp::dispatch_count_for(n_items, workgroup_size)
857        }
858
859        /// Compile and dispatch a WGSL compute shader.
860        ///
861        /// The pipeline is compiled lazily and cached by a hash of
862        /// `(wgsl_src, entry_point)`, so repeated calls with the same shader
863        /// do not recompile.
864        ///
865        /// # Parameters
866        ///
867        /// * `wgsl_src`    — WGSL shader source code.
868        /// * `entry_point` — Name of the `@compute` entry point function.
869        /// * `buffers`     — Ordered list of `(handle, binding_type)` pairs.
870        ///   Binding index in the WGSL shader corresponds to the position in
871        ///   this slice (binding 0 = `buffers[0]`, etc.).
872        /// * `workgroups`  — `[x, y, z]` dispatch counts.
873        ///
874        /// # Errors
875        ///
876        /// Returns `Err(WgpuInitError::InvalidHandle)` if any buffer handle is
877        /// out of range.  Returns `Err(WgpuInitError::PoisonedLock)` if the
878        /// shader-cache mutex is poisoned (should not occur in practice).
879        pub fn dispatch_wgsl(
880            &self,
881            wgsl_src: &str,
882            entry_point: &str,
883            buffers: &[(WgpuBufferHandle, wgpu::BufferBindingType)],
884            workgroups: [u32; 3],
885        ) -> Result<(), WgpuInitError> {
886            // Hash the shader source + entry point to key the pipeline cache.
887            let mut hasher = DefaultHasher::new();
888            wgsl_src.hash(&mut hasher);
889            entry_point.hash(&mut hasher);
890            let key = hasher.finish();
891
892            // Obtain or compile the pipeline.
893            let pipeline: Arc<wgpu::ComputePipeline> = {
894                let mut cache = self.shader_cache.lock().unwrap_or_else(|e| e.into_inner());
895
896                if let Some(entry) = cache.get(&key) {
897                    entry.pipeline.clone()
898                } else {
899                    let module = self
900                        .device
901                        .create_shader_module(wgpu::ShaderModuleDescriptor {
902                            label: Some(entry_point),
903                            source: wgpu::ShaderSource::Wgsl(wgsl_src.into()),
904                        });
905                    let pipeline = Arc::new(self.device.create_compute_pipeline(
906                        &wgpu::ComputePipelineDescriptor {
907                            label: Some(entry_point),
908                            layout: None,
909                            module: &module,
910                            entry_point: Some(entry_point),
911                            compilation_options: wgpu::PipelineCompilationOptions::default(),
912                            cache: None,
913                        },
914                    ));
915                    cache.insert(
916                        key,
917                        ShaderCacheEntry {
918                            pipeline: pipeline.clone(),
919                        },
920                    );
921                    pipeline
922                }
923            };
924
925            // Derive the bind-group layout from the compiled pipeline.
926            let bg_layout = pipeline.get_bind_group_layout(0);
927
928            // Build the bind-group entries.
929            let mut entries: Vec<wgpu::BindGroupEntry> = Vec::with_capacity(buffers.len());
930            for (i, (handle, _binding_type)) in buffers.iter().enumerate() {
931                let buf = self
932                    .buffers
933                    .get(handle.0)
934                    .and_then(|b| b.as_ref())
935                    .ok_or(WgpuInitError::InvalidHandle(handle.0))?;
936                entries.push(wgpu::BindGroupEntry {
937                    binding: i as u32,
938                    resource: buf.as_entire_binding(),
939                });
940            }
941
942            let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
943                label: None,
944                layout: &bg_layout,
945                entries: &entries,
946            });
947
948            // Record and submit the compute pass.
949            let mut encoder = self
950                .device
951                .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
952            {
953                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
954                    label: None,
955                    timestamp_writes: None,
956                });
957                pass.set_pipeline(&pipeline);
958                pass.set_bind_group(0, &bind_group, &[]);
959                pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
960            }
961            self.queue.submit(std::iter::once(encoder.finish()));
962
963            // Block until the GPU has finished (synchronous dispatch).
964            self.device
965                .poll(wgpu::PollType::Wait {
966                    submission_index: None,
967                    timeout: None,
968                })
969                .map_err(|_| WgpuInitError::DeviceRequest("poll failed".into()))?;
970
971            Ok(())
972        }
973    }
974
975    // ── Feature-gated tests ───────────────────────────────────────────────────
976
977    #[cfg(test)]
978    mod tests {
979        use super::*;
980
981        /// Helper: attempt to create a real backend, returning `None` if no GPU
982        /// is available (e.g. in headless CI).
983        fn try_backend() -> Option<WgpuBackendReal> {
984            WgpuBackendReal::try_new().ok()
985        }
986
987        #[test]
988        fn real_backend_try_new_succeeds_or_gracefully_fails() {
989            // This test always passes: it either succeeds (GPU present) or
990            // returns None (headless / CI environment).
991            match WgpuBackendReal::try_new() {
992                Ok(b) => {
993                    assert!(b.is_available());
994                    assert!(!b.device_info.backend.is_empty());
995                }
996                Err(e) => {
997                    // NoAdapter is the expected error in headless CI.
998                    eprintln!("No GPU adapter available: {e}");
999                }
1000            }
1001        }
1002
1003        #[test]
1004        fn real_backend_create_and_write_buffer() {
1005            let Some(mut backend) = try_backend() else {
1006                return;
1007            };
1008            let data = vec![1.0_f64, 2.0, 3.0, 4.0];
1009            let handle = backend.create_buffer_f64(data.len());
1010            backend.write_buffer_f64(handle, &data);
1011            // write_buffer_f64 is fire-and-forget; we just verify no panic.
1012            assert!(handle.0 < backend.buffers.len());
1013        }
1014
1015        #[test]
1016        fn real_backend_buffer_roundtrip() {
1017            let Some(mut backend) = try_backend() else {
1018                return;
1019            };
1020            let data = vec![1.0_f64, 2.0, 3.0, 4.0];
1021            let handle = backend.create_buffer_f64(data.len());
1022            backend.write_buffer_f64(handle, &data);
1023            let out = backend.read_buffer_f64(handle);
1024            // f64→f32→f64 loses precision; check within f32 rounding.
1025            assert_eq!(out.len(), data.len());
1026            for (&expected, &got) in data.iter().zip(out.iter()) {
1027                assert!(
1028                    (expected as f32 - got as f32).abs() < 1e-5,
1029                    "roundtrip mismatch: expected {expected}, got {got}"
1030                );
1031            }
1032        }
1033
1034        #[test]
1035        fn real_backend_dispatch_scale_shader() {
1036            let Some(mut backend) = try_backend() else {
1037                return;
1038            };
1039            use super::super::WgpuBackend;
1040
1041            // A simple WGSL shader that multiplies each f32 element by 2.
1042            const SCALE_BY_TWO: &str = r#"
1043@group(0) @binding(0) var<storage, read>       input_buf:  array<f32>;
1044@group(0) @binding(1) var<storage, read_write> output_buf: array<f32>;
1045
1046@compute @workgroup_size(64)
1047fn scale_by_two(@builtin(global_invocation_id) gid: vec3<u32>) {
1048    let i = gid.x;
1049    if i < arrayLength(&input_buf) {
1050        output_buf[i] = input_buf[i] * 2.0;
1051    }
1052}
1053"#;
1054            let n: usize = 4;
1055            let input_data: Vec<f32> = (1..=n as u32).map(|x| x as f32).collect();
1056            let in_handle = backend.create_buffer_storage((n * 4) as u64);
1057            let out_handle = backend.create_buffer_storage((n * 4) as u64);
1058
1059            backend.queue.write_buffer(
1060                backend.buffers[in_handle.0].as_ref().unwrap(),
1061                0,
1062                bytemuck::cast_slice(&input_data),
1063            );
1064
1065            // Dispatch: 1 workgroup of 64 threads covers n=4 elements.
1066            let result = backend.dispatch_wgsl(
1067                SCALE_BY_TWO,
1068                "scale_by_two",
1069                &[
1070                    (
1071                        in_handle,
1072                        wgpu::BufferBindingType::Storage { read_only: true },
1073                    ),
1074                    (
1075                        out_handle,
1076                        wgpu::BufferBindingType::Storage { read_only: false },
1077                    ),
1078                ],
1079                [1, 1, 1],
1080            );
1081            assert!(result.is_ok(), "dispatch_wgsl failed: {:?}", result.err());
1082
1083            // Readback via staging and verify.
1084            let out = backend.read_buffer_f64(out_handle);
1085            assert_eq!(out.len(), n);
1086            for (i, &v) in out.iter().enumerate() {
1087                let expected = (i + 1) as f64 * 2.0;
1088                assert!(
1089                    (v - expected).abs() < 0.01,
1090                    "element {i}: expected {expected}, got {v}"
1091                );
1092            }
1093
1094            // Regression guard: stub backend still works.
1095            let mut stub = WgpuBackend::new_stub();
1096            let h = stub.create_buffer(4);
1097            let _ = stub.read_buffer(h);
1098        }
1099
1100        #[test]
1101        fn dispatch_count_for_zero_items() {
1102            assert_eq!(WgpuBackendReal::dispatch_count_for(0, 64), [0, 1, 1]);
1103        }
1104
1105        #[test]
1106        fn dispatch_count_for_65_items() {
1107            assert_eq!(WgpuBackendReal::dispatch_count_for(65, 64), [2, 1, 1]);
1108        }
1109
1110        #[test]
1111        fn dispatch_count_for_exact_workgroup() {
1112            assert_eq!(WgpuBackendReal::dispatch_count_for(256, 64), [4, 1, 1]);
1113        }
1114    }
1115}