cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::post_processing::{
8    checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12    Metadata, WgpuCompilationOptions,
13    ir::{self as cube, Scope},
14    prelude::expand_erf,
15};
16use cubecl_core::{
17    ir::{ConstantScalarValue, Processor, UIntKind},
18    post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: u32 = 4;
24
25/// Wgsl Compiler.
26#[derive(Clone, Default)]
27pub struct WgslCompiler {
28    metadata: Metadata,
29    ext_meta_pos: Vec<u32>,
30    local_invocation_index: bool,
31    local_invocation_id: bool,
32    // TODO: possible cleanup, this bool seems to not be used
33    global_invocation_id: bool,
34    workgroup_id: bool,
35    subgroup_size: bool,
36    subgroup_invocation_id: bool,
37    id: bool,
38    num_workgroups: bool,
39    workgroup_id_no_axis: bool,
40    workgroup_size_no_axis: bool,
41    num_workgroup_no_axis: bool,
42    shared_arrays: Vec<SharedArray>,
43    shared_values: Vec<SharedValue>,
44    const_arrays: Vec<ConstantArray>,
45    local_arrays: Vec<LocalArray>,
46    #[allow(dead_code)]
47    compilation_options: WgpuCompilationOptions,
48    strategy: ExecutionMode,
49    subgroup_instructions_used: bool,
50    f16_used: bool,
51}
52
53impl core::fmt::Debug for WgslCompiler {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.write_str("WgslCompiler")
56    }
57}
58
59impl cubecl_core::Compiler for WgslCompiler {
60    type Representation = ComputeShader;
61    type CompilationOptions = WgpuCompilationOptions;
62
63    fn compile(
64        &mut self,
65        shader: kernel::KernelDefinition,
66        compilation_options: &Self::CompilationOptions,
67        mode: ExecutionMode,
68    ) -> Result<Self::Representation, CompilationError> {
69        self.compilation_options = compilation_options.clone();
70        self.compile_shader(shader, mode)
71    }
72
73    fn elem_size(&self, elem: cube::ElemType) -> usize {
74        elem.size()
75    }
76
77    fn extension(&self) -> &'static str {
78        "wgsl"
79    }
80}
81
82impl WgslCompiler {
83    fn compile_shader(
84        &mut self,
85        mut value: kernel::KernelDefinition,
86        mode: ExecutionMode,
87    ) -> Result<wgsl::ComputeShader, CompilationError> {
88        let errors = value.body.pop_errors();
89        if !errors.is_empty() {
90            let mut reason = "Can't compile wgsl kernel".to_string();
91            for error in errors {
92                reason += error.as_str();
93                reason += "\n";
94            }
95
96            return Err(CompilationError::Validation {
97                reason,
98                backtrace: BackTrace::capture(),
99            });
100        }
101
102        self.strategy = mode;
103
104        let num_meta = value.buffers.len();
105
106        self.ext_meta_pos = Vec::new();
107        let mut num_ext = 0;
108
109        for binding in value.buffers.iter() {
110            self.ext_meta_pos.push(num_ext);
111            if binding.has_extended_meta {
112                num_ext += 1;
113            }
114        }
115
116        self.metadata = Metadata::new(num_meta as u32, num_ext);
117
118        let instructions = self.compile_scope(&mut value.body);
119        let extensions = register_extensions(&instructions);
120        let body = wgsl::Body {
121            instructions,
122            id: self.id,
123        };
124
125        Ok(wgsl::ComputeShader {
126            buffers: value
127                .buffers
128                .into_iter()
129                .map(|mut it| {
130                    // This is safe when combined with the unroll transform that adjusts all indices.
131                    // Must not be used alone
132                    if it.ty.line_size() > MAX_LINE_SIZE {
133                        it.ty = it.ty.line(MAX_LINE_SIZE);
134                    }
135                    self.compile_binding(it)
136                })
137                .collect(),
138            scalars: value
139                .scalars
140                .into_iter()
141                .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
142                .collect(),
143            shared_arrays: self.shared_arrays.clone(),
144            shared_values: self.shared_values.clone(),
145            constant_arrays: self.const_arrays.clone(),
146            local_arrays: self.local_arrays.clone(),
147            has_metadata: self.metadata.static_len() > 0,
148            workgroup_size: value.cube_dim,
149            global_invocation_id: self.global_invocation_id || self.id,
150            local_invocation_index: self.local_invocation_index,
151            local_invocation_id: self.local_invocation_id,
152            num_workgroups: self.id
153                || self.num_workgroups
154                || self.num_workgroup_no_axis
155                || self.workgroup_id_no_axis,
156            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
157            subgroup_size: self.subgroup_size,
158            subgroup_invocation_id: self.subgroup_invocation_id,
159            body,
160            extensions,
161            num_workgroups_no_axis: self.num_workgroup_no_axis,
162            workgroup_id_no_axis: self.workgroup_id_no_axis,
163            workgroup_size_no_axis: self.workgroup_size_no_axis,
164            subgroup_instructions_used: self.subgroup_instructions_used,
165            f16_used: self.f16_used,
166            kernel_name: value.options.kernel_name,
167        })
168    }
169
170    fn compile_type(&mut self, item: cube::Type) -> Item {
171        match item {
172            cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
173            cube::Type::Line(ty, size) => {
174                let elem = self.compile_storage_type(ty);
175                match size {
176                    2 => wgsl::Item::Vec2(elem),
177                    3 => wgsl::Item::Vec3(elem),
178                    4 => wgsl::Item::Vec4(elem),
179                    _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
180                }
181            }
182            cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
183        }
184    }
185
186    fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
187        match ty {
188            cube::StorageType::Scalar(ty) => self.compile_elem(ty),
189            cube::StorageType::Atomic(ty) => match ty {
190                cube::ElemType::Float(i) => match i {
191                    cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
192                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
193                },
194                cube::ElemType::Int(i) => match i {
195                    cube::IntKind::I32 => wgsl::Elem::AtomicI32,
196                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
197                },
198                cube::ElemType::UInt(kind) => match kind {
199                    cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
200                    kind => panic!("{kind:?} is not a valid WgpuElement"),
201                },
202                other => panic!("{other:?} is not a valid WgpuElement"),
203            },
204            cube::StorageType::Packed(_, _) => {
205                unimplemented!("Packed types not yet supported in WGSL")
206            }
207            cube::StorageType::Opaque(ty) => match ty {
208                cube::OpaqueType::Barrier(_) => {
209                    unimplemented!("Barrier objects not supported in WGSL")
210                }
211            },
212        }
213    }
214
215    fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
216        match value {
217            cube::ElemType::Float(f) => match f {
218                cube::FloatKind::E2M1
219                | cube::FloatKind::E2M3
220                | cube::FloatKind::E3M2
221                | cube::FloatKind::E4M3
222                | cube::FloatKind::E5M2
223                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
224                cube::FloatKind::F16 => {
225                    self.f16_used = true;
226                    wgsl::Elem::F16
227                }
228                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
229                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
230                cube::FloatKind::Flex32 => wgsl::Elem::F32,
231                cube::FloatKind::F32 => wgsl::Elem::F32,
232                cube::FloatKind::F64 => wgsl::Elem::F64,
233            },
234            cube::ElemType::Int(i) => match i {
235                cube::IntKind::I32 => wgsl::Elem::I32,
236                cube::IntKind::I64 => wgsl::Elem::I64,
237                kind => panic!("{kind:?} is not a valid WgpuElement"),
238            },
239            cube::ElemType::UInt(kind) => match kind {
240                cube::UIntKind::U32 => wgsl::Elem::U32,
241                cube::UIntKind::U64 => wgsl::Elem::U64,
242                kind => panic!("{kind:?} is not a valid WgpuElement"),
243            },
244            cube::ElemType::Bool => wgsl::Elem::Bool,
245        }
246    }
247
248    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
249        let pos = var.index().expect("Variable should have index");
250        self.ext_meta_pos[pos as usize]
251    }
252
253    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
254        let item = value.ty;
255        match value.kind {
256            cube::VariableKind::GlobalInputArray(id) => {
257                wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
258            }
259            cube::VariableKind::GlobalScalar(id) => {
260                wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
261            }
262            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
263                wgsl::Variable::LocalMut {
264                    id,
265                    item: self.compile_type(item),
266                }
267            }
268            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
269                id,
270                item: self.compile_type(item),
271            },
272            cube::VariableKind::GlobalOutputArray(id) => {
273                wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
274            }
275            cube::VariableKind::ConstantScalar(value) => {
276                wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
277            }
278            cube::VariableKind::SharedArray {
279                id,
280                length,
281                unroll_factor,
282                alignment,
283            } => {
284                let item = self.compile_type(item);
285                if !self.shared_arrays.iter().any(|s| s.index == id) {
286                    self.shared_arrays.push(SharedArray::new(
287                        id,
288                        item,
289                        length * unroll_factor,
290                        alignment,
291                    ));
292                }
293                wgsl::Variable::SharedArray(id, item, length)
294            }
295            cube::VariableKind::Shared { id } => {
296                let item = self.compile_type(item);
297                if !self.shared_values.iter().any(|s| s.index == id) {
298                    self.shared_values.push(SharedValue::new(id, item));
299                }
300                wgsl::Variable::SharedValue(id, item)
301            }
302            cube::VariableKind::ConstantArray { id, length, .. } => {
303                let item = self.compile_type(item);
304                wgsl::Variable::ConstantArray(id, item, length)
305            }
306            cube::VariableKind::LocalArray {
307                id,
308                length,
309                unroll_factor,
310            } => {
311                let item = self.compile_type(item);
312                if !self.local_arrays.iter().any(|s| s.index == id) {
313                    self.local_arrays
314                        .push(LocalArray::new(id, item, length * unroll_factor));
315                }
316                wgsl::Variable::LocalArray(id, item, length)
317            }
318            cube::VariableKind::Builtin(builtin) => match builtin {
319                cube::Builtin::AbsolutePos => {
320                    self.id = true;
321                    wgsl::Variable::Id
322                }
323                cube::Builtin::UnitPos => {
324                    self.local_invocation_index = true;
325                    wgsl::Variable::LocalInvocationIndex
326                }
327                cube::Builtin::UnitPosX => {
328                    self.local_invocation_id = true;
329                    wgsl::Variable::LocalInvocationIdX
330                }
331                cube::Builtin::UnitPosY => {
332                    self.local_invocation_id = true;
333                    wgsl::Variable::LocalInvocationIdY
334                }
335                cube::Builtin::UnitPosZ => {
336                    self.local_invocation_id = true;
337                    wgsl::Variable::LocalInvocationIdZ
338                }
339                cube::Builtin::CubePosX => {
340                    self.workgroup_id = true;
341                    wgsl::Variable::WorkgroupIdX
342                }
343                cube::Builtin::CubePosY => {
344                    self.workgroup_id = true;
345                    wgsl::Variable::WorkgroupIdY
346                }
347                cube::Builtin::CubePosZ => {
348                    self.workgroup_id = true;
349                    wgsl::Variable::WorkgroupIdZ
350                }
351                cube::Builtin::CubePosCluster
352                | cube::Builtin::CubePosClusterX
353                | cube::Builtin::CubePosClusterY
354                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
355                cube::Builtin::AbsolutePosX => {
356                    self.global_invocation_id = true;
357                    wgsl::Variable::GlobalInvocationIdX
358                }
359                cube::Builtin::AbsolutePosY => {
360                    self.global_invocation_id = true;
361                    wgsl::Variable::GlobalInvocationIdY
362                }
363                cube::Builtin::AbsolutePosZ => {
364                    self.global_invocation_id = true;
365                    wgsl::Variable::GlobalInvocationIdZ
366                }
367                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
368                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
369                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
370                cube::Builtin::CubeClusterDim
371                | cube::Builtin::CubeClusterDimX
372                | cube::Builtin::CubeClusterDimY
373                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
374                cube::Builtin::CubeCountX => {
375                    self.num_workgroups = true;
376                    wgsl::Variable::NumWorkgroupsX
377                }
378                cube::Builtin::CubeCountY => {
379                    self.num_workgroups = true;
380                    wgsl::Variable::NumWorkgroupsY
381                }
382                cube::Builtin::CubeCountZ => {
383                    self.num_workgroups = true;
384                    wgsl::Variable::NumWorkgroupsZ
385                }
386                cube::Builtin::CubePos => {
387                    self.workgroup_id_no_axis = true;
388                    wgsl::Variable::WorkgroupId
389                }
390                cube::Builtin::CubeDim => {
391                    self.workgroup_size_no_axis = true;
392                    wgsl::Variable::WorkgroupSize
393                }
394                cube::Builtin::CubeCount => {
395                    self.num_workgroup_no_axis = true;
396                    wgsl::Variable::NumWorkgroups
397                }
398                cube::Builtin::PlaneDim => {
399                    self.subgroup_size = true;
400                    wgsl::Variable::SubgroupSize
401                }
402                cube::Builtin::UnitPosPlane => {
403                    self.subgroup_invocation_id = true;
404                    wgsl::Variable::SubgroupInvocationId
405                }
406            },
407            cube::VariableKind::Matrix { .. } => {
408                panic!("Cooperative matrix-multiply and accumulate not supported.")
409            }
410            cube::VariableKind::Pipeline { .. } => {
411                panic!("Pipeline not supported.")
412            }
413            cube::VariableKind::BarrierToken { .. } => {
414                panic!("Barrier not supported.")
415            }
416            cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
417            cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
418        }
419    }
420
421    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
422        let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
423        self.compile_variable(var)
424    }
425
426    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
427        let mut instructions = Vec::new();
428
429        let const_arrays = scope
430            .const_arrays
431            .drain(..)
432            .map(|(var, values)| ConstantArray {
433                index: var.index().unwrap(),
434                item: self.compile_type(var.ty),
435                size: values.len() as u32,
436                values: values
437                    .into_iter()
438                    .map(|val| self.compile_variable(val))
439                    .collect(),
440            })
441            .collect::<Vec<_>>();
442        self.const_arrays.extend(const_arrays);
443
444        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
445        let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
446        let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
447        let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
448
449        for mut var in processing.variables {
450            if var.ty.line_size() > MAX_LINE_SIZE {
451                var.ty = var.ty.line(MAX_LINE_SIZE);
452            }
453            instructions.push(wgsl::Instruction::DeclareVariable {
454                var: self.compile_variable(var),
455            });
456        }
457
458        processing
459            .instructions
460            .into_iter()
461            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
462
463        instructions
464    }
465
466    fn compile_operation(
467        &mut self,
468        instructions: &mut Vec<wgsl::Instruction>,
469        operation: cube::Operation,
470        out: Option<cube::Variable>,
471        scope: &mut cube::Scope,
472    ) {
473        match operation {
474            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
475                input: self.compile_variable(variable),
476                out: self.compile_variable(out.unwrap()),
477            }),
478            cube::Operation::Arithmetic(op) => {
479                self.compile_arithmetic(op, out, instructions, scope)
480            }
481            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
482            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
483            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
484            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
485            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
486            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
487            cube::Operation::Synchronization(val) => {
488                self.compile_synchronization(instructions, val)
489            }
490            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
491            cube::Operation::CoopMma(_) => {
492                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
493            }
494            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
495                self.compile_comment(instructions, content)
496            }
497            cube::Operation::NonSemantic(_) => {}
498            cube::Operation::Barrier(_) => {
499                panic!("Barrier isn't supported on wgpu.")
500            }
501            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
502            cube::Operation::Marker(_) => {}
503        }
504    }
505
506    fn compile_subgroup(
507        &mut self,
508        instructions: &mut Vec<wgsl::Instruction>,
509        subgroup: cube::Plane,
510        out: Option<cube::Variable>,
511    ) {
512        self.subgroup_instructions_used = true;
513
514        let out = out.unwrap();
515        let op = match subgroup {
516            cube::Plane::Elect => Subgroup::Elect {
517                out: self.compile_variable(out),
518            },
519            cube::Plane::All(op) => Subgroup::All {
520                input: self.compile_variable(op.input),
521                out: self.compile_variable(out),
522            },
523            cube::Plane::Any(op) => Subgroup::Any {
524                input: self.compile_variable(op.input),
525                out: self.compile_variable(out),
526            },
527            cube::Plane::Ballot(op) => Subgroup::Ballot {
528                input: self.compile_variable(op.input),
529                out: self.compile_variable(out),
530            },
531
532            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
533                lhs: self.compile_variable(op.lhs),
534                rhs: self.compile_variable(op.rhs),
535                out: self.compile_variable(out),
536            },
537
538            cube::Plane::Sum(op) => Subgroup::Sum {
539                input: self.compile_variable(op.input),
540                out: self.compile_variable(out),
541            },
542
543            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
544                input: self.compile_variable(op.input),
545                out: self.compile_variable(out),
546            },
547            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
548                input: self.compile_variable(op.input),
549                out: self.compile_variable(out),
550            },
551            cube::Plane::Prod(op) => Subgroup::Prod {
552                input: self.compile_variable(op.input),
553                out: self.compile_variable(out),
554            },
555            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
556                input: self.compile_variable(op.input),
557                out: self.compile_variable(out),
558            },
559            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
560                input: self.compile_variable(op.input),
561                out: self.compile_variable(out),
562            },
563            cube::Plane::Min(op) => Subgroup::Min {
564                input: self.compile_variable(op.input),
565                out: self.compile_variable(out),
566            },
567            cube::Plane::Max(op) => Subgroup::Max {
568                input: self.compile_variable(op.input),
569                out: self.compile_variable(out),
570            },
571            cube::Plane::Shuffle(op) => Subgroup::Shuffle {
572                lhs: self.compile_variable(op.lhs),
573                rhs: self.compile_variable(op.rhs),
574                out: self.compile_variable(out),
575            },
576            cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
577                lhs: self.compile_variable(op.lhs),
578                rhs: self.compile_variable(op.rhs),
579                out: self.compile_variable(out),
580            },
581            cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
582                lhs: self.compile_variable(op.lhs),
583                rhs: self.compile_variable(op.rhs),
584                out: self.compile_variable(out),
585            },
586            cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
587                lhs: self.compile_variable(op.lhs),
588                rhs: self.compile_variable(op.rhs),
589                out: self.compile_variable(out),
590            },
591        };
592
593        instructions.push(wgsl::Instruction::Subgroup(op));
594    }
595
596    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
597        match branch {
598            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
599                cond: self.compile_variable(op.cond),
600                instructions: self.compile_scope(&mut op.scope),
601            }),
602            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
603                cond: self.compile_variable(op.cond),
604                instructions_if: self.compile_scope(&mut op.scope_if),
605                instructions_else: self.compile_scope(&mut op.scope_else),
606            }),
607            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
608                value: self.compile_variable(op.value),
609                instructions_default: self.compile_scope(&mut op.scope_default),
610                cases: op
611                    .cases
612                    .into_iter()
613                    .map(|(val, mut scope)| {
614                        (self.compile_variable(val), self.compile_scope(&mut scope))
615                    })
616                    .collect(),
617            }),
618            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
619            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
620            cube::Branch::RangeLoop(mut range_loop) => {
621                instructions.push(wgsl::Instruction::RangeLoop {
622                    i: self.compile_variable(range_loop.i),
623                    start: self.compile_variable(range_loop.start),
624                    end: self.compile_variable(range_loop.end),
625                    step: range_loop.step.map(|it| self.compile_variable(it)),
626                    inclusive: range_loop.inclusive,
627                    instructions: self.compile_scope(&mut range_loop.scope),
628                })
629            }
630            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
631                instructions: self.compile_scope(&mut op.scope),
632            }),
633        };
634    }
635
636    fn compile_synchronization(
637        &mut self,
638        instructions: &mut Vec<wgsl::Instruction>,
639        synchronization: cube::Synchronization,
640    ) {
641        match synchronization {
642            cube::Synchronization::SyncCube => {
643                instructions.push(wgsl::Instruction::WorkgroupBarrier)
644            }
645            cube::Synchronization::SyncPlane => {
646                panic!("Synchronization within a plane is not supported in WGSL")
647            }
648            cube::Synchronization::SyncStorage => {
649                instructions.push(wgsl::Instruction::StorageBarrier)
650            }
651            cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
652        };
653    }
654
655    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
656        instructions.push(wgsl::Instruction::Comment { content })
657    }
658
659    fn compile_metadata(
660        &mut self,
661        metadata: cube::Metadata,
662        out: Option<cube::Variable>,
663    ) -> wgsl::Instruction {
664        let out = out.unwrap();
665        match metadata {
666            cube::Metadata::Rank { var } => {
667                let position = self.ext_meta_pos(&var);
668                let offset = self.metadata.rank_index(position);
669                wgsl::Instruction::Metadata {
670                    out: self.compile_variable(out),
671                    info_offset: self.compile_variable(offset.into()),
672                }
673            }
674            cube::Metadata::Stride { dim, var } => {
675                let position = self.ext_meta_pos(&var);
676                let offset = self.metadata.stride_offset_index(position);
677                wgsl::Instruction::ExtendedMeta {
678                    info_offset: self.compile_variable(offset.into()),
679                    dim: self.compile_variable(dim),
680                    out: self.compile_variable(out),
681                }
682            }
683            cube::Metadata::Shape { dim, var } => {
684                let position = self.ext_meta_pos(&var);
685                let offset = self.metadata.shape_offset_index(position);
686                wgsl::Instruction::ExtendedMeta {
687                    info_offset: self.compile_variable(offset.into()),
688                    dim: self.compile_variable(dim),
689                    out: self.compile_variable(out),
690                }
691            }
692            cube::Metadata::Length { var } => match var.kind {
693                cube::VariableKind::GlobalInputArray(id) => {
694                    let offset = self.metadata.len_index(id);
695                    wgsl::Instruction::Metadata {
696                        out: self.compile_variable(out),
697                        info_offset: self.compile_variable(offset.into()),
698                    }
699                }
700                cube::VariableKind::GlobalOutputArray(id) => {
701                    let offset = self.metadata.len_index(id);
702                    wgsl::Instruction::Metadata {
703                        out: self.compile_variable(out),
704                        info_offset: self.compile_variable(offset.into()),
705                    }
706                }
707                _ => wgsl::Instruction::Length {
708                    var: self.compile_variable(var),
709                    out: self.compile_variable(out),
710                },
711            },
712            cube::Metadata::BufferLength { var } => match var.kind {
713                cube::VariableKind::GlobalInputArray(id) => {
714                    let offset = self.metadata.buffer_len_index(id);
715                    wgsl::Instruction::Metadata {
716                        out: self.compile_variable(out),
717                        info_offset: self.compile_variable(offset.into()),
718                    }
719                }
720                cube::VariableKind::GlobalOutputArray(id) => {
721                    let offset = self.metadata.buffer_len_index(id);
722                    wgsl::Instruction::Metadata {
723                        out: self.compile_variable(out),
724                        info_offset: self.compile_variable(offset.into()),
725                    }
726                }
727                _ => wgsl::Instruction::Length {
728                    var: self.compile_variable(var),
729                    out: self.compile_variable(out),
730                },
731            },
732        }
733    }
734
735    fn compile_arithmetic(
736        &mut self,
737        value: cube::Arithmetic,
738        out: Option<cube::Variable>,
739        instructions: &mut Vec<wgsl::Instruction>,
740        scope: &mut Scope,
741    ) {
742        let out = out.unwrap();
743        match value {
744            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
745                lhs: self.compile_variable(op.lhs),
746                rhs: self.compile_variable(op.rhs),
747                out: self.compile_variable(out),
748            }),
749            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
750                lhs: self.compile_variable(op.lhs),
751                rhs: self.compile_variable(op.rhs),
752                out: self.compile_variable(out),
753            }),
754            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
755                lhs: self.compile_variable(op.lhs),
756                rhs: self.compile_variable(op.rhs),
757                out: self.compile_variable(out),
758            }),
759            cube::Arithmetic::SaturatingAdd(_) => {
760                unreachable!("Saturating add should be removed by processor");
761            }
762            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
763                a: self.compile_variable(op.a),
764                b: self.compile_variable(op.b),
765                c: self.compile_variable(op.c),
766                out: self.compile_variable(out),
767            }),
768            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
769                lhs: self.compile_variable(op.lhs),
770                rhs: self.compile_variable(op.rhs),
771                out: self.compile_variable(out),
772            }),
773            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
774                lhs: self.compile_variable(op.lhs),
775                rhs: self.compile_variable(op.rhs),
776                out: self.compile_variable(out),
777            }),
778            cube::Arithmetic::SaturatingSub(_) => {
779                unreachable!("Saturating sub should be removed by processor");
780            }
781            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
782                lhs: self.compile_variable(op.lhs),
783                rhs: self.compile_variable(op.rhs),
784                out: self.compile_variable(out),
785            }),
786            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
787                lhs: self.compile_variable(op.lhs),
788                rhs: self.compile_variable(op.rhs),
789                out: self.compile_variable(out),
790            }),
791            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
792                input: self.compile_variable(op.input),
793                out: self.compile_variable(out),
794            }),
795            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
796                input: self.compile_variable(op.input),
797                out: self.compile_variable(out),
798            }),
799            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
800                input: self.compile_variable(op.input),
801                out: self.compile_variable(out),
802            }),
803            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
804                input: self.compile_variable(op.input),
805                out: self.compile_variable(out),
806            }),
807            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
808                input: self.compile_variable(op.input),
809                out: self.compile_variable(out),
810            }),
811            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
812                input: self.compile_variable(op.input),
813                out: self.compile_variable(out),
814            }),
815            cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
816                input: self.compile_variable(op.input),
817                out: self.compile_variable(out),
818            }),
819            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
820                input: self.compile_variable(op.input),
821                out: self.compile_variable(out),
822            }),
823            cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
824                input: self.compile_variable(op.input),
825                out: self.compile_variable(out),
826            }),
827            cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
828                input: self.compile_variable(op.input),
829                out: self.compile_variable(out),
830            }),
831            cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
832                input: self.compile_variable(op.input),
833                out: self.compile_variable(out),
834            }),
835            cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
836                input: self.compile_variable(op.input),
837                out: self.compile_variable(out),
838            }),
839            cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
840                input: self.compile_variable(op.input),
841                out: self.compile_variable(out),
842            }),
843            cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
844                input: self.compile_variable(op.input),
845                out: self.compile_variable(out),
846            }),
847            cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
848                input: self.compile_variable(op.input),
849                out: self.compile_variable(out),
850            }),
851            cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
852                input: self.compile_variable(op.input),
853                out: self.compile_variable(out),
854            }),
855            cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
856                input: self.compile_variable(op.input),
857                out: self.compile_variable(out),
858            }),
859            cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
860                input: self.compile_variable(op.input),
861                out: self.compile_variable(out),
862            }),
863            cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
864                lhs: self.compile_variable(op.lhs),
865                rhs: self.compile_variable(op.rhs),
866                out: self.compile_variable(out),
867            }),
868            // No powi in WGSL
869            cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
870                instructions.push(wgsl::Instruction::Powf {
871                    lhs: self.compile_variable(op.lhs),
872                    rhs: self.compile_variable(op.rhs),
873                    out: self.compile_variable(out),
874                })
875            }
876            cube::Arithmetic::Hypot(op) => {
877                let mut scope = scope.child();
878                expand_hypot(&mut scope, op.lhs, op.rhs, out);
879                instructions.extend(self.compile_scope(&mut scope));
880            }
881            cube::Arithmetic::Rhypot(op) => {
882                let mut scope = scope.child();
883                expand_rhypot(&mut scope, op.lhs, op.rhs, out);
884                instructions.extend(self.compile_scope(&mut scope));
885            }
886
887            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
888                input: self.compile_variable(op.input),
889                out: self.compile_variable(out),
890            }),
891            cube::Arithmetic::InverseSqrt(op) => {
892                instructions.push(wgsl::Instruction::InverseSqrt {
893                    input: self.compile_variable(op.input),
894                    out: self.compile_variable(out),
895                })
896            }
897            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
898                input: self.compile_variable(op.input),
899                out: self.compile_variable(out),
900            }),
901            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
902                input: self.compile_variable(op.input),
903                out: self.compile_variable(out),
904            }),
905            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
906                input: self.compile_variable(op.input),
907                out: self.compile_variable(out),
908            }),
909            cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
910                input: self.compile_variable(op.input),
911                out: self.compile_variable(out),
912            }),
913            cube::Arithmetic::Erf(op) => {
914                let mut scope = scope.child();
915                expand_erf(&mut scope, op.input, out);
916                instructions.extend(self.compile_scope(&mut scope));
917            }
918            cube::Arithmetic::MulHi(op) => {
919                let mut scope = scope.child();
920                match self.compilation_options.supports_u64 {
921                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
922                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
923                }
924                instructions.extend(self.compile_scope(&mut scope));
925            }
926            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
927                input: self.compile_variable(op.input),
928                out: self.compile_variable(out),
929            }),
930            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
931                input: self.compile_variable(op.input),
932                min_value: self.compile_variable(op.min_value),
933                max_value: self.compile_variable(op.max_value),
934                out: self.compile_variable(out),
935            }),
936            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
937                lhs: self.compile_variable(op.lhs),
938                rhs: self.compile_variable(op.rhs),
939                out: self.compile_variable(out),
940            }),
941            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
942                input: self.compile_variable(op.input),
943                out: self.compile_variable(out),
944            }),
945            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
946                input: self.compile_variable(op.input),
947                out: self.compile_variable(out),
948            }),
949            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
950                input: self.compile_variable(op.input),
951                out: self.compile_variable(out),
952            }),
953            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
954                lhs: self.compile_variable(op.lhs),
955                rhs: self.compile_variable(op.rhs),
956                out: self.compile_variable(out),
957            }),
958        }
959    }
960
961    fn compile_cmp(
962        &mut self,
963        value: cube::Comparison,
964        out: Option<cube::Variable>,
965        instructions: &mut Vec<wgsl::Instruction>,
966    ) {
967        let out = out.unwrap();
968        match value {
969            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
970                lhs: self.compile_variable(op.lhs),
971                rhs: self.compile_variable(op.rhs),
972                out: self.compile_variable(out),
973            }),
974            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
975                lhs: self.compile_variable(op.lhs),
976                rhs: self.compile_variable(op.rhs),
977                out: self.compile_variable(out),
978            }),
979            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
980                lhs: self.compile_variable(op.lhs),
981                rhs: self.compile_variable(op.rhs),
982                out: self.compile_variable(out),
983            }),
984            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
985                lhs: self.compile_variable(op.lhs),
986                rhs: self.compile_variable(op.rhs),
987                out: self.compile_variable(out),
988            }),
989            cube::Comparison::GreaterEqual(op) => {
990                instructions.push(wgsl::Instruction::GreaterEqual {
991                    lhs: self.compile_variable(op.lhs),
992                    rhs: self.compile_variable(op.rhs),
993                    out: self.compile_variable(out),
994                })
995            }
996            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
997                lhs: self.compile_variable(op.lhs),
998                rhs: self.compile_variable(op.rhs),
999                out: self.compile_variable(out),
1000            }),
1001            cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1002                input: self.compile_variable(op.input),
1003                out: self.compile_variable(out),
1004            }),
1005            cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1006                input: self.compile_variable(op.input),
1007                out: self.compile_variable(out),
1008            }),
1009        }
1010    }
1011
1012    fn compile_bitwise(
1013        &mut self,
1014        value: cube::Bitwise,
1015        out: Option<cube::Variable>,
1016        instructions: &mut Vec<wgsl::Instruction>,
1017    ) {
1018        let out = out.unwrap();
1019        match value {
1020            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1021                lhs: self.compile_variable(op.lhs),
1022                rhs: self.compile_variable(op.rhs),
1023                out: self.compile_variable(out),
1024            }),
1025            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1026                lhs: self.compile_variable(op.lhs),
1027                rhs: self.compile_variable(op.rhs),
1028                out: self.compile_variable(out),
1029            }),
1030            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1031                lhs: self.compile_variable(op.lhs),
1032                rhs: self.compile_variable(op.rhs),
1033                out: self.compile_variable(out),
1034            }),
1035            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1036                input: self.compile_variable(op.input),
1037                out: self.compile_variable(out),
1038            }),
1039            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1040                input: self.compile_variable(op.input),
1041                out: self.compile_variable(out),
1042            }),
1043            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1044                lhs: self.compile_variable(op.lhs),
1045                rhs: self.compile_variable(op.rhs),
1046                out: self.compile_variable(out),
1047            }),
1048            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1049                lhs: self.compile_variable(op.lhs),
1050                rhs: self.compile_variable(op.rhs),
1051                out: self.compile_variable(out),
1052            }),
1053            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1054                input: self.compile_variable(op.input),
1055                out: self.compile_variable(out),
1056            }),
1057            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1058                input: self.compile_variable(op.input),
1059                out: self.compile_variable(out),
1060            }),
1061            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1062                input: self.compile_variable(op.input),
1063                out: self.compile_variable(out),
1064            }),
1065        }
1066    }
1067
1068    fn compile_operator(
1069        &mut self,
1070        value: cube::Operator,
1071        out: Option<cube::Variable>,
1072        instructions: &mut Vec<wgsl::Instruction>,
1073    ) {
1074        let out = out.unwrap();
1075        match value {
1076            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1077                input: self.compile_variable(op.input),
1078                out: self.compile_variable(out),
1079            }),
1080            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1081                instructions.push(wgsl::Instruction::Index {
1082                    lhs: self.compile_variable(op.list),
1083                    rhs: self.compile_variable(op.index),
1084                    out: self.compile_variable(out),
1085                });
1086            }
1087            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1088                instructions.push(wgsl::Instruction::IndexAssign {
1089                    index: self.compile_variable(op.index),
1090                    rhs: self.compile_variable(op.value),
1091                    out: self.compile_variable(out),
1092                })
1093            }
1094            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1095                lhs: self.compile_variable(op.lhs),
1096                rhs: self.compile_variable(op.rhs),
1097                out: self.compile_variable(out),
1098            }),
1099            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1100                lhs: self.compile_variable(op.lhs),
1101                rhs: self.compile_variable(op.rhs),
1102                out: self.compile_variable(out),
1103            }),
1104            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1105                input: self.compile_variable(op.input),
1106                out: self.compile_variable(out),
1107            }),
1108            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1109                input: self.compile_variable(op.input),
1110                out: self.compile_variable(out),
1111            }),
1112            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1113                inputs: op
1114                    .inputs
1115                    .into_iter()
1116                    .map(|var| self.compile_variable(var))
1117                    .collect(),
1118                out: self.compile_variable(out),
1119            }),
1120            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1121                input: self.compile_variable(op.input),
1122                in_index: self.compile_variable(op.in_index),
1123                out: self.compile_variable(out),
1124                out_index: self.compile_variable(op.out_index),
1125            }),
1126            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1127                input: self.compile_variable(op.input),
1128                in_index: self.compile_variable(op.in_index),
1129                out: self.compile_variable(out),
1130                out_index: self.compile_variable(op.out_index),
1131                len: op.len,
1132            }),
1133            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1134                cond: self.compile_variable(op.cond),
1135                then: self.compile_variable(op.then),
1136                or_else: self.compile_variable(op.or_else),
1137                out: self.compile_variable(out),
1138            }),
1139        }
1140    }
1141
1142    fn compile_atomic(
1143        &mut self,
1144        atomic: cube::AtomicOp,
1145        out: Option<cube::Variable>,
1146    ) -> wgsl::Instruction {
1147        let out = out.unwrap();
1148        match atomic {
1149            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1150                lhs: self.compile_variable(op.lhs),
1151                rhs: self.compile_variable(op.rhs),
1152                out: self.compile_variable(out),
1153            },
1154            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1155                lhs: self.compile_variable(op.lhs),
1156                rhs: self.compile_variable(op.rhs),
1157                out: self.compile_variable(out),
1158            },
1159            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1160                lhs: self.compile_variable(op.lhs),
1161                rhs: self.compile_variable(op.rhs),
1162                out: self.compile_variable(out),
1163            },
1164            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1165                lhs: self.compile_variable(op.lhs),
1166                rhs: self.compile_variable(op.rhs),
1167                out: self.compile_variable(out),
1168            },
1169            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1170                lhs: self.compile_variable(op.lhs),
1171                rhs: self.compile_variable(op.rhs),
1172                out: self.compile_variable(out),
1173            },
1174            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1175                lhs: self.compile_variable(op.lhs),
1176                rhs: self.compile_variable(op.rhs),
1177                out: self.compile_variable(out),
1178            },
1179            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1180                lhs: self.compile_variable(op.lhs),
1181                rhs: self.compile_variable(op.rhs),
1182                out: self.compile_variable(out),
1183            },
1184            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1185                input: self.compile_variable(op.input),
1186                out: self.compile_variable(out),
1187            },
1188            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1189                input: self.compile_variable(op.input),
1190                out: self.compile_variable(out),
1191            },
1192            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1193                lhs: self.compile_variable(op.lhs),
1194                rhs: self.compile_variable(op.rhs),
1195                out: self.compile_variable(out),
1196            },
1197            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1198                lhs: self.compile_variable(op.input),
1199                cmp: self.compile_variable(op.cmp),
1200                value: self.compile_variable(op.val),
1201                out: self.compile_variable(out),
1202            },
1203        }
1204    }
1205
1206    fn compile_location(value: kernel::Location) -> wgsl::Location {
1207        match value {
1208            kernel::Location::Storage => wgsl::Location::Storage,
1209            kernel::Location::Cube => wgsl::Location::Workgroup,
1210        }
1211    }
1212
1213    fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1214        wgsl::Binding {
1215            id: value.id,
1216            visibility: value.visibility,
1217            location: Self::compile_location(value.location),
1218            item: self.compile_type(value.ty),
1219            size: value.size,
1220        }
1221    }
1222}
1223
1224fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1225    let mut extensions = Vec::new();
1226
1227    let mut register_extension = |extension: wgsl::Extension| {
1228        if !extensions.contains(&extension) {
1229            extensions.push(extension);
1230        }
1231    };
1232
1233    // Since not all instructions are native to WGSL, we need to add the custom ones.
1234    for instruction in instructions {
1235        match instruction {
1236            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1237                register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1238                register_extension(wgsl::powf_extension(rhs, out));
1239            }
1240            #[cfg(target_os = "macos")]
1241            wgsl::Instruction::Tanh { input, out: _ } => {
1242                register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1243                register_extension(wgsl::Extension::SafeTanh(input.item()));
1244            }
1245            wgsl::Instruction::IsNan { input, out } => {
1246                register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1247                register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1248            }
1249            wgsl::Instruction::IsInf { input, out } => {
1250                register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1251                register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1252            }
1253            wgsl::Instruction::If { instructions, .. } => {
1254                for extension in register_extensions(instructions) {
1255                    register_extension(extension);
1256                }
1257            }
1258            wgsl::Instruction::IfElse {
1259                instructions_if,
1260                instructions_else,
1261                ..
1262            } => {
1263                for extension in register_extensions(instructions_if) {
1264                    register_extension(extension);
1265                }
1266                for extension in register_extensions(instructions_else) {
1267                    register_extension(extension);
1268                }
1269            }
1270            wgsl::Instruction::Loop { instructions } => {
1271                for extension in register_extensions(instructions) {
1272                    register_extension(extension);
1273                }
1274            }
1275            wgsl::Instruction::RangeLoop { instructions, .. } => {
1276                for extension in register_extensions(instructions) {
1277                    register_extension(extension);
1278                }
1279            }
1280            _ => {}
1281        }
1282    }
1283
1284    extensions
1285}