1use crate::types::WgslType;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AccessMode {
10 Read,
12 Write,
14 ReadWrite,
16}
17
18impl AccessMode {
19 pub fn to_wgsl(&self) -> &'static str {
21 match self {
22 AccessMode::Read => "read",
23 AccessMode::Write => "write",
24 AccessMode::ReadWrite => "read_write",
25 }
26 }
27}
28
29#[derive(Debug, Clone)]
31pub struct BindingLayout {
32 pub group: u32,
34 pub binding: u32,
36 pub name: String,
38 pub ty: WgslType,
40 pub access: AccessMode,
42}
43
44impl BindingLayout {
45 pub fn new(group: u32, binding: u32, name: &str, ty: WgslType, access: AccessMode) -> Self {
47 Self {
48 group,
49 binding,
50 name: name.to_string(),
51 ty,
52 access,
53 }
54 }
55
56 pub fn storage_read(binding: u32, name: &str, element_type: WgslType) -> Self {
58 Self::new(
59 0,
60 binding,
61 name,
62 WgslType::Array {
63 element: Box::new(element_type),
64 size: None,
65 },
66 AccessMode::Read,
67 )
68 }
69
70 pub fn storage_read_write(binding: u32, name: &str, element_type: WgslType) -> Self {
72 Self::new(
73 0,
74 binding,
75 name,
76 WgslType::Array {
77 element: Box::new(element_type),
78 size: None,
79 },
80 AccessMode::ReadWrite,
81 )
82 }
83
84 pub fn uniform(binding: u32, name: &str, ty: WgslType) -> Self {
86 Self::new(0, binding, name, ty, AccessMode::Read)
87 }
88
89 pub fn to_wgsl(&self) -> String {
91 let type_str = self.ty.to_wgsl();
92
93 match &self.ty {
94 WgslType::Array { .. } => {
95 format!(
97 "@group({}) @binding({}) var<storage, {}> {}: {};",
98 self.group,
99 self.binding,
100 self.access.to_wgsl(),
101 self.name,
102 type_str
103 )
104 }
105 WgslType::Struct(_) if self.access == AccessMode::Read => {
106 format!(
108 "@group({}) @binding({}) var<uniform> {}: {};",
109 self.group, self.binding, self.name, type_str
110 )
111 }
112 _ => {
113 format!(
115 "@group({}) @binding({}) var<storage, {}> {}: {};",
116 self.group,
117 self.binding,
118 self.access.to_wgsl(),
119 self.name,
120 type_str
121 )
122 }
123 }
124 }
125}
126
127pub fn generate_bindings(bindings: &[BindingLayout]) -> String {
129 bindings
130 .iter()
131 .map(|b| b.to_wgsl())
132 .collect::<Vec<_>>()
133 .join("\n")
134}
135
136pub fn bindings_from_params(params: &[(String, WgslType, bool)]) -> Vec<BindingLayout> {
140 let mut bindings = Vec::new();
141 let mut binding_idx = 0u32;
142
143 for (name, ty, is_mutable) in params {
144 match ty {
145 WgslType::Ptr { inner, .. } => {
146 let access = if *is_mutable {
147 AccessMode::ReadWrite
148 } else {
149 AccessMode::Read
150 };
151 bindings.push(BindingLayout::new(
152 0,
153 binding_idx,
154 name,
155 WgslType::Array {
156 element: inner.clone(),
157 size: None,
158 },
159 access,
160 ));
161 binding_idx += 1;
162 }
163 WgslType::Array { element, .. } => {
164 let access = if *is_mutable {
165 AccessMode::ReadWrite
166 } else {
167 AccessMode::Read
168 };
169 bindings.push(BindingLayout::new(
170 0,
171 binding_idx,
172 name,
173 WgslType::Array {
174 element: element.clone(),
175 size: None,
176 },
177 access,
178 ));
179 binding_idx += 1;
180 }
181 _ => {
184 bindings.push(BindingLayout::new(
185 0,
186 binding_idx,
187 name,
188 ty.clone(),
189 AccessMode::Read,
190 ));
191 binding_idx += 1;
192 }
193 }
194 }
195
196 bindings
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_storage_read_binding() {
205 let binding = BindingLayout::storage_read(0, "input", WgslType::F32);
206 assert_eq!(
207 binding.to_wgsl(),
208 "@group(0) @binding(0) var<storage, read> input: array<f32>;"
209 );
210 }
211
212 #[test]
213 fn test_storage_read_write_binding() {
214 let binding = BindingLayout::storage_read_write(1, "output", WgslType::F32);
215 assert_eq!(
216 binding.to_wgsl(),
217 "@group(0) @binding(1) var<storage, read_write> output: array<f32>;"
218 );
219 }
220
221 #[test]
222 fn test_generate_bindings() {
223 let bindings = vec![
224 BindingLayout::storage_read(0, "input", WgslType::F32),
225 BindingLayout::storage_read_write(1, "output", WgslType::F32),
226 ];
227
228 let wgsl = generate_bindings(&bindings);
229 assert!(wgsl.contains("@binding(0)"));
230 assert!(wgsl.contains("@binding(1)"));
231 assert!(wgsl.contains("read>"));
232 assert!(wgsl.contains("read_write>"));
233 }
234}