ringkernel_cuda_codegen/
persistent_fdtd.rs

1//! Persistent FDTD kernel code generation.
2//!
3//! Generates CUDA code for truly persistent GPU actors that run for the
4//! entire simulation lifetime with:
5//!
6//! - Single kernel launch
7//! - H2K/K2H command/response messaging
8//! - K2K halo exchange between blocks
9//! - Grid-wide synchronization via cooperative groups
10//!
11//! # Architecture
12//!
13//! The generated kernel is structured as:
14//!
15//! 1. **Coordinator (Block 0)**: Processes H2K commands, sends K2H responses
16//! 2. **Worker Blocks (1..N)**: Compute FDTD steps, exchange halos via K2K
17//! 3. **Persistent Loop**: Runs until Terminate command received
18//!
19//! # Memory Layout
20//!
21//! ```text
22//! Kernel Parameters:
23//! - control_ptr:      PersistentControlBlock* (mapped memory)
24//! - pressure_a_ptr:   float* (device memory, buffer A)
25//! - pressure_b_ptr:   float* (device memory, buffer B)
26//! - h2k_header_ptr:   SpscQueueHeader* (mapped memory)
27//! - h2k_slots_ptr:    H2KMessage* (mapped memory)
28//! - k2h_header_ptr:   SpscQueueHeader* (mapped memory)
29//! - k2h_slots_ptr:    K2HMessage* (mapped memory)
30//! - routes_ptr:       K2KRouteEntry* (device memory)
31//! - halo_ptr:         float* (device memory, halo buffers)
32//! ```
33
34/// Configuration for persistent FDTD kernel generation.
35#[derive(Debug, Clone)]
36pub struct PersistentFdtdConfig {
37    /// Kernel function name.
38    pub name: String,
39    /// Tile size per dimension (cells per block).
40    pub tile_size: (usize, usize, usize),
41    /// Whether to use cooperative groups for grid sync.
42    pub use_cooperative: bool,
43    /// Progress report interval (steps).
44    pub progress_interval: u64,
45    /// Enable energy calculation.
46    pub track_energy: bool,
47}
48
49impl Default for PersistentFdtdConfig {
50    fn default() -> Self {
51        Self {
52            name: "persistent_fdtd3d".to_string(),
53            tile_size: (8, 8, 8),
54            use_cooperative: true,
55            progress_interval: 100,
56            track_energy: true,
57        }
58    }
59}
60
61impl PersistentFdtdConfig {
62    /// Create a new config with the given name.
63    pub fn new(name: &str) -> Self {
64        Self {
65            name: name.to_string(),
66            ..Default::default()
67        }
68    }
69
70    /// Set tile size.
71    pub fn with_tile_size(mut self, tx: usize, ty: usize, tz: usize) -> Self {
72        self.tile_size = (tx, ty, tz);
73        self
74    }
75
76    /// Set cooperative groups usage.
77    pub fn with_cooperative(mut self, use_coop: bool) -> Self {
78        self.use_cooperative = use_coop;
79        self
80    }
81
82    /// Set progress reporting interval.
83    pub fn with_progress_interval(mut self, interval: u64) -> Self {
84        self.progress_interval = interval;
85        self
86    }
87
88    /// Calculate threads per block.
89    pub fn threads_per_block(&self) -> usize {
90        self.tile_size.0 * self.tile_size.1 * self.tile_size.2
91    }
92
93    /// Calculate shared memory size per block (tile + halo).
94    pub fn shared_mem_size(&self) -> usize {
95        let (tx, ty, tz) = self.tile_size;
96        // Tile + 1 cell halo on each side
97        let with_halo = (tx + 2) * (ty + 2) * (tz + 2);
98        with_halo * std::mem::size_of::<f32>()
99    }
100}
101
102/// Generate a complete persistent FDTD kernel.
103///
104/// The generated kernel includes:
105/// - H2K command processing
106/// - K2H response generation
107/// - K2K halo exchange
108/// - FDTD computation with shared memory
109/// - Grid-wide synchronization
110pub fn generate_persistent_fdtd_kernel(config: &PersistentFdtdConfig) -> String {
111    let mut code = String::new();
112
113    // Header with includes
114    code.push_str(&generate_header(config));
115
116    // Structure definitions
117    code.push_str(&generate_structures());
118
119    // Device helper functions
120    code.push_str(&generate_device_functions(config));
121
122    // Main kernel
123    code.push_str(&generate_main_kernel(config));
124
125    code
126}
127
128fn generate_header(config: &PersistentFdtdConfig) -> String {
129    let mut code = String::new();
130
131    code.push_str("// Generated Persistent FDTD Kernel\n");
132    code.push_str("// RingKernel GPU Actor System\n\n");
133
134    if config.use_cooperative {
135        code.push_str("#include <cooperative_groups.h>\n");
136        code.push_str("namespace cg = cooperative_groups;\n\n");
137    }
138
139    code.push_str("#include <cuda_runtime.h>\n");
140    code.push_str("#include <stdint.h>\n\n");
141
142    code
143}
144
145fn generate_structures() -> String {
146    r#"
147// ============================================================================
148// STRUCTURE DEFINITIONS (must match Rust persistent.rs)
149// ============================================================================
150
151// Control block for persistent kernel (256 bytes, cache-aligned)
152typedef struct __align__(256) {
153    // Lifecycle control (host -> GPU)
154    uint32_t should_terminate;
155    uint32_t _pad0;
156    uint64_t steps_remaining;
157
158    // State (GPU -> host)
159    uint64_t current_step;
160    uint32_t current_buffer;
161    uint32_t has_terminated;
162    float total_energy;
163    uint32_t _pad1;
164
165    // Configuration
166    uint32_t grid_dim[3];
167    uint32_t block_dim[3];
168    uint32_t sim_size[3];
169    uint32_t tile_size[3];
170
171    // Acoustic parameters
172    float c2_dt2;
173    float damping;
174    float cell_size;
175    float dt;
176
177    // Synchronization
178    uint32_t barrier_counter;
179    uint32_t barrier_generation;
180
181    // Statistics
182    uint64_t messages_processed;
183    uint64_t k2k_messages_sent;
184    uint64_t k2k_messages_received;
185
186    uint64_t _reserved[16];
187} PersistentControlBlock;
188
189// SPSC queue header (128 bytes, cache-aligned)
190typedef struct __align__(128) {
191    uint64_t head;
192    uint64_t tail;
193    uint32_t capacity;
194    uint32_t mask;
195    uint64_t _padding[12];
196} SpscQueueHeader;
197
198// H2K message (64 bytes)
199typedef struct __align__(64) {
200    uint32_t cmd;
201    uint32_t flags;
202    uint64_t cmd_id;
203    uint64_t param1;
204    uint32_t param2;
205    uint32_t param3;
206    float param4;
207    float param5;
208    uint64_t _reserved[4];
209} H2KMessage;
210
211// K2H message (64 bytes)
212typedef struct __align__(64) {
213    uint32_t resp_type;
214    uint32_t flags;
215    uint64_t cmd_id;
216    uint64_t step;
217    uint64_t steps_remaining;
218    float energy;
219    uint32_t error_code;
220    uint64_t _reserved[3];
221} K2HMessage;
222
223// Command types
224#define CMD_NOP 0
225#define CMD_RUN_STEPS 1
226#define CMD_PAUSE 2
227#define CMD_RESUME 3
228#define CMD_TERMINATE 4
229#define CMD_INJECT_IMPULSE 5
230#define CMD_SET_SOURCE 6
231#define CMD_GET_PROGRESS 7
232
233// Response types
234#define RESP_ACK 0
235#define RESP_PROGRESS 1
236#define RESP_ERROR 2
237#define RESP_TERMINATED 3
238#define RESP_ENERGY 4
239
240// Neighbor block IDs
241typedef struct {
242    int32_t pos_x, neg_x;
243    int32_t pos_y, neg_y;
244    int32_t pos_z, neg_z;
245    int32_t _padding[2];
246} BlockNeighbors;
247
248// K2K route entry
249typedef struct {
250    BlockNeighbors neighbors;
251    uint32_t block_pos[3];
252    uint32_t _padding;
253    uint32_t cell_offset[3];
254    uint32_t _padding2;
255} K2KRouteEntry;
256
257// Face indices
258#define FACE_POS_X 0
259#define FACE_NEG_X 1
260#define FACE_POS_Y 2
261#define FACE_NEG_Y 3
262#define FACE_POS_Z 4
263#define FACE_NEG_Z 5
264
265"#
266    .to_string()
267}
268
269fn generate_device_functions(config: &PersistentFdtdConfig) -> String {
270    let (tx, ty, tz) = config.tile_size;
271    let face_size = tx * ty;
272
273    let mut code = String::new();
274
275    // Software grid barrier (fallback when cooperative not available)
276    code.push_str(
277        r#"
278// ============================================================================
279// SYNCHRONIZATION FUNCTIONS
280// ============================================================================
281
282// Software grid barrier (atomic counter + generation)
283__device__ void software_grid_sync(
284    volatile uint32_t* barrier_counter,
285    volatile uint32_t* barrier_gen,
286    int num_blocks
287) {
288    __syncthreads();  // First sync within block
289
290    if (threadIdx.x == 0) {
291        unsigned int gen = *barrier_gen;
292
293        unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
294        if (arrived == num_blocks) {
295            *barrier_counter = 0;
296            __threadfence();
297            atomicAdd((unsigned int*)barrier_gen, 1);
298        } else {
299            while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
300                __threadfence();
301            }
302        }
303    }
304
305    __syncthreads();
306}
307
308"#,
309    );
310
311    // H2K/K2H queue operations
312    code.push_str(
313        r#"
314// ============================================================================
315// MESSAGE QUEUE OPERATIONS
316// ============================================================================
317
318// Copy H2K message from volatile source (field-by-field to handle volatile)
319__device__ void copy_h2k_message(H2KMessage* dst, volatile H2KMessage* src) {
320    dst->cmd = src->cmd;
321    dst->flags = src->flags;
322    dst->cmd_id = src->cmd_id;
323    dst->param1 = src->param1;
324    dst->param2 = src->param2;
325    dst->param3 = src->param3;
326    dst->param4 = src->param4;
327    dst->param5 = src->param5;
328    // Skip reserved fields for performance
329}
330
331// Copy K2H message to volatile destination (field-by-field to handle volatile)
332__device__ void copy_k2h_message(volatile K2HMessage* dst, const K2HMessage* src) {
333    dst->resp_type = src->resp_type;
334    dst->flags = src->flags;
335    dst->cmd_id = src->cmd_id;
336    dst->step = src->step;
337    dst->steps_remaining = src->steps_remaining;
338    dst->energy = src->energy;
339    dst->error_code = src->error_code;
340    // Skip reserved fields for performance
341}
342
343// Try to receive H2K message (returns true if message available)
344__device__ bool h2k_try_recv(
345    volatile SpscQueueHeader* header,
346    volatile H2KMessage* slots,
347    H2KMessage* out_msg
348) {
349    // Fence BEFORE reading to ensure we see host writes
350    __threadfence_system();
351
352    uint64_t head = header->head;
353    uint64_t tail = header->tail;
354
355    if (head == tail) {
356        return false;  // Empty
357    }
358
359    uint32_t slot = tail & header->mask;
360    copy_h2k_message(out_msg, &slots[slot]);
361
362    __threadfence();
363    header->tail = tail + 1;
364
365    return true;
366}
367
368// Send K2H message
369__device__ bool k2h_send(
370    volatile SpscQueueHeader* header,
371    volatile K2HMessage* slots,
372    const K2HMessage* msg
373) {
374    uint64_t head = header->head;
375    uint64_t tail = header->tail;
376    uint32_t capacity = header->capacity;
377
378    if (head - tail >= capacity) {
379        return false;  // Full
380    }
381
382    uint32_t slot = head & header->mask;
383    copy_k2h_message(&slots[slot], msg);
384
385    __threadfence_system();  // Ensure host sees our writes
386    header->head = head + 1;
387
388    return true;
389}
390
391"#,
392    );
393
394    // Energy reduction function
395    code.push_str(
396        r#"
397// ============================================================================
398// ENERGY CALCULATION (Parallel Reduction)
399// ============================================================================
400
401// Block-level parallel reduction for energy: E = sum(p^2)
402// Uses shared memory for efficient reduction within a block
403__device__ float block_reduce_energy(
404    float my_energy,
405    float* shared_reduce,
406    int threads_per_block
407) {
408    int tid = threadIdx.x;
409
410    // Store initial value in shared memory
411    shared_reduce[tid] = my_energy;
412    __syncthreads();
413
414    // Parallel reduction tree
415    for (int stride = threads_per_block / 2; stride > 0; stride >>= 1) {
416        if (tid < stride) {
417            shared_reduce[tid] += shared_reduce[tid + stride];
418        }
419        __syncthreads();
420    }
421
422    // Return the total for this block (only thread 0 has final sum)
423    return shared_reduce[0];
424}
425
426"#,
427    );
428
429    // Halo exchange functions
430    code.push_str(&format!(
431        r#"
432// ============================================================================
433// K2K HALO EXCHANGE
434// ============================================================================
435
436#define TILE_X {tx}
437#define TILE_Y {ty}
438#define TILE_Z {tz}
439#define FACE_SIZE {face_size}
440
441// Pack halo faces from shared memory to device halo buffer
442__device__ void pack_halo_faces(
443    float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
444    float* halo_buffers,
445    int block_id,
446    int pingpong,
447    int num_blocks
448) {{
449    int tid = threadIdx.x;
450    int face_stride = FACE_SIZE;
451    int block_stride = 6 * face_stride * 2;  // 6 faces, 2 ping-pong
452    int pp_offset = pingpong * 6 * face_stride;
453
454    float* block_halo = halo_buffers + block_id * block_stride + pp_offset;
455
456    if (tid < FACE_SIZE) {{
457        int fx = tid % TILE_X;
458        int fy = tid / TILE_X;
459
460        // +X face (x = TILE_X in local coords, which is index TILE_X)
461        block_halo[FACE_POS_X * face_stride + tid] = tile[fy + 1][fx + 1][TILE_X];
462        // -X face (x = 1 in local coords)
463        block_halo[FACE_NEG_X * face_stride + tid] = tile[fy + 1][fx + 1][1];
464        // +Y face
465        block_halo[FACE_POS_Y * face_stride + tid] = tile[fy + 1][TILE_Y][fx + 1];
466        // -Y face
467        block_halo[FACE_NEG_Y * face_stride + tid] = tile[fy + 1][1][fx + 1];
468        // +Z face
469        block_halo[FACE_POS_Z * face_stride + tid] = tile[TILE_Z][fy + 1][fx + 1];
470        // -Z face
471        block_halo[FACE_NEG_Z * face_stride + tid] = tile[1][fy + 1][fx + 1];
472    }}
473}}
474
475// Unpack halo faces from device buffer to shared memory ghost cells
476__device__ void unpack_halo_faces(
477    float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
478    const float* halo_buffers,
479    const K2KRouteEntry* route,
480    int pingpong,
481    int num_blocks
482) {{
483    int tid = threadIdx.x;
484    int face_stride = FACE_SIZE;
485    int block_stride = 6 * face_stride * 2;
486    int pp_offset = pingpong * 6 * face_stride;
487
488    if (tid < FACE_SIZE) {{
489        int fx = tid % TILE_X;
490        int fy = tid / TILE_X;
491
492        // My +X ghost comes from neighbor's -X face
493        if (route->neighbors.pos_x >= 0) {{
494            const float* n_halo = halo_buffers + route->neighbors.pos_x * block_stride + pp_offset;
495            tile[fy + 1][fx + 1][TILE_X + 1] = n_halo[FACE_NEG_X * face_stride + tid];
496        }}
497
498        // My -X ghost comes from neighbor's +X face
499        if (route->neighbors.neg_x >= 0) {{
500            const float* n_halo = halo_buffers + route->neighbors.neg_x * block_stride + pp_offset;
501            tile[fy + 1][fx + 1][0] = n_halo[FACE_POS_X * face_stride + tid];
502        }}
503
504        // My +Y ghost
505        if (route->neighbors.pos_y >= 0) {{
506            const float* n_halo = halo_buffers + route->neighbors.pos_y * block_stride + pp_offset;
507            tile[fy + 1][TILE_Y + 1][fx + 1] = n_halo[FACE_NEG_Y * face_stride + tid];
508        }}
509
510        // My -Y ghost
511        if (route->neighbors.neg_y >= 0) {{
512            const float* n_halo = halo_buffers + route->neighbors.neg_y * block_stride + pp_offset;
513            tile[fy + 1][0][fx + 1] = n_halo[FACE_POS_Y * face_stride + tid];
514        }}
515
516        // My +Z ghost
517        if (route->neighbors.pos_z >= 0) {{
518            const float* n_halo = halo_buffers + route->neighbors.pos_z * block_stride + pp_offset;
519            tile[TILE_Z + 1][fy + 1][fx + 1] = n_halo[FACE_NEG_Z * face_stride + tid];
520        }}
521
522        // My -Z ghost
523        if (route->neighbors.neg_z >= 0) {{
524            const float* n_halo = halo_buffers + route->neighbors.neg_z * block_stride + pp_offset;
525            tile[0][fy + 1][fx + 1] = n_halo[FACE_POS_Z * face_stride + tid];
526        }}
527    }}
528}}
529
530"#,
531        tx = tx,
532        ty = ty,
533        tz = tz,
534        face_size = face_size
535    ));
536
537    code
538}
539
540fn generate_main_kernel(config: &PersistentFdtdConfig) -> String {
541    let (tx, ty, tz) = config.tile_size;
542    let threads_per_block = tx * ty * tz;
543    let progress_interval = config.progress_interval;
544
545    // For cooperative groups: declare grid variable once at init, then just call sync
546    // For software sync: no variable needed, just call the function each time
547    let grid_sync = if config.use_cooperative {
548        "grid.sync();"
549    } else {
550        "software_grid_sync(&ctrl->barrier_counter, &ctrl->barrier_generation, num_blocks);"
551    };
552
553    format!(
554        r#"
555// ============================================================================
556// MAIN PERSISTENT KERNEL
557// ============================================================================
558
559extern "C" __global__ void __launch_bounds__({threads_per_block}, 2)
560{name}(
561    PersistentControlBlock* __restrict__ ctrl,
562    float* __restrict__ pressure_a,
563    float* __restrict__ pressure_b,
564    SpscQueueHeader* __restrict__ h2k_header,
565    H2KMessage* __restrict__ h2k_slots,
566    SpscQueueHeader* __restrict__ k2h_header,
567    K2HMessage* __restrict__ k2h_slots,
568    const K2KRouteEntry* __restrict__ routes,
569    float* __restrict__ halo_buffers
570) {{
571    // Block and thread indices
572    int block_id = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
573    int num_blocks = gridDim.x * gridDim.y * gridDim.z;
574    int tid = threadIdx.x;
575    bool is_coordinator = (block_id == 0 && tid == 0);
576
577    // Shared memory for tile + ghost cells
578    __shared__ float tile[{tz} + 2][{ty} + 2][{tx} + 2];
579
580    // Shared memory for energy reduction
581    __shared__ float energy_reduce[{threads_per_block}];
582
583    // Get my routing info
584    const K2KRouteEntry* my_route = &routes[block_id];
585
586    // Pointers to pressure buffers (will swap)
587    float* p_curr = pressure_a;
588    float* p_prev = pressure_b;
589
590    // Grid dimensions for indexing
591    int sim_x = ctrl->sim_size[0];
592    int sim_y = ctrl->sim_size[1];
593    int sim_z = ctrl->sim_size[2];
594
595    // Local cell coordinates within tile
596    int lx = tid % {tx};
597    int ly = (tid / {tx}) % {ty};
598    int lz = tid / ({tx} * {ty});
599
600    // Global cell coordinates
601    int gx = my_route->cell_offset[0] + lx;
602    int gy = my_route->cell_offset[1] + ly;
603    int gz = my_route->cell_offset[2] + lz;
604    int global_idx = gx + gy * sim_x + gz * sim_x * sim_y;
605
606    // Physics parameters
607    float c2_dt2 = ctrl->c2_dt2;
608    float damping = ctrl->damping;
609
610    {grid_sync_init}
611
612    // ========== PERSISTENT LOOP ==========
613    while (true) {{
614        // --- Phase 1: Command Processing (coordinator only) ---
615        if (is_coordinator) {{
616            H2KMessage cmd;
617            while (h2k_try_recv(h2k_header, h2k_slots, &cmd)) {{
618                ctrl->messages_processed++;
619
620                switch (cmd.cmd) {{
621                    case CMD_RUN_STEPS:
622                        ctrl->steps_remaining += cmd.param1;
623                        break;
624
625                    case CMD_TERMINATE:
626                        ctrl->should_terminate = 1;
627                        break;
628
629                    case CMD_INJECT_IMPULSE: {{
630                        // Extract position from params
631                        uint32_t ix = (cmd.param1 >> 32) & 0xFFFFFFFF;
632                        uint32_t iy = cmd.param1 & 0xFFFFFFFF;
633                        uint32_t iz = cmd.param2;
634                        float amplitude = cmd.param4;
635
636                        if (ix < sim_x && iy < sim_y && iz < sim_z) {{
637                            int idx = ix + iy * sim_x + iz * sim_x * sim_y;
638                            p_curr[idx] += amplitude;
639                        }}
640                        break;
641                    }}
642
643                    case CMD_GET_PROGRESS: {{
644                        K2HMessage resp;
645                        resp.resp_type = RESP_PROGRESS;
646                        resp.cmd_id = cmd.cmd_id;
647                        resp.step = ctrl->current_step;
648                        resp.steps_remaining = ctrl->steps_remaining;
649                        resp.energy = ctrl->total_energy;
650                        k2h_send(k2h_header, k2h_slots, &resp);
651                        break;
652                    }}
653
654                    default:
655                        break;
656                }}
657            }}
658        }}
659
660        // Grid-wide sync so all blocks see updated control state
661        {grid_sync}
662
663        // Check for termination
664        if (ctrl->should_terminate) {{
665            break;
666        }}
667
668        // --- Phase 2: Check if we have work ---
669        if (ctrl->steps_remaining == 0) {{
670            // No work - brief spinwait then check again
671            // Use volatile counter to prevent optimization
672            volatile int spin_count = 0;
673            for (int i = 0; i < 1000; i++) {{
674                spin_count++;
675            }}
676            {grid_sync}
677            continue;
678        }}
679
680        // --- Phase 3: Load tile from global memory ---
681        // Load interior cells
682        if (gx < sim_x && gy < sim_y && gz < sim_z) {{
683            tile[lz + 1][ly + 1][lx + 1] = p_curr[global_idx];
684        }} else {{
685            tile[lz + 1][ly + 1][lx + 1] = 0.0f;
686        }}
687
688        // Initialize ghost cells to boundary (zero pressure)
689        if (lx == 0) tile[lz + 1][ly + 1][0] = 0.0f;
690        if (lx == {tx} - 1) tile[lz + 1][ly + 1][{tx} + 1] = 0.0f;
691        if (ly == 0) tile[lz + 1][0][lx + 1] = 0.0f;
692        if (ly == {ty} - 1) tile[lz + 1][{ty} + 1][lx + 1] = 0.0f;
693        if (lz == 0) tile[0][ly + 1][lx + 1] = 0.0f;
694        if (lz == {tz} - 1) tile[{tz} + 1][ly + 1][lx + 1] = 0.0f;
695
696        __syncthreads();
697
698        // --- Phase 4: K2K Halo Exchange ---
699        int pingpong = ctrl->current_step & 1;
700
701        // Pack my boundary faces into halo buffers
702        pack_halo_faces(tile, halo_buffers, block_id, pingpong, num_blocks);
703        __threadfence();  // Ensure writes visible to other blocks
704
705        // Grid-wide sync - wait for all blocks to finish packing
706        {grid_sync}
707
708        // Unpack neighbor faces into my ghost cells
709        unpack_halo_faces(tile, halo_buffers, my_route, pingpong, num_blocks);
710        __syncthreads();
711
712        // --- Phase 5: FDTD Computation ---
713        if (gx < sim_x && gy < sim_y && gz < sim_z) {{
714            // 7-point Laplacian from shared memory
715            float center = tile[lz + 1][ly + 1][lx + 1];
716            float lap = tile[lz + 1][ly + 1][lx + 2]   // +X
717                      + tile[lz + 1][ly + 1][lx]       // -X
718                      + tile[lz + 1][ly + 2][lx + 1]   // +Y
719                      + tile[lz + 1][ly][lx + 1]       // -Y
720                      + tile[lz + 2][ly + 1][lx + 1]   // +Z
721                      + tile[lz][ly + 1][lx + 1]       // -Z
722                      - 6.0f * center;
723
724            // FDTD update: p_new = 2*p - p_prev + c^2*dt^2*lap
725            float p_prev_val = p_prev[global_idx];
726            float p_new = 2.0f * center - p_prev_val + c2_dt2 * lap;
727            p_new *= damping;
728
729            // Write to "previous" buffer (will become current after swap)
730            p_prev[global_idx] = p_new;
731        }}
732
733        // --- Phase 5b: Energy Calculation (periodic) ---
734        // Check if this step will report progress (after counter increment)
735        bool is_progress_step = ((ctrl->current_step + 1) % {progress_interval}) == 0;
736
737        if (is_progress_step) {{
738            // Reset energy accumulator (coordinator only, before reduction)
739            if (is_coordinator) {{
740                ctrl->total_energy = 0.0f;
741            }}
742            __threadfence();  // Ensure reset is visible
743
744            // Compute this thread's energy contribution: E = p^2
745            float my_energy = 0.0f;
746            if (gx < sim_x && gy < sim_y && gz < sim_z) {{
747                float p_val = p_prev[global_idx];  // Just written value
748                my_energy = p_val * p_val;
749            }}
750
751            // Block-level parallel reduction
752            float block_energy = block_reduce_energy(my_energy, energy_reduce, {threads_per_block});
753
754            // Block leader accumulates to global energy
755            if (tid == 0) {{
756                atomicAdd(&ctrl->total_energy, block_energy);
757            }}
758        }}
759
760        // Grid-wide sync before buffer swap
761        {grid_sync}
762
763        // --- Phase 6: Buffer Swap & Progress ---
764        if (is_coordinator) {{
765            // Swap buffer pointers (toggle index)
766            ctrl->current_buffer ^= 1;
767
768            // Update counters
769            ctrl->current_step++;
770            ctrl->steps_remaining--;
771
772            // Send progress update periodically
773            if (ctrl->current_step % {progress_interval} == 0) {{
774                K2HMessage resp;
775                resp.resp_type = RESP_PROGRESS;
776                resp.cmd_id = 0;
777                resp.step = ctrl->current_step;
778                resp.steps_remaining = ctrl->steps_remaining;
779                resp.energy = ctrl->total_energy;  // Computed in Phase 5b
780                k2h_send(k2h_header, k2h_slots, &resp);
781            }}
782        }}
783
784        // Swap our local pointers
785        float* tmp = p_curr;
786        p_curr = p_prev;
787        p_prev = tmp;
788
789        {grid_sync}
790    }}
791
792    // ========== FINAL ENERGY CALCULATION ==========
793    // Compute final energy before termination (all threads participate)
794    if (is_coordinator) {{
795        ctrl->total_energy = 0.0f;
796    }}
797    __threadfence();
798
799    // Each thread's final energy contribution
800    float final_energy = 0.0f;
801    if (gx < sim_x && gy < sim_y && gz < sim_z) {{
802        float p_val = p_curr[global_idx];  // Current buffer has final state
803        final_energy = p_val * p_val;
804    }}
805
806    // Block-level parallel reduction
807    float block_final_energy = block_reduce_energy(final_energy, energy_reduce, {threads_per_block});
808
809    // Block leader accumulates to global energy
810    if (tid == 0) {{
811        atomicAdd(&ctrl->total_energy, block_final_energy);
812    }}
813
814    {grid_sync}
815
816    // ========== CLEANUP ==========
817    if (is_coordinator) {{
818        ctrl->has_terminated = 1;
819
820        K2HMessage resp;
821        resp.resp_type = RESP_TERMINATED;
822        resp.cmd_id = 0;
823        resp.step = ctrl->current_step;
824        resp.steps_remaining = 0;
825        resp.energy = ctrl->total_energy;  // Final computed energy
826        k2h_send(k2h_header, k2h_slots, &resp);
827    }}
828}}
829"#,
830        name = config.name,
831        tx = tx,
832        ty = ty,
833        tz = tz,
834        threads_per_block = threads_per_block,
835        progress_interval = progress_interval,
836        grid_sync_init = if config.use_cooperative {
837            "cg::grid_group grid = cg::this_grid();"
838        } else {
839            ""
840        },
841        grid_sync = grid_sync,
842    )
843}
844
845/// Generate PTX for the persistent FDTD kernel using nvcc.
846///
847/// This requires nvcc to be installed.
848#[cfg(feature = "nvcc")]
849pub fn compile_persistent_fdtd_to_ptx(config: &PersistentFdtdConfig) -> Result<String, String> {
850    let cuda_code = generate_persistent_fdtd_kernel(config);
851
852    // Use nvcc to compile to PTX
853    use std::process::Command;
854
855    // Write to temp file
856    let temp_dir = std::env::temp_dir();
857    let cuda_file = temp_dir.join("persistent_fdtd.cu");
858    let ptx_file = temp_dir.join("persistent_fdtd.ptx");
859
860    std::fs::write(&cuda_file, &cuda_code)
861        .map_err(|e| format!("Failed to write CUDA file: {}", e))?;
862
863    // Compile with nvcc
864    // Use -arch=native to automatically detect the GPU architecture
865    // This ensures compatibility with newer CUDA versions that dropped older sm_* support
866    let mut args = vec![
867        "-ptx".to_string(),
868        "-o".to_string(),
869        ptx_file.to_string_lossy().to_string(),
870        cuda_file.to_string_lossy().to_string(),
871        "-arch=native".to_string(),
872        "-std=c++17".to_string(),
873    ];
874
875    if config.use_cooperative {
876        args.push("-rdc=true".to_string()); // Required for cooperative groups
877    }
878
879    let output = Command::new("nvcc")
880        .args(&args)
881        .output()
882        .map_err(|e| format!("Failed to run nvcc: {}", e))?;
883
884    if !output.status.success() {
885        return Err(format!(
886            "nvcc compilation failed:\n{}",
887            String::from_utf8_lossy(&output.stderr)
888        ));
889    }
890
891    std::fs::read_to_string(&ptx_file).map_err(|e| format!("Failed to read PTX: {}", e))
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897
898    #[test]
899    fn test_default_config() {
900        let config = PersistentFdtdConfig::default();
901        assert_eq!(config.name, "persistent_fdtd3d");
902        assert_eq!(config.tile_size, (8, 8, 8));
903        assert_eq!(config.threads_per_block(), 512);
904    }
905
906    #[test]
907    fn test_config_builder() {
908        let config = PersistentFdtdConfig::new("my_kernel")
909            .with_tile_size(4, 4, 4)
910            .with_cooperative(false)
911            .with_progress_interval(50);
912
913        assert_eq!(config.name, "my_kernel");
914        assert_eq!(config.tile_size, (4, 4, 4));
915        assert_eq!(config.threads_per_block(), 64);
916        assert!(!config.use_cooperative);
917        assert_eq!(config.progress_interval, 50);
918    }
919
920    #[test]
921    fn test_shared_mem_calculation() {
922        let config = PersistentFdtdConfig::default(); // 8x8x8
923                                                      // With halo: 10x10x10 = 1000 floats = 4000 bytes
924        assert_eq!(config.shared_mem_size(), 4000);
925    }
926
927    #[test]
928    fn test_generate_kernel_cooperative() {
929        let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(true);
930
931        let code = generate_persistent_fdtd_kernel(&config);
932
933        // Check includes
934        assert!(code.contains("#include <cooperative_groups.h>"));
935        assert!(code.contains("namespace cg = cooperative_groups;"));
936
937        // Check structures
938        assert!(code.contains("typedef struct __align__(256)"));
939        assert!(code.contains("PersistentControlBlock"));
940        assert!(code.contains("SpscQueueHeader"));
941        assert!(code.contains("H2KMessage"));
942        assert!(code.contains("K2HMessage"));
943
944        // Check device functions
945        assert!(code.contains("__device__ bool h2k_try_recv"));
946        assert!(code.contains("__device__ bool k2h_send"));
947        assert!(code.contains("__device__ void pack_halo_faces"));
948        assert!(code.contains("__device__ void unpack_halo_faces"));
949
950        // Check main kernel
951        assert!(code.contains("extern \"C\" __global__ void"));
952        assert!(code.contains("test_kernel"));
953        assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
954        assert!(code.contains("grid.sync()"));
955
956        // Check FDTD computation
957        assert!(code.contains("7-point Laplacian"));
958        assert!(code.contains("c2_dt2 * lap"));
959    }
960
961    #[test]
962    fn test_generate_kernel_software_sync() {
963        let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(false);
964
965        let code = generate_persistent_fdtd_kernel(&config);
966
967        // Should NOT have cooperative groups
968        assert!(!code.contains("#include <cooperative_groups.h>"));
969
970        // Should have software barrier
971        assert!(code.contains("software_grid_sync"));
972        assert!(code.contains("atomicAdd"));
973    }
974
975    #[test]
976    fn test_generate_kernel_command_handling() {
977        let config = PersistentFdtdConfig::default();
978        let code = generate_persistent_fdtd_kernel(&config);
979
980        // Check command handling
981        assert!(code.contains("CMD_RUN_STEPS"));
982        assert!(code.contains("CMD_TERMINATE"));
983        assert!(code.contains("CMD_INJECT_IMPULSE"));
984        assert!(code.contains("CMD_GET_PROGRESS"));
985
986        // Check response handling
987        assert!(code.contains("RESP_PROGRESS"));
988        assert!(code.contains("RESP_TERMINATED"));
989    }
990
991    #[test]
992    fn test_generate_kernel_halo_exchange() {
993        let config = PersistentFdtdConfig::new("test").with_tile_size(8, 8, 8);
994
995        let code = generate_persistent_fdtd_kernel(&config);
996
997        // Check halo defines
998        assert!(code.contains("#define TILE_X 8"));
999        assert!(code.contains("#define TILE_Y 8"));
1000        assert!(code.contains("#define TILE_Z 8"));
1001        assert!(code.contains("#define FACE_SIZE 64"));
1002
1003        // Check face indices
1004        assert!(code.contains("FACE_POS_X"));
1005        assert!(code.contains("FACE_NEG_Z"));
1006
1007        // Check K2K operations
1008        assert!(code.contains("pack_halo_faces"));
1009        assert!(code.contains("unpack_halo_faces"));
1010    }
1011
1012    #[test]
1013    fn test_kernel_contains_persistent_loop() {
1014        let config = PersistentFdtdConfig::default();
1015        let code = generate_persistent_fdtd_kernel(&config);
1016
1017        // Must have persistent loop structure
1018        assert!(code.contains("while (true)"));
1019        assert!(code.contains("if (ctrl->should_terminate)"));
1020        assert!(code.contains("break;"));
1021
1022        // Must handle no-work case
1023        assert!(code.contains("if (ctrl->steps_remaining == 0)"));
1024        assert!(code.contains("volatile int spin_count"));
1025    }
1026
1027    #[test]
1028    fn test_kernel_contains_energy_calculation() {
1029        let config = PersistentFdtdConfig::new("test_energy")
1030            .with_tile_size(8, 8, 8)
1031            .with_progress_interval(100);
1032
1033        let code = generate_persistent_fdtd_kernel(&config);
1034
1035        // Must have energy reduction function
1036        assert!(code.contains("block_reduce_energy"));
1037        assert!(code.contains("E = sum(p^2)"));
1038        assert!(code.contains("Parallel reduction tree"));
1039
1040        // Must have shared memory for energy reduction
1041        assert!(code.contains("energy_reduce[512]")); // 8*8*8 = 512 threads
1042
1043        // Must compute energy periodically (progress interval check)
1044        assert!(code.contains("is_progress_step"));
1045        assert!(code.contains("% 100"));
1046
1047        // Must reset energy accumulator before reduction
1048        assert!(code.contains("ctrl->total_energy = 0.0f"));
1049
1050        // Must compute p^2 for energy
1051        assert!(code.contains("p_val * p_val"));
1052
1053        // Must accumulate via atomicAdd
1054        assert!(code.contains("atomicAdd(&ctrl->total_energy"));
1055
1056        // Must use energy in progress response
1057        assert!(code.contains("resp.energy = ctrl->total_energy"));
1058
1059        // Must compute final energy at termination
1060        assert!(code.contains("FINAL ENERGY CALCULATION"));
1061        assert!(code.contains("block_final_energy"));
1062    }
1063}