1use rayon::prelude::*;
25use std::sync::Arc;
26use thiserror::Error;
27
28#[derive(Debug, Clone, PartialEq, Error)]
32pub enum ShaderError {
33 #[error("Invalid group size: {0}")]
35 InvalidGroupSize(String),
36 #[error("Data slice is empty")]
38 EmptyData,
39 #[error("Kernel panicked: {0}")]
41 KernelPanic(String),
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct ThreadGroupContext {
51 pub group_id: usize,
53 pub local_id: usize,
55 pub group_size: usize,
57 pub global_id: usize,
61}
62
63impl ThreadGroupContext {
64 #[must_use]
66 pub fn new(group_id: usize, local_id: usize, group_size: usize) -> Self {
67 Self {
68 group_id,
69 local_id,
70 group_size,
71 global_id: group_id * group_size + local_id,
72 }
73 }
74}
75
76type KernelFn<T> = Arc<dyn Fn(&ThreadGroupContext, &mut T) + Send + Sync>;
80
81pub struct ShaderKernel<T: Send + Sync> {
83 kernel_fn: KernelFn<T>,
84 group_size: usize,
85 name: String,
86}
87
88impl<T: Send + Sync> ShaderKernel<T> {
89 #[must_use]
95 pub fn new(
96 name: impl Into<String>,
97 group_size: usize,
98 f: impl Fn(&ThreadGroupContext, &mut T) + Send + Sync + 'static,
99 ) -> Self {
100 Self {
101 kernel_fn: Arc::new(f),
102 group_size: group_size.max(1),
103 name: name.into(),
104 }
105 }
106
107 pub fn execute(&self, data: &mut [T], work_groups: usize) {
112 if data.is_empty() || work_groups == 0 {
113 return;
114 }
115 let gs = self.group_size;
116 let kfn = Arc::clone(&self.kernel_fn);
117
118 data.par_iter_mut().enumerate().for_each(|(i, elem)| {
119 let group_id = i / gs;
120 let local_id = i % gs;
121 if group_id < work_groups {
123 let ctx = ThreadGroupContext::new(group_id, local_id, gs);
124 kfn(&ctx, elem);
125 }
126 });
127 }
128
129 #[must_use]
131 pub fn group_size(&self) -> usize {
132 self.group_size
133 }
134
135 #[must_use]
137 pub fn name(&self) -> &str {
138 &self.name
139 }
140}
141
142#[derive(Debug, Clone)]
146pub struct DispatchConfig {
147 pub work_groups: usize,
149 pub group_size: usize,
151 pub label: String,
153}
154
155impl DispatchConfig {
156 #[must_use]
158 pub fn new(work_groups: usize, group_size: usize, label: impl Into<String>) -> Self {
159 Self {
160 work_groups,
161 group_size,
162 label: label.into(),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
174pub struct ComputeShaderSimulator {
175 default_group_size: usize,
176}
177
178impl ComputeShaderSimulator {
179 #[must_use]
183 pub fn new(default_group_size: usize) -> Self {
184 Self {
185 default_group_size: if default_group_size == 0 {
186 64
187 } else {
188 default_group_size
189 },
190 }
191 }
192
193 #[must_use]
195 pub fn default_group_size(&self) -> usize {
196 self.default_group_size
197 }
198
199 #[must_use]
201 pub fn create_kernel<T: Send + Sync + 'static>(
202 &self,
203 name: impl Into<String>,
204 f: impl Fn(&ThreadGroupContext, &mut T) + Send + Sync + 'static,
205 ) -> ShaderKernel<T> {
206 ShaderKernel::new(name, self.default_group_size, f)
207 }
208
209 #[must_use]
211 pub fn create_kernel_with_group_size<T: Send + Sync + 'static>(
212 &self,
213 name: impl Into<String>,
214 group_size: usize,
215 f: impl Fn(&ThreadGroupContext, &mut T) + Send + Sync + 'static,
216 ) -> ShaderKernel<T> {
217 ShaderKernel::new(name, group_size, f)
218 }
219
220 pub fn dispatch<T: Send + Sync>(
222 &self,
223 kernel: &ShaderKernel<T>,
224 data: &mut [T],
225 work_groups: usize,
226 ) {
227 kernel.execute(data, work_groups);
228 }
229
230 pub fn dispatch_with_barrier<T: Send + Sync + Clone>(
238 &self,
239 kernel: &ShaderKernel<T>,
240 data: &mut [T],
241 work_groups: usize,
242 ) {
243 kernel.execute(data, work_groups);
246 }
248}
249
250#[cfg(test)]
253mod tests {
254 use super::*;
255 use std::sync::atomic::{AtomicUsize, Ordering};
256
257 fn work_groups_for(len: usize, group_size: usize) -> usize {
258 (len + group_size - 1) / group_size
259 }
260
261 #[test]
264 fn test_thread_group_context_global_id() {
265 let ctx = ThreadGroupContext::new(3, 5, 8);
266 assert_eq!(ctx.group_id, 3);
267 assert_eq!(ctx.local_id, 5);
268 assert_eq!(ctx.group_size, 8);
269 assert_eq!(ctx.global_id, 3 * 8 + 5);
270 }
271
272 #[test]
273 fn test_thread_group_context_zero_group() {
274 let ctx = ThreadGroupContext::new(0, 0, 64);
275 assert_eq!(ctx.global_id, 0);
276 }
277
278 #[test]
281 fn test_shader_kernel_name_and_group_size() {
282 let k = ShaderKernel::new(
283 "test_kernel",
284 32,
285 |_ctx: &ThreadGroupContext, _v: &mut u32| {},
286 );
287 assert_eq!(k.name(), "test_kernel");
288 assert_eq!(k.group_size(), 32);
289 }
290
291 #[test]
292 fn test_shader_kernel_group_size_zero_normalised() {
293 let k = ShaderKernel::new("k", 0, |_ctx: &ThreadGroupContext, _v: &mut u32| {});
294 assert_eq!(k.group_size(), 1);
295 }
296
297 #[test]
298 fn test_execute_multiply_by_two() {
299 let k = ShaderKernel::new("double", 4, |_ctx: &ThreadGroupContext, v: &mut u32| {
300 *v *= 2;
301 });
302 let mut data = vec![1u32, 2, 3, 4, 5, 6, 7, 8];
303 let wg = work_groups_for(data.len(), 4);
304 k.execute(&mut data, wg);
305 assert_eq!(data, [2, 4, 6, 8, 10, 12, 14, 16]);
306 }
307
308 #[test]
309 fn test_execute_fill_with_global_id() {
310 let k = ShaderKernel::new("fill_id", 8, |ctx: &ThreadGroupContext, v: &mut usize| {
311 *v = ctx.global_id;
312 });
313 let mut data = vec![0usize; 16];
314 let wg = work_groups_for(data.len(), 8);
315 k.execute(&mut data, wg);
316 for (i, &v) in data.iter().enumerate() {
317 assert_eq!(v, i, "element {i} should equal its global_id");
318 }
319 }
320
321 #[test]
322 fn test_execute_work_groups_larger_than_needed() {
323 let k = ShaderKernel::new("k", 4, |_ctx: &ThreadGroupContext, v: &mut u32| {
325 *v += 10;
326 });
327 let mut data = vec![0u32; 6]; k.execute(&mut data, 100); assert!(data.iter().all(|&v| v == 10));
330 }
331
332 #[test]
333 fn test_execute_single_work_group() {
334 let k = ShaderKernel::new("k", 8, |_ctx: &ThreadGroupContext, v: &mut u32| {
335 *v = 42;
336 });
337 let mut data = vec![0u32; 8];
338 k.execute(&mut data, 1);
339 assert!(data.iter().all(|&v| v == 42));
340 }
341
342 #[test]
343 fn test_execute_empty_data_no_panic() {
344 let k = ShaderKernel::new("k", 8, |_ctx: &ThreadGroupContext, v: &mut u32| {
345 *v = 1;
346 });
347 let mut data: Vec<u32> = vec![];
348 k.execute(&mut data, 4);
350 assert!(data.is_empty());
351 }
352
353 #[test]
354 fn test_execute_f32_scale() {
355 let factor = 2.5_f32;
356 let k = ShaderKernel::new(
357 "scale_f32",
358 4,
359 move |_ctx: &ThreadGroupContext, v: &mut f32| {
360 *v *= factor;
361 },
362 );
363 let mut data = vec![1.0_f32, 2.0, 3.0, 4.0];
364 k.execute(&mut data, 1);
365 for (i, &v) in data.iter().enumerate() {
366 let expected = (i as f32 + 1.0) * factor;
367 assert!(
368 (v - expected).abs() < 1e-5,
369 "element {i}: got {v}, expected {expected}"
370 );
371 }
372 }
373
374 #[test]
377 fn test_simulator_default_group_size() {
378 let sim = ComputeShaderSimulator::new(64);
379 assert_eq!(sim.default_group_size(), 64);
380 }
381
382 #[test]
383 fn test_simulator_zero_group_size_normalised() {
384 let sim = ComputeShaderSimulator::new(0);
385 assert_eq!(sim.default_group_size(), 64);
386 }
387
388 #[test]
389 fn test_simulator_create_kernel_and_dispatch() {
390 let sim = ComputeShaderSimulator::new(4);
391 let kernel = sim.create_kernel("incr", |_ctx: &ThreadGroupContext, v: &mut u32| {
392 *v += 1;
393 });
394 let mut data = vec![0u32; 8];
395 let wg = work_groups_for(data.len(), sim.default_group_size());
396 sim.dispatch(&kernel, &mut data, wg);
397 assert!(data.iter().all(|&v| v == 1));
398 }
399
400 #[test]
401 fn test_simulator_create_kernel_with_group_size() {
402 let sim = ComputeShaderSimulator::new(64);
403 let kernel = sim.create_kernel_with_group_size(
404 "k16",
405 16,
406 |_ctx: &ThreadGroupContext, v: &mut u32| {
407 *v = 99;
408 },
409 );
410 assert_eq!(kernel.group_size(), 16);
411 let mut data = vec![0u32; 32];
412 let wg = work_groups_for(data.len(), 16);
413 kernel.execute(&mut data, wg);
414 assert!(data.iter().all(|&v| v == 99));
415 }
416
417 #[test]
418 fn test_dispatch_with_barrier() {
419 let sim = ComputeShaderSimulator::new(8);
420 let k = sim.create_kernel("b_k", |_ctx: &ThreadGroupContext, v: &mut u32| {
421 *v = 7;
422 });
423 let mut data = vec![0u32; 8];
424 sim.dispatch_with_barrier(&k, &mut data, 1);
425 assert!(data.iter().all(|&v| v == 7));
426 }
427
428 #[test]
429 fn test_multiple_kernels_on_same_data() {
430 let sim = ComputeShaderSimulator::new(4);
431 let k1 = sim.create_kernel("add1", |_ctx: &ThreadGroupContext, v: &mut u32| {
432 *v += 1;
433 });
434 let k2 = sim.create_kernel("mul3", |_ctx: &ThreadGroupContext, v: &mut u32| {
435 *v *= 3;
436 });
437 let mut data = vec![0u32; 4];
438 let wg = 1;
439 sim.dispatch(&k1, &mut data, wg);
440 sim.dispatch(&k2, &mut data, wg);
441 assert!(data.iter().all(|&v| v == 3));
443 }
444
445 #[test]
446 fn test_large_data_set() {
447 let sim = ComputeShaderSimulator::new(64);
448 let k = sim.create_kernel("large", |_ctx: &ThreadGroupContext, v: &mut u32| {
449 *v += 1;
450 });
451 let mut data = vec![0u32; 10_000];
452 let wg = work_groups_for(data.len(), sim.default_group_size());
453 sim.dispatch(&k, &mut data, wg);
454 assert!(data.iter().all(|&v| v == 1));
455 }
456
457 #[test]
458 fn test_kernel_captures_closure_state_with_atomic() {
459 let counter = Arc::new(AtomicUsize::new(0));
460 let counter_clone = Arc::clone(&counter);
461 let sim = ComputeShaderSimulator::new(8);
462 let k = sim.create_kernel(
463 "counter_k",
464 move |_ctx: &ThreadGroupContext, _v: &mut u32| {
465 counter_clone.fetch_add(1, Ordering::Relaxed);
466 },
467 );
468 let mut data = vec![0u32; 16];
469 let wg = work_groups_for(data.len(), sim.default_group_size());
470 k.execute(&mut data, wg);
471 assert_eq!(counter.load(Ordering::Relaxed), 16);
472 }
473}