1use super::AcousticParams;
39use ringkernel::prelude::*;
40use ringkernel_core::hlc::HlcTimestamp;
41use ringkernel_core::k2k::{DeliveryStatus, K2KBroker, K2KBuilder, K2KEndpoint};
42use ringkernel_core::message::{MessageEnvelope, MessageHeader};
43use ringkernel_core::runtime::KernelId;
44use std::collections::HashMap;
45use std::sync::Arc;
46
47#[cfg(feature = "wgpu")]
49use super::gpu_compute::{init_wgpu, TileBuffers, TileGpuComputePool};
50
51#[cfg(feature = "cuda")]
53use super::cuda_compute::CudaTileBackend;
54#[cfg(any(feature = "wgpu", feature = "cuda"))]
55use super::gpu_backend::{Edge, FdtdParams, TileGpuBackend, TileGpuBuffers};
56#[cfg(feature = "wgpu")]
57use super::wgpu_compute::{WgpuBuffer, WgpuTileBackend};
58
59pub const DEFAULT_TILE_SIZE: u32 = 16;
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum HaloDirection {
65 North,
66 South,
67 East,
68 West,
69}
70
71impl HaloDirection {
72 pub fn opposite(self) -> Self {
74 match self {
75 HaloDirection::North => HaloDirection::South,
76 HaloDirection::South => HaloDirection::North,
77 HaloDirection::East => HaloDirection::West,
78 HaloDirection::West => HaloDirection::East,
79 }
80 }
81
82 pub const ALL: [HaloDirection; 4] = [
84 HaloDirection::North,
85 HaloDirection::South,
86 HaloDirection::East,
87 HaloDirection::West,
88 ];
89}
90
91const HALO_EXCHANGE_TYPE_ID: u64 = 200;
93
94pub struct TileActor {
96 pub tile_x: u32,
98 pub tile_y: u32,
99
100 tile_size: u32,
102
103 buffer_width: usize,
105
106 pressure: Vec<f32>,
109
110 pressure_prev: Vec<f32>,
112
113 halo_north: Vec<f32>,
115 halo_south: Vec<f32>,
116 halo_east: Vec<f32>,
117 halo_west: Vec<f32>,
118
119 halos_received: [bool; 4],
121
122 neighbor_north: Option<KernelId>,
124 neighbor_south: Option<KernelId>,
125 neighbor_east: Option<KernelId>,
126 neighbor_west: Option<KernelId>,
127
128 endpoint: K2KEndpoint,
130
131 #[allow(dead_code)]
133 kernel_id: KernelId,
134
135 is_north_boundary: bool,
137 is_south_boundary: bool,
138 is_east_boundary: bool,
139 is_west_boundary: bool,
140
141 reflection_coeff: f32,
143}
144
145impl TileActor {
146 pub fn new(
148 tile_x: u32,
149 tile_y: u32,
150 tile_size: u32,
151 tiles_x: u32,
152 tiles_y: u32,
153 broker: &Arc<K2KBroker>,
154 reflection_coeff: f32,
155 ) -> Self {
156 let buffer_width = (tile_size + 2) as usize;
157 let buffer_size = buffer_width * buffer_width;
158
159 let kernel_id = Self::tile_kernel_id(tile_x, tile_y);
160 let endpoint = broker.register(kernel_id.clone());
161
162 let neighbor_north = if tile_y > 0 {
164 Some(Self::tile_kernel_id(tile_x, tile_y - 1))
165 } else {
166 None
167 };
168 let neighbor_south = if tile_y < tiles_y - 1 {
169 Some(Self::tile_kernel_id(tile_x, tile_y + 1))
170 } else {
171 None
172 };
173 let neighbor_west = if tile_x > 0 {
174 Some(Self::tile_kernel_id(tile_x - 1, tile_y))
175 } else {
176 None
177 };
178 let neighbor_east = if tile_x < tiles_x - 1 {
179 Some(Self::tile_kernel_id(tile_x + 1, tile_y))
180 } else {
181 None
182 };
183
184 Self {
185 tile_x,
186 tile_y,
187 tile_size,
188 buffer_width,
189 pressure: vec![0.0; buffer_size],
190 pressure_prev: vec![0.0; buffer_size],
191 halo_north: vec![0.0; tile_size as usize],
192 halo_south: vec![0.0; tile_size as usize],
193 halo_east: vec![0.0; tile_size as usize],
194 halo_west: vec![0.0; tile_size as usize],
195 halos_received: [false; 4],
196 neighbor_north,
197 neighbor_south,
198 neighbor_east,
199 neighbor_west,
200 endpoint,
201 kernel_id,
202 is_north_boundary: tile_y == 0,
203 is_south_boundary: tile_y == tiles_y - 1,
204 is_east_boundary: tile_x == tiles_x - 1,
205 is_west_boundary: tile_x == 0,
206 reflection_coeff,
207 }
208 }
209
210 pub fn tile_kernel_id(tile_x: u32, tile_y: u32) -> KernelId {
212 KernelId::new(format!("tile_{}_{}", tile_x, tile_y))
213 }
214
215 #[inline(always)]
217 fn buffer_idx(&self, local_x: usize, local_y: usize) -> usize {
218 (local_y + 1) * self.buffer_width + (local_x + 1)
220 }
221
222 pub fn get_pressure(&self, local_x: u32, local_y: u32) -> f32 {
224 self.pressure[self.buffer_idx(local_x as usize, local_y as usize)]
225 }
226
227 pub fn set_pressure(&mut self, local_x: u32, local_y: u32, value: f32) {
229 let idx = self.buffer_idx(local_x as usize, local_y as usize);
230 self.pressure[idx] = value;
231 }
232
233 fn extract_edge(&self, direction: HaloDirection) -> Vec<f32> {
235 let size = self.tile_size as usize;
236 let mut edge = vec![0.0; size];
237
238 match direction {
239 HaloDirection::North => {
240 for (x, cell) in edge.iter_mut().enumerate().take(size) {
242 *cell = self.pressure[self.buffer_idx(x, 0)];
243 }
244 }
245 HaloDirection::South => {
246 for (x, cell) in edge.iter_mut().enumerate().take(size) {
248 *cell = self.pressure[self.buffer_idx(x, size - 1)];
249 }
250 }
251 HaloDirection::West => {
252 for (y, cell) in edge.iter_mut().enumerate().take(size) {
254 *cell = self.pressure[self.buffer_idx(0, y)];
255 }
256 }
257 HaloDirection::East => {
258 for (y, cell) in edge.iter_mut().enumerate().take(size) {
260 *cell = self.pressure[self.buffer_idx(size - 1, y)];
261 }
262 }
263 }
264
265 edge
266 }
267
268 fn apply_halo(&mut self, direction: HaloDirection, data: &[f32]) {
270 let size = self.tile_size as usize;
271 let bw = self.buffer_width;
272
273 match direction {
274 HaloDirection::North => {
275 self.pressure[1..(size + 1)].copy_from_slice(&data[..size]);
277 }
278 HaloDirection::South => {
279 for (x, &val) in data.iter().enumerate().take(size) {
281 self.pressure[(size + 1) * bw + (x + 1)] = val;
282 }
283 }
284 HaloDirection::West => {
285 for (y, &val) in data.iter().enumerate().take(size) {
287 self.pressure[(y + 1) * bw] = val;
288 }
289 }
290 HaloDirection::East => {
291 for (y, &val) in data.iter().enumerate().take(size) {
293 self.pressure[(y + 1) * bw + (size + 1)] = val;
294 }
295 }
296 }
297 }
298
299 fn apply_boundary_reflection(&mut self) {
301 let size = self.tile_size as usize;
302 let bw = self.buffer_width;
303 let refl = self.reflection_coeff;
304
305 if self.is_north_boundary {
306 for x in 0..size {
308 let interior_idx = self.buffer_idx(x, 0);
309 self.pressure[x + 1] = self.pressure[interior_idx] * refl;
310 }
311 }
312
313 if self.is_south_boundary {
314 for x in 0..size {
316 let interior_idx = self.buffer_idx(x, size - 1);
317 self.pressure[(size + 1) * bw + (x + 1)] = self.pressure[interior_idx] * refl;
318 }
319 }
320
321 if self.is_west_boundary {
322 for y in 0..size {
324 let interior_idx = self.buffer_idx(0, y);
325 self.pressure[(y + 1) * bw] = self.pressure[interior_idx] * refl;
326 }
327 }
328
329 if self.is_east_boundary {
330 for y in 0..size {
332 let interior_idx = self.buffer_idx(size - 1, y);
333 self.pressure[(y + 1) * bw + (size + 1)] = self.pressure[interior_idx] * refl;
334 }
335 }
336 }
337
338 pub async fn send_halos(&self) -> Result<()> {
340 for direction in HaloDirection::ALL {
341 let neighbor = match direction {
342 HaloDirection::North => &self.neighbor_north,
343 HaloDirection::South => &self.neighbor_south,
344 HaloDirection::East => &self.neighbor_east,
345 HaloDirection::West => &self.neighbor_west,
346 };
347
348 if let Some(neighbor_id) = neighbor {
349 let edge_data = self.extract_edge(direction);
350
351 let envelope = Self::create_halo_envelope(direction.opposite(), &edge_data);
353
354 let receipt = self.endpoint.send(neighbor_id.clone(), envelope).await?;
355 if receipt.status != DeliveryStatus::Delivered {
356 tracing::warn!(
357 "Halo send failed: tile ({},{}) -> {:?}, status: {:?}",
358 self.tile_x,
359 self.tile_y,
360 neighbor_id,
361 receipt.status
362 );
363 }
364 }
365 }
366
367 Ok(())
368 }
369
370 pub fn receive_halos(&mut self) {
372 self.halos_received = [false; 4];
374
375 while let Some(message) = self.endpoint.try_receive() {
377 if let Some((direction, data)) = Self::parse_halo_envelope(&message.envelope) {
378 match direction {
380 HaloDirection::North => {
381 self.halo_north.copy_from_slice(&data);
382 self.halos_received[0] = true;
383 }
384 HaloDirection::South => {
385 self.halo_south.copy_from_slice(&data);
386 self.halos_received[1] = true;
387 }
388 HaloDirection::East => {
389 self.halo_east.copy_from_slice(&data);
390 self.halos_received[2] = true;
391 }
392 HaloDirection::West => {
393 self.halo_west.copy_from_slice(&data);
394 self.halos_received[3] = true;
395 }
396 }
397 }
398 }
399
400 if self.halos_received[0] && self.neighbor_north.is_some() {
402 self.apply_halo(HaloDirection::North, &self.halo_north.clone());
403 }
404 if self.halos_received[1] && self.neighbor_south.is_some() {
405 self.apply_halo(HaloDirection::South, &self.halo_south.clone());
406 }
407 if self.halos_received[2] && self.neighbor_east.is_some() {
408 self.apply_halo(HaloDirection::East, &self.halo_east.clone());
409 }
410 if self.halos_received[3] && self.neighbor_west.is_some() {
411 self.apply_halo(HaloDirection::West, &self.halo_west.clone());
412 }
413
414 self.apply_boundary_reflection();
416 }
417
418 fn create_halo_envelope(direction: HaloDirection, data: &[f32]) -> MessageEnvelope {
420 let mut payload = Vec::with_capacity(1 + data.len() * 4);
423 payload.push(direction as u8);
424 for &value in data {
425 payload.extend_from_slice(&value.to_le_bytes());
426 }
427
428 let header = MessageHeader::new(
429 HALO_EXCHANGE_TYPE_ID,
430 0, 0, payload.len(),
433 HlcTimestamp::now(0),
434 );
435
436 MessageEnvelope { header, payload }
437 }
438
439 fn parse_halo_envelope(envelope: &MessageEnvelope) -> Option<(HaloDirection, Vec<f32>)> {
441 if envelope.header.message_type != HALO_EXCHANGE_TYPE_ID {
442 return None;
443 }
444
445 if envelope.payload.is_empty() {
446 return None;
447 }
448
449 let direction = match envelope.payload[0] {
450 0 => HaloDirection::North,
451 1 => HaloDirection::South,
452 2 => HaloDirection::East,
453 3 => HaloDirection::West,
454 _ => return None,
455 };
456
457 let data: Vec<f32> = envelope.payload[1..]
458 .chunks_exact(4)
459 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
460 .collect();
461
462 Some((direction, data))
463 }
464
465 pub fn compute_fdtd(&mut self, c2: f32, damping: f32) {
467 let size = self.tile_size as usize;
468 let bw = self.buffer_width;
469
470 for local_y in 0..size {
472 for local_x in 0..size {
473 let idx = (local_y + 1) * bw + (local_x + 1);
474
475 let p_curr = self.pressure[idx];
476 let p_prev = self.pressure_prev[idx];
477
478 let p_north = self.pressure[idx - bw];
480 let p_south = self.pressure[idx + bw];
481 let p_west = self.pressure[idx - 1];
482 let p_east = self.pressure[idx + 1];
483
484 let laplacian = p_north + p_south + p_east + p_west - 4.0 * p_curr;
486 let p_new = 2.0 * p_curr - p_prev + c2 * laplacian;
487
488 self.pressure_prev[idx] = p_new * damping;
489 }
490 }
491 }
492
493 pub fn swap_buffers(&mut self) {
495 std::mem::swap(&mut self.pressure, &mut self.pressure_prev);
496 }
497
498 pub fn reset(&mut self) {
500 self.pressure.fill(0.0);
501 self.pressure_prev.fill(0.0);
502 }
503}
504
505#[cfg(feature = "wgpu")]
508struct GpuComputeResources {
509 pool: TileGpuComputePool,
511 tile_buffers: HashMap<(u32, u32), TileBuffers>,
513}
514
515#[cfg(feature = "wgpu")]
518struct WgpuPersistentState {
519 backend: WgpuTileBackend,
521 tile_buffers: HashMap<(u32, u32), TileGpuBuffers<WgpuBuffer>>,
523}
524
525#[cfg(feature = "cuda")]
528struct CudaPersistentState {
529 backend: CudaTileBackend,
531 tile_buffers: HashMap<(u32, u32), TileGpuBuffers<ringkernel_cuda::CudaBuffer>>,
533}
534
535#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537pub enum GpuPersistentBackend {
538 Wgpu,
540 Cuda,
542}
543
544pub struct TileKernelGrid {
546 pub width: u32,
548 pub height: u32,
549
550 tile_size: u32,
552
553 tiles_x: u32,
555 tiles_y: u32,
556
557 tiles: HashMap<(u32, u32), TileActor>,
559
560 broker: Arc<K2KBroker>,
562
563 pub params: AcousticParams,
565
566 backend: Backend,
568
569 #[allow(dead_code)]
571 runtime: Arc<RingKernel>,
572
573 #[cfg(feature = "wgpu")]
576 gpu_compute: Option<GpuComputeResources>,
577
578 #[cfg(feature = "wgpu")]
581 wgpu_persistent: Option<WgpuPersistentState>,
582
583 #[cfg(feature = "cuda")]
586 cuda_persistent: Option<CudaPersistentState>,
587}
588
589impl std::fmt::Debug for TileKernelGrid {
590 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591 f.debug_struct("TileKernelGrid")
592 .field("width", &self.width)
593 .field("height", &self.height)
594 .field("tile_size", &self.tile_size)
595 .field("tiles_x", &self.tiles_x)
596 .field("tiles_y", &self.tiles_y)
597 .field("tile_count", &self.tiles.len())
598 .field("backend", &self.backend)
599 .finish()
600 }
601}
602
603impl TileKernelGrid {
604 pub async fn new(
606 width: u32,
607 height: u32,
608 params: AcousticParams,
609 backend: Backend,
610 ) -> Result<Self> {
611 Self::with_tile_size(width, height, params, backend, DEFAULT_TILE_SIZE).await
612 }
613
614 pub async fn with_tile_size(
616 width: u32,
617 height: u32,
618 params: AcousticParams,
619 backend: Backend,
620 tile_size: u32,
621 ) -> Result<Self> {
622 let runtime = Arc::new(RingKernel::with_backend(backend).await?);
623
624 let tiles_x = width.div_ceil(tile_size);
626 let tiles_y = height.div_ceil(tile_size);
627
628 let broker = K2KBuilder::new()
630 .max_pending_messages(tiles_x as usize * tiles_y as usize * 8)
631 .build();
632
633 let mut tiles = HashMap::new();
635 for ty in 0..tiles_y {
636 for tx in 0..tiles_x {
637 let tile = TileActor::new(tx, ty, tile_size, tiles_x, tiles_y, &broker, 0.95);
638 tiles.insert((tx, ty), tile);
639 }
640 }
641
642 tracing::info!(
643 "Created TileKernelGrid: {}x{} cells, {}x{} tiles ({}x{} per tile), {} tile actors",
644 width,
645 height,
646 tiles_x,
647 tiles_y,
648 tile_size,
649 tile_size,
650 tiles.len()
651 );
652
653 Ok(Self {
654 width,
655 height,
656 tile_size,
657 tiles_x,
658 tiles_y,
659 tiles,
660 broker,
661 params,
662 backend,
663 runtime,
664 #[cfg(feature = "wgpu")]
665 gpu_compute: None,
666 #[cfg(feature = "wgpu")]
667 wgpu_persistent: None,
668 #[cfg(feature = "cuda")]
669 cuda_persistent: None,
670 })
671 }
672
673 fn global_to_tile_coords(&self, x: u32, y: u32) -> (u32, u32, u32, u32) {
675 let tile_x = x / self.tile_size;
676 let tile_y = y / self.tile_size;
677 let local_x = x % self.tile_size;
678 let local_y = y % self.tile_size;
679 (tile_x, tile_y, local_x, local_y)
680 }
681
682 pub async fn step(&mut self) -> Result<()> {
684 let c2 = self.params.courant_number().powi(2);
685 let damping = 1.0 - self.params.damping;
686
687 for tile in self.tiles.values() {
689 tile.send_halos().await?;
690 }
691
692 for tile in self.tiles.values_mut() {
694 tile.receive_halos();
695 }
696
697 for tile in self.tiles.values_mut() {
699 tile.compute_fdtd(c2, damping);
700 }
701
702 for tile in self.tiles.values_mut() {
704 tile.swap_buffers();
705 }
706
707 Ok(())
708 }
709
710 #[cfg(feature = "wgpu")]
715 pub async fn enable_gpu_compute(&mut self) -> Result<()> {
716 let (device, queue) = init_wgpu().await?;
717 let pool = TileGpuComputePool::new(device, queue, self.tile_size)?;
718
719 let mut tile_buffers = HashMap::new();
721 for tile_coords in self.tiles.keys() {
722 let buffers = pool.create_tile_buffers();
723 tile_buffers.insert(*tile_coords, buffers);
724 }
725
726 let num_tiles = tile_buffers.len();
727 self.gpu_compute = Some(GpuComputeResources { pool, tile_buffers });
728
729 tracing::info!(
730 "GPU compute enabled for TileKernelGrid: {} tiles with GPU buffers",
731 num_tiles
732 );
733
734 Ok(())
735 }
736
737 #[cfg(feature = "wgpu")]
739 pub fn is_gpu_enabled(&self) -> bool {
740 self.gpu_compute.is_some()
741 }
742
743 #[cfg(feature = "wgpu")]
749 pub async fn step_gpu(&mut self) -> Result<()> {
750 let gpu = self.gpu_compute.as_ref().ok_or_else(|| {
751 ringkernel_core::error::RingKernelError::BackendError(
752 "GPU compute not enabled. Call enable_gpu_compute() first.".to_string(),
753 )
754 })?;
755
756 let c2 = self.params.courant_number().powi(2);
757 let damping = 1.0 - self.params.damping;
758
759 for tile in self.tiles.values() {
761 tile.send_halos().await?;
762 }
763
764 for tile in self.tiles.values_mut() {
766 tile.receive_halos();
767 }
768
769 for ((tx, ty), tile) in self.tiles.iter_mut() {
772 if let Some(buffers) = gpu.tile_buffers.get(&(*tx, *ty)) {
773 let pressure = &tile.pressure;
775 let pressure_prev = &tile.pressure_prev;
776
777 let result = gpu
779 .pool
780 .compute_fdtd(buffers, pressure, pressure_prev, c2, damping);
781
782 tile.pressure_prev.copy_from_slice(&result);
784 } else {
785 tile.compute_fdtd(c2, damping);
787 }
788 }
789
790 for tile in self.tiles.values_mut() {
792 tile.swap_buffers();
793 }
794
795 Ok(())
796 }
797
798 #[cfg(feature = "wgpu")]
804 pub async fn enable_wgpu_persistent(&mut self) -> Result<()> {
805 let backend = WgpuTileBackend::new(self.tile_size).await?;
806
807 let mut tile_buffers = HashMap::new();
809 for ((tx, ty), tile) in &self.tiles {
810 let buffers = backend.create_tile_buffers(self.tile_size)?;
811 backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
813 tile_buffers.insert((*tx, *ty), buffers);
814 }
815
816 let num_tiles = tile_buffers.len();
817 self.wgpu_persistent = Some(WgpuPersistentState {
818 backend,
819 tile_buffers,
820 });
821
822 tracing::info!(
823 "WGPU GPU-persistent compute enabled for TileKernelGrid: {} tiles",
824 num_tiles
825 );
826
827 Ok(())
828 }
829
830 #[cfg(feature = "cuda")]
835 pub fn enable_cuda_persistent(&mut self) -> Result<()> {
836 let backend = CudaTileBackend::new(self.tile_size)?;
837
838 let mut tile_buffers = HashMap::new();
840 for ((tx, ty), tile) in &self.tiles {
841 let buffers = backend.create_tile_buffers(self.tile_size)?;
842 backend.upload_initial_state(&buffers, &tile.pressure, &tile.pressure_prev)?;
844 tile_buffers.insert((*tx, *ty), buffers);
845 }
846
847 let num_tiles = tile_buffers.len();
848 self.cuda_persistent = Some(CudaPersistentState {
849 backend,
850 tile_buffers,
851 });
852
853 tracing::info!(
854 "CUDA GPU-persistent compute enabled for TileKernelGrid: {} tiles",
855 num_tiles
856 );
857
858 Ok(())
859 }
860
861 pub fn is_gpu_persistent_enabled(&self) -> bool {
863 #[cfg(feature = "wgpu")]
864 if self.wgpu_persistent.is_some() {
865 return true;
866 }
867 #[cfg(feature = "cuda")]
868 if self.cuda_persistent.is_some() {
869 return true;
870 }
871 false
872 }
873
874 pub fn gpu_persistent_backend(&self) -> Option<GpuPersistentBackend> {
876 #[cfg(feature = "cuda")]
877 if self.cuda_persistent.is_some() {
878 return Some(GpuPersistentBackend::Cuda);
879 }
880 #[cfg(feature = "wgpu")]
881 if self.wgpu_persistent.is_some() {
882 return Some(GpuPersistentBackend::Wgpu);
883 }
884 None
885 }
886
887 #[cfg(feature = "wgpu")]
891 pub fn step_wgpu_persistent(&mut self) -> Result<()> {
892 let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
893 ringkernel_core::error::RingKernelError::BackendError(
894 "WGPU persistent compute not enabled. Call enable_wgpu_persistent() first."
895 .to_string(),
896 )
897 })?;
898
899 let c2 = self.params.courant_number().powi(2);
900 let damping = 1.0 - self.params.damping;
901 let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
902
903 let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
906
907 for (tx, ty) in self.tiles.keys() {
908 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
909 let tile = self.tiles.get(&(*tx, *ty)).unwrap();
911
912 if tile.neighbor_north.is_some() {
913 let halo = state.backend.extract_halo(buffers, Edge::North)?;
914 halo_messages.push((
915 tile.neighbor_north.clone().unwrap(),
916 HaloDirection::South,
917 halo,
918 ));
919 }
920 if tile.neighbor_south.is_some() {
921 let halo = state.backend.extract_halo(buffers, Edge::South)?;
922 halo_messages.push((
923 tile.neighbor_south.clone().unwrap(),
924 HaloDirection::North,
925 halo,
926 ));
927 }
928 if tile.neighbor_west.is_some() {
929 let halo = state.backend.extract_halo(buffers, Edge::West)?;
930 halo_messages.push((
931 tile.neighbor_west.clone().unwrap(),
932 HaloDirection::East,
933 halo,
934 ));
935 }
936 if tile.neighbor_east.is_some() {
937 let halo = state.backend.extract_halo(buffers, Edge::East)?;
938 halo_messages.push((
939 tile.neighbor_east.clone().unwrap(),
940 HaloDirection::West,
941 halo,
942 ));
943 }
944 }
945 }
946
947 for (dest_id, direction, halo_data) in halo_messages {
950 for (tx, ty) in self.tiles.keys() {
952 if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
953 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
954 let edge = match direction {
955 HaloDirection::North => Edge::North,
956 HaloDirection::South => Edge::South,
957 HaloDirection::East => Edge::East,
958 HaloDirection::West => Edge::West,
959 };
960 state.backend.inject_halo(buffers, edge, &halo_data)?;
961 }
962 break;
963 }
964 }
965 }
966
967 for (tx, ty) in self.tiles.keys() {
973 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
974 state.backend.fdtd_step(buffers, &fdtd_params)?;
975 }
976 }
977
978 state.backend.synchronize()?;
980
981 let state = self.wgpu_persistent.as_mut().unwrap();
984 for buffers in state.tile_buffers.values_mut() {
985 state.backend.swap_buffers(buffers);
986 }
987
988 Ok(())
989 }
990
991 #[cfg(feature = "cuda")]
995 pub fn step_cuda_persistent(&mut self) -> Result<()> {
996 let state = self.cuda_persistent.as_ref().ok_or_else(|| {
997 ringkernel_core::error::RingKernelError::BackendError(
998 "CUDA persistent compute not enabled. Call enable_cuda_persistent() first."
999 .to_string(),
1000 )
1001 })?;
1002
1003 let c2 = self.params.courant_number().powi(2);
1004 let damping = 1.0 - self.params.damping;
1005 let fdtd_params = FdtdParams::new(self.tile_size, c2, damping);
1006
1007 let mut halo_messages: Vec<(KernelId, HaloDirection, Vec<f32>)> = Vec::new();
1009
1010 for (tx, ty) in self.tiles.keys() {
1011 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1012 let tile = self.tiles.get(&(*tx, *ty)).unwrap();
1013
1014 if tile.neighbor_north.is_some() {
1015 let halo = state.backend.extract_halo(buffers, Edge::North)?;
1016 halo_messages.push((
1017 tile.neighbor_north.clone().unwrap(),
1018 HaloDirection::South,
1019 halo,
1020 ));
1021 }
1022 if tile.neighbor_south.is_some() {
1023 let halo = state.backend.extract_halo(buffers, Edge::South)?;
1024 halo_messages.push((
1025 tile.neighbor_south.clone().unwrap(),
1026 HaloDirection::North,
1027 halo,
1028 ));
1029 }
1030 if tile.neighbor_west.is_some() {
1031 let halo = state.backend.extract_halo(buffers, Edge::West)?;
1032 halo_messages.push((
1033 tile.neighbor_west.clone().unwrap(),
1034 HaloDirection::East,
1035 halo,
1036 ));
1037 }
1038 if tile.neighbor_east.is_some() {
1039 let halo = state.backend.extract_halo(buffers, Edge::East)?;
1040 halo_messages.push((
1041 tile.neighbor_east.clone().unwrap(),
1042 HaloDirection::West,
1043 halo,
1044 ));
1045 }
1046 }
1047 }
1048
1049 for (dest_id, direction, halo_data) in halo_messages {
1051 for (tx, ty) in self.tiles.keys() {
1052 if TileActor::tile_kernel_id(*tx, *ty) == dest_id {
1053 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1054 let edge = match direction {
1055 HaloDirection::North => Edge::North,
1056 HaloDirection::South => Edge::South,
1057 HaloDirection::East => Edge::East,
1058 HaloDirection::West => Edge::West,
1059 };
1060 state.backend.inject_halo(buffers, edge, &halo_data)?;
1061 }
1062 break;
1063 }
1064 }
1065 }
1066
1067 for (tx, ty) in self.tiles.keys() {
1069 if let Some(buffers) = state.tile_buffers.get(&(*tx, *ty)) {
1070 state.backend.fdtd_step(buffers, &fdtd_params)?;
1071 }
1072 }
1073
1074 state.backend.synchronize()?;
1076
1077 let state = self.cuda_persistent.as_mut().unwrap();
1079 for buffers in state.tile_buffers.values_mut() {
1080 state.backend.swap_buffers(buffers);
1081 }
1082
1083 Ok(())
1084 }
1085
1086 #[cfg(feature = "wgpu")]
1091 pub fn read_pressure_from_wgpu(&self) -> Result<Vec<Vec<f32>>> {
1092 let state = self.wgpu_persistent.as_ref().ok_or_else(|| {
1093 ringkernel_core::error::RingKernelError::BackendError(
1094 "WGPU persistent compute not enabled.".to_string(),
1095 )
1096 })?;
1097
1098 let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1099
1100 for ((tx, ty), buffers) in &state.tile_buffers {
1101 let interior = state.backend.read_interior_pressure(buffers)?;
1102
1103 let tile_start_x = tx * self.tile_size;
1105 let tile_start_y = ty * self.tile_size;
1106
1107 for ly in 0..self.tile_size {
1108 for lx in 0..self.tile_size {
1109 let gx = tile_start_x + lx;
1110 let gy = tile_start_y + ly;
1111 if gx < self.width && gy < self.height {
1112 grid[gy as usize][gx as usize] =
1113 interior[(ly * self.tile_size + lx) as usize];
1114 }
1115 }
1116 }
1117 }
1118
1119 Ok(grid)
1120 }
1121
1122 #[cfg(feature = "cuda")]
1124 pub fn read_pressure_from_cuda(&self) -> Result<Vec<Vec<f32>>> {
1125 let state = self.cuda_persistent.as_ref().ok_or_else(|| {
1126 ringkernel_core::error::RingKernelError::BackendError(
1127 "CUDA persistent compute not enabled.".to_string(),
1128 )
1129 })?;
1130
1131 let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1132
1133 for ((tx, ty), buffers) in &state.tile_buffers {
1134 let interior = state.backend.read_interior_pressure(buffers)?;
1135
1136 let tile_start_x = tx * self.tile_size;
1137 let tile_start_y = ty * self.tile_size;
1138
1139 for ly in 0..self.tile_size {
1140 for lx in 0..self.tile_size {
1141 let gx = tile_start_x + lx;
1142 let gy = tile_start_y + ly;
1143 if gx < self.width && gy < self.height {
1144 grid[gy as usize][gx as usize] =
1145 interior[(ly * self.tile_size + lx) as usize];
1146 }
1147 }
1148 }
1149 }
1150
1151 Ok(grid)
1152 }
1153
1154 #[cfg(feature = "wgpu")]
1158 pub fn inject_impulse_wgpu(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1159 if x >= self.width || y >= self.height {
1160 return Ok(());
1161 }
1162
1163 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1164
1165 if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1167 let current = tile.get_pressure(local_x, local_y);
1168 tile.set_pressure(local_x, local_y, current + amplitude);
1169
1170 if let Some(state) = &self.wgpu_persistent {
1172 if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1173 state.backend.upload_initial_state(
1174 buffers,
1175 &tile.pressure,
1176 &tile.pressure_prev,
1177 )?;
1178 }
1179 }
1180 }
1181
1182 Ok(())
1183 }
1184
1185 #[cfg(feature = "cuda")]
1187 pub fn inject_impulse_cuda(&mut self, x: u32, y: u32, amplitude: f32) -> Result<()> {
1188 if x >= self.width || y >= self.height {
1189 return Ok(());
1190 }
1191
1192 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1193
1194 if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1195 let current = tile.get_pressure(local_x, local_y);
1196 tile.set_pressure(local_x, local_y, current + amplitude);
1197
1198 if let Some(state) = &self.cuda_persistent {
1199 if let Some(buffers) = state.tile_buffers.get(&(tile_x, tile_y)) {
1200 state.backend.upload_initial_state(
1201 buffers,
1202 &tile.pressure,
1203 &tile.pressure_prev,
1204 )?;
1205 }
1206 }
1207 }
1208
1209 Ok(())
1210 }
1211
1212 pub fn inject_impulse(&mut self, x: u32, y: u32, amplitude: f32) {
1214 if x >= self.width || y >= self.height {
1215 return;
1216 }
1217
1218 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1219
1220 if let Some(tile) = self.tiles.get_mut(&(tile_x, tile_y)) {
1221 let current = tile.get_pressure(local_x, local_y);
1222 tile.set_pressure(local_x, local_y, current + amplitude);
1223 }
1224 }
1225
1226 pub fn get_pressure_grid(&self) -> Vec<Vec<f32>> {
1228 let mut grid = vec![vec![0.0; self.width as usize]; self.height as usize];
1229
1230 for y in 0..self.height {
1231 for x in 0..self.width {
1232 let (tile_x, tile_y, local_x, local_y) = self.global_to_tile_coords(x, y);
1233
1234 if let Some(tile) = self.tiles.get(&(tile_x, tile_y)) {
1235 if local_x < self.tile_size && local_y < self.tile_size {
1237 grid[y as usize][x as usize] = tile.get_pressure(local_x, local_y);
1238 }
1239 }
1240 }
1241 }
1242
1243 grid
1244 }
1245
1246 pub fn max_pressure(&self) -> f32 {
1248 self.tiles
1249 .values()
1250 .flat_map(|tile| {
1251 (0..self.tile_size).flat_map(move |y| {
1252 (0..self.tile_size).map(move |x| tile.get_pressure(x, y).abs())
1253 })
1254 })
1255 .fold(0.0, f32::max)
1256 }
1257
1258 pub fn total_energy(&self) -> f32 {
1260 self.tiles
1261 .values()
1262 .flat_map(|tile| {
1263 (0..self.tile_size).flat_map(move |y| {
1264 (0..self.tile_size).map(move |x| {
1265 let p = tile.get_pressure(x, y);
1266 p * p
1267 })
1268 })
1269 })
1270 .sum()
1271 }
1272
1273 pub fn reset(&mut self) {
1275 for tile in self.tiles.values_mut() {
1276 tile.reset();
1277 }
1278 }
1279
1280 pub fn cell_count(&self) -> usize {
1282 (self.width * self.height) as usize
1283 }
1284
1285 pub fn tile_count(&self) -> usize {
1287 self.tiles.len()
1288 }
1289
1290 pub fn backend(&self) -> Backend {
1292 self.backend
1293 }
1294
1295 pub fn k2k_stats(&self) -> ringkernel_core::k2k::K2KStats {
1297 self.broker.stats()
1298 }
1299
1300 pub fn set_speed_of_sound(&mut self, speed: f32) {
1302 self.params.set_speed_of_sound(speed);
1303 }
1304
1305 pub fn set_cell_size(&mut self, size: f32) {
1307 self.params.set_cell_size(size);
1308 }
1309
1310 pub async fn resize(&mut self, new_width: u32, new_height: u32) -> Result<()> {
1312 self.width = new_width;
1313 self.height = new_height;
1314
1315 self.tiles_x = new_width.div_ceil(self.tile_size);
1317 self.tiles_y = new_height.div_ceil(self.tile_size);
1318
1319 self.broker = K2KBuilder::new()
1321 .max_pending_messages(self.tiles_x as usize * self.tiles_y as usize * 8)
1322 .build();
1323
1324 self.tiles.clear();
1326 for ty in 0..self.tiles_y {
1327 for tx in 0..self.tiles_x {
1328 let tile = TileActor::new(
1329 tx,
1330 ty,
1331 self.tile_size,
1332 self.tiles_x,
1333 self.tiles_y,
1334 &self.broker,
1335 0.95,
1336 );
1337 self.tiles.insert((tx, ty), tile);
1338 }
1339 }
1340
1341 tracing::info!(
1342 "Resized TileKernelGrid: {}x{} cells, {} tile actors",
1343 new_width,
1344 new_height,
1345 self.tiles.len()
1346 );
1347
1348 Ok(())
1349 }
1350
1351 pub async fn shutdown(self) -> Result<()> {
1353 Ok(())
1355 }
1356}
1357
1358#[cfg(test)]
1359mod tests {
1360 use super::*;
1361
1362 #[tokio::test]
1363 async fn test_tile_grid_creation() {
1364 let params = AcousticParams::new(343.0, 1.0);
1365 let grid = TileKernelGrid::new(64, 64, params, Backend::Cpu)
1366 .await
1367 .unwrap();
1368
1369 assert_eq!(grid.width, 64);
1370 assert_eq!(grid.height, 64);
1371 assert_eq!(grid.tile_size, DEFAULT_TILE_SIZE);
1372 assert_eq!(grid.tiles_x, 4); assert_eq!(grid.tiles_y, 4);
1374 assert_eq!(grid.tile_count(), 16); }
1376
1377 #[tokio::test]
1378 async fn test_tile_grid_impulse() {
1379 let params = AcousticParams::new(343.0, 1.0);
1380 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1381 .await
1382 .unwrap();
1383
1384 grid.inject_impulse(16, 16, 1.0);
1385
1386 let pressure_grid = grid.get_pressure_grid();
1387 assert_eq!(pressure_grid[16][16], 1.0);
1388 }
1389
1390 #[tokio::test]
1391 async fn test_tile_grid_step() {
1392 let params = AcousticParams::new(343.0, 1.0);
1393 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1394 .await
1395 .unwrap();
1396
1397 grid.inject_impulse(16, 16, 1.0);
1399
1400 for _ in 0..10 {
1402 grid.step().await.unwrap();
1403 }
1404
1405 let pressure_grid = grid.get_pressure_grid();
1407 let neighbor_pressure = pressure_grid[16][17];
1408 assert!(
1409 neighbor_pressure.abs() > 0.0,
1410 "Wave should have propagated to neighbor"
1411 );
1412 }
1413
1414 #[tokio::test]
1415 async fn test_tile_grid_k2k_stats() {
1416 let params = AcousticParams::new(343.0, 1.0);
1417 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1418 .await
1419 .unwrap();
1420
1421 grid.inject_impulse(16, 16, 1.0);
1422 grid.step().await.unwrap();
1423
1424 let stats = grid.k2k_stats();
1425 assert!(
1426 stats.messages_delivered > 0,
1427 "K2K messages should have been exchanged"
1428 );
1429 }
1430
1431 #[tokio::test]
1432 async fn test_tile_grid_reset() {
1433 let params = AcousticParams::new(343.0, 1.0);
1434 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1435 .await
1436 .unwrap();
1437
1438 grid.inject_impulse(16, 16, 1.0);
1439 grid.step().await.unwrap();
1440
1441 grid.reset();
1442
1443 let pressure_grid = grid.get_pressure_grid();
1444 for row in pressure_grid {
1445 for p in row {
1446 assert_eq!(p, 0.0);
1447 }
1448 }
1449 }
1450
1451 #[tokio::test]
1452 async fn test_tile_boundary_handling() {
1453 let params = AcousticParams::new(343.0, 1.0);
1454 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1455 .await
1456 .unwrap();
1457
1458 grid.inject_impulse(1, 1, 1.0);
1460
1461 for _ in 0..5 {
1463 grid.step().await.unwrap();
1464 }
1465
1466 let energy = grid.total_energy();
1468 assert!(energy.is_finite(), "Energy should be finite");
1469 }
1470
1471 #[cfg(feature = "wgpu")]
1472 #[tokio::test]
1473 #[ignore] async fn test_tile_grid_gpu_step() {
1475 let params = AcousticParams::new(343.0, 1.0);
1476 let mut grid = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1477 .await
1478 .unwrap();
1479
1480 grid.enable_gpu_compute().await.unwrap();
1482 assert!(grid.is_gpu_enabled());
1483
1484 grid.inject_impulse(16, 16, 1.0);
1486
1487 for _ in 0..10 {
1489 grid.step_gpu().await.unwrap();
1490 }
1491
1492 let pressure_grid = grid.get_pressure_grid();
1494 let neighbor_pressure = pressure_grid[16][17];
1495 assert!(
1496 neighbor_pressure.abs() > 0.0,
1497 "Wave should have propagated to neighbor (GPU compute)"
1498 );
1499 }
1500
1501 #[cfg(feature = "wgpu")]
1502 #[tokio::test]
1503 #[ignore] async fn test_tile_grid_gpu_matches_cpu() {
1505 let params = AcousticParams::new(343.0, 1.0);
1506
1507 let mut grid_cpu = TileKernelGrid::with_tile_size(32, 32, params.clone(), Backend::Cpu, 16)
1509 .await
1510 .unwrap();
1511 let mut grid_gpu = TileKernelGrid::with_tile_size(32, 32, params, Backend::Cpu, 16)
1512 .await
1513 .unwrap();
1514
1515 grid_gpu.enable_gpu_compute().await.unwrap();
1516
1517 grid_cpu.inject_impulse(16, 16, 1.0);
1519 grid_gpu.inject_impulse(16, 16, 1.0);
1520
1521 for _ in 0..5 {
1523 grid_cpu.step().await.unwrap();
1524 grid_gpu.step_gpu().await.unwrap();
1525 }
1526
1527 let cpu_grid = grid_cpu.get_pressure_grid();
1529 let gpu_grid = grid_gpu.get_pressure_grid();
1530
1531 for y in 0..32 {
1532 for x in 0..32 {
1533 let diff = (cpu_grid[y][x] - gpu_grid[y][x]).abs();
1534 assert!(
1535 diff < 1e-4,
1536 "CPU/GPU mismatch at ({},{}): cpu={}, gpu={}, diff={}",
1537 x,
1538 y,
1539 cpu_grid[y][x],
1540 gpu_grid[y][x],
1541 diff
1542 );
1543 }
1544 }
1545 }
1546}