1use quote::{quote, format_ident};
4use proc_macro2::TokenStream;
5use crate::{Result, translation_error};
6use crate::parser::ast::*;
7
8pub struct CodeGenerator {
10 code: TokenStream,
12}
13
14impl Default for CodeGenerator {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl CodeGenerator {
21 pub fn new() -> Self {
23 Self {
24 code: TokenStream::new(),
25 }
26 }
27
28 pub fn generate(&mut self, ast: Ast) -> Result<String> {
30 let imports = self.generate_imports();
32
33 let items: Vec<TokenStream> = ast.items.into_iter()
35 .map(|item| self.generate_item(item))
36 .collect::<Result<Vec<_>>>()?;
37
38 let code = quote! {
39 #imports
40
41 #(#items)*
42 };
43
44 Ok(code.to_string())
45 }
46
47 fn generate_imports(&self) -> TokenStream {
49 quote! {
50 use cuda_rust_wasm::runtime::{Grid, Block, thread, block, grid};
51 use cuda_rust_wasm::memory::{DeviceBuffer, SharedMemory};
52 use cuda_rust_wasm::kernel::launch_kernel;
53 }
54 }
55
56 fn generate_item(&self, item: Item) -> Result<TokenStream> {
58 match item {
59 Item::Kernel(kernel) => self.generate_kernel(kernel),
60 Item::DeviceFunction(func) => self.generate_device_function(func),
61 Item::HostFunction(func) => self.generate_host_function(func),
62 Item::GlobalVar(var) => self.generate_global_var(var),
63 Item::TypeDef(typedef) => self.generate_typedef(typedef),
64 Item::Include(_) => Ok(TokenStream::new()), }
66 }
67
68 fn generate_kernel(&self, kernel: KernelDef) -> Result<TokenStream> {
70 let name = format_ident!("{}", kernel.name);
71 let params = self.generate_parameters(&kernel.params)?;
72 let body = self.generate_block(&kernel.body)?;
73
74 Ok(quote! {
75 #[kernel]
76 pub fn #name(#params) {
77 #body
78 }
79 })
80 }
81
82 fn generate_device_function(&self, func: FunctionDef) -> Result<TokenStream> {
84 let name = format_ident!("{}", func.name);
85 let params = self.generate_parameters(&func.params)?;
86 let return_type = self.generate_type(&func.return_type)?;
87 let body = self.generate_block(&func.body)?;
88
89 Ok(quote! {
90 #[device_function]
91 pub fn #name(#params) -> #return_type {
92 #body
93 }
94 })
95 }
96
97 fn generate_host_function(&self, func: FunctionDef) -> Result<TokenStream> {
99 let name = format_ident!("{}", func.name);
100 let params = self.generate_parameters(&func.params)?;
101 let return_type = self.generate_type(&func.return_type)?;
102 let body = self.generate_block(&func.body)?;
103
104 Ok(quote! {
105 pub fn #name(#params) -> #return_type {
106 #body
107 }
108 })
109 }
110
111 fn generate_parameters(&self, params: &[Parameter]) -> Result<TokenStream> {
113 let params: Vec<TokenStream> = params.iter()
114 .map(|p| {
115 let name = format_ident!("{}", p.name);
116 let ty = self.generate_type(&p.ty)?;
117 Ok(quote! { #name: #ty })
118 })
119 .collect::<Result<Vec<_>>>()?;
120
121 Ok(quote! { #(#params),* })
122 }
123
124 fn generate_type(&self, ty: &Type) -> Result<TokenStream> {
126 match ty {
127 Type::Void => Ok(quote! { () }),
128 Type::Bool => Ok(quote! { bool }),
129 Type::Int(int_ty) => Ok(match int_ty {
130 IntType::I8 => quote! { i8 },
131 IntType::I16 => quote! { i16 },
132 IntType::I32 => quote! { i32 },
133 IntType::I64 => quote! { i64 },
134 IntType::U8 => quote! { u8 },
135 IntType::U16 => quote! { u16 },
136 IntType::U32 => quote! { u32 },
137 IntType::U64 => quote! { u64 },
138 }),
139 Type::Float(float_ty) => Ok(match float_ty {
140 FloatType::F16 => quote! { f16 },
141 FloatType::F32 => quote! { f32 },
142 FloatType::F64 => quote! { f64 },
143 }),
144 Type::Pointer(inner) => {
145 let inner_ty = self.generate_type(inner)?;
146 Ok(quote! { &mut #inner_ty })
147 },
148 Type::Array(inner, size) => {
149 let inner_ty = self.generate_type(inner)?;
150 match size {
151 Some(n) => Ok(quote! { [#inner_ty; #n] }),
152 None => Ok(quote! { &[#inner_ty] }),
153 }
154 },
155 Type::Vector(vec_ty) => {
156 let elem_ty = self.generate_type(&vec_ty.element)?;
157 let size = vec_ty.size as usize;
158 Ok(quote! { [#elem_ty; #size] })
159 },
160 Type::Named(name) => {
161 let name = format_ident!("{}", name);
162 Ok(quote! { #name })
163 },
164 Type::Texture(_) => Err(translation_error!("Texture types not yet supported")),
165 }
166 }
167
168 fn generate_block(&self, block: &Block) -> Result<TokenStream> {
170 let statements: Vec<TokenStream> = block.statements.iter()
171 .map(|stmt| self.generate_statement(stmt))
172 .collect::<Result<Vec<_>>>()?;
173
174 Ok(quote! {
175 #(#statements)*
176 })
177 }
178
179 fn generate_statement(&self, stmt: &Statement) -> Result<TokenStream> {
181 match stmt {
182 Statement::VarDecl { name, ty, init, storage } => {
183 let name = format_ident!("{}", name);
184 let ty = self.generate_type(ty)?;
185 let storage_attr = self.generate_storage_class(storage)?;
186
187 match init {
188 Some(init_expr) => {
189 let expr = self.generate_expression(init_expr)?;
190 Ok(quote! {
191 #storage_attr
192 let #name: #ty = #expr;
193 })
194 },
195 None => Ok(quote! {
196 #storage_attr
197 let #name: #ty;
198 }),
199 }
200 },
201 Statement::Expr(expr) => {
202 let expr = self.generate_expression(expr)?;
203 Ok(quote! { #expr; })
204 },
205 Statement::Block(block) => {
206 let block = self.generate_block(block)?;
207 Ok(quote! { { #block } })
208 },
209 Statement::If { condition, then_branch, else_branch } => {
210 let cond = self.generate_expression(condition)?;
211 let then_stmt = self.generate_statement(then_branch)?;
212
213 match else_branch {
214 Some(else_stmt) => {
215 let else_stmt = self.generate_statement(else_stmt)?;
216 Ok(quote! {
217 if #cond {
218 #then_stmt
219 } else {
220 #else_stmt
221 }
222 })
223 },
224 None => Ok(quote! {
225 if #cond {
226 #then_stmt
227 }
228 }),
229 }
230 },
231 Statement::For { init, condition, update, body } => {
232 let init_stmt = match init {
234 Some(init) => match init.as_ref() {
235 Statement::VarDecl { name, ty, init, .. } => {
236 let name = format_ident!("{}", name);
237 let ty = self.generate_type(ty)?;
238 match init {
239 Some(init_expr) => {
240 let expr = self.generate_expression(init_expr)?;
241 quote! { let mut #name: #ty = #expr; }
242 },
243 None => quote! { let mut #name: #ty; },
244 }
245 },
246 Statement::Expr(expr) => {
247 let expr = self.generate_expression(expr)?;
248 quote! { #expr; }
249 },
250 _ => return Err(translation_error!("Invalid init statement in for loop")),
251 },
252 None => TokenStream::new(),
253 };
254
255 let cond = match condition {
257 Some(c) => {
258 let cond_expr = self.generate_expression(c)?;
259 quote! { #cond_expr }
260 },
261 None => quote! { true },
262 };
263
264 let update_stmt = match update {
266 Some(u) => {
267 let update_expr = self.generate_expression(u)?;
268 quote! { #update_expr; }
269 },
270 None => TokenStream::new(),
271 };
272
273 let body_stmt = self.generate_statement(body)?;
275
276 Ok(quote! {
278 {
279 #init_stmt
280 while #cond {
281 #body_stmt
282 #update_stmt
283 }
284 }
285 })
286 },
287 Statement::While { condition, body } => {
288 let cond = self.generate_expression(condition)?;
289 let body_stmt = self.generate_statement(body)?;
290 Ok(quote! {
291 while #cond {
292 #body_stmt
293 }
294 })
295 },
296 Statement::Return(expr) => {
297 match expr {
298 Some(e) => {
299 let expr = self.generate_expression(e)?;
300 Ok(quote! { return #expr; })
301 },
302 None => Ok(quote! { return; }),
303 }
304 },
305 Statement::Break => Ok(quote! { break; }),
306 Statement::Continue => Ok(quote! { continue; }),
307 Statement::SyncThreads => Ok(quote! { cuda_rust_wasm::runtime::sync_threads(); }),
308 }
309 }
310
311 fn generate_storage_class(&self, storage: &StorageClass) -> Result<TokenStream> {
313 match storage {
314 StorageClass::Shared => Ok(quote! { #[shared] }),
315 StorageClass::Constant => Ok(quote! { #[constant] }),
316 _ => Ok(TokenStream::new()),
317 }
318 }
319
320 fn generate_expression(&self, expr: &Expression) -> Result<TokenStream> {
322 match expr {
323 Expression::Literal(lit) => self.generate_literal(lit),
324 Expression::Var(name) => {
325 let name = format_ident!("{}", name);
326 Ok(quote! { #name })
327 },
328 Expression::Binary { op, left, right } => {
329 let left = self.generate_expression(left)?;
330 let right = self.generate_expression(right)?;
331 let op = self.generate_binary_op(op)?;
332 Ok(quote! { (#left #op #right) })
333 },
334 Expression::Unary { op, expr } => {
335 let expr = self.generate_expression(expr)?;
336 let op = self.generate_unary_op(op)?;
337 Ok(quote! { (#op #expr) })
338 },
339 Expression::Call { name, args } => {
340 let name = format_ident!("{}", name);
341 let args: Vec<TokenStream> = args.iter()
342 .map(|arg| self.generate_expression(arg))
343 .collect::<Result<Vec<_>>>()?;
344 Ok(quote! { #name(#(#args),*) })
345 },
346 Expression::Index { array, index } => {
347 let array = self.generate_expression(array)?;
348 let index = self.generate_expression(index)?;
349 Ok(quote! { #array[#index] })
350 },
351 Expression::Member { object, field } => {
352 let object = self.generate_expression(object)?;
353 let field = format_ident!("{}", field);
354 Ok(quote! { #object.#field })
355 },
356 Expression::Cast { ty, expr } => {
357 let ty = self.generate_type(ty)?;
358 let expr = self.generate_expression(expr)?;
359 Ok(quote! { #expr as #ty })
360 },
361 Expression::ThreadIdx(dim) => {
362 let dim = self.generate_dimension(dim)?;
363 Ok(quote! { thread::index().#dim })
364 },
365 Expression::BlockIdx(dim) => {
366 let dim = self.generate_dimension(dim)?;
367 Ok(quote! { block::index().#dim })
368 },
369 Expression::BlockDim(dim) => {
370 let dim = self.generate_dimension(dim)?;
371 Ok(quote! { block::dim().#dim })
372 },
373 Expression::GridDim(dim) => {
374 let dim = self.generate_dimension(dim)?;
375 Ok(quote! { grid::dim().#dim })
376 },
377 Expression::WarpPrimitive { op, args } => {
378 match op {
380 WarpOp::Shuffle => {
381 if args.len() != 2 {
382 return Err(translation_error!("Warp shuffle requires 2 arguments"));
383 }
384 let value = self.generate_expression(&args[0])?;
385 let lane = self.generate_expression(&args[1])?;
386 Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle(#value, #lane) })
387 },
388 WarpOp::ShuffleXor => {
389 if args.len() != 2 {
390 return Err(translation_error!("Warp shuffle_xor requires 2 arguments"));
391 }
392 let value = self.generate_expression(&args[0])?;
393 let mask = self.generate_expression(&args[1])?;
394 Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle_xor(#value, #mask) })
395 },
396 WarpOp::ShuffleUp => {
397 if args.len() != 2 {
398 return Err(translation_error!("Warp shuffle_up requires 2 arguments"));
399 }
400 let value = self.generate_expression(&args[0])?;
401 let delta = self.generate_expression(&args[1])?;
402 Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle_up(#value, #delta) })
403 },
404 WarpOp::ShuffleDown => {
405 if args.len() != 2 {
406 return Err(translation_error!("Warp shuffle_down requires 2 arguments"));
407 }
408 let value = self.generate_expression(&args[0])?;
409 let delta = self.generate_expression(&args[1])?;
410 Ok(quote! { cuda_rust_wasm::runtime::warp_shuffle_down(#value, #delta) })
411 },
412 WarpOp::Vote => {
413 if args.len() != 1 {
414 return Err(translation_error!("Warp vote requires 1 argument"));
415 }
416 let predicate = self.generate_expression(&args[0])?;
417 Ok(quote! { cuda_rust_wasm::runtime::warp_vote_all(#predicate) })
418 },
419 WarpOp::Ballot => {
420 if args.len() != 1 {
421 return Err(translation_error!("Warp ballot requires 1 argument"));
422 }
423 let predicate = self.generate_expression(&args[0])?;
424 Ok(quote! { cuda_rust_wasm::runtime::warp_ballot(#predicate) })
425 },
426 WarpOp::ActiveMask => {
427 if !args.is_empty() {
428 return Err(translation_error!("Warp activemask takes no arguments"));
429 }
430 Ok(quote! { cuda_rust_wasm::runtime::warp_activemask() })
431 },
432 }
433 },
434 }
435 }
436
437 fn generate_literal(&self, lit: &Literal) -> Result<TokenStream> {
439 match lit {
440 Literal::Bool(b) => Ok(quote! { #b }),
441 Literal::Int(i) => Ok(quote! { #i }),
442 Literal::UInt(u) => Ok(quote! { #u }),
443 Literal::Float(f) => Ok(quote! { #f }),
444 Literal::String(s) => Ok(quote! { #s }),
445 }
446 }
447
448 fn generate_binary_op(&self, op: &BinaryOp) -> Result<TokenStream> {
450 Ok(match op {
451 BinaryOp::Add => quote! { + },
452 BinaryOp::Sub => quote! { - },
453 BinaryOp::Mul => quote! { * },
454 BinaryOp::Div => quote! { / },
455 BinaryOp::Mod => quote! { % },
456 BinaryOp::And => quote! { & },
457 BinaryOp::Or => quote! { | },
458 BinaryOp::Xor => quote! { ^ },
459 BinaryOp::Shl => quote! { << },
460 BinaryOp::Shr => quote! { >> },
461 BinaryOp::Eq => quote! { == },
462 BinaryOp::Ne => quote! { != },
463 BinaryOp::Lt => quote! { < },
464 BinaryOp::Le => quote! { <= },
465 BinaryOp::Gt => quote! { > },
466 BinaryOp::Ge => quote! { >= },
467 BinaryOp::LogicalAnd => quote! { && },
468 BinaryOp::LogicalOr => quote! { || },
469 BinaryOp::Assign => quote! { = },
470 })
471 }
472
473 fn generate_unary_op(&self, op: &UnaryOp) -> Result<TokenStream> {
475 Ok(match op {
476 UnaryOp::Not => quote! { ! },
477 UnaryOp::Neg => quote! { - },
478 UnaryOp::BitNot => quote! { ! },
479 UnaryOp::PreInc => quote! { ++ },
480 UnaryOp::PreDec => quote! { -- },
481 UnaryOp::PostInc => return Err(translation_error!("Post-increment not supported")),
482 UnaryOp::PostDec => return Err(translation_error!("Post-decrement not supported")),
483 UnaryOp::Deref => quote! { * },
484 UnaryOp::AddrOf => quote! { & },
485 })
486 }
487
488 fn generate_dimension(&self, dim: &Dimension) -> Result<TokenStream> {
490 Ok(match dim {
491 Dimension::X => quote! { x },
492 Dimension::Y => quote! { y },
493 Dimension::Z => quote! { z },
494 })
495 }
496
497 fn generate_global_var(&self, var: GlobalVar) -> Result<TokenStream> {
499 let name = format_ident!("{}", var.name);
500 let ty = self.generate_type(&var.ty)?;
501 let storage_attr = self.generate_storage_class(&var.storage)?;
502
503 match var.init {
504 Some(init) => {
505 let init_expr = self.generate_expression(&init)?;
506 Ok(quote! {
507 #storage_attr
508 static #name: #ty = #init_expr;
509 })
510 },
511 None => Ok(quote! {
512 #storage_attr
513 static #name: #ty;
514 }),
515 }
516 }
517
518 fn generate_typedef(&self, typedef: TypeDef) -> Result<TokenStream> {
520 let name = format_ident!("{}", typedef.name);
521 let ty = self.generate_type(&typedef.ty)?;
522 Ok(quote! {
523 type #name = #ty;
524 })
525 }
526}