Skip to main content

ringkernel_wgpu_codegen/
shared.rs

1//! Shared (workgroup) memory support for WGSL code generation.
2//!
3//! Maps Rust shared memory patterns to WGSL workgroup variables.
4
5use crate::types::WgslType;
6
7/// Declaration of a shared memory variable.
8#[derive(Debug, Clone)]
9pub struct SharedMemoryDecl {
10    /// Variable name.
11    pub name: String,
12    /// Element type.
13    pub element_type: WgslType,
14    /// Dimensions (1D, 2D, etc.).
15    pub dimensions: Vec<u32>,
16}
17
18impl SharedMemoryDecl {
19    /// Create a 1D shared memory declaration.
20    pub fn new_1d(name: &str, element_type: WgslType, size: u32) -> Self {
21        Self {
22            name: name.to_string(),
23            element_type,
24            dimensions: vec![size],
25        }
26    }
27
28    /// Create a 2D shared memory declaration.
29    pub fn new_2d(name: &str, element_type: WgslType, width: u32, height: u32) -> Self {
30        Self {
31            name: name.to_string(),
32            element_type,
33            dimensions: vec![width, height],
34        }
35    }
36
37    /// Create a 3D shared memory declaration.
38    pub fn new_3d(name: &str, element_type: WgslType, width: u32, height: u32, depth: u32) -> Self {
39        Self {
40            name: name.to_string(),
41            element_type,
42            dimensions: vec![width, height, depth],
43        }
44    }
45
46    /// Total number of elements across all dimensions.
47    pub fn total_elements(&self) -> u32 {
48        self.dimensions.iter().product()
49    }
50
51    /// Generate WGSL declaration.
52    ///
53    /// Generates nested arrays for multi-dimensional shared memory:
54    /// - 1D: `array<T, N>`
55    /// - 2D: `array<array<T, W>, H>` (accessed as `arr[y][x]`)
56    /// - 3D: `array<array<array<T, X>, Y>, Z>` (accessed as `arr[z][y][x]`)
57    /// - 4D+: Linearized 1D array with accessor comment
58    pub fn to_wgsl(&self) -> String {
59        let type_str = self.element_type.to_wgsl();
60        match self.dimensions.len() {
61            0 => format!("var<workgroup> {}: {};", self.name, type_str),
62            1 => format!(
63                "var<workgroup> {}: array<{}, {}>;",
64                self.name, type_str, self.dimensions[0]
65            ),
66            2 => {
67                // 2D: array<array<T, W>, H> - [height][width] indexing (row-major)
68                format!(
69                    "var<workgroup> {}: array<array<{}, {}>, {}>;",
70                    self.name, type_str, self.dimensions[0], self.dimensions[1]
71                )
72            }
73            3 => {
74                // 3D: array<array<array<T, X>, Y>, Z> - [depth][height][width] indexing
75                format!(
76                    "var<workgroup> {}: array<array<array<{}, {}>, {}>, {}>;",
77                    self.name, type_str, self.dimensions[0], self.dimensions[1], self.dimensions[2]
78                )
79            }
80            _ => {
81                // 4D+ dimensions: linearize to 1D array
82                // Include a comment showing the dimensions for reference
83                let total = self.total_elements();
84                let dims_str = self
85                    .dimensions
86                    .iter()
87                    .map(|d| d.to_string())
88                    .collect::<Vec<_>>()
89                    .join("x");
90                format!(
91                    "var<workgroup> {}: array<{}, {}>; // linearized {}D ({})",
92                    self.name,
93                    type_str,
94                    total,
95                    self.dimensions.len(),
96                    dims_str
97                )
98            }
99        }
100    }
101
102    /// Generate index calculation for linearized access to higher-dimensional arrays.
103    ///
104    /// For 4D+ arrays that are linearized, this generates the index formula.
105    /// E.g., for dims [W, H, D, T]: `x + y * W + z * W * H + t * W * H * D`
106    pub fn linearized_index_formula(&self, index_vars: &[&str]) -> Option<String> {
107        if self.dimensions.len() < 4 || index_vars.len() != self.dimensions.len() {
108            return None;
109        }
110
111        let mut terms = Vec::new();
112        let mut stride = 1u32;
113
114        for (i, var) in index_vars.iter().enumerate() {
115            if stride == 1 {
116                terms.push(var.to_string());
117            } else {
118                terms.push(format!("{} * {}u", var, stride));
119            }
120            stride *= self.dimensions[i];
121        }
122
123        Some(terms.join(" + "))
124    }
125}
126
127/// Configuration for shared memory in a kernel.
128#[derive(Debug, Clone, Default)]
129pub struct SharedMemoryConfig {
130    /// List of shared memory declarations.
131    pub declarations: Vec<SharedMemoryDecl>,
132}
133
134impl SharedMemoryConfig {
135    /// Create a new empty configuration.
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    /// Add a shared memory declaration.
141    pub fn add(&mut self, decl: SharedMemoryDecl) {
142        self.declarations.push(decl);
143    }
144
145    /// Generate all WGSL declarations.
146    pub fn to_wgsl(&self) -> String {
147        self.declarations
148            .iter()
149            .map(|d| d.to_wgsl())
150            .collect::<Vec<_>>()
151            .join("\n")
152    }
153}
154
155/// Marker type for 2D shared memory tiles.
156///
157/// This is a compile-time marker. The transpiler recognizes usage and
158/// generates appropriate WGSL workgroup variables.
159pub struct SharedTile<T, const W: usize, const H: usize> {
160    _marker: std::marker::PhantomData<T>,
161}
162
163impl<T, const W: usize, const H: usize> SharedTile<T, W, H> {
164    /// Get the tile width.
165    pub const fn width() -> usize {
166        W
167    }
168
169    /// Get the tile height.
170    pub const fn height() -> usize {
171        H
172    }
173}
174
175/// Marker type for 1D shared memory arrays.
176pub struct SharedArray<T, const N: usize> {
177    _marker: std::marker::PhantomData<T>,
178}
179
180impl<T, const N: usize> SharedArray<T, N> {
181    /// Get the array size.
182    pub const fn size() -> usize {
183        N
184    }
185}
186
187/// Marker type for 3D shared memory volumes.
188///
189/// This is a compile-time marker for volumetric shared memory (e.g., 3D stencil tiles).
190/// The transpiler recognizes usage and generates appropriate WGSL workgroup variables.
191/// Accessed as `volume[z][y][x]` in WGSL.
192pub struct SharedVolume<T, const X: usize, const Y: usize, const Z: usize> {
193    _marker: std::marker::PhantomData<T>,
194}
195
196impl<T, const X: usize, const Y: usize, const Z: usize> SharedVolume<T, X, Y, Z> {
197    /// Get the volume width (X dimension).
198    pub const fn width() -> usize {
199        X
200    }
201
202    /// Get the volume height (Y dimension).
203    pub const fn height() -> usize {
204        Y
205    }
206
207    /// Get the volume depth (Z dimension).
208    pub const fn depth() -> usize {
209        Z
210    }
211
212    /// Total number of elements in the volume.
213    pub const fn total() -> usize {
214        X * Y * Z
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_shared_memory_1d() {
224        let decl = SharedMemoryDecl::new_1d("cache", WgslType::F32, 256);
225        assert_eq!(decl.to_wgsl(), "var<workgroup> cache: array<f32, 256>;");
226    }
227
228    #[test]
229    fn test_shared_memory_2d() {
230        let decl = SharedMemoryDecl::new_2d("tile", WgslType::F32, 16, 16);
231        assert_eq!(
232            decl.to_wgsl(),
233            "var<workgroup> tile: array<array<f32, 16>, 16>;"
234        );
235    }
236
237    #[test]
238    fn test_shared_memory_config() {
239        let mut config = SharedMemoryConfig::new();
240        config.add(SharedMemoryDecl::new_1d("a", WgslType::I32, 64));
241        config.add(SharedMemoryDecl::new_1d("b", WgslType::F32, 128));
242
243        let wgsl = config.to_wgsl();
244        assert!(wgsl.contains("var<workgroup> a: array<i32, 64>;"));
245        assert!(wgsl.contains("var<workgroup> b: array<f32, 128>;"));
246    }
247
248    #[test]
249    fn test_shared_memory_3d() {
250        let decl = SharedMemoryDecl::new_3d("volume", WgslType::F32, 8, 8, 8);
251        assert_eq!(
252            decl.to_wgsl(),
253            "var<workgroup> volume: array<array<array<f32, 8>, 8>, 8>;"
254        );
255        assert_eq!(decl.total_elements(), 512);
256    }
257
258    #[test]
259    fn test_shared_memory_3d_asymmetric() {
260        // 3D tile with halo: 10x10x10 for 8x8x8 interior with 1-cell halo
261        let decl = SharedMemoryDecl::new_3d("tile_with_halo", WgslType::F32, 10, 10, 10);
262        assert_eq!(
263            decl.to_wgsl(),
264            "var<workgroup> tile_with_halo: array<array<array<f32, 10>, 10>, 10>;"
265        );
266        assert_eq!(decl.total_elements(), 1000);
267    }
268
269    #[test]
270    fn test_shared_memory_4d_linearized() {
271        // 4D array gets linearized
272        let decl = SharedMemoryDecl {
273            name: "hypercube".to_string(),
274            element_type: WgslType::F32,
275            dimensions: vec![4, 4, 4, 4],
276        };
277        let wgsl = decl.to_wgsl();
278        assert!(wgsl.contains("array<f32, 256>")); // 4*4*4*4 = 256
279        assert!(wgsl.contains("linearized 4D"));
280        assert!(wgsl.contains("4x4x4x4"));
281    }
282
283    #[test]
284    fn test_linearized_index_formula() {
285        let decl = SharedMemoryDecl {
286            name: "data".to_string(),
287            element_type: WgslType::F32,
288            dimensions: vec![4, 8, 2, 3], // W=4, H=8, D=2, T=3
289        };
290
291        // Formula: x + y*4 + z*32 + t*64
292        let formula = decl
293            .linearized_index_formula(&["x", "y", "z", "t"])
294            .unwrap();
295        assert_eq!(formula, "x + y * 4u + z * 32u + t * 64u");
296    }
297
298    #[test]
299    fn test_linearized_index_formula_returns_none_for_3d() {
300        let decl = SharedMemoryDecl::new_3d("vol", WgslType::F32, 8, 8, 8);
301        // 3D doesn't need linearization
302        assert!(decl.linearized_index_formula(&["x", "y", "z"]).is_none());
303    }
304
305    #[test]
306    fn test_shared_volume_marker() {
307        // Verify the marker type constants work at compile time
308        assert_eq!(SharedVolume::<f32, 8, 8, 8>::width(), 8);
309        assert_eq!(SharedVolume::<f32, 8, 8, 8>::height(), 8);
310        assert_eq!(SharedVolume::<f32, 8, 8, 8>::depth(), 8);
311        assert_eq!(SharedVolume::<f32, 8, 8, 8>::total(), 512);
312
313        // Asymmetric volume
314        assert_eq!(SharedVolume::<f32, 16, 8, 4>::width(), 16);
315        assert_eq!(SharedVolume::<f32, 16, 8, 4>::height(), 8);
316        assert_eq!(SharedVolume::<f32, 16, 8, 4>::depth(), 4);
317        assert_eq!(SharedVolume::<f32, 16, 8, 4>::total(), 512);
318    }
319}