ringkernel_cuda_codegen/
transpiler.rs

1//! Core Rust-to-CUDA transpiler.
2//!
3//! This module handles the translation of Rust AST to CUDA C code.
4
5use crate::handler::{ContextMethod, HandlerSignature};
6use crate::intrinsics::{IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
7use crate::loops::{extract_loop_var, RangeInfo};
8use crate::shared::{rust_to_cuda_element_type, SharedMemoryConfig, SharedMemoryDecl};
9use crate::stencil::StencilConfig;
10use crate::types::{is_grid_pos_type, is_ring_context_type, TypeMapper};
11use crate::validation::ValidationMode;
12use crate::{Result, TranspileError};
13use quote::ToTokens;
14use syn::{
15    BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
16    ExprIf, ExprIndex, ExprLet, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
17    ExprReference, ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat,
18    ReturnType, Stmt, UnOp,
19};
20
21/// CUDA code transpiler.
22pub struct CudaTranspiler {
23    /// Stencil configuration (if generating a stencil kernel).
24    config: Option<StencilConfig>,
25    /// Type mapper.
26    type_mapper: TypeMapper,
27    /// Intrinsic registry.
28    intrinsics: IntrinsicRegistry,
29    /// Variables known to be the GridPos context.
30    grid_pos_vars: Vec<String>,
31    /// Variables known to be RingContext references.
32    context_vars: Vec<String>,
33    /// Current indentation level.
34    indent: usize,
35    /// Validation mode for loop handling.
36    validation_mode: ValidationMode,
37    /// Shared memory configuration.
38    shared_memory: SharedMemoryConfig,
39    /// Variables that are SharedTile or SharedArray types.
40    pub shared_vars: std::collections::HashMap<String, SharedVarInfo>,
41    /// Whether we're in ring kernel mode (enables context method inlining).
42    ring_kernel_mode: bool,
43    /// Variables known to be pointer types (for -> operator usage).
44    pointer_vars: std::collections::HashSet<String>,
45}
46
47/// Information about a shared memory variable.
48#[derive(Debug, Clone)]
49pub struct SharedVarInfo {
50    /// Variable name.
51    pub name: String,
52    /// Whether it's a 2D tile (true) or 1D array (false).
53    pub is_tile: bool,
54    /// Dimensions: [size] for 1D, [height, width] for 2D.
55    pub dimensions: Vec<usize>,
56    /// Element type (CUDA type string).
57    pub element_type: String,
58}
59
60impl CudaTranspiler {
61    /// Create a new transpiler with stencil configuration.
62    pub fn new(config: StencilConfig) -> Self {
63        Self {
64            config: Some(config),
65            type_mapper: TypeMapper::new(),
66            intrinsics: IntrinsicRegistry::new(),
67            grid_pos_vars: Vec::new(),
68            context_vars: Vec::new(),
69            indent: 1, // Start with 1 level for function body
70            validation_mode: ValidationMode::Stencil,
71            shared_memory: SharedMemoryConfig::new(),
72            shared_vars: std::collections::HashMap::new(),
73            ring_kernel_mode: false,
74            pointer_vars: std::collections::HashSet::new(),
75        }
76    }
77
78    /// Create a new transpiler without stencil configuration.
79    pub fn new_generic() -> Self {
80        Self {
81            config: None,
82            type_mapper: TypeMapper::new(),
83            intrinsics: IntrinsicRegistry::new(),
84            grid_pos_vars: Vec::new(),
85            context_vars: Vec::new(),
86            indent: 1,
87            validation_mode: ValidationMode::Generic,
88            shared_memory: SharedMemoryConfig::new(),
89            shared_vars: std::collections::HashMap::new(),
90            ring_kernel_mode: false,
91            pointer_vars: std::collections::HashSet::new(),
92        }
93    }
94
95    /// Create a new transpiler with a specific validation mode.
96    pub fn with_mode(mode: ValidationMode) -> Self {
97        Self {
98            config: None,
99            type_mapper: TypeMapper::new(),
100            intrinsics: IntrinsicRegistry::new(),
101            grid_pos_vars: Vec::new(),
102            context_vars: Vec::new(),
103            indent: 1,
104            validation_mode: mode,
105            shared_memory: SharedMemoryConfig::new(),
106            shared_vars: std::collections::HashMap::new(),
107            ring_kernel_mode: false,
108            pointer_vars: std::collections::HashSet::new(),
109        }
110    }
111
112    /// Create a transpiler configured for ring kernel handler transpilation.
113    pub fn for_ring_kernel() -> Self {
114        Self {
115            config: None,
116            type_mapper: crate::types::ring_kernel_type_mapper(),
117            intrinsics: IntrinsicRegistry::new(),
118            grid_pos_vars: Vec::new(),
119            context_vars: Vec::new(),
120            indent: 2, // Inside kernel + loop
121            validation_mode: ValidationMode::Generic,
122            shared_memory: SharedMemoryConfig::new(),
123            shared_vars: std::collections::HashMap::new(),
124            ring_kernel_mode: true,
125            pointer_vars: std::collections::HashSet::new(),
126        }
127    }
128
129    /// Set the validation mode.
130    pub fn set_validation_mode(&mut self, mode: ValidationMode) {
131        self.validation_mode = mode;
132    }
133
134    /// Get the shared memory configuration.
135    pub fn shared_memory(&self) -> &SharedMemoryConfig {
136        &self.shared_memory
137    }
138
139    /// Get current indentation string.
140    fn indent_str(&self) -> String {
141        "    ".repeat(self.indent)
142    }
143
144    /// Transpile a stencil kernel function.
145    pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
146        let config = self
147            .config
148            .as_ref()
149            .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
150            .clone();
151
152        // Identify GridPos parameters
153        for param in &func.sig.inputs {
154            if let FnArg::Typed(pat_type) = param {
155                if is_grid_pos_type(&pat_type.ty) {
156                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
157                        self.grid_pos_vars.push(ident.ident.to_string());
158                    }
159                }
160            }
161        }
162
163        // Generate function signature (without GridPos params)
164        let signature = self.transpile_kernel_signature(func)?;
165
166        // Generate preamble (thread index calculations)
167        let preamble = config.generate_preamble();
168
169        // Generate function body
170        let body = self.transpile_block(&func.block)?;
171
172        Ok(format!(
173            "extern \"C\" __global__ void {signature} {{\n{preamble}\n{body}}}\n"
174        ))
175    }
176
177    /// Transpile a generic (non-stencil) kernel function.
178    ///
179    /// This generates a `__global__` kernel without stencil-specific preamble.
180    /// The kernel code can use `thread_idx_x()`, `block_idx_x()` etc. to access
181    /// CUDA thread indices directly.
182    pub fn transpile_generic_kernel(&mut self, func: &ItemFn) -> Result<String> {
183        // Generate function signature with all params
184        let signature = self.transpile_generic_kernel_signature(func)?;
185
186        // Generate function body
187        let body = self.transpile_block(&func.block)?;
188
189        Ok(format!(
190            "extern \"C\" __global__ void {signature} {{\n{body}}}\n"
191        ))
192    }
193
194    /// Transpile a handler function into a persistent ring kernel.
195    ///
196    /// This wraps the handler body in a persistent message-processing loop
197    /// with control block integration, queue operations, and HLC support.
198    pub fn transpile_ring_kernel(
199        &mut self,
200        handler: &ItemFn,
201        config: &crate::ring_kernel::RingKernelConfig,
202    ) -> Result<String> {
203        use std::fmt::Write;
204
205        // Parse handler signature
206        let handler_sig = HandlerSignature::parse(handler, &self.type_mapper)?;
207
208        // Track context variables for method inlining
209        for param in &handler.sig.inputs {
210            if let FnArg::Typed(pat_type) = param {
211                if is_ring_context_type(&pat_type.ty) {
212                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
213                        self.context_vars.push(ident.ident.to_string());
214                    }
215                }
216            }
217        }
218
219        // Enable ring kernel mode for context method inlining
220        self.ring_kernel_mode = true;
221
222        let mut output = String::new();
223
224        // Generate struct definitions
225        output.push_str(&crate::ring_kernel::generate_control_block_struct());
226        output.push('\n');
227
228        if config.enable_hlc {
229            output.push_str(&crate::ring_kernel::generate_hlc_struct());
230            output.push('\n');
231        }
232
233        if config.enable_k2k {
234            output.push_str(&crate::ring_kernel::generate_k2k_structs());
235            output.push('\n');
236        }
237
238        // Generate message/response struct definitions if needed
239        if let Some(ref msg_param) = handler_sig.message_param {
240            // Extract type name from the parameter
241            let type_name = msg_param
242                .rust_type
243                .trim_start_matches('&')
244                .trim_start_matches("mut ")
245                .trim();
246            if !type_name.is_empty() && type_name != "f32" && type_name != "i32" {
247                writeln!(output, "// Message type: {}", type_name).unwrap();
248            }
249        }
250
251        if let Some(ref ret_type) = handler_sig.return_type {
252            if ret_type.is_struct {
253                writeln!(output, "// Response type: {}", ret_type.rust_type).unwrap();
254            }
255        }
256
257        // Generate kernel signature
258        output.push_str(&config.generate_signature());
259        output.push_str(" {\n");
260
261        // Generate preamble
262        output.push_str(&config.generate_preamble("    "));
263
264        // Generate message loop header
265        output.push_str(&config.generate_loop_header("    "));
266
267        // Generate message deserialization if handler has message param
268        if let Some(ref msg_param) = handler_sig.message_param {
269            let type_name = msg_param
270                .rust_type
271                .trim_start_matches('&')
272                .trim_start_matches("mut ")
273                .trim();
274            if !type_name.is_empty() {
275                writeln!(output, "        // Message deserialization").unwrap();
276                writeln!(
277                    output,
278                    "        // {}* {} = ({}*)msg_ptr;",
279                    type_name, msg_param.name, type_name
280                )
281                .unwrap();
282                output.push('\n');
283            }
284        }
285
286        // Transpile handler body
287        self.indent = 2; // Inside the message loop
288        let handler_body = self.transpile_block(&handler.block)?;
289
290        // Insert handler code with proper indentation
291        writeln!(output, "        // === USER HANDLER CODE ===").unwrap();
292        for line in handler_body.lines() {
293            if !line.trim().is_empty() {
294                // Add extra indent for being inside the loop
295                writeln!(output, "    {}", line).unwrap();
296            }
297        }
298        writeln!(output, "        // === END HANDLER CODE ===").unwrap();
299
300        // Generate response serialization if handler returns a value
301        if let Some(ref ret_type) = handler_sig.return_type {
302            writeln!(output).unwrap();
303            writeln!(output, "        // Response serialization").unwrap();
304            if ret_type.is_struct {
305                writeln!(output, "        // memcpy(&output_buffer[_out_idx * RESP_SIZE], &response, sizeof({}));",
306                    ret_type.cuda_type).unwrap();
307            }
308        }
309
310        // Generate message completion
311        output.push_str(&config.generate_message_complete("    "));
312
313        // Generate loop footer
314        output.push_str(&config.generate_loop_footer("    "));
315
316        // Generate epilogue
317        output.push_str(&config.generate_epilogue("    "));
318
319        output.push_str("}\n");
320
321        Ok(output)
322    }
323
324    /// Transpile a generic kernel function signature (keeps all params).
325    fn transpile_generic_kernel_signature(&self, func: &ItemFn) -> Result<String> {
326        let name = func.sig.ident.to_string();
327
328        let mut params = Vec::new();
329        for param in &func.sig.inputs {
330            if let FnArg::Typed(pat_type) = param {
331                let param_name = match pat_type.pat.as_ref() {
332                    Pat::Ident(ident) => ident.ident.to_string(),
333                    _ => {
334                        return Err(TranspileError::Unsupported(
335                            "Complex pattern in parameter".into(),
336                        ))
337                    }
338                };
339
340                let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
341                params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
342            }
343        }
344
345        Ok(format!("{}({})", name, params.join(", ")))
346    }
347
348    /// Transpile the kernel function signature.
349    fn transpile_kernel_signature(&self, func: &ItemFn) -> Result<String> {
350        let name = func.sig.ident.to_string();
351
352        let mut params = Vec::new();
353        for param in &func.sig.inputs {
354            if let FnArg::Typed(pat_type) = param {
355                // Skip GridPos parameters
356                if is_grid_pos_type(&pat_type.ty) {
357                    continue;
358                }
359
360                let param_name = match pat_type.pat.as_ref() {
361                    Pat::Ident(ident) => ident.ident.to_string(),
362                    _ => {
363                        return Err(TranspileError::Unsupported(
364                            "Complex pattern in parameter".into(),
365                        ))
366                    }
367                };
368
369                let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
370                params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
371            }
372        }
373
374        Ok(format!("{}({})", name, params.join(", ")))
375    }
376
377    /// Transpile a block of statements.
378    fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
379        let mut output = String::new();
380
381        for stmt in &block.stmts {
382            let stmt_str = self.transpile_stmt(stmt)?;
383            if !stmt_str.is_empty() {
384                output.push_str(&stmt_str);
385            }
386        }
387
388        Ok(output)
389    }
390
391    /// Transpile a single statement.
392    fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
393        match stmt {
394            Stmt::Local(local) => {
395                let indent = self.indent_str();
396
397                // Get variable name
398                let var_name = match &local.pat {
399                    Pat::Ident(ident) => ident.ident.to_string(),
400                    Pat::Type(pat_type) => {
401                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
402                            ident.ident.to_string()
403                        } else {
404                            return Err(TranspileError::Unsupported(
405                                "Complex pattern in let binding".into(),
406                            ));
407                        }
408                    }
409                    _ => {
410                        return Err(TranspileError::Unsupported(
411                            "Complex pattern in let binding".into(),
412                        ))
413                    }
414                };
415
416                // Check for SharedTile or SharedArray type annotation
417                if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
418                    // Register the shared variable
419                    self.shared_vars.insert(
420                        var_name.clone(),
421                        SharedVarInfo {
422                            name: var_name.clone(),
423                            is_tile: shared_decl.dimensions.len() == 2,
424                            dimensions: shared_decl.dimensions.clone(),
425                            element_type: shared_decl.element_type.clone(),
426                        },
427                    );
428
429                    // Add to shared memory config
430                    self.shared_memory.add(shared_decl.clone());
431
432                    // Return the __shared__ declaration
433                    return Ok(format!("{indent}{}\n", shared_decl.to_cuda_decl()));
434                }
435
436                // Get initializer
437                if let Some(init) = &local.init {
438                    let expr_str = self.transpile_expr(&init.expr)?;
439
440                    // Infer type from expression (default to float for now)
441                    // In a more complete implementation, we'd do proper type inference
442                    let type_str = self.infer_cuda_type(&init.expr);
443
444                    // Track if this is a pointer variable for -> access
445                    if type_str.ends_with('*') {
446                        self.pointer_vars.insert(var_name.clone());
447                    }
448
449                    Ok(format!("{indent}{type_str} {var_name} = {expr_str};\n"))
450                } else {
451                    // Uninitialized variable - need type annotation
452                    Ok(format!("{indent}float {var_name};\n"))
453                }
454            }
455            Stmt::Expr(expr, semi) => {
456                let indent = self.indent_str();
457
458                // Check if this is an if statement with early return (shouldn't wrap in return)
459                if let Expr::If(if_expr) = expr {
460                    // Check if the body contains only a return statement
461                    if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
462                    {
463                        if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
464                            let expr_str = self.transpile_expr(expr)?;
465                            return Ok(format!("{indent}{expr_str};\n"));
466                        }
467                    }
468                }
469
470                let expr_str = self.transpile_expr(expr)?;
471
472                if semi.is_some() {
473                    Ok(format!("{indent}{expr_str};\n"))
474                } else {
475                    // Expression without semicolon (implicit return)
476                    // But check if it's already a return or an if with return
477                    if matches!(expr, Expr::Return(_))
478                        || expr_str.starts_with("return")
479                        || expr_str.starts_with("if (")
480                    {
481                        Ok(format!("{indent}{expr_str};\n"))
482                    } else {
483                        Ok(format!("{indent}return {expr_str};\n"))
484                    }
485                }
486            }
487            Stmt::Item(_) => {
488                // Items in function body not supported
489                Err(TranspileError::Unsupported("Item in function body".into()))
490            }
491            Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
492        }
493    }
494
495    /// Transpile an expression.
496    fn transpile_expr(&self, expr: &Expr) -> Result<String> {
497        match expr {
498            Expr::Lit(lit) => self.transpile_lit(lit),
499            Expr::Path(path) => self.transpile_path(path),
500            Expr::Binary(bin) => self.transpile_binary(bin),
501            Expr::Unary(unary) => self.transpile_unary(unary),
502            Expr::Paren(paren) => self.transpile_paren(paren),
503            Expr::Index(index) => self.transpile_index(index),
504            Expr::Call(call) => self.transpile_call(call),
505            Expr::MethodCall(method) => self.transpile_method_call(method),
506            Expr::If(if_expr) => self.transpile_if(if_expr),
507            Expr::Assign(assign) => self.transpile_assign(assign),
508            Expr::Cast(cast) => self.transpile_cast(cast),
509            Expr::Match(match_expr) => self.transpile_match(match_expr),
510            Expr::Block(block) => {
511                // For block expressions, we just return the last expression
512                if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
513                    self.transpile_expr(expr)
514                } else {
515                    Err(TranspileError::Unsupported(
516                        "Complex block expression".into(),
517                    ))
518                }
519            }
520            Expr::Field(field) => {
521                // Struct field access: obj.field -> obj.field or obj->field for pointers
522                let base = self.transpile_expr(&field.base)?;
523                let member = match &field.member {
524                    syn::Member::Named(ident) => ident.to_string(),
525                    syn::Member::Unnamed(idx) => idx.index.to_string(),
526                };
527
528                // Check if base is a pointer variable - use -> instead of .
529                let accessor = if self.pointer_vars.contains(&base) {
530                    "->"
531                } else {
532                    "."
533                };
534                Ok(format!("{base}{accessor}{member}"))
535            }
536            Expr::Return(ret) => self.transpile_return(ret),
537            Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
538            Expr::While(while_loop) => self.transpile_while_loop(while_loop),
539            Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
540            Expr::Break(break_expr) => self.transpile_break(break_expr),
541            Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
542            Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
543            Expr::Reference(ref_expr) => self.transpile_reference(ref_expr),
544            Expr::Let(let_expr) => self.transpile_let_expr(let_expr),
545            Expr::Tuple(tuple) => {
546                // Transpile tuple as a comma-separated list (for multi-value returns, etc.)
547                let elements: Vec<String> = tuple
548                    .elems
549                    .iter()
550                    .map(|e| self.transpile_expr(e))
551                    .collect::<Result<_>>()?;
552                Ok(format!("({})", elements.join(", ")))
553            }
554            _ => Err(TranspileError::Unsupported(format!(
555                "Expression type: {}",
556                expr.to_token_stream()
557            ))),
558        }
559    }
560
561    /// Transpile a literal.
562    fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
563        match &lit.lit {
564            Lit::Float(f) => {
565                let s = f.to_string();
566                // Ensure f32 literals have 'f' suffix in CUDA
567                if s.ends_with("f32") || !s.contains('.') {
568                    let num = s.trim_end_matches("f32").trim_end_matches("f64");
569                    Ok(format!("{num}f"))
570                } else if s.ends_with("f64") {
571                    Ok(s.trim_end_matches("f64").to_string())
572                } else {
573                    // Plain float literal - add 'f' suffix for float
574                    Ok(format!("{s}f"))
575                }
576            }
577            Lit::Int(i) => Ok(i.to_string()),
578            Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
579            _ => Err(TranspileError::Unsupported(format!(
580                "Literal type: {}",
581                lit.to_token_stream()
582            ))),
583        }
584    }
585
586    /// Transpile a path (variable reference).
587    fn transpile_path(&self, path: &ExprPath) -> Result<String> {
588        let segments: Vec<_> = path
589            .path
590            .segments
591            .iter()
592            .map(|s| s.ident.to_string())
593            .collect();
594
595        if segments.len() == 1 {
596            Ok(segments[0].clone())
597        } else {
598            Ok(segments.join("::"))
599        }
600    }
601
602    /// Transpile a binary expression.
603    fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
604        let left = self.transpile_expr(&bin.left)?;
605        let right = self.transpile_expr(&bin.right)?;
606
607        let op = match bin.op {
608            BinOp::Add(_) => "+",
609            BinOp::Sub(_) => "-",
610            BinOp::Mul(_) => "*",
611            BinOp::Div(_) => "/",
612            BinOp::Rem(_) => "%",
613            BinOp::And(_) => "&&",
614            BinOp::Or(_) => "||",
615            BinOp::BitXor(_) => "^",
616            BinOp::BitAnd(_) => "&",
617            BinOp::BitOr(_) => "|",
618            BinOp::Shl(_) => "<<",
619            BinOp::Shr(_) => ">>",
620            BinOp::Eq(_) => "==",
621            BinOp::Lt(_) => "<",
622            BinOp::Le(_) => "<=",
623            BinOp::Ne(_) => "!=",
624            BinOp::Ge(_) => ">=",
625            BinOp::Gt(_) => ">",
626            BinOp::AddAssign(_) => "+=",
627            BinOp::SubAssign(_) => "-=",
628            BinOp::MulAssign(_) => "*=",
629            BinOp::DivAssign(_) => "/=",
630            BinOp::RemAssign(_) => "%=",
631            BinOp::BitXorAssign(_) => "^=",
632            BinOp::BitAndAssign(_) => "&=",
633            BinOp::BitOrAssign(_) => "|=",
634            BinOp::ShlAssign(_) => "<<=",
635            BinOp::ShrAssign(_) => ">>=",
636            _ => {
637                return Err(TranspileError::Unsupported(format!(
638                    "Binary operator: {}",
639                    bin.to_token_stream()
640                )))
641            }
642        };
643
644        Ok(format!("{left} {op} {right}"))
645    }
646
647    /// Transpile a unary expression.
648    fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
649        let expr = self.transpile_expr(&unary.expr)?;
650
651        let op = match unary.op {
652            UnOp::Neg(_) => "-",
653            UnOp::Not(_) => "!",
654            UnOp::Deref(_) => "*",
655            _ => {
656                return Err(TranspileError::Unsupported(format!(
657                    "Unary operator: {}",
658                    unary.to_token_stream()
659                )))
660            }
661        };
662
663        Ok(format!("{op}({expr})"))
664    }
665
666    /// Transpile a parenthesized expression.
667    fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
668        let inner = self.transpile_expr(&paren.expr)?;
669        Ok(format!("({inner})"))
670    }
671
672    /// Transpile an index expression.
673    fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
674        let base = self.transpile_expr(&index.expr)?;
675        let idx = self.transpile_expr(&index.index)?;
676        Ok(format!("{base}[{idx}]"))
677    }
678
679    /// Transpile a function call.
680    fn transpile_call(&self, call: &ExprCall) -> Result<String> {
681        let func = self.transpile_expr(&call.func)?;
682
683        // Check for intrinsics
684        if let Some(intrinsic) = self.intrinsics.lookup(&func) {
685            let cuda_name = intrinsic.to_cuda_string();
686
687            // Check if this is a "value" intrinsic (like threadIdx.x) vs a "function" intrinsic
688            // Value intrinsics don't have parentheses in CUDA
689            let is_value_intrinsic = cuda_name.contains("Idx.")
690                || cuda_name.contains("Dim.")
691                || cuda_name.starts_with("threadIdx")
692                || cuda_name.starts_with("blockIdx")
693                || cuda_name.starts_with("blockDim")
694                || cuda_name.starts_with("gridDim");
695
696            if is_value_intrinsic && call.args.is_empty() {
697                // Value intrinsics: threadIdx.x, blockIdx.y, etc. - no parens
698                return Ok(cuda_name.to_string());
699            }
700
701            if call.args.is_empty() && cuda_name.ends_with("()") {
702                // Zero-arg function intrinsics: __syncthreads()
703                return Ok(cuda_name.to_string());
704            }
705
706            let args: Vec<String> = call
707                .args
708                .iter()
709                .map(|a| self.transpile_expr(a))
710                .collect::<Result<_>>()?;
711
712            return Ok(format!(
713                "{}({})",
714                cuda_name.trim_end_matches("()"),
715                args.join(", ")
716            ));
717        }
718
719        // Regular function call
720        let args: Vec<String> = call
721            .args
722            .iter()
723            .map(|a| self.transpile_expr(a))
724            .collect::<Result<_>>()?;
725
726        Ok(format!("{}({})", func, args.join(", ")))
727    }
728
729    /// Transpile a method call, handling stencil intrinsics.
730    fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
731        let receiver = self.transpile_expr(&method.receiver)?;
732        let method_name = method.method.to_string();
733
734        // Check if this is a SharedTile/SharedArray method call
735        if let Some(result) =
736            self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
737        {
738            return result;
739        }
740
741        // Check if this is a RingContext method call (in ring kernel mode)
742        if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
743            return self.transpile_context_method(&method_name, &method.args);
744        }
745
746        // Check if this is a GridPos method call
747        if self.grid_pos_vars.contains(&receiver) {
748            return self.transpile_stencil_intrinsic(&method_name, &method.args);
749        }
750
751        // Check for ring kernel intrinsics (like is_active(), should_terminate())
752        if self.ring_kernel_mode {
753            if let Some(intrinsic) = RingKernelIntrinsic::from_name(&method_name) {
754                let args: Vec<String> = method
755                    .args
756                    .iter()
757                    .map(|a| self.transpile_expr(a).unwrap_or_default())
758                    .collect();
759                return Ok(intrinsic.to_cuda(&args));
760            }
761        }
762
763        // Check for f32/f64 math methods
764        if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
765            let cuda_name = intrinsic.to_cuda_string();
766            let args: Vec<String> = std::iter::once(receiver)
767                .chain(
768                    method
769                        .args
770                        .iter()
771                        .map(|a| self.transpile_expr(a).unwrap_or_default()),
772                )
773                .collect();
774
775            return Ok(format!("{}({})", cuda_name, args.join(", ")));
776        }
777
778        // Regular method call (treat as field access + call for C structs)
779        let args: Vec<String> = method
780            .args
781            .iter()
782            .map(|a| self.transpile_expr(a))
783            .collect::<Result<_>>()?;
784
785        Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
786    }
787
788    /// Transpile a RingContext method call to CUDA intrinsics.
789    fn transpile_context_method(
790        &self,
791        method: &str,
792        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
793    ) -> Result<String> {
794        let ctx_method = ContextMethod::from_name(method).ok_or_else(|| {
795            TranspileError::Unsupported(format!("Unknown context method: {}", method))
796        })?;
797
798        let cuda_args: Vec<String> = args
799            .iter()
800            .map(|a| self.transpile_expr(a).unwrap_or_default())
801            .collect();
802
803        Ok(ctx_method.to_cuda(&cuda_args))
804    }
805
806    /// Transpile a stencil intrinsic method call.
807    fn transpile_stencil_intrinsic(
808        &self,
809        method: &str,
810        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
811    ) -> Result<String> {
812        let config = self.config.as_ref().ok_or_else(|| {
813            TranspileError::Unsupported("Stencil intrinsic without config".into())
814        })?;
815
816        let buffer_width = config.buffer_width().to_string();
817
818        let intrinsic = StencilIntrinsic::from_method_name(method).ok_or_else(|| {
819            TranspileError::Unsupported(format!("Unknown stencil intrinsic: {method}"))
820        })?;
821
822        match intrinsic {
823            StencilIntrinsic::Index => {
824                // pos.idx() -> idx
825                Ok("idx".to_string())
826            }
827            StencilIntrinsic::North
828            | StencilIntrinsic::South
829            | StencilIntrinsic::East
830            | StencilIntrinsic::West => {
831                // pos.north(buf) -> buf[idx - buffer_width]
832                if args.is_empty() {
833                    return Err(TranspileError::Unsupported(
834                        "Stencil accessor requires buffer argument".into(),
835                    ));
836                }
837                let buffer = self.transpile_expr(&args[0])?;
838                Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx"))
839            }
840            StencilIntrinsic::At => {
841                // pos.at(buf, dx, dy) -> buf[idx + dy * buffer_width + dx]
842                if args.len() < 3 {
843                    return Err(TranspileError::Unsupported(
844                        "at() requires buffer, dx, dy arguments".into(),
845                    ));
846                }
847                let buffer = self.transpile_expr(&args[0])?;
848                let dx = self.transpile_expr(&args[1])?;
849                let dy = self.transpile_expr(&args[2])?;
850                Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]"))
851            }
852            StencilIntrinsic::Up | StencilIntrinsic::Down => {
853                // 3D intrinsics
854                Err(TranspileError::Unsupported(
855                    "3D stencil intrinsics not yet implemented".into(),
856                ))
857            }
858        }
859    }
860
861    /// Transpile an if expression.
862    fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
863        let cond = self.transpile_expr(&if_expr.cond)?;
864
865        // Check if the body contains only a return statement (early return pattern)
866        if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
867            if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
868                // Simple early return: if (cond) return;
869                if ret.expr.is_none() {
870                    return Ok(format!("if ({cond}) return"));
871                }
872                let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
873                return Ok(format!("if ({cond}) return {ret_val}"));
874            }
875        }
876
877        // For now, only handle if-else as ternary when it's an expression
878        if let Some((_, else_branch)) = &if_expr.else_branch {
879            // If both branches are simple expressions, use ternary
880            if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
881                (if_expr.then_branch.stmts.last(), else_branch.as_ref())
882            {
883                if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
884                    let then_str = self.transpile_expr(then_expr)?;
885                    let else_str = self.transpile_expr(else_expr)?;
886                    return Ok(format!("({cond}) ? ({then_str}) : ({else_str})"));
887                }
888            }
889
890            // Otherwise, generate if statement
891            if let Expr::If(else_if) = else_branch.as_ref() {
892                // else if chain
893                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
894                let else_part = self.transpile_if(else_if)?;
895                return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
896            } else if let Expr::Block(else_block) = else_branch.as_ref() {
897                // else block
898                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
899                let else_body = self.transpile_if_body(&else_block.block)?;
900                return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
901            }
902        }
903
904        // If without else
905        let then_body = self.transpile_if_body(&if_expr.then_branch)?;
906        Ok(format!("if ({cond}) {{{then_body}}}"))
907    }
908
909    /// Transpile the body of an if branch.
910    fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
911        let mut body = String::new();
912        for stmt in &block.stmts {
913            match stmt {
914                Stmt::Expr(expr, Some(_)) => {
915                    let expr_str = self.transpile_expr(expr)?;
916                    body.push_str(&format!(" {expr_str};"));
917                }
918                Stmt::Expr(Expr::Return(ret), None) => {
919                    // Handle explicit return without semicolon
920                    if let Some(ret_expr) = &ret.expr {
921                        let expr_str = self.transpile_expr(ret_expr)?;
922                        body.push_str(&format!(" return {expr_str};"));
923                    } else {
924                        body.push_str(" return;");
925                    }
926                }
927                Stmt::Expr(expr, None) => {
928                    let expr_str = self.transpile_expr(expr)?;
929                    body.push_str(&format!(" return {expr_str};"));
930                }
931                _ => {}
932            }
933        }
934        Ok(body)
935    }
936
937    /// Transpile an assignment expression.
938    fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
939        let left = self.transpile_expr(&assign.left)?;
940        let right = self.transpile_expr(&assign.right)?;
941        Ok(format!("{left} = {right}"))
942    }
943
944    /// Transpile a cast expression.
945    fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
946        let expr = self.transpile_expr(&cast.expr)?;
947        let cuda_type = self.type_mapper.map_type(&cast.ty)?;
948        Ok(format!("({})({})", cuda_type.to_cuda_string(), expr))
949    }
950
951    /// Transpile a return expression.
952    fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
953        if let Some(expr) = &ret.expr {
954            let expr_str = self.transpile_expr(expr)?;
955            Ok(format!("return {expr_str}"))
956        } else {
957            Ok("return".to_string())
958        }
959    }
960
961    /// Transpile a struct literal expression.
962    ///
963    /// Converts Rust struct literals to C-style compound literals:
964    /// `Point { x: 1.0, y: 2.0 }` -> `(Point){ .x = 1.0f, .y = 2.0f }`
965    fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
966        // Get the struct type name
967        let type_name = struct_expr
968            .path
969            .segments
970            .iter()
971            .map(|s| s.ident.to_string())
972            .collect::<Vec<_>>()
973            .join("::");
974
975        // Transpile each field
976        let mut fields = Vec::new();
977        for field in &struct_expr.fields {
978            let field_name = match &field.member {
979                syn::Member::Named(ident) => ident.to_string(),
980                syn::Member::Unnamed(idx) => idx.index.to_string(),
981            };
982            let value = self.transpile_expr(&field.expr)?;
983            fields.push(format!(".{} = {}", field_name, value));
984        }
985
986        // Check for struct update syntax (not supported in C)
987        if struct_expr.rest.is_some() {
988            return Err(TranspileError::Unsupported(
989                "Struct update syntax (..base) is not supported in CUDA".into(),
990            ));
991        }
992
993        // Generate C compound literal: (TypeName){ .field1 = val1, .field2 = val2 }
994        Ok(format!("({}){{ {} }}", type_name, fields.join(", ")))
995    }
996
997    /// Transpile a reference expression.
998    ///
999    /// In CUDA C, we typically need pointers. This handles:
1000    /// - `&arr[idx]` -> `&arr[idx]` (pointer to element)
1001    /// - `&mut arr[idx]` -> `&arr[idx]` (same in C)
1002    /// - `&variable` -> `&variable`
1003    fn transpile_reference(&self, ref_expr: &ExprReference) -> Result<String> {
1004        let inner = self.transpile_expr(&ref_expr.expr)?;
1005
1006        // In CUDA C, taking a reference is the same as taking address
1007        // For array indexing like &arr[idx], we produce &arr[idx]
1008        // This creates a pointer to that element
1009        Ok(format!("&{inner}"))
1010    }
1011
1012    /// Transpile a let expression (used in if-let patterns).
1013    ///
1014    /// Note: Full pattern matching is not supported in CUDA, but we can
1015    /// handle simple `if let Some(x) = expr` patterns by transpiling
1016    /// to a simple conditional check.
1017    fn transpile_let_expr(&self, let_expr: &ExprLet) -> Result<String> {
1018        // For now, we treat let expressions as unsupported since
1019        // CUDA doesn't have Option types. But we can add special cases.
1020        let _ = let_expr; // Silence unused warning
1021        Err(TranspileError::Unsupported(
1022            "let expressions (if-let patterns) are not directly supported in CUDA. \
1023             Use explicit comparisons instead."
1024                .into(),
1025        ))
1026    }
1027
1028    // === Loop Transpilation ===
1029
1030    /// Transpile a for loop to CUDA.
1031    ///
1032    /// Handles `for i in start..end` and `for i in start..=end` patterns.
1033    ///
1034    /// # Example
1035    ///
1036    /// ```ignore
1037    /// // Rust
1038    /// for i in 0..n {
1039    ///     data[i] = 0.0;
1040    /// }
1041    ///
1042    /// // CUDA
1043    /// for (int i = 0; i < n; i++) {
1044    ///     data[i] = 0.0f;
1045    /// }
1046    /// ```
1047    fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1048        // Check if loops are allowed
1049        if !self.validation_mode.allows_loops() {
1050            return Err(TranspileError::Unsupported(
1051                "Loops are not allowed in stencil kernels".into(),
1052            ));
1053        }
1054
1055        // Extract loop variable name
1056        let var_name = extract_loop_var(&for_loop.pat)
1057            .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1058
1059        // The iterator expression should be a range
1060        let header = match for_loop.expr.as_ref() {
1061            Expr::Range(range) => {
1062                let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1063                range_info.to_cuda_for_header(&var_name, "int")
1064            }
1065            _ => {
1066                // For non-range iterators, we can't directly transpile
1067                return Err(TranspileError::Unsupported(
1068                    "Only range expressions (start..end) are supported in for loops".into(),
1069                ));
1070            }
1071        };
1072
1073        // Transpile the loop body
1074        let body = self.transpile_loop_body(&for_loop.body)?;
1075
1076        Ok(format!("{header} {{\n{body}}}"))
1077    }
1078
1079    /// Transpile a while loop to CUDA.
1080    ///
1081    /// # Example
1082    ///
1083    /// ```ignore
1084    /// // Rust
1085    /// while !done {
1086    ///     process();
1087    /// }
1088    ///
1089    /// // CUDA
1090    /// while (!done) {
1091    ///     process();
1092    /// }
1093    /// ```
1094    fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1095        // Check if loops are allowed
1096        if !self.validation_mode.allows_loops() {
1097            return Err(TranspileError::Unsupported(
1098                "Loops are not allowed in stencil kernels".into(),
1099            ));
1100        }
1101
1102        // Transpile the condition
1103        let condition = self.transpile_expr(&while_loop.cond)?;
1104
1105        // Transpile the loop body
1106        let body = self.transpile_loop_body(&while_loop.body)?;
1107
1108        Ok(format!("while ({condition}) {{\n{body}}}"))
1109    }
1110
1111    /// Transpile an infinite loop to CUDA.
1112    ///
1113    /// # Example
1114    ///
1115    /// ```ignore
1116    /// // Rust
1117    /// loop {
1118    ///     if should_exit { break; }
1119    /// }
1120    ///
1121    /// // CUDA
1122    /// while (true) {
1123    ///     if (should_exit) { break; }
1124    /// }
1125    /// ```
1126    fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1127        // Check if loops are allowed
1128        if !self.validation_mode.allows_loops() {
1129            return Err(TranspileError::Unsupported(
1130                "Loops are not allowed in stencil kernels".into(),
1131            ));
1132        }
1133
1134        // Transpile the loop body
1135        let body = self.transpile_loop_body(&loop_expr.body)?;
1136
1137        // Use while(true) for infinite loops
1138        Ok(format!("while (true) {{\n{body}}}"))
1139    }
1140
1141    /// Transpile a break expression.
1142    fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1143        // Check for labeled break (not supported)
1144        if break_expr.label.is_some() {
1145            return Err(TranspileError::Unsupported(
1146                "Labeled break is not supported in CUDA".into(),
1147            ));
1148        }
1149
1150        // Check for break with value (not supported in CUDA)
1151        if break_expr.expr.is_some() {
1152            return Err(TranspileError::Unsupported(
1153                "Break with value is not supported in CUDA".into(),
1154            ));
1155        }
1156
1157        Ok("break".to_string())
1158    }
1159
1160    /// Transpile a continue expression.
1161    fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1162        // Check for labeled continue (not supported)
1163        if cont_expr.label.is_some() {
1164            return Err(TranspileError::Unsupported(
1165                "Labeled continue is not supported in CUDA".into(),
1166            ));
1167        }
1168
1169        Ok("continue".to_string())
1170    }
1171
1172    /// Transpile a loop body (block of statements).
1173    fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1174        let mut output = String::new();
1175        let inner_indent = "    ".repeat(self.indent + 1);
1176
1177        for stmt in &block.stmts {
1178            match stmt {
1179                Stmt::Local(local) => {
1180                    // Variable declaration
1181                    let var_name = match &local.pat {
1182                        Pat::Ident(ident) => ident.ident.to_string(),
1183                        Pat::Type(pat_type) => {
1184                            if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1185                                ident.ident.to_string()
1186                            } else {
1187                                return Err(TranspileError::Unsupported(
1188                                    "Complex pattern in let binding".into(),
1189                                ));
1190                            }
1191                        }
1192                        _ => {
1193                            return Err(TranspileError::Unsupported(
1194                                "Complex pattern in let binding".into(),
1195                            ))
1196                        }
1197                    };
1198
1199                    if let Some(init) = &local.init {
1200                        let expr_str = self.transpile_expr(&init.expr)?;
1201                        let type_str = self.infer_cuda_type(&init.expr);
1202                        output.push_str(&format!(
1203                            "{inner_indent}{type_str} {var_name} = {expr_str};\n"
1204                        ));
1205                    } else {
1206                        output.push_str(&format!("{inner_indent}float {var_name};\n"));
1207                    }
1208                }
1209                Stmt::Expr(expr, semi) => {
1210                    let expr_str = self.transpile_expr(expr)?;
1211                    if semi.is_some() {
1212                        output.push_str(&format!("{inner_indent}{expr_str};\n"));
1213                    } else {
1214                        // Expression without semicolon at end of block
1215                        output.push_str(&format!("{inner_indent}{expr_str};\n"));
1216                    }
1217                }
1218                _ => {
1219                    return Err(TranspileError::Unsupported(
1220                        "Unsupported statement in loop body".into(),
1221                    ));
1222                }
1223            }
1224        }
1225
1226        // Add closing indentation
1227        let closing_indent = "    ".repeat(self.indent);
1228        output.push_str(&closing_indent);
1229
1230        Ok(output)
1231    }
1232
1233    // === Shared Memory Support ===
1234
1235    /// Try to parse a local variable declaration as a shared memory declaration.
1236    ///
1237    /// Recognizes patterns like:
1238    /// - `let tile = SharedTile::<f32, 16, 16>::new();`
1239    /// - `let buffer = SharedArray::<f32, 256>::new();`
1240    /// - `let tile: SharedTile<f32, 16, 16> = SharedTile::new();`
1241    fn try_parse_shared_declaration(
1242        &self,
1243        local: &syn::Local,
1244        var_name: &str,
1245    ) -> Result<Option<SharedMemoryDecl>> {
1246        // Check if there's a type annotation
1247        if let Pat::Type(pat_type) = &local.pat {
1248            let type_str = pat_type.ty.to_token_stream().to_string();
1249            return self.parse_shared_type(&type_str, var_name);
1250        }
1251
1252        // Check the initializer expression for SharedTile::new() or SharedArray::new()
1253        if let Some(init) = &local.init {
1254            if let Expr::Call(call) = init.expr.as_ref() {
1255                if let Expr::Path(path) = call.func.as_ref() {
1256                    let path_str = path.to_token_stream().to_string();
1257                    return self.parse_shared_type(&path_str, var_name);
1258                }
1259            }
1260        }
1261
1262        Ok(None)
1263    }
1264
1265    /// Parse a type string to extract shared memory info.
1266    fn parse_shared_type(
1267        &self,
1268        type_str: &str,
1269        var_name: &str,
1270    ) -> Result<Option<SharedMemoryDecl>> {
1271        // Clean up the type string (remove spaces around ::)
1272        let type_str = type_str
1273            .replace(" :: ", "::")
1274            .replace(" ::", "::")
1275            .replace(":: ", "::");
1276
1277        // Check for SharedTile<T, W, H> or SharedTile::<T, W, H>::new
1278        if type_str.contains("SharedTile") {
1279            // Extract the generic parameters
1280            if let Some(start) = type_str.find('<') {
1281                if let Some(end) = type_str.rfind('>') {
1282                    let params = &type_str[start + 1..end];
1283                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1284
1285                    if parts.len() >= 3 {
1286                        let rust_type = parts[0];
1287                        let width: usize = parts[1].parse().map_err(|_| {
1288                            TranspileError::Unsupported("Invalid SharedTile width".into())
1289                        })?;
1290                        let height: usize = parts[2].parse().map_err(|_| {
1291                            TranspileError::Unsupported("Invalid SharedTile height".into())
1292                        })?;
1293
1294                        let cuda_type = rust_to_cuda_element_type(rust_type);
1295                        return Ok(Some(SharedMemoryDecl::tile(
1296                            var_name, cuda_type, width, height,
1297                        )));
1298                    }
1299                }
1300            }
1301        }
1302
1303        // Check for SharedArray<T, N> or SharedArray::<T, N>::new
1304        if type_str.contains("SharedArray") {
1305            if let Some(start) = type_str.find('<') {
1306                if let Some(end) = type_str.rfind('>') {
1307                    let params = &type_str[start + 1..end];
1308                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1309
1310                    if parts.len() >= 2 {
1311                        let rust_type = parts[0];
1312                        let size: usize = parts[1].parse().map_err(|_| {
1313                            TranspileError::Unsupported("Invalid SharedArray size".into())
1314                        })?;
1315
1316                        let cuda_type = rust_to_cuda_element_type(rust_type);
1317                        return Ok(Some(SharedMemoryDecl::array(var_name, cuda_type, size)));
1318                    }
1319                }
1320            }
1321        }
1322
1323        Ok(None)
1324    }
1325
1326    /// Check if a variable is a shared memory variable and handle method calls.
1327    fn try_transpile_shared_method_call(
1328        &self,
1329        receiver: &str,
1330        method_name: &str,
1331        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1332    ) -> Option<Result<String>> {
1333        let shared_info = self.shared_vars.get(receiver)?;
1334
1335        match method_name {
1336            "get" => {
1337                // tile.get(x, y) -> tile[y][x] (for 2D) or arr[idx] (for 1D)
1338                if shared_info.is_tile {
1339                    if args.len() >= 2 {
1340                        let x = self.transpile_expr(&args[0]).ok()?;
1341                        let y = self.transpile_expr(&args[1]).ok()?;
1342                        // CUDA uses row-major: tile[row][col] = tile[y][x]
1343                        Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1344                    } else {
1345                        Some(Err(TranspileError::Unsupported(
1346                            "SharedTile.get requires x and y arguments".into(),
1347                        )))
1348                    }
1349                } else {
1350                    // 1D array
1351                    if !args.is_empty() {
1352                        let idx = self.transpile_expr(&args[0]).ok()?;
1353                        Some(Ok(format!("{}[{}]", receiver, idx)))
1354                    } else {
1355                        Some(Err(TranspileError::Unsupported(
1356                            "SharedArray.get requires index argument".into(),
1357                        )))
1358                    }
1359                }
1360            }
1361            "set" => {
1362                // tile.set(x, y, val) -> tile[y][x] = val
1363                if shared_info.is_tile {
1364                    if args.len() >= 3 {
1365                        let x = self.transpile_expr(&args[0]).ok()?;
1366                        let y = self.transpile_expr(&args[1]).ok()?;
1367                        let val = self.transpile_expr(&args[2]).ok()?;
1368                        Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1369                    } else {
1370                        Some(Err(TranspileError::Unsupported(
1371                            "SharedTile.set requires x, y, and value arguments".into(),
1372                        )))
1373                    }
1374                } else {
1375                    // 1D array
1376                    if args.len() >= 2 {
1377                        let idx = self.transpile_expr(&args[0]).ok()?;
1378                        let val = self.transpile_expr(&args[1]).ok()?;
1379                        Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1380                    } else {
1381                        Some(Err(TranspileError::Unsupported(
1382                            "SharedArray.set requires index and value arguments".into(),
1383                        )))
1384                    }
1385                }
1386            }
1387            "width" | "height" | "size" => {
1388                // These are compile-time constants
1389                match method_name {
1390                    "width" if shared_info.is_tile => {
1391                        Some(Ok(shared_info.dimensions[1].to_string()))
1392                    }
1393                    "height" if shared_info.is_tile => {
1394                        Some(Ok(shared_info.dimensions[0].to_string()))
1395                    }
1396                    "size" => {
1397                        let total: usize = shared_info.dimensions.iter().product();
1398                        Some(Ok(total.to_string()))
1399                    }
1400                    _ => None,
1401                }
1402            }
1403            _ => None,
1404        }
1405    }
1406
1407    /// Transpile a match expression to switch/case.
1408    fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1409        let scrutinee = self.transpile_expr(&match_expr.expr)?;
1410        let mut output = format!("switch ({scrutinee}) {{\n");
1411
1412        for arm in &match_expr.arms {
1413            // Handle the pattern
1414            let case_label = self.transpile_match_pattern(&arm.pat)?;
1415
1416            if case_label == "default" || case_label.starts_with("/*") {
1417                output.push_str("        default: {\n");
1418            } else {
1419                output.push_str(&format!("        case {case_label}: {{\n"));
1420            }
1421
1422            // Handle the arm body - check if it's a block with statements
1423            match arm.body.as_ref() {
1424                Expr::Block(block) => {
1425                    // Block expression with multiple statements
1426                    for stmt in &block.block.stmts {
1427                        let stmt_str = self.transpile_stmt_inline(stmt)?;
1428                        output.push_str(&format!("            {stmt_str}\n"));
1429                    }
1430                }
1431                _ => {
1432                    // Single expression - wrap in statement
1433                    let body = self.transpile_expr(&arm.body)?;
1434                    output.push_str(&format!("            {body};\n"));
1435                }
1436            }
1437
1438            output.push_str("            break;\n");
1439            output.push_str("        }\n");
1440        }
1441
1442        output.push_str("    }");
1443        Ok(output)
1444    }
1445
1446    /// Transpile a match pattern to a case label.
1447    fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1448        match pat {
1449            Pat::Lit(pat_lit) => {
1450                // Integer literal pattern - pat_lit.lit contains the literal
1451                match &pat_lit.lit {
1452                    Lit::Int(i) => Ok(i.to_string()),
1453                    Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
1454                    _ => Err(TranspileError::Unsupported(
1455                        "Non-integer literal in match pattern".into(),
1456                    )),
1457                }
1458            }
1459            Pat::Wild(_) => {
1460                // _ pattern becomes default
1461                Ok("default".to_string())
1462            }
1463            Pat::Ident(ident) => {
1464                // Named pattern - treat as default case for now
1465                // This handles things like `x => ...` which bind a value
1466                Ok(format!("/* {} */ default", ident.ident))
1467            }
1468            Pat::Or(pat_or) => {
1469                // Multiple patterns: 0 | 1 | 2 => ...
1470                // CUDA switch doesn't support this directly, we need multiple case labels
1471                // For now, just use the first pattern and note the limitation
1472                if let Some(first) = pat_or.cases.first() {
1473                    self.transpile_match_pattern(first)
1474                } else {
1475                    Err(TranspileError::Unsupported("Empty or pattern".into()))
1476                }
1477            }
1478            _ => Err(TranspileError::Unsupported(format!(
1479                "Match pattern: {}",
1480                pat.to_token_stream()
1481            ))),
1482        }
1483    }
1484
1485    /// Transpile a statement without indentation (for inline use in switch).
1486    fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1487        match stmt {
1488            Stmt::Local(local) => {
1489                let var_name = match &local.pat {
1490                    Pat::Ident(ident) => ident.ident.to_string(),
1491                    Pat::Type(pat_type) => {
1492                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1493                            ident.ident.to_string()
1494                        } else {
1495                            return Err(TranspileError::Unsupported(
1496                                "Complex pattern in let binding".into(),
1497                            ));
1498                        }
1499                    }
1500                    _ => {
1501                        return Err(TranspileError::Unsupported(
1502                            "Complex pattern in let binding".into(),
1503                        ))
1504                    }
1505                };
1506
1507                if let Some(init) = &local.init {
1508                    let expr_str = self.transpile_expr(&init.expr)?;
1509                    let type_str = self.infer_cuda_type(&init.expr);
1510                    Ok(format!("{type_str} {var_name} = {expr_str};"))
1511                } else {
1512                    Ok(format!("float {var_name};"))
1513                }
1514            }
1515            Stmt::Expr(expr, semi) => {
1516                let expr_str = self.transpile_expr(expr)?;
1517                if semi.is_some() {
1518                    Ok(format!("{expr_str};"))
1519                } else {
1520                    Ok(format!("return {expr_str};"))
1521                }
1522            }
1523            _ => Err(TranspileError::Unsupported(
1524                "Unsupported statement in match arm".into(),
1525            )),
1526        }
1527    }
1528
1529    /// Infer CUDA type from expression (simple heuristic).
1530    fn infer_cuda_type(&self, expr: &Expr) -> &'static str {
1531        match expr {
1532            Expr::Lit(lit) => match &lit.lit {
1533                Lit::Float(_) => "float",
1534                Lit::Int(_) => "int",
1535                Lit::Bool(_) => "int",
1536                _ => "float",
1537            },
1538            Expr::Binary(bin) => {
1539                // Check if this is an integer operation
1540                let left_type = self.infer_cuda_type(&bin.left);
1541                let right_type = self.infer_cuda_type(&bin.right);
1542                // If both sides are int, result is int
1543                if left_type == "int" && right_type == "int" {
1544                    "int"
1545                } else {
1546                    "float"
1547                }
1548            }
1549            Expr::Call(call) => {
1550                // Check for intrinsics that return int
1551                if let Ok(func) = self.transpile_expr(&call.func) {
1552                    if let Some(intrinsic) = self.intrinsics.lookup(&func) {
1553                        let cuda_name = intrinsic.to_cuda_string();
1554                        // Thread/block indices return int
1555                        if cuda_name.contains("Idx") || cuda_name.contains("Dim") {
1556                            return "int";
1557                        }
1558                    }
1559                }
1560                "float"
1561            }
1562            Expr::Index(_) => "float", // Array access - could be any type, default to float
1563            Expr::Cast(cast) => {
1564                // Use the target type of the cast
1565                if let Ok(cuda_type) = self.type_mapper.map_type(&cast.ty) {
1566                    let s = cuda_type.to_cuda_string();
1567                    if s.contains("int") || s.contains("size_t") || s == "unsigned long long" {
1568                        return "int";
1569                    }
1570                }
1571                "float"
1572            }
1573            Expr::Reference(ref_expr) => {
1574                // Reference expression - try to determine pointer type from inner
1575                // For &arr[idx], we get a pointer to the element type
1576                match ref_expr.expr.as_ref() {
1577                    Expr::Index(idx_expr) => {
1578                        // &arr[idx] - get the base array name and look it up
1579                        if let Expr::Path(path) = &*idx_expr.expr {
1580                            let name = path
1581                                .path
1582                                .segments
1583                                .iter()
1584                                .map(|s| s.ident.to_string())
1585                                .collect::<Vec<_>>()
1586                                .join("::");
1587                            // Common GPU struct names -> pointer to that struct
1588                            if name.contains("transaction") || name.contains("Transaction") {
1589                                return "GpuTransaction*";
1590                            }
1591                            if name.contains("profile") || name.contains("Profile") {
1592                                return "GpuCustomerProfile*";
1593                            }
1594                            if name.contains("alert") || name.contains("Alert") {
1595                                return "GpuAlert*";
1596                            }
1597                        }
1598                        "float*" // Default element pointer
1599                    }
1600                    _ => "void*",
1601                }
1602            }
1603            Expr::MethodCall(_) => "float",
1604            Expr::Field(field) => {
1605                // Field access - try to infer type from field name
1606                let member_name = match &field.member {
1607                    syn::Member::Named(ident) => ident.to_string(),
1608                    syn::Member::Unnamed(idx) => idx.index.to_string(),
1609                };
1610                // Common field name patterns
1611                if member_name.contains("count") || member_name.contains("_count") {
1612                    return "unsigned int";
1613                }
1614                if member_name.contains("threshold") || member_name.ends_with("_id") {
1615                    return "unsigned long long";
1616                }
1617                if member_name.ends_with("_pct") {
1618                    return "unsigned char";
1619                }
1620                "float"
1621            }
1622            Expr::Path(path) => {
1623                // Variable access - check if it's a known variable
1624                let name = path
1625                    .path
1626                    .segments
1627                    .iter()
1628                    .map(|s| s.ident.to_string())
1629                    .collect::<Vec<_>>()
1630                    .join("::");
1631                if name.contains("threshold")
1632                    || name.contains("count")
1633                    || name == "idx"
1634                    || name == "n"
1635                {
1636                    return "int";
1637                }
1638                "float"
1639            }
1640            Expr::If(if_expr) => {
1641                // For ternary (if-else), infer type from branches
1642                if let Some((_, else_branch)) = &if_expr.else_branch {
1643                    if let Expr::Block(block) = else_branch.as_ref() {
1644                        if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
1645                            return self.infer_cuda_type(expr);
1646                        }
1647                    }
1648                }
1649                // Try from then branch
1650                if let Some(Stmt::Expr(expr, None)) = if_expr.then_branch.stmts.last() {
1651                    return self.infer_cuda_type(expr);
1652                }
1653                "float"
1654            }
1655            _ => "float",
1656        }
1657    }
1658}
1659
1660/// Transpile a function to CUDA without stencil configuration.
1661pub fn transpile_function(func: &ItemFn) -> Result<String> {
1662    let mut transpiler = CudaTranspiler::new_generic();
1663
1664    // Generate function signature
1665    let name = func.sig.ident.to_string();
1666
1667    let mut params = Vec::new();
1668    for param in &func.sig.inputs {
1669        if let FnArg::Typed(pat_type) = param {
1670            let param_name = match pat_type.pat.as_ref() {
1671                Pat::Ident(ident) => ident.ident.to_string(),
1672                _ => continue,
1673            };
1674
1675            let cuda_type = transpiler.type_mapper.map_type(&pat_type.ty)?;
1676            params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
1677        }
1678    }
1679
1680    // Return type
1681    let return_type = match &func.sig.output {
1682        ReturnType::Default => "void".to_string(),
1683        ReturnType::Type(_, ty) => transpiler.type_mapper.map_type(ty)?.to_cuda_string(),
1684    };
1685
1686    // Generate body
1687    let body = transpiler.transpile_block(&func.block)?;
1688
1689    Ok(format!(
1690        "__device__ {return_type} {name}({params}) {{\n{body}}}\n",
1691        params = params.join(", ")
1692    ))
1693}
1694
1695#[cfg(test)]
1696mod tests {
1697    use super::*;
1698    use syn::parse_quote;
1699
1700    #[test]
1701    fn test_simple_arithmetic() {
1702        let transpiler = CudaTranspiler::new_generic();
1703
1704        let expr: Expr = parse_quote!(a + b * 2.0);
1705        let result = transpiler.transpile_expr(&expr).unwrap();
1706        assert_eq!(result, "a + b * 2.0f");
1707    }
1708
1709    #[test]
1710    fn test_let_binding() {
1711        let mut transpiler = CudaTranspiler::new_generic();
1712
1713        let stmt: Stmt = parse_quote!(let x = a + b;);
1714        let result = transpiler.transpile_stmt(&stmt).unwrap();
1715        assert!(result.contains("float x = a + b;"));
1716    }
1717
1718    #[test]
1719    fn test_array_index() {
1720        let transpiler = CudaTranspiler::new_generic();
1721
1722        let expr: Expr = parse_quote!(data[idx]);
1723        let result = transpiler.transpile_expr(&expr).unwrap();
1724        assert_eq!(result, "data[idx]");
1725    }
1726
1727    #[test]
1728    fn test_stencil_intrinsics() {
1729        let config = StencilConfig::new("test")
1730            .with_tile_size(16, 16)
1731            .with_halo(1);
1732        let mut transpiler = CudaTranspiler::new(config);
1733        transpiler.grid_pos_vars.push("pos".to_string());
1734
1735        // Test pos.idx()
1736        let expr: Expr = parse_quote!(pos.idx());
1737        let result = transpiler.transpile_expr(&expr).unwrap();
1738        assert_eq!(result, "idx");
1739
1740        // Test pos.north(p)
1741        let expr: Expr = parse_quote!(pos.north(p));
1742        let result = transpiler.transpile_expr(&expr).unwrap();
1743        assert_eq!(result, "p[idx - 18]");
1744
1745        // Test pos.east(p)
1746        let expr: Expr = parse_quote!(pos.east(p));
1747        let result = transpiler.transpile_expr(&expr).unwrap();
1748        assert_eq!(result, "p[idx + 1]");
1749    }
1750
1751    #[test]
1752    fn test_ternary_if() {
1753        let transpiler = CudaTranspiler::new_generic();
1754
1755        let expr: Expr = parse_quote!(if x > 0.0 { x } else { -x });
1756        let result = transpiler.transpile_expr(&expr).unwrap();
1757        assert!(result.contains("?"));
1758        assert!(result.contains(":"));
1759    }
1760
1761    #[test]
1762    fn test_full_stencil_kernel() {
1763        let func: ItemFn = parse_quote! {
1764            fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
1765                let curr = p[pos.idx()];
1766                let prev = p_prev[pos.idx()];
1767                let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
1768                p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
1769            }
1770        };
1771
1772        let config = StencilConfig::new("fdtd")
1773            .with_tile_size(16, 16)
1774            .with_halo(1);
1775
1776        let mut transpiler = CudaTranspiler::new(config);
1777        let cuda = transpiler.transpile_stencil(&func).unwrap();
1778
1779        // Check key features
1780        assert!(cuda.contains("extern \"C\" __global__"));
1781        assert!(cuda.contains("threadIdx.x"));
1782        assert!(cuda.contains("threadIdx.y"));
1783        assert!(cuda.contains("buffer_width = 18"));
1784        assert!(cuda.contains("const float* __restrict__ p"));
1785        assert!(cuda.contains("float* __restrict__ p_prev"));
1786        assert!(!cuda.contains("GridPos")); // GridPos should be removed
1787
1788        println!("Generated CUDA:\n{}", cuda);
1789    }
1790
1791    #[test]
1792    fn test_early_return() {
1793        let mut transpiler = CudaTranspiler::new_generic();
1794
1795        let stmt: Stmt = parse_quote!(return;);
1796        let result = transpiler.transpile_stmt(&stmt).unwrap();
1797        assert!(result.contains("return;"));
1798
1799        let stmt_val: Stmt = parse_quote!(return 42;);
1800        let result_val = transpiler.transpile_stmt(&stmt_val).unwrap();
1801        assert!(result_val.contains("return 42;"));
1802    }
1803
1804    #[test]
1805    fn test_match_to_switch() {
1806        let transpiler = CudaTranspiler::new_generic();
1807
1808        let expr: Expr = parse_quote! {
1809            match edge {
1810                0 => { idx = 1 * 18 + i; }
1811                1 => { idx = 16 * 18 + i; }
1812                _ => { idx = 0; }
1813            }
1814        };
1815
1816        let result = transpiler.transpile_expr(&expr).unwrap();
1817        assert!(
1818            result.contains("switch (edge)"),
1819            "Should generate switch: {}",
1820            result
1821        );
1822        assert!(result.contains("case 0:"), "Should have case 0: {}", result);
1823        assert!(result.contains("case 1:"), "Should have case 1: {}", result);
1824        assert!(
1825            result.contains("default:"),
1826            "Should have default: {}",
1827            result
1828        );
1829        assert!(result.contains("break;"), "Should have break: {}", result);
1830
1831        println!("Generated switch:\n{}", result);
1832    }
1833
1834    #[test]
1835    fn test_block_idx_intrinsics() {
1836        let transpiler = CudaTranspiler::new_generic();
1837
1838        // Test block_idx_x() call
1839        let expr: Expr = parse_quote!(block_idx_x());
1840        let result = transpiler.transpile_expr(&expr).unwrap();
1841        assert_eq!(result, "blockIdx.x");
1842
1843        // Test thread_idx_y() call
1844        let expr2: Expr = parse_quote!(thread_idx_y());
1845        let result2 = transpiler.transpile_expr(&expr2).unwrap();
1846        assert_eq!(result2, "threadIdx.y");
1847
1848        // Test grid_dim_x() call
1849        let expr3: Expr = parse_quote!(grid_dim_x());
1850        let result3 = transpiler.transpile_expr(&expr3).unwrap();
1851        assert_eq!(result3, "gridDim.x");
1852    }
1853
1854    #[test]
1855    fn test_global_index_calculation() {
1856        let transpiler = CudaTranspiler::new_generic();
1857
1858        // Common CUDA pattern: gx = blockIdx.x * blockDim.x + threadIdx.x
1859        let expr: Expr = parse_quote!(block_idx_x() * block_dim_x() + thread_idx_x());
1860        let result = transpiler.transpile_expr(&expr).unwrap();
1861        assert!(result.contains("blockIdx.x"), "Should contain blockIdx.x");
1862        assert!(result.contains("blockDim.x"), "Should contain blockDim.x");
1863        assert!(result.contains("threadIdx.x"), "Should contain threadIdx.x");
1864
1865        println!("Global index expression: {}", result);
1866    }
1867
1868    // === Loop Transpilation Tests ===
1869
1870    #[test]
1871    fn test_for_loop_transpile() {
1872        let transpiler = CudaTranspiler::new_generic();
1873
1874        let expr: Expr = parse_quote! {
1875            for i in 0..n {
1876                data[i] = 0.0;
1877            }
1878        };
1879
1880        let result = transpiler.transpile_expr(&expr).unwrap();
1881        assert!(
1882            result.contains("for (int i = 0; i < n; i++)"),
1883            "Should generate for loop header: {}",
1884            result
1885        );
1886        assert!(
1887            result.contains("data[i] = 0.0f"),
1888            "Should contain loop body: {}",
1889            result
1890        );
1891
1892        println!("Generated for loop:\n{}", result);
1893    }
1894
1895    #[test]
1896    fn test_for_loop_inclusive_range() {
1897        let transpiler = CudaTranspiler::new_generic();
1898
1899        let expr: Expr = parse_quote! {
1900            for i in 1..=10 {
1901                sum += i;
1902            }
1903        };
1904
1905        let result = transpiler.transpile_expr(&expr).unwrap();
1906        assert!(
1907            result.contains("for (int i = 1; i <= 10; i++)"),
1908            "Should generate inclusive range: {}",
1909            result
1910        );
1911
1912        println!("Generated inclusive for loop:\n{}", result);
1913    }
1914
1915    #[test]
1916    fn test_while_loop_transpile() {
1917        let transpiler = CudaTranspiler::new_generic();
1918
1919        let expr: Expr = parse_quote! {
1920            while i < 10 {
1921                i += 1;
1922            }
1923        };
1924
1925        let result = transpiler.transpile_expr(&expr).unwrap();
1926        assert!(
1927            result.contains("while (i < 10)"),
1928            "Should generate while loop: {}",
1929            result
1930        );
1931        assert!(
1932            result.contains("i += 1"),
1933            "Should contain loop body: {}",
1934            result
1935        );
1936
1937        println!("Generated while loop:\n{}", result);
1938    }
1939
1940    #[test]
1941    fn test_while_loop_negation() {
1942        let transpiler = CudaTranspiler::new_generic();
1943
1944        let expr: Expr = parse_quote! {
1945            while !done {
1946                process();
1947            }
1948        };
1949
1950        let result = transpiler.transpile_expr(&expr).unwrap();
1951        assert!(
1952            result.contains("while (!(done))"),
1953            "Should negate condition: {}",
1954            result
1955        );
1956
1957        println!("Generated while loop with negation:\n{}", result);
1958    }
1959
1960    #[test]
1961    fn test_infinite_loop_transpile() {
1962        let transpiler = CudaTranspiler::new_generic();
1963
1964        let expr: Expr = parse_quote! {
1965            loop {
1966                process();
1967            }
1968        };
1969
1970        let result = transpiler.transpile_expr(&expr).unwrap();
1971        assert!(
1972            result.contains("while (true)"),
1973            "Should generate infinite loop: {}",
1974            result
1975        );
1976        assert!(
1977            result.contains("process()"),
1978            "Should contain loop body: {}",
1979            result
1980        );
1981
1982        println!("Generated infinite loop:\n{}", result);
1983    }
1984
1985    #[test]
1986    fn test_break_transpile() {
1987        let transpiler = CudaTranspiler::new_generic();
1988
1989        let expr: Expr = parse_quote!(break);
1990        let result = transpiler.transpile_expr(&expr).unwrap();
1991        assert_eq!(result, "break");
1992    }
1993
1994    #[test]
1995    fn test_continue_transpile() {
1996        let transpiler = CudaTranspiler::new_generic();
1997
1998        let expr: Expr = parse_quote!(continue);
1999        let result = transpiler.transpile_expr(&expr).unwrap();
2000        assert_eq!(result, "continue");
2001    }
2002
2003    #[test]
2004    fn test_loop_with_break() {
2005        let transpiler = CudaTranspiler::new_generic();
2006
2007        let expr: Expr = parse_quote! {
2008            loop {
2009                if done {
2010                    break;
2011                }
2012            }
2013        };
2014
2015        let result = transpiler.transpile_expr(&expr).unwrap();
2016        assert!(
2017            result.contains("while (true)"),
2018            "Should generate infinite loop: {}",
2019            result
2020        );
2021        assert!(result.contains("break"), "Should contain break: {}", result);
2022
2023        println!("Generated loop with break:\n{}", result);
2024    }
2025
2026    #[test]
2027    fn test_nested_loops() {
2028        let transpiler = CudaTranspiler::new_generic();
2029
2030        let expr: Expr = parse_quote! {
2031            for i in 0..m {
2032                for j in 0..n {
2033                    matrix[i * n + j] = 0.0;
2034                }
2035            }
2036        };
2037
2038        let result = transpiler.transpile_expr(&expr).unwrap();
2039        assert!(
2040            result.contains("for (int i = 0; i < m; i++)"),
2041            "Should have outer loop: {}",
2042            result
2043        );
2044        assert!(
2045            result.contains("for (int j = 0; j < n; j++)"),
2046            "Should have inner loop: {}",
2047            result
2048        );
2049
2050        println!("Generated nested loops:\n{}", result);
2051    }
2052
2053    #[test]
2054    fn test_stencil_mode_rejects_loops() {
2055        let config = StencilConfig::new("test")
2056            .with_tile_size(16, 16)
2057            .with_halo(1);
2058        let transpiler = CudaTranspiler::new(config);
2059
2060        let expr: Expr = parse_quote! {
2061            for i in 0..n {
2062                data[i] = 0.0;
2063            }
2064        };
2065
2066        let result = transpiler.transpile_expr(&expr);
2067        assert!(result.is_err(), "Stencil mode should reject loops");
2068    }
2069
2070    #[test]
2071    fn test_labeled_break_rejected() {
2072        let transpiler = CudaTranspiler::new_generic();
2073
2074        // Note: We can't directly parse `break 'label` without a labeled block,
2075        // so we test that the error path exists by checking the function handles labels
2076        let break_expr = syn::ExprBreak {
2077            attrs: Vec::new(),
2078            break_token: syn::token::Break::default(),
2079            label: Some(syn::Lifetime::new("'outer", proc_macro2::Span::call_site())),
2080            expr: None,
2081        };
2082
2083        let result = transpiler.transpile_break(&break_expr);
2084        assert!(result.is_err(), "Labeled break should be rejected");
2085    }
2086
2087    #[test]
2088    fn test_full_kernel_with_loop() {
2089        let func: ItemFn = parse_quote! {
2090            fn fill_array(data: &mut [f32], n: i32) {
2091                for i in 0..n {
2092                    data[i as usize] = 0.0;
2093                }
2094            }
2095        };
2096
2097        let mut transpiler = CudaTranspiler::new_generic();
2098        let cuda = transpiler.transpile_generic_kernel(&func).unwrap();
2099
2100        assert!(
2101            cuda.contains("extern \"C\" __global__"),
2102            "Should be global kernel: {}",
2103            cuda
2104        );
2105        assert!(
2106            cuda.contains("for (int i = 0; i < n; i++)"),
2107            "Should have for loop: {}",
2108            cuda
2109        );
2110
2111        println!("Generated kernel with loop:\n{}", cuda);
2112    }
2113
2114    #[test]
2115    fn test_persistent_kernel_pattern() {
2116        // Test the pattern used for ring/actor kernels
2117        let transpiler = CudaTranspiler::with_mode(ValidationMode::RingKernel);
2118
2119        let expr: Expr = parse_quote! {
2120            while !should_terminate {
2121                if has_message {
2122                    process_message();
2123                }
2124            }
2125        };
2126
2127        let result = transpiler.transpile_expr(&expr).unwrap();
2128        assert!(
2129            result.contains("while (!(should_terminate))"),
2130            "Should have persistent loop: {}",
2131            result
2132        );
2133        assert!(
2134            result.contains("if (has_message)"),
2135            "Should have message check: {}",
2136            result
2137        );
2138
2139        println!("Generated persistent kernel pattern:\n{}", result);
2140    }
2141
2142    // ==================== Shared Memory Tests ====================
2143
2144    #[test]
2145    fn test_shared_tile_declaration() {
2146        use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2147
2148        let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
2149        assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
2150
2151        let mut config = SharedMemoryConfig::new();
2152        config.add_tile("tile", "float", 16, 16);
2153        assert_eq!(config.total_bytes(), 16 * 16 * 4); // 1024 bytes
2154
2155        let decls = config.generate_declarations("    ");
2156        assert!(decls.contains("__shared__ float tile[16][16];"));
2157    }
2158
2159    #[test]
2160    fn test_shared_array_declaration() {
2161        use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2162
2163        let decl = SharedMemoryDecl::array("buffer", "float", 256);
2164        assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
2165
2166        let mut config = SharedMemoryConfig::new();
2167        config.add_array("buffer", "float", 256);
2168        assert_eq!(config.total_bytes(), 256 * 4); // 1024 bytes
2169    }
2170
2171    #[test]
2172    fn test_shared_memory_access_expressions() {
2173        use crate::shared::SharedMemoryDecl;
2174
2175        let tile = SharedMemoryDecl::tile("tile", "float", 16, 16);
2176        assert_eq!(
2177            tile.to_cuda_access(&["y".to_string(), "x".to_string()]),
2178            "tile[y][x]"
2179        );
2180
2181        let arr = SharedMemoryDecl::array("buf", "int", 128);
2182        assert_eq!(arr.to_cuda_access(&["i".to_string()]), "buf[i]");
2183    }
2184
2185    #[test]
2186    fn test_parse_shared_tile_type() {
2187        use crate::shared::parse_shared_tile_type;
2188
2189        let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
2190        assert_eq!(result, Some(("f32".to_string(), 16, 16)));
2191
2192        let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
2193        assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
2194
2195        let invalid = parse_shared_tile_type("Vec<f32>");
2196        assert_eq!(invalid, None);
2197    }
2198
2199    #[test]
2200    fn test_parse_shared_array_type() {
2201        use crate::shared::parse_shared_array_type;
2202
2203        let result = parse_shared_array_type("SharedArray::<f32, 256>");
2204        assert_eq!(result, Some(("f32".to_string(), 256)));
2205
2206        let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
2207        assert_eq!(result2, Some(("u32".to_string(), 1024)));
2208
2209        let invalid = parse_shared_array_type("Vec<f32>");
2210        assert_eq!(invalid, None);
2211    }
2212
2213    #[test]
2214    fn test_rust_to_cuda_element_types() {
2215        use crate::shared::rust_to_cuda_element_type;
2216
2217        assert_eq!(rust_to_cuda_element_type("f32"), "float");
2218        assert_eq!(rust_to_cuda_element_type("f64"), "double");
2219        assert_eq!(rust_to_cuda_element_type("i32"), "int");
2220        assert_eq!(rust_to_cuda_element_type("u32"), "unsigned int");
2221        assert_eq!(rust_to_cuda_element_type("i64"), "long long");
2222        assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
2223        assert_eq!(rust_to_cuda_element_type("bool"), "int");
2224    }
2225
2226    #[test]
2227    fn test_shared_memory_total_bytes() {
2228        use crate::shared::SharedMemoryConfig;
2229
2230        let mut config = SharedMemoryConfig::new();
2231        config.add_tile("tile1", "float", 16, 16); // 16*16*4 = 1024
2232        config.add_tile("tile2", "double", 8, 8); // 8*8*8 = 512
2233        config.add_array("temp", "int", 64); // 64*4 = 256
2234
2235        assert_eq!(config.total_bytes(), 1024 + 512 + 256);
2236    }
2237
2238    #[test]
2239    fn test_transpiler_shared_var_tracking() {
2240        let mut transpiler = CudaTranspiler::new_generic();
2241
2242        // Manually register a shared variable
2243        transpiler.shared_vars.insert(
2244            "tile".to_string(),
2245            SharedVarInfo {
2246                name: "tile".to_string(),
2247                is_tile: true,
2248                dimensions: vec![16, 16],
2249                element_type: "float".to_string(),
2250            },
2251        );
2252
2253        // Test that transpiler tracks it
2254        assert!(transpiler.shared_vars.contains_key("tile"));
2255        assert!(transpiler.shared_vars.get("tile").unwrap().is_tile);
2256    }
2257
2258    #[test]
2259    fn test_shared_tile_get_transpilation() {
2260        let mut transpiler = CudaTranspiler::new_generic();
2261
2262        // Register a shared tile
2263        transpiler.shared_vars.insert(
2264            "tile".to_string(),
2265            SharedVarInfo {
2266                name: "tile".to_string(),
2267                is_tile: true,
2268                dimensions: vec![16, 16],
2269                element_type: "float".to_string(),
2270            },
2271        );
2272
2273        // Test method call transpilation
2274        let result = transpiler.try_transpile_shared_method_call(
2275            "tile",
2276            "get",
2277            &syn::punctuated::Punctuated::new(),
2278        );
2279
2280        // With no args, it should return None (args required)
2281        assert!(result.is_none() || result.unwrap().is_err());
2282    }
2283
2284    #[test]
2285    fn test_shared_array_access() {
2286        let mut transpiler = CudaTranspiler::new_generic();
2287
2288        // Register a shared array
2289        transpiler.shared_vars.insert(
2290            "buffer".to_string(),
2291            SharedVarInfo {
2292                name: "buffer".to_string(),
2293                is_tile: false,
2294                dimensions: vec![256],
2295                element_type: "float".to_string(),
2296            },
2297        );
2298
2299        assert!(!transpiler.shared_vars.get("buffer").unwrap().is_tile);
2300        assert_eq!(
2301            transpiler.shared_vars.get("buffer").unwrap().dimensions,
2302            vec![256]
2303        );
2304    }
2305
2306    #[test]
2307    fn test_full_kernel_with_shared_memory() {
2308        // Test that we can generate declarations correctly
2309        use crate::shared::SharedMemoryConfig;
2310
2311        let mut config = SharedMemoryConfig::new();
2312        config.add_tile("smem", "float", 16, 16);
2313
2314        let decls = config.generate_declarations("    ");
2315        assert!(decls.contains("__shared__ float smem[16][16];"));
2316        assert!(!config.is_empty());
2317    }
2318
2319    // === Struct Literal Tests ===
2320
2321    #[test]
2322    fn test_struct_literal_transpile() {
2323        let transpiler = CudaTranspiler::new_generic();
2324
2325        let expr: Expr = parse_quote! {
2326            Point { x: 1.0, y: 2.0 }
2327        };
2328
2329        let result = transpiler.transpile_expr(&expr).unwrap();
2330        assert!(
2331            result.contains("Point"),
2332            "Should contain struct name: {}",
2333            result
2334        );
2335        assert!(result.contains(".x ="), "Should have field x: {}", result);
2336        assert!(result.contains(".y ="), "Should have field y: {}", result);
2337        assert!(
2338            result.contains("1.0f"),
2339            "Should have value 1.0f: {}",
2340            result
2341        );
2342        assert!(
2343            result.contains("2.0f"),
2344            "Should have value 2.0f: {}",
2345            result
2346        );
2347
2348        println!("Generated struct literal: {}", result);
2349    }
2350
2351    #[test]
2352    fn test_struct_literal_with_expressions() {
2353        let transpiler = CudaTranspiler::new_generic();
2354
2355        let expr: Expr = parse_quote! {
2356            Response { value: x * 2.0, id: idx as u64 }
2357        };
2358
2359        let result = transpiler.transpile_expr(&expr).unwrap();
2360        assert!(
2361            result.contains("Response"),
2362            "Should contain struct name: {}",
2363            result
2364        );
2365        assert!(
2366            result.contains(".value = x * 2.0f"),
2367            "Should have computed value: {}",
2368            result
2369        );
2370        assert!(result.contains(".id ="), "Should have id field: {}", result);
2371
2372        println!("Generated struct with expressions: {}", result);
2373    }
2374
2375    #[test]
2376    fn test_struct_literal_in_return() {
2377        let mut transpiler = CudaTranspiler::new_generic();
2378
2379        let stmt: Stmt = parse_quote! {
2380            return MyStruct { a: 1, b: 2.0 };
2381        };
2382
2383        let result = transpiler.transpile_stmt(&stmt).unwrap();
2384        assert!(result.contains("return"), "Should have return: {}", result);
2385        assert!(
2386            result.contains("MyStruct"),
2387            "Should contain struct name: {}",
2388            result
2389        );
2390
2391        println!("Generated return with struct: {}", result);
2392    }
2393
2394    #[test]
2395    fn test_struct_literal_compound_literal_format() {
2396        let transpiler = CudaTranspiler::new_generic();
2397
2398        let expr: Expr = parse_quote! {
2399            Vec3 { x: a, y: b, z: c }
2400        };
2401
2402        let result = transpiler.transpile_expr(&expr).unwrap();
2403        // Check for C compound literal format: (Type){ .field = val, ... }
2404        assert!(
2405            result.starts_with("(Vec3){"),
2406            "Should use compound literal format: {}",
2407            result
2408        );
2409        assert!(
2410            result.ends_with("}"),
2411            "Should end with closing brace: {}",
2412            result
2413        );
2414
2415        println!("Generated compound literal: {}", result);
2416    }
2417
2418    // === Reference Expression Tests ===
2419
2420    #[test]
2421    fn test_reference_to_array_element() {
2422        let transpiler = CudaTranspiler::new_generic();
2423
2424        let expr: Expr = parse_quote! {
2425            &arr[idx]
2426        };
2427
2428        let result = transpiler.transpile_expr(&expr).unwrap();
2429        assert_eq!(
2430            result, "&arr[idx]",
2431            "Should produce address-of array element"
2432        );
2433    }
2434
2435    #[test]
2436    fn test_mutable_reference_to_array_element() {
2437        let transpiler = CudaTranspiler::new_generic();
2438
2439        let expr: Expr = parse_quote! {
2440            &mut arr[idx * 4 + offset]
2441        };
2442
2443        let result = transpiler.transpile_expr(&expr).unwrap();
2444        assert!(
2445            result.contains("&arr["),
2446            "Should produce address-of: {}",
2447            result
2448        );
2449        assert!(
2450            result.contains("idx * 4"),
2451            "Should have index expression: {}",
2452            result
2453        );
2454    }
2455
2456    #[test]
2457    fn test_reference_to_variable() {
2458        let transpiler = CudaTranspiler::new_generic();
2459
2460        let expr: Expr = parse_quote! {
2461            &value
2462        };
2463
2464        let result = transpiler.transpile_expr(&expr).unwrap();
2465        assert_eq!(result, "&value", "Should produce address-of variable");
2466    }
2467
2468    #[test]
2469    fn test_reference_to_struct_field() {
2470        let transpiler = CudaTranspiler::new_generic();
2471
2472        let expr: Expr = parse_quote! {
2473            &alerts[(idx as usize) * 4 + alert_idx as usize]
2474        };
2475
2476        let result = transpiler.transpile_expr(&expr).unwrap();
2477        assert!(
2478            result.starts_with("&alerts["),
2479            "Should have address-of array: {}",
2480            result
2481        );
2482
2483        println!("Generated reference: {}", result);
2484    }
2485
2486    #[test]
2487    fn test_complex_reference_pattern() {
2488        let mut transpiler = CudaTranspiler::new_generic();
2489
2490        // This is the pattern from txmon batch kernel
2491        let stmt: Stmt = parse_quote! {
2492            let alert = &mut alerts[(idx as usize) * 4 + alert_idx as usize];
2493        };
2494
2495        let result = transpiler.transpile_stmt(&stmt).unwrap();
2496        assert!(
2497            result.contains("alert ="),
2498            "Should have variable assignment: {}",
2499            result
2500        );
2501        assert!(
2502            result.contains("&alerts["),
2503            "Should have reference to array: {}",
2504            result
2505        );
2506
2507        println!("Generated statement: {}", result);
2508    }
2509}