cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedMemory};
4use crate::compiler::wgsl;
5
6use cubecl_common::ExecutionMode;
7use cubecl_core::ir::{ConstantScalarValue, Processor, UIntKind};
8use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
9use cubecl_core::prelude::*;
10use cubecl_core::{
11    Metadata, WgpuCompilationOptions, compute,
12    ir::{self as cube, Scope},
13    prelude::expand_erf,
14};
15
16/// Wgsl Compiler.
17#[derive(Clone, Default)]
18pub struct WgslCompiler {
19    metadata: Metadata,
20    ext_meta_pos: Vec<u32>,
21    local_invocation_index: bool,
22    local_invocation_id: bool,
23    // TODO: possible cleanup, this bool seems to not be used
24    global_invocation_id: bool,
25    workgroup_id: bool,
26    subgroup_size: bool,
27    subgroup_invocation_id: bool,
28    id: bool,
29    num_workgroups: bool,
30    workgroup_id_no_axis: bool,
31    workgroup_size_no_axis: bool,
32    num_workgroup_no_axis: bool,
33    shared_memories: Vec<SharedMemory>,
34    const_arrays: Vec<ConstantArray>,
35    local_arrays: Vec<LocalArray>,
36    #[allow(dead_code)]
37    compilation_options: WgpuCompilationOptions,
38    strategy: ExecutionMode,
39    subgroup_instructions_used: bool,
40    f16_used: bool,
41}
42
43impl core::fmt::Debug for WgslCompiler {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.write_str("WgslCompiler")
46    }
47}
48
49impl cubecl_core::Compiler for WgslCompiler {
50    type Representation = ComputeShader;
51    type CompilationOptions = WgpuCompilationOptions;
52
53    fn compile(
54        &mut self,
55        shader: compute::KernelDefinition,
56        compilation_options: &Self::CompilationOptions,
57        mode: ExecutionMode,
58    ) -> Self::Representation {
59        self.compilation_options = compilation_options.clone();
60        self.compile_shader(shader, mode)
61    }
62
63    fn elem_size(&self, elem: cube::Elem) -> usize {
64        elem.size()
65    }
66
67    fn extension(&self) -> &'static str {
68        "wgsl"
69    }
70}
71
72impl WgslCompiler {
73    fn compile_shader(
74        &mut self,
75        mut value: compute::KernelDefinition,
76        mode: ExecutionMode,
77    ) -> wgsl::ComputeShader {
78        self.strategy = mode;
79
80        let num_meta = value.buffers.len();
81
82        self.ext_meta_pos = Vec::new();
83        let mut num_ext = 0;
84
85        for binding in value.buffers.iter() {
86            self.ext_meta_pos.push(num_ext);
87            if binding.has_extended_meta {
88                num_ext += 1;
89            }
90        }
91
92        self.metadata = Metadata::new(num_meta as u32, num_ext);
93
94        let instructions = self.compile_scope(&mut value.body);
95        let extensions = register_extensions(&instructions);
96        let body = wgsl::Body {
97            instructions,
98            id: self.id,
99        };
100
101        wgsl::ComputeShader {
102            buffers: value
103                .buffers
104                .into_iter()
105                .map(|it| self.compile_binding(it))
106                .collect(),
107            scalars: value
108                .scalars
109                .into_iter()
110                .map(|binding| (self.compile_elem(binding.elem), binding.count))
111                .collect(),
112            shared_memories: self.shared_memories.clone(),
113            constant_arrays: self.const_arrays.clone(),
114            local_arrays: self.local_arrays.clone(),
115            has_metadata: self.metadata.static_len() > 0,
116            workgroup_size: value.cube_dim,
117            global_invocation_id: self.global_invocation_id || self.id,
118            local_invocation_index: self.local_invocation_index,
119            local_invocation_id: self.local_invocation_id,
120            num_workgroups: self.id
121                || self.num_workgroups
122                || self.num_workgroup_no_axis
123                || self.workgroup_id_no_axis,
124            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
125            subgroup_size: self.subgroup_size,
126            subgroup_invocation_id: self.subgroup_invocation_id,
127            body,
128            extensions,
129            num_workgroups_no_axis: self.num_workgroup_no_axis,
130            workgroup_id_no_axis: self.workgroup_id_no_axis,
131            workgroup_size_no_axis: self.workgroup_size_no_axis,
132            subgroup_instructions_used: self.subgroup_instructions_used,
133            f16_used: self.f16_used,
134            kernel_name: value.options.kernel_name,
135        }
136    }
137
138    fn compile_item(&mut self, item: cube::Item) -> Item {
139        let elem = self.compile_elem(item.elem);
140        match item.vectorization.map(|it| it.get()).unwrap_or(1) {
141            1 => wgsl::Item::Scalar(elem),
142            2 => wgsl::Item::Vec2(elem),
143            3 => wgsl::Item::Vec3(elem),
144            4 => wgsl::Item::Vec4(elem),
145            _ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization),
146        }
147    }
148
149    fn compile_elem(&mut self, value: cube::Elem) -> wgsl::Elem {
150        match value {
151            cube::Elem::Float(f) => match f {
152                cube::FloatKind::E2M1
153                | cube::FloatKind::E2M3
154                | cube::FloatKind::E3M2
155                | cube::FloatKind::E4M3
156                | cube::FloatKind::E5M2
157                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
158                cube::FloatKind::F16 => {
159                    self.f16_used = true;
160                    wgsl::Elem::F16
161                }
162                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
163                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
164                cube::FloatKind::Flex32 => wgsl::Elem::F32,
165                cube::FloatKind::F32 => wgsl::Elem::F32,
166                cube::FloatKind::F64 => wgsl::Elem::F64,
167            },
168            cube::Elem::Int(i) => match i {
169                cube::IntKind::I32 => wgsl::Elem::I32,
170                cube::IntKind::I64 => wgsl::Elem::I64,
171                kind => panic!("{kind:?} is not a valid WgpuElement"),
172            },
173            cube::Elem::UInt(kind) => match kind {
174                cube::UIntKind::U32 => wgsl::Elem::U32,
175                cube::UIntKind::U64 => wgsl::Elem::U64,
176                kind => panic!("{kind:?} is not a valid WgpuElement"),
177            },
178            cube::Elem::Bool => wgsl::Elem::Bool,
179            cube::Elem::AtomicFloat(i) => match i {
180                cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
181                kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
182            },
183            cube::Elem::AtomicInt(i) => match i {
184                cube::IntKind::I32 => wgsl::Elem::AtomicI32,
185                kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
186            },
187            cube::Elem::AtomicUInt(kind) => match kind {
188                cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
189                kind => panic!("{kind:?} is not a valid WgpuElement"),
190            },
191        }
192    }
193
194    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
195        let pos = var.index().expect("Variable should have index");
196        self.ext_meta_pos[pos as usize]
197    }
198
199    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
200        let item = value.item;
201        match value.kind {
202            cube::VariableKind::GlobalInputArray(id) => {
203                wgsl::Variable::GlobalInputArray(id, self.compile_item(item))
204            }
205            cube::VariableKind::GlobalScalar(id) => {
206                wgsl::Variable::GlobalScalar(id, self.compile_elem(item.elem), item.elem)
207            }
208            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
209                wgsl::Variable::LocalMut {
210                    id,
211                    item: self.compile_item(item),
212                }
213            }
214            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
215                id,
216                item: self.compile_item(item),
217            },
218            cube::VariableKind::GlobalOutputArray(id) => {
219                wgsl::Variable::GlobalOutputArray(id, self.compile_item(item))
220            }
221            cube::VariableKind::ConstantScalar(value) => {
222                wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem()))
223            }
224            cube::VariableKind::SharedMemory {
225                id,
226                length,
227                alignment,
228            } => {
229                let item = self.compile_item(item);
230                if !self.shared_memories.iter().any(|s| s.index == id) {
231                    self.shared_memories
232                        .push(SharedMemory::new(id, item, length, alignment));
233                }
234                wgsl::Variable::SharedMemory(id, item, length)
235            }
236            cube::VariableKind::ConstantArray { id, length } => {
237                let item = self.compile_item(item);
238                wgsl::Variable::ConstantArray(id, item, length)
239            }
240            cube::VariableKind::LocalArray { id, length } => {
241                let item = self.compile_item(item);
242                if !self.local_arrays.iter().any(|s| s.index == id) {
243                    self.local_arrays.push(LocalArray::new(id, item, length));
244                }
245                wgsl::Variable::LocalArray(id, item, length)
246            }
247            cube::VariableKind::Builtin(builtin) => match builtin {
248                cube::Builtin::AbsolutePos => {
249                    self.id = true;
250                    wgsl::Variable::Id
251                }
252                cube::Builtin::UnitPos => {
253                    self.local_invocation_index = true;
254                    wgsl::Variable::LocalInvocationIndex
255                }
256                cube::Builtin::UnitPosX => {
257                    self.local_invocation_id = true;
258                    wgsl::Variable::LocalInvocationIdX
259                }
260                cube::Builtin::UnitPosY => {
261                    self.local_invocation_id = true;
262                    wgsl::Variable::LocalInvocationIdY
263                }
264                cube::Builtin::UnitPosZ => {
265                    self.local_invocation_id = true;
266                    wgsl::Variable::LocalInvocationIdZ
267                }
268                cube::Builtin::CubePosX => {
269                    self.workgroup_id = true;
270                    wgsl::Variable::WorkgroupIdX
271                }
272                cube::Builtin::CubePosY => {
273                    self.workgroup_id = true;
274                    wgsl::Variable::WorkgroupIdY
275                }
276                cube::Builtin::CubePosZ => {
277                    self.workgroup_id = true;
278                    wgsl::Variable::WorkgroupIdZ
279                }
280                cube::Builtin::CubePosCluster
281                | cube::Builtin::CubePosClusterX
282                | cube::Builtin::CubePosClusterY
283                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
284                cube::Builtin::AbsolutePosX => {
285                    self.global_invocation_id = true;
286                    wgsl::Variable::GlobalInvocationIdX
287                }
288                cube::Builtin::AbsolutePosY => {
289                    self.global_invocation_id = true;
290                    wgsl::Variable::GlobalInvocationIdY
291                }
292                cube::Builtin::AbsolutePosZ => {
293                    self.global_invocation_id = true;
294                    wgsl::Variable::GlobalInvocationIdZ
295                }
296                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
297                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
298                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
299                cube::Builtin::CubeClusterDim
300                | cube::Builtin::CubeClusterDimX
301                | cube::Builtin::CubeClusterDimY
302                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
303                cube::Builtin::CubeCountX => {
304                    self.num_workgroups = true;
305                    wgsl::Variable::NumWorkgroupsX
306                }
307                cube::Builtin::CubeCountY => {
308                    self.num_workgroups = true;
309                    wgsl::Variable::NumWorkgroupsY
310                }
311                cube::Builtin::CubeCountZ => {
312                    self.num_workgroups = true;
313                    wgsl::Variable::NumWorkgroupsZ
314                }
315                cube::Builtin::CubePos => {
316                    self.workgroup_id_no_axis = true;
317                    wgsl::Variable::WorkgroupId
318                }
319                cube::Builtin::CubeDim => {
320                    self.workgroup_size_no_axis = true;
321                    wgsl::Variable::WorkgroupSize
322                }
323                cube::Builtin::CubeCount => {
324                    self.num_workgroup_no_axis = true;
325                    wgsl::Variable::NumWorkgroups
326                }
327                cube::Builtin::PlaneDim => {
328                    self.subgroup_size = true;
329                    wgsl::Variable::SubgroupSize
330                }
331                cube::Builtin::UnitPosPlane => {
332                    self.subgroup_invocation_id = true;
333                    wgsl::Variable::SubgroupInvocationId
334                }
335            },
336            cube::VariableKind::Matrix { .. } => {
337                panic!("Cooperative matrix-multiply and accumulate not supported.")
338            }
339            cube::VariableKind::Pipeline { .. } => {
340                panic!("Pipeline not supported.")
341            }
342            cube::VariableKind::Barrier { .. } => {
343                panic!("Barrier not supported.")
344            }
345            cube::VariableKind::TensorMap(_) => panic!("Tensor map not supported."),
346        }
347    }
348
349    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
350        let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
351        self.compile_variable(var)
352    }
353
354    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
355        let mut instructions = Vec::new();
356
357        let const_arrays = scope
358            .const_arrays
359            .drain(..)
360            .map(|(var, values)| ConstantArray {
361                index: var.index().unwrap(),
362                item: self.compile_item(var.item),
363                size: values.len() as u32,
364                values: values
365                    .into_iter()
366                    .map(|val| self.compile_variable(val))
367                    .collect(),
368            })
369            .collect::<Vec<_>>();
370        self.const_arrays.extend(const_arrays);
371
372        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
373        let processing = scope.process([checked_io]);
374
375        for var in processing.variables {
376            instructions.push(wgsl::Instruction::DeclareVariable {
377                var: self.compile_variable(var),
378            });
379        }
380
381        processing
382            .instructions
383            .into_iter()
384            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
385
386        instructions
387    }
388
389    fn compile_operation(
390        &mut self,
391        instructions: &mut Vec<wgsl::Instruction>,
392        operation: cube::Operation,
393        out: Option<cube::Variable>,
394        scope: &mut cube::Scope,
395    ) {
396        match operation {
397            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
398                input: self.compile_variable(variable),
399                out: self.compile_variable(out.unwrap()),
400            }),
401            cube::Operation::Arithmetic(op) => {
402                self.compile_arithmetic(op, out, instructions, scope)
403            }
404            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
405            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
406            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
407            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
408            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
409            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
410            cube::Operation::Synchronization(val) => {
411                self.compile_synchronization(instructions, val)
412            }
413            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
414            cube::Operation::CoopMma(_) => {
415                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
416            }
417            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
418                self.compile_comment(instructions, content)
419            }
420            cube::Operation::NonSemantic(_) => {}
421            cube::Operation::Barrier(_) => {
422                panic!("Barrier isn't supported on wgpu.")
423            }
424            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
425        }
426    }
427
428    fn compile_subgroup(
429        &mut self,
430        instructions: &mut Vec<wgsl::Instruction>,
431        subgroup: cube::Plane,
432        out: Option<cube::Variable>,
433    ) {
434        self.subgroup_instructions_used = true;
435
436        let out = out.unwrap();
437        let op = match subgroup {
438            cube::Plane::Elect => Subgroup::Elect {
439                out: self.compile_variable(out),
440            },
441            cube::Plane::All(op) => Subgroup::All {
442                input: self.compile_variable(op.input),
443                out: self.compile_variable(out),
444            },
445            cube::Plane::Any(op) => Subgroup::Any {
446                input: self.compile_variable(op.input),
447                out: self.compile_variable(out),
448            },
449            cube::Plane::Ballot(op) => Subgroup::Ballot {
450                input: self.compile_variable(op.input),
451                out: self.compile_variable(out),
452            },
453            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
454                lhs: self.compile_variable(op.lhs),
455                rhs: self.compile_variable(op.rhs),
456                out: self.compile_variable(out),
457            },
458            cube::Plane::Sum(op) => Subgroup::Sum {
459                input: self.compile_variable(op.input),
460                out: self.compile_variable(out),
461            },
462            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
463                input: self.compile_variable(op.input),
464                out: self.compile_variable(out),
465            },
466            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
467                input: self.compile_variable(op.input),
468                out: self.compile_variable(out),
469            },
470            cube::Plane::Prod(op) => Subgroup::Prod {
471                input: self.compile_variable(op.input),
472                out: self.compile_variable(out),
473            },
474            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
475                input: self.compile_variable(op.input),
476                out: self.compile_variable(out),
477            },
478            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
479                input: self.compile_variable(op.input),
480                out: self.compile_variable(out),
481            },
482            cube::Plane::Min(op) => Subgroup::Min {
483                input: self.compile_variable(op.input),
484                out: self.compile_variable(out),
485            },
486            cube::Plane::Max(op) => Subgroup::Max {
487                input: self.compile_variable(op.input),
488                out: self.compile_variable(out),
489            },
490        };
491
492        instructions.push(wgsl::Instruction::Subgroup(op));
493    }
494
495    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
496        match branch {
497            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
498                cond: self.compile_variable(op.cond),
499                instructions: self.compile_scope(&mut op.scope),
500            }),
501            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
502                cond: self.compile_variable(op.cond),
503                instructions_if: self.compile_scope(&mut op.scope_if),
504                instructions_else: self.compile_scope(&mut op.scope_else),
505            }),
506            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
507                value: self.compile_variable(op.value),
508                instructions_default: self.compile_scope(&mut op.scope_default),
509                cases: op
510                    .cases
511                    .into_iter()
512                    .map(|(val, mut scope)| {
513                        (self.compile_variable(val), self.compile_scope(&mut scope))
514                    })
515                    .collect(),
516            }),
517            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
518            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
519            cube::Branch::RangeLoop(mut range_loop) => {
520                instructions.push(wgsl::Instruction::RangeLoop {
521                    i: self.compile_variable(range_loop.i),
522                    start: self.compile_variable(range_loop.start),
523                    end: self.compile_variable(range_loop.end),
524                    step: range_loop.step.map(|it| self.compile_variable(it)),
525                    inclusive: range_loop.inclusive,
526                    instructions: self.compile_scope(&mut range_loop.scope),
527                })
528            }
529            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
530                instructions: self.compile_scope(&mut op.scope),
531            }),
532        };
533    }
534
535    fn compile_synchronization(
536        &mut self,
537        instructions: &mut Vec<wgsl::Instruction>,
538        synchronization: cube::Synchronization,
539    ) {
540        match synchronization {
541            cube::Synchronization::SyncCube => {
542                instructions.push(wgsl::Instruction::WorkgroupBarrier)
543            }
544            cube::Synchronization::SyncPlane => {
545                panic!("Synchronization within a plane is not supported in WGSL")
546            }
547            cube::Synchronization::SyncStorage => {
548                instructions.push(wgsl::Instruction::StorageBarrier)
549            }
550            cube::Synchronization::SyncProxyShared => panic!("TMA is not supported in WGSL"),
551        };
552    }
553
554    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
555        instructions.push(wgsl::Instruction::Comment { content })
556    }
557
558    fn compile_metadata(
559        &mut self,
560        metadata: cube::Metadata,
561        out: Option<cube::Variable>,
562    ) -> wgsl::Instruction {
563        let out = out.unwrap();
564        match metadata {
565            cube::Metadata::Rank { var } => {
566                let position = self.ext_meta_pos(&var);
567                let offset = self.metadata.rank_index(position);
568                wgsl::Instruction::Metadata {
569                    out: self.compile_variable(out),
570                    info_offset: self.compile_variable(offset.into()),
571                }
572            }
573            cube::Metadata::Stride { dim, var } => {
574                let position = self.ext_meta_pos(&var);
575                let offset = self.metadata.stride_offset_index(position);
576                wgsl::Instruction::ExtendedMeta {
577                    info_offset: self.compile_variable(offset.into()),
578                    dim: self.compile_variable(dim),
579                    out: self.compile_variable(out),
580                }
581            }
582            cube::Metadata::Shape { dim, var } => {
583                let position = self.ext_meta_pos(&var);
584                let offset = self.metadata.shape_offset_index(position);
585                wgsl::Instruction::ExtendedMeta {
586                    info_offset: self.compile_variable(offset.into()),
587                    dim: self.compile_variable(dim),
588                    out: self.compile_variable(out),
589                }
590            }
591            cube::Metadata::Length { var } => match var.kind {
592                cube::VariableKind::GlobalInputArray(id) => {
593                    let offset = self.metadata.len_index(id);
594                    wgsl::Instruction::Metadata {
595                        out: self.compile_variable(out),
596                        info_offset: self.compile_variable(offset.into()),
597                    }
598                }
599                cube::VariableKind::GlobalOutputArray(id) => {
600                    let offset = self.metadata.len_index(id);
601                    wgsl::Instruction::Metadata {
602                        out: self.compile_variable(out),
603                        info_offset: self.compile_variable(offset.into()),
604                    }
605                }
606                _ => wgsl::Instruction::Length {
607                    var: self.compile_variable(var),
608                    out: self.compile_variable(out),
609                },
610            },
611            cube::Metadata::BufferLength { var } => match var.kind {
612                cube::VariableKind::GlobalInputArray(id) => {
613                    let offset = self.metadata.buffer_len_index(id);
614                    wgsl::Instruction::Metadata {
615                        out: self.compile_variable(out),
616                        info_offset: self.compile_variable(offset.into()),
617                    }
618                }
619                cube::VariableKind::GlobalOutputArray(id) => {
620                    let offset = self.metadata.buffer_len_index(id);
621                    wgsl::Instruction::Metadata {
622                        out: self.compile_variable(out),
623                        info_offset: self.compile_variable(offset.into()),
624                    }
625                }
626                _ => wgsl::Instruction::Length {
627                    var: self.compile_variable(var),
628                    out: self.compile_variable(out),
629                },
630            },
631        }
632    }
633
634    fn compile_arithmetic(
635        &mut self,
636        value: cube::Arithmetic,
637        out: Option<cube::Variable>,
638        instructions: &mut Vec<wgsl::Instruction>,
639        scope: &mut Scope,
640    ) {
641        let out = out.unwrap();
642        match value {
643            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
644                lhs: self.compile_variable(op.lhs),
645                rhs: self.compile_variable(op.rhs),
646                out: self.compile_variable(out),
647            }),
648            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
649                lhs: self.compile_variable(op.lhs),
650                rhs: self.compile_variable(op.rhs),
651                out: self.compile_variable(out),
652            }),
653            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
654                lhs: self.compile_variable(op.lhs),
655                rhs: self.compile_variable(op.rhs),
656                out: self.compile_variable(out),
657            }),
658            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
659                a: self.compile_variable(op.a),
660                b: self.compile_variable(op.b),
661                c: self.compile_variable(op.c),
662                out: self.compile_variable(out),
663            }),
664            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
665                lhs: self.compile_variable(op.lhs),
666                rhs: self.compile_variable(op.rhs),
667                out: self.compile_variable(out),
668            }),
669            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
670                lhs: self.compile_variable(op.lhs),
671                rhs: self.compile_variable(op.rhs),
672                out: self.compile_variable(out),
673            }),
674            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
675                lhs: self.compile_variable(op.lhs),
676                rhs: self.compile_variable(op.rhs),
677                out: self.compile_variable(out),
678            }),
679            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
680                lhs: self.compile_variable(op.lhs),
681                rhs: self.compile_variable(op.rhs),
682                out: self.compile_variable(out),
683            }),
684            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
685                input: self.compile_variable(op.input),
686                out: self.compile_variable(out),
687            }),
688            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
689                input: self.compile_variable(op.input),
690                out: self.compile_variable(out),
691            }),
692            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
693                input: self.compile_variable(op.input),
694                out: self.compile_variable(out),
695            }),
696            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
697                input: self.compile_variable(op.input),
698                out: self.compile_variable(out),
699            }),
700            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
701                input: self.compile_variable(op.input),
702                out: self.compile_variable(out),
703            }),
704            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
705                input: self.compile_variable(op.input),
706                out: self.compile_variable(out),
707            }),
708            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
709                input: self.compile_variable(op.input),
710                out: self.compile_variable(out),
711            }),
712            cube::Arithmetic::Powf(op) => instructions.push(wgsl::Instruction::Powf {
713                lhs: self.compile_variable(op.lhs),
714                rhs: self.compile_variable(op.rhs),
715                out: self.compile_variable(out),
716            }),
717            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
718                input: self.compile_variable(op.input),
719                out: self.compile_variable(out),
720            }),
721            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
722                input: self.compile_variable(op.input),
723                out: self.compile_variable(out),
724            }),
725            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
726                input: self.compile_variable(op.input),
727                out: self.compile_variable(out),
728            }),
729            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
730                input: self.compile_variable(op.input),
731                out: self.compile_variable(out),
732            }),
733            cube::Arithmetic::Erf(op) => {
734                let mut scope = scope.child();
735                expand_erf(&mut scope, op.input, out);
736                instructions.extend(self.compile_scope(&mut scope));
737            }
738            cube::Arithmetic::MulHi(op) => {
739                let mut scope = scope.child();
740                match self.compilation_options.supports_u64 {
741                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
742                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
743                }
744                instructions.extend(self.compile_scope(&mut scope));
745            }
746            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
747                input: self.compile_variable(op.input),
748                out: self.compile_variable(out),
749            }),
750            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
751                input: self.compile_variable(op.input),
752                min_value: self.compile_variable(op.min_value),
753                max_value: self.compile_variable(op.max_value),
754                out: self.compile_variable(out),
755            }),
756            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
757                lhs: self.compile_variable(op.lhs),
758                rhs: self.compile_variable(op.rhs),
759                out: self.compile_variable(out),
760            }),
761            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
762                input: self.compile_variable(op.input),
763                out: self.compile_variable(out),
764            }),
765            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
766                input: self.compile_variable(op.input),
767                out: self.compile_variable(out),
768            }),
769            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
770                input: self.compile_variable(op.input),
771                out: self.compile_variable(out),
772            }),
773            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
774                lhs: self.compile_variable(op.lhs),
775                rhs: self.compile_variable(op.rhs),
776                out: self.compile_variable(out),
777            }),
778        }
779    }
780
781    fn compile_cmp(
782        &mut self,
783        value: cube::Comparison,
784        out: Option<cube::Variable>,
785        instructions: &mut Vec<wgsl::Instruction>,
786    ) {
787        let out = out.unwrap();
788        match value {
789            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
790                lhs: self.compile_variable(op.lhs),
791                rhs: self.compile_variable(op.rhs),
792                out: self.compile_variable(out),
793            }),
794            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
795                lhs: self.compile_variable(op.lhs),
796                rhs: self.compile_variable(op.rhs),
797                out: self.compile_variable(out),
798            }),
799            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
800                lhs: self.compile_variable(op.lhs),
801                rhs: self.compile_variable(op.rhs),
802                out: self.compile_variable(out),
803            }),
804            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
805                lhs: self.compile_variable(op.lhs),
806                rhs: self.compile_variable(op.rhs),
807                out: self.compile_variable(out),
808            }),
809            cube::Comparison::GreaterEqual(op) => {
810                instructions.push(wgsl::Instruction::GreaterEqual {
811                    lhs: self.compile_variable(op.lhs),
812                    rhs: self.compile_variable(op.rhs),
813                    out: self.compile_variable(out),
814                })
815            }
816            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
817                lhs: self.compile_variable(op.lhs),
818                rhs: self.compile_variable(op.rhs),
819                out: self.compile_variable(out),
820            }),
821        }
822    }
823
824    fn compile_bitwise(
825        &mut self,
826        value: cube::Bitwise,
827        out: Option<cube::Variable>,
828        instructions: &mut Vec<wgsl::Instruction>,
829    ) {
830        let out = out.unwrap();
831        match value {
832            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
833                lhs: self.compile_variable(op.lhs),
834                rhs: self.compile_variable(op.rhs),
835                out: self.compile_variable(out),
836            }),
837            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
838                lhs: self.compile_variable(op.lhs),
839                rhs: self.compile_variable(op.rhs),
840                out: self.compile_variable(out),
841            }),
842            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
843                lhs: self.compile_variable(op.lhs),
844                rhs: self.compile_variable(op.rhs),
845                out: self.compile_variable(out),
846            }),
847            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
848                input: self.compile_variable(op.input),
849                out: self.compile_variable(out),
850            }),
851            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
852                input: self.compile_variable(op.input),
853                out: self.compile_variable(out),
854            }),
855            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
856                lhs: self.compile_variable(op.lhs),
857                rhs: self.compile_variable(op.rhs),
858                out: self.compile_variable(out),
859            }),
860            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
861                lhs: self.compile_variable(op.lhs),
862                rhs: self.compile_variable(op.rhs),
863                out: self.compile_variable(out),
864            }),
865            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
866                input: self.compile_variable(op.input),
867                out: self.compile_variable(out),
868            }),
869            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
870                input: self.compile_variable(op.input),
871                out: self.compile_variable(out),
872            }),
873            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
874                input: self.compile_variable(op.input),
875                out: self.compile_variable(out),
876            }),
877        }
878    }
879
880    fn compile_operator(
881        &mut self,
882        value: cube::Operator,
883        out: Option<cube::Variable>,
884        instructions: &mut Vec<wgsl::Instruction>,
885    ) {
886        let out = out.unwrap();
887        match value {
888            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
889                input: self.compile_variable(op.input),
890                out: self.compile_variable(out),
891            }),
892            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
893                instructions.push(wgsl::Instruction::Index {
894                    lhs: self.compile_variable(op.list),
895                    rhs: self.compile_variable(op.index),
896                    out: self.compile_variable(out),
897                });
898            }
899            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
900                instructions.push(wgsl::Instruction::IndexAssign {
901                    index: self.compile_variable(op.index),
902                    rhs: self.compile_variable(op.value),
903                    out: self.compile_variable(out),
904                })
905            }
906            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
907                lhs: self.compile_variable(op.lhs),
908                rhs: self.compile_variable(op.rhs),
909                out: self.compile_variable(out),
910            }),
911            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
912                lhs: self.compile_variable(op.lhs),
913                rhs: self.compile_variable(op.rhs),
914                out: self.compile_variable(out),
915            }),
916            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
917                input: self.compile_variable(op.input),
918                out: self.compile_variable(out),
919            }),
920            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
921                input: self.compile_variable(op.input),
922                out: self.compile_variable(out),
923            }),
924            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
925                inputs: op
926                    .inputs
927                    .into_iter()
928                    .map(|var| self.compile_variable(var))
929                    .collect(),
930                out: self.compile_variable(out),
931            }),
932            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
933                input: self.compile_variable(op.input),
934                in_index: self.compile_variable(op.in_index),
935                out: self.compile_variable(out),
936                out_index: self.compile_variable(op.out_index),
937            }),
938            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
939                input: self.compile_variable(op.input),
940                in_index: self.compile_variable(op.in_index),
941                out: self.compile_variable(out),
942                out_index: self.compile_variable(op.out_index),
943                len: op.len.as_const().unwrap().as_u32(),
944            }),
945            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
946                cond: self.compile_variable(op.cond),
947                then: self.compile_variable(op.then),
948                or_else: self.compile_variable(op.or_else),
949                out: self.compile_variable(out),
950            }),
951        }
952    }
953
954    fn compile_atomic(
955        &mut self,
956        atomic: cube::AtomicOp,
957        out: Option<cube::Variable>,
958    ) -> wgsl::Instruction {
959        let out = out.unwrap();
960        match atomic {
961            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
962                lhs: self.compile_variable(op.lhs),
963                rhs: self.compile_variable(op.rhs),
964                out: self.compile_variable(out),
965            },
966            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
967                lhs: self.compile_variable(op.lhs),
968                rhs: self.compile_variable(op.rhs),
969                out: self.compile_variable(out),
970            },
971            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
972                lhs: self.compile_variable(op.lhs),
973                rhs: self.compile_variable(op.rhs),
974                out: self.compile_variable(out),
975            },
976            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
977                lhs: self.compile_variable(op.lhs),
978                rhs: self.compile_variable(op.rhs),
979                out: self.compile_variable(out),
980            },
981            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
982                lhs: self.compile_variable(op.lhs),
983                rhs: self.compile_variable(op.rhs),
984                out: self.compile_variable(out),
985            },
986            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
987                lhs: self.compile_variable(op.lhs),
988                rhs: self.compile_variable(op.rhs),
989                out: self.compile_variable(out),
990            },
991            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
992                lhs: self.compile_variable(op.lhs),
993                rhs: self.compile_variable(op.rhs),
994                out: self.compile_variable(out),
995            },
996            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
997                input: self.compile_variable(op.input),
998                out: self.compile_variable(out),
999            },
1000            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1001                input: self.compile_variable(op.input),
1002                out: self.compile_variable(out),
1003            },
1004            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1005                lhs: self.compile_variable(op.lhs),
1006                rhs: self.compile_variable(op.rhs),
1007                out: self.compile_variable(out),
1008            },
1009            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1010                lhs: self.compile_variable(op.input),
1011                cmp: self.compile_variable(op.cmp),
1012                value: self.compile_variable(op.val),
1013                out: self.compile_variable(out),
1014            },
1015        }
1016    }
1017
1018    fn compile_location(value: compute::Location) -> wgsl::Location {
1019        match value {
1020            compute::Location::Storage => wgsl::Location::Storage,
1021            compute::Location::Cube => wgsl::Location::Workgroup,
1022        }
1023    }
1024
1025    fn compile_binding(&mut self, value: compute::Binding) -> wgsl::Binding {
1026        wgsl::Binding {
1027            id: value.id,
1028            visibility: value.visibility,
1029            location: Self::compile_location(value.location),
1030            item: self.compile_item(value.item),
1031            size: value.size,
1032        }
1033    }
1034}
1035
1036fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1037    let mut extensions = Vec::new();
1038
1039    let mut register_extension = |extension: wgsl::Extension| {
1040        if !extensions.contains(&extension) {
1041            extensions.push(extension);
1042        }
1043    };
1044
1045    // Since not all instructions are native to WGSL, we need to add the custom ones.
1046    for instruction in instructions {
1047        match instruction {
1048            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1049                register_extension(wgsl::Extension::PowfPrimitive(out.item()));
1050
1051                if rhs.is_always_scalar() || rhs.item().vectorization_factor() == 1 {
1052                    register_extension(wgsl::Extension::PowfScalar(out.item()));
1053                } else {
1054                    register_extension(wgsl::Extension::Powf(out.item()));
1055                }
1056            }
1057            #[cfg(target_os = "macos")]
1058            wgsl::Instruction::Tanh { input, out: _ } => {
1059                register_extension(wgsl::Extension::SafeTanh(input.item()))
1060            }
1061            wgsl::Instruction::If { instructions, .. } => {
1062                for extension in register_extensions(instructions) {
1063                    register_extension(extension);
1064                }
1065            }
1066            wgsl::Instruction::IfElse {
1067                instructions_if,
1068                instructions_else,
1069                ..
1070            } => {
1071                for extension in register_extensions(instructions_if) {
1072                    register_extension(extension);
1073                }
1074                for extension in register_extensions(instructions_else) {
1075                    register_extension(extension);
1076                }
1077            }
1078            wgsl::Instruction::Loop { instructions } => {
1079                for extension in register_extensions(instructions) {
1080                    register_extension(extension);
1081                }
1082            }
1083            wgsl::Instruction::RangeLoop { instructions, .. } => {
1084                for extension in register_extensions(instructions) {
1085                    register_extension(extension);
1086                }
1087            }
1088            _ => {}
1089        }
1090    }
1091
1092    extensions
1093}