Skip to main content

oximedia_gpu/
compute_shader.rs

1//! GPU compute shader simulator.
2//!
3//! Simulates GPU-style work-group execution using Rayon thread-pool parallelism.
4//! Each [`ShaderKernel`] receives a [`ThreadGroupContext`] per element that
5//! mirrors the `gl_GlobalInvocationID` / `gl_LocalInvocationID` / `gl_WorkGroupID`
6//! semantics of GLSL/HLSL compute shaders.
7//!
8//! # Example
9//!
10//! ```rust
11//! use oximedia_gpu::compute_shader::{ComputeShaderSimulator, ThreadGroupContext};
12//!
13//! let sim = ComputeShaderSimulator::new(64);
14//! let kernel = sim.create_kernel("double", |ctx: &ThreadGroupContext, v: &mut u32| {
15//!     *v = *v * 2;
16//! });
17//!
18//! let mut data = vec![1u32, 2, 3, 4];
19//! let work_groups = (data.len() + sim.default_group_size() - 1) / sim.default_group_size();
20//! kernel.execute(&mut data, work_groups);
21//! assert_eq!(data, [2, 4, 6, 8]);
22//! ```
23
24use rayon::prelude::*;
25use std::sync::Arc;
26use thiserror::Error;
27
28// ─── Error ────────────────────────────────────────────────────────────────────
29
30/// Errors returned by compute shader operations.
31#[derive(Debug, Clone, PartialEq, Error)]
32pub enum ShaderError {
33    /// The requested group size is zero or otherwise invalid.
34    #[error("Invalid group size: {0}")]
35    InvalidGroupSize(String),
36    /// The data slice passed to the kernel is empty.
37    #[error("Data slice is empty")]
38    EmptyData,
39    /// A kernel closure panicked during execution.
40    #[error("Kernel panicked: {0}")]
41    KernelPanic(String),
42}
43
44// ─── ThreadGroupContext ───────────────────────────────────────────────────────
45
46/// Execution context passed to each invocation of a kernel closure.
47///
48/// Mirrors the built-in variables available in GLSL/HLSL/WGSL compute shaders.
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct ThreadGroupContext {
51    /// Index of the work group this thread belongs to (`gl_WorkGroupID`).
52    pub group_id: usize,
53    /// Thread index within the work group (`gl_LocalInvocationID`).
54    pub local_id: usize,
55    /// Number of threads per work group.
56    pub group_size: usize,
57    /// Flat global index across all work groups (`gl_GlobalInvocationID`).
58    ///
59    /// Equals `group_id * group_size + local_id`.
60    pub global_id: usize,
61}
62
63impl ThreadGroupContext {
64    /// Construct a context from its constituent indices.
65    #[must_use]
66    pub fn new(group_id: usize, local_id: usize, group_size: usize) -> Self {
67        Self {
68            group_id,
69            local_id,
70            group_size,
71            global_id: group_id * group_size + local_id,
72        }
73    }
74}
75
76// ─── ShaderKernel ─────────────────────────────────────────────────────────────
77
78/// Type-erased kernel function: closure that receives context + mutable element.
79type KernelFn<T> = Arc<dyn Fn(&ThreadGroupContext, &mut T) + Send + Sync>;
80
81/// A named, parameterised GPU kernel ready for parallel dispatch.
82pub struct ShaderKernel<T: Send + Sync> {
83    kernel_fn: KernelFn<T>,
84    group_size: usize,
85    name: String,
86}
87
88impl<T: Send + Sync> ShaderKernel<T> {
89    /// Create a new kernel with an explicit group size.
90    ///
91    /// # Panics
92    ///
93    /// Does not panic; `group_size = 0` is normalised to 1 at runtime.
94    #[must_use]
95    pub fn new(
96        name: impl Into<String>,
97        group_size: usize,
98        f: impl Fn(&ThreadGroupContext, &mut T) + Send + Sync + 'static,
99    ) -> Self {
100        Self {
101            kernel_fn: Arc::new(f),
102            group_size: group_size.max(1),
103            name: name.into(),
104        }
105    }
106
107    /// Execute the kernel over `data` using `work_groups` groups in parallel.
108    ///
109    /// Elements at indices `>= work_groups * group_size` are silently ignored,
110    /// matching GPU semantics where excess invocations are masked out.
111    pub fn execute(&self, data: &mut [T], work_groups: usize) {
112        if data.is_empty() || work_groups == 0 {
113            return;
114        }
115        let gs = self.group_size;
116        let kfn = Arc::clone(&self.kernel_fn);
117
118        data.par_iter_mut().enumerate().for_each(|(i, elem)| {
119            let group_id = i / gs;
120            let local_id = i % gs;
121            // Only process elements within the declared work-group count.
122            if group_id < work_groups {
123                let ctx = ThreadGroupContext::new(group_id, local_id, gs);
124                kfn(&ctx, elem);
125            }
126        });
127    }
128
129    /// The number of threads per work group.
130    #[must_use]
131    pub fn group_size(&self) -> usize {
132        self.group_size
133    }
134
135    /// The human-readable label for this kernel.
136    #[must_use]
137    pub fn name(&self) -> &str {
138        &self.name
139    }
140}
141
142// ─── DispatchConfig ──────────────────────────────────────────────────────────
143
144/// Configuration bundle for a single kernel dispatch.
145#[derive(Debug, Clone)]
146pub struct DispatchConfig {
147    /// Number of work groups to dispatch.
148    pub work_groups: usize,
149    /// Threads per work group (overrides kernel default when > 0).
150    pub group_size: usize,
151    /// Human-readable label used for profiling/logging.
152    pub label: String,
153}
154
155impl DispatchConfig {
156    /// Convenience constructor.
157    #[must_use]
158    pub fn new(work_groups: usize, group_size: usize, label: impl Into<String>) -> Self {
159        Self {
160            work_groups,
161            group_size,
162            label: label.into(),
163        }
164    }
165}
166
167// ─── ComputeShaderSimulator ───────────────────────────────────────────────────
168
169/// High-level entry point for simulated GPU compute.
170///
171/// Manages a default group size and provides factory methods for creating
172/// typed [`ShaderKernel`] instances.
173#[derive(Debug, Clone)]
174pub struct ComputeShaderSimulator {
175    default_group_size: usize,
176}
177
178impl ComputeShaderSimulator {
179    /// Create a simulator with the given default work-group size.
180    ///
181    /// Sizes of 0 are normalised to 64 (a common GPU default).
182    #[must_use]
183    pub fn new(default_group_size: usize) -> Self {
184        Self {
185            default_group_size: if default_group_size == 0 {
186                64
187            } else {
188                default_group_size
189            },
190        }
191    }
192
193    /// The default number of threads per work group.
194    #[must_use]
195    pub fn default_group_size(&self) -> usize {
196        self.default_group_size
197    }
198
199    /// Create a kernel using the simulator's default group size.
200    #[must_use]
201    pub fn create_kernel<T: Send + Sync + 'static>(
202        &self,
203        name: impl Into<String>,
204        f: impl Fn(&ThreadGroupContext, &mut T) + Send + Sync + 'static,
205    ) -> ShaderKernel<T> {
206        ShaderKernel::new(name, self.default_group_size, f)
207    }
208
209    /// Create a kernel with a custom group size, ignoring the simulator default.
210    #[must_use]
211    pub fn create_kernel_with_group_size<T: Send + Sync + 'static>(
212        &self,
213        name: impl Into<String>,
214        group_size: usize,
215        f: impl Fn(&ThreadGroupContext, &mut T) + Send + Sync + 'static,
216    ) -> ShaderKernel<T> {
217        ShaderKernel::new(name, group_size, f)
218    }
219
220    /// Dispatch `kernel` over `data` using `work_groups` work groups.
221    pub fn dispatch<T: Send + Sync>(
222        &self,
223        kernel: &ShaderKernel<T>,
224        data: &mut [T],
225        work_groups: usize,
226    ) {
227        kernel.execute(data, work_groups);
228    }
229
230    /// Dispatch `kernel` and wait for all threads to complete.
231    ///
232    /// Rayon's `par_iter_mut` already joins all threads before returning, so
233    /// this is semantically equivalent to [`dispatch`].  The method exists to
234    /// model GPU barriers explicitly in calling code.
235    ///
236    /// [`dispatch`]: ComputeShaderSimulator::dispatch
237    pub fn dispatch_with_barrier<T: Send + Sync + Clone>(
238        &self,
239        kernel: &ShaderKernel<T>,
240        data: &mut [T],
241        work_groups: usize,
242    ) {
243        // Rayon join semantics: all parallel work completes before the call
244        // returns — no additional synchronisation primitive needed.
245        kernel.execute(data, work_groups);
246        // Conceptual barrier point; Rayon's fork-join ensures this.
247    }
248}
249
250// ─── Tests ───────────────────────────────────────────────────────────────────
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use std::sync::atomic::{AtomicUsize, Ordering};
256
257    fn work_groups_for(len: usize, group_size: usize) -> usize {
258        (len + group_size - 1) / group_size
259    }
260
261    // ── ThreadGroupContext ────────────────────────────────────────────────────
262
263    #[test]
264    fn test_thread_group_context_global_id() {
265        let ctx = ThreadGroupContext::new(3, 5, 8);
266        assert_eq!(ctx.group_id, 3);
267        assert_eq!(ctx.local_id, 5);
268        assert_eq!(ctx.group_size, 8);
269        assert_eq!(ctx.global_id, 3 * 8 + 5);
270    }
271
272    #[test]
273    fn test_thread_group_context_zero_group() {
274        let ctx = ThreadGroupContext::new(0, 0, 64);
275        assert_eq!(ctx.global_id, 0);
276    }
277
278    // ── ShaderKernel ─────────────────────────────────────────────────────────
279
280    #[test]
281    fn test_shader_kernel_name_and_group_size() {
282        let k = ShaderKernel::new(
283            "test_kernel",
284            32,
285            |_ctx: &ThreadGroupContext, _v: &mut u32| {},
286        );
287        assert_eq!(k.name(), "test_kernel");
288        assert_eq!(k.group_size(), 32);
289    }
290
291    #[test]
292    fn test_shader_kernel_group_size_zero_normalised() {
293        let k = ShaderKernel::new("k", 0, |_ctx: &ThreadGroupContext, _v: &mut u32| {});
294        assert_eq!(k.group_size(), 1);
295    }
296
297    #[test]
298    fn test_execute_multiply_by_two() {
299        let k = ShaderKernel::new("double", 4, |_ctx: &ThreadGroupContext, v: &mut u32| {
300            *v *= 2;
301        });
302        let mut data = vec![1u32, 2, 3, 4, 5, 6, 7, 8];
303        let wg = work_groups_for(data.len(), 4);
304        k.execute(&mut data, wg);
305        assert_eq!(data, [2, 4, 6, 8, 10, 12, 14, 16]);
306    }
307
308    #[test]
309    fn test_execute_fill_with_global_id() {
310        let k = ShaderKernel::new("fill_id", 8, |ctx: &ThreadGroupContext, v: &mut usize| {
311            *v = ctx.global_id;
312        });
313        let mut data = vec![0usize; 16];
314        let wg = work_groups_for(data.len(), 8);
315        k.execute(&mut data, wg);
316        for (i, &v) in data.iter().enumerate() {
317            assert_eq!(v, i, "element {i} should equal its global_id");
318        }
319    }
320
321    #[test]
322    fn test_execute_work_groups_larger_than_needed() {
323        // Extra work groups simply have no data elements to process.
324        let k = ShaderKernel::new("k", 4, |_ctx: &ThreadGroupContext, v: &mut u32| {
325            *v += 10;
326        });
327        let mut data = vec![0u32; 6]; // 6 elements, group_size=4 → 2 groups
328        k.execute(&mut data, 100); // 100 work groups requested – fine
329        assert!(data.iter().all(|&v| v == 10));
330    }
331
332    #[test]
333    fn test_execute_single_work_group() {
334        let k = ShaderKernel::new("k", 8, |_ctx: &ThreadGroupContext, v: &mut u32| {
335            *v = 42;
336        });
337        let mut data = vec![0u32; 8];
338        k.execute(&mut data, 1);
339        assert!(data.iter().all(|&v| v == 42));
340    }
341
342    #[test]
343    fn test_execute_empty_data_no_panic() {
344        let k = ShaderKernel::new("k", 8, |_ctx: &ThreadGroupContext, v: &mut u32| {
345            *v = 1;
346        });
347        let mut data: Vec<u32> = vec![];
348        // Must not panic
349        k.execute(&mut data, 4);
350        assert!(data.is_empty());
351    }
352
353    #[test]
354    fn test_execute_f32_scale() {
355        let factor = 2.5_f32;
356        let k = ShaderKernel::new(
357            "scale_f32",
358            4,
359            move |_ctx: &ThreadGroupContext, v: &mut f32| {
360                *v *= factor;
361            },
362        );
363        let mut data = vec![1.0_f32, 2.0, 3.0, 4.0];
364        k.execute(&mut data, 1);
365        for (i, &v) in data.iter().enumerate() {
366            let expected = (i as f32 + 1.0) * factor;
367            assert!(
368                (v - expected).abs() < 1e-5,
369                "element {i}: got {v}, expected {expected}"
370            );
371        }
372    }
373
374    // ── ComputeShaderSimulator ────────────────────────────────────────────────
375
376    #[test]
377    fn test_simulator_default_group_size() {
378        let sim = ComputeShaderSimulator::new(64);
379        assert_eq!(sim.default_group_size(), 64);
380    }
381
382    #[test]
383    fn test_simulator_zero_group_size_normalised() {
384        let sim = ComputeShaderSimulator::new(0);
385        assert_eq!(sim.default_group_size(), 64);
386    }
387
388    #[test]
389    fn test_simulator_create_kernel_and_dispatch() {
390        let sim = ComputeShaderSimulator::new(4);
391        let kernel = sim.create_kernel("incr", |_ctx: &ThreadGroupContext, v: &mut u32| {
392            *v += 1;
393        });
394        let mut data = vec![0u32; 8];
395        let wg = work_groups_for(data.len(), sim.default_group_size());
396        sim.dispatch(&kernel, &mut data, wg);
397        assert!(data.iter().all(|&v| v == 1));
398    }
399
400    #[test]
401    fn test_simulator_create_kernel_with_group_size() {
402        let sim = ComputeShaderSimulator::new(64);
403        let kernel = sim.create_kernel_with_group_size(
404            "k16",
405            16,
406            |_ctx: &ThreadGroupContext, v: &mut u32| {
407                *v = 99;
408            },
409        );
410        assert_eq!(kernel.group_size(), 16);
411        let mut data = vec![0u32; 32];
412        let wg = work_groups_for(data.len(), 16);
413        kernel.execute(&mut data, wg);
414        assert!(data.iter().all(|&v| v == 99));
415    }
416
417    #[test]
418    fn test_dispatch_with_barrier() {
419        let sim = ComputeShaderSimulator::new(8);
420        let k = sim.create_kernel("b_k", |_ctx: &ThreadGroupContext, v: &mut u32| {
421            *v = 7;
422        });
423        let mut data = vec![0u32; 8];
424        sim.dispatch_with_barrier(&k, &mut data, 1);
425        assert!(data.iter().all(|&v| v == 7));
426    }
427
428    #[test]
429    fn test_multiple_kernels_on_same_data() {
430        let sim = ComputeShaderSimulator::new(4);
431        let k1 = sim.create_kernel("add1", |_ctx: &ThreadGroupContext, v: &mut u32| {
432            *v += 1;
433        });
434        let k2 = sim.create_kernel("mul3", |_ctx: &ThreadGroupContext, v: &mut u32| {
435            *v *= 3;
436        });
437        let mut data = vec![0u32; 4];
438        let wg = 1;
439        sim.dispatch(&k1, &mut data, wg);
440        sim.dispatch(&k2, &mut data, wg);
441        // Each element: (0+1)*3 = 3
442        assert!(data.iter().all(|&v| v == 3));
443    }
444
445    #[test]
446    fn test_large_data_set() {
447        let sim = ComputeShaderSimulator::new(64);
448        let k = sim.create_kernel("large", |_ctx: &ThreadGroupContext, v: &mut u32| {
449            *v += 1;
450        });
451        let mut data = vec![0u32; 10_000];
452        let wg = work_groups_for(data.len(), sim.default_group_size());
453        sim.dispatch(&k, &mut data, wg);
454        assert!(data.iter().all(|&v| v == 1));
455    }
456
457    #[test]
458    fn test_kernel_captures_closure_state_with_atomic() {
459        let counter = Arc::new(AtomicUsize::new(0));
460        let counter_clone = Arc::clone(&counter);
461        let sim = ComputeShaderSimulator::new(8);
462        let k = sim.create_kernel(
463            "counter_k",
464            move |_ctx: &ThreadGroupContext, _v: &mut u32| {
465                counter_clone.fetch_add(1, Ordering::Relaxed);
466            },
467        );
468        let mut data = vec![0u32; 16];
469        let wg = work_groups_for(data.len(), sim.default_group_size());
470        k.execute(&mut data, wg);
471        assert_eq!(counter.load(Ordering::Relaxed), 16);
472    }
473}