1use crate::handler::{ContextMethod, HandlerSignature};
6use crate::intrinsics::{IntrinsicRegistry, RingKernelIntrinsic, StencilIntrinsic};
7use crate::loops::{extract_loop_var, RangeInfo};
8use crate::shared::{rust_to_cuda_element_type, SharedMemoryConfig, SharedMemoryDecl};
9use crate::stencil::StencilConfig;
10use crate::types::{is_grid_pos_type, is_ring_context_type, TypeMapper};
11use crate::validation::ValidationMode;
12use crate::{Result, TranspileError};
13use quote::ToTokens;
14use syn::{
15 BinOp, Expr, ExprAssign, ExprBinary, ExprBreak, ExprCall, ExprCast, ExprContinue, ExprForLoop,
16 ExprIf, ExprIndex, ExprLet, ExprLit, ExprLoop, ExprMatch, ExprMethodCall, ExprParen, ExprPath,
17 ExprReference, ExprReturn, ExprStruct, ExprUnary, ExprWhile, FnArg, ItemFn, Lit, Pat,
18 ReturnType, Stmt, UnOp,
19};
20
21pub struct CudaTranspiler {
23 config: Option<StencilConfig>,
25 type_mapper: TypeMapper,
27 intrinsics: IntrinsicRegistry,
29 grid_pos_vars: Vec<String>,
31 context_vars: Vec<String>,
33 indent: usize,
35 validation_mode: ValidationMode,
37 shared_memory: SharedMemoryConfig,
39 pub shared_vars: std::collections::HashMap<String, SharedVarInfo>,
41 ring_kernel_mode: bool,
43 pointer_vars: std::collections::HashSet<String>,
45}
46
47#[derive(Debug, Clone)]
49pub struct SharedVarInfo {
50 pub name: String,
52 pub is_tile: bool,
54 pub dimensions: Vec<usize>,
56 pub element_type: String,
58}
59
60impl CudaTranspiler {
61 pub fn new(config: StencilConfig) -> Self {
63 Self {
64 config: Some(config),
65 type_mapper: TypeMapper::new(),
66 intrinsics: IntrinsicRegistry::new(),
67 grid_pos_vars: Vec::new(),
68 context_vars: Vec::new(),
69 indent: 1, validation_mode: ValidationMode::Stencil,
71 shared_memory: SharedMemoryConfig::new(),
72 shared_vars: std::collections::HashMap::new(),
73 ring_kernel_mode: false,
74 pointer_vars: std::collections::HashSet::new(),
75 }
76 }
77
78 pub fn new_generic() -> Self {
80 Self {
81 config: None,
82 type_mapper: TypeMapper::new(),
83 intrinsics: IntrinsicRegistry::new(),
84 grid_pos_vars: Vec::new(),
85 context_vars: Vec::new(),
86 indent: 1,
87 validation_mode: ValidationMode::Generic,
88 shared_memory: SharedMemoryConfig::new(),
89 shared_vars: std::collections::HashMap::new(),
90 ring_kernel_mode: false,
91 pointer_vars: std::collections::HashSet::new(),
92 }
93 }
94
95 pub fn with_mode(mode: ValidationMode) -> Self {
97 Self {
98 config: None,
99 type_mapper: TypeMapper::new(),
100 intrinsics: IntrinsicRegistry::new(),
101 grid_pos_vars: Vec::new(),
102 context_vars: Vec::new(),
103 indent: 1,
104 validation_mode: mode,
105 shared_memory: SharedMemoryConfig::new(),
106 shared_vars: std::collections::HashMap::new(),
107 ring_kernel_mode: false,
108 pointer_vars: std::collections::HashSet::new(),
109 }
110 }
111
112 pub fn for_ring_kernel() -> Self {
114 Self {
115 config: None,
116 type_mapper: crate::types::ring_kernel_type_mapper(),
117 intrinsics: IntrinsicRegistry::new(),
118 grid_pos_vars: Vec::new(),
119 context_vars: Vec::new(),
120 indent: 2, validation_mode: ValidationMode::Generic,
122 shared_memory: SharedMemoryConfig::new(),
123 shared_vars: std::collections::HashMap::new(),
124 ring_kernel_mode: true,
125 pointer_vars: std::collections::HashSet::new(),
126 }
127 }
128
129 pub fn set_validation_mode(&mut self, mode: ValidationMode) {
131 self.validation_mode = mode;
132 }
133
134 pub fn shared_memory(&self) -> &SharedMemoryConfig {
136 &self.shared_memory
137 }
138
139 fn indent_str(&self) -> String {
141 " ".repeat(self.indent)
142 }
143
144 pub fn transpile_stencil(&mut self, func: &ItemFn) -> Result<String> {
146 let config = self
147 .config
148 .as_ref()
149 .ok_or_else(|| TranspileError::Unsupported("No stencil config provided".into()))?
150 .clone();
151
152 for param in &func.sig.inputs {
154 if let FnArg::Typed(pat_type) = param {
155 if is_grid_pos_type(&pat_type.ty) {
156 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
157 self.grid_pos_vars.push(ident.ident.to_string());
158 }
159 }
160 }
161 }
162
163 let signature = self.transpile_kernel_signature(func)?;
165
166 let preamble = config.generate_preamble();
168
169 let body = self.transpile_block(&func.block)?;
171
172 Ok(format!(
173 "extern \"C\" __global__ void {signature} {{\n{preamble}\n{body}}}\n"
174 ))
175 }
176
177 pub fn transpile_generic_kernel(&mut self, func: &ItemFn) -> Result<String> {
183 let signature = self.transpile_generic_kernel_signature(func)?;
185
186 let body = self.transpile_block(&func.block)?;
188
189 Ok(format!(
190 "extern \"C\" __global__ void {signature} {{\n{body}}}\n"
191 ))
192 }
193
194 pub fn transpile_ring_kernel(
199 &mut self,
200 handler: &ItemFn,
201 config: &crate::ring_kernel::RingKernelConfig,
202 ) -> Result<String> {
203 use std::fmt::Write;
204
205 let handler_sig = HandlerSignature::parse(handler, &self.type_mapper)?;
207
208 for param in &handler.sig.inputs {
210 if let FnArg::Typed(pat_type) = param {
211 if is_ring_context_type(&pat_type.ty) {
212 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
213 self.context_vars.push(ident.ident.to_string());
214 }
215 }
216 }
217 }
218
219 self.ring_kernel_mode = true;
221
222 let mut output = String::new();
223
224 output.push_str(&crate::ring_kernel::generate_control_block_struct());
226 output.push('\n');
227
228 if config.enable_hlc {
229 output.push_str(&crate::ring_kernel::generate_hlc_struct());
230 output.push('\n');
231 }
232
233 if config.use_envelope_format {
235 output.push_str(&crate::ring_kernel::generate_message_envelope_structs());
236 output.push('\n');
237 }
238
239 if config.enable_k2k {
240 output.push_str(&crate::ring_kernel::generate_k2k_structs());
241 output.push('\n');
242 }
243
244 if let Some(ref msg_param) = handler_sig.message_param {
246 let type_name = msg_param
248 .rust_type
249 .trim_start_matches('&')
250 .trim_start_matches("mut ")
251 .trim();
252 if !type_name.is_empty() && type_name != "f32" && type_name != "i32" {
253 writeln!(output, "// Message type: {}", type_name).unwrap();
254 }
255 }
256
257 if let Some(ref ret_type) = handler_sig.return_type {
258 if ret_type.is_struct {
259 writeln!(output, "// Response type: {}", ret_type.rust_type).unwrap();
260 }
261 }
262
263 output.push_str(&config.generate_signature());
265 output.push_str(" {\n");
266
267 output.push_str(&config.generate_preamble(" "));
269
270 output.push_str(&config.generate_loop_header(" "));
272
273 if let Some(ref msg_param) = handler_sig.message_param {
275 let type_name = msg_param
276 .rust_type
277 .trim_start_matches('&')
278 .trim_start_matches("mut ")
279 .trim();
280 if !type_name.is_empty() {
281 if config.use_envelope_format {
282 writeln!(
284 output,
285 " // Message deserialization (envelope format)"
286 )
287 .unwrap();
288 writeln!(
289 output,
290 " // msg_header contains: message_id, correlation_id, source_kernel, timestamp"
291 )
292 .unwrap();
293 writeln!(
294 output,
295 " {}* {} = ({}*)msg_ptr;",
296 type_name, msg_param.name, type_name
297 )
298 .unwrap();
299 } else {
300 writeln!(output, " // Message deserialization (raw format)").unwrap();
302 writeln!(
303 output,
304 " {}* {} = ({}*)msg_ptr;",
305 type_name, msg_param.name, type_name
306 )
307 .unwrap();
308 }
309 output.push('\n');
310 }
311 }
312
313 self.indent = 2; let handler_body = self.transpile_block(&handler.block)?;
316
317 writeln!(output, " // === USER HANDLER CODE ===").unwrap();
319 for line in handler_body.lines() {
320 if !line.trim().is_empty() {
321 writeln!(output, " {}", line).unwrap();
323 }
324 }
325 writeln!(output, " // === END HANDLER CODE ===").unwrap();
326
327 if let Some(ref ret_type) = handler_sig.return_type {
329 writeln!(output).unwrap();
330 if config.use_envelope_format {
331 writeln!(
333 output,
334 " // Response serialization (envelope format)"
335 )
336 .unwrap();
337 writeln!(
338 output,
339 " unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;"
340 )
341 .unwrap();
342 writeln!(
343 output,
344 " unsigned char* resp_envelope = &output_buffer[_out_idx * RESP_SIZE];"
345 )
346 .unwrap();
347 writeln!(
348 output,
349 " MessageHeader* resp_header = (MessageHeader*)resp_envelope;"
350 )
351 .unwrap();
352 writeln!(output, " message_create_response_header(").unwrap();
353 writeln!(output, " resp_header, msg_header, KERNEL_ID,").unwrap();
354 writeln!(
355 output,
356 " sizeof({}), hlc_physical, hlc_logical, HLC_NODE_ID",
357 ret_type.cuda_type
358 )
359 .unwrap();
360 writeln!(output, " );").unwrap();
361 writeln!(
362 output,
363 " memcpy(resp_envelope + MESSAGE_HEADER_SIZE, &response, sizeof({}));",
364 ret_type.cuda_type
365 )
366 .unwrap();
367 writeln!(output, " __threadfence();").unwrap();
368 } else {
369 writeln!(output, " // Response serialization (raw format)").unwrap();
371 if ret_type.is_struct {
372 writeln!(
373 output,
374 " unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;"
375 )
376 .unwrap();
377 writeln!(
378 output,
379 " memcpy(&output_buffer[_out_idx * RESP_SIZE], &response, sizeof({}));",
380 ret_type.cuda_type
381 )
382 .unwrap();
383 }
384 }
385 }
386
387 output.push_str(&config.generate_message_complete(" "));
389
390 output.push_str(&config.generate_loop_footer(" "));
392
393 output.push_str(&config.generate_epilogue(" "));
395
396 output.push_str("}\n");
397
398 Ok(output)
399 }
400
401 fn transpile_generic_kernel_signature(&self, func: &ItemFn) -> Result<String> {
403 let name = func.sig.ident.to_string();
404
405 let mut params = Vec::new();
406 for param in &func.sig.inputs {
407 if let FnArg::Typed(pat_type) = param {
408 let param_name = match pat_type.pat.as_ref() {
409 Pat::Ident(ident) => ident.ident.to_string(),
410 _ => {
411 return Err(TranspileError::Unsupported(
412 "Complex pattern in parameter".into(),
413 ))
414 }
415 };
416
417 let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
418 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
419 }
420 }
421
422 Ok(format!("{}({})", name, params.join(", ")))
423 }
424
425 fn transpile_kernel_signature(&self, func: &ItemFn) -> Result<String> {
427 let name = func.sig.ident.to_string();
428
429 let mut params = Vec::new();
430 for param in &func.sig.inputs {
431 if let FnArg::Typed(pat_type) = param {
432 if is_grid_pos_type(&pat_type.ty) {
434 continue;
435 }
436
437 let param_name = match pat_type.pat.as_ref() {
438 Pat::Ident(ident) => ident.ident.to_string(),
439 _ => {
440 return Err(TranspileError::Unsupported(
441 "Complex pattern in parameter".into(),
442 ))
443 }
444 };
445
446 let cuda_type = self.type_mapper.map_type(&pat_type.ty)?;
447 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
448 }
449 }
450
451 Ok(format!("{}({})", name, params.join(", ")))
452 }
453
454 fn transpile_block(&mut self, block: &syn::Block) -> Result<String> {
456 let mut output = String::new();
457
458 for stmt in &block.stmts {
459 let stmt_str = self.transpile_stmt(stmt)?;
460 if !stmt_str.is_empty() {
461 output.push_str(&stmt_str);
462 }
463 }
464
465 Ok(output)
466 }
467
468 fn transpile_stmt(&mut self, stmt: &Stmt) -> Result<String> {
470 match stmt {
471 Stmt::Local(local) => {
472 let indent = self.indent_str();
473
474 let var_name = match &local.pat {
476 Pat::Ident(ident) => ident.ident.to_string(),
477 Pat::Type(pat_type) => {
478 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
479 ident.ident.to_string()
480 } else {
481 return Err(TranspileError::Unsupported(
482 "Complex pattern in let binding".into(),
483 ));
484 }
485 }
486 _ => {
487 return Err(TranspileError::Unsupported(
488 "Complex pattern in let binding".into(),
489 ))
490 }
491 };
492
493 if let Some(shared_decl) = self.try_parse_shared_declaration(local, &var_name)? {
495 self.shared_vars.insert(
497 var_name.clone(),
498 SharedVarInfo {
499 name: var_name.clone(),
500 is_tile: shared_decl.dimensions.len() == 2,
501 dimensions: shared_decl.dimensions.clone(),
502 element_type: shared_decl.element_type.clone(),
503 },
504 );
505
506 self.shared_memory.add(shared_decl.clone());
508
509 return Ok(format!("{indent}{}\n", shared_decl.to_cuda_decl()));
511 }
512
513 if let Some(init) = &local.init {
515 let expr_str = self.transpile_expr(&init.expr)?;
516
517 let type_str = self.infer_cuda_type(&init.expr);
520
521 if type_str.ends_with('*') {
523 self.pointer_vars.insert(var_name.clone());
524 }
525
526 Ok(format!("{indent}{type_str} {var_name} = {expr_str};\n"))
527 } else {
528 Ok(format!("{indent}float {var_name};\n"))
530 }
531 }
532 Stmt::Expr(expr, semi) => {
533 let indent = self.indent_str();
534
535 if let Expr::If(if_expr) = expr {
537 if let Some(Stmt::Expr(Expr::Return(_), _)) = if_expr.then_branch.stmts.first()
539 {
540 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
541 let expr_str = self.transpile_expr(expr)?;
542 return Ok(format!("{indent}{expr_str};\n"));
543 }
544 }
545 }
546
547 let expr_str = self.transpile_expr(expr)?;
548
549 if semi.is_some() {
550 Ok(format!("{indent}{expr_str};\n"))
551 } else {
552 if matches!(expr, Expr::Return(_))
555 || expr_str.starts_with("return")
556 || expr_str.starts_with("if (")
557 {
558 Ok(format!("{indent}{expr_str};\n"))
559 } else {
560 Ok(format!("{indent}return {expr_str};\n"))
561 }
562 }
563 }
564 Stmt::Item(_) => {
565 Err(TranspileError::Unsupported("Item in function body".into()))
567 }
568 Stmt::Macro(_) => Err(TranspileError::Unsupported("Macro in function body".into())),
569 }
570 }
571
572 fn transpile_expr(&self, expr: &Expr) -> Result<String> {
574 match expr {
575 Expr::Lit(lit) => self.transpile_lit(lit),
576 Expr::Path(path) => self.transpile_path(path),
577 Expr::Binary(bin) => self.transpile_binary(bin),
578 Expr::Unary(unary) => self.transpile_unary(unary),
579 Expr::Paren(paren) => self.transpile_paren(paren),
580 Expr::Index(index) => self.transpile_index(index),
581 Expr::Call(call) => self.transpile_call(call),
582 Expr::MethodCall(method) => self.transpile_method_call(method),
583 Expr::If(if_expr) => self.transpile_if(if_expr),
584 Expr::Assign(assign) => self.transpile_assign(assign),
585 Expr::Cast(cast) => self.transpile_cast(cast),
586 Expr::Match(match_expr) => self.transpile_match(match_expr),
587 Expr::Block(block) => {
588 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
590 self.transpile_expr(expr)
591 } else {
592 Err(TranspileError::Unsupported(
593 "Complex block expression".into(),
594 ))
595 }
596 }
597 Expr::Field(field) => {
598 let base = self.transpile_expr(&field.base)?;
600 let member = match &field.member {
601 syn::Member::Named(ident) => ident.to_string(),
602 syn::Member::Unnamed(idx) => idx.index.to_string(),
603 };
604
605 let accessor = if self.pointer_vars.contains(&base) {
607 "->"
608 } else {
609 "."
610 };
611 Ok(format!("{base}{accessor}{member}"))
612 }
613 Expr::Return(ret) => self.transpile_return(ret),
614 Expr::ForLoop(for_loop) => self.transpile_for_loop(for_loop),
615 Expr::While(while_loop) => self.transpile_while_loop(while_loop),
616 Expr::Loop(loop_expr) => self.transpile_infinite_loop(loop_expr),
617 Expr::Break(break_expr) => self.transpile_break(break_expr),
618 Expr::Continue(cont_expr) => self.transpile_continue(cont_expr),
619 Expr::Struct(struct_expr) => self.transpile_struct_literal(struct_expr),
620 Expr::Reference(ref_expr) => self.transpile_reference(ref_expr),
621 Expr::Let(let_expr) => self.transpile_let_expr(let_expr),
622 Expr::Tuple(tuple) => {
623 let elements: Vec<String> = tuple
625 .elems
626 .iter()
627 .map(|e| self.transpile_expr(e))
628 .collect::<Result<_>>()?;
629 Ok(format!("({})", elements.join(", ")))
630 }
631 _ => Err(TranspileError::Unsupported(format!(
632 "Expression type: {}",
633 expr.to_token_stream()
634 ))),
635 }
636 }
637
638 fn transpile_lit(&self, lit: &ExprLit) -> Result<String> {
640 match &lit.lit {
641 Lit::Float(f) => {
642 let s = f.to_string();
643 if s.ends_with("f32") || !s.contains('.') {
645 let num = s.trim_end_matches("f32").trim_end_matches("f64");
646 Ok(format!("{num}f"))
647 } else if s.ends_with("f64") {
648 Ok(s.trim_end_matches("f64").to_string())
649 } else {
650 Ok(format!("{s}f"))
652 }
653 }
654 Lit::Int(i) => Ok(i.to_string()),
655 Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
656 _ => Err(TranspileError::Unsupported(format!(
657 "Literal type: {}",
658 lit.to_token_stream()
659 ))),
660 }
661 }
662
663 fn transpile_path(&self, path: &ExprPath) -> Result<String> {
665 let segments: Vec<_> = path
666 .path
667 .segments
668 .iter()
669 .map(|s| s.ident.to_string())
670 .collect();
671
672 if segments.len() == 1 {
673 Ok(segments[0].clone())
674 } else {
675 Ok(segments.join("::"))
676 }
677 }
678
679 fn transpile_binary(&self, bin: &ExprBinary) -> Result<String> {
681 let left = self.transpile_expr(&bin.left)?;
682 let right = self.transpile_expr(&bin.right)?;
683
684 let op = match bin.op {
685 BinOp::Add(_) => "+",
686 BinOp::Sub(_) => "-",
687 BinOp::Mul(_) => "*",
688 BinOp::Div(_) => "/",
689 BinOp::Rem(_) => "%",
690 BinOp::And(_) => "&&",
691 BinOp::Or(_) => "||",
692 BinOp::BitXor(_) => "^",
693 BinOp::BitAnd(_) => "&",
694 BinOp::BitOr(_) => "|",
695 BinOp::Shl(_) => "<<",
696 BinOp::Shr(_) => ">>",
697 BinOp::Eq(_) => "==",
698 BinOp::Lt(_) => "<",
699 BinOp::Le(_) => "<=",
700 BinOp::Ne(_) => "!=",
701 BinOp::Ge(_) => ">=",
702 BinOp::Gt(_) => ">",
703 BinOp::AddAssign(_) => "+=",
704 BinOp::SubAssign(_) => "-=",
705 BinOp::MulAssign(_) => "*=",
706 BinOp::DivAssign(_) => "/=",
707 BinOp::RemAssign(_) => "%=",
708 BinOp::BitXorAssign(_) => "^=",
709 BinOp::BitAndAssign(_) => "&=",
710 BinOp::BitOrAssign(_) => "|=",
711 BinOp::ShlAssign(_) => "<<=",
712 BinOp::ShrAssign(_) => ">>=",
713 _ => {
714 return Err(TranspileError::Unsupported(format!(
715 "Binary operator: {}",
716 bin.to_token_stream()
717 )))
718 }
719 };
720
721 Ok(format!("{left} {op} {right}"))
722 }
723
724 fn transpile_unary(&self, unary: &ExprUnary) -> Result<String> {
726 let expr = self.transpile_expr(&unary.expr)?;
727
728 let op = match unary.op {
729 UnOp::Neg(_) => "-",
730 UnOp::Not(_) => "!",
731 UnOp::Deref(_) => "*",
732 _ => {
733 return Err(TranspileError::Unsupported(format!(
734 "Unary operator: {}",
735 unary.to_token_stream()
736 )))
737 }
738 };
739
740 Ok(format!("{op}({expr})"))
741 }
742
743 fn transpile_paren(&self, paren: &ExprParen) -> Result<String> {
745 let inner = self.transpile_expr(&paren.expr)?;
746 Ok(format!("({inner})"))
747 }
748
749 fn transpile_index(&self, index: &ExprIndex) -> Result<String> {
751 let base = self.transpile_expr(&index.expr)?;
752 let idx = self.transpile_expr(&index.index)?;
753 Ok(format!("{base}[{idx}]"))
754 }
755
756 fn transpile_call(&self, call: &ExprCall) -> Result<String> {
758 let func = self.transpile_expr(&call.func)?;
759
760 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
762 let cuda_name = intrinsic.to_cuda_string();
763
764 let is_value_intrinsic = cuda_name.contains("Idx.")
767 || cuda_name.contains("Dim.")
768 || cuda_name.starts_with("threadIdx")
769 || cuda_name.starts_with("blockIdx")
770 || cuda_name.starts_with("blockDim")
771 || cuda_name.starts_with("gridDim");
772
773 if is_value_intrinsic && call.args.is_empty() {
774 return Ok(cuda_name.to_string());
776 }
777
778 if call.args.is_empty() && cuda_name.ends_with("()") {
779 return Ok(cuda_name.to_string());
781 }
782
783 let args: Vec<String> = call
784 .args
785 .iter()
786 .map(|a| self.transpile_expr(a))
787 .collect::<Result<_>>()?;
788
789 return Ok(format!(
790 "{}({})",
791 cuda_name.trim_end_matches("()"),
792 args.join(", ")
793 ));
794 }
795
796 let args: Vec<String> = call
798 .args
799 .iter()
800 .map(|a| self.transpile_expr(a))
801 .collect::<Result<_>>()?;
802
803 Ok(format!("{}({})", func, args.join(", ")))
804 }
805
806 fn transpile_method_call(&self, method: &ExprMethodCall) -> Result<String> {
808 let receiver = self.transpile_expr(&method.receiver)?;
809 let method_name = method.method.to_string();
810
811 if let Some(result) =
813 self.try_transpile_shared_method_call(&receiver, &method_name, &method.args)
814 {
815 return result;
816 }
817
818 if self.ring_kernel_mode && self.context_vars.contains(&receiver) {
820 return self.transpile_context_method(&method_name, &method.args);
821 }
822
823 if self.grid_pos_vars.contains(&receiver) {
825 return self.transpile_stencil_intrinsic(&method_name, &method.args);
826 }
827
828 if self.ring_kernel_mode {
830 if let Some(intrinsic) = RingKernelIntrinsic::from_name(&method_name) {
831 let args: Vec<String> = method
832 .args
833 .iter()
834 .map(|a| self.transpile_expr(a).unwrap_or_default())
835 .collect();
836 return Ok(intrinsic.to_cuda(&args));
837 }
838 }
839
840 if let Some(intrinsic) = self.intrinsics.lookup(&method_name) {
842 let cuda_name = intrinsic.to_cuda_string();
843 let args: Vec<String> = std::iter::once(receiver)
844 .chain(
845 method
846 .args
847 .iter()
848 .map(|a| self.transpile_expr(a).unwrap_or_default()),
849 )
850 .collect();
851
852 return Ok(format!("{}({})", cuda_name, args.join(", ")));
853 }
854
855 let args: Vec<String> = method
857 .args
858 .iter()
859 .map(|a| self.transpile_expr(a))
860 .collect::<Result<_>>()?;
861
862 Ok(format!("{}.{}({})", receiver, method_name, args.join(", ")))
863 }
864
865 fn transpile_context_method(
867 &self,
868 method: &str,
869 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
870 ) -> Result<String> {
871 let ctx_method = ContextMethod::from_name(method).ok_or_else(|| {
872 TranspileError::Unsupported(format!("Unknown context method: {}", method))
873 })?;
874
875 let cuda_args: Vec<String> = args
876 .iter()
877 .map(|a| self.transpile_expr(a).unwrap_or_default())
878 .collect();
879
880 Ok(ctx_method.to_cuda(&cuda_args))
881 }
882
883 fn transpile_stencil_intrinsic(
885 &self,
886 method: &str,
887 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
888 ) -> Result<String> {
889 let config = self.config.as_ref().ok_or_else(|| {
890 TranspileError::Unsupported("Stencil intrinsic without config".into())
891 })?;
892
893 let buffer_width = config.buffer_width().to_string();
894 let buffer_slice = format!("{}", config.buffer_width() * config.buffer_height());
895 let is_3d = config.grid == crate::stencil::Grid::Grid3D;
896
897 let intrinsic = StencilIntrinsic::from_method_name(method).ok_or_else(|| {
898 TranspileError::Unsupported(format!("Unknown stencil intrinsic: {method}"))
899 })?;
900
901 if intrinsic.is_3d_only() && !is_3d {
903 return Err(TranspileError::Unsupported(format!(
904 "3D stencil intrinsic '{}' requires Grid3D configuration",
905 method
906 )));
907 }
908
909 match intrinsic {
910 StencilIntrinsic::Index => {
911 Ok("idx".to_string())
913 }
914 StencilIntrinsic::North
915 | StencilIntrinsic::South
916 | StencilIntrinsic::East
917 | StencilIntrinsic::West => {
918 if args.is_empty() {
920 return Err(TranspileError::Unsupported(
921 "Stencil accessor requires buffer argument".into(),
922 ));
923 }
924 let buffer = self.transpile_expr(&args[0])?;
925 if is_3d {
926 Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
927 } else {
928 Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx"))
929 }
930 }
931 StencilIntrinsic::Up | StencilIntrinsic::Down => {
932 if args.is_empty() {
934 return Err(TranspileError::Unsupported(
935 "3D stencil accessor requires buffer argument".into(),
936 ));
937 }
938 let buffer = self.transpile_expr(&args[0])?;
939 Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx"))
940 }
941 StencilIntrinsic::At => {
942 if is_3d {
945 if args.len() < 4 {
946 return Err(TranspileError::Unsupported(
947 "at() in 3D requires buffer, dx, dy, dz arguments".into(),
948 ));
949 }
950 let buffer = self.transpile_expr(&args[0])?;
951 let dx = self.transpile_expr(&args[1])?;
952 let dy = self.transpile_expr(&args[2])?;
953 let dz = self.transpile_expr(&args[3])?;
954 Ok(format!(
955 "{buffer}[idx + ({dz}) * {buffer_slice} + ({dy}) * {buffer_width} + ({dx})]"
956 ))
957 } else {
958 if args.len() < 3 {
959 return Err(TranspileError::Unsupported(
960 "at() requires buffer, dx, dy arguments".into(),
961 ));
962 }
963 let buffer = self.transpile_expr(&args[0])?;
964 let dx = self.transpile_expr(&args[1])?;
965 let dy = self.transpile_expr(&args[2])?;
966 Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]"))
967 }
968 }
969 }
970 }
971
972 fn transpile_if(&self, if_expr: &ExprIf) -> Result<String> {
974 let cond = self.transpile_expr(&if_expr.cond)?;
975
976 if let Some(Stmt::Expr(Expr::Return(ret), _)) = if_expr.then_branch.stmts.first() {
978 if if_expr.then_branch.stmts.len() == 1 && if_expr.else_branch.is_none() {
979 if ret.expr.is_none() {
981 return Ok(format!("if ({cond}) return"));
982 }
983 let ret_val = self.transpile_expr(ret.expr.as_ref().unwrap())?;
984 return Ok(format!("if ({cond}) return {ret_val}"));
985 }
986 }
987
988 if let Some((_, else_branch)) = &if_expr.else_branch {
990 if let (Some(Stmt::Expr(then_expr, None)), Expr::Block(else_block)) =
992 (if_expr.then_branch.stmts.last(), else_branch.as_ref())
993 {
994 if let Some(Stmt::Expr(else_expr, None)) = else_block.block.stmts.last() {
995 let then_str = self.transpile_expr(then_expr)?;
996 let else_str = self.transpile_expr(else_expr)?;
997 return Ok(format!("({cond}) ? ({then_str}) : ({else_str})"));
998 }
999 }
1000
1001 if let Expr::If(else_if) = else_branch.as_ref() {
1003 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
1005 let else_part = self.transpile_if(else_if)?;
1006 return Ok(format!("if ({cond}) {{{then_body}}} else {else_part}"));
1007 } else if let Expr::Block(else_block) = else_branch.as_ref() {
1008 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
1010 let else_body = self.transpile_if_body(&else_block.block)?;
1011 return Ok(format!("if ({cond}) {{{then_body}}} else {{{else_body}}}"));
1012 }
1013 }
1014
1015 let then_body = self.transpile_if_body(&if_expr.then_branch)?;
1017 Ok(format!("if ({cond}) {{{then_body}}}"))
1018 }
1019
1020 fn transpile_if_body(&self, block: &syn::Block) -> Result<String> {
1022 let mut body = String::new();
1023 for stmt in &block.stmts {
1024 match stmt {
1025 Stmt::Expr(expr, Some(_)) => {
1026 let expr_str = self.transpile_expr(expr)?;
1027 body.push_str(&format!(" {expr_str};"));
1028 }
1029 Stmt::Expr(Expr::Return(ret), None) => {
1030 if let Some(ret_expr) = &ret.expr {
1032 let expr_str = self.transpile_expr(ret_expr)?;
1033 body.push_str(&format!(" return {expr_str};"));
1034 } else {
1035 body.push_str(" return;");
1036 }
1037 }
1038 Stmt::Expr(expr, None) => {
1039 let expr_str = self.transpile_expr(expr)?;
1040 body.push_str(&format!(" return {expr_str};"));
1041 }
1042 _ => {}
1043 }
1044 }
1045 Ok(body)
1046 }
1047
1048 fn transpile_assign(&self, assign: &ExprAssign) -> Result<String> {
1050 let left = self.transpile_expr(&assign.left)?;
1051 let right = self.transpile_expr(&assign.right)?;
1052 Ok(format!("{left} = {right}"))
1053 }
1054
1055 fn transpile_cast(&self, cast: &ExprCast) -> Result<String> {
1057 let expr = self.transpile_expr(&cast.expr)?;
1058 let cuda_type = self.type_mapper.map_type(&cast.ty)?;
1059 Ok(format!("({})({})", cuda_type.to_cuda_string(), expr))
1060 }
1061
1062 fn transpile_return(&self, ret: &ExprReturn) -> Result<String> {
1064 if let Some(expr) = &ret.expr {
1065 let expr_str = self.transpile_expr(expr)?;
1066 Ok(format!("return {expr_str}"))
1067 } else {
1068 Ok("return".to_string())
1069 }
1070 }
1071
1072 fn transpile_struct_literal(&self, struct_expr: &ExprStruct) -> Result<String> {
1077 let type_name = struct_expr
1079 .path
1080 .segments
1081 .iter()
1082 .map(|s| s.ident.to_string())
1083 .collect::<Vec<_>>()
1084 .join("::");
1085
1086 let mut fields = Vec::new();
1088 for field in &struct_expr.fields {
1089 let field_name = match &field.member {
1090 syn::Member::Named(ident) => ident.to_string(),
1091 syn::Member::Unnamed(idx) => idx.index.to_string(),
1092 };
1093 let value = self.transpile_expr(&field.expr)?;
1094 fields.push(format!(".{} = {}", field_name, value));
1095 }
1096
1097 if struct_expr.rest.is_some() {
1099 return Err(TranspileError::Unsupported(
1100 "Struct update syntax (..base) is not supported in CUDA".into(),
1101 ));
1102 }
1103
1104 Ok(format!("({}){{ {} }}", type_name, fields.join(", ")))
1106 }
1107
1108 fn transpile_reference(&self, ref_expr: &ExprReference) -> Result<String> {
1115 let inner = self.transpile_expr(&ref_expr.expr)?;
1116
1117 Ok(format!("&{inner}"))
1121 }
1122
1123 fn transpile_let_expr(&self, let_expr: &ExprLet) -> Result<String> {
1129 let _ = let_expr; Err(TranspileError::Unsupported(
1133 "let expressions (if-let patterns) are not directly supported in CUDA. \
1134 Use explicit comparisons instead."
1135 .into(),
1136 ))
1137 }
1138
1139 fn transpile_for_loop(&self, for_loop: &ExprForLoop) -> Result<String> {
1159 if !self.validation_mode.allows_loops() {
1161 return Err(TranspileError::Unsupported(
1162 "Loops are not allowed in stencil kernels".into(),
1163 ));
1164 }
1165
1166 let var_name = extract_loop_var(&for_loop.pat)
1168 .ok_or_else(|| TranspileError::Unsupported("Complex pattern in for loop".into()))?;
1169
1170 let header = match for_loop.expr.as_ref() {
1172 Expr::Range(range) => {
1173 let range_info = RangeInfo::from_range(range, |e| self.transpile_expr(e));
1174 range_info.to_cuda_for_header(&var_name, "int")
1175 }
1176 _ => {
1177 return Err(TranspileError::Unsupported(
1179 "Only range expressions (start..end) are supported in for loops".into(),
1180 ));
1181 }
1182 };
1183
1184 let body = self.transpile_loop_body(&for_loop.body)?;
1186
1187 Ok(format!("{header} {{\n{body}}}"))
1188 }
1189
1190 fn transpile_while_loop(&self, while_loop: &ExprWhile) -> Result<String> {
1206 if !self.validation_mode.allows_loops() {
1208 return Err(TranspileError::Unsupported(
1209 "Loops are not allowed in stencil kernels".into(),
1210 ));
1211 }
1212
1213 let condition = self.transpile_expr(&while_loop.cond)?;
1215
1216 let body = self.transpile_loop_body(&while_loop.body)?;
1218
1219 Ok(format!("while ({condition}) {{\n{body}}}"))
1220 }
1221
1222 fn transpile_infinite_loop(&self, loop_expr: &ExprLoop) -> Result<String> {
1238 if !self.validation_mode.allows_loops() {
1240 return Err(TranspileError::Unsupported(
1241 "Loops are not allowed in stencil kernels".into(),
1242 ));
1243 }
1244
1245 let body = self.transpile_loop_body(&loop_expr.body)?;
1247
1248 Ok(format!("while (true) {{\n{body}}}"))
1250 }
1251
1252 fn transpile_break(&self, break_expr: &ExprBreak) -> Result<String> {
1254 if break_expr.label.is_some() {
1256 return Err(TranspileError::Unsupported(
1257 "Labeled break is not supported in CUDA".into(),
1258 ));
1259 }
1260
1261 if break_expr.expr.is_some() {
1263 return Err(TranspileError::Unsupported(
1264 "Break with value is not supported in CUDA".into(),
1265 ));
1266 }
1267
1268 Ok("break".to_string())
1269 }
1270
1271 fn transpile_continue(&self, cont_expr: &ExprContinue) -> Result<String> {
1273 if cont_expr.label.is_some() {
1275 return Err(TranspileError::Unsupported(
1276 "Labeled continue is not supported in CUDA".into(),
1277 ));
1278 }
1279
1280 Ok("continue".to_string())
1281 }
1282
1283 fn transpile_loop_body(&self, block: &syn::Block) -> Result<String> {
1285 let mut output = String::new();
1286 let inner_indent = " ".repeat(self.indent + 1);
1287
1288 for stmt in &block.stmts {
1289 match stmt {
1290 Stmt::Local(local) => {
1291 let var_name = match &local.pat {
1293 Pat::Ident(ident) => ident.ident.to_string(),
1294 Pat::Type(pat_type) => {
1295 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1296 ident.ident.to_string()
1297 } else {
1298 return Err(TranspileError::Unsupported(
1299 "Complex pattern in let binding".into(),
1300 ));
1301 }
1302 }
1303 _ => {
1304 return Err(TranspileError::Unsupported(
1305 "Complex pattern in let binding".into(),
1306 ))
1307 }
1308 };
1309
1310 if let Some(init) = &local.init {
1311 let expr_str = self.transpile_expr(&init.expr)?;
1312 let type_str = self.infer_cuda_type(&init.expr);
1313 output.push_str(&format!(
1314 "{inner_indent}{type_str} {var_name} = {expr_str};\n"
1315 ));
1316 } else {
1317 output.push_str(&format!("{inner_indent}float {var_name};\n"));
1318 }
1319 }
1320 Stmt::Expr(expr, semi) => {
1321 let expr_str = self.transpile_expr(expr)?;
1322 if semi.is_some() {
1323 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1324 } else {
1325 output.push_str(&format!("{inner_indent}{expr_str};\n"));
1327 }
1328 }
1329 _ => {
1330 return Err(TranspileError::Unsupported(
1331 "Unsupported statement in loop body".into(),
1332 ));
1333 }
1334 }
1335 }
1336
1337 let closing_indent = " ".repeat(self.indent);
1339 output.push_str(&closing_indent);
1340
1341 Ok(output)
1342 }
1343
1344 fn try_parse_shared_declaration(
1353 &self,
1354 local: &syn::Local,
1355 var_name: &str,
1356 ) -> Result<Option<SharedMemoryDecl>> {
1357 if let Pat::Type(pat_type) = &local.pat {
1359 let type_str = pat_type.ty.to_token_stream().to_string();
1360 return self.parse_shared_type(&type_str, var_name);
1361 }
1362
1363 if let Some(init) = &local.init {
1365 if let Expr::Call(call) = init.expr.as_ref() {
1366 if let Expr::Path(path) = call.func.as_ref() {
1367 let path_str = path.to_token_stream().to_string();
1368 return self.parse_shared_type(&path_str, var_name);
1369 }
1370 }
1371 }
1372
1373 Ok(None)
1374 }
1375
1376 fn parse_shared_type(
1378 &self,
1379 type_str: &str,
1380 var_name: &str,
1381 ) -> Result<Option<SharedMemoryDecl>> {
1382 let type_str = type_str
1384 .replace(" :: ", "::")
1385 .replace(" ::", "::")
1386 .replace(":: ", "::");
1387
1388 if type_str.contains("SharedTile") {
1390 if let Some(start) = type_str.find('<') {
1392 if let Some(end) = type_str.rfind('>') {
1393 let params = &type_str[start + 1..end];
1394 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1395
1396 if parts.len() >= 3 {
1397 let rust_type = parts[0];
1398 let width: usize = parts[1].parse().map_err(|_| {
1399 TranspileError::Unsupported("Invalid SharedTile width".into())
1400 })?;
1401 let height: usize = parts[2].parse().map_err(|_| {
1402 TranspileError::Unsupported("Invalid SharedTile height".into())
1403 })?;
1404
1405 let cuda_type = rust_to_cuda_element_type(rust_type);
1406 return Ok(Some(SharedMemoryDecl::tile(
1407 var_name, cuda_type, width, height,
1408 )));
1409 }
1410 }
1411 }
1412 }
1413
1414 if type_str.contains("SharedArray") {
1416 if let Some(start) = type_str.find('<') {
1417 if let Some(end) = type_str.rfind('>') {
1418 let params = &type_str[start + 1..end];
1419 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
1420
1421 if parts.len() >= 2 {
1422 let rust_type = parts[0];
1423 let size: usize = parts[1].parse().map_err(|_| {
1424 TranspileError::Unsupported("Invalid SharedArray size".into())
1425 })?;
1426
1427 let cuda_type = rust_to_cuda_element_type(rust_type);
1428 return Ok(Some(SharedMemoryDecl::array(var_name, cuda_type, size)));
1429 }
1430 }
1431 }
1432 }
1433
1434 Ok(None)
1435 }
1436
1437 fn try_transpile_shared_method_call(
1439 &self,
1440 receiver: &str,
1441 method_name: &str,
1442 args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>,
1443 ) -> Option<Result<String>> {
1444 let shared_info = self.shared_vars.get(receiver)?;
1445
1446 match method_name {
1447 "get" => {
1448 if shared_info.is_tile {
1450 if args.len() >= 2 {
1451 let x = self.transpile_expr(&args[0]).ok()?;
1452 let y = self.transpile_expr(&args[1]).ok()?;
1453 Some(Ok(format!("{}[{}][{}]", receiver, y, x)))
1455 } else {
1456 Some(Err(TranspileError::Unsupported(
1457 "SharedTile.get requires x and y arguments".into(),
1458 )))
1459 }
1460 } else {
1461 if !args.is_empty() {
1463 let idx = self.transpile_expr(&args[0]).ok()?;
1464 Some(Ok(format!("{}[{}]", receiver, idx)))
1465 } else {
1466 Some(Err(TranspileError::Unsupported(
1467 "SharedArray.get requires index argument".into(),
1468 )))
1469 }
1470 }
1471 }
1472 "set" => {
1473 if shared_info.is_tile {
1475 if args.len() >= 3 {
1476 let x = self.transpile_expr(&args[0]).ok()?;
1477 let y = self.transpile_expr(&args[1]).ok()?;
1478 let val = self.transpile_expr(&args[2]).ok()?;
1479 Some(Ok(format!("{}[{}][{}] = {}", receiver, y, x, val)))
1480 } else {
1481 Some(Err(TranspileError::Unsupported(
1482 "SharedTile.set requires x, y, and value arguments".into(),
1483 )))
1484 }
1485 } else {
1486 if args.len() >= 2 {
1488 let idx = self.transpile_expr(&args[0]).ok()?;
1489 let val = self.transpile_expr(&args[1]).ok()?;
1490 Some(Ok(format!("{}[{}] = {}", receiver, idx, val)))
1491 } else {
1492 Some(Err(TranspileError::Unsupported(
1493 "SharedArray.set requires index and value arguments".into(),
1494 )))
1495 }
1496 }
1497 }
1498 "width" | "height" | "size" => {
1499 match method_name {
1501 "width" if shared_info.is_tile => {
1502 Some(Ok(shared_info.dimensions[1].to_string()))
1503 }
1504 "height" if shared_info.is_tile => {
1505 Some(Ok(shared_info.dimensions[0].to_string()))
1506 }
1507 "size" => {
1508 let total: usize = shared_info.dimensions.iter().product();
1509 Some(Ok(total.to_string()))
1510 }
1511 _ => None,
1512 }
1513 }
1514 _ => None,
1515 }
1516 }
1517
1518 fn transpile_match(&self, match_expr: &ExprMatch) -> Result<String> {
1520 let scrutinee = self.transpile_expr(&match_expr.expr)?;
1521 let mut output = format!("switch ({scrutinee}) {{\n");
1522
1523 for arm in &match_expr.arms {
1524 let case_label = self.transpile_match_pattern(&arm.pat)?;
1526
1527 if case_label == "default" || case_label.starts_with("/*") {
1528 output.push_str(" default: {\n");
1529 } else {
1530 output.push_str(&format!(" case {case_label}: {{\n"));
1531 }
1532
1533 match arm.body.as_ref() {
1535 Expr::Block(block) => {
1536 for stmt in &block.block.stmts {
1538 let stmt_str = self.transpile_stmt_inline(stmt)?;
1539 output.push_str(&format!(" {stmt_str}\n"));
1540 }
1541 }
1542 _ => {
1543 let body = self.transpile_expr(&arm.body)?;
1545 output.push_str(&format!(" {body};\n"));
1546 }
1547 }
1548
1549 output.push_str(" break;\n");
1550 output.push_str(" }\n");
1551 }
1552
1553 output.push_str(" }");
1554 Ok(output)
1555 }
1556
1557 fn transpile_match_pattern(&self, pat: &Pat) -> Result<String> {
1559 match pat {
1560 Pat::Lit(pat_lit) => {
1561 match &pat_lit.lit {
1563 Lit::Int(i) => Ok(i.to_string()),
1564 Lit::Bool(b) => Ok(if b.value { "1" } else { "0" }.to_string()),
1565 _ => Err(TranspileError::Unsupported(
1566 "Non-integer literal in match pattern".into(),
1567 )),
1568 }
1569 }
1570 Pat::Wild(_) => {
1571 Ok("default".to_string())
1573 }
1574 Pat::Ident(ident) => {
1575 Ok(format!("/* {} */ default", ident.ident))
1578 }
1579 Pat::Or(pat_or) => {
1580 if let Some(first) = pat_or.cases.first() {
1584 self.transpile_match_pattern(first)
1585 } else {
1586 Err(TranspileError::Unsupported("Empty or pattern".into()))
1587 }
1588 }
1589 _ => Err(TranspileError::Unsupported(format!(
1590 "Match pattern: {}",
1591 pat.to_token_stream()
1592 ))),
1593 }
1594 }
1595
1596 fn transpile_stmt_inline(&self, stmt: &Stmt) -> Result<String> {
1598 match stmt {
1599 Stmt::Local(local) => {
1600 let var_name = match &local.pat {
1601 Pat::Ident(ident) => ident.ident.to_string(),
1602 Pat::Type(pat_type) => {
1603 if let Pat::Ident(ident) = pat_type.pat.as_ref() {
1604 ident.ident.to_string()
1605 } else {
1606 return Err(TranspileError::Unsupported(
1607 "Complex pattern in let binding".into(),
1608 ));
1609 }
1610 }
1611 _ => {
1612 return Err(TranspileError::Unsupported(
1613 "Complex pattern in let binding".into(),
1614 ))
1615 }
1616 };
1617
1618 if let Some(init) = &local.init {
1619 let expr_str = self.transpile_expr(&init.expr)?;
1620 let type_str = self.infer_cuda_type(&init.expr);
1621 Ok(format!("{type_str} {var_name} = {expr_str};"))
1622 } else {
1623 Ok(format!("float {var_name};"))
1624 }
1625 }
1626 Stmt::Expr(expr, semi) => {
1627 let expr_str = self.transpile_expr(expr)?;
1628 if semi.is_some() {
1629 Ok(format!("{expr_str};"))
1630 } else {
1631 Ok(format!("return {expr_str};"))
1632 }
1633 }
1634 _ => Err(TranspileError::Unsupported(
1635 "Unsupported statement in match arm".into(),
1636 )),
1637 }
1638 }
1639
1640 fn infer_cuda_type(&self, expr: &Expr) -> &'static str {
1642 match expr {
1643 Expr::Lit(lit) => match &lit.lit {
1644 Lit::Float(_) => "float",
1645 Lit::Int(_) => "int",
1646 Lit::Bool(_) => "int",
1647 _ => "float",
1648 },
1649 Expr::Binary(bin) => {
1650 let left_type = self.infer_cuda_type(&bin.left);
1652 let right_type = self.infer_cuda_type(&bin.right);
1653 if left_type == "int" && right_type == "int" {
1655 "int"
1656 } else {
1657 "float"
1658 }
1659 }
1660 Expr::Call(call) => {
1661 if let Ok(func) = self.transpile_expr(&call.func) {
1663 if let Some(intrinsic) = self.intrinsics.lookup(&func) {
1664 let cuda_name = intrinsic.to_cuda_string();
1665 if cuda_name.contains("Idx") || cuda_name.contains("Dim") {
1667 return "int";
1668 }
1669 }
1670 }
1671 "float"
1672 }
1673 Expr::Index(_) => "float", Expr::Cast(cast) => {
1675 if let Ok(cuda_type) = self.type_mapper.map_type(&cast.ty) {
1677 let s = cuda_type.to_cuda_string();
1678 if s.contains("int") || s.contains("size_t") || s == "unsigned long long" {
1679 return "int";
1680 }
1681 }
1682 "float"
1683 }
1684 Expr::Reference(ref_expr) => {
1685 match ref_expr.expr.as_ref() {
1688 Expr::Index(idx_expr) => {
1689 if let Expr::Path(path) = &*idx_expr.expr {
1691 let name = path
1692 .path
1693 .segments
1694 .iter()
1695 .map(|s| s.ident.to_string())
1696 .collect::<Vec<_>>()
1697 .join("::");
1698 if name.contains("transaction") || name.contains("Transaction") {
1700 return "GpuTransaction*";
1701 }
1702 if name.contains("profile") || name.contains("Profile") {
1703 return "GpuCustomerProfile*";
1704 }
1705 if name.contains("alert") || name.contains("Alert") {
1706 return "GpuAlert*";
1707 }
1708 }
1709 "float*" }
1711 _ => "void*",
1712 }
1713 }
1714 Expr::MethodCall(_) => "float",
1715 Expr::Field(field) => {
1716 let member_name = match &field.member {
1718 syn::Member::Named(ident) => ident.to_string(),
1719 syn::Member::Unnamed(idx) => idx.index.to_string(),
1720 };
1721 if member_name.contains("count") || member_name.contains("_count") {
1723 return "unsigned int";
1724 }
1725 if member_name.contains("threshold") || member_name.ends_with("_id") {
1726 return "unsigned long long";
1727 }
1728 if member_name.ends_with("_pct") {
1729 return "unsigned char";
1730 }
1731 "float"
1732 }
1733 Expr::Path(path) => {
1734 let name = path
1736 .path
1737 .segments
1738 .iter()
1739 .map(|s| s.ident.to_string())
1740 .collect::<Vec<_>>()
1741 .join("::");
1742 if name.contains("threshold")
1743 || name.contains("count")
1744 || name == "idx"
1745 || name == "n"
1746 {
1747 return "int";
1748 }
1749 "float"
1750 }
1751 Expr::If(if_expr) => {
1752 if let Some((_, else_branch)) = &if_expr.else_branch {
1754 if let Expr::Block(block) = else_branch.as_ref() {
1755 if let Some(Stmt::Expr(expr, None)) = block.block.stmts.last() {
1756 return self.infer_cuda_type(expr);
1757 }
1758 }
1759 }
1760 if let Some(Stmt::Expr(expr, None)) = if_expr.then_branch.stmts.last() {
1762 return self.infer_cuda_type(expr);
1763 }
1764 "float"
1765 }
1766 _ => "float",
1767 }
1768 }
1769}
1770
1771pub fn transpile_function(func: &ItemFn) -> Result<String> {
1773 let mut transpiler = CudaTranspiler::new_generic();
1774
1775 let name = func.sig.ident.to_string();
1777
1778 let mut params = Vec::new();
1779 for param in &func.sig.inputs {
1780 if let FnArg::Typed(pat_type) = param {
1781 let param_name = match pat_type.pat.as_ref() {
1782 Pat::Ident(ident) => ident.ident.to_string(),
1783 _ => continue,
1784 };
1785
1786 let cuda_type = transpiler.type_mapper.map_type(&pat_type.ty)?;
1787 params.push(format!("{} {}", cuda_type.to_cuda_string(), param_name));
1788 }
1789 }
1790
1791 let return_type = match &func.sig.output {
1793 ReturnType::Default => "void".to_string(),
1794 ReturnType::Type(_, ty) => transpiler.type_mapper.map_type(ty)?.to_cuda_string(),
1795 };
1796
1797 let body = transpiler.transpile_block(&func.block)?;
1799
1800 Ok(format!(
1801 "__device__ {return_type} {name}({params}) {{\n{body}}}\n",
1802 params = params.join(", ")
1803 ))
1804}
1805
1806#[cfg(test)]
1807mod tests {
1808 use super::*;
1809 use syn::parse_quote;
1810
1811 #[test]
1812 fn test_simple_arithmetic() {
1813 let transpiler = CudaTranspiler::new_generic();
1814
1815 let expr: Expr = parse_quote!(a + b * 2.0);
1816 let result = transpiler.transpile_expr(&expr).unwrap();
1817 assert_eq!(result, "a + b * 2.0f");
1818 }
1819
1820 #[test]
1821 fn test_let_binding() {
1822 let mut transpiler = CudaTranspiler::new_generic();
1823
1824 let stmt: Stmt = parse_quote!(let x = a + b;);
1825 let result = transpiler.transpile_stmt(&stmt).unwrap();
1826 assert!(result.contains("float x = a + b;"));
1827 }
1828
1829 #[test]
1830 fn test_array_index() {
1831 let transpiler = CudaTranspiler::new_generic();
1832
1833 let expr: Expr = parse_quote!(data[idx]);
1834 let result = transpiler.transpile_expr(&expr).unwrap();
1835 assert_eq!(result, "data[idx]");
1836 }
1837
1838 #[test]
1839 fn test_stencil_intrinsics() {
1840 let config = StencilConfig::new("test")
1841 .with_tile_size(16, 16)
1842 .with_halo(1);
1843 let mut transpiler = CudaTranspiler::new(config);
1844 transpiler.grid_pos_vars.push("pos".to_string());
1845
1846 let expr: Expr = parse_quote!(pos.idx());
1848 let result = transpiler.transpile_expr(&expr).unwrap();
1849 assert_eq!(result, "idx");
1850
1851 let expr: Expr = parse_quote!(pos.north(p));
1853 let result = transpiler.transpile_expr(&expr).unwrap();
1854 assert_eq!(result, "p[idx - 18]");
1855
1856 let expr: Expr = parse_quote!(pos.east(p));
1858 let result = transpiler.transpile_expr(&expr).unwrap();
1859 assert_eq!(result, "p[idx + 1]");
1860 }
1861
1862 #[test]
1863 fn test_ternary_if() {
1864 let transpiler = CudaTranspiler::new_generic();
1865
1866 let expr: Expr = parse_quote!(if x > 0.0 { x } else { -x });
1867 let result = transpiler.transpile_expr(&expr).unwrap();
1868 assert!(result.contains("?"));
1869 assert!(result.contains(":"));
1870 }
1871
1872 #[test]
1873 fn test_full_stencil_kernel() {
1874 let func: ItemFn = parse_quote! {
1875 fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
1876 let curr = p[pos.idx()];
1877 let prev = p_prev[pos.idx()];
1878 let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
1879 p_prev[pos.idx()] = (2.0 * curr - prev + c2 * lap);
1880 }
1881 };
1882
1883 let config = StencilConfig::new("fdtd")
1884 .with_tile_size(16, 16)
1885 .with_halo(1);
1886
1887 let mut transpiler = CudaTranspiler::new(config);
1888 let cuda = transpiler.transpile_stencil(&func).unwrap();
1889
1890 assert!(cuda.contains("extern \"C\" __global__"));
1892 assert!(cuda.contains("threadIdx.x"));
1893 assert!(cuda.contains("threadIdx.y"));
1894 assert!(cuda.contains("buffer_width = 18"));
1895 assert!(cuda.contains("const float* __restrict__ p"));
1896 assert!(cuda.contains("float* __restrict__ p_prev"));
1897 assert!(!cuda.contains("GridPos")); println!("Generated CUDA:\n{}", cuda);
1900 }
1901
1902 #[test]
1903 fn test_early_return() {
1904 let mut transpiler = CudaTranspiler::new_generic();
1905
1906 let stmt: Stmt = parse_quote!(return;);
1907 let result = transpiler.transpile_stmt(&stmt).unwrap();
1908 assert!(result.contains("return;"));
1909
1910 let stmt_val: Stmt = parse_quote!(return 42;);
1911 let result_val = transpiler.transpile_stmt(&stmt_val).unwrap();
1912 assert!(result_val.contains("return 42;"));
1913 }
1914
1915 #[test]
1916 fn test_match_to_switch() {
1917 let transpiler = CudaTranspiler::new_generic();
1918
1919 let expr: Expr = parse_quote! {
1920 match edge {
1921 0 => { idx = 1 * 18 + i; }
1922 1 => { idx = 16 * 18 + i; }
1923 _ => { idx = 0; }
1924 }
1925 };
1926
1927 let result = transpiler.transpile_expr(&expr).unwrap();
1928 assert!(
1929 result.contains("switch (edge)"),
1930 "Should generate switch: {}",
1931 result
1932 );
1933 assert!(result.contains("case 0:"), "Should have case 0: {}", result);
1934 assert!(result.contains("case 1:"), "Should have case 1: {}", result);
1935 assert!(
1936 result.contains("default:"),
1937 "Should have default: {}",
1938 result
1939 );
1940 assert!(result.contains("break;"), "Should have break: {}", result);
1941
1942 println!("Generated switch:\n{}", result);
1943 }
1944
1945 #[test]
1946 fn test_block_idx_intrinsics() {
1947 let transpiler = CudaTranspiler::new_generic();
1948
1949 let expr: Expr = parse_quote!(block_idx_x());
1951 let result = transpiler.transpile_expr(&expr).unwrap();
1952 assert_eq!(result, "blockIdx.x");
1953
1954 let expr2: Expr = parse_quote!(thread_idx_y());
1956 let result2 = transpiler.transpile_expr(&expr2).unwrap();
1957 assert_eq!(result2, "threadIdx.y");
1958
1959 let expr3: Expr = parse_quote!(grid_dim_x());
1961 let result3 = transpiler.transpile_expr(&expr3).unwrap();
1962 assert_eq!(result3, "gridDim.x");
1963 }
1964
1965 #[test]
1966 fn test_global_index_calculation() {
1967 let transpiler = CudaTranspiler::new_generic();
1968
1969 let expr: Expr = parse_quote!(block_idx_x() * block_dim_x() + thread_idx_x());
1971 let result = transpiler.transpile_expr(&expr).unwrap();
1972 assert!(result.contains("blockIdx.x"), "Should contain blockIdx.x");
1973 assert!(result.contains("blockDim.x"), "Should contain blockDim.x");
1974 assert!(result.contains("threadIdx.x"), "Should contain threadIdx.x");
1975
1976 println!("Global index expression: {}", result);
1977 }
1978
1979 #[test]
1982 fn test_for_loop_transpile() {
1983 let transpiler = CudaTranspiler::new_generic();
1984
1985 let expr: Expr = parse_quote! {
1986 for i in 0..n {
1987 data[i] = 0.0;
1988 }
1989 };
1990
1991 let result = transpiler.transpile_expr(&expr).unwrap();
1992 assert!(
1993 result.contains("for (int i = 0; i < n; i++)"),
1994 "Should generate for loop header: {}",
1995 result
1996 );
1997 assert!(
1998 result.contains("data[i] = 0.0f"),
1999 "Should contain loop body: {}",
2000 result
2001 );
2002
2003 println!("Generated for loop:\n{}", result);
2004 }
2005
2006 #[test]
2007 fn test_for_loop_inclusive_range() {
2008 let transpiler = CudaTranspiler::new_generic();
2009
2010 let expr: Expr = parse_quote! {
2011 for i in 1..=10 {
2012 sum += i;
2013 }
2014 };
2015
2016 let result = transpiler.transpile_expr(&expr).unwrap();
2017 assert!(
2018 result.contains("for (int i = 1; i <= 10; i++)"),
2019 "Should generate inclusive range: {}",
2020 result
2021 );
2022
2023 println!("Generated inclusive for loop:\n{}", result);
2024 }
2025
2026 #[test]
2027 fn test_while_loop_transpile() {
2028 let transpiler = CudaTranspiler::new_generic();
2029
2030 let expr: Expr = parse_quote! {
2031 while i < 10 {
2032 i += 1;
2033 }
2034 };
2035
2036 let result = transpiler.transpile_expr(&expr).unwrap();
2037 assert!(
2038 result.contains("while (i < 10)"),
2039 "Should generate while loop: {}",
2040 result
2041 );
2042 assert!(
2043 result.contains("i += 1"),
2044 "Should contain loop body: {}",
2045 result
2046 );
2047
2048 println!("Generated while loop:\n{}", result);
2049 }
2050
2051 #[test]
2052 fn test_while_loop_negation() {
2053 let transpiler = CudaTranspiler::new_generic();
2054
2055 let expr: Expr = parse_quote! {
2056 while !done {
2057 process();
2058 }
2059 };
2060
2061 let result = transpiler.transpile_expr(&expr).unwrap();
2062 assert!(
2063 result.contains("while (!(done))"),
2064 "Should negate condition: {}",
2065 result
2066 );
2067
2068 println!("Generated while loop with negation:\n{}", result);
2069 }
2070
2071 #[test]
2072 fn test_infinite_loop_transpile() {
2073 let transpiler = CudaTranspiler::new_generic();
2074
2075 let expr: Expr = parse_quote! {
2076 loop {
2077 process();
2078 }
2079 };
2080
2081 let result = transpiler.transpile_expr(&expr).unwrap();
2082 assert!(
2083 result.contains("while (true)"),
2084 "Should generate infinite loop: {}",
2085 result
2086 );
2087 assert!(
2088 result.contains("process()"),
2089 "Should contain loop body: {}",
2090 result
2091 );
2092
2093 println!("Generated infinite loop:\n{}", result);
2094 }
2095
2096 #[test]
2097 fn test_break_transpile() {
2098 let transpiler = CudaTranspiler::new_generic();
2099
2100 let expr: Expr = parse_quote!(break);
2101 let result = transpiler.transpile_expr(&expr).unwrap();
2102 assert_eq!(result, "break");
2103 }
2104
2105 #[test]
2106 fn test_continue_transpile() {
2107 let transpiler = CudaTranspiler::new_generic();
2108
2109 let expr: Expr = parse_quote!(continue);
2110 let result = transpiler.transpile_expr(&expr).unwrap();
2111 assert_eq!(result, "continue");
2112 }
2113
2114 #[test]
2115 fn test_loop_with_break() {
2116 let transpiler = CudaTranspiler::new_generic();
2117
2118 let expr: Expr = parse_quote! {
2119 loop {
2120 if done {
2121 break;
2122 }
2123 }
2124 };
2125
2126 let result = transpiler.transpile_expr(&expr).unwrap();
2127 assert!(
2128 result.contains("while (true)"),
2129 "Should generate infinite loop: {}",
2130 result
2131 );
2132 assert!(result.contains("break"), "Should contain break: {}", result);
2133
2134 println!("Generated loop with break:\n{}", result);
2135 }
2136
2137 #[test]
2138 fn test_nested_loops() {
2139 let transpiler = CudaTranspiler::new_generic();
2140
2141 let expr: Expr = parse_quote! {
2142 for i in 0..m {
2143 for j in 0..n {
2144 matrix[i * n + j] = 0.0;
2145 }
2146 }
2147 };
2148
2149 let result = transpiler.transpile_expr(&expr).unwrap();
2150 assert!(
2151 result.contains("for (int i = 0; i < m; i++)"),
2152 "Should have outer loop: {}",
2153 result
2154 );
2155 assert!(
2156 result.contains("for (int j = 0; j < n; j++)"),
2157 "Should have inner loop: {}",
2158 result
2159 );
2160
2161 println!("Generated nested loops:\n{}", result);
2162 }
2163
2164 #[test]
2165 fn test_stencil_mode_rejects_loops() {
2166 let config = StencilConfig::new("test")
2167 .with_tile_size(16, 16)
2168 .with_halo(1);
2169 let transpiler = CudaTranspiler::new(config);
2170
2171 let expr: Expr = parse_quote! {
2172 for i in 0..n {
2173 data[i] = 0.0;
2174 }
2175 };
2176
2177 let result = transpiler.transpile_expr(&expr);
2178 assert!(result.is_err(), "Stencil mode should reject loops");
2179 }
2180
2181 #[test]
2182 fn test_labeled_break_rejected() {
2183 let transpiler = CudaTranspiler::new_generic();
2184
2185 let break_expr = syn::ExprBreak {
2188 attrs: Vec::new(),
2189 break_token: syn::token::Break::default(),
2190 label: Some(syn::Lifetime::new("'outer", proc_macro2::Span::call_site())),
2191 expr: None,
2192 };
2193
2194 let result = transpiler.transpile_break(&break_expr);
2195 assert!(result.is_err(), "Labeled break should be rejected");
2196 }
2197
2198 #[test]
2199 fn test_full_kernel_with_loop() {
2200 let func: ItemFn = parse_quote! {
2201 fn fill_array(data: &mut [f32], n: i32) {
2202 for i in 0..n {
2203 data[i as usize] = 0.0;
2204 }
2205 }
2206 };
2207
2208 let mut transpiler = CudaTranspiler::new_generic();
2209 let cuda = transpiler.transpile_generic_kernel(&func).unwrap();
2210
2211 assert!(
2212 cuda.contains("extern \"C\" __global__"),
2213 "Should be global kernel: {}",
2214 cuda
2215 );
2216 assert!(
2217 cuda.contains("for (int i = 0; i < n; i++)"),
2218 "Should have for loop: {}",
2219 cuda
2220 );
2221
2222 println!("Generated kernel with loop:\n{}", cuda);
2223 }
2224
2225 #[test]
2226 fn test_persistent_kernel_pattern() {
2227 let transpiler = CudaTranspiler::with_mode(ValidationMode::RingKernel);
2229
2230 let expr: Expr = parse_quote! {
2231 while !should_terminate {
2232 if has_message {
2233 process_message();
2234 }
2235 }
2236 };
2237
2238 let result = transpiler.transpile_expr(&expr).unwrap();
2239 assert!(
2240 result.contains("while (!(should_terminate))"),
2241 "Should have persistent loop: {}",
2242 result
2243 );
2244 assert!(
2245 result.contains("if (has_message)"),
2246 "Should have message check: {}",
2247 result
2248 );
2249
2250 println!("Generated persistent kernel pattern:\n{}", result);
2251 }
2252
2253 #[test]
2256 fn test_shared_tile_declaration() {
2257 use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2258
2259 let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
2260 assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
2261
2262 let mut config = SharedMemoryConfig::new();
2263 config.add_tile("tile", "float", 16, 16);
2264 assert_eq!(config.total_bytes(), 16 * 16 * 4); let decls = config.generate_declarations(" ");
2267 assert!(decls.contains("__shared__ float tile[16][16];"));
2268 }
2269
2270 #[test]
2271 fn test_shared_array_declaration() {
2272 use crate::shared::{SharedMemoryConfig, SharedMemoryDecl};
2273
2274 let decl = SharedMemoryDecl::array("buffer", "float", 256);
2275 assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
2276
2277 let mut config = SharedMemoryConfig::new();
2278 config.add_array("buffer", "float", 256);
2279 assert_eq!(config.total_bytes(), 256 * 4); }
2281
2282 #[test]
2283 fn test_shared_memory_access_expressions() {
2284 use crate::shared::SharedMemoryDecl;
2285
2286 let tile = SharedMemoryDecl::tile("tile", "float", 16, 16);
2287 assert_eq!(
2288 tile.to_cuda_access(&["y".to_string(), "x".to_string()]),
2289 "tile[y][x]"
2290 );
2291
2292 let arr = SharedMemoryDecl::array("buf", "int", 128);
2293 assert_eq!(arr.to_cuda_access(&["i".to_string()]), "buf[i]");
2294 }
2295
2296 #[test]
2297 fn test_parse_shared_tile_type() {
2298 use crate::shared::parse_shared_tile_type;
2299
2300 let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
2301 assert_eq!(result, Some(("f32".to_string(), 16, 16)));
2302
2303 let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
2304 assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
2305
2306 let invalid = parse_shared_tile_type("Vec<f32>");
2307 assert_eq!(invalid, None);
2308 }
2309
2310 #[test]
2311 fn test_parse_shared_array_type() {
2312 use crate::shared::parse_shared_array_type;
2313
2314 let result = parse_shared_array_type("SharedArray::<f32, 256>");
2315 assert_eq!(result, Some(("f32".to_string(), 256)));
2316
2317 let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
2318 assert_eq!(result2, Some(("u32".to_string(), 1024)));
2319
2320 let invalid = parse_shared_array_type("Vec<f32>");
2321 assert_eq!(invalid, None);
2322 }
2323
2324 #[test]
2325 fn test_rust_to_cuda_element_types() {
2326 use crate::shared::rust_to_cuda_element_type;
2327
2328 assert_eq!(rust_to_cuda_element_type("f32"), "float");
2329 assert_eq!(rust_to_cuda_element_type("f64"), "double");
2330 assert_eq!(rust_to_cuda_element_type("i32"), "int");
2331 assert_eq!(rust_to_cuda_element_type("u32"), "unsigned int");
2332 assert_eq!(rust_to_cuda_element_type("i64"), "long long");
2333 assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
2334 assert_eq!(rust_to_cuda_element_type("bool"), "int");
2335 }
2336
2337 #[test]
2338 fn test_shared_memory_total_bytes() {
2339 use crate::shared::SharedMemoryConfig;
2340
2341 let mut config = SharedMemoryConfig::new();
2342 config.add_tile("tile1", "float", 16, 16); config.add_tile("tile2", "double", 8, 8); config.add_array("temp", "int", 64); assert_eq!(config.total_bytes(), 1024 + 512 + 256);
2347 }
2348
2349 #[test]
2350 fn test_transpiler_shared_var_tracking() {
2351 let mut transpiler = CudaTranspiler::new_generic();
2352
2353 transpiler.shared_vars.insert(
2355 "tile".to_string(),
2356 SharedVarInfo {
2357 name: "tile".to_string(),
2358 is_tile: true,
2359 dimensions: vec![16, 16],
2360 element_type: "float".to_string(),
2361 },
2362 );
2363
2364 assert!(transpiler.shared_vars.contains_key("tile"));
2366 assert!(transpiler.shared_vars.get("tile").unwrap().is_tile);
2367 }
2368
2369 #[test]
2370 fn test_shared_tile_get_transpilation() {
2371 let mut transpiler = CudaTranspiler::new_generic();
2372
2373 transpiler.shared_vars.insert(
2375 "tile".to_string(),
2376 SharedVarInfo {
2377 name: "tile".to_string(),
2378 is_tile: true,
2379 dimensions: vec![16, 16],
2380 element_type: "float".to_string(),
2381 },
2382 );
2383
2384 let result = transpiler.try_transpile_shared_method_call(
2386 "tile",
2387 "get",
2388 &syn::punctuated::Punctuated::new(),
2389 );
2390
2391 assert!(result.is_none() || result.unwrap().is_err());
2393 }
2394
2395 #[test]
2396 fn test_shared_array_access() {
2397 let mut transpiler = CudaTranspiler::new_generic();
2398
2399 transpiler.shared_vars.insert(
2401 "buffer".to_string(),
2402 SharedVarInfo {
2403 name: "buffer".to_string(),
2404 is_tile: false,
2405 dimensions: vec![256],
2406 element_type: "float".to_string(),
2407 },
2408 );
2409
2410 assert!(!transpiler.shared_vars.get("buffer").unwrap().is_tile);
2411 assert_eq!(
2412 transpiler.shared_vars.get("buffer").unwrap().dimensions,
2413 vec![256]
2414 );
2415 }
2416
2417 #[test]
2418 fn test_full_kernel_with_shared_memory() {
2419 use crate::shared::SharedMemoryConfig;
2421
2422 let mut config = SharedMemoryConfig::new();
2423 config.add_tile("smem", "float", 16, 16);
2424
2425 let decls = config.generate_declarations(" ");
2426 assert!(decls.contains("__shared__ float smem[16][16];"));
2427 assert!(!config.is_empty());
2428 }
2429
2430 #[test]
2433 fn test_struct_literal_transpile() {
2434 let transpiler = CudaTranspiler::new_generic();
2435
2436 let expr: Expr = parse_quote! {
2437 Point { x: 1.0, y: 2.0 }
2438 };
2439
2440 let result = transpiler.transpile_expr(&expr).unwrap();
2441 assert!(
2442 result.contains("Point"),
2443 "Should contain struct name: {}",
2444 result
2445 );
2446 assert!(result.contains(".x ="), "Should have field x: {}", result);
2447 assert!(result.contains(".y ="), "Should have field y: {}", result);
2448 assert!(
2449 result.contains("1.0f"),
2450 "Should have value 1.0f: {}",
2451 result
2452 );
2453 assert!(
2454 result.contains("2.0f"),
2455 "Should have value 2.0f: {}",
2456 result
2457 );
2458
2459 println!("Generated struct literal: {}", result);
2460 }
2461
2462 #[test]
2463 fn test_struct_literal_with_expressions() {
2464 let transpiler = CudaTranspiler::new_generic();
2465
2466 let expr: Expr = parse_quote! {
2467 Response { value: x * 2.0, id: idx as u64 }
2468 };
2469
2470 let result = transpiler.transpile_expr(&expr).unwrap();
2471 assert!(
2472 result.contains("Response"),
2473 "Should contain struct name: {}",
2474 result
2475 );
2476 assert!(
2477 result.contains(".value = x * 2.0f"),
2478 "Should have computed value: {}",
2479 result
2480 );
2481 assert!(result.contains(".id ="), "Should have id field: {}", result);
2482
2483 println!("Generated struct with expressions: {}", result);
2484 }
2485
2486 #[test]
2487 fn test_struct_literal_in_return() {
2488 let mut transpiler = CudaTranspiler::new_generic();
2489
2490 let stmt: Stmt = parse_quote! {
2491 return MyStruct { a: 1, b: 2.0 };
2492 };
2493
2494 let result = transpiler.transpile_stmt(&stmt).unwrap();
2495 assert!(result.contains("return"), "Should have return: {}", result);
2496 assert!(
2497 result.contains("MyStruct"),
2498 "Should contain struct name: {}",
2499 result
2500 );
2501
2502 println!("Generated return with struct: {}", result);
2503 }
2504
2505 #[test]
2506 fn test_struct_literal_compound_literal_format() {
2507 let transpiler = CudaTranspiler::new_generic();
2508
2509 let expr: Expr = parse_quote! {
2510 Vec3 { x: a, y: b, z: c }
2511 };
2512
2513 let result = transpiler.transpile_expr(&expr).unwrap();
2514 assert!(
2516 result.starts_with("(Vec3){"),
2517 "Should use compound literal format: {}",
2518 result
2519 );
2520 assert!(
2521 result.ends_with("}"),
2522 "Should end with closing brace: {}",
2523 result
2524 );
2525
2526 println!("Generated compound literal: {}", result);
2527 }
2528
2529 #[test]
2532 fn test_reference_to_array_element() {
2533 let transpiler = CudaTranspiler::new_generic();
2534
2535 let expr: Expr = parse_quote! {
2536 &arr[idx]
2537 };
2538
2539 let result = transpiler.transpile_expr(&expr).unwrap();
2540 assert_eq!(
2541 result, "&arr[idx]",
2542 "Should produce address-of array element"
2543 );
2544 }
2545
2546 #[test]
2547 fn test_mutable_reference_to_array_element() {
2548 let transpiler = CudaTranspiler::new_generic();
2549
2550 let expr: Expr = parse_quote! {
2551 &mut arr[idx * 4 + offset]
2552 };
2553
2554 let result = transpiler.transpile_expr(&expr).unwrap();
2555 assert!(
2556 result.contains("&arr["),
2557 "Should produce address-of: {}",
2558 result
2559 );
2560 assert!(
2561 result.contains("idx * 4"),
2562 "Should have index expression: {}",
2563 result
2564 );
2565 }
2566
2567 #[test]
2568 fn test_reference_to_variable() {
2569 let transpiler = CudaTranspiler::new_generic();
2570
2571 let expr: Expr = parse_quote! {
2572 &value
2573 };
2574
2575 let result = transpiler.transpile_expr(&expr).unwrap();
2576 assert_eq!(result, "&value", "Should produce address-of variable");
2577 }
2578
2579 #[test]
2580 fn test_reference_to_struct_field() {
2581 let transpiler = CudaTranspiler::new_generic();
2582
2583 let expr: Expr = parse_quote! {
2584 &alerts[(idx as usize) * 4 + alert_idx as usize]
2585 };
2586
2587 let result = transpiler.transpile_expr(&expr).unwrap();
2588 assert!(
2589 result.starts_with("&alerts["),
2590 "Should have address-of array: {}",
2591 result
2592 );
2593
2594 println!("Generated reference: {}", result);
2595 }
2596
2597 #[test]
2598 fn test_complex_reference_pattern() {
2599 let mut transpiler = CudaTranspiler::new_generic();
2600
2601 let stmt: Stmt = parse_quote! {
2603 let alert = &mut alerts[(idx as usize) * 4 + alert_idx as usize];
2604 };
2605
2606 let result = transpiler.transpile_stmt(&stmt).unwrap();
2607 assert!(
2608 result.contains("alert ="),
2609 "Should have variable assignment: {}",
2610 result
2611 );
2612 assert!(
2613 result.contains("&alerts["),
2614 "Should have reference to array: {}",
2615 result
2616 );
2617
2618 println!("Generated statement: {}", result);
2619 }
2620}