1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::post_processing::{
8 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12 Metadata, WgpuCompilationOptions,
13 ir::{self as cube, Scope},
14 prelude::expand_erf,
15};
16use cubecl_core::{
17 ir::{Processor, UIntKind},
18 post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: usize = 4;
24
25#[derive(Clone, Default)]
27pub struct WgslCompiler {
28 metadata: Metadata,
29 ext_meta_pos: Vec<u32>,
30 local_invocation_index: bool,
31 local_invocation_id: bool,
32 global_invocation_id: bool,
34 workgroup_id: bool,
35 subgroup_size: bool,
36 subgroup_invocation_id: bool,
37 id: bool,
38 num_workgroups: bool,
39 workgroup_id_no_axis: bool,
40 workgroup_size_no_axis: bool,
41 num_workgroup_no_axis: bool,
42 shared_arrays: Vec<SharedArray>,
43 shared_values: Vec<SharedValue>,
44 const_arrays: Vec<ConstantArray>,
45 local_arrays: Vec<LocalArray>,
46 #[allow(dead_code)]
47 compilation_options: WgpuCompilationOptions,
48 strategy: ExecutionMode,
49 subgroup_instructions_used: bool,
50 f16_used: bool,
51}
52
53impl core::fmt::Debug for WgslCompiler {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.write_str("WgslCompiler")
56 }
57}
58
59impl cubecl_core::Compiler for WgslCompiler {
60 type Representation = ComputeShader;
61 type CompilationOptions = WgpuCompilationOptions;
62
63 fn compile(
64 &mut self,
65 shader: kernel::KernelDefinition,
66 compilation_options: &Self::CompilationOptions,
67 mode: ExecutionMode,
68 address_type: StorageType,
69 ) -> Result<Self::Representation, CompilationError> {
70 self.compilation_options = compilation_options.clone();
71 self.compile_shader(shader, mode, address_type)
72 }
73
74 fn elem_size(&self, elem: cube::ElemType) -> usize {
75 elem.size()
76 }
77
78 fn extension(&self) -> &'static str {
79 "wgsl"
80 }
81}
82
83impl WgslCompiler {
84 fn compile_shader(
85 &mut self,
86 mut value: kernel::KernelDefinition,
87 mode: ExecutionMode,
88 address_type: StorageType,
89 ) -> Result<wgsl::ComputeShader, CompilationError> {
90 let errors = value.body.pop_errors();
91 if !errors.is_empty() {
92 let mut reason = "Can't compile wgsl kernel".to_string();
93 for error in errors {
94 reason += error.as_str();
95 reason += "\n";
96 }
97
98 return Err(CompilationError::Validation {
99 reason,
100 backtrace: BackTrace::capture(),
101 });
102 }
103
104 self.strategy = mode;
105
106 let num_meta = value.buffers.len();
107
108 self.ext_meta_pos = Vec::new();
109 let mut num_ext = 0;
110
111 for binding in value.buffers.iter() {
112 self.ext_meta_pos.push(num_ext);
113 if binding.has_extended_meta {
114 num_ext += 1;
115 }
116 }
117
118 self.metadata = Metadata::new(num_meta as u32, num_ext);
119
120 let address_type = self.compile_storage_type(address_type);
121 let instructions = self.compile_scope(&mut value.body);
122 let extensions = register_extensions(&instructions);
123 let body = wgsl::Body {
124 instructions,
125 id: self.id,
126 address_type,
127 };
128
129 Ok(wgsl::ComputeShader {
130 address_type,
131 buffers: value
132 .buffers
133 .into_iter()
134 .map(|mut it| {
135 if it.ty.line_size() > MAX_LINE_SIZE {
138 it.ty = it.ty.line(MAX_LINE_SIZE);
139 }
140 self.compile_binding(it)
141 })
142 .collect(),
143 scalars: value
144 .scalars
145 .into_iter()
146 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
147 .collect(),
148 shared_arrays: self.shared_arrays.clone(),
149 shared_values: self.shared_values.clone(),
150 constant_arrays: self.const_arrays.clone(),
151 local_arrays: self.local_arrays.clone(),
152 has_metadata: self.metadata.static_len() > 0,
153 workgroup_size: value.cube_dim,
154 global_invocation_id: self.global_invocation_id || self.id,
155 local_invocation_index: self.local_invocation_index,
156 local_invocation_id: self.local_invocation_id,
157 num_workgroups: self.id
158 || self.num_workgroups
159 || self.num_workgroup_no_axis
160 || self.workgroup_id_no_axis,
161 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
162 subgroup_size: self.subgroup_size,
163 subgroup_invocation_id: self.subgroup_invocation_id,
164 body,
165 extensions,
166 num_workgroups_no_axis: self.num_workgroup_no_axis,
167 workgroup_id_no_axis: self.workgroup_id_no_axis,
168 workgroup_size_no_axis: self.workgroup_size_no_axis,
169 subgroup_instructions_used: self.subgroup_instructions_used,
170 f16_used: self.f16_used,
171 kernel_name: value.options.kernel_name,
172 })
173 }
174
175 fn compile_type(&mut self, item: cube::Type) -> Item {
176 match item {
177 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
178 cube::Type::Line(ty, size) => {
179 let elem = self.compile_storage_type(ty);
180 match size {
181 2 => wgsl::Item::Vec2(elem),
182 3 => wgsl::Item::Vec3(elem),
183 4 => wgsl::Item::Vec4(elem),
184 _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
185 }
186 }
187 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
188 }
189 }
190
191 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
192 match ty {
193 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
194 cube::StorageType::Atomic(ty) => match ty {
195 cube::ElemType::Float(i) => match i {
196 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
197 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
198 },
199 cube::ElemType::Int(i) => match i {
200 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
201 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
202 },
203 cube::ElemType::UInt(kind) => match kind {
204 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
205 kind => panic!("{kind:?} is not a valid WgpuElement"),
206 },
207 other => panic!("{other:?} is not a valid WgpuElement"),
208 },
209 cube::StorageType::Packed(_, _) => {
210 unimplemented!("Packed types not yet supported in WGSL")
211 }
212 cube::StorageType::Opaque(ty) => match ty {
213 cube::OpaqueType::Barrier(_) => {
214 unimplemented!("Barrier objects not supported in WGSL")
215 }
216 },
217 }
218 }
219
220 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
221 match value {
222 cube::ElemType::Float(f) => match f {
223 cube::FloatKind::E2M1
224 | cube::FloatKind::E2M3
225 | cube::FloatKind::E3M2
226 | cube::FloatKind::E4M3
227 | cube::FloatKind::E5M2
228 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
229 cube::FloatKind::F16 => {
230 self.f16_used = true;
231 wgsl::Elem::F16
232 }
233 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
234 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
235 cube::FloatKind::Flex32 => wgsl::Elem::F32,
236 cube::FloatKind::F32 => wgsl::Elem::F32,
237 cube::FloatKind::F64 => wgsl::Elem::F64,
238 },
239 cube::ElemType::Int(i) => match i {
240 cube::IntKind::I32 => wgsl::Elem::I32,
241 cube::IntKind::I64 => wgsl::Elem::I64,
242 kind => panic!("{kind:?} is not a valid WgpuElement"),
243 },
244 cube::ElemType::UInt(kind) => match kind {
245 cube::UIntKind::U32 => wgsl::Elem::U32,
246 cube::UIntKind::U64 => wgsl::Elem::U64,
247 kind => panic!("{kind:?} is not a valid WgpuElement"),
248 },
249 cube::ElemType::Bool => wgsl::Elem::Bool,
250 }
251 }
252
253 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
254 let pos = var.index().expect("Variable should have index");
255 self.ext_meta_pos[pos as usize]
256 }
257
258 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
259 let item = value.ty;
260 match value.kind {
261 cube::VariableKind::GlobalInputArray(id) => {
262 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
263 }
264 cube::VariableKind::GlobalScalar(id) => {
265 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
266 }
267 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
268 wgsl::Variable::LocalMut {
269 id,
270 item: self.compile_type(item),
271 }
272 }
273 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
274 id,
275 item: self.compile_type(item),
276 },
277 cube::VariableKind::GlobalOutputArray(id) => {
278 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
279 }
280 cube::VariableKind::Constant(value) => {
281 wgsl::Variable::Constant(value, self.compile_type(item))
282 }
283 cube::VariableKind::SharedArray {
284 id,
285 length,
286 unroll_factor,
287 alignment,
288 } => {
289 let item = self.compile_type(item);
290 if !self.shared_arrays.iter().any(|s| s.index == id) {
291 self.shared_arrays.push(SharedArray::new(
292 id,
293 item,
294 (length * unroll_factor) as u32,
295 alignment.map(|it| it as u32),
296 ));
297 }
298 wgsl::Variable::SharedArray(id, item, length as u32)
299 }
300 cube::VariableKind::Shared { id } => {
301 let item = self.compile_type(item);
302 if !self.shared_values.iter().any(|s| s.index == id) {
303 self.shared_values.push(SharedValue::new(id, item));
304 }
305 wgsl::Variable::SharedValue(id, item)
306 }
307 cube::VariableKind::ConstantArray { id, length, .. } => {
308 let item = self.compile_type(item);
309 wgsl::Variable::ConstantArray(id, item, length as u32)
310 }
311 cube::VariableKind::LocalArray {
312 id,
313 length,
314 unroll_factor,
315 } => {
316 let item = self.compile_type(item);
317 if !self.local_arrays.iter().any(|s| s.index == id) {
318 self.local_arrays.push(LocalArray::new(
319 id,
320 item,
321 (length * unroll_factor) as u32,
322 ));
323 }
324 wgsl::Variable::LocalArray(id, item, length as u32)
325 }
326 cube::VariableKind::Builtin(builtin) => match builtin {
327 cube::Builtin::AbsolutePos => {
328 self.id = true;
329 wgsl::Variable::Id
330 }
331 cube::Builtin::UnitPos => {
332 self.local_invocation_index = true;
333 wgsl::Variable::LocalInvocationIndex
334 }
335 cube::Builtin::UnitPosX => {
336 self.local_invocation_id = true;
337 wgsl::Variable::LocalInvocationIdX
338 }
339 cube::Builtin::UnitPosY => {
340 self.local_invocation_id = true;
341 wgsl::Variable::LocalInvocationIdY
342 }
343 cube::Builtin::UnitPosZ => {
344 self.local_invocation_id = true;
345 wgsl::Variable::LocalInvocationIdZ
346 }
347 cube::Builtin::CubePosX => {
348 self.workgroup_id = true;
349 wgsl::Variable::WorkgroupIdX
350 }
351 cube::Builtin::CubePosY => {
352 self.workgroup_id = true;
353 wgsl::Variable::WorkgroupIdY
354 }
355 cube::Builtin::CubePosZ => {
356 self.workgroup_id = true;
357 wgsl::Variable::WorkgroupIdZ
358 }
359 cube::Builtin::CubePosCluster
360 | cube::Builtin::CubePosClusterX
361 | cube::Builtin::CubePosClusterY
362 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
363 cube::Builtin::AbsolutePosX => {
364 self.global_invocation_id = true;
365 wgsl::Variable::GlobalInvocationIdX
366 }
367 cube::Builtin::AbsolutePosY => {
368 self.global_invocation_id = true;
369 wgsl::Variable::GlobalInvocationIdY
370 }
371 cube::Builtin::AbsolutePosZ => {
372 self.global_invocation_id = true;
373 wgsl::Variable::GlobalInvocationIdZ
374 }
375 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
376 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
377 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
378 cube::Builtin::CubeClusterDim
379 | cube::Builtin::CubeClusterDimX
380 | cube::Builtin::CubeClusterDimY
381 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
382 cube::Builtin::CubeCountX => {
383 self.num_workgroups = true;
384 wgsl::Variable::NumWorkgroupsX
385 }
386 cube::Builtin::CubeCountY => {
387 self.num_workgroups = true;
388 wgsl::Variable::NumWorkgroupsY
389 }
390 cube::Builtin::CubeCountZ => {
391 self.num_workgroups = true;
392 wgsl::Variable::NumWorkgroupsZ
393 }
394 cube::Builtin::CubePos => {
395 self.workgroup_id_no_axis = true;
396 wgsl::Variable::WorkgroupId
397 }
398 cube::Builtin::CubeDim => {
399 self.workgroup_size_no_axis = true;
400 wgsl::Variable::WorkgroupSize
401 }
402 cube::Builtin::CubeCount => {
403 self.num_workgroup_no_axis = true;
404 wgsl::Variable::NumWorkgroups
405 }
406 cube::Builtin::PlaneDim => {
407 self.subgroup_size = true;
408 wgsl::Variable::SubgroupSize
409 }
410 cube::Builtin::UnitPosPlane => {
411 self.subgroup_invocation_id = true;
412 wgsl::Variable::SubgroupInvocationId
413 }
414 },
415 cube::VariableKind::Matrix { .. } => {
416 panic!("Cooperative matrix-multiply and accumulate not supported.")
417 }
418 cube::VariableKind::Pipeline { .. } => {
419 panic!("Pipeline not supported.")
420 }
421 cube::VariableKind::BarrierToken { .. } => {
422 panic!("Barrier not supported.")
423 }
424 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
425 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
426 }
427 }
428
429 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
430 let var = cube::Variable::constant(value.into(), UIntKind::U32);
431 self.compile_variable(var)
432 }
433
434 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
435 let mut instructions = Vec::new();
436
437 let const_arrays = scope
438 .const_arrays
439 .drain(..)
440 .map(|(var, values)| ConstantArray {
441 index: var.index().unwrap(),
442 item: self.compile_type(var.ty),
443 size: values.len() as u32,
444 values: values
445 .into_iter()
446 .map(|val| self.compile_variable(val))
447 .collect(),
448 })
449 .collect::<Vec<_>>();
450 self.const_arrays.extend(const_arrays);
451
452 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
453 let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
454 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
455 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
456
457 for mut var in processing.variables {
458 if var.ty.line_size() > MAX_LINE_SIZE {
459 var.ty = var.ty.line(MAX_LINE_SIZE);
460 }
461 instructions.push(wgsl::Instruction::DeclareVariable {
462 var: self.compile_variable(var),
463 });
464 }
465
466 processing
467 .instructions
468 .into_iter()
469 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
470
471 instructions
472 }
473
474 fn compile_operation(
475 &mut self,
476 instructions: &mut Vec<wgsl::Instruction>,
477 operation: cube::Operation,
478 out: Option<cube::Variable>,
479 scope: &mut cube::Scope,
480 ) {
481 match operation {
482 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
483 input: self.compile_variable(variable),
484 out: self.compile_variable(out.unwrap()),
485 }),
486 cube::Operation::Arithmetic(op) => {
487 self.compile_arithmetic(op, out, instructions, scope)
488 }
489 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
490 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
491 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
492 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
493 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
494 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
495 cube::Operation::Synchronization(val) => {
496 self.compile_synchronization(instructions, val)
497 }
498 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
499 cube::Operation::CoopMma(_) => {
500 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
501 }
502 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
503 self.compile_comment(instructions, content)
504 }
505 cube::Operation::NonSemantic(_) => {}
506 cube::Operation::Barrier(_) => {
507 panic!("Barrier isn't supported on wgpu.")
508 }
509 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
510 cube::Operation::Marker(_) => {}
511 }
512 }
513
514 fn compile_subgroup(
515 &mut self,
516 instructions: &mut Vec<wgsl::Instruction>,
517 subgroup: cube::Plane,
518 out: Option<cube::Variable>,
519 ) {
520 self.subgroup_instructions_used = true;
521
522 let out = out.unwrap();
523 let op = match subgroup {
524 cube::Plane::Elect => Subgroup::Elect {
525 out: self.compile_variable(out),
526 },
527 cube::Plane::All(op) => Subgroup::All {
528 input: self.compile_variable(op.input),
529 out: self.compile_variable(out),
530 },
531 cube::Plane::Any(op) => Subgroup::Any {
532 input: self.compile_variable(op.input),
533 out: self.compile_variable(out),
534 },
535 cube::Plane::Ballot(op) => Subgroup::Ballot {
536 input: self.compile_variable(op.input),
537 out: self.compile_variable(out),
538 },
539
540 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
541 lhs: self.compile_variable(op.lhs),
542 rhs: self.compile_variable(op.rhs),
543 out: self.compile_variable(out),
544 },
545
546 cube::Plane::Sum(op) => Subgroup::Sum {
547 input: self.compile_variable(op.input),
548 out: self.compile_variable(out),
549 },
550
551 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
552 input: self.compile_variable(op.input),
553 out: self.compile_variable(out),
554 },
555 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
556 input: self.compile_variable(op.input),
557 out: self.compile_variable(out),
558 },
559 cube::Plane::Prod(op) => Subgroup::Prod {
560 input: self.compile_variable(op.input),
561 out: self.compile_variable(out),
562 },
563 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
564 input: self.compile_variable(op.input),
565 out: self.compile_variable(out),
566 },
567 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
568 input: self.compile_variable(op.input),
569 out: self.compile_variable(out),
570 },
571 cube::Plane::Min(op) => Subgroup::Min {
572 input: self.compile_variable(op.input),
573 out: self.compile_variable(out),
574 },
575 cube::Plane::Max(op) => Subgroup::Max {
576 input: self.compile_variable(op.input),
577 out: self.compile_variable(out),
578 },
579 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
580 lhs: self.compile_variable(op.lhs),
581 rhs: self.compile_variable(op.rhs),
582 out: self.compile_variable(out),
583 },
584 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
585 lhs: self.compile_variable(op.lhs),
586 rhs: self.compile_variable(op.rhs),
587 out: self.compile_variable(out),
588 },
589 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
590 lhs: self.compile_variable(op.lhs),
591 rhs: self.compile_variable(op.rhs),
592 out: self.compile_variable(out),
593 },
594 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
595 lhs: self.compile_variable(op.lhs),
596 rhs: self.compile_variable(op.rhs),
597 out: self.compile_variable(out),
598 },
599 };
600
601 instructions.push(wgsl::Instruction::Subgroup(op));
602 }
603
604 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
605 match branch {
606 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
607 cond: self.compile_variable(op.cond),
608 instructions: self.compile_scope(&mut op.scope),
609 }),
610 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
611 cond: self.compile_variable(op.cond),
612 instructions_if: self.compile_scope(&mut op.scope_if),
613 instructions_else: self.compile_scope(&mut op.scope_else),
614 }),
615 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
616 value: self.compile_variable(op.value),
617 instructions_default: self.compile_scope(&mut op.scope_default),
618 cases: op
619 .cases
620 .into_iter()
621 .map(|(val, mut scope)| {
622 (self.compile_variable(val), self.compile_scope(&mut scope))
623 })
624 .collect(),
625 }),
626 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
627 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
628 cube::Branch::RangeLoop(mut range_loop) => {
629 instructions.push(wgsl::Instruction::RangeLoop {
630 i: self.compile_variable(range_loop.i),
631 start: self.compile_variable(range_loop.start),
632 end: self.compile_variable(range_loop.end),
633 step: range_loop.step.map(|it| self.compile_variable(it)),
634 inclusive: range_loop.inclusive,
635 instructions: self.compile_scope(&mut range_loop.scope),
636 })
637 }
638 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
639 instructions: self.compile_scope(&mut op.scope),
640 }),
641 };
642 }
643
644 fn compile_synchronization(
645 &mut self,
646 instructions: &mut Vec<wgsl::Instruction>,
647 synchronization: cube::Synchronization,
648 ) {
649 match synchronization {
650 cube::Synchronization::SyncCube => {
651 instructions.push(wgsl::Instruction::WorkgroupBarrier)
652 }
653 cube::Synchronization::SyncPlane => {
654 panic!("Synchronization within a plane is not supported in WGSL")
655 }
656 cube::Synchronization::SyncStorage => {
657 instructions.push(wgsl::Instruction::StorageBarrier)
658 }
659 cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
660 };
661 }
662
663 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
664 instructions.push(wgsl::Instruction::Comment { content })
665 }
666
667 fn compile_metadata(
668 &mut self,
669 metadata: cube::Metadata,
670 out: Option<cube::Variable>,
671 ) -> wgsl::Instruction {
672 let out = out.unwrap();
673 match metadata {
674 cube::Metadata::Rank { var } => {
675 let position = self.ext_meta_pos(&var);
676 let offset = self.metadata.rank_index(position);
677 wgsl::Instruction::Metadata {
678 out: self.compile_variable(out),
679 info_offset: self.compile_variable(offset.into()),
680 }
681 }
682 cube::Metadata::Stride { dim, var } => {
683 let position = self.ext_meta_pos(&var);
684 let offset = self.metadata.stride_offset_index(position);
685 wgsl::Instruction::ExtendedMeta {
686 info_offset: self.compile_variable(offset.into()),
687 dim: self.compile_variable(dim),
688 out: self.compile_variable(out),
689 }
690 }
691 cube::Metadata::Shape { dim, var } => {
692 let position = self.ext_meta_pos(&var);
693 let offset = self.metadata.shape_offset_index(position);
694 wgsl::Instruction::ExtendedMeta {
695 info_offset: self.compile_variable(offset.into()),
696 dim: self.compile_variable(dim),
697 out: self.compile_variable(out),
698 }
699 }
700 cube::Metadata::Length { var } => match var.kind {
701 cube::VariableKind::GlobalInputArray(id) => {
702 let offset = self.metadata.len_index(id);
703 wgsl::Instruction::Metadata {
704 out: self.compile_variable(out),
705 info_offset: self.compile_variable(offset.into()),
706 }
707 }
708 cube::VariableKind::GlobalOutputArray(id) => {
709 let offset = self.metadata.len_index(id);
710 wgsl::Instruction::Metadata {
711 out: self.compile_variable(out),
712 info_offset: self.compile_variable(offset.into()),
713 }
714 }
715 _ => wgsl::Instruction::Length {
716 var: self.compile_variable(var),
717 out: self.compile_variable(out),
718 },
719 },
720 cube::Metadata::BufferLength { var } => match var.kind {
721 cube::VariableKind::GlobalInputArray(id) => {
722 let offset = self.metadata.buffer_len_index(id);
723 wgsl::Instruction::Metadata {
724 out: self.compile_variable(out),
725 info_offset: self.compile_variable(offset.into()),
726 }
727 }
728 cube::VariableKind::GlobalOutputArray(id) => {
729 let offset = self.metadata.buffer_len_index(id);
730 wgsl::Instruction::Metadata {
731 out: self.compile_variable(out),
732 info_offset: self.compile_variable(offset.into()),
733 }
734 }
735 _ => wgsl::Instruction::Length {
736 var: self.compile_variable(var),
737 out: self.compile_variable(out),
738 },
739 },
740 }
741 }
742
743 fn compile_arithmetic(
744 &mut self,
745 value: cube::Arithmetic,
746 out: Option<cube::Variable>,
747 instructions: &mut Vec<wgsl::Instruction>,
748 scope: &mut Scope,
749 ) {
750 let out = out.unwrap();
751 match value {
752 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
753 lhs: self.compile_variable(op.lhs),
754 rhs: self.compile_variable(op.rhs),
755 out: self.compile_variable(out),
756 }),
757 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
758 lhs: self.compile_variable(op.lhs),
759 rhs: self.compile_variable(op.rhs),
760 out: self.compile_variable(out),
761 }),
762 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
763 lhs: self.compile_variable(op.lhs),
764 rhs: self.compile_variable(op.rhs),
765 out: self.compile_variable(out),
766 }),
767 cube::Arithmetic::SaturatingAdd(_) => {
768 unreachable!("Saturating add should be removed by processor");
769 }
770 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
771 a: self.compile_variable(op.a),
772 b: self.compile_variable(op.b),
773 c: self.compile_variable(op.c),
774 out: self.compile_variable(out),
775 }),
776 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
777 lhs: self.compile_variable(op.lhs),
778 rhs: self.compile_variable(op.rhs),
779 out: self.compile_variable(out),
780 }),
781 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
782 lhs: self.compile_variable(op.lhs),
783 rhs: self.compile_variable(op.rhs),
784 out: self.compile_variable(out),
785 }),
786 cube::Arithmetic::SaturatingSub(_) => {
787 unreachable!("Saturating sub should be removed by processor");
788 }
789 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
790 lhs: self.compile_variable(op.lhs),
791 rhs: self.compile_variable(op.rhs),
792 out: self.compile_variable(out),
793 }),
794 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
795 lhs: self.compile_variable(op.lhs),
796 rhs: self.compile_variable(op.rhs),
797 out: self.compile_variable(out),
798 }),
799 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
800 input: self.compile_variable(op.input),
801 out: self.compile_variable(out),
802 }),
803 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
804 input: self.compile_variable(op.input),
805 out: self.compile_variable(out),
806 }),
807 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
808 input: self.compile_variable(op.input),
809 out: self.compile_variable(out),
810 }),
811 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
812 input: self.compile_variable(op.input),
813 out: self.compile_variable(out),
814 }),
815 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
816 input: self.compile_variable(op.input),
817 out: self.compile_variable(out),
818 }),
819 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
820 input: self.compile_variable(op.input),
821 out: self.compile_variable(out),
822 }),
823 cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
824 input: self.compile_variable(op.input),
825 out: self.compile_variable(out),
826 }),
827 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
828 input: self.compile_variable(op.input),
829 out: self.compile_variable(out),
830 }),
831 cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
832 input: self.compile_variable(op.input),
833 out: self.compile_variable(out),
834 }),
835 cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
836 input: self.compile_variable(op.input),
837 out: self.compile_variable(out),
838 }),
839 cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
840 input: self.compile_variable(op.input),
841 out: self.compile_variable(out),
842 }),
843 cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
844 input: self.compile_variable(op.input),
845 out: self.compile_variable(out),
846 }),
847 cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
848 input: self.compile_variable(op.input),
849 out: self.compile_variable(out),
850 }),
851 cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
852 input: self.compile_variable(op.input),
853 out: self.compile_variable(out),
854 }),
855 cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
856 input: self.compile_variable(op.input),
857 out: self.compile_variable(out),
858 }),
859 cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
860 input: self.compile_variable(op.input),
861 out: self.compile_variable(out),
862 }),
863 cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
864 input: self.compile_variable(op.input),
865 out: self.compile_variable(out),
866 }),
867 cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
868 input: self.compile_variable(op.input),
869 out: self.compile_variable(out),
870 }),
871 cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
872 lhs: self.compile_variable(op.lhs),
873 rhs: self.compile_variable(op.rhs),
874 out: self.compile_variable(out),
875 }),
876 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
878 instructions.push(wgsl::Instruction::Powf {
879 lhs: self.compile_variable(op.lhs),
880 rhs: self.compile_variable(op.rhs),
881 out: self.compile_variable(out),
882 })
883 }
884 cube::Arithmetic::Hypot(op) => {
885 let mut scope = scope.child();
886 expand_hypot(&mut scope, op.lhs, op.rhs, out);
887 instructions.extend(self.compile_scope(&mut scope));
888 }
889 cube::Arithmetic::Rhypot(op) => {
890 let mut scope = scope.child();
891 expand_rhypot(&mut scope, op.lhs, op.rhs, out);
892 instructions.extend(self.compile_scope(&mut scope));
893 }
894
895 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
896 input: self.compile_variable(op.input),
897 out: self.compile_variable(out),
898 }),
899 cube::Arithmetic::InverseSqrt(op) => {
900 instructions.push(wgsl::Instruction::InverseSqrt {
901 input: self.compile_variable(op.input),
902 out: self.compile_variable(out),
903 })
904 }
905 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
906 input: self.compile_variable(op.input),
907 out: self.compile_variable(out),
908 }),
909 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
910 input: self.compile_variable(op.input),
911 out: self.compile_variable(out),
912 }),
913 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
914 input: self.compile_variable(op.input),
915 out: self.compile_variable(out),
916 }),
917 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
918 input: self.compile_variable(op.input),
919 out: self.compile_variable(out),
920 }),
921 cube::Arithmetic::Erf(op) => {
922 let mut scope = scope.child();
923 expand_erf(&mut scope, op.input, out);
924 instructions.extend(self.compile_scope(&mut scope));
925 }
926 cube::Arithmetic::MulHi(op) => {
927 let mut scope = scope.child();
928 match self.compilation_options.supports_u64 {
929 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
930 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
931 }
932 instructions.extend(self.compile_scope(&mut scope));
933 }
934 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
935 input: self.compile_variable(op.input),
936 out: self.compile_variable(out),
937 }),
938 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
939 input: self.compile_variable(op.input),
940 min_value: self.compile_variable(op.min_value),
941 max_value: self.compile_variable(op.max_value),
942 out: self.compile_variable(out),
943 }),
944 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
945 lhs: self.compile_variable(op.lhs),
946 rhs: self.compile_variable(op.rhs),
947 out: self.compile_variable(out),
948 }),
949 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
950 input: self.compile_variable(op.input),
951 out: self.compile_variable(out),
952 }),
953 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
954 input: self.compile_variable(op.input),
955 out: self.compile_variable(out),
956 }),
957 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
958 input: self.compile_variable(op.input),
959 out: self.compile_variable(out),
960 }),
961 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
962 lhs: self.compile_variable(op.lhs),
963 rhs: self.compile_variable(op.rhs),
964 out: self.compile_variable(out),
965 }),
966 }
967 }
968
969 fn compile_cmp(
970 &mut self,
971 value: cube::Comparison,
972 out: Option<cube::Variable>,
973 instructions: &mut Vec<wgsl::Instruction>,
974 ) {
975 let out = out.unwrap();
976 match value {
977 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
978 lhs: self.compile_variable(op.lhs),
979 rhs: self.compile_variable(op.rhs),
980 out: self.compile_variable(out),
981 }),
982 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
983 lhs: self.compile_variable(op.lhs),
984 rhs: self.compile_variable(op.rhs),
985 out: self.compile_variable(out),
986 }),
987 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
988 lhs: self.compile_variable(op.lhs),
989 rhs: self.compile_variable(op.rhs),
990 out: self.compile_variable(out),
991 }),
992 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
993 lhs: self.compile_variable(op.lhs),
994 rhs: self.compile_variable(op.rhs),
995 out: self.compile_variable(out),
996 }),
997 cube::Comparison::GreaterEqual(op) => {
998 instructions.push(wgsl::Instruction::GreaterEqual {
999 lhs: self.compile_variable(op.lhs),
1000 rhs: self.compile_variable(op.rhs),
1001 out: self.compile_variable(out),
1002 })
1003 }
1004 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
1005 lhs: self.compile_variable(op.lhs),
1006 rhs: self.compile_variable(op.rhs),
1007 out: self.compile_variable(out),
1008 }),
1009 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1010 input: self.compile_variable(op.input),
1011 out: self.compile_variable(out),
1012 }),
1013 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1014 input: self.compile_variable(op.input),
1015 out: self.compile_variable(out),
1016 }),
1017 }
1018 }
1019
1020 fn compile_bitwise(
1021 &mut self,
1022 value: cube::Bitwise,
1023 out: Option<cube::Variable>,
1024 instructions: &mut Vec<wgsl::Instruction>,
1025 ) {
1026 let out = out.unwrap();
1027 match value {
1028 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1029 lhs: self.compile_variable(op.lhs),
1030 rhs: self.compile_variable(op.rhs),
1031 out: self.compile_variable(out),
1032 }),
1033 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1034 lhs: self.compile_variable(op.lhs),
1035 rhs: self.compile_variable(op.rhs),
1036 out: self.compile_variable(out),
1037 }),
1038 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1039 lhs: self.compile_variable(op.lhs),
1040 rhs: self.compile_variable(op.rhs),
1041 out: self.compile_variable(out),
1042 }),
1043 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1044 input: self.compile_variable(op.input),
1045 out: self.compile_variable(out),
1046 }),
1047 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1048 input: self.compile_variable(op.input),
1049 out: self.compile_variable(out),
1050 }),
1051 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1052 lhs: self.compile_variable(op.lhs),
1053 rhs: self.compile_variable(op.rhs),
1054 out: self.compile_variable(out),
1055 }),
1056 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1057 lhs: self.compile_variable(op.lhs),
1058 rhs: self.compile_variable(op.rhs),
1059 out: self.compile_variable(out),
1060 }),
1061 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1062 input: self.compile_variable(op.input),
1063 out: self.compile_variable(out),
1064 }),
1065 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1066 input: self.compile_variable(op.input),
1067 out: self.compile_variable(out),
1068 }),
1069 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1070 input: self.compile_variable(op.input),
1071 out: self.compile_variable(out),
1072 }),
1073 }
1074 }
1075
1076 fn compile_operator(
1077 &mut self,
1078 value: cube::Operator,
1079 out: Option<cube::Variable>,
1080 instructions: &mut Vec<wgsl::Instruction>,
1081 ) {
1082 let out = out.unwrap();
1083 match value {
1084 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1085 input: self.compile_variable(op.input),
1086 out: self.compile_variable(out),
1087 }),
1088 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1089 instructions.push(wgsl::Instruction::Index {
1090 lhs: self.compile_variable(op.list),
1091 rhs: self.compile_variable(op.index),
1092 out: self.compile_variable(out),
1093 });
1094 }
1095 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1096 instructions.push(wgsl::Instruction::IndexAssign {
1097 index: self.compile_variable(op.index),
1098 rhs: self.compile_variable(op.value),
1099 out: self.compile_variable(out),
1100 })
1101 }
1102 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1103 lhs: self.compile_variable(op.lhs),
1104 rhs: self.compile_variable(op.rhs),
1105 out: self.compile_variable(out),
1106 }),
1107 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1108 lhs: self.compile_variable(op.lhs),
1109 rhs: self.compile_variable(op.rhs),
1110 out: self.compile_variable(out),
1111 }),
1112 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1113 input: self.compile_variable(op.input),
1114 out: self.compile_variable(out),
1115 }),
1116 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1117 input: self.compile_variable(op.input),
1118 out: self.compile_variable(out),
1119 }),
1120 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1121 inputs: op
1122 .inputs
1123 .into_iter()
1124 .map(|var| self.compile_variable(var))
1125 .collect(),
1126 out: self.compile_variable(out),
1127 }),
1128 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1129 input: self.compile_variable(op.input),
1130 in_index: self.compile_variable(op.in_index),
1131 out: self.compile_variable(out),
1132 out_index: self.compile_variable(op.out_index),
1133 }),
1134 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1135 input: self.compile_variable(op.input),
1136 in_index: self.compile_variable(op.in_index),
1137 out: self.compile_variable(out),
1138 out_index: self.compile_variable(op.out_index),
1139 len: op.len as u32,
1140 }),
1141 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1142 cond: self.compile_variable(op.cond),
1143 then: self.compile_variable(op.then),
1144 or_else: self.compile_variable(op.or_else),
1145 out: self.compile_variable(out),
1146 }),
1147 }
1148 }
1149
1150 fn compile_atomic(
1151 &mut self,
1152 atomic: cube::AtomicOp,
1153 out: Option<cube::Variable>,
1154 ) -> wgsl::Instruction {
1155 let out = out.unwrap();
1156 match atomic {
1157 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1158 lhs: self.compile_variable(op.lhs),
1159 rhs: self.compile_variable(op.rhs),
1160 out: self.compile_variable(out),
1161 },
1162 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1163 lhs: self.compile_variable(op.lhs),
1164 rhs: self.compile_variable(op.rhs),
1165 out: self.compile_variable(out),
1166 },
1167 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1168 lhs: self.compile_variable(op.lhs),
1169 rhs: self.compile_variable(op.rhs),
1170 out: self.compile_variable(out),
1171 },
1172 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1173 lhs: self.compile_variable(op.lhs),
1174 rhs: self.compile_variable(op.rhs),
1175 out: self.compile_variable(out),
1176 },
1177 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1178 lhs: self.compile_variable(op.lhs),
1179 rhs: self.compile_variable(op.rhs),
1180 out: self.compile_variable(out),
1181 },
1182 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1183 lhs: self.compile_variable(op.lhs),
1184 rhs: self.compile_variable(op.rhs),
1185 out: self.compile_variable(out),
1186 },
1187 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1188 lhs: self.compile_variable(op.lhs),
1189 rhs: self.compile_variable(op.rhs),
1190 out: self.compile_variable(out),
1191 },
1192 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1193 input: self.compile_variable(op.input),
1194 out: self.compile_variable(out),
1195 },
1196 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1197 input: self.compile_variable(op.input),
1198 out: self.compile_variable(out),
1199 },
1200 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1201 lhs: self.compile_variable(op.lhs),
1202 rhs: self.compile_variable(op.rhs),
1203 out: self.compile_variable(out),
1204 },
1205 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1206 lhs: self.compile_variable(op.input),
1207 cmp: self.compile_variable(op.cmp),
1208 value: self.compile_variable(op.val),
1209 out: self.compile_variable(out),
1210 },
1211 }
1212 }
1213
1214 fn compile_location(value: kernel::Location) -> wgsl::Location {
1215 match value {
1216 kernel::Location::Storage => wgsl::Location::Storage,
1217 kernel::Location::Cube => wgsl::Location::Workgroup,
1218 }
1219 }
1220
1221 fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1222 wgsl::Binding {
1223 id: value.id,
1224 visibility: value.visibility,
1225 location: Self::compile_location(value.location),
1226 item: self.compile_type(value.ty),
1227 size: value.size,
1228 }
1229 }
1230}
1231
1232fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1233 let mut extensions = Vec::new();
1234
1235 let mut register_extension = |extension: wgsl::Extension| {
1236 if !extensions.contains(&extension) {
1237 extensions.push(extension);
1238 }
1239 };
1240
1241 for instruction in instructions {
1243 match instruction {
1244 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1245 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1246 register_extension(wgsl::powf_extension(rhs, out));
1247 }
1248 #[cfg(target_os = "macos")]
1249 wgsl::Instruction::Tanh { input, out: _ } => {
1250 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1251 register_extension(wgsl::Extension::SafeTanh(input.item()));
1252 }
1253 wgsl::Instruction::IsNan { input, out } => {
1254 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1255 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1256 }
1257 wgsl::Instruction::IsInf { input, out } => {
1258 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1259 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1260 }
1261 wgsl::Instruction::If { instructions, .. } => {
1262 for extension in register_extensions(instructions) {
1263 register_extension(extension);
1264 }
1265 }
1266 wgsl::Instruction::IfElse {
1267 instructions_if,
1268 instructions_else,
1269 ..
1270 } => {
1271 for extension in register_extensions(instructions_if) {
1272 register_extension(extension);
1273 }
1274 for extension in register_extensions(instructions_else) {
1275 register_extension(extension);
1276 }
1277 }
1278 wgsl::Instruction::Loop { instructions } => {
1279 for extension in register_extensions(instructions) {
1280 register_extension(extension);
1281 }
1282 }
1283 wgsl::Instruction::RangeLoop { instructions, .. } => {
1284 for extension in register_extensions(instructions) {
1285 register_extension(extension);
1286 }
1287 }
1288 _ => {}
1289 }
1290 }
1291
1292 extensions
1293}