1use crate::bindings::{generate_bindings, AccessMode, BindingLayout};
6use crate::handler::WgslContextMethod;
7use crate::intrinsics::{IntrinsicRegistry, WgslIntrinsic};
8use crate::loops::RangeInfo;
9use crate::ring_kernel::RingKernelConfig;
10use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
11use crate::stencil::StencilConfig;
12use crate::types::{
13 is_grid_pos_type, is_mutable_reference, is_ring_context_type, TypeMapper, WgslType,
14};
15use crate::u64_workarounds::U64Helpers;
16use crate::validation::ValidationMode;
17use crate::{Result, TranspileError};
18use quote::ToTokens;
19use std::collections::HashMap;
20use syn::{
21 BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
22 ExprIf, ExprIndex, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
23 ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat, ReturnType, Stmt, UnOp,
24};
25
26pub struct WgslTranspiler {
28 config: Option<StencilConfig>,
30 #[allow(dead_code)]
32 ring_config: Option<RingKernelConfig>,
33 type_mapper: TypeMapper,
35 intrinsics: IntrinsicRegistry,
37 grid_pos_vars: Vec<String>,
39 context_vars: Vec<String>,
41 indent: usize,
43 validation_mode: ValidationMode,
45 shared_memory: SharedMemoryConfig,
47 pub shared_vars: HashMap<String, SharedVarInfo>,
49 ring_kernel_mode: bool,
51 needs_u64_helpers: bool,
53 needs_subgroup_extension: bool,
55 workgroup_size: (u32, u32, u32),
57 bindings: Vec<BindingLayout>,
59}
60
61#[derive(Debug, Clone)]
63pub struct SharedVarInfo {
64 pub name: String,
66 pub is_tile: bool,
68 pub dimensions: Vec<usize>,
70 pub element_type: String,
72}
73
74impl WgslTranspiler {
75 pub fn new_stencil(config: StencilConfig) -> Self {
77 Self {
78 config: Some(config),
79 ring_config: None,
80 type_mapper: TypeMapper::new(),
81 intrinsics: IntrinsicRegistry::new(),
82 grid_pos_vars: Vec::new(),
83 context_vars: Vec::new(),
84 indent: 1,
85 validation_mode: ValidationMode::Stencil,
86 shared_memory: SharedMemoryConfig::new(),
87 shared_vars: HashMap::new(),
88 ring_kernel_mode: false,
89 needs_u64_helpers: false,
90 needs_subgroup_extension: false,
91 workgroup_size: (256, 1, 1),
92 bindings: Vec::new(),
93 }
94 }
95
96 pub fn new_generic() -> Self {
98 Self {
99 config: None,
100 ring_config: None,
101 type_mapper: TypeMapper::new(),
102 intrinsics: IntrinsicRegistry::new(),
103 grid_pos_vars: Vec::new(),
104 context_vars: Vec::new(),
105 indent: 1,
106 validation_mode: ValidationMode::Generic,
107 shared_memory: SharedMemoryConfig::new(),
108 shared_vars: HashMap::new(),
109 ring_kernel_mode: false,
110 needs_u64_helpers: false,
111 needs_subgroup_extension: false,
112 workgroup_size: (256, 1, 1),
113 bindings: Vec::new(),
114 }
115 }
116
117 pub fn new_ring_kernel(config: RingKernelConfig) -> Self {
119 Self {
120 config: None,
121 ring_config: Some(config.clone()),
122 type_mapper: TypeMapper::new(),
123 intrinsics: IntrinsicRegistry::new(),
124 grid_pos_vars: Vec::new(),
125 context_vars: Vec::new(),
126 indent: 2,
127 validation_mode: ValidationMode::Generic,
128 shared_memory: SharedMemoryConfig::new(),
129 shared_vars: HashMap::new(),
130 ring_kernel_mode: true,
131 needs_u64_helpers: true, needs_subgroup_extension: false,
133 workgroup_size: (config.workgroup_size, 1, 1),
134 bindings: Vec::new(),
135 }
136 }
137
138 pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
140 self.workgroup_size = (x, y, z);
141 self
142 }
143
144 fn indent_str(&self) -> String {
146 " ".repeat(self.indent)
147 }
148
149 pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
151 let config = self
152 .config
153 .as_ref()
154 .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
155 .clone();
156
157 for param in &func.sig.inputs {
159 if let FnArg::Typed(pat_type) = param {
160 if is_grid_pos_type(&pat_type.ty) {
161 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
162 self.grid_pos_vars.push(ident.ident.to_string());
163 }
164 }
165 }
166 }
167
168 self.collect_bindings(func)?;
170
171 let mut output = String::new();
172
173 output.push_str(&generate_bindings(&self.bindings));
175 output.push_str("\n\n");
176
177 if config.use_shared_memory {
179 let buffer_width = config.buffer_width();
180 let buffer_height = config.buffer_height();
181 output.push_str(&format!(
182 "var<workgroup> tile: array<array<f32, {}>, {}>;\n\n",
183 buffer_width, buffer_height
184 ));
185 }
186
187 output.push_str("@compute ");
189 output.push_str(&config.workgroup_size_annotation());
190 output.push('\n');
191 output.push_str(&format!("fn {}(\n", func.sig.ident));
192 output.push_str(" @builtin(local_invocation_id) local_id: vec3<u32>,\n");
193 output.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
194 output.push_str(" @builtin(global_invocation_id) global_id: vec3<u32>\n");
195 output.push_str(") {\n");
196
197 output.push_str(&self.generate_stencil_preamble(&config));
199
200 let body = self.transpile_block(&func.block)?;
202 output.push_str(&body);
203
204 output.push_str("}\n");
205
206 Ok(output)
207 }
208
209 fn generate_stencil_preamble(&self, config: &StencilConfig) -> String {
211 let buffer_width = config.buffer_width();
212 let mut preamble = String::new();
213
214 preamble.push_str(" // Thread indices\n");
215 preamble.push_str(" let lx = local_id.x;\n");
216 preamble.push_str(" let ly = local_id.y;\n");
217 preamble.push_str(&format!(" let buffer_width = {}u;\n", buffer_width));
218 preamble.push('\n');
219
220 preamble.push_str(&format!(
222 " if (lx >= {}u || ly >= {}u) {{ return; }}\n\n",
223 config.tile_width, config.tile_height
224 ));
225
226 preamble.push_str(&format!(
228 " let idx = (ly + {}u) * buffer_width + (lx + {}u);\n\n",
229 config.halo, config.halo
230 ));
231
232 preamble
233 }
234
235 pub fn transpile_global_kernel(&mut self, func: &ItemFn) -> Result<String> {
237 self.collect_bindings(func)?;
239
240 let mut output = String::new();
241
242 if self.needs_subgroup_extension {
244 output.push_str("enable chromium_experimental_subgroups;\n\n");
245 }
246
247 output.push_str(&generate_bindings(&self.bindings));
249 output.push_str("\n\n");
250
251 if !self.shared_memory.declarations.is_empty() {
253 output.push_str(&self.shared_memory.to_wgsl());
254 output.push_str("\n\n");
255 }
256
257 output.push_str("@compute ");
259 output.push_str(&format!(
260 "@workgroup_size({}, {}, {})\n",
261 self.workgroup_size.0, self.workgroup_size.1, self.workgroup_size.2
262 ));
263 output.push_str(&format!("fn {}(\n", func.sig.ident));
264 output.push_str(" @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
265 output.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
266 output.push_str(" @builtin(global_invocation_id) global_invocation_id: vec3<u32>,\n");
267 output.push_str(" @builtin(num_workgroups) num_workgroups: vec3<u32>\n");
268 output.push_str(") {\n");
269
270 let body = self.transpile_block(&func.block)?;
272 output.push_str(&body);
273
274 output.push_str("}\n");
275
276 Ok(output)
277 }
278
279 pub fn transpile_ring_kernel(
281 &mut self,
282 handler: &ItemFn,
283 config: &RingKernelConfig,
284 ) -> Result<String> {
285 for param in &handler.sig.inputs {
287 if let FnArg::Typed(pat_type) = param {
288 if is_ring_context_type(&pat_type.ty) {
289 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
290 self.context_vars.push(ident.ident.to_string());
291 }
292 }
293 }
294 }
295
296 self.ring_kernel_mode = true;
297 self.needs_u64_helpers = true;
298
299 let mut output = String::new();
300
301 output.push_str(&U64Helpers::generate_all());
303 output.push_str("\n\n");
304
305 output.push_str(&crate::ring_kernel::generate_control_block_struct(config));
307 output.push_str("\n\n");
308
309 output.push_str(crate::ring_kernel::generate_ring_kernel_bindings());
311 output.push_str("\n\n");
312
313 output.push_str("@compute ");
315 output.push_str(&config.workgroup_size_annotation());
316 output.push('\n');
317 output.push_str(&format!("fn ring_kernel_{}(\n", config.name));
318 output.push_str(" @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
319 output.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
320 output.push_str(" @builtin(global_invocation_id) global_invocation_id: vec3<u32>\n");
321 output.push_str(") {\n");
322
323 output.push_str(crate::ring_kernel::generate_ring_kernel_preamble());
325 output.push_str("\n\n");
326
327 output.push_str(" // === USER HANDLER CODE ===\n");
329 self.indent = 1;
330 let handler_body = self.transpile_block(&handler.block)?;
331 output.push_str(&handler_body);
332 output.push_str(" // === END HANDLER CODE ===\n\n");
333
334 output.push_str(" // Update message counter\n");
336 output.push_str(
337 " atomic_inc_u64(&control.messages_processed_lo, &control.messages_processed_hi);\n",
338 );
339
340 output.push_str("}\n");
341
342 Ok(output)
343 }
344
345 fn collect_bindings(&mut self, func: &ItemFn) -> Result<()> {
347 let mut binding_idx = 0u32;
348
349 for param in &func.sig.inputs {
350 if let FnArg::Typed(pat_type) = param {
351 if is_grid_pos_type(&pat_type.ty) || is_ring_context_type(&pat_type.ty) {
353 continue;
354 }
355
356 let param_name = match pat_type.pat.as_ref() {
357 Pat::Ident(ident) => ident.ident.to_string(),
358 _ => continue,
359 };
360
361 let wgsl_type = self
362 .type_mapper
363 .map_type(&pat_type.ty)
364 .map_err(TranspileError::Type)?;
365 let is_mutable = is_mutable_reference(&pat_type.ty);
366
367 match &wgsl_type {
369 WgslType::Ptr { inner, .. } => {
370 let access = if is_mutable {
371 AccessMode::ReadWrite
372 } else {
373 AccessMode::Read
374 };
375 self.bindings.push(BindingLayout::new(
376 0,
377 binding_idx,
378 ¶m_name,
379 WgslType::Array {
380 element: inner.clone(),
381 size: None,
382 },
383 access,
384 ));
385 binding_idx += 1;
386 }
387 WgslType::Array { .. } => {
388 let access = if is_mutable {
389 AccessMode::ReadWrite
390 } else {
391 AccessMode::Read
392 };
393 self.bindings.push(BindingLayout::new(
394 0,
395 binding_idx,
396 ¶m_name,
397 wgsl_type.clone(),
398 access,
399 ));
400 binding_idx += 1;
401 }
402 _ => {
404 self.bindings.push(BindingLayout::uniform(
405 binding_idx,
406 ¶m_name,
407 wgsl_type.clone(),
408 ));
409 binding_idx += 1;
410 }
411 }
412 }
413 }
414
415 Ok(())
416 }
417
418 fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
420 let mut output = String::new();
421
422 for stmt in &block.stmts {
423 let stmt_str = self.transpile_stmt(stmt)?;
424 if !stmt_str.is_empty() {
425 output.push_str(&stmt_str);
426 }
427 }
428
429 Ok(output)
430 }
431
432 fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
434 match stmt {
435 Stmt::Local(local) => {
436 let indent = self.indent_str();
437
438 let var_name = match &local.pat {
440 Pat::Ident(ident) => ident.ident.to_string(),
441 Pat::Type(pat_type) => {
442 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
443 ident.ident.to_string()
444 } else {
445 return Err(TranspileError::Unsupported(
446 "Complex pattern in let binding".into(),
447 ));
448 }
449 }
450 _ => {
451 return Err(TranspileError::Unsupported(
452 "Complex pattern in let binding".into(),
453 ))
454 }
455 };
456
457 if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
459 self.shared_vars.insert(
460 var_name.clone(),
461 SharedVarInfo {
462 name: var_name.clone(),
463 is_tile: shared_decl.dimensions.len() == 2,
464 dimensions: shared_decl
465 .dimensions
466 .iter()
467 .map(|&d| d as usize)
468 .collect(),
469 element_type: shared_decl.element_type.to_wgsl(),
470 },
471 );
472 self.shared_memory.add(shared_decl.clone());
473 return Ok(format!(
474 "{indent}// shared memory: {} declared at module scope\n",
475 var_name
476 ));
477 }
478
479 if let Some(init) = &local.init {
481 let expr_str = self.transpile_expr(&init.expr)?;
482 let type_str = self.infer_wgsl_type(&init.expr);
483
484 Ok(format!(
485 "{indent}var {var_name}: {type_str} = {expr_str};\n"
486 ))
487 } else {
488 Ok(format!("{indent}var {var_name}: f32;\n"))
489 }
490 }
491 Stmt::Expr(expr, semi) => {
492 let indent = self.indent_str();
493
494 if let Expr::If(if_expr) = expr {
496 if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
497 {
498 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
499 let expr_str = self.transpile_expr(expr)?;
500 return Ok(format!("{indent}{expr_str}\n"));
501 }
502 }
503 }
504
505 let expr_str = self.transpile_expr(expr)?;
506
507 if semi.is_some()
508 || matches!(expr, Expr::Return(_))
509 || expr_str.starts_with("return")
510 || expr_str.starts_with("if (")
511 {
512 Ok(format!("{indent}{expr_str};\n"))
513 } else {
514 Ok(format!("{indent}return {expr_str};\n"))
515 }
516 }
517 Stmt::Item(_) => Err(TranspileError::Unsupported("Item in function body".into())),
518 Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
519 }
520 }
521
522 fn transpile_expr(&self, expr: &Expr) -> Result<String> {
524 match expr {
525 Expr::Lit(lit) => self.transpile_lit(lit),
526 Expr::Path(path) => self.transpile_path(path),
527 Expr::Binary(bin) => self.transpile_binary(bin),
528 Expr::Unary(unary) => self.transpile_unary(unary),
529 Expr::Paren(paren) => self.transpile_paren(paren),
530 Expr::Index(index) => self.transpile_index(index),
531 Expr::Call(call) => self.transpile_call(call),
532 Expr::MethodCall(method) => self.transpile_method_call(method),
533 Expr::If(if_expr) => self.transpile_if(if_expr),
534 Expr::Assign(assign) => self.transpile_assign(assign),
535 Expr::Cast(cast) => self.transpile_cast(cast),
536 Expr::Match(match_expr) => self.transpile_match(match_expr),
537 Expr::Block(block) => {
538 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
539 self.transpile_expr(expr)
540 } else {
541 Err(TranspileError::Unsupported(
542 "Complex block expression".into(),
543 ))
544 }
545 }
546 Expr::Field(field) => {
547 let base = self.transpile_expr(&field.base)?;
548 let member = match &field.member {
549 syn::Member::Named(ident) => ident.to_string(),
550 syn::Member::Unnamed(idx) => idx.index.to_string(),
551 };
552 Ok(format!("{base}.{member}"))
553 }
554 Expr::Return(ret) => self.transpile_return(ret),
555 Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
556 Expr::While(while_loop) => self.transpile_while_loop(while_loop),
557 Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
558 Expr::Break(break_expr) => self.transpile_break(break_expr),
559 Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
560 Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
561 Expr::Reference(ref_expr) => {
562 let inner = self.transpile_expr(&ref_expr.expr)?;
564 Ok(format!("&{inner}"))
565 }
566 _ => Err(TranspileError::Unsupported(format!(
567 "Expression type: {}",
568 expr.to_token_stream()
569 ))),
570 }
571 }
572
573 fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
575 match &lit.lit {
576 Lit::Float(f) => {
577 let s = f.to_string();
578 let num = s.trim_end_matches("f32").trim_end_matches("f64");
580 if num.contains('.') {
582 Ok(num.to_string())
583 } else {
584 Ok(format!("{}.0", num))
585 }
586 }
587 Lit::Int(i) => {
588 let s = i.to_string();
589 if s.ends_with("u32") || s.ends_with("usize") {
591 Ok(format!(
592 "{}u",
593 s.trim_end_matches("u32").trim_end_matches("usize")
594 ))
595 } else if s.ends_with("i32") || s.ends_with("isize") {
596 Ok(format!(
597 "{}i",
598 s.trim_end_matches("i32").trim_end_matches("isize")
599 ))
600 } else {
601 Ok(s)
603 }
604 }
605 Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
606 _ => Err(TranspileError::Unsupported(format!(
607 "Literal type: {}",
608 lit.to_token_stream()
609 ))),
610 }
611 }
612
613 fn transpile_path(&self, path: &ExprPath) -> Result<String> {
615 let segments: Vec<_> = path
616 .path
617 .segments
618 .iter()
619 .map(|s| s.ident.to_string())
620 .collect();
621
622 if segments.len() == 1 {
623 Ok(segments[0].clone())
624 } else {
625 Ok(segments.join("_"))
626 }
627 }
628
629 fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
631 let left = self.transpile_expr(&bin.left)?;
632 let right = self.transpile_expr(&bin.right)?;
633
634 let op = match bin.op {
635 BinOp::Add(_) => "+",
636 BinOp::Sub(_) => "-",
637 BinOp::Mul(_) => "*",
638 BinOp::Div(_) => "/",
639 BinOp::Rem(_) => "%",
640 BinOp::And(_) => "&&",
641 BinOp::Or(_) => "||",
642 BinOp::BitXor(_) => "^",
643 BinOp::BitAnd(_) => "&",
644 BinOp::BitOr(_) => "|",
645 BinOp::Shl(_) => "<<",
646 BinOp::Shr(_) => ">>",
647 BinOp::Eq(_) => "==",
648 BinOp::Lt(_) => "<",
649 BinOp::Le(_) => "<=",
650 BinOp::Ne(_) => "!=",
651 BinOp::Ge(_) => ">=",
652 BinOp::Gt(_) => ">",
653 BinOp::AddAssign(_) => "+=",
654 BinOp::SubAssign(_) => "-=",
655 BinOp::MulAssign(_) => "*=",
656 BinOp::DivAssign(_) => "/=",
657 BinOp::RemAssign(_) => "%=",
658 BinOp::BitXorAssign(_) => "^=",
659 BinOp::BitAndAssign(_) => "&=",
660 BinOp::BitOrAssign(_) => "|=",
661 BinOp::ShlAssign(_) => "<<=",
662 BinOp::ShrAssign(_) => ">>=",
663 _ => {
664 return Err(TranspileError::Unsupported(format!(
665 "Binary operator: {}",
666 bin.to_token_stream()
667 )))
668 }
669 };
670
671 Ok(format!("{left} {op} {right}"))
672 }
673
674 fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
676 let expr = self.transpile_expr(&unary.expr)?;
677
678 let op = match unary.op {
679 UnOp::Neg(_) => "-",
680 UnOp::Not(_) => "!",
681 UnOp::Deref(_) => "*",
682 _ => {
683 return Err(TranspileError::Unsupported(format!(
684 "Unary operator: {}",
685 unary.to_token_stream()
686 )))
687 }
688 };
689
690 Ok(format!("{op}({expr})"))
691 }
692
693 fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
695 let inner = self.transpile_expr(&paren.expr)?;
696 Ok(format!("({inner})"))
697 }
698
699 fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
701 let base = self.transpile_expr(&index.expr)?;
702 let idx = self.transpile_expr(&index.index)?;
703 Ok(format!("{base}[{idx}]"))
704 }
705
706 fn transpile_call(&self, call: &ExprCall) -> Result<String> {
708 let func = self.transpile_expr(&call.func)?;
709
710 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
712 return self.transpile_intrinsic_call(intrinsic, &call.args);
713 }
714
715 let args: Vec<String> = call
717 .args
718 .iter()
719 .map(|a| self.transpile_expr(a))
720 .collect::<Result<_>>()?;
721
722 Ok(format!("{}({})", func, args.join(", ")))
723 }
724
725 fn transpile_intrinsic_call(
727 &self,
728 intrinsic: WgslIntrinsic,
729 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
730 ) -> Result<String> {
731 let wgsl_name = intrinsic.to_wgsl();
732
733 match intrinsic {
735 WgslIntrinsic::LocalInvocationIdX
736 | WgslIntrinsic::LocalInvocationIdY
737 | WgslIntrinsic::LocalInvocationIdZ
738 | WgslIntrinsic::WorkgroupIdX
739 | WgslIntrinsic::WorkgroupIdY
740 | WgslIntrinsic::WorkgroupIdZ
741 | WgslIntrinsic::GlobalInvocationIdX
742 | WgslIntrinsic::GlobalInvocationIdY
743 | WgslIntrinsic::GlobalInvocationIdZ
744 | WgslIntrinsic::NumWorkgroupsX
745 | WgslIntrinsic::NumWorkgroupsY
746 | WgslIntrinsic::NumWorkgroupsZ => {
747 return Ok(format!("i32({})", wgsl_name));
749 }
750 WgslIntrinsic::WorkgroupSizeX => {
751 return Ok(format!("i32({}u)", self.workgroup_size.0));
752 }
753 WgslIntrinsic::WorkgroupSizeY => {
754 return Ok(format!("i32({}u)", self.workgroup_size.1));
755 }
756 WgslIntrinsic::WorkgroupSizeZ => {
757 return Ok(format!("i32({}u)", self.workgroup_size.2));
758 }
759 WgslIntrinsic::WorkgroupBarrier | WgslIntrinsic::StorageBarrier => {
760 return Ok(wgsl_name.to_string());
762 }
763 WgslIntrinsic::SubgroupInvocationId | WgslIntrinsic::SubgroupSize => {
764 return Ok(wgsl_name.to_string());
766 }
767 _ => {}
768 }
769
770 let transpiled_args: Vec<String> = args
772 .iter()
773 .map(|a| self.transpile_expr(a))
774 .collect::<Result<_>>()?;
775
776 Ok(format!("{}({})", wgsl_name, transpiled_args.join(", ")))
777 }
778
779 fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
781 let receiver = self.transpile_expr(&method.receiver)?;
782 let method_name = method.method.to_string();
783
784 if let Some(result) =
786 self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
787 {
788 return result;
789 }
790
791 if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
793 return self.transpile_context_method(&method_name, &method.args);
794 }
795
796 if self.grid_pos_vars.contains(&receiver) {
798 return self.transpile_stencil_intrinsic(&method_name, &method.args);
799 }
800
801 if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
803 let wgsl_name = intrinsic.to_wgsl();
804 let args: Vec<String> = std::iter::once(receiver)
805 .chain(
806 method
807 .args
808 .iter()
809 .map(|a| self.transpile_expr(a).unwrap_or_default()),
810 )
811 .collect();
812
813 return Ok(format!("{}({})", wgsl_name, args.join(", ")));
814 }
815
816 let args: Vec<String> = method
818 .args
819 .iter()
820 .map(|a| self.transpile_expr(a))
821 .collect::<Result<_>>()?;
822
823 Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
824 }
825
826 fn transpile_context_method(
828 &self,
829 method: &str,
830 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
831 ) -> Result<String> {
832 let ctx_method = WgslContextMethod::from_name(method).ok_or_else(|| {
833 TranspileError::Unsupported(format!("Unknown context method: {}", method))
834 })?;
835
836 let wgsl_args: Vec<String> = args
837 .iter()
838 .map(|a| self.transpile_expr(a).unwrap_or_default())
839 .collect();
840
841 match ctx_method {
842 WgslContextMethod::AtomicAdd
843 | WgslContextMethod::AtomicLoad
844 | WgslContextMethod::AtomicStore => Ok(format!(
845 "{}({})",
846 ctx_method.to_wgsl(),
847 wgsl_args.join(", ")
848 )),
849 _ => Ok(ctx_method.to_wgsl().to_string()),
850 }
851 }
852
853 fn transpile_stencil_intrinsic(
855 &self,
856 method: &str,
857 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
858 ) -> Result<String> {
859 let config = self.config.as_ref().ok_or_else(|| {
860 TranspileError::Unsupported("Stencil intrinsic without config".into())
861 })?;
862
863 let buffer_width = config.buffer_width();
864
865 match method {
866 "idx" => Ok("idx".to_string()),
867 "x" => Ok("i32(lx)".to_string()),
868 "y" => Ok("i32(ly)".to_string()),
869 "north" => {
870 if args.is_empty() {
871 return Err(TranspileError::Unsupported(
872 "north() requires buffer argument".into(),
873 ));
874 }
875 let buffer = self.transpile_expr(&args[0])?;
876 Ok(format!("{buffer}[idx - {buffer_width}u]"))
877 }
878 "south" => {
879 if args.is_empty() {
880 return Err(TranspileError::Unsupported(
881 "south() requires buffer argument".into(),
882 ));
883 }
884 let buffer = self.transpile_expr(&args[0])?;
885 Ok(format!("{buffer}[idx + {buffer_width}u]"))
886 }
887 "east" => {
888 if args.is_empty() {
889 return Err(TranspileError::Unsupported(
890 "east() requires buffer argument".into(),
891 ));
892 }
893 let buffer = self.transpile_expr(&args[0])?;
894 Ok(format!("{buffer}[idx + 1u]"))
895 }
896 "west" => {
897 if args.is_empty() {
898 return Err(TranspileError::Unsupported(
899 "west() requires buffer argument".into(),
900 ));
901 }
902 let buffer = self.transpile_expr(&args[0])?;
903 Ok(format!("{buffer}[idx - 1u]"))
904 }
905 "at" => {
906 if args.len() < 3 {
907 return Err(TranspileError::Unsupported(
908 "at() requires buffer, dx, dy arguments".into(),
909 ));
910 }
911 let buffer = self.transpile_expr(&args[0])?;
912 let dx = self.transpile_expr(&args[1])?;
913 let dy = self.transpile_expr(&args[2])?;
914 Ok(format!(
915 "{buffer}[idx + u32(({dy}) * i32({buffer_width}u) + ({dx}))]"
916 ))
917 }
918 _ => Err(TranspileError::Unsupported(format!(
919 "Unknown stencil intrinsic: {}",
920 method
921 ))),
922 }
923 }
924
925 fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
927 let cond = self.transpile_expr(&if_expr.cond)?;
928
929 if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
931 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
932 if ret.expr.is_none() {
933 return Ok(format!("if ({cond}) {{ return; }}"));
934 }
935 let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
936 return Ok(format!("if ({cond}) {{ return {ret_val}; }}"));
937 }
938 }
939
940 if let Some((_, else_branch)) = &if_expr.else_branch {
942 if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
943 (if_expr.then_branch.stmts.last(), else_branch.as_ref())
944 {
945 if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
946 let then_str = self.transpile_expr(then_expr)?;
947 let else_str = self.transpile_expr(else_expr)?;
948 return Ok(format!("select({else_str}, {then_str}, {cond})"));
949 }
950 }
951
952 if let Expr::If(else_if) = else_branch.as_ref() {
954 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
955 let else_part = self.transpile_if(else_if)?;
956 return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
957 } else if let Expr::Block(else_block) = else_branch.as_ref() {
958 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
959 let else_body = self.transpile_if_body(&else_block.block)?;
960 return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
961 }
962 }
963
964 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
966 Ok(format!("if ({cond}) {{{then_body}}}"))
967 }
968
969 fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
971 let mut body = String::new();
972 for stmt in &block.stmts {
973 match stmt {
974 Stmt::Expr(expr, Some(_)) => {
975 let expr_str = self.transpile_expr(expr)?;
976 body.push_str(&format!(" {expr_str};"));
977 }
978 Stmt::Expr(Expr::Return(ret), None) => {
979 if let Some(ret_expr) = &ret.expr {
980 let expr_str = self.transpile_expr(ret_expr)?;
981 body.push_str(&format!(" return {expr_str};"));
982 } else {
983 body.push_str(" return;");
984 }
985 }
986 Stmt::Expr(expr, None) => {
987 let expr_str = self.transpile_expr(expr)?;
988 body.push_str(&format!(" return {expr_str};"));
989 }
990 _ => {}
991 }
992 }
993 Ok(body)
994 }
995
996 fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
998 let left = self.transpile_expr(&assign.left)?;
999 let right = self.transpile_expr(&assign.right)?;
1000 Ok(format!("{left} = {right}"))
1001 }
1002
1003 fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
1005 let expr = self.transpile_expr(&cast.expr)?;
1006 let wgsl_type = self
1007 .type_mapper
1008 .map_type(&cast.ty)
1009 .map_err(TranspileError::Type)?;
1010 Ok(format!("{}({})", wgsl_type.to_wgsl(), expr))
1011 }
1012
1013 fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
1015 if let Some(expr) = &ret.expr {
1016 let expr_str = self.transpile_expr(expr)?;
1017 Ok(format!("return {expr_str}"))
1018 } else {
1019 Ok("return".to_string())
1020 }
1021 }
1022
1023 fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1025 let type_name = struct_expr
1026 .path
1027 .segments
1028 .iter()
1029 .map(|s| s.ident.to_string())
1030 .collect::<Vec<_>>()
1031 .join("_");
1032
1033 let mut fields = Vec::new();
1034 for field in &struct_expr.fields {
1035 let field_name = match &field.member {
1036 syn::Member::Named(ident) => ident.to_string(),
1037 syn::Member::Unnamed(idx) => idx.index.to_string(),
1038 };
1039 let value = self.transpile_expr(&field.expr)?;
1040 fields.push(format!("{}: {}", field_name, value));
1041 }
1042
1043 if struct_expr.rest.is_some() {
1044 return Err(TranspileError::Unsupported(
1045 "Struct update syntax (..base) is not supported in WGSL".into(),
1046 ));
1047 }
1048
1049 Ok(format!("{}({})", type_name, fields.join(", ")))
1050 }
1051
1052 fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1056 if !self.validation_mode.allows_loops() {
1057 return Err(TranspileError::Unsupported(
1058 "Loops are not allowed in stencil kernels".into(),
1059 ));
1060 }
1061
1062 let var_name = extract_loop_var(&for_loop.pat)
1063 .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1064
1065 let header = match for_loop.expr.as_ref() {
1066 Expr::Range(range) => {
1067 let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1068 range_info.to_wgsl_for_header(&var_name)
1069 }
1070 _ => {
1071 return Err(TranspileError::Unsupported(
1072 "Only range expressions are supported in for loops".into(),
1073 ));
1074 }
1075 };
1076
1077 let body = self.transpile_loop_body(&for_loop.body)?;
1078 Ok(format!("{header} {{\n{body}}}"))
1079 }
1080
1081 fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1083 if !self.validation_mode.allows_loops() {
1084 return Err(TranspileError::Unsupported(
1085 "Loops are not allowed in stencil kernels".into(),
1086 ));
1087 }
1088
1089 let condition = self.transpile_expr(&while_loop.cond)?;
1090 let body = self.transpile_loop_body(&while_loop.body)?;
1091 Ok(format!("while ({condition}) {{\n{body}}}"))
1092 }
1093
1094 fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1096 if !self.validation_mode.allows_loops() {
1097 return Err(TranspileError::Unsupported(
1098 "Loops are not allowed in stencil kernels".into(),
1099 ));
1100 }
1101
1102 let body = self.transpile_loop_body(&loop_expr.body)?;
1103 Ok(format!("loop {{\n{body}}}"))
1104 }
1105
1106 fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1108 if break_expr.label.is_some() {
1109 return Err(TranspileError::Unsupported(
1110 "Labeled break is not supported in WGSL".into(),
1111 ));
1112 }
1113 if break_expr.expr.is_some() {
1114 return Err(TranspileError::Unsupported(
1115 "Break with value is not supported in WGSL".into(),
1116 ));
1117 }
1118 Ok("break".to_string())
1119 }
1120
1121 fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1123 if cont_expr.label.is_some() {
1124 return Err(TranspileError::Unsupported(
1125 "Labeled continue is not supported in WGSL".into(),
1126 ));
1127 }
1128 Ok("continue".to_string())
1129 }
1130
1131 fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1133 let mut output = String::new();
1134 let inner_indent = " ".repeat(self.indent + 1);
1135
1136 for stmt in &block.stmts {
1137 match stmt {
1138 Stmt::Local(local) => {
1139 let var_name = match &local.pat {
1140 Pat::Ident(ident) => ident.ident.to_string(),
1141 Pat::Type(pat_type) => {
1142 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1143 ident.ident.to_string()
1144 } else {
1145 return Err(TranspileError::Unsupported(
1146 "Complex pattern in let binding".into(),
1147 ));
1148 }
1149 }
1150 _ => {
1151 return Err(TranspileError::Unsupported(
1152 "Complex pattern in let binding".into(),
1153 ))
1154 }
1155 };
1156
1157 if let Some(init) = &local.init {
1158 let expr_str = self.transpile_expr(&init.expr)?;
1159 let type_str = self.infer_wgsl_type(&init.expr);
1160 output.push_str(&format!(
1161 "{inner_indent}var {var_name}: {type_str} = {expr_str};\n"
1162 ));
1163 } else {
1164 output.push_str(&format!("{inner_indent}var {var_name}: f32;\n"));
1165 }
1166 }
1167 Stmt::Expr(expr, _semi) => {
1168 let expr_str = self.transpile_expr(expr)?;
1169 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1170 }
1171 _ => {
1172 return Err(TranspileError::Unsupported(
1173 "Unsupported statement in loop body".into(),
1174 ));
1175 }
1176 }
1177 }
1178
1179 let closing_indent = " ".repeat(self.indent);
1180 output.push_str(&closing_indent);
1181
1182 Ok(output)
1183 }
1184
1185 fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1187 let scrutinee = self.transpile_expr(&match_expr.expr)?;
1188 let mut output = format!("switch ({scrutinee}) {{\n");
1189
1190 for arm in &match_expr.arms {
1191 let case_label = self.transpile_match_pattern(&arm.pat)?;
1192
1193 if case_label == "default" || case_label.starts_with("/*") {
1194 output.push_str(" default: {\n");
1195 } else {
1196 output.push_str(&format!(" case {case_label}: {{\n"));
1197 }
1198
1199 match arm.body.as_ref() {
1200 Expr::Block(block) => {
1201 for stmt in &block.block.stmts {
1202 let stmt_str = self.transpile_stmt_inline(stmt)?;
1203 output.push_str(&format!(" {stmt_str}\n"));
1204 }
1205 }
1206 _ => {
1207 let body = self.transpile_expr(&arm.body)?;
1208 output.push_str(&format!(" {body};\n"));
1209 }
1210 }
1211
1212 output.push_str(" }\n");
1213 }
1214
1215 output.push('}');
1216 Ok(output)
1217 }
1218
1219 fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1221 match pat {
1222 Pat::Lit(pat_lit) => match &pat_lit.lit {
1223 Lit::Int(i) => Ok(i.to_string()),
1224 Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
1225 _ => Err(TranspileError::Unsupported(
1226 "Non-integer literal in match pattern".into(),
1227 )),
1228 },
1229 Pat::Wild(_) => Ok("default".to_string()),
1230 Pat::Ident(ident) => Ok(format!("/* {} */ default", ident.ident)),
1231 Pat::Or(pat_or) => {
1232 if let Some(first) = pat_or.cases.first() {
1233 self.transpile_match_pattern(first)
1234 } else {
1235 Err(TranspileError::Unsupported("Empty or pattern".into()))
1236 }
1237 }
1238 _ => Err(TranspileError::Unsupported(format!(
1239 "Match pattern: {}",
1240 pat.to_token_stream()
1241 ))),
1242 }
1243 }
1244
1245 fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1247 match stmt {
1248 Stmt::Local(local) => {
1249 let var_name = match &local.pat {
1250 Pat::Ident(ident) => ident.ident.to_string(),
1251 Pat::Type(pat_type) => {
1252 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1253 ident.ident.to_string()
1254 } else {
1255 return Err(TranspileError::Unsupported(
1256 "Complex pattern in let binding".into(),
1257 ));
1258 }
1259 }
1260 _ => {
1261 return Err(TranspileError::Unsupported(
1262 "Complex pattern in let binding".into(),
1263 ))
1264 }
1265 };
1266
1267 if let Some(init) = &local.init {
1268 let expr_str = self.transpile_expr(&init.expr)?;
1269 let type_str = self.infer_wgsl_type(&init.expr);
1270 Ok(format!("var {var_name}: {type_str} = {expr_str};"))
1271 } else {
1272 Ok(format!("var {var_name}: f32;"))
1273 }
1274 }
1275 Stmt::Expr(expr, semi) => {
1276 let expr_str = self.transpile_expr(expr)?;
1277 if semi.is_some() {
1278 Ok(format!("{expr_str};"))
1279 } else {
1280 Ok(format!("return {expr_str};"))
1281 }
1282 }
1283 _ => Err(TranspileError::Unsupported(
1284 "Unsupported statement in match arm".into(),
1285 )),
1286 }
1287 }
1288
1289 fn try_parse_shared_declaration(
1293 &self,
1294 local: &syn::Local,
1295 var_name: &str,
1296 ) -> Result<Option<SharedMemoryDecl>> {
1297 if let Pat::Type(pat_type) = &local.pat {
1298 let type_str = pat_type.ty.to_token_stream().to_string();
1299 return self.parse_shared_type(&type_str, var_name);
1300 }
1301
1302 if let Some(init) = &local.init {
1303 if let Expr::Call(call) = init.expr.as_ref() {
1304 if let Expr::Path(path) = call.func.as_ref() {
1305 let path_str = path.to_token_stream().to_string();
1306 return self.parse_shared_type(&path_str, var_name);
1307 }
1308 }
1309 }
1310
1311 Ok(None)
1312 }
1313
1314 fn parse_shared_type(
1316 &self,
1317 type_str: &str,
1318 var_name: &str,
1319 ) -> Result<Option<SharedMemoryDecl>> {
1320 let type_str = type_str
1321 .replace(" :: ", "::")
1322 .replace(" ::", "::")
1323 .replace(":: ", "::");
1324
1325 if type_str.contains("SharedTile") {
1326 if let Some(start) = type_str.find('<') {
1327 if let Some(end) = type_str.rfind('>') {
1328 let params = &type_str[start + 1..end];
1329 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1330
1331 if parts.len() >= 3 {
1332 let rust_type = parts[0];
1333 let width: u32 = parts[1].parse().map_err(|_| {
1334 TranspileError::Unsupported("Invalid SharedTile width".into())
1335 })?;
1336 let height: u32 = parts[2].parse().map_err(|_| {
1337 TranspileError::Unsupported("Invalid SharedTile height".into())
1338 })?;
1339
1340 let wgsl_type = rust_type_to_wgsl(rust_type);
1341 return Ok(Some(SharedMemoryDecl::new_2d(
1342 var_name, wgsl_type, width, height,
1343 )));
1344 }
1345 }
1346 }
1347 }
1348
1349 if type_str.contains("SharedArray") {
1350 if let Some(start) = type_str.find('<') {
1351 if let Some(end) = type_str.rfind('>') {
1352 let params = &type_str[start + 1..end];
1353 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1354
1355 if parts.len() >= 2 {
1356 let rust_type = parts[0];
1357 let size: u32 = parts[1].parse().map_err(|_| {
1358 TranspileError::Unsupported("Invalid SharedArray size".into())
1359 })?;
1360
1361 let wgsl_type = rust_type_to_wgsl(rust_type);
1362 return Ok(Some(SharedMemoryDecl::new_1d(var_name, wgsl_type, size)));
1363 }
1364 }
1365 }
1366 }
1367
1368 Ok(None)
1369 }
1370
1371 fn try_transpile_shared_method_call(
1373 &self,
1374 receiver: &str,
1375 method_name: &str,
1376 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1377 ) -> Option<Result<String>> {
1378 let shared_info = self.shared_vars.get(receiver)?;
1379
1380 match method_name {
1381 "get" => {
1382 if shared_info.is_tile {
1383 if args.len() >= 2 {
1384 let x = self.transpile_expr(&args[0]).ok()?;
1385 let y = self.transpile_expr(&args[1]).ok()?;
1386 Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1387 } else {
1388 Some(Err(TranspileError::Unsupported(
1389 "SharedTile.get requires x and y arguments".into(),
1390 )))
1391 }
1392 } else if !args.is_empty() {
1393 let idx = self.transpile_expr(&args[0]).ok()?;
1394 Some(Ok(format!("{}[{}]", receiver, idx)))
1395 } else {
1396 Some(Err(TranspileError::Unsupported(
1397 "SharedArray.get requires index argument".into(),
1398 )))
1399 }
1400 }
1401 "set" => {
1402 if shared_info.is_tile {
1403 if args.len() >= 3 {
1404 let x = self.transpile_expr(&args[0]).ok()?;
1405 let y = self.transpile_expr(&args[1]).ok()?;
1406 let val = self.transpile_expr(&args[2]).ok()?;
1407 Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1408 } else {
1409 Some(Err(TranspileError::Unsupported(
1410 "SharedTile.set requires x, y, and value arguments".into(),
1411 )))
1412 }
1413 } else if args.len() >= 2 {
1414 let idx = self.transpile_expr(&args[0]).ok()?;
1415 let val = self.transpile_expr(&args[1]).ok()?;
1416 Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1417 } else {
1418 Some(Err(TranspileError::Unsupported(
1419 "SharedArray.set requires index and value arguments".into(),
1420 )))
1421 }
1422 }
1423 "width" if shared_info.is_tile => Some(Ok(shared_info.dimensions[1].to_string())),
1424 "height" if shared_info.is_tile => Some(Ok(shared_info.dimensions[0].to_string())),
1425 "size" => {
1426 let total: usize = shared_info.dimensions.iter().product();
1427 Some(Ok(total.to_string()))
1428 }
1429 _ => None,
1430 }
1431 }
1432
1433 fn infer_wgsl_type(&self, expr: &Expr) -> &'static str {
1435 match expr {
1436 Expr::Lit(lit) => match &lit.lit {
1437 Lit::Float(_) => "f32",
1438 Lit::Int(i) => {
1439 let s = i.to_string();
1440 if s.ends_with("u32") || s.ends_with("usize") {
1441 "u32"
1442 } else {
1443 "i32"
1444 }
1445 }
1446 Lit::Bool(_) => "bool",
1447 _ => "f32",
1448 },
1449 Expr::Binary(bin) => {
1450 let left_type = self.infer_wgsl_type(&bin.left);
1451 let right_type = self.infer_wgsl_type(&bin.right);
1452 if left_type == "i32" && right_type == "i32" {
1453 "i32"
1454 } else if left_type == "u32" && right_type == "u32" {
1455 "u32"
1456 } else {
1457 "f32"
1458 }
1459 }
1460 Expr::Call(call) => {
1461 if let Ok(func) = self.transpile_expr(&call.func) {
1462 if let Some(
1463 WgslIntrinsic::LocalInvocationIdX
1464 | WgslIntrinsic::WorkgroupIdX
1465 | WgslIntrinsic::GlobalInvocationIdX,
1466 ) = self.intrinsics.lookup(&func)
1467 {
1468 return "i32";
1469 }
1470 }
1471 "f32"
1472 }
1473 Expr::Cast(cast) => {
1474 if let Ok(wgsl_type) = self.type_mapper.map_type(&cast.ty) {
1475 match wgsl_type {
1476 WgslType::I32 => return "i32",
1477 WgslType::U32 => return "u32",
1478 WgslType::F32 => return "f32",
1479 WgslType::Bool => return "bool",
1480 _ => {}
1481 }
1482 }
1483 "f32"
1484 }
1485 _ => "f32",
1486 }
1487 }
1488}
1489
1490fn extract_loop_var(pat: &Pat) -> Option<String> {
1492 match pat {
1493 Pat::Ident(ident) => Some(ident.ident.to_string()),
1494 _ => None,
1495 }
1496}
1497
1498fn rust_type_to_wgsl(rust_type: &str) -> WgslType {
1500 match rust_type {
1501 "f32" => WgslType::F32,
1502 "f64" => WgslType::F32, "i32" => WgslType::I32,
1504 "u32" => WgslType::U32,
1505 "bool" => WgslType::Bool,
1506 _ => WgslType::F32,
1507 }
1508}
1509
1510pub fn transpile_function(func: &ItemFn) -> Result<String> {
1512 let mut transpiler = WgslTranspiler::new_generic();
1513
1514 let name = func.sig.ident.to_string();
1515
1516 let mut params = Vec::new();
1517 for param in &func.sig.inputs {
1518 if let FnArg::Typed(pat_type) = param {
1519 let param_name = match pat_type.pat.as_ref() {
1520 Pat::Ident(ident) => ident.ident.to_string(),
1521 _ => continue,
1522 };
1523
1524 let wgsl_type = transpiler
1525 .type_mapper
1526 .map_type(&pat_type.ty)
1527 .map_err(TranspileError::Type)?;
1528 params.push(format!("{}: {}", param_name, wgsl_type.to_wgsl()));
1529 }
1530 }
1531
1532 let return_type = match &func.sig.output {
1533 ReturnType::Default => "".to_string(),
1534 ReturnType::Type(_, ty) => {
1535 let wgsl_type = transpiler
1536 .type_mapper
1537 .map_type(ty)
1538 .map_err(TranspileError::Type)?;
1539 format!(" -> {}", wgsl_type.to_wgsl())
1540 }
1541 };
1542
1543 let body = transpiler.transpile_block(&func.block)?;
1544
1545 Ok(format!(
1546 "fn {name}({}){return_type} {{\n{body}}}\n",
1547 params.join(", ")
1548 ))
1549}
1550
1551impl RangeInfo {
1554 pub fn from_range<F>(range: &syn::ExprRange, transpile: F) -> Self
1556 where
1557 F: Fn(&Expr) -> Result<String>,
1558 {
1559 let start = range.start.as_ref().and_then(|e| transpile(e).ok());
1560 let end = range.end.as_ref().and_then(|e| transpile(e).ok());
1561 let inclusive = matches!(range.limits, syn::RangeLimits::Closed(_));
1562
1563 Self::new(start, end, inclusive)
1564 }
1565
1566 pub fn to_wgsl_for_header(&self, var_name: &str) -> String {
1568 let start = self.start_or_default();
1569 let end = self.end.as_deref().unwrap_or("/* unbounded */");
1570 let op = if self.inclusive { "<=" } else { "<" };
1571
1572 format!(
1573 "for (var {var_name}: i32 = {start}; {var_name} {op} {end}; {var_name} = {var_name} + 1)"
1574 )
1575 }
1576}
1577
1578#[cfg(test)]
1579mod tests {
1580 use super::*;
1581 use syn::parse_quote;
1582
1583 #[test]
1584 fn test_simple_arithmetic() {
1585 let transpiler = WgslTranspiler::new_generic();
1586
1587 let expr: Expr = parse_quote!(a + b * 2.0);
1588 let result = transpiler.transpile_expr(&expr).unwrap();
1589 assert_eq!(result, "a + b * 2.0");
1590 }
1591
1592 #[test]
1593 fn test_let_binding() {
1594 let mut transpiler = WgslTranspiler::new_generic();
1595
1596 let stmt: Stmt = parse_quote!(let x = a + b;);
1597 let result = transpiler.transpile_stmt(&stmt).unwrap();
1598 assert!(result.contains("var x: f32 = a + b;"));
1599 }
1600
1601 #[test]
1602 fn test_array_index() {
1603 let transpiler = WgslTranspiler::new_generic();
1604
1605 let expr: Expr = parse_quote!(data[idx]);
1606 let result = transpiler.transpile_expr(&expr).unwrap();
1607 assert_eq!(result, "data[idx]");
1608 }
1609
1610 #[test]
1611 fn test_intrinsic_thread_idx() {
1612 let transpiler = WgslTranspiler::new_generic();
1613
1614 let expr: Expr = parse_quote!(thread_idx_x());
1615 let result = transpiler.transpile_expr(&expr).unwrap();
1616 assert!(result.contains("local_invocation_id.x"));
1617 }
1618
1619 #[test]
1620 fn test_for_loop() {
1621 let transpiler = WgslTranspiler::new_generic();
1622
1623 let expr: Expr = parse_quote! {
1624 for i in 0..n {
1625 data[i] = 0.0;
1626 }
1627 };
1628
1629 let result = transpiler.transpile_expr(&expr).unwrap();
1630 assert!(result.contains("for (var i: i32 = 0; i < n; i = i + 1)"));
1631 }
1632
1633 #[test]
1634 fn test_while_loop() {
1635 let transpiler = WgslTranspiler::new_generic();
1636
1637 let expr: Expr = parse_quote! {
1638 while x < 10 {
1639 x += 1;
1640 }
1641 };
1642
1643 let result = transpiler.transpile_expr(&expr).unwrap();
1644 assert!(result.contains("while (x < 10)"));
1645 }
1646
1647 #[test]
1648 fn test_early_return() {
1649 let transpiler = WgslTranspiler::new_generic();
1650
1651 let expr: Expr = parse_quote! {
1652 if idx >= n { return; }
1653 };
1654
1655 let result = transpiler.transpile_expr(&expr).unwrap();
1656 assert!(result.contains("if (idx >= n)"));
1657 assert!(result.contains("return;"));
1658 }
1659
1660 #[test]
1661 fn test_sync_threads() {
1662 let transpiler = WgslTranspiler::new_generic();
1663
1664 let expr: Expr = parse_quote!(sync_threads());
1665 let result = transpiler.transpile_expr(&expr).unwrap();
1666 assert_eq!(result, "workgroupBarrier()");
1667 }
1668}