1#![allow(missing_docs)]
10#![allow(ambiguous_glob_reexports)]
11#![allow(dead_code)]
12
13mod error;
14pub use error::*;
15
16pub mod bvh;
17pub mod cell_list;
18pub mod compute;
19pub mod compute_pipeline;
20pub mod flux_compute;
21pub mod grid_reduce;
22pub mod kernels;
23pub mod neural_compute;
24pub mod parallel;
25pub mod parallel_sort;
26pub mod particle_system;
27pub mod pipeline;
28pub mod sdf_compute;
29pub mod shader_registry;
30pub mod shaders;
31pub mod sparse_gpu;
32
33pub use compute::{BufferHandle, ComputeBackend, ComputeKernel, CpuBackend};
34pub use neural_compute::*;
35pub use particle_system::*;
36pub use sparse_gpu::*;
37
38pub fn dispatch_count(total: usize, group_size: usize) -> usize {
44 if group_size == 0 {
45 return 0;
46 }
47 total.div_ceil(group_size)
48}
49
50pub fn aligned_size(size: usize, alignment: usize) -> usize {
54 if alignment == 0 {
55 return size;
56 }
57 size.div_ceil(alignment) * alignment
58}
59
60#[allow(dead_code)]
62pub fn linear_index_3d(x: usize, y: usize, z: usize, dim_x: usize, dim_y: usize) -> usize {
63 z * dim_x * dim_y + y * dim_x + x
64}
65
66#[allow(dead_code)]
68pub fn index_3d_from_linear(index: usize, dim_x: usize, dim_y: usize) -> (usize, usize, usize) {
69 let z = index / (dim_x * dim_y);
70 let rem = index % (dim_x * dim_y);
71 let y = rem / dim_x;
72 let x = rem % dim_x;
73 (x, y, z)
74}
75
76#[derive(Debug, Clone)]
78pub struct DispatchTimer {
79 pub label: String,
81 pub elapsed_secs: f64,
83}
84
85impl DispatchTimer {
86 pub fn new(label: impl Into<String>) -> Self {
88 Self {
89 label: label.into(),
90 elapsed_secs: 0.0,
91 }
92 }
93
94 pub fn record(&mut self, elapsed: f64) {
96 self.elapsed_secs = elapsed;
97 }
98}
99
100#[allow(dead_code)]
105pub fn bandwidth_gb_s(bytes_transferred: usize, elapsed_secs: f64) -> f64 {
106 if elapsed_secs <= 0.0 {
107 return 0.0;
108 }
109 (bytes_transferred as f64) / elapsed_secs / 1e9
110}
111
112#[allow(dead_code)]
117pub fn elements_in_budget(budget_bytes: usize, element_size: usize) -> usize {
118 if element_size == 0 {
119 return 0;
120 }
121 budget_bytes / element_size
122}
123
124#[allow(dead_code)]
131pub fn row_pitch(elements_per_row: usize, element_size: usize, alignment: usize) -> usize {
132 let raw = elements_per_row * element_size;
133 aligned_size(raw, alignment)
134}
135
136#[allow(dead_code)]
138pub fn buffer_size_2d(
139 width: usize,
140 height: usize,
141 element_size: usize,
142 row_alignment: usize,
143) -> usize {
144 row_pitch(width, element_size, row_alignment) * height
145}
146
147pub fn next_power_of_two(value: usize) -> usize {
152 if value == 0 {
153 return 1;
154 }
155 let mut p = 1usize;
156 while p < value {
157 p <<= 1;
158 }
159 p
160}
161
162pub fn is_power_of_two(value: usize) -> bool {
164 value != 0 && (value & (value - 1)) == 0
165}
166
167pub fn log2_pow2(v: usize) -> u32 {
170 debug_assert!(is_power_of_two(v), "{v} is not a power of two");
171 v.trailing_zeros()
172}
173
174pub fn tile_count_2d(width: usize, height: usize, tw: usize, th: usize) -> (usize, usize) {
181 let tx = width.div_ceil(tw);
182 let ty = height.div_ceil(th);
183 (tx, ty)
184}
185
186pub fn total_tiles_2d(width: usize, height: usize, tw: usize, th: usize) -> usize {
188 let (tx, ty) = tile_count_2d(width, height, tw, th);
189 tx * ty
190}
191
192pub fn tile_index_to_2d(flat: usize, tiles_x: usize) -> (usize, usize) {
195 (flat % tiles_x, flat / tiles_x)
196}
197
198pub fn clamp_f64(v: f64, lo: f64, hi: f64) -> f64 {
202 v.max(lo).min(hi)
203}
204
205pub fn smoothstep(lo: f64, hi: f64, v: f64) -> f64 {
207 let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
208 t * t * (3.0 - 2.0 * t)
209}
210
211pub fn smootherstep(lo: f64, hi: f64, v: f64) -> f64 {
213 let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
214 t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
215}
216
217pub fn lerp(a: f64, b: f64, t: f64) -> f64 {
219 a + t * (b - a)
220}
221
222pub fn inv_lerp(a: f64, b: f64, v: f64) -> f64 {
224 if (b - a).abs() < f64::EPSILON {
225 return 0.0;
226 }
227 (v - a) / (b - a)
228}
229
230pub fn safe_recip(x: f64, eps: f64) -> f64 {
234 if x.abs() > eps { 1.0 / x } else { 0.0 }
235}
236
237pub fn safe_sqrt(x: f64) -> f64 {
239 x.max(0.0).sqrt()
240}
241
242pub fn wrap_angle(theta: f64) -> f64 {
244 use std::f64::consts::PI;
245 let mut t = theta % (2.0 * PI);
246 if t > PI {
247 t -= 2.0 * PI;
248 }
249 if t <= -PI {
250 t += 2.0 * PI;
251 }
252 t
253}
254
255pub fn dot3(a: [f64; 3], b: [f64; 3]) -> f64 {
259 a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
260}
261
262pub fn cross3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
264 [
265 a[1] * b[2] - a[2] * b[1],
266 a[2] * b[0] - a[0] * b[2],
267 a[0] * b[1] - a[1] * b[0],
268 ]
269}
270
271pub fn length3(v: [f64; 3]) -> f64 {
273 dot3(v, v).sqrt()
274}
275
276pub fn normalize3(v: [f64; 3]) -> [f64; 3] {
278 let len = length3(v);
279 if len < 1e-15 {
280 return [0.0; 3];
281 }
282 [v[0] / len, v[1] / len, v[2] / len]
283}
284
285pub fn reflect3(d: [f64; 3], n: [f64; 3]) -> [f64; 3] {
287 let dn2 = 2.0 * dot3(d, n);
288 [d[0] - dn2 * n[0], d[1] - dn2 * n[1], d[2] - dn2 * n[2]]
289}
290
291pub fn exclusive_scan(data: &[f64]) -> Vec<f64> {
298 let mut result = Vec::with_capacity(data.len());
299 let mut acc = 0.0;
300 for &v in data {
301 result.push(acc);
302 acc += v;
303 }
304 result
305}
306
307pub fn inclusive_scan(data: &[f64]) -> Vec<f64> {
309 let mut result = Vec::with_capacity(data.len());
310 let mut acc = 0.0;
311 for &v in data {
312 acc += v;
313 result.push(acc);
314 }
315 result
316}
317
318pub fn reduce_sum(data: &[f64]) -> f64 {
320 data.iter().copied().sum()
321}
322
323pub fn reduce_max(data: &[f64]) -> f64 {
325 data.iter().copied().fold(f64::NEG_INFINITY, f64::max)
326}
327
328pub fn reduce_min(data: &[f64]) -> f64 {
330 data.iter().copied().fold(f64::INFINITY, f64::min)
331}
332
333#[cfg(test)]
334mod gpu_util_tests {
335 use super::*;
336 use std::f64::consts::PI;
337
338 #[test]
339 fn test_dispatch_count_exact() {
340 assert_eq!(dispatch_count(256, 64), 4);
341 }
342
343 #[test]
344 fn test_dispatch_count_remainder() {
345 assert_eq!(dispatch_count(257, 64), 5);
346 }
347
348 #[test]
349 fn test_dispatch_count_zero_group() {
350 assert_eq!(dispatch_count(100, 0), 0);
351 }
352
353 #[test]
354 fn test_aligned_size_exact() {
355 assert_eq!(aligned_size(256, 64), 256);
356 }
357
358 #[test]
359 fn test_aligned_size_pad() {
360 assert_eq!(aligned_size(257, 64), 320);
361 }
362
363 #[test]
364 fn test_aligned_size_zero_alignment() {
365 assert_eq!(aligned_size(100, 0), 100);
366 }
367
368 #[test]
369 fn test_linear_index_3d() {
370 assert_eq!(linear_index_3d(0, 0, 0, 4, 3), 0);
372 assert_eq!(linear_index_3d(3, 2, 1, 4, 3), 12 + 2 * 4 + 3);
373 }
374
375 #[test]
376 fn test_index_3d_roundtrip() {
377 let (dx, dy) = (4, 3);
378 for z in 0..2 {
379 for y in 0..dy {
380 for x in 0..dx {
381 let idx = linear_index_3d(x, y, z, dx, dy);
382 let (rx, ry, rz) = index_3d_from_linear(idx, dx, dy);
383 assert_eq!((rx, ry, rz), (x, y, z));
384 }
385 }
386 }
387 }
388
389 #[test]
390 fn test_dispatch_timer() {
391 let mut timer = DispatchTimer::new("test");
392 assert_eq!(timer.label, "test");
393 timer.record(0.5);
394 assert!((timer.elapsed_secs - 0.5).abs() < 1e-10);
395 }
396
397 #[test]
398 fn test_bandwidth_gb_s() {
399 let bw = bandwidth_gb_s(1_000_000_000, 1.0);
401 assert!((bw - 1.0).abs() < 1e-6);
402 }
403
404 #[test]
405 fn test_bandwidth_zero_time() {
406 assert!((bandwidth_gb_s(1000, 0.0)).abs() < 1e-10);
407 }
408
409 #[test]
410 fn test_elements_in_budget() {
411 assert_eq!(elements_in_budget(1024, 4), 256);
412 assert_eq!(elements_in_budget(1024, 0), 0);
413 }
414
415 #[test]
416 fn test_exclusive_scan() {
417 let data = [1.0, 2.0, 3.0, 4.0];
418 let result = exclusive_scan(&data);
419 assert_eq!(result, vec![0.0, 1.0, 3.0, 6.0]);
420 }
421
422 #[test]
423 fn test_inclusive_scan() {
424 let data = [1.0, 2.0, 3.0, 4.0];
425 let result = inclusive_scan(&data);
426 assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0]);
427 }
428
429 #[test]
430 fn test_reduce_sum() {
431 assert!((reduce_sum(&[1.0, 2.0, 3.0]) - 6.0).abs() < 1e-10);
432 }
433
434 #[test]
435 fn test_reduce_max() {
436 assert!((reduce_max(&[1.0, 5.0, 3.0]) - 5.0).abs() < 1e-10);
437 }
438
439 #[test]
440 fn test_reduce_min() {
441 assert!((reduce_min(&[1.0, 5.0, 3.0]) - 1.0).abs() < 1e-10);
442 }
443
444 #[test]
445 fn test_exclusive_scan_empty() {
446 let result = exclusive_scan(&[]);
447 assert!(result.is_empty());
448 }
449
450 #[test]
451 fn test_inclusive_scan_single() {
452 let result = inclusive_scan(&[42.0]);
453 assert_eq!(result, vec![42.0]);
454 }
455
456 #[test]
459 fn test_row_pitch_aligned() {
460 assert_eq!(row_pitch(128, 4, 256), 512);
462 }
463
464 #[test]
465 fn test_row_pitch_needs_padding() {
466 assert_eq!(row_pitch(100, 4, 256), 512);
468 }
469
470 #[test]
471 fn test_buffer_size_2d() {
472 assert_eq!(buffer_size_2d(64, 4, 4, 256), 1024);
474 }
475
476 #[test]
477 fn test_next_power_of_two() {
478 assert_eq!(next_power_of_two(0), 1);
479 assert_eq!(next_power_of_two(1), 1);
480 assert_eq!(next_power_of_two(5), 8);
481 assert_eq!(next_power_of_two(8), 8);
482 assert_eq!(next_power_of_two(9), 16);
483 }
484
485 #[test]
486 fn test_is_power_of_two() {
487 assert!(is_power_of_two(1));
488 assert!(is_power_of_two(16));
489 assert!(!is_power_of_two(0));
490 assert!(!is_power_of_two(7));
491 }
492
493 #[test]
494 fn test_log2_pow2() {
495 assert_eq!(log2_pow2(1), 0);
496 assert_eq!(log2_pow2(2), 1);
497 assert_eq!(log2_pow2(256), 8);
498 }
499
500 #[test]
501 fn test_tile_count_2d_exact() {
502 let (tx, ty) = tile_count_2d(64, 64, 16, 16);
503 assert_eq!(tx, 4);
504 assert_eq!(ty, 4);
505 }
506
507 #[test]
508 fn test_tile_count_2d_remainder() {
509 let (tx, ty) = tile_count_2d(65, 65, 16, 16);
510 assert_eq!(tx, 5);
511 assert_eq!(ty, 5);
512 }
513
514 #[test]
515 fn test_total_tiles_2d() {
516 assert_eq!(total_tiles_2d(64, 64, 16, 16), 16);
517 }
518
519 #[test]
520 fn test_tile_index_to_2d() {
521 assert_eq!(tile_index_to_2d(5, 4), (1, 1));
523 assert_eq!(tile_index_to_2d(0, 4), (0, 0));
524 }
525
526 #[test]
529 fn test_smoothstep_edges() {
530 assert!((smoothstep(0.0, 1.0, 0.0) - 0.0).abs() < 1e-12);
531 assert!((smoothstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
532 }
533
534 #[test]
535 fn test_smoothstep_midpoint() {
536 assert!((smoothstep(0.0, 1.0, 0.5) - 0.5).abs() < 1e-12);
538 }
539
540 #[test]
541 fn test_smootherstep_edges() {
542 assert!((smootherstep(0.0, 1.0, 0.0)).abs() < 1e-12);
543 assert!((smootherstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
544 }
545
546 #[test]
547 fn test_lerp_inv_lerp_roundtrip() {
548 let a = 10.0;
549 let b = 20.0;
550 let t = 0.3;
551 let v = lerp(a, b, t);
552 assert!((inv_lerp(a, b, v) - t).abs() < 1e-12);
553 }
554
555 #[test]
556 fn test_safe_recip_normal() {
557 assert!((safe_recip(2.0, 1e-9) - 0.5).abs() < 1e-12);
558 }
559
560 #[test]
561 fn test_safe_recip_near_zero() {
562 assert!((safe_recip(1e-15, 1e-9)).abs() < 1e-12);
563 }
564
565 #[test]
566 fn test_safe_sqrt_positive() {
567 assert!((safe_sqrt(9.0) - 3.0).abs() < 1e-12);
568 }
569
570 #[test]
571 fn test_safe_sqrt_negative() {
572 assert!((safe_sqrt(-1.0)).abs() < 1e-12);
573 }
574
575 #[test]
576 fn test_wrap_angle_in_range() {
577 let wrapped = wrap_angle(3.0 * PI);
578 assert!(wrapped.abs() <= PI + 1e-12, "wrapped = {wrapped}");
579 }
580
581 #[test]
584 fn test_dot3() {
585 let a = [1.0, 2.0, 3.0];
586 let b = [4.0, 5.0, 6.0];
587 assert!((dot3(a, b) - 32.0).abs() < 1e-12);
588 }
589
590 #[test]
591 fn test_cross3() {
592 let i = [1.0, 0.0, 0.0];
593 let j = [0.0, 1.0, 0.0];
594 let k = cross3(i, j);
595 assert!((k[0]).abs() < 1e-12);
596 assert!((k[1]).abs() < 1e-12);
597 assert!((k[2] - 1.0).abs() < 1e-12);
598 }
599
600 #[test]
601 fn test_length3() {
602 let v = [3.0, 4.0, 0.0];
603 assert!((length3(v) - 5.0).abs() < 1e-12);
604 }
605
606 #[test]
607 fn test_normalize3() {
608 let v = [0.0, 0.0, 5.0];
609 let n = normalize3(v);
610 assert!((length3(n) - 1.0).abs() < 1e-12);
611 assert!((n[2] - 1.0).abs() < 1e-12);
612 }
613
614 #[test]
615 fn test_normalize3_zero_vec() {
616 let n = normalize3([0.0; 3]);
617 assert_eq!(n, [0.0; 3]);
618 }
619
620 #[test]
621 fn test_reflect3() {
622 let d = [0.0, -1.0, 0.0]; let n = [0.0, 1.0, 0.0]; let r = reflect3(d, n);
626 assert!((r[1] - 1.0).abs() < 1e-12);
628 }
629}
630pub mod collision_gpu;
631pub mod deformable_gpu;
632pub mod fluid_gpu;
633pub mod fluid_sim_gpu;
634pub mod gpu_cloth;
635pub mod gpu_collision_detection;
636pub mod gpu_collision_ext;
637pub mod gpu_fem_assembly;
638pub mod gpu_fluid;
639pub mod gpu_fluid_euler;
640pub mod gpu_lbm;
641pub mod gpu_md_solver;
642pub mod gpu_mesh_processing;
643pub mod gpu_neural_solver;
644pub mod gpu_nn;
645pub mod gpu_particle_system;
646pub mod gpu_particles;
647pub mod gpu_ray_tracing;
648pub mod gpu_reduction;
649pub mod gpu_rigid;
650pub mod gpu_sdf;
651pub mod gpu_sort;
652pub mod gpu_sparse_solver;
653pub mod gpu_sph_density;
654pub mod gpu_sph_pressure;
655pub mod gpu_sph_solver;
656pub mod gpu_thermal;
657pub mod gpu_voxel;
658pub mod memory;
659pub mod neural_physics;
660pub mod path_tracer;
661pub mod ray_marching;
662pub mod ray_tracing_gpu;
663pub mod raytracing;
664pub mod scheduler;