1use crate::handler::{ContextMethod, HandlerSignature};
6use crate::intrinsics::{IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
7use crate::loops::{extract_loop_var, RangeInfo};
8use crate::shared::{rust_to_cuda_element_type, SharedMemoryConfig, SharedMemoryDecl};
9use crate::stencil::StencilConfig;
10use crate::types::{is_grid_pos_type, is_ring_context_type, TypeMapper};
11use crate::validation::ValidationMode;
12use crate::{Result, TranspileError};
13use quote::ToTokens;
14use syn::{
15 BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
16 ExprIf, ExprIndex, ExprLet, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
17 ExprReference, ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat,
18 ReturnType, Stmt, UnOp,
19};
20
21pub struct CudaTranspiler {
23 config: Option<StencilConfig>,
25 type_mapper: TypeMapper,
27 intrinsics: IntrinsicRegistry,
29 grid_pos_vars: Vec<String>,
31 context_vars: Vec<String>,
33 indent: usize,
35 validation_mode: ValidationMode,
37 shared_memory: SharedMemoryConfig,
39 pub shared_vars: std::collections::HashMap<String, SharedVarInfo>,
41 ring_kernel_mode: bool,
43 pointer_vars: std::collections::HashSet<String>,
45}
46
47#[derive(Debug, Clone)]
49pub struct SharedVarInfo {
50 pub name: String,
52 pub is_tile: bool,
54 pub dimensions: Vec<usize>,
56 pub element_type: String,
58}
59
60impl CudaTranspiler {
61 pub fn new(config: StencilConfig) -> Self {
63 Self {
64 config: Some(config),
65 type_mapper: TypeMapper::new(),
66 intrinsics: IntrinsicRegistry::new(),
67 grid_pos_vars: Vec::new(),
68 context_vars: Vec::new(),
69 indent: 1, validation_mode: ValidationMode::Stencil,
71 shared_memory: SharedMemoryConfig::new(),
72 shared_vars: std::collections::HashMap::new(),
73 ring_kernel_mode: false,
74 pointer_vars: std::collections::HashSet::new(),
75 }
76 }
77
78 pub fn new_generic() -> Self {
80 Self {
81 config: None,
82 type_mapper: TypeMapper::new(),
83 intrinsics: IntrinsicRegistry::new(),
84 grid_pos_vars: Vec::new(),
85 context_vars: Vec::new(),
86 indent: 1,
87 validation_mode: ValidationMode::Generic,
88 shared_memory: SharedMemoryConfig::new(),
89 shared_vars: std::collections::HashMap::new(),
90 ring_kernel_mode: false,
91 pointer_vars: std::collections::HashSet::new(),
92 }
93 }
94
95 pub fn with_mode(mode: ValidationMode) -> Self {
97 Self {
98 config: None,
99 type_mapper: TypeMapper::new(),
100 intrinsics: IntrinsicRegistry::new(),
101 grid_pos_vars: Vec::new(),
102 context_vars: Vec::new(),
103 indent: 1,
104 validation_mode: mode,
105 shared_memory: SharedMemoryConfig::new(),
106 shared_vars: std::collections::HashMap::new(),
107 ring_kernel_mode: false,
108 pointer_vars: std::collections::HashSet::new(),
109 }
110 }
111
112 pub fn for_ring_kernel() -> Self {
114 Self {
115 config: None,
116 type_mapper: crate::types::ring_kernel_type_mapper(),
117 intrinsics: IntrinsicRegistry::new(),
118 grid_pos_vars: Vec::new(),
119 context_vars: Vec::new(),
120 indent: 2, validation_mode: ValidationMode::Generic,
122 shared_memory: SharedMemoryConfig::new(),
123 shared_vars: std::collections::HashMap::new(),
124 ring_kernel_mode: true,
125 pointer_vars: std::collections::HashSet::new(),
126 }
127 }
128
129 pub fn set_validation_mode(&mut self, mode: ValidationMode) {
131 self.validation_mode = mode;
132 }
133
134 pub fn shared_memory(&self) -> &SharedMemoryConfig {
136 &self.shared_memory
137 }
138
139 fn indent_str(&self) -> String {
141 " ".repeat(self.indent)
142 }
143
144 pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
146 let config = self
147 .config
148 .as_ref()
149 .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
150 .clone();
151
152 for param in &func.sig.inputs {
154 if let FnArg::Typed(pat_type) = param {
155 if is_grid_pos_type(&pat_type.ty) {
156 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
157 self.grid_pos_vars.push(ident.ident.to_string());
158 }
159 }
160 }
161 }
162
163 let signature = self.transpile_kernel_signature(func)?;
165
166 let preamble = config.generate_preamble();
168
169 let body = self.transpile_block(&func.block)?;
171
172 Ok(format!(
173 "extern \"C\" __global__ void {signature} {{\n{preamble}\n{body}}}\n"
174 ))
175 }
176
177 pub fn transpile_generic_kernel(&mut self, func: &ItemFn) -> Result<String> {
183 let signature = self.transpile_generic_kernel_signature(func)?;
185
186 let body = self.transpile_block(&func.block)?;
188
189 Ok(format!(
190 "extern \"C\" __global__ void {signature} {{\n{body}}}\n"
191 ))
192 }
193
194 pub fn transpile_ring_kernel(
199 &mut self,
200 handler: &ItemFn,
201 config: &crate::ring_kernel::RingKernelConfig,
202 ) -> Result<String> {
203 use std::fmt::Write;
204
205 let handler_sig = HandlerSignature::parse(handler, &self.type_mapper)?;
207
208 for param in &handler.sig.inputs {
210 if let FnArg::Typed(pat_type) = param {
211 if is_ring_context_type(&pat_type.ty) {
212 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
213 self.context_vars.push(ident.ident.to_string());
214 }
215 }
216 }
217 }
218
219 self.ring_kernel_mode = true;
221
222 let mut output = String::new();
223
224 output.push_str(&crate::ring_kernel::generate_control_block_struct());
226 output.push('\n');
227
228 if config.enable_hlc {
229 output.push_str(&crate::ring_kernel::generate_hlc_struct());
230 output.push('\n');
231 }
232
233 if config.enable_k2k {
234 output.push_str(&crate::ring_kernel::generate_k2k_structs());
235 output.push('\n');
236 }
237
238 if let Some(ref msg_param) = handler_sig.message_param {
240 let type_name = msg_param
242 .rust_type
243 .trim_start_matches('&')
244 .trim_start_matches("mut ")
245 .trim();
246 if !type_name.is_empty() && type_name != "f32" && type_name != "i32" {
247 writeln!(output, "// Message type: {}", type_name).unwrap();
248 }
249 }
250
251 if let Some(ref ret_type) = handler_sig.return_type {
252 if ret_type.is_struct {
253 writeln!(output, "// Response type: {}", ret_type.rust_type).unwrap();
254 }
255 }
256
257 output.push_str(&config.generate_signature());
259 output.push_str(" {\n");
260
261 output.push_str(&config.generate_preamble(" "));
263
264 output.push_str(&config.generate_loop_header(" "));
266
267 if let Some(ref msg_param) = handler_sig.message_param {
269 let type_name = msg_param
270 .rust_type
271 .trim_start_matches('&')
272 .trim_start_matches("mut ")
273 .trim();
274 if !type_name.is_empty() {
275 writeln!(output, " // Message deserialization").unwrap();
276 writeln!(
277 output,
278 " // {}* {} = ({}*)msg_ptr;",
279 type_name, msg_param.name, type_name
280 )
281 .unwrap();
282 output.push('\n');
283 }
284 }
285
286 self.indent = 2; let handler_body = self.transpile_block(&handler.block)?;
289
290 writeln!(output, " // === USER HANDLER CODE ===").unwrap();
292 for line in handler_body.lines() {
293 if !line.trim().is_empty() {
294 writeln!(output, " {}", line).unwrap();
296 }
297 }
298 writeln!(output, " // === END HANDLER CODE ===").unwrap();
299
300 if let Some(ref ret_type) = handler_sig.return_type {
302 writeln!(output).unwrap();
303 writeln!(output, " // Response serialization").unwrap();
304 if ret_type.is_struct {
305 writeln!(output, " // memcpy(&output_buffer[_out_idx * RESP_SIZE], &response, sizeof({}));",
306 ret_type.cuda_type).unwrap();
307 }
308 }
309
310 output.push_str(&config.generate_message_complete(" "));
312
313 output.push_str(&config.generate_loop_footer(" "));
315
316 output.push_str(&config.generate_epilogue(" "));
318
319 output.push_str("}\n");
320
321 Ok(output)
322 }
323
324 fn transpile_generic_kernel_signature(&self, func: &ItemFn) -> Result<String> {
326 let name = func.sig.ident.to_string();
327
328 let mut params = Vec::new();
329 for param in &func.sig.inputs {
330 if let FnArg::Typed(pat_type) = param {
331 let param_name = match pat_type.pat.as_ref() {
332 Pat::Ident(ident) => ident.ident.to_string(),
333 _ => {
334 return Err(TranspileError::Unsupported(
335 "Complex pattern in parameter".into(),
336 ))
337 }
338 };
339
340 let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
341 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
342 }
343 }
344
345 Ok(format!("{}({})", name, params.join(", ")))
346 }
347
348 fn transpile_kernel_signature(&self, func: &ItemFn) -> Result<String> {
350 let name = func.sig.ident.to_string();
351
352 let mut params = Vec::new();
353 for param in &func.sig.inputs {
354 if let FnArg::Typed(pat_type) = param {
355 if is_grid_pos_type(&pat_type.ty) {
357 continue;
358 }
359
360 let param_name = match pat_type.pat.as_ref() {
361 Pat::Ident(ident) => ident.ident.to_string(),
362 _ => {
363 return Err(TranspileError::Unsupported(
364 "Complex pattern in parameter".into(),
365 ))
366 }
367 };
368
369 let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
370 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
371 }
372 }
373
374 Ok(format!("{}({})", name, params.join(", ")))
375 }
376
377 fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
379 let mut output = String::new();
380
381 for stmt in &block.stmts {
382 let stmt_str = self.transpile_stmt(stmt)?;
383 if !stmt_str.is_empty() {
384 output.push_str(&stmt_str);
385 }
386 }
387
388 Ok(output)
389 }
390
391 fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
393 match stmt {
394 Stmt::Local(local) => {
395 let indent = self.indent_str();
396
397 let var_name = match &local.pat {
399 Pat::Ident(ident) => ident.ident.to_string(),
400 Pat::Type(pat_type) => {
401 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
402 ident.ident.to_string()
403 } else {
404 return Err(TranspileError::Unsupported(
405 "Complex pattern in let binding".into(),
406 ));
407 }
408 }
409 _ => {
410 return Err(TranspileError::Unsupported(
411 "Complex pattern in let binding".into(),
412 ))
413 }
414 };
415
416 if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
418 self.shared_vars.insert(
420 var_name.clone(),
421 SharedVarInfo {
422 name: var_name.clone(),
423 is_tile: shared_decl.dimensions.len() == 2,
424 dimensions: shared_decl.dimensions.clone(),
425 element_type: shared_decl.element_type.clone(),
426 },
427 );
428
429 self.shared_memory.add(shared_decl.clone());
431
432 return Ok(format!("{indent}{}\n", shared_decl.to_cuda_decl()));
434 }
435
436 if let Some(init) = &local.init {
438 let expr_str = self.transpile_expr(&init.expr)?;
439
440 let type_str = self.infer_cuda_type(&init.expr);
443
444 if type_str.ends_with('*') {
446 self.pointer_vars.insert(var_name.clone());
447 }
448
449 Ok(format!("{indent}{type_str} {var_name} = {expr_str};\n"))
450 } else {
451 Ok(format!("{indent}float {var_name};\n"))
453 }
454 }
455 Stmt::Expr(expr, semi) => {
456 let indent = self.indent_str();
457
458 if let Expr::If(if_expr) = expr {
460 if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
462 {
463 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
464 let expr_str = self.transpile_expr(expr)?;
465 return Ok(format!("{indent}{expr_str};\n"));
466 }
467 }
468 }
469
470 let expr_str = self.transpile_expr(expr)?;
471
472 if semi.is_some() {
473 Ok(format!("{indent}{expr_str};\n"))
474 } else {
475 if matches!(expr, Expr::Return(_))
478 || expr_str.starts_with("return")
479 || expr_str.starts_with("if (")
480 {
481 Ok(format!("{indent}{expr_str};\n"))
482 } else {
483 Ok(format!("{indent}return {expr_str};\n"))
484 }
485 }
486 }
487 Stmt::Item(_) => {
488 Err(TranspileError::Unsupported("Item in function body".into()))
490 }
491 Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
492 }
493 }
494
495 fn transpile_expr(&self, expr: &Expr) -> Result<String> {
497 match expr {
498 Expr::Lit(lit) => self.transpile_lit(lit),
499 Expr::Path(path) => self.transpile_path(path),
500 Expr::Binary(bin) => self.transpile_binary(bin),
501 Expr::Unary(unary) => self.transpile_unary(unary),
502 Expr::Paren(paren) => self.transpile_paren(paren),
503 Expr::Index(index) => self.transpile_index(index),
504 Expr::Call(call) => self.transpile_call(call),
505 Expr::MethodCall(method) => self.transpile_method_call(method),
506 Expr::If(if_expr) => self.transpile_if(if_expr),
507 Expr::Assign(assign) => self.transpile_assign(assign),
508 Expr::Cast(cast) => self.transpile_cast(cast),
509 Expr::Match(match_expr) => self.transpile_match(match_expr),
510 Expr::Block(block) => {
511 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
513 self.transpile_expr(expr)
514 } else {
515 Err(TranspileError::Unsupported(
516 "Complex block expression".into(),
517 ))
518 }
519 }
520 Expr::Field(field) => {
521 let base = self.transpile_expr(&field.base)?;
523 let member = match &field.member {
524 syn::Member::Named(ident) => ident.to_string(),
525 syn::Member::Unnamed(idx) => idx.index.to_string(),
526 };
527
528 let accessor = if self.pointer_vars.contains(&base) {
530 "->"
531 } else {
532 "."
533 };
534 Ok(format!("{base}{accessor}{member}"))
535 }
536 Expr::Return(ret) => self.transpile_return(ret),
537 Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
538 Expr::While(while_loop) => self.transpile_while_loop(while_loop),
539 Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
540 Expr::Break(break_expr) => self.transpile_break(break_expr),
541 Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
542 Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
543 Expr::Reference(ref_expr) => self.transpile_reference(ref_expr),
544 Expr::Let(let_expr) => self.transpile_let_expr(let_expr),
545 Expr::Tuple(tuple) => {
546 let elements: Vec<String> = tuple
548 .elems
549 .iter()
550 .map(|e| self.transpile_expr(e))
551 .collect::<Result<_>>()?;
552 Ok(format!("({})", elements.join(", ")))
553 }
554 _ => Err(TranspileError::Unsupported(format!(
555 "Expression type: {}",
556 expr.to_token_stream()
557 ))),
558 }
559 }
560
561 fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
563 match &lit.lit {
564 Lit::Float(f) => {
565 let s = f.to_string();
566 if s.ends_with("f32") || !s.contains('.') {
568 let num = s.trim_end_matches("f32").trim_end_matches("f64");
569 Ok(format!("{num}f"))
570 } else if s.ends_with("f64") {
571 Ok(s.trim_end_matches("f64").to_string())
572 } else {
573 Ok(format!("{s}f"))
575 }
576 }
577 Lit::Int(i) => Ok(i.to_string()),
578 Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
579 _ => Err(TranspileError::Unsupported(format!(
580 "Literal type: {}",
581 lit.to_token_stream()
582 ))),
583 }
584 }
585
586 fn transpile_path(&self, path: &ExprPath) -> Result<String> {
588 let segments: Vec<_> = path
589 .path
590 .segments
591 .iter()
592 .map(|s| s.ident.to_string())
593 .collect();
594
595 if segments.len() == 1 {
596 Ok(segments[0].clone())
597 } else {
598 Ok(segments.join("::"))
599 }
600 }
601
602 fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
604 let left = self.transpile_expr(&bin.left)?;
605 let right = self.transpile_expr(&bin.right)?;
606
607 let op = match bin.op {
608 BinOp::Add(_) => "+",
609 BinOp::Sub(_) => "-",
610 BinOp::Mul(_) => "*",
611 BinOp::Div(_) => "/",
612 BinOp::Rem(_) => "%",
613 BinOp::And(_) => "&&",
614 BinOp::Or(_) => "||",
615 BinOp::BitXor(_) => "^",
616 BinOp::BitAnd(_) => "&",
617 BinOp::BitOr(_) => "|",
618 BinOp::Shl(_) => "<<",
619 BinOp::Shr(_) => ">>",
620 BinOp::Eq(_) => "==",
621 BinOp::Lt(_) => "<",
622 BinOp::Le(_) => "<=",
623 BinOp::Ne(_) => "!=",
624 BinOp::Ge(_) => ">=",
625 BinOp::Gt(_) => ">",
626 BinOp::AddAssign(_) => "+=",
627 BinOp::SubAssign(_) => "-=",
628 BinOp::MulAssign(_) => "*=",
629 BinOp::DivAssign(_) => "/=",
630 BinOp::RemAssign(_) => "%=",
631 BinOp::BitXorAssign(_) => "^=",
632 BinOp::BitAndAssign(_) => "&=",
633 BinOp::BitOrAssign(_) => "|=",
634 BinOp::ShlAssign(_) => "<<=",
635 BinOp::ShrAssign(_) => ">>=",
636 _ => {
637 return Err(TranspileError::Unsupported(format!(
638 "Binary operator: {}",
639 bin.to_token_stream()
640 )))
641 }
642 };
643
644 Ok(format!("{left} {op} {right}"))
645 }
646
647 fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
649 let expr = self.transpile_expr(&unary.expr)?;
650
651 let op = match unary.op {
652 UnOp::Neg(_) => "-",
653 UnOp::Not(_) => "!",
654 UnOp::Deref(_) => "*",
655 _ => {
656 return Err(TranspileError::Unsupported(format!(
657 "Unary operator: {}",
658 unary.to_token_stream()
659 )))
660 }
661 };
662
663 Ok(format!("{op}({expr})"))
664 }
665
666 fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
668 let inner = self.transpile_expr(&paren.expr)?;
669 Ok(format!("({inner})"))
670 }
671
672 fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
674 let base = self.transpile_expr(&index.expr)?;
675 let idx = self.transpile_expr(&index.index)?;
676 Ok(format!("{base}[{idx}]"))
677 }
678
679 fn transpile_call(&self, call: &ExprCall) -> Result<String> {
681 let func = self.transpile_expr(&call.func)?;
682
683 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
685 let cuda_name = intrinsic.to_cuda_string();
686
687 let is_value_intrinsic = cuda_name.contains("Idx.")
690 || cuda_name.contains("Dim.")
691 || cuda_name.starts_with("threadIdx")
692 || cuda_name.starts_with("blockIdx")
693 || cuda_name.starts_with("blockDim")
694 || cuda_name.starts_with("gridDim");
695
696 if is_value_intrinsic && call.args.is_empty() {
697 return Ok(cuda_name.to_string());
699 }
700
701 if call.args.is_empty() && cuda_name.ends_with("()") {
702 return Ok(cuda_name.to_string());
704 }
705
706 let args: Vec<String> = call
707 .args
708 .iter()
709 .map(|a| self.transpile_expr(a))
710 .collect::<Result<_>>()?;
711
712 return Ok(format!(
713 "{}({})",
714 cuda_name.trim_end_matches("()"),
715 args.join(", ")
716 ));
717 }
718
719 let args: Vec<String> = call
721 .args
722 .iter()
723 .map(|a| self.transpile_expr(a))
724 .collect::<Result<_>>()?;
725
726 Ok(format!("{}({})", func, args.join(", ")))
727 }
728
729 fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
731 let receiver = self.transpile_expr(&method.receiver)?;
732 let method_name = method.method.to_string();
733
734 if let Some(result) =
736 self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
737 {
738 return result;
739 }
740
741 if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
743 return self.transpile_context_method(&method_name, &method.args);
744 }
745
746 if self.grid_pos_vars.contains(&receiver) {
748 return self.transpile_stencil_intrinsic(&method_name, &method.args);
749 }
750
751 if self.ring_kernel_mode {
753 if let Some(intrinsic) = RingKernelIntrinsic::from_name(&method_name) {
754 let args: Vec<String> = method
755 .args
756 .iter()
757 .map(|a| self.transpile_expr(a).unwrap_or_default())
758 .collect();
759 return Ok(intrinsic.to_cuda(&args));
760 }
761 }
762
763 if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
765 let cuda_name = intrinsic.to_cuda_string();
766 let args: Vec<String> = std::iter::once(receiver)
767 .chain(
768 method
769 .args
770 .iter()
771 .map(|a| self.transpile_expr(a).unwrap_or_default()),
772 )
773 .collect();
774
775 return Ok(format!("{}({})", cuda_name, args.join(", ")));
776 }
777
778 let args: Vec<String> = method
780 .args
781 .iter()
782 .map(|a| self.transpile_expr(a))
783 .collect::<Result<_>>()?;
784
785 Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
786 }
787
788 fn transpile_context_method(
790 &self,
791 method: &str,
792 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
793 ) -> Result<String> {
794 let ctx_method = ContextMethod::from_name(method).ok_or_else(|| {
795 TranspileError::Unsupported(format!("Unknown context method: {}", method))
796 })?;
797
798 let cuda_args: Vec<String> = args
799 .iter()
800 .map(|a| self.transpile_expr(a).unwrap_or_default())
801 .collect();
802
803 Ok(ctx_method.to_cuda(&cuda_args))
804 }
805
806 fn transpile_stencil_intrinsic(
808 &self,
809 method: &str,
810 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
811 ) -> Result<String> {
812 let config = self.config.as_ref().ok_or_else(|| {
813 TranspileError::Unsupported("Stencil intrinsic without config".into())
814 })?;
815
816 let buffer_width = config.buffer_width().to_string();
817
818 let intrinsic = StencilIntrinsic::from_method_name(method).ok_or_else(|| {
819 TranspileError::Unsupported(format!("Unknown stencil intrinsic: {method}"))
820 })?;
821
822 match intrinsic {
823 StencilIntrinsic::Index => {
824 Ok("idx".to_string())
826 }
827 StencilIntrinsic::North
828 | StencilIntrinsic::South
829 | StencilIntrinsic::East
830 | StencilIntrinsic::West => {
831 if args.is_empty() {
833 return Err(TranspileError::Unsupported(
834 "Stencil accessor requires buffer argument".into(),
835 ));
836 }
837 let buffer = self.transpile_expr(&args[0])?;
838 Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx"))
839 }
840 StencilIntrinsic::At => {
841 if args.len() < 3 {
843 return Err(TranspileError::Unsupported(
844 "at() requires buffer, dx, dy arguments".into(),
845 ));
846 }
847 let buffer = self.transpile_expr(&args[0])?;
848 let dx = self.transpile_expr(&args[1])?;
849 let dy = self.transpile_expr(&args[2])?;
850 Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]"))
851 }
852 StencilIntrinsic::Up | StencilIntrinsic::Down => {
853 Err(TranspileError::Unsupported(
855 "3D stencil intrinsics not yet implemented".into(),
856 ))
857 }
858 }
859 }
860
861 fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
863 let cond = self.transpile_expr(&if_expr.cond)?;
864
865 if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
867 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
868 if ret.expr.is_none() {
870 return Ok(format!("if ({cond}) return"));
871 }
872 let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
873 return Ok(format!("if ({cond}) return {ret_val}"));
874 }
875 }
876
877 if let Some((_, else_branch)) = &if_expr.else_branch {
879 if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
881 (if_expr.then_branch.stmts.last(), else_branch.as_ref())
882 {
883 if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
884 let then_str = self.transpile_expr(then_expr)?;
885 let else_str = self.transpile_expr(else_expr)?;
886 return Ok(format!("({cond}) ? ({then_str}) : ({else_str})"));
887 }
888 }
889
890 if let Expr::If(else_if) = else_branch.as_ref() {
892 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
894 let else_part = self.transpile_if(else_if)?;
895 return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
896 } else if let Expr::Block(else_block) = else_branch.as_ref() {
897 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
899 let else_body = self.transpile_if_body(&else_block.block)?;
900 return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
901 }
902 }
903
904 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
906 Ok(format!("if ({cond}) {{{then_body}}}"))
907 }
908
909 fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
911 let mut body = String::new();
912 for stmt in &block.stmts {
913 match stmt {
914 Stmt::Expr(expr, Some(_)) => {
915 let expr_str = self.transpile_expr(expr)?;
916 body.push_str(&format!(" {expr_str};"));
917 }
918 Stmt::Expr(Expr::Return(ret), None) => {
919 if let Some(ret_expr) = &ret.expr {
921 let expr_str = self.transpile_expr(ret_expr)?;
922 body.push_str(&format!(" return {expr_str};"));
923 } else {
924 body.push_str(" return;");
925 }
926 }
927 Stmt::Expr(expr, None) => {
928 let expr_str = self.transpile_expr(expr)?;
929 body.push_str(&format!(" return {expr_str};"));
930 }
931 _ => {}
932 }
933 }
934 Ok(body)
935 }
936
937 fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
939 let left = self.transpile_expr(&assign.left)?;
940 let right = self.transpile_expr(&assign.right)?;
941 Ok(format!("{left} = {right}"))
942 }
943
944 fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
946 let expr = self.transpile_expr(&cast.expr)?;
947 let cuda_type = self.type_mapper.map_type(&cast.ty)?;
948 Ok(format!("({})({})", cuda_type.to_cuda_string(), expr))
949 }
950
951 fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
953 if let Some(expr) = &ret.expr {
954 let expr_str = self.transpile_expr(expr)?;
955 Ok(format!("return {expr_str}"))
956 } else {
957 Ok("return".to_string())
958 }
959 }
960
961 fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
966 let type_name = struct_expr
968 .path
969 .segments
970 .iter()
971 .map(|s| s.ident.to_string())
972 .collect::<Vec<_>>()
973 .join("::");
974
975 let mut fields = Vec::new();
977 for field in &struct_expr.fields {
978 let field_name = match &field.member {
979 syn::Member::Named(ident) => ident.to_string(),
980 syn::Member::Unnamed(idx) => idx.index.to_string(),
981 };
982 let value = self.transpile_expr(&field.expr)?;
983 fields.push(format!(".{} = {}", field_name, value));
984 }
985
986 if struct_expr.rest.is_some() {
988 return Err(TranspileError::Unsupported(
989 "Struct update syntax (..base) is not supported in CUDA".into(),
990 ));
991 }
992
993 Ok(format!("({}){{ {} }}", type_name, fields.join(", ")))
995 }
996
997 fn transpile_reference(&self, ref_expr: &ExprReference) -> Result<String> {
1004 let inner = self.transpile_expr(&ref_expr.expr)?;
1005
1006 Ok(format!("&{inner}"))
1010 }
1011
1012 fn transpile_let_expr(&self, let_expr: &ExprLet) -> Result<String> {
1018 let _ = let_expr; Err(TranspileError::Unsupported(
1022 "let expressions (if-let patterns) are not directly supported in CUDA. \
1023 Use explicit comparisons instead."
1024 .into(),
1025 ))
1026 }
1027
1028 fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1048 if !self.validation_mode.allows_loops() {
1050 return Err(TranspileError::Unsupported(
1051 "Loops are not allowed in stencil kernels".into(),
1052 ));
1053 }
1054
1055 let var_name = extract_loop_var(&for_loop.pat)
1057 .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1058
1059 let header = match for_loop.expr.as_ref() {
1061 Expr::Range(range) => {
1062 let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1063 range_info.to_cuda_for_header(&var_name, "int")
1064 }
1065 _ => {
1066 return Err(TranspileError::Unsupported(
1068 "Only range expressions (start..end) are supported in for loops".into(),
1069 ));
1070 }
1071 };
1072
1073 let body = self.transpile_loop_body(&for_loop.body)?;
1075
1076 Ok(format!("{header} {{\n{body}}}"))
1077 }
1078
1079 fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1095 if !self.validation_mode.allows_loops() {
1097 return Err(TranspileError::Unsupported(
1098 "Loops are not allowed in stencil kernels".into(),
1099 ));
1100 }
1101
1102 let condition = self.transpile_expr(&while_loop.cond)?;
1104
1105 let body = self.transpile_loop_body(&while_loop.body)?;
1107
1108 Ok(format!("while ({condition}) {{\n{body}}}"))
1109 }
1110
1111 fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1127 if !self.validation_mode.allows_loops() {
1129 return Err(TranspileError::Unsupported(
1130 "Loops are not allowed in stencil kernels".into(),
1131 ));
1132 }
1133
1134 let body = self.transpile_loop_body(&loop_expr.body)?;
1136
1137 Ok(format!("while (true) {{\n{body}}}"))
1139 }
1140
1141 fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1143 if break_expr.label.is_some() {
1145 return Err(TranspileError::Unsupported(
1146 "Labeled break is not supported in CUDA".into(),
1147 ));
1148 }
1149
1150 if break_expr.expr.is_some() {
1152 return Err(TranspileError::Unsupported(
1153 "Break with value is not supported in CUDA".into(),
1154 ));
1155 }
1156
1157 Ok("break".to_string())
1158 }
1159
1160 fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1162 if cont_expr.label.is_some() {
1164 return Err(TranspileError::Unsupported(
1165 "Labeled continue is not supported in CUDA".into(),
1166 ));
1167 }
1168
1169 Ok("continue".to_string())
1170 }
1171
1172 fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1174 let mut output = String::new();
1175 let inner_indent = " ".repeat(self.indent + 1);
1176
1177 for stmt in &block.stmts {
1178 match stmt {
1179 Stmt::Local(local) => {
1180 let var_name = match &local.pat {
1182 Pat::Ident(ident) => ident.ident.to_string(),
1183 Pat::Type(pat_type) => {
1184 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1185 ident.ident.to_string()
1186 } else {
1187 return Err(TranspileError::Unsupported(
1188 "Complex pattern in let binding".into(),
1189 ));
1190 }
1191 }
1192 _ => {
1193 return Err(TranspileError::Unsupported(
1194 "Complex pattern in let binding".into(),
1195 ))
1196 }
1197 };
1198
1199 if let Some(init) = &local.init {
1200 let expr_str = self.transpile_expr(&init.expr)?;
1201 let type_str = self.infer_cuda_type(&init.expr);
1202 output.push_str(&format!(
1203 "{inner_indent}{type_str} {var_name} = {expr_str};\n"
1204 ));
1205 } else {
1206 output.push_str(&format!("{inner_indent}float {var_name};\n"));
1207 }
1208 }
1209 Stmt::Expr(expr, semi) => {
1210 let expr_str = self.transpile_expr(expr)?;
1211 if semi.is_some() {
1212 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1213 } else {
1214 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1216 }
1217 }
1218 _ => {
1219 return Err(TranspileError::Unsupported(
1220 "Unsupported statement in loop body".into(),
1221 ));
1222 }
1223 }
1224 }
1225
1226 let closing_indent = " ".repeat(self.indent);
1228 output.push_str(&closing_indent);
1229
1230 Ok(output)
1231 }
1232
1233 fn try_parse_shared_declaration(
1242 &self,
1243 local: &syn::Local,
1244 var_name: &str,
1245 ) -> Result<Option<SharedMemoryDecl>> {
1246 if let Pat::Type(pat_type) = &local.pat {
1248 let type_str = pat_type.ty.to_token_stream().to_string();
1249 return self.parse_shared_type(&type_str, var_name);
1250 }
1251
1252 if let Some(init) = &local.init {
1254 if let Expr::Call(call) = init.expr.as_ref() {
1255 if let Expr::Path(path) = call.func.as_ref() {
1256 let path_str = path.to_token_stream().to_string();
1257 return self.parse_shared_type(&path_str, var_name);
1258 }
1259 }
1260 }
1261
1262 Ok(None)
1263 }
1264
1265 fn parse_shared_type(
1267 &self,
1268 type_str: &str,
1269 var_name: &str,
1270 ) -> Result<Option<SharedMemoryDecl>> {
1271 let type_str = type_str
1273 .replace(" :: ", "::")
1274 .replace(" ::", "::")
1275 .replace(":: ", "::");
1276
1277 if type_str.contains("SharedTile") {
1279 if let Some(start) = type_str.find('<') {
1281 if let Some(end) = type_str.rfind('>') {
1282 let params = &type_str[start + 1..end];
1283 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1284
1285 if parts.len() >= 3 {
1286 let rust_type = parts[0];
1287 let width: usize = parts[1].parse().map_err(|_| {
1288 TranspileError::Unsupported("Invalid SharedTile width".into())
1289 })?;
1290 let height: usize = parts[2].parse().map_err(|_| {
1291 TranspileError::Unsupported("Invalid SharedTile height".into())
1292 })?;
1293
1294 let cuda_type = rust_to_cuda_element_type(rust_type);
1295 return Ok(Some(SharedMemoryDecl::tile(
1296 var_name, cuda_type, width, height,
1297 )));
1298 }
1299 }
1300 }
1301 }
1302
1303 if type_str.contains("SharedArray") {
1305 if let Some(start) = type_str.find('<') {
1306 if let Some(end) = type_str.rfind('>') {
1307 let params = &type_str[start + 1..end];
1308 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1309
1310 if parts.len() >= 2 {
1311 let rust_type = parts[0];
1312 let size: usize = parts[1].parse().map_err(|_| {
1313 TranspileError::Unsupported("Invalid SharedArray size".into())
1314 })?;
1315
1316 let cuda_type = rust_to_cuda_element_type(rust_type);
1317 return Ok(Some(SharedMemoryDecl::array(var_name, cuda_type, size)));
1318 }
1319 }
1320 }
1321 }
1322
1323 Ok(None)
1324 }
1325
1326 fn try_transpile_shared_method_call(
1328 &self,
1329 receiver: &str,
1330 method_name: &str,
1331 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1332 ) -> Option<Result<String>> {
1333 let shared_info = self.shared_vars.get(receiver)?;
1334
1335 match method_name {
1336 "get" => {
1337 if shared_info.is_tile {
1339 if args.len() >= 2 {
1340 let x = self.transpile_expr(&args[0]).ok()?;
1341 let y = self.transpile_expr(&args[1]).ok()?;
1342 Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1344 } else {
1345 Some(Err(TranspileError::Unsupported(
1346 "SharedTile.get requires x and y arguments".into(),
1347 )))
1348 }
1349 } else {
1350 if !args.is_empty() {
1352 let idx = self.transpile_expr(&args[0]).ok()?;
1353 Some(Ok(format!("{}[{}]", receiver, idx)))
1354 } else {
1355 Some(Err(TranspileError::Unsupported(
1356 "SharedArray.get requires index argument".into(),
1357 )))
1358 }
1359 }
1360 }
1361 "set" => {
1362 if shared_info.is_tile {
1364 if args.len() >= 3 {
1365 let x = self.transpile_expr(&args[0]).ok()?;
1366 let y = self.transpile_expr(&args[1]).ok()?;
1367 let val = self.transpile_expr(&args[2]).ok()?;
1368 Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1369 } else {
1370 Some(Err(TranspileError::Unsupported(
1371 "SharedTile.set requires x, y, and value arguments".into(),
1372 )))
1373 }
1374 } else {
1375 if args.len() >= 2 {
1377 let idx = self.transpile_expr(&args[0]).ok()?;
1378 let val = self.transpile_expr(&args[1]).ok()?;
1379 Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1380 } else {
1381 Some(Err(TranspileError::Unsupported(
1382 "SharedArray.set requires index and value arguments".into(),
1383 )))
1384 }
1385 }
1386 }
1387 "width" | "height" | "size" => {
1388 match method_name {
1390 "width" if shared_info.is_tile => {
1391 Some(Ok(shared_info.dimensions[1].to_string()))
1392 }
1393 "height" if shared_info.is_tile => {
1394 Some(Ok(shared_info.dimensions[0].to_string()))
1395 }
1396 "size" => {
1397 let total: usize = shared_info.dimensions.iter().product();
1398 Some(Ok(total.to_string()))
1399 }
1400 _ => None,
1401 }
1402 }
1403 _ => None,
1404 }
1405 }
1406
1407 fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1409 let scrutinee = self.transpile_expr(&match_expr.expr)?;
1410 let mut output = format!("switch ({scrutinee}) {{\n");
1411
1412 for arm in &match_expr.arms {
1413 let case_label = self.transpile_match_pattern(&arm.pat)?;
1415
1416 if case_label == "default" || case_label.starts_with("/*") {
1417 output.push_str(" default: {\n");
1418 } else {
1419 output.push_str(&format!(" case {case_label}: {{\n"));
1420 }
1421
1422 match arm.body.as_ref() {
1424 Expr::Block(block) => {
1425 for stmt in &block.block.stmts {
1427 let stmt_str = self.transpile_stmt_inline(stmt)?;
1428 output.push_str(&format!(" {stmt_str}\n"));
1429 }
1430 }
1431 _ => {
1432 let body = self.transpile_expr(&arm.body)?;
1434 output.push_str(&format!(" {body};\n"));
1435 }
1436 }
1437
1438 output.push_str(" break;\n");
1439 output.push_str(" }\n");
1440 }
1441
1442 output.push_str(" }");
1443 Ok(output)
1444 }
1445
1446 fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1448 match pat {
1449 Pat::Lit(pat_lit) => {
1450 match &pat_lit.lit {
1452 Lit::Int(i) => Ok(i.to_string()),
1453 Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
1454 _ => Err(TranspileError::Unsupported(
1455 "Non-integer literal in match pattern".into(),
1456 )),
1457 }
1458 }
1459 Pat::Wild(_) => {
1460 Ok("default".to_string())
1462 }
1463 Pat::Ident(ident) => {
1464 Ok(format!("/* {} */ default", ident.ident))
1467 }
1468 Pat::Or(pat_or) => {
1469 if let Some(first) = pat_or.cases.first() {
1473 self.transpile_match_pattern(first)
1474 } else {
1475 Err(TranspileError::Unsupported("Empty or pattern".into()))
1476 }
1477 }
1478 _ => Err(TranspileError::Unsupported(format!(
1479 "Match pattern: {}",
1480 pat.to_token_stream()
1481 ))),
1482 }
1483 }
1484
1485 fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1487 match stmt {
1488 Stmt::Local(local) => {
1489 let var_name = match &local.pat {
1490 Pat::Ident(ident) => ident.ident.to_string(),
1491 Pat::Type(pat_type) => {
1492 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1493 ident.ident.to_string()
1494 } else {
1495 return Err(TranspileError::Unsupported(
1496 "Complex pattern in let binding".into(),
1497 ));
1498 }
1499 }
1500 _ => {
1501 return Err(TranspileError::Unsupported(
1502 "Complex pattern in let binding".into(),
1503 ))
1504 }
1505 };
1506
1507 if let Some(init) = &local.init {
1508 let expr_str = self.transpile_expr(&init.expr)?;
1509 let type_str = self.infer_cuda_type(&init.expr);
1510 Ok(format!("{type_str} {var_name} = {expr_str};"))
1511 } else {
1512 Ok(format!("float {var_name};"))
1513 }
1514 }
1515 Stmt::Expr(expr, semi) => {
1516 let expr_str = self.transpile_expr(expr)?;
1517 if semi.is_some() {
1518 Ok(format!("{expr_str};"))
1519 } else {
1520 Ok(format!("return {expr_str};"))
1521 }
1522 }
1523 _ => Err(TranspileError::Unsupported(
1524 "Unsupported statement in match arm".into(),
1525 )),
1526 }
1527 }
1528
1529 fn infer_cuda_type(&self, expr: &Expr) -> &'static str {
1531 match expr {
1532 Expr::Lit(lit) => match &lit.lit {
1533 Lit::Float(_) => "float",
1534 Lit::Int(_) => "int",
1535 Lit::Bool(_) => "int",
1536 _ => "float",
1537 },
1538 Expr::Binary(bin) => {
1539 let left_type = self.infer_cuda_type(&bin.left);
1541 let right_type = self.infer_cuda_type(&bin.right);
1542 if left_type == "int" && right_type == "int" {
1544 "int"
1545 } else {
1546 "float"
1547 }
1548 }
1549 Expr::Call(call) => {
1550 if let Ok(func) = self.transpile_expr(&call.func) {
1552 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
1553 let cuda_name = intrinsic.to_cuda_string();
1554 if cuda_name.contains("Idx") || cuda_name.contains("Dim") {
1556 return "int";
1557 }
1558 }
1559 }
1560 "float"
1561 }
1562 Expr::Index(_) => "float", Expr::Cast(cast) => {
1564 if let Ok(cuda_type) = self.type_mapper.map_type(&cast.ty) {
1566 let s = cuda_type.to_cuda_string();
1567 if s.contains("int") || s.contains("size_t") || s == "unsigned long long" {
1568 return "int";
1569 }
1570 }
1571 "float"
1572 }
1573 Expr::Reference(ref_expr) => {
1574 match ref_expr.expr.as_ref() {
1577 Expr::Index(idx_expr) => {
1578 if let Expr::Path(path) = &*idx_expr.expr {
1580 let name = path
1581 .path
1582 .segments
1583 .iter()
1584 .map(|s| s.ident.to_string())
1585 .collect::<Vec<_>>()
1586 .join("::");
1587 if name.contains("transaction") || name.contains("Transaction") {
1589 return "GpuTransaction*";
1590 }
1591 if name.contains("profile") || name.contains("Profile") {
1592 return "GpuCustomerProfile*";
1593 }
1594 if name.contains("alert") || name.contains("Alert") {
1595 return "GpuAlert*";
1596 }
1597 }
1598 "float*" }
1600 _ => "void*",
1601 }
1602 }
1603 Expr::MethodCall(_) => "float",
1604 Expr::Field(field) => {
1605 let member_name = match &field.member {
1607 syn::Member::Named(ident) => ident.to_string(),
1608 syn::Member::Unnamed(idx) => idx.index.to_string(),
1609 };
1610 if member_name.contains("count") || member_name.contains("_count") {
1612 return "unsigned int";
1613 }
1614 if member_name.contains("threshold") || member_name.ends_with("_id") {
1615 return "unsigned long long";
1616 }
1617 if member_name.ends_with("_pct") {
1618 return "unsigned char";
1619 }
1620 "float"
1621 }
1622 Expr::Path(path) => {
1623 let name = path
1625 .path
1626 .segments
1627 .iter()
1628 .map(|s| s.ident.to_string())
1629 .collect::<Vec<_>>()
1630 .join("::");
1631 if name.contains("threshold")
1632 || name.contains("count")
1633 || name == "idx"
1634 || name == "n"
1635 {
1636 return "int";
1637 }
1638 "float"
1639 }
1640 Expr::If(if_expr) => {
1641 if let Some((_, else_branch)) = &if_expr.else_branch {
1643 if let Expr::Block(block) = else_branch.as_ref() {
1644 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
1645 return self.infer_cuda_type(expr);
1646 }
1647 }
1648 }
1649 if let Some(Stmt::Expr(expr, None)) = if_expr.then_branch.stmts.last() {
1651 return self.infer_cuda_type(expr);
1652 }
1653 "float"
1654 }
1655 _ => "float",
1656 }
1657 }
1658}
1659
1660pub fn transpile_function(func: &ItemFn) -> Result<String> {
1662 let mut transpiler = CudaTranspiler::new_generic();
1663
1664 let name = func.sig.ident.to_string();
1666
1667 let mut params = Vec::new();
1668 for param in &func.sig.inputs {
1669 if let FnArg::Typed(pat_type) = param {
1670 let param_name = match pat_type.pat.as_ref() {
1671 Pat::Ident(ident) => ident.ident.to_string(),
1672 _ => continue,
1673 };
1674
1675 let cuda_type = transpiler.type_mapper.map_type(&pat_type.ty)?;
1676 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
1677 }
1678 }
1679
1680 let return_type = match &func.sig.output {
1682 ReturnType::Default => "void".to_string(),
1683 ReturnType::Type(_, ty) => transpiler.type_mapper.map_type(ty)?.to_cuda_string(),
1684 };
1685
1686 let body = transpiler.transpile_block(&func.block)?;
1688
1689 Ok(format!(
1690 "__device__ {return_type} {name}({params}) {{\n{body}}}\n",
1691 params = params.join(", ")
1692 ))
1693}
1694
1695#[cfg(test)]
1696mod tests {
1697 use super::*;
1698 use syn::parse_quote;
1699
1700 #[test]
1701 fn test_simple_arithmetic() {
1702 let transpiler = CudaTranspiler::new_generic();
1703
1704 let expr: Expr = parse_quote!(a + b * 2.0);
1705 let result = transpiler.transpile_expr(&expr).unwrap();
1706 assert_eq!(result, "a + b * 2.0f");
1707 }
1708
1709 #[test]
1710 fn test_let_binding() {
1711 let mut transpiler = CudaTranspiler::new_generic();
1712
1713 let stmt: Stmt = parse_quote!(let x = a + b;);
1714 let result = transpiler.transpile_stmt(&stmt).unwrap();
1715 assert!(result.contains("float x = a + b;"));
1716 }
1717
1718 #[test]
1719 fn test_array_index() {
1720 let transpiler = CudaTranspiler::new_generic();
1721
1722 let expr: Expr = parse_quote!(data[idx]);
1723 let result = transpiler.transpile_expr(&expr).unwrap();
1724 assert_eq!(result, "data[idx]");
1725 }
1726
1727 #[test]
1728 fn test_stencil_intrinsics() {
1729 let config = StencilConfig::new("test")
1730 .with_tile_size(16, 16)
1731 .with_halo(1);
1732 let mut transpiler = CudaTranspiler::new(config);
1733 transpiler.grid_pos_vars.push("pos".to_string());
1734
1735 let expr: Expr = parse_quote!(pos.idx());
1737 let result = transpiler.transpile_expr(&expr).unwrap();
1738 assert_eq!(result, "idx");
1739
1740 let expr: Expr = parse_quote!(pos.north(p));
1742 let result = transpiler.transpile_expr(&expr).unwrap();
1743 assert_eq!(result, "p[idx - 18]");
1744
1745 let expr: Expr = parse_quote!(pos.east(p));
1747 let result = transpiler.transpile_expr(&expr).unwrap();
1748 assert_eq!(result, "p[idx + 1]");
1749 }
1750
1751 #[test]
1752 fn test_ternary_if() {
1753 let transpiler = CudaTranspiler::new_generic();
1754
1755 let expr: Expr = parse_quote!(if x > 0.0 { x } else { -x });
1756 let result = transpiler.transpile_expr(&expr).unwrap();
1757 assert!(result.contains("?"));
1758 assert!(result.contains(":"));
1759 }
1760
1761 #[test]
1762 fn test_full_stencil_kernel() {
1763 let func: ItemFn = parse_quote! {
1764 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
1765 let curr = p[pos.idx()];
1766 let prev = p_prev[pos.idx()];
1767 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
1768 p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
1769 }
1770 };
1771
1772 let config = StencilConfig::new("fdtd")
1773 .with_tile_size(16, 16)
1774 .with_halo(1);
1775
1776 let mut transpiler = CudaTranspiler::new(config);
1777 let cuda = transpiler.transpile_stencil(&func).unwrap();
1778
1779 assert!(cuda.contains("extern \"C\" __global__"));
1781 assert!(cuda.contains("threadIdx.x"));
1782 assert!(cuda.contains("threadIdx.y"));
1783 assert!(cuda.contains("buffer_width = 18"));
1784 assert!(cuda.contains("const float* __restrict__ p"));
1785 assert!(cuda.contains("float* __restrict__ p_prev"));
1786 assert!(!cuda.contains("GridPos")); println!("Generated CUDA:\n{}", cuda);
1789 }
1790
1791 #[test]
1792 fn test_early_return() {
1793 let mut transpiler = CudaTranspiler::new_generic();
1794
1795 let stmt: Stmt = parse_quote!(return;);
1796 let result = transpiler.transpile_stmt(&stmt).unwrap();
1797 assert!(result.contains("return;"));
1798
1799 let stmt_val: Stmt = parse_quote!(return 42;);
1800 let result_val = transpiler.transpile_stmt(&stmt_val).unwrap();
1801 assert!(result_val.contains("return 42;"));
1802 }
1803
1804 #[test]
1805 fn test_match_to_switch() {
1806 let transpiler = CudaTranspiler::new_generic();
1807
1808 let expr: Expr = parse_quote! {
1809 match edge {
1810 0 => { idx = 1 * 18 + i; }
1811 1 => { idx = 16 * 18 + i; }
1812 _ => { idx = 0; }
1813 }
1814 };
1815
1816 let result = transpiler.transpile_expr(&expr).unwrap();
1817 assert!(
1818 result.contains("switch (edge)"),
1819 "Should generate switch: {}",
1820 result
1821 );
1822 assert!(result.contains("case 0:"), "Should have case 0: {}", result);
1823 assert!(result.contains("case 1:"), "Should have case 1: {}", result);
1824 assert!(
1825 result.contains("default:"),
1826 "Should have default: {}",
1827 result
1828 );
1829 assert!(result.contains("break;"), "Should have break: {}", result);
1830
1831 println!("Generated switch:\n{}", result);
1832 }
1833
1834 #[test]
1835 fn test_block_idx_intrinsics() {
1836 let transpiler = CudaTranspiler::new_generic();
1837
1838 let expr: Expr = parse_quote!(block_idx_x());
1840 let result = transpiler.transpile_expr(&expr).unwrap();
1841 assert_eq!(result, "blockIdx.x");
1842
1843 let expr2: Expr = parse_quote!(thread_idx_y());
1845 let result2 = transpiler.transpile_expr(&expr2).unwrap();
1846 assert_eq!(result2, "threadIdx.y");
1847
1848 let expr3: Expr = parse_quote!(grid_dim_x());
1850 let result3 = transpiler.transpile_expr(&expr3).unwrap();
1851 assert_eq!(result3, "gridDim.x");
1852 }
1853
1854 #[test]
1855 fn test_global_index_calculation() {
1856 let transpiler = CudaTranspiler::new_generic();
1857
1858 let expr: Expr = parse_quote!(block_idx_x() * block_dim_x() + thread_idx_x());
1860 let result = transpiler.transpile_expr(&expr).unwrap();
1861 assert!(result.contains("blockIdx.x"), "Should contain blockIdx.x");
1862 assert!(result.contains("blockDim.x"), "Should contain blockDim.x");
1863 assert!(result.contains("threadIdx.x"), "Should contain threadIdx.x");
1864
1865 println!("Global index expression: {}", result);
1866 }
1867
1868 #[test]
1871 fn test_for_loop_transpile() {
1872 let transpiler = CudaTranspiler::new_generic();
1873
1874 let expr: Expr = parse_quote! {
1875 for i in 0..n {
1876 data[i] = 0.0;
1877 }
1878 };
1879
1880 let result = transpiler.transpile_expr(&expr).unwrap();
1881 assert!(
1882 result.contains("for (int i = 0; i < n; i++)"),
1883 "Should generate for loop header: {}",
1884 result
1885 );
1886 assert!(
1887 result.contains("data[i] = 0.0f"),
1888 "Should contain loop body: {}",
1889 result
1890 );
1891
1892 println!("Generated for loop:\n{}", result);
1893 }
1894
1895 #[test]
1896 fn test_for_loop_inclusive_range() {
1897 let transpiler = CudaTranspiler::new_generic();
1898
1899 let expr: Expr = parse_quote! {
1900 for i in 1..=10 {
1901 sum += i;
1902 }
1903 };
1904
1905 let result = transpiler.transpile_expr(&expr).unwrap();
1906 assert!(
1907 result.contains("for (int i = 1; i <= 10; i++)"),
1908 "Should generate inclusive range: {}",
1909 result
1910 );
1911
1912 println!("Generated inclusive for loop:\n{}", result);
1913 }
1914
1915 #[test]
1916 fn test_while_loop_transpile() {
1917 let transpiler = CudaTranspiler::new_generic();
1918
1919 let expr: Expr = parse_quote! {
1920 while i < 10 {
1921 i += 1;
1922 }
1923 };
1924
1925 let result = transpiler.transpile_expr(&expr).unwrap();
1926 assert!(
1927 result.contains("while (i < 10)"),
1928 "Should generate while loop: {}",
1929 result
1930 );
1931 assert!(
1932 result.contains("i += 1"),
1933 "Should contain loop body: {}",
1934 result
1935 );
1936
1937 println!("Generated while loop:\n{}", result);
1938 }
1939
1940 #[test]
1941 fn test_while_loop_negation() {
1942 let transpiler = CudaTranspiler::new_generic();
1943
1944 let expr: Expr = parse_quote! {
1945 while !done {
1946 process();
1947 }
1948 };
1949
1950 let result = transpiler.transpile_expr(&expr).unwrap();
1951 assert!(
1952 result.contains("while (!(done))"),
1953 "Should negate condition: {}",
1954 result
1955 );
1956
1957 println!("Generated while loop with negation:\n{}", result);
1958 }
1959
1960 #[test]
1961 fn test_infinite_loop_transpile() {
1962 let transpiler = CudaTranspiler::new_generic();
1963
1964 let expr: Expr = parse_quote! {
1965 loop {
1966 process();
1967 }
1968 };
1969
1970 let result = transpiler.transpile_expr(&expr).unwrap();
1971 assert!(
1972 result.contains("while (true)"),
1973 "Should generate infinite loop: {}",
1974 result
1975 );
1976 assert!(
1977 result.contains("process()"),
1978 "Should contain loop body: {}",
1979 result
1980 );
1981
1982 println!("Generated infinite loop:\n{}", result);
1983 }
1984
1985 #[test]
1986 fn test_break_transpile() {
1987 let transpiler = CudaTranspiler::new_generic();
1988
1989 let expr: Expr = parse_quote!(break);
1990 let result = transpiler.transpile_expr(&expr).unwrap();
1991 assert_eq!(result, "break");
1992 }
1993
1994 #[test]
1995 fn test_continue_transpile() {
1996 let transpiler = CudaTranspiler::new_generic();
1997
1998 let expr: Expr = parse_quote!(continue);
1999 let result = transpiler.transpile_expr(&expr).unwrap();
2000 assert_eq!(result, "continue");
2001 }
2002
2003 #[test]
2004 fn test_loop_with_break() {
2005 let transpiler = CudaTranspiler::new_generic();
2006
2007 let expr: Expr = parse_quote! {
2008 loop {
2009 if done {
2010 break;
2011 }
2012 }
2013 };
2014
2015 let result = transpiler.transpile_expr(&expr).unwrap();
2016 assert!(
2017 result.contains("while (true)"),
2018 "Should generate infinite loop: {}",
2019 result
2020 );
2021 assert!(result.contains("break"), "Should contain break: {}", result);
2022
2023 println!("Generated loop with break:\n{}", result);
2024 }
2025
2026 #[test]
2027 fn test_nested_loops() {
2028 let transpiler = CudaTranspiler::new_generic();
2029
2030 let expr: Expr = parse_quote! {
2031 for i in 0..m {
2032 for j in 0..n {
2033 matrix[i * n + j] = 0.0;
2034 }
2035 }
2036 };
2037
2038 let result = transpiler.transpile_expr(&expr).unwrap();
2039 assert!(
2040 result.contains("for (int i = 0; i < m; i++)"),
2041 "Should have outer loop: {}",
2042 result
2043 );
2044 assert!(
2045 result.contains("for (int j = 0; j < n; j++)"),
2046 "Should have inner loop: {}",
2047 result
2048 );
2049
2050 println!("Generated nested loops:\n{}", result);
2051 }
2052
2053 #[test]
2054 fn test_stencil_mode_rejects_loops() {
2055 let config = StencilConfig::new("test")
2056 .with_tile_size(16, 16)
2057 .with_halo(1);
2058 let transpiler = CudaTranspiler::new(config);
2059
2060 let expr: Expr = parse_quote! {
2061 for i in 0..n {
2062 data[i] = 0.0;
2063 }
2064 };
2065
2066 let result = transpiler.transpile_expr(&expr);
2067 assert!(result.is_err(), "Stencil mode should reject loops");
2068 }
2069
2070 #[test]
2071 fn test_labeled_break_rejected() {
2072 let transpiler = CudaTranspiler::new_generic();
2073
2074 let break_expr = syn::ExprBreak {
2077 attrs: Vec::new(),
2078 break_token: syn::token::Break::default(),
2079 label: Some(syn::Lifetime::new("'outer", proc_macro2::Span::call_site())),
2080 expr: None,
2081 };
2082
2083 let result = transpiler.transpile_break(&break_expr);
2084 assert!(result.is_err(), "Labeled break should be rejected");
2085 }
2086
2087 #[test]
2088 fn test_full_kernel_with_loop() {
2089 let func: ItemFn = parse_quote! {
2090 fn fill_array(data: &mut [f32], n: i32) {
2091 for i in 0..n {
2092 data[i as usize] = 0.0;
2093 }
2094 }
2095 };
2096
2097 let mut transpiler = CudaTranspiler::new_generic();
2098 let cuda = transpiler.transpile_generic_kernel(&func).unwrap();
2099
2100 assert!(
2101 cuda.contains("extern \"C\" __global__"),
2102 "Should be global kernel: {}",
2103 cuda
2104 );
2105 assert!(
2106 cuda.contains("for (int i = 0; i < n; i++)"),
2107 "Should have for loop: {}",
2108 cuda
2109 );
2110
2111 println!("Generated kernel with loop:\n{}", cuda);
2112 }
2113
2114 #[test]
2115 fn test_persistent_kernel_pattern() {
2116 let transpiler = CudaTranspiler::with_mode(ValidationMode::RingKernel);
2118
2119 let expr: Expr = parse_quote! {
2120 while !should_terminate {
2121 if has_message {
2122 process_message();
2123 }
2124 }
2125 };
2126
2127 let result = transpiler.transpile_expr(&expr).unwrap();
2128 assert!(
2129 result.contains("while (!(should_terminate))"),
2130 "Should have persistent loop: {}",
2131 result
2132 );
2133 assert!(
2134 result.contains("if (has_message)"),
2135 "Should have message check: {}",
2136 result
2137 );
2138
2139 println!("Generated persistent kernel pattern:\n{}", result);
2140 }
2141
2142 #[test]
2145 fn test_shared_tile_declaration() {
2146 use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2147
2148 let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
2149 assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
2150
2151 let mut config = SharedMemoryConfig::new();
2152 config.add_tile("tile", "float", 16, 16);
2153 assert_eq!(config.total_bytes(), 16 * 16 * 4); let decls = config.generate_declarations(" ");
2156 assert!(decls.contains("__shared__ float tile[16][16];"));
2157 }
2158
2159 #[test]
2160 fn test_shared_array_declaration() {
2161 use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2162
2163 let decl = SharedMemoryDecl::array("buffer", "float", 256);
2164 assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
2165
2166 let mut config = SharedMemoryConfig::new();
2167 config.add_array("buffer", "float", 256);
2168 assert_eq!(config.total_bytes(), 256 * 4); }
2170
2171 #[test]
2172 fn test_shared_memory_access_expressions() {
2173 use crate::shared::SharedMemoryDecl;
2174
2175 let tile = SharedMemoryDecl::tile("tile", "float", 16, 16);
2176 assert_eq!(
2177 tile.to_cuda_access(&["y".to_string(), "x".to_string()]),
2178 "tile[y][x]"
2179 );
2180
2181 let arr = SharedMemoryDecl::array("buf", "int", 128);
2182 assert_eq!(arr.to_cuda_access(&["i".to_string()]), "buf[i]");
2183 }
2184
2185 #[test]
2186 fn test_parse_shared_tile_type() {
2187 use crate::shared::parse_shared_tile_type;
2188
2189 let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
2190 assert_eq!(result, Some(("f32".to_string(), 16, 16)));
2191
2192 let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
2193 assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
2194
2195 let invalid = parse_shared_tile_type("Vec<f32>");
2196 assert_eq!(invalid, None);
2197 }
2198
2199 #[test]
2200 fn test_parse_shared_array_type() {
2201 use crate::shared::parse_shared_array_type;
2202
2203 let result = parse_shared_array_type("SharedArray::<f32, 256>");
2204 assert_eq!(result, Some(("f32".to_string(), 256)));
2205
2206 let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
2207 assert_eq!(result2, Some(("u32".to_string(), 1024)));
2208
2209 let invalid = parse_shared_array_type("Vec<f32>");
2210 assert_eq!(invalid, None);
2211 }
2212
2213 #[test]
2214 fn test_rust_to_cuda_element_types() {
2215 use crate::shared::rust_to_cuda_element_type;
2216
2217 assert_eq!(rust_to_cuda_element_type("f32"), "float");
2218 assert_eq!(rust_to_cuda_element_type("f64"), "double");
2219 assert_eq!(rust_to_cuda_element_type("i32"), "int");
2220 assert_eq!(rust_to_cuda_element_type("u32"), "unsigned int");
2221 assert_eq!(rust_to_cuda_element_type("i64"), "long long");
2222 assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
2223 assert_eq!(rust_to_cuda_element_type("bool"), "int");
2224 }
2225
2226 #[test]
2227 fn test_shared_memory_total_bytes() {
2228 use crate::shared::SharedMemoryConfig;
2229
2230 let mut config = SharedMemoryConfig::new();
2231 config.add_tile("tile1", "float", 16, 16); config.add_tile("tile2", "double", 8, 8); config.add_array("temp", "int", 64); assert_eq!(config.total_bytes(), 1024 + 512 + 256);
2236 }
2237
2238 #[test]
2239 fn test_transpiler_shared_var_tracking() {
2240 let mut transpiler = CudaTranspiler::new_generic();
2241
2242 transpiler.shared_vars.insert(
2244 "tile".to_string(),
2245 SharedVarInfo {
2246 name: "tile".to_string(),
2247 is_tile: true,
2248 dimensions: vec![16, 16],
2249 element_type: "float".to_string(),
2250 },
2251 );
2252
2253 assert!(transpiler.shared_vars.contains_key("tile"));
2255 assert!(transpiler.shared_vars.get("tile").unwrap().is_tile);
2256 }
2257
2258 #[test]
2259 fn test_shared_tile_get_transpilation() {
2260 let mut transpiler = CudaTranspiler::new_generic();
2261
2262 transpiler.shared_vars.insert(
2264 "tile".to_string(),
2265 SharedVarInfo {
2266 name: "tile".to_string(),
2267 is_tile: true,
2268 dimensions: vec![16, 16],
2269 element_type: "float".to_string(),
2270 },
2271 );
2272
2273 let result = transpiler.try_transpile_shared_method_call(
2275 "tile",
2276 "get",
2277 &syn::punctuated::Punctuated::new(),
2278 );
2279
2280 assert!(result.is_none() || result.unwrap().is_err());
2282 }
2283
2284 #[test]
2285 fn test_shared_array_access() {
2286 let mut transpiler = CudaTranspiler::new_generic();
2287
2288 transpiler.shared_vars.insert(
2290 "buffer".to_string(),
2291 SharedVarInfo {
2292 name: "buffer".to_string(),
2293 is_tile: false,
2294 dimensions: vec![256],
2295 element_type: "float".to_string(),
2296 },
2297 );
2298
2299 assert!(!transpiler.shared_vars.get("buffer").unwrap().is_tile);
2300 assert_eq!(
2301 transpiler.shared_vars.get("buffer").unwrap().dimensions,
2302 vec![256]
2303 );
2304 }
2305
2306 #[test]
2307 fn test_full_kernel_with_shared_memory() {
2308 use crate::shared::SharedMemoryConfig;
2310
2311 let mut config = SharedMemoryConfig::new();
2312 config.add_tile("smem", "float", 16, 16);
2313
2314 let decls = config.generate_declarations(" ");
2315 assert!(decls.contains("__shared__ float smem[16][16];"));
2316 assert!(!config.is_empty());
2317 }
2318
2319 #[test]
2322 fn test_struct_literal_transpile() {
2323 let transpiler = CudaTranspiler::new_generic();
2324
2325 let expr: Expr = parse_quote! {
2326 Point { x: 1.0, y: 2.0 }
2327 };
2328
2329 let result = transpiler.transpile_expr(&expr).unwrap();
2330 assert!(
2331 result.contains("Point"),
2332 "Should contain struct name: {}",
2333 result
2334 );
2335 assert!(result.contains(".x ="), "Should have field x: {}", result);
2336 assert!(result.contains(".y ="), "Should have field y: {}", result);
2337 assert!(
2338 result.contains("1.0f"),
2339 "Should have value 1.0f: {}",
2340 result
2341 );
2342 assert!(
2343 result.contains("2.0f"),
2344 "Should have value 2.0f: {}",
2345 result
2346 );
2347
2348 println!("Generated struct literal: {}", result);
2349 }
2350
2351 #[test]
2352 fn test_struct_literal_with_expressions() {
2353 let transpiler = CudaTranspiler::new_generic();
2354
2355 let expr: Expr = parse_quote! {
2356 Response { value: x * 2.0, id: idx as u64 }
2357 };
2358
2359 let result = transpiler.transpile_expr(&expr).unwrap();
2360 assert!(
2361 result.contains("Response"),
2362 "Should contain struct name: {}",
2363 result
2364 );
2365 assert!(
2366 result.contains(".value = x * 2.0f"),
2367 "Should have computed value: {}",
2368 result
2369 );
2370 assert!(result.contains(".id ="), "Should have id field: {}", result);
2371
2372 println!("Generated struct with expressions: {}", result);
2373 }
2374
2375 #[test]
2376 fn test_struct_literal_in_return() {
2377 let mut transpiler = CudaTranspiler::new_generic();
2378
2379 let stmt: Stmt = parse_quote! {
2380 return MyStruct { a: 1, b: 2.0 };
2381 };
2382
2383 let result = transpiler.transpile_stmt(&stmt).unwrap();
2384 assert!(result.contains("return"), "Should have return: {}", result);
2385 assert!(
2386 result.contains("MyStruct"),
2387 "Should contain struct name: {}",
2388 result
2389 );
2390
2391 println!("Generated return with struct: {}", result);
2392 }
2393
2394 #[test]
2395 fn test_struct_literal_compound_literal_format() {
2396 let transpiler = CudaTranspiler::new_generic();
2397
2398 let expr: Expr = parse_quote! {
2399 Vec3 { x: a, y: b, z: c }
2400 };
2401
2402 let result = transpiler.transpile_expr(&expr).unwrap();
2403 assert!(
2405 result.starts_with("(Vec3){"),
2406 "Should use compound literal format: {}",
2407 result
2408 );
2409 assert!(
2410 result.ends_with("}"),
2411 "Should end with closing brace: {}",
2412 result
2413 );
2414
2415 println!("Generated compound literal: {}", result);
2416 }
2417
2418 #[test]
2421 fn test_reference_to_array_element() {
2422 let transpiler = CudaTranspiler::new_generic();
2423
2424 let expr: Expr = parse_quote! {
2425 &arr[idx]
2426 };
2427
2428 let result = transpiler.transpile_expr(&expr).unwrap();
2429 assert_eq!(
2430 result, "&arr[idx]",
2431 "Should produce address-of array element"
2432 );
2433 }
2434
2435 #[test]
2436 fn test_mutable_reference_to_array_element() {
2437 let transpiler = CudaTranspiler::new_generic();
2438
2439 let expr: Expr = parse_quote! {
2440 &mut arr[idx * 4 + offset]
2441 };
2442
2443 let result = transpiler.transpile_expr(&expr).unwrap();
2444 assert!(
2445 result.contains("&arr["),
2446 "Should produce address-of: {}",
2447 result
2448 );
2449 assert!(
2450 result.contains("idx * 4"),
2451 "Should have index expression: {}",
2452 result
2453 );
2454 }
2455
2456 #[test]
2457 fn test_reference_to_variable() {
2458 let transpiler = CudaTranspiler::new_generic();
2459
2460 let expr: Expr = parse_quote! {
2461 &value
2462 };
2463
2464 let result = transpiler.transpile_expr(&expr).unwrap();
2465 assert_eq!(result, "&value", "Should produce address-of variable");
2466 }
2467
2468 #[test]
2469 fn test_reference_to_struct_field() {
2470 let transpiler = CudaTranspiler::new_generic();
2471
2472 let expr: Expr = parse_quote! {
2473 &alerts[(idx as usize) * 4 + alert_idx as usize]
2474 };
2475
2476 let result = transpiler.transpile_expr(&expr).unwrap();
2477 assert!(
2478 result.starts_with("&alerts["),
2479 "Should have address-of array: {}",
2480 result
2481 );
2482
2483 println!("Generated reference: {}", result);
2484 }
2485
2486 #[test]
2487 fn test_complex_reference_pattern() {
2488 let mut transpiler = CudaTranspiler::new_generic();
2489
2490 let stmt: Stmt = parse_quote! {
2492 let alert = &mut alerts[(idx as usize) * 4 + alert_idx as usize];
2493 };
2494
2495 let result = transpiler.transpile_stmt(&stmt).unwrap();
2496 assert!(
2497 result.contains("alert ="),
2498 "Should have variable assignment: {}",
2499 result
2500 );
2501 assert!(
2502 result.contains("&alerts["),
2503 "Should have reference to array: {}",
2504 result
2505 );
2506
2507 println!("Generated statement: {}", result);
2508 }
2509}