Skip to main content

proof_engine/compute/
dispatch.rs

1//! Compute shader compilation, dispatch, pipeline caching, and profiling.
2//!
3//! Provides the core compute dispatch pipeline:
4//! - `ShaderSource` with `#define` injection
5//! - `ComputeProgram` with compile/link/validate
6//! - `WorkgroupSize` calculation with hardware-limit awareness
7//! - `ComputeDispatch` for 1D/2D/3D and indirect dispatch
8//! - `PipelineCache` for shader reuse
9//! - `SpecializationConstant` for compile-time constants
10//! - `ComputeProfiler` with GPU timer queries
11
12use std::collections::HashMap;
13
14// ---------------------------------------------------------------------------
15// GL constants
16// ---------------------------------------------------------------------------
17
18const GL_COMPUTE_SHADER: u32 = 0x91B9;
19const GL_SHADER_STORAGE_BARRIER_BIT: u32 = 0x00002000;
20const GL_DISPATCH_INDIRECT_BUFFER: u32 = 0x90EE;
21const GL_TIME_ELAPSED: u32 = 0x88BF;
22const GL_QUERY_RESULT: u32 = 0x8866;
23const GL_QUERY_RESULT_AVAILABLE: u32 = 0x8867;
24
25// ---------------------------------------------------------------------------
26// ShaderSource
27// ---------------------------------------------------------------------------
28
29/// A compute shader source with support for `#define` injection and includes.
30#[derive(Debug, Clone)]
31pub struct ShaderSource {
32    /// Base GLSL source code.
33    source: String,
34    /// Defines to inject after the #version line.
35    defines: Vec<(String, String)>,
36    /// Version string (e.g., "430").
37    version: String,
38    /// Optional label for debugging.
39    label: Option<String>,
40}
41
42impl ShaderSource {
43    /// Create a new shader source from GLSL code.
44    pub fn new(source: &str) -> Self {
45        Self {
46            source: source.to_string(),
47            defines: Vec::new(),
48            version: "430".to_string(),
49            label: None,
50        }
51    }
52
53    /// Create with explicit version.
54    pub fn with_version(source: &str, version: &str) -> Self {
55        Self {
56            source: source.to_string(),
57            defines: Vec::new(),
58            version: version.to_string(),
59            label: None,
60        }
61    }
62
63    /// Add a `#define NAME VALUE` to be injected.
64    pub fn define(&mut self, name: &str, value: &str) -> &mut Self {
65        self.defines.push((name.to_string(), value.to_string()));
66        self
67    }
68
69    /// Add a `#define NAME` (flag, no value).
70    pub fn define_flag(&mut self, name: &str) -> &mut Self {
71        self.defines.push((name.to_string(), String::new()));
72        self
73    }
74
75    /// Set the debug label.
76    pub fn set_label(&mut self, label: &str) -> &mut Self {
77        self.label = Some(label.to_string());
78        self
79    }
80
81    /// Get the label.
82    pub fn label(&self) -> Option<&str> {
83        self.label.as_deref()
84    }
85
86    /// Produce the final GLSL source with version and defines injected.
87    pub fn assemble(&self) -> String {
88        let mut result = String::with_capacity(self.source.len() + 256);
89        result.push_str(&format!("#version {} core\n", self.version));
90
91        for (name, value) in &self.defines {
92            if value.is_empty() {
93                result.push_str(&format!("#define {}\n", name));
94            } else {
95                result.push_str(&format!("#define {} {}\n", name, value));
96            }
97        }
98        result.push('\n');
99
100        // Strip any existing #version line from the source
101        for line in self.source.lines() {
102            let trimmed = line.trim();
103            if trimmed.starts_with("#version") {
104                continue;
105            }
106            result.push_str(line);
107            result.push('\n');
108        }
109        result
110    }
111
112    /// Generate a cache key based on source + defines (for PipelineCache).
113    pub fn cache_key(&self) -> u64 {
114        // Simple FNV-1a hash
115        let assembled = self.assemble();
116        let mut hash: u64 = 0xcbf29ce484222325;
117        for byte in assembled.bytes() {
118            hash ^= byte as u64;
119            hash = hash.wrapping_mul(0x100000001b3);
120        }
121        hash
122    }
123}
124
125// ---------------------------------------------------------------------------
126// SpecializationConstant
127// ---------------------------------------------------------------------------
128
129/// A specialization constant that can be set at compile time.
130/// In OpenGL this is simulated via `#define` injection.
131#[derive(Debug, Clone)]
132pub struct SpecializationConstant {
133    /// Constant name (becomes a #define).
134    pub name: String,
135    /// Value as string (will be injected as-is).
136    pub value: String,
137    /// Constant ID (for Vulkan compatibility tracking).
138    pub id: u32,
139}
140
141impl SpecializationConstant {
142    /// Create a new integer specialization constant.
143    pub fn int(name: &str, id: u32, value: i32) -> Self {
144        Self {
145            name: name.to_string(),
146            value: value.to_string(),
147            id,
148        }
149    }
150
151    /// Create a new unsigned integer specialization constant.
152    pub fn uint(name: &str, id: u32, value: u32) -> Self {
153        Self {
154            name: name.to_string(),
155            value: format!("{}u", value),
156            id,
157        }
158    }
159
160    /// Create a new float specialization constant.
161    pub fn float(name: &str, id: u32, value: f32) -> Self {
162        Self {
163            name: name.to_string(),
164            value: format!("{:.8}", value),
165            id,
166        }
167    }
168
169    /// Create a boolean specialization constant (0 or 1).
170    pub fn boolean(name: &str, id: u32, value: bool) -> Self {
171        Self {
172            name: name.to_string(),
173            value: if value { "1".to_string() } else { "0".to_string() },
174            id,
175        }
176    }
177
178    /// Apply this constant to a shader source as a define.
179    pub fn apply(&self, source: &mut ShaderSource) {
180        source.define(&self.name, &self.value);
181    }
182}
183
184/// Apply a set of specialization constants to a shader source.
185pub fn apply_specializations(source: &mut ShaderSource, constants: &[SpecializationConstant]) {
186    for c in constants {
187        c.apply(source);
188    }
189}
190
191// ---------------------------------------------------------------------------
192// ComputeProgram
193// ---------------------------------------------------------------------------
194
195/// A compiled and linked compute shader program.
196pub struct ComputeProgram {
197    /// GL program object.
198    program: glow::NativeProgram,
199    /// Cache key for lookup.
200    cache_key: u64,
201    /// Local workgroup size declared in the shader.
202    local_size: [u32; 3],
203    /// Debug label.
204    label: Option<String>,
205}
206
207impl ComputeProgram {
208    /// Compile and link a compute shader from source.
209    pub fn compile(
210        gl: &glow::Context,
211        source: &ShaderSource,
212    ) -> Result<Self, String> {
213        use glow::HasContext;
214        let assembled = source.assemble();
215        let cache_key = source.cache_key();
216
217        unsafe {
218            let shader = gl
219                .create_shader(GL_COMPUTE_SHADER)
220                .map_err(|e| format!("Failed to create compute shader: {}", e))?;
221
222            gl.shader_source(shader, &assembled);
223            gl.compile_shader(shader);
224
225            if !gl.get_shader_compile_status(shader) {
226                let log = gl.get_shader_info_log(shader);
227                gl.delete_shader(shader);
228                return Err(format!("Compute shader compilation failed:\n{}", log));
229            }
230
231            let program = gl
232                .create_program()
233                .map_err(|e| format!("Failed to create program: {}", e))?;
234
235            gl.attach_shader(program, shader);
236            gl.link_program(program);
237
238            if !gl.get_program_link_status(program) {
239                let log = gl.get_program_info_log(program);
240                gl.delete_program(program);
241                gl.delete_shader(shader);
242                return Err(format!("Compute program link failed:\n{}", log));
243            }
244
245            gl.detach_shader(program, shader);
246            gl.delete_shader(shader);
247
248            // Query local workgroup size
249            let local_size = Self::query_work_group_size(gl, program);
250
251            Ok(Self {
252                program,
253                cache_key,
254                local_size,
255                label: source.label().map(|s| s.to_string()),
256            })
257        }
258    }
259
260    /// Compile with specialization constants applied.
261    pub fn compile_specialized(
262        gl: &glow::Context,
263        source: &ShaderSource,
264        constants: &[SpecializationConstant],
265    ) -> Result<Self, String> {
266        let mut src = source.clone();
267        apply_specializations(&mut src, constants);
268        Self::compile(gl, &src)
269    }
270
271    /// Validate the program by checking the link status and info log.
272    pub fn validate(&self, gl: &glow::Context) -> Result<(), String> {
273        use glow::HasContext;
274        unsafe {
275            if !gl.get_program_link_status(self.program) {
276                let log = gl.get_program_info_log(self.program);
277                return Err(format!("Program validation failed:\n{}", log));
278            }
279        }
280        Ok(())
281    }
282
283    /// Use (bind) this program.
284    pub fn bind(&self, gl: &glow::Context) {
285        use glow::HasContext;
286        unsafe {
287            gl.use_program(Some(self.program));
288        }
289    }
290
291    /// Unbind the current program.
292    pub fn unbind(&self, gl: &glow::Context) {
293        use glow::HasContext;
294        unsafe {
295            gl.use_program(None);
296        }
297    }
298
299    /// Set a uniform int value.
300    pub fn set_uniform_int(&self, gl: &glow::Context, name: &str, value: i32) {
301        use glow::HasContext;
302        unsafe {
303            let loc = gl.get_uniform_location(self.program, name);
304            if let Some(loc) = loc {
305                gl.uniform_1_i32(Some(&loc), value);
306            }
307        }
308    }
309
310    /// Set a uniform uint value.
311    pub fn set_uniform_uint(&self, gl: &glow::Context, name: &str, value: u32) {
312        use glow::HasContext;
313        unsafe {
314            let loc = gl.get_uniform_location(self.program, name);
315            if let Some(loc) = loc {
316                gl.uniform_1_u32(Some(&loc), value);
317            }
318        }
319    }
320
321    /// Set a uniform float value.
322    pub fn set_uniform_float(&self, gl: &glow::Context, name: &str, value: f32) {
323        use glow::HasContext;
324        unsafe {
325            let loc = gl.get_uniform_location(self.program, name);
326            if let Some(loc) = loc {
327                gl.uniform_1_f32(Some(&loc), value);
328            }
329        }
330    }
331
332    /// Set a uniform vec2.
333    pub fn set_uniform_vec2(&self, gl: &glow::Context, name: &str, x: f32, y: f32) {
334        use glow::HasContext;
335        unsafe {
336            let loc = gl.get_uniform_location(self.program, name);
337            if let Some(loc) = loc {
338                gl.uniform_2_f32(Some(&loc), x, y);
339            }
340        }
341    }
342
343    /// Set a uniform vec3.
344    pub fn set_uniform_vec3(&self, gl: &glow::Context, name: &str, x: f32, y: f32, z: f32) {
345        use glow::HasContext;
346        unsafe {
347            let loc = gl.get_uniform_location(self.program, name);
348            if let Some(loc) = loc {
349                gl.uniform_3_f32(Some(&loc), x, y, z);
350            }
351        }
352    }
353
354    /// Set a uniform vec4.
355    pub fn set_uniform_vec4(
356        &self,
357        gl: &glow::Context,
358        name: &str,
359        x: f32,
360        y: f32,
361        z: f32,
362        w: f32,
363    ) {
364        use glow::HasContext;
365        unsafe {
366            let loc = gl.get_uniform_location(self.program, name);
367            if let Some(loc) = loc {
368                gl.uniform_4_f32(Some(&loc), x, y, z, w);
369            }
370        }
371    }
372
373    /// Set a uniform mat4 (column-major).
374    pub fn set_uniform_mat4(&self, gl: &glow::Context, name: &str, data: &[f32; 16]) {
375        use glow::HasContext;
376        unsafe {
377            let loc = gl.get_uniform_location(self.program, name);
378            if let Some(loc) = loc {
379                gl.uniform_matrix_4_f32_slice(Some(&loc), false, data);
380            }
381        }
382    }
383
384    /// Get the local workgroup size declared in the shader.
385    pub fn local_size(&self) -> [u32; 3] {
386        self.local_size
387    }
388
389    /// Get the cache key.
390    pub fn cache_key(&self) -> u64 {
391        self.cache_key
392    }
393
394    /// Get the label.
395    pub fn label(&self) -> Option<&str> {
396        self.label.as_deref()
397    }
398
399    /// Get the raw GL program.
400    pub fn raw_program(&self) -> glow::NativeProgram {
401        self.program
402    }
403
404    /// Destroy the program.
405    pub fn destroy(self, gl: &glow::Context) {
406        use glow::HasContext;
407        unsafe {
408            gl.delete_program(self.program);
409        }
410    }
411
412    /// Query the work group size from a linked compute program.
413    /// Falls back to [64, 1, 1] since glow does not directly expose
414    /// glGetProgramiv(GL_COMPUTE_WORK_GROUP_SIZE).
415    fn query_work_group_size(_gl: &glow::Context, _program: glow::NativeProgram) -> [u32; 3] {
416        // glow's abstraction does not expose glGetProgramiv for
417        // GL_COMPUTE_WORK_GROUP_SIZE directly. We use a sensible default
418        // and let callers override via WorkgroupSize.
419        [64, 1, 1]
420    }
421}
422
423// ---------------------------------------------------------------------------
424// WorkgroupSize
425// ---------------------------------------------------------------------------
426
427/// Computes optimal workgroup sizes for dispatch, respecting hardware limits.
428#[derive(Debug, Clone, Copy)]
429pub struct WorkgroupSize {
430    /// Local size X.
431    pub x: u32,
432    /// Local size Y.
433    pub y: u32,
434    /// Local size Z.
435    pub z: u32,
436}
437
438impl WorkgroupSize {
439    /// Create a 1D workgroup size.
440    pub fn new_1d(x: u32) -> Self {
441        Self { x, y: 1, z: 1 }
442    }
443
444    /// Create a 2D workgroup size.
445    pub fn new_2d(x: u32, y: u32) -> Self {
446        Self { x, y, z: 1 }
447    }
448
449    /// Create a 3D workgroup size.
450    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
451        Self { x, y, z }
452    }
453
454    /// Total number of invocations per workgroup.
455    pub fn total_invocations(&self) -> u32 {
456        self.x * self.y * self.z
457    }
458
459    /// Auto-fit a 1D workgroup to a total element count, clamped to hardware max.
460    pub fn auto_fit_1d(total_elements: u32, max_invocations: u32) -> Self {
461        let size = total_elements.min(max_invocations).max(1);
462        // Round down to nearest power of 2 for efficiency
463        let size = Self::round_down_pow2(size);
464        Self::new_1d(size)
465    }
466
467    /// Auto-fit a 2D workgroup to a width x height, clamped to limits.
468    pub fn auto_fit_2d(width: u32, height: u32, max_invocations: u32) -> Self {
469        let mut sx = 8u32;
470        let mut sy = 8u32;
471        while sx * sy > max_invocations {
472            if sx > sy {
473                sx /= 2;
474            } else {
475                sy /= 2;
476            }
477        }
478        sx = sx.min(width).max(1);
479        sy = sy.min(height).max(1);
480        Self::new_2d(sx, sy)
481    }
482
483    /// Compute the number of workgroups needed to cover `total` elements in 1D.
484    pub fn dispatch_count_1d(&self, total: u32) -> u32 {
485        (total + self.x - 1) / self.x
486    }
487
488    /// Compute the number of workgroups needed to cover (width, height) in 2D.
489    pub fn dispatch_count_2d(&self, width: u32, height: u32) -> (u32, u32) {
490        ((width + self.x - 1) / self.x, (height + self.y - 1) / self.y)
491    }
492
493    /// Compute dispatch counts for 3D.
494    pub fn dispatch_count_3d(&self, w: u32, h: u32, d: u32) -> (u32, u32, u32) {
495        (
496            (w + self.x - 1) / self.x,
497            (h + self.y - 1) / self.y,
498            (d + self.z - 1) / self.z,
499        )
500    }
501
502    /// Round down to nearest power of 2.
503    fn round_down_pow2(v: u32) -> u32 {
504        if v == 0 {
505            return 1;
506        }
507        let mut r = v;
508        r |= r >> 1;
509        r |= r >> 2;
510        r |= r >> 4;
511        r |= r >> 8;
512        r |= r >> 16;
513        (r >> 1) + 1
514    }
515
516    /// Query hardware limits from GL context.
517    pub fn query_limits(gl: &glow::Context) -> WorkgroupLimits {
518        use glow::HasContext;
519        unsafe {
520            let max_invocations = gl.get_parameter_i32(0x90EB) as u32; // GL_MAX_COMPUTE_WORK_GROUP_INVOCATIONS
521            let max_x = gl.get_parameter_indexed_i32(0x91BE, 0) as u32; // GL_MAX_COMPUTE_WORK_GROUP_SIZE
522            let max_y = gl.get_parameter_indexed_i32(0x91BE, 1) as u32;
523            let max_z = gl.get_parameter_indexed_i32(0x91BE, 2) as u32;
524            let max_count_x = gl.get_parameter_indexed_i32(0x91BF, 0) as u32; // GL_MAX_COMPUTE_WORK_GROUP_COUNT
525            let max_count_y = gl.get_parameter_indexed_i32(0x91BF, 1) as u32;
526            let max_count_z = gl.get_parameter_indexed_i32(0x91BF, 2) as u32;
527            let max_shared = gl.get_parameter_i32(0x8262) as u32; // GL_MAX_COMPUTE_SHARED_MEMORY_SIZE
528            WorkgroupLimits {
529                max_invocations,
530                max_size: [max_x, max_y, max_z],
531                max_count: [max_count_x, max_count_y, max_count_z],
532                max_shared_memory: max_shared,
533            }
534        }
535    }
536}
537
538/// Hardware workgroup limits queried from the GL context.
539#[derive(Debug, Clone, Copy)]
540pub struct WorkgroupLimits {
541    /// Maximum total invocations per workgroup.
542    pub max_invocations: u32,
543    /// Maximum local_size in each dimension.
544    pub max_size: [u32; 3],
545    /// Maximum dispatch count in each dimension.
546    pub max_count: [u32; 3],
547    /// Maximum shared memory in bytes.
548    pub max_shared_memory: u32,
549}
550
551impl Default for WorkgroupLimits {
552    fn default() -> Self {
553        Self {
554            max_invocations: 1024,
555            max_size: [1024, 1024, 64],
556            max_count: [65535, 65535, 65535],
557            max_shared_memory: 49152,
558        }
559    }
560}
561
562// ---------------------------------------------------------------------------
563// DispatchDimension
564// ---------------------------------------------------------------------------
565
566/// Describes the dispatch dimensionality.
567#[derive(Debug, Clone, Copy, PartialEq, Eq)]
568pub enum DispatchDimension {
569    /// 1D dispatch: (groups_x, 1, 1).
570    D1(u32),
571    /// 2D dispatch: (groups_x, groups_y, 1).
572    D2(u32, u32),
573    /// 3D dispatch: (groups_x, groups_y, groups_z).
574    D3(u32, u32, u32),
575}
576
577impl DispatchDimension {
578    /// Total number of workgroups.
579    pub fn total_groups(&self) -> u64 {
580        match *self {
581            DispatchDimension::D1(x) => x as u64,
582            DispatchDimension::D2(x, y) => x as u64 * y as u64,
583            DispatchDimension::D3(x, y, z) => x as u64 * y as u64 * z as u64,
584        }
585    }
586
587    /// Unpack to (x, y, z).
588    pub fn as_tuple(&self) -> (u32, u32, u32) {
589        match *self {
590            DispatchDimension::D1(x) => (x, 1, 1),
591            DispatchDimension::D2(x, y) => (x, y, 1),
592            DispatchDimension::D3(x, y, z) => (x, y, z),
593        }
594    }
595}
596
597// ---------------------------------------------------------------------------
598// IndirectDispatchArgs
599// ---------------------------------------------------------------------------
600
601/// Arguments for an indirect dispatch (read from a buffer on the GPU).
602#[derive(Debug, Clone, Copy)]
603#[repr(C)]
604pub struct IndirectDispatchArgs {
605    pub num_groups_x: u32,
606    pub num_groups_y: u32,
607    pub num_groups_z: u32,
608}
609
610impl IndirectDispatchArgs {
611    pub fn new(x: u32, y: u32, z: u32) -> Self {
612        Self {
613            num_groups_x: x,
614            num_groups_y: y,
615            num_groups_z: z,
616        }
617    }
618}
619
620// ---------------------------------------------------------------------------
621// ComputeDispatch
622// ---------------------------------------------------------------------------
623
624/// Executes compute shader dispatches.
625pub struct ComputeDispatch {
626    /// Default barrier bits to issue after each dispatch.
627    default_barrier: u32,
628    /// Whether to automatically issue a barrier after dispatch.
629    auto_barrier: bool,
630}
631
632impl ComputeDispatch {
633    /// Create a new dispatcher.
634    pub fn new() -> Self {
635        Self {
636            default_barrier: GL_SHADER_STORAGE_BARRIER_BIT,
637            auto_barrier: true,
638        }
639    }
640
641    /// Set whether barriers are automatically issued after dispatch.
642    pub fn set_auto_barrier(&mut self, enabled: bool) {
643        self.auto_barrier = enabled;
644    }
645
646    /// Set the default barrier bits.
647    pub fn set_default_barrier(&mut self, bits: u32) {
648        self.default_barrier = bits;
649    }
650
651    /// Dispatch a compute shader with the given program and dimensions.
652    pub fn dispatch(
653        &self,
654        gl: &glow::Context,
655        program: &ComputeProgram,
656        dim: DispatchDimension,
657    ) {
658        use glow::HasContext;
659        program.bind(gl);
660        let (x, y, z) = dim.as_tuple();
661        unsafe {
662            gl.dispatch_compute(x, y, z);
663        }
664        if self.auto_barrier {
665            unsafe {
666                gl.memory_barrier(self.default_barrier);
667            }
668        }
669    }
670
671    /// Dispatch 1D: convenience for dispatching over N elements.
672    pub fn dispatch_1d(
673        &self,
674        gl: &glow::Context,
675        program: &ComputeProgram,
676        total_elements: u32,
677        local_size_x: u32,
678    ) {
679        let groups = (total_elements + local_size_x - 1) / local_size_x;
680        self.dispatch(gl, program, DispatchDimension::D1(groups));
681    }
682
683    /// Dispatch 2D: convenience for dispatching over a width x height grid.
684    pub fn dispatch_2d(
685        &self,
686        gl: &glow::Context,
687        program: &ComputeProgram,
688        width: u32,
689        height: u32,
690        local_size: WorkgroupSize,
691    ) {
692        let gx = (width + local_size.x - 1) / local_size.x;
693        let gy = (height + local_size.y - 1) / local_size.y;
694        self.dispatch(gl, program, DispatchDimension::D2(gx, gy));
695    }
696
697    /// Dispatch 3D.
698    pub fn dispatch_3d(
699        &self,
700        gl: &glow::Context,
701        program: &ComputeProgram,
702        w: u32,
703        h: u32,
704        d: u32,
705        local_size: WorkgroupSize,
706    ) {
707        let gx = (w + local_size.x - 1) / local_size.x;
708        let gy = (h + local_size.y - 1) / local_size.y;
709        let gz = (d + local_size.z - 1) / local_size.z;
710        self.dispatch(gl, program, DispatchDimension::D3(gx, gy, gz));
711    }
712
713    /// Indirect dispatch: read dispatch arguments from a buffer on the GPU.
714    pub fn dispatch_indirect(
715        &self,
716        gl: &glow::Context,
717        program: &ComputeProgram,
718        buffer: super::buffer::BufferHandle,
719        offset: usize,
720    ) {
721        use glow::HasContext;
722        program.bind(gl);
723        let buf = glow::NativeBuffer(std::num::NonZeroU32::new(buffer.raw).unwrap());
724        unsafe {
725            gl.bind_buffer(GL_DISPATCH_INDIRECT_BUFFER, Some(buf));
726            gl.dispatch_compute_indirect(offset as i32);
727            gl.bind_buffer(GL_DISPATCH_INDIRECT_BUFFER, None);
728        }
729        if self.auto_barrier {
730            unsafe {
731                gl.memory_barrier(self.default_barrier);
732            }
733        }
734    }
735
736    /// Dispatch with an explicit barrier type (overrides auto).
737    pub fn dispatch_with_barrier(
738        &self,
739        gl: &glow::Context,
740        program: &ComputeProgram,
741        dim: DispatchDimension,
742        barrier: super::buffer::BufferBarrierType,
743    ) {
744        use glow::HasContext;
745        program.bind(gl);
746        let (x, y, z) = dim.as_tuple();
747        unsafe {
748            gl.dispatch_compute(x, y, z);
749            gl.memory_barrier(barrier.to_gl_bits());
750        }
751    }
752
753    /// Dispatch multiple passes of the same program with different dimensions.
754    pub fn dispatch_multi(
755        &self,
756        gl: &glow::Context,
757        program: &ComputeProgram,
758        dims: &[DispatchDimension],
759    ) {
760        use glow::HasContext;
761        program.bind(gl);
762        for dim in dims {
763            let (x, y, z) = dim.as_tuple();
764            unsafe {
765                gl.dispatch_compute(x, y, z);
766                if self.auto_barrier {
767                    gl.memory_barrier(self.default_barrier);
768                }
769            }
770        }
771    }
772}
773
774impl Default for ComputeDispatch {
775    fn default() -> Self {
776        Self::new()
777    }
778}
779
780// ---------------------------------------------------------------------------
781// PipelineCache
782// ---------------------------------------------------------------------------
783
784/// Caches compiled compute programs by their source hash.
785pub struct PipelineCache {
786    /// Internal cache map, keyed by source hash.
787    pub(crate) cache: HashMap<u64, ComputeProgram>,
788}
789
790impl PipelineCache {
791    /// Create a new empty cache.
792    pub fn new() -> Self {
793        Self {
794            cache: HashMap::new(),
795        }
796    }
797
798    /// Get or compile a program. If already cached, returns a reference.
799    pub fn get_or_compile(
800        &mut self,
801        gl: &glow::Context,
802        source: &ShaderSource,
803    ) -> Result<&ComputeProgram, String> {
804        let key = source.cache_key();
805        if !self.cache.contains_key(&key) {
806            let program = ComputeProgram::compile(gl, source)?;
807            self.cache.insert(key, program);
808        }
809        Ok(self.cache.get(&key).unwrap())
810    }
811
812    /// Get or compile with specialization constants.
813    pub fn get_or_compile_specialized(
814        &mut self,
815        gl: &glow::Context,
816        source: &ShaderSource,
817        constants: &[SpecializationConstant],
818    ) -> Result<&ComputeProgram, String> {
819        let mut src = source.clone();
820        apply_specializations(&mut src, constants);
821        let key = src.cache_key();
822        if !self.cache.contains_key(&key) {
823            let program = ComputeProgram::compile(gl, &src)?;
824            self.cache.insert(key, program);
825        }
826        Ok(self.cache.get(&key).unwrap())
827    }
828
829    /// Check if a program is cached.
830    pub fn contains(&self, source: &ShaderSource) -> bool {
831        self.cache.contains_key(&source.cache_key())
832    }
833
834    /// Number of cached programs.
835    pub fn len(&self) -> usize {
836        self.cache.len()
837    }
838
839    /// Whether the cache is empty.
840    pub fn is_empty(&self) -> bool {
841        self.cache.is_empty()
842    }
843
844    /// Evict a specific entry.
845    pub fn evict(&mut self, gl: &glow::Context, source: &ShaderSource) {
846        let key = source.cache_key();
847        if let Some(prog) = self.cache.remove(&key) {
848            prog.destroy(gl);
849        }
850    }
851
852    /// Clear the entire cache, deleting all programs.
853    pub fn clear(&mut self, gl: &glow::Context) {
854        for (_key, prog) in self.cache.drain() {
855            prog.destroy(gl);
856        }
857    }
858
859    /// Destroy the cache.
860    pub fn destroy(mut self, gl: &glow::Context) {
861        self.clear(gl);
862    }
863}
864
865impl Default for PipelineCache {
866    fn default() -> Self {
867        Self::new()
868    }
869}
870
871// ---------------------------------------------------------------------------
872// TimingQuery
873// ---------------------------------------------------------------------------
874
875/// A GPU timer query for measuring dispatch duration.
876pub struct TimingQuery {
877    query: glow::NativeQuery,
878    active: bool,
879    last_result_ns: u64,
880}
881
882impl TimingQuery {
883    /// Create a new timing query.
884    pub fn create(gl: &glow::Context) -> Self {
885        use glow::HasContext;
886        let query = unsafe {
887            gl.create_query().expect("Failed to create timer query")
888        };
889        Self {
890            query,
891            active: false,
892            last_result_ns: 0,
893        }
894    }
895
896    /// Begin the timer query.
897    pub fn begin(&mut self, gl: &glow::Context) {
898        use glow::HasContext;
899        unsafe {
900            gl.begin_query(GL_TIME_ELAPSED, self.query);
901        }
902        self.active = true;
903    }
904
905    /// End the timer query.
906    pub fn end(&mut self, gl: &glow::Context) {
907        use glow::HasContext;
908        unsafe {
909            gl.end_query(GL_TIME_ELAPSED);
910        }
911        self.active = false;
912    }
913
914    /// Check if the result is available (non-blocking).
915    pub fn is_available(&self, gl: &glow::Context) -> bool {
916        use glow::HasContext;
917        unsafe {
918            let available = gl.get_query_parameter_u32(self.query, GL_QUERY_RESULT_AVAILABLE);
919            available != 0
920        }
921    }
922
923    /// Retrieve the elapsed time in nanoseconds (blocks until available).
924    pub fn result_ns(&mut self, gl: &glow::Context) -> u64 {
925        use glow::HasContext;
926        let ns = unsafe { gl.get_query_parameter_u32(self.query, GL_QUERY_RESULT) };
927        self.last_result_ns = ns as u64;
928        self.last_result_ns
929    }
930
931    /// Get the last retrieved result without re-querying.
932    pub fn last_result_ns(&self) -> u64 {
933        self.last_result_ns
934    }
935
936    /// Last result in milliseconds.
937    pub fn last_result_ms(&self) -> f64 {
938        self.last_result_ns as f64 / 1_000_000.0
939    }
940
941    /// Whether a query is currently active (between begin/end).
942    pub fn is_active(&self) -> bool {
943        self.active
944    }
945
946    /// Destroy the query.
947    pub fn destroy(self, gl: &glow::Context) {
948        use glow::HasContext;
949        unsafe {
950            gl.delete_query(self.query);
951        }
952    }
953}
954
955// ---------------------------------------------------------------------------
956// ComputeProfiler
957// ---------------------------------------------------------------------------
958
959/// Profiles compute dispatches with per-dispatch GPU timing.
960pub struct ComputeProfiler {
961    /// Named timing queries.
962    queries: HashMap<String, TimingQuery>,
963    /// Whether profiling is enabled.
964    enabled: bool,
965    /// History of frame timings (dispatch_name -> Vec of ms values).
966    history: HashMap<String, Vec<f64>>,
967    /// Maximum history length per dispatch.
968    max_history: usize,
969}
970
971impl ComputeProfiler {
972    /// Create a new profiler.
973    pub fn new(max_history: usize) -> Self {
974        Self {
975            queries: HashMap::new(),
976            enabled: true,
977            history: HashMap::new(),
978            max_history,
979        }
980    }
981
982    /// Enable or disable profiling.
983    pub fn set_enabled(&mut self, enabled: bool) {
984        self.enabled = enabled;
985    }
986
987    /// Whether profiling is enabled.
988    pub fn is_enabled(&self) -> bool {
989        self.enabled
990    }
991
992    /// Begin timing a named dispatch.
993    pub fn begin(&mut self, gl: &glow::Context, name: &str) {
994        if !self.enabled {
995            return;
996        }
997        if !self.queries.contains_key(name) {
998            self.queries
999                .insert(name.to_string(), TimingQuery::create(gl));
1000        }
1001        if let Some(q) = self.queries.get_mut(name) {
1002            q.begin(gl);
1003        }
1004    }
1005
1006    /// End timing a named dispatch.
1007    pub fn end(&mut self, gl: &glow::Context, name: &str) {
1008        if !self.enabled {
1009            return;
1010        }
1011        if let Some(q) = self.queries.get_mut(name) {
1012            q.end(gl);
1013        }
1014    }
1015
1016    /// Collect results for all completed queries.
1017    pub fn collect_results(&mut self, gl: &glow::Context) {
1018        if !self.enabled {
1019            return;
1020        }
1021        let names: Vec<String> = self.queries.keys().cloned().collect();
1022        for name in names {
1023            if let Some(q) = self.queries.get_mut(&name) {
1024                if !q.is_active() && q.is_available(gl) {
1025                    let ns = q.result_ns(gl);
1026                    let ms = ns as f64 / 1_000_000.0;
1027                    let hist = self.history.entry(name).or_insert_with(Vec::new);
1028                    hist.push(ms);
1029                    if hist.len() > self.max_history {
1030                        hist.remove(0);
1031                    }
1032                }
1033            }
1034        }
1035    }
1036
1037    /// Get the last timing for a named dispatch in milliseconds.
1038    pub fn last_ms(&self, name: &str) -> Option<f64> {
1039        self.queries.get(name).map(|q| q.last_result_ms())
1040    }
1041
1042    /// Get the average timing for a named dispatch over the history window.
1043    pub fn average_ms(&self, name: &str) -> Option<f64> {
1044        self.history.get(name).and_then(|h| {
1045            if h.is_empty() {
1046                None
1047            } else {
1048                Some(h.iter().sum::<f64>() / h.len() as f64)
1049            }
1050        })
1051    }
1052
1053    /// Get the min/max timing for a named dispatch.
1054    pub fn min_max_ms(&self, name: &str) -> Option<(f64, f64)> {
1055        self.history.get(name).and_then(|h| {
1056            if h.is_empty() {
1057                None
1058            } else {
1059                let min = h.iter().cloned().fold(f64::MAX, f64::min);
1060                let max = h.iter().cloned().fold(f64::MIN, f64::max);
1061                Some((min, max))
1062            }
1063        })
1064    }
1065
1066    /// Get all dispatch names that have been profiled.
1067    pub fn dispatch_names(&self) -> Vec<&str> {
1068        self.queries.keys().map(|s| s.as_str()).collect()
1069    }
1070
1071    /// Print a summary of all profiled dispatches.
1072    pub fn summary(&self) -> String {
1073        let mut s = String::from("=== Compute Profiler Summary ===\n");
1074        let mut names: Vec<&str> = self.dispatch_names();
1075        names.sort();
1076        for name in names {
1077            let avg = self.average_ms(name).unwrap_or(0.0);
1078            let (min, max) = self.min_max_ms(name).unwrap_or((0.0, 0.0));
1079            let last = self.last_ms(name).unwrap_or(0.0);
1080            s.push_str(&format!(
1081                "  {}: last={:.3}ms avg={:.3}ms min={:.3}ms max={:.3}ms\n",
1082                name, last, avg, min, max
1083            ));
1084        }
1085        s
1086    }
1087
1088    /// Reset all history.
1089    pub fn reset_history(&mut self) {
1090        self.history.clear();
1091    }
1092
1093    /// Destroy all queries.
1094    pub fn destroy(self, gl: &glow::Context) {
1095        for (_name, query) in self.queries {
1096            query.destroy(gl);
1097        }
1098    }
1099}
1100
1101// ---------------------------------------------------------------------------
1102// PipelineState — immutable snapshot for caching dispatch configurations
1103// ---------------------------------------------------------------------------
1104
1105/// Snapshot of the state needed for a compute dispatch.
1106#[derive(Debug, Clone)]
1107pub struct PipelineState {
1108    /// Program cache key.
1109    pub program_key: u64,
1110    /// Dispatch dimension.
1111    pub dimension: DispatchDimension,
1112    /// Barrier bits to issue after dispatch.
1113    pub barrier_bits: u32,
1114    /// SSBO bindings: (binding_index, buffer_raw_id).
1115    pub ssbo_bindings: Vec<(u32, u32)>,
1116    /// Uniform values.
1117    pub uniforms: Vec<UniformValue>,
1118}
1119
1120/// A uniform value to set before dispatch.
1121#[derive(Debug, Clone)]
1122pub enum UniformValue {
1123    Int(String, i32),
1124    Uint(String, u32),
1125    Float(String, f32),
1126    Vec2(String, f32, f32),
1127    Vec3(String, f32, f32, f32),
1128    Vec4(String, f32, f32, f32, f32),
1129}
1130
1131impl PipelineState {
1132    /// Create a new pipeline state.
1133    pub fn new(program_key: u64, dimension: DispatchDimension) -> Self {
1134        Self {
1135            program_key,
1136            dimension,
1137            barrier_bits: GL_SHADER_STORAGE_BARRIER_BIT,
1138            ssbo_bindings: Vec::new(),
1139            uniforms: Vec::new(),
1140        }
1141    }
1142
1143    /// Add an SSBO binding.
1144    pub fn bind_ssbo(&mut self, binding: u32, buffer_raw: u32) -> &mut Self {
1145        self.ssbo_bindings.push((binding, buffer_raw));
1146        self
1147    }
1148
1149    /// Add a uniform.
1150    pub fn set_uniform(&mut self, value: UniformValue) -> &mut Self {
1151        self.uniforms.push(value);
1152        self
1153    }
1154
1155    /// Set barrier bits.
1156    pub fn set_barrier(&mut self, bits: u32) -> &mut Self {
1157        self.barrier_bits = bits;
1158        self
1159    }
1160
1161    /// Execute this pipeline state: bind SSBOs, set uniforms, dispatch.
1162    pub fn execute(
1163        &self,
1164        gl: &glow::Context,
1165        cache: &PipelineCache,
1166    ) {
1167        use glow::HasContext;
1168        // Find program in cache
1169        let program = match cache.cache.get(&self.program_key) {
1170            Some(p) => p,
1171            None => return, // Program not found
1172        };
1173
1174        program.bind(gl);
1175
1176        // Bind SSBOs
1177        for &(binding, raw) in &self.ssbo_bindings {
1178            if let Some(nz) = std::num::NonZeroU32::new(raw) {
1179                let buf = glow::NativeBuffer(nz);
1180                unsafe {
1181                    gl.bind_buffer_base(0x90D2, binding, Some(buf)); // GL_SHADER_STORAGE_BUFFER
1182                }
1183            }
1184        }
1185
1186        // Set uniforms
1187        for u in &self.uniforms {
1188            match u {
1189                UniformValue::Int(name, v) => program.set_uniform_int(gl, name, *v),
1190                UniformValue::Uint(name, v) => program.set_uniform_uint(gl, name, *v),
1191                UniformValue::Float(name, v) => program.set_uniform_float(gl, name, *v),
1192                UniformValue::Vec2(name, x, y) => program.set_uniform_vec2(gl, name, *x, *y),
1193                UniformValue::Vec3(name, x, y, z) => {
1194                    program.set_uniform_vec3(gl, name, *x, *y, *z)
1195                }
1196                UniformValue::Vec4(name, x, y, z, w) => {
1197                    program.set_uniform_vec4(gl, name, *x, *y, *z, *w)
1198                }
1199            }
1200        }
1201
1202        // Dispatch
1203        let (gx, gy, gz) = self.dimension.as_tuple();
1204        unsafe {
1205            gl.dispatch_compute(gx, gy, gz);
1206            gl.memory_barrier(self.barrier_bits);
1207        }
1208    }
1209}
1210
1211// ---------------------------------------------------------------------------
1212// ComputeChain — chain multiple dispatches
1213// ---------------------------------------------------------------------------
1214
1215/// A chain of compute dispatches to be executed in sequence with barriers between them.
1216pub struct ComputeChain {
1217    steps: Vec<ChainStep>,
1218}
1219
1220/// A single step in a compute chain.
1221pub struct ChainStep {
1222    /// Program to dispatch.
1223    pub program_key: u64,
1224    /// Dispatch dimensions.
1225    pub dimension: DispatchDimension,
1226    /// Uniforms to set for this step.
1227    pub uniforms: Vec<UniformValue>,
1228    /// Barrier bits after this step (0 = no barrier).
1229    pub barrier_bits: u32,
1230}
1231
1232impl ComputeChain {
1233    /// Create a new empty chain.
1234    pub fn new() -> Self {
1235        Self { steps: Vec::new() }
1236    }
1237
1238    /// Add a step.
1239    pub fn add_step(&mut self, step: ChainStep) -> &mut Self {
1240        self.steps.push(step);
1241        self
1242    }
1243
1244    /// Number of steps.
1245    pub fn len(&self) -> usize {
1246        self.steps.len()
1247    }
1248
1249    /// Whether the chain is empty.
1250    pub fn is_empty(&self) -> bool {
1251        self.steps.is_empty()
1252    }
1253
1254    /// Execute the entire chain.
1255    pub fn execute(&self, gl: &glow::Context, cache: &PipelineCache) {
1256        use glow::HasContext;
1257        for step in &self.steps {
1258            if let Some(program) = cache.cache.get(&step.program_key) {
1259                program.bind(gl);
1260                for u in &step.uniforms {
1261                    match u {
1262                        UniformValue::Int(name, v) => program.set_uniform_int(gl, name, *v),
1263                        UniformValue::Uint(name, v) => program.set_uniform_uint(gl, name, *v),
1264                        UniformValue::Float(name, v) => program.set_uniform_float(gl, name, *v),
1265                        UniformValue::Vec2(name, x, y) => {
1266                            program.set_uniform_vec2(gl, name, *x, *y)
1267                        }
1268                        UniformValue::Vec3(name, x, y, z) => {
1269                            program.set_uniform_vec3(gl, name, *x, *y, *z)
1270                        }
1271                        UniformValue::Vec4(name, x, y, z, w) => {
1272                            program.set_uniform_vec4(gl, name, *x, *y, *z, *w)
1273                        }
1274                    }
1275                }
1276                let (gx, gy, gz) = step.dimension.as_tuple();
1277                unsafe {
1278                    gl.dispatch_compute(gx, gy, gz);
1279                    if step.barrier_bits != 0 {
1280                        gl.memory_barrier(step.barrier_bits);
1281                    }
1282                }
1283            }
1284        }
1285    }
1286}
1287
1288impl Default for ComputeChain {
1289    fn default() -> Self {
1290        Self::new()
1291    }
1292}
1293
1294// ---------------------------------------------------------------------------
1295// ShaderPreprocessor — handle #include directives
1296// ---------------------------------------------------------------------------
1297
1298/// Simple shader preprocessor that resolves `#include "name"` directives
1299/// from a registered library of snippets.
1300pub struct ShaderPreprocessor {
1301    snippets: HashMap<String, String>,
1302}
1303
1304impl ShaderPreprocessor {
1305    /// Create a new preprocessor.
1306    pub fn new() -> Self {
1307        Self {
1308            snippets: HashMap::new(),
1309        }
1310    }
1311
1312    /// Register a named snippet.
1313    pub fn register(&mut self, name: &str, source: &str) {
1314        self.snippets.insert(name.to_string(), source.to_string());
1315    }
1316
1317    /// Process a shader source, resolving #include directives.
1318    pub fn process(&self, source: &str) -> String {
1319        let mut result = String::with_capacity(source.len());
1320        for line in source.lines() {
1321            let trimmed = line.trim();
1322            if trimmed.starts_with("#include") {
1323                // Extract the name between quotes
1324                if let Some(start) = trimmed.find('"') {
1325                    if let Some(end) = trimmed[start + 1..].find('"') {
1326                        let name = &trimmed[start + 1..start + 1 + end];
1327                        if let Some(snippet) = self.snippets.get(name) {
1328                            result.push_str(snippet);
1329                            result.push('\n');
1330                            continue;
1331                        }
1332                    }
1333                }
1334                // Include not resolved — keep the line as a comment
1335                result.push_str("// UNRESOLVED: ");
1336                result.push_str(line);
1337                result.push('\n');
1338            } else {
1339                result.push_str(line);
1340                result.push('\n');
1341            }
1342        }
1343        result
1344    }
1345}
1346
1347impl Default for ShaderPreprocessor {
1348    fn default() -> Self {
1349        Self::new()
1350    }
1351}