Skip to main content

oxiphysics_gpu/
lib.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU acceleration backends for the OxiPhysics engine.
5//!
6//! This crate provides a GPU compute abstraction layer that can work with any
7//! backend, with a CPU fallback as the default implementation. No heavy GPU
8//! dependencies (such as wgpu) are required.
9#![allow(missing_docs)]
10#![allow(ambiguous_glob_reexports)]
11#![allow(dead_code)]
12
13mod error;
14pub use error::*;
15
16pub mod bvh;
17pub mod cell_list;
18pub mod compute;
19pub mod compute_pipeline;
20pub mod flux_compute;
21pub mod gpu_bench;
22pub mod grid_reduce;
23pub mod kernels;
24pub mod lbm_gpu;
25pub mod neural_compute;
26pub mod parallel;
27pub mod parallel_sort;
28pub mod particle_system;
29pub mod pipeline;
30pub mod sdf_compute;
31pub mod shader_registry;
32pub mod shaders;
33pub mod sparse_gpu;
34pub mod sph_gpu;
35
36pub use compute::{BufferHandle, ComputeBackend, ComputeKernel, CpuBackend};
37pub use neural_compute::*;
38pub use particle_system::*;
39pub use sparse_gpu::*;
40
41// ── GPU compute utility functions ───────────────────────────────────────────
42
43/// Compute the optimal work group size for a given total work item count.
44///
45/// Rounds up `total` to the next multiple of `group_size`.
46pub fn dispatch_count(total: usize, group_size: usize) -> usize {
47    if group_size == 0 {
48        return 0;
49    }
50    total.div_ceil(group_size)
51}
52
53/// Compute the padded buffer size to meet alignment requirements.
54///
55/// Returns the smallest multiple of `alignment` that is >= `size`.
56pub fn aligned_size(size: usize, alignment: usize) -> usize {
57    if alignment == 0 {
58        return size;
59    }
60    size.div_ceil(alignment) * alignment
61}
62
63/// Flatten a 3D dispatch (x, y, z) into a linear index, given grid dimensions.
64#[allow(dead_code)]
65pub fn linear_index_3d(x: usize, y: usize, z: usize, dim_x: usize, dim_y: usize) -> usize {
66    z * dim_x * dim_y + y * dim_x + x
67}
68
69/// Convert a linear index back to 3D coordinates.
70#[allow(dead_code)]
71pub fn index_3d_from_linear(index: usize, dim_x: usize, dim_y: usize) -> (usize, usize, usize) {
72    let z = index / (dim_x * dim_y);
73    let rem = index % (dim_x * dim_y);
74    let y = rem / dim_x;
75    let x = rem % dim_x;
76    (x, y, z)
77}
78
79/// A simple timer utility for profiling GPU-like dispatches.
80#[derive(Debug, Clone)]
81pub struct DispatchTimer {
82    /// Label for this dispatch.
83    pub label: String,
84    /// Elapsed time in seconds (set after timing).
85    pub elapsed_secs: f64,
86}
87
88impl DispatchTimer {
89    /// Create a new timer with the given label.
90    pub fn new(label: impl Into<String>) -> Self {
91        Self {
92            label: label.into(),
93            elapsed_secs: 0.0,
94        }
95    }
96
97    /// Record elapsed time.
98    pub fn record(&mut self, elapsed: f64) {
99        self.elapsed_secs = elapsed;
100    }
101}
102
103/// Estimate memory bandwidth in GB/s.
104///
105/// * `bytes_transferred` - Total bytes read + written.
106/// * `elapsed_secs` - Elapsed time in seconds.
107#[allow(dead_code)]
108pub fn bandwidth_gb_s(bytes_transferred: usize, elapsed_secs: f64) -> f64 {
109    if elapsed_secs <= 0.0 {
110        return 0.0;
111    }
112    (bytes_transferred as f64) / elapsed_secs / 1e9
113}
114
115/// Compute the number of elements that fit in a given memory budget.
116///
117/// * `budget_bytes` - Available memory in bytes.
118/// * `element_size` - Size of one element in bytes.
119#[allow(dead_code)]
120pub fn elements_in_budget(budget_bytes: usize, element_size: usize) -> usize {
121    if element_size == 0 {
122        return 0;
123    }
124    budget_bytes / element_size
125}
126
127// ── GPU buffer utilities ─────────────────────────────────────────────────────
128
129/// Stride (in bytes) of a row in a 2-D buffer, given the element count per row
130/// and the required alignment.
131///
132/// This mirrors `wgpuDeviceGetSupportedSurfaceFormats` style pitch calculation.
133#[allow(dead_code)]
134pub fn row_pitch(elements_per_row: usize, element_size: usize, alignment: usize) -> usize {
135    let raw = elements_per_row * element_size;
136    aligned_size(raw, alignment)
137}
138
139/// Compute the 2-D buffer size (rows × pitch) for a texture-like allocation.
140#[allow(dead_code)]
141pub fn buffer_size_2d(
142    width: usize,
143    height: usize,
144    element_size: usize,
145    row_alignment: usize,
146) -> usize {
147    row_pitch(width, element_size, row_alignment) * height
148}
149
150/// Round `value` up to the next power of two.
151///
152/// Returns `value` unchanged when it is already a power of two.
153/// Returns 1 when `value` is 0.
154pub fn next_power_of_two(value: usize) -> usize {
155    if value == 0 {
156        return 1;
157    }
158    let mut p = 1usize;
159    while p < value {
160        p <<= 1;
161    }
162    p
163}
164
165/// True when `value` is a power of two (including 1).
166pub fn is_power_of_two(value: usize) -> bool {
167    value != 0 && (value & (value - 1)) == 0
168}
169
170/// Log2 of a power-of-two value.  Panics in debug mode if `v` is not a power
171/// of two.
172pub fn log2_pow2(v: usize) -> u32 {
173    debug_assert!(is_power_of_two(v), "{v} is not a power of two");
174    v.trailing_zeros()
175}
176
177// ── Work-group tiling helpers ─────────────────────────────────────────────────
178
179/// Divides a 2-D problem of `(width × height)` into tiles of `(tw × th)` and
180/// returns `(tiles_x, tiles_y)`.
181///
182/// Each dimension is rounded up so the full problem is covered.
183pub fn tile_count_2d(width: usize, height: usize, tw: usize, th: usize) -> (usize, usize) {
184    let tx = width.div_ceil(tw);
185    let ty = height.div_ceil(th);
186    (tx, ty)
187}
188
189/// Total number of tiles for a 2-D problem.
190pub fn total_tiles_2d(width: usize, height: usize, tw: usize, th: usize) -> usize {
191    let (tx, ty) = tile_count_2d(width, height, tw, th);
192    tx * ty
193}
194
195/// Convert a flat tile index back to `(tile_x, tile_y)` for a grid with
196/// `tiles_x` columns.
197pub fn tile_index_to_2d(flat: usize, tiles_x: usize) -> (usize, usize) {
198    (flat % tiles_x, flat / tiles_x)
199}
200
201// ── Numeric helpers used across GPU kernels ──────────────────────────────────
202
203/// Clamp `v` to `[lo, hi]`.
204pub fn clamp_f64(v: f64, lo: f64, hi: f64) -> f64 {
205    v.max(lo).min(hi)
206}
207
208/// Smooth-step function: `3t² - 2t³` with `t = (v - lo) / (hi - lo)`.
209pub fn smoothstep(lo: f64, hi: f64, v: f64) -> f64 {
210    let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
211    t * t * (3.0 - 2.0 * t)
212}
213
214/// Smoother-step (Ken Perlin's quintic): `6t⁵ − 15t⁴ + 10t³`.
215pub fn smootherstep(lo: f64, hi: f64, v: f64) -> f64 {
216    let t = clamp_f64((v - lo) / (hi - lo), 0.0, 1.0);
217    t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
218}
219
220/// Linear interpolation: `a + t*(b-a)`.
221pub fn lerp(a: f64, b: f64, t: f64) -> f64 {
222    a + t * (b - a)
223}
224
225/// Inverse lerp: returns `t` such that `lerp(a, b, t) == v`, or `0` if `a==b`.
226pub fn inv_lerp(a: f64, b: f64, v: f64) -> f64 {
227    if (b - a).abs() < f64::EPSILON {
228        return 0.0;
229    }
230    (v - a) / (b - a)
231}
232
233// ── FP utilities ─────────────────────────────────────────────────────────────
234
235/// Safe reciprocal: returns `1/x` when `|x| > eps`, else `0`.
236pub fn safe_recip(x: f64, eps: f64) -> f64 {
237    if x.abs() > eps { 1.0 / x } else { 0.0 }
238}
239
240/// Safe square root: clamps negative values to 0 before taking sqrt.
241pub fn safe_sqrt(x: f64) -> f64 {
242    x.max(0.0).sqrt()
243}
244
245/// Wrap an angle in radians to `(-π, π]`.
246pub fn wrap_angle(theta: f64) -> f64 {
247    use std::f64::consts::PI;
248    let mut t = theta % (2.0 * PI);
249    if t > PI {
250        t -= 2.0 * PI;
251    }
252    if t <= -PI {
253        t += 2.0 * PI;
254    }
255    t
256}
257
258// ── Vector math (3-D, f64) ────────────────────────────────────────────────────
259
260/// Compute the dot product of two 3-element arrays.
261pub fn dot3(a: [f64; 3], b: [f64; 3]) -> f64 {
262    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
263}
264
265/// Compute the cross product of two 3-element arrays.
266pub fn cross3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
267    [
268        a[1] * b[2] - a[2] * b[1],
269        a[2] * b[0] - a[0] * b[2],
270        a[0] * b[1] - a[1] * b[0],
271    ]
272}
273
274/// Length of a 3-D vector.
275pub fn length3(v: [f64; 3]) -> f64 {
276    dot3(v, v).sqrt()
277}
278
279/// Normalise a 3-D vector.  Returns the zero vector if the length is < eps.
280pub fn normalize3(v: [f64; 3]) -> [f64; 3] {
281    let len = length3(v);
282    if len < 1e-15 {
283        return [0.0; 3];
284    }
285    [v[0] / len, v[1] / len, v[2] / len]
286}
287
288/// Reflect vector `d` about normal `n` (both assumed normalised).
289pub fn reflect3(d: [f64; 3], n: [f64; 3]) -> [f64; 3] {
290    let dn2 = 2.0 * dot3(d, n);
291    [d[0] - dn2 * n[0], d[1] - dn2 * n[1], d[2] - dn2 * n[2]]
292}
293
294// ── Parallel prefix sum (scan) ───────────────────────────────────────────────
295
296/// Parallel prefix sum (scan) on a slice of f64 values.
297///
298/// Returns a new vector where `result[i] = sum(data[0..i])`.
299/// This is the exclusive scan variant.
300pub fn exclusive_scan(data: &[f64]) -> Vec<f64> {
301    let mut result = Vec::with_capacity(data.len());
302    let mut acc = 0.0;
303    for &v in data {
304        result.push(acc);
305        acc += v;
306    }
307    result
308}
309
310/// Inclusive scan: `result[i] = sum(data[0..=i])`.
311pub fn inclusive_scan(data: &[f64]) -> Vec<f64> {
312    let mut result = Vec::with_capacity(data.len());
313    let mut acc = 0.0;
314    for &v in data {
315        acc += v;
316        result.push(acc);
317    }
318    result
319}
320
321/// Parallel reduce: compute the sum of all elements.
322pub fn reduce_sum(data: &[f64]) -> f64 {
323    data.iter().copied().sum()
324}
325
326/// Parallel reduce: compute the maximum of all elements.
327pub fn reduce_max(data: &[f64]) -> f64 {
328    data.iter().copied().fold(f64::NEG_INFINITY, f64::max)
329}
330
331/// Parallel reduce: compute the minimum of all elements.
332pub fn reduce_min(data: &[f64]) -> f64 {
333    data.iter().copied().fold(f64::INFINITY, f64::min)
334}
335
336#[cfg(test)]
337mod gpu_util_tests {
338    use super::*;
339    use std::f64::consts::PI;
340
341    #[test]
342    fn test_dispatch_count_exact() {
343        assert_eq!(dispatch_count(256, 64), 4);
344    }
345
346    #[test]
347    fn test_dispatch_count_remainder() {
348        assert_eq!(dispatch_count(257, 64), 5);
349    }
350
351    #[test]
352    fn test_dispatch_count_zero_group() {
353        assert_eq!(dispatch_count(100, 0), 0);
354    }
355
356    #[test]
357    fn test_aligned_size_exact() {
358        assert_eq!(aligned_size(256, 64), 256);
359    }
360
361    #[test]
362    fn test_aligned_size_pad() {
363        assert_eq!(aligned_size(257, 64), 320);
364    }
365
366    #[test]
367    fn test_aligned_size_zero_alignment() {
368        assert_eq!(aligned_size(100, 0), 100);
369    }
370
371    #[test]
372    fn test_linear_index_3d() {
373        // Grid 4x3x2
374        assert_eq!(linear_index_3d(0, 0, 0, 4, 3), 0);
375        assert_eq!(linear_index_3d(3, 2, 1, 4, 3), 12 + 2 * 4 + 3);
376    }
377
378    #[test]
379    fn test_index_3d_roundtrip() {
380        let (dx, dy) = (4, 3);
381        for z in 0..2 {
382            for y in 0..dy {
383                for x in 0..dx {
384                    let idx = linear_index_3d(x, y, z, dx, dy);
385                    let (rx, ry, rz) = index_3d_from_linear(idx, dx, dy);
386                    assert_eq!((rx, ry, rz), (x, y, z));
387                }
388            }
389        }
390    }
391
392    #[test]
393    fn test_dispatch_timer() {
394        let mut timer = DispatchTimer::new("test");
395        assert_eq!(timer.label, "test");
396        timer.record(0.5);
397        assert!((timer.elapsed_secs - 0.5).abs() < 1e-10);
398    }
399
400    #[test]
401    fn test_bandwidth_gb_s() {
402        // 1 GB in 1 second = 1 GB/s
403        let bw = bandwidth_gb_s(1_000_000_000, 1.0);
404        assert!((bw - 1.0).abs() < 1e-6);
405    }
406
407    #[test]
408    fn test_bandwidth_zero_time() {
409        assert!((bandwidth_gb_s(1000, 0.0)).abs() < 1e-10);
410    }
411
412    #[test]
413    fn test_elements_in_budget() {
414        assert_eq!(elements_in_budget(1024, 4), 256);
415        assert_eq!(elements_in_budget(1024, 0), 0);
416    }
417
418    #[test]
419    fn test_exclusive_scan() {
420        let data = [1.0, 2.0, 3.0, 4.0];
421        let result = exclusive_scan(&data);
422        assert_eq!(result, vec![0.0, 1.0, 3.0, 6.0]);
423    }
424
425    #[test]
426    fn test_inclusive_scan() {
427        let data = [1.0, 2.0, 3.0, 4.0];
428        let result = inclusive_scan(&data);
429        assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0]);
430    }
431
432    #[test]
433    fn test_reduce_sum() {
434        assert!((reduce_sum(&[1.0, 2.0, 3.0]) - 6.0).abs() < 1e-10);
435    }
436
437    #[test]
438    fn test_reduce_max() {
439        assert!((reduce_max(&[1.0, 5.0, 3.0]) - 5.0).abs() < 1e-10);
440    }
441
442    #[test]
443    fn test_reduce_min() {
444        assert!((reduce_min(&[1.0, 5.0, 3.0]) - 1.0).abs() < 1e-10);
445    }
446
447    #[test]
448    fn test_exclusive_scan_empty() {
449        let result = exclusive_scan(&[]);
450        assert!(result.is_empty());
451    }
452
453    #[test]
454    fn test_inclusive_scan_single() {
455        let result = inclusive_scan(&[42.0]);
456        assert_eq!(result, vec![42.0]);
457    }
458
459    // ── Buffer utility tests ─────────────────────────────────────────────
460
461    #[test]
462    fn test_row_pitch_aligned() {
463        // 128 elements × 4 bytes = 512 bytes, already aligned to 256
464        assert_eq!(row_pitch(128, 4, 256), 512);
465    }
466
467    #[test]
468    fn test_row_pitch_needs_padding() {
469        // 100 × 4 = 400; aligned to 256 => 512
470        assert_eq!(row_pitch(100, 4, 256), 512);
471    }
472
473    #[test]
474    fn test_buffer_size_2d() {
475        // 4 rows × pitch(64 elems × 4 bytes, align 256) = 4 × 256 = 1024
476        assert_eq!(buffer_size_2d(64, 4, 4, 256), 1024);
477    }
478
479    #[test]
480    fn test_next_power_of_two() {
481        assert_eq!(next_power_of_two(0), 1);
482        assert_eq!(next_power_of_two(1), 1);
483        assert_eq!(next_power_of_two(5), 8);
484        assert_eq!(next_power_of_two(8), 8);
485        assert_eq!(next_power_of_two(9), 16);
486    }
487
488    #[test]
489    fn test_is_power_of_two() {
490        assert!(is_power_of_two(1));
491        assert!(is_power_of_two(16));
492        assert!(!is_power_of_two(0));
493        assert!(!is_power_of_two(7));
494    }
495
496    #[test]
497    fn test_log2_pow2() {
498        assert_eq!(log2_pow2(1), 0);
499        assert_eq!(log2_pow2(2), 1);
500        assert_eq!(log2_pow2(256), 8);
501    }
502
503    #[test]
504    fn test_tile_count_2d_exact() {
505        let (tx, ty) = tile_count_2d(64, 64, 16, 16);
506        assert_eq!(tx, 4);
507        assert_eq!(ty, 4);
508    }
509
510    #[test]
511    fn test_tile_count_2d_remainder() {
512        let (tx, ty) = tile_count_2d(65, 65, 16, 16);
513        assert_eq!(tx, 5);
514        assert_eq!(ty, 5);
515    }
516
517    #[test]
518    fn test_total_tiles_2d() {
519        assert_eq!(total_tiles_2d(64, 64, 16, 16), 16);
520    }
521
522    #[test]
523    fn test_tile_index_to_2d() {
524        // tiles_x = 4; flat=5 => (1, 1)
525        assert_eq!(tile_index_to_2d(5, 4), (1, 1));
526        assert_eq!(tile_index_to_2d(0, 4), (0, 0));
527    }
528
529    // ── Numeric helpers tests ────────────────────────────────────────────
530
531    #[test]
532    fn test_smoothstep_edges() {
533        assert!((smoothstep(0.0, 1.0, 0.0) - 0.0).abs() < 1e-12);
534        assert!((smoothstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
535    }
536
537    #[test]
538    fn test_smoothstep_midpoint() {
539        // at t=0.5: 3*(0.25) - 2*(0.125) = 0.75 - 0.25 = 0.5
540        assert!((smoothstep(0.0, 1.0, 0.5) - 0.5).abs() < 1e-12);
541    }
542
543    #[test]
544    fn test_smootherstep_edges() {
545        assert!((smootherstep(0.0, 1.0, 0.0)).abs() < 1e-12);
546        assert!((smootherstep(0.0, 1.0, 1.0) - 1.0).abs() < 1e-12);
547    }
548
549    #[test]
550    fn test_lerp_inv_lerp_roundtrip() {
551        let a = 10.0;
552        let b = 20.0;
553        let t = 0.3;
554        let v = lerp(a, b, t);
555        assert!((inv_lerp(a, b, v) - t).abs() < 1e-12);
556    }
557
558    #[test]
559    fn test_safe_recip_normal() {
560        assert!((safe_recip(2.0, 1e-9) - 0.5).abs() < 1e-12);
561    }
562
563    #[test]
564    fn test_safe_recip_near_zero() {
565        assert!((safe_recip(1e-15, 1e-9)).abs() < 1e-12);
566    }
567
568    #[test]
569    fn test_safe_sqrt_positive() {
570        assert!((safe_sqrt(9.0) - 3.0).abs() < 1e-12);
571    }
572
573    #[test]
574    fn test_safe_sqrt_negative() {
575        assert!((safe_sqrt(-1.0)).abs() < 1e-12);
576    }
577
578    #[test]
579    fn test_wrap_angle_in_range() {
580        let wrapped = wrap_angle(3.0 * PI);
581        assert!(wrapped.abs() <= PI + 1e-12, "wrapped = {wrapped}");
582    }
583
584    // ── Vector math tests ────────────────────────────────────────────────
585
586    #[test]
587    fn test_dot3() {
588        let a = [1.0, 2.0, 3.0];
589        let b = [4.0, 5.0, 6.0];
590        assert!((dot3(a, b) - 32.0).abs() < 1e-12);
591    }
592
593    #[test]
594    fn test_cross3() {
595        let i = [1.0, 0.0, 0.0];
596        let j = [0.0, 1.0, 0.0];
597        let k = cross3(i, j);
598        assert!((k[0]).abs() < 1e-12);
599        assert!((k[1]).abs() < 1e-12);
600        assert!((k[2] - 1.0).abs() < 1e-12);
601    }
602
603    #[test]
604    fn test_length3() {
605        let v = [3.0, 4.0, 0.0];
606        assert!((length3(v) - 5.0).abs() < 1e-12);
607    }
608
609    #[test]
610    fn test_normalize3() {
611        let v = [0.0, 0.0, 5.0];
612        let n = normalize3(v);
613        assert!((length3(n) - 1.0).abs() < 1e-12);
614        assert!((n[2] - 1.0).abs() < 1e-12);
615    }
616
617    #[test]
618    fn test_normalize3_zero_vec() {
619        let n = normalize3([0.0; 3]);
620        assert_eq!(n, [0.0; 3]);
621    }
622
623    #[test]
624    fn test_reflect3() {
625        // Reflect (1,0,0) about (0,1,0) => still (1,0,0) but inverted y
626        let d = [0.0, -1.0, 0.0]; // pointing down
627        let n = [0.0, 1.0, 0.0]; // surface normal up
628        let r = reflect3(d, n);
629        // r = d - 2*(d·n)*n = [0,-1,0] - 2*(-1)*[0,1,0] = [0,1,0]
630        assert!((r[1] - 1.0).abs() < 1e-12);
631    }
632}
633pub mod collision_gpu;
634pub mod deformable_gpu;
635pub mod fluid_gpu;
636pub mod fluid_sim_gpu;
637pub mod gpu_cloth;
638pub mod gpu_collision_detection;
639pub mod gpu_collision_ext;
640pub mod gpu_fem_assembly;
641pub mod gpu_fluid;
642pub mod gpu_fluid_euler;
643pub mod gpu_lbm;
644pub mod gpu_md_solver;
645pub mod gpu_mesh_processing;
646pub mod gpu_neural_solver;
647pub mod gpu_nn;
648pub mod gpu_particle_system;
649pub mod gpu_particles;
650pub mod gpu_ray_tracing;
651pub mod gpu_reduction;
652pub mod gpu_rigid;
653pub mod gpu_sdf;
654pub mod gpu_sort;
655pub mod gpu_sparse_solver;
656pub mod gpu_sph_density;
657pub mod gpu_sph_pressure;
658pub mod gpu_sph_solver;
659pub mod gpu_thermal;
660pub mod gpu_voxel;
661pub mod memory;
662pub mod neural_physics;
663pub mod path_tracer;
664pub mod ray_marching;
665pub mod ray_tracing_gpu;
666pub mod raytracing;
667pub mod scheduler;