Skip to main content

ringkernel_wavesim/simulation/
fdtd_dsl.rs

1//! FDTD kernel defined using the Rust DSL.
2//!
3//! This module demonstrates how to define stencil kernels in pure Rust
4//! that get transpiled to CUDA at compile time.
5//!
6//! The generated CUDA code is equivalent to the handwritten version in
7//! `shaders/fdtd_tile.cu`, but is written in a high-level Rust DSL.
8
9#[cfg(feature = "cuda-codegen")]
10use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
11
12// Re-export or define GridPos
13#[cfg(feature = "cuda-codegen")]
14pub use ringkernel_cuda_codegen::GridPos;
15
16/// CPU-side GridPos for when cuda-codegen is not enabled.
17///
18/// This provides a fully functional CPU implementation that mirrors
19/// the CUDA version's semantics for testing and verification.
20#[cfg(not(feature = "cuda-codegen"))]
21#[derive(Debug, Clone, Copy)]
22pub struct GridPos {
23    /// Current linear index into the buffer.
24    index: usize,
25    /// Buffer width (tile_size + 2 * halo).
26    buffer_width: usize,
27}
28
29#[cfg(not(feature = "cuda-codegen"))]
30impl GridPos {
31    /// Create a new GridPos for CPU execution.
32    ///
33    /// # Arguments
34    /// * `index` - Current cell's linear index
35    /// * `buffer_width` - Width of the buffer including halos
36    pub fn new(index: usize, buffer_width: usize) -> Self {
37        Self {
38            index,
39            buffer_width,
40        }
41    }
42
43    /// Create a GridPos from 2D coordinates within the interior.
44    ///
45    /// # Arguments
46    /// * `lx` - Local x coordinate (0..tile_size)
47    /// * `ly` - Local y coordinate (0..tile_size)
48    /// * `tile_size` - Size of the interior tile
49    /// * `halo` - Halo width
50    pub fn from_local(lx: usize, ly: usize, tile_size: usize, halo: usize) -> Self {
51        let buffer_width = tile_size + 2 * halo;
52        let index = (ly + halo) * buffer_width + (lx + halo);
53        Self {
54            index,
55            buffer_width,
56        }
57    }
58
59    /// Get the current cell's linear index.
60    #[inline]
61    pub fn idx(&self) -> usize {
62        self.index
63    }
64
65    /// Access the north neighbor (y - 1).
66    #[inline]
67    pub fn north<T: Copy>(&self, buf: &[T]) -> T {
68        buf[self.index - self.buffer_width]
69    }
70
71    /// Access the south neighbor (y + 1).
72    #[inline]
73    pub fn south<T: Copy>(&self, buf: &[T]) -> T {
74        buf[self.index + self.buffer_width]
75    }
76
77    /// Access the east neighbor (x + 1).
78    #[inline]
79    pub fn east<T: Copy>(&self, buf: &[T]) -> T {
80        buf[self.index + 1]
81    }
82
83    /// Access the west neighbor (x - 1).
84    #[inline]
85    pub fn west<T: Copy>(&self, buf: &[T]) -> T {
86        buf[self.index - 1]
87    }
88
89    /// Access a neighbor at arbitrary offset.
90    #[inline]
91    pub fn at<T: Copy>(&self, buf: &[T], dx: i32, dy: i32) -> T {
92        let offset = dy as isize * self.buffer_width as isize + dx as isize;
93        buf[(self.index as isize + offset) as usize]
94    }
95
96    /// Access the neighbor above (z - 1, 3D only).
97    ///
98    /// For 2D grids, this is equivalent to `north`.
99    #[inline]
100    pub fn up<T: Copy>(&self, buf: &[T]) -> T {
101        self.north(buf)
102    }
103
104    /// Access the neighbor below (z + 1, 3D only).
105    ///
106    /// For 2D grids, this is equivalent to `south`.
107    #[inline]
108    pub fn down<T: Copy>(&self, buf: &[T]) -> T {
109        self.south(buf)
110    }
111}
112
113/// FDTD wave equation kernel for 16x16 tiles with 1-cell halo.
114///
115/// This is a pure Rust implementation of the FDTD step that can be
116/// used for CPU fallback and verification.
117pub fn fdtd_wave_step_cpu(
118    pressure: &[f32],
119    pressure_prev: &mut [f32],
120    c2: f32,
121    damping: f32,
122    idx: usize,
123    buffer_width: usize,
124) {
125    let p = pressure[idx];
126    let p_prev = pressure_prev[idx];
127
128    // Compute discrete Laplacian using neighbor access
129    let laplacian = pressure[idx - buffer_width]  // North
130        + pressure[idx + buffer_width]             // South
131        + pressure[idx + 1]                        // East
132        + pressure[idx - 1]                        // West
133        - 4.0 * p;
134
135    // FDTD update: p_new = 2*p - p_prev + c²*laplacian
136    let p_new = 2.0 * p - p_prev + c2 * laplacian;
137
138    // Apply damping and store result
139    pressure_prev[idx] = p_new * damping;
140}
141
142/// FDTD wave equation kernel using the GridPos DSL.
143///
144/// This version uses the same API as the CUDA-transpiled version,
145/// demonstrating that the Rust DSL code works identically on CPU.
146///
147/// The function signature matches what `#[stencil_kernel]` would generate:
148/// ```ignore
149/// fn fdtd_wave_step(
150///     pressure: &[f32],
151///     pressure_prev: &mut [f32],
152///     c2: f32,
153///     damping: f32,
154///     pos: GridPos,
155/// )
156/// ```
157#[cfg(not(feature = "cuda-codegen"))]
158pub fn fdtd_wave_step_dsl(
159    pressure: &[f32],
160    pressure_prev: &mut [f32],
161    c2: f32,
162    damping: f32,
163    pos: GridPos,
164) {
165    let p = pressure[pos.idx()];
166    let p_prev = pressure_prev[pos.idx()];
167    let laplacian =
168        pos.north(pressure) + pos.south(pressure) + pos.east(pressure) + pos.west(pressure)
169            - 4.0 * p;
170    let p_new = 2.0 * p - p_prev + c2 * laplacian;
171    pressure_prev[pos.idx()] = p_new * damping;
172}
173
174/// Run FDTD on an entire tile using the DSL kernel.
175///
176/// This simulates what CUDA would do: run the kernel for each interior cell.
177#[cfg(not(feature = "cuda-codegen"))]
178pub fn fdtd_tile_step_dsl(
179    pressure: &[f32],
180    pressure_prev: &mut [f32],
181    c2: f32,
182    damping: f32,
183    tile_size: usize,
184    halo: usize,
185) {
186    for ly in 0..tile_size {
187        for lx in 0..tile_size {
188            let pos = GridPos::from_local(lx, ly, tile_size, halo);
189            fdtd_wave_step_dsl(pressure, pressure_prev, c2, damping, pos);
190        }
191    }
192}
193
194/// Generate CUDA source for the FDTD kernel using the DSL transpiler.
195///
196/// This function demonstrates using the transpiler directly to generate
197/// CUDA code from a Rust function AST.
198#[cfg(feature = "cuda-codegen")]
199pub fn generate_fdtd_cuda() -> String {
200    use syn::parse_quote;
201
202    // Define the kernel as a syn AST
203    let kernel_fn: syn::ItemFn = parse_quote! {
204        fn fdtd_wave_step(
205            pressure: &[f32],
206            pressure_prev: &mut [f32],
207            c2: f32,
208            damping: f32,
209            pos: GridPos,
210        ) {
211            let p = pressure[pos.idx()];
212            let p_prev = pressure_prev[pos.idx()];
213            let laplacian = pos.north(pressure)
214                + pos.south(pressure)
215                + pos.east(pressure)
216                + pos.west(pressure)
217                - 4.0 * p;
218            let p_new = 2.0 * p - p_prev + c2 * laplacian;
219            pressure_prev[pos.idx()] = p_new * damping;
220        }
221    };
222
223    // Configure the stencil
224    let config = StencilConfig::new("fdtd_wave_step")
225        .with_grid(Grid::Grid2D)
226        .with_tile_size(16, 16)
227        .with_halo(1);
228
229    // Transpile to CUDA
230    match transpile_stencil_kernel(&kernel_fn, &config) {
231        Ok(cuda) => cuda,
232        Err(e) => format!("// Transpilation error: {}", e),
233    }
234}
235
236#[cfg(not(feature = "cuda-codegen"))]
237pub fn generate_fdtd_cuda() -> String {
238    "// CUDA codegen not enabled".to_string()
239}
240
241/// Get the handwritten CUDA source for comparison.
242pub fn handwritten_fdtd_cuda() -> &'static str {
243    include_str!("../shaders/fdtd_tile.cu")
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_cpu_fdtd_step() {
252        // Create a simple 4x4 buffer (with 1-cell halo = 6x6 total)
253        let buffer_width = 6;
254        let mut pressure = vec![0.0f32; 36];
255        let mut pressure_prev = vec![0.0f32; 36];
256
257        // Set center cell to 1.0
258        pressure[14] = 1.0; // (2,2) in 6x6 buffer
259
260        // Run one FDTD step on the center cell
261        fdtd_wave_step_cpu(&pressure, &mut pressure_prev, 0.25, 0.99, 14, buffer_width);
262
263        // The Laplacian of a single spike is -4 (since neighbors are 0)
264        // p_new = 2*1 - 0 + 0.25*(-4) = 2 - 1 = 1.0
265        // With damping: 1.0 * 0.99 = 0.99
266        assert!((pressure_prev[14] - 0.99).abs() < 0.01);
267    }
268
269    #[test]
270    #[cfg(feature = "cuda-codegen")]
271    fn test_generated_cuda_source() {
272        let source = generate_fdtd_cuda();
273
274        // Check that it looks like valid CUDA
275        assert!(
276            source.contains("extern \"C\" __global__"),
277            "Should have CUDA kernel declaration"
278        );
279        assert!(source.contains("threadIdx.x"), "Should use thread index X");
280        assert!(source.contains("threadIdx.y"), "Should use thread index Y");
281        assert!(
282            source.contains("buffer_width = 18"),
283            "Should have correct buffer width (16 + 2*1)"
284        );
285
286        println!("Generated CUDA:\n{}", source);
287    }
288
289    #[test]
290    #[cfg(feature = "cuda-codegen")]
291    fn test_cuda_source_structure() {
292        let source = generate_fdtd_cuda();
293
294        // Verify key structural elements
295        assert!(
296            source.contains("if (lx >= 16 || ly >= 16) return;"),
297            "Should have bounds check"
298        );
299        assert!(
300            source.contains("float p ="),
301            "Should have pressure variable"
302        );
303
304        // Verify stencil access pattern (using buffer_width constant)
305        assert!(
306            source.contains("- 18") || source.contains("- buffer_width"),
307            "Should access north neighbor"
308        );
309        assert!(
310            source.contains("+ 18") || source.contains("+ buffer_width"),
311            "Should access south neighbor"
312        );
313    }
314
315    #[test]
316    #[cfg(not(feature = "cuda-codegen"))]
317    fn test_grid_pos_neighbor_access() {
318        // Create a 4x4 interior with 1-cell halo = 6x6 buffer
319        // Layout (indices):
320        //  0  1  2  3  4  5
321        //  6  7  8  9 10 11
322        // 12 13 14 15 16 17
323        // 18 19 20 21 22 23
324        // 24 25 26 27 28 29
325        // 30 31 32 33 34 35
326
327        let buffer: Vec<f32> = (0..36).map(|i| i as f32).collect();
328        let pos = GridPos::new(14, 6); // Center at (2,2)
329
330        assert_eq!(pos.idx(), 14);
331        assert_eq!(pos.north(&buffer), 8.0); // 14 - 6 = 8
332        assert_eq!(pos.south(&buffer), 20.0); // 14 + 6 = 20
333        assert_eq!(pos.east(&buffer), 15.0); // 14 + 1 = 15
334        assert_eq!(pos.west(&buffer), 13.0); // 14 - 1 = 13
335    }
336
337    #[test]
338    #[cfg(not(feature = "cuda-codegen"))]
339    fn test_grid_pos_from_local() {
340        // 16x16 tile with halo=1 -> 18x18 buffer
341        let pos = GridPos::from_local(0, 0, 16, 1);
342        // index = (0+1)*18 + (0+1) = 19
343        assert_eq!(pos.idx(), 19);
344
345        let pos = GridPos::from_local(15, 15, 16, 1);
346        // index = (15+1)*18 + (15+1) = 16*18 + 16 = 288 + 16 = 304
347        assert_eq!(pos.idx(), 304);
348    }
349
350    #[test]
351    #[cfg(not(feature = "cuda-codegen"))]
352    fn test_grid_pos_at_offset() {
353        let buffer: Vec<f32> = (0..36).map(|i| i as f32).collect();
354        let pos = GridPos::new(14, 6);
355
356        // at(0, 0) should be the same as idx
357        assert_eq!(pos.at(&buffer, 0, 0), buffer[14]);
358
359        // at(1, 0) should be east
360        assert_eq!(pos.at(&buffer, 1, 0), buffer[15]);
361
362        // at(-1, 0) should be west
363        assert_eq!(pos.at(&buffer, -1, 0), buffer[13]);
364
365        // at(0, -1) should be north
366        assert_eq!(pos.at(&buffer, 0, -1), buffer[8]);
367
368        // at(0, 1) should be south
369        assert_eq!(pos.at(&buffer, 0, 1), buffer[20]);
370
371        // Diagonal: northeast
372        assert_eq!(pos.at(&buffer, 1, -1), buffer[9]);
373    }
374
375    #[test]
376    #[cfg(not(feature = "cuda-codegen"))]
377    fn test_dsl_matches_cpu() {
378        // Verify that the DSL version produces the same results as the direct CPU version
379        let buffer_width = 6;
380        let tile_size = 4;
381        let halo = 1;
382
383        let mut pressure1 = vec![0.0f32; 36];
384        let mut pressure1_prev = vec![0.0f32; 36];
385        let mut pressure2 = vec![0.0f32; 36];
386        let mut pressure2_prev = vec![0.0f32; 36];
387
388        // Set center cell to 1.0 in both
389        pressure1[14] = 1.0;
390        pressure2[14] = 1.0;
391
392        let c2 = 0.25f32;
393        let damping = 0.99f32;
394
395        // Run CPU version for all interior cells
396        for ly in 0..tile_size {
397            for lx in 0..tile_size {
398                let idx = (ly + halo) * buffer_width + (lx + halo);
399                fdtd_wave_step_cpu(
400                    &pressure1,
401                    &mut pressure1_prev,
402                    c2,
403                    damping,
404                    idx,
405                    buffer_width,
406                );
407            }
408        }
409
410        // Run DSL version
411        fdtd_tile_step_dsl(
412            &pressure2,
413            &mut pressure2_prev,
414            c2,
415            damping,
416            tile_size,
417            halo,
418        );
419
420        // Compare results
421        for i in 0..36 {
422            assert!(
423                (pressure1_prev[i] - pressure2_prev[i]).abs() < 1e-6,
424                "Mismatch at index {}: CPU={}, DSL={}",
425                i,
426                pressure1_prev[i],
427                pressure2_prev[i]
428            );
429        }
430    }
431}