daedalus_wgsl_infer/
lib.rs1#[derive(Clone, Debug, PartialEq, Eq)]
9pub enum InferredAccess {
10 StorageRead,
11 StorageReadWrite,
12 StorageWrite,
13 Uniform,
14 StorageTexture {
15 format: Option<String>,
16 view: Option<String>,
17 },
18 Texture {
19 format: Option<String>,
20 sample_type: Option<String>,
21 view: Option<String>,
22 },
23 Sampler(Option<String>),
24}
25
26#[derive(Clone, Debug, PartialEq, Eq)]
34pub struct InferredBinding {
35 pub binding: u32,
36 pub access: InferredAccess,
37}
38
39#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct InferredSpec {
48 pub workgroup: Option<[u32; 3]>,
49 pub bindings: Vec<InferredBinding>,
50}
51
52pub fn infer_workgroup_size(src: &str) -> Option<[u32; 3]> {
60 if let Some(idx) = src.find("@workgroup_size") {
61 let rest = &src[idx..];
62 if let Some(start) = rest.find('(') {
63 let rest = &rest[start + 1..];
64 let mut parts = rest.split(|c| [',', ')'].contains(&c));
65 let x = parts
66 .next()
67 .and_then(|s| s.trim().parse::<u32>().ok())
68 .unwrap_or(0);
69 let y = parts
70 .next()
71 .and_then(|s| s.trim().parse::<u32>().ok())
72 .unwrap_or(1);
73 let z = parts
74 .next()
75 .and_then(|s| s.trim().parse::<u32>().ok())
76 .unwrap_or(1);
77 if x > 0 {
78 return Some([x, y.max(1), z.max(1)]);
79 }
80 }
81 }
82 None
83}
84
85pub fn infer_bindings(src: &str) -> Vec<InferredBinding> {
94 let mut bindings = Vec::new();
95 let mut offset = 0;
96 while let Some(rel_idx) = src[offset..].find("@binding") {
97 let idx = offset + rel_idx;
98 let rest = &src[idx..];
99 let decl = if let Some(end) = rest.find(';') {
100 &rest[..=end]
101 } else {
102 rest
103 };
104 let num = rest
105 .split(|c: char| !c.is_ascii_digit())
106 .find_map(|chunk| chunk.parse::<u32>().ok());
107 let Some(binding) = num else {
108 offset = idx + "@binding".len();
109 continue;
110 };
111 let lower = decl.to_ascii_lowercase();
112 let mut format = None;
113 let mut sample_type = None;
114 let access = if lower.contains("var<storage") {
115 if lower.contains("read_write") {
116 InferredAccess::StorageReadWrite
117 } else if lower.contains("write") && !lower.contains("read") {
118 InferredAccess::StorageWrite
119 } else {
120 InferredAccess::StorageRead
121 }
122 } else if lower.contains("var<uniform") {
123 InferredAccess::Uniform
124 } else if lower.contains("texture_storage_2d") {
125 let view = Some("2d".to_string());
126 if let Some(lt) = lower.find('<')
127 && let Some(gt) = lower[lt..].find('>')
128 {
129 let inner = &lower[lt + 1..lt + gt];
130 let parts: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
131 if let Some(fmt) = parts.first() {
132 format = Some(fmt.to_string());
133 }
134 }
135 InferredAccess::StorageTexture { format, view }
136 } else if lower.contains("texture_2d_array") {
137 let view = Some("2d_array".to_string());
138 if let Some(lt) = lower.find('<')
139 && let Some(gt) = lower[lt..].find('>')
140 {
141 let inner = &lower[lt + 1..lt + gt];
142 sample_type = Some(inner.trim().to_string());
143 }
144 InferredAccess::Texture {
145 format: None,
146 sample_type,
147 view,
148 }
149 } else if lower.contains("texture_2d") {
150 let view = Some("2d".to_string());
151 if let Some(lt) = lower.find('<')
152 && let Some(gt) = lower[lt..].find('>')
153 {
154 let inner = &lower[lt + 1..lt + gt];
155 sample_type = Some(inner.trim().to_string());
156 }
157 InferredAccess::Texture {
158 format: None,
159 sample_type,
160 view,
161 }
162 } else if lower.contains("sampler_comparison") {
163 InferredAccess::Sampler(Some("comparison".into()))
164 } else if lower.contains("sampler") {
165 InferredAccess::Sampler(Some("filtering".into()))
166 } else {
167 offset = idx + "@binding".len();
168 continue;
169 };
170 bindings.push(InferredBinding { binding, access });
171 offset = idx + decl.len();
172 }
173 bindings
174}
175
176pub fn infer_spec(src: &str) -> InferredSpec {
185 InferredSpec {
186 workgroup: infer_workgroup_size(src),
187 bindings: infer_bindings(src),
188 }
189}