1use syn::visit::Visit;
15use syn::{
16 Expr, ExprAsync, ExprAwait, ExprClosure, ExprForLoop, ExprLoop, ExprWhile, ExprYield, ItemFn,
17 Stmt,
18};
19use thiserror::Error;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum ValidationMode {
27 #[default]
30 Stencil,
31
32 Generic,
35
36 RingKernel,
39}
40
41impl ValidationMode {
42 pub fn allows_loops(&self) -> bool {
44 matches!(self, ValidationMode::Generic | ValidationMode::RingKernel)
45 }
46
47 pub fn requires_loops(&self) -> bool {
49 matches!(self, ValidationMode::RingKernel)
50 }
51}
52
53#[derive(Error, Debug, Clone)]
55pub enum ValidationError {
56 #[error(
58 "Loops are not supported in stencil kernels. Use parallel threads instead. Found: {0}"
59 )]
60 LoopNotAllowed(String),
61
62 #[error("Closures are not supported in stencil kernels. Found at: {0}")]
64 ClosureNotAllowed(String),
65
66 #[error("Async/await is not supported in GPU kernels")]
68 AsyncNotAllowed,
69
70 #[error("Heap allocation ({0}) is not supported in GPU kernels")]
72 HeapAllocationNotAllowed(String),
73
74 #[error("Unsupported expression: {0}")]
76 UnsupportedExpression(String),
77
78 #[error("Recursion is not supported in GPU kernels")]
80 RecursionNotAllowed,
81
82 #[error("Invalid function signature: {0}")]
84 InvalidSignature(String),
85
86 #[error("Ring kernels require a message processing loop")]
88 LoopRequired,
89}
90
91struct DslValidator {
93 errors: Vec<ValidationError>,
94 function_name: String,
95 mode: ValidationMode,
96 loop_count: usize,
97}
98
99impl DslValidator {
100 fn with_mode(function_name: String, mode: ValidationMode) -> Self {
101 Self {
102 errors: Vec::new(),
103 function_name,
104 mode,
105 loop_count: 0,
106 }
107 }
108
109 fn check_heap_allocation(&mut self, expr: &Expr) {
110 if let Expr::Call(call) = expr {
112 if let Expr::Path(path) = call.func.as_ref() {
113 let path_str = path
114 .path
115 .segments
116 .iter()
117 .map(|s| s.ident.to_string())
118 .collect::<Vec<_>>()
119 .join("::");
120
121 let heap_types = ["Vec", "Box", "String", "Rc", "Arc", "HashMap", "HashSet"];
123 for heap_type in &heap_types {
124 if path_str.contains(heap_type) {
125 self.errors
126 .push(ValidationError::HeapAllocationNotAllowed(path_str.clone()));
127 return;
128 }
129 }
130 }
131 }
132
133 if let Expr::Macro(mac) = expr {
135 let macro_name = mac.mac.path.segments.last().map(|s| s.ident.to_string());
136 if macro_name == Some("vec".to_string()) {
137 self.errors.push(ValidationError::HeapAllocationNotAllowed(
138 "vec!".to_string(),
139 ));
140 }
141 }
142 }
143
144 fn check_recursion(&mut self, expr: &Expr) {
145 if let Expr::Call(call) = expr {
147 if let Expr::Path(path) = call.func.as_ref() {
148 if let Some(segment) = path.path.segments.last() {
149 if segment.ident == self.function_name {
150 self.errors.push(ValidationError::RecursionNotAllowed);
151 }
152 }
153 }
154 }
155 }
156}
157
158impl<'ast> Visit<'ast> for DslValidator {
159 fn visit_expr_for_loop(&mut self, node: &'ast ExprForLoop) {
160 if !self.mode.allows_loops() {
161 self.errors
162 .push(ValidationError::LoopNotAllowed("for loop".to_string()));
163 }
164 self.loop_count += 1;
165 syn::visit::visit_expr_for_loop(self, node);
167 }
168
169 fn visit_expr_while(&mut self, node: &'ast ExprWhile) {
170 if !self.mode.allows_loops() {
171 self.errors
172 .push(ValidationError::LoopNotAllowed("while loop".to_string()));
173 }
174 self.loop_count += 1;
175 syn::visit::visit_expr_while(self, node);
176 }
177
178 fn visit_expr_loop(&mut self, node: &'ast ExprLoop) {
179 if !self.mode.allows_loops() {
180 self.errors
181 .push(ValidationError::LoopNotAllowed("loop".to_string()));
182 }
183 self.loop_count += 1;
184 syn::visit::visit_expr_loop(self, node);
185 }
186
187 fn visit_expr_closure(&mut self, node: &'ast ExprClosure) {
188 self.errors.push(ValidationError::ClosureNotAllowed(
189 "closure expression".to_string(),
190 ));
191 syn::visit::visit_expr_closure(self, node);
192 }
193
194 fn visit_expr_async(&mut self, _node: &'ast ExprAsync) {
195 self.errors.push(ValidationError::AsyncNotAllowed);
196 }
197
198 fn visit_expr_await(&mut self, _node: &'ast ExprAwait) {
199 self.errors.push(ValidationError::AsyncNotAllowed);
200 }
201
202 fn visit_expr_yield(&mut self, _node: &'ast ExprYield) {
203 self.errors.push(ValidationError::UnsupportedExpression(
204 "yield expression".to_string(),
205 ));
206 }
207
208 fn visit_expr(&mut self, node: &'ast Expr) {
209 self.check_heap_allocation(node);
211
212 self.check_recursion(node);
214
215 syn::visit::visit_expr(self, node);
217 }
218}
219
220pub fn validate_function(func: &ItemFn) -> Result<(), ValidationError> {
235 validate_function_with_mode(func, ValidationMode::Stencil)
236}
237
238pub fn validate_function_with_mode(
274 func: &ItemFn,
275 mode: ValidationMode,
276) -> Result<(), ValidationError> {
277 let function_name = func.sig.ident.to_string();
278 let mut validator = DslValidator::with_mode(function_name, mode);
279
280 for stmt in &func.block.stmts {
282 match stmt {
283 Stmt::Expr(expr, _) => {
284 validator.visit_expr(expr);
285 }
286 Stmt::Local(syn::Local {
287 init: Some(syn::LocalInit { expr, .. }),
288 ..
289 }) => {
290 validator.visit_expr(expr);
291 }
292 _ => {}
293 }
294 }
295
296 syn::visit::visit_block(&mut validator, &func.block);
298
299 if mode.requires_loops() && validator.loop_count == 0 {
301 return Err(ValidationError::LoopRequired);
302 }
303
304 if let Some(error) = validator.errors.into_iter().next() {
306 Err(error)
307 } else {
308 Ok(())
309 }
310}
311
312pub fn validate_stencil_signature(func: &ItemFn) -> Result<(), ValidationError> {
317 if func.sig.asyncness.is_some() {
319 return Err(ValidationError::AsyncNotAllowed);
320 }
321
322 if !func.sig.generics.params.is_empty() {
324 return Err(ValidationError::InvalidSignature(
325 "Generic parameters are not supported in stencil kernels".to_string(),
326 ));
327 }
328
329 Ok(())
330}
331
332pub fn is_simple_assignment(stmt: &Stmt) -> bool {
334 matches!(stmt, Stmt::Local(_) | Stmt::Expr(Expr::Assign(_), _))
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use syn::parse_quote;
341
342 #[test]
343 fn test_valid_function() {
344 let func: ItemFn = parse_quote! {
345 fn add(a: f32, b: f32) -> f32 {
346 let c = a + b;
347 c * 2.0
348 }
349 };
350
351 assert!(validate_function(&func).is_ok());
352 }
353
354 #[test]
355 fn test_for_loop_rejected() {
356 let func: ItemFn = parse_quote! {
357 fn sum(data: &[f32]) -> f32 {
358 let mut total = 0.0;
359 for x in data {
360 total += x;
361 }
362 total
363 }
364 };
365
366 let result = validate_function(&func);
367 assert!(matches!(result, Err(ValidationError::LoopNotAllowed(_))));
368 }
369
370 #[test]
371 fn test_while_loop_rejected() {
372 let func: ItemFn = parse_quote! {
373 fn countdown(n: i32) -> i32 {
374 let mut i = n;
375 while i > 0 {
376 i -= 1;
377 }
378 i
379 }
380 };
381
382 let result = validate_function(&func);
383 assert!(matches!(result, Err(ValidationError::LoopNotAllowed(_))));
384 }
385
386 #[test]
387 fn test_closure_rejected() {
388 let func: ItemFn = parse_quote! {
389 fn apply(x: f32) -> f32 {
390 let f = |v| v * 2.0;
391 f(x)
392 }
393 };
394
395 let result = validate_function(&func);
396 assert!(matches!(result, Err(ValidationError::ClosureNotAllowed(_))));
397 }
398
399 #[test]
400 fn test_if_else_allowed() {
401 let func: ItemFn = parse_quote! {
402 fn clamp(x: f32, min: f32, max: f32) -> f32 {
403 if x < min {
404 min
405 } else if x > max {
406 max
407 } else {
408 x
409 }
410 }
411 };
412
413 assert!(validate_function(&func).is_ok());
414 }
415
416 #[test]
417 fn test_stencil_pattern_allowed() {
418 let func: ItemFn = parse_quote! {
419 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
420 let curr = p[pos.idx()];
421 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
422 p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
423 }
424 };
425
426 assert!(validate_function(&func).is_ok());
427 }
428
429 #[test]
430 fn test_async_rejected() {
431 let func: ItemFn = parse_quote! {
432 async fn fetch(x: f32) -> f32 {
433 x
434 }
435 };
436
437 let result = validate_stencil_signature(&func);
438 assert!(matches!(result, Err(ValidationError::AsyncNotAllowed)));
439 }
440
441 #[test]
444 fn test_generic_mode_allows_for_loop() {
445 let func: ItemFn = parse_quote! {
446 fn process(data: &mut [f32], n: i32) {
447 for i in 0..n {
448 data[i as usize] = data[i as usize] * 2.0;
449 }
450 }
451 };
452
453 assert!(matches!(
455 validate_function_with_mode(&func, ValidationMode::Stencil),
456 Err(ValidationError::LoopNotAllowed(_))
457 ));
458
459 assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
461 }
462
463 #[test]
464 fn test_generic_mode_allows_while_loop() {
465 let func: ItemFn = parse_quote! {
466 fn process(data: &mut [f32]) {
467 let mut i = 0;
468 while i < 10 {
469 data[i] = 0.0;
470 i += 1;
471 }
472 }
473 };
474
475 assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
477 }
478
479 #[test]
480 fn test_generic_mode_allows_infinite_loop() {
481 let func: ItemFn = parse_quote! {
482 fn process(active: &u32) {
483 loop {
484 if *active == 0 {
485 return;
486 }
487 }
488 }
489 };
490
491 assert!(validate_function_with_mode(&func, ValidationMode::Generic).is_ok());
493 }
494
495 #[test]
496 fn test_ring_kernel_mode_requires_loop() {
497 let func_no_loop: ItemFn = parse_quote! {
498 fn handler(x: f32) -> f32 {
499 x * 2.0
500 }
501 };
502
503 assert!(matches!(
505 validate_function_with_mode(&func_no_loop, ValidationMode::RingKernel),
506 Err(ValidationError::LoopRequired)
507 ));
508
509 let func_with_loop: ItemFn = parse_quote! {
510 fn handler(active: &u32) {
511 while *active != 0 {
512 }
514 }
515 };
516
517 assert!(validate_function_with_mode(&func_with_loop, ValidationMode::RingKernel).is_ok());
519 }
520
521 #[test]
522 fn test_validation_mode_allows_loops() {
523 assert!(!ValidationMode::Stencil.allows_loops());
524 assert!(ValidationMode::Generic.allows_loops());
525 assert!(ValidationMode::RingKernel.allows_loops());
526 }
527
528 #[test]
529 fn test_validation_mode_requires_loops() {
530 assert!(!ValidationMode::Stencil.requires_loops());
531 assert!(!ValidationMode::Generic.requires_loops());
532 assert!(ValidationMode::RingKernel.requires_loops());
533 }
534
535 #[test]
536 fn test_closures_rejected_in_all_modes() {
537 let func: ItemFn = parse_quote! {
539 fn apply(x: f32) -> f32 {
540 let f = |v| v * 2.0;
541 f(x)
542 }
543 };
544
545 assert!(matches!(
547 validate_function_with_mode(&func, ValidationMode::Stencil),
548 Err(ValidationError::ClosureNotAllowed(_))
549 ));
550 assert!(matches!(
551 validate_function_with_mode(&func, ValidationMode::Generic),
552 Err(ValidationError::ClosureNotAllowed(_))
553 ));
554
555 let func_with_loop: ItemFn = parse_quote! {
557 fn apply(x: f32) -> f32 {
558 loop {
559 let f = |v| v * 2.0;
560 if f(x) > 0.0 { break; }
561 }
562 x
563 }
564 };
565
566 assert!(matches!(
568 validate_function_with_mode(&func_with_loop, ValidationMode::RingKernel),
569 Err(ValidationError::ClosureNotAllowed(_))
570 ));
571 }
572}