cuda_rust_wasm/transpiler/
code_generator.rs

1//! Rust code generation from CUDA AST
2
3use quote::{quote, format_ident};
4use proc_macro2::TokenStream;
5use crate::{Result, translation_error};
6use crate::parser::ast::*;
7
8/// Code generator for converting CUDA AST to Rust
9pub struct CodeGenerator {
10    /// Generated Rust code
11    code: TokenStream,
12}
13
14impl Default for CodeGenerator {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl CodeGenerator {
21    /// Create a new code generator
22    pub fn new() -> Self {
23        Self {
24            code: TokenStream::new(),
25        }
26    }
27    
28    /// Generate Rust code from AST
29    pub fn generate(&mut self, ast: Ast) -> Result<String> {
30        // Generate module imports
31        let imports = self.generate_imports();
32        
33        // Generate code for each item
34        let items: Vec<TokenStream> = ast.items.into_iter()
35            .map(|item| self.generate_item(item))
36            .collect::<Result<Vec<_>>>()?;
37        
38        let code = quote! {
39            #imports
40            
41            #(#items)*
42        };
43        
44        Ok(code.to_string())
45    }
46    
47    /// Generate standard imports
48    fn generate_imports(&self) -> TokenStream {
49        quote! {
50            use cuda_rust_wasm::runtime::{Grid, Block, thread, block, grid};
51            use cuda_rust_wasm::memory::{DeviceBuffer, SharedMemory};
52            use cuda_rust_wasm::kernel::launch_kernel;
53        }
54    }
55    
56    /// Generate code for a single AST item
57    fn generate_item(&self, item: Item) -> Result<TokenStream> {
58        match item {
59            Item::Kernel(kernel) => self.generate_kernel(kernel),
60            Item::DeviceFunction(func) => self.generate_device_function(func),
61            Item::HostFunction(func) => self.generate_host_function(func),
62            Item::GlobalVar(var) => self.generate_global_var(var),
63            Item::TypeDef(typedef) => self.generate_typedef(typedef),
64            Item::Include(_) => Ok(TokenStream::new()), // Includes handled separately
65        }
66    }
67    
68    /// Generate code for a kernel function
69    fn generate_kernel(&self, kernel: KernelDef) -> Result<TokenStream> {
70        let name = format_ident!("{}", kernel.name);
71        let params = self.generate_parameters(&kernel.params)?;
72        let body = self.generate_block(&kernel.body)?;
73        
74        Ok(quote! {
75            #[kernel]
76            pub fn #name(#params) {
77                #body
78            }
79        })
80    }
81    
82    /// Generate code for a device function
83    fn generate_device_function(&self, func: FunctionDef) -> Result<TokenStream> {
84        let name = format_ident!("{}", func.name);
85        let params = self.generate_parameters(&func.params)?;
86        let return_type = self.generate_type(&func.return_type)?;
87        let body = self.generate_block(&func.body)?;
88        
89        Ok(quote! {
90            #[device_function]
91            pub fn #name(#params) -> #return_type {
92                #body
93            }
94        })
95    }
96    
97    /// Generate code for a host function
98    fn generate_host_function(&self, func: FunctionDef) -> Result<TokenStream> {
99        let name = format_ident!("{}", func.name);
100        let params = self.generate_parameters(&func.params)?;
101        let return_type = self.generate_type(&func.return_type)?;
102        let body = self.generate_block(&func.body)?;
103        
104        Ok(quote! {
105            pub fn #name(#params) -> #return_type {
106                #body
107            }
108        })
109    }
110    
111    /// Generate function parameters
112    fn generate_parameters(&self, params: &[Parameter]) -> Result<TokenStream> {
113        let params: Vec<TokenStream> = params.iter()
114            .map(|p| {
115                let name = format_ident!("{}", p.name);
116                let ty = self.generate_type(&p.ty)?;
117                Ok(quote! { #name: #ty })
118            })
119            .collect::<Result<Vec<_>>>()?;
120        
121        Ok(quote! { #(#params),* })
122    }
123    
124    /// Generate Rust type from CUDA type
125    fn generate_type(&self, ty: &Type) -> Result<TokenStream> {
126        match ty {
127            Type::Void => Ok(quote! { () }),
128            Type::Bool => Ok(quote! { bool }),
129            Type::Int(int_ty) => Ok(match int_ty {
130                IntType::I8 => quote! { i8 },
131                IntType::I16 => quote! { i16 },
132                IntType::I32 => quote! { i32 },
133                IntType::I64 => quote! { i64 },
134                IntType::U8 => quote! { u8 },
135                IntType::U16 => quote! { u16 },
136                IntType::U32 => quote! { u32 },
137                IntType::U64 => quote! { u64 },
138            }),
139            Type::Float(float_ty) => Ok(match float_ty {
140                FloatType::F16 => quote! { f16 },
141                FloatType::F32 => quote! { f32 },
142                FloatType::F64 => quote! { f64 },
143            }),
144            Type::Pointer(inner) => {
145                let inner_ty = self.generate_type(inner)?;
146                Ok(quote! { &mut #inner_ty })
147            },
148            Type::Array(inner, size) => {
149                let inner_ty = self.generate_type(inner)?;
150                match size {
151                    Some(n) => Ok(quote! { [#inner_ty; #n] }),
152                    None => Ok(quote! { &[#inner_ty] }),
153                }
154            },
155            Type::Vector(vec_ty) => {
156                let elem_ty = self.generate_type(&vec_ty.element)?;
157                let size = vec_ty.size as usize;
158                Ok(quote! { [#elem_ty; #size] })
159            },
160            Type::Named(name) => {
161                let name = format_ident!("{}", name);
162                Ok(quote! { #name })
163            },
164            Type::Texture(_) => Err(translation_error!("Texture types not yet supported")),
165        }
166    }
167    
168    /// Generate code for a block of statements
169    fn generate_block(&self, block: &Block) -> Result<TokenStream> {
170        let statements: Vec<TokenStream> = block.statements.iter()
171            .map(|stmt| self.generate_statement(stmt))
172            .collect::<Result<Vec<_>>>()?;
173        
174        Ok(quote! {
175            #(#statements)*
176        })
177    }
178    
179    /// Generate code for a statement
180    fn generate_statement(&self, stmt: &Statement) -> Result<TokenStream> {
181        match stmt {
182            Statement::VarDecl { name, ty, init, storage } => {
183                let name = format_ident!("{}", name);
184                let ty = self.generate_type(ty)?;
185                let storage_attr = self.generate_storage_class(storage)?;
186                
187                match init {
188                    Some(init_expr) => {
189                        let expr = self.generate_expression(init_expr)?;
190                        Ok(quote! {
191                            #storage_attr
192                            let #name: #ty = #expr;
193                        })
194                    },
195                    None => Ok(quote! {
196                        #storage_attr
197                        let #name: #ty;
198                    }),
199                }
200            },
201            Statement::Expr(expr) => {
202                let expr = self.generate_expression(expr)?;
203                Ok(quote! { #expr; })
204            },
205            Statement::Block(block) => {
206                let block = self.generate_block(block)?;
207                Ok(quote! { { #block } })
208            },
209            Statement::If { condition, then_branch, else_branch } => {
210                let cond = self.generate_expression(condition)?;
211                let then_stmt = self.generate_statement(then_branch)?;
212                
213                match else_branch {
214                    Some(else_stmt) => {
215                        let else_stmt = self.generate_statement(else_stmt)?;
216                        Ok(quote! {
217                            if #cond {
218                                #then_stmt
219                            } else {
220                                #else_stmt
221                            }
222                        })
223                    },
224                    None => Ok(quote! {
225                        if #cond {
226                            #then_stmt
227                        }
228                    }),
229                }
230            },
231            Statement::For { init, condition, update, body } => {
232                // Generate init as variable declaration or expression
233                let init_stmt = match init {
234                    Some(init) => match init.as_ref() {
235                        Statement::VarDecl { name, ty, init, .. } => {
236                            let name = format_ident!("{}", name);
237                            let ty = self.generate_type(ty)?;
238                            match init {
239                                Some(init_expr) => {
240                                    let expr = self.generate_expression(init_expr)?;
241                                    quote! { let mut #name: #ty = #expr; }
242                                },
243                                None => quote! { let mut #name: #ty; },
244                            }
245                        },
246                        Statement::Expr(expr) => {
247                            let expr = self.generate_expression(expr)?;
248                            quote! { #expr; }
249                        },
250                        _ => return Err(translation_error!("Invalid init statement in for loop")),
251                    },
252                    None => TokenStream::new(),
253                };
254                
255                // Generate condition
256                let cond = match condition {
257                    Some(c) => {
258                        let cond_expr = self.generate_expression(c)?;
259                        quote! { #cond_expr }
260                    },
261                    None => quote! { true },
262                };
263                
264                // Generate update
265                let update_stmt = match update {
266                    Some(u) => {
267                        let update_expr = self.generate_expression(u)?;
268                        quote! { #update_expr; }
269                    },
270                    None => TokenStream::new(),
271                };
272                
273                // Generate body
274                let body_stmt = self.generate_statement(body)?;
275                
276                // Construct the for loop as a while loop with init/update
277                Ok(quote! {
278                    {
279                        #init_stmt
280                        while #cond {
281                            #body_stmt
282                            #update_stmt
283                        }
284                    }
285                })
286            },
287            Statement::While { condition, body } => {
288                let cond = self.generate_expression(condition)?;
289                let body_stmt = self.generate_statement(body)?;
290                Ok(quote! {
291                    while #cond {
292                        #body_stmt
293                    }
294                })
295            },
296            Statement::Return(expr) => {
297                match expr {
298                    Some(e) => {
299                        let expr = self.generate_expression(e)?;
300                        Ok(quote! { return #expr; })
301                    },
302                    None => Ok(quote! { return; }),
303                }
304            },
305            Statement::Break => Ok(quote! { break; }),
306            Statement::Continue => Ok(quote! { continue; }),
307            Statement::SyncThreads => Ok(quote! { cuda_rust_wasm::runtime::sync_threads(); }),
308        }
309    }
310    
311    /// Generate storage class attributes
312    fn generate_storage_class(&self, storage: &StorageClass) -> Result<TokenStream> {
313        match storage {
314            StorageClass::Shared => Ok(quote! { #[shared] }),
315            StorageClass::Constant => Ok(quote! { #[constant] }),
316            _ => Ok(TokenStream::new()),
317        }
318    }
319    
320    /// Generate code for an expression
321    fn generate_expression(&self, expr: &Expression) -> Result<TokenStream> {
322        match expr {
323            Expression::Literal(lit) => self.generate_literal(lit),
324            Expression::Var(name) => {
325                let name = format_ident!("{}", name);
326                Ok(quote! { #name })
327            },
328            Expression::Binary { op, left, right } => {
329                let left = self.generate_expression(left)?;
330                let right = self.generate_expression(right)?;
331                let op = self.generate_binary_op(op)?;
332                Ok(quote! { (#left #op #right) })
333            },
334            Expression::Unary { op, expr } => {
335                let expr = self.generate_expression(expr)?;
336                let op = self.generate_unary_op(op)?;
337                Ok(quote! { (#op #expr) })
338            },
339            Expression::Call { name, args } => {
340                let name = format_ident!("{}", name);
341                let args: Vec<TokenStream> = args.iter()
342                    .map(|arg| self.generate_expression(arg))
343                    .collect::<Result<Vec<_>>>()?;
344                Ok(quote! { #name(#(#args),*) })
345            },
346            Expression::Index { array, index } => {
347                let array = self.generate_expression(array)?;
348                let index = self.generate_expression(index)?;
349                Ok(quote! { #array[#index] })
350            },
351            Expression::Member { object, field } => {
352                let object = self.generate_expression(object)?;
353                let field = format_ident!("{}", field);
354                Ok(quote! { #object.#field })
355            },
356            Expression::Cast { ty, expr } => {
357                let ty = self.generate_type(ty)?;
358                let expr = self.generate_expression(expr)?;
359                Ok(quote! { #expr as #ty })
360            },
361            Expression::ThreadIdx(dim) => {
362                let dim = self.generate_dimension(dim)?;
363                Ok(quote! { thread::index().#dim })
364            },
365            Expression::BlockIdx(dim) => {
366                let dim = self.generate_dimension(dim)?;
367                Ok(quote! { block::index().#dim })
368            },
369            Expression::BlockDim(dim) => {
370                let dim = self.generate_dimension(dim)?;
371                Ok(quote! { block::dim().#dim })
372            },
373            Expression::GridDim(dim) => {
374                let dim = self.generate_dimension(dim)?;
375                Ok(quote! { grid::dim().#dim })
376            },
377            Expression::WarpPrimitive { op, args } => {
378                // Generate warp primitive operations
379                match op {
380                    WarpOp::Shuffle => {
381                        if args.len() != 2 {
382                            return Err(translation_error!("Warp shuffle requires 2 arguments"));
383                        }
384                        let value = self.generate_expression(&args[0])?;
385                        let lane = self.generate_expression(&args[1])?;
386                        Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle(#value, #lane) })
387                    },
388                    WarpOp::ShuffleXor => {
389                        if args.len() != 2 {
390                            return Err(translation_error!("Warp shuffle_xor requires 2 arguments"));
391                        }
392                        let value = self.generate_expression(&args[0])?;
393                        let mask = self.generate_expression(&args[1])?;
394                        Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle_xor(#value, #mask) })
395                    },
396                    WarpOp::ShuffleUp => {
397                        if args.len() != 2 {
398                            return Err(translation_error!("Warp shuffle_up requires 2 arguments"));
399                        }
400                        let value = self.generate_expression(&args[0])?;
401                        let delta = self.generate_expression(&args[1])?;
402                        Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle_up(#value, #delta) })
403                    },
404                    WarpOp::ShuffleDown => {
405                        if args.len() != 2 {
406                            return Err(translation_error!("Warp shuffle_down requires 2 arguments"));
407                        }
408                        let value = self.generate_expression(&args[0])?;
409                        let delta = self.generate_expression(&args[1])?;
410                        Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle_down(#value, #delta) })
411                    },
412                    WarpOp::Vote => {
413                        if args.len() != 1 {
414                            return Err(translation_error!("Warp vote requires 1 argument"));
415                        }
416                        let predicate = self.generate_expression(&args[0])?;
417                        Ok(quote! { cuda_rust_wasm::runtime::warp_vote_all(#predicate) })
418                    },
419                    WarpOp::Ballot => {
420                        if args.len() != 1 {
421                            return Err(translation_error!("Warp ballot requires 1 argument"));
422                        }
423                        let predicate = self.generate_expression(&args[0])?;
424                        Ok(quote! { cuda_rust_wasm::runtime::warp_ballot(#predicate) })
425                    },
426                    WarpOp::ActiveMask => {
427                        if !args.is_empty() {
428                            return Err(translation_error!("Warp activemask takes no arguments"));
429                        }
430                        Ok(quote! { cuda_rust_wasm::runtime::warp_activemask() })
431                    },
432                }
433            },
434        }
435    }
436    
437    /// Generate literal values
438    fn generate_literal(&self, lit: &Literal) -> Result<TokenStream> {
439        match lit {
440            Literal::Bool(b) => Ok(quote! { #b }),
441            Literal::Int(i) => Ok(quote! { #i }),
442            Literal::UInt(u) => Ok(quote! { #u }),
443            Literal::Float(f) => Ok(quote! { #f }),
444            Literal::String(s) => Ok(quote! { #s }),
445        }
446    }
447    
448    /// Generate binary operator
449    fn generate_binary_op(&self, op: &BinaryOp) -> Result<TokenStream> {
450        Ok(match op {
451            BinaryOp::Add => quote! { + },
452            BinaryOp::Sub => quote! { - },
453            BinaryOp::Mul => quote! { * },
454            BinaryOp::Div => quote! { / },
455            BinaryOp::Mod => quote! { % },
456            BinaryOp::And => quote! { & },
457            BinaryOp::Or => quote! { | },
458            BinaryOp::Xor => quote! { ^ },
459            BinaryOp::Shl => quote! { << },
460            BinaryOp::Shr => quote! { >> },
461            BinaryOp::Eq => quote! { == },
462            BinaryOp::Ne => quote! { != },
463            BinaryOp::Lt => quote! { < },
464            BinaryOp::Le => quote! { <= },
465            BinaryOp::Gt => quote! { > },
466            BinaryOp::Ge => quote! { >= },
467            BinaryOp::LogicalAnd => quote! { && },
468            BinaryOp::LogicalOr => quote! { || },
469            BinaryOp::Assign => quote! { = },
470        })
471    }
472    
473    /// Generate unary operator
474    fn generate_unary_op(&self, op: &UnaryOp) -> Result<TokenStream> {
475        Ok(match op {
476            UnaryOp::Not => quote! { ! },
477            UnaryOp::Neg => quote! { - },
478            UnaryOp::BitNot => quote! { ! },
479            UnaryOp::PreInc => quote! { ++ },
480            UnaryOp::PreDec => quote! { -- },
481            UnaryOp::PostInc => return Err(translation_error!("Post-increment not supported")),
482            UnaryOp::PostDec => return Err(translation_error!("Post-decrement not supported")),
483            UnaryOp::Deref => quote! { * },
484            UnaryOp::AddrOf => quote! { & },
485        })
486    }
487    
488    /// Generate dimension accessor
489    fn generate_dimension(&self, dim: &Dimension) -> Result<TokenStream> {
490        Ok(match dim {
491            Dimension::X => quote! { x },
492            Dimension::Y => quote! { y },
493            Dimension::Z => quote! { z },
494        })
495    }
496    
497    /// Generate global variable
498    fn generate_global_var(&self, var: GlobalVar) -> Result<TokenStream> {
499        let name = format_ident!("{}", var.name);
500        let ty = self.generate_type(&var.ty)?;
501        let storage_attr = self.generate_storage_class(&var.storage)?;
502        
503        match var.init {
504            Some(init) => {
505                let init_expr = self.generate_expression(&init)?;
506                Ok(quote! {
507                    #storage_attr
508                    static #name: #ty = #init_expr;
509                })
510            },
511            None => Ok(quote! {
512                #storage_attr
513                static #name: #ty;
514            }),
515        }
516    }
517    
518    /// Generate type definition
519    fn generate_typedef(&self, typedef: TypeDef) -> Result<TokenStream> {
520        let name = format_ident!("{}", typedef.name);
521        let ty = self.generate_type(&typedef.ty)?;
522        Ok(quote! {
523            type #name = #ty;
524        })
525    }
526}