Skip to main content

ringkernel_wavesim/simulation/
kernels.rs

1//! CUDA kernel definitions using the Rust DSL.
2//!
3//! This module contains all CUDA kernels for the wave simulation, defined in
4//! a Rust DSL that gets transpiled to CUDA C at compile time.
5//!
6//! The generated CUDA code is designed to match the handwritten versions in
7//! `shaders/fdtd_tile.cu` and `shaders/fdtd_packed.cu` exactly.
8
9#[cfg(feature = "cuda-codegen")]
10use ringkernel_cuda_codegen::{
11    transpile_global_kernel, transpile_ring_kernel, transpile_stencil_kernel, Grid,
12    RingKernelConfig, StencilConfig,
13};
14
15// ============================================================================
16// Tile-Based Kernels (fdtd_tile.cu equivalent)
17// ============================================================================
18
19/// Generate the complete CUDA source for tile-based kernels.
20///
21/// This generates CUDA code equivalent to `shaders/fdtd_tile.cu`:
22/// - `fdtd_tile_step`: Main FDTD wave equation kernel
23/// - `extract_halo`: Extract halo from interior edge
24/// - `inject_halo`: Inject halo to boundary region
25/// - `read_interior`: Read interior cells to output buffer
26/// - `apply_boundary_reflection`: Apply boundary conditions
27#[cfg(feature = "cuda-codegen")]
28pub fn generate_tile_kernels() -> String {
29    let mut output = String::new();
30
31    output.push_str(TILE_KERNELS_HEADER);
32    output.push_str("\nextern \"C\" {\n\n");
33
34    // Generate fdtd_tile_step kernel
35    output.push_str(&generate_fdtd_tile_step());
36    output.push('\n');
37
38    // Generate extract_halo kernel
39    output.push_str(&generate_extract_halo());
40    output.push('\n');
41
42    // Generate inject_halo kernel
43    output.push_str(&generate_inject_halo());
44    output.push('\n');
45
46    // Generate read_interior kernel
47    output.push_str(&generate_read_interior());
48    output.push('\n');
49
50    // Generate apply_boundary_reflection kernel
51    output.push_str(&generate_apply_boundary_reflection());
52
53    output.push_str("\n}  // extern \"C\"\n");
54
55    output
56}
57
58/// Header comment for generated tile kernels.
59pub const TILE_KERNELS_HEADER: &str = r#"// CUDA Kernels for Tile-Based FDTD Wave Simulation
60// Generated by ringkernel-cuda-codegen from Rust DSL
61//
62// Buffer Layout (18x18 = 324 floats):
63//   +---+----------------+---+
64//   | NW|   North Halo   |NE |  <- Row 0
65//   +---+----------------+---+
66//   |   |   16x16 Tile   |   |  <- Rows 1-16
67//   | W |    Interior    | E |
68//   +---+----------------+---+
69//   | SW|   South Halo   |SE |  <- Row 17
70//   +---+----------------+---+
71//
72// Index: idx = y * 18 + x
73// Interior cell (lx, ly): idx = (ly + 1) * 18 + (lx + 1)
74"#;
75
76/// Generate the main FDTD tile step kernel.
77#[cfg(feature = "cuda-codegen")]
78fn generate_fdtd_tile_step() -> String {
79    use syn::parse_quote;
80
81    let kernel_fn: syn::ItemFn = parse_quote! {
82        fn fdtd_tile_step(
83            pressure: &[f32],
84            pressure_prev: &mut [f32],
85            c2: f32,
86            damping: f32,
87            pos: GridPos,
88        ) {
89            let p = pressure[pos.idx()];
90            let p_prev = pressure_prev[pos.idx()];
91
92            let p_n = pos.north(pressure);
93            let p_s = pos.south(pressure);
94            let p_w = pos.west(pressure);
95            let p_e = pos.east(pressure);
96
97            let laplacian = p_n + p_s + p_e + p_w - 4.0 * p;
98            let p_new = 2.0 * p - p_prev + c2 * laplacian;
99
100            pressure_prev[pos.idx()] = p_new * damping;
101        }
102    };
103
104    let config = StencilConfig::new("fdtd_tile_step")
105        .with_grid(Grid::Grid2D)
106        .with_tile_size(16, 16)
107        .with_halo(1);
108
109    match transpile_stencil_kernel(&kernel_fn, &config) {
110        Ok(cuda) => cuda,
111        Err(e) => format!("// Transpilation error: {}\n", e),
112    }
113}
114
115/// Generate the extract_halo kernel.
116#[cfg(feature = "cuda-codegen")]
117fn generate_extract_halo() -> String {
118    use syn::parse_quote;
119
120    let kernel_fn: syn::ItemFn = parse_quote! {
121        fn extract_halo(
122            pressure: &[f32],
123            halo_out: &mut [f32],
124            edge: i32,
125        ) {
126            let i = thread_idx_x();
127            if i >= 16 {
128                return;
129            }
130
131            let idx = match edge {
132                0 => 1 * 18 + (i + 1),       // North - extract row 1
133                1 => 16 * 18 + (i + 1),      // South - extract row 16
134                2 => (i + 1) * 18 + 1,       // West - extract col 1
135                _ => (i + 1) * 18 + 16,      // East - extract col 16
136            };
137
138            halo_out[i as usize] = pressure[idx as usize];
139        }
140    };
141
142    match transpile_global_kernel(&kernel_fn) {
143        Ok(cuda) => format!("// Extract Halo - extracts halo data from interior edge\n// edge: 0=North, 1=South, 2=West, 3=East\n{}", cuda),
144        Err(e) => format!("// Transpilation error: {}\n", e),
145    }
146}
147
148/// Generate the inject_halo kernel.
149#[cfg(feature = "cuda-codegen")]
150fn generate_inject_halo() -> String {
151    use syn::parse_quote;
152
153    let kernel_fn: syn::ItemFn = parse_quote! {
154        fn inject_halo(
155            pressure: &mut [f32],
156            halo_in: &[f32],
157            edge: i32,
158        ) {
159            let i = thread_idx_x();
160            if i >= 16 {
161                return;
162            }
163
164            let idx = match edge {
165                0 => 0 * 18 + (i + 1),       // North - inject to row 0
166                1 => 17 * 18 + (i + 1),      // South - inject to row 17
167                2 => (i + 1) * 18 + 0,       // West - inject to col 0
168                _ => (i + 1) * 18 + 17,      // East - inject to col 17
169            };
170
171            pressure[idx as usize] = halo_in[i as usize];
172        }
173    };
174
175    match transpile_global_kernel(&kernel_fn) {
176        Ok(cuda) => format!("// Inject Halo - injects halo data from linear buffer to halo region\n// edge: 0=North, 1=South, 2=West, 3=East\n{}", cuda),
177        Err(e) => format!("// Transpilation error: {}\n", e),
178    }
179}
180
181/// Generate the read_interior kernel.
182#[cfg(feature = "cuda-codegen")]
183fn generate_read_interior() -> String {
184    use syn::parse_quote;
185
186    let kernel_fn: syn::ItemFn = parse_quote! {
187        fn read_interior(
188            pressure: &[f32],
189            output: &mut [f32],
190        ) {
191            let lx = thread_idx_x();
192            let ly = thread_idx_y();
193
194            if lx >= 16 || ly >= 16 {
195                return;
196            }
197
198            let src_idx = (ly + 1) * 18 + (lx + 1);
199            let dst_idx = ly * 16 + lx;
200
201            output[dst_idx as usize] = pressure[src_idx as usize];
202        }
203    };
204
205    match transpile_global_kernel(&kernel_fn) {
206        Ok(cuda) => format!(
207            "// Read Interior - reads interior pressure to linear buffer for visualization\n{}",
208            cuda
209        ),
210        Err(e) => format!("// Transpilation error: {}\n", e),
211    }
212}
213
214/// Generate the apply_boundary_reflection kernel.
215#[cfg(feature = "cuda-codegen")]
216fn generate_apply_boundary_reflection() -> String {
217    use syn::parse_quote;
218
219    let kernel_fn: syn::ItemFn = parse_quote! {
220        fn apply_boundary_reflection(
221            pressure: &mut [f32],
222            edge: i32,
223            reflection_coeff: f32,
224        ) {
225            let i = thread_idx_x();
226            if i >= 16 {
227                return;
228            }
229
230            // Calculate source and destination indices based on edge
231            let src_idx = match edge {
232                0 => 1 * 18 + (i + 1),       // North - reflect row 1
233                1 => 16 * 18 + (i + 1),      // South - reflect row 16
234                2 => (i + 1) * 18 + 1,       // West - reflect col 1
235                _ => (i + 1) * 18 + 16,      // East - reflect col 16
236            };
237
238            let dst_idx = match edge {
239                0 => 0 * 18 + (i + 1),       // North - to row 0
240                1 => 17 * 18 + (i + 1),      // South - to row 17
241                2 => (i + 1) * 18 + 0,       // West - to col 0
242                _ => (i + 1) * 18 + 17,      // East - to col 17
243            };
244
245            pressure[dst_idx as usize] = pressure[src_idx as usize] * reflection_coeff;
246        }
247    };
248
249    match transpile_global_kernel(&kernel_fn) {
250        Ok(cuda) => format!("// Apply Boundary Reflection - applies boundary conditions for tiles at grid edges\n// edge: 0=North, 1=South, 2=West, 3=East\n{}", cuda),
251        Err(e) => format!("// Transpilation error: {}\n", e),
252    }
253}
254
255// ============================================================================
256// Packed Tile Kernels (fdtd_packed.cu equivalent)
257// ============================================================================
258
259/// Generate the complete CUDA source for packed tile kernels.
260///
261/// This generates CUDA code equivalent to `shaders/fdtd_packed.cu`:
262/// - `exchange_all_halos`: Copy halos between adjacent tiles
263/// - `fdtd_all_tiles`: Batched FDTD for all tiles in parallel
264/// - `upload_tile_data`: Upload initial state to a tile
265/// - `read_all_interiors`: Read all tile interiors to output
266/// - `inject_impulse`: Add impulse to specific cell
267/// - `apply_boundary_conditions`: Apply boundary conditions to edge tiles
268#[cfg(feature = "cuda-codegen")]
269pub fn generate_packed_kernels() -> String {
270    let mut output = String::new();
271
272    output.push_str(PACKED_KERNELS_HEADER);
273    output.push_str("\nextern \"C\" {\n\n");
274
275    output.push_str(&generate_exchange_all_halos());
276    output.push('\n');
277
278    output.push_str(&generate_fdtd_all_tiles());
279    output.push('\n');
280
281    output.push_str(&generate_upload_tile_data());
282    output.push('\n');
283
284    output.push_str(&generate_read_all_interiors());
285    output.push('\n');
286
287    output.push_str(&generate_inject_impulse());
288    output.push('\n');
289
290    output.push_str(&generate_apply_boundary_conditions());
291
292    output.push_str("\n}  // extern \"C\"\n");
293
294    output
295}
296
297/// Header comment for generated packed kernels.
298pub const PACKED_KERNELS_HEADER: &str = r#"// CUDA Kernels for Packed Tile-Based FDTD Wave Simulation
299// Generated by ringkernel-cuda-codegen from Rust DSL
300//
301// All tiles packed contiguously: [Tile(0,0)][Tile(1,0)]...[Tile(n,m)]
302// Each tile is 18x18 floats (16x16 interior + 1-cell halo)
303//
304// Benefits:
305// - Zero host<->GPU transfers during simulation
306// - All tiles computed in parallel
307// - Halo exchange is just GPU memory copies
308"#;
309
310/// Generate exchange_all_halos kernel.
311#[cfg(feature = "cuda-codegen")]
312fn generate_exchange_all_halos() -> String {
313    use syn::parse_quote;
314
315    let kernel_fn: syn::ItemFn = parse_quote! {
316        fn exchange_all_halos(
317            packed_buffer: &mut [f32],
318            copies: &[u32],
319            num_copies: i32,
320        ) {
321            let idx = block_idx_x() * block_dim_x() + thread_idx_x();
322            if idx >= num_copies {
323                return;
324            }
325
326            let src_idx = copies[(idx * 2) as usize];
327            let dst_idx = copies[(idx * 2 + 1) as usize];
328
329            packed_buffer[dst_idx as usize] = packed_buffer[src_idx as usize];
330        }
331    };
332
333    match transpile_global_kernel(&kernel_fn) {
334        Ok(cuda) => format!(
335            "// Halo Exchange Kernel - copies all halo edges between adjacent tiles\n{}",
336            cuda
337        ),
338        Err(e) => format!("// Transpilation error: {}\n", e),
339    }
340}
341
342/// Generate fdtd_all_tiles kernel.
343#[cfg(feature = "cuda-codegen")]
344fn generate_fdtd_all_tiles() -> String {
345    use syn::parse_quote;
346
347    let kernel_fn: syn::ItemFn = parse_quote! {
348        fn fdtd_all_tiles(
349            packed_curr: &[f32],
350            packed_prev: &mut [f32],
351            tiles_x: i32,
352            tiles_y: i32,
353            tile_size: i32,
354            buffer_width: i32,
355            c2: f32,
356            damping: f32,
357        ) {
358            let tile_x = block_idx_x();
359            let tile_y = block_idx_y();
360            let lx = thread_idx_x();
361            let ly = thread_idx_y();
362
363            if tile_x >= tiles_x || tile_y >= tiles_y {
364                return;
365            }
366            if lx >= tile_size || ly >= tile_size {
367                return;
368            }
369
370            let tile_buffer_size = buffer_width * buffer_width;
371            let tile_idx = tile_y * tiles_x + tile_x;
372            let tile_offset = tile_idx * tile_buffer_size;
373
374            let idx = tile_offset + (ly + 1) * buffer_width + (lx + 1);
375
376            let p = packed_curr[idx as usize];
377            let p_prev_val = packed_prev[idx as usize];
378
379            let p_n = packed_curr[(idx - buffer_width) as usize];
380            let p_s = packed_curr[(idx + buffer_width) as usize];
381            let p_w = packed_curr[(idx - 1) as usize];
382            let p_e = packed_curr[(idx + 1) as usize];
383
384            let laplacian = p_n + p_s + p_e + p_w - 4.0 * p;
385            let p_new = 2.0 * p - p_prev_val + c2 * laplacian;
386
387            packed_prev[idx as usize] = p_new * damping;
388        }
389    };
390
391    match transpile_global_kernel(&kernel_fn) {
392        Ok(cuda) => format!(
393            "// Batched FDTD Kernel - computes FDTD for ALL tiles in single launch\n{}",
394            cuda
395        ),
396        Err(e) => format!("// Transpilation error: {}\n", e),
397    }
398}
399
400/// Generate upload_tile_data kernel.
401#[cfg(feature = "cuda-codegen")]
402fn generate_upload_tile_data() -> String {
403    use syn::parse_quote;
404
405    let kernel_fn: syn::ItemFn = parse_quote! {
406        fn upload_tile_data(
407            packed_buffer: &mut [f32],
408            staging: &[f32],
409            tile_x: i32,
410            tile_y: i32,
411            tiles_x: i32,
412            buffer_width: i32,
413        ) {
414            let lx = thread_idx_x();
415            let ly = thread_idx_y();
416
417            if lx >= buffer_width || ly >= buffer_width {
418                return;
419            }
420
421            let tile_buffer_size = buffer_width * buffer_width;
422            let tile_idx = tile_y * tiles_x + tile_x;
423            let tile_offset = tile_idx * tile_buffer_size;
424
425            let local_idx = ly * buffer_width + lx;
426            let global_idx = tile_offset + local_idx;
427
428            packed_buffer[global_idx as usize] = staging[local_idx as usize];
429        }
430    };
431
432    match transpile_global_kernel(&kernel_fn) {
433        Ok(cuda) => format!(
434            "// Upload Initial State - copies initial data to packed buffer\n{}",
435            cuda
436        ),
437        Err(e) => format!("// Transpilation error: {}\n", e),
438    }
439}
440
441/// Generate read_all_interiors kernel.
442#[cfg(feature = "cuda-codegen")]
443fn generate_read_all_interiors() -> String {
444    use syn::parse_quote;
445
446    let kernel_fn: syn::ItemFn = parse_quote! {
447        fn read_all_interiors(
448            packed_buffer: &[f32],
449            output: &mut [f32],
450            tiles_x: i32,
451            tiles_y: i32,
452            tile_size: i32,
453            buffer_width: i32,
454            grid_width: i32,
455            grid_height: i32,
456        ) {
457            let gx = block_idx_x() * block_dim_x() + thread_idx_x();
458            let gy = block_idx_y() * block_dim_y() + thread_idx_y();
459
460            if gx >= grid_width || gy >= grid_height {
461                return;
462            }
463
464            let tile_x = gx / tile_size;
465            let tile_y = gy / tile_size;
466
467            let lx = gx % tile_size;
468            let ly = gy % tile_size;
469
470            let tile_buffer_size = buffer_width * buffer_width;
471            let tile_idx = tile_y * tiles_x + tile_x;
472            let tile_offset = tile_idx * tile_buffer_size;
473            let src_idx = tile_offset + (ly + 1) * buffer_width + (lx + 1);
474
475            let dst_idx = gy * grid_width + gx;
476
477            output[dst_idx as usize] = packed_buffer[src_idx as usize];
478        }
479    };
480
481    match transpile_global_kernel(&kernel_fn) {
482        Ok(cuda) => format!(
483            "// Read All Interiors - extracts all tile interiors for visualization\n{}",
484            cuda
485        ),
486        Err(e) => format!("// Transpilation error: {}\n", e),
487    }
488}
489
490/// Generate inject_impulse kernel.
491#[cfg(feature = "cuda-codegen")]
492fn generate_inject_impulse() -> String {
493    use syn::parse_quote;
494
495    let kernel_fn: syn::ItemFn = parse_quote! {
496        fn inject_impulse(
497            packed_buffer: &mut [f32],
498            tile_x: i32,
499            tile_y: i32,
500            local_x: i32,
501            local_y: i32,
502            tiles_x: i32,
503            buffer_width: i32,
504            amplitude: f32,
505        ) {
506            let tile_buffer_size = buffer_width * buffer_width;
507            let tile_idx = tile_y * tiles_x + tile_x;
508            let tile_offset = tile_idx * tile_buffer_size;
509            let idx = tile_offset + (local_y + 1) * buffer_width + (local_x + 1);
510
511            packed_buffer[idx as usize] = packed_buffer[idx as usize] + amplitude;
512        }
513    };
514
515    match transpile_global_kernel(&kernel_fn) {
516        Ok(cuda) => format!("// Inject Impulse - adds energy to specific cell\n{}", cuda),
517        Err(e) => format!("// Transpilation error: {}\n", e),
518    }
519}
520
521/// Generate apply_boundary_conditions kernel.
522#[cfg(feature = "cuda-codegen")]
523fn generate_apply_boundary_conditions() -> String {
524    use syn::parse_quote;
525
526    let kernel_fn: syn::ItemFn = parse_quote! {
527        fn apply_boundary_conditions(
528            packed_buffer: &mut [f32],
529            tiles_x: i32,
530            tiles_y: i32,
531            tile_size: i32,
532            buffer_width: i32,
533            reflection_coeff: f32,
534        ) {
535            let edge = block_idx_x();
536            let idx = thread_idx_x();
537
538            let tile_buffer_size = buffer_width * buffer_width;
539
540            if edge == 0 {
541                // North boundary: tiles with tile_y == 0
542                let tile_x = idx / tile_size;
543                let cell_x = idx % tile_size;
544                if tile_x >= tiles_x {
545                    return;
546                }
547
548                let tile_idx = 0 * tiles_x + tile_x;
549                let tile_offset = tile_idx * tile_buffer_size;
550                let src_idx = tile_offset + 1 * buffer_width + (cell_x + 1);
551                let dst_idx = tile_offset + 0 * buffer_width + (cell_x + 1);
552                packed_buffer[dst_idx as usize] = packed_buffer[src_idx as usize] * reflection_coeff;
553            } else if edge == 1 {
554                // South boundary: tiles with tile_y == tiles_y - 1
555                let tile_x = idx / tile_size;
556                let cell_x = idx % tile_size;
557                if tile_x >= tiles_x {
558                    return;
559                }
560
561                let tile_idx = (tiles_y - 1) * tiles_x + tile_x;
562                let tile_offset = tile_idx * tile_buffer_size;
563                let src_idx = tile_offset + tile_size * buffer_width + (cell_x + 1);
564                let dst_idx = tile_offset + (tile_size + 1) * buffer_width + (cell_x + 1);
565                packed_buffer[dst_idx as usize] = packed_buffer[src_idx as usize] * reflection_coeff;
566            } else if edge == 2 {
567                // West boundary: tiles with tile_x == 0
568                let tile_y = idx / tile_size;
569                let cell_y = idx % tile_size;
570                if tile_y >= tiles_y {
571                    return;
572                }
573
574                let tile_idx = tile_y * tiles_x + 0;
575                let tile_offset = tile_idx * tile_buffer_size;
576                let src_idx = tile_offset + (cell_y + 1) * buffer_width + 1;
577                let dst_idx = tile_offset + (cell_y + 1) * buffer_width + 0;
578                packed_buffer[dst_idx as usize] = packed_buffer[src_idx as usize] * reflection_coeff;
579            } else if edge == 3 {
580                // East boundary: tiles with tile_x == tiles_x - 1
581                let tile_y = idx / tile_size;
582                let cell_y = idx % tile_size;
583                if tile_y >= tiles_y {
584                    return;
585                }
586
587                let tile_idx = tile_y * tiles_x + (tiles_x - 1);
588                let tile_offset = tile_idx * tile_buffer_size;
589                let src_idx = tile_offset + (cell_y + 1) * buffer_width + tile_size;
590                let dst_idx = tile_offset + (cell_y + 1) * buffer_width + (tile_size + 1);
591                packed_buffer[dst_idx as usize] = packed_buffer[src_idx as usize] * reflection_coeff;
592            }
593        }
594    };
595
596    match transpile_global_kernel(&kernel_fn) {
597        Ok(cuda) => format!(
598            "// Apply Boundary Conditions - handles domain edges for packed tiles\n{}",
599            cuda
600        ),
601        Err(e) => format!("// Transpilation error: {}\n", e),
602    }
603}
604
605// ============================================================================
606// Ring Kernel Actor (Persistent Tile Actor with K2K Messaging)
607// ============================================================================
608
609/// Generate a persistent ring kernel actor for tile-based FDTD simulation.
610///
611/// This generates a GPU-persistent actor that:
612/// - Processes halo exchange messages via K2K with envelope format
613/// - Computes FDTD for interior cells
614/// - Sends updated halo data to neighbors
615/// - Uses HLC timestamps for causal message ordering
616///
617/// The actor runs persistently on the GPU, processing messages in a loop
618/// until termination is signaled.
619#[cfg(feature = "cuda-codegen")]
620pub fn generate_tile_actor_kernel(tile_id: u64, node_id: u64) -> String {
621    use syn::parse_quote;
622
623    // Simple handler that processes halo messages and returns updated edges
624    // The actual FDTD computation is handled by the kernel infrastructure
625    let handler: syn::ItemFn = parse_quote! {
626        fn tile_actor_handler(ctx: &RingContext, msg: &HaloMessage) -> HaloResponse {
627            // Get tile-local thread index
628            let tid = ctx.global_thread_id();
629
630            // Synchronize threads before processing
631            ctx.sync_threads();
632
633            // Process the halo message
634            // Direction: 0=North, 1=South, 2=East, 3=West
635            let direction = msg.direction;
636            let step = msg.step;
637
638            // Acknowledge receipt with updated step count
639            HaloResponse {
640                direction: direction,
641                step: step + 1,
642                ack: 1,
643            }
644        }
645    };
646
647    let config = RingKernelConfig::new("tile_actor")
648        .with_block_size(256)
649        .with_queue_capacity(64)
650        .with_envelope_format(true)
651        .with_k2k(true)
652        .with_hlc(true)
653        .with_kernel_id(tile_id)
654        .with_hlc_node_id(node_id);
655
656    match transpile_ring_kernel(&handler, &config) {
657        Ok(cuda) => cuda,
658        Err(e) => format!("// Transpilation error: {}\n", e),
659    }
660}
661
662/// Generate the complete CUDA source for actor-based tile kernels.
663///
664/// This generates persistent ring kernel actors for tile-based FDTD:
665/// - Each tile is a persistent actor processing halo messages
666/// - K2K messaging with envelope format for neighbor communication
667/// - HLC timestamps for causal ordering of halo exchanges
668#[cfg(feature = "cuda-codegen")]
669pub fn generate_actor_tile_kernels() -> String {
670    let mut output = String::new();
671
672    output.push_str(ACTOR_TILE_KERNELS_HEADER);
673    output.push('\n');
674
675    // Generate message type structs
676    output.push_str(&generate_halo_message_types());
677    output.push('\n');
678
679    // Generate the tile actor kernel (with default IDs, actual IDs set at launch)
680    output.push_str(&generate_tile_actor_kernel(0, 0));
681
682    output
683}
684
685/// Header comment for generated actor tile kernels.
686pub const ACTOR_TILE_KERNELS_HEADER: &str = r#"// CUDA Kernels for Actor-Based Tile FDTD Wave Simulation
687// Generated by ringkernel-cuda-codegen from Rust DSL
688//
689// Architecture:
690// - Each 16x16 tile is a persistent ring kernel actor
691// - Actors communicate via K2K messaging with envelope format
692// - HLC timestamps ensure causal ordering of halo exchanges
693// - MessageEnvelope format: 256-byte header + payload
694//
695// Halo Exchange Protocol:
696// 1. Receive halo data from neighbors (envelope contains direction + data)
697// 2. Apply halos to shared memory buffer
698// 3. Compute FDTD for interior cells
699// 4. Extract edges and send to neighbors
700// 5. Loop until termination signaled
701"#;
702
703/// Generate the halo message type definitions for CUDA.
704#[cfg(feature = "cuda-codegen")]
705fn generate_halo_message_types() -> String {
706    r#"
707// Halo message types for K2K communication
708// Direction: 0=North, 1=South, 2=East, 3=West
709
710struct HaloMessage {
711    unsigned int direction;     // Which edge this halo is for
712    unsigned int step;          // Simulation step for ordering
713    float data[16];             // Halo data (16 cells for 16x16 tile)
714};
715
716struct HaloResponse {
717    unsigned int direction;     // Which edge we're acknowledging
718    unsigned int step;          // Updated step count
719    unsigned int ack;           // Acknowledgment flag
720    unsigned int _padding;      // Align to 8 bytes
721};
722
723// Message type ID for halo exchange
724#define HALO_MESSAGE_TYPE_ID 200ULL
725
726"#
727    .to_string()
728}
729
730// ============================================================================
731// Fallback implementations (when cuda-codegen is not enabled)
732// ============================================================================
733
734#[cfg(not(feature = "cuda-codegen"))]
735pub fn generate_tile_kernels() -> String {
736    "// CUDA codegen not enabled - use handwritten shaders/fdtd_tile.cu".to_string()
737}
738
739#[cfg(not(feature = "cuda-codegen"))]
740pub fn generate_packed_kernels() -> String {
741    "// CUDA codegen not enabled - use handwritten shaders/fdtd_packed.cu".to_string()
742}
743
744#[cfg(not(feature = "cuda-codegen"))]
745pub fn generate_actor_tile_kernels() -> String {
746    "// CUDA codegen not enabled - actor kernels require cuda-codegen feature".to_string()
747}
748
749#[cfg(not(feature = "cuda-codegen"))]
750pub fn generate_tile_actor_kernel(_tile_id: u64, _node_id: u64) -> String {
751    "// CUDA codegen not enabled".to_string()
752}
753
754// ============================================================================
755// Tests
756// ============================================================================
757
758#[cfg(test)]
759mod tests {
760    #[allow(unused_imports)]
761    use super::*;
762
763    #[test]
764    #[cfg(feature = "cuda-codegen")]
765    fn test_tile_kernels_structure() {
766        let source = generate_tile_kernels();
767
768        // Check all kernels are present
769        assert!(source.contains("fdtd_tile_step"), "Missing fdtd_tile_step");
770        assert!(source.contains("extract_halo"), "Missing extract_halo");
771        assert!(source.contains("inject_halo"), "Missing inject_halo");
772        assert!(source.contains("read_interior"), "Missing read_interior");
773        assert!(
774            source.contains("apply_boundary_reflection"),
775            "Missing apply_boundary_reflection"
776        );
777
778        // Check extern "C" wrapper
779        assert!(source.contains("extern \"C\""), "Missing extern C");
780
781        // Check FDTD kernel has correct structure
782        assert!(source.contains("__global__ void fdtd_tile_step"));
783        assert!(source.contains("threadIdx.x"));
784        assert!(source.contains("threadIdx.y"));
785        assert!(source.contains("buffer_width = 18") || source.contains("* 18"));
786    }
787
788    #[test]
789    #[cfg(feature = "cuda-codegen")]
790    fn test_packed_kernels_structure() {
791        let source = generate_packed_kernels();
792
793        // Check all kernels are present
794        assert!(
795            source.contains("exchange_all_halos"),
796            "Missing exchange_all_halos"
797        );
798        assert!(source.contains("fdtd_all_tiles"), "Missing fdtd_all_tiles");
799        assert!(
800            source.contains("upload_tile_data"),
801            "Missing upload_tile_data"
802        );
803        assert!(
804            source.contains("read_all_interiors"),
805            "Missing read_all_interiors"
806        );
807        assert!(source.contains("inject_impulse"), "Missing inject_impulse");
808        assert!(
809            source.contains("apply_boundary_conditions"),
810            "Missing apply_boundary_conditions"
811        );
812
813        // Check batched FDTD uses blockIdx
814        assert!(source.contains("blockIdx.x"), "Missing blockIdx usage");
815        assert!(source.contains("blockIdx.y"), "Missing blockIdx.y usage");
816    }
817
818    #[test]
819    #[cfg(feature = "cuda-codegen")]
820    fn test_fdtd_tile_step_matches_handwritten() {
821        let generated = generate_fdtd_tile_step();
822
823        // Verify key structural elements match handwritten version
824        assert!(generated.contains("const float* __restrict__ pressure"));
825        assert!(generated.contains("float* __restrict__ pressure_prev"));
826        assert!(generated.contains("float c2"));
827        assert!(generated.contains("float damping"));
828        assert!(generated.contains("if (lx >= 16 || ly >= 16) return;"));
829        assert!(
830            generated.contains("idx = (ly + 1) * buffer_width + (lx + 1)")
831                || generated.contains("(ly + 1) * 18 + (lx + 1)")
832        );
833        assert!(generated.contains("laplacian"));
834        assert!(generated.contains("* damping"));
835
836        println!("Generated fdtd_tile_step:\n{}", generated);
837    }
838
839    #[test]
840    #[cfg(feature = "cuda-codegen")]
841    fn test_generated_vs_handwritten_tile() {
842        let generated = generate_tile_kernels();
843        let handwritten = include_str!("../shaders/fdtd_tile.cu");
844
845        // Count kernels in both
846        let gen_kernel_count = generated.matches("__global__").count();
847        let hw_kernel_count = handwritten.matches("__global__").count();
848
849        assert_eq!(
850            gen_kernel_count, hw_kernel_count,
851            "Kernel count mismatch: generated={}, handwritten={}",
852            gen_kernel_count, hw_kernel_count
853        );
854    }
855
856    #[test]
857    #[cfg(feature = "cuda-codegen")]
858    fn test_generated_vs_handwritten_packed() {
859        let generated = generate_packed_kernels();
860        let handwritten = include_str!("../shaders/fdtd_packed.cu");
861
862        // Count kernels in both
863        let gen_kernel_count = generated.matches("__global__").count();
864        let hw_kernel_count = handwritten.matches("__global__").count();
865
866        assert_eq!(
867            gen_kernel_count, hw_kernel_count,
868            "Kernel count mismatch: generated={}, handwritten={}",
869            gen_kernel_count, hw_kernel_count
870        );
871    }
872
873    #[test]
874    #[cfg(feature = "cuda-codegen")]
875    fn test_match_expression_transpiles_to_switch() {
876        // Test extract_halo uses switch for edge selection
877        let extract = generate_extract_halo();
878        assert!(
879            extract.contains("switch (edge)"),
880            "extract_halo should use switch: {}",
881            extract
882        );
883        assert!(
884            extract.contains("case 0:"),
885            "extract_halo should have case 0"
886        );
887        assert!(
888            extract.contains("case 1:"),
889            "extract_halo should have case 1"
890        );
891        assert!(
892            extract.contains("case 2:"),
893            "extract_halo should have case 2"
894        );
895        assert!(
896            extract.contains("default:"),
897            "extract_halo should have default"
898        );
899
900        // Test inject_halo uses switch for edge selection
901        let inject = generate_inject_halo();
902        assert!(
903            inject.contains("switch (edge)"),
904            "inject_halo should use switch: {}",
905            inject
906        );
907
908        // Test apply_boundary_reflection uses switch
909        let boundary = generate_apply_boundary_reflection();
910        assert!(
911            boundary.contains("switch (edge)"),
912            "apply_boundary_reflection should use switch: {}",
913            boundary
914        );
915
916        println!("Generated extract_halo:\n{}", extract);
917    }
918
919    #[test]
920    #[cfg(feature = "cuda-codegen")]
921    fn test_all_kernels_transpile_successfully() {
922        // Verify all tile kernels transpile without errors
923        let tile_source = generate_tile_kernels();
924        assert!(
925            !tile_source.contains("Transpilation error"),
926            "Tile kernels had transpilation errors:\n{}",
927            tile_source
928        );
929
930        // Verify all packed kernels transpile without errors
931        let packed_source = generate_packed_kernels();
932        assert!(
933            !packed_source.contains("Transpilation error"),
934            "Packed kernels had transpilation errors:\n{}",
935            packed_source
936        );
937
938        // Count __global__ functions to ensure all generated
939        let tile_count = tile_source.matches("__global__").count();
940        let packed_count = packed_source.matches("__global__").count();
941
942        assert_eq!(tile_count, 5, "Expected 5 tile kernels, got {}", tile_count);
943        assert_eq!(
944            packed_count, 6,
945            "Expected 6 packed kernels, got {}",
946            packed_count
947        );
948
949        println!(
950            "Successfully generated {} tile kernels and {} packed kernels",
951            tile_count, packed_count
952        );
953    }
954
955    #[test]
956    #[cfg(feature = "cuda-codegen")]
957    fn test_actor_tile_kernel_generates() {
958        let source = generate_actor_tile_kernels();
959
960        // Check header is present
961        assert!(
962            source.contains("Actor-Based Tile FDTD"),
963            "Missing actor kernel header"
964        );
965
966        // Check message types are defined
967        assert!(
968            source.contains("struct HaloMessage"),
969            "Missing HaloMessage struct"
970        );
971        assert!(
972            source.contains("struct HaloResponse"),
973            "Missing HaloResponse struct"
974        );
975
976        // Check ring kernel structure
977        assert!(
978            source.contains("ring_kernel_tile_actor"),
979            "Missing ring kernel function"
980        );
981        assert!(
982            source.contains("ControlBlock"),
983            "Missing ControlBlock structure"
984        );
985
986        // Verify no transpilation errors
987        assert!(
988            !source.contains("Transpilation error"),
989            "Actor kernel had transpilation errors:\n{}",
990            source
991        );
992
993        println!("Generated actor tile kernel ({} bytes)", source.len());
994    }
995
996    #[test]
997    #[cfg(feature = "cuda-codegen")]
998    fn test_actor_kernel_has_envelope_format() {
999        let source = generate_tile_actor_kernel(42, 7);
1000
1001        // Verify no transpilation errors first
1002        assert!(
1003            !source.contains("Transpilation error"),
1004            "Actor kernel had transpilation errors:\n{}",
1005            source
1006        );
1007
1008        // Check kernel identity constants
1009        assert!(
1010            source.contains("KERNEL_ID = 42"),
1011            "Missing KERNEL_ID constant: {}",
1012            source
1013        );
1014        assert!(
1015            source.contains("HLC_NODE_ID = 7"),
1016            "Missing HLC_NODE_ID constant"
1017        );
1018
1019        // Check envelope format structures (MessageHeader struct)
1020        assert!(
1021            source.contains("MessageHeader"),
1022            "Missing MessageHeader struct for envelope format"
1023        );
1024
1025        // Check HLC support
1026        assert!(
1027            source.contains("HlcTimestamp") || source.contains("hlc_physical"),
1028            "Missing HLC support"
1029        );
1030
1031        // Check K2K support
1032        assert!(
1033            source.contains("K2KRoutingTable") || source.contains("k2k_"),
1034            "Missing K2K routing support"
1035        );
1036    }
1037
1038    #[test]
1039    #[cfg(feature = "cuda-codegen")]
1040    fn test_actor_kernel_ring_kernel_structure() {
1041        let source = generate_tile_actor_kernel(1, 1);
1042
1043        // Verify no transpilation errors
1044        assert!(
1045            !source.contains("Transpilation error"),
1046            "Actor kernel had transpilation errors:\n{}",
1047            source
1048        );
1049
1050        // Check persistent loop structure
1051        assert!(
1052            source.contains("while (true)") || source.contains("while(true)"),
1053            "Missing persistent loop"
1054        );
1055
1056        // Check termination handling
1057        assert!(
1058            source.contains("should_terminate"),
1059            "Missing termination check"
1060        );
1061
1062        // Check synchronization
1063        assert!(
1064            source.contains("__syncthreads()"),
1065            "Missing thread synchronization"
1066        );
1067
1068        // Check message processing
1069        assert!(
1070            source.contains("HaloMessage") || source.contains("msg->"),
1071            "Missing message processing"
1072        );
1073
1074        println!("Actor kernel ring structure verified");
1075    }
1076}