ringkernel_cuda_codegen/
validation.rs

1//! DSL constraint validation for CUDA code generation.
2//!
3//! This module validates that Rust code conforms to the restricted DSL
4//! that can be transpiled to CUDA.
5//!
6//! # Validation Modes
7//!
8//! Different kernel types have different validation requirements:
9//!
10//! - **Stencil**: No loops allowed (use parallel threads)
11//! - **Generic**: Loops allowed for general CUDA kernels
12//! - **RingKernel**: Loops required for persistent actor kernels
13
14use syn::visit::Visit;
15use syn::{
16    Expr, ExprAsync, ExprAwait, ExprClosure, ExprForLoop, ExprLoop, ExprWhile, ExprYield, ItemFn,
17    Stmt,
18};
19use thiserror::Error;
20
21/// Validation mode for different kernel types.
22///
23/// Different kernel patterns have different requirements for what
24/// Rust constructs are allowed.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum ValidationMode {
27    /// Stencil kernels: No loops allowed (current behavior).
28    /// Work is parallelized across threads, so loops would serialize work.
29    #[default]
30    Stencil,
31
32    /// Generic CUDA kernels: Loops allowed.
33    /// Useful for kernels that need sequential processing within a thread.
34    Generic,
35
36    /// Ring/Actor kernels: Loops required (persistent message loop).
37    /// These kernels run persistently and process messages in a loop.
38    RingKernel,
39}
40
41impl ValidationMode {
42    /// Check if loops are allowed in this validation mode.
43    pub fn allows_loops(&self) -> bool {
44        matches!(self, ValidationMode::Generic | ValidationMode::RingKernel)
45    }
46
47    /// Check if loops are required in this validation mode.
48    pub fn requires_loops(&self) -> bool {
49        matches!(self, ValidationMode::RingKernel)
50    }
51}
52
53/// Validation errors for DSL constraint violations.
54#[derive(Error, Debug, Clone)]
55pub enum ValidationError {
56    /// Loops are not supported (use parallel threads instead).
57    #[error(
58        "Loops are not supported in stencil kernels. Use parallel threads instead. Found: {0}"
59    )]
60    LoopNotAllowed(String),
61
62    /// Closures are not supported.
63    #[error("Closures are not supported in stencil kernels. Found at: {0}")]
64    ClosureNotAllowed(String),
65
66    /// Async/await is not supported.
67    #[error("Async/await is not supported in GPU kernels")]
68    AsyncNotAllowed,
69
70    /// Heap allocation is not supported.
71    #[error("Heap allocation ({0}) is not supported in GPU kernels")]
72    HeapAllocationNotAllowed(String),
73
74    /// Unsupported expression type.
75    #[error("Unsupported expression: {0}")]
76    UnsupportedExpression(String),
77
78    /// Recursion detected.
79    #[error("Recursion is not supported in GPU kernels")]
80    RecursionNotAllowed,
81
82    /// Invalid function signature.
83    #[error("Invalid function signature: {0}")]
84    InvalidSignature(String),
85
86    /// Missing required loop for ring kernels.
87    #[error("Ring kernels require a message processing loop")]
88    LoopRequired,
89}
90
91/// Visitor that checks for DSL constraint violations.
92struct DslValidator {
93    errors: Vec<ValidationError>,
94    function_name: String,
95    mode: ValidationMode,
96    loop_count: usize,
97}
98
99impl DslValidator {
100    fn with_mode(function_name: String, mode: ValidationMode) -> Self {
101        Self {
102            errors: Vec::new(),
103            function_name,
104            mode,
105            loop_count: 0,
106        }
107    }
108
109    fn check_heap_allocation(&mut self, expr: &Expr) {
110        // Check for common heap-allocating patterns
111        if let Expr::Call(call) = expr {
112            if let Expr::Path(path) = call.func.as_ref() {
113                let path_str = path
114                    .path
115                    .segments
116                    .iter()
117                    .map(|s| s.ident.to_string())
118                    .collect::<Vec<_>>()
119                    .join("::");
120
121                // Check for Vec, Box, String, etc.
122                let heap_types = ["Vec", "Box", "String", "Rc", "Arc", "HashMap", "HashSet"];
123                for heap_type in &heap_types {
124                    if path_str.contains(heap_type) {
125                        self.errors
126                            .push(ValidationError::HeapAllocationNotAllowed(path_str.clone()));
127                        return;
128                    }
129                }
130            }
131        }
132
133        // Check for vec! macro
134        if let Expr::Macro(mac) = expr {
135            let macro_name = mac.mac.path.segments.last().map(|s| s.ident.to_string());
136            if macro_name == Some("vec".to_string()) {
137                self.errors.push(ValidationError::HeapAllocationNotAllowed(
138                    "vec!".to_string(),
139                ));
140            }
141        }
142    }
143
144    fn check_recursion(&mut self, expr: &Expr) {
145        // Check for potential recursive calls
146        if let Expr::Call(call) = expr {
147            if let Expr::Path(path) = call.func.as_ref() {
148                if let Some(segment) = path.path.segments.last() {
149                    if segment.ident == self.function_name {
150                        self.errors.push(ValidationError::RecursionNotAllowed);
151                    }
152                }
153            }
154        }
155    }
156}
157
158impl<'ast> Visit<'ast> for DslValidator {
159    fn visit_expr_for_loop(&mut self, node: &'ast ExprForLoop) {
160        if !self.mode.allows_loops() {
161            self.errors
162                .push(ValidationError::LoopNotAllowed("for loop".to_string()));
163        }
164        self.loop_count += 1;
165        // Still visit children to find more errors
166        syn::visit::visit_expr_for_loop(self, node);
167    }
168
169    fn visit_expr_while(&mut self, node: &'ast ExprWhile) {
170        if !self.mode.allows_loops() {
171            self.errors
172                .push(ValidationError::LoopNotAllowed("while loop".to_string()));
173        }
174        self.loop_count += 1;
175        syn::visit::visit_expr_while(self, node);
176    }
177
178    fn visit_expr_loop(&mut self, node: &'ast ExprLoop) {
179        if !self.mode.allows_loops() {
180            self.errors
181                .push(ValidationError::LoopNotAllowed("loop".to_string()));
182        }
183        self.loop_count += 1;
184        syn::visit::visit_expr_loop(self, node);
185    }
186
187    fn visit_expr_closure(&mut self, node: &'ast ExprClosure) {
188        self.errors.push(ValidationError::ClosureNotAllowed(
189            "closure expression".to_string(),
190        ));
191        syn::visit::visit_expr_closure(self, node);
192    }
193
194    fn visit_expr_async(&mut self, _node: &'ast ExprAsync) {
195        self.errors.push(ValidationError::AsyncNotAllowed);
196    }
197
198    fn visit_expr_await(&mut self, _node: &'ast ExprAwait) {
199        self.errors.push(ValidationError::AsyncNotAllowed);
200    }
201
202    fn visit_expr_yield(&mut self, _node: &'ast ExprYield) {
203        self.errors.push(ValidationError::UnsupportedExpression(
204            "yield expression".to_string(),
205        ));
206    }
207
208    fn visit_expr(&mut self, node: &'ast Expr) {
209        // Check for heap allocations
210        self.check_heap_allocation(node);
211
212        // Check for recursion
213        self.check_recursion(node);
214
215        // Continue visiting children
216        syn::visit::visit_expr(self, node);
217    }
218}
219
220/// Validate that a function conforms to the stencil kernel DSL.
221///
222/// # Constraints
223///
224/// - No loops (`for`, `while`, `loop`) - use parallel threads
225/// - No closures
226/// - No async/await
227/// - No heap allocations (Vec, Box, String, etc.)
228/// - No recursion
229/// - No trait objects
230///
231/// # Returns
232///
233/// `Ok(())` if validation passes, `Err` with the first validation error otherwise.
234pub fn validate_function(func: &ItemFn) -> Result<(), ValidationError> {
235    validate_function_with_mode(func, ValidationMode::Stencil)
236}
237
238/// Validate a function with a specific validation mode.
239///
240/// Different kernel types have different validation requirements:
241///
242/// - `ValidationMode::Stencil`: No loops allowed (current behavior for stencil kernels)
243/// - `ValidationMode::Generic`: Loops allowed for general CUDA kernels
244/// - `ValidationMode::RingKernel`: Loops required for persistent actor kernels
245///
246/// # Arguments
247///
248/// * `func` - The function to validate
249/// * `mode` - The validation mode to use
250///
251/// # Returns
252///
253/// `Ok(())` if validation passes, `Err` with the first validation error otherwise.
254///
255/// # Example
256///
257/// ```ignore
258/// use ringkernel_cuda_codegen::{validate_function_with_mode, ValidationMode};
259/// use syn::parse_quote;
260///
261/// // Generic kernel with loops allowed
262/// let func: syn::ItemFn = parse_quote! {
263///     fn process(data: &mut [f32], n: i32) {
264///         for i in 0..n {
265///             data[i as usize] = data[i as usize] * 2.0;
266///         }
267///     }
268/// };
269///
270/// // This would fail with Stencil mode, but passes with Generic mode
271/// assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
272/// ```
273pub fn validate_function_with_mode(
274    func: &ItemFn,
275    mode: ValidationMode,
276) -> Result<(), ValidationError> {
277    let function_name = func.sig.ident.to_string();
278    let mut validator = DslValidator::with_mode(function_name, mode);
279
280    // Visit all statements in the function body
281    for stmt in &func.block.stmts {
282        match stmt {
283            Stmt::Expr(expr, _) => {
284                validator.visit_expr(expr);
285            }
286            Stmt::Local(syn::Local {
287                init: Some(syn::LocalInit { expr, .. }),
288                ..
289            }) => {
290                validator.visit_expr(expr);
291            }
292            _ => {}
293        }
294    }
295
296    // Also visit the entire block to catch nested constructs
297    syn::visit::visit_block(&mut validator, &func.block);
298
299    // Check if loops are required but none found
300    if mode.requires_loops() && validator.loop_count == 0 {
301        return Err(ValidationError::LoopRequired);
302    }
303
304    // Return first error if any
305    if let Some(error) = validator.errors.into_iter().next() {
306        Err(error)
307    } else {
308        Ok(())
309    }
310}
311
312/// Validate function signature for stencil kernels.
313///
314/// Ensures the function has appropriate parameter types and
315/// returns void or a simple type.
316pub fn validate_stencil_signature(func: &ItemFn) -> Result<(), ValidationError> {
317    // Check that the function is not async
318    if func.sig.asyncness.is_some() {
319        return Err(ValidationError::AsyncNotAllowed);
320    }
321
322    // Check that the function doesn't use generics (for now)
323    if !func.sig.generics.params.is_empty() {
324        return Err(ValidationError::InvalidSignature(
325            "Generic parameters are not supported in stencil kernels".to_string(),
326        ));
327    }
328
329    Ok(())
330}
331
332/// Check if a statement might need special handling.
333pub fn is_simple_assignment(stmt: &Stmt) -> bool {
334    matches!(stmt, Stmt::Local(_) | Stmt::Expr(Expr::Assign(_), _))
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use syn::parse_quote;
341
342    #[test]
343    fn test_valid_function() {
344        let func: ItemFn = parse_quote! {
345            fn add(a: f32, b: f32) -> f32 {
346                let c = a + b;
347                c * 2.0
348            }
349        };
350
351        assert!(validate_function(&func).is_ok());
352    }
353
354    #[test]
355    fn test_for_loop_rejected() {
356        let func: ItemFn = parse_quote! {
357            fn sum(data: &[f32]) -> f32 {
358                let mut total = 0.0;
359                for x in data {
360                    total += x;
361                }
362                total
363            }
364        };
365
366        let result = validate_function(&func);
367        assert!(matches!(result, Err(ValidationError::LoopNotAllowed(_))));
368    }
369
370    #[test]
371    fn test_while_loop_rejected() {
372        let func: ItemFn = parse_quote! {
373            fn countdown(n: i32) -> i32 {
374                let mut i = n;
375                while i > 0 {
376                    i -= 1;
377                }
378                i
379            }
380        };
381
382        let result = validate_function(&func);
383        assert!(matches!(result, Err(ValidationError::LoopNotAllowed(_))));
384    }
385
386    #[test]
387    fn test_closure_rejected() {
388        let func: ItemFn = parse_quote! {
389            fn apply(x: f32) -> f32 {
390                let f = |v| v * 2.0;
391                f(x)
392            }
393        };
394
395        let result = validate_function(&func);
396        assert!(matches!(result, Err(ValidationError::ClosureNotAllowed(_))));
397    }
398
399    #[test]
400    fn test_if_else_allowed() {
401        let func: ItemFn = parse_quote! {
402            fn clamp(x: f32, min: f32, max: f32) -> f32 {
403                if x < min {
404                    min
405                } else if x > max {
406                    max
407                } else {
408                    x
409                }
410            }
411        };
412
413        assert!(validate_function(&func).is_ok());
414    }
415
416    #[test]
417    fn test_stencil_pattern_allowed() {
418        let func: ItemFn = parse_quote! {
419            fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
420                let curr = p[pos.idx()];
421                let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
422                p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
423            }
424        };
425
426        assert!(validate_function(&func).is_ok());
427    }
428
429    #[test]
430    fn test_async_rejected() {
431        let func: ItemFn = parse_quote! {
432            async fn fetch(x: f32) -> f32 {
433                x
434            }
435        };
436
437        let result = validate_stencil_signature(&func);
438        assert!(matches!(result, Err(ValidationError::AsyncNotAllowed)));
439    }
440
441    // === ValidationMode tests ===
442
443    #[test]
444    fn test_generic_mode_allows_for_loop() {
445        let func: ItemFn = parse_quote! {
446            fn process(data: &mut [f32], n: i32) {
447                for i in 0..n {
448                    data[i as usize] = data[i as usize] * 2.0;
449                }
450            }
451        };
452
453        // Stencil mode rejects loops
454        assert!(matches!(
455            validate_function_with_mode(&func, ValidationMode::Stencil),
456            Err(ValidationError::LoopNotAllowed(_))
457        ));
458
459        // Generic mode allows loops
460        assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
461    }
462
463    #[test]
464    fn test_generic_mode_allows_while_loop() {
465        let func: ItemFn = parse_quote! {
466            fn process(data: &mut [f32]) {
467                let mut i = 0;
468                while i < 10 {
469                    data[i] = 0.0;
470                    i += 1;
471                }
472            }
473        };
474
475        // Generic mode allows while loops
476        assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
477    }
478
479    #[test]
480    fn test_generic_mode_allows_infinite_loop() {
481        let func: ItemFn = parse_quote! {
482            fn process(active: &u32) {
483                loop {
484                    if *active == 0 {
485                        return;
486                    }
487                }
488            }
489        };
490
491        // Generic mode allows infinite loops
492        assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
493    }
494
495    #[test]
496    fn test_ring_kernel_mode_requires_loop() {
497        let func_no_loop: ItemFn = parse_quote! {
498            fn handler(x: f32) -> f32 {
499                x * 2.0
500            }
501        };
502
503        // RingKernel mode requires at least one loop
504        assert!(matches!(
505            validate_function_with_mode(&func_no_loop, ValidationMode::RingKernel),
506            Err(ValidationError::LoopRequired)
507        ));
508
509        let func_with_loop: ItemFn = parse_quote! {
510            fn handler(active: &u32) {
511                while *active != 0 {
512                    // Process messages
513                }
514            }
515        };
516
517        // RingKernel mode accepts function with loop
518        assert!(validate_function_with_mode(&func_with_loop, ValidationMode::RingKernel).is_ok());
519    }
520
521    #[test]
522    fn test_validation_mode_allows_loops() {
523        assert!(!ValidationMode::Stencil.allows_loops());
524        assert!(ValidationMode::Generic.allows_loops());
525        assert!(ValidationMode::RingKernel.allows_loops());
526    }
527
528    #[test]
529    fn test_validation_mode_requires_loops() {
530        assert!(!ValidationMode::Stencil.requires_loops());
531        assert!(!ValidationMode::Generic.requires_loops());
532        assert!(ValidationMode::RingKernel.requires_loops());
533    }
534
535    #[test]
536    fn test_closures_rejected_in_all_modes() {
537        // Test without loop (for Stencil and Generic modes)
538        let func: ItemFn = parse_quote! {
539            fn apply(x: f32) -> f32 {
540                let f = |v| v * 2.0;
541                f(x)
542            }
543        };
544
545        // Closures are rejected in Stencil and Generic modes
546        assert!(matches!(
547            validate_function_with_mode(&func, ValidationMode::Stencil),
548            Err(ValidationError::ClosureNotAllowed(_))
549        ));
550        assert!(matches!(
551            validate_function_with_mode(&func, ValidationMode::Generic),
552            Err(ValidationError::ClosureNotAllowed(_))
553        ));
554
555        // For RingKernel mode, we need a function with a loop that also has a closure
556        let func_with_loop: ItemFn = parse_quote! {
557            fn apply(x: f32) -> f32 {
558                loop {
559                    let f = |v| v * 2.0;
560                    if f(x) > 0.0 { break; }
561                }
562                x
563            }
564        };
565
566        // Closures are rejected in RingKernel mode too
567        assert!(matches!(
568            validate_function_with_mode(&func_with_loop, ValidationMode::RingKernel),
569            Err(ValidationError::ClosureNotAllowed(_))
570        ));
571    }
572}