Skip to main content

proof_engine/render/compute/
mod.rs

1//! GPU Compute Pipeline for Proof Engine.
2//!
3//! Abstracts over GPU compute shaders for:
4//! - 100K+ particle simulation (position/velocity integration on GPU)
5//! - Compute shader dispatch with SSBO (Shader Storage Buffer Objects)
6//! - Double-buffered state for ping-pong GPU updates
7//! - Indirect draw command generation from compute results
8//! - GPU particle sorting (bitonic sort compute shader)
9//! - Force field evaluation on the GPU
10//! - Fluid simulation compute passes
11
12// Note: This module provides the CPU-side orchestration and data structures
13// for GPU compute. Actual GLSL shader source strings are included for completeness.
14
15use std::collections::HashMap;
16
17// ── GpuBufferId ───────────────────────────────────────────────────────────────
18
19/// Resource handle used by render_graph.rs.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct ResourceHandle(pub u32);
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub struct BufferId(pub u32);
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub struct ComputePassId(pub u32);
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct PipelineId(pub u32);
31
32// ── GpuBufferDesc ─────────────────────────────────────────────────────────────
33
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum BufferUsage {
36    /// Shader Storage Buffer Object — read/write by compute.
37    Ssbo,
38    /// Uniform Buffer Object — read-only small data.
39    Ubo,
40    /// Indirect draw command buffer.
41    IndirectDraw,
42    /// Atomic counter.
43    Atomic,
44    /// Vertex buffer populated by compute.
45    VertexOut,
46}
47
48#[derive(Debug, Clone)]
49pub struct GpuBufferDesc {
50    pub id:     BufferId,
51    pub name:   String,
52    pub size:   usize, // bytes
53    pub usage:  BufferUsage,
54    /// Optional initial data.
55    pub data:   Option<Vec<u8>>,
56    pub dynamic: bool,
57}
58
59impl GpuBufferDesc {
60    pub fn ssbo(name: &str, size: usize) -> Self {
61        Self { id: BufferId(0), name: name.to_string(), size, usage: BufferUsage::Ssbo, data: None, dynamic: true }
62    }
63
64    pub fn ubo(name: &str, size: usize) -> Self {
65        Self { id: BufferId(0), name: name.to_string(), size, usage: BufferUsage::Ubo, data: None, dynamic: false }
66    }
67
68    pub fn indirect(name: &str, max_draws: usize) -> Self {
69        // IndirectDrawArraysCommand = 4 × u32 = 16 bytes each
70        Self::ssbo(name, max_draws * 16).with_usage(BufferUsage::IndirectDraw)
71    }
72
73    fn with_usage(mut self, usage: BufferUsage) -> Self { self.usage = usage; self }
74
75    pub fn with_data(mut self, data: Vec<u8>) -> Self { self.data = Some(data); self }
76}
77
78// ── ComputePassDesc ───────────────────────────────────────────────────────────
79
80/// Descriptor for a single compute dispatch pass.
81#[derive(Debug, Clone)]
82pub struct ComputePassDesc {
83    pub id:          ComputePassId,
84    pub name:        String,
85    pub shader_src:  String,
86    pub work_groups: [u32; 3],
87    /// (binding_point, buffer_id)
88    pub ssbo_bindings: Vec<(u32, BufferId)>,
89    pub ubo_bindings:  Vec<(u32, BufferId)>,
90    pub uniforms:      HashMap<String, ComputeUniform>,
91    /// Barrier required before next pass.
92    pub barrier:       MemoryBarrier,
93}
94
95#[derive(Debug, Clone)]
96pub enum ComputeUniform {
97    Float(f32),
98    Vec2([f32; 2]),
99    Vec3([f32; 3]),
100    Vec4([f32; 4]),
101    Int(i32),
102    UInt(u32),
103    Mat4([f32; 16]),
104}
105
106#[derive(Debug, Clone, Copy, PartialEq)]
107pub enum MemoryBarrier {
108    None,
109    ShaderStorage,
110    Buffer,
111    All,
112}
113
114impl ComputePassDesc {
115    pub fn new(name: &str, shader_src: &str) -> Self {
116        Self {
117            id: ComputePassId(0),
118            name: name.to_string(),
119            shader_src: shader_src.to_string(),
120            work_groups: [1, 1, 1],
121            ssbo_bindings: Vec::new(),
122            ubo_bindings: Vec::new(),
123            uniforms: HashMap::new(),
124            barrier: MemoryBarrier::ShaderStorage,
125        }
126    }
127
128    pub fn dispatch(mut self, x: u32, y: u32, z: u32) -> Self {
129        self.work_groups = [x, y, z];
130        self
131    }
132
133    pub fn bind_ssbo(mut self, binding: u32, buf: BufferId) -> Self {
134        self.ssbo_bindings.push((binding, buf));
135        self
136    }
137
138    pub fn bind_ubo(mut self, binding: u32, buf: BufferId) -> Self {
139        self.ubo_bindings.push((binding, buf));
140        self
141    }
142
143    pub fn set_uniform(mut self, name: &str, v: ComputeUniform) -> Self {
144        self.uniforms.insert(name.to_string(), v);
145        self
146    }
147}
148
149// ── Particle layout (matches GPU struct) ──────────────────────────────────────
150
151/// CPU-side mirror of the GPU particle struct.
152/// `#[repr(C)]` to match GLSL std430 layout.
153#[derive(Debug, Clone, Copy)]
154#[repr(C)]
155pub struct GpuParticle {
156    pub position:  [f32; 4], // xyz + lifetime
157    pub velocity:  [f32; 4], // xyz + age
158    pub color:     [f32; 4], // rgba
159    pub size:      f32,
160    pub mass:      f32,
161    pub flags:     u32,      // bit 0: alive, bit 1: emitting, bit 2: collides
162    pub attractor: u32,      // index of attractor affecting this particle
163}
164
165impl GpuParticle {
166    pub const SIZE: usize = std::mem::size_of::<GpuParticle>();
167
168    pub fn alive(pos: [f32; 3], vel: [f32; 3], color: [f32; 4], lifetime: f32, size: f32) -> Self {
169        Self {
170            position: [pos[0], pos[1], pos[2], lifetime],
171            velocity: [vel[0], vel[1], vel[2], 0.0],
172            color,
173            size,
174            mass: 1.0,
175            flags: 1,
176            attractor: 0,
177        }
178    }
179
180    pub fn is_alive(&self) -> bool { self.flags & 1 != 0 }
181    pub fn lifetime(&self) -> f32 { self.position[3] }
182    pub fn age(&self) -> f32 { self.velocity[3] }
183}
184
185// ── GPU Attractor (force field element) ──────────────────────────────────────
186
187#[derive(Debug, Clone, Copy)]
188#[repr(C)]
189pub struct GpuAttractor {
190    pub position:  [f32; 4], // xyz + strength
191    pub params:    [f32; 4], // type, falloff_start, falloff_end, rotation
192    pub color:     [f32; 4],
193    pub attractor_type: u32, // 0=point, 1=vortex, 2=lorenz, 3=repulse
194    _pad: [u32; 3],
195}
196
197impl GpuAttractor {
198    pub fn point(pos: [f32; 3], strength: f32) -> Self {
199        Self {
200            position: [pos[0], pos[1], pos[2], strength],
201            params: [0.0, 0.5, 5.0, 0.0],
202            color: [1.0; 4],
203            attractor_type: 0,
204            _pad: [0; 3],
205        }
206    }
207
208    pub fn vortex(pos: [f32; 3], strength: f32, rotation: f32) -> Self {
209        let mut a = Self::point(pos, strength);
210        a.attractor_type = 1;
211        a.params[3] = rotation;
212        a
213    }
214}
215
216// ── IndirectDrawCommand ───────────────────────────────────────────────────────
217
218#[derive(Debug, Clone, Copy)]
219#[repr(C)]
220pub struct IndirectDrawCommand {
221    pub count:      u32, // vertices to draw
222    pub prim_count: u32, // instances
223    pub first:      u32, // first vertex
224    pub base_inst:  u32, // base instance
225}
226
227impl IndirectDrawCommand {
228    pub fn new(count: u32) -> Self {
229        Self { count, prim_count: 1, first: 0, base_inst: 0 }
230    }
231}
232
233// ── GLSL Shader Sources ───────────────────────────────────────────────────────
234
235/// Particle integration compute shader (GLSL 4.30+).
236pub const PARTICLE_INTEGRATE_GLSL: &str = r#"
237#version 430 core
238
239layout(local_size_x = 256) in;
240
241struct Particle {
242    vec4 position;  // xyz + lifetime
243    vec4 velocity;  // xyz + age
244    vec4 color;
245    float size;
246    float mass;
247    uint flags;
248    uint attractor;
249};
250
251struct Attractor {
252    vec4 position;   // xyz + strength
253    vec4 params;     // type, falloff_start, falloff_end, rotation
254    vec4 color;
255    uint atype;
256    uint _pad[3];
257};
258
259layout(std430, binding = 0) buffer ParticleBuffer {
260    Particle particles[];
261};
262
263layout(std430, binding = 1) readonly buffer AttractorBuffer {
264    Attractor attractors[];
265};
266
267layout(std430, binding = 2) buffer DeadList {
268    uint dead_count;
269    uint dead_indices[];
270};
271
272layout(std140, binding = 0) uniform Params {
273    float dt;
274    float time;
275    vec3 gravity;
276    float drag;
277    uint num_particles;
278    uint num_attractors;
279    float emit_rate;
280    float _pad;
281};
282
283// Lorenz attractor vector field
284vec3 lorenz(vec3 p, float sigma, float rho, float beta) {
285    return vec3(
286        sigma * (p.y - p.x),
287        p.x * (rho - p.z) - p.y,
288        p.x * p.y - beta * p.z
289    );
290}
291
292// Vortex force
293vec3 vortex_force(vec3 particle_pos, vec3 center, float strength, float rotation) {
294    vec3 r = particle_pos - center;
295    float d = length(r) + 0.001;
296    vec3 tangent = cross(r, vec3(0.0, 1.0, 0.0)) / d;
297    return tangent * strength * rotation / (d * d + 1.0);
298}
299
300vec3 compute_attractor_force(Particle p, Attractor a) {
301    vec3 pos = p.position.xyz;
302    vec3 apos = a.position.xyz;
303    float strength = a.position.w;
304    float falloff_start = a.params.y;
305    float falloff_end   = a.params.z;
306
307    vec3 delta = apos - pos;
308    float dist = length(delta) + 0.001;
309
310    // Falloff
311    float t = clamp((dist - falloff_start) / (falloff_end - falloff_start + 0.001), 0.0, 1.0);
312    float attenuation = 1.0 - t;
313
314    switch (a.atype) {
315        case 0: // Point attractor
316            return normalize(delta) * strength * attenuation / (dist * dist + 1.0);
317        case 1: // Vortex
318            return vortex_force(pos, apos, strength * attenuation, a.params.w);
319        case 2: // Lorenz field
320            return lorenz(pos * 0.1, 10.0, 28.0, 2.667) * strength * 0.01 * attenuation;
321        case 3: // Repulse
322            return -normalize(delta) * strength * attenuation / (dist * dist + 0.5);
323        default:
324            return vec3(0.0);
325    }
326}
327
328void main() {
329    uint idx = gl_GlobalInvocationID.x;
330    if (idx >= num_particles) return;
331
332    Particle p = particles[idx];
333    if ((p.flags & 1u) == 0u) return; // Skip dead particles
334
335    // Accumulate forces
336    vec3 force = gravity * p.mass;
337
338    for (uint i = 0; i < num_attractors; i++) {
339        force += compute_attractor_force(p, attractors[i]);
340    }
341
342    // Drag
343    force -= p.velocity.xyz * drag;
344
345    // Semi-implicit Euler integration
346    vec3 new_vel = p.velocity.xyz + (force / p.mass) * dt;
347    vec3 new_pos = p.position.xyz + new_vel * dt;
348
349    // Age
350    float new_age      = p.velocity.w + dt;
351    float lifetime     = p.position.w;
352
353    // Kill if expired
354    if (new_age >= lifetime) {
355        p.flags &= ~1u; // clear alive bit
356        uint dead_idx = atomicAdd(dead_count, 1u);
357        dead_indices[dead_idx] = idx;
358    } else {
359        p.position.xyz = new_pos;
360        p.velocity.xyz = new_vel;
361        p.velocity.w   = new_age;
362    }
363
364    particles[idx] = p;
365}
366"#;
367
368/// Particle emit compute shader — spawns new particles from dead list.
369pub const PARTICLE_EMIT_GLSL: &str = r#"
370#version 430 core
371
372layout(local_size_x = 64) in;
373
374struct Particle {
375    vec4 position;
376    vec4 velocity;
377    vec4 color;
378    float size;
379    float mass;
380    uint flags;
381    uint attractor;
382};
383
384layout(std430, binding = 0) buffer ParticleBuffer { Particle particles[]; };
385layout(std430, binding = 1) buffer DeadList       { uint dead_count; uint dead_indices[]; };
386layout(std430, binding = 2) readonly buffer EmitBuffer { uint emit_count; uvec4 emit_data[]; };
387
388layout(std140, binding = 0) uniform EmitParams {
389    vec3 origin;
390    float spread;
391    vec4 color_a;
392    vec4 color_b;
393    float lifetime_min;
394    float lifetime_max;
395    float speed_min;
396    float speed_max;
397    float size_min;
398    float size_max;
399    float time;
400    uint seed;
401};
402
403// Simple hash function for pseudo-randomness
404float hash(uint n) {
405    n = (n ^ 61u) ^ (n >> 16u);
406    n *= 9u; n ^= n >> 4u;
407    n *= 0x27d4eb2du; n ^= n >> 15u;
408    return float(n) / float(0xFFFFFFFFu);
409}
410
411vec3 random_dir(uint seed) {
412    float theta = hash(seed)       * 6.2831853;
413    float phi   = hash(seed + 1u)  * 3.1415927;
414    return vec3(sin(phi)*cos(theta), cos(phi), sin(phi)*sin(theta));
415}
416
417void main() {
418    uint idx = gl_GlobalInvocationID.x;
419    if (idx >= emit_count) return;
420
421    // Claim a dead particle slot
422    uint dead_idx_pos = atomicAdd(dead_count, uint(-1));
423    if (dead_idx_pos == 0u) return; // No dead particles available
424    uint slot = dead_indices[dead_idx_pos - 1u];
425
426    uint s = seed + idx * 7u;
427    float lifetime = mix(lifetime_min, lifetime_max, hash(s));
428    float speed    = mix(speed_min,    speed_max,    hash(s + 2u));
429    float psize    = mix(size_min,     size_max,     hash(s + 3u));
430    vec3 dir = random_dir(s + 4u);
431    vec3 pos = origin + dir * spread * hash(s + 5u);
432
433    particles[slot].position = vec4(pos, lifetime);
434    particles[slot].velocity = vec4(dir * speed, 0.0);
435    particles[slot].color    = mix(color_a, color_b, hash(s + 6u));
436    particles[slot].size     = psize;
437    particles[slot].mass     = 1.0;
438    particles[slot].flags    = 1u;
439    particles[slot].attractor = 0u;
440}
441"#;
442
443/// Indirect draw generation — count alive particles and build draw command.
444pub const PARTICLE_COUNT_GLSL: &str = r#"
445#version 430 core
446
447layout(local_size_x = 256) in;
448
449struct Particle { vec4 position; vec4 velocity; vec4 color; float size; float mass; uint flags; uint attractor; };
450
451layout(std430, binding = 0) readonly buffer ParticleBuffer { Particle particles[]; };
452layout(std430, binding = 1) buffer IndirectBuffer {
453    uint vertex_count;
454    uint instance_count;
455    uint first_vertex;
456    uint base_instance;
457};
458
459uniform uint num_particles;
460
461shared uint local_count;
462
463void main() {
464    if (gl_LocalInvocationID.x == 0) local_count = 0;
465    barrier();
466
467    uint idx = gl_GlobalInvocationID.x;
468    if (idx < num_particles && (particles[idx].flags & 1u) != 0u) {
469        atomicAdd(local_count, 1u);
470    }
471    barrier();
472
473    if (gl_LocalInvocationID.x == 0) {
474        atomicAdd(instance_count, local_count);
475    }
476
477    if (idx == 0) vertex_count = 4u; // Billboard quad = 4 verts
478}
479"#;
480
481/// Fluid simulation velocity advection compute pass.
482pub const FLUID_ADVECT_GLSL: &str = r#"
483#version 430 core
484
485layout(local_size_x = 16, local_size_y = 16) in;
486
487layout(std430, binding = 0) buffer VelocityX  { float vel_x[]; };
488layout(std430, binding = 1) buffer VelocityY  { float vel_y[]; };
489layout(std430, binding = 2) buffer VelocityXn { float vel_xn[]; };
490layout(std430, binding = 3) buffer VelocityYn { float vel_yn[]; };
491layout(std430, binding = 4) readonly buffer Density { float density[]; };
492
493uniform int grid_w;
494uniform int grid_h;
495uniform float dt;
496uniform float dissipation;
497
498int idx(int x, int y) { return clamp(x, 0, grid_w-1) + clamp(y, 0, grid_h-1) * grid_w; }
499
500float sample_x(float px, float py) {
501    int x0 = int(floor(px)); int x1 = x0 + 1;
502    int y0 = int(floor(py)); int y1 = y0 + 1;
503    float tx = fract(px); float ty = fract(py);
504    return mix(mix(vel_x[idx(x0,y0)], vel_x[idx(x1,y0)], tx),
505               mix(vel_x[idx(x0,y1)], vel_x[idx(x1,y1)], tx), ty);
506}
507
508float sample_y(float px, float py) {
509    int x0 = int(floor(px)); int x1 = x0 + 1;
510    int y0 = int(floor(py)); int y1 = y0 + 1;
511    float tx = fract(px); float ty = fract(py);
512    return mix(mix(vel_y[idx(x0,y0)], vel_y[idx(x1,y0)], tx),
513               mix(vel_y[idx(x0,y1)], vel_y[idx(x1,y1)], tx), ty);
514}
515
516void main() {
517    int x = int(gl_GlobalInvocationID.x);
518    int y = int(gl_GlobalInvocationID.y);
519    if (x >= grid_w || y >= grid_h) return;
520
521    int i = idx(x, y);
522    float vx = vel_x[i];
523    float vy = vel_y[i];
524
525    // Backtrace
526    float px = float(x) - vx * dt;
527    float py = float(y) - vy * dt;
528
529    vel_xn[i] = sample_x(px, py) * dissipation;
530    vel_yn[i] = sample_y(px, py) * dissipation;
531}
532"#;
533
534/// Bitonic sort for GPU particle depth ordering.
535pub const BITONIC_SORT_GLSL: &str = r#"
536#version 430 core
537
538layout(local_size_x = 512) in;
539
540layout(std430, binding = 0) buffer Keys   { float keys[]; };   // depth values
541layout(std430, binding = 1) buffer Values { uint  values[]; }; // particle indices
542
543uniform uint num_elements;
544uniform uint block_size;
545uniform uint sub_block_size;
546uniform bool ascending;
547
548shared float shared_keys[512];
549shared uint  shared_vals[512];
550
551void main() {
552    uint gid = gl_GlobalInvocationID.x;
553    uint lid = gl_LocalInvocationID.x;
554
555    if (gid < num_elements) {
556        shared_keys[lid] = keys[gid];
557        shared_vals[lid] = values[gid];
558    } else {
559        shared_keys[lid] = ascending ? 1e38 : -1e38;
560        shared_vals[lid] = gid;
561    }
562    barrier();
563
564    for (uint stride = sub_block_size; stride > 0; stride >>= 1) {
565        uint idx_a = (gid / stride) * stride * 2 + (gid % stride);
566        uint idx_b = idx_a + stride;
567
568        if (idx_a < num_elements && idx_b < num_elements) {
569            bool swap_cond = ascending
570                ? (shared_keys[idx_a % 512] > shared_keys[idx_b % 512])
571                : (shared_keys[idx_a % 512] < shared_keys[idx_b % 512]);
572
573            if (swap_cond) {
574                float tmp_k = shared_keys[idx_a % 512];
575                shared_keys[idx_a % 512] = shared_keys[idx_b % 512];
576                shared_keys[idx_b % 512] = tmp_k;
577
578                uint tmp_v = shared_vals[idx_a % 512];
579                shared_vals[idx_a % 512] = shared_vals[idx_b % 512];
580                shared_vals[idx_b % 512] = tmp_v;
581            }
582        }
583        barrier();
584    }
585
586    if (gid < num_elements) {
587        keys[gid]   = shared_keys[lid];
588        values[gid] = shared_vals[lid];
589    }
590}
591"#;
592
593// ── ComputePipeline ───────────────────────────────────────────────────────────
594
595/// Manages a set of compute passes as a pipeline.
596pub struct ComputePipeline {
597    pub name:    String,
598    passes:      Vec<ComputePassDesc>,
599    buffers:     Vec<GpuBufferDesc>,
600    next_buf_id: u32,
601    next_pass_id: u32,
602    pub enabled: bool,
603    /// Execution order (pass indices).
604    pub order:   Vec<usize>,
605}
606
607impl ComputePipeline {
608    pub fn new(name: &str) -> Self {
609        Self { name: name.to_string(), passes: Vec::new(), buffers: Vec::new(),
610               next_buf_id: 1, next_pass_id: 1, enabled: true, order: Vec::new() }
611    }
612
613    pub fn add_buffer(&mut self, mut desc: GpuBufferDesc) -> BufferId {
614        let id = BufferId(self.next_buf_id);
615        self.next_buf_id += 1;
616        desc.id = id;
617        self.buffers.push(desc);
618        id
619    }
620
621    pub fn add_pass(&mut self, mut desc: ComputePassDesc) -> ComputePassId {
622        let id = ComputePassId(self.next_pass_id);
623        self.next_pass_id += 1;
624        desc.id = id;
625        let idx = self.passes.len();
626        self.passes.push(desc);
627        self.order.push(idx);
628        id
629    }
630
631    pub fn pass(&self, id: ComputePassId) -> Option<&ComputePassDesc> {
632        self.passes.iter().find(|p| p.id == id)
633    }
634
635    pub fn buffer(&self, id: BufferId) -> Option<&GpuBufferDesc> {
636        self.buffers.iter().find(|b| b.id == id)
637    }
638
639    pub fn buffer_by_name(&self, name: &str) -> Option<&GpuBufferDesc> {
640        self.buffers.iter().find(|b| b.name == name)
641    }
642
643    pub fn total_buffer_size(&self) -> usize {
644        self.buffers.iter().map(|b| b.size).sum()
645    }
646}
647
648// ── GpuParticleSystem ─────────────────────────────────────────────────────────
649
650/// Complete GPU particle system configuration and state.
651pub struct GpuParticleSystem {
652    pub pipeline: ComputePipeline,
653    pub max_particles: usize,
654    pub particle_buf_a: BufferId,
655    pub particle_buf_b: BufferId,
656    pub attractor_buf:  BufferId,
657    pub dead_list_buf:  BufferId,
658    pub indirect_buf:   BufferId,
659    pub params_ubo:     BufferId,
660    pub integrate_pass: ComputePassId,
661    pub emit_pass:      ComputePassId,
662    pub count_pass:     ComputePassId,
663    pub sort_pass:      ComputePassId,
664    /// Which buffer is "current" (ping/pong).
665    pub frame:          u64,
666    // CPU-side particle state for initial upload
667    pub initial_particles: Vec<GpuParticle>,
668    pub attractors:        Vec<GpuAttractor>,
669    pub gravity:           [f32; 3],
670    pub drag:              f32,
671    pub emit_rate:         f32,
672    pub do_sort:           bool,
673}
674
675impl GpuParticleSystem {
676    /// Build a complete 100K particle system pipeline.
677    pub fn new(max_particles: usize) -> Self {
678        let mut pipeline = ComputePipeline::new("gpu_particles");
679
680        // Buffers
681        let particle_size = GpuParticle::SIZE * max_particles;
682        let attractor_size = std::mem::size_of::<GpuAttractor>() * 64;
683        let dead_size = 4 + 4 * max_particles; // count + indices
684        let indirect_size = std::mem::size_of::<IndirectDrawCommand>();
685        let params_size = 64; // Params UBO
686
687        let particle_buf_a = pipeline.add_buffer(GpuBufferDesc::ssbo("particles_a", particle_size));
688        let particle_buf_b = pipeline.add_buffer(GpuBufferDesc::ssbo("particles_b", particle_size));
689        let attractor_buf  = pipeline.add_buffer(GpuBufferDesc::ssbo("attractors", attractor_size));
690        let dead_list_buf  = pipeline.add_buffer(GpuBufferDesc::ssbo("dead_list", dead_size));
691        let indirect_buf   = pipeline.add_buffer(GpuBufferDesc::indirect("indirect", 1));
692        let params_ubo     = pipeline.add_buffer(GpuBufferDesc::ubo("params", params_size));
693
694        // Work groups: 256 threads per group, ceil(N/256) groups
695        let integrate_groups = ((max_particles + 255) / 256) as u32;
696
697        let integrate_pass = pipeline.add_pass(
698            ComputePassDesc::new("integrate", PARTICLE_INTEGRATE_GLSL)
699                .dispatch(integrate_groups, 1, 1)
700                .bind_ssbo(0, particle_buf_a)
701                .bind_ssbo(1, attractor_buf)
702                .bind_ssbo(2, dead_list_buf)
703                .bind_ubo(0, params_ubo)
704        );
705
706        let emit_pass = pipeline.add_pass(
707            ComputePassDesc::new("emit", PARTICLE_EMIT_GLSL)
708                .dispatch(4, 1, 1)
709                .bind_ssbo(0, particle_buf_a)
710                .bind_ssbo(1, dead_list_buf)
711                .bind_ubo(0, params_ubo)
712        );
713
714        let count_pass = pipeline.add_pass(
715            ComputePassDesc::new("count", PARTICLE_COUNT_GLSL)
716                .dispatch(integrate_groups, 1, 1)
717                .bind_ssbo(0, particle_buf_a)
718                .bind_ssbo(1, indirect_buf)
719                .set_uniform("num_particles", ComputeUniform::UInt(max_particles as u32))
720        );
721
722        let sort_pass = pipeline.add_pass(
723            ComputePassDesc::new("sort", BITONIC_SORT_GLSL)
724                .dispatch((max_particles / 512 + 1) as u32, 1, 1)
725                .bind_ssbo(0, particle_buf_a)
726        );
727
728        Self {
729            pipeline,
730            max_particles,
731            particle_buf_a, particle_buf_b,
732            attractor_buf, dead_list_buf, indirect_buf, params_ubo,
733            integrate_pass, emit_pass, count_pass, sort_pass,
734            frame: 0,
735            initial_particles: Vec::new(),
736            attractors: Vec::new(),
737            gravity: [0.0, -9.81, 0.0],
738            drag: 0.02,
739            emit_rate: 1000.0,
740            do_sort: false,
741        }
742    }
743
744    /// Add an attractor.
745    pub fn add_attractor(&mut self, a: GpuAttractor) {
746        self.attractors.push(a);
747    }
748
749    /// Spawn initial particles (CPU-side setup for upload).
750    pub fn spawn_burst(&mut self, origin: [f32; 3], count: usize, speed: f32, lifetime: f32) {
751        for i in 0..count {
752            let theta = i as f32 * 2.399963; // golden angle
753            let phi   = (i as f32 / count as f32).acos();
754            let vel = [phi.sin() * theta.cos() * speed,
755                       phi.cos() * speed,
756                       phi.sin() * theta.sin() * speed];
757            self.initial_particles.push(GpuParticle::alive(
758                origin, vel, [1.0, 0.8, 0.2, 1.0], lifetime, 2.0,
759            ));
760        }
761    }
762
763    pub fn advance_frame(&mut self) { self.frame += 1; }
764
765    /// Current particle buffer (ping-pong).
766    pub fn current_buffer(&self) -> BufferId {
767        if self.frame % 2 == 0 { self.particle_buf_a } else { self.particle_buf_b }
768    }
769
770    /// Build dispatch parameters for this frame.
771    pub fn frame_params(&self, dt: f32) -> HashMap<String, f32> {
772        let mut p = HashMap::new();
773        p.insert("dt".to_string(), dt);
774        p.insert("time".to_string(), self.frame as f32 * dt);
775        p.insert("gravity_x".to_string(), self.gravity[0]);
776        p.insert("gravity_y".to_string(), self.gravity[1]);
777        p.insert("gravity_z".to_string(), self.gravity[2]);
778        p.insert("drag".to_string(), self.drag);
779        p.insert("num_particles".to_string(), self.max_particles as f32);
780        p.insert("emit_rate".to_string(), self.emit_rate);
781        p
782    }
783}
784
785// ── Presets ───────────────────────────────────────────────────────────────────
786
787pub struct ComputePresets;
788
789impl ComputePresets {
790    /// 100K chaos field: mix of point attractors + Lorenz
791    pub fn chaos_field() -> GpuParticleSystem {
792        let mut sys = GpuParticleSystem::new(100_000);
793        sys.add_attractor(GpuAttractor::point([0.0, 0.0, 0.0], 5.0));
794        sys.add_attractor(GpuAttractor::vortex([10.0, 0.0, 0.0], 3.0, 2.0));
795        sys.add_attractor(GpuAttractor::vortex([-10.0, 0.0, 0.0], 3.0, -2.0));
796        sys.spawn_burst([0.0; 3], 50_000, 0.5, 5.0);
797        sys.gravity = [0.0; 3];
798        sys.drag = 0.01;
799        sys
800    }
801
802    /// Fireworks burst: particles explode from center with gravity
803    pub fn fireworks() -> GpuParticleSystem {
804        let mut sys = GpuParticleSystem::new(50_000);
805        sys.spawn_burst([0.0, 0.0, 0.0], 50_000, 5.0, 3.0);
806        sys.gravity = [0.0, -9.81, 0.0];
807        sys.drag = 0.05;
808        sys
809    }
810
811    /// Fluid simulation particles
812    pub fn fluid_particles() -> GpuParticleSystem {
813        let mut sys = GpuParticleSystem::new(200_000);
814        sys.do_sort = true;
815        sys.gravity = [0.0, -2.0, 0.0];
816        sys.drag = 0.1;
817        sys
818    }
819}
820
821// ── Tests ─────────────────────────────────────────────────────────────────────
822
823#[cfg(test)]
824mod tests {
825    use super::*;
826
827    #[test]
828    fn test_gpu_particle_size() {
829        assert_eq!(GpuParticle::SIZE, 48, "Particle must be exactly 48 bytes for std430");
830    }
831
832    #[test]
833    fn test_pipeline_builds() {
834        let sys = GpuParticleSystem::new(1024);
835        assert!(sys.pipeline.total_buffer_size() > 0);
836        assert_eq!(sys.pipeline.passes.len(), 4);
837    }
838
839    #[test]
840    fn test_particle_alive() {
841        let p = GpuParticle::alive([1.0, 2.0, 3.0], [0.1, 0.2, 0.3], [1.0; 4], 5.0, 2.0);
842        assert!(p.is_alive());
843        assert!((p.lifetime() - 5.0).abs() < 0.001);
844    }
845
846    #[test]
847    fn test_attractor_types() {
848        let a = GpuAttractor::point([0.0; 3], 10.0);
849        assert_eq!(a.attractor_type, 0);
850        let v = GpuAttractor::vortex([1.0, 0.0, 0.0], 5.0, 1.5);
851        assert_eq!(v.attractor_type, 1);
852    }
853
854    #[test]
855    fn test_chaos_field_preset() {
856        let sys = ComputePresets::chaos_field();
857        assert_eq!(sys.max_particles, 100_000);
858        assert_eq!(sys.attractors.len(), 3);
859        assert!(!sys.initial_particles.is_empty());
860    }
861
862    #[test]
863    fn test_frame_params() {
864        let sys = GpuParticleSystem::new(1024);
865        let params = sys.frame_params(0.016);
866        assert!((params["dt"] - 0.016).abs() < 0.0001);
867        assert_eq!(params["num_particles"] as usize, 1024);
868    }
869
870    #[test]
871    fn test_spawn_burst() {
872        let mut sys = GpuParticleSystem::new(10_000);
873        sys.spawn_burst([0.0; 3], 100, 1.0, 3.0);
874        assert_eq!(sys.initial_particles.len(), 100);
875        for p in &sys.initial_particles {
876            assert!(p.is_alive());
877        }
878    }
879
880    #[test]
881    fn test_pipeline_buffers() {
882        let sys = GpuParticleSystem::new(1000);
883        assert!(sys.pipeline.buffer(sys.particle_buf_a).is_some());
884        assert!(sys.pipeline.buffer(sys.attractor_buf).is_some());
885        assert!(sys.pipeline.buffer(sys.indirect_buf).is_some());
886    }
887
888    #[test]
889    fn test_shader_sources_not_empty() {
890        assert!(!PARTICLE_INTEGRATE_GLSL.is_empty());
891        assert!(!PARTICLE_EMIT_GLSL.is_empty());
892        assert!(!FLUID_ADVECT_GLSL.is_empty());
893        assert!(!BITONIC_SORT_GLSL.is_empty());
894    }
895
896    #[test]
897    fn test_compute_pass_desc_builder() {
898        let buf_id = BufferId(1);
899        let pass = ComputePassDesc::new("test", "#version 430")
900            .dispatch(32, 1, 1)
901            .bind_ssbo(0, buf_id)
902            .set_uniform("num_particles", ComputeUniform::UInt(1000));
903        assert_eq!(pass.work_groups, [32, 1, 1]);
904        assert_eq!(pass.ssbo_bindings.len(), 1);
905        assert!(pass.uniforms.contains_key("num_particles"));
906    }
907}