Skip to main content

ringkernel_cuda_codegen/
persistent_fdtd.rs

1//! Persistent FDTD kernel code generation.
2//!
3//! Generates CUDA code for truly persistent GPU actors that run for the
4//! entire simulation lifetime with:
5//!
6//! - Single kernel launch
7//! - H2K/K2H command/response messaging
8//! - K2K halo exchange between blocks
9//! - Grid-wide synchronization via cooperative groups
10//!
11//! # Architecture
12//!
13//! The generated kernel is structured as:
14//!
15//! 1. **Coordinator (Block 0)**: Processes H2K commands, sends K2H responses
16//! 2. **Worker Blocks (1..N)**: Compute FDTD steps, exchange halos via K2K
17//! 3. **Persistent Loop**: Runs until Terminate command received
18//!
19//! # Memory Layout
20//!
21//! ```text
22//! Kernel Parameters:
23//! - control_ptr:      PersistentControlBlock* (mapped memory)
24//! - pressure_a_ptr:   float* (device memory, buffer A)
25//! - pressure_b_ptr:   float* (device memory, buffer B)
26//! - h2k_header_ptr:   SpscQueueHeader* (mapped memory)
27//! - h2k_slots_ptr:    H2KMessage* (mapped memory)
28//! - k2h_header_ptr:   SpscQueueHeader* (mapped memory)
29//! - k2h_slots_ptr:    K2HMessage* (mapped memory)
30//! - routes_ptr:       K2KRouteEntry* (device memory)
31//! - halo_ptr:         float* (device memory, halo buffers)
32//! ```
33
34/// Configuration for persistent FDTD kernel generation.
35#[derive(Debug, Clone)]
36pub struct PersistentFdtdConfig {
37    /// Kernel function name.
38    pub name: String,
39    /// Tile size per dimension (cells per block).
40    pub tile_size: (usize, usize, usize),
41    /// Whether to use cooperative groups for grid sync.
42    pub use_cooperative: bool,
43    /// Progress report interval (steps).
44    pub progress_interval: u64,
45    /// Enable energy calculation.
46    pub track_energy: bool,
47    /// Nanosleep duration for idle spin-wait (nanoseconds).
48    /// Used when the kernel has no work to do, reducing power consumption.
49    pub idle_sleep_ns: u32,
50    /// Use libcu++ ordered atomics (cuda::atomic_ref) for queue operations.
51    /// Requires CUDA 11.0+ toolkit. When false (default), uses legacy
52    /// __threadfence_system() pairs.
53    pub use_libcupp_atomics: bool,
54}
55
56impl Default for PersistentFdtdConfig {
57    fn default() -> Self {
58        Self {
59            name: "persistent_fdtd3d".to_string(),
60            tile_size: (8, 8, 8),
61            use_cooperative: true,
62            progress_interval: 100,
63            track_energy: true,
64            idle_sleep_ns: 1000,
65            use_libcupp_atomics: false,
66        }
67    }
68}
69
70impl PersistentFdtdConfig {
71    /// Create a new config with the given name.
72    pub fn new(name: &str) -> Self {
73        Self {
74            name: name.to_string(),
75            ..Default::default()
76        }
77    }
78
79    /// Set tile size.
80    pub fn with_tile_size(mut self, tx: usize, ty: usize, tz: usize) -> Self {
81        self.tile_size = (tx, ty, tz);
82        self
83    }
84
85    /// Set cooperative groups usage.
86    pub fn with_cooperative(mut self, use_coop: bool) -> Self {
87        self.use_cooperative = use_coop;
88        self
89    }
90
91    /// Set progress reporting interval.
92    pub fn with_progress_interval(mut self, interval: u64) -> Self {
93        self.progress_interval = interval;
94        self
95    }
96
97    /// Set idle sleep duration in nanoseconds.
98    /// Controls power consumption during spin-wait when no work is available.
99    pub fn with_idle_sleep(mut self, ns: u32) -> Self {
100        self.idle_sleep_ns = ns;
101        self
102    }
103
104    /// Enable libcu++ ordered atomics for queue operations.
105    /// Requires CUDA 11.0+ toolkit.
106    pub fn with_libcupp_atomics(mut self, enabled: bool) -> Self {
107        self.use_libcupp_atomics = enabled;
108        self
109    }
110
111    /// Calculate threads per block.
112    pub fn threads_per_block(&self) -> usize {
113        self.tile_size.0 * self.tile_size.1 * self.tile_size.2
114    }
115
116    /// Calculate shared memory size per block (tile + halo).
117    pub fn shared_mem_size(&self) -> usize {
118        let (tx, ty, tz) = self.tile_size;
119        // Tile + 1 cell halo on each side
120        let with_halo = (tx + 2) * (ty + 2) * (tz + 2);
121        with_halo * std::mem::size_of::<f32>()
122    }
123}
124
125/// Generate a complete persistent FDTD kernel.
126///
127/// The generated kernel includes:
128/// - H2K command processing
129/// - K2H response generation
130/// - K2K halo exchange
131/// - FDTD computation with shared memory
132/// - Grid-wide synchronization
133pub fn generate_persistent_fdtd_kernel(config: &PersistentFdtdConfig) -> String {
134    let mut code = String::new();
135
136    // Header with includes
137    code.push_str(&generate_header(config));
138
139    // Structure definitions
140    code.push_str(&generate_structures());
141
142    // Device helper functions
143    code.push_str(&generate_device_functions(config));
144
145    // Main kernel
146    code.push_str(&generate_main_kernel(config));
147
148    code
149}
150
151fn generate_header(config: &PersistentFdtdConfig) -> String {
152    let mut code = String::new();
153
154    code.push_str("// Generated Persistent FDTD Kernel\n");
155    code.push_str("// RingKernel GPU Actor System\n\n");
156
157    if config.use_cooperative {
158        code.push_str("#include <cooperative_groups.h>\n");
159        code.push_str("namespace cg = cooperative_groups;\n\n");
160    }
161
162    code.push_str("#include <cuda_runtime.h>\n");
163    code.push_str("#include <stdint.h>\n");
164
165    if config.use_libcupp_atomics {
166        code.push_str("\n// libcu++ ordered atomics (CUDA 11.0+)\n");
167        code.push_str("#if __CUDACC_VER_MAJOR__ < 11\n");
168        code.push_str("#error \"libcu++ atomics require CUDA 11.0 or later\"\n");
169        code.push_str("#endif\n");
170        code.push_str("#include <cuda/atomic>\n");
171    }
172
173    code.push('\n');
174
175    code
176}
177
178fn generate_structures() -> String {
179    r#"
180// ============================================================================
181// STRUCTURE DEFINITIONS (must match Rust persistent.rs)
182// ============================================================================
183
184// Control block for persistent kernel (256 bytes, cache-aligned)
185typedef struct __align__(256) {
186    // Lifecycle control (host -> GPU)
187    uint32_t should_terminate;
188    uint32_t _pad0;
189    uint64_t steps_remaining;
190
191    // State (GPU -> host)
192    uint64_t current_step;
193    uint32_t current_buffer;
194    uint32_t has_terminated;
195    float total_energy;
196    uint32_t _pad1;
197
198    // Configuration
199    uint32_t grid_dim[3];
200    uint32_t block_dim[3];
201    uint32_t sim_size[3];
202    uint32_t tile_size[3];
203
204    // Acoustic parameters
205    float c2_dt2;
206    float damping;
207    float cell_size;
208    float dt;
209
210    // Synchronization
211    uint32_t barrier_counter;
212    uint32_t barrier_generation;
213
214    // Statistics
215    uint64_t messages_processed;
216    uint64_t k2k_messages_sent;
217    uint64_t k2k_messages_received;
218
219    uint64_t _reserved[16];
220} PersistentControlBlock;
221
222// SPSC queue header (128 bytes, cache-aligned)
223typedef struct __align__(128) {
224    uint64_t head;
225    uint64_t tail;
226    uint32_t capacity;
227    uint32_t mask;
228    uint64_t _padding[12];
229} SpscQueueHeader;
230
231// H2K message (64 bytes)
232typedef struct __align__(64) {
233    uint32_t cmd;
234    uint32_t flags;
235    uint64_t cmd_id;
236    uint64_t param1;
237    uint32_t param2;
238    uint32_t param3;
239    float param4;
240    float param5;
241    uint64_t _reserved[4];
242} H2KMessage;
243
244// K2H message (64 bytes)
245typedef struct __align__(64) {
246    uint32_t resp_type;
247    uint32_t flags;
248    uint64_t cmd_id;
249    uint64_t step;
250    uint64_t steps_remaining;
251    float energy;
252    uint32_t error_code;
253    uint64_t _reserved[3];
254} K2HMessage;
255
256// Command types
257#define CMD_NOP 0
258#define CMD_RUN_STEPS 1
259#define CMD_PAUSE 2
260#define CMD_RESUME 3
261#define CMD_TERMINATE 4
262#define CMD_INJECT_IMPULSE 5
263#define CMD_SET_SOURCE 6
264#define CMD_GET_PROGRESS 7
265
266// Response types
267#define RESP_ACK 0
268#define RESP_PROGRESS 1
269#define RESP_ERROR 2
270#define RESP_TERMINATED 3
271#define RESP_ENERGY 4
272
273// Neighbor block IDs
274typedef struct {
275    int32_t pos_x, neg_x;
276    int32_t pos_y, neg_y;
277    int32_t pos_z, neg_z;
278    int32_t _padding[2];
279} BlockNeighbors;
280
281// K2K route entry
282typedef struct {
283    BlockNeighbors neighbors;
284    uint32_t block_pos[3];
285    uint32_t _padding;
286    uint32_t cell_offset[3];
287    uint32_t _padding2;
288} K2KRouteEntry;
289
290// Face indices
291#define FACE_POS_X 0
292#define FACE_NEG_X 1
293#define FACE_POS_Y 2
294#define FACE_NEG_Y 3
295#define FACE_POS_Z 4
296#define FACE_NEG_Z 5
297
298"#
299    .to_string()
300}
301
302fn generate_device_functions(config: &PersistentFdtdConfig) -> String {
303    let (tx, ty, tz) = config.tile_size;
304    let face_size = tx * ty;
305
306    let mut code = String::new();
307
308    // Software grid barrier (fallback when cooperative not available)
309    if config.use_libcupp_atomics {
310        code.push_str(
311            r#"
312// ============================================================================
313// SYNCHRONIZATION FUNCTIONS (libcu++ ordered atomics)
314// ============================================================================
315
316// Software grid barrier using cuda::atomic_ref with explicit memory ordering.
317// Uses device scope (not system) since barrier is GPU-internal.
318__device__ void software_grid_sync(
319    uint32_t* barrier_counter,
320    uint32_t* barrier_gen,
321    int num_blocks
322) {
323    __syncthreads();  // First sync within block
324
325    if (threadIdx.x == 0) {
326        cuda::atomic_ref<unsigned int, cuda::thread_scope_device> gen_ref(*barrier_gen);
327        cuda::atomic_ref<unsigned int, cuda::thread_scope_device> cnt_ref(*barrier_counter);
328
329        unsigned int gen = gen_ref.load(cuda::memory_order_acquire);
330
331        unsigned int arrived = cnt_ref.fetch_add(1, cuda::memory_order_acq_rel) + 1;
332        if (arrived == num_blocks) {
333            cnt_ref.store(0, cuda::memory_order_relaxed);
334            gen_ref.fetch_add(1, cuda::memory_order_release);
335        } else {
336            while (gen_ref.load(cuda::memory_order_acquire) == gen) {
337                __nanosleep(100);  // Reduce power in barrier spin
338            }
339        }
340    }
341
342    __syncthreads();
343}
344
345"#,
346        );
347    } else {
348        code.push_str(
349            r#"
350// ============================================================================
351// SYNCHRONIZATION FUNCTIONS
352// ============================================================================
353
354// Software grid barrier (atomic counter + generation)
355__device__ void software_grid_sync(
356    volatile uint32_t* barrier_counter,
357    volatile uint32_t* barrier_gen,
358    int num_blocks
359) {
360    __syncthreads();  // First sync within block
361
362    if (threadIdx.x == 0) {
363        unsigned int gen = *barrier_gen;
364
365        unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
366        if (arrived == num_blocks) {
367            *barrier_counter = 0;
368            __threadfence();
369            atomicAdd((unsigned int*)barrier_gen, 1);
370        } else {
371            while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
372                __threadfence();
373                __nanosleep(100);  // Reduce power in barrier spin
374            }
375        }
376    }
377
378    __syncthreads();
379}
380
381"#,
382        );
383    }
384
385    // H2K/K2H queue operations
386    if config.use_libcupp_atomics {
387        code.push_str(
388            r#"
389// ============================================================================
390// MESSAGE QUEUE OPERATIONS (libcu++ ordered atomics)
391// ============================================================================
392
393// Try to receive H2K message using cuda::atomic_ref with memory ordering.
394// acquire on tail read ensures we see the host's payload writes.
395// release on tail publish ensures host sees our consumption.
396__device__ bool h2k_try_recv(
397    SpscQueueHeader* header,
398    H2KMessage* slots,
399    H2KMessage* out_msg
400) {
401    cuda::atomic_ref<uint64_t, cuda::thread_scope_system> head_ref(header->head);
402    cuda::atomic_ref<uint64_t, cuda::thread_scope_system> tail_ref(header->tail);
403
404    uint64_t tail = tail_ref.load(cuda::memory_order_relaxed);
405    uint64_t head = head_ref.load(cuda::memory_order_acquire);
406
407    if (head == tail) {
408        return false;  // Empty
409    }
410
411    uint32_t slot = tail & header->mask;
412    *out_msg = slots[slot];
413
414    tail_ref.store(tail + 1, cuda::memory_order_release);
415
416    return true;
417}
418
419// Send K2H message using cuda::atomic_ref with memory ordering.
420// acquire on head read to see our own previous writes.
421// release on head publish ensures host sees the payload.
422__device__ bool k2h_send(
423    SpscQueueHeader* header,
424    K2HMessage* slots,
425    const K2HMessage* msg
426) {
427    cuda::atomic_ref<uint64_t, cuda::thread_scope_system> head_ref(header->head);
428    cuda::atomic_ref<uint64_t, cuda::thread_scope_system> tail_ref(header->tail);
429
430    uint64_t head = head_ref.load(cuda::memory_order_relaxed);
431    uint64_t tail = tail_ref.load(cuda::memory_order_acquire);
432    uint32_t capacity = header->capacity;
433
434    if (head - tail >= capacity) {
435        return false;  // Full
436    }
437
438    uint32_t slot = head & header->mask;
439    slots[slot] = *msg;
440
441    head_ref.store(head + 1, cuda::memory_order_release);
442
443    return true;
444}
445
446"#,
447        );
448    } else {
449        code.push_str(
450            r#"
451// ============================================================================
452// MESSAGE QUEUE OPERATIONS
453// ============================================================================
454
455// Copy H2K message from volatile source (field-by-field to handle volatile)
456__device__ void copy_h2k_message(H2KMessage* dst, volatile H2KMessage* src) {
457    dst->cmd = src->cmd;
458    dst->flags = src->flags;
459    dst->cmd_id = src->cmd_id;
460    dst->param1 = src->param1;
461    dst->param2 = src->param2;
462    dst->param3 = src->param3;
463    dst->param4 = src->param4;
464    dst->param5 = src->param5;
465    // Skip reserved fields for performance
466}
467
468// Copy K2H message to volatile destination (field-by-field to handle volatile)
469__device__ void copy_k2h_message(volatile K2HMessage* dst, const K2HMessage* src) {
470    dst->resp_type = src->resp_type;
471    dst->flags = src->flags;
472    dst->cmd_id = src->cmd_id;
473    dst->step = src->step;
474    dst->steps_remaining = src->steps_remaining;
475    dst->energy = src->energy;
476    dst->error_code = src->error_code;
477    // Skip reserved fields for performance
478}
479
480// Try to receive H2K message (returns true if message available)
481__device__ bool h2k_try_recv(
482    volatile SpscQueueHeader* header,
483    volatile H2KMessage* slots,
484    H2KMessage* out_msg
485) {
486    // Fence BEFORE reading to ensure we see host writes
487    __threadfence_system();
488
489    uint64_t head = header->head;
490    uint64_t tail = header->tail;
491
492    if (head == tail) {
493        return false;  // Empty
494    }
495
496    uint32_t slot = tail & header->mask;
497    copy_h2k_message(out_msg, &slots[slot]);
498
499    __threadfence();
500    header->tail = tail + 1;
501
502    return true;
503}
504
505// Send K2H message
506__device__ bool k2h_send(
507    volatile SpscQueueHeader* header,
508    volatile K2HMessage* slots,
509    const K2HMessage* msg
510) {
511    uint64_t head = header->head;
512    uint64_t tail = header->tail;
513    uint32_t capacity = header->capacity;
514
515    if (head - tail >= capacity) {
516        return false;  // Full
517    }
518
519    uint32_t slot = head & header->mask;
520    copy_k2h_message(&slots[slot], msg);
521
522    __threadfence_system();  // Ensure host sees our writes
523    header->head = head + 1;
524
525    return true;
526}
527
528"#,
529        );
530    }
531
532    // Energy reduction function (two-phase warp-shuffle)
533    code.push_str(
534        r#"
535// ============================================================================
536// ENERGY CALCULATION (Warp-Shuffle Reduction)
537// ============================================================================
538
539// Block-level parallel reduction for energy: E = sum(p^2)
540// Two-phase approach: intra-warp shuffle (no __syncthreads), then
541// cross-warp reduction via shared memory (one __syncthreads).
542__device__ float block_reduce_energy(
543    float my_energy,
544    float* shared_reduce,
545    int threads_per_block
546) {
547    int tid = threadIdx.x;
548    int warp_id = tid / 32;
549    int lane_id = tid % 32;
550    int num_warps = threads_per_block / 32;
551
552    // Phase 1: Intra-warp reduction via shuffle (no syncthreads needed)
553    float val = my_energy;
554    for (int offset = 16; offset > 0; offset >>= 1) {
555        val += __shfl_down_sync(0xFFFFFFFF, val, offset);
556    }
557
558    // Warp leaders write partial sums to shared memory
559    if (lane_id == 0) {
560        shared_reduce[warp_id] = val;
561    }
562    __syncthreads();
563
564    // Phase 2: First warp reduces across all warp results
565    val = (tid < num_warps) ? shared_reduce[tid] : 0.0f;
566    if (warp_id == 0) {
567        for (int offset = 16; offset > 0; offset >>= 1) {
568            val += __shfl_down_sync(0xFFFFFFFF, val, offset);
569        }
570    }
571
572    // Thread 0 has the final block sum
573    if (tid == 0) {
574        shared_reduce[0] = val;
575    }
576
577    return shared_reduce[0];
578}
579
580"#,
581    );
582
583    // Halo exchange functions
584    code.push_str(&format!(
585        r#"
586// ============================================================================
587// K2K HALO EXCHANGE
588// ============================================================================
589
590#define TILE_X {tx}
591#define TILE_Y {ty}
592#define TILE_Z {tz}
593#define FACE_SIZE {face_size}
594
595// Pack halo faces from shared memory to device halo buffer
596__device__ void pack_halo_faces(
597    float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
598    float* halo_buffers,
599    int block_id,
600    int pingpong,
601    int num_blocks
602) {{
603    int tid = threadIdx.x;
604    int face_stride = FACE_SIZE;
605    int block_stride = 6 * face_stride * 2;  // 6 faces, 2 ping-pong
606    int pp_offset = pingpong * 6 * face_stride;
607
608    float* block_halo = halo_buffers + block_id * block_stride + pp_offset;
609
610    if (tid < FACE_SIZE) {{
611        int fx = tid % TILE_X;
612        int fy = tid / TILE_X;
613
614        // +X face (x = TILE_X in local coords, which is index TILE_X)
615        block_halo[FACE_POS_X * face_stride + tid] = tile[fy + 1][fx + 1][TILE_X];
616        // -X face (x = 1 in local coords)
617        block_halo[FACE_NEG_X * face_stride + tid] = tile[fy + 1][fx + 1][1];
618        // +Y face
619        block_halo[FACE_POS_Y * face_stride + tid] = tile[fy + 1][TILE_Y][fx + 1];
620        // -Y face
621        block_halo[FACE_NEG_Y * face_stride + tid] = tile[fy + 1][1][fx + 1];
622        // +Z face
623        block_halo[FACE_POS_Z * face_stride + tid] = tile[TILE_Z][fy + 1][fx + 1];
624        // -Z face
625        block_halo[FACE_NEG_Z * face_stride + tid] = tile[1][fy + 1][fx + 1];
626    }}
627}}
628
629// Unpack halo faces from device buffer to shared memory ghost cells
630__device__ void unpack_halo_faces(
631    float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
632    const float* halo_buffers,
633    const K2KRouteEntry* route,
634    int pingpong,
635    int num_blocks
636) {{
637    int tid = threadIdx.x;
638    int face_stride = FACE_SIZE;
639    int block_stride = 6 * face_stride * 2;
640    int pp_offset = pingpong * 6 * face_stride;
641
642    if (tid < FACE_SIZE) {{
643        int fx = tid % TILE_X;
644        int fy = tid / TILE_X;
645
646        // My +X ghost comes from neighbor's -X face
647        if (route->neighbors.pos_x >= 0) {{
648            const float* n_halo = halo_buffers + route->neighbors.pos_x * block_stride + pp_offset;
649            tile[fy + 1][fx + 1][TILE_X + 1] = n_halo[FACE_NEG_X * face_stride + tid];
650        }}
651
652        // My -X ghost comes from neighbor's +X face
653        if (route->neighbors.neg_x >= 0) {{
654            const float* n_halo = halo_buffers + route->neighbors.neg_x * block_stride + pp_offset;
655            tile[fy + 1][fx + 1][0] = n_halo[FACE_POS_X * face_stride + tid];
656        }}
657
658        // My +Y ghost
659        if (route->neighbors.pos_y >= 0) {{
660            const float* n_halo = halo_buffers + route->neighbors.pos_y * block_stride + pp_offset;
661            tile[fy + 1][TILE_Y + 1][fx + 1] = n_halo[FACE_NEG_Y * face_stride + tid];
662        }}
663
664        // My -Y ghost
665        if (route->neighbors.neg_y >= 0) {{
666            const float* n_halo = halo_buffers + route->neighbors.neg_y * block_stride + pp_offset;
667            tile[fy + 1][0][fx + 1] = n_halo[FACE_POS_Y * face_stride + tid];
668        }}
669
670        // My +Z ghost
671        if (route->neighbors.pos_z >= 0) {{
672            const float* n_halo = halo_buffers + route->neighbors.pos_z * block_stride + pp_offset;
673            tile[TILE_Z + 1][fy + 1][fx + 1] = n_halo[FACE_NEG_Z * face_stride + tid];
674        }}
675
676        // My -Z ghost
677        if (route->neighbors.neg_z >= 0) {{
678            const float* n_halo = halo_buffers + route->neighbors.neg_z * block_stride + pp_offset;
679            tile[0][fy + 1][fx + 1] = n_halo[FACE_POS_Z * face_stride + tid];
680        }}
681    }}
682}}
683
684"#,
685        tx = tx,
686        ty = ty,
687        tz = tz,
688        face_size = face_size
689    ));
690
691    code
692}
693
694fn generate_main_kernel(config: &PersistentFdtdConfig) -> String {
695    let (tx, ty, tz) = config.tile_size;
696    let threads_per_block = tx * ty * tz;
697    let progress_interval = config.progress_interval;
698
699    // For cooperative groups: declare grid variable once at init, then just call sync
700    // For software sync: no variable needed, just call the function each time
701    let grid_sync = if config.use_cooperative {
702        "grid.sync();"
703    } else {
704        "software_grid_sync(&ctrl->barrier_counter, &ctrl->barrier_generation, num_blocks);"
705    };
706
707    format!(
708        r#"
709// ============================================================================
710// MAIN PERSISTENT KERNEL
711// ============================================================================
712
713extern "C" __global__ void __launch_bounds__({threads_per_block}, 2)
714{name}(
715    PersistentControlBlock* __restrict__ ctrl,
716    float* __restrict__ pressure_a,
717    float* __restrict__ pressure_b,
718    SpscQueueHeader* __restrict__ h2k_header,
719    H2KMessage* __restrict__ h2k_slots,
720    SpscQueueHeader* __restrict__ k2h_header,
721    K2HMessage* __restrict__ k2h_slots,
722    const K2KRouteEntry* __restrict__ routes,
723    float* __restrict__ halo_buffers
724) {{
725    // Block and thread indices
726    int block_id = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
727    int num_blocks = gridDim.x * gridDim.y * gridDim.z;
728    int tid = threadIdx.x;
729    bool is_coordinator = (block_id == 0 && tid == 0);
730
731    // Shared memory for tile + ghost cells
732    __shared__ float tile[{tz} + 2][{ty} + 2][{tx} + 2];
733
734    // Shared memory for energy reduction
735    __shared__ float energy_reduce[{threads_per_block}];
736
737    // Get my routing info
738    const K2KRouteEntry* my_route = &routes[block_id];
739
740    // Pointers to pressure buffers (will swap)
741    float* p_curr = pressure_a;
742    float* p_prev = pressure_b;
743
744    // Grid dimensions for indexing
745    int sim_x = ctrl->sim_size[0];
746    int sim_y = ctrl->sim_size[1];
747    int sim_z = ctrl->sim_size[2];
748
749    // Local cell coordinates within tile
750    int lx = tid % {tx};
751    int ly = (tid / {tx}) % {ty};
752    int lz = tid / ({tx} * {ty});
753
754    // Global cell coordinates
755    int gx = my_route->cell_offset[0] + lx;
756    int gy = my_route->cell_offset[1] + ly;
757    int gz = my_route->cell_offset[2] + lz;
758    int global_idx = gx + gy * sim_x + gz * sim_x * sim_y;
759
760    // Physics parameters
761    float c2_dt2 = ctrl->c2_dt2;
762    float damping = ctrl->damping;
763
764    {grid_sync_init}
765
766    // ========== PERSISTENT LOOP ==========
767    while (true) {{
768        // --- Phase 1: Command Processing (coordinator only) ---
769        if (is_coordinator) {{
770            H2KMessage cmd;
771            while (h2k_try_recv(h2k_header, h2k_slots, &cmd)) {{
772                ctrl->messages_processed++;
773
774                switch (cmd.cmd) {{
775                    case CMD_RUN_STEPS:
776                        ctrl->steps_remaining += cmd.param1;
777                        break;
778
779                    case CMD_TERMINATE:
780                        ctrl->should_terminate = 1;
781                        break;
782
783                    case CMD_INJECT_IMPULSE: {{
784                        // Extract position from params
785                        uint32_t ix = (cmd.param1 >> 32) & 0xFFFFFFFF;
786                        uint32_t iy = cmd.param1 & 0xFFFFFFFF;
787                        uint32_t iz = cmd.param2;
788                        float amplitude = cmd.param4;
789
790                        if (ix < sim_x && iy < sim_y && iz < sim_z) {{
791                            int idx = ix + iy * sim_x + iz * sim_x * sim_y;
792                            p_curr[idx] += amplitude;
793                        }}
794                        break;
795                    }}
796
797                    case CMD_GET_PROGRESS: {{
798                        K2HMessage resp;
799                        resp.resp_type = RESP_PROGRESS;
800                        resp.cmd_id = cmd.cmd_id;
801                        resp.step = ctrl->current_step;
802                        resp.steps_remaining = ctrl->steps_remaining;
803                        resp.energy = ctrl->total_energy;
804                        k2h_send(k2h_header, k2h_slots, &resp);
805                        break;
806                    }}
807
808                    default:
809                        break;
810                }}
811            }}
812        }}
813
814        // Grid-wide sync so all blocks see updated control state
815        {grid_sync}
816
817        // Check for termination
818        if (ctrl->should_terminate) {{
819            break;
820        }}
821
822        // --- Phase 2: Check if we have work ---
823        if (ctrl->steps_remaining == 0) {{
824            // No work - sleep to reduce power consumption, then check again
825            __nanosleep({idle_sleep_ns});
826            {grid_sync}
827            continue;
828        }}
829
830        // --- Phase 3: Load tile from global memory ---
831        // Load interior cells
832        if (gx < sim_x && gy < sim_y && gz < sim_z) {{
833            tile[lz + 1][ly + 1][lx + 1] = p_curr[global_idx];
834        }} else {{
835            tile[lz + 1][ly + 1][lx + 1] = 0.0f;
836        }}
837
838        // Initialize ghost cells to boundary (zero pressure)
839        if (lx == 0) tile[lz + 1][ly + 1][0] = 0.0f;
840        if (lx == {tx} - 1) tile[lz + 1][ly + 1][{tx} + 1] = 0.0f;
841        if (ly == 0) tile[lz + 1][0][lx + 1] = 0.0f;
842        if (ly == {ty} - 1) tile[lz + 1][{ty} + 1][lx + 1] = 0.0f;
843        if (lz == 0) tile[0][ly + 1][lx + 1] = 0.0f;
844        if (lz == {tz} - 1) tile[{tz} + 1][ly + 1][lx + 1] = 0.0f;
845
846        __syncthreads();
847
848        // --- Phase 4: K2K Halo Exchange ---
849        int pingpong = ctrl->current_step & 1;
850
851        // Pack my boundary faces into halo buffers
852        pack_halo_faces(tile, halo_buffers, block_id, pingpong, num_blocks);
853        __threadfence();  // Ensure writes visible to other blocks
854
855        // Grid-wide sync - wait for all blocks to finish packing
856        {grid_sync}
857
858        // Unpack neighbor faces into my ghost cells
859        unpack_halo_faces(tile, halo_buffers, my_route, pingpong, num_blocks);
860        __syncthreads();
861
862        // --- Phase 5: FDTD Computation ---
863        if (gx < sim_x && gy < sim_y && gz < sim_z) {{
864            // 7-point Laplacian from shared memory
865            float center = tile[lz + 1][ly + 1][lx + 1];
866            float lap = tile[lz + 1][ly + 1][lx + 2]   // +X
867                      + tile[lz + 1][ly + 1][lx]       // -X
868                      + tile[lz + 1][ly + 2][lx + 1]   // +Y
869                      + tile[lz + 1][ly][lx + 1]       // -Y
870                      + tile[lz + 2][ly + 1][lx + 1]   // +Z
871                      + tile[lz][ly + 1][lx + 1]       // -Z
872                      - 6.0f * center;
873
874            // FDTD update: p_new = 2*p - p_prev + c^2*dt^2*lap
875            float p_prev_val = p_prev[global_idx];
876            float p_new = 2.0f * center - p_prev_val + c2_dt2 * lap;
877            p_new *= damping;
878
879            // Write to "previous" buffer (will become current after swap)
880            p_prev[global_idx] = p_new;
881        }}
882
883        // --- Phase 5b: Energy Calculation (periodic) ---
884        // Check if this step will report progress (after counter increment)
885        bool is_progress_step = ((ctrl->current_step + 1) % {progress_interval}) == 0;
886
887        if (is_progress_step) {{
888            // Reset energy accumulator (coordinator only, before reduction)
889            if (is_coordinator) {{
890                ctrl->total_energy = 0.0f;
891            }}
892            __threadfence();  // Ensure reset is visible
893
894            // Compute this thread's energy contribution: E = p^2
895            float my_energy = 0.0f;
896            if (gx < sim_x && gy < sim_y && gz < sim_z) {{
897                float p_val = p_prev[global_idx];  // Just written value
898                my_energy = p_val * p_val;
899            }}
900
901            // Block-level parallel reduction
902            float block_energy = block_reduce_energy(my_energy, energy_reduce, {threads_per_block});
903
904            // Block leader accumulates to global energy
905            if (tid == 0) {{
906                atomicAdd(&ctrl->total_energy, block_energy);
907            }}
908        }}
909
910        // Grid-wide sync before buffer swap
911        {grid_sync}
912
913        // --- Phase 6: Buffer Swap & Progress ---
914        if (is_coordinator) {{
915            // Swap buffer pointers (toggle index)
916            ctrl->current_buffer ^= 1;
917
918            // Update counters
919            ctrl->current_step++;
920            ctrl->steps_remaining--;
921
922            // Send progress update periodically
923            if (ctrl->current_step % {progress_interval} == 0) {{
924                K2HMessage resp;
925                resp.resp_type = RESP_PROGRESS;
926                resp.cmd_id = 0;
927                resp.step = ctrl->current_step;
928                resp.steps_remaining = ctrl->steps_remaining;
929                resp.energy = ctrl->total_energy;  // Computed in Phase 5b
930                k2h_send(k2h_header, k2h_slots, &resp);
931            }}
932        }}
933
934        // Swap our local pointers
935        float* tmp = p_curr;
936        p_curr = p_prev;
937        p_prev = tmp;
938
939        {grid_sync}
940    }}
941
942    // ========== FINAL ENERGY CALCULATION ==========
943    // Compute final energy before termination (all threads participate)
944    if (is_coordinator) {{
945        ctrl->total_energy = 0.0f;
946    }}
947    __threadfence();
948
949    // Each thread's final energy contribution
950    float final_energy = 0.0f;
951    if (gx < sim_x && gy < sim_y && gz < sim_z) {{
952        float p_val = p_curr[global_idx];  // Current buffer has final state
953        final_energy = p_val * p_val;
954    }}
955
956    // Block-level parallel reduction
957    float block_final_energy = block_reduce_energy(final_energy, energy_reduce, {threads_per_block});
958
959    // Block leader accumulates to global energy
960    if (tid == 0) {{
961        atomicAdd(&ctrl->total_energy, block_final_energy);
962    }}
963
964    {grid_sync}
965
966    // ========== CLEANUP ==========
967    if (is_coordinator) {{
968        ctrl->has_terminated = 1;
969
970        K2HMessage resp;
971        resp.resp_type = RESP_TERMINATED;
972        resp.cmd_id = 0;
973        resp.step = ctrl->current_step;
974        resp.steps_remaining = 0;
975        resp.energy = ctrl->total_energy;  // Final computed energy
976        k2h_send(k2h_header, k2h_slots, &resp);
977    }}
978}}
979"#,
980        name = config.name,
981        tx = tx,
982        ty = ty,
983        tz = tz,
984        threads_per_block = threads_per_block,
985        progress_interval = progress_interval,
986        grid_sync_init = if config.use_cooperative {
987            "cg::grid_group grid = cg::this_grid();"
988        } else {
989            ""
990        },
991        grid_sync = grid_sync,
992        idle_sleep_ns = config.idle_sleep_ns,
993    )
994}
995
996/// Generate PTX for the persistent FDTD kernel using nvcc.
997///
998/// This requires nvcc to be installed.
999#[cfg(feature = "nvcc")]
1000pub fn compile_persistent_fdtd_to_ptx(config: &PersistentFdtdConfig) -> Result<String, String> {
1001    let cuda_code = generate_persistent_fdtd_kernel(config);
1002
1003    // Use nvcc to compile to PTX
1004    use std::process::Command;
1005
1006    // Write to temp file
1007    let temp_dir = std::env::temp_dir();
1008    let cuda_file = temp_dir.join("persistent_fdtd.cu");
1009    let ptx_file = temp_dir.join("persistent_fdtd.ptx");
1010
1011    std::fs::write(&cuda_file, &cuda_code)
1012        .map_err(|e| format!("Failed to write CUDA file: {}", e))?;
1013
1014    // Compile with nvcc
1015    // Use -arch=native to automatically detect the GPU architecture
1016    // This ensures compatibility with newer CUDA versions that dropped older sm_* support
1017    let mut args = vec![
1018        "-ptx".to_string(),
1019        "-o".to_string(),
1020        ptx_file.to_string_lossy().to_string(),
1021        cuda_file.to_string_lossy().to_string(),
1022        "-arch=native".to_string(),
1023        "-std=c++17".to_string(),
1024    ];
1025
1026    if config.use_cooperative {
1027        args.push("-rdc=true".to_string()); // Required for cooperative groups
1028    }
1029
1030    let output = Command::new("nvcc")
1031        .args(&args)
1032        .output()
1033        .map_err(|e| format!("Failed to run nvcc: {}", e))?;
1034
1035    if !output.status.success() {
1036        return Err(format!(
1037            "nvcc compilation failed:\n{}",
1038            String::from_utf8_lossy(&output.stderr)
1039        ));
1040    }
1041
1042    std::fs::read_to_string(&ptx_file).map_err(|e| format!("Failed to read PTX: {}", e))
1043}
1044
1045#[cfg(test)]
1046mod tests {
1047    use super::*;
1048
1049    #[test]
1050    fn test_default_config() {
1051        let config = PersistentFdtdConfig::default();
1052        assert_eq!(config.name, "persistent_fdtd3d");
1053        assert_eq!(config.tile_size, (8, 8, 8));
1054        assert_eq!(config.threads_per_block(), 512);
1055    }
1056
1057    #[test]
1058    fn test_config_builder() {
1059        let config = PersistentFdtdConfig::new("my_kernel")
1060            .with_tile_size(4, 4, 4)
1061            .with_cooperative(false)
1062            .with_progress_interval(50);
1063
1064        assert_eq!(config.name, "my_kernel");
1065        assert_eq!(config.tile_size, (4, 4, 4));
1066        assert_eq!(config.threads_per_block(), 64);
1067        assert!(!config.use_cooperative);
1068        assert_eq!(config.progress_interval, 50);
1069    }
1070
1071    #[test]
1072    fn test_shared_mem_calculation() {
1073        let config = PersistentFdtdConfig::default(); // 8x8x8
1074                                                      // With halo: 10x10x10 = 1000 floats = 4000 bytes
1075        assert_eq!(config.shared_mem_size(), 4000);
1076    }
1077
1078    #[test]
1079    fn test_generate_kernel_cooperative() {
1080        let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(true);
1081
1082        let code = generate_persistent_fdtd_kernel(&config);
1083
1084        // Check includes
1085        assert!(code.contains("#include <cooperative_groups.h>"));
1086        assert!(code.contains("namespace cg = cooperative_groups;"));
1087
1088        // Check structures
1089        assert!(code.contains("typedef struct __align__(256)"));
1090        assert!(code.contains("PersistentControlBlock"));
1091        assert!(code.contains("SpscQueueHeader"));
1092        assert!(code.contains("H2KMessage"));
1093        assert!(code.contains("K2HMessage"));
1094
1095        // Check device functions
1096        assert!(code.contains("__device__ bool h2k_try_recv"));
1097        assert!(code.contains("__device__ bool k2h_send"));
1098        assert!(code.contains("__device__ void pack_halo_faces"));
1099        assert!(code.contains("__device__ void unpack_halo_faces"));
1100
1101        // Check main kernel
1102        assert!(code.contains("extern \"C\" __global__ void"));
1103        assert!(code.contains("test_kernel"));
1104        assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
1105        assert!(code.contains("grid.sync()"));
1106
1107        // Check FDTD computation
1108        assert!(code.contains("7-point Laplacian"));
1109        assert!(code.contains("c2_dt2 * lap"));
1110    }
1111
1112    #[test]
1113    fn test_generate_kernel_software_sync() {
1114        let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(false);
1115
1116        let code = generate_persistent_fdtd_kernel(&config);
1117
1118        // Should NOT have cooperative groups
1119        assert!(!code.contains("#include <cooperative_groups.h>"));
1120
1121        // Should have software barrier
1122        assert!(code.contains("software_grid_sync"));
1123        assert!(code.contains("atomicAdd"));
1124    }
1125
1126    #[test]
1127    fn test_generate_kernel_command_handling() {
1128        let config = PersistentFdtdConfig::default();
1129        let code = generate_persistent_fdtd_kernel(&config);
1130
1131        // Check command handling
1132        assert!(code.contains("CMD_RUN_STEPS"));
1133        assert!(code.contains("CMD_TERMINATE"));
1134        assert!(code.contains("CMD_INJECT_IMPULSE"));
1135        assert!(code.contains("CMD_GET_PROGRESS"));
1136
1137        // Check response handling
1138        assert!(code.contains("RESP_PROGRESS"));
1139        assert!(code.contains("RESP_TERMINATED"));
1140    }
1141
1142    #[test]
1143    fn test_generate_kernel_halo_exchange() {
1144        let config = PersistentFdtdConfig::new("test").with_tile_size(8, 8, 8);
1145
1146        let code = generate_persistent_fdtd_kernel(&config);
1147
1148        // Check halo defines
1149        assert!(code.contains("#define TILE_X 8"));
1150        assert!(code.contains("#define TILE_Y 8"));
1151        assert!(code.contains("#define TILE_Z 8"));
1152        assert!(code.contains("#define FACE_SIZE 64"));
1153
1154        // Check face indices
1155        assert!(code.contains("FACE_POS_X"));
1156        assert!(code.contains("FACE_NEG_Z"));
1157
1158        // Check K2K operations
1159        assert!(code.contains("pack_halo_faces"));
1160        assert!(code.contains("unpack_halo_faces"));
1161    }
1162
1163    #[test]
1164    fn test_kernel_contains_persistent_loop() {
1165        let config = PersistentFdtdConfig::default();
1166        let code = generate_persistent_fdtd_kernel(&config);
1167
1168        // Must have persistent loop structure
1169        assert!(code.contains("while (true)"));
1170        assert!(code.contains("if (ctrl->should_terminate)"));
1171        assert!(code.contains("break;"));
1172
1173        // Must handle no-work case with nanosleep
1174        assert!(code.contains("if (ctrl->steps_remaining == 0)"));
1175        assert!(code.contains("__nanosleep(1000)"));
1176    }
1177
1178    #[test]
1179    fn test_idle_sleep_configurable() {
1180        let config = PersistentFdtdConfig::new("test_sleep").with_idle_sleep(5000);
1181        let code = generate_persistent_fdtd_kernel(&config);
1182
1183        assert!(code.contains("__nanosleep(5000)"));
1184        assert!(!code.contains("__nanosleep(1000)"));
1185    }
1186
1187    #[test]
1188    fn test_software_barrier_nanosleep() {
1189        let config = PersistentFdtdConfig::new("test_barrier").with_cooperative(false);
1190        let code = generate_persistent_fdtd_kernel(&config);
1191
1192        // Software barrier spin-loop should contain nanosleep
1193        assert!(code.contains("__nanosleep(100)"));
1194        assert!(code.contains("Reduce power in barrier spin"));
1195    }
1196
1197    #[test]
1198    fn test_kernel_contains_energy_calculation() {
1199        let config = PersistentFdtdConfig::new("test_energy")
1200            .with_tile_size(8, 8, 8)
1201            .with_progress_interval(100);
1202
1203        let code = generate_persistent_fdtd_kernel(&config);
1204
1205        // Must have energy reduction function with warp-shuffle
1206        assert!(code.contains("block_reduce_energy"));
1207        assert!(code.contains("E = sum(p^2)"));
1208        assert!(code.contains("Warp-Shuffle Reduction"));
1209        assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, val, offset)"));
1210
1211        // Must have shared memory for energy reduction
1212        assert!(code.contains("energy_reduce[512]")); // 8*8*8 = 512 threads
1213
1214        // Must compute energy periodically (progress interval check)
1215        assert!(code.contains("is_progress_step"));
1216        assert!(code.contains("% 100"));
1217
1218        // Must reset energy accumulator before reduction
1219        assert!(code.contains("ctrl->total_energy = 0.0f"));
1220
1221        // Must compute p^2 for energy
1222        assert!(code.contains("p_val * p_val"));
1223
1224        // Must accumulate via atomicAdd
1225        assert!(code.contains("atomicAdd(&ctrl->total_energy"));
1226
1227        // Must use energy in progress response
1228        assert!(code.contains("resp.energy = ctrl->total_energy"));
1229
1230        // Must compute final energy at termination
1231        assert!(code.contains("FINAL ENERGY CALCULATION"));
1232        assert!(code.contains("block_final_energy"));
1233    }
1234
1235    #[test]
1236    fn test_generate_kernel_libcupp_atomics() {
1237        let config = PersistentFdtdConfig::new("test_libcupp")
1238            .with_libcupp_atomics(true)
1239            .with_cooperative(false);
1240
1241        let code = generate_persistent_fdtd_kernel(&config);
1242
1243        // Must include libcu++ header with version guard
1244        assert!(code.contains("#include <cuda/atomic>"));
1245        assert!(code.contains("__CUDACC_VER_MAJOR__ < 11"));
1246
1247        // Queue operations must use cuda::atomic_ref with system scope
1248        assert!(code.contains("cuda::atomic_ref<uint64_t, cuda::thread_scope_system>"));
1249        assert!(code.contains("memory_order_acquire"));
1250        assert!(code.contains("memory_order_release"));
1251
1252        // Must NOT have legacy __threadfence_system() in queue ops
1253        assert!(!code.contains("__threadfence_system()"));
1254
1255        // Software barrier must use device scope (not system)
1256        assert!(code.contains("cuda::atomic_ref<unsigned int, cuda::thread_scope_device>"));
1257        assert!(code.contains("memory_order_acq_rel"));
1258    }
1259
1260    #[test]
1261    fn test_generate_kernel_default_no_libcupp() {
1262        let config = PersistentFdtdConfig::default();
1263        let code = generate_persistent_fdtd_kernel(&config);
1264
1265        // Default must NOT include libcu++
1266        assert!(!code.contains("#include <cuda/atomic>"));
1267        assert!(!code.contains("cuda::atomic_ref"));
1268
1269        // Must use legacy fencing
1270        assert!(code.contains("__threadfence_system()"));
1271        assert!(code.contains("volatile"));
1272    }
1273}