1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::prelude::*;
8use cubecl_core::{
9 Info,
10 post_processing::{checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor},
11};
12use cubecl_core::{
13 Metadata, WgpuCompilationOptions,
14 ir::{self as cube, Scope},
15 prelude::expand_erf,
16};
17use cubecl_core::{
18 ir::{Processor, UIntKind},
19 post_processing::unroll::UnrollProcessor,
20};
21use cubecl_runtime::compiler::CompilationError;
22use cubecl_runtime::kernel;
23
24pub const MAX_VECTOR_SIZE: usize = 4;
25
26#[derive(Clone, Default)]
28pub struct WgslCompiler {
29 kernel_name: String,
30 info: Info,
31 ext_meta_pos: Vec<u32>,
32 local_invocation_index: bool,
33 local_invocation_id: bool,
34 global_invocation_id: bool,
36 workgroup_id: bool,
37 subgroup_size: bool,
38 subgroup_id: bool,
39 subgroup_invocation_id: bool,
40 id: bool,
41 num_workgroups: bool,
42 workgroup_id_no_axis: bool,
43 workgroup_size_no_axis: bool,
44 num_workgroup_no_axis: bool,
45 shared_arrays: Vec<SharedArray>,
46 shared_values: Vec<SharedValue>,
47 const_arrays: Vec<ConstantArray>,
48 local_arrays: Vec<LocalArray>,
49 #[allow(dead_code)]
50 compilation_options: WgpuCompilationOptions,
51 strategy: ExecutionMode,
52 subgroup_instructions_used: bool,
53 f16_used: bool,
54}
55
56impl core::fmt::Debug for WgslCompiler {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.write_str("WgslCompiler")
59 }
60}
61
62impl cubecl_core::Compiler for WgslCompiler {
63 type Representation = ComputeShader;
64 type CompilationOptions = WgpuCompilationOptions;
65
66 fn compile(
67 &mut self,
68 shader: kernel::KernelDefinition,
69 compilation_options: &Self::CompilationOptions,
70 mode: ExecutionMode,
71 address_type: StorageType,
72 ) -> Result<Self::Representation, CompilationError> {
73 self.compilation_options = *compilation_options;
74 self.compile_shader(shader, mode, address_type)
75 }
76
77 fn elem_size(&self, elem: cube::ElemType) -> usize {
78 elem.size()
79 }
80
81 fn extension(&self) -> &'static str {
82 "wgsl"
83 }
84}
85
86impl WgslCompiler {
87 fn compile_shader(
88 &mut self,
89 mut value: kernel::KernelDefinition,
90 mode: ExecutionMode,
91 address_type: StorageType,
92 ) -> Result<wgsl::ComputeShader, CompilationError> {
93 let errors = value.body.pop_errors();
94 if !errors.is_empty() {
95 let mut reason = "Can't compile wgsl kernel".to_string();
96 for error in errors {
97 reason += error.as_str();
98 reason += "\n";
99 }
100
101 return Err(CompilationError::Validation {
102 reason,
103 backtrace: BackTrace::capture(),
104 });
105 }
106
107 self.strategy = mode;
108 self.kernel_name = value.options.kernel_name.clone();
109
110 let num_meta = value.buffers.len();
111
112 self.ext_meta_pos = Vec::new();
113 let mut num_ext = 0;
114
115 for binding in value.buffers.iter() {
116 self.ext_meta_pos.push(num_ext);
117 if binding.has_extended_meta {
118 num_ext += 1;
119 }
120 }
121
122 let metadata = Metadata::new(num_meta as u32, num_ext);
123 self.info = Info::new(&value.scalars, metadata, address_type);
124
125 let address_type = self.compile_storage_type(address_type);
126 let instructions = self.compile_scope(&mut value.body);
127 let extensions = register_extensions(&instructions);
128 let body = wgsl::Body {
129 instructions,
130 id: self.id,
131 address_type,
132 };
133
134 Ok(wgsl::ComputeShader {
135 address_type,
136 buffers: value
137 .buffers
138 .into_iter()
139 .map(|mut it| {
140 if it.ty.vector_size() > MAX_VECTOR_SIZE {
143 it.ty = it.ty.with_vector_size(MAX_VECTOR_SIZE);
144 }
145 self.compile_binding(it)
146 })
147 .collect(),
148 scalars: value
149 .scalars
150 .into_iter()
151 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
152 .collect(),
153 shared_arrays: self.shared_arrays.clone(),
154 shared_values: self.shared_values.clone(),
155 constant_arrays: self.const_arrays.clone(),
156 local_arrays: self.local_arrays.clone(),
157 static_meta_len: self.info.metadata.static_len() as usize,
158 info: self.info.clone(),
159 workgroup_size: value.cube_dim,
160 global_invocation_id: self.global_invocation_id || self.id,
161 local_invocation_index: self.local_invocation_index,
162 local_invocation_id: self.local_invocation_id,
163 num_workgroups: self.id
164 || self.num_workgroups
165 || self.num_workgroup_no_axis
166 || self.workgroup_id_no_axis,
167 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
168 subgroup_size: self.subgroup_size,
169 subgroup_id: self.subgroup_id,
170 subgroup_invocation_id: self.subgroup_invocation_id,
171 body,
172 extensions,
173 num_workgroups_no_axis: self.num_workgroup_no_axis,
174 workgroup_id_no_axis: self.workgroup_id_no_axis,
175 workgroup_size_no_axis: self.workgroup_size_no_axis,
176 subgroup_instructions_used: self.subgroup_instructions_used,
177 f16_used: self.f16_used,
178 kernel_name: value.options.kernel_name,
179 })
180 }
181
182 fn compile_type(&mut self, item: cube::Type) -> Item {
183 match item {
184 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
185 cube::Type::Vector(ty, size) => {
186 let elem = self.compile_storage_type(ty);
187 match size {
188 2 => wgsl::Item::Vec2(elem),
189 3 => wgsl::Item::Vec3(elem),
190 4 => wgsl::Item::Vec4(elem),
191 _ => panic!("Unsupported vectorizations scheme {:?}", item.vector_size()),
192 }
193 }
194 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
195 }
196 }
197
198 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
199 match ty {
200 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
201 cube::StorageType::Atomic(ty) => match ty {
202 cube::ElemType::Float(i) => match i {
203 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
204 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
205 },
206 cube::ElemType::Int(i) => match i {
207 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
208 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
209 },
210 cube::ElemType::UInt(kind) => match kind {
211 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
212 kind => panic!("{kind:?} is not a valid WgpuElement"),
213 },
214 other => panic!("{other:?} is not a valid WgpuElement"),
215 },
216 cube::StorageType::Packed(_, _) => {
217 unimplemented!("Packed types not yet supported in WGSL")
218 }
219 cube::StorageType::Opaque(ty) => match ty {
220 cube::OpaqueType::Barrier(_) => {
221 unimplemented!("Barrier objects not supported in WGSL")
222 }
223 },
224 }
225 }
226
227 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
228 match value {
229 cube::ElemType::Float(f) => match f {
230 cube::FloatKind::E2M1
231 | cube::FloatKind::E2M3
232 | cube::FloatKind::E3M2
233 | cube::FloatKind::E4M3
234 | cube::FloatKind::E5M2
235 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
236 cube::FloatKind::F16 => {
237 self.f16_used = true;
238 wgsl::Elem::F16
239 }
240 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
241 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
242 cube::FloatKind::Flex32 => wgsl::Elem::F32,
243 cube::FloatKind::F32 => wgsl::Elem::F32,
244 cube::FloatKind::F64 => wgsl::Elem::F64,
245 },
246 cube::ElemType::Int(i) => match i {
247 cube::IntKind::I32 => wgsl::Elem::I32,
248 cube::IntKind::I64 => wgsl::Elem::I64,
249 kind => panic!("{kind:?} is not a valid WgpuElement"),
250 },
251 cube::ElemType::UInt(kind) => match kind {
252 cube::UIntKind::U32 => wgsl::Elem::U32,
253 cube::UIntKind::U64 => wgsl::Elem::U64,
254 kind => panic!("{kind:?} is not a valid WgpuElement"),
255 },
256 cube::ElemType::Bool => wgsl::Elem::Bool,
257 }
258 }
259
260 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
261 let pos = var.index().expect("Variable should have index");
262 self.ext_meta_pos[pos as usize]
263 }
264
265 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
266 let item = value.ty;
267 match value.kind {
268 cube::VariableKind::GlobalInputArray(id) => {
269 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
270 }
271 cube::VariableKind::GlobalScalar(id) => {
272 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
273 }
274 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
275 wgsl::Variable::LocalMut {
276 id,
277 item: self.compile_type(item),
278 }
279 }
280 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
281 id,
282 item: self.compile_type(item),
283 },
284 cube::VariableKind::GlobalOutputArray(id) => {
285 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
286 }
287 cube::VariableKind::Constant(value) => {
288 wgsl::Variable::Constant(value, self.compile_type(item))
289 }
290 cube::VariableKind::SharedArray {
291 id,
292 length,
293 unroll_factor,
294 alignment,
295 } => {
296 let item = self.compile_type(item);
297 if !self.shared_arrays.iter().any(|s| s.index == id) {
298 self.shared_arrays.push(SharedArray::new(
299 id,
300 item,
301 (length * unroll_factor) as u32,
302 alignment.map(|it| it as u32),
303 ));
304 }
305 wgsl::Variable::SharedArray(id, item, length as u32)
306 }
307 cube::VariableKind::Shared { id } => {
308 let item = self.compile_type(item);
309 if !self.shared_values.iter().any(|s| s.index == id) {
310 self.shared_values.push(SharedValue::new(id, item));
311 }
312 wgsl::Variable::SharedValue(id, item)
313 }
314 cube::VariableKind::ConstantArray { id, length, .. } => {
315 let item = self.compile_type(item);
316 wgsl::Variable::ConstantArray(id, item, length as u32)
317 }
318 cube::VariableKind::LocalArray {
319 id,
320 length,
321 unroll_factor,
322 } => {
323 let item = self.compile_type(item);
324 if !self.local_arrays.iter().any(|s| s.index == id) {
325 self.local_arrays.push(LocalArray::new(
326 id,
327 item,
328 (length * unroll_factor) as u32,
329 ));
330 }
331 wgsl::Variable::LocalArray(id, item, length as u32)
332 }
333 cube::VariableKind::Builtin(builtin) => match builtin {
334 cube::Builtin::AbsolutePos => {
335 self.id = true;
336 wgsl::Variable::Id
337 }
338 cube::Builtin::UnitPos => {
339 self.local_invocation_index = true;
340 wgsl::Variable::LocalInvocationIndex
341 }
342 cube::Builtin::UnitPosX => {
343 self.local_invocation_id = true;
344 wgsl::Variable::LocalInvocationIdX
345 }
346 cube::Builtin::UnitPosY => {
347 self.local_invocation_id = true;
348 wgsl::Variable::LocalInvocationIdY
349 }
350 cube::Builtin::UnitPosZ => {
351 self.local_invocation_id = true;
352 wgsl::Variable::LocalInvocationIdZ
353 }
354 cube::Builtin::CubePosX => {
355 self.workgroup_id = true;
356 wgsl::Variable::WorkgroupIdX
357 }
358 cube::Builtin::CubePosY => {
359 self.workgroup_id = true;
360 wgsl::Variable::WorkgroupIdY
361 }
362 cube::Builtin::CubePosZ => {
363 self.workgroup_id = true;
364 wgsl::Variable::WorkgroupIdZ
365 }
366 cube::Builtin::CubePosCluster
367 | cube::Builtin::CubePosClusterX
368 | cube::Builtin::CubePosClusterY
369 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
370 cube::Builtin::AbsolutePosX => {
371 self.global_invocation_id = true;
372 wgsl::Variable::GlobalInvocationIdX
373 }
374 cube::Builtin::AbsolutePosY => {
375 self.global_invocation_id = true;
376 wgsl::Variable::GlobalInvocationIdY
377 }
378 cube::Builtin::AbsolutePosZ => {
379 self.global_invocation_id = true;
380 wgsl::Variable::GlobalInvocationIdZ
381 }
382 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
383 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
384 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
385 cube::Builtin::CubeClusterDim
386 | cube::Builtin::CubeClusterDimX
387 | cube::Builtin::CubeClusterDimY
388 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
389 cube::Builtin::CubeCountX => {
390 self.num_workgroups = true;
391 wgsl::Variable::NumWorkgroupsX
392 }
393 cube::Builtin::CubeCountY => {
394 self.num_workgroups = true;
395 wgsl::Variable::NumWorkgroupsY
396 }
397 cube::Builtin::CubeCountZ => {
398 self.num_workgroups = true;
399 wgsl::Variable::NumWorkgroupsZ
400 }
401 cube::Builtin::CubePos => {
402 self.workgroup_id_no_axis = true;
403 wgsl::Variable::WorkgroupId
404 }
405 cube::Builtin::CubeDim => {
406 self.workgroup_size_no_axis = true;
407 wgsl::Variable::WorkgroupSize
408 }
409 cube::Builtin::CubeCount => {
410 self.num_workgroup_no_axis = true;
411 wgsl::Variable::NumWorkgroups
412 }
413 cube::Builtin::PlaneDim => {
414 self.subgroup_size = true;
415 wgsl::Variable::SubgroupSize
416 }
417 cube::Builtin::PlanePos => {
418 self.subgroup_id = true;
419 wgsl::Variable::SubgroupId
420 }
421 cube::Builtin::UnitPosPlane => {
422 self.subgroup_invocation_id = true;
423 wgsl::Variable::SubgroupInvocationId
424 }
425 },
426 cube::VariableKind::Matrix { .. } => {
427 panic!("Cooperative matrix-multiply and accumulate not supported.")
428 }
429 cube::VariableKind::Pipeline { .. } => {
430 panic!("Pipeline not supported.")
431 }
432 cube::VariableKind::BarrierToken { .. } => {
433 panic!("Barrier not supported.")
434 }
435 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
436 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
437 }
438 }
439
440 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
441 let var = cube::Variable::constant(value.into(), UIntKind::U32);
442 self.compile_variable(var)
443 }
444
445 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
446 let mut instructions = Vec::new();
447
448 let const_arrays = scope
449 .const_arrays
450 .drain(..)
451 .map(|(var, values)| ConstantArray {
452 index: var.index().unwrap(),
453 item: self.compile_type(var.ty),
454 size: values.len() as u32,
455 values: values
456 .into_iter()
457 .map(|val| self.compile_variable(val))
458 .collect(),
459 })
460 .collect::<Vec<_>>();
461 self.const_arrays.extend(const_arrays);
462
463 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(
464 self.strategy,
465 self.kernel_name.clone(),
466 ));
467 let unroll = Box::new(UnrollProcessor::new(MAX_VECTOR_SIZE));
468 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
469 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
470
471 for mut var in processing.variables {
472 if var.ty.vector_size() > MAX_VECTOR_SIZE {
473 var.ty = var.ty.with_vector_size(MAX_VECTOR_SIZE);
474 }
475 instructions.push(wgsl::Instruction::DeclareVariable {
476 var: self.compile_variable(var),
477 });
478 }
479
480 processing
481 .instructions
482 .into_iter()
483 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
484
485 instructions
486 }
487
488 fn compile_operation(
489 &mut self,
490 instructions: &mut Vec<wgsl::Instruction>,
491 operation: cube::Operation,
492 out: Option<cube::Variable>,
493 scope: &mut cube::Scope,
494 ) {
495 match operation {
496 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
497 input: self.compile_variable(variable),
498 out: self.compile_variable(out.unwrap()),
499 }),
500 cube::Operation::Arithmetic(op) => {
501 self.compile_arithmetic(op, out, instructions, scope)
502 }
503 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
504 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
505 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
506 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
507 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
508 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
509 cube::Operation::Synchronization(val) => {
510 self.compile_synchronization(instructions, val)
511 }
512 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
513 cube::Operation::CoopMma(_) => {
514 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
515 }
516 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
517 self.compile_comment(instructions, content)
518 }
519 cube::Operation::NonSemantic(_) => {}
520 cube::Operation::Barrier(_) => {
521 panic!("Barrier isn't supported on wgpu.")
522 }
523 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
524 cube::Operation::Marker(_) => {}
525 }
526 }
527
528 fn compile_subgroup(
529 &mut self,
530 instructions: &mut Vec<wgsl::Instruction>,
531 subgroup: cube::Plane,
532 out: Option<cube::Variable>,
533 ) {
534 self.subgroup_instructions_used = true;
535
536 let out = out.unwrap();
537 let op = match subgroup {
538 cube::Plane::Elect => Subgroup::Elect {
539 out: self.compile_variable(out),
540 },
541 cube::Plane::All(op) => Subgroup::All {
542 input: self.compile_variable(op.input),
543 out: self.compile_variable(out),
544 },
545 cube::Plane::Any(op) => Subgroup::Any {
546 input: self.compile_variable(op.input),
547 out: self.compile_variable(out),
548 },
549 cube::Plane::Ballot(op) => Subgroup::Ballot {
550 input: self.compile_variable(op.input),
551 out: self.compile_variable(out),
552 },
553
554 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
555 lhs: self.compile_variable(op.lhs),
556 rhs: self.compile_variable(op.rhs),
557 out: self.compile_variable(out),
558 },
559
560 cube::Plane::Sum(op) => Subgroup::Sum {
561 input: self.compile_variable(op.input),
562 out: self.compile_variable(out),
563 },
564
565 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
566 input: self.compile_variable(op.input),
567 out: self.compile_variable(out),
568 },
569 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
570 input: self.compile_variable(op.input),
571 out: self.compile_variable(out),
572 },
573 cube::Plane::Prod(op) => Subgroup::Prod {
574 input: self.compile_variable(op.input),
575 out: self.compile_variable(out),
576 },
577 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
578 input: self.compile_variable(op.input),
579 out: self.compile_variable(out),
580 },
581 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
582 input: self.compile_variable(op.input),
583 out: self.compile_variable(out),
584 },
585 cube::Plane::Min(op) => Subgroup::Min {
586 input: self.compile_variable(op.input),
587 out: self.compile_variable(out),
588 },
589 cube::Plane::Max(op) => Subgroup::Max {
590 input: self.compile_variable(op.input),
591 out: self.compile_variable(out),
592 },
593 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
594 lhs: self.compile_variable(op.lhs),
595 rhs: self.compile_variable(op.rhs),
596 out: self.compile_variable(out),
597 },
598 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
599 lhs: self.compile_variable(op.lhs),
600 rhs: self.compile_variable(op.rhs),
601 out: self.compile_variable(out),
602 },
603 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
604 lhs: self.compile_variable(op.lhs),
605 rhs: self.compile_variable(op.rhs),
606 out: self.compile_variable(out),
607 },
608 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
609 lhs: self.compile_variable(op.lhs),
610 rhs: self.compile_variable(op.rhs),
611 out: self.compile_variable(out),
612 },
613 };
614
615 instructions.push(wgsl::Instruction::Subgroup(op));
616 }
617
618 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
619 match branch {
620 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
621 cond: self.compile_variable(op.cond),
622 instructions: self.compile_scope(&mut op.scope),
623 }),
624 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
625 cond: self.compile_variable(op.cond),
626 instructions_if: self.compile_scope(&mut op.scope_if),
627 instructions_else: self.compile_scope(&mut op.scope_else),
628 }),
629 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
630 value: self.compile_variable(op.value),
631 instructions_default: self.compile_scope(&mut op.scope_default),
632 cases: op
633 .cases
634 .into_iter()
635 .map(|(val, mut scope)| {
636 (self.compile_variable(val), self.compile_scope(&mut scope))
637 })
638 .collect(),
639 }),
640 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
641 cube::Branch::Unreachable => instructions.push(wgsl::Instruction::Return),
643 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
644 cube::Branch::RangeLoop(mut range_loop) => {
645 instructions.push(wgsl::Instruction::RangeLoop {
646 i: self.compile_variable(range_loop.i),
647 start: self.compile_variable(range_loop.start),
648 end: self.compile_variable(range_loop.end),
649 step: range_loop.step.map(|it| self.compile_variable(it)),
650 inclusive: range_loop.inclusive,
651 instructions: self.compile_scope(&mut range_loop.scope),
652 })
653 }
654 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
655 instructions: self.compile_scope(&mut op.scope),
656 }),
657 };
658 }
659
660 fn compile_synchronization(
661 &mut self,
662 instructions: &mut Vec<wgsl::Instruction>,
663 synchronization: cube::Synchronization,
664 ) {
665 match synchronization {
666 cube::Synchronization::SyncCube => {
667 instructions.push(wgsl::Instruction::WorkgroupBarrier)
668 }
669 cube::Synchronization::SyncPlane => {
670 panic!("Synchronization within a plane is not supported in WGSL")
671 }
672 cube::Synchronization::SyncStorage => {
673 instructions.push(wgsl::Instruction::StorageBarrier)
674 }
675 cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
676 };
677 }
678
679 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
680 instructions.push(wgsl::Instruction::Comment { content })
681 }
682
683 fn compile_metadata(
684 &mut self,
685 metadata: cube::Metadata,
686 out: Option<cube::Variable>,
687 ) -> wgsl::Instruction {
688 let out = out.unwrap();
689 match metadata {
690 cube::Metadata::Rank { var } => {
691 let position = self.ext_meta_pos(&var);
692 let offset = self.info.metadata.rank_index(position);
693 wgsl::Instruction::Metadata {
694 out: self.compile_variable(out),
695 info_offset: self.compile_variable(offset.into()),
696 }
697 }
698 cube::Metadata::Stride { dim, var } => {
699 let position = self.ext_meta_pos(&var);
700 let offset = self.info.metadata.stride_offset_index(position);
701 wgsl::Instruction::ExtendedMeta {
702 info_offset: self.compile_variable(offset.into()),
703 dim: self.compile_variable(dim),
704 out: self.compile_variable(out),
705 }
706 }
707 cube::Metadata::Shape { dim, var } => {
708 let position = self.ext_meta_pos(&var);
709 let offset = self.info.metadata.shape_offset_index(position);
710 wgsl::Instruction::ExtendedMeta {
711 info_offset: self.compile_variable(offset.into()),
712 dim: self.compile_variable(dim),
713 out: self.compile_variable(out),
714 }
715 }
716 cube::Metadata::Length { var } => match var.kind {
717 cube::VariableKind::GlobalInputArray(id) => {
718 let offset = self.info.metadata.len_index(id);
719 wgsl::Instruction::Metadata {
720 out: self.compile_variable(out),
721 info_offset: self.compile_variable(offset.into()),
722 }
723 }
724 cube::VariableKind::GlobalOutputArray(id) => {
725 let offset = self.info.metadata.len_index(id);
726 wgsl::Instruction::Metadata {
727 out: self.compile_variable(out),
728 info_offset: self.compile_variable(offset.into()),
729 }
730 }
731 _ => wgsl::Instruction::Length {
732 var: self.compile_variable(var),
733 out: self.compile_variable(out),
734 },
735 },
736 cube::Metadata::BufferLength { var } => match var.kind {
737 cube::VariableKind::GlobalInputArray(id) => {
738 let offset = self.info.metadata.buffer_len_index(id);
739 wgsl::Instruction::Metadata {
740 out: self.compile_variable(out),
741 info_offset: self.compile_variable(offset.into()),
742 }
743 }
744 cube::VariableKind::GlobalOutputArray(id) => {
745 let offset = self.info.metadata.buffer_len_index(id);
746 wgsl::Instruction::Metadata {
747 out: self.compile_variable(out),
748 info_offset: self.compile_variable(offset.into()),
749 }
750 }
751 _ => wgsl::Instruction::Length {
752 var: self.compile_variable(var),
753 out: self.compile_variable(out),
754 },
755 },
756 }
757 }
758
759 fn compile_arithmetic(
760 &mut self,
761 value: cube::Arithmetic,
762 out: Option<cube::Variable>,
763 instructions: &mut Vec<wgsl::Instruction>,
764 scope: &mut Scope,
765 ) {
766 let out = out.unwrap();
767 match value {
768 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
769 lhs: self.compile_variable(op.lhs),
770 rhs: self.compile_variable(op.rhs),
771 out: self.compile_variable(out),
772 }),
773 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
774 lhs: self.compile_variable(op.lhs),
775 rhs: self.compile_variable(op.rhs),
776 out: self.compile_variable(out),
777 }),
778 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
779 lhs: self.compile_variable(op.lhs),
780 rhs: self.compile_variable(op.rhs),
781 out: self.compile_variable(out),
782 }),
783 cube::Arithmetic::SaturatingAdd(_) => {
784 unreachable!("Saturating add should be removed by processor");
785 }
786 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
787 a: self.compile_variable(op.a),
788 b: self.compile_variable(op.b),
789 c: self.compile_variable(op.c),
790 out: self.compile_variable(out),
791 }),
792 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
793 lhs: self.compile_variable(op.lhs),
794 rhs: self.compile_variable(op.rhs),
795 out: self.compile_variable(out),
796 }),
797 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
798 lhs: self.compile_variable(op.lhs),
799 rhs: self.compile_variable(op.rhs),
800 out: self.compile_variable(out),
801 }),
802 cube::Arithmetic::SaturatingSub(_) => {
803 unreachable!("Saturating sub should be removed by processor");
804 }
805 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
806 lhs: self.compile_variable(op.lhs),
807 rhs: self.compile_variable(op.rhs),
808 out: self.compile_variable(out),
809 }),
810 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
811 lhs: self.compile_variable(op.lhs),
812 rhs: self.compile_variable(op.rhs),
813 out: self.compile_variable(out),
814 }),
815 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
816 input: self.compile_variable(op.input),
817 out: self.compile_variable(out),
818 }),
819 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
820 input: self.compile_variable(op.input),
821 out: self.compile_variable(out),
822 }),
823 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
824 input: self.compile_variable(op.input),
825 out: self.compile_variable(out),
826 }),
827 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
828 input: self.compile_variable(op.input),
829 out: self.compile_variable(out),
830 }),
831 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
832 input: self.compile_variable(op.input),
833 out: self.compile_variable(out),
834 }),
835 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
836 input: self.compile_variable(op.input),
837 out: self.compile_variable(out),
838 }),
839 cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
840 input: self.compile_variable(op.input),
841 out: self.compile_variable(out),
842 }),
843 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
844 input: self.compile_variable(op.input),
845 out: self.compile_variable(out),
846 }),
847 cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
848 input: self.compile_variable(op.input),
849 out: self.compile_variable(out),
850 }),
851 cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
852 input: self.compile_variable(op.input),
853 out: self.compile_variable(out),
854 }),
855 cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
856 input: self.compile_variable(op.input),
857 out: self.compile_variable(out),
858 }),
859 cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
860 input: self.compile_variable(op.input),
861 out: self.compile_variable(out),
862 }),
863 cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
864 input: self.compile_variable(op.input),
865 out: self.compile_variable(out),
866 }),
867 cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
868 input: self.compile_variable(op.input),
869 out: self.compile_variable(out),
870 }),
871 cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
872 input: self.compile_variable(op.input),
873 out: self.compile_variable(out),
874 }),
875 cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
876 input: self.compile_variable(op.input),
877 out: self.compile_variable(out),
878 }),
879 cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
880 input: self.compile_variable(op.input),
881 out: self.compile_variable(out),
882 }),
883 cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
884 input: self.compile_variable(op.input),
885 out: self.compile_variable(out),
886 }),
887 cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
888 lhs: self.compile_variable(op.lhs),
889 rhs: self.compile_variable(op.rhs),
890 out: self.compile_variable(out),
891 }),
892 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
894 instructions.push(wgsl::Instruction::Powf {
895 lhs: self.compile_variable(op.lhs),
896 rhs: self.compile_variable(op.rhs),
897 out: self.compile_variable(out),
898 })
899 }
900 cube::Arithmetic::Hypot(op) => {
901 let mut scope = scope.child();
902 expand_hypot(&mut scope, op.lhs, op.rhs, out);
903 instructions.extend(self.compile_scope(&mut scope));
904 }
905 cube::Arithmetic::Rhypot(op) => {
906 let mut scope = scope.child();
907 expand_rhypot(&mut scope, op.lhs, op.rhs, out);
908 instructions.extend(self.compile_scope(&mut scope));
909 }
910
911 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
912 input: self.compile_variable(op.input),
913 out: self.compile_variable(out),
914 }),
915 cube::Arithmetic::InverseSqrt(op) => {
916 instructions.push(wgsl::Instruction::InverseSqrt {
917 input: self.compile_variable(op.input),
918 out: self.compile_variable(out),
919 })
920 }
921 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
922 input: self.compile_variable(op.input),
923 out: self.compile_variable(out),
924 }),
925 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
926 input: self.compile_variable(op.input),
927 out: self.compile_variable(out),
928 }),
929 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
930 input: self.compile_variable(op.input),
931 out: self.compile_variable(out),
932 }),
933 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
934 input: self.compile_variable(op.input),
935 out: self.compile_variable(out),
936 }),
937 cube::Arithmetic::Erf(op) => {
938 let mut scope = scope.child();
939 expand_erf(&mut scope, op.input, out);
940 instructions.extend(self.compile_scope(&mut scope));
941 }
942 cube::Arithmetic::MulHi(op) => {
943 let mut scope = scope.child();
944 match self.compilation_options.supports_u64 {
945 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
946 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
947 }
948 instructions.extend(self.compile_scope(&mut scope));
949 }
950 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
951 input: self.compile_variable(op.input),
952 out: self.compile_variable(out),
953 }),
954 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
955 input: self.compile_variable(op.input),
956 min_value: self.compile_variable(op.min_value),
957 max_value: self.compile_variable(op.max_value),
958 out: self.compile_variable(out),
959 }),
960 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
961 lhs: self.compile_variable(op.lhs),
962 rhs: self.compile_variable(op.rhs),
963 out: self.compile_variable(out),
964 }),
965 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
966 input: self.compile_variable(op.input),
967 out: self.compile_variable(out),
968 }),
969 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
970 input: self.compile_variable(op.input),
971 out: self.compile_variable(out),
972 }),
973 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
974 input: self.compile_variable(op.input),
975 out: self.compile_variable(out),
976 }),
977 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
978 lhs: self.compile_variable(op.lhs),
979 rhs: self.compile_variable(op.rhs),
980 out: self.compile_variable(out),
981 }),
982 cube::Arithmetic::VectorSum(op) => instructions.push(wgsl::Instruction::VectorSum {
983 input: self.compile_variable(op.input),
984 out: self.compile_variable(out),
985 }),
986 }
987 }
988
989 fn compile_cmp(
990 &mut self,
991 value: cube::Comparison,
992 out: Option<cube::Variable>,
993 instructions: &mut Vec<wgsl::Instruction>,
994 ) {
995 let out = out.unwrap();
996 match value {
997 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
998 lhs: self.compile_variable(op.lhs),
999 rhs: self.compile_variable(op.rhs),
1000 out: self.compile_variable(out),
1001 }),
1002 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
1003 lhs: self.compile_variable(op.lhs),
1004 rhs: self.compile_variable(op.rhs),
1005 out: self.compile_variable(out),
1006 }),
1007 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
1008 lhs: self.compile_variable(op.lhs),
1009 rhs: self.compile_variable(op.rhs),
1010 out: self.compile_variable(out),
1011 }),
1012 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
1013 lhs: self.compile_variable(op.lhs),
1014 rhs: self.compile_variable(op.rhs),
1015 out: self.compile_variable(out),
1016 }),
1017 cube::Comparison::GreaterEqual(op) => {
1018 instructions.push(wgsl::Instruction::GreaterEqual {
1019 lhs: self.compile_variable(op.lhs),
1020 rhs: self.compile_variable(op.rhs),
1021 out: self.compile_variable(out),
1022 })
1023 }
1024 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
1025 lhs: self.compile_variable(op.lhs),
1026 rhs: self.compile_variable(op.rhs),
1027 out: self.compile_variable(out),
1028 }),
1029 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1030 input: self.compile_variable(op.input),
1031 out: self.compile_variable(out),
1032 }),
1033 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1034 input: self.compile_variable(op.input),
1035 out: self.compile_variable(out),
1036 }),
1037 }
1038 }
1039
1040 fn compile_bitwise(
1041 &mut self,
1042 value: cube::Bitwise,
1043 out: Option<cube::Variable>,
1044 instructions: &mut Vec<wgsl::Instruction>,
1045 ) {
1046 let out = out.unwrap();
1047 match value {
1048 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1049 lhs: self.compile_variable(op.lhs),
1050 rhs: self.compile_variable(op.rhs),
1051 out: self.compile_variable(out),
1052 }),
1053 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1054 lhs: self.compile_variable(op.lhs),
1055 rhs: self.compile_variable(op.rhs),
1056 out: self.compile_variable(out),
1057 }),
1058 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1059 lhs: self.compile_variable(op.lhs),
1060 rhs: self.compile_variable(op.rhs),
1061 out: self.compile_variable(out),
1062 }),
1063 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1064 input: self.compile_variable(op.input),
1065 out: self.compile_variable(out),
1066 }),
1067 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1068 input: self.compile_variable(op.input),
1069 out: self.compile_variable(out),
1070 }),
1071 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1072 lhs: self.compile_variable(op.lhs),
1073 rhs: self.compile_variable(op.rhs),
1074 out: self.compile_variable(out),
1075 }),
1076 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1077 lhs: self.compile_variable(op.lhs),
1078 rhs: self.compile_variable(op.rhs),
1079 out: self.compile_variable(out),
1080 }),
1081 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1082 input: self.compile_variable(op.input),
1083 out: self.compile_variable(out),
1084 }),
1085 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1086 input: self.compile_variable(op.input),
1087 out: self.compile_variable(out),
1088 }),
1089 cube::Bitwise::TrailingZeros(op) => {
1090 instructions.push(wgsl::Instruction::TrailingZeros {
1091 input: self.compile_variable(op.input),
1092 out: self.compile_variable(out),
1093 })
1094 }
1095 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1096 input: self.compile_variable(op.input),
1097 out: self.compile_variable(out),
1098 }),
1099 }
1100 }
1101
1102 fn compile_operator(
1103 &mut self,
1104 value: cube::Operator,
1105 out: Option<cube::Variable>,
1106 instructions: &mut Vec<wgsl::Instruction>,
1107 ) {
1108 let out = out.unwrap();
1109 match value {
1110 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1111 input: self.compile_variable(op.input),
1112 out: self.compile_variable(out),
1113 }),
1114 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1115 instructions.push(wgsl::Instruction::Index {
1116 lhs: self.compile_variable(op.list),
1117 rhs: self.compile_variable(op.index),
1118 out: self.compile_variable(out),
1119 });
1120 }
1121 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1122 instructions.push(wgsl::Instruction::IndexAssign {
1123 index: self.compile_variable(op.index),
1124 rhs: self.compile_variable(op.value),
1125 out: self.compile_variable(out),
1126 })
1127 }
1128 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1129 lhs: self.compile_variable(op.lhs),
1130 rhs: self.compile_variable(op.rhs),
1131 out: self.compile_variable(out),
1132 }),
1133 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1134 lhs: self.compile_variable(op.lhs),
1135 rhs: self.compile_variable(op.rhs),
1136 out: self.compile_variable(out),
1137 }),
1138 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1139 input: self.compile_variable(op.input),
1140 out: self.compile_variable(out),
1141 }),
1142 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1143 input: self.compile_variable(op.input),
1144 out: self.compile_variable(out),
1145 }),
1146 cube::Operator::InitVector(op) => instructions.push(wgsl::Instruction::VecInit {
1147 inputs: op
1148 .inputs
1149 .into_iter()
1150 .map(|var| self.compile_variable(var))
1151 .collect(),
1152 out: self.compile_variable(out),
1153 }),
1154 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1155 input: self.compile_variable(op.input),
1156 in_index: self.compile_variable(op.in_index),
1157 out: self.compile_variable(out),
1158 out_index: self.compile_variable(op.out_index),
1159 }),
1160 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1161 input: self.compile_variable(op.input),
1162 in_index: self.compile_variable(op.in_index),
1163 out: self.compile_variable(out),
1164 out_index: self.compile_variable(op.out_index),
1165 len: op.len as u32,
1166 }),
1167 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1168 cond: self.compile_variable(op.cond),
1169 then: self.compile_variable(op.then),
1170 or_else: self.compile_variable(op.or_else),
1171 out: self.compile_variable(out),
1172 }),
1173 }
1174 }
1175
1176 fn compile_atomic(
1177 &mut self,
1178 atomic: cube::AtomicOp,
1179 out: Option<cube::Variable>,
1180 ) -> wgsl::Instruction {
1181 let out = out.unwrap();
1182 match atomic {
1183 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1184 lhs: self.compile_variable(op.lhs),
1185 rhs: self.compile_variable(op.rhs),
1186 out: self.compile_variable(out),
1187 },
1188 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1189 lhs: self.compile_variable(op.lhs),
1190 rhs: self.compile_variable(op.rhs),
1191 out: self.compile_variable(out),
1192 },
1193 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1194 lhs: self.compile_variable(op.lhs),
1195 rhs: self.compile_variable(op.rhs),
1196 out: self.compile_variable(out),
1197 },
1198 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1199 lhs: self.compile_variable(op.lhs),
1200 rhs: self.compile_variable(op.rhs),
1201 out: self.compile_variable(out),
1202 },
1203 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1204 lhs: self.compile_variable(op.lhs),
1205 rhs: self.compile_variable(op.rhs),
1206 out: self.compile_variable(out),
1207 },
1208 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1209 lhs: self.compile_variable(op.lhs),
1210 rhs: self.compile_variable(op.rhs),
1211 out: self.compile_variable(out),
1212 },
1213 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1214 lhs: self.compile_variable(op.lhs),
1215 rhs: self.compile_variable(op.rhs),
1216 out: self.compile_variable(out),
1217 },
1218 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1219 input: self.compile_variable(op.input),
1220 out: self.compile_variable(out),
1221 },
1222 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1223 input: self.compile_variable(op.input),
1224 out: self.compile_variable(out),
1225 },
1226 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1227 lhs: self.compile_variable(op.lhs),
1228 rhs: self.compile_variable(op.rhs),
1229 out: self.compile_variable(out),
1230 },
1231 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1232 lhs: self.compile_variable(op.input),
1233 cmp: self.compile_variable(op.cmp),
1234 value: self.compile_variable(op.val),
1235 out: self.compile_variable(out),
1236 },
1237 }
1238 }
1239
1240 fn compile_binding(&mut self, value: kernel::KernelArg) -> wgsl::KernelArg {
1241 wgsl::KernelArg {
1242 id: value.id,
1243 visibility: value.visibility,
1244 location: wgsl::Location::Storage,
1245 item: self.compile_type(value.ty),
1246 size: value.size,
1247 }
1248 }
1249}
1250
1251fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1252 let mut extensions = Vec::new();
1253
1254 let mut register_extension = |extension: wgsl::Extension| {
1255 if !extensions.contains(&extension) {
1256 extensions.push(extension);
1257 }
1258 };
1259
1260 for instruction in instructions {
1262 match instruction {
1263 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1264 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1265 register_extension(wgsl::powf_extension(rhs, out));
1266 }
1267 #[cfg(target_os = "macos")]
1268 wgsl::Instruction::Tanh { input, out: _ } => {
1269 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1270 register_extension(wgsl::Extension::SafeTanh(input.item()));
1271 }
1272 wgsl::Instruction::IsNan { input, out } => {
1273 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1274 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1275 }
1276 wgsl::Instruction::IsInf { input, out } => {
1277 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1278 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1279 }
1280 wgsl::Instruction::If { instructions, .. } => {
1281 for extension in register_extensions(instructions) {
1282 register_extension(extension);
1283 }
1284 }
1285 wgsl::Instruction::IfElse {
1286 instructions_if,
1287 instructions_else,
1288 ..
1289 } => {
1290 for extension in register_extensions(instructions_if) {
1291 register_extension(extension);
1292 }
1293 for extension in register_extensions(instructions_else) {
1294 register_extension(extension);
1295 }
1296 }
1297 wgsl::Instruction::Loop { instructions } => {
1298 for extension in register_extensions(instructions) {
1299 register_extension(extension);
1300 }
1301 }
1302 wgsl::Instruction::RangeLoop { instructions, .. } => {
1303 for extension in register_extensions(instructions) {
1304 register_extension(extension);
1305 }
1306 }
1307 _ => {}
1308 }
1309 }
1310
1311 extensions
1312}