Skip to main content

ringkernel_wgpu_codegen/
validation.rs

1//! DSL validation for WGSL code generation.
2//!
3//! This module provides validation for the Rust DSL subset that can be
4//! transpiled to WGSL.
5
6use thiserror::Error;
7
8/// Errors that can occur during validation.
9#[derive(Error, Debug)]
10pub enum ValidationError {
11    /// Unsupported Rust construct.
12    #[error("Unsupported: {0}")]
13    Unsupported(String),
14
15    /// Invalid DSL usage.
16    #[error("Invalid DSL usage: {0}")]
17    InvalidDsl(String),
18
19    /// Loop not allowed in this mode.
20    #[error("Loops not allowed in {0} mode")]
21    LoopNotAllowed(String),
22}
23
24/// Validation mode determines which constructs are allowed.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum ValidationMode {
27    /// Stencil kernels - classic grid-based computation.
28    /// Loops are forbidden (each thread handles one cell).
29    Stencil,
30
31    /// Generic kernels - general-purpose compute shaders.
32    /// Loops are allowed.
33    #[default]
34    Generic,
35
36    /// Ring kernels - persistent message processing.
37    /// Loops are required (the persistent message loop).
38    RingKernel,
39}
40
41impl ValidationMode {
42    /// Check if this mode allows loops.
43    pub fn allows_loops(&self) -> bool {
44        match self {
45            ValidationMode::Stencil => false,
46            ValidationMode::Generic => true,
47            ValidationMode::RingKernel => true,
48        }
49    }
50
51    /// Check if this mode requires loops.
52    pub fn requires_loops(&self) -> bool {
53        matches!(self, ValidationMode::RingKernel)
54    }
55}
56
57/// Validate a function for WGSL transpilation.
58pub fn validate_function(func: &syn::ItemFn) -> Result<(), ValidationError> {
59    validate_function_with_mode(func, ValidationMode::default())
60}
61
62/// Validate a function with a specific validation mode.
63pub fn validate_function_with_mode(
64    func: &syn::ItemFn,
65    mode: ValidationMode,
66) -> Result<(), ValidationError> {
67    // Check for unsupported attributes
68    for attr in &func.attrs {
69        let path = attr.path().to_token_stream().to_string();
70        if path != "doc" && path != "allow" && path != "cfg" {
71            return Err(ValidationError::Unsupported(format!(
72                "Function attribute: {}",
73                path
74            )));
75        }
76    }
77
78    // Check for async functions (not supported)
79    if func.sig.asyncness.is_some() {
80        return Err(ValidationError::Unsupported(
81            "Async functions are not supported in WGSL".to_string(),
82        ));
83    }
84
85    // Check for generics (not supported)
86    if !func.sig.generics.params.is_empty() {
87        return Err(ValidationError::Unsupported(
88            "Generic functions are not supported in WGSL".to_string(),
89        ));
90    }
91
92    // Check for variadic functions (not supported)
93    if func.sig.variadic.is_some() {
94        return Err(ValidationError::Unsupported(
95            "Variadic functions are not supported in WGSL".to_string(),
96        ));
97    }
98
99    // Validate the function body
100    validate_block(&func.block, mode)?;
101
102    Ok(())
103}
104
105/// Validate a block of statements.
106fn validate_block(block: &syn::Block, mode: ValidationMode) -> Result<(), ValidationError> {
107    for stmt in &block.stmts {
108        validate_stmt(stmt, mode)?;
109    }
110    Ok(())
111}
112
113/// Validate a single statement.
114fn validate_stmt(stmt: &syn::Stmt, mode: ValidationMode) -> Result<(), ValidationError> {
115    match stmt {
116        syn::Stmt::Local(local) => {
117            // Validate the initializer if present
118            if let Some(init) = &local.init {
119                validate_expr(&init.expr, mode)?;
120            }
121            Ok(())
122        }
123        syn::Stmt::Expr(expr, _) => validate_expr(expr, mode),
124        syn::Stmt::Item(_) => Err(ValidationError::Unsupported(
125            "Nested items are not supported".to_string(),
126        )),
127        syn::Stmt::Macro(_) => Err(ValidationError::Unsupported(
128            "Macros are not supported in WGSL DSL".to_string(),
129        )),
130    }
131}
132
133/// Validate an expression.
134fn validate_expr(expr: &syn::Expr, mode: ValidationMode) -> Result<(), ValidationError> {
135    match expr {
136        // Simple expressions - always allowed
137        syn::Expr::Lit(_) => Ok(()),
138        syn::Expr::Path(_) => Ok(()),
139        syn::Expr::Paren(p) => validate_expr(&p.expr, mode),
140
141        // Binary and unary operations
142        syn::Expr::Binary(bin) => {
143            validate_expr(&bin.left, mode)?;
144            validate_expr(&bin.right, mode)
145        }
146        syn::Expr::Unary(unary) => validate_expr(&unary.expr, mode),
147
148        // Array indexing
149        syn::Expr::Index(idx) => {
150            validate_expr(&idx.expr, mode)?;
151            validate_expr(&idx.index, mode)
152        }
153
154        // Function and method calls
155        syn::Expr::Call(call) => {
156            validate_expr(&call.func, mode)?;
157            for arg in &call.args {
158                validate_expr(arg, mode)?;
159            }
160            Ok(())
161        }
162        syn::Expr::MethodCall(method) => {
163            validate_expr(&method.receiver, mode)?;
164            for arg in &method.args {
165                validate_expr(arg, mode)?;
166            }
167            Ok(())
168        }
169
170        // Control flow
171        syn::Expr::If(if_expr) => {
172            validate_expr(&if_expr.cond, mode)?;
173            validate_block(&if_expr.then_branch, mode)?;
174            if let Some((_, else_branch)) = &if_expr.else_branch {
175                validate_expr(else_branch, mode)?;
176            }
177            Ok(())
178        }
179        syn::Expr::Block(block) => validate_block(&block.block, mode),
180
181        // Loops - mode-dependent
182        syn::Expr::ForLoop(for_loop) => {
183            if !mode.allows_loops() {
184                return Err(ValidationError::LoopNotAllowed("stencil".to_string()));
185            }
186            validate_expr(&for_loop.expr, mode)?;
187            validate_block(&for_loop.body, mode)
188        }
189        syn::Expr::While(while_loop) => {
190            if !mode.allows_loops() {
191                return Err(ValidationError::LoopNotAllowed("stencil".to_string()));
192            }
193            validate_expr(&while_loop.cond, mode)?;
194            validate_block(&while_loop.body, mode)
195        }
196        syn::Expr::Loop(loop_expr) => {
197            if !mode.allows_loops() {
198                return Err(ValidationError::LoopNotAllowed("stencil".to_string()));
199            }
200            validate_block(&loop_expr.body, mode)
201        }
202
203        // Return and control
204        syn::Expr::Return(ret) => {
205            if let Some(expr) = &ret.expr {
206                validate_expr(expr, mode)?;
207            }
208            Ok(())
209        }
210        syn::Expr::Break(_) => Ok(()),
211        syn::Expr::Continue(_) => Ok(()),
212
213        // Assignments
214        syn::Expr::Assign(assign) => {
215            validate_expr(&assign.left, mode)?;
216            validate_expr(&assign.right, mode)
217        }
218
219        // Type casts
220        syn::Expr::Cast(cast) => validate_expr(&cast.expr, mode),
221
222        // Struct literals
223        syn::Expr::Struct(struct_expr) => {
224            for field in &struct_expr.fields {
225                validate_expr(&field.expr, mode)?;
226            }
227            Ok(())
228        }
229
230        // Field access
231        syn::Expr::Field(field) => validate_expr(&field.base, mode),
232
233        // Match expressions
234        syn::Expr::Match(match_expr) => {
235            validate_expr(&match_expr.expr, mode)?;
236            for arm in &match_expr.arms {
237                validate_expr(&arm.body, mode)?;
238            }
239            Ok(())
240        }
241
242        // Range expressions (for loops)
243        syn::Expr::Range(range) => {
244            if let Some(start) = &range.start {
245                validate_expr(start, mode)?;
246            }
247            if let Some(end) = &range.end {
248                validate_expr(end, mode)?;
249            }
250            Ok(())
251        }
252
253        // Reference and dereference
254        syn::Expr::Reference(ref_expr) => validate_expr(&ref_expr.expr, mode),
255
256        // Unsupported expressions
257        _ => Err(ValidationError::Unsupported(format!(
258            "Expression type: {}",
259            quote::ToTokens::to_token_stream(expr)
260        ))),
261    }
262}
263
264use quote::ToTokens;
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use syn::parse_quote;
270
271    #[test]
272    fn test_simple_function_validates() {
273        let func: syn::ItemFn = parse_quote! {
274            fn add(a: f32, b: f32) -> f32 {
275                a + b
276            }
277        };
278        assert!(validate_function(&func).is_ok());
279    }
280
281    #[test]
282    fn test_async_function_rejected() {
283        let func: syn::ItemFn = parse_quote! {
284            async fn process() {}
285        };
286        assert!(validate_function(&func).is_err());
287    }
288
289    #[test]
290    fn test_generic_function_rejected() {
291        let func: syn::ItemFn = parse_quote! {
292            fn generic<T>(x: T) -> T { x }
293        };
294        assert!(validate_function(&func).is_err());
295    }
296
297    #[test]
298    fn test_stencil_mode_rejects_loops() {
299        let func: syn::ItemFn = parse_quote! {
300            fn with_loop() {
301                for i in 0..10 {
302                    process(i);
303                }
304            }
305        };
306        assert!(validate_function_with_mode(&func, ValidationMode::Stencil).is_err());
307    }
308
309    #[test]
310    fn test_generic_mode_allows_loops() {
311        let func: syn::ItemFn = parse_quote! {
312            fn with_loop() {
313                for i in 0..10 {
314                    process(i);
315                }
316            }
317        };
318        assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
319    }
320}