cubecl_wgpu/compiler/wgsl/
compiler.rs

1use std::{borrow::Cow, sync::Arc};
2
3use super::Subgroup;
4use super::{shader::ComputeShader, ConstantArray};
5use super::{Item, LocalArray, SharedMemory};
6use crate::{
7    compiler::{base::WgpuCompiler, wgsl},
8    WgpuServer,
9};
10
11use cubecl_core::ir::expand_checked_index_assign;
12use cubecl_core::{
13    ir::{self as cube, UIntKind},
14    prelude::CompiledKernel,
15    server::ComputeServer,
16    Feature, Metadata,
17};
18use cubecl_runtime::{DeviceProperties, ExecutionMode};
19use wgpu::{
20    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
21    ComputePipeline, DeviceDescriptor, PipelineLayoutDescriptor, ShaderModuleDescriptor,
22    ShaderStages,
23};
24
25#[derive(Clone, Debug, Default)]
26pub struct CompilationOptions {}
27
28/// Wgsl Compiler.
29#[derive(Clone, Default)]
30pub struct WgslCompiler {
31    num_inputs: usize,
32    num_outputs: usize,
33    metadata: Metadata,
34    ext_meta_pos: Vec<u32>,
35    local_invocation_index: bool,
36    local_invocation_id: bool,
37    global_invocation_id: bool,
38    workgroup_id: bool,
39    subgroup_size: bool,
40    subgroup_invocation_id: bool,
41    id: bool,
42    num_workgroups: bool,
43    workgroup_id_no_axis: bool,
44    workgroup_size_no_axis: bool,
45    num_workgroup_no_axis: bool,
46    shared_memories: Vec<SharedMemory>,
47    const_arrays: Vec<ConstantArray>,
48    local_arrays: Vec<LocalArray>,
49    #[allow(dead_code)]
50    compilation_options: CompilationOptions,
51    strategy: ExecutionMode,
52    subgroup_instructions_used: bool,
53}
54
55impl core::fmt::Debug for WgslCompiler {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.write_str("WgslCompiler")
58    }
59}
60
61impl cubecl_core::Compiler for WgslCompiler {
62    type Representation = ComputeShader;
63    type CompilationOptions = CompilationOptions;
64
65    fn compile(
66        shader: cube::KernelDefinition,
67        compilation_options: &Self::CompilationOptions,
68        mode: ExecutionMode,
69    ) -> Self::Representation {
70        let mut compiler = Self {
71            compilation_options: compilation_options.clone(),
72            ..Self::default()
73        };
74        compiler.compile_shader(shader, mode)
75    }
76
77    fn elem_size(elem: cube::Elem) -> usize {
78        Self::compile_elem(elem).size()
79    }
80
81    fn max_shared_memory_size() -> usize {
82        32768
83    }
84}
85
86impl WgpuCompiler for WgslCompiler {
87    fn create_pipeline(
88        server: &mut WgpuServer<Self>,
89        kernel: CompiledKernel<Self>,
90    ) -> Arc<ComputePipeline> {
91        let source = &kernel.source;
92        // Cube always in principle uses unchecked modules. Certain operations like
93        // indexing are instead checked by cube. The WebGPU specification only makes
94        // incredibly loose guarantees that Cube can't rely on. Additionally, kernels
95        // can opt in/out per operation whether checks should be performed which can be faster.
96        //
97        // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
98        // is only available through the use of unsafe code.
99        let module = unsafe {
100            server
101                .device
102                .create_shader_module_unchecked(ShaderModuleDescriptor {
103                    label: None,
104                    source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
105                })
106        };
107
108        let layout = kernel.repr.map(|repr| {
109            let bindings = repr
110                .inputs
111                .iter()
112                .chain(repr.outputs.iter())
113                .chain(repr.named.iter().map(|it| &it.1))
114                .enumerate();
115            let bindings = bindings
116                .map(|(i, _binding)| BindGroupLayoutEntry {
117                    binding: i as u32,
118                    visibility: ShaderStages::COMPUTE,
119                    ty: BindingType::Buffer {
120                        #[cfg(not(exclusive_memory_only))]
121                        ty: BufferBindingType::Storage { read_only: false },
122                        #[cfg(exclusive_memory_only)]
123                        ty: BufferBindingType::Storage {
124                            read_only: matches!(
125                                _binding.visibility,
126                                super::shader::Visibility::Read
127                            ),
128                        },
129                        has_dynamic_offset: false,
130                        min_binding_size: None,
131                    },
132                    count: None,
133                })
134                .collect::<Vec<_>>();
135            let layout = server
136                .device
137                .create_bind_group_layout(&BindGroupLayoutDescriptor {
138                    label: None,
139                    entries: &bindings,
140                });
141            server
142                .device
143                .create_pipeline_layout(&PipelineLayoutDescriptor {
144                    label: None,
145                    bind_group_layouts: &[&layout],
146                    push_constant_ranges: &[],
147                })
148        });
149
150        Arc::new(
151            server
152                .device
153                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
154                    label: None,
155                    layout: layout.as_ref(),
156                    module: &module,
157                    entry_point: Some(&kernel.entrypoint_name),
158                    compilation_options: wgpu::PipelineCompilationOptions {
159                        zero_initialize_workgroup_memory: false,
160                        ..Default::default()
161                    },
162                    cache: None,
163                }),
164        )
165    }
166
167    fn compile(
168        server: &mut WgpuServer<Self>,
169        kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
170        mode: ExecutionMode,
171    ) -> CompiledKernel<Self> {
172        kernel.compile(&server.compilation_options, mode)
173    }
174
175    async fn request_device(adapter: &wgpu::Adapter) -> (wgpu::Device, wgpu::Queue) {
176        let limits = adapter.limits();
177        adapter
178            .request_device(
179                &DeviceDescriptor {
180                    label: None,
181                    required_features: adapter.features(),
182                    required_limits: limits,
183                    // The default is MemoryHints::Performance, which tries to do some bigger
184                    // block allocations. However, we already batch allocations, so we
185                    // can use MemoryHints::MemoryUsage to lower memory usage.
186                    memory_hints: wgpu::MemoryHints::MemoryUsage,
187                },
188                None,
189            )
190            .await
191            .map_err(|err| {
192                format!(
193                    "Unable to request the device with the adapter {:?}, err {:?}",
194                    adapter.get_info(),
195                    err
196                )
197            })
198            .unwrap()
199    }
200
201    fn register_features(
202        _adapter: &wgpu::Adapter,
203        _device: &wgpu::Device,
204        props: &mut DeviceProperties<Feature>,
205    ) {
206        register_types(props);
207    }
208}
209
210fn register_types(props: &mut DeviceProperties<Feature>) {
211    use cubecl_core::ir::{Elem, FloatKind, IntKind};
212
213    let supported_types = [
214        Elem::UInt(UIntKind::U32),
215        Elem::Int(IntKind::I32),
216        Elem::AtomicInt(IntKind::I32),
217        Elem::AtomicUInt(UIntKind::U32),
218        Elem::Float(FloatKind::F32),
219        Elem::Float(FloatKind::Flex32),
220        Elem::Bool,
221    ];
222
223    for ty in supported_types {
224        props.register_feature(Feature::Type(ty));
225    }
226}
227
228impl WgslCompiler {
229    fn compile_shader(
230        &mut self,
231        mut value: cube::KernelDefinition,
232        mode: ExecutionMode,
233    ) -> wgsl::ComputeShader {
234        self.strategy = mode;
235
236        self.num_inputs = value.inputs.len();
237        self.num_outputs = value.outputs.len();
238        let num_meta = value.inputs.len() + value.outputs.len();
239
240        self.ext_meta_pos = Vec::new();
241        let mut num_ext = 0;
242
243        for binding in value.inputs.iter().chain(value.outputs.iter()) {
244            self.ext_meta_pos.push(num_ext);
245            if binding.has_extended_meta {
246                num_ext += 1;
247            }
248        }
249
250        self.metadata = Metadata::new(num_meta as u32, num_ext);
251
252        let instructions = self.compile_scope(&mut value.body);
253        let extensions = register_extensions(&instructions);
254        let body = wgsl::Body {
255            instructions,
256            id: self.id,
257        };
258
259        wgsl::ComputeShader {
260            inputs: value
261                .inputs
262                .into_iter()
263                .map(Self::compile_binding)
264                .collect(),
265            outputs: value
266                .outputs
267                .into_iter()
268                .map(Self::compile_binding)
269                .collect(),
270            named: value
271                .named
272                .into_iter()
273                .map(|(name, binding)| (name, Self::compile_binding(binding)))
274                .collect(),
275            shared_memories: self.shared_memories.clone(),
276            constant_arrays: self.const_arrays.clone(),
277            local_arrays: self.local_arrays.clone(),
278            workgroup_size: value.cube_dim,
279            global_invocation_id: self.global_invocation_id || self.id,
280            local_invocation_index: self.local_invocation_index,
281            local_invocation_id: self.local_invocation_id,
282            num_workgroups: self.id
283                || self.num_workgroups
284                || self.num_workgroup_no_axis
285                || self.workgroup_id_no_axis,
286            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
287            subgroup_size: self.subgroup_size,
288            subgroup_invocation_id: self.subgroup_invocation_id,
289            body,
290            extensions,
291            num_workgroups_no_axis: self.num_workgroup_no_axis,
292            workgroup_id_no_axis: self.workgroup_id_no_axis,
293            workgroup_size_no_axis: self.workgroup_size_no_axis,
294            subgroup_instructions_used: self.subgroup_instructions_used,
295            kernel_name: value.kernel_name,
296        }
297    }
298
299    fn compile_item(item: cube::Item) -> Item {
300        let elem = Self::compile_elem(item.elem);
301        match item.vectorization.map(|it| it.get()).unwrap_or(1) {
302            1 => wgsl::Item::Scalar(elem),
303            2 => wgsl::Item::Vec2(elem),
304            3 => wgsl::Item::Vec3(elem),
305            4 => wgsl::Item::Vec4(elem),
306            _ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization),
307        }
308    }
309
310    fn compile_elem(value: cube::Elem) -> wgsl::Elem {
311        match value {
312            cube::Elem::Float(f) => match f {
313                cube::FloatKind::F16 => panic!("f16 is not yet supported"),
314                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
315                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
316                cube::FloatKind::Flex32 => wgsl::Elem::F32,
317                cube::FloatKind::F32 => wgsl::Elem::F32,
318                cube::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"),
319            },
320            cube::Elem::Int(i) => match i {
321                cube::IntKind::I32 => wgsl::Elem::I32,
322                kind => panic!("{kind:?} is not a valid WgpuElement"),
323            },
324            cube::Elem::UInt(kind) => match kind {
325                cube::UIntKind::U32 => wgsl::Elem::U32,
326                kind => panic!("{kind:?} is not a valid WgpuElement"),
327            },
328            cube::Elem::Bool => wgsl::Elem::Bool,
329            cube::Elem::AtomicFloat(i) => match i {
330                cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
331                kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
332            },
333            cube::Elem::AtomicInt(i) => match i {
334                cube::IntKind::I32 => wgsl::Elem::AtomicI32,
335                kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
336            },
337            cube::Elem::AtomicUInt(kind) => match kind {
338                cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
339                kind => panic!("{kind:?} is not a valid WgpuElement"),
340            },
341        }
342    }
343
344    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
345        let pos = match var.kind {
346            cube::VariableKind::GlobalInputArray(id) => id as usize,
347            cube::VariableKind::GlobalOutputArray(id) => self.num_inputs + id as usize,
348            _ => panic!("Only global arrays have metadata"),
349        };
350        self.ext_meta_pos[pos]
351    }
352
353    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
354        let item = value.item;
355        match value.kind {
356            cube::VariableKind::GlobalInputArray(id) => {
357                wgsl::Variable::GlobalInputArray(id, Self::compile_item(item))
358            }
359            cube::VariableKind::GlobalScalar(id) => {
360                wgsl::Variable::GlobalScalar(id, Self::compile_elem(item.elem), item.elem)
361            }
362            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
363                wgsl::Variable::LocalMut {
364                    id,
365                    item: Self::compile_item(item),
366                }
367            }
368            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
369                id,
370                item: Self::compile_item(item),
371            },
372            cube::VariableKind::Slice { id } => wgsl::Variable::Slice {
373                id,
374                item: Self::compile_item(item),
375            },
376            cube::VariableKind::GlobalOutputArray(id) => {
377                wgsl::Variable::GlobalOutputArray(id, Self::compile_item(item))
378            }
379            cube::VariableKind::ConstantScalar(value) => {
380                wgsl::Variable::ConstantScalar(value, Self::compile_elem(value.elem()))
381            }
382            cube::VariableKind::SharedMemory { id, length } => {
383                let item = Self::compile_item(item);
384                if !self.shared_memories.iter().any(|s| s.index == id) {
385                    self.shared_memories
386                        .push(SharedMemory::new(id, item, length));
387                }
388                wgsl::Variable::SharedMemory(id, item, length)
389            }
390            cube::VariableKind::ConstantArray { id, length } => {
391                let item = Self::compile_item(item);
392                wgsl::Variable::ConstantArray(id, item, length)
393            }
394            cube::VariableKind::LocalArray { id, length } => {
395                let item = Self::compile_item(item);
396                if !self.local_arrays.iter().any(|s| s.index == id) {
397                    self.local_arrays.push(LocalArray::new(id, item, length));
398                }
399                wgsl::Variable::LocalArray(id, item, length)
400            }
401            cube::VariableKind::Builtin(builtin) => match builtin {
402                cube::Builtin::AbsolutePos => {
403                    self.id = true;
404                    wgsl::Variable::Id
405                }
406                cube::Builtin::UnitPos => {
407                    self.local_invocation_index = true;
408                    wgsl::Variable::LocalInvocationIndex
409                }
410                cube::Builtin::UnitPosX => {
411                    self.local_invocation_id = true;
412                    wgsl::Variable::LocalInvocationIdX
413                }
414                cube::Builtin::UnitPosY => {
415                    self.local_invocation_id = true;
416                    wgsl::Variable::LocalInvocationIdY
417                }
418                cube::Builtin::UnitPosZ => {
419                    self.local_invocation_id = true;
420                    wgsl::Variable::LocalInvocationIdZ
421                }
422                cube::Builtin::CubePosX => {
423                    self.workgroup_id = true;
424                    wgsl::Variable::WorkgroupIdX
425                }
426                cube::Builtin::CubePosY => {
427                    self.workgroup_id = true;
428                    wgsl::Variable::WorkgroupIdY
429                }
430                cube::Builtin::CubePosZ => {
431                    self.workgroup_id = true;
432                    wgsl::Variable::WorkgroupIdZ
433                }
434                cube::Builtin::AbsolutePosX => {
435                    self.global_invocation_id = true;
436                    wgsl::Variable::GlobalInvocationIdX
437                }
438                cube::Builtin::AbsolutePosY => {
439                    self.global_invocation_id = true;
440                    wgsl::Variable::GlobalInvocationIdY
441                }
442                cube::Builtin::AbsolutePosZ => {
443                    self.global_invocation_id = true;
444                    wgsl::Variable::GlobalInvocationIdZ
445                }
446                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
447                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
448                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
449                cube::Builtin::CubeCountX => {
450                    self.num_workgroups = true;
451                    wgsl::Variable::NumWorkgroupsX
452                }
453                cube::Builtin::CubeCountY => {
454                    self.num_workgroups = true;
455                    wgsl::Variable::NumWorkgroupsY
456                }
457                cube::Builtin::CubeCountZ => {
458                    self.num_workgroups = true;
459                    wgsl::Variable::NumWorkgroupsZ
460                }
461                cube::Builtin::CubePos => {
462                    self.workgroup_id_no_axis = true;
463                    wgsl::Variable::WorkgroupId
464                }
465                cube::Builtin::CubeDim => {
466                    self.workgroup_size_no_axis = true;
467                    wgsl::Variable::WorkgroupSize
468                }
469                cube::Builtin::CubeCount => {
470                    self.num_workgroup_no_axis = true;
471                    wgsl::Variable::NumWorkgroups
472                }
473                cube::Builtin::PlaneDim => {
474                    self.subgroup_size = true;
475                    wgsl::Variable::SubgroupSize
476                }
477                cube::Builtin::UnitPosPlane => {
478                    self.subgroup_invocation_id = true;
479                    wgsl::Variable::SubgroupInvocationId
480                }
481            },
482            cube::VariableKind::Matrix { .. } => {
483                panic!("Cooperative matrix-multiply and accumulate not supported.")
484            }
485        }
486    }
487
488    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
489        let mut instructions = Vec::new();
490
491        let const_arrays = scope
492            .const_arrays
493            .drain(..)
494            .map(|(var, values)| ConstantArray {
495                index: var.index().unwrap(),
496                item: Self::compile_item(var.item),
497                size: values.len() as u32,
498                values: values
499                    .into_iter()
500                    .map(|val| self.compile_variable(val))
501                    .collect(),
502            })
503            .collect::<Vec<_>>();
504        self.const_arrays.extend(const_arrays);
505
506        let processing = scope.process();
507
508        for var in processing.variables {
509            // We don't declare slices.
510            if let cube::VariableKind::Slice { .. } = var.kind {
511                continue;
512            }
513
514            instructions.push(wgsl::Instruction::DeclareVariable {
515                var: self.compile_variable(var),
516            });
517        }
518
519        processing
520            .operations
521            .into_iter()
522            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
523
524        instructions
525    }
526
527    fn compile_operation(
528        &mut self,
529        instructions: &mut Vec<wgsl::Instruction>,
530        operation: cube::Operation,
531        out: Option<cube::Variable>,
532        scope: &mut cube::Scope,
533    ) {
534        match operation {
535            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
536                input: self.compile_variable(variable),
537                out: self.compile_variable(out.unwrap()),
538            }),
539            cube::Operation::Operator(op) => self.compile_instruction(op, out, instructions, scope),
540            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
541            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
542            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
543            cube::Operation::Synchronization(val) => {
544                self.compile_synchronization(instructions, val)
545            }
546            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
547            cube::Operation::CoopMma(_) => {
548                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
549            }
550            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
551                self.compile_comment(instructions, content)
552            }
553            // No good way to attach debug info
554            cube::Operation::NonSemantic(_) => {}
555        }
556    }
557
558    fn compile_subgroup(
559        &mut self,
560        instructions: &mut Vec<wgsl::Instruction>,
561        subgroup: cube::Plane,
562        out: Option<cube::Variable>,
563    ) {
564        self.subgroup_instructions_used = true;
565
566        let out = out.unwrap();
567        let op = match subgroup {
568            cube::Plane::Elect => Subgroup::Elect {
569                out: self.compile_variable(out),
570            },
571            cube::Plane::All(op) => Subgroup::All {
572                input: self.compile_variable(op.input),
573                out: self.compile_variable(out),
574            },
575            cube::Plane::Any(op) => Subgroup::Any {
576                input: self.compile_variable(op.input),
577                out: self.compile_variable(out),
578            },
579            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
580                lhs: self.compile_variable(op.lhs),
581                rhs: self.compile_variable(op.rhs),
582                out: self.compile_variable(out),
583            },
584            cube::Plane::Sum(op) => Subgroup::Sum {
585                input: self.compile_variable(op.input),
586                out: self.compile_variable(out),
587            },
588            cube::Plane::Prod(op) => Subgroup::Prod {
589                input: self.compile_variable(op.input),
590                out: self.compile_variable(out),
591            },
592            cube::Plane::Min(op) => Subgroup::Min {
593                input: self.compile_variable(op.input),
594                out: self.compile_variable(out),
595            },
596            cube::Plane::Max(op) => Subgroup::Max {
597                input: self.compile_variable(op.input),
598                out: self.compile_variable(out),
599            },
600        };
601
602        instructions.push(wgsl::Instruction::Subgroup(op));
603    }
604
605    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
606        match branch {
607            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
608                cond: self.compile_variable(op.cond),
609                instructions: self.compile_scope(&mut op.scope),
610            }),
611            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
612                cond: self.compile_variable(op.cond),
613                instructions_if: self.compile_scope(&mut op.scope_if),
614                instructions_else: self.compile_scope(&mut op.scope_else),
615            }),
616            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
617                value: self.compile_variable(op.value),
618                instructions_default: self.compile_scope(&mut op.scope_default),
619                cases: op
620                    .cases
621                    .into_iter()
622                    .map(|(val, mut scope)| {
623                        (self.compile_variable(val), self.compile_scope(&mut scope))
624                    })
625                    .collect(),
626            }),
627            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
628            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
629            cube::Branch::RangeLoop(mut range_loop) => {
630                instructions.push(wgsl::Instruction::RangeLoop {
631                    i: self.compile_variable(range_loop.i),
632                    start: self.compile_variable(range_loop.start),
633                    end: self.compile_variable(range_loop.end),
634                    step: range_loop.step.map(|it| self.compile_variable(it)),
635                    inclusive: range_loop.inclusive,
636                    instructions: self.compile_scope(&mut range_loop.scope),
637                })
638            }
639            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
640                instructions: self.compile_scope(&mut op.scope),
641            }),
642        };
643    }
644
645    fn compile_synchronization(
646        &mut self,
647        instructions: &mut Vec<wgsl::Instruction>,
648        synchronization: cube::Synchronization,
649    ) {
650        match synchronization {
651            cube::Synchronization::SyncUnits => {
652                instructions.push(wgsl::Instruction::WorkgroupBarrier)
653            }
654            cube::Synchronization::SyncStorage => {
655                instructions.push(wgsl::Instruction::StorageBarrier)
656            }
657        };
658    }
659
660    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
661        instructions.push(wgsl::Instruction::Comment { content })
662    }
663
664    fn compile_metadata(
665        &mut self,
666        metadata: cube::Metadata,
667        out: Option<cube::Variable>,
668    ) -> wgsl::Instruction {
669        let out = out.unwrap();
670        match metadata {
671            cube::Metadata::Rank { var } => {
672                let position = self.ext_meta_pos(&var);
673                let offset = self.metadata.rank_index(position);
674                wgsl::Instruction::Metadata {
675                    out: self.compile_variable(out),
676                    info_offset: self.compile_variable(offset.into()),
677                }
678            }
679            cube::Metadata::Stride { dim, var } => {
680                let position = self.ext_meta_pos(&var);
681                let offset = self.metadata.stride_offset_index(position);
682                wgsl::Instruction::ExtendedMeta {
683                    info_offset: self.compile_variable(offset.into()),
684                    dim: self.compile_variable(dim),
685                    out: self.compile_variable(out),
686                }
687            }
688            cube::Metadata::Shape { dim, var } => {
689                let position = self.ext_meta_pos(&var);
690                let offset = self.metadata.shape_offset_index(position);
691                wgsl::Instruction::ExtendedMeta {
692                    info_offset: self.compile_variable(offset.into()),
693                    dim: self.compile_variable(dim),
694                    out: self.compile_variable(out),
695                }
696            }
697            cube::Metadata::Length { var } => match var.kind {
698                cube::VariableKind::GlobalInputArray(id) => {
699                    let offset = self.metadata.len_index(id);
700                    wgsl::Instruction::Metadata {
701                        out: self.compile_variable(out),
702                        info_offset: self.compile_variable(offset.into()),
703                    }
704                }
705                cube::VariableKind::GlobalOutputArray(id) => {
706                    let offset = self.metadata.len_index(self.num_inputs as u32 + id);
707                    wgsl::Instruction::Metadata {
708                        out: self.compile_variable(out),
709                        info_offset: self.compile_variable(offset.into()),
710                    }
711                }
712                _ => wgsl::Instruction::Length {
713                    var: self.compile_variable(var),
714                    out: self.compile_variable(out),
715                },
716            },
717            cube::Metadata::BufferLength { var } => match var.kind {
718                cube::VariableKind::GlobalInputArray(id) => {
719                    let offset = self.metadata.buffer_len_index(id);
720                    wgsl::Instruction::Metadata {
721                        out: self.compile_variable(out),
722                        info_offset: self.compile_variable(offset.into()),
723                    }
724                }
725                cube::VariableKind::GlobalOutputArray(id) => {
726                    let id = self.num_inputs as u32 + id;
727                    let offset = self.metadata.buffer_len_index(id);
728                    wgsl::Instruction::Metadata {
729                        out: self.compile_variable(out),
730                        info_offset: self.compile_variable(offset.into()),
731                    }
732                }
733                _ => wgsl::Instruction::Length {
734                    var: self.compile_variable(var),
735                    out: self.compile_variable(out),
736                },
737            },
738        }
739    }
740
741    fn compile_instruction(
742        &mut self,
743        value: cube::Operator,
744        out: Option<cube::Variable>,
745        instructions: &mut Vec<wgsl::Instruction>,
746        scope: &mut cube::Scope,
747    ) {
748        let out = out.unwrap();
749        match value {
750            cube::Operator::Max(op) => instructions.push(wgsl::Instruction::Max {
751                lhs: self.compile_variable(op.lhs),
752                rhs: self.compile_variable(op.rhs),
753                out: self.compile_variable(out),
754            }),
755            cube::Operator::Min(op) => instructions.push(wgsl::Instruction::Min {
756                lhs: self.compile_variable(op.lhs),
757                rhs: self.compile_variable(op.rhs),
758                out: self.compile_variable(out),
759            }),
760            cube::Operator::Add(op) => instructions.push(wgsl::Instruction::Add {
761                lhs: self.compile_variable(op.lhs),
762                rhs: self.compile_variable(op.rhs),
763                out: self.compile_variable(out),
764            }),
765            cube::Operator::Fma(op) => instructions.push(wgsl::Instruction::Fma {
766                a: self.compile_variable(op.a),
767                b: self.compile_variable(op.b),
768                c: self.compile_variable(op.c),
769                out: self.compile_variable(out),
770            }),
771            cube::Operator::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
772                lhs: self.compile_variable(op.lhs),
773                rhs: self.compile_variable(op.rhs),
774                out: self.compile_variable(out),
775            }),
776            cube::Operator::Sub(op) => instructions.push(wgsl::Instruction::Sub {
777                lhs: self.compile_variable(op.lhs),
778                rhs: self.compile_variable(op.rhs),
779                out: self.compile_variable(out),
780            }),
781            cube::Operator::Mul(op) => instructions.push(wgsl::Instruction::Mul {
782                lhs: self.compile_variable(op.lhs),
783                rhs: self.compile_variable(op.rhs),
784                out: self.compile_variable(out),
785            }),
786            cube::Operator::Div(op) => instructions.push(wgsl::Instruction::Div {
787                lhs: self.compile_variable(op.lhs),
788                rhs: self.compile_variable(op.rhs),
789                out: self.compile_variable(out),
790            }),
791            cube::Operator::Abs(op) => instructions.push(wgsl::Instruction::Abs {
792                input: self.compile_variable(op.input),
793                out: self.compile_variable(out),
794            }),
795            cube::Operator::Exp(op) => instructions.push(wgsl::Instruction::Exp {
796                input: self.compile_variable(op.input),
797                out: self.compile_variable(out),
798            }),
799            cube::Operator::Log(op) => instructions.push(wgsl::Instruction::Log {
800                input: self.compile_variable(op.input),
801                out: self.compile_variable(out),
802            }),
803            cube::Operator::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
804                input: self.compile_variable(op.input),
805                out: self.compile_variable(out),
806            }),
807            cube::Operator::Cos(op) => instructions.push(wgsl::Instruction::Cos {
808                input: self.compile_variable(op.input),
809                out: self.compile_variable(out),
810            }),
811            cube::Operator::Sin(op) => instructions.push(wgsl::Instruction::Sin {
812                input: self.compile_variable(op.input),
813                out: self.compile_variable(out),
814            }),
815            cube::Operator::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
816                input: self.compile_variable(op.input),
817                out: self.compile_variable(out),
818            }),
819            cube::Operator::Powf(op) => instructions.push(wgsl::Instruction::Powf {
820                lhs: self.compile_variable(op.lhs),
821                rhs: self.compile_variable(op.rhs),
822                out: self.compile_variable(out),
823            }),
824            cube::Operator::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
825                input: self.compile_variable(op.input),
826                out: self.compile_variable(out),
827            }),
828            cube::Operator::Round(op) => instructions.push(wgsl::Instruction::Round {
829                input: self.compile_variable(op.input),
830                out: self.compile_variable(out),
831            }),
832            cube::Operator::Floor(op) => instructions.push(wgsl::Instruction::Floor {
833                input: self.compile_variable(op.input),
834                out: self.compile_variable(out),
835            }),
836            cube::Operator::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
837                input: self.compile_variable(op.input),
838                out: self.compile_variable(out),
839            }),
840            cube::Operator::Erf(op) => instructions.push(wgsl::Instruction::Erf {
841                input: self.compile_variable(op.input),
842                out: self.compile_variable(out),
843            }),
844            cube::Operator::Recip(op) => instructions.push(wgsl::Instruction::Recip {
845                input: self.compile_variable(op.input),
846                out: self.compile_variable(out),
847            }),
848            cube::Operator::Equal(op) => instructions.push(wgsl::Instruction::Equal {
849                lhs: self.compile_variable(op.lhs),
850                rhs: self.compile_variable(op.rhs),
851                out: self.compile_variable(out),
852            }),
853            cube::Operator::Lower(op) => instructions.push(wgsl::Instruction::Lower {
854                lhs: self.compile_variable(op.lhs),
855                rhs: self.compile_variable(op.rhs),
856                out: self.compile_variable(out),
857            }),
858            cube::Operator::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
859                input: self.compile_variable(op.input),
860                min_value: self.compile_variable(op.min_value),
861                max_value: self.compile_variable(op.max_value),
862                out: self.compile_variable(out),
863            }),
864            cube::Operator::Greater(op) => instructions.push(wgsl::Instruction::Greater {
865                lhs: self.compile_variable(op.lhs),
866                rhs: self.compile_variable(op.rhs),
867                out: self.compile_variable(out),
868            }),
869            cube::Operator::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
870                lhs: self.compile_variable(op.lhs),
871                rhs: self.compile_variable(op.rhs),
872                out: self.compile_variable(out),
873            }),
874            cube::Operator::GreaterEqual(op) => {
875                instructions.push(wgsl::Instruction::GreaterEqual {
876                    lhs: self.compile_variable(op.lhs),
877                    rhs: self.compile_variable(op.rhs),
878                    out: self.compile_variable(out),
879                })
880            }
881            cube::Operator::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
882                lhs: self.compile_variable(op.lhs),
883                rhs: self.compile_variable(op.rhs),
884                out: self.compile_variable(out),
885            }),
886            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
887                input: self.compile_variable(op.input),
888                out: self.compile_variable(out),
889            }),
890            cube::Operator::Index(op) => {
891                if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() {
892                    let lhs = op.lhs;
893                    let rhs = op.rhs;
894                    let array_len =
895                        scope.create_local(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32)));
896
897                    instructions.extend(self.compile_scope(scope));
898
899                    let length = match lhs.has_buffer_length() {
900                        true => cube::Metadata::BufferLength { var: lhs },
901                        false => cube::Metadata::Length { var: lhs },
902                    };
903                    instructions.push(self.compile_metadata(length, Some(array_len)));
904                    instructions.push(wgsl::Instruction::CheckedIndex {
905                        len: self.compile_variable(array_len),
906                        lhs: self.compile_variable(lhs),
907                        rhs: self.compile_variable(rhs),
908                        out: self.compile_variable(out),
909                    });
910                } else {
911                    instructions.push(wgsl::Instruction::Index {
912                        lhs: self.compile_variable(op.lhs),
913                        rhs: self.compile_variable(op.rhs),
914                        out: self.compile_variable(out),
915                    });
916                }
917            }
918            cube::Operator::UncheckedIndex(op) => instructions.push(wgsl::Instruction::Index {
919                lhs: self.compile_variable(op.lhs),
920                rhs: self.compile_variable(op.rhs),
921                out: self.compile_variable(out),
922            }),
923            cube::Operator::IndexAssign(op) => {
924                if let ExecutionMode::Checked = self.strategy {
925                    if out.has_length() {
926                        expand_checked_index_assign(scope, op.lhs, op.rhs, out);
927                        instructions.extend(self.compile_scope(scope));
928                        return;
929                    }
930                };
931                instructions.push(wgsl::Instruction::IndexAssign {
932                    lhs: self.compile_variable(op.lhs),
933                    rhs: self.compile_variable(op.rhs),
934                    out: self.compile_variable(out),
935                })
936            }
937            cube::Operator::UncheckedIndexAssign(op) => {
938                instructions.push(wgsl::Instruction::IndexAssign {
939                    lhs: self.compile_variable(op.lhs),
940                    rhs: self.compile_variable(op.rhs),
941                    out: self.compile_variable(out),
942                })
943            }
944            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
945                lhs: self.compile_variable(op.lhs),
946                rhs: self.compile_variable(op.rhs),
947                out: self.compile_variable(out),
948            }),
949            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
950                lhs: self.compile_variable(op.lhs),
951                rhs: self.compile_variable(op.rhs),
952                out: self.compile_variable(out),
953            }),
954            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
955                input: self.compile_variable(op.input),
956                out: self.compile_variable(out),
957            }),
958            cube::Operator::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
959                lhs: self.compile_variable(op.lhs),
960                rhs: self.compile_variable(op.rhs),
961                out: self.compile_variable(out),
962            }),
963            cube::Operator::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
964                lhs: self.compile_variable(op.lhs),
965                rhs: self.compile_variable(op.rhs),
966                out: self.compile_variable(out),
967            }),
968            cube::Operator::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
969                lhs: self.compile_variable(op.lhs),
970                rhs: self.compile_variable(op.rhs),
971                out: self.compile_variable(out),
972            }),
973            cube::Operator::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
974                input: self.compile_variable(op.input),
975                out: self.compile_variable(out),
976            }),
977            cube::Operator::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
978                input: self.compile_variable(op.input),
979                out: self.compile_variable(out),
980            }),
981            cube::Operator::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
982                lhs: self.compile_variable(op.lhs),
983                rhs: self.compile_variable(op.rhs),
984                out: self.compile_variable(out),
985            }),
986            cube::Operator::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
987                lhs: self.compile_variable(op.lhs),
988                rhs: self.compile_variable(op.rhs),
989                out: self.compile_variable(out),
990            }),
991            cube::Operator::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
992                lhs: self.compile_variable(op.lhs),
993                rhs: self.compile_variable(op.rhs),
994                out: self.compile_variable(out),
995            }),
996            cube::Operator::Slice(op) => {
997                if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() {
998                    let input = op.input;
999                    let input_len = scope
1000                        .create_local_mut(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32)));
1001                    instructions.extend(self.compile_scope(scope));
1002
1003                    let length = match input.has_buffer_length() {
1004                        true => cube::Metadata::BufferLength { var: input },
1005                        false => cube::Metadata::Length { var: input },
1006                    };
1007
1008                    instructions.push(self.compile_metadata(length, Some(input_len)));
1009                    instructions.push(wgsl::Instruction::CheckedSlice {
1010                        input: self.compile_variable(input),
1011                        start: self.compile_variable(op.start),
1012                        end: self.compile_variable(op.end),
1013                        out: self.compile_variable(out),
1014                        len: self.compile_variable(input_len),
1015                    });
1016                } else {
1017                    instructions.push(wgsl::Instruction::Slice {
1018                        input: self.compile_variable(op.input),
1019                        start: self.compile_variable(op.start),
1020                        end: self.compile_variable(op.end),
1021                        out: self.compile_variable(out),
1022                    });
1023                }
1024            }
1025            cube::Operator::Bitcast(op) => instructions.push(wgsl::Instruction::Bitcast {
1026                input: self.compile_variable(op.input),
1027                out: self.compile_variable(out),
1028            }),
1029
1030            cube::Operator::Neg(op) => instructions.push(wgsl::Instruction::Negate {
1031                input: self.compile_variable(op.input),
1032                out: self.compile_variable(out),
1033            }),
1034            cube::Operator::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
1035                input: self.compile_variable(op.input),
1036                out: self.compile_variable(out),
1037            }),
1038            cube::Operator::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
1039                input: self.compile_variable(op.input),
1040                out: self.compile_variable(out),
1041            }),
1042            cube::Operator::Dot(op) => instructions.push(wgsl::Instruction::Dot {
1043                lhs: self.compile_variable(op.lhs),
1044                rhs: self.compile_variable(op.rhs),
1045                out: self.compile_variable(out),
1046            }),
1047            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1048                inputs: op
1049                    .inputs
1050                    .into_iter()
1051                    .map(|var| self.compile_variable(var))
1052                    .collect(),
1053                out: self.compile_variable(out),
1054            }),
1055            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1056                input: self.compile_variable(op.input),
1057                in_index: self.compile_variable(op.in_index),
1058                out: self.compile_variable(out),
1059                out_index: self.compile_variable(op.out_index),
1060            }),
1061            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1062                input: self.compile_variable(op.input),
1063                in_index: self.compile_variable(op.in_index),
1064                out: self.compile_variable(out),
1065                out_index: self.compile_variable(op.out_index),
1066                len: op.len,
1067            }),
1068            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1069                cond: self.compile_variable(op.cond),
1070                then: self.compile_variable(op.then),
1071                or_else: self.compile_variable(op.or_else),
1072                out: self.compile_variable(out),
1073            }),
1074        }
1075    }
1076
1077    fn compile_atomic(
1078        &mut self,
1079        atomic: cube::AtomicOp,
1080        out: Option<cube::Variable>,
1081    ) -> wgsl::Instruction {
1082        let out = out.unwrap();
1083        match atomic {
1084            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1085                lhs: self.compile_variable(op.lhs),
1086                rhs: self.compile_variable(op.rhs),
1087                out: self.compile_variable(out),
1088            },
1089            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1090                lhs: self.compile_variable(op.lhs),
1091                rhs: self.compile_variable(op.rhs),
1092                out: self.compile_variable(out),
1093            },
1094            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1095                lhs: self.compile_variable(op.lhs),
1096                rhs: self.compile_variable(op.rhs),
1097                out: self.compile_variable(out),
1098            },
1099            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1100                lhs: self.compile_variable(op.lhs),
1101                rhs: self.compile_variable(op.rhs),
1102                out: self.compile_variable(out),
1103            },
1104            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1105                lhs: self.compile_variable(op.lhs),
1106                rhs: self.compile_variable(op.rhs),
1107                out: self.compile_variable(out),
1108            },
1109            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1110                lhs: self.compile_variable(op.lhs),
1111                rhs: self.compile_variable(op.rhs),
1112                out: self.compile_variable(out),
1113            },
1114            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1115                lhs: self.compile_variable(op.lhs),
1116                rhs: self.compile_variable(op.rhs),
1117                out: self.compile_variable(out),
1118            },
1119            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1120                input: self.compile_variable(op.input),
1121                out: self.compile_variable(out),
1122            },
1123            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1124                input: self.compile_variable(op.input),
1125                out: self.compile_variable(out),
1126            },
1127            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1128                lhs: self.compile_variable(op.lhs),
1129                rhs: self.compile_variable(op.rhs),
1130                out: self.compile_variable(out),
1131            },
1132            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1133                lhs: self.compile_variable(op.input),
1134                cmp: self.compile_variable(op.cmp),
1135                value: self.compile_variable(op.val),
1136                out: self.compile_variable(out),
1137            },
1138        }
1139    }
1140
1141    fn compile_location(value: cube::Location) -> wgsl::Location {
1142        match value {
1143            cube::Location::Storage => wgsl::Location::Storage,
1144            cube::Location::Cube => wgsl::Location::Workgroup,
1145        }
1146    }
1147
1148    fn compile_visibility(value: cube::Visibility) -> wgsl::Visibility {
1149        match value {
1150            cube::Visibility::Read => wgsl::Visibility::Read,
1151            cube::Visibility::ReadWrite => wgsl::Visibility::ReadWrite,
1152        }
1153    }
1154
1155    fn compile_binding(value: cube::Binding) -> wgsl::Binding {
1156        wgsl::Binding {
1157            visibility: Self::compile_visibility(value.visibility),
1158            location: Self::compile_location(value.location),
1159            item: Self::compile_item(value.item),
1160            size: value.size,
1161        }
1162    }
1163}
1164
1165fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1166    let mut extensions = Vec::new();
1167
1168    let mut register_extension = |extension: wgsl::Extension| {
1169        if !extensions.contains(&extension) {
1170            extensions.push(extension);
1171        }
1172    };
1173
1174    // Since not all instructions are native to WGSL, we need to add the custom ones.
1175    for instruction in instructions {
1176        match instruction {
1177            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1178                register_extension(wgsl::Extension::PowfPrimitive(out.item()));
1179
1180                if rhs.is_always_scalar() || rhs.item().vectorization_factor() == 1 {
1181                    register_extension(wgsl::Extension::PowfScalar(out.item()));
1182                } else {
1183                    register_extension(wgsl::Extension::Powf(out.item()));
1184                }
1185            }
1186            wgsl::Instruction::Erf { input, out: _ } => {
1187                register_extension(wgsl::Extension::Erf(input.item()));
1188            }
1189            #[cfg(target_os = "macos")]
1190            wgsl::Instruction::Tanh { input, out: _ } => {
1191                register_extension(wgsl::Extension::SafeTanh(input.item()))
1192            }
1193            wgsl::Instruction::If { instructions, .. } => {
1194                for extension in register_extensions(instructions) {
1195                    register_extension(extension);
1196                }
1197            }
1198            wgsl::Instruction::IfElse {
1199                instructions_if,
1200                instructions_else,
1201                ..
1202            } => {
1203                for extension in register_extensions(instructions_if) {
1204                    register_extension(extension);
1205                }
1206                for extension in register_extensions(instructions_else) {
1207                    register_extension(extension);
1208                }
1209            }
1210            wgsl::Instruction::Loop { instructions } => {
1211                for extension in register_extensions(instructions) {
1212                    register_extension(extension);
1213                }
1214            }
1215            wgsl::Instruction::RangeLoop { instructions, .. } => {
1216                for extension in register_extensions(instructions) {
1217                    register_extension(extension);
1218                }
1219            }
1220            _ => {}
1221        }
1222    }
1223
1224    extensions
1225}