1#![allow(missing_docs)]
10#![allow(ambiguous_glob_reexports)]
11#![allow(dead_code)]
12
13mod error;
14pub use error::*;
15
16pub mod bvh;
17pub mod cell_list;
18pub mod compute;
19pub mod compute_pipeline;
20pub mod flux_compute;
21pub mod gpu_bench;
22pub mod grid_reduce;
23pub mod kernels;
24pub mod lbm_gpu;
25pub mod neural_compute;
26pub mod parallel;
27pub mod parallel_sort;
28pub mod particle_system;
29pub mod pipeline;
30pub mod sdf_compute;
31pub mod shader_registry;
32pub mod shaders;
33pub mod sparse_gpu;
34pub mod sph_gpu;
35
36pub use compute::{BufferHandle, ComputeBackend, ComputeKernel, CpuBackend};
37pub use neural_compute::*;
38pub use particle_system::*;
39pub use sparse_gpu::*;
40
41pub fn dispatch_count(total: usize, group_size: usize) -> usize {
47 if group_size == 0 {
48 return 0;
49 }
50 total.div_ceil(group_size)
51}
52
53pub fn aligned_size(size: usize, alignment: usize) -> usize {
57 if alignment == 0 {
58 return size;
59 }
60 size.div_ceil(alignment) * alignment
61}
62
63#[allow(dead_code)]
65pub fn linear_index_3d(x: usize, y: usize, z: usize, dim_x: usize, dim_y: usize) -> usize {
66 z * dim_x * dim_y + y * dim_x + x
67}
68
69#[allow(dead_code)]
71pub fn index_3d_from_linear(index: usize, dim_x: usize, dim_y: usize) -> (usize, usize, usize) {
72 let z = index / (dim_x * dim_y);
73 let rem = index % (dim_x * dim_y);
74 let y = rem / dim_x;
75 let x = rem % dim_x;
76 (x, y, z)
77}
78
79#[derive(Debug, Clone)]
81pub struct DispatchTimer {
82 pub label: String,
84 pub elapsed_secs: f64,
86}
87
88impl DispatchTimer {
89 pub fn new(label: impl Into<String>) -> Self {
91 Self {
92 label: label.into(),
93 elapsed_secs: 0.0,
94 }
95 }
96
97 pub fn record(&mut self, elapsed: f64) {
99 self.elapsed_secs = elapsed;
100 }
101}
102
103#[allow(dead_code)]
108pub fn bandwidth_gb_s(bytes_transferred: usize, elapsed_secs: f64) -> f64 {
109 if elapsed_secs <= 0.0 {
110 return 0.0;
111 }
112 (bytes_transferred as f64) / elapsed_secs / 1e9
113}
114
115#[allow(dead_code)]
120pub fn elements_in_budget(budget_bytes: usize, element_size: usize) -> usize {
121 if element_size == 0 {
122 return 0;
123 }
124 budget_bytes / element_size
125}
126
127#[allow(dead_code)]
134pub fn row_pitch(elements_per_row: usize, element_size: usize, alignment: usize) -> usize {
135 let raw = elements_per_row * element_size;
136 aligned_size(raw, alignment)
137}
138
139#[allow(dead_code)]
141pub fn buffer_size_2d(
142 width: usize,
143 height: usize,
144 element_size: usize,
145 row_alignment: usize,
146) -> usize {
147 row_pitch(width, element_size, row_alignment) * height
148}
149
150pub fn next_power_of_two(value: usize) -> usize {
155 if value == 0 {
156 return 1;
157 }
158 let mut p = 1usize;
159 while p < value {
160 p <<= 1;
161 }
162 p
163}
164
165pub fn is_power_of_two(value: usize) -> bool {
167 value != 0 && (value & (value - 1)) == 0
168}
169
170pub fn log2_pow2(v: usize) -> u32 {
173 debug_assert!(is_power_of_two(v), "{v} is not a power of two");
174 v.trailing_zeros()
175}
176
177pub fn tile_count_2d(width: usize, height: usize, tw: usize, th: usize) -> (usize, usize) {
184 let tx = width.div_ceil(tw);
185 let ty = height.div_ceil(th);
186 (tx, ty)
187}
188
189pub fn total_tiles_2d(width: usize, height: usize, tw: usize, th: usize) -> usize {
191 let (tx, ty) = tile_count_2d(width, height, tw, th);
192 tx * ty
193}
194
195pub fn tile_index_to_2d(flat: usize, tiles_x: usize) -> (usize, usize) {
198 (flat % tiles_x, flat / tiles_x)
199}
200
201pub fn clamp_f64(v: f64, lo: f64, hi: f64) -> f64 {
205 v.max(lo).min(hi)
206}
207
208pub fn smoothstep(lo: f64, hi: f64, v: f64) -> f64 {
210 let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
211 t * t * (3.0 - 2.0 * t)
212}
213
214pub fn smootherstep(lo: f64, hi: f64, v: f64) -> f64 {
216 let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
217 t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
218}
219
220pub fn lerp(a: f64, b: f64, t: f64) -> f64 {
222 a + t * (b - a)
223}
224
225pub fn inv_lerp(a: f64, b: f64, v: f64) -> f64 {
227 if (b - a).abs() < f64::EPSILON {
228 return 0.0;
229 }
230 (v - a) / (b - a)
231}
232
233pub fn safe_recip(x: f64, eps: f64) -> f64 {
237 if x.abs() > eps { 1.0 / x } else { 0.0 }
238}
239
240pub fn safe_sqrt(x: f64) -> f64 {
242 x.max(0.0).sqrt()
243}
244
245pub fn wrap_angle(theta: f64) -> f64 {
247 use std::f64::consts::PI;
248 let mut t = theta % (2.0 * PI);
249 if t > PI {
250 t -= 2.0 * PI;
251 }
252 if t <= -PI {
253 t += 2.0 * PI;
254 }
255 t
256}
257
258pub fn dot3(a: [f64; 3], b: [f64; 3]) -> f64 {
262 a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
263}
264
265pub fn cross3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
267 [
268 a[1] * b[2] - a[2] * b[1],
269 a[2] * b[0] - a[0] * b[2],
270 a[0] * b[1] - a[1] * b[0],
271 ]
272}
273
274pub fn length3(v: [f64; 3]) -> f64 {
276 dot3(v, v).sqrt()
277}
278
279pub fn normalize3(v: [f64; 3]) -> [f64; 3] {
281 let len = length3(v);
282 if len < 1e-15 {
283 return [0.0; 3];
284 }
285 [v[0] / len, v[1] / len, v[2] / len]
286}
287
288pub fn reflect3(d: [f64; 3], n: [f64; 3]) -> [f64; 3] {
290 let dn2 = 2.0 * dot3(d, n);
291 [d[0] - dn2 * n[0], d[1] - dn2 * n[1], d[2] - dn2 * n[2]]
292}
293
294pub fn exclusive_scan(data: &[f64]) -> Vec<f64> {
301 let mut result = Vec::with_capacity(data.len());
302 let mut acc = 0.0;
303 for &v in data {
304 result.push(acc);
305 acc += v;
306 }
307 result
308}
309
310pub fn inclusive_scan(data: &[f64]) -> Vec<f64> {
312 let mut result = Vec::with_capacity(data.len());
313 let mut acc = 0.0;
314 for &v in data {
315 acc += v;
316 result.push(acc);
317 }
318 result
319}
320
321pub fn reduce_sum(data: &[f64]) -> f64 {
323 data.iter().copied().sum()
324}
325
326pub fn reduce_max(data: &[f64]) -> f64 {
328 data.iter().copied().fold(f64::NEG_INFINITY, f64::max)
329}
330
331pub fn reduce_min(data: &[f64]) -> f64 {
333 data.iter().copied().fold(f64::INFINITY, f64::min)
334}
335
336#[cfg(test)]
337mod gpu_util_tests {
338 use super::*;
339 use std::f64::consts::PI;
340
341 #[test]
342 fn test_dispatch_count_exact() {
343 assert_eq!(dispatch_count(256, 64), 4);
344 }
345
346 #[test]
347 fn test_dispatch_count_remainder() {
348 assert_eq!(dispatch_count(257, 64), 5);
349 }
350
351 #[test]
352 fn test_dispatch_count_zero_group() {
353 assert_eq!(dispatch_count(100, 0), 0);
354 }
355
356 #[test]
357 fn test_aligned_size_exact() {
358 assert_eq!(aligned_size(256, 64), 256);
359 }
360
361 #[test]
362 fn test_aligned_size_pad() {
363 assert_eq!(aligned_size(257, 64), 320);
364 }
365
366 #[test]
367 fn test_aligned_size_zero_alignment() {
368 assert_eq!(aligned_size(100, 0), 100);
369 }
370
371 #[test]
372 fn test_linear_index_3d() {
373 assert_eq!(linear_index_3d(0, 0, 0, 4, 3), 0);
375 assert_eq!(linear_index_3d(3, 2, 1, 4, 3), 12 + 2 * 4 + 3);
376 }
377
378 #[test]
379 fn test_index_3d_roundtrip() {
380 let (dx, dy) = (4, 3);
381 for z in 0..2 {
382 for y in 0..dy {
383 for x in 0..dx {
384 let idx = linear_index_3d(x, y, z, dx, dy);
385 let (rx, ry, rz) = index_3d_from_linear(idx, dx, dy);
386 assert_eq!((rx, ry, rz), (x, y, z));
387 }
388 }
389 }
390 }
391
392 #[test]
393 fn test_dispatch_timer() {
394 let mut timer = DispatchTimer::new("test");
395 assert_eq!(timer.label, "test");
396 timer.record(0.5);
397 assert!((timer.elapsed_secs - 0.5).abs() < 1e-10);
398 }
399
400 #[test]
401 fn test_bandwidth_gb_s() {
402 let bw = bandwidth_gb_s(1_000_000_000, 1.0);
404 assert!((bw - 1.0).abs() < 1e-6);
405 }
406
407 #[test]
408 fn test_bandwidth_zero_time() {
409 assert!((bandwidth_gb_s(1000, 0.0)).abs() < 1e-10);
410 }
411
412 #[test]
413 fn test_elements_in_budget() {
414 assert_eq!(elements_in_budget(1024, 4), 256);
415 assert_eq!(elements_in_budget(1024, 0), 0);
416 }
417
418 #[test]
419 fn test_exclusive_scan() {
420 let data = [1.0, 2.0, 3.0, 4.0];
421 let result = exclusive_scan(&data);
422 assert_eq!(result, vec![0.0, 1.0, 3.0, 6.0]);
423 }
424
425 #[test]
426 fn test_inclusive_scan() {
427 let data = [1.0, 2.0, 3.0, 4.0];
428 let result = inclusive_scan(&data);
429 assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0]);
430 }
431
432 #[test]
433 fn test_reduce_sum() {
434 assert!((reduce_sum(&[1.0, 2.0, 3.0]) - 6.0).abs() < 1e-10);
435 }
436
437 #[test]
438 fn test_reduce_max() {
439 assert!((reduce_max(&[1.0, 5.0, 3.0]) - 5.0).abs() < 1e-10);
440 }
441
442 #[test]
443 fn test_reduce_min() {
444 assert!((reduce_min(&[1.0, 5.0, 3.0]) - 1.0).abs() < 1e-10);
445 }
446
447 #[test]
448 fn test_exclusive_scan_empty() {
449 let result = exclusive_scan(&[]);
450 assert!(result.is_empty());
451 }
452
453 #[test]
454 fn test_inclusive_scan_single() {
455 let result = inclusive_scan(&[42.0]);
456 assert_eq!(result, vec![42.0]);
457 }
458
459 #[test]
462 fn test_row_pitch_aligned() {
463 assert_eq!(row_pitch(128, 4, 256), 512);
465 }
466
467 #[test]
468 fn test_row_pitch_needs_padding() {
469 assert_eq!(row_pitch(100, 4, 256), 512);
471 }
472
473 #[test]
474 fn test_buffer_size_2d() {
475 assert_eq!(buffer_size_2d(64, 4, 4, 256), 1024);
477 }
478
479 #[test]
480 fn test_next_power_of_two() {
481 assert_eq!(next_power_of_two(0), 1);
482 assert_eq!(next_power_of_two(1), 1);
483 assert_eq!(next_power_of_two(5), 8);
484 assert_eq!(next_power_of_two(8), 8);
485 assert_eq!(next_power_of_two(9), 16);
486 }
487
488 #[test]
489 fn test_is_power_of_two() {
490 assert!(is_power_of_two(1));
491 assert!(is_power_of_two(16));
492 assert!(!is_power_of_two(0));
493 assert!(!is_power_of_two(7));
494 }
495
496 #[test]
497 fn test_log2_pow2() {
498 assert_eq!(log2_pow2(1), 0);
499 assert_eq!(log2_pow2(2), 1);
500 assert_eq!(log2_pow2(256), 8);
501 }
502
503 #[test]
504 fn test_tile_count_2d_exact() {
505 let (tx, ty) = tile_count_2d(64, 64, 16, 16);
506 assert_eq!(tx, 4);
507 assert_eq!(ty, 4);
508 }
509
510 #[test]
511 fn test_tile_count_2d_remainder() {
512 let (tx, ty) = tile_count_2d(65, 65, 16, 16);
513 assert_eq!(tx, 5);
514 assert_eq!(ty, 5);
515 }
516
517 #[test]
518 fn test_total_tiles_2d() {
519 assert_eq!(total_tiles_2d(64, 64, 16, 16), 16);
520 }
521
522 #[test]
523 fn test_tile_index_to_2d() {
524 assert_eq!(tile_index_to_2d(5, 4), (1, 1));
526 assert_eq!(tile_index_to_2d(0, 4), (0, 0));
527 }
528
529 #[test]
532 fn test_smoothstep_edges() {
533 assert!((smoothstep(0.0, 1.0, 0.0) - 0.0).abs() < 1e-12);
534 assert!((smoothstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
535 }
536
537 #[test]
538 fn test_smoothstep_midpoint() {
539 assert!((smoothstep(0.0, 1.0, 0.5) - 0.5).abs() < 1e-12);
541 }
542
543 #[test]
544 fn test_smootherstep_edges() {
545 assert!((smootherstep(0.0, 1.0, 0.0)).abs() < 1e-12);
546 assert!((smootherstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
547 }
548
549 #[test]
550 fn test_lerp_inv_lerp_roundtrip() {
551 let a = 10.0;
552 let b = 20.0;
553 let t = 0.3;
554 let v = lerp(a, b, t);
555 assert!((inv_lerp(a, b, v) - t).abs() < 1e-12);
556 }
557
558 #[test]
559 fn test_safe_recip_normal() {
560 assert!((safe_recip(2.0, 1e-9) - 0.5).abs() < 1e-12);
561 }
562
563 #[test]
564 fn test_safe_recip_near_zero() {
565 assert!((safe_recip(1e-15, 1e-9)).abs() < 1e-12);
566 }
567
568 #[test]
569 fn test_safe_sqrt_positive() {
570 assert!((safe_sqrt(9.0) - 3.0).abs() < 1e-12);
571 }
572
573 #[test]
574 fn test_safe_sqrt_negative() {
575 assert!((safe_sqrt(-1.0)).abs() < 1e-12);
576 }
577
578 #[test]
579 fn test_wrap_angle_in_range() {
580 let wrapped = wrap_angle(3.0 * PI);
581 assert!(wrapped.abs() <= PI + 1e-12, "wrapped = {wrapped}");
582 }
583
584 #[test]
587 fn test_dot3() {
588 let a = [1.0, 2.0, 3.0];
589 let b = [4.0, 5.0, 6.0];
590 assert!((dot3(a, b) - 32.0).abs() < 1e-12);
591 }
592
593 #[test]
594 fn test_cross3() {
595 let i = [1.0, 0.0, 0.0];
596 let j = [0.0, 1.0, 0.0];
597 let k = cross3(i, j);
598 assert!((k[0]).abs() < 1e-12);
599 assert!((k[1]).abs() < 1e-12);
600 assert!((k[2] - 1.0).abs() < 1e-12);
601 }
602
603 #[test]
604 fn test_length3() {
605 let v = [3.0, 4.0, 0.0];
606 assert!((length3(v) - 5.0).abs() < 1e-12);
607 }
608
609 #[test]
610 fn test_normalize3() {
611 let v = [0.0, 0.0, 5.0];
612 let n = normalize3(v);
613 assert!((length3(n) - 1.0).abs() < 1e-12);
614 assert!((n[2] - 1.0).abs() < 1e-12);
615 }
616
617 #[test]
618 fn test_normalize3_zero_vec() {
619 let n = normalize3([0.0; 3]);
620 assert_eq!(n, [0.0; 3]);
621 }
622
623 #[test]
624 fn test_reflect3() {
625 let d = [0.0, -1.0, 0.0]; let n = [0.0, 1.0, 0.0]; let r = reflect3(d, n);
629 assert!((r[1] - 1.0).abs() < 1e-12);
631 }
632}
633pub mod collision_gpu;
634pub mod deformable_gpu;
635pub mod fluid_gpu;
636pub mod fluid_sim_gpu;
637pub mod gpu_cloth;
638pub mod gpu_collision_detection;
639pub mod gpu_collision_ext;
640pub mod gpu_fem_assembly;
641pub mod gpu_fluid;
642pub mod gpu_fluid_euler;
643pub mod gpu_lbm;
644pub mod gpu_md_solver;
645pub mod gpu_mesh_processing;
646pub mod gpu_neural_solver;
647pub mod gpu_nn;
648pub mod gpu_particle_system;
649pub mod gpu_particles;
650pub mod gpu_ray_tracing;
651pub mod gpu_reduction;
652pub mod gpu_rigid;
653pub mod gpu_sdf;
654pub mod gpu_sort;
655pub mod gpu_sparse_solver;
656pub mod gpu_sph_density;
657pub mod gpu_sph_pressure;
658pub mod gpu_sph_solver;
659pub mod gpu_thermal;
660pub mod gpu_voxel;
661pub mod memory;
662pub mod neural_physics;
663pub mod path_tracer;
664pub mod ray_marching;
665pub mod ray_tracing_gpu;
666pub mod raytracing;
667pub mod scheduler;