1#[derive(Debug, Clone)]
36pub struct PersistentFdtdConfig {
37 pub name: String,
39 pub tile_size: (usize, usize, usize),
41 pub use_cooperative: bool,
43 pub progress_interval: u64,
45 pub track_energy: bool,
47}
48
49impl Default for PersistentFdtdConfig {
50 fn default() -> Self {
51 Self {
52 name: "persistent_fdtd3d".to_string(),
53 tile_size: (8, 8, 8),
54 use_cooperative: true,
55 progress_interval: 100,
56 track_energy: true,
57 }
58 }
59}
60
61impl PersistentFdtdConfig {
62 pub fn new(name: &str) -> Self {
64 Self {
65 name: name.to_string(),
66 ..Default::default()
67 }
68 }
69
70 pub fn with_tile_size(mut self, tx: usize, ty: usize, tz: usize) -> Self {
72 self.tile_size = (tx, ty, tz);
73 self
74 }
75
76 pub fn with_cooperative(mut self, use_coop: bool) -> Self {
78 self.use_cooperative = use_coop;
79 self
80 }
81
82 pub fn with_progress_interval(mut self, interval: u64) -> Self {
84 self.progress_interval = interval;
85 self
86 }
87
88 pub fn threads_per_block(&self) -> usize {
90 self.tile_size.0 * self.tile_size.1 * self.tile_size.2
91 }
92
93 pub fn shared_mem_size(&self) -> usize {
95 let (tx, ty, tz) = self.tile_size;
96 let with_halo = (tx + 2) * (ty + 2) * (tz + 2);
98 with_halo * std::mem::size_of::<f32>()
99 }
100}
101
102pub fn generate_persistent_fdtd_kernel(config: &PersistentFdtdConfig) -> String {
111 let mut code = String::new();
112
113 code.push_str(&generate_header(config));
115
116 code.push_str(&generate_structures());
118
119 code.push_str(&generate_device_functions(config));
121
122 code.push_str(&generate_main_kernel(config));
124
125 code
126}
127
128fn generate_header(config: &PersistentFdtdConfig) -> String {
129 let mut code = String::new();
130
131 code.push_str("// Generated Persistent FDTD Kernel\n");
132 code.push_str("// RingKernel GPU Actor System\n\n");
133
134 if config.use_cooperative {
135 code.push_str("#include <cooperative_groups.h>\n");
136 code.push_str("namespace cg = cooperative_groups;\n\n");
137 }
138
139 code.push_str("#include <cuda_runtime.h>\n");
140 code.push_str("#include <stdint.h>\n\n");
141
142 code
143}
144
145fn generate_structures() -> String {
146 r#"
147// ============================================================================
148// STRUCTURE DEFINITIONS (must match Rust persistent.rs)
149// ============================================================================
150
151// Control block for persistent kernel (256 bytes, cache-aligned)
152typedef struct __align__(256) {
153 // Lifecycle control (host -> GPU)
154 uint32_t should_terminate;
155 uint32_t _pad0;
156 uint64_t steps_remaining;
157
158 // State (GPU -> host)
159 uint64_t current_step;
160 uint32_t current_buffer;
161 uint32_t has_terminated;
162 float total_energy;
163 uint32_t _pad1;
164
165 // Configuration
166 uint32_t grid_dim[3];
167 uint32_t block_dim[3];
168 uint32_t sim_size[3];
169 uint32_t tile_size[3];
170
171 // Acoustic parameters
172 float c2_dt2;
173 float damping;
174 float cell_size;
175 float dt;
176
177 // Synchronization
178 uint32_t barrier_counter;
179 uint32_t barrier_generation;
180
181 // Statistics
182 uint64_t messages_processed;
183 uint64_t k2k_messages_sent;
184 uint64_t k2k_messages_received;
185
186 uint64_t _reserved[16];
187} PersistentControlBlock;
188
189// SPSC queue header (128 bytes, cache-aligned)
190typedef struct __align__(128) {
191 uint64_t head;
192 uint64_t tail;
193 uint32_t capacity;
194 uint32_t mask;
195 uint64_t _padding[12];
196} SpscQueueHeader;
197
198// H2K message (64 bytes)
199typedef struct __align__(64) {
200 uint32_t cmd;
201 uint32_t flags;
202 uint64_t cmd_id;
203 uint64_t param1;
204 uint32_t param2;
205 uint32_t param3;
206 float param4;
207 float param5;
208 uint64_t _reserved[4];
209} H2KMessage;
210
211// K2H message (64 bytes)
212typedef struct __align__(64) {
213 uint32_t resp_type;
214 uint32_t flags;
215 uint64_t cmd_id;
216 uint64_t step;
217 uint64_t steps_remaining;
218 float energy;
219 uint32_t error_code;
220 uint64_t _reserved[3];
221} K2HMessage;
222
223// Command types
224#define CMD_NOP 0
225#define CMD_RUN_STEPS 1
226#define CMD_PAUSE 2
227#define CMD_RESUME 3
228#define CMD_TERMINATE 4
229#define CMD_INJECT_IMPULSE 5
230#define CMD_SET_SOURCE 6
231#define CMD_GET_PROGRESS 7
232
233// Response types
234#define RESP_ACK 0
235#define RESP_PROGRESS 1
236#define RESP_ERROR 2
237#define RESP_TERMINATED 3
238#define RESP_ENERGY 4
239
240// Neighbor block IDs
241typedef struct {
242 int32_t pos_x, neg_x;
243 int32_t pos_y, neg_y;
244 int32_t pos_z, neg_z;
245 int32_t _padding[2];
246} BlockNeighbors;
247
248// K2K route entry
249typedef struct {
250 BlockNeighbors neighbors;
251 uint32_t block_pos[3];
252 uint32_t _padding;
253 uint32_t cell_offset[3];
254 uint32_t _padding2;
255} K2KRouteEntry;
256
257// Face indices
258#define FACE_POS_X 0
259#define FACE_NEG_X 1
260#define FACE_POS_Y 2
261#define FACE_NEG_Y 3
262#define FACE_POS_Z 4
263#define FACE_NEG_Z 5
264
265"#
266 .to_string()
267}
268
269fn generate_device_functions(config: &PersistentFdtdConfig) -> String {
270 let (tx, ty, tz) = config.tile_size;
271 let face_size = tx * ty;
272
273 let mut code = String::new();
274
275 code.push_str(
277 r#"
278// ============================================================================
279// SYNCHRONIZATION FUNCTIONS
280// ============================================================================
281
282// Software grid barrier (atomic counter + generation)
283__device__ void software_grid_sync(
284 volatile uint32_t* barrier_counter,
285 volatile uint32_t* barrier_gen,
286 int num_blocks
287) {
288 __syncthreads(); // First sync within block
289
290 if (threadIdx.x == 0) {
291 unsigned int gen = *barrier_gen;
292
293 unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
294 if (arrived == num_blocks) {
295 *barrier_counter = 0;
296 __threadfence();
297 atomicAdd((unsigned int*)barrier_gen, 1);
298 } else {
299 while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
300 __threadfence();
301 }
302 }
303 }
304
305 __syncthreads();
306}
307
308"#,
309 );
310
311 code.push_str(
313 r#"
314// ============================================================================
315// MESSAGE QUEUE OPERATIONS
316// ============================================================================
317
318// Copy H2K message from volatile source (field-by-field to handle volatile)
319__device__ void copy_h2k_message(H2KMessage* dst, volatile H2KMessage* src) {
320 dst->cmd = src->cmd;
321 dst->flags = src->flags;
322 dst->cmd_id = src->cmd_id;
323 dst->param1 = src->param1;
324 dst->param2 = src->param2;
325 dst->param3 = src->param3;
326 dst->param4 = src->param4;
327 dst->param5 = src->param5;
328 // Skip reserved fields for performance
329}
330
331// Copy K2H message to volatile destination (field-by-field to handle volatile)
332__device__ void copy_k2h_message(volatile K2HMessage* dst, const K2HMessage* src) {
333 dst->resp_type = src->resp_type;
334 dst->flags = src->flags;
335 dst->cmd_id = src->cmd_id;
336 dst->step = src->step;
337 dst->steps_remaining = src->steps_remaining;
338 dst->energy = src->energy;
339 dst->error_code = src->error_code;
340 // Skip reserved fields for performance
341}
342
343// Try to receive H2K message (returns true if message available)
344__device__ bool h2k_try_recv(
345 volatile SpscQueueHeader* header,
346 volatile H2KMessage* slots,
347 H2KMessage* out_msg
348) {
349 // Fence BEFORE reading to ensure we see host writes
350 __threadfence_system();
351
352 uint64_t head = header->head;
353 uint64_t tail = header->tail;
354
355 if (head == tail) {
356 return false; // Empty
357 }
358
359 uint32_t slot = tail & header->mask;
360 copy_h2k_message(out_msg, &slots[slot]);
361
362 __threadfence();
363 header->tail = tail + 1;
364
365 return true;
366}
367
368// Send K2H message
369__device__ bool k2h_send(
370 volatile SpscQueueHeader* header,
371 volatile K2HMessage* slots,
372 const K2HMessage* msg
373) {
374 uint64_t head = header->head;
375 uint64_t tail = header->tail;
376 uint32_t capacity = header->capacity;
377
378 if (head - tail >= capacity) {
379 return false; // Full
380 }
381
382 uint32_t slot = head & header->mask;
383 copy_k2h_message(&slots[slot], msg);
384
385 __threadfence_system(); // Ensure host sees our writes
386 header->head = head + 1;
387
388 return true;
389}
390
391"#,
392 );
393
394 code.push_str(
396 r#"
397// ============================================================================
398// ENERGY CALCULATION (Parallel Reduction)
399// ============================================================================
400
401// Block-level parallel reduction for energy: E = sum(p^2)
402// Uses shared memory for efficient reduction within a block
403__device__ float block_reduce_energy(
404 float my_energy,
405 float* shared_reduce,
406 int threads_per_block
407) {
408 int tid = threadIdx.x;
409
410 // Store initial value in shared memory
411 shared_reduce[tid] = my_energy;
412 __syncthreads();
413
414 // Parallel reduction tree
415 for (int stride = threads_per_block / 2; stride > 0; stride >>= 1) {
416 if (tid < stride) {
417 shared_reduce[tid] += shared_reduce[tid + stride];
418 }
419 __syncthreads();
420 }
421
422 // Return the total for this block (only thread 0 has final sum)
423 return shared_reduce[0];
424}
425
426"#,
427 );
428
429 code.push_str(&format!(
431 r#"
432// ============================================================================
433// K2K HALO EXCHANGE
434// ============================================================================
435
436#define TILE_X {tx}
437#define TILE_Y {ty}
438#define TILE_Z {tz}
439#define FACE_SIZE {face_size}
440
441// Pack halo faces from shared memory to device halo buffer
442__device__ void pack_halo_faces(
443 float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
444 float* halo_buffers,
445 int block_id,
446 int pingpong,
447 int num_blocks
448) {{
449 int tid = threadIdx.x;
450 int face_stride = FACE_SIZE;
451 int block_stride = 6 * face_stride * 2; // 6 faces, 2 ping-pong
452 int pp_offset = pingpong * 6 * face_stride;
453
454 float* block_halo = halo_buffers + block_id * block_stride + pp_offset;
455
456 if (tid < FACE_SIZE) {{
457 int fx = tid % TILE_X;
458 int fy = tid / TILE_X;
459
460 // +X face (x = TILE_X in local coords, which is index TILE_X)
461 block_halo[FACE_POS_X * face_stride + tid] = tile[fy + 1][fx + 1][TILE_X];
462 // -X face (x = 1 in local coords)
463 block_halo[FACE_NEG_X * face_stride + tid] = tile[fy + 1][fx + 1][1];
464 // +Y face
465 block_halo[FACE_POS_Y * face_stride + tid] = tile[fy + 1][TILE_Y][fx + 1];
466 // -Y face
467 block_halo[FACE_NEG_Y * face_stride + tid] = tile[fy + 1][1][fx + 1];
468 // +Z face
469 block_halo[FACE_POS_Z * face_stride + tid] = tile[TILE_Z][fy + 1][fx + 1];
470 // -Z face
471 block_halo[FACE_NEG_Z * face_stride + tid] = tile[1][fy + 1][fx + 1];
472 }}
473}}
474
475// Unpack halo faces from device buffer to shared memory ghost cells
476__device__ void unpack_halo_faces(
477 float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
478 const float* halo_buffers,
479 const K2KRouteEntry* route,
480 int pingpong,
481 int num_blocks
482) {{
483 int tid = threadIdx.x;
484 int face_stride = FACE_SIZE;
485 int block_stride = 6 * face_stride * 2;
486 int pp_offset = pingpong * 6 * face_stride;
487
488 if (tid < FACE_SIZE) {{
489 int fx = tid % TILE_X;
490 int fy = tid / TILE_X;
491
492 // My +X ghost comes from neighbor's -X face
493 if (route->neighbors.pos_x >= 0) {{
494 const float* n_halo = halo_buffers + route->neighbors.pos_x * block_stride + pp_offset;
495 tile[fy + 1][fx + 1][TILE_X + 1] = n_halo[FACE_NEG_X * face_stride + tid];
496 }}
497
498 // My -X ghost comes from neighbor's +X face
499 if (route->neighbors.neg_x >= 0) {{
500 const float* n_halo = halo_buffers + route->neighbors.neg_x * block_stride + pp_offset;
501 tile[fy + 1][fx + 1][0] = n_halo[FACE_POS_X * face_stride + tid];
502 }}
503
504 // My +Y ghost
505 if (route->neighbors.pos_y >= 0) {{
506 const float* n_halo = halo_buffers + route->neighbors.pos_y * block_stride + pp_offset;
507 tile[fy + 1][TILE_Y + 1][fx + 1] = n_halo[FACE_NEG_Y * face_stride + tid];
508 }}
509
510 // My -Y ghost
511 if (route->neighbors.neg_y >= 0) {{
512 const float* n_halo = halo_buffers + route->neighbors.neg_y * block_stride + pp_offset;
513 tile[fy + 1][0][fx + 1] = n_halo[FACE_POS_Y * face_stride + tid];
514 }}
515
516 // My +Z ghost
517 if (route->neighbors.pos_z >= 0) {{
518 const float* n_halo = halo_buffers + route->neighbors.pos_z * block_stride + pp_offset;
519 tile[TILE_Z + 1][fy + 1][fx + 1] = n_halo[FACE_NEG_Z * face_stride + tid];
520 }}
521
522 // My -Z ghost
523 if (route->neighbors.neg_z >= 0) {{
524 const float* n_halo = halo_buffers + route->neighbors.neg_z * block_stride + pp_offset;
525 tile[0][fy + 1][fx + 1] = n_halo[FACE_POS_Z * face_stride + tid];
526 }}
527 }}
528}}
529
530"#,
531 tx = tx,
532 ty = ty,
533 tz = tz,
534 face_size = face_size
535 ));
536
537 code
538}
539
540fn generate_main_kernel(config: &PersistentFdtdConfig) -> String {
541 let (tx, ty, tz) = config.tile_size;
542 let threads_per_block = tx * ty * tz;
543 let progress_interval = config.progress_interval;
544
545 let grid_sync = if config.use_cooperative {
548 "grid.sync();"
549 } else {
550 "software_grid_sync(&ctrl->barrier_counter, &ctrl->barrier_generation, num_blocks);"
551 };
552
553 format!(
554 r#"
555// ============================================================================
556// MAIN PERSISTENT KERNEL
557// ============================================================================
558
559extern "C" __global__ void __launch_bounds__({threads_per_block}, 2)
560{name}(
561 PersistentControlBlock* __restrict__ ctrl,
562 float* __restrict__ pressure_a,
563 float* __restrict__ pressure_b,
564 SpscQueueHeader* __restrict__ h2k_header,
565 H2KMessage* __restrict__ h2k_slots,
566 SpscQueueHeader* __restrict__ k2h_header,
567 K2HMessage* __restrict__ k2h_slots,
568 const K2KRouteEntry* __restrict__ routes,
569 float* __restrict__ halo_buffers
570) {{
571 // Block and thread indices
572 int block_id = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
573 int num_blocks = gridDim.x * gridDim.y * gridDim.z;
574 int tid = threadIdx.x;
575 bool is_coordinator = (block_id == 0 && tid == 0);
576
577 // Shared memory for tile + ghost cells
578 __shared__ float tile[{tz} + 2][{ty} + 2][{tx} + 2];
579
580 // Shared memory for energy reduction
581 __shared__ float energy_reduce[{threads_per_block}];
582
583 // Get my routing info
584 const K2KRouteEntry* my_route = &routes[block_id];
585
586 // Pointers to pressure buffers (will swap)
587 float* p_curr = pressure_a;
588 float* p_prev = pressure_b;
589
590 // Grid dimensions for indexing
591 int sim_x = ctrl->sim_size[0];
592 int sim_y = ctrl->sim_size[1];
593 int sim_z = ctrl->sim_size[2];
594
595 // Local cell coordinates within tile
596 int lx = tid % {tx};
597 int ly = (tid / {tx}) % {ty};
598 int lz = tid / ({tx} * {ty});
599
600 // Global cell coordinates
601 int gx = my_route->cell_offset[0] + lx;
602 int gy = my_route->cell_offset[1] + ly;
603 int gz = my_route->cell_offset[2] + lz;
604 int global_idx = gx + gy * sim_x + gz * sim_x * sim_y;
605
606 // Physics parameters
607 float c2_dt2 = ctrl->c2_dt2;
608 float damping = ctrl->damping;
609
610 {grid_sync_init}
611
612 // ========== PERSISTENT LOOP ==========
613 while (true) {{
614 // --- Phase 1: Command Processing (coordinator only) ---
615 if (is_coordinator) {{
616 H2KMessage cmd;
617 while (h2k_try_recv(h2k_header, h2k_slots, &cmd)) {{
618 ctrl->messages_processed++;
619
620 switch (cmd.cmd) {{
621 case CMD_RUN_STEPS:
622 ctrl->steps_remaining += cmd.param1;
623 break;
624
625 case CMD_TERMINATE:
626 ctrl->should_terminate = 1;
627 break;
628
629 case CMD_INJECT_IMPULSE: {{
630 // Extract position from params
631 uint32_t ix = (cmd.param1 >> 32) & 0xFFFFFFFF;
632 uint32_t iy = cmd.param1 & 0xFFFFFFFF;
633 uint32_t iz = cmd.param2;
634 float amplitude = cmd.param4;
635
636 if (ix < sim_x && iy < sim_y && iz < sim_z) {{
637 int idx = ix + iy * sim_x + iz * sim_x * sim_y;
638 p_curr[idx] += amplitude;
639 }}
640 break;
641 }}
642
643 case CMD_GET_PROGRESS: {{
644 K2HMessage resp;
645 resp.resp_type = RESP_PROGRESS;
646 resp.cmd_id = cmd.cmd_id;
647 resp.step = ctrl->current_step;
648 resp.steps_remaining = ctrl->steps_remaining;
649 resp.energy = ctrl->total_energy;
650 k2h_send(k2h_header, k2h_slots, &resp);
651 break;
652 }}
653
654 default:
655 break;
656 }}
657 }}
658 }}
659
660 // Grid-wide sync so all blocks see updated control state
661 {grid_sync}
662
663 // Check for termination
664 if (ctrl->should_terminate) {{
665 break;
666 }}
667
668 // --- Phase 2: Check if we have work ---
669 if (ctrl->steps_remaining == 0) {{
670 // No work - brief spinwait then check again
671 // Use volatile counter to prevent optimization
672 volatile int spin_count = 0;
673 for (int i = 0; i < 1000; i++) {{
674 spin_count++;
675 }}
676 {grid_sync}
677 continue;
678 }}
679
680 // --- Phase 3: Load tile from global memory ---
681 // Load interior cells
682 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
683 tile[lz + 1][ly + 1][lx + 1] = p_curr[global_idx];
684 }} else {{
685 tile[lz + 1][ly + 1][lx + 1] = 0.0f;
686 }}
687
688 // Initialize ghost cells to boundary (zero pressure)
689 if (lx == 0) tile[lz + 1][ly + 1][0] = 0.0f;
690 if (lx == {tx} - 1) tile[lz + 1][ly + 1][{tx} + 1] = 0.0f;
691 if (ly == 0) tile[lz + 1][0][lx + 1] = 0.0f;
692 if (ly == {ty} - 1) tile[lz + 1][{ty} + 1][lx + 1] = 0.0f;
693 if (lz == 0) tile[0][ly + 1][lx + 1] = 0.0f;
694 if (lz == {tz} - 1) tile[{tz} + 1][ly + 1][lx + 1] = 0.0f;
695
696 __syncthreads();
697
698 // --- Phase 4: K2K Halo Exchange ---
699 int pingpong = ctrl->current_step & 1;
700
701 // Pack my boundary faces into halo buffers
702 pack_halo_faces(tile, halo_buffers, block_id, pingpong, num_blocks);
703 __threadfence(); // Ensure writes visible to other blocks
704
705 // Grid-wide sync - wait for all blocks to finish packing
706 {grid_sync}
707
708 // Unpack neighbor faces into my ghost cells
709 unpack_halo_faces(tile, halo_buffers, my_route, pingpong, num_blocks);
710 __syncthreads();
711
712 // --- Phase 5: FDTD Computation ---
713 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
714 // 7-point Laplacian from shared memory
715 float center = tile[lz + 1][ly + 1][lx + 1];
716 float lap = tile[lz + 1][ly + 1][lx + 2] // +X
717 + tile[lz + 1][ly + 1][lx] // -X
718 + tile[lz + 1][ly + 2][lx + 1] // +Y
719 + tile[lz + 1][ly][lx + 1] // -Y
720 + tile[lz + 2][ly + 1][lx + 1] // +Z
721 + tile[lz][ly + 1][lx + 1] // -Z
722 - 6.0f * center;
723
724 // FDTD update: p_new = 2*p - p_prev + c^2*dt^2*lap
725 float p_prev_val = p_prev[global_idx];
726 float p_new = 2.0f * center - p_prev_val + c2_dt2 * lap;
727 p_new *= damping;
728
729 // Write to "previous" buffer (will become current after swap)
730 p_prev[global_idx] = p_new;
731 }}
732
733 // --- Phase 5b: Energy Calculation (periodic) ---
734 // Check if this step will report progress (after counter increment)
735 bool is_progress_step = ((ctrl->current_step + 1) % {progress_interval}) == 0;
736
737 if (is_progress_step) {{
738 // Reset energy accumulator (coordinator only, before reduction)
739 if (is_coordinator) {{
740 ctrl->total_energy = 0.0f;
741 }}
742 __threadfence(); // Ensure reset is visible
743
744 // Compute this thread's energy contribution: E = p^2
745 float my_energy = 0.0f;
746 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
747 float p_val = p_prev[global_idx]; // Just written value
748 my_energy = p_val * p_val;
749 }}
750
751 // Block-level parallel reduction
752 float block_energy = block_reduce_energy(my_energy, energy_reduce, {threads_per_block});
753
754 // Block leader accumulates to global energy
755 if (tid == 0) {{
756 atomicAdd(&ctrl->total_energy, block_energy);
757 }}
758 }}
759
760 // Grid-wide sync before buffer swap
761 {grid_sync}
762
763 // --- Phase 6: Buffer Swap & Progress ---
764 if (is_coordinator) {{
765 // Swap buffer pointers (toggle index)
766 ctrl->current_buffer ^= 1;
767
768 // Update counters
769 ctrl->current_step++;
770 ctrl->steps_remaining--;
771
772 // Send progress update periodically
773 if (ctrl->current_step % {progress_interval} == 0) {{
774 K2HMessage resp;
775 resp.resp_type = RESP_PROGRESS;
776 resp.cmd_id = 0;
777 resp.step = ctrl->current_step;
778 resp.steps_remaining = ctrl->steps_remaining;
779 resp.energy = ctrl->total_energy; // Computed in Phase 5b
780 k2h_send(k2h_header, k2h_slots, &resp);
781 }}
782 }}
783
784 // Swap our local pointers
785 float* tmp = p_curr;
786 p_curr = p_prev;
787 p_prev = tmp;
788
789 {grid_sync}
790 }}
791
792 // ========== FINAL ENERGY CALCULATION ==========
793 // Compute final energy before termination (all threads participate)
794 if (is_coordinator) {{
795 ctrl->total_energy = 0.0f;
796 }}
797 __threadfence();
798
799 // Each thread's final energy contribution
800 float final_energy = 0.0f;
801 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
802 float p_val = p_curr[global_idx]; // Current buffer has final state
803 final_energy = p_val * p_val;
804 }}
805
806 // Block-level parallel reduction
807 float block_final_energy = block_reduce_energy(final_energy, energy_reduce, {threads_per_block});
808
809 // Block leader accumulates to global energy
810 if (tid == 0) {{
811 atomicAdd(&ctrl->total_energy, block_final_energy);
812 }}
813
814 {grid_sync}
815
816 // ========== CLEANUP ==========
817 if (is_coordinator) {{
818 ctrl->has_terminated = 1;
819
820 K2HMessage resp;
821 resp.resp_type = RESP_TERMINATED;
822 resp.cmd_id = 0;
823 resp.step = ctrl->current_step;
824 resp.steps_remaining = 0;
825 resp.energy = ctrl->total_energy; // Final computed energy
826 k2h_send(k2h_header, k2h_slots, &resp);
827 }}
828}}
829"#,
830 name = config.name,
831 tx = tx,
832 ty = ty,
833 tz = tz,
834 threads_per_block = threads_per_block,
835 progress_interval = progress_interval,
836 grid_sync_init = if config.use_cooperative {
837 "cg::grid_group grid = cg::this_grid();"
838 } else {
839 ""
840 },
841 grid_sync = grid_sync,
842 )
843}
844
845#[cfg(feature = "nvcc")]
849pub fn compile_persistent_fdtd_to_ptx(config: &PersistentFdtdConfig) -> Result<String, String> {
850 let cuda_code = generate_persistent_fdtd_kernel(config);
851
852 use std::process::Command;
854
855 let temp_dir = std::env::temp_dir();
857 let cuda_file = temp_dir.join("persistent_fdtd.cu");
858 let ptx_file = temp_dir.join("persistent_fdtd.ptx");
859
860 std::fs::write(&cuda_file, &cuda_code)
861 .map_err(|e| format!("Failed to write CUDA file: {}", e))?;
862
863 let mut args = vec![
867 "-ptx".to_string(),
868 "-o".to_string(),
869 ptx_file.to_string_lossy().to_string(),
870 cuda_file.to_string_lossy().to_string(),
871 "-arch=native".to_string(),
872 "-std=c++17".to_string(),
873 ];
874
875 if config.use_cooperative {
876 args.push("-rdc=true".to_string()); }
878
879 let output = Command::new("nvcc")
880 .args(&args)
881 .output()
882 .map_err(|e| format!("Failed to run nvcc: {}", e))?;
883
884 if !output.status.success() {
885 return Err(format!(
886 "nvcc compilation failed:\n{}",
887 String::from_utf8_lossy(&output.stderr)
888 ));
889 }
890
891 std::fs::read_to_string(&ptx_file).map_err(|e| format!("Failed to read PTX: {}", e))
892}
893
894#[cfg(test)]
895mod tests {
896 use super::*;
897
898 #[test]
899 fn test_default_config() {
900 let config = PersistentFdtdConfig::default();
901 assert_eq!(config.name, "persistent_fdtd3d");
902 assert_eq!(config.tile_size, (8, 8, 8));
903 assert_eq!(config.threads_per_block(), 512);
904 }
905
906 #[test]
907 fn test_config_builder() {
908 let config = PersistentFdtdConfig::new("my_kernel")
909 .with_tile_size(4, 4, 4)
910 .with_cooperative(false)
911 .with_progress_interval(50);
912
913 assert_eq!(config.name, "my_kernel");
914 assert_eq!(config.tile_size, (4, 4, 4));
915 assert_eq!(config.threads_per_block(), 64);
916 assert!(!config.use_cooperative);
917 assert_eq!(config.progress_interval, 50);
918 }
919
920 #[test]
921 fn test_shared_mem_calculation() {
922 let config = PersistentFdtdConfig::default(); assert_eq!(config.shared_mem_size(), 4000);
925 }
926
927 #[test]
928 fn test_generate_kernel_cooperative() {
929 let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(true);
930
931 let code = generate_persistent_fdtd_kernel(&config);
932
933 assert!(code.contains("#include <cooperative_groups.h>"));
935 assert!(code.contains("namespace cg = cooperative_groups;"));
936
937 assert!(code.contains("typedef struct __align__(256)"));
939 assert!(code.contains("PersistentControlBlock"));
940 assert!(code.contains("SpscQueueHeader"));
941 assert!(code.contains("H2KMessage"));
942 assert!(code.contains("K2HMessage"));
943
944 assert!(code.contains("__device__ bool h2k_try_recv"));
946 assert!(code.contains("__device__ bool k2h_send"));
947 assert!(code.contains("__device__ void pack_halo_faces"));
948 assert!(code.contains("__device__ void unpack_halo_faces"));
949
950 assert!(code.contains("extern \"C\" __global__ void"));
952 assert!(code.contains("test_kernel"));
953 assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
954 assert!(code.contains("grid.sync()"));
955
956 assert!(code.contains("7-point Laplacian"));
958 assert!(code.contains("c2_dt2 * lap"));
959 }
960
961 #[test]
962 fn test_generate_kernel_software_sync() {
963 let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(false);
964
965 let code = generate_persistent_fdtd_kernel(&config);
966
967 assert!(!code.contains("#include <cooperative_groups.h>"));
969
970 assert!(code.contains("software_grid_sync"));
972 assert!(code.contains("atomicAdd"));
973 }
974
975 #[test]
976 fn test_generate_kernel_command_handling() {
977 let config = PersistentFdtdConfig::default();
978 let code = generate_persistent_fdtd_kernel(&config);
979
980 assert!(code.contains("CMD_RUN_STEPS"));
982 assert!(code.contains("CMD_TERMINATE"));
983 assert!(code.contains("CMD_INJECT_IMPULSE"));
984 assert!(code.contains("CMD_GET_PROGRESS"));
985
986 assert!(code.contains("RESP_PROGRESS"));
988 assert!(code.contains("RESP_TERMINATED"));
989 }
990
991 #[test]
992 fn test_generate_kernel_halo_exchange() {
993 let config = PersistentFdtdConfig::new("test").with_tile_size(8, 8, 8);
994
995 let code = generate_persistent_fdtd_kernel(&config);
996
997 assert!(code.contains("#define TILE_X 8"));
999 assert!(code.contains("#define TILE_Y 8"));
1000 assert!(code.contains("#define TILE_Z 8"));
1001 assert!(code.contains("#define FACE_SIZE 64"));
1002
1003 assert!(code.contains("FACE_POS_X"));
1005 assert!(code.contains("FACE_NEG_Z"));
1006
1007 assert!(code.contains("pack_halo_faces"));
1009 assert!(code.contains("unpack_halo_faces"));
1010 }
1011
1012 #[test]
1013 fn test_kernel_contains_persistent_loop() {
1014 let config = PersistentFdtdConfig::default();
1015 let code = generate_persistent_fdtd_kernel(&config);
1016
1017 assert!(code.contains("while (true)"));
1019 assert!(code.contains("if (ctrl->should_terminate)"));
1020 assert!(code.contains("break;"));
1021
1022 assert!(code.contains("if (ctrl->steps_remaining == 0)"));
1024 assert!(code.contains("volatile int spin_count"));
1025 }
1026
1027 #[test]
1028 fn test_kernel_contains_energy_calculation() {
1029 let config = PersistentFdtdConfig::new("test_energy")
1030 .with_tile_size(8, 8, 8)
1031 .with_progress_interval(100);
1032
1033 let code = generate_persistent_fdtd_kernel(&config);
1034
1035 assert!(code.contains("block_reduce_energy"));
1037 assert!(code.contains("E = sum(p^2)"));
1038 assert!(code.contains("Parallel reduction tree"));
1039
1040 assert!(code.contains("energy_reduce[512]")); assert!(code.contains("is_progress_step"));
1045 assert!(code.contains("% 100"));
1046
1047 assert!(code.contains("ctrl->total_energy = 0.0f"));
1049
1050 assert!(code.contains("p_val * p_val"));
1052
1053 assert!(code.contains("atomicAdd(&ctrl->total_energy"));
1055
1056 assert!(code.contains("resp.energy = ctrl->total_energy"));
1058
1059 assert!(code.contains("FINAL ENERGY CALCULATION"));
1061 assert!(code.contains("block_final_energy"));
1062 }
1063}