1#[derive(Debug, Clone)]
36pub struct PersistentFdtdConfig {
37 pub name: String,
39 pub tile_size: (usize, usize, usize),
41 pub use_cooperative: bool,
43 pub progress_interval: u64,
45 pub track_energy: bool,
47 pub idle_sleep_ns: u32,
50 pub use_libcupp_atomics: bool,
54}
55
56impl Default for PersistentFdtdConfig {
57 fn default() -> Self {
58 Self {
59 name: "persistent_fdtd3d".to_string(),
60 tile_size: (8, 8, 8),
61 use_cooperative: true,
62 progress_interval: 100,
63 track_energy: true,
64 idle_sleep_ns: 1000,
65 use_libcupp_atomics: false,
66 }
67 }
68}
69
70impl PersistentFdtdConfig {
71 pub fn new(name: &str) -> Self {
73 Self {
74 name: name.to_string(),
75 ..Default::default()
76 }
77 }
78
79 pub fn with_tile_size(mut self, tx: usize, ty: usize, tz: usize) -> Self {
81 self.tile_size = (tx, ty, tz);
82 self
83 }
84
85 pub fn with_cooperative(mut self, use_coop: bool) -> Self {
87 self.use_cooperative = use_coop;
88 self
89 }
90
91 pub fn with_progress_interval(mut self, interval: u64) -> Self {
93 self.progress_interval = interval;
94 self
95 }
96
97 pub fn with_idle_sleep(mut self, ns: u32) -> Self {
100 self.idle_sleep_ns = ns;
101 self
102 }
103
104 pub fn with_libcupp_atomics(mut self, enabled: bool) -> Self {
107 self.use_libcupp_atomics = enabled;
108 self
109 }
110
111 pub fn threads_per_block(&self) -> usize {
113 self.tile_size.0 * self.tile_size.1 * self.tile_size.2
114 }
115
116 pub fn shared_mem_size(&self) -> usize {
118 let (tx, ty, tz) = self.tile_size;
119 let with_halo = (tx + 2) * (ty + 2) * (tz + 2);
121 with_halo * std::mem::size_of::<f32>()
122 }
123}
124
125pub fn generate_persistent_fdtd_kernel(config: &PersistentFdtdConfig) -> String {
134 let mut code = String::new();
135
136 code.push_str(&generate_header(config));
138
139 code.push_str(&generate_structures());
141
142 code.push_str(&generate_device_functions(config));
144
145 code.push_str(&generate_main_kernel(config));
147
148 code
149}
150
151fn generate_header(config: &PersistentFdtdConfig) -> String {
152 let mut code = String::new();
153
154 code.push_str("// Generated Persistent FDTD Kernel\n");
155 code.push_str("// RingKernel GPU Actor System\n\n");
156
157 if config.use_cooperative {
158 code.push_str("#include <cooperative_groups.h>\n");
159 code.push_str("namespace cg = cooperative_groups;\n\n");
160 }
161
162 code.push_str("#include <cuda_runtime.h>\n");
163 code.push_str("#include <stdint.h>\n");
164
165 if config.use_libcupp_atomics {
166 code.push_str("\n// libcu++ ordered atomics (CUDA 11.0+)\n");
167 code.push_str("#if __CUDACC_VER_MAJOR__ < 11\n");
168 code.push_str("#error \"libcu++ atomics require CUDA 11.0 or later\"\n");
169 code.push_str("#endif\n");
170 code.push_str("#include <cuda/atomic>\n");
171 }
172
173 code.push('\n');
174
175 code
176}
177
178fn generate_structures() -> String {
179 r#"
180// ============================================================================
181// STRUCTURE DEFINITIONS (must match Rust persistent.rs)
182// ============================================================================
183
184// Control block for persistent kernel (256 bytes, cache-aligned)
185typedef struct __align__(256) {
186 // Lifecycle control (host -> GPU)
187 uint32_t should_terminate;
188 uint32_t _pad0;
189 uint64_t steps_remaining;
190
191 // State (GPU -> host)
192 uint64_t current_step;
193 uint32_t current_buffer;
194 uint32_t has_terminated;
195 float total_energy;
196 uint32_t _pad1;
197
198 // Configuration
199 uint32_t grid_dim[3];
200 uint32_t block_dim[3];
201 uint32_t sim_size[3];
202 uint32_t tile_size[3];
203
204 // Acoustic parameters
205 float c2_dt2;
206 float damping;
207 float cell_size;
208 float dt;
209
210 // Synchronization
211 uint32_t barrier_counter;
212 uint32_t barrier_generation;
213
214 // Statistics
215 uint64_t messages_processed;
216 uint64_t k2k_messages_sent;
217 uint64_t k2k_messages_received;
218
219 uint64_t _reserved[16];
220} PersistentControlBlock;
221
222// SPSC queue header (128 bytes, cache-aligned)
223typedef struct __align__(128) {
224 uint64_t head;
225 uint64_t tail;
226 uint32_t capacity;
227 uint32_t mask;
228 uint64_t _padding[12];
229} SpscQueueHeader;
230
231// H2K message (64 bytes)
232typedef struct __align__(64) {
233 uint32_t cmd;
234 uint32_t flags;
235 uint64_t cmd_id;
236 uint64_t param1;
237 uint32_t param2;
238 uint32_t param3;
239 float param4;
240 float param5;
241 uint64_t _reserved[4];
242} H2KMessage;
243
244// K2H message (64 bytes)
245typedef struct __align__(64) {
246 uint32_t resp_type;
247 uint32_t flags;
248 uint64_t cmd_id;
249 uint64_t step;
250 uint64_t steps_remaining;
251 float energy;
252 uint32_t error_code;
253 uint64_t _reserved[3];
254} K2HMessage;
255
256// Command types
257#define CMD_NOP 0
258#define CMD_RUN_STEPS 1
259#define CMD_PAUSE 2
260#define CMD_RESUME 3
261#define CMD_TERMINATE 4
262#define CMD_INJECT_IMPULSE 5
263#define CMD_SET_SOURCE 6
264#define CMD_GET_PROGRESS 7
265
266// Response types
267#define RESP_ACK 0
268#define RESP_PROGRESS 1
269#define RESP_ERROR 2
270#define RESP_TERMINATED 3
271#define RESP_ENERGY 4
272
273// Neighbor block IDs
274typedef struct {
275 int32_t pos_x, neg_x;
276 int32_t pos_y, neg_y;
277 int32_t pos_z, neg_z;
278 int32_t _padding[2];
279} BlockNeighbors;
280
281// K2K route entry
282typedef struct {
283 BlockNeighbors neighbors;
284 uint32_t block_pos[3];
285 uint32_t _padding;
286 uint32_t cell_offset[3];
287 uint32_t _padding2;
288} K2KRouteEntry;
289
290// Face indices
291#define FACE_POS_X 0
292#define FACE_NEG_X 1
293#define FACE_POS_Y 2
294#define FACE_NEG_Y 3
295#define FACE_POS_Z 4
296#define FACE_NEG_Z 5
297
298"#
299 .to_string()
300}
301
302fn generate_device_functions(config: &PersistentFdtdConfig) -> String {
303 let (tx, ty, tz) = config.tile_size;
304 let face_size = tx * ty;
305
306 let mut code = String::new();
307
308 if config.use_libcupp_atomics {
310 code.push_str(
311 r#"
312// ============================================================================
313// SYNCHRONIZATION FUNCTIONS (libcu++ ordered atomics)
314// ============================================================================
315
316// Software grid barrier using cuda::atomic_ref with explicit memory ordering.
317// Uses device scope (not system) since barrier is GPU-internal.
318__device__ void software_grid_sync(
319 uint32_t* barrier_counter,
320 uint32_t* barrier_gen,
321 int num_blocks
322) {
323 __syncthreads(); // First sync within block
324
325 if (threadIdx.x == 0) {
326 cuda::atomic_ref<unsigned int, cuda::thread_scope_device> gen_ref(*barrier_gen);
327 cuda::atomic_ref<unsigned int, cuda::thread_scope_device> cnt_ref(*barrier_counter);
328
329 unsigned int gen = gen_ref.load(cuda::memory_order_acquire);
330
331 unsigned int arrived = cnt_ref.fetch_add(1, cuda::memory_order_acq_rel) + 1;
332 if (arrived == num_blocks) {
333 cnt_ref.store(0, cuda::memory_order_relaxed);
334 gen_ref.fetch_add(1, cuda::memory_order_release);
335 } else {
336 while (gen_ref.load(cuda::memory_order_acquire) == gen) {
337 __nanosleep(100); // Reduce power in barrier spin
338 }
339 }
340 }
341
342 __syncthreads();
343}
344
345"#,
346 );
347 } else {
348 code.push_str(
349 r#"
350// ============================================================================
351// SYNCHRONIZATION FUNCTIONS
352// ============================================================================
353
354// Software grid barrier (atomic counter + generation)
355__device__ void software_grid_sync(
356 volatile uint32_t* barrier_counter,
357 volatile uint32_t* barrier_gen,
358 int num_blocks
359) {
360 __syncthreads(); // First sync within block
361
362 if (threadIdx.x == 0) {
363 unsigned int gen = *barrier_gen;
364
365 unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
366 if (arrived == num_blocks) {
367 *barrier_counter = 0;
368 __threadfence();
369 atomicAdd((unsigned int*)barrier_gen, 1);
370 } else {
371 while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
372 __threadfence();
373 __nanosleep(100); // Reduce power in barrier spin
374 }
375 }
376 }
377
378 __syncthreads();
379}
380
381"#,
382 );
383 }
384
385 if config.use_libcupp_atomics {
387 code.push_str(
388 r#"
389// ============================================================================
390// MESSAGE QUEUE OPERATIONS (libcu++ ordered atomics)
391// ============================================================================
392
393// Try to receive H2K message using cuda::atomic_ref with memory ordering.
394// acquire on tail read ensures we see the host's payload writes.
395// release on tail publish ensures host sees our consumption.
396__device__ bool h2k_try_recv(
397 SpscQueueHeader* header,
398 H2KMessage* slots,
399 H2KMessage* out_msg
400) {
401 cuda::atomic_ref<uint64_t, cuda::thread_scope_system> head_ref(header->head);
402 cuda::atomic_ref<uint64_t, cuda::thread_scope_system> tail_ref(header->tail);
403
404 uint64_t tail = tail_ref.load(cuda::memory_order_relaxed);
405 uint64_t head = head_ref.load(cuda::memory_order_acquire);
406
407 if (head == tail) {
408 return false; // Empty
409 }
410
411 uint32_t slot = tail & header->mask;
412 *out_msg = slots[slot];
413
414 tail_ref.store(tail + 1, cuda::memory_order_release);
415
416 return true;
417}
418
419// Send K2H message using cuda::atomic_ref with memory ordering.
420// acquire on head read to see our own previous writes.
421// release on head publish ensures host sees the payload.
422__device__ bool k2h_send(
423 SpscQueueHeader* header,
424 K2HMessage* slots,
425 const K2HMessage* msg
426) {
427 cuda::atomic_ref<uint64_t, cuda::thread_scope_system> head_ref(header->head);
428 cuda::atomic_ref<uint64_t, cuda::thread_scope_system> tail_ref(header->tail);
429
430 uint64_t head = head_ref.load(cuda::memory_order_relaxed);
431 uint64_t tail = tail_ref.load(cuda::memory_order_acquire);
432 uint32_t capacity = header->capacity;
433
434 if (head - tail >= capacity) {
435 return false; // Full
436 }
437
438 uint32_t slot = head & header->mask;
439 slots[slot] = *msg;
440
441 head_ref.store(head + 1, cuda::memory_order_release);
442
443 return true;
444}
445
446"#,
447 );
448 } else {
449 code.push_str(
450 r#"
451// ============================================================================
452// MESSAGE QUEUE OPERATIONS
453// ============================================================================
454
455// Copy H2K message from volatile source (field-by-field to handle volatile)
456__device__ void copy_h2k_message(H2KMessage* dst, volatile H2KMessage* src) {
457 dst->cmd = src->cmd;
458 dst->flags = src->flags;
459 dst->cmd_id = src->cmd_id;
460 dst->param1 = src->param1;
461 dst->param2 = src->param2;
462 dst->param3 = src->param3;
463 dst->param4 = src->param4;
464 dst->param5 = src->param5;
465 // Skip reserved fields for performance
466}
467
468// Copy K2H message to volatile destination (field-by-field to handle volatile)
469__device__ void copy_k2h_message(volatile K2HMessage* dst, const K2HMessage* src) {
470 dst->resp_type = src->resp_type;
471 dst->flags = src->flags;
472 dst->cmd_id = src->cmd_id;
473 dst->step = src->step;
474 dst->steps_remaining = src->steps_remaining;
475 dst->energy = src->energy;
476 dst->error_code = src->error_code;
477 // Skip reserved fields for performance
478}
479
480// Try to receive H2K message (returns true if message available)
481__device__ bool h2k_try_recv(
482 volatile SpscQueueHeader* header,
483 volatile H2KMessage* slots,
484 H2KMessage* out_msg
485) {
486 // Fence BEFORE reading to ensure we see host writes
487 __threadfence_system();
488
489 uint64_t head = header->head;
490 uint64_t tail = header->tail;
491
492 if (head == tail) {
493 return false; // Empty
494 }
495
496 uint32_t slot = tail & header->mask;
497 copy_h2k_message(out_msg, &slots[slot]);
498
499 __threadfence();
500 header->tail = tail + 1;
501
502 return true;
503}
504
505// Send K2H message
506__device__ bool k2h_send(
507 volatile SpscQueueHeader* header,
508 volatile K2HMessage* slots,
509 const K2HMessage* msg
510) {
511 uint64_t head = header->head;
512 uint64_t tail = header->tail;
513 uint32_t capacity = header->capacity;
514
515 if (head - tail >= capacity) {
516 return false; // Full
517 }
518
519 uint32_t slot = head & header->mask;
520 copy_k2h_message(&slots[slot], msg);
521
522 __threadfence_system(); // Ensure host sees our writes
523 header->head = head + 1;
524
525 return true;
526}
527
528"#,
529 );
530 }
531
532 code.push_str(
534 r#"
535// ============================================================================
536// ENERGY CALCULATION (Warp-Shuffle Reduction)
537// ============================================================================
538
539// Block-level parallel reduction for energy: E = sum(p^2)
540// Two-phase approach: intra-warp shuffle (no __syncthreads), then
541// cross-warp reduction via shared memory (one __syncthreads).
542__device__ float block_reduce_energy(
543 float my_energy,
544 float* shared_reduce,
545 int threads_per_block
546) {
547 int tid = threadIdx.x;
548 int warp_id = tid / 32;
549 int lane_id = tid % 32;
550 int num_warps = threads_per_block / 32;
551
552 // Phase 1: Intra-warp reduction via shuffle (no syncthreads needed)
553 float val = my_energy;
554 for (int offset = 16; offset > 0; offset >>= 1) {
555 val += __shfl_down_sync(0xFFFFFFFF, val, offset);
556 }
557
558 // Warp leaders write partial sums to shared memory
559 if (lane_id == 0) {
560 shared_reduce[warp_id] = val;
561 }
562 __syncthreads();
563
564 // Phase 2: First warp reduces across all warp results
565 val = (tid < num_warps) ? shared_reduce[tid] : 0.0f;
566 if (warp_id == 0) {
567 for (int offset = 16; offset > 0; offset >>= 1) {
568 val += __shfl_down_sync(0xFFFFFFFF, val, offset);
569 }
570 }
571
572 // Thread 0 has the final block sum
573 if (tid == 0) {
574 shared_reduce[0] = val;
575 }
576
577 return shared_reduce[0];
578}
579
580"#,
581 );
582
583 code.push_str(&format!(
585 r#"
586// ============================================================================
587// K2K HALO EXCHANGE
588// ============================================================================
589
590#define TILE_X {tx}
591#define TILE_Y {ty}
592#define TILE_Z {tz}
593#define FACE_SIZE {face_size}
594
595// Pack halo faces from shared memory to device halo buffer
596__device__ void pack_halo_faces(
597 float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
598 float* halo_buffers,
599 int block_id,
600 int pingpong,
601 int num_blocks
602) {{
603 int tid = threadIdx.x;
604 int face_stride = FACE_SIZE;
605 int block_stride = 6 * face_stride * 2; // 6 faces, 2 ping-pong
606 int pp_offset = pingpong * 6 * face_stride;
607
608 float* block_halo = halo_buffers + block_id * block_stride + pp_offset;
609
610 if (tid < FACE_SIZE) {{
611 int fx = tid % TILE_X;
612 int fy = tid / TILE_X;
613
614 // +X face (x = TILE_X in local coords, which is index TILE_X)
615 block_halo[FACE_POS_X * face_stride + tid] = tile[fy + 1][fx + 1][TILE_X];
616 // -X face (x = 1 in local coords)
617 block_halo[FACE_NEG_X * face_stride + tid] = tile[fy + 1][fx + 1][1];
618 // +Y face
619 block_halo[FACE_POS_Y * face_stride + tid] = tile[fy + 1][TILE_Y][fx + 1];
620 // -Y face
621 block_halo[FACE_NEG_Y * face_stride + tid] = tile[fy + 1][1][fx + 1];
622 // +Z face
623 block_halo[FACE_POS_Z * face_stride + tid] = tile[TILE_Z][fy + 1][fx + 1];
624 // -Z face
625 block_halo[FACE_NEG_Z * face_stride + tid] = tile[1][fy + 1][fx + 1];
626 }}
627}}
628
629// Unpack halo faces from device buffer to shared memory ghost cells
630__device__ void unpack_halo_faces(
631 float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
632 const float* halo_buffers,
633 const K2KRouteEntry* route,
634 int pingpong,
635 int num_blocks
636) {{
637 int tid = threadIdx.x;
638 int face_stride = FACE_SIZE;
639 int block_stride = 6 * face_stride * 2;
640 int pp_offset = pingpong * 6 * face_stride;
641
642 if (tid < FACE_SIZE) {{
643 int fx = tid % TILE_X;
644 int fy = tid / TILE_X;
645
646 // My +X ghost comes from neighbor's -X face
647 if (route->neighbors.pos_x >= 0) {{
648 const float* n_halo = halo_buffers + route->neighbors.pos_x * block_stride + pp_offset;
649 tile[fy + 1][fx + 1][TILE_X + 1] = n_halo[FACE_NEG_X * face_stride + tid];
650 }}
651
652 // My -X ghost comes from neighbor's +X face
653 if (route->neighbors.neg_x >= 0) {{
654 const float* n_halo = halo_buffers + route->neighbors.neg_x * block_stride + pp_offset;
655 tile[fy + 1][fx + 1][0] = n_halo[FACE_POS_X * face_stride + tid];
656 }}
657
658 // My +Y ghost
659 if (route->neighbors.pos_y >= 0) {{
660 const float* n_halo = halo_buffers + route->neighbors.pos_y * block_stride + pp_offset;
661 tile[fy + 1][TILE_Y + 1][fx + 1] = n_halo[FACE_NEG_Y * face_stride + tid];
662 }}
663
664 // My -Y ghost
665 if (route->neighbors.neg_y >= 0) {{
666 const float* n_halo = halo_buffers + route->neighbors.neg_y * block_stride + pp_offset;
667 tile[fy + 1][0][fx + 1] = n_halo[FACE_POS_Y * face_stride + tid];
668 }}
669
670 // My +Z ghost
671 if (route->neighbors.pos_z >= 0) {{
672 const float* n_halo = halo_buffers + route->neighbors.pos_z * block_stride + pp_offset;
673 tile[TILE_Z + 1][fy + 1][fx + 1] = n_halo[FACE_NEG_Z * face_stride + tid];
674 }}
675
676 // My -Z ghost
677 if (route->neighbors.neg_z >= 0) {{
678 const float* n_halo = halo_buffers + route->neighbors.neg_z * block_stride + pp_offset;
679 tile[0][fy + 1][fx + 1] = n_halo[FACE_POS_Z * face_stride + tid];
680 }}
681 }}
682}}
683
684"#,
685 tx = tx,
686 ty = ty,
687 tz = tz,
688 face_size = face_size
689 ));
690
691 code
692}
693
694fn generate_main_kernel(config: &PersistentFdtdConfig) -> String {
695 let (tx, ty, tz) = config.tile_size;
696 let threads_per_block = tx * ty * tz;
697 let progress_interval = config.progress_interval;
698
699 let grid_sync = if config.use_cooperative {
702 "grid.sync();"
703 } else {
704 "software_grid_sync(&ctrl->barrier_counter, &ctrl->barrier_generation, num_blocks);"
705 };
706
707 format!(
708 r#"
709// ============================================================================
710// MAIN PERSISTENT KERNEL
711// ============================================================================
712
713extern "C" __global__ void __launch_bounds__({threads_per_block}, 2)
714{name}(
715 PersistentControlBlock* __restrict__ ctrl,
716 float* __restrict__ pressure_a,
717 float* __restrict__ pressure_b,
718 SpscQueueHeader* __restrict__ h2k_header,
719 H2KMessage* __restrict__ h2k_slots,
720 SpscQueueHeader* __restrict__ k2h_header,
721 K2HMessage* __restrict__ k2h_slots,
722 const K2KRouteEntry* __restrict__ routes,
723 float* __restrict__ halo_buffers
724) {{
725 // Block and thread indices
726 int block_id = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
727 int num_blocks = gridDim.x * gridDim.y * gridDim.z;
728 int tid = threadIdx.x;
729 bool is_coordinator = (block_id == 0 && tid == 0);
730
731 // Shared memory for tile + ghost cells
732 __shared__ float tile[{tz} + 2][{ty} + 2][{tx} + 2];
733
734 // Shared memory for energy reduction
735 __shared__ float energy_reduce[{threads_per_block}];
736
737 // Get my routing info
738 const K2KRouteEntry* my_route = &routes[block_id];
739
740 // Pointers to pressure buffers (will swap)
741 float* p_curr = pressure_a;
742 float* p_prev = pressure_b;
743
744 // Grid dimensions for indexing
745 int sim_x = ctrl->sim_size[0];
746 int sim_y = ctrl->sim_size[1];
747 int sim_z = ctrl->sim_size[2];
748
749 // Local cell coordinates within tile
750 int lx = tid % {tx};
751 int ly = (tid / {tx}) % {ty};
752 int lz = tid / ({tx} * {ty});
753
754 // Global cell coordinates
755 int gx = my_route->cell_offset[0] + lx;
756 int gy = my_route->cell_offset[1] + ly;
757 int gz = my_route->cell_offset[2] + lz;
758 int global_idx = gx + gy * sim_x + gz * sim_x * sim_y;
759
760 // Physics parameters
761 float c2_dt2 = ctrl->c2_dt2;
762 float damping = ctrl->damping;
763
764 {grid_sync_init}
765
766 // ========== PERSISTENT LOOP ==========
767 while (true) {{
768 // --- Phase 1: Command Processing (coordinator only) ---
769 if (is_coordinator) {{
770 H2KMessage cmd;
771 while (h2k_try_recv(h2k_header, h2k_slots, &cmd)) {{
772 ctrl->messages_processed++;
773
774 switch (cmd.cmd) {{
775 case CMD_RUN_STEPS:
776 ctrl->steps_remaining += cmd.param1;
777 break;
778
779 case CMD_TERMINATE:
780 ctrl->should_terminate = 1;
781 break;
782
783 case CMD_INJECT_IMPULSE: {{
784 // Extract position from params
785 uint32_t ix = (cmd.param1 >> 32) & 0xFFFFFFFF;
786 uint32_t iy = cmd.param1 & 0xFFFFFFFF;
787 uint32_t iz = cmd.param2;
788 float amplitude = cmd.param4;
789
790 if (ix < sim_x && iy < sim_y && iz < sim_z) {{
791 int idx = ix + iy * sim_x + iz * sim_x * sim_y;
792 p_curr[idx] += amplitude;
793 }}
794 break;
795 }}
796
797 case CMD_GET_PROGRESS: {{
798 K2HMessage resp;
799 resp.resp_type = RESP_PROGRESS;
800 resp.cmd_id = cmd.cmd_id;
801 resp.step = ctrl->current_step;
802 resp.steps_remaining = ctrl->steps_remaining;
803 resp.energy = ctrl->total_energy;
804 k2h_send(k2h_header, k2h_slots, &resp);
805 break;
806 }}
807
808 default:
809 break;
810 }}
811 }}
812 }}
813
814 // Grid-wide sync so all blocks see updated control state
815 {grid_sync}
816
817 // Check for termination
818 if (ctrl->should_terminate) {{
819 break;
820 }}
821
822 // --- Phase 2: Check if we have work ---
823 if (ctrl->steps_remaining == 0) {{
824 // No work - sleep to reduce power consumption, then check again
825 __nanosleep({idle_sleep_ns});
826 {grid_sync}
827 continue;
828 }}
829
830 // --- Phase 3: Load tile from global memory ---
831 // Load interior cells
832 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
833 tile[lz + 1][ly + 1][lx + 1] = p_curr[global_idx];
834 }} else {{
835 tile[lz + 1][ly + 1][lx + 1] = 0.0f;
836 }}
837
838 // Initialize ghost cells to boundary (zero pressure)
839 if (lx == 0) tile[lz + 1][ly + 1][0] = 0.0f;
840 if (lx == {tx} - 1) tile[lz + 1][ly + 1][{tx} + 1] = 0.0f;
841 if (ly == 0) tile[lz + 1][0][lx + 1] = 0.0f;
842 if (ly == {ty} - 1) tile[lz + 1][{ty} + 1][lx + 1] = 0.0f;
843 if (lz == 0) tile[0][ly + 1][lx + 1] = 0.0f;
844 if (lz == {tz} - 1) tile[{tz} + 1][ly + 1][lx + 1] = 0.0f;
845
846 __syncthreads();
847
848 // --- Phase 4: K2K Halo Exchange ---
849 int pingpong = ctrl->current_step & 1;
850
851 // Pack my boundary faces into halo buffers
852 pack_halo_faces(tile, halo_buffers, block_id, pingpong, num_blocks);
853 __threadfence(); // Ensure writes visible to other blocks
854
855 // Grid-wide sync - wait for all blocks to finish packing
856 {grid_sync}
857
858 // Unpack neighbor faces into my ghost cells
859 unpack_halo_faces(tile, halo_buffers, my_route, pingpong, num_blocks);
860 __syncthreads();
861
862 // --- Phase 5: FDTD Computation ---
863 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
864 // 7-point Laplacian from shared memory
865 float center = tile[lz + 1][ly + 1][lx + 1];
866 float lap = tile[lz + 1][ly + 1][lx + 2] // +X
867 + tile[lz + 1][ly + 1][lx] // -X
868 + tile[lz + 1][ly + 2][lx + 1] // +Y
869 + tile[lz + 1][ly][lx + 1] // -Y
870 + tile[lz + 2][ly + 1][lx + 1] // +Z
871 + tile[lz][ly + 1][lx + 1] // -Z
872 - 6.0f * center;
873
874 // FDTD update: p_new = 2*p - p_prev + c^2*dt^2*lap
875 float p_prev_val = p_prev[global_idx];
876 float p_new = 2.0f * center - p_prev_val + c2_dt2 * lap;
877 p_new *= damping;
878
879 // Write to "previous" buffer (will become current after swap)
880 p_prev[global_idx] = p_new;
881 }}
882
883 // --- Phase 5b: Energy Calculation (periodic) ---
884 // Check if this step will report progress (after counter increment)
885 bool is_progress_step = ((ctrl->current_step + 1) % {progress_interval}) == 0;
886
887 if (is_progress_step) {{
888 // Reset energy accumulator (coordinator only, before reduction)
889 if (is_coordinator) {{
890 ctrl->total_energy = 0.0f;
891 }}
892 __threadfence(); // Ensure reset is visible
893
894 // Compute this thread's energy contribution: E = p^2
895 float my_energy = 0.0f;
896 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
897 float p_val = p_prev[global_idx]; // Just written value
898 my_energy = p_val * p_val;
899 }}
900
901 // Block-level parallel reduction
902 float block_energy = block_reduce_energy(my_energy, energy_reduce, {threads_per_block});
903
904 // Block leader accumulates to global energy
905 if (tid == 0) {{
906 atomicAdd(&ctrl->total_energy, block_energy);
907 }}
908 }}
909
910 // Grid-wide sync before buffer swap
911 {grid_sync}
912
913 // --- Phase 6: Buffer Swap & Progress ---
914 if (is_coordinator) {{
915 // Swap buffer pointers (toggle index)
916 ctrl->current_buffer ^= 1;
917
918 // Update counters
919 ctrl->current_step++;
920 ctrl->steps_remaining--;
921
922 // Send progress update periodically
923 if (ctrl->current_step % {progress_interval} == 0) {{
924 K2HMessage resp;
925 resp.resp_type = RESP_PROGRESS;
926 resp.cmd_id = 0;
927 resp.step = ctrl->current_step;
928 resp.steps_remaining = ctrl->steps_remaining;
929 resp.energy = ctrl->total_energy; // Computed in Phase 5b
930 k2h_send(k2h_header, k2h_slots, &resp);
931 }}
932 }}
933
934 // Swap our local pointers
935 float* tmp = p_curr;
936 p_curr = p_prev;
937 p_prev = tmp;
938
939 {grid_sync}
940 }}
941
942 // ========== FINAL ENERGY CALCULATION ==========
943 // Compute final energy before termination (all threads participate)
944 if (is_coordinator) {{
945 ctrl->total_energy = 0.0f;
946 }}
947 __threadfence();
948
949 // Each thread's final energy contribution
950 float final_energy = 0.0f;
951 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
952 float p_val = p_curr[global_idx]; // Current buffer has final state
953 final_energy = p_val * p_val;
954 }}
955
956 // Block-level parallel reduction
957 float block_final_energy = block_reduce_energy(final_energy, energy_reduce, {threads_per_block});
958
959 // Block leader accumulates to global energy
960 if (tid == 0) {{
961 atomicAdd(&ctrl->total_energy, block_final_energy);
962 }}
963
964 {grid_sync}
965
966 // ========== CLEANUP ==========
967 if (is_coordinator) {{
968 ctrl->has_terminated = 1;
969
970 K2HMessage resp;
971 resp.resp_type = RESP_TERMINATED;
972 resp.cmd_id = 0;
973 resp.step = ctrl->current_step;
974 resp.steps_remaining = 0;
975 resp.energy = ctrl->total_energy; // Final computed energy
976 k2h_send(k2h_header, k2h_slots, &resp);
977 }}
978}}
979"#,
980 name = config.name,
981 tx = tx,
982 ty = ty,
983 tz = tz,
984 threads_per_block = threads_per_block,
985 progress_interval = progress_interval,
986 grid_sync_init = if config.use_cooperative {
987 "cg::grid_group grid = cg::this_grid();"
988 } else {
989 ""
990 },
991 grid_sync = grid_sync,
992 idle_sleep_ns = config.idle_sleep_ns,
993 )
994}
995
996#[cfg(feature = "nvcc")]
1000pub fn compile_persistent_fdtd_to_ptx(config: &PersistentFdtdConfig) -> Result<String, String> {
1001 let cuda_code = generate_persistent_fdtd_kernel(config);
1002
1003 use std::process::Command;
1005
1006 let temp_dir = std::env::temp_dir();
1008 let cuda_file = temp_dir.join("persistent_fdtd.cu");
1009 let ptx_file = temp_dir.join("persistent_fdtd.ptx");
1010
1011 std::fs::write(&cuda_file, &cuda_code)
1012 .map_err(|e| format!("Failed to write CUDA file: {}", e))?;
1013
1014 let mut args = vec![
1018 "-ptx".to_string(),
1019 "-o".to_string(),
1020 ptx_file.to_string_lossy().to_string(),
1021 cuda_file.to_string_lossy().to_string(),
1022 "-arch=native".to_string(),
1023 "-std=c++17".to_string(),
1024 ];
1025
1026 if config.use_cooperative {
1027 args.push("-rdc=true".to_string()); }
1029
1030 let output = Command::new("nvcc")
1031 .args(&args)
1032 .output()
1033 .map_err(|e| format!("Failed to run nvcc: {}", e))?;
1034
1035 if !output.status.success() {
1036 return Err(format!(
1037 "nvcc compilation failed:\n{}",
1038 String::from_utf8_lossy(&output.stderr)
1039 ));
1040 }
1041
1042 std::fs::read_to_string(&ptx_file).map_err(|e| format!("Failed to read PTX: {}", e))
1043}
1044
1045#[cfg(test)]
1046mod tests {
1047 use super::*;
1048
1049 #[test]
1050 fn test_default_config() {
1051 let config = PersistentFdtdConfig::default();
1052 assert_eq!(config.name, "persistent_fdtd3d");
1053 assert_eq!(config.tile_size, (8, 8, 8));
1054 assert_eq!(config.threads_per_block(), 512);
1055 }
1056
1057 #[test]
1058 fn test_config_builder() {
1059 let config = PersistentFdtdConfig::new("my_kernel")
1060 .with_tile_size(4, 4, 4)
1061 .with_cooperative(false)
1062 .with_progress_interval(50);
1063
1064 assert_eq!(config.name, "my_kernel");
1065 assert_eq!(config.tile_size, (4, 4, 4));
1066 assert_eq!(config.threads_per_block(), 64);
1067 assert!(!config.use_cooperative);
1068 assert_eq!(config.progress_interval, 50);
1069 }
1070
1071 #[test]
1072 fn test_shared_mem_calculation() {
1073 let config = PersistentFdtdConfig::default(); assert_eq!(config.shared_mem_size(), 4000);
1076 }
1077
1078 #[test]
1079 fn test_generate_kernel_cooperative() {
1080 let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(true);
1081
1082 let code = generate_persistent_fdtd_kernel(&config);
1083
1084 assert!(code.contains("#include <cooperative_groups.h>"));
1086 assert!(code.contains("namespace cg = cooperative_groups;"));
1087
1088 assert!(code.contains("typedef struct __align__(256)"));
1090 assert!(code.contains("PersistentControlBlock"));
1091 assert!(code.contains("SpscQueueHeader"));
1092 assert!(code.contains("H2KMessage"));
1093 assert!(code.contains("K2HMessage"));
1094
1095 assert!(code.contains("__device__ bool h2k_try_recv"));
1097 assert!(code.contains("__device__ bool k2h_send"));
1098 assert!(code.contains("__device__ void pack_halo_faces"));
1099 assert!(code.contains("__device__ void unpack_halo_faces"));
1100
1101 assert!(code.contains("extern \"C\" __global__ void"));
1103 assert!(code.contains("test_kernel"));
1104 assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
1105 assert!(code.contains("grid.sync()"));
1106
1107 assert!(code.contains("7-point Laplacian"));
1109 assert!(code.contains("c2_dt2 * lap"));
1110 }
1111
1112 #[test]
1113 fn test_generate_kernel_software_sync() {
1114 let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(false);
1115
1116 let code = generate_persistent_fdtd_kernel(&config);
1117
1118 assert!(!code.contains("#include <cooperative_groups.h>"));
1120
1121 assert!(code.contains("software_grid_sync"));
1123 assert!(code.contains("atomicAdd"));
1124 }
1125
1126 #[test]
1127 fn test_generate_kernel_command_handling() {
1128 let config = PersistentFdtdConfig::default();
1129 let code = generate_persistent_fdtd_kernel(&config);
1130
1131 assert!(code.contains("CMD_RUN_STEPS"));
1133 assert!(code.contains("CMD_TERMINATE"));
1134 assert!(code.contains("CMD_INJECT_IMPULSE"));
1135 assert!(code.contains("CMD_GET_PROGRESS"));
1136
1137 assert!(code.contains("RESP_PROGRESS"));
1139 assert!(code.contains("RESP_TERMINATED"));
1140 }
1141
1142 #[test]
1143 fn test_generate_kernel_halo_exchange() {
1144 let config = PersistentFdtdConfig::new("test").with_tile_size(8, 8, 8);
1145
1146 let code = generate_persistent_fdtd_kernel(&config);
1147
1148 assert!(code.contains("#define TILE_X 8"));
1150 assert!(code.contains("#define TILE_Y 8"));
1151 assert!(code.contains("#define TILE_Z 8"));
1152 assert!(code.contains("#define FACE_SIZE 64"));
1153
1154 assert!(code.contains("FACE_POS_X"));
1156 assert!(code.contains("FACE_NEG_Z"));
1157
1158 assert!(code.contains("pack_halo_faces"));
1160 assert!(code.contains("unpack_halo_faces"));
1161 }
1162
1163 #[test]
1164 fn test_kernel_contains_persistent_loop() {
1165 let config = PersistentFdtdConfig::default();
1166 let code = generate_persistent_fdtd_kernel(&config);
1167
1168 assert!(code.contains("while (true)"));
1170 assert!(code.contains("if (ctrl->should_terminate)"));
1171 assert!(code.contains("break;"));
1172
1173 assert!(code.contains("if (ctrl->steps_remaining == 0)"));
1175 assert!(code.contains("__nanosleep(1000)"));
1176 }
1177
1178 #[test]
1179 fn test_idle_sleep_configurable() {
1180 let config = PersistentFdtdConfig::new("test_sleep").with_idle_sleep(5000);
1181 let code = generate_persistent_fdtd_kernel(&config);
1182
1183 assert!(code.contains("__nanosleep(5000)"));
1184 assert!(!code.contains("__nanosleep(1000)"));
1185 }
1186
1187 #[test]
1188 fn test_software_barrier_nanosleep() {
1189 let config = PersistentFdtdConfig::new("test_barrier").with_cooperative(false);
1190 let code = generate_persistent_fdtd_kernel(&config);
1191
1192 assert!(code.contains("__nanosleep(100)"));
1194 assert!(code.contains("Reduce power in barrier spin"));
1195 }
1196
1197 #[test]
1198 fn test_kernel_contains_energy_calculation() {
1199 let config = PersistentFdtdConfig::new("test_energy")
1200 .with_tile_size(8, 8, 8)
1201 .with_progress_interval(100);
1202
1203 let code = generate_persistent_fdtd_kernel(&config);
1204
1205 assert!(code.contains("block_reduce_energy"));
1207 assert!(code.contains("E = sum(p^2)"));
1208 assert!(code.contains("Warp-Shuffle Reduction"));
1209 assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, val, offset)"));
1210
1211 assert!(code.contains("energy_reduce[512]")); assert!(code.contains("is_progress_step"));
1216 assert!(code.contains("% 100"));
1217
1218 assert!(code.contains("ctrl->total_energy = 0.0f"));
1220
1221 assert!(code.contains("p_val * p_val"));
1223
1224 assert!(code.contains("atomicAdd(&ctrl->total_energy"));
1226
1227 assert!(code.contains("resp.energy = ctrl->total_energy"));
1229
1230 assert!(code.contains("FINAL ENERGY CALCULATION"));
1232 assert!(code.contains("block_final_energy"));
1233 }
1234
1235 #[test]
1236 fn test_generate_kernel_libcupp_atomics() {
1237 let config = PersistentFdtdConfig::new("test_libcupp")
1238 .with_libcupp_atomics(true)
1239 .with_cooperative(false);
1240
1241 let code = generate_persistent_fdtd_kernel(&config);
1242
1243 assert!(code.contains("#include <cuda/atomic>"));
1245 assert!(code.contains("__CUDACC_VER_MAJOR__ < 11"));
1246
1247 assert!(code.contains("cuda::atomic_ref<uint64_t, cuda::thread_scope_system>"));
1249 assert!(code.contains("memory_order_acquire"));
1250 assert!(code.contains("memory_order_release"));
1251
1252 assert!(!code.contains("__threadfence_system()"));
1254
1255 assert!(code.contains("cuda::atomic_ref<unsigned int, cuda::thread_scope_device>"));
1257 assert!(code.contains("memory_order_acq_rel"));
1258 }
1259
1260 #[test]
1261 fn test_generate_kernel_default_no_libcupp() {
1262 let config = PersistentFdtdConfig::default();
1263 let code = generate_persistent_fdtd_kernel(&config);
1264
1265 assert!(!code.contains("#include <cuda/atomic>"));
1267 assert!(!code.contains("cuda::atomic_ref"));
1268
1269 assert!(code.contains("__threadfence_system()"));
1271 assert!(code.contains("volatile"));
1272 }
1273}