1use thiserror::Error;
7
8#[derive(Error, Debug)]
10pub enum ValidationError {
11 #[error("Unsupported: {0}")]
13 Unsupported(String),
14
15 #[error("Invalid DSL usage: {0}")]
17 InvalidDsl(String),
18
19 #[error("Loops not allowed in {0} mode")]
21 LoopNotAllowed(String),
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum ValidationMode {
27 Stencil,
30
31 #[default]
34 Generic,
35
36 RingKernel,
39}
40
41impl ValidationMode {
42 pub fn allows_loops(&self) -> bool {
44 match self {
45 ValidationMode::Stencil => false,
46 ValidationMode::Generic => true,
47 ValidationMode::RingKernel => true,
48 }
49 }
50
51 pub fn requires_loops(&self) -> bool {
53 matches!(self, ValidationMode::RingKernel)
54 }
55}
56
57pub fn validate_function(func: &syn::ItemFn) -> Result<(), ValidationError> {
59 validate_function_with_mode(func, ValidationMode::default())
60}
61
62pub fn validate_function_with_mode(
64 func: &syn::ItemFn,
65 mode: ValidationMode,
66) -> Result<(), ValidationError> {
67 for attr in &func.attrs {
69 let path = attr.path().to_token_stream().to_string();
70 if path != "doc" && path != "allow" && path != "cfg" {
71 return Err(ValidationError::Unsupported(format!(
72 "Function attribute: {}",
73 path
74 )));
75 }
76 }
77
78 if func.sig.asyncness.is_some() {
80 return Err(ValidationError::Unsupported(
81 "Async functions are not supported in WGSL".to_string(),
82 ));
83 }
84
85 if !func.sig.generics.params.is_empty() {
87 return Err(ValidationError::Unsupported(
88 "Generic functions are not supported in WGSL".to_string(),
89 ));
90 }
91
92 if func.sig.variadic.is_some() {
94 return Err(ValidationError::Unsupported(
95 "Variadic functions are not supported in WGSL".to_string(),
96 ));
97 }
98
99 validate_block(&func.block, mode)?;
101
102 Ok(())
103}
104
105fn validate_block(block: &syn::Block, mode: ValidationMode) -> Result<(), ValidationError> {
107 for stmt in &block.stmts {
108 validate_stmt(stmt, mode)?;
109 }
110 Ok(())
111}
112
113fn validate_stmt(stmt: &syn::Stmt, mode: ValidationMode) -> Result<(), ValidationError> {
115 match stmt {
116 syn::Stmt::Local(local) => {
117 if let Some(init) = &local.init {
119 validate_expr(&init.expr, mode)?;
120 }
121 Ok(())
122 }
123 syn::Stmt::Expr(expr, _) => validate_expr(expr, mode),
124 syn::Stmt::Item(_) => Err(ValidationError::Unsupported(
125 "Nested items are not supported".to_string(),
126 )),
127 syn::Stmt::Macro(_) => Err(ValidationError::Unsupported(
128 "Macros are not supported in WGSL DSL".to_string(),
129 )),
130 }
131}
132
133fn validate_expr(expr: &syn::Expr, mode: ValidationMode) -> Result<(), ValidationError> {
135 match expr {
136 syn::Expr::Lit(_) => Ok(()),
138 syn::Expr::Path(_) => Ok(()),
139 syn::Expr::Paren(p) => validate_expr(&p.expr, mode),
140
141 syn::Expr::Binary(bin) => {
143 validate_expr(&bin.left, mode)?;
144 validate_expr(&bin.right, mode)
145 }
146 syn::Expr::Unary(unary) => validate_expr(&unary.expr, mode),
147
148 syn::Expr::Index(idx) => {
150 validate_expr(&idx.expr, mode)?;
151 validate_expr(&idx.index, mode)
152 }
153
154 syn::Expr::Call(call) => {
156 validate_expr(&call.func, mode)?;
157 for arg in &call.args {
158 validate_expr(arg, mode)?;
159 }
160 Ok(())
161 }
162 syn::Expr::MethodCall(method) => {
163 validate_expr(&method.receiver, mode)?;
164 for arg in &method.args {
165 validate_expr(arg, mode)?;
166 }
167 Ok(())
168 }
169
170 syn::Expr::If(if_expr) => {
172 validate_expr(&if_expr.cond, mode)?;
173 validate_block(&if_expr.then_branch, mode)?;
174 if let Some((_, else_branch)) = &if_expr.else_branch {
175 validate_expr(else_branch, mode)?;
176 }
177 Ok(())
178 }
179 syn::Expr::Block(block) => validate_block(&block.block, mode),
180
181 syn::Expr::ForLoop(for_loop) => {
183 if !mode.allows_loops() {
184 return Err(ValidationError::LoopNotAllowed("stencil".to_string()));
185 }
186 validate_expr(&for_loop.expr, mode)?;
187 validate_block(&for_loop.body, mode)
188 }
189 syn::Expr::While(while_loop) => {
190 if !mode.allows_loops() {
191 return Err(ValidationError::LoopNotAllowed("stencil".to_string()));
192 }
193 validate_expr(&while_loop.cond, mode)?;
194 validate_block(&while_loop.body, mode)
195 }
196 syn::Expr::Loop(loop_expr) => {
197 if !mode.allows_loops() {
198 return Err(ValidationError::LoopNotAllowed("stencil".to_string()));
199 }
200 validate_block(&loop_expr.body, mode)
201 }
202
203 syn::Expr::Return(ret) => {
205 if let Some(expr) = &ret.expr {
206 validate_expr(expr, mode)?;
207 }
208 Ok(())
209 }
210 syn::Expr::Break(_) => Ok(()),
211 syn::Expr::Continue(_) => Ok(()),
212
213 syn::Expr::Assign(assign) => {
215 validate_expr(&assign.left, mode)?;
216 validate_expr(&assign.right, mode)
217 }
218
219 syn::Expr::Cast(cast) => validate_expr(&cast.expr, mode),
221
222 syn::Expr::Struct(struct_expr) => {
224 for field in &struct_expr.fields {
225 validate_expr(&field.expr, mode)?;
226 }
227 Ok(())
228 }
229
230 syn::Expr::Field(field) => validate_expr(&field.base, mode),
232
233 syn::Expr::Match(match_expr) => {
235 validate_expr(&match_expr.expr, mode)?;
236 for arm in &match_expr.arms {
237 validate_expr(&arm.body, mode)?;
238 }
239 Ok(())
240 }
241
242 syn::Expr::Range(range) => {
244 if let Some(start) = &range.start {
245 validate_expr(start, mode)?;
246 }
247 if let Some(end) = &range.end {
248 validate_expr(end, mode)?;
249 }
250 Ok(())
251 }
252
253 syn::Expr::Reference(ref_expr) => validate_expr(&ref_expr.expr, mode),
255
256 _ => Err(ValidationError::Unsupported(format!(
258 "Expression type: {}",
259 quote::ToTokens::to_token_stream(expr)
260 ))),
261 }
262}
263
264use quote::ToTokens;
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use syn::parse_quote;
270
271 #[test]
272 fn test_simple_function_validates() {
273 let func: syn::ItemFn = parse_quote! {
274 fn add(a: f32, b: f32) -> f32 {
275 a + b
276 }
277 };
278 assert!(validate_function(&func).is_ok());
279 }
280
281 #[test]
282 fn test_async_function_rejected() {
283 let func: syn::ItemFn = parse_quote! {
284 async fn process() {}
285 };
286 assert!(validate_function(&func).is_err());
287 }
288
289 #[test]
290 fn test_generic_function_rejected() {
291 let func: syn::ItemFn = parse_quote! {
292 fn generic<T>(x: T) -> T { x }
293 };
294 assert!(validate_function(&func).is_err());
295 }
296
297 #[test]
298 fn test_stencil_mode_rejects_loops() {
299 let func: syn::ItemFn = parse_quote! {
300 fn with_loop() {
301 for i in 0..10 {
302 process(i);
303 }
304 }
305 };
306 assert!(validate_function_with_mode(&func, ValidationMode::Stencil).is_err());
307 }
308
309 #[test]
310 fn test_generic_mode_allows_loops() {
311 let func: syn::ItemFn = parse_quote! {
312 fn with_loop() {
313 for i in 0..10 {
314 process(i);
315 }
316 }
317 };
318 assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
319 }
320}