Skip to main content

oxiphysics_gpu/compute/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::{
6    BufferBinding, BufferHandle, BufferId, BufferUsage, PipelineBarrier, WarpDivergenceRecord,
7};
8
9/// Trait for a compute backend (GPU or CPU fallback).
10pub trait ComputeBackend {
11    /// Human-readable name of this backend.
12    fn name(&self) -> &str;
13    /// Allocate a buffer that can hold `size` f64 elements.
14    fn create_buffer(&self, size: usize) -> BufferHandle;
15    /// Write `data` into the buffer referenced by `handle`.
16    fn write_buffer(&self, handle: BufferHandle, data: &[f64]);
17    /// Read the full contents of the buffer referenced by `handle`.
18    fn read_buffer(&self, handle: BufferHandle) -> Vec<f64>;
19    /// Dispatch a compute kernel over `work_size` work items.
20    fn dispatch(&self, kernel: &dyn ComputeKernel, work_size: usize);
21}
22/// Trait for a compute kernel that can be dispatched on a backend.
23pub trait ComputeKernel {
24    /// Human-readable name of this kernel.
25    fn name(&self) -> &str;
26    /// Execute the kernel over `work_size` work items.
27    ///
28    /// * `inputs`  – read-only input slices.
29    /// * `outputs` – mutable output vectors (pre-allocated by the caller).
30    /// * `work_size` – number of logical work items.
31    fn execute(&self, inputs: &[&[f64]], outputs: &mut [Vec<f64>], work_size: usize);
32}
33/// Compute the number of workgroups needed for a 1D dispatch.
34#[allow(dead_code)]
35pub fn compute_num_workgroups(total_items: u32, workgroup_size: u32) -> u32 {
36    total_items.div_ceil(workgroup_size)
37}
38/// Compute workgroup counts for a 3D dispatch.
39#[allow(dead_code)]
40pub fn compute_num_workgroups_3d(total: [u32; 3], workgroup_size: [u32; 3]) -> [u32; 3] {
41    [
42        total[0].div_ceil(workgroup_size[0]),
43        total[1].div_ceil(workgroup_size[1]),
44        total[2].div_ceil(workgroup_size[2]),
45    ]
46}
47/// Determine the required pipeline barrier between two kernel passes.
48///
49/// If the output buffers of pass A overlap with the input buffers of pass B,
50/// a read-after-write barrier is needed.
51#[allow(dead_code)]
52pub fn required_barrier(
53    pass_a_outputs: &[BufferId],
54    pass_b_inputs: &[BufferId],
55) -> PipelineBarrier {
56    let overlap = pass_a_outputs.iter().any(|out| pass_b_inputs.contains(out));
57    if overlap {
58        PipelineBarrier::StorageReadAfterWrite
59    } else {
60        PipelineBarrier::None
61    }
62}
63/// Detect whether any buffers in a pass alias the same storage.
64///
65/// Two bindings alias if they reference the same `BufferId` with incompatible
66/// usages (e.g., one is write and the other is read in the same pass).
67#[allow(dead_code)]
68pub fn detect_aliasing(bindings: &[BufferBinding]) -> Vec<(u32, u32)> {
69    let mut conflicts = Vec::new();
70    for i in 0..bindings.len() {
71        for j in (i + 1)..bindings.len() {
72            if bindings[i].buffer_id == bindings[j].buffer_id {
73                let write_i = matches!(
74                    bindings[i].usage,
75                    BufferUsage::WriteOnly | BufferUsage::ReadWrite
76                );
77                let read_j = matches!(
78                    bindings[j].usage,
79                    BufferUsage::ReadOnly | BufferUsage::ReadWrite
80                );
81                let write_j = matches!(
82                    bindings[j].usage,
83                    BufferUsage::WriteOnly | BufferUsage::ReadWrite
84                );
85                let read_i = matches!(
86                    bindings[i].usage,
87                    BufferUsage::ReadOnly | BufferUsage::ReadWrite
88                );
89                if write_i && read_j || write_j && read_i {
90                    conflicts.push((bindings[i].binding, bindings[j].binding));
91                }
92            }
93        }
94    }
95    conflicts
96}
97/// Simulate warp divergence by analysing a boolean predicate over work items.
98///
99/// Groups work items into warps and checks if all threads take the same branch.
100/// Returns a [`WarpDivergenceRecord`].
101#[allow(dead_code)]
102pub fn analyse_warp_divergence(predicates: &[bool], warp_size: usize) -> WarpDivergenceRecord {
103    if predicates.is_empty() || warp_size == 0 {
104        return WarpDivergenceRecord::default();
105    }
106    let mut total = 0u64;
107    let mut divergent = 0u64;
108    let n_warps = predicates.len().div_ceil(warp_size);
109    for w in 0..n_warps {
110        let start = w * warp_size;
111        let end = (start + warp_size).min(predicates.len());
112        let slice = &predicates[start..end];
113        total += 1;
114        let all_true = slice.iter().all(|&v| v);
115        let all_false = slice.iter().all(|&v| !v);
116        if !all_true && !all_false {
117            divergent += 1;
118        }
119    }
120    WarpDivergenceRecord {
121        total_branches: total,
122        divergent_branches: divergent,
123    }
124}
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::CpuBackend;
129    use crate::compute::ComputeDispatcher;
130    use crate::compute::ComputePass;
131    use crate::compute::GpuBuffer;
132    use crate::compute::GpuCommand;
133    use crate::compute::GpuCommandEncoder;
134    use crate::compute::GpuError;
135    use crate::compute::KernelSpec;
136    use crate::compute::MemoryBandwidthModel;
137    use crate::compute::OccupancyModel;
138    use crate::compute::ResourceLifecycle;
139    use crate::compute::TimelineSemaphore;
140    #[test]
141    fn cpu_backend_buffer_roundtrip() {
142        let backend = CpuBackend::new();
143        let buf = backend.create_buffer(4);
144        backend.write_buffer(buf, &[1.0, 2.0, 3.0, 4.0]);
145        let data = backend.read_buffer(buf);
146        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
147    }
148    #[test]
149    fn dispatcher_buffer_write_read_roundtrip() {
150        let mut d = ComputeDispatcher::new();
151        let id = d.create_buffer(5, None);
152        d.write_buffer(id, &[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
153        let out = d.read_buffer(id).unwrap();
154        assert_eq!(out, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
155    }
156    #[test]
157    fn dispatcher_buffer_initial_data() {
158        let mut d = ComputeDispatcher::new();
159        let id = d.create_buffer(3, Some(&[10.0, 20.0, 30.0]));
160        let out = d.read_buffer(id).unwrap();
161        assert_eq!(out, vec![10.0, 20.0, 30.0]);
162    }
163    #[test]
164    fn dispatcher_invalid_buffer_read_errors() {
165        let d = ComputeDispatcher::new();
166        let bad_id = BufferId(99);
167        assert_eq!(d.read_buffer(bad_id), Err(GpuError::InvalidBuffer(bad_id)));
168    }
169    #[test]
170    fn dispatch_map_identity() {
171        let mut d = ComputeDispatcher::new();
172        let src = d.create_buffer(4, Some(&[1.0, 2.0, 3.0, 4.0]));
173        let dst = d.create_buffer(4, None);
174        d.dispatch_map(src, dst, |x| x).unwrap();
175        assert_eq!(d.read_buffer(dst).unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
176    }
177    #[test]
178    fn dispatch_map_scale_by_two() {
179        let mut d = ComputeDispatcher::new();
180        let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
181        let dst = d.create_buffer(3, None);
182        d.dispatch_map(src, dst, |x| x * 2.0).unwrap();
183        assert_eq!(d.read_buffer(dst).unwrap(), vec![2.0, 4.0, 6.0]);
184    }
185    #[test]
186    fn dispatch_reduce_sum() {
187        let mut d = ComputeDispatcher::new();
188        let id = d.create_buffer(5, Some(&[1.0, 2.0, 3.0, 4.0, 5.0]));
189        let sum = d.dispatch_reduce(id, |a, b| a + b).unwrap();
190        assert!((sum - 15.0).abs() < 1e-12);
191    }
192    #[test]
193    fn dispatch_reduce_max() {
194        let mut d = ComputeDispatcher::new();
195        let id = d.create_buffer(5, Some(&[3.0, 1.0, 7.0, 2.0, 5.0]));
196        let max = d.dispatch_reduce(id, f64::max).unwrap();
197        assert!((max - 7.0).abs() < 1e-12);
198    }
199    #[test]
200    fn dispatch_reduce_empty_errors() {
201        let mut d = ComputeDispatcher::new();
202        let id = d.create_buffer(0, None);
203        assert_eq!(
204            d.dispatch_reduce(id, |a, b| a + b),
205            Err(GpuError::EmptyBuffer)
206        );
207    }
208    #[test]
209    fn sph_density_single_particle_self_contribution_positive() {
210        let mut d = ComputeDispatcher::new();
211        let pos = d.create_buffer(3, Some(&[0.0, 0.0, 0.0]));
212        let mass = d.create_buffer(1, Some(&[1.0]));
213        let out = d.create_buffer(1, None);
214        d.dispatch_sph_density(pos, mass, 1.0, out).unwrap();
215        let density = d.read_buffer(out).unwrap();
216        assert_eq!(density.len(), 1);
217        assert!((density[0] - 1.0).abs() < 1e-12);
218    }
219    #[test]
220    fn sph_density_two_particles_within_kernel_positive() {
221        let mut d = ComputeDispatcher::new();
222        let pos = d.create_buffer(6, Some(&[0.0, 0.0, 0.0, 0.5, 0.0, 0.0]));
223        let mass = d.create_buffer(2, Some(&[1.0, 1.0]));
224        let out = d.create_buffer(2, None);
225        d.dispatch_sph_density(pos, mass, 2.0, out).unwrap();
226        let density = d.read_buffer(out).unwrap();
227        assert_eq!(density.len(), 2);
228        assert!(
229            density[0] > 0.0,
230            "density[0] should be positive: {}",
231            density[0]
232        );
233        assert!(
234            density[1] > 0.0,
235            "density[1] should be positive: {}",
236            density[1]
237        );
238    }
239    #[test]
240    fn sph_density_particles_outside_kernel_zero_cross_contribution() {
241        let mut d = ComputeDispatcher::new();
242        let pos = d.create_buffer(6, Some(&[0.0, 0.0, 0.0, 100.0, 0.0, 0.0]));
243        let mass = d.create_buffer(2, Some(&[1.0, 1.0]));
244        let out = d.create_buffer(2, None);
245        d.dispatch_sph_density(pos, mass, 1.0, out).unwrap();
246        let density = d.read_buffer(out).unwrap();
247        assert!((density[0] - 1.0).abs() < 1e-12);
248        assert!((density[1] - 1.0).abs() < 1e-12);
249    }
250    #[test]
251    fn kernel_spec_creation() {
252        let b0 = BufferId(0);
253        let b1 = BufferId(1);
254        let spec = KernelSpec::new("sph_density", 64, vec![b0, b1]);
255        assert_eq!(spec.name, "sph_density");
256        assert_eq!(spec.workgroup_size, [64, 1, 1]);
257        assert_eq!(spec.buffer_bindings.len(), 2);
258    }
259    #[test]
260    fn gpu_buffer_new_zeros() {
261        let buf = GpuBuffer::new(8);
262        assert_eq!(buf.size, 8);
263        assert!(buf.data.iter().all(|&v| v == 0.0));
264    }
265    #[test]
266    fn test_buffer_binding_shorthands() {
267        let id = BufferId(5);
268        let br = BufferBinding::read(0, id);
269        assert_eq!(br.usage, BufferUsage::ReadOnly);
270        let bw = BufferBinding::write(1, id);
271        assert_eq!(bw.usage, BufferUsage::WriteOnly);
272        let brw = BufferBinding::read_write(2, id);
273        assert_eq!(brw.usage, BufferUsage::ReadWrite);
274        let bu = BufferBinding::uniform(3, id);
275        assert_eq!(bu.usage, BufferUsage::Uniform);
276    }
277    #[test]
278    fn test_kernel_spec_3d_workgroup() {
279        let spec = KernelSpec::with_workgroup_3d("test", [8, 8, 4], vec![]);
280        assert_eq!(spec.workgroup_size, [8, 8, 4]);
281        assert_eq!(spec.threads_per_workgroup(), 256);
282    }
283    #[test]
284    fn test_kernel_spec_num_workgroups() {
285        let spec = KernelSpec::new("test", 64, vec![]);
286        assert_eq!(spec.num_workgroups_x(100), 2);
287        assert_eq!(spec.num_workgroups_x(64), 1);
288        assert_eq!(spec.num_workgroups_x(65), 2);
289    }
290    #[test]
291    fn test_gpu_buffer_fill_and_clear() {
292        let mut buf = GpuBuffer::new(5);
293        buf.fill(42.0);
294        assert!(buf.data.iter().all(|&v| (v - 42.0).abs() < 1e-12));
295        buf.clear();
296        assert!(buf.data.iter().all(|&v| v == 0.0));
297    }
298    #[test]
299    fn test_gpu_buffer_byte_size() {
300        let buf = GpuBuffer::new(10);
301        assert_eq!(buf.byte_size(), 80);
302    }
303    #[test]
304    fn test_gpu_buffer_as_slice() {
305        let buf = GpuBuffer::from_data(vec![1.0, 2.0, 3.0]);
306        assert_eq!(buf.as_slice(), &[1.0, 2.0, 3.0]);
307    }
308    #[test]
309    fn test_cpu_backend_num_buffers() {
310        let backend = CpuBackend::new();
311        assert_eq!(backend.num_buffers(), 0);
312        backend.create_buffer(10);
313        assert_eq!(backend.num_buffers(), 1);
314        backend.create_buffer(5);
315        assert_eq!(backend.num_buffers(), 2);
316    }
317    #[test]
318    fn test_cpu_backend_total_elements() {
319        let backend = CpuBackend::new();
320        backend.create_buffer(10);
321        backend.create_buffer(5);
322        assert_eq!(backend.total_elements(), 15);
323    }
324    #[test]
325    fn test_dispatcher_num_buffers() {
326        let mut d = ComputeDispatcher::new();
327        assert_eq!(d.num_buffers(), 0);
328        d.create_buffer(5, None);
329        assert_eq!(d.num_buffers(), 1);
330    }
331    #[test]
332    fn test_dispatcher_has_buffer() {
333        let mut d = ComputeDispatcher::new();
334        let id = d.create_buffer(5, None);
335        assert!(d.has_buffer(id));
336        assert!(!d.has_buffer(BufferId(999)));
337    }
338    #[test]
339    fn test_dispatcher_buffer_size() {
340        let mut d = ComputeDispatcher::new();
341        let id = d.create_buffer(7, None);
342        assert_eq!(d.buffer_size(id).unwrap(), 7);
343    }
344    #[test]
345    fn test_dispatcher_destroy_buffer() {
346        let mut d = ComputeDispatcher::new();
347        let id = d.create_buffer(5, None);
348        assert!(d.has_buffer(id));
349        d.destroy_buffer(id).unwrap();
350        assert!(!d.has_buffer(id));
351    }
352    #[test]
353    fn test_dispatcher_destroy_invalid_buffer_errors() {
354        let mut d = ComputeDispatcher::new();
355        assert_eq!(
356            d.destroy_buffer(BufferId(42)),
357            Err(GpuError::InvalidBuffer(BufferId(42)))
358        );
359    }
360    #[test]
361    fn test_dispatcher_copy_buffer() {
362        let mut d = ComputeDispatcher::new();
363        let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
364        let dst = d.create_buffer(3, None);
365        d.copy_buffer(src, dst).unwrap();
366        assert_eq!(d.read_buffer(dst).unwrap(), vec![1.0, 2.0, 3.0]);
367    }
368    #[test]
369    fn test_dispatcher_copy_buffer_size_mismatch() {
370        let mut d = ComputeDispatcher::new();
371        let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
372        let dst = d.create_buffer(5, None);
373        assert!(d.copy_buffer(src, dst).is_err());
374    }
375    #[test]
376    fn test_dispatch_map_indexed() {
377        let mut d = ComputeDispatcher::new();
378        let src = d.create_buffer(4, Some(&[10.0, 20.0, 30.0, 40.0]));
379        let dst = d.create_buffer(4, None);
380        d.dispatch_map_indexed(src, dst, |i, x| x + i as f64)
381            .unwrap();
382        assert_eq!(d.read_buffer(dst).unwrap(), vec![10.0, 21.0, 32.0, 43.0]);
383    }
384    #[test]
385    fn test_dispatch_zip_map() {
386        let mut d = ComputeDispatcher::new();
387        let a = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
388        let b = d.create_buffer(3, Some(&[10.0, 20.0, 30.0]));
389        let out = d.create_buffer(3, None);
390        d.dispatch_zip_map(a, b, out, |x, y| x + y).unwrap();
391        assert_eq!(d.read_buffer(out).unwrap(), vec![11.0, 22.0, 33.0]);
392    }
393    #[test]
394    fn test_compute_pass_recording() {
395        let mut pass = ComputePass::new();
396        assert_eq!(pass.num_commands(), 0);
397        pass.dispatch("density", 1000);
398        pass.dispatch("force", 1000);
399        pass.dispatch("integrate", 1000);
400        assert_eq!(pass.num_commands(), 3);
401        assert_eq!(pass.total_work_items(), 3000);
402        assert_eq!(pass.commands()[0].0, "density");
403        assert_eq!(pass.commands()[1].1, 1000);
404    }
405    #[test]
406    fn test_compute_pass_clear() {
407        let mut pass = ComputePass::new();
408        pass.dispatch("test", 100);
409        assert_eq!(pass.num_commands(), 1);
410        pass.clear();
411        assert_eq!(pass.num_commands(), 0);
412    }
413    #[test]
414    fn test_resource_lifecycle_tracking() {
415        let mut lifecycle = ResourceLifecycle::new();
416        assert!(lifecycle.is_empty());
417        let id = BufferId(0);
418        lifecycle.record_create(id, 100);
419        lifecycle.record_write(id);
420        lifecycle.record_write(id);
421        lifecycle.record_read(id);
422        assert_eq!(lifecycle.len(), 4);
423        assert_eq!(lifecycle.count_writes(id), 2);
424        assert_eq!(lifecycle.count_reads(id), 1);
425    }
426    #[test]
427    fn test_resource_lifecycle_clear() {
428        let mut lifecycle = ResourceLifecycle::new();
429        lifecycle.record_create(BufferId(0), 10);
430        lifecycle.clear();
431        assert!(lifecycle.is_empty());
432    }
433    #[test]
434    fn test_compute_num_workgroups() {
435        assert_eq!(compute_num_workgroups(100, 64), 2);
436        assert_eq!(compute_num_workgroups(64, 64), 1);
437        assert_eq!(compute_num_workgroups(1, 64), 1);
438    }
439    #[test]
440    fn test_compute_num_workgroups_3d() {
441        let wg = compute_num_workgroups_3d([100, 100, 100], [8, 8, 8]);
442        assert_eq!(wg, [13, 13, 13]);
443    }
444    #[test]
445    fn test_gpu_error_display() {
446        let e = GpuError::InvalidBuffer(BufferId(5));
447        assert!(format!("{e}").contains("5"));
448        let e2 = GpuError::SizeMismatch {
449            expected: 10,
450            got: 5,
451        };
452        assert!(format!("{e2}").contains("10"));
453        let e3 = GpuError::EmptyBuffer;
454        assert!(format!("{e3}").contains("empty"));
455        let e4 = GpuError::NotFound("test".to_string());
456        assert!(format!("{e4}").contains("test"));
457    }
458    #[test]
459    fn test_command_encoder_basic() {
460        let mut enc = GpuCommandEncoder::new("test_pass");
461        assert_eq!(enc.label(), "test_pass");
462        assert_eq!(enc.command_count(), 0);
463        enc.dispatch_compute("density", [64, 1, 1]);
464        enc.dispatch_compute("force", [64, 1, 1]);
465        enc.insert_barrier(PipelineBarrier::StorageReadAfterWrite);
466        assert_eq!(enc.command_count(), 3);
467    }
468    #[test]
469    fn test_command_encoder_reset() {
470        let mut enc = GpuCommandEncoder::new("enc");
471        enc.dispatch_compute("k", [1, 1, 1]);
472        enc.reset();
473        assert_eq!(enc.command_count(), 0);
474    }
475    #[test]
476    fn test_command_encoder_submit_copies() {
477        let mut enc = GpuCommandEncoder::new("enc");
478        let mut d = ComputeDispatcher::new();
479        let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
480        let dst = d.create_buffer(3, None);
481        enc.copy_buffer(src, dst, 3);
482        enc.submit(&mut d).unwrap();
483        assert_eq!(d.read_buffer(dst).unwrap(), vec![1.0, 2.0, 3.0]);
484    }
485    #[test]
486    fn test_command_encoder_push_constant() {
487        let mut enc = GpuCommandEncoder::new("enc");
488        enc.push_constant("dt", 0.001);
489        assert_eq!(enc.command_count(), 1);
490        match &enc.commands()[0] {
491            GpuCommand::PushConstant { name, value } => {
492                assert_eq!(name, "dt");
493                assert!((value - 0.001).abs() < 1e-15);
494            }
495            _ => panic!("expected PushConstant"),
496        }
497    }
498    #[test]
499    fn test_required_barrier_overlap() {
500        let a_out = vec![BufferId(0), BufferId(1)];
501        let b_in = vec![BufferId(1), BufferId(2)];
502        let barrier = required_barrier(&a_out, &b_in);
503        assert_eq!(barrier, PipelineBarrier::StorageReadAfterWrite);
504    }
505    #[test]
506    fn test_required_barrier_no_overlap() {
507        let a_out = vec![BufferId(0)];
508        let b_in = vec![BufferId(5)];
509        let barrier = required_barrier(&a_out, &b_in);
510        assert_eq!(barrier, PipelineBarrier::None);
511    }
512    #[test]
513    fn test_detect_aliasing_conflict() {
514        let bindings = vec![
515            BufferBinding::write(0, BufferId(10)),
516            BufferBinding::read(1, BufferId(10)),
517        ];
518        let conflicts = detect_aliasing(&bindings);
519        assert!(!conflicts.is_empty(), "should detect aliasing conflict");
520    }
521    #[test]
522    fn test_detect_aliasing_no_conflict() {
523        let bindings = vec![
524            BufferBinding::read(0, BufferId(10)),
525            BufferBinding::read(1, BufferId(11)),
526        ];
527        let conflicts = detect_aliasing(&bindings);
528        assert!(conflicts.is_empty(), "no conflict expected");
529    }
530    #[test]
531    fn test_detect_aliasing_same_buffer_two_reads() {
532        let bindings = vec![
533            BufferBinding::read(0, BufferId(5)),
534            BufferBinding::read(1, BufferId(5)),
535        ];
536        let conflicts = detect_aliasing(&bindings);
537        assert!(conflicts.is_empty());
538    }
539    #[test]
540    fn test_timeline_semaphore_signal_and_wait() {
541        let mut sem = TimelineSemaphore::new();
542        assert_eq!(sem.current_value(), 0);
543        sem.signal(1);
544        assert_eq!(sem.current_value(), 1);
545        assert!(sem.wait(1));
546        assert!(!sem.wait(2));
547        sem.signal(3);
548        assert!(sem.wait(3));
549        assert_eq!(sem.signal_count(), 2);
550    }
551    #[test]
552    fn test_timeline_semaphore_default() {
553        let sem = TimelineSemaphore::default();
554        assert_eq!(sem.current_value(), 0);
555    }
556    #[test]
557    fn test_occupancy_full_when_unconstrained() {
558        let model = OccupancyModel::mid_range();
559        let occ = model.estimate_occupancy(64, 0, 32);
560        assert!(
561            occ > 0.5,
562            "occupancy should be high for small workgroup, got {occ}"
563        );
564    }
565    #[test]
566    fn test_occupancy_limited_by_shared_memory() {
567        let model = OccupancyModel::mid_range();
568        let occ = model.estimate_occupancy(64, model.shared_mem_per_cu, 1);
569        let occ_limited = model.estimate_occupancy(64, model.shared_mem_per_cu / 2, 1);
570        assert!(
571            occ <= occ_limited,
572            "more smem usage should give lower or equal occupancy"
573        );
574    }
575    #[test]
576    fn test_occupancy_bounded_to_one() {
577        let model = OccupancyModel::mid_range();
578        let occ = model.estimate_occupancy(1, 0, 0);
579        assert!((0.0..=1.0).contains(&occ));
580    }
581    #[test]
582    fn test_peak_gflops_positive() {
583        let model = OccupancyModel::mid_range();
584        let gflops = model.peak_gflops(1500.0);
585        assert!(gflops > 0.0);
586    }
587    #[test]
588    fn test_warp_divergence_none() {
589        let predicates = vec![true; 32];
590        let rec = analyse_warp_divergence(&predicates, 32);
591        assert_eq!(rec.divergent_branches, 0);
592        assert!((rec.divergence_rate()).abs() < 1e-12);
593    }
594    #[test]
595    fn test_warp_divergence_full() {
596        let predicates: Vec<bool> = (0..32).map(|i| i % 2 == 0).collect();
597        let rec = analyse_warp_divergence(&predicates, 32);
598        assert_eq!(rec.divergent_branches, 1);
599        assert!((rec.divergence_rate() - 1.0).abs() < 1e-12);
600    }
601    #[test]
602    fn test_warp_divergence_penalty() {
603        let rec = WarpDivergenceRecord {
604            total_branches: 10,
605            divergent_branches: 5,
606        };
607        let penalty = rec.performance_penalty(32);
608        assert!(
609            penalty > 1.0 && penalty < 2.0,
610            "penalty should be > 1, got {penalty}"
611        );
612    }
613    #[test]
614    fn test_warp_divergence_empty() {
615        let rec = analyse_warp_divergence(&[], 32);
616        assert_eq!(rec.total_branches, 0);
617        assert!((rec.divergence_rate()).abs() < 1e-12);
618    }
619    #[test]
620    fn test_memory_bandwidth_arithmetic_intensity() {
621        let intensity = MemoryBandwidthModel::arithmetic_intensity(1000.0, 100.0);
622        assert!((intensity - 10.0).abs() < 1e-12);
623    }
624    #[test]
625    fn test_memory_bandwidth_zero_bytes() {
626        let intensity = MemoryBandwidthModel::arithmetic_intensity(100.0, 0.0);
627        assert!(intensity.is_infinite());
628    }
629    #[test]
630    fn test_roofline_bandwidth_bound() {
631        let model = MemoryBandwidthModel::mid_range();
632        let perf = model.roofline_performance(0.1);
633        let expected = 0.1 * model.peak_bandwidth_gbs;
634        assert!(
635            (perf - expected).abs() < 1e-6,
636            "bandwidth-bound perf mismatch"
637        );
638    }
639    #[test]
640    fn test_roofline_compute_bound() {
641        let model = MemoryBandwidthModel::mid_range();
642        let perf = model.roofline_performance(1e9);
643        assert!((perf - model.peak_compute_gflops).abs() < 1e-6);
644    }
645    #[test]
646    fn test_is_bandwidth_bound() {
647        let model = MemoryBandwidthModel::mid_range();
648        let ridge = model.peak_compute_gflops / model.peak_bandwidth_gbs;
649        assert!(model.is_bandwidth_bound(ridge * 0.5));
650        assert!(!model.is_bandwidth_bound(ridge * 2.0));
651    }
652    #[test]
653    fn test_estimated_runtime_ms_positive() {
654        let model = MemoryBandwidthModel::mid_range();
655        let t = model.estimated_runtime_ms(1e12, 1e9);
656        assert!(t > 0.0 && t.is_finite());
657    }
658    #[test]
659    fn test_reduction_tree_sum() {
660        let mut d = ComputeDispatcher::new();
661        let buf = d.create_buffer(4, Some(&[1.0, 2.0, 3.0, 4.0]));
662        let result = d.dispatch_reduction_tree(buf).unwrap();
663        assert!(
664            (result - 10.0).abs() < 1e-12,
665            "sum should be 10, got {result}"
666        );
667    }
668    #[test]
669    fn test_reduction_tree_empty() {
670        let mut d = ComputeDispatcher::new();
671        let buf = d.create_buffer(0, Some(&[]));
672        let result = d.dispatch_reduction_tree(buf).unwrap();
673        assert_eq!(result, 0.0);
674    }
675    #[test]
676    fn test_reduction_tree_single_element() {
677        let mut d = ComputeDispatcher::new();
678        let buf = d.create_buffer(1, Some(&[42.0]));
679        let result = d.dispatch_reduction_tree(buf).unwrap();
680        assert!((result - 42.0).abs() < 1e-12);
681    }
682    #[test]
683    fn test_reduction_tree_power_of_two() {
684        let data: Vec<f64> = (1..=8).map(|x| x as f64).collect();
685        let mut d = ComputeDispatcher::new();
686        let buf = d.create_buffer(8, Some(&data));
687        let result = d.dispatch_reduction_tree(buf).unwrap();
688        assert!((result - 36.0).abs() < 1e-12, "1+2+…+8=36, got {result}");
689    }
690    #[test]
691    fn test_inclusive_scan_basic() {
692        let mut d = ComputeDispatcher::new();
693        let buf_in = d.create_buffer(4, Some(&[1.0, 2.0, 3.0, 4.0]));
694        let buf_out = d.create_buffer(4, None);
695        d.dispatch_inclusive_scan(buf_in, buf_out).unwrap();
696        let result = d.read_buffer(buf_out).unwrap();
697        let expected = [1.0, 3.0, 6.0, 10.0];
698        for (a, b) in result.iter().zip(expected.iter()) {
699            assert!((a - b).abs() < 1e-12, "mismatch: {a} vs {b}");
700        }
701    }
702    #[test]
703    fn test_inclusive_scan_single() {
704        let mut d = ComputeDispatcher::new();
705        let buf_in = d.create_buffer(1, Some(&[7.0]));
706        let buf_out = d.create_buffer(1, None);
707        d.dispatch_inclusive_scan(buf_in, buf_out).unwrap();
708        let result = d.read_buffer(buf_out).unwrap();
709        assert!((result[0] - 7.0).abs() < 1e-12);
710    }
711    #[test]
712    fn test_radix_sort_basic() {
713        let data = vec![5.0, 1.0, 3.0, 2.0, 4.0];
714        let mut d = ComputeDispatcher::new();
715        let buf = d.create_buffer(5, Some(&data));
716        let sorted = d.dispatch_radix_sort(buf).unwrap();
717        for w in sorted.windows(2) {
718            assert!(w[0] <= w[1], "not sorted: {} > {}", w[0], w[1]);
719        }
720    }
721    #[test]
722    fn test_radix_sort_empty() {
723        let mut d = ComputeDispatcher::new();
724        let buf = d.create_buffer(0, Some(&[]));
725        let sorted = d.dispatch_radix_sort(buf).unwrap();
726        assert!(sorted.is_empty());
727    }
728    #[test]
729    fn test_radix_sort_already_sorted() {
730        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
731        let mut d = ComputeDispatcher::new();
732        let buf = d.create_buffer(5, Some(&data));
733        let sorted = d.dispatch_radix_sort(buf).unwrap();
734        assert_eq!(sorted, data);
735    }
736    #[test]
737    fn test_radix_sort_length_preserved() {
738        let data: Vec<f64> = (0..16).map(|i| (16 - i) as f64).collect();
739        let mut d = ComputeDispatcher::new();
740        let buf = d.create_buffer(16, Some(&data));
741        let sorted = d.dispatch_radix_sort(buf).unwrap();
742        assert_eq!(sorted.len(), 16);
743    }
744}