Skip to main content

oxiphysics_gpu/
lib.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU acceleration backends for the OxiPhysics engine.
5//!
6//! This crate provides a GPU compute abstraction layer that can work with any
7//! backend, with a CPU fallback as the default implementation. No heavy GPU
8//! dependencies (such as wgpu) are required.
9#![allow(missing_docs)]
10#![allow(ambiguous_glob_reexports)]
11#![allow(dead_code)]
12
13mod error;
14pub use error::*;
15
16pub mod bvh;
17pub mod cell_list;
18pub mod compute;
19pub mod compute_pipeline;
20pub mod flux_compute;
21pub mod grid_reduce;
22pub mod kernels;
23pub mod neural_compute;
24pub mod parallel;
25pub mod parallel_sort;
26pub mod particle_system;
27pub mod pipeline;
28pub mod sdf_compute;
29pub mod shader_registry;
30pub mod shaders;
31pub mod sparse_gpu;
32
33pub use compute::{BufferHandle, ComputeBackend, ComputeKernel, CpuBackend};
34pub use neural_compute::*;
35pub use particle_system::*;
36pub use sparse_gpu::*;
37
38// ── GPU compute utility functions ───────────────────────────────────────────
39
40/// Compute the optimal work group size for a given total work item count.
41///
42/// Rounds up `total` to the next multiple of `group_size`.
43pub fn dispatch_count(total: usize, group_size: usize) -> usize {
44    if group_size == 0 {
45        return 0;
46    }
47    total.div_ceil(group_size)
48}
49
50/// Compute the padded buffer size to meet alignment requirements.
51///
52/// Returns the smallest multiple of `alignment` that is >= `size`.
53pub fn aligned_size(size: usize, alignment: usize) -> usize {
54    if alignment == 0 {
55        return size;
56    }
57    size.div_ceil(alignment) * alignment
58}
59
60/// Flatten a 3D dispatch (x, y, z) into a linear index, given grid dimensions.
61#[allow(dead_code)]
62pub fn linear_index_3d(x: usize, y: usize, z: usize, dim_x: usize, dim_y: usize) -> usize {
63    z * dim_x * dim_y + y * dim_x + x
64}
65
66/// Convert a linear index back to 3D coordinates.
67#[allow(dead_code)]
68pub fn index_3d_from_linear(index: usize, dim_x: usize, dim_y: usize) -> (usize, usize, usize) {
69    let z = index / (dim_x * dim_y);
70    let rem = index % (dim_x * dim_y);
71    let y = rem / dim_x;
72    let x = rem % dim_x;
73    (x, y, z)
74}
75
76/// A simple timer utility for profiling GPU-like dispatches.
77#[derive(Debug, Clone)]
78pub struct DispatchTimer {
79    /// Label for this dispatch.
80    pub label: String,
81    /// Elapsed time in seconds (set after timing).
82    pub elapsed_secs: f64,
83}
84
85impl DispatchTimer {
86    /// Create a new timer with the given label.
87    pub fn new(label: impl Into<String>) -> Self {
88        Self {
89            label: label.into(),
90            elapsed_secs: 0.0,
91        }
92    }
93
94    /// Record elapsed time.
95    pub fn record(&mut self, elapsed: f64) {
96        self.elapsed_secs = elapsed;
97    }
98}
99
100/// Estimate memory bandwidth in GB/s.
101///
102/// * `bytes_transferred` - Total bytes read + written.
103/// * `elapsed_secs` - Elapsed time in seconds.
104#[allow(dead_code)]
105pub fn bandwidth_gb_s(bytes_transferred: usize, elapsed_secs: f64) -> f64 {
106    if elapsed_secs <= 0.0 {
107        return 0.0;
108    }
109    (bytes_transferred as f64) / elapsed_secs / 1e9
110}
111
112/// Compute the number of elements that fit in a given memory budget.
113///
114/// * `budget_bytes` - Available memory in bytes.
115/// * `element_size` - Size of one element in bytes.
116#[allow(dead_code)]
117pub fn elements_in_budget(budget_bytes: usize, element_size: usize) -> usize {
118    if element_size == 0 {
119        return 0;
120    }
121    budget_bytes / element_size
122}
123
124// ── GPU buffer utilities ─────────────────────────────────────────────────────
125
126/// Stride (in bytes) of a row in a 2-D buffer, given the element count per row
127/// and the required alignment.
128///
129/// This mirrors `wgpuDeviceGetSupportedSurfaceFormats` style pitch calculation.
130#[allow(dead_code)]
131pub fn row_pitch(elements_per_row: usize, element_size: usize, alignment: usize) -> usize {
132    let raw = elements_per_row * element_size;
133    aligned_size(raw, alignment)
134}
135
136/// Compute the 2-D buffer size (rows × pitch) for a texture-like allocation.
137#[allow(dead_code)]
138pub fn buffer_size_2d(
139    width: usize,
140    height: usize,
141    element_size: usize,
142    row_alignment: usize,
143) -> usize {
144    row_pitch(width, element_size, row_alignment) * height
145}
146
147/// Round `value` up to the next power of two.
148///
149/// Returns `value` unchanged when it is already a power of two.
150/// Returns 1 when `value` is 0.
151pub fn next_power_of_two(value: usize) -> usize {
152    if value == 0 {
153        return 1;
154    }
155    let mut p = 1usize;
156    while p < value {
157        p <<= 1;
158    }
159    p
160}
161
162/// True when `value` is a power of two (including 1).
163pub fn is_power_of_two(value: usize) -> bool {
164    value != 0 && (value & (value - 1)) == 0
165}
166
167/// Log2 of a power-of-two value.  Panics in debug mode if `v` is not a power
168/// of two.
169pub fn log2_pow2(v: usize) -> u32 {
170    debug_assert!(is_power_of_two(v), "{v} is not a power of two");
171    v.trailing_zeros()
172}
173
174// ── Work-group tiling helpers ─────────────────────────────────────────────────
175
176/// Divides a 2-D problem of `(width × height)` into tiles of `(tw × th)` and
177/// returns `(tiles_x, tiles_y)`.
178///
179/// Each dimension is rounded up so the full problem is covered.
180pub fn tile_count_2d(width: usize, height: usize, tw: usize, th: usize) -> (usize, usize) {
181    let tx = width.div_ceil(tw);
182    let ty = height.div_ceil(th);
183    (tx, ty)
184}
185
186/// Total number of tiles for a 2-D problem.
187pub fn total_tiles_2d(width: usize, height: usize, tw: usize, th: usize) -> usize {
188    let (tx, ty) = tile_count_2d(width, height, tw, th);
189    tx * ty
190}
191
192/// Convert a flat tile index back to `(tile_x, tile_y)` for a grid with
193/// `tiles_x` columns.
194pub fn tile_index_to_2d(flat: usize, tiles_x: usize) -> (usize, usize) {
195    (flat % tiles_x, flat / tiles_x)
196}
197
198// ── Numeric helpers used across GPU kernels ──────────────────────────────────
199
200/// Clamp `v` to `[lo, hi]`.
201pub fn clamp_f64(v: f64, lo: f64, hi: f64) -> f64 {
202    v.max(lo).min(hi)
203}
204
205/// Smooth-step function: `3t² - 2t³` with `t = (v - lo) / (hi - lo)`.
206pub fn smoothstep(lo: f64, hi: f64, v: f64) -> f64 {
207    let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
208    t * t * (3.0 - 2.0 * t)
209}
210
211/// Smoother-step (Ken Perlin's quintic): `6t⁵ − 15t⁴ + 10t³`.
212pub fn smootherstep(lo: f64, hi: f64, v: f64) -> f64 {
213    let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
214    t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
215}
216
217/// Linear interpolation: `a + t*(b-a)`.
218pub fn lerp(a: f64, b: f64, t: f64) -> f64 {
219    a + t * (b - a)
220}
221
222/// Inverse lerp: returns `t` such that `lerp(a, b, t) == v`, or `0` if `a==b`.
223pub fn inv_lerp(a: f64, b: f64, v: f64) -> f64 {
224    if (b - a).abs() < f64::EPSILON {
225        return 0.0;
226    }
227    (v - a) / (b - a)
228}
229
230// ── FP utilities ─────────────────────────────────────────────────────────────
231
232/// Safe reciprocal: returns `1/x` when `|x| > eps`, else `0`.
233pub fn safe_recip(x: f64, eps: f64) -> f64 {
234    if x.abs() > eps { 1.0 / x } else { 0.0 }
235}
236
237/// Safe square root: clamps negative values to 0 before taking sqrt.
238pub fn safe_sqrt(x: f64) -> f64 {
239    x.max(0.0).sqrt()
240}
241
242/// Wrap an angle in radians to `(-π, π]`.
243pub fn wrap_angle(theta: f64) -> f64 {
244    use std::f64::consts::PI;
245    let mut t = theta % (2.0 * PI);
246    if t > PI {
247        t -= 2.0 * PI;
248    }
249    if t <= -PI {
250        t += 2.0 * PI;
251    }
252    t
253}
254
255// ── Vector math (3-D, f64) ────────────────────────────────────────────────────
256
257/// Compute the dot product of two 3-element arrays.
258pub fn dot3(a: [f64; 3], b: [f64; 3]) -> f64 {
259    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
260}
261
262/// Compute the cross product of two 3-element arrays.
263pub fn cross3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
264    [
265        a[1] * b[2] - a[2] * b[1],
266        a[2] * b[0] - a[0] * b[2],
267        a[0] * b[1] - a[1] * b[0],
268    ]
269}
270
271/// Length of a 3-D vector.
272pub fn length3(v: [f64; 3]) -> f64 {
273    dot3(v, v).sqrt()
274}
275
276/// Normalise a 3-D vector.  Returns the zero vector if the length is < eps.
277pub fn normalize3(v: [f64; 3]) -> [f64; 3] {
278    let len = length3(v);
279    if len < 1e-15 {
280        return [0.0; 3];
281    }
282    [v[0] / len, v[1] / len, v[2] / len]
283}
284
285/// Reflect vector `d` about normal `n` (both assumed normalised).
286pub fn reflect3(d: [f64; 3], n: [f64; 3]) -> [f64; 3] {
287    let dn2 = 2.0 * dot3(d, n);
288    [d[0] - dn2 * n[0], d[1] - dn2 * n[1], d[2] - dn2 * n[2]]
289}
290
291// ── Parallel prefix sum (scan) ───────────────────────────────────────────────
292
293/// Parallel prefix sum (scan) on a slice of f64 values.
294///
295/// Returns a new vector where `result[i] = sum(data[0..i])`.
296/// This is the exclusive scan variant.
297pub fn exclusive_scan(data: &[f64]) -> Vec<f64> {
298    let mut result = Vec::with_capacity(data.len());
299    let mut acc = 0.0;
300    for &v in data {
301        result.push(acc);
302        acc += v;
303    }
304    result
305}
306
307/// Inclusive scan: `result[i] = sum(data[0..=i])`.
308pub fn inclusive_scan(data: &[f64]) -> Vec<f64> {
309    let mut result = Vec::with_capacity(data.len());
310    let mut acc = 0.0;
311    for &v in data {
312        acc += v;
313        result.push(acc);
314    }
315    result
316}
317
318/// Parallel reduce: compute the sum of all elements.
319pub fn reduce_sum(data: &[f64]) -> f64 {
320    data.iter().copied().sum()
321}
322
323/// Parallel reduce: compute the maximum of all elements.
324pub fn reduce_max(data: &[f64]) -> f64 {
325    data.iter().copied().fold(f64::NEG_INFINITY, f64::max)
326}
327
328/// Parallel reduce: compute the minimum of all elements.
329pub fn reduce_min(data: &[f64]) -> f64 {
330    data.iter().copied().fold(f64::INFINITY, f64::min)
331}
332
333#[cfg(test)]
334mod gpu_util_tests {
335    use super::*;
336    use std::f64::consts::PI;
337
338    #[test]
339    fn test_dispatch_count_exact() {
340        assert_eq!(dispatch_count(256, 64), 4);
341    }
342
343    #[test]
344    fn test_dispatch_count_remainder() {
345        assert_eq!(dispatch_count(257, 64), 5);
346    }
347
348    #[test]
349    fn test_dispatch_count_zero_group() {
350        assert_eq!(dispatch_count(100, 0), 0);
351    }
352
353    #[test]
354    fn test_aligned_size_exact() {
355        assert_eq!(aligned_size(256, 64), 256);
356    }
357
358    #[test]
359    fn test_aligned_size_pad() {
360        assert_eq!(aligned_size(257, 64), 320);
361    }
362
363    #[test]
364    fn test_aligned_size_zero_alignment() {
365        assert_eq!(aligned_size(100, 0), 100);
366    }
367
368    #[test]
369    fn test_linear_index_3d() {
370        // Grid 4x3x2
371        assert_eq!(linear_index_3d(0, 0, 0, 4, 3), 0);
372        assert_eq!(linear_index_3d(3, 2, 1, 4, 3), 12 + 2 * 4 + 3);
373    }
374
375    #[test]
376    fn test_index_3d_roundtrip() {
377        let (dx, dy) = (4, 3);
378        for z in 0..2 {
379            for y in 0..dy {
380                for x in 0..dx {
381                    let idx = linear_index_3d(x, y, z, dx, dy);
382                    let (rx, ry, rz) = index_3d_from_linear(idx, dx, dy);
383                    assert_eq!((rx, ry, rz), (x, y, z));
384                }
385            }
386        }
387    }
388
389    #[test]
390    fn test_dispatch_timer() {
391        let mut timer = DispatchTimer::new("test");
392        assert_eq!(timer.label, "test");
393        timer.record(0.5);
394        assert!((timer.elapsed_secs - 0.5).abs() < 1e-10);
395    }
396
397    #[test]
398    fn test_bandwidth_gb_s() {
399        // 1 GB in 1 second = 1 GB/s
400        let bw = bandwidth_gb_s(1_000_000_000, 1.0);
401        assert!((bw - 1.0).abs() < 1e-6);
402    }
403
404    #[test]
405    fn test_bandwidth_zero_time() {
406        assert!((bandwidth_gb_s(1000, 0.0)).abs() < 1e-10);
407    }
408
409    #[test]
410    fn test_elements_in_budget() {
411        assert_eq!(elements_in_budget(1024, 4), 256);
412        assert_eq!(elements_in_budget(1024, 0), 0);
413    }
414
415    #[test]
416    fn test_exclusive_scan() {
417        let data = [1.0, 2.0, 3.0, 4.0];
418        let result = exclusive_scan(&data);
419        assert_eq!(result, vec![0.0, 1.0, 3.0, 6.0]);
420    }
421
422    #[test]
423    fn test_inclusive_scan() {
424        let data = [1.0, 2.0, 3.0, 4.0];
425        let result = inclusive_scan(&data);
426        assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0]);
427    }
428
429    #[test]
430    fn test_reduce_sum() {
431        assert!((reduce_sum(&[1.0, 2.0, 3.0]) - 6.0).abs() < 1e-10);
432    }
433
434    #[test]
435    fn test_reduce_max() {
436        assert!((reduce_max(&[1.0, 5.0, 3.0]) - 5.0).abs() < 1e-10);
437    }
438
439    #[test]
440    fn test_reduce_min() {
441        assert!((reduce_min(&[1.0, 5.0, 3.0]) - 1.0).abs() < 1e-10);
442    }
443
444    #[test]
445    fn test_exclusive_scan_empty() {
446        let result = exclusive_scan(&[]);
447        assert!(result.is_empty());
448    }
449
450    #[test]
451    fn test_inclusive_scan_single() {
452        let result = inclusive_scan(&[42.0]);
453        assert_eq!(result, vec![42.0]);
454    }
455
456    // ── Buffer utility tests ─────────────────────────────────────────────
457
458    #[test]
459    fn test_row_pitch_aligned() {
460        // 128 elements × 4 bytes = 512 bytes, already aligned to 256
461        assert_eq!(row_pitch(128, 4, 256), 512);
462    }
463
464    #[test]
465    fn test_row_pitch_needs_padding() {
466        // 100 × 4 = 400; aligned to 256 => 512
467        assert_eq!(row_pitch(100, 4, 256), 512);
468    }
469
470    #[test]
471    fn test_buffer_size_2d() {
472        // 4 rows × pitch(64 elems × 4 bytes, align 256) = 4 × 256 = 1024
473        assert_eq!(buffer_size_2d(64, 4, 4, 256), 1024);
474    }
475
476    #[test]
477    fn test_next_power_of_two() {
478        assert_eq!(next_power_of_two(0), 1);
479        assert_eq!(next_power_of_two(1), 1);
480        assert_eq!(next_power_of_two(5), 8);
481        assert_eq!(next_power_of_two(8), 8);
482        assert_eq!(next_power_of_two(9), 16);
483    }
484
485    #[test]
486    fn test_is_power_of_two() {
487        assert!(is_power_of_two(1));
488        assert!(is_power_of_two(16));
489        assert!(!is_power_of_two(0));
490        assert!(!is_power_of_two(7));
491    }
492
493    #[test]
494    fn test_log2_pow2() {
495        assert_eq!(log2_pow2(1), 0);
496        assert_eq!(log2_pow2(2), 1);
497        assert_eq!(log2_pow2(256), 8);
498    }
499
500    #[test]
501    fn test_tile_count_2d_exact() {
502        let (tx, ty) = tile_count_2d(64, 64, 16, 16);
503        assert_eq!(tx, 4);
504        assert_eq!(ty, 4);
505    }
506
507    #[test]
508    fn test_tile_count_2d_remainder() {
509        let (tx, ty) = tile_count_2d(65, 65, 16, 16);
510        assert_eq!(tx, 5);
511        assert_eq!(ty, 5);
512    }
513
514    #[test]
515    fn test_total_tiles_2d() {
516        assert_eq!(total_tiles_2d(64, 64, 16, 16), 16);
517    }
518
519    #[test]
520    fn test_tile_index_to_2d() {
521        // tiles_x = 4; flat=5 => (1, 1)
522        assert_eq!(tile_index_to_2d(5, 4), (1, 1));
523        assert_eq!(tile_index_to_2d(0, 4), (0, 0));
524    }
525
526    // ── Numeric helpers tests ────────────────────────────────────────────
527
528    #[test]
529    fn test_smoothstep_edges() {
530        assert!((smoothstep(0.0, 1.0, 0.0) - 0.0).abs() < 1e-12);
531        assert!((smoothstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
532    }
533
534    #[test]
535    fn test_smoothstep_midpoint() {
536        // at t=0.5: 3*(0.25) - 2*(0.125) = 0.75 - 0.25 = 0.5
537        assert!((smoothstep(0.0, 1.0, 0.5) - 0.5).abs() < 1e-12);
538    }
539
540    #[test]
541    fn test_smootherstep_edges() {
542        assert!((smootherstep(0.0, 1.0, 0.0)).abs() < 1e-12);
543        assert!((smootherstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
544    }
545
546    #[test]
547    fn test_lerp_inv_lerp_roundtrip() {
548        let a = 10.0;
549        let b = 20.0;
550        let t = 0.3;
551        let v = lerp(a, b, t);
552        assert!((inv_lerp(a, b, v) - t).abs() < 1e-12);
553    }
554
555    #[test]
556    fn test_safe_recip_normal() {
557        assert!((safe_recip(2.0, 1e-9) - 0.5).abs() < 1e-12);
558    }
559
560    #[test]
561    fn test_safe_recip_near_zero() {
562        assert!((safe_recip(1e-15, 1e-9)).abs() < 1e-12);
563    }
564
565    #[test]
566    fn test_safe_sqrt_positive() {
567        assert!((safe_sqrt(9.0) - 3.0).abs() < 1e-12);
568    }
569
570    #[test]
571    fn test_safe_sqrt_negative() {
572        assert!((safe_sqrt(-1.0)).abs() < 1e-12);
573    }
574
575    #[test]
576    fn test_wrap_angle_in_range() {
577        let wrapped = wrap_angle(3.0 * PI);
578        assert!(wrapped.abs() <= PI + 1e-12, "wrapped = {wrapped}");
579    }
580
581    // ── Vector math tests ────────────────────────────────────────────────
582
583    #[test]
584    fn test_dot3() {
585        let a = [1.0, 2.0, 3.0];
586        let b = [4.0, 5.0, 6.0];
587        assert!((dot3(a, b) - 32.0).abs() < 1e-12);
588    }
589
590    #[test]
591    fn test_cross3() {
592        let i = [1.0, 0.0, 0.0];
593        let j = [0.0, 1.0, 0.0];
594        let k = cross3(i, j);
595        assert!((k[0]).abs() < 1e-12);
596        assert!((k[1]).abs() < 1e-12);
597        assert!((k[2] - 1.0).abs() < 1e-12);
598    }
599
600    #[test]
601    fn test_length3() {
602        let v = [3.0, 4.0, 0.0];
603        assert!((length3(v) - 5.0).abs() < 1e-12);
604    }
605
606    #[test]
607    fn test_normalize3() {
608        let v = [0.0, 0.0, 5.0];
609        let n = normalize3(v);
610        assert!((length3(n) - 1.0).abs() < 1e-12);
611        assert!((n[2] - 1.0).abs() < 1e-12);
612    }
613
614    #[test]
615    fn test_normalize3_zero_vec() {
616        let n = normalize3([0.0; 3]);
617        assert_eq!(n, [0.0; 3]);
618    }
619
620    #[test]
621    fn test_reflect3() {
622        // Reflect (1,0,0) about (0,1,0) => still (1,0,0) but inverted y
623        let d = [0.0, -1.0, 0.0]; // pointing down
624        let n = [0.0, 1.0, 0.0]; // surface normal up
625        let r = reflect3(d, n);
626        // r = d - 2*(d·n)*n = [0,-1,0] - 2*(-1)*[0,1,0] = [0,1,0]
627        assert!((r[1] - 1.0).abs() < 1e-12);
628    }
629}
630pub mod collision_gpu;
631pub mod deformable_gpu;
632pub mod fluid_gpu;
633pub mod fluid_sim_gpu;
634pub mod gpu_cloth;
635pub mod gpu_collision_detection;
636pub mod gpu_collision_ext;
637pub mod gpu_fem_assembly;
638pub mod gpu_fluid;
639pub mod gpu_fluid_euler;
640pub mod gpu_lbm;
641pub mod gpu_md_solver;
642pub mod gpu_mesh_processing;
643pub mod gpu_neural_solver;
644pub mod gpu_nn;
645pub mod gpu_particle_system;
646pub mod gpu_particles;
647pub mod gpu_ray_tracing;
648pub mod gpu_reduction;
649pub mod gpu_rigid;
650pub mod gpu_sdf;
651pub mod gpu_sort;
652pub mod gpu_sparse_solver;
653pub mod gpu_sph_density;
654pub mod gpu_sph_pressure;
655pub mod gpu_sph_solver;
656pub mod gpu_thermal;
657pub mod gpu_voxel;
658pub mod memory;
659pub mod neural_physics;
660pub mod path_tracer;
661pub mod ray_marching;
662pub mod ray_tracing_gpu;
663pub mod raytracing;
664pub mod scheduler;