ringkernel_wgpu_codegen/
shared.rs

1//! Shared (workgroup) memory support for WGSL code generation.
2//!
3//! Maps Rust shared memory patterns to WGSL workgroup variables.
4
5use crate::types::WgslType;
6
7/// Declaration of a shared memory variable.
8#[derive(Debug, Clone)]
9pub struct SharedMemoryDecl {
10    /// Variable name.
11    pub name: String,
12    /// Element type.
13    pub element_type: WgslType,
14    /// Dimensions (1D, 2D, etc.).
15    pub dimensions: Vec<u32>,
16}
17
18impl SharedMemoryDecl {
19    /// Create a 1D shared memory declaration.
20    pub fn new_1d(name: &str, element_type: WgslType, size: u32) -> Self {
21        Self {
22            name: name.to_string(),
23            element_type,
24            dimensions: vec![size],
25        }
26    }
27
28    /// Create a 2D shared memory declaration.
29    pub fn new_2d(name: &str, element_type: WgslType, width: u32, height: u32) -> Self {
30        Self {
31            name: name.to_string(),
32            element_type,
33            dimensions: vec![width, height],
34        }
35    }
36
37    /// Generate WGSL declaration.
38    pub fn to_wgsl(&self) -> String {
39        let type_str = self.element_type.to_wgsl();
40        match self.dimensions.len() {
41            1 => format!(
42                "var<workgroup> {}: array<{}, {}>;",
43                self.name, type_str, self.dimensions[0]
44            ),
45            2 => format!(
46                "var<workgroup> {}: array<array<{}, {}>, {}>;",
47                self.name, type_str, self.dimensions[0], self.dimensions[1]
48            ),
49            _ => format!(
50                "var<workgroup> {}: array<{}, {}>; // TODO: higher dimensions",
51                self.name, type_str, self.dimensions[0]
52            ),
53        }
54    }
55}
56
57/// Configuration for shared memory in a kernel.
58#[derive(Debug, Clone, Default)]
59pub struct SharedMemoryConfig {
60    /// List of shared memory declarations.
61    pub declarations: Vec<SharedMemoryDecl>,
62}
63
64impl SharedMemoryConfig {
65    /// Create a new empty configuration.
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Add a shared memory declaration.
71    pub fn add(&mut self, decl: SharedMemoryDecl) {
72        self.declarations.push(decl);
73    }
74
75    /// Generate all WGSL declarations.
76    pub fn to_wgsl(&self) -> String {
77        self.declarations
78            .iter()
79            .map(|d| d.to_wgsl())
80            .collect::<Vec<_>>()
81            .join("\n")
82    }
83}
84
85/// Marker type for 2D shared memory tiles.
86///
87/// This is a compile-time marker. The transpiler recognizes usage and
88/// generates appropriate WGSL workgroup variables.
89pub struct SharedTile<T, const W: usize, const H: usize> {
90    _marker: std::marker::PhantomData<T>,
91}
92
93impl<T, const W: usize, const H: usize> SharedTile<T, W, H> {
94    /// Get the tile width.
95    pub const fn width() -> usize {
96        W
97    }
98
99    /// Get the tile height.
100    pub const fn height() -> usize {
101        H
102    }
103}
104
105/// Marker type for 1D shared memory arrays.
106pub struct SharedArray<T, const N: usize> {
107    _marker: std::marker::PhantomData<T>,
108}
109
110impl<T, const N: usize> SharedArray<T, N> {
111    /// Get the array size.
112    pub const fn size() -> usize {
113        N
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_shared_memory_1d() {
123        let decl = SharedMemoryDecl::new_1d("cache", WgslType::F32, 256);
124        assert_eq!(decl.to_wgsl(), "var<workgroup> cache: array<f32, 256>;");
125    }
126
127    #[test]
128    fn test_shared_memory_2d() {
129        let decl = SharedMemoryDecl::new_2d("tile", WgslType::F32, 16, 16);
130        assert_eq!(
131            decl.to_wgsl(),
132            "var<workgroup> tile: array<array<f32, 16>, 16>;"
133        );
134    }
135
136    #[test]
137    fn test_shared_memory_config() {
138        let mut config = SharedMemoryConfig::new();
139        config.add(SharedMemoryDecl::new_1d("a", WgslType::I32, 64));
140        config.add(SharedMemoryDecl::new_1d("b", WgslType::F32, 128));
141
142        let wgsl = config.to_wgsl();
143        assert!(wgsl.contains("var<workgroup> a: array<i32, 64>;"));
144        assert!(wgsl.contains("var<workgroup> b: array<f32, 128>;"));
145    }
146}