1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::post_processing::{
8 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12 Metadata, WgpuCompilationOptions,
13 ir::{self as cube, Scope},
14 prelude::expand_erf,
15};
16use cubecl_core::{
17 ir::{Processor, UIntKind},
18 post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: usize = 4;
24
25#[derive(Clone, Default)]
27pub struct WgslCompiler {
28 metadata: Metadata,
29 ext_meta_pos: Vec<u32>,
30 local_invocation_index: bool,
31 local_invocation_id: bool,
32 global_invocation_id: bool,
34 workgroup_id: bool,
35 subgroup_size: bool,
36 subgroup_id: bool,
37 subgroup_invocation_id: bool,
38 id: bool,
39 num_workgroups: bool,
40 workgroup_id_no_axis: bool,
41 workgroup_size_no_axis: bool,
42 num_workgroup_no_axis: bool,
43 shared_arrays: Vec<SharedArray>,
44 shared_values: Vec<SharedValue>,
45 const_arrays: Vec<ConstantArray>,
46 local_arrays: Vec<LocalArray>,
47 #[allow(dead_code)]
48 compilation_options: WgpuCompilationOptions,
49 strategy: ExecutionMode,
50 subgroup_instructions_used: bool,
51 f16_used: bool,
52}
53
54impl core::fmt::Debug for WgslCompiler {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.write_str("WgslCompiler")
57 }
58}
59
60impl cubecl_core::Compiler for WgslCompiler {
61 type Representation = ComputeShader;
62 type CompilationOptions = WgpuCompilationOptions;
63
64 fn compile(
65 &mut self,
66 shader: kernel::KernelDefinition,
67 compilation_options: &Self::CompilationOptions,
68 mode: ExecutionMode,
69 address_type: StorageType,
70 ) -> Result<Self::Representation, CompilationError> {
71 self.compilation_options = compilation_options.clone();
72 self.compile_shader(shader, mode, address_type)
73 }
74
75 fn elem_size(&self, elem: cube::ElemType) -> usize {
76 elem.size()
77 }
78
79 fn extension(&self) -> &'static str {
80 "wgsl"
81 }
82}
83
84impl WgslCompiler {
85 fn compile_shader(
86 &mut self,
87 mut value: kernel::KernelDefinition,
88 mode: ExecutionMode,
89 address_type: StorageType,
90 ) -> Result<wgsl::ComputeShader, CompilationError> {
91 let errors = value.body.pop_errors();
92 if !errors.is_empty() {
93 let mut reason = "Can't compile wgsl kernel".to_string();
94 for error in errors {
95 reason += error.as_str();
96 reason += "\n";
97 }
98
99 return Err(CompilationError::Validation {
100 reason,
101 backtrace: BackTrace::capture(),
102 });
103 }
104
105 self.strategy = mode;
106
107 let num_meta = value.buffers.len();
108
109 self.ext_meta_pos = Vec::new();
110 let mut num_ext = 0;
111
112 for binding in value.buffers.iter() {
113 self.ext_meta_pos.push(num_ext);
114 if binding.has_extended_meta {
115 num_ext += 1;
116 }
117 }
118
119 self.metadata = Metadata::new(num_meta as u32, num_ext);
120
121 let address_type = self.compile_storage_type(address_type);
122 let instructions = self.compile_scope(&mut value.body);
123 let extensions = register_extensions(&instructions);
124 let body = wgsl::Body {
125 instructions,
126 id: self.id,
127 address_type,
128 };
129
130 Ok(wgsl::ComputeShader {
131 address_type,
132 buffers: value
133 .buffers
134 .into_iter()
135 .map(|mut it| {
136 if it.ty.line_size() > MAX_LINE_SIZE {
139 it.ty = it.ty.line(MAX_LINE_SIZE);
140 }
141 self.compile_binding(it)
142 })
143 .collect(),
144 scalars: value
145 .scalars
146 .into_iter()
147 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
148 .collect(),
149 shared_arrays: self.shared_arrays.clone(),
150 shared_values: self.shared_values.clone(),
151 constant_arrays: self.const_arrays.clone(),
152 local_arrays: self.local_arrays.clone(),
153 has_metadata: self.metadata.static_len() > 0,
154 workgroup_size: value.cube_dim,
155 global_invocation_id: self.global_invocation_id || self.id,
156 local_invocation_index: self.local_invocation_index,
157 local_invocation_id: self.local_invocation_id,
158 num_workgroups: self.id
159 || self.num_workgroups
160 || self.num_workgroup_no_axis
161 || self.workgroup_id_no_axis,
162 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
163 subgroup_size: self.subgroup_size,
164 subgroup_id: self.subgroup_id,
165 subgroup_invocation_id: self.subgroup_invocation_id,
166 body,
167 extensions,
168 num_workgroups_no_axis: self.num_workgroup_no_axis,
169 workgroup_id_no_axis: self.workgroup_id_no_axis,
170 workgroup_size_no_axis: self.workgroup_size_no_axis,
171 subgroup_instructions_used: self.subgroup_instructions_used,
172 f16_used: self.f16_used,
173 kernel_name: value.options.kernel_name,
174 })
175 }
176
177 fn compile_type(&mut self, item: cube::Type) -> Item {
178 match item {
179 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
180 cube::Type::Line(ty, size) => {
181 let elem = self.compile_storage_type(ty);
182 match size {
183 2 => wgsl::Item::Vec2(elem),
184 3 => wgsl::Item::Vec3(elem),
185 4 => wgsl::Item::Vec4(elem),
186 _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
187 }
188 }
189 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
190 }
191 }
192
193 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
194 match ty {
195 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
196 cube::StorageType::Atomic(ty) => match ty {
197 cube::ElemType::Float(i) => match i {
198 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
199 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
200 },
201 cube::ElemType::Int(i) => match i {
202 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
203 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
204 },
205 cube::ElemType::UInt(kind) => match kind {
206 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
207 kind => panic!("{kind:?} is not a valid WgpuElement"),
208 },
209 other => panic!("{other:?} is not a valid WgpuElement"),
210 },
211 cube::StorageType::Packed(_, _) => {
212 unimplemented!("Packed types not yet supported in WGSL")
213 }
214 cube::StorageType::Opaque(ty) => match ty {
215 cube::OpaqueType::Barrier(_) => {
216 unimplemented!("Barrier objects not supported in WGSL")
217 }
218 },
219 }
220 }
221
222 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
223 match value {
224 cube::ElemType::Float(f) => match f {
225 cube::FloatKind::E2M1
226 | cube::FloatKind::E2M3
227 | cube::FloatKind::E3M2
228 | cube::FloatKind::E4M3
229 | cube::FloatKind::E5M2
230 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
231 cube::FloatKind::F16 => {
232 self.f16_used = true;
233 wgsl::Elem::F16
234 }
235 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
236 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
237 cube::FloatKind::Flex32 => wgsl::Elem::F32,
238 cube::FloatKind::F32 => wgsl::Elem::F32,
239 cube::FloatKind::F64 => wgsl::Elem::F64,
240 },
241 cube::ElemType::Int(i) => match i {
242 cube::IntKind::I32 => wgsl::Elem::I32,
243 cube::IntKind::I64 => wgsl::Elem::I64,
244 kind => panic!("{kind:?} is not a valid WgpuElement"),
245 },
246 cube::ElemType::UInt(kind) => match kind {
247 cube::UIntKind::U32 => wgsl::Elem::U32,
248 cube::UIntKind::U64 => wgsl::Elem::U64,
249 kind => panic!("{kind:?} is not a valid WgpuElement"),
250 },
251 cube::ElemType::Bool => wgsl::Elem::Bool,
252 }
253 }
254
255 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
256 let pos = var.index().expect("Variable should have index");
257 self.ext_meta_pos[pos as usize]
258 }
259
260 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
261 let item = value.ty;
262 match value.kind {
263 cube::VariableKind::GlobalInputArray(id) => {
264 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
265 }
266 cube::VariableKind::GlobalScalar(id) => {
267 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
268 }
269 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
270 wgsl::Variable::LocalMut {
271 id,
272 item: self.compile_type(item),
273 }
274 }
275 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
276 id,
277 item: self.compile_type(item),
278 },
279 cube::VariableKind::GlobalOutputArray(id) => {
280 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
281 }
282 cube::VariableKind::Constant(value) => {
283 wgsl::Variable::Constant(value, self.compile_type(item))
284 }
285 cube::VariableKind::SharedArray {
286 id,
287 length,
288 unroll_factor,
289 alignment,
290 } => {
291 let item = self.compile_type(item);
292 if !self.shared_arrays.iter().any(|s| s.index == id) {
293 self.shared_arrays.push(SharedArray::new(
294 id,
295 item,
296 (length * unroll_factor) as u32,
297 alignment.map(|it| it as u32),
298 ));
299 }
300 wgsl::Variable::SharedArray(id, item, length as u32)
301 }
302 cube::VariableKind::Shared { id } => {
303 let item = self.compile_type(item);
304 if !self.shared_values.iter().any(|s| s.index == id) {
305 self.shared_values.push(SharedValue::new(id, item));
306 }
307 wgsl::Variable::SharedValue(id, item)
308 }
309 cube::VariableKind::ConstantArray { id, length, .. } => {
310 let item = self.compile_type(item);
311 wgsl::Variable::ConstantArray(id, item, length as u32)
312 }
313 cube::VariableKind::LocalArray {
314 id,
315 length,
316 unroll_factor,
317 } => {
318 let item = self.compile_type(item);
319 if !self.local_arrays.iter().any(|s| s.index == id) {
320 self.local_arrays.push(LocalArray::new(
321 id,
322 item,
323 (length * unroll_factor) as u32,
324 ));
325 }
326 wgsl::Variable::LocalArray(id, item, length as u32)
327 }
328 cube::VariableKind::Builtin(builtin) => match builtin {
329 cube::Builtin::AbsolutePos => {
330 self.id = true;
331 wgsl::Variable::Id
332 }
333 cube::Builtin::UnitPos => {
334 self.local_invocation_index = true;
335 wgsl::Variable::LocalInvocationIndex
336 }
337 cube::Builtin::UnitPosX => {
338 self.local_invocation_id = true;
339 wgsl::Variable::LocalInvocationIdX
340 }
341 cube::Builtin::UnitPosY => {
342 self.local_invocation_id = true;
343 wgsl::Variable::LocalInvocationIdY
344 }
345 cube::Builtin::UnitPosZ => {
346 self.local_invocation_id = true;
347 wgsl::Variable::LocalInvocationIdZ
348 }
349 cube::Builtin::CubePosX => {
350 self.workgroup_id = true;
351 wgsl::Variable::WorkgroupIdX
352 }
353 cube::Builtin::CubePosY => {
354 self.workgroup_id = true;
355 wgsl::Variable::WorkgroupIdY
356 }
357 cube::Builtin::CubePosZ => {
358 self.workgroup_id = true;
359 wgsl::Variable::WorkgroupIdZ
360 }
361 cube::Builtin::CubePosCluster
362 | cube::Builtin::CubePosClusterX
363 | cube::Builtin::CubePosClusterY
364 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
365 cube::Builtin::AbsolutePosX => {
366 self.global_invocation_id = true;
367 wgsl::Variable::GlobalInvocationIdX
368 }
369 cube::Builtin::AbsolutePosY => {
370 self.global_invocation_id = true;
371 wgsl::Variable::GlobalInvocationIdY
372 }
373 cube::Builtin::AbsolutePosZ => {
374 self.global_invocation_id = true;
375 wgsl::Variable::GlobalInvocationIdZ
376 }
377 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
378 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
379 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
380 cube::Builtin::CubeClusterDim
381 | cube::Builtin::CubeClusterDimX
382 | cube::Builtin::CubeClusterDimY
383 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
384 cube::Builtin::CubeCountX => {
385 self.num_workgroups = true;
386 wgsl::Variable::NumWorkgroupsX
387 }
388 cube::Builtin::CubeCountY => {
389 self.num_workgroups = true;
390 wgsl::Variable::NumWorkgroupsY
391 }
392 cube::Builtin::CubeCountZ => {
393 self.num_workgroups = true;
394 wgsl::Variable::NumWorkgroupsZ
395 }
396 cube::Builtin::CubePos => {
397 self.workgroup_id_no_axis = true;
398 wgsl::Variable::WorkgroupId
399 }
400 cube::Builtin::CubeDim => {
401 self.workgroup_size_no_axis = true;
402 wgsl::Variable::WorkgroupSize
403 }
404 cube::Builtin::CubeCount => {
405 self.num_workgroup_no_axis = true;
406 wgsl::Variable::NumWorkgroups
407 }
408 cube::Builtin::PlaneDim => {
409 self.subgroup_size = true;
410 wgsl::Variable::SubgroupSize
411 }
412 cube::Builtin::PlanePos => {
413 self.subgroup_id = true;
414 wgsl::Variable::SubgroupId
415 }
416 cube::Builtin::UnitPosPlane => {
417 self.subgroup_invocation_id = true;
418 wgsl::Variable::SubgroupInvocationId
419 }
420 },
421 cube::VariableKind::Matrix { .. } => {
422 panic!("Cooperative matrix-multiply and accumulate not supported.")
423 }
424 cube::VariableKind::Pipeline { .. } => {
425 panic!("Pipeline not supported.")
426 }
427 cube::VariableKind::BarrierToken { .. } => {
428 panic!("Barrier not supported.")
429 }
430 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
431 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
432 }
433 }
434
435 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
436 let var = cube::Variable::constant(value.into(), UIntKind::U32);
437 self.compile_variable(var)
438 }
439
440 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
441 let mut instructions = Vec::new();
442
443 let const_arrays = scope
444 .const_arrays
445 .drain(..)
446 .map(|(var, values)| ConstantArray {
447 index: var.index().unwrap(),
448 item: self.compile_type(var.ty),
449 size: values.len() as u32,
450 values: values
451 .into_iter()
452 .map(|val| self.compile_variable(val))
453 .collect(),
454 })
455 .collect::<Vec<_>>();
456 self.const_arrays.extend(const_arrays);
457
458 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
459 let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
460 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
461 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
462
463 for mut var in processing.variables {
464 if var.ty.line_size() > MAX_LINE_SIZE {
465 var.ty = var.ty.line(MAX_LINE_SIZE);
466 }
467 instructions.push(wgsl::Instruction::DeclareVariable {
468 var: self.compile_variable(var),
469 });
470 }
471
472 processing
473 .instructions
474 .into_iter()
475 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
476
477 instructions
478 }
479
480 fn compile_operation(
481 &mut self,
482 instructions: &mut Vec<wgsl::Instruction>,
483 operation: cube::Operation,
484 out: Option<cube::Variable>,
485 scope: &mut cube::Scope,
486 ) {
487 match operation {
488 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
489 input: self.compile_variable(variable),
490 out: self.compile_variable(out.unwrap()),
491 }),
492 cube::Operation::Arithmetic(op) => {
493 self.compile_arithmetic(op, out, instructions, scope)
494 }
495 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
496 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
497 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
498 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
499 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
500 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
501 cube::Operation::Synchronization(val) => {
502 self.compile_synchronization(instructions, val)
503 }
504 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
505 cube::Operation::CoopMma(_) => {
506 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
507 }
508 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
509 self.compile_comment(instructions, content)
510 }
511 cube::Operation::NonSemantic(_) => {}
512 cube::Operation::Barrier(_) => {
513 panic!("Barrier isn't supported on wgpu.")
514 }
515 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
516 cube::Operation::Marker(_) => {}
517 }
518 }
519
520 fn compile_subgroup(
521 &mut self,
522 instructions: &mut Vec<wgsl::Instruction>,
523 subgroup: cube::Plane,
524 out: Option<cube::Variable>,
525 ) {
526 self.subgroup_instructions_used = true;
527
528 let out = out.unwrap();
529 let op = match subgroup {
530 cube::Plane::Elect => Subgroup::Elect {
531 out: self.compile_variable(out),
532 },
533 cube::Plane::All(op) => Subgroup::All {
534 input: self.compile_variable(op.input),
535 out: self.compile_variable(out),
536 },
537 cube::Plane::Any(op) => Subgroup::Any {
538 input: self.compile_variable(op.input),
539 out: self.compile_variable(out),
540 },
541 cube::Plane::Ballot(op) => Subgroup::Ballot {
542 input: self.compile_variable(op.input),
543 out: self.compile_variable(out),
544 },
545
546 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
547 lhs: self.compile_variable(op.lhs),
548 rhs: self.compile_variable(op.rhs),
549 out: self.compile_variable(out),
550 },
551
552 cube::Plane::Sum(op) => Subgroup::Sum {
553 input: self.compile_variable(op.input),
554 out: self.compile_variable(out),
555 },
556
557 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
558 input: self.compile_variable(op.input),
559 out: self.compile_variable(out),
560 },
561 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
562 input: self.compile_variable(op.input),
563 out: self.compile_variable(out),
564 },
565 cube::Plane::Prod(op) => Subgroup::Prod {
566 input: self.compile_variable(op.input),
567 out: self.compile_variable(out),
568 },
569 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
570 input: self.compile_variable(op.input),
571 out: self.compile_variable(out),
572 },
573 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
574 input: self.compile_variable(op.input),
575 out: self.compile_variable(out),
576 },
577 cube::Plane::Min(op) => Subgroup::Min {
578 input: self.compile_variable(op.input),
579 out: self.compile_variable(out),
580 },
581 cube::Plane::Max(op) => Subgroup::Max {
582 input: self.compile_variable(op.input),
583 out: self.compile_variable(out),
584 },
585 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
586 lhs: self.compile_variable(op.lhs),
587 rhs: self.compile_variable(op.rhs),
588 out: self.compile_variable(out),
589 },
590 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
591 lhs: self.compile_variable(op.lhs),
592 rhs: self.compile_variable(op.rhs),
593 out: self.compile_variable(out),
594 },
595 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
596 lhs: self.compile_variable(op.lhs),
597 rhs: self.compile_variable(op.rhs),
598 out: self.compile_variable(out),
599 },
600 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
601 lhs: self.compile_variable(op.lhs),
602 rhs: self.compile_variable(op.rhs),
603 out: self.compile_variable(out),
604 },
605 };
606
607 instructions.push(wgsl::Instruction::Subgroup(op));
608 }
609
610 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
611 match branch {
612 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
613 cond: self.compile_variable(op.cond),
614 instructions: self.compile_scope(&mut op.scope),
615 }),
616 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
617 cond: self.compile_variable(op.cond),
618 instructions_if: self.compile_scope(&mut op.scope_if),
619 instructions_else: self.compile_scope(&mut op.scope_else),
620 }),
621 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
622 value: self.compile_variable(op.value),
623 instructions_default: self.compile_scope(&mut op.scope_default),
624 cases: op
625 .cases
626 .into_iter()
627 .map(|(val, mut scope)| {
628 (self.compile_variable(val), self.compile_scope(&mut scope))
629 })
630 .collect(),
631 }),
632 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
633 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
634 cube::Branch::RangeLoop(mut range_loop) => {
635 instructions.push(wgsl::Instruction::RangeLoop {
636 i: self.compile_variable(range_loop.i),
637 start: self.compile_variable(range_loop.start),
638 end: self.compile_variable(range_loop.end),
639 step: range_loop.step.map(|it| self.compile_variable(it)),
640 inclusive: range_loop.inclusive,
641 instructions: self.compile_scope(&mut range_loop.scope),
642 })
643 }
644 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
645 instructions: self.compile_scope(&mut op.scope),
646 }),
647 };
648 }
649
650 fn compile_synchronization(
651 &mut self,
652 instructions: &mut Vec<wgsl::Instruction>,
653 synchronization: cube::Synchronization,
654 ) {
655 match synchronization {
656 cube::Synchronization::SyncCube => {
657 instructions.push(wgsl::Instruction::WorkgroupBarrier)
658 }
659 cube::Synchronization::SyncPlane => {
660 panic!("Synchronization within a plane is not supported in WGSL")
661 }
662 cube::Synchronization::SyncStorage => {
663 instructions.push(wgsl::Instruction::StorageBarrier)
664 }
665 cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
666 };
667 }
668
669 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
670 instructions.push(wgsl::Instruction::Comment { content })
671 }
672
673 fn compile_metadata(
674 &mut self,
675 metadata: cube::Metadata,
676 out: Option<cube::Variable>,
677 ) -> wgsl::Instruction {
678 let out = out.unwrap();
679 match metadata {
680 cube::Metadata::Rank { var } => {
681 let position = self.ext_meta_pos(&var);
682 let offset = self.metadata.rank_index(position);
683 wgsl::Instruction::Metadata {
684 out: self.compile_variable(out),
685 info_offset: self.compile_variable(offset.into()),
686 }
687 }
688 cube::Metadata::Stride { dim, var } => {
689 let position = self.ext_meta_pos(&var);
690 let offset = self.metadata.stride_offset_index(position);
691 wgsl::Instruction::ExtendedMeta {
692 info_offset: self.compile_variable(offset.into()),
693 dim: self.compile_variable(dim),
694 out: self.compile_variable(out),
695 }
696 }
697 cube::Metadata::Shape { dim, var } => {
698 let position = self.ext_meta_pos(&var);
699 let offset = self.metadata.shape_offset_index(position);
700 wgsl::Instruction::ExtendedMeta {
701 info_offset: self.compile_variable(offset.into()),
702 dim: self.compile_variable(dim),
703 out: self.compile_variable(out),
704 }
705 }
706 cube::Metadata::Length { var } => match var.kind {
707 cube::VariableKind::GlobalInputArray(id) => {
708 let offset = self.metadata.len_index(id);
709 wgsl::Instruction::Metadata {
710 out: self.compile_variable(out),
711 info_offset: self.compile_variable(offset.into()),
712 }
713 }
714 cube::VariableKind::GlobalOutputArray(id) => {
715 let offset = self.metadata.len_index(id);
716 wgsl::Instruction::Metadata {
717 out: self.compile_variable(out),
718 info_offset: self.compile_variable(offset.into()),
719 }
720 }
721 _ => wgsl::Instruction::Length {
722 var: self.compile_variable(var),
723 out: self.compile_variable(out),
724 },
725 },
726 cube::Metadata::BufferLength { var } => match var.kind {
727 cube::VariableKind::GlobalInputArray(id) => {
728 let offset = self.metadata.buffer_len_index(id);
729 wgsl::Instruction::Metadata {
730 out: self.compile_variable(out),
731 info_offset: self.compile_variable(offset.into()),
732 }
733 }
734 cube::VariableKind::GlobalOutputArray(id) => {
735 let offset = self.metadata.buffer_len_index(id);
736 wgsl::Instruction::Metadata {
737 out: self.compile_variable(out),
738 info_offset: self.compile_variable(offset.into()),
739 }
740 }
741 _ => wgsl::Instruction::Length {
742 var: self.compile_variable(var),
743 out: self.compile_variable(out),
744 },
745 },
746 }
747 }
748
749 fn compile_arithmetic(
750 &mut self,
751 value: cube::Arithmetic,
752 out: Option<cube::Variable>,
753 instructions: &mut Vec<wgsl::Instruction>,
754 scope: &mut Scope,
755 ) {
756 let out = out.unwrap();
757 match value {
758 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
759 lhs: self.compile_variable(op.lhs),
760 rhs: self.compile_variable(op.rhs),
761 out: self.compile_variable(out),
762 }),
763 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
764 lhs: self.compile_variable(op.lhs),
765 rhs: self.compile_variable(op.rhs),
766 out: self.compile_variable(out),
767 }),
768 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
769 lhs: self.compile_variable(op.lhs),
770 rhs: self.compile_variable(op.rhs),
771 out: self.compile_variable(out),
772 }),
773 cube::Arithmetic::SaturatingAdd(_) => {
774 unreachable!("Saturating add should be removed by processor");
775 }
776 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
777 a: self.compile_variable(op.a),
778 b: self.compile_variable(op.b),
779 c: self.compile_variable(op.c),
780 out: self.compile_variable(out),
781 }),
782 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
783 lhs: self.compile_variable(op.lhs),
784 rhs: self.compile_variable(op.rhs),
785 out: self.compile_variable(out),
786 }),
787 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
788 lhs: self.compile_variable(op.lhs),
789 rhs: self.compile_variable(op.rhs),
790 out: self.compile_variable(out),
791 }),
792 cube::Arithmetic::SaturatingSub(_) => {
793 unreachable!("Saturating sub should be removed by processor");
794 }
795 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
796 lhs: self.compile_variable(op.lhs),
797 rhs: self.compile_variable(op.rhs),
798 out: self.compile_variable(out),
799 }),
800 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
801 lhs: self.compile_variable(op.lhs),
802 rhs: self.compile_variable(op.rhs),
803 out: self.compile_variable(out),
804 }),
805 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
806 input: self.compile_variable(op.input),
807 out: self.compile_variable(out),
808 }),
809 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
810 input: self.compile_variable(op.input),
811 out: self.compile_variable(out),
812 }),
813 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
814 input: self.compile_variable(op.input),
815 out: self.compile_variable(out),
816 }),
817 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
818 input: self.compile_variable(op.input),
819 out: self.compile_variable(out),
820 }),
821 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
822 input: self.compile_variable(op.input),
823 out: self.compile_variable(out),
824 }),
825 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
826 input: self.compile_variable(op.input),
827 out: self.compile_variable(out),
828 }),
829 cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
830 input: self.compile_variable(op.input),
831 out: self.compile_variable(out),
832 }),
833 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
834 input: self.compile_variable(op.input),
835 out: self.compile_variable(out),
836 }),
837 cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
838 input: self.compile_variable(op.input),
839 out: self.compile_variable(out),
840 }),
841 cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
842 input: self.compile_variable(op.input),
843 out: self.compile_variable(out),
844 }),
845 cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
846 input: self.compile_variable(op.input),
847 out: self.compile_variable(out),
848 }),
849 cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
850 input: self.compile_variable(op.input),
851 out: self.compile_variable(out),
852 }),
853 cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
854 input: self.compile_variable(op.input),
855 out: self.compile_variable(out),
856 }),
857 cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
858 input: self.compile_variable(op.input),
859 out: self.compile_variable(out),
860 }),
861 cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
862 input: self.compile_variable(op.input),
863 out: self.compile_variable(out),
864 }),
865 cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
866 input: self.compile_variable(op.input),
867 out: self.compile_variable(out),
868 }),
869 cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
870 input: self.compile_variable(op.input),
871 out: self.compile_variable(out),
872 }),
873 cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
874 input: self.compile_variable(op.input),
875 out: self.compile_variable(out),
876 }),
877 cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
878 lhs: self.compile_variable(op.lhs),
879 rhs: self.compile_variable(op.rhs),
880 out: self.compile_variable(out),
881 }),
882 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
884 instructions.push(wgsl::Instruction::Powf {
885 lhs: self.compile_variable(op.lhs),
886 rhs: self.compile_variable(op.rhs),
887 out: self.compile_variable(out),
888 })
889 }
890 cube::Arithmetic::Hypot(op) => {
891 let mut scope = scope.child();
892 expand_hypot(&mut scope, op.lhs, op.rhs, out);
893 instructions.extend(self.compile_scope(&mut scope));
894 }
895 cube::Arithmetic::Rhypot(op) => {
896 let mut scope = scope.child();
897 expand_rhypot(&mut scope, op.lhs, op.rhs, out);
898 instructions.extend(self.compile_scope(&mut scope));
899 }
900
901 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
902 input: self.compile_variable(op.input),
903 out: self.compile_variable(out),
904 }),
905 cube::Arithmetic::InverseSqrt(op) => {
906 instructions.push(wgsl::Instruction::InverseSqrt {
907 input: self.compile_variable(op.input),
908 out: self.compile_variable(out),
909 })
910 }
911 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
912 input: self.compile_variable(op.input),
913 out: self.compile_variable(out),
914 }),
915 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
916 input: self.compile_variable(op.input),
917 out: self.compile_variable(out),
918 }),
919 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
920 input: self.compile_variable(op.input),
921 out: self.compile_variable(out),
922 }),
923 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
924 input: self.compile_variable(op.input),
925 out: self.compile_variable(out),
926 }),
927 cube::Arithmetic::Erf(op) => {
928 let mut scope = scope.child();
929 expand_erf(&mut scope, op.input, out);
930 instructions.extend(self.compile_scope(&mut scope));
931 }
932 cube::Arithmetic::MulHi(op) => {
933 let mut scope = scope.child();
934 match self.compilation_options.supports_u64 {
935 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
936 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
937 }
938 instructions.extend(self.compile_scope(&mut scope));
939 }
940 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
941 input: self.compile_variable(op.input),
942 out: self.compile_variable(out),
943 }),
944 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
945 input: self.compile_variable(op.input),
946 min_value: self.compile_variable(op.min_value),
947 max_value: self.compile_variable(op.max_value),
948 out: self.compile_variable(out),
949 }),
950 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
951 lhs: self.compile_variable(op.lhs),
952 rhs: self.compile_variable(op.rhs),
953 out: self.compile_variable(out),
954 }),
955 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
956 input: self.compile_variable(op.input),
957 out: self.compile_variable(out),
958 }),
959 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
960 input: self.compile_variable(op.input),
961 out: self.compile_variable(out),
962 }),
963 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
964 input: self.compile_variable(op.input),
965 out: self.compile_variable(out),
966 }),
967 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
968 lhs: self.compile_variable(op.lhs),
969 rhs: self.compile_variable(op.rhs),
970 out: self.compile_variable(out),
971 }),
972 }
973 }
974
975 fn compile_cmp(
976 &mut self,
977 value: cube::Comparison,
978 out: Option<cube::Variable>,
979 instructions: &mut Vec<wgsl::Instruction>,
980 ) {
981 let out = out.unwrap();
982 match value {
983 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
984 lhs: self.compile_variable(op.lhs),
985 rhs: self.compile_variable(op.rhs),
986 out: self.compile_variable(out),
987 }),
988 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
989 lhs: self.compile_variable(op.lhs),
990 rhs: self.compile_variable(op.rhs),
991 out: self.compile_variable(out),
992 }),
993 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
994 lhs: self.compile_variable(op.lhs),
995 rhs: self.compile_variable(op.rhs),
996 out: self.compile_variable(out),
997 }),
998 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
999 lhs: self.compile_variable(op.lhs),
1000 rhs: self.compile_variable(op.rhs),
1001 out: self.compile_variable(out),
1002 }),
1003 cube::Comparison::GreaterEqual(op) => {
1004 instructions.push(wgsl::Instruction::GreaterEqual {
1005 lhs: self.compile_variable(op.lhs),
1006 rhs: self.compile_variable(op.rhs),
1007 out: self.compile_variable(out),
1008 })
1009 }
1010 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
1011 lhs: self.compile_variable(op.lhs),
1012 rhs: self.compile_variable(op.rhs),
1013 out: self.compile_variable(out),
1014 }),
1015 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1016 input: self.compile_variable(op.input),
1017 out: self.compile_variable(out),
1018 }),
1019 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1020 input: self.compile_variable(op.input),
1021 out: self.compile_variable(out),
1022 }),
1023 }
1024 }
1025
1026 fn compile_bitwise(
1027 &mut self,
1028 value: cube::Bitwise,
1029 out: Option<cube::Variable>,
1030 instructions: &mut Vec<wgsl::Instruction>,
1031 ) {
1032 let out = out.unwrap();
1033 match value {
1034 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1035 lhs: self.compile_variable(op.lhs),
1036 rhs: self.compile_variable(op.rhs),
1037 out: self.compile_variable(out),
1038 }),
1039 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1040 lhs: self.compile_variable(op.lhs),
1041 rhs: self.compile_variable(op.rhs),
1042 out: self.compile_variable(out),
1043 }),
1044 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1045 lhs: self.compile_variable(op.lhs),
1046 rhs: self.compile_variable(op.rhs),
1047 out: self.compile_variable(out),
1048 }),
1049 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1050 input: self.compile_variable(op.input),
1051 out: self.compile_variable(out),
1052 }),
1053 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1054 input: self.compile_variable(op.input),
1055 out: self.compile_variable(out),
1056 }),
1057 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1058 lhs: self.compile_variable(op.lhs),
1059 rhs: self.compile_variable(op.rhs),
1060 out: self.compile_variable(out),
1061 }),
1062 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1063 lhs: self.compile_variable(op.lhs),
1064 rhs: self.compile_variable(op.rhs),
1065 out: self.compile_variable(out),
1066 }),
1067 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1068 input: self.compile_variable(op.input),
1069 out: self.compile_variable(out),
1070 }),
1071 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1072 input: self.compile_variable(op.input),
1073 out: self.compile_variable(out),
1074 }),
1075 cube::Bitwise::TrailingZeros(op) => {
1076 instructions.push(wgsl::Instruction::TrailingZeros {
1077 input: self.compile_variable(op.input),
1078 out: self.compile_variable(out),
1079 })
1080 }
1081 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1082 input: self.compile_variable(op.input),
1083 out: self.compile_variable(out),
1084 }),
1085 }
1086 }
1087
1088 fn compile_operator(
1089 &mut self,
1090 value: cube::Operator,
1091 out: Option<cube::Variable>,
1092 instructions: &mut Vec<wgsl::Instruction>,
1093 ) {
1094 let out = out.unwrap();
1095 match value {
1096 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1097 input: self.compile_variable(op.input),
1098 out: self.compile_variable(out),
1099 }),
1100 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1101 instructions.push(wgsl::Instruction::Index {
1102 lhs: self.compile_variable(op.list),
1103 rhs: self.compile_variable(op.index),
1104 out: self.compile_variable(out),
1105 });
1106 }
1107 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1108 instructions.push(wgsl::Instruction::IndexAssign {
1109 index: self.compile_variable(op.index),
1110 rhs: self.compile_variable(op.value),
1111 out: self.compile_variable(out),
1112 })
1113 }
1114 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1115 lhs: self.compile_variable(op.lhs),
1116 rhs: self.compile_variable(op.rhs),
1117 out: self.compile_variable(out),
1118 }),
1119 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1120 lhs: self.compile_variable(op.lhs),
1121 rhs: self.compile_variable(op.rhs),
1122 out: self.compile_variable(out),
1123 }),
1124 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1125 input: self.compile_variable(op.input),
1126 out: self.compile_variable(out),
1127 }),
1128 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1129 input: self.compile_variable(op.input),
1130 out: self.compile_variable(out),
1131 }),
1132 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1133 inputs: op
1134 .inputs
1135 .into_iter()
1136 .map(|var| self.compile_variable(var))
1137 .collect(),
1138 out: self.compile_variable(out),
1139 }),
1140 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1141 input: self.compile_variable(op.input),
1142 in_index: self.compile_variable(op.in_index),
1143 out: self.compile_variable(out),
1144 out_index: self.compile_variable(op.out_index),
1145 }),
1146 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1147 input: self.compile_variable(op.input),
1148 in_index: self.compile_variable(op.in_index),
1149 out: self.compile_variable(out),
1150 out_index: self.compile_variable(op.out_index),
1151 len: op.len as u32,
1152 }),
1153 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1154 cond: self.compile_variable(op.cond),
1155 then: self.compile_variable(op.then),
1156 or_else: self.compile_variable(op.or_else),
1157 out: self.compile_variable(out),
1158 }),
1159 }
1160 }
1161
1162 fn compile_atomic(
1163 &mut self,
1164 atomic: cube::AtomicOp,
1165 out: Option<cube::Variable>,
1166 ) -> wgsl::Instruction {
1167 let out = out.unwrap();
1168 match atomic {
1169 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1170 lhs: self.compile_variable(op.lhs),
1171 rhs: self.compile_variable(op.rhs),
1172 out: self.compile_variable(out),
1173 },
1174 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1175 lhs: self.compile_variable(op.lhs),
1176 rhs: self.compile_variable(op.rhs),
1177 out: self.compile_variable(out),
1178 },
1179 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1180 lhs: self.compile_variable(op.lhs),
1181 rhs: self.compile_variable(op.rhs),
1182 out: self.compile_variable(out),
1183 },
1184 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1185 lhs: self.compile_variable(op.lhs),
1186 rhs: self.compile_variable(op.rhs),
1187 out: self.compile_variable(out),
1188 },
1189 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1190 lhs: self.compile_variable(op.lhs),
1191 rhs: self.compile_variable(op.rhs),
1192 out: self.compile_variable(out),
1193 },
1194 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1195 lhs: self.compile_variable(op.lhs),
1196 rhs: self.compile_variable(op.rhs),
1197 out: self.compile_variable(out),
1198 },
1199 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1200 lhs: self.compile_variable(op.lhs),
1201 rhs: self.compile_variable(op.rhs),
1202 out: self.compile_variable(out),
1203 },
1204 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1205 input: self.compile_variable(op.input),
1206 out: self.compile_variable(out),
1207 },
1208 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1209 input: self.compile_variable(op.input),
1210 out: self.compile_variable(out),
1211 },
1212 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1213 lhs: self.compile_variable(op.lhs),
1214 rhs: self.compile_variable(op.rhs),
1215 out: self.compile_variable(out),
1216 },
1217 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1218 lhs: self.compile_variable(op.input),
1219 cmp: self.compile_variable(op.cmp),
1220 value: self.compile_variable(op.val),
1221 out: self.compile_variable(out),
1222 },
1223 }
1224 }
1225
1226 fn compile_location(value: kernel::Location) -> wgsl::Location {
1227 match value {
1228 kernel::Location::Storage => wgsl::Location::Storage,
1229 kernel::Location::Cube => wgsl::Location::Workgroup,
1230 }
1231 }
1232
1233 fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1234 wgsl::Binding {
1235 id: value.id,
1236 visibility: value.visibility,
1237 location: Self::compile_location(value.location),
1238 item: self.compile_type(value.ty),
1239 size: value.size,
1240 }
1241 }
1242}
1243
1244fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1245 let mut extensions = Vec::new();
1246
1247 let mut register_extension = |extension: wgsl::Extension| {
1248 if !extensions.contains(&extension) {
1249 extensions.push(extension);
1250 }
1251 };
1252
1253 for instruction in instructions {
1255 match instruction {
1256 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1257 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1258 register_extension(wgsl::powf_extension(rhs, out));
1259 }
1260 #[cfg(target_os = "macos")]
1261 wgsl::Instruction::Tanh { input, out: _ } => {
1262 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1263 register_extension(wgsl::Extension::SafeTanh(input.item()));
1264 }
1265 wgsl::Instruction::IsNan { input, out } => {
1266 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1267 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1268 }
1269 wgsl::Instruction::IsInf { input, out } => {
1270 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1271 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1272 }
1273 wgsl::Instruction::If { instructions, .. } => {
1274 for extension in register_extensions(instructions) {
1275 register_extension(extension);
1276 }
1277 }
1278 wgsl::Instruction::IfElse {
1279 instructions_if,
1280 instructions_else,
1281 ..
1282 } => {
1283 for extension in register_extensions(instructions_if) {
1284 register_extension(extension);
1285 }
1286 for extension in register_extensions(instructions_else) {
1287 register_extension(extension);
1288 }
1289 }
1290 wgsl::Instruction::Loop { instructions } => {
1291 for extension in register_extensions(instructions) {
1292 register_extension(extension);
1293 }
1294 }
1295 wgsl::Instruction::RangeLoop { instructions, .. } => {
1296 for extension in register_extensions(instructions) {
1297 register_extension(extension);
1298 }
1299 }
1300 _ => {}
1301 }
1302 }
1303
1304 extensions
1305}