ringkernel_cuda_codegen/
transpiler.rs

1//! Core Rust-to-CUDA transpiler.
2//!
3//! This module handles the translation of Rust AST to CUDA C code.
4
5use crate::handler::{ContextMethod, HandlerSignature};
6use crate::intrinsics::{IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
7use crate::loops::{extract_loop_var, RangeInfo};
8use crate::shared::{rust_to_cuda_element_type, SharedMemoryConfig, SharedMemoryDecl};
9use crate::stencil::StencilConfig;
10use crate::types::{is_grid_pos_type, is_ring_context_type, TypeMapper};
11use crate::validation::ValidationMode;
12use crate::{Result, TranspileError};
13use quote::ToTokens;
14use syn::{
15    BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
16    ExprIf, ExprIndex, ExprLet, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
17    ExprReference, ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat,
18    ReturnType, Stmt, UnOp,
19};
20
21/// CUDA code transpiler.
22pub struct CudaTranspiler {
23    /// Stencil configuration (if generating a stencil kernel).
24    config: Option<StencilConfig>,
25    /// Type mapper.
26    type_mapper: TypeMapper,
27    /// Intrinsic registry.
28    intrinsics: IntrinsicRegistry,
29    /// Variables known to be the GridPos context.
30    grid_pos_vars: Vec<String>,
31    /// Variables known to be RingContext references.
32    context_vars: Vec<String>,
33    /// Current indentation level.
34    indent: usize,
35    /// Validation mode for loop handling.
36    validation_mode: ValidationMode,
37    /// Shared memory configuration.
38    shared_memory: SharedMemoryConfig,
39    /// Variables that are SharedTile or SharedArray types.
40    pub shared_vars: std::collections::HashMap<String, SharedVarInfo>,
41    /// Whether we're in ring kernel mode (enables context method inlining).
42    ring_kernel_mode: bool,
43    /// Variables known to be pointer types (for -> operator usage).
44    pointer_vars: std::collections::HashSet<String>,
45}
46
47/// Information about a shared memory variable.
48#[derive(Debug, Clone)]
49pub struct SharedVarInfo {
50    /// Variable name.
51    pub name: String,
52    /// Whether it's a 2D tile (true) or 1D array (false).
53    pub is_tile: bool,
54    /// Dimensions: [size] for 1D, [height, width] for 2D.
55    pub dimensions: Vec<usize>,
56    /// Element type (CUDA type string).
57    pub element_type: String,
58}
59
60impl CudaTranspiler {
61    /// Create a new transpiler with stencil configuration.
62    pub fn new(config: StencilConfig) -> Self {
63        Self {
64            config: Some(config),
65            type_mapper: TypeMapper::new(),
66            intrinsics: IntrinsicRegistry::new(),
67            grid_pos_vars: Vec::new(),
68            context_vars: Vec::new(),
69            indent: 1, // Start with 1 level for function body
70            validation_mode: ValidationMode::Stencil,
71            shared_memory: SharedMemoryConfig::new(),
72            shared_vars: std::collections::HashMap::new(),
73            ring_kernel_mode: false,
74            pointer_vars: std::collections::HashSet::new(),
75        }
76    }
77
78    /// Create a new transpiler without stencil configuration.
79    pub fn new_generic() -> Self {
80        Self {
81            config: None,
82            type_mapper: TypeMapper::new(),
83            intrinsics: IntrinsicRegistry::new(),
84            grid_pos_vars: Vec::new(),
85            context_vars: Vec::new(),
86            indent: 1,
87            validation_mode: ValidationMode::Generic,
88            shared_memory: SharedMemoryConfig::new(),
89            shared_vars: std::collections::HashMap::new(),
90            ring_kernel_mode: false,
91            pointer_vars: std::collections::HashSet::new(),
92        }
93    }
94
95    /// Create a new transpiler with a specific validation mode.
96    pub fn with_mode(mode: ValidationMode) -> Self {
97        Self {
98            config: None,
99            type_mapper: TypeMapper::new(),
100            intrinsics: IntrinsicRegistry::new(),
101            grid_pos_vars: Vec::new(),
102            context_vars: Vec::new(),
103            indent: 1,
104            validation_mode: mode,
105            shared_memory: SharedMemoryConfig::new(),
106            shared_vars: std::collections::HashMap::new(),
107            ring_kernel_mode: false,
108            pointer_vars: std::collections::HashSet::new(),
109        }
110    }
111
112    /// Create a transpiler configured for ring kernel handler transpilation.
113    pub fn for_ring_kernel() -> Self {
114        Self {
115            config: None,
116            type_mapper: crate::types::ring_kernel_type_mapper(),
117            intrinsics: IntrinsicRegistry::new(),
118            grid_pos_vars: Vec::new(),
119            context_vars: Vec::new(),
120            indent: 2, // Inside kernel + loop
121            validation_mode: ValidationMode::Generic,
122            shared_memory: SharedMemoryConfig::new(),
123            shared_vars: std::collections::HashMap::new(),
124            ring_kernel_mode: true,
125            pointer_vars: std::collections::HashSet::new(),
126        }
127    }
128
129    /// Set the validation mode.
130    pub fn set_validation_mode(&mut self, mode: ValidationMode) {
131        self.validation_mode = mode;
132    }
133
134    /// Get the shared memory configuration.
135    pub fn shared_memory(&self) -> &SharedMemoryConfig {
136        &self.shared_memory
137    }
138
139    /// Get current indentation string.
140    fn indent_str(&self) -> String {
141        "    ".repeat(self.indent)
142    }
143
144    /// Transpile a stencil kernel function.
145    pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
146        let config = self
147            .config
148            .as_ref()
149            .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
150            .clone();
151
152        // Identify GridPos parameters
153        for param in &func.sig.inputs {
154            if let FnArg::Typed(pat_type) = param {
155                if is_grid_pos_type(&pat_type.ty) {
156                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
157                        self.grid_pos_vars.push(ident.ident.to_string());
158                    }
159                }
160            }
161        }
162
163        // Generate function signature (without GridPos params)
164        let signature = self.transpile_kernel_signature(func)?;
165
166        // Generate preamble (thread index calculations)
167        let preamble = config.generate_preamble();
168
169        // Generate function body
170        let body = self.transpile_block(&func.block)?;
171
172        Ok(format!(
173            "extern \"C\" __global__ void {signature} {{\n{preamble}\n{body}}}\n"
174        ))
175    }
176
177    /// Transpile a generic (non-stencil) kernel function.
178    ///
179    /// This generates a `__global__` kernel without stencil-specific preamble.
180    /// The kernel code can use `thread_idx_x()`, `block_idx_x()` etc. to access
181    /// CUDA thread indices directly.
182    pub fn transpile_generic_kernel(&mut self, func: &ItemFn) -> Result<String> {
183        // Generate function signature with all params
184        let signature = self.transpile_generic_kernel_signature(func)?;
185
186        // Generate function body
187        let body = self.transpile_block(&func.block)?;
188
189        Ok(format!(
190            "extern \"C\" __global__ void {signature} {{\n{body}}}\n"
191        ))
192    }
193
194    /// Transpile a handler function into a persistent ring kernel.
195    ///
196    /// This wraps the handler body in a persistent message-processing loop
197    /// with control block integration, queue operations, and HLC support.
198    pub fn transpile_ring_kernel(
199        &mut self,
200        handler: &ItemFn,
201        config: &crate::ring_kernel::RingKernelConfig,
202    ) -> Result<String> {
203        use std::fmt::Write;
204
205        // Parse handler signature
206        let handler_sig = HandlerSignature::parse(handler, &self.type_mapper)?;
207
208        // Track context variables for method inlining
209        for param in &handler.sig.inputs {
210            if let FnArg::Typed(pat_type) = param {
211                if is_ring_context_type(&pat_type.ty) {
212                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
213                        self.context_vars.push(ident.ident.to_string());
214                    }
215                }
216            }
217        }
218
219        // Enable ring kernel mode for context method inlining
220        self.ring_kernel_mode = true;
221
222        let mut output = String::new();
223
224        // Generate struct definitions
225        output.push_str(&crate::ring_kernel::generate_control_block_struct());
226        output.push('\n');
227
228        if config.enable_hlc {
229            output.push_str(&crate::ring_kernel::generate_hlc_struct());
230            output.push('\n');
231        }
232
233        if config.enable_k2k {
234            output.push_str(&crate::ring_kernel::generate_k2k_structs());
235            output.push('\n');
236        }
237
238        // Generate message/response struct definitions if needed
239        if let Some(ref msg_param) = handler_sig.message_param {
240            // Extract type name from the parameter
241            let type_name = msg_param
242                .rust_type
243                .trim_start_matches('&')
244                .trim_start_matches("mut ")
245                .trim();
246            if !type_name.is_empty() && type_name != "f32" && type_name != "i32" {
247                writeln!(output, "// Message type: {}", type_name).unwrap();
248            }
249        }
250
251        if let Some(ref ret_type) = handler_sig.return_type {
252            if ret_type.is_struct {
253                writeln!(output, "// Response type: {}", ret_type.rust_type).unwrap();
254            }
255        }
256
257        // Generate kernel signature
258        output.push_str(&config.generate_signature());
259        output.push_str(" {\n");
260
261        // Generate preamble
262        output.push_str(&config.generate_preamble("    "));
263
264        // Generate message loop header
265        output.push_str(&config.generate_loop_header("    "));
266
267        // Generate message deserialization if handler has message param
268        if let Some(ref msg_param) = handler_sig.message_param {
269            let type_name = msg_param
270                .rust_type
271                .trim_start_matches('&')
272                .trim_start_matches("mut ")
273                .trim();
274            if !type_name.is_empty() {
275                writeln!(output, "        // Message deserialization").unwrap();
276                writeln!(
277                    output,
278                    "        // {}* {} = ({}*)msg_ptr;",
279                    type_name, msg_param.name, type_name
280                )
281                .unwrap();
282                output.push('\n');
283            }
284        }
285
286        // Transpile handler body
287        self.indent = 2; // Inside the message loop
288        let handler_body = self.transpile_block(&handler.block)?;
289
290        // Insert handler code with proper indentation
291        writeln!(output, "        // === USER HANDLER CODE ===").unwrap();
292        for line in handler_body.lines() {
293            if !line.trim().is_empty() {
294                // Add extra indent for being inside the loop
295                writeln!(output, "    {}", line).unwrap();
296            }
297        }
298        writeln!(output, "        // === END HANDLER CODE ===").unwrap();
299
300        // Generate response serialization if handler returns a value
301        if let Some(ref ret_type) = handler_sig.return_type {
302            writeln!(output).unwrap();
303            writeln!(output, "        // Response serialization").unwrap();
304            if ret_type.is_struct {
305                writeln!(output, "        // memcpy(&output_buffer[_out_idx * RESP_SIZE], &response, sizeof({}));",
306                    ret_type.cuda_type).unwrap();
307            }
308        }
309
310        // Generate message completion
311        output.push_str(&config.generate_message_complete("    "));
312
313        // Generate loop footer
314        output.push_str(&config.generate_loop_footer("    "));
315
316        // Generate epilogue
317        output.push_str(&config.generate_epilogue("    "));
318
319        output.push_str("}\n");
320
321        Ok(output)
322    }
323
324    /// Transpile a generic kernel function signature (keeps all params).
325    fn transpile_generic_kernel_signature(&self, func: &ItemFn) -> Result<String> {
326        let name = func.sig.ident.to_string();
327
328        let mut params = Vec::new();
329        for param in &func.sig.inputs {
330            if let FnArg::Typed(pat_type) = param {
331                let param_name = match pat_type.pat.as_ref() {
332                    Pat::Ident(ident) => ident.ident.to_string(),
333                    _ => {
334                        return Err(TranspileError::Unsupported(
335                            "Complex pattern in parameter".into(),
336                        ))
337                    }
338                };
339
340                let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
341                params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
342            }
343        }
344
345        Ok(format!("{}({})", name, params.join(", ")))
346    }
347
348    /// Transpile the kernel function signature.
349    fn transpile_kernel_signature(&self, func: &ItemFn) -> Result<String> {
350        let name = func.sig.ident.to_string();
351
352        let mut params = Vec::new();
353        for param in &func.sig.inputs {
354            if let FnArg::Typed(pat_type) = param {
355                // Skip GridPos parameters
356                if is_grid_pos_type(&pat_type.ty) {
357                    continue;
358                }
359
360                let param_name = match pat_type.pat.as_ref() {
361                    Pat::Ident(ident) => ident.ident.to_string(),
362                    _ => {
363                        return Err(TranspileError::Unsupported(
364                            "Complex pattern in parameter".into(),
365                        ))
366                    }
367                };
368
369                let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
370                params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
371            }
372        }
373
374        Ok(format!("{}({})", name, params.join(", ")))
375    }
376
377    /// Transpile a block of statements.
378    fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
379        let mut output = String::new();
380
381        for stmt in &block.stmts {
382            let stmt_str = self.transpile_stmt(stmt)?;
383            if !stmt_str.is_empty() {
384                output.push_str(&stmt_str);
385            }
386        }
387
388        Ok(output)
389    }
390
391    /// Transpile a single statement.
392    fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
393        match stmt {
394            Stmt::Local(local) => {
395                let indent = self.indent_str();
396
397                // Get variable name
398                let var_name = match &local.pat {
399                    Pat::Ident(ident) => ident.ident.to_string(),
400                    Pat::Type(pat_type) => {
401                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
402                            ident.ident.to_string()
403                        } else {
404                            return Err(TranspileError::Unsupported(
405                                "Complex pattern in let binding".into(),
406                            ));
407                        }
408                    }
409                    _ => {
410                        return Err(TranspileError::Unsupported(
411                            "Complex pattern in let binding".into(),
412                        ))
413                    }
414                };
415
416                // Check for SharedTile or SharedArray type annotation
417                if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
418                    // Register the shared variable
419                    self.shared_vars.insert(
420                        var_name.clone(),
421                        SharedVarInfo {
422                            name: var_name.clone(),
423                            is_tile: shared_decl.dimensions.len() == 2,
424                            dimensions: shared_decl.dimensions.clone(),
425                            element_type: shared_decl.element_type.clone(),
426                        },
427                    );
428
429                    // Add to shared memory config
430                    self.shared_memory.add(shared_decl.clone());
431
432                    // Return the __shared__ declaration
433                    return Ok(format!("{indent}{}\n", shared_decl.to_cuda_decl()));
434                }
435
436                // Get initializer
437                if let Some(init) = &local.init {
438                    let expr_str = self.transpile_expr(&init.expr)?;
439
440                    // Infer type from expression (default to float for now)
441                    // In a more complete implementation, we'd do proper type inference
442                    let type_str = self.infer_cuda_type(&init.expr);
443
444                    // Track if this is a pointer variable for -> access
445                    if type_str.ends_with('*') {
446                        self.pointer_vars.insert(var_name.clone());
447                    }
448
449                    Ok(format!("{indent}{type_str} {var_name} = {expr_str};\n"))
450                } else {
451                    // Uninitialized variable - need type annotation
452                    Ok(format!("{indent}float {var_name};\n"))
453                }
454            }
455            Stmt::Expr(expr, semi) => {
456                let indent = self.indent_str();
457
458                // Check if this is an if statement with early return (shouldn't wrap in return)
459                if let Expr::If(if_expr) = expr {
460                    // Check if the body contains only a return statement
461                    if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
462                    {
463                        if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
464                            let expr_str = self.transpile_expr(expr)?;
465                            return Ok(format!("{indent}{expr_str};\n"));
466                        }
467                    }
468                }
469
470                let expr_str = self.transpile_expr(expr)?;
471
472                if semi.is_some() {
473                    Ok(format!("{indent}{expr_str};\n"))
474                } else {
475                    // Expression without semicolon (implicit return)
476                    // But check if it's already a return or an if with return
477                    if matches!(expr, Expr::Return(_))
478                        || expr_str.starts_with("return")
479                        || expr_str.starts_with("if (")
480                    {
481                        Ok(format!("{indent}{expr_str};\n"))
482                    } else {
483                        Ok(format!("{indent}return {expr_str};\n"))
484                    }
485                }
486            }
487            Stmt::Item(_) => {
488                // Items in function body not supported
489                Err(TranspileError::Unsupported("Item in function body".into()))
490            }
491            Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
492        }
493    }
494
495    /// Transpile an expression.
496    fn transpile_expr(&self, expr: &Expr) -> Result<String> {
497        match expr {
498            Expr::Lit(lit) => self.transpile_lit(lit),
499            Expr::Path(path) => self.transpile_path(path),
500            Expr::Binary(bin) => self.transpile_binary(bin),
501            Expr::Unary(unary) => self.transpile_unary(unary),
502            Expr::Paren(paren) => self.transpile_paren(paren),
503            Expr::Index(index) => self.transpile_index(index),
504            Expr::Call(call) => self.transpile_call(call),
505            Expr::MethodCall(method) => self.transpile_method_call(method),
506            Expr::If(if_expr) => self.transpile_if(if_expr),
507            Expr::Assign(assign) => self.transpile_assign(assign),
508            Expr::Cast(cast) => self.transpile_cast(cast),
509            Expr::Match(match_expr) => self.transpile_match(match_expr),
510            Expr::Block(block) => {
511                // For block expressions, we just return the last expression
512                if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
513                    self.transpile_expr(expr)
514                } else {
515                    Err(TranspileError::Unsupported(
516                        "Complex block expression".into(),
517                    ))
518                }
519            }
520            Expr::Field(field) => {
521                // Struct field access: obj.field -> obj.field or obj->field for pointers
522                let base = self.transpile_expr(&field.base)?;
523                let member = match &field.member {
524                    syn::Member::Named(ident) => ident.to_string(),
525                    syn::Member::Unnamed(idx) => idx.index.to_string(),
526                };
527
528                // Check if base is a pointer variable - use -> instead of .
529                let accessor = if self.pointer_vars.contains(&base) {
530                    "->"
531                } else {
532                    "."
533                };
534                Ok(format!("{base}{accessor}{member}"))
535            }
536            Expr::Return(ret) => self.transpile_return(ret),
537            Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
538            Expr::While(while_loop) => self.transpile_while_loop(while_loop),
539            Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
540            Expr::Break(break_expr) => self.transpile_break(break_expr),
541            Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
542            Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
543            Expr::Reference(ref_expr) => self.transpile_reference(ref_expr),
544            Expr::Let(let_expr) => self.transpile_let_expr(let_expr),
545            Expr::Tuple(tuple) => {
546                // Transpile tuple as a comma-separated list (for multi-value returns, etc.)
547                let elements: Vec<String> = tuple
548                    .elems
549                    .iter()
550                    .map(|e| self.transpile_expr(e))
551                    .collect::<Result<_>>()?;
552                Ok(format!("({})", elements.join(", ")))
553            }
554            _ => Err(TranspileError::Unsupported(format!(
555                "Expression type: {}",
556                expr.to_token_stream()
557            ))),
558        }
559    }
560
561    /// Transpile a literal.
562    fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
563        match &lit.lit {
564            Lit::Float(f) => {
565                let s = f.to_string();
566                // Ensure f32 literals have 'f' suffix in CUDA
567                if s.ends_with("f32") || !s.contains('.') {
568                    let num = s.trim_end_matches("f32").trim_end_matches("f64");
569                    Ok(format!("{num}f"))
570                } else if s.ends_with("f64") {
571                    Ok(s.trim_end_matches("f64").to_string())
572                } else {
573                    // Plain float literal - add 'f' suffix for float
574                    Ok(format!("{s}f"))
575                }
576            }
577            Lit::Int(i) => Ok(i.to_string()),
578            Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
579            _ => Err(TranspileError::Unsupported(format!(
580                "Literal type: {}",
581                lit.to_token_stream()
582            ))),
583        }
584    }
585
586    /// Transpile a path (variable reference).
587    fn transpile_path(&self, path: &ExprPath) -> Result<String> {
588        let segments: Vec<_> = path
589            .path
590            .segments
591            .iter()
592            .map(|s| s.ident.to_string())
593            .collect();
594
595        if segments.len() == 1 {
596            Ok(segments[0].clone())
597        } else {
598            Ok(segments.join("::"))
599        }
600    }
601
602    /// Transpile a binary expression.
603    fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
604        let left = self.transpile_expr(&bin.left)?;
605        let right = self.transpile_expr(&bin.right)?;
606
607        let op = match bin.op {
608            BinOp::Add(_) => "+",
609            BinOp::Sub(_) => "-",
610            BinOp::Mul(_) => "*",
611            BinOp::Div(_) => "/",
612            BinOp::Rem(_) => "%",
613            BinOp::And(_) => "&&",
614            BinOp::Or(_) => "||",
615            BinOp::BitXor(_) => "^",
616            BinOp::BitAnd(_) => "&",
617            BinOp::BitOr(_) => "|",
618            BinOp::Shl(_) => "<<",
619            BinOp::Shr(_) => ">>",
620            BinOp::Eq(_) => "==",
621            BinOp::Lt(_) => "<",
622            BinOp::Le(_) => "<=",
623            BinOp::Ne(_) => "!=",
624            BinOp::Ge(_) => ">=",
625            BinOp::Gt(_) => ">",
626            BinOp::AddAssign(_) => "+=",
627            BinOp::SubAssign(_) => "-=",
628            BinOp::MulAssign(_) => "*=",
629            BinOp::DivAssign(_) => "/=",
630            BinOp::RemAssign(_) => "%=",
631            BinOp::BitXorAssign(_) => "^=",
632            BinOp::BitAndAssign(_) => "&=",
633            BinOp::BitOrAssign(_) => "|=",
634            BinOp::ShlAssign(_) => "<<=",
635            BinOp::ShrAssign(_) => ">>=",
636            _ => {
637                return Err(TranspileError::Unsupported(format!(
638                    "Binary operator: {}",
639                    bin.to_token_stream()
640                )))
641            }
642        };
643
644        Ok(format!("{left} {op} {right}"))
645    }
646
647    /// Transpile a unary expression.
648    fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
649        let expr = self.transpile_expr(&unary.expr)?;
650
651        let op = match unary.op {
652            UnOp::Neg(_) => "-",
653            UnOp::Not(_) => "!",
654            UnOp::Deref(_) => "*",
655            _ => {
656                return Err(TranspileError::Unsupported(format!(
657                    "Unary operator: {}",
658                    unary.to_token_stream()
659                )))
660            }
661        };
662
663        Ok(format!("{op}({expr})"))
664    }
665
666    /// Transpile a parenthesized expression.
667    fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
668        let inner = self.transpile_expr(&paren.expr)?;
669        Ok(format!("({inner})"))
670    }
671
672    /// Transpile an index expression.
673    fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
674        let base = self.transpile_expr(&index.expr)?;
675        let idx = self.transpile_expr(&index.index)?;
676        Ok(format!("{base}[{idx}]"))
677    }
678
679    /// Transpile a function call.
680    fn transpile_call(&self, call: &ExprCall) -> Result<String> {
681        let func = self.transpile_expr(&call.func)?;
682
683        // Check for intrinsics
684        if let Some(intrinsic) = self.intrinsics.lookup(&func) {
685            let cuda_name = intrinsic.to_cuda_string();
686
687            // Check if this is a "value" intrinsic (like threadIdx.x) vs a "function" intrinsic
688            // Value intrinsics don't have parentheses in CUDA
689            let is_value_intrinsic = cuda_name.contains("Idx.")
690                || cuda_name.contains("Dim.")
691                || cuda_name.starts_with("threadIdx")
692                || cuda_name.starts_with("blockIdx")
693                || cuda_name.starts_with("blockDim")
694                || cuda_name.starts_with("gridDim");
695
696            if is_value_intrinsic && call.args.is_empty() {
697                // Value intrinsics: threadIdx.x, blockIdx.y, etc. - no parens
698                return Ok(cuda_name.to_string());
699            }
700
701            if call.args.is_empty() && cuda_name.ends_with("()") {
702                // Zero-arg function intrinsics: __syncthreads()
703                return Ok(cuda_name.to_string());
704            }
705
706            let args: Vec<String> = call
707                .args
708                .iter()
709                .map(|a| self.transpile_expr(a))
710                .collect::<Result<_>>()?;
711
712            return Ok(format!(
713                "{}({})",
714                cuda_name.trim_end_matches("()"),
715                args.join(", ")
716            ));
717        }
718
719        // Regular function call
720        let args: Vec<String> = call
721            .args
722            .iter()
723            .map(|a| self.transpile_expr(a))
724            .collect::<Result<_>>()?;
725
726        Ok(format!("{}({})", func, args.join(", ")))
727    }
728
729    /// Transpile a method call, handling stencil intrinsics.
730    fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
731        let receiver = self.transpile_expr(&method.receiver)?;
732        let method_name = method.method.to_string();
733
734        // Check if this is a SharedTile/SharedArray method call
735        if let Some(result) =
736            self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
737        {
738            return result;
739        }
740
741        // Check if this is a RingContext method call (in ring kernel mode)
742        if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
743            return self.transpile_context_method(&method_name, &method.args);
744        }
745
746        // Check if this is a GridPos method call
747        if self.grid_pos_vars.contains(&receiver) {
748            return self.transpile_stencil_intrinsic(&method_name, &method.args);
749        }
750
751        // Check for ring kernel intrinsics (like is_active(), should_terminate())
752        if self.ring_kernel_mode {
753            if let Some(intrinsic) = RingKernelIntrinsic::from_name(&method_name) {
754                let args: Vec<String> = method
755                    .args
756                    .iter()
757                    .map(|a| self.transpile_expr(a).unwrap_or_default())
758                    .collect();
759                return Ok(intrinsic.to_cuda(&args));
760            }
761        }
762
763        // Check for f32/f64 math methods
764        if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
765            let cuda_name = intrinsic.to_cuda_string();
766            let args: Vec<String> = std::iter::once(receiver)
767                .chain(
768                    method
769                        .args
770                        .iter()
771                        .map(|a| self.transpile_expr(a).unwrap_or_default()),
772                )
773                .collect();
774
775            return Ok(format!("{}({})", cuda_name, args.join(", ")));
776        }
777
778        // Regular method call (treat as field access + call for C structs)
779        let args: Vec<String> = method
780            .args
781            .iter()
782            .map(|a| self.transpile_expr(a))
783            .collect::<Result<_>>()?;
784
785        Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
786    }
787
788    /// Transpile a RingContext method call to CUDA intrinsics.
789    fn transpile_context_method(
790        &self,
791        method: &str,
792        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
793    ) -> Result<String> {
794        let ctx_method = ContextMethod::from_name(method).ok_or_else(|| {
795            TranspileError::Unsupported(format!("Unknown context method: {}", method))
796        })?;
797
798        let cuda_args: Vec<String> = args
799            .iter()
800            .map(|a| self.transpile_expr(a).unwrap_or_default())
801            .collect();
802
803        Ok(ctx_method.to_cuda(&cuda_args))
804    }
805
806    /// Transpile a stencil intrinsic method call.
807    fn transpile_stencil_intrinsic(
808        &self,
809        method: &str,
810        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
811    ) -> Result<String> {
812        let config = self.config.as_ref().ok_or_else(|| {
813            TranspileError::Unsupported("Stencil intrinsic without config".into())
814        })?;
815
816        let buffer_width = config.buffer_width().to_string();
817        let buffer_slice = format!("{}", config.buffer_width() * config.buffer_height());
818        let is_3d = config.grid == crate::stencil::Grid::Grid3D;
819
820        let intrinsic = StencilIntrinsic::from_method_name(method).ok_or_else(|| {
821            TranspileError::Unsupported(format!("Unknown stencil intrinsic: {method}"))
822        })?;
823
824        // Check if 3D intrinsic used in non-3D kernel
825        if intrinsic.is_3d_only() && !is_3d {
826            return Err(TranspileError::Unsupported(format!(
827                "3D stencil intrinsic '{}' requires Grid3D configuration",
828                method
829            )));
830        }
831
832        match intrinsic {
833            StencilIntrinsic::Index => {
834                // pos.idx() -> idx
835                Ok("idx".to_string())
836            }
837            StencilIntrinsic::North
838            | StencilIntrinsic::South
839            | StencilIntrinsic::East
840            | StencilIntrinsic::West => {
841                // pos.north(buf) -> buf[idx - buffer_width]
842                if args.is_empty() {
843                    return Err(TranspileError::Unsupported(
844                        "Stencil accessor requires buffer argument".into(),
845                    ));
846                }
847                let buffer = self.transpile_expr(&args[0])?;
848                if is_3d {
849                    Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
850                } else {
851                    Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx"))
852                }
853            }
854            StencilIntrinsic::Up | StencilIntrinsic::Down => {
855                // 3D intrinsics: pos.up(buf) -> buf[idx - buffer_slice]
856                if args.is_empty() {
857                    return Err(TranspileError::Unsupported(
858                        "3D stencil accessor requires buffer argument".into(),
859                    ));
860                }
861                let buffer = self.transpile_expr(&args[0])?;
862                Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
863            }
864            StencilIntrinsic::At => {
865                // 2D: pos.at(buf, dx, dy) -> buf[idx + dy * buffer_width + dx]
866                // 3D: pos.at(buf, dx, dy, dz) -> buf[idx + dz * buffer_slice + dy * buffer_width + dx]
867                if is_3d {
868                    if args.len() < 4 {
869                        return Err(TranspileError::Unsupported(
870                            "at() in 3D requires buffer, dx, dy, dz arguments".into(),
871                        ));
872                    }
873                    let buffer = self.transpile_expr(&args[0])?;
874                    let dx = self.transpile_expr(&args[1])?;
875                    let dy = self.transpile_expr(&args[2])?;
876                    let dz = self.transpile_expr(&args[3])?;
877                    Ok(format!(
878                        "{buffer}[idx + ({dz}) * {buffer_slice} + ({dy}) * {buffer_width} + ({dx})]"
879                    ))
880                } else {
881                    if args.len() < 3 {
882                        return Err(TranspileError::Unsupported(
883                            "at() requires buffer, dx, dy arguments".into(),
884                        ));
885                    }
886                    let buffer = self.transpile_expr(&args[0])?;
887                    let dx = self.transpile_expr(&args[1])?;
888                    let dy = self.transpile_expr(&args[2])?;
889                    Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]"))
890                }
891            }
892        }
893    }
894
895    /// Transpile an if expression.
896    fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
897        let cond = self.transpile_expr(&if_expr.cond)?;
898
899        // Check if the body contains only a return statement (early return pattern)
900        if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
901            if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
902                // Simple early return: if (cond) return;
903                if ret.expr.is_none() {
904                    return Ok(format!("if ({cond}) return"));
905                }
906                let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
907                return Ok(format!("if ({cond}) return {ret_val}"));
908            }
909        }
910
911        // For now, only handle if-else as ternary when it's an expression
912        if let Some((_, else_branch)) = &if_expr.else_branch {
913            // If both branches are simple expressions, use ternary
914            if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
915                (if_expr.then_branch.stmts.last(), else_branch.as_ref())
916            {
917                if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
918                    let then_str = self.transpile_expr(then_expr)?;
919                    let else_str = self.transpile_expr(else_expr)?;
920                    return Ok(format!("({cond}) ? ({then_str}) : ({else_str})"));
921                }
922            }
923
924            // Otherwise, generate if statement
925            if let Expr::If(else_if) = else_branch.as_ref() {
926                // else if chain
927                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
928                let else_part = self.transpile_if(else_if)?;
929                return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
930            } else if let Expr::Block(else_block) = else_branch.as_ref() {
931                // else block
932                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
933                let else_body = self.transpile_if_body(&else_block.block)?;
934                return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
935            }
936        }
937
938        // If without else
939        let then_body = self.transpile_if_body(&if_expr.then_branch)?;
940        Ok(format!("if ({cond}) {{{then_body}}}"))
941    }
942
943    /// Transpile the body of an if branch.
944    fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
945        let mut body = String::new();
946        for stmt in &block.stmts {
947            match stmt {
948                Stmt::Expr(expr, Some(_)) => {
949                    let expr_str = self.transpile_expr(expr)?;
950                    body.push_str(&format!(" {expr_str};"));
951                }
952                Stmt::Expr(Expr::Return(ret), None) => {
953                    // Handle explicit return without semicolon
954                    if let Some(ret_expr) = &ret.expr {
955                        let expr_str = self.transpile_expr(ret_expr)?;
956                        body.push_str(&format!(" return {expr_str};"));
957                    } else {
958                        body.push_str(" return;");
959                    }
960                }
961                Stmt::Expr(expr, None) => {
962                    let expr_str = self.transpile_expr(expr)?;
963                    body.push_str(&format!(" return {expr_str};"));
964                }
965                _ => {}
966            }
967        }
968        Ok(body)
969    }
970
971    /// Transpile an assignment expression.
972    fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
973        let left = self.transpile_expr(&assign.left)?;
974        let right = self.transpile_expr(&assign.right)?;
975        Ok(format!("{left} = {right}"))
976    }
977
978    /// Transpile a cast expression.
979    fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
980        let expr = self.transpile_expr(&cast.expr)?;
981        let cuda_type = self.type_mapper.map_type(&cast.ty)?;
982        Ok(format!("({})({})", cuda_type.to_cuda_string(), expr))
983    }
984
985    /// Transpile a return expression.
986    fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
987        if let Some(expr) = &ret.expr {
988            let expr_str = self.transpile_expr(expr)?;
989            Ok(format!("return {expr_str}"))
990        } else {
991            Ok("return".to_string())
992        }
993    }
994
995    /// Transpile a struct literal expression.
996    ///
997    /// Converts Rust struct literals to C-style compound literals:
998    /// `Point { x: 1.0, y: 2.0 }` -> `(Point){ .x = 1.0f, .y = 2.0f }`
999    fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1000        // Get the struct type name
1001        let type_name = struct_expr
1002            .path
1003            .segments
1004            .iter()
1005            .map(|s| s.ident.to_string())
1006            .collect::<Vec<_>>()
1007            .join("::");
1008
1009        // Transpile each field
1010        let mut fields = Vec::new();
1011        for field in &struct_expr.fields {
1012            let field_name = match &field.member {
1013                syn::Member::Named(ident) => ident.to_string(),
1014                syn::Member::Unnamed(idx) => idx.index.to_string(),
1015            };
1016            let value = self.transpile_expr(&field.expr)?;
1017            fields.push(format!(".{} = {}", field_name, value));
1018        }
1019
1020        // Check for struct update syntax (not supported in C)
1021        if struct_expr.rest.is_some() {
1022            return Err(TranspileError::Unsupported(
1023                "Struct update syntax (..base) is not supported in CUDA".into(),
1024            ));
1025        }
1026
1027        // Generate C compound literal: (TypeName){ .field1 = val1, .field2 = val2 }
1028        Ok(format!("({}){{ {} }}", type_name, fields.join(", ")))
1029    }
1030
1031    /// Transpile a reference expression.
1032    ///
1033    /// In CUDA C, we typically need pointers. This handles:
1034    /// - `&arr[idx]` -> `&arr[idx]` (pointer to element)
1035    /// - `&mut arr[idx]` -> `&arr[idx]` (same in C)
1036    /// - `&variable` -> `&variable`
1037    fn transpile_reference(&self, ref_expr: &ExprReference) -> Result<String> {
1038        let inner = self.transpile_expr(&ref_expr.expr)?;
1039
1040        // In CUDA C, taking a reference is the same as taking address
1041        // For array indexing like &arr[idx], we produce &arr[idx]
1042        // This creates a pointer to that element
1043        Ok(format!("&{inner}"))
1044    }
1045
1046    /// Transpile a let expression (used in if-let patterns).
1047    ///
1048    /// Note: Full pattern matching is not supported in CUDA, but we can
1049    /// handle simple `if let Some(x) = expr` patterns by transpiling
1050    /// to a simple conditional check.
1051    fn transpile_let_expr(&self, let_expr: &ExprLet) -> Result<String> {
1052        // For now, we treat let expressions as unsupported since
1053        // CUDA doesn't have Option types. But we can add special cases.
1054        let _ = let_expr; // Silence unused warning
1055        Err(TranspileError::Unsupported(
1056            "let expressions (if-let patterns) are not directly supported in CUDA. \
1057             Use explicit comparisons instead."
1058                .into(),
1059        ))
1060    }
1061
1062    // === Loop Transpilation ===
1063
1064    /// Transpile a for loop to CUDA.
1065    ///
1066    /// Handles `for i in start..end` and `for i in start..=end` patterns.
1067    ///
1068    /// # Example
1069    ///
1070    /// ```ignore
1071    /// // Rust
1072    /// for i in 0..n {
1073    ///     data[i] = 0.0;
1074    /// }
1075    ///
1076    /// // CUDA
1077    /// for (int i = 0; i < n; i++) {
1078    ///     data[i] = 0.0f;
1079    /// }
1080    /// ```
1081    fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1082        // Check if loops are allowed
1083        if !self.validation_mode.allows_loops() {
1084            return Err(TranspileError::Unsupported(
1085                "Loops are not allowed in stencil kernels".into(),
1086            ));
1087        }
1088
1089        // Extract loop variable name
1090        let var_name = extract_loop_var(&for_loop.pat)
1091            .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1092
1093        // The iterator expression should be a range
1094        let header = match for_loop.expr.as_ref() {
1095            Expr::Range(range) => {
1096                let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1097                range_info.to_cuda_for_header(&var_name, "int")
1098            }
1099            _ => {
1100                // For non-range iterators, we can't directly transpile
1101                return Err(TranspileError::Unsupported(
1102                    "Only range expressions (start..end) are supported in for loops".into(),
1103                ));
1104            }
1105        };
1106
1107        // Transpile the loop body
1108        let body = self.transpile_loop_body(&for_loop.body)?;
1109
1110        Ok(format!("{header} {{\n{body}}}"))
1111    }
1112
1113    /// Transpile a while loop to CUDA.
1114    ///
1115    /// # Example
1116    ///
1117    /// ```ignore
1118    /// // Rust
1119    /// while !done {
1120    ///     process();
1121    /// }
1122    ///
1123    /// // CUDA
1124    /// while (!done) {
1125    ///     process();
1126    /// }
1127    /// ```
1128    fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1129        // Check if loops are allowed
1130        if !self.validation_mode.allows_loops() {
1131            return Err(TranspileError::Unsupported(
1132                "Loops are not allowed in stencil kernels".into(),
1133            ));
1134        }
1135
1136        // Transpile the condition
1137        let condition = self.transpile_expr(&while_loop.cond)?;
1138
1139        // Transpile the loop body
1140        let body = self.transpile_loop_body(&while_loop.body)?;
1141
1142        Ok(format!("while ({condition}) {{\n{body}}}"))
1143    }
1144
1145    /// Transpile an infinite loop to CUDA.
1146    ///
1147    /// # Example
1148    ///
1149    /// ```ignore
1150    /// // Rust
1151    /// loop {
1152    ///     if should_exit { break; }
1153    /// }
1154    ///
1155    /// // CUDA
1156    /// while (true) {
1157    ///     if (should_exit) { break; }
1158    /// }
1159    /// ```
1160    fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1161        // Check if loops are allowed
1162        if !self.validation_mode.allows_loops() {
1163            return Err(TranspileError::Unsupported(
1164                "Loops are not allowed in stencil kernels".into(),
1165            ));
1166        }
1167
1168        // Transpile the loop body
1169        let body = self.transpile_loop_body(&loop_expr.body)?;
1170
1171        // Use while(true) for infinite loops
1172        Ok(format!("while (true) {{\n{body}}}"))
1173    }
1174
1175    /// Transpile a break expression.
1176    fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1177        // Check for labeled break (not supported)
1178        if break_expr.label.is_some() {
1179            return Err(TranspileError::Unsupported(
1180                "Labeled break is not supported in CUDA".into(),
1181            ));
1182        }
1183
1184        // Check for break with value (not supported in CUDA)
1185        if break_expr.expr.is_some() {
1186            return Err(TranspileError::Unsupported(
1187                "Break with value is not supported in CUDA".into(),
1188            ));
1189        }
1190
1191        Ok("break".to_string())
1192    }
1193
1194    /// Transpile a continue expression.
1195    fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1196        // Check for labeled continue (not supported)
1197        if cont_expr.label.is_some() {
1198            return Err(TranspileError::Unsupported(
1199                "Labeled continue is not supported in CUDA".into(),
1200            ));
1201        }
1202
1203        Ok("continue".to_string())
1204    }
1205
1206    /// Transpile a loop body (block of statements).
1207    fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1208        let mut output = String::new();
1209        let inner_indent = "    ".repeat(self.indent + 1);
1210
1211        for stmt in &block.stmts {
1212            match stmt {
1213                Stmt::Local(local) => {
1214                    // Variable declaration
1215                    let var_name = match &local.pat {
1216                        Pat::Ident(ident) => ident.ident.to_string(),
1217                        Pat::Type(pat_type) => {
1218                            if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1219                                ident.ident.to_string()
1220                            } else {
1221                                return Err(TranspileError::Unsupported(
1222                                    "Complex pattern in let binding".into(),
1223                                ));
1224                            }
1225                        }
1226                        _ => {
1227                            return Err(TranspileError::Unsupported(
1228                                "Complex pattern in let binding".into(),
1229                            ))
1230                        }
1231                    };
1232
1233                    if let Some(init) = &local.init {
1234                        let expr_str = self.transpile_expr(&init.expr)?;
1235                        let type_str = self.infer_cuda_type(&init.expr);
1236                        output.push_str(&format!(
1237                            "{inner_indent}{type_str} {var_name} = {expr_str};\n"
1238                        ));
1239                    } else {
1240                        output.push_str(&format!("{inner_indent}float {var_name};\n"));
1241                    }
1242                }
1243                Stmt::Expr(expr, semi) => {
1244                    let expr_str = self.transpile_expr(expr)?;
1245                    if semi.is_some() {
1246                        output.push_str(&format!("{inner_indent}{expr_str};\n"));
1247                    } else {
1248                        // Expression without semicolon at end of block
1249                        output.push_str(&format!("{inner_indent}{expr_str};\n"));
1250                    }
1251                }
1252                _ => {
1253                    return Err(TranspileError::Unsupported(
1254                        "Unsupported statement in loop body".into(),
1255                    ));
1256                }
1257            }
1258        }
1259
1260        // Add closing indentation
1261        let closing_indent = "    ".repeat(self.indent);
1262        output.push_str(&closing_indent);
1263
1264        Ok(output)
1265    }
1266
1267    // === Shared Memory Support ===
1268
1269    /// Try to parse a local variable declaration as a shared memory declaration.
1270    ///
1271    /// Recognizes patterns like:
1272    /// - `let tile = SharedTile::<f32, 16, 16>::new();`
1273    /// - `let buffer = SharedArray::<f32, 256>::new();`
1274    /// - `let tile: SharedTile<f32, 16, 16> = SharedTile::new();`
1275    fn try_parse_shared_declaration(
1276        &self,
1277        local: &syn::Local,
1278        var_name: &str,
1279    ) -> Result<Option<SharedMemoryDecl>> {
1280        // Check if there's a type annotation
1281        if let Pat::Type(pat_type) = &local.pat {
1282            let type_str = pat_type.ty.to_token_stream().to_string();
1283            return self.parse_shared_type(&type_str, var_name);
1284        }
1285
1286        // Check the initializer expression for SharedTile::new() or SharedArray::new()
1287        if let Some(init) = &local.init {
1288            if let Expr::Call(call) = init.expr.as_ref() {
1289                if let Expr::Path(path) = call.func.as_ref() {
1290                    let path_str = path.to_token_stream().to_string();
1291                    return self.parse_shared_type(&path_str, var_name);
1292                }
1293            }
1294        }
1295
1296        Ok(None)
1297    }
1298
1299    /// Parse a type string to extract shared memory info.
1300    fn parse_shared_type(
1301        &self,
1302        type_str: &str,
1303        var_name: &str,
1304    ) -> Result<Option<SharedMemoryDecl>> {
1305        // Clean up the type string (remove spaces around ::)
1306        let type_str = type_str
1307            .replace(" :: ", "::")
1308            .replace(" ::", "::")
1309            .replace(":: ", "::");
1310
1311        // Check for SharedTile<T, W, H> or SharedTile::<T, W, H>::new
1312        if type_str.contains("SharedTile") {
1313            // Extract the generic parameters
1314            if let Some(start) = type_str.find('<') {
1315                if let Some(end) = type_str.rfind('>') {
1316                    let params = &type_str[start + 1..end];
1317                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1318
1319                    if parts.len() >= 3 {
1320                        let rust_type = parts[0];
1321                        let width: usize = parts[1].parse().map_err(|_| {
1322                            TranspileError::Unsupported("Invalid SharedTile width".into())
1323                        })?;
1324                        let height: usize = parts[2].parse().map_err(|_| {
1325                            TranspileError::Unsupported("Invalid SharedTile height".into())
1326                        })?;
1327
1328                        let cuda_type = rust_to_cuda_element_type(rust_type);
1329                        return Ok(Some(SharedMemoryDecl::tile(
1330                            var_name, cuda_type, width, height,
1331                        )));
1332                    }
1333                }
1334            }
1335        }
1336
1337        // Check for SharedArray<T, N> or SharedArray::<T, N>::new
1338        if type_str.contains("SharedArray") {
1339            if let Some(start) = type_str.find('<') {
1340                if let Some(end) = type_str.rfind('>') {
1341                    let params = &type_str[start + 1..end];
1342                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1343
1344                    if parts.len() >= 2 {
1345                        let rust_type = parts[0];
1346                        let size: usize = parts[1].parse().map_err(|_| {
1347                            TranspileError::Unsupported("Invalid SharedArray size".into())
1348                        })?;
1349
1350                        let cuda_type = rust_to_cuda_element_type(rust_type);
1351                        return Ok(Some(SharedMemoryDecl::array(var_name, cuda_type, size)));
1352                    }
1353                }
1354            }
1355        }
1356
1357        Ok(None)
1358    }
1359
1360    /// Check if a variable is a shared memory variable and handle method calls.
1361    fn try_transpile_shared_method_call(
1362        &self,
1363        receiver: &str,
1364        method_name: &str,
1365        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1366    ) -> Option<Result<String>> {
1367        let shared_info = self.shared_vars.get(receiver)?;
1368
1369        match method_name {
1370            "get" => {
1371                // tile.get(x, y) -> tile[y][x] (for 2D) or arr[idx] (for 1D)
1372                if shared_info.is_tile {
1373                    if args.len() >= 2 {
1374                        let x = self.transpile_expr(&args[0]).ok()?;
1375                        let y = self.transpile_expr(&args[1]).ok()?;
1376                        // CUDA uses row-major: tile[row][col] = tile[y][x]
1377                        Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1378                    } else {
1379                        Some(Err(TranspileError::Unsupported(
1380                            "SharedTile.get requires x and y arguments".into(),
1381                        )))
1382                    }
1383                } else {
1384                    // 1D array
1385                    if !args.is_empty() {
1386                        let idx = self.transpile_expr(&args[0]).ok()?;
1387                        Some(Ok(format!("{}[{}]", receiver, idx)))
1388                    } else {
1389                        Some(Err(TranspileError::Unsupported(
1390                            "SharedArray.get requires index argument".into(),
1391                        )))
1392                    }
1393                }
1394            }
1395            "set" => {
1396                // tile.set(x, y, val) -> tile[y][x] = val
1397                if shared_info.is_tile {
1398                    if args.len() >= 3 {
1399                        let x = self.transpile_expr(&args[0]).ok()?;
1400                        let y = self.transpile_expr(&args[1]).ok()?;
1401                        let val = self.transpile_expr(&args[2]).ok()?;
1402                        Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1403                    } else {
1404                        Some(Err(TranspileError::Unsupported(
1405                            "SharedTile.set requires x, y, and value arguments".into(),
1406                        )))
1407                    }
1408                } else {
1409                    // 1D array
1410                    if args.len() >= 2 {
1411                        let idx = self.transpile_expr(&args[0]).ok()?;
1412                        let val = self.transpile_expr(&args[1]).ok()?;
1413                        Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1414                    } else {
1415                        Some(Err(TranspileError::Unsupported(
1416                            "SharedArray.set requires index and value arguments".into(),
1417                        )))
1418                    }
1419                }
1420            }
1421            "width" | "height" | "size" => {
1422                // These are compile-time constants
1423                match method_name {
1424                    "width" if shared_info.is_tile => {
1425                        Some(Ok(shared_info.dimensions[1].to_string()))
1426                    }
1427                    "height" if shared_info.is_tile => {
1428                        Some(Ok(shared_info.dimensions[0].to_string()))
1429                    }
1430                    "size" => {
1431                        let total: usize = shared_info.dimensions.iter().product();
1432                        Some(Ok(total.to_string()))
1433                    }
1434                    _ => None,
1435                }
1436            }
1437            _ => None,
1438        }
1439    }
1440
1441    /// Transpile a match expression to switch/case.
1442    fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1443        let scrutinee = self.transpile_expr(&match_expr.expr)?;
1444        let mut output = format!("switch ({scrutinee}) {{\n");
1445
1446        for arm in &match_expr.arms {
1447            // Handle the pattern
1448            let case_label = self.transpile_match_pattern(&arm.pat)?;
1449
1450            if case_label == "default" || case_label.starts_with("/*") {
1451                output.push_str("        default: {\n");
1452            } else {
1453                output.push_str(&format!("        case {case_label}: {{\n"));
1454            }
1455
1456            // Handle the arm body - check if it's a block with statements
1457            match arm.body.as_ref() {
1458                Expr::Block(block) => {
1459                    // Block expression with multiple statements
1460                    for stmt in &block.block.stmts {
1461                        let stmt_str = self.transpile_stmt_inline(stmt)?;
1462                        output.push_str(&format!("            {stmt_str}\n"));
1463                    }
1464                }
1465                _ => {
1466                    // Single expression - wrap in statement
1467                    let body = self.transpile_expr(&arm.body)?;
1468                    output.push_str(&format!("            {body};\n"));
1469                }
1470            }
1471
1472            output.push_str("            break;\n");
1473            output.push_str("        }\n");
1474        }
1475
1476        output.push_str("    }");
1477        Ok(output)
1478    }
1479
1480    /// Transpile a match pattern to a case label.
1481    fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1482        match pat {
1483            Pat::Lit(pat_lit) => {
1484                // Integer literal pattern - pat_lit.lit contains the literal
1485                match &pat_lit.lit {
1486                    Lit::Int(i) => Ok(i.to_string()),
1487                    Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
1488                    _ => Err(TranspileError::Unsupported(
1489                        "Non-integer literal in match pattern".into(),
1490                    )),
1491                }
1492            }
1493            Pat::Wild(_) => {
1494                // _ pattern becomes default
1495                Ok("default".to_string())
1496            }
1497            Pat::Ident(ident) => {
1498                // Named pattern - treat as default case for now
1499                // This handles things like `x => ...` which bind a value
1500                Ok(format!("/* {} */ default", ident.ident))
1501            }
1502            Pat::Or(pat_or) => {
1503                // Multiple patterns: 0 | 1 | 2 => ...
1504                // CUDA switch doesn't support this directly, we need multiple case labels
1505                // For now, just use the first pattern and note the limitation
1506                if let Some(first) = pat_or.cases.first() {
1507                    self.transpile_match_pattern(first)
1508                } else {
1509                    Err(TranspileError::Unsupported("Empty or pattern".into()))
1510                }
1511            }
1512            _ => Err(TranspileError::Unsupported(format!(
1513                "Match pattern: {}",
1514                pat.to_token_stream()
1515            ))),
1516        }
1517    }
1518
1519    /// Transpile a statement without indentation (for inline use in switch).
1520    fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1521        match stmt {
1522            Stmt::Local(local) => {
1523                let var_name = match &local.pat {
1524                    Pat::Ident(ident) => ident.ident.to_string(),
1525                    Pat::Type(pat_type) => {
1526                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1527                            ident.ident.to_string()
1528                        } else {
1529                            return Err(TranspileError::Unsupported(
1530                                "Complex pattern in let binding".into(),
1531                            ));
1532                        }
1533                    }
1534                    _ => {
1535                        return Err(TranspileError::Unsupported(
1536                            "Complex pattern in let binding".into(),
1537                        ))
1538                    }
1539                };
1540
1541                if let Some(init) = &local.init {
1542                    let expr_str = self.transpile_expr(&init.expr)?;
1543                    let type_str = self.infer_cuda_type(&init.expr);
1544                    Ok(format!("{type_str} {var_name} = {expr_str};"))
1545                } else {
1546                    Ok(format!("float {var_name};"))
1547                }
1548            }
1549            Stmt::Expr(expr, semi) => {
1550                let expr_str = self.transpile_expr(expr)?;
1551                if semi.is_some() {
1552                    Ok(format!("{expr_str};"))
1553                } else {
1554                    Ok(format!("return {expr_str};"))
1555                }
1556            }
1557            _ => Err(TranspileError::Unsupported(
1558                "Unsupported statement in match arm".into(),
1559            )),
1560        }
1561    }
1562
1563    /// Infer CUDA type from expression (simple heuristic).
1564    fn infer_cuda_type(&self, expr: &Expr) -> &'static str {
1565        match expr {
1566            Expr::Lit(lit) => match &lit.lit {
1567                Lit::Float(_) => "float",
1568                Lit::Int(_) => "int",
1569                Lit::Bool(_) => "int",
1570                _ => "float",
1571            },
1572            Expr::Binary(bin) => {
1573                // Check if this is an integer operation
1574                let left_type = self.infer_cuda_type(&bin.left);
1575                let right_type = self.infer_cuda_type(&bin.right);
1576                // If both sides are int, result is int
1577                if left_type == "int" && right_type == "int" {
1578                    "int"
1579                } else {
1580                    "float"
1581                }
1582            }
1583            Expr::Call(call) => {
1584                // Check for intrinsics that return int
1585                if let Ok(func) = self.transpile_expr(&call.func) {
1586                    if let Some(intrinsic) = self.intrinsics.lookup(&func) {
1587                        let cuda_name = intrinsic.to_cuda_string();
1588                        // Thread/block indices return int
1589                        if cuda_name.contains("Idx") || cuda_name.contains("Dim") {
1590                            return "int";
1591                        }
1592                    }
1593                }
1594                "float"
1595            }
1596            Expr::Index(_) => "float", // Array access - could be any type, default to float
1597            Expr::Cast(cast) => {
1598                // Use the target type of the cast
1599                if let Ok(cuda_type) = self.type_mapper.map_type(&cast.ty) {
1600                    let s = cuda_type.to_cuda_string();
1601                    if s.contains("int") || s.contains("size_t") || s == "unsigned long long" {
1602                        return "int";
1603                    }
1604                }
1605                "float"
1606            }
1607            Expr::Reference(ref_expr) => {
1608                // Reference expression - try to determine pointer type from inner
1609                // For &arr[idx], we get a pointer to the element type
1610                match ref_expr.expr.as_ref() {
1611                    Expr::Index(idx_expr) => {
1612                        // &arr[idx] - get the base array name and look it up
1613                        if let Expr::Path(path) = &*idx_expr.expr {
1614                            let name = path
1615                                .path
1616                                .segments
1617                                .iter()
1618                                .map(|s| s.ident.to_string())
1619                                .collect::<Vec<_>>()
1620                                .join("::");
1621                            // Common GPU struct names -> pointer to that struct
1622                            if name.contains("transaction") || name.contains("Transaction") {
1623                                return "GpuTransaction*";
1624                            }
1625                            if name.contains("profile") || name.contains("Profile") {
1626                                return "GpuCustomerProfile*";
1627                            }
1628                            if name.contains("alert") || name.contains("Alert") {
1629                                return "GpuAlert*";
1630                            }
1631                        }
1632                        "float*" // Default element pointer
1633                    }
1634                    _ => "void*",
1635                }
1636            }
1637            Expr::MethodCall(_) => "float",
1638            Expr::Field(field) => {
1639                // Field access - try to infer type from field name
1640                let member_name = match &field.member {
1641                    syn::Member::Named(ident) => ident.to_string(),
1642                    syn::Member::Unnamed(idx) => idx.index.to_string(),
1643                };
1644                // Common field name patterns
1645                if member_name.contains("count") || member_name.contains("_count") {
1646                    return "unsigned int";
1647                }
1648                if member_name.contains("threshold") || member_name.ends_with("_id") {
1649                    return "unsigned long long";
1650                }
1651                if member_name.ends_with("_pct") {
1652                    return "unsigned char";
1653                }
1654                "float"
1655            }
1656            Expr::Path(path) => {
1657                // Variable access - check if it's a known variable
1658                let name = path
1659                    .path
1660                    .segments
1661                    .iter()
1662                    .map(|s| s.ident.to_string())
1663                    .collect::<Vec<_>>()
1664                    .join("::");
1665                if name.contains("threshold")
1666                    || name.contains("count")
1667                    || name == "idx"
1668                    || name == "n"
1669                {
1670                    return "int";
1671                }
1672                "float"
1673            }
1674            Expr::If(if_expr) => {
1675                // For ternary (if-else), infer type from branches
1676                if let Some((_, else_branch)) = &if_expr.else_branch {
1677                    if let Expr::Block(block) = else_branch.as_ref() {
1678                        if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
1679                            return self.infer_cuda_type(expr);
1680                        }
1681                    }
1682                }
1683                // Try from then branch
1684                if let Some(Stmt::Expr(expr, None)) = if_expr.then_branch.stmts.last() {
1685                    return self.infer_cuda_type(expr);
1686                }
1687                "float"
1688            }
1689            _ => "float",
1690        }
1691    }
1692}
1693
1694/// Transpile a function to CUDA without stencil configuration.
1695pub fn transpile_function(func: &ItemFn) -> Result<String> {
1696    let mut transpiler = CudaTranspiler::new_generic();
1697
1698    // Generate function signature
1699    let name = func.sig.ident.to_string();
1700
1701    let mut params = Vec::new();
1702    for param in &func.sig.inputs {
1703        if let FnArg::Typed(pat_type) = param {
1704            let param_name = match pat_type.pat.as_ref() {
1705                Pat::Ident(ident) => ident.ident.to_string(),
1706                _ => continue,
1707            };
1708
1709            let cuda_type = transpiler.type_mapper.map_type(&pat_type.ty)?;
1710            params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
1711        }
1712    }
1713
1714    // Return type
1715    let return_type = match &func.sig.output {
1716        ReturnType::Default => "void".to_string(),
1717        ReturnType::Type(_, ty) => transpiler.type_mapper.map_type(ty)?.to_cuda_string(),
1718    };
1719
1720    // Generate body
1721    let body = transpiler.transpile_block(&func.block)?;
1722
1723    Ok(format!(
1724        "__device__ {return_type} {name}({params}) {{\n{body}}}\n",
1725        params = params.join(", ")
1726    ))
1727}
1728
1729#[cfg(test)]
1730mod tests {
1731    use super::*;
1732    use syn::parse_quote;
1733
1734    #[test]
1735    fn test_simple_arithmetic() {
1736        let transpiler = CudaTranspiler::new_generic();
1737
1738        let expr: Expr = parse_quote!(a + b * 2.0);
1739        let result = transpiler.transpile_expr(&expr).unwrap();
1740        assert_eq!(result, "a + b * 2.0f");
1741    }
1742
1743    #[test]
1744    fn test_let_binding() {
1745        let mut transpiler = CudaTranspiler::new_generic();
1746
1747        let stmt: Stmt = parse_quote!(let x = a + b;);
1748        let result = transpiler.transpile_stmt(&stmt).unwrap();
1749        assert!(result.contains("float x = a + b;"));
1750    }
1751
1752    #[test]
1753    fn test_array_index() {
1754        let transpiler = CudaTranspiler::new_generic();
1755
1756        let expr: Expr = parse_quote!(data[idx]);
1757        let result = transpiler.transpile_expr(&expr).unwrap();
1758        assert_eq!(result, "data[idx]");
1759    }
1760
1761    #[test]
1762    fn test_stencil_intrinsics() {
1763        let config = StencilConfig::new("test")
1764            .with_tile_size(16, 16)
1765            .with_halo(1);
1766        let mut transpiler = CudaTranspiler::new(config);
1767        transpiler.grid_pos_vars.push("pos".to_string());
1768
1769        // Test pos.idx()
1770        let expr: Expr = parse_quote!(pos.idx());
1771        let result = transpiler.transpile_expr(&expr).unwrap();
1772        assert_eq!(result, "idx");
1773
1774        // Test pos.north(p)
1775        let expr: Expr = parse_quote!(pos.north(p));
1776        let result = transpiler.transpile_expr(&expr).unwrap();
1777        assert_eq!(result, "p[idx - 18]");
1778
1779        // Test pos.east(p)
1780        let expr: Expr = parse_quote!(pos.east(p));
1781        let result = transpiler.transpile_expr(&expr).unwrap();
1782        assert_eq!(result, "p[idx + 1]");
1783    }
1784
1785    #[test]
1786    fn test_ternary_if() {
1787        let transpiler = CudaTranspiler::new_generic();
1788
1789        let expr: Expr = parse_quote!(if x > 0.0 { x } else { -x });
1790        let result = transpiler.transpile_expr(&expr).unwrap();
1791        assert!(result.contains("?"));
1792        assert!(result.contains(":"));
1793    }
1794
1795    #[test]
1796    fn test_full_stencil_kernel() {
1797        let func: ItemFn = parse_quote! {
1798            fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
1799                let curr = p[pos.idx()];
1800                let prev = p_prev[pos.idx()];
1801                let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
1802                p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
1803            }
1804        };
1805
1806        let config = StencilConfig::new("fdtd")
1807            .with_tile_size(16, 16)
1808            .with_halo(1);
1809
1810        let mut transpiler = CudaTranspiler::new(config);
1811        let cuda = transpiler.transpile_stencil(&func).unwrap();
1812
1813        // Check key features
1814        assert!(cuda.contains("extern \"C\" __global__"));
1815        assert!(cuda.contains("threadIdx.x"));
1816        assert!(cuda.contains("threadIdx.y"));
1817        assert!(cuda.contains("buffer_width = 18"));
1818        assert!(cuda.contains("const float* __restrict__ p"));
1819        assert!(cuda.contains("float* __restrict__ p_prev"));
1820        assert!(!cuda.contains("GridPos")); // GridPos should be removed
1821
1822        println!("Generated CUDA:\n{}", cuda);
1823    }
1824
1825    #[test]
1826    fn test_early_return() {
1827        let mut transpiler = CudaTranspiler::new_generic();
1828
1829        let stmt: Stmt = parse_quote!(return;);
1830        let result = transpiler.transpile_stmt(&stmt).unwrap();
1831        assert!(result.contains("return;"));
1832
1833        let stmt_val: Stmt = parse_quote!(return 42;);
1834        let result_val = transpiler.transpile_stmt(&stmt_val).unwrap();
1835        assert!(result_val.contains("return 42;"));
1836    }
1837
1838    #[test]
1839    fn test_match_to_switch() {
1840        let transpiler = CudaTranspiler::new_generic();
1841
1842        let expr: Expr = parse_quote! {
1843            match edge {
1844                0 => { idx = 1 * 18 + i; }
1845                1 => { idx = 16 * 18 + i; }
1846                _ => { idx = 0; }
1847            }
1848        };
1849
1850        let result = transpiler.transpile_expr(&expr).unwrap();
1851        assert!(
1852            result.contains("switch (edge)"),
1853            "Should generate switch: {}",
1854            result
1855        );
1856        assert!(result.contains("case 0:"), "Should have case 0: {}", result);
1857        assert!(result.contains("case 1:"), "Should have case 1: {}", result);
1858        assert!(
1859            result.contains("default:"),
1860            "Should have default: {}",
1861            result
1862        );
1863        assert!(result.contains("break;"), "Should have break: {}", result);
1864
1865        println!("Generated switch:\n{}", result);
1866    }
1867
1868    #[test]
1869    fn test_block_idx_intrinsics() {
1870        let transpiler = CudaTranspiler::new_generic();
1871
1872        // Test block_idx_x() call
1873        let expr: Expr = parse_quote!(block_idx_x());
1874        let result = transpiler.transpile_expr(&expr).unwrap();
1875        assert_eq!(result, "blockIdx.x");
1876
1877        // Test thread_idx_y() call
1878        let expr2: Expr = parse_quote!(thread_idx_y());
1879        let result2 = transpiler.transpile_expr(&expr2).unwrap();
1880        assert_eq!(result2, "threadIdx.y");
1881
1882        // Test grid_dim_x() call
1883        let expr3: Expr = parse_quote!(grid_dim_x());
1884        let result3 = transpiler.transpile_expr(&expr3).unwrap();
1885        assert_eq!(result3, "gridDim.x");
1886    }
1887
1888    #[test]
1889    fn test_global_index_calculation() {
1890        let transpiler = CudaTranspiler::new_generic();
1891
1892        // Common CUDA pattern: gx = blockIdx.x * blockDim.x + threadIdx.x
1893        let expr: Expr = parse_quote!(block_idx_x() * block_dim_x() + thread_idx_x());
1894        let result = transpiler.transpile_expr(&expr).unwrap();
1895        assert!(result.contains("blockIdx.x"), "Should contain blockIdx.x");
1896        assert!(result.contains("blockDim.x"), "Should contain blockDim.x");
1897        assert!(result.contains("threadIdx.x"), "Should contain threadIdx.x");
1898
1899        println!("Global index expression: {}", result);
1900    }
1901
1902    // === Loop Transpilation Tests ===
1903
1904    #[test]
1905    fn test_for_loop_transpile() {
1906        let transpiler = CudaTranspiler::new_generic();
1907
1908        let expr: Expr = parse_quote! {
1909            for i in 0..n {
1910                data[i] = 0.0;
1911            }
1912        };
1913
1914        let result = transpiler.transpile_expr(&expr).unwrap();
1915        assert!(
1916            result.contains("for (int i = 0; i < n; i++)"),
1917            "Should generate for loop header: {}",
1918            result
1919        );
1920        assert!(
1921            result.contains("data[i] = 0.0f"),
1922            "Should contain loop body: {}",
1923            result
1924        );
1925
1926        println!("Generated for loop:\n{}", result);
1927    }
1928
1929    #[test]
1930    fn test_for_loop_inclusive_range() {
1931        let transpiler = CudaTranspiler::new_generic();
1932
1933        let expr: Expr = parse_quote! {
1934            for i in 1..=10 {
1935                sum += i;
1936            }
1937        };
1938
1939        let result = transpiler.transpile_expr(&expr).unwrap();
1940        assert!(
1941            result.contains("for (int i = 1; i <= 10; i++)"),
1942            "Should generate inclusive range: {}",
1943            result
1944        );
1945
1946        println!("Generated inclusive for loop:\n{}", result);
1947    }
1948
1949    #[test]
1950    fn test_while_loop_transpile() {
1951        let transpiler = CudaTranspiler::new_generic();
1952
1953        let expr: Expr = parse_quote! {
1954            while i < 10 {
1955                i += 1;
1956            }
1957        };
1958
1959        let result = transpiler.transpile_expr(&expr).unwrap();
1960        assert!(
1961            result.contains("while (i < 10)"),
1962            "Should generate while loop: {}",
1963            result
1964        );
1965        assert!(
1966            result.contains("i += 1"),
1967            "Should contain loop body: {}",
1968            result
1969        );
1970
1971        println!("Generated while loop:\n{}", result);
1972    }
1973
1974    #[test]
1975    fn test_while_loop_negation() {
1976        let transpiler = CudaTranspiler::new_generic();
1977
1978        let expr: Expr = parse_quote! {
1979            while !done {
1980                process();
1981            }
1982        };
1983
1984        let result = transpiler.transpile_expr(&expr).unwrap();
1985        assert!(
1986            result.contains("while (!(done))"),
1987            "Should negate condition: {}",
1988            result
1989        );
1990
1991        println!("Generated while loop with negation:\n{}", result);
1992    }
1993
1994    #[test]
1995    fn test_infinite_loop_transpile() {
1996        let transpiler = CudaTranspiler::new_generic();
1997
1998        let expr: Expr = parse_quote! {
1999            loop {
2000                process();
2001            }
2002        };
2003
2004        let result = transpiler.transpile_expr(&expr).unwrap();
2005        assert!(
2006            result.contains("while (true)"),
2007            "Should generate infinite loop: {}",
2008            result
2009        );
2010        assert!(
2011            result.contains("process()"),
2012            "Should contain loop body: {}",
2013            result
2014        );
2015
2016        println!("Generated infinite loop:\n{}", result);
2017    }
2018
2019    #[test]
2020    fn test_break_transpile() {
2021        let transpiler = CudaTranspiler::new_generic();
2022
2023        let expr: Expr = parse_quote!(break);
2024        let result = transpiler.transpile_expr(&expr).unwrap();
2025        assert_eq!(result, "break");
2026    }
2027
2028    #[test]
2029    fn test_continue_transpile() {
2030        let transpiler = CudaTranspiler::new_generic();
2031
2032        let expr: Expr = parse_quote!(continue);
2033        let result = transpiler.transpile_expr(&expr).unwrap();
2034        assert_eq!(result, "continue");
2035    }
2036
2037    #[test]
2038    fn test_loop_with_break() {
2039        let transpiler = CudaTranspiler::new_generic();
2040
2041        let expr: Expr = parse_quote! {
2042            loop {
2043                if done {
2044                    break;
2045                }
2046            }
2047        };
2048
2049        let result = transpiler.transpile_expr(&expr).unwrap();
2050        assert!(
2051            result.contains("while (true)"),
2052            "Should generate infinite loop: {}",
2053            result
2054        );
2055        assert!(result.contains("break"), "Should contain break: {}", result);
2056
2057        println!("Generated loop with break:\n{}", result);
2058    }
2059
2060    #[test]
2061    fn test_nested_loops() {
2062        let transpiler = CudaTranspiler::new_generic();
2063
2064        let expr: Expr = parse_quote! {
2065            for i in 0..m {
2066                for j in 0..n {
2067                    matrix[i * n + j] = 0.0;
2068                }
2069            }
2070        };
2071
2072        let result = transpiler.transpile_expr(&expr).unwrap();
2073        assert!(
2074            result.contains("for (int i = 0; i < m; i++)"),
2075            "Should have outer loop: {}",
2076            result
2077        );
2078        assert!(
2079            result.contains("for (int j = 0; j < n; j++)"),
2080            "Should have inner loop: {}",
2081            result
2082        );
2083
2084        println!("Generated nested loops:\n{}", result);
2085    }
2086
2087    #[test]
2088    fn test_stencil_mode_rejects_loops() {
2089        let config = StencilConfig::new("test")
2090            .with_tile_size(16, 16)
2091            .with_halo(1);
2092        let transpiler = CudaTranspiler::new(config);
2093
2094        let expr: Expr = parse_quote! {
2095            for i in 0..n {
2096                data[i] = 0.0;
2097            }
2098        };
2099
2100        let result = transpiler.transpile_expr(&expr);
2101        assert!(result.is_err(), "Stencil mode should reject loops");
2102    }
2103
2104    #[test]
2105    fn test_labeled_break_rejected() {
2106        let transpiler = CudaTranspiler::new_generic();
2107
2108        // Note: We can't directly parse `break 'label` without a labeled block,
2109        // so we test that the error path exists by checking the function handles labels
2110        let break_expr = syn::ExprBreak {
2111            attrs: Vec::new(),
2112            break_token: syn::token::Break::default(),
2113            label: Some(syn::Lifetime::new("'outer", proc_macro2::Span::call_site())),
2114            expr: None,
2115        };
2116
2117        let result = transpiler.transpile_break(&break_expr);
2118        assert!(result.is_err(), "Labeled break should be rejected");
2119    }
2120
2121    #[test]
2122    fn test_full_kernel_with_loop() {
2123        let func: ItemFn = parse_quote! {
2124            fn fill_array(data: &mut [f32], n: i32) {
2125                for i in 0..n {
2126                    data[i as usize] = 0.0;
2127                }
2128            }
2129        };
2130
2131        let mut transpiler = CudaTranspiler::new_generic();
2132        let cuda = transpiler.transpile_generic_kernel(&func).unwrap();
2133
2134        assert!(
2135            cuda.contains("extern \"C\" __global__"),
2136            "Should be global kernel: {}",
2137            cuda
2138        );
2139        assert!(
2140            cuda.contains("for (int i = 0; i < n; i++)"),
2141            "Should have for loop: {}",
2142            cuda
2143        );
2144
2145        println!("Generated kernel with loop:\n{}", cuda);
2146    }
2147
2148    #[test]
2149    fn test_persistent_kernel_pattern() {
2150        // Test the pattern used for ring/actor kernels
2151        let transpiler = CudaTranspiler::with_mode(ValidationMode::RingKernel);
2152
2153        let expr: Expr = parse_quote! {
2154            while !should_terminate {
2155                if has_message {
2156                    process_message();
2157                }
2158            }
2159        };
2160
2161        let result = transpiler.transpile_expr(&expr).unwrap();
2162        assert!(
2163            result.contains("while (!(should_terminate))"),
2164            "Should have persistent loop: {}",
2165            result
2166        );
2167        assert!(
2168            result.contains("if (has_message)"),
2169            "Should have message check: {}",
2170            result
2171        );
2172
2173        println!("Generated persistent kernel pattern:\n{}", result);
2174    }
2175
2176    // ==================== Shared Memory Tests ====================
2177
2178    #[test]
2179    fn test_shared_tile_declaration() {
2180        use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2181
2182        let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
2183        assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
2184
2185        let mut config = SharedMemoryConfig::new();
2186        config.add_tile("tile", "float", 16, 16);
2187        assert_eq!(config.total_bytes(), 16 * 16 * 4); // 1024 bytes
2188
2189        let decls = config.generate_declarations("    ");
2190        assert!(decls.contains("__shared__ float tile[16][16];"));
2191    }
2192
2193    #[test]
2194    fn test_shared_array_declaration() {
2195        use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2196
2197        let decl = SharedMemoryDecl::array("buffer", "float", 256);
2198        assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
2199
2200        let mut config = SharedMemoryConfig::new();
2201        config.add_array("buffer", "float", 256);
2202        assert_eq!(config.total_bytes(), 256 * 4); // 1024 bytes
2203    }
2204
2205    #[test]
2206    fn test_shared_memory_access_expressions() {
2207        use crate::shared::SharedMemoryDecl;
2208
2209        let tile = SharedMemoryDecl::tile("tile", "float", 16, 16);
2210        assert_eq!(
2211            tile.to_cuda_access(&["y".to_string(), "x".to_string()]),
2212            "tile[y][x]"
2213        );
2214
2215        let arr = SharedMemoryDecl::array("buf", "int", 128);
2216        assert_eq!(arr.to_cuda_access(&["i".to_string()]), "buf[i]");
2217    }
2218
2219    #[test]
2220    fn test_parse_shared_tile_type() {
2221        use crate::shared::parse_shared_tile_type;
2222
2223        let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
2224        assert_eq!(result, Some(("f32".to_string(), 16, 16)));
2225
2226        let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
2227        assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
2228
2229        let invalid = parse_shared_tile_type("Vec<f32>");
2230        assert_eq!(invalid, None);
2231    }
2232
2233    #[test]
2234    fn test_parse_shared_array_type() {
2235        use crate::shared::parse_shared_array_type;
2236
2237        let result = parse_shared_array_type("SharedArray::<f32, 256>");
2238        assert_eq!(result, Some(("f32".to_string(), 256)));
2239
2240        let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
2241        assert_eq!(result2, Some(("u32".to_string(), 1024)));
2242
2243        let invalid = parse_shared_array_type("Vec<f32>");
2244        assert_eq!(invalid, None);
2245    }
2246
2247    #[test]
2248    fn test_rust_to_cuda_element_types() {
2249        use crate::shared::rust_to_cuda_element_type;
2250
2251        assert_eq!(rust_to_cuda_element_type("f32"), "float");
2252        assert_eq!(rust_to_cuda_element_type("f64"), "double");
2253        assert_eq!(rust_to_cuda_element_type("i32"), "int");
2254        assert_eq!(rust_to_cuda_element_type("u32"), "unsigned int");
2255        assert_eq!(rust_to_cuda_element_type("i64"), "long long");
2256        assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
2257        assert_eq!(rust_to_cuda_element_type("bool"), "int");
2258    }
2259
2260    #[test]
2261    fn test_shared_memory_total_bytes() {
2262        use crate::shared::SharedMemoryConfig;
2263
2264        let mut config = SharedMemoryConfig::new();
2265        config.add_tile("tile1", "float", 16, 16); // 16*16*4 = 1024
2266        config.add_tile("tile2", "double", 8, 8); // 8*8*8 = 512
2267        config.add_array("temp", "int", 64); // 64*4 = 256
2268
2269        assert_eq!(config.total_bytes(), 1024 + 512 + 256);
2270    }
2271
2272    #[test]
2273    fn test_transpiler_shared_var_tracking() {
2274        let mut transpiler = CudaTranspiler::new_generic();
2275
2276        // Manually register a shared variable
2277        transpiler.shared_vars.insert(
2278            "tile".to_string(),
2279            SharedVarInfo {
2280                name: "tile".to_string(),
2281                is_tile: true,
2282                dimensions: vec![16, 16],
2283                element_type: "float".to_string(),
2284            },
2285        );
2286
2287        // Test that transpiler tracks it
2288        assert!(transpiler.shared_vars.contains_key("tile"));
2289        assert!(transpiler.shared_vars.get("tile").unwrap().is_tile);
2290    }
2291
2292    #[test]
2293    fn test_shared_tile_get_transpilation() {
2294        let mut transpiler = CudaTranspiler::new_generic();
2295
2296        // Register a shared tile
2297        transpiler.shared_vars.insert(
2298            "tile".to_string(),
2299            SharedVarInfo {
2300                name: "tile".to_string(),
2301                is_tile: true,
2302                dimensions: vec![16, 16],
2303                element_type: "float".to_string(),
2304            },
2305        );
2306
2307        // Test method call transpilation
2308        let result = transpiler.try_transpile_shared_method_call(
2309            "tile",
2310            "get",
2311            &syn::punctuated::Punctuated::new(),
2312        );
2313
2314        // With no args, it should return None (args required)
2315        assert!(result.is_none() || result.unwrap().is_err());
2316    }
2317
2318    #[test]
2319    fn test_shared_array_access() {
2320        let mut transpiler = CudaTranspiler::new_generic();
2321
2322        // Register a shared array
2323        transpiler.shared_vars.insert(
2324            "buffer".to_string(),
2325            SharedVarInfo {
2326                name: "buffer".to_string(),
2327                is_tile: false,
2328                dimensions: vec![256],
2329                element_type: "float".to_string(),
2330            },
2331        );
2332
2333        assert!(!transpiler.shared_vars.get("buffer").unwrap().is_tile);
2334        assert_eq!(
2335            transpiler.shared_vars.get("buffer").unwrap().dimensions,
2336            vec![256]
2337        );
2338    }
2339
2340    #[test]
2341    fn test_full_kernel_with_shared_memory() {
2342        // Test that we can generate declarations correctly
2343        use crate::shared::SharedMemoryConfig;
2344
2345        let mut config = SharedMemoryConfig::new();
2346        config.add_tile("smem", "float", 16, 16);
2347
2348        let decls = config.generate_declarations("    ");
2349        assert!(decls.contains("__shared__ float smem[16][16];"));
2350        assert!(!config.is_empty());
2351    }
2352
2353    // === Struct Literal Tests ===
2354
2355    #[test]
2356    fn test_struct_literal_transpile() {
2357        let transpiler = CudaTranspiler::new_generic();
2358
2359        let expr: Expr = parse_quote! {
2360            Point { x: 1.0, y: 2.0 }
2361        };
2362
2363        let result = transpiler.transpile_expr(&expr).unwrap();
2364        assert!(
2365            result.contains("Point"),
2366            "Should contain struct name: {}",
2367            result
2368        );
2369        assert!(result.contains(".x ="), "Should have field x: {}", result);
2370        assert!(result.contains(".y ="), "Should have field y: {}", result);
2371        assert!(
2372            result.contains("1.0f"),
2373            "Should have value 1.0f: {}",
2374            result
2375        );
2376        assert!(
2377            result.contains("2.0f"),
2378            "Should have value 2.0f: {}",
2379            result
2380        );
2381
2382        println!("Generated struct literal: {}", result);
2383    }
2384
2385    #[test]
2386    fn test_struct_literal_with_expressions() {
2387        let transpiler = CudaTranspiler::new_generic();
2388
2389        let expr: Expr = parse_quote! {
2390            Response { value: x * 2.0, id: idx as u64 }
2391        };
2392
2393        let result = transpiler.transpile_expr(&expr).unwrap();
2394        assert!(
2395            result.contains("Response"),
2396            "Should contain struct name: {}",
2397            result
2398        );
2399        assert!(
2400            result.contains(".value = x * 2.0f"),
2401            "Should have computed value: {}",
2402            result
2403        );
2404        assert!(result.contains(".id ="), "Should have id field: {}", result);
2405
2406        println!("Generated struct with expressions: {}", result);
2407    }
2408
2409    #[test]
2410    fn test_struct_literal_in_return() {
2411        let mut transpiler = CudaTranspiler::new_generic();
2412
2413        let stmt: Stmt = parse_quote! {
2414            return MyStruct { a: 1, b: 2.0 };
2415        };
2416
2417        let result = transpiler.transpile_stmt(&stmt).unwrap();
2418        assert!(result.contains("return"), "Should have return: {}", result);
2419        assert!(
2420            result.contains("MyStruct"),
2421            "Should contain struct name: {}",
2422            result
2423        );
2424
2425        println!("Generated return with struct: {}", result);
2426    }
2427
2428    #[test]
2429    fn test_struct_literal_compound_literal_format() {
2430        let transpiler = CudaTranspiler::new_generic();
2431
2432        let expr: Expr = parse_quote! {
2433            Vec3 { x: a, y: b, z: c }
2434        };
2435
2436        let result = transpiler.transpile_expr(&expr).unwrap();
2437        // Check for C compound literal format: (Type){ .field = val, ... }
2438        assert!(
2439            result.starts_with("(Vec3){"),
2440            "Should use compound literal format: {}",
2441            result
2442        );
2443        assert!(
2444            result.ends_with("}"),
2445            "Should end with closing brace: {}",
2446            result
2447        );
2448
2449        println!("Generated compound literal: {}", result);
2450    }
2451
2452    // === Reference Expression Tests ===
2453
2454    #[test]
2455    fn test_reference_to_array_element() {
2456        let transpiler = CudaTranspiler::new_generic();
2457
2458        let expr: Expr = parse_quote! {
2459            &arr[idx]
2460        };
2461
2462        let result = transpiler.transpile_expr(&expr).unwrap();
2463        assert_eq!(
2464            result, "&arr[idx]",
2465            "Should produce address-of array element"
2466        );
2467    }
2468
2469    #[test]
2470    fn test_mutable_reference_to_array_element() {
2471        let transpiler = CudaTranspiler::new_generic();
2472
2473        let expr: Expr = parse_quote! {
2474            &mut arr[idx * 4 + offset]
2475        };
2476
2477        let result = transpiler.transpile_expr(&expr).unwrap();
2478        assert!(
2479            result.contains("&arr["),
2480            "Should produce address-of: {}",
2481            result
2482        );
2483        assert!(
2484            result.contains("idx * 4"),
2485            "Should have index expression: {}",
2486            result
2487        );
2488    }
2489
2490    #[test]
2491    fn test_reference_to_variable() {
2492        let transpiler = CudaTranspiler::new_generic();
2493
2494        let expr: Expr = parse_quote! {
2495            &value
2496        };
2497
2498        let result = transpiler.transpile_expr(&expr).unwrap();
2499        assert_eq!(result, "&value", "Should produce address-of variable");
2500    }
2501
2502    #[test]
2503    fn test_reference_to_struct_field() {
2504        let transpiler = CudaTranspiler::new_generic();
2505
2506        let expr: Expr = parse_quote! {
2507            &alerts[(idx as usize) * 4 + alert_idx as usize]
2508        };
2509
2510        let result = transpiler.transpile_expr(&expr).unwrap();
2511        assert!(
2512            result.starts_with("&alerts["),
2513            "Should have address-of array: {}",
2514            result
2515        );
2516
2517        println!("Generated reference: {}", result);
2518    }
2519
2520    #[test]
2521    fn test_complex_reference_pattern() {
2522        let mut transpiler = CudaTranspiler::new_generic();
2523
2524        // This is the pattern from txmon batch kernel
2525        let stmt: Stmt = parse_quote! {
2526            let alert = &mut alerts[(idx as usize) * 4 + alert_idx as usize];
2527        };
2528
2529        let result = transpiler.transpile_stmt(&stmt).unwrap();
2530        assert!(
2531            result.contains("alert ="),
2532            "Should have variable assignment: {}",
2533            result
2534        );
2535        assert!(
2536            result.contains("&alerts["),
2537            "Should have reference to array: {}",
2538            result
2539        );
2540
2541        println!("Generated statement: {}", result);
2542    }
2543}