Skip to main content

ringkernel_ir/
lower_wgsl.rs

1//! IR to WGSL lowering pass.
2//!
3//! Lowers IR to WebGPU Shading Language for cross-platform GPU compute.
4
5use std::collections::HashMap;
6use std::fmt::Write;
7
8use crate::{
9    nodes::*, BackendCapabilities, BlockId, CapabilityFlag, Dimension, IrModule, IrNode, IrType,
10    ScalarType, Terminator, ValueId,
11};
12
13/// WGSL lowering configuration.
14#[derive(Debug, Clone)]
15pub struct WgslLoweringConfig {
16    /// Enable subgroup operations (if available).
17    pub subgroups: bool,
18    /// Workgroup size.
19    pub workgroup_size: (u32, u32, u32),
20    /// Emulate 64-bit atomics using 32-bit pairs.
21    pub emulate_atomic64: bool,
22    /// Downcast f64 to f32 (WGSL doesn't support f64).
23    pub downcast_f64: bool,
24    /// Generate debug comments.
25    pub debug: bool,
26}
27
28impl Default for WgslLoweringConfig {
29    fn default() -> Self {
30        Self {
31            subgroups: false,
32            workgroup_size: (256, 1, 1),
33            emulate_atomic64: true,
34            downcast_f64: true,
35            debug: false,
36        }
37    }
38}
39
40impl WgslLoweringConfig {
41    /// Enable subgroup operations.
42    pub fn with_subgroups(mut self) -> Self {
43        self.subgroups = true;
44        self
45    }
46
47    /// Set workgroup size.
48    pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
49        self.workgroup_size = (x, y, z);
50        self
51    }
52}
53
54/// WGSL code generator.
55pub struct WgslLowering {
56    config: WgslLoweringConfig,
57    output: String,
58    indent: usize,
59    value_names: HashMap<ValueId, String>,
60    name_counter: usize,
61    block_labels: HashMap<BlockId, String>,
62    #[allow(dead_code)]
63    has_f64_warning: bool,
64}
65
66impl WgslLowering {
67    /// Create a new WGSL lowering pass.
68    pub fn new(config: WgslLoweringConfig) -> Self {
69        Self {
70            config,
71            output: String::new(),
72            indent: 0,
73            value_names: HashMap::new(),
74            name_counter: 0,
75            block_labels: HashMap::new(),
76            has_f64_warning: false,
77        }
78    }
79
80    /// Lower an IR module to WGSL code.
81    pub fn lower(mut self, module: &IrModule) -> Result<String, WgslLoweringError> {
82        // Check capabilities
83        self.check_capabilities(module)?;
84
85        // Generate bindings and structs
86        self.emit_header(module);
87
88        // Generate compute shader
89        self.emit_compute_shader(module)?;
90
91        Ok(self.output)
92    }
93
94    fn check_capabilities(&self, module: &IrModule) -> Result<(), WgslLoweringError> {
95        // Capability tracking for future use
96        let _wgpu_caps = if self.config.subgroups {
97            BackendCapabilities::wgpu_with_subgroups()
98        } else {
99            BackendCapabilities::wgpu_baseline()
100        };
101
102        // Check for f64 usage
103        if module.required_capabilities.has(CapabilityFlag::Float64) && !self.config.downcast_f64 {
104            return Err(WgslLoweringError::UnsupportedCapability(
105                "f64 not supported in WGSL (use downcast_f64 option)".to_string(),
106            ));
107        }
108
109        // Check for atomic64
110        if module.required_capabilities.has(CapabilityFlag::Atomic64)
111            && !self.config.emulate_atomic64
112        {
113            return Err(WgslLoweringError::UnsupportedCapability(
114                "64-bit atomics not supported in WGSL (use emulate_atomic64 option)".to_string(),
115            ));
116        }
117
118        // Check for cooperative groups
119        if module
120            .required_capabilities
121            .has(CapabilityFlag::CooperativeGroups)
122        {
123            return Err(WgslLoweringError::UnsupportedCapability(
124                "Cooperative groups / grid sync not supported in WebGPU".to_string(),
125            ));
126        }
127
128        Ok(())
129    }
130
131    fn emit_header(&mut self, module: &IrModule) {
132        self.emit_line("// Generated by ringkernel-ir WGSL lowering");
133        self.emit_line("");
134
135        // Emit subgroup enable if needed
136        if self.config.subgroups {
137            self.emit_line("enable subgroups;");
138            self.emit_line("");
139        }
140
141        // Emit parameter structs
142        if !module.parameters.is_empty() {
143            self.emit_line("// Parameters");
144            self.emit_line("struct Params {");
145            self.indent += 1;
146            for param in module.parameters.iter() {
147                // Only emit non-pointer params in struct
148                if !matches!(param.ty, IrType::Ptr(_) | IrType::Slice(_)) {
149                    let ty = self.lower_type(&param.ty);
150                    self.emit_line(&format!("{}: {},", param.name, ty));
151                }
152            }
153            self.indent -= 1;
154            self.emit_line("}");
155            self.emit_line("");
156        }
157
158        // Emit bindings
159        self.emit_line("// Bindings");
160        let mut binding_idx = 0;
161
162        // Uniform buffer for params
163        let has_uniforms = module
164            .parameters
165            .iter()
166            .any(|p| !matches!(p.ty, IrType::Ptr(_) | IrType::Slice(_)));
167        if has_uniforms {
168            self.emit_line(&format!(
169                "@group(0) @binding({}) var<uniform> params: Params;",
170                binding_idx
171            ));
172            binding_idx += 1;
173        }
174
175        // Storage buffers for pointers/slices
176        for param in &module.parameters {
177            if let IrType::Ptr(inner) | IrType::Slice(inner) = &param.ty {
178                let elem_ty = self.lower_type(inner);
179                self.emit_line(&format!(
180                    "@group(0) @binding({}) var<storage, read_write> {}: array<{}>;",
181                    binding_idx, param.name, elem_ty
182                ));
183                binding_idx += 1;
184            }
185        }
186
187        self.emit_line("");
188    }
189
190    fn emit_compute_shader(&mut self, module: &IrModule) -> Result<(), WgslLoweringError> {
191        // Assign names
192        self.assign_names(module);
193
194        // Workgroup size
195        let (wx, wy, wz) = self.config.workgroup_size;
196
197        self.emit_line(&format!("@compute @workgroup_size({}, {}, {})", wx, wy, wz));
198        self.emit_line(&format!("fn {}(", module.name));
199        self.indent += 1;
200        self.emit_line("@builtin(global_invocation_id) global_id: vec3<u32>,");
201        self.emit_line("@builtin(local_invocation_id) local_id: vec3<u32>,");
202        self.emit_line("@builtin(workgroup_id) workgroup_id: vec3<u32>,");
203        self.emit_line("@builtin(num_workgroups) num_workgroups: vec3<u32>,");
204        self.indent -= 1;
205        self.emit_line(") {");
206        self.indent += 1;
207
208        // Emit blocks
209        self.emit_block(module, module.entry_block)?;
210
211        self.indent -= 1;
212        self.emit_line("}");
213
214        Ok(())
215    }
216
217    fn assign_names(&mut self, module: &IrModule) {
218        for param in &module.parameters {
219            // For pointer/slice params, they become array accesses
220            self.value_names.insert(param.value_id, param.name.clone());
221        }
222
223        for (block_id, block) in &module.blocks {
224            self.block_labels.insert(*block_id, block.label.clone());
225        }
226    }
227
228    fn emit_block(
229        &mut self,
230        module: &IrModule,
231        block_id: BlockId,
232    ) -> Result<(), WgslLoweringError> {
233        let block = module
234            .blocks
235            .get(&block_id)
236            .ok_or(WgslLoweringError::UndefinedBlock(block_id))?;
237
238        // Note: WGSL doesn't have goto, so we use structured control flow
239        // For now, emit as a sequence with comments for block labels
240        if block_id != module.entry_block {
241            self.emit_line(&format!("// Block: {}", block.label));
242        }
243
244        // Instructions
245        for inst in &block.instructions {
246            self.emit_instruction(module, &inst.result, &inst.result_type, &inst.node)?;
247        }
248
249        // Terminator
250        if let Some(term) = &block.terminator {
251            self.emit_terminator(module, term)?;
252        }
253
254        Ok(())
255    }
256
257    fn emit_instruction(
258        &mut self,
259        _module: &IrModule,
260        result: &ValueId,
261        result_type: &IrType,
262        node: &IrNode,
263    ) -> Result<(), WgslLoweringError> {
264        let result_name = self.get_or_create_name(*result);
265        let ty = self.lower_type(result_type);
266
267        match node {
268            // Constants
269            IrNode::Constant(c) => {
270                let val = self.lower_constant(c);
271                self.emit_line(&format!("var {}: {} = {};", result_name, ty, val));
272            }
273
274            // Binary operations
275            IrNode::BinaryOp(op, lhs, rhs) => {
276                let lhs_name = self.get_value_name(*lhs);
277                let rhs_name = self.get_value_name(*rhs);
278                let expr = self.lower_binary_op(op, &lhs_name, &rhs_name);
279                self.emit_line(&format!("var {}: {} = {};", result_name, ty, expr));
280            }
281
282            // Unary operations
283            IrNode::UnaryOp(op, val) => {
284                let val_name = self.get_value_name(*val);
285                let expr = self.lower_unary_op(op, &val_name);
286                self.emit_line(&format!("var {}: {} = {};", result_name, ty, expr));
287            }
288
289            // Comparisons
290            IrNode::Compare(op, lhs, rhs) => {
291                let lhs_name = self.get_value_name(*lhs);
292                let rhs_name = self.get_value_name(*rhs);
293                let cmp_op = self.lower_compare_op(op);
294                self.emit_line(&format!(
295                    "var {}: bool = {} {} {};",
296                    result_name, lhs_name, cmp_op, rhs_name
297                ));
298            }
299
300            // Memory operations
301            IrNode::Load(ptr) => {
302                let ptr_name = self.get_value_name(*ptr);
303                // In WGSL, arrays use [] indexing
304                self.emit_line(&format!("var {}: {} = {};", result_name, ty, ptr_name));
305            }
306
307            IrNode::Store(ptr, val) => {
308                let ptr_name = self.get_value_name(*ptr);
309                let val_name = self.get_value_name(*val);
310                self.emit_line(&format!("{} = {};", ptr_name, val_name));
311            }
312
313            IrNode::GetElementPtr(ptr, indices) => {
314                let ptr_name = self.get_value_name(*ptr);
315                let idx_name = self.get_value_name(indices[0]);
316                // In WGSL, this becomes an array index
317                self.emit_line(&format!(
318                    "var {}: {} = {}[{}];",
319                    result_name, ty, ptr_name, idx_name
320                ));
321            }
322
323            IrNode::SharedAlloc(_elem_ty, _count) => {
324                // In WGSL, workgroup vars are declared at module scope
325                // For now, emit a comment
326                self.emit_line(&format!("// Workgroup var: {}", result_name));
327            }
328
329            // GPU indexing
330            IrNode::ThreadId(dim) => {
331                let idx = self.lower_dimension(dim, "local_id");
332                self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
333            }
334
335            IrNode::BlockId(dim) => {
336                let idx = self.lower_dimension(dim, "workgroup_id");
337                self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
338            }
339
340            IrNode::BlockDim(dim) => {
341                // In WGSL, workgroup size is a compile-time constant
342                let size = match dim {
343                    Dimension::X => self.config.workgroup_size.0,
344                    Dimension::Y => self.config.workgroup_size.1,
345                    Dimension::Z => self.config.workgroup_size.2,
346                };
347                self.emit_line(&format!("var {}: {} = {}u;", result_name, ty, size));
348            }
349
350            IrNode::GridDim(dim) => {
351                let idx = self.lower_dimension(dim, "num_workgroups");
352                self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
353            }
354
355            IrNode::GlobalThreadId(dim) => {
356                let idx = self.lower_dimension(dim, "global_id");
357                self.emit_line(&format!("var {}: {} = {};", result_name, ty, idx));
358            }
359
360            IrNode::WarpId => {
361                // Approximate warp ID
362                self.emit_line(&format!("var {}: {} = local_id.x / 32u;", result_name, ty));
363            }
364
365            IrNode::LaneId => {
366                self.emit_line(&format!("var {}: {} = local_id.x % 32u;", result_name, ty));
367            }
368
369            // Synchronization
370            IrNode::Barrier => {
371                self.emit_line("workgroupBarrier();");
372            }
373
374            IrNode::MemoryFence(_scope) => {
375                self.emit_line("storageBarrier();");
376            }
377
378            IrNode::GridSync => {
379                return Err(WgslLoweringError::UnsupportedOperation(
380                    "Grid sync not supported in WGSL".to_string(),
381                ));
382            }
383
384            // Atomics
385            IrNode::Atomic(op, ptr, val) => {
386                let ptr_name = self.get_value_name(*ptr);
387                let val_name = self.get_value_name(*val);
388                let atomic_fn = match op {
389                    AtomicOp::Add => "atomicAdd",
390                    AtomicOp::Sub => "atomicSub",
391                    AtomicOp::Exchange => "atomicExchange",
392                    AtomicOp::Min => "atomicMin",
393                    AtomicOp::Max => "atomicMax",
394                    AtomicOp::And => "atomicAnd",
395                    AtomicOp::Or => "atomicOr",
396                    AtomicOp::Xor => "atomicXor",
397                    AtomicOp::Load => "atomicLoad",
398                    AtomicOp::Store => {
399                        self.emit_line(&format!("atomicStore(&{}, {});", ptr_name, val_name));
400                        return Ok(());
401                    }
402                };
403                self.emit_line(&format!(
404                    "var {}: {} = {}(&{}, {});",
405                    result_name, ty, atomic_fn, ptr_name, val_name
406                ));
407            }
408
409            IrNode::AtomicCas(ptr, expected, desired) => {
410                let ptr_name = self.get_value_name(*ptr);
411                let exp_name = self.get_value_name(*expected);
412                let des_name = self.get_value_name(*desired);
413                self.emit_line(&format!(
414                    "var {}: {} = atomicCompareExchangeWeak(&{}, {}, {}).old_value;",
415                    result_name, ty, ptr_name, exp_name, des_name
416                ));
417            }
418
419            // Warp/subgroup operations
420            IrNode::WarpVote(op, val) => {
421                if !self.config.subgroups {
422                    return Err(WgslLoweringError::UnsupportedOperation(
423                        "Subgroup operations require subgroups feature".to_string(),
424                    ));
425                }
426                let val_name = self.get_value_name(*val);
427                let vote_fn = match op {
428                    WarpVoteOp::All => "subgroupAll",
429                    WarpVoteOp::Any => "subgroupAny",
430                    WarpVoteOp::Ballot => "subgroupBallot",
431                };
432                self.emit_line(&format!(
433                    "var {}: {} = {}({});",
434                    result_name, ty, vote_fn, val_name
435                ));
436            }
437
438            IrNode::WarpShuffle(op, val, lane) => {
439                if !self.config.subgroups {
440                    return Err(WgslLoweringError::UnsupportedOperation(
441                        "Subgroup shuffle requires subgroups feature".to_string(),
442                    ));
443                }
444                let val_name = self.get_value_name(*val);
445                let lane_name = self.get_value_name(*lane);
446                let shfl_fn = match op {
447                    WarpShuffleOp::Index => "subgroupShuffle",
448                    WarpShuffleOp::Up => "subgroupShuffleUp",
449                    WarpShuffleOp::Down => "subgroupShuffleDown",
450                    WarpShuffleOp::Xor => "subgroupShuffleXor",
451                };
452                self.emit_line(&format!(
453                    "var {}: {} = {}({}, {});",
454                    result_name, ty, shfl_fn, val_name, lane_name
455                ));
456            }
457
458            // Select
459            IrNode::Select(cond, then_val, else_val) => {
460                let cond_name = self.get_value_name(*cond);
461                let then_name = self.get_value_name(*then_val);
462                let else_name = self.get_value_name(*else_val);
463                self.emit_line(&format!(
464                    "var {}: {} = select({}, {}, {});",
465                    result_name, ty, else_name, then_name, cond_name
466                ));
467            }
468
469            // Math functions
470            IrNode::Math(op, args) => {
471                let fn_name = self.lower_math_op(op);
472                let args_str: Vec<String> = args.iter().map(|a| self.get_value_name(*a)).collect();
473                self.emit_line(&format!(
474                    "var {}: {} = {}({});",
475                    result_name,
476                    ty,
477                    fn_name,
478                    args_str.join(", ")
479                ));
480            }
481
482            // Skip nodes that don't produce WGSL output
483            IrNode::Parameter(_) | IrNode::Undef | IrNode::Phi(_) => {}
484
485            // Messaging (not supported in WGSL)
486            IrNode::K2HEnqueue(_)
487            | IrNode::H2KDequeue
488            | IrNode::H2KIsEmpty
489            | IrNode::K2KSend(_, _)
490            | IrNode::K2KRecv
491            | IrNode::K2KTryRecv
492            | IrNode::HlcNow
493            | IrNode::HlcTick
494            | IrNode::HlcUpdate(_) => {
495                self.emit_line(&format!("// Not supported in WGSL: {:?}", node));
496            }
497
498            _ => {
499                self.emit_line(&format!("// Unhandled: {:?}", node));
500            }
501        }
502
503        Ok(())
504    }
505
506    fn emit_terminator(
507        &mut self,
508        module: &IrModule,
509        term: &Terminator,
510    ) -> Result<(), WgslLoweringError> {
511        match term {
512            Terminator::Return(None) => {
513                self.emit_line("return;");
514            }
515            Terminator::Return(Some(val)) => {
516                // Compute shaders don't return values
517                let val_name = self.get_value_name(*val);
518                self.emit_line(&format!("// Return: {}", val_name));
519                self.emit_line("return;");
520            }
521            Terminator::Branch(target) => {
522                // WGSL doesn't have goto, emit the target block inline
523                self.emit_block(module, *target)?;
524            }
525            Terminator::CondBranch(cond, then_block, else_block) => {
526                let cond_name = self.get_value_name(*cond);
527                self.emit_line(&format!("if ({}) {{", cond_name));
528                self.indent += 1;
529                self.emit_block(module, *then_block)?;
530                self.indent -= 1;
531                self.emit_line("} else {");
532                self.indent += 1;
533                self.emit_block(module, *else_block)?;
534                self.indent -= 1;
535                self.emit_line("}");
536            }
537            Terminator::Switch(val, default, cases) => {
538                let val_name = self.get_value_name(*val);
539                self.emit_line(&format!("switch ({}) {{", val_name));
540                self.indent += 1;
541                for (case_val, target) in cases {
542                    let case_str = self.lower_constant(case_val);
543                    self.emit_line(&format!("case {}: {{", case_str));
544                    self.indent += 1;
545                    self.emit_block(module, *target)?;
546                    self.indent -= 1;
547                    self.emit_line("}");
548                }
549                self.emit_line("default: {");
550                self.indent += 1;
551                self.emit_block(module, *default)?;
552                self.indent -= 1;
553                self.emit_line("}");
554                self.indent -= 1;
555                self.emit_line("}");
556            }
557            Terminator::Unreachable => {
558                self.emit_line("// unreachable");
559            }
560        }
561        Ok(())
562    }
563
564    fn lower_type(&self, ty: &IrType) -> String {
565        match ty {
566            IrType::Void => "void".to_string(),
567            IrType::Scalar(s) => self.lower_scalar_type(s),
568            IrType::Vector(v) => format!("vec{}<{}>", v.count, self.lower_scalar_type(&v.element)),
569            IrType::Ptr(inner) => format!("ptr<storage, {}>", self.lower_type(inner)),
570            IrType::Array(inner, size) => format!("array<{}, {}>", self.lower_type(inner), size),
571            IrType::Slice(inner) => format!("array<{}>", self.lower_type(inner)),
572            IrType::Struct(s) => s.name.clone(),
573            IrType::Function(_) => "void".to_string(),
574        }
575    }
576
577    fn lower_scalar_type(&self, ty: &ScalarType) -> String {
578        match ty {
579            ScalarType::Bool => "bool",
580            ScalarType::I8 | ScalarType::I16 | ScalarType::I32 => "i32",
581            // WGSL doesn't have i64, always downcast
582            ScalarType::I64 => "i32",
583            ScalarType::U8 | ScalarType::U16 | ScalarType::U32 => "u32",
584            ScalarType::U64 => "u32", // Downcast
585            ScalarType::F16 => "f16",
586            ScalarType::F32 => "f32",
587            ScalarType::F64 => "f32", // Downcast (WGSL doesn't support f64)
588        }
589        .to_string()
590    }
591
592    fn lower_constant(&self, c: &ConstantValue) -> String {
593        match c {
594            ConstantValue::Bool(b) => if *b { "true" } else { "false" }.to_string(),
595            ConstantValue::I32(v) => format!("{}i", v),
596            ConstantValue::I64(v) => format!("{}i", *v as i32), // Downcast
597            ConstantValue::U32(v) => format!("{}u", v),
598            ConstantValue::U64(v) => format!("{}u", *v as u32), // Downcast
599            ConstantValue::F32(v) => {
600                if v.is_nan() {
601                    "0.0f".to_string()
602                } else if v.is_infinite() {
603                    if *v > 0.0 { "1e38f" } else { "-1e38f" }.to_string()
604                } else {
605                    format!("{}f", v)
606                }
607            }
608            ConstantValue::F64(v) => format!("{}f", *v as f32), // Downcast
609            ConstantValue::Null => "0u".to_string(),
610            ConstantValue::Array(elems) => {
611                let elems_str: Vec<String> = elems.iter().map(|e| self.lower_constant(e)).collect();
612                format!("array({})", elems_str.join(", "))
613            }
614            ConstantValue::Struct(fields) => {
615                let fields_str: Vec<String> =
616                    fields.iter().map(|f| self.lower_constant(f)).collect();
617                format!("({})", fields_str.join(", "))
618            }
619        }
620    }
621
622    fn lower_binary_op(&self, op: &BinaryOp, lhs: &str, rhs: &str) -> String {
623        match op {
624            BinaryOp::Add => format!("({} + {})", lhs, rhs),
625            BinaryOp::Sub => format!("({} - {})", lhs, rhs),
626            BinaryOp::Mul => format!("({} * {})", lhs, rhs),
627            BinaryOp::Div => format!("({} / {})", lhs, rhs),
628            BinaryOp::Rem => format!("({} % {})", lhs, rhs),
629            BinaryOp::And => format!("({} & {})", lhs, rhs),
630            BinaryOp::Or => format!("({} | {})", lhs, rhs),
631            BinaryOp::Xor => format!("({} ^ {})", lhs, rhs),
632            BinaryOp::Shl => format!("({} << {})", lhs, rhs),
633            BinaryOp::Shr => format!("({} >> {})", lhs, rhs),
634            BinaryOp::Sar => format!("({} >> {})", lhs, rhs),
635            BinaryOp::Fma => format!("fma({}, {}, 0.0)", lhs, rhs),
636            BinaryOp::Pow => format!("pow({}, {})", lhs, rhs),
637            BinaryOp::Min => format!("min({}, {})", lhs, rhs),
638            BinaryOp::Max => format!("max({}, {})", lhs, rhs),
639        }
640    }
641
642    fn lower_unary_op(&self, op: &UnaryOp, val: &str) -> String {
643        match op {
644            UnaryOp::Neg => format!("(-{})", val),
645            UnaryOp::Not => format!("(~{})", val),
646            UnaryOp::LogicalNot => format!("(!{})", val),
647            UnaryOp::Abs => format!("abs({})", val),
648            UnaryOp::Sqrt => format!("sqrt({})", val),
649            UnaryOp::Rsqrt => format!("inverseSqrt({})", val),
650            UnaryOp::Floor => format!("floor({})", val),
651            UnaryOp::Ceil => format!("ceil({})", val),
652            UnaryOp::Round => format!("round({})", val),
653            UnaryOp::Trunc => format!("trunc({})", val),
654            UnaryOp::Sign => format!("sign({})", val),
655        }
656    }
657
658    fn lower_compare_op(&self, op: &CompareOp) -> &'static str {
659        match op {
660            CompareOp::Eq => "==",
661            CompareOp::Ne => "!=",
662            CompareOp::Lt => "<",
663            CompareOp::Le => "<=",
664            CompareOp::Gt => ">",
665            CompareOp::Ge => ">=",
666        }
667    }
668
669    fn lower_dimension(&self, dim: &Dimension, prefix: &str) -> String {
670        match dim {
671            Dimension::X => format!("{}.x", prefix),
672            Dimension::Y => format!("{}.y", prefix),
673            Dimension::Z => format!("{}.z", prefix),
674        }
675    }
676
677    fn lower_math_op(&self, op: &MathOp) -> &'static str {
678        match op {
679            MathOp::Sin => "sin",
680            MathOp::Cos => "cos",
681            MathOp::Tan => "tan",
682            MathOp::Asin => "asin",
683            MathOp::Acos => "acos",
684            MathOp::Atan => "atan",
685            MathOp::Atan2 => "atan2",
686            MathOp::Sinh => "sinh",
687            MathOp::Cosh => "cosh",
688            MathOp::Tanh => "tanh",
689            MathOp::Exp => "exp",
690            MathOp::Exp2 => "exp2",
691            MathOp::Log => "log",
692            MathOp::Log2 => "log2",
693            MathOp::Log10 => "log", // log10 not in WGSL, would need emulation
694            MathOp::Lerp => "mix",
695            MathOp::Clamp => "clamp",
696            MathOp::Step => "step",
697            MathOp::SmoothStep => "smoothstep",
698            MathOp::Fract => "fract",
699            MathOp::CopySign => "sign", // Approximate
700        }
701    }
702
703    fn get_value_name(&self, id: ValueId) -> String {
704        self.value_names
705            .get(&id)
706            .cloned()
707            .unwrap_or_else(|| format!("v{}", id.raw()))
708    }
709
710    fn get_or_create_name(&mut self, id: ValueId) -> String {
711        if let Some(name) = self.value_names.get(&id) {
712            return name.clone();
713        }
714        let name = format!("t{}", self.name_counter);
715        self.name_counter += 1;
716        self.value_names.insert(id, name.clone());
717        name
718    }
719
720    fn emit_line(&mut self, line: &str) {
721        let indent = "    ".repeat(self.indent);
722        writeln!(self.output, "{}{}", indent, line).unwrap();
723    }
724}
725
726/// WGSL lowering errors.
727#[derive(Debug, Clone)]
728pub enum WgslLoweringError {
729    /// Unsupported capability.
730    UnsupportedCapability(String),
731    /// Unsupported operation.
732    UnsupportedOperation(String),
733    /// Undefined block reference.
734    UndefinedBlock(BlockId),
735    /// Undefined value reference.
736    UndefinedValue(ValueId),
737    /// Type error.
738    TypeError(String),
739}
740
741impl std::fmt::Display for WgslLoweringError {
742    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
743        match self {
744            WgslLoweringError::UnsupportedCapability(cap) => {
745                write!(f, "Unsupported capability: {}", cap)
746            }
747            WgslLoweringError::UnsupportedOperation(op) => {
748                write!(f, "Unsupported operation: {}", op)
749            }
750            WgslLoweringError::UndefinedBlock(id) => write!(f, "Undefined block: {}", id),
751            WgslLoweringError::UndefinedValue(id) => write!(f, "Undefined value: {}", id),
752            WgslLoweringError::TypeError(msg) => write!(f, "Type error: {}", msg),
753        }
754    }
755}
756
757impl std::error::Error for WgslLoweringError {}
758
759/// Convenience function to lower IR to WGSL.
760pub fn lower_to_wgsl(module: &IrModule) -> Result<String, WgslLoweringError> {
761    WgslLowering::new(WgslLoweringConfig::default()).lower(module)
762}
763
764/// Lower IR to WGSL with custom config.
765pub fn lower_to_wgsl_with_config(
766    module: &IrModule,
767    config: WgslLoweringConfig,
768) -> Result<String, WgslLoweringError> {
769    WgslLowering::new(config).lower(module)
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775    use crate::IrBuilder;
776
777    #[test]
778    fn test_lower_simple_kernel() {
779        let mut builder = IrBuilder::new("add_one");
780
781        let _x = builder.parameter("x", IrType::slice(IrType::F32));
782        let _n = builder.parameter("n", IrType::I32);
783
784        let idx = builder.global_thread_id(Dimension::X);
785        let _ = idx;
786
787        builder.ret();
788
789        let module = builder.build();
790        let wgsl = lower_to_wgsl(&module).unwrap();
791
792        assert!(wgsl.contains("@compute @workgroup_size"));
793        assert!(wgsl.contains("fn add_one"));
794        assert!(wgsl.contains("global_id"));
795    }
796
797    #[test]
798    fn test_lower_with_barrier() {
799        let mut builder = IrBuilder::new("sync");
800
801        builder.barrier();
802        builder.ret();
803
804        let module = builder.build();
805        let wgsl = lower_to_wgsl(&module).unwrap();
806
807        assert!(wgsl.contains("workgroupBarrier()"));
808    }
809
810    #[test]
811    fn test_lower_with_control_flow() {
812        let mut builder = IrBuilder::new("branch");
813
814        let cond = builder.const_bool(true);
815        let then_block = builder.create_block("then");
816        let else_block = builder.create_block("else");
817
818        builder.cond_branch(cond, then_block, else_block);
819
820        builder.switch_to_block(then_block);
821        builder.ret();
822
823        builder.switch_to_block(else_block);
824        builder.ret();
825
826        let module = builder.build();
827        let wgsl = lower_to_wgsl(&module).unwrap();
828
829        assert!(wgsl.contains("if ("));
830        assert!(wgsl.contains("} else {"));
831    }
832
833    #[test]
834    fn test_lower_rejects_grid_sync() {
835        let mut builder = IrBuilder::new("grid");
836        builder.grid_sync();
837        builder.ret();
838
839        let module = builder.build();
840        let result = lower_to_wgsl(&module);
841
842        assert!(result.is_err());
843    }
844
845    #[test]
846    fn test_lower_with_subgroups() {
847        let mut builder = IrBuilder::new("subgroup");
848
849        let _val = builder.const_bool(true);
850        // WarpVote requires subgroups capability
851
852        builder.ret();
853
854        let module = builder.build();
855        let config = WgslLoweringConfig::default().with_subgroups();
856        let wgsl = lower_to_wgsl_with_config(&module, config).unwrap();
857
858        assert!(wgsl.contains("enable subgroups;"));
859    }
860}