cubecl_wgpu/compiler/wgsl/
compiler.rs

1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedMemory};
4use crate::compiler::wgsl;
5
6use cubecl_common::ExecutionMode;
7use cubecl_core::post_processing::{
8    checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12    Metadata, WgpuCompilationOptions, compute,
13    ir::{self as cube, Scope},
14    prelude::expand_erf,
15};
16use cubecl_core::{
17    ir::{ConstantScalarValue, Processor, UIntKind},
18    post_processing::unroll::UnrollProcessor,
19};
20
21pub const MAX_LINE_SIZE: u32 = 4;
22
23/// Wgsl Compiler.
24#[derive(Clone, Default)]
25pub struct WgslCompiler {
26    metadata: Metadata,
27    ext_meta_pos: Vec<u32>,
28    local_invocation_index: bool,
29    local_invocation_id: bool,
30    // TODO: possible cleanup, this bool seems to not be used
31    global_invocation_id: bool,
32    workgroup_id: bool,
33    subgroup_size: bool,
34    subgroup_invocation_id: bool,
35    id: bool,
36    num_workgroups: bool,
37    workgroup_id_no_axis: bool,
38    workgroup_size_no_axis: bool,
39    num_workgroup_no_axis: bool,
40    shared_memories: Vec<SharedMemory>,
41    const_arrays: Vec<ConstantArray>,
42    local_arrays: Vec<LocalArray>,
43    #[allow(dead_code)]
44    compilation_options: WgpuCompilationOptions,
45    strategy: ExecutionMode,
46    subgroup_instructions_used: bool,
47    f16_used: bool,
48}
49
50impl core::fmt::Debug for WgslCompiler {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.write_str("WgslCompiler")
53    }
54}
55
56impl cubecl_core::Compiler for WgslCompiler {
57    type Representation = ComputeShader;
58    type CompilationOptions = WgpuCompilationOptions;
59
60    fn compile(
61        &mut self,
62        shader: compute::KernelDefinition,
63        compilation_options: &Self::CompilationOptions,
64        mode: ExecutionMode,
65    ) -> Self::Representation {
66        self.compilation_options = compilation_options.clone();
67        self.compile_shader(shader, mode)
68    }
69
70    fn elem_size(&self, elem: cube::ElemType) -> usize {
71        elem.size()
72    }
73
74    fn extension(&self) -> &'static str {
75        "wgsl"
76    }
77}
78
79impl WgslCompiler {
80    fn compile_shader(
81        &mut self,
82        mut value: compute::KernelDefinition,
83        mode: ExecutionMode,
84    ) -> wgsl::ComputeShader {
85        self.strategy = mode;
86
87        let num_meta = value.buffers.len();
88
89        self.ext_meta_pos = Vec::new();
90        let mut num_ext = 0;
91
92        for binding in value.buffers.iter() {
93            self.ext_meta_pos.push(num_ext);
94            if binding.has_extended_meta {
95                num_ext += 1;
96            }
97        }
98
99        self.metadata = Metadata::new(num_meta as u32, num_ext);
100
101        let instructions = self.compile_scope(&mut value.body);
102        let extensions = register_extensions(&instructions);
103        let body = wgsl::Body {
104            instructions,
105            id: self.id,
106        };
107
108        wgsl::ComputeShader {
109            buffers: value
110                .buffers
111                .into_iter()
112                .map(|mut it| {
113                    // This is safe when combined with the unroll transform that adjusts all indices.
114                    // Must not be used alone
115                    if it.ty.line_size() > MAX_LINE_SIZE {
116                        it.ty = it.ty.line(MAX_LINE_SIZE);
117                    }
118                    self.compile_binding(it)
119                })
120                .collect(),
121            scalars: value
122                .scalars
123                .into_iter()
124                .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
125                .collect(),
126            shared_memories: self.shared_memories.clone(),
127            constant_arrays: self.const_arrays.clone(),
128            local_arrays: self.local_arrays.clone(),
129            has_metadata: self.metadata.static_len() > 0,
130            workgroup_size: value.cube_dim,
131            global_invocation_id: self.global_invocation_id || self.id,
132            local_invocation_index: self.local_invocation_index,
133            local_invocation_id: self.local_invocation_id,
134            num_workgroups: self.id
135                || self.num_workgroups
136                || self.num_workgroup_no_axis
137                || self.workgroup_id_no_axis,
138            workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
139            subgroup_size: self.subgroup_size,
140            subgroup_invocation_id: self.subgroup_invocation_id,
141            body,
142            extensions,
143            num_workgroups_no_axis: self.num_workgroup_no_axis,
144            workgroup_id_no_axis: self.workgroup_id_no_axis,
145            workgroup_size_no_axis: self.workgroup_size_no_axis,
146            subgroup_instructions_used: self.subgroup_instructions_used,
147            f16_used: self.f16_used,
148            kernel_name: value.options.kernel_name,
149        }
150    }
151
152    fn compile_type(&mut self, item: cube::Type) -> Item {
153        match item {
154            cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
155            cube::Type::Line(ty, size) => {
156                let elem = self.compile_storage_type(ty);
157                match size {
158                    2 => wgsl::Item::Vec2(elem),
159                    3 => wgsl::Item::Vec3(elem),
160                    4 => wgsl::Item::Vec4(elem),
161                    _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
162                }
163            }
164            cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
165        }
166    }
167
168    fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
169        match ty {
170            cube::StorageType::Scalar(ty) => self.compile_elem(ty),
171            cube::StorageType::Atomic(ty) => match ty {
172                cube::ElemType::Float(i) => match i {
173                    cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
174                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
175                },
176                cube::ElemType::Int(i) => match i {
177                    cube::IntKind::I32 => wgsl::Elem::AtomicI32,
178                    kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
179                },
180                cube::ElemType::UInt(kind) => match kind {
181                    cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
182                    kind => panic!("{kind:?} is not a valid WgpuElement"),
183                },
184                other => panic!("{other:?} is not a valid WgpuElement"),
185            },
186            cube::StorageType::Packed(_, _) => {
187                unimplemented!("Packed types not yet supported in WGSL")
188            }
189        }
190    }
191
192    fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
193        match value {
194            cube::ElemType::Float(f) => match f {
195                cube::FloatKind::E2M1
196                | cube::FloatKind::E2M3
197                | cube::FloatKind::E3M2
198                | cube::FloatKind::E4M3
199                | cube::FloatKind::E5M2
200                | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
201                cube::FloatKind::F16 => {
202                    self.f16_used = true;
203                    wgsl::Elem::F16
204                }
205                cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
206                cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
207                cube::FloatKind::Flex32 => wgsl::Elem::F32,
208                cube::FloatKind::F32 => wgsl::Elem::F32,
209                cube::FloatKind::F64 => wgsl::Elem::F64,
210            },
211            cube::ElemType::Int(i) => match i {
212                cube::IntKind::I32 => wgsl::Elem::I32,
213                cube::IntKind::I64 => wgsl::Elem::I64,
214                kind => panic!("{kind:?} is not a valid WgpuElement"),
215            },
216            cube::ElemType::UInt(kind) => match kind {
217                cube::UIntKind::U32 => wgsl::Elem::U32,
218                cube::UIntKind::U64 => wgsl::Elem::U64,
219                kind => panic!("{kind:?} is not a valid WgpuElement"),
220            },
221            cube::ElemType::Bool => wgsl::Elem::Bool,
222        }
223    }
224
225    fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
226        let pos = var.index().expect("Variable should have index");
227        self.ext_meta_pos[pos as usize]
228    }
229
230    pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
231        let item = value.ty;
232        match value.kind {
233            cube::VariableKind::GlobalInputArray(id) => {
234                wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
235            }
236            cube::VariableKind::GlobalScalar(id) => {
237                wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
238            }
239            cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
240                wgsl::Variable::LocalMut {
241                    id,
242                    item: self.compile_type(item),
243                }
244            }
245            cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
246                id,
247                item: self.compile_type(item),
248            },
249            cube::VariableKind::GlobalOutputArray(id) => {
250                wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
251            }
252            cube::VariableKind::ConstantScalar(value) => {
253                wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
254            }
255            cube::VariableKind::SharedMemory {
256                id,
257                length,
258                unroll_factor,
259                alignment,
260            } => {
261                let item = self.compile_type(item);
262                if !self.shared_memories.iter().any(|s| s.index == id) {
263                    self.shared_memories.push(SharedMemory::new(
264                        id,
265                        item,
266                        length * unroll_factor,
267                        alignment,
268                    ));
269                }
270                wgsl::Variable::SharedMemory(id, item, length)
271            }
272            cube::VariableKind::ConstantArray { id, length, .. } => {
273                let item = self.compile_type(item);
274                wgsl::Variable::ConstantArray(id, item, length)
275            }
276            cube::VariableKind::LocalArray {
277                id,
278                length,
279                unroll_factor,
280            } => {
281                let item = self.compile_type(item);
282                if !self.local_arrays.iter().any(|s| s.index == id) {
283                    self.local_arrays
284                        .push(LocalArray::new(id, item, length * unroll_factor));
285                }
286                wgsl::Variable::LocalArray(id, item, length)
287            }
288            cube::VariableKind::Builtin(builtin) => match builtin {
289                cube::Builtin::AbsolutePos => {
290                    self.id = true;
291                    wgsl::Variable::Id
292                }
293                cube::Builtin::UnitPos => {
294                    self.local_invocation_index = true;
295                    wgsl::Variable::LocalInvocationIndex
296                }
297                cube::Builtin::UnitPosX => {
298                    self.local_invocation_id = true;
299                    wgsl::Variable::LocalInvocationIdX
300                }
301                cube::Builtin::UnitPosY => {
302                    self.local_invocation_id = true;
303                    wgsl::Variable::LocalInvocationIdY
304                }
305                cube::Builtin::UnitPosZ => {
306                    self.local_invocation_id = true;
307                    wgsl::Variable::LocalInvocationIdZ
308                }
309                cube::Builtin::CubePosX => {
310                    self.workgroup_id = true;
311                    wgsl::Variable::WorkgroupIdX
312                }
313                cube::Builtin::CubePosY => {
314                    self.workgroup_id = true;
315                    wgsl::Variable::WorkgroupIdY
316                }
317                cube::Builtin::CubePosZ => {
318                    self.workgroup_id = true;
319                    wgsl::Variable::WorkgroupIdZ
320                }
321                cube::Builtin::CubePosCluster
322                | cube::Builtin::CubePosClusterX
323                | cube::Builtin::CubePosClusterY
324                | cube::Builtin::CubePosClusterZ => self.constant_var(1),
325                cube::Builtin::AbsolutePosX => {
326                    self.global_invocation_id = true;
327                    wgsl::Variable::GlobalInvocationIdX
328                }
329                cube::Builtin::AbsolutePosY => {
330                    self.global_invocation_id = true;
331                    wgsl::Variable::GlobalInvocationIdY
332                }
333                cube::Builtin::AbsolutePosZ => {
334                    self.global_invocation_id = true;
335                    wgsl::Variable::GlobalInvocationIdZ
336                }
337                cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
338                cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
339                cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
340                cube::Builtin::CubeClusterDim
341                | cube::Builtin::CubeClusterDimX
342                | cube::Builtin::CubeClusterDimY
343                | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
344                cube::Builtin::CubeCountX => {
345                    self.num_workgroups = true;
346                    wgsl::Variable::NumWorkgroupsX
347                }
348                cube::Builtin::CubeCountY => {
349                    self.num_workgroups = true;
350                    wgsl::Variable::NumWorkgroupsY
351                }
352                cube::Builtin::CubeCountZ => {
353                    self.num_workgroups = true;
354                    wgsl::Variable::NumWorkgroupsZ
355                }
356                cube::Builtin::CubePos => {
357                    self.workgroup_id_no_axis = true;
358                    wgsl::Variable::WorkgroupId
359                }
360                cube::Builtin::CubeDim => {
361                    self.workgroup_size_no_axis = true;
362                    wgsl::Variable::WorkgroupSize
363                }
364                cube::Builtin::CubeCount => {
365                    self.num_workgroup_no_axis = true;
366                    wgsl::Variable::NumWorkgroups
367                }
368                cube::Builtin::PlaneDim => {
369                    self.subgroup_size = true;
370                    wgsl::Variable::SubgroupSize
371                }
372                cube::Builtin::UnitPosPlane => {
373                    self.subgroup_invocation_id = true;
374                    wgsl::Variable::SubgroupInvocationId
375                }
376            },
377            cube::VariableKind::Matrix { .. } => {
378                panic!("Cooperative matrix-multiply and accumulate not supported.")
379            }
380            cube::VariableKind::Pipeline { .. } => {
381                panic!("Pipeline not supported.")
382            }
383            cube::VariableKind::Barrier { .. } => {
384                panic!("Barrier not supported.")
385            }
386            cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
387            cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
388        }
389    }
390
391    fn constant_var(&mut self, value: u32) -> wgsl::Variable {
392        let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
393        self.compile_variable(var)
394    }
395
396    fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
397        let mut instructions = Vec::new();
398
399        let const_arrays = scope
400            .const_arrays
401            .drain(..)
402            .map(|(var, values)| ConstantArray {
403                index: var.index().unwrap(),
404                item: self.compile_type(var.ty),
405                size: values.len() as u32,
406                values: values
407                    .into_iter()
408                    .map(|val| self.compile_variable(val))
409                    .collect(),
410            })
411            .collect::<Vec<_>>();
412        self.const_arrays.extend(const_arrays);
413
414        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
415        let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
416        let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
417        let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
418
419        for mut var in processing.variables {
420            if var.ty.line_size() > MAX_LINE_SIZE {
421                var.ty = var.ty.line(MAX_LINE_SIZE);
422            }
423            instructions.push(wgsl::Instruction::DeclareVariable {
424                var: self.compile_variable(var),
425            });
426        }
427
428        processing
429            .instructions
430            .into_iter()
431            .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
432
433        instructions
434    }
435
436    fn compile_operation(
437        &mut self,
438        instructions: &mut Vec<wgsl::Instruction>,
439        operation: cube::Operation,
440        out: Option<cube::Variable>,
441        scope: &mut cube::Scope,
442    ) {
443        match operation {
444            cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
445                input: self.compile_variable(variable),
446                out: self.compile_variable(out.unwrap()),
447            }),
448            cube::Operation::Arithmetic(op) => {
449                self.compile_arithmetic(op, out, instructions, scope)
450            }
451            cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
452            cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
453            cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
454            cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
455            cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
456            cube::Operation::Branch(val) => self.compile_branch(instructions, val),
457            cube::Operation::Synchronization(val) => {
458                self.compile_synchronization(instructions, val)
459            }
460            cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
461            cube::Operation::CoopMma(_) => {
462                panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
463            }
464            cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
465                self.compile_comment(instructions, content)
466            }
467            cube::Operation::NonSemantic(_) => {}
468            cube::Operation::Barrier(_) => {
469                panic!("Barrier isn't supported on wgpu.")
470            }
471            cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
472            cube::Operation::Free(_) => {}
473        }
474    }
475
476    fn compile_subgroup(
477        &mut self,
478        instructions: &mut Vec<wgsl::Instruction>,
479        subgroup: cube::Plane,
480        out: Option<cube::Variable>,
481    ) {
482        self.subgroup_instructions_used = true;
483
484        let out = out.unwrap();
485        let op = match subgroup {
486            cube::Plane::Elect => Subgroup::Elect {
487                out: self.compile_variable(out),
488            },
489            cube::Plane::All(op) => Subgroup::All {
490                input: self.compile_variable(op.input),
491                out: self.compile_variable(out),
492            },
493            cube::Plane::Any(op) => Subgroup::Any {
494                input: self.compile_variable(op.input),
495                out: self.compile_variable(out),
496            },
497            cube::Plane::Ballot(op) => Subgroup::Ballot {
498                input: self.compile_variable(op.input),
499                out: self.compile_variable(out),
500            },
501
502            cube::Plane::Broadcast(op) => Subgroup::Broadcast {
503                lhs: self.compile_variable(op.lhs),
504                rhs: self.compile_variable(op.rhs),
505                out: self.compile_variable(out),
506            },
507
508            cube::Plane::Sum(op) => Subgroup::Sum {
509                input: self.compile_variable(op.input),
510                out: self.compile_variable(out),
511            },
512
513            cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
514                input: self.compile_variable(op.input),
515                out: self.compile_variable(out),
516            },
517            cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
518                input: self.compile_variable(op.input),
519                out: self.compile_variable(out),
520            },
521            cube::Plane::Prod(op) => Subgroup::Prod {
522                input: self.compile_variable(op.input),
523                out: self.compile_variable(out),
524            },
525            cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
526                input: self.compile_variable(op.input),
527                out: self.compile_variable(out),
528            },
529            cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
530                input: self.compile_variable(op.input),
531                out: self.compile_variable(out),
532            },
533            cube::Plane::Min(op) => Subgroup::Min {
534                input: self.compile_variable(op.input),
535                out: self.compile_variable(out),
536            },
537            cube::Plane::Max(op) => Subgroup::Max {
538                input: self.compile_variable(op.input),
539                out: self.compile_variable(out),
540            },
541            cube::Plane::Shuffle(op) => Subgroup::Shuffle {
542                lhs: self.compile_variable(op.lhs),
543                rhs: self.compile_variable(op.rhs),
544                out: self.compile_variable(out),
545            },
546            cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
547                lhs: self.compile_variable(op.lhs),
548                rhs: self.compile_variable(op.rhs),
549                out: self.compile_variable(out),
550            },
551            cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
552                lhs: self.compile_variable(op.lhs),
553                rhs: self.compile_variable(op.rhs),
554                out: self.compile_variable(out),
555            },
556            cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
557                lhs: self.compile_variable(op.lhs),
558                rhs: self.compile_variable(op.rhs),
559                out: self.compile_variable(out),
560            },
561        };
562
563        instructions.push(wgsl::Instruction::Subgroup(op));
564    }
565
566    fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
567        match branch {
568            cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
569                cond: self.compile_variable(op.cond),
570                instructions: self.compile_scope(&mut op.scope),
571            }),
572            cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
573                cond: self.compile_variable(op.cond),
574                instructions_if: self.compile_scope(&mut op.scope_if),
575                instructions_else: self.compile_scope(&mut op.scope_else),
576            }),
577            cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
578                value: self.compile_variable(op.value),
579                instructions_default: self.compile_scope(&mut op.scope_default),
580                cases: op
581                    .cases
582                    .into_iter()
583                    .map(|(val, mut scope)| {
584                        (self.compile_variable(val), self.compile_scope(&mut scope))
585                    })
586                    .collect(),
587            }),
588            cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
589            cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
590            cube::Branch::RangeLoop(mut range_loop) => {
591                instructions.push(wgsl::Instruction::RangeLoop {
592                    i: self.compile_variable(range_loop.i),
593                    start: self.compile_variable(range_loop.start),
594                    end: self.compile_variable(range_loop.end),
595                    step: range_loop.step.map(|it| self.compile_variable(it)),
596                    inclusive: range_loop.inclusive,
597                    instructions: self.compile_scope(&mut range_loop.scope),
598                })
599            }
600            cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
601                instructions: self.compile_scope(&mut op.scope),
602            }),
603        };
604    }
605
606    fn compile_synchronization(
607        &mut self,
608        instructions: &mut Vec<wgsl::Instruction>,
609        synchronization: cube::Synchronization,
610    ) {
611        match synchronization {
612            cube::Synchronization::SyncCube => {
613                instructions.push(wgsl::Instruction::WorkgroupBarrier)
614            }
615            cube::Synchronization::SyncPlane => {
616                panic!("Synchronization within a plane is not supported in WGSL")
617            }
618            cube::Synchronization::SyncStorage => {
619                instructions.push(wgsl::Instruction::StorageBarrier)
620            }
621            cube::Synchronization::SyncProxyShared => panic!("TMA is not supported in WGSL"),
622        };
623    }
624
625    fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
626        instructions.push(wgsl::Instruction::Comment { content })
627    }
628
629    fn compile_metadata(
630        &mut self,
631        metadata: cube::Metadata,
632        out: Option<cube::Variable>,
633    ) -> wgsl::Instruction {
634        let out = out.unwrap();
635        match metadata {
636            cube::Metadata::Rank { var } => {
637                let position = self.ext_meta_pos(&var);
638                let offset = self.metadata.rank_index(position);
639                wgsl::Instruction::Metadata {
640                    out: self.compile_variable(out),
641                    info_offset: self.compile_variable(offset.into()),
642                }
643            }
644            cube::Metadata::Stride { dim, var } => {
645                let position = self.ext_meta_pos(&var);
646                let offset = self.metadata.stride_offset_index(position);
647                wgsl::Instruction::ExtendedMeta {
648                    info_offset: self.compile_variable(offset.into()),
649                    dim: self.compile_variable(dim),
650                    out: self.compile_variable(out),
651                }
652            }
653            cube::Metadata::Shape { dim, var } => {
654                let position = self.ext_meta_pos(&var);
655                let offset = self.metadata.shape_offset_index(position);
656                wgsl::Instruction::ExtendedMeta {
657                    info_offset: self.compile_variable(offset.into()),
658                    dim: self.compile_variable(dim),
659                    out: self.compile_variable(out),
660                }
661            }
662            cube::Metadata::Length { var } => match var.kind {
663                cube::VariableKind::GlobalInputArray(id) => {
664                    let offset = self.metadata.len_index(id);
665                    wgsl::Instruction::Metadata {
666                        out: self.compile_variable(out),
667                        info_offset: self.compile_variable(offset.into()),
668                    }
669                }
670                cube::VariableKind::GlobalOutputArray(id) => {
671                    let offset = self.metadata.len_index(id);
672                    wgsl::Instruction::Metadata {
673                        out: self.compile_variable(out),
674                        info_offset: self.compile_variable(offset.into()),
675                    }
676                }
677                _ => wgsl::Instruction::Length {
678                    var: self.compile_variable(var),
679                    out: self.compile_variable(out),
680                },
681            },
682            cube::Metadata::BufferLength { var } => match var.kind {
683                cube::VariableKind::GlobalInputArray(id) => {
684                    let offset = self.metadata.buffer_len_index(id);
685                    wgsl::Instruction::Metadata {
686                        out: self.compile_variable(out),
687                        info_offset: self.compile_variable(offset.into()),
688                    }
689                }
690                cube::VariableKind::GlobalOutputArray(id) => {
691                    let offset = self.metadata.buffer_len_index(id);
692                    wgsl::Instruction::Metadata {
693                        out: self.compile_variable(out),
694                        info_offset: self.compile_variable(offset.into()),
695                    }
696                }
697                _ => wgsl::Instruction::Length {
698                    var: self.compile_variable(var),
699                    out: self.compile_variable(out),
700                },
701            },
702        }
703    }
704
705    fn compile_arithmetic(
706        &mut self,
707        value: cube::Arithmetic,
708        out: Option<cube::Variable>,
709        instructions: &mut Vec<wgsl::Instruction>,
710        scope: &mut Scope,
711    ) {
712        let out = out.unwrap();
713        match value {
714            cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
715                lhs: self.compile_variable(op.lhs),
716                rhs: self.compile_variable(op.rhs),
717                out: self.compile_variable(out),
718            }),
719            cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
720                lhs: self.compile_variable(op.lhs),
721                rhs: self.compile_variable(op.rhs),
722                out: self.compile_variable(out),
723            }),
724            cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
725                lhs: self.compile_variable(op.lhs),
726                rhs: self.compile_variable(op.rhs),
727                out: self.compile_variable(out),
728            }),
729            cube::Arithmetic::SaturatingAdd(_) => {
730                unreachable!("Saturating add should be removed by processor");
731            }
732            cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
733                a: self.compile_variable(op.a),
734                b: self.compile_variable(op.b),
735                c: self.compile_variable(op.c),
736                out: self.compile_variable(out),
737            }),
738            cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
739                lhs: self.compile_variable(op.lhs),
740                rhs: self.compile_variable(op.rhs),
741                out: self.compile_variable(out),
742            }),
743            cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
744                lhs: self.compile_variable(op.lhs),
745                rhs: self.compile_variable(op.rhs),
746                out: self.compile_variable(out),
747            }),
748            cube::Arithmetic::SaturatingSub(_) => {
749                unreachable!("Saturating sub should be removed by processor");
750            }
751            cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
752                lhs: self.compile_variable(op.lhs),
753                rhs: self.compile_variable(op.rhs),
754                out: self.compile_variable(out),
755            }),
756            cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
757                lhs: self.compile_variable(op.lhs),
758                rhs: self.compile_variable(op.rhs),
759                out: self.compile_variable(out),
760            }),
761            cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
762                input: self.compile_variable(op.input),
763                out: self.compile_variable(out),
764            }),
765            cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
766                input: self.compile_variable(op.input),
767                out: self.compile_variable(out),
768            }),
769            cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
770                input: self.compile_variable(op.input),
771                out: self.compile_variable(out),
772            }),
773            cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
774                input: self.compile_variable(op.input),
775                out: self.compile_variable(out),
776            }),
777            cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
778                input: self.compile_variable(op.input),
779                out: self.compile_variable(out),
780            }),
781            cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
782                input: self.compile_variable(op.input),
783                out: self.compile_variable(out),
784            }),
785            cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
786                input: self.compile_variable(op.input),
787                out: self.compile_variable(out),
788            }),
789            // No powi in WGSL
790            cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
791                instructions.push(wgsl::Instruction::Powf {
792                    lhs: self.compile_variable(op.lhs),
793                    rhs: self.compile_variable(op.rhs),
794                    out: self.compile_variable(out),
795                })
796            }
797            cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
798                input: self.compile_variable(op.input),
799                out: self.compile_variable(out),
800            }),
801            cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
802                input: self.compile_variable(op.input),
803                out: self.compile_variable(out),
804            }),
805            cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
806                input: self.compile_variable(op.input),
807                out: self.compile_variable(out),
808            }),
809            cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
810                input: self.compile_variable(op.input),
811                out: self.compile_variable(out),
812            }),
813            cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
814                input: self.compile_variable(op.input),
815                out: self.compile_variable(out),
816            }),
817            cube::Arithmetic::Erf(op) => {
818                let mut scope = scope.child();
819                expand_erf(&mut scope, op.input, out);
820                instructions.extend(self.compile_scope(&mut scope));
821            }
822            cube::Arithmetic::MulHi(op) => {
823                let mut scope = scope.child();
824                match self.compilation_options.supports_u64 {
825                    true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
826                    false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
827                }
828                instructions.extend(self.compile_scope(&mut scope));
829            }
830            cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
831                input: self.compile_variable(op.input),
832                out: self.compile_variable(out),
833            }),
834            cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
835                input: self.compile_variable(op.input),
836                min_value: self.compile_variable(op.min_value),
837                max_value: self.compile_variable(op.max_value),
838                out: self.compile_variable(out),
839            }),
840            cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
841                lhs: self.compile_variable(op.lhs),
842                rhs: self.compile_variable(op.rhs),
843                out: self.compile_variable(out),
844            }),
845            cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
846                input: self.compile_variable(op.input),
847                out: self.compile_variable(out),
848            }),
849            cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
850                input: self.compile_variable(op.input),
851                out: self.compile_variable(out),
852            }),
853            cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
854                input: self.compile_variable(op.input),
855                out: self.compile_variable(out),
856            }),
857            cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
858                lhs: self.compile_variable(op.lhs),
859                rhs: self.compile_variable(op.rhs),
860                out: self.compile_variable(out),
861            }),
862        }
863    }
864
865    fn compile_cmp(
866        &mut self,
867        value: cube::Comparison,
868        out: Option<cube::Variable>,
869        instructions: &mut Vec<wgsl::Instruction>,
870    ) {
871        let out = out.unwrap();
872        match value {
873            cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
874                lhs: self.compile_variable(op.lhs),
875                rhs: self.compile_variable(op.rhs),
876                out: self.compile_variable(out),
877            }),
878            cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
879                lhs: self.compile_variable(op.lhs),
880                rhs: self.compile_variable(op.rhs),
881                out: self.compile_variable(out),
882            }),
883            cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
884                lhs: self.compile_variable(op.lhs),
885                rhs: self.compile_variable(op.rhs),
886                out: self.compile_variable(out),
887            }),
888            cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
889                lhs: self.compile_variable(op.lhs),
890                rhs: self.compile_variable(op.rhs),
891                out: self.compile_variable(out),
892            }),
893            cube::Comparison::GreaterEqual(op) => {
894                instructions.push(wgsl::Instruction::GreaterEqual {
895                    lhs: self.compile_variable(op.lhs),
896                    rhs: self.compile_variable(op.rhs),
897                    out: self.compile_variable(out),
898                })
899            }
900            cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
901                lhs: self.compile_variable(op.lhs),
902                rhs: self.compile_variable(op.rhs),
903                out: self.compile_variable(out),
904            }),
905            cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
906                input: self.compile_variable(op.input),
907                out: self.compile_variable(out),
908            }),
909            cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
910                input: self.compile_variable(op.input),
911                out: self.compile_variable(out),
912            }),
913        }
914    }
915
916    fn compile_bitwise(
917        &mut self,
918        value: cube::Bitwise,
919        out: Option<cube::Variable>,
920        instructions: &mut Vec<wgsl::Instruction>,
921    ) {
922        let out = out.unwrap();
923        match value {
924            cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
925                lhs: self.compile_variable(op.lhs),
926                rhs: self.compile_variable(op.rhs),
927                out: self.compile_variable(out),
928            }),
929            cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
930                lhs: self.compile_variable(op.lhs),
931                rhs: self.compile_variable(op.rhs),
932                out: self.compile_variable(out),
933            }),
934            cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
935                lhs: self.compile_variable(op.lhs),
936                rhs: self.compile_variable(op.rhs),
937                out: self.compile_variable(out),
938            }),
939            cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
940                input: self.compile_variable(op.input),
941                out: self.compile_variable(out),
942            }),
943            cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
944                input: self.compile_variable(op.input),
945                out: self.compile_variable(out),
946            }),
947            cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
948                lhs: self.compile_variable(op.lhs),
949                rhs: self.compile_variable(op.rhs),
950                out: self.compile_variable(out),
951            }),
952            cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
953                lhs: self.compile_variable(op.lhs),
954                rhs: self.compile_variable(op.rhs),
955                out: self.compile_variable(out),
956            }),
957            cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
958                input: self.compile_variable(op.input),
959                out: self.compile_variable(out),
960            }),
961            cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
962                input: self.compile_variable(op.input),
963                out: self.compile_variable(out),
964            }),
965            cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
966                input: self.compile_variable(op.input),
967                out: self.compile_variable(out),
968            }),
969        }
970    }
971
972    fn compile_operator(
973        &mut self,
974        value: cube::Operator,
975        out: Option<cube::Variable>,
976        instructions: &mut Vec<wgsl::Instruction>,
977    ) {
978        let out = out.unwrap();
979        match value {
980            cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
981                input: self.compile_variable(op.input),
982                out: self.compile_variable(out),
983            }),
984            cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
985                instructions.push(wgsl::Instruction::Index {
986                    lhs: self.compile_variable(op.list),
987                    rhs: self.compile_variable(op.index),
988                    out: self.compile_variable(out),
989                });
990            }
991            cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
992                instructions.push(wgsl::Instruction::IndexAssign {
993                    index: self.compile_variable(op.index),
994                    rhs: self.compile_variable(op.value),
995                    out: self.compile_variable(out),
996                })
997            }
998            cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
999                lhs: self.compile_variable(op.lhs),
1000                rhs: self.compile_variable(op.rhs),
1001                out: self.compile_variable(out),
1002            }),
1003            cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1004                lhs: self.compile_variable(op.lhs),
1005                rhs: self.compile_variable(op.rhs),
1006                out: self.compile_variable(out),
1007            }),
1008            cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1009                input: self.compile_variable(op.input),
1010                out: self.compile_variable(out),
1011            }),
1012            cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1013                input: self.compile_variable(op.input),
1014                out: self.compile_variable(out),
1015            }),
1016            cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1017                inputs: op
1018                    .inputs
1019                    .into_iter()
1020                    .map(|var| self.compile_variable(var))
1021                    .collect(),
1022                out: self.compile_variable(out),
1023            }),
1024            cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1025                input: self.compile_variable(op.input),
1026                in_index: self.compile_variable(op.in_index),
1027                out: self.compile_variable(out),
1028                out_index: self.compile_variable(op.out_index),
1029            }),
1030            cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1031                input: self.compile_variable(op.input),
1032                in_index: self.compile_variable(op.in_index),
1033                out: self.compile_variable(out),
1034                out_index: self.compile_variable(op.out_index),
1035                len: op.len,
1036            }),
1037            cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1038                cond: self.compile_variable(op.cond),
1039                then: self.compile_variable(op.then),
1040                or_else: self.compile_variable(op.or_else),
1041                out: self.compile_variable(out),
1042            }),
1043        }
1044    }
1045
1046    fn compile_atomic(
1047        &mut self,
1048        atomic: cube::AtomicOp,
1049        out: Option<cube::Variable>,
1050    ) -> wgsl::Instruction {
1051        let out = out.unwrap();
1052        match atomic {
1053            cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1054                lhs: self.compile_variable(op.lhs),
1055                rhs: self.compile_variable(op.rhs),
1056                out: self.compile_variable(out),
1057            },
1058            cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1059                lhs: self.compile_variable(op.lhs),
1060                rhs: self.compile_variable(op.rhs),
1061                out: self.compile_variable(out),
1062            },
1063            cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1064                lhs: self.compile_variable(op.lhs),
1065                rhs: self.compile_variable(op.rhs),
1066                out: self.compile_variable(out),
1067            },
1068            cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1069                lhs: self.compile_variable(op.lhs),
1070                rhs: self.compile_variable(op.rhs),
1071                out: self.compile_variable(out),
1072            },
1073            cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1074                lhs: self.compile_variable(op.lhs),
1075                rhs: self.compile_variable(op.rhs),
1076                out: self.compile_variable(out),
1077            },
1078            cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1079                lhs: self.compile_variable(op.lhs),
1080                rhs: self.compile_variable(op.rhs),
1081                out: self.compile_variable(out),
1082            },
1083            cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1084                lhs: self.compile_variable(op.lhs),
1085                rhs: self.compile_variable(op.rhs),
1086                out: self.compile_variable(out),
1087            },
1088            cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1089                input: self.compile_variable(op.input),
1090                out: self.compile_variable(out),
1091            },
1092            cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1093                input: self.compile_variable(op.input),
1094                out: self.compile_variable(out),
1095            },
1096            cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1097                lhs: self.compile_variable(op.lhs),
1098                rhs: self.compile_variable(op.rhs),
1099                out: self.compile_variable(out),
1100            },
1101            cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1102                lhs: self.compile_variable(op.input),
1103                cmp: self.compile_variable(op.cmp),
1104                value: self.compile_variable(op.val),
1105                out: self.compile_variable(out),
1106            },
1107        }
1108    }
1109
1110    fn compile_location(value: compute::Location) -> wgsl::Location {
1111        match value {
1112            compute::Location::Storage => wgsl::Location::Storage,
1113            compute::Location::Cube => wgsl::Location::Workgroup,
1114        }
1115    }
1116
1117    fn compile_binding(&mut self, value: compute::Binding) -> wgsl::Binding {
1118        wgsl::Binding {
1119            id: value.id,
1120            visibility: value.visibility,
1121            location: Self::compile_location(value.location),
1122            item: self.compile_type(value.ty),
1123            size: value.size,
1124        }
1125    }
1126}
1127
1128fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1129    let mut extensions = Vec::new();
1130
1131    let mut register_extension = |extension: wgsl::Extension| {
1132        if !extensions.contains(&extension) {
1133            extensions.push(extension);
1134        }
1135    };
1136
1137    // Since not all instructions are native to WGSL, we need to add the custom ones.
1138    for instruction in instructions {
1139        match instruction {
1140            wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1141                register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1142                register_extension(wgsl::powf_extension(rhs, out));
1143            }
1144            #[cfg(target_os = "macos")]
1145            wgsl::Instruction::Tanh { input, out: _ } => {
1146                register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1147                register_extension(wgsl::Extension::SafeTanh(input.item()));
1148            }
1149            wgsl::Instruction::IsNan { input, out } => {
1150                register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1151                register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1152            }
1153            wgsl::Instruction::IsInf { input, out } => {
1154                register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1155                register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1156            }
1157            wgsl::Instruction::If { instructions, .. } => {
1158                for extension in register_extensions(instructions) {
1159                    register_extension(extension);
1160                }
1161            }
1162            wgsl::Instruction::IfElse {
1163                instructions_if,
1164                instructions_else,
1165                ..
1166            } => {
1167                for extension in register_extensions(instructions_if) {
1168                    register_extension(extension);
1169                }
1170                for extension in register_extensions(instructions_else) {
1171                    register_extension(extension);
1172                }
1173            }
1174            wgsl::Instruction::Loop { instructions } => {
1175                for extension in register_extensions(instructions) {
1176                    register_extension(extension);
1177                }
1178            }
1179            wgsl::Instruction::RangeLoop { instructions, .. } => {
1180                for extension in register_extensions(instructions) {
1181                    register_extension(extension);
1182                }
1183            }
1184            _ => {}
1185        }
1186    }
1187
1188    extensions
1189}