Skip to main content

ringkernel_wgpu_codegen/
transpiler.rs

1//! Core Rust-to-WGSL transpiler.
2//!
3//! This module handles the translation of Rust AST to WGSL code.
4
5use crate::bindings::{generate_bindings, AccessMode, BindingLayout};
6use crate::handler::WgslContextMethod;
7use crate::intrinsics::{IntrinsicRegistry, WgslIntrinsic};
8use crate::loops::RangeInfo;
9use crate::ring_kernel::RingKernelConfig;
10use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
11use crate::stencil::StencilConfig;
12use crate::types::{
13    is_grid_pos_type, is_mutable_reference, is_ring_context_type, TypeMapper, WgslType,
14};
15use crate::u64_workarounds::U64Helpers;
16use crate::validation::ValidationMode;
17use crate::{Result, TranspileError};
18use quote::ToTokens;
19use std::cell::Cell;
20use std::collections::HashMap;
21use syn::{
22    BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
23    ExprIf, ExprIndex, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
24    ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat, ReturnType, Stmt, UnOp,
25};
26
27/// WGSL code transpiler.
28pub struct WgslTranspiler {
29    /// Stencil configuration (if generating a stencil kernel).
30    config: Option<StencilConfig>,
31    /// Ring kernel configuration (if generating a ring kernel).
32    #[allow(dead_code)]
33    ring_config: Option<RingKernelConfig>,
34    /// Type mapper.
35    type_mapper: TypeMapper,
36    /// Intrinsic registry.
37    intrinsics: IntrinsicRegistry,
38    /// Variables known to be the GridPos context.
39    grid_pos_vars: Vec<String>,
40    /// Variables known to be RingContext references.
41    context_vars: Vec<String>,
42    /// Current indentation level.
43    indent: usize,
44    /// Validation mode for loop handling.
45    validation_mode: ValidationMode,
46    /// Shared memory configuration.
47    shared_memory: SharedMemoryConfig,
48    /// Variables that are SharedTile or SharedArray types.
49    pub shared_vars: HashMap<String, SharedVarInfo>,
50    /// Whether we're in ring kernel mode (enables context method inlining).
51    ring_kernel_mode: bool,
52    /// Whether to include u64 helper functions.
53    needs_u64_helpers: bool,
54    /// Whether subgroup operations are used (need extension).
55    /// Uses Cell for interior mutability during transpilation.
56    needs_subgroup_extension: Cell<bool>,
57    /// Workgroup size for generic kernels.
58    workgroup_size: (u32, u32, u32),
59    /// Collected buffer bindings.
60    bindings: Vec<BindingLayout>,
61}
62
63/// Information about a shared memory variable.
64#[derive(Debug, Clone)]
65pub struct SharedVarInfo {
66    /// Variable name.
67    pub name: String,
68    /// Whether it's a 2D tile (true) or 1D array (false).
69    pub is_tile: bool,
70    /// Dimensions: [size] for 1D, [height, width] for 2D.
71    pub dimensions: Vec<usize>,
72    /// Element type (WGSL type string).
73    pub element_type: String,
74}
75
76impl WgslTranspiler {
77    /// Create a new transpiler with stencil configuration.
78    pub fn new_stencil(config: StencilConfig) -> Self {
79        Self {
80            config: Some(config),
81            ring_config: None,
82            type_mapper: TypeMapper::new(),
83            intrinsics: IntrinsicRegistry::new(),
84            grid_pos_vars: Vec::new(),
85            context_vars: Vec::new(),
86            indent: 1,
87            validation_mode: ValidationMode::Stencil,
88            shared_memory: SharedMemoryConfig::new(),
89            shared_vars: HashMap::new(),
90            ring_kernel_mode: false,
91            needs_u64_helpers: false,
92            needs_subgroup_extension: Cell::new(false),
93            workgroup_size: (256, 1, 1),
94            bindings: Vec::new(),
95        }
96    }
97
98    /// Create a new transpiler without stencil configuration.
99    pub fn new_generic() -> Self {
100        Self {
101            config: None,
102            ring_config: None,
103            type_mapper: TypeMapper::new(),
104            intrinsics: IntrinsicRegistry::new(),
105            grid_pos_vars: Vec::new(),
106            context_vars: Vec::new(),
107            indent: 1,
108            validation_mode: ValidationMode::Generic,
109            shared_memory: SharedMemoryConfig::new(),
110            shared_vars: HashMap::new(),
111            ring_kernel_mode: false,
112            needs_u64_helpers: false,
113            needs_subgroup_extension: Cell::new(false),
114            workgroup_size: (256, 1, 1),
115            bindings: Vec::new(),
116        }
117    }
118
119    /// Create a new transpiler for ring kernel generation.
120    pub fn new_ring_kernel(config: RingKernelConfig) -> Self {
121        Self {
122            config: None,
123            ring_config: Some(config.clone()),
124            type_mapper: TypeMapper::new(),
125            intrinsics: IntrinsicRegistry::new(),
126            grid_pos_vars: Vec::new(),
127            context_vars: Vec::new(),
128            indent: 2,
129            validation_mode: ValidationMode::Generic,
130            shared_memory: SharedMemoryConfig::new(),
131            shared_vars: HashMap::new(),
132            ring_kernel_mode: true,
133            needs_u64_helpers: true, // Ring kernels typically need 64-bit counters
134            needs_subgroup_extension: Cell::new(false),
135            workgroup_size: (config.workgroup_size, 1, 1),
136            bindings: Vec::new(),
137        }
138    }
139
140    /// Set the workgroup size for generic kernels.
141    pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
142        self.workgroup_size = (x, y, z);
143        self
144    }
145
146    /// Get current indentation string.
147    fn indent_str(&self) -> String {
148        "    ".repeat(self.indent)
149    }
150
151    /// Transpile a stencil kernel function.
152    pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
153        let config = self
154            .config
155            .as_ref()
156            .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
157            .clone();
158
159        // Identify GridPos parameters
160        for param in &func.sig.inputs {
161            if let FnArg::Typed(pat_type) = param {
162                if is_grid_pos_type(&pat_type.ty) {
163                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
164                        self.grid_pos_vars.push(ident.ident.to_string());
165                    }
166                }
167            }
168        }
169
170        // Generate bindings from parameters
171        self.collect_bindings(func)?;
172
173        let mut output = String::new();
174
175        // Generate buffer bindings
176        output.push_str(&generate_bindings(&self.bindings));
177        output.push_str("\n\n");
178
179        // Generate shared memory declarations
180        if config.use_shared_memory {
181            let buffer_width = config.buffer_width();
182            let buffer_height = config.buffer_height();
183            output.push_str(&format!(
184                "var<workgroup> tile: array<array<f32, {}>, {}>;\n\n",
185                buffer_width, buffer_height
186            ));
187        }
188
189        // Generate kernel signature with builtins
190        output.push_str("@compute ");
191        output.push_str(&config.workgroup_size_annotation());
192        output.push('\n');
193        output.push_str(&format!("fn {}(\n", func.sig.ident));
194        output.push_str("    @builtin(local_invocation_id) local_id: vec3<u32>,\n");
195        output.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
196        output.push_str("    @builtin(global_invocation_id) global_id: vec3<u32>\n");
197        output.push_str(") {\n");
198
199        // Generate preamble
200        output.push_str(&self.generate_stencil_preamble(&config));
201
202        // Generate function body
203        let body = self.transpile_block(&func.block)?;
204        output.push_str(&body);
205
206        output.push_str("}\n");
207
208        Ok(output)
209    }
210
211    /// Generate stencil kernel preamble.
212    fn generate_stencil_preamble(&self, config: &StencilConfig) -> String {
213        let buffer_width = config.buffer_width();
214        let mut preamble = String::new();
215
216        preamble.push_str("    // Thread indices\n");
217        preamble.push_str("    let lx = local_id.x;\n");
218        preamble.push_str("    let ly = local_id.y;\n");
219        preamble.push_str(&format!("    let buffer_width = {}u;\n", buffer_width));
220        preamble.push('\n');
221
222        // Bounds check
223        preamble.push_str(&format!(
224            "    if (lx >= {}u || ly >= {}u) {{ return; }}\n\n",
225            config.tile_width, config.tile_height
226        ));
227
228        // Calculate buffer index with halo offset
229        preamble.push_str(&format!(
230            "    let idx = (ly + {}u) * buffer_width + (lx + {}u);\n\n",
231            config.halo, config.halo
232        ));
233
234        preamble
235    }
236
237    /// Transpile a generic (non-stencil) kernel function.
238    pub fn transpile_global_kernel(&mut self, func: &ItemFn) -> Result<String> {
239        // Collect bindings from parameters
240        self.collect_bindings(func)?;
241
242        // Generate function body first to detect subgroup usage
243        let body = self.transpile_block(&func.block)?;
244
245        let mut output = String::new();
246
247        // Add extensions if subgroup operations were used
248        if self.needs_subgroup_extension.get() {
249            output.push_str("enable chromium_experimental_subgroups;\n\n");
250        }
251
252        // Generate buffer bindings
253        output.push_str(&generate_bindings(&self.bindings));
254        output.push_str("\n\n");
255
256        // Generate shared memory declarations
257        if !self.shared_memory.declarations.is_empty() {
258            output.push_str(&self.shared_memory.to_wgsl());
259            output.push_str("\n\n");
260        }
261
262        // Generate kernel signature with subgroup builtins if needed
263        output.push_str("@compute ");
264        output.push_str(&format!(
265            "@workgroup_size({}, {}, {})\n",
266            self.workgroup_size.0, self.workgroup_size.1, self.workgroup_size.2
267        ));
268        output.push_str(&format!("fn {}(\n", func.sig.ident));
269        output.push_str("    @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
270        output.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
271        output.push_str("    @builtin(global_invocation_id) global_invocation_id: vec3<u32>,\n");
272        output.push_str("    @builtin(num_workgroups) num_workgroups: vec3<u32>");
273
274        // Add subgroup builtins if needed
275        if self.needs_subgroup_extension.get() {
276            output
277                .push_str(",\n    @builtin(subgroup_invocation_id) subgroup_invocation_id: u32,\n");
278            output.push_str("    @builtin(subgroup_size) subgroup_size: u32\n");
279        } else {
280            output.push('\n');
281        }
282        output.push_str(") {\n");
283
284        // Add the already-transpiled body
285        output.push_str(&body);
286
287        output.push_str("}\n");
288
289        Ok(output)
290    }
291
292    /// Transpile a handler function into a ring kernel.
293    pub fn transpile_ring_kernel(
294        &mut self,
295        handler: &ItemFn,
296        config: &RingKernelConfig,
297    ) -> Result<String> {
298        // Track context variables for method inlining
299        for param in &handler.sig.inputs {
300            if let FnArg::Typed(pat_type) = param {
301                if is_ring_context_type(&pat_type.ty) {
302                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
303                        self.context_vars.push(ident.ident.to_string());
304                    }
305                }
306            }
307        }
308
309        self.ring_kernel_mode = true;
310        self.needs_u64_helpers = true;
311
312        let mut output = String::new();
313
314        // Generate 64-bit helper functions
315        output.push_str(&U64Helpers::generate_all());
316        output.push_str("\n\n");
317
318        // Generate ControlBlock struct
319        output.push_str(&crate::ring_kernel::generate_control_block_struct(config));
320        output.push_str("\n\n");
321
322        // Generate bindings
323        output.push_str(crate::ring_kernel::generate_ring_kernel_bindings());
324        output.push_str("\n\n");
325
326        // Generate kernel
327        output.push_str("@compute ");
328        output.push_str(&config.workgroup_size_annotation());
329        output.push('\n');
330        output.push_str(&format!("fn ring_kernel_{}(\n", config.name));
331        output.push_str("    @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
332        output.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
333        output.push_str("    @builtin(global_invocation_id) global_invocation_id: vec3<u32>\n");
334        output.push_str(") {\n");
335
336        // Generate preamble (activation/termination checks)
337        output.push_str(crate::ring_kernel::generate_ring_kernel_preamble());
338        output.push_str("\n\n");
339
340        // Generate handler body
341        output.push_str("    // === USER HANDLER CODE ===\n");
342        self.indent = 1;
343        let handler_body = self.transpile_block(&handler.block)?;
344        output.push_str(&handler_body);
345        output.push_str("    // === END HANDLER CODE ===\n\n");
346
347        // Update message counter
348        output.push_str("    // Update message counter\n");
349        output.push_str(
350            "    atomic_inc_u64(&control.messages_processed_lo, &control.messages_processed_hi);\n",
351        );
352
353        output.push_str("}\n");
354
355        Ok(output)
356    }
357
358    /// Collect buffer bindings from function parameters.
359    fn collect_bindings(&mut self, func: &ItemFn) -> Result<()> {
360        let mut binding_idx = 0u32;
361
362        for param in &func.sig.inputs {
363            if let FnArg::Typed(pat_type) = param {
364                // Skip GridPos and RingContext parameters
365                if is_grid_pos_type(&pat_type.ty) || is_ring_context_type(&pat_type.ty) {
366                    continue;
367                }
368
369                let param_name = match pat_type.pat.as_ref() {
370                    Pat::Ident(ident) => ident.ident.to_string(),
371                    _ => continue,
372                };
373
374                let wgsl_type = self
375                    .type_mapper
376                    .map_type(&pat_type.ty)
377                    .map_err(TranspileError::Type)?;
378                let is_mutable = is_mutable_reference(&pat_type.ty);
379
380                // Determine if this is a buffer or scalar
381                match &wgsl_type {
382                    WgslType::Ptr { inner, .. } => {
383                        let access = if is_mutable {
384                            AccessMode::ReadWrite
385                        } else {
386                            AccessMode::Read
387                        };
388                        self.bindings.push(BindingLayout::new(
389                            0,
390                            binding_idx,
391                            &param_name,
392                            WgslType::Array {
393                                element: inner.clone(),
394                                size: None,
395                            },
396                            access,
397                        ));
398                        binding_idx += 1;
399                    }
400                    WgslType::Array { .. } => {
401                        let access = if is_mutable {
402                            AccessMode::ReadWrite
403                        } else {
404                            AccessMode::Read
405                        };
406                        self.bindings.push(BindingLayout::new(
407                            0,
408                            binding_idx,
409                            &param_name,
410                            wgsl_type.clone(),
411                            access,
412                        ));
413                        binding_idx += 1;
414                    }
415                    // Scalars are typically passed via uniforms
416                    _ => {
417                        self.bindings.push(BindingLayout::uniform(
418                            binding_idx,
419                            &param_name,
420                            wgsl_type.clone(),
421                        ));
422                        binding_idx += 1;
423                    }
424                }
425            }
426        }
427
428        Ok(())
429    }
430
431    /// Transpile a block of statements.
432    fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
433        let mut output = String::new();
434
435        for stmt in &block.stmts {
436            let stmt_str = self.transpile_stmt(stmt)?;
437            if !stmt_str.is_empty() {
438                output.push_str(&stmt_str);
439            }
440        }
441
442        Ok(output)
443    }
444
445    /// Transpile a single statement.
446    fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
447        match stmt {
448            Stmt::Local(local) => {
449                let indent = self.indent_str();
450
451                // Get variable name
452                let var_name = match &local.pat {
453                    Pat::Ident(ident) => ident.ident.to_string(),
454                    Pat::Type(pat_type) => {
455                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
456                            ident.ident.to_string()
457                        } else {
458                            return Err(TranspileError::Unsupported(
459                                "Complex pattern in let binding".into(),
460                            ));
461                        }
462                    }
463                    _ => {
464                        return Err(TranspileError::Unsupported(
465                            "Complex pattern in let binding".into(),
466                        ))
467                    }
468                };
469
470                // Check for SharedTile or SharedArray type annotation
471                if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
472                    self.shared_vars.insert(
473                        var_name.clone(),
474                        SharedVarInfo {
475                            name: var_name.clone(),
476                            is_tile: shared_decl.dimensions.len() == 2,
477                            dimensions: shared_decl
478                                .dimensions
479                                .iter()
480                                .map(|&d| d as usize)
481                                .collect(),
482                            element_type: shared_decl.element_type.to_wgsl(),
483                        },
484                    );
485                    self.shared_memory.add(shared_decl.clone());
486                    return Ok(format!(
487                        "{indent}// shared memory: {} declared at module scope\n",
488                        var_name
489                    ));
490                }
491
492                // Get initializer
493                if let Some(init) = &local.init {
494                    let expr_str = self.transpile_expr(&init.expr)?;
495                    let type_str = self.infer_wgsl_type(&init.expr);
496
497                    Ok(format!(
498                        "{indent}var {var_name}: {type_str} = {expr_str};\n"
499                    ))
500                } else {
501                    Ok(format!("{indent}var {var_name}: f32;\n"))
502                }
503            }
504            Stmt::Expr(expr, semi) => {
505                let indent = self.indent_str();
506
507                // Handle early return pattern in if statements
508                if let Expr::If(if_expr) = expr {
509                    if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
510                    {
511                        if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
512                            let expr_str = self.transpile_expr(expr)?;
513                            return Ok(format!("{indent}{expr_str}\n"));
514                        }
515                    }
516                }
517
518                let expr_str = self.transpile_expr(expr)?;
519
520                if semi.is_some()
521                    || matches!(expr, Expr::Return(_))
522                    || expr_str.starts_with("return")
523                    || expr_str.starts_with("if (")
524                {
525                    Ok(format!("{indent}{expr_str};\n"))
526                } else {
527                    Ok(format!("{indent}return {expr_str};\n"))
528                }
529            }
530            Stmt::Item(_) => Err(TranspileError::Unsupported("Item in function body".into())),
531            Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
532        }
533    }
534
535    /// Transpile an expression.
536    fn transpile_expr(&self, expr: &Expr) -> Result<String> {
537        match expr {
538            Expr::Lit(lit) => self.transpile_lit(lit),
539            Expr::Path(path) => self.transpile_path(path),
540            Expr::Binary(bin) => self.transpile_binary(bin),
541            Expr::Unary(unary) => self.transpile_unary(unary),
542            Expr::Paren(paren) => self.transpile_paren(paren),
543            Expr::Index(index) => self.transpile_index(index),
544            Expr::Call(call) => self.transpile_call(call),
545            Expr::MethodCall(method) => self.transpile_method_call(method),
546            Expr::If(if_expr) => self.transpile_if(if_expr),
547            Expr::Assign(assign) => self.transpile_assign(assign),
548            Expr::Cast(cast) => self.transpile_cast(cast),
549            Expr::Match(match_expr) => self.transpile_match(match_expr),
550            Expr::Block(block) => {
551                if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
552                    self.transpile_expr(expr)
553                } else {
554                    Err(TranspileError::Unsupported(
555                        "Complex block expression".into(),
556                    ))
557                }
558            }
559            Expr::Field(field) => {
560                let base = self.transpile_expr(&field.base)?;
561                let member = match &field.member {
562                    syn::Member::Named(ident) => ident.to_string(),
563                    syn::Member::Unnamed(idx) => idx.index.to_string(),
564                };
565                Ok(format!("{base}.{member}"))
566            }
567            Expr::Return(ret) => self.transpile_return(ret),
568            Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
569            Expr::While(while_loop) => self.transpile_while_loop(while_loop),
570            Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
571            Expr::Break(break_expr) => self.transpile_break(break_expr),
572            Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
573            Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
574            Expr::Reference(ref_expr) => {
575                // In WGSL, we often need &var for atomics
576                let inner = self.transpile_expr(&ref_expr.expr)?;
577                Ok(format!("&{inner}"))
578            }
579            _ => Err(TranspileError::Unsupported(format!(
580                "Expression type: {}",
581                expr.to_token_stream()
582            ))),
583        }
584    }
585
586    /// Transpile a literal.
587    fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
588        match &lit.lit {
589            Lit::Float(f) => {
590                let s = f.to_string();
591                // WGSL float literals don't need suffix
592                let num = s.trim_end_matches("f32").trim_end_matches("f64");
593                // Ensure there's a decimal point
594                if num.contains('.') {
595                    Ok(num.to_string())
596                } else {
597                    Ok(format!("{}.0", num))
598                }
599            }
600            Lit::Int(i) => {
601                let s = i.to_string();
602                // Check for unsigned suffix
603                if s.ends_with("u32") || s.ends_with("usize") {
604                    Ok(format!(
605                        "{}u",
606                        s.trim_end_matches("u32").trim_end_matches("usize")
607                    ))
608                } else if s.ends_with("i32") || s.ends_with("isize") {
609                    Ok(format!(
610                        "{}i",
611                        s.trim_end_matches("i32").trim_end_matches("isize")
612                    ))
613                } else {
614                    // Default to signed int
615                    Ok(s)
616                }
617            }
618            Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
619            _ => Err(TranspileError::Unsupported(format!(
620                "Literal type: {}",
621                lit.to_token_stream()
622            ))),
623        }
624    }
625
626    /// Transpile a path (variable reference).
627    fn transpile_path(&self, path: &ExprPath) -> Result<String> {
628        let segments: Vec<_> = path
629            .path
630            .segments
631            .iter()
632            .map(|s| s.ident.to_string())
633            .collect();
634
635        if segments.len() == 1 {
636            Ok(segments[0].clone())
637        } else {
638            Ok(segments.join("_"))
639        }
640    }
641
642    /// Transpile a binary expression.
643    fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
644        let left = self.transpile_expr(&bin.left)?;
645        let right = self.transpile_expr(&bin.right)?;
646
647        let op = match bin.op {
648            BinOp::Add(_) => "+",
649            BinOp::Sub(_) => "-",
650            BinOp::Mul(_) => "*",
651            BinOp::Div(_) => "/",
652            BinOp::Rem(_) => "%",
653            BinOp::And(_) => "&&",
654            BinOp::Or(_) => "||",
655            BinOp::BitXor(_) => "^",
656            BinOp::BitAnd(_) => "&",
657            BinOp::BitOr(_) => "|",
658            BinOp::Shl(_) => "<<",
659            BinOp::Shr(_) => ">>",
660            BinOp::Eq(_) => "==",
661            BinOp::Lt(_) => "<",
662            BinOp::Le(_) => "<=",
663            BinOp::Ne(_) => "!=",
664            BinOp::Ge(_) => ">=",
665            BinOp::Gt(_) => ">",
666            BinOp::AddAssign(_) => "+=",
667            BinOp::SubAssign(_) => "-=",
668            BinOp::MulAssign(_) => "*=",
669            BinOp::DivAssign(_) => "/=",
670            BinOp::RemAssign(_) => "%=",
671            BinOp::BitXorAssign(_) => "^=",
672            BinOp::BitAndAssign(_) => "&=",
673            BinOp::BitOrAssign(_) => "|=",
674            BinOp::ShlAssign(_) => "<<=",
675            BinOp::ShrAssign(_) => ">>=",
676            _ => {
677                return Err(TranspileError::Unsupported(format!(
678                    "Binary operator: {}",
679                    bin.to_token_stream()
680                )))
681            }
682        };
683
684        Ok(format!("{left} {op} {right}"))
685    }
686
687    /// Transpile a unary expression.
688    fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
689        let expr = self.transpile_expr(&unary.expr)?;
690
691        let op = match unary.op {
692            UnOp::Neg(_) => "-",
693            UnOp::Not(_) => "!",
694            UnOp::Deref(_) => "*",
695            _ => {
696                return Err(TranspileError::Unsupported(format!(
697                    "Unary operator: {}",
698                    unary.to_token_stream()
699                )))
700            }
701        };
702
703        Ok(format!("{op}({expr})"))
704    }
705
706    /// Transpile a parenthesized expression.
707    fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
708        let inner = self.transpile_expr(&paren.expr)?;
709        Ok(format!("({inner})"))
710    }
711
712    /// Transpile an index expression.
713    fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
714        let base = self.transpile_expr(&index.expr)?;
715        let idx = self.transpile_expr(&index.index)?;
716        Ok(format!("{base}[{idx}]"))
717    }
718
719    /// Transpile a function call.
720    fn transpile_call(&self, call: &ExprCall) -> Result<String> {
721        let func = self.transpile_expr(&call.func)?;
722
723        // Check for intrinsics
724        if let Some(intrinsic) = self.intrinsics.lookup(&func) {
725            return self.transpile_intrinsic_call(intrinsic, &call.args);
726        }
727
728        // Regular function call
729        let args: Vec<String> = call
730            .args
731            .iter()
732            .map(|a| self.transpile_expr(a))
733            .collect::<Result<_>>()?;
734
735        Ok(format!("{}({})", func, args.join(", ")))
736    }
737
738    /// Transpile an intrinsic call.
739    fn transpile_intrinsic_call(
740        &self,
741        intrinsic: WgslIntrinsic,
742        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
743    ) -> Result<String> {
744        let wgsl_name = intrinsic.to_wgsl();
745
746        // Track if we need subgroup extension
747        if intrinsic.requires_subgroup_extension() {
748            self.needs_subgroup_extension.set(true);
749        }
750
751        // Check for value intrinsics (builtins accessed as variables)
752        match intrinsic {
753            WgslIntrinsic::LocalInvocationIdX
754            | WgslIntrinsic::LocalInvocationIdY
755            | WgslIntrinsic::LocalInvocationIdZ
756            | WgslIntrinsic::WorkgroupIdX
757            | WgslIntrinsic::WorkgroupIdY
758            | WgslIntrinsic::WorkgroupIdZ
759            | WgslIntrinsic::GlobalInvocationIdX
760            | WgslIntrinsic::GlobalInvocationIdY
761            | WgslIntrinsic::GlobalInvocationIdZ
762            | WgslIntrinsic::NumWorkgroupsX
763            | WgslIntrinsic::NumWorkgroupsY
764            | WgslIntrinsic::NumWorkgroupsZ => {
765                // These are variable accesses, not function calls
766                return Ok(format!("i32({})", wgsl_name));
767            }
768            WgslIntrinsic::WorkgroupSizeX => {
769                return Ok(format!("i32({}u)", self.workgroup_size.0));
770            }
771            WgslIntrinsic::WorkgroupSizeY => {
772                return Ok(format!("i32({}u)", self.workgroup_size.1));
773            }
774            WgslIntrinsic::WorkgroupSizeZ => {
775                return Ok(format!("i32({}u)", self.workgroup_size.2));
776            }
777            WgslIntrinsic::WorkgroupBarrier | WgslIntrinsic::StorageBarrier => {
778                // Zero-arg function intrinsics
779                return Ok(wgsl_name.to_string());
780            }
781            WgslIntrinsic::SubgroupInvocationId | WgslIntrinsic::SubgroupSize => {
782                // Subgroup builtins - accessed as variables
783                return Ok(wgsl_name.to_string());
784            }
785            WgslIntrinsic::SubgroupElect => {
786                // Zero-arg subgroup function
787                return Ok(format!("{}()", wgsl_name));
788            }
789            _ => {}
790        }
791
792        // Function intrinsics with arguments
793        let transpiled_args: Vec<String> = args
794            .iter()
795            .map(|a| self.transpile_expr(a))
796            .collect::<Result<_>>()?;
797
798        Ok(format!("{}({})", wgsl_name, transpiled_args.join(", ")))
799    }
800
801    /// Transpile a method call.
802    fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
803        let receiver = self.transpile_expr(&method.receiver)?;
804        let method_name = method.method.to_string();
805
806        // Check if this is a SharedTile/SharedArray method call
807        if let Some(result) =
808            self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
809        {
810            return result;
811        }
812
813        // Check if this is a RingContext method call
814        if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
815            return self.transpile_context_method(&method_name, &method.args);
816        }
817
818        // Check if this is a GridPos method call
819        if self.grid_pos_vars.contains(&receiver) {
820            return self.transpile_stencil_intrinsic(&method_name, &method.args);
821        }
822
823        // Check for f32 math methods
824        if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
825            let wgsl_name = intrinsic.to_wgsl();
826            let args: Vec<String> = std::iter::once(receiver)
827                .chain(
828                    method
829                        .args
830                        .iter()
831                        .map(|a| self.transpile_expr(a).unwrap_or_default()),
832                )
833                .collect();
834
835            return Ok(format!("{}({})", wgsl_name, args.join(", ")));
836        }
837
838        // Regular method call
839        let args: Vec<String> = method
840            .args
841            .iter()
842            .map(|a| self.transpile_expr(a))
843            .collect::<Result<_>>()?;
844
845        Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
846    }
847
848    /// Transpile a RingContext method call.
849    fn transpile_context_method(
850        &self,
851        method: &str,
852        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
853    ) -> Result<String> {
854        let ctx_method = WgslContextMethod::from_name(method).ok_or_else(|| {
855            TranspileError::Unsupported(format!("Unknown context method: {}", method))
856        })?;
857
858        let wgsl_args: Vec<String> = args
859            .iter()
860            .map(|a| self.transpile_expr(a).unwrap_or_default())
861            .collect();
862
863        match ctx_method {
864            WgslContextMethod::AtomicAdd
865            | WgslContextMethod::AtomicLoad
866            | WgslContextMethod::AtomicStore => Ok(format!(
867                "{}({})",
868                ctx_method.to_wgsl(),
869                wgsl_args.join(", ")
870            )),
871            _ => Ok(ctx_method.to_wgsl().to_string()),
872        }
873    }
874
875    /// Transpile a stencil intrinsic method call.
876    fn transpile_stencil_intrinsic(
877        &self,
878        method: &str,
879        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
880    ) -> Result<String> {
881        let config = self.config.as_ref().ok_or_else(|| {
882            TranspileError::Unsupported("Stencil intrinsic without config".into())
883        })?;
884
885        let buffer_width = config.buffer_width();
886
887        match method {
888            "idx" => Ok("idx".to_string()),
889            "x" => Ok("i32(lx)".to_string()),
890            "y" => Ok("i32(ly)".to_string()),
891            "north" => {
892                if args.is_empty() {
893                    return Err(TranspileError::Unsupported(
894                        "north() requires buffer argument".into(),
895                    ));
896                }
897                let buffer = self.transpile_expr(&args[0])?;
898                Ok(format!("{buffer}[idx - {buffer_width}u]"))
899            }
900            "south" => {
901                if args.is_empty() {
902                    return Err(TranspileError::Unsupported(
903                        "south() requires buffer argument".into(),
904                    ));
905                }
906                let buffer = self.transpile_expr(&args[0])?;
907                Ok(format!("{buffer}[idx + {buffer_width}u]"))
908            }
909            "east" => {
910                if args.is_empty() {
911                    return Err(TranspileError::Unsupported(
912                        "east() requires buffer argument".into(),
913                    ));
914                }
915                let buffer = self.transpile_expr(&args[0])?;
916                Ok(format!("{buffer}[idx + 1u]"))
917            }
918            "west" => {
919                if args.is_empty() {
920                    return Err(TranspileError::Unsupported(
921                        "west() requires buffer argument".into(),
922                    ));
923                }
924                let buffer = self.transpile_expr(&args[0])?;
925                Ok(format!("{buffer}[idx - 1u]"))
926            }
927            "at" => {
928                if args.len() < 3 {
929                    return Err(TranspileError::Unsupported(
930                        "at() requires buffer, dx, dy arguments".into(),
931                    ));
932                }
933                let buffer = self.transpile_expr(&args[0])?;
934                let dx = self.transpile_expr(&args[1])?;
935                let dy = self.transpile_expr(&args[2])?;
936                Ok(format!(
937                    "{buffer}[idx + u32(({dy}) * i32({buffer_width}u) + ({dx}))]"
938                ))
939            }
940            _ => Err(TranspileError::Unsupported(format!(
941                "Unknown stencil intrinsic: {}",
942                method
943            ))),
944        }
945    }
946
947    /// Transpile an if expression.
948    fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
949        let cond = self.transpile_expr(&if_expr.cond)?;
950
951        // Check for early return pattern
952        if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
953            if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
954                if ret.expr.is_none() {
955                    return Ok(format!("if ({cond}) {{ return; }}"));
956                }
957                let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
958                return Ok(format!("if ({cond}) {{ return {ret_val}; }}"));
959            }
960        }
961
962        // Handle if-else as select when possible
963        if let Some((_, else_branch)) = &if_expr.else_branch {
964            if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
965                (if_expr.then_branch.stmts.last(), else_branch.as_ref())
966            {
967                if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
968                    let then_str = self.transpile_expr(then_expr)?;
969                    let else_str = self.transpile_expr(else_expr)?;
970                    return Ok(format!("select({else_str}, {then_str}, {cond})"));
971                }
972            }
973
974            // else if chain
975            if let Expr::If(else_if) = else_branch.as_ref() {
976                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
977                let else_part = self.transpile_if(else_if)?;
978                return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
979            } else if let Expr::Block(else_block) = else_branch.as_ref() {
980                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
981                let else_body = self.transpile_if_body(&else_block.block)?;
982                return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
983            }
984        }
985
986        // If without else
987        let then_body = self.transpile_if_body(&if_expr.then_branch)?;
988        Ok(format!("if ({cond}) {{{then_body}}}"))
989    }
990
991    /// Transpile the body of an if branch.
992    fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
993        let mut body = String::new();
994        for stmt in &block.stmts {
995            match stmt {
996                Stmt::Expr(expr, Some(_)) => {
997                    let expr_str = self.transpile_expr(expr)?;
998                    body.push_str(&format!(" {expr_str};"));
999                }
1000                Stmt::Expr(Expr::Return(ret), None) => {
1001                    if let Some(ret_expr) = &ret.expr {
1002                        let expr_str = self.transpile_expr(ret_expr)?;
1003                        body.push_str(&format!(" return {expr_str};"));
1004                    } else {
1005                        body.push_str(" return;");
1006                    }
1007                }
1008                Stmt::Expr(expr, None) => {
1009                    let expr_str = self.transpile_expr(expr)?;
1010                    body.push_str(&format!(" return {expr_str};"));
1011                }
1012                _ => {}
1013            }
1014        }
1015        Ok(body)
1016    }
1017
1018    /// Transpile an assignment expression.
1019    fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
1020        let left = self.transpile_expr(&assign.left)?;
1021        let right = self.transpile_expr(&assign.right)?;
1022        Ok(format!("{left} = {right}"))
1023    }
1024
1025    /// Transpile a cast expression.
1026    fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
1027        let expr = self.transpile_expr(&cast.expr)?;
1028        let wgsl_type = self
1029            .type_mapper
1030            .map_type(&cast.ty)
1031            .map_err(TranspileError::Type)?;
1032        Ok(format!("{}({})", wgsl_type.to_wgsl(), expr))
1033    }
1034
1035    /// Transpile a return expression.
1036    fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
1037        if let Some(expr) = &ret.expr {
1038            let expr_str = self.transpile_expr(expr)?;
1039            Ok(format!("return {expr_str}"))
1040        } else {
1041            Ok("return".to_string())
1042        }
1043    }
1044
1045    /// Transpile a struct literal.
1046    fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1047        let type_name = struct_expr
1048            .path
1049            .segments
1050            .iter()
1051            .map(|s| s.ident.to_string())
1052            .collect::<Vec<_>>()
1053            .join("_");
1054
1055        let mut fields = Vec::new();
1056        for field in &struct_expr.fields {
1057            let field_name = match &field.member {
1058                syn::Member::Named(ident) => ident.to_string(),
1059                syn::Member::Unnamed(idx) => idx.index.to_string(),
1060            };
1061            let value = self.transpile_expr(&field.expr)?;
1062            fields.push(format!("{}: {}", field_name, value));
1063        }
1064
1065        if struct_expr.rest.is_some() {
1066            return Err(TranspileError::Unsupported(
1067                "Struct update syntax (..base) is not supported in WGSL".into(),
1068            ));
1069        }
1070
1071        Ok(format!("{}({})", type_name, fields.join(", ")))
1072    }
1073
1074    // === Loop Transpilation ===
1075
1076    /// Transpile a for loop.
1077    fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1078        if !self.validation_mode.allows_loops() {
1079            return Err(TranspileError::Unsupported(
1080                "Loops are not allowed in stencil kernels".into(),
1081            ));
1082        }
1083
1084        let var_name = extract_loop_var(&for_loop.pat)
1085            .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1086
1087        let header = match for_loop.expr.as_ref() {
1088            Expr::Range(range) => {
1089                let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1090                range_info.to_wgsl_for_header(&var_name)
1091            }
1092            _ => {
1093                return Err(TranspileError::Unsupported(
1094                    "Only range expressions are supported in for loops".into(),
1095                ));
1096            }
1097        };
1098
1099        let body = self.transpile_loop_body(&for_loop.body)?;
1100        Ok(format!("{header} {{\n{body}}}"))
1101    }
1102
1103    /// Transpile a while loop.
1104    fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1105        if !self.validation_mode.allows_loops() {
1106            return Err(TranspileError::Unsupported(
1107                "Loops are not allowed in stencil kernels".into(),
1108            ));
1109        }
1110
1111        let condition = self.transpile_expr(&while_loop.cond)?;
1112        let body = self.transpile_loop_body(&while_loop.body)?;
1113        Ok(format!("while ({condition}) {{\n{body}}}"))
1114    }
1115
1116    /// Transpile an infinite loop.
1117    fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1118        if !self.validation_mode.allows_loops() {
1119            return Err(TranspileError::Unsupported(
1120                "Loops are not allowed in stencil kernels".into(),
1121            ));
1122        }
1123
1124        let body = self.transpile_loop_body(&loop_expr.body)?;
1125        Ok(format!("loop {{\n{body}}}"))
1126    }
1127
1128    /// Transpile a break expression.
1129    fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1130        if break_expr.label.is_some() {
1131            return Err(TranspileError::Unsupported(
1132                "Labeled break is not supported in WGSL".into(),
1133            ));
1134        }
1135        if break_expr.expr.is_some() {
1136            return Err(TranspileError::Unsupported(
1137                "Break with value is not supported in WGSL".into(),
1138            ));
1139        }
1140        Ok("break".to_string())
1141    }
1142
1143    /// Transpile a continue expression.
1144    fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1145        if cont_expr.label.is_some() {
1146            return Err(TranspileError::Unsupported(
1147                "Labeled continue is not supported in WGSL".into(),
1148            ));
1149        }
1150        Ok("continue".to_string())
1151    }
1152
1153    /// Transpile a loop body.
1154    fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1155        let mut output = String::new();
1156        let inner_indent = "    ".repeat(self.indent + 1);
1157
1158        for stmt in &block.stmts {
1159            match stmt {
1160                Stmt::Local(local) => {
1161                    let var_name = match &local.pat {
1162                        Pat::Ident(ident) => ident.ident.to_string(),
1163                        Pat::Type(pat_type) => {
1164                            if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1165                                ident.ident.to_string()
1166                            } else {
1167                                return Err(TranspileError::Unsupported(
1168                                    "Complex pattern in let binding".into(),
1169                                ));
1170                            }
1171                        }
1172                        _ => {
1173                            return Err(TranspileError::Unsupported(
1174                                "Complex pattern in let binding".into(),
1175                            ))
1176                        }
1177                    };
1178
1179                    if let Some(init) = &local.init {
1180                        let expr_str = self.transpile_expr(&init.expr)?;
1181                        let type_str = self.infer_wgsl_type(&init.expr);
1182                        output.push_str(&format!(
1183                            "{inner_indent}var {var_name}: {type_str} = {expr_str};\n"
1184                        ));
1185                    } else {
1186                        output.push_str(&format!("{inner_indent}var {var_name}: f32;\n"));
1187                    }
1188                }
1189                Stmt::Expr(expr, _semi) => {
1190                    let expr_str = self.transpile_expr(expr)?;
1191                    output.push_str(&format!("{inner_indent}{expr_str};\n"));
1192                }
1193                _ => {
1194                    return Err(TranspileError::Unsupported(
1195                        "Unsupported statement in loop body".into(),
1196                    ));
1197                }
1198            }
1199        }
1200
1201        let closing_indent = "    ".repeat(self.indent);
1202        output.push_str(&closing_indent);
1203
1204        Ok(output)
1205    }
1206
1207    /// Transpile a match expression to switch.
1208    fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1209        let scrutinee = self.transpile_expr(&match_expr.expr)?;
1210        let mut output = format!("switch ({scrutinee}) {{\n");
1211
1212        for arm in &match_expr.arms {
1213            let case_label = self.transpile_match_pattern(&arm.pat)?;
1214
1215            if case_label == "default" || case_label.starts_with("/*") {
1216                output.push_str("    default: {\n");
1217            } else {
1218                output.push_str(&format!("    case {case_label}: {{\n"));
1219            }
1220
1221            match arm.body.as_ref() {
1222                Expr::Block(block) => {
1223                    for stmt in &block.block.stmts {
1224                        let stmt_str = self.transpile_stmt_inline(stmt)?;
1225                        output.push_str(&format!("        {stmt_str}\n"));
1226                    }
1227                }
1228                _ => {
1229                    let body = self.transpile_expr(&arm.body)?;
1230                    output.push_str(&format!("        {body};\n"));
1231                }
1232            }
1233
1234            output.push_str("    }\n");
1235        }
1236
1237        output.push('}');
1238        Ok(output)
1239    }
1240
1241    /// Transpile a match pattern.
1242    fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1243        match pat {
1244            Pat::Lit(pat_lit) => match &pat_lit.lit {
1245                Lit::Int(i) => Ok(i.to_string()),
1246                Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
1247                _ => Err(TranspileError::Unsupported(
1248                    "Non-integer literal in match pattern".into(),
1249                )),
1250            },
1251            Pat::Wild(_) => Ok("default".to_string()),
1252            Pat::Ident(ident) => Ok(format!("/* {} */ default", ident.ident)),
1253            Pat::Or(pat_or) => {
1254                if let Some(first) = pat_or.cases.first() {
1255                    self.transpile_match_pattern(first)
1256                } else {
1257                    Err(TranspileError::Unsupported("Empty or pattern".into()))
1258                }
1259            }
1260            _ => Err(TranspileError::Unsupported(format!(
1261                "Match pattern: {}",
1262                pat.to_token_stream()
1263            ))),
1264        }
1265    }
1266
1267    /// Transpile a statement inline.
1268    fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1269        match stmt {
1270            Stmt::Local(local) => {
1271                let var_name = match &local.pat {
1272                    Pat::Ident(ident) => ident.ident.to_string(),
1273                    Pat::Type(pat_type) => {
1274                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1275                            ident.ident.to_string()
1276                        } else {
1277                            return Err(TranspileError::Unsupported(
1278                                "Complex pattern in let binding".into(),
1279                            ));
1280                        }
1281                    }
1282                    _ => {
1283                        return Err(TranspileError::Unsupported(
1284                            "Complex pattern in let binding".into(),
1285                        ))
1286                    }
1287                };
1288
1289                if let Some(init) = &local.init {
1290                    let expr_str = self.transpile_expr(&init.expr)?;
1291                    let type_str = self.infer_wgsl_type(&init.expr);
1292                    Ok(format!("var {var_name}: {type_str} = {expr_str};"))
1293                } else {
1294                    Ok(format!("var {var_name}: f32;"))
1295                }
1296            }
1297            Stmt::Expr(expr, semi) => {
1298                let expr_str = self.transpile_expr(expr)?;
1299                if semi.is_some() {
1300                    Ok(format!("{expr_str};"))
1301                } else {
1302                    Ok(format!("return {expr_str};"))
1303                }
1304            }
1305            _ => Err(TranspileError::Unsupported(
1306                "Unsupported statement in match arm".into(),
1307            )),
1308        }
1309    }
1310
1311    // === Shared Memory Support ===
1312
1313    /// Try to parse a shared memory declaration.
1314    fn try_parse_shared_declaration(
1315        &self,
1316        local: &syn::Local,
1317        var_name: &str,
1318    ) -> Result<Option<SharedMemoryDecl>> {
1319        if let Pat::Type(pat_type) = &local.pat {
1320            let type_str = pat_type.ty.to_token_stream().to_string();
1321            return self.parse_shared_type(&type_str, var_name);
1322        }
1323
1324        if let Some(init) = &local.init {
1325            if let Expr::Call(call) = init.expr.as_ref() {
1326                if let Expr::Path(path) = call.func.as_ref() {
1327                    let path_str = path.to_token_stream().to_string();
1328                    return self.parse_shared_type(&path_str, var_name);
1329                }
1330            }
1331        }
1332
1333        Ok(None)
1334    }
1335
1336    /// Parse a type string for shared memory info.
1337    fn parse_shared_type(
1338        &self,
1339        type_str: &str,
1340        var_name: &str,
1341    ) -> Result<Option<SharedMemoryDecl>> {
1342        let type_str = type_str
1343            .replace(" :: ", "::")
1344            .replace(" ::", "::")
1345            .replace(":: ", "::");
1346
1347        if type_str.contains("SharedTile") {
1348            if let Some(start) = type_str.find('<') {
1349                if let Some(end) = type_str.rfind('>') {
1350                    let params = &type_str[start + 1..end];
1351                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1352
1353                    if parts.len() >= 3 {
1354                        let rust_type = parts[0];
1355                        let width: u32 = parts[1].parse().map_err(|_| {
1356                            TranspileError::Unsupported("Invalid SharedTile width".into())
1357                        })?;
1358                        let height: u32 = parts[2].parse().map_err(|_| {
1359                            TranspileError::Unsupported("Invalid SharedTile height".into())
1360                        })?;
1361
1362                        let wgsl_type = rust_type_to_wgsl(rust_type);
1363                        return Ok(Some(SharedMemoryDecl::new_2d(
1364                            var_name, wgsl_type, width, height,
1365                        )));
1366                    }
1367                }
1368            }
1369        }
1370
1371        if type_str.contains("SharedArray") {
1372            if let Some(start) = type_str.find('<') {
1373                if let Some(end) = type_str.rfind('>') {
1374                    let params = &type_str[start + 1..end];
1375                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1376
1377                    if parts.len() >= 2 {
1378                        let rust_type = parts[0];
1379                        let size: u32 = parts[1].parse().map_err(|_| {
1380                            TranspileError::Unsupported("Invalid SharedArray size".into())
1381                        })?;
1382
1383                        let wgsl_type = rust_type_to_wgsl(rust_type);
1384                        return Ok(Some(SharedMemoryDecl::new_1d(var_name, wgsl_type, size)));
1385                    }
1386                }
1387            }
1388        }
1389
1390        Ok(None)
1391    }
1392
1393    /// Try to transpile a shared memory method call.
1394    fn try_transpile_shared_method_call(
1395        &self,
1396        receiver: &str,
1397        method_name: &str,
1398        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1399    ) -> Option<Result<String>> {
1400        let shared_info = self.shared_vars.get(receiver)?;
1401
1402        match method_name {
1403            "get" => {
1404                if shared_info.is_tile {
1405                    if args.len() >= 2 {
1406                        let x = self.transpile_expr(&args[0]).ok()?;
1407                        let y = self.transpile_expr(&args[1]).ok()?;
1408                        Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1409                    } else {
1410                        Some(Err(TranspileError::Unsupported(
1411                            "SharedTile.get requires x and y arguments".into(),
1412                        )))
1413                    }
1414                } else if !args.is_empty() {
1415                    let idx = self.transpile_expr(&args[0]).ok()?;
1416                    Some(Ok(format!("{}[{}]", receiver, idx)))
1417                } else {
1418                    Some(Err(TranspileError::Unsupported(
1419                        "SharedArray.get requires index argument".into(),
1420                    )))
1421                }
1422            }
1423            "set" => {
1424                if shared_info.is_tile {
1425                    if args.len() >= 3 {
1426                        let x = self.transpile_expr(&args[0]).ok()?;
1427                        let y = self.transpile_expr(&args[1]).ok()?;
1428                        let val = self.transpile_expr(&args[2]).ok()?;
1429                        Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1430                    } else {
1431                        Some(Err(TranspileError::Unsupported(
1432                            "SharedTile.set requires x, y, and value arguments".into(),
1433                        )))
1434                    }
1435                } else if args.len() >= 2 {
1436                    let idx = self.transpile_expr(&args[0]).ok()?;
1437                    let val = self.transpile_expr(&args[1]).ok()?;
1438                    Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1439                } else {
1440                    Some(Err(TranspileError::Unsupported(
1441                        "SharedArray.set requires index and value arguments".into(),
1442                    )))
1443                }
1444            }
1445            "width" if shared_info.is_tile => Some(Ok(shared_info.dimensions[1].to_string())),
1446            "height" if shared_info.is_tile => Some(Ok(shared_info.dimensions[0].to_string())),
1447            "size" => {
1448                let total: usize = shared_info.dimensions.iter().product();
1449                Some(Ok(total.to_string()))
1450            }
1451            _ => None,
1452        }
1453    }
1454
1455    /// Infer WGSL type from expression.
1456    fn infer_wgsl_type(&self, expr: &Expr) -> &'static str {
1457        match expr {
1458            Expr::Lit(lit) => match &lit.lit {
1459                Lit::Float(_) => "f32",
1460                Lit::Int(i) => {
1461                    let s = i.to_string();
1462                    if s.ends_with("u32") || s.ends_with("usize") {
1463                        "u32"
1464                    } else {
1465                        "i32"
1466                    }
1467                }
1468                Lit::Bool(_) => "bool",
1469                _ => "f32",
1470            },
1471            Expr::Binary(bin) => {
1472                let left_type = self.infer_wgsl_type(&bin.left);
1473                let right_type = self.infer_wgsl_type(&bin.right);
1474                if left_type == "i32" && right_type == "i32" {
1475                    "i32"
1476                } else if left_type == "u32" && right_type == "u32" {
1477                    "u32"
1478                } else {
1479                    "f32"
1480                }
1481            }
1482            Expr::Call(call) => {
1483                if let Ok(func) = self.transpile_expr(&call.func) {
1484                    if let Some(
1485                        WgslIntrinsic::LocalInvocationIdX
1486                        | WgslIntrinsic::WorkgroupIdX
1487                        | WgslIntrinsic::GlobalInvocationIdX,
1488                    ) = self.intrinsics.lookup(&func)
1489                    {
1490                        return "i32";
1491                    }
1492                }
1493                "f32"
1494            }
1495            Expr::Cast(cast) => {
1496                if let Ok(wgsl_type) = self.type_mapper.map_type(&cast.ty) {
1497                    match wgsl_type {
1498                        WgslType::I32 => return "i32",
1499                        WgslType::U32 => return "u32",
1500                        WgslType::F32 => return "f32",
1501                        WgslType::Bool => return "bool",
1502                        _ => {}
1503                    }
1504                }
1505                "f32"
1506            }
1507            _ => "f32",
1508        }
1509    }
1510}
1511
1512/// Extract loop variable name from pattern.
1513fn extract_loop_var(pat: &Pat) -> Option<String> {
1514    match pat {
1515        Pat::Ident(ident) => Some(ident.ident.to_string()),
1516        _ => None,
1517    }
1518}
1519
1520/// Convert Rust primitive type to WGSL type.
1521fn rust_type_to_wgsl(rust_type: &str) -> WgslType {
1522    match rust_type {
1523        "f32" => WgslType::F32,
1524        "f64" => WgslType::F32, // WGSL doesn't have f64
1525        "i32" => WgslType::I32,
1526        "u32" => WgslType::U32,
1527        "bool" => WgslType::Bool,
1528        _ => WgslType::F32,
1529    }
1530}
1531
1532/// Transpile a function to WGSL without kernel attributes.
1533pub fn transpile_function(func: &ItemFn) -> Result<String> {
1534    let mut transpiler = WgslTranspiler::new_generic();
1535
1536    let name = func.sig.ident.to_string();
1537
1538    let mut params = Vec::new();
1539    for param in &func.sig.inputs {
1540        if let FnArg::Typed(pat_type) = param {
1541            let param_name = match pat_type.pat.as_ref() {
1542                Pat::Ident(ident) => ident.ident.to_string(),
1543                _ => continue,
1544            };
1545
1546            let wgsl_type = transpiler
1547                .type_mapper
1548                .map_type(&pat_type.ty)
1549                .map_err(TranspileError::Type)?;
1550            params.push(format!("{}: {}", param_name, wgsl_type.to_wgsl()));
1551        }
1552    }
1553
1554    let return_type = match &func.sig.output {
1555        ReturnType::Default => "".to_string(),
1556        ReturnType::Type(_, ty) => {
1557            let wgsl_type = transpiler
1558                .type_mapper
1559                .map_type(ty)
1560                .map_err(TranspileError::Type)?;
1561            format!(" -> {}", wgsl_type.to_wgsl())
1562        }
1563    };
1564
1565    let body = transpiler.transpile_block(&func.block)?;
1566
1567    Ok(format!(
1568        "fn {name}({}){return_type} {{\n{body}}}\n",
1569        params.join(", ")
1570    ))
1571}
1572
1573// === Range Info Extension for WGSL ===
1574
1575impl RangeInfo {
1576    /// Create from a syn::ExprRange.
1577    pub fn from_range<F>(range: &syn::ExprRange, transpile: F) -> Self
1578    where
1579        F: Fn(&Expr) -> Result<String>,
1580    {
1581        let start = range.start.as_ref().and_then(|e| transpile(e).ok());
1582        let end = range.end.as_ref().and_then(|e| transpile(e).ok());
1583        let inclusive = matches!(range.limits, syn::RangeLimits::Closed(_));
1584
1585        Self::new(start, end, inclusive)
1586    }
1587
1588    /// Generate WGSL for loop header.
1589    pub fn to_wgsl_for_header(&self, var_name: &str) -> String {
1590        let start = self.start_or_default();
1591        let end = self.end.as_deref().unwrap_or("/* unbounded */");
1592        let op = if self.inclusive { "<=" } else { "<" };
1593
1594        format!(
1595            "for (var {var_name}: i32 = {start}; {var_name} {op} {end}; {var_name} = {var_name} + 1)"
1596        )
1597    }
1598}
1599
1600#[cfg(test)]
1601mod tests {
1602    use super::*;
1603    use syn::parse_quote;
1604
1605    #[test]
1606    fn test_simple_arithmetic() {
1607        let transpiler = WgslTranspiler::new_generic();
1608
1609        let expr: Expr = parse_quote!(a + b * 2.0);
1610        let result = transpiler.transpile_expr(&expr).unwrap();
1611        assert_eq!(result, "a + b * 2.0");
1612    }
1613
1614    #[test]
1615    fn test_let_binding() {
1616        let mut transpiler = WgslTranspiler::new_generic();
1617
1618        let stmt: Stmt = parse_quote!(let x = a + b;);
1619        let result = transpiler.transpile_stmt(&stmt).unwrap();
1620        assert!(result.contains("var x: f32 = a + b;"));
1621    }
1622
1623    #[test]
1624    fn test_array_index() {
1625        let transpiler = WgslTranspiler::new_generic();
1626
1627        let expr: Expr = parse_quote!(data[idx]);
1628        let result = transpiler.transpile_expr(&expr).unwrap();
1629        assert_eq!(result, "data[idx]");
1630    }
1631
1632    #[test]
1633    fn test_intrinsic_thread_idx() {
1634        let transpiler = WgslTranspiler::new_generic();
1635
1636        let expr: Expr = parse_quote!(thread_idx_x());
1637        let result = transpiler.transpile_expr(&expr).unwrap();
1638        assert!(result.contains("local_invocation_id.x"));
1639    }
1640
1641    #[test]
1642    fn test_for_loop() {
1643        let transpiler = WgslTranspiler::new_generic();
1644
1645        let expr: Expr = parse_quote! {
1646            for i in 0..n {
1647                data[i] = 0.0;
1648            }
1649        };
1650
1651        let result = transpiler.transpile_expr(&expr).unwrap();
1652        assert!(result.contains("for (var i: i32 = 0; i < n; i = i + 1)"));
1653    }
1654
1655    #[test]
1656    fn test_while_loop() {
1657        let transpiler = WgslTranspiler::new_generic();
1658
1659        let expr: Expr = parse_quote! {
1660            while x < 10 {
1661                x += 1;
1662            }
1663        };
1664
1665        let result = transpiler.transpile_expr(&expr).unwrap();
1666        assert!(result.contains("while (x < 10)"));
1667    }
1668
1669    #[test]
1670    fn test_early_return() {
1671        let transpiler = WgslTranspiler::new_generic();
1672
1673        let expr: Expr = parse_quote! {
1674            if idx >= n { return; }
1675        };
1676
1677        let result = transpiler.transpile_expr(&expr).unwrap();
1678        assert!(result.contains("if (idx >= n)"));
1679        assert!(result.contains("return;"));
1680    }
1681
1682    #[test]
1683    fn test_sync_threads() {
1684        let transpiler = WgslTranspiler::new_generic();
1685
1686        let expr: Expr = parse_quote!(sync_threads());
1687        let result = transpiler.transpile_expr(&expr).unwrap();
1688        assert_eq!(result, "workgroupBarrier()");
1689    }
1690}