cuda_rust_wasm/transpiler/
wgsl.rs

1//! WebGPU Shading Language (WGSL) generation from CUDA AST
2
3use crate::{Result, translation_error};
4use crate::parser::ast::*;
5use std::fmt::Write;
6
7/// WGSL code generator for converting CUDA AST to WebGPU shaders
8pub struct WgslGenerator {
9    /// Generated WGSL code
10    code: String,
11    /// Current indentation level
12    indent_level: usize,
13    /// Workgroup size configuration
14    workgroup_size: (u32, u32, u32),
15}
16
17impl WgslGenerator {
18    /// Create a new WGSL generator
19    pub fn new() -> Self {
20        Self {
21            code: String::new(),
22            indent_level: 0,
23            workgroup_size: (64, 1, 1), // Default workgroup size
24        }
25    }
26    
27    /// Set workgroup size for compute shaders
28    pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
29        self.workgroup_size = (x, y, z);
30        self
31    }
32    
33    /// Generate WGSL code from AST
34    pub fn generate(&mut self, ast: Ast) -> Result<String> {
35        // Generate struct definitions for kernel parameters
36        self.generate_structs(&ast)?;
37        
38        // Generate global variables
39        for item in &ast.items {
40            if let Item::GlobalVar(var) = item {
41                self.generate_global_var(var)?;
42            }
43        }
44        
45        // Generate device functions
46        for item in &ast.items {
47            if let Item::DeviceFunction(func) = item {
48                self.generate_device_function(func)?;
49            }
50        }
51        
52        // Generate compute kernels
53        for item in &ast.items {
54            if let Item::Kernel(kernel) = item {
55                self.generate_kernel(kernel)?;
56            }
57        }
58        
59        Ok(self.code.clone())
60    }
61    
62    /// Generate struct definitions for kernel parameters
63    fn generate_structs(&mut self, ast: &Ast) -> Result<()> {
64        // For each kernel, generate binding structs
65        let mut binding_index = 0;
66        
67        for item in &ast.items {
68            if let Item::Kernel(kernel) = item {
69                // Generate buffer bindings for pointer parameters
70                for param in &kernel.params {
71                    if matches!(param.ty, Type::Pointer(_)) {
72                        self.writeln(&format!(
73                            "@group(0) @binding({binding_index})"
74                        ))?;
75                        
76                        let buffer_type = match &param.ty {
77                            Type::Pointer(inner) => {
78                                let wgsl_type = self.type_to_wgsl(inner)?;
79                                if param.qualifiers.iter().any(|q| matches!(q, ParamQualifier::Const)) {
80                                    format!("var<storage, read> {}: array<{}>;", param.name, wgsl_type)
81                                } else {
82                                    format!("var<storage, read_write> {}: array<{}>;", param.name, wgsl_type)
83                                }
84                            },
85                            _ => unreachable!(),
86                        };
87                        
88                        self.writeln(&buffer_type)?;
89                        self.writeln("")?;
90                        binding_index += 1;
91                    }
92                }
93            }
94        }
95        
96        Ok(())
97    }
98    
99    /// Generate WGSL code for a kernel
100    fn generate_kernel(&mut self, kernel: &KernelDef) -> Result<()> {
101        // Generate workgroup size attribute
102        self.writeln(&format!(
103            "@compute @workgroup_size({}, {}, {})",
104            self.workgroup_size.0, self.workgroup_size.1, self.workgroup_size.2
105        ))?;
106        
107        // Generate function signature
108        self.write(&format!("fn {}(", kernel.name))?;
109        
110        // Add built-in parameters
111        self.write("@builtin(global_invocation_id) global_id: vec3<u32>")?;
112        self.write(", @builtin(local_invocation_id) local_id: vec3<u32>")?;
113        self.write(", @builtin(workgroup_id) workgroup_id: vec3<u32>")?;
114        
115        self.writeln(") {")?;
116        self.indent();
117        
118        // Map CUDA built-ins to WGSL
119        self.writeln("// Map CUDA thread/block indices to WGSL")?;
120        self.writeln("let threadIdx = local_id;")?;
121        self.writeln("let blockIdx = workgroup_id;")?;
122        self.writeln("let blockDim = vec3<u32>(64u, 1u, 1u);")?; // Match workgroup size
123        self.writeln("let gridDim = vec3<u32>(1u, 1u, 1u);")?; // Would need to be computed
124        self.writeln("")?;
125        
126        // Generate kernel body
127        self.generate_block(&kernel.body)?;
128        
129        self.dedent();
130        self.writeln("}")?;
131        self.writeln("")?;
132        
133        Ok(())
134    }
135    
136    /// Generate WGSL code for a device function
137    fn generate_device_function(&mut self, func: &FunctionDef) -> Result<()> {
138        self.write(&format!("fn {}(", func.name))?;
139        
140        // Generate parameters
141        for (i, param) in func.params.iter().enumerate() {
142            if i > 0 {
143                self.write(", ")?;
144            }
145            self.write(&format!("{}: {}", param.name, self.type_to_wgsl(&param.ty)?))?;
146        }
147        
148        self.write(") -> ")?;
149        self.write(&self.type_to_wgsl(&func.return_type)?)?;
150        self.writeln(" {")?;
151        
152        self.indent();
153        self.generate_block(&func.body)?;
154        self.dedent();
155        
156        self.writeln("}")?;
157        self.writeln("")?;
158        
159        Ok(())
160    }
161    
162    /// Generate global variable
163    fn generate_global_var(&mut self, var: &GlobalVar) -> Result<()> {
164        match var.storage {
165            StorageClass::Constant => {
166                self.write("const ")?;
167            },
168            StorageClass::Shared => {
169                self.write("var<workgroup> ")?;
170            },
171            _ => {
172                self.write("var<private> ")?;
173            },
174        }
175        
176        self.write(&format!("{}: {}", var.name, self.type_to_wgsl(&var.ty)?))?;
177        
178        if let Some(init) = &var.init {
179            self.write(" = ")?;
180            self.generate_expression(init)?;
181        }
182        
183        self.writeln(";")?;
184        self.writeln("")?;
185        
186        Ok(())
187    }
188    
189    /// Generate code for a block
190    fn generate_block(&mut self, block: &Block) -> Result<()> {
191        for stmt in &block.statements {
192            self.generate_statement(stmt)?;
193        }
194        Ok(())
195    }
196    
197    /// Generate code for a statement
198    fn generate_statement(&mut self, stmt: &Statement) -> Result<()> {
199        match stmt {
200            Statement::VarDecl { name, ty, init, storage } => {
201                match storage {
202                    StorageClass::Shared => self.write("var<workgroup> ")?,
203                    _ => self.write("var ")?,
204                }
205                
206                self.write(&format!("{}: {}", name, self.type_to_wgsl(ty)?))?;
207                
208                if let Some(init_expr) = init {
209                    self.write(" = ")?;
210                    self.generate_expression(init_expr)?;
211                }
212                
213                self.writeln(";")?;
214            },
215            Statement::Expr(expr) => {
216                self.generate_expression(expr)?;
217                self.writeln(";")?;
218            },
219            Statement::Block(block) => {
220                self.writeln("{")?;
221                self.indent();
222                self.generate_block(block)?;
223                self.dedent();
224                self.writeln("}")?;
225            },
226            Statement::If { condition, then_branch, else_branch } => {
227                self.write("if (")?;
228                self.generate_expression(condition)?;
229                self.writeln(") {")?;
230                
231                self.indent();
232                self.generate_statement(then_branch)?;
233                self.dedent();
234                
235                if let Some(else_stmt) = else_branch {
236                    self.writeln("} else {")?;
237                    self.indent();
238                    self.generate_statement(else_stmt)?;
239                    self.dedent();
240                }
241                
242                self.writeln("}")?;
243            },
244            Statement::For { init, condition, update, body } => {
245                // WGSL doesn't have traditional for loops, convert to while
246                self.writeln("{")?;
247                self.indent();
248                
249                // Initialize
250                if let Some(init) = init {
251                    match init.as_ref() {
252                        Statement::VarDecl { name, ty, init, .. } => {
253                            self.write(&format!("var {}: {}", name, self.type_to_wgsl(ty)?))?;
254                            if let Some(init_expr) = init {
255                                self.write(" = ")?;
256                                self.generate_expression(init_expr)?;
257                            }
258                            self.writeln(";")?;
259                        },
260                        Statement::Expr(expr) => {
261                            self.generate_expression(expr)?;
262                            self.writeln(";")?;
263                        },
264                        _ => return Err(translation_error!("Invalid init statement in for loop")),
265                    }
266                }
267                
268                // While loop
269                self.write("while (")?;
270                if let Some(cond) = condition {
271                    self.generate_expression(cond)?;
272                } else {
273                    self.write("true")?;
274                }
275                self.writeln(") {")?;
276                
277                self.indent();
278                self.generate_statement(body)?;
279                
280                // Update
281                if let Some(update_expr) = update {
282                    self.generate_expression(update_expr)?;
283                    self.writeln(";")?;
284                }
285                
286                self.dedent();
287                self.writeln("}")?;
288                
289                self.dedent();
290                self.writeln("}")?;
291            },
292            Statement::While { condition, body } => {
293                self.write("while (")?;
294                self.generate_expression(condition)?;
295                self.writeln(") {")?;
296                
297                self.indent();
298                self.generate_statement(body)?;
299                self.dedent();
300                
301                self.writeln("}")?;
302            },
303            Statement::Return(expr) => {
304                self.write("return")?;
305                if let Some(e) = expr {
306                    self.write(" ")?;
307                    self.generate_expression(e)?;
308                }
309                self.writeln(";")?;
310            },
311            Statement::Break => self.writeln("break;")?,
312            Statement::Continue => self.writeln("continue;")?,
313            Statement::SyncThreads => self.writeln("workgroupBarrier();")?,
314        }
315        
316        Ok(())
317    }
318    
319    /// Generate code for an expression
320    fn generate_expression(&mut self, expr: &Expression) -> Result<()> {
321        match expr {
322            Expression::Literal(lit) => self.generate_literal(lit)?,
323            Expression::Var(name) => self.write(name)?,
324            Expression::Binary { op, left, right } => {
325                self.write("(")?;
326                self.generate_expression(left)?;
327                self.write(" ")?;
328                self.write(self.binary_op_to_wgsl(op)?)?;
329                self.write(" ")?;
330                self.generate_expression(right)?;
331                self.write(")")?;
332            },
333            Expression::Unary { op, expr } => {
334                self.write("(")?;
335                self.write(self.unary_op_to_wgsl(op)?)?;
336                self.generate_expression(expr)?;
337                self.write(")")?;
338            },
339            Expression::Call { name, args } => {
340                self.write(&format!("{name}("))?;
341                for (i, arg) in args.iter().enumerate() {
342                    if i > 0 {
343                        self.write(", ")?;
344                    }
345                    self.generate_expression(arg)?;
346                }
347                self.write(")")?;
348            },
349            Expression::Index { array, index } => {
350                self.generate_expression(array)?;
351                self.write("[")?;
352                self.generate_expression(index)?;
353                self.write("]")?;
354            },
355            Expression::Member { object, field } => {
356                self.generate_expression(object)?;
357                self.write(&format!(".{field}"))?;
358            },
359            Expression::Cast { ty, expr } => {
360                let wgsl_type = self.type_to_wgsl(ty)?;
361                self.write(&format!("{wgsl_type}("))?;
362                self.generate_expression(expr)?;
363                self.write(")")?;
364            },
365            Expression::ThreadIdx(dim) => {
366                self.write(&format!("threadIdx.{}", self.dimension_to_wgsl(dim)))?;
367            },
368            Expression::BlockIdx(dim) => {
369                self.write(&format!("blockIdx.{}", self.dimension_to_wgsl(dim)))?;
370            },
371            Expression::BlockDim(dim) => {
372                self.write(&format!("blockDim.{}", self.dimension_to_wgsl(dim)))?;
373            },
374            Expression::GridDim(dim) => {
375                self.write(&format!("gridDim.{}", self.dimension_to_wgsl(dim)))?;
376            },
377            Expression::WarpPrimitive { op, args } => {
378                // WGSL doesn't have direct warp primitives, emit a comment
379                self.write(&format!("/* warp_{op:?}("))?;
380                for (i, arg) in args.iter().enumerate() {
381                    if i > 0 {
382                        self.write(", ")?;
383                    }
384                    self.generate_expression(arg)?;
385                }
386                self.write(") */")?;
387                // Emit a placeholder value
388                self.write("0")?;
389            },
390        }
391        
392        Ok(())
393    }
394    
395    /// Generate literal
396    fn generate_literal(&mut self, lit: &Literal) -> Result<()> {
397        match lit {
398            Literal::Bool(b) => self.write(&format!("{b}"))?,
399            Literal::Int(i) => self.write(&format!("{i}i"))?,
400            Literal::UInt(u) => self.write(&format!("{u}u"))?,
401            Literal::Float(f) => self.write(&format!("{f}f"))?,
402            Literal::String(s) => self.write(&format!("\"{s}\""))?,
403        }
404        Ok(())
405    }
406    
407    /// Convert CUDA type to WGSL type
408    fn type_to_wgsl(&self, ty: &Type) -> Result<String> {
409        Ok(match ty {
410            Type::Void => return Err(translation_error!("void type not supported in WGSL")),
411            Type::Bool => "bool".to_string(),
412            Type::Int(int_ty) => match int_ty {
413                IntType::I8 | IntType::I16 | IntType::I32 => "i32".to_string(),
414                IntType::I64 => return Err(translation_error!("i64 not supported in WGSL")),
415                IntType::U8 | IntType::U16 | IntType::U32 => "u32".to_string(),
416                IntType::U64 => return Err(translation_error!("u64 not supported in WGSL")),
417            },
418            Type::Float(float_ty) => match float_ty {
419                FloatType::F16 => "f16".to_string(),
420                FloatType::F32 => "f32".to_string(),
421                FloatType::F64 => return Err(translation_error!("f64 not supported in WGSL")),
422            },
423            Type::Pointer(inner) => {
424                // Pointers are handled as array references in bindings
425                format!("ptr<storage, {}>", self.type_to_wgsl(inner)?)
426            },
427            Type::Array(inner, size) => {
428                match size {
429                    Some(n) => format!("array<{}, {}>", self.type_to_wgsl(inner)?, n),
430                    None => format!("array<{}>", self.type_to_wgsl(inner)?),
431                }
432            },
433            Type::Vector(vec_ty) => {
434                let elem_type = self.type_to_wgsl(&vec_ty.element)?;
435                format!("vec{}<{}>", vec_ty.size, elem_type)
436            },
437            Type::Named(name) => name.clone(),
438            Type::Texture(_) => return Err(translation_error!("Texture types not yet supported")),
439        })
440    }
441    
442    /// Convert binary operator to WGSL
443    fn binary_op_to_wgsl(&self, op: &BinaryOp) -> Result<&'static str> {
444        Ok(match op {
445            BinaryOp::Add => "+",
446            BinaryOp::Sub => "-",
447            BinaryOp::Mul => "*",
448            BinaryOp::Div => "/",
449            BinaryOp::Mod => "%",
450            BinaryOp::And => "&",
451            BinaryOp::Or => "|",
452            BinaryOp::Xor => "^",
453            BinaryOp::Shl => "<<",
454            BinaryOp::Shr => ">>",
455            BinaryOp::Eq => "==",
456            BinaryOp::Ne => "!=",
457            BinaryOp::Lt => "<",
458            BinaryOp::Le => "<=",
459            BinaryOp::Gt => ">",
460            BinaryOp::Ge => ">=",
461            BinaryOp::LogicalAnd => "&&",
462            BinaryOp::LogicalOr => "||",
463            BinaryOp::Assign => "=",
464        })
465    }
466    
467    /// Convert unary operator to WGSL
468    fn unary_op_to_wgsl(&self, op: &UnaryOp) -> Result<&'static str> {
469        Ok(match op {
470            UnaryOp::Not => "!",
471            UnaryOp::Neg => "-",
472            UnaryOp::BitNot => "~",
473            UnaryOp::PreInc => return Err(translation_error!("Pre-increment not supported in WGSL")),
474            UnaryOp::PreDec => return Err(translation_error!("Pre-decrement not supported in WGSL")),
475            UnaryOp::PostInc => return Err(translation_error!("Post-increment not supported in WGSL")),
476            UnaryOp::PostDec => return Err(translation_error!("Post-decrement not supported in WGSL")),
477            UnaryOp::Deref => "*",
478            UnaryOp::AddrOf => "&",
479        })
480    }
481    
482    /// Convert dimension to WGSL component
483    fn dimension_to_wgsl(&self, dim: &Dimension) -> &'static str {
484        match dim {
485            Dimension::X => "x",
486            Dimension::Y => "y",
487            Dimension::Z => "z",
488        }
489    }
490    
491    /// Helper: Write with indentation
492    fn write(&mut self, s: &str) -> Result<()> {
493        self.code.push_str(s);
494        Ok(())
495    }
496    
497    /// Helper: Write line with indentation
498    fn writeln(&mut self, s: &str) -> Result<()> {
499        if !s.is_empty() {
500            for _ in 0..self.indent_level {
501                self.code.push_str("    ");
502            }
503            self.code.push_str(s);
504        }
505        self.code.push('\n');
506        Ok(())
507    }
508    
509    /// Helper: Increase indentation
510    fn indent(&mut self) {
511        self.indent_level += 1;
512    }
513    
514    /// Helper: Decrease indentation
515    fn dedent(&mut self) {
516        if self.indent_level > 0 {
517            self.indent_level -= 1;
518        }
519    }
520}
521
522impl Default for WgslGenerator {
523    fn default() -> Self {
524        Self::new()
525    }
526}