use crate::{
buffer::{read_back, storage_buffer_init, uniform_buffer},
context::ComputeContext,
pipeline::compute_pipeline,
wgsl::{
SHADER_BITONIC_SORT, SHADER_MAP_F32_TEMPLATE, SHADER_REDUCTION_SUM, SHADER_SPH_DENSITY,
SHADER_ZIP_MAP_F32_TEMPLATE,
},
};
use wgpu;
const PI: f32 = std::f32::consts::PI;
fn validate_wgsl_op(op: &str) -> Result<(), &'static str> {
for b in op.bytes() {
let allowed = b.is_ascii_alphanumeric()
|| matches!(
b,
b' ' | b'\t'
| b'+'
| b'-'
| b'*'
| b'/'
| b'%'
| b'!'
| b'<'
| b'>'
| b'='
| b'('
| b')'
| b','
| b'.'
| b'_'
);
if !allowed {
return Err("op contains a character not permitted in a WGSL expression");
}
}
Ok(())
}
pub struct Dispatcher<'a> {
ctx: &'a ComputeContext,
}
impl<'a> Dispatcher<'a> {
pub fn new(ctx: &'a ComputeContext) -> Self {
Self { ctx }
}
pub fn map_f32(&self, src: &[f32], op: &str) -> Vec<f32> {
assert!(
!src.is_empty(),
"Dispatcher::map_f32: src must be non-empty"
);
assert!(
validate_wgsl_op(op).is_ok(),
"Dispatcher::map_f32: invalid WGSL expression — op contains forbidden characters: {op:?}"
);
let n = src.len() as u32;
let wgsl = SHADER_MAP_F32_TEMPLATE.replace("%%OP%%", op);
let device = &self.ctx.device;
let queue = &self.ctx.queue;
let src_buf = storage_buffer_init(device, "map-src", bytemuck::cast_slice(src));
let dst_buf = storage_buffer_init(
device,
"map-dst",
bytemuck::cast_slice(&vec![0.0_f32; src.len()]),
);
let n_buf = uniform_buffer(device, "map-n", bytemuck::bytes_of(&n));
let pipeline = compute_pipeline(device, &wgsl, "main_map");
let bg_layout = pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("map-bg"),
layout: &bg_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: src_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: dst_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: n_buf.as_entire_binding(),
},
],
});
let workgroups = n.div_ceil(64);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("map-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("map-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
read_back::<f32>(device, queue, &dst_buf, src.len())
}
pub fn zip_map_f32(&self, a: &[f32], b: &[f32], op: &str) -> Vec<f32> {
assert!(
!a.is_empty(),
"Dispatcher::zip_map_f32: a must be non-empty"
);
assert_eq!(
a.len(),
b.len(),
"Dispatcher::zip_map_f32: a and b must have equal length"
);
assert!(
validate_wgsl_op(op).is_ok(),
"Dispatcher::zip_map_f32: invalid WGSL expression — op contains forbidden characters: {op:?}"
);
let n = a.len() as u32;
let wgsl = SHADER_ZIP_MAP_F32_TEMPLATE.replace("%%OP%%", op);
let device = &self.ctx.device;
let queue = &self.ctx.queue;
let a_buf = storage_buffer_init(device, "zip-a", bytemuck::cast_slice(a));
let b_buf = storage_buffer_init(device, "zip-b", bytemuck::cast_slice(b));
let dst_buf = storage_buffer_init(
device,
"zip-dst",
bytemuck::cast_slice(&vec![0.0_f32; a.len()]),
);
let n_buf = uniform_buffer(device, "zip-n", bytemuck::bytes_of(&n));
let pipeline = compute_pipeline(device, &wgsl, "main_zip_map");
let bg_layout = pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("zip-bg"),
layout: &bg_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: b_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: dst_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: n_buf.as_entire_binding(),
},
],
});
let workgroups = n.div_ceil(64);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("zip-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("zip-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
read_back::<f32>(device, queue, &dst_buf, a.len())
}
pub fn reduce_sum_f32(&self, data: &[f32]) -> f32 {
assert!(
!data.is_empty(),
"Dispatcher::reduce_sum_f32: data must be non-empty"
);
let device = &self.ctx.device;
let queue = &self.ctx.queue;
let input_buf = storage_buffer_init(device, "reduce-in", bytemuck::cast_slice(data));
let output_buf =
storage_buffer_init(device, "reduce-out", bytemuck::cast_slice(&[0.0_f32]));
let pipeline = compute_pipeline(device, SHADER_REDUCTION_SUM, "main_cs");
let bg_layout = pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("reduce-bg"),
layout: &bg_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buf.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("reduce-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("reduce-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
let result = read_back::<f32>(device, queue, &output_buf, 1);
result[0]
}
pub fn sph_density(&self, positions: &[[f32; 3]], masses: &[f32], h: f32) -> Vec<f32> {
assert!(
!positions.is_empty(),
"Dispatcher::sph_density: positions must be non-empty"
);
assert_eq!(
positions.len(),
masses.len(),
"Dispatcher::sph_density: positions and masses must have equal length"
);
let n = positions.len() as u32;
let h_sq = h * h;
let kernel_coeff = 315.0_f32 / (64.0 * PI * h.powi(9));
let positions_vec4: Vec<f32> = positions
.iter()
.flat_map(|&[x, y, z]| [x, y, z, 0.0_f32])
.collect();
let device = &self.ctx.device;
let queue = &self.ctx.queue;
let pos_buf = storage_buffer_init(device, "sph-pos", bytemuck::cast_slice(&positions_vec4));
let mass_buf = storage_buffer_init(device, "sph-mass", bytemuck::cast_slice(masses));
let density_buf = storage_buffer_init(
device,
"sph-density",
bytemuck::cast_slice(&vec![0.0_f32; positions.len()]),
);
let params_bytes: [u8; 16] = {
let mut bytes = [0u8; 16];
bytes[0..4].copy_from_slice(&n.to_ne_bytes());
bytes[4..8].copy_from_slice(&h_sq.to_ne_bytes());
bytes[8..12].copy_from_slice(&kernel_coeff.to_ne_bytes());
bytes[12..16].copy_from_slice(&0u32.to_ne_bytes());
bytes
};
let params_buf = uniform_buffer(device, "sph-params", ¶ms_bytes);
let pipeline = compute_pipeline(device, SHADER_SPH_DENSITY, "main_sph");
let bg_layout = pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("sph-bg"),
layout: &bg_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: pos_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: mass_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: density_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: params_buf.as_entire_binding(),
},
],
});
let workgroups = n.div_ceil(64);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sph-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sph-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
read_back::<f32>(device, queue, &density_buf, positions.len())
}
pub fn sort_f32(&self, data: &[f32]) -> Vec<f32> {
assert!(
!data.is_empty(),
"Dispatcher::sort_f32: data must be non-empty"
);
let original_len = data.len();
let padded_len = original_len.next_power_of_two();
let mut padded = Vec::with_capacity(padded_len);
padded.extend_from_slice(data);
padded.resize(padded_len, f32::MAX);
let n = padded_len as u32;
let device = &self.ctx.device;
let queue = &self.ctx.queue;
let data_buf = storage_buffer_init(device, "sort-data", bytemuck::cast_slice(&padded));
let pipeline = compute_pipeline(device, SHADER_BITONIC_SORT, "main_bitonic");
let bg_layout = pipeline.get_bind_group_layout(0);
let workgroups = n.div_ceil(64);
let mut k: u32 = 2;
while k <= n {
let mut j = k >> 1;
while j >= 1 {
let step_bytes: [u8; 16] = {
let mut b = [0u8; 16];
b[0..4].copy_from_slice(&n.to_ne_bytes());
b[4..8].copy_from_slice(&k.to_ne_bytes());
b[8..12].copy_from_slice(&j.to_ne_bytes());
b[12..16].copy_from_slice(&0u32.to_ne_bytes());
b
};
let step_buf = uniform_buffer(device, "sort-step", &step_bytes);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("sort-bg"),
layout: &bg_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: data_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: step_buf.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sort-step-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sort-step-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
device
.poll(wgpu::PollType::wait_indefinitely())
.expect("sort_f32: device poll failed");
j >>= 1;
}
k <<= 1;
}
let sorted = read_back::<f32>(device, queue, &data_buf, padded_len);
sorted[..original_len].to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn map_f32_doubles() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let d = Dispatcher::new(&ctx);
let out = d.map_f32(&[1.0_f32, 2.0, 3.0], "x * 2.0");
assert_eq!(out.len(), 3);
assert!((out[0] - 2.0).abs() < 1e-5, "expected 2.0, got {}", out[0]);
assert!((out[1] - 4.0).abs() < 1e-5, "expected 4.0, got {}", out[1]);
assert!((out[2] - 6.0).abs() < 1e-5, "expected 6.0, got {}", out[2]);
}
#[test]
fn zip_map_f32_adds() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let d = Dispatcher::new(&ctx);
let out = d.zip_map_f32(&[1.0_f32, 2.0], &[3.0, 4.0], "a + b");
assert_eq!(out.len(), 2);
assert!((out[0] - 4.0).abs() < 1e-5, "expected 4.0, got {}", out[0]);
assert!((out[1] - 6.0).abs() < 1e-5, "expected 6.0, got {}", out[1]);
}
#[test]
fn reduce_sum_f32_correct() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let d = Dispatcher::new(&ctx);
let sum = d.reduce_sum_f32(&[1.0_f32, 2.0, 3.0, 4.0]);
assert!((sum - 10.0).abs() < 1e-3, "expected 10.0, got {sum}");
}
#[test]
fn sph_density_single_particle() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let d = Dispatcher::new(&ctx);
let positions = [[0.0_f32, 0.0, 0.0]];
let masses = [1.0_f32];
let densities = d.sph_density(&positions, &masses, 1.0);
assert_eq!(densities.len(), 1);
assert!(
densities[0] > 0.0,
"single-particle density must be > 0, got {}",
densities[0]
);
}
#[test]
fn sort_f32_small() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let d = Dispatcher::new(&ctx);
let out = d.sort_f32(&[4.0_f32, 2.0, 3.0, 1.0]);
assert_eq!(out, vec![1.0_f32, 2.0, 3.0, 4.0]);
}
#[test]
fn map_f32_rejects_injection() {
let result =
std::panic::catch_unwind(|| validate_wgsl_op("x; } @compute fn evil() {").unwrap());
assert!(result.is_err(), "injection expression should be rejected");
}
#[test]
fn validate_wgsl_op_accepts_valid_expressions() {
for expr in &[
"x * 2.0",
"a + b",
"sqrt(x)",
"max(a, b)",
"sin(x) + cos(x)",
"x * x + 1.0",
"a / (b + 1.0)",
] {
assert!(
validate_wgsl_op(expr).is_ok(),
"should accept valid expression: {expr}"
);
}
}
#[test]
fn validate_wgsl_op_rejects_injection_chars() {
for bad in &[
"x; }",
"x\n@compute",
"x{evil}",
"x; @group(0)",
"x // comment\n}",
"x: f32",
] {
assert!(validate_wgsl_op(bad).is_err(), "should reject: {bad:?}");
}
}
}