1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedMemory};
4use crate::compiler::wgsl;
5
6use cubecl_common::ExecutionMode;
7use cubecl_core::post_processing::{
8 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12 Metadata, WgpuCompilationOptions, compute,
13 ir::{self as cube, Scope},
14 prelude::expand_erf,
15};
16use cubecl_core::{
17 ir::{ConstantScalarValue, Processor, UIntKind},
18 post_processing::unroll::UnrollProcessor,
19};
20
21pub const MAX_LINE_SIZE: u32 = 4;
22
23#[derive(Clone, Default)]
25pub struct WgslCompiler {
26 metadata: Metadata,
27 ext_meta_pos: Vec<u32>,
28 local_invocation_index: bool,
29 local_invocation_id: bool,
30 global_invocation_id: bool,
32 workgroup_id: bool,
33 subgroup_size: bool,
34 subgroup_invocation_id: bool,
35 id: bool,
36 num_workgroups: bool,
37 workgroup_id_no_axis: bool,
38 workgroup_size_no_axis: bool,
39 num_workgroup_no_axis: bool,
40 shared_memories: Vec<SharedMemory>,
41 const_arrays: Vec<ConstantArray>,
42 local_arrays: Vec<LocalArray>,
43 #[allow(dead_code)]
44 compilation_options: WgpuCompilationOptions,
45 strategy: ExecutionMode,
46 subgroup_instructions_used: bool,
47 f16_used: bool,
48}
49
50impl core::fmt::Debug for WgslCompiler {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.write_str("WgslCompiler")
53 }
54}
55
56impl cubecl_core::Compiler for WgslCompiler {
57 type Representation = ComputeShader;
58 type CompilationOptions = WgpuCompilationOptions;
59
60 fn compile(
61 &mut self,
62 shader: compute::KernelDefinition,
63 compilation_options: &Self::CompilationOptions,
64 mode: ExecutionMode,
65 ) -> Self::Representation {
66 self.compilation_options = compilation_options.clone();
67 self.compile_shader(shader, mode)
68 }
69
70 fn elem_size(&self, elem: cube::ElemType) -> usize {
71 elem.size()
72 }
73
74 fn extension(&self) -> &'static str {
75 "wgsl"
76 }
77}
78
79impl WgslCompiler {
80 fn compile_shader(
81 &mut self,
82 mut value: compute::KernelDefinition,
83 mode: ExecutionMode,
84 ) -> wgsl::ComputeShader {
85 self.strategy = mode;
86
87 let num_meta = value.buffers.len();
88
89 self.ext_meta_pos = Vec::new();
90 let mut num_ext = 0;
91
92 for binding in value.buffers.iter() {
93 self.ext_meta_pos.push(num_ext);
94 if binding.has_extended_meta {
95 num_ext += 1;
96 }
97 }
98
99 self.metadata = Metadata::new(num_meta as u32, num_ext);
100
101 let instructions = self.compile_scope(&mut value.body);
102 let extensions = register_extensions(&instructions);
103 let body = wgsl::Body {
104 instructions,
105 id: self.id,
106 };
107
108 wgsl::ComputeShader {
109 buffers: value
110 .buffers
111 .into_iter()
112 .map(|mut it| {
113 if it.ty.line_size() > MAX_LINE_SIZE {
116 it.ty = it.ty.line(MAX_LINE_SIZE);
117 }
118 self.compile_binding(it)
119 })
120 .collect(),
121 scalars: value
122 .scalars
123 .into_iter()
124 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
125 .collect(),
126 shared_memories: self.shared_memories.clone(),
127 constant_arrays: self.const_arrays.clone(),
128 local_arrays: self.local_arrays.clone(),
129 has_metadata: self.metadata.static_len() > 0,
130 workgroup_size: value.cube_dim,
131 global_invocation_id: self.global_invocation_id || self.id,
132 local_invocation_index: self.local_invocation_index,
133 local_invocation_id: self.local_invocation_id,
134 num_workgroups: self.id
135 || self.num_workgroups
136 || self.num_workgroup_no_axis
137 || self.workgroup_id_no_axis,
138 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
139 subgroup_size: self.subgroup_size,
140 subgroup_invocation_id: self.subgroup_invocation_id,
141 body,
142 extensions,
143 num_workgroups_no_axis: self.num_workgroup_no_axis,
144 workgroup_id_no_axis: self.workgroup_id_no_axis,
145 workgroup_size_no_axis: self.workgroup_size_no_axis,
146 subgroup_instructions_used: self.subgroup_instructions_used,
147 f16_used: self.f16_used,
148 kernel_name: value.options.kernel_name,
149 }
150 }
151
152 fn compile_type(&mut self, item: cube::Type) -> Item {
153 match item {
154 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
155 cube::Type::Line(ty, size) => {
156 let elem = self.compile_storage_type(ty);
157 match size {
158 2 => wgsl::Item::Vec2(elem),
159 3 => wgsl::Item::Vec3(elem),
160 4 => wgsl::Item::Vec4(elem),
161 _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
162 }
163 }
164 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
165 }
166 }
167
168 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
169 match ty {
170 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
171 cube::StorageType::Atomic(ty) => match ty {
172 cube::ElemType::Float(i) => match i {
173 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
174 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
175 },
176 cube::ElemType::Int(i) => match i {
177 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
178 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
179 },
180 cube::ElemType::UInt(kind) => match kind {
181 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
182 kind => panic!("{kind:?} is not a valid WgpuElement"),
183 },
184 other => panic!("{other:?} is not a valid WgpuElement"),
185 },
186 cube::StorageType::Packed(_, _) => {
187 unimplemented!("Packed types not yet supported in WGSL")
188 }
189 }
190 }
191
192 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
193 match value {
194 cube::ElemType::Float(f) => match f {
195 cube::FloatKind::E2M1
196 | cube::FloatKind::E2M3
197 | cube::FloatKind::E3M2
198 | cube::FloatKind::E4M3
199 | cube::FloatKind::E5M2
200 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
201 cube::FloatKind::F16 => {
202 self.f16_used = true;
203 wgsl::Elem::F16
204 }
205 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
206 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
207 cube::FloatKind::Flex32 => wgsl::Elem::F32,
208 cube::FloatKind::F32 => wgsl::Elem::F32,
209 cube::FloatKind::F64 => wgsl::Elem::F64,
210 },
211 cube::ElemType::Int(i) => match i {
212 cube::IntKind::I32 => wgsl::Elem::I32,
213 cube::IntKind::I64 => wgsl::Elem::I64,
214 kind => panic!("{kind:?} is not a valid WgpuElement"),
215 },
216 cube::ElemType::UInt(kind) => match kind {
217 cube::UIntKind::U32 => wgsl::Elem::U32,
218 cube::UIntKind::U64 => wgsl::Elem::U64,
219 kind => panic!("{kind:?} is not a valid WgpuElement"),
220 },
221 cube::ElemType::Bool => wgsl::Elem::Bool,
222 }
223 }
224
225 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
226 let pos = var.index().expect("Variable should have index");
227 self.ext_meta_pos[pos as usize]
228 }
229
230 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
231 let item = value.ty;
232 match value.kind {
233 cube::VariableKind::GlobalInputArray(id) => {
234 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
235 }
236 cube::VariableKind::GlobalScalar(id) => {
237 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
238 }
239 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
240 wgsl::Variable::LocalMut {
241 id,
242 item: self.compile_type(item),
243 }
244 }
245 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
246 id,
247 item: self.compile_type(item),
248 },
249 cube::VariableKind::GlobalOutputArray(id) => {
250 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
251 }
252 cube::VariableKind::ConstantScalar(value) => {
253 wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
254 }
255 cube::VariableKind::SharedMemory {
256 id,
257 length,
258 unroll_factor,
259 alignment,
260 } => {
261 let item = self.compile_type(item);
262 if !self.shared_memories.iter().any(|s| s.index == id) {
263 self.shared_memories.push(SharedMemory::new(
264 id,
265 item,
266 length * unroll_factor,
267 alignment,
268 ));
269 }
270 wgsl::Variable::SharedMemory(id, item, length)
271 }
272 cube::VariableKind::ConstantArray { id, length, .. } => {
273 let item = self.compile_type(item);
274 wgsl::Variable::ConstantArray(id, item, length)
275 }
276 cube::VariableKind::LocalArray {
277 id,
278 length,
279 unroll_factor,
280 } => {
281 let item = self.compile_type(item);
282 if !self.local_arrays.iter().any(|s| s.index == id) {
283 self.local_arrays
284 .push(LocalArray::new(id, item, length * unroll_factor));
285 }
286 wgsl::Variable::LocalArray(id, item, length)
287 }
288 cube::VariableKind::Builtin(builtin) => match builtin {
289 cube::Builtin::AbsolutePos => {
290 self.id = true;
291 wgsl::Variable::Id
292 }
293 cube::Builtin::UnitPos => {
294 self.local_invocation_index = true;
295 wgsl::Variable::LocalInvocationIndex
296 }
297 cube::Builtin::UnitPosX => {
298 self.local_invocation_id = true;
299 wgsl::Variable::LocalInvocationIdX
300 }
301 cube::Builtin::UnitPosY => {
302 self.local_invocation_id = true;
303 wgsl::Variable::LocalInvocationIdY
304 }
305 cube::Builtin::UnitPosZ => {
306 self.local_invocation_id = true;
307 wgsl::Variable::LocalInvocationIdZ
308 }
309 cube::Builtin::CubePosX => {
310 self.workgroup_id = true;
311 wgsl::Variable::WorkgroupIdX
312 }
313 cube::Builtin::CubePosY => {
314 self.workgroup_id = true;
315 wgsl::Variable::WorkgroupIdY
316 }
317 cube::Builtin::CubePosZ => {
318 self.workgroup_id = true;
319 wgsl::Variable::WorkgroupIdZ
320 }
321 cube::Builtin::CubePosCluster
322 | cube::Builtin::CubePosClusterX
323 | cube::Builtin::CubePosClusterY
324 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
325 cube::Builtin::AbsolutePosX => {
326 self.global_invocation_id = true;
327 wgsl::Variable::GlobalInvocationIdX
328 }
329 cube::Builtin::AbsolutePosY => {
330 self.global_invocation_id = true;
331 wgsl::Variable::GlobalInvocationIdY
332 }
333 cube::Builtin::AbsolutePosZ => {
334 self.global_invocation_id = true;
335 wgsl::Variable::GlobalInvocationIdZ
336 }
337 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
338 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
339 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
340 cube::Builtin::CubeClusterDim
341 | cube::Builtin::CubeClusterDimX
342 | cube::Builtin::CubeClusterDimY
343 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
344 cube::Builtin::CubeCountX => {
345 self.num_workgroups = true;
346 wgsl::Variable::NumWorkgroupsX
347 }
348 cube::Builtin::CubeCountY => {
349 self.num_workgroups = true;
350 wgsl::Variable::NumWorkgroupsY
351 }
352 cube::Builtin::CubeCountZ => {
353 self.num_workgroups = true;
354 wgsl::Variable::NumWorkgroupsZ
355 }
356 cube::Builtin::CubePos => {
357 self.workgroup_id_no_axis = true;
358 wgsl::Variable::WorkgroupId
359 }
360 cube::Builtin::CubeDim => {
361 self.workgroup_size_no_axis = true;
362 wgsl::Variable::WorkgroupSize
363 }
364 cube::Builtin::CubeCount => {
365 self.num_workgroup_no_axis = true;
366 wgsl::Variable::NumWorkgroups
367 }
368 cube::Builtin::PlaneDim => {
369 self.subgroup_size = true;
370 wgsl::Variable::SubgroupSize
371 }
372 cube::Builtin::UnitPosPlane => {
373 self.subgroup_invocation_id = true;
374 wgsl::Variable::SubgroupInvocationId
375 }
376 },
377 cube::VariableKind::Matrix { .. } => {
378 panic!("Cooperative matrix-multiply and accumulate not supported.")
379 }
380 cube::VariableKind::Pipeline { .. } => {
381 panic!("Pipeline not supported.")
382 }
383 cube::VariableKind::Barrier { .. } | cube::VariableKind::BarrierToken { .. } => {
384 panic!("Barrier not supported.")
385 }
386 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
387 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
388 }
389 }
390
391 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
392 let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
393 self.compile_variable(var)
394 }
395
396 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
397 let mut instructions = Vec::new();
398
399 let const_arrays = scope
400 .const_arrays
401 .drain(..)
402 .map(|(var, values)| ConstantArray {
403 index: var.index().unwrap(),
404 item: self.compile_type(var.ty),
405 size: values.len() as u32,
406 values: values
407 .into_iter()
408 .map(|val| self.compile_variable(val))
409 .collect(),
410 })
411 .collect::<Vec<_>>();
412 self.const_arrays.extend(const_arrays);
413
414 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
415 let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
416 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
417 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
418
419 for mut var in processing.variables {
420 if var.ty.line_size() > MAX_LINE_SIZE {
421 var.ty = var.ty.line(MAX_LINE_SIZE);
422 }
423 instructions.push(wgsl::Instruction::DeclareVariable {
424 var: self.compile_variable(var),
425 });
426 }
427
428 processing
429 .instructions
430 .into_iter()
431 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
432
433 instructions
434 }
435
436 fn compile_operation(
437 &mut self,
438 instructions: &mut Vec<wgsl::Instruction>,
439 operation: cube::Operation,
440 out: Option<cube::Variable>,
441 scope: &mut cube::Scope,
442 ) {
443 match operation {
444 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
445 input: self.compile_variable(variable),
446 out: self.compile_variable(out.unwrap()),
447 }),
448 cube::Operation::Arithmetic(op) => {
449 self.compile_arithmetic(op, out, instructions, scope)
450 }
451 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
452 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
453 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
454 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
455 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
456 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
457 cube::Operation::Synchronization(val) => {
458 self.compile_synchronization(instructions, val)
459 }
460 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
461 cube::Operation::CoopMma(_) => {
462 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
463 }
464 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
465 self.compile_comment(instructions, content)
466 }
467 cube::Operation::NonSemantic(_) => {}
468 cube::Operation::Barrier(_) => {
469 panic!("Barrier isn't supported on wgpu.")
470 }
471 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
472 cube::Operation::Marker(_) => {}
473 }
474 }
475
476 fn compile_subgroup(
477 &mut self,
478 instructions: &mut Vec<wgsl::Instruction>,
479 subgroup: cube::Plane,
480 out: Option<cube::Variable>,
481 ) {
482 self.subgroup_instructions_used = true;
483
484 let out = out.unwrap();
485 let op = match subgroup {
486 cube::Plane::Elect => Subgroup::Elect {
487 out: self.compile_variable(out),
488 },
489 cube::Plane::All(op) => Subgroup::All {
490 input: self.compile_variable(op.input),
491 out: self.compile_variable(out),
492 },
493 cube::Plane::Any(op) => Subgroup::Any {
494 input: self.compile_variable(op.input),
495 out: self.compile_variable(out),
496 },
497 cube::Plane::Ballot(op) => Subgroup::Ballot {
498 input: self.compile_variable(op.input),
499 out: self.compile_variable(out),
500 },
501
502 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
503 lhs: self.compile_variable(op.lhs),
504 rhs: self.compile_variable(op.rhs),
505 out: self.compile_variable(out),
506 },
507
508 cube::Plane::Sum(op) => Subgroup::Sum {
509 input: self.compile_variable(op.input),
510 out: self.compile_variable(out),
511 },
512
513 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
514 input: self.compile_variable(op.input),
515 out: self.compile_variable(out),
516 },
517 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
518 input: self.compile_variable(op.input),
519 out: self.compile_variable(out),
520 },
521 cube::Plane::Prod(op) => Subgroup::Prod {
522 input: self.compile_variable(op.input),
523 out: self.compile_variable(out),
524 },
525 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
526 input: self.compile_variable(op.input),
527 out: self.compile_variable(out),
528 },
529 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
530 input: self.compile_variable(op.input),
531 out: self.compile_variable(out),
532 },
533 cube::Plane::Min(op) => Subgroup::Min {
534 input: self.compile_variable(op.input),
535 out: self.compile_variable(out),
536 },
537 cube::Plane::Max(op) => Subgroup::Max {
538 input: self.compile_variable(op.input),
539 out: self.compile_variable(out),
540 },
541 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
542 lhs: self.compile_variable(op.lhs),
543 rhs: self.compile_variable(op.rhs),
544 out: self.compile_variable(out),
545 },
546 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
547 lhs: self.compile_variable(op.lhs),
548 rhs: self.compile_variable(op.rhs),
549 out: self.compile_variable(out),
550 },
551 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
552 lhs: self.compile_variable(op.lhs),
553 rhs: self.compile_variable(op.rhs),
554 out: self.compile_variable(out),
555 },
556 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
557 lhs: self.compile_variable(op.lhs),
558 rhs: self.compile_variable(op.rhs),
559 out: self.compile_variable(out),
560 },
561 };
562
563 instructions.push(wgsl::Instruction::Subgroup(op));
564 }
565
566 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
567 match branch {
568 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
569 cond: self.compile_variable(op.cond),
570 instructions: self.compile_scope(&mut op.scope),
571 }),
572 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
573 cond: self.compile_variable(op.cond),
574 instructions_if: self.compile_scope(&mut op.scope_if),
575 instructions_else: self.compile_scope(&mut op.scope_else),
576 }),
577 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
578 value: self.compile_variable(op.value),
579 instructions_default: self.compile_scope(&mut op.scope_default),
580 cases: op
581 .cases
582 .into_iter()
583 .map(|(val, mut scope)| {
584 (self.compile_variable(val), self.compile_scope(&mut scope))
585 })
586 .collect(),
587 }),
588 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
589 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
590 cube::Branch::RangeLoop(mut range_loop) => {
591 instructions.push(wgsl::Instruction::RangeLoop {
592 i: self.compile_variable(range_loop.i),
593 start: self.compile_variable(range_loop.start),
594 end: self.compile_variable(range_loop.end),
595 step: range_loop.step.map(|it| self.compile_variable(it)),
596 inclusive: range_loop.inclusive,
597 instructions: self.compile_scope(&mut range_loop.scope),
598 })
599 }
600 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
601 instructions: self.compile_scope(&mut op.scope),
602 }),
603 };
604 }
605
606 fn compile_synchronization(
607 &mut self,
608 instructions: &mut Vec<wgsl::Instruction>,
609 synchronization: cube::Synchronization,
610 ) {
611 match synchronization {
612 cube::Synchronization::SyncCube => {
613 instructions.push(wgsl::Instruction::WorkgroupBarrier)
614 }
615 cube::Synchronization::SyncPlane => {
616 panic!("Synchronization within a plane is not supported in WGSL")
617 }
618 cube::Synchronization::SyncStorage => {
619 instructions.push(wgsl::Instruction::StorageBarrier)
620 }
621 cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
622 };
623 }
624
625 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
626 instructions.push(wgsl::Instruction::Comment { content })
627 }
628
629 fn compile_metadata(
630 &mut self,
631 metadata: cube::Metadata,
632 out: Option<cube::Variable>,
633 ) -> wgsl::Instruction {
634 let out = out.unwrap();
635 match metadata {
636 cube::Metadata::Rank { var } => {
637 let position = self.ext_meta_pos(&var);
638 let offset = self.metadata.rank_index(position);
639 wgsl::Instruction::Metadata {
640 out: self.compile_variable(out),
641 info_offset: self.compile_variable(offset.into()),
642 }
643 }
644 cube::Metadata::Stride { dim, var } => {
645 let position = self.ext_meta_pos(&var);
646 let offset = self.metadata.stride_offset_index(position);
647 wgsl::Instruction::ExtendedMeta {
648 info_offset: self.compile_variable(offset.into()),
649 dim: self.compile_variable(dim),
650 out: self.compile_variable(out),
651 }
652 }
653 cube::Metadata::Shape { dim, var } => {
654 let position = self.ext_meta_pos(&var);
655 let offset = self.metadata.shape_offset_index(position);
656 wgsl::Instruction::ExtendedMeta {
657 info_offset: self.compile_variable(offset.into()),
658 dim: self.compile_variable(dim),
659 out: self.compile_variable(out),
660 }
661 }
662 cube::Metadata::Length { var } => match var.kind {
663 cube::VariableKind::GlobalInputArray(id) => {
664 let offset = self.metadata.len_index(id);
665 wgsl::Instruction::Metadata {
666 out: self.compile_variable(out),
667 info_offset: self.compile_variable(offset.into()),
668 }
669 }
670 cube::VariableKind::GlobalOutputArray(id) => {
671 let offset = self.metadata.len_index(id);
672 wgsl::Instruction::Metadata {
673 out: self.compile_variable(out),
674 info_offset: self.compile_variable(offset.into()),
675 }
676 }
677 _ => wgsl::Instruction::Length {
678 var: self.compile_variable(var),
679 out: self.compile_variable(out),
680 },
681 },
682 cube::Metadata::BufferLength { var } => match var.kind {
683 cube::VariableKind::GlobalInputArray(id) => {
684 let offset = self.metadata.buffer_len_index(id);
685 wgsl::Instruction::Metadata {
686 out: self.compile_variable(out),
687 info_offset: self.compile_variable(offset.into()),
688 }
689 }
690 cube::VariableKind::GlobalOutputArray(id) => {
691 let offset = self.metadata.buffer_len_index(id);
692 wgsl::Instruction::Metadata {
693 out: self.compile_variable(out),
694 info_offset: self.compile_variable(offset.into()),
695 }
696 }
697 _ => wgsl::Instruction::Length {
698 var: self.compile_variable(var),
699 out: self.compile_variable(out),
700 },
701 },
702 }
703 }
704
705 fn compile_arithmetic(
706 &mut self,
707 value: cube::Arithmetic,
708 out: Option<cube::Variable>,
709 instructions: &mut Vec<wgsl::Instruction>,
710 scope: &mut Scope,
711 ) {
712 let out = out.unwrap();
713 match value {
714 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
715 lhs: self.compile_variable(op.lhs),
716 rhs: self.compile_variable(op.rhs),
717 out: self.compile_variable(out),
718 }),
719 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
720 lhs: self.compile_variable(op.lhs),
721 rhs: self.compile_variable(op.rhs),
722 out: self.compile_variable(out),
723 }),
724 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
725 lhs: self.compile_variable(op.lhs),
726 rhs: self.compile_variable(op.rhs),
727 out: self.compile_variable(out),
728 }),
729 cube::Arithmetic::SaturatingAdd(_) => {
730 unreachable!("Saturating add should be removed by processor");
731 }
732 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
733 a: self.compile_variable(op.a),
734 b: self.compile_variable(op.b),
735 c: self.compile_variable(op.c),
736 out: self.compile_variable(out),
737 }),
738 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
739 lhs: self.compile_variable(op.lhs),
740 rhs: self.compile_variable(op.rhs),
741 out: self.compile_variable(out),
742 }),
743 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
744 lhs: self.compile_variable(op.lhs),
745 rhs: self.compile_variable(op.rhs),
746 out: self.compile_variable(out),
747 }),
748 cube::Arithmetic::SaturatingSub(_) => {
749 unreachable!("Saturating sub should be removed by processor");
750 }
751 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
752 lhs: self.compile_variable(op.lhs),
753 rhs: self.compile_variable(op.rhs),
754 out: self.compile_variable(out),
755 }),
756 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
757 lhs: self.compile_variable(op.lhs),
758 rhs: self.compile_variable(op.rhs),
759 out: self.compile_variable(out),
760 }),
761 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
762 input: self.compile_variable(op.input),
763 out: self.compile_variable(out),
764 }),
765 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
766 input: self.compile_variable(op.input),
767 out: self.compile_variable(out),
768 }),
769 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
770 input: self.compile_variable(op.input),
771 out: self.compile_variable(out),
772 }),
773 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
774 input: self.compile_variable(op.input),
775 out: self.compile_variable(out),
776 }),
777 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
778 input: self.compile_variable(op.input),
779 out: self.compile_variable(out),
780 }),
781 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
782 input: self.compile_variable(op.input),
783 out: self.compile_variable(out),
784 }),
785 cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
786 input: self.compile_variable(op.input),
787 out: self.compile_variable(out),
788 }),
789 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
790 input: self.compile_variable(op.input),
791 out: self.compile_variable(out),
792 }),
793 cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
794 input: self.compile_variable(op.input),
795 out: self.compile_variable(out),
796 }),
797 cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
798 input: self.compile_variable(op.input),
799 out: self.compile_variable(out),
800 }),
801 cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
802 input: self.compile_variable(op.input),
803 out: self.compile_variable(out),
804 }),
805 cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
806 input: self.compile_variable(op.input),
807 out: self.compile_variable(out),
808 }),
809 cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
810 input: self.compile_variable(op.input),
811 out: self.compile_variable(out),
812 }),
813 cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
814 input: self.compile_variable(op.input),
815 out: self.compile_variable(out),
816 }),
817 cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
818 input: self.compile_variable(op.input),
819 out: self.compile_variable(out),
820 }),
821 cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
822 input: self.compile_variable(op.input),
823 out: self.compile_variable(out),
824 }),
825 cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
826 input: self.compile_variable(op.input),
827 out: self.compile_variable(out),
828 }),
829 cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
830 input: self.compile_variable(op.input),
831 out: self.compile_variable(out),
832 }),
833 cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
834 lhs: self.compile_variable(op.lhs),
835 rhs: self.compile_variable(op.rhs),
836 out: self.compile_variable(out),
837 }),
838 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
840 instructions.push(wgsl::Instruction::Powf {
841 lhs: self.compile_variable(op.lhs),
842 rhs: self.compile_variable(op.rhs),
843 out: self.compile_variable(out),
844 })
845 }
846 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
847 input: self.compile_variable(op.input),
848 out: self.compile_variable(out),
849 }),
850 cube::Arithmetic::InverseSqrt(op) => {
851 instructions.push(wgsl::Instruction::InverseSqrt {
852 input: self.compile_variable(op.input),
853 out: self.compile_variable(out),
854 })
855 }
856 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
857 input: self.compile_variable(op.input),
858 out: self.compile_variable(out),
859 }),
860 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
861 input: self.compile_variable(op.input),
862 out: self.compile_variable(out),
863 }),
864 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
865 input: self.compile_variable(op.input),
866 out: self.compile_variable(out),
867 }),
868 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
869 input: self.compile_variable(op.input),
870 out: self.compile_variable(out),
871 }),
872 cube::Arithmetic::Erf(op) => {
873 let mut scope = scope.child();
874 expand_erf(&mut scope, op.input, out);
875 instructions.extend(self.compile_scope(&mut scope));
876 }
877 cube::Arithmetic::MulHi(op) => {
878 let mut scope = scope.child();
879 match self.compilation_options.supports_u64 {
880 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
881 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
882 }
883 instructions.extend(self.compile_scope(&mut scope));
884 }
885 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
886 input: self.compile_variable(op.input),
887 out: self.compile_variable(out),
888 }),
889 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
890 input: self.compile_variable(op.input),
891 min_value: self.compile_variable(op.min_value),
892 max_value: self.compile_variable(op.max_value),
893 out: self.compile_variable(out),
894 }),
895 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
896 lhs: self.compile_variable(op.lhs),
897 rhs: self.compile_variable(op.rhs),
898 out: self.compile_variable(out),
899 }),
900 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
901 input: self.compile_variable(op.input),
902 out: self.compile_variable(out),
903 }),
904 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
905 input: self.compile_variable(op.input),
906 out: self.compile_variable(out),
907 }),
908 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
909 input: self.compile_variable(op.input),
910 out: self.compile_variable(out),
911 }),
912 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
913 lhs: self.compile_variable(op.lhs),
914 rhs: self.compile_variable(op.rhs),
915 out: self.compile_variable(out),
916 }),
917 }
918 }
919
920 fn compile_cmp(
921 &mut self,
922 value: cube::Comparison,
923 out: Option<cube::Variable>,
924 instructions: &mut Vec<wgsl::Instruction>,
925 ) {
926 let out = out.unwrap();
927 match value {
928 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
929 lhs: self.compile_variable(op.lhs),
930 rhs: self.compile_variable(op.rhs),
931 out: self.compile_variable(out),
932 }),
933 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
934 lhs: self.compile_variable(op.lhs),
935 rhs: self.compile_variable(op.rhs),
936 out: self.compile_variable(out),
937 }),
938 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
939 lhs: self.compile_variable(op.lhs),
940 rhs: self.compile_variable(op.rhs),
941 out: self.compile_variable(out),
942 }),
943 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
944 lhs: self.compile_variable(op.lhs),
945 rhs: self.compile_variable(op.rhs),
946 out: self.compile_variable(out),
947 }),
948 cube::Comparison::GreaterEqual(op) => {
949 instructions.push(wgsl::Instruction::GreaterEqual {
950 lhs: self.compile_variable(op.lhs),
951 rhs: self.compile_variable(op.rhs),
952 out: self.compile_variable(out),
953 })
954 }
955 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
956 lhs: self.compile_variable(op.lhs),
957 rhs: self.compile_variable(op.rhs),
958 out: self.compile_variable(out),
959 }),
960 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
961 input: self.compile_variable(op.input),
962 out: self.compile_variable(out),
963 }),
964 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
965 input: self.compile_variable(op.input),
966 out: self.compile_variable(out),
967 }),
968 }
969 }
970
971 fn compile_bitwise(
972 &mut self,
973 value: cube::Bitwise,
974 out: Option<cube::Variable>,
975 instructions: &mut Vec<wgsl::Instruction>,
976 ) {
977 let out = out.unwrap();
978 match value {
979 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
980 lhs: self.compile_variable(op.lhs),
981 rhs: self.compile_variable(op.rhs),
982 out: self.compile_variable(out),
983 }),
984 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
985 lhs: self.compile_variable(op.lhs),
986 rhs: self.compile_variable(op.rhs),
987 out: self.compile_variable(out),
988 }),
989 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
990 lhs: self.compile_variable(op.lhs),
991 rhs: self.compile_variable(op.rhs),
992 out: self.compile_variable(out),
993 }),
994 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
995 input: self.compile_variable(op.input),
996 out: self.compile_variable(out),
997 }),
998 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
999 input: self.compile_variable(op.input),
1000 out: self.compile_variable(out),
1001 }),
1002 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1003 lhs: self.compile_variable(op.lhs),
1004 rhs: self.compile_variable(op.rhs),
1005 out: self.compile_variable(out),
1006 }),
1007 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1008 lhs: self.compile_variable(op.lhs),
1009 rhs: self.compile_variable(op.rhs),
1010 out: self.compile_variable(out),
1011 }),
1012 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1013 input: self.compile_variable(op.input),
1014 out: self.compile_variable(out),
1015 }),
1016 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1017 input: self.compile_variable(op.input),
1018 out: self.compile_variable(out),
1019 }),
1020 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1021 input: self.compile_variable(op.input),
1022 out: self.compile_variable(out),
1023 }),
1024 }
1025 }
1026
1027 fn compile_operator(
1028 &mut self,
1029 value: cube::Operator,
1030 out: Option<cube::Variable>,
1031 instructions: &mut Vec<wgsl::Instruction>,
1032 ) {
1033 let out = out.unwrap();
1034 match value {
1035 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1036 input: self.compile_variable(op.input),
1037 out: self.compile_variable(out),
1038 }),
1039 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1040 instructions.push(wgsl::Instruction::Index {
1041 lhs: self.compile_variable(op.list),
1042 rhs: self.compile_variable(op.index),
1043 out: self.compile_variable(out),
1044 });
1045 }
1046 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1047 instructions.push(wgsl::Instruction::IndexAssign {
1048 index: self.compile_variable(op.index),
1049 rhs: self.compile_variable(op.value),
1050 out: self.compile_variable(out),
1051 })
1052 }
1053 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1054 lhs: self.compile_variable(op.lhs),
1055 rhs: self.compile_variable(op.rhs),
1056 out: self.compile_variable(out),
1057 }),
1058 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1059 lhs: self.compile_variable(op.lhs),
1060 rhs: self.compile_variable(op.rhs),
1061 out: self.compile_variable(out),
1062 }),
1063 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1064 input: self.compile_variable(op.input),
1065 out: self.compile_variable(out),
1066 }),
1067 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1068 input: self.compile_variable(op.input),
1069 out: self.compile_variable(out),
1070 }),
1071 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1072 inputs: op
1073 .inputs
1074 .into_iter()
1075 .map(|var| self.compile_variable(var))
1076 .collect(),
1077 out: self.compile_variable(out),
1078 }),
1079 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1080 input: self.compile_variable(op.input),
1081 in_index: self.compile_variable(op.in_index),
1082 out: self.compile_variable(out),
1083 out_index: self.compile_variable(op.out_index),
1084 }),
1085 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1086 input: self.compile_variable(op.input),
1087 in_index: self.compile_variable(op.in_index),
1088 out: self.compile_variable(out),
1089 out_index: self.compile_variable(op.out_index),
1090 len: op.len,
1091 }),
1092 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1093 cond: self.compile_variable(op.cond),
1094 then: self.compile_variable(op.then),
1095 or_else: self.compile_variable(op.or_else),
1096 out: self.compile_variable(out),
1097 }),
1098 }
1099 }
1100
1101 fn compile_atomic(
1102 &mut self,
1103 atomic: cube::AtomicOp,
1104 out: Option<cube::Variable>,
1105 ) -> wgsl::Instruction {
1106 let out = out.unwrap();
1107 match atomic {
1108 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1109 lhs: self.compile_variable(op.lhs),
1110 rhs: self.compile_variable(op.rhs),
1111 out: self.compile_variable(out),
1112 },
1113 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1114 lhs: self.compile_variable(op.lhs),
1115 rhs: self.compile_variable(op.rhs),
1116 out: self.compile_variable(out),
1117 },
1118 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1119 lhs: self.compile_variable(op.lhs),
1120 rhs: self.compile_variable(op.rhs),
1121 out: self.compile_variable(out),
1122 },
1123 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1124 lhs: self.compile_variable(op.lhs),
1125 rhs: self.compile_variable(op.rhs),
1126 out: self.compile_variable(out),
1127 },
1128 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1129 lhs: self.compile_variable(op.lhs),
1130 rhs: self.compile_variable(op.rhs),
1131 out: self.compile_variable(out),
1132 },
1133 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1134 lhs: self.compile_variable(op.lhs),
1135 rhs: self.compile_variable(op.rhs),
1136 out: self.compile_variable(out),
1137 },
1138 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1139 lhs: self.compile_variable(op.lhs),
1140 rhs: self.compile_variable(op.rhs),
1141 out: self.compile_variable(out),
1142 },
1143 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1144 input: self.compile_variable(op.input),
1145 out: self.compile_variable(out),
1146 },
1147 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1148 input: self.compile_variable(op.input),
1149 out: self.compile_variable(out),
1150 },
1151 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1152 lhs: self.compile_variable(op.lhs),
1153 rhs: self.compile_variable(op.rhs),
1154 out: self.compile_variable(out),
1155 },
1156 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1157 lhs: self.compile_variable(op.input),
1158 cmp: self.compile_variable(op.cmp),
1159 value: self.compile_variable(op.val),
1160 out: self.compile_variable(out),
1161 },
1162 }
1163 }
1164
1165 fn compile_location(value: compute::Location) -> wgsl::Location {
1166 match value {
1167 compute::Location::Storage => wgsl::Location::Storage,
1168 compute::Location::Cube => wgsl::Location::Workgroup,
1169 }
1170 }
1171
1172 fn compile_binding(&mut self, value: compute::Binding) -> wgsl::Binding {
1173 wgsl::Binding {
1174 id: value.id,
1175 visibility: value.visibility,
1176 location: Self::compile_location(value.location),
1177 item: self.compile_type(value.ty),
1178 size: value.size,
1179 }
1180 }
1181}
1182
1183fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1184 let mut extensions = Vec::new();
1185
1186 let mut register_extension = |extension: wgsl::Extension| {
1187 if !extensions.contains(&extension) {
1188 extensions.push(extension);
1189 }
1190 };
1191
1192 for instruction in instructions {
1194 match instruction {
1195 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1196 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1197 register_extension(wgsl::powf_extension(rhs, out));
1198 }
1199 #[cfg(target_os = "macos")]
1200 wgsl::Instruction::Tanh { input, out: _ } => {
1201 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1202 register_extension(wgsl::Extension::SafeTanh(input.item()));
1203 }
1204 wgsl::Instruction::IsNan { input, out } => {
1205 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1206 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1207 }
1208 wgsl::Instruction::IsInf { input, out } => {
1209 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1210 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1211 }
1212 wgsl::Instruction::If { instructions, .. } => {
1213 for extension in register_extensions(instructions) {
1214 register_extension(extension);
1215 }
1216 }
1217 wgsl::Instruction::IfElse {
1218 instructions_if,
1219 instructions_else,
1220 ..
1221 } => {
1222 for extension in register_extensions(instructions_if) {
1223 register_extension(extension);
1224 }
1225 for extension in register_extensions(instructions_else) {
1226 register_extension(extension);
1227 }
1228 }
1229 wgsl::Instruction::Loop { instructions } => {
1230 for extension in register_extensions(instructions) {
1231 register_extension(extension);
1232 }
1233 }
1234 wgsl::Instruction::RangeLoop { instructions, .. } => {
1235 for extension in register_extensions(instructions) {
1236 register_extension(extension);
1237 }
1238 }
1239 _ => {}
1240 }
1241 }
1242
1243 extensions
1244}