Skip to main content

cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::prelude::*;
8use cubecl_core::{
9    Info,
10    post_processing::{checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor},
11};
12use cubecl_core::{
13    Metadata, WgpuCompilationOptions,
14    ir::{self as cube, Scope},
15    prelude::expand_erf,
16};
17use cubecl_core::{
18    ir::{Processor, UIntKind},
19    post_processing::unroll::UnrollProcessor,
20};
21use cubecl_runtime::compiler::CompilationError;
22use cubecl_runtime::kernel;
23
24pub const MAX_VECTOR_SIZE: usize = 4;
25
26/// Wgsl Compiler.
27#[derive(Clone, Default)]
28pub struct WgslCompiler {
29    kernel_name: String,
30    info: Info,
31    ext_meta_pos: Vec<u32>,
32    local_invocation_index: bool,
33    local_invocation_id: bool,
34    // TODO: possible cleanup, this bool seems to not be used
35    global_invocation_id: bool,
36    workgroup_id: bool,
37    subgroup_size: bool,
38    subgroup_id: bool,
39    subgroup_invocation_id: bool,
40    id: bool,
41    num_workgroups: bool,
42    workgroup_id_no_axis: bool,
43    workgroup_size_no_axis: bool,
44    num_workgroup_no_axis: bool,
45    shared_arrays: Vec<SharedArray>,
46    shared_values: Vec<SharedValue>,
47    const_arrays: Vec<ConstantArray>,
48    local_arrays: Vec<LocalArray>,
49    #[allow(dead_code)]
50    compilation_options: WgpuCompilationOptions,
51    strategy: ExecutionMode,
52    subgroup_instructions_used: bool,
53    f16_used: bool,
54}
55
56impl core::fmt::Debug for WgslCompiler {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.write_str("WgslCompiler")
59    }
60}
61
62impl cubecl_core::Compiler for WgslCompiler {
63    type Representation = ComputeShader;
64    type CompilationOptions = WgpuCompilationOptions;
65
66    fn compile(
67        &mut self,
68        shader: kernel::KernelDefinition,
69        compilation_options: &Self::CompilationOptions,
70        mode: ExecutionMode,
71        address_type: StorageType,
72    ) -> Result<Self::Representation, CompilationError> {
73        self.compilation_options = *compilation_options;
74        self.compile_shader(shader, mode, address_type)
75    }
76
77    fn elem_size(&self, elem: cube::ElemType) -> usize {
78        elem.size()
79    }
80
81    fn extension(&self) -> &'static str {
82        "wgsl"
83    }
84}
85
86impl WgslCompiler {
87    fn compile_shader(
88        &mut self,
89        mut value: kernel::KernelDefinition,
90        mode: ExecutionMode,
91        address_type: StorageType,
92    ) -> Result<wgsl::ComputeShader, CompilationError> {
93        let errors = value.body.pop_errors();
94        if !errors.is_empty() {
95            let mut reason = "Can't compile wgsl kernel".to_string();
96            for error in errors {
97                reason += error.as_str();
98                reason += "\n";
99            }
100
101            return Err(CompilationError::Validation {
102                reason,
103                backtrace: BackTrace::capture(),
104            });
105        }
106
107        self.strategy = mode;
108        self.kernel_name = value.options.kernel_name.clone();
109
110        let num_meta = value.buffers.len();
111
112        self.ext_meta_pos = Vec::new();
113        let mut num_ext = 0;
114
115        for binding in value.buffers.iter() {
116            self.ext_meta_pos.push(num_ext);
117            if binding.has_extended_meta {
118                num_ext += 1;
119            }
120        }
121
122        let metadata = Metadata::new(num_meta as u32, num_ext);
123        self.info = Info::new(&value.scalars, metadata, address_type);
124
125        let address_type = self.compile_storage_type(address_type);
126        let instructions = self.compile_scope(&mut value.body);
127        let extensions = register_extensions(&instructions);
128        let body = wgsl::Body {
129            instructions,
130            id: self.id,
131            address_type,
132        };
133
134        Ok(wgsl::ComputeShader {
135            address_type,
136            buffers: value
137                .buffers
138                .into_iter()
139                .map(|mut it| {
140                    // This is safe when combined with the unroll transform that adjusts all indices.
141                    // Must not be used alone
142                    if it.ty.vector_size() > MAX_VECTOR_SIZE {
143                        it.ty = it.ty.with_vector_size(MAX_VECTOR_SIZE);
144                    }
145                    self.compile_binding(it)
146                })
147                .collect(),
148            scalars: value
149                .scalars
150                .into_iter()
151                .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
152                .collect(),
153            shared_arrays: self.shared_arrays.clone(),
154            shared_values: self.shared_values.clone(),
155            constant_arrays: self.const_arrays.clone(),
156            local_arrays: self.local_arrays.clone(),
157            static_meta_len: self.info.metadata.static_len() as usize,
158            info: self.info.clone(),
159            workgroup_size: value.cube_dim,
160            global_invocation_id: self.global_invocation_id || self.id,
161            local_invocation_index: self.local_invocation_index,
162            local_invocation_id: self.local_invocation_id,
163            num_workgroups: self.id
164                || self.num_workgroups
165                || self.num_workgroup_no_axis
166                || self.workgroup_id_no_axis,
167            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
168            subgroup_size: self.subgroup_size,
169            subgroup_id: self.subgroup_id,
170            subgroup_invocation_id: self.subgroup_invocation_id,
171            body,
172            extensions,
173            num_workgroups_no_axis: self.num_workgroup_no_axis,
174            workgroup_id_no_axis: self.workgroup_id_no_axis,
175            workgroup_size_no_axis: self.workgroup_size_no_axis,
176            subgroup_instructions_used: self.subgroup_instructions_used,
177            f16_used: self.f16_used,
178            kernel_name: value.options.kernel_name,
179        })
180    }
181
182    fn compile_type(&mut self, item: cube::Type) -> Item {
183        match item {
184            cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
185            cube::Type::Vector(ty, size) => {
186                let elem = self.compile_storage_type(ty);
187                match size {
188                    2 => wgsl::Item::Vec2(elem),
189                    3 => wgsl::Item::Vec3(elem),
190                    4 => wgsl::Item::Vec4(elem),
191                    _ => panic!("Unsupported vectorizations scheme {:?}", item.vector_size()),
192                }
193            }
194            cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
195        }
196    }
197
198    fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
199        match ty {
200            cube::StorageType::Scalar(ty) => self.compile_elem(ty),
201            cube::StorageType::Atomic(ty) => match ty {
202                cube::ElemType::Float(i) => match i {
203                    cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
204                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
205                },
206                cube::ElemType::Int(i) => match i {
207                    cube::IntKind::I32 => wgsl::Elem::AtomicI32,
208                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
209                },
210                cube::ElemType::UInt(kind) => match kind {
211                    cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
212                    kind => panic!("{kind:?} is not a valid WgpuElement"),
213                },
214                other => panic!("{other:?} is not a valid WgpuElement"),
215            },
216            cube::StorageType::Packed(_, _) => {
217                unimplemented!("Packed types not yet supported in WGSL")
218            }
219            cube::StorageType::Opaque(ty) => match ty {
220                cube::OpaqueType::Barrier(_) => {
221                    unimplemented!("Barrier objects not supported in WGSL")
222                }
223            },
224        }
225    }
226
227    fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
228        match value {
229            cube::ElemType::Float(f) => match f {
230                cube::FloatKind::E2M1
231                | cube::FloatKind::E2M3
232                | cube::FloatKind::E3M2
233                | cube::FloatKind::E4M3
234                | cube::FloatKind::E5M2
235                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
236                cube::FloatKind::F16 => {
237                    self.f16_used = true;
238                    wgsl::Elem::F16
239                }
240                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
241                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
242                cube::FloatKind::Flex32 => wgsl::Elem::F32,
243                cube::FloatKind::F32 => wgsl::Elem::F32,
244                cube::FloatKind::F64 => wgsl::Elem::F64,
245            },
246            cube::ElemType::Int(i) => match i {
247                cube::IntKind::I32 => wgsl::Elem::I32,
248                cube::IntKind::I64 => wgsl::Elem::I64,
249                kind => panic!("{kind:?} is not a valid WgpuElement"),
250            },
251            cube::ElemType::UInt(kind) => match kind {
252                cube::UIntKind::U32 => wgsl::Elem::U32,
253                cube::UIntKind::U64 => wgsl::Elem::U64,
254                kind => panic!("{kind:?} is not a valid WgpuElement"),
255            },
256            cube::ElemType::Bool => wgsl::Elem::Bool,
257        }
258    }
259
260    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
261        let pos = var.index().expect("Variable should have index");
262        self.ext_meta_pos[pos as usize]
263    }
264
265    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
266        let item = value.ty;
267        match value.kind {
268            cube::VariableKind::GlobalInputArray(id) => {
269                wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
270            }
271            cube::VariableKind::GlobalScalar(id) => {
272                wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
273            }
274            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
275                wgsl::Variable::LocalMut {
276                    id,
277                    item: self.compile_type(item),
278                }
279            }
280            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
281                id,
282                item: self.compile_type(item),
283            },
284            cube::VariableKind::GlobalOutputArray(id) => {
285                wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
286            }
287            cube::VariableKind::Constant(value) => {
288                wgsl::Variable::Constant(value, self.compile_type(item))
289            }
290            cube::VariableKind::SharedArray {
291                id,
292                length,
293                unroll_factor,
294                alignment,
295            } => {
296                let item = self.compile_type(item);
297                if !self.shared_arrays.iter().any(|s| s.index == id) {
298                    self.shared_arrays.push(SharedArray::new(
299                        id,
300                        item,
301                        (length * unroll_factor) as u32,
302                        alignment.map(|it| it as u32),
303                    ));
304                }
305                wgsl::Variable::SharedArray(id, item, length as u32)
306            }
307            cube::VariableKind::Shared { id } => {
308                let item = self.compile_type(item);
309                if !self.shared_values.iter().any(|s| s.index == id) {
310                    self.shared_values.push(SharedValue::new(id, item));
311                }
312                wgsl::Variable::SharedValue(id, item)
313            }
314            cube::VariableKind::ConstantArray { id, length, .. } => {
315                let item = self.compile_type(item);
316                wgsl::Variable::ConstantArray(id, item, length as u32)
317            }
318            cube::VariableKind::LocalArray {
319                id,
320                length,
321                unroll_factor,
322            } => {
323                let item = self.compile_type(item);
324                if !self.local_arrays.iter().any(|s| s.index == id) {
325                    self.local_arrays.push(LocalArray::new(
326                        id,
327                        item,
328                        (length * unroll_factor) as u32,
329                    ));
330                }
331                wgsl::Variable::LocalArray(id, item, length as u32)
332            }
333            cube::VariableKind::Builtin(builtin) => match builtin {
334                cube::Builtin::AbsolutePos => {
335                    self.id = true;
336                    wgsl::Variable::Id
337                }
338                cube::Builtin::UnitPos => {
339                    self.local_invocation_index = true;
340                    wgsl::Variable::LocalInvocationIndex
341                }
342                cube::Builtin::UnitPosX => {
343                    self.local_invocation_id = true;
344                    wgsl::Variable::LocalInvocationIdX
345                }
346                cube::Builtin::UnitPosY => {
347                    self.local_invocation_id = true;
348                    wgsl::Variable::LocalInvocationIdY
349                }
350                cube::Builtin::UnitPosZ => {
351                    self.local_invocation_id = true;
352                    wgsl::Variable::LocalInvocationIdZ
353                }
354                cube::Builtin::CubePosX => {
355                    self.workgroup_id = true;
356                    wgsl::Variable::WorkgroupIdX
357                }
358                cube::Builtin::CubePosY => {
359                    self.workgroup_id = true;
360                    wgsl::Variable::WorkgroupIdY
361                }
362                cube::Builtin::CubePosZ => {
363                    self.workgroup_id = true;
364                    wgsl::Variable::WorkgroupIdZ
365                }
366                cube::Builtin::CubePosCluster
367                | cube::Builtin::CubePosClusterX
368                | cube::Builtin::CubePosClusterY
369                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
370                cube::Builtin::AbsolutePosX => {
371                    self.global_invocation_id = true;
372                    wgsl::Variable::GlobalInvocationIdX
373                }
374                cube::Builtin::AbsolutePosY => {
375                    self.global_invocation_id = true;
376                    wgsl::Variable::GlobalInvocationIdY
377                }
378                cube::Builtin::AbsolutePosZ => {
379                    self.global_invocation_id = true;
380                    wgsl::Variable::GlobalInvocationIdZ
381                }
382                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
383                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
384                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
385                cube::Builtin::CubeClusterDim
386                | cube::Builtin::CubeClusterDimX
387                | cube::Builtin::CubeClusterDimY
388                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
389                cube::Builtin::CubeCountX => {
390                    self.num_workgroups = true;
391                    wgsl::Variable::NumWorkgroupsX
392                }
393                cube::Builtin::CubeCountY => {
394                    self.num_workgroups = true;
395                    wgsl::Variable::NumWorkgroupsY
396                }
397                cube::Builtin::CubeCountZ => {
398                    self.num_workgroups = true;
399                    wgsl::Variable::NumWorkgroupsZ
400                }
401                cube::Builtin::CubePos => {
402                    self.workgroup_id_no_axis = true;
403                    wgsl::Variable::WorkgroupId
404                }
405                cube::Builtin::CubeDim => {
406                    self.workgroup_size_no_axis = true;
407                    wgsl::Variable::WorkgroupSize
408                }
409                cube::Builtin::CubeCount => {
410                    self.num_workgroup_no_axis = true;
411                    wgsl::Variable::NumWorkgroups
412                }
413                cube::Builtin::PlaneDim => {
414                    self.subgroup_size = true;
415                    wgsl::Variable::SubgroupSize
416                }
417                cube::Builtin::PlanePos => {
418                    self.subgroup_id = true;
419                    wgsl::Variable::SubgroupId
420                }
421                cube::Builtin::UnitPosPlane => {
422                    self.subgroup_invocation_id = true;
423                    wgsl::Variable::SubgroupInvocationId
424                }
425            },
426            cube::VariableKind::Matrix { .. } => {
427                panic!("Cooperative matrix-multiply and accumulate not supported.")
428            }
429            cube::VariableKind::Pipeline { .. } => {
430                panic!("Pipeline not supported.")
431            }
432            cube::VariableKind::BarrierToken { .. } => {
433                panic!("Barrier not supported.")
434            }
435            cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
436            cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
437        }
438    }
439
440    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
441        let var = cube::Variable::constant(value.into(), UIntKind::U32);
442        self.compile_variable(var)
443    }
444
445    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
446        let mut instructions = Vec::new();
447
448        let const_arrays = scope
449            .const_arrays
450            .drain(..)
451            .map(|(var, values)| ConstantArray {
452                index: var.index().unwrap(),
453                item: self.compile_type(var.ty),
454                size: values.len() as u32,
455                values: values
456                    .into_iter()
457                    .map(|val| self.compile_variable(val))
458                    .collect(),
459            })
460            .collect::<Vec<_>>();
461        self.const_arrays.extend(const_arrays);
462
463        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(
464            self.strategy,
465            self.kernel_name.clone(),
466        ));
467        let unroll = Box::new(UnrollProcessor::new(MAX_VECTOR_SIZE));
468        let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
469        let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
470
471        for mut var in processing.variables {
472            if var.ty.vector_size() > MAX_VECTOR_SIZE {
473                var.ty = var.ty.with_vector_size(MAX_VECTOR_SIZE);
474            }
475            instructions.push(wgsl::Instruction::DeclareVariable {
476                var: self.compile_variable(var),
477            });
478        }
479
480        processing
481            .instructions
482            .into_iter()
483            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
484
485        instructions
486    }
487
488    fn compile_operation(
489        &mut self,
490        instructions: &mut Vec<wgsl::Instruction>,
491        operation: cube::Operation,
492        out: Option<cube::Variable>,
493        scope: &mut cube::Scope,
494    ) {
495        match operation {
496            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
497                input: self.compile_variable(variable),
498                out: self.compile_variable(out.unwrap()),
499            }),
500            cube::Operation::Arithmetic(op) => {
501                self.compile_arithmetic(op, out, instructions, scope)
502            }
503            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
504            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
505            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
506            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
507            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
508            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
509            cube::Operation::Synchronization(val) => {
510                self.compile_synchronization(instructions, val)
511            }
512            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
513            cube::Operation::CoopMma(_) => {
514                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
515            }
516            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
517                self.compile_comment(instructions, content)
518            }
519            cube::Operation::NonSemantic(_) => {}
520            cube::Operation::Barrier(_) => {
521                panic!("Barrier isn't supported on wgpu.")
522            }
523            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
524            cube::Operation::Marker(_) => {}
525        }
526    }
527
528    fn compile_subgroup(
529        &mut self,
530        instructions: &mut Vec<wgsl::Instruction>,
531        subgroup: cube::Plane,
532        out: Option<cube::Variable>,
533    ) {
534        self.subgroup_instructions_used = true;
535
536        let out = out.unwrap();
537        let op = match subgroup {
538            cube::Plane::Elect => Subgroup::Elect {
539                out: self.compile_variable(out),
540            },
541            cube::Plane::All(op) => Subgroup::All {
542                input: self.compile_variable(op.input),
543                out: self.compile_variable(out),
544            },
545            cube::Plane::Any(op) => Subgroup::Any {
546                input: self.compile_variable(op.input),
547                out: self.compile_variable(out),
548            },
549            cube::Plane::Ballot(op) => Subgroup::Ballot {
550                input: self.compile_variable(op.input),
551                out: self.compile_variable(out),
552            },
553
554            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
555                lhs: self.compile_variable(op.lhs),
556                rhs: self.compile_variable(op.rhs),
557                out: self.compile_variable(out),
558            },
559
560            cube::Plane::Sum(op) => Subgroup::Sum {
561                input: self.compile_variable(op.input),
562                out: self.compile_variable(out),
563            },
564
565            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
566                input: self.compile_variable(op.input),
567                out: self.compile_variable(out),
568            },
569            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
570                input: self.compile_variable(op.input),
571                out: self.compile_variable(out),
572            },
573            cube::Plane::Prod(op) => Subgroup::Prod {
574                input: self.compile_variable(op.input),
575                out: self.compile_variable(out),
576            },
577            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
578                input: self.compile_variable(op.input),
579                out: self.compile_variable(out),
580            },
581            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
582                input: self.compile_variable(op.input),
583                out: self.compile_variable(out),
584            },
585            cube::Plane::Min(op) => Subgroup::Min {
586                input: self.compile_variable(op.input),
587                out: self.compile_variable(out),
588            },
589            cube::Plane::Max(op) => Subgroup::Max {
590                input: self.compile_variable(op.input),
591                out: self.compile_variable(out),
592            },
593            cube::Plane::Shuffle(op) => Subgroup::Shuffle {
594                lhs: self.compile_variable(op.lhs),
595                rhs: self.compile_variable(op.rhs),
596                out: self.compile_variable(out),
597            },
598            cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
599                lhs: self.compile_variable(op.lhs),
600                rhs: self.compile_variable(op.rhs),
601                out: self.compile_variable(out),
602            },
603            cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
604                lhs: self.compile_variable(op.lhs),
605                rhs: self.compile_variable(op.rhs),
606                out: self.compile_variable(out),
607            },
608            cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
609                lhs: self.compile_variable(op.lhs),
610                rhs: self.compile_variable(op.rhs),
611                out: self.compile_variable(out),
612            },
613        };
614
615        instructions.push(wgsl::Instruction::Subgroup(op));
616    }
617
618    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
619        match branch {
620            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
621                cond: self.compile_variable(op.cond),
622                instructions: self.compile_scope(&mut op.scope),
623            }),
624            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
625                cond: self.compile_variable(op.cond),
626                instructions_if: self.compile_scope(&mut op.scope_if),
627                instructions_else: self.compile_scope(&mut op.scope_else),
628            }),
629            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
630                value: self.compile_variable(op.value),
631                instructions_default: self.compile_scope(&mut op.scope_default),
632                cases: op
633                    .cases
634                    .into_iter()
635                    .map(|(val, mut scope)| {
636                        (self.compile_variable(val), self.compile_scope(&mut scope))
637                    })
638                    .collect(),
639            }),
640            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
641            // No unreachable hint in WGSL
642            cube::Branch::Unreachable => instructions.push(wgsl::Instruction::Return),
643            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
644            cube::Branch::RangeLoop(mut range_loop) => {
645                instructions.push(wgsl::Instruction::RangeLoop {
646                    i: self.compile_variable(range_loop.i),
647                    start: self.compile_variable(range_loop.start),
648                    end: self.compile_variable(range_loop.end),
649                    step: range_loop.step.map(|it| self.compile_variable(it)),
650                    inclusive: range_loop.inclusive,
651                    instructions: self.compile_scope(&mut range_loop.scope),
652                })
653            }
654            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
655                instructions: self.compile_scope(&mut op.scope),
656            }),
657        };
658    }
659
660    fn compile_synchronization(
661        &mut self,
662        instructions: &mut Vec<wgsl::Instruction>,
663        synchronization: cube::Synchronization,
664    ) {
665        match synchronization {
666            cube::Synchronization::SyncCube => {
667                instructions.push(wgsl::Instruction::WorkgroupBarrier)
668            }
669            cube::Synchronization::SyncPlane => {
670                panic!("Synchronization within a plane is not supported in WGSL")
671            }
672            cube::Synchronization::SyncStorage => {
673                instructions.push(wgsl::Instruction::StorageBarrier)
674            }
675            cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
676        };
677    }
678
679    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
680        instructions.push(wgsl::Instruction::Comment { content })
681    }
682
683    fn compile_metadata(
684        &mut self,
685        metadata: cube::Metadata,
686        out: Option<cube::Variable>,
687    ) -> wgsl::Instruction {
688        let out = out.unwrap();
689        match metadata {
690            cube::Metadata::Rank { var } => {
691                let position = self.ext_meta_pos(&var);
692                let offset = self.info.metadata.rank_index(position);
693                wgsl::Instruction::Metadata {
694                    out: self.compile_variable(out),
695                    info_offset: self.compile_variable(offset.into()),
696                }
697            }
698            cube::Metadata::Stride { dim, var } => {
699                let position = self.ext_meta_pos(&var);
700                let offset = self.info.metadata.stride_offset_index(position);
701                wgsl::Instruction::ExtendedMeta {
702                    info_offset: self.compile_variable(offset.into()),
703                    dim: self.compile_variable(dim),
704                    out: self.compile_variable(out),
705                }
706            }
707            cube::Metadata::Shape { dim, var } => {
708                let position = self.ext_meta_pos(&var);
709                let offset = self.info.metadata.shape_offset_index(position);
710                wgsl::Instruction::ExtendedMeta {
711                    info_offset: self.compile_variable(offset.into()),
712                    dim: self.compile_variable(dim),
713                    out: self.compile_variable(out),
714                }
715            }
716            cube::Metadata::Length { var } => match var.kind {
717                cube::VariableKind::GlobalInputArray(id) => {
718                    let offset = self.info.metadata.len_index(id);
719                    wgsl::Instruction::Metadata {
720                        out: self.compile_variable(out),
721                        info_offset: self.compile_variable(offset.into()),
722                    }
723                }
724                cube::VariableKind::GlobalOutputArray(id) => {
725                    let offset = self.info.metadata.len_index(id);
726                    wgsl::Instruction::Metadata {
727                        out: self.compile_variable(out),
728                        info_offset: self.compile_variable(offset.into()),
729                    }
730                }
731                _ => wgsl::Instruction::Length {
732                    var: self.compile_variable(var),
733                    out: self.compile_variable(out),
734                },
735            },
736            cube::Metadata::BufferLength { var } => match var.kind {
737                cube::VariableKind::GlobalInputArray(id) => {
738                    let offset = self.info.metadata.buffer_len_index(id);
739                    wgsl::Instruction::Metadata {
740                        out: self.compile_variable(out),
741                        info_offset: self.compile_variable(offset.into()),
742                    }
743                }
744                cube::VariableKind::GlobalOutputArray(id) => {
745                    let offset = self.info.metadata.buffer_len_index(id);
746                    wgsl::Instruction::Metadata {
747                        out: self.compile_variable(out),
748                        info_offset: self.compile_variable(offset.into()),
749                    }
750                }
751                _ => wgsl::Instruction::Length {
752                    var: self.compile_variable(var),
753                    out: self.compile_variable(out),
754                },
755            },
756        }
757    }
758
759    fn compile_arithmetic(
760        &mut self,
761        value: cube::Arithmetic,
762        out: Option<cube::Variable>,
763        instructions: &mut Vec<wgsl::Instruction>,
764        scope: &mut Scope,
765    ) {
766        let out = out.unwrap();
767        match value {
768            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
769                lhs: self.compile_variable(op.lhs),
770                rhs: self.compile_variable(op.rhs),
771                out: self.compile_variable(out),
772            }),
773            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
774                lhs: self.compile_variable(op.lhs),
775                rhs: self.compile_variable(op.rhs),
776                out: self.compile_variable(out),
777            }),
778            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
779                lhs: self.compile_variable(op.lhs),
780                rhs: self.compile_variable(op.rhs),
781                out: self.compile_variable(out),
782            }),
783            cube::Arithmetic::SaturatingAdd(_) => {
784                unreachable!("Saturating add should be removed by processor");
785            }
786            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
787                a: self.compile_variable(op.a),
788                b: self.compile_variable(op.b),
789                c: self.compile_variable(op.c),
790                out: self.compile_variable(out),
791            }),
792            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
793                lhs: self.compile_variable(op.lhs),
794                rhs: self.compile_variable(op.rhs),
795                out: self.compile_variable(out),
796            }),
797            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
798                lhs: self.compile_variable(op.lhs),
799                rhs: self.compile_variable(op.rhs),
800                out: self.compile_variable(out),
801            }),
802            cube::Arithmetic::SaturatingSub(_) => {
803                unreachable!("Saturating sub should be removed by processor");
804            }
805            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
806                lhs: self.compile_variable(op.lhs),
807                rhs: self.compile_variable(op.rhs),
808                out: self.compile_variable(out),
809            }),
810            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
811                lhs: self.compile_variable(op.lhs),
812                rhs: self.compile_variable(op.rhs),
813                out: self.compile_variable(out),
814            }),
815            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
816                input: self.compile_variable(op.input),
817                out: self.compile_variable(out),
818            }),
819            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
820                input: self.compile_variable(op.input),
821                out: self.compile_variable(out),
822            }),
823            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
824                input: self.compile_variable(op.input),
825                out: self.compile_variable(out),
826            }),
827            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
828                input: self.compile_variable(op.input),
829                out: self.compile_variable(out),
830            }),
831            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
832                input: self.compile_variable(op.input),
833                out: self.compile_variable(out),
834            }),
835            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
836                input: self.compile_variable(op.input),
837                out: self.compile_variable(out),
838            }),
839            cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
840                input: self.compile_variable(op.input),
841                out: self.compile_variable(out),
842            }),
843            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
844                input: self.compile_variable(op.input),
845                out: self.compile_variable(out),
846            }),
847            cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
848                input: self.compile_variable(op.input),
849                out: self.compile_variable(out),
850            }),
851            cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
852                input: self.compile_variable(op.input),
853                out: self.compile_variable(out),
854            }),
855            cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
856                input: self.compile_variable(op.input),
857                out: self.compile_variable(out),
858            }),
859            cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
860                input: self.compile_variable(op.input),
861                out: self.compile_variable(out),
862            }),
863            cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
864                input: self.compile_variable(op.input),
865                out: self.compile_variable(out),
866            }),
867            cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
868                input: self.compile_variable(op.input),
869                out: self.compile_variable(out),
870            }),
871            cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
872                input: self.compile_variable(op.input),
873                out: self.compile_variable(out),
874            }),
875            cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
876                input: self.compile_variable(op.input),
877                out: self.compile_variable(out),
878            }),
879            cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
880                input: self.compile_variable(op.input),
881                out: self.compile_variable(out),
882            }),
883            cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
884                input: self.compile_variable(op.input),
885                out: self.compile_variable(out),
886            }),
887            cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
888                lhs: self.compile_variable(op.lhs),
889                rhs: self.compile_variable(op.rhs),
890                out: self.compile_variable(out),
891            }),
892            // No powi in WGSL
893            cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
894                instructions.push(wgsl::Instruction::Powf {
895                    lhs: self.compile_variable(op.lhs),
896                    rhs: self.compile_variable(op.rhs),
897                    out: self.compile_variable(out),
898                })
899            }
900            cube::Arithmetic::Hypot(op) => {
901                let mut scope = scope.child();
902                expand_hypot(&mut scope, op.lhs, op.rhs, out);
903                instructions.extend(self.compile_scope(&mut scope));
904            }
905            cube::Arithmetic::Rhypot(op) => {
906                let mut scope = scope.child();
907                expand_rhypot(&mut scope, op.lhs, op.rhs, out);
908                instructions.extend(self.compile_scope(&mut scope));
909            }
910
911            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
912                input: self.compile_variable(op.input),
913                out: self.compile_variable(out),
914            }),
915            cube::Arithmetic::InverseSqrt(op) => {
916                instructions.push(wgsl::Instruction::InverseSqrt {
917                    input: self.compile_variable(op.input),
918                    out: self.compile_variable(out),
919                })
920            }
921            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
922                input: self.compile_variable(op.input),
923                out: self.compile_variable(out),
924            }),
925            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
926                input: self.compile_variable(op.input),
927                out: self.compile_variable(out),
928            }),
929            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
930                input: self.compile_variable(op.input),
931                out: self.compile_variable(out),
932            }),
933            cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
934                input: self.compile_variable(op.input),
935                out: self.compile_variable(out),
936            }),
937            cube::Arithmetic::Erf(op) => {
938                let mut scope = scope.child();
939                expand_erf(&mut scope, op.input, out);
940                instructions.extend(self.compile_scope(&mut scope));
941            }
942            cube::Arithmetic::MulHi(op) => {
943                let mut scope = scope.child();
944                match self.compilation_options.supports_u64 {
945                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
946                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
947                }
948                instructions.extend(self.compile_scope(&mut scope));
949            }
950            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
951                input: self.compile_variable(op.input),
952                out: self.compile_variable(out),
953            }),
954            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
955                input: self.compile_variable(op.input),
956                min_value: self.compile_variable(op.min_value),
957                max_value: self.compile_variable(op.max_value),
958                out: self.compile_variable(out),
959            }),
960            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
961                lhs: self.compile_variable(op.lhs),
962                rhs: self.compile_variable(op.rhs),
963                out: self.compile_variable(out),
964            }),
965            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
966                input: self.compile_variable(op.input),
967                out: self.compile_variable(out),
968            }),
969            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
970                input: self.compile_variable(op.input),
971                out: self.compile_variable(out),
972            }),
973            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
974                input: self.compile_variable(op.input),
975                out: self.compile_variable(out),
976            }),
977            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
978                lhs: self.compile_variable(op.lhs),
979                rhs: self.compile_variable(op.rhs),
980                out: self.compile_variable(out),
981            }),
982        }
983    }
984
985    fn compile_cmp(
986        &mut self,
987        value: cube::Comparison,
988        out: Option<cube::Variable>,
989        instructions: &mut Vec<wgsl::Instruction>,
990    ) {
991        let out = out.unwrap();
992        match value {
993            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
994                lhs: self.compile_variable(op.lhs),
995                rhs: self.compile_variable(op.rhs),
996                out: self.compile_variable(out),
997            }),
998            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
999                lhs: self.compile_variable(op.lhs),
1000                rhs: self.compile_variable(op.rhs),
1001                out: self.compile_variable(out),
1002            }),
1003            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
1004                lhs: self.compile_variable(op.lhs),
1005                rhs: self.compile_variable(op.rhs),
1006                out: self.compile_variable(out),
1007            }),
1008            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
1009                lhs: self.compile_variable(op.lhs),
1010                rhs: self.compile_variable(op.rhs),
1011                out: self.compile_variable(out),
1012            }),
1013            cube::Comparison::GreaterEqual(op) => {
1014                instructions.push(wgsl::Instruction::GreaterEqual {
1015                    lhs: self.compile_variable(op.lhs),
1016                    rhs: self.compile_variable(op.rhs),
1017                    out: self.compile_variable(out),
1018                })
1019            }
1020            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
1021                lhs: self.compile_variable(op.lhs),
1022                rhs: self.compile_variable(op.rhs),
1023                out: self.compile_variable(out),
1024            }),
1025            cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1026                input: self.compile_variable(op.input),
1027                out: self.compile_variable(out),
1028            }),
1029            cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1030                input: self.compile_variable(op.input),
1031                out: self.compile_variable(out),
1032            }),
1033        }
1034    }
1035
1036    fn compile_bitwise(
1037        &mut self,
1038        value: cube::Bitwise,
1039        out: Option<cube::Variable>,
1040        instructions: &mut Vec<wgsl::Instruction>,
1041    ) {
1042        let out = out.unwrap();
1043        match value {
1044            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1045                lhs: self.compile_variable(op.lhs),
1046                rhs: self.compile_variable(op.rhs),
1047                out: self.compile_variable(out),
1048            }),
1049            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1050                lhs: self.compile_variable(op.lhs),
1051                rhs: self.compile_variable(op.rhs),
1052                out: self.compile_variable(out),
1053            }),
1054            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1055                lhs: self.compile_variable(op.lhs),
1056                rhs: self.compile_variable(op.rhs),
1057                out: self.compile_variable(out),
1058            }),
1059            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1060                input: self.compile_variable(op.input),
1061                out: self.compile_variable(out),
1062            }),
1063            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1064                input: self.compile_variable(op.input),
1065                out: self.compile_variable(out),
1066            }),
1067            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1068                lhs: self.compile_variable(op.lhs),
1069                rhs: self.compile_variable(op.rhs),
1070                out: self.compile_variable(out),
1071            }),
1072            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1073                lhs: self.compile_variable(op.lhs),
1074                rhs: self.compile_variable(op.rhs),
1075                out: self.compile_variable(out),
1076            }),
1077            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1078                input: self.compile_variable(op.input),
1079                out: self.compile_variable(out),
1080            }),
1081            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1082                input: self.compile_variable(op.input),
1083                out: self.compile_variable(out),
1084            }),
1085            cube::Bitwise::TrailingZeros(op) => {
1086                instructions.push(wgsl::Instruction::TrailingZeros {
1087                    input: self.compile_variable(op.input),
1088                    out: self.compile_variable(out),
1089                })
1090            }
1091            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1092                input: self.compile_variable(op.input),
1093                out: self.compile_variable(out),
1094            }),
1095        }
1096    }
1097
1098    fn compile_operator(
1099        &mut self,
1100        value: cube::Operator,
1101        out: Option<cube::Variable>,
1102        instructions: &mut Vec<wgsl::Instruction>,
1103    ) {
1104        let out = out.unwrap();
1105        match value {
1106            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1107                input: self.compile_variable(op.input),
1108                out: self.compile_variable(out),
1109            }),
1110            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1111                instructions.push(wgsl::Instruction::Index {
1112                    lhs: self.compile_variable(op.list),
1113                    rhs: self.compile_variable(op.index),
1114                    out: self.compile_variable(out),
1115                });
1116            }
1117            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1118                instructions.push(wgsl::Instruction::IndexAssign {
1119                    index: self.compile_variable(op.index),
1120                    rhs: self.compile_variable(op.value),
1121                    out: self.compile_variable(out),
1122                })
1123            }
1124            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1125                lhs: self.compile_variable(op.lhs),
1126                rhs: self.compile_variable(op.rhs),
1127                out: self.compile_variable(out),
1128            }),
1129            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1130                lhs: self.compile_variable(op.lhs),
1131                rhs: self.compile_variable(op.rhs),
1132                out: self.compile_variable(out),
1133            }),
1134            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1135                input: self.compile_variable(op.input),
1136                out: self.compile_variable(out),
1137            }),
1138            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1139                input: self.compile_variable(op.input),
1140                out: self.compile_variable(out),
1141            }),
1142            cube::Operator::InitVector(op) => instructions.push(wgsl::Instruction::VecInit {
1143                inputs: op
1144                    .inputs
1145                    .into_iter()
1146                    .map(|var| self.compile_variable(var))
1147                    .collect(),
1148                out: self.compile_variable(out),
1149            }),
1150            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1151                input: self.compile_variable(op.input),
1152                in_index: self.compile_variable(op.in_index),
1153                out: self.compile_variable(out),
1154                out_index: self.compile_variable(op.out_index),
1155            }),
1156            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1157                input: self.compile_variable(op.input),
1158                in_index: self.compile_variable(op.in_index),
1159                out: self.compile_variable(out),
1160                out_index: self.compile_variable(op.out_index),
1161                len: op.len as u32,
1162            }),
1163            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1164                cond: self.compile_variable(op.cond),
1165                then: self.compile_variable(op.then),
1166                or_else: self.compile_variable(op.or_else),
1167                out: self.compile_variable(out),
1168            }),
1169        }
1170    }
1171
1172    fn compile_atomic(
1173        &mut self,
1174        atomic: cube::AtomicOp,
1175        out: Option<cube::Variable>,
1176    ) -> wgsl::Instruction {
1177        let out = out.unwrap();
1178        match atomic {
1179            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1180                lhs: self.compile_variable(op.lhs),
1181                rhs: self.compile_variable(op.rhs),
1182                out: self.compile_variable(out),
1183            },
1184            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1185                lhs: self.compile_variable(op.lhs),
1186                rhs: self.compile_variable(op.rhs),
1187                out: self.compile_variable(out),
1188            },
1189            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1190                lhs: self.compile_variable(op.lhs),
1191                rhs: self.compile_variable(op.rhs),
1192                out: self.compile_variable(out),
1193            },
1194            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1195                lhs: self.compile_variable(op.lhs),
1196                rhs: self.compile_variable(op.rhs),
1197                out: self.compile_variable(out),
1198            },
1199            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1200                lhs: self.compile_variable(op.lhs),
1201                rhs: self.compile_variable(op.rhs),
1202                out: self.compile_variable(out),
1203            },
1204            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1205                lhs: self.compile_variable(op.lhs),
1206                rhs: self.compile_variable(op.rhs),
1207                out: self.compile_variable(out),
1208            },
1209            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1210                lhs: self.compile_variable(op.lhs),
1211                rhs: self.compile_variable(op.rhs),
1212                out: self.compile_variable(out),
1213            },
1214            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1215                input: self.compile_variable(op.input),
1216                out: self.compile_variable(out),
1217            },
1218            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1219                input: self.compile_variable(op.input),
1220                out: self.compile_variable(out),
1221            },
1222            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1223                lhs: self.compile_variable(op.lhs),
1224                rhs: self.compile_variable(op.rhs),
1225                out: self.compile_variable(out),
1226            },
1227            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1228                lhs: self.compile_variable(op.input),
1229                cmp: self.compile_variable(op.cmp),
1230                value: self.compile_variable(op.val),
1231                out: self.compile_variable(out),
1232            },
1233        }
1234    }
1235
1236    fn compile_binding(&mut self, value: kernel::KernelArg) -> wgsl::KernelArg {
1237        wgsl::KernelArg {
1238            id: value.id,
1239            visibility: value.visibility,
1240            location: wgsl::Location::Storage,
1241            item: self.compile_type(value.ty),
1242            size: value.size,
1243        }
1244    }
1245}
1246
1247fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1248    let mut extensions = Vec::new();
1249
1250    let mut register_extension = |extension: wgsl::Extension| {
1251        if !extensions.contains(&extension) {
1252            extensions.push(extension);
1253        }
1254    };
1255
1256    // Since not all instructions are native to WGSL, we need to add the custom ones.
1257    for instruction in instructions {
1258        match instruction {
1259            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1260                register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1261                register_extension(wgsl::powf_extension(rhs, out));
1262            }
1263            #[cfg(target_os = "macos")]
1264            wgsl::Instruction::Tanh { input, out: _ } => {
1265                register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1266                register_extension(wgsl::Extension::SafeTanh(input.item()));
1267            }
1268            wgsl::Instruction::IsNan { input, out } => {
1269                register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1270                register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1271            }
1272            wgsl::Instruction::IsInf { input, out } => {
1273                register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1274                register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1275            }
1276            wgsl::Instruction::If { instructions, .. } => {
1277                for extension in register_extensions(instructions) {
1278                    register_extension(extension);
1279                }
1280            }
1281            wgsl::Instruction::IfElse {
1282                instructions_if,
1283                instructions_else,
1284                ..
1285            } => {
1286                for extension in register_extensions(instructions_if) {
1287                    register_extension(extension);
1288                }
1289                for extension in register_extensions(instructions_else) {
1290                    register_extension(extension);
1291                }
1292            }
1293            wgsl::Instruction::Loop { instructions } => {
1294                for extension in register_extensions(instructions) {
1295                    register_extension(extension);
1296                }
1297            }
1298            wgsl::Instruction::RangeLoop { instructions, .. } => {
1299                for extension in register_extensions(instructions) {
1300                    register_extension(extension);
1301                }
1302            }
1303            _ => {}
1304        }
1305    }
1306
1307    extensions
1308}