ringkernel_cuda_codegen/
persistent_fdtd.rs

1//! Persistent FDTD kernel code generation.
2//!
3//! Generates CUDA code for truly persistent GPU actors that run for the
4//! entire simulation lifetime with:
5//!
6//! - Single kernel launch
7//! - H2K/K2H command/response messaging
8//! - K2K halo exchange between blocks
9//! - Grid-wide synchronization via cooperative groups
10//!
11//! # Architecture
12//!
13//! The generated kernel is structured as:
14//!
15//! 1. **Coordinator (Block 0)**: Processes H2K commands, sends K2H responses
16//! 2. **Worker Blocks (1..N)**: Compute FDTD steps, exchange halos via K2K
17//! 3. **Persistent Loop**: Runs until Terminate command received
18//!
19//! # Memory Layout
20//!
21//! ```text
22//! Kernel Parameters:
23//! - control_ptr:      PersistentControlBlock* (mapped memory)
24//! - pressure_a_ptr:   float* (device memory, buffer A)
25//! - pressure_b_ptr:   float* (device memory, buffer B)
26//! - h2k_header_ptr:   SpscQueueHeader* (mapped memory)
27//! - h2k_slots_ptr:    H2KMessage* (mapped memory)
28//! - k2h_header_ptr:   SpscQueueHeader* (mapped memory)
29//! - k2h_slots_ptr:    K2HMessage* (mapped memory)
30//! - routes_ptr:       K2KRouteEntry* (device memory)
31//! - halo_ptr:         float* (device memory, halo buffers)
32//! ```
33
34/// Configuration for persistent FDTD kernel generation.
35#[derive(Debug, Clone)]
36pub struct PersistentFdtdConfig {
37    /// Kernel function name.
38    pub name: String,
39    /// Tile size per dimension (cells per block).
40    pub tile_size: (usize, usize, usize),
41    /// Whether to use cooperative groups for grid sync.
42    pub use_cooperative: bool,
43    /// Progress report interval (steps).
44    pub progress_interval: u64,
45    /// Enable energy calculation.
46    pub track_energy: bool,
47}
48
49impl Default for PersistentFdtdConfig {
50    fn default() -> Self {
51        Self {
52            name: "persistent_fdtd3d".to_string(),
53            tile_size: (8, 8, 8),
54            use_cooperative: true,
55            progress_interval: 100,
56            track_energy: true,
57        }
58    }
59}
60
61impl PersistentFdtdConfig {
62    /// Create a new config with the given name.
63    pub fn new(name: &str) -> Self {
64        Self {
65            name: name.to_string(),
66            ..Default::default()
67        }
68    }
69
70    /// Set tile size.
71    pub fn with_tile_size(mut self, tx: usize, ty: usize, tz: usize) -> Self {
72        self.tile_size = (tx, ty, tz);
73        self
74    }
75
76    /// Set cooperative groups usage.
77    pub fn with_cooperative(mut self, use_coop: bool) -> Self {
78        self.use_cooperative = use_coop;
79        self
80    }
81
82    /// Set progress reporting interval.
83    pub fn with_progress_interval(mut self, interval: u64) -> Self {
84        self.progress_interval = interval;
85        self
86    }
87
88    /// Calculate threads per block.
89    pub fn threads_per_block(&self) -> usize {
90        self.tile_size.0 * self.tile_size.1 * self.tile_size.2
91    }
92
93    /// Calculate shared memory size per block (tile + halo).
94    pub fn shared_mem_size(&self) -> usize {
95        let (tx, ty, tz) = self.tile_size;
96        // Tile + 1 cell halo on each side
97        let with_halo = (tx + 2) * (ty + 2) * (tz + 2);
98        with_halo * std::mem::size_of::<f32>()
99    }
100}
101
102/// Generate a complete persistent FDTD kernel.
103///
104/// The generated kernel includes:
105/// - H2K command processing
106/// - K2H response generation
107/// - K2K halo exchange
108/// - FDTD computation with shared memory
109/// - Grid-wide synchronization
110pub fn generate_persistent_fdtd_kernel(config: &PersistentFdtdConfig) -> String {
111    let mut code = String::new();
112
113    // Header with includes
114    code.push_str(&generate_header(config));
115
116    // Structure definitions
117    code.push_str(&generate_structures());
118
119    // Device helper functions
120    code.push_str(&generate_device_functions(config));
121
122    // Main kernel
123    code.push_str(&generate_main_kernel(config));
124
125    code
126}
127
128fn generate_header(config: &PersistentFdtdConfig) -> String {
129    let mut code = String::new();
130
131    code.push_str("// Generated Persistent FDTD Kernel\n");
132    code.push_str("// RingKernel GPU Actor System\n\n");
133
134    if config.use_cooperative {
135        code.push_str("#include <cooperative_groups.h>\n");
136        code.push_str("namespace cg = cooperative_groups;\n\n");
137    }
138
139    code.push_str("#include <cuda_runtime.h>\n");
140    code.push_str("#include <stdint.h>\n\n");
141
142    code
143}
144
145fn generate_structures() -> String {
146    r#"
147// ============================================================================
148// STRUCTURE DEFINITIONS (must match Rust persistent.rs)
149// ============================================================================
150
151// Control block for persistent kernel (256 bytes, cache-aligned)
152typedef struct __align__(256) {
153    // Lifecycle control (host -> GPU)
154    uint32_t should_terminate;
155    uint32_t _pad0;
156    uint64_t steps_remaining;
157
158    // State (GPU -> host)
159    uint64_t current_step;
160    uint32_t current_buffer;
161    uint32_t has_terminated;
162    float total_energy;
163    uint32_t _pad1;
164
165    // Configuration
166    uint32_t grid_dim[3];
167    uint32_t block_dim[3];
168    uint32_t sim_size[3];
169    uint32_t tile_size[3];
170
171    // Acoustic parameters
172    float c2_dt2;
173    float damping;
174    float cell_size;
175    float dt;
176
177    // Synchronization
178    uint32_t barrier_counter;
179    uint32_t barrier_generation;
180
181    // Statistics
182    uint64_t messages_processed;
183    uint64_t k2k_messages_sent;
184    uint64_t k2k_messages_received;
185
186    uint64_t _reserved[16];
187} PersistentControlBlock;
188
189// SPSC queue header (128 bytes, cache-aligned)
190typedef struct __align__(128) {
191    uint64_t head;
192    uint64_t tail;
193    uint32_t capacity;
194    uint32_t mask;
195    uint64_t _padding[12];
196} SpscQueueHeader;
197
198// H2K message (64 bytes)
199typedef struct __align__(64) {
200    uint32_t cmd;
201    uint32_t flags;
202    uint64_t cmd_id;
203    uint64_t param1;
204    uint32_t param2;
205    uint32_t param3;
206    float param4;
207    float param5;
208    uint64_t _reserved[4];
209} H2KMessage;
210
211// K2H message (64 bytes)
212typedef struct __align__(64) {
213    uint32_t resp_type;
214    uint32_t flags;
215    uint64_t cmd_id;
216    uint64_t step;
217    uint64_t steps_remaining;
218    float energy;
219    uint32_t error_code;
220    uint64_t _reserved[3];
221} K2HMessage;
222
223// Command types
224#define CMD_NOP 0
225#define CMD_RUN_STEPS 1
226#define CMD_PAUSE 2
227#define CMD_RESUME 3
228#define CMD_TERMINATE 4
229#define CMD_INJECT_IMPULSE 5
230#define CMD_SET_SOURCE 6
231#define CMD_GET_PROGRESS 7
232
233// Response types
234#define RESP_ACK 0
235#define RESP_PROGRESS 1
236#define RESP_ERROR 2
237#define RESP_TERMINATED 3
238#define RESP_ENERGY 4
239
240// Neighbor block IDs
241typedef struct {
242    int32_t pos_x, neg_x;
243    int32_t pos_y, neg_y;
244    int32_t pos_z, neg_z;
245    int32_t _padding[2];
246} BlockNeighbors;
247
248// K2K route entry
249typedef struct {
250    BlockNeighbors neighbors;
251    uint32_t block_pos[3];
252    uint32_t _padding;
253    uint32_t cell_offset[3];
254    uint32_t _padding2;
255} K2KRouteEntry;
256
257// Face indices
258#define FACE_POS_X 0
259#define FACE_NEG_X 1
260#define FACE_POS_Y 2
261#define FACE_NEG_Y 3
262#define FACE_POS_Z 4
263#define FACE_NEG_Z 5
264
265"#
266    .to_string()
267}
268
269fn generate_device_functions(config: &PersistentFdtdConfig) -> String {
270    let (tx, ty, tz) = config.tile_size;
271    let face_size = tx * ty;
272
273    let mut code = String::new();
274
275    // Software grid barrier (fallback when cooperative not available)
276    code.push_str(
277        r#"
278// ============================================================================
279// SYNCHRONIZATION FUNCTIONS
280// ============================================================================
281
282// Software grid barrier (atomic counter + generation)
283__device__ void software_grid_sync(
284    volatile uint32_t* barrier_counter,
285    volatile uint32_t* barrier_gen,
286    int num_blocks
287) {
288    __syncthreads();  // First sync within block
289
290    if (threadIdx.x == 0) {
291        unsigned int gen = *barrier_gen;
292
293        unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
294        if (arrived == num_blocks) {
295            *barrier_counter = 0;
296            __threadfence();
297            atomicAdd((unsigned int*)barrier_gen, 1);
298        } else {
299            while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
300                __threadfence();
301            }
302        }
303    }
304
305    __syncthreads();
306}
307
308"#,
309    );
310
311    // H2K/K2H queue operations
312    code.push_str(
313        r#"
314// ============================================================================
315// MESSAGE QUEUE OPERATIONS
316// ============================================================================
317
318// Copy H2K message from volatile source (field-by-field to handle volatile)
319__device__ void copy_h2k_message(H2KMessage* dst, volatile H2KMessage* src) {
320    dst->cmd = src->cmd;
321    dst->flags = src->flags;
322    dst->cmd_id = src->cmd_id;
323    dst->param1 = src->param1;
324    dst->param2 = src->param2;
325    dst->param3 = src->param3;
326    dst->param4 = src->param4;
327    dst->param5 = src->param5;
328    // Skip reserved fields for performance
329}
330
331// Copy K2H message to volatile destination (field-by-field to handle volatile)
332__device__ void copy_k2h_message(volatile K2HMessage* dst, const K2HMessage* src) {
333    dst->resp_type = src->resp_type;
334    dst->flags = src->flags;
335    dst->cmd_id = src->cmd_id;
336    dst->step = src->step;
337    dst->steps_remaining = src->steps_remaining;
338    dst->energy = src->energy;
339    dst->error_code = src->error_code;
340    // Skip reserved fields for performance
341}
342
343// Try to receive H2K message (returns true if message available)
344__device__ bool h2k_try_recv(
345    volatile SpscQueueHeader* header,
346    volatile H2KMessage* slots,
347    H2KMessage* out_msg
348) {
349    // Fence BEFORE reading to ensure we see host writes
350    __threadfence_system();
351
352    uint64_t head = header->head;
353    uint64_t tail = header->tail;
354
355    if (head == tail) {
356        return false;  // Empty
357    }
358
359    uint32_t slot = tail & header->mask;
360    copy_h2k_message(out_msg, &slots[slot]);
361
362    __threadfence();
363    header->tail = tail + 1;
364
365    return true;
366}
367
368// Send K2H message
369__device__ bool k2h_send(
370    volatile SpscQueueHeader* header,
371    volatile K2HMessage* slots,
372    const K2HMessage* msg
373) {
374    uint64_t head = header->head;
375    uint64_t tail = header->tail;
376    uint32_t capacity = header->capacity;
377
378    if (head - tail >= capacity) {
379        return false;  // Full
380    }
381
382    uint32_t slot = head & header->mask;
383    copy_k2h_message(&slots[slot], msg);
384
385    __threadfence_system();  // Ensure host sees our writes
386    header->head = head + 1;
387
388    return true;
389}
390
391"#,
392    );
393
394    // Halo exchange functions
395    code.push_str(&format!(
396        r#"
397// ============================================================================
398// K2K HALO EXCHANGE
399// ============================================================================
400
401#define TILE_X {tx}
402#define TILE_Y {ty}
403#define TILE_Z {tz}
404#define FACE_SIZE {face_size}
405
406// Pack halo faces from shared memory to device halo buffer
407__device__ void pack_halo_faces(
408    float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
409    float* halo_buffers,
410    int block_id,
411    int pingpong,
412    int num_blocks
413) {{
414    int tid = threadIdx.x;
415    int face_stride = FACE_SIZE;
416    int block_stride = 6 * face_stride * 2;  // 6 faces, 2 ping-pong
417    int pp_offset = pingpong * 6 * face_stride;
418
419    float* block_halo = halo_buffers + block_id * block_stride + pp_offset;
420
421    if (tid < FACE_SIZE) {{
422        int fx = tid % TILE_X;
423        int fy = tid / TILE_X;
424
425        // +X face (x = TILE_X in local coords, which is index TILE_X)
426        block_halo[FACE_POS_X * face_stride + tid] = tile[fy + 1][fx + 1][TILE_X];
427        // -X face (x = 1 in local coords)
428        block_halo[FACE_NEG_X * face_stride + tid] = tile[fy + 1][fx + 1][1];
429        // +Y face
430        block_halo[FACE_POS_Y * face_stride + tid] = tile[fy + 1][TILE_Y][fx + 1];
431        // -Y face
432        block_halo[FACE_NEG_Y * face_stride + tid] = tile[fy + 1][1][fx + 1];
433        // +Z face
434        block_halo[FACE_POS_Z * face_stride + tid] = tile[TILE_Z][fy + 1][fx + 1];
435        // -Z face
436        block_halo[FACE_NEG_Z * face_stride + tid] = tile[1][fy + 1][fx + 1];
437    }}
438}}
439
440// Unpack halo faces from device buffer to shared memory ghost cells
441__device__ void unpack_halo_faces(
442    float tile[TILE_Z + 2][TILE_Y + 2][TILE_X + 2],
443    const float* halo_buffers,
444    const K2KRouteEntry* route,
445    int pingpong,
446    int num_blocks
447) {{
448    int tid = threadIdx.x;
449    int face_stride = FACE_SIZE;
450    int block_stride = 6 * face_stride * 2;
451    int pp_offset = pingpong * 6 * face_stride;
452
453    if (tid < FACE_SIZE) {{
454        int fx = tid % TILE_X;
455        int fy = tid / TILE_X;
456
457        // My +X ghost comes from neighbor's -X face
458        if (route->neighbors.pos_x >= 0) {{
459            const float* n_halo = halo_buffers + route->neighbors.pos_x * block_stride + pp_offset;
460            tile[fy + 1][fx + 1][TILE_X + 1] = n_halo[FACE_NEG_X * face_stride + tid];
461        }}
462
463        // My -X ghost comes from neighbor's +X face
464        if (route->neighbors.neg_x >= 0) {{
465            const float* n_halo = halo_buffers + route->neighbors.neg_x * block_stride + pp_offset;
466            tile[fy + 1][fx + 1][0] = n_halo[FACE_POS_X * face_stride + tid];
467        }}
468
469        // My +Y ghost
470        if (route->neighbors.pos_y >= 0) {{
471            const float* n_halo = halo_buffers + route->neighbors.pos_y * block_stride + pp_offset;
472            tile[fy + 1][TILE_Y + 1][fx + 1] = n_halo[FACE_NEG_Y * face_stride + tid];
473        }}
474
475        // My -Y ghost
476        if (route->neighbors.neg_y >= 0) {{
477            const float* n_halo = halo_buffers + route->neighbors.neg_y * block_stride + pp_offset;
478            tile[fy + 1][0][fx + 1] = n_halo[FACE_POS_Y * face_stride + tid];
479        }}
480
481        // My +Z ghost
482        if (route->neighbors.pos_z >= 0) {{
483            const float* n_halo = halo_buffers + route->neighbors.pos_z * block_stride + pp_offset;
484            tile[TILE_Z + 1][fy + 1][fx + 1] = n_halo[FACE_NEG_Z * face_stride + tid];
485        }}
486
487        // My -Z ghost
488        if (route->neighbors.neg_z >= 0) {{
489            const float* n_halo = halo_buffers + route->neighbors.neg_z * block_stride + pp_offset;
490            tile[0][fy + 1][fx + 1] = n_halo[FACE_POS_Z * face_stride + tid];
491        }}
492    }}
493}}
494
495"#,
496        tx = tx,
497        ty = ty,
498        tz = tz,
499        face_size = face_size
500    ));
501
502    code
503}
504
505fn generate_main_kernel(config: &PersistentFdtdConfig) -> String {
506    let (tx, ty, tz) = config.tile_size;
507    let threads_per_block = tx * ty * tz;
508    let progress_interval = config.progress_interval;
509
510    // For cooperative groups: declare grid variable once at init, then just call sync
511    // For software sync: no variable needed, just call the function each time
512    let grid_sync = if config.use_cooperative {
513        "grid.sync();"
514    } else {
515        "software_grid_sync(&ctrl->barrier_counter, &ctrl->barrier_generation, num_blocks);"
516    };
517
518    format!(
519        r#"
520// ============================================================================
521// MAIN PERSISTENT KERNEL
522// ============================================================================
523
524extern "C" __global__ void __launch_bounds__({threads_per_block}, 2)
525{name}(
526    PersistentControlBlock* __restrict__ ctrl,
527    float* __restrict__ pressure_a,
528    float* __restrict__ pressure_b,
529    SpscQueueHeader* __restrict__ h2k_header,
530    H2KMessage* __restrict__ h2k_slots,
531    SpscQueueHeader* __restrict__ k2h_header,
532    K2HMessage* __restrict__ k2h_slots,
533    const K2KRouteEntry* __restrict__ routes,
534    float* __restrict__ halo_buffers
535) {{
536    // Block and thread indices
537    int block_id = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
538    int num_blocks = gridDim.x * gridDim.y * gridDim.z;
539    int tid = threadIdx.x;
540    bool is_coordinator = (block_id == 0 && tid == 0);
541
542    // Shared memory for tile + ghost cells
543    __shared__ float tile[{tz} + 2][{ty} + 2][{tx} + 2];
544
545    // Get my routing info
546    const K2KRouteEntry* my_route = &routes[block_id];
547
548    // Pointers to pressure buffers (will swap)
549    float* p_curr = pressure_a;
550    float* p_prev = pressure_b;
551
552    // Grid dimensions for indexing
553    int sim_x = ctrl->sim_size[0];
554    int sim_y = ctrl->sim_size[1];
555    int sim_z = ctrl->sim_size[2];
556
557    // Local cell coordinates within tile
558    int lx = tid % {tx};
559    int ly = (tid / {tx}) % {ty};
560    int lz = tid / ({tx} * {ty});
561
562    // Global cell coordinates
563    int gx = my_route->cell_offset[0] + lx;
564    int gy = my_route->cell_offset[1] + ly;
565    int gz = my_route->cell_offset[2] + lz;
566    int global_idx = gx + gy * sim_x + gz * sim_x * sim_y;
567
568    // Physics parameters
569    float c2_dt2 = ctrl->c2_dt2;
570    float damping = ctrl->damping;
571
572    {grid_sync_init}
573
574    // ========== PERSISTENT LOOP ==========
575    while (true) {{
576        // --- Phase 1: Command Processing (coordinator only) ---
577        if (is_coordinator) {{
578            H2KMessage cmd;
579            while (h2k_try_recv(h2k_header, h2k_slots, &cmd)) {{
580                ctrl->messages_processed++;
581
582                switch (cmd.cmd) {{
583                    case CMD_RUN_STEPS:
584                        ctrl->steps_remaining += cmd.param1;
585                        break;
586
587                    case CMD_TERMINATE:
588                        ctrl->should_terminate = 1;
589                        break;
590
591                    case CMD_INJECT_IMPULSE: {{
592                        // Extract position from params
593                        uint32_t ix = (cmd.param1 >> 32) & 0xFFFFFFFF;
594                        uint32_t iy = cmd.param1 & 0xFFFFFFFF;
595                        uint32_t iz = cmd.param2;
596                        float amplitude = cmd.param4;
597
598                        if (ix < sim_x && iy < sim_y && iz < sim_z) {{
599                            int idx = ix + iy * sim_x + iz * sim_x * sim_y;
600                            p_curr[idx] += amplitude;
601                        }}
602                        break;
603                    }}
604
605                    case CMD_GET_PROGRESS: {{
606                        K2HMessage resp;
607                        resp.resp_type = RESP_PROGRESS;
608                        resp.cmd_id = cmd.cmd_id;
609                        resp.step = ctrl->current_step;
610                        resp.steps_remaining = ctrl->steps_remaining;
611                        resp.energy = ctrl->total_energy;
612                        k2h_send(k2h_header, k2h_slots, &resp);
613                        break;
614                    }}
615
616                    default:
617                        break;
618                }}
619            }}
620        }}
621
622        // Grid-wide sync so all blocks see updated control state
623        {grid_sync}
624
625        // Check for termination
626        if (ctrl->should_terminate) {{
627            break;
628        }}
629
630        // --- Phase 2: Check if we have work ---
631        if (ctrl->steps_remaining == 0) {{
632            // No work - brief spinwait then check again
633            // Use volatile counter to prevent optimization
634            volatile int spin_count = 0;
635            for (int i = 0; i < 1000; i++) {{
636                spin_count++;
637            }}
638            {grid_sync}
639            continue;
640        }}
641
642        // --- Phase 3: Load tile from global memory ---
643        // Load interior cells
644        if (gx < sim_x && gy < sim_y && gz < sim_z) {{
645            tile[lz + 1][ly + 1][lx + 1] = p_curr[global_idx];
646        }} else {{
647            tile[lz + 1][ly + 1][lx + 1] = 0.0f;
648        }}
649
650        // Initialize ghost cells to boundary (zero pressure)
651        if (lx == 0) tile[lz + 1][ly + 1][0] = 0.0f;
652        if (lx == {tx} - 1) tile[lz + 1][ly + 1][{tx} + 1] = 0.0f;
653        if (ly == 0) tile[lz + 1][0][lx + 1] = 0.0f;
654        if (ly == {ty} - 1) tile[lz + 1][{ty} + 1][lx + 1] = 0.0f;
655        if (lz == 0) tile[0][ly + 1][lx + 1] = 0.0f;
656        if (lz == {tz} - 1) tile[{tz} + 1][ly + 1][lx + 1] = 0.0f;
657
658        __syncthreads();
659
660        // --- Phase 4: K2K Halo Exchange ---
661        int pingpong = ctrl->current_step & 1;
662
663        // Pack my boundary faces into halo buffers
664        pack_halo_faces(tile, halo_buffers, block_id, pingpong, num_blocks);
665        __threadfence();  // Ensure writes visible to other blocks
666
667        // Grid-wide sync - wait for all blocks to finish packing
668        {grid_sync}
669
670        // Unpack neighbor faces into my ghost cells
671        unpack_halo_faces(tile, halo_buffers, my_route, pingpong, num_blocks);
672        __syncthreads();
673
674        // --- Phase 5: FDTD Computation ---
675        if (gx < sim_x && gy < sim_y && gz < sim_z) {{
676            // 7-point Laplacian from shared memory
677            float center = tile[lz + 1][ly + 1][lx + 1];
678            float lap = tile[lz + 1][ly + 1][lx + 2]   // +X
679                      + tile[lz + 1][ly + 1][lx]       // -X
680                      + tile[lz + 1][ly + 2][lx + 1]   // +Y
681                      + tile[lz + 1][ly][lx + 1]       // -Y
682                      + tile[lz + 2][ly + 1][lx + 1]   // +Z
683                      + tile[lz][ly + 1][lx + 1]       // -Z
684                      - 6.0f * center;
685
686            // FDTD update: p_new = 2*p - p_prev + c^2*dt^2*lap
687            float p_prev_val = p_prev[global_idx];
688            float p_new = 2.0f * center - p_prev_val + c2_dt2 * lap;
689            p_new *= damping;
690
691            // Write to "previous" buffer (will become current after swap)
692            p_prev[global_idx] = p_new;
693        }}
694
695        // Grid-wide sync before buffer swap
696        {grid_sync}
697
698        // --- Phase 6: Buffer Swap & Progress ---
699        if (is_coordinator) {{
700            // Swap buffer pointers (toggle index)
701            ctrl->current_buffer ^= 1;
702
703            // Update counters
704            ctrl->current_step++;
705            ctrl->steps_remaining--;
706
707            // Send progress update periodically
708            if (ctrl->current_step % {progress_interval} == 0) {{
709                K2HMessage resp;
710                resp.resp_type = RESP_PROGRESS;
711                resp.cmd_id = 0;
712                resp.step = ctrl->current_step;
713                resp.steps_remaining = ctrl->steps_remaining;
714                resp.energy = 0.0f;  // TODO: calculate energy
715                k2h_send(k2h_header, k2h_slots, &resp);
716            }}
717        }}
718
719        // Swap our local pointers
720        float* tmp = p_curr;
721        p_curr = p_prev;
722        p_prev = tmp;
723
724        {grid_sync}
725    }}
726
727    // ========== CLEANUP ==========
728    if (is_coordinator) {{
729        ctrl->has_terminated = 1;
730
731        K2HMessage resp;
732        resp.resp_type = RESP_TERMINATED;
733        resp.cmd_id = 0;
734        resp.step = ctrl->current_step;
735        resp.steps_remaining = 0;
736        resp.energy = 0.0f;
737        k2h_send(k2h_header, k2h_slots, &resp);
738    }}
739}}
740"#,
741        name = config.name,
742        tx = tx,
743        ty = ty,
744        tz = tz,
745        threads_per_block = threads_per_block,
746        progress_interval = progress_interval,
747        grid_sync_init = if config.use_cooperative {
748            "cg::grid_group grid = cg::this_grid();"
749        } else {
750            ""
751        },
752        grid_sync = grid_sync,
753    )
754}
755
756/// Generate PTX for the persistent FDTD kernel using nvcc.
757///
758/// This requires nvcc to be installed.
759#[cfg(feature = "nvcc")]
760pub fn compile_persistent_fdtd_to_ptx(config: &PersistentFdtdConfig) -> Result<String, String> {
761    let cuda_code = generate_persistent_fdtd_kernel(config);
762
763    // Use nvcc to compile to PTX
764    use std::io::Write;
765    use std::process::Command;
766
767    // Write to temp file
768    let temp_dir = std::env::temp_dir();
769    let cuda_file = temp_dir.join("persistent_fdtd.cu");
770    let ptx_file = temp_dir.join("persistent_fdtd.ptx");
771
772    std::fs::write(&cuda_file, &cuda_code)
773        .map_err(|e| format!("Failed to write CUDA file: {}", e))?;
774
775    // Compile with nvcc
776    // Use -arch=native to automatically detect the GPU architecture
777    // This ensures compatibility with newer CUDA versions that dropped older sm_* support
778    let mut args = vec![
779        "-ptx".to_string(),
780        "-o".to_string(),
781        ptx_file.to_string_lossy().to_string(),
782        cuda_file.to_string_lossy().to_string(),
783        "-arch=native".to_string(),
784        "-std=c++17".to_string(),
785    ];
786
787    if config.use_cooperative {
788        args.push("-rdc=true".to_string()); // Required for cooperative groups
789    }
790
791    let output = Command::new("nvcc")
792        .args(&args)
793        .output()
794        .map_err(|e| format!("Failed to run nvcc: {}", e))?;
795
796    if !output.status.success() {
797        return Err(format!(
798            "nvcc compilation failed:\n{}",
799            String::from_utf8_lossy(&output.stderr)
800        ));
801    }
802
803    std::fs::read_to_string(&ptx_file).map_err(|e| format!("Failed to read PTX: {}", e))
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809
810    #[test]
811    fn test_default_config() {
812        let config = PersistentFdtdConfig::default();
813        assert_eq!(config.name, "persistent_fdtd3d");
814        assert_eq!(config.tile_size, (8, 8, 8));
815        assert_eq!(config.threads_per_block(), 512);
816    }
817
818    #[test]
819    fn test_config_builder() {
820        let config = PersistentFdtdConfig::new("my_kernel")
821            .with_tile_size(4, 4, 4)
822            .with_cooperative(false)
823            .with_progress_interval(50);
824
825        assert_eq!(config.name, "my_kernel");
826        assert_eq!(config.tile_size, (4, 4, 4));
827        assert_eq!(config.threads_per_block(), 64);
828        assert!(!config.use_cooperative);
829        assert_eq!(config.progress_interval, 50);
830    }
831
832    #[test]
833    fn test_shared_mem_calculation() {
834        let config = PersistentFdtdConfig::default(); // 8x8x8
835                                                      // With halo: 10x10x10 = 1000 floats = 4000 bytes
836        assert_eq!(config.shared_mem_size(), 4000);
837    }
838
839    #[test]
840    fn test_generate_kernel_cooperative() {
841        let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(true);
842
843        let code = generate_persistent_fdtd_kernel(&config);
844
845        // Check includes
846        assert!(code.contains("#include <cooperative_groups.h>"));
847        assert!(code.contains("namespace cg = cooperative_groups;"));
848
849        // Check structures
850        assert!(code.contains("typedef struct __align__(256)"));
851        assert!(code.contains("PersistentControlBlock"));
852        assert!(code.contains("SpscQueueHeader"));
853        assert!(code.contains("H2KMessage"));
854        assert!(code.contains("K2HMessage"));
855
856        // Check device functions
857        assert!(code.contains("__device__ bool h2k_try_recv"));
858        assert!(code.contains("__device__ bool k2h_send"));
859        assert!(code.contains("__device__ void pack_halo_faces"));
860        assert!(code.contains("__device__ void unpack_halo_faces"));
861
862        // Check main kernel
863        assert!(code.contains("extern \"C\" __global__ void"));
864        assert!(code.contains("test_kernel"));
865        assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
866        assert!(code.contains("grid.sync()"));
867
868        // Check FDTD computation
869        assert!(code.contains("7-point Laplacian"));
870        assert!(code.contains("c2_dt2 * lap"));
871    }
872
873    #[test]
874    fn test_generate_kernel_software_sync() {
875        let config = PersistentFdtdConfig::new("test_kernel").with_cooperative(false);
876
877        let code = generate_persistent_fdtd_kernel(&config);
878
879        // Should NOT have cooperative groups
880        assert!(!code.contains("#include <cooperative_groups.h>"));
881
882        // Should have software barrier
883        assert!(code.contains("software_grid_sync"));
884        assert!(code.contains("atomicAdd"));
885    }
886
887    #[test]
888    fn test_generate_kernel_command_handling() {
889        let config = PersistentFdtdConfig::default();
890        let code = generate_persistent_fdtd_kernel(&config);
891
892        // Check command handling
893        assert!(code.contains("CMD_RUN_STEPS"));
894        assert!(code.contains("CMD_TERMINATE"));
895        assert!(code.contains("CMD_INJECT_IMPULSE"));
896        assert!(code.contains("CMD_GET_PROGRESS"));
897
898        // Check response handling
899        assert!(code.contains("RESP_PROGRESS"));
900        assert!(code.contains("RESP_TERMINATED"));
901    }
902
903    #[test]
904    fn test_generate_kernel_halo_exchange() {
905        let config = PersistentFdtdConfig::new("test").with_tile_size(8, 8, 8);
906
907        let code = generate_persistent_fdtd_kernel(&config);
908
909        // Check halo defines
910        assert!(code.contains("#define TILE_X 8"));
911        assert!(code.contains("#define TILE_Y 8"));
912        assert!(code.contains("#define TILE_Z 8"));
913        assert!(code.contains("#define FACE_SIZE 64"));
914
915        // Check face indices
916        assert!(code.contains("FACE_POS_X"));
917        assert!(code.contains("FACE_NEG_Z"));
918
919        // Check K2K operations
920        assert!(code.contains("pack_halo_faces"));
921        assert!(code.contains("unpack_halo_faces"));
922    }
923
924    #[test]
925    fn test_kernel_contains_persistent_loop() {
926        let config = PersistentFdtdConfig::default();
927        let code = generate_persistent_fdtd_kernel(&config);
928
929        // Must have persistent loop structure
930        assert!(code.contains("while (true)"));
931        assert!(code.contains("if (ctrl->should_terminate)"));
932        assert!(code.contains("break;"));
933
934        // Must handle no-work case
935        assert!(code.contains("if (ctrl->steps_remaining == 0)"));
936        assert!(code.contains("volatile int spin_count"));
937    }
938}