daedalus_wgsl_infer/
lib.rs

1/// Inferred access kind for a WGSL binding.
2///
3/// ```
4/// use daedalus_wgsl_infer::InferredAccess;
5/// let access = InferredAccess::StorageRead;
6/// assert!(matches!(access, InferredAccess::StorageRead));
7/// ```
8#[derive(Clone, Debug, PartialEq, Eq)]
9pub enum InferredAccess {
10    StorageRead,
11    StorageReadWrite,
12    StorageWrite,
13    Uniform,
14    StorageTexture {
15        format: Option<String>,
16        view: Option<String>,
17    },
18    Texture {
19        format: Option<String>,
20        sample_type: Option<String>,
21        view: Option<String>,
22    },
23    Sampler(Option<String>),
24}
25
26/// Inferred binding metadata.
27///
28/// ```
29/// use daedalus_wgsl_infer::{InferredAccess, InferredBinding};
30/// let binding = InferredBinding { binding: 0, access: InferredAccess::Uniform };
31/// assert_eq!(binding.binding, 0);
32/// ```
33#[derive(Clone, Debug, PartialEq, Eq)]
34pub struct InferredBinding {
35    pub binding: u32,
36    pub access: InferredAccess,
37}
38
39/// Inferred workgroup and bindings for a WGSL entry point.
40///
41/// ```
42/// use daedalus_wgsl_infer::InferredSpec;
43/// let spec = InferredSpec { workgroup: None, bindings: vec![] };
44/// assert!(spec.bindings.is_empty());
45/// ```
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct InferredSpec {
48    pub workgroup: Option<[u32; 3]>,
49    pub bindings: Vec<InferredBinding>,
50}
51
52/// Infer the `@workgroup_size` annotation from WGSL source.
53///
54/// ```
55/// use daedalus_wgsl_infer::infer_workgroup_size;
56/// let wgsl = "@compute @workgroup_size(8, 4, 1) fn main() {}";
57/// assert_eq!(infer_workgroup_size(wgsl), Some([8, 4, 1]));
58/// ```
59pub fn infer_workgroup_size(src: &str) -> Option<[u32; 3]> {
60    if let Some(idx) = src.find("@workgroup_size") {
61        let rest = &src[idx..];
62        if let Some(start) = rest.find('(') {
63            let rest = &rest[start + 1..];
64            let mut parts = rest.split(|c| [',', ')'].contains(&c));
65            let x = parts
66                .next()
67                .and_then(|s| s.trim().parse::<u32>().ok())
68                .unwrap_or(0);
69            let y = parts
70                .next()
71                .and_then(|s| s.trim().parse::<u32>().ok())
72                .unwrap_or(1);
73            let z = parts
74                .next()
75                .and_then(|s| s.trim().parse::<u32>().ok())
76                .unwrap_or(1);
77            if x > 0 {
78                return Some([x, y.max(1), z.max(1)]);
79            }
80        }
81    }
82    None
83}
84
85/// Infer bindings from WGSL source.
86///
87/// ```
88/// use daedalus_wgsl_infer::infer_bindings;
89/// let wgsl = "@group(0) @binding(0) var<uniform> Params: vec4<f32>;";
90/// let bindings = infer_bindings(wgsl);
91/// assert_eq!(bindings.len(), 1);
92/// ```
93pub fn infer_bindings(src: &str) -> Vec<InferredBinding> {
94    let mut bindings = Vec::new();
95    let mut offset = 0;
96    while let Some(rel_idx) = src[offset..].find("@binding") {
97        let idx = offset + rel_idx;
98        let rest = &src[idx..];
99        let decl = if let Some(end) = rest.find(';') {
100            &rest[..=end]
101        } else {
102            rest
103        };
104        let num = rest
105            .split(|c: char| !c.is_ascii_digit())
106            .find_map(|chunk| chunk.parse::<u32>().ok());
107        let Some(binding) = num else {
108            offset = idx + "@binding".len();
109            continue;
110        };
111        let lower = decl.to_ascii_lowercase();
112        let mut format = None;
113        let mut sample_type = None;
114        let access = if lower.contains("var<storage") {
115            if lower.contains("read_write") {
116                InferredAccess::StorageReadWrite
117            } else if lower.contains("write") && !lower.contains("read") {
118                InferredAccess::StorageWrite
119            } else {
120                InferredAccess::StorageRead
121            }
122        } else if lower.contains("var<uniform") {
123            InferredAccess::Uniform
124        } else if lower.contains("texture_storage_2d") {
125            let view = Some("2d".to_string());
126            if let Some(lt) = lower.find('<')
127                && let Some(gt) = lower[lt..].find('>')
128            {
129                let inner = &lower[lt + 1..lt + gt];
130                let parts: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
131                if let Some(fmt) = parts.first() {
132                    format = Some(fmt.to_string());
133                }
134            }
135            InferredAccess::StorageTexture { format, view }
136        } else if lower.contains("texture_2d_array") {
137            let view = Some("2d_array".to_string());
138            if let Some(lt) = lower.find('<')
139                && let Some(gt) = lower[lt..].find('>')
140            {
141                let inner = &lower[lt + 1..lt + gt];
142                sample_type = Some(inner.trim().to_string());
143            }
144            InferredAccess::Texture {
145                format: None,
146                sample_type,
147                view,
148            }
149        } else if lower.contains("texture_2d") {
150            let view = Some("2d".to_string());
151            if let Some(lt) = lower.find('<')
152                && let Some(gt) = lower[lt..].find('>')
153            {
154                let inner = &lower[lt + 1..lt + gt];
155                sample_type = Some(inner.trim().to_string());
156            }
157            InferredAccess::Texture {
158                format: None,
159                sample_type,
160                view,
161            }
162        } else if lower.contains("sampler_comparison") {
163            InferredAccess::Sampler(Some("comparison".into()))
164        } else if lower.contains("sampler") {
165            InferredAccess::Sampler(Some("filtering".into()))
166        } else {
167            offset = idx + "@binding".len();
168            continue;
169        };
170        bindings.push(InferredBinding { binding, access });
171        offset = idx + decl.len();
172    }
173    bindings
174}
175
176/// Infer both workgroup size and bindings from WGSL source.
177///
178/// ```
179/// use daedalus_wgsl_infer::infer_spec;
180/// let wgsl = "@compute @workgroup_size(1) fn main() {}";
181/// let spec = infer_spec(wgsl);
182/// assert!(spec.workgroup.is_some());
183/// ```
184pub fn infer_spec(src: &str) -> InferredSpec {
185    InferredSpec {
186        workgroup: infer_workgroup_size(src),
187        bindings: infer_bindings(src),
188    }
189}