Skip to main content

ringkernel_wavesim/simulation/
tile_grid.rs

1//! Tile-based GPU-native actor grid for wave simulation.
2//!
3//! This module implements a scalable actor-based simulation where each actor
4//! manages a tile (e.g., 16x16) of cells. Tiles communicate via K2K messaging
5//! to exchange boundary (halo) data, showcasing RingKernel's GPU-native actor model.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! +--------+--------+--------+
11//! | Tile   | Tile   | Tile   |
12//! | (0,0)  | (1,0)  | (2,0)  |
13//! +--------+--------+--------+
14//! | Tile   | Tile   | Tile   |
15//! | (0,1)  | (1,1)  | (2,1)  |
16//! +--------+--------+--------+
17//! ```
18//!
19//! Each tile:
20//! - Has a pressure buffer with 1-cell halo for neighbor access
21//! - Exchanges halo data with 4 neighbors via K2K messaging
22//! - Computes FDTD for its interior cells in parallel (GPU or CPU)
23//!
24//! ## Halo Exchange Pattern
25//!
26//! ```text
27//! +--+----------------+--+
28//! |  | North Halo     |  |  <- Received from north neighbor
29//! +--+----------------+--+
30//! |W |                |E |
31//! |  |   Interior     |  |
32//! |  |    16x16       |  |
33//! +--+----------------+--+
34//! |  | South Halo     |  |  <- Received from south neighbor
35//! +--+----------------+--+
36//! ```
37
38use super::AcousticParams;
39use ringkernel::prelude::*;
40use ringkernel_core::hlc::HlcTimestamp;
41use ringkernel_core::k2k::{DeliveryStatus, K2KBroker, K2KBuilder, K2KEndpoint};
42use ringkernel_core::message::{MessageEnvelope, MessageHeader};
43use ringkernel_core::runtime::KernelId;
44use std::collections::HashMap;
45use std::sync::Arc;
46
47// Legacy GPU compute (upload/download per step)
48#[cfg(feature = "wgpu")]
49use super::gpu_compute::{init_wgpu, TileBuffers, TileGpuComputePool};
50
51// New GPU-persistent backends using unified trait
52#[cfg(feature = "cuda")]
53use super::cuda_compute::CudaTileBackend;
54#[cfg(any(feature = "wgpu", feature = "cuda"))]
55use super::gpu_backend::{BoundaryCondition, Edge, FdtdParams, TileGpuBackend, TileGpuBuffers};
56#[cfg(feature = "wgpu")]
57use super::wgpu_compute::{WgpuBuffer, WgpuTileBackend};
58
59/// Default tile size (16x16 cells per tile).
60pub const DEFAULT_TILE_SIZE: u32 = 16;
61
62/// Direction for halo exchange.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum HaloDirection {
65    North,
66    South,
67    East,
68    West,
69}
70
71impl HaloDirection {
72    /// Get the opposite direction.
73    pub fn opposite(self) -> Self {
74        match self {
75            HaloDirection::North => HaloDirection::South,
76            HaloDirection::South => HaloDirection::North,
77            HaloDirection::East => HaloDirection::West,
78            HaloDirection::West => HaloDirection::East,
79        }
80    }
81
82    /// All directions.
83    pub const ALL: [HaloDirection; 4] = [
84        HaloDirection::North,
85        HaloDirection::South,
86        HaloDirection::East,
87        HaloDirection::West,
88    ];
89}
90
91/// Message type IDs for K2K communication.
92const HALO_EXCHANGE_TYPE_ID: u64 = 200;
93
94/// A single tile actor managing a region of the simulation grid.
95pub struct TileActor {
96    /// Tile position in tile coordinates.
97    pub tile_x: u32,
98    pub tile_y: u32,
99
100    /// Size of the tile (cells per side).
101    tile_size: u32,
102
103    /// Buffer width including halos: tile_size + 2.
104    buffer_width: usize,
105
106    /// Current pressure values (buffer_width * buffer_width).
107    /// Layout: row-major with 1-cell halo on each side.
108    pressure: Vec<f32>,
109
110    /// Previous pressure values for time-stepping.
111    pressure_prev: Vec<f32>,
112
113    /// Halo buffers for async exchange (received from neighbors).
114    halo_north: Vec<f32>,
115    halo_south: Vec<f32>,
116    halo_east: Vec<f32>,
117    halo_west: Vec<f32>,
118
119    /// Flags indicating which halos have been received this step.
120    halos_received: [bool; 4],
121
122    /// Neighbor tile IDs (None if at grid boundary).
123    neighbor_north: Option<KernelId>,
124    neighbor_south: Option<KernelId>,
125    neighbor_east: Option<KernelId>,
126    neighbor_west: Option<KernelId>,
127
128    /// K2K endpoint for this tile.
129    endpoint: K2KEndpoint,
130
131    /// Kernel ID for this tile (used for K2K routing).
132    #[allow(dead_code)]
133    kernel_id: KernelId,
134
135    /// Is this tile on the grid boundary?
136    is_north_boundary: bool,
137    is_south_boundary: bool,
138    is_east_boundary: bool,
139    is_west_boundary: bool,
140
141    /// Reflection coefficient for boundaries.
142    reflection_coeff: f32,
143}
144
145impl TileActor {
146    /// Create a new tile actor.
147    pub fn new(
148        tile_x: u32,
149        tile_y: u32,
150        tile_size: u32,
151        tiles_x: u32,
152        tiles_y: u32,
153        broker: &Arc<K2KBroker>,
154        reflection_coeff: f32,
155    ) -> Self {
156        let buffer_width = (tile_size + 2) as usize;
157        let buffer_size = buffer_width * buffer_width;
158
159        let kernel_id = Self::tile_kernel_id(tile_x, tile_y);
160        let endpoint = broker.register(kernel_id.clone());
161
162        // Determine neighbors
163        let neighbor_north = if tile_y > 0 {
164            Some(Self::tile_kernel_id(tile_x, tile_y - 1))
165        } else {
166            None
167        };
168        let neighbor_south = if tile_y < tiles_y - 1 {
169            Some(Self::tile_kernel_id(tile_x, tile_y + 1))
170        } else {
171            None
172        };
173        let neighbor_west = if tile_x > 0 {
174            Some(Self::tile_kernel_id(tile_x - 1, tile_y))
175        } else {
176            None
177        };
178        let neighbor_east = if tile_x < tiles_x - 1 {
179            Some(Self::tile_kernel_id(tile_x + 1, tile_y))
180        } else {
181            None
182        };
183
184        Self {
185            tile_x,
186            tile_y,
187            tile_size,
188            buffer_width,
189            pressure: vec![0.0; buffer_size],
190            pressure_prev: vec![0.0; buffer_size],
191            halo_north: vec![0.0; tile_size as usize],
192            halo_south: vec![0.0; tile_size as usize],
193            halo_east: vec![0.0; tile_size as usize],
194            halo_west: vec![0.0; tile_size as usize],
195            halos_received: [false; 4],
196            neighbor_north,
197            neighbor_south,
198            neighbor_east,
199            neighbor_west,
200            endpoint,
201            kernel_id,
202            is_north_boundary: tile_y == 0,
203            is_south_boundary: tile_y == tiles_y - 1,
204            is_east_boundary: tile_x == tiles_x - 1,
205            is_west_boundary: tile_x == 0,
206            reflection_coeff,
207        }
208    }
209
210    /// Generate kernel ID for a tile.
211    pub fn tile_kernel_id(tile_x: u32, tile_y: u32) -> KernelId {
212        KernelId::new(format!("tile_{}_{}", tile_x, tile_y))
213    }
214
215    /// Convert local tile coordinates to buffer index.
216    #[inline(always)]
217    fn buffer_idx(&self, local_x: usize, local_y: usize) -> usize {
218        // Add 1 to account for halo
219        (local_y + 1) * self.buffer_width + (local_x + 1)
220    }
221
222    /// Get pressure at local tile coordinates.
223    pub fn get_pressure(&self, local_x: u32, local_y: u32) -> f32 {
224        self.pressure[self.buffer_idx(local_x as usize, local_y as usize)]
225    }
226
227    /// Set pressure at local tile coordinates.
228    pub fn set_pressure(&mut self, local_x: u32, local_y: u32, value: f32) {
229        let idx = self.buffer_idx(local_x as usize, local_y as usize);
230        self.pressure[idx] = value;
231    }
232
233    /// Extract edge data for sending to neighbors.
234    fn extract_edge(&self, direction: HaloDirection) -> Vec<f32> {
235        let size = self.tile_size as usize;
236        let mut edge = vec![0.0; size];
237
238        match direction {
239            HaloDirection::North => {
240                // First interior row (y = 0)
241                for (x, cell) in edge.iter_mut().enumerate().take(size) {
242                    *cell = self.pressure[self.buffer_idx(x, 0)];
243                }
244            }
245            HaloDirection::South => {
246                // Last interior row (y = size - 1)
247                for (x, cell) in edge.iter_mut().enumerate().take(size) {
248                    *cell = self.pressure[self.buffer_idx(x, size - 1)];
249                }
250            }
251            HaloDirection::West => {
252                // First interior column (x = 0)
253                for (y, cell) in edge.iter_mut().enumerate().take(size) {
254                    *cell = self.pressure[self.buffer_idx(0, y)];
255                }
256            }
257            HaloDirection::East => {
258                // Last interior column (x = size - 1)
259                for (y, cell) in edge.iter_mut().enumerate().take(size) {
260                    *cell = self.pressure[self.buffer_idx(size - 1, y)];
261                }
262            }
263        }
264
265        edge
266    }
267
268    /// Apply received halo data to the buffer.
269    fn apply_halo(&mut self, direction: HaloDirection, data: &[f32]) {
270        let size = self.tile_size as usize;
271        let bw = self.buffer_width;
272
273        match direction {
274            HaloDirection::North => {
275                // Top halo row (y = -1 in local coords, index row 0 in buffer)
276                self.pressure[1..(size + 1)].copy_from_slice(&data[..size]);
277            }
278            HaloDirection::South => {
279                // Bottom halo row (y = size in local coords)
280                for (x, &val) in data.iter().enumerate().take(size) {
281                    self.pressure[(size + 1) * bw + (x + 1)] = val;
282                }
283            }
284            HaloDirection::West => {
285                // Left halo column (x = -1 in local coords, index col 0 in buffer)
286                for (y, &val) in data.iter().enumerate().take(size) {
287                    self.pressure[(y + 1) * bw] = val;
288                }
289            }
290            HaloDirection::East => {
291                // Right halo column (x = size in local coords)
292                for (y, &val) in data.iter().enumerate().take(size) {
293                    self.pressure[(y + 1) * bw + (size + 1)] = val;
294                }
295            }
296        }
297    }
298
299    /// Apply boundary reflection for edges without neighbors.
300    fn apply_boundary_reflection(&mut self) {
301        let size = self.tile_size as usize;
302        let bw = self.buffer_width;
303        let refl = self.reflection_coeff;
304
305        if self.is_north_boundary {
306            // Reflect interior row 0 to halo row
307            for x in 0..size {
308                let interior_idx = self.buffer_idx(x, 0);
309                self.pressure[x + 1] = self.pressure[interior_idx] * refl;
310            }
311        }
312
313        if self.is_south_boundary {
314            // Reflect interior row (size-1) to halo row
315            for x in 0..size {
316                let interior_idx = self.buffer_idx(x, size - 1);
317                self.pressure[(size + 1) * bw + (x + 1)] = self.pressure[interior_idx] * refl;
318            }
319        }
320
321        if self.is_west_boundary {
322            // Reflect interior col 0 to halo col
323            for y in 0..size {
324                let interior_idx = self.buffer_idx(0, y);
325                self.pressure[(y + 1) * bw] = self.pressure[interior_idx] * refl;
326            }
327        }
328
329        if self.is_east_boundary {
330            // Reflect interior col (size-1) to halo col
331            for y in 0..size {
332                let interior_idx = self.buffer_idx(size - 1, y);
333                self.pressure[(y + 1) * bw + (size + 1)] = self.pressure[interior_idx] * refl;
334            }
335        }
336    }
337
338    /// Send halo data to all neighbors via K2K.
339    pub async fn send_halos(&self) -> Result<()> {
340        for direction in HaloDirection::ALL {
341            let neighbor = match direction {
342                HaloDirection::North => &self.neighbor_north,
343                HaloDirection::South => &self.neighbor_south,
344                HaloDirection::East => &self.neighbor_east,
345                HaloDirection::West => &self.neighbor_west,
346            };
347
348            if let Some(neighbor_id) = neighbor {
349                let edge_data = self.extract_edge(direction);
350
351                // Serialize halo data into envelope
352                let envelope = Self::create_halo_envelope(direction.opposite(), &edge_data);
353
354                let receipt = self.endpoint.send(neighbor_id.clone(), envelope).await?;
355                if receipt.status != DeliveryStatus::Delivered {
356                    tracing::warn!(
357                        "Halo send failed: tile ({},{}) -> {:?}, status: {:?}",
358                        self.tile_x,
359                        self.tile_y,
360                        neighbor_id,
361                        receipt.status
362                    );
363                }
364            }
365        }
366
367        Ok(())
368    }
369
370    /// Receive halo data from all neighbors via K2K.
371    pub fn receive_halos(&mut self) {
372        // Reset received flags
373        self.halos_received = [false; 4];
374
375        // Process all pending messages
376        while let Some(message) = self.endpoint.try_receive() {
377            if let Some((direction, data)) = Self::parse_halo_envelope(&message.envelope) {
378                // Store in halo buffer
379                match direction {
380                    HaloDirection::North => {
381                        self.halo_north.copy_from_slice(&data);
382                        self.halos_received[0] = true;
383                    }
384                    HaloDirection::South => {
385                        self.halo_south.copy_from_slice(&data);
386                        self.halos_received[1] = true;
387                    }
388                    HaloDirection::East => {
389                        self.halo_east.copy_from_slice(&data);
390                        self.halos_received[2] = true;
391                    }
392                    HaloDirection::West => {
393                        self.halo_west.copy_from_slice(&data);
394                        self.halos_received[3] = true;
395                    }
396                }
397            }
398        }
399
400        // Apply received halos to pressure buffer
401        if self.halos_received[0] && self.neighbor_north.is_some() {
402            self.apply_halo(HaloDirection::North, &self.halo_north.clone());
403        }
404        if self.halos_received[1] && self.neighbor_south.is_some() {
405            self.apply_halo(HaloDirection::South, &self.halo_south.clone());
406        }
407        if self.halos_received[2] && self.neighbor_east.is_some() {
408            self.apply_halo(HaloDirection::East, &self.halo_east.clone());
409        }
410        if self.halos_received[3] && self.neighbor_west.is_some() {
411            self.apply_halo(HaloDirection::West, &self.halo_west.clone());
412        }
413
414        // Apply boundary reflection for edges without neighbors
415        self.apply_boundary_reflection();
416    }
417
418    /// Create a K2K envelope containing halo data.
419    fn create_halo_envelope(direction: HaloDirection, data: &[f32]) -> MessageEnvelope {
420        // Pack direction and data into payload
421        // Format: [direction_byte, f32 data...]
422        let mut payload = Vec::with_capacity(1 + data.len() * 4);
423        payload.push(direction as u8);
424        for &value in data {
425            payload.extend_from_slice(&value.to_le_bytes());
426        }
427
428        let header = MessageHeader::new(
429            HALO_EXCHANGE_TYPE_ID,
430            0, // source_kernel (host)
431            0, // dest_kernel (will be set by K2K routing)
432            payload.len(),
433            HlcTimestamp::now(0),
434        );
435
436        MessageEnvelope { header, payload }
437    }
438
439    /// Parse a K2K envelope to extract halo data.
440    fn parse_halo_envelope(envelope: &MessageEnvelope) -> Option<(HaloDirection, Vec<f32>)> {
441        if envelope.header.message_type != HALO_EXCHANGE_TYPE_ID {
442            return None;
443        }
444
445        if envelope.payload.is_empty() {
446            return None;
447        }
448
449        let direction = match envelope.payload[0] {
450            0 => HaloDirection::North,
451            1 => HaloDirection::South,
452            2 => HaloDirection::East,
453            3 => HaloDirection::West,
454            _ => return None,
455        };
456
457        let data: Vec<f32> = envelope.payload[1..]
458            .chunks_exact(4)
459            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
460            .collect();
461
462        Some((direction, data))
463    }
464
465    /// Compute FDTD for all interior cells.
466    pub fn compute_fdtd(&mut self, c2: f32, damping: f32) {
467        let size = self.tile_size as usize;
468        let bw = self.buffer_width;
469
470        // Process all interior cells
471        for local_y in 0..size {
472            for local_x in 0..size {
473                let idx = (local_y + 1) * bw + (local_x + 1);
474
475                let p_curr = self.pressure[idx];
476                let p_prev = self.pressure_prev[idx];
477
478                // Neighbor access (halos provide boundary values)
479                let p_north = self.pressure[idx - bw];
480                let p_south = self.pressure[idx + bw];
481                let p_west = self.pressure[idx - 1];
482                let p_east = self.pressure[idx + 1];
483
484                // FDTD computation
485                let laplacian = p_north + p_south + p_east + p_west - 4.0 * p_curr;
486                let p_new = 2.0 * p_curr - p_prev + c2 * laplacian;
487
488                self.pressure_prev[idx] = p_new * damping;
489            }
490        }
491    }
492
493    /// Swap pressure buffers after FDTD step.
494    pub fn swap_buffers(&mut self) {
495        std::mem::swap(&mut self.pressure, &mut self.pressure_prev);
496    }
497
498    /// Reset tile to initial state.
499    pub fn reset(&mut self) {
500        self.pressure.fill(0.0);
501        self.pressure_prev.fill(0.0);
502    }
503}
504
505/// GPU compute resources for tile-based FDTD (wgpu feature only).
506/// This is the legacy mode with upload/download per step.
507#[cfg(feature = "wgpu")]
508struct GpuComputeResources {
509    /// Shared compute pool for all tiles.
510    pool: TileGpuComputePool,
511    /// Per-tile GPU buffers.
512    tile_buffers: HashMap<(u32, u32), TileBuffers>,
513}
514
515/// GPU-persistent state using WGPU backend.
516/// Pressure stays GPU-resident, only halos are transferred.
517#[cfg(feature = "wgpu")]
518struct WgpuPersistentState {
519    /// WGPU backend.
520    backend: WgpuTileBackend,
521    /// Per-tile GPU buffers.
522    tile_buffers: HashMap<(u32, u32), TileGpuBuffers<WgpuBuffer>>,
523}
524
525/// GPU-persistent state using CUDA backend.
526/// Pressure stays GPU-resident, only halos are transferred.
527#[cfg(feature = "cuda")]
528struct CudaPersistentState {
529    /// CUDA backend.
530    backend: CudaTileBackend,
531    /// Per-tile GPU buffers.
532    tile_buffers: HashMap<(u32, u32), TileGpuBuffers<ringkernel_cuda::CudaBuffer>>,
533}
534
535/// Which GPU backend is active for persistent state.
536#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537pub enum GpuPersistentBackend {
538    /// WGPU (WebGPU) backend.
539    Wgpu,
540    /// CUDA (NVIDIA) backend.
541    Cuda,
542}
543
544/// A simulation grid using tile actors with K2K messaging.
545pub struct TileKernelGrid {
546    /// Grid dimensions in cells.
547    pub width: u32,
548    pub height: u32,
549
550    /// Tile size (cells per tile side).
551    tile_size: u32,
552
553    /// Grid dimensions in tiles.
554    tiles_x: u32,
555    tiles_y: u32,
556
557    /// Tile actors indexed by (tile_x, tile_y).
558    tiles: HashMap<(u32, u32), TileActor>,
559
560    /// K2K broker for inter-tile messaging.
561    broker: Arc<K2KBroker>,
562
563    /// Acoustic simulation parameters.
564    pub params: AcousticParams,
565
566    /// The backend being used.
567    backend: Backend,
568
569    /// RingKernel runtime (for kernel handle management).
570    #[allow(dead_code)]
571    runtime: Arc<RingKernel>,
572
573    /// Legacy GPU compute pool for tile FDTD (optional, requires wgpu feature).
574    /// This mode uploads/downloads full buffers each step.
575    #[cfg(feature = "wgpu")]
576    gpu_compute: Option<GpuComputeResources>,
577
578    /// GPU-persistent WGPU state (optional).
579    /// Pressure stays GPU-resident, only halos are transferred.
580    #[cfg(feature = "wgpu")]
581    wgpu_persistent: Option<WgpuPersistentState>,
582
583    /// GPU-persistent CUDA state (optional).
584    /// Pressure stays GPU-resident, only halos are transferred.
585    #[cfg(feature = "cuda")]
586    cuda_persistent: Option<CudaPersistentState>,
587}
588
589impl std::fmt::Debug for TileKernelGrid {
590    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591        f.debug_struct("TileKernelGrid")
592            .field("width", &self.width)
593            .field("height", &self.height)
594            .field("tile_size", &self.tile_size)
595            .field("tiles_x", &self.tiles_x)
596            .field("tiles_y", &self.tiles_y)
597            .field("tile_count", &self.tiles.len())
598            .field("backend", &self.backend)
599            .finish()
600    }
601}
602
603impl TileKernelGrid {
604    /// Create a new tile-based kernel grid.
605    pub async fn new(
606        width: u32,
607        height: u32,
608        params: AcousticParams,
609        backend: Backend,
610    ) -> Result<Self> {
611        Self::with_tile_size(width, height, params, backend, DEFAULT_TILE_SIZE).await
612    }
613
614    /// Create a new tile-based kernel grid with custom tile size.
615    pub async fn with_tile_size(
616        width: u32,
617        height: u32,
618        params: AcousticParams,
619        backend: Backend,
620        tile_size: u32,
621    ) -> Result<Self> {
622        let runtime = Arc::new(RingKernel::with_backend(backend).await?);
623
624        // Calculate tile grid dimensions
625        let tiles_x = width.div_ceil(tile_size);
626        let tiles_y = height.div_ceil(tile_size);
627
628        // Create K2K broker
629        let broker = K2KBuilder::new()
630            .max_pending_messages(tiles_x as usize * tiles_y as usize * 8)
631            .build();
632
633        // Create tile actors
634        let mut tiles = HashMap::new();
635        for ty in 0..tiles_y {
636            for tx in 0..tiles_x {
637                let tile = TileActor::new(tx, ty, tile_size, tiles_x, tiles_y, &broker, 0.95);
638                tiles.insert((tx, ty), tile);
639            }
640        }
641
642        tracing::info!(
643            "Created TileKernelGrid: {}x{} cells, {}x{} tiles ({}x{} per tile), {} tile actors",
644            width,
645            height,
646            tiles_x,
647            tiles_y,
648            tile_size,
649            tile_size,
650            tiles.len()
651        );
652
653        Ok(Self {
654            width,
655            height,
656            tile_size,
657            tiles_x,
658            tiles_y,
659            tiles,
660            broker,
661            params,
662            backend,
663            runtime,
664            #[cfg(feature = "wgpu")]
665            gpu_compute: None,
666            #[cfg(feature = "wgpu")]
667            wgpu_persistent: None,
668            #[cfg(feature = "cuda")]
669            cuda_persistent: None,
670        })
671    }
672
673    /// Convert global cell coordinates to (tile_x, tile_y, local_x, local_y).
674    fn global_to_tile_coords(&self, x: u32, y: u32) -> (u32, u32, u32, u32) {
675        let tile_x = x / self.tile_size;
676        let tile_y = y / self.tile_size;
677        let local_x = x % self.tile_size;
678        let local_y = y % self.tile_size;
679        (tile_x, tile_y, local_x, local_y)
680    }
681
682    /// Perform one simulation step using tile actors with K2K messaging.
683    pub async fn step(&mut self) -> Result<()> {
684        let c2 = self.params.courant_number().powi(2);
685        let damping = 1.0 - self.params.damping;
686
687        // Phase 1: All tiles send halo data to neighbors (K2K)
688        for tile in self.tiles.values() {
689            tile.send_halos().await?;
690        }
691
692        // Phase 2: All tiles receive halo data from neighbors (K2K)
693        for tile in self.tiles.values_mut() {
694            tile.receive_halos();
695        }
696
697        // Phase 3: All tiles compute FDTD for interior cells
698        for tile in self.tiles.values_mut() {
699            tile.compute_fdtd(c2, damping);
700        }
701
702        // Phase 4: All tiles swap buffers
703        for tile in self.tiles.values_mut() {
704            tile.swap_buffers();
705        }
706
707        Ok(())
708    }
709
710    /// Enable GPU compute for tile FDTD acceleration.
711    ///
712    /// This initializes WGPU resources and creates GPU buffers for each tile.
713    /// Once enabled, use `step_gpu()` for GPU-accelerated simulation steps.
714    #[cfg(feature = "wgpu")]
715    pub async fn enable_gpu_compute(&mut self) -> Result<()> {
716        let (device, queue) = init_wgpu().await?;
717        let pool = TileGpuComputePool::new(device, queue, self.tile_size)?;
718
719        // Create GPU buffers for each tile
720        let mut tile_buffers = HashMap::new();
721        for tile_coords in self.tiles.keys() {
722            let buffers = pool.create_tile_buffers();
723            tile_buffers.insert(*tile_coords, buffers);
724        }
725
726        let num_tiles = tile_buffers.len();
727        self.gpu_compute = Some(GpuComputeResources { pool, tile_buffers });
728
729        tracing::info!(
730            "GPU compute enabled for TileKernelGrid: {} tiles with GPU buffers",
731            num_tiles
732        );
733
734        Ok(())
735    }
736
737    /// Check if GPU compute is enabled.
738    #[cfg(feature = "wgpu")]
739    pub fn is_gpu_enabled(&self) -> bool {
740        self.gpu_compute.is_some()
741    }
742
743    /// Perform one simulation step using GPU compute for FDTD.
744    ///
745    /// This is the hybrid approach:
746    /// - K2K messaging handles halo exchange between tiles (actor model showcase)
747    /// - GPU compute shaders handle FDTD for tile interiors (parallel compute)
748    #[cfg(feature = "wgpu")]
749    pub async fn step_gpu(&mut self) -> Result<()> {
750        let gpu = self.gpu_compute.as_ref().ok_or_else(|| {
751            ringkernel_core::error::RingKernelError::BackendError(
752                "GPU compute not enabled. Call enable_gpu_compute() first.".to_string(),
753            )
754        })?;
755
756        let c2 = self.params.courant_number().powi(2);
757        let damping = 1.0 - self.params.damping;
758
759        // Phase 1: All tiles send halo data to neighbors (K2K)
760        for tile in self.tiles.values() {
761            tile.send_halos().await?;
762        }
763
764        // Phase 2: All tiles receive halo data from neighbors (K2K)
765        for tile in self.tiles.values_mut() {
766            tile.receive_halos();
767        }
768
769        // Phase 3: GPU compute FDTD for all tile interiors
770        // Each tile dispatches to GPU and reads back results
771        for ((tx, ty), tile) in self.tiles.iter_mut() {
772            if let Some(buffers) = gpu.tile_buffers.get(&(*tx, *ty)) {
773                // Get tile's pressure buffer (with halos already applied from K2K)
774                let pressure = &tile.pressure;
775                let pressure_prev = &tile.pressure_prev;
776
777                // Dispatch GPU compute
778                let result = gpu
779                    .pool
780                    .compute_fdtd(buffers, pressure, pressure_prev, c2, damping);
781
782                // Copy results back to tile's pressure_prev buffer
783                tile.pressure_prev.copy_from_slice(&result);
784            } else {
785                // Fallback to CPU if no GPU buffers (shouldn't happen)
786                tile.compute_fdtd(c2, damping);
787            }
788        }
789
790        // Phase 4: All tiles swap buffers
791        for tile in self.tiles.values_mut() {
792            tile.swap_buffers();
793        }
794
795        Ok(())
796    }
797
798    /// Enable GPU-persistent WGPU compute.
799    ///
800    /// This mode keeps pressure state GPU-resident and only transfers halos
801    /// for K2K communication. Much faster than the legacy mode which uploads
802    /// and downloads full buffers each step.
803    #[cfg(feature = "wgpu")]
804    pub async fn enable_wgpu_persistent(&mut self) -> Result<()> {
805        let backend = WgpuTileBackend::new(self.tile_size).await?;
806
807        // Create GPU buffers for each tile and upload initial state
808        let mut tile_buffers = HashMap::new();
809        for ((tx, ty), tile) in &self.tiles {
810            let buffers = backend.create_tile_buffers(self.tile_size)?;
811            // Upload initial state (all zeros typically)
812            backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
813            tile_buffers.insert((*tx, *ty), buffers);
814        }
815
816        let num_tiles = tile_buffers.len();
817        self.wgpu_persistent = Some(WgpuPersistentState {
818            backend,
819            tile_buffers,
820        });
821
822        tracing::info!(
823            "WGPU GPU-persistent compute enabled for TileKernelGrid: {} tiles",
824            num_tiles
825        );
826
827        Ok(())
828    }
829
830    /// Enable GPU-persistent CUDA compute.
831    ///
832    /// This mode keeps pressure state GPU-resident and only transfers halos
833    /// for K2K communication. Much faster than upload/download per step.
834    #[cfg(feature = "cuda")]
835    pub fn enable_cuda_persistent(&mut self) -> Result<()> {
836        let backend = CudaTileBackend::new(self.tile_size)?;
837
838        // Create GPU buffers for each tile and upload initial state
839        let mut tile_buffers = HashMap::new();
840        for ((tx, ty), tile) in &self.tiles {
841            let buffers = backend.create_tile_buffers(self.tile_size)?;
842            // Upload initial state
843            backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
844            tile_buffers.insert((*tx, *ty), buffers);
845        }
846
847        let num_tiles = tile_buffers.len();
848        self.cuda_persistent = Some(CudaPersistentState {
849            backend,
850            tile_buffers,
851        });
852
853        tracing::info!(
854            "CUDA GPU-persistent compute enabled for TileKernelGrid: {} tiles",
855            num_tiles
856        );
857
858        Ok(())
859    }
860
861    /// Check if GPU-persistent mode is enabled.
862    pub fn is_gpu_persistent_enabled(&self) -> bool {
863        #[cfg(feature = "wgpu")]
864        if self.wgpu_persistent.is_some() {
865            return true;
866        }
867        #[cfg(feature = "cuda")]
868        if self.cuda_persistent.is_some() {
869            return true;
870        }
871        false
872    }
873
874    /// Get which GPU-persistent backend is active.
875    pub fn gpu_persistent_backend(&self) -> Option<GpuPersistentBackend> {
876        #[cfg(feature = "cuda")]
877        if self.cuda_persistent.is_some() {
878            return Some(GpuPersistentBackend::Cuda);
879        }
880        #[cfg(feature = "wgpu")]
881        if self.wgpu_persistent.is_some() {
882            return Some(GpuPersistentBackend::Wgpu);
883        }
884        None
885    }
886
887    /// Perform one simulation step using GPU-persistent WGPU compute.
888    ///
889    /// State stays GPU-resident. Only halos are transferred for K2K messaging.
890    #[cfg(feature = "wgpu")]
891    pub fn step_wgpu_persistent(&mut self) -> Result<()> {
892        let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
893            ringkernel_core::error::RingKernelError::BackendError(
894                "WGPU persistent compute not enabled. Call enable_wgpu_persistent() first."
895                    .to_string(),
896            )
897        })?;
898
899        let c2 = self.params.courant_number().powi(2);
900        let damping = 1.0 - self.params.damping;
901        let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
902
903        // Phase 1: Extract halos from GPU and send via K2K
904        // Build halo messages for all tiles
905        let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
906
907        for (tx, ty) in self.tiles.keys() {
908            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
909                // Extract halos based on tile's neighbors
910                let tile = self.tiles.get(&(*tx, *ty)).unwrap();
911
912                if tile.neighbor_north.is_some() {
913                    let halo = state.backend.extract_halo(buffers, Edge::North)?;
914                    halo_messages.push((
915                        tile.neighbor_north.clone().unwrap(),
916                        HaloDirection::South,
917                        halo,
918                    ));
919                }
920                if tile.neighbor_south.is_some() {
921                    let halo = state.backend.extract_halo(buffers, Edge::South)?;
922                    halo_messages.push((
923                        tile.neighbor_south.clone().unwrap(),
924                        HaloDirection::North,
925                        halo,
926                    ));
927                }
928                if tile.neighbor_west.is_some() {
929                    let halo = state.backend.extract_halo(buffers, Edge::West)?;
930                    halo_messages.push((
931                        tile.neighbor_west.clone().unwrap(),
932                        HaloDirection::East,
933                        halo,
934                    ));
935                }
936                if tile.neighbor_east.is_some() {
937                    let halo = state.backend.extract_halo(buffers, Edge::East)?;
938                    halo_messages.push((
939                        tile.neighbor_east.clone().unwrap(),
940                        HaloDirection::West,
941                        halo,
942                    ));
943                }
944            }
945        }
946
947        // Phase 2: Inject halos into GPU buffers
948        // Group messages by destination tile
949        for (dest_id, direction, halo_data) in halo_messages {
950            // Find the tile coords from kernel ID
951            for (tx, ty) in self.tiles.keys() {
952                if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
953                    if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
954                        let edge = match direction {
955                            HaloDirection::North => Edge::North,
956                            HaloDirection::South => Edge::South,
957                            HaloDirection::East => Edge::East,
958                            HaloDirection::West => Edge::West,
959                        };
960                        state.backend.inject_halo(buffers, edge, &halo_data)?;
961                    }
962                    break;
963                }
964            }
965        }
966
967        // Phase 3: Apply boundary conditions for domain edge tiles
968        // Tiles without neighbors on certain edges need boundary conditions applied
969        for (tx, ty) in self.tiles.keys() {
970            let tile = self.tiles.get(&(*tx, *ty)).unwrap();
971            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
972                // Apply absorbing boundary for edges without neighbors
973                if tile.neighbor_north.is_none() {
974                    state.backend.apply_boundary(
975                        buffers,
976                        Edge::North,
977                        BoundaryCondition::Absorbing,
978                    )?;
979                }
980                if tile.neighbor_south.is_none() {
981                    state.backend.apply_boundary(
982                        buffers,
983                        Edge::South,
984                        BoundaryCondition::Absorbing,
985                    )?;
986                }
987                if tile.neighbor_west.is_none() {
988                    state.backend.apply_boundary(
989                        buffers,
990                        Edge::West,
991                        BoundaryCondition::Absorbing,
992                    )?;
993                }
994                if tile.neighbor_east.is_none() {
995                    state.backend.apply_boundary(
996                        buffers,
997                        Edge::East,
998                        BoundaryCondition::Absorbing,
999                    )?;
1000                }
1001            }
1002        }
1003
1004        // Phase 4: Compute FDTD on GPU for all tiles
1005        for (tx, ty) in self.tiles.keys() {
1006            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1007                state.backend.fdtd_step(buffers, &fdtd_params)?;
1008            }
1009        }
1010
1011        // Synchronize
1012        state.backend.synchronize()?;
1013
1014        // Phase 5: Swap buffers (pointer swap only, no data movement)
1015        // NOTE: We need mutable access to state for this
1016        let state = self.wgpu_persistent.as_mut().unwrap();
1017        for buffers in state.tile_buffers.values_mut() {
1018            state.backend.swap_buffers(buffers);
1019        }
1020
1021        Ok(())
1022    }
1023
1024    /// Perform one simulation step using GPU-persistent CUDA compute.
1025    ///
1026    /// State stays GPU-resident. Only halos are transferred for K2K messaging.
1027    #[cfg(feature = "cuda")]
1028    pub fn step_cuda_persistent(&mut self) -> Result<()> {
1029        let state = self.cuda_persistent.as_ref().ok_or_else(|| {
1030            ringkernel_core::error::RingKernelError::BackendError(
1031                "CUDA persistent compute not enabled. Call enable_cuda_persistent() first."
1032                    .to_string(),
1033            )
1034        })?;
1035
1036        let c2 = self.params.courant_number().powi(2);
1037        let damping = 1.0 - self.params.damping;
1038        let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
1039
1040        // Phase 1: Extract halos from GPU and build messages
1041        let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
1042
1043        for (tx, ty) in self.tiles.keys() {
1044            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1045                let tile = self.tiles.get(&(*tx, *ty)).unwrap();
1046
1047                if tile.neighbor_north.is_some() {
1048                    let halo = state.backend.extract_halo(buffers, Edge::North)?;
1049                    halo_messages.push((
1050                        tile.neighbor_north.clone().unwrap(),
1051                        HaloDirection::South,
1052                        halo,
1053                    ));
1054                }
1055                if tile.neighbor_south.is_some() {
1056                    let halo = state.backend.extract_halo(buffers, Edge::South)?;
1057                    halo_messages.push((
1058                        tile.neighbor_south.clone().unwrap(),
1059                        HaloDirection::North,
1060                        halo,
1061                    ));
1062                }
1063                if tile.neighbor_west.is_some() {
1064                    let halo = state.backend.extract_halo(buffers, Edge::West)?;
1065                    halo_messages.push((
1066                        tile.neighbor_west.clone().unwrap(),
1067                        HaloDirection::East,
1068                        halo,
1069                    ));
1070                }
1071                if tile.neighbor_east.is_some() {
1072                    let halo = state.backend.extract_halo(buffers, Edge::East)?;
1073                    halo_messages.push((
1074                        tile.neighbor_east.clone().unwrap(),
1075                        HaloDirection::West,
1076                        halo,
1077                    ));
1078                }
1079            }
1080        }
1081
1082        // Phase 2: Inject halos into GPU buffers
1083        for (dest_id, direction, halo_data) in halo_messages {
1084            for (tx, ty) in self.tiles.keys() {
1085                if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
1086                    if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1087                        let edge = match direction {
1088                            HaloDirection::North => Edge::North,
1089                            HaloDirection::South => Edge::South,
1090                            HaloDirection::East => Edge::East,
1091                            HaloDirection::West => Edge::West,
1092                        };
1093                        state.backend.inject_halo(buffers, edge, &halo_data)?;
1094                    }
1095                    break;
1096                }
1097            }
1098        }
1099
1100        // Phase 3: Compute FDTD on GPU for all tiles
1101        for (tx, ty) in self.tiles.keys() {
1102            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1103                state.backend.fdtd_step(buffers, &fdtd_params)?;
1104            }
1105        }
1106
1107        // Synchronize
1108        state.backend.synchronize()?;
1109
1110        // Phase 4: Swap buffers
1111        let state = self.cuda_persistent.as_mut().unwrap();
1112        for buffers in state.tile_buffers.values_mut() {
1113            state.backend.swap_buffers(buffers);
1114        }
1115
1116        Ok(())
1117    }
1118
1119    /// Read pressure grid from GPU for visualization.
1120    ///
1121    /// This is the only time we transfer full tile data from GPU to host.
1122    /// Call this only when you need to render the simulation.
1123    #[cfg(feature = "wgpu")]
1124    pub fn read_pressure_from_wgpu(&self) -> Result<Vec<Vec<f32>>> {
1125        let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
1126            ringkernel_core::error::RingKernelError::BackendError(
1127                "WGPU persistent compute not enabled.".to_string(),
1128            )
1129        })?;
1130
1131        let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1132
1133        for ((tx, ty), buffers) in &state.tile_buffers {
1134            let interior = state.backend.read_interior_pressure(buffers)?;
1135
1136            // Copy tile interior to grid
1137            let tile_start_x = tx * self.tile_size;
1138            let tile_start_y = ty * self.tile_size;
1139
1140            for ly in 0..self.tile_size {
1141                for lx in 0..self.tile_size {
1142                    let gx = tile_start_x + lx;
1143                    let gy = tile_start_y + ly;
1144                    if gx < self.width && gy < self.height {
1145                        grid[gy as usize][gx as usize] =
1146                            interior[(ly * self.tile_size + lx) as usize];
1147                    }
1148                }
1149            }
1150        }
1151
1152        Ok(grid)
1153    }
1154
1155    /// Read pressure grid from CUDA for visualization.
1156    #[cfg(feature = "cuda")]
1157    pub fn read_pressure_from_cuda(&self) -> Result<Vec<Vec<f32>>> {
1158        let state = self.cuda_persistent.as_ref().ok_or_else(|| {
1159            ringkernel_core::error::RingKernelError::BackendError(
1160                "CUDA persistent compute not enabled.".to_string(),
1161            )
1162        })?;
1163
1164        let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1165
1166        for ((tx, ty), buffers) in &state.tile_buffers {
1167            let interior = state.backend.read_interior_pressure(buffers)?;
1168
1169            let tile_start_x = tx * self.tile_size;
1170            let tile_start_y = ty * self.tile_size;
1171
1172            for ly in 0..self.tile_size {
1173                for lx in 0..self.tile_size {
1174                    let gx = tile_start_x + lx;
1175                    let gy = tile_start_y + ly;
1176                    if gx < self.width && gy < self.height {
1177                        grid[gy as usize][gx as usize] =
1178                            interior[(ly * self.tile_size + lx) as usize];
1179                    }
1180                }
1181            }
1182        }
1183
1184        Ok(grid)
1185    }
1186
1187    /// Upload a pressure impulse to GPU-persistent state.
1188    ///
1189    /// For GPU-persistent mode, we need to upload the impulse to the GPU.
1190    #[cfg(feature = "wgpu")]
1191    pub fn inject_impulse_wgpu(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1192        if x >= self.width || y >= self.height {
1193            return Ok(());
1194        }
1195
1196        let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1197
1198        // First, update the CPU tile (for consistency)
1199        if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1200            let current = tile.get_pressure(local_x, local_y);
1201            tile.set_pressure(local_x, local_y, current + amplitude);
1202
1203            // Then upload to GPU
1204            if let Some(state) = &self.wgpu_persistent {
1205                if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1206                    state.backend.upload_initial_state(
1207                        buffers,
1208                        &tile.pressure,
1209                        &tile.pressure_prev,
1210                    )?;
1211                }
1212            }
1213        }
1214
1215        Ok(())
1216    }
1217
1218    /// Upload a pressure impulse to CUDA GPU-persistent state.
1219    #[cfg(feature = "cuda")]
1220    pub fn inject_impulse_cuda(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1221        if x >= self.width || y >= self.height {
1222            return Ok(());
1223        }
1224
1225        let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1226
1227        if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1228            let current = tile.get_pressure(local_x, local_y);
1229            tile.set_pressure(local_x, local_y, current + amplitude);
1230
1231            if let Some(state) = &self.cuda_persistent {
1232                if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1233                    state.backend.upload_initial_state(
1234                        buffers,
1235                        &tile.pressure,
1236                        &tile.pressure_prev,
1237                    )?;
1238                }
1239            }
1240        }
1241
1242        Ok(())
1243    }
1244
1245    /// Inject an impulse at the given global grid position.
1246    pub fn inject_impulse(&mut self, x: u32, y: u32, amplitude: f32) {
1247        if x >= self.width || y >= self.height {
1248            return;
1249        }
1250
1251        let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1252
1253        if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1254            let current = tile.get_pressure(local_x, local_y);
1255            tile.set_pressure(local_x, local_y, current + amplitude);
1256        }
1257    }
1258
1259    /// Get the pressure grid for visualization.
1260    pub fn get_pressure_grid(&self) -> Vec<Vec<f32>> {
1261        let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1262
1263        for y in 0..self.height {
1264            for x in 0..self.width {
1265                let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1266
1267                if let Some(tile) = self.tiles.get(&(tile_x, tile_y)) {
1268                    // Handle tiles at the edge that may have fewer cells
1269                    if local_x < self.tile_size && local_y < self.tile_size {
1270                        grid[y as usize][x as usize] = tile.get_pressure(local_x, local_y);
1271                    }
1272                }
1273            }
1274        }
1275
1276        grid
1277    }
1278
1279    /// Get the maximum absolute pressure in the grid.
1280    pub fn max_pressure(&self) -> f32 {
1281        self.tiles
1282            .values()
1283            .flat_map(|tile| {
1284                (0..self.tile_size).flat_map(move |y| {
1285                    (0..self.tile_size).map(move |x| tile.get_pressure(x, y).abs())
1286                })
1287            })
1288            .fold(0.0, f32::max)
1289    }
1290
1291    /// Get total energy in the system.
1292    pub fn total_energy(&self) -> f32 {
1293        self.tiles
1294            .values()
1295            .flat_map(|tile| {
1296                (0..self.tile_size).flat_map(move |y| {
1297                    (0..self.tile_size).map(move |x| {
1298                        let p = tile.get_pressure(x, y);
1299                        p * p
1300                    })
1301                })
1302            })
1303            .sum()
1304    }
1305
1306    /// Reset all tiles to initial state.
1307    pub fn reset(&mut self) {
1308        for tile in self.tiles.values_mut() {
1309            tile.reset();
1310        }
1311    }
1312
1313    /// Get the number of cells.
1314    pub fn cell_count(&self) -> usize {
1315        (self.width * self.height) as usize
1316    }
1317
1318    /// Get the number of tile actors.
1319    pub fn tile_count(&self) -> usize {
1320        self.tiles.len()
1321    }
1322
1323    /// Get the current backend.
1324    pub fn backend(&self) -> Backend {
1325        self.backend
1326    }
1327
1328    /// Get K2K messaging statistics.
1329    pub fn k2k_stats(&self) -> ringkernel_core::k2k::K2KStats {
1330        self.broker.stats()
1331    }
1332
1333    /// Update acoustic parameters.
1334    pub fn set_speed_of_sound(&mut self, speed: f32) {
1335        self.params.set_speed_of_sound(speed);
1336    }
1337
1338    /// Update cell size.
1339    pub fn set_cell_size(&mut self, size: f32) {
1340        self.params.set_cell_size(size);
1341    }
1342
1343    /// Resize the grid.
1344    pub async fn resize(&mut self, new_width: u32, new_height: u32) -> Result<()> {
1345        self.width = new_width;
1346        self.height = new_height;
1347
1348        // Recalculate tile grid dimensions
1349        self.tiles_x = new_width.div_ceil(self.tile_size);
1350        self.tiles_y = new_height.div_ceil(self.tile_size);
1351
1352        // Create new K2K broker
1353        self.broker = K2KBuilder::new()
1354            .max_pending_messages(self.tiles_x as usize * self.tiles_y as usize * 8)
1355            .build();
1356
1357        // Recreate tile actors
1358        self.tiles.clear();
1359        for ty in 0..self.tiles_y {
1360            for tx in 0..self.tiles_x {
1361                let tile = TileActor::new(
1362                    tx,
1363                    ty,
1364                    self.tile_size,
1365                    self.tiles_x,
1366                    self.tiles_y,
1367                    &self.broker,
1368                    0.95,
1369                );
1370                self.tiles.insert((tx, ty), tile);
1371            }
1372        }
1373
1374        tracing::info!(
1375            "Resized TileKernelGrid: {}x{} cells, {} tile actors",
1376            new_width,
1377            new_height,
1378            self.tiles.len()
1379        );
1380
1381        Ok(())
1382    }
1383
1384    /// Shutdown the grid.
1385    pub async fn shutdown(self) -> Result<()> {
1386        // K2K broker and tiles are dropped automatically
1387        Ok(())
1388    }
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393    use super::*;
1394
1395    #[tokio::test]
1396    async fn test_tile_grid_creation() {
1397        let params = AcousticParams::new(343.0, 1.0);
1398        let grid = TileKernelGrid::new(64, 64, params, Backend::Cpu)
1399            .await
1400            .unwrap();
1401
1402        assert_eq!(grid.width, 64);
1403        assert_eq!(grid.height, 64);
1404        assert_eq!(grid.tile_size, DEFAULT_TILE_SIZE);
1405        assert_eq!(grid.tiles_x, 4); // 64 / 16 = 4
1406        assert_eq!(grid.tiles_y, 4);
1407        assert_eq!(grid.tile_count(), 16); // 4 * 4 = 16 tiles
1408    }
1409
1410    #[tokio::test]
1411    async fn test_tile_grid_impulse() {
1412        let params = AcousticParams::new(343.0, 1.0);
1413        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1414            .await
1415            .unwrap();
1416
1417        grid.inject_impulse(16, 16, 1.0);
1418
1419        let pressure_grid = grid.get_pressure_grid();
1420        assert_eq!(pressure_grid[16][16], 1.0);
1421    }
1422
1423    #[tokio::test]
1424    async fn test_tile_grid_step() {
1425        let params = AcousticParams::new(343.0, 1.0);
1426        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1427            .await
1428            .unwrap();
1429
1430        // Inject impulse at center
1431        grid.inject_impulse(16, 16, 1.0);
1432
1433        // Run several steps
1434        for _ in 0..10 {
1435            grid.step().await.unwrap();
1436        }
1437
1438        // Wave should have propagated
1439        let pressure_grid = grid.get_pressure_grid();
1440        let neighbor_pressure = pressure_grid[16][17];
1441        assert!(
1442            neighbor_pressure.abs() > 0.0,
1443            "Wave should have propagated to neighbor"
1444        );
1445    }
1446
1447    #[tokio::test]
1448    async fn test_tile_grid_k2k_stats() {
1449        let params = AcousticParams::new(343.0, 1.0);
1450        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1451            .await
1452            .unwrap();
1453
1454        grid.inject_impulse(16, 16, 1.0);
1455        grid.step().await.unwrap();
1456
1457        let stats = grid.k2k_stats();
1458        assert!(
1459            stats.messages_delivered > 0,
1460            "K2K messages should have been exchanged"
1461        );
1462    }
1463
1464    #[tokio::test]
1465    async fn test_tile_grid_reset() {
1466        let params = AcousticParams::new(343.0, 1.0);
1467        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1468            .await
1469            .unwrap();
1470
1471        grid.inject_impulse(16, 16, 1.0);
1472        grid.step().await.unwrap();
1473
1474        grid.reset();
1475
1476        let pressure_grid = grid.get_pressure_grid();
1477        for row in pressure_grid {
1478            for p in row {
1479                assert_eq!(p, 0.0);
1480            }
1481        }
1482    }
1483
1484    #[tokio::test]
1485    async fn test_tile_boundary_handling() {
1486        let params = AcousticParams::new(343.0, 1.0);
1487        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1488            .await
1489            .unwrap();
1490
1491        // Inject near boundary
1492        grid.inject_impulse(1, 1, 1.0);
1493
1494        // Run several steps
1495        for _ in 0..5 {
1496            grid.step().await.unwrap();
1497        }
1498
1499        // Should not crash and energy should be bounded
1500        let energy = grid.total_energy();
1501        assert!(energy.is_finite(), "Energy should be finite");
1502    }
1503
1504    #[cfg(feature = "wgpu")]
1505    #[tokio::test]
1506    #[ignore] // May not have GPU
1507    async fn test_tile_grid_gpu_step() {
1508        let params = AcousticParams::new(343.0, 1.0);
1509        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1510            .await
1511            .unwrap();
1512
1513        // Enable GPU compute
1514        grid.enable_gpu_compute().await.unwrap();
1515        assert!(grid.is_gpu_enabled());
1516
1517        // Inject impulse at center
1518        grid.inject_impulse(16, 16, 1.0);
1519
1520        // Run several steps using GPU
1521        for _ in 0..10 {
1522            grid.step_gpu().await.unwrap();
1523        }
1524
1525        // Wave should have propagated
1526        let pressure_grid = grid.get_pressure_grid();
1527        let neighbor_pressure = pressure_grid[16][17];
1528        assert!(
1529            neighbor_pressure.abs() > 0.0,
1530            "Wave should have propagated to neighbor (GPU compute)"
1531        );
1532    }
1533
1534    #[cfg(feature = "wgpu")]
1535    #[tokio::test]
1536    #[ignore] // May not have GPU
1537    async fn test_tile_grid_gpu_matches_cpu() {
1538        let params = AcousticParams::new(343.0, 1.0);
1539
1540        // Create two grids - one CPU, one GPU
1541        let mut grid_cpu = TileKernelGrid::with_tile_size(32, 32, params.clone(), Backend::Cpu, 16)
1542            .await
1543            .unwrap();
1544        let mut grid_gpu = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1545            .await
1546            .unwrap();
1547
1548        grid_gpu.enable_gpu_compute().await.unwrap();
1549
1550        // Inject same impulse
1551        grid_cpu.inject_impulse(16, 16, 1.0);
1552        grid_gpu.inject_impulse(16, 16, 1.0);
1553
1554        // Run same number of steps
1555        for _ in 0..5 {
1556            grid_cpu.step().await.unwrap();
1557            grid_gpu.step_gpu().await.unwrap();
1558        }
1559
1560        // Results should match closely
1561        let cpu_grid = grid_cpu.get_pressure_grid();
1562        let gpu_grid = grid_gpu.get_pressure_grid();
1563
1564        for y in 0..32 {
1565            for x in 0..32 {
1566                let diff = (cpu_grid[y][x] - gpu_grid[y][x]).abs();
1567                assert!(
1568                    diff < 1e-4,
1569                    "CPU/GPU mismatch at ({},{}): cpu={}, gpu={}, diff={}",
1570                    x,
1571                    y,
1572                    cpu_grid[y][x],
1573                    gpu_grid[y][x],
1574                    diff
1575                );
1576            }
1577        }
1578    }
1579}