1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedMemory};
4use crate::compiler::wgsl;
5
6use cubecl_common::ExecutionMode;
7use cubecl_core::post_processing::{
8 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12 Metadata, WgpuCompilationOptions,
13 ir::{self as cube, Scope},
14 prelude::expand_erf,
15};
16use cubecl_core::{
17 ir::{ConstantScalarValue, Processor, UIntKind},
18 post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: u32 = 4;
24
25#[derive(Clone, Default)]
27pub struct WgslCompiler {
28 metadata: Metadata,
29 ext_meta_pos: Vec<u32>,
30 local_invocation_index: bool,
31 local_invocation_id: bool,
32 global_invocation_id: bool,
34 workgroup_id: bool,
35 subgroup_size: bool,
36 subgroup_invocation_id: bool,
37 id: bool,
38 num_workgroups: bool,
39 workgroup_id_no_axis: bool,
40 workgroup_size_no_axis: bool,
41 num_workgroup_no_axis: bool,
42 shared_memories: Vec<SharedMemory>,
43 const_arrays: Vec<ConstantArray>,
44 local_arrays: Vec<LocalArray>,
45 #[allow(dead_code)]
46 compilation_options: WgpuCompilationOptions,
47 strategy: ExecutionMode,
48 subgroup_instructions_used: bool,
49 f16_used: bool,
50}
51
52impl core::fmt::Debug for WgslCompiler {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.write_str("WgslCompiler")
55 }
56}
57
58impl cubecl_core::Compiler for WgslCompiler {
59 type Representation = ComputeShader;
60 type CompilationOptions = WgpuCompilationOptions;
61
62 fn compile(
63 &mut self,
64 shader: kernel::KernelDefinition,
65 compilation_options: &Self::CompilationOptions,
66 mode: ExecutionMode,
67 ) -> Result<Self::Representation, CompilationError> {
68 self.compilation_options = compilation_options.clone();
69 self.compile_shader(shader, mode)
70 }
71
72 fn elem_size(&self, elem: cube::ElemType) -> usize {
73 elem.size()
74 }
75
76 fn extension(&self) -> &'static str {
77 "wgsl"
78 }
79}
80
81impl WgslCompiler {
82 fn compile_shader(
83 &mut self,
84 mut value: kernel::KernelDefinition,
85 mode: ExecutionMode,
86 ) -> Result<wgsl::ComputeShader, CompilationError> {
87 self.strategy = mode;
88
89 let num_meta = value.buffers.len();
90
91 self.ext_meta_pos = Vec::new();
92 let mut num_ext = 0;
93
94 for binding in value.buffers.iter() {
95 self.ext_meta_pos.push(num_ext);
96 if binding.has_extended_meta {
97 num_ext += 1;
98 }
99 }
100
101 self.metadata = Metadata::new(num_meta as u32, num_ext);
102
103 let instructions = self.compile_scope(&mut value.body);
104 let extensions = register_extensions(&instructions);
105 let body = wgsl::Body {
106 instructions,
107 id: self.id,
108 };
109
110 Ok(wgsl::ComputeShader {
111 buffers: value
112 .buffers
113 .into_iter()
114 .map(|mut it| {
115 if it.ty.line_size() > MAX_LINE_SIZE {
118 it.ty = it.ty.line(MAX_LINE_SIZE);
119 }
120 self.compile_binding(it)
121 })
122 .collect(),
123 scalars: value
124 .scalars
125 .into_iter()
126 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
127 .collect(),
128 shared_memories: self.shared_memories.clone(),
129 constant_arrays: self.const_arrays.clone(),
130 local_arrays: self.local_arrays.clone(),
131 has_metadata: self.metadata.static_len() > 0,
132 workgroup_size: value.cube_dim,
133 global_invocation_id: self.global_invocation_id || self.id,
134 local_invocation_index: self.local_invocation_index,
135 local_invocation_id: self.local_invocation_id,
136 num_workgroups: self.id
137 || self.num_workgroups
138 || self.num_workgroup_no_axis
139 || self.workgroup_id_no_axis,
140 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
141 subgroup_size: self.subgroup_size,
142 subgroup_invocation_id: self.subgroup_invocation_id,
143 body,
144 extensions,
145 num_workgroups_no_axis: self.num_workgroup_no_axis,
146 workgroup_id_no_axis: self.workgroup_id_no_axis,
147 workgroup_size_no_axis: self.workgroup_size_no_axis,
148 subgroup_instructions_used: self.subgroup_instructions_used,
149 f16_used: self.f16_used,
150 kernel_name: value.options.kernel_name,
151 })
152 }
153
154 fn compile_type(&mut self, item: cube::Type) -> Item {
155 match item {
156 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
157 cube::Type::Line(ty, size) => {
158 let elem = self.compile_storage_type(ty);
159 match size {
160 2 => wgsl::Item::Vec2(elem),
161 3 => wgsl::Item::Vec3(elem),
162 4 => wgsl::Item::Vec4(elem),
163 _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
164 }
165 }
166 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
167 }
168 }
169
170 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
171 match ty {
172 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
173 cube::StorageType::Atomic(ty) => match ty {
174 cube::ElemType::Float(i) => match i {
175 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
176 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
177 },
178 cube::ElemType::Int(i) => match i {
179 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
180 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
181 },
182 cube::ElemType::UInt(kind) => match kind {
183 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
184 kind => panic!("{kind:?} is not a valid WgpuElement"),
185 },
186 other => panic!("{other:?} is not a valid WgpuElement"),
187 },
188 cube::StorageType::Packed(_, _) => {
189 unimplemented!("Packed types not yet supported in WGSL")
190 }
191 }
192 }
193
194 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
195 match value {
196 cube::ElemType::Float(f) => match f {
197 cube::FloatKind::E2M1
198 | cube::FloatKind::E2M3
199 | cube::FloatKind::E3M2
200 | cube::FloatKind::E4M3
201 | cube::FloatKind::E5M2
202 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
203 cube::FloatKind::F16 => {
204 self.f16_used = true;
205 wgsl::Elem::F16
206 }
207 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
208 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
209 cube::FloatKind::Flex32 => wgsl::Elem::F32,
210 cube::FloatKind::F32 => wgsl::Elem::F32,
211 cube::FloatKind::F64 => wgsl::Elem::F64,
212 },
213 cube::ElemType::Int(i) => match i {
214 cube::IntKind::I32 => wgsl::Elem::I32,
215 cube::IntKind::I64 => wgsl::Elem::I64,
216 kind => panic!("{kind:?} is not a valid WgpuElement"),
217 },
218 cube::ElemType::UInt(kind) => match kind {
219 cube::UIntKind::U32 => wgsl::Elem::U32,
220 cube::UIntKind::U64 => wgsl::Elem::U64,
221 kind => panic!("{kind:?} is not a valid WgpuElement"),
222 },
223 cube::ElemType::Bool => wgsl::Elem::Bool,
224 }
225 }
226
227 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
228 let pos = var.index().expect("Variable should have index");
229 self.ext_meta_pos[pos as usize]
230 }
231
232 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
233 let item = value.ty;
234 match value.kind {
235 cube::VariableKind::GlobalInputArray(id) => {
236 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
237 }
238 cube::VariableKind::GlobalScalar(id) => {
239 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
240 }
241 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
242 wgsl::Variable::LocalMut {
243 id,
244 item: self.compile_type(item),
245 }
246 }
247 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
248 id,
249 item: self.compile_type(item),
250 },
251 cube::VariableKind::GlobalOutputArray(id) => {
252 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
253 }
254 cube::VariableKind::ConstantScalar(value) => {
255 wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
256 }
257 cube::VariableKind::SharedMemory {
258 id,
259 length,
260 unroll_factor,
261 alignment,
262 } => {
263 let item = self.compile_type(item);
264 if !self.shared_memories.iter().any(|s| s.index == id) {
265 self.shared_memories.push(SharedMemory::new(
266 id,
267 item,
268 length * unroll_factor,
269 alignment,
270 ));
271 }
272 wgsl::Variable::SharedMemory(id, item, length)
273 }
274 cube::VariableKind::ConstantArray { id, length, .. } => {
275 let item = self.compile_type(item);
276 wgsl::Variable::ConstantArray(id, item, length)
277 }
278 cube::VariableKind::LocalArray {
279 id,
280 length,
281 unroll_factor,
282 } => {
283 let item = self.compile_type(item);
284 if !self.local_arrays.iter().any(|s| s.index == id) {
285 self.local_arrays
286 .push(LocalArray::new(id, item, length * unroll_factor));
287 }
288 wgsl::Variable::LocalArray(id, item, length)
289 }
290 cube::VariableKind::Builtin(builtin) => match builtin {
291 cube::Builtin::AbsolutePos => {
292 self.id = true;
293 wgsl::Variable::Id
294 }
295 cube::Builtin::UnitPos => {
296 self.local_invocation_index = true;
297 wgsl::Variable::LocalInvocationIndex
298 }
299 cube::Builtin::UnitPosX => {
300 self.local_invocation_id = true;
301 wgsl::Variable::LocalInvocationIdX
302 }
303 cube::Builtin::UnitPosY => {
304 self.local_invocation_id = true;
305 wgsl::Variable::LocalInvocationIdY
306 }
307 cube::Builtin::UnitPosZ => {
308 self.local_invocation_id = true;
309 wgsl::Variable::LocalInvocationIdZ
310 }
311 cube::Builtin::CubePosX => {
312 self.workgroup_id = true;
313 wgsl::Variable::WorkgroupIdX
314 }
315 cube::Builtin::CubePosY => {
316 self.workgroup_id = true;
317 wgsl::Variable::WorkgroupIdY
318 }
319 cube::Builtin::CubePosZ => {
320 self.workgroup_id = true;
321 wgsl::Variable::WorkgroupIdZ
322 }
323 cube::Builtin::CubePosCluster
324 | cube::Builtin::CubePosClusterX
325 | cube::Builtin::CubePosClusterY
326 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
327 cube::Builtin::AbsolutePosX => {
328 self.global_invocation_id = true;
329 wgsl::Variable::GlobalInvocationIdX
330 }
331 cube::Builtin::AbsolutePosY => {
332 self.global_invocation_id = true;
333 wgsl::Variable::GlobalInvocationIdY
334 }
335 cube::Builtin::AbsolutePosZ => {
336 self.global_invocation_id = true;
337 wgsl::Variable::GlobalInvocationIdZ
338 }
339 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
340 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
341 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
342 cube::Builtin::CubeClusterDim
343 | cube::Builtin::CubeClusterDimX
344 | cube::Builtin::CubeClusterDimY
345 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
346 cube::Builtin::CubeCountX => {
347 self.num_workgroups = true;
348 wgsl::Variable::NumWorkgroupsX
349 }
350 cube::Builtin::CubeCountY => {
351 self.num_workgroups = true;
352 wgsl::Variable::NumWorkgroupsY
353 }
354 cube::Builtin::CubeCountZ => {
355 self.num_workgroups = true;
356 wgsl::Variable::NumWorkgroupsZ
357 }
358 cube::Builtin::CubePos => {
359 self.workgroup_id_no_axis = true;
360 wgsl::Variable::WorkgroupId
361 }
362 cube::Builtin::CubeDim => {
363 self.workgroup_size_no_axis = true;
364 wgsl::Variable::WorkgroupSize
365 }
366 cube::Builtin::CubeCount => {
367 self.num_workgroup_no_axis = true;
368 wgsl::Variable::NumWorkgroups
369 }
370 cube::Builtin::PlaneDim => {
371 self.subgroup_size = true;
372 wgsl::Variable::SubgroupSize
373 }
374 cube::Builtin::UnitPosPlane => {
375 self.subgroup_invocation_id = true;
376 wgsl::Variable::SubgroupInvocationId
377 }
378 },
379 cube::VariableKind::Matrix { .. } => {
380 panic!("Cooperative matrix-multiply and accumulate not supported.")
381 }
382 cube::VariableKind::Pipeline { .. } => {
383 panic!("Pipeline not supported.")
384 }
385 cube::VariableKind::Barrier { .. } | cube::VariableKind::BarrierToken { .. } => {
386 panic!("Barrier not supported.")
387 }
388 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
389 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
390 }
391 }
392
393 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
394 let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
395 self.compile_variable(var)
396 }
397
398 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
399 let mut instructions = Vec::new();
400
401 let const_arrays = scope
402 .const_arrays
403 .drain(..)
404 .map(|(var, values)| ConstantArray {
405 index: var.index().unwrap(),
406 item: self.compile_type(var.ty),
407 size: values.len() as u32,
408 values: values
409 .into_iter()
410 .map(|val| self.compile_variable(val))
411 .collect(),
412 })
413 .collect::<Vec<_>>();
414 self.const_arrays.extend(const_arrays);
415
416 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
417 let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
418 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
419 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
420
421 for mut var in processing.variables {
422 if var.ty.line_size() > MAX_LINE_SIZE {
423 var.ty = var.ty.line(MAX_LINE_SIZE);
424 }
425 instructions.push(wgsl::Instruction::DeclareVariable {
426 var: self.compile_variable(var),
427 });
428 }
429
430 processing
431 .instructions
432 .into_iter()
433 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
434
435 instructions
436 }
437
438 fn compile_operation(
439 &mut self,
440 instructions: &mut Vec<wgsl::Instruction>,
441 operation: cube::Operation,
442 out: Option<cube::Variable>,
443 scope: &mut cube::Scope,
444 ) {
445 match operation {
446 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
447 input: self.compile_variable(variable),
448 out: self.compile_variable(out.unwrap()),
449 }),
450 cube::Operation::Arithmetic(op) => {
451 self.compile_arithmetic(op, out, instructions, scope)
452 }
453 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
454 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
455 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
456 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
457 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
458 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
459 cube::Operation::Synchronization(val) => {
460 self.compile_synchronization(instructions, val)
461 }
462 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
463 cube::Operation::CoopMma(_) => {
464 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
465 }
466 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
467 self.compile_comment(instructions, content)
468 }
469 cube::Operation::NonSemantic(_) => {}
470 cube::Operation::Barrier(_) => {
471 panic!("Barrier isn't supported on wgpu.")
472 }
473 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
474 cube::Operation::Marker(_) => {}
475 }
476 }
477
478 fn compile_subgroup(
479 &mut self,
480 instructions: &mut Vec<wgsl::Instruction>,
481 subgroup: cube::Plane,
482 out: Option<cube::Variable>,
483 ) {
484 self.subgroup_instructions_used = true;
485
486 let out = out.unwrap();
487 let op = match subgroup {
488 cube::Plane::Elect => Subgroup::Elect {
489 out: self.compile_variable(out),
490 },
491 cube::Plane::All(op) => Subgroup::All {
492 input: self.compile_variable(op.input),
493 out: self.compile_variable(out),
494 },
495 cube::Plane::Any(op) => Subgroup::Any {
496 input: self.compile_variable(op.input),
497 out: self.compile_variable(out),
498 },
499 cube::Plane::Ballot(op) => Subgroup::Ballot {
500 input: self.compile_variable(op.input),
501 out: self.compile_variable(out),
502 },
503
504 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
505 lhs: self.compile_variable(op.lhs),
506 rhs: self.compile_variable(op.rhs),
507 out: self.compile_variable(out),
508 },
509
510 cube::Plane::Sum(op) => Subgroup::Sum {
511 input: self.compile_variable(op.input),
512 out: self.compile_variable(out),
513 },
514
515 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
516 input: self.compile_variable(op.input),
517 out: self.compile_variable(out),
518 },
519 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
520 input: self.compile_variable(op.input),
521 out: self.compile_variable(out),
522 },
523 cube::Plane::Prod(op) => Subgroup::Prod {
524 input: self.compile_variable(op.input),
525 out: self.compile_variable(out),
526 },
527 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
528 input: self.compile_variable(op.input),
529 out: self.compile_variable(out),
530 },
531 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
532 input: self.compile_variable(op.input),
533 out: self.compile_variable(out),
534 },
535 cube::Plane::Min(op) => Subgroup::Min {
536 input: self.compile_variable(op.input),
537 out: self.compile_variable(out),
538 },
539 cube::Plane::Max(op) => Subgroup::Max {
540 input: self.compile_variable(op.input),
541 out: self.compile_variable(out),
542 },
543 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
544 lhs: self.compile_variable(op.lhs),
545 rhs: self.compile_variable(op.rhs),
546 out: self.compile_variable(out),
547 },
548 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
549 lhs: self.compile_variable(op.lhs),
550 rhs: self.compile_variable(op.rhs),
551 out: self.compile_variable(out),
552 },
553 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
554 lhs: self.compile_variable(op.lhs),
555 rhs: self.compile_variable(op.rhs),
556 out: self.compile_variable(out),
557 },
558 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
559 lhs: self.compile_variable(op.lhs),
560 rhs: self.compile_variable(op.rhs),
561 out: self.compile_variable(out),
562 },
563 };
564
565 instructions.push(wgsl::Instruction::Subgroup(op));
566 }
567
568 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
569 match branch {
570 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
571 cond: self.compile_variable(op.cond),
572 instructions: self.compile_scope(&mut op.scope),
573 }),
574 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
575 cond: self.compile_variable(op.cond),
576 instructions_if: self.compile_scope(&mut op.scope_if),
577 instructions_else: self.compile_scope(&mut op.scope_else),
578 }),
579 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
580 value: self.compile_variable(op.value),
581 instructions_default: self.compile_scope(&mut op.scope_default),
582 cases: op
583 .cases
584 .into_iter()
585 .map(|(val, mut scope)| {
586 (self.compile_variable(val), self.compile_scope(&mut scope))
587 })
588 .collect(),
589 }),
590 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
591 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
592 cube::Branch::RangeLoop(mut range_loop) => {
593 instructions.push(wgsl::Instruction::RangeLoop {
594 i: self.compile_variable(range_loop.i),
595 start: self.compile_variable(range_loop.start),
596 end: self.compile_variable(range_loop.end),
597 step: range_loop.step.map(|it| self.compile_variable(it)),
598 inclusive: range_loop.inclusive,
599 instructions: self.compile_scope(&mut range_loop.scope),
600 })
601 }
602 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
603 instructions: self.compile_scope(&mut op.scope),
604 }),
605 };
606 }
607
608 fn compile_synchronization(
609 &mut self,
610 instructions: &mut Vec<wgsl::Instruction>,
611 synchronization: cube::Synchronization,
612 ) {
613 match synchronization {
614 cube::Synchronization::SyncCube => {
615 instructions.push(wgsl::Instruction::WorkgroupBarrier)
616 }
617 cube::Synchronization::SyncPlane => {
618 panic!("Synchronization within a plane is not supported in WGSL")
619 }
620 cube::Synchronization::SyncStorage => {
621 instructions.push(wgsl::Instruction::StorageBarrier)
622 }
623 cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
624 };
625 }
626
627 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
628 instructions.push(wgsl::Instruction::Comment { content })
629 }
630
631 fn compile_metadata(
632 &mut self,
633 metadata: cube::Metadata,
634 out: Option<cube::Variable>,
635 ) -> wgsl::Instruction {
636 let out = out.unwrap();
637 match metadata {
638 cube::Metadata::Rank { var } => {
639 let position = self.ext_meta_pos(&var);
640 let offset = self.metadata.rank_index(position);
641 wgsl::Instruction::Metadata {
642 out: self.compile_variable(out),
643 info_offset: self.compile_variable(offset.into()),
644 }
645 }
646 cube::Metadata::Stride { dim, var } => {
647 let position = self.ext_meta_pos(&var);
648 let offset = self.metadata.stride_offset_index(position);
649 wgsl::Instruction::ExtendedMeta {
650 info_offset: self.compile_variable(offset.into()),
651 dim: self.compile_variable(dim),
652 out: self.compile_variable(out),
653 }
654 }
655 cube::Metadata::Shape { dim, var } => {
656 let position = self.ext_meta_pos(&var);
657 let offset = self.metadata.shape_offset_index(position);
658 wgsl::Instruction::ExtendedMeta {
659 info_offset: self.compile_variable(offset.into()),
660 dim: self.compile_variable(dim),
661 out: self.compile_variable(out),
662 }
663 }
664 cube::Metadata::Length { var } => match var.kind {
665 cube::VariableKind::GlobalInputArray(id) => {
666 let offset = self.metadata.len_index(id);
667 wgsl::Instruction::Metadata {
668 out: self.compile_variable(out),
669 info_offset: self.compile_variable(offset.into()),
670 }
671 }
672 cube::VariableKind::GlobalOutputArray(id) => {
673 let offset = self.metadata.len_index(id);
674 wgsl::Instruction::Metadata {
675 out: self.compile_variable(out),
676 info_offset: self.compile_variable(offset.into()),
677 }
678 }
679 _ => wgsl::Instruction::Length {
680 var: self.compile_variable(var),
681 out: self.compile_variable(out),
682 },
683 },
684 cube::Metadata::BufferLength { var } => match var.kind {
685 cube::VariableKind::GlobalInputArray(id) => {
686 let offset = self.metadata.buffer_len_index(id);
687 wgsl::Instruction::Metadata {
688 out: self.compile_variable(out),
689 info_offset: self.compile_variable(offset.into()),
690 }
691 }
692 cube::VariableKind::GlobalOutputArray(id) => {
693 let offset = self.metadata.buffer_len_index(id);
694 wgsl::Instruction::Metadata {
695 out: self.compile_variable(out),
696 info_offset: self.compile_variable(offset.into()),
697 }
698 }
699 _ => wgsl::Instruction::Length {
700 var: self.compile_variable(var),
701 out: self.compile_variable(out),
702 },
703 },
704 }
705 }
706
707 fn compile_arithmetic(
708 &mut self,
709 value: cube::Arithmetic,
710 out: Option<cube::Variable>,
711 instructions: &mut Vec<wgsl::Instruction>,
712 scope: &mut Scope,
713 ) {
714 let out = out.unwrap();
715 match value {
716 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
717 lhs: self.compile_variable(op.lhs),
718 rhs: self.compile_variable(op.rhs),
719 out: self.compile_variable(out),
720 }),
721 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
722 lhs: self.compile_variable(op.lhs),
723 rhs: self.compile_variable(op.rhs),
724 out: self.compile_variable(out),
725 }),
726 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
727 lhs: self.compile_variable(op.lhs),
728 rhs: self.compile_variable(op.rhs),
729 out: self.compile_variable(out),
730 }),
731 cube::Arithmetic::SaturatingAdd(_) => {
732 unreachable!("Saturating add should be removed by processor");
733 }
734 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
735 a: self.compile_variable(op.a),
736 b: self.compile_variable(op.b),
737 c: self.compile_variable(op.c),
738 out: self.compile_variable(out),
739 }),
740 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
741 lhs: self.compile_variable(op.lhs),
742 rhs: self.compile_variable(op.rhs),
743 out: self.compile_variable(out),
744 }),
745 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
746 lhs: self.compile_variable(op.lhs),
747 rhs: self.compile_variable(op.rhs),
748 out: self.compile_variable(out),
749 }),
750 cube::Arithmetic::SaturatingSub(_) => {
751 unreachable!("Saturating sub should be removed by processor");
752 }
753 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
754 lhs: self.compile_variable(op.lhs),
755 rhs: self.compile_variable(op.rhs),
756 out: self.compile_variable(out),
757 }),
758 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
759 lhs: self.compile_variable(op.lhs),
760 rhs: self.compile_variable(op.rhs),
761 out: self.compile_variable(out),
762 }),
763 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
764 input: self.compile_variable(op.input),
765 out: self.compile_variable(out),
766 }),
767 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
768 input: self.compile_variable(op.input),
769 out: self.compile_variable(out),
770 }),
771 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
772 input: self.compile_variable(op.input),
773 out: self.compile_variable(out),
774 }),
775 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
776 input: self.compile_variable(op.input),
777 out: self.compile_variable(out),
778 }),
779 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
780 input: self.compile_variable(op.input),
781 out: self.compile_variable(out),
782 }),
783 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
784 input: self.compile_variable(op.input),
785 out: self.compile_variable(out),
786 }),
787 cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
788 input: self.compile_variable(op.input),
789 out: self.compile_variable(out),
790 }),
791 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
792 input: self.compile_variable(op.input),
793 out: self.compile_variable(out),
794 }),
795 cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
796 input: self.compile_variable(op.input),
797 out: self.compile_variable(out),
798 }),
799 cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
800 input: self.compile_variable(op.input),
801 out: self.compile_variable(out),
802 }),
803 cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
804 input: self.compile_variable(op.input),
805 out: self.compile_variable(out),
806 }),
807 cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
808 input: self.compile_variable(op.input),
809 out: self.compile_variable(out),
810 }),
811 cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
812 input: self.compile_variable(op.input),
813 out: self.compile_variable(out),
814 }),
815 cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
816 input: self.compile_variable(op.input),
817 out: self.compile_variable(out),
818 }),
819 cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
820 input: self.compile_variable(op.input),
821 out: self.compile_variable(out),
822 }),
823 cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
824 input: self.compile_variable(op.input),
825 out: self.compile_variable(out),
826 }),
827 cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
828 input: self.compile_variable(op.input),
829 out: self.compile_variable(out),
830 }),
831 cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
832 input: self.compile_variable(op.input),
833 out: self.compile_variable(out),
834 }),
835 cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
836 lhs: self.compile_variable(op.lhs),
837 rhs: self.compile_variable(op.rhs),
838 out: self.compile_variable(out),
839 }),
840 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
842 instructions.push(wgsl::Instruction::Powf {
843 lhs: self.compile_variable(op.lhs),
844 rhs: self.compile_variable(op.rhs),
845 out: self.compile_variable(out),
846 })
847 }
848 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
849 input: self.compile_variable(op.input),
850 out: self.compile_variable(out),
851 }),
852 cube::Arithmetic::InverseSqrt(op) => {
853 instructions.push(wgsl::Instruction::InverseSqrt {
854 input: self.compile_variable(op.input),
855 out: self.compile_variable(out),
856 })
857 }
858 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
859 input: self.compile_variable(op.input),
860 out: self.compile_variable(out),
861 }),
862 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
863 input: self.compile_variable(op.input),
864 out: self.compile_variable(out),
865 }),
866 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
867 input: self.compile_variable(op.input),
868 out: self.compile_variable(out),
869 }),
870 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
871 input: self.compile_variable(op.input),
872 out: self.compile_variable(out),
873 }),
874 cube::Arithmetic::Erf(op) => {
875 let mut scope = scope.child();
876 expand_erf(&mut scope, op.input, out);
877 instructions.extend(self.compile_scope(&mut scope));
878 }
879 cube::Arithmetic::MulHi(op) => {
880 let mut scope = scope.child();
881 match self.compilation_options.supports_u64 {
882 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
883 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
884 }
885 instructions.extend(self.compile_scope(&mut scope));
886 }
887 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
888 input: self.compile_variable(op.input),
889 out: self.compile_variable(out),
890 }),
891 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
892 input: self.compile_variable(op.input),
893 min_value: self.compile_variable(op.min_value),
894 max_value: self.compile_variable(op.max_value),
895 out: self.compile_variable(out),
896 }),
897 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
898 lhs: self.compile_variable(op.lhs),
899 rhs: self.compile_variable(op.rhs),
900 out: self.compile_variable(out),
901 }),
902 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
903 input: self.compile_variable(op.input),
904 out: self.compile_variable(out),
905 }),
906 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
907 input: self.compile_variable(op.input),
908 out: self.compile_variable(out),
909 }),
910 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
911 input: self.compile_variable(op.input),
912 out: self.compile_variable(out),
913 }),
914 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
915 lhs: self.compile_variable(op.lhs),
916 rhs: self.compile_variable(op.rhs),
917 out: self.compile_variable(out),
918 }),
919 }
920 }
921
922 fn compile_cmp(
923 &mut self,
924 value: cube::Comparison,
925 out: Option<cube::Variable>,
926 instructions: &mut Vec<wgsl::Instruction>,
927 ) {
928 let out = out.unwrap();
929 match value {
930 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
931 lhs: self.compile_variable(op.lhs),
932 rhs: self.compile_variable(op.rhs),
933 out: self.compile_variable(out),
934 }),
935 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
936 lhs: self.compile_variable(op.lhs),
937 rhs: self.compile_variable(op.rhs),
938 out: self.compile_variable(out),
939 }),
940 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
941 lhs: self.compile_variable(op.lhs),
942 rhs: self.compile_variable(op.rhs),
943 out: self.compile_variable(out),
944 }),
945 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
946 lhs: self.compile_variable(op.lhs),
947 rhs: self.compile_variable(op.rhs),
948 out: self.compile_variable(out),
949 }),
950 cube::Comparison::GreaterEqual(op) => {
951 instructions.push(wgsl::Instruction::GreaterEqual {
952 lhs: self.compile_variable(op.lhs),
953 rhs: self.compile_variable(op.rhs),
954 out: self.compile_variable(out),
955 })
956 }
957 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
958 lhs: self.compile_variable(op.lhs),
959 rhs: self.compile_variable(op.rhs),
960 out: self.compile_variable(out),
961 }),
962 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
963 input: self.compile_variable(op.input),
964 out: self.compile_variable(out),
965 }),
966 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
967 input: self.compile_variable(op.input),
968 out: self.compile_variable(out),
969 }),
970 }
971 }
972
973 fn compile_bitwise(
974 &mut self,
975 value: cube::Bitwise,
976 out: Option<cube::Variable>,
977 instructions: &mut Vec<wgsl::Instruction>,
978 ) {
979 let out = out.unwrap();
980 match value {
981 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
982 lhs: self.compile_variable(op.lhs),
983 rhs: self.compile_variable(op.rhs),
984 out: self.compile_variable(out),
985 }),
986 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
987 lhs: self.compile_variable(op.lhs),
988 rhs: self.compile_variable(op.rhs),
989 out: self.compile_variable(out),
990 }),
991 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
992 lhs: self.compile_variable(op.lhs),
993 rhs: self.compile_variable(op.rhs),
994 out: self.compile_variable(out),
995 }),
996 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
997 input: self.compile_variable(op.input),
998 out: self.compile_variable(out),
999 }),
1000 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1001 input: self.compile_variable(op.input),
1002 out: self.compile_variable(out),
1003 }),
1004 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1005 lhs: self.compile_variable(op.lhs),
1006 rhs: self.compile_variable(op.rhs),
1007 out: self.compile_variable(out),
1008 }),
1009 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1010 lhs: self.compile_variable(op.lhs),
1011 rhs: self.compile_variable(op.rhs),
1012 out: self.compile_variable(out),
1013 }),
1014 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1015 input: self.compile_variable(op.input),
1016 out: self.compile_variable(out),
1017 }),
1018 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1019 input: self.compile_variable(op.input),
1020 out: self.compile_variable(out),
1021 }),
1022 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1023 input: self.compile_variable(op.input),
1024 out: self.compile_variable(out),
1025 }),
1026 }
1027 }
1028
1029 fn compile_operator(
1030 &mut self,
1031 value: cube::Operator,
1032 out: Option<cube::Variable>,
1033 instructions: &mut Vec<wgsl::Instruction>,
1034 ) {
1035 let out = out.unwrap();
1036 match value {
1037 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1038 input: self.compile_variable(op.input),
1039 out: self.compile_variable(out),
1040 }),
1041 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1042 instructions.push(wgsl::Instruction::Index {
1043 lhs: self.compile_variable(op.list),
1044 rhs: self.compile_variable(op.index),
1045 out: self.compile_variable(out),
1046 });
1047 }
1048 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1049 instructions.push(wgsl::Instruction::IndexAssign {
1050 index: self.compile_variable(op.index),
1051 rhs: self.compile_variable(op.value),
1052 out: self.compile_variable(out),
1053 })
1054 }
1055 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1056 lhs: self.compile_variable(op.lhs),
1057 rhs: self.compile_variable(op.rhs),
1058 out: self.compile_variable(out),
1059 }),
1060 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1061 lhs: self.compile_variable(op.lhs),
1062 rhs: self.compile_variable(op.rhs),
1063 out: self.compile_variable(out),
1064 }),
1065 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1066 input: self.compile_variable(op.input),
1067 out: self.compile_variable(out),
1068 }),
1069 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1070 input: self.compile_variable(op.input),
1071 out: self.compile_variable(out),
1072 }),
1073 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1074 inputs: op
1075 .inputs
1076 .into_iter()
1077 .map(|var| self.compile_variable(var))
1078 .collect(),
1079 out: self.compile_variable(out),
1080 }),
1081 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1082 input: self.compile_variable(op.input),
1083 in_index: self.compile_variable(op.in_index),
1084 out: self.compile_variable(out),
1085 out_index: self.compile_variable(op.out_index),
1086 }),
1087 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1088 input: self.compile_variable(op.input),
1089 in_index: self.compile_variable(op.in_index),
1090 out: self.compile_variable(out),
1091 out_index: self.compile_variable(op.out_index),
1092 len: op.len,
1093 }),
1094 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1095 cond: self.compile_variable(op.cond),
1096 then: self.compile_variable(op.then),
1097 or_else: self.compile_variable(op.or_else),
1098 out: self.compile_variable(out),
1099 }),
1100 }
1101 }
1102
1103 fn compile_atomic(
1104 &mut self,
1105 atomic: cube::AtomicOp,
1106 out: Option<cube::Variable>,
1107 ) -> wgsl::Instruction {
1108 let out = out.unwrap();
1109 match atomic {
1110 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1111 lhs: self.compile_variable(op.lhs),
1112 rhs: self.compile_variable(op.rhs),
1113 out: self.compile_variable(out),
1114 },
1115 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1116 lhs: self.compile_variable(op.lhs),
1117 rhs: self.compile_variable(op.rhs),
1118 out: self.compile_variable(out),
1119 },
1120 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1121 lhs: self.compile_variable(op.lhs),
1122 rhs: self.compile_variable(op.rhs),
1123 out: self.compile_variable(out),
1124 },
1125 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1126 lhs: self.compile_variable(op.lhs),
1127 rhs: self.compile_variable(op.rhs),
1128 out: self.compile_variable(out),
1129 },
1130 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1131 lhs: self.compile_variable(op.lhs),
1132 rhs: self.compile_variable(op.rhs),
1133 out: self.compile_variable(out),
1134 },
1135 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1136 lhs: self.compile_variable(op.lhs),
1137 rhs: self.compile_variable(op.rhs),
1138 out: self.compile_variable(out),
1139 },
1140 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1141 lhs: self.compile_variable(op.lhs),
1142 rhs: self.compile_variable(op.rhs),
1143 out: self.compile_variable(out),
1144 },
1145 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1146 input: self.compile_variable(op.input),
1147 out: self.compile_variable(out),
1148 },
1149 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1150 input: self.compile_variable(op.input),
1151 out: self.compile_variable(out),
1152 },
1153 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1154 lhs: self.compile_variable(op.lhs),
1155 rhs: self.compile_variable(op.rhs),
1156 out: self.compile_variable(out),
1157 },
1158 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1159 lhs: self.compile_variable(op.input),
1160 cmp: self.compile_variable(op.cmp),
1161 value: self.compile_variable(op.val),
1162 out: self.compile_variable(out),
1163 },
1164 }
1165 }
1166
1167 fn compile_location(value: kernel::Location) -> wgsl::Location {
1168 match value {
1169 kernel::Location::Storage => wgsl::Location::Storage,
1170 kernel::Location::Cube => wgsl::Location::Workgroup,
1171 }
1172 }
1173
1174 fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1175 wgsl::Binding {
1176 id: value.id,
1177 visibility: value.visibility,
1178 location: Self::compile_location(value.location),
1179 item: self.compile_type(value.ty),
1180 size: value.size,
1181 }
1182 }
1183}
1184
1185fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1186 let mut extensions = Vec::new();
1187
1188 let mut register_extension = |extension: wgsl::Extension| {
1189 if !extensions.contains(&extension) {
1190 extensions.push(extension);
1191 }
1192 };
1193
1194 for instruction in instructions {
1196 match instruction {
1197 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1198 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1199 register_extension(wgsl::powf_extension(rhs, out));
1200 }
1201 #[cfg(target_os = "macos")]
1202 wgsl::Instruction::Tanh { input, out: _ } => {
1203 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1204 register_extension(wgsl::Extension::SafeTanh(input.item()));
1205 }
1206 wgsl::Instruction::IsNan { input, out } => {
1207 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1208 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1209 }
1210 wgsl::Instruction::IsInf { input, out } => {
1211 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1212 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1213 }
1214 wgsl::Instruction::If { instructions, .. } => {
1215 for extension in register_extensions(instructions) {
1216 register_extension(extension);
1217 }
1218 }
1219 wgsl::Instruction::IfElse {
1220 instructions_if,
1221 instructions_else,
1222 ..
1223 } => {
1224 for extension in register_extensions(instructions_if) {
1225 register_extension(extension);
1226 }
1227 for extension in register_extensions(instructions_else) {
1228 register_extension(extension);
1229 }
1230 }
1231 wgsl::Instruction::Loop { instructions } => {
1232 for extension in register_extensions(instructions) {
1233 register_extension(extension);
1234 }
1235 }
1236 wgsl::Instruction::RangeLoop { instructions, .. } => {
1237 for extension in register_extensions(instructions) {
1238 register_extension(extension);
1239 }
1240 }
1241 _ => {}
1242 }
1243 }
1244
1245 extensions
1246}