1use super::types::{
6 BufferBinding, BufferHandle, BufferId, BufferUsage, PipelineBarrier, WarpDivergenceRecord,
7};
8
9pub trait ComputeBackend {
11 fn name(&self) -> &str;
13 fn create_buffer(&self, size: usize) -> BufferHandle;
15 fn write_buffer(&self, handle: BufferHandle, data: &[f64]);
17 fn read_buffer(&self, handle: BufferHandle) -> Vec<f64>;
19 fn dispatch(&self, kernel: &dyn ComputeKernel, work_size: usize);
21}
22pub trait ComputeKernel {
24 fn name(&self) -> &str;
26 fn execute(&self, inputs: &[&[f64]], outputs: &mut [Vec<f64>], work_size: usize);
32}
33#[allow(dead_code)]
35pub fn compute_num_workgroups(total_items: u32, workgroup_size: u32) -> u32 {
36 total_items.div_ceil(workgroup_size)
37}
38#[allow(dead_code)]
40pub fn compute_num_workgroups_3d(total: [u32; 3], workgroup_size: [u32; 3]) -> [u32; 3] {
41 [
42 total[0].div_ceil(workgroup_size[0]),
43 total[1].div_ceil(workgroup_size[1]),
44 total[2].div_ceil(workgroup_size[2]),
45 ]
46}
47#[allow(dead_code)]
52pub fn required_barrier(
53 pass_a_outputs: &[BufferId],
54 pass_b_inputs: &[BufferId],
55) -> PipelineBarrier {
56 let overlap = pass_a_outputs.iter().any(|out| pass_b_inputs.contains(out));
57 if overlap {
58 PipelineBarrier::StorageReadAfterWrite
59 } else {
60 PipelineBarrier::None
61 }
62}
63#[allow(dead_code)]
68pub fn detect_aliasing(bindings: &[BufferBinding]) -> Vec<(u32, u32)> {
69 let mut conflicts = Vec::new();
70 for i in 0..bindings.len() {
71 for j in (i + 1)..bindings.len() {
72 if bindings[i].buffer_id == bindings[j].buffer_id {
73 let write_i = matches!(
74 bindings[i].usage,
75 BufferUsage::WriteOnly | BufferUsage::ReadWrite
76 );
77 let read_j = matches!(
78 bindings[j].usage,
79 BufferUsage::ReadOnly | BufferUsage::ReadWrite
80 );
81 let write_j = matches!(
82 bindings[j].usage,
83 BufferUsage::WriteOnly | BufferUsage::ReadWrite
84 );
85 let read_i = matches!(
86 bindings[i].usage,
87 BufferUsage::ReadOnly | BufferUsage::ReadWrite
88 );
89 if write_i && read_j || write_j && read_i {
90 conflicts.push((bindings[i].binding, bindings[j].binding));
91 }
92 }
93 }
94 }
95 conflicts
96}
97#[allow(dead_code)]
102pub fn analyse_warp_divergence(predicates: &[bool], warp_size: usize) -> WarpDivergenceRecord {
103 if predicates.is_empty() || warp_size == 0 {
104 return WarpDivergenceRecord::default();
105 }
106 let mut total = 0u64;
107 let mut divergent = 0u64;
108 let n_warps = predicates.len().div_ceil(warp_size);
109 for w in 0..n_warps {
110 let start = w * warp_size;
111 let end = (start + warp_size).min(predicates.len());
112 let slice = &predicates[start..end];
113 total += 1;
114 let all_true = slice.iter().all(|&v| v);
115 let all_false = slice.iter().all(|&v| !v);
116 if !all_true && !all_false {
117 divergent += 1;
118 }
119 }
120 WarpDivergenceRecord {
121 total_branches: total,
122 divergent_branches: divergent,
123 }
124}
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::CpuBackend;
129 use crate::compute::ComputeDispatcher;
130 use crate::compute::ComputePass;
131 use crate::compute::GpuBuffer;
132 use crate::compute::GpuCommand;
133 use crate::compute::GpuCommandEncoder;
134 use crate::compute::GpuError;
135 use crate::compute::KernelSpec;
136 use crate::compute::MemoryBandwidthModel;
137 use crate::compute::OccupancyModel;
138 use crate::compute::ResourceLifecycle;
139 use crate::compute::TimelineSemaphore;
140 #[test]
141 fn cpu_backend_buffer_roundtrip() {
142 let backend = CpuBackend::new();
143 let buf = backend.create_buffer(4);
144 backend.write_buffer(buf, &[1.0, 2.0, 3.0, 4.0]);
145 let data = backend.read_buffer(buf);
146 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
147 }
148 #[test]
149 fn dispatcher_buffer_write_read_roundtrip() {
150 let mut d = ComputeDispatcher::new();
151 let id = d.create_buffer(5, None);
152 d.write_buffer(id, &[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
153 let out = d.read_buffer(id).unwrap();
154 assert_eq!(out, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
155 }
156 #[test]
157 fn dispatcher_buffer_initial_data() {
158 let mut d = ComputeDispatcher::new();
159 let id = d.create_buffer(3, Some(&[10.0, 20.0, 30.0]));
160 let out = d.read_buffer(id).unwrap();
161 assert_eq!(out, vec![10.0, 20.0, 30.0]);
162 }
163 #[test]
164 fn dispatcher_invalid_buffer_read_errors() {
165 let d = ComputeDispatcher::new();
166 let bad_id = BufferId(99);
167 assert_eq!(d.read_buffer(bad_id), Err(GpuError::InvalidBuffer(bad_id)));
168 }
169 #[test]
170 fn dispatch_map_identity() {
171 let mut d = ComputeDispatcher::new();
172 let src = d.create_buffer(4, Some(&[1.0, 2.0, 3.0, 4.0]));
173 let dst = d.create_buffer(4, None);
174 d.dispatch_map(src, dst, |x| x).unwrap();
175 assert_eq!(d.read_buffer(dst).unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
176 }
177 #[test]
178 fn dispatch_map_scale_by_two() {
179 let mut d = ComputeDispatcher::new();
180 let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
181 let dst = d.create_buffer(3, None);
182 d.dispatch_map(src, dst, |x| x * 2.0).unwrap();
183 assert_eq!(d.read_buffer(dst).unwrap(), vec![2.0, 4.0, 6.0]);
184 }
185 #[test]
186 fn dispatch_reduce_sum() {
187 let mut d = ComputeDispatcher::new();
188 let id = d.create_buffer(5, Some(&[1.0, 2.0, 3.0, 4.0, 5.0]));
189 let sum = d.dispatch_reduce(id, |a, b| a + b).unwrap();
190 assert!((sum - 15.0).abs() < 1e-12);
191 }
192 #[test]
193 fn dispatch_reduce_max() {
194 let mut d = ComputeDispatcher::new();
195 let id = d.create_buffer(5, Some(&[3.0, 1.0, 7.0, 2.0, 5.0]));
196 let max = d.dispatch_reduce(id, f64::max).unwrap();
197 assert!((max - 7.0).abs() < 1e-12);
198 }
199 #[test]
200 fn dispatch_reduce_empty_errors() {
201 let mut d = ComputeDispatcher::new();
202 let id = d.create_buffer(0, None);
203 assert_eq!(
204 d.dispatch_reduce(id, |a, b| a + b),
205 Err(GpuError::EmptyBuffer)
206 );
207 }
208 #[test]
209 fn sph_density_single_particle_self_contribution_positive() {
210 let mut d = ComputeDispatcher::new();
211 let pos = d.create_buffer(3, Some(&[0.0, 0.0, 0.0]));
212 let mass = d.create_buffer(1, Some(&[1.0]));
213 let out = d.create_buffer(1, None);
214 d.dispatch_sph_density(pos, mass, 1.0, out).unwrap();
215 let density = d.read_buffer(out).unwrap();
216 assert_eq!(density.len(), 1);
217 assert!((density[0] - 1.0).abs() < 1e-12);
218 }
219 #[test]
220 fn sph_density_two_particles_within_kernel_positive() {
221 let mut d = ComputeDispatcher::new();
222 let pos = d.create_buffer(6, Some(&[0.0, 0.0, 0.0, 0.5, 0.0, 0.0]));
223 let mass = d.create_buffer(2, Some(&[1.0, 1.0]));
224 let out = d.create_buffer(2, None);
225 d.dispatch_sph_density(pos, mass, 2.0, out).unwrap();
226 let density = d.read_buffer(out).unwrap();
227 assert_eq!(density.len(), 2);
228 assert!(
229 density[0] > 0.0,
230 "density[0] should be positive: {}",
231 density[0]
232 );
233 assert!(
234 density[1] > 0.0,
235 "density[1] should be positive: {}",
236 density[1]
237 );
238 }
239 #[test]
240 fn sph_density_particles_outside_kernel_zero_cross_contribution() {
241 let mut d = ComputeDispatcher::new();
242 let pos = d.create_buffer(6, Some(&[0.0, 0.0, 0.0, 100.0, 0.0, 0.0]));
243 let mass = d.create_buffer(2, Some(&[1.0, 1.0]));
244 let out = d.create_buffer(2, None);
245 d.dispatch_sph_density(pos, mass, 1.0, out).unwrap();
246 let density = d.read_buffer(out).unwrap();
247 assert!((density[0] - 1.0).abs() < 1e-12);
248 assert!((density[1] - 1.0).abs() < 1e-12);
249 }
250 #[test]
251 fn kernel_spec_creation() {
252 let b0 = BufferId(0);
253 let b1 = BufferId(1);
254 let spec = KernelSpec::new("sph_density", 64, vec![b0, b1]);
255 assert_eq!(spec.name, "sph_density");
256 assert_eq!(spec.workgroup_size, [64, 1, 1]);
257 assert_eq!(spec.buffer_bindings.len(), 2);
258 }
259 #[test]
260 fn gpu_buffer_new_zeros() {
261 let buf = GpuBuffer::new(8);
262 assert_eq!(buf.size, 8);
263 assert!(buf.data.iter().all(|&v| v == 0.0));
264 }
265 #[test]
266 fn test_buffer_binding_shorthands() {
267 let id = BufferId(5);
268 let br = BufferBinding::read(0, id);
269 assert_eq!(br.usage, BufferUsage::ReadOnly);
270 let bw = BufferBinding::write(1, id);
271 assert_eq!(bw.usage, BufferUsage::WriteOnly);
272 let brw = BufferBinding::read_write(2, id);
273 assert_eq!(brw.usage, BufferUsage::ReadWrite);
274 let bu = BufferBinding::uniform(3, id);
275 assert_eq!(bu.usage, BufferUsage::Uniform);
276 }
277 #[test]
278 fn test_kernel_spec_3d_workgroup() {
279 let spec = KernelSpec::with_workgroup_3d("test", [8, 8, 4], vec![]);
280 assert_eq!(spec.workgroup_size, [8, 8, 4]);
281 assert_eq!(spec.threads_per_workgroup(), 256);
282 }
283 #[test]
284 fn test_kernel_spec_num_workgroups() {
285 let spec = KernelSpec::new("test", 64, vec![]);
286 assert_eq!(spec.num_workgroups_x(100), 2);
287 assert_eq!(spec.num_workgroups_x(64), 1);
288 assert_eq!(spec.num_workgroups_x(65), 2);
289 }
290 #[test]
291 fn test_gpu_buffer_fill_and_clear() {
292 let mut buf = GpuBuffer::new(5);
293 buf.fill(42.0);
294 assert!(buf.data.iter().all(|&v| (v - 42.0).abs() < 1e-12));
295 buf.clear();
296 assert!(buf.data.iter().all(|&v| v == 0.0));
297 }
298 #[test]
299 fn test_gpu_buffer_byte_size() {
300 let buf = GpuBuffer::new(10);
301 assert_eq!(buf.byte_size(), 80);
302 }
303 #[test]
304 fn test_gpu_buffer_as_slice() {
305 let buf = GpuBuffer::from_data(vec![1.0, 2.0, 3.0]);
306 assert_eq!(buf.as_slice(), &[1.0, 2.0, 3.0]);
307 }
308 #[test]
309 fn test_cpu_backend_num_buffers() {
310 let backend = CpuBackend::new();
311 assert_eq!(backend.num_buffers(), 0);
312 backend.create_buffer(10);
313 assert_eq!(backend.num_buffers(), 1);
314 backend.create_buffer(5);
315 assert_eq!(backend.num_buffers(), 2);
316 }
317 #[test]
318 fn test_cpu_backend_total_elements() {
319 let backend = CpuBackend::new();
320 backend.create_buffer(10);
321 backend.create_buffer(5);
322 assert_eq!(backend.total_elements(), 15);
323 }
324 #[test]
325 fn test_dispatcher_num_buffers() {
326 let mut d = ComputeDispatcher::new();
327 assert_eq!(d.num_buffers(), 0);
328 d.create_buffer(5, None);
329 assert_eq!(d.num_buffers(), 1);
330 }
331 #[test]
332 fn test_dispatcher_has_buffer() {
333 let mut d = ComputeDispatcher::new();
334 let id = d.create_buffer(5, None);
335 assert!(d.has_buffer(id));
336 assert!(!d.has_buffer(BufferId(999)));
337 }
338 #[test]
339 fn test_dispatcher_buffer_size() {
340 let mut d = ComputeDispatcher::new();
341 let id = d.create_buffer(7, None);
342 assert_eq!(d.buffer_size(id).unwrap(), 7);
343 }
344 #[test]
345 fn test_dispatcher_destroy_buffer() {
346 let mut d = ComputeDispatcher::new();
347 let id = d.create_buffer(5, None);
348 assert!(d.has_buffer(id));
349 d.destroy_buffer(id).unwrap();
350 assert!(!d.has_buffer(id));
351 }
352 #[test]
353 fn test_dispatcher_destroy_invalid_buffer_errors() {
354 let mut d = ComputeDispatcher::new();
355 assert_eq!(
356 d.destroy_buffer(BufferId(42)),
357 Err(GpuError::InvalidBuffer(BufferId(42)))
358 );
359 }
360 #[test]
361 fn test_dispatcher_copy_buffer() {
362 let mut d = ComputeDispatcher::new();
363 let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
364 let dst = d.create_buffer(3, None);
365 d.copy_buffer(src, dst).unwrap();
366 assert_eq!(d.read_buffer(dst).unwrap(), vec![1.0, 2.0, 3.0]);
367 }
368 #[test]
369 fn test_dispatcher_copy_buffer_size_mismatch() {
370 let mut d = ComputeDispatcher::new();
371 let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
372 let dst = d.create_buffer(5, None);
373 assert!(d.copy_buffer(src, dst).is_err());
374 }
375 #[test]
376 fn test_dispatch_map_indexed() {
377 let mut d = ComputeDispatcher::new();
378 let src = d.create_buffer(4, Some(&[10.0, 20.0, 30.0, 40.0]));
379 let dst = d.create_buffer(4, None);
380 d.dispatch_map_indexed(src, dst, |i, x| x + i as f64)
381 .unwrap();
382 assert_eq!(d.read_buffer(dst).unwrap(), vec![10.0, 21.0, 32.0, 43.0]);
383 }
384 #[test]
385 fn test_dispatch_zip_map() {
386 let mut d = ComputeDispatcher::new();
387 let a = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
388 let b = d.create_buffer(3, Some(&[10.0, 20.0, 30.0]));
389 let out = d.create_buffer(3, None);
390 d.dispatch_zip_map(a, b, out, |x, y| x + y).unwrap();
391 assert_eq!(d.read_buffer(out).unwrap(), vec![11.0, 22.0, 33.0]);
392 }
393 #[test]
394 fn test_compute_pass_recording() {
395 let mut pass = ComputePass::new();
396 assert_eq!(pass.num_commands(), 0);
397 pass.dispatch("density", 1000);
398 pass.dispatch("force", 1000);
399 pass.dispatch("integrate", 1000);
400 assert_eq!(pass.num_commands(), 3);
401 assert_eq!(pass.total_work_items(), 3000);
402 assert_eq!(pass.commands()[0].0, "density");
403 assert_eq!(pass.commands()[1].1, 1000);
404 }
405 #[test]
406 fn test_compute_pass_clear() {
407 let mut pass = ComputePass::new();
408 pass.dispatch("test", 100);
409 assert_eq!(pass.num_commands(), 1);
410 pass.clear();
411 assert_eq!(pass.num_commands(), 0);
412 }
413 #[test]
414 fn test_resource_lifecycle_tracking() {
415 let mut lifecycle = ResourceLifecycle::new();
416 assert!(lifecycle.is_empty());
417 let id = BufferId(0);
418 lifecycle.record_create(id, 100);
419 lifecycle.record_write(id);
420 lifecycle.record_write(id);
421 lifecycle.record_read(id);
422 assert_eq!(lifecycle.len(), 4);
423 assert_eq!(lifecycle.count_writes(id), 2);
424 assert_eq!(lifecycle.count_reads(id), 1);
425 }
426 #[test]
427 fn test_resource_lifecycle_clear() {
428 let mut lifecycle = ResourceLifecycle::new();
429 lifecycle.record_create(BufferId(0), 10);
430 lifecycle.clear();
431 assert!(lifecycle.is_empty());
432 }
433 #[test]
434 fn test_compute_num_workgroups() {
435 assert_eq!(compute_num_workgroups(100, 64), 2);
436 assert_eq!(compute_num_workgroups(64, 64), 1);
437 assert_eq!(compute_num_workgroups(1, 64), 1);
438 }
439 #[test]
440 fn test_compute_num_workgroups_3d() {
441 let wg = compute_num_workgroups_3d([100, 100, 100], [8, 8, 8]);
442 assert_eq!(wg, [13, 13, 13]);
443 }
444 #[test]
445 fn test_gpu_error_display() {
446 let e = GpuError::InvalidBuffer(BufferId(5));
447 assert!(format!("{e}").contains("5"));
448 let e2 = GpuError::SizeMismatch {
449 expected: 10,
450 got: 5,
451 };
452 assert!(format!("{e2}").contains("10"));
453 let e3 = GpuError::EmptyBuffer;
454 assert!(format!("{e3}").contains("empty"));
455 let e4 = GpuError::NotFound("test".to_string());
456 assert!(format!("{e4}").contains("test"));
457 }
458 #[test]
459 fn test_command_encoder_basic() {
460 let mut enc = GpuCommandEncoder::new("test_pass");
461 assert_eq!(enc.label(), "test_pass");
462 assert_eq!(enc.command_count(), 0);
463 enc.dispatch_compute("density", [64, 1, 1]);
464 enc.dispatch_compute("force", [64, 1, 1]);
465 enc.insert_barrier(PipelineBarrier::StorageReadAfterWrite);
466 assert_eq!(enc.command_count(), 3);
467 }
468 #[test]
469 fn test_command_encoder_reset() {
470 let mut enc = GpuCommandEncoder::new("enc");
471 enc.dispatch_compute("k", [1, 1, 1]);
472 enc.reset();
473 assert_eq!(enc.command_count(), 0);
474 }
475 #[test]
476 fn test_command_encoder_submit_copies() {
477 let mut enc = GpuCommandEncoder::new("enc");
478 let mut d = ComputeDispatcher::new();
479 let src = d.create_buffer(3, Some(&[1.0, 2.0, 3.0]));
480 let dst = d.create_buffer(3, None);
481 enc.copy_buffer(src, dst, 3);
482 enc.submit(&mut d).unwrap();
483 assert_eq!(d.read_buffer(dst).unwrap(), vec![1.0, 2.0, 3.0]);
484 }
485 #[test]
486 fn test_command_encoder_push_constant() {
487 let mut enc = GpuCommandEncoder::new("enc");
488 enc.push_constant("dt", 0.001);
489 assert_eq!(enc.command_count(), 1);
490 match &enc.commands()[0] {
491 GpuCommand::PushConstant { name, value } => {
492 assert_eq!(name, "dt");
493 assert!((value - 0.001).abs() < 1e-15);
494 }
495 _ => panic!("expected PushConstant"),
496 }
497 }
498 #[test]
499 fn test_required_barrier_overlap() {
500 let a_out = vec![BufferId(0), BufferId(1)];
501 let b_in = vec![BufferId(1), BufferId(2)];
502 let barrier = required_barrier(&a_out, &b_in);
503 assert_eq!(barrier, PipelineBarrier::StorageReadAfterWrite);
504 }
505 #[test]
506 fn test_required_barrier_no_overlap() {
507 let a_out = vec![BufferId(0)];
508 let b_in = vec![BufferId(5)];
509 let barrier = required_barrier(&a_out, &b_in);
510 assert_eq!(barrier, PipelineBarrier::None);
511 }
512 #[test]
513 fn test_detect_aliasing_conflict() {
514 let bindings = vec![
515 BufferBinding::write(0, BufferId(10)),
516 BufferBinding::read(1, BufferId(10)),
517 ];
518 let conflicts = detect_aliasing(&bindings);
519 assert!(!conflicts.is_empty(), "should detect aliasing conflict");
520 }
521 #[test]
522 fn test_detect_aliasing_no_conflict() {
523 let bindings = vec![
524 BufferBinding::read(0, BufferId(10)),
525 BufferBinding::read(1, BufferId(11)),
526 ];
527 let conflicts = detect_aliasing(&bindings);
528 assert!(conflicts.is_empty(), "no conflict expected");
529 }
530 #[test]
531 fn test_detect_aliasing_same_buffer_two_reads() {
532 let bindings = vec![
533 BufferBinding::read(0, BufferId(5)),
534 BufferBinding::read(1, BufferId(5)),
535 ];
536 let conflicts = detect_aliasing(&bindings);
537 assert!(conflicts.is_empty());
538 }
539 #[test]
540 fn test_timeline_semaphore_signal_and_wait() {
541 let mut sem = TimelineSemaphore::new();
542 assert_eq!(sem.current_value(), 0);
543 sem.signal(1);
544 assert_eq!(sem.current_value(), 1);
545 assert!(sem.wait(1));
546 assert!(!sem.wait(2));
547 sem.signal(3);
548 assert!(sem.wait(3));
549 assert_eq!(sem.signal_count(), 2);
550 }
551 #[test]
552 fn test_timeline_semaphore_default() {
553 let sem = TimelineSemaphore::default();
554 assert_eq!(sem.current_value(), 0);
555 }
556 #[test]
557 fn test_occupancy_full_when_unconstrained() {
558 let model = OccupancyModel::mid_range();
559 let occ = model.estimate_occupancy(64, 0, 32);
560 assert!(
561 occ > 0.5,
562 "occupancy should be high for small workgroup, got {occ}"
563 );
564 }
565 #[test]
566 fn test_occupancy_limited_by_shared_memory() {
567 let model = OccupancyModel::mid_range();
568 let occ = model.estimate_occupancy(64, model.shared_mem_per_cu, 1);
569 let occ_limited = model.estimate_occupancy(64, model.shared_mem_per_cu / 2, 1);
570 assert!(
571 occ <= occ_limited,
572 "more smem usage should give lower or equal occupancy"
573 );
574 }
575 #[test]
576 fn test_occupancy_bounded_to_one() {
577 let model = OccupancyModel::mid_range();
578 let occ = model.estimate_occupancy(1, 0, 0);
579 assert!((0.0..=1.0).contains(&occ));
580 }
581 #[test]
582 fn test_peak_gflops_positive() {
583 let model = OccupancyModel::mid_range();
584 let gflops = model.peak_gflops(1500.0);
585 assert!(gflops > 0.0);
586 }
587 #[test]
588 fn test_warp_divergence_none() {
589 let predicates = vec![true; 32];
590 let rec = analyse_warp_divergence(&predicates, 32);
591 assert_eq!(rec.divergent_branches, 0);
592 assert!((rec.divergence_rate()).abs() < 1e-12);
593 }
594 #[test]
595 fn test_warp_divergence_full() {
596 let predicates: Vec<bool> = (0..32).map(|i| i % 2 == 0).collect();
597 let rec = analyse_warp_divergence(&predicates, 32);
598 assert_eq!(rec.divergent_branches, 1);
599 assert!((rec.divergence_rate() - 1.0).abs() < 1e-12);
600 }
601 #[test]
602 fn test_warp_divergence_penalty() {
603 let rec = WarpDivergenceRecord {
604 total_branches: 10,
605 divergent_branches: 5,
606 };
607 let penalty = rec.performance_penalty(32);
608 assert!(
609 penalty > 1.0 && penalty < 2.0,
610 "penalty should be > 1, got {penalty}"
611 );
612 }
613 #[test]
614 fn test_warp_divergence_empty() {
615 let rec = analyse_warp_divergence(&[], 32);
616 assert_eq!(rec.total_branches, 0);
617 assert!((rec.divergence_rate()).abs() < 1e-12);
618 }
619 #[test]
620 fn test_memory_bandwidth_arithmetic_intensity() {
621 let intensity = MemoryBandwidthModel::arithmetic_intensity(1000.0, 100.0);
622 assert!((intensity - 10.0).abs() < 1e-12);
623 }
624 #[test]
625 fn test_memory_bandwidth_zero_bytes() {
626 let intensity = MemoryBandwidthModel::arithmetic_intensity(100.0, 0.0);
627 assert!(intensity.is_infinite());
628 }
629 #[test]
630 fn test_roofline_bandwidth_bound() {
631 let model = MemoryBandwidthModel::mid_range();
632 let perf = model.roofline_performance(0.1);
633 let expected = 0.1 * model.peak_bandwidth_gbs;
634 assert!(
635 (perf - expected).abs() < 1e-6,
636 "bandwidth-bound perf mismatch"
637 );
638 }
639 #[test]
640 fn test_roofline_compute_bound() {
641 let model = MemoryBandwidthModel::mid_range();
642 let perf = model.roofline_performance(1e9);
643 assert!((perf - model.peak_compute_gflops).abs() < 1e-6);
644 }
645 #[test]
646 fn test_is_bandwidth_bound() {
647 let model = MemoryBandwidthModel::mid_range();
648 let ridge = model.peak_compute_gflops / model.peak_bandwidth_gbs;
649 assert!(model.is_bandwidth_bound(ridge * 0.5));
650 assert!(!model.is_bandwidth_bound(ridge * 2.0));
651 }
652 #[test]
653 fn test_estimated_runtime_ms_positive() {
654 let model = MemoryBandwidthModel::mid_range();
655 let t = model.estimated_runtime_ms(1e12, 1e9);
656 assert!(t > 0.0 && t.is_finite());
657 }
658 #[test]
659 fn test_reduction_tree_sum() {
660 let mut d = ComputeDispatcher::new();
661 let buf = d.create_buffer(4, Some(&[1.0, 2.0, 3.0, 4.0]));
662 let result = d.dispatch_reduction_tree(buf).unwrap();
663 assert!(
664 (result - 10.0).abs() < 1e-12,
665 "sum should be 10, got {result}"
666 );
667 }
668 #[test]
669 fn test_reduction_tree_empty() {
670 let mut d = ComputeDispatcher::new();
671 let buf = d.create_buffer(0, Some(&[]));
672 let result = d.dispatch_reduction_tree(buf).unwrap();
673 assert_eq!(result, 0.0);
674 }
675 #[test]
676 fn test_reduction_tree_single_element() {
677 let mut d = ComputeDispatcher::new();
678 let buf = d.create_buffer(1, Some(&[42.0]));
679 let result = d.dispatch_reduction_tree(buf).unwrap();
680 assert!((result - 42.0).abs() < 1e-12);
681 }
682 #[test]
683 fn test_reduction_tree_power_of_two() {
684 let data: Vec<f64> = (1..=8).map(|x| x as f64).collect();
685 let mut d = ComputeDispatcher::new();
686 let buf = d.create_buffer(8, Some(&data));
687 let result = d.dispatch_reduction_tree(buf).unwrap();
688 assert!((result - 36.0).abs() < 1e-12, "1+2+…+8=36, got {result}");
689 }
690 #[test]
691 fn test_inclusive_scan_basic() {
692 let mut d = ComputeDispatcher::new();
693 let buf_in = d.create_buffer(4, Some(&[1.0, 2.0, 3.0, 4.0]));
694 let buf_out = d.create_buffer(4, None);
695 d.dispatch_inclusive_scan(buf_in, buf_out).unwrap();
696 let result = d.read_buffer(buf_out).unwrap();
697 let expected = [1.0, 3.0, 6.0, 10.0];
698 for (a, b) in result.iter().zip(expected.iter()) {
699 assert!((a - b).abs() < 1e-12, "mismatch: {a} vs {b}");
700 }
701 }
702 #[test]
703 fn test_inclusive_scan_single() {
704 let mut d = ComputeDispatcher::new();
705 let buf_in = d.create_buffer(1, Some(&[7.0]));
706 let buf_out = d.create_buffer(1, None);
707 d.dispatch_inclusive_scan(buf_in, buf_out).unwrap();
708 let result = d.read_buffer(buf_out).unwrap();
709 assert!((result[0] - 7.0).abs() < 1e-12);
710 }
711 #[test]
712 fn test_radix_sort_basic() {
713 let data = vec![5.0, 1.0, 3.0, 2.0, 4.0];
714 let mut d = ComputeDispatcher::new();
715 let buf = d.create_buffer(5, Some(&data));
716 let sorted = d.dispatch_radix_sort(buf).unwrap();
717 for w in sorted.windows(2) {
718 assert!(w[0] <= w[1], "not sorted: {} > {}", w[0], w[1]);
719 }
720 }
721 #[test]
722 fn test_radix_sort_empty() {
723 let mut d = ComputeDispatcher::new();
724 let buf = d.create_buffer(0, Some(&[]));
725 let sorted = d.dispatch_radix_sort(buf).unwrap();
726 assert!(sorted.is_empty());
727 }
728 #[test]
729 fn test_radix_sort_already_sorted() {
730 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
731 let mut d = ComputeDispatcher::new();
732 let buf = d.create_buffer(5, Some(&data));
733 let sorted = d.dispatch_radix_sort(buf).unwrap();
734 assert_eq!(sorted, data);
735 }
736 #[test]
737 fn test_radix_sort_length_preserved() {
738 let data: Vec<f64> = (0..16).map(|i| (16 - i) as f64).collect();
739 let mut d = ComputeDispatcher::new();
740 let buf = d.create_buffer(16, Some(&data));
741 let sorted = d.dispatch_radix_sort(buf).unwrap();
742 assert_eq!(sorted.len(), 16);
743 }
744}