Skip to main content

oxiphysics_gpu/shaders/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::collections::HashMap;
6
7use super::types::ShaderMetadata;
8
9/// WGSL compute shader for SPH density computation.
10pub const SPH_DENSITY_WGSL: &str = r#"
11// Binding layout:
12// @group(0) @binding(0) var<storage, read> positions: array<vec4<f32>>;
13// @group(0) @binding(1) var<storage, read> masses: array<f32>;
14// @group(0) @binding(2) var<storage, read_write> densities: array<f32>;
15// @group(0) @binding(3) var<uniform> params: SphParams;
16
17pub(super) struct SphParams {
18    n_particles: u32,
19    h: f32,
20    h2: f32,
21    h3: f32,
22}
23
24@compute @workgroup_size(64)
25pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
26    let i = id.x;
27    if (i >= params.n_particles) { return; }
28
29    var density: f32 = 0.0;
30    let pi = positions[i].xyz;
31
32    for (var j: u32 = 0u; j < params.n_particles; j++) {
33        let r = pi - positions[j].xyz;
34        let r2 = dot(r, r);
35        if (r2 < params.h2) {
36            let q = 1.0 - r2 / params.h2;
37            density += masses[j] * q * q * q;
38        }
39    }
40
41    densities[i] = density * 315.0 / (64.0 * 3.14159265 * params.h3);
42}
43"#;
44/// WGSL shader for particle integration (velocity Verlet half-step).
45pub const INTEGRATE_WGSL: &str = r#"
46pub(super) struct Particle { pos: vec4<f32>, vel: vec4<f32>, }
47pub(super) struct IntegParams { dt: f32, gravity_y: f32, n: u32, _pad: u32, }
48
49@group(0) @binding(0) var<storage, read_write> particles: array<Particle>;
50@group(0) @binding(1) var<storage, read> forces: array<vec4<f32>>;
51@group(0) @binding(2) var<uniform> params: IntegParams;
52
53@compute @workgroup_size(64)
54pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
55    let i = id.x;
56    if (i >= params.n) { return; }
57
58    var p = particles[i];
59    let f = forces[i].xyz + vec3<f32>(0.0, params.gravity_y, 0.0);
60    p.vel = vec4<f32>(p.vel.xyz + f * params.dt, 0.0);
61    p.pos = vec4<f32>(p.pos.xyz + p.vel.xyz * params.dt, 1.0);
62    particles[i] = p;
63}
64"#;
65/// WGSL shader for LBM BGK collision (D2Q9).
66pub const LBM_BGK_D2Q9_WGSL: &str = r#"
67// D2Q9 LBM BGK collision kernel
68// Layout: f[node * 9 + direction]
69pub(super) struct LbmParams { nx: u32, ny: u32, tau: f32, _pad: u32, }
70
71@group(0) @binding(0) var<storage, read_write> f: array<f32>;
72@group(0) @binding(1) var<storage, read_write> rho: array<f32>;
73@group(0) @binding(2) var<storage, read_write> ux: array<f32>;
74@group(0) @binding(3) var<storage, read_write> uy: array<f32>;
75@group(0) @binding(4) var<uniform> params: LbmParams;
76
77@compute @workgroup_size(64)
78pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
79    let idx = id.x;
80    if (idx >= params.nx * params.ny) { return; }
81
82    // Compute macroscopic values
83    var r: f32 = 0.0;
84    var u: f32 = 0.0;
85    var v: f32 = 0.0;
86    let w = array<f32, 9>(4.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/36.0, 1.0/36.0, 1.0/36.0, 1.0/36.0);
87    let ex = array<f32, 9>(0.0, 1.0, 0.0, -1.0, 0.0, 1.0, -1.0, -1.0, 1.0);
88    let ey = array<f32, 9>(0.0, 0.0, 1.0, 0.0, -1.0, 1.0, 1.0, -1.0, -1.0);
89
90    for (var i: u32 = 0u; i < 9u; i++) {
91        let fi = f[idx * 9u + i];
92        r += fi;
93        u += ex[i] * fi;
94        v += ey[i] * fi;
95    }
96    if (r > 0.0) { u /= r; v /= r; }
97
98    rho[idx] = r; ux[idx] = u; uy[idx] = v;
99
100    // BGK collision
101    for (var i: u32 = 0u; i < 9u; i++) {
102        let eu = ex[i]*u + ey[i]*v;
103        let feq = w[i] * r * (1.0 + 3.0*eu + 4.5*eu*eu - 1.5*(u*u+v*v));
104        f[idx * 9u + i] += (feq - f[idx * 9u + i]) / params.tau;
105    }
106}
107"#;
108/// WGSL compute shader for parallel cell-list construction.
109pub const CELL_LIST_WGSL: &str = r#"
110pub(super) struct CellParams {
111    nx:      u32,
112    ny:      u32,
113    nz:      u32,
114    cell_size: f32,
115    n_atoms: u32,
116    box_x:   f32,
117    box_y:   f32,
118    box_z:   f32,
119}
120
121@group(0) @binding(0) var<storage, read>       positions:    array<vec4<f32>>;
122@group(0) @binding(1) var<storage, read_write> cell_indices: array<u32>;
123@group(0) @binding(2) var<uniform>             params:       CellParams;
124
125@compute @workgroup_size(64)
126pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
127    let atom_id = id.x;
128    if (atom_id >= params.n_atoms) { return; }
129
130    let pos = positions[atom_id].xyz;
131
132    let ix = clamp(u32(pos.x / params.cell_size), 0u, params.nx - 1u);
133    let iy = clamp(u32(pos.y / params.cell_size), 0u, params.ny - 1u);
134    let iz = clamp(u32(pos.z / params.cell_size), 0u, params.nz - 1u);
135
136    cell_indices[atom_id] = iz * params.ny * params.nx
137                          + iy * params.nx
138                          + ix;
139}
140"#;
141/// WGSL compute shader that computes a sphere SDF on the GPU.
142pub const SDF_COMPUTE_WGSL: &str = r#"
143pub(super) struct SdfSphereParams {
144    nx:       u32,
145    ny:       u32,
146    nz:       u32,
147    dx:       f32,
148    origin_x: f32,
149    origin_y: f32,
150    origin_z: f32,
151    cx:       f32,
152    cy:       f32,
153    cz:       f32,
154    radius:   f32,
155    _pad:     u32,
156}
157
158@group(0) @binding(0) var<storage, read_write> sdf_values: array<f32>;
159@group(0) @binding(1) var<uniform>             params:     SdfSphereParams;
160
161@compute @workgroup_size(64)
162pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
163    let total = params.nx * params.ny * params.nz;
164    let idx   = id.x;
165    if (idx >= total) { return; }
166
167    let i = idx / (params.ny * params.nz);
168    let j = (idx / params.nz) % params.ny;
169    let k = idx % params.nz;
170
171    let px = params.origin_x + (f32(i) + 0.5) * params.dx;
172    let py = params.origin_y + (f32(j) + 0.5) * params.dx;
173    let pz = params.origin_z + (f32(k) + 0.5) * params.dx;
174
175    let dist = distance(vec3<f32>(px, py, pz),
176                        vec3<f32>(params.cx, params.cy, params.cz));
177    sdf_values[idx] = dist - params.radius;
178}
179"#;
180/// WGSL compute shader for LBM streaming step (D3Q19).
181pub const LBM_STREAMING_SHADER: &str = r#"
182// LBM D3Q19 streaming step
183// Distributes f values from source to destination according to lattice velocities.
184pub(super) struct StreamParams { nx: u32, ny: u32, nz: u32, _pad: u32, }
185
186@group(0) @binding(0) var<storage, read>       f_src:  array<f32>;
187@group(0) @binding(1) var<storage, read_write> f_dst:  array<f32>;
188@group(0) @binding(2) var<uniform>             params: StreamParams;
189
190@compute @workgroup_size(64)
191pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
192    let nx = params.nx;
193    let ny = params.ny;
194    let nz = params.nz;
195    let total = nx * ny * nz;
196    let idx = id.x;
197    if (idx >= total) { return; }
198
199    let iz = idx / (ny * nx);
200    let iy = (idx / nx) % ny;
201    let ix = idx % nx;
202
203    // D3Q19 lattice velocities (abbreviated: direction 0 = rest)
204    let ex = array<i32, 19>( 0, 1,-1, 0, 0, 0, 0, 1,-1, 1,-1, 1,-1, 1,-1, 0, 0, 0, 0);
205    let ey = array<i32, 19>( 0, 0, 0, 1,-1, 0, 0, 1, 1,-1,-1, 0, 0, 0, 0, 1,-1, 1,-1);
206    let ez = array<i32, 19>( 0, 0, 0, 0, 0, 1,-1, 0, 0, 0, 0, 1, 1,-1,-1, 1, 1,-1,-1);
207
208    for (var q: u32 = 0u; q < 19u; q++) {
209        let sx = (i32(ix) - ex[q] + i32(nx)) % i32(nx);
210        let sy = (i32(iy) - ey[q] + i32(ny)) % i32(ny);
211        let sz = (i32(iz) - ez[q] + i32(nz)) % i32(nz);
212        let src_idx = u32(sz) * ny * nx + u32(sy) * nx + u32(sx);
213        f_dst[idx * 19u + q] = f_src[src_idx * 19u + q];
214    }
215}
216"#;
217/// WGSL compute shader for rigid body semi-implicit Euler integration.
218pub const RIGID_INTEGRATE_SHADER: &str = r#"
219// Semi-implicit Euler integration for rigid bodies.
220pub(super) struct RigidBody {
221    pos:       vec4<f32>,  // xyz = position, w = mass
222    vel:       vec4<f32>,  // xyz = linear velocity, w unused
223    ang_vel:   vec4<f32>,  // xyz = angular velocity, w unused
224    force:     vec4<f32>,  // xyz = accumulated force, w unused
225    torque:    vec4<f32>,  // xyz = accumulated torque, w unused
226}
227pub(super) struct IntegRigidParams { dt: f32, n: u32, _pad0: u32, _pad1: u32, }
228
229@group(0) @binding(0) var<storage, read_write> bodies: array<RigidBody>;
230@group(0) @binding(1) var<uniform>             params: IntegRigidParams;
231
232@compute @workgroup_size(64)
233pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
234    let i = id.x;
235    if (i >= params.n) { return; }
236
237    var b = bodies[i];
238    let mass = b.pos.w;
239    if (mass <= 0.0) { return; }
240
241    // Semi-implicit Euler: update velocity first, then position.
242    let acc = b.force.xyz / mass;
243    b.vel = vec4<f32>(b.vel.xyz + acc * params.dt, 0.0);
244    b.pos = vec4<f32>(b.pos.xyz + b.vel.xyz * params.dt, mass);
245
246    // Clear forces.
247    b.force  = vec4<f32>(0.0, 0.0, 0.0, 0.0);
248    b.torque = vec4<f32>(0.0, 0.0, 0.0, 0.0);
249    bodies[i] = b;
250}
251"#;
252/// WGSL compute shader for bitonic sort used in SAP broadphase.
253pub const BROADPHASE_SORT_SHADER: &str = r#"
254// Bitonic sort pass for SAP broadphase.
255// Each invocation compares and optionally swaps one pair of elements.
256pub(super) struct SortParams { n: u32, step: u32, stage: u32, _pad: u32, }
257
258@group(0) @binding(0) var<storage, read_write> keys:   array<f32>;
259@group(0) @binding(1) var<storage, read_write> values: array<u32>;
260@group(0) @binding(2) var<uniform>             params: SortParams;
261
262@compute @workgroup_size(64)
263pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
264    let idx = id.x;
265    let n   = params.n;
266    if (idx >= n / 2u) { return; }
267
268    let step  = params.step;
269    let stage = params.stage;
270
271    let j   = idx & (step - 1u);
272    let i   = (idx - j) * 2u + j;
273    let i2  = i + step;
274
275    if (i2 >= n) { return; }
276
277    // Ascending sort when the block bit is 0.
278    let ascending = ((i / (stage * 2u)) & 1u) == 0u;
279    let swap = (keys[i] > keys[i2]) == ascending;
280
281    if (swap) {
282        let tmp_k   = keys[i];   keys[i]   = keys[i2];   keys[i2]   = tmp_k;
283        let tmp_v   = values[i]; values[i] = values[i2]; values[i2] = tmp_v;
284    }
285}
286"#;
287/// WGSL compute shader for SPH pressure force computation.
288pub const SPH_FORCE_WGSL: &str = r#"
289pub(super) struct SphForceParams {
290    n_particles: u32,
291    h: f32,
292    mu: f32,
293    _pad: u32,
294}
295
296@group(0) @binding(0) var<storage, read>       positions:  array<vec4<f32>>;
297@group(0) @binding(1) var<storage, read>       velocities: array<vec4<f32>>;
298@group(0) @binding(2) var<storage, read>       densities:  array<f32>;
299@group(0) @binding(3) var<storage, read>       pressures:  array<f32>;
300@group(0) @binding(4) var<storage, read>       masses:     array<f32>;
301@group(0) @binding(5) var<storage, read_write> forces:     array<vec4<f32>>;
302@group(0) @binding(6) var<uniform>             params:     SphForceParams;
303
304@compute @workgroup_size(64)
305pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
306    let i = id.x;
307    if (i >= params.n_particles) { return; }
308
309    let pi = positions[i].xyz;
310    let vi = velocities[i].xyz;
311    var force: vec3<f32> = vec3<f32>(0.0, 0.0, 0.0);
312    let h2 = params.h * params.h;
313
314    for (var j: u32 = 0u; j < params.n_particles; j++) {
315        if (j == i) { continue; }
316        let r_vec = pi - positions[j].xyz;
317        let r2 = dot(r_vec, r_vec);
318        let r = sqrt(r2);
319        if (r < 0.0001 || r >= params.h) { continue; }
320
321        // Spiky gradient
322        let h_r = params.h - r;
323        let grad_mag = -45.0 / (3.14159265 * pow(params.h, 6.0)) * h_r * h_r;
324        let grad = r_vec / r * grad_mag;
325
326        // Pressure force
327        let p_term = -masses[j] * (pressures[i] + pressures[j]) / (2.0 * densities[j]);
328        force += p_term * grad;
329
330        // Viscosity
331        let v_lap = 45.0 / (3.14159265 * pow(params.h, 6.0)) * h_r;
332        force += params.mu * masses[j] * (velocities[j].xyz - vi) / densities[j] * v_lap;
333    }
334
335    forces[i] = vec4<f32>(force, 0.0);
336}
337"#;
338/// WGSL compute shader for boundary condition enforcement.
339pub const BOUNDARY_ENFORCE_WGSL: &str = r#"
340pub(super) struct BoundaryParams {
341    n: u32,
342    box_min_x: f32,
343    box_min_y: f32,
344    box_min_z: f32,
345    box_max_x: f32,
346    box_max_y: f32,
347    box_max_z: f32,
348    restitution: f32,
349}
350
351@group(0) @binding(0) var<storage, read_write> positions:  array<vec4<f32>>;
352@group(0) @binding(1) var<storage, read_write> velocities: array<vec4<f32>>;
353@group(0) @binding(2) var<uniform>             params:     BoundaryParams;
354
355@compute @workgroup_size(64)
356pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
357    let i = id.x;
358    if (i >= params.n) { return; }
359
360    var pos = positions[i].xyz;
361    var vel = velocities[i].xyz;
362    let e = params.restitution;
363
364    if (pos.x < params.box_min_x) { pos.x = params.box_min_x; vel.x = -vel.x * e; }
365    if (pos.x > params.box_max_x) { pos.x = params.box_max_x; vel.x = -vel.x * e; }
366    if (pos.y < params.box_min_y) { pos.y = params.box_min_y; vel.y = -vel.y * e; }
367    if (pos.y > params.box_max_y) { pos.y = params.box_max_y; vel.y = -vel.y * e; }
368    if (pos.z < params.box_min_z) { pos.z = params.box_min_z; vel.z = -vel.z * e; }
369    if (pos.z > params.box_max_z) { pos.z = params.box_max_z; vel.z = -vel.z * e; }
370
371    positions[i]  = vec4<f32>(pos, positions[i].w);
372    velocities[i] = vec4<f32>(vel, 0.0);
373}
374"#;
375/// Resolves `#include "filename"` directives in shader source.
376///
377/// This is a simple preprocessor-like mechanism for composing shaders
378/// from reusable fragments.
379pub fn resolve_includes(source: &str, includes: &HashMap<&str, &str>) -> String {
380    let mut result = String::new();
381    for line in source.lines() {
382        let trimmed = line.trim();
383        if trimmed.starts_with("#include") {
384            if let Some(start) = trimmed.find('"')
385                && let Some(end) = trimmed[start + 1..].find('"')
386            {
387                let filename = &trimmed[start + 1..start + 1 + end];
388                if let Some(content) = includes.get(filename) {
389                    result.push_str(content);
390                    result.push('\n');
391                    continue;
392                }
393            }
394            result.push_str("// UNRESOLVED: ");
395            result.push_str(line);
396            result.push('\n');
397        } else {
398            result.push_str(line);
399            result.push('\n');
400        }
401    }
402    result
403}
404/// Validate that a WGSL source has basic structural correctness.
405///
406/// Checks:
407/// 1. The source contains at least one `fn` keyword.
408/// 2. Braces are balanced (same number of `{` and `}`).
409pub fn validate_wgsl_structure(source: &str) -> bool {
410    if !source.contains("fn") {
411        return false;
412    }
413    let open: usize = source.chars().filter(|&c| c == '{').count();
414    let close: usize = source.chars().filter(|&c| c == '}').count();
415    open == close && open > 0
416}
417/// Validate that all built-in shader sources contain the `@compute` entry point marker.
418pub fn validate_shader_sources() -> bool {
419    SPH_DENSITY_WGSL.contains("@compute")
420        && INTEGRATE_WGSL.contains("@compute")
421        && LBM_BGK_D2Q9_WGSL.contains("@compute")
422        && CELL_LIST_WGSL.contains("@compute")
423        && SDF_COMPUTE_WGSL.contains("@compute")
424        && LBM_STREAMING_SHADER.contains("@compute")
425        && RIGID_INTEGRATE_SHADER.contains("@compute")
426        && BROADPHASE_SORT_SHADER.contains("@compute")
427        && SPH_FORCE_WGSL.contains("@compute")
428        && BOUNDARY_ENFORCE_WGSL.contains("@compute")
429}
430/// Count the total number of binding annotations across all built-in shaders.
431pub fn count_total_bindings() -> usize {
432    let all_shaders = [
433        SPH_DENSITY_WGSL,
434        INTEGRATE_WGSL,
435        LBM_BGK_D2Q9_WGSL,
436        CELL_LIST_WGSL,
437        SDF_COMPUTE_WGSL,
438        LBM_STREAMING_SHADER,
439        RIGID_INTEGRATE_SHADER,
440        BROADPHASE_SORT_SHADER,
441        SPH_FORCE_WGSL,
442        BOUNDARY_ENFORCE_WGSL,
443    ];
444    all_shaders
445        .iter()
446        .map(|s| s.matches("@binding(").count())
447        .sum()
448}
449/// Simple 64-bit hash of a string (FNV-1a variant).
450pub(super) fn simple_hash(s: &str) -> u64 {
451    let mut h: u64 = 0xcbf2_9ce4_8422_2325;
452    for byte in s.bytes() {
453        h ^= byte as u64;
454        h = h.wrapping_mul(0x0000_0100_0000_01b3);
455    }
456    h
457}
458/// Produce a mock SPIR-V byte sequence from WGSL source.
459///
460/// In a real pipeline this would invoke naga or spirv-cross.  Here we
461/// produce a deterministic stub that starts with the SPIR-V magic number
462/// and encodes basic information from the source.
463pub fn mock_compile_to_spirv(source: &str, entry_point: &str) -> Vec<u8> {
464    let mut out = vec![0x07u8, 0x23, 0x02, 0x03];
465    out.extend_from_slice(&[0x00, 0x01, 0x05, 0x00]);
466    let src_len = source.len() as u32;
467    out.extend_from_slice(&src_len.to_le_bytes());
468    let ep_hash = simple_hash(entry_point) as u32;
469    out.extend_from_slice(&ep_hash.to_le_bytes());
470    let bindings = source.matches("@binding(").count() as u32;
471    out.extend_from_slice(&bindings.to_le_bytes());
472    while !out.len().is_multiple_of(4) {
473        out.push(0x00);
474    }
475    out
476}
477/// Parse `@workgroup_size(x)` or `@workgroup_size(x, y, z)` from WGSL.
478pub(super) fn parse_workgroup_size(source: &str) -> [u32; 3] {
479    if let Some(pos) = source.find("@workgroup_size(") {
480        let rest = &source[pos + 16..];
481        if let Some(end) = rest.find(')') {
482            let inner = &rest[..end];
483            let parts: Vec<u32> = inner
484                .split(',')
485                .map(|s| s.trim().parse::<u32>().unwrap_or(1))
486                .collect();
487            return match parts.len() {
488                1 => [parts[0], 1, 1],
489                2 => [parts[0], parts[1], 1],
490                3 => [parts[0], parts[1], parts[2]],
491                _ => [1, 1, 1],
492            };
493        }
494    }
495    [1, 1, 1]
496}
497/// Validate that a [`ShaderMetadata`] record is internally consistent.
498///
499/// Checks:
500/// 1. Workgroup size is non-zero in every dimension.
501/// 2. Total threads-per-workgroup does not exceed 1024 (common GPU limit).
502/// 3. `bind_group_count` is <= 4 (common GPU limit).
503/// 4. `entry_point` is non-empty.
504pub fn validate_shader_metadata(meta: &ShaderMetadata) -> Result<(), String> {
505    if meta.workgroup_size[0] == 0 || meta.workgroup_size[1] == 0 || meta.workgroup_size[2] == 0 {
506        return Err("workgroup_size must be non-zero in every dimension".to_string());
507    }
508    let total = meta.workgroup_size[0]
509        .saturating_mul(meta.workgroup_size[1])
510        .saturating_mul(meta.workgroup_size[2]);
511    if total > 1024 {
512        return Err(format!(
513            "total threads-per-workgroup {} exceeds 1024",
514            total
515        ));
516    }
517    if meta.bind_group_count > 4 {
518        return Err(format!(
519            "bind_group_count {} exceeds maximum of 4",
520            meta.bind_group_count
521        ));
522    }
523    if meta.entry_point.is_empty() {
524        return Err("entry_point must not be empty".to_string());
525    }
526    Ok(())
527}
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::shaders::AddressMode;
532    use crate::shaders::BindGroupLayout;
533    use crate::shaders::ComputeShaderDesc;
534    use crate::shaders::DescriptorSetLayout;
535    use crate::shaders::DescriptorType;
536    use crate::shaders::FilterMode;
537    use crate::shaders::PushConstantRange;
538    use crate::shaders::RenderPassDesc;
539    use crate::shaders::SamplerDesc;
540    use crate::shaders::ShaderCache;
541    use crate::shaders::ShaderCompilationPipeline;
542    use crate::shaders::ShaderHotReloadManager;
543    use crate::shaders::ShaderRegistry;
544    use crate::shaders::ShaderStage;
545    use crate::shaders::ShaderTemplate;
546    use crate::shaders::SpecializationMap;
547    use crate::shaders::SpirVModule;
548    use crate::shaders::TextureFormat;
549    use crate::shaders::UniformBufferDesc;
550    #[test]
551    fn test_shader_sources_valid() {
552        assert!(validate_shader_sources());
553    }
554    #[test]
555    fn test_shader_sph_contains_density() {
556        assert!(SPH_DENSITY_WGSL.contains("density"));
557    }
558    #[test]
559    fn test_shader_lbm_contains_bgk() {
560        assert!(LBM_BGK_D2Q9_WGSL.contains("feq"));
561    }
562    /// ShaderTemplate replaces placeholders correctly.
563    #[test]
564    fn test_template_instantiation() {
565        let tmpl = ShaderTemplate::new(
566            "@compute @workgroup_size({WORKGROUP_SIZE})\nfn main() { var x: {BUFFER_TYPE}; }",
567        );
568        let mut params = HashMap::new();
569        params.insert("WORKGROUP_SIZE", "128");
570        params.insert("BUFFER_TYPE", "f32");
571        let result = tmpl.instantiate(&params);
572        assert!(
573            result.contains("@workgroup_size(128)"),
574            "Workgroup size not substituted: {result}"
575        );
576        assert!(
577            result.contains("var x: f32;"),
578            "Buffer type not substituted: {result}"
579        );
580        assert!(
581            !result.contains("{WORKGROUP_SIZE}"),
582            "Placeholder not removed"
583        );
584    }
585    /// ShaderTemplate leaves unknown placeholders unchanged.
586    #[test]
587    fn test_template_unknown_placeholder() {
588        let tmpl = ShaderTemplate::new("fn main() { var x: {UNKNOWN}; }");
589        let params = HashMap::new();
590        let result = tmpl.instantiate(&params);
591        assert!(
592            result.contains("{UNKNOWN}"),
593            "Unknown placeholder should be preserved"
594        );
595    }
596    /// ShaderRegistry registers and retrieves shaders by name.
597    #[test]
598    fn test_shader_registry_register_get() {
599        let mut reg = ShaderRegistry::new();
600        assert!(reg.is_empty());
601        let desc = ComputeShaderDesc::new("main", [64, 1, 1], LBM_STREAMING_SHADER);
602        reg.register("lbm_stream", desc);
603        assert_eq!(reg.len(), 1);
604        let retrieved = reg.get("lbm_stream").expect("Shader should be present");
605        assert_eq!(retrieved.entry_point, "main");
606        assert_eq!(retrieved.workgroup_size, [64, 1, 1]);
607        assert!(retrieved.source.contains("@compute"));
608    }
609    /// ShaderRegistry::with_builtins contains expected shaders.
610    #[test]
611    fn test_shader_registry_builtins() {
612        let reg = ShaderRegistry::with_builtins();
613        assert!(reg.get("sph_density").is_some());
614        assert!(reg.get("lbm_streaming").is_some());
615        assert!(reg.get("rigid_integrate").is_some());
616        assert!(reg.get("broadphase_sort").is_some());
617        assert!(reg.get("sph_force").is_some());
618        assert!(reg.get("boundary_enforce").is_some());
619        assert!(reg.get("integrate").is_some());
620        assert!(reg.get("lbm_bgk_d2q9").is_some());
621    }
622    /// validate_wgsl_structure returns true for balanced, valid WGSL.
623    #[test]
624    fn test_validate_wgsl_structure_valid() {
625        let src = "@compute @workgroup_size(64)\nfn main() { let x = 1u; }";
626        assert!(
627            validate_wgsl_structure(src),
628            "Valid WGSL should pass validation"
629        );
630    }
631    /// validate_wgsl_structure returns false for unbalanced braces.
632    #[test]
633    fn test_validate_wgsl_structure_unbalanced() {
634        let src = "fn main() { let x = 1u;";
635        assert!(
636            !validate_wgsl_structure(src),
637            "Unbalanced WGSL should fail validation"
638        );
639    }
640    /// All built-in shaders pass the structural validator.
641    #[test]
642    fn test_all_builtins_pass_structural_validation() {
643        let shaders = [
644            SPH_DENSITY_WGSL,
645            INTEGRATE_WGSL,
646            LBM_BGK_D2Q9_WGSL,
647            CELL_LIST_WGSL,
648            SDF_COMPUTE_WGSL,
649            LBM_STREAMING_SHADER,
650            RIGID_INTEGRATE_SHADER,
651            BROADPHASE_SORT_SHADER,
652            SPH_FORCE_WGSL,
653            BOUNDARY_ENFORCE_WGSL,
654        ];
655        for (i, src) in shaders.iter().enumerate() {
656            assert!(
657                validate_wgsl_structure(src),
658                "Shader {i} failed structural validation"
659            );
660        }
661    }
662    #[test]
663    fn test_template_placeholders() {
664        let tmpl = ShaderTemplate::new("fn main() { var x: {TYPE}; var y: {SIZE}; }");
665        let placeholders = tmpl.placeholders();
666        assert!(placeholders.contains(&"TYPE".to_string()));
667        assert!(placeholders.contains(&"SIZE".to_string()));
668    }
669    #[test]
670    fn test_template_all_placeholders_provided() {
671        let tmpl = ShaderTemplate::new("fn main() { var x: {TYPE}; }");
672        let mut params = HashMap::new();
673        assert!(!tmpl.all_placeholders_provided(&params));
674        params.insert("TYPE", "f32");
675        assert!(tmpl.all_placeholders_provided(&params));
676    }
677    #[test]
678    fn test_specialization_map() {
679        let mut sm = SpecializationMap::new();
680        assert!(sm.is_empty());
681        sm.define("WORKGROUP_SIZE", "64", "Threads per workgroup");
682        sm.define("MAX_NEIGHBORS", "128", "Max neighbors per particle");
683        assert_eq!(sm.len(), 2);
684        assert_eq!(sm.get("WORKGROUP_SIZE"), Some("64"));
685        sm.set("WORKGROUP_SIZE", "128");
686        assert_eq!(sm.get("WORKGROUP_SIZE"), Some("128"));
687        assert_eq!(sm.get("MAX_NEIGHBORS"), Some("128"));
688    }
689    #[test]
690    fn test_specialization_map_apply() {
691        let mut sm = SpecializationMap::new();
692        sm.define("WG", "64", "workgroup");
693        sm.set("WG", "128");
694        let source = "const WG = 64;\nfn main() {}";
695        let result = sm.apply(source);
696        assert!(
697            result.contains("const WG = 128;"),
698            "specialization not applied: {result}"
699        );
700    }
701    #[test]
702    fn test_resolve_includes() {
703        let mut includes = HashMap::new();
704        includes.insert("common.wgsl", "// common code\nfn helper() { }");
705        let source = "#include \"common.wgsl\"\nfn main() { }";
706        let resolved = resolve_includes(source, &includes);
707        assert!(
708            resolved.contains("common code"),
709            "include not resolved: {resolved}"
710        );
711        assert!(resolved.contains("fn main()"));
712    }
713    #[test]
714    fn test_resolve_includes_unresolved() {
715        let includes = HashMap::new();
716        let source = "#include \"missing.wgsl\"\nfn main() { }";
717        let resolved = resolve_includes(source, &includes);
718        assert!(
719            resolved.contains("UNRESOLVED"),
720            "unresolved include should be marked"
721        );
722    }
723    #[test]
724    fn test_shader_cache() {
725        let mut cache = ShaderCache::new();
726        assert!(cache.is_empty());
727        let source = cache.get_or_insert("test", || "fn main() { }".to_string());
728        assert_eq!(source, "fn main() { }");
729        assert!(cache.contains("test"));
730        assert_eq!(cache.len(), 1);
731        cache.remove("test");
732        assert!(!cache.contains("test"));
733    }
734    #[test]
735    fn test_shader_cache_returns_cached() {
736        let mut cache = ShaderCache::new();
737        cache.get_or_insert("test", || "first".to_string());
738        let second = cache.get_or_insert("test", || "second".to_string());
739        assert_eq!(second, "first", "cache should return first insertion");
740    }
741    #[test]
742    fn test_compute_shader_desc_binding_count() {
743        let desc = ComputeShaderDesc::new("main", [64, 1, 1], SPH_DENSITY_WGSL);
744        assert!(desc.binding_count() > 0);
745        assert_eq!(desc.threads_per_workgroup(), 64);
746    }
747    #[test]
748    fn test_registry_unregister() {
749        let mut reg = ShaderRegistry::new();
750        let desc = ComputeShaderDesc::new("main", [64, 1, 1], "fn main() { }");
751        reg.register("test", desc);
752        assert!(reg.contains("test"));
753        reg.unregister("test");
754        assert!(!reg.contains("test"));
755    }
756    #[test]
757    fn test_registry_names() {
758        let mut reg = ShaderRegistry::new();
759        reg.register(
760            "a",
761            ComputeShaderDesc::new("main", [64, 1, 1], "fn main() { }"),
762        );
763        reg.register(
764            "b",
765            ComputeShaderDesc::new("main", [64, 1, 1], "fn main() { }"),
766        );
767        let names: Vec<&str> = reg.names().collect();
768        assert_eq!(names.len(), 2);
769        assert!(names.contains(&"a"));
770        assert!(names.contains(&"b"));
771    }
772    #[test]
773    fn test_compilation_pipeline() {
774        let mut pipeline = ShaderCompilationPipeline::new();
775        let source = "@compute @workgroup_size(64)\nfn main() { let x = 1u; }";
776        let result = pipeline.compile("test", source, None).unwrap();
777        assert!(result.contains("fn main()"));
778        assert_eq!(pipeline.cache_size(), 1);
779        let result2 = pipeline.compile("test", source, None).unwrap();
780        assert_eq!(result, result2);
781    }
782    #[test]
783    fn test_compilation_pipeline_with_includes() {
784        let mut pipeline = ShaderCompilationPipeline::new();
785        pipeline.add_include("common.wgsl", "// included code");
786        let source = "#include \"common.wgsl\"\n@compute @workgroup_size(64)\nfn main() { }";
787        let result = pipeline.compile("test_inc", source, None).unwrap();
788        assert!(result.contains("included code"));
789    }
790    #[test]
791    fn test_compilation_pipeline_invalid_source() {
792        let mut pipeline = ShaderCompilationPipeline::new();
793        let source = "no functions here";
794        let result = pipeline.compile("bad", source, None);
795        assert!(result.is_err());
796    }
797    #[test]
798    fn test_count_total_bindings() {
799        let total = count_total_bindings();
800        assert!(
801            total > 10,
802            "should have many bindings across all shaders, got {total}"
803        );
804    }
805    #[test]
806    fn test_sph_force_shader_valid() {
807        assert!(validate_wgsl_structure(SPH_FORCE_WGSL));
808        assert!(SPH_FORCE_WGSL.contains("@compute"));
809        assert!(SPH_FORCE_WGSL.contains("forces"));
810    }
811    #[test]
812    fn test_boundary_enforce_shader_valid() {
813        assert!(validate_wgsl_structure(BOUNDARY_ENFORCE_WGSL));
814        assert!(BOUNDARY_ENFORCE_WGSL.contains("@compute"));
815        assert!(BOUNDARY_ENFORCE_WGSL.contains("restitution"));
816    }
817    #[test]
818    fn test_registry_with_builtins_count() {
819        let reg = ShaderRegistry::with_builtins();
820        assert!(
821            reg.len() >= 8,
822            "should have at least 8 built-in shaders, got {}",
823            reg.len()
824        );
825    }
826    #[test]
827    fn test_uniform_buffer_binding_desc() {
828        let desc = UniformBufferDesc::new("SceneParams", 0, 0, 256);
829        assert_eq!(desc.name, "SceneParams");
830        assert_eq!(desc.group, 0);
831        assert_eq!(desc.binding, 0);
832        assert_eq!(desc.size_bytes, 256);
833    }
834    #[test]
835    fn test_uniform_buffer_binding_layout() {
836        let mut layout = BindGroupLayout::new();
837        layout.add_uniform("CameraUbo", 0, 0, 64);
838        layout.add_uniform("LightUbo", 0, 1, 128);
839        assert_eq!(layout.binding_count(), 2);
840    }
841    #[test]
842    fn test_push_constant_range() {
843        let range = PushConstantRange::new(0, 128, ShaderStage::Compute);
844        assert_eq!(range.offset, 0);
845        assert_eq!(range.size, 128);
846        assert_eq!(range.stage, ShaderStage::Compute);
847    }
848    #[test]
849    fn test_push_constants_exceed_limit_detection() {
850        let range = PushConstantRange::new(0, 256, ShaderStage::Compute);
851        assert!(range.size > 128, "Detected large push constant range");
852    }
853    #[test]
854    fn test_descriptor_set_layout_storage() {
855        let mut dsl = DescriptorSetLayout::new(0);
856        dsl.add_storage_buffer(0, ShaderStage::Compute, false);
857        dsl.add_storage_buffer(1, ShaderStage::Compute, true);
858        assert_eq!(dsl.bindings.len(), 2);
859        assert!(!dsl.bindings[0].read_only);
860        assert!(dsl.bindings[1].read_only);
861    }
862    #[test]
863    fn test_descriptor_set_layout_uniform() {
864        let mut dsl = DescriptorSetLayout::new(0);
865        dsl.add_uniform_buffer(2, ShaderStage::Vertex);
866        assert_eq!(dsl.bindings.len(), 1);
867        assert_eq!(dsl.bindings[0].binding, 2);
868        assert_eq!(
869            dsl.bindings[0].descriptor_type,
870            DescriptorType::UniformBuffer
871        );
872    }
873    #[test]
874    fn test_sampler_descriptor_linear() {
875        let sd = SamplerDesc::linear();
876        assert_eq!(sd.filter_min, FilterMode::Linear);
877        assert_eq!(sd.filter_mag, FilterMode::Linear);
878        assert_eq!(sd.address_mode, AddressMode::ClampToEdge);
879    }
880    #[test]
881    fn test_sampler_descriptor_nearest() {
882        let sd = SamplerDesc::nearest();
883        assert_eq!(sd.filter_min, FilterMode::Nearest);
884        assert_eq!(sd.filter_mag, FilterMode::Nearest);
885    }
886    #[test]
887    fn test_sampler_descriptor_repeat() {
888        let sd = SamplerDesc {
889            address_mode: AddressMode::Repeat,
890            ..SamplerDesc::linear()
891        };
892        assert_eq!(sd.address_mode, AddressMode::Repeat);
893    }
894    #[test]
895    fn test_render_pass_desc_color_attachment() {
896        let rp = RenderPassDesc::new_simple_color();
897        assert_eq!(rp.color_attachments.len(), 1);
898        assert_eq!(rp.color_attachments[0].format, TextureFormat::Rgba8Unorm);
899    }
900    #[test]
901    fn test_render_pass_desc_depth_attachment() {
902        let rp = RenderPassDesc::new_with_depth();
903        assert!(rp.depth_attachment.is_some());
904        let depth = rp.depth_attachment.unwrap();
905        assert_eq!(depth.format, TextureFormat::Depth32Float);
906    }
907    #[test]
908    fn test_render_pass_desc_no_depth() {
909        let rp = RenderPassDesc::new_simple_color();
910        assert!(rp.depth_attachment.is_none());
911    }
912    #[test]
913    fn test_hot_reload_manager_watch() {
914        let mut mgr = ShaderHotReloadManager::new();
915        mgr.watch("particles.wgsl", "fn main() { }");
916        assert!(mgr.is_watched("particles.wgsl"));
917    }
918    #[test]
919    fn test_hot_reload_manager_update() {
920        let mut mgr = ShaderHotReloadManager::new();
921        mgr.watch("test.wgsl", "fn main() { let x = 1u; }");
922        let changed = mgr.update("test.wgsl", "@compute\nfn main() { let x = 2u; }");
923        assert!(changed);
924        assert!(mgr.get_source("test.wgsl").unwrap().contains("2u"));
925    }
926    #[test]
927    fn test_hot_reload_manager_no_change() {
928        let src = "fn main() { }";
929        let mut mgr = ShaderHotReloadManager::new();
930        mgr.watch("test.wgsl", src);
931        let changed = mgr.update("test.wgsl", src);
932        assert!(!changed, "Identical content should not trigger a reload");
933    }
934    #[test]
935    fn test_hot_reload_manager_unwatch() {
936        let mut mgr = ShaderHotReloadManager::new();
937        mgr.watch("a.wgsl", "fn main() {}");
938        mgr.unwatch("a.wgsl");
939        assert!(!mgr.is_watched("a.wgsl"));
940    }
941    #[test]
942    fn test_spirv_compilation_mock_produces_bytes() {
943        let src = "@compute @workgroup_size(64)\nfn main() { let x = 1u; }";
944        let spirv = mock_compile_to_spirv(src, "main");
945        assert!(
946            !spirv.is_empty(),
947            "Mock SPIR-V should produce non-empty bytes"
948        );
949        assert_eq!(spirv[0], 0x07, "SPIR-V first byte mismatch");
950    }
951    #[test]
952    fn test_spirv_compilation_deterministic() {
953        let src = "@compute @workgroup_size(64)\nfn main() { }";
954        let s1 = mock_compile_to_spirv(src, "main");
955        let s2 = mock_compile_to_spirv(src, "main");
956        assert_eq!(s1, s2, "Compilation should be deterministic");
957    }
958    #[test]
959    fn test_spirv_module_reflection() {
960        let src = "@group(0) @binding(0) var<storage, read_write> buf: array<f32>;\n@compute @workgroup_size(64)\nfn main() { }";
961        let module = SpirVModule::from_wgsl(src);
962        assert!(module.entry_points.contains(&"main".to_string()));
963        assert!(module.binding_count >= 1);
964    }
965    #[test]
966    fn test_spirv_module_workgroup_size() {
967        let src = "@compute @workgroup_size(128)\nfn compute_main() { }";
968        let module = SpirVModule::from_wgsl(src);
969        assert_eq!(module.workgroup_size[0], 128);
970        assert!(module.entry_points.contains(&"compute_main".to_string()));
971    }
972    #[test]
973    fn test_bind_group_layout_serialize() {
974        let mut layout = BindGroupLayout::new();
975        layout.add_uniform("Ubo", 0, 0, 64);
976        layout.add_storage("Sbo", 0, 1, false);
977        let desc = layout.to_wgsl_snippet();
978        assert!(desc.contains("@binding(0)"));
979        assert!(desc.contains("@binding(1)"));
980    }
981    #[test]
982    fn test_bind_group_layout_empty() {
983        let layout = BindGroupLayout::new();
984        assert_eq!(layout.binding_count(), 0);
985        assert!(layout.is_empty());
986    }
987}