#![allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct WgpuBufferHandle(pub usize);
#[derive(Debug, Clone, Default)]
pub struct WgpuDeviceInfo {
pub name: String,
pub backend: String,
pub driver_version: String,
pub vram_bytes: u64,
pub supports_f64: bool,
pub max_workgroup_size: [u32; 3],
}
#[derive(Debug)]
pub struct WgpuBackend {
pub device_info: WgpuDeviceInfo,
buffers: Vec<WgpuBufferEntry>,
available: bool,
}
#[derive(Debug, Clone)]
struct WgpuBufferEntry {
capacity: usize,
shadow: Vec<f64>,
}
impl WgpuBackend {
pub fn try_new() -> Result<Self, WgpuInitError> {
Err(WgpuInitError::NotAvailable)
}
pub fn new_stub() -> Self {
Self {
device_info: WgpuDeviceInfo {
name: "CPU stub".to_string(),
backend: "None".to_string(),
..Default::default()
},
buffers: Vec::new(),
available: false,
}
}
pub fn is_available(&self) -> bool {
self.available
}
pub fn device_info(&self) -> &WgpuDeviceInfo {
&self.device_info
}
pub fn create_buffer(&mut self, len: usize) -> WgpuBufferHandle {
let handle = WgpuBufferHandle(self.buffers.len());
self.buffers.push(WgpuBufferEntry {
capacity: len,
shadow: vec![0.0; len],
});
handle
}
pub fn write_buffer(&mut self, handle: WgpuBufferHandle, data: &[f64]) {
if let Some(entry) = self.buffers.get_mut(handle.0) {
let len = data.len().min(entry.capacity);
entry.shadow[..len].copy_from_slice(&data[..len]);
}
}
pub fn read_buffer(&self, handle: WgpuBufferHandle) -> Vec<f64> {
self.buffers
.get(handle.0)
.map(|e| e.shadow.clone())
.unwrap_or_default()
}
pub fn dispatch(
&mut self,
kernel_name: &str,
buffers: &[WgpuBufferHandle],
work_groups_x: u32,
) {
let _ = (kernel_name, buffers, work_groups_x);
}
pub fn register_shader(&mut self, name: &str, wgsl_source: &str) {
let _ = (name, wgsl_source);
}
}
pub const WGSL_PARALLEL_SCAN: &str = r#"
// Exclusive parallel prefix sum (Blelloch up-sweep / down-sweep)
// Workgroup size: 256 threads
// Binding 0: input buffer (read)
// Binding 1: output buffer (write)
// Binding 2: uniform { n: u32, pass: u32 }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
struct Params { n: u32, pass: u32 }
@group(0) @binding(2) var<uniform> params: Params;
var<workgroup> shared: array<f32, 256>;
@compute @workgroup_size(256)
fn exclusive_scan(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let n = params.n;
let i = gid.x;
// Load
shared[lid.x] = select(0.0, input[i], i < n);
workgroupBarrier();
// Up-sweep (reduce)
var stride: u32 = 1u;
loop {
if stride >= 256u { break; }
if lid.x % (stride * 2u) == (stride * 2u - 1u) {
shared[lid.x] += shared[lid.x - stride];
}
workgroupBarrier();
stride = stride * 2u;
}
// Down-sweep
if lid.x == 255u { shared[255] = 0.0; }
workgroupBarrier();
stride = 128u;
loop {
if stride == 0u { break; }
if lid.x % (stride * 2u) == (stride * 2u - 1u) {
let tmp = shared[lid.x - stride];
shared[lid.x - stride] = shared[lid.x];
shared[lid.x] += tmp;
}
workgroupBarrier();
stride = stride / 2u;
}
// Store
if i < n { output[i] = shared[lid.x]; }
}
"#;
pub const WGSL_SPH_DENSITY: &str = r#"
// SPH density kernel — W_spline3 smoothing
// Binding 0: positions array (x0,y0,z0, x1,y1,z1, ...)
// Binding 1: densities output (one per particle)
// Binding 2: uniform { n: u32, h: f32, mass: f32 }
struct SphParams { n: u32, h: f32, mass: f32 }
@group(0) @binding(0) var<storage, read> positions: array<f32>;
@group(0) @binding(1) var<storage, read_write> densities: array<f32>;
@group(0) @binding(2) var<uniform> params: SphParams;
fn w_spline3(r: f32, h: f32) -> f32 {
let q = r / h;
let sigma = 3.0 / (2.0 * 3.14159265358979 * h * h * h);
if q < 1.0 {
return sigma * (2.0/3.0 - q*q + 0.5*q*q*q);
} else if q < 2.0 {
let t = 2.0 - q;
return sigma * (1.0/6.0) * t*t*t;
} else {
return 0.0;
}
}
@compute @workgroup_size(64)
fn sph_density(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
let n = params.n;
if i >= n { return; }
let xi = vec3<f32>(positions[i*3u], positions[i*3u+1u], positions[i*3u+2u]);
var density: f32 = 0.0;
for (var j: u32 = 0u; j < n; j++) {
let xj = vec3<f32>(positions[j*3u], positions[j*3u+1u], positions[j*3u+2u]);
let r = length(xi - xj);
density += params.mass * w_spline3(r, params.h);
}
densities[i] = density;
}
"#;
pub const WGSL_BVH_TRAVERSAL: &str = r#"
// Parallel BVH ray traversal stub
// Each thread handles one ray; BVH nodes are in binding 0.
struct Ray { origin: vec3<f32>, dir: vec3<f32>, t_max: f32 }
struct BvhNode { lo: vec3<f32>, hi: vec3<f32>, left: u32, right: u32, is_leaf: u32, prim: u32 }
struct HitResult { hit: u32, t: f32, prim: u32 }
@group(0) @binding(0) var<storage, read> nodes: array<BvhNode>;
@group(0) @binding(1) var<storage, read> rays: array<Ray>;
@group(0) @binding(2) var<storage, read_write> results: array<HitResult>;
@group(0) @binding(3) var<uniform> num_rays: u32;
fn ray_aabb(ray: Ray, lo: vec3<f32>, hi: vec3<f32>) -> f32 {
let inv_dir = 1.0 / ray.dir;
let t0 = (lo - ray.origin) * inv_dir;
let t1 = (hi - ray.origin) * inv_dir;
let t_min = max(max(min(t0.x, t1.x), min(t0.y, t1.y)), min(t0.z, t1.z));
let t_max_box = min(min(max(t0.x, t1.x), max(t0.y, t1.y)), max(t0.z, t1.z));
if t_max_box < t_min || t_min > ray.t_max { return -1.0; }
return t_min;
}
@compute @workgroup_size(64)
fn bvh_traverse(@builtin(global_invocation_id) gid: vec3<u32>) {
let rid = gid.x;
if rid >= num_rays { return; }
let ray = rays[rid];
results[rid] = HitResult(0u, ray.t_max, 0xFFFFFFFFu);
// Iterative DFS stack (max depth 32)
var stack: array<u32, 32>;
var sp: i32 = 0;
stack[0] = 0u;
loop {
if sp < 0 { break; }
let node_idx = stack[sp]; sp--;
let node = nodes[node_idx];
let t = ray_aabb(ray, node.lo, node.hi);
if t < 0.0 { continue; }
if node.is_leaf != 0u {
if t < results[rid].t {
results[rid] = HitResult(1u, t, node.prim);
}
} else {
if sp < 30 { sp++; stack[sp] = node.left; }
if sp < 30 { sp++; stack[sp] = node.right; }
}
}
}
"#;
#[derive(Debug, Clone, PartialEq)]
pub enum WgpuInitError {
NoAdapter,
NotAvailable,
DeviceRequestFailed(String),
FeatureDisabled,
DeviceRequest(String),
InvalidHandle(usize),
PoisonedLock,
}
impl std::fmt::Display for WgpuInitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WgpuInitError::NoAdapter => write!(f, "No compatible GPU adapter found"),
WgpuInitError::NotAvailable => write!(f, "wgpu-backend feature not enabled"),
WgpuInitError::DeviceRequestFailed(s) => write!(f, "Device request failed: {s}"),
WgpuInitError::FeatureDisabled => write!(f, "Required GPU feature is not available"),
WgpuInitError::DeviceRequest(s) => write!(f, "Device request error: {s}"),
WgpuInitError::InvalidHandle(h) => write!(f, "Invalid buffer handle: {h}"),
WgpuInitError::PoisonedLock => write!(f, "Internal mutex was poisoned"),
}
}
}
impl std::error::Error for WgpuInitError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_new_returns_not_available_in_stub_build() {
let result = WgpuBackend::try_new();
assert!(matches!(result, Err(WgpuInitError::NotAvailable)));
}
#[test]
fn stub_backend_write_read_roundtrip() {
let mut backend = WgpuBackend::new_stub();
let handle = backend.create_buffer(4);
let data = vec![1.0_f64, 2.0, 3.0, 4.0];
backend.write_buffer(handle, &data);
let out = backend.read_buffer(handle);
assert_eq!(out, data);
}
#[test]
fn stub_dispatch_is_noop() {
let mut backend = WgpuBackend::new_stub();
let h = backend.create_buffer(8);
let before = backend.read_buffer(h);
backend.dispatch("sph_density", &[h], 1);
let after = backend.read_buffer(h);
assert_eq!(before, after, "stub dispatch should not modify buffers");
}
#[test]
fn wgsl_kernels_are_non_empty() {
assert!(!WGSL_PARALLEL_SCAN.is_empty());
assert!(!WGSL_SPH_DENSITY.is_empty());
assert!(!WGSL_BVH_TRAVERSAL.is_empty());
}
#[test]
fn device_info_stub_has_name() {
let backend = WgpuBackend::new_stub();
assert!(!backend.device_info().name.is_empty());
}
#[test]
fn wgpu_init_error_display() {
assert!(!WgpuInitError::NotAvailable.to_string().is_empty());
assert!(!WgpuInitError::NoAdapter.to_string().is_empty());
assert!(!WgpuInitError::FeatureDisabled.to_string().is_empty());
assert!(
!WgpuInitError::DeviceRequest("oom".into())
.to_string()
.is_empty()
);
assert!(!WgpuInitError::InvalidHandle(7).to_string().is_empty());
assert!(!WgpuInitError::PoisonedLock.to_string().is_empty());
}
}
#[cfg(feature = "wgpu-backend")]
pub mod real {
use super::{WgpuBufferHandle, WgpuDeviceInfo, WgpuInitError};
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{Arc, Mutex};
struct ShaderCacheEntry {
pipeline: Arc<wgpu::ComputePipeline>,
}
pub struct WgpuBackendReal {
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
pub device_info: WgpuDeviceInfo,
buffers: Vec<Option<Arc<wgpu::Buffer>>>,
buffer_sizes: Vec<u64>,
shader_cache: Mutex<HashMap<u64, ShaderCacheEntry>>,
}
impl WgpuBackendReal {
pub fn try_new() -> Result<Self, WgpuInitError> {
pollster::block_on(Self::try_new_async())
}
pub async fn try_new_async() -> Result<Self, WgpuInitError> {
let instance =
wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.map_err(|_| WgpuInitError::NoAdapter)?;
let info = adapter.get_info();
let desc = wgpu::DeviceDescriptor {
label: Some("oxiphysics-wgpu"),
required_features: wgpu::Features::empty(),
required_limits: adapter.limits(),
..Default::default()
};
let (device, queue) = adapter
.request_device(&desc)
.await
.map_err(|e| WgpuInitError::DeviceRequest(e.to_string()))?;
let device_info = WgpuDeviceInfo {
name: info.name.clone(),
backend: format!("{:?}", info.backend),
driver_version: info.driver_info.clone(),
vram_bytes: 0,
supports_f64: false,
max_workgroup_size: [256, 256, 64],
};
Ok(Self {
device: Arc::new(device),
queue: Arc::new(queue),
device_info,
buffers: Vec::new(),
buffer_sizes: Vec::new(),
shader_cache: Mutex::new(HashMap::new()),
})
}
pub fn is_available(&self) -> bool {
true
}
pub fn create_buffer_storage(&mut self, size_bytes: u64) -> WgpuBufferHandle {
let handle = WgpuBufferHandle(self.buffers.len());
let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: size_bytes,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.buffers.push(Some(Arc::new(buf)));
self.buffer_sizes.push(size_bytes);
handle
}
pub fn create_buffer_f64(&mut self, len: usize) -> WgpuBufferHandle {
self.create_buffer_storage((len * 4) as u64)
}
pub fn write_buffer_f64(&self, handle: WgpuBufferHandle, data: &[f64]) {
if let Some(Some(buf)) = self.buffers.get(handle.0) {
let f32_data: Vec<f32> = data.iter().map(|&v| v as f32).collect();
self.queue
.write_buffer(buf, 0, bytemuck::cast_slice(&f32_data));
}
}
pub fn read_buffer_f64(&self, handle: WgpuBufferHandle) -> Vec<f64> {
let buf = match self.buffers.get(handle.0).and_then(|b| b.as_ref()) {
Some(b) => b.clone(),
None => return Vec::new(),
};
let size = self.buffer_sizes[handle.0];
let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("oxiphysics_staging_readback"),
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(&buf, 0, &staging, 0, size);
self.queue.submit(std::iter::once(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
if let Err(_e) = self.device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
}) {
return Vec::new();
}
if rx.recv().ok().and_then(|r| r.ok()).is_none() {
return Vec::new();
}
let mapped = slice.get_mapped_range();
let f32_data: &[f32] = bytemuck::cast_slice(&mapped);
let result: Vec<f64> = f32_data.iter().map(|&v| v as f64).collect();
drop(mapped);
staging.unmap();
result
}
pub fn queue_write_buffer_raw(&self, handle: &WgpuBufferHandle, data: &[u8]) {
if let Some(Some(buf)) = self.buffers.get(handle.0) {
self.queue.write_buffer(buf, 0, data);
}
}
pub fn queue_write_buffer_f32(&self, handle: &WgpuBufferHandle, data: &[f32]) {
if let Some(Some(buf)) = self.buffers.get(handle.0) {
self.queue.write_buffer(buf, 0, bytemuck::cast_slice(data));
}
}
pub fn read_buffer_f32(&self, handle: WgpuBufferHandle) -> Vec<f32> {
let buf = match self.buffers.get(handle.0).and_then(|b| b.as_ref()) {
Some(b) => b.clone(),
None => return Vec::new(),
};
let size = self.buffer_sizes[handle.0];
let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("oxiphysics_staging_readback_f32"),
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(&buf, 0, &staging, 0, size);
self.queue.submit(std::iter::once(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
if let Err(_e) = self.device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
}) {
return Vec::new();
}
if rx.recv().ok().and_then(|r| r.ok()).is_none() {
return Vec::new();
}
let mapped = slice.get_mapped_range();
let result: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&mapped).to_vec();
drop(mapped);
staging.unmap();
result
}
pub fn dispatch_count_for(n_items: usize, workgroup_size: u32) -> [u32; 3] {
crate::compute::timestamp::dispatch_count_for(n_items, workgroup_size)
}
pub fn dispatch_wgsl(
&self,
wgsl_src: &str,
entry_point: &str,
buffers: &[(WgpuBufferHandle, wgpu::BufferBindingType)],
workgroups: [u32; 3],
) -> Result<(), WgpuInitError> {
let mut hasher = DefaultHasher::new();
wgsl_src.hash(&mut hasher);
entry_point.hash(&mut hasher);
let key = hasher.finish();
let pipeline: Arc<wgpu::ComputePipeline> = {
let mut cache = self.shader_cache.lock().unwrap_or_else(|e| e.into_inner());
if let Some(entry) = cache.get(&key) {
entry.pipeline.clone()
} else {
let module = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(entry_point),
source: wgpu::ShaderSource::Wgsl(wgsl_src.into()),
});
let pipeline = Arc::new(self.device.create_compute_pipeline(
&wgpu::ComputePipelineDescriptor {
label: Some(entry_point),
layout: None,
module: &module,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
},
));
cache.insert(
key,
ShaderCacheEntry {
pipeline: pipeline.clone(),
},
);
pipeline
}
};
let bg_layout = pipeline.get_bind_group_layout(0);
let mut entries: Vec<wgpu::BindGroupEntry> = Vec::with_capacity(buffers.len());
for (i, (handle, _binding_type)) in buffers.iter().enumerate() {
let buf = self
.buffers
.get(handle.0)
.and_then(|b| b.as_ref())
.ok_or(WgpuInitError::InvalidHandle(handle.0))?;
entries.push(wgpu::BindGroupEntry {
binding: i as u32,
resource: buf.as_entire_binding(),
});
}
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bg_layout,
entries: &entries,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
self.queue.submit(std::iter::once(encoder.finish()));
self.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.map_err(|_| WgpuInitError::DeviceRequest("poll failed".into()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn try_backend() -> Option<WgpuBackendReal> {
WgpuBackendReal::try_new().ok()
}
#[test]
fn real_backend_try_new_succeeds_or_gracefully_fails() {
match WgpuBackendReal::try_new() {
Ok(b) => {
assert!(b.is_available());
assert!(!b.device_info.backend.is_empty());
}
Err(e) => {
eprintln!("No GPU adapter available: {e}");
}
}
}
#[test]
fn real_backend_create_and_write_buffer() {
let Some(mut backend) = try_backend() else {
return;
};
let data = vec![1.0_f64, 2.0, 3.0, 4.0];
let handle = backend.create_buffer_f64(data.len());
backend.write_buffer_f64(handle, &data);
assert!(handle.0 < backend.buffers.len());
}
#[test]
fn real_backend_buffer_roundtrip() {
let Some(mut backend) = try_backend() else {
return;
};
let data = vec![1.0_f64, 2.0, 3.0, 4.0];
let handle = backend.create_buffer_f64(data.len());
backend.write_buffer_f64(handle, &data);
let out = backend.read_buffer_f64(handle);
assert_eq!(out.len(), data.len());
for (&expected, &got) in data.iter().zip(out.iter()) {
assert!(
(expected as f32 - got as f32).abs() < 1e-5,
"roundtrip mismatch: expected {expected}, got {got}"
);
}
}
#[test]
fn real_backend_dispatch_scale_shader() {
let Some(mut backend) = try_backend() else {
return;
};
use super::super::WgpuBackend;
const SCALE_BY_TWO: &str = r#"
@group(0) @binding(0) var<storage, read> input_buf: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_buf: array<f32>;
@compute @workgroup_size(64)
fn scale_by_two(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if i < arrayLength(&input_buf) {
output_buf[i] = input_buf[i] * 2.0;
}
}
"#;
let n: usize = 4;
let input_data: Vec<f32> = (1..=n as u32).map(|x| x as f32).collect();
let in_handle = backend.create_buffer_storage((n * 4) as u64);
let out_handle = backend.create_buffer_storage((n * 4) as u64);
backend.queue.write_buffer(
backend.buffers[in_handle.0].as_ref().unwrap(),
0,
bytemuck::cast_slice(&input_data),
);
let result = backend.dispatch_wgsl(
SCALE_BY_TWO,
"scale_by_two",
&[
(
in_handle,
wgpu::BufferBindingType::Storage { read_only: true },
),
(
out_handle,
wgpu::BufferBindingType::Storage { read_only: false },
),
],
[1, 1, 1],
);
assert!(result.is_ok(), "dispatch_wgsl failed: {:?}", result.err());
let out = backend.read_buffer_f64(out_handle);
assert_eq!(out.len(), n);
for (i, &v) in out.iter().enumerate() {
let expected = (i + 1) as f64 * 2.0;
assert!(
(v - expected).abs() < 0.01,
"element {i}: expected {expected}, got {v}"
);
}
let mut stub = WgpuBackend::new_stub();
let h = stub.create_buffer(4);
let _ = stub.read_buffer(h);
}
#[test]
fn dispatch_count_for_zero_items() {
assert_eq!(WgpuBackendReal::dispatch_count_for(0, 64), [0, 1, 1]);
}
#[test]
fn dispatch_count_for_65_items() {
assert_eq!(WgpuBackendReal::dispatch_count_for(65, 64), [2, 1, 1]);
}
#[test]
fn dispatch_count_for_exact_workgroup() {
assert_eq!(WgpuBackendReal::dispatch_count_for(256, 64), [4, 1, 1]);
}
}
}