cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedMemory};
4use crate::compiler::wgsl;
5
6use cubecl_common::ExecutionMode;
7use cubecl_core::post_processing::{
8    checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12    Metadata, WgpuCompilationOptions,
13    ir::{self as cube, Scope},
14    prelude::expand_erf,
15};
16use cubecl_core::{
17    ir::{ConstantScalarValue, Processor, UIntKind},
18    post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: u32 = 4;
24
25/// Wgsl Compiler.
26#[derive(Clone, Default)]
27pub struct WgslCompiler {
28    metadata: Metadata,
29    ext_meta_pos: Vec<u32>,
30    local_invocation_index: bool,
31    local_invocation_id: bool,
32    // TODO: possible cleanup, this bool seems to not be used
33    global_invocation_id: bool,
34    workgroup_id: bool,
35    subgroup_size: bool,
36    subgroup_invocation_id: bool,
37    id: bool,
38    num_workgroups: bool,
39    workgroup_id_no_axis: bool,
40    workgroup_size_no_axis: bool,
41    num_workgroup_no_axis: bool,
42    shared_memories: Vec<SharedMemory>,
43    const_arrays: Vec<ConstantArray>,
44    local_arrays: Vec<LocalArray>,
45    #[allow(dead_code)]
46    compilation_options: WgpuCompilationOptions,
47    strategy: ExecutionMode,
48    subgroup_instructions_used: bool,
49    f16_used: bool,
50}
51
52impl core::fmt::Debug for WgslCompiler {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.write_str("WgslCompiler")
55    }
56}
57
58impl cubecl_core::Compiler for WgslCompiler {
59    type Representation = ComputeShader;
60    type CompilationOptions = WgpuCompilationOptions;
61
62    fn compile(
63        &mut self,
64        shader: kernel::KernelDefinition,
65        compilation_options: &Self::CompilationOptions,
66        mode: ExecutionMode,
67    ) -> Result<Self::Representation, CompilationError> {
68        self.compilation_options = compilation_options.clone();
69        self.compile_shader(shader, mode)
70    }
71
72    fn elem_size(&self, elem: cube::ElemType) -> usize {
73        elem.size()
74    }
75
76    fn extension(&self) -> &'static str {
77        "wgsl"
78    }
79}
80
81impl WgslCompiler {
82    fn compile_shader(
83        &mut self,
84        mut value: kernel::KernelDefinition,
85        mode: ExecutionMode,
86    ) -> Result<wgsl::ComputeShader, CompilationError> {
87        self.strategy = mode;
88
89        let num_meta = value.buffers.len();
90
91        self.ext_meta_pos = Vec::new();
92        let mut num_ext = 0;
93
94        for binding in value.buffers.iter() {
95            self.ext_meta_pos.push(num_ext);
96            if binding.has_extended_meta {
97                num_ext += 1;
98            }
99        }
100
101        self.metadata = Metadata::new(num_meta as u32, num_ext);
102
103        let instructions = self.compile_scope(&mut value.body);
104        let extensions = register_extensions(&instructions);
105        let body = wgsl::Body {
106            instructions,
107            id: self.id,
108        };
109
110        Ok(wgsl::ComputeShader {
111            buffers: value
112                .buffers
113                .into_iter()
114                .map(|mut it| {
115                    // This is safe when combined with the unroll transform that adjusts all indices.
116                    // Must not be used alone
117                    if it.ty.line_size() > MAX_LINE_SIZE {
118                        it.ty = it.ty.line(MAX_LINE_SIZE);
119                    }
120                    self.compile_binding(it)
121                })
122                .collect(),
123            scalars: value
124                .scalars
125                .into_iter()
126                .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
127                .collect(),
128            shared_memories: self.shared_memories.clone(),
129            constant_arrays: self.const_arrays.clone(),
130            local_arrays: self.local_arrays.clone(),
131            has_metadata: self.metadata.static_len() > 0,
132            workgroup_size: value.cube_dim,
133            global_invocation_id: self.global_invocation_id || self.id,
134            local_invocation_index: self.local_invocation_index,
135            local_invocation_id: self.local_invocation_id,
136            num_workgroups: self.id
137                || self.num_workgroups
138                || self.num_workgroup_no_axis
139                || self.workgroup_id_no_axis,
140            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
141            subgroup_size: self.subgroup_size,
142            subgroup_invocation_id: self.subgroup_invocation_id,
143            body,
144            extensions,
145            num_workgroups_no_axis: self.num_workgroup_no_axis,
146            workgroup_id_no_axis: self.workgroup_id_no_axis,
147            workgroup_size_no_axis: self.workgroup_size_no_axis,
148            subgroup_instructions_used: self.subgroup_instructions_used,
149            f16_used: self.f16_used,
150            kernel_name: value.options.kernel_name,
151        })
152    }
153
154    fn compile_type(&mut self, item: cube::Type) -> Item {
155        match item {
156            cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
157            cube::Type::Line(ty, size) => {
158                let elem = self.compile_storage_type(ty);
159                match size {
160                    2 => wgsl::Item::Vec2(elem),
161                    3 => wgsl::Item::Vec3(elem),
162                    4 => wgsl::Item::Vec4(elem),
163                    _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
164                }
165            }
166            cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
167        }
168    }
169
170    fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
171        match ty {
172            cube::StorageType::Scalar(ty) => self.compile_elem(ty),
173            cube::StorageType::Atomic(ty) => match ty {
174                cube::ElemType::Float(i) => match i {
175                    cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
176                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
177                },
178                cube::ElemType::Int(i) => match i {
179                    cube::IntKind::I32 => wgsl::Elem::AtomicI32,
180                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
181                },
182                cube::ElemType::UInt(kind) => match kind {
183                    cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
184                    kind => panic!("{kind:?} is not a valid WgpuElement"),
185                },
186                other => panic!("{other:?} is not a valid WgpuElement"),
187            },
188            cube::StorageType::Packed(_, _) => {
189                unimplemented!("Packed types not yet supported in WGSL")
190            }
191        }
192    }
193
194    fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
195        match value {
196            cube::ElemType::Float(f) => match f {
197                cube::FloatKind::E2M1
198                | cube::FloatKind::E2M3
199                | cube::FloatKind::E3M2
200                | cube::FloatKind::E4M3
201                | cube::FloatKind::E5M2
202                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
203                cube::FloatKind::F16 => {
204                    self.f16_used = true;
205                    wgsl::Elem::F16
206                }
207                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
208                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
209                cube::FloatKind::Flex32 => wgsl::Elem::F32,
210                cube::FloatKind::F32 => wgsl::Elem::F32,
211                cube::FloatKind::F64 => wgsl::Elem::F64,
212            },
213            cube::ElemType::Int(i) => match i {
214                cube::IntKind::I32 => wgsl::Elem::I32,
215                cube::IntKind::I64 => wgsl::Elem::I64,
216                kind => panic!("{kind:?} is not a valid WgpuElement"),
217            },
218            cube::ElemType::UInt(kind) => match kind {
219                cube::UIntKind::U32 => wgsl::Elem::U32,
220                cube::UIntKind::U64 => wgsl::Elem::U64,
221                kind => panic!("{kind:?} is not a valid WgpuElement"),
222            },
223            cube::ElemType::Bool => wgsl::Elem::Bool,
224        }
225    }
226
227    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
228        let pos = var.index().expect("Variable should have index");
229        self.ext_meta_pos[pos as usize]
230    }
231
232    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
233        let item = value.ty;
234        match value.kind {
235            cube::VariableKind::GlobalInputArray(id) => {
236                wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
237            }
238            cube::VariableKind::GlobalScalar(id) => {
239                wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
240            }
241            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
242                wgsl::Variable::LocalMut {
243                    id,
244                    item: self.compile_type(item),
245                }
246            }
247            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
248                id,
249                item: self.compile_type(item),
250            },
251            cube::VariableKind::GlobalOutputArray(id) => {
252                wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
253            }
254            cube::VariableKind::ConstantScalar(value) => {
255                wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
256            }
257            cube::VariableKind::SharedMemory {
258                id,
259                length,
260                unroll_factor,
261                alignment,
262            } => {
263                let item = self.compile_type(item);
264                if !self.shared_memories.iter().any(|s| s.index == id) {
265                    self.shared_memories.push(SharedMemory::new(
266                        id,
267                        item,
268                        length * unroll_factor,
269                        alignment,
270                    ));
271                }
272                wgsl::Variable::SharedMemory(id, item, length)
273            }
274            cube::VariableKind::ConstantArray { id, length, .. } => {
275                let item = self.compile_type(item);
276                wgsl::Variable::ConstantArray(id, item, length)
277            }
278            cube::VariableKind::LocalArray {
279                id,
280                length,
281                unroll_factor,
282            } => {
283                let item = self.compile_type(item);
284                if !self.local_arrays.iter().any(|s| s.index == id) {
285                    self.local_arrays
286                        .push(LocalArray::new(id, item, length * unroll_factor));
287                }
288                wgsl::Variable::LocalArray(id, item, length)
289            }
290            cube::VariableKind::Builtin(builtin) => match builtin {
291                cube::Builtin::AbsolutePos => {
292                    self.id = true;
293                    wgsl::Variable::Id
294                }
295                cube::Builtin::UnitPos => {
296                    self.local_invocation_index = true;
297                    wgsl::Variable::LocalInvocationIndex
298                }
299                cube::Builtin::UnitPosX => {
300                    self.local_invocation_id = true;
301                    wgsl::Variable::LocalInvocationIdX
302                }
303                cube::Builtin::UnitPosY => {
304                    self.local_invocation_id = true;
305                    wgsl::Variable::LocalInvocationIdY
306                }
307                cube::Builtin::UnitPosZ => {
308                    self.local_invocation_id = true;
309                    wgsl::Variable::LocalInvocationIdZ
310                }
311                cube::Builtin::CubePosX => {
312                    self.workgroup_id = true;
313                    wgsl::Variable::WorkgroupIdX
314                }
315                cube::Builtin::CubePosY => {
316                    self.workgroup_id = true;
317                    wgsl::Variable::WorkgroupIdY
318                }
319                cube::Builtin::CubePosZ => {
320                    self.workgroup_id = true;
321                    wgsl::Variable::WorkgroupIdZ
322                }
323                cube::Builtin::CubePosCluster
324                | cube::Builtin::CubePosClusterX
325                | cube::Builtin::CubePosClusterY
326                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
327                cube::Builtin::AbsolutePosX => {
328                    self.global_invocation_id = true;
329                    wgsl::Variable::GlobalInvocationIdX
330                }
331                cube::Builtin::AbsolutePosY => {
332                    self.global_invocation_id = true;
333                    wgsl::Variable::GlobalInvocationIdY
334                }
335                cube::Builtin::AbsolutePosZ => {
336                    self.global_invocation_id = true;
337                    wgsl::Variable::GlobalInvocationIdZ
338                }
339                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
340                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
341                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
342                cube::Builtin::CubeClusterDim
343                | cube::Builtin::CubeClusterDimX
344                | cube::Builtin::CubeClusterDimY
345                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
346                cube::Builtin::CubeCountX => {
347                    self.num_workgroups = true;
348                    wgsl::Variable::NumWorkgroupsX
349                }
350                cube::Builtin::CubeCountY => {
351                    self.num_workgroups = true;
352                    wgsl::Variable::NumWorkgroupsY
353                }
354                cube::Builtin::CubeCountZ => {
355                    self.num_workgroups = true;
356                    wgsl::Variable::NumWorkgroupsZ
357                }
358                cube::Builtin::CubePos => {
359                    self.workgroup_id_no_axis = true;
360                    wgsl::Variable::WorkgroupId
361                }
362                cube::Builtin::CubeDim => {
363                    self.workgroup_size_no_axis = true;
364                    wgsl::Variable::WorkgroupSize
365                }
366                cube::Builtin::CubeCount => {
367                    self.num_workgroup_no_axis = true;
368                    wgsl::Variable::NumWorkgroups
369                }
370                cube::Builtin::PlaneDim => {
371                    self.subgroup_size = true;
372                    wgsl::Variable::SubgroupSize
373                }
374                cube::Builtin::UnitPosPlane => {
375                    self.subgroup_invocation_id = true;
376                    wgsl::Variable::SubgroupInvocationId
377                }
378            },
379            cube::VariableKind::Matrix { .. } => {
380                panic!("Cooperative matrix-multiply and accumulate not supported.")
381            }
382            cube::VariableKind::Pipeline { .. } => {
383                panic!("Pipeline not supported.")
384            }
385            cube::VariableKind::Barrier { .. } | cube::VariableKind::BarrierToken { .. } => {
386                panic!("Barrier not supported.")
387            }
388            cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
389            cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
390        }
391    }
392
393    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
394        let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
395        self.compile_variable(var)
396    }
397
398    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
399        let mut instructions = Vec::new();
400
401        let const_arrays = scope
402            .const_arrays
403            .drain(..)
404            .map(|(var, values)| ConstantArray {
405                index: var.index().unwrap(),
406                item: self.compile_type(var.ty),
407                size: values.len() as u32,
408                values: values
409                    .into_iter()
410                    .map(|val| self.compile_variable(val))
411                    .collect(),
412            })
413            .collect::<Vec<_>>();
414        self.const_arrays.extend(const_arrays);
415
416        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
417        let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
418        let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
419        let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
420
421        for mut var in processing.variables {
422            if var.ty.line_size() > MAX_LINE_SIZE {
423                var.ty = var.ty.line(MAX_LINE_SIZE);
424            }
425            instructions.push(wgsl::Instruction::DeclareVariable {
426                var: self.compile_variable(var),
427            });
428        }
429
430        processing
431            .instructions
432            .into_iter()
433            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
434
435        instructions
436    }
437
438    fn compile_operation(
439        &mut self,
440        instructions: &mut Vec<wgsl::Instruction>,
441        operation: cube::Operation,
442        out: Option<cube::Variable>,
443        scope: &mut cube::Scope,
444    ) {
445        match operation {
446            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
447                input: self.compile_variable(variable),
448                out: self.compile_variable(out.unwrap()),
449            }),
450            cube::Operation::Arithmetic(op) => {
451                self.compile_arithmetic(op, out, instructions, scope)
452            }
453            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
454            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
455            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
456            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
457            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
458            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
459            cube::Operation::Synchronization(val) => {
460                self.compile_synchronization(instructions, val)
461            }
462            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
463            cube::Operation::CoopMma(_) => {
464                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
465            }
466            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
467                self.compile_comment(instructions, content)
468            }
469            cube::Operation::NonSemantic(_) => {}
470            cube::Operation::Barrier(_) => {
471                panic!("Barrier isn't supported on wgpu.")
472            }
473            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
474            cube::Operation::Marker(_) => {}
475        }
476    }
477
478    fn compile_subgroup(
479        &mut self,
480        instructions: &mut Vec<wgsl::Instruction>,
481        subgroup: cube::Plane,
482        out: Option<cube::Variable>,
483    ) {
484        self.subgroup_instructions_used = true;
485
486        let out = out.unwrap();
487        let op = match subgroup {
488            cube::Plane::Elect => Subgroup::Elect {
489                out: self.compile_variable(out),
490            },
491            cube::Plane::All(op) => Subgroup::All {
492                input: self.compile_variable(op.input),
493                out: self.compile_variable(out),
494            },
495            cube::Plane::Any(op) => Subgroup::Any {
496                input: self.compile_variable(op.input),
497                out: self.compile_variable(out),
498            },
499            cube::Plane::Ballot(op) => Subgroup::Ballot {
500                input: self.compile_variable(op.input),
501                out: self.compile_variable(out),
502            },
503
504            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
505                lhs: self.compile_variable(op.lhs),
506                rhs: self.compile_variable(op.rhs),
507                out: self.compile_variable(out),
508            },
509
510            cube::Plane::Sum(op) => Subgroup::Sum {
511                input: self.compile_variable(op.input),
512                out: self.compile_variable(out),
513            },
514
515            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
516                input: self.compile_variable(op.input),
517                out: self.compile_variable(out),
518            },
519            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
520                input: self.compile_variable(op.input),
521                out: self.compile_variable(out),
522            },
523            cube::Plane::Prod(op) => Subgroup::Prod {
524                input: self.compile_variable(op.input),
525                out: self.compile_variable(out),
526            },
527            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
528                input: self.compile_variable(op.input),
529                out: self.compile_variable(out),
530            },
531            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
532                input: self.compile_variable(op.input),
533                out: self.compile_variable(out),
534            },
535            cube::Plane::Min(op) => Subgroup::Min {
536                input: self.compile_variable(op.input),
537                out: self.compile_variable(out),
538            },
539            cube::Plane::Max(op) => Subgroup::Max {
540                input: self.compile_variable(op.input),
541                out: self.compile_variable(out),
542            },
543            cube::Plane::Shuffle(op) => Subgroup::Shuffle {
544                lhs: self.compile_variable(op.lhs),
545                rhs: self.compile_variable(op.rhs),
546                out: self.compile_variable(out),
547            },
548            cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
549                lhs: self.compile_variable(op.lhs),
550                rhs: self.compile_variable(op.rhs),
551                out: self.compile_variable(out),
552            },
553            cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
554                lhs: self.compile_variable(op.lhs),
555                rhs: self.compile_variable(op.rhs),
556                out: self.compile_variable(out),
557            },
558            cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
559                lhs: self.compile_variable(op.lhs),
560                rhs: self.compile_variable(op.rhs),
561                out: self.compile_variable(out),
562            },
563        };
564
565        instructions.push(wgsl::Instruction::Subgroup(op));
566    }
567
568    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
569        match branch {
570            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
571                cond: self.compile_variable(op.cond),
572                instructions: self.compile_scope(&mut op.scope),
573            }),
574            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
575                cond: self.compile_variable(op.cond),
576                instructions_if: self.compile_scope(&mut op.scope_if),
577                instructions_else: self.compile_scope(&mut op.scope_else),
578            }),
579            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
580                value: self.compile_variable(op.value),
581                instructions_default: self.compile_scope(&mut op.scope_default),
582                cases: op
583                    .cases
584                    .into_iter()
585                    .map(|(val, mut scope)| {
586                        (self.compile_variable(val), self.compile_scope(&mut scope))
587                    })
588                    .collect(),
589            }),
590            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
591            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
592            cube::Branch::RangeLoop(mut range_loop) => {
593                instructions.push(wgsl::Instruction::RangeLoop {
594                    i: self.compile_variable(range_loop.i),
595                    start: self.compile_variable(range_loop.start),
596                    end: self.compile_variable(range_loop.end),
597                    step: range_loop.step.map(|it| self.compile_variable(it)),
598                    inclusive: range_loop.inclusive,
599                    instructions: self.compile_scope(&mut range_loop.scope),
600                })
601            }
602            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
603                instructions: self.compile_scope(&mut op.scope),
604            }),
605        };
606    }
607
608    fn compile_synchronization(
609        &mut self,
610        instructions: &mut Vec<wgsl::Instruction>,
611        synchronization: cube::Synchronization,
612    ) {
613        match synchronization {
614            cube::Synchronization::SyncCube => {
615                instructions.push(wgsl::Instruction::WorkgroupBarrier)
616            }
617            cube::Synchronization::SyncPlane => {
618                panic!("Synchronization within a plane is not supported in WGSL")
619            }
620            cube::Synchronization::SyncStorage => {
621                instructions.push(wgsl::Instruction::StorageBarrier)
622            }
623            cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
624        };
625    }
626
627    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
628        instructions.push(wgsl::Instruction::Comment { content })
629    }
630
631    fn compile_metadata(
632        &mut self,
633        metadata: cube::Metadata,
634        out: Option<cube::Variable>,
635    ) -> wgsl::Instruction {
636        let out = out.unwrap();
637        match metadata {
638            cube::Metadata::Rank { var } => {
639                let position = self.ext_meta_pos(&var);
640                let offset = self.metadata.rank_index(position);
641                wgsl::Instruction::Metadata {
642                    out: self.compile_variable(out),
643                    info_offset: self.compile_variable(offset.into()),
644                }
645            }
646            cube::Metadata::Stride { dim, var } => {
647                let position = self.ext_meta_pos(&var);
648                let offset = self.metadata.stride_offset_index(position);
649                wgsl::Instruction::ExtendedMeta {
650                    info_offset: self.compile_variable(offset.into()),
651                    dim: self.compile_variable(dim),
652                    out: self.compile_variable(out),
653                }
654            }
655            cube::Metadata::Shape { dim, var } => {
656                let position = self.ext_meta_pos(&var);
657                let offset = self.metadata.shape_offset_index(position);
658                wgsl::Instruction::ExtendedMeta {
659                    info_offset: self.compile_variable(offset.into()),
660                    dim: self.compile_variable(dim),
661                    out: self.compile_variable(out),
662                }
663            }
664            cube::Metadata::Length { var } => match var.kind {
665                cube::VariableKind::GlobalInputArray(id) => {
666                    let offset = self.metadata.len_index(id);
667                    wgsl::Instruction::Metadata {
668                        out: self.compile_variable(out),
669                        info_offset: self.compile_variable(offset.into()),
670                    }
671                }
672                cube::VariableKind::GlobalOutputArray(id) => {
673                    let offset = self.metadata.len_index(id);
674                    wgsl::Instruction::Metadata {
675                        out: self.compile_variable(out),
676                        info_offset: self.compile_variable(offset.into()),
677                    }
678                }
679                _ => wgsl::Instruction::Length {
680                    var: self.compile_variable(var),
681                    out: self.compile_variable(out),
682                },
683            },
684            cube::Metadata::BufferLength { var } => match var.kind {
685                cube::VariableKind::GlobalInputArray(id) => {
686                    let offset = self.metadata.buffer_len_index(id);
687                    wgsl::Instruction::Metadata {
688                        out: self.compile_variable(out),
689                        info_offset: self.compile_variable(offset.into()),
690                    }
691                }
692                cube::VariableKind::GlobalOutputArray(id) => {
693                    let offset = self.metadata.buffer_len_index(id);
694                    wgsl::Instruction::Metadata {
695                        out: self.compile_variable(out),
696                        info_offset: self.compile_variable(offset.into()),
697                    }
698                }
699                _ => wgsl::Instruction::Length {
700                    var: self.compile_variable(var),
701                    out: self.compile_variable(out),
702                },
703            },
704        }
705    }
706
707    fn compile_arithmetic(
708        &mut self,
709        value: cube::Arithmetic,
710        out: Option<cube::Variable>,
711        instructions: &mut Vec<wgsl::Instruction>,
712        scope: &mut Scope,
713    ) {
714        let out = out.unwrap();
715        match value {
716            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
717                lhs: self.compile_variable(op.lhs),
718                rhs: self.compile_variable(op.rhs),
719                out: self.compile_variable(out),
720            }),
721            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
722                lhs: self.compile_variable(op.lhs),
723                rhs: self.compile_variable(op.rhs),
724                out: self.compile_variable(out),
725            }),
726            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
727                lhs: self.compile_variable(op.lhs),
728                rhs: self.compile_variable(op.rhs),
729                out: self.compile_variable(out),
730            }),
731            cube::Arithmetic::SaturatingAdd(_) => {
732                unreachable!("Saturating add should be removed by processor");
733            }
734            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
735                a: self.compile_variable(op.a),
736                b: self.compile_variable(op.b),
737                c: self.compile_variable(op.c),
738                out: self.compile_variable(out),
739            }),
740            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
741                lhs: self.compile_variable(op.lhs),
742                rhs: self.compile_variable(op.rhs),
743                out: self.compile_variable(out),
744            }),
745            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
746                lhs: self.compile_variable(op.lhs),
747                rhs: self.compile_variable(op.rhs),
748                out: self.compile_variable(out),
749            }),
750            cube::Arithmetic::SaturatingSub(_) => {
751                unreachable!("Saturating sub should be removed by processor");
752            }
753            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
754                lhs: self.compile_variable(op.lhs),
755                rhs: self.compile_variable(op.rhs),
756                out: self.compile_variable(out),
757            }),
758            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
759                lhs: self.compile_variable(op.lhs),
760                rhs: self.compile_variable(op.rhs),
761                out: self.compile_variable(out),
762            }),
763            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
764                input: self.compile_variable(op.input),
765                out: self.compile_variable(out),
766            }),
767            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
768                input: self.compile_variable(op.input),
769                out: self.compile_variable(out),
770            }),
771            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
772                input: self.compile_variable(op.input),
773                out: self.compile_variable(out),
774            }),
775            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
776                input: self.compile_variable(op.input),
777                out: self.compile_variable(out),
778            }),
779            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
780                input: self.compile_variable(op.input),
781                out: self.compile_variable(out),
782            }),
783            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
784                input: self.compile_variable(op.input),
785                out: self.compile_variable(out),
786            }),
787            cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
788                input: self.compile_variable(op.input),
789                out: self.compile_variable(out),
790            }),
791            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
792                input: self.compile_variable(op.input),
793                out: self.compile_variable(out),
794            }),
795            cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
796                input: self.compile_variable(op.input),
797                out: self.compile_variable(out),
798            }),
799            cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
800                input: self.compile_variable(op.input),
801                out: self.compile_variable(out),
802            }),
803            cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
804                input: self.compile_variable(op.input),
805                out: self.compile_variable(out),
806            }),
807            cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
808                input: self.compile_variable(op.input),
809                out: self.compile_variable(out),
810            }),
811            cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
812                input: self.compile_variable(op.input),
813                out: self.compile_variable(out),
814            }),
815            cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
816                input: self.compile_variable(op.input),
817                out: self.compile_variable(out),
818            }),
819            cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
820                input: self.compile_variable(op.input),
821                out: self.compile_variable(out),
822            }),
823            cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
824                input: self.compile_variable(op.input),
825                out: self.compile_variable(out),
826            }),
827            cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
828                input: self.compile_variable(op.input),
829                out: self.compile_variable(out),
830            }),
831            cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
832                input: self.compile_variable(op.input),
833                out: self.compile_variable(out),
834            }),
835            cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
836                lhs: self.compile_variable(op.lhs),
837                rhs: self.compile_variable(op.rhs),
838                out: self.compile_variable(out),
839            }),
840            // No powi in WGSL
841            cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
842                instructions.push(wgsl::Instruction::Powf {
843                    lhs: self.compile_variable(op.lhs),
844                    rhs: self.compile_variable(op.rhs),
845                    out: self.compile_variable(out),
846                })
847            }
848            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
849                input: self.compile_variable(op.input),
850                out: self.compile_variable(out),
851            }),
852            cube::Arithmetic::InverseSqrt(op) => {
853                instructions.push(wgsl::Instruction::InverseSqrt {
854                    input: self.compile_variable(op.input),
855                    out: self.compile_variable(out),
856                })
857            }
858            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
859                input: self.compile_variable(op.input),
860                out: self.compile_variable(out),
861            }),
862            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
863                input: self.compile_variable(op.input),
864                out: self.compile_variable(out),
865            }),
866            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
867                input: self.compile_variable(op.input),
868                out: self.compile_variable(out),
869            }),
870            cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
871                input: self.compile_variable(op.input),
872                out: self.compile_variable(out),
873            }),
874            cube::Arithmetic::Erf(op) => {
875                let mut scope = scope.child();
876                expand_erf(&mut scope, op.input, out);
877                instructions.extend(self.compile_scope(&mut scope));
878            }
879            cube::Arithmetic::MulHi(op) => {
880                let mut scope = scope.child();
881                match self.compilation_options.supports_u64 {
882                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
883                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
884                }
885                instructions.extend(self.compile_scope(&mut scope));
886            }
887            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
888                input: self.compile_variable(op.input),
889                out: self.compile_variable(out),
890            }),
891            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
892                input: self.compile_variable(op.input),
893                min_value: self.compile_variable(op.min_value),
894                max_value: self.compile_variable(op.max_value),
895                out: self.compile_variable(out),
896            }),
897            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
898                lhs: self.compile_variable(op.lhs),
899                rhs: self.compile_variable(op.rhs),
900                out: self.compile_variable(out),
901            }),
902            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
903                input: self.compile_variable(op.input),
904                out: self.compile_variable(out),
905            }),
906            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
907                input: self.compile_variable(op.input),
908                out: self.compile_variable(out),
909            }),
910            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
911                input: self.compile_variable(op.input),
912                out: self.compile_variable(out),
913            }),
914            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
915                lhs: self.compile_variable(op.lhs),
916                rhs: self.compile_variable(op.rhs),
917                out: self.compile_variable(out),
918            }),
919        }
920    }
921
922    fn compile_cmp(
923        &mut self,
924        value: cube::Comparison,
925        out: Option<cube::Variable>,
926        instructions: &mut Vec<wgsl::Instruction>,
927    ) {
928        let out = out.unwrap();
929        match value {
930            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
931                lhs: self.compile_variable(op.lhs),
932                rhs: self.compile_variable(op.rhs),
933                out: self.compile_variable(out),
934            }),
935            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
936                lhs: self.compile_variable(op.lhs),
937                rhs: self.compile_variable(op.rhs),
938                out: self.compile_variable(out),
939            }),
940            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
941                lhs: self.compile_variable(op.lhs),
942                rhs: self.compile_variable(op.rhs),
943                out: self.compile_variable(out),
944            }),
945            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
946                lhs: self.compile_variable(op.lhs),
947                rhs: self.compile_variable(op.rhs),
948                out: self.compile_variable(out),
949            }),
950            cube::Comparison::GreaterEqual(op) => {
951                instructions.push(wgsl::Instruction::GreaterEqual {
952                    lhs: self.compile_variable(op.lhs),
953                    rhs: self.compile_variable(op.rhs),
954                    out: self.compile_variable(out),
955                })
956            }
957            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
958                lhs: self.compile_variable(op.lhs),
959                rhs: self.compile_variable(op.rhs),
960                out: self.compile_variable(out),
961            }),
962            cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
963                input: self.compile_variable(op.input),
964                out: self.compile_variable(out),
965            }),
966            cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
967                input: self.compile_variable(op.input),
968                out: self.compile_variable(out),
969            }),
970        }
971    }
972
973    fn compile_bitwise(
974        &mut self,
975        value: cube::Bitwise,
976        out: Option<cube::Variable>,
977        instructions: &mut Vec<wgsl::Instruction>,
978    ) {
979        let out = out.unwrap();
980        match value {
981            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
982                lhs: self.compile_variable(op.lhs),
983                rhs: self.compile_variable(op.rhs),
984                out: self.compile_variable(out),
985            }),
986            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
987                lhs: self.compile_variable(op.lhs),
988                rhs: self.compile_variable(op.rhs),
989                out: self.compile_variable(out),
990            }),
991            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
992                lhs: self.compile_variable(op.lhs),
993                rhs: self.compile_variable(op.rhs),
994                out: self.compile_variable(out),
995            }),
996            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
997                input: self.compile_variable(op.input),
998                out: self.compile_variable(out),
999            }),
1000            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1001                input: self.compile_variable(op.input),
1002                out: self.compile_variable(out),
1003            }),
1004            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1005                lhs: self.compile_variable(op.lhs),
1006                rhs: self.compile_variable(op.rhs),
1007                out: self.compile_variable(out),
1008            }),
1009            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1010                lhs: self.compile_variable(op.lhs),
1011                rhs: self.compile_variable(op.rhs),
1012                out: self.compile_variable(out),
1013            }),
1014            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1015                input: self.compile_variable(op.input),
1016                out: self.compile_variable(out),
1017            }),
1018            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1019                input: self.compile_variable(op.input),
1020                out: self.compile_variable(out),
1021            }),
1022            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1023                input: self.compile_variable(op.input),
1024                out: self.compile_variable(out),
1025            }),
1026        }
1027    }
1028
1029    fn compile_operator(
1030        &mut self,
1031        value: cube::Operator,
1032        out: Option<cube::Variable>,
1033        instructions: &mut Vec<wgsl::Instruction>,
1034    ) {
1035        let out = out.unwrap();
1036        match value {
1037            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1038                input: self.compile_variable(op.input),
1039                out: self.compile_variable(out),
1040            }),
1041            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1042                instructions.push(wgsl::Instruction::Index {
1043                    lhs: self.compile_variable(op.list),
1044                    rhs: self.compile_variable(op.index),
1045                    out: self.compile_variable(out),
1046                });
1047            }
1048            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1049                instructions.push(wgsl::Instruction::IndexAssign {
1050                    index: self.compile_variable(op.index),
1051                    rhs: self.compile_variable(op.value),
1052                    out: self.compile_variable(out),
1053                })
1054            }
1055            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1056                lhs: self.compile_variable(op.lhs),
1057                rhs: self.compile_variable(op.rhs),
1058                out: self.compile_variable(out),
1059            }),
1060            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1061                lhs: self.compile_variable(op.lhs),
1062                rhs: self.compile_variable(op.rhs),
1063                out: self.compile_variable(out),
1064            }),
1065            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1066                input: self.compile_variable(op.input),
1067                out: self.compile_variable(out),
1068            }),
1069            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1070                input: self.compile_variable(op.input),
1071                out: self.compile_variable(out),
1072            }),
1073            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1074                inputs: op
1075                    .inputs
1076                    .into_iter()
1077                    .map(|var| self.compile_variable(var))
1078                    .collect(),
1079                out: self.compile_variable(out),
1080            }),
1081            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1082                input: self.compile_variable(op.input),
1083                in_index: self.compile_variable(op.in_index),
1084                out: self.compile_variable(out),
1085                out_index: self.compile_variable(op.out_index),
1086            }),
1087            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1088                input: self.compile_variable(op.input),
1089                in_index: self.compile_variable(op.in_index),
1090                out: self.compile_variable(out),
1091                out_index: self.compile_variable(op.out_index),
1092                len: op.len,
1093            }),
1094            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1095                cond: self.compile_variable(op.cond),
1096                then: self.compile_variable(op.then),
1097                or_else: self.compile_variable(op.or_else),
1098                out: self.compile_variable(out),
1099            }),
1100        }
1101    }
1102
1103    fn compile_atomic(
1104        &mut self,
1105        atomic: cube::AtomicOp,
1106        out: Option<cube::Variable>,
1107    ) -> wgsl::Instruction {
1108        let out = out.unwrap();
1109        match atomic {
1110            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1111                lhs: self.compile_variable(op.lhs),
1112                rhs: self.compile_variable(op.rhs),
1113                out: self.compile_variable(out),
1114            },
1115            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1116                lhs: self.compile_variable(op.lhs),
1117                rhs: self.compile_variable(op.rhs),
1118                out: self.compile_variable(out),
1119            },
1120            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1121                lhs: self.compile_variable(op.lhs),
1122                rhs: self.compile_variable(op.rhs),
1123                out: self.compile_variable(out),
1124            },
1125            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1126                lhs: self.compile_variable(op.lhs),
1127                rhs: self.compile_variable(op.rhs),
1128                out: self.compile_variable(out),
1129            },
1130            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1131                lhs: self.compile_variable(op.lhs),
1132                rhs: self.compile_variable(op.rhs),
1133                out: self.compile_variable(out),
1134            },
1135            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1136                lhs: self.compile_variable(op.lhs),
1137                rhs: self.compile_variable(op.rhs),
1138                out: self.compile_variable(out),
1139            },
1140            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1141                lhs: self.compile_variable(op.lhs),
1142                rhs: self.compile_variable(op.rhs),
1143                out: self.compile_variable(out),
1144            },
1145            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1146                input: self.compile_variable(op.input),
1147                out: self.compile_variable(out),
1148            },
1149            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1150                input: self.compile_variable(op.input),
1151                out: self.compile_variable(out),
1152            },
1153            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1154                lhs: self.compile_variable(op.lhs),
1155                rhs: self.compile_variable(op.rhs),
1156                out: self.compile_variable(out),
1157            },
1158            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1159                lhs: self.compile_variable(op.input),
1160                cmp: self.compile_variable(op.cmp),
1161                value: self.compile_variable(op.val),
1162                out: self.compile_variable(out),
1163            },
1164        }
1165    }
1166
1167    fn compile_location(value: kernel::Location) -> wgsl::Location {
1168        match value {
1169            kernel::Location::Storage => wgsl::Location::Storage,
1170            kernel::Location::Cube => wgsl::Location::Workgroup,
1171        }
1172    }
1173
1174    fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1175        wgsl::Binding {
1176            id: value.id,
1177            visibility: value.visibility,
1178            location: Self::compile_location(value.location),
1179            item: self.compile_type(value.ty),
1180            size: value.size,
1181        }
1182    }
1183}
1184
1185fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1186    let mut extensions = Vec::new();
1187
1188    let mut register_extension = |extension: wgsl::Extension| {
1189        if !extensions.contains(&extension) {
1190            extensions.push(extension);
1191        }
1192    };
1193
1194    // Since not all instructions are native to WGSL, we need to add the custom ones.
1195    for instruction in instructions {
1196        match instruction {
1197            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1198                register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1199                register_extension(wgsl::powf_extension(rhs, out));
1200            }
1201            #[cfg(target_os = "macos")]
1202            wgsl::Instruction::Tanh { input, out: _ } => {
1203                register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1204                register_extension(wgsl::Extension::SafeTanh(input.item()));
1205            }
1206            wgsl::Instruction::IsNan { input, out } => {
1207                register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1208                register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1209            }
1210            wgsl::Instruction::IsInf { input, out } => {
1211                register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1212                register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1213            }
1214            wgsl::Instruction::If { instructions, .. } => {
1215                for extension in register_extensions(instructions) {
1216                    register_extension(extension);
1217                }
1218            }
1219            wgsl::Instruction::IfElse {
1220                instructions_if,
1221                instructions_else,
1222                ..
1223            } => {
1224                for extension in register_extensions(instructions_if) {
1225                    register_extension(extension);
1226                }
1227                for extension in register_extensions(instructions_else) {
1228                    register_extension(extension);
1229                }
1230            }
1231            wgsl::Instruction::Loop { instructions } => {
1232                for extension in register_extensions(instructions) {
1233                    register_extension(extension);
1234                }
1235            }
1236            wgsl::Instruction::RangeLoop { instructions, .. } => {
1237                for extension in register_extensions(instructions) {
1238                    register_extension(extension);
1239                }
1240            }
1241            _ => {}
1242        }
1243    }
1244
1245    extensions
1246}