1use std::collections::HashMap;
6use std::fmt::Write;
7
8use crate::{
9 nodes::*, BackendCapabilities, BlockId, CapabilityFlag, Dimension, IrModule, IrNode, IrType,
10 ScalarType, Terminator, ValueId,
11};
12
13#[derive(Debug, Clone)]
15pub struct WgslLoweringConfig {
16 pub subgroups: bool,
18 pub workgroup_size: (u32, u32, u32),
20 pub emulate_atomic64: bool,
22 pub downcast_f64: bool,
24 pub debug: bool,
26}
27
28impl Default for WgslLoweringConfig {
29 fn default() -> Self {
30 Self {
31 subgroups: false,
32 workgroup_size: (256, 1, 1),
33 emulate_atomic64: true,
34 downcast_f64: true,
35 debug: false,
36 }
37 }
38}
39
40impl WgslLoweringConfig {
41 pub fn with_subgroups(mut self) -> Self {
43 self.subgroups = true;
44 self
45 }
46
47 pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
49 self.workgroup_size = (x, y, z);
50 self
51 }
52}
53
54pub struct WgslLowering {
56 config: WgslLoweringConfig,
57 output: String,
58 indent: usize,
59 value_names: HashMap<ValueId, String>,
60 name_counter: usize,
61 block_labels: HashMap<BlockId, String>,
62 #[allow(dead_code)]
63 has_f64_warning: bool,
64}
65
66impl WgslLowering {
67 pub fn new(config: WgslLoweringConfig) -> Self {
69 Self {
70 config,
71 output: String::new(),
72 indent: 0,
73 value_names: HashMap::new(),
74 name_counter: 0,
75 block_labels: HashMap::new(),
76 has_f64_warning: false,
77 }
78 }
79
80 pub fn lower(mut self, module: &IrModule) -> Result<String, WgslLoweringError> {
82 self.check_capabilities(module)?;
84
85 self.emit_header(module);
87
88 self.emit_compute_shader(module)?;
90
91 Ok(self.output)
92 }
93
94 fn check_capabilities(&self, module: &IrModule) -> Result<(), WgslLoweringError> {
95 let _wgpu_caps = if self.config.subgroups {
97 BackendCapabilities::wgpu_with_subgroups()
98 } else {
99 BackendCapabilities::wgpu_baseline()
100 };
101
102 if module.required_capabilities.has(CapabilityFlag::Float64) && !self.config.downcast_f64 {
104 return Err(WgslLoweringError::UnsupportedCapability(
105 "f64 not supported in WGSL (use downcast_f64 option)".to_string(),
106 ));
107 }
108
109 if module.required_capabilities.has(CapabilityFlag::Atomic64)
111 && !self.config.emulate_atomic64
112 {
113 return Err(WgslLoweringError::UnsupportedCapability(
114 "64-bit atomics not supported in WGSL (use emulate_atomic64 option)".to_string(),
115 ));
116 }
117
118 if module
120 .required_capabilities
121 .has(CapabilityFlag::CooperativeGroups)
122 {
123 return Err(WgslLoweringError::UnsupportedCapability(
124 "Cooperative groups / grid sync not supported in WebGPU".to_string(),
125 ));
126 }
127
128 Ok(())
129 }
130
131 fn emit_header(&mut self, module: &IrModule) {
132 self.emit_line("// Generated by ringkernel-ir WGSL lowering");
133 self.emit_line("");
134
135 if self.config.subgroups {
137 self.emit_line("enable subgroups;");
138 self.emit_line("");
139 }
140
141 if !module.parameters.is_empty() {
143 self.emit_line("// Parameters");
144 self.emit_line("struct Params {");
145 self.indent += 1;
146 for param in module.parameters.iter() {
147 if !matches!(param.ty, IrType::Ptr(_) | IrType::Slice(_)) {
149 let ty = self.lower_type(¶m.ty);
150 self.emit_line(&format!("{}: {},", param.name, ty));
151 }
152 }
153 self.indent -= 1;
154 self.emit_line("}");
155 self.emit_line("");
156 }
157
158 self.emit_line("// Bindings");
160 let mut binding_idx = 0;
161
162 let has_uniforms = module
164 .parameters
165 .iter()
166 .any(|p| !matches!(p.ty, IrType::Ptr(_) | IrType::Slice(_)));
167 if has_uniforms {
168 self.emit_line(&format!(
169 "@group(0) @binding({}) var<uniform> params: Params;",
170 binding_idx
171 ));
172 binding_idx += 1;
173 }
174
175 for param in &module.parameters {
177 if let IrType::Ptr(inner) | IrType::Slice(inner) = ¶m.ty {
178 let elem_ty = self.lower_type(inner);
179 self.emit_line(&format!(
180 "@group(0) @binding({}) var<storage, read_write> {}: array<{}>;",
181 binding_idx, param.name, elem_ty
182 ));
183 binding_idx += 1;
184 }
185 }
186
187 self.emit_line("");
188 }
189
190 fn emit_compute_shader(&mut self, module: &IrModule) -> Result<(), WgslLoweringError> {
191 self.assign_names(module);
193
194 let (wx, wy, wz) = self.config.workgroup_size;
196
197 self.emit_line(&format!("@compute @workgroup_size({}, {}, {})", wx, wy, wz));
198 self.emit_line(&format!("fn {}(", module.name));
199 self.indent += 1;
200 self.emit_line("@builtin(global_invocation_id) global_id: vec3<u32>,");
201 self.emit_line("@builtin(local_invocation_id) local_id: vec3<u32>,");
202 self.emit_line("@builtin(workgroup_id) workgroup_id: vec3<u32>,");
203 self.emit_line("@builtin(num_workgroups) num_workgroups: vec3<u32>,");
204 self.indent -= 1;
205 self.emit_line(") {");
206 self.indent += 1;
207
208 self.emit_block(module, module.entry_block)?;
210
211 self.indent -= 1;
212 self.emit_line("}");
213
214 Ok(())
215 }
216
217 fn assign_names(&mut self, module: &IrModule) {
218 for param in &module.parameters {
219 self.value_names.insert(param.value_id, param.name.clone());
221 }
222
223 for (block_id, block) in &module.blocks {
224 self.block_labels.insert(*block_id, block.label.clone());
225 }
226 }
227
228 fn emit_block(
229 &mut self,
230 module: &IrModule,
231 block_id: BlockId,
232 ) -> Result<(), WgslLoweringError> {
233 let block = module
234 .blocks
235 .get(&block_id)
236 .ok_or(WgslLoweringError::UndefinedBlock(block_id))?;
237
238 if block_id != module.entry_block {
241 self.emit_line(&format!("// Block: {}", block.label));
242 }
243
244 for inst in &block.instructions {
246 self.emit_instruction(module, &inst.result, &inst.result_type, &inst.node)?;
247 }
248
249 if let Some(term) = &block.terminator {
251 self.emit_terminator(module, term)?;
252 }
253
254 Ok(())
255 }
256
257 fn emit_instruction(
258 &mut self,
259 _module: &IrModule,
260 result: &ValueId,
261 result_type: &IrType,
262 node: &IrNode,
263 ) -> Result<(), WgslLoweringError> {
264 let result_name = self.get_or_create_name(*result);
265 let ty = self.lower_type(result_type);
266
267 match node {
268 IrNode::Constant(c) => {
270 let val = self.lower_constant(c);
271 self.emit_line(&format!("var {}: {} = {};", result_name, ty, val));
272 }
273
274 IrNode::BinaryOp(op, lhs, rhs) => {
276 let lhs_name = self.get_value_name(*lhs);
277 let rhs_name = self.get_value_name(*rhs);
278 let expr = self.lower_binary_op(op, &lhs_name, &rhs_name);
279 self.emit_line(&format!("var {}: {} = {};", result_name, ty, expr));
280 }
281
282 IrNode::UnaryOp(op, val) => {
284 let val_name = self.get_value_name(*val);
285 let expr = self.lower_unary_op(op, &val_name);
286 self.emit_line(&format!("var {}: {} = {};", result_name, ty, expr));
287 }
288
289 IrNode::Compare(op, lhs, rhs) => {
291 let lhs_name = self.get_value_name(*lhs);
292 let rhs_name = self.get_value_name(*rhs);
293 let cmp_op = self.lower_compare_op(op);
294 self.emit_line(&format!(
295 "var {}: bool = {} {} {};",
296 result_name, lhs_name, cmp_op, rhs_name
297 ));
298 }
299
300 IrNode::Load(ptr) => {
302 let ptr_name = self.get_value_name(*ptr);
303 self.emit_line(&format!("var {}: {} = {};", result_name, ty, ptr_name));
305 }
306
307 IrNode::Store(ptr, val) => {
308 let ptr_name = self.get_value_name(*ptr);
309 let val_name = self.get_value_name(*val);
310 self.emit_line(&format!("{} = {};", ptr_name, val_name));
311 }
312
313 IrNode::GetElementPtr(ptr, indices) => {
314 let ptr_name = self.get_value_name(*ptr);
315 let idx_name = self.get_value_name(indices[0]);
316 self.emit_line(&format!(
318 "var {}: {} = {}[{}];",
319 result_name, ty, ptr_name, idx_name
320 ));
321 }
322
323 IrNode::SharedAlloc(_elem_ty, _count) => {
324 self.emit_line(&format!("// Workgroup var: {}", result_name));
327 }
328
329 IrNode::ThreadId(dim) => {
331 let idx = self.lower_dimension(dim, "local_id");
332 self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
333 }
334
335 IrNode::BlockId(dim) => {
336 let idx = self.lower_dimension(dim, "workgroup_id");
337 self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
338 }
339
340 IrNode::BlockDim(dim) => {
341 let size = match dim {
343 Dimension::X => self.config.workgroup_size.0,
344 Dimension::Y => self.config.workgroup_size.1,
345 Dimension::Z => self.config.workgroup_size.2,
346 };
347 self.emit_line(&format!("var {}: {} = {}u;", result_name, ty, size));
348 }
349
350 IrNode::GridDim(dim) => {
351 let idx = self.lower_dimension(dim, "num_workgroups");
352 self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
353 }
354
355 IrNode::GlobalThreadId(dim) => {
356 let idx = self.lower_dimension(dim, "global_id");
357 self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
358 }
359
360 IrNode::WarpId => {
361 self.emit_line(&format!("var {}: {} = local_id.x / 32u;", result_name, ty));
363 }
364
365 IrNode::LaneId => {
366 self.emit_line(&format!("var {}: {} = local_id.x % 32u;", result_name, ty));
367 }
368
369 IrNode::Barrier => {
371 self.emit_line("workgroupBarrier();");
372 }
373
374 IrNode::MemoryFence(_scope) => {
375 self.emit_line("storageBarrier();");
376 }
377
378 IrNode::GridSync => {
379 return Err(WgslLoweringError::UnsupportedOperation(
380 "Grid sync not supported in WGSL".to_string(),
381 ));
382 }
383
384 IrNode::Atomic(op, ptr, val) => {
386 let ptr_name = self.get_value_name(*ptr);
387 let val_name = self.get_value_name(*val);
388 let atomic_fn = match op {
389 AtomicOp::Add => "atomicAdd",
390 AtomicOp::Sub => "atomicSub",
391 AtomicOp::Exchange => "atomicExchange",
392 AtomicOp::Min => "atomicMin",
393 AtomicOp::Max => "atomicMax",
394 AtomicOp::And => "atomicAnd",
395 AtomicOp::Or => "atomicOr",
396 AtomicOp::Xor => "atomicXor",
397 AtomicOp::Load => "atomicLoad",
398 AtomicOp::Store => {
399 self.emit_line(&format!("atomicStore(&{}, {});", ptr_name, val_name));
400 return Ok(());
401 }
402 };
403 self.emit_line(&format!(
404 "var {}: {} = {}(&{}, {});",
405 result_name, ty, atomic_fn, ptr_name, val_name
406 ));
407 }
408
409 IrNode::AtomicCas(ptr, expected, desired) => {
410 let ptr_name = self.get_value_name(*ptr);
411 let exp_name = self.get_value_name(*expected);
412 let des_name = self.get_value_name(*desired);
413 self.emit_line(&format!(
414 "var {}: {} = atomicCompareExchangeWeak(&{}, {}, {}).old_value;",
415 result_name, ty, ptr_name, exp_name, des_name
416 ));
417 }
418
419 IrNode::WarpVote(op, val) => {
421 if !self.config.subgroups {
422 return Err(WgslLoweringError::UnsupportedOperation(
423 "Subgroup operations require subgroups feature".to_string(),
424 ));
425 }
426 let val_name = self.get_value_name(*val);
427 let vote_fn = match op {
428 WarpVoteOp::All => "subgroupAll",
429 WarpVoteOp::Any => "subgroupAny",
430 WarpVoteOp::Ballot => "subgroupBallot",
431 };
432 self.emit_line(&format!(
433 "var {}: {} = {}({});",
434 result_name, ty, vote_fn, val_name
435 ));
436 }
437
438 IrNode::WarpShuffle(op, val, lane) => {
439 if !self.config.subgroups {
440 return Err(WgslLoweringError::UnsupportedOperation(
441 "Subgroup shuffle requires subgroups feature".to_string(),
442 ));
443 }
444 let val_name = self.get_value_name(*val);
445 let lane_name = self.get_value_name(*lane);
446 let shfl_fn = match op {
447 WarpShuffleOp::Index => "subgroupShuffle",
448 WarpShuffleOp::Up => "subgroupShuffleUp",
449 WarpShuffleOp::Down => "subgroupShuffleDown",
450 WarpShuffleOp::Xor => "subgroupShuffleXor",
451 };
452 self.emit_line(&format!(
453 "var {}: {} = {}({}, {});",
454 result_name, ty, shfl_fn, val_name, lane_name
455 ));
456 }
457
458 IrNode::Select(cond, then_val, else_val) => {
460 let cond_name = self.get_value_name(*cond);
461 let then_name = self.get_value_name(*then_val);
462 let else_name = self.get_value_name(*else_val);
463 self.emit_line(&format!(
464 "var {}: {} = select({}, {}, {});",
465 result_name, ty, else_name, then_name, cond_name
466 ));
467 }
468
469 IrNode::Math(op, args) => {
471 let fn_name = self.lower_math_op(op);
472 let args_str: Vec<String> = args.iter().map(|a| self.get_value_name(*a)).collect();
473 self.emit_line(&format!(
474 "var {}: {} = {}({});",
475 result_name,
476 ty,
477 fn_name,
478 args_str.join(", ")
479 ));
480 }
481
482 IrNode::Parameter(_) | IrNode::Undef | IrNode::Phi(_) => {}
484
485 IrNode::K2HEnqueue(_)
487 | IrNode::H2KDequeue
488 | IrNode::H2KIsEmpty
489 | IrNode::K2KSend(_, _)
490 | IrNode::K2KRecv
491 | IrNode::K2KTryRecv
492 | IrNode::HlcNow
493 | IrNode::HlcTick
494 | IrNode::HlcUpdate(_) => {
495 self.emit_line(&format!("// Not supported in WGSL: {:?}", node));
496 }
497
498 _ => {
499 self.emit_line(&format!("// Unhandled: {:?}", node));
500 }
501 }
502
503 Ok(())
504 }
505
506 fn emit_terminator(
507 &mut self,
508 module: &IrModule,
509 term: &Terminator,
510 ) -> Result<(), WgslLoweringError> {
511 match term {
512 Terminator::Return(None) => {
513 self.emit_line("return;");
514 }
515 Terminator::Return(Some(val)) => {
516 let val_name = self.get_value_name(*val);
518 self.emit_line(&format!("// Return: {}", val_name));
519 self.emit_line("return;");
520 }
521 Terminator::Branch(target) => {
522 self.emit_block(module, *target)?;
524 }
525 Terminator::CondBranch(cond, then_block, else_block) => {
526 let cond_name = self.get_value_name(*cond);
527 self.emit_line(&format!("if ({}) {{", cond_name));
528 self.indent += 1;
529 self.emit_block(module, *then_block)?;
530 self.indent -= 1;
531 self.emit_line("} else {");
532 self.indent += 1;
533 self.emit_block(module, *else_block)?;
534 self.indent -= 1;
535 self.emit_line("}");
536 }
537 Terminator::Switch(val, default, cases) => {
538 let val_name = self.get_value_name(*val);
539 self.emit_line(&format!("switch ({}) {{", val_name));
540 self.indent += 1;
541 for (case_val, target) in cases {
542 let case_str = self.lower_constant(case_val);
543 self.emit_line(&format!("case {}: {{", case_str));
544 self.indent += 1;
545 self.emit_block(module, *target)?;
546 self.indent -= 1;
547 self.emit_line("}");
548 }
549 self.emit_line("default: {");
550 self.indent += 1;
551 self.emit_block(module, *default)?;
552 self.indent -= 1;
553 self.emit_line("}");
554 self.indent -= 1;
555 self.emit_line("}");
556 }
557 Terminator::Unreachable => {
558 self.emit_line("// unreachable");
559 }
560 }
561 Ok(())
562 }
563
564 fn lower_type(&self, ty: &IrType) -> String {
565 match ty {
566 IrType::Void => "void".to_string(),
567 IrType::Scalar(s) => self.lower_scalar_type(s),
568 IrType::Vector(v) => format!("vec{}<{}>", v.count, self.lower_scalar_type(&v.element)),
569 IrType::Ptr(inner) => format!("ptr<storage, {}>", self.lower_type(inner)),
570 IrType::Array(inner, size) => format!("array<{}, {}>", self.lower_type(inner), size),
571 IrType::Slice(inner) => format!("array<{}>", self.lower_type(inner)),
572 IrType::Struct(s) => s.name.clone(),
573 IrType::Function(_) => "void".to_string(),
574 }
575 }
576
577 fn lower_scalar_type(&self, ty: &ScalarType) -> String {
578 match ty {
579 ScalarType::Bool => "bool",
580 ScalarType::I8 | ScalarType::I16 | ScalarType::I32 => "i32",
581 ScalarType::I64 => "i32",
583 ScalarType::U8 | ScalarType::U16 | ScalarType::U32 => "u32",
584 ScalarType::U64 => "u32", ScalarType::F16 => "f16",
586 ScalarType::F32 => "f32",
587 ScalarType::F64 => "f32", }
589 .to_string()
590 }
591
592 fn lower_constant(&self, c: &ConstantValue) -> String {
593 match c {
594 ConstantValue::Bool(b) => if *b { "true" } else { "false" }.to_string(),
595 ConstantValue::I32(v) => format!("{}i", v),
596 ConstantValue::I64(v) => format!("{}i", *v as i32), ConstantValue::U32(v) => format!("{}u", v),
598 ConstantValue::U64(v) => format!("{}u", *v as u32), ConstantValue::F32(v) => {
600 if v.is_nan() {
601 "0.0f".to_string()
602 } else if v.is_infinite() {
603 if *v > 0.0 { "1e38f" } else { "-1e38f" }.to_string()
604 } else {
605 format!("{}f", v)
606 }
607 }
608 ConstantValue::F64(v) => format!("{}f", *v as f32), ConstantValue::Null => "0u".to_string(),
610 ConstantValue::Array(elems) => {
611 let elems_str: Vec<String> = elems.iter().map(|e| self.lower_constant(e)).collect();
612 format!("array({})", elems_str.join(", "))
613 }
614 ConstantValue::Struct(fields) => {
615 let fields_str: Vec<String> =
616 fields.iter().map(|f| self.lower_constant(f)).collect();
617 format!("({})", fields_str.join(", "))
618 }
619 }
620 }
621
622 fn lower_binary_op(&self, op: &BinaryOp, lhs: &str, rhs: &str) -> String {
623 match op {
624 BinaryOp::Add => format!("({} + {})", lhs, rhs),
625 BinaryOp::Sub => format!("({} - {})", lhs, rhs),
626 BinaryOp::Mul => format!("({} * {})", lhs, rhs),
627 BinaryOp::Div => format!("({} / {})", lhs, rhs),
628 BinaryOp::Rem => format!("({} % {})", lhs, rhs),
629 BinaryOp::And => format!("({} & {})", lhs, rhs),
630 BinaryOp::Or => format!("({} | {})", lhs, rhs),
631 BinaryOp::Xor => format!("({} ^ {})", lhs, rhs),
632 BinaryOp::Shl => format!("({} << {})", lhs, rhs),
633 BinaryOp::Shr => format!("({} >> {})", lhs, rhs),
634 BinaryOp::Sar => format!("({} >> {})", lhs, rhs),
635 BinaryOp::Fma => format!("fma({}, {}, 0.0)", lhs, rhs),
636 BinaryOp::Pow => format!("pow({}, {})", lhs, rhs),
637 BinaryOp::Min => format!("min({}, {})", lhs, rhs),
638 BinaryOp::Max => format!("max({}, {})", lhs, rhs),
639 }
640 }
641
642 fn lower_unary_op(&self, op: &UnaryOp, val: &str) -> String {
643 match op {
644 UnaryOp::Neg => format!("(-{})", val),
645 UnaryOp::Not => format!("(~{})", val),
646 UnaryOp::LogicalNot => format!("(!{})", val),
647 UnaryOp::Abs => format!("abs({})", val),
648 UnaryOp::Sqrt => format!("sqrt({})", val),
649 UnaryOp::Rsqrt => format!("inverseSqrt({})", val),
650 UnaryOp::Floor => format!("floor({})", val),
651 UnaryOp::Ceil => format!("ceil({})", val),
652 UnaryOp::Round => format!("round({})", val),
653 UnaryOp::Trunc => format!("trunc({})", val),
654 UnaryOp::Sign => format!("sign({})", val),
655 }
656 }
657
658 fn lower_compare_op(&self, op: &CompareOp) -> &'static str {
659 match op {
660 CompareOp::Eq => "==",
661 CompareOp::Ne => "!=",
662 CompareOp::Lt => "<",
663 CompareOp::Le => "<=",
664 CompareOp::Gt => ">",
665 CompareOp::Ge => ">=",
666 }
667 }
668
669 fn lower_dimension(&self, dim: &Dimension, prefix: &str) -> String {
670 match dim {
671 Dimension::X => format!("{}.x", prefix),
672 Dimension::Y => format!("{}.y", prefix),
673 Dimension::Z => format!("{}.z", prefix),
674 }
675 }
676
677 fn lower_math_op(&self, op: &MathOp) -> &'static str {
678 match op {
679 MathOp::Sin => "sin",
680 MathOp::Cos => "cos",
681 MathOp::Tan => "tan",
682 MathOp::Asin => "asin",
683 MathOp::Acos => "acos",
684 MathOp::Atan => "atan",
685 MathOp::Atan2 => "atan2",
686 MathOp::Sinh => "sinh",
687 MathOp::Cosh => "cosh",
688 MathOp::Tanh => "tanh",
689 MathOp::Exp => "exp",
690 MathOp::Exp2 => "exp2",
691 MathOp::Log => "log",
692 MathOp::Log2 => "log2",
693 MathOp::Log10 => "log", MathOp::Lerp => "mix",
695 MathOp::Clamp => "clamp",
696 MathOp::Step => "step",
697 MathOp::SmoothStep => "smoothstep",
698 MathOp::Fract => "fract",
699 MathOp::CopySign => "sign", }
701 }
702
703 fn get_value_name(&self, id: ValueId) -> String {
704 self.value_names
705 .get(&id)
706 .cloned()
707 .unwrap_or_else(|| format!("v{}", id.raw()))
708 }
709
710 fn get_or_create_name(&mut self, id: ValueId) -> String {
711 if let Some(name) = self.value_names.get(&id) {
712 return name.clone();
713 }
714 let name = format!("t{}", self.name_counter);
715 self.name_counter += 1;
716 self.value_names.insert(id, name.clone());
717 name
718 }
719
720 fn emit_line(&mut self, line: &str) {
721 let indent = " ".repeat(self.indent);
722 writeln!(self.output, "{}{}", indent, line).unwrap();
723 }
724}
725
726#[derive(Debug, Clone)]
728pub enum WgslLoweringError {
729 UnsupportedCapability(String),
731 UnsupportedOperation(String),
733 UndefinedBlock(BlockId),
735 UndefinedValue(ValueId),
737 TypeError(String),
739}
740
741impl std::fmt::Display for WgslLoweringError {
742 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
743 match self {
744 WgslLoweringError::UnsupportedCapability(cap) => {
745 write!(f, "Unsupported capability: {}", cap)
746 }
747 WgslLoweringError::UnsupportedOperation(op) => {
748 write!(f, "Unsupported operation: {}", op)
749 }
750 WgslLoweringError::UndefinedBlock(id) => write!(f, "Undefined block: {}", id),
751 WgslLoweringError::UndefinedValue(id) => write!(f, "Undefined value: {}", id),
752 WgslLoweringError::TypeError(msg) => write!(f, "Type error: {}", msg),
753 }
754 }
755}
756
757impl std::error::Error for WgslLoweringError {}
758
759pub fn lower_to_wgsl(module: &IrModule) -> Result<String, WgslLoweringError> {
761 WgslLowering::new(WgslLoweringConfig::default()).lower(module)
762}
763
764pub fn lower_to_wgsl_with_config(
766 module: &IrModule,
767 config: WgslLoweringConfig,
768) -> Result<String, WgslLoweringError> {
769 WgslLowering::new(config).lower(module)
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775 use crate::IrBuilder;
776
777 #[test]
778 fn test_lower_simple_kernel() {
779 let mut builder = IrBuilder::new("add_one");
780
781 let _x = builder.parameter("x", IrType::slice(IrType::F32));
782 let _n = builder.parameter("n", IrType::I32);
783
784 let idx = builder.global_thread_id(Dimension::X);
785 let _ = idx;
786
787 builder.ret();
788
789 let module = builder.build();
790 let wgsl = lower_to_wgsl(&module).unwrap();
791
792 assert!(wgsl.contains("@compute @workgroup_size"));
793 assert!(wgsl.contains("fn add_one"));
794 assert!(wgsl.contains("global_id"));
795 }
796
797 #[test]
798 fn test_lower_with_barrier() {
799 let mut builder = IrBuilder::new("sync");
800
801 builder.barrier();
802 builder.ret();
803
804 let module = builder.build();
805 let wgsl = lower_to_wgsl(&module).unwrap();
806
807 assert!(wgsl.contains("workgroupBarrier()"));
808 }
809
810 #[test]
811 fn test_lower_with_control_flow() {
812 let mut builder = IrBuilder::new("branch");
813
814 let cond = builder.const_bool(true);
815 let then_block = builder.create_block("then");
816 let else_block = builder.create_block("else");
817
818 builder.cond_branch(cond, then_block, else_block);
819
820 builder.switch_to_block(then_block);
821 builder.ret();
822
823 builder.switch_to_block(else_block);
824 builder.ret();
825
826 let module = builder.build();
827 let wgsl = lower_to_wgsl(&module).unwrap();
828
829 assert!(wgsl.contains("if ("));
830 assert!(wgsl.contains("} else {"));
831 }
832
833 #[test]
834 fn test_lower_rejects_grid_sync() {
835 let mut builder = IrBuilder::new("grid");
836 builder.grid_sync();
837 builder.ret();
838
839 let module = builder.build();
840 let result = lower_to_wgsl(&module);
841
842 assert!(result.is_err());
843 }
844
845 #[test]
846 fn test_lower_with_subgroups() {
847 let mut builder = IrBuilder::new("subgroup");
848
849 let _val = builder.const_bool(true);
850 builder.ret();
853
854 let module = builder.build();
855 let config = WgslLoweringConfig::default().with_subgroups();
856 let wgsl = lower_to_wgsl_with_config(&module, config).unwrap();
857
858 assert!(wgsl.contains("enable subgroups;"));
859 }
860}