Skip to main content

ringkernel_wgpu_codegen/
bindings.rs

1//! Buffer binding layout generation for WGSL.
2//!
3//! Generates @group/@binding declarations from kernel parameters.
4
5use crate::types::WgslType;
6
7/// Access mode for storage buffers.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AccessMode {
10    /// Read-only access.
11    Read,
12    /// Write-only access (rare in WGSL).
13    Write,
14    /// Read-write access.
15    ReadWrite,
16}
17
18impl AccessMode {
19    /// Get the WGSL access mode string.
20    pub fn to_wgsl(&self) -> &'static str {
21        match self {
22            AccessMode::Read => "read",
23            AccessMode::Write => "write",
24            AccessMode::ReadWrite => "read_write",
25        }
26    }
27}
28
29/// Description of a buffer binding.
30#[derive(Debug, Clone)]
31pub struct BindingLayout {
32    /// Binding group (usually 0).
33    pub group: u32,
34    /// Binding number within the group.
35    pub binding: u32,
36    /// Variable name in the shader.
37    pub name: String,
38    /// Type of the binding.
39    pub ty: WgslType,
40    /// Access mode for storage buffers.
41    pub access: AccessMode,
42}
43
44impl BindingLayout {
45    /// Create a new binding layout.
46    pub fn new(group: u32, binding: u32, name: &str, ty: WgslType, access: AccessMode) -> Self {
47        Self {
48            group,
49            binding,
50            name: name.to_string(),
51            ty,
52            access,
53        }
54    }
55
56    /// Create a read-only storage buffer binding.
57    pub fn storage_read(binding: u32, name: &str, element_type: WgslType) -> Self {
58        Self::new(
59            0,
60            binding,
61            name,
62            WgslType::Array {
63                element: Box::new(element_type),
64                size: None,
65            },
66            AccessMode::Read,
67        )
68    }
69
70    /// Create a read-write storage buffer binding.
71    pub fn storage_read_write(binding: u32, name: &str, element_type: WgslType) -> Self {
72        Self::new(
73            0,
74            binding,
75            name,
76            WgslType::Array {
77                element: Box::new(element_type),
78                size: None,
79            },
80            AccessMode::ReadWrite,
81        )
82    }
83
84    /// Create a uniform buffer binding.
85    pub fn uniform(binding: u32, name: &str, ty: WgslType) -> Self {
86        Self::new(0, binding, name, ty, AccessMode::Read)
87    }
88
89    /// Generate the WGSL binding declaration.
90    pub fn to_wgsl(&self) -> String {
91        let type_str = self.ty.to_wgsl();
92
93        match &self.ty {
94            WgslType::Array { .. } => {
95                // Storage buffer
96                format!(
97                    "@group({}) @binding({}) var<storage, {}> {}: {};",
98                    self.group,
99                    self.binding,
100                    self.access.to_wgsl(),
101                    self.name,
102                    type_str
103                )
104            }
105            WgslType::Struct(_) if self.access == AccessMode::Read => {
106                // Uniform buffer
107                format!(
108                    "@group({}) @binding({}) var<uniform> {}: {};",
109                    self.group, self.binding, self.name, type_str
110                )
111            }
112            _ => {
113                // Generic storage
114                format!(
115                    "@group({}) @binding({}) var<storage, {}> {}: {};",
116                    self.group,
117                    self.binding,
118                    self.access.to_wgsl(),
119                    self.name,
120                    type_str
121                )
122            }
123        }
124    }
125}
126
127/// Generate binding declarations from a list of parameters.
128pub fn generate_bindings(bindings: &[BindingLayout]) -> String {
129    bindings
130        .iter()
131        .map(|b| b.to_wgsl())
132        .collect::<Vec<_>>()
133        .join("\n")
134}
135
136/// Generate bindings for kernel parameters.
137///
138/// Slices become storage buffers, scalars become push constants or uniforms.
139pub fn bindings_from_params(params: &[(String, WgslType, bool)]) -> Vec<BindingLayout> {
140    let mut bindings = Vec::new();
141    let mut binding_idx = 0u32;
142
143    for (name, ty, is_mutable) in params {
144        match ty {
145            WgslType::Ptr { inner, .. } => {
146                let access = if *is_mutable {
147                    AccessMode::ReadWrite
148                } else {
149                    AccessMode::Read
150                };
151                bindings.push(BindingLayout::new(
152                    0,
153                    binding_idx,
154                    name,
155                    WgslType::Array {
156                        element: inner.clone(),
157                        size: None,
158                    },
159                    access,
160                ));
161                binding_idx += 1;
162            }
163            WgslType::Array { element, .. } => {
164                let access = if *is_mutable {
165                    AccessMode::ReadWrite
166                } else {
167                    AccessMode::Read
168                };
169                bindings.push(BindingLayout::new(
170                    0,
171                    binding_idx,
172                    name,
173                    WgslType::Array {
174                        element: element.clone(),
175                        size: None,
176                    },
177                    access,
178                ));
179                binding_idx += 1;
180            }
181            // Scalars are typically passed via uniforms or push constants
182            // For now, we'll add them as uniform buffer fields
183            _ => {
184                bindings.push(BindingLayout::new(
185                    0,
186                    binding_idx,
187                    name,
188                    ty.clone(),
189                    AccessMode::Read,
190                ));
191                binding_idx += 1;
192            }
193        }
194    }
195
196    bindings
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_storage_read_binding() {
205        let binding = BindingLayout::storage_read(0, "input", WgslType::F32);
206        assert_eq!(
207            binding.to_wgsl(),
208            "@group(0) @binding(0) var<storage, read> input: array<f32>;"
209        );
210    }
211
212    #[test]
213    fn test_storage_read_write_binding() {
214        let binding = BindingLayout::storage_read_write(1, "output", WgslType::F32);
215        assert_eq!(
216            binding.to_wgsl(),
217            "@group(0) @binding(1) var<storage, read_write> output: array<f32>;"
218        );
219    }
220
221    #[test]
222    fn test_generate_bindings() {
223        let bindings = vec![
224            BindingLayout::storage_read(0, "input", WgslType::F32),
225            BindingLayout::storage_read_write(1, "output", WgslType::F32),
226        ];
227
228        let wgsl = generate_bindings(&bindings);
229        assert!(wgsl.contains("@binding(0)"));
230        assert!(wgsl.contains("@binding(1)"));
231        assert!(wgsl.contains("read>"));
232        assert!(wgsl.contains("read_write>"));
233    }
234}