cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::post_processing::{
8    checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12    Metadata, WgpuCompilationOptions,
13    ir::{self as cube, Scope},
14    prelude::expand_erf,
15};
16use cubecl_core::{
17    ir::{Processor, UIntKind},
18    post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: usize = 4;
24
25/// Wgsl Compiler.
26#[derive(Clone, Default)]
27pub struct WgslCompiler {
28    metadata: Metadata,
29    ext_meta_pos: Vec<u32>,
30    local_invocation_index: bool,
31    local_invocation_id: bool,
32    // TODO: possible cleanup, this bool seems to not be used
33    global_invocation_id: bool,
34    workgroup_id: bool,
35    subgroup_size: bool,
36    subgroup_invocation_id: bool,
37    id: bool,
38    num_workgroups: bool,
39    workgroup_id_no_axis: bool,
40    workgroup_size_no_axis: bool,
41    num_workgroup_no_axis: bool,
42    shared_arrays: Vec<SharedArray>,
43    shared_values: Vec<SharedValue>,
44    const_arrays: Vec<ConstantArray>,
45    local_arrays: Vec<LocalArray>,
46    #[allow(dead_code)]
47    compilation_options: WgpuCompilationOptions,
48    strategy: ExecutionMode,
49    subgroup_instructions_used: bool,
50    f16_used: bool,
51}
52
53impl core::fmt::Debug for WgslCompiler {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.write_str("WgslCompiler")
56    }
57}
58
59impl cubecl_core::Compiler for WgslCompiler {
60    type Representation = ComputeShader;
61    type CompilationOptions = WgpuCompilationOptions;
62
63    fn compile(
64        &mut self,
65        shader: kernel::KernelDefinition,
66        compilation_options: &Self::CompilationOptions,
67        mode: ExecutionMode,
68        address_type: StorageType,
69    ) -> Result<Self::Representation, CompilationError> {
70        self.compilation_options = compilation_options.clone();
71        self.compile_shader(shader, mode, address_type)
72    }
73
74    fn elem_size(&self, elem: cube::ElemType) -> usize {
75        elem.size()
76    }
77
78    fn extension(&self) -> &'static str {
79        "wgsl"
80    }
81}
82
83impl WgslCompiler {
84    fn compile_shader(
85        &mut self,
86        mut value: kernel::KernelDefinition,
87        mode: ExecutionMode,
88        address_type: StorageType,
89    ) -> Result<wgsl::ComputeShader, CompilationError> {
90        let errors = value.body.pop_errors();
91        if !errors.is_empty() {
92            let mut reason = "Can't compile wgsl kernel".to_string();
93            for error in errors {
94                reason += error.as_str();
95                reason += "\n";
96            }
97
98            return Err(CompilationError::Validation {
99                reason,
100                backtrace: BackTrace::capture(),
101            });
102        }
103
104        self.strategy = mode;
105
106        let num_meta = value.buffers.len();
107
108        self.ext_meta_pos = Vec::new();
109        let mut num_ext = 0;
110
111        for binding in value.buffers.iter() {
112            self.ext_meta_pos.push(num_ext);
113            if binding.has_extended_meta {
114                num_ext += 1;
115            }
116        }
117
118        self.metadata = Metadata::new(num_meta as u32, num_ext);
119
120        let address_type = self.compile_storage_type(address_type);
121        let instructions = self.compile_scope(&mut value.body);
122        let extensions = register_extensions(&instructions);
123        let body = wgsl::Body {
124            instructions,
125            id: self.id,
126            address_type,
127        };
128
129        Ok(wgsl::ComputeShader {
130            address_type,
131            buffers: value
132                .buffers
133                .into_iter()
134                .map(|mut it| {
135                    // This is safe when combined with the unroll transform that adjusts all indices.
136                    // Must not be used alone
137                    if it.ty.line_size() > MAX_LINE_SIZE {
138                        it.ty = it.ty.line(MAX_LINE_SIZE);
139                    }
140                    self.compile_binding(it)
141                })
142                .collect(),
143            scalars: value
144                .scalars
145                .into_iter()
146                .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
147                .collect(),
148            shared_arrays: self.shared_arrays.clone(),
149            shared_values: self.shared_values.clone(),
150            constant_arrays: self.const_arrays.clone(),
151            local_arrays: self.local_arrays.clone(),
152            has_metadata: self.metadata.static_len() > 0,
153            workgroup_size: value.cube_dim,
154            global_invocation_id: self.global_invocation_id || self.id,
155            local_invocation_index: self.local_invocation_index,
156            local_invocation_id: self.local_invocation_id,
157            num_workgroups: self.id
158                || self.num_workgroups
159                || self.num_workgroup_no_axis
160                || self.workgroup_id_no_axis,
161            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
162            subgroup_size: self.subgroup_size,
163            subgroup_invocation_id: self.subgroup_invocation_id,
164            body,
165            extensions,
166            num_workgroups_no_axis: self.num_workgroup_no_axis,
167            workgroup_id_no_axis: self.workgroup_id_no_axis,
168            workgroup_size_no_axis: self.workgroup_size_no_axis,
169            subgroup_instructions_used: self.subgroup_instructions_used,
170            f16_used: self.f16_used,
171            kernel_name: value.options.kernel_name,
172        })
173    }
174
175    fn compile_type(&mut self, item: cube::Type) -> Item {
176        match item {
177            cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
178            cube::Type::Line(ty, size) => {
179                let elem = self.compile_storage_type(ty);
180                match size {
181                    2 => wgsl::Item::Vec2(elem),
182                    3 => wgsl::Item::Vec3(elem),
183                    4 => wgsl::Item::Vec4(elem),
184                    _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
185                }
186            }
187            cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
188        }
189    }
190
191    fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
192        match ty {
193            cube::StorageType::Scalar(ty) => self.compile_elem(ty),
194            cube::StorageType::Atomic(ty) => match ty {
195                cube::ElemType::Float(i) => match i {
196                    cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
197                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
198                },
199                cube::ElemType::Int(i) => match i {
200                    cube::IntKind::I32 => wgsl::Elem::AtomicI32,
201                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
202                },
203                cube::ElemType::UInt(kind) => match kind {
204                    cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
205                    kind => panic!("{kind:?} is not a valid WgpuElement"),
206                },
207                other => panic!("{other:?} is not a valid WgpuElement"),
208            },
209            cube::StorageType::Packed(_, _) => {
210                unimplemented!("Packed types not yet supported in WGSL")
211            }
212            cube::StorageType::Opaque(ty) => match ty {
213                cube::OpaqueType::Barrier(_) => {
214                    unimplemented!("Barrier objects not supported in WGSL")
215                }
216            },
217        }
218    }
219
220    fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
221        match value {
222            cube::ElemType::Float(f) => match f {
223                cube::FloatKind::E2M1
224                | cube::FloatKind::E2M3
225                | cube::FloatKind::E3M2
226                | cube::FloatKind::E4M3
227                | cube::FloatKind::E5M2
228                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
229                cube::FloatKind::F16 => {
230                    self.f16_used = true;
231                    wgsl::Elem::F16
232                }
233                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
234                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
235                cube::FloatKind::Flex32 => wgsl::Elem::F32,
236                cube::FloatKind::F32 => wgsl::Elem::F32,
237                cube::FloatKind::F64 => wgsl::Elem::F64,
238            },
239            cube::ElemType::Int(i) => match i {
240                cube::IntKind::I32 => wgsl::Elem::I32,
241                cube::IntKind::I64 => wgsl::Elem::I64,
242                kind => panic!("{kind:?} is not a valid WgpuElement"),
243            },
244            cube::ElemType::UInt(kind) => match kind {
245                cube::UIntKind::U32 => wgsl::Elem::U32,
246                cube::UIntKind::U64 => wgsl::Elem::U64,
247                kind => panic!("{kind:?} is not a valid WgpuElement"),
248            },
249            cube::ElemType::Bool => wgsl::Elem::Bool,
250        }
251    }
252
253    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
254        let pos = var.index().expect("Variable should have index");
255        self.ext_meta_pos[pos as usize]
256    }
257
258    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
259        let item = value.ty;
260        match value.kind {
261            cube::VariableKind::GlobalInputArray(id) => {
262                wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
263            }
264            cube::VariableKind::GlobalScalar(id) => {
265                wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
266            }
267            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
268                wgsl::Variable::LocalMut {
269                    id,
270                    item: self.compile_type(item),
271                }
272            }
273            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
274                id,
275                item: self.compile_type(item),
276            },
277            cube::VariableKind::GlobalOutputArray(id) => {
278                wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
279            }
280            cube::VariableKind::Constant(value) => {
281                wgsl::Variable::Constant(value, self.compile_type(item))
282            }
283            cube::VariableKind::SharedArray {
284                id,
285                length,
286                unroll_factor,
287                alignment,
288            } => {
289                let item = self.compile_type(item);
290                if !self.shared_arrays.iter().any(|s| s.index == id) {
291                    self.shared_arrays.push(SharedArray::new(
292                        id,
293                        item,
294                        (length * unroll_factor) as u32,
295                        alignment.map(|it| it as u32),
296                    ));
297                }
298                wgsl::Variable::SharedArray(id, item, length as u32)
299            }
300            cube::VariableKind::Shared { id } => {
301                let item = self.compile_type(item);
302                if !self.shared_values.iter().any(|s| s.index == id) {
303                    self.shared_values.push(SharedValue::new(id, item));
304                }
305                wgsl::Variable::SharedValue(id, item)
306            }
307            cube::VariableKind::ConstantArray { id, length, .. } => {
308                let item = self.compile_type(item);
309                wgsl::Variable::ConstantArray(id, item, length as u32)
310            }
311            cube::VariableKind::LocalArray {
312                id,
313                length,
314                unroll_factor,
315            } => {
316                let item = self.compile_type(item);
317                if !self.local_arrays.iter().any(|s| s.index == id) {
318                    self.local_arrays.push(LocalArray::new(
319                        id,
320                        item,
321                        (length * unroll_factor) as u32,
322                    ));
323                }
324                wgsl::Variable::LocalArray(id, item, length as u32)
325            }
326            cube::VariableKind::Builtin(builtin) => match builtin {
327                cube::Builtin::AbsolutePos => {
328                    self.id = true;
329                    wgsl::Variable::Id
330                }
331                cube::Builtin::UnitPos => {
332                    self.local_invocation_index = true;
333                    wgsl::Variable::LocalInvocationIndex
334                }
335                cube::Builtin::UnitPosX => {
336                    self.local_invocation_id = true;
337                    wgsl::Variable::LocalInvocationIdX
338                }
339                cube::Builtin::UnitPosY => {
340                    self.local_invocation_id = true;
341                    wgsl::Variable::LocalInvocationIdY
342                }
343                cube::Builtin::UnitPosZ => {
344                    self.local_invocation_id = true;
345                    wgsl::Variable::LocalInvocationIdZ
346                }
347                cube::Builtin::CubePosX => {
348                    self.workgroup_id = true;
349                    wgsl::Variable::WorkgroupIdX
350                }
351                cube::Builtin::CubePosY => {
352                    self.workgroup_id = true;
353                    wgsl::Variable::WorkgroupIdY
354                }
355                cube::Builtin::CubePosZ => {
356                    self.workgroup_id = true;
357                    wgsl::Variable::WorkgroupIdZ
358                }
359                cube::Builtin::CubePosCluster
360                | cube::Builtin::CubePosClusterX
361                | cube::Builtin::CubePosClusterY
362                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
363                cube::Builtin::AbsolutePosX => {
364                    self.global_invocation_id = true;
365                    wgsl::Variable::GlobalInvocationIdX
366                }
367                cube::Builtin::AbsolutePosY => {
368                    self.global_invocation_id = true;
369                    wgsl::Variable::GlobalInvocationIdY
370                }
371                cube::Builtin::AbsolutePosZ => {
372                    self.global_invocation_id = true;
373                    wgsl::Variable::GlobalInvocationIdZ
374                }
375                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
376                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
377                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
378                cube::Builtin::CubeClusterDim
379                | cube::Builtin::CubeClusterDimX
380                | cube::Builtin::CubeClusterDimY
381                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
382                cube::Builtin::CubeCountX => {
383                    self.num_workgroups = true;
384                    wgsl::Variable::NumWorkgroupsX
385                }
386                cube::Builtin::CubeCountY => {
387                    self.num_workgroups = true;
388                    wgsl::Variable::NumWorkgroupsY
389                }
390                cube::Builtin::CubeCountZ => {
391                    self.num_workgroups = true;
392                    wgsl::Variable::NumWorkgroupsZ
393                }
394                cube::Builtin::CubePos => {
395                    self.workgroup_id_no_axis = true;
396                    wgsl::Variable::WorkgroupId
397                }
398                cube::Builtin::CubeDim => {
399                    self.workgroup_size_no_axis = true;
400                    wgsl::Variable::WorkgroupSize
401                }
402                cube::Builtin::CubeCount => {
403                    self.num_workgroup_no_axis = true;
404                    wgsl::Variable::NumWorkgroups
405                }
406                cube::Builtin::PlaneDim => {
407                    self.subgroup_size = true;
408                    wgsl::Variable::SubgroupSize
409                }
410                cube::Builtin::UnitPosPlane => {
411                    self.subgroup_invocation_id = true;
412                    wgsl::Variable::SubgroupInvocationId
413                }
414            },
415            cube::VariableKind::Matrix { .. } => {
416                panic!("Cooperative matrix-multiply and accumulate not supported.")
417            }
418            cube::VariableKind::Pipeline { .. } => {
419                panic!("Pipeline not supported.")
420            }
421            cube::VariableKind::BarrierToken { .. } => {
422                panic!("Barrier not supported.")
423            }
424            cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
425            cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
426        }
427    }
428
429    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
430        let var = cube::Variable::constant(value.into(), UIntKind::U32);
431        self.compile_variable(var)
432    }
433
434    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
435        let mut instructions = Vec::new();
436
437        let const_arrays = scope
438            .const_arrays
439            .drain(..)
440            .map(|(var, values)| ConstantArray {
441                index: var.index().unwrap(),
442                item: self.compile_type(var.ty),
443                size: values.len() as u32,
444                values: values
445                    .into_iter()
446                    .map(|val| self.compile_variable(val))
447                    .collect(),
448            })
449            .collect::<Vec<_>>();
450        self.const_arrays.extend(const_arrays);
451
452        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
453        let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
454        let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
455        let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
456
457        for mut var in processing.variables {
458            if var.ty.line_size() > MAX_LINE_SIZE {
459                var.ty = var.ty.line(MAX_LINE_SIZE);
460            }
461            instructions.push(wgsl::Instruction::DeclareVariable {
462                var: self.compile_variable(var),
463            });
464        }
465
466        processing
467            .instructions
468            .into_iter()
469            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
470
471        instructions
472    }
473
474    fn compile_operation(
475        &mut self,
476        instructions: &mut Vec<wgsl::Instruction>,
477        operation: cube::Operation,
478        out: Option<cube::Variable>,
479        scope: &mut cube::Scope,
480    ) {
481        match operation {
482            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
483                input: self.compile_variable(variable),
484                out: self.compile_variable(out.unwrap()),
485            }),
486            cube::Operation::Arithmetic(op) => {
487                self.compile_arithmetic(op, out, instructions, scope)
488            }
489            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
490            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
491            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
492            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
493            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
494            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
495            cube::Operation::Synchronization(val) => {
496                self.compile_synchronization(instructions, val)
497            }
498            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
499            cube::Operation::CoopMma(_) => {
500                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
501            }
502            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
503                self.compile_comment(instructions, content)
504            }
505            cube::Operation::NonSemantic(_) => {}
506            cube::Operation::Barrier(_) => {
507                panic!("Barrier isn't supported on wgpu.")
508            }
509            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
510            cube::Operation::Marker(_) => {}
511        }
512    }
513
514    fn compile_subgroup(
515        &mut self,
516        instructions: &mut Vec<wgsl::Instruction>,
517        subgroup: cube::Plane,
518        out: Option<cube::Variable>,
519    ) {
520        self.subgroup_instructions_used = true;
521
522        let out = out.unwrap();
523        let op = match subgroup {
524            cube::Plane::Elect => Subgroup::Elect {
525                out: self.compile_variable(out),
526            },
527            cube::Plane::All(op) => Subgroup::All {
528                input: self.compile_variable(op.input),
529                out: self.compile_variable(out),
530            },
531            cube::Plane::Any(op) => Subgroup::Any {
532                input: self.compile_variable(op.input),
533                out: self.compile_variable(out),
534            },
535            cube::Plane::Ballot(op) => Subgroup::Ballot {
536                input: self.compile_variable(op.input),
537                out: self.compile_variable(out),
538            },
539
540            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
541                lhs: self.compile_variable(op.lhs),
542                rhs: self.compile_variable(op.rhs),
543                out: self.compile_variable(out),
544            },
545
546            cube::Plane::Sum(op) => Subgroup::Sum {
547                input: self.compile_variable(op.input),
548                out: self.compile_variable(out),
549            },
550
551            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
552                input: self.compile_variable(op.input),
553                out: self.compile_variable(out),
554            },
555            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
556                input: self.compile_variable(op.input),
557                out: self.compile_variable(out),
558            },
559            cube::Plane::Prod(op) => Subgroup::Prod {
560                input: self.compile_variable(op.input),
561                out: self.compile_variable(out),
562            },
563            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
564                input: self.compile_variable(op.input),
565                out: self.compile_variable(out),
566            },
567            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
568                input: self.compile_variable(op.input),
569                out: self.compile_variable(out),
570            },
571            cube::Plane::Min(op) => Subgroup::Min {
572                input: self.compile_variable(op.input),
573                out: self.compile_variable(out),
574            },
575            cube::Plane::Max(op) => Subgroup::Max {
576                input: self.compile_variable(op.input),
577                out: self.compile_variable(out),
578            },
579            cube::Plane::Shuffle(op) => Subgroup::Shuffle {
580                lhs: self.compile_variable(op.lhs),
581                rhs: self.compile_variable(op.rhs),
582                out: self.compile_variable(out),
583            },
584            cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
585                lhs: self.compile_variable(op.lhs),
586                rhs: self.compile_variable(op.rhs),
587                out: self.compile_variable(out),
588            },
589            cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
590                lhs: self.compile_variable(op.lhs),
591                rhs: self.compile_variable(op.rhs),
592                out: self.compile_variable(out),
593            },
594            cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
595                lhs: self.compile_variable(op.lhs),
596                rhs: self.compile_variable(op.rhs),
597                out: self.compile_variable(out),
598            },
599        };
600
601        instructions.push(wgsl::Instruction::Subgroup(op));
602    }
603
604    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
605        match branch {
606            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
607                cond: self.compile_variable(op.cond),
608                instructions: self.compile_scope(&mut op.scope),
609            }),
610            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
611                cond: self.compile_variable(op.cond),
612                instructions_if: self.compile_scope(&mut op.scope_if),
613                instructions_else: self.compile_scope(&mut op.scope_else),
614            }),
615            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
616                value: self.compile_variable(op.value),
617                instructions_default: self.compile_scope(&mut op.scope_default),
618                cases: op
619                    .cases
620                    .into_iter()
621                    .map(|(val, mut scope)| {
622                        (self.compile_variable(val), self.compile_scope(&mut scope))
623                    })
624                    .collect(),
625            }),
626            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
627            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
628            cube::Branch::RangeLoop(mut range_loop) => {
629                instructions.push(wgsl::Instruction::RangeLoop {
630                    i: self.compile_variable(range_loop.i),
631                    start: self.compile_variable(range_loop.start),
632                    end: self.compile_variable(range_loop.end),
633                    step: range_loop.step.map(|it| self.compile_variable(it)),
634                    inclusive: range_loop.inclusive,
635                    instructions: self.compile_scope(&mut range_loop.scope),
636                })
637            }
638            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
639                instructions: self.compile_scope(&mut op.scope),
640            }),
641        };
642    }
643
644    fn compile_synchronization(
645        &mut self,
646        instructions: &mut Vec<wgsl::Instruction>,
647        synchronization: cube::Synchronization,
648    ) {
649        match synchronization {
650            cube::Synchronization::SyncCube => {
651                instructions.push(wgsl::Instruction::WorkgroupBarrier)
652            }
653            cube::Synchronization::SyncPlane => {
654                panic!("Synchronization within a plane is not supported in WGSL")
655            }
656            cube::Synchronization::SyncStorage => {
657                instructions.push(wgsl::Instruction::StorageBarrier)
658            }
659            cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
660        };
661    }
662
663    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
664        instructions.push(wgsl::Instruction::Comment { content })
665    }
666
667    fn compile_metadata(
668        &mut self,
669        metadata: cube::Metadata,
670        out: Option<cube::Variable>,
671    ) -> wgsl::Instruction {
672        let out = out.unwrap();
673        match metadata {
674            cube::Metadata::Rank { var } => {
675                let position = self.ext_meta_pos(&var);
676                let offset = self.metadata.rank_index(position);
677                wgsl::Instruction::Metadata {
678                    out: self.compile_variable(out),
679                    info_offset: self.compile_variable(offset.into()),
680                }
681            }
682            cube::Metadata::Stride { dim, var } => {
683                let position = self.ext_meta_pos(&var);
684                let offset = self.metadata.stride_offset_index(position);
685                wgsl::Instruction::ExtendedMeta {
686                    info_offset: self.compile_variable(offset.into()),
687                    dim: self.compile_variable(dim),
688                    out: self.compile_variable(out),
689                }
690            }
691            cube::Metadata::Shape { dim, var } => {
692                let position = self.ext_meta_pos(&var);
693                let offset = self.metadata.shape_offset_index(position);
694                wgsl::Instruction::ExtendedMeta {
695                    info_offset: self.compile_variable(offset.into()),
696                    dim: self.compile_variable(dim),
697                    out: self.compile_variable(out),
698                }
699            }
700            cube::Metadata::Length { var } => match var.kind {
701                cube::VariableKind::GlobalInputArray(id) => {
702                    let offset = self.metadata.len_index(id);
703                    wgsl::Instruction::Metadata {
704                        out: self.compile_variable(out),
705                        info_offset: self.compile_variable(offset.into()),
706                    }
707                }
708                cube::VariableKind::GlobalOutputArray(id) => {
709                    let offset = self.metadata.len_index(id);
710                    wgsl::Instruction::Metadata {
711                        out: self.compile_variable(out),
712                        info_offset: self.compile_variable(offset.into()),
713                    }
714                }
715                _ => wgsl::Instruction::Length {
716                    var: self.compile_variable(var),
717                    out: self.compile_variable(out),
718                },
719            },
720            cube::Metadata::BufferLength { var } => match var.kind {
721                cube::VariableKind::GlobalInputArray(id) => {
722                    let offset = self.metadata.buffer_len_index(id);
723                    wgsl::Instruction::Metadata {
724                        out: self.compile_variable(out),
725                        info_offset: self.compile_variable(offset.into()),
726                    }
727                }
728                cube::VariableKind::GlobalOutputArray(id) => {
729                    let offset = self.metadata.buffer_len_index(id);
730                    wgsl::Instruction::Metadata {
731                        out: self.compile_variable(out),
732                        info_offset: self.compile_variable(offset.into()),
733                    }
734                }
735                _ => wgsl::Instruction::Length {
736                    var: self.compile_variable(var),
737                    out: self.compile_variable(out),
738                },
739            },
740        }
741    }
742
743    fn compile_arithmetic(
744        &mut self,
745        value: cube::Arithmetic,
746        out: Option<cube::Variable>,
747        instructions: &mut Vec<wgsl::Instruction>,
748        scope: &mut Scope,
749    ) {
750        let out = out.unwrap();
751        match value {
752            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
753                lhs: self.compile_variable(op.lhs),
754                rhs: self.compile_variable(op.rhs),
755                out: self.compile_variable(out),
756            }),
757            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
758                lhs: self.compile_variable(op.lhs),
759                rhs: self.compile_variable(op.rhs),
760                out: self.compile_variable(out),
761            }),
762            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
763                lhs: self.compile_variable(op.lhs),
764                rhs: self.compile_variable(op.rhs),
765                out: self.compile_variable(out),
766            }),
767            cube::Arithmetic::SaturatingAdd(_) => {
768                unreachable!("Saturating add should be removed by processor");
769            }
770            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
771                a: self.compile_variable(op.a),
772                b: self.compile_variable(op.b),
773                c: self.compile_variable(op.c),
774                out: self.compile_variable(out),
775            }),
776            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
777                lhs: self.compile_variable(op.lhs),
778                rhs: self.compile_variable(op.rhs),
779                out: self.compile_variable(out),
780            }),
781            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
782                lhs: self.compile_variable(op.lhs),
783                rhs: self.compile_variable(op.rhs),
784                out: self.compile_variable(out),
785            }),
786            cube::Arithmetic::SaturatingSub(_) => {
787                unreachable!("Saturating sub should be removed by processor");
788            }
789            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
790                lhs: self.compile_variable(op.lhs),
791                rhs: self.compile_variable(op.rhs),
792                out: self.compile_variable(out),
793            }),
794            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
795                lhs: self.compile_variable(op.lhs),
796                rhs: self.compile_variable(op.rhs),
797                out: self.compile_variable(out),
798            }),
799            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
800                input: self.compile_variable(op.input),
801                out: self.compile_variable(out),
802            }),
803            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
804                input: self.compile_variable(op.input),
805                out: self.compile_variable(out),
806            }),
807            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
808                input: self.compile_variable(op.input),
809                out: self.compile_variable(out),
810            }),
811            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
812                input: self.compile_variable(op.input),
813                out: self.compile_variable(out),
814            }),
815            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
816                input: self.compile_variable(op.input),
817                out: self.compile_variable(out),
818            }),
819            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
820                input: self.compile_variable(op.input),
821                out: self.compile_variable(out),
822            }),
823            cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
824                input: self.compile_variable(op.input),
825                out: self.compile_variable(out),
826            }),
827            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
828                input: self.compile_variable(op.input),
829                out: self.compile_variable(out),
830            }),
831            cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
832                input: self.compile_variable(op.input),
833                out: self.compile_variable(out),
834            }),
835            cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
836                input: self.compile_variable(op.input),
837                out: self.compile_variable(out),
838            }),
839            cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
840                input: self.compile_variable(op.input),
841                out: self.compile_variable(out),
842            }),
843            cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
844                input: self.compile_variable(op.input),
845                out: self.compile_variable(out),
846            }),
847            cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
848                input: self.compile_variable(op.input),
849                out: self.compile_variable(out),
850            }),
851            cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
852                input: self.compile_variable(op.input),
853                out: self.compile_variable(out),
854            }),
855            cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
856                input: self.compile_variable(op.input),
857                out: self.compile_variable(out),
858            }),
859            cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
860                input: self.compile_variable(op.input),
861                out: self.compile_variable(out),
862            }),
863            cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
864                input: self.compile_variable(op.input),
865                out: self.compile_variable(out),
866            }),
867            cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
868                input: self.compile_variable(op.input),
869                out: self.compile_variable(out),
870            }),
871            cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
872                lhs: self.compile_variable(op.lhs),
873                rhs: self.compile_variable(op.rhs),
874                out: self.compile_variable(out),
875            }),
876            // No powi in WGSL
877            cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
878                instructions.push(wgsl::Instruction::Powf {
879                    lhs: self.compile_variable(op.lhs),
880                    rhs: self.compile_variable(op.rhs),
881                    out: self.compile_variable(out),
882                })
883            }
884            cube::Arithmetic::Hypot(op) => {
885                let mut scope = scope.child();
886                expand_hypot(&mut scope, op.lhs, op.rhs, out);
887                instructions.extend(self.compile_scope(&mut scope));
888            }
889            cube::Arithmetic::Rhypot(op) => {
890                let mut scope = scope.child();
891                expand_rhypot(&mut scope, op.lhs, op.rhs, out);
892                instructions.extend(self.compile_scope(&mut scope));
893            }
894
895            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
896                input: self.compile_variable(op.input),
897                out: self.compile_variable(out),
898            }),
899            cube::Arithmetic::InverseSqrt(op) => {
900                instructions.push(wgsl::Instruction::InverseSqrt {
901                    input: self.compile_variable(op.input),
902                    out: self.compile_variable(out),
903                })
904            }
905            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
906                input: self.compile_variable(op.input),
907                out: self.compile_variable(out),
908            }),
909            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
910                input: self.compile_variable(op.input),
911                out: self.compile_variable(out),
912            }),
913            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
914                input: self.compile_variable(op.input),
915                out: self.compile_variable(out),
916            }),
917            cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
918                input: self.compile_variable(op.input),
919                out: self.compile_variable(out),
920            }),
921            cube::Arithmetic::Erf(op) => {
922                let mut scope = scope.child();
923                expand_erf(&mut scope, op.input, out);
924                instructions.extend(self.compile_scope(&mut scope));
925            }
926            cube::Arithmetic::MulHi(op) => {
927                let mut scope = scope.child();
928                match self.compilation_options.supports_u64 {
929                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
930                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
931                }
932                instructions.extend(self.compile_scope(&mut scope));
933            }
934            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
935                input: self.compile_variable(op.input),
936                out: self.compile_variable(out),
937            }),
938            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
939                input: self.compile_variable(op.input),
940                min_value: self.compile_variable(op.min_value),
941                max_value: self.compile_variable(op.max_value),
942                out: self.compile_variable(out),
943            }),
944            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
945                lhs: self.compile_variable(op.lhs),
946                rhs: self.compile_variable(op.rhs),
947                out: self.compile_variable(out),
948            }),
949            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
950                input: self.compile_variable(op.input),
951                out: self.compile_variable(out),
952            }),
953            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
954                input: self.compile_variable(op.input),
955                out: self.compile_variable(out),
956            }),
957            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
958                input: self.compile_variable(op.input),
959                out: self.compile_variable(out),
960            }),
961            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
962                lhs: self.compile_variable(op.lhs),
963                rhs: self.compile_variable(op.rhs),
964                out: self.compile_variable(out),
965            }),
966        }
967    }
968
969    fn compile_cmp(
970        &mut self,
971        value: cube::Comparison,
972        out: Option<cube::Variable>,
973        instructions: &mut Vec<wgsl::Instruction>,
974    ) {
975        let out = out.unwrap();
976        match value {
977            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
978                lhs: self.compile_variable(op.lhs),
979                rhs: self.compile_variable(op.rhs),
980                out: self.compile_variable(out),
981            }),
982            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
983                lhs: self.compile_variable(op.lhs),
984                rhs: self.compile_variable(op.rhs),
985                out: self.compile_variable(out),
986            }),
987            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
988                lhs: self.compile_variable(op.lhs),
989                rhs: self.compile_variable(op.rhs),
990                out: self.compile_variable(out),
991            }),
992            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
993                lhs: self.compile_variable(op.lhs),
994                rhs: self.compile_variable(op.rhs),
995                out: self.compile_variable(out),
996            }),
997            cube::Comparison::GreaterEqual(op) => {
998                instructions.push(wgsl::Instruction::GreaterEqual {
999                    lhs: self.compile_variable(op.lhs),
1000                    rhs: self.compile_variable(op.rhs),
1001                    out: self.compile_variable(out),
1002                })
1003            }
1004            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
1005                lhs: self.compile_variable(op.lhs),
1006                rhs: self.compile_variable(op.rhs),
1007                out: self.compile_variable(out),
1008            }),
1009            cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1010                input: self.compile_variable(op.input),
1011                out: self.compile_variable(out),
1012            }),
1013            cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1014                input: self.compile_variable(op.input),
1015                out: self.compile_variable(out),
1016            }),
1017        }
1018    }
1019
1020    fn compile_bitwise(
1021        &mut self,
1022        value: cube::Bitwise,
1023        out: Option<cube::Variable>,
1024        instructions: &mut Vec<wgsl::Instruction>,
1025    ) {
1026        let out = out.unwrap();
1027        match value {
1028            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1029                lhs: self.compile_variable(op.lhs),
1030                rhs: self.compile_variable(op.rhs),
1031                out: self.compile_variable(out),
1032            }),
1033            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1034                lhs: self.compile_variable(op.lhs),
1035                rhs: self.compile_variable(op.rhs),
1036                out: self.compile_variable(out),
1037            }),
1038            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1039                lhs: self.compile_variable(op.lhs),
1040                rhs: self.compile_variable(op.rhs),
1041                out: self.compile_variable(out),
1042            }),
1043            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1044                input: self.compile_variable(op.input),
1045                out: self.compile_variable(out),
1046            }),
1047            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1048                input: self.compile_variable(op.input),
1049                out: self.compile_variable(out),
1050            }),
1051            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1052                lhs: self.compile_variable(op.lhs),
1053                rhs: self.compile_variable(op.rhs),
1054                out: self.compile_variable(out),
1055            }),
1056            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1057                lhs: self.compile_variable(op.lhs),
1058                rhs: self.compile_variable(op.rhs),
1059                out: self.compile_variable(out),
1060            }),
1061            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1062                input: self.compile_variable(op.input),
1063                out: self.compile_variable(out),
1064            }),
1065            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1066                input: self.compile_variable(op.input),
1067                out: self.compile_variable(out),
1068            }),
1069            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1070                input: self.compile_variable(op.input),
1071                out: self.compile_variable(out),
1072            }),
1073        }
1074    }
1075
1076    fn compile_operator(
1077        &mut self,
1078        value: cube::Operator,
1079        out: Option<cube::Variable>,
1080        instructions: &mut Vec<wgsl::Instruction>,
1081    ) {
1082        let out = out.unwrap();
1083        match value {
1084            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1085                input: self.compile_variable(op.input),
1086                out: self.compile_variable(out),
1087            }),
1088            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1089                instructions.push(wgsl::Instruction::Index {
1090                    lhs: self.compile_variable(op.list),
1091                    rhs: self.compile_variable(op.index),
1092                    out: self.compile_variable(out),
1093                });
1094            }
1095            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1096                instructions.push(wgsl::Instruction::IndexAssign {
1097                    index: self.compile_variable(op.index),
1098                    rhs: self.compile_variable(op.value),
1099                    out: self.compile_variable(out),
1100                })
1101            }
1102            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1103                lhs: self.compile_variable(op.lhs),
1104                rhs: self.compile_variable(op.rhs),
1105                out: self.compile_variable(out),
1106            }),
1107            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1108                lhs: self.compile_variable(op.lhs),
1109                rhs: self.compile_variable(op.rhs),
1110                out: self.compile_variable(out),
1111            }),
1112            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1113                input: self.compile_variable(op.input),
1114                out: self.compile_variable(out),
1115            }),
1116            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1117                input: self.compile_variable(op.input),
1118                out: self.compile_variable(out),
1119            }),
1120            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1121                inputs: op
1122                    .inputs
1123                    .into_iter()
1124                    .map(|var| self.compile_variable(var))
1125                    .collect(),
1126                out: self.compile_variable(out),
1127            }),
1128            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1129                input: self.compile_variable(op.input),
1130                in_index: self.compile_variable(op.in_index),
1131                out: self.compile_variable(out),
1132                out_index: self.compile_variable(op.out_index),
1133            }),
1134            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1135                input: self.compile_variable(op.input),
1136                in_index: self.compile_variable(op.in_index),
1137                out: self.compile_variable(out),
1138                out_index: self.compile_variable(op.out_index),
1139                len: op.len as u32,
1140            }),
1141            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1142                cond: self.compile_variable(op.cond),
1143                then: self.compile_variable(op.then),
1144                or_else: self.compile_variable(op.or_else),
1145                out: self.compile_variable(out),
1146            }),
1147        }
1148    }
1149
1150    fn compile_atomic(
1151        &mut self,
1152        atomic: cube::AtomicOp,
1153        out: Option<cube::Variable>,
1154    ) -> wgsl::Instruction {
1155        let out = out.unwrap();
1156        match atomic {
1157            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1158                lhs: self.compile_variable(op.lhs),
1159                rhs: self.compile_variable(op.rhs),
1160                out: self.compile_variable(out),
1161            },
1162            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1163                lhs: self.compile_variable(op.lhs),
1164                rhs: self.compile_variable(op.rhs),
1165                out: self.compile_variable(out),
1166            },
1167            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1168                lhs: self.compile_variable(op.lhs),
1169                rhs: self.compile_variable(op.rhs),
1170                out: self.compile_variable(out),
1171            },
1172            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1173                lhs: self.compile_variable(op.lhs),
1174                rhs: self.compile_variable(op.rhs),
1175                out: self.compile_variable(out),
1176            },
1177            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1178                lhs: self.compile_variable(op.lhs),
1179                rhs: self.compile_variable(op.rhs),
1180                out: self.compile_variable(out),
1181            },
1182            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1183                lhs: self.compile_variable(op.lhs),
1184                rhs: self.compile_variable(op.rhs),
1185                out: self.compile_variable(out),
1186            },
1187            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1188                lhs: self.compile_variable(op.lhs),
1189                rhs: self.compile_variable(op.rhs),
1190                out: self.compile_variable(out),
1191            },
1192            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1193                input: self.compile_variable(op.input),
1194                out: self.compile_variable(out),
1195            },
1196            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1197                input: self.compile_variable(op.input),
1198                out: self.compile_variable(out),
1199            },
1200            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1201                lhs: self.compile_variable(op.lhs),
1202                rhs: self.compile_variable(op.rhs),
1203                out: self.compile_variable(out),
1204            },
1205            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1206                lhs: self.compile_variable(op.input),
1207                cmp: self.compile_variable(op.cmp),
1208                value: self.compile_variable(op.val),
1209                out: self.compile_variable(out),
1210            },
1211        }
1212    }
1213
1214    fn compile_location(value: kernel::Location) -> wgsl::Location {
1215        match value {
1216            kernel::Location::Storage => wgsl::Location::Storage,
1217            kernel::Location::Cube => wgsl::Location::Workgroup,
1218        }
1219    }
1220
1221    fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1222        wgsl::Binding {
1223            id: value.id,
1224            visibility: value.visibility,
1225            location: Self::compile_location(value.location),
1226            item: self.compile_type(value.ty),
1227            size: value.size,
1228        }
1229    }
1230}
1231
1232fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1233    let mut extensions = Vec::new();
1234
1235    let mut register_extension = |extension: wgsl::Extension| {
1236        if !extensions.contains(&extension) {
1237            extensions.push(extension);
1238        }
1239    };
1240
1241    // Since not all instructions are native to WGSL, we need to add the custom ones.
1242    for instruction in instructions {
1243        match instruction {
1244            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1245                register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1246                register_extension(wgsl::powf_extension(rhs, out));
1247            }
1248            #[cfg(target_os = "macos")]
1249            wgsl::Instruction::Tanh { input, out: _ } => {
1250                register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1251                register_extension(wgsl::Extension::SafeTanh(input.item()));
1252            }
1253            wgsl::Instruction::IsNan { input, out } => {
1254                register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1255                register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1256            }
1257            wgsl::Instruction::IsInf { input, out } => {
1258                register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1259                register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1260            }
1261            wgsl::Instruction::If { instructions, .. } => {
1262                for extension in register_extensions(instructions) {
1263                    register_extension(extension);
1264                }
1265            }
1266            wgsl::Instruction::IfElse {
1267                instructions_if,
1268                instructions_else,
1269                ..
1270            } => {
1271                for extension in register_extensions(instructions_if) {
1272                    register_extension(extension);
1273                }
1274                for extension in register_extensions(instructions_else) {
1275                    register_extension(extension);
1276                }
1277            }
1278            wgsl::Instruction::Loop { instructions } => {
1279                for extension in register_extensions(instructions) {
1280                    register_extension(extension);
1281                }
1282            }
1283            wgsl::Instruction::RangeLoop { instructions, .. } => {
1284                for extension in register_extensions(instructions) {
1285                    register_extension(extension);
1286                }
1287            }
1288            _ => {}
1289        }
1290    }
1291
1292    extensions
1293}