oxiphysics_gpu/kernels/
mod.rs1#![allow(dead_code)]
11
12pub mod broadphase;
13pub mod md_force;
14pub mod rigid;
15pub mod sph;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum KernelFamily {
22 Sph,
24 Rigid,
26 Broadphase,
28 MdForce,
30 SdfCompute,
32 NeuralCompute,
34 GridReduce,
36}
37
38impl std::fmt::Display for KernelFamily {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 let name = match self {
41 KernelFamily::Sph => "sph",
42 KernelFamily::Rigid => "rigid",
43 KernelFamily::Broadphase => "broadphase",
44 KernelFamily::MdForce => "md_force",
45 KernelFamily::SdfCompute => "sdf_compute",
46 KernelFamily::NeuralCompute => "neural_compute",
47 KernelFamily::GridReduce => "grid_reduce",
48 };
49 write!(f, "{name}")
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub struct DispatchDims {
61 pub x: u32,
63 pub y: u32,
65 pub z: u32,
67}
68
69impl DispatchDims {
70 pub fn linear(n: u32) -> Self {
72 Self { x: n, y: 1, z: 1 }
73 }
74
75 pub fn grid2d(x: u32, y: u32) -> Self {
77 Self { x, y, z: 1 }
78 }
79
80 pub fn grid3d(x: u32, y: u32, z: u32) -> Self {
82 Self { x, y, z }
83 }
84
85 pub fn total_groups(&self) -> u64 {
87 self.x as u64 * self.y as u64 * self.z as u64
88 }
89
90 pub fn total_threads(&self, threads_per_group: u32) -> u64 {
92 self.total_groups() * threads_per_group as u64
93 }
94}
95
96pub fn dispatch_size_1d(n: u32, group_size: u32) -> u32 {
99 if group_size == 0 {
100 return 0;
101 }
102 n.div_ceil(group_size)
103}
104
105#[derive(Debug, Clone, Default)]
109pub struct KernelPerfCounters {
110 pub dispatch_count: u64,
112 pub elements_processed: u64,
114 pub flop_count: u64,
116 pub bytes_read: u64,
118 pub bytes_written: u64,
120}
121
122impl KernelPerfCounters {
123 pub fn record_dispatch(&mut self, elements: u64, flops: u64, bytes_r: u64, bytes_w: u64) {
125 self.dispatch_count += 1;
126 self.elements_processed += elements;
127 self.flop_count += flops;
128 self.bytes_read += bytes_r;
129 self.bytes_written += bytes_w;
130 }
131
132 pub fn arithmetic_intensity(&self) -> f64 {
134 let bytes = self.bytes_read + self.bytes_written;
135 if bytes == 0 {
136 return 0.0;
137 }
138 self.flop_count as f64 / bytes as f64
139 }
140
141 pub fn reset(&mut self) {
143 *self = KernelPerfCounters::default();
144 }
145}
146
147pub fn smem_bytes_matmul<T>(tile: usize) -> usize {
152 2 * tile * tile * std::mem::size_of::<T>()
153}
154
155#[inline(always)]
160pub fn workgroup_barrier() {
161 std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
163}
164
165pub mod group_sizes {
169 pub const WG_64: u32 = 64;
171 pub const WG_128: u32 = 128;
173 pub const WG_256: u32 = 256;
175 pub const WG_512: u32 = 512;
177 pub const WG_1024: u32 = 1024;
179}
180
181#[cfg(test)]
184mod kernel_mod_tests {
185 use super::*;
186
187 #[test]
188 fn test_kernel_family_display() {
189 assert_eq!(KernelFamily::Sph.to_string(), "sph");
190 assert_eq!(KernelFamily::NeuralCompute.to_string(), "neural_compute");
191 assert_eq!(KernelFamily::GridReduce.to_string(), "grid_reduce");
192 }
193
194 #[test]
195 fn test_dispatch_dims_linear() {
196 let d = DispatchDims::linear(128);
197 assert_eq!(d.total_groups(), 128);
198 assert_eq!(d.total_threads(256), 128 * 256);
199 }
200
201 #[test]
202 fn test_dispatch_dims_grid3d() {
203 let d = DispatchDims::grid3d(4, 4, 4);
204 assert_eq!(d.total_groups(), 64);
205 }
206
207 #[test]
208 fn test_dispatch_size_1d_exact() {
209 assert_eq!(dispatch_size_1d(256, 64), 4);
210 }
211
212 #[test]
213 fn test_dispatch_size_1d_remainder() {
214 assert_eq!(dispatch_size_1d(257, 64), 5);
215 }
216
217 #[test]
218 fn test_dispatch_size_1d_zero_group() {
219 assert_eq!(dispatch_size_1d(100, 0), 0);
220 }
221
222 #[test]
223 fn test_perf_counters_arithmetic_intensity() {
224 let mut c = KernelPerfCounters::default();
225 c.record_dispatch(1024, 8192, 4096, 4096);
226 assert!((c.arithmetic_intensity() - 1.0).abs() < 1e-10);
228 }
229
230 #[test]
231 fn test_perf_counters_reset() {
232 let mut c = KernelPerfCounters::default();
233 c.record_dispatch(512, 1024, 512, 512);
234 c.reset();
235 assert_eq!(c.dispatch_count, 0);
236 assert_eq!(c.flop_count, 0);
237 }
238
239 #[test]
240 fn test_smem_bytes_matmul_f32() {
241 let bytes = smem_bytes_matmul::<f32>(16);
243 assert_eq!(bytes, 2048);
244 }
245
246 #[test]
247 fn test_smem_bytes_matmul_f64() {
248 let bytes = smem_bytes_matmul::<f64>(16);
250 assert_eq!(bytes, 4096);
251 }
252
253 #[test]
254 fn test_workgroup_barrier_no_panic() {
255 workgroup_barrier(); }
257
258 #[allow(clippy::assertions_on_constants)]
259 #[test]
260 fn test_group_sizes_constants() {
261 use group_sizes::*;
262 assert!(WG_64 < WG_128);
263 assert!(WG_128 < WG_256);
264 assert!(WG_256 < WG_512);
265 assert!(WG_512 < WG_1024);
266 }
267}