1#[derive(Debug, Clone)]
36pub struct PersistentFdtdConfig {
37 pub name: String,
39 pub tile_size: (usize, usize, usize),
41 pub use_cooperative: bool,
43 pub progress_interval: u64,
45 pub track_energy: bool,
47}
48
49impl Default for PersistentFdtdConfig {
50 fn default() -> Self {
51 Self {
52 name: "persistent_fdtd3d".to_string(),
53 tile_size: (8, 8, 8),
54 use_cooperative: true,
55 progress_interval: 100,
56 track_energy: true,
57 }
58 }
59}
60
61impl PersistentFdtdConfig {
62 pub fn new(name: &str) -> Self {
64 Self {
65 name: name.to_string(),
66 ..Default::default()
67 }
68 }
69
70 pub fn with_tile_size(mut self, tx: usize, ty: usize, tz: usize) -> Self {
72 self.tile_size = (tx, ty, tz);
73 self
74 }
75
76 pub fn with_cooperative(mut self, use_coop: bool) -> Self {
78 self.use_cooperative = use_coop;
79 self
80 }
81
82 pub fn with_progress_interval(mut self, interval: u64) -> Self {
84 self.progress_interval = interval;
85 self
86 }
87
88 pub fn threads_per_block(&self) -> usize {
90 self.tile_size.0 * self.tile_size.1 * self.tile_size.2
91 }
92
93 pub fn shared_mem_size(&self) -> usize {
95 let (tx, ty, tz) = self.tile_size;
96 let with_halo = (tx + 2) * (ty + 2) * (tz + 2);
98 with_halo * std::mem::size_of::<f32>()
99 }
100}
101
102pub fn generate_persistent_fdtd_kernel(config: &PersistentFdtdConfig) -> String {
111 let mut code = String::new();
112
113 code.push_str(&generate_header(config));
115
116 code.push_str(&generate_structures());
118
119 code.push_str(&generate_device_functions(config));
121
122 code.push_str(&generate_main_kernel(config));
124
125 code
126}
127
128fn generate_header(config: &PersistentFdtdConfig) -> String {
129 let mut code = String::new();
130
131 code.push_str("// Generated Persistent FDTD Kernel\n");
132 code.push_str("// RingKernel GPU Actor System\n\n");
133
134 if config.use_cooperative {
135 code.push_str("#include <cooperative_groups.h>\n");
136 code.push_str("namespace cg = cooperative_groups;\n\n");
137 }
138
139 code.push_str("#include <cuda_runtime.h>\n");
140 code.push_str("#include <stdint.h>\n\n");
141
142 code
143}
144
145fn generate_structures() -> String {
146 r#"
147// ============================================================================
148// STRUCTURE DEFINITIONS (must match Rust persistent.rs)
149// ============================================================================
150
151// Control block for persistent kernel (256 bytes, cache-aligned)
152typedef struct __align__(256) {
153 // Lifecycle control (host -> GPU)
154 uint32_t should_terminate;
155 uint32_t _pad0;
156 uint64_t steps_remaining;
157
158 // State (GPU -> host)
159 uint64_t current_step;
160 uint32_t current_buffer;
161 uint32_t has_terminated;
162 float total_energy;
163 uint32_t _pad1;
164
165 // Configuration
166 uint32_t grid_dim[3];
167 uint32_t block_dim[3];
168 uint32_t sim_size[3];
169 uint32_t tile_size[3];
170
171 // Acoustic parameters
172 float c2_dt2;
173 float damping;
174 float cell_size;
175 float dt;
176
177 // Synchronization
178 uint32_t barrier_counter;
179 uint32_t barrier_generation;
180
181 // Statistics
182 uint64_t messages_processed;
183 uint64_t k2k_messages_sent;
184 uint64_t k2k_messages_received;
185
186 uint64_t _reserved[16];
187} PersistentControlBlock;
188
189// SPSC queue header (128 bytes, cache-aligned)
190typedef struct __align__(128) {
191 uint64_t head;
192 uint64_t tail;
193 uint32_t capacity;
194 uint32_t mask;
195 uint64_t _padding[12];
196} SpscQueueHeader;
197
198// H2K message (64 bytes)
199typedef struct __align__(64) {
200 uint32_t cmd;
201 uint32_t flags;
202 uint64_t cmd_id;
203 uint64_t param1;
204 uint32_t param2;
205 uint32_t param3;
206 float param4;
207 float param5;
208 uint64_t _reserved[4];
209} H2KMessage;
210
211// K2H message (64 bytes)
212typedef struct __align__(64) {
213 uint32_t resp_type;
214 uint32_t flags;
215 uint64_t cmd_id;
216 uint64_t step;
217 uint64_t steps_remaining;
218 float energy;
219 uint32_t error_code;
220 uint64_t _reserved[3];
221} K2HMessage;
222
223// Command types
224#define CMD_NOP 0
225#define CMD_RUN_STEPS 1
226#define CMD_PAUSE 2
227#define CMD_RESUME 3
228#define CMD_TERMINATE 4
229#define CMD_INJECT_IMPULSE 5
230#define CMD_SET_SOURCE 6
231#define CMD_GET_PROGRESS 7
232
233// Response types
234#define RESP_ACK 0
235#define RESP_PROGRESS 1
236#define RESP_ERROR 2
237#define RESP_TERMINATED 3
238#define RESP_ENERGY 4
239
240// Neighbor block IDs
241typedef struct {
242 int32_t pos_x, neg_x;
243 int32_t pos_y, neg_y;
244 int32_t pos_z, neg_z;
245 int32_t _padding[2];
246} BlockNeighbors;
247
248// K2K route entry
249typedef struct {
250 BlockNeighbors neighbors;
251 uint32_t block_pos[3];
252 uint32_t _padding;
253 uint32_t cell_offset[3];
254 uint32_t _padding2;
255} K2KRouteEntry;
256
257// Face indices
258#define FACE_POS_X 0
259#define FACE_NEG_X 1
260#define FACE_POS_Y 2
261#define FACE_NEG_Y 3
262#define FACE_POS_Z 4
263#define FACE_NEG_Z 5
264
265"#
266 .to_string()
267}
268
269fn generate_device_functions(config: &PersistentFdtdConfig) -> String {
270 let (tx, ty, tz) = config.tile_size;
271 let face_size = tx * ty;
272
273 let mut code = String::new();
274
275 code.push_str(
277 r#"
278// ============================================================================
279// SYNCHRONIZATION FUNCTIONS
280// ============================================================================
281
282// Software grid barrier (atomic counter + generation)
283__device__ void software_grid_sync(
284 volatile uint32_t* barrier_counter,
285 volatile uint32_t* barrier_gen,
286 int num_blocks
287) {
288 __syncthreads(); // First sync within block
289
290 if (threadIdx.x == 0) {
291 unsigned int gen = *barrier_gen;
292
293 unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
294 if (arrived == num_blocks) {
295 *barrier_counter = 0;
296 __threadfence();
297 atomicAdd((unsigned int*)barrier_gen, 1);
298 } else {
299 while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
300 __threadfence();
301 }
302 }
303 }
304
305 __syncthreads();
306}
307
308"#,
309 );
310
311 code.push_str(
313 r#"
314// ============================================================================
315// MESSAGE QUEUE OPERATIONS
316// ============================================================================
317
318// Copy H2K message from volatile source (field-by-field to handle volatile)
319__device__ void copy_h2k_message(H2KMessage* dst, volatile H2KMessage* src) {
320 dst->cmd = src->cmd;
321 dst->flags = src->flags;
322 dst->cmd_id = src->cmd_id;
323 dst->param1 = src->param1;
324 dst->param2 = src->param2;
325 dst->param3 = src->param3;
326 dst->param4 = src->param4;
327 dst->param5 = src->param5;
328 // Skip reserved fields for performance
329}
330
331// Copy K2H message to volatile destination (field-by-field to handle volatile)
332__device__ void copy_k2h_message(volatile K2HMessage* dst, const K2HMessage* src) {
333 dst->resp_type = src->resp_type;
334 dst->flags = src->flags;
335 dst->cmd_id = src->cmd_id;
336 dst->step = src->step;
337 dst->steps_remaining = src->steps_remaining;
338 dst->energy = src->energy;
339 dst->error_code = src->error_code;
340 // Skip reserved fields for performance
341}
342
343// Try to receive H2K message (returns true if message available)
344__device__ bool h2k_try_recv(
345 volatile SpscQueueHeader* header,
346 volatile H2KMessage* slots,
347 H2KMessage* out_msg
348) {
349 // Fence BEFORE reading to ensure we see host writes
350 __threadfence_system();
351
352 uint64_t head = header->head;
353 uint64_t tail = header->tail;
354
355 if (head == tail) {
356 return false; // Empty
357 }
358
359 uint32_t slot = tail & header->mask;
360 copy_h2k_message(out_msg, &slots[slot]);
361
362 __threadfence();
363 header->tail = tail + 1;
364
365 return true;
366}
367
368// Send K2H message
369__device__ bool k2h_send(
370 volatile SpscQueueHeader* header,
371 volatile K2HMessage* slots,
372 const K2HMessage* msg
373) {
374 uint64_t head = header->head;
375 uint64_t tail = header->tail;
376 uint32_t capacity = header->capacity;
377
378 if (head - tail >= capacity) {
379 return false; // Full
380 }
381
382 uint32_t slot = head & header->mask;
383 copy_k2h_message(&slots[slot], msg);
384
385 __threadfence_system(); // Ensure host sees our writes
386 header->head = head + 1;
387
388 return true;
389}
390
391"#,
392 );
393
394 code.push_str(&format!(
396 r#"
397// ============================================================================
398// K2K HALO EXCHANGE
399// ============================================================================
400
401#define TILE_X {tx}
402#define TILE_Y {ty}
403#define TILE_Z {tz}
404#define FACE_SIZE {face_size}
405
406// Pack halo faces from shared memory to device halo buffer
407__device__ void pack_halo_faces(
408 float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
409 float* halo_buffers,
410 int block_id,
411 int pingpong,
412 int num_blocks
413) {{
414 int tid = threadIdx.x;
415 int face_stride = FACE_SIZE;
416 int block_stride = 6 * face_stride * 2; // 6 faces, 2 ping-pong
417 int pp_offset = pingpong * 6 * face_stride;
418
419 float* block_halo = halo_buffers + block_id * block_stride + pp_offset;
420
421 if (tid < FACE_SIZE) {{
422 int fx = tid % TILE_X;
423 int fy = tid / TILE_X;
424
425 // +X face (x = TILE_X in local coords, which is index TILE_X)
426 block_halo[FACE_POS_X * face_stride + tid] = tile[fy + 1][fx + 1][TILE_X];
427 // -X face (x = 1 in local coords)
428 block_halo[FACE_NEG_X * face_stride + tid] = tile[fy + 1][fx + 1][1];
429 // +Y face
430 block_halo[FACE_POS_Y * face_stride + tid] = tile[fy + 1][TILE_Y][fx + 1];
431 // -Y face
432 block_halo[FACE_NEG_Y * face_stride + tid] = tile[fy + 1][1][fx + 1];
433 // +Z face
434 block_halo[FACE_POS_Z * face_stride + tid] = tile[TILE_Z][fy + 1][fx + 1];
435 // -Z face
436 block_halo[FACE_NEG_Z * face_stride + tid] = tile[1][fy + 1][fx + 1];
437 }}
438}}
439
440// Unpack halo faces from device buffer to shared memory ghost cells
441__device__ void unpack_halo_faces(
442 float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
443 const float* halo_buffers,
444 const K2KRouteEntry* route,
445 int pingpong,
446 int num_blocks
447) {{
448 int tid = threadIdx.x;
449 int face_stride = FACE_SIZE;
450 int block_stride = 6 * face_stride * 2;
451 int pp_offset = pingpong * 6 * face_stride;
452
453 if (tid < FACE_SIZE) {{
454 int fx = tid % TILE_X;
455 int fy = tid / TILE_X;
456
457 // My +X ghost comes from neighbor's -X face
458 if (route->neighbors.pos_x >= 0) {{
459 const float* n_halo = halo_buffers + route->neighbors.pos_x * block_stride + pp_offset;
460 tile[fy + 1][fx + 1][TILE_X + 1] = n_halo[FACE_NEG_X * face_stride + tid];
461 }}
462
463 // My -X ghost comes from neighbor's +X face
464 if (route->neighbors.neg_x >= 0) {{
465 const float* n_halo = halo_buffers + route->neighbors.neg_x * block_stride + pp_offset;
466 tile[fy + 1][fx + 1][0] = n_halo[FACE_POS_X * face_stride + tid];
467 }}
468
469 // My +Y ghost
470 if (route->neighbors.pos_y >= 0) {{
471 const float* n_halo = halo_buffers + route->neighbors.pos_y * block_stride + pp_offset;
472 tile[fy + 1][TILE_Y + 1][fx + 1] = n_halo[FACE_NEG_Y * face_stride + tid];
473 }}
474
475 // My -Y ghost
476 if (route->neighbors.neg_y >= 0) {{
477 const float* n_halo = halo_buffers + route->neighbors.neg_y * block_stride + pp_offset;
478 tile[fy + 1][0][fx + 1] = n_halo[FACE_POS_Y * face_stride + tid];
479 }}
480
481 // My +Z ghost
482 if (route->neighbors.pos_z >= 0) {{
483 const float* n_halo = halo_buffers + route->neighbors.pos_z * block_stride + pp_offset;
484 tile[TILE_Z + 1][fy + 1][fx + 1] = n_halo[FACE_NEG_Z * face_stride + tid];
485 }}
486
487 // My -Z ghost
488 if (route->neighbors.neg_z >= 0) {{
489 const float* n_halo = halo_buffers + route->neighbors.neg_z * block_stride + pp_offset;
490 tile[0][fy + 1][fx + 1] = n_halo[FACE_POS_Z * face_stride + tid];
491 }}
492 }}
493}}
494
495"#,
496 tx = tx,
497 ty = ty,
498 tz = tz,
499 face_size = face_size
500 ));
501
502 code
503}
504
505fn generate_main_kernel(config: &PersistentFdtdConfig) -> String {
506 let (tx, ty, tz) = config.tile_size;
507 let threads_per_block = tx * ty * tz;
508 let progress_interval = config.progress_interval;
509
510 let grid_sync = if config.use_cooperative {
513 "grid.sync();"
514 } else {
515 "software_grid_sync(&ctrl->barrier_counter, &ctrl->barrier_generation, num_blocks);"
516 };
517
518 format!(
519 r#"
520// ============================================================================
521// MAIN PERSISTENT KERNEL
522// ============================================================================
523
524extern "C" __global__ void __launch_bounds__({threads_per_block}, 2)
525{name}(
526 PersistentControlBlock* __restrict__ ctrl,
527 float* __restrict__ pressure_a,
528 float* __restrict__ pressure_b,
529 SpscQueueHeader* __restrict__ h2k_header,
530 H2KMessage* __restrict__ h2k_slots,
531 SpscQueueHeader* __restrict__ k2h_header,
532 K2HMessage* __restrict__ k2h_slots,
533 const K2KRouteEntry* __restrict__ routes,
534 float* __restrict__ halo_buffers
535) {{
536 // Block and thread indices
537 int block_id = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
538 int num_blocks = gridDim.x * gridDim.y * gridDim.z;
539 int tid = threadIdx.x;
540 bool is_coordinator = (block_id == 0 && tid == 0);
541
542 // Shared memory for tile + ghost cells
543 __shared__ float tile[{tz} + 2][{ty} + 2][{tx} + 2];
544
545 // Get my routing info
546 const K2KRouteEntry* my_route = &routes[block_id];
547
548 // Pointers to pressure buffers (will swap)
549 float* p_curr = pressure_a;
550 float* p_prev = pressure_b;
551
552 // Grid dimensions for indexing
553 int sim_x = ctrl->sim_size[0];
554 int sim_y = ctrl->sim_size[1];
555 int sim_z = ctrl->sim_size[2];
556
557 // Local cell coordinates within tile
558 int lx = tid % {tx};
559 int ly = (tid / {tx}) % {ty};
560 int lz = tid / ({tx} * {ty});
561
562 // Global cell coordinates
563 int gx = my_route->cell_offset[0] + lx;
564 int gy = my_route->cell_offset[1] + ly;
565 int gz = my_route->cell_offset[2] + lz;
566 int global_idx = gx + gy * sim_x + gz * sim_x * sim_y;
567
568 // Physics parameters
569 float c2_dt2 = ctrl->c2_dt2;
570 float damping = ctrl->damping;
571
572 {grid_sync_init}
573
574 // ========== PERSISTENT LOOP ==========
575 while (true) {{
576 // --- Phase 1: Command Processing (coordinator only) ---
577 if (is_coordinator) {{
578 H2KMessage cmd;
579 while (h2k_try_recv(h2k_header, h2k_slots, &cmd)) {{
580 ctrl->messages_processed++;
581
582 switch (cmd.cmd) {{
583 case CMD_RUN_STEPS:
584 ctrl->steps_remaining += cmd.param1;
585 break;
586
587 case CMD_TERMINATE:
588 ctrl->should_terminate = 1;
589 break;
590
591 case CMD_INJECT_IMPULSE: {{
592 // Extract position from params
593 uint32_t ix = (cmd.param1 >> 32) & 0xFFFFFFFF;
594 uint32_t iy = cmd.param1 & 0xFFFFFFFF;
595 uint32_t iz = cmd.param2;
596 float amplitude = cmd.param4;
597
598 if (ix < sim_x && iy < sim_y && iz < sim_z) {{
599 int idx = ix + iy * sim_x + iz * sim_x * sim_y;
600 p_curr[idx] += amplitude;
601 }}
602 break;
603 }}
604
605 case CMD_GET_PROGRESS: {{
606 K2HMessage resp;
607 resp.resp_type = RESP_PROGRESS;
608 resp.cmd_id = cmd.cmd_id;
609 resp.step = ctrl->current_step;
610 resp.steps_remaining = ctrl->steps_remaining;
611 resp.energy = ctrl->total_energy;
612 k2h_send(k2h_header, k2h_slots, &resp);
613 break;
614 }}
615
616 default:
617 break;
618 }}
619 }}
620 }}
621
622 // Grid-wide sync so all blocks see updated control state
623 {grid_sync}
624
625 // Check for termination
626 if (ctrl->should_terminate) {{
627 break;
628 }}
629
630 // --- Phase 2: Check if we have work ---
631 if (ctrl->steps_remaining == 0) {{
632 // No work - brief spinwait then check again
633 // Use volatile counter to prevent optimization
634 volatile int spin_count = 0;
635 for (int i = 0; i < 1000; i++) {{
636 spin_count++;
637 }}
638 {grid_sync}
639 continue;
640 }}
641
642 // --- Phase 3: Load tile from global memory ---
643 // Load interior cells
644 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
645 tile[lz + 1][ly + 1][lx + 1] = p_curr[global_idx];
646 }} else {{
647 tile[lz + 1][ly + 1][lx + 1] = 0.0f;
648 }}
649
650 // Initialize ghost cells to boundary (zero pressure)
651 if (lx == 0) tile[lz + 1][ly + 1][0] = 0.0f;
652 if (lx == {tx} - 1) tile[lz + 1][ly + 1][{tx} + 1] = 0.0f;
653 if (ly == 0) tile[lz + 1][0][lx + 1] = 0.0f;
654 if (ly == {ty} - 1) tile[lz + 1][{ty} + 1][lx + 1] = 0.0f;
655 if (lz == 0) tile[0][ly + 1][lx + 1] = 0.0f;
656 if (lz == {tz} - 1) tile[{tz} + 1][ly + 1][lx + 1] = 0.0f;
657
658 __syncthreads();
659
660 // --- Phase 4: K2K Halo Exchange ---
661 int pingpong = ctrl->current_step & 1;
662
663 // Pack my boundary faces into halo buffers
664 pack_halo_faces(tile, halo_buffers, block_id, pingpong, num_blocks);
665 __threadfence(); // Ensure writes visible to other blocks
666
667 // Grid-wide sync - wait for all blocks to finish packing
668 {grid_sync}
669
670 // Unpack neighbor faces into my ghost cells
671 unpack_halo_faces(tile, halo_buffers, my_route, pingpong, num_blocks);
672 __syncthreads();
673
674 // --- Phase 5: FDTD Computation ---
675 if (gx < sim_x && gy < sim_y && gz < sim_z) {{
676 // 7-point Laplacian from shared memory
677 float center = tile[lz + 1][ly + 1][lx + 1];
678 float lap = tile[lz + 1][ly + 1][lx + 2] // +X
679 + tile[lz + 1][ly + 1][lx] // -X
680 + tile[lz + 1][ly + 2][lx + 1] // +Y
681 + tile[lz + 1][ly][lx + 1] // -Y
682 + tile[lz + 2][ly + 1][lx + 1] // +Z
683 + tile[lz][ly + 1][lx + 1] // -Z
684 - 6.0f * center;
685
686 // FDTD update: p_new = 2*p - p_prev + c^2*dt^2*lap
687 float p_prev_val = p_prev[global_idx];
688 float p_new = 2.0f * center - p_prev_val + c2_dt2 * lap;
689 p_new *= damping;
690
691 // Write to "previous" buffer (will become current after swap)
692 p_prev[global_idx] = p_new;
693 }}
694
695 // Grid-wide sync before buffer swap
696 {grid_sync}
697
698 // --- Phase 6: Buffer Swap & Progress ---
699 if (is_coordinator) {{
700 // Swap buffer pointers (toggle index)
701 ctrl->current_buffer ^= 1;
702
703 // Update counters
704 ctrl->current_step++;
705 ctrl->steps_remaining--;
706
707 // Send progress update periodically
708 if (ctrl->current_step % {progress_interval} == 0) {{
709 K2HMessage resp;
710 resp.resp_type = RESP_PROGRESS;
711 resp.cmd_id = 0;
712 resp.step = ctrl->current_step;
713 resp.steps_remaining = ctrl->steps_remaining;
714 resp.energy = 0.0f; // TODO: calculate energy
715 k2h_send(k2h_header, k2h_slots, &resp);
716 }}
717 }}
718
719 // Swap our local pointers
720 float* tmp = p_curr;
721 p_curr = p_prev;
722 p_prev = tmp;
723
724 {grid_sync}
725 }}
726
727 // ========== CLEANUP ==========
728 if (is_coordinator) {{
729 ctrl->has_terminated = 1;
730
731 K2HMessage resp;
732 resp.resp_type = RESP_TERMINATED;
733 resp.cmd_id = 0;
734 resp.step = ctrl->current_step;
735 resp.steps_remaining = 0;
736 resp.energy = 0.0f;
737 k2h_send(k2h_header, k2h_slots, &resp);
738 }}
739}}
740"#,
741 name = config.name,
742 tx = tx,
743 ty = ty,
744 tz = tz,
745 threads_per_block = threads_per_block,
746 progress_interval = progress_interval,
747 grid_sync_init = if config.use_cooperative {
748 "cg::grid_group grid = cg::this_grid();"
749 } else {
750 ""
751 },
752 grid_sync = grid_sync,
753 )
754}
755
756#[cfg(feature = "nvcc")]
760pub fn compile_persistent_fdtd_to_ptx(config: &PersistentFdtdConfig) -> Result<String, String> {
761 let cuda_code = generate_persistent_fdtd_kernel(config);
762
763 use std::io::Write;
765 use std::process::Command;
766
767 let temp_dir = std::env::temp_dir();
769 let cuda_file = temp_dir.join("persistent_fdtd.cu");
770 let ptx_file = temp_dir.join("persistent_fdtd.ptx");
771
772 std::fs::write(&cuda_file, &cuda_code)
773 .map_err(|e| format!("Failed to write CUDA file: {}", e))?;
774
775 let mut args = vec![
779 "-ptx".to_string(),
780 "-o".to_string(),
781 ptx_file.to_string_lossy().to_string(),
782 cuda_file.to_string_lossy().to_string(),
783 "-arch=native".to_string(),
784 "-std=c++17".to_string(),
785 ];
786
787 if config.use_cooperative {
788 args.push("-rdc=true".to_string()); }
790
791 let output = Command::new("nvcc")
792 .args(&args)
793 .output()
794 .map_err(|e| format!("Failed to run nvcc: {}", e))?;
795
796 if !output.status.success() {
797 return Err(format!(
798 "nvcc compilation failed:\n{}",
799 String::from_utf8_lossy(&output.stderr)
800 ));
801 }
802
803 std::fs::read_to_string(&ptx_file).map_err(|e| format!("Failed to read PTX: {}", e))
804}
805
806#[cfg(test)]
807mod tests {
808 use super::*;
809
810 #[test]
811 fn test_default_config() {
812 let config = PersistentFdtdConfig::default();
813 assert_eq!(config.name, "persistent_fdtd3d");
814 assert_eq!(config.tile_size, (8, 8, 8));
815 assert_eq!(config.threads_per_block(), 512);
816 }
817
818 #[test]
819 fn test_config_builder() {
820 let config = PersistentFdtdConfig::new("my_kernel")
821 .with_tile_size(4, 4, 4)
822 .with_cooperative(false)
823 .with_progress_interval(50);
824
825 assert_eq!(config.name, "my_kernel");
826 assert_eq!(config.tile_size, (4, 4, 4));
827 assert_eq!(config.threads_per_block(), 64);
828 assert!(!config.use_cooperative);
829 assert_eq!(config.progress_interval, 50);
830 }
831
832 #[test]
833 fn test_shared_mem_calculation() {
834 let config = PersistentFdtdConfig::default(); assert_eq!(config.shared_mem_size(), 4000);
837 }
838
839 #[test]
840 fn test_generate_kernel_cooperative() {
841 let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(true);
842
843 let code = generate_persistent_fdtd_kernel(&config);
844
845 assert!(code.contains("#include <cooperative_groups.h>"));
847 assert!(code.contains("namespace cg = cooperative_groups;"));
848
849 assert!(code.contains("typedef struct __align__(256)"));
851 assert!(code.contains("PersistentControlBlock"));
852 assert!(code.contains("SpscQueueHeader"));
853 assert!(code.contains("H2KMessage"));
854 assert!(code.contains("K2HMessage"));
855
856 assert!(code.contains("__device__ bool h2k_try_recv"));
858 assert!(code.contains("__device__ bool k2h_send"));
859 assert!(code.contains("__device__ void pack_halo_faces"));
860 assert!(code.contains("__device__ void unpack_halo_faces"));
861
862 assert!(code.contains("extern \"C\" __global__ void"));
864 assert!(code.contains("test_kernel"));
865 assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
866 assert!(code.contains("grid.sync()"));
867
868 assert!(code.contains("7-point Laplacian"));
870 assert!(code.contains("c2_dt2 * lap"));
871 }
872
873 #[test]
874 fn test_generate_kernel_software_sync() {
875 let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(false);
876
877 let code = generate_persistent_fdtd_kernel(&config);
878
879 assert!(!code.contains("#include <cooperative_groups.h>"));
881
882 assert!(code.contains("software_grid_sync"));
884 assert!(code.contains("atomicAdd"));
885 }
886
887 #[test]
888 fn test_generate_kernel_command_handling() {
889 let config = PersistentFdtdConfig::default();
890 let code = generate_persistent_fdtd_kernel(&config);
891
892 assert!(code.contains("CMD_RUN_STEPS"));
894 assert!(code.contains("CMD_TERMINATE"));
895 assert!(code.contains("CMD_INJECT_IMPULSE"));
896 assert!(code.contains("CMD_GET_PROGRESS"));
897
898 assert!(code.contains("RESP_PROGRESS"));
900 assert!(code.contains("RESP_TERMINATED"));
901 }
902
903 #[test]
904 fn test_generate_kernel_halo_exchange() {
905 let config = PersistentFdtdConfig::new("test").with_tile_size(8, 8, 8);
906
907 let code = generate_persistent_fdtd_kernel(&config);
908
909 assert!(code.contains("#define TILE_X 8"));
911 assert!(code.contains("#define TILE_Y 8"));
912 assert!(code.contains("#define TILE_Z 8"));
913 assert!(code.contains("#define FACE_SIZE 64"));
914
915 assert!(code.contains("FACE_POS_X"));
917 assert!(code.contains("FACE_NEG_Z"));
918
919 assert!(code.contains("pack_halo_faces"));
921 assert!(code.contains("unpack_halo_faces"));
922 }
923
924 #[test]
925 fn test_kernel_contains_persistent_loop() {
926 let config = PersistentFdtdConfig::default();
927 let code = generate_persistent_fdtd_kernel(&config);
928
929 assert!(code.contains("while (true)"));
931 assert!(code.contains("if (ctrl->should_terminate)"));
932 assert!(code.contains("break;"));
933
934 assert!(code.contains("if (ctrl->steps_remaining == 0)"));
936 assert!(code.contains("volatile int spin_count"));
937 }
938}