Skip to main content

proof_engine/compute/
kernels.rs

1//! Built-in compute kernels as embedded GLSL source strings.
2//!
3//! Each kernel is a fully implemented GLSL compute shader. The Rust side
4//! provides parameter structs and convenience methods to compile and dispatch
5//! each kernel through the `dispatch` module.
6//!
7//! Kernels:
8//! 1. **particle_integrate** — position += velocity * dt, apply forces, age, kill dead
9//! 2. **particle_emit** — atomic counter for birth, initialize from emitter params
10//! 3. **force_field_sample** — evaluate multiple force fields at particle positions
11//! 4. **math_function_gpu** — Lorenz attractor, Mandelbrot iteration, Julia set
12//! 5. **fluid_diffuse** — Jacobi iteration for diffusion
13//! 6. **histogram_equalize** — compute histogram and equalize
14//! 7. **prefix_sum** — Blelloch parallel prefix sum (scan)
15//! 8. **radix_sort** — GPU radix sort
16//! 9. **frustum_cull** — per-instance frustum culling
17//! 10. **skinning** — bone matrix palette skinning
18
19use std::collections::HashMap;
20
21// ---------------------------------------------------------------------------
22// KernelId
23// ---------------------------------------------------------------------------
24
25/// Identifies a built-in kernel.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum KernelId {
28    ParticleIntegrate,
29    ParticleEmit,
30    ForceFieldSample,
31    MathFunctionGpu,
32    FluidDiffuse,
33    HistogramEqualize,
34    PrefixSum,
35    RadixSort,
36    FrustumCull,
37    Skinning,
38}
39
40impl KernelId {
41    /// All kernel IDs.
42    pub fn all() -> &'static [KernelId] {
43        &[
44            KernelId::ParticleIntegrate,
45            KernelId::ParticleEmit,
46            KernelId::ForceFieldSample,
47            KernelId::MathFunctionGpu,
48            KernelId::FluidDiffuse,
49            KernelId::HistogramEqualize,
50            KernelId::PrefixSum,
51            KernelId::RadixSort,
52            KernelId::FrustumCull,
53            KernelId::Skinning,
54        ]
55    }
56
57    /// Human-readable name.
58    pub fn name(&self) -> &'static str {
59        match self {
60            KernelId::ParticleIntegrate => "particle_integrate",
61            KernelId::ParticleEmit => "particle_emit",
62            KernelId::ForceFieldSample => "force_field_sample",
63            KernelId::MathFunctionGpu => "math_function_gpu",
64            KernelId::FluidDiffuse => "fluid_diffuse",
65            KernelId::HistogramEqualize => "histogram_equalize",
66            KernelId::PrefixSum => "prefix_sum",
67            KernelId::RadixSort => "radix_sort",
68            KernelId::FrustumCull => "frustum_cull",
69            KernelId::Skinning => "skinning",
70        }
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Parameter structs
76// ---------------------------------------------------------------------------
77
78/// Parameters for the particle integration kernel.
79#[derive(Debug, Clone, Copy)]
80pub struct ParticleIntegrateParams {
81    pub dt: f32,
82    pub gravity: [f32; 3],
83    pub damping: f32,
84    pub particle_count: u32,
85    pub max_age: f32,
86    pub wind: [f32; 3],
87    pub turbulence_strength: f32,
88    pub time: f32,
89}
90
91impl Default for ParticleIntegrateParams {
92    fn default() -> Self {
93        Self {
94            dt: 1.0 / 60.0,
95            gravity: [0.0, -9.81, 0.0],
96            damping: 0.98,
97            particle_count: 0,
98            max_age: 5.0,
99            wind: [0.0; 3],
100            turbulence_strength: 0.0,
101            time: 0.0,
102        }
103    }
104}
105
106/// Parameters for the particle emission kernel.
107#[derive(Debug, Clone, Copy)]
108pub struct ParticleEmitParams {
109    pub emit_count: u32,
110    pub max_particles: u32,
111    pub emitter_position: [f32; 3],
112    pub emitter_radius: f32,
113    pub initial_speed_min: f32,
114    pub initial_speed_max: f32,
115    pub initial_direction: [f32; 3],
116    pub spread_angle: f32,
117    pub lifetime_min: f32,
118    pub lifetime_max: f32,
119    pub time: f32,
120    pub seed: u32,
121    pub color_start: [f32; 4],
122    pub color_end: [f32; 4],
123    pub size_start: f32,
124    pub size_end: f32,
125}
126
127impl Default for ParticleEmitParams {
128    fn default() -> Self {
129        Self {
130            emit_count: 100,
131            max_particles: 100_000,
132            emitter_position: [0.0; 3],
133            emitter_radius: 0.1,
134            initial_speed_min: 1.0,
135            initial_speed_max: 3.0,
136            initial_direction: [0.0, 1.0, 0.0],
137            spread_angle: 0.5,
138            lifetime_min: 1.0,
139            lifetime_max: 3.0,
140            time: 0.0,
141            seed: 0,
142            color_start: [1.0, 1.0, 1.0, 1.0],
143            color_end: [1.0, 1.0, 1.0, 0.0],
144            size_start: 1.0,
145            size_end: 0.0,
146        }
147    }
148}
149
150/// Describes a force field for the force_field_sample kernel.
151#[derive(Debug, Clone, Copy)]
152pub struct ForceFieldDesc {
153    pub field_type: ForceFieldType,
154    pub position: [f32; 3],
155    pub strength: f32,
156    pub radius: f32,
157    pub falloff: f32,
158    pub direction: [f32; 3],
159    pub frequency: f32,
160}
161
162/// Types of force fields.
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum ForceFieldType {
165    Attractor = 0,
166    Repulsor = 1,
167    Vortex = 2,
168    Directional = 3,
169    Noise = 4,
170    Drag = 5,
171}
172
173impl Default for ForceFieldDesc {
174    fn default() -> Self {
175        Self {
176            field_type: ForceFieldType::Attractor,
177            position: [0.0; 3],
178            strength: 1.0,
179            radius: 10.0,
180            falloff: 2.0,
181            direction: [0.0, 1.0, 0.0],
182            frequency: 1.0,
183        }
184    }
185}
186
187/// Math function types for the math_function_gpu kernel.
188#[derive(Debug, Clone, Copy, PartialEq, Eq)]
189pub enum MathFunctionType {
190    LorenzAttractor = 0,
191    MandelbrotIteration = 1,
192    JuliaSet = 2,
193    RosslerAttractor = 3,
194    AizawaAttractor = 4,
195}
196
197/// Parameters for the fluid diffusion kernel (Jacobi iteration).
198#[derive(Debug, Clone, Copy)]
199pub struct FluidDiffuseParams {
200    pub grid_width: u32,
201    pub grid_height: u32,
202    pub diffusion_rate: f32,
203    pub dt: f32,
204    pub iterations: u32,
205}
206
207impl Default for FluidDiffuseParams {
208    fn default() -> Self {
209        Self {
210            grid_width: 256,
211            grid_height: 256,
212            diffusion_rate: 0.001,
213            dt: 1.0 / 60.0,
214            iterations: 20,
215        }
216    }
217}
218
219/// Parameters for histogram equalization.
220#[derive(Debug, Clone, Copy)]
221pub struct HistogramParams {
222    pub width: u32,
223    pub height: u32,
224    pub bin_count: u32,
225    pub min_value: f32,
226    pub max_value: f32,
227}
228
229impl Default for HistogramParams {
230    fn default() -> Self {
231        Self {
232            width: 1920,
233            height: 1080,
234            bin_count: 256,
235            min_value: 0.0,
236            max_value: 1.0,
237        }
238    }
239}
240
241/// Plan for prefix sum (Blelloch algorithm).
242#[derive(Debug, Clone)]
243pub struct PrefixSumPlan {
244    pub element_count: u32,
245    pub workgroup_size: u32,
246    pub inclusive: bool,
247}
248
249impl Default for PrefixSumPlan {
250    fn default() -> Self {
251        Self {
252            element_count: 1024,
253            workgroup_size: 256,
254            inclusive: false,
255        }
256    }
257}
258
259/// Plan for radix sort.
260#[derive(Debug, Clone)]
261pub struct RadixSortPlan {
262    pub element_count: u32,
263    pub bits_per_pass: u32,
264    pub total_bits: u32,
265    pub workgroup_size: u32,
266}
267
268impl Default for RadixSortPlan {
269    fn default() -> Self {
270        Self {
271            element_count: 1024,
272            bits_per_pass: 4,
273            total_bits: 32,
274            workgroup_size: 256,
275        }
276    }
277}
278
279impl RadixSortPlan {
280    /// Number of passes needed.
281    pub fn pass_count(&self) -> u32 {
282        (self.total_bits + self.bits_per_pass - 1) / self.bits_per_pass
283    }
284
285    /// Number of bins per pass (2^bits_per_pass).
286    pub fn radix(&self) -> u32 {
287        1 << self.bits_per_pass
288    }
289}
290
291/// Parameters for frustum culling.
292#[derive(Debug, Clone, Copy)]
293pub struct FrustumCullParams {
294    pub instance_count: u32,
295    pub frustum_planes: [[f32; 4]; 6],
296    pub lod_distances: [f32; 4],
297    pub camera_position: [f32; 3],
298    pub enable_lod: bool,
299}
300
301impl Default for FrustumCullParams {
302    fn default() -> Self {
303        Self {
304            instance_count: 0,
305            frustum_planes: [[0.0; 4]; 6],
306            lod_distances: [50.0, 150.0, 500.0, 1000.0],
307            camera_position: [0.0; 3],
308            enable_lod: true,
309        }
310    }
311}
312
313/// Parameters for skeletal skinning.
314#[derive(Debug, Clone, Copy)]
315pub struct SkinningParams {
316    pub vertex_count: u32,
317    pub bone_count: u32,
318    pub max_bones_per_vertex: u32,
319}
320
321impl Default for SkinningParams {
322    fn default() -> Self {
323        Self {
324            vertex_count: 0,
325            bone_count: 64,
326            max_bones_per_vertex: 4,
327        }
328    }
329}
330
331// ---------------------------------------------------------------------------
332// GLSL kernel sources
333// ---------------------------------------------------------------------------
334
335/// Particle integration kernel: advance positions, apply forces, age, kill dead.
336pub const KERNEL_PARTICLE_INTEGRATE: &str = r#"
337// Particle integration kernel
338// Reads from SSBO binding 0, writes to SSBO binding 1 (ping-pong).
339// Each particle: vec4 position (xyz + age), vec4 velocity (xyz + lifetime).
340
341layout(local_size_x = 256) in;
342
343struct Particle {
344    vec4 pos_age;    // xyz = position, w = age
345    vec4 vel_life;   // xyz = velocity, w = lifetime
346};
347
348layout(std430, binding = 0) readonly buffer ParticlesIn {
349    Particle particles_in[];
350};
351
352layout(std430, binding = 1) writeonly buffer ParticlesOut {
353    Particle particles_out[];
354};
355
356uniform float u_dt;
357uniform vec3 u_gravity;
358uniform float u_damping;
359uniform uint u_particle_count;
360uniform float u_max_age;
361uniform vec3 u_wind;
362uniform float u_turbulence_strength;
363uniform float u_time;
364
365// Simple hash for turbulence
366float hash(vec3 p) {
367    p = fract(p * vec3(443.897, 441.423, 437.195));
368    p += dot(p, p.yzx + 19.19);
369    return fract((p.x + p.y) * p.z);
370}
371
372vec3 turbulence(vec3 pos, float time) {
373    float n1 = hash(pos + vec3(time * 0.3));
374    float n2 = hash(pos + vec3(time * 0.7, 0.0, 0.0));
375    float n3 = hash(pos + vec3(0.0, time * 0.5, 0.0));
376    return vec3(n1 - 0.5, n2 - 0.5, n3 - 0.5) * 2.0;
377}
378
379void main() {
380    uint idx = gl_GlobalInvocationID.x;
381    if (idx >= u_particle_count) return;
382
383    Particle p = particles_in[idx];
384
385    float age = p.pos_age.w;
386    float lifetime = p.vel_life.w;
387
388    // Advance age
389    age += u_dt;
390
391    // Kill dead particles by setting lifetime to 0
392    if (age >= lifetime || age >= u_max_age) {
393        p.pos_age.w = lifetime + 1.0; // Mark as dead
394        p.vel_life.xyz = vec3(0.0);
395        particles_out[idx] = p;
396        return;
397    }
398
399    // Apply forces
400    vec3 vel = p.vel_life.xyz;
401    vec3 pos = p.pos_age.xyz;
402
403    // Gravity
404    vel += u_gravity * u_dt;
405
406    // Wind
407    vel += u_wind * u_dt;
408
409    // Turbulence
410    if (u_turbulence_strength > 0.0) {
411        vec3 turb = turbulence(pos * 0.1, u_time);
412        vel += turb * u_turbulence_strength * u_dt;
413    }
414
415    // Damping
416    vel *= pow(u_damping, u_dt);
417
418    // Integrate position
419    pos += vel * u_dt;
420
421    // Write output
422    particles_out[idx].pos_age = vec4(pos, age);
423    particles_out[idx].vel_life = vec4(vel, lifetime);
424}
425"#;
426
427/// Particle emission kernel: spawn new particles using atomic counter.
428pub const KERNEL_PARTICLE_EMIT: &str = r#"
429// Particle emission kernel
430// Uses an atomic counter to allocate slots in the particle buffer.
431
432layout(local_size_x = 64) in;
433
434struct Particle {
435    vec4 pos_age;
436    vec4 vel_life;
437};
438
439layout(std430, binding = 1) writeonly buffer ParticlesOut {
440    Particle particles[];
441};
442
443layout(binding = 0, offset = 0) uniform atomic_uint u_alive_count;
444
445uniform uint u_emit_count;
446uniform uint u_max_particles;
447uniform vec3 u_emitter_pos;
448uniform float u_emitter_radius;
449uniform float u_speed_min;
450uniform float u_speed_max;
451uniform vec3 u_direction;
452uniform float u_spread;
453uniform float u_life_min;
454uniform float u_life_max;
455uniform float u_time;
456uniform uint u_seed;
457uniform vec4 u_color_start;
458uniform vec4 u_color_end;
459
460// PCG random number generator
461uint pcg(uint state) {
462    uint s = state * 747796405u + 2891336453u;
463    uint word = ((s >> ((s >> 28u) + 4u)) ^ s) * 277803737u;
464    return (word >> 22u) ^ word;
465}
466
467float rand01(inout uint seed) {
468    seed = pcg(seed);
469    return float(seed) / 4294967295.0;
470}
471
472vec3 random_direction(inout uint seed, vec3 dir, float spread) {
473    float phi = rand01(seed) * 6.283185307;
474    float cos_theta = 1.0 - rand01(seed) * spread;
475    float sin_theta = sqrt(max(0.0, 1.0 - cos_theta * cos_theta));
476
477    vec3 random_dir = vec3(
478        sin_theta * cos(phi),
479        sin_theta * sin(phi),
480        cos_theta
481    );
482
483    // Rotate random_dir to align with dir
484    vec3 up = abs(dir.y) < 0.999 ? vec3(0, 1, 0) : vec3(1, 0, 0);
485    vec3 right = normalize(cross(up, dir));
486    up = cross(dir, right);
487
488    return right * random_dir.x + up * random_dir.y + dir * random_dir.z;
489}
490
491vec3 random_sphere(inout uint seed, float radius) {
492    float phi = rand01(seed) * 6.283185307;
493    float cos_theta = rand01(seed) * 2.0 - 1.0;
494    float sin_theta = sqrt(1.0 - cos_theta * cos_theta);
495    float r = pow(rand01(seed), 1.0 / 3.0) * radius;
496    return r * vec3(sin_theta * cos(phi), sin_theta * sin(phi), cos_theta);
497}
498
499void main() {
500    uint idx = gl_GlobalInvocationID.x;
501    if (idx >= u_emit_count) return;
502
503    // Allocate a slot
504    uint slot = atomicCounterIncrement(u_alive_count);
505    if (slot >= u_max_particles) return;
506
507    // Seed RNG
508    uint seed = u_seed + idx * 1973u + uint(u_time * 1000.0) * 7919u;
509
510    // Random position within emitter sphere
511    vec3 pos = u_emitter_pos + random_sphere(seed, u_emitter_radius);
512
513    // Random direction with spread
514    vec3 dir = random_direction(seed, normalize(u_direction), u_spread);
515
516    // Random speed
517    float speed = mix(u_speed_min, u_speed_max, rand01(seed));
518
519    // Random lifetime
520    float life = mix(u_life_min, u_life_max, rand01(seed));
521
522    particles[slot].pos_age = vec4(pos, 0.0);
523    particles[slot].vel_life = vec4(dir * speed, life);
524}
525"#;
526
527/// Force field sampling kernel: evaluate multiple force fields at particle positions.
528pub const KERNEL_FORCE_FIELD_SAMPLE: &str = r#"
529// Force field sampling kernel
530// Reads particle positions, evaluates force fields, accumulates forces into velocity.
531
532layout(local_size_x = 256) in;
533
534struct Particle {
535    vec4 pos_age;
536    vec4 vel_life;
537};
538
539layout(std430, binding = 0) buffer Particles {
540    Particle particles[];
541};
542
543// Force field types: 0=attractor, 1=repulsor, 2=vortex, 3=directional, 4=noise, 5=drag
544struct ForceField {
545    vec4 pos_strength;   // xyz = position, w = strength
546    vec4 dir_radius;     // xyz = direction, w = radius
547    vec4 params;         // x = falloff, y = frequency, z = type, w = unused
548};
549
550layout(std430, binding = 2) readonly buffer ForceFields {
551    ForceField fields[];
552};
553
554uniform uint u_particle_count;
555uniform uint u_field_count;
556uniform float u_dt;
557uniform float u_time;
558
559float hash31(vec3 p) {
560    p = fract(p * vec3(443.8975, 441.4230, 437.1950));
561    p += dot(p, p.yzx + 19.19);
562    return fract((p.x + p.y) * p.z);
563}
564
565vec3 eval_field(ForceField f, vec3 pos, float time) {
566    vec3 field_pos = f.pos_strength.xyz;
567    float strength = f.pos_strength.w;
568    float radius = f.dir_radius.w;
569    vec3 direction = f.dir_radius.xyz;
570    float falloff = f.params.x;
571    float freq = f.params.y;
572    int ftype = int(f.params.z);
573
574    vec3 delta = field_pos - pos;
575    float dist = length(delta);
576    float atten = 1.0;
577    if (radius > 0.0) {
578        atten = 1.0 - clamp(dist / radius, 0.0, 1.0);
579        atten = pow(atten, falloff);
580    }
581
582    vec3 force = vec3(0.0);
583
584    if (ftype == 0) {
585        // Attractor: pull toward center
586        if (dist > 0.001) {
587            force = normalize(delta) * strength * atten;
588        }
589    } else if (ftype == 1) {
590        // Repulsor: push away from center
591        if (dist > 0.001) {
592            force = -normalize(delta) * strength * atten;
593        }
594    } else if (ftype == 2) {
595        // Vortex: swirl around axis (direction)
596        vec3 axis = normalize(direction);
597        vec3 radial = delta - dot(delta, axis) * axis;
598        if (length(radial) > 0.001) {
599            vec3 tangent = cross(axis, normalize(radial));
600            force = tangent * strength * atten;
601        }
602    } else if (ftype == 3) {
603        // Directional: constant force in a direction within radius
604        force = normalize(direction) * strength * atten;
605    } else if (ftype == 4) {
606        // Noise: pseudo-random force based on position and time
607        float n1 = hash31(pos * freq + vec3(time));
608        float n2 = hash31(pos * freq + vec3(0.0, time, 0.0));
609        float n3 = hash31(pos * freq + vec3(0.0, 0.0, time));
610        force = (vec3(n1, n2, n3) * 2.0 - 1.0) * strength * atten;
611    } else if (ftype == 5) {
612        // Drag: opposes velocity (we approximate by opposing position change)
613        force = -normalize(delta) * strength * atten * dist;
614    }
615
616    return force;
617}
618
619void main() {
620    uint idx = gl_GlobalInvocationID.x;
621    if (idx >= u_particle_count) return;
622
623    Particle p = particles[idx];
624    vec3 pos = p.pos_age.xyz;
625    vec3 vel = p.vel_life.xyz;
626
627    // Skip dead particles
628    if (p.pos_age.w >= p.vel_life.w) return;
629
630    // Accumulate forces from all fields
631    vec3 total_force = vec3(0.0);
632    for (uint i = 0u; i < u_field_count; i++) {
633        total_force += eval_field(fields[i], pos, u_time);
634    }
635
636    // Apply accumulated force
637    vel += total_force * u_dt;
638    particles[idx].vel_life.xyz = vel;
639}
640"#;
641
642/// Math function GPU kernel: Lorenz, Mandelbrot, Julia, Rossler, Aizawa.
643pub const KERNEL_MATH_FUNCTION_GPU: &str = r#"
644// GPU math function kernel
645// Computes one step of various mathematical functions.
646// Mode uniform selects the function.
647
648layout(local_size_x = 256) in;
649
650struct Point {
651    vec4 pos_age;    // xyz = position, w = iteration count or age
652    vec4 vel_param;  // xyz = velocity/derivative, w = parameter
653};
654
655layout(std430, binding = 0) buffer Points {
656    Point points[];
657};
658
659uniform uint u_point_count;
660uniform uint u_function_type;  // 0=Lorenz, 1=Mandelbrot, 2=Julia, 3=Rossler, 4=Aizawa
661uniform float u_dt;
662uniform float u_time;
663uniform float u_param_a;
664uniform float u_param_b;
665uniform float u_param_c;
666uniform float u_param_d;
667uniform uint u_max_iterations;
668uniform vec2 u_julia_c;  // Julia set constant
669
670// Lorenz attractor: dx/dt = sigma*(y-x), dy/dt = x*(rho-z)-y, dz/dt = x*y - beta*z
671vec3 lorenz(vec3 p, float sigma, float rho, float beta) {
672    return vec3(
673        sigma * (p.y - p.x),
674        p.x * (rho - p.z) - p.y,
675        p.x * p.y - beta * p.z
676    );
677}
678
679// Rossler attractor: dx/dt = -y-z, dy/dt = x+a*y, dz/dt = b+z*(x-c)
680vec3 rossler(vec3 p, float a, float b, float c) {
681    return vec3(
682        -p.y - p.z,
683        p.x + a * p.y,
684        b + p.z * (p.x - c)
685    );
686}
687
688// Aizawa attractor
689vec3 aizawa(vec3 p, float a, float b, float c, float d) {
690    float x = p.x, y = p.y, z = p.z;
691    return vec3(
692        (z - b) * x - d * y,
693        d * x + (z - b) * y,
694        c + a * z - z * z * z / 3.0 - (x * x + y * y) * (1.0 + 0.25 * z) + 0.1 * z * x * x * x
695    );
696}
697
698// Complex multiply
699vec2 cmul(vec2 a, vec2 b) {
700    return vec2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
701}
702
703void main() {
704    uint idx = gl_GlobalInvocationID.x;
705    if (idx >= u_point_count) return;
706
707    Point pt = points[idx];
708    vec3 pos = pt.pos_age.xyz;
709    float age = pt.pos_age.w;
710
711    if (u_function_type == 0u) {
712        // Lorenz attractor (RK4 integration)
713        float sigma = u_param_a;  // default 10.0
714        float rho = u_param_b;    // default 28.0
715        float beta = u_param_c;   // default 8.0/3.0
716
717        vec3 k1 = lorenz(pos, sigma, rho, beta);
718        vec3 k2 = lorenz(pos + 0.5 * u_dt * k1, sigma, rho, beta);
719        vec3 k3 = lorenz(pos + 0.5 * u_dt * k2, sigma, rho, beta);
720        vec3 k4 = lorenz(pos + u_dt * k3, sigma, rho, beta);
721
722        pos += (u_dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
723        pt.vel_param.xyz = k1;  // Store derivative for visualization
724        age += u_dt;
725
726    } else if (u_function_type == 1u) {
727        // Mandelbrot iteration
728        // pos.xy = current z, vel_param.xy = c (constant)
729        vec2 z = pos.xy;
730        vec2 c = pt.vel_param.xy;
731        uint iter = uint(age);
732
733        if (iter < u_max_iterations && dot(z, z) < 4.0) {
734            z = cmul(z, z) + c;
735            pos.xy = z;
736            pos.z = dot(z, z);  // magnitude squared for coloring
737            age = float(iter + 1u);
738        }
739
740    } else if (u_function_type == 2u) {
741        // Julia set iteration
742        vec2 z = pos.xy;
743        vec2 c = u_julia_c;
744        uint iter = uint(age);
745
746        if (iter < u_max_iterations && dot(z, z) < 4.0) {
747            z = cmul(z, z) + c;
748            pos.xy = z;
749            pos.z = dot(z, z);
750            age = float(iter + 1u);
751        }
752
753    } else if (u_function_type == 3u) {
754        // Rossler attractor (RK4)
755        float a = u_param_a;  // default 0.2
756        float b = u_param_b;  // default 0.2
757        float c = u_param_c;  // default 5.7
758
759        vec3 k1 = rossler(pos, a, b, c);
760        vec3 k2 = rossler(pos + 0.5 * u_dt * k1, a, b, c);
761        vec3 k3 = rossler(pos + 0.5 * u_dt * k2, a, b, c);
762        vec3 k4 = rossler(pos + u_dt * k3, a, b, c);
763
764        pos += (u_dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
765        pt.vel_param.xyz = k1;
766        age += u_dt;
767
768    } else if (u_function_type == 4u) {
769        // Aizawa attractor (RK4)
770        float a = u_param_a;  // default 0.95
771        float b = u_param_b;  // default 0.7
772        float c = u_param_c;  // default 0.6
773        float d = u_param_d;  // default 3.5
774
775        vec3 k1 = aizawa(pos, a, b, c, d);
776        vec3 k2 = aizawa(pos + 0.5 * u_dt * k1, a, b, c, d);
777        vec3 k3 = aizawa(pos + 0.5 * u_dt * k2, a, b, c, d);
778        vec3 k4 = aizawa(pos + u_dt * k3, a, b, c, d);
779
780        pos += (u_dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
781        pt.vel_param.xyz = k1;
782        age += u_dt;
783    }
784
785    points[idx].pos_age = vec4(pos, age);
786    points[idx].vel_param = pt.vel_param;
787}
788"#;
789
790/// Fluid diffusion kernel using Jacobi iteration.
791pub const KERNEL_FLUID_DIFFUSE: &str = r#"
792// Jacobi iteration for 2D fluid diffusion
793// Reads from one grid, writes to another (ping-pong).
794// d(x)/dt = k * laplacian(x)
795// Jacobi: x_new[i,j] = (x_old[i,j] + alpha * (x[i-1,j] + x[i+1,j] + x[i,j-1] + x[i,j+1])) / (1 + 4*alpha)
796// where alpha = k * dt / (dx * dx)
797
798layout(local_size_x = 16, local_size_y = 16) in;
799
800layout(std430, binding = 0) readonly buffer GridIn {
801    float grid_in[];
802};
803
804layout(std430, binding = 1) writeonly buffer GridOut {
805    float grid_out[];
806};
807
808uniform uint u_width;
809uniform uint u_height;
810uniform float u_alpha;  // diffusion_rate * dt / (dx * dx)
811uniform float u_r_beta;  // 1.0 / (1.0 + 4.0 * alpha)
812
813uint idx2d(uint x, uint y) {
814    return y * u_width + x;
815}
816
817void main() {
818    uint x = gl_GlobalInvocationID.x;
819    uint y = gl_GlobalInvocationID.y;
820
821    if (x >= u_width || y >= u_height) return;
822
823    // Boundary: clamp to edge
824    uint x0 = max(x, 1u) - 1u;
825    uint x1 = min(x + 1u, u_width - 1u);
826    uint y0 = max(y, 1u) - 1u;
827    uint y1 = min(y + 1u, u_height - 1u);
828
829    float center = grid_in[idx2d(x, y)];
830    float left   = grid_in[idx2d(x0, y)];
831    float right  = grid_in[idx2d(x1, y)];
832    float down   = grid_in[idx2d(x, y0)];
833    float up     = grid_in[idx2d(x, y1)];
834
835    float result = (center + u_alpha * (left + right + down + up)) * u_r_beta;
836    grid_out[idx2d(x, y)] = result;
837}
838"#;
839
840/// Histogram equalization kernel (two passes: histogram + equalize).
841pub const KERNEL_HISTOGRAM_EQUALIZE: &str = r#"
842// Histogram computation and equalization.
843// Pass 1: Compute histogram (atomically increment bins).
844// Pass 2: Use CDF to remap values.
845// Selected by PASS_MODE define: 0 = histogram, 1 = CDF prefix sum, 2 = equalize.
846
847layout(local_size_x = 256) in;
848
849#ifndef PASS_MODE
850#define PASS_MODE 0
851#endif
852
853#ifndef BIN_COUNT
854#define BIN_COUNT 256
855#endif
856
857layout(std430, binding = 0) buffer InputData {
858    float input_data[];
859};
860
861layout(std430, binding = 1) buffer Histogram {
862    uint histogram[];
863};
864
865layout(std430, binding = 2) buffer CDF {
866    float cdf[];
867};
868
869layout(std430, binding = 3) buffer OutputData {
870    float output_data[];
871};
872
873uniform uint u_element_count;
874uniform float u_min_value;
875uniform float u_max_value;
876
877// Shared memory for local histogram accumulation
878shared uint local_hist[BIN_COUNT];
879
880void main() {
881    uint idx = gl_GlobalInvocationID.x;
882    uint lid = gl_LocalInvocationID.x;
883
884#if PASS_MODE == 0
885    // Pass 0: Build histogram
886    // Initialize shared histogram
887    if (lid < uint(BIN_COUNT)) {
888        local_hist[lid] = 0u;
889    }
890    barrier();
891
892    if (idx < u_element_count) {
893        float val = input_data[idx];
894        float norm = clamp((val - u_min_value) / (u_max_value - u_min_value), 0.0, 1.0);
895        uint bin = min(uint(norm * float(BIN_COUNT - 1)), uint(BIN_COUNT - 1));
896        atomicAdd(local_hist[bin], 1u);
897    }
898    barrier();
899
900    // Merge local histogram into global
901    if (lid < uint(BIN_COUNT)) {
902        atomicAdd(histogram[lid], local_hist[lid]);
903    }
904
905#elif PASS_MODE == 1
906    // Pass 1: Build CDF from histogram (single workgroup, sequential for simplicity)
907    if (idx == 0u) {
908        uint running = 0u;
909        for (uint i = 0u; i < uint(BIN_COUNT); i++) {
910            running += histogram[i];
911            cdf[i] = float(running) / float(u_element_count);
912        }
913    }
914
915#elif PASS_MODE == 2
916    // Pass 2: Apply equalization using CDF
917    if (idx < u_element_count) {
918        float val = input_data[idx];
919        float norm = clamp((val - u_min_value) / (u_max_value - u_min_value), 0.0, 1.0);
920        uint bin = min(uint(norm * float(BIN_COUNT - 1)), uint(BIN_COUNT - 1));
921        float equalized = cdf[bin];
922        output_data[idx] = equalized * (u_max_value - u_min_value) + u_min_value;
923    }
924
925#endif
926}
927"#;
928
929/// Blelloch parallel prefix sum (exclusive scan).
930pub const KERNEL_PREFIX_SUM: &str = r#"
931// Blelloch parallel prefix sum (exclusive scan)
932// Two-phase: up-sweep (reduce) then down-sweep.
933// Works on a single workgroup; for larger arrays, use multi-block with auxiliary sums.
934//
935// PHASE define: 0 = up-sweep, 1 = down-sweep, 2 = add block offsets
936
937layout(local_size_x = 256) in;
938
939#ifndef PHASE
940#define PHASE 0
941#endif
942
943layout(std430, binding = 0) buffer Data {
944    uint data[];
945};
946
947layout(std430, binding = 1) buffer BlockSums {
948    uint block_sums[];
949};
950
951uniform uint u_n;           // number of elements
952uniform uint u_block_size;  // elements per block (2 * local_size)
953
954shared uint temp[512]; // 2 * local_size_x
955
956void main() {
957    uint lid = gl_LocalInvocationID.x;
958    uint gid = gl_WorkGroupID.x;
959    uint block_offset = gid * u_block_size;
960
961#if PHASE == 0
962    // Load into shared memory
963    uint ai = lid;
964    uint bi = lid + 256u;
965    uint a_idx = block_offset + ai;
966    uint b_idx = block_offset + bi;
967
968    temp[ai] = (a_idx < u_n) ? data[a_idx] : 0u;
969    temp[bi] = (b_idx < u_n) ? data[b_idx] : 0u;
970    barrier();
971
972    // Up-sweep (reduce)
973    uint offset = 1u;
974    for (uint d = 512u >> 1u; d > 0u; d >>= 1u) {
975        barrier();
976        if (lid < d) {
977            uint ai2 = offset * (2u * lid + 1u) - 1u;
978            uint bi2 = offset * (2u * lid + 2u) - 1u;
979            temp[bi2] += temp[ai2];
980        }
981        offset <<= 1u;
982    }
983    barrier();
984
985    // Store block sum and clear last element
986    if (lid == 0u) {
987        block_sums[gid] = temp[511u];
988        temp[511u] = 0u;
989    }
990    barrier();
991
992    // Down-sweep
993    for (uint d = 1u; d < 512u; d <<= 1u) {
994        offset >>= 1u;
995        barrier();
996        if (lid < d) {
997            uint ai2 = offset * (2u * lid + 1u) - 1u;
998            uint bi2 = offset * (2u * lid + 2u) - 1u;
999            uint t = temp[ai2];
1000            temp[ai2] = temp[bi2];
1001            temp[bi2] += t;
1002        }
1003    }
1004    barrier();
1005
1006    // Write back
1007    if (a_idx < u_n) data[a_idx] = temp[ai];
1008    if (b_idx < u_n) data[b_idx] = temp[bi];
1009
1010#elif PHASE == 2
1011    // Add block offsets for multi-block scan
1012    if (gid > 0u) {
1013        uint a_idx2 = block_offset + lid;
1014        uint b_idx2 = block_offset + lid + 256u;
1015        uint block_sum = block_sums[gid];
1016
1017        if (a_idx2 < u_n) data[a_idx2] += block_sum;
1018        if (b_idx2 < u_n) data[b_idx2] += block_sum;
1019    }
1020
1021#endif
1022}
1023"#;
1024
1025/// GPU radix sort kernel.
1026pub const KERNEL_RADIX_SORT: &str = r#"
1027// Radix sort kernel (LSB first, 4 bits per pass)
1028// PASS define: 0 = count, 1 = scatter
1029// Uses prefix sums (from prefix_sum kernel) between passes.
1030
1031layout(local_size_x = 256) in;
1032
1033#ifndef PASS
1034#define PASS 0
1035#endif
1036
1037#ifndef RADIX_BITS
1038#define RADIX_BITS 4
1039#endif
1040
1041#define RADIX (1 << RADIX_BITS)
1042
1043layout(std430, binding = 0) buffer KeysIn {
1044    uint keys_in[];
1045};
1046
1047layout(std430, binding = 1) buffer KeysOut {
1048    uint keys_out[];
1049};
1050
1051layout(std430, binding = 2) buffer ValuesIn {
1052    uint values_in[];
1053};
1054
1055layout(std430, binding = 3) buffer ValuesOut {
1056    uint values_out[];
1057};
1058
1059layout(std430, binding = 4) buffer Offsets {
1060    uint offsets[];      // RADIX * num_blocks
1061};
1062
1063layout(std430, binding = 5) buffer GlobalOffsets {
1064    uint global_offsets[];  // RADIX prefix sums
1065};
1066
1067uniform uint u_n;
1068uniform uint u_bit_offset;  // which 4-bit nibble (0, 4, 8, 12, ...)
1069
1070shared uint local_counts[RADIX];
1071
1072uint extract_digit(uint key, uint bit_offset) {
1073    return (key >> bit_offset) & uint(RADIX - 1);
1074}
1075
1076void main() {
1077    uint lid = gl_LocalInvocationID.x;
1078    uint gid = gl_WorkGroupID.x;
1079    uint idx = gl_GlobalInvocationID.x;
1080
1081#if PASS == 0
1082    // Count pass: count occurrences of each digit in this block
1083    if (lid < uint(RADIX)) {
1084        local_counts[lid] = 0u;
1085    }
1086    barrier();
1087
1088    if (idx < u_n) {
1089        uint digit = extract_digit(keys_in[idx], u_bit_offset);
1090        atomicAdd(local_counts[digit], 1u);
1091    }
1092    barrier();
1093
1094    // Write local counts to global offset table
1095    if (lid < uint(RADIX)) {
1096        offsets[lid * gl_NumWorkGroups.x + gid] = local_counts[lid];
1097    }
1098
1099#elif PASS == 1
1100    // Scatter pass: place each element at its globally sorted position
1101    if (idx < u_n) {
1102        uint key = keys_in[idx];
1103        uint digit = extract_digit(key, u_bit_offset);
1104
1105        // global_offsets[digit] = prefix sum of all counts for this digit
1106        // We need the exact output position: global prefix + local prefix
1107        // This is a simplified scatter; a production sort uses local prefix sums
1108        uint dest = atomicAdd(global_offsets[digit], 1u);
1109        if (dest < u_n) {
1110            keys_out[dest] = key;
1111            values_out[dest] = values_in[idx];
1112        }
1113    }
1114
1115#endif
1116}
1117"#;
1118
1119/// Frustum culling kernel: per-instance visibility testing.
1120pub const KERNEL_FRUSTUM_CULL: &str = r#"
1121// Per-instance frustum culling with optional LOD selection.
1122// Tests each instance's bounding sphere against 6 frustum planes.
1123
1124layout(local_size_x = 256) in;
1125
1126struct Instance {
1127    vec4 position_radius;   // xyz = position, w = bounding sphere radius
1128    vec4 extra;             // xyz = scale, w = LOD override (-1 = auto)
1129};
1130
1131struct VisibleInstance {
1132    uint original_index;
1133    uint lod_level;
1134    vec2 _padding;
1135};
1136
1137layout(std430, binding = 0) readonly buffer Instances {
1138    Instance instances[];
1139};
1140
1141layout(std430, binding = 1) writeonly buffer VisibleOut {
1142    VisibleInstance visible[];
1143};
1144
1145layout(binding = 0, offset = 0) uniform atomic_uint u_visible_count;
1146
1147uniform uint u_instance_count;
1148uniform vec4 u_planes[6];      // frustum planes (normal.xyz, distance.w)
1149uniform vec3 u_camera_pos;
1150uniform vec4 u_lod_distances;  // x=lod0, y=lod1, z=lod2, w=lod3
1151uniform uint u_enable_lod;
1152
1153bool sphere_vs_frustum(vec3 center, float radius) {
1154    for (int i = 0; i < 6; i++) {
1155        float dist = dot(u_planes[i].xyz, center) + u_planes[i].w;
1156        if (dist < -radius) {
1157            return false;  // fully outside this plane
1158        }
1159    }
1160    return true;
1161}
1162
1163uint compute_lod(vec3 pos, float radius) {
1164    if (u_enable_lod == 0u) return 0u;
1165
1166    float dist = length(pos - u_camera_pos) - radius;
1167    if (dist < u_lod_distances.x) return 0u;
1168    if (dist < u_lod_distances.y) return 1u;
1169    if (dist < u_lod_distances.z) return 2u;
1170    return 3u;
1171}
1172
1173void main() {
1174    uint idx = gl_GlobalInvocationID.x;
1175    if (idx >= u_instance_count) return;
1176
1177    Instance inst = instances[idx];
1178    vec3 center = inst.position_radius.xyz;
1179    float radius = inst.position_radius.w * max(inst.extra.x, max(inst.extra.y, inst.extra.z));
1180
1181    if (sphere_vs_frustum(center, radius)) {
1182        uint lod = (inst.extra.w >= 0.0)
1183            ? uint(inst.extra.w)
1184            : compute_lod(center, radius);
1185
1186        uint slot = atomicCounterIncrement(u_visible_count);
1187        visible[slot].original_index = idx;
1188        visible[slot].lod_level = lod;
1189    }
1190}
1191"#;
1192
1193/// Skeletal skinning kernel: apply bone matrix palette transforms.
1194pub const KERNEL_SKINNING: &str = r#"
1195// GPU skinning kernel
1196// Transforms vertices by a weighted sum of bone matrices.
1197
1198layout(local_size_x = 256) in;
1199
1200struct Vertex {
1201    vec4 position;     // xyz = position, w = 1
1202    vec4 normal;       // xyz = normal, w = 0
1203    vec4 tangent;      // xyz = tangent, w = handedness
1204    vec4 bone_weights; // up to 4 bone weights
1205    uvec4 bone_indices; // up to 4 bone indices
1206};
1207
1208struct SkinnedVertex {
1209    vec4 position;
1210    vec4 normal;
1211    vec4 tangent;
1212    vec4 _reserved;
1213};
1214
1215layout(std430, binding = 0) readonly buffer VerticesIn {
1216    Vertex vertices_in[];
1217};
1218
1219layout(std430, binding = 1) writeonly buffer VerticesOut {
1220    SkinnedVertex vertices_out[];
1221};
1222
1223layout(std430, binding = 2) readonly buffer BoneMatrices {
1224    mat4 bones[];
1225};
1226
1227layout(std430, binding = 3) readonly buffer InverseBindMatrices {
1228    mat4 inv_bind[];
1229};
1230
1231uniform uint u_vertex_count;
1232uniform uint u_bone_count;
1233uniform uint u_max_bones_per_vertex;
1234
1235mat4 get_skin_matrix(uvec4 indices, vec4 weights) {
1236    mat4 skin = mat4(0.0);
1237
1238    // Bone 0
1239    if (weights.x > 0.0 && indices.x < u_bone_count) {
1240        skin += weights.x * (bones[indices.x] * inv_bind[indices.x]);
1241    }
1242
1243    // Bone 1
1244    if (u_max_bones_per_vertex > 1u && weights.y > 0.0 && indices.y < u_bone_count) {
1245        skin += weights.y * (bones[indices.y] * inv_bind[indices.y]);
1246    }
1247
1248    // Bone 2
1249    if (u_max_bones_per_vertex > 2u && weights.z > 0.0 && indices.z < u_bone_count) {
1250        skin += weights.z * (bones[indices.z] * inv_bind[indices.z]);
1251    }
1252
1253    // Bone 3
1254    if (u_max_bones_per_vertex > 3u && weights.w > 0.0 && indices.w < u_bone_count) {
1255        skin += weights.w * (bones[indices.w] * inv_bind[indices.w]);
1256    }
1257
1258    return skin;
1259}
1260
1261void main() {
1262    uint idx = gl_GlobalInvocationID.x;
1263    if (idx >= u_vertex_count) return;
1264
1265    Vertex v = vertices_in[idx];
1266    mat4 skin = get_skin_matrix(v.bone_indices, v.bone_weights);
1267
1268    // If no bones affect this vertex, use identity
1269    if (v.bone_weights.x + v.bone_weights.y + v.bone_weights.z + v.bone_weights.w < 0.001) {
1270        skin = mat4(1.0);
1271    }
1272
1273    vec4 skinned_pos = skin * v.position;
1274    vec3 skinned_normal = normalize(mat3(skin) * v.normal.xyz);
1275    vec3 skinned_tangent = normalize(mat3(skin) * v.tangent.xyz);
1276
1277    vertices_out[idx].position = skinned_pos;
1278    vertices_out[idx].normal = vec4(skinned_normal, 0.0);
1279    vertices_out[idx].tangent = vec4(skinned_tangent, v.tangent.w);
1280}
1281"#;
1282
1283// ---------------------------------------------------------------------------
1284// KernelLibrary
1285// ---------------------------------------------------------------------------
1286
1287/// Library of all built-in compute kernels, providing easy access to sources
1288/// and parameter setup.
1289pub struct KernelLibrary {
1290    sources: HashMap<KernelId, &'static str>,
1291}
1292
1293impl KernelLibrary {
1294    /// Create the library with all built-in kernels.
1295    pub fn new() -> Self {
1296        let mut sources = HashMap::new();
1297        sources.insert(KernelId::ParticleIntegrate, KERNEL_PARTICLE_INTEGRATE);
1298        sources.insert(KernelId::ParticleEmit, KERNEL_PARTICLE_EMIT);
1299        sources.insert(KernelId::ForceFieldSample, KERNEL_FORCE_FIELD_SAMPLE);
1300        sources.insert(KernelId::MathFunctionGpu, KERNEL_MATH_FUNCTION_GPU);
1301        sources.insert(KernelId::FluidDiffuse, KERNEL_FLUID_DIFFUSE);
1302        sources.insert(KernelId::HistogramEqualize, KERNEL_HISTOGRAM_EQUALIZE);
1303        sources.insert(KernelId::PrefixSum, KERNEL_PREFIX_SUM);
1304        sources.insert(KernelId::RadixSort, KERNEL_RADIX_SORT);
1305        sources.insert(KernelId::FrustumCull, KERNEL_FRUSTUM_CULL);
1306        sources.insert(KernelId::Skinning, KERNEL_SKINNING);
1307        Self { sources }
1308    }
1309
1310    /// Get the raw GLSL source for a kernel.
1311    pub fn source(&self, id: KernelId) -> Option<&'static str> {
1312        self.sources.get(&id).copied()
1313    }
1314
1315    /// Build a `ShaderSource` for a kernel, with version header and optional defines.
1316    pub fn shader_source(&self, id: KernelId) -> Option<super::dispatch::ShaderSource> {
1317        self.source(id).map(|src| {
1318            let mut ss = super::dispatch::ShaderSource::new(src);
1319            ss.set_label(id.name());
1320            ss
1321        })
1322    }
1323
1324    /// Compile a kernel into a `ComputeProgram`.
1325    pub fn compile(
1326        &self,
1327        gl: &glow::Context,
1328        id: KernelId,
1329    ) -> Result<super::dispatch::ComputeProgram, String> {
1330        let src = self
1331            .shader_source(id)
1332            .ok_or_else(|| format!("Unknown kernel: {:?}", id))?;
1333        super::dispatch::ComputeProgram::compile(gl, &src)
1334    }
1335
1336    /// Compile a kernel with extra defines.
1337    pub fn compile_with_defines(
1338        &self,
1339        gl: &glow::Context,
1340        id: KernelId,
1341        defines: &[(&str, &str)],
1342    ) -> Result<super::dispatch::ComputeProgram, String> {
1343        let mut src = self
1344            .shader_source(id)
1345            .ok_or_else(|| format!("Unknown kernel: {:?}", id))?;
1346        for (name, value) in defines {
1347            src.define(name, value);
1348        }
1349        super::dispatch::ComputeProgram::compile(gl, &src)
1350    }
1351
1352    /// List all available kernel IDs.
1353    pub fn available_kernels(&self) -> Vec<KernelId> {
1354        KernelId::all().to_vec()
1355    }
1356}
1357
1358impl Default for KernelLibrary {
1359    fn default() -> Self {
1360        Self::new()
1361    }
1362}
1363
1364// ---------------------------------------------------------------------------
1365// Convenience: set uniforms for each kernel
1366// ---------------------------------------------------------------------------
1367
1368/// Set uniforms for the particle integration kernel.
1369pub fn set_particle_integrate_uniforms(
1370    gl: &glow::Context,
1371    program: &super::dispatch::ComputeProgram,
1372    params: &ParticleIntegrateParams,
1373) {
1374    program.set_uniform_float(gl, "u_dt", params.dt);
1375    program.set_uniform_vec3(
1376        gl,
1377        "u_gravity",
1378        params.gravity[0],
1379        params.gravity[1],
1380        params.gravity[2],
1381    );
1382    program.set_uniform_float(gl, "u_damping", params.damping);
1383    program.set_uniform_uint(gl, "u_particle_count", params.particle_count);
1384    program.set_uniform_float(gl, "u_max_age", params.max_age);
1385    program.set_uniform_vec3(
1386        gl,
1387        "u_wind",
1388        params.wind[0],
1389        params.wind[1],
1390        params.wind[2],
1391    );
1392    program.set_uniform_float(gl, "u_turbulence_strength", params.turbulence_strength);
1393    program.set_uniform_float(gl, "u_time", params.time);
1394}
1395
1396/// Set uniforms for the particle emission kernel.
1397pub fn set_particle_emit_uniforms(
1398    gl: &glow::Context,
1399    program: &super::dispatch::ComputeProgram,
1400    params: &ParticleEmitParams,
1401) {
1402    program.set_uniform_uint(gl, "u_emit_count", params.emit_count);
1403    program.set_uniform_uint(gl, "u_max_particles", params.max_particles);
1404    program.set_uniform_vec3(
1405        gl,
1406        "u_emitter_pos",
1407        params.emitter_position[0],
1408        params.emitter_position[1],
1409        params.emitter_position[2],
1410    );
1411    program.set_uniform_float(gl, "u_emitter_radius", params.emitter_radius);
1412    program.set_uniform_float(gl, "u_speed_min", params.initial_speed_min);
1413    program.set_uniform_float(gl, "u_speed_max", params.initial_speed_max);
1414    program.set_uniform_vec3(
1415        gl,
1416        "u_direction",
1417        params.initial_direction[0],
1418        params.initial_direction[1],
1419        params.initial_direction[2],
1420    );
1421    program.set_uniform_float(gl, "u_spread", params.spread_angle);
1422    program.set_uniform_float(gl, "u_life_min", params.lifetime_min);
1423    program.set_uniform_float(gl, "u_life_max", params.lifetime_max);
1424    program.set_uniform_float(gl, "u_time", params.time);
1425    program.set_uniform_uint(gl, "u_seed", params.seed);
1426    program.set_uniform_vec4(
1427        gl,
1428        "u_color_start",
1429        params.color_start[0],
1430        params.color_start[1],
1431        params.color_start[2],
1432        params.color_start[3],
1433    );
1434    program.set_uniform_vec4(
1435        gl,
1436        "u_color_end",
1437        params.color_end[0],
1438        params.color_end[1],
1439        params.color_end[2],
1440        params.color_end[3],
1441    );
1442}
1443
1444/// Set uniforms for the fluid diffusion kernel.
1445pub fn set_fluid_diffuse_uniforms(
1446    gl: &glow::Context,
1447    program: &super::dispatch::ComputeProgram,
1448    params: &FluidDiffuseParams,
1449) {
1450    let dx = 1.0f32;
1451    let alpha = params.diffusion_rate * params.dt / (dx * dx);
1452    let r_beta = 1.0 / (1.0 + 4.0 * alpha);
1453    program.set_uniform_uint(gl, "u_width", params.grid_width);
1454    program.set_uniform_uint(gl, "u_height", params.grid_height);
1455    program.set_uniform_float(gl, "u_alpha", alpha);
1456    program.set_uniform_float(gl, "u_r_beta", r_beta);
1457}
1458
1459/// Set uniforms for the histogram equalization kernel.
1460pub fn set_histogram_uniforms(
1461    gl: &glow::Context,
1462    program: &super::dispatch::ComputeProgram,
1463    params: &HistogramParams,
1464) {
1465    program.set_uniform_uint(gl, "u_element_count", params.width * params.height);
1466    program.set_uniform_float(gl, "u_min_value", params.min_value);
1467    program.set_uniform_float(gl, "u_max_value", params.max_value);
1468}
1469
1470/// Set uniforms for the prefix sum kernel.
1471pub fn set_prefix_sum_uniforms(
1472    gl: &glow::Context,
1473    program: &super::dispatch::ComputeProgram,
1474    plan: &PrefixSumPlan,
1475) {
1476    program.set_uniform_uint(gl, "u_n", plan.element_count);
1477    program.set_uniform_uint(gl, "u_block_size", plan.workgroup_size * 2);
1478}
1479
1480/// Set uniforms for the radix sort kernel.
1481pub fn set_radix_sort_uniforms(
1482    gl: &glow::Context,
1483    program: &super::dispatch::ComputeProgram,
1484    plan: &RadixSortPlan,
1485    bit_offset: u32,
1486) {
1487    program.set_uniform_uint(gl, "u_n", plan.element_count);
1488    program.set_uniform_uint(gl, "u_bit_offset", bit_offset);
1489}
1490
1491/// Set uniforms for the frustum culling kernel.
1492pub fn set_frustum_cull_uniforms(
1493    gl: &glow::Context,
1494    program: &super::dispatch::ComputeProgram,
1495    params: &FrustumCullParams,
1496) {
1497    program.set_uniform_uint(gl, "u_instance_count", params.instance_count);
1498    // Set frustum planes as individual vec4 uniforms
1499    for (i, plane) in params.frustum_planes.iter().enumerate() {
1500        let name = format!("u_planes[{}]", i);
1501        program.set_uniform_vec4(gl, &name, plane[0], plane[1], plane[2], plane[3]);
1502    }
1503    program.set_uniform_vec3(
1504        gl,
1505        "u_camera_pos",
1506        params.camera_position[0],
1507        params.camera_position[1],
1508        params.camera_position[2],
1509    );
1510    program.set_uniform_vec4(
1511        gl,
1512        "u_lod_distances",
1513        params.lod_distances[0],
1514        params.lod_distances[1],
1515        params.lod_distances[2],
1516        params.lod_distances[3],
1517    );
1518    program.set_uniform_uint(gl, "u_enable_lod", if params.enable_lod { 1 } else { 0 });
1519}
1520
1521/// Set uniforms for the skinning kernel.
1522pub fn set_skinning_uniforms(
1523    gl: &glow::Context,
1524    program: &super::dispatch::ComputeProgram,
1525    params: &SkinningParams,
1526) {
1527    program.set_uniform_uint(gl, "u_vertex_count", params.vertex_count);
1528    program.set_uniform_uint(gl, "u_bone_count", params.bone_count);
1529    program.set_uniform_uint(gl, "u_max_bones_per_vertex", params.max_bones_per_vertex);
1530}
1531
1532/// Set uniforms for the math function GPU kernel.
1533pub fn set_math_function_uniforms(
1534    gl: &glow::Context,
1535    program: &super::dispatch::ComputeProgram,
1536    function_type: MathFunctionType,
1537    point_count: u32,
1538    dt: f32,
1539    time: f32,
1540    params: [f32; 4],
1541    max_iterations: u32,
1542    julia_c: [f32; 2],
1543) {
1544    program.set_uniform_uint(gl, "u_point_count", point_count);
1545    program.set_uniform_uint(gl, "u_function_type", function_type as u32);
1546    program.set_uniform_float(gl, "u_dt", dt);
1547    program.set_uniform_float(gl, "u_time", time);
1548    program.set_uniform_float(gl, "u_param_a", params[0]);
1549    program.set_uniform_float(gl, "u_param_b", params[1]);
1550    program.set_uniform_float(gl, "u_param_c", params[2]);
1551    program.set_uniform_float(gl, "u_param_d", params[3]);
1552    program.set_uniform_uint(gl, "u_max_iterations", max_iterations);
1553    program.set_uniform_vec2(gl, "u_julia_c", julia_c[0], julia_c[1]);
1554}
1555
1556/// Set uniforms for the force field sampling kernel.
1557pub fn set_force_field_uniforms(
1558    gl: &glow::Context,
1559    program: &super::dispatch::ComputeProgram,
1560    particle_count: u32,
1561    field_count: u32,
1562    dt: f32,
1563    time: f32,
1564) {
1565    program.set_uniform_uint(gl, "u_particle_count", particle_count);
1566    program.set_uniform_uint(gl, "u_field_count", field_count);
1567    program.set_uniform_float(gl, "u_dt", dt);
1568    program.set_uniform_float(gl, "u_time", time);
1569}