Skip to main content

proof_engine/wgpu_backend/
compute.rs

1//! Compute shader support: storage buffers, dispatch, CPU fallback, profiling.
2
3use std::collections::HashMap;
4use std::time::{Duration, Instant};
5
6use super::backend::{
7    BackendCapabilities, BackendContext, BufferHandle, BufferUsage, ComputePipelineHandle,
8    GpuBackend, GpuCommand, PipelineLayout, ShaderHandle, ShaderStage, SoftwareContext,
9};
10
11// ---------------------------------------------------------------------------
12// Access mode
13// ---------------------------------------------------------------------------
14
15/// Access mode for a binding in a compute shader.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum AccessMode {
18    ReadOnly,
19    WriteOnly,
20    ReadWrite,
21}
22
23// ---------------------------------------------------------------------------
24// Bind group layout
25// ---------------------------------------------------------------------------
26
27/// Describes one entry in a bind group.
28#[derive(Debug, Clone)]
29pub struct BindGroupEntry {
30    pub binding: u32,
31    pub buffer_or_texture: BindingResource,
32    pub access: AccessMode,
33}
34
35/// What is bound at a given slot.
36#[derive(Debug, Clone)]
37pub enum BindingResource {
38    Buffer(BufferHandle),
39    Texture(super::backend::TextureHandle),
40}
41
42/// Layout of a bind group (the descriptor, not the actual bound resources).
43#[derive(Debug, Clone)]
44pub struct BindGroupLayout {
45    pub entries: Vec<BindGroupLayoutEntry>,
46}
47
48impl BindGroupLayout {
49    pub fn new() -> Self { Self { entries: Vec::new() } }
50
51    pub fn push(mut self, binding: u32, access: AccessMode) -> Self {
52        self.entries.push(BindGroupLayoutEntry { binding, access });
53        self
54    }
55}
56
57impl Default for BindGroupLayout {
58    fn default() -> Self { Self::new() }
59}
60
61/// One entry in a bind-group layout descriptor.
62#[derive(Debug, Clone)]
63pub struct BindGroupLayoutEntry {
64    pub binding: u32,
65    pub access: AccessMode,
66}
67
68// ---------------------------------------------------------------------------
69// ComputePipeline
70// ---------------------------------------------------------------------------
71
72/// A compute pipeline ready for dispatch.
73#[derive(Debug, Clone)]
74pub struct ComputePipeline {
75    pub shader: ShaderHandle,
76    pub bind_group_layout: BindGroupLayout,
77    pub workgroup_size: [u32; 3],
78    pub handle: ComputePipelineHandle,
79}
80
81// ---------------------------------------------------------------------------
82// ComputeBuffer
83// ---------------------------------------------------------------------------
84
85/// A buffer suitable for compute shader storage.
86#[derive(Debug, Clone)]
87pub struct ComputeBuffer {
88    pub handle: BufferHandle,
89    pub size: usize,
90    pub element_size: usize,
91}
92
93impl ComputeBuffer {
94    /// Number of elements this buffer can hold.
95    pub fn element_count(&self) -> usize {
96        if self.element_size == 0 { 0 } else { self.size / self.element_size }
97    }
98}
99
100// ---------------------------------------------------------------------------
101// ComputeProfiler
102// ---------------------------------------------------------------------------
103
104/// Tracks timing information for compute dispatches.
105pub struct ComputeProfiler {
106    records: Vec<DispatchRecord>,
107    max_records: usize,
108}
109
110#[derive(Debug, Clone)]
111pub struct DispatchRecord {
112    pub label: String,
113    pub workgroups: [u32; 3],
114    pub duration: Duration,
115}
116
117impl ComputeProfiler {
118    pub fn new(max_records: usize) -> Self {
119        Self {
120            records: Vec::with_capacity(max_records.min(4096)),
121            max_records,
122        }
123    }
124
125    pub fn record(&mut self, label: &str, workgroups: [u32; 3], duration: Duration) {
126        if self.records.len() >= self.max_records {
127            self.records.remove(0);
128        }
129        self.records.push(DispatchRecord {
130            label: label.to_string(),
131            workgroups,
132            duration,
133        });
134    }
135
136    pub fn average_duration(&self) -> Duration {
137        if self.records.is_empty() {
138            return Duration::ZERO;
139        }
140        let total: Duration = self.records.iter().map(|r| r.duration).sum();
141        total / self.records.len() as u32
142    }
143
144    pub fn total_dispatches(&self) -> usize {
145        self.records.len()
146    }
147
148    pub fn clear(&mut self) {
149        self.records.clear();
150    }
151
152    pub fn last(&self) -> Option<&DispatchRecord> {
153        self.records.last()
154    }
155
156    pub fn records(&self) -> &[DispatchRecord] {
157        &self.records
158    }
159}
160
161// ---------------------------------------------------------------------------
162// CPU fallback kernel
163// ---------------------------------------------------------------------------
164
165/// A CPU compute kernel: runs `f(global_id)` for every invocation in the
166/// workgroup grid, in parallel across a simple thread-per-row scheme.
167pub struct CpuKernel {
168    pub workgroup_size: [u32; 3],
169}
170
171impl CpuKernel {
172    pub fn new(workgroup_size: [u32; 3]) -> Self {
173        Self { workgroup_size }
174    }
175
176    /// Dispatch the kernel on CPU.  Calls `f(global_x, global_y, global_z)`
177    /// for every invocation.
178    pub fn dispatch<F>(&self, groups: [u32; 3], mut f: F)
179    where
180        F: FnMut(u32, u32, u32),
181    {
182        let [sx, sy, sz] = self.workgroup_size;
183        let [gx, gy, gz] = groups;
184        for gz_i in 0..gz {
185            for gy_i in 0..gy {
186                for gx_i in 0..gx {
187                    for lz in 0..sz {
188                        for ly in 0..sy {
189                            for lx in 0..sx {
190                                let x = gx_i * sx + lx;
191                                let y = gy_i * sy + ly;
192                                let z = gz_i * sz + lz;
193                                f(x, y, z);
194                            }
195                        }
196                    }
197                }
198            }
199        }
200    }
201
202    /// Total number of invocations for the given dispatch size.
203    pub fn total_invocations(&self, groups: [u32; 3]) -> u64 {
204        let [sx, sy, sz] = self.workgroup_size;
205        let [gx, gy, gz] = groups;
206        (sx as u64) * (sy as u64) * (sz as u64)
207            * (gx as u64) * (gy as u64) * (gz as u64)
208    }
209}
210
211// ---------------------------------------------------------------------------
212// ComputeContext
213// ---------------------------------------------------------------------------
214
215/// High-level compute context wrapping a backend.
216pub struct ComputeContext {
217    pub backend_type: GpuBackend,
218    backend: Box<dyn BackendContext>,
219    capabilities: BackendCapabilities,
220    profiler: ComputeProfiler,
221    pipelines: HashMap<u64, ComputePipeline>,
222}
223
224impl ComputeContext {
225    pub fn new(backend: Box<dyn BackendContext>, backend_type: GpuBackend) -> Self {
226        let capabilities = BackendCapabilities::for_backend(backend_type);
227        Self {
228            backend_type,
229            backend,
230            capabilities,
231            profiler: ComputeProfiler::new(1024),
232            pipelines: HashMap::new(),
233        }
234    }
235
236    /// Create a compute context with a software backend.
237    pub fn software() -> Self {
238        Self::new(Box::new(SoftwareContext::new()), GpuBackend::Software)
239    }
240
241    /// Create a typed storage buffer from a slice, copying element data.
242    pub fn create_storage_buffer<T: Copy>(&mut self, data: &[T]) -> ComputeBuffer {
243        let element_size = std::mem::size_of::<T>();
244        let byte_size = element_size * data.len();
245        let handle = self.backend.create_buffer(byte_size, BufferUsage::STORAGE);
246
247        // Copy data bytes
248        let byte_slice = unsafe {
249            std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_size)
250        };
251        self.backend.write_buffer(handle, byte_slice);
252
253        ComputeBuffer {
254            handle,
255            size: byte_size,
256            element_size,
257        }
258    }
259
260    /// Create an empty storage buffer with room for `count` elements of type T.
261    pub fn create_empty_buffer<T>(&mut self, count: usize) -> ComputeBuffer {
262        let element_size = std::mem::size_of::<T>();
263        let byte_size = element_size * count;
264        let handle = self.backend.create_buffer(byte_size, BufferUsage::STORAGE);
265        ComputeBuffer {
266            handle,
267            size: byte_size,
268            element_size,
269        }
270    }
271
272    /// Create a compute pipeline.
273    pub fn create_pipeline(
274        &mut self,
275        source: &str,
276        layout: BindGroupLayout,
277        workgroup_size: [u32; 3],
278    ) -> ComputePipeline {
279        let shader = self.backend.create_shader(source, ShaderStage::Compute);
280        let pl = PipelineLayout::default();
281        let handle = self.backend.create_compute_pipeline(shader, &pl);
282        let pipeline = ComputePipeline {
283            shader,
284            bind_group_layout: layout,
285            workgroup_size,
286            handle,
287        };
288        self.pipelines.insert(handle.0, pipeline.clone());
289        pipeline
290    }
291
292    /// Dispatch a compute pipeline.
293    pub fn dispatch(&mut self, pipeline: &ComputePipeline, x: u32, y: u32, z: u32) {
294        let start = Instant::now();
295
296        if self.backend_type == GpuBackend::Software {
297            // CPU fallback: nothing to actually execute — the software backend
298            // records the command but doesn't run shader code.
299        }
300
301        self.backend.submit(&[GpuCommand::Dispatch {
302            pipeline: pipeline.handle,
303            x,
304            y,
305            z,
306        }]);
307
308        let elapsed = start.elapsed();
309        self.profiler.record("dispatch", [x, y, z], elapsed);
310    }
311
312    /// Indirect dispatch: the workgroup counts come from a GPU buffer.
313    pub fn indirect_dispatch(&mut self, pipeline: &ComputePipeline, args_buffer: &ComputeBuffer) {
314        // Read the indirect args from the buffer (3 x u32).
315        let data = self.backend.read_buffer(args_buffer.handle);
316        let mut groups = [1u32, 1, 1];
317        if data.len() >= 12 {
318            for i in 0..3 {
319                let bytes = [data[i * 4], data[i * 4 + 1], data[i * 4 + 2], data[i * 4 + 3]];
320                groups[i] = u32::from_le_bytes(bytes);
321            }
322        }
323        self.dispatch(pipeline, groups[0], groups[1], groups[2]);
324    }
325
326    /// Insert a memory barrier.
327    pub fn memory_barrier(&mut self) {
328        self.backend.submit(&[GpuCommand::Barrier]);
329    }
330
331    /// Read back buffer contents as a typed slice.
332    pub fn read_back<T: Copy + Default>(&self, buffer: &ComputeBuffer) -> Vec<T> {
333        let data = self.backend.read_buffer(buffer.handle);
334        let elem_size = std::mem::size_of::<T>();
335        if elem_size == 0 {
336            return Vec::new();
337        }
338        let count = data.len() / elem_size;
339        let mut result = vec![T::default(); count];
340        unsafe {
341            let dst = std::slice::from_raw_parts_mut(
342                result.as_mut_ptr() as *mut u8,
343                count * elem_size,
344            );
345            dst.copy_from_slice(&data[..count * elem_size]);
346        }
347        result
348    }
349
350    /// Write typed data into an existing buffer.
351    pub fn write_buffer<T: Copy>(&mut self, buffer: &ComputeBuffer, data: &[T]) {
352        let byte_size = std::mem::size_of::<T>() * data.len();
353        let byte_slice = unsafe {
354            std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_size)
355        };
356        self.backend.write_buffer(buffer.handle, byte_slice);
357    }
358
359    /// Get the profiler.
360    pub fn profiler(&self) -> &ComputeProfiler {
361        &self.profiler
362    }
363
364    /// Mutable access to the profiler.
365    pub fn profiler_mut(&mut self) -> &mut ComputeProfiler {
366        &mut self.profiler
367    }
368
369    /// Whether the backend supports compute shaders natively.
370    pub fn supports_compute(&self) -> bool {
371        self.capabilities.compute_shaders
372    }
373
374    /// Destroy a compute buffer.
375    pub fn destroy_buffer(&mut self, buffer: &ComputeBuffer) {
376        self.backend.destroy_buffer(buffer.handle);
377    }
378}
379
380// ---------------------------------------------------------------------------
381// CPU fallback dispatch (parallel-ish, no rayon dep — uses std threads)
382// ---------------------------------------------------------------------------
383
384/// Run a CPU compute kernel across multiple threads.
385/// `f` receives `(thread_id, global_x, global_y, global_z)`.
386pub fn cpu_parallel_dispatch<F>(
387    workgroup_size: [u32; 3],
388    groups: [u32; 3],
389    num_threads: usize,
390    f: F,
391) where
392    F: Fn(usize, u32, u32, u32) + Send + Sync,
393{
394    let [sx, sy, sz] = workgroup_size;
395    let [gx, gy, gz] = groups;
396    let total_groups = (gx as usize) * (gy as usize) * (gz as usize);
397    let num_threads = num_threads.max(1).min(total_groups);
398
399    if num_threads <= 1 {
400        // Single-threaded fast path.
401        let kernel = CpuKernel::new(workgroup_size);
402        kernel.dispatch(groups, |x, y, z| f(0, x, y, z));
403        return;
404    }
405
406    // Build a flat list of group indices and split among threads.
407    let groups_per_thread = (total_groups + num_threads - 1) / num_threads;
408    let f_ref = &f;
409
410    std::thread::scope(|scope| {
411        for tid in 0..num_threads {
412            let start = tid * groups_per_thread;
413            let end = ((tid + 1) * groups_per_thread).min(total_groups);
414            scope.spawn(move || {
415                for flat in start..end {
416                    let gz_i = (flat / ((gx as usize) * (gy as usize))) as u32;
417                    let rem = flat % ((gx as usize) * (gy as usize));
418                    let gy_i = (rem / (gx as usize)) as u32;
419                    let gx_i = (rem % (gx as usize)) as u32;
420                    for lz in 0..sz {
421                        for ly in 0..sy {
422                            for lx in 0..sx {
423                                let x = gx_i * sx + lx;
424                                let y = gy_i * sy + ly;
425                                let z = gz_i * sz + lz;
426                                f_ref(tid, x, y, z);
427                            }
428                        }
429                    }
430                }
431            });
432        }
433    });
434}
435
436// ---------------------------------------------------------------------------
437// Tests
438// ---------------------------------------------------------------------------
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn compute_buffer_element_count() {
446        let buf = ComputeBuffer {
447            handle: BufferHandle(1),
448            size: 40,
449            element_size: 4,
450        };
451        assert_eq!(buf.element_count(), 10);
452    }
453
454    #[test]
455    fn create_storage_buffer_f32() {
456        let mut ctx = ComputeContext::software();
457        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
458        let buf = ctx.create_storage_buffer(&data);
459        assert_eq!(buf.size, 16);
460        assert_eq!(buf.element_size, 4);
461        assert_eq!(buf.element_count(), 4);
462
463        let readback: Vec<f32> = ctx.read_back(&buf);
464        assert_eq!(readback, vec![1.0, 2.0, 3.0, 4.0]);
465    }
466
467    #[test]
468    fn create_storage_buffer_u32() {
469        let mut ctx = ComputeContext::software();
470        let data: Vec<u32> = vec![10, 20, 30];
471        let buf = ctx.create_storage_buffer(&data);
472        let readback: Vec<u32> = ctx.read_back(&buf);
473        assert_eq!(readback, vec![10, 20, 30]);
474    }
475
476    #[test]
477    fn create_empty_buffer() {
478        let mut ctx = ComputeContext::software();
479        let buf = ctx.create_empty_buffer::<f32>(8);
480        assert_eq!(buf.size, 32);
481        assert_eq!(buf.element_count(), 8);
482    }
483
484    #[test]
485    fn dispatch_pipeline() {
486        let mut ctx = ComputeContext::software();
487        let layout = BindGroupLayout::new().push(0, AccessMode::ReadWrite);
488        let pipeline = ctx.create_pipeline("void main(){}", layout, [64, 1, 1]);
489        ctx.dispatch(&pipeline, 4, 1, 1);
490        assert_eq!(ctx.profiler().total_dispatches(), 1);
491    }
492
493    #[test]
494    fn indirect_dispatch() {
495        let mut ctx = ComputeContext::software();
496        let layout = BindGroupLayout::new();
497        let pipeline = ctx.create_pipeline("void main(){}", layout, [1, 1, 1]);
498
499        // Create an indirect args buffer with [2, 1, 1]
500        let args: Vec<u32> = vec![2, 1, 1];
501        let args_buf = ctx.create_storage_buffer(&args);
502        ctx.indirect_dispatch(&pipeline, &args_buf);
503        assert_eq!(ctx.profiler().total_dispatches(), 1);
504    }
505
506    #[test]
507    fn memory_barrier() {
508        let mut ctx = ComputeContext::software();
509        ctx.memory_barrier();
510    }
511
512    #[test]
513    fn write_and_read_back() {
514        let mut ctx = ComputeContext::software();
515        let buf = ctx.create_empty_buffer::<u32>(4);
516        ctx.write_buffer(&buf, &[100u32, 200, 300, 400]);
517        let result: Vec<u32> = ctx.read_back(&buf);
518        assert_eq!(result, vec![100, 200, 300, 400]);
519    }
520
521    #[test]
522    fn profiler_average() {
523        let mut profiler = ComputeProfiler::new(10);
524        profiler.record("a", [1, 1, 1], Duration::from_millis(10));
525        profiler.record("b", [1, 1, 1], Duration::from_millis(20));
526        assert_eq!(profiler.total_dispatches(), 2);
527        let avg = profiler.average_duration();
528        assert_eq!(avg, Duration::from_millis(15));
529    }
530
531    #[test]
532    fn profiler_rolling() {
533        let mut profiler = ComputeProfiler::new(3);
534        for i in 0..5 {
535            profiler.record(&format!("d{}", i), [1, 1, 1], Duration::from_millis(i as u64));
536        }
537        assert_eq!(profiler.total_dispatches(), 3);
538        assert_eq!(profiler.last().unwrap().label, "d4");
539    }
540
541    #[test]
542    fn profiler_clear() {
543        let mut profiler = ComputeProfiler::new(10);
544        profiler.record("x", [1, 1, 1], Duration::from_millis(5));
545        profiler.clear();
546        assert_eq!(profiler.total_dispatches(), 0);
547        assert_eq!(profiler.average_duration(), Duration::ZERO);
548    }
549
550    #[test]
551    fn cpu_kernel_dispatch() {
552        let kernel = CpuKernel::new([2, 2, 1]);
553        let mut invocations = Vec::new();
554        kernel.dispatch([2, 1, 1], |x, y, z| {
555            invocations.push((x, y, z));
556        });
557        // 2 groups * (2*2*1) local = 8 invocations
558        assert_eq!(invocations.len(), 8);
559        assert!(invocations.contains(&(0, 0, 0)));
560        assert!(invocations.contains(&(3, 1, 0)));
561    }
562
563    #[test]
564    fn cpu_kernel_total_invocations() {
565        let kernel = CpuKernel::new([8, 8, 1]);
566        assert_eq!(kernel.total_invocations([4, 4, 1]), 8 * 8 * 4 * 4);
567    }
568
569    #[test]
570    fn cpu_parallel_dispatch_runs() {
571        use std::sync::atomic::{AtomicU32, Ordering};
572        let counter = AtomicU32::new(0);
573        cpu_parallel_dispatch([2, 1, 1], [4, 1, 1], 2, |_tid, _x, _y, _z| {
574            counter.fetch_add(1, Ordering::Relaxed);
575        });
576        assert_eq!(counter.load(Ordering::Relaxed), 8); // 4 groups * 2 local
577    }
578
579    #[test]
580    fn cpu_parallel_dispatch_single_thread() {
581        use std::sync::atomic::{AtomicU32, Ordering};
582        let counter = AtomicU32::new(0);
583        cpu_parallel_dispatch([1, 1, 1], [3, 2, 1], 1, |_tid, _x, _y, _z| {
584            counter.fetch_add(1, Ordering::Relaxed);
585        });
586        assert_eq!(counter.load(Ordering::Relaxed), 6);
587    }
588
589    #[test]
590    fn bind_group_layout_builder() {
591        let layout = BindGroupLayout::new()
592            .push(0, AccessMode::ReadOnly)
593            .push(1, AccessMode::WriteOnly)
594            .push(2, AccessMode::ReadWrite);
595        assert_eq!(layout.entries.len(), 3);
596        assert_eq!(layout.entries[1].access, AccessMode::WriteOnly);
597    }
598
599    #[test]
600    fn supports_compute_software() {
601        let ctx = ComputeContext::software();
602        assert!(ctx.supports_compute());
603    }
604
605    #[test]
606    fn destroy_buffer() {
607        let mut ctx = ComputeContext::software();
608        let buf = ctx.create_storage_buffer(&[1u32, 2, 3]);
609        ctx.destroy_buffer(&buf);
610        let readback: Vec<u32> = ctx.read_back(&buf);
611        assert!(readback.is_empty());
612    }
613}