1use std::collections::HashMap;
6
7use super::types::ShaderMetadata;
8
9pub const SPH_DENSITY_WGSL: &str = r#"
11// Binding layout:
12// @group(0) @binding(0) var<storage, read> positions: array<vec4<f32>>;
13// @group(0) @binding(1) var<storage, read> masses: array<f32>;
14// @group(0) @binding(2) var<storage, read_write> densities: array<f32>;
15// @group(0) @binding(3) var<uniform> params: SphParams;
16
17pub(super) struct SphParams {
18 n_particles: u32,
19 h: f32,
20 h2: f32,
21 h3: f32,
22}
23
24@compute @workgroup_size(64)
25pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
26 let i = id.x;
27 if (i >= params.n_particles) { return; }
28
29 var density: f32 = 0.0;
30 let pi = positions[i].xyz;
31
32 for (var j: u32 = 0u; j < params.n_particles; j++) {
33 let r = pi - positions[j].xyz;
34 let r2 = dot(r, r);
35 if (r2 < params.h2) {
36 let q = 1.0 - r2 / params.h2;
37 density += masses[j] * q * q * q;
38 }
39 }
40
41 densities[i] = density * 315.0 / (64.0 * 3.14159265 * params.h3);
42}
43"#;
44pub const INTEGRATE_WGSL: &str = r#"
46pub(super) struct Particle { pos: vec4<f32>, vel: vec4<f32>, }
47pub(super) struct IntegParams { dt: f32, gravity_y: f32, n: u32, _pad: u32, }
48
49@group(0) @binding(0) var<storage, read_write> particles: array<Particle>;
50@group(0) @binding(1) var<storage, read> forces: array<vec4<f32>>;
51@group(0) @binding(2) var<uniform> params: IntegParams;
52
53@compute @workgroup_size(64)
54pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
55 let i = id.x;
56 if (i >= params.n) { return; }
57
58 var p = particles[i];
59 let f = forces[i].xyz + vec3<f32>(0.0, params.gravity_y, 0.0);
60 p.vel = vec4<f32>(p.vel.xyz + f * params.dt, 0.0);
61 p.pos = vec4<f32>(p.pos.xyz + p.vel.xyz * params.dt, 1.0);
62 particles[i] = p;
63}
64"#;
65pub const LBM_BGK_D2Q9_WGSL: &str = r#"
67// D2Q9 LBM BGK collision kernel
68// Layout: f[node * 9 + direction]
69pub(super) struct LbmParams { nx: u32, ny: u32, tau: f32, _pad: u32, }
70
71@group(0) @binding(0) var<storage, read_write> f: array<f32>;
72@group(0) @binding(1) var<storage, read_write> rho: array<f32>;
73@group(0) @binding(2) var<storage, read_write> ux: array<f32>;
74@group(0) @binding(3) var<storage, read_write> uy: array<f32>;
75@group(0) @binding(4) var<uniform> params: LbmParams;
76
77@compute @workgroup_size(64)
78pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
79 let idx = id.x;
80 if (idx >= params.nx * params.ny) { return; }
81
82 // Compute macroscopic values
83 var r: f32 = 0.0;
84 var u: f32 = 0.0;
85 var v: f32 = 0.0;
86 let w = array<f32, 9>(4.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/36.0, 1.0/36.0, 1.0/36.0, 1.0/36.0);
87 let ex = array<f32, 9>(0.0, 1.0, 0.0, -1.0, 0.0, 1.0, -1.0, -1.0, 1.0);
88 let ey = array<f32, 9>(0.0, 0.0, 1.0, 0.0, -1.0, 1.0, 1.0, -1.0, -1.0);
89
90 for (var i: u32 = 0u; i < 9u; i++) {
91 let fi = f[idx * 9u + i];
92 r += fi;
93 u += ex[i] * fi;
94 v += ey[i] * fi;
95 }
96 if (r > 0.0) { u /= r; v /= r; }
97
98 rho[idx] = r; ux[idx] = u; uy[idx] = v;
99
100 // BGK collision
101 for (var i: u32 = 0u; i < 9u; i++) {
102 let eu = ex[i]*u + ey[i]*v;
103 let feq = w[i] * r * (1.0 + 3.0*eu + 4.5*eu*eu - 1.5*(u*u+v*v));
104 f[idx * 9u + i] += (feq - f[idx * 9u + i]) / params.tau;
105 }
106}
107"#;
108pub const CELL_LIST_WGSL: &str = r#"
110pub(super) struct CellParams {
111 nx: u32,
112 ny: u32,
113 nz: u32,
114 cell_size: f32,
115 n_atoms: u32,
116 box_x: f32,
117 box_y: f32,
118 box_z: f32,
119}
120
121@group(0) @binding(0) var<storage, read> positions: array<vec4<f32>>;
122@group(0) @binding(1) var<storage, read_write> cell_indices: array<u32>;
123@group(0) @binding(2) var<uniform> params: CellParams;
124
125@compute @workgroup_size(64)
126pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
127 let atom_id = id.x;
128 if (atom_id >= params.n_atoms) { return; }
129
130 let pos = positions[atom_id].xyz;
131
132 let ix = clamp(u32(pos.x / params.cell_size), 0u, params.nx - 1u);
133 let iy = clamp(u32(pos.y / params.cell_size), 0u, params.ny - 1u);
134 let iz = clamp(u32(pos.z / params.cell_size), 0u, params.nz - 1u);
135
136 cell_indices[atom_id] = iz * params.ny * params.nx
137 + iy * params.nx
138 + ix;
139}
140"#;
141pub const SDF_COMPUTE_WGSL: &str = r#"
143pub(super) struct SdfSphereParams {
144 nx: u32,
145 ny: u32,
146 nz: u32,
147 dx: f32,
148 origin_x: f32,
149 origin_y: f32,
150 origin_z: f32,
151 cx: f32,
152 cy: f32,
153 cz: f32,
154 radius: f32,
155 _pad: u32,
156}
157
158@group(0) @binding(0) var<storage, read_write> sdf_values: array<f32>;
159@group(0) @binding(1) var<uniform> params: SdfSphereParams;
160
161@compute @workgroup_size(64)
162pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
163 let total = params.nx * params.ny * params.nz;
164 let idx = id.x;
165 if (idx >= total) { return; }
166
167 let i = idx / (params.ny * params.nz);
168 let j = (idx / params.nz) % params.ny;
169 let k = idx % params.nz;
170
171 let px = params.origin_x + (f32(i) + 0.5) * params.dx;
172 let py = params.origin_y + (f32(j) + 0.5) * params.dx;
173 let pz = params.origin_z + (f32(k) + 0.5) * params.dx;
174
175 let dist = distance(vec3<f32>(px, py, pz),
176 vec3<f32>(params.cx, params.cy, params.cz));
177 sdf_values[idx] = dist - params.radius;
178}
179"#;
180pub const LBM_STREAMING_SHADER: &str = r#"
182// LBM D3Q19 streaming step
183// Distributes f values from source to destination according to lattice velocities.
184pub(super) struct StreamParams { nx: u32, ny: u32, nz: u32, _pad: u32, }
185
186@group(0) @binding(0) var<storage, read> f_src: array<f32>;
187@group(0) @binding(1) var<storage, read_write> f_dst: array<f32>;
188@group(0) @binding(2) var<uniform> params: StreamParams;
189
190@compute @workgroup_size(64)
191pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
192 let nx = params.nx;
193 let ny = params.ny;
194 let nz = params.nz;
195 let total = nx * ny * nz;
196 let idx = id.x;
197 if (idx >= total) { return; }
198
199 let iz = idx / (ny * nx);
200 let iy = (idx / nx) % ny;
201 let ix = idx % nx;
202
203 // D3Q19 lattice velocities (abbreviated: direction 0 = rest)
204 let ex = array<i32, 19>( 0, 1,-1, 0, 0, 0, 0, 1,-1, 1,-1, 1,-1, 1,-1, 0, 0, 0, 0);
205 let ey = array<i32, 19>( 0, 0, 0, 1,-1, 0, 0, 1, 1,-1,-1, 0, 0, 0, 0, 1,-1, 1,-1);
206 let ez = array<i32, 19>( 0, 0, 0, 0, 0, 1,-1, 0, 0, 0, 0, 1, 1,-1,-1, 1, 1,-1,-1);
207
208 for (var q: u32 = 0u; q < 19u; q++) {
209 let sx = (i32(ix) - ex[q] + i32(nx)) % i32(nx);
210 let sy = (i32(iy) - ey[q] + i32(ny)) % i32(ny);
211 let sz = (i32(iz) - ez[q] + i32(nz)) % i32(nz);
212 let src_idx = u32(sz) * ny * nx + u32(sy) * nx + u32(sx);
213 f_dst[idx * 19u + q] = f_src[src_idx * 19u + q];
214 }
215}
216"#;
217pub const RIGID_INTEGRATE_SHADER: &str = r#"
219// Semi-implicit Euler integration for rigid bodies.
220pub(super) struct RigidBody {
221 pos: vec4<f32>, // xyz = position, w = mass
222 vel: vec4<f32>, // xyz = linear velocity, w unused
223 ang_vel: vec4<f32>, // xyz = angular velocity, w unused
224 force: vec4<f32>, // xyz = accumulated force, w unused
225 torque: vec4<f32>, // xyz = accumulated torque, w unused
226}
227pub(super) struct IntegRigidParams { dt: f32, n: u32, _pad0: u32, _pad1: u32, }
228
229@group(0) @binding(0) var<storage, read_write> bodies: array<RigidBody>;
230@group(0) @binding(1) var<uniform> params: IntegRigidParams;
231
232@compute @workgroup_size(64)
233pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
234 let i = id.x;
235 if (i >= params.n) { return; }
236
237 var b = bodies[i];
238 let mass = b.pos.w;
239 if (mass <= 0.0) { return; }
240
241 // Semi-implicit Euler: update velocity first, then position.
242 let acc = b.force.xyz / mass;
243 b.vel = vec4<f32>(b.vel.xyz + acc * params.dt, 0.0);
244 b.pos = vec4<f32>(b.pos.xyz + b.vel.xyz * params.dt, mass);
245
246 // Clear forces.
247 b.force = vec4<f32>(0.0, 0.0, 0.0, 0.0);
248 b.torque = vec4<f32>(0.0, 0.0, 0.0, 0.0);
249 bodies[i] = b;
250}
251"#;
252pub const BROADPHASE_SORT_SHADER: &str = r#"
254// Bitonic sort pass for SAP broadphase.
255// Each invocation compares and optionally swaps one pair of elements.
256pub(super) struct SortParams { n: u32, step: u32, stage: u32, _pad: u32, }
257
258@group(0) @binding(0) var<storage, read_write> keys: array<f32>;
259@group(0) @binding(1) var<storage, read_write> values: array<u32>;
260@group(0) @binding(2) var<uniform> params: SortParams;
261
262@compute @workgroup_size(64)
263pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
264 let idx = id.x;
265 let n = params.n;
266 if (idx >= n / 2u) { return; }
267
268 let step = params.step;
269 let stage = params.stage;
270
271 let j = idx & (step - 1u);
272 let i = (idx - j) * 2u + j;
273 let i2 = i + step;
274
275 if (i2 >= n) { return; }
276
277 // Ascending sort when the block bit is 0.
278 let ascending = ((i / (stage * 2u)) & 1u) == 0u;
279 let swap = (keys[i] > keys[i2]) == ascending;
280
281 if (swap) {
282 let tmp_k = keys[i]; keys[i] = keys[i2]; keys[i2] = tmp_k;
283 let tmp_v = values[i]; values[i] = values[i2]; values[i2] = tmp_v;
284 }
285}
286"#;
287pub const SPH_FORCE_WGSL: &str = r#"
289pub(super) struct SphForceParams {
290 n_particles: u32,
291 h: f32,
292 mu: f32,
293 _pad: u32,
294}
295
296@group(0) @binding(0) var<storage, read> positions: array<vec4<f32>>;
297@group(0) @binding(1) var<storage, read> velocities: array<vec4<f32>>;
298@group(0) @binding(2) var<storage, read> densities: array<f32>;
299@group(0) @binding(3) var<storage, read> pressures: array<f32>;
300@group(0) @binding(4) var<storage, read> masses: array<f32>;
301@group(0) @binding(5) var<storage, read_write> forces: array<vec4<f32>>;
302@group(0) @binding(6) var<uniform> params: SphForceParams;
303
304@compute @workgroup_size(64)
305pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
306 let i = id.x;
307 if (i >= params.n_particles) { return; }
308
309 let pi = positions[i].xyz;
310 let vi = velocities[i].xyz;
311 var force: vec3<f32> = vec3<f32>(0.0, 0.0, 0.0);
312 let h2 = params.h * params.h;
313
314 for (var j: u32 = 0u; j < params.n_particles; j++) {
315 if (j == i) { continue; }
316 let r_vec = pi - positions[j].xyz;
317 let r2 = dot(r_vec, r_vec);
318 let r = sqrt(r2);
319 if (r < 0.0001 || r >= params.h) { continue; }
320
321 // Spiky gradient
322 let h_r = params.h - r;
323 let grad_mag = -45.0 / (3.14159265 * pow(params.h, 6.0)) * h_r * h_r;
324 let grad = r_vec / r * grad_mag;
325
326 // Pressure force
327 let p_term = -masses[j] * (pressures[i] + pressures[j]) / (2.0 * densities[j]);
328 force += p_term * grad;
329
330 // Viscosity
331 let v_lap = 45.0 / (3.14159265 * pow(params.h, 6.0)) * h_r;
332 force += params.mu * masses[j] * (velocities[j].xyz - vi) / densities[j] * v_lap;
333 }
334
335 forces[i] = vec4<f32>(force, 0.0);
336}
337"#;
338pub const BOUNDARY_ENFORCE_WGSL: &str = r#"
340pub(super) struct BoundaryParams {
341 n: u32,
342 box_min_x: f32,
343 box_min_y: f32,
344 box_min_z: f32,
345 box_max_x: f32,
346 box_max_y: f32,
347 box_max_z: f32,
348 restitution: f32,
349}
350
351@group(0) @binding(0) var<storage, read_write> positions: array<vec4<f32>>;
352@group(0) @binding(1) var<storage, read_write> velocities: array<vec4<f32>>;
353@group(0) @binding(2) var<uniform> params: BoundaryParams;
354
355@compute @workgroup_size(64)
356pub(super) fn main(@builtin(global_invocation_id) id: vec3<u32>) {
357 let i = id.x;
358 if (i >= params.n) { return; }
359
360 var pos = positions[i].xyz;
361 var vel = velocities[i].xyz;
362 let e = params.restitution;
363
364 if (pos.x < params.box_min_x) { pos.x = params.box_min_x; vel.x = -vel.x * e; }
365 if (pos.x > params.box_max_x) { pos.x = params.box_max_x; vel.x = -vel.x * e; }
366 if (pos.y < params.box_min_y) { pos.y = params.box_min_y; vel.y = -vel.y * e; }
367 if (pos.y > params.box_max_y) { pos.y = params.box_max_y; vel.y = -vel.y * e; }
368 if (pos.z < params.box_min_z) { pos.z = params.box_min_z; vel.z = -vel.z * e; }
369 if (pos.z > params.box_max_z) { pos.z = params.box_max_z; vel.z = -vel.z * e; }
370
371 positions[i] = vec4<f32>(pos, positions[i].w);
372 velocities[i] = vec4<f32>(vel, 0.0);
373}
374"#;
375pub fn resolve_includes(source: &str, includes: &HashMap<&str, &str>) -> String {
380 let mut result = String::new();
381 for line in source.lines() {
382 let trimmed = line.trim();
383 if trimmed.starts_with("#include") {
384 if let Some(start) = trimmed.find('"')
385 && let Some(end) = trimmed[start + 1..].find('"')
386 {
387 let filename = &trimmed[start + 1..start + 1 + end];
388 if let Some(content) = includes.get(filename) {
389 result.push_str(content);
390 result.push('\n');
391 continue;
392 }
393 }
394 result.push_str("// UNRESOLVED: ");
395 result.push_str(line);
396 result.push('\n');
397 } else {
398 result.push_str(line);
399 result.push('\n');
400 }
401 }
402 result
403}
404pub fn validate_wgsl_structure(source: &str) -> bool {
410 if !source.contains("fn") {
411 return false;
412 }
413 let open: usize = source.chars().filter(|&c| c == '{').count();
414 let close: usize = source.chars().filter(|&c| c == '}').count();
415 open == close && open > 0
416}
417pub fn validate_shader_sources() -> bool {
419 SPH_DENSITY_WGSL.contains("@compute")
420 && INTEGRATE_WGSL.contains("@compute")
421 && LBM_BGK_D2Q9_WGSL.contains("@compute")
422 && CELL_LIST_WGSL.contains("@compute")
423 && SDF_COMPUTE_WGSL.contains("@compute")
424 && LBM_STREAMING_SHADER.contains("@compute")
425 && RIGID_INTEGRATE_SHADER.contains("@compute")
426 && BROADPHASE_SORT_SHADER.contains("@compute")
427 && SPH_FORCE_WGSL.contains("@compute")
428 && BOUNDARY_ENFORCE_WGSL.contains("@compute")
429}
430pub fn count_total_bindings() -> usize {
432 let all_shaders = [
433 SPH_DENSITY_WGSL,
434 INTEGRATE_WGSL,
435 LBM_BGK_D2Q9_WGSL,
436 CELL_LIST_WGSL,
437 SDF_COMPUTE_WGSL,
438 LBM_STREAMING_SHADER,
439 RIGID_INTEGRATE_SHADER,
440 BROADPHASE_SORT_SHADER,
441 SPH_FORCE_WGSL,
442 BOUNDARY_ENFORCE_WGSL,
443 ];
444 all_shaders
445 .iter()
446 .map(|s| s.matches("@binding(").count())
447 .sum()
448}
449pub(super) fn simple_hash(s: &str) -> u64 {
451 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
452 for byte in s.bytes() {
453 h ^= byte as u64;
454 h = h.wrapping_mul(0x0000_0100_0000_01b3);
455 }
456 h
457}
458pub fn mock_compile_to_spirv(source: &str, entry_point: &str) -> Vec<u8> {
464 let mut out = vec![0x07u8, 0x23, 0x02, 0x03];
465 out.extend_from_slice(&[0x00, 0x01, 0x05, 0x00]);
466 let src_len = source.len() as u32;
467 out.extend_from_slice(&src_len.to_le_bytes());
468 let ep_hash = simple_hash(entry_point) as u32;
469 out.extend_from_slice(&ep_hash.to_le_bytes());
470 let bindings = source.matches("@binding(").count() as u32;
471 out.extend_from_slice(&bindings.to_le_bytes());
472 while !out.len().is_multiple_of(4) {
473 out.push(0x00);
474 }
475 out
476}
477pub(super) fn parse_workgroup_size(source: &str) -> [u32; 3] {
479 if let Some(pos) = source.find("@workgroup_size(") {
480 let rest = &source[pos + 16..];
481 if let Some(end) = rest.find(')') {
482 let inner = &rest[..end];
483 let parts: Vec<u32> = inner
484 .split(',')
485 .map(|s| s.trim().parse::<u32>().unwrap_or(1))
486 .collect();
487 return match parts.len() {
488 1 => [parts[0], 1, 1],
489 2 => [parts[0], parts[1], 1],
490 3 => [parts[0], parts[1], parts[2]],
491 _ => [1, 1, 1],
492 };
493 }
494 }
495 [1, 1, 1]
496}
497pub fn validate_shader_metadata(meta: &ShaderMetadata) -> Result<(), String> {
505 if meta.workgroup_size[0] == 0 || meta.workgroup_size[1] == 0 || meta.workgroup_size[2] == 0 {
506 return Err("workgroup_size must be non-zero in every dimension".to_string());
507 }
508 let total = meta.workgroup_size[0]
509 .saturating_mul(meta.workgroup_size[1])
510 .saturating_mul(meta.workgroup_size[2]);
511 if total > 1024 {
512 return Err(format!(
513 "total threads-per-workgroup {} exceeds 1024",
514 total
515 ));
516 }
517 if meta.bind_group_count > 4 {
518 return Err(format!(
519 "bind_group_count {} exceeds maximum of 4",
520 meta.bind_group_count
521 ));
522 }
523 if meta.entry_point.is_empty() {
524 return Err("entry_point must not be empty".to_string());
525 }
526 Ok(())
527}
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::shaders::AddressMode;
532 use crate::shaders::BindGroupLayout;
533 use crate::shaders::ComputeShaderDesc;
534 use crate::shaders::DescriptorSetLayout;
535 use crate::shaders::DescriptorType;
536 use crate::shaders::FilterMode;
537 use crate::shaders::PushConstantRange;
538 use crate::shaders::RenderPassDesc;
539 use crate::shaders::SamplerDesc;
540 use crate::shaders::ShaderCache;
541 use crate::shaders::ShaderCompilationPipeline;
542 use crate::shaders::ShaderHotReloadManager;
543 use crate::shaders::ShaderRegistry;
544 use crate::shaders::ShaderStage;
545 use crate::shaders::ShaderTemplate;
546 use crate::shaders::SpecializationMap;
547 use crate::shaders::SpirVModule;
548 use crate::shaders::TextureFormat;
549 use crate::shaders::UniformBufferDesc;
550 #[test]
551 fn test_shader_sources_valid() {
552 assert!(validate_shader_sources());
553 }
554 #[test]
555 fn test_shader_sph_contains_density() {
556 assert!(SPH_DENSITY_WGSL.contains("density"));
557 }
558 #[test]
559 fn test_shader_lbm_contains_bgk() {
560 assert!(LBM_BGK_D2Q9_WGSL.contains("feq"));
561 }
562 #[test]
564 fn test_template_instantiation() {
565 let tmpl = ShaderTemplate::new(
566 "@compute @workgroup_size({WORKGROUP_SIZE})\nfn main() { var x: {BUFFER_TYPE}; }",
567 );
568 let mut params = HashMap::new();
569 params.insert("WORKGROUP_SIZE", "128");
570 params.insert("BUFFER_TYPE", "f32");
571 let result = tmpl.instantiate(¶ms);
572 assert!(
573 result.contains("@workgroup_size(128)"),
574 "Workgroup size not substituted: {result}"
575 );
576 assert!(
577 result.contains("var x: f32;"),
578 "Buffer type not substituted: {result}"
579 );
580 assert!(
581 !result.contains("{WORKGROUP_SIZE}"),
582 "Placeholder not removed"
583 );
584 }
585 #[test]
587 fn test_template_unknown_placeholder() {
588 let tmpl = ShaderTemplate::new("fn main() { var x: {UNKNOWN}; }");
589 let params = HashMap::new();
590 let result = tmpl.instantiate(¶ms);
591 assert!(
592 result.contains("{UNKNOWN}"),
593 "Unknown placeholder should be preserved"
594 );
595 }
596 #[test]
598 fn test_shader_registry_register_get() {
599 let mut reg = ShaderRegistry::new();
600 assert!(reg.is_empty());
601 let desc = ComputeShaderDesc::new("main", [64, 1, 1], LBM_STREAMING_SHADER);
602 reg.register("lbm_stream", desc);
603 assert_eq!(reg.len(), 1);
604 let retrieved = reg.get("lbm_stream").expect("Shader should be present");
605 assert_eq!(retrieved.entry_point, "main");
606 assert_eq!(retrieved.workgroup_size, [64, 1, 1]);
607 assert!(retrieved.source.contains("@compute"));
608 }
609 #[test]
611 fn test_shader_registry_builtins() {
612 let reg = ShaderRegistry::with_builtins();
613 assert!(reg.get("sph_density").is_some());
614 assert!(reg.get("lbm_streaming").is_some());
615 assert!(reg.get("rigid_integrate").is_some());
616 assert!(reg.get("broadphase_sort").is_some());
617 assert!(reg.get("sph_force").is_some());
618 assert!(reg.get("boundary_enforce").is_some());
619 assert!(reg.get("integrate").is_some());
620 assert!(reg.get("lbm_bgk_d2q9").is_some());
621 }
622 #[test]
624 fn test_validate_wgsl_structure_valid() {
625 let src = "@compute @workgroup_size(64)\nfn main() { let x = 1u; }";
626 assert!(
627 validate_wgsl_structure(src),
628 "Valid WGSL should pass validation"
629 );
630 }
631 #[test]
633 fn test_validate_wgsl_structure_unbalanced() {
634 let src = "fn main() { let x = 1u;";
635 assert!(
636 !validate_wgsl_structure(src),
637 "Unbalanced WGSL should fail validation"
638 );
639 }
640 #[test]
642 fn test_all_builtins_pass_structural_validation() {
643 let shaders = [
644 SPH_DENSITY_WGSL,
645 INTEGRATE_WGSL,
646 LBM_BGK_D2Q9_WGSL,
647 CELL_LIST_WGSL,
648 SDF_COMPUTE_WGSL,
649 LBM_STREAMING_SHADER,
650 RIGID_INTEGRATE_SHADER,
651 BROADPHASE_SORT_SHADER,
652 SPH_FORCE_WGSL,
653 BOUNDARY_ENFORCE_WGSL,
654 ];
655 for (i, src) in shaders.iter().enumerate() {
656 assert!(
657 validate_wgsl_structure(src),
658 "Shader {i} failed structural validation"
659 );
660 }
661 }
662 #[test]
663 fn test_template_placeholders() {
664 let tmpl = ShaderTemplate::new("fn main() { var x: {TYPE}; var y: {SIZE}; }");
665 let placeholders = tmpl.placeholders();
666 assert!(placeholders.contains(&"TYPE".to_string()));
667 assert!(placeholders.contains(&"SIZE".to_string()));
668 }
669 #[test]
670 fn test_template_all_placeholders_provided() {
671 let tmpl = ShaderTemplate::new("fn main() { var x: {TYPE}; }");
672 let mut params = HashMap::new();
673 assert!(!tmpl.all_placeholders_provided(¶ms));
674 params.insert("TYPE", "f32");
675 assert!(tmpl.all_placeholders_provided(¶ms));
676 }
677 #[test]
678 fn test_specialization_map() {
679 let mut sm = SpecializationMap::new();
680 assert!(sm.is_empty());
681 sm.define("WORKGROUP_SIZE", "64", "Threads per workgroup");
682 sm.define("MAX_NEIGHBORS", "128", "Max neighbors per particle");
683 assert_eq!(sm.len(), 2);
684 assert_eq!(sm.get("WORKGROUP_SIZE"), Some("64"));
685 sm.set("WORKGROUP_SIZE", "128");
686 assert_eq!(sm.get("WORKGROUP_SIZE"), Some("128"));
687 assert_eq!(sm.get("MAX_NEIGHBORS"), Some("128"));
688 }
689 #[test]
690 fn test_specialization_map_apply() {
691 let mut sm = SpecializationMap::new();
692 sm.define("WG", "64", "workgroup");
693 sm.set("WG", "128");
694 let source = "const WG = 64;\nfn main() {}";
695 let result = sm.apply(source);
696 assert!(
697 result.contains("const WG = 128;"),
698 "specialization not applied: {result}"
699 );
700 }
701 #[test]
702 fn test_resolve_includes() {
703 let mut includes = HashMap::new();
704 includes.insert("common.wgsl", "// common code\nfn helper() { }");
705 let source = "#include \"common.wgsl\"\nfn main() { }";
706 let resolved = resolve_includes(source, &includes);
707 assert!(
708 resolved.contains("common code"),
709 "include not resolved: {resolved}"
710 );
711 assert!(resolved.contains("fn main()"));
712 }
713 #[test]
714 fn test_resolve_includes_unresolved() {
715 let includes = HashMap::new();
716 let source = "#include \"missing.wgsl\"\nfn main() { }";
717 let resolved = resolve_includes(source, &includes);
718 assert!(
719 resolved.contains("UNRESOLVED"),
720 "unresolved include should be marked"
721 );
722 }
723 #[test]
724 fn test_shader_cache() {
725 let mut cache = ShaderCache::new();
726 assert!(cache.is_empty());
727 let source = cache.get_or_insert("test", || "fn main() { }".to_string());
728 assert_eq!(source, "fn main() { }");
729 assert!(cache.contains("test"));
730 assert_eq!(cache.len(), 1);
731 cache.remove("test");
732 assert!(!cache.contains("test"));
733 }
734 #[test]
735 fn test_shader_cache_returns_cached() {
736 let mut cache = ShaderCache::new();
737 cache.get_or_insert("test", || "first".to_string());
738 let second = cache.get_or_insert("test", || "second".to_string());
739 assert_eq!(second, "first", "cache should return first insertion");
740 }
741 #[test]
742 fn test_compute_shader_desc_binding_count() {
743 let desc = ComputeShaderDesc::new("main", [64, 1, 1], SPH_DENSITY_WGSL);
744 assert!(desc.binding_count() > 0);
745 assert_eq!(desc.threads_per_workgroup(), 64);
746 }
747 #[test]
748 fn test_registry_unregister() {
749 let mut reg = ShaderRegistry::new();
750 let desc = ComputeShaderDesc::new("main", [64, 1, 1], "fn main() { }");
751 reg.register("test", desc);
752 assert!(reg.contains("test"));
753 reg.unregister("test");
754 assert!(!reg.contains("test"));
755 }
756 #[test]
757 fn test_registry_names() {
758 let mut reg = ShaderRegistry::new();
759 reg.register(
760 "a",
761 ComputeShaderDesc::new("main", [64, 1, 1], "fn main() { }"),
762 );
763 reg.register(
764 "b",
765 ComputeShaderDesc::new("main", [64, 1, 1], "fn main() { }"),
766 );
767 let names: Vec<&str> = reg.names().collect();
768 assert_eq!(names.len(), 2);
769 assert!(names.contains(&"a"));
770 assert!(names.contains(&"b"));
771 }
772 #[test]
773 fn test_compilation_pipeline() {
774 let mut pipeline = ShaderCompilationPipeline::new();
775 let source = "@compute @workgroup_size(64)\nfn main() { let x = 1u; }";
776 let result = pipeline.compile("test", source, None).unwrap();
777 assert!(result.contains("fn main()"));
778 assert_eq!(pipeline.cache_size(), 1);
779 let result2 = pipeline.compile("test", source, None).unwrap();
780 assert_eq!(result, result2);
781 }
782 #[test]
783 fn test_compilation_pipeline_with_includes() {
784 let mut pipeline = ShaderCompilationPipeline::new();
785 pipeline.add_include("common.wgsl", "// included code");
786 let source = "#include \"common.wgsl\"\n@compute @workgroup_size(64)\nfn main() { }";
787 let result = pipeline.compile("test_inc", source, None).unwrap();
788 assert!(result.contains("included code"));
789 }
790 #[test]
791 fn test_compilation_pipeline_invalid_source() {
792 let mut pipeline = ShaderCompilationPipeline::new();
793 let source = "no functions here";
794 let result = pipeline.compile("bad", source, None);
795 assert!(result.is_err());
796 }
797 #[test]
798 fn test_count_total_bindings() {
799 let total = count_total_bindings();
800 assert!(
801 total > 10,
802 "should have many bindings across all shaders, got {total}"
803 );
804 }
805 #[test]
806 fn test_sph_force_shader_valid() {
807 assert!(validate_wgsl_structure(SPH_FORCE_WGSL));
808 assert!(SPH_FORCE_WGSL.contains("@compute"));
809 assert!(SPH_FORCE_WGSL.contains("forces"));
810 }
811 #[test]
812 fn test_boundary_enforce_shader_valid() {
813 assert!(validate_wgsl_structure(BOUNDARY_ENFORCE_WGSL));
814 assert!(BOUNDARY_ENFORCE_WGSL.contains("@compute"));
815 assert!(BOUNDARY_ENFORCE_WGSL.contains("restitution"));
816 }
817 #[test]
818 fn test_registry_with_builtins_count() {
819 let reg = ShaderRegistry::with_builtins();
820 assert!(
821 reg.len() >= 8,
822 "should have at least 8 built-in shaders, got {}",
823 reg.len()
824 );
825 }
826 #[test]
827 fn test_uniform_buffer_binding_desc() {
828 let desc = UniformBufferDesc::new("SceneParams", 0, 0, 256);
829 assert_eq!(desc.name, "SceneParams");
830 assert_eq!(desc.group, 0);
831 assert_eq!(desc.binding, 0);
832 assert_eq!(desc.size_bytes, 256);
833 }
834 #[test]
835 fn test_uniform_buffer_binding_layout() {
836 let mut layout = BindGroupLayout::new();
837 layout.add_uniform("CameraUbo", 0, 0, 64);
838 layout.add_uniform("LightUbo", 0, 1, 128);
839 assert_eq!(layout.binding_count(), 2);
840 }
841 #[test]
842 fn test_push_constant_range() {
843 let range = PushConstantRange::new(0, 128, ShaderStage::Compute);
844 assert_eq!(range.offset, 0);
845 assert_eq!(range.size, 128);
846 assert_eq!(range.stage, ShaderStage::Compute);
847 }
848 #[test]
849 fn test_push_constants_exceed_limit_detection() {
850 let range = PushConstantRange::new(0, 256, ShaderStage::Compute);
851 assert!(range.size > 128, "Detected large push constant range");
852 }
853 #[test]
854 fn test_descriptor_set_layout_storage() {
855 let mut dsl = DescriptorSetLayout::new(0);
856 dsl.add_storage_buffer(0, ShaderStage::Compute, false);
857 dsl.add_storage_buffer(1, ShaderStage::Compute, true);
858 assert_eq!(dsl.bindings.len(), 2);
859 assert!(!dsl.bindings[0].read_only);
860 assert!(dsl.bindings[1].read_only);
861 }
862 #[test]
863 fn test_descriptor_set_layout_uniform() {
864 let mut dsl = DescriptorSetLayout::new(0);
865 dsl.add_uniform_buffer(2, ShaderStage::Vertex);
866 assert_eq!(dsl.bindings.len(), 1);
867 assert_eq!(dsl.bindings[0].binding, 2);
868 assert_eq!(
869 dsl.bindings[0].descriptor_type,
870 DescriptorType::UniformBuffer
871 );
872 }
873 #[test]
874 fn test_sampler_descriptor_linear() {
875 let sd = SamplerDesc::linear();
876 assert_eq!(sd.filter_min, FilterMode::Linear);
877 assert_eq!(sd.filter_mag, FilterMode::Linear);
878 assert_eq!(sd.address_mode, AddressMode::ClampToEdge);
879 }
880 #[test]
881 fn test_sampler_descriptor_nearest() {
882 let sd = SamplerDesc::nearest();
883 assert_eq!(sd.filter_min, FilterMode::Nearest);
884 assert_eq!(sd.filter_mag, FilterMode::Nearest);
885 }
886 #[test]
887 fn test_sampler_descriptor_repeat() {
888 let sd = SamplerDesc {
889 address_mode: AddressMode::Repeat,
890 ..SamplerDesc::linear()
891 };
892 assert_eq!(sd.address_mode, AddressMode::Repeat);
893 }
894 #[test]
895 fn test_render_pass_desc_color_attachment() {
896 let rp = RenderPassDesc::new_simple_color();
897 assert_eq!(rp.color_attachments.len(), 1);
898 assert_eq!(rp.color_attachments[0].format, TextureFormat::Rgba8Unorm);
899 }
900 #[test]
901 fn test_render_pass_desc_depth_attachment() {
902 let rp = RenderPassDesc::new_with_depth();
903 assert!(rp.depth_attachment.is_some());
904 let depth = rp.depth_attachment.unwrap();
905 assert_eq!(depth.format, TextureFormat::Depth32Float);
906 }
907 #[test]
908 fn test_render_pass_desc_no_depth() {
909 let rp = RenderPassDesc::new_simple_color();
910 assert!(rp.depth_attachment.is_none());
911 }
912 #[test]
913 fn test_hot_reload_manager_watch() {
914 let mut mgr = ShaderHotReloadManager::new();
915 mgr.watch("particles.wgsl", "fn main() { }");
916 assert!(mgr.is_watched("particles.wgsl"));
917 }
918 #[test]
919 fn test_hot_reload_manager_update() {
920 let mut mgr = ShaderHotReloadManager::new();
921 mgr.watch("test.wgsl", "fn main() { let x = 1u; }");
922 let changed = mgr.update("test.wgsl", "@compute\nfn main() { let x = 2u; }");
923 assert!(changed);
924 assert!(mgr.get_source("test.wgsl").unwrap().contains("2u"));
925 }
926 #[test]
927 fn test_hot_reload_manager_no_change() {
928 let src = "fn main() { }";
929 let mut mgr = ShaderHotReloadManager::new();
930 mgr.watch("test.wgsl", src);
931 let changed = mgr.update("test.wgsl", src);
932 assert!(!changed, "Identical content should not trigger a reload");
933 }
934 #[test]
935 fn test_hot_reload_manager_unwatch() {
936 let mut mgr = ShaderHotReloadManager::new();
937 mgr.watch("a.wgsl", "fn main() {}");
938 mgr.unwatch("a.wgsl");
939 assert!(!mgr.is_watched("a.wgsl"));
940 }
941 #[test]
942 fn test_spirv_compilation_mock_produces_bytes() {
943 let src = "@compute @workgroup_size(64)\nfn main() { let x = 1u; }";
944 let spirv = mock_compile_to_spirv(src, "main");
945 assert!(
946 !spirv.is_empty(),
947 "Mock SPIR-V should produce non-empty bytes"
948 );
949 assert_eq!(spirv[0], 0x07, "SPIR-V first byte mismatch");
950 }
951 #[test]
952 fn test_spirv_compilation_deterministic() {
953 let src = "@compute @workgroup_size(64)\nfn main() { }";
954 let s1 = mock_compile_to_spirv(src, "main");
955 let s2 = mock_compile_to_spirv(src, "main");
956 assert_eq!(s1, s2, "Compilation should be deterministic");
957 }
958 #[test]
959 fn test_spirv_module_reflection() {
960 let src = "@group(0) @binding(0) var<storage, read_write> buf: array<f32>;\n@compute @workgroup_size(64)\nfn main() { }";
961 let module = SpirVModule::from_wgsl(src);
962 assert!(module.entry_points.contains(&"main".to_string()));
963 assert!(module.binding_count >= 1);
964 }
965 #[test]
966 fn test_spirv_module_workgroup_size() {
967 let src = "@compute @workgroup_size(128)\nfn compute_main() { }";
968 let module = SpirVModule::from_wgsl(src);
969 assert_eq!(module.workgroup_size[0], 128);
970 assert!(module.entry_points.contains(&"compute_main".to_string()));
971 }
972 #[test]
973 fn test_bind_group_layout_serialize() {
974 let mut layout = BindGroupLayout::new();
975 layout.add_uniform("Ubo", 0, 0, 64);
976 layout.add_storage("Sbo", 0, 1, false);
977 let desc = layout.to_wgsl_snippet();
978 assert!(desc.contains("@binding(0)"));
979 assert!(desc.contains("@binding(1)"));
980 }
981 #[test]
982 fn test_bind_group_layout_empty() {
983 let layout = BindGroupLayout::new();
984 assert_eq!(layout.binding_count(), 0);
985 assert!(layout.is_empty());
986 }
987}