Skip to main content

oxiphysics_gpu/kernels/
mod.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU/CPU compute kernels for physics simulation.
5//!
6//! This module groups all low-level compute kernels.  Each sub-module exposes
7//! a CPU-mock implementation that mirrors a GPU kernel in its data layout and
8//! dispatch model, but executes on the CPU using Rayon for parallelism.
9
10#![allow(dead_code)]
11
12pub mod broadphase;
13pub mod md_force;
14pub mod rigid;
15pub mod sph;
16
17// ── Kernel registry helpers ──────────────────────────────────────────────────
18
19/// Identifier for a built-in kernel family.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum KernelFamily {
22    /// Smoothed Particle Hydrodynamics kernels.
23    Sph,
24    /// Rigid-body integration and collision kernels.
25    Rigid,
26    /// Broad-phase AABB/BVH traversal kernels.
27    Broadphase,
28    /// Molecular dynamics force kernels.
29    MdForce,
30    /// Signed distance field evaluation kernels.
31    SdfCompute,
32    /// Neural-network inference kernels.
33    NeuralCompute,
34    /// Grid-reduce / scan kernels.
35    GridReduce,
36}
37
38impl std::fmt::Display for KernelFamily {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        let name = match self {
41            KernelFamily::Sph => "sph",
42            KernelFamily::Rigid => "rigid",
43            KernelFamily::Broadphase => "broadphase",
44            KernelFamily::MdForce => "md_force",
45            KernelFamily::SdfCompute => "sdf_compute",
46            KernelFamily::NeuralCompute => "neural_compute",
47            KernelFamily::GridReduce => "grid_reduce",
48        };
49        write!(f, "{name}")
50    }
51}
52
53// ── Dispatch descriptor ──────────────────────────────────────────────────────
54
55/// Describes the 3-D work-group dispatch dimensions for a kernel launch.
56///
57/// Mirrors the `(group_count_x, group_count_y, group_count_z)` triple passed
58/// to `vkCmdDispatch` / `wgpuComputePassEncoderDispatchWorkgroups`.
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub struct DispatchDims {
61    /// Number of work-groups in X.
62    pub x: u32,
63    /// Number of work-groups in Y.
64    pub y: u32,
65    /// Number of work-groups in Z.
66    pub z: u32,
67}
68
69impl DispatchDims {
70    /// Create a 1-D dispatch of `n` work-groups.
71    pub fn linear(n: u32) -> Self {
72        Self { x: n, y: 1, z: 1 }
73    }
74
75    /// Create a 2-D dispatch.
76    pub fn grid2d(x: u32, y: u32) -> Self {
77        Self { x, y, z: 1 }
78    }
79
80    /// Create a 3-D dispatch.
81    pub fn grid3d(x: u32, y: u32, z: u32) -> Self {
82        Self { x, y, z }
83    }
84
85    /// Total number of work-groups.
86    pub fn total_groups(&self) -> u64 {
87        self.x as u64 * self.y as u64 * self.z as u64
88    }
89
90    /// Total threads given `threads_per_group`.
91    pub fn total_threads(&self, threads_per_group: u32) -> u64 {
92        self.total_groups() * threads_per_group as u64
93    }
94}
95
96/// Compute the 1-D dispatch size needed to cover `n` items with `group_size`
97/// threads per work-group.
98pub fn dispatch_size_1d(n: u32, group_size: u32) -> u32 {
99    if group_size == 0 {
100        return 0;
101    }
102    n.div_ceil(group_size)
103}
104
105// ── Kernel performance counters ──────────────────────────────────────────────
106
107/// Lightweight performance counters attached to a single kernel invocation.
108#[derive(Debug, Clone, Default)]
109pub struct KernelPerfCounters {
110    /// Number of times the kernel was dispatched.
111    pub dispatch_count: u64,
112    /// Total elements processed across all dispatches.
113    pub elements_processed: u64,
114    /// Estimated floating-point operations (MACs counted as 2 FLOPs).
115    pub flop_count: u64,
116    /// Total bytes read from global memory (mock).
117    pub bytes_read: u64,
118    /// Total bytes written to global memory (mock).
119    pub bytes_written: u64,
120}
121
122impl KernelPerfCounters {
123    /// Record one dispatch that processed `n` elements.
124    pub fn record_dispatch(&mut self, elements: u64, flops: u64, bytes_r: u64, bytes_w: u64) {
125        self.dispatch_count += 1;
126        self.elements_processed += elements;
127        self.flop_count += flops;
128        self.bytes_read += bytes_r;
129        self.bytes_written += bytes_w;
130    }
131
132    /// Arithmetic intensity (FLOPs per byte).
133    pub fn arithmetic_intensity(&self) -> f64 {
134        let bytes = self.bytes_read + self.bytes_written;
135        if bytes == 0 {
136            return 0.0;
137        }
138        self.flop_count as f64 / bytes as f64
139    }
140
141    /// Reset all counters.
142    pub fn reset(&mut self) {
143        *self = KernelPerfCounters::default();
144    }
145}
146
147// ── Shared-memory size helper ────────────────────────────────────────────────
148
149/// Calculate the shared-memory footprint (bytes) for a tiled matrix-multiply
150/// kernel with tiles of size `tile` × `tile` of `T`-sized elements.
151pub fn smem_bytes_matmul<T>(tile: usize) -> usize {
152    2 * tile * tile * std::mem::size_of::<T>()
153}
154
155// ── Barrier simulation ───────────────────────────────────────────────────────
156
157/// Simulated GPU barrier: in CPU mock this is a no-op but documents
158/// synchronisation points for future GPU backend porting.
159#[inline(always)]
160pub fn workgroup_barrier() {
161    // CPU: no-op — Rayon fork-join already provides synchronisation.
162    std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
163}
164
165// ── Predefined group sizes ───────────────────────────────────────────────────
166
167/// Typical work-group sizes used by NVIDIA/AMD GPUs.
168pub mod group_sizes {
169    /// 64 threads — common on AMD RDNA and for register-heavy kernels.
170    pub const WG_64: u32 = 64;
171    /// 128 threads — common general-purpose choice.
172    pub const WG_128: u32 = 128;
173    /// 256 threads — default for many CUDA/Vulkan kernels.
174    pub const WG_256: u32 = 256;
175    /// 512 threads — useful for reduction passes.
176    pub const WG_512: u32 = 512;
177    /// 1024 threads — maximum work-group size on most hardware.
178    pub const WG_1024: u32 = 1024;
179}
180
181// ── Tests ────────────────────────────────────────────────────────────────────
182
183#[cfg(test)]
184mod kernel_mod_tests {
185    use super::*;
186
187    #[test]
188    fn test_kernel_family_display() {
189        assert_eq!(KernelFamily::Sph.to_string(), "sph");
190        assert_eq!(KernelFamily::NeuralCompute.to_string(), "neural_compute");
191        assert_eq!(KernelFamily::GridReduce.to_string(), "grid_reduce");
192    }
193
194    #[test]
195    fn test_dispatch_dims_linear() {
196        let d = DispatchDims::linear(128);
197        assert_eq!(d.total_groups(), 128);
198        assert_eq!(d.total_threads(256), 128 * 256);
199    }
200
201    #[test]
202    fn test_dispatch_dims_grid3d() {
203        let d = DispatchDims::grid3d(4, 4, 4);
204        assert_eq!(d.total_groups(), 64);
205    }
206
207    #[test]
208    fn test_dispatch_size_1d_exact() {
209        assert_eq!(dispatch_size_1d(256, 64), 4);
210    }
211
212    #[test]
213    fn test_dispatch_size_1d_remainder() {
214        assert_eq!(dispatch_size_1d(257, 64), 5);
215    }
216
217    #[test]
218    fn test_dispatch_size_1d_zero_group() {
219        assert_eq!(dispatch_size_1d(100, 0), 0);
220    }
221
222    #[test]
223    fn test_perf_counters_arithmetic_intensity() {
224        let mut c = KernelPerfCounters::default();
225        c.record_dispatch(1024, 8192, 4096, 4096);
226        // intensity = 8192 / 8192 = 1.0
227        assert!((c.arithmetic_intensity() - 1.0).abs() < 1e-10);
228    }
229
230    #[test]
231    fn test_perf_counters_reset() {
232        let mut c = KernelPerfCounters::default();
233        c.record_dispatch(512, 1024, 512, 512);
234        c.reset();
235        assert_eq!(c.dispatch_count, 0);
236        assert_eq!(c.flop_count, 0);
237    }
238
239    #[test]
240    fn test_smem_bytes_matmul_f32() {
241        // 2 * 16 * 16 * 4 bytes = 2048
242        let bytes = smem_bytes_matmul::<f32>(16);
243        assert_eq!(bytes, 2048);
244    }
245
246    #[test]
247    fn test_smem_bytes_matmul_f64() {
248        // 2 * 16 * 16 * 8 bytes = 4096
249        let bytes = smem_bytes_matmul::<f64>(16);
250        assert_eq!(bytes, 4096);
251    }
252
253    #[test]
254    fn test_workgroup_barrier_no_panic() {
255        workgroup_barrier(); // must not panic
256    }
257
258    #[allow(clippy::assertions_on_constants)]
259    #[test]
260    fn test_group_sizes_constants() {
261        use group_sizes::*;
262        assert!(WG_64 < WG_128);
263        assert!(WG_128 < WG_256);
264        assert!(WG_256 < WG_512);
265        assert!(WG_512 < WG_1024);
266    }
267}