1use crate::{Result, translation_error};
4use crate::parser::ast::*;
5use std::fmt::Write;
6
7pub struct WgslGenerator {
9 code: String,
11 indent_level: usize,
13 workgroup_size: (u32, u32, u32),
15}
16
17impl WgslGenerator {
18 pub fn new() -> Self {
20 Self {
21 code: String::new(),
22 indent_level: 0,
23 workgroup_size: (64, 1, 1), }
25 }
26
27 pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
29 self.workgroup_size = (x, y, z);
30 self
31 }
32
33 pub fn generate(&mut self, ast: Ast) -> Result<String> {
35 self.generate_structs(&ast)?;
37
38 for item in &ast.items {
40 if let Item::GlobalVar(var) = item {
41 self.generate_global_var(var)?;
42 }
43 }
44
45 for item in &ast.items {
47 if let Item::DeviceFunction(func) = item {
48 self.generate_device_function(func)?;
49 }
50 }
51
52 for item in &ast.items {
54 if let Item::Kernel(kernel) = item {
55 self.generate_kernel(kernel)?;
56 }
57 }
58
59 Ok(self.code.clone())
60 }
61
62 fn generate_structs(&mut self, ast: &Ast) -> Result<()> {
64 let mut binding_index = 0;
66
67 for item in &ast.items {
68 if let Item::Kernel(kernel) = item {
69 for param in &kernel.params {
71 if matches!(param.ty, Type::Pointer(_)) {
72 self.writeln(&format!(
73 "@group(0) @binding({binding_index})"
74 ))?;
75
76 let buffer_type = match ¶m.ty {
77 Type::Pointer(inner) => {
78 let wgsl_type = self.type_to_wgsl(inner)?;
79 if param.qualifiers.iter().any(|q| matches!(q, ParamQualifier::Const)) {
80 format!("var<storage, read> {}: array<{}>;", param.name, wgsl_type)
81 } else {
82 format!("var<storage, read_write> {}: array<{}>;", param.name, wgsl_type)
83 }
84 },
85 _ => unreachable!(),
86 };
87
88 self.writeln(&buffer_type)?;
89 self.writeln("")?;
90 binding_index += 1;
91 }
92 }
93 }
94 }
95
96 Ok(())
97 }
98
99 fn generate_kernel(&mut self, kernel: &KernelDef) -> Result<()> {
101 self.writeln(&format!(
103 "@compute @workgroup_size({}, {}, {})",
104 self.workgroup_size.0, self.workgroup_size.1, self.workgroup_size.2
105 ))?;
106
107 self.write(&format!("fn {}(", kernel.name))?;
109
110 self.write("@builtin(global_invocation_id) global_id: vec3<u32>")?;
112 self.write(", @builtin(local_invocation_id) local_id: vec3<u32>")?;
113 self.write(", @builtin(workgroup_id) workgroup_id: vec3<u32>")?;
114
115 self.writeln(") {")?;
116 self.indent();
117
118 self.writeln("// Map CUDA thread/block indices to WGSL")?;
120 self.writeln("let threadIdx = local_id;")?;
121 self.writeln("let blockIdx = workgroup_id;")?;
122 self.writeln("let blockDim = vec3<u32>(64u, 1u, 1u);")?; self.writeln("let gridDim = vec3<u32>(1u, 1u, 1u);")?; self.writeln("")?;
125
126 self.generate_block(&kernel.body)?;
128
129 self.dedent();
130 self.writeln("}")?;
131 self.writeln("")?;
132
133 Ok(())
134 }
135
136 fn generate_device_function(&mut self, func: &FunctionDef) -> Result<()> {
138 self.write(&format!("fn {}(", func.name))?;
139
140 for (i, param) in func.params.iter().enumerate() {
142 if i > 0 {
143 self.write(", ")?;
144 }
145 self.write(&format!("{}: {}", param.name, self.type_to_wgsl(¶m.ty)?))?;
146 }
147
148 self.write(") -> ")?;
149 self.write(&self.type_to_wgsl(&func.return_type)?)?;
150 self.writeln(" {")?;
151
152 self.indent();
153 self.generate_block(&func.body)?;
154 self.dedent();
155
156 self.writeln("}")?;
157 self.writeln("")?;
158
159 Ok(())
160 }
161
162 fn generate_global_var(&mut self, var: &GlobalVar) -> Result<()> {
164 match var.storage {
165 StorageClass::Constant => {
166 self.write("const ")?;
167 },
168 StorageClass::Shared => {
169 self.write("var<workgroup> ")?;
170 },
171 _ => {
172 self.write("var<private> ")?;
173 },
174 }
175
176 self.write(&format!("{}: {}", var.name, self.type_to_wgsl(&var.ty)?))?;
177
178 if let Some(init) = &var.init {
179 self.write(" = ")?;
180 self.generate_expression(init)?;
181 }
182
183 self.writeln(";")?;
184 self.writeln("")?;
185
186 Ok(())
187 }
188
189 fn generate_block(&mut self, block: &Block) -> Result<()> {
191 for stmt in &block.statements {
192 self.generate_statement(stmt)?;
193 }
194 Ok(())
195 }
196
197 fn generate_statement(&mut self, stmt: &Statement) -> Result<()> {
199 match stmt {
200 Statement::VarDecl { name, ty, init, storage } => {
201 match storage {
202 StorageClass::Shared => self.write("var<workgroup> ")?,
203 _ => self.write("var ")?,
204 }
205
206 self.write(&format!("{}: {}", name, self.type_to_wgsl(ty)?))?;
207
208 if let Some(init_expr) = init {
209 self.write(" = ")?;
210 self.generate_expression(init_expr)?;
211 }
212
213 self.writeln(";")?;
214 },
215 Statement::Expr(expr) => {
216 self.generate_expression(expr)?;
217 self.writeln(";")?;
218 },
219 Statement::Block(block) => {
220 self.writeln("{")?;
221 self.indent();
222 self.generate_block(block)?;
223 self.dedent();
224 self.writeln("}")?;
225 },
226 Statement::If { condition, then_branch, else_branch } => {
227 self.write("if (")?;
228 self.generate_expression(condition)?;
229 self.writeln(") {")?;
230
231 self.indent();
232 self.generate_statement(then_branch)?;
233 self.dedent();
234
235 if let Some(else_stmt) = else_branch {
236 self.writeln("} else {")?;
237 self.indent();
238 self.generate_statement(else_stmt)?;
239 self.dedent();
240 }
241
242 self.writeln("}")?;
243 },
244 Statement::For { init, condition, update, body } => {
245 self.writeln("{")?;
247 self.indent();
248
249 if let Some(init) = init {
251 match init.as_ref() {
252 Statement::VarDecl { name, ty, init, .. } => {
253 self.write(&format!("var {}: {}", name, self.type_to_wgsl(ty)?))?;
254 if let Some(init_expr) = init {
255 self.write(" = ")?;
256 self.generate_expression(init_expr)?;
257 }
258 self.writeln(";")?;
259 },
260 Statement::Expr(expr) => {
261 self.generate_expression(expr)?;
262 self.writeln(";")?;
263 },
264 _ => return Err(translation_error!("Invalid init statement in for loop")),
265 }
266 }
267
268 self.write("while (")?;
270 if let Some(cond) = condition {
271 self.generate_expression(cond)?;
272 } else {
273 self.write("true")?;
274 }
275 self.writeln(") {")?;
276
277 self.indent();
278 self.generate_statement(body)?;
279
280 if let Some(update_expr) = update {
282 self.generate_expression(update_expr)?;
283 self.writeln(";")?;
284 }
285
286 self.dedent();
287 self.writeln("}")?;
288
289 self.dedent();
290 self.writeln("}")?;
291 },
292 Statement::While { condition, body } => {
293 self.write("while (")?;
294 self.generate_expression(condition)?;
295 self.writeln(") {")?;
296
297 self.indent();
298 self.generate_statement(body)?;
299 self.dedent();
300
301 self.writeln("}")?;
302 },
303 Statement::Return(expr) => {
304 self.write("return")?;
305 if let Some(e) = expr {
306 self.write(" ")?;
307 self.generate_expression(e)?;
308 }
309 self.writeln(";")?;
310 },
311 Statement::Break => self.writeln("break;")?,
312 Statement::Continue => self.writeln("continue;")?,
313 Statement::SyncThreads => self.writeln("workgroupBarrier();")?,
314 }
315
316 Ok(())
317 }
318
319 fn generate_expression(&mut self, expr: &Expression) -> Result<()> {
321 match expr {
322 Expression::Literal(lit) => self.generate_literal(lit)?,
323 Expression::Var(name) => self.write(name)?,
324 Expression::Binary { op, left, right } => {
325 self.write("(")?;
326 self.generate_expression(left)?;
327 self.write(" ")?;
328 self.write(self.binary_op_to_wgsl(op)?)?;
329 self.write(" ")?;
330 self.generate_expression(right)?;
331 self.write(")")?;
332 },
333 Expression::Unary { op, expr } => {
334 self.write("(")?;
335 self.write(self.unary_op_to_wgsl(op)?)?;
336 self.generate_expression(expr)?;
337 self.write(")")?;
338 },
339 Expression::Call { name, args } => {
340 self.write(&format!("{name}("))?;
341 for (i, arg) in args.iter().enumerate() {
342 if i > 0 {
343 self.write(", ")?;
344 }
345 self.generate_expression(arg)?;
346 }
347 self.write(")")?;
348 },
349 Expression::Index { array, index } => {
350 self.generate_expression(array)?;
351 self.write("[")?;
352 self.generate_expression(index)?;
353 self.write("]")?;
354 },
355 Expression::Member { object, field } => {
356 self.generate_expression(object)?;
357 self.write(&format!(".{field}"))?;
358 },
359 Expression::Cast { ty, expr } => {
360 let wgsl_type = self.type_to_wgsl(ty)?;
361 self.write(&format!("{wgsl_type}("))?;
362 self.generate_expression(expr)?;
363 self.write(")")?;
364 },
365 Expression::ThreadIdx(dim) => {
366 self.write(&format!("threadIdx.{}", self.dimension_to_wgsl(dim)))?;
367 },
368 Expression::BlockIdx(dim) => {
369 self.write(&format!("blockIdx.{}", self.dimension_to_wgsl(dim)))?;
370 },
371 Expression::BlockDim(dim) => {
372 self.write(&format!("blockDim.{}", self.dimension_to_wgsl(dim)))?;
373 },
374 Expression::GridDim(dim) => {
375 self.write(&format!("gridDim.{}", self.dimension_to_wgsl(dim)))?;
376 },
377 Expression::WarpPrimitive { op, args } => {
378 self.write(&format!("/* warp_{op:?}("))?;
380 for (i, arg) in args.iter().enumerate() {
381 if i > 0 {
382 self.write(", ")?;
383 }
384 self.generate_expression(arg)?;
385 }
386 self.write(") */")?;
387 self.write("0")?;
389 },
390 }
391
392 Ok(())
393 }
394
395 fn generate_literal(&mut self, lit: &Literal) -> Result<()> {
397 match lit {
398 Literal::Bool(b) => self.write(&format!("{b}"))?,
399 Literal::Int(i) => self.write(&format!("{i}i"))?,
400 Literal::UInt(u) => self.write(&format!("{u}u"))?,
401 Literal::Float(f) => self.write(&format!("{f}f"))?,
402 Literal::String(s) => self.write(&format!("\"{s}\""))?,
403 }
404 Ok(())
405 }
406
407 fn type_to_wgsl(&self, ty: &Type) -> Result<String> {
409 Ok(match ty {
410 Type::Void => return Err(translation_error!("void type not supported in WGSL")),
411 Type::Bool => "bool".to_string(),
412 Type::Int(int_ty) => match int_ty {
413 IntType::I8 | IntType::I16 | IntType::I32 => "i32".to_string(),
414 IntType::I64 => return Err(translation_error!("i64 not supported in WGSL")),
415 IntType::U8 | IntType::U16 | IntType::U32 => "u32".to_string(),
416 IntType::U64 => return Err(translation_error!("u64 not supported in WGSL")),
417 },
418 Type::Float(float_ty) => match float_ty {
419 FloatType::F16 => "f16".to_string(),
420 FloatType::F32 => "f32".to_string(),
421 FloatType::F64 => return Err(translation_error!("f64 not supported in WGSL")),
422 },
423 Type::Pointer(inner) => {
424 format!("ptr<storage, {}>", self.type_to_wgsl(inner)?)
426 },
427 Type::Array(inner, size) => {
428 match size {
429 Some(n) => format!("array<{}, {}>", self.type_to_wgsl(inner)?, n),
430 None => format!("array<{}>", self.type_to_wgsl(inner)?),
431 }
432 },
433 Type::Vector(vec_ty) => {
434 let elem_type = self.type_to_wgsl(&vec_ty.element)?;
435 format!("vec{}<{}>", vec_ty.size, elem_type)
436 },
437 Type::Named(name) => name.clone(),
438 Type::Texture(_) => return Err(translation_error!("Texture types not yet supported")),
439 })
440 }
441
442 fn binary_op_to_wgsl(&self, op: &BinaryOp) -> Result<&'static str> {
444 Ok(match op {
445 BinaryOp::Add => "+",
446 BinaryOp::Sub => "-",
447 BinaryOp::Mul => "*",
448 BinaryOp::Div => "/",
449 BinaryOp::Mod => "%",
450 BinaryOp::And => "&",
451 BinaryOp::Or => "|",
452 BinaryOp::Xor => "^",
453 BinaryOp::Shl => "<<",
454 BinaryOp::Shr => ">>",
455 BinaryOp::Eq => "==",
456 BinaryOp::Ne => "!=",
457 BinaryOp::Lt => "<",
458 BinaryOp::Le => "<=",
459 BinaryOp::Gt => ">",
460 BinaryOp::Ge => ">=",
461 BinaryOp::LogicalAnd => "&&",
462 BinaryOp::LogicalOr => "||",
463 BinaryOp::Assign => "=",
464 })
465 }
466
467 fn unary_op_to_wgsl(&self, op: &UnaryOp) -> Result<&'static str> {
469 Ok(match op {
470 UnaryOp::Not => "!",
471 UnaryOp::Neg => "-",
472 UnaryOp::BitNot => "~",
473 UnaryOp::PreInc => return Err(translation_error!("Pre-increment not supported in WGSL")),
474 UnaryOp::PreDec => return Err(translation_error!("Pre-decrement not supported in WGSL")),
475 UnaryOp::PostInc => return Err(translation_error!("Post-increment not supported in WGSL")),
476 UnaryOp::PostDec => return Err(translation_error!("Post-decrement not supported in WGSL")),
477 UnaryOp::Deref => "*",
478 UnaryOp::AddrOf => "&",
479 })
480 }
481
482 fn dimension_to_wgsl(&self, dim: &Dimension) -> &'static str {
484 match dim {
485 Dimension::X => "x",
486 Dimension::Y => "y",
487 Dimension::Z => "z",
488 }
489 }
490
491 fn write(&mut self, s: &str) -> Result<()> {
493 self.code.push_str(s);
494 Ok(())
495 }
496
497 fn writeln(&mut self, s: &str) -> Result<()> {
499 if !s.is_empty() {
500 for _ in 0..self.indent_level {
501 self.code.push_str(" ");
502 }
503 self.code.push_str(s);
504 }
505 self.code.push('\n');
506 Ok(())
507 }
508
509 fn indent(&mut self) {
511 self.indent_level += 1;
512 }
513
514 fn dedent(&mut self) {
516 if self.indent_level > 0 {
517 self.indent_level -= 1;
518 }
519 }
520}
521
522impl Default for WgslGenerator {
523 fn default() -> Self {
524 Self::new()
525 }
526}