Skip to main content

ringkernel_wavesim/simulation/
kernel_grid.rs

1//! Kernel-based simulation grid using RingKernel actors.
2//!
3//! This module provides a GPU-native implementation of the wave simulation
4//! where each cell is a persistent kernel actor communicating via K2K messaging.
5
6use super::{AcousticParams, CellState};
7use ringkernel::prelude::*;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// A simulation grid backed by RingKernel actors.
12///
13/// Each cell in the grid is a persistent kernel that maintains its own state
14/// and communicates with neighbors via message passing.
15pub struct KernelGrid {
16    // Note: We don't derive Debug because RingKernel and KernelHandle don't implement it
17    /// Grid width (number of columns).
18    pub width: u32,
19    /// Grid height (number of rows).
20    pub height: u32,
21    /// Cell states (for visualization).
22    cells: HashMap<(u32, u32), CellState>,
23    /// Acoustic simulation parameters.
24    pub params: AcousticParams,
25    /// The RingKernel runtime.
26    runtime: Arc<RingKernel>,
27    /// Kernel handles for each cell.
28    kernels: HashMap<String, KernelHandle>,
29    /// The backend being used.
30    backend: Backend,
31}
32
33impl std::fmt::Debug for KernelGrid {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("KernelGrid")
36            .field("width", &self.width)
37            .field("height", &self.height)
38            .field("backend", &self.backend)
39            .field("kernel_count", &self.kernels.len())
40            .finish()
41    }
42}
43
44impl KernelGrid {
45    /// Create a new kernel-based simulation grid.
46    ///
47    /// # Arguments
48    /// * `width` - Number of columns
49    /// * `height` - Number of rows
50    /// * `params` - Acoustic simulation parameters
51    /// * `backend` - The compute backend to use
52    pub async fn new(
53        width: u32,
54        height: u32,
55        params: AcousticParams,
56        backend: Backend,
57    ) -> Result<Self> {
58        let runtime = Arc::new(RingKernel::with_backend(backend).await?);
59
60        let mut grid = Self {
61            width,
62            height,
63            cells: HashMap::with_capacity((width * height) as usize),
64            params,
65            runtime,
66            kernels: HashMap::new(),
67            backend,
68        };
69
70        grid.initialize_cells();
71        grid.launch_kernels().await?;
72
73        Ok(grid)
74    }
75
76    /// Initialize cell states.
77    fn initialize_cells(&mut self) {
78        self.cells.clear();
79
80        for y in 0..self.height {
81            for x in 0..self.width {
82                let mut cell = CellState::new(x, y);
83
84                // Mark boundary cells
85                let is_boundary = x == 0 || y == 0 || x == self.width - 1 || y == self.height - 1;
86                cell.is_boundary = is_boundary;
87                cell.reflection_coeff = if is_boundary { 0.95 } else { 1.0 };
88
89                self.cells.insert((x, y), cell);
90            }
91        }
92    }
93
94    /// Launch kernel actors for each cell.
95    async fn launch_kernels(&mut self) -> Result<()> {
96        for ((_x, _y), cell) in &self.cells {
97            let kernel_id = cell.kernel_id();
98
99            let options = LaunchOptions::default();
100
101            let kernel = self.runtime.launch(&kernel_id, options).await?;
102            // Note: CPU runtime launches kernels in Active state by default,
103            // so we don't call activate() here to avoid InvalidStateTransition errors
104
105            self.kernels.insert(kernel_id, kernel);
106        }
107
108        tracing::info!(
109            "Launched {} kernel actors on {:?} backend",
110            self.kernels.len(),
111            self.backend
112        );
113
114        Ok(())
115    }
116
117    /// Perform one simulation step using kernel actors.
118    ///
119    /// This coordinates the pressure exchange and computation across
120    /// all kernel actors.
121    pub async fn step(&mut self) -> Result<()> {
122        // Phase 1: Exchange pressure values between neighbors
123        self.exchange_pressures();
124
125        // Phase 2: Each cell computes its next pressure value
126        for cell in self.cells.values_mut() {
127            cell.step(&self.params);
128        }
129
130        Ok(())
131    }
132
133    /// Exchange pressure values between neighboring cells.
134    fn exchange_pressures(&mut self) {
135        // Collect current pressure values
136        let pressures: HashMap<(u32, u32), f32> = self
137            .cells
138            .iter()
139            .map(|(&pos, cell)| (pos, cell.pressure))
140            .collect();
141
142        // Update each cell with neighbor pressures
143        for ((x, y), cell) in self.cells.iter_mut() {
144            // North neighbor (y - 1)
145            cell.p_north = if *y > 0 {
146                pressures.get(&(*x, y - 1)).copied().unwrap_or(0.0)
147            } else {
148                cell.pressure * cell.reflection_coeff // Reflect at boundary
149            };
150
151            // South neighbor (y + 1)
152            cell.p_south = if *y < self.height - 1 {
153                pressures.get(&(*x, y + 1)).copied().unwrap_or(0.0)
154            } else {
155                cell.pressure * cell.reflection_coeff
156            };
157
158            // West neighbor (x - 1)
159            cell.p_west = if *x > 0 {
160                pressures.get(&(x - 1, *y)).copied().unwrap_or(0.0)
161            } else {
162                cell.pressure * cell.reflection_coeff
163            };
164
165            // East neighbor (x + 1)
166            cell.p_east = if *x < self.width - 1 {
167                pressures.get(&(x + 1, *y)).copied().unwrap_or(0.0)
168            } else {
169                cell.pressure * cell.reflection_coeff
170            };
171        }
172    }
173
174    /// Inject an impulse at the given grid position.
175    pub fn inject_impulse(&mut self, x: u32, y: u32, amplitude: f32) {
176        if let Some(cell) = self.cells.get_mut(&(x, y)) {
177            cell.inject_impulse(amplitude);
178        }
179    }
180
181    /// Get the pressure grid for visualization.
182    pub fn get_pressure_grid(&self) -> Vec<Vec<f32>> {
183        let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
184        for ((x, y), cell) in &self.cells {
185            if (*y as usize) < grid.len() && (*x as usize) < grid[0].len() {
186                grid[*y as usize][*x as usize] = cell.pressure;
187            }
188        }
189        grid
190    }
191
192    /// Get the maximum absolute pressure in the grid.
193    pub fn max_pressure(&self) -> f32 {
194        self.cells
195            .values()
196            .map(|c| c.pressure.abs())
197            .fold(0.0, f32::max)
198    }
199
200    /// Get total energy in the system.
201    pub fn total_energy(&self) -> f32 {
202        self.cells.values().map(|c| c.pressure * c.pressure).sum()
203    }
204
205    /// Reset all cells to initial state.
206    pub fn reset(&mut self) {
207        for cell in self.cells.values_mut() {
208            cell.reset();
209        }
210    }
211
212    /// Resize the grid.
213    pub async fn resize(&mut self, new_width: u32, new_height: u32) -> Result<()> {
214        // Terminate existing kernels and recreate runtime to clear registered kernel IDs
215        self.shutdown_kernels().await?;
216
217        // Create new runtime to clear kernel registrations
218        self.runtime = Arc::new(RingKernel::with_backend(self.backend).await?);
219
220        self.width = new_width;
221        self.height = new_height;
222
223        self.initialize_cells();
224        self.launch_kernels().await
225    }
226
227    /// Update acoustic parameters.
228    pub fn set_speed_of_sound(&mut self, speed: f32) {
229        self.params.set_speed_of_sound(speed);
230    }
231
232    /// Update cell size.
233    pub fn set_cell_size(&mut self, size: f32) {
234        self.params.set_cell_size(size);
235    }
236
237    /// Get the number of cells.
238    pub fn cell_count(&self) -> usize {
239        self.cells.len()
240    }
241
242    /// Get the current backend.
243    pub fn backend(&self) -> Backend {
244        self.backend
245    }
246
247    /// Switch to a different backend.
248    pub async fn switch_backend(&mut self, backend: Backend) -> Result<()> {
249        if self.backend == backend {
250            return Ok(());
251        }
252
253        // Shutdown current kernels
254        self.shutdown_kernels().await?;
255
256        // Create new runtime with new backend
257        self.runtime = Arc::new(RingKernel::with_backend(backend).await?);
258        self.backend = backend;
259
260        // Relaunch kernels
261        self.launch_kernels().await?;
262
263        tracing::info!("Switched to {:?} backend", backend);
264        Ok(())
265    }
266
267    /// Shutdown all kernel actors.
268    async fn shutdown_kernels(&mut self) -> Result<()> {
269        for (id, kernel) in self.kernels.drain() {
270            if let Err(e) = kernel.terminate().await {
271                tracing::warn!("Failed to terminate kernel {}: {:?}", id, e);
272            }
273        }
274        Ok(())
275    }
276
277    /// Shutdown the grid and release resources.
278    pub async fn shutdown(mut self) -> Result<()> {
279        self.shutdown_kernels().await
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[tokio::test]
288    async fn test_kernel_grid_creation() {
289        let params = AcousticParams::new(343.0, 1.0);
290        let grid = KernelGrid::new(8, 8, params, Backend::Cpu).await.unwrap();
291
292        assert_eq!(grid.width, 8);
293        assert_eq!(grid.height, 8);
294        assert_eq!(grid.cell_count(), 64);
295        assert_eq!(grid.backend(), Backend::Cpu);
296    }
297
298    #[tokio::test]
299    async fn test_kernel_grid_impulse() {
300        let params = AcousticParams::new(343.0, 1.0);
301        let mut grid = KernelGrid::new(4, 4, params, Backend::Cpu).await.unwrap();
302
303        grid.inject_impulse(2, 2, 1.0);
304
305        let pressure_grid = grid.get_pressure_grid();
306        assert_eq!(pressure_grid[2][2], 1.0);
307    }
308
309    #[tokio::test]
310    async fn test_kernel_grid_step() {
311        let params = AcousticParams::new(343.0, 1.0);
312        let mut grid = KernelGrid::new(8, 8, params, Backend::Cpu).await.unwrap();
313
314        // Inject impulse at center
315        grid.inject_impulse(4, 4, 1.0);
316
317        // Run several steps
318        for _ in 0..10 {
319            grid.step().await.unwrap();
320        }
321
322        // Wave should have propagated to neighbors
323        let pressure_grid = grid.get_pressure_grid();
324        let neighbor_pressure = pressure_grid[4][5];
325        assert!(
326            neighbor_pressure.abs() > 0.0,
327            "Wave should have propagated to neighbor"
328        );
329    }
330
331    #[tokio::test]
332    async fn test_kernel_grid_reset() {
333        let params = AcousticParams::new(343.0, 1.0);
334        let mut grid = KernelGrid::new(4, 4, params, Backend::Cpu).await.unwrap();
335
336        grid.inject_impulse(2, 2, 1.0);
337        grid.step().await.unwrap();
338
339        grid.reset();
340
341        let pressure_grid = grid.get_pressure_grid();
342        for row in pressure_grid {
343            for p in row {
344                assert_eq!(p, 0.0);
345            }
346        }
347    }
348
349    #[tokio::test]
350    async fn test_kernel_grid_resize() {
351        let params = AcousticParams::new(343.0, 1.0);
352        let mut grid = KernelGrid::new(4, 4, params, Backend::Cpu).await.unwrap();
353
354        grid.resize(8, 6).await.unwrap();
355
356        assert_eq!(grid.width, 8);
357        assert_eq!(grid.height, 6);
358        assert_eq!(grid.cell_count(), 48);
359    }
360
361    #[tokio::test]
362    async fn test_kernel_grid_shutdown() {
363        let params = AcousticParams::new(343.0, 1.0);
364        let grid = KernelGrid::new(4, 4, params, Backend::Cpu).await.unwrap();
365
366        // Should shutdown cleanly
367        grid.shutdown().await.unwrap();
368    }
369}