1use crate::bindings::{generate_bindings, AccessMode, BindingLayout};
6use crate::handler::WgslContextMethod;
7use crate::intrinsics::{IntrinsicRegistry, WgslIntrinsic};
8use crate::loops::RangeInfo;
9use crate::ring_kernel::RingKernelConfig;
10use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
11use crate::stencil::StencilConfig;
12use crate::types::{
13 is_grid_pos_type, is_mutable_reference, is_ring_context_type, TypeMapper, WgslType,
14};
15use crate::u64_workarounds::U64Helpers;
16use crate::validation::ValidationMode;
17use crate::{Result, TranspileError};
18use quote::ToTokens;
19use std::cell::Cell;
20use std::collections::HashMap;
21use syn::{
22 BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
23 ExprIf, ExprIndex, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
24 ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat, ReturnType, Stmt, UnOp,
25};
26
27pub struct WgslTranspiler {
29 config: Option<StencilConfig>,
31 #[allow(dead_code)]
33 ring_config: Option<RingKernelConfig>,
34 type_mapper: TypeMapper,
36 intrinsics: IntrinsicRegistry,
38 grid_pos_vars: Vec<String>,
40 context_vars: Vec<String>,
42 indent: usize,
44 validation_mode: ValidationMode,
46 shared_memory: SharedMemoryConfig,
48 pub shared_vars: HashMap<String, SharedVarInfo>,
50 ring_kernel_mode: bool,
52 needs_u64_helpers: bool,
54 needs_subgroup_extension: Cell<bool>,
57 workgroup_size: (u32, u32, u32),
59 bindings: Vec<BindingLayout>,
61}
62
63#[derive(Debug, Clone)]
65pub struct SharedVarInfo {
66 pub name: String,
68 pub is_tile: bool,
70 pub dimensions: Vec<usize>,
72 pub element_type: String,
74}
75
76impl WgslTranspiler {
77 pub fn new_stencil(config: StencilConfig) -> Self {
79 Self {
80 config: Some(config),
81 ring_config: None,
82 type_mapper: TypeMapper::new(),
83 intrinsics: IntrinsicRegistry::new(),
84 grid_pos_vars: Vec::new(),
85 context_vars: Vec::new(),
86 indent: 1,
87 validation_mode: ValidationMode::Stencil,
88 shared_memory: SharedMemoryConfig::new(),
89 shared_vars: HashMap::new(),
90 ring_kernel_mode: false,
91 needs_u64_helpers: false,
92 needs_subgroup_extension: Cell::new(false),
93 workgroup_size: (256, 1, 1),
94 bindings: Vec::new(),
95 }
96 }
97
98 pub fn new_generic() -> Self {
100 Self {
101 config: None,
102 ring_config: None,
103 type_mapper: TypeMapper::new(),
104 intrinsics: IntrinsicRegistry::new(),
105 grid_pos_vars: Vec::new(),
106 context_vars: Vec::new(),
107 indent: 1,
108 validation_mode: ValidationMode::Generic,
109 shared_memory: SharedMemoryConfig::new(),
110 shared_vars: HashMap::new(),
111 ring_kernel_mode: false,
112 needs_u64_helpers: false,
113 needs_subgroup_extension: Cell::new(false),
114 workgroup_size: (256, 1, 1),
115 bindings: Vec::new(),
116 }
117 }
118
119 pub fn new_ring_kernel(config: RingKernelConfig) -> Self {
121 Self {
122 config: None,
123 ring_config: Some(config.clone()),
124 type_mapper: TypeMapper::new(),
125 intrinsics: IntrinsicRegistry::new(),
126 grid_pos_vars: Vec::new(),
127 context_vars: Vec::new(),
128 indent: 2,
129 validation_mode: ValidationMode::Generic,
130 shared_memory: SharedMemoryConfig::new(),
131 shared_vars: HashMap::new(),
132 ring_kernel_mode: true,
133 needs_u64_helpers: true, needs_subgroup_extension: Cell::new(false),
135 workgroup_size: (config.workgroup_size, 1, 1),
136 bindings: Vec::new(),
137 }
138 }
139
140 pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
142 self.workgroup_size = (x, y, z);
143 self
144 }
145
146 fn indent_str(&self) -> String {
148 " ".repeat(self.indent)
149 }
150
151 pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
153 let config = self
154 .config
155 .as_ref()
156 .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
157 .clone();
158
159 for param in &func.sig.inputs {
161 if let FnArg::Typed(pat_type) = param {
162 if is_grid_pos_type(&pat_type.ty) {
163 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
164 self.grid_pos_vars.push(ident.ident.to_string());
165 }
166 }
167 }
168 }
169
170 self.collect_bindings(func)?;
172
173 let mut output = String::new();
174
175 output.push_str(&generate_bindings(&self.bindings));
177 output.push_str("\n\n");
178
179 if config.use_shared_memory {
181 let buffer_width = config.buffer_width();
182 let buffer_height = config.buffer_height();
183 output.push_str(&format!(
184 "var<workgroup> tile: array<array<f32, {}>, {}>;\n\n",
185 buffer_width, buffer_height
186 ));
187 }
188
189 output.push_str("@compute ");
191 output.push_str(&config.workgroup_size_annotation());
192 output.push('\n');
193 output.push_str(&format!("fn {}(\n", func.sig.ident));
194 output.push_str(" @builtin(local_invocation_id) local_id: vec3<u32>,\n");
195 output.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
196 output.push_str(" @builtin(global_invocation_id) global_id: vec3<u32>\n");
197 output.push_str(") {\n");
198
199 output.push_str(&self.generate_stencil_preamble(&config));
201
202 let body = self.transpile_block(&func.block)?;
204 output.push_str(&body);
205
206 output.push_str("}\n");
207
208 Ok(output)
209 }
210
211 fn generate_stencil_preamble(&self, config: &StencilConfig) -> String {
213 let buffer_width = config.buffer_width();
214 let mut preamble = String::new();
215
216 preamble.push_str(" // Thread indices\n");
217 preamble.push_str(" let lx = local_id.x;\n");
218 preamble.push_str(" let ly = local_id.y;\n");
219 preamble.push_str(&format!(" let buffer_width = {}u;\n", buffer_width));
220 preamble.push('\n');
221
222 preamble.push_str(&format!(
224 " if (lx >= {}u || ly >= {}u) {{ return; }}\n\n",
225 config.tile_width, config.tile_height
226 ));
227
228 preamble.push_str(&format!(
230 " let idx = (ly + {}u) * buffer_width + (lx + {}u);\n\n",
231 config.halo, config.halo
232 ));
233
234 preamble
235 }
236
237 pub fn transpile_global_kernel(&mut self, func: &ItemFn) -> Result<String> {
239 self.collect_bindings(func)?;
241
242 let body = self.transpile_block(&func.block)?;
244
245 let mut output = String::new();
246
247 if self.needs_subgroup_extension.get() {
249 output.push_str("enable chromium_experimental_subgroups;\n\n");
250 }
251
252 output.push_str(&generate_bindings(&self.bindings));
254 output.push_str("\n\n");
255
256 if !self.shared_memory.declarations.is_empty() {
258 output.push_str(&self.shared_memory.to_wgsl());
259 output.push_str("\n\n");
260 }
261
262 output.push_str("@compute ");
264 output.push_str(&format!(
265 "@workgroup_size({}, {}, {})\n",
266 self.workgroup_size.0, self.workgroup_size.1, self.workgroup_size.2
267 ));
268 output.push_str(&format!("fn {}(\n", func.sig.ident));
269 output.push_str(" @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
270 output.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
271 output.push_str(" @builtin(global_invocation_id) global_invocation_id: vec3<u32>,\n");
272 output.push_str(" @builtin(num_workgroups) num_workgroups: vec3<u32>");
273
274 if self.needs_subgroup_extension.get() {
276 output
277 .push_str(",\n @builtin(subgroup_invocation_id) subgroup_invocation_id: u32,\n");
278 output.push_str(" @builtin(subgroup_size) subgroup_size: u32\n");
279 } else {
280 output.push('\n');
281 }
282 output.push_str(") {\n");
283
284 output.push_str(&body);
286
287 output.push_str("}\n");
288
289 Ok(output)
290 }
291
292 pub fn transpile_ring_kernel(
294 &mut self,
295 handler: &ItemFn,
296 config: &RingKernelConfig,
297 ) -> Result<String> {
298 for param in &handler.sig.inputs {
300 if let FnArg::Typed(pat_type) = param {
301 if is_ring_context_type(&pat_type.ty) {
302 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
303 self.context_vars.push(ident.ident.to_string());
304 }
305 }
306 }
307 }
308
309 self.ring_kernel_mode = true;
310 self.needs_u64_helpers = true;
311
312 let mut output = String::new();
313
314 output.push_str(&U64Helpers::generate_all());
316 output.push_str("\n\n");
317
318 output.push_str(&crate::ring_kernel::generate_control_block_struct(config));
320 output.push_str("\n\n");
321
322 output.push_str(crate::ring_kernel::generate_ring_kernel_bindings());
324 output.push_str("\n\n");
325
326 output.push_str("@compute ");
328 output.push_str(&config.workgroup_size_annotation());
329 output.push('\n');
330 output.push_str(&format!("fn ring_kernel_{}(\n", config.name));
331 output.push_str(" @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
332 output.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
333 output.push_str(" @builtin(global_invocation_id) global_invocation_id: vec3<u32>\n");
334 output.push_str(") {\n");
335
336 output.push_str(crate::ring_kernel::generate_ring_kernel_preamble());
338 output.push_str("\n\n");
339
340 output.push_str(" // === USER HANDLER CODE ===\n");
342 self.indent = 1;
343 let handler_body = self.transpile_block(&handler.block)?;
344 output.push_str(&handler_body);
345 output.push_str(" // === END HANDLER CODE ===\n\n");
346
347 output.push_str(" // Update message counter\n");
349 output.push_str(
350 " atomic_inc_u64(&control.messages_processed_lo, &control.messages_processed_hi);\n",
351 );
352
353 output.push_str("}\n");
354
355 Ok(output)
356 }
357
358 fn collect_bindings(&mut self, func: &ItemFn) -> Result<()> {
360 let mut binding_idx = 0u32;
361
362 for param in &func.sig.inputs {
363 if let FnArg::Typed(pat_type) = param {
364 if is_grid_pos_type(&pat_type.ty) || is_ring_context_type(&pat_type.ty) {
366 continue;
367 }
368
369 let param_name = match pat_type.pat.as_ref() {
370 Pat::Ident(ident) => ident.ident.to_string(),
371 _ => continue,
372 };
373
374 let wgsl_type = self
375 .type_mapper
376 .map_type(&pat_type.ty)
377 .map_err(TranspileError::Type)?;
378 let is_mutable = is_mutable_reference(&pat_type.ty);
379
380 match &wgsl_type {
382 WgslType::Ptr { inner, .. } => {
383 let access = if is_mutable {
384 AccessMode::ReadWrite
385 } else {
386 AccessMode::Read
387 };
388 self.bindings.push(BindingLayout::new(
389 0,
390 binding_idx,
391 ¶m_name,
392 WgslType::Array {
393 element: inner.clone(),
394 size: None,
395 },
396 access,
397 ));
398 binding_idx += 1;
399 }
400 WgslType::Array { .. } => {
401 let access = if is_mutable {
402 AccessMode::ReadWrite
403 } else {
404 AccessMode::Read
405 };
406 self.bindings.push(BindingLayout::new(
407 0,
408 binding_idx,
409 ¶m_name,
410 wgsl_type.clone(),
411 access,
412 ));
413 binding_idx += 1;
414 }
415 _ => {
417 self.bindings.push(BindingLayout::uniform(
418 binding_idx,
419 ¶m_name,
420 wgsl_type.clone(),
421 ));
422 binding_idx += 1;
423 }
424 }
425 }
426 }
427
428 Ok(())
429 }
430
431 fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
433 let mut output = String::new();
434
435 for stmt in &block.stmts {
436 let stmt_str = self.transpile_stmt(stmt)?;
437 if !stmt_str.is_empty() {
438 output.push_str(&stmt_str);
439 }
440 }
441
442 Ok(output)
443 }
444
445 fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
447 match stmt {
448 Stmt::Local(local) => {
449 let indent = self.indent_str();
450
451 let var_name = match &local.pat {
453 Pat::Ident(ident) => ident.ident.to_string(),
454 Pat::Type(pat_type) => {
455 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
456 ident.ident.to_string()
457 } else {
458 return Err(TranspileError::Unsupported(
459 "Complex pattern in let binding".into(),
460 ));
461 }
462 }
463 _ => {
464 return Err(TranspileError::Unsupported(
465 "Complex pattern in let binding".into(),
466 ))
467 }
468 };
469
470 if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
472 self.shared_vars.insert(
473 var_name.clone(),
474 SharedVarInfo {
475 name: var_name.clone(),
476 is_tile: shared_decl.dimensions.len() == 2,
477 dimensions: shared_decl
478 .dimensions
479 .iter()
480 .map(|&d| d as usize)
481 .collect(),
482 element_type: shared_decl.element_type.to_wgsl(),
483 },
484 );
485 self.shared_memory.add(shared_decl.clone());
486 return Ok(format!(
487 "{indent}// shared memory: {} declared at module scope\n",
488 var_name
489 ));
490 }
491
492 if let Some(init) = &local.init {
494 let expr_str = self.transpile_expr(&init.expr)?;
495 let type_str = self.infer_wgsl_type(&init.expr);
496
497 Ok(format!(
498 "{indent}var {var_name}: {type_str} = {expr_str};\n"
499 ))
500 } else {
501 Ok(format!("{indent}var {var_name}: f32;\n"))
502 }
503 }
504 Stmt::Expr(expr, semi) => {
505 let indent = self.indent_str();
506
507 if let Expr::If(if_expr) = expr {
509 if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
510 {
511 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
512 let expr_str = self.transpile_expr(expr)?;
513 return Ok(format!("{indent}{expr_str}\n"));
514 }
515 }
516 }
517
518 let expr_str = self.transpile_expr(expr)?;
519
520 if semi.is_some()
521 || matches!(expr, Expr::Return(_))
522 || expr_str.starts_with("return")
523 || expr_str.starts_with("if (")
524 {
525 Ok(format!("{indent}{expr_str};\n"))
526 } else {
527 Ok(format!("{indent}return {expr_str};\n"))
528 }
529 }
530 Stmt::Item(_) => Err(TranspileError::Unsupported("Item in function body".into())),
531 Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
532 }
533 }
534
535 fn transpile_expr(&self, expr: &Expr) -> Result<String> {
537 match expr {
538 Expr::Lit(lit) => self.transpile_lit(lit),
539 Expr::Path(path) => self.transpile_path(path),
540 Expr::Binary(bin) => self.transpile_binary(bin),
541 Expr::Unary(unary) => self.transpile_unary(unary),
542 Expr::Paren(paren) => self.transpile_paren(paren),
543 Expr::Index(index) => self.transpile_index(index),
544 Expr::Call(call) => self.transpile_call(call),
545 Expr::MethodCall(method) => self.transpile_method_call(method),
546 Expr::If(if_expr) => self.transpile_if(if_expr),
547 Expr::Assign(assign) => self.transpile_assign(assign),
548 Expr::Cast(cast) => self.transpile_cast(cast),
549 Expr::Match(match_expr) => self.transpile_match(match_expr),
550 Expr::Block(block) => {
551 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
552 self.transpile_expr(expr)
553 } else {
554 Err(TranspileError::Unsupported(
555 "Complex block expression".into(),
556 ))
557 }
558 }
559 Expr::Field(field) => {
560 let base = self.transpile_expr(&field.base)?;
561 let member = match &field.member {
562 syn::Member::Named(ident) => ident.to_string(),
563 syn::Member::Unnamed(idx) => idx.index.to_string(),
564 };
565 Ok(format!("{base}.{member}"))
566 }
567 Expr::Return(ret) => self.transpile_return(ret),
568 Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
569 Expr::While(while_loop) => self.transpile_while_loop(while_loop),
570 Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
571 Expr::Break(break_expr) => self.transpile_break(break_expr),
572 Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
573 Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
574 Expr::Reference(ref_expr) => {
575 let inner = self.transpile_expr(&ref_expr.expr)?;
577 Ok(format!("&{inner}"))
578 }
579 _ => Err(TranspileError::Unsupported(format!(
580 "Expression type: {}",
581 expr.to_token_stream()
582 ))),
583 }
584 }
585
586 fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
588 match &lit.lit {
589 Lit::Float(f) => {
590 let s = f.to_string();
591 let num = s.trim_end_matches("f32").trim_end_matches("f64");
593 if num.contains('.') {
595 Ok(num.to_string())
596 } else {
597 Ok(format!("{}.0", num))
598 }
599 }
600 Lit::Int(i) => {
601 let s = i.to_string();
602 if s.ends_with("u32") || s.ends_with("usize") {
604 Ok(format!(
605 "{}u",
606 s.trim_end_matches("u32").trim_end_matches("usize")
607 ))
608 } else if s.ends_with("i32") || s.ends_with("isize") {
609 Ok(format!(
610 "{}i",
611 s.trim_end_matches("i32").trim_end_matches("isize")
612 ))
613 } else {
614 Ok(s)
616 }
617 }
618 Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
619 _ => Err(TranspileError::Unsupported(format!(
620 "Literal type: {}",
621 lit.to_token_stream()
622 ))),
623 }
624 }
625
626 fn transpile_path(&self, path: &ExprPath) -> Result<String> {
628 let segments: Vec<_> = path
629 .path
630 .segments
631 .iter()
632 .map(|s| s.ident.to_string())
633 .collect();
634
635 if segments.len() == 1 {
636 Ok(segments[0].clone())
637 } else {
638 Ok(segments.join("_"))
639 }
640 }
641
642 fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
644 let left = self.transpile_expr(&bin.left)?;
645 let right = self.transpile_expr(&bin.right)?;
646
647 let op = match bin.op {
648 BinOp::Add(_) => "+",
649 BinOp::Sub(_) => "-",
650 BinOp::Mul(_) => "*",
651 BinOp::Div(_) => "/",
652 BinOp::Rem(_) => "%",
653 BinOp::And(_) => "&&",
654 BinOp::Or(_) => "||",
655 BinOp::BitXor(_) => "^",
656 BinOp::BitAnd(_) => "&",
657 BinOp::BitOr(_) => "|",
658 BinOp::Shl(_) => "<<",
659 BinOp::Shr(_) => ">>",
660 BinOp::Eq(_) => "==",
661 BinOp::Lt(_) => "<",
662 BinOp::Le(_) => "<=",
663 BinOp::Ne(_) => "!=",
664 BinOp::Ge(_) => ">=",
665 BinOp::Gt(_) => ">",
666 BinOp::AddAssign(_) => "+=",
667 BinOp::SubAssign(_) => "-=",
668 BinOp::MulAssign(_) => "*=",
669 BinOp::DivAssign(_) => "/=",
670 BinOp::RemAssign(_) => "%=",
671 BinOp::BitXorAssign(_) => "^=",
672 BinOp::BitAndAssign(_) => "&=",
673 BinOp::BitOrAssign(_) => "|=",
674 BinOp::ShlAssign(_) => "<<=",
675 BinOp::ShrAssign(_) => ">>=",
676 _ => {
677 return Err(TranspileError::Unsupported(format!(
678 "Binary operator: {}",
679 bin.to_token_stream()
680 )))
681 }
682 };
683
684 Ok(format!("{left} {op} {right}"))
685 }
686
687 fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
689 let expr = self.transpile_expr(&unary.expr)?;
690
691 let op = match unary.op {
692 UnOp::Neg(_) => "-",
693 UnOp::Not(_) => "!",
694 UnOp::Deref(_) => "*",
695 _ => {
696 return Err(TranspileError::Unsupported(format!(
697 "Unary operator: {}",
698 unary.to_token_stream()
699 )))
700 }
701 };
702
703 Ok(format!("{op}({expr})"))
704 }
705
706 fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
708 let inner = self.transpile_expr(&paren.expr)?;
709 Ok(format!("({inner})"))
710 }
711
712 fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
714 let base = self.transpile_expr(&index.expr)?;
715 let idx = self.transpile_expr(&index.index)?;
716 Ok(format!("{base}[{idx}]"))
717 }
718
719 fn transpile_call(&self, call: &ExprCall) -> Result<String> {
721 let func = self.transpile_expr(&call.func)?;
722
723 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
725 return self.transpile_intrinsic_call(intrinsic, &call.args);
726 }
727
728 let args: Vec<String> = call
730 .args
731 .iter()
732 .map(|a| self.transpile_expr(a))
733 .collect::<Result<_>>()?;
734
735 Ok(format!("{}({})", func, args.join(", ")))
736 }
737
738 fn transpile_intrinsic_call(
740 &self,
741 intrinsic: WgslIntrinsic,
742 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
743 ) -> Result<String> {
744 let wgsl_name = intrinsic.to_wgsl();
745
746 if intrinsic.requires_subgroup_extension() {
748 self.needs_subgroup_extension.set(true);
749 }
750
751 match intrinsic {
753 WgslIntrinsic::LocalInvocationIdX
754 | WgslIntrinsic::LocalInvocationIdY
755 | WgslIntrinsic::LocalInvocationIdZ
756 | WgslIntrinsic::WorkgroupIdX
757 | WgslIntrinsic::WorkgroupIdY
758 | WgslIntrinsic::WorkgroupIdZ
759 | WgslIntrinsic::GlobalInvocationIdX
760 | WgslIntrinsic::GlobalInvocationIdY
761 | WgslIntrinsic::GlobalInvocationIdZ
762 | WgslIntrinsic::NumWorkgroupsX
763 | WgslIntrinsic::NumWorkgroupsY
764 | WgslIntrinsic::NumWorkgroupsZ => {
765 return Ok(format!("i32({})", wgsl_name));
767 }
768 WgslIntrinsic::WorkgroupSizeX => {
769 return Ok(format!("i32({}u)", self.workgroup_size.0));
770 }
771 WgslIntrinsic::WorkgroupSizeY => {
772 return Ok(format!("i32({}u)", self.workgroup_size.1));
773 }
774 WgslIntrinsic::WorkgroupSizeZ => {
775 return Ok(format!("i32({}u)", self.workgroup_size.2));
776 }
777 WgslIntrinsic::WorkgroupBarrier | WgslIntrinsic::StorageBarrier => {
778 return Ok(wgsl_name.to_string());
780 }
781 WgslIntrinsic::SubgroupInvocationId | WgslIntrinsic::SubgroupSize => {
782 return Ok(wgsl_name.to_string());
784 }
785 WgslIntrinsic::SubgroupElect => {
786 return Ok(format!("{}()", wgsl_name));
788 }
789 _ => {}
790 }
791
792 let transpiled_args: Vec<String> = args
794 .iter()
795 .map(|a| self.transpile_expr(a))
796 .collect::<Result<_>>()?;
797
798 Ok(format!("{}({})", wgsl_name, transpiled_args.join(", ")))
799 }
800
801 fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
803 let receiver = self.transpile_expr(&method.receiver)?;
804 let method_name = method.method.to_string();
805
806 if let Some(result) =
808 self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
809 {
810 return result;
811 }
812
813 if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
815 return self.transpile_context_method(&method_name, &method.args);
816 }
817
818 if self.grid_pos_vars.contains(&receiver) {
820 return self.transpile_stencil_intrinsic(&method_name, &method.args);
821 }
822
823 if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
825 let wgsl_name = intrinsic.to_wgsl();
826 let args: Vec<String> = std::iter::once(receiver)
827 .chain(
828 method
829 .args
830 .iter()
831 .map(|a| self.transpile_expr(a).unwrap_or_default()),
832 )
833 .collect();
834
835 return Ok(format!("{}({})", wgsl_name, args.join(", ")));
836 }
837
838 let args: Vec<String> = method
840 .args
841 .iter()
842 .map(|a| self.transpile_expr(a))
843 .collect::<Result<_>>()?;
844
845 Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
846 }
847
848 fn transpile_context_method(
850 &self,
851 method: &str,
852 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
853 ) -> Result<String> {
854 let ctx_method = WgslContextMethod::from_name(method).ok_or_else(|| {
855 TranspileError::Unsupported(format!("Unknown context method: {}", method))
856 })?;
857
858 let wgsl_args: Vec<String> = args
859 .iter()
860 .map(|a| self.transpile_expr(a).unwrap_or_default())
861 .collect();
862
863 match ctx_method {
864 WgslContextMethod::AtomicAdd
865 | WgslContextMethod::AtomicLoad
866 | WgslContextMethod::AtomicStore => Ok(format!(
867 "{}({})",
868 ctx_method.to_wgsl(),
869 wgsl_args.join(", ")
870 )),
871 _ => Ok(ctx_method.to_wgsl().to_string()),
872 }
873 }
874
875 fn transpile_stencil_intrinsic(
877 &self,
878 method: &str,
879 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
880 ) -> Result<String> {
881 let config = self.config.as_ref().ok_or_else(|| {
882 TranspileError::Unsupported("Stencil intrinsic without config".into())
883 })?;
884
885 let buffer_width = config.buffer_width();
886
887 match method {
888 "idx" => Ok("idx".to_string()),
889 "x" => Ok("i32(lx)".to_string()),
890 "y" => Ok("i32(ly)".to_string()),
891 "north" => {
892 if args.is_empty() {
893 return Err(TranspileError::Unsupported(
894 "north() requires buffer argument".into(),
895 ));
896 }
897 let buffer = self.transpile_expr(&args[0])?;
898 Ok(format!("{buffer}[idx - {buffer_width}u]"))
899 }
900 "south" => {
901 if args.is_empty() {
902 return Err(TranspileError::Unsupported(
903 "south() requires buffer argument".into(),
904 ));
905 }
906 let buffer = self.transpile_expr(&args[0])?;
907 Ok(format!("{buffer}[idx + {buffer_width}u]"))
908 }
909 "east" => {
910 if args.is_empty() {
911 return Err(TranspileError::Unsupported(
912 "east() requires buffer argument".into(),
913 ));
914 }
915 let buffer = self.transpile_expr(&args[0])?;
916 Ok(format!("{buffer}[idx + 1u]"))
917 }
918 "west" => {
919 if args.is_empty() {
920 return Err(TranspileError::Unsupported(
921 "west() requires buffer argument".into(),
922 ));
923 }
924 let buffer = self.transpile_expr(&args[0])?;
925 Ok(format!("{buffer}[idx - 1u]"))
926 }
927 "at" => {
928 if args.len() < 3 {
929 return Err(TranspileError::Unsupported(
930 "at() requires buffer, dx, dy arguments".into(),
931 ));
932 }
933 let buffer = self.transpile_expr(&args[0])?;
934 let dx = self.transpile_expr(&args[1])?;
935 let dy = self.transpile_expr(&args[2])?;
936 Ok(format!(
937 "{buffer}[idx + u32(({dy}) * i32({buffer_width}u) + ({dx}))]"
938 ))
939 }
940 _ => Err(TranspileError::Unsupported(format!(
941 "Unknown stencil intrinsic: {}",
942 method
943 ))),
944 }
945 }
946
947 fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
949 let cond = self.transpile_expr(&if_expr.cond)?;
950
951 if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
953 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
954 if ret.expr.is_none() {
955 return Ok(format!("if ({cond}) {{ return; }}"));
956 }
957 let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
958 return Ok(format!("if ({cond}) {{ return {ret_val}; }}"));
959 }
960 }
961
962 if let Some((_, else_branch)) = &if_expr.else_branch {
964 if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
965 (if_expr.then_branch.stmts.last(), else_branch.as_ref())
966 {
967 if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
968 let then_str = self.transpile_expr(then_expr)?;
969 let else_str = self.transpile_expr(else_expr)?;
970 return Ok(format!("select({else_str}, {then_str}, {cond})"));
971 }
972 }
973
974 if let Expr::If(else_if) = else_branch.as_ref() {
976 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
977 let else_part = self.transpile_if(else_if)?;
978 return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
979 } else if let Expr::Block(else_block) = else_branch.as_ref() {
980 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
981 let else_body = self.transpile_if_body(&else_block.block)?;
982 return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
983 }
984 }
985
986 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
988 Ok(format!("if ({cond}) {{{then_body}}}"))
989 }
990
991 fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
993 let mut body = String::new();
994 for stmt in &block.stmts {
995 match stmt {
996 Stmt::Expr(expr, Some(_)) => {
997 let expr_str = self.transpile_expr(expr)?;
998 body.push_str(&format!(" {expr_str};"));
999 }
1000 Stmt::Expr(Expr::Return(ret), None) => {
1001 if let Some(ret_expr) = &ret.expr {
1002 let expr_str = self.transpile_expr(ret_expr)?;
1003 body.push_str(&format!(" return {expr_str};"));
1004 } else {
1005 body.push_str(" return;");
1006 }
1007 }
1008 Stmt::Expr(expr, None) => {
1009 let expr_str = self.transpile_expr(expr)?;
1010 body.push_str(&format!(" return {expr_str};"));
1011 }
1012 _ => {}
1013 }
1014 }
1015 Ok(body)
1016 }
1017
1018 fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
1020 let left = self.transpile_expr(&assign.left)?;
1021 let right = self.transpile_expr(&assign.right)?;
1022 Ok(format!("{left} = {right}"))
1023 }
1024
1025 fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
1027 let expr = self.transpile_expr(&cast.expr)?;
1028 let wgsl_type = self
1029 .type_mapper
1030 .map_type(&cast.ty)
1031 .map_err(TranspileError::Type)?;
1032 Ok(format!("{}({})", wgsl_type.to_wgsl(), expr))
1033 }
1034
1035 fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
1037 if let Some(expr) = &ret.expr {
1038 let expr_str = self.transpile_expr(expr)?;
1039 Ok(format!("return {expr_str}"))
1040 } else {
1041 Ok("return".to_string())
1042 }
1043 }
1044
1045 fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1047 let type_name = struct_expr
1048 .path
1049 .segments
1050 .iter()
1051 .map(|s| s.ident.to_string())
1052 .collect::<Vec<_>>()
1053 .join("_");
1054
1055 let mut fields = Vec::new();
1056 for field in &struct_expr.fields {
1057 let field_name = match &field.member {
1058 syn::Member::Named(ident) => ident.to_string(),
1059 syn::Member::Unnamed(idx) => idx.index.to_string(),
1060 };
1061 let value = self.transpile_expr(&field.expr)?;
1062 fields.push(format!("{}: {}", field_name, value));
1063 }
1064
1065 if struct_expr.rest.is_some() {
1066 return Err(TranspileError::Unsupported(
1067 "Struct update syntax (..base) is not supported in WGSL".into(),
1068 ));
1069 }
1070
1071 Ok(format!("{}({})", type_name, fields.join(", ")))
1072 }
1073
1074 fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1078 if !self.validation_mode.allows_loops() {
1079 return Err(TranspileError::Unsupported(
1080 "Loops are not allowed in stencil kernels".into(),
1081 ));
1082 }
1083
1084 let var_name = extract_loop_var(&for_loop.pat)
1085 .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1086
1087 let header = match for_loop.expr.as_ref() {
1088 Expr::Range(range) => {
1089 let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1090 range_info.to_wgsl_for_header(&var_name)
1091 }
1092 _ => {
1093 return Err(TranspileError::Unsupported(
1094 "Only range expressions are supported in for loops".into(),
1095 ));
1096 }
1097 };
1098
1099 let body = self.transpile_loop_body(&for_loop.body)?;
1100 Ok(format!("{header} {{\n{body}}}"))
1101 }
1102
1103 fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1105 if !self.validation_mode.allows_loops() {
1106 return Err(TranspileError::Unsupported(
1107 "Loops are not allowed in stencil kernels".into(),
1108 ));
1109 }
1110
1111 let condition = self.transpile_expr(&while_loop.cond)?;
1112 let body = self.transpile_loop_body(&while_loop.body)?;
1113 Ok(format!("while ({condition}) {{\n{body}}}"))
1114 }
1115
1116 fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1118 if !self.validation_mode.allows_loops() {
1119 return Err(TranspileError::Unsupported(
1120 "Loops are not allowed in stencil kernels".into(),
1121 ));
1122 }
1123
1124 let body = self.transpile_loop_body(&loop_expr.body)?;
1125 Ok(format!("loop {{\n{body}}}"))
1126 }
1127
1128 fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1130 if break_expr.label.is_some() {
1131 return Err(TranspileError::Unsupported(
1132 "Labeled break is not supported in WGSL".into(),
1133 ));
1134 }
1135 if break_expr.expr.is_some() {
1136 return Err(TranspileError::Unsupported(
1137 "Break with value is not supported in WGSL".into(),
1138 ));
1139 }
1140 Ok("break".to_string())
1141 }
1142
1143 fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1145 if cont_expr.label.is_some() {
1146 return Err(TranspileError::Unsupported(
1147 "Labeled continue is not supported in WGSL".into(),
1148 ));
1149 }
1150 Ok("continue".to_string())
1151 }
1152
1153 fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1155 let mut output = String::new();
1156 let inner_indent = " ".repeat(self.indent + 1);
1157
1158 for stmt in &block.stmts {
1159 match stmt {
1160 Stmt::Local(local) => {
1161 let var_name = match &local.pat {
1162 Pat::Ident(ident) => ident.ident.to_string(),
1163 Pat::Type(pat_type) => {
1164 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1165 ident.ident.to_string()
1166 } else {
1167 return Err(TranspileError::Unsupported(
1168 "Complex pattern in let binding".into(),
1169 ));
1170 }
1171 }
1172 _ => {
1173 return Err(TranspileError::Unsupported(
1174 "Complex pattern in let binding".into(),
1175 ))
1176 }
1177 };
1178
1179 if let Some(init) = &local.init {
1180 let expr_str = self.transpile_expr(&init.expr)?;
1181 let type_str = self.infer_wgsl_type(&init.expr);
1182 output.push_str(&format!(
1183 "{inner_indent}var {var_name}: {type_str} = {expr_str};\n"
1184 ));
1185 } else {
1186 output.push_str(&format!("{inner_indent}var {var_name}: f32;\n"));
1187 }
1188 }
1189 Stmt::Expr(expr, _semi) => {
1190 let expr_str = self.transpile_expr(expr)?;
1191 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1192 }
1193 _ => {
1194 return Err(TranspileError::Unsupported(
1195 "Unsupported statement in loop body".into(),
1196 ));
1197 }
1198 }
1199 }
1200
1201 let closing_indent = " ".repeat(self.indent);
1202 output.push_str(&closing_indent);
1203
1204 Ok(output)
1205 }
1206
1207 fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1209 let scrutinee = self.transpile_expr(&match_expr.expr)?;
1210 let mut output = format!("switch ({scrutinee}) {{\n");
1211
1212 for arm in &match_expr.arms {
1213 let case_label = self.transpile_match_pattern(&arm.pat)?;
1214
1215 if case_label == "default" || case_label.starts_with("/*") {
1216 output.push_str(" default: {\n");
1217 } else {
1218 output.push_str(&format!(" case {case_label}: {{\n"));
1219 }
1220
1221 match arm.body.as_ref() {
1222 Expr::Block(block) => {
1223 for stmt in &block.block.stmts {
1224 let stmt_str = self.transpile_stmt_inline(stmt)?;
1225 output.push_str(&format!(" {stmt_str}\n"));
1226 }
1227 }
1228 _ => {
1229 let body = self.transpile_expr(&arm.body)?;
1230 output.push_str(&format!(" {body};\n"));
1231 }
1232 }
1233
1234 output.push_str(" }\n");
1235 }
1236
1237 output.push('}');
1238 Ok(output)
1239 }
1240
1241 fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1243 match pat {
1244 Pat::Lit(pat_lit) => match &pat_lit.lit {
1245 Lit::Int(i) => Ok(i.to_string()),
1246 Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
1247 _ => Err(TranspileError::Unsupported(
1248 "Non-integer literal in match pattern".into(),
1249 )),
1250 },
1251 Pat::Wild(_) => Ok("default".to_string()),
1252 Pat::Ident(ident) => Ok(format!("/* {} */ default", ident.ident)),
1253 Pat::Or(pat_or) => {
1254 if let Some(first) = pat_or.cases.first() {
1255 self.transpile_match_pattern(first)
1256 } else {
1257 Err(TranspileError::Unsupported("Empty or pattern".into()))
1258 }
1259 }
1260 _ => Err(TranspileError::Unsupported(format!(
1261 "Match pattern: {}",
1262 pat.to_token_stream()
1263 ))),
1264 }
1265 }
1266
1267 fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1269 match stmt {
1270 Stmt::Local(local) => {
1271 let var_name = match &local.pat {
1272 Pat::Ident(ident) => ident.ident.to_string(),
1273 Pat::Type(pat_type) => {
1274 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1275 ident.ident.to_string()
1276 } else {
1277 return Err(TranspileError::Unsupported(
1278 "Complex pattern in let binding".into(),
1279 ));
1280 }
1281 }
1282 _ => {
1283 return Err(TranspileError::Unsupported(
1284 "Complex pattern in let binding".into(),
1285 ))
1286 }
1287 };
1288
1289 if let Some(init) = &local.init {
1290 let expr_str = self.transpile_expr(&init.expr)?;
1291 let type_str = self.infer_wgsl_type(&init.expr);
1292 Ok(format!("var {var_name}: {type_str} = {expr_str};"))
1293 } else {
1294 Ok(format!("var {var_name}: f32;"))
1295 }
1296 }
1297 Stmt::Expr(expr, semi) => {
1298 let expr_str = self.transpile_expr(expr)?;
1299 if semi.is_some() {
1300 Ok(format!("{expr_str};"))
1301 } else {
1302 Ok(format!("return {expr_str};"))
1303 }
1304 }
1305 _ => Err(TranspileError::Unsupported(
1306 "Unsupported statement in match arm".into(),
1307 )),
1308 }
1309 }
1310
1311 fn try_parse_shared_declaration(
1315 &self,
1316 local: &syn::Local,
1317 var_name: &str,
1318 ) -> Result<Option<SharedMemoryDecl>> {
1319 if let Pat::Type(pat_type) = &local.pat {
1320 let type_str = pat_type.ty.to_token_stream().to_string();
1321 return self.parse_shared_type(&type_str, var_name);
1322 }
1323
1324 if let Some(init) = &local.init {
1325 if let Expr::Call(call) = init.expr.as_ref() {
1326 if let Expr::Path(path) = call.func.as_ref() {
1327 let path_str = path.to_token_stream().to_string();
1328 return self.parse_shared_type(&path_str, var_name);
1329 }
1330 }
1331 }
1332
1333 Ok(None)
1334 }
1335
1336 fn parse_shared_type(
1338 &self,
1339 type_str: &str,
1340 var_name: &str,
1341 ) -> Result<Option<SharedMemoryDecl>> {
1342 let type_str = type_str
1343 .replace(" :: ", "::")
1344 .replace(" ::", "::")
1345 .replace(":: ", "::");
1346
1347 if type_str.contains("SharedTile") {
1348 if let Some(start) = type_str.find('<') {
1349 if let Some(end) = type_str.rfind('>') {
1350 let params = &type_str[start + 1..end];
1351 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1352
1353 if parts.len() >= 3 {
1354 let rust_type = parts[0];
1355 let width: u32 = parts[1].parse().map_err(|_| {
1356 TranspileError::Unsupported("Invalid SharedTile width".into())
1357 })?;
1358 let height: u32 = parts[2].parse().map_err(|_| {
1359 TranspileError::Unsupported("Invalid SharedTile height".into())
1360 })?;
1361
1362 let wgsl_type = rust_type_to_wgsl(rust_type);
1363 return Ok(Some(SharedMemoryDecl::new_2d(
1364 var_name, wgsl_type, width, height,
1365 )));
1366 }
1367 }
1368 }
1369 }
1370
1371 if type_str.contains("SharedArray") {
1372 if let Some(start) = type_str.find('<') {
1373 if let Some(end) = type_str.rfind('>') {
1374 let params = &type_str[start + 1..end];
1375 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1376
1377 if parts.len() >= 2 {
1378 let rust_type = parts[0];
1379 let size: u32 = parts[1].parse().map_err(|_| {
1380 TranspileError::Unsupported("Invalid SharedArray size".into())
1381 })?;
1382
1383 let wgsl_type = rust_type_to_wgsl(rust_type);
1384 return Ok(Some(SharedMemoryDecl::new_1d(var_name, wgsl_type, size)));
1385 }
1386 }
1387 }
1388 }
1389
1390 Ok(None)
1391 }
1392
1393 fn try_transpile_shared_method_call(
1395 &self,
1396 receiver: &str,
1397 method_name: &str,
1398 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1399 ) -> Option<Result<String>> {
1400 let shared_info = self.shared_vars.get(receiver)?;
1401
1402 match method_name {
1403 "get" => {
1404 if shared_info.is_tile {
1405 if args.len() >= 2 {
1406 let x = self.transpile_expr(&args[0]).ok()?;
1407 let y = self.transpile_expr(&args[1]).ok()?;
1408 Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1409 } else {
1410 Some(Err(TranspileError::Unsupported(
1411 "SharedTile.get requires x and y arguments".into(),
1412 )))
1413 }
1414 } else if !args.is_empty() {
1415 let idx = self.transpile_expr(&args[0]).ok()?;
1416 Some(Ok(format!("{}[{}]", receiver, idx)))
1417 } else {
1418 Some(Err(TranspileError::Unsupported(
1419 "SharedArray.get requires index argument".into(),
1420 )))
1421 }
1422 }
1423 "set" => {
1424 if shared_info.is_tile {
1425 if args.len() >= 3 {
1426 let x = self.transpile_expr(&args[0]).ok()?;
1427 let y = self.transpile_expr(&args[1]).ok()?;
1428 let val = self.transpile_expr(&args[2]).ok()?;
1429 Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1430 } else {
1431 Some(Err(TranspileError::Unsupported(
1432 "SharedTile.set requires x, y, and value arguments".into(),
1433 )))
1434 }
1435 } else if args.len() >= 2 {
1436 let idx = self.transpile_expr(&args[0]).ok()?;
1437 let val = self.transpile_expr(&args[1]).ok()?;
1438 Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1439 } else {
1440 Some(Err(TranspileError::Unsupported(
1441 "SharedArray.set requires index and value arguments".into(),
1442 )))
1443 }
1444 }
1445 "width" if shared_info.is_tile => Some(Ok(shared_info.dimensions[1].to_string())),
1446 "height" if shared_info.is_tile => Some(Ok(shared_info.dimensions[0].to_string())),
1447 "size" => {
1448 let total: usize = shared_info.dimensions.iter().product();
1449 Some(Ok(total.to_string()))
1450 }
1451 _ => None,
1452 }
1453 }
1454
1455 fn infer_wgsl_type(&self, expr: &Expr) -> &'static str {
1457 match expr {
1458 Expr::Lit(lit) => match &lit.lit {
1459 Lit::Float(_) => "f32",
1460 Lit::Int(i) => {
1461 let s = i.to_string();
1462 if s.ends_with("u32") || s.ends_with("usize") {
1463 "u32"
1464 } else {
1465 "i32"
1466 }
1467 }
1468 Lit::Bool(_) => "bool",
1469 _ => "f32",
1470 },
1471 Expr::Binary(bin) => {
1472 let left_type = self.infer_wgsl_type(&bin.left);
1473 let right_type = self.infer_wgsl_type(&bin.right);
1474 if left_type == "i32" && right_type == "i32" {
1475 "i32"
1476 } else if left_type == "u32" && right_type == "u32" {
1477 "u32"
1478 } else {
1479 "f32"
1480 }
1481 }
1482 Expr::Call(call) => {
1483 if let Ok(func) = self.transpile_expr(&call.func) {
1484 if let Some(
1485 WgslIntrinsic::LocalInvocationIdX
1486 | WgslIntrinsic::WorkgroupIdX
1487 | WgslIntrinsic::GlobalInvocationIdX,
1488 ) = self.intrinsics.lookup(&func)
1489 {
1490 return "i32";
1491 }
1492 }
1493 "f32"
1494 }
1495 Expr::Cast(cast) => {
1496 if let Ok(wgsl_type) = self.type_mapper.map_type(&cast.ty) {
1497 match wgsl_type {
1498 WgslType::I32 => return "i32",
1499 WgslType::U32 => return "u32",
1500 WgslType::F32 => return "f32",
1501 WgslType::Bool => return "bool",
1502 _ => {}
1503 }
1504 }
1505 "f32"
1506 }
1507 _ => "f32",
1508 }
1509 }
1510}
1511
1512fn extract_loop_var(pat: &Pat) -> Option<String> {
1514 match pat {
1515 Pat::Ident(ident) => Some(ident.ident.to_string()),
1516 _ => None,
1517 }
1518}
1519
1520fn rust_type_to_wgsl(rust_type: &str) -> WgslType {
1522 match rust_type {
1523 "f32" => WgslType::F32,
1524 "f64" => WgslType::F32, "i32" => WgslType::I32,
1526 "u32" => WgslType::U32,
1527 "bool" => WgslType::Bool,
1528 _ => WgslType::F32,
1529 }
1530}
1531
1532pub fn transpile_function(func: &ItemFn) -> Result<String> {
1534 let mut transpiler = WgslTranspiler::new_generic();
1535
1536 let name = func.sig.ident.to_string();
1537
1538 let mut params = Vec::new();
1539 for param in &func.sig.inputs {
1540 if let FnArg::Typed(pat_type) = param {
1541 let param_name = match pat_type.pat.as_ref() {
1542 Pat::Ident(ident) => ident.ident.to_string(),
1543 _ => continue,
1544 };
1545
1546 let wgsl_type = transpiler
1547 .type_mapper
1548 .map_type(&pat_type.ty)
1549 .map_err(TranspileError::Type)?;
1550 params.push(format!("{}: {}", param_name, wgsl_type.to_wgsl()));
1551 }
1552 }
1553
1554 let return_type = match &func.sig.output {
1555 ReturnType::Default => "".to_string(),
1556 ReturnType::Type(_, ty) => {
1557 let wgsl_type = transpiler
1558 .type_mapper
1559 .map_type(ty)
1560 .map_err(TranspileError::Type)?;
1561 format!(" -> {}", wgsl_type.to_wgsl())
1562 }
1563 };
1564
1565 let body = transpiler.transpile_block(&func.block)?;
1566
1567 Ok(format!(
1568 "fn {name}({}){return_type} {{\n{body}}}\n",
1569 params.join(", ")
1570 ))
1571}
1572
1573impl RangeInfo {
1576 pub fn from_range<F>(range: &syn::ExprRange, transpile: F) -> Self
1578 where
1579 F: Fn(&Expr) -> Result<String>,
1580 {
1581 let start = range.start.as_ref().and_then(|e| transpile(e).ok());
1582 let end = range.end.as_ref().and_then(|e| transpile(e).ok());
1583 let inclusive = matches!(range.limits, syn::RangeLimits::Closed(_));
1584
1585 Self::new(start, end, inclusive)
1586 }
1587
1588 pub fn to_wgsl_for_header(&self, var_name: &str) -> String {
1590 let start = self.start_or_default();
1591 let end = self.end.as_deref().unwrap_or("/* unbounded */");
1592 let op = if self.inclusive { "<=" } else { "<" };
1593
1594 format!(
1595 "for (var {var_name}: i32 = {start}; {var_name} {op} {end}; {var_name} = {var_name} + 1)"
1596 )
1597 }
1598}
1599
1600#[cfg(test)]
1601mod tests {
1602 use super::*;
1603 use syn::parse_quote;
1604
1605 #[test]
1606 fn test_simple_arithmetic() {
1607 let transpiler = WgslTranspiler::new_generic();
1608
1609 let expr: Expr = parse_quote!(a + b * 2.0);
1610 let result = transpiler.transpile_expr(&expr).unwrap();
1611 assert_eq!(result, "a + b * 2.0");
1612 }
1613
1614 #[test]
1615 fn test_let_binding() {
1616 let mut transpiler = WgslTranspiler::new_generic();
1617
1618 let stmt: Stmt = parse_quote!(let x = a + b;);
1619 let result = transpiler.transpile_stmt(&stmt).unwrap();
1620 assert!(result.contains("var x: f32 = a + b;"));
1621 }
1622
1623 #[test]
1624 fn test_array_index() {
1625 let transpiler = WgslTranspiler::new_generic();
1626
1627 let expr: Expr = parse_quote!(data[idx]);
1628 let result = transpiler.transpile_expr(&expr).unwrap();
1629 assert_eq!(result, "data[idx]");
1630 }
1631
1632 #[test]
1633 fn test_intrinsic_thread_idx() {
1634 let transpiler = WgslTranspiler::new_generic();
1635
1636 let expr: Expr = parse_quote!(thread_idx_x());
1637 let result = transpiler.transpile_expr(&expr).unwrap();
1638 assert!(result.contains("local_invocation_id.x"));
1639 }
1640
1641 #[test]
1642 fn test_for_loop() {
1643 let transpiler = WgslTranspiler::new_generic();
1644
1645 let expr: Expr = parse_quote! {
1646 for i in 0..n {
1647 data[i] = 0.0;
1648 }
1649 };
1650
1651 let result = transpiler.transpile_expr(&expr).unwrap();
1652 assert!(result.contains("for (var i: i32 = 0; i < n; i = i + 1)"));
1653 }
1654
1655 #[test]
1656 fn test_while_loop() {
1657 let transpiler = WgslTranspiler::new_generic();
1658
1659 let expr: Expr = parse_quote! {
1660 while x < 10 {
1661 x += 1;
1662 }
1663 };
1664
1665 let result = transpiler.transpile_expr(&expr).unwrap();
1666 assert!(result.contains("while (x < 10)"));
1667 }
1668
1669 #[test]
1670 fn test_early_return() {
1671 let transpiler = WgslTranspiler::new_generic();
1672
1673 let expr: Expr = parse_quote! {
1674 if idx >= n { return; }
1675 };
1676
1677 let result = transpiler.transpile_expr(&expr).unwrap();
1678 assert!(result.contains("if (idx >= n)"));
1679 assert!(result.contains("return;"));
1680 }
1681
1682 #[test]
1683 fn test_sync_threads() {
1684 let transpiler = WgslTranspiler::new_generic();
1685
1686 let expr: Expr = parse_quote!(sync_threads());
1687 let result = transpiler.transpile_expr(&expr).unwrap();
1688 assert_eq!(result, "workgroupBarrier()");
1689 }
1690}