Skip to main content

cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::post_processing::{
8    checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12    Metadata, WgpuCompilationOptions,
13    ir::{self as cube, Scope},
14    prelude::expand_erf,
15};
16use cubecl_core::{
17    ir::{Processor, UIntKind},
18    post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: usize = 4;
24
25/// Wgsl Compiler.
26#[derive(Clone, Default)]
27pub struct WgslCompiler {
28    metadata: Metadata,
29    ext_meta_pos: Vec<u32>,
30    local_invocation_index: bool,
31    local_invocation_id: bool,
32    // TODO: possible cleanup, this bool seems to not be used
33    global_invocation_id: bool,
34    workgroup_id: bool,
35    subgroup_size: bool,
36    subgroup_id: bool,
37    subgroup_invocation_id: bool,
38    id: bool,
39    num_workgroups: bool,
40    workgroup_id_no_axis: bool,
41    workgroup_size_no_axis: bool,
42    num_workgroup_no_axis: bool,
43    shared_arrays: Vec<SharedArray>,
44    shared_values: Vec<SharedValue>,
45    const_arrays: Vec<ConstantArray>,
46    local_arrays: Vec<LocalArray>,
47    #[allow(dead_code)]
48    compilation_options: WgpuCompilationOptions,
49    strategy: ExecutionMode,
50    subgroup_instructions_used: bool,
51    f16_used: bool,
52}
53
54impl core::fmt::Debug for WgslCompiler {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.write_str("WgslCompiler")
57    }
58}
59
60impl cubecl_core::Compiler for WgslCompiler {
61    type Representation = ComputeShader;
62    type CompilationOptions = WgpuCompilationOptions;
63
64    fn compile(
65        &mut self,
66        shader: kernel::KernelDefinition,
67        compilation_options: &Self::CompilationOptions,
68        mode: ExecutionMode,
69        address_type: StorageType,
70    ) -> Result<Self::Representation, CompilationError> {
71        self.compilation_options = compilation_options.clone();
72        self.compile_shader(shader, mode, address_type)
73    }
74
75    fn elem_size(&self, elem: cube::ElemType) -> usize {
76        elem.size()
77    }
78
79    fn extension(&self) -> &'static str {
80        "wgsl"
81    }
82}
83
84impl WgslCompiler {
85    fn compile_shader(
86        &mut self,
87        mut value: kernel::KernelDefinition,
88        mode: ExecutionMode,
89        address_type: StorageType,
90    ) -> Result<wgsl::ComputeShader, CompilationError> {
91        let errors = value.body.pop_errors();
92        if !errors.is_empty() {
93            let mut reason = "Can't compile wgsl kernel".to_string();
94            for error in errors {
95                reason += error.as_str();
96                reason += "\n";
97            }
98
99            return Err(CompilationError::Validation {
100                reason,
101                backtrace: BackTrace::capture(),
102            });
103        }
104
105        self.strategy = mode;
106
107        let num_meta = value.buffers.len();
108
109        self.ext_meta_pos = Vec::new();
110        let mut num_ext = 0;
111
112        for binding in value.buffers.iter() {
113            self.ext_meta_pos.push(num_ext);
114            if binding.has_extended_meta {
115                num_ext += 1;
116            }
117        }
118
119        self.metadata = Metadata::new(num_meta as u32, num_ext);
120
121        let address_type = self.compile_storage_type(address_type);
122        let instructions = self.compile_scope(&mut value.body);
123        let extensions = register_extensions(&instructions);
124        let body = wgsl::Body {
125            instructions,
126            id: self.id,
127            address_type,
128        };
129
130        Ok(wgsl::ComputeShader {
131            address_type,
132            buffers: value
133                .buffers
134                .into_iter()
135                .map(|mut it| {
136                    // This is safe when combined with the unroll transform that adjusts all indices.
137                    // Must not be used alone
138                    if it.ty.line_size() > MAX_LINE_SIZE {
139                        it.ty = it.ty.line(MAX_LINE_SIZE);
140                    }
141                    self.compile_binding(it)
142                })
143                .collect(),
144            scalars: value
145                .scalars
146                .into_iter()
147                .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
148                .collect(),
149            shared_arrays: self.shared_arrays.clone(),
150            shared_values: self.shared_values.clone(),
151            constant_arrays: self.const_arrays.clone(),
152            local_arrays: self.local_arrays.clone(),
153            has_metadata: self.metadata.static_len() > 0,
154            workgroup_size: value.cube_dim,
155            global_invocation_id: self.global_invocation_id || self.id,
156            local_invocation_index: self.local_invocation_index,
157            local_invocation_id: self.local_invocation_id,
158            num_workgroups: self.id
159                || self.num_workgroups
160                || self.num_workgroup_no_axis
161                || self.workgroup_id_no_axis,
162            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
163            subgroup_size: self.subgroup_size,
164            subgroup_id: self.subgroup_id,
165            subgroup_invocation_id: self.subgroup_invocation_id,
166            body,
167            extensions,
168            num_workgroups_no_axis: self.num_workgroup_no_axis,
169            workgroup_id_no_axis: self.workgroup_id_no_axis,
170            workgroup_size_no_axis: self.workgroup_size_no_axis,
171            subgroup_instructions_used: self.subgroup_instructions_used,
172            f16_used: self.f16_used,
173            kernel_name: value.options.kernel_name,
174        })
175    }
176
177    fn compile_type(&mut self, item: cube::Type) -> Item {
178        match item {
179            cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
180            cube::Type::Line(ty, size) => {
181                let elem = self.compile_storage_type(ty);
182                match size {
183                    2 => wgsl::Item::Vec2(elem),
184                    3 => wgsl::Item::Vec3(elem),
185                    4 => wgsl::Item::Vec4(elem),
186                    _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
187                }
188            }
189            cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
190        }
191    }
192
193    fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
194        match ty {
195            cube::StorageType::Scalar(ty) => self.compile_elem(ty),
196            cube::StorageType::Atomic(ty) => match ty {
197                cube::ElemType::Float(i) => match i {
198                    cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
199                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
200                },
201                cube::ElemType::Int(i) => match i {
202                    cube::IntKind::I32 => wgsl::Elem::AtomicI32,
203                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
204                },
205                cube::ElemType::UInt(kind) => match kind {
206                    cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
207                    kind => panic!("{kind:?} is not a valid WgpuElement"),
208                },
209                other => panic!("{other:?} is not a valid WgpuElement"),
210            },
211            cube::StorageType::Packed(_, _) => {
212                unimplemented!("Packed types not yet supported in WGSL")
213            }
214            cube::StorageType::Opaque(ty) => match ty {
215                cube::OpaqueType::Barrier(_) => {
216                    unimplemented!("Barrier objects not supported in WGSL")
217                }
218            },
219        }
220    }
221
222    fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
223        match value {
224            cube::ElemType::Float(f) => match f {
225                cube::FloatKind::E2M1
226                | cube::FloatKind::E2M3
227                | cube::FloatKind::E3M2
228                | cube::FloatKind::E4M3
229                | cube::FloatKind::E5M2
230                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
231                cube::FloatKind::F16 => {
232                    self.f16_used = true;
233                    wgsl::Elem::F16
234                }
235                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
236                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
237                cube::FloatKind::Flex32 => wgsl::Elem::F32,
238                cube::FloatKind::F32 => wgsl::Elem::F32,
239                cube::FloatKind::F64 => wgsl::Elem::F64,
240            },
241            cube::ElemType::Int(i) => match i {
242                cube::IntKind::I32 => wgsl::Elem::I32,
243                cube::IntKind::I64 => wgsl::Elem::I64,
244                kind => panic!("{kind:?} is not a valid WgpuElement"),
245            },
246            cube::ElemType::UInt(kind) => match kind {
247                cube::UIntKind::U32 => wgsl::Elem::U32,
248                cube::UIntKind::U64 => wgsl::Elem::U64,
249                kind => panic!("{kind:?} is not a valid WgpuElement"),
250            },
251            cube::ElemType::Bool => wgsl::Elem::Bool,
252        }
253    }
254
255    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
256        let pos = var.index().expect("Variable should have index");
257        self.ext_meta_pos[pos as usize]
258    }
259
260    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
261        let item = value.ty;
262        match value.kind {
263            cube::VariableKind::GlobalInputArray(id) => {
264                wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
265            }
266            cube::VariableKind::GlobalScalar(id) => {
267                wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
268            }
269            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
270                wgsl::Variable::LocalMut {
271                    id,
272                    item: self.compile_type(item),
273                }
274            }
275            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
276                id,
277                item: self.compile_type(item),
278            },
279            cube::VariableKind::GlobalOutputArray(id) => {
280                wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
281            }
282            cube::VariableKind::Constant(value) => {
283                wgsl::Variable::Constant(value, self.compile_type(item))
284            }
285            cube::VariableKind::SharedArray {
286                id,
287                length,
288                unroll_factor,
289                alignment,
290            } => {
291                let item = self.compile_type(item);
292                if !self.shared_arrays.iter().any(|s| s.index == id) {
293                    self.shared_arrays.push(SharedArray::new(
294                        id,
295                        item,
296                        (length * unroll_factor) as u32,
297                        alignment.map(|it| it as u32),
298                    ));
299                }
300                wgsl::Variable::SharedArray(id, item, length as u32)
301            }
302            cube::VariableKind::Shared { id } => {
303                let item = self.compile_type(item);
304                if !self.shared_values.iter().any(|s| s.index == id) {
305                    self.shared_values.push(SharedValue::new(id, item));
306                }
307                wgsl::Variable::SharedValue(id, item)
308            }
309            cube::VariableKind::ConstantArray { id, length, .. } => {
310                let item = self.compile_type(item);
311                wgsl::Variable::ConstantArray(id, item, length as u32)
312            }
313            cube::VariableKind::LocalArray {
314                id,
315                length,
316                unroll_factor,
317            } => {
318                let item = self.compile_type(item);
319                if !self.local_arrays.iter().any(|s| s.index == id) {
320                    self.local_arrays.push(LocalArray::new(
321                        id,
322                        item,
323                        (length * unroll_factor) as u32,
324                    ));
325                }
326                wgsl::Variable::LocalArray(id, item, length as u32)
327            }
328            cube::VariableKind::Builtin(builtin) => match builtin {
329                cube::Builtin::AbsolutePos => {
330                    self.id = true;
331                    wgsl::Variable::Id
332                }
333                cube::Builtin::UnitPos => {
334                    self.local_invocation_index = true;
335                    wgsl::Variable::LocalInvocationIndex
336                }
337                cube::Builtin::UnitPosX => {
338                    self.local_invocation_id = true;
339                    wgsl::Variable::LocalInvocationIdX
340                }
341                cube::Builtin::UnitPosY => {
342                    self.local_invocation_id = true;
343                    wgsl::Variable::LocalInvocationIdY
344                }
345                cube::Builtin::UnitPosZ => {
346                    self.local_invocation_id = true;
347                    wgsl::Variable::LocalInvocationIdZ
348                }
349                cube::Builtin::CubePosX => {
350                    self.workgroup_id = true;
351                    wgsl::Variable::WorkgroupIdX
352                }
353                cube::Builtin::CubePosY => {
354                    self.workgroup_id = true;
355                    wgsl::Variable::WorkgroupIdY
356                }
357                cube::Builtin::CubePosZ => {
358                    self.workgroup_id = true;
359                    wgsl::Variable::WorkgroupIdZ
360                }
361                cube::Builtin::CubePosCluster
362                | cube::Builtin::CubePosClusterX
363                | cube::Builtin::CubePosClusterY
364                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
365                cube::Builtin::AbsolutePosX => {
366                    self.global_invocation_id = true;
367                    wgsl::Variable::GlobalInvocationIdX
368                }
369                cube::Builtin::AbsolutePosY => {
370                    self.global_invocation_id = true;
371                    wgsl::Variable::GlobalInvocationIdY
372                }
373                cube::Builtin::AbsolutePosZ => {
374                    self.global_invocation_id = true;
375                    wgsl::Variable::GlobalInvocationIdZ
376                }
377                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
378                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
379                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
380                cube::Builtin::CubeClusterDim
381                | cube::Builtin::CubeClusterDimX
382                | cube::Builtin::CubeClusterDimY
383                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
384                cube::Builtin::CubeCountX => {
385                    self.num_workgroups = true;
386                    wgsl::Variable::NumWorkgroupsX
387                }
388                cube::Builtin::CubeCountY => {
389                    self.num_workgroups = true;
390                    wgsl::Variable::NumWorkgroupsY
391                }
392                cube::Builtin::CubeCountZ => {
393                    self.num_workgroups = true;
394                    wgsl::Variable::NumWorkgroupsZ
395                }
396                cube::Builtin::CubePos => {
397                    self.workgroup_id_no_axis = true;
398                    wgsl::Variable::WorkgroupId
399                }
400                cube::Builtin::CubeDim => {
401                    self.workgroup_size_no_axis = true;
402                    wgsl::Variable::WorkgroupSize
403                }
404                cube::Builtin::CubeCount => {
405                    self.num_workgroup_no_axis = true;
406                    wgsl::Variable::NumWorkgroups
407                }
408                cube::Builtin::PlaneDim => {
409                    self.subgroup_size = true;
410                    wgsl::Variable::SubgroupSize
411                }
412                cube::Builtin::PlanePos => {
413                    self.subgroup_id = true;
414                    wgsl::Variable::SubgroupId
415                }
416                cube::Builtin::UnitPosPlane => {
417                    self.subgroup_invocation_id = true;
418                    wgsl::Variable::SubgroupInvocationId
419                }
420            },
421            cube::VariableKind::Matrix { .. } => {
422                panic!("Cooperative matrix-multiply and accumulate not supported.")
423            }
424            cube::VariableKind::Pipeline { .. } => {
425                panic!("Pipeline not supported.")
426            }
427            cube::VariableKind::BarrierToken { .. } => {
428                panic!("Barrier not supported.")
429            }
430            cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
431            cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
432        }
433    }
434
435    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
436        let var = cube::Variable::constant(value.into(), UIntKind::U32);
437        self.compile_variable(var)
438    }
439
440    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
441        let mut instructions = Vec::new();
442
443        let const_arrays = scope
444            .const_arrays
445            .drain(..)
446            .map(|(var, values)| ConstantArray {
447                index: var.index().unwrap(),
448                item: self.compile_type(var.ty),
449                size: values.len() as u32,
450                values: values
451                    .into_iter()
452                    .map(|val| self.compile_variable(val))
453                    .collect(),
454            })
455            .collect::<Vec<_>>();
456        self.const_arrays.extend(const_arrays);
457
458        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
459        let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
460        let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
461        let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
462
463        for mut var in processing.variables {
464            if var.ty.line_size() > MAX_LINE_SIZE {
465                var.ty = var.ty.line(MAX_LINE_SIZE);
466            }
467            instructions.push(wgsl::Instruction::DeclareVariable {
468                var: self.compile_variable(var),
469            });
470        }
471
472        processing
473            .instructions
474            .into_iter()
475            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
476
477        instructions
478    }
479
480    fn compile_operation(
481        &mut self,
482        instructions: &mut Vec<wgsl::Instruction>,
483        operation: cube::Operation,
484        out: Option<cube::Variable>,
485        scope: &mut cube::Scope,
486    ) {
487        match operation {
488            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
489                input: self.compile_variable(variable),
490                out: self.compile_variable(out.unwrap()),
491            }),
492            cube::Operation::Arithmetic(op) => {
493                self.compile_arithmetic(op, out, instructions, scope)
494            }
495            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
496            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
497            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
498            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
499            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
500            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
501            cube::Operation::Synchronization(val) => {
502                self.compile_synchronization(instructions, val)
503            }
504            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
505            cube::Operation::CoopMma(_) => {
506                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
507            }
508            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
509                self.compile_comment(instructions, content)
510            }
511            cube::Operation::NonSemantic(_) => {}
512            cube::Operation::Barrier(_) => {
513                panic!("Barrier isn't supported on wgpu.")
514            }
515            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
516            cube::Operation::Marker(_) => {}
517        }
518    }
519
520    fn compile_subgroup(
521        &mut self,
522        instructions: &mut Vec<wgsl::Instruction>,
523        subgroup: cube::Plane,
524        out: Option<cube::Variable>,
525    ) {
526        self.subgroup_instructions_used = true;
527
528        let out = out.unwrap();
529        let op = match subgroup {
530            cube::Plane::Elect => Subgroup::Elect {
531                out: self.compile_variable(out),
532            },
533            cube::Plane::All(op) => Subgroup::All {
534                input: self.compile_variable(op.input),
535                out: self.compile_variable(out),
536            },
537            cube::Plane::Any(op) => Subgroup::Any {
538                input: self.compile_variable(op.input),
539                out: self.compile_variable(out),
540            },
541            cube::Plane::Ballot(op) => Subgroup::Ballot {
542                input: self.compile_variable(op.input),
543                out: self.compile_variable(out),
544            },
545
546            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
547                lhs: self.compile_variable(op.lhs),
548                rhs: self.compile_variable(op.rhs),
549                out: self.compile_variable(out),
550            },
551
552            cube::Plane::Sum(op) => Subgroup::Sum {
553                input: self.compile_variable(op.input),
554                out: self.compile_variable(out),
555            },
556
557            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
558                input: self.compile_variable(op.input),
559                out: self.compile_variable(out),
560            },
561            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
562                input: self.compile_variable(op.input),
563                out: self.compile_variable(out),
564            },
565            cube::Plane::Prod(op) => Subgroup::Prod {
566                input: self.compile_variable(op.input),
567                out: self.compile_variable(out),
568            },
569            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
570                input: self.compile_variable(op.input),
571                out: self.compile_variable(out),
572            },
573            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
574                input: self.compile_variable(op.input),
575                out: self.compile_variable(out),
576            },
577            cube::Plane::Min(op) => Subgroup::Min {
578                input: self.compile_variable(op.input),
579                out: self.compile_variable(out),
580            },
581            cube::Plane::Max(op) => Subgroup::Max {
582                input: self.compile_variable(op.input),
583                out: self.compile_variable(out),
584            },
585            cube::Plane::Shuffle(op) => Subgroup::Shuffle {
586                lhs: self.compile_variable(op.lhs),
587                rhs: self.compile_variable(op.rhs),
588                out: self.compile_variable(out),
589            },
590            cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
591                lhs: self.compile_variable(op.lhs),
592                rhs: self.compile_variable(op.rhs),
593                out: self.compile_variable(out),
594            },
595            cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
596                lhs: self.compile_variable(op.lhs),
597                rhs: self.compile_variable(op.rhs),
598                out: self.compile_variable(out),
599            },
600            cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
601                lhs: self.compile_variable(op.lhs),
602                rhs: self.compile_variable(op.rhs),
603                out: self.compile_variable(out),
604            },
605        };
606
607        instructions.push(wgsl::Instruction::Subgroup(op));
608    }
609
610    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
611        match branch {
612            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
613                cond: self.compile_variable(op.cond),
614                instructions: self.compile_scope(&mut op.scope),
615            }),
616            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
617                cond: self.compile_variable(op.cond),
618                instructions_if: self.compile_scope(&mut op.scope_if),
619                instructions_else: self.compile_scope(&mut op.scope_else),
620            }),
621            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
622                value: self.compile_variable(op.value),
623                instructions_default: self.compile_scope(&mut op.scope_default),
624                cases: op
625                    .cases
626                    .into_iter()
627                    .map(|(val, mut scope)| {
628                        (self.compile_variable(val), self.compile_scope(&mut scope))
629                    })
630                    .collect(),
631            }),
632            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
633            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
634            cube::Branch::RangeLoop(mut range_loop) => {
635                instructions.push(wgsl::Instruction::RangeLoop {
636                    i: self.compile_variable(range_loop.i),
637                    start: self.compile_variable(range_loop.start),
638                    end: self.compile_variable(range_loop.end),
639                    step: range_loop.step.map(|it| self.compile_variable(it)),
640                    inclusive: range_loop.inclusive,
641                    instructions: self.compile_scope(&mut range_loop.scope),
642                })
643            }
644            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
645                instructions: self.compile_scope(&mut op.scope),
646            }),
647        };
648    }
649
650    fn compile_synchronization(
651        &mut self,
652        instructions: &mut Vec<wgsl::Instruction>,
653        synchronization: cube::Synchronization,
654    ) {
655        match synchronization {
656            cube::Synchronization::SyncCube => {
657                instructions.push(wgsl::Instruction::WorkgroupBarrier)
658            }
659            cube::Synchronization::SyncPlane => {
660                panic!("Synchronization within a plane is not supported in WGSL")
661            }
662            cube::Synchronization::SyncStorage => {
663                instructions.push(wgsl::Instruction::StorageBarrier)
664            }
665            cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
666        };
667    }
668
669    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
670        instructions.push(wgsl::Instruction::Comment { content })
671    }
672
673    fn compile_metadata(
674        &mut self,
675        metadata: cube::Metadata,
676        out: Option<cube::Variable>,
677    ) -> wgsl::Instruction {
678        let out = out.unwrap();
679        match metadata {
680            cube::Metadata::Rank { var } => {
681                let position = self.ext_meta_pos(&var);
682                let offset = self.metadata.rank_index(position);
683                wgsl::Instruction::Metadata {
684                    out: self.compile_variable(out),
685                    info_offset: self.compile_variable(offset.into()),
686                }
687            }
688            cube::Metadata::Stride { dim, var } => {
689                let position = self.ext_meta_pos(&var);
690                let offset = self.metadata.stride_offset_index(position);
691                wgsl::Instruction::ExtendedMeta {
692                    info_offset: self.compile_variable(offset.into()),
693                    dim: self.compile_variable(dim),
694                    out: self.compile_variable(out),
695                }
696            }
697            cube::Metadata::Shape { dim, var } => {
698                let position = self.ext_meta_pos(&var);
699                let offset = self.metadata.shape_offset_index(position);
700                wgsl::Instruction::ExtendedMeta {
701                    info_offset: self.compile_variable(offset.into()),
702                    dim: self.compile_variable(dim),
703                    out: self.compile_variable(out),
704                }
705            }
706            cube::Metadata::Length { var } => match var.kind {
707                cube::VariableKind::GlobalInputArray(id) => {
708                    let offset = self.metadata.len_index(id);
709                    wgsl::Instruction::Metadata {
710                        out: self.compile_variable(out),
711                        info_offset: self.compile_variable(offset.into()),
712                    }
713                }
714                cube::VariableKind::GlobalOutputArray(id) => {
715                    let offset = self.metadata.len_index(id);
716                    wgsl::Instruction::Metadata {
717                        out: self.compile_variable(out),
718                        info_offset: self.compile_variable(offset.into()),
719                    }
720                }
721                _ => wgsl::Instruction::Length {
722                    var: self.compile_variable(var),
723                    out: self.compile_variable(out),
724                },
725            },
726            cube::Metadata::BufferLength { var } => match var.kind {
727                cube::VariableKind::GlobalInputArray(id) => {
728                    let offset = self.metadata.buffer_len_index(id);
729                    wgsl::Instruction::Metadata {
730                        out: self.compile_variable(out),
731                        info_offset: self.compile_variable(offset.into()),
732                    }
733                }
734                cube::VariableKind::GlobalOutputArray(id) => {
735                    let offset = self.metadata.buffer_len_index(id);
736                    wgsl::Instruction::Metadata {
737                        out: self.compile_variable(out),
738                        info_offset: self.compile_variable(offset.into()),
739                    }
740                }
741                _ => wgsl::Instruction::Length {
742                    var: self.compile_variable(var),
743                    out: self.compile_variable(out),
744                },
745            },
746        }
747    }
748
749    fn compile_arithmetic(
750        &mut self,
751        value: cube::Arithmetic,
752        out: Option<cube::Variable>,
753        instructions: &mut Vec<wgsl::Instruction>,
754        scope: &mut Scope,
755    ) {
756        let out = out.unwrap();
757        match value {
758            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
759                lhs: self.compile_variable(op.lhs),
760                rhs: self.compile_variable(op.rhs),
761                out: self.compile_variable(out),
762            }),
763            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
764                lhs: self.compile_variable(op.lhs),
765                rhs: self.compile_variable(op.rhs),
766                out: self.compile_variable(out),
767            }),
768            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
769                lhs: self.compile_variable(op.lhs),
770                rhs: self.compile_variable(op.rhs),
771                out: self.compile_variable(out),
772            }),
773            cube::Arithmetic::SaturatingAdd(_) => {
774                unreachable!("Saturating add should be removed by processor");
775            }
776            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
777                a: self.compile_variable(op.a),
778                b: self.compile_variable(op.b),
779                c: self.compile_variable(op.c),
780                out: self.compile_variable(out),
781            }),
782            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
783                lhs: self.compile_variable(op.lhs),
784                rhs: self.compile_variable(op.rhs),
785                out: self.compile_variable(out),
786            }),
787            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
788                lhs: self.compile_variable(op.lhs),
789                rhs: self.compile_variable(op.rhs),
790                out: self.compile_variable(out),
791            }),
792            cube::Arithmetic::SaturatingSub(_) => {
793                unreachable!("Saturating sub should be removed by processor");
794            }
795            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
796                lhs: self.compile_variable(op.lhs),
797                rhs: self.compile_variable(op.rhs),
798                out: self.compile_variable(out),
799            }),
800            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
801                lhs: self.compile_variable(op.lhs),
802                rhs: self.compile_variable(op.rhs),
803                out: self.compile_variable(out),
804            }),
805            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
806                input: self.compile_variable(op.input),
807                out: self.compile_variable(out),
808            }),
809            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
810                input: self.compile_variable(op.input),
811                out: self.compile_variable(out),
812            }),
813            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
814                input: self.compile_variable(op.input),
815                out: self.compile_variable(out),
816            }),
817            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
818                input: self.compile_variable(op.input),
819                out: self.compile_variable(out),
820            }),
821            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
822                input: self.compile_variable(op.input),
823                out: self.compile_variable(out),
824            }),
825            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
826                input: self.compile_variable(op.input),
827                out: self.compile_variable(out),
828            }),
829            cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
830                input: self.compile_variable(op.input),
831                out: self.compile_variable(out),
832            }),
833            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
834                input: self.compile_variable(op.input),
835                out: self.compile_variable(out),
836            }),
837            cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
838                input: self.compile_variable(op.input),
839                out: self.compile_variable(out),
840            }),
841            cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
842                input: self.compile_variable(op.input),
843                out: self.compile_variable(out),
844            }),
845            cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
846                input: self.compile_variable(op.input),
847                out: self.compile_variable(out),
848            }),
849            cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
850                input: self.compile_variable(op.input),
851                out: self.compile_variable(out),
852            }),
853            cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
854                input: self.compile_variable(op.input),
855                out: self.compile_variable(out),
856            }),
857            cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
858                input: self.compile_variable(op.input),
859                out: self.compile_variable(out),
860            }),
861            cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
862                input: self.compile_variable(op.input),
863                out: self.compile_variable(out),
864            }),
865            cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
866                input: self.compile_variable(op.input),
867                out: self.compile_variable(out),
868            }),
869            cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
870                input: self.compile_variable(op.input),
871                out: self.compile_variable(out),
872            }),
873            cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
874                input: self.compile_variable(op.input),
875                out: self.compile_variable(out),
876            }),
877            cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
878                lhs: self.compile_variable(op.lhs),
879                rhs: self.compile_variable(op.rhs),
880                out: self.compile_variable(out),
881            }),
882            // No powi in WGSL
883            cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
884                instructions.push(wgsl::Instruction::Powf {
885                    lhs: self.compile_variable(op.lhs),
886                    rhs: self.compile_variable(op.rhs),
887                    out: self.compile_variable(out),
888                })
889            }
890            cube::Arithmetic::Hypot(op) => {
891                let mut scope = scope.child();
892                expand_hypot(&mut scope, op.lhs, op.rhs, out);
893                instructions.extend(self.compile_scope(&mut scope));
894            }
895            cube::Arithmetic::Rhypot(op) => {
896                let mut scope = scope.child();
897                expand_rhypot(&mut scope, op.lhs, op.rhs, out);
898                instructions.extend(self.compile_scope(&mut scope));
899            }
900
901            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
902                input: self.compile_variable(op.input),
903                out: self.compile_variable(out),
904            }),
905            cube::Arithmetic::InverseSqrt(op) => {
906                instructions.push(wgsl::Instruction::InverseSqrt {
907                    input: self.compile_variable(op.input),
908                    out: self.compile_variable(out),
909                })
910            }
911            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
912                input: self.compile_variable(op.input),
913                out: self.compile_variable(out),
914            }),
915            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
916                input: self.compile_variable(op.input),
917                out: self.compile_variable(out),
918            }),
919            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
920                input: self.compile_variable(op.input),
921                out: self.compile_variable(out),
922            }),
923            cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
924                input: self.compile_variable(op.input),
925                out: self.compile_variable(out),
926            }),
927            cube::Arithmetic::Erf(op) => {
928                let mut scope = scope.child();
929                expand_erf(&mut scope, op.input, out);
930                instructions.extend(self.compile_scope(&mut scope));
931            }
932            cube::Arithmetic::MulHi(op) => {
933                let mut scope = scope.child();
934                match self.compilation_options.supports_u64 {
935                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
936                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
937                }
938                instructions.extend(self.compile_scope(&mut scope));
939            }
940            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
941                input: self.compile_variable(op.input),
942                out: self.compile_variable(out),
943            }),
944            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
945                input: self.compile_variable(op.input),
946                min_value: self.compile_variable(op.min_value),
947                max_value: self.compile_variable(op.max_value),
948                out: self.compile_variable(out),
949            }),
950            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
951                lhs: self.compile_variable(op.lhs),
952                rhs: self.compile_variable(op.rhs),
953                out: self.compile_variable(out),
954            }),
955            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
956                input: self.compile_variable(op.input),
957                out: self.compile_variable(out),
958            }),
959            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
960                input: self.compile_variable(op.input),
961                out: self.compile_variable(out),
962            }),
963            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
964                input: self.compile_variable(op.input),
965                out: self.compile_variable(out),
966            }),
967            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
968                lhs: self.compile_variable(op.lhs),
969                rhs: self.compile_variable(op.rhs),
970                out: self.compile_variable(out),
971            }),
972        }
973    }
974
975    fn compile_cmp(
976        &mut self,
977        value: cube::Comparison,
978        out: Option<cube::Variable>,
979        instructions: &mut Vec<wgsl::Instruction>,
980    ) {
981        let out = out.unwrap();
982        match value {
983            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
984                lhs: self.compile_variable(op.lhs),
985                rhs: self.compile_variable(op.rhs),
986                out: self.compile_variable(out),
987            }),
988            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
989                lhs: self.compile_variable(op.lhs),
990                rhs: self.compile_variable(op.rhs),
991                out: self.compile_variable(out),
992            }),
993            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
994                lhs: self.compile_variable(op.lhs),
995                rhs: self.compile_variable(op.rhs),
996                out: self.compile_variable(out),
997            }),
998            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
999                lhs: self.compile_variable(op.lhs),
1000                rhs: self.compile_variable(op.rhs),
1001                out: self.compile_variable(out),
1002            }),
1003            cube::Comparison::GreaterEqual(op) => {
1004                instructions.push(wgsl::Instruction::GreaterEqual {
1005                    lhs: self.compile_variable(op.lhs),
1006                    rhs: self.compile_variable(op.rhs),
1007                    out: self.compile_variable(out),
1008                })
1009            }
1010            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
1011                lhs: self.compile_variable(op.lhs),
1012                rhs: self.compile_variable(op.rhs),
1013                out: self.compile_variable(out),
1014            }),
1015            cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1016                input: self.compile_variable(op.input),
1017                out: self.compile_variable(out),
1018            }),
1019            cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1020                input: self.compile_variable(op.input),
1021                out: self.compile_variable(out),
1022            }),
1023        }
1024    }
1025
1026    fn compile_bitwise(
1027        &mut self,
1028        value: cube::Bitwise,
1029        out: Option<cube::Variable>,
1030        instructions: &mut Vec<wgsl::Instruction>,
1031    ) {
1032        let out = out.unwrap();
1033        match value {
1034            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1035                lhs: self.compile_variable(op.lhs),
1036                rhs: self.compile_variable(op.rhs),
1037                out: self.compile_variable(out),
1038            }),
1039            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1040                lhs: self.compile_variable(op.lhs),
1041                rhs: self.compile_variable(op.rhs),
1042                out: self.compile_variable(out),
1043            }),
1044            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1045                lhs: self.compile_variable(op.lhs),
1046                rhs: self.compile_variable(op.rhs),
1047                out: self.compile_variable(out),
1048            }),
1049            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1050                input: self.compile_variable(op.input),
1051                out: self.compile_variable(out),
1052            }),
1053            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1054                input: self.compile_variable(op.input),
1055                out: self.compile_variable(out),
1056            }),
1057            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1058                lhs: self.compile_variable(op.lhs),
1059                rhs: self.compile_variable(op.rhs),
1060                out: self.compile_variable(out),
1061            }),
1062            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1063                lhs: self.compile_variable(op.lhs),
1064                rhs: self.compile_variable(op.rhs),
1065                out: self.compile_variable(out),
1066            }),
1067            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1068                input: self.compile_variable(op.input),
1069                out: self.compile_variable(out),
1070            }),
1071            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1072                input: self.compile_variable(op.input),
1073                out: self.compile_variable(out),
1074            }),
1075            cube::Bitwise::TrailingZeros(op) => {
1076                instructions.push(wgsl::Instruction::TrailingZeros {
1077                    input: self.compile_variable(op.input),
1078                    out: self.compile_variable(out),
1079                })
1080            }
1081            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1082                input: self.compile_variable(op.input),
1083                out: self.compile_variable(out),
1084            }),
1085        }
1086    }
1087
1088    fn compile_operator(
1089        &mut self,
1090        value: cube::Operator,
1091        out: Option<cube::Variable>,
1092        instructions: &mut Vec<wgsl::Instruction>,
1093    ) {
1094        let out = out.unwrap();
1095        match value {
1096            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1097                input: self.compile_variable(op.input),
1098                out: self.compile_variable(out),
1099            }),
1100            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1101                instructions.push(wgsl::Instruction::Index {
1102                    lhs: self.compile_variable(op.list),
1103                    rhs: self.compile_variable(op.index),
1104                    out: self.compile_variable(out),
1105                });
1106            }
1107            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1108                instructions.push(wgsl::Instruction::IndexAssign {
1109                    index: self.compile_variable(op.index),
1110                    rhs: self.compile_variable(op.value),
1111                    out: self.compile_variable(out),
1112                })
1113            }
1114            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1115                lhs: self.compile_variable(op.lhs),
1116                rhs: self.compile_variable(op.rhs),
1117                out: self.compile_variable(out),
1118            }),
1119            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1120                lhs: self.compile_variable(op.lhs),
1121                rhs: self.compile_variable(op.rhs),
1122                out: self.compile_variable(out),
1123            }),
1124            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1125                input: self.compile_variable(op.input),
1126                out: self.compile_variable(out),
1127            }),
1128            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1129                input: self.compile_variable(op.input),
1130                out: self.compile_variable(out),
1131            }),
1132            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1133                inputs: op
1134                    .inputs
1135                    .into_iter()
1136                    .map(|var| self.compile_variable(var))
1137                    .collect(),
1138                out: self.compile_variable(out),
1139            }),
1140            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1141                input: self.compile_variable(op.input),
1142                in_index: self.compile_variable(op.in_index),
1143                out: self.compile_variable(out),
1144                out_index: self.compile_variable(op.out_index),
1145            }),
1146            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1147                input: self.compile_variable(op.input),
1148                in_index: self.compile_variable(op.in_index),
1149                out: self.compile_variable(out),
1150                out_index: self.compile_variable(op.out_index),
1151                len: op.len as u32,
1152            }),
1153            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1154                cond: self.compile_variable(op.cond),
1155                then: self.compile_variable(op.then),
1156                or_else: self.compile_variable(op.or_else),
1157                out: self.compile_variable(out),
1158            }),
1159        }
1160    }
1161
1162    fn compile_atomic(
1163        &mut self,
1164        atomic: cube::AtomicOp,
1165        out: Option<cube::Variable>,
1166    ) -> wgsl::Instruction {
1167        let out = out.unwrap();
1168        match atomic {
1169            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1170                lhs: self.compile_variable(op.lhs),
1171                rhs: self.compile_variable(op.rhs),
1172                out: self.compile_variable(out),
1173            },
1174            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1175                lhs: self.compile_variable(op.lhs),
1176                rhs: self.compile_variable(op.rhs),
1177                out: self.compile_variable(out),
1178            },
1179            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1180                lhs: self.compile_variable(op.lhs),
1181                rhs: self.compile_variable(op.rhs),
1182                out: self.compile_variable(out),
1183            },
1184            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1185                lhs: self.compile_variable(op.lhs),
1186                rhs: self.compile_variable(op.rhs),
1187                out: self.compile_variable(out),
1188            },
1189            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1190                lhs: self.compile_variable(op.lhs),
1191                rhs: self.compile_variable(op.rhs),
1192                out: self.compile_variable(out),
1193            },
1194            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1195                lhs: self.compile_variable(op.lhs),
1196                rhs: self.compile_variable(op.rhs),
1197                out: self.compile_variable(out),
1198            },
1199            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1200                lhs: self.compile_variable(op.lhs),
1201                rhs: self.compile_variable(op.rhs),
1202                out: self.compile_variable(out),
1203            },
1204            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1205                input: self.compile_variable(op.input),
1206                out: self.compile_variable(out),
1207            },
1208            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1209                input: self.compile_variable(op.input),
1210                out: self.compile_variable(out),
1211            },
1212            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1213                lhs: self.compile_variable(op.lhs),
1214                rhs: self.compile_variable(op.rhs),
1215                out: self.compile_variable(out),
1216            },
1217            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1218                lhs: self.compile_variable(op.input),
1219                cmp: self.compile_variable(op.cmp),
1220                value: self.compile_variable(op.val),
1221                out: self.compile_variable(out),
1222            },
1223        }
1224    }
1225
1226    fn compile_location(value: kernel::Location) -> wgsl::Location {
1227        match value {
1228            kernel::Location::Storage => wgsl::Location::Storage,
1229            kernel::Location::Cube => wgsl::Location::Workgroup,
1230        }
1231    }
1232
1233    fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1234        wgsl::Binding {
1235            id: value.id,
1236            visibility: value.visibility,
1237            location: Self::compile_location(value.location),
1238            item: self.compile_type(value.ty),
1239            size: value.size,
1240        }
1241    }
1242}
1243
1244fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1245    let mut extensions = Vec::new();
1246
1247    let mut register_extension = |extension: wgsl::Extension| {
1248        if !extensions.contains(&extension) {
1249            extensions.push(extension);
1250        }
1251    };
1252
1253    // Since not all instructions are native to WGSL, we need to add the custom ones.
1254    for instruction in instructions {
1255        match instruction {
1256            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1257                register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1258                register_extension(wgsl::powf_extension(rhs, out));
1259            }
1260            #[cfg(target_os = "macos")]
1261            wgsl::Instruction::Tanh { input, out: _ } => {
1262                register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1263                register_extension(wgsl::Extension::SafeTanh(input.item()));
1264            }
1265            wgsl::Instruction::IsNan { input, out } => {
1266                register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1267                register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1268            }
1269            wgsl::Instruction::IsInf { input, out } => {
1270                register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1271                register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1272            }
1273            wgsl::Instruction::If { instructions, .. } => {
1274                for extension in register_extensions(instructions) {
1275                    register_extension(extension);
1276                }
1277            }
1278            wgsl::Instruction::IfElse {
1279                instructions_if,
1280                instructions_else,
1281                ..
1282            } => {
1283                for extension in register_extensions(instructions_if) {
1284                    register_extension(extension);
1285                }
1286                for extension in register_extensions(instructions_else) {
1287                    register_extension(extension);
1288                }
1289            }
1290            wgsl::Instruction::Loop { instructions } => {
1291                for extension in register_extensions(instructions) {
1292                    register_extension(extension);
1293                }
1294            }
1295            wgsl::Instruction::RangeLoop { instructions, .. } => {
1296                for extension in register_extensions(instructions) {
1297                    register_extension(extension);
1298                }
1299            }
1300            _ => {}
1301        }
1302    }
1303
1304    extensions
1305}