1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedMemory};
4use crate::compiler::wgsl;
5
6use cubecl_common::ExecutionMode;
7use cubecl_core::ir::{ConstantScalarValue, Processor, UIntKind};
8use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
9use cubecl_core::prelude::*;
10use cubecl_core::{
11 Metadata, WgpuCompilationOptions, compute,
12 ir::{self as cube, Scope},
13 prelude::expand_erf,
14};
15
16#[derive(Clone, Default)]
18pub struct WgslCompiler {
19 metadata: Metadata,
20 ext_meta_pos: Vec<u32>,
21 local_invocation_index: bool,
22 local_invocation_id: bool,
23 global_invocation_id: bool,
25 workgroup_id: bool,
26 subgroup_size: bool,
27 subgroup_invocation_id: bool,
28 id: bool,
29 num_workgroups: bool,
30 workgroup_id_no_axis: bool,
31 workgroup_size_no_axis: bool,
32 num_workgroup_no_axis: bool,
33 shared_memories: Vec<SharedMemory>,
34 const_arrays: Vec<ConstantArray>,
35 local_arrays: Vec<LocalArray>,
36 #[allow(dead_code)]
37 compilation_options: WgpuCompilationOptions,
38 strategy: ExecutionMode,
39 subgroup_instructions_used: bool,
40 f16_used: bool,
41}
42
43impl core::fmt::Debug for WgslCompiler {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.write_str("WgslCompiler")
46 }
47}
48
49impl cubecl_core::Compiler for WgslCompiler {
50 type Representation = ComputeShader;
51 type CompilationOptions = WgpuCompilationOptions;
52
53 fn compile(
54 &mut self,
55 shader: compute::KernelDefinition,
56 compilation_options: &Self::CompilationOptions,
57 mode: ExecutionMode,
58 ) -> Self::Representation {
59 self.compilation_options = compilation_options.clone();
60 self.compile_shader(shader, mode)
61 }
62
63 fn elem_size(&self, elem: cube::Elem) -> usize {
64 elem.size()
65 }
66
67 fn extension(&self) -> &'static str {
68 "wgsl"
69 }
70}
71
72impl WgslCompiler {
73 fn compile_shader(
74 &mut self,
75 mut value: compute::KernelDefinition,
76 mode: ExecutionMode,
77 ) -> wgsl::ComputeShader {
78 self.strategy = mode;
79
80 let num_meta = value.buffers.len();
81
82 self.ext_meta_pos = Vec::new();
83 let mut num_ext = 0;
84
85 for binding in value.buffers.iter() {
86 self.ext_meta_pos.push(num_ext);
87 if binding.has_extended_meta {
88 num_ext += 1;
89 }
90 }
91
92 self.metadata = Metadata::new(num_meta as u32, num_ext);
93
94 let instructions = self.compile_scope(&mut value.body);
95 let extensions = register_extensions(&instructions);
96 let body = wgsl::Body {
97 instructions,
98 id: self.id,
99 };
100
101 wgsl::ComputeShader {
102 buffers: value
103 .buffers
104 .into_iter()
105 .map(|it| self.compile_binding(it))
106 .collect(),
107 scalars: value
108 .scalars
109 .into_iter()
110 .map(|binding| (self.compile_elem(binding.elem), binding.count))
111 .collect(),
112 shared_memories: self.shared_memories.clone(),
113 constant_arrays: self.const_arrays.clone(),
114 local_arrays: self.local_arrays.clone(),
115 has_metadata: self.metadata.static_len() > 0,
116 workgroup_size: value.cube_dim,
117 global_invocation_id: self.global_invocation_id || self.id,
118 local_invocation_index: self.local_invocation_index,
119 local_invocation_id: self.local_invocation_id,
120 num_workgroups: self.id
121 || self.num_workgroups
122 || self.num_workgroup_no_axis
123 || self.workgroup_id_no_axis,
124 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
125 subgroup_size: self.subgroup_size,
126 subgroup_invocation_id: self.subgroup_invocation_id,
127 body,
128 extensions,
129 num_workgroups_no_axis: self.num_workgroup_no_axis,
130 workgroup_id_no_axis: self.workgroup_id_no_axis,
131 workgroup_size_no_axis: self.workgroup_size_no_axis,
132 subgroup_instructions_used: self.subgroup_instructions_used,
133 f16_used: self.f16_used,
134 kernel_name: value.options.kernel_name,
135 }
136 }
137
138 fn compile_item(&mut self, item: cube::Item) -> Item {
139 let elem = self.compile_elem(item.elem);
140 match item.vectorization.map(|it| it.get()).unwrap_or(1) {
141 1 => wgsl::Item::Scalar(elem),
142 2 => wgsl::Item::Vec2(elem),
143 3 => wgsl::Item::Vec3(elem),
144 4 => wgsl::Item::Vec4(elem),
145 _ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization),
146 }
147 }
148
149 fn compile_elem(&mut self, value: cube::Elem) -> wgsl::Elem {
150 match value {
151 cube::Elem::Float(f) => match f {
152 cube::FloatKind::E2M1
153 | cube::FloatKind::E2M3
154 | cube::FloatKind::E3M2
155 | cube::FloatKind::E4M3
156 | cube::FloatKind::E5M2
157 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
158 cube::FloatKind::F16 => {
159 self.f16_used = true;
160 wgsl::Elem::F16
161 }
162 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
163 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
164 cube::FloatKind::Flex32 => wgsl::Elem::F32,
165 cube::FloatKind::F32 => wgsl::Elem::F32,
166 cube::FloatKind::F64 => wgsl::Elem::F64,
167 },
168 cube::Elem::Int(i) => match i {
169 cube::IntKind::I32 => wgsl::Elem::I32,
170 cube::IntKind::I64 => wgsl::Elem::I64,
171 kind => panic!("{kind:?} is not a valid WgpuElement"),
172 },
173 cube::Elem::UInt(kind) => match kind {
174 cube::UIntKind::U32 => wgsl::Elem::U32,
175 cube::UIntKind::U64 => wgsl::Elem::U64,
176 kind => panic!("{kind:?} is not a valid WgpuElement"),
177 },
178 cube::Elem::Bool => wgsl::Elem::Bool,
179 cube::Elem::AtomicFloat(i) => match i {
180 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
181 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
182 },
183 cube::Elem::AtomicInt(i) => match i {
184 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
185 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
186 },
187 cube::Elem::AtomicUInt(kind) => match kind {
188 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
189 kind => panic!("{kind:?} is not a valid WgpuElement"),
190 },
191 }
192 }
193
194 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
195 let pos = var.index().expect("Variable should have index");
196 self.ext_meta_pos[pos as usize]
197 }
198
199 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
200 let item = value.item;
201 match value.kind {
202 cube::VariableKind::GlobalInputArray(id) => {
203 wgsl::Variable::GlobalInputArray(id, self.compile_item(item))
204 }
205 cube::VariableKind::GlobalScalar(id) => {
206 wgsl::Variable::GlobalScalar(id, self.compile_elem(item.elem), item.elem)
207 }
208 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
209 wgsl::Variable::LocalMut {
210 id,
211 item: self.compile_item(item),
212 }
213 }
214 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
215 id,
216 item: self.compile_item(item),
217 },
218 cube::VariableKind::GlobalOutputArray(id) => {
219 wgsl::Variable::GlobalOutputArray(id, self.compile_item(item))
220 }
221 cube::VariableKind::ConstantScalar(value) => {
222 wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem()))
223 }
224 cube::VariableKind::SharedMemory {
225 id,
226 length,
227 alignment,
228 } => {
229 let item = self.compile_item(item);
230 if !self.shared_memories.iter().any(|s| s.index == id) {
231 self.shared_memories
232 .push(SharedMemory::new(id, item, length, alignment));
233 }
234 wgsl::Variable::SharedMemory(id, item, length)
235 }
236 cube::VariableKind::ConstantArray { id, length } => {
237 let item = self.compile_item(item);
238 wgsl::Variable::ConstantArray(id, item, length)
239 }
240 cube::VariableKind::LocalArray { id, length } => {
241 let item = self.compile_item(item);
242 if !self.local_arrays.iter().any(|s| s.index == id) {
243 self.local_arrays.push(LocalArray::new(id, item, length));
244 }
245 wgsl::Variable::LocalArray(id, item, length)
246 }
247 cube::VariableKind::Builtin(builtin) => match builtin {
248 cube::Builtin::AbsolutePos => {
249 self.id = true;
250 wgsl::Variable::Id
251 }
252 cube::Builtin::UnitPos => {
253 self.local_invocation_index = true;
254 wgsl::Variable::LocalInvocationIndex
255 }
256 cube::Builtin::UnitPosX => {
257 self.local_invocation_id = true;
258 wgsl::Variable::LocalInvocationIdX
259 }
260 cube::Builtin::UnitPosY => {
261 self.local_invocation_id = true;
262 wgsl::Variable::LocalInvocationIdY
263 }
264 cube::Builtin::UnitPosZ => {
265 self.local_invocation_id = true;
266 wgsl::Variable::LocalInvocationIdZ
267 }
268 cube::Builtin::CubePosX => {
269 self.workgroup_id = true;
270 wgsl::Variable::WorkgroupIdX
271 }
272 cube::Builtin::CubePosY => {
273 self.workgroup_id = true;
274 wgsl::Variable::WorkgroupIdY
275 }
276 cube::Builtin::CubePosZ => {
277 self.workgroup_id = true;
278 wgsl::Variable::WorkgroupIdZ
279 }
280 cube::Builtin::CubePosCluster
281 | cube::Builtin::CubePosClusterX
282 | cube::Builtin::CubePosClusterY
283 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
284 cube::Builtin::AbsolutePosX => {
285 self.global_invocation_id = true;
286 wgsl::Variable::GlobalInvocationIdX
287 }
288 cube::Builtin::AbsolutePosY => {
289 self.global_invocation_id = true;
290 wgsl::Variable::GlobalInvocationIdY
291 }
292 cube::Builtin::AbsolutePosZ => {
293 self.global_invocation_id = true;
294 wgsl::Variable::GlobalInvocationIdZ
295 }
296 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
297 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
298 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
299 cube::Builtin::CubeClusterDim
300 | cube::Builtin::CubeClusterDimX
301 | cube::Builtin::CubeClusterDimY
302 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
303 cube::Builtin::CubeCountX => {
304 self.num_workgroups = true;
305 wgsl::Variable::NumWorkgroupsX
306 }
307 cube::Builtin::CubeCountY => {
308 self.num_workgroups = true;
309 wgsl::Variable::NumWorkgroupsY
310 }
311 cube::Builtin::CubeCountZ => {
312 self.num_workgroups = true;
313 wgsl::Variable::NumWorkgroupsZ
314 }
315 cube::Builtin::CubePos => {
316 self.workgroup_id_no_axis = true;
317 wgsl::Variable::WorkgroupId
318 }
319 cube::Builtin::CubeDim => {
320 self.workgroup_size_no_axis = true;
321 wgsl::Variable::WorkgroupSize
322 }
323 cube::Builtin::CubeCount => {
324 self.num_workgroup_no_axis = true;
325 wgsl::Variable::NumWorkgroups
326 }
327 cube::Builtin::PlaneDim => {
328 self.subgroup_size = true;
329 wgsl::Variable::SubgroupSize
330 }
331 cube::Builtin::UnitPosPlane => {
332 self.subgroup_invocation_id = true;
333 wgsl::Variable::SubgroupInvocationId
334 }
335 },
336 cube::VariableKind::Matrix { .. } => {
337 panic!("Cooperative matrix-multiply and accumulate not supported.")
338 }
339 cube::VariableKind::Pipeline { .. } => {
340 panic!("Pipeline not supported.")
341 }
342 cube::VariableKind::Barrier { .. } => {
343 panic!("Barrier not supported.")
344 }
345 cube::VariableKind::TensorMap(_) => panic!("Tensor map not supported."),
346 }
347 }
348
349 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
350 let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
351 self.compile_variable(var)
352 }
353
354 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
355 let mut instructions = Vec::new();
356
357 let const_arrays = scope
358 .const_arrays
359 .drain(..)
360 .map(|(var, values)| ConstantArray {
361 index: var.index().unwrap(),
362 item: self.compile_item(var.item),
363 size: values.len() as u32,
364 values: values
365 .into_iter()
366 .map(|val| self.compile_variable(val))
367 .collect(),
368 })
369 .collect::<Vec<_>>();
370 self.const_arrays.extend(const_arrays);
371
372 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
373 let processing = scope.process([checked_io]);
374
375 for var in processing.variables {
376 instructions.push(wgsl::Instruction::DeclareVariable {
377 var: self.compile_variable(var),
378 });
379 }
380
381 processing
382 .instructions
383 .into_iter()
384 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
385
386 instructions
387 }
388
389 fn compile_operation(
390 &mut self,
391 instructions: &mut Vec<wgsl::Instruction>,
392 operation: cube::Operation,
393 out: Option<cube::Variable>,
394 scope: &mut cube::Scope,
395 ) {
396 match operation {
397 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
398 input: self.compile_variable(variable),
399 out: self.compile_variable(out.unwrap()),
400 }),
401 cube::Operation::Arithmetic(op) => {
402 self.compile_arithmetic(op, out, instructions, scope)
403 }
404 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
405 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
406 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
407 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
408 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
409 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
410 cube::Operation::Synchronization(val) => {
411 self.compile_synchronization(instructions, val)
412 }
413 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
414 cube::Operation::CoopMma(_) => {
415 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
416 }
417 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
418 self.compile_comment(instructions, content)
419 }
420 cube::Operation::NonSemantic(_) => {}
421 cube::Operation::Barrier(_) => {
422 panic!("Barrier isn't supported on wgpu.")
423 }
424 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
425 }
426 }
427
428 fn compile_subgroup(
429 &mut self,
430 instructions: &mut Vec<wgsl::Instruction>,
431 subgroup: cube::Plane,
432 out: Option<cube::Variable>,
433 ) {
434 self.subgroup_instructions_used = true;
435
436 let out = out.unwrap();
437 let op = match subgroup {
438 cube::Plane::Elect => Subgroup::Elect {
439 out: self.compile_variable(out),
440 },
441 cube::Plane::All(op) => Subgroup::All {
442 input: self.compile_variable(op.input),
443 out: self.compile_variable(out),
444 },
445 cube::Plane::Any(op) => Subgroup::Any {
446 input: self.compile_variable(op.input),
447 out: self.compile_variable(out),
448 },
449 cube::Plane::Ballot(op) => Subgroup::Ballot {
450 input: self.compile_variable(op.input),
451 out: self.compile_variable(out),
452 },
453 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
454 lhs: self.compile_variable(op.lhs),
455 rhs: self.compile_variable(op.rhs),
456 out: self.compile_variable(out),
457 },
458 cube::Plane::Sum(op) => Subgroup::Sum {
459 input: self.compile_variable(op.input),
460 out: self.compile_variable(out),
461 },
462 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
463 input: self.compile_variable(op.input),
464 out: self.compile_variable(out),
465 },
466 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
467 input: self.compile_variable(op.input),
468 out: self.compile_variable(out),
469 },
470 cube::Plane::Prod(op) => Subgroup::Prod {
471 input: self.compile_variable(op.input),
472 out: self.compile_variable(out),
473 },
474 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
475 input: self.compile_variable(op.input),
476 out: self.compile_variable(out),
477 },
478 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
479 input: self.compile_variable(op.input),
480 out: self.compile_variable(out),
481 },
482 cube::Plane::Min(op) => Subgroup::Min {
483 input: self.compile_variable(op.input),
484 out: self.compile_variable(out),
485 },
486 cube::Plane::Max(op) => Subgroup::Max {
487 input: self.compile_variable(op.input),
488 out: self.compile_variable(out),
489 },
490 };
491
492 instructions.push(wgsl::Instruction::Subgroup(op));
493 }
494
495 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
496 match branch {
497 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
498 cond: self.compile_variable(op.cond),
499 instructions: self.compile_scope(&mut op.scope),
500 }),
501 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
502 cond: self.compile_variable(op.cond),
503 instructions_if: self.compile_scope(&mut op.scope_if),
504 instructions_else: self.compile_scope(&mut op.scope_else),
505 }),
506 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
507 value: self.compile_variable(op.value),
508 instructions_default: self.compile_scope(&mut op.scope_default),
509 cases: op
510 .cases
511 .into_iter()
512 .map(|(val, mut scope)| {
513 (self.compile_variable(val), self.compile_scope(&mut scope))
514 })
515 .collect(),
516 }),
517 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
518 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
519 cube::Branch::RangeLoop(mut range_loop) => {
520 instructions.push(wgsl::Instruction::RangeLoop {
521 i: self.compile_variable(range_loop.i),
522 start: self.compile_variable(range_loop.start),
523 end: self.compile_variable(range_loop.end),
524 step: range_loop.step.map(|it| self.compile_variable(it)),
525 inclusive: range_loop.inclusive,
526 instructions: self.compile_scope(&mut range_loop.scope),
527 })
528 }
529 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
530 instructions: self.compile_scope(&mut op.scope),
531 }),
532 };
533 }
534
535 fn compile_synchronization(
536 &mut self,
537 instructions: &mut Vec<wgsl::Instruction>,
538 synchronization: cube::Synchronization,
539 ) {
540 match synchronization {
541 cube::Synchronization::SyncCube => {
542 instructions.push(wgsl::Instruction::WorkgroupBarrier)
543 }
544 cube::Synchronization::SyncPlane => {
545 panic!("Synchronization within a plane is not supported in WGSL")
546 }
547 cube::Synchronization::SyncStorage => {
548 instructions.push(wgsl::Instruction::StorageBarrier)
549 }
550 cube::Synchronization::SyncProxyShared => panic!("TMA is not supported in WGSL"),
551 };
552 }
553
554 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
555 instructions.push(wgsl::Instruction::Comment { content })
556 }
557
558 fn compile_metadata(
559 &mut self,
560 metadata: cube::Metadata,
561 out: Option<cube::Variable>,
562 ) -> wgsl::Instruction {
563 let out = out.unwrap();
564 match metadata {
565 cube::Metadata::Rank { var } => {
566 let position = self.ext_meta_pos(&var);
567 let offset = self.metadata.rank_index(position);
568 wgsl::Instruction::Metadata {
569 out: self.compile_variable(out),
570 info_offset: self.compile_variable(offset.into()),
571 }
572 }
573 cube::Metadata::Stride { dim, var } => {
574 let position = self.ext_meta_pos(&var);
575 let offset = self.metadata.stride_offset_index(position);
576 wgsl::Instruction::ExtendedMeta {
577 info_offset: self.compile_variable(offset.into()),
578 dim: self.compile_variable(dim),
579 out: self.compile_variable(out),
580 }
581 }
582 cube::Metadata::Shape { dim, var } => {
583 let position = self.ext_meta_pos(&var);
584 let offset = self.metadata.shape_offset_index(position);
585 wgsl::Instruction::ExtendedMeta {
586 info_offset: self.compile_variable(offset.into()),
587 dim: self.compile_variable(dim),
588 out: self.compile_variable(out),
589 }
590 }
591 cube::Metadata::Length { var } => match var.kind {
592 cube::VariableKind::GlobalInputArray(id) => {
593 let offset = self.metadata.len_index(id);
594 wgsl::Instruction::Metadata {
595 out: self.compile_variable(out),
596 info_offset: self.compile_variable(offset.into()),
597 }
598 }
599 cube::VariableKind::GlobalOutputArray(id) => {
600 let offset = self.metadata.len_index(id);
601 wgsl::Instruction::Metadata {
602 out: self.compile_variable(out),
603 info_offset: self.compile_variable(offset.into()),
604 }
605 }
606 _ => wgsl::Instruction::Length {
607 var: self.compile_variable(var),
608 out: self.compile_variable(out),
609 },
610 },
611 cube::Metadata::BufferLength { var } => match var.kind {
612 cube::VariableKind::GlobalInputArray(id) => {
613 let offset = self.metadata.buffer_len_index(id);
614 wgsl::Instruction::Metadata {
615 out: self.compile_variable(out),
616 info_offset: self.compile_variable(offset.into()),
617 }
618 }
619 cube::VariableKind::GlobalOutputArray(id) => {
620 let offset = self.metadata.buffer_len_index(id);
621 wgsl::Instruction::Metadata {
622 out: self.compile_variable(out),
623 info_offset: self.compile_variable(offset.into()),
624 }
625 }
626 _ => wgsl::Instruction::Length {
627 var: self.compile_variable(var),
628 out: self.compile_variable(out),
629 },
630 },
631 }
632 }
633
634 fn compile_arithmetic(
635 &mut self,
636 value: cube::Arithmetic,
637 out: Option<cube::Variable>,
638 instructions: &mut Vec<wgsl::Instruction>,
639 scope: &mut Scope,
640 ) {
641 let out = out.unwrap();
642 match value {
643 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
644 lhs: self.compile_variable(op.lhs),
645 rhs: self.compile_variable(op.rhs),
646 out: self.compile_variable(out),
647 }),
648 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
649 lhs: self.compile_variable(op.lhs),
650 rhs: self.compile_variable(op.rhs),
651 out: self.compile_variable(out),
652 }),
653 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
654 lhs: self.compile_variable(op.lhs),
655 rhs: self.compile_variable(op.rhs),
656 out: self.compile_variable(out),
657 }),
658 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
659 a: self.compile_variable(op.a),
660 b: self.compile_variable(op.b),
661 c: self.compile_variable(op.c),
662 out: self.compile_variable(out),
663 }),
664 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
665 lhs: self.compile_variable(op.lhs),
666 rhs: self.compile_variable(op.rhs),
667 out: self.compile_variable(out),
668 }),
669 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
670 lhs: self.compile_variable(op.lhs),
671 rhs: self.compile_variable(op.rhs),
672 out: self.compile_variable(out),
673 }),
674 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
675 lhs: self.compile_variable(op.lhs),
676 rhs: self.compile_variable(op.rhs),
677 out: self.compile_variable(out),
678 }),
679 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
680 lhs: self.compile_variable(op.lhs),
681 rhs: self.compile_variable(op.rhs),
682 out: self.compile_variable(out),
683 }),
684 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
685 input: self.compile_variable(op.input),
686 out: self.compile_variable(out),
687 }),
688 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
689 input: self.compile_variable(op.input),
690 out: self.compile_variable(out),
691 }),
692 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
693 input: self.compile_variable(op.input),
694 out: self.compile_variable(out),
695 }),
696 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
697 input: self.compile_variable(op.input),
698 out: self.compile_variable(out),
699 }),
700 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
701 input: self.compile_variable(op.input),
702 out: self.compile_variable(out),
703 }),
704 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
705 input: self.compile_variable(op.input),
706 out: self.compile_variable(out),
707 }),
708 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
709 input: self.compile_variable(op.input),
710 out: self.compile_variable(out),
711 }),
712 cube::Arithmetic::Powf(op) => instructions.push(wgsl::Instruction::Powf {
713 lhs: self.compile_variable(op.lhs),
714 rhs: self.compile_variable(op.rhs),
715 out: self.compile_variable(out),
716 }),
717 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
718 input: self.compile_variable(op.input),
719 out: self.compile_variable(out),
720 }),
721 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
722 input: self.compile_variable(op.input),
723 out: self.compile_variable(out),
724 }),
725 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
726 input: self.compile_variable(op.input),
727 out: self.compile_variable(out),
728 }),
729 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
730 input: self.compile_variable(op.input),
731 out: self.compile_variable(out),
732 }),
733 cube::Arithmetic::Erf(op) => {
734 let mut scope = scope.child();
735 expand_erf(&mut scope, op.input, out);
736 instructions.extend(self.compile_scope(&mut scope));
737 }
738 cube::Arithmetic::MulHi(op) => {
739 let mut scope = scope.child();
740 match self.compilation_options.supports_u64 {
741 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
742 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
743 }
744 instructions.extend(self.compile_scope(&mut scope));
745 }
746 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
747 input: self.compile_variable(op.input),
748 out: self.compile_variable(out),
749 }),
750 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
751 input: self.compile_variable(op.input),
752 min_value: self.compile_variable(op.min_value),
753 max_value: self.compile_variable(op.max_value),
754 out: self.compile_variable(out),
755 }),
756 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
757 lhs: self.compile_variable(op.lhs),
758 rhs: self.compile_variable(op.rhs),
759 out: self.compile_variable(out),
760 }),
761 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
762 input: self.compile_variable(op.input),
763 out: self.compile_variable(out),
764 }),
765 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
766 input: self.compile_variable(op.input),
767 out: self.compile_variable(out),
768 }),
769 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
770 input: self.compile_variable(op.input),
771 out: self.compile_variable(out),
772 }),
773 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
774 lhs: self.compile_variable(op.lhs),
775 rhs: self.compile_variable(op.rhs),
776 out: self.compile_variable(out),
777 }),
778 }
779 }
780
781 fn compile_cmp(
782 &mut self,
783 value: cube::Comparison,
784 out: Option<cube::Variable>,
785 instructions: &mut Vec<wgsl::Instruction>,
786 ) {
787 let out = out.unwrap();
788 match value {
789 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
790 lhs: self.compile_variable(op.lhs),
791 rhs: self.compile_variable(op.rhs),
792 out: self.compile_variable(out),
793 }),
794 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
795 lhs: self.compile_variable(op.lhs),
796 rhs: self.compile_variable(op.rhs),
797 out: self.compile_variable(out),
798 }),
799 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
800 lhs: self.compile_variable(op.lhs),
801 rhs: self.compile_variable(op.rhs),
802 out: self.compile_variable(out),
803 }),
804 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
805 lhs: self.compile_variable(op.lhs),
806 rhs: self.compile_variable(op.rhs),
807 out: self.compile_variable(out),
808 }),
809 cube::Comparison::GreaterEqual(op) => {
810 instructions.push(wgsl::Instruction::GreaterEqual {
811 lhs: self.compile_variable(op.lhs),
812 rhs: self.compile_variable(op.rhs),
813 out: self.compile_variable(out),
814 })
815 }
816 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
817 lhs: self.compile_variable(op.lhs),
818 rhs: self.compile_variable(op.rhs),
819 out: self.compile_variable(out),
820 }),
821 }
822 }
823
824 fn compile_bitwise(
825 &mut self,
826 value: cube::Bitwise,
827 out: Option<cube::Variable>,
828 instructions: &mut Vec<wgsl::Instruction>,
829 ) {
830 let out = out.unwrap();
831 match value {
832 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
833 lhs: self.compile_variable(op.lhs),
834 rhs: self.compile_variable(op.rhs),
835 out: self.compile_variable(out),
836 }),
837 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
838 lhs: self.compile_variable(op.lhs),
839 rhs: self.compile_variable(op.rhs),
840 out: self.compile_variable(out),
841 }),
842 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
843 lhs: self.compile_variable(op.lhs),
844 rhs: self.compile_variable(op.rhs),
845 out: self.compile_variable(out),
846 }),
847 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
848 input: self.compile_variable(op.input),
849 out: self.compile_variable(out),
850 }),
851 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
852 input: self.compile_variable(op.input),
853 out: self.compile_variable(out),
854 }),
855 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
856 lhs: self.compile_variable(op.lhs),
857 rhs: self.compile_variable(op.rhs),
858 out: self.compile_variable(out),
859 }),
860 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
861 lhs: self.compile_variable(op.lhs),
862 rhs: self.compile_variable(op.rhs),
863 out: self.compile_variable(out),
864 }),
865 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
866 input: self.compile_variable(op.input),
867 out: self.compile_variable(out),
868 }),
869 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
870 input: self.compile_variable(op.input),
871 out: self.compile_variable(out),
872 }),
873 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
874 input: self.compile_variable(op.input),
875 out: self.compile_variable(out),
876 }),
877 }
878 }
879
880 fn compile_operator(
881 &mut self,
882 value: cube::Operator,
883 out: Option<cube::Variable>,
884 instructions: &mut Vec<wgsl::Instruction>,
885 ) {
886 let out = out.unwrap();
887 match value {
888 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
889 input: self.compile_variable(op.input),
890 out: self.compile_variable(out),
891 }),
892 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
893 instructions.push(wgsl::Instruction::Index {
894 lhs: self.compile_variable(op.list),
895 rhs: self.compile_variable(op.index),
896 out: self.compile_variable(out),
897 });
898 }
899 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
900 instructions.push(wgsl::Instruction::IndexAssign {
901 index: self.compile_variable(op.index),
902 rhs: self.compile_variable(op.value),
903 out: self.compile_variable(out),
904 })
905 }
906 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
907 lhs: self.compile_variable(op.lhs),
908 rhs: self.compile_variable(op.rhs),
909 out: self.compile_variable(out),
910 }),
911 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
912 lhs: self.compile_variable(op.lhs),
913 rhs: self.compile_variable(op.rhs),
914 out: self.compile_variable(out),
915 }),
916 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
917 input: self.compile_variable(op.input),
918 out: self.compile_variable(out),
919 }),
920 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
921 input: self.compile_variable(op.input),
922 out: self.compile_variable(out),
923 }),
924 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
925 inputs: op
926 .inputs
927 .into_iter()
928 .map(|var| self.compile_variable(var))
929 .collect(),
930 out: self.compile_variable(out),
931 }),
932 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
933 input: self.compile_variable(op.input),
934 in_index: self.compile_variable(op.in_index),
935 out: self.compile_variable(out),
936 out_index: self.compile_variable(op.out_index),
937 }),
938 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
939 input: self.compile_variable(op.input),
940 in_index: self.compile_variable(op.in_index),
941 out: self.compile_variable(out),
942 out_index: self.compile_variable(op.out_index),
943 len: op.len.as_const().unwrap().as_u32(),
944 }),
945 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
946 cond: self.compile_variable(op.cond),
947 then: self.compile_variable(op.then),
948 or_else: self.compile_variable(op.or_else),
949 out: self.compile_variable(out),
950 }),
951 }
952 }
953
954 fn compile_atomic(
955 &mut self,
956 atomic: cube::AtomicOp,
957 out: Option<cube::Variable>,
958 ) -> wgsl::Instruction {
959 let out = out.unwrap();
960 match atomic {
961 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
962 lhs: self.compile_variable(op.lhs),
963 rhs: self.compile_variable(op.rhs),
964 out: self.compile_variable(out),
965 },
966 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
967 lhs: self.compile_variable(op.lhs),
968 rhs: self.compile_variable(op.rhs),
969 out: self.compile_variable(out),
970 },
971 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
972 lhs: self.compile_variable(op.lhs),
973 rhs: self.compile_variable(op.rhs),
974 out: self.compile_variable(out),
975 },
976 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
977 lhs: self.compile_variable(op.lhs),
978 rhs: self.compile_variable(op.rhs),
979 out: self.compile_variable(out),
980 },
981 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
982 lhs: self.compile_variable(op.lhs),
983 rhs: self.compile_variable(op.rhs),
984 out: self.compile_variable(out),
985 },
986 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
987 lhs: self.compile_variable(op.lhs),
988 rhs: self.compile_variable(op.rhs),
989 out: self.compile_variable(out),
990 },
991 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
992 lhs: self.compile_variable(op.lhs),
993 rhs: self.compile_variable(op.rhs),
994 out: self.compile_variable(out),
995 },
996 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
997 input: self.compile_variable(op.input),
998 out: self.compile_variable(out),
999 },
1000 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1001 input: self.compile_variable(op.input),
1002 out: self.compile_variable(out),
1003 },
1004 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1005 lhs: self.compile_variable(op.lhs),
1006 rhs: self.compile_variable(op.rhs),
1007 out: self.compile_variable(out),
1008 },
1009 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1010 lhs: self.compile_variable(op.input),
1011 cmp: self.compile_variable(op.cmp),
1012 value: self.compile_variable(op.val),
1013 out: self.compile_variable(out),
1014 },
1015 }
1016 }
1017
1018 fn compile_location(value: compute::Location) -> wgsl::Location {
1019 match value {
1020 compute::Location::Storage => wgsl::Location::Storage,
1021 compute::Location::Cube => wgsl::Location::Workgroup,
1022 }
1023 }
1024
1025 fn compile_binding(&mut self, value: compute::Binding) -> wgsl::Binding {
1026 wgsl::Binding {
1027 id: value.id,
1028 visibility: value.visibility,
1029 location: Self::compile_location(value.location),
1030 item: self.compile_item(value.item),
1031 size: value.size,
1032 }
1033 }
1034}
1035
1036fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1037 let mut extensions = Vec::new();
1038
1039 let mut register_extension = |extension: wgsl::Extension| {
1040 if !extensions.contains(&extension) {
1041 extensions.push(extension);
1042 }
1043 };
1044
1045 for instruction in instructions {
1047 match instruction {
1048 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1049 register_extension(wgsl::Extension::PowfPrimitive(out.item()));
1050
1051 if rhs.is_always_scalar() || rhs.item().vectorization_factor() == 1 {
1052 register_extension(wgsl::Extension::PowfScalar(out.item()));
1053 } else {
1054 register_extension(wgsl::Extension::Powf(out.item()));
1055 }
1056 }
1057 #[cfg(target_os = "macos")]
1058 wgsl::Instruction::Tanh { input, out: _ } => {
1059 register_extension(wgsl::Extension::SafeTanh(input.item()))
1060 }
1061 wgsl::Instruction::If { instructions, .. } => {
1062 for extension in register_extensions(instructions) {
1063 register_extension(extension);
1064 }
1065 }
1066 wgsl::Instruction::IfElse {
1067 instructions_if,
1068 instructions_else,
1069 ..
1070 } => {
1071 for extension in register_extensions(instructions_if) {
1072 register_extension(extension);
1073 }
1074 for extension in register_extensions(instructions_else) {
1075 register_extension(extension);
1076 }
1077 }
1078 wgsl::Instruction::Loop { instructions } => {
1079 for extension in register_extensions(instructions) {
1080 register_extension(extension);
1081 }
1082 }
1083 wgsl::Instruction::RangeLoop { instructions, .. } => {
1084 for extension in register_extensions(instructions) {
1085 register_extension(extension);
1086 }
1087 }
1088 _ => {}
1089 }
1090 }
1091
1092 extensions
1093}