Skip to main content

ringkernel_wavesim3d/simulation/
grid3d.rs

1//! 3D FDTD simulation grid for acoustic wave propagation.
2//!
3//! Implements the 3D wave equation using finite differences:
4//! ∂²p/∂t² = c² (∂²p/∂x² + ∂²p/∂y² + ∂²p/∂z²) - γ·∂p/∂t
5//!
6//! Uses a 7-point stencil (6 neighbors + center) for the Laplacian.
7
8use crate::simulation::physics::{AcousticParams3D, Position3D};
9use rayon::prelude::*;
10
11/// Cell type for boundary conditions.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum CellType {
14    /// Normal cell - wave propagates freely
15    #[default]
16    Normal,
17    /// Absorber - absorbs waves (anechoic boundary)
18    Absorber,
19    /// Reflector - reflects waves (hard boundary)
20    Reflector,
21    /// Obstacle - solid object (partial reflection)
22    Obstacle,
23}
24
25/// 3D simulation grid using Structure of Arrays (SoA) layout.
26///
27/// Memory layout is z-major (z varies slowest):
28/// `index = z * (width * height) + y * width + x`
29pub struct SimulationGrid3D {
30    /// Grid dimensions
31    pub width: usize,
32    pub height: usize,
33    pub depth: usize,
34
35    /// Current pressure field (row-major, z-major ordering)
36    pub pressure: Vec<f32>,
37
38    /// Previous pressure field (for time-stepping)
39    pub pressure_prev: Vec<f32>,
40
41    /// Cell types (boundary conditions)
42    pub cell_types: Vec<CellType>,
43
44    /// Reflection coefficients per cell (0 = full absorption, 1 = full reflection)
45    pub reflection_coeff: Vec<f32>,
46
47    /// Acoustic parameters
48    pub params: AcousticParams3D,
49
50    /// Total number of cells
51    total_cells: usize,
52
53    /// Slice size (width * height) for z-indexing
54    slice_size: usize,
55
56    /// Current simulation step
57    pub step: u64,
58
59    /// Accumulated simulation time
60    pub time: f32,
61}
62
63impl SimulationGrid3D {
64    /// Create a new 3D simulation grid.
65    ///
66    /// # Arguments
67    /// * `width` - Number of cells in X direction
68    /// * `height` - Number of cells in Y direction
69    /// * `depth` - Number of cells in Z direction
70    /// * `params` - Acoustic simulation parameters
71    pub fn new(width: usize, height: usize, depth: usize, params: AcousticParams3D) -> Self {
72        let total_cells = width * height * depth;
73        let slice_size = width * height;
74
75        let mut grid = Self {
76            width,
77            height,
78            depth,
79            pressure: vec![0.0; total_cells],
80            pressure_prev: vec![0.0; total_cells],
81            cell_types: vec![CellType::Normal; total_cells],
82            reflection_coeff: vec![1.0; total_cells],
83            params,
84            total_cells,
85            slice_size,
86            step: 0,
87            time: 0.0,
88        };
89
90        // Set boundary conditions
91        grid.setup_boundaries();
92        grid
93    }
94
95    /// Set up default boundary conditions (absorbing boundaries).
96    fn setup_boundaries(&mut self) {
97        // Set all boundary cells to absorbers
98        for z in 0..self.depth {
99            for y in 0..self.height {
100                for x in 0..self.width {
101                    let is_boundary = x == 0
102                        || x == self.width - 1
103                        || y == 0
104                        || y == self.height - 1
105                        || z == 0
106                        || z == self.depth - 1;
107
108                    if is_boundary {
109                        let idx = self.index(x, y, z);
110                        self.cell_types[idx] = CellType::Absorber;
111                        self.reflection_coeff[idx] = 0.1; // 10% reflection
112                    }
113                }
114            }
115        }
116    }
117
118    /// Convert (x, y, z) coordinates to linear index.
119    #[inline]
120    pub fn index(&self, x: usize, y: usize, z: usize) -> usize {
121        z * self.slice_size + y * self.width + x
122    }
123
124    /// Convert linear index to (x, y, z) coordinates.
125    #[inline]
126    pub fn coords(&self, idx: usize) -> (usize, usize, usize) {
127        let z = idx / self.slice_size;
128        let remainder = idx % self.slice_size;
129        let y = remainder / self.width;
130        let x = remainder % self.width;
131        (x, y, z)
132    }
133
134    /// Check if coordinates are within bounds.
135    #[inline]
136    pub fn in_bounds(&self, x: i32, y: i32, z: i32) -> bool {
137        x >= 0
138            && x < self.width as i32
139            && y >= 0
140            && y < self.height as i32
141            && z >= 0
142            && z < self.depth as i32
143    }
144
145    /// Inject a pressure impulse at a specific position.
146    pub fn inject_impulse(&mut self, x: usize, y: usize, z: usize, amplitude: f32) {
147        if x < self.width && y < self.height && z < self.depth {
148            let idx = self.index(x, y, z);
149            self.pressure[idx] += amplitude;
150            self.pressure_prev[idx] += amplitude * 0.5;
151        }
152    }
153
154    /// Inject a pressure impulse at a 3D position (in meters).
155    pub fn inject_impulse_at(&mut self, pos: Position3D, amplitude: f32) {
156        let (x, y, z) = pos.to_grid_indices(self.params.cell_size);
157        self.inject_impulse(x, y, z, amplitude);
158    }
159
160    /// Inject a spherical impulse (Gaussian distribution).
161    pub fn inject_spherical_impulse(
162        &mut self,
163        center: Position3D,
164        amplitude: f32,
165        radius_cells: f32,
166    ) {
167        let (cx, cy, cz) = center.to_grid_indices(self.params.cell_size);
168        let r = radius_cells as i32;
169        let r_sq = radius_cells * radius_cells;
170
171        for dz in -r..=r {
172            for dy in -r..=r {
173                for dx in -r..=r {
174                    let x = cx as i32 + dx;
175                    let y = cy as i32 + dy;
176                    let z = cz as i32 + dz;
177
178                    if self.in_bounds(x, y, z) {
179                        let dist_sq = (dx * dx + dy * dy + dz * dz) as f32;
180                        if dist_sq <= r_sq {
181                            // Gaussian falloff
182                            let factor = (-dist_sq / (2.0 * radius_cells)).exp();
183                            let idx = self.index(x as usize, y as usize, z as usize);
184                            self.pressure[idx] += amplitude * factor;
185                            self.pressure_prev[idx] += amplitude * factor * 0.5;
186                        }
187                    }
188                }
189            }
190        }
191    }
192
193    /// Set cell type at a specific location.
194    pub fn set_cell_type(&mut self, x: usize, y: usize, z: usize, cell_type: CellType) {
195        if x < self.width && y < self.height && z < self.depth {
196            let idx = self.index(x, y, z);
197            self.cell_types[idx] = cell_type;
198            self.reflection_coeff[idx] = match cell_type {
199                CellType::Normal => 1.0,
200                CellType::Absorber => 0.0,
201                CellType::Reflector => 1.0,
202                CellType::Obstacle => 0.8,
203            };
204        }
205    }
206
207    /// Create a rectangular obstacle.
208    pub fn create_obstacle(&mut self, min: Position3D, max: Position3D, cell_type: CellType) {
209        let (x0, y0, z0) = min.to_grid_indices(self.params.cell_size);
210        let (x1, y1, z1) = max.to_grid_indices(self.params.cell_size);
211
212        for z in z0..=z1.min(self.depth - 1) {
213            for y in y0..=y1.min(self.height - 1) {
214                for x in x0..=x1.min(self.width - 1) {
215                    self.set_cell_type(x, y, z, cell_type);
216                }
217            }
218        }
219    }
220
221    /// Perform one FDTD time step (CPU, sequential).
222    pub fn step_sequential(&mut self) {
223        let c2 = self.params.c_squared;
224        let damping = self.params.simple_damping;
225
226        // Update interior cells only
227        for z in 1..self.depth - 1 {
228            for y in 1..self.height - 1 {
229                for x in 1..self.width - 1 {
230                    let idx = self.index(x, y, z);
231
232                    // Skip non-normal cells (handle separately)
233                    if self.cell_types[idx] != CellType::Normal {
234                        continue;
235                    }
236
237                    // Current and previous pressure
238                    let p = self.pressure[idx];
239                    let p_prev = self.pressure_prev[idx];
240
241                    // 6 neighbors for 3D Laplacian
242                    let p_west = self.pressure[idx - 1];
243                    let p_east = self.pressure[idx + 1];
244                    let p_south = self.pressure[idx - self.width];
245                    let p_north = self.pressure[idx + self.width];
246                    let p_down = self.pressure[idx - self.slice_size];
247                    let p_up = self.pressure[idx + self.slice_size];
248
249                    // 7-point stencil Laplacian
250                    let laplacian = p_west + p_east + p_south + p_north + p_down + p_up - 6.0 * p;
251
252                    // FDTD update equation
253                    let p_new = 2.0 * p - p_prev + c2 * laplacian;
254
255                    // Apply damping
256                    self.pressure_prev[idx] = p_new * damping;
257                }
258            }
259        }
260
261        // Handle boundary cells
262        self.apply_boundary_conditions();
263
264        // Swap buffers
265        std::mem::swap(&mut self.pressure, &mut self.pressure_prev);
266
267        self.step += 1;
268        self.time += self.params.time_step;
269    }
270
271    /// Perform one FDTD time step (CPU, parallel with Rayon).
272    pub fn step_parallel(&mut self) {
273        let c2 = self.params.c_squared;
274        let damping = self.params.simple_damping;
275        let width = self.width;
276        let height = self.height;
277        let depth = self.depth;
278        let slice_size = self.slice_size;
279
280        // Take references to avoid capturing self in closures
281        let pressure = &self.pressure;
282        let pressure_prev = &self.pressure_prev;
283        let cell_types = &self.cell_types;
284
285        // Compute new values in parallel and store in a new buffer
286        // Use par_iter over z indices to avoid flat_map memory explosion
287        let num_interior_z = depth - 2;
288
289        // Pre-allocate output buffer
290        let mut new_values = vec![0.0f32; num_interior_z * (height - 2) * (width - 2)];
291
292        // Process in parallel
293        new_values
294            .par_chunks_mut((height - 2) * (width - 2))
295            .enumerate()
296            .for_each(|(zi, chunk)| {
297                let z = zi + 1; // Actual z index (skip boundary)
298                let mut out_idx = 0;
299
300                for y in 1..height - 1 {
301                    for x in 1..width - 1 {
302                        let idx = z * slice_size + y * width + x;
303
304                        let p_new = if cell_types[idx] != CellType::Normal {
305                            // Keep previous value for non-normal cells
306                            pressure_prev[idx]
307                        } else {
308                            let p = pressure[idx];
309                            let p_prev = pressure_prev[idx];
310
311                            // 6 neighbors
312                            let p_west = pressure[idx - 1];
313                            let p_east = pressure[idx + 1];
314                            let p_south = pressure[idx - width];
315                            let p_north = pressure[idx + width];
316                            let p_down = pressure[idx - slice_size];
317                            let p_up = pressure[idx + slice_size];
318
319                            let laplacian =
320                                p_west + p_east + p_south + p_north + p_down + p_up - 6.0 * p;
321                            (2.0 * p - p_prev + c2 * laplacian) * damping
322                        };
323
324                        chunk[out_idx] = p_new;
325                        out_idx += 1;
326                    }
327                }
328            });
329
330        // Copy results back to pressure_prev
331        let mut result_idx = 0;
332        for z in 1..depth - 1 {
333            for y in 1..height - 1 {
334                for x in 1..width - 1 {
335                    let idx = self.index(x, y, z);
336                    self.pressure_prev[idx] = new_values[result_idx];
337                    result_idx += 1;
338                }
339            }
340        }
341
342        // Handle boundary conditions
343        self.apply_boundary_conditions();
344
345        // Swap buffers
346        std::mem::swap(&mut self.pressure, &mut self.pressure_prev);
347
348        self.step += 1;
349        self.time += self.params.time_step;
350    }
351
352    /// Apply boundary conditions.
353    fn apply_boundary_conditions(&mut self) {
354        for idx in 0..self.total_cells {
355            match self.cell_types[idx] {
356                CellType::Absorber => {
357                    // Gradually absorb energy
358                    self.pressure_prev[idx] *= self.reflection_coeff[idx];
359                }
360                CellType::Reflector => {
361                    // Perfect reflection (handled implicitly by stencil)
362                }
363                CellType::Obstacle => {
364                    // Partial reflection
365                    self.pressure_prev[idx] *= self.reflection_coeff[idx];
366                }
367                CellType::Normal => {}
368            }
369        }
370    }
371
372    /// Get pressure value at a specific position (interpolated).
373    pub fn sample_pressure(&self, pos: Position3D) -> f32 {
374        // Trilinear interpolation
375        let fx = pos.x / self.params.cell_size;
376        let fy = pos.y / self.params.cell_size;
377        let fz = pos.z / self.params.cell_size;
378
379        let x0 = fx.floor() as usize;
380        let y0 = fy.floor() as usize;
381        let z0 = fz.floor() as usize;
382
383        let x1 = (x0 + 1).min(self.width - 1);
384        let y1 = (y0 + 1).min(self.height - 1);
385        let z1 = (z0 + 1).min(self.depth - 1);
386
387        let tx = fx - x0 as f32;
388        let ty = fy - y0 as f32;
389        let tz = fz - z0 as f32;
390
391        // 8 corner samples
392        let c000 = self.pressure[self.index(x0, y0, z0)];
393        let c100 = self.pressure[self.index(x1, y0, z0)];
394        let c010 = self.pressure[self.index(x0, y1, z0)];
395        let c110 = self.pressure[self.index(x1, y1, z0)];
396        let c001 = self.pressure[self.index(x0, y0, z1)];
397        let c101 = self.pressure[self.index(x1, y0, z1)];
398        let c011 = self.pressure[self.index(x0, y1, z1)];
399        let c111 = self.pressure[self.index(x1, y1, z1)];
400
401        // Trilinear interpolation
402        let c00 = c000 * (1.0 - tx) + c100 * tx;
403        let c01 = c001 * (1.0 - tx) + c101 * tx;
404        let c10 = c010 * (1.0 - tx) + c110 * tx;
405        let c11 = c011 * (1.0 - tx) + c111 * tx;
406
407        let c0 = c00 * (1.0 - ty) + c10 * ty;
408        let c1 = c01 * (1.0 - ty) + c11 * ty;
409
410        c0 * (1.0 - tz) + c1 * tz
411    }
412
413    /// Get an XY slice of the pressure field at a given Z.
414    pub fn get_xy_slice(&self, z: usize) -> Vec<f32> {
415        let z = z.min(self.depth - 1);
416        let start = z * self.slice_size;
417        self.pressure[start..start + self.slice_size].to_vec()
418    }
419
420    /// Get an XZ slice of the pressure field at a given Y.
421    pub fn get_xz_slice(&self, y: usize) -> Vec<f32> {
422        let y = y.min(self.height - 1);
423        let mut slice = Vec::with_capacity(self.width * self.depth);
424        for z in 0..self.depth {
425            for x in 0..self.width {
426                slice.push(self.pressure[self.index(x, y, z)]);
427            }
428        }
429        slice
430    }
431
432    /// Get a YZ slice of the pressure field at a given X.
433    pub fn get_yz_slice(&self, x: usize) -> Vec<f32> {
434        let x = x.min(self.width - 1);
435        let mut slice = Vec::with_capacity(self.height * self.depth);
436        for z in 0..self.depth {
437            for y in 0..self.height {
438                slice.push(self.pressure[self.index(x, y, z)]);
439            }
440        }
441        slice
442    }
443
444    /// Get the total energy in the simulation.
445    pub fn total_energy(&self) -> f32 {
446        self.pressure.iter().map(|&p| p * p).sum::<f32>()
447    }
448
449    /// Get the maximum absolute pressure.
450    pub fn max_pressure(&self) -> f32 {
451        self.pressure
452            .iter()
453            .map(|&p| p.abs())
454            .fold(0.0_f32, f32::max)
455    }
456
457    /// Reset the simulation to initial state.
458    pub fn reset(&mut self) {
459        self.pressure.fill(0.0);
460        self.pressure_prev.fill(0.0);
461        self.step = 0;
462        self.time = 0.0;
463    }
464
465    /// Get grid dimensions as (width, height, depth).
466    pub fn dimensions(&self) -> (usize, usize, usize) {
467        (self.width, self.height, self.depth)
468    }
469
470    /// Get the physical size of the simulation domain in meters.
471    pub fn physical_size(&self) -> (f32, f32, f32) {
472        let cell = self.params.cell_size;
473        (
474            self.width as f32 * cell,
475            self.height as f32 * cell,
476            self.depth as f32 * cell,
477        )
478    }
479}
480
481/// GPU-compatible data layout for 3D grid.
482///
483/// Packed format for efficient GPU memory access.
484#[repr(C)]
485#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
486pub struct GridParams {
487    /// Grid dimensions
488    pub width: u32,
489    pub height: u32,
490    pub depth: u32,
491    pub _padding0: u32,
492
493    /// Courant number squared
494    pub c_squared: f32,
495    /// Damping factor
496    pub damping: f32,
497    /// Slice size (width * height)
498    pub slice_size: u32,
499    pub _padding1: u32,
500}
501
502impl From<&SimulationGrid3D> for GridParams {
503    fn from(grid: &SimulationGrid3D) -> Self {
504        Self {
505            width: grid.width as u32,
506            height: grid.height as u32,
507            depth: grid.depth as u32,
508            _padding0: 0,
509            c_squared: grid.params.c_squared,
510            damping: grid.params.simple_damping,
511            slice_size: grid.slice_size as u32,
512            _padding1: 0,
513        }
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use crate::simulation::physics::Environment;
521
522    fn create_test_grid() -> SimulationGrid3D {
523        let params = AcousticParams3D::new(Environment::default(), 0.1);
524        SimulationGrid3D::new(32, 32, 32, params)
525    }
526
527    #[test]
528    fn test_grid_creation() {
529        let grid = create_test_grid();
530        assert_eq!(grid.width, 32);
531        assert_eq!(grid.height, 32);
532        assert_eq!(grid.depth, 32);
533        assert_eq!(grid.total_cells, 32 * 32 * 32);
534    }
535
536    #[test]
537    fn test_indexing() {
538        let grid = create_test_grid();
539
540        // Test index conversion
541        let idx = grid.index(5, 10, 15);
542        let (x, y, z) = grid.coords(idx);
543        assert_eq!((x, y, z), (5, 10, 15));
544
545        // Test corner cases
546        assert_eq!(grid.index(0, 0, 0), 0);
547        assert_eq!(grid.index(31, 31, 31), 32 * 32 * 32 - 1);
548    }
549
550    #[test]
551    fn test_impulse_injection() {
552        let mut grid = create_test_grid();
553        let center_idx = grid.index(16, 16, 16);
554
555        assert_eq!(grid.pressure[center_idx], 0.0);
556        grid.inject_impulse(16, 16, 16, 1.0);
557        assert_eq!(grid.pressure[center_idx], 1.0);
558    }
559
560    #[test]
561    fn test_wave_propagation() {
562        let mut grid = create_test_grid();
563
564        // Inject impulse at center
565        grid.inject_impulse(16, 16, 16, 1.0);
566
567        // Run simulation
568        for _ in 0..10 {
569            grid.step_sequential();
570        }
571
572        // Energy should have spread from center
573        let center_idx = grid.index(16, 16, 16);
574        let neighbor_idx = grid.index(17, 16, 16);
575
576        // Neighbor should have received energy (some energy left center)
577        assert!(grid.pressure[neighbor_idx].abs() > 0.0 || grid.pressure[center_idx].abs() > 0.0);
578    }
579
580    #[test]
581    fn test_wave_spreading() {
582        let mut grid = create_test_grid();
583
584        // Set up for minimal damping
585        grid.params.simple_damping = 1.0;
586
587        // Inject impulse at center
588        grid.inject_impulse(16, 16, 16, 1.0);
589
590        // Verify initial state
591        let initial_max = grid.max_pressure();
592        assert!(initial_max > 0.5);
593
594        // Run a few steps
595        for _ in 0..5 {
596            grid.step_sequential();
597        }
598
599        // Wave should have propagated - energy should be distributed
600        // and not concentrated at a single point anymore
601        let total_energy = grid.total_energy();
602        assert!(total_energy > 0.0, "Wave should contain energy");
603
604        // Check that wave has spread beyond center
605        // (neighboring cells should have non-zero pressure)
606        let neighbor_idx = grid.index(17, 16, 16);
607        let has_spread = grid.pressure[neighbor_idx].abs() > 0.0
608            || grid.pressure[grid.index(16, 17, 16)].abs() > 0.0;
609        assert!(has_spread, "Wave should propagate to neighbors");
610    }
611
612    #[test]
613    fn test_pressure_sampling() {
614        let mut grid = create_test_grid();
615
616        // Set a known pressure value
617        let idx = grid.index(10, 10, 10);
618        grid.pressure[idx] = 1.0;
619
620        // Sample at exact grid point
621        let pos = Position3D::new(
622            10.0 * grid.params.cell_size,
623            10.0 * grid.params.cell_size,
624            10.0 * grid.params.cell_size,
625        );
626        let sampled = grid.sample_pressure(pos);
627
628        // Should be close to 1.0 (trilinear interpolation)
629        assert!(sampled > 0.5, "Sampled value too low: {}", sampled);
630    }
631
632    #[test]
633    fn test_boundary_setup() {
634        let grid = create_test_grid();
635
636        // Check that boundary cells are absorbers
637        assert_eq!(grid.cell_types[grid.index(0, 0, 0)], CellType::Absorber);
638        assert_eq!(grid.cell_types[grid.index(31, 31, 31)], CellType::Absorber);
639
640        // Check interior cell
641        assert_eq!(grid.cell_types[grid.index(16, 16, 16)], CellType::Normal);
642    }
643
644    #[test]
645    fn test_slices() {
646        let mut grid = create_test_grid();
647
648        // Set a known pattern
649        let idx = grid.index(5, 10, 15);
650        grid.pressure[idx] = 1.0;
651
652        // Test XY slice
653        let xy_slice = grid.get_xy_slice(15);
654        assert_eq!(xy_slice.len(), grid.width * grid.height);
655        assert_eq!(xy_slice[10 * grid.width + 5], 1.0);
656
657        // Test XZ slice
658        let xz_slice = grid.get_xz_slice(10);
659        assert_eq!(xz_slice.len(), grid.width * grid.depth);
660        assert_eq!(xz_slice[15 * grid.width + 5], 1.0);
661
662        // Test YZ slice
663        let yz_slice = grid.get_yz_slice(5);
664        assert_eq!(yz_slice.len(), grid.height * grid.depth);
665        assert_eq!(yz_slice[15 * grid.height + 10], 1.0);
666    }
667
668    #[test]
669    fn test_parallel_vs_sequential() {
670        let params = AcousticParams3D::new(Environment::default(), 0.1);
671
672        let mut grid_seq = SimulationGrid3D::new(16, 16, 16, params.clone());
673        let mut grid_par = SimulationGrid3D::new(16, 16, 16, params);
674
675        // Same impulse
676        grid_seq.inject_impulse(8, 8, 8, 1.0);
677        grid_par.inject_impulse(8, 8, 8, 1.0);
678
679        // Run both
680        for _ in 0..10 {
681            grid_seq.step_sequential();
682            grid_par.step_parallel();
683        }
684
685        // Results should be very close
686        let max_diff: f32 = grid_seq
687            .pressure
688            .iter()
689            .zip(grid_par.pressure.iter())
690            .map(|(a, b)| (a - b).abs())
691            .fold(0.0, f32::max);
692
693        assert!(
694            max_diff < 0.001,
695            "Sequential and parallel results differ by {}",
696            max_diff
697        );
698    }
699}