ringkernel_wgpu_codegen/
transpiler.rs

1//! Core Rust-to-WGSL transpiler.
2//!
3//! This module handles the translation of Rust AST to WGSL code.
4
5use crate::bindings::{generate_bindings, AccessMode, BindingLayout};
6use crate::handler::WgslContextMethod;
7use crate::intrinsics::{IntrinsicRegistry, WgslIntrinsic};
8use crate::loops::RangeInfo;
9use crate::ring_kernel::RingKernelConfig;
10use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
11use crate::stencil::StencilConfig;
12use crate::types::{
13    is_grid_pos_type, is_mutable_reference, is_ring_context_type, TypeMapper, WgslType,
14};
15use crate::u64_workarounds::U64Helpers;
16use crate::validation::ValidationMode;
17use crate::{Result, TranspileError};
18use quote::ToTokens;
19use std::collections::HashMap;
20use syn::{
21    BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
22    ExprIf, ExprIndex, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
23    ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat, ReturnType, Stmt, UnOp,
24};
25
26/// WGSL code transpiler.
27pub struct WgslTranspiler {
28    /// Stencil configuration (if generating a stencil kernel).
29    config: Option<StencilConfig>,
30    /// Ring kernel configuration (if generating a ring kernel).
31    #[allow(dead_code)]
32    ring_config: Option<RingKernelConfig>,
33    /// Type mapper.
34    type_mapper: TypeMapper,
35    /// Intrinsic registry.
36    intrinsics: IntrinsicRegistry,
37    /// Variables known to be the GridPos context.
38    grid_pos_vars: Vec<String>,
39    /// Variables known to be RingContext references.
40    context_vars: Vec<String>,
41    /// Current indentation level.
42    indent: usize,
43    /// Validation mode for loop handling.
44    validation_mode: ValidationMode,
45    /// Shared memory configuration.
46    shared_memory: SharedMemoryConfig,
47    /// Variables that are SharedTile or SharedArray types.
48    pub shared_vars: HashMap<String, SharedVarInfo>,
49    /// Whether we're in ring kernel mode (enables context method inlining).
50    ring_kernel_mode: bool,
51    /// Whether to include u64 helper functions.
52    needs_u64_helpers: bool,
53    /// Whether subgroup operations are used (need extension).
54    needs_subgroup_extension: bool,
55    /// Workgroup size for generic kernels.
56    workgroup_size: (u32, u32, u32),
57    /// Collected buffer bindings.
58    bindings: Vec<BindingLayout>,
59}
60
61/// Information about a shared memory variable.
62#[derive(Debug, Clone)]
63pub struct SharedVarInfo {
64    /// Variable name.
65    pub name: String,
66    /// Whether it's a 2D tile (true) or 1D array (false).
67    pub is_tile: bool,
68    /// Dimensions: [size] for 1D, [height, width] for 2D.
69    pub dimensions: Vec<usize>,
70    /// Element type (WGSL type string).
71    pub element_type: String,
72}
73
74impl WgslTranspiler {
75    /// Create a new transpiler with stencil configuration.
76    pub fn new_stencil(config: StencilConfig) -> Self {
77        Self {
78            config: Some(config),
79            ring_config: None,
80            type_mapper: TypeMapper::new(),
81            intrinsics: IntrinsicRegistry::new(),
82            grid_pos_vars: Vec::new(),
83            context_vars: Vec::new(),
84            indent: 1,
85            validation_mode: ValidationMode::Stencil,
86            shared_memory: SharedMemoryConfig::new(),
87            shared_vars: HashMap::new(),
88            ring_kernel_mode: false,
89            needs_u64_helpers: false,
90            needs_subgroup_extension: false,
91            workgroup_size: (256, 1, 1),
92            bindings: Vec::new(),
93        }
94    }
95
96    /// Create a new transpiler without stencil configuration.
97    pub fn new_generic() -> Self {
98        Self {
99            config: None,
100            ring_config: None,
101            type_mapper: TypeMapper::new(),
102            intrinsics: IntrinsicRegistry::new(),
103            grid_pos_vars: Vec::new(),
104            context_vars: Vec::new(),
105            indent: 1,
106            validation_mode: ValidationMode::Generic,
107            shared_memory: SharedMemoryConfig::new(),
108            shared_vars: HashMap::new(),
109            ring_kernel_mode: false,
110            needs_u64_helpers: false,
111            needs_subgroup_extension: false,
112            workgroup_size: (256, 1, 1),
113            bindings: Vec::new(),
114        }
115    }
116
117    /// Create a new transpiler for ring kernel generation.
118    pub fn new_ring_kernel(config: RingKernelConfig) -> Self {
119        Self {
120            config: None,
121            ring_config: Some(config.clone()),
122            type_mapper: TypeMapper::new(),
123            intrinsics: IntrinsicRegistry::new(),
124            grid_pos_vars: Vec::new(),
125            context_vars: Vec::new(),
126            indent: 2,
127            validation_mode: ValidationMode::Generic,
128            shared_memory: SharedMemoryConfig::new(),
129            shared_vars: HashMap::new(),
130            ring_kernel_mode: true,
131            needs_u64_helpers: true, // Ring kernels typically need 64-bit counters
132            needs_subgroup_extension: false,
133            workgroup_size: (config.workgroup_size, 1, 1),
134            bindings: Vec::new(),
135        }
136    }
137
138    /// Set the workgroup size for generic kernels.
139    pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
140        self.workgroup_size = (x, y, z);
141        self
142    }
143
144    /// Get current indentation string.
145    fn indent_str(&self) -> String {
146        "    ".repeat(self.indent)
147    }
148
149    /// Transpile a stencil kernel function.
150    pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
151        let config = self
152            .config
153            .as_ref()
154            .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
155            .clone();
156
157        // Identify GridPos parameters
158        for param in &func.sig.inputs {
159            if let FnArg::Typed(pat_type) = param {
160                if is_grid_pos_type(&pat_type.ty) {
161                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
162                        self.grid_pos_vars.push(ident.ident.to_string());
163                    }
164                }
165            }
166        }
167
168        // Generate bindings from parameters
169        self.collect_bindings(func)?;
170
171        let mut output = String::new();
172
173        // Generate buffer bindings
174        output.push_str(&generate_bindings(&self.bindings));
175        output.push_str("\n\n");
176
177        // Generate shared memory declarations
178        if config.use_shared_memory {
179            let buffer_width = config.buffer_width();
180            let buffer_height = config.buffer_height();
181            output.push_str(&format!(
182                "var<workgroup> tile: array<array<f32, {}>, {}>;\n\n",
183                buffer_width, buffer_height
184            ));
185        }
186
187        // Generate kernel signature with builtins
188        output.push_str("@compute ");
189        output.push_str(&config.workgroup_size_annotation());
190        output.push('\n');
191        output.push_str(&format!("fn {}(\n", func.sig.ident));
192        output.push_str("    @builtin(local_invocation_id) local_id: vec3<u32>,\n");
193        output.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
194        output.push_str("    @builtin(global_invocation_id) global_id: vec3<u32>\n");
195        output.push_str(") {\n");
196
197        // Generate preamble
198        output.push_str(&self.generate_stencil_preamble(&config));
199
200        // Generate function body
201        let body = self.transpile_block(&func.block)?;
202        output.push_str(&body);
203
204        output.push_str("}\n");
205
206        Ok(output)
207    }
208
209    /// Generate stencil kernel preamble.
210    fn generate_stencil_preamble(&self, config: &StencilConfig) -> String {
211        let buffer_width = config.buffer_width();
212        let mut preamble = String::new();
213
214        preamble.push_str("    // Thread indices\n");
215        preamble.push_str("    let lx = local_id.x;\n");
216        preamble.push_str("    let ly = local_id.y;\n");
217        preamble.push_str(&format!("    let buffer_width = {}u;\n", buffer_width));
218        preamble.push('\n');
219
220        // Bounds check
221        preamble.push_str(&format!(
222            "    if (lx >= {}u || ly >= {}u) {{ return; }}\n\n",
223            config.tile_width, config.tile_height
224        ));
225
226        // Calculate buffer index with halo offset
227        preamble.push_str(&format!(
228            "    let idx = (ly + {}u) * buffer_width + (lx + {}u);\n\n",
229            config.halo, config.halo
230        ));
231
232        preamble
233    }
234
235    /// Transpile a generic (non-stencil) kernel function.
236    pub fn transpile_global_kernel(&mut self, func: &ItemFn) -> Result<String> {
237        // Collect bindings from parameters
238        self.collect_bindings(func)?;
239
240        let mut output = String::new();
241
242        // Add extensions if needed
243        if self.needs_subgroup_extension {
244            output.push_str("enable chromium_experimental_subgroups;\n\n");
245        }
246
247        // Generate buffer bindings
248        output.push_str(&generate_bindings(&self.bindings));
249        output.push_str("\n\n");
250
251        // Generate shared memory declarations
252        if !self.shared_memory.declarations.is_empty() {
253            output.push_str(&self.shared_memory.to_wgsl());
254            output.push_str("\n\n");
255        }
256
257        // Generate kernel signature
258        output.push_str("@compute ");
259        output.push_str(&format!(
260            "@workgroup_size({}, {}, {})\n",
261            self.workgroup_size.0, self.workgroup_size.1, self.workgroup_size.2
262        ));
263        output.push_str(&format!("fn {}(\n", func.sig.ident));
264        output.push_str("    @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
265        output.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
266        output.push_str("    @builtin(global_invocation_id) global_invocation_id: vec3<u32>,\n");
267        output.push_str("    @builtin(num_workgroups) num_workgroups: vec3<u32>\n");
268        output.push_str(") {\n");
269
270        // Generate function body
271        let body = self.transpile_block(&func.block)?;
272        output.push_str(&body);
273
274        output.push_str("}\n");
275
276        Ok(output)
277    }
278
279    /// Transpile a handler function into a ring kernel.
280    pub fn transpile_ring_kernel(
281        &mut self,
282        handler: &ItemFn,
283        config: &RingKernelConfig,
284    ) -> Result<String> {
285        // Track context variables for method inlining
286        for param in &handler.sig.inputs {
287            if let FnArg::Typed(pat_type) = param {
288                if is_ring_context_type(&pat_type.ty) {
289                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
290                        self.context_vars.push(ident.ident.to_string());
291                    }
292                }
293            }
294        }
295
296        self.ring_kernel_mode = true;
297        self.needs_u64_helpers = true;
298
299        let mut output = String::new();
300
301        // Generate 64-bit helper functions
302        output.push_str(&U64Helpers::generate_all());
303        output.push_str("\n\n");
304
305        // Generate ControlBlock struct
306        output.push_str(&crate::ring_kernel::generate_control_block_struct(config));
307        output.push_str("\n\n");
308
309        // Generate bindings
310        output.push_str(crate::ring_kernel::generate_ring_kernel_bindings());
311        output.push_str("\n\n");
312
313        // Generate kernel
314        output.push_str("@compute ");
315        output.push_str(&config.workgroup_size_annotation());
316        output.push('\n');
317        output.push_str(&format!("fn ring_kernel_{}(\n", config.name));
318        output.push_str("    @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n");
319        output.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
320        output.push_str("    @builtin(global_invocation_id) global_invocation_id: vec3<u32>\n");
321        output.push_str(") {\n");
322
323        // Generate preamble (activation/termination checks)
324        output.push_str(crate::ring_kernel::generate_ring_kernel_preamble());
325        output.push_str("\n\n");
326
327        // Generate handler body
328        output.push_str("    // === USER HANDLER CODE ===\n");
329        self.indent = 1;
330        let handler_body = self.transpile_block(&handler.block)?;
331        output.push_str(&handler_body);
332        output.push_str("    // === END HANDLER CODE ===\n\n");
333
334        // Update message counter
335        output.push_str("    // Update message counter\n");
336        output.push_str(
337            "    atomic_inc_u64(&control.messages_processed_lo, &control.messages_processed_hi);\n",
338        );
339
340        output.push_str("}\n");
341
342        Ok(output)
343    }
344
345    /// Collect buffer bindings from function parameters.
346    fn collect_bindings(&mut self, func: &ItemFn) -> Result<()> {
347        let mut binding_idx = 0u32;
348
349        for param in &func.sig.inputs {
350            if let FnArg::Typed(pat_type) = param {
351                // Skip GridPos and RingContext parameters
352                if is_grid_pos_type(&pat_type.ty) || is_ring_context_type(&pat_type.ty) {
353                    continue;
354                }
355
356                let param_name = match pat_type.pat.as_ref() {
357                    Pat::Ident(ident) => ident.ident.to_string(),
358                    _ => continue,
359                };
360
361                let wgsl_type = self
362                    .type_mapper
363                    .map_type(&pat_type.ty)
364                    .map_err(TranspileError::Type)?;
365                let is_mutable = is_mutable_reference(&pat_type.ty);
366
367                // Determine if this is a buffer or scalar
368                match &wgsl_type {
369                    WgslType::Ptr { inner, .. } => {
370                        let access = if is_mutable {
371                            AccessMode::ReadWrite
372                        } else {
373                            AccessMode::Read
374                        };
375                        self.bindings.push(BindingLayout::new(
376                            0,
377                            binding_idx,
378                            &param_name,
379                            WgslType::Array {
380                                element: inner.clone(),
381                                size: None,
382                            },
383                            access,
384                        ));
385                        binding_idx += 1;
386                    }
387                    WgslType::Array { .. } => {
388                        let access = if is_mutable {
389                            AccessMode::ReadWrite
390                        } else {
391                            AccessMode::Read
392                        };
393                        self.bindings.push(BindingLayout::new(
394                            0,
395                            binding_idx,
396                            &param_name,
397                            wgsl_type.clone(),
398                            access,
399                        ));
400                        binding_idx += 1;
401                    }
402                    // Scalars are typically passed via uniforms
403                    _ => {
404                        self.bindings.push(BindingLayout::uniform(
405                            binding_idx,
406                            &param_name,
407                            wgsl_type.clone(),
408                        ));
409                        binding_idx += 1;
410                    }
411                }
412            }
413        }
414
415        Ok(())
416    }
417
418    /// Transpile a block of statements.
419    fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
420        let mut output = String::new();
421
422        for stmt in &block.stmts {
423            let stmt_str = self.transpile_stmt(stmt)?;
424            if !stmt_str.is_empty() {
425                output.push_str(&stmt_str);
426            }
427        }
428
429        Ok(output)
430    }
431
432    /// Transpile a single statement.
433    fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
434        match stmt {
435            Stmt::Local(local) => {
436                let indent = self.indent_str();
437
438                // Get variable name
439                let var_name = match &local.pat {
440                    Pat::Ident(ident) => ident.ident.to_string(),
441                    Pat::Type(pat_type) => {
442                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
443                            ident.ident.to_string()
444                        } else {
445                            return Err(TranspileError::Unsupported(
446                                "Complex pattern in let binding".into(),
447                            ));
448                        }
449                    }
450                    _ => {
451                        return Err(TranspileError::Unsupported(
452                            "Complex pattern in let binding".into(),
453                        ))
454                    }
455                };
456
457                // Check for SharedTile or SharedArray type annotation
458                if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
459                    self.shared_vars.insert(
460                        var_name.clone(),
461                        SharedVarInfo {
462                            name: var_name.clone(),
463                            is_tile: shared_decl.dimensions.len() == 2,
464                            dimensions: shared_decl
465                                .dimensions
466                                .iter()
467                                .map(|&d| d as usize)
468                                .collect(),
469                            element_type: shared_decl.element_type.to_wgsl(),
470                        },
471                    );
472                    self.shared_memory.add(shared_decl.clone());
473                    return Ok(format!(
474                        "{indent}// shared memory: {} declared at module scope\n",
475                        var_name
476                    ));
477                }
478
479                // Get initializer
480                if let Some(init) = &local.init {
481                    let expr_str = self.transpile_expr(&init.expr)?;
482                    let type_str = self.infer_wgsl_type(&init.expr);
483
484                    Ok(format!(
485                        "{indent}var {var_name}: {type_str} = {expr_str};\n"
486                    ))
487                } else {
488                    Ok(format!("{indent}var {var_name}: f32;\n"))
489                }
490            }
491            Stmt::Expr(expr, semi) => {
492                let indent = self.indent_str();
493
494                // Handle early return pattern in if statements
495                if let Expr::If(if_expr) = expr {
496                    if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
497                    {
498                        if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
499                            let expr_str = self.transpile_expr(expr)?;
500                            return Ok(format!("{indent}{expr_str}\n"));
501                        }
502                    }
503                }
504
505                let expr_str = self.transpile_expr(expr)?;
506
507                if semi.is_some()
508                    || matches!(expr, Expr::Return(_))
509                    || expr_str.starts_with("return")
510                    || expr_str.starts_with("if (")
511                {
512                    Ok(format!("{indent}{expr_str};\n"))
513                } else {
514                    Ok(format!("{indent}return {expr_str};\n"))
515                }
516            }
517            Stmt::Item(_) => Err(TranspileError::Unsupported("Item in function body".into())),
518            Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
519        }
520    }
521
522    /// Transpile an expression.
523    fn transpile_expr(&self, expr: &Expr) -> Result<String> {
524        match expr {
525            Expr::Lit(lit) => self.transpile_lit(lit),
526            Expr::Path(path) => self.transpile_path(path),
527            Expr::Binary(bin) => self.transpile_binary(bin),
528            Expr::Unary(unary) => self.transpile_unary(unary),
529            Expr::Paren(paren) => self.transpile_paren(paren),
530            Expr::Index(index) => self.transpile_index(index),
531            Expr::Call(call) => self.transpile_call(call),
532            Expr::MethodCall(method) => self.transpile_method_call(method),
533            Expr::If(if_expr) => self.transpile_if(if_expr),
534            Expr::Assign(assign) => self.transpile_assign(assign),
535            Expr::Cast(cast) => self.transpile_cast(cast),
536            Expr::Match(match_expr) => self.transpile_match(match_expr),
537            Expr::Block(block) => {
538                if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
539                    self.transpile_expr(expr)
540                } else {
541                    Err(TranspileError::Unsupported(
542                        "Complex block expression".into(),
543                    ))
544                }
545            }
546            Expr::Field(field) => {
547                let base = self.transpile_expr(&field.base)?;
548                let member = match &field.member {
549                    syn::Member::Named(ident) => ident.to_string(),
550                    syn::Member::Unnamed(idx) => idx.index.to_string(),
551                };
552                Ok(format!("{base}.{member}"))
553            }
554            Expr::Return(ret) => self.transpile_return(ret),
555            Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
556            Expr::While(while_loop) => self.transpile_while_loop(while_loop),
557            Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
558            Expr::Break(break_expr) => self.transpile_break(break_expr),
559            Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
560            Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
561            Expr::Reference(ref_expr) => {
562                // In WGSL, we often need &var for atomics
563                let inner = self.transpile_expr(&ref_expr.expr)?;
564                Ok(format!("&{inner}"))
565            }
566            _ => Err(TranspileError::Unsupported(format!(
567                "Expression type: {}",
568                expr.to_token_stream()
569            ))),
570        }
571    }
572
573    /// Transpile a literal.
574    fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
575        match &lit.lit {
576            Lit::Float(f) => {
577                let s = f.to_string();
578                // WGSL float literals don't need suffix
579                let num = s.trim_end_matches("f32").trim_end_matches("f64");
580                // Ensure there's a decimal point
581                if num.contains('.') {
582                    Ok(num.to_string())
583                } else {
584                    Ok(format!("{}.0", num))
585                }
586            }
587            Lit::Int(i) => {
588                let s = i.to_string();
589                // Check for unsigned suffix
590                if s.ends_with("u32") || s.ends_with("usize") {
591                    Ok(format!(
592                        "{}u",
593                        s.trim_end_matches("u32").trim_end_matches("usize")
594                    ))
595                } else if s.ends_with("i32") || s.ends_with("isize") {
596                    Ok(format!(
597                        "{}i",
598                        s.trim_end_matches("i32").trim_end_matches("isize")
599                    ))
600                } else {
601                    // Default to signed int
602                    Ok(s)
603                }
604            }
605            Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
606            _ => Err(TranspileError::Unsupported(format!(
607                "Literal type: {}",
608                lit.to_token_stream()
609            ))),
610        }
611    }
612
613    /// Transpile a path (variable reference).
614    fn transpile_path(&self, path: &ExprPath) -> Result<String> {
615        let segments: Vec<_> = path
616            .path
617            .segments
618            .iter()
619            .map(|s| s.ident.to_string())
620            .collect();
621
622        if segments.len() == 1 {
623            Ok(segments[0].clone())
624        } else {
625            Ok(segments.join("_"))
626        }
627    }
628
629    /// Transpile a binary expression.
630    fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
631        let left = self.transpile_expr(&bin.left)?;
632        let right = self.transpile_expr(&bin.right)?;
633
634        let op = match bin.op {
635            BinOp::Add(_) => "+",
636            BinOp::Sub(_) => "-",
637            BinOp::Mul(_) => "*",
638            BinOp::Div(_) => "/",
639            BinOp::Rem(_) => "%",
640            BinOp::And(_) => "&&",
641            BinOp::Or(_) => "||",
642            BinOp::BitXor(_) => "^",
643            BinOp::BitAnd(_) => "&",
644            BinOp::BitOr(_) => "|",
645            BinOp::Shl(_) => "<<",
646            BinOp::Shr(_) => ">>",
647            BinOp::Eq(_) => "==",
648            BinOp::Lt(_) => "<",
649            BinOp::Le(_) => "<=",
650            BinOp::Ne(_) => "!=",
651            BinOp::Ge(_) => ">=",
652            BinOp::Gt(_) => ">",
653            BinOp::AddAssign(_) => "+=",
654            BinOp::SubAssign(_) => "-=",
655            BinOp::MulAssign(_) => "*=",
656            BinOp::DivAssign(_) => "/=",
657            BinOp::RemAssign(_) => "%=",
658            BinOp::BitXorAssign(_) => "^=",
659            BinOp::BitAndAssign(_) => "&=",
660            BinOp::BitOrAssign(_) => "|=",
661            BinOp::ShlAssign(_) => "<<=",
662            BinOp::ShrAssign(_) => ">>=",
663            _ => {
664                return Err(TranspileError::Unsupported(format!(
665                    "Binary operator: {}",
666                    bin.to_token_stream()
667                )))
668            }
669        };
670
671        Ok(format!("{left} {op} {right}"))
672    }
673
674    /// Transpile a unary expression.
675    fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
676        let expr = self.transpile_expr(&unary.expr)?;
677
678        let op = match unary.op {
679            UnOp::Neg(_) => "-",
680            UnOp::Not(_) => "!",
681            UnOp::Deref(_) => "*",
682            _ => {
683                return Err(TranspileError::Unsupported(format!(
684                    "Unary operator: {}",
685                    unary.to_token_stream()
686                )))
687            }
688        };
689
690        Ok(format!("{op}({expr})"))
691    }
692
693    /// Transpile a parenthesized expression.
694    fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
695        let inner = self.transpile_expr(&paren.expr)?;
696        Ok(format!("({inner})"))
697    }
698
699    /// Transpile an index expression.
700    fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
701        let base = self.transpile_expr(&index.expr)?;
702        let idx = self.transpile_expr(&index.index)?;
703        Ok(format!("{base}[{idx}]"))
704    }
705
706    /// Transpile a function call.
707    fn transpile_call(&self, call: &ExprCall) -> Result<String> {
708        let func = self.transpile_expr(&call.func)?;
709
710        // Check for intrinsics
711        if let Some(intrinsic) = self.intrinsics.lookup(&func) {
712            return self.transpile_intrinsic_call(intrinsic, &call.args);
713        }
714
715        // Regular function call
716        let args: Vec<String> = call
717            .args
718            .iter()
719            .map(|a| self.transpile_expr(a))
720            .collect::<Result<_>>()?;
721
722        Ok(format!("{}({})", func, args.join(", ")))
723    }
724
725    /// Transpile an intrinsic call.
726    fn transpile_intrinsic_call(
727        &self,
728        intrinsic: WgslIntrinsic,
729        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
730    ) -> Result<String> {
731        let wgsl_name = intrinsic.to_wgsl();
732
733        // Check for value intrinsics (builtins accessed as variables)
734        match intrinsic {
735            WgslIntrinsic::LocalInvocationIdX
736            | WgslIntrinsic::LocalInvocationIdY
737            | WgslIntrinsic::LocalInvocationIdZ
738            | WgslIntrinsic::WorkgroupIdX
739            | WgslIntrinsic::WorkgroupIdY
740            | WgslIntrinsic::WorkgroupIdZ
741            | WgslIntrinsic::GlobalInvocationIdX
742            | WgslIntrinsic::GlobalInvocationIdY
743            | WgslIntrinsic::GlobalInvocationIdZ
744            | WgslIntrinsic::NumWorkgroupsX
745            | WgslIntrinsic::NumWorkgroupsY
746            | WgslIntrinsic::NumWorkgroupsZ => {
747                // These are variable accesses, not function calls
748                return Ok(format!("i32({})", wgsl_name));
749            }
750            WgslIntrinsic::WorkgroupSizeX => {
751                return Ok(format!("i32({}u)", self.workgroup_size.0));
752            }
753            WgslIntrinsic::WorkgroupSizeY => {
754                return Ok(format!("i32({}u)", self.workgroup_size.1));
755            }
756            WgslIntrinsic::WorkgroupSizeZ => {
757                return Ok(format!("i32({}u)", self.workgroup_size.2));
758            }
759            WgslIntrinsic::WorkgroupBarrier | WgslIntrinsic::StorageBarrier => {
760                // Zero-arg function intrinsics
761                return Ok(wgsl_name.to_string());
762            }
763            WgslIntrinsic::SubgroupInvocationId | WgslIntrinsic::SubgroupSize => {
764                // Subgroup builtins
765                return Ok(wgsl_name.to_string());
766            }
767            _ => {}
768        }
769
770        // Function intrinsics with arguments
771        let transpiled_args: Vec<String> = args
772            .iter()
773            .map(|a| self.transpile_expr(a))
774            .collect::<Result<_>>()?;
775
776        Ok(format!("{}({})", wgsl_name, transpiled_args.join(", ")))
777    }
778
779    /// Transpile a method call.
780    fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
781        let receiver = self.transpile_expr(&method.receiver)?;
782        let method_name = method.method.to_string();
783
784        // Check if this is a SharedTile/SharedArray method call
785        if let Some(result) =
786            self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
787        {
788            return result;
789        }
790
791        // Check if this is a RingContext method call
792        if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
793            return self.transpile_context_method(&method_name, &method.args);
794        }
795
796        // Check if this is a GridPos method call
797        if self.grid_pos_vars.contains(&receiver) {
798            return self.transpile_stencil_intrinsic(&method_name, &method.args);
799        }
800
801        // Check for f32 math methods
802        if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
803            let wgsl_name = intrinsic.to_wgsl();
804            let args: Vec<String> = std::iter::once(receiver)
805                .chain(
806                    method
807                        .args
808                        .iter()
809                        .map(|a| self.transpile_expr(a).unwrap_or_default()),
810                )
811                .collect();
812
813            return Ok(format!("{}({})", wgsl_name, args.join(", ")));
814        }
815
816        // Regular method call
817        let args: Vec<String> = method
818            .args
819            .iter()
820            .map(|a| self.transpile_expr(a))
821            .collect::<Result<_>>()?;
822
823        Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
824    }
825
826    /// Transpile a RingContext method call.
827    fn transpile_context_method(
828        &self,
829        method: &str,
830        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
831    ) -> Result<String> {
832        let ctx_method = WgslContextMethod::from_name(method).ok_or_else(|| {
833            TranspileError::Unsupported(format!("Unknown context method: {}", method))
834        })?;
835
836        let wgsl_args: Vec<String> = args
837            .iter()
838            .map(|a| self.transpile_expr(a).unwrap_or_default())
839            .collect();
840
841        match ctx_method {
842            WgslContextMethod::AtomicAdd
843            | WgslContextMethod::AtomicLoad
844            | WgslContextMethod::AtomicStore => Ok(format!(
845                "{}({})",
846                ctx_method.to_wgsl(),
847                wgsl_args.join(", ")
848            )),
849            _ => Ok(ctx_method.to_wgsl().to_string()),
850        }
851    }
852
853    /// Transpile a stencil intrinsic method call.
854    fn transpile_stencil_intrinsic(
855        &self,
856        method: &str,
857        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
858    ) -> Result<String> {
859        let config = self.config.as_ref().ok_or_else(|| {
860            TranspileError::Unsupported("Stencil intrinsic without config".into())
861        })?;
862
863        let buffer_width = config.buffer_width();
864
865        match method {
866            "idx" => Ok("idx".to_string()),
867            "x" => Ok("i32(lx)".to_string()),
868            "y" => Ok("i32(ly)".to_string()),
869            "north" => {
870                if args.is_empty() {
871                    return Err(TranspileError::Unsupported(
872                        "north() requires buffer argument".into(),
873                    ));
874                }
875                let buffer = self.transpile_expr(&args[0])?;
876                Ok(format!("{buffer}[idx - {buffer_width}u]"))
877            }
878            "south" => {
879                if args.is_empty() {
880                    return Err(TranspileError::Unsupported(
881                        "south() requires buffer argument".into(),
882                    ));
883                }
884                let buffer = self.transpile_expr(&args[0])?;
885                Ok(format!("{buffer}[idx + {buffer_width}u]"))
886            }
887            "east" => {
888                if args.is_empty() {
889                    return Err(TranspileError::Unsupported(
890                        "east() requires buffer argument".into(),
891                    ));
892                }
893                let buffer = self.transpile_expr(&args[0])?;
894                Ok(format!("{buffer}[idx + 1u]"))
895            }
896            "west" => {
897                if args.is_empty() {
898                    return Err(TranspileError::Unsupported(
899                        "west() requires buffer argument".into(),
900                    ));
901                }
902                let buffer = self.transpile_expr(&args[0])?;
903                Ok(format!("{buffer}[idx - 1u]"))
904            }
905            "at" => {
906                if args.len() < 3 {
907                    return Err(TranspileError::Unsupported(
908                        "at() requires buffer, dx, dy arguments".into(),
909                    ));
910                }
911                let buffer = self.transpile_expr(&args[0])?;
912                let dx = self.transpile_expr(&args[1])?;
913                let dy = self.transpile_expr(&args[2])?;
914                Ok(format!(
915                    "{buffer}[idx + u32(({dy}) * i32({buffer_width}u) + ({dx}))]"
916                ))
917            }
918            _ => Err(TranspileError::Unsupported(format!(
919                "Unknown stencil intrinsic: {}",
920                method
921            ))),
922        }
923    }
924
925    /// Transpile an if expression.
926    fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
927        let cond = self.transpile_expr(&if_expr.cond)?;
928
929        // Check for early return pattern
930        if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
931            if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
932                if ret.expr.is_none() {
933                    return Ok(format!("if ({cond}) {{ return; }}"));
934                }
935                let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
936                return Ok(format!("if ({cond}) {{ return {ret_val}; }}"));
937            }
938        }
939
940        // Handle if-else as select when possible
941        if let Some((_, else_branch)) = &if_expr.else_branch {
942            if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
943                (if_expr.then_branch.stmts.last(), else_branch.as_ref())
944            {
945                if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
946                    let then_str = self.transpile_expr(then_expr)?;
947                    let else_str = self.transpile_expr(else_expr)?;
948                    return Ok(format!("select({else_str}, {then_str}, {cond})"));
949                }
950            }
951
952            // else if chain
953            if let Expr::If(else_if) = else_branch.as_ref() {
954                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
955                let else_part = self.transpile_if(else_if)?;
956                return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
957            } else if let Expr::Block(else_block) = else_branch.as_ref() {
958                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
959                let else_body = self.transpile_if_body(&else_block.block)?;
960                return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
961            }
962        }
963
964        // If without else
965        let then_body = self.transpile_if_body(&if_expr.then_branch)?;
966        Ok(format!("if ({cond}) {{{then_body}}}"))
967    }
968
969    /// Transpile the body of an if branch.
970    fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
971        let mut body = String::new();
972        for stmt in &block.stmts {
973            match stmt {
974                Stmt::Expr(expr, Some(_)) => {
975                    let expr_str = self.transpile_expr(expr)?;
976                    body.push_str(&format!(" {expr_str};"));
977                }
978                Stmt::Expr(Expr::Return(ret), None) => {
979                    if let Some(ret_expr) = &ret.expr {
980                        let expr_str = self.transpile_expr(ret_expr)?;
981                        body.push_str(&format!(" return {expr_str};"));
982                    } else {
983                        body.push_str(" return;");
984                    }
985                }
986                Stmt::Expr(expr, None) => {
987                    let expr_str = self.transpile_expr(expr)?;
988                    body.push_str(&format!(" return {expr_str};"));
989                }
990                _ => {}
991            }
992        }
993        Ok(body)
994    }
995
996    /// Transpile an assignment expression.
997    fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
998        let left = self.transpile_expr(&assign.left)?;
999        let right = self.transpile_expr(&assign.right)?;
1000        Ok(format!("{left} = {right}"))
1001    }
1002
1003    /// Transpile a cast expression.
1004    fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
1005        let expr = self.transpile_expr(&cast.expr)?;
1006        let wgsl_type = self
1007            .type_mapper
1008            .map_type(&cast.ty)
1009            .map_err(TranspileError::Type)?;
1010        Ok(format!("{}({})", wgsl_type.to_wgsl(), expr))
1011    }
1012
1013    /// Transpile a return expression.
1014    fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
1015        if let Some(expr) = &ret.expr {
1016            let expr_str = self.transpile_expr(expr)?;
1017            Ok(format!("return {expr_str}"))
1018        } else {
1019            Ok("return".to_string())
1020        }
1021    }
1022
1023    /// Transpile a struct literal.
1024    fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1025        let type_name = struct_expr
1026            .path
1027            .segments
1028            .iter()
1029            .map(|s| s.ident.to_string())
1030            .collect::<Vec<_>>()
1031            .join("_");
1032
1033        let mut fields = Vec::new();
1034        for field in &struct_expr.fields {
1035            let field_name = match &field.member {
1036                syn::Member::Named(ident) => ident.to_string(),
1037                syn::Member::Unnamed(idx) => idx.index.to_string(),
1038            };
1039            let value = self.transpile_expr(&field.expr)?;
1040            fields.push(format!("{}: {}", field_name, value));
1041        }
1042
1043        if struct_expr.rest.is_some() {
1044            return Err(TranspileError::Unsupported(
1045                "Struct update syntax (..base) is not supported in WGSL".into(),
1046            ));
1047        }
1048
1049        Ok(format!("{}({})", type_name, fields.join(", ")))
1050    }
1051
1052    // === Loop Transpilation ===
1053
1054    /// Transpile a for loop.
1055    fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1056        if !self.validation_mode.allows_loops() {
1057            return Err(TranspileError::Unsupported(
1058                "Loops are not allowed in stencil kernels".into(),
1059            ));
1060        }
1061
1062        let var_name = extract_loop_var(&for_loop.pat)
1063            .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1064
1065        let header = match for_loop.expr.as_ref() {
1066            Expr::Range(range) => {
1067                let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1068                range_info.to_wgsl_for_header(&var_name)
1069            }
1070            _ => {
1071                return Err(TranspileError::Unsupported(
1072                    "Only range expressions are supported in for loops".into(),
1073                ));
1074            }
1075        };
1076
1077        let body = self.transpile_loop_body(&for_loop.body)?;
1078        Ok(format!("{header} {{\n{body}}}"))
1079    }
1080
1081    /// Transpile a while loop.
1082    fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1083        if !self.validation_mode.allows_loops() {
1084            return Err(TranspileError::Unsupported(
1085                "Loops are not allowed in stencil kernels".into(),
1086            ));
1087        }
1088
1089        let condition = self.transpile_expr(&while_loop.cond)?;
1090        let body = self.transpile_loop_body(&while_loop.body)?;
1091        Ok(format!("while ({condition}) {{\n{body}}}"))
1092    }
1093
1094    /// Transpile an infinite loop.
1095    fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1096        if !self.validation_mode.allows_loops() {
1097            return Err(TranspileError::Unsupported(
1098                "Loops are not allowed in stencil kernels".into(),
1099            ));
1100        }
1101
1102        let body = self.transpile_loop_body(&loop_expr.body)?;
1103        Ok(format!("loop {{\n{body}}}"))
1104    }
1105
1106    /// Transpile a break expression.
1107    fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1108        if break_expr.label.is_some() {
1109            return Err(TranspileError::Unsupported(
1110                "Labeled break is not supported in WGSL".into(),
1111            ));
1112        }
1113        if break_expr.expr.is_some() {
1114            return Err(TranspileError::Unsupported(
1115                "Break with value is not supported in WGSL".into(),
1116            ));
1117        }
1118        Ok("break".to_string())
1119    }
1120
1121    /// Transpile a continue expression.
1122    fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1123        if cont_expr.label.is_some() {
1124            return Err(TranspileError::Unsupported(
1125                "Labeled continue is not supported in WGSL".into(),
1126            ));
1127        }
1128        Ok("continue".to_string())
1129    }
1130
1131    /// Transpile a loop body.
1132    fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1133        let mut output = String::new();
1134        let inner_indent = "    ".repeat(self.indent + 1);
1135
1136        for stmt in &block.stmts {
1137            match stmt {
1138                Stmt::Local(local) => {
1139                    let var_name = match &local.pat {
1140                        Pat::Ident(ident) => ident.ident.to_string(),
1141                        Pat::Type(pat_type) => {
1142                            if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1143                                ident.ident.to_string()
1144                            } else {
1145                                return Err(TranspileError::Unsupported(
1146                                    "Complex pattern in let binding".into(),
1147                                ));
1148                            }
1149                        }
1150                        _ => {
1151                            return Err(TranspileError::Unsupported(
1152                                "Complex pattern in let binding".into(),
1153                            ))
1154                        }
1155                    };
1156
1157                    if let Some(init) = &local.init {
1158                        let expr_str = self.transpile_expr(&init.expr)?;
1159                        let type_str = self.infer_wgsl_type(&init.expr);
1160                        output.push_str(&format!(
1161                            "{inner_indent}var {var_name}: {type_str} = {expr_str};\n"
1162                        ));
1163                    } else {
1164                        output.push_str(&format!("{inner_indent}var {var_name}: f32;\n"));
1165                    }
1166                }
1167                Stmt::Expr(expr, _semi) => {
1168                    let expr_str = self.transpile_expr(expr)?;
1169                    output.push_str(&format!("{inner_indent}{expr_str};\n"));
1170                }
1171                _ => {
1172                    return Err(TranspileError::Unsupported(
1173                        "Unsupported statement in loop body".into(),
1174                    ));
1175                }
1176            }
1177        }
1178
1179        let closing_indent = "    ".repeat(self.indent);
1180        output.push_str(&closing_indent);
1181
1182        Ok(output)
1183    }
1184
1185    /// Transpile a match expression to switch.
1186    fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1187        let scrutinee = self.transpile_expr(&match_expr.expr)?;
1188        let mut output = format!("switch ({scrutinee}) {{\n");
1189
1190        for arm in &match_expr.arms {
1191            let case_label = self.transpile_match_pattern(&arm.pat)?;
1192
1193            if case_label == "default" || case_label.starts_with("/*") {
1194                output.push_str("    default: {\n");
1195            } else {
1196                output.push_str(&format!("    case {case_label}: {{\n"));
1197            }
1198
1199            match arm.body.as_ref() {
1200                Expr::Block(block) => {
1201                    for stmt in &block.block.stmts {
1202                        let stmt_str = self.transpile_stmt_inline(stmt)?;
1203                        output.push_str(&format!("        {stmt_str}\n"));
1204                    }
1205                }
1206                _ => {
1207                    let body = self.transpile_expr(&arm.body)?;
1208                    output.push_str(&format!("        {body};\n"));
1209                }
1210            }
1211
1212            output.push_str("    }\n");
1213        }
1214
1215        output.push('}');
1216        Ok(output)
1217    }
1218
1219    /// Transpile a match pattern.
1220    fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1221        match pat {
1222            Pat::Lit(pat_lit) => match &pat_lit.lit {
1223                Lit::Int(i) => Ok(i.to_string()),
1224                Lit::Bool(b) => Ok(if b.value { "true" } else { "false" }.to_string()),
1225                _ => Err(TranspileError::Unsupported(
1226                    "Non-integer literal in match pattern".into(),
1227                )),
1228            },
1229            Pat::Wild(_) => Ok("default".to_string()),
1230            Pat::Ident(ident) => Ok(format!("/* {} */ default", ident.ident)),
1231            Pat::Or(pat_or) => {
1232                if let Some(first) = pat_or.cases.first() {
1233                    self.transpile_match_pattern(first)
1234                } else {
1235                    Err(TranspileError::Unsupported("Empty or pattern".into()))
1236                }
1237            }
1238            _ => Err(TranspileError::Unsupported(format!(
1239                "Match pattern: {}",
1240                pat.to_token_stream()
1241            ))),
1242        }
1243    }
1244
1245    /// Transpile a statement inline.
1246    fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1247        match stmt {
1248            Stmt::Local(local) => {
1249                let var_name = match &local.pat {
1250                    Pat::Ident(ident) => ident.ident.to_string(),
1251                    Pat::Type(pat_type) => {
1252                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1253                            ident.ident.to_string()
1254                        } else {
1255                            return Err(TranspileError::Unsupported(
1256                                "Complex pattern in let binding".into(),
1257                            ));
1258                        }
1259                    }
1260                    _ => {
1261                        return Err(TranspileError::Unsupported(
1262                            "Complex pattern in let binding".into(),
1263                        ))
1264                    }
1265                };
1266
1267                if let Some(init) = &local.init {
1268                    let expr_str = self.transpile_expr(&init.expr)?;
1269                    let type_str = self.infer_wgsl_type(&init.expr);
1270                    Ok(format!("var {var_name}: {type_str} = {expr_str};"))
1271                } else {
1272                    Ok(format!("var {var_name}: f32;"))
1273                }
1274            }
1275            Stmt::Expr(expr, semi) => {
1276                let expr_str = self.transpile_expr(expr)?;
1277                if semi.is_some() {
1278                    Ok(format!("{expr_str};"))
1279                } else {
1280                    Ok(format!("return {expr_str};"))
1281                }
1282            }
1283            _ => Err(TranspileError::Unsupported(
1284                "Unsupported statement in match arm".into(),
1285            )),
1286        }
1287    }
1288
1289    // === Shared Memory Support ===
1290
1291    /// Try to parse a shared memory declaration.
1292    fn try_parse_shared_declaration(
1293        &self,
1294        local: &syn::Local,
1295        var_name: &str,
1296    ) -> Result<Option<SharedMemoryDecl>> {
1297        if let Pat::Type(pat_type) = &local.pat {
1298            let type_str = pat_type.ty.to_token_stream().to_string();
1299            return self.parse_shared_type(&type_str, var_name);
1300        }
1301
1302        if let Some(init) = &local.init {
1303            if let Expr::Call(call) = init.expr.as_ref() {
1304                if let Expr::Path(path) = call.func.as_ref() {
1305                    let path_str = path.to_token_stream().to_string();
1306                    return self.parse_shared_type(&path_str, var_name);
1307                }
1308            }
1309        }
1310
1311        Ok(None)
1312    }
1313
1314    /// Parse a type string for shared memory info.
1315    fn parse_shared_type(
1316        &self,
1317        type_str: &str,
1318        var_name: &str,
1319    ) -> Result<Option<SharedMemoryDecl>> {
1320        let type_str = type_str
1321            .replace(" :: ", "::")
1322            .replace(" ::", "::")
1323            .replace(":: ", "::");
1324
1325        if type_str.contains("SharedTile") {
1326            if let Some(start) = type_str.find('<') {
1327                if let Some(end) = type_str.rfind('>') {
1328                    let params = &type_str[start + 1..end];
1329                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1330
1331                    if parts.len() >= 3 {
1332                        let rust_type = parts[0];
1333                        let width: u32 = parts[1].parse().map_err(|_| {
1334                            TranspileError::Unsupported("Invalid SharedTile width".into())
1335                        })?;
1336                        let height: u32 = parts[2].parse().map_err(|_| {
1337                            TranspileError::Unsupported("Invalid SharedTile height".into())
1338                        })?;
1339
1340                        let wgsl_type = rust_type_to_wgsl(rust_type);
1341                        return Ok(Some(SharedMemoryDecl::new_2d(
1342                            var_name, wgsl_type, width, height,
1343                        )));
1344                    }
1345                }
1346            }
1347        }
1348
1349        if type_str.contains("SharedArray") {
1350            if let Some(start) = type_str.find('<') {
1351                if let Some(end) = type_str.rfind('>') {
1352                    let params = &type_str[start + 1..end];
1353                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1354
1355                    if parts.len() >= 2 {
1356                        let rust_type = parts[0];
1357                        let size: u32 = parts[1].parse().map_err(|_| {
1358                            TranspileError::Unsupported("Invalid SharedArray size".into())
1359                        })?;
1360
1361                        let wgsl_type = rust_type_to_wgsl(rust_type);
1362                        return Ok(Some(SharedMemoryDecl::new_1d(var_name, wgsl_type, size)));
1363                    }
1364                }
1365            }
1366        }
1367
1368        Ok(None)
1369    }
1370
1371    /// Try to transpile a shared memory method call.
1372    fn try_transpile_shared_method_call(
1373        &self,
1374        receiver: &str,
1375        method_name: &str,
1376        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1377    ) -> Option<Result<String>> {
1378        let shared_info = self.shared_vars.get(receiver)?;
1379
1380        match method_name {
1381            "get" => {
1382                if shared_info.is_tile {
1383                    if args.len() >= 2 {
1384                        let x = self.transpile_expr(&args[0]).ok()?;
1385                        let y = self.transpile_expr(&args[1]).ok()?;
1386                        Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1387                    } else {
1388                        Some(Err(TranspileError::Unsupported(
1389                            "SharedTile.get requires x and y arguments".into(),
1390                        )))
1391                    }
1392                } else if !args.is_empty() {
1393                    let idx = self.transpile_expr(&args[0]).ok()?;
1394                    Some(Ok(format!("{}[{}]", receiver, idx)))
1395                } else {
1396                    Some(Err(TranspileError::Unsupported(
1397                        "SharedArray.get requires index argument".into(),
1398                    )))
1399                }
1400            }
1401            "set" => {
1402                if shared_info.is_tile {
1403                    if args.len() >= 3 {
1404                        let x = self.transpile_expr(&args[0]).ok()?;
1405                        let y = self.transpile_expr(&args[1]).ok()?;
1406                        let val = self.transpile_expr(&args[2]).ok()?;
1407                        Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1408                    } else {
1409                        Some(Err(TranspileError::Unsupported(
1410                            "SharedTile.set requires x, y, and value arguments".into(),
1411                        )))
1412                    }
1413                } else if args.len() >= 2 {
1414                    let idx = self.transpile_expr(&args[0]).ok()?;
1415                    let val = self.transpile_expr(&args[1]).ok()?;
1416                    Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1417                } else {
1418                    Some(Err(TranspileError::Unsupported(
1419                        "SharedArray.set requires index and value arguments".into(),
1420                    )))
1421                }
1422            }
1423            "width" if shared_info.is_tile => Some(Ok(shared_info.dimensions[1].to_string())),
1424            "height" if shared_info.is_tile => Some(Ok(shared_info.dimensions[0].to_string())),
1425            "size" => {
1426                let total: usize = shared_info.dimensions.iter().product();
1427                Some(Ok(total.to_string()))
1428            }
1429            _ => None,
1430        }
1431    }
1432
1433    /// Infer WGSL type from expression.
1434    fn infer_wgsl_type(&self, expr: &Expr) -> &'static str {
1435        match expr {
1436            Expr::Lit(lit) => match &lit.lit {
1437                Lit::Float(_) => "f32",
1438                Lit::Int(i) => {
1439                    let s = i.to_string();
1440                    if s.ends_with("u32") || s.ends_with("usize") {
1441                        "u32"
1442                    } else {
1443                        "i32"
1444                    }
1445                }
1446                Lit::Bool(_) => "bool",
1447                _ => "f32",
1448            },
1449            Expr::Binary(bin) => {
1450                let left_type = self.infer_wgsl_type(&bin.left);
1451                let right_type = self.infer_wgsl_type(&bin.right);
1452                if left_type == "i32" && right_type == "i32" {
1453                    "i32"
1454                } else if left_type == "u32" && right_type == "u32" {
1455                    "u32"
1456                } else {
1457                    "f32"
1458                }
1459            }
1460            Expr::Call(call) => {
1461                if let Ok(func) = self.transpile_expr(&call.func) {
1462                    if let Some(
1463                        WgslIntrinsic::LocalInvocationIdX
1464                        | WgslIntrinsic::WorkgroupIdX
1465                        | WgslIntrinsic::GlobalInvocationIdX,
1466                    ) = self.intrinsics.lookup(&func)
1467                    {
1468                        return "i32";
1469                    }
1470                }
1471                "f32"
1472            }
1473            Expr::Cast(cast) => {
1474                if let Ok(wgsl_type) = self.type_mapper.map_type(&cast.ty) {
1475                    match wgsl_type {
1476                        WgslType::I32 => return "i32",
1477                        WgslType::U32 => return "u32",
1478                        WgslType::F32 => return "f32",
1479                        WgslType::Bool => return "bool",
1480                        _ => {}
1481                    }
1482                }
1483                "f32"
1484            }
1485            _ => "f32",
1486        }
1487    }
1488}
1489
1490/// Extract loop variable name from pattern.
1491fn extract_loop_var(pat: &Pat) -> Option<String> {
1492    match pat {
1493        Pat::Ident(ident) => Some(ident.ident.to_string()),
1494        _ => None,
1495    }
1496}
1497
1498/// Convert Rust primitive type to WGSL type.
1499fn rust_type_to_wgsl(rust_type: &str) -> WgslType {
1500    match rust_type {
1501        "f32" => WgslType::F32,
1502        "f64" => WgslType::F32, // WGSL doesn't have f64
1503        "i32" => WgslType::I32,
1504        "u32" => WgslType::U32,
1505        "bool" => WgslType::Bool,
1506        _ => WgslType::F32,
1507    }
1508}
1509
1510/// Transpile a function to WGSL without kernel attributes.
1511pub fn transpile_function(func: &ItemFn) -> Result<String> {
1512    let mut transpiler = WgslTranspiler::new_generic();
1513
1514    let name = func.sig.ident.to_string();
1515
1516    let mut params = Vec::new();
1517    for param in &func.sig.inputs {
1518        if let FnArg::Typed(pat_type) = param {
1519            let param_name = match pat_type.pat.as_ref() {
1520                Pat::Ident(ident) => ident.ident.to_string(),
1521                _ => continue,
1522            };
1523
1524            let wgsl_type = transpiler
1525                .type_mapper
1526                .map_type(&pat_type.ty)
1527                .map_err(TranspileError::Type)?;
1528            params.push(format!("{}: {}", param_name, wgsl_type.to_wgsl()));
1529        }
1530    }
1531
1532    let return_type = match &func.sig.output {
1533        ReturnType::Default => "".to_string(),
1534        ReturnType::Type(_, ty) => {
1535            let wgsl_type = transpiler
1536                .type_mapper
1537                .map_type(ty)
1538                .map_err(TranspileError::Type)?;
1539            format!(" -> {}", wgsl_type.to_wgsl())
1540        }
1541    };
1542
1543    let body = transpiler.transpile_block(&func.block)?;
1544
1545    Ok(format!(
1546        "fn {name}({}){return_type} {{\n{body}}}\n",
1547        params.join(", ")
1548    ))
1549}
1550
1551// === Range Info Extension for WGSL ===
1552
1553impl RangeInfo {
1554    /// Create from a syn::ExprRange.
1555    pub fn from_range<F>(range: &syn::ExprRange, transpile: F) -> Self
1556    where
1557        F: Fn(&Expr) -> Result<String>,
1558    {
1559        let start = range.start.as_ref().and_then(|e| transpile(e).ok());
1560        let end = range.end.as_ref().and_then(|e| transpile(e).ok());
1561        let inclusive = matches!(range.limits, syn::RangeLimits::Closed(_));
1562
1563        Self::new(start, end, inclusive)
1564    }
1565
1566    /// Generate WGSL for loop header.
1567    pub fn to_wgsl_for_header(&self, var_name: &str) -> String {
1568        let start = self.start_or_default();
1569        let end = self.end.as_deref().unwrap_or("/* unbounded */");
1570        let op = if self.inclusive { "<=" } else { "<" };
1571
1572        format!(
1573            "for (var {var_name}: i32 = {start}; {var_name} {op} {end}; {var_name} = {var_name} + 1)"
1574        )
1575    }
1576}
1577
1578#[cfg(test)]
1579mod tests {
1580    use super::*;
1581    use syn::parse_quote;
1582
1583    #[test]
1584    fn test_simple_arithmetic() {
1585        let transpiler = WgslTranspiler::new_generic();
1586
1587        let expr: Expr = parse_quote!(a + b * 2.0);
1588        let result = transpiler.transpile_expr(&expr).unwrap();
1589        assert_eq!(result, "a + b * 2.0");
1590    }
1591
1592    #[test]
1593    fn test_let_binding() {
1594        let mut transpiler = WgslTranspiler::new_generic();
1595
1596        let stmt: Stmt = parse_quote!(let x = a + b;);
1597        let result = transpiler.transpile_stmt(&stmt).unwrap();
1598        assert!(result.contains("var x: f32 = a + b;"));
1599    }
1600
1601    #[test]
1602    fn test_array_index() {
1603        let transpiler = WgslTranspiler::new_generic();
1604
1605        let expr: Expr = parse_quote!(data[idx]);
1606        let result = transpiler.transpile_expr(&expr).unwrap();
1607        assert_eq!(result, "data[idx]");
1608    }
1609
1610    #[test]
1611    fn test_intrinsic_thread_idx() {
1612        let transpiler = WgslTranspiler::new_generic();
1613
1614        let expr: Expr = parse_quote!(thread_idx_x());
1615        let result = transpiler.transpile_expr(&expr).unwrap();
1616        assert!(result.contains("local_invocation_id.x"));
1617    }
1618
1619    #[test]
1620    fn test_for_loop() {
1621        let transpiler = WgslTranspiler::new_generic();
1622
1623        let expr: Expr = parse_quote! {
1624            for i in 0..n {
1625                data[i] = 0.0;
1626            }
1627        };
1628
1629        let result = transpiler.transpile_expr(&expr).unwrap();
1630        assert!(result.contains("for (var i: i32 = 0; i < n; i = i + 1)"));
1631    }
1632
1633    #[test]
1634    fn test_while_loop() {
1635        let transpiler = WgslTranspiler::new_generic();
1636
1637        let expr: Expr = parse_quote! {
1638            while x < 10 {
1639                x += 1;
1640            }
1641        };
1642
1643        let result = transpiler.transpile_expr(&expr).unwrap();
1644        assert!(result.contains("while (x < 10)"));
1645    }
1646
1647    #[test]
1648    fn test_early_return() {
1649        let transpiler = WgslTranspiler::new_generic();
1650
1651        let expr: Expr = parse_quote! {
1652            if idx >= n { return; }
1653        };
1654
1655        let result = transpiler.transpile_expr(&expr).unwrap();
1656        assert!(result.contains("if (idx >= n)"));
1657        assert!(result.contains("return;"));
1658    }
1659
1660    #[test]
1661    fn test_sync_threads() {
1662        let transpiler = WgslTranspiler::new_generic();
1663
1664        let expr: Expr = parse_quote!(sync_threads());
1665        let result = transpiler.transpile_expr(&expr).unwrap();
1666        assert_eq!(result, "workgroupBarrier()");
1667    }
1668}