1use crate::handler::{ContextMethod, HandlerSignature};
6use crate::intrinsics::{IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
7use crate::loops::{extract_loop_var, RangeInfo};
8use crate::shared::{rust_to_cuda_element_type, SharedMemoryConfig, SharedMemoryDecl};
9use crate::stencil::StencilConfig;
10use crate::types::{is_grid_pos_type, is_ring_context_type, TypeMapper};
11use crate::validation::ValidationMode;
12use crate::{Result, TranspileError};
13use quote::ToTokens;
14use syn::{
15 BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
16 ExprIf, ExprIndex, ExprLet, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
17 ExprReference, ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat,
18 ReturnType, Stmt, UnOp,
19};
20
21pub struct CudaTranspiler {
23 config: Option<StencilConfig>,
25 type_mapper: TypeMapper,
27 intrinsics: IntrinsicRegistry,
29 grid_pos_vars: Vec<String>,
31 context_vars: Vec<String>,
33 indent: usize,
35 validation_mode: ValidationMode,
37 shared_memory: SharedMemoryConfig,
39 pub shared_vars: std::collections::HashMap<String, SharedVarInfo>,
41 ring_kernel_mode: bool,
43 pointer_vars: std::collections::HashSet<String>,
45}
46
47#[derive(Debug, Clone)]
49pub struct SharedVarInfo {
50 pub name: String,
52 pub is_tile: bool,
54 pub dimensions: Vec<usize>,
56 pub element_type: String,
58}
59
60impl CudaTranspiler {
61 pub fn new(config: StencilConfig) -> Self {
63 Self {
64 config: Some(config),
65 type_mapper: TypeMapper::new(),
66 intrinsics: IntrinsicRegistry::new(),
67 grid_pos_vars: Vec::new(),
68 context_vars: Vec::new(),
69 indent: 1, validation_mode: ValidationMode::Stencil,
71 shared_memory: SharedMemoryConfig::new(),
72 shared_vars: std::collections::HashMap::new(),
73 ring_kernel_mode: false,
74 pointer_vars: std::collections::HashSet::new(),
75 }
76 }
77
78 pub fn new_generic() -> Self {
80 Self {
81 config: None,
82 type_mapper: TypeMapper::new(),
83 intrinsics: IntrinsicRegistry::new(),
84 grid_pos_vars: Vec::new(),
85 context_vars: Vec::new(),
86 indent: 1,
87 validation_mode: ValidationMode::Generic,
88 shared_memory: SharedMemoryConfig::new(),
89 shared_vars: std::collections::HashMap::new(),
90 ring_kernel_mode: false,
91 pointer_vars: std::collections::HashSet::new(),
92 }
93 }
94
95 pub fn with_mode(mode: ValidationMode) -> Self {
97 Self {
98 config: None,
99 type_mapper: TypeMapper::new(),
100 intrinsics: IntrinsicRegistry::new(),
101 grid_pos_vars: Vec::new(),
102 context_vars: Vec::new(),
103 indent: 1,
104 validation_mode: mode,
105 shared_memory: SharedMemoryConfig::new(),
106 shared_vars: std::collections::HashMap::new(),
107 ring_kernel_mode: false,
108 pointer_vars: std::collections::HashSet::new(),
109 }
110 }
111
112 pub fn for_ring_kernel() -> Self {
114 Self {
115 config: None,
116 type_mapper: crate::types::ring_kernel_type_mapper(),
117 intrinsics: IntrinsicRegistry::new(),
118 grid_pos_vars: Vec::new(),
119 context_vars: Vec::new(),
120 indent: 2, validation_mode: ValidationMode::Generic,
122 shared_memory: SharedMemoryConfig::new(),
123 shared_vars: std::collections::HashMap::new(),
124 ring_kernel_mode: true,
125 pointer_vars: std::collections::HashSet::new(),
126 }
127 }
128
129 pub fn set_validation_mode(&mut self, mode: ValidationMode) {
131 self.validation_mode = mode;
132 }
133
134 pub fn shared_memory(&self) -> &SharedMemoryConfig {
136 &self.shared_memory
137 }
138
139 fn indent_str(&self) -> String {
141 " ".repeat(self.indent)
142 }
143
144 pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
146 let config = self
147 .config
148 .as_ref()
149 .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
150 .clone();
151
152 for param in &func.sig.inputs {
154 if let FnArg::Typed(pat_type) = param {
155 if is_grid_pos_type(&pat_type.ty) {
156 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
157 self.grid_pos_vars.push(ident.ident.to_string());
158 }
159 }
160 }
161 }
162
163 let signature = self.transpile_kernel_signature(func)?;
165
166 let preamble = config.generate_preamble();
168
169 let body = self.transpile_block(&func.block)?;
171
172 Ok(format!(
173 "extern \"C\" __global__ void {signature} {{\n{preamble}\n{body}}}\n"
174 ))
175 }
176
177 pub fn transpile_generic_kernel(&mut self, func: &ItemFn) -> Result<String> {
183 let signature = self.transpile_generic_kernel_signature(func)?;
185
186 let body = self.transpile_block(&func.block)?;
188
189 Ok(format!(
190 "extern \"C\" __global__ void {signature} {{\n{body}}}\n"
191 ))
192 }
193
194 pub fn transpile_ring_kernel(
199 &mut self,
200 handler: &ItemFn,
201 config: &crate::ring_kernel::RingKernelConfig,
202 ) -> Result<String> {
203 use std::fmt::Write;
204
205 let handler_sig = HandlerSignature::parse(handler, &self.type_mapper)?;
207
208 for param in &handler.sig.inputs {
210 if let FnArg::Typed(pat_type) = param {
211 if is_ring_context_type(&pat_type.ty) {
212 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
213 self.context_vars.push(ident.ident.to_string());
214 }
215 }
216 }
217 }
218
219 self.ring_kernel_mode = true;
221
222 let mut output = String::new();
223
224 output.push_str(&crate::ring_kernel::generate_control_block_struct());
226 output.push('\n');
227
228 if config.enable_hlc {
229 output.push_str(&crate::ring_kernel::generate_hlc_struct());
230 output.push('\n');
231 }
232
233 if config.enable_k2k {
234 output.push_str(&crate::ring_kernel::generate_k2k_structs());
235 output.push('\n');
236 }
237
238 if let Some(ref msg_param) = handler_sig.message_param {
240 let type_name = msg_param
242 .rust_type
243 .trim_start_matches('&')
244 .trim_start_matches("mut ")
245 .trim();
246 if !type_name.is_empty() && type_name != "f32" && type_name != "i32" {
247 writeln!(output, "// Message type: {}", type_name).unwrap();
248 }
249 }
250
251 if let Some(ref ret_type) = handler_sig.return_type {
252 if ret_type.is_struct {
253 writeln!(output, "// Response type: {}", ret_type.rust_type).unwrap();
254 }
255 }
256
257 output.push_str(&config.generate_signature());
259 output.push_str(" {\n");
260
261 output.push_str(&config.generate_preamble(" "));
263
264 output.push_str(&config.generate_loop_header(" "));
266
267 if let Some(ref msg_param) = handler_sig.message_param {
269 let type_name = msg_param
270 .rust_type
271 .trim_start_matches('&')
272 .trim_start_matches("mut ")
273 .trim();
274 if !type_name.is_empty() {
275 writeln!(output, " // Message deserialization").unwrap();
276 writeln!(
277 output,
278 " // {}* {} = ({}*)msg_ptr;",
279 type_name, msg_param.name, type_name
280 )
281 .unwrap();
282 output.push('\n');
283 }
284 }
285
286 self.indent = 2; let handler_body = self.transpile_block(&handler.block)?;
289
290 writeln!(output, " // === USER HANDLER CODE ===").unwrap();
292 for line in handler_body.lines() {
293 if !line.trim().is_empty() {
294 writeln!(output, " {}", line).unwrap();
296 }
297 }
298 writeln!(output, " // === END HANDLER CODE ===").unwrap();
299
300 if let Some(ref ret_type) = handler_sig.return_type {
302 writeln!(output).unwrap();
303 writeln!(output, " // Response serialization").unwrap();
304 if ret_type.is_struct {
305 writeln!(output, " // memcpy(&output_buffer[_out_idx * RESP_SIZE], &response, sizeof({}));",
306 ret_type.cuda_type).unwrap();
307 }
308 }
309
310 output.push_str(&config.generate_message_complete(" "));
312
313 output.push_str(&config.generate_loop_footer(" "));
315
316 output.push_str(&config.generate_epilogue(" "));
318
319 output.push_str("}\n");
320
321 Ok(output)
322 }
323
324 fn transpile_generic_kernel_signature(&self, func: &ItemFn) -> Result<String> {
326 let name = func.sig.ident.to_string();
327
328 let mut params = Vec::new();
329 for param in &func.sig.inputs {
330 if let FnArg::Typed(pat_type) = param {
331 let param_name = match pat_type.pat.as_ref() {
332 Pat::Ident(ident) => ident.ident.to_string(),
333 _ => {
334 return Err(TranspileError::Unsupported(
335 "Complex pattern in parameter".into(),
336 ))
337 }
338 };
339
340 let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
341 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
342 }
343 }
344
345 Ok(format!("{}({})", name, params.join(", ")))
346 }
347
348 fn transpile_kernel_signature(&self, func: &ItemFn) -> Result<String> {
350 let name = func.sig.ident.to_string();
351
352 let mut params = Vec::new();
353 for param in &func.sig.inputs {
354 if let FnArg::Typed(pat_type) = param {
355 if is_grid_pos_type(&pat_type.ty) {
357 continue;
358 }
359
360 let param_name = match pat_type.pat.as_ref() {
361 Pat::Ident(ident) => ident.ident.to_string(),
362 _ => {
363 return Err(TranspileError::Unsupported(
364 "Complex pattern in parameter".into(),
365 ))
366 }
367 };
368
369 let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
370 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
371 }
372 }
373
374 Ok(format!("{}({})", name, params.join(", ")))
375 }
376
377 fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
379 let mut output = String::new();
380
381 for stmt in &block.stmts {
382 let stmt_str = self.transpile_stmt(stmt)?;
383 if !stmt_str.is_empty() {
384 output.push_str(&stmt_str);
385 }
386 }
387
388 Ok(output)
389 }
390
391 fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
393 match stmt {
394 Stmt::Local(local) => {
395 let indent = self.indent_str();
396
397 let var_name = match &local.pat {
399 Pat::Ident(ident) => ident.ident.to_string(),
400 Pat::Type(pat_type) => {
401 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
402 ident.ident.to_string()
403 } else {
404 return Err(TranspileError::Unsupported(
405 "Complex pattern in let binding".into(),
406 ));
407 }
408 }
409 _ => {
410 return Err(TranspileError::Unsupported(
411 "Complex pattern in let binding".into(),
412 ))
413 }
414 };
415
416 if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
418 self.shared_vars.insert(
420 var_name.clone(),
421 SharedVarInfo {
422 name: var_name.clone(),
423 is_tile: shared_decl.dimensions.len() == 2,
424 dimensions: shared_decl.dimensions.clone(),
425 element_type: shared_decl.element_type.clone(),
426 },
427 );
428
429 self.shared_memory.add(shared_decl.clone());
431
432 return Ok(format!("{indent}{}\n", shared_decl.to_cuda_decl()));
434 }
435
436 if let Some(init) = &local.init {
438 let expr_str = self.transpile_expr(&init.expr)?;
439
440 let type_str = self.infer_cuda_type(&init.expr);
443
444 if type_str.ends_with('*') {
446 self.pointer_vars.insert(var_name.clone());
447 }
448
449 Ok(format!("{indent}{type_str} {var_name} = {expr_str};\n"))
450 } else {
451 Ok(format!("{indent}float {var_name};\n"))
453 }
454 }
455 Stmt::Expr(expr, semi) => {
456 let indent = self.indent_str();
457
458 if let Expr::If(if_expr) = expr {
460 if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
462 {
463 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
464 let expr_str = self.transpile_expr(expr)?;
465 return Ok(format!("{indent}{expr_str};\n"));
466 }
467 }
468 }
469
470 let expr_str = self.transpile_expr(expr)?;
471
472 if semi.is_some() {
473 Ok(format!("{indent}{expr_str};\n"))
474 } else {
475 if matches!(expr, Expr::Return(_))
478 || expr_str.starts_with("return")
479 || expr_str.starts_with("if (")
480 {
481 Ok(format!("{indent}{expr_str};\n"))
482 } else {
483 Ok(format!("{indent}return {expr_str};\n"))
484 }
485 }
486 }
487 Stmt::Item(_) => {
488 Err(TranspileError::Unsupported("Item in function body".into()))
490 }
491 Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
492 }
493 }
494
495 fn transpile_expr(&self, expr: &Expr) -> Result<String> {
497 match expr {
498 Expr::Lit(lit) => self.transpile_lit(lit),
499 Expr::Path(path) => self.transpile_path(path),
500 Expr::Binary(bin) => self.transpile_binary(bin),
501 Expr::Unary(unary) => self.transpile_unary(unary),
502 Expr::Paren(paren) => self.transpile_paren(paren),
503 Expr::Index(index) => self.transpile_index(index),
504 Expr::Call(call) => self.transpile_call(call),
505 Expr::MethodCall(method) => self.transpile_method_call(method),
506 Expr::If(if_expr) => self.transpile_if(if_expr),
507 Expr::Assign(assign) => self.transpile_assign(assign),
508 Expr::Cast(cast) => self.transpile_cast(cast),
509 Expr::Match(match_expr) => self.transpile_match(match_expr),
510 Expr::Block(block) => {
511 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
513 self.transpile_expr(expr)
514 } else {
515 Err(TranspileError::Unsupported(
516 "Complex block expression".into(),
517 ))
518 }
519 }
520 Expr::Field(field) => {
521 let base = self.transpile_expr(&field.base)?;
523 let member = match &field.member {
524 syn::Member::Named(ident) => ident.to_string(),
525 syn::Member::Unnamed(idx) => idx.index.to_string(),
526 };
527
528 let accessor = if self.pointer_vars.contains(&base) {
530 "->"
531 } else {
532 "."
533 };
534 Ok(format!("{base}{accessor}{member}"))
535 }
536 Expr::Return(ret) => self.transpile_return(ret),
537 Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
538 Expr::While(while_loop) => self.transpile_while_loop(while_loop),
539 Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
540 Expr::Break(break_expr) => self.transpile_break(break_expr),
541 Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
542 Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
543 Expr::Reference(ref_expr) => self.transpile_reference(ref_expr),
544 Expr::Let(let_expr) => self.transpile_let_expr(let_expr),
545 Expr::Tuple(tuple) => {
546 let elements: Vec<String> = tuple
548 .elems
549 .iter()
550 .map(|e| self.transpile_expr(e))
551 .collect::<Result<_>>()?;
552 Ok(format!("({})", elements.join(", ")))
553 }
554 _ => Err(TranspileError::Unsupported(format!(
555 "Expression type: {}",
556 expr.to_token_stream()
557 ))),
558 }
559 }
560
561 fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
563 match &lit.lit {
564 Lit::Float(f) => {
565 let s = f.to_string();
566 if s.ends_with("f32") || !s.contains('.') {
568 let num = s.trim_end_matches("f32").trim_end_matches("f64");
569 Ok(format!("{num}f"))
570 } else if s.ends_with("f64") {
571 Ok(s.trim_end_matches("f64").to_string())
572 } else {
573 Ok(format!("{s}f"))
575 }
576 }
577 Lit::Int(i) => Ok(i.to_string()),
578 Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
579 _ => Err(TranspileError::Unsupported(format!(
580 "Literal type: {}",
581 lit.to_token_stream()
582 ))),
583 }
584 }
585
586 fn transpile_path(&self, path: &ExprPath) -> Result<String> {
588 let segments: Vec<_> = path
589 .path
590 .segments
591 .iter()
592 .map(|s| s.ident.to_string())
593 .collect();
594
595 if segments.len() == 1 {
596 Ok(segments[0].clone())
597 } else {
598 Ok(segments.join("::"))
599 }
600 }
601
602 fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
604 let left = self.transpile_expr(&bin.left)?;
605 let right = self.transpile_expr(&bin.right)?;
606
607 let op = match bin.op {
608 BinOp::Add(_) => "+",
609 BinOp::Sub(_) => "-",
610 BinOp::Mul(_) => "*",
611 BinOp::Div(_) => "/",
612 BinOp::Rem(_) => "%",
613 BinOp::And(_) => "&&",
614 BinOp::Or(_) => "||",
615 BinOp::BitXor(_) => "^",
616 BinOp::BitAnd(_) => "&",
617 BinOp::BitOr(_) => "|",
618 BinOp::Shl(_) => "<<",
619 BinOp::Shr(_) => ">>",
620 BinOp::Eq(_) => "==",
621 BinOp::Lt(_) => "<",
622 BinOp::Le(_) => "<=",
623 BinOp::Ne(_) => "!=",
624 BinOp::Ge(_) => ">=",
625 BinOp::Gt(_) => ">",
626 BinOp::AddAssign(_) => "+=",
627 BinOp::SubAssign(_) => "-=",
628 BinOp::MulAssign(_) => "*=",
629 BinOp::DivAssign(_) => "/=",
630 BinOp::RemAssign(_) => "%=",
631 BinOp::BitXorAssign(_) => "^=",
632 BinOp::BitAndAssign(_) => "&=",
633 BinOp::BitOrAssign(_) => "|=",
634 BinOp::ShlAssign(_) => "<<=",
635 BinOp::ShrAssign(_) => ">>=",
636 _ => {
637 return Err(TranspileError::Unsupported(format!(
638 "Binary operator: {}",
639 bin.to_token_stream()
640 )))
641 }
642 };
643
644 Ok(format!("{left} {op} {right}"))
645 }
646
647 fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
649 let expr = self.transpile_expr(&unary.expr)?;
650
651 let op = match unary.op {
652 UnOp::Neg(_) => "-",
653 UnOp::Not(_) => "!",
654 UnOp::Deref(_) => "*",
655 _ => {
656 return Err(TranspileError::Unsupported(format!(
657 "Unary operator: {}",
658 unary.to_token_stream()
659 )))
660 }
661 };
662
663 Ok(format!("{op}({expr})"))
664 }
665
666 fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
668 let inner = self.transpile_expr(&paren.expr)?;
669 Ok(format!("({inner})"))
670 }
671
672 fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
674 let base = self.transpile_expr(&index.expr)?;
675 let idx = self.transpile_expr(&index.index)?;
676 Ok(format!("{base}[{idx}]"))
677 }
678
679 fn transpile_call(&self, call: &ExprCall) -> Result<String> {
681 let func = self.transpile_expr(&call.func)?;
682
683 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
685 let cuda_name = intrinsic.to_cuda_string();
686
687 let is_value_intrinsic = cuda_name.contains("Idx.")
690 || cuda_name.contains("Dim.")
691 || cuda_name.starts_with("threadIdx")
692 || cuda_name.starts_with("blockIdx")
693 || cuda_name.starts_with("blockDim")
694 || cuda_name.starts_with("gridDim");
695
696 if is_value_intrinsic && call.args.is_empty() {
697 return Ok(cuda_name.to_string());
699 }
700
701 if call.args.is_empty() && cuda_name.ends_with("()") {
702 return Ok(cuda_name.to_string());
704 }
705
706 let args: Vec<String> = call
707 .args
708 .iter()
709 .map(|a| self.transpile_expr(a))
710 .collect::<Result<_>>()?;
711
712 return Ok(format!(
713 "{}({})",
714 cuda_name.trim_end_matches("()"),
715 args.join(", ")
716 ));
717 }
718
719 let args: Vec<String> = call
721 .args
722 .iter()
723 .map(|a| self.transpile_expr(a))
724 .collect::<Result<_>>()?;
725
726 Ok(format!("{}({})", func, args.join(", ")))
727 }
728
729 fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
731 let receiver = self.transpile_expr(&method.receiver)?;
732 let method_name = method.method.to_string();
733
734 if let Some(result) =
736 self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
737 {
738 return result;
739 }
740
741 if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
743 return self.transpile_context_method(&method_name, &method.args);
744 }
745
746 if self.grid_pos_vars.contains(&receiver) {
748 return self.transpile_stencil_intrinsic(&method_name, &method.args);
749 }
750
751 if self.ring_kernel_mode {
753 if let Some(intrinsic) = RingKernelIntrinsic::from_name(&method_name) {
754 let args: Vec<String> = method
755 .args
756 .iter()
757 .map(|a| self.transpile_expr(a).unwrap_or_default())
758 .collect();
759 return Ok(intrinsic.to_cuda(&args));
760 }
761 }
762
763 if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
765 let cuda_name = intrinsic.to_cuda_string();
766 let args: Vec<String> = std::iter::once(receiver)
767 .chain(
768 method
769 .args
770 .iter()
771 .map(|a| self.transpile_expr(a).unwrap_or_default()),
772 )
773 .collect();
774
775 return Ok(format!("{}({})", cuda_name, args.join(", ")));
776 }
777
778 let args: Vec<String> = method
780 .args
781 .iter()
782 .map(|a| self.transpile_expr(a))
783 .collect::<Result<_>>()?;
784
785 Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
786 }
787
788 fn transpile_context_method(
790 &self,
791 method: &str,
792 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
793 ) -> Result<String> {
794 let ctx_method = ContextMethod::from_name(method).ok_or_else(|| {
795 TranspileError::Unsupported(format!("Unknown context method: {}", method))
796 })?;
797
798 let cuda_args: Vec<String> = args
799 .iter()
800 .map(|a| self.transpile_expr(a).unwrap_or_default())
801 .collect();
802
803 Ok(ctx_method.to_cuda(&cuda_args))
804 }
805
806 fn transpile_stencil_intrinsic(
808 &self,
809 method: &str,
810 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
811 ) -> Result<String> {
812 let config = self.config.as_ref().ok_or_else(|| {
813 TranspileError::Unsupported("Stencil intrinsic without config".into())
814 })?;
815
816 let buffer_width = config.buffer_width().to_string();
817 let buffer_slice = format!("{}", config.buffer_width() * config.buffer_height());
818 let is_3d = config.grid == crate::stencil::Grid::Grid3D;
819
820 let intrinsic = StencilIntrinsic::from_method_name(method).ok_or_else(|| {
821 TranspileError::Unsupported(format!("Unknown stencil intrinsic: {method}"))
822 })?;
823
824 if intrinsic.is_3d_only() && !is_3d {
826 return Err(TranspileError::Unsupported(format!(
827 "3D stencil intrinsic '{}' requires Grid3D configuration",
828 method
829 )));
830 }
831
832 match intrinsic {
833 StencilIntrinsic::Index => {
834 Ok("idx".to_string())
836 }
837 StencilIntrinsic::North
838 | StencilIntrinsic::South
839 | StencilIntrinsic::East
840 | StencilIntrinsic::West => {
841 if args.is_empty() {
843 return Err(TranspileError::Unsupported(
844 "Stencil accessor requires buffer argument".into(),
845 ));
846 }
847 let buffer = self.transpile_expr(&args[0])?;
848 if is_3d {
849 Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
850 } else {
851 Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx"))
852 }
853 }
854 StencilIntrinsic::Up | StencilIntrinsic::Down => {
855 if args.is_empty() {
857 return Err(TranspileError::Unsupported(
858 "3D stencil accessor requires buffer argument".into(),
859 ));
860 }
861 let buffer = self.transpile_expr(&args[0])?;
862 Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
863 }
864 StencilIntrinsic::At => {
865 if is_3d {
868 if args.len() < 4 {
869 return Err(TranspileError::Unsupported(
870 "at() in 3D requires buffer, dx, dy, dz arguments".into(),
871 ));
872 }
873 let buffer = self.transpile_expr(&args[0])?;
874 let dx = self.transpile_expr(&args[1])?;
875 let dy = self.transpile_expr(&args[2])?;
876 let dz = self.transpile_expr(&args[3])?;
877 Ok(format!(
878 "{buffer}[idx + ({dz}) * {buffer_slice} + ({dy}) * {buffer_width} + ({dx})]"
879 ))
880 } else {
881 if args.len() < 3 {
882 return Err(TranspileError::Unsupported(
883 "at() requires buffer, dx, dy arguments".into(),
884 ));
885 }
886 let buffer = self.transpile_expr(&args[0])?;
887 let dx = self.transpile_expr(&args[1])?;
888 let dy = self.transpile_expr(&args[2])?;
889 Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]"))
890 }
891 }
892 }
893 }
894
895 fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
897 let cond = self.transpile_expr(&if_expr.cond)?;
898
899 if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
901 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
902 if ret.expr.is_none() {
904 return Ok(format!("if ({cond}) return"));
905 }
906 let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
907 return Ok(format!("if ({cond}) return {ret_val}"));
908 }
909 }
910
911 if let Some((_, else_branch)) = &if_expr.else_branch {
913 if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
915 (if_expr.then_branch.stmts.last(), else_branch.as_ref())
916 {
917 if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
918 let then_str = self.transpile_expr(then_expr)?;
919 let else_str = self.transpile_expr(else_expr)?;
920 return Ok(format!("({cond}) ? ({then_str}) : ({else_str})"));
921 }
922 }
923
924 if let Expr::If(else_if) = else_branch.as_ref() {
926 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
928 let else_part = self.transpile_if(else_if)?;
929 return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
930 } else if let Expr::Block(else_block) = else_branch.as_ref() {
931 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
933 let else_body = self.transpile_if_body(&else_block.block)?;
934 return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
935 }
936 }
937
938 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
940 Ok(format!("if ({cond}) {{{then_body}}}"))
941 }
942
943 fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
945 let mut body = String::new();
946 for stmt in &block.stmts {
947 match stmt {
948 Stmt::Expr(expr, Some(_)) => {
949 let expr_str = self.transpile_expr(expr)?;
950 body.push_str(&format!(" {expr_str};"));
951 }
952 Stmt::Expr(Expr::Return(ret), None) => {
953 if let Some(ret_expr) = &ret.expr {
955 let expr_str = self.transpile_expr(ret_expr)?;
956 body.push_str(&format!(" return {expr_str};"));
957 } else {
958 body.push_str(" return;");
959 }
960 }
961 Stmt::Expr(expr, None) => {
962 let expr_str = self.transpile_expr(expr)?;
963 body.push_str(&format!(" return {expr_str};"));
964 }
965 _ => {}
966 }
967 }
968 Ok(body)
969 }
970
971 fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
973 let left = self.transpile_expr(&assign.left)?;
974 let right = self.transpile_expr(&assign.right)?;
975 Ok(format!("{left} = {right}"))
976 }
977
978 fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
980 let expr = self.transpile_expr(&cast.expr)?;
981 let cuda_type = self.type_mapper.map_type(&cast.ty)?;
982 Ok(format!("({})({})", cuda_type.to_cuda_string(), expr))
983 }
984
985 fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
987 if let Some(expr) = &ret.expr {
988 let expr_str = self.transpile_expr(expr)?;
989 Ok(format!("return {expr_str}"))
990 } else {
991 Ok("return".to_string())
992 }
993 }
994
995 fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1000 let type_name = struct_expr
1002 .path
1003 .segments
1004 .iter()
1005 .map(|s| s.ident.to_string())
1006 .collect::<Vec<_>>()
1007 .join("::");
1008
1009 let mut fields = Vec::new();
1011 for field in &struct_expr.fields {
1012 let field_name = match &field.member {
1013 syn::Member::Named(ident) => ident.to_string(),
1014 syn::Member::Unnamed(idx) => idx.index.to_string(),
1015 };
1016 let value = self.transpile_expr(&field.expr)?;
1017 fields.push(format!(".{} = {}", field_name, value));
1018 }
1019
1020 if struct_expr.rest.is_some() {
1022 return Err(TranspileError::Unsupported(
1023 "Struct update syntax (..base) is not supported in CUDA".into(),
1024 ));
1025 }
1026
1027 Ok(format!("({}){{ {} }}", type_name, fields.join(", ")))
1029 }
1030
1031 fn transpile_reference(&self, ref_expr: &ExprReference) -> Result<String> {
1038 let inner = self.transpile_expr(&ref_expr.expr)?;
1039
1040 Ok(format!("&{inner}"))
1044 }
1045
1046 fn transpile_let_expr(&self, let_expr: &ExprLet) -> Result<String> {
1052 let _ = let_expr; Err(TranspileError::Unsupported(
1056 "let expressions (if-let patterns) are not directly supported in CUDA. \
1057 Use explicit comparisons instead."
1058 .into(),
1059 ))
1060 }
1061
1062 fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1082 if !self.validation_mode.allows_loops() {
1084 return Err(TranspileError::Unsupported(
1085 "Loops are not allowed in stencil kernels".into(),
1086 ));
1087 }
1088
1089 let var_name = extract_loop_var(&for_loop.pat)
1091 .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1092
1093 let header = match for_loop.expr.as_ref() {
1095 Expr::Range(range) => {
1096 let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1097 range_info.to_cuda_for_header(&var_name, "int")
1098 }
1099 _ => {
1100 return Err(TranspileError::Unsupported(
1102 "Only range expressions (start..end) are supported in for loops".into(),
1103 ));
1104 }
1105 };
1106
1107 let body = self.transpile_loop_body(&for_loop.body)?;
1109
1110 Ok(format!("{header} {{\n{body}}}"))
1111 }
1112
1113 fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1129 if !self.validation_mode.allows_loops() {
1131 return Err(TranspileError::Unsupported(
1132 "Loops are not allowed in stencil kernels".into(),
1133 ));
1134 }
1135
1136 let condition = self.transpile_expr(&while_loop.cond)?;
1138
1139 let body = self.transpile_loop_body(&while_loop.body)?;
1141
1142 Ok(format!("while ({condition}) {{\n{body}}}"))
1143 }
1144
1145 fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1161 if !self.validation_mode.allows_loops() {
1163 return Err(TranspileError::Unsupported(
1164 "Loops are not allowed in stencil kernels".into(),
1165 ));
1166 }
1167
1168 let body = self.transpile_loop_body(&loop_expr.body)?;
1170
1171 Ok(format!("while (true) {{\n{body}}}"))
1173 }
1174
1175 fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1177 if break_expr.label.is_some() {
1179 return Err(TranspileError::Unsupported(
1180 "Labeled break is not supported in CUDA".into(),
1181 ));
1182 }
1183
1184 if break_expr.expr.is_some() {
1186 return Err(TranspileError::Unsupported(
1187 "Break with value is not supported in CUDA".into(),
1188 ));
1189 }
1190
1191 Ok("break".to_string())
1192 }
1193
1194 fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1196 if cont_expr.label.is_some() {
1198 return Err(TranspileError::Unsupported(
1199 "Labeled continue is not supported in CUDA".into(),
1200 ));
1201 }
1202
1203 Ok("continue".to_string())
1204 }
1205
1206 fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1208 let mut output = String::new();
1209 let inner_indent = " ".repeat(self.indent + 1);
1210
1211 for stmt in &block.stmts {
1212 match stmt {
1213 Stmt::Local(local) => {
1214 let var_name = match &local.pat {
1216 Pat::Ident(ident) => ident.ident.to_string(),
1217 Pat::Type(pat_type) => {
1218 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1219 ident.ident.to_string()
1220 } else {
1221 return Err(TranspileError::Unsupported(
1222 "Complex pattern in let binding".into(),
1223 ));
1224 }
1225 }
1226 _ => {
1227 return Err(TranspileError::Unsupported(
1228 "Complex pattern in let binding".into(),
1229 ))
1230 }
1231 };
1232
1233 if let Some(init) = &local.init {
1234 let expr_str = self.transpile_expr(&init.expr)?;
1235 let type_str = self.infer_cuda_type(&init.expr);
1236 output.push_str(&format!(
1237 "{inner_indent}{type_str} {var_name} = {expr_str};\n"
1238 ));
1239 } else {
1240 output.push_str(&format!("{inner_indent}float {var_name};\n"));
1241 }
1242 }
1243 Stmt::Expr(expr, semi) => {
1244 let expr_str = self.transpile_expr(expr)?;
1245 if semi.is_some() {
1246 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1247 } else {
1248 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1250 }
1251 }
1252 _ => {
1253 return Err(TranspileError::Unsupported(
1254 "Unsupported statement in loop body".into(),
1255 ));
1256 }
1257 }
1258 }
1259
1260 let closing_indent = " ".repeat(self.indent);
1262 output.push_str(&closing_indent);
1263
1264 Ok(output)
1265 }
1266
1267 fn try_parse_shared_declaration(
1276 &self,
1277 local: &syn::Local,
1278 var_name: &str,
1279 ) -> Result<Option<SharedMemoryDecl>> {
1280 if let Pat::Type(pat_type) = &local.pat {
1282 let type_str = pat_type.ty.to_token_stream().to_string();
1283 return self.parse_shared_type(&type_str, var_name);
1284 }
1285
1286 if let Some(init) = &local.init {
1288 if let Expr::Call(call) = init.expr.as_ref() {
1289 if let Expr::Path(path) = call.func.as_ref() {
1290 let path_str = path.to_token_stream().to_string();
1291 return self.parse_shared_type(&path_str, var_name);
1292 }
1293 }
1294 }
1295
1296 Ok(None)
1297 }
1298
1299 fn parse_shared_type(
1301 &self,
1302 type_str: &str,
1303 var_name: &str,
1304 ) -> Result<Option<SharedMemoryDecl>> {
1305 let type_str = type_str
1307 .replace(" :: ", "::")
1308 .replace(" ::", "::")
1309 .replace(":: ", "::");
1310
1311 if type_str.contains("SharedTile") {
1313 if let Some(start) = type_str.find('<') {
1315 if let Some(end) = type_str.rfind('>') {
1316 let params = &type_str[start + 1..end];
1317 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1318
1319 if parts.len() >= 3 {
1320 let rust_type = parts[0];
1321 let width: usize = parts[1].parse().map_err(|_| {
1322 TranspileError::Unsupported("Invalid SharedTile width".into())
1323 })?;
1324 let height: usize = parts[2].parse().map_err(|_| {
1325 TranspileError::Unsupported("Invalid SharedTile height".into())
1326 })?;
1327
1328 let cuda_type = rust_to_cuda_element_type(rust_type);
1329 return Ok(Some(SharedMemoryDecl::tile(
1330 var_name, cuda_type, width, height,
1331 )));
1332 }
1333 }
1334 }
1335 }
1336
1337 if type_str.contains("SharedArray") {
1339 if let Some(start) = type_str.find('<') {
1340 if let Some(end) = type_str.rfind('>') {
1341 let params = &type_str[start + 1..end];
1342 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1343
1344 if parts.len() >= 2 {
1345 let rust_type = parts[0];
1346 let size: usize = parts[1].parse().map_err(|_| {
1347 TranspileError::Unsupported("Invalid SharedArray size".into())
1348 })?;
1349
1350 let cuda_type = rust_to_cuda_element_type(rust_type);
1351 return Ok(Some(SharedMemoryDecl::array(var_name, cuda_type, size)));
1352 }
1353 }
1354 }
1355 }
1356
1357 Ok(None)
1358 }
1359
1360 fn try_transpile_shared_method_call(
1362 &self,
1363 receiver: &str,
1364 method_name: &str,
1365 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1366 ) -> Option<Result<String>> {
1367 let shared_info = self.shared_vars.get(receiver)?;
1368
1369 match method_name {
1370 "get" => {
1371 if shared_info.is_tile {
1373 if args.len() >= 2 {
1374 let x = self.transpile_expr(&args[0]).ok()?;
1375 let y = self.transpile_expr(&args[1]).ok()?;
1376 Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1378 } else {
1379 Some(Err(TranspileError::Unsupported(
1380 "SharedTile.get requires x and y arguments".into(),
1381 )))
1382 }
1383 } else {
1384 if !args.is_empty() {
1386 let idx = self.transpile_expr(&args[0]).ok()?;
1387 Some(Ok(format!("{}[{}]", receiver, idx)))
1388 } else {
1389 Some(Err(TranspileError::Unsupported(
1390 "SharedArray.get requires index argument".into(),
1391 )))
1392 }
1393 }
1394 }
1395 "set" => {
1396 if shared_info.is_tile {
1398 if args.len() >= 3 {
1399 let x = self.transpile_expr(&args[0]).ok()?;
1400 let y = self.transpile_expr(&args[1]).ok()?;
1401 let val = self.transpile_expr(&args[2]).ok()?;
1402 Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1403 } else {
1404 Some(Err(TranspileError::Unsupported(
1405 "SharedTile.set requires x, y, and value arguments".into(),
1406 )))
1407 }
1408 } else {
1409 if args.len() >= 2 {
1411 let idx = self.transpile_expr(&args[0]).ok()?;
1412 let val = self.transpile_expr(&args[1]).ok()?;
1413 Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1414 } else {
1415 Some(Err(TranspileError::Unsupported(
1416 "SharedArray.set requires index and value arguments".into(),
1417 )))
1418 }
1419 }
1420 }
1421 "width" | "height" | "size" => {
1422 match method_name {
1424 "width" if shared_info.is_tile => {
1425 Some(Ok(shared_info.dimensions[1].to_string()))
1426 }
1427 "height" if shared_info.is_tile => {
1428 Some(Ok(shared_info.dimensions[0].to_string()))
1429 }
1430 "size" => {
1431 let total: usize = shared_info.dimensions.iter().product();
1432 Some(Ok(total.to_string()))
1433 }
1434 _ => None,
1435 }
1436 }
1437 _ => None,
1438 }
1439 }
1440
1441 fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1443 let scrutinee = self.transpile_expr(&match_expr.expr)?;
1444 let mut output = format!("switch ({scrutinee}) {{\n");
1445
1446 for arm in &match_expr.arms {
1447 let case_label = self.transpile_match_pattern(&arm.pat)?;
1449
1450 if case_label == "default" || case_label.starts_with("/*") {
1451 output.push_str(" default: {\n");
1452 } else {
1453 output.push_str(&format!(" case {case_label}: {{\n"));
1454 }
1455
1456 match arm.body.as_ref() {
1458 Expr::Block(block) => {
1459 for stmt in &block.block.stmts {
1461 let stmt_str = self.transpile_stmt_inline(stmt)?;
1462 output.push_str(&format!(" {stmt_str}\n"));
1463 }
1464 }
1465 _ => {
1466 let body = self.transpile_expr(&arm.body)?;
1468 output.push_str(&format!(" {body};\n"));
1469 }
1470 }
1471
1472 output.push_str(" break;\n");
1473 output.push_str(" }\n");
1474 }
1475
1476 output.push_str(" }");
1477 Ok(output)
1478 }
1479
1480 fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1482 match pat {
1483 Pat::Lit(pat_lit) => {
1484 match &pat_lit.lit {
1486 Lit::Int(i) => Ok(i.to_string()),
1487 Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
1488 _ => Err(TranspileError::Unsupported(
1489 "Non-integer literal in match pattern".into(),
1490 )),
1491 }
1492 }
1493 Pat::Wild(_) => {
1494 Ok("default".to_string())
1496 }
1497 Pat::Ident(ident) => {
1498 Ok(format!("/* {} */ default", ident.ident))
1501 }
1502 Pat::Or(pat_or) => {
1503 if let Some(first) = pat_or.cases.first() {
1507 self.transpile_match_pattern(first)
1508 } else {
1509 Err(TranspileError::Unsupported("Empty or pattern".into()))
1510 }
1511 }
1512 _ => Err(TranspileError::Unsupported(format!(
1513 "Match pattern: {}",
1514 pat.to_token_stream()
1515 ))),
1516 }
1517 }
1518
1519 fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1521 match stmt {
1522 Stmt::Local(local) => {
1523 let var_name = match &local.pat {
1524 Pat::Ident(ident) => ident.ident.to_string(),
1525 Pat::Type(pat_type) => {
1526 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1527 ident.ident.to_string()
1528 } else {
1529 return Err(TranspileError::Unsupported(
1530 "Complex pattern in let binding".into(),
1531 ));
1532 }
1533 }
1534 _ => {
1535 return Err(TranspileError::Unsupported(
1536 "Complex pattern in let binding".into(),
1537 ))
1538 }
1539 };
1540
1541 if let Some(init) = &local.init {
1542 let expr_str = self.transpile_expr(&init.expr)?;
1543 let type_str = self.infer_cuda_type(&init.expr);
1544 Ok(format!("{type_str} {var_name} = {expr_str};"))
1545 } else {
1546 Ok(format!("float {var_name};"))
1547 }
1548 }
1549 Stmt::Expr(expr, semi) => {
1550 let expr_str = self.transpile_expr(expr)?;
1551 if semi.is_some() {
1552 Ok(format!("{expr_str};"))
1553 } else {
1554 Ok(format!("return {expr_str};"))
1555 }
1556 }
1557 _ => Err(TranspileError::Unsupported(
1558 "Unsupported statement in match arm".into(),
1559 )),
1560 }
1561 }
1562
1563 fn infer_cuda_type(&self, expr: &Expr) -> &'static str {
1565 match expr {
1566 Expr::Lit(lit) => match &lit.lit {
1567 Lit::Float(_) => "float",
1568 Lit::Int(_) => "int",
1569 Lit::Bool(_) => "int",
1570 _ => "float",
1571 },
1572 Expr::Binary(bin) => {
1573 let left_type = self.infer_cuda_type(&bin.left);
1575 let right_type = self.infer_cuda_type(&bin.right);
1576 if left_type == "int" && right_type == "int" {
1578 "int"
1579 } else {
1580 "float"
1581 }
1582 }
1583 Expr::Call(call) => {
1584 if let Ok(func) = self.transpile_expr(&call.func) {
1586 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
1587 let cuda_name = intrinsic.to_cuda_string();
1588 if cuda_name.contains("Idx") || cuda_name.contains("Dim") {
1590 return "int";
1591 }
1592 }
1593 }
1594 "float"
1595 }
1596 Expr::Index(_) => "float", Expr::Cast(cast) => {
1598 if let Ok(cuda_type) = self.type_mapper.map_type(&cast.ty) {
1600 let s = cuda_type.to_cuda_string();
1601 if s.contains("int") || s.contains("size_t") || s == "unsigned long long" {
1602 return "int";
1603 }
1604 }
1605 "float"
1606 }
1607 Expr::Reference(ref_expr) => {
1608 match ref_expr.expr.as_ref() {
1611 Expr::Index(idx_expr) => {
1612 if let Expr::Path(path) = &*idx_expr.expr {
1614 let name = path
1615 .path
1616 .segments
1617 .iter()
1618 .map(|s| s.ident.to_string())
1619 .collect::<Vec<_>>()
1620 .join("::");
1621 if name.contains("transaction") || name.contains("Transaction") {
1623 return "GpuTransaction*";
1624 }
1625 if name.contains("profile") || name.contains("Profile") {
1626 return "GpuCustomerProfile*";
1627 }
1628 if name.contains("alert") || name.contains("Alert") {
1629 return "GpuAlert*";
1630 }
1631 }
1632 "float*" }
1634 _ => "void*",
1635 }
1636 }
1637 Expr::MethodCall(_) => "float",
1638 Expr::Field(field) => {
1639 let member_name = match &field.member {
1641 syn::Member::Named(ident) => ident.to_string(),
1642 syn::Member::Unnamed(idx) => idx.index.to_string(),
1643 };
1644 if member_name.contains("count") || member_name.contains("_count") {
1646 return "unsigned int";
1647 }
1648 if member_name.contains("threshold") || member_name.ends_with("_id") {
1649 return "unsigned long long";
1650 }
1651 if member_name.ends_with("_pct") {
1652 return "unsigned char";
1653 }
1654 "float"
1655 }
1656 Expr::Path(path) => {
1657 let name = path
1659 .path
1660 .segments
1661 .iter()
1662 .map(|s| s.ident.to_string())
1663 .collect::<Vec<_>>()
1664 .join("::");
1665 if name.contains("threshold")
1666 || name.contains("count")
1667 || name == "idx"
1668 || name == "n"
1669 {
1670 return "int";
1671 }
1672 "float"
1673 }
1674 Expr::If(if_expr) => {
1675 if let Some((_, else_branch)) = &if_expr.else_branch {
1677 if let Expr::Block(block) = else_branch.as_ref() {
1678 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
1679 return self.infer_cuda_type(expr);
1680 }
1681 }
1682 }
1683 if let Some(Stmt::Expr(expr, None)) = if_expr.then_branch.stmts.last() {
1685 return self.infer_cuda_type(expr);
1686 }
1687 "float"
1688 }
1689 _ => "float",
1690 }
1691 }
1692}
1693
1694pub fn transpile_function(func: &ItemFn) -> Result<String> {
1696 let mut transpiler = CudaTranspiler::new_generic();
1697
1698 let name = func.sig.ident.to_string();
1700
1701 let mut params = Vec::new();
1702 for param in &func.sig.inputs {
1703 if let FnArg::Typed(pat_type) = param {
1704 let param_name = match pat_type.pat.as_ref() {
1705 Pat::Ident(ident) => ident.ident.to_string(),
1706 _ => continue,
1707 };
1708
1709 let cuda_type = transpiler.type_mapper.map_type(&pat_type.ty)?;
1710 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
1711 }
1712 }
1713
1714 let return_type = match &func.sig.output {
1716 ReturnType::Default => "void".to_string(),
1717 ReturnType::Type(_, ty) => transpiler.type_mapper.map_type(ty)?.to_cuda_string(),
1718 };
1719
1720 let body = transpiler.transpile_block(&func.block)?;
1722
1723 Ok(format!(
1724 "__device__ {return_type} {name}({params}) {{\n{body}}}\n",
1725 params = params.join(", ")
1726 ))
1727}
1728
1729#[cfg(test)]
1730mod tests {
1731 use super::*;
1732 use syn::parse_quote;
1733
1734 #[test]
1735 fn test_simple_arithmetic() {
1736 let transpiler = CudaTranspiler::new_generic();
1737
1738 let expr: Expr = parse_quote!(a + b * 2.0);
1739 let result = transpiler.transpile_expr(&expr).unwrap();
1740 assert_eq!(result, "a + b * 2.0f");
1741 }
1742
1743 #[test]
1744 fn test_let_binding() {
1745 let mut transpiler = CudaTranspiler::new_generic();
1746
1747 let stmt: Stmt = parse_quote!(let x = a + b;);
1748 let result = transpiler.transpile_stmt(&stmt).unwrap();
1749 assert!(result.contains("float x = a + b;"));
1750 }
1751
1752 #[test]
1753 fn test_array_index() {
1754 let transpiler = CudaTranspiler::new_generic();
1755
1756 let expr: Expr = parse_quote!(data[idx]);
1757 let result = transpiler.transpile_expr(&expr).unwrap();
1758 assert_eq!(result, "data[idx]");
1759 }
1760
1761 #[test]
1762 fn test_stencil_intrinsics() {
1763 let config = StencilConfig::new("test")
1764 .with_tile_size(16, 16)
1765 .with_halo(1);
1766 let mut transpiler = CudaTranspiler::new(config);
1767 transpiler.grid_pos_vars.push("pos".to_string());
1768
1769 let expr: Expr = parse_quote!(pos.idx());
1771 let result = transpiler.transpile_expr(&expr).unwrap();
1772 assert_eq!(result, "idx");
1773
1774 let expr: Expr = parse_quote!(pos.north(p));
1776 let result = transpiler.transpile_expr(&expr).unwrap();
1777 assert_eq!(result, "p[idx - 18]");
1778
1779 let expr: Expr = parse_quote!(pos.east(p));
1781 let result = transpiler.transpile_expr(&expr).unwrap();
1782 assert_eq!(result, "p[idx + 1]");
1783 }
1784
1785 #[test]
1786 fn test_ternary_if() {
1787 let transpiler = CudaTranspiler::new_generic();
1788
1789 let expr: Expr = parse_quote!(if x > 0.0 { x } else { -x });
1790 let result = transpiler.transpile_expr(&expr).unwrap();
1791 assert!(result.contains("?"));
1792 assert!(result.contains(":"));
1793 }
1794
1795 #[test]
1796 fn test_full_stencil_kernel() {
1797 let func: ItemFn = parse_quote! {
1798 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
1799 let curr = p[pos.idx()];
1800 let prev = p_prev[pos.idx()];
1801 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
1802 p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
1803 }
1804 };
1805
1806 let config = StencilConfig::new("fdtd")
1807 .with_tile_size(16, 16)
1808 .with_halo(1);
1809
1810 let mut transpiler = CudaTranspiler::new(config);
1811 let cuda = transpiler.transpile_stencil(&func).unwrap();
1812
1813 assert!(cuda.contains("extern \"C\" __global__"));
1815 assert!(cuda.contains("threadIdx.x"));
1816 assert!(cuda.contains("threadIdx.y"));
1817 assert!(cuda.contains("buffer_width = 18"));
1818 assert!(cuda.contains("const float* __restrict__ p"));
1819 assert!(cuda.contains("float* __restrict__ p_prev"));
1820 assert!(!cuda.contains("GridPos")); println!("Generated CUDA:\n{}", cuda);
1823 }
1824
1825 #[test]
1826 fn test_early_return() {
1827 let mut transpiler = CudaTranspiler::new_generic();
1828
1829 let stmt: Stmt = parse_quote!(return;);
1830 let result = transpiler.transpile_stmt(&stmt).unwrap();
1831 assert!(result.contains("return;"));
1832
1833 let stmt_val: Stmt = parse_quote!(return 42;);
1834 let result_val = transpiler.transpile_stmt(&stmt_val).unwrap();
1835 assert!(result_val.contains("return 42;"));
1836 }
1837
1838 #[test]
1839 fn test_match_to_switch() {
1840 let transpiler = CudaTranspiler::new_generic();
1841
1842 let expr: Expr = parse_quote! {
1843 match edge {
1844 0 => { idx = 1 * 18 + i; }
1845 1 => { idx = 16 * 18 + i; }
1846 _ => { idx = 0; }
1847 }
1848 };
1849
1850 let result = transpiler.transpile_expr(&expr).unwrap();
1851 assert!(
1852 result.contains("switch (edge)"),
1853 "Should generate switch: {}",
1854 result
1855 );
1856 assert!(result.contains("case 0:"), "Should have case 0: {}", result);
1857 assert!(result.contains("case 1:"), "Should have case 1: {}", result);
1858 assert!(
1859 result.contains("default:"),
1860 "Should have default: {}",
1861 result
1862 );
1863 assert!(result.contains("break;"), "Should have break: {}", result);
1864
1865 println!("Generated switch:\n{}", result);
1866 }
1867
1868 #[test]
1869 fn test_block_idx_intrinsics() {
1870 let transpiler = CudaTranspiler::new_generic();
1871
1872 let expr: Expr = parse_quote!(block_idx_x());
1874 let result = transpiler.transpile_expr(&expr).unwrap();
1875 assert_eq!(result, "blockIdx.x");
1876
1877 let expr2: Expr = parse_quote!(thread_idx_y());
1879 let result2 = transpiler.transpile_expr(&expr2).unwrap();
1880 assert_eq!(result2, "threadIdx.y");
1881
1882 let expr3: Expr = parse_quote!(grid_dim_x());
1884 let result3 = transpiler.transpile_expr(&expr3).unwrap();
1885 assert_eq!(result3, "gridDim.x");
1886 }
1887
1888 #[test]
1889 fn test_global_index_calculation() {
1890 let transpiler = CudaTranspiler::new_generic();
1891
1892 let expr: Expr = parse_quote!(block_idx_x() * block_dim_x() + thread_idx_x());
1894 let result = transpiler.transpile_expr(&expr).unwrap();
1895 assert!(result.contains("blockIdx.x"), "Should contain blockIdx.x");
1896 assert!(result.contains("blockDim.x"), "Should contain blockDim.x");
1897 assert!(result.contains("threadIdx.x"), "Should contain threadIdx.x");
1898
1899 println!("Global index expression: {}", result);
1900 }
1901
1902 #[test]
1905 fn test_for_loop_transpile() {
1906 let transpiler = CudaTranspiler::new_generic();
1907
1908 let expr: Expr = parse_quote! {
1909 for i in 0..n {
1910 data[i] = 0.0;
1911 }
1912 };
1913
1914 let result = transpiler.transpile_expr(&expr).unwrap();
1915 assert!(
1916 result.contains("for (int i = 0; i < n; i++)"),
1917 "Should generate for loop header: {}",
1918 result
1919 );
1920 assert!(
1921 result.contains("data[i] = 0.0f"),
1922 "Should contain loop body: {}",
1923 result
1924 );
1925
1926 println!("Generated for loop:\n{}", result);
1927 }
1928
1929 #[test]
1930 fn test_for_loop_inclusive_range() {
1931 let transpiler = CudaTranspiler::new_generic();
1932
1933 let expr: Expr = parse_quote! {
1934 for i in 1..=10 {
1935 sum += i;
1936 }
1937 };
1938
1939 let result = transpiler.transpile_expr(&expr).unwrap();
1940 assert!(
1941 result.contains("for (int i = 1; i <= 10; i++)"),
1942 "Should generate inclusive range: {}",
1943 result
1944 );
1945
1946 println!("Generated inclusive for loop:\n{}", result);
1947 }
1948
1949 #[test]
1950 fn test_while_loop_transpile() {
1951 let transpiler = CudaTranspiler::new_generic();
1952
1953 let expr: Expr = parse_quote! {
1954 while i < 10 {
1955 i += 1;
1956 }
1957 };
1958
1959 let result = transpiler.transpile_expr(&expr).unwrap();
1960 assert!(
1961 result.contains("while (i < 10)"),
1962 "Should generate while loop: {}",
1963 result
1964 );
1965 assert!(
1966 result.contains("i += 1"),
1967 "Should contain loop body: {}",
1968 result
1969 );
1970
1971 println!("Generated while loop:\n{}", result);
1972 }
1973
1974 #[test]
1975 fn test_while_loop_negation() {
1976 let transpiler = CudaTranspiler::new_generic();
1977
1978 let expr: Expr = parse_quote! {
1979 while !done {
1980 process();
1981 }
1982 };
1983
1984 let result = transpiler.transpile_expr(&expr).unwrap();
1985 assert!(
1986 result.contains("while (!(done))"),
1987 "Should negate condition: {}",
1988 result
1989 );
1990
1991 println!("Generated while loop with negation:\n{}", result);
1992 }
1993
1994 #[test]
1995 fn test_infinite_loop_transpile() {
1996 let transpiler = CudaTranspiler::new_generic();
1997
1998 let expr: Expr = parse_quote! {
1999 loop {
2000 process();
2001 }
2002 };
2003
2004 let result = transpiler.transpile_expr(&expr).unwrap();
2005 assert!(
2006 result.contains("while (true)"),
2007 "Should generate infinite loop: {}",
2008 result
2009 );
2010 assert!(
2011 result.contains("process()"),
2012 "Should contain loop body: {}",
2013 result
2014 );
2015
2016 println!("Generated infinite loop:\n{}", result);
2017 }
2018
2019 #[test]
2020 fn test_break_transpile() {
2021 let transpiler = CudaTranspiler::new_generic();
2022
2023 let expr: Expr = parse_quote!(break);
2024 let result = transpiler.transpile_expr(&expr).unwrap();
2025 assert_eq!(result, "break");
2026 }
2027
2028 #[test]
2029 fn test_continue_transpile() {
2030 let transpiler = CudaTranspiler::new_generic();
2031
2032 let expr: Expr = parse_quote!(continue);
2033 let result = transpiler.transpile_expr(&expr).unwrap();
2034 assert_eq!(result, "continue");
2035 }
2036
2037 #[test]
2038 fn test_loop_with_break() {
2039 let transpiler = CudaTranspiler::new_generic();
2040
2041 let expr: Expr = parse_quote! {
2042 loop {
2043 if done {
2044 break;
2045 }
2046 }
2047 };
2048
2049 let result = transpiler.transpile_expr(&expr).unwrap();
2050 assert!(
2051 result.contains("while (true)"),
2052 "Should generate infinite loop: {}",
2053 result
2054 );
2055 assert!(result.contains("break"), "Should contain break: {}", result);
2056
2057 println!("Generated loop with break:\n{}", result);
2058 }
2059
2060 #[test]
2061 fn test_nested_loops() {
2062 let transpiler = CudaTranspiler::new_generic();
2063
2064 let expr: Expr = parse_quote! {
2065 for i in 0..m {
2066 for j in 0..n {
2067 matrix[i * n + j] = 0.0;
2068 }
2069 }
2070 };
2071
2072 let result = transpiler.transpile_expr(&expr).unwrap();
2073 assert!(
2074 result.contains("for (int i = 0; i < m; i++)"),
2075 "Should have outer loop: {}",
2076 result
2077 );
2078 assert!(
2079 result.contains("for (int j = 0; j < n; j++)"),
2080 "Should have inner loop: {}",
2081 result
2082 );
2083
2084 println!("Generated nested loops:\n{}", result);
2085 }
2086
2087 #[test]
2088 fn test_stencil_mode_rejects_loops() {
2089 let config = StencilConfig::new("test")
2090 .with_tile_size(16, 16)
2091 .with_halo(1);
2092 let transpiler = CudaTranspiler::new(config);
2093
2094 let expr: Expr = parse_quote! {
2095 for i in 0..n {
2096 data[i] = 0.0;
2097 }
2098 };
2099
2100 let result = transpiler.transpile_expr(&expr);
2101 assert!(result.is_err(), "Stencil mode should reject loops");
2102 }
2103
2104 #[test]
2105 fn test_labeled_break_rejected() {
2106 let transpiler = CudaTranspiler::new_generic();
2107
2108 let break_expr = syn::ExprBreak {
2111 attrs: Vec::new(),
2112 break_token: syn::token::Break::default(),
2113 label: Some(syn::Lifetime::new("'outer", proc_macro2::Span::call_site())),
2114 expr: None,
2115 };
2116
2117 let result = transpiler.transpile_break(&break_expr);
2118 assert!(result.is_err(), "Labeled break should be rejected");
2119 }
2120
2121 #[test]
2122 fn test_full_kernel_with_loop() {
2123 let func: ItemFn = parse_quote! {
2124 fn fill_array(data: &mut [f32], n: i32) {
2125 for i in 0..n {
2126 data[i as usize] = 0.0;
2127 }
2128 }
2129 };
2130
2131 let mut transpiler = CudaTranspiler::new_generic();
2132 let cuda = transpiler.transpile_generic_kernel(&func).unwrap();
2133
2134 assert!(
2135 cuda.contains("extern \"C\" __global__"),
2136 "Should be global kernel: {}",
2137 cuda
2138 );
2139 assert!(
2140 cuda.contains("for (int i = 0; i < n; i++)"),
2141 "Should have for loop: {}",
2142 cuda
2143 );
2144
2145 println!("Generated kernel with loop:\n{}", cuda);
2146 }
2147
2148 #[test]
2149 fn test_persistent_kernel_pattern() {
2150 let transpiler = CudaTranspiler::with_mode(ValidationMode::RingKernel);
2152
2153 let expr: Expr = parse_quote! {
2154 while !should_terminate {
2155 if has_message {
2156 process_message();
2157 }
2158 }
2159 };
2160
2161 let result = transpiler.transpile_expr(&expr).unwrap();
2162 assert!(
2163 result.contains("while (!(should_terminate))"),
2164 "Should have persistent loop: {}",
2165 result
2166 );
2167 assert!(
2168 result.contains("if (has_message)"),
2169 "Should have message check: {}",
2170 result
2171 );
2172
2173 println!("Generated persistent kernel pattern:\n{}", result);
2174 }
2175
2176 #[test]
2179 fn test_shared_tile_declaration() {
2180 use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2181
2182 let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
2183 assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
2184
2185 let mut config = SharedMemoryConfig::new();
2186 config.add_tile("tile", "float", 16, 16);
2187 assert_eq!(config.total_bytes(), 16 * 16 * 4); let decls = config.generate_declarations(" ");
2190 assert!(decls.contains("__shared__ float tile[16][16];"));
2191 }
2192
2193 #[test]
2194 fn test_shared_array_declaration() {
2195 use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2196
2197 let decl = SharedMemoryDecl::array("buffer", "float", 256);
2198 assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
2199
2200 let mut config = SharedMemoryConfig::new();
2201 config.add_array("buffer", "float", 256);
2202 assert_eq!(config.total_bytes(), 256 * 4); }
2204
2205 #[test]
2206 fn test_shared_memory_access_expressions() {
2207 use crate::shared::SharedMemoryDecl;
2208
2209 let tile = SharedMemoryDecl::tile("tile", "float", 16, 16);
2210 assert_eq!(
2211 tile.to_cuda_access(&["y".to_string(), "x".to_string()]),
2212 "tile[y][x]"
2213 );
2214
2215 let arr = SharedMemoryDecl::array("buf", "int", 128);
2216 assert_eq!(arr.to_cuda_access(&["i".to_string()]), "buf[i]");
2217 }
2218
2219 #[test]
2220 fn test_parse_shared_tile_type() {
2221 use crate::shared::parse_shared_tile_type;
2222
2223 let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
2224 assert_eq!(result, Some(("f32".to_string(), 16, 16)));
2225
2226 let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
2227 assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
2228
2229 let invalid = parse_shared_tile_type("Vec<f32>");
2230 assert_eq!(invalid, None);
2231 }
2232
2233 #[test]
2234 fn test_parse_shared_array_type() {
2235 use crate::shared::parse_shared_array_type;
2236
2237 let result = parse_shared_array_type("SharedArray::<f32, 256>");
2238 assert_eq!(result, Some(("f32".to_string(), 256)));
2239
2240 let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
2241 assert_eq!(result2, Some(("u32".to_string(), 1024)));
2242
2243 let invalid = parse_shared_array_type("Vec<f32>");
2244 assert_eq!(invalid, None);
2245 }
2246
2247 #[test]
2248 fn test_rust_to_cuda_element_types() {
2249 use crate::shared::rust_to_cuda_element_type;
2250
2251 assert_eq!(rust_to_cuda_element_type("f32"), "float");
2252 assert_eq!(rust_to_cuda_element_type("f64"), "double");
2253 assert_eq!(rust_to_cuda_element_type("i32"), "int");
2254 assert_eq!(rust_to_cuda_element_type("u32"), "unsigned int");
2255 assert_eq!(rust_to_cuda_element_type("i64"), "long long");
2256 assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
2257 assert_eq!(rust_to_cuda_element_type("bool"), "int");
2258 }
2259
2260 #[test]
2261 fn test_shared_memory_total_bytes() {
2262 use crate::shared::SharedMemoryConfig;
2263
2264 let mut config = SharedMemoryConfig::new();
2265 config.add_tile("tile1", "float", 16, 16); config.add_tile("tile2", "double", 8, 8); config.add_array("temp", "int", 64); assert_eq!(config.total_bytes(), 1024 + 512 + 256);
2270 }
2271
2272 #[test]
2273 fn test_transpiler_shared_var_tracking() {
2274 let mut transpiler = CudaTranspiler::new_generic();
2275
2276 transpiler.shared_vars.insert(
2278 "tile".to_string(),
2279 SharedVarInfo {
2280 name: "tile".to_string(),
2281 is_tile: true,
2282 dimensions: vec![16, 16],
2283 element_type: "float".to_string(),
2284 },
2285 );
2286
2287 assert!(transpiler.shared_vars.contains_key("tile"));
2289 assert!(transpiler.shared_vars.get("tile").unwrap().is_tile);
2290 }
2291
2292 #[test]
2293 fn test_shared_tile_get_transpilation() {
2294 let mut transpiler = CudaTranspiler::new_generic();
2295
2296 transpiler.shared_vars.insert(
2298 "tile".to_string(),
2299 SharedVarInfo {
2300 name: "tile".to_string(),
2301 is_tile: true,
2302 dimensions: vec![16, 16],
2303 element_type: "float".to_string(),
2304 },
2305 );
2306
2307 let result = transpiler.try_transpile_shared_method_call(
2309 "tile",
2310 "get",
2311 &syn::punctuated::Punctuated::new(),
2312 );
2313
2314 assert!(result.is_none() || result.unwrap().is_err());
2316 }
2317
2318 #[test]
2319 fn test_shared_array_access() {
2320 let mut transpiler = CudaTranspiler::new_generic();
2321
2322 transpiler.shared_vars.insert(
2324 "buffer".to_string(),
2325 SharedVarInfo {
2326 name: "buffer".to_string(),
2327 is_tile: false,
2328 dimensions: vec![256],
2329 element_type: "float".to_string(),
2330 },
2331 );
2332
2333 assert!(!transpiler.shared_vars.get("buffer").unwrap().is_tile);
2334 assert_eq!(
2335 transpiler.shared_vars.get("buffer").unwrap().dimensions,
2336 vec![256]
2337 );
2338 }
2339
2340 #[test]
2341 fn test_full_kernel_with_shared_memory() {
2342 use crate::shared::SharedMemoryConfig;
2344
2345 let mut config = SharedMemoryConfig::new();
2346 config.add_tile("smem", "float", 16, 16);
2347
2348 let decls = config.generate_declarations(" ");
2349 assert!(decls.contains("__shared__ float smem[16][16];"));
2350 assert!(!config.is_empty());
2351 }
2352
2353 #[test]
2356 fn test_struct_literal_transpile() {
2357 let transpiler = CudaTranspiler::new_generic();
2358
2359 let expr: Expr = parse_quote! {
2360 Point { x: 1.0, y: 2.0 }
2361 };
2362
2363 let result = transpiler.transpile_expr(&expr).unwrap();
2364 assert!(
2365 result.contains("Point"),
2366 "Should contain struct name: {}",
2367 result
2368 );
2369 assert!(result.contains(".x ="), "Should have field x: {}", result);
2370 assert!(result.contains(".y ="), "Should have field y: {}", result);
2371 assert!(
2372 result.contains("1.0f"),
2373 "Should have value 1.0f: {}",
2374 result
2375 );
2376 assert!(
2377 result.contains("2.0f"),
2378 "Should have value 2.0f: {}",
2379 result
2380 );
2381
2382 println!("Generated struct literal: {}", result);
2383 }
2384
2385 #[test]
2386 fn test_struct_literal_with_expressions() {
2387 let transpiler = CudaTranspiler::new_generic();
2388
2389 let expr: Expr = parse_quote! {
2390 Response { value: x * 2.0, id: idx as u64 }
2391 };
2392
2393 let result = transpiler.transpile_expr(&expr).unwrap();
2394 assert!(
2395 result.contains("Response"),
2396 "Should contain struct name: {}",
2397 result
2398 );
2399 assert!(
2400 result.contains(".value = x * 2.0f"),
2401 "Should have computed value: {}",
2402 result
2403 );
2404 assert!(result.contains(".id ="), "Should have id field: {}", result);
2405
2406 println!("Generated struct with expressions: {}", result);
2407 }
2408
2409 #[test]
2410 fn test_struct_literal_in_return() {
2411 let mut transpiler = CudaTranspiler::new_generic();
2412
2413 let stmt: Stmt = parse_quote! {
2414 return MyStruct { a: 1, b: 2.0 };
2415 };
2416
2417 let result = transpiler.transpile_stmt(&stmt).unwrap();
2418 assert!(result.contains("return"), "Should have return: {}", result);
2419 assert!(
2420 result.contains("MyStruct"),
2421 "Should contain struct name: {}",
2422 result
2423 );
2424
2425 println!("Generated return with struct: {}", result);
2426 }
2427
2428 #[test]
2429 fn test_struct_literal_compound_literal_format() {
2430 let transpiler = CudaTranspiler::new_generic();
2431
2432 let expr: Expr = parse_quote! {
2433 Vec3 { x: a, y: b, z: c }
2434 };
2435
2436 let result = transpiler.transpile_expr(&expr).unwrap();
2437 assert!(
2439 result.starts_with("(Vec3){"),
2440 "Should use compound literal format: {}",
2441 result
2442 );
2443 assert!(
2444 result.ends_with("}"),
2445 "Should end with closing brace: {}",
2446 result
2447 );
2448
2449 println!("Generated compound literal: {}", result);
2450 }
2451
2452 #[test]
2455 fn test_reference_to_array_element() {
2456 let transpiler = CudaTranspiler::new_generic();
2457
2458 let expr: Expr = parse_quote! {
2459 &arr[idx]
2460 };
2461
2462 let result = transpiler.transpile_expr(&expr).unwrap();
2463 assert_eq!(
2464 result, "&arr[idx]",
2465 "Should produce address-of array element"
2466 );
2467 }
2468
2469 #[test]
2470 fn test_mutable_reference_to_array_element() {
2471 let transpiler = CudaTranspiler::new_generic();
2472
2473 let expr: Expr = parse_quote! {
2474 &mut arr[idx * 4 + offset]
2475 };
2476
2477 let result = transpiler.transpile_expr(&expr).unwrap();
2478 assert!(
2479 result.contains("&arr["),
2480 "Should produce address-of: {}",
2481 result
2482 );
2483 assert!(
2484 result.contains("idx * 4"),
2485 "Should have index expression: {}",
2486 result
2487 );
2488 }
2489
2490 #[test]
2491 fn test_reference_to_variable() {
2492 let transpiler = CudaTranspiler::new_generic();
2493
2494 let expr: Expr = parse_quote! {
2495 &value
2496 };
2497
2498 let result = transpiler.transpile_expr(&expr).unwrap();
2499 assert_eq!(result, "&value", "Should produce address-of variable");
2500 }
2501
2502 #[test]
2503 fn test_reference_to_struct_field() {
2504 let transpiler = CudaTranspiler::new_generic();
2505
2506 let expr: Expr = parse_quote! {
2507 &alerts[(idx as usize) * 4 + alert_idx as usize]
2508 };
2509
2510 let result = transpiler.transpile_expr(&expr).unwrap();
2511 assert!(
2512 result.starts_with("&alerts["),
2513 "Should have address-of array: {}",
2514 result
2515 );
2516
2517 println!("Generated reference: {}", result);
2518 }
2519
2520 #[test]
2521 fn test_complex_reference_pattern() {
2522 let mut transpiler = CudaTranspiler::new_generic();
2523
2524 let stmt: Stmt = parse_quote! {
2526 let alert = &mut alerts[(idx as usize) * 4 + alert_idx as usize];
2527 };
2528
2529 let result = transpiler.transpile_stmt(&stmt).unwrap();
2530 assert!(
2531 result.contains("alert ="),
2532 "Should have variable assignment: {}",
2533 result
2534 );
2535 assert!(
2536 result.contains("&alerts["),
2537 "Should have reference to array: {}",
2538 result
2539 );
2540
2541 println!("Generated statement: {}", result);
2542 }
2543}