1use super::AcousticParams;
39use ringkernel::prelude::*;
40use ringkernel_core::hlc::HlcTimestamp;
41use ringkernel_core::k2k::{DeliveryStatus, K2KBroker, K2KBuilder, K2KEndpoint};
42use ringkernel_core::message::{MessageEnvelope, MessageHeader};
43use ringkernel_core::runtime::KernelId;
44use std::collections::HashMap;
45use std::sync::Arc;
46
47#[cfg(feature = "wgpu")]
49use super::gpu_compute::{init_wgpu, TileBuffers, TileGpuComputePool};
50
51#[cfg(feature = "cuda")]
53use super::cuda_compute::CudaTileBackend;
54#[cfg(any(feature = "wgpu", feature = "cuda"))]
55use super::gpu_backend::{BoundaryCondition, Edge, FdtdParams, TileGpuBackend, TileGpuBuffers};
56#[cfg(feature = "wgpu")]
57use super::wgpu_compute::{WgpuBuffer, WgpuTileBackend};
58
59pub const DEFAULT_TILE_SIZE: u32 = 16;
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum HaloDirection {
65 North,
66 South,
67 East,
68 West,
69}
70
71impl HaloDirection {
72 pub fn opposite(self) -> Self {
74 match self {
75 HaloDirection::North => HaloDirection::South,
76 HaloDirection::South => HaloDirection::North,
77 HaloDirection::East => HaloDirection::West,
78 HaloDirection::West => HaloDirection::East,
79 }
80 }
81
82 pub const ALL: [HaloDirection; 4] = [
84 HaloDirection::North,
85 HaloDirection::South,
86 HaloDirection::East,
87 HaloDirection::West,
88 ];
89}
90
91const HALO_EXCHANGE_TYPE_ID: u64 = 200;
93
94pub struct TileActor {
96 pub tile_x: u32,
98 pub tile_y: u32,
99
100 tile_size: u32,
102
103 buffer_width: usize,
105
106 pressure: Vec<f32>,
109
110 pressure_prev: Vec<f32>,
112
113 halo_north: Vec<f32>,
115 halo_south: Vec<f32>,
116 halo_east: Vec<f32>,
117 halo_west: Vec<f32>,
118
119 halos_received: [bool; 4],
121
122 neighbor_north: Option<KernelId>,
124 neighbor_south: Option<KernelId>,
125 neighbor_east: Option<KernelId>,
126 neighbor_west: Option<KernelId>,
127
128 endpoint: K2KEndpoint,
130
131 #[allow(dead_code)]
133 kernel_id: KernelId,
134
135 is_north_boundary: bool,
137 is_south_boundary: bool,
138 is_east_boundary: bool,
139 is_west_boundary: bool,
140
141 reflection_coeff: f32,
143}
144
145impl TileActor {
146 pub fn new(
148 tile_x: u32,
149 tile_y: u32,
150 tile_size: u32,
151 tiles_x: u32,
152 tiles_y: u32,
153 broker: &Arc<K2KBroker>,
154 reflection_coeff: f32,
155 ) -> Self {
156 let buffer_width = (tile_size + 2) as usize;
157 let buffer_size = buffer_width * buffer_width;
158
159 let kernel_id = Self::tile_kernel_id(tile_x, tile_y);
160 let endpoint = broker.register(kernel_id.clone());
161
162 let neighbor_north = if tile_y > 0 {
164 Some(Self::tile_kernel_id(tile_x, tile_y - 1))
165 } else {
166 None
167 };
168 let neighbor_south = if tile_y < tiles_y - 1 {
169 Some(Self::tile_kernel_id(tile_x, tile_y + 1))
170 } else {
171 None
172 };
173 let neighbor_west = if tile_x > 0 {
174 Some(Self::tile_kernel_id(tile_x - 1, tile_y))
175 } else {
176 None
177 };
178 let neighbor_east = if tile_x < tiles_x - 1 {
179 Some(Self::tile_kernel_id(tile_x + 1, tile_y))
180 } else {
181 None
182 };
183
184 Self {
185 tile_x,
186 tile_y,
187 tile_size,
188 buffer_width,
189 pressure: vec![0.0; buffer_size],
190 pressure_prev: vec![0.0; buffer_size],
191 halo_north: vec![0.0; tile_size as usize],
192 halo_south: vec![0.0; tile_size as usize],
193 halo_east: vec![0.0; tile_size as usize],
194 halo_west: vec![0.0; tile_size as usize],
195 halos_received: [false; 4],
196 neighbor_north,
197 neighbor_south,
198 neighbor_east,
199 neighbor_west,
200 endpoint,
201 kernel_id,
202 is_north_boundary: tile_y == 0,
203 is_south_boundary: tile_y == tiles_y - 1,
204 is_east_boundary: tile_x == tiles_x - 1,
205 is_west_boundary: tile_x == 0,
206 reflection_coeff,
207 }
208 }
209
210 pub fn tile_kernel_id(tile_x: u32, tile_y: u32) -> KernelId {
212 KernelId::new(format!("tile_{}_{}", tile_x, tile_y))
213 }
214
215 #[inline(always)]
217 fn buffer_idx(&self, local_x: usize, local_y: usize) -> usize {
218 (local_y + 1) * self.buffer_width + (local_x + 1)
220 }
221
222 pub fn get_pressure(&self, local_x: u32, local_y: u32) -> f32 {
224 self.pressure[self.buffer_idx(local_x as usize, local_y as usize)]
225 }
226
227 pub fn set_pressure(&mut self, local_x: u32, local_y: u32, value: f32) {
229 let idx = self.buffer_idx(local_x as usize, local_y as usize);
230 self.pressure[idx] = value;
231 }
232
233 fn extract_edge(&self, direction: HaloDirection) -> Vec<f32> {
235 let size = self.tile_size as usize;
236 let mut edge = vec![0.0; size];
237
238 match direction {
239 HaloDirection::North => {
240 for (x, cell) in edge.iter_mut().enumerate().take(size) {
242 *cell = self.pressure[self.buffer_idx(x, 0)];
243 }
244 }
245 HaloDirection::South => {
246 for (x, cell) in edge.iter_mut().enumerate().take(size) {
248 *cell = self.pressure[self.buffer_idx(x, size - 1)];
249 }
250 }
251 HaloDirection::West => {
252 for (y, cell) in edge.iter_mut().enumerate().take(size) {
254 *cell = self.pressure[self.buffer_idx(0, y)];
255 }
256 }
257 HaloDirection::East => {
258 for (y, cell) in edge.iter_mut().enumerate().take(size) {
260 *cell = self.pressure[self.buffer_idx(size - 1, y)];
261 }
262 }
263 }
264
265 edge
266 }
267
268 fn apply_halo(&mut self, direction: HaloDirection, data: &[f32]) {
270 let size = self.tile_size as usize;
271 let bw = self.buffer_width;
272
273 match direction {
274 HaloDirection::North => {
275 self.pressure[1..(size + 1)].copy_from_slice(&data[..size]);
277 }
278 HaloDirection::South => {
279 for (x, &val) in data.iter().enumerate().take(size) {
281 self.pressure[(size + 1) * bw + (x + 1)] = val;
282 }
283 }
284 HaloDirection::West => {
285 for (y, &val) in data.iter().enumerate().take(size) {
287 self.pressure[(y + 1) * bw] = val;
288 }
289 }
290 HaloDirection::East => {
291 for (y, &val) in data.iter().enumerate().take(size) {
293 self.pressure[(y + 1) * bw + (size + 1)] = val;
294 }
295 }
296 }
297 }
298
299 fn apply_boundary_reflection(&mut self) {
301 let size = self.tile_size as usize;
302 let bw = self.buffer_width;
303 let refl = self.reflection_coeff;
304
305 if self.is_north_boundary {
306 for x in 0..size {
308 let interior_idx = self.buffer_idx(x, 0);
309 self.pressure[x + 1] = self.pressure[interior_idx] * refl;
310 }
311 }
312
313 if self.is_south_boundary {
314 for x in 0..size {
316 let interior_idx = self.buffer_idx(x, size - 1);
317 self.pressure[(size + 1) * bw + (x + 1)] = self.pressure[interior_idx] * refl;
318 }
319 }
320
321 if self.is_west_boundary {
322 for y in 0..size {
324 let interior_idx = self.buffer_idx(0, y);
325 self.pressure[(y + 1) * bw] = self.pressure[interior_idx] * refl;
326 }
327 }
328
329 if self.is_east_boundary {
330 for y in 0..size {
332 let interior_idx = self.buffer_idx(size - 1, y);
333 self.pressure[(y + 1) * bw + (size + 1)] = self.pressure[interior_idx] * refl;
334 }
335 }
336 }
337
338 pub async fn send_halos(&self) -> Result<()> {
340 for direction in HaloDirection::ALL {
341 let neighbor = match direction {
342 HaloDirection::North => &self.neighbor_north,
343 HaloDirection::South => &self.neighbor_south,
344 HaloDirection::East => &self.neighbor_east,
345 HaloDirection::West => &self.neighbor_west,
346 };
347
348 if let Some(neighbor_id) = neighbor {
349 let edge_data = self.extract_edge(direction);
350
351 let envelope = Self::create_halo_envelope(direction.opposite(), &edge_data);
353
354 let receipt = self.endpoint.send(neighbor_id.clone(), envelope).await?;
355 if receipt.status != DeliveryStatus::Delivered {
356 tracing::warn!(
357 "Halo send failed: tile ({},{}) -> {:?}, status: {:?}",
358 self.tile_x,
359 self.tile_y,
360 neighbor_id,
361 receipt.status
362 );
363 }
364 }
365 }
366
367 Ok(())
368 }
369
370 pub fn receive_halos(&mut self) {
372 self.halos_received = [false; 4];
374
375 while let Some(message) = self.endpoint.try_receive() {
377 if let Some((direction, data)) = Self::parse_halo_envelope(&message.envelope) {
378 match direction {
380 HaloDirection::North => {
381 self.halo_north.copy_from_slice(&data);
382 self.halos_received[0] = true;
383 }
384 HaloDirection::South => {
385 self.halo_south.copy_from_slice(&data);
386 self.halos_received[1] = true;
387 }
388 HaloDirection::East => {
389 self.halo_east.copy_from_slice(&data);
390 self.halos_received[2] = true;
391 }
392 HaloDirection::West => {
393 self.halo_west.copy_from_slice(&data);
394 self.halos_received[3] = true;
395 }
396 }
397 }
398 }
399
400 if self.halos_received[0] && self.neighbor_north.is_some() {
402 self.apply_halo(HaloDirection::North, &self.halo_north.clone());
403 }
404 if self.halos_received[1] && self.neighbor_south.is_some() {
405 self.apply_halo(HaloDirection::South, &self.halo_south.clone());
406 }
407 if self.halos_received[2] && self.neighbor_east.is_some() {
408 self.apply_halo(HaloDirection::East, &self.halo_east.clone());
409 }
410 if self.halos_received[3] && self.neighbor_west.is_some() {
411 self.apply_halo(HaloDirection::West, &self.halo_west.clone());
412 }
413
414 self.apply_boundary_reflection();
416 }
417
418 fn create_halo_envelope(direction: HaloDirection, data: &[f32]) -> MessageEnvelope {
420 let mut payload = Vec::with_capacity(1 + data.len() * 4);
423 payload.push(direction as u8);
424 for &value in data {
425 payload.extend_from_slice(&value.to_le_bytes());
426 }
427
428 let header = MessageHeader::new(
429 HALO_EXCHANGE_TYPE_ID,
430 0, 0, payload.len(),
433 HlcTimestamp::now(0),
434 );
435
436 MessageEnvelope { header, payload }
437 }
438
439 fn parse_halo_envelope(envelope: &MessageEnvelope) -> Option<(HaloDirection, Vec<f32>)> {
441 if envelope.header.message_type != HALO_EXCHANGE_TYPE_ID {
442 return None;
443 }
444
445 if envelope.payload.is_empty() {
446 return None;
447 }
448
449 let direction = match envelope.payload[0] {
450 0 => HaloDirection::North,
451 1 => HaloDirection::South,
452 2 => HaloDirection::East,
453 3 => HaloDirection::West,
454 _ => return None,
455 };
456
457 let data: Vec<f32> = envelope.payload[1..]
458 .chunks_exact(4)
459 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
460 .collect();
461
462 Some((direction, data))
463 }
464
465 pub fn compute_fdtd(&mut self, c2: f32, damping: f32) {
467 let size = self.tile_size as usize;
468 let bw = self.buffer_width;
469
470 for local_y in 0..size {
472 for local_x in 0..size {
473 let idx = (local_y + 1) * bw + (local_x + 1);
474
475 let p_curr = self.pressure[idx];
476 let p_prev = self.pressure_prev[idx];
477
478 let p_north = self.pressure[idx - bw];
480 let p_south = self.pressure[idx + bw];
481 let p_west = self.pressure[idx - 1];
482 let p_east = self.pressure[idx + 1];
483
484 let laplacian = p_north + p_south + p_east + p_west - 4.0 * p_curr;
486 let p_new = 2.0 * p_curr - p_prev + c2 * laplacian;
487
488 self.pressure_prev[idx] = p_new * damping;
489 }
490 }
491 }
492
493 pub fn swap_buffers(&mut self) {
495 std::mem::swap(&mut self.pressure, &mut self.pressure_prev);
496 }
497
498 pub fn reset(&mut self) {
500 self.pressure.fill(0.0);
501 self.pressure_prev.fill(0.0);
502 }
503}
504
505#[cfg(feature = "wgpu")]
508struct GpuComputeResources {
509 pool: TileGpuComputePool,
511 tile_buffers: HashMap<(u32, u32), TileBuffers>,
513}
514
515#[cfg(feature = "wgpu")]
518struct WgpuPersistentState {
519 backend: WgpuTileBackend,
521 tile_buffers: HashMap<(u32, u32), TileGpuBuffers<WgpuBuffer>>,
523}
524
525#[cfg(feature = "cuda")]
528struct CudaPersistentState {
529 backend: CudaTileBackend,
531 tile_buffers: HashMap<(u32, u32), TileGpuBuffers<ringkernel_cuda::CudaBuffer>>,
533}
534
535#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537pub enum GpuPersistentBackend {
538 Wgpu,
540 Cuda,
542}
543
544pub struct TileKernelGrid {
546 pub width: u32,
548 pub height: u32,
549
550 tile_size: u32,
552
553 tiles_x: u32,
555 tiles_y: u32,
556
557 tiles: HashMap<(u32, u32), TileActor>,
559
560 broker: Arc<K2KBroker>,
562
563 pub params: AcousticParams,
565
566 backend: Backend,
568
569 #[allow(dead_code)]
571 runtime: Arc<RingKernel>,
572
573 #[cfg(feature = "wgpu")]
576 gpu_compute: Option<GpuComputeResources>,
577
578 #[cfg(feature = "wgpu")]
581 wgpu_persistent: Option<WgpuPersistentState>,
582
583 #[cfg(feature = "cuda")]
586 cuda_persistent: Option<CudaPersistentState>,
587}
588
589impl std::fmt::Debug for TileKernelGrid {
590 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591 f.debug_struct("TileKernelGrid")
592 .field("width", &self.width)
593 .field("height", &self.height)
594 .field("tile_size", &self.tile_size)
595 .field("tiles_x", &self.tiles_x)
596 .field("tiles_y", &self.tiles_y)
597 .field("tile_count", &self.tiles.len())
598 .field("backend", &self.backend)
599 .finish()
600 }
601}
602
603impl TileKernelGrid {
604 pub async fn new(
606 width: u32,
607 height: u32,
608 params: AcousticParams,
609 backend: Backend,
610 ) -> Result<Self> {
611 Self::with_tile_size(width, height, params, backend, DEFAULT_TILE_SIZE).await
612 }
613
614 pub async fn with_tile_size(
616 width: u32,
617 height: u32,
618 params: AcousticParams,
619 backend: Backend,
620 tile_size: u32,
621 ) -> Result<Self> {
622 let runtime = Arc::new(RingKernel::with_backend(backend).await?);
623
624 let tiles_x = width.div_ceil(tile_size);
626 let tiles_y = height.div_ceil(tile_size);
627
628 let broker = K2KBuilder::new()
630 .max_pending_messages(tiles_x as usize * tiles_y as usize * 8)
631 .build();
632
633 let mut tiles = HashMap::new();
635 for ty in 0..tiles_y {
636 for tx in 0..tiles_x {
637 let tile = TileActor::new(tx, ty, tile_size, tiles_x, tiles_y, &broker, 0.95);
638 tiles.insert((tx, ty), tile);
639 }
640 }
641
642 tracing::info!(
643 "Created TileKernelGrid: {}x{} cells, {}x{} tiles ({}x{} per tile), {} tile actors",
644 width,
645 height,
646 tiles_x,
647 tiles_y,
648 tile_size,
649 tile_size,
650 tiles.len()
651 );
652
653 Ok(Self {
654 width,
655 height,
656 tile_size,
657 tiles_x,
658 tiles_y,
659 tiles,
660 broker,
661 params,
662 backend,
663 runtime,
664 #[cfg(feature = "wgpu")]
665 gpu_compute: None,
666 #[cfg(feature = "wgpu")]
667 wgpu_persistent: None,
668 #[cfg(feature = "cuda")]
669 cuda_persistent: None,
670 })
671 }
672
673 fn global_to_tile_coords(&self, x: u32, y: u32) -> (u32, u32, u32, u32) {
675 let tile_x = x / self.tile_size;
676 let tile_y = y / self.tile_size;
677 let local_x = x % self.tile_size;
678 let local_y = y % self.tile_size;
679 (tile_x, tile_y, local_x, local_y)
680 }
681
682 pub async fn step(&mut self) -> Result<()> {
684 let c2 = self.params.courant_number().powi(2);
685 let damping = 1.0 - self.params.damping;
686
687 for tile in self.tiles.values() {
689 tile.send_halos().await?;
690 }
691
692 for tile in self.tiles.values_mut() {
694 tile.receive_halos();
695 }
696
697 for tile in self.tiles.values_mut() {
699 tile.compute_fdtd(c2, damping);
700 }
701
702 for tile in self.tiles.values_mut() {
704 tile.swap_buffers();
705 }
706
707 Ok(())
708 }
709
710 #[cfg(feature = "wgpu")]
715 pub async fn enable_gpu_compute(&mut self) -> Result<()> {
716 let (device, queue) = init_wgpu().await?;
717 let pool = TileGpuComputePool::new(device, queue, self.tile_size)?;
718
719 let mut tile_buffers = HashMap::new();
721 for tile_coords in self.tiles.keys() {
722 let buffers = pool.create_tile_buffers();
723 tile_buffers.insert(*tile_coords, buffers);
724 }
725
726 let num_tiles = tile_buffers.len();
727 self.gpu_compute = Some(GpuComputeResources { pool, tile_buffers });
728
729 tracing::info!(
730 "GPU compute enabled for TileKernelGrid: {} tiles with GPU buffers",
731 num_tiles
732 );
733
734 Ok(())
735 }
736
737 #[cfg(feature = "wgpu")]
739 pub fn is_gpu_enabled(&self) -> bool {
740 self.gpu_compute.is_some()
741 }
742
743 #[cfg(feature = "wgpu")]
749 pub async fn step_gpu(&mut self) -> Result<()> {
750 let gpu = self.gpu_compute.as_ref().ok_or_else(|| {
751 ringkernel_core::error::RingKernelError::BackendError(
752 "GPU compute not enabled. Call enable_gpu_compute() first.".to_string(),
753 )
754 })?;
755
756 let c2 = self.params.courant_number().powi(2);
757 let damping = 1.0 - self.params.damping;
758
759 for tile in self.tiles.values() {
761 tile.send_halos().await?;
762 }
763
764 for tile in self.tiles.values_mut() {
766 tile.receive_halos();
767 }
768
769 for ((tx, ty), tile) in self.tiles.iter_mut() {
772 if let Some(buffers) = gpu.tile_buffers.get(&(*tx, *ty)) {
773 let pressure = &tile.pressure;
775 let pressure_prev = &tile.pressure_prev;
776
777 let result = gpu
779 .pool
780 .compute_fdtd(buffers, pressure, pressure_prev, c2, damping);
781
782 tile.pressure_prev.copy_from_slice(&result);
784 } else {
785 tile.compute_fdtd(c2, damping);
787 }
788 }
789
790 for tile in self.tiles.values_mut() {
792 tile.swap_buffers();
793 }
794
795 Ok(())
796 }
797
798 #[cfg(feature = "wgpu")]
804 pub async fn enable_wgpu_persistent(&mut self) -> Result<()> {
805 let backend = WgpuTileBackend::new(self.tile_size).await?;
806
807 let mut tile_buffers = HashMap::new();
809 for ((tx, ty), tile) in &self.tiles {
810 let buffers = backend.create_tile_buffers(self.tile_size)?;
811 backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
813 tile_buffers.insert((*tx, *ty), buffers);
814 }
815
816 let num_tiles = tile_buffers.len();
817 self.wgpu_persistent = Some(WgpuPersistentState {
818 backend,
819 tile_buffers,
820 });
821
822 tracing::info!(
823 "WGPU GPU-persistent compute enabled for TileKernelGrid: {} tiles",
824 num_tiles
825 );
826
827 Ok(())
828 }
829
830 #[cfg(feature = "cuda")]
835 pub fn enable_cuda_persistent(&mut self) -> Result<()> {
836 let backend = CudaTileBackend::new(self.tile_size)?;
837
838 let mut tile_buffers = HashMap::new();
840 for ((tx, ty), tile) in &self.tiles {
841 let buffers = backend.create_tile_buffers(self.tile_size)?;
842 backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
844 tile_buffers.insert((*tx, *ty), buffers);
845 }
846
847 let num_tiles = tile_buffers.len();
848 self.cuda_persistent = Some(CudaPersistentState {
849 backend,
850 tile_buffers,
851 });
852
853 tracing::info!(
854 "CUDA GPU-persistent compute enabled for TileKernelGrid: {} tiles",
855 num_tiles
856 );
857
858 Ok(())
859 }
860
861 pub fn is_gpu_persistent_enabled(&self) -> bool {
863 #[cfg(feature = "wgpu")]
864 if self.wgpu_persistent.is_some() {
865 return true;
866 }
867 #[cfg(feature = "cuda")]
868 if self.cuda_persistent.is_some() {
869 return true;
870 }
871 false
872 }
873
874 pub fn gpu_persistent_backend(&self) -> Option<GpuPersistentBackend> {
876 #[cfg(feature = "cuda")]
877 if self.cuda_persistent.is_some() {
878 return Some(GpuPersistentBackend::Cuda);
879 }
880 #[cfg(feature = "wgpu")]
881 if self.wgpu_persistent.is_some() {
882 return Some(GpuPersistentBackend::Wgpu);
883 }
884 None
885 }
886
887 #[cfg(feature = "wgpu")]
891 pub fn step_wgpu_persistent(&mut self) -> Result<()> {
892 let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
893 ringkernel_core::error::RingKernelError::BackendError(
894 "WGPU persistent compute not enabled. Call enable_wgpu_persistent() first."
895 .to_string(),
896 )
897 })?;
898
899 let c2 = self.params.courant_number().powi(2);
900 let damping = 1.0 - self.params.damping;
901 let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
902
903 let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
906
907 for (tx, ty) in self.tiles.keys() {
908 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
909 let tile = self.tiles.get(&(*tx, *ty)).unwrap();
911
912 if tile.neighbor_north.is_some() {
913 let halo = state.backend.extract_halo(buffers, Edge::North)?;
914 halo_messages.push((
915 tile.neighbor_north.clone().unwrap(),
916 HaloDirection::South,
917 halo,
918 ));
919 }
920 if tile.neighbor_south.is_some() {
921 let halo = state.backend.extract_halo(buffers, Edge::South)?;
922 halo_messages.push((
923 tile.neighbor_south.clone().unwrap(),
924 HaloDirection::North,
925 halo,
926 ));
927 }
928 if tile.neighbor_west.is_some() {
929 let halo = state.backend.extract_halo(buffers, Edge::West)?;
930 halo_messages.push((
931 tile.neighbor_west.clone().unwrap(),
932 HaloDirection::East,
933 halo,
934 ));
935 }
936 if tile.neighbor_east.is_some() {
937 let halo = state.backend.extract_halo(buffers, Edge::East)?;
938 halo_messages.push((
939 tile.neighbor_east.clone().unwrap(),
940 HaloDirection::West,
941 halo,
942 ));
943 }
944 }
945 }
946
947 for (dest_id, direction, halo_data) in halo_messages {
950 for (tx, ty) in self.tiles.keys() {
952 if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
953 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
954 let edge = match direction {
955 HaloDirection::North => Edge::North,
956 HaloDirection::South => Edge::South,
957 HaloDirection::East => Edge::East,
958 HaloDirection::West => Edge::West,
959 };
960 state.backend.inject_halo(buffers, edge, &halo_data)?;
961 }
962 break;
963 }
964 }
965 }
966
967 for (tx, ty) in self.tiles.keys() {
970 let tile = self.tiles.get(&(*tx, *ty)).unwrap();
971 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
972 if tile.neighbor_north.is_none() {
974 state.backend.apply_boundary(
975 buffers,
976 Edge::North,
977 BoundaryCondition::Absorbing,
978 )?;
979 }
980 if tile.neighbor_south.is_none() {
981 state.backend.apply_boundary(
982 buffers,
983 Edge::South,
984 BoundaryCondition::Absorbing,
985 )?;
986 }
987 if tile.neighbor_west.is_none() {
988 state.backend.apply_boundary(
989 buffers,
990 Edge::West,
991 BoundaryCondition::Absorbing,
992 )?;
993 }
994 if tile.neighbor_east.is_none() {
995 state.backend.apply_boundary(
996 buffers,
997 Edge::East,
998 BoundaryCondition::Absorbing,
999 )?;
1000 }
1001 }
1002 }
1003
1004 for (tx, ty) in self.tiles.keys() {
1006 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1007 state.backend.fdtd_step(buffers, &fdtd_params)?;
1008 }
1009 }
1010
1011 state.backend.synchronize()?;
1013
1014 let state = self.wgpu_persistent.as_mut().unwrap();
1017 for buffers in state.tile_buffers.values_mut() {
1018 state.backend.swap_buffers(buffers);
1019 }
1020
1021 Ok(())
1022 }
1023
1024 #[cfg(feature = "cuda")]
1028 pub fn step_cuda_persistent(&mut self) -> Result<()> {
1029 let state = self.cuda_persistent.as_ref().ok_or_else(|| {
1030 ringkernel_core::error::RingKernelError::BackendError(
1031 "CUDA persistent compute not enabled. Call enable_cuda_persistent() first."
1032 .to_string(),
1033 )
1034 })?;
1035
1036 let c2 = self.params.courant_number().powi(2);
1037 let damping = 1.0 - self.params.damping;
1038 let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
1039
1040 let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
1042
1043 for (tx, ty) in self.tiles.keys() {
1044 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1045 let tile = self.tiles.get(&(*tx, *ty)).unwrap();
1046
1047 if tile.neighbor_north.is_some() {
1048 let halo = state.backend.extract_halo(buffers, Edge::North)?;
1049 halo_messages.push((
1050 tile.neighbor_north.clone().unwrap(),
1051 HaloDirection::South,
1052 halo,
1053 ));
1054 }
1055 if tile.neighbor_south.is_some() {
1056 let halo = state.backend.extract_halo(buffers, Edge::South)?;
1057 halo_messages.push((
1058 tile.neighbor_south.clone().unwrap(),
1059 HaloDirection::North,
1060 halo,
1061 ));
1062 }
1063 if tile.neighbor_west.is_some() {
1064 let halo = state.backend.extract_halo(buffers, Edge::West)?;
1065 halo_messages.push((
1066 tile.neighbor_west.clone().unwrap(),
1067 HaloDirection::East,
1068 halo,
1069 ));
1070 }
1071 if tile.neighbor_east.is_some() {
1072 let halo = state.backend.extract_halo(buffers, Edge::East)?;
1073 halo_messages.push((
1074 tile.neighbor_east.clone().unwrap(),
1075 HaloDirection::West,
1076 halo,
1077 ));
1078 }
1079 }
1080 }
1081
1082 for (dest_id, direction, halo_data) in halo_messages {
1084 for (tx, ty) in self.tiles.keys() {
1085 if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
1086 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1087 let edge = match direction {
1088 HaloDirection::North => Edge::North,
1089 HaloDirection::South => Edge::South,
1090 HaloDirection::East => Edge::East,
1091 HaloDirection::West => Edge::West,
1092 };
1093 state.backend.inject_halo(buffers, edge, &halo_data)?;
1094 }
1095 break;
1096 }
1097 }
1098 }
1099
1100 for (tx, ty) in self.tiles.keys() {
1102 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1103 state.backend.fdtd_step(buffers, &fdtd_params)?;
1104 }
1105 }
1106
1107 state.backend.synchronize()?;
1109
1110 let state = self.cuda_persistent.as_mut().unwrap();
1112 for buffers in state.tile_buffers.values_mut() {
1113 state.backend.swap_buffers(buffers);
1114 }
1115
1116 Ok(())
1117 }
1118
1119 #[cfg(feature = "wgpu")]
1124 pub fn read_pressure_from_wgpu(&self) -> Result<Vec<Vec<f32>>> {
1125 let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
1126 ringkernel_core::error::RingKernelError::BackendError(
1127 "WGPU persistent compute not enabled.".to_string(),
1128 )
1129 })?;
1130
1131 let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1132
1133 for ((tx, ty), buffers) in &state.tile_buffers {
1134 let interior = state.backend.read_interior_pressure(buffers)?;
1135
1136 let tile_start_x = tx * self.tile_size;
1138 let tile_start_y = ty * self.tile_size;
1139
1140 for ly in 0..self.tile_size {
1141 for lx in 0..self.tile_size {
1142 let gx = tile_start_x + lx;
1143 let gy = tile_start_y + ly;
1144 if gx < self.width && gy < self.height {
1145 grid[gy as usize][gx as usize] =
1146 interior[(ly * self.tile_size + lx) as usize];
1147 }
1148 }
1149 }
1150 }
1151
1152 Ok(grid)
1153 }
1154
1155 #[cfg(feature = "cuda")]
1157 pub fn read_pressure_from_cuda(&self) -> Result<Vec<Vec<f32>>> {
1158 let state = self.cuda_persistent.as_ref().ok_or_else(|| {
1159 ringkernel_core::error::RingKernelError::BackendError(
1160 "CUDA persistent compute not enabled.".to_string(),
1161 )
1162 })?;
1163
1164 let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1165
1166 for ((tx, ty), buffers) in &state.tile_buffers {
1167 let interior = state.backend.read_interior_pressure(buffers)?;
1168
1169 let tile_start_x = tx * self.tile_size;
1170 let tile_start_y = ty * self.tile_size;
1171
1172 for ly in 0..self.tile_size {
1173 for lx in 0..self.tile_size {
1174 let gx = tile_start_x + lx;
1175 let gy = tile_start_y + ly;
1176 if gx < self.width && gy < self.height {
1177 grid[gy as usize][gx as usize] =
1178 interior[(ly * self.tile_size + lx) as usize];
1179 }
1180 }
1181 }
1182 }
1183
1184 Ok(grid)
1185 }
1186
1187 #[cfg(feature = "wgpu")]
1191 pub fn inject_impulse_wgpu(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1192 if x >= self.width || y >= self.height {
1193 return Ok(());
1194 }
1195
1196 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1197
1198 if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1200 let current = tile.get_pressure(local_x, local_y);
1201 tile.set_pressure(local_x, local_y, current + amplitude);
1202
1203 if let Some(state) = &self.wgpu_persistent {
1205 if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1206 state.backend.upload_initial_state(
1207 buffers,
1208 &tile.pressure,
1209 &tile.pressure_prev,
1210 )?;
1211 }
1212 }
1213 }
1214
1215 Ok(())
1216 }
1217
1218 #[cfg(feature = "cuda")]
1220 pub fn inject_impulse_cuda(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1221 if x >= self.width || y >= self.height {
1222 return Ok(());
1223 }
1224
1225 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1226
1227 if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1228 let current = tile.get_pressure(local_x, local_y);
1229 tile.set_pressure(local_x, local_y, current + amplitude);
1230
1231 if let Some(state) = &self.cuda_persistent {
1232 if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1233 state.backend.upload_initial_state(
1234 buffers,
1235 &tile.pressure,
1236 &tile.pressure_prev,
1237 )?;
1238 }
1239 }
1240 }
1241
1242 Ok(())
1243 }
1244
1245 pub fn inject_impulse(&mut self, x: u32, y: u32, amplitude: f32) {
1247 if x >= self.width || y >= self.height {
1248 return;
1249 }
1250
1251 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1252
1253 if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1254 let current = tile.get_pressure(local_x, local_y);
1255 tile.set_pressure(local_x, local_y, current + amplitude);
1256 }
1257 }
1258
1259 pub fn get_pressure_grid(&self) -> Vec<Vec<f32>> {
1261 let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1262
1263 for y in 0..self.height {
1264 for x in 0..self.width {
1265 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1266
1267 if let Some(tile) = self.tiles.get(&(tile_x, tile_y)) {
1268 if local_x < self.tile_size && local_y < self.tile_size {
1270 grid[y as usize][x as usize] = tile.get_pressure(local_x, local_y);
1271 }
1272 }
1273 }
1274 }
1275
1276 grid
1277 }
1278
1279 pub fn max_pressure(&self) -> f32 {
1281 self.tiles
1282 .values()
1283 .flat_map(|tile| {
1284 (0..self.tile_size).flat_map(move |y| {
1285 (0..self.tile_size).map(move |x| tile.get_pressure(x, y).abs())
1286 })
1287 })
1288 .fold(0.0, f32::max)
1289 }
1290
1291 pub fn total_energy(&self) -> f32 {
1293 self.tiles
1294 .values()
1295 .flat_map(|tile| {
1296 (0..self.tile_size).flat_map(move |y| {
1297 (0..self.tile_size).map(move |x| {
1298 let p = tile.get_pressure(x, y);
1299 p * p
1300 })
1301 })
1302 })
1303 .sum()
1304 }
1305
1306 pub fn reset(&mut self) {
1308 for tile in self.tiles.values_mut() {
1309 tile.reset();
1310 }
1311 }
1312
1313 pub fn cell_count(&self) -> usize {
1315 (self.width * self.height) as usize
1316 }
1317
1318 pub fn tile_count(&self) -> usize {
1320 self.tiles.len()
1321 }
1322
1323 pub fn backend(&self) -> Backend {
1325 self.backend
1326 }
1327
1328 pub fn k2k_stats(&self) -> ringkernel_core::k2k::K2KStats {
1330 self.broker.stats()
1331 }
1332
1333 pub fn set_speed_of_sound(&mut self, speed: f32) {
1335 self.params.set_speed_of_sound(speed);
1336 }
1337
1338 pub fn set_cell_size(&mut self, size: f32) {
1340 self.params.set_cell_size(size);
1341 }
1342
1343 pub async fn resize(&mut self, new_width: u32, new_height: u32) -> Result<()> {
1345 self.width = new_width;
1346 self.height = new_height;
1347
1348 self.tiles_x = new_width.div_ceil(self.tile_size);
1350 self.tiles_y = new_height.div_ceil(self.tile_size);
1351
1352 self.broker = K2KBuilder::new()
1354 .max_pending_messages(self.tiles_x as usize * self.tiles_y as usize * 8)
1355 .build();
1356
1357 self.tiles.clear();
1359 for ty in 0..self.tiles_y {
1360 for tx in 0..self.tiles_x {
1361 let tile = TileActor::new(
1362 tx,
1363 ty,
1364 self.tile_size,
1365 self.tiles_x,
1366 self.tiles_y,
1367 &self.broker,
1368 0.95,
1369 );
1370 self.tiles.insert((tx, ty), tile);
1371 }
1372 }
1373
1374 tracing::info!(
1375 "Resized TileKernelGrid: {}x{} cells, {} tile actors",
1376 new_width,
1377 new_height,
1378 self.tiles.len()
1379 );
1380
1381 Ok(())
1382 }
1383
1384 pub async fn shutdown(self) -> Result<()> {
1386 Ok(())
1388 }
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393 use super::*;
1394
1395 #[tokio::test]
1396 async fn test_tile_grid_creation() {
1397 let params = AcousticParams::new(343.0, 1.0);
1398 let grid = TileKernelGrid::new(64, 64, params, Backend::Cpu)
1399 .await
1400 .unwrap();
1401
1402 assert_eq!(grid.width, 64);
1403 assert_eq!(grid.height, 64);
1404 assert_eq!(grid.tile_size, DEFAULT_TILE_SIZE);
1405 assert_eq!(grid.tiles_x, 4); assert_eq!(grid.tiles_y, 4);
1407 assert_eq!(grid.tile_count(), 16); }
1409
1410 #[tokio::test]
1411 async fn test_tile_grid_impulse() {
1412 let params = AcousticParams::new(343.0, 1.0);
1413 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1414 .await
1415 .unwrap();
1416
1417 grid.inject_impulse(16, 16, 1.0);
1418
1419 let pressure_grid = grid.get_pressure_grid();
1420 assert_eq!(pressure_grid[16][16], 1.0);
1421 }
1422
1423 #[tokio::test]
1424 async fn test_tile_grid_step() {
1425 let params = AcousticParams::new(343.0, 1.0);
1426 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1427 .await
1428 .unwrap();
1429
1430 grid.inject_impulse(16, 16, 1.0);
1432
1433 for _ in 0..10 {
1435 grid.step().await.unwrap();
1436 }
1437
1438 let pressure_grid = grid.get_pressure_grid();
1440 let neighbor_pressure = pressure_grid[16][17];
1441 assert!(
1442 neighbor_pressure.abs() > 0.0,
1443 "Wave should have propagated to neighbor"
1444 );
1445 }
1446
1447 #[tokio::test]
1448 async fn test_tile_grid_k2k_stats() {
1449 let params = AcousticParams::new(343.0, 1.0);
1450 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1451 .await
1452 .unwrap();
1453
1454 grid.inject_impulse(16, 16, 1.0);
1455 grid.step().await.unwrap();
1456
1457 let stats = grid.k2k_stats();
1458 assert!(
1459 stats.messages_delivered > 0,
1460 "K2K messages should have been exchanged"
1461 );
1462 }
1463
1464 #[tokio::test]
1465 async fn test_tile_grid_reset() {
1466 let params = AcousticParams::new(343.0, 1.0);
1467 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1468 .await
1469 .unwrap();
1470
1471 grid.inject_impulse(16, 16, 1.0);
1472 grid.step().await.unwrap();
1473
1474 grid.reset();
1475
1476 let pressure_grid = grid.get_pressure_grid();
1477 for row in pressure_grid {
1478 for p in row {
1479 assert_eq!(p, 0.0);
1480 }
1481 }
1482 }
1483
1484 #[tokio::test]
1485 async fn test_tile_boundary_handling() {
1486 let params = AcousticParams::new(343.0, 1.0);
1487 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1488 .await
1489 .unwrap();
1490
1491 grid.inject_impulse(1, 1, 1.0);
1493
1494 for _ in 0..5 {
1496 grid.step().await.unwrap();
1497 }
1498
1499 let energy = grid.total_energy();
1501 assert!(energy.is_finite(), "Energy should be finite");
1502 }
1503
1504 #[cfg(feature = "wgpu")]
1505 #[tokio::test]
1506 #[ignore] async fn test_tile_grid_gpu_step() {
1508 let params = AcousticParams::new(343.0, 1.0);
1509 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1510 .await
1511 .unwrap();
1512
1513 grid.enable_gpu_compute().await.unwrap();
1515 assert!(grid.is_gpu_enabled());
1516
1517 grid.inject_impulse(16, 16, 1.0);
1519
1520 for _ in 0..10 {
1522 grid.step_gpu().await.unwrap();
1523 }
1524
1525 let pressure_grid = grid.get_pressure_grid();
1527 let neighbor_pressure = pressure_grid[16][17];
1528 assert!(
1529 neighbor_pressure.abs() > 0.0,
1530 "Wave should have propagated to neighbor (GPU compute)"
1531 );
1532 }
1533
1534 #[cfg(feature = "wgpu")]
1535 #[tokio::test]
1536 #[ignore] async fn test_tile_grid_gpu_matches_cpu() {
1538 let params = AcousticParams::new(343.0, 1.0);
1539
1540 let mut grid_cpu = TileKernelGrid::with_tile_size(32, 32, params.clone(), Backend::Cpu, 16)
1542 .await
1543 .unwrap();
1544 let mut grid_gpu = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1545 .await
1546 .unwrap();
1547
1548 grid_gpu.enable_gpu_compute().await.unwrap();
1549
1550 grid_cpu.inject_impulse(16, 16, 1.0);
1552 grid_gpu.inject_impulse(16, 16, 1.0);
1553
1554 for _ in 0..5 {
1556 grid_cpu.step().await.unwrap();
1557 grid_gpu.step_gpu().await.unwrap();
1558 }
1559
1560 let cpu_grid = grid_cpu.get_pressure_grid();
1562 let gpu_grid = grid_gpu.get_pressure_grid();
1563
1564 for y in 0..32 {
1565 for x in 0..32 {
1566 let diff = (cpu_grid[y][x] - gpu_grid[y][x]).abs();
1567 assert!(
1568 diff < 1e-4,
1569 "CPU/GPU mismatch at ({},{}): cpu={}, gpu={}, diff={}",
1570 x,
1571 y,
1572 cpu_grid[y][x],
1573 gpu_grid[y][x],
1574 diff
1575 );
1576 }
1577 }
1578 }
1579}