ringkernel_wavesim/simulation/
tile_grid.rs

1//! Tile-based GPU-native actor grid for wave simulation.
2//!
3//! This module implements a scalable actor-based simulation where each actor
4//! manages a tile (e.g., 16x16) of cells. Tiles communicate via K2K messaging
5//! to exchange boundary (halo) data, showcasing RingKernel's GPU-native actor model.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! +--------+--------+--------+
11//! | Tile   | Tile   | Tile   |
12//! | (0,0)  | (1,0)  | (2,0)  |
13//! +--------+--------+--------+
14//! | Tile   | Tile   | Tile   |
15//! | (0,1)  | (1,1)  | (2,1)  |
16//! +--------+--------+--------+
17//! ```
18//!
19//! Each tile:
20//! - Has a pressure buffer with 1-cell halo for neighbor access
21//! - Exchanges halo data with 4 neighbors via K2K messaging
22//! - Computes FDTD for its interior cells in parallel (GPU or CPU)
23//!
24//! ## Halo Exchange Pattern
25//!
26//! ```text
27//! +--+----------------+--+
28//! |  | North Halo     |  |  <- Received from north neighbor
29//! +--+----------------+--+
30//! |W |                |E |
31//! |  |   Interior     |  |
32//! |  |    16x16       |  |
33//! +--+----------------+--+
34//! |  | South Halo     |  |  <- Received from south neighbor
35//! +--+----------------+--+
36//! ```
37
38use super::AcousticParams;
39use ringkernel::prelude::*;
40use ringkernel_core::hlc::HlcTimestamp;
41use ringkernel_core::k2k::{DeliveryStatus, K2KBroker, K2KBuilder, K2KEndpoint};
42use ringkernel_core::message::{MessageEnvelope, MessageHeader};
43use ringkernel_core::runtime::KernelId;
44use std::collections::HashMap;
45use std::sync::Arc;
46
47// Legacy GPU compute (upload/download per step)
48#[cfg(feature = "wgpu")]
49use super::gpu_compute::{init_wgpu, TileBuffers, TileGpuComputePool};
50
51// New GPU-persistent backends using unified trait
52#[cfg(feature = "cuda")]
53use super::cuda_compute::CudaTileBackend;
54#[cfg(any(feature = "wgpu", feature = "cuda"))]
55use super::gpu_backend::{Edge, FdtdParams, TileGpuBackend, TileGpuBuffers};
56#[cfg(feature = "wgpu")]
57use super::wgpu_compute::{WgpuBuffer, WgpuTileBackend};
58
59/// Default tile size (16x16 cells per tile).
60pub const DEFAULT_TILE_SIZE: u32 = 16;
61
62/// Direction for halo exchange.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum HaloDirection {
65    North,
66    South,
67    East,
68    West,
69}
70
71impl HaloDirection {
72    /// Get the opposite direction.
73    pub fn opposite(self) -> Self {
74        match self {
75            HaloDirection::North => HaloDirection::South,
76            HaloDirection::South => HaloDirection::North,
77            HaloDirection::East => HaloDirection::West,
78            HaloDirection::West => HaloDirection::East,
79        }
80    }
81
82    /// All directions.
83    pub const ALL: [HaloDirection; 4] = [
84        HaloDirection::North,
85        HaloDirection::South,
86        HaloDirection::East,
87        HaloDirection::West,
88    ];
89}
90
91/// Message type IDs for K2K communication.
92const HALO_EXCHANGE_TYPE_ID: u64 = 200;
93
94/// A single tile actor managing a region of the simulation grid.
95pub struct TileActor {
96    /// Tile position in tile coordinates.
97    pub tile_x: u32,
98    pub tile_y: u32,
99
100    /// Size of the tile (cells per side).
101    tile_size: u32,
102
103    /// Buffer width including halos: tile_size + 2.
104    buffer_width: usize,
105
106    /// Current pressure values (buffer_width * buffer_width).
107    /// Layout: row-major with 1-cell halo on each side.
108    pressure: Vec<f32>,
109
110    /// Previous pressure values for time-stepping.
111    pressure_prev: Vec<f32>,
112
113    /// Halo buffers for async exchange (received from neighbors).
114    halo_north: Vec<f32>,
115    halo_south: Vec<f32>,
116    halo_east: Vec<f32>,
117    halo_west: Vec<f32>,
118
119    /// Flags indicating which halos have been received this step.
120    halos_received: [bool; 4],
121
122    /// Neighbor tile IDs (None if at grid boundary).
123    neighbor_north: Option<KernelId>,
124    neighbor_south: Option<KernelId>,
125    neighbor_east: Option<KernelId>,
126    neighbor_west: Option<KernelId>,
127
128    /// K2K endpoint for this tile.
129    endpoint: K2KEndpoint,
130
131    /// Kernel ID for this tile (used for K2K routing).
132    #[allow(dead_code)]
133    kernel_id: KernelId,
134
135    /// Is this tile on the grid boundary?
136    is_north_boundary: bool,
137    is_south_boundary: bool,
138    is_east_boundary: bool,
139    is_west_boundary: bool,
140
141    /// Reflection coefficient for boundaries.
142    reflection_coeff: f32,
143}
144
145impl TileActor {
146    /// Create a new tile actor.
147    pub fn new(
148        tile_x: u32,
149        tile_y: u32,
150        tile_size: u32,
151        tiles_x: u32,
152        tiles_y: u32,
153        broker: &Arc<K2KBroker>,
154        reflection_coeff: f32,
155    ) -> Self {
156        let buffer_width = (tile_size + 2) as usize;
157        let buffer_size = buffer_width * buffer_width;
158
159        let kernel_id = Self::tile_kernel_id(tile_x, tile_y);
160        let endpoint = broker.register(kernel_id.clone());
161
162        // Determine neighbors
163        let neighbor_north = if tile_y > 0 {
164            Some(Self::tile_kernel_id(tile_x, tile_y - 1))
165        } else {
166            None
167        };
168        let neighbor_south = if tile_y < tiles_y - 1 {
169            Some(Self::tile_kernel_id(tile_x, tile_y + 1))
170        } else {
171            None
172        };
173        let neighbor_west = if tile_x > 0 {
174            Some(Self::tile_kernel_id(tile_x - 1, tile_y))
175        } else {
176            None
177        };
178        let neighbor_east = if tile_x < tiles_x - 1 {
179            Some(Self::tile_kernel_id(tile_x + 1, tile_y))
180        } else {
181            None
182        };
183
184        Self {
185            tile_x,
186            tile_y,
187            tile_size,
188            buffer_width,
189            pressure: vec![0.0; buffer_size],
190            pressure_prev: vec![0.0; buffer_size],
191            halo_north: vec![0.0; tile_size as usize],
192            halo_south: vec![0.0; tile_size as usize],
193            halo_east: vec![0.0; tile_size as usize],
194            halo_west: vec![0.0; tile_size as usize],
195            halos_received: [false; 4],
196            neighbor_north,
197            neighbor_south,
198            neighbor_east,
199            neighbor_west,
200            endpoint,
201            kernel_id,
202            is_north_boundary: tile_y == 0,
203            is_south_boundary: tile_y == tiles_y - 1,
204            is_east_boundary: tile_x == tiles_x - 1,
205            is_west_boundary: tile_x == 0,
206            reflection_coeff,
207        }
208    }
209
210    /// Generate kernel ID for a tile.
211    pub fn tile_kernel_id(tile_x: u32, tile_y: u32) -> KernelId {
212        KernelId::new(format!("tile_{}_{}", tile_x, tile_y))
213    }
214
215    /// Convert local tile coordinates to buffer index.
216    #[inline(always)]
217    fn buffer_idx(&self, local_x: usize, local_y: usize) -> usize {
218        // Add 1 to account for halo
219        (local_y + 1) * self.buffer_width + (local_x + 1)
220    }
221
222    /// Get pressure at local tile coordinates.
223    pub fn get_pressure(&self, local_x: u32, local_y: u32) -> f32 {
224        self.pressure[self.buffer_idx(local_x as usize, local_y as usize)]
225    }
226
227    /// Set pressure at local tile coordinates.
228    pub fn set_pressure(&mut self, local_x: u32, local_y: u32, value: f32) {
229        let idx = self.buffer_idx(local_x as usize, local_y as usize);
230        self.pressure[idx] = value;
231    }
232
233    /// Extract edge data for sending to neighbors.
234    fn extract_edge(&self, direction: HaloDirection) -> Vec<f32> {
235        let size = self.tile_size as usize;
236        let mut edge = vec![0.0; size];
237
238        match direction {
239            HaloDirection::North => {
240                // First interior row (y = 0)
241                for (x, cell) in edge.iter_mut().enumerate().take(size) {
242                    *cell = self.pressure[self.buffer_idx(x, 0)];
243                }
244            }
245            HaloDirection::South => {
246                // Last interior row (y = size - 1)
247                for (x, cell) in edge.iter_mut().enumerate().take(size) {
248                    *cell = self.pressure[self.buffer_idx(x, size - 1)];
249                }
250            }
251            HaloDirection::West => {
252                // First interior column (x = 0)
253                for (y, cell) in edge.iter_mut().enumerate().take(size) {
254                    *cell = self.pressure[self.buffer_idx(0, y)];
255                }
256            }
257            HaloDirection::East => {
258                // Last interior column (x = size - 1)
259                for (y, cell) in edge.iter_mut().enumerate().take(size) {
260                    *cell = self.pressure[self.buffer_idx(size - 1, y)];
261                }
262            }
263        }
264
265        edge
266    }
267
268    /// Apply received halo data to the buffer.
269    fn apply_halo(&mut self, direction: HaloDirection, data: &[f32]) {
270        let size = self.tile_size as usize;
271        let bw = self.buffer_width;
272
273        match direction {
274            HaloDirection::North => {
275                // Top halo row (y = -1 in local coords, index row 0 in buffer)
276                self.pressure[1..(size + 1)].copy_from_slice(&data[..size]);
277            }
278            HaloDirection::South => {
279                // Bottom halo row (y = size in local coords)
280                for (x, &val) in data.iter().enumerate().take(size) {
281                    self.pressure[(size + 1) * bw + (x + 1)] = val;
282                }
283            }
284            HaloDirection::West => {
285                // Left halo column (x = -1 in local coords, index col 0 in buffer)
286                for (y, &val) in data.iter().enumerate().take(size) {
287                    self.pressure[(y + 1) * bw] = val;
288                }
289            }
290            HaloDirection::East => {
291                // Right halo column (x = size in local coords)
292                for (y, &val) in data.iter().enumerate().take(size) {
293                    self.pressure[(y + 1) * bw + (size + 1)] = val;
294                }
295            }
296        }
297    }
298
299    /// Apply boundary reflection for edges without neighbors.
300    fn apply_boundary_reflection(&mut self) {
301        let size = self.tile_size as usize;
302        let bw = self.buffer_width;
303        let refl = self.reflection_coeff;
304
305        if self.is_north_boundary {
306            // Reflect interior row 0 to halo row
307            for x in 0..size {
308                let interior_idx = self.buffer_idx(x, 0);
309                self.pressure[x + 1] = self.pressure[interior_idx] * refl;
310            }
311        }
312
313        if self.is_south_boundary {
314            // Reflect interior row (size-1) to halo row
315            for x in 0..size {
316                let interior_idx = self.buffer_idx(x, size - 1);
317                self.pressure[(size + 1) * bw + (x + 1)] = self.pressure[interior_idx] * refl;
318            }
319        }
320
321        if self.is_west_boundary {
322            // Reflect interior col 0 to halo col
323            for y in 0..size {
324                let interior_idx = self.buffer_idx(0, y);
325                self.pressure[(y + 1) * bw] = self.pressure[interior_idx] * refl;
326            }
327        }
328
329        if self.is_east_boundary {
330            // Reflect interior col (size-1) to halo col
331            for y in 0..size {
332                let interior_idx = self.buffer_idx(size - 1, y);
333                self.pressure[(y + 1) * bw + (size + 1)] = self.pressure[interior_idx] * refl;
334            }
335        }
336    }
337
338    /// Send halo data to all neighbors via K2K.
339    pub async fn send_halos(&self) -> Result<()> {
340        for direction in HaloDirection::ALL {
341            let neighbor = match direction {
342                HaloDirection::North => &self.neighbor_north,
343                HaloDirection::South => &self.neighbor_south,
344                HaloDirection::East => &self.neighbor_east,
345                HaloDirection::West => &self.neighbor_west,
346            };
347
348            if let Some(neighbor_id) = neighbor {
349                let edge_data = self.extract_edge(direction);
350
351                // Serialize halo data into envelope
352                let envelope = Self::create_halo_envelope(direction.opposite(), &edge_data);
353
354                let receipt = self.endpoint.send(neighbor_id.clone(), envelope).await?;
355                if receipt.status != DeliveryStatus::Delivered {
356                    tracing::warn!(
357                        "Halo send failed: tile ({},{}) -> {:?}, status: {:?}",
358                        self.tile_x,
359                        self.tile_y,
360                        neighbor_id,
361                        receipt.status
362                    );
363                }
364            }
365        }
366
367        Ok(())
368    }
369
370    /// Receive halo data from all neighbors via K2K.
371    pub fn receive_halos(&mut self) {
372        // Reset received flags
373        self.halos_received = [false; 4];
374
375        // Process all pending messages
376        while let Some(message) = self.endpoint.try_receive() {
377            if let Some((direction, data)) = Self::parse_halo_envelope(&message.envelope) {
378                // Store in halo buffer
379                match direction {
380                    HaloDirection::North => {
381                        self.halo_north.copy_from_slice(&data);
382                        self.halos_received[0] = true;
383                    }
384                    HaloDirection::South => {
385                        self.halo_south.copy_from_slice(&data);
386                        self.halos_received[1] = true;
387                    }
388                    HaloDirection::East => {
389                        self.halo_east.copy_from_slice(&data);
390                        self.halos_received[2] = true;
391                    }
392                    HaloDirection::West => {
393                        self.halo_west.copy_from_slice(&data);
394                        self.halos_received[3] = true;
395                    }
396                }
397            }
398        }
399
400        // Apply received halos to pressure buffer
401        if self.halos_received[0] && self.neighbor_north.is_some() {
402            self.apply_halo(HaloDirection::North, &self.halo_north.clone());
403        }
404        if self.halos_received[1] && self.neighbor_south.is_some() {
405            self.apply_halo(HaloDirection::South, &self.halo_south.clone());
406        }
407        if self.halos_received[2] && self.neighbor_east.is_some() {
408            self.apply_halo(HaloDirection::East, &self.halo_east.clone());
409        }
410        if self.halos_received[3] && self.neighbor_west.is_some() {
411            self.apply_halo(HaloDirection::West, &self.halo_west.clone());
412        }
413
414        // Apply boundary reflection for edges without neighbors
415        self.apply_boundary_reflection();
416    }
417
418    /// Create a K2K envelope containing halo data.
419    fn create_halo_envelope(direction: HaloDirection, data: &[f32]) -> MessageEnvelope {
420        // Pack direction and data into payload
421        // Format: [direction_byte, f32 data...]
422        let mut payload = Vec::with_capacity(1 + data.len() * 4);
423        payload.push(direction as u8);
424        for &value in data {
425            payload.extend_from_slice(&value.to_le_bytes());
426        }
427
428        let header = MessageHeader::new(
429            HALO_EXCHANGE_TYPE_ID,
430            0, // source_kernel (host)
431            0, // dest_kernel (will be set by K2K routing)
432            payload.len(),
433            HlcTimestamp::now(0),
434        );
435
436        MessageEnvelope { header, payload }
437    }
438
439    /// Parse a K2K envelope to extract halo data.
440    fn parse_halo_envelope(envelope: &MessageEnvelope) -> Option<(HaloDirection, Vec<f32>)> {
441        if envelope.header.message_type != HALO_EXCHANGE_TYPE_ID {
442            return None;
443        }
444
445        if envelope.payload.is_empty() {
446            return None;
447        }
448
449        let direction = match envelope.payload[0] {
450            0 => HaloDirection::North,
451            1 => HaloDirection::South,
452            2 => HaloDirection::East,
453            3 => HaloDirection::West,
454            _ => return None,
455        };
456
457        let data: Vec<f32> = envelope.payload[1..]
458            .chunks_exact(4)
459            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
460            .collect();
461
462        Some((direction, data))
463    }
464
465    /// Compute FDTD for all interior cells.
466    pub fn compute_fdtd(&mut self, c2: f32, damping: f32) {
467        let size = self.tile_size as usize;
468        let bw = self.buffer_width;
469
470        // Process all interior cells
471        for local_y in 0..size {
472            for local_x in 0..size {
473                let idx = (local_y + 1) * bw + (local_x + 1);
474
475                let p_curr = self.pressure[idx];
476                let p_prev = self.pressure_prev[idx];
477
478                // Neighbor access (halos provide boundary values)
479                let p_north = self.pressure[idx - bw];
480                let p_south = self.pressure[idx + bw];
481                let p_west = self.pressure[idx - 1];
482                let p_east = self.pressure[idx + 1];
483
484                // FDTD computation
485                let laplacian = p_north + p_south + p_east + p_west - 4.0 * p_curr;
486                let p_new = 2.0 * p_curr - p_prev + c2 * laplacian;
487
488                self.pressure_prev[idx] = p_new * damping;
489            }
490        }
491    }
492
493    /// Swap pressure buffers after FDTD step.
494    pub fn swap_buffers(&mut self) {
495        std::mem::swap(&mut self.pressure, &mut self.pressure_prev);
496    }
497
498    /// Reset tile to initial state.
499    pub fn reset(&mut self) {
500        self.pressure.fill(0.0);
501        self.pressure_prev.fill(0.0);
502    }
503}
504
505/// GPU compute resources for tile-based FDTD (wgpu feature only).
506/// This is the legacy mode with upload/download per step.
507#[cfg(feature = "wgpu")]
508struct GpuComputeResources {
509    /// Shared compute pool for all tiles.
510    pool: TileGpuComputePool,
511    /// Per-tile GPU buffers.
512    tile_buffers: HashMap<(u32, u32), TileBuffers>,
513}
514
515/// GPU-persistent state using WGPU backend.
516/// Pressure stays GPU-resident, only halos are transferred.
517#[cfg(feature = "wgpu")]
518struct WgpuPersistentState {
519    /// WGPU backend.
520    backend: WgpuTileBackend,
521    /// Per-tile GPU buffers.
522    tile_buffers: HashMap<(u32, u32), TileGpuBuffers<WgpuBuffer>>,
523}
524
525/// GPU-persistent state using CUDA backend.
526/// Pressure stays GPU-resident, only halos are transferred.
527#[cfg(feature = "cuda")]
528struct CudaPersistentState {
529    /// CUDA backend.
530    backend: CudaTileBackend,
531    /// Per-tile GPU buffers.
532    tile_buffers: HashMap<(u32, u32), TileGpuBuffers<ringkernel_cuda::CudaBuffer>>,
533}
534
535/// Which GPU backend is active for persistent state.
536#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537pub enum GpuPersistentBackend {
538    /// WGPU (WebGPU) backend.
539    Wgpu,
540    /// CUDA (NVIDIA) backend.
541    Cuda,
542}
543
544/// A simulation grid using tile actors with K2K messaging.
545pub struct TileKernelGrid {
546    /// Grid dimensions in cells.
547    pub width: u32,
548    pub height: u32,
549
550    /// Tile size (cells per tile side).
551    tile_size: u32,
552
553    /// Grid dimensions in tiles.
554    tiles_x: u32,
555    tiles_y: u32,
556
557    /// Tile actors indexed by (tile_x, tile_y).
558    tiles: HashMap<(u32, u32), TileActor>,
559
560    /// K2K broker for inter-tile messaging.
561    broker: Arc<K2KBroker>,
562
563    /// Acoustic simulation parameters.
564    pub params: AcousticParams,
565
566    /// The backend being used.
567    backend: Backend,
568
569    /// RingKernel runtime (for kernel handle management).
570    #[allow(dead_code)]
571    runtime: Arc<RingKernel>,
572
573    /// Legacy GPU compute pool for tile FDTD (optional, requires wgpu feature).
574    /// This mode uploads/downloads full buffers each step.
575    #[cfg(feature = "wgpu")]
576    gpu_compute: Option<GpuComputeResources>,
577
578    /// GPU-persistent WGPU state (optional).
579    /// Pressure stays GPU-resident, only halos are transferred.
580    #[cfg(feature = "wgpu")]
581    wgpu_persistent: Option<WgpuPersistentState>,
582
583    /// GPU-persistent CUDA state (optional).
584    /// Pressure stays GPU-resident, only halos are transferred.
585    #[cfg(feature = "cuda")]
586    cuda_persistent: Option<CudaPersistentState>,
587}
588
589impl std::fmt::Debug for TileKernelGrid {
590    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591        f.debug_struct("TileKernelGrid")
592            .field("width", &self.width)
593            .field("height", &self.height)
594            .field("tile_size", &self.tile_size)
595            .field("tiles_x", &self.tiles_x)
596            .field("tiles_y", &self.tiles_y)
597            .field("tile_count", &self.tiles.len())
598            .field("backend", &self.backend)
599            .finish()
600    }
601}
602
603impl TileKernelGrid {
604    /// Create a new tile-based kernel grid.
605    pub async fn new(
606        width: u32,
607        height: u32,
608        params: AcousticParams,
609        backend: Backend,
610    ) -> Result<Self> {
611        Self::with_tile_size(width, height, params, backend, DEFAULT_TILE_SIZE).await
612    }
613
614    /// Create a new tile-based kernel grid with custom tile size.
615    pub async fn with_tile_size(
616        width: u32,
617        height: u32,
618        params: AcousticParams,
619        backend: Backend,
620        tile_size: u32,
621    ) -> Result<Self> {
622        let runtime = Arc::new(RingKernel::with_backend(backend).await?);
623
624        // Calculate tile grid dimensions
625        let tiles_x = width.div_ceil(tile_size);
626        let tiles_y = height.div_ceil(tile_size);
627
628        // Create K2K broker
629        let broker = K2KBuilder::new()
630            .max_pending_messages(tiles_x as usize * tiles_y as usize * 8)
631            .build();
632
633        // Create tile actors
634        let mut tiles = HashMap::new();
635        for ty in 0..tiles_y {
636            for tx in 0..tiles_x {
637                let tile = TileActor::new(tx, ty, tile_size, tiles_x, tiles_y, &broker, 0.95);
638                tiles.insert((tx, ty), tile);
639            }
640        }
641
642        tracing::info!(
643            "Created TileKernelGrid: {}x{} cells, {}x{} tiles ({}x{} per tile), {} tile actors",
644            width,
645            height,
646            tiles_x,
647            tiles_y,
648            tile_size,
649            tile_size,
650            tiles.len()
651        );
652
653        Ok(Self {
654            width,
655            height,
656            tile_size,
657            tiles_x,
658            tiles_y,
659            tiles,
660            broker,
661            params,
662            backend,
663            runtime,
664            #[cfg(feature = "wgpu")]
665            gpu_compute: None,
666            #[cfg(feature = "wgpu")]
667            wgpu_persistent: None,
668            #[cfg(feature = "cuda")]
669            cuda_persistent: None,
670        })
671    }
672
673    /// Convert global cell coordinates to (tile_x, tile_y, local_x, local_y).
674    fn global_to_tile_coords(&self, x: u32, y: u32) -> (u32, u32, u32, u32) {
675        let tile_x = x / self.tile_size;
676        let tile_y = y / self.tile_size;
677        let local_x = x % self.tile_size;
678        let local_y = y % self.tile_size;
679        (tile_x, tile_y, local_x, local_y)
680    }
681
682    /// Perform one simulation step using tile actors with K2K messaging.
683    pub async fn step(&mut self) -> Result<()> {
684        let c2 = self.params.courant_number().powi(2);
685        let damping = 1.0 - self.params.damping;
686
687        // Phase 1: All tiles send halo data to neighbors (K2K)
688        for tile in self.tiles.values() {
689            tile.send_halos().await?;
690        }
691
692        // Phase 2: All tiles receive halo data from neighbors (K2K)
693        for tile in self.tiles.values_mut() {
694            tile.receive_halos();
695        }
696
697        // Phase 3: All tiles compute FDTD for interior cells
698        for tile in self.tiles.values_mut() {
699            tile.compute_fdtd(c2, damping);
700        }
701
702        // Phase 4: All tiles swap buffers
703        for tile in self.tiles.values_mut() {
704            tile.swap_buffers();
705        }
706
707        Ok(())
708    }
709
710    /// Enable GPU compute for tile FDTD acceleration.
711    ///
712    /// This initializes WGPU resources and creates GPU buffers for each tile.
713    /// Once enabled, use `step_gpu()` for GPU-accelerated simulation steps.
714    #[cfg(feature = "wgpu")]
715    pub async fn enable_gpu_compute(&mut self) -> Result<()> {
716        let (device, queue) = init_wgpu().await?;
717        let pool = TileGpuComputePool::new(device, queue, self.tile_size)?;
718
719        // Create GPU buffers for each tile
720        let mut tile_buffers = HashMap::new();
721        for tile_coords in self.tiles.keys() {
722            let buffers = pool.create_tile_buffers();
723            tile_buffers.insert(*tile_coords, buffers);
724        }
725
726        let num_tiles = tile_buffers.len();
727        self.gpu_compute = Some(GpuComputeResources { pool, tile_buffers });
728
729        tracing::info!(
730            "GPU compute enabled for TileKernelGrid: {} tiles with GPU buffers",
731            num_tiles
732        );
733
734        Ok(())
735    }
736
737    /// Check if GPU compute is enabled.
738    #[cfg(feature = "wgpu")]
739    pub fn is_gpu_enabled(&self) -> bool {
740        self.gpu_compute.is_some()
741    }
742
743    /// Perform one simulation step using GPU compute for FDTD.
744    ///
745    /// This is the hybrid approach:
746    /// - K2K messaging handles halo exchange between tiles (actor model showcase)
747    /// - GPU compute shaders handle FDTD for tile interiors (parallel compute)
748    #[cfg(feature = "wgpu")]
749    pub async fn step_gpu(&mut self) -> Result<()> {
750        let gpu = self.gpu_compute.as_ref().ok_or_else(|| {
751            ringkernel_core::error::RingKernelError::BackendError(
752                "GPU compute not enabled. Call enable_gpu_compute() first.".to_string(),
753            )
754        })?;
755
756        let c2 = self.params.courant_number().powi(2);
757        let damping = 1.0 - self.params.damping;
758
759        // Phase 1: All tiles send halo data to neighbors (K2K)
760        for tile in self.tiles.values() {
761            tile.send_halos().await?;
762        }
763
764        // Phase 2: All tiles receive halo data from neighbors (K2K)
765        for tile in self.tiles.values_mut() {
766            tile.receive_halos();
767        }
768
769        // Phase 3: GPU compute FDTD for all tile interiors
770        // Each tile dispatches to GPU and reads back results
771        for ((tx, ty), tile) in self.tiles.iter_mut() {
772            if let Some(buffers) = gpu.tile_buffers.get(&(*tx, *ty)) {
773                // Get tile's pressure buffer (with halos already applied from K2K)
774                let pressure = &tile.pressure;
775                let pressure_prev = &tile.pressure_prev;
776
777                // Dispatch GPU compute
778                let result = gpu
779                    .pool
780                    .compute_fdtd(buffers, pressure, pressure_prev, c2, damping);
781
782                // Copy results back to tile's pressure_prev buffer
783                tile.pressure_prev.copy_from_slice(&result);
784            } else {
785                // Fallback to CPU if no GPU buffers (shouldn't happen)
786                tile.compute_fdtd(c2, damping);
787            }
788        }
789
790        // Phase 4: All tiles swap buffers
791        for tile in self.tiles.values_mut() {
792            tile.swap_buffers();
793        }
794
795        Ok(())
796    }
797
798    /// Enable GPU-persistent WGPU compute.
799    ///
800    /// This mode keeps pressure state GPU-resident and only transfers halos
801    /// for K2K communication. Much faster than the legacy mode which uploads
802    /// and downloads full buffers each step.
803    #[cfg(feature = "wgpu")]
804    pub async fn enable_wgpu_persistent(&mut self) -> Result<()> {
805        let backend = WgpuTileBackend::new(self.tile_size).await?;
806
807        // Create GPU buffers for each tile and upload initial state
808        let mut tile_buffers = HashMap::new();
809        for ((tx, ty), tile) in &self.tiles {
810            let buffers = backend.create_tile_buffers(self.tile_size)?;
811            // Upload initial state (all zeros typically)
812            backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
813            tile_buffers.insert((*tx, *ty), buffers);
814        }
815
816        let num_tiles = tile_buffers.len();
817        self.wgpu_persistent = Some(WgpuPersistentState {
818            backend,
819            tile_buffers,
820        });
821
822        tracing::info!(
823            "WGPU GPU-persistent compute enabled for TileKernelGrid: {} tiles",
824            num_tiles
825        );
826
827        Ok(())
828    }
829
830    /// Enable GPU-persistent CUDA compute.
831    ///
832    /// This mode keeps pressure state GPU-resident and only transfers halos
833    /// for K2K communication. Much faster than upload/download per step.
834    #[cfg(feature = "cuda")]
835    pub fn enable_cuda_persistent(&mut self) -> Result<()> {
836        let backend = CudaTileBackend::new(self.tile_size)?;
837
838        // Create GPU buffers for each tile and upload initial state
839        let mut tile_buffers = HashMap::new();
840        for ((tx, ty), tile) in &self.tiles {
841            let buffers = backend.create_tile_buffers(self.tile_size)?;
842            // Upload initial state
843            backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
844            tile_buffers.insert((*tx, *ty), buffers);
845        }
846
847        let num_tiles = tile_buffers.len();
848        self.cuda_persistent = Some(CudaPersistentState {
849            backend,
850            tile_buffers,
851        });
852
853        tracing::info!(
854            "CUDA GPU-persistent compute enabled for TileKernelGrid: {} tiles",
855            num_tiles
856        );
857
858        Ok(())
859    }
860
861    /// Check if GPU-persistent mode is enabled.
862    pub fn is_gpu_persistent_enabled(&self) -> bool {
863        #[cfg(feature = "wgpu")]
864        if self.wgpu_persistent.is_some() {
865            return true;
866        }
867        #[cfg(feature = "cuda")]
868        if self.cuda_persistent.is_some() {
869            return true;
870        }
871        false
872    }
873
874    /// Get which GPU-persistent backend is active.
875    pub fn gpu_persistent_backend(&self) -> Option<GpuPersistentBackend> {
876        #[cfg(feature = "cuda")]
877        if self.cuda_persistent.is_some() {
878            return Some(GpuPersistentBackend::Cuda);
879        }
880        #[cfg(feature = "wgpu")]
881        if self.wgpu_persistent.is_some() {
882            return Some(GpuPersistentBackend::Wgpu);
883        }
884        None
885    }
886
887    /// Perform one simulation step using GPU-persistent WGPU compute.
888    ///
889    /// State stays GPU-resident. Only halos are transferred for K2K messaging.
890    #[cfg(feature = "wgpu")]
891    pub fn step_wgpu_persistent(&mut self) -> Result<()> {
892        let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
893            ringkernel_core::error::RingKernelError::BackendError(
894                "WGPU persistent compute not enabled. Call enable_wgpu_persistent() first."
895                    .to_string(),
896            )
897        })?;
898
899        let c2 = self.params.courant_number().powi(2);
900        let damping = 1.0 - self.params.damping;
901        let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
902
903        // Phase 1: Extract halos from GPU and send via K2K
904        // Build halo messages for all tiles
905        let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
906
907        for (tx, ty) in self.tiles.keys() {
908            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
909                // Extract halos based on tile's neighbors
910                let tile = self.tiles.get(&(*tx, *ty)).unwrap();
911
912                if tile.neighbor_north.is_some() {
913                    let halo = state.backend.extract_halo(buffers, Edge::North)?;
914                    halo_messages.push((
915                        tile.neighbor_north.clone().unwrap(),
916                        HaloDirection::South,
917                        halo,
918                    ));
919                }
920                if tile.neighbor_south.is_some() {
921                    let halo = state.backend.extract_halo(buffers, Edge::South)?;
922                    halo_messages.push((
923                        tile.neighbor_south.clone().unwrap(),
924                        HaloDirection::North,
925                        halo,
926                    ));
927                }
928                if tile.neighbor_west.is_some() {
929                    let halo = state.backend.extract_halo(buffers, Edge::West)?;
930                    halo_messages.push((
931                        tile.neighbor_west.clone().unwrap(),
932                        HaloDirection::East,
933                        halo,
934                    ));
935                }
936                if tile.neighbor_east.is_some() {
937                    let halo = state.backend.extract_halo(buffers, Edge::East)?;
938                    halo_messages.push((
939                        tile.neighbor_east.clone().unwrap(),
940                        HaloDirection::West,
941                        halo,
942                    ));
943                }
944            }
945        }
946
947        // Phase 2: Inject halos into GPU buffers
948        // Group messages by destination tile
949        for (dest_id, direction, halo_data) in halo_messages {
950            // Find the tile coords from kernel ID
951            for (tx, ty) in self.tiles.keys() {
952                if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
953                    if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
954                        let edge = match direction {
955                            HaloDirection::North => Edge::North,
956                            HaloDirection::South => Edge::South,
957                            HaloDirection::East => Edge::East,
958                            HaloDirection::West => Edge::West,
959                        };
960                        state.backend.inject_halo(buffers, edge, &halo_data)?;
961                    }
962                    break;
963                }
964            }
965        }
966
967        // Phase 3: Apply boundary reflection for edge tiles (on GPU)
968        // For now, we handle this by keeping boundaries at zero (absorbing)
969        // TODO: Add GPU kernel for boundary reflection
970
971        // Phase 4: Compute FDTD on GPU for all tiles
972        for (tx, ty) in self.tiles.keys() {
973            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
974                state.backend.fdtd_step(buffers, &fdtd_params)?;
975            }
976        }
977
978        // Synchronize
979        state.backend.synchronize()?;
980
981        // Phase 5: Swap buffers (pointer swap only, no data movement)
982        // NOTE: We need mutable access to state for this
983        let state = self.wgpu_persistent.as_mut().unwrap();
984        for buffers in state.tile_buffers.values_mut() {
985            state.backend.swap_buffers(buffers);
986        }
987
988        Ok(())
989    }
990
991    /// Perform one simulation step using GPU-persistent CUDA compute.
992    ///
993    /// State stays GPU-resident. Only halos are transferred for K2K messaging.
994    #[cfg(feature = "cuda")]
995    pub fn step_cuda_persistent(&mut self) -> Result<()> {
996        let state = self.cuda_persistent.as_ref().ok_or_else(|| {
997            ringkernel_core::error::RingKernelError::BackendError(
998                "CUDA persistent compute not enabled. Call enable_cuda_persistent() first."
999                    .to_string(),
1000            )
1001        })?;
1002
1003        let c2 = self.params.courant_number().powi(2);
1004        let damping = 1.0 - self.params.damping;
1005        let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
1006
1007        // Phase 1: Extract halos from GPU and build messages
1008        let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
1009
1010        for (tx, ty) in self.tiles.keys() {
1011            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1012                let tile = self.tiles.get(&(*tx, *ty)).unwrap();
1013
1014                if tile.neighbor_north.is_some() {
1015                    let halo = state.backend.extract_halo(buffers, Edge::North)?;
1016                    halo_messages.push((
1017                        tile.neighbor_north.clone().unwrap(),
1018                        HaloDirection::South,
1019                        halo,
1020                    ));
1021                }
1022                if tile.neighbor_south.is_some() {
1023                    let halo = state.backend.extract_halo(buffers, Edge::South)?;
1024                    halo_messages.push((
1025                        tile.neighbor_south.clone().unwrap(),
1026                        HaloDirection::North,
1027                        halo,
1028                    ));
1029                }
1030                if tile.neighbor_west.is_some() {
1031                    let halo = state.backend.extract_halo(buffers, Edge::West)?;
1032                    halo_messages.push((
1033                        tile.neighbor_west.clone().unwrap(),
1034                        HaloDirection::East,
1035                        halo,
1036                    ));
1037                }
1038                if tile.neighbor_east.is_some() {
1039                    let halo = state.backend.extract_halo(buffers, Edge::East)?;
1040                    halo_messages.push((
1041                        tile.neighbor_east.clone().unwrap(),
1042                        HaloDirection::West,
1043                        halo,
1044                    ));
1045                }
1046            }
1047        }
1048
1049        // Phase 2: Inject halos into GPU buffers
1050        for (dest_id, direction, halo_data) in halo_messages {
1051            for (tx, ty) in self.tiles.keys() {
1052                if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
1053                    if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1054                        let edge = match direction {
1055                            HaloDirection::North => Edge::North,
1056                            HaloDirection::South => Edge::South,
1057                            HaloDirection::East => Edge::East,
1058                            HaloDirection::West => Edge::West,
1059                        };
1060                        state.backend.inject_halo(buffers, edge, &halo_data)?;
1061                    }
1062                    break;
1063                }
1064            }
1065        }
1066
1067        // Phase 3: Compute FDTD on GPU for all tiles
1068        for (tx, ty) in self.tiles.keys() {
1069            if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1070                state.backend.fdtd_step(buffers, &fdtd_params)?;
1071            }
1072        }
1073
1074        // Synchronize
1075        state.backend.synchronize()?;
1076
1077        // Phase 4: Swap buffers
1078        let state = self.cuda_persistent.as_mut().unwrap();
1079        for buffers in state.tile_buffers.values_mut() {
1080            state.backend.swap_buffers(buffers);
1081        }
1082
1083        Ok(())
1084    }
1085
1086    /// Read pressure grid from GPU for visualization.
1087    ///
1088    /// This is the only time we transfer full tile data from GPU to host.
1089    /// Call this only when you need to render the simulation.
1090    #[cfg(feature = "wgpu")]
1091    pub fn read_pressure_from_wgpu(&self) -> Result<Vec<Vec<f32>>> {
1092        let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
1093            ringkernel_core::error::RingKernelError::BackendError(
1094                "WGPU persistent compute not enabled.".to_string(),
1095            )
1096        })?;
1097
1098        let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1099
1100        for ((tx, ty), buffers) in &state.tile_buffers {
1101            let interior = state.backend.read_interior_pressure(buffers)?;
1102
1103            // Copy tile interior to grid
1104            let tile_start_x = tx * self.tile_size;
1105            let tile_start_y = ty * self.tile_size;
1106
1107            for ly in 0..self.tile_size {
1108                for lx in 0..self.tile_size {
1109                    let gx = tile_start_x + lx;
1110                    let gy = tile_start_y + ly;
1111                    if gx < self.width && gy < self.height {
1112                        grid[gy as usize][gx as usize] =
1113                            interior[(ly * self.tile_size + lx) as usize];
1114                    }
1115                }
1116            }
1117        }
1118
1119        Ok(grid)
1120    }
1121
1122    /// Read pressure grid from CUDA for visualization.
1123    #[cfg(feature = "cuda")]
1124    pub fn read_pressure_from_cuda(&self) -> Result<Vec<Vec<f32>>> {
1125        let state = self.cuda_persistent.as_ref().ok_or_else(|| {
1126            ringkernel_core::error::RingKernelError::BackendError(
1127                "CUDA persistent compute not enabled.".to_string(),
1128            )
1129        })?;
1130
1131        let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1132
1133        for ((tx, ty), buffers) in &state.tile_buffers {
1134            let interior = state.backend.read_interior_pressure(buffers)?;
1135
1136            let tile_start_x = tx * self.tile_size;
1137            let tile_start_y = ty * self.tile_size;
1138
1139            for ly in 0..self.tile_size {
1140                for lx in 0..self.tile_size {
1141                    let gx = tile_start_x + lx;
1142                    let gy = tile_start_y + ly;
1143                    if gx < self.width && gy < self.height {
1144                        grid[gy as usize][gx as usize] =
1145                            interior[(ly * self.tile_size + lx) as usize];
1146                    }
1147                }
1148            }
1149        }
1150
1151        Ok(grid)
1152    }
1153
1154    /// Upload a pressure impulse to GPU-persistent state.
1155    ///
1156    /// For GPU-persistent mode, we need to upload the impulse to the GPU.
1157    #[cfg(feature = "wgpu")]
1158    pub fn inject_impulse_wgpu(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1159        if x >= self.width || y >= self.height {
1160            return Ok(());
1161        }
1162
1163        let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1164
1165        // First, update the CPU tile (for consistency)
1166        if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1167            let current = tile.get_pressure(local_x, local_y);
1168            tile.set_pressure(local_x, local_y, current + amplitude);
1169
1170            // Then upload to GPU
1171            if let Some(state) = &self.wgpu_persistent {
1172                if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1173                    state.backend.upload_initial_state(
1174                        buffers,
1175                        &tile.pressure,
1176                        &tile.pressure_prev,
1177                    )?;
1178                }
1179            }
1180        }
1181
1182        Ok(())
1183    }
1184
1185    /// Upload a pressure impulse to CUDA GPU-persistent state.
1186    #[cfg(feature = "cuda")]
1187    pub fn inject_impulse_cuda(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1188        if x >= self.width || y >= self.height {
1189            return Ok(());
1190        }
1191
1192        let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1193
1194        if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1195            let current = tile.get_pressure(local_x, local_y);
1196            tile.set_pressure(local_x, local_y, current + amplitude);
1197
1198            if let Some(state) = &self.cuda_persistent {
1199                if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1200                    state.backend.upload_initial_state(
1201                        buffers,
1202                        &tile.pressure,
1203                        &tile.pressure_prev,
1204                    )?;
1205                }
1206            }
1207        }
1208
1209        Ok(())
1210    }
1211
1212    /// Inject an impulse at the given global grid position.
1213    pub fn inject_impulse(&mut self, x: u32, y: u32, amplitude: f32) {
1214        if x >= self.width || y >= self.height {
1215            return;
1216        }
1217
1218        let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1219
1220        if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1221            let current = tile.get_pressure(local_x, local_y);
1222            tile.set_pressure(local_x, local_y, current + amplitude);
1223        }
1224    }
1225
1226    /// Get the pressure grid for visualization.
1227    pub fn get_pressure_grid(&self) -> Vec<Vec<f32>> {
1228        let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1229
1230        for y in 0..self.height {
1231            for x in 0..self.width {
1232                let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1233
1234                if let Some(tile) = self.tiles.get(&(tile_x, tile_y)) {
1235                    // Handle tiles at the edge that may have fewer cells
1236                    if local_x < self.tile_size && local_y < self.tile_size {
1237                        grid[y as usize][x as usize] = tile.get_pressure(local_x, local_y);
1238                    }
1239                }
1240            }
1241        }
1242
1243        grid
1244    }
1245
1246    /// Get the maximum absolute pressure in the grid.
1247    pub fn max_pressure(&self) -> f32 {
1248        self.tiles
1249            .values()
1250            .flat_map(|tile| {
1251                (0..self.tile_size).flat_map(move |y| {
1252                    (0..self.tile_size).map(move |x| tile.get_pressure(x, y).abs())
1253                })
1254            })
1255            .fold(0.0, f32::max)
1256    }
1257
1258    /// Get total energy in the system.
1259    pub fn total_energy(&self) -> f32 {
1260        self.tiles
1261            .values()
1262            .flat_map(|tile| {
1263                (0..self.tile_size).flat_map(move |y| {
1264                    (0..self.tile_size).map(move |x| {
1265                        let p = tile.get_pressure(x, y);
1266                        p * p
1267                    })
1268                })
1269            })
1270            .sum()
1271    }
1272
1273    /// Reset all tiles to initial state.
1274    pub fn reset(&mut self) {
1275        for tile in self.tiles.values_mut() {
1276            tile.reset();
1277        }
1278    }
1279
1280    /// Get the number of cells.
1281    pub fn cell_count(&self) -> usize {
1282        (self.width * self.height) as usize
1283    }
1284
1285    /// Get the number of tile actors.
1286    pub fn tile_count(&self) -> usize {
1287        self.tiles.len()
1288    }
1289
1290    /// Get the current backend.
1291    pub fn backend(&self) -> Backend {
1292        self.backend
1293    }
1294
1295    /// Get K2K messaging statistics.
1296    pub fn k2k_stats(&self) -> ringkernel_core::k2k::K2KStats {
1297        self.broker.stats()
1298    }
1299
1300    /// Update acoustic parameters.
1301    pub fn set_speed_of_sound(&mut self, speed: f32) {
1302        self.params.set_speed_of_sound(speed);
1303    }
1304
1305    /// Update cell size.
1306    pub fn set_cell_size(&mut self, size: f32) {
1307        self.params.set_cell_size(size);
1308    }
1309
1310    /// Resize the grid.
1311    pub async fn resize(&mut self, new_width: u32, new_height: u32) -> Result<()> {
1312        self.width = new_width;
1313        self.height = new_height;
1314
1315        // Recalculate tile grid dimensions
1316        self.tiles_x = new_width.div_ceil(self.tile_size);
1317        self.tiles_y = new_height.div_ceil(self.tile_size);
1318
1319        // Create new K2K broker
1320        self.broker = K2KBuilder::new()
1321            .max_pending_messages(self.tiles_x as usize * self.tiles_y as usize * 8)
1322            .build();
1323
1324        // Recreate tile actors
1325        self.tiles.clear();
1326        for ty in 0..self.tiles_y {
1327            for tx in 0..self.tiles_x {
1328                let tile = TileActor::new(
1329                    tx,
1330                    ty,
1331                    self.tile_size,
1332                    self.tiles_x,
1333                    self.tiles_y,
1334                    &self.broker,
1335                    0.95,
1336                );
1337                self.tiles.insert((tx, ty), tile);
1338            }
1339        }
1340
1341        tracing::info!(
1342            "Resized TileKernelGrid: {}x{} cells, {} tile actors",
1343            new_width,
1344            new_height,
1345            self.tiles.len()
1346        );
1347
1348        Ok(())
1349    }
1350
1351    /// Shutdown the grid.
1352    pub async fn shutdown(self) -> Result<()> {
1353        // K2K broker and tiles are dropped automatically
1354        Ok(())
1355    }
1356}
1357
1358#[cfg(test)]
1359mod tests {
1360    use super::*;
1361
1362    #[tokio::test]
1363    async fn test_tile_grid_creation() {
1364        let params = AcousticParams::new(343.0, 1.0);
1365        let grid = TileKernelGrid::new(64, 64, params, Backend::Cpu)
1366            .await
1367            .unwrap();
1368
1369        assert_eq!(grid.width, 64);
1370        assert_eq!(grid.height, 64);
1371        assert_eq!(grid.tile_size, DEFAULT_TILE_SIZE);
1372        assert_eq!(grid.tiles_x, 4); // 64 / 16 = 4
1373        assert_eq!(grid.tiles_y, 4);
1374        assert_eq!(grid.tile_count(), 16); // 4 * 4 = 16 tiles
1375    }
1376
1377    #[tokio::test]
1378    async fn test_tile_grid_impulse() {
1379        let params = AcousticParams::new(343.0, 1.0);
1380        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1381            .await
1382            .unwrap();
1383
1384        grid.inject_impulse(16, 16, 1.0);
1385
1386        let pressure_grid = grid.get_pressure_grid();
1387        assert_eq!(pressure_grid[16][16], 1.0);
1388    }
1389
1390    #[tokio::test]
1391    async fn test_tile_grid_step() {
1392        let params = AcousticParams::new(343.0, 1.0);
1393        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1394            .await
1395            .unwrap();
1396
1397        // Inject impulse at center
1398        grid.inject_impulse(16, 16, 1.0);
1399
1400        // Run several steps
1401        for _ in 0..10 {
1402            grid.step().await.unwrap();
1403        }
1404
1405        // Wave should have propagated
1406        let pressure_grid = grid.get_pressure_grid();
1407        let neighbor_pressure = pressure_grid[16][17];
1408        assert!(
1409            neighbor_pressure.abs() > 0.0,
1410            "Wave should have propagated to neighbor"
1411        );
1412    }
1413
1414    #[tokio::test]
1415    async fn test_tile_grid_k2k_stats() {
1416        let params = AcousticParams::new(343.0, 1.0);
1417        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1418            .await
1419            .unwrap();
1420
1421        grid.inject_impulse(16, 16, 1.0);
1422        grid.step().await.unwrap();
1423
1424        let stats = grid.k2k_stats();
1425        assert!(
1426            stats.messages_delivered > 0,
1427            "K2K messages should have been exchanged"
1428        );
1429    }
1430
1431    #[tokio::test]
1432    async fn test_tile_grid_reset() {
1433        let params = AcousticParams::new(343.0, 1.0);
1434        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1435            .await
1436            .unwrap();
1437
1438        grid.inject_impulse(16, 16, 1.0);
1439        grid.step().await.unwrap();
1440
1441        grid.reset();
1442
1443        let pressure_grid = grid.get_pressure_grid();
1444        for row in pressure_grid {
1445            for p in row {
1446                assert_eq!(p, 0.0);
1447            }
1448        }
1449    }
1450
1451    #[tokio::test]
1452    async fn test_tile_boundary_handling() {
1453        let params = AcousticParams::new(343.0, 1.0);
1454        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1455            .await
1456            .unwrap();
1457
1458        // Inject near boundary
1459        grid.inject_impulse(1, 1, 1.0);
1460
1461        // Run several steps
1462        for _ in 0..5 {
1463            grid.step().await.unwrap();
1464        }
1465
1466        // Should not crash and energy should be bounded
1467        let energy = grid.total_energy();
1468        assert!(energy.is_finite(), "Energy should be finite");
1469    }
1470
1471    #[cfg(feature = "wgpu")]
1472    #[tokio::test]
1473    #[ignore] // May not have GPU
1474    async fn test_tile_grid_gpu_step() {
1475        let params = AcousticParams::new(343.0, 1.0);
1476        let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1477            .await
1478            .unwrap();
1479
1480        // Enable GPU compute
1481        grid.enable_gpu_compute().await.unwrap();
1482        assert!(grid.is_gpu_enabled());
1483
1484        // Inject impulse at center
1485        grid.inject_impulse(16, 16, 1.0);
1486
1487        // Run several steps using GPU
1488        for _ in 0..10 {
1489            grid.step_gpu().await.unwrap();
1490        }
1491
1492        // Wave should have propagated
1493        let pressure_grid = grid.get_pressure_grid();
1494        let neighbor_pressure = pressure_grid[16][17];
1495        assert!(
1496            neighbor_pressure.abs() > 0.0,
1497            "Wave should have propagated to neighbor (GPU compute)"
1498        );
1499    }
1500
1501    #[cfg(feature = "wgpu")]
1502    #[tokio::test]
1503    #[ignore] // May not have GPU
1504    async fn test_tile_grid_gpu_matches_cpu() {
1505        let params = AcousticParams::new(343.0, 1.0);
1506
1507        // Create two grids - one CPU, one GPU
1508        let mut grid_cpu = TileKernelGrid::with_tile_size(32, 32, params.clone(), Backend::Cpu, 16)
1509            .await
1510            .unwrap();
1511        let mut grid_gpu = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1512            .await
1513            .unwrap();
1514
1515        grid_gpu.enable_gpu_compute().await.unwrap();
1516
1517        // Inject same impulse
1518        grid_cpu.inject_impulse(16, 16, 1.0);
1519        grid_gpu.inject_impulse(16, 16, 1.0);
1520
1521        // Run same number of steps
1522        for _ in 0..5 {
1523            grid_cpu.step().await.unwrap();
1524            grid_gpu.step_gpu().await.unwrap();
1525        }
1526
1527        // Results should match closely
1528        let cpu_grid = grid_cpu.get_pressure_grid();
1529        let gpu_grid = grid_gpu.get_pressure_grid();
1530
1531        for y in 0..32 {
1532            for x in 0..32 {
1533                let diff = (cpu_grid[y][x] - gpu_grid[y][x]).abs();
1534                assert!(
1535                    diff < 1e-4,
1536                    "CPU/GPU mismatch at ({},{}): cpu={}, gpu={}, diff={}",
1537                    x,
1538                    y,
1539                    cpu_grid[y][x],
1540                    gpu_grid[y][x],
1541                    diff
1542                );
1543            }
1544        }
1545    }
1546}