1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::ExecutionMode;
7use cubecl_common::backtrace::BackTrace;
8use cubecl_core::post_processing::{
9 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
10};
11use cubecl_core::prelude::*;
12use cubecl_core::{
13 Metadata, WgpuCompilationOptions,
14 ir::{self as cube, Scope},
15 prelude::expand_erf,
16};
17use cubecl_core::{
18 ir::{ConstantScalarValue, Processor, UIntKind},
19 post_processing::unroll::UnrollProcessor,
20};
21use cubecl_runtime::compiler::CompilationError;
22use cubecl_runtime::kernel;
23
24pub const MAX_LINE_SIZE: u32 = 4;
25
26#[derive(Clone, Default)]
28pub struct WgslCompiler {
29 metadata: Metadata,
30 ext_meta_pos: Vec<u32>,
31 local_invocation_index: bool,
32 local_invocation_id: bool,
33 global_invocation_id: bool,
35 workgroup_id: bool,
36 subgroup_size: bool,
37 subgroup_invocation_id: bool,
38 id: bool,
39 num_workgroups: bool,
40 workgroup_id_no_axis: bool,
41 workgroup_size_no_axis: bool,
42 num_workgroup_no_axis: bool,
43 shared_arrays: Vec<SharedArray>,
44 shared_values: Vec<SharedValue>,
45 const_arrays: Vec<ConstantArray>,
46 local_arrays: Vec<LocalArray>,
47 #[allow(dead_code)]
48 compilation_options: WgpuCompilationOptions,
49 strategy: ExecutionMode,
50 subgroup_instructions_used: bool,
51 f16_used: bool,
52}
53
54impl core::fmt::Debug for WgslCompiler {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.write_str("WgslCompiler")
57 }
58}
59
60impl cubecl_core::Compiler for WgslCompiler {
61 type Representation = ComputeShader;
62 type CompilationOptions = WgpuCompilationOptions;
63
64 fn compile(
65 &mut self,
66 shader: kernel::KernelDefinition,
67 compilation_options: &Self::CompilationOptions,
68 mode: ExecutionMode,
69 ) -> Result<Self::Representation, CompilationError> {
70 self.compilation_options = compilation_options.clone();
71 self.compile_shader(shader, mode)
72 }
73
74 fn elem_size(&self, elem: cube::ElemType) -> usize {
75 elem.size()
76 }
77
78 fn extension(&self) -> &'static str {
79 "wgsl"
80 }
81}
82
83impl WgslCompiler {
84 fn compile_shader(
85 &mut self,
86 mut value: kernel::KernelDefinition,
87 mode: ExecutionMode,
88 ) -> Result<wgsl::ComputeShader, CompilationError> {
89 let errors = value.body.pop_errors();
90 if !errors.is_empty() {
91 let mut reason = "Can't compile wgsl kernel".to_string();
92 for error in errors {
93 reason += error.as_str();
94 reason += "\n";
95 }
96
97 return Err(CompilationError::Validation {
98 reason,
99 backtrace: BackTrace::capture(),
100 });
101 }
102
103 self.strategy = mode;
104
105 let num_meta = value.buffers.len();
106
107 self.ext_meta_pos = Vec::new();
108 let mut num_ext = 0;
109
110 for binding in value.buffers.iter() {
111 self.ext_meta_pos.push(num_ext);
112 if binding.has_extended_meta {
113 num_ext += 1;
114 }
115 }
116
117 self.metadata = Metadata::new(num_meta as u32, num_ext);
118
119 let instructions = self.compile_scope(&mut value.body);
120 let extensions = register_extensions(&instructions);
121 let body = wgsl::Body {
122 instructions,
123 id: self.id,
124 };
125
126 Ok(wgsl::ComputeShader {
127 buffers: value
128 .buffers
129 .into_iter()
130 .map(|mut it| {
131 if it.ty.line_size() > MAX_LINE_SIZE {
134 it.ty = it.ty.line(MAX_LINE_SIZE);
135 }
136 self.compile_binding(it)
137 })
138 .collect(),
139 scalars: value
140 .scalars
141 .into_iter()
142 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
143 .collect(),
144 shared_arrays: self.shared_arrays.clone(),
145 shared_values: self.shared_values.clone(),
146 constant_arrays: self.const_arrays.clone(),
147 local_arrays: self.local_arrays.clone(),
148 has_metadata: self.metadata.static_len() > 0,
149 workgroup_size: value.cube_dim,
150 global_invocation_id: self.global_invocation_id || self.id,
151 local_invocation_index: self.local_invocation_index,
152 local_invocation_id: self.local_invocation_id,
153 num_workgroups: self.id
154 || self.num_workgroups
155 || self.num_workgroup_no_axis
156 || self.workgroup_id_no_axis,
157 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
158 subgroup_size: self.subgroup_size,
159 subgroup_invocation_id: self.subgroup_invocation_id,
160 body,
161 extensions,
162 num_workgroups_no_axis: self.num_workgroup_no_axis,
163 workgroup_id_no_axis: self.workgroup_id_no_axis,
164 workgroup_size_no_axis: self.workgroup_size_no_axis,
165 subgroup_instructions_used: self.subgroup_instructions_used,
166 f16_used: self.f16_used,
167 kernel_name: value.options.kernel_name,
168 })
169 }
170
171 fn compile_type(&mut self, item: cube::Type) -> Item {
172 match item {
173 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
174 cube::Type::Line(ty, size) => {
175 let elem = self.compile_storage_type(ty);
176 match size {
177 2 => wgsl::Item::Vec2(elem),
178 3 => wgsl::Item::Vec3(elem),
179 4 => wgsl::Item::Vec4(elem),
180 _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
181 }
182 }
183 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
184 }
185 }
186
187 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
188 match ty {
189 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
190 cube::StorageType::Atomic(ty) => match ty {
191 cube::ElemType::Float(i) => match i {
192 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
193 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
194 },
195 cube::ElemType::Int(i) => match i {
196 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
197 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
198 },
199 cube::ElemType::UInt(kind) => match kind {
200 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
201 kind => panic!("{kind:?} is not a valid WgpuElement"),
202 },
203 other => panic!("{other:?} is not a valid WgpuElement"),
204 },
205 cube::StorageType::Packed(_, _) => {
206 unimplemented!("Packed types not yet supported in WGSL")
207 }
208 cube::StorageType::Opaque(ty) => match ty {
209 cube::OpaqueType::Barrier(_) => {
210 unimplemented!("Barrier objects not supported in WGSL")
211 }
212 },
213 }
214 }
215
216 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
217 match value {
218 cube::ElemType::Float(f) => match f {
219 cube::FloatKind::E2M1
220 | cube::FloatKind::E2M3
221 | cube::FloatKind::E3M2
222 | cube::FloatKind::E4M3
223 | cube::FloatKind::E5M2
224 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
225 cube::FloatKind::F16 => {
226 self.f16_used = true;
227 wgsl::Elem::F16
228 }
229 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
230 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
231 cube::FloatKind::Flex32 => wgsl::Elem::F32,
232 cube::FloatKind::F32 => wgsl::Elem::F32,
233 cube::FloatKind::F64 => wgsl::Elem::F64,
234 },
235 cube::ElemType::Int(i) => match i {
236 cube::IntKind::I32 => wgsl::Elem::I32,
237 cube::IntKind::I64 => wgsl::Elem::I64,
238 kind => panic!("{kind:?} is not a valid WgpuElement"),
239 },
240 cube::ElemType::UInt(kind) => match kind {
241 cube::UIntKind::U32 => wgsl::Elem::U32,
242 cube::UIntKind::U64 => wgsl::Elem::U64,
243 kind => panic!("{kind:?} is not a valid WgpuElement"),
244 },
245 cube::ElemType::Bool => wgsl::Elem::Bool,
246 }
247 }
248
249 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
250 let pos = var.index().expect("Variable should have index");
251 self.ext_meta_pos[pos as usize]
252 }
253
254 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
255 let item = value.ty;
256 match value.kind {
257 cube::VariableKind::GlobalInputArray(id) => {
258 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
259 }
260 cube::VariableKind::GlobalScalar(id) => {
261 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
262 }
263 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
264 wgsl::Variable::LocalMut {
265 id,
266 item: self.compile_type(item),
267 }
268 }
269 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
270 id,
271 item: self.compile_type(item),
272 },
273 cube::VariableKind::GlobalOutputArray(id) => {
274 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
275 }
276 cube::VariableKind::ConstantScalar(value) => {
277 wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
278 }
279 cube::VariableKind::SharedArray {
280 id,
281 length,
282 unroll_factor,
283 alignment,
284 } => {
285 let item = self.compile_type(item);
286 if !self.shared_arrays.iter().any(|s| s.index == id) {
287 self.shared_arrays.push(SharedArray::new(
288 id,
289 item,
290 length * unroll_factor,
291 alignment,
292 ));
293 }
294 wgsl::Variable::SharedArray(id, item, length)
295 }
296 cube::VariableKind::Shared { id } => {
297 let item = self.compile_type(item);
298 if !self.shared_values.iter().any(|s| s.index == id) {
299 self.shared_values.push(SharedValue::new(id, item));
300 }
301 wgsl::Variable::SharedValue(id, item)
302 }
303 cube::VariableKind::ConstantArray { id, length, .. } => {
304 let item = self.compile_type(item);
305 wgsl::Variable::ConstantArray(id, item, length)
306 }
307 cube::VariableKind::LocalArray {
308 id,
309 length,
310 unroll_factor,
311 } => {
312 let item = self.compile_type(item);
313 if !self.local_arrays.iter().any(|s| s.index == id) {
314 self.local_arrays
315 .push(LocalArray::new(id, item, length * unroll_factor));
316 }
317 wgsl::Variable::LocalArray(id, item, length)
318 }
319 cube::VariableKind::Builtin(builtin) => match builtin {
320 cube::Builtin::AbsolutePos => {
321 self.id = true;
322 wgsl::Variable::Id
323 }
324 cube::Builtin::UnitPos => {
325 self.local_invocation_index = true;
326 wgsl::Variable::LocalInvocationIndex
327 }
328 cube::Builtin::UnitPosX => {
329 self.local_invocation_id = true;
330 wgsl::Variable::LocalInvocationIdX
331 }
332 cube::Builtin::UnitPosY => {
333 self.local_invocation_id = true;
334 wgsl::Variable::LocalInvocationIdY
335 }
336 cube::Builtin::UnitPosZ => {
337 self.local_invocation_id = true;
338 wgsl::Variable::LocalInvocationIdZ
339 }
340 cube::Builtin::CubePosX => {
341 self.workgroup_id = true;
342 wgsl::Variable::WorkgroupIdX
343 }
344 cube::Builtin::CubePosY => {
345 self.workgroup_id = true;
346 wgsl::Variable::WorkgroupIdY
347 }
348 cube::Builtin::CubePosZ => {
349 self.workgroup_id = true;
350 wgsl::Variable::WorkgroupIdZ
351 }
352 cube::Builtin::CubePosCluster
353 | cube::Builtin::CubePosClusterX
354 | cube::Builtin::CubePosClusterY
355 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
356 cube::Builtin::AbsolutePosX => {
357 self.global_invocation_id = true;
358 wgsl::Variable::GlobalInvocationIdX
359 }
360 cube::Builtin::AbsolutePosY => {
361 self.global_invocation_id = true;
362 wgsl::Variable::GlobalInvocationIdY
363 }
364 cube::Builtin::AbsolutePosZ => {
365 self.global_invocation_id = true;
366 wgsl::Variable::GlobalInvocationIdZ
367 }
368 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
369 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
370 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
371 cube::Builtin::CubeClusterDim
372 | cube::Builtin::CubeClusterDimX
373 | cube::Builtin::CubeClusterDimY
374 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
375 cube::Builtin::CubeCountX => {
376 self.num_workgroups = true;
377 wgsl::Variable::NumWorkgroupsX
378 }
379 cube::Builtin::CubeCountY => {
380 self.num_workgroups = true;
381 wgsl::Variable::NumWorkgroupsY
382 }
383 cube::Builtin::CubeCountZ => {
384 self.num_workgroups = true;
385 wgsl::Variable::NumWorkgroupsZ
386 }
387 cube::Builtin::CubePos => {
388 self.workgroup_id_no_axis = true;
389 wgsl::Variable::WorkgroupId
390 }
391 cube::Builtin::CubeDim => {
392 self.workgroup_size_no_axis = true;
393 wgsl::Variable::WorkgroupSize
394 }
395 cube::Builtin::CubeCount => {
396 self.num_workgroup_no_axis = true;
397 wgsl::Variable::NumWorkgroups
398 }
399 cube::Builtin::PlaneDim => {
400 self.subgroup_size = true;
401 wgsl::Variable::SubgroupSize
402 }
403 cube::Builtin::UnitPosPlane => {
404 self.subgroup_invocation_id = true;
405 wgsl::Variable::SubgroupInvocationId
406 }
407 },
408 cube::VariableKind::Matrix { .. } => {
409 panic!("Cooperative matrix-multiply and accumulate not supported.")
410 }
411 cube::VariableKind::Pipeline { .. } => {
412 panic!("Pipeline not supported.")
413 }
414 cube::VariableKind::BarrierToken { .. } => {
415 panic!("Barrier not supported.")
416 }
417 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
418 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
419 }
420 }
421
422 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
423 let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
424 self.compile_variable(var)
425 }
426
427 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
428 let mut instructions = Vec::new();
429
430 let const_arrays = scope
431 .const_arrays
432 .drain(..)
433 .map(|(var, values)| ConstantArray {
434 index: var.index().unwrap(),
435 item: self.compile_type(var.ty),
436 size: values.len() as u32,
437 values: values
438 .into_iter()
439 .map(|val| self.compile_variable(val))
440 .collect(),
441 })
442 .collect::<Vec<_>>();
443 self.const_arrays.extend(const_arrays);
444
445 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
446 let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
447 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
448 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
449
450 for mut var in processing.variables {
451 if var.ty.line_size() > MAX_LINE_SIZE {
452 var.ty = var.ty.line(MAX_LINE_SIZE);
453 }
454 instructions.push(wgsl::Instruction::DeclareVariable {
455 var: self.compile_variable(var),
456 });
457 }
458
459 processing
460 .instructions
461 .into_iter()
462 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
463
464 instructions
465 }
466
467 fn compile_operation(
468 &mut self,
469 instructions: &mut Vec<wgsl::Instruction>,
470 operation: cube::Operation,
471 out: Option<cube::Variable>,
472 scope: &mut cube::Scope,
473 ) {
474 match operation {
475 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
476 input: self.compile_variable(variable),
477 out: self.compile_variable(out.unwrap()),
478 }),
479 cube::Operation::Arithmetic(op) => {
480 self.compile_arithmetic(op, out, instructions, scope)
481 }
482 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
483 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
484 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
485 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
486 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
487 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
488 cube::Operation::Synchronization(val) => {
489 self.compile_synchronization(instructions, val)
490 }
491 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
492 cube::Operation::CoopMma(_) => {
493 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
494 }
495 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
496 self.compile_comment(instructions, content)
497 }
498 cube::Operation::NonSemantic(_) => {}
499 cube::Operation::Barrier(_) => {
500 panic!("Barrier isn't supported on wgpu.")
501 }
502 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
503 cube::Operation::Marker(_) => {}
504 }
505 }
506
507 fn compile_subgroup(
508 &mut self,
509 instructions: &mut Vec<wgsl::Instruction>,
510 subgroup: cube::Plane,
511 out: Option<cube::Variable>,
512 ) {
513 self.subgroup_instructions_used = true;
514
515 let out = out.unwrap();
516 let op = match subgroup {
517 cube::Plane::Elect => Subgroup::Elect {
518 out: self.compile_variable(out),
519 },
520 cube::Plane::All(op) => Subgroup::All {
521 input: self.compile_variable(op.input),
522 out: self.compile_variable(out),
523 },
524 cube::Plane::Any(op) => Subgroup::Any {
525 input: self.compile_variable(op.input),
526 out: self.compile_variable(out),
527 },
528 cube::Plane::Ballot(op) => Subgroup::Ballot {
529 input: self.compile_variable(op.input),
530 out: self.compile_variable(out),
531 },
532
533 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
534 lhs: self.compile_variable(op.lhs),
535 rhs: self.compile_variable(op.rhs),
536 out: self.compile_variable(out),
537 },
538
539 cube::Plane::Sum(op) => Subgroup::Sum {
540 input: self.compile_variable(op.input),
541 out: self.compile_variable(out),
542 },
543
544 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
545 input: self.compile_variable(op.input),
546 out: self.compile_variable(out),
547 },
548 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
549 input: self.compile_variable(op.input),
550 out: self.compile_variable(out),
551 },
552 cube::Plane::Prod(op) => Subgroup::Prod {
553 input: self.compile_variable(op.input),
554 out: self.compile_variable(out),
555 },
556 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
557 input: self.compile_variable(op.input),
558 out: self.compile_variable(out),
559 },
560 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
561 input: self.compile_variable(op.input),
562 out: self.compile_variable(out),
563 },
564 cube::Plane::Min(op) => Subgroup::Min {
565 input: self.compile_variable(op.input),
566 out: self.compile_variable(out),
567 },
568 cube::Plane::Max(op) => Subgroup::Max {
569 input: self.compile_variable(op.input),
570 out: self.compile_variable(out),
571 },
572 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
573 lhs: self.compile_variable(op.lhs),
574 rhs: self.compile_variable(op.rhs),
575 out: self.compile_variable(out),
576 },
577 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
578 lhs: self.compile_variable(op.lhs),
579 rhs: self.compile_variable(op.rhs),
580 out: self.compile_variable(out),
581 },
582 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
583 lhs: self.compile_variable(op.lhs),
584 rhs: self.compile_variable(op.rhs),
585 out: self.compile_variable(out),
586 },
587 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
588 lhs: self.compile_variable(op.lhs),
589 rhs: self.compile_variable(op.rhs),
590 out: self.compile_variable(out),
591 },
592 };
593
594 instructions.push(wgsl::Instruction::Subgroup(op));
595 }
596
597 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
598 match branch {
599 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
600 cond: self.compile_variable(op.cond),
601 instructions: self.compile_scope(&mut op.scope),
602 }),
603 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
604 cond: self.compile_variable(op.cond),
605 instructions_if: self.compile_scope(&mut op.scope_if),
606 instructions_else: self.compile_scope(&mut op.scope_else),
607 }),
608 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
609 value: self.compile_variable(op.value),
610 instructions_default: self.compile_scope(&mut op.scope_default),
611 cases: op
612 .cases
613 .into_iter()
614 .map(|(val, mut scope)| {
615 (self.compile_variable(val), self.compile_scope(&mut scope))
616 })
617 .collect(),
618 }),
619 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
620 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
621 cube::Branch::RangeLoop(mut range_loop) => {
622 instructions.push(wgsl::Instruction::RangeLoop {
623 i: self.compile_variable(range_loop.i),
624 start: self.compile_variable(range_loop.start),
625 end: self.compile_variable(range_loop.end),
626 step: range_loop.step.map(|it| self.compile_variable(it)),
627 inclusive: range_loop.inclusive,
628 instructions: self.compile_scope(&mut range_loop.scope),
629 })
630 }
631 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
632 instructions: self.compile_scope(&mut op.scope),
633 }),
634 };
635 }
636
637 fn compile_synchronization(
638 &mut self,
639 instructions: &mut Vec<wgsl::Instruction>,
640 synchronization: cube::Synchronization,
641 ) {
642 match synchronization {
643 cube::Synchronization::SyncCube => {
644 instructions.push(wgsl::Instruction::WorkgroupBarrier)
645 }
646 cube::Synchronization::SyncPlane => {
647 panic!("Synchronization within a plane is not supported in WGSL")
648 }
649 cube::Synchronization::SyncStorage => {
650 instructions.push(wgsl::Instruction::StorageBarrier)
651 }
652 cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
653 };
654 }
655
656 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
657 instructions.push(wgsl::Instruction::Comment { content })
658 }
659
660 fn compile_metadata(
661 &mut self,
662 metadata: cube::Metadata,
663 out: Option<cube::Variable>,
664 ) -> wgsl::Instruction {
665 let out = out.unwrap();
666 match metadata {
667 cube::Metadata::Rank { var } => {
668 let position = self.ext_meta_pos(&var);
669 let offset = self.metadata.rank_index(position);
670 wgsl::Instruction::Metadata {
671 out: self.compile_variable(out),
672 info_offset: self.compile_variable(offset.into()),
673 }
674 }
675 cube::Metadata::Stride { dim, var } => {
676 let position = self.ext_meta_pos(&var);
677 let offset = self.metadata.stride_offset_index(position);
678 wgsl::Instruction::ExtendedMeta {
679 info_offset: self.compile_variable(offset.into()),
680 dim: self.compile_variable(dim),
681 out: self.compile_variable(out),
682 }
683 }
684 cube::Metadata::Shape { dim, var } => {
685 let position = self.ext_meta_pos(&var);
686 let offset = self.metadata.shape_offset_index(position);
687 wgsl::Instruction::ExtendedMeta {
688 info_offset: self.compile_variable(offset.into()),
689 dim: self.compile_variable(dim),
690 out: self.compile_variable(out),
691 }
692 }
693 cube::Metadata::Length { var } => match var.kind {
694 cube::VariableKind::GlobalInputArray(id) => {
695 let offset = self.metadata.len_index(id);
696 wgsl::Instruction::Metadata {
697 out: self.compile_variable(out),
698 info_offset: self.compile_variable(offset.into()),
699 }
700 }
701 cube::VariableKind::GlobalOutputArray(id) => {
702 let offset = self.metadata.len_index(id);
703 wgsl::Instruction::Metadata {
704 out: self.compile_variable(out),
705 info_offset: self.compile_variable(offset.into()),
706 }
707 }
708 _ => wgsl::Instruction::Length {
709 var: self.compile_variable(var),
710 out: self.compile_variable(out),
711 },
712 },
713 cube::Metadata::BufferLength { var } => match var.kind {
714 cube::VariableKind::GlobalInputArray(id) => {
715 let offset = self.metadata.buffer_len_index(id);
716 wgsl::Instruction::Metadata {
717 out: self.compile_variable(out),
718 info_offset: self.compile_variable(offset.into()),
719 }
720 }
721 cube::VariableKind::GlobalOutputArray(id) => {
722 let offset = self.metadata.buffer_len_index(id);
723 wgsl::Instruction::Metadata {
724 out: self.compile_variable(out),
725 info_offset: self.compile_variable(offset.into()),
726 }
727 }
728 _ => wgsl::Instruction::Length {
729 var: self.compile_variable(var),
730 out: self.compile_variable(out),
731 },
732 },
733 }
734 }
735
736 fn compile_arithmetic(
737 &mut self,
738 value: cube::Arithmetic,
739 out: Option<cube::Variable>,
740 instructions: &mut Vec<wgsl::Instruction>,
741 scope: &mut Scope,
742 ) {
743 let out = out.unwrap();
744 match value {
745 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
746 lhs: self.compile_variable(op.lhs),
747 rhs: self.compile_variable(op.rhs),
748 out: self.compile_variable(out),
749 }),
750 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
751 lhs: self.compile_variable(op.lhs),
752 rhs: self.compile_variable(op.rhs),
753 out: self.compile_variable(out),
754 }),
755 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
756 lhs: self.compile_variable(op.lhs),
757 rhs: self.compile_variable(op.rhs),
758 out: self.compile_variable(out),
759 }),
760 cube::Arithmetic::SaturatingAdd(_) => {
761 unreachable!("Saturating add should be removed by processor");
762 }
763 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
764 a: self.compile_variable(op.a),
765 b: self.compile_variable(op.b),
766 c: self.compile_variable(op.c),
767 out: self.compile_variable(out),
768 }),
769 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
770 lhs: self.compile_variable(op.lhs),
771 rhs: self.compile_variable(op.rhs),
772 out: self.compile_variable(out),
773 }),
774 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
775 lhs: self.compile_variable(op.lhs),
776 rhs: self.compile_variable(op.rhs),
777 out: self.compile_variable(out),
778 }),
779 cube::Arithmetic::SaturatingSub(_) => {
780 unreachable!("Saturating sub should be removed by processor");
781 }
782 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
783 lhs: self.compile_variable(op.lhs),
784 rhs: self.compile_variable(op.rhs),
785 out: self.compile_variable(out),
786 }),
787 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
788 lhs: self.compile_variable(op.lhs),
789 rhs: self.compile_variable(op.rhs),
790 out: self.compile_variable(out),
791 }),
792 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
793 input: self.compile_variable(op.input),
794 out: self.compile_variable(out),
795 }),
796 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
797 input: self.compile_variable(op.input),
798 out: self.compile_variable(out),
799 }),
800 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
801 input: self.compile_variable(op.input),
802 out: self.compile_variable(out),
803 }),
804 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
805 input: self.compile_variable(op.input),
806 out: self.compile_variable(out),
807 }),
808 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
809 input: self.compile_variable(op.input),
810 out: self.compile_variable(out),
811 }),
812 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
813 input: self.compile_variable(op.input),
814 out: self.compile_variable(out),
815 }),
816 cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
817 input: self.compile_variable(op.input),
818 out: self.compile_variable(out),
819 }),
820 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
821 input: self.compile_variable(op.input),
822 out: self.compile_variable(out),
823 }),
824 cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
825 input: self.compile_variable(op.input),
826 out: self.compile_variable(out),
827 }),
828 cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
829 input: self.compile_variable(op.input),
830 out: self.compile_variable(out),
831 }),
832 cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
833 input: self.compile_variable(op.input),
834 out: self.compile_variable(out),
835 }),
836 cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
837 input: self.compile_variable(op.input),
838 out: self.compile_variable(out),
839 }),
840 cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
841 input: self.compile_variable(op.input),
842 out: self.compile_variable(out),
843 }),
844 cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
845 input: self.compile_variable(op.input),
846 out: self.compile_variable(out),
847 }),
848 cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
849 input: self.compile_variable(op.input),
850 out: self.compile_variable(out),
851 }),
852 cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
853 input: self.compile_variable(op.input),
854 out: self.compile_variable(out),
855 }),
856 cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
857 input: self.compile_variable(op.input),
858 out: self.compile_variable(out),
859 }),
860 cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
861 input: self.compile_variable(op.input),
862 out: self.compile_variable(out),
863 }),
864 cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
865 lhs: self.compile_variable(op.lhs),
866 rhs: self.compile_variable(op.rhs),
867 out: self.compile_variable(out),
868 }),
869 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
871 instructions.push(wgsl::Instruction::Powf {
872 lhs: self.compile_variable(op.lhs),
873 rhs: self.compile_variable(op.rhs),
874 out: self.compile_variable(out),
875 })
876 }
877 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
878 input: self.compile_variable(op.input),
879 out: self.compile_variable(out),
880 }),
881 cube::Arithmetic::InverseSqrt(op) => {
882 instructions.push(wgsl::Instruction::InverseSqrt {
883 input: self.compile_variable(op.input),
884 out: self.compile_variable(out),
885 })
886 }
887 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
888 input: self.compile_variable(op.input),
889 out: self.compile_variable(out),
890 }),
891 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
892 input: self.compile_variable(op.input),
893 out: self.compile_variable(out),
894 }),
895 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
896 input: self.compile_variable(op.input),
897 out: self.compile_variable(out),
898 }),
899 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
900 input: self.compile_variable(op.input),
901 out: self.compile_variable(out),
902 }),
903 cube::Arithmetic::Erf(op) => {
904 let mut scope = scope.child();
905 expand_erf(&mut scope, op.input, out);
906 instructions.extend(self.compile_scope(&mut scope));
907 }
908 cube::Arithmetic::MulHi(op) => {
909 let mut scope = scope.child();
910 match self.compilation_options.supports_u64 {
911 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
912 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
913 }
914 instructions.extend(self.compile_scope(&mut scope));
915 }
916 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
917 input: self.compile_variable(op.input),
918 out: self.compile_variable(out),
919 }),
920 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
921 input: self.compile_variable(op.input),
922 min_value: self.compile_variable(op.min_value),
923 max_value: self.compile_variable(op.max_value),
924 out: self.compile_variable(out),
925 }),
926 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
927 lhs: self.compile_variable(op.lhs),
928 rhs: self.compile_variable(op.rhs),
929 out: self.compile_variable(out),
930 }),
931 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
932 input: self.compile_variable(op.input),
933 out: self.compile_variable(out),
934 }),
935 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
936 input: self.compile_variable(op.input),
937 out: self.compile_variable(out),
938 }),
939 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
940 input: self.compile_variable(op.input),
941 out: self.compile_variable(out),
942 }),
943 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
944 lhs: self.compile_variable(op.lhs),
945 rhs: self.compile_variable(op.rhs),
946 out: self.compile_variable(out),
947 }),
948 }
949 }
950
951 fn compile_cmp(
952 &mut self,
953 value: cube::Comparison,
954 out: Option<cube::Variable>,
955 instructions: &mut Vec<wgsl::Instruction>,
956 ) {
957 let out = out.unwrap();
958 match value {
959 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
960 lhs: self.compile_variable(op.lhs),
961 rhs: self.compile_variable(op.rhs),
962 out: self.compile_variable(out),
963 }),
964 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
965 lhs: self.compile_variable(op.lhs),
966 rhs: self.compile_variable(op.rhs),
967 out: self.compile_variable(out),
968 }),
969 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
970 lhs: self.compile_variable(op.lhs),
971 rhs: self.compile_variable(op.rhs),
972 out: self.compile_variable(out),
973 }),
974 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
975 lhs: self.compile_variable(op.lhs),
976 rhs: self.compile_variable(op.rhs),
977 out: self.compile_variable(out),
978 }),
979 cube::Comparison::GreaterEqual(op) => {
980 instructions.push(wgsl::Instruction::GreaterEqual {
981 lhs: self.compile_variable(op.lhs),
982 rhs: self.compile_variable(op.rhs),
983 out: self.compile_variable(out),
984 })
985 }
986 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
987 lhs: self.compile_variable(op.lhs),
988 rhs: self.compile_variable(op.rhs),
989 out: self.compile_variable(out),
990 }),
991 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
992 input: self.compile_variable(op.input),
993 out: self.compile_variable(out),
994 }),
995 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
996 input: self.compile_variable(op.input),
997 out: self.compile_variable(out),
998 }),
999 }
1000 }
1001
1002 fn compile_bitwise(
1003 &mut self,
1004 value: cube::Bitwise,
1005 out: Option<cube::Variable>,
1006 instructions: &mut Vec<wgsl::Instruction>,
1007 ) {
1008 let out = out.unwrap();
1009 match value {
1010 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1011 lhs: self.compile_variable(op.lhs),
1012 rhs: self.compile_variable(op.rhs),
1013 out: self.compile_variable(out),
1014 }),
1015 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1016 lhs: self.compile_variable(op.lhs),
1017 rhs: self.compile_variable(op.rhs),
1018 out: self.compile_variable(out),
1019 }),
1020 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1021 lhs: self.compile_variable(op.lhs),
1022 rhs: self.compile_variable(op.rhs),
1023 out: self.compile_variable(out),
1024 }),
1025 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1026 input: self.compile_variable(op.input),
1027 out: self.compile_variable(out),
1028 }),
1029 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1030 input: self.compile_variable(op.input),
1031 out: self.compile_variable(out),
1032 }),
1033 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1034 lhs: self.compile_variable(op.lhs),
1035 rhs: self.compile_variable(op.rhs),
1036 out: self.compile_variable(out),
1037 }),
1038 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1039 lhs: self.compile_variable(op.lhs),
1040 rhs: self.compile_variable(op.rhs),
1041 out: self.compile_variable(out),
1042 }),
1043 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1044 input: self.compile_variable(op.input),
1045 out: self.compile_variable(out),
1046 }),
1047 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1048 input: self.compile_variable(op.input),
1049 out: self.compile_variable(out),
1050 }),
1051 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1052 input: self.compile_variable(op.input),
1053 out: self.compile_variable(out),
1054 }),
1055 }
1056 }
1057
1058 fn compile_operator(
1059 &mut self,
1060 value: cube::Operator,
1061 out: Option<cube::Variable>,
1062 instructions: &mut Vec<wgsl::Instruction>,
1063 ) {
1064 let out = out.unwrap();
1065 match value {
1066 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1067 input: self.compile_variable(op.input),
1068 out: self.compile_variable(out),
1069 }),
1070 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1071 instructions.push(wgsl::Instruction::Index {
1072 lhs: self.compile_variable(op.list),
1073 rhs: self.compile_variable(op.index),
1074 out: self.compile_variable(out),
1075 });
1076 }
1077 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1078 instructions.push(wgsl::Instruction::IndexAssign {
1079 index: self.compile_variable(op.index),
1080 rhs: self.compile_variable(op.value),
1081 out: self.compile_variable(out),
1082 })
1083 }
1084 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1085 lhs: self.compile_variable(op.lhs),
1086 rhs: self.compile_variable(op.rhs),
1087 out: self.compile_variable(out),
1088 }),
1089 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1090 lhs: self.compile_variable(op.lhs),
1091 rhs: self.compile_variable(op.rhs),
1092 out: self.compile_variable(out),
1093 }),
1094 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1095 input: self.compile_variable(op.input),
1096 out: self.compile_variable(out),
1097 }),
1098 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1099 input: self.compile_variable(op.input),
1100 out: self.compile_variable(out),
1101 }),
1102 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1103 inputs: op
1104 .inputs
1105 .into_iter()
1106 .map(|var| self.compile_variable(var))
1107 .collect(),
1108 out: self.compile_variable(out),
1109 }),
1110 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1111 input: self.compile_variable(op.input),
1112 in_index: self.compile_variable(op.in_index),
1113 out: self.compile_variable(out),
1114 out_index: self.compile_variable(op.out_index),
1115 }),
1116 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1117 input: self.compile_variable(op.input),
1118 in_index: self.compile_variable(op.in_index),
1119 out: self.compile_variable(out),
1120 out_index: self.compile_variable(op.out_index),
1121 len: op.len,
1122 }),
1123 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1124 cond: self.compile_variable(op.cond),
1125 then: self.compile_variable(op.then),
1126 or_else: self.compile_variable(op.or_else),
1127 out: self.compile_variable(out),
1128 }),
1129 }
1130 }
1131
1132 fn compile_atomic(
1133 &mut self,
1134 atomic: cube::AtomicOp,
1135 out: Option<cube::Variable>,
1136 ) -> wgsl::Instruction {
1137 let out = out.unwrap();
1138 match atomic {
1139 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1140 lhs: self.compile_variable(op.lhs),
1141 rhs: self.compile_variable(op.rhs),
1142 out: self.compile_variable(out),
1143 },
1144 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1145 lhs: self.compile_variable(op.lhs),
1146 rhs: self.compile_variable(op.rhs),
1147 out: self.compile_variable(out),
1148 },
1149 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1150 lhs: self.compile_variable(op.lhs),
1151 rhs: self.compile_variable(op.rhs),
1152 out: self.compile_variable(out),
1153 },
1154 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1155 lhs: self.compile_variable(op.lhs),
1156 rhs: self.compile_variable(op.rhs),
1157 out: self.compile_variable(out),
1158 },
1159 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1160 lhs: self.compile_variable(op.lhs),
1161 rhs: self.compile_variable(op.rhs),
1162 out: self.compile_variable(out),
1163 },
1164 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1165 lhs: self.compile_variable(op.lhs),
1166 rhs: self.compile_variable(op.rhs),
1167 out: self.compile_variable(out),
1168 },
1169 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1170 lhs: self.compile_variable(op.lhs),
1171 rhs: self.compile_variable(op.rhs),
1172 out: self.compile_variable(out),
1173 },
1174 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1175 input: self.compile_variable(op.input),
1176 out: self.compile_variable(out),
1177 },
1178 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1179 input: self.compile_variable(op.input),
1180 out: self.compile_variable(out),
1181 },
1182 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1183 lhs: self.compile_variable(op.lhs),
1184 rhs: self.compile_variable(op.rhs),
1185 out: self.compile_variable(out),
1186 },
1187 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1188 lhs: self.compile_variable(op.input),
1189 cmp: self.compile_variable(op.cmp),
1190 value: self.compile_variable(op.val),
1191 out: self.compile_variable(out),
1192 },
1193 }
1194 }
1195
1196 fn compile_location(value: kernel::Location) -> wgsl::Location {
1197 match value {
1198 kernel::Location::Storage => wgsl::Location::Storage,
1199 kernel::Location::Cube => wgsl::Location::Workgroup,
1200 }
1201 }
1202
1203 fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1204 wgsl::Binding {
1205 id: value.id,
1206 visibility: value.visibility,
1207 location: Self::compile_location(value.location),
1208 item: self.compile_type(value.ty),
1209 size: value.size,
1210 }
1211 }
1212}
1213
1214fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1215 let mut extensions = Vec::new();
1216
1217 let mut register_extension = |extension: wgsl::Extension| {
1218 if !extensions.contains(&extension) {
1219 extensions.push(extension);
1220 }
1221 };
1222
1223 for instruction in instructions {
1225 match instruction {
1226 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1227 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1228 register_extension(wgsl::powf_extension(rhs, out));
1229 }
1230 #[cfg(target_os = "macos")]
1231 wgsl::Instruction::Tanh { input, out: _ } => {
1232 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1233 register_extension(wgsl::Extension::SafeTanh(input.item()));
1234 }
1235 wgsl::Instruction::IsNan { input, out } => {
1236 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1237 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1238 }
1239 wgsl::Instruction::IsInf { input, out } => {
1240 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1241 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1242 }
1243 wgsl::Instruction::If { instructions, .. } => {
1244 for extension in register_extensions(instructions) {
1245 register_extension(extension);
1246 }
1247 }
1248 wgsl::Instruction::IfElse {
1249 instructions_if,
1250 instructions_else,
1251 ..
1252 } => {
1253 for extension in register_extensions(instructions_if) {
1254 register_extension(extension);
1255 }
1256 for extension in register_extensions(instructions_else) {
1257 register_extension(extension);
1258 }
1259 }
1260 wgsl::Instruction::Loop { instructions } => {
1261 for extension in register_extensions(instructions) {
1262 register_extension(extension);
1263 }
1264 }
1265 wgsl::Instruction::RangeLoop { instructions, .. } => {
1266 for extension in register_extensions(instructions) {
1267 register_extension(extension);
1268 }
1269 }
1270 _ => {}
1271 }
1272 }
1273
1274 extensions
1275}