1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedMemory};
4use crate::compiler::wgsl;
5
6use cubecl_common::ExecutionMode;
7use cubecl_core::post_processing::{
8 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12 Metadata, WgpuCompilationOptions, compute,
13 ir::{self as cube, Scope},
14 prelude::expand_erf,
15};
16use cubecl_core::{
17 ir::{ConstantScalarValue, Processor, UIntKind},
18 post_processing::unroll::UnrollProcessor,
19};
20
21pub const MAX_LINE_SIZE: u32 = 4;
22
23#[derive(Clone, Default)]
25pub struct WgslCompiler {
26 metadata: Metadata,
27 ext_meta_pos: Vec<u32>,
28 local_invocation_index: bool,
29 local_invocation_id: bool,
30 global_invocation_id: bool,
32 workgroup_id: bool,
33 subgroup_size: bool,
34 subgroup_invocation_id: bool,
35 id: bool,
36 num_workgroups: bool,
37 workgroup_id_no_axis: bool,
38 workgroup_size_no_axis: bool,
39 num_workgroup_no_axis: bool,
40 shared_memories: Vec<SharedMemory>,
41 const_arrays: Vec<ConstantArray>,
42 local_arrays: Vec<LocalArray>,
43 #[allow(dead_code)]
44 compilation_options: WgpuCompilationOptions,
45 strategy: ExecutionMode,
46 subgroup_instructions_used: bool,
47 f16_used: bool,
48}
49
50impl core::fmt::Debug for WgslCompiler {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.write_str("WgslCompiler")
53 }
54}
55
56impl cubecl_core::Compiler for WgslCompiler {
57 type Representation = ComputeShader;
58 type CompilationOptions = WgpuCompilationOptions;
59
60 fn compile(
61 &mut self,
62 shader: compute::KernelDefinition,
63 compilation_options: &Self::CompilationOptions,
64 mode: ExecutionMode,
65 ) -> Self::Representation {
66 self.compilation_options = compilation_options.clone();
67 self.compile_shader(shader, mode)
68 }
69
70 fn elem_size(&self, elem: cube::ElemType) -> usize {
71 elem.size()
72 }
73
74 fn extension(&self) -> &'static str {
75 "wgsl"
76 }
77}
78
79impl WgslCompiler {
80 fn compile_shader(
81 &mut self,
82 mut value: compute::KernelDefinition,
83 mode: ExecutionMode,
84 ) -> wgsl::ComputeShader {
85 self.strategy = mode;
86
87 let num_meta = value.buffers.len();
88
89 self.ext_meta_pos = Vec::new();
90 let mut num_ext = 0;
91
92 for binding in value.buffers.iter() {
93 self.ext_meta_pos.push(num_ext);
94 if binding.has_extended_meta {
95 num_ext += 1;
96 }
97 }
98
99 self.metadata = Metadata::new(num_meta as u32, num_ext);
100
101 let instructions = self.compile_scope(&mut value.body);
102 let extensions = register_extensions(&instructions);
103 let body = wgsl::Body {
104 instructions,
105 id: self.id,
106 };
107
108 wgsl::ComputeShader {
109 buffers: value
110 .buffers
111 .into_iter()
112 .map(|mut it| {
113 if it.ty.line_size() > MAX_LINE_SIZE {
116 it.ty = it.ty.line(MAX_LINE_SIZE);
117 }
118 self.compile_binding(it)
119 })
120 .collect(),
121 scalars: value
122 .scalars
123 .into_iter()
124 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
125 .collect(),
126 shared_memories: self.shared_memories.clone(),
127 constant_arrays: self.const_arrays.clone(),
128 local_arrays: self.local_arrays.clone(),
129 has_metadata: self.metadata.static_len() > 0,
130 workgroup_size: value.cube_dim,
131 global_invocation_id: self.global_invocation_id || self.id,
132 local_invocation_index: self.local_invocation_index,
133 local_invocation_id: self.local_invocation_id,
134 num_workgroups: self.id
135 || self.num_workgroups
136 || self.num_workgroup_no_axis
137 || self.workgroup_id_no_axis,
138 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
139 subgroup_size: self.subgroup_size,
140 subgroup_invocation_id: self.subgroup_invocation_id,
141 body,
142 extensions,
143 num_workgroups_no_axis: self.num_workgroup_no_axis,
144 workgroup_id_no_axis: self.workgroup_id_no_axis,
145 workgroup_size_no_axis: self.workgroup_size_no_axis,
146 subgroup_instructions_used: self.subgroup_instructions_used,
147 f16_used: self.f16_used,
148 kernel_name: value.options.kernel_name,
149 }
150 }
151
152 fn compile_type(&mut self, item: cube::Type) -> Item {
153 match item {
154 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
155 cube::Type::Line(ty, size) => {
156 let elem = self.compile_storage_type(ty);
157 match size {
158 2 => wgsl::Item::Vec2(elem),
159 3 => wgsl::Item::Vec3(elem),
160 4 => wgsl::Item::Vec4(elem),
161 _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
162 }
163 }
164 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
165 }
166 }
167
168 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
169 match ty {
170 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
171 cube::StorageType::Atomic(ty) => match ty {
172 cube::ElemType::Float(i) => match i {
173 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
174 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
175 },
176 cube::ElemType::Int(i) => match i {
177 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
178 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
179 },
180 cube::ElemType::UInt(kind) => match kind {
181 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
182 kind => panic!("{kind:?} is not a valid WgpuElement"),
183 },
184 other => panic!("{other:?} is not a valid WgpuElement"),
185 },
186 cube::StorageType::Packed(_, _) => {
187 unimplemented!("Packed types not yet supported in WGSL")
188 }
189 }
190 }
191
192 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
193 match value {
194 cube::ElemType::Float(f) => match f {
195 cube::FloatKind::E2M1
196 | cube::FloatKind::E2M3
197 | cube::FloatKind::E3M2
198 | cube::FloatKind::E4M3
199 | cube::FloatKind::E5M2
200 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
201 cube::FloatKind::F16 => {
202 self.f16_used = true;
203 wgsl::Elem::F16
204 }
205 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
206 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
207 cube::FloatKind::Flex32 => wgsl::Elem::F32,
208 cube::FloatKind::F32 => wgsl::Elem::F32,
209 cube::FloatKind::F64 => wgsl::Elem::F64,
210 },
211 cube::ElemType::Int(i) => match i {
212 cube::IntKind::I32 => wgsl::Elem::I32,
213 cube::IntKind::I64 => wgsl::Elem::I64,
214 kind => panic!("{kind:?} is not a valid WgpuElement"),
215 },
216 cube::ElemType::UInt(kind) => match kind {
217 cube::UIntKind::U32 => wgsl::Elem::U32,
218 cube::UIntKind::U64 => wgsl::Elem::U64,
219 kind => panic!("{kind:?} is not a valid WgpuElement"),
220 },
221 cube::ElemType::Bool => wgsl::Elem::Bool,
222 }
223 }
224
225 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
226 let pos = var.index().expect("Variable should have index");
227 self.ext_meta_pos[pos as usize]
228 }
229
230 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
231 let item = value.ty;
232 match value.kind {
233 cube::VariableKind::GlobalInputArray(id) => {
234 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
235 }
236 cube::VariableKind::GlobalScalar(id) => {
237 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
238 }
239 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
240 wgsl::Variable::LocalMut {
241 id,
242 item: self.compile_type(item),
243 }
244 }
245 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
246 id,
247 item: self.compile_type(item),
248 },
249 cube::VariableKind::GlobalOutputArray(id) => {
250 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
251 }
252 cube::VariableKind::ConstantScalar(value) => {
253 wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
254 }
255 cube::VariableKind::SharedMemory {
256 id,
257 length,
258 unroll_factor,
259 alignment,
260 } => {
261 let item = self.compile_type(item);
262 if !self.shared_memories.iter().any(|s| s.index == id) {
263 self.shared_memories.push(SharedMemory::new(
264 id,
265 item,
266 length * unroll_factor,
267 alignment,
268 ));
269 }
270 wgsl::Variable::SharedMemory(id, item, length)
271 }
272 cube::VariableKind::ConstantArray { id, length, .. } => {
273 let item = self.compile_type(item);
274 wgsl::Variable::ConstantArray(id, item, length)
275 }
276 cube::VariableKind::LocalArray {
277 id,
278 length,
279 unroll_factor,
280 } => {
281 let item = self.compile_type(item);
282 if !self.local_arrays.iter().any(|s| s.index == id) {
283 self.local_arrays
284 .push(LocalArray::new(id, item, length * unroll_factor));
285 }
286 wgsl::Variable::LocalArray(id, item, length)
287 }
288 cube::VariableKind::Builtin(builtin) => match builtin {
289 cube::Builtin::AbsolutePos => {
290 self.id = true;
291 wgsl::Variable::Id
292 }
293 cube::Builtin::UnitPos => {
294 self.local_invocation_index = true;
295 wgsl::Variable::LocalInvocationIndex
296 }
297 cube::Builtin::UnitPosX => {
298 self.local_invocation_id = true;
299 wgsl::Variable::LocalInvocationIdX
300 }
301 cube::Builtin::UnitPosY => {
302 self.local_invocation_id = true;
303 wgsl::Variable::LocalInvocationIdY
304 }
305 cube::Builtin::UnitPosZ => {
306 self.local_invocation_id = true;
307 wgsl::Variable::LocalInvocationIdZ
308 }
309 cube::Builtin::CubePosX => {
310 self.workgroup_id = true;
311 wgsl::Variable::WorkgroupIdX
312 }
313 cube::Builtin::CubePosY => {
314 self.workgroup_id = true;
315 wgsl::Variable::WorkgroupIdY
316 }
317 cube::Builtin::CubePosZ => {
318 self.workgroup_id = true;
319 wgsl::Variable::WorkgroupIdZ
320 }
321 cube::Builtin::CubePosCluster
322 | cube::Builtin::CubePosClusterX
323 | cube::Builtin::CubePosClusterY
324 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
325 cube::Builtin::AbsolutePosX => {
326 self.global_invocation_id = true;
327 wgsl::Variable::GlobalInvocationIdX
328 }
329 cube::Builtin::AbsolutePosY => {
330 self.global_invocation_id = true;
331 wgsl::Variable::GlobalInvocationIdY
332 }
333 cube::Builtin::AbsolutePosZ => {
334 self.global_invocation_id = true;
335 wgsl::Variable::GlobalInvocationIdZ
336 }
337 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
338 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
339 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
340 cube::Builtin::CubeClusterDim
341 | cube::Builtin::CubeClusterDimX
342 | cube::Builtin::CubeClusterDimY
343 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
344 cube::Builtin::CubeCountX => {
345 self.num_workgroups = true;
346 wgsl::Variable::NumWorkgroupsX
347 }
348 cube::Builtin::CubeCountY => {
349 self.num_workgroups = true;
350 wgsl::Variable::NumWorkgroupsY
351 }
352 cube::Builtin::CubeCountZ => {
353 self.num_workgroups = true;
354 wgsl::Variable::NumWorkgroupsZ
355 }
356 cube::Builtin::CubePos => {
357 self.workgroup_id_no_axis = true;
358 wgsl::Variable::WorkgroupId
359 }
360 cube::Builtin::CubeDim => {
361 self.workgroup_size_no_axis = true;
362 wgsl::Variable::WorkgroupSize
363 }
364 cube::Builtin::CubeCount => {
365 self.num_workgroup_no_axis = true;
366 wgsl::Variable::NumWorkgroups
367 }
368 cube::Builtin::PlaneDim => {
369 self.subgroup_size = true;
370 wgsl::Variable::SubgroupSize
371 }
372 cube::Builtin::UnitPosPlane => {
373 self.subgroup_invocation_id = true;
374 wgsl::Variable::SubgroupInvocationId
375 }
376 },
377 cube::VariableKind::Matrix { .. } => {
378 panic!("Cooperative matrix-multiply and accumulate not supported.")
379 }
380 cube::VariableKind::Pipeline { .. } => {
381 panic!("Pipeline not supported.")
382 }
383 cube::VariableKind::Barrier { .. } => {
384 panic!("Barrier not supported.")
385 }
386 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
387 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
388 }
389 }
390
391 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
392 let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
393 self.compile_variable(var)
394 }
395
396 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
397 let mut instructions = Vec::new();
398
399 let const_arrays = scope
400 .const_arrays
401 .drain(..)
402 .map(|(var, values)| ConstantArray {
403 index: var.index().unwrap(),
404 item: self.compile_type(var.ty),
405 size: values.len() as u32,
406 values: values
407 .into_iter()
408 .map(|val| self.compile_variable(val))
409 .collect(),
410 })
411 .collect::<Vec<_>>();
412 self.const_arrays.extend(const_arrays);
413
414 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
415 let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
416 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
417 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
418
419 for mut var in processing.variables {
420 if var.ty.line_size() > MAX_LINE_SIZE {
421 var.ty = var.ty.line(MAX_LINE_SIZE);
422 }
423 instructions.push(wgsl::Instruction::DeclareVariable {
424 var: self.compile_variable(var),
425 });
426 }
427
428 processing
429 .instructions
430 .into_iter()
431 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
432
433 instructions
434 }
435
436 fn compile_operation(
437 &mut self,
438 instructions: &mut Vec<wgsl::Instruction>,
439 operation: cube::Operation,
440 out: Option<cube::Variable>,
441 scope: &mut cube::Scope,
442 ) {
443 match operation {
444 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
445 input: self.compile_variable(variable),
446 out: self.compile_variable(out.unwrap()),
447 }),
448 cube::Operation::Arithmetic(op) => {
449 self.compile_arithmetic(op, out, instructions, scope)
450 }
451 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
452 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
453 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
454 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
455 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
456 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
457 cube::Operation::Synchronization(val) => {
458 self.compile_synchronization(instructions, val)
459 }
460 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
461 cube::Operation::CoopMma(_) => {
462 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
463 }
464 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
465 self.compile_comment(instructions, content)
466 }
467 cube::Operation::NonSemantic(_) => {}
468 cube::Operation::Barrier(_) => {
469 panic!("Barrier isn't supported on wgpu.")
470 }
471 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
472 cube::Operation::Free(_) => {}
473 }
474 }
475
476 fn compile_subgroup(
477 &mut self,
478 instructions: &mut Vec<wgsl::Instruction>,
479 subgroup: cube::Plane,
480 out: Option<cube::Variable>,
481 ) {
482 self.subgroup_instructions_used = true;
483
484 let out = out.unwrap();
485 let op = match subgroup {
486 cube::Plane::Elect => Subgroup::Elect {
487 out: self.compile_variable(out),
488 },
489 cube::Plane::All(op) => Subgroup::All {
490 input: self.compile_variable(op.input),
491 out: self.compile_variable(out),
492 },
493 cube::Plane::Any(op) => Subgroup::Any {
494 input: self.compile_variable(op.input),
495 out: self.compile_variable(out),
496 },
497 cube::Plane::Ballot(op) => Subgroup::Ballot {
498 input: self.compile_variable(op.input),
499 out: self.compile_variable(out),
500 },
501
502 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
503 lhs: self.compile_variable(op.lhs),
504 rhs: self.compile_variable(op.rhs),
505 out: self.compile_variable(out),
506 },
507
508 cube::Plane::Sum(op) => Subgroup::Sum {
509 input: self.compile_variable(op.input),
510 out: self.compile_variable(out),
511 },
512
513 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
514 input: self.compile_variable(op.input),
515 out: self.compile_variable(out),
516 },
517 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
518 input: self.compile_variable(op.input),
519 out: self.compile_variable(out),
520 },
521 cube::Plane::Prod(op) => Subgroup::Prod {
522 input: self.compile_variable(op.input),
523 out: self.compile_variable(out),
524 },
525 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
526 input: self.compile_variable(op.input),
527 out: self.compile_variable(out),
528 },
529 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
530 input: self.compile_variable(op.input),
531 out: self.compile_variable(out),
532 },
533 cube::Plane::Min(op) => Subgroup::Min {
534 input: self.compile_variable(op.input),
535 out: self.compile_variable(out),
536 },
537 cube::Plane::Max(op) => Subgroup::Max {
538 input: self.compile_variable(op.input),
539 out: self.compile_variable(out),
540 },
541 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
542 lhs: self.compile_variable(op.lhs),
543 rhs: self.compile_variable(op.rhs),
544 out: self.compile_variable(out),
545 },
546 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
547 lhs: self.compile_variable(op.lhs),
548 rhs: self.compile_variable(op.rhs),
549 out: self.compile_variable(out),
550 },
551 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
552 lhs: self.compile_variable(op.lhs),
553 rhs: self.compile_variable(op.rhs),
554 out: self.compile_variable(out),
555 },
556 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
557 lhs: self.compile_variable(op.lhs),
558 rhs: self.compile_variable(op.rhs),
559 out: self.compile_variable(out),
560 },
561 };
562
563 instructions.push(wgsl::Instruction::Subgroup(op));
564 }
565
566 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
567 match branch {
568 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
569 cond: self.compile_variable(op.cond),
570 instructions: self.compile_scope(&mut op.scope),
571 }),
572 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
573 cond: self.compile_variable(op.cond),
574 instructions_if: self.compile_scope(&mut op.scope_if),
575 instructions_else: self.compile_scope(&mut op.scope_else),
576 }),
577 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
578 value: self.compile_variable(op.value),
579 instructions_default: self.compile_scope(&mut op.scope_default),
580 cases: op
581 .cases
582 .into_iter()
583 .map(|(val, mut scope)| {
584 (self.compile_variable(val), self.compile_scope(&mut scope))
585 })
586 .collect(),
587 }),
588 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
589 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
590 cube::Branch::RangeLoop(mut range_loop) => {
591 instructions.push(wgsl::Instruction::RangeLoop {
592 i: self.compile_variable(range_loop.i),
593 start: self.compile_variable(range_loop.start),
594 end: self.compile_variable(range_loop.end),
595 step: range_loop.step.map(|it| self.compile_variable(it)),
596 inclusive: range_loop.inclusive,
597 instructions: self.compile_scope(&mut range_loop.scope),
598 })
599 }
600 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
601 instructions: self.compile_scope(&mut op.scope),
602 }),
603 };
604 }
605
606 fn compile_synchronization(
607 &mut self,
608 instructions: &mut Vec<wgsl::Instruction>,
609 synchronization: cube::Synchronization,
610 ) {
611 match synchronization {
612 cube::Synchronization::SyncCube => {
613 instructions.push(wgsl::Instruction::WorkgroupBarrier)
614 }
615 cube::Synchronization::SyncPlane => {
616 panic!("Synchronization within a plane is not supported in WGSL")
617 }
618 cube::Synchronization::SyncStorage => {
619 instructions.push(wgsl::Instruction::StorageBarrier)
620 }
621 cube::Synchronization::SyncProxyShared => panic!("TMA is not supported in WGSL"),
622 };
623 }
624
625 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
626 instructions.push(wgsl::Instruction::Comment { content })
627 }
628
629 fn compile_metadata(
630 &mut self,
631 metadata: cube::Metadata,
632 out: Option<cube::Variable>,
633 ) -> wgsl::Instruction {
634 let out = out.unwrap();
635 match metadata {
636 cube::Metadata::Rank { var } => {
637 let position = self.ext_meta_pos(&var);
638 let offset = self.metadata.rank_index(position);
639 wgsl::Instruction::Metadata {
640 out: self.compile_variable(out),
641 info_offset: self.compile_variable(offset.into()),
642 }
643 }
644 cube::Metadata::Stride { dim, var } => {
645 let position = self.ext_meta_pos(&var);
646 let offset = self.metadata.stride_offset_index(position);
647 wgsl::Instruction::ExtendedMeta {
648 info_offset: self.compile_variable(offset.into()),
649 dim: self.compile_variable(dim),
650 out: self.compile_variable(out),
651 }
652 }
653 cube::Metadata::Shape { dim, var } => {
654 let position = self.ext_meta_pos(&var);
655 let offset = self.metadata.shape_offset_index(position);
656 wgsl::Instruction::ExtendedMeta {
657 info_offset: self.compile_variable(offset.into()),
658 dim: self.compile_variable(dim),
659 out: self.compile_variable(out),
660 }
661 }
662 cube::Metadata::Length { var } => match var.kind {
663 cube::VariableKind::GlobalInputArray(id) => {
664 let offset = self.metadata.len_index(id);
665 wgsl::Instruction::Metadata {
666 out: self.compile_variable(out),
667 info_offset: self.compile_variable(offset.into()),
668 }
669 }
670 cube::VariableKind::GlobalOutputArray(id) => {
671 let offset = self.metadata.len_index(id);
672 wgsl::Instruction::Metadata {
673 out: self.compile_variable(out),
674 info_offset: self.compile_variable(offset.into()),
675 }
676 }
677 _ => wgsl::Instruction::Length {
678 var: self.compile_variable(var),
679 out: self.compile_variable(out),
680 },
681 },
682 cube::Metadata::BufferLength { var } => match var.kind {
683 cube::VariableKind::GlobalInputArray(id) => {
684 let offset = self.metadata.buffer_len_index(id);
685 wgsl::Instruction::Metadata {
686 out: self.compile_variable(out),
687 info_offset: self.compile_variable(offset.into()),
688 }
689 }
690 cube::VariableKind::GlobalOutputArray(id) => {
691 let offset = self.metadata.buffer_len_index(id);
692 wgsl::Instruction::Metadata {
693 out: self.compile_variable(out),
694 info_offset: self.compile_variable(offset.into()),
695 }
696 }
697 _ => wgsl::Instruction::Length {
698 var: self.compile_variable(var),
699 out: self.compile_variable(out),
700 },
701 },
702 }
703 }
704
705 fn compile_arithmetic(
706 &mut self,
707 value: cube::Arithmetic,
708 out: Option<cube::Variable>,
709 instructions: &mut Vec<wgsl::Instruction>,
710 scope: &mut Scope,
711 ) {
712 let out = out.unwrap();
713 match value {
714 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
715 lhs: self.compile_variable(op.lhs),
716 rhs: self.compile_variable(op.rhs),
717 out: self.compile_variable(out),
718 }),
719 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
720 lhs: self.compile_variable(op.lhs),
721 rhs: self.compile_variable(op.rhs),
722 out: self.compile_variable(out),
723 }),
724 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
725 lhs: self.compile_variable(op.lhs),
726 rhs: self.compile_variable(op.rhs),
727 out: self.compile_variable(out),
728 }),
729 cube::Arithmetic::SaturatingAdd(_) => {
730 unreachable!("Saturating add should be removed by processor");
731 }
732 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
733 a: self.compile_variable(op.a),
734 b: self.compile_variable(op.b),
735 c: self.compile_variable(op.c),
736 out: self.compile_variable(out),
737 }),
738 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
739 lhs: self.compile_variable(op.lhs),
740 rhs: self.compile_variable(op.rhs),
741 out: self.compile_variable(out),
742 }),
743 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
744 lhs: self.compile_variable(op.lhs),
745 rhs: self.compile_variable(op.rhs),
746 out: self.compile_variable(out),
747 }),
748 cube::Arithmetic::SaturatingSub(_) => {
749 unreachable!("Saturating sub should be removed by processor");
750 }
751 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
752 lhs: self.compile_variable(op.lhs),
753 rhs: self.compile_variable(op.rhs),
754 out: self.compile_variable(out),
755 }),
756 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
757 lhs: self.compile_variable(op.lhs),
758 rhs: self.compile_variable(op.rhs),
759 out: self.compile_variable(out),
760 }),
761 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
762 input: self.compile_variable(op.input),
763 out: self.compile_variable(out),
764 }),
765 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
766 input: self.compile_variable(op.input),
767 out: self.compile_variable(out),
768 }),
769 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
770 input: self.compile_variable(op.input),
771 out: self.compile_variable(out),
772 }),
773 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
774 input: self.compile_variable(op.input),
775 out: self.compile_variable(out),
776 }),
777 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
778 input: self.compile_variable(op.input),
779 out: self.compile_variable(out),
780 }),
781 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
782 input: self.compile_variable(op.input),
783 out: self.compile_variable(out),
784 }),
785 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
786 input: self.compile_variable(op.input),
787 out: self.compile_variable(out),
788 }),
789 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
791 instructions.push(wgsl::Instruction::Powf {
792 lhs: self.compile_variable(op.lhs),
793 rhs: self.compile_variable(op.rhs),
794 out: self.compile_variable(out),
795 })
796 }
797 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
798 input: self.compile_variable(op.input),
799 out: self.compile_variable(out),
800 }),
801 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
802 input: self.compile_variable(op.input),
803 out: self.compile_variable(out),
804 }),
805 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
806 input: self.compile_variable(op.input),
807 out: self.compile_variable(out),
808 }),
809 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
810 input: self.compile_variable(op.input),
811 out: self.compile_variable(out),
812 }),
813 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
814 input: self.compile_variable(op.input),
815 out: self.compile_variable(out),
816 }),
817 cube::Arithmetic::Erf(op) => {
818 let mut scope = scope.child();
819 expand_erf(&mut scope, op.input, out);
820 instructions.extend(self.compile_scope(&mut scope));
821 }
822 cube::Arithmetic::MulHi(op) => {
823 let mut scope = scope.child();
824 match self.compilation_options.supports_u64 {
825 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
826 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
827 }
828 instructions.extend(self.compile_scope(&mut scope));
829 }
830 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
831 input: self.compile_variable(op.input),
832 out: self.compile_variable(out),
833 }),
834 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
835 input: self.compile_variable(op.input),
836 min_value: self.compile_variable(op.min_value),
837 max_value: self.compile_variable(op.max_value),
838 out: self.compile_variable(out),
839 }),
840 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
841 lhs: self.compile_variable(op.lhs),
842 rhs: self.compile_variable(op.rhs),
843 out: self.compile_variable(out),
844 }),
845 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
846 input: self.compile_variable(op.input),
847 out: self.compile_variable(out),
848 }),
849 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
850 input: self.compile_variable(op.input),
851 out: self.compile_variable(out),
852 }),
853 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
854 input: self.compile_variable(op.input),
855 out: self.compile_variable(out),
856 }),
857 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
858 lhs: self.compile_variable(op.lhs),
859 rhs: self.compile_variable(op.rhs),
860 out: self.compile_variable(out),
861 }),
862 }
863 }
864
865 fn compile_cmp(
866 &mut self,
867 value: cube::Comparison,
868 out: Option<cube::Variable>,
869 instructions: &mut Vec<wgsl::Instruction>,
870 ) {
871 let out = out.unwrap();
872 match value {
873 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
874 lhs: self.compile_variable(op.lhs),
875 rhs: self.compile_variable(op.rhs),
876 out: self.compile_variable(out),
877 }),
878 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
879 lhs: self.compile_variable(op.lhs),
880 rhs: self.compile_variable(op.rhs),
881 out: self.compile_variable(out),
882 }),
883 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
884 lhs: self.compile_variable(op.lhs),
885 rhs: self.compile_variable(op.rhs),
886 out: self.compile_variable(out),
887 }),
888 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
889 lhs: self.compile_variable(op.lhs),
890 rhs: self.compile_variable(op.rhs),
891 out: self.compile_variable(out),
892 }),
893 cube::Comparison::GreaterEqual(op) => {
894 instructions.push(wgsl::Instruction::GreaterEqual {
895 lhs: self.compile_variable(op.lhs),
896 rhs: self.compile_variable(op.rhs),
897 out: self.compile_variable(out),
898 })
899 }
900 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
901 lhs: self.compile_variable(op.lhs),
902 rhs: self.compile_variable(op.rhs),
903 out: self.compile_variable(out),
904 }),
905 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
906 input: self.compile_variable(op.input),
907 out: self.compile_variable(out),
908 }),
909 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
910 input: self.compile_variable(op.input),
911 out: self.compile_variable(out),
912 }),
913 }
914 }
915
916 fn compile_bitwise(
917 &mut self,
918 value: cube::Bitwise,
919 out: Option<cube::Variable>,
920 instructions: &mut Vec<wgsl::Instruction>,
921 ) {
922 let out = out.unwrap();
923 match value {
924 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
925 lhs: self.compile_variable(op.lhs),
926 rhs: self.compile_variable(op.rhs),
927 out: self.compile_variable(out),
928 }),
929 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
930 lhs: self.compile_variable(op.lhs),
931 rhs: self.compile_variable(op.rhs),
932 out: self.compile_variable(out),
933 }),
934 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
935 lhs: self.compile_variable(op.lhs),
936 rhs: self.compile_variable(op.rhs),
937 out: self.compile_variable(out),
938 }),
939 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
940 input: self.compile_variable(op.input),
941 out: self.compile_variable(out),
942 }),
943 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
944 input: self.compile_variable(op.input),
945 out: self.compile_variable(out),
946 }),
947 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
948 lhs: self.compile_variable(op.lhs),
949 rhs: self.compile_variable(op.rhs),
950 out: self.compile_variable(out),
951 }),
952 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
953 lhs: self.compile_variable(op.lhs),
954 rhs: self.compile_variable(op.rhs),
955 out: self.compile_variable(out),
956 }),
957 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
958 input: self.compile_variable(op.input),
959 out: self.compile_variable(out),
960 }),
961 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
962 input: self.compile_variable(op.input),
963 out: self.compile_variable(out),
964 }),
965 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
966 input: self.compile_variable(op.input),
967 out: self.compile_variable(out),
968 }),
969 }
970 }
971
972 fn compile_operator(
973 &mut self,
974 value: cube::Operator,
975 out: Option<cube::Variable>,
976 instructions: &mut Vec<wgsl::Instruction>,
977 ) {
978 let out = out.unwrap();
979 match value {
980 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
981 input: self.compile_variable(op.input),
982 out: self.compile_variable(out),
983 }),
984 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
985 instructions.push(wgsl::Instruction::Index {
986 lhs: self.compile_variable(op.list),
987 rhs: self.compile_variable(op.index),
988 out: self.compile_variable(out),
989 });
990 }
991 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
992 instructions.push(wgsl::Instruction::IndexAssign {
993 index: self.compile_variable(op.index),
994 rhs: self.compile_variable(op.value),
995 out: self.compile_variable(out),
996 })
997 }
998 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
999 lhs: self.compile_variable(op.lhs),
1000 rhs: self.compile_variable(op.rhs),
1001 out: self.compile_variable(out),
1002 }),
1003 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1004 lhs: self.compile_variable(op.lhs),
1005 rhs: self.compile_variable(op.rhs),
1006 out: self.compile_variable(out),
1007 }),
1008 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1009 input: self.compile_variable(op.input),
1010 out: self.compile_variable(out),
1011 }),
1012 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1013 input: self.compile_variable(op.input),
1014 out: self.compile_variable(out),
1015 }),
1016 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1017 inputs: op
1018 .inputs
1019 .into_iter()
1020 .map(|var| self.compile_variable(var))
1021 .collect(),
1022 out: self.compile_variable(out),
1023 }),
1024 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1025 input: self.compile_variable(op.input),
1026 in_index: self.compile_variable(op.in_index),
1027 out: self.compile_variable(out),
1028 out_index: self.compile_variable(op.out_index),
1029 }),
1030 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1031 input: self.compile_variable(op.input),
1032 in_index: self.compile_variable(op.in_index),
1033 out: self.compile_variable(out),
1034 out_index: self.compile_variable(op.out_index),
1035 len: op.len,
1036 }),
1037 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1038 cond: self.compile_variable(op.cond),
1039 then: self.compile_variable(op.then),
1040 or_else: self.compile_variable(op.or_else),
1041 out: self.compile_variable(out),
1042 }),
1043 }
1044 }
1045
1046 fn compile_atomic(
1047 &mut self,
1048 atomic: cube::AtomicOp,
1049 out: Option<cube::Variable>,
1050 ) -> wgsl::Instruction {
1051 let out = out.unwrap();
1052 match atomic {
1053 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1054 lhs: self.compile_variable(op.lhs),
1055 rhs: self.compile_variable(op.rhs),
1056 out: self.compile_variable(out),
1057 },
1058 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1059 lhs: self.compile_variable(op.lhs),
1060 rhs: self.compile_variable(op.rhs),
1061 out: self.compile_variable(out),
1062 },
1063 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1064 lhs: self.compile_variable(op.lhs),
1065 rhs: self.compile_variable(op.rhs),
1066 out: self.compile_variable(out),
1067 },
1068 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1069 lhs: self.compile_variable(op.lhs),
1070 rhs: self.compile_variable(op.rhs),
1071 out: self.compile_variable(out),
1072 },
1073 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1074 lhs: self.compile_variable(op.lhs),
1075 rhs: self.compile_variable(op.rhs),
1076 out: self.compile_variable(out),
1077 },
1078 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1079 lhs: self.compile_variable(op.lhs),
1080 rhs: self.compile_variable(op.rhs),
1081 out: self.compile_variable(out),
1082 },
1083 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1084 lhs: self.compile_variable(op.lhs),
1085 rhs: self.compile_variable(op.rhs),
1086 out: self.compile_variable(out),
1087 },
1088 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1089 input: self.compile_variable(op.input),
1090 out: self.compile_variable(out),
1091 },
1092 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1093 input: self.compile_variable(op.input),
1094 out: self.compile_variable(out),
1095 },
1096 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1097 lhs: self.compile_variable(op.lhs),
1098 rhs: self.compile_variable(op.rhs),
1099 out: self.compile_variable(out),
1100 },
1101 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1102 lhs: self.compile_variable(op.input),
1103 cmp: self.compile_variable(op.cmp),
1104 value: self.compile_variable(op.val),
1105 out: self.compile_variable(out),
1106 },
1107 }
1108 }
1109
1110 fn compile_location(value: compute::Location) -> wgsl::Location {
1111 match value {
1112 compute::Location::Storage => wgsl::Location::Storage,
1113 compute::Location::Cube => wgsl::Location::Workgroup,
1114 }
1115 }
1116
1117 fn compile_binding(&mut self, value: compute::Binding) -> wgsl::Binding {
1118 wgsl::Binding {
1119 id: value.id,
1120 visibility: value.visibility,
1121 location: Self::compile_location(value.location),
1122 item: self.compile_type(value.ty),
1123 size: value.size,
1124 }
1125 }
1126}
1127
1128fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1129 let mut extensions = Vec::new();
1130
1131 let mut register_extension = |extension: wgsl::Extension| {
1132 if !extensions.contains(&extension) {
1133 extensions.push(extension);
1134 }
1135 };
1136
1137 for instruction in instructions {
1139 match instruction {
1140 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1141 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1142 register_extension(wgsl::powf_extension(rhs, out));
1143 }
1144 #[cfg(target_os = "macos")]
1145 wgsl::Instruction::Tanh { input, out: _ } => {
1146 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1147 register_extension(wgsl::Extension::SafeTanh(input.item()));
1148 }
1149 wgsl::Instruction::IsNan { input, out } => {
1150 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1151 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1152 }
1153 wgsl::Instruction::IsInf { input, out } => {
1154 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1155 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1156 }
1157 wgsl::Instruction::If { instructions, .. } => {
1158 for extension in register_extensions(instructions) {
1159 register_extension(extension);
1160 }
1161 }
1162 wgsl::Instruction::IfElse {
1163 instructions_if,
1164 instructions_else,
1165 ..
1166 } => {
1167 for extension in register_extensions(instructions_if) {
1168 register_extension(extension);
1169 }
1170 for extension in register_extensions(instructions_else) {
1171 register_extension(extension);
1172 }
1173 }
1174 wgsl::Instruction::Loop { instructions } => {
1175 for extension in register_extensions(instructions) {
1176 register_extension(extension);
1177 }
1178 }
1179 wgsl::Instruction::RangeLoop { instructions, .. } => {
1180 for extension in register_extensions(instructions) {
1181 register_extension(extension);
1182 }
1183 }
1184 _ => {}
1185 }
1186 }
1187
1188 extensions
1189}