cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::ExecutionMode;
7use cubecl_common::backtrace::BackTrace;
8use cubecl_core::post_processing::{
9    checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
10};
11use cubecl_core::prelude::*;
12use cubecl_core::{
13    Metadata, WgpuCompilationOptions,
14    ir::{self as cube, Scope},
15    prelude::expand_erf,
16};
17use cubecl_core::{
18    ir::{ConstantScalarValue, Processor, UIntKind},
19    post_processing::unroll::UnrollProcessor,
20};
21use cubecl_runtime::compiler::CompilationError;
22use cubecl_runtime::kernel;
23
24pub const MAX_LINE_SIZE: u32 = 4;
25
26/// Wgsl Compiler.
27#[derive(Clone, Default)]
28pub struct WgslCompiler {
29    metadata: Metadata,
30    ext_meta_pos: Vec<u32>,
31    local_invocation_index: bool,
32    local_invocation_id: bool,
33    // TODO: possible cleanup, this bool seems to not be used
34    global_invocation_id: bool,
35    workgroup_id: bool,
36    subgroup_size: bool,
37    subgroup_invocation_id: bool,
38    id: bool,
39    num_workgroups: bool,
40    workgroup_id_no_axis: bool,
41    workgroup_size_no_axis: bool,
42    num_workgroup_no_axis: bool,
43    shared_arrays: Vec<SharedArray>,
44    shared_values: Vec<SharedValue>,
45    const_arrays: Vec<ConstantArray>,
46    local_arrays: Vec<LocalArray>,
47    #[allow(dead_code)]
48    compilation_options: WgpuCompilationOptions,
49    strategy: ExecutionMode,
50    subgroup_instructions_used: bool,
51    f16_used: bool,
52}
53
54impl core::fmt::Debug for WgslCompiler {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.write_str("WgslCompiler")
57    }
58}
59
60impl cubecl_core::Compiler for WgslCompiler {
61    type Representation = ComputeShader;
62    type CompilationOptions = WgpuCompilationOptions;
63
64    fn compile(
65        &mut self,
66        shader: kernel::KernelDefinition,
67        compilation_options: &Self::CompilationOptions,
68        mode: ExecutionMode,
69    ) -> Result<Self::Representation, CompilationError> {
70        self.compilation_options = compilation_options.clone();
71        self.compile_shader(shader, mode)
72    }
73
74    fn elem_size(&self, elem: cube::ElemType) -> usize {
75        elem.size()
76    }
77
78    fn extension(&self) -> &'static str {
79        "wgsl"
80    }
81}
82
83impl WgslCompiler {
84    fn compile_shader(
85        &mut self,
86        mut value: kernel::KernelDefinition,
87        mode: ExecutionMode,
88    ) -> Result<wgsl::ComputeShader, CompilationError> {
89        let errors = value.body.pop_errors();
90        if !errors.is_empty() {
91            let mut reason = "Can't compile wgsl kernel".to_string();
92            for error in errors {
93                reason += error.as_str();
94                reason += "\n";
95            }
96
97            return Err(CompilationError::Validation {
98                reason,
99                backtrace: BackTrace::capture(),
100            });
101        }
102
103        self.strategy = mode;
104
105        let num_meta = value.buffers.len();
106
107        self.ext_meta_pos = Vec::new();
108        let mut num_ext = 0;
109
110        for binding in value.buffers.iter() {
111            self.ext_meta_pos.push(num_ext);
112            if binding.has_extended_meta {
113                num_ext += 1;
114            }
115        }
116
117        self.metadata = Metadata::new(num_meta as u32, num_ext);
118
119        let instructions = self.compile_scope(&mut value.body);
120        let extensions = register_extensions(&instructions);
121        let body = wgsl::Body {
122            instructions,
123            id: self.id,
124        };
125
126        Ok(wgsl::ComputeShader {
127            buffers: value
128                .buffers
129                .into_iter()
130                .map(|mut it| {
131                    // This is safe when combined with the unroll transform that adjusts all indices.
132                    // Must not be used alone
133                    if it.ty.line_size() > MAX_LINE_SIZE {
134                        it.ty = it.ty.line(MAX_LINE_SIZE);
135                    }
136                    self.compile_binding(it)
137                })
138                .collect(),
139            scalars: value
140                .scalars
141                .into_iter()
142                .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
143                .collect(),
144            shared_arrays: self.shared_arrays.clone(),
145            shared_values: self.shared_values.clone(),
146            constant_arrays: self.const_arrays.clone(),
147            local_arrays: self.local_arrays.clone(),
148            has_metadata: self.metadata.static_len() > 0,
149            workgroup_size: value.cube_dim,
150            global_invocation_id: self.global_invocation_id || self.id,
151            local_invocation_index: self.local_invocation_index,
152            local_invocation_id: self.local_invocation_id,
153            num_workgroups: self.id
154                || self.num_workgroups
155                || self.num_workgroup_no_axis
156                || self.workgroup_id_no_axis,
157            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
158            subgroup_size: self.subgroup_size,
159            subgroup_invocation_id: self.subgroup_invocation_id,
160            body,
161            extensions,
162            num_workgroups_no_axis: self.num_workgroup_no_axis,
163            workgroup_id_no_axis: self.workgroup_id_no_axis,
164            workgroup_size_no_axis: self.workgroup_size_no_axis,
165            subgroup_instructions_used: self.subgroup_instructions_used,
166            f16_used: self.f16_used,
167            kernel_name: value.options.kernel_name,
168        })
169    }
170
171    fn compile_type(&mut self, item: cube::Type) -> Item {
172        match item {
173            cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
174            cube::Type::Line(ty, size) => {
175                let elem = self.compile_storage_type(ty);
176                match size {
177                    2 => wgsl::Item::Vec2(elem),
178                    3 => wgsl::Item::Vec3(elem),
179                    4 => wgsl::Item::Vec4(elem),
180                    _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
181                }
182            }
183            cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
184        }
185    }
186
187    fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
188        match ty {
189            cube::StorageType::Scalar(ty) => self.compile_elem(ty),
190            cube::StorageType::Atomic(ty) => match ty {
191                cube::ElemType::Float(i) => match i {
192                    cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
193                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
194                },
195                cube::ElemType::Int(i) => match i {
196                    cube::IntKind::I32 => wgsl::Elem::AtomicI32,
197                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
198                },
199                cube::ElemType::UInt(kind) => match kind {
200                    cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
201                    kind => panic!("{kind:?} is not a valid WgpuElement"),
202                },
203                other => panic!("{other:?} is not a valid WgpuElement"),
204            },
205            cube::StorageType::Packed(_, _) => {
206                unimplemented!("Packed types not yet supported in WGSL")
207            }
208            cube::StorageType::Opaque(ty) => match ty {
209                cube::OpaqueType::Barrier(_) => {
210                    unimplemented!("Barrier objects not supported in WGSL")
211                }
212            },
213        }
214    }
215
216    fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
217        match value {
218            cube::ElemType::Float(f) => match f {
219                cube::FloatKind::E2M1
220                | cube::FloatKind::E2M3
221                | cube::FloatKind::E3M2
222                | cube::FloatKind::E4M3
223                | cube::FloatKind::E5M2
224                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
225                cube::FloatKind::F16 => {
226                    self.f16_used = true;
227                    wgsl::Elem::F16
228                }
229                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
230                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
231                cube::FloatKind::Flex32 => wgsl::Elem::F32,
232                cube::FloatKind::F32 => wgsl::Elem::F32,
233                cube::FloatKind::F64 => wgsl::Elem::F64,
234            },
235            cube::ElemType::Int(i) => match i {
236                cube::IntKind::I32 => wgsl::Elem::I32,
237                cube::IntKind::I64 => wgsl::Elem::I64,
238                kind => panic!("{kind:?} is not a valid WgpuElement"),
239            },
240            cube::ElemType::UInt(kind) => match kind {
241                cube::UIntKind::U32 => wgsl::Elem::U32,
242                cube::UIntKind::U64 => wgsl::Elem::U64,
243                kind => panic!("{kind:?} is not a valid WgpuElement"),
244            },
245            cube::ElemType::Bool => wgsl::Elem::Bool,
246        }
247    }
248
249    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
250        let pos = var.index().expect("Variable should have index");
251        self.ext_meta_pos[pos as usize]
252    }
253
254    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
255        let item = value.ty;
256        match value.kind {
257            cube::VariableKind::GlobalInputArray(id) => {
258                wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
259            }
260            cube::VariableKind::GlobalScalar(id) => {
261                wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
262            }
263            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
264                wgsl::Variable::LocalMut {
265                    id,
266                    item: self.compile_type(item),
267                }
268            }
269            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
270                id,
271                item: self.compile_type(item),
272            },
273            cube::VariableKind::GlobalOutputArray(id) => {
274                wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
275            }
276            cube::VariableKind::ConstantScalar(value) => {
277                wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
278            }
279            cube::VariableKind::SharedArray {
280                id,
281                length,
282                unroll_factor,
283                alignment,
284            } => {
285                let item = self.compile_type(item);
286                if !self.shared_arrays.iter().any(|s| s.index == id) {
287                    self.shared_arrays.push(SharedArray::new(
288                        id,
289                        item,
290                        length * unroll_factor,
291                        alignment,
292                    ));
293                }
294                wgsl::Variable::SharedArray(id, item, length)
295            }
296            cube::VariableKind::Shared { id } => {
297                let item = self.compile_type(item);
298                if !self.shared_values.iter().any(|s| s.index == id) {
299                    self.shared_values.push(SharedValue::new(id, item));
300                }
301                wgsl::Variable::SharedValue(id, item)
302            }
303            cube::VariableKind::ConstantArray { id, length, .. } => {
304                let item = self.compile_type(item);
305                wgsl::Variable::ConstantArray(id, item, length)
306            }
307            cube::VariableKind::LocalArray {
308                id,
309                length,
310                unroll_factor,
311            } => {
312                let item = self.compile_type(item);
313                if !self.local_arrays.iter().any(|s| s.index == id) {
314                    self.local_arrays
315                        .push(LocalArray::new(id, item, length * unroll_factor));
316                }
317                wgsl::Variable::LocalArray(id, item, length)
318            }
319            cube::VariableKind::Builtin(builtin) => match builtin {
320                cube::Builtin::AbsolutePos => {
321                    self.id = true;
322                    wgsl::Variable::Id
323                }
324                cube::Builtin::UnitPos => {
325                    self.local_invocation_index = true;
326                    wgsl::Variable::LocalInvocationIndex
327                }
328                cube::Builtin::UnitPosX => {
329                    self.local_invocation_id = true;
330                    wgsl::Variable::LocalInvocationIdX
331                }
332                cube::Builtin::UnitPosY => {
333                    self.local_invocation_id = true;
334                    wgsl::Variable::LocalInvocationIdY
335                }
336                cube::Builtin::UnitPosZ => {
337                    self.local_invocation_id = true;
338                    wgsl::Variable::LocalInvocationIdZ
339                }
340                cube::Builtin::CubePosX => {
341                    self.workgroup_id = true;
342                    wgsl::Variable::WorkgroupIdX
343                }
344                cube::Builtin::CubePosY => {
345                    self.workgroup_id = true;
346                    wgsl::Variable::WorkgroupIdY
347                }
348                cube::Builtin::CubePosZ => {
349                    self.workgroup_id = true;
350                    wgsl::Variable::WorkgroupIdZ
351                }
352                cube::Builtin::CubePosCluster
353                | cube::Builtin::CubePosClusterX
354                | cube::Builtin::CubePosClusterY
355                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
356                cube::Builtin::AbsolutePosX => {
357                    self.global_invocation_id = true;
358                    wgsl::Variable::GlobalInvocationIdX
359                }
360                cube::Builtin::AbsolutePosY => {
361                    self.global_invocation_id = true;
362                    wgsl::Variable::GlobalInvocationIdY
363                }
364                cube::Builtin::AbsolutePosZ => {
365                    self.global_invocation_id = true;
366                    wgsl::Variable::GlobalInvocationIdZ
367                }
368                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
369                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
370                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
371                cube::Builtin::CubeClusterDim
372                | cube::Builtin::CubeClusterDimX
373                | cube::Builtin::CubeClusterDimY
374                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
375                cube::Builtin::CubeCountX => {
376                    self.num_workgroups = true;
377                    wgsl::Variable::NumWorkgroupsX
378                }
379                cube::Builtin::CubeCountY => {
380                    self.num_workgroups = true;
381                    wgsl::Variable::NumWorkgroupsY
382                }
383                cube::Builtin::CubeCountZ => {
384                    self.num_workgroups = true;
385                    wgsl::Variable::NumWorkgroupsZ
386                }
387                cube::Builtin::CubePos => {
388                    self.workgroup_id_no_axis = true;
389                    wgsl::Variable::WorkgroupId
390                }
391                cube::Builtin::CubeDim => {
392                    self.workgroup_size_no_axis = true;
393                    wgsl::Variable::WorkgroupSize
394                }
395                cube::Builtin::CubeCount => {
396                    self.num_workgroup_no_axis = true;
397                    wgsl::Variable::NumWorkgroups
398                }
399                cube::Builtin::PlaneDim => {
400                    self.subgroup_size = true;
401                    wgsl::Variable::SubgroupSize
402                }
403                cube::Builtin::UnitPosPlane => {
404                    self.subgroup_invocation_id = true;
405                    wgsl::Variable::SubgroupInvocationId
406                }
407            },
408            cube::VariableKind::Matrix { .. } => {
409                panic!("Cooperative matrix-multiply and accumulate not supported.")
410            }
411            cube::VariableKind::Pipeline { .. } => {
412                panic!("Pipeline not supported.")
413            }
414            cube::VariableKind::BarrierToken { .. } => {
415                panic!("Barrier not supported.")
416            }
417            cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
418            cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
419        }
420    }
421
422    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
423        let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
424        self.compile_variable(var)
425    }
426
427    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
428        let mut instructions = Vec::new();
429
430        let const_arrays = scope
431            .const_arrays
432            .drain(..)
433            .map(|(var, values)| ConstantArray {
434                index: var.index().unwrap(),
435                item: self.compile_type(var.ty),
436                size: values.len() as u32,
437                values: values
438                    .into_iter()
439                    .map(|val| self.compile_variable(val))
440                    .collect(),
441            })
442            .collect::<Vec<_>>();
443        self.const_arrays.extend(const_arrays);
444
445        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
446        let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
447        let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
448        let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
449
450        for mut var in processing.variables {
451            if var.ty.line_size() > MAX_LINE_SIZE {
452                var.ty = var.ty.line(MAX_LINE_SIZE);
453            }
454            instructions.push(wgsl::Instruction::DeclareVariable {
455                var: self.compile_variable(var),
456            });
457        }
458
459        processing
460            .instructions
461            .into_iter()
462            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
463
464        instructions
465    }
466
467    fn compile_operation(
468        &mut self,
469        instructions: &mut Vec<wgsl::Instruction>,
470        operation: cube::Operation,
471        out: Option<cube::Variable>,
472        scope: &mut cube::Scope,
473    ) {
474        match operation {
475            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
476                input: self.compile_variable(variable),
477                out: self.compile_variable(out.unwrap()),
478            }),
479            cube::Operation::Arithmetic(op) => {
480                self.compile_arithmetic(op, out, instructions, scope)
481            }
482            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
483            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
484            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
485            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
486            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
487            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
488            cube::Operation::Synchronization(val) => {
489                self.compile_synchronization(instructions, val)
490            }
491            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
492            cube::Operation::CoopMma(_) => {
493                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
494            }
495            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
496                self.compile_comment(instructions, content)
497            }
498            cube::Operation::NonSemantic(_) => {}
499            cube::Operation::Barrier(_) => {
500                panic!("Barrier isn't supported on wgpu.")
501            }
502            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
503            cube::Operation::Marker(_) => {}
504        }
505    }
506
507    fn compile_subgroup(
508        &mut self,
509        instructions: &mut Vec<wgsl::Instruction>,
510        subgroup: cube::Plane,
511        out: Option<cube::Variable>,
512    ) {
513        self.subgroup_instructions_used = true;
514
515        let out = out.unwrap();
516        let op = match subgroup {
517            cube::Plane::Elect => Subgroup::Elect {
518                out: self.compile_variable(out),
519            },
520            cube::Plane::All(op) => Subgroup::All {
521                input: self.compile_variable(op.input),
522                out: self.compile_variable(out),
523            },
524            cube::Plane::Any(op) => Subgroup::Any {
525                input: self.compile_variable(op.input),
526                out: self.compile_variable(out),
527            },
528            cube::Plane::Ballot(op) => Subgroup::Ballot {
529                input: self.compile_variable(op.input),
530                out: self.compile_variable(out),
531            },
532
533            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
534                lhs: self.compile_variable(op.lhs),
535                rhs: self.compile_variable(op.rhs),
536                out: self.compile_variable(out),
537            },
538
539            cube::Plane::Sum(op) => Subgroup::Sum {
540                input: self.compile_variable(op.input),
541                out: self.compile_variable(out),
542            },
543
544            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
545                input: self.compile_variable(op.input),
546                out: self.compile_variable(out),
547            },
548            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
549                input: self.compile_variable(op.input),
550                out: self.compile_variable(out),
551            },
552            cube::Plane::Prod(op) => Subgroup::Prod {
553                input: self.compile_variable(op.input),
554                out: self.compile_variable(out),
555            },
556            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
557                input: self.compile_variable(op.input),
558                out: self.compile_variable(out),
559            },
560            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
561                input: self.compile_variable(op.input),
562                out: self.compile_variable(out),
563            },
564            cube::Plane::Min(op) => Subgroup::Min {
565                input: self.compile_variable(op.input),
566                out: self.compile_variable(out),
567            },
568            cube::Plane::Max(op) => Subgroup::Max {
569                input: self.compile_variable(op.input),
570                out: self.compile_variable(out),
571            },
572            cube::Plane::Shuffle(op) => Subgroup::Shuffle {
573                lhs: self.compile_variable(op.lhs),
574                rhs: self.compile_variable(op.rhs),
575                out: self.compile_variable(out),
576            },
577            cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
578                lhs: self.compile_variable(op.lhs),
579                rhs: self.compile_variable(op.rhs),
580                out: self.compile_variable(out),
581            },
582            cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
583                lhs: self.compile_variable(op.lhs),
584                rhs: self.compile_variable(op.rhs),
585                out: self.compile_variable(out),
586            },
587            cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
588                lhs: self.compile_variable(op.lhs),
589                rhs: self.compile_variable(op.rhs),
590                out: self.compile_variable(out),
591            },
592        };
593
594        instructions.push(wgsl::Instruction::Subgroup(op));
595    }
596
597    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
598        match branch {
599            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
600                cond: self.compile_variable(op.cond),
601                instructions: self.compile_scope(&mut op.scope),
602            }),
603            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
604                cond: self.compile_variable(op.cond),
605                instructions_if: self.compile_scope(&mut op.scope_if),
606                instructions_else: self.compile_scope(&mut op.scope_else),
607            }),
608            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
609                value: self.compile_variable(op.value),
610                instructions_default: self.compile_scope(&mut op.scope_default),
611                cases: op
612                    .cases
613                    .into_iter()
614                    .map(|(val, mut scope)| {
615                        (self.compile_variable(val), self.compile_scope(&mut scope))
616                    })
617                    .collect(),
618            }),
619            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
620            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
621            cube::Branch::RangeLoop(mut range_loop) => {
622                instructions.push(wgsl::Instruction::RangeLoop {
623                    i: self.compile_variable(range_loop.i),
624                    start: self.compile_variable(range_loop.start),
625                    end: self.compile_variable(range_loop.end),
626                    step: range_loop.step.map(|it| self.compile_variable(it)),
627                    inclusive: range_loop.inclusive,
628                    instructions: self.compile_scope(&mut range_loop.scope),
629                })
630            }
631            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
632                instructions: self.compile_scope(&mut op.scope),
633            }),
634        };
635    }
636
637    fn compile_synchronization(
638        &mut self,
639        instructions: &mut Vec<wgsl::Instruction>,
640        synchronization: cube::Synchronization,
641    ) {
642        match synchronization {
643            cube::Synchronization::SyncCube => {
644                instructions.push(wgsl::Instruction::WorkgroupBarrier)
645            }
646            cube::Synchronization::SyncPlane => {
647                panic!("Synchronization within a plane is not supported in WGSL")
648            }
649            cube::Synchronization::SyncStorage => {
650                instructions.push(wgsl::Instruction::StorageBarrier)
651            }
652            cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
653        };
654    }
655
656    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
657        instructions.push(wgsl::Instruction::Comment { content })
658    }
659
660    fn compile_metadata(
661        &mut self,
662        metadata: cube::Metadata,
663        out: Option<cube::Variable>,
664    ) -> wgsl::Instruction {
665        let out = out.unwrap();
666        match metadata {
667            cube::Metadata::Rank { var } => {
668                let position = self.ext_meta_pos(&var);
669                let offset = self.metadata.rank_index(position);
670                wgsl::Instruction::Metadata {
671                    out: self.compile_variable(out),
672                    info_offset: self.compile_variable(offset.into()),
673                }
674            }
675            cube::Metadata::Stride { dim, var } => {
676                let position = self.ext_meta_pos(&var);
677                let offset = self.metadata.stride_offset_index(position);
678                wgsl::Instruction::ExtendedMeta {
679                    info_offset: self.compile_variable(offset.into()),
680                    dim: self.compile_variable(dim),
681                    out: self.compile_variable(out),
682                }
683            }
684            cube::Metadata::Shape { dim, var } => {
685                let position = self.ext_meta_pos(&var);
686                let offset = self.metadata.shape_offset_index(position);
687                wgsl::Instruction::ExtendedMeta {
688                    info_offset: self.compile_variable(offset.into()),
689                    dim: self.compile_variable(dim),
690                    out: self.compile_variable(out),
691                }
692            }
693            cube::Metadata::Length { var } => match var.kind {
694                cube::VariableKind::GlobalInputArray(id) => {
695                    let offset = self.metadata.len_index(id);
696                    wgsl::Instruction::Metadata {
697                        out: self.compile_variable(out),
698                        info_offset: self.compile_variable(offset.into()),
699                    }
700                }
701                cube::VariableKind::GlobalOutputArray(id) => {
702                    let offset = self.metadata.len_index(id);
703                    wgsl::Instruction::Metadata {
704                        out: self.compile_variable(out),
705                        info_offset: self.compile_variable(offset.into()),
706                    }
707                }
708                _ => wgsl::Instruction::Length {
709                    var: self.compile_variable(var),
710                    out: self.compile_variable(out),
711                },
712            },
713            cube::Metadata::BufferLength { var } => match var.kind {
714                cube::VariableKind::GlobalInputArray(id) => {
715                    let offset = self.metadata.buffer_len_index(id);
716                    wgsl::Instruction::Metadata {
717                        out: self.compile_variable(out),
718                        info_offset: self.compile_variable(offset.into()),
719                    }
720                }
721                cube::VariableKind::GlobalOutputArray(id) => {
722                    let offset = self.metadata.buffer_len_index(id);
723                    wgsl::Instruction::Metadata {
724                        out: self.compile_variable(out),
725                        info_offset: self.compile_variable(offset.into()),
726                    }
727                }
728                _ => wgsl::Instruction::Length {
729                    var: self.compile_variable(var),
730                    out: self.compile_variable(out),
731                },
732            },
733        }
734    }
735
736    fn compile_arithmetic(
737        &mut self,
738        value: cube::Arithmetic,
739        out: Option<cube::Variable>,
740        instructions: &mut Vec<wgsl::Instruction>,
741        scope: &mut Scope,
742    ) {
743        let out = out.unwrap();
744        match value {
745            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
746                lhs: self.compile_variable(op.lhs),
747                rhs: self.compile_variable(op.rhs),
748                out: self.compile_variable(out),
749            }),
750            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
751                lhs: self.compile_variable(op.lhs),
752                rhs: self.compile_variable(op.rhs),
753                out: self.compile_variable(out),
754            }),
755            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
756                lhs: self.compile_variable(op.lhs),
757                rhs: self.compile_variable(op.rhs),
758                out: self.compile_variable(out),
759            }),
760            cube::Arithmetic::SaturatingAdd(_) => {
761                unreachable!("Saturating add should be removed by processor");
762            }
763            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
764                a: self.compile_variable(op.a),
765                b: self.compile_variable(op.b),
766                c: self.compile_variable(op.c),
767                out: self.compile_variable(out),
768            }),
769            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
770                lhs: self.compile_variable(op.lhs),
771                rhs: self.compile_variable(op.rhs),
772                out: self.compile_variable(out),
773            }),
774            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
775                lhs: self.compile_variable(op.lhs),
776                rhs: self.compile_variable(op.rhs),
777                out: self.compile_variable(out),
778            }),
779            cube::Arithmetic::SaturatingSub(_) => {
780                unreachable!("Saturating sub should be removed by processor");
781            }
782            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
783                lhs: self.compile_variable(op.lhs),
784                rhs: self.compile_variable(op.rhs),
785                out: self.compile_variable(out),
786            }),
787            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
788                lhs: self.compile_variable(op.lhs),
789                rhs: self.compile_variable(op.rhs),
790                out: self.compile_variable(out),
791            }),
792            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
793                input: self.compile_variable(op.input),
794                out: self.compile_variable(out),
795            }),
796            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
797                input: self.compile_variable(op.input),
798                out: self.compile_variable(out),
799            }),
800            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
801                input: self.compile_variable(op.input),
802                out: self.compile_variable(out),
803            }),
804            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
805                input: self.compile_variable(op.input),
806                out: self.compile_variable(out),
807            }),
808            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
809                input: self.compile_variable(op.input),
810                out: self.compile_variable(out),
811            }),
812            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
813                input: self.compile_variable(op.input),
814                out: self.compile_variable(out),
815            }),
816            cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
817                input: self.compile_variable(op.input),
818                out: self.compile_variable(out),
819            }),
820            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
821                input: self.compile_variable(op.input),
822                out: self.compile_variable(out),
823            }),
824            cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
825                input: self.compile_variable(op.input),
826                out: self.compile_variable(out),
827            }),
828            cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
829                input: self.compile_variable(op.input),
830                out: self.compile_variable(out),
831            }),
832            cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
833                input: self.compile_variable(op.input),
834                out: self.compile_variable(out),
835            }),
836            cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
837                input: self.compile_variable(op.input),
838                out: self.compile_variable(out),
839            }),
840            cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
841                input: self.compile_variable(op.input),
842                out: self.compile_variable(out),
843            }),
844            cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
845                input: self.compile_variable(op.input),
846                out: self.compile_variable(out),
847            }),
848            cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
849                input: self.compile_variable(op.input),
850                out: self.compile_variable(out),
851            }),
852            cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
853                input: self.compile_variable(op.input),
854                out: self.compile_variable(out),
855            }),
856            cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
857                input: self.compile_variable(op.input),
858                out: self.compile_variable(out),
859            }),
860            cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
861                input: self.compile_variable(op.input),
862                out: self.compile_variable(out),
863            }),
864            cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
865                lhs: self.compile_variable(op.lhs),
866                rhs: self.compile_variable(op.rhs),
867                out: self.compile_variable(out),
868            }),
869            // No powi in WGSL
870            cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
871                instructions.push(wgsl::Instruction::Powf {
872                    lhs: self.compile_variable(op.lhs),
873                    rhs: self.compile_variable(op.rhs),
874                    out: self.compile_variable(out),
875                })
876            }
877            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
878                input: self.compile_variable(op.input),
879                out: self.compile_variable(out),
880            }),
881            cube::Arithmetic::InverseSqrt(op) => {
882                instructions.push(wgsl::Instruction::InverseSqrt {
883                    input: self.compile_variable(op.input),
884                    out: self.compile_variable(out),
885                })
886            }
887            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
888                input: self.compile_variable(op.input),
889                out: self.compile_variable(out),
890            }),
891            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
892                input: self.compile_variable(op.input),
893                out: self.compile_variable(out),
894            }),
895            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
896                input: self.compile_variable(op.input),
897                out: self.compile_variable(out),
898            }),
899            cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
900                input: self.compile_variable(op.input),
901                out: self.compile_variable(out),
902            }),
903            cube::Arithmetic::Erf(op) => {
904                let mut scope = scope.child();
905                expand_erf(&mut scope, op.input, out);
906                instructions.extend(self.compile_scope(&mut scope));
907            }
908            cube::Arithmetic::MulHi(op) => {
909                let mut scope = scope.child();
910                match self.compilation_options.supports_u64 {
911                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
912                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
913                }
914                instructions.extend(self.compile_scope(&mut scope));
915            }
916            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
917                input: self.compile_variable(op.input),
918                out: self.compile_variable(out),
919            }),
920            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
921                input: self.compile_variable(op.input),
922                min_value: self.compile_variable(op.min_value),
923                max_value: self.compile_variable(op.max_value),
924                out: self.compile_variable(out),
925            }),
926            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
927                lhs: self.compile_variable(op.lhs),
928                rhs: self.compile_variable(op.rhs),
929                out: self.compile_variable(out),
930            }),
931            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
932                input: self.compile_variable(op.input),
933                out: self.compile_variable(out),
934            }),
935            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
936                input: self.compile_variable(op.input),
937                out: self.compile_variable(out),
938            }),
939            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
940                input: self.compile_variable(op.input),
941                out: self.compile_variable(out),
942            }),
943            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
944                lhs: self.compile_variable(op.lhs),
945                rhs: self.compile_variable(op.rhs),
946                out: self.compile_variable(out),
947            }),
948        }
949    }
950
951    fn compile_cmp(
952        &mut self,
953        value: cube::Comparison,
954        out: Option<cube::Variable>,
955        instructions: &mut Vec<wgsl::Instruction>,
956    ) {
957        let out = out.unwrap();
958        match value {
959            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
960                lhs: self.compile_variable(op.lhs),
961                rhs: self.compile_variable(op.rhs),
962                out: self.compile_variable(out),
963            }),
964            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
965                lhs: self.compile_variable(op.lhs),
966                rhs: self.compile_variable(op.rhs),
967                out: self.compile_variable(out),
968            }),
969            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
970                lhs: self.compile_variable(op.lhs),
971                rhs: self.compile_variable(op.rhs),
972                out: self.compile_variable(out),
973            }),
974            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
975                lhs: self.compile_variable(op.lhs),
976                rhs: self.compile_variable(op.rhs),
977                out: self.compile_variable(out),
978            }),
979            cube::Comparison::GreaterEqual(op) => {
980                instructions.push(wgsl::Instruction::GreaterEqual {
981                    lhs: self.compile_variable(op.lhs),
982                    rhs: self.compile_variable(op.rhs),
983                    out: self.compile_variable(out),
984                })
985            }
986            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
987                lhs: self.compile_variable(op.lhs),
988                rhs: self.compile_variable(op.rhs),
989                out: self.compile_variable(out),
990            }),
991            cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
992                input: self.compile_variable(op.input),
993                out: self.compile_variable(out),
994            }),
995            cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
996                input: self.compile_variable(op.input),
997                out: self.compile_variable(out),
998            }),
999        }
1000    }
1001
1002    fn compile_bitwise(
1003        &mut self,
1004        value: cube::Bitwise,
1005        out: Option<cube::Variable>,
1006        instructions: &mut Vec<wgsl::Instruction>,
1007    ) {
1008        let out = out.unwrap();
1009        match value {
1010            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1011                lhs: self.compile_variable(op.lhs),
1012                rhs: self.compile_variable(op.rhs),
1013                out: self.compile_variable(out),
1014            }),
1015            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1016                lhs: self.compile_variable(op.lhs),
1017                rhs: self.compile_variable(op.rhs),
1018                out: self.compile_variable(out),
1019            }),
1020            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1021                lhs: self.compile_variable(op.lhs),
1022                rhs: self.compile_variable(op.rhs),
1023                out: self.compile_variable(out),
1024            }),
1025            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1026                input: self.compile_variable(op.input),
1027                out: self.compile_variable(out),
1028            }),
1029            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1030                input: self.compile_variable(op.input),
1031                out: self.compile_variable(out),
1032            }),
1033            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1034                lhs: self.compile_variable(op.lhs),
1035                rhs: self.compile_variable(op.rhs),
1036                out: self.compile_variable(out),
1037            }),
1038            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1039                lhs: self.compile_variable(op.lhs),
1040                rhs: self.compile_variable(op.rhs),
1041                out: self.compile_variable(out),
1042            }),
1043            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1044                input: self.compile_variable(op.input),
1045                out: self.compile_variable(out),
1046            }),
1047            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1048                input: self.compile_variable(op.input),
1049                out: self.compile_variable(out),
1050            }),
1051            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1052                input: self.compile_variable(op.input),
1053                out: self.compile_variable(out),
1054            }),
1055        }
1056    }
1057
1058    fn compile_operator(
1059        &mut self,
1060        value: cube::Operator,
1061        out: Option<cube::Variable>,
1062        instructions: &mut Vec<wgsl::Instruction>,
1063    ) {
1064        let out = out.unwrap();
1065        match value {
1066            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1067                input: self.compile_variable(op.input),
1068                out: self.compile_variable(out),
1069            }),
1070            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1071                instructions.push(wgsl::Instruction::Index {
1072                    lhs: self.compile_variable(op.list),
1073                    rhs: self.compile_variable(op.index),
1074                    out: self.compile_variable(out),
1075                });
1076            }
1077            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1078                instructions.push(wgsl::Instruction::IndexAssign {
1079                    index: self.compile_variable(op.index),
1080                    rhs: self.compile_variable(op.value),
1081                    out: self.compile_variable(out),
1082                })
1083            }
1084            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1085                lhs: self.compile_variable(op.lhs),
1086                rhs: self.compile_variable(op.rhs),
1087                out: self.compile_variable(out),
1088            }),
1089            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1090                lhs: self.compile_variable(op.lhs),
1091                rhs: self.compile_variable(op.rhs),
1092                out: self.compile_variable(out),
1093            }),
1094            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1095                input: self.compile_variable(op.input),
1096                out: self.compile_variable(out),
1097            }),
1098            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1099                input: self.compile_variable(op.input),
1100                out: self.compile_variable(out),
1101            }),
1102            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1103                inputs: op
1104                    .inputs
1105                    .into_iter()
1106                    .map(|var| self.compile_variable(var))
1107                    .collect(),
1108                out: self.compile_variable(out),
1109            }),
1110            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1111                input: self.compile_variable(op.input),
1112                in_index: self.compile_variable(op.in_index),
1113                out: self.compile_variable(out),
1114                out_index: self.compile_variable(op.out_index),
1115            }),
1116            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1117                input: self.compile_variable(op.input),
1118                in_index: self.compile_variable(op.in_index),
1119                out: self.compile_variable(out),
1120                out_index: self.compile_variable(op.out_index),
1121                len: op.len,
1122            }),
1123            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1124                cond: self.compile_variable(op.cond),
1125                then: self.compile_variable(op.then),
1126                or_else: self.compile_variable(op.or_else),
1127                out: self.compile_variable(out),
1128            }),
1129        }
1130    }
1131
1132    fn compile_atomic(
1133        &mut self,
1134        atomic: cube::AtomicOp,
1135        out: Option<cube::Variable>,
1136    ) -> wgsl::Instruction {
1137        let out = out.unwrap();
1138        match atomic {
1139            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1140                lhs: self.compile_variable(op.lhs),
1141                rhs: self.compile_variable(op.rhs),
1142                out: self.compile_variable(out),
1143            },
1144            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1145                lhs: self.compile_variable(op.lhs),
1146                rhs: self.compile_variable(op.rhs),
1147                out: self.compile_variable(out),
1148            },
1149            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1150                lhs: self.compile_variable(op.lhs),
1151                rhs: self.compile_variable(op.rhs),
1152                out: self.compile_variable(out),
1153            },
1154            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1155                lhs: self.compile_variable(op.lhs),
1156                rhs: self.compile_variable(op.rhs),
1157                out: self.compile_variable(out),
1158            },
1159            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1160                lhs: self.compile_variable(op.lhs),
1161                rhs: self.compile_variable(op.rhs),
1162                out: self.compile_variable(out),
1163            },
1164            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1165                lhs: self.compile_variable(op.lhs),
1166                rhs: self.compile_variable(op.rhs),
1167                out: self.compile_variable(out),
1168            },
1169            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1170                lhs: self.compile_variable(op.lhs),
1171                rhs: self.compile_variable(op.rhs),
1172                out: self.compile_variable(out),
1173            },
1174            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1175                input: self.compile_variable(op.input),
1176                out: self.compile_variable(out),
1177            },
1178            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1179                input: self.compile_variable(op.input),
1180                out: self.compile_variable(out),
1181            },
1182            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1183                lhs: self.compile_variable(op.lhs),
1184                rhs: self.compile_variable(op.rhs),
1185                out: self.compile_variable(out),
1186            },
1187            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1188                lhs: self.compile_variable(op.input),
1189                cmp: self.compile_variable(op.cmp),
1190                value: self.compile_variable(op.val),
1191                out: self.compile_variable(out),
1192            },
1193        }
1194    }
1195
1196    fn compile_location(value: kernel::Location) -> wgsl::Location {
1197        match value {
1198            kernel::Location::Storage => wgsl::Location::Storage,
1199            kernel::Location::Cube => wgsl::Location::Workgroup,
1200        }
1201    }
1202
1203    fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1204        wgsl::Binding {
1205            id: value.id,
1206            visibility: value.visibility,
1207            location: Self::compile_location(value.location),
1208            item: self.compile_type(value.ty),
1209            size: value.size,
1210        }
1211    }
1212}
1213
1214fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1215    let mut extensions = Vec::new();
1216
1217    let mut register_extension = |extension: wgsl::Extension| {
1218        if !extensions.contains(&extension) {
1219            extensions.push(extension);
1220        }
1221    };
1222
1223    // Since not all instructions are native to WGSL, we need to add the custom ones.
1224    for instruction in instructions {
1225        match instruction {
1226            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1227                register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1228                register_extension(wgsl::powf_extension(rhs, out));
1229            }
1230            #[cfg(target_os = "macos")]
1231            wgsl::Instruction::Tanh { input, out: _ } => {
1232                register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1233                register_extension(wgsl::Extension::SafeTanh(input.item()));
1234            }
1235            wgsl::Instruction::IsNan { input, out } => {
1236                register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1237                register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1238            }
1239            wgsl::Instruction::IsInf { input, out } => {
1240                register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1241                register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1242            }
1243            wgsl::Instruction::If { instructions, .. } => {
1244                for extension in register_extensions(instructions) {
1245                    register_extension(extension);
1246                }
1247            }
1248            wgsl::Instruction::IfElse {
1249                instructions_if,
1250                instructions_else,
1251                ..
1252            } => {
1253                for extension in register_extensions(instructions_if) {
1254                    register_extension(extension);
1255                }
1256                for extension in register_extensions(instructions_else) {
1257                    register_extension(extension);
1258                }
1259            }
1260            wgsl::Instruction::Loop { instructions } => {
1261                for extension in register_extensions(instructions) {
1262                    register_extension(extension);
1263                }
1264            }
1265            wgsl::Instruction::RangeLoop { instructions, .. } => {
1266                for extension in register_extensions(instructions) {
1267                    register_extension(extension);
1268                }
1269            }
1270            _ => {}
1271        }
1272    }
1273
1274    extensions
1275}