Skip to main content

ringkernel_cuda_codegen/
transpiler.rs

1//! Core Rust-to-CUDA transpiler.
2//!
3//! This module handles the translation of Rust AST to CUDA C code.
4
5use crate::handler::{ContextMethod, HandlerSignature};
6use crate::intrinsics::{IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
7use crate::loops::{extract_loop_var, RangeInfo};
8use crate::shared::{rust_to_cuda_element_type, SharedMemoryConfig, SharedMemoryDecl};
9use crate::stencil::StencilConfig;
10use crate::types::{is_grid_pos_type, is_ring_context_type, TypeMapper};
11use crate::validation::ValidationMode;
12use crate::{Result, TranspileError};
13use quote::ToTokens;
14use syn::{
15    BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
16    ExprIf, ExprIndex, ExprLet, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
17    ExprReference, ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat,
18    ReturnType, Stmt, UnOp,
19};
20
21/// CUDA code transpiler.
22pub struct CudaTranspiler {
23    /// Stencil configuration (if generating a stencil kernel).
24    config: Option<StencilConfig>,
25    /// Type mapper.
26    type_mapper: TypeMapper,
27    /// Intrinsic registry.
28    intrinsics: IntrinsicRegistry,
29    /// Variables known to be the GridPos context.
30    grid_pos_vars: Vec<String>,
31    /// Variables known to be RingContext references.
32    context_vars: Vec<String>,
33    /// Current indentation level.
34    indent: usize,
35    /// Validation mode for loop handling.
36    validation_mode: ValidationMode,
37    /// Shared memory configuration.
38    shared_memory: SharedMemoryConfig,
39    /// Variables that are SharedTile or SharedArray types.
40    pub shared_vars: std::collections::HashMap<String, SharedVarInfo>,
41    /// Whether we're in ring kernel mode (enables context method inlining).
42    ring_kernel_mode: bool,
43    /// Variables known to be pointer types (for -> operator usage).
44    pointer_vars: std::collections::HashSet<String>,
45}
46
47/// Information about a shared memory variable.
48#[derive(Debug, Clone)]
49pub struct SharedVarInfo {
50    /// Variable name.
51    pub name: String,
52    /// Whether it's a 2D tile (true) or 1D array (false).
53    pub is_tile: bool,
54    /// Dimensions: [size] for 1D, [height, width] for 2D.
55    pub dimensions: Vec<usize>,
56    /// Element type (CUDA type string).
57    pub element_type: String,
58}
59
60impl CudaTranspiler {
61    /// Create a new transpiler with stencil configuration.
62    pub fn new(config: StencilConfig) -> Self {
63        Self {
64            config: Some(config),
65            type_mapper: TypeMapper::new(),
66            intrinsics: IntrinsicRegistry::new(),
67            grid_pos_vars: Vec::new(),
68            context_vars: Vec::new(),
69            indent: 1, // Start with 1 level for function body
70            validation_mode: ValidationMode::Stencil,
71            shared_memory: SharedMemoryConfig::new(),
72            shared_vars: std::collections::HashMap::new(),
73            ring_kernel_mode: false,
74            pointer_vars: std::collections::HashSet::new(),
75        }
76    }
77
78    /// Create a new transpiler without stencil configuration.
79    pub fn new_generic() -> Self {
80        Self {
81            config: None,
82            type_mapper: TypeMapper::new(),
83            intrinsics: IntrinsicRegistry::new(),
84            grid_pos_vars: Vec::new(),
85            context_vars: Vec::new(),
86            indent: 1,
87            validation_mode: ValidationMode::Generic,
88            shared_memory: SharedMemoryConfig::new(),
89            shared_vars: std::collections::HashMap::new(),
90            ring_kernel_mode: false,
91            pointer_vars: std::collections::HashSet::new(),
92        }
93    }
94
95    /// Create a new transpiler with a specific validation mode.
96    pub fn with_mode(mode: ValidationMode) -> Self {
97        Self {
98            config: None,
99            type_mapper: TypeMapper::new(),
100            intrinsics: IntrinsicRegistry::new(),
101            grid_pos_vars: Vec::new(),
102            context_vars: Vec::new(),
103            indent: 1,
104            validation_mode: mode,
105            shared_memory: SharedMemoryConfig::new(),
106            shared_vars: std::collections::HashMap::new(),
107            ring_kernel_mode: false,
108            pointer_vars: std::collections::HashSet::new(),
109        }
110    }
111
112    /// Create a transpiler configured for ring kernel handler transpilation.
113    pub fn for_ring_kernel() -> Self {
114        Self {
115            config: None,
116            type_mapper: crate::types::ring_kernel_type_mapper(),
117            intrinsics: IntrinsicRegistry::new(),
118            grid_pos_vars: Vec::new(),
119            context_vars: Vec::new(),
120            indent: 2, // Inside kernel + loop
121            validation_mode: ValidationMode::Generic,
122            shared_memory: SharedMemoryConfig::new(),
123            shared_vars: std::collections::HashMap::new(),
124            ring_kernel_mode: true,
125            pointer_vars: std::collections::HashSet::new(),
126        }
127    }
128
129    /// Set the validation mode.
130    pub fn set_validation_mode(&mut self, mode: ValidationMode) {
131        self.validation_mode = mode;
132    }
133
134    /// Get the shared memory configuration.
135    pub fn shared_memory(&self) -> &SharedMemoryConfig {
136        &self.shared_memory
137    }
138
139    /// Get current indentation string.
140    fn indent_str(&self) -> String {
141        "    ".repeat(self.indent)
142    }
143
144    /// Transpile a stencil kernel function.
145    pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
146        let config = self
147            .config
148            .as_ref()
149            .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
150            .clone();
151
152        // Identify GridPos parameters
153        for param in &func.sig.inputs {
154            if let FnArg::Typed(pat_type) = param {
155                if is_grid_pos_type(&pat_type.ty) {
156                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
157                        self.grid_pos_vars.push(ident.ident.to_string());
158                    }
159                }
160            }
161        }
162
163        // Generate function signature (without GridPos params)
164        let signature = self.transpile_kernel_signature(func)?;
165
166        // Generate preamble (thread index calculations)
167        let preamble = config.generate_preamble();
168
169        // Generate function body
170        let body = self.transpile_block(&func.block)?;
171
172        Ok(format!(
173            "extern \"C\" __global__ void {signature} {{\n{preamble}\n{body}}}\n"
174        ))
175    }
176
177    /// Transpile a generic (non-stencil) kernel function.
178    ///
179    /// This generates a `__global__` kernel without stencil-specific preamble.
180    /// The kernel code can use `thread_idx_x()`, `block_idx_x()` etc. to access
181    /// CUDA thread indices directly.
182    pub fn transpile_generic_kernel(&mut self, func: &ItemFn) -> Result<String> {
183        // Generate function signature with all params
184        let signature = self.transpile_generic_kernel_signature(func)?;
185
186        // Generate function body
187        let body = self.transpile_block(&func.block)?;
188
189        Ok(format!(
190            "extern \"C\" __global__ void {signature} {{\n{body}}}\n"
191        ))
192    }
193
194    /// Transpile a handler function into a persistent ring kernel.
195    ///
196    /// This wraps the handler body in a persistent message-processing loop
197    /// with control block integration, queue operations, and HLC support.
198    pub fn transpile_ring_kernel(
199        &mut self,
200        handler: &ItemFn,
201        config: &crate::ring_kernel::RingKernelConfig,
202    ) -> Result<String> {
203        use std::fmt::Write;
204
205        // Parse handler signature
206        let handler_sig = HandlerSignature::parse(handler, &self.type_mapper)?;
207
208        // Track context variables for method inlining
209        for param in &handler.sig.inputs {
210            if let FnArg::Typed(pat_type) = param {
211                if is_ring_context_type(&pat_type.ty) {
212                    if let Pat::Ident(ident) = pat_type.pat.as_ref() {
213                        self.context_vars.push(ident.ident.to_string());
214                    }
215                }
216            }
217        }
218
219        // Enable ring kernel mode for context method inlining
220        self.ring_kernel_mode = true;
221
222        let mut output = String::new();
223
224        // Generate struct definitions
225        output.push_str(&crate::ring_kernel::generate_control_block_struct());
226        output.push('\n');
227
228        if config.enable_hlc {
229            output.push_str(&crate::ring_kernel::generate_hlc_struct());
230            output.push('\n');
231        }
232
233        // MessageEnvelope structs (requires HLC for HlcTimestamp)
234        if config.use_envelope_format {
235            output.push_str(&crate::ring_kernel::generate_message_envelope_structs());
236            output.push('\n');
237        }
238
239        if config.enable_k2k {
240            output.push_str(&crate::ring_kernel::generate_k2k_structs());
241            output.push('\n');
242        }
243
244        // Generate message/response struct definitions if needed
245        if let Some(ref msg_param) = handler_sig.message_param {
246            // Extract type name from the parameter
247            let type_name = msg_param
248                .rust_type
249                .trim_start_matches('&')
250                .trim_start_matches("mut ")
251                .trim();
252            if !type_name.is_empty() && type_name != "f32" && type_name != "i32" {
253                writeln!(output, "// Message type: {}", type_name).unwrap();
254            }
255        }
256
257        if let Some(ref ret_type) = handler_sig.return_type {
258            if ret_type.is_struct {
259                writeln!(output, "// Response type: {}", ret_type.rust_type).unwrap();
260            }
261        }
262
263        // Generate kernel signature
264        output.push_str(&config.generate_signature());
265        output.push_str(" {\n");
266
267        // Generate preamble
268        output.push_str(&config.generate_preamble("    "));
269
270        // Generate message loop header
271        output.push_str(&config.generate_loop_header("    "));
272
273        // Generate message deserialization if handler has message param
274        if let Some(ref msg_param) = handler_sig.message_param {
275            let type_name = msg_param
276                .rust_type
277                .trim_start_matches('&')
278                .trim_start_matches("mut ")
279                .trim();
280            if !type_name.is_empty() {
281                if config.use_envelope_format {
282                    // Envelope format: cast msg_ptr to typed payload
283                    writeln!(
284                        output,
285                        "        // Message deserialization (envelope format)"
286                    )
287                    .unwrap();
288                    writeln!(
289                        output,
290                        "        // msg_header contains: message_id, correlation_id, source_kernel, timestamp"
291                    )
292                    .unwrap();
293                    writeln!(
294                        output,
295                        "        {}* {} = ({}*)msg_ptr;",
296                        type_name, msg_param.name, type_name
297                    )
298                    .unwrap();
299                } else {
300                    // Raw format
301                    writeln!(output, "        // Message deserialization (raw format)").unwrap();
302                    writeln!(
303                        output,
304                        "        {}* {} = ({}*)msg_ptr;",
305                        type_name, msg_param.name, type_name
306                    )
307                    .unwrap();
308                }
309                output.push('\n');
310            }
311        }
312
313        // Transpile handler body
314        self.indent = 2; // Inside the message loop
315        let handler_body = self.transpile_block(&handler.block)?;
316
317        // Insert handler code with proper indentation
318        writeln!(output, "        // === USER HANDLER CODE ===").unwrap();
319        for line in handler_body.lines() {
320            if !line.trim().is_empty() {
321                // Add extra indent for being inside the loop
322                writeln!(output, "    {}", line).unwrap();
323            }
324        }
325        writeln!(output, "        // === END HANDLER CODE ===").unwrap();
326
327        // Generate response serialization if handler returns a value
328        if let Some(ref ret_type) = handler_sig.return_type {
329            writeln!(output).unwrap();
330            if config.use_envelope_format {
331                // Envelope format: create full response envelope with header
332                writeln!(
333                    output,
334                    "        // Response serialization (envelope format)"
335                )
336                .unwrap();
337                writeln!(
338                    output,
339                    "        unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;"
340                )
341                .unwrap();
342                writeln!(
343                    output,
344                    "        unsigned char* resp_envelope = &output_buffer[_out_idx * RESP_SIZE];"
345                )
346                .unwrap();
347                writeln!(
348                    output,
349                    "        MessageHeader* resp_header = (MessageHeader*)resp_envelope;"
350                )
351                .unwrap();
352                writeln!(output, "        message_create_response_header(").unwrap();
353                writeln!(output, "            resp_header, msg_header, KERNEL_ID,").unwrap();
354                writeln!(
355                    output,
356                    "            sizeof({}), hlc_physical, hlc_logical, HLC_NODE_ID",
357                    ret_type.cuda_type
358                )
359                .unwrap();
360                writeln!(output, "        );").unwrap();
361                writeln!(
362                    output,
363                    "        memcpy(resp_envelope + MESSAGE_HEADER_SIZE, &response, sizeof({}));",
364                    ret_type.cuda_type
365                )
366                .unwrap();
367                writeln!(output, "        __threadfence();").unwrap();
368            } else {
369                // Raw format
370                writeln!(output, "        // Response serialization (raw format)").unwrap();
371                if ret_type.is_struct {
372                    writeln!(
373                        output,
374                        "        unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;"
375                    )
376                    .unwrap();
377                    writeln!(
378                        output,
379                        "        memcpy(&output_buffer[_out_idx * RESP_SIZE], &response, sizeof({}));",
380                        ret_type.cuda_type
381                    )
382                    .unwrap();
383                }
384            }
385        }
386
387        // Generate message completion
388        output.push_str(&config.generate_message_complete("    "));
389
390        // Generate loop footer
391        output.push_str(&config.generate_loop_footer("    "));
392
393        // Generate epilogue
394        output.push_str(&config.generate_epilogue("    "));
395
396        output.push_str("}\n");
397
398        Ok(output)
399    }
400
401    /// Transpile a generic kernel function signature (keeps all params).
402    fn transpile_generic_kernel_signature(&self, func: &ItemFn) -> Result<String> {
403        let name = func.sig.ident.to_string();
404
405        let mut params = Vec::new();
406        for param in &func.sig.inputs {
407            if let FnArg::Typed(pat_type) = param {
408                let param_name = match pat_type.pat.as_ref() {
409                    Pat::Ident(ident) => ident.ident.to_string(),
410                    _ => {
411                        return Err(TranspileError::Unsupported(
412                            "Complex pattern in parameter".into(),
413                        ))
414                    }
415                };
416
417                let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
418                params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
419            }
420        }
421
422        Ok(format!("{}({})", name, params.join(", ")))
423    }
424
425    /// Transpile the kernel function signature.
426    fn transpile_kernel_signature(&self, func: &ItemFn) -> Result<String> {
427        let name = func.sig.ident.to_string();
428
429        let mut params = Vec::new();
430        for param in &func.sig.inputs {
431            if let FnArg::Typed(pat_type) = param {
432                // Skip GridPos parameters
433                if is_grid_pos_type(&pat_type.ty) {
434                    continue;
435                }
436
437                let param_name = match pat_type.pat.as_ref() {
438                    Pat::Ident(ident) => ident.ident.to_string(),
439                    _ => {
440                        return Err(TranspileError::Unsupported(
441                            "Complex pattern in parameter".into(),
442                        ))
443                    }
444                };
445
446                let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
447                params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
448            }
449        }
450
451        Ok(format!("{}({})", name, params.join(", ")))
452    }
453
454    /// Transpile a block of statements.
455    fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
456        let mut output = String::new();
457
458        for stmt in &block.stmts {
459            let stmt_str = self.transpile_stmt(stmt)?;
460            if !stmt_str.is_empty() {
461                output.push_str(&stmt_str);
462            }
463        }
464
465        Ok(output)
466    }
467
468    /// Transpile a single statement.
469    fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
470        match stmt {
471            Stmt::Local(local) => {
472                let indent = self.indent_str();
473
474                // Get variable name
475                let var_name = match &local.pat {
476                    Pat::Ident(ident) => ident.ident.to_string(),
477                    Pat::Type(pat_type) => {
478                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
479                            ident.ident.to_string()
480                        } else {
481                            return Err(TranspileError::Unsupported(
482                                "Complex pattern in let binding".into(),
483                            ));
484                        }
485                    }
486                    _ => {
487                        return Err(TranspileError::Unsupported(
488                            "Complex pattern in let binding".into(),
489                        ))
490                    }
491                };
492
493                // Check for SharedTile or SharedArray type annotation
494                if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
495                    // Register the shared variable
496                    self.shared_vars.insert(
497                        var_name.clone(),
498                        SharedVarInfo {
499                            name: var_name.clone(),
500                            is_tile: shared_decl.dimensions.len() == 2,
501                            dimensions: shared_decl.dimensions.clone(),
502                            element_type: shared_decl.element_type.clone(),
503                        },
504                    );
505
506                    // Add to shared memory config
507                    self.shared_memory.add(shared_decl.clone());
508
509                    // Return the __shared__ declaration
510                    return Ok(format!("{indent}{}\n", shared_decl.to_cuda_decl()));
511                }
512
513                // Get initializer
514                if let Some(init) = &local.init {
515                    let expr_str = self.transpile_expr(&init.expr)?;
516
517                    // Infer type from expression (default to float for now)
518                    // In a more complete implementation, we'd do proper type inference
519                    let type_str = self.infer_cuda_type(&init.expr);
520
521                    // Track if this is a pointer variable for -> access
522                    if type_str.ends_with('*') {
523                        self.pointer_vars.insert(var_name.clone());
524                    }
525
526                    Ok(format!("{indent}{type_str} {var_name} = {expr_str};\n"))
527                } else {
528                    // Uninitialized variable - need type annotation
529                    Ok(format!("{indent}float {var_name};\n"))
530                }
531            }
532            Stmt::Expr(expr, semi) => {
533                let indent = self.indent_str();
534
535                // Check if this is an if statement with early return (shouldn't wrap in return)
536                if let Expr::If(if_expr) = expr {
537                    // Check if the body contains only a return statement
538                    if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
539                    {
540                        if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
541                            let expr_str = self.transpile_expr(expr)?;
542                            return Ok(format!("{indent}{expr_str};\n"));
543                        }
544                    }
545                }
546
547                let expr_str = self.transpile_expr(expr)?;
548
549                if semi.is_some() {
550                    Ok(format!("{indent}{expr_str};\n"))
551                } else {
552                    // Expression without semicolon (implicit return)
553                    // But check if it's already a return or an if with return
554                    if matches!(expr, Expr::Return(_))
555                        || expr_str.starts_with("return")
556                        || expr_str.starts_with("if (")
557                    {
558                        Ok(format!("{indent}{expr_str};\n"))
559                    } else {
560                        Ok(format!("{indent}return {expr_str};\n"))
561                    }
562                }
563            }
564            Stmt::Item(_) => {
565                // Items in function body not supported
566                Err(TranspileError::Unsupported("Item in function body".into()))
567            }
568            Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
569        }
570    }
571
572    /// Transpile an expression.
573    fn transpile_expr(&self, expr: &Expr) -> Result<String> {
574        match expr {
575            Expr::Lit(lit) => self.transpile_lit(lit),
576            Expr::Path(path) => self.transpile_path(path),
577            Expr::Binary(bin) => self.transpile_binary(bin),
578            Expr::Unary(unary) => self.transpile_unary(unary),
579            Expr::Paren(paren) => self.transpile_paren(paren),
580            Expr::Index(index) => self.transpile_index(index),
581            Expr::Call(call) => self.transpile_call(call),
582            Expr::MethodCall(method) => self.transpile_method_call(method),
583            Expr::If(if_expr) => self.transpile_if(if_expr),
584            Expr::Assign(assign) => self.transpile_assign(assign),
585            Expr::Cast(cast) => self.transpile_cast(cast),
586            Expr::Match(match_expr) => self.transpile_match(match_expr),
587            Expr::Block(block) => {
588                // For block expressions, we just return the last expression
589                if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
590                    self.transpile_expr(expr)
591                } else {
592                    Err(TranspileError::Unsupported(
593                        "Complex block expression".into(),
594                    ))
595                }
596            }
597            Expr::Field(field) => {
598                // Struct field access: obj.field -> obj.field or obj->field for pointers
599                let base = self.transpile_expr(&field.base)?;
600                let member = match &field.member {
601                    syn::Member::Named(ident) => ident.to_string(),
602                    syn::Member::Unnamed(idx) => idx.index.to_string(),
603                };
604
605                // Check if base is a pointer variable - use -> instead of .
606                let accessor = if self.pointer_vars.contains(&base) {
607                    "->"
608                } else {
609                    "."
610                };
611                Ok(format!("{base}{accessor}{member}"))
612            }
613            Expr::Return(ret) => self.transpile_return(ret),
614            Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
615            Expr::While(while_loop) => self.transpile_while_loop(while_loop),
616            Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
617            Expr::Break(break_expr) => self.transpile_break(break_expr),
618            Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
619            Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
620            Expr::Reference(ref_expr) => self.transpile_reference(ref_expr),
621            Expr::Let(let_expr) => self.transpile_let_expr(let_expr),
622            Expr::Tuple(tuple) => {
623                // Transpile tuple as a comma-separated list (for multi-value returns, etc.)
624                let elements: Vec<String> = tuple
625                    .elems
626                    .iter()
627                    .map(|e| self.transpile_expr(e))
628                    .collect::<Result<_>>()?;
629                Ok(format!("({})", elements.join(", ")))
630            }
631            _ => Err(TranspileError::Unsupported(format!(
632                "Expression type: {}",
633                expr.to_token_stream()
634            ))),
635        }
636    }
637
638    /// Transpile a literal.
639    fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
640        match &lit.lit {
641            Lit::Float(f) => {
642                let s = f.to_string();
643                // Ensure f32 literals have 'f' suffix in CUDA
644                if s.ends_with("f32") || !s.contains('.') {
645                    let num = s.trim_end_matches("f32").trim_end_matches("f64");
646                    Ok(format!("{num}f"))
647                } else if s.ends_with("f64") {
648                    Ok(s.trim_end_matches("f64").to_string())
649                } else {
650                    // Plain float literal - add 'f' suffix for float
651                    Ok(format!("{s}f"))
652                }
653            }
654            Lit::Int(i) => Ok(i.to_string()),
655            Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
656            _ => Err(TranspileError::Unsupported(format!(
657                "Literal type: {}",
658                lit.to_token_stream()
659            ))),
660        }
661    }
662
663    /// Transpile a path (variable reference).
664    fn transpile_path(&self, path: &ExprPath) -> Result<String> {
665        let segments: Vec<_> = path
666            .path
667            .segments
668            .iter()
669            .map(|s| s.ident.to_string())
670            .collect();
671
672        if segments.len() == 1 {
673            Ok(segments[0].clone())
674        } else {
675            Ok(segments.join("::"))
676        }
677    }
678
679    /// Transpile a binary expression.
680    fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
681        let left = self.transpile_expr(&bin.left)?;
682        let right = self.transpile_expr(&bin.right)?;
683
684        let op = match bin.op {
685            BinOp::Add(_) => "+",
686            BinOp::Sub(_) => "-",
687            BinOp::Mul(_) => "*",
688            BinOp::Div(_) => "/",
689            BinOp::Rem(_) => "%",
690            BinOp::And(_) => "&&",
691            BinOp::Or(_) => "||",
692            BinOp::BitXor(_) => "^",
693            BinOp::BitAnd(_) => "&",
694            BinOp::BitOr(_) => "|",
695            BinOp::Shl(_) => "<<",
696            BinOp::Shr(_) => ">>",
697            BinOp::Eq(_) => "==",
698            BinOp::Lt(_) => "<",
699            BinOp::Le(_) => "<=",
700            BinOp::Ne(_) => "!=",
701            BinOp::Ge(_) => ">=",
702            BinOp::Gt(_) => ">",
703            BinOp::AddAssign(_) => "+=",
704            BinOp::SubAssign(_) => "-=",
705            BinOp::MulAssign(_) => "*=",
706            BinOp::DivAssign(_) => "/=",
707            BinOp::RemAssign(_) => "%=",
708            BinOp::BitXorAssign(_) => "^=",
709            BinOp::BitAndAssign(_) => "&=",
710            BinOp::BitOrAssign(_) => "|=",
711            BinOp::ShlAssign(_) => "<<=",
712            BinOp::ShrAssign(_) => ">>=",
713            _ => {
714                return Err(TranspileError::Unsupported(format!(
715                    "Binary operator: {}",
716                    bin.to_token_stream()
717                )))
718            }
719        };
720
721        Ok(format!("{left} {op} {right}"))
722    }
723
724    /// Transpile a unary expression.
725    fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
726        let expr = self.transpile_expr(&unary.expr)?;
727
728        let op = match unary.op {
729            UnOp::Neg(_) => "-",
730            UnOp::Not(_) => "!",
731            UnOp::Deref(_) => "*",
732            _ => {
733                return Err(TranspileError::Unsupported(format!(
734                    "Unary operator: {}",
735                    unary.to_token_stream()
736                )))
737            }
738        };
739
740        Ok(format!("{op}({expr})"))
741    }
742
743    /// Transpile a parenthesized expression.
744    fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
745        let inner = self.transpile_expr(&paren.expr)?;
746        Ok(format!("({inner})"))
747    }
748
749    /// Transpile an index expression.
750    fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
751        let base = self.transpile_expr(&index.expr)?;
752        let idx = self.transpile_expr(&index.index)?;
753        Ok(format!("{base}[{idx}]"))
754    }
755
756    /// Transpile a function call.
757    fn transpile_call(&self, call: &ExprCall) -> Result<String> {
758        let func = self.transpile_expr(&call.func)?;
759
760        // Check for intrinsics
761        if let Some(intrinsic) = self.intrinsics.lookup(&func) {
762            let cuda_name = intrinsic.to_cuda_string();
763
764            // Check if this is a "value" intrinsic (like threadIdx.x) vs a "function" intrinsic
765            // Value intrinsics don't have parentheses in CUDA
766            let is_value_intrinsic = cuda_name.contains("Idx.")
767                || cuda_name.contains("Dim.")
768                || cuda_name.starts_with("threadIdx")
769                || cuda_name.starts_with("blockIdx")
770                || cuda_name.starts_with("blockDim")
771                || cuda_name.starts_with("gridDim");
772
773            if is_value_intrinsic && call.args.is_empty() {
774                // Value intrinsics: threadIdx.x, blockIdx.y, etc. - no parens
775                return Ok(cuda_name.to_string());
776            }
777
778            if call.args.is_empty() && cuda_name.ends_with("()") {
779                // Zero-arg function intrinsics: __syncthreads()
780                return Ok(cuda_name.to_string());
781            }
782
783            let args: Vec<String> = call
784                .args
785                .iter()
786                .map(|a| self.transpile_expr(a))
787                .collect::<Result<_>>()?;
788
789            return Ok(format!(
790                "{}({})",
791                cuda_name.trim_end_matches("()"),
792                args.join(", ")
793            ));
794        }
795
796        // Regular function call
797        let args: Vec<String> = call
798            .args
799            .iter()
800            .map(|a| self.transpile_expr(a))
801            .collect::<Result<_>>()?;
802
803        Ok(format!("{}({})", func, args.join(", ")))
804    }
805
806    /// Transpile a method call, handling stencil intrinsics.
807    fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
808        let receiver = self.transpile_expr(&method.receiver)?;
809        let method_name = method.method.to_string();
810
811        // Check if this is a SharedTile/SharedArray method call
812        if let Some(result) =
813            self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
814        {
815            return result;
816        }
817
818        // Check if this is a RingContext method call (in ring kernel mode)
819        if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
820            return self.transpile_context_method(&method_name, &method.args);
821        }
822
823        // Check if this is a GridPos method call
824        if self.grid_pos_vars.contains(&receiver) {
825            return self.transpile_stencil_intrinsic(&method_name, &method.args);
826        }
827
828        // Check for ring kernel intrinsics (like is_active(), should_terminate())
829        if self.ring_kernel_mode {
830            if let Some(intrinsic) = RingKernelIntrinsic::from_name(&method_name) {
831                let args: Vec<String> = method
832                    .args
833                    .iter()
834                    .map(|a| self.transpile_expr(a).unwrap_or_default())
835                    .collect();
836                return Ok(intrinsic.to_cuda(&args));
837            }
838        }
839
840        // Check for f32/f64 math methods
841        if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
842            let cuda_name = intrinsic.to_cuda_string();
843            let args: Vec<String> = std::iter::once(receiver)
844                .chain(
845                    method
846                        .args
847                        .iter()
848                        .map(|a| self.transpile_expr(a).unwrap_or_default()),
849                )
850                .collect();
851
852            return Ok(format!("{}({})", cuda_name, args.join(", ")));
853        }
854
855        // Regular method call (treat as field access + call for C structs)
856        let args: Vec<String> = method
857            .args
858            .iter()
859            .map(|a| self.transpile_expr(a))
860            .collect::<Result<_>>()?;
861
862        Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
863    }
864
865    /// Transpile a RingContext method call to CUDA intrinsics.
866    fn transpile_context_method(
867        &self,
868        method: &str,
869        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
870    ) -> Result<String> {
871        let ctx_method = ContextMethod::from_name(method).ok_or_else(|| {
872            TranspileError::Unsupported(format!("Unknown context method: {}", method))
873        })?;
874
875        let cuda_args: Vec<String> = args
876            .iter()
877            .map(|a| self.transpile_expr(a).unwrap_or_default())
878            .collect();
879
880        Ok(ctx_method.to_cuda(&cuda_args))
881    }
882
883    /// Transpile a stencil intrinsic method call.
884    fn transpile_stencil_intrinsic(
885        &self,
886        method: &str,
887        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
888    ) -> Result<String> {
889        let config = self.config.as_ref().ok_or_else(|| {
890            TranspileError::Unsupported("Stencil intrinsic without config".into())
891        })?;
892
893        let buffer_width = config.buffer_width().to_string();
894        let buffer_slice = format!("{}", config.buffer_width() * config.buffer_height());
895        let is_3d = config.grid == crate::stencil::Grid::Grid3D;
896
897        let intrinsic = StencilIntrinsic::from_method_name(method).ok_or_else(|| {
898            TranspileError::Unsupported(format!("Unknown stencil intrinsic: {method}"))
899        })?;
900
901        // Check if 3D intrinsic used in non-3D kernel
902        if intrinsic.is_3d_only() && !is_3d {
903            return Err(TranspileError::Unsupported(format!(
904                "3D stencil intrinsic '{}' requires Grid3D configuration",
905                method
906            )));
907        }
908
909        match intrinsic {
910            StencilIntrinsic::Index => {
911                // pos.idx() -> idx
912                Ok("idx".to_string())
913            }
914            StencilIntrinsic::North
915            | StencilIntrinsic::South
916            | StencilIntrinsic::East
917            | StencilIntrinsic::West => {
918                // pos.north(buf) -> buf[idx - buffer_width]
919                if args.is_empty() {
920                    return Err(TranspileError::Unsupported(
921                        "Stencil accessor requires buffer argument".into(),
922                    ));
923                }
924                let buffer = self.transpile_expr(&args[0])?;
925                if is_3d {
926                    Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
927                } else {
928                    Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx"))
929                }
930            }
931            StencilIntrinsic::Up | StencilIntrinsic::Down => {
932                // 3D intrinsics: pos.up(buf) -> buf[idx - buffer_slice]
933                if args.is_empty() {
934                    return Err(TranspileError::Unsupported(
935                        "3D stencil accessor requires buffer argument".into(),
936                    ));
937                }
938                let buffer = self.transpile_expr(&args[0])?;
939                Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
940            }
941            StencilIntrinsic::At => {
942                // 2D: pos.at(buf, dx, dy) -> buf[idx + dy * buffer_width + dx]
943                // 3D: pos.at(buf, dx, dy, dz) -> buf[idx + dz * buffer_slice + dy * buffer_width + dx]
944                if is_3d {
945                    if args.len() < 4 {
946                        return Err(TranspileError::Unsupported(
947                            "at() in 3D requires buffer, dx, dy, dz arguments".into(),
948                        ));
949                    }
950                    let buffer = self.transpile_expr(&args[0])?;
951                    let dx = self.transpile_expr(&args[1])?;
952                    let dy = self.transpile_expr(&args[2])?;
953                    let dz = self.transpile_expr(&args[3])?;
954                    Ok(format!(
955                        "{buffer}[idx + ({dz}) * {buffer_slice} + ({dy}) * {buffer_width} + ({dx})]"
956                    ))
957                } else {
958                    if args.len() < 3 {
959                        return Err(TranspileError::Unsupported(
960                            "at() requires buffer, dx, dy arguments".into(),
961                        ));
962                    }
963                    let buffer = self.transpile_expr(&args[0])?;
964                    let dx = self.transpile_expr(&args[1])?;
965                    let dy = self.transpile_expr(&args[2])?;
966                    Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]"))
967                }
968            }
969        }
970    }
971
972    /// Transpile an if expression.
973    fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
974        let cond = self.transpile_expr(&if_expr.cond)?;
975
976        // Check if the body contains only a return statement (early return pattern)
977        if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
978            if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
979                // Simple early return: if (cond) return;
980                if ret.expr.is_none() {
981                    return Ok(format!("if ({cond}) return"));
982                }
983                let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
984                return Ok(format!("if ({cond}) return {ret_val}"));
985            }
986        }
987
988        // For now, only handle if-else as ternary when it's an expression
989        if let Some((_, else_branch)) = &if_expr.else_branch {
990            // If both branches are simple expressions, use ternary
991            if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
992                (if_expr.then_branch.stmts.last(), else_branch.as_ref())
993            {
994                if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
995                    let then_str = self.transpile_expr(then_expr)?;
996                    let else_str = self.transpile_expr(else_expr)?;
997                    return Ok(format!("({cond}) ? ({then_str}) : ({else_str})"));
998                }
999            }
1000
1001            // Otherwise, generate if statement
1002            if let Expr::If(else_if) = else_branch.as_ref() {
1003                // else if chain
1004                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
1005                let else_part = self.transpile_if(else_if)?;
1006                return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
1007            } else if let Expr::Block(else_block) = else_branch.as_ref() {
1008                // else block
1009                let then_body = self.transpile_if_body(&if_expr.then_branch)?;
1010                let else_body = self.transpile_if_body(&else_block.block)?;
1011                return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
1012            }
1013        }
1014
1015        // If without else
1016        let then_body = self.transpile_if_body(&if_expr.then_branch)?;
1017        Ok(format!("if ({cond}) {{{then_body}}}"))
1018    }
1019
1020    /// Transpile the body of an if branch.
1021    fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
1022        let mut body = String::new();
1023        for stmt in &block.stmts {
1024            match stmt {
1025                Stmt::Expr(expr, Some(_)) => {
1026                    let expr_str = self.transpile_expr(expr)?;
1027                    body.push_str(&format!(" {expr_str};"));
1028                }
1029                Stmt::Expr(Expr::Return(ret), None) => {
1030                    // Handle explicit return without semicolon
1031                    if let Some(ret_expr) = &ret.expr {
1032                        let expr_str = self.transpile_expr(ret_expr)?;
1033                        body.push_str(&format!(" return {expr_str};"));
1034                    } else {
1035                        body.push_str(" return;");
1036                    }
1037                }
1038                Stmt::Expr(expr, None) => {
1039                    let expr_str = self.transpile_expr(expr)?;
1040                    body.push_str(&format!(" return {expr_str};"));
1041                }
1042                _ => {}
1043            }
1044        }
1045        Ok(body)
1046    }
1047
1048    /// Transpile an assignment expression.
1049    fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
1050        let left = self.transpile_expr(&assign.left)?;
1051        let right = self.transpile_expr(&assign.right)?;
1052        Ok(format!("{left} = {right}"))
1053    }
1054
1055    /// Transpile a cast expression.
1056    fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
1057        let expr = self.transpile_expr(&cast.expr)?;
1058        let cuda_type = self.type_mapper.map_type(&cast.ty)?;
1059        Ok(format!("({})({})", cuda_type.to_cuda_string(), expr))
1060    }
1061
1062    /// Transpile a return expression.
1063    fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
1064        if let Some(expr) = &ret.expr {
1065            let expr_str = self.transpile_expr(expr)?;
1066            Ok(format!("return {expr_str}"))
1067        } else {
1068            Ok("return".to_string())
1069        }
1070    }
1071
1072    /// Transpile a struct literal expression.
1073    ///
1074    /// Converts Rust struct literals to C-style compound literals:
1075    /// `Point { x: 1.0, y: 2.0 }` -> `(Point){ .x = 1.0f, .y = 2.0f }`
1076    fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1077        // Get the struct type name
1078        let type_name = struct_expr
1079            .path
1080            .segments
1081            .iter()
1082            .map(|s| s.ident.to_string())
1083            .collect::<Vec<_>>()
1084            .join("::");
1085
1086        // Transpile each field
1087        let mut fields = Vec::new();
1088        for field in &struct_expr.fields {
1089            let field_name = match &field.member {
1090                syn::Member::Named(ident) => ident.to_string(),
1091                syn::Member::Unnamed(idx) => idx.index.to_string(),
1092            };
1093            let value = self.transpile_expr(&field.expr)?;
1094            fields.push(format!(".{} = {}", field_name, value));
1095        }
1096
1097        // Check for struct update syntax (not supported in C)
1098        if struct_expr.rest.is_some() {
1099            return Err(TranspileError::Unsupported(
1100                "Struct update syntax (..base) is not supported in CUDA".into(),
1101            ));
1102        }
1103
1104        // Generate C compound literal: (TypeName){ .field1 = val1, .field2 = val2 }
1105        Ok(format!("({}){{ {} }}", type_name, fields.join(", ")))
1106    }
1107
1108    /// Transpile a reference expression.
1109    ///
1110    /// In CUDA C, we typically need pointers. This handles:
1111    /// - `&arr[idx]` -> `&arr[idx]` (pointer to element)
1112    /// - `&mut arr[idx]` -> `&arr[idx]` (same in C)
1113    /// - `&variable` -> `&variable`
1114    fn transpile_reference(&self, ref_expr: &ExprReference) -> Result<String> {
1115        let inner = self.transpile_expr(&ref_expr.expr)?;
1116
1117        // In CUDA C, taking a reference is the same as taking address
1118        // For array indexing like &arr[idx], we produce &arr[idx]
1119        // This creates a pointer to that element
1120        Ok(format!("&{inner}"))
1121    }
1122
1123    /// Transpile a let expression (used in if-let patterns).
1124    ///
1125    /// Note: Full pattern matching is not supported in CUDA, but we can
1126    /// handle simple `if let Some(x) = expr` patterns by transpiling
1127    /// to a simple conditional check.
1128    fn transpile_let_expr(&self, let_expr: &ExprLet) -> Result<String> {
1129        // For now, we treat let expressions as unsupported since
1130        // CUDA doesn't have Option types. But we can add special cases.
1131        let _ = let_expr; // Silence unused warning
1132        Err(TranspileError::Unsupported(
1133            "let expressions (if-let patterns) are not directly supported in CUDA. \
1134             Use explicit comparisons instead."
1135                .into(),
1136        ))
1137    }
1138
1139    // === Loop Transpilation ===
1140
1141    /// Transpile a for loop to CUDA.
1142    ///
1143    /// Handles `for i in start..end` and `for i in start..=end` patterns.
1144    ///
1145    /// # Example
1146    ///
1147    /// ```ignore
1148    /// // Rust
1149    /// for i in 0..n {
1150    ///     data[i] = 0.0;
1151    /// }
1152    ///
1153    /// // CUDA
1154    /// for (int i = 0; i < n; i++) {
1155    ///     data[i] = 0.0f;
1156    /// }
1157    /// ```
1158    fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1159        // Check if loops are allowed
1160        if !self.validation_mode.allows_loops() {
1161            return Err(TranspileError::Unsupported(
1162                "Loops are not allowed in stencil kernels".into(),
1163            ));
1164        }
1165
1166        // Extract loop variable name
1167        let var_name = extract_loop_var(&for_loop.pat)
1168            .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1169
1170        // The iterator expression should be a range
1171        let header = match for_loop.expr.as_ref() {
1172            Expr::Range(range) => {
1173                let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1174                range_info.to_cuda_for_header(&var_name, "int")
1175            }
1176            _ => {
1177                // For non-range iterators, we can't directly transpile
1178                return Err(TranspileError::Unsupported(
1179                    "Only range expressions (start..end) are supported in for loops".into(),
1180                ));
1181            }
1182        };
1183
1184        // Transpile the loop body
1185        let body = self.transpile_loop_body(&for_loop.body)?;
1186
1187        Ok(format!("{header} {{\n{body}}}"))
1188    }
1189
1190    /// Transpile a while loop to CUDA.
1191    ///
1192    /// # Example
1193    ///
1194    /// ```ignore
1195    /// // Rust
1196    /// while !done {
1197    ///     process();
1198    /// }
1199    ///
1200    /// // CUDA
1201    /// while (!done) {
1202    ///     process();
1203    /// }
1204    /// ```
1205    fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1206        // Check if loops are allowed
1207        if !self.validation_mode.allows_loops() {
1208            return Err(TranspileError::Unsupported(
1209                "Loops are not allowed in stencil kernels".into(),
1210            ));
1211        }
1212
1213        // Transpile the condition
1214        let condition = self.transpile_expr(&while_loop.cond)?;
1215
1216        // Transpile the loop body
1217        let body = self.transpile_loop_body(&while_loop.body)?;
1218
1219        Ok(format!("while ({condition}) {{\n{body}}}"))
1220    }
1221
1222    /// Transpile an infinite loop to CUDA.
1223    ///
1224    /// # Example
1225    ///
1226    /// ```ignore
1227    /// // Rust
1228    /// loop {
1229    ///     if should_exit { break; }
1230    /// }
1231    ///
1232    /// // CUDA
1233    /// while (true) {
1234    ///     if (should_exit) { break; }
1235    /// }
1236    /// ```
1237    fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1238        // Check if loops are allowed
1239        if !self.validation_mode.allows_loops() {
1240            return Err(TranspileError::Unsupported(
1241                "Loops are not allowed in stencil kernels".into(),
1242            ));
1243        }
1244
1245        // Transpile the loop body
1246        let body = self.transpile_loop_body(&loop_expr.body)?;
1247
1248        // Use while(true) for infinite loops
1249        Ok(format!("while (true) {{\n{body}}}"))
1250    }
1251
1252    /// Transpile a break expression.
1253    fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1254        // Check for labeled break (not supported)
1255        if break_expr.label.is_some() {
1256            return Err(TranspileError::Unsupported(
1257                "Labeled break is not supported in CUDA".into(),
1258            ));
1259        }
1260
1261        // Check for break with value (not supported in CUDA)
1262        if break_expr.expr.is_some() {
1263            return Err(TranspileError::Unsupported(
1264                "Break with value is not supported in CUDA".into(),
1265            ));
1266        }
1267
1268        Ok("break".to_string())
1269    }
1270
1271    /// Transpile a continue expression.
1272    fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1273        // Check for labeled continue (not supported)
1274        if cont_expr.label.is_some() {
1275            return Err(TranspileError::Unsupported(
1276                "Labeled continue is not supported in CUDA".into(),
1277            ));
1278        }
1279
1280        Ok("continue".to_string())
1281    }
1282
1283    /// Transpile a loop body (block of statements).
1284    fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1285        let mut output = String::new();
1286        let inner_indent = "    ".repeat(self.indent + 1);
1287
1288        for stmt in &block.stmts {
1289            match stmt {
1290                Stmt::Local(local) => {
1291                    // Variable declaration
1292                    let var_name = match &local.pat {
1293                        Pat::Ident(ident) => ident.ident.to_string(),
1294                        Pat::Type(pat_type) => {
1295                            if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1296                                ident.ident.to_string()
1297                            } else {
1298                                return Err(TranspileError::Unsupported(
1299                                    "Complex pattern in let binding".into(),
1300                                ));
1301                            }
1302                        }
1303                        _ => {
1304                            return Err(TranspileError::Unsupported(
1305                                "Complex pattern in let binding".into(),
1306                            ))
1307                        }
1308                    };
1309
1310                    if let Some(init) = &local.init {
1311                        let expr_str = self.transpile_expr(&init.expr)?;
1312                        let type_str = self.infer_cuda_type(&init.expr);
1313                        output.push_str(&format!(
1314                            "{inner_indent}{type_str} {var_name} = {expr_str};\n"
1315                        ));
1316                    } else {
1317                        output.push_str(&format!("{inner_indent}float {var_name};\n"));
1318                    }
1319                }
1320                Stmt::Expr(expr, semi) => {
1321                    let expr_str = self.transpile_expr(expr)?;
1322                    if semi.is_some() {
1323                        output.push_str(&format!("{inner_indent}{expr_str};\n"));
1324                    } else {
1325                        // Expression without semicolon at end of block
1326                        output.push_str(&format!("{inner_indent}{expr_str};\n"));
1327                    }
1328                }
1329                _ => {
1330                    return Err(TranspileError::Unsupported(
1331                        "Unsupported statement in loop body".into(),
1332                    ));
1333                }
1334            }
1335        }
1336
1337        // Add closing indentation
1338        let closing_indent = "    ".repeat(self.indent);
1339        output.push_str(&closing_indent);
1340
1341        Ok(output)
1342    }
1343
1344    // === Shared Memory Support ===
1345
1346    /// Try to parse a local variable declaration as a shared memory declaration.
1347    ///
1348    /// Recognizes patterns like:
1349    /// - `let tile = SharedTile::<f32, 16, 16>::new();`
1350    /// - `let buffer = SharedArray::<f32, 256>::new();`
1351    /// - `let tile: SharedTile<f32, 16, 16> = SharedTile::new();`
1352    fn try_parse_shared_declaration(
1353        &self,
1354        local: &syn::Local,
1355        var_name: &str,
1356    ) -> Result<Option<SharedMemoryDecl>> {
1357        // Check if there's a type annotation
1358        if let Pat::Type(pat_type) = &local.pat {
1359            let type_str = pat_type.ty.to_token_stream().to_string();
1360            return self.parse_shared_type(&type_str, var_name);
1361        }
1362
1363        // Check the initializer expression for SharedTile::new() or SharedArray::new()
1364        if let Some(init) = &local.init {
1365            if let Expr::Call(call) = init.expr.as_ref() {
1366                if let Expr::Path(path) = call.func.as_ref() {
1367                    let path_str = path.to_token_stream().to_string();
1368                    return self.parse_shared_type(&path_str, var_name);
1369                }
1370            }
1371        }
1372
1373        Ok(None)
1374    }
1375
1376    /// Parse a type string to extract shared memory info.
1377    fn parse_shared_type(
1378        &self,
1379        type_str: &str,
1380        var_name: &str,
1381    ) -> Result<Option<SharedMemoryDecl>> {
1382        // Clean up the type string (remove spaces around ::)
1383        let type_str = type_str
1384            .replace(" :: ", "::")
1385            .replace(" ::", "::")
1386            .replace(":: ", "::");
1387
1388        // Check for SharedTile<T, W, H> or SharedTile::<T, W, H>::new
1389        if type_str.contains("SharedTile") {
1390            // Extract the generic parameters
1391            if let Some(start) = type_str.find('<') {
1392                if let Some(end) = type_str.rfind('>') {
1393                    let params = &type_str[start + 1..end];
1394                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1395
1396                    if parts.len() >= 3 {
1397                        let rust_type = parts[0];
1398                        let width: usize = parts[1].parse().map_err(|_| {
1399                            TranspileError::Unsupported("Invalid SharedTile width".into())
1400                        })?;
1401                        let height: usize = parts[2].parse().map_err(|_| {
1402                            TranspileError::Unsupported("Invalid SharedTile height".into())
1403                        })?;
1404
1405                        let cuda_type = rust_to_cuda_element_type(rust_type);
1406                        return Ok(Some(SharedMemoryDecl::tile(
1407                            var_name, cuda_type, width, height,
1408                        )));
1409                    }
1410                }
1411            }
1412        }
1413
1414        // Check for SharedArray<T, N> or SharedArray::<T, N>::new
1415        if type_str.contains("SharedArray") {
1416            if let Some(start) = type_str.find('<') {
1417                if let Some(end) = type_str.rfind('>') {
1418                    let params = &type_str[start + 1..end];
1419                    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1420
1421                    if parts.len() >= 2 {
1422                        let rust_type = parts[0];
1423                        let size: usize = parts[1].parse().map_err(|_| {
1424                            TranspileError::Unsupported("Invalid SharedArray size".into())
1425                        })?;
1426
1427                        let cuda_type = rust_to_cuda_element_type(rust_type);
1428                        return Ok(Some(SharedMemoryDecl::array(var_name, cuda_type, size)));
1429                    }
1430                }
1431            }
1432        }
1433
1434        Ok(None)
1435    }
1436
1437    /// Check if a variable is a shared memory variable and handle method calls.
1438    fn try_transpile_shared_method_call(
1439        &self,
1440        receiver: &str,
1441        method_name: &str,
1442        args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1443    ) -> Option<Result<String>> {
1444        let shared_info = self.shared_vars.get(receiver)?;
1445
1446        match method_name {
1447            "get" => {
1448                // tile.get(x, y) -> tile[y][x] (for 2D) or arr[idx] (for 1D)
1449                if shared_info.is_tile {
1450                    if args.len() >= 2 {
1451                        let x = self.transpile_expr(&args[0]).ok()?;
1452                        let y = self.transpile_expr(&args[1]).ok()?;
1453                        // CUDA uses row-major: tile[row][col] = tile[y][x]
1454                        Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1455                    } else {
1456                        Some(Err(TranspileError::Unsupported(
1457                            "SharedTile.get requires x and y arguments".into(),
1458                        )))
1459                    }
1460                } else {
1461                    // 1D array
1462                    if !args.is_empty() {
1463                        let idx = self.transpile_expr(&args[0]).ok()?;
1464                        Some(Ok(format!("{}[{}]", receiver, idx)))
1465                    } else {
1466                        Some(Err(TranspileError::Unsupported(
1467                            "SharedArray.get requires index argument".into(),
1468                        )))
1469                    }
1470                }
1471            }
1472            "set" => {
1473                // tile.set(x, y, val) -> tile[y][x] = val
1474                if shared_info.is_tile {
1475                    if args.len() >= 3 {
1476                        let x = self.transpile_expr(&args[0]).ok()?;
1477                        let y = self.transpile_expr(&args[1]).ok()?;
1478                        let val = self.transpile_expr(&args[2]).ok()?;
1479                        Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1480                    } else {
1481                        Some(Err(TranspileError::Unsupported(
1482                            "SharedTile.set requires x, y, and value arguments".into(),
1483                        )))
1484                    }
1485                } else {
1486                    // 1D array
1487                    if args.len() >= 2 {
1488                        let idx = self.transpile_expr(&args[0]).ok()?;
1489                        let val = self.transpile_expr(&args[1]).ok()?;
1490                        Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1491                    } else {
1492                        Some(Err(TranspileError::Unsupported(
1493                            "SharedArray.set requires index and value arguments".into(),
1494                        )))
1495                    }
1496                }
1497            }
1498            "width" | "height" | "size" => {
1499                // These are compile-time constants
1500                match method_name {
1501                    "width" if shared_info.is_tile => {
1502                        Some(Ok(shared_info.dimensions[1].to_string()))
1503                    }
1504                    "height" if shared_info.is_tile => {
1505                        Some(Ok(shared_info.dimensions[0].to_string()))
1506                    }
1507                    "size" => {
1508                        let total: usize = shared_info.dimensions.iter().product();
1509                        Some(Ok(total.to_string()))
1510                    }
1511                    _ => None,
1512                }
1513            }
1514            _ => None,
1515        }
1516    }
1517
1518    /// Transpile a match expression to switch/case.
1519    fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1520        let scrutinee = self.transpile_expr(&match_expr.expr)?;
1521        let mut output = format!("switch ({scrutinee}) {{\n");
1522
1523        for arm in &match_expr.arms {
1524            // Handle the pattern
1525            let case_label = self.transpile_match_pattern(&arm.pat)?;
1526
1527            if case_label == "default" || case_label.starts_with("/*") {
1528                output.push_str("        default: {\n");
1529            } else {
1530                output.push_str(&format!("        case {case_label}: {{\n"));
1531            }
1532
1533            // Handle the arm body - check if it's a block with statements
1534            match arm.body.as_ref() {
1535                Expr::Block(block) => {
1536                    // Block expression with multiple statements
1537                    for stmt in &block.block.stmts {
1538                        let stmt_str = self.transpile_stmt_inline(stmt)?;
1539                        output.push_str(&format!("            {stmt_str}\n"));
1540                    }
1541                }
1542                _ => {
1543                    // Single expression - wrap in statement
1544                    let body = self.transpile_expr(&arm.body)?;
1545                    output.push_str(&format!("            {body};\n"));
1546                }
1547            }
1548
1549            output.push_str("            break;\n");
1550            output.push_str("        }\n");
1551        }
1552
1553        output.push_str("    }");
1554        Ok(output)
1555    }
1556
1557    /// Transpile a match pattern to a case label.
1558    fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1559        match pat {
1560            Pat::Lit(pat_lit) => {
1561                // Integer literal pattern - pat_lit.lit contains the literal
1562                match &pat_lit.lit {
1563                    Lit::Int(i) => Ok(i.to_string()),
1564                    Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
1565                    _ => Err(TranspileError::Unsupported(
1566                        "Non-integer literal in match pattern".into(),
1567                    )),
1568                }
1569            }
1570            Pat::Wild(_) => {
1571                // _ pattern becomes default
1572                Ok("default".to_string())
1573            }
1574            Pat::Ident(ident) => {
1575                // Named pattern - treat as default case for now
1576                // This handles things like `x => ...` which bind a value
1577                Ok(format!("/* {} */ default", ident.ident))
1578            }
1579            Pat::Or(pat_or) => {
1580                // Multiple patterns: 0 | 1 | 2 => ...
1581                // CUDA switch doesn't support this directly, we need multiple case labels
1582                // For now, just use the first pattern and note the limitation
1583                if let Some(first) = pat_or.cases.first() {
1584                    self.transpile_match_pattern(first)
1585                } else {
1586                    Err(TranspileError::Unsupported("Empty or pattern".into()))
1587                }
1588            }
1589            _ => Err(TranspileError::Unsupported(format!(
1590                "Match pattern: {}",
1591                pat.to_token_stream()
1592            ))),
1593        }
1594    }
1595
1596    /// Transpile a statement without indentation (for inline use in switch).
1597    fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1598        match stmt {
1599            Stmt::Local(local) => {
1600                let var_name = match &local.pat {
1601                    Pat::Ident(ident) => ident.ident.to_string(),
1602                    Pat::Type(pat_type) => {
1603                        if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1604                            ident.ident.to_string()
1605                        } else {
1606                            return Err(TranspileError::Unsupported(
1607                                "Complex pattern in let binding".into(),
1608                            ));
1609                        }
1610                    }
1611                    _ => {
1612                        return Err(TranspileError::Unsupported(
1613                            "Complex pattern in let binding".into(),
1614                        ))
1615                    }
1616                };
1617
1618                if let Some(init) = &local.init {
1619                    let expr_str = self.transpile_expr(&init.expr)?;
1620                    let type_str = self.infer_cuda_type(&init.expr);
1621                    Ok(format!("{type_str} {var_name} = {expr_str};"))
1622                } else {
1623                    Ok(format!("float {var_name};"))
1624                }
1625            }
1626            Stmt::Expr(expr, semi) => {
1627                let expr_str = self.transpile_expr(expr)?;
1628                if semi.is_some() {
1629                    Ok(format!("{expr_str};"))
1630                } else {
1631                    Ok(format!("return {expr_str};"))
1632                }
1633            }
1634            _ => Err(TranspileError::Unsupported(
1635                "Unsupported statement in match arm".into(),
1636            )),
1637        }
1638    }
1639
1640    /// Infer CUDA type from expression (simple heuristic).
1641    fn infer_cuda_type(&self, expr: &Expr) -> &'static str {
1642        match expr {
1643            Expr::Lit(lit) => match &lit.lit {
1644                Lit::Float(_) => "float",
1645                Lit::Int(_) => "int",
1646                Lit::Bool(_) => "int",
1647                _ => "float",
1648            },
1649            Expr::Binary(bin) => {
1650                // Check if this is an integer operation
1651                let left_type = self.infer_cuda_type(&bin.left);
1652                let right_type = self.infer_cuda_type(&bin.right);
1653                // If both sides are int, result is int
1654                if left_type == "int" && right_type == "int" {
1655                    "int"
1656                } else {
1657                    "float"
1658                }
1659            }
1660            Expr::Call(call) => {
1661                // Check for intrinsics that return int
1662                if let Ok(func) = self.transpile_expr(&call.func) {
1663                    if let Some(intrinsic) = self.intrinsics.lookup(&func) {
1664                        let cuda_name = intrinsic.to_cuda_string();
1665                        // Thread/block indices return int
1666                        if cuda_name.contains("Idx") || cuda_name.contains("Dim") {
1667                            return "int";
1668                        }
1669                    }
1670                }
1671                "float"
1672            }
1673            Expr::Index(_) => "float", // Array access - could be any type, default to float
1674            Expr::Cast(cast) => {
1675                // Use the target type of the cast
1676                if let Ok(cuda_type) = self.type_mapper.map_type(&cast.ty) {
1677                    let s = cuda_type.to_cuda_string();
1678                    if s.contains("int") || s.contains("size_t") || s == "unsigned long long" {
1679                        return "int";
1680                    }
1681                }
1682                "float"
1683            }
1684            Expr::Reference(ref_expr) => {
1685                // Reference expression - try to determine pointer type from inner
1686                // For &arr[idx], we get a pointer to the element type
1687                match ref_expr.expr.as_ref() {
1688                    Expr::Index(idx_expr) => {
1689                        // &arr[idx] - get the base array name and look it up
1690                        if let Expr::Path(path) = &*idx_expr.expr {
1691                            let name = path
1692                                .path
1693                                .segments
1694                                .iter()
1695                                .map(|s| s.ident.to_string())
1696                                .collect::<Vec<_>>()
1697                                .join("::");
1698                            // Common GPU struct names -> pointer to that struct
1699                            if name.contains("transaction") || name.contains("Transaction") {
1700                                return "GpuTransaction*";
1701                            }
1702                            if name.contains("profile") || name.contains("Profile") {
1703                                return "GpuCustomerProfile*";
1704                            }
1705                            if name.contains("alert") || name.contains("Alert") {
1706                                return "GpuAlert*";
1707                            }
1708                        }
1709                        "float*" // Default element pointer
1710                    }
1711                    _ => "void*",
1712                }
1713            }
1714            Expr::MethodCall(_) => "float",
1715            Expr::Field(field) => {
1716                // Field access - try to infer type from field name
1717                let member_name = match &field.member {
1718                    syn::Member::Named(ident) => ident.to_string(),
1719                    syn::Member::Unnamed(idx) => idx.index.to_string(),
1720                };
1721                // Common field name patterns
1722                if member_name.contains("count") || member_name.contains("_count") {
1723                    return "unsigned int";
1724                }
1725                if member_name.contains("threshold") || member_name.ends_with("_id") {
1726                    return "unsigned long long";
1727                }
1728                if member_name.ends_with("_pct") {
1729                    return "unsigned char";
1730                }
1731                "float"
1732            }
1733            Expr::Path(path) => {
1734                // Variable access - check if it's a known variable
1735                let name = path
1736                    .path
1737                    .segments
1738                    .iter()
1739                    .map(|s| s.ident.to_string())
1740                    .collect::<Vec<_>>()
1741                    .join("::");
1742                if name.contains("threshold")
1743                    || name.contains("count")
1744                    || name == "idx"
1745                    || name == "n"
1746                {
1747                    return "int";
1748                }
1749                "float"
1750            }
1751            Expr::If(if_expr) => {
1752                // For ternary (if-else), infer type from branches
1753                if let Some((_, else_branch)) = &if_expr.else_branch {
1754                    if let Expr::Block(block) = else_branch.as_ref() {
1755                        if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
1756                            return self.infer_cuda_type(expr);
1757                        }
1758                    }
1759                }
1760                // Try from then branch
1761                if let Some(Stmt::Expr(expr, None)) = if_expr.then_branch.stmts.last() {
1762                    return self.infer_cuda_type(expr);
1763                }
1764                "float"
1765            }
1766            _ => "float",
1767        }
1768    }
1769}
1770
1771/// Transpile a function to CUDA without stencil configuration.
1772pub fn transpile_function(func: &ItemFn) -> Result<String> {
1773    let mut transpiler = CudaTranspiler::new_generic();
1774
1775    // Generate function signature
1776    let name = func.sig.ident.to_string();
1777
1778    let mut params = Vec::new();
1779    for param in &func.sig.inputs {
1780        if let FnArg::Typed(pat_type) = param {
1781            let param_name = match pat_type.pat.as_ref() {
1782                Pat::Ident(ident) => ident.ident.to_string(),
1783                _ => continue,
1784            };
1785
1786            let cuda_type = transpiler.type_mapper.map_type(&pat_type.ty)?;
1787            params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
1788        }
1789    }
1790
1791    // Return type
1792    let return_type = match &func.sig.output {
1793        ReturnType::Default => "void".to_string(),
1794        ReturnType::Type(_, ty) => transpiler.type_mapper.map_type(ty)?.to_cuda_string(),
1795    };
1796
1797    // Generate body
1798    let body = transpiler.transpile_block(&func.block)?;
1799
1800    Ok(format!(
1801        "__device__ {return_type} {name}({params}) {{\n{body}}}\n",
1802        params = params.join(", ")
1803    ))
1804}
1805
1806#[cfg(test)]
1807mod tests {
1808    use super::*;
1809    use syn::parse_quote;
1810
1811    #[test]
1812    fn test_simple_arithmetic() {
1813        let transpiler = CudaTranspiler::new_generic();
1814
1815        let expr: Expr = parse_quote!(a + b * 2.0);
1816        let result = transpiler.transpile_expr(&expr).unwrap();
1817        assert_eq!(result, "a + b * 2.0f");
1818    }
1819
1820    #[test]
1821    fn test_let_binding() {
1822        let mut transpiler = CudaTranspiler::new_generic();
1823
1824        let stmt: Stmt = parse_quote!(let x = a + b;);
1825        let result = transpiler.transpile_stmt(&stmt).unwrap();
1826        assert!(result.contains("float x = a + b;"));
1827    }
1828
1829    #[test]
1830    fn test_array_index() {
1831        let transpiler = CudaTranspiler::new_generic();
1832
1833        let expr: Expr = parse_quote!(data[idx]);
1834        let result = transpiler.transpile_expr(&expr).unwrap();
1835        assert_eq!(result, "data[idx]");
1836    }
1837
1838    #[test]
1839    fn test_stencil_intrinsics() {
1840        let config = StencilConfig::new("test")
1841            .with_tile_size(16, 16)
1842            .with_halo(1);
1843        let mut transpiler = CudaTranspiler::new(config);
1844        transpiler.grid_pos_vars.push("pos".to_string());
1845
1846        // Test pos.idx()
1847        let expr: Expr = parse_quote!(pos.idx());
1848        let result = transpiler.transpile_expr(&expr).unwrap();
1849        assert_eq!(result, "idx");
1850
1851        // Test pos.north(p)
1852        let expr: Expr = parse_quote!(pos.north(p));
1853        let result = transpiler.transpile_expr(&expr).unwrap();
1854        assert_eq!(result, "p[idx - 18]");
1855
1856        // Test pos.east(p)
1857        let expr: Expr = parse_quote!(pos.east(p));
1858        let result = transpiler.transpile_expr(&expr).unwrap();
1859        assert_eq!(result, "p[idx + 1]");
1860    }
1861
1862    #[test]
1863    fn test_ternary_if() {
1864        let transpiler = CudaTranspiler::new_generic();
1865
1866        let expr: Expr = parse_quote!(if x > 0.0 { x } else { -x });
1867        let result = transpiler.transpile_expr(&expr).unwrap();
1868        assert!(result.contains("?"));
1869        assert!(result.contains(":"));
1870    }
1871
1872    #[test]
1873    fn test_full_stencil_kernel() {
1874        let func: ItemFn = parse_quote! {
1875            fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
1876                let curr = p[pos.idx()];
1877                let prev = p_prev[pos.idx()];
1878                let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
1879                p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
1880            }
1881        };
1882
1883        let config = StencilConfig::new("fdtd")
1884            .with_tile_size(16, 16)
1885            .with_halo(1);
1886
1887        let mut transpiler = CudaTranspiler::new(config);
1888        let cuda = transpiler.transpile_stencil(&func).unwrap();
1889
1890        // Check key features
1891        assert!(cuda.contains("extern \"C\" __global__"));
1892        assert!(cuda.contains("threadIdx.x"));
1893        assert!(cuda.contains("threadIdx.y"));
1894        assert!(cuda.contains("buffer_width = 18"));
1895        assert!(cuda.contains("const float* __restrict__ p"));
1896        assert!(cuda.contains("float* __restrict__ p_prev"));
1897        assert!(!cuda.contains("GridPos")); // GridPos should be removed
1898
1899        println!("Generated CUDA:\n{}", cuda);
1900    }
1901
1902    #[test]
1903    fn test_early_return() {
1904        let mut transpiler = CudaTranspiler::new_generic();
1905
1906        let stmt: Stmt = parse_quote!(return;);
1907        let result = transpiler.transpile_stmt(&stmt).unwrap();
1908        assert!(result.contains("return;"));
1909
1910        let stmt_val: Stmt = parse_quote!(return 42;);
1911        let result_val = transpiler.transpile_stmt(&stmt_val).unwrap();
1912        assert!(result_val.contains("return 42;"));
1913    }
1914
1915    #[test]
1916    fn test_match_to_switch() {
1917        let transpiler = CudaTranspiler::new_generic();
1918
1919        let expr: Expr = parse_quote! {
1920            match edge {
1921                0 => { idx = 1 * 18 + i; }
1922                1 => { idx = 16 * 18 + i; }
1923                _ => { idx = 0; }
1924            }
1925        };
1926
1927        let result = transpiler.transpile_expr(&expr).unwrap();
1928        assert!(
1929            result.contains("switch (edge)"),
1930            "Should generate switch: {}",
1931            result
1932        );
1933        assert!(result.contains("case 0:"), "Should have case 0: {}", result);
1934        assert!(result.contains("case 1:"), "Should have case 1: {}", result);
1935        assert!(
1936            result.contains("default:"),
1937            "Should have default: {}",
1938            result
1939        );
1940        assert!(result.contains("break;"), "Should have break: {}", result);
1941
1942        println!("Generated switch:\n{}", result);
1943    }
1944
1945    #[test]
1946    fn test_block_idx_intrinsics() {
1947        let transpiler = CudaTranspiler::new_generic();
1948
1949        // Test block_idx_x() call
1950        let expr: Expr = parse_quote!(block_idx_x());
1951        let result = transpiler.transpile_expr(&expr).unwrap();
1952        assert_eq!(result, "blockIdx.x");
1953
1954        // Test thread_idx_y() call
1955        let expr2: Expr = parse_quote!(thread_idx_y());
1956        let result2 = transpiler.transpile_expr(&expr2).unwrap();
1957        assert_eq!(result2, "threadIdx.y");
1958
1959        // Test grid_dim_x() call
1960        let expr3: Expr = parse_quote!(grid_dim_x());
1961        let result3 = transpiler.transpile_expr(&expr3).unwrap();
1962        assert_eq!(result3, "gridDim.x");
1963    }
1964
1965    #[test]
1966    fn test_global_index_calculation() {
1967        let transpiler = CudaTranspiler::new_generic();
1968
1969        // Common CUDA pattern: gx = blockIdx.x * blockDim.x + threadIdx.x
1970        let expr: Expr = parse_quote!(block_idx_x() * block_dim_x() + thread_idx_x());
1971        let result = transpiler.transpile_expr(&expr).unwrap();
1972        assert!(result.contains("blockIdx.x"), "Should contain blockIdx.x");
1973        assert!(result.contains("blockDim.x"), "Should contain blockDim.x");
1974        assert!(result.contains("threadIdx.x"), "Should contain threadIdx.x");
1975
1976        println!("Global index expression: {}", result);
1977    }
1978
1979    // === Loop Transpilation Tests ===
1980
1981    #[test]
1982    fn test_for_loop_transpile() {
1983        let transpiler = CudaTranspiler::new_generic();
1984
1985        let expr: Expr = parse_quote! {
1986            for i in 0..n {
1987                data[i] = 0.0;
1988            }
1989        };
1990
1991        let result = transpiler.transpile_expr(&expr).unwrap();
1992        assert!(
1993            result.contains("for (int i = 0; i < n; i++)"),
1994            "Should generate for loop header: {}",
1995            result
1996        );
1997        assert!(
1998            result.contains("data[i] = 0.0f"),
1999            "Should contain loop body: {}",
2000            result
2001        );
2002
2003        println!("Generated for loop:\n{}", result);
2004    }
2005
2006    #[test]
2007    fn test_for_loop_inclusive_range() {
2008        let transpiler = CudaTranspiler::new_generic();
2009
2010        let expr: Expr = parse_quote! {
2011            for i in 1..=10 {
2012                sum += i;
2013            }
2014        };
2015
2016        let result = transpiler.transpile_expr(&expr).unwrap();
2017        assert!(
2018            result.contains("for (int i = 1; i <= 10; i++)"),
2019            "Should generate inclusive range: {}",
2020            result
2021        );
2022
2023        println!("Generated inclusive for loop:\n{}", result);
2024    }
2025
2026    #[test]
2027    fn test_while_loop_transpile() {
2028        let transpiler = CudaTranspiler::new_generic();
2029
2030        let expr: Expr = parse_quote! {
2031            while i < 10 {
2032                i += 1;
2033            }
2034        };
2035
2036        let result = transpiler.transpile_expr(&expr).unwrap();
2037        assert!(
2038            result.contains("while (i < 10)"),
2039            "Should generate while loop: {}",
2040            result
2041        );
2042        assert!(
2043            result.contains("i += 1"),
2044            "Should contain loop body: {}",
2045            result
2046        );
2047
2048        println!("Generated while loop:\n{}", result);
2049    }
2050
2051    #[test]
2052    fn test_while_loop_negation() {
2053        let transpiler = CudaTranspiler::new_generic();
2054
2055        let expr: Expr = parse_quote! {
2056            while !done {
2057                process();
2058            }
2059        };
2060
2061        let result = transpiler.transpile_expr(&expr).unwrap();
2062        assert!(
2063            result.contains("while (!(done))"),
2064            "Should negate condition: {}",
2065            result
2066        );
2067
2068        println!("Generated while loop with negation:\n{}", result);
2069    }
2070
2071    #[test]
2072    fn test_infinite_loop_transpile() {
2073        let transpiler = CudaTranspiler::new_generic();
2074
2075        let expr: Expr = parse_quote! {
2076            loop {
2077                process();
2078            }
2079        };
2080
2081        let result = transpiler.transpile_expr(&expr).unwrap();
2082        assert!(
2083            result.contains("while (true)"),
2084            "Should generate infinite loop: {}",
2085            result
2086        );
2087        assert!(
2088            result.contains("process()"),
2089            "Should contain loop body: {}",
2090            result
2091        );
2092
2093        println!("Generated infinite loop:\n{}", result);
2094    }
2095
2096    #[test]
2097    fn test_break_transpile() {
2098        let transpiler = CudaTranspiler::new_generic();
2099
2100        let expr: Expr = parse_quote!(break);
2101        let result = transpiler.transpile_expr(&expr).unwrap();
2102        assert_eq!(result, "break");
2103    }
2104
2105    #[test]
2106    fn test_continue_transpile() {
2107        let transpiler = CudaTranspiler::new_generic();
2108
2109        let expr: Expr = parse_quote!(continue);
2110        let result = transpiler.transpile_expr(&expr).unwrap();
2111        assert_eq!(result, "continue");
2112    }
2113
2114    #[test]
2115    fn test_loop_with_break() {
2116        let transpiler = CudaTranspiler::new_generic();
2117
2118        let expr: Expr = parse_quote! {
2119            loop {
2120                if done {
2121                    break;
2122                }
2123            }
2124        };
2125
2126        let result = transpiler.transpile_expr(&expr).unwrap();
2127        assert!(
2128            result.contains("while (true)"),
2129            "Should generate infinite loop: {}",
2130            result
2131        );
2132        assert!(result.contains("break"), "Should contain break: {}", result);
2133
2134        println!("Generated loop with break:\n{}", result);
2135    }
2136
2137    #[test]
2138    fn test_nested_loops() {
2139        let transpiler = CudaTranspiler::new_generic();
2140
2141        let expr: Expr = parse_quote! {
2142            for i in 0..m {
2143                for j in 0..n {
2144                    matrix[i * n + j] = 0.0;
2145                }
2146            }
2147        };
2148
2149        let result = transpiler.transpile_expr(&expr).unwrap();
2150        assert!(
2151            result.contains("for (int i = 0; i < m; i++)"),
2152            "Should have outer loop: {}",
2153            result
2154        );
2155        assert!(
2156            result.contains("for (int j = 0; j < n; j++)"),
2157            "Should have inner loop: {}",
2158            result
2159        );
2160
2161        println!("Generated nested loops:\n{}", result);
2162    }
2163
2164    #[test]
2165    fn test_stencil_mode_rejects_loops() {
2166        let config = StencilConfig::new("test")
2167            .with_tile_size(16, 16)
2168            .with_halo(1);
2169        let transpiler = CudaTranspiler::new(config);
2170
2171        let expr: Expr = parse_quote! {
2172            for i in 0..n {
2173                data[i] = 0.0;
2174            }
2175        };
2176
2177        let result = transpiler.transpile_expr(&expr);
2178        assert!(result.is_err(), "Stencil mode should reject loops");
2179    }
2180
2181    #[test]
2182    fn test_labeled_break_rejected() {
2183        let transpiler = CudaTranspiler::new_generic();
2184
2185        // Note: We can't directly parse `break 'label` without a labeled block,
2186        // so we test that the error path exists by checking the function handles labels
2187        let break_expr = syn::ExprBreak {
2188            attrs: Vec::new(),
2189            break_token: syn::token::Break::default(),
2190            label: Some(syn::Lifetime::new("'outer", proc_macro2::Span::call_site())),
2191            expr: None,
2192        };
2193
2194        let result = transpiler.transpile_break(&break_expr);
2195        assert!(result.is_err(), "Labeled break should be rejected");
2196    }
2197
2198    #[test]
2199    fn test_full_kernel_with_loop() {
2200        let func: ItemFn = parse_quote! {
2201            fn fill_array(data: &mut [f32], n: i32) {
2202                for i in 0..n {
2203                    data[i as usize] = 0.0;
2204                }
2205            }
2206        };
2207
2208        let mut transpiler = CudaTranspiler::new_generic();
2209        let cuda = transpiler.transpile_generic_kernel(&func).unwrap();
2210
2211        assert!(
2212            cuda.contains("extern \"C\" __global__"),
2213            "Should be global kernel: {}",
2214            cuda
2215        );
2216        assert!(
2217            cuda.contains("for (int i = 0; i < n; i++)"),
2218            "Should have for loop: {}",
2219            cuda
2220        );
2221
2222        println!("Generated kernel with loop:\n{}", cuda);
2223    }
2224
2225    #[test]
2226    fn test_persistent_kernel_pattern() {
2227        // Test the pattern used for ring/actor kernels
2228        let transpiler = CudaTranspiler::with_mode(ValidationMode::RingKernel);
2229
2230        let expr: Expr = parse_quote! {
2231            while !should_terminate {
2232                if has_message {
2233                    process_message();
2234                }
2235            }
2236        };
2237
2238        let result = transpiler.transpile_expr(&expr).unwrap();
2239        assert!(
2240            result.contains("while (!(should_terminate))"),
2241            "Should have persistent loop: {}",
2242            result
2243        );
2244        assert!(
2245            result.contains("if (has_message)"),
2246            "Should have message check: {}",
2247            result
2248        );
2249
2250        println!("Generated persistent kernel pattern:\n{}", result);
2251    }
2252
2253    // ==================== Shared Memory Tests ====================
2254
2255    #[test]
2256    fn test_shared_tile_declaration() {
2257        use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2258
2259        let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
2260        assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
2261
2262        let mut config = SharedMemoryConfig::new();
2263        config.add_tile("tile", "float", 16, 16);
2264        assert_eq!(config.total_bytes(), 16 * 16 * 4); // 1024 bytes
2265
2266        let decls = config.generate_declarations("    ");
2267        assert!(decls.contains("__shared__ float tile[16][16];"));
2268    }
2269
2270    #[test]
2271    fn test_shared_array_declaration() {
2272        use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2273
2274        let decl = SharedMemoryDecl::array("buffer", "float", 256);
2275        assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
2276
2277        let mut config = SharedMemoryConfig::new();
2278        config.add_array("buffer", "float", 256);
2279        assert_eq!(config.total_bytes(), 256 * 4); // 1024 bytes
2280    }
2281
2282    #[test]
2283    fn test_shared_memory_access_expressions() {
2284        use crate::shared::SharedMemoryDecl;
2285
2286        let tile = SharedMemoryDecl::tile("tile", "float", 16, 16);
2287        assert_eq!(
2288            tile.to_cuda_access(&["y".to_string(), "x".to_string()]),
2289            "tile[y][x]"
2290        );
2291
2292        let arr = SharedMemoryDecl::array("buf", "int", 128);
2293        assert_eq!(arr.to_cuda_access(&["i".to_string()]), "buf[i]");
2294    }
2295
2296    #[test]
2297    fn test_parse_shared_tile_type() {
2298        use crate::shared::parse_shared_tile_type;
2299
2300        let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
2301        assert_eq!(result, Some(("f32".to_string(), 16, 16)));
2302
2303        let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
2304        assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
2305
2306        let invalid = parse_shared_tile_type("Vec<f32>");
2307        assert_eq!(invalid, None);
2308    }
2309
2310    #[test]
2311    fn test_parse_shared_array_type() {
2312        use crate::shared::parse_shared_array_type;
2313
2314        let result = parse_shared_array_type("SharedArray::<f32, 256>");
2315        assert_eq!(result, Some(("f32".to_string(), 256)));
2316
2317        let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
2318        assert_eq!(result2, Some(("u32".to_string(), 1024)));
2319
2320        let invalid = parse_shared_array_type("Vec<f32>");
2321        assert_eq!(invalid, None);
2322    }
2323
2324    #[test]
2325    fn test_rust_to_cuda_element_types() {
2326        use crate::shared::rust_to_cuda_element_type;
2327
2328        assert_eq!(rust_to_cuda_element_type("f32"), "float");
2329        assert_eq!(rust_to_cuda_element_type("f64"), "double");
2330        assert_eq!(rust_to_cuda_element_type("i32"), "int");
2331        assert_eq!(rust_to_cuda_element_type("u32"), "unsigned int");
2332        assert_eq!(rust_to_cuda_element_type("i64"), "long long");
2333        assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
2334        assert_eq!(rust_to_cuda_element_type("bool"), "int");
2335    }
2336
2337    #[test]
2338    fn test_shared_memory_total_bytes() {
2339        use crate::shared::SharedMemoryConfig;
2340
2341        let mut config = SharedMemoryConfig::new();
2342        config.add_tile("tile1", "float", 16, 16); // 16*16*4 = 1024
2343        config.add_tile("tile2", "double", 8, 8); // 8*8*8 = 512
2344        config.add_array("temp", "int", 64); // 64*4 = 256
2345
2346        assert_eq!(config.total_bytes(), 1024 + 512 + 256);
2347    }
2348
2349    #[test]
2350    fn test_transpiler_shared_var_tracking() {
2351        let mut transpiler = CudaTranspiler::new_generic();
2352
2353        // Manually register a shared variable
2354        transpiler.shared_vars.insert(
2355            "tile".to_string(),
2356            SharedVarInfo {
2357                name: "tile".to_string(),
2358                is_tile: true,
2359                dimensions: vec![16, 16],
2360                element_type: "float".to_string(),
2361            },
2362        );
2363
2364        // Test that transpiler tracks it
2365        assert!(transpiler.shared_vars.contains_key("tile"));
2366        assert!(transpiler.shared_vars.get("tile").unwrap().is_tile);
2367    }
2368
2369    #[test]
2370    fn test_shared_tile_get_transpilation() {
2371        let mut transpiler = CudaTranspiler::new_generic();
2372
2373        // Register a shared tile
2374        transpiler.shared_vars.insert(
2375            "tile".to_string(),
2376            SharedVarInfo {
2377                name: "tile".to_string(),
2378                is_tile: true,
2379                dimensions: vec![16, 16],
2380                element_type: "float".to_string(),
2381            },
2382        );
2383
2384        // Test method call transpilation
2385        let result = transpiler.try_transpile_shared_method_call(
2386            "tile",
2387            "get",
2388            &syn::punctuated::Punctuated::new(),
2389        );
2390
2391        // With no args, it should return None (args required)
2392        assert!(result.is_none() || result.unwrap().is_err());
2393    }
2394
2395    #[test]
2396    fn test_shared_array_access() {
2397        let mut transpiler = CudaTranspiler::new_generic();
2398
2399        // Register a shared array
2400        transpiler.shared_vars.insert(
2401            "buffer".to_string(),
2402            SharedVarInfo {
2403                name: "buffer".to_string(),
2404                is_tile: false,
2405                dimensions: vec![256],
2406                element_type: "float".to_string(),
2407            },
2408        );
2409
2410        assert!(!transpiler.shared_vars.get("buffer").unwrap().is_tile);
2411        assert_eq!(
2412            transpiler.shared_vars.get("buffer").unwrap().dimensions,
2413            vec![256]
2414        );
2415    }
2416
2417    #[test]
2418    fn test_full_kernel_with_shared_memory() {
2419        // Test that we can generate declarations correctly
2420        use crate::shared::SharedMemoryConfig;
2421
2422        let mut config = SharedMemoryConfig::new();
2423        config.add_tile("smem", "float", 16, 16);
2424
2425        let decls = config.generate_declarations("    ");
2426        assert!(decls.contains("__shared__ float smem[16][16];"));
2427        assert!(!config.is_empty());
2428    }
2429
2430    // === Struct Literal Tests ===
2431
2432    #[test]
2433    fn test_struct_literal_transpile() {
2434        let transpiler = CudaTranspiler::new_generic();
2435
2436        let expr: Expr = parse_quote! {
2437            Point { x: 1.0, y: 2.0 }
2438        };
2439
2440        let result = transpiler.transpile_expr(&expr).unwrap();
2441        assert!(
2442            result.contains("Point"),
2443            "Should contain struct name: {}",
2444            result
2445        );
2446        assert!(result.contains(".x ="), "Should have field x: {}", result);
2447        assert!(result.contains(".y ="), "Should have field y: {}", result);
2448        assert!(
2449            result.contains("1.0f"),
2450            "Should have value 1.0f: {}",
2451            result
2452        );
2453        assert!(
2454            result.contains("2.0f"),
2455            "Should have value 2.0f: {}",
2456            result
2457        );
2458
2459        println!("Generated struct literal: {}", result);
2460    }
2461
2462    #[test]
2463    fn test_struct_literal_with_expressions() {
2464        let transpiler = CudaTranspiler::new_generic();
2465
2466        let expr: Expr = parse_quote! {
2467            Response { value: x * 2.0, id: idx as u64 }
2468        };
2469
2470        let result = transpiler.transpile_expr(&expr).unwrap();
2471        assert!(
2472            result.contains("Response"),
2473            "Should contain struct name: {}",
2474            result
2475        );
2476        assert!(
2477            result.contains(".value = x * 2.0f"),
2478            "Should have computed value: {}",
2479            result
2480        );
2481        assert!(result.contains(".id ="), "Should have id field: {}", result);
2482
2483        println!("Generated struct with expressions: {}", result);
2484    }
2485
2486    #[test]
2487    fn test_struct_literal_in_return() {
2488        let mut transpiler = CudaTranspiler::new_generic();
2489
2490        let stmt: Stmt = parse_quote! {
2491            return MyStruct { a: 1, b: 2.0 };
2492        };
2493
2494        let result = transpiler.transpile_stmt(&stmt).unwrap();
2495        assert!(result.contains("return"), "Should have return: {}", result);
2496        assert!(
2497            result.contains("MyStruct"),
2498            "Should contain struct name: {}",
2499            result
2500        );
2501
2502        println!("Generated return with struct: {}", result);
2503    }
2504
2505    #[test]
2506    fn test_struct_literal_compound_literal_format() {
2507        let transpiler = CudaTranspiler::new_generic();
2508
2509        let expr: Expr = parse_quote! {
2510            Vec3 { x: a, y: b, z: c }
2511        };
2512
2513        let result = transpiler.transpile_expr(&expr).unwrap();
2514        // Check for C compound literal format: (Type){ .field = val, ... }
2515        assert!(
2516            result.starts_with("(Vec3){"),
2517            "Should use compound literal format: {}",
2518            result
2519        );
2520        assert!(
2521            result.ends_with("}"),
2522            "Should end with closing brace: {}",
2523            result
2524        );
2525
2526        println!("Generated compound literal: {}", result);
2527    }
2528
2529    // === Reference Expression Tests ===
2530
2531    #[test]
2532    fn test_reference_to_array_element() {
2533        let transpiler = CudaTranspiler::new_generic();
2534
2535        let expr: Expr = parse_quote! {
2536            &arr[idx]
2537        };
2538
2539        let result = transpiler.transpile_expr(&expr).unwrap();
2540        assert_eq!(
2541            result, "&arr[idx]",
2542            "Should produce address-of array element"
2543        );
2544    }
2545
2546    #[test]
2547    fn test_mutable_reference_to_array_element() {
2548        let transpiler = CudaTranspiler::new_generic();
2549
2550        let expr: Expr = parse_quote! {
2551            &mut arr[idx * 4 + offset]
2552        };
2553
2554        let result = transpiler.transpile_expr(&expr).unwrap();
2555        assert!(
2556            result.contains("&arr["),
2557            "Should produce address-of: {}",
2558            result
2559        );
2560        assert!(
2561            result.contains("idx * 4"),
2562            "Should have index expression: {}",
2563            result
2564        );
2565    }
2566
2567    #[test]
2568    fn test_reference_to_variable() {
2569        let transpiler = CudaTranspiler::new_generic();
2570
2571        let expr: Expr = parse_quote! {
2572            &value
2573        };
2574
2575        let result = transpiler.transpile_expr(&expr).unwrap();
2576        assert_eq!(result, "&value", "Should produce address-of variable");
2577    }
2578
2579    #[test]
2580    fn test_reference_to_struct_field() {
2581        let transpiler = CudaTranspiler::new_generic();
2582
2583        let expr: Expr = parse_quote! {
2584            &alerts[(idx as usize) * 4 + alert_idx as usize]
2585        };
2586
2587        let result = transpiler.transpile_expr(&expr).unwrap();
2588        assert!(
2589            result.starts_with("&alerts["),
2590            "Should have address-of array: {}",
2591            result
2592        );
2593
2594        println!("Generated reference: {}", result);
2595    }
2596
2597    #[test]
2598    fn test_complex_reference_pattern() {
2599        let mut transpiler = CudaTranspiler::new_generic();
2600
2601        // This is the pattern from txmon batch kernel
2602        let stmt: Stmt = parse_quote! {
2603            let alert = &mut alerts[(idx as usize) * 4 + alert_idx as usize];
2604        };
2605
2606        let result = transpiler.transpile_stmt(&stmt).unwrap();
2607        assert!(
2608            result.contains("alert ="),
2609            "Should have variable assignment: {}",
2610            result
2611        );
2612        assert!(
2613            result.contains("&alerts["),
2614            "Should have reference to array: {}",
2615            result
2616        );
2617
2618        println!("Generated statement: {}", result);
2619    }
2620}