1use super::Subgroup;
2use super::{ConstantArray, shader::ComputeShader};
3use super::{Item, LocalArray, SharedArray};
4use crate::compiler::wgsl::{self, SharedValue};
5
6use cubecl_common::backtrace::BackTrace;
7use cubecl_core::post_processing::{
8 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{
12 Metadata, WgpuCompilationOptions,
13 ir::{self as cube, Scope},
14 prelude::expand_erf,
15};
16use cubecl_core::{
17 ir::{ConstantScalarValue, Processor, UIntKind},
18 post_processing::unroll::UnrollProcessor,
19};
20use cubecl_runtime::compiler::CompilationError;
21use cubecl_runtime::kernel;
22
23pub const MAX_LINE_SIZE: u32 = 4;
24
25#[derive(Clone, Default)]
27pub struct WgslCompiler {
28 metadata: Metadata,
29 ext_meta_pos: Vec<u32>,
30 local_invocation_index: bool,
31 local_invocation_id: bool,
32 global_invocation_id: bool,
34 workgroup_id: bool,
35 subgroup_size: bool,
36 subgroup_invocation_id: bool,
37 id: bool,
38 num_workgroups: bool,
39 workgroup_id_no_axis: bool,
40 workgroup_size_no_axis: bool,
41 num_workgroup_no_axis: bool,
42 shared_arrays: Vec<SharedArray>,
43 shared_values: Vec<SharedValue>,
44 const_arrays: Vec<ConstantArray>,
45 local_arrays: Vec<LocalArray>,
46 #[allow(dead_code)]
47 compilation_options: WgpuCompilationOptions,
48 strategy: ExecutionMode,
49 subgroup_instructions_used: bool,
50 f16_used: bool,
51}
52
53impl core::fmt::Debug for WgslCompiler {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.write_str("WgslCompiler")
56 }
57}
58
59impl cubecl_core::Compiler for WgslCompiler {
60 type Representation = ComputeShader;
61 type CompilationOptions = WgpuCompilationOptions;
62
63 fn compile(
64 &mut self,
65 shader: kernel::KernelDefinition,
66 compilation_options: &Self::CompilationOptions,
67 mode: ExecutionMode,
68 ) -> Result<Self::Representation, CompilationError> {
69 self.compilation_options = compilation_options.clone();
70 self.compile_shader(shader, mode)
71 }
72
73 fn elem_size(&self, elem: cube::ElemType) -> usize {
74 elem.size()
75 }
76
77 fn extension(&self) -> &'static str {
78 "wgsl"
79 }
80}
81
82impl WgslCompiler {
83 fn compile_shader(
84 &mut self,
85 mut value: kernel::KernelDefinition,
86 mode: ExecutionMode,
87 ) -> Result<wgsl::ComputeShader, CompilationError> {
88 let errors = value.body.pop_errors();
89 if !errors.is_empty() {
90 let mut reason = "Can't compile wgsl kernel".to_string();
91 for error in errors {
92 reason += error.as_str();
93 reason += "\n";
94 }
95
96 return Err(CompilationError::Validation {
97 reason,
98 backtrace: BackTrace::capture(),
99 });
100 }
101
102 self.strategy = mode;
103
104 let num_meta = value.buffers.len();
105
106 self.ext_meta_pos = Vec::new();
107 let mut num_ext = 0;
108
109 for binding in value.buffers.iter() {
110 self.ext_meta_pos.push(num_ext);
111 if binding.has_extended_meta {
112 num_ext += 1;
113 }
114 }
115
116 self.metadata = Metadata::new(num_meta as u32, num_ext);
117
118 let instructions = self.compile_scope(&mut value.body);
119 let extensions = register_extensions(&instructions);
120 let body = wgsl::Body {
121 instructions,
122 id: self.id,
123 };
124
125 Ok(wgsl::ComputeShader {
126 buffers: value
127 .buffers
128 .into_iter()
129 .map(|mut it| {
130 if it.ty.line_size() > MAX_LINE_SIZE {
133 it.ty = it.ty.line(MAX_LINE_SIZE);
134 }
135 self.compile_binding(it)
136 })
137 .collect(),
138 scalars: value
139 .scalars
140 .into_iter()
141 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
142 .collect(),
143 shared_arrays: self.shared_arrays.clone(),
144 shared_values: self.shared_values.clone(),
145 constant_arrays: self.const_arrays.clone(),
146 local_arrays: self.local_arrays.clone(),
147 has_metadata: self.metadata.static_len() > 0,
148 workgroup_size: value.cube_dim,
149 global_invocation_id: self.global_invocation_id || self.id,
150 local_invocation_index: self.local_invocation_index,
151 local_invocation_id: self.local_invocation_id,
152 num_workgroups: self.id
153 || self.num_workgroups
154 || self.num_workgroup_no_axis
155 || self.workgroup_id_no_axis,
156 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
157 subgroup_size: self.subgroup_size,
158 subgroup_invocation_id: self.subgroup_invocation_id,
159 body,
160 extensions,
161 num_workgroups_no_axis: self.num_workgroup_no_axis,
162 workgroup_id_no_axis: self.workgroup_id_no_axis,
163 workgroup_size_no_axis: self.workgroup_size_no_axis,
164 subgroup_instructions_used: self.subgroup_instructions_used,
165 f16_used: self.f16_used,
166 kernel_name: value.options.kernel_name,
167 })
168 }
169
170 fn compile_type(&mut self, item: cube::Type) -> Item {
171 match item {
172 cube::Type::Scalar(ty) => wgsl::Item::Scalar(self.compile_storage_type(ty)),
173 cube::Type::Line(ty, size) => {
174 let elem = self.compile_storage_type(ty);
175 match size {
176 2 => wgsl::Item::Vec2(elem),
177 3 => wgsl::Item::Vec3(elem),
178 4 => wgsl::Item::Vec4(elem),
179 _ => panic!("Unsupported vectorizations scheme {:?}", item.line_size()),
180 }
181 }
182 cube::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
183 }
184 }
185
186 fn compile_storage_type(&mut self, ty: cube::StorageType) -> wgsl::Elem {
187 match ty {
188 cube::StorageType::Scalar(ty) => self.compile_elem(ty),
189 cube::StorageType::Atomic(ty) => match ty {
190 cube::ElemType::Float(i) => match i {
191 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
192 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
193 },
194 cube::ElemType::Int(i) => match i {
195 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
196 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
197 },
198 cube::ElemType::UInt(kind) => match kind {
199 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
200 kind => panic!("{kind:?} is not a valid WgpuElement"),
201 },
202 other => panic!("{other:?} is not a valid WgpuElement"),
203 },
204 cube::StorageType::Packed(_, _) => {
205 unimplemented!("Packed types not yet supported in WGSL")
206 }
207 cube::StorageType::Opaque(ty) => match ty {
208 cube::OpaqueType::Barrier(_) => {
209 unimplemented!("Barrier objects not supported in WGSL")
210 }
211 },
212 }
213 }
214
215 fn compile_elem(&mut self, value: cube::ElemType) -> wgsl::Elem {
216 match value {
217 cube::ElemType::Float(f) => match f {
218 cube::FloatKind::E2M1
219 | cube::FloatKind::E2M3
220 | cube::FloatKind::E3M2
221 | cube::FloatKind::E4M3
222 | cube::FloatKind::E5M2
223 | cube::FloatKind::UE8M0 => panic!("Minifloat is not a valid WgpuElement"),
224 cube::FloatKind::F16 => {
225 self.f16_used = true;
226 wgsl::Elem::F16
227 }
228 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
229 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
230 cube::FloatKind::Flex32 => wgsl::Elem::F32,
231 cube::FloatKind::F32 => wgsl::Elem::F32,
232 cube::FloatKind::F64 => wgsl::Elem::F64,
233 },
234 cube::ElemType::Int(i) => match i {
235 cube::IntKind::I32 => wgsl::Elem::I32,
236 cube::IntKind::I64 => wgsl::Elem::I64,
237 kind => panic!("{kind:?} is not a valid WgpuElement"),
238 },
239 cube::ElemType::UInt(kind) => match kind {
240 cube::UIntKind::U32 => wgsl::Elem::U32,
241 cube::UIntKind::U64 => wgsl::Elem::U64,
242 kind => panic!("{kind:?} is not a valid WgpuElement"),
243 },
244 cube::ElemType::Bool => wgsl::Elem::Bool,
245 }
246 }
247
248 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
249 let pos = var.index().expect("Variable should have index");
250 self.ext_meta_pos[pos as usize]
251 }
252
253 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
254 let item = value.ty;
255 match value.kind {
256 cube::VariableKind::GlobalInputArray(id) => {
257 wgsl::Variable::GlobalInputArray(id, self.compile_type(item))
258 }
259 cube::VariableKind::GlobalScalar(id) => {
260 wgsl::Variable::GlobalScalar(id, self.compile_storage_type(item.storage_type()))
261 }
262 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
263 wgsl::Variable::LocalMut {
264 id,
265 item: self.compile_type(item),
266 }
267 }
268 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
269 id,
270 item: self.compile_type(item),
271 },
272 cube::VariableKind::GlobalOutputArray(id) => {
273 wgsl::Variable::GlobalOutputArray(id, self.compile_type(item))
274 }
275 cube::VariableKind::ConstantScalar(value) => {
276 wgsl::Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
277 }
278 cube::VariableKind::SharedArray {
279 id,
280 length,
281 unroll_factor,
282 alignment,
283 } => {
284 let item = self.compile_type(item);
285 if !self.shared_arrays.iter().any(|s| s.index == id) {
286 self.shared_arrays.push(SharedArray::new(
287 id,
288 item,
289 length * unroll_factor,
290 alignment,
291 ));
292 }
293 wgsl::Variable::SharedArray(id, item, length)
294 }
295 cube::VariableKind::Shared { id } => {
296 let item = self.compile_type(item);
297 if !self.shared_values.iter().any(|s| s.index == id) {
298 self.shared_values.push(SharedValue::new(id, item));
299 }
300 wgsl::Variable::SharedValue(id, item)
301 }
302 cube::VariableKind::ConstantArray { id, length, .. } => {
303 let item = self.compile_type(item);
304 wgsl::Variable::ConstantArray(id, item, length)
305 }
306 cube::VariableKind::LocalArray {
307 id,
308 length,
309 unroll_factor,
310 } => {
311 let item = self.compile_type(item);
312 if !self.local_arrays.iter().any(|s| s.index == id) {
313 self.local_arrays
314 .push(LocalArray::new(id, item, length * unroll_factor));
315 }
316 wgsl::Variable::LocalArray(id, item, length)
317 }
318 cube::VariableKind::Builtin(builtin) => match builtin {
319 cube::Builtin::AbsolutePos => {
320 self.id = true;
321 wgsl::Variable::Id
322 }
323 cube::Builtin::UnitPos => {
324 self.local_invocation_index = true;
325 wgsl::Variable::LocalInvocationIndex
326 }
327 cube::Builtin::UnitPosX => {
328 self.local_invocation_id = true;
329 wgsl::Variable::LocalInvocationIdX
330 }
331 cube::Builtin::UnitPosY => {
332 self.local_invocation_id = true;
333 wgsl::Variable::LocalInvocationIdY
334 }
335 cube::Builtin::UnitPosZ => {
336 self.local_invocation_id = true;
337 wgsl::Variable::LocalInvocationIdZ
338 }
339 cube::Builtin::CubePosX => {
340 self.workgroup_id = true;
341 wgsl::Variable::WorkgroupIdX
342 }
343 cube::Builtin::CubePosY => {
344 self.workgroup_id = true;
345 wgsl::Variable::WorkgroupIdY
346 }
347 cube::Builtin::CubePosZ => {
348 self.workgroup_id = true;
349 wgsl::Variable::WorkgroupIdZ
350 }
351 cube::Builtin::CubePosCluster
352 | cube::Builtin::CubePosClusterX
353 | cube::Builtin::CubePosClusterY
354 | cube::Builtin::CubePosClusterZ => self.constant_var(1),
355 cube::Builtin::AbsolutePosX => {
356 self.global_invocation_id = true;
357 wgsl::Variable::GlobalInvocationIdX
358 }
359 cube::Builtin::AbsolutePosY => {
360 self.global_invocation_id = true;
361 wgsl::Variable::GlobalInvocationIdY
362 }
363 cube::Builtin::AbsolutePosZ => {
364 self.global_invocation_id = true;
365 wgsl::Variable::GlobalInvocationIdZ
366 }
367 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
368 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
369 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
370 cube::Builtin::CubeClusterDim
371 | cube::Builtin::CubeClusterDimX
372 | cube::Builtin::CubeClusterDimY
373 | cube::Builtin::CubeClusterDimZ => self.constant_var(1),
374 cube::Builtin::CubeCountX => {
375 self.num_workgroups = true;
376 wgsl::Variable::NumWorkgroupsX
377 }
378 cube::Builtin::CubeCountY => {
379 self.num_workgroups = true;
380 wgsl::Variable::NumWorkgroupsY
381 }
382 cube::Builtin::CubeCountZ => {
383 self.num_workgroups = true;
384 wgsl::Variable::NumWorkgroupsZ
385 }
386 cube::Builtin::CubePos => {
387 self.workgroup_id_no_axis = true;
388 wgsl::Variable::WorkgroupId
389 }
390 cube::Builtin::CubeDim => {
391 self.workgroup_size_no_axis = true;
392 wgsl::Variable::WorkgroupSize
393 }
394 cube::Builtin::CubeCount => {
395 self.num_workgroup_no_axis = true;
396 wgsl::Variable::NumWorkgroups
397 }
398 cube::Builtin::PlaneDim => {
399 self.subgroup_size = true;
400 wgsl::Variable::SubgroupSize
401 }
402 cube::Builtin::UnitPosPlane => {
403 self.subgroup_invocation_id = true;
404 wgsl::Variable::SubgroupInvocationId
405 }
406 },
407 cube::VariableKind::Matrix { .. } => {
408 panic!("Cooperative matrix-multiply and accumulate not supported.")
409 }
410 cube::VariableKind::Pipeline { .. } => {
411 panic!("Pipeline not supported.")
412 }
413 cube::VariableKind::BarrierToken { .. } => {
414 panic!("Barrier not supported.")
415 }
416 cube::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
417 cube::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
418 }
419 }
420
421 fn constant_var(&mut self, value: u32) -> wgsl::Variable {
422 let var = cube::Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32));
423 self.compile_variable(var)
424 }
425
426 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
427 let mut instructions = Vec::new();
428
429 let const_arrays = scope
430 .const_arrays
431 .drain(..)
432 .map(|(var, values)| ConstantArray {
433 index: var.index().unwrap(),
434 item: self.compile_type(var.ty),
435 size: values.len() as u32,
436 values: values
437 .into_iter()
438 .map(|val| self.compile_variable(val))
439 .collect(),
440 })
441 .collect::<Vec<_>>();
442 self.const_arrays.extend(const_arrays);
443
444 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
445 let unroll = Box::new(UnrollProcessor::new(MAX_LINE_SIZE));
446 let saturating = Box::new(SaturatingArithmeticProcessor::new(true));
447 let processing = scope.process([&*unroll, &*checked_io, &*saturating]);
448
449 for mut var in processing.variables {
450 if var.ty.line_size() > MAX_LINE_SIZE {
451 var.ty = var.ty.line(MAX_LINE_SIZE);
452 }
453 instructions.push(wgsl::Instruction::DeclareVariable {
454 var: self.compile_variable(var),
455 });
456 }
457
458 processing
459 .instructions
460 .into_iter()
461 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
462
463 instructions
464 }
465
466 fn compile_operation(
467 &mut self,
468 instructions: &mut Vec<wgsl::Instruction>,
469 operation: cube::Operation,
470 out: Option<cube::Variable>,
471 scope: &mut cube::Scope,
472 ) {
473 match operation {
474 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
475 input: self.compile_variable(variable),
476 out: self.compile_variable(out.unwrap()),
477 }),
478 cube::Operation::Arithmetic(op) => {
479 self.compile_arithmetic(op, out, instructions, scope)
480 }
481 cube::Operation::Comparison(op) => self.compile_cmp(op, out, instructions),
482 cube::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
483 cube::Operation::Operator(op) => self.compile_operator(op, out, instructions),
484 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
485 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
486 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
487 cube::Operation::Synchronization(val) => {
488 self.compile_synchronization(instructions, val)
489 }
490 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
491 cube::Operation::CoopMma(_) => {
492 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
493 }
494 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
495 self.compile_comment(instructions, content)
496 }
497 cube::Operation::NonSemantic(_) => {}
498 cube::Operation::Barrier(_) => {
499 panic!("Barrier isn't supported on wgpu.")
500 }
501 cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."),
502 cube::Operation::Marker(_) => {}
503 }
504 }
505
506 fn compile_subgroup(
507 &mut self,
508 instructions: &mut Vec<wgsl::Instruction>,
509 subgroup: cube::Plane,
510 out: Option<cube::Variable>,
511 ) {
512 self.subgroup_instructions_used = true;
513
514 let out = out.unwrap();
515 let op = match subgroup {
516 cube::Plane::Elect => Subgroup::Elect {
517 out: self.compile_variable(out),
518 },
519 cube::Plane::All(op) => Subgroup::All {
520 input: self.compile_variable(op.input),
521 out: self.compile_variable(out),
522 },
523 cube::Plane::Any(op) => Subgroup::Any {
524 input: self.compile_variable(op.input),
525 out: self.compile_variable(out),
526 },
527 cube::Plane::Ballot(op) => Subgroup::Ballot {
528 input: self.compile_variable(op.input),
529 out: self.compile_variable(out),
530 },
531
532 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
533 lhs: self.compile_variable(op.lhs),
534 rhs: self.compile_variable(op.rhs),
535 out: self.compile_variable(out),
536 },
537
538 cube::Plane::Sum(op) => Subgroup::Sum {
539 input: self.compile_variable(op.input),
540 out: self.compile_variable(out),
541 },
542
543 cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum {
544 input: self.compile_variable(op.input),
545 out: self.compile_variable(out),
546 },
547 cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum {
548 input: self.compile_variable(op.input),
549 out: self.compile_variable(out),
550 },
551 cube::Plane::Prod(op) => Subgroup::Prod {
552 input: self.compile_variable(op.input),
553 out: self.compile_variable(out),
554 },
555 cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd {
556 input: self.compile_variable(op.input),
557 out: self.compile_variable(out),
558 },
559 cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd {
560 input: self.compile_variable(op.input),
561 out: self.compile_variable(out),
562 },
563 cube::Plane::Min(op) => Subgroup::Min {
564 input: self.compile_variable(op.input),
565 out: self.compile_variable(out),
566 },
567 cube::Plane::Max(op) => Subgroup::Max {
568 input: self.compile_variable(op.input),
569 out: self.compile_variable(out),
570 },
571 cube::Plane::Shuffle(op) => Subgroup::Shuffle {
572 lhs: self.compile_variable(op.lhs),
573 rhs: self.compile_variable(op.rhs),
574 out: self.compile_variable(out),
575 },
576 cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor {
577 lhs: self.compile_variable(op.lhs),
578 rhs: self.compile_variable(op.rhs),
579 out: self.compile_variable(out),
580 },
581 cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp {
582 lhs: self.compile_variable(op.lhs),
583 rhs: self.compile_variable(op.rhs),
584 out: self.compile_variable(out),
585 },
586 cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown {
587 lhs: self.compile_variable(op.lhs),
588 rhs: self.compile_variable(op.rhs),
589 out: self.compile_variable(out),
590 },
591 };
592
593 instructions.push(wgsl::Instruction::Subgroup(op));
594 }
595
596 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
597 match branch {
598 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
599 cond: self.compile_variable(op.cond),
600 instructions: self.compile_scope(&mut op.scope),
601 }),
602 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
603 cond: self.compile_variable(op.cond),
604 instructions_if: self.compile_scope(&mut op.scope_if),
605 instructions_else: self.compile_scope(&mut op.scope_else),
606 }),
607 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
608 value: self.compile_variable(op.value),
609 instructions_default: self.compile_scope(&mut op.scope_default),
610 cases: op
611 .cases
612 .into_iter()
613 .map(|(val, mut scope)| {
614 (self.compile_variable(val), self.compile_scope(&mut scope))
615 })
616 .collect(),
617 }),
618 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
619 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
620 cube::Branch::RangeLoop(mut range_loop) => {
621 instructions.push(wgsl::Instruction::RangeLoop {
622 i: self.compile_variable(range_loop.i),
623 start: self.compile_variable(range_loop.start),
624 end: self.compile_variable(range_loop.end),
625 step: range_loop.step.map(|it| self.compile_variable(it)),
626 inclusive: range_loop.inclusive,
627 instructions: self.compile_scope(&mut range_loop.scope),
628 })
629 }
630 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
631 instructions: self.compile_scope(&mut op.scope),
632 }),
633 };
634 }
635
636 fn compile_synchronization(
637 &mut self,
638 instructions: &mut Vec<wgsl::Instruction>,
639 synchronization: cube::Synchronization,
640 ) {
641 match synchronization {
642 cube::Synchronization::SyncCube => {
643 instructions.push(wgsl::Instruction::WorkgroupBarrier)
644 }
645 cube::Synchronization::SyncPlane => {
646 panic!("Synchronization within a plane is not supported in WGSL")
647 }
648 cube::Synchronization::SyncStorage => {
649 instructions.push(wgsl::Instruction::StorageBarrier)
650 }
651 cube::Synchronization::SyncAsyncProxyShared => panic!("TMA is not supported in WGSL"),
652 };
653 }
654
655 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
656 instructions.push(wgsl::Instruction::Comment { content })
657 }
658
659 fn compile_metadata(
660 &mut self,
661 metadata: cube::Metadata,
662 out: Option<cube::Variable>,
663 ) -> wgsl::Instruction {
664 let out = out.unwrap();
665 match metadata {
666 cube::Metadata::Rank { var } => {
667 let position = self.ext_meta_pos(&var);
668 let offset = self.metadata.rank_index(position);
669 wgsl::Instruction::Metadata {
670 out: self.compile_variable(out),
671 info_offset: self.compile_variable(offset.into()),
672 }
673 }
674 cube::Metadata::Stride { dim, var } => {
675 let position = self.ext_meta_pos(&var);
676 let offset = self.metadata.stride_offset_index(position);
677 wgsl::Instruction::ExtendedMeta {
678 info_offset: self.compile_variable(offset.into()),
679 dim: self.compile_variable(dim),
680 out: self.compile_variable(out),
681 }
682 }
683 cube::Metadata::Shape { dim, var } => {
684 let position = self.ext_meta_pos(&var);
685 let offset = self.metadata.shape_offset_index(position);
686 wgsl::Instruction::ExtendedMeta {
687 info_offset: self.compile_variable(offset.into()),
688 dim: self.compile_variable(dim),
689 out: self.compile_variable(out),
690 }
691 }
692 cube::Metadata::Length { var } => match var.kind {
693 cube::VariableKind::GlobalInputArray(id) => {
694 let offset = self.metadata.len_index(id);
695 wgsl::Instruction::Metadata {
696 out: self.compile_variable(out),
697 info_offset: self.compile_variable(offset.into()),
698 }
699 }
700 cube::VariableKind::GlobalOutputArray(id) => {
701 let offset = self.metadata.len_index(id);
702 wgsl::Instruction::Metadata {
703 out: self.compile_variable(out),
704 info_offset: self.compile_variable(offset.into()),
705 }
706 }
707 _ => wgsl::Instruction::Length {
708 var: self.compile_variable(var),
709 out: self.compile_variable(out),
710 },
711 },
712 cube::Metadata::BufferLength { var } => match var.kind {
713 cube::VariableKind::GlobalInputArray(id) => {
714 let offset = self.metadata.buffer_len_index(id);
715 wgsl::Instruction::Metadata {
716 out: self.compile_variable(out),
717 info_offset: self.compile_variable(offset.into()),
718 }
719 }
720 cube::VariableKind::GlobalOutputArray(id) => {
721 let offset = self.metadata.buffer_len_index(id);
722 wgsl::Instruction::Metadata {
723 out: self.compile_variable(out),
724 info_offset: self.compile_variable(offset.into()),
725 }
726 }
727 _ => wgsl::Instruction::Length {
728 var: self.compile_variable(var),
729 out: self.compile_variable(out),
730 },
731 },
732 }
733 }
734
735 fn compile_arithmetic(
736 &mut self,
737 value: cube::Arithmetic,
738 out: Option<cube::Variable>,
739 instructions: &mut Vec<wgsl::Instruction>,
740 scope: &mut Scope,
741 ) {
742 let out = out.unwrap();
743 match value {
744 cube::Arithmetic::Max(op) => instructions.push(wgsl::Instruction::Max {
745 lhs: self.compile_variable(op.lhs),
746 rhs: self.compile_variable(op.rhs),
747 out: self.compile_variable(out),
748 }),
749 cube::Arithmetic::Min(op) => instructions.push(wgsl::Instruction::Min {
750 lhs: self.compile_variable(op.lhs),
751 rhs: self.compile_variable(op.rhs),
752 out: self.compile_variable(out),
753 }),
754 cube::Arithmetic::Add(op) => instructions.push(wgsl::Instruction::Add {
755 lhs: self.compile_variable(op.lhs),
756 rhs: self.compile_variable(op.rhs),
757 out: self.compile_variable(out),
758 }),
759 cube::Arithmetic::SaturatingAdd(_) => {
760 unreachable!("Saturating add should be removed by processor");
761 }
762 cube::Arithmetic::Fma(op) => instructions.push(wgsl::Instruction::Fma {
763 a: self.compile_variable(op.a),
764 b: self.compile_variable(op.b),
765 c: self.compile_variable(op.c),
766 out: self.compile_variable(out),
767 }),
768 cube::Arithmetic::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
769 lhs: self.compile_variable(op.lhs),
770 rhs: self.compile_variable(op.rhs),
771 out: self.compile_variable(out),
772 }),
773 cube::Arithmetic::Sub(op) => instructions.push(wgsl::Instruction::Sub {
774 lhs: self.compile_variable(op.lhs),
775 rhs: self.compile_variable(op.rhs),
776 out: self.compile_variable(out),
777 }),
778 cube::Arithmetic::SaturatingSub(_) => {
779 unreachable!("Saturating sub should be removed by processor");
780 }
781 cube::Arithmetic::Mul(op) => instructions.push(wgsl::Instruction::Mul {
782 lhs: self.compile_variable(op.lhs),
783 rhs: self.compile_variable(op.rhs),
784 out: self.compile_variable(out),
785 }),
786 cube::Arithmetic::Div(op) => instructions.push(wgsl::Instruction::Div {
787 lhs: self.compile_variable(op.lhs),
788 rhs: self.compile_variable(op.rhs),
789 out: self.compile_variable(out),
790 }),
791 cube::Arithmetic::Abs(op) => instructions.push(wgsl::Instruction::Abs {
792 input: self.compile_variable(op.input),
793 out: self.compile_variable(out),
794 }),
795 cube::Arithmetic::Exp(op) => instructions.push(wgsl::Instruction::Exp {
796 input: self.compile_variable(op.input),
797 out: self.compile_variable(out),
798 }),
799 cube::Arithmetic::Log(op) => instructions.push(wgsl::Instruction::Log {
800 input: self.compile_variable(op.input),
801 out: self.compile_variable(out),
802 }),
803 cube::Arithmetic::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
804 input: self.compile_variable(op.input),
805 out: self.compile_variable(out),
806 }),
807 cube::Arithmetic::Cos(op) => instructions.push(wgsl::Instruction::Cos {
808 input: self.compile_variable(op.input),
809 out: self.compile_variable(out),
810 }),
811 cube::Arithmetic::Sin(op) => instructions.push(wgsl::Instruction::Sin {
812 input: self.compile_variable(op.input),
813 out: self.compile_variable(out),
814 }),
815 cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan {
816 input: self.compile_variable(op.input),
817 out: self.compile_variable(out),
818 }),
819 cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
820 input: self.compile_variable(op.input),
821 out: self.compile_variable(out),
822 }),
823 cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh {
824 input: self.compile_variable(op.input),
825 out: self.compile_variable(out),
826 }),
827 cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh {
828 input: self.compile_variable(op.input),
829 out: self.compile_variable(out),
830 }),
831 cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos {
832 input: self.compile_variable(op.input),
833 out: self.compile_variable(out),
834 }),
835 cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin {
836 input: self.compile_variable(op.input),
837 out: self.compile_variable(out),
838 }),
839 cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan {
840 input: self.compile_variable(op.input),
841 out: self.compile_variable(out),
842 }),
843 cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh {
844 input: self.compile_variable(op.input),
845 out: self.compile_variable(out),
846 }),
847 cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh {
848 input: self.compile_variable(op.input),
849 out: self.compile_variable(out),
850 }),
851 cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh {
852 input: self.compile_variable(op.input),
853 out: self.compile_variable(out),
854 }),
855 cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees {
856 input: self.compile_variable(op.input),
857 out: self.compile_variable(out),
858 }),
859 cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians {
860 input: self.compile_variable(op.input),
861 out: self.compile_variable(out),
862 }),
863 cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 {
864 lhs: self.compile_variable(op.lhs),
865 rhs: self.compile_variable(op.rhs),
866 out: self.compile_variable(out),
867 }),
868 cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => {
870 instructions.push(wgsl::Instruction::Powf {
871 lhs: self.compile_variable(op.lhs),
872 rhs: self.compile_variable(op.rhs),
873 out: self.compile_variable(out),
874 })
875 }
876 cube::Arithmetic::Hypot(op) => {
877 let mut scope = scope.child();
878 expand_hypot(&mut scope, op.lhs, op.rhs, out);
879 instructions.extend(self.compile_scope(&mut scope));
880 }
881 cube::Arithmetic::Rhypot(op) => {
882 let mut scope = scope.child();
883 expand_rhypot(&mut scope, op.lhs, op.rhs, out);
884 instructions.extend(self.compile_scope(&mut scope));
885 }
886
887 cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
888 input: self.compile_variable(op.input),
889 out: self.compile_variable(out),
890 }),
891 cube::Arithmetic::InverseSqrt(op) => {
892 instructions.push(wgsl::Instruction::InverseSqrt {
893 input: self.compile_variable(op.input),
894 out: self.compile_variable(out),
895 })
896 }
897 cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round {
898 input: self.compile_variable(op.input),
899 out: self.compile_variable(out),
900 }),
901 cube::Arithmetic::Floor(op) => instructions.push(wgsl::Instruction::Floor {
902 input: self.compile_variable(op.input),
903 out: self.compile_variable(out),
904 }),
905 cube::Arithmetic::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
906 input: self.compile_variable(op.input),
907 out: self.compile_variable(out),
908 }),
909 cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
910 input: self.compile_variable(op.input),
911 out: self.compile_variable(out),
912 }),
913 cube::Arithmetic::Erf(op) => {
914 let mut scope = scope.child();
915 expand_erf(&mut scope, op.input, out);
916 instructions.extend(self.compile_scope(&mut scope));
917 }
918 cube::Arithmetic::MulHi(op) => {
919 let mut scope = scope.child();
920 match self.compilation_options.supports_u64 {
921 true => expand_himul_64(&mut scope, op.lhs, op.rhs, out),
922 false => expand_himul_sim(&mut scope, op.lhs, op.rhs, out),
923 }
924 instructions.extend(self.compile_scope(&mut scope));
925 }
926 cube::Arithmetic::Recip(op) => instructions.push(wgsl::Instruction::Recip {
927 input: self.compile_variable(op.input),
928 out: self.compile_variable(out),
929 }),
930 cube::Arithmetic::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
931 input: self.compile_variable(op.input),
932 min_value: self.compile_variable(op.min_value),
933 max_value: self.compile_variable(op.max_value),
934 out: self.compile_variable(out),
935 }),
936 cube::Arithmetic::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
937 lhs: self.compile_variable(op.lhs),
938 rhs: self.compile_variable(op.rhs),
939 out: self.compile_variable(out),
940 }),
941 cube::Arithmetic::Neg(op) => instructions.push(wgsl::Instruction::Negate {
942 input: self.compile_variable(op.input),
943 out: self.compile_variable(out),
944 }),
945 cube::Arithmetic::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
946 input: self.compile_variable(op.input),
947 out: self.compile_variable(out),
948 }),
949 cube::Arithmetic::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
950 input: self.compile_variable(op.input),
951 out: self.compile_variable(out),
952 }),
953 cube::Arithmetic::Dot(op) => instructions.push(wgsl::Instruction::Dot {
954 lhs: self.compile_variable(op.lhs),
955 rhs: self.compile_variable(op.rhs),
956 out: self.compile_variable(out),
957 }),
958 }
959 }
960
961 fn compile_cmp(
962 &mut self,
963 value: cube::Comparison,
964 out: Option<cube::Variable>,
965 instructions: &mut Vec<wgsl::Instruction>,
966 ) {
967 let out = out.unwrap();
968 match value {
969 cube::Comparison::Equal(op) => instructions.push(wgsl::Instruction::Equal {
970 lhs: self.compile_variable(op.lhs),
971 rhs: self.compile_variable(op.rhs),
972 out: self.compile_variable(out),
973 }),
974 cube::Comparison::Lower(op) => instructions.push(wgsl::Instruction::Lower {
975 lhs: self.compile_variable(op.lhs),
976 rhs: self.compile_variable(op.rhs),
977 out: self.compile_variable(out),
978 }),
979 cube::Comparison::Greater(op) => instructions.push(wgsl::Instruction::Greater {
980 lhs: self.compile_variable(op.lhs),
981 rhs: self.compile_variable(op.rhs),
982 out: self.compile_variable(out),
983 }),
984 cube::Comparison::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
985 lhs: self.compile_variable(op.lhs),
986 rhs: self.compile_variable(op.rhs),
987 out: self.compile_variable(out),
988 }),
989 cube::Comparison::GreaterEqual(op) => {
990 instructions.push(wgsl::Instruction::GreaterEqual {
991 lhs: self.compile_variable(op.lhs),
992 rhs: self.compile_variable(op.rhs),
993 out: self.compile_variable(out),
994 })
995 }
996 cube::Comparison::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
997 lhs: self.compile_variable(op.lhs),
998 rhs: self.compile_variable(op.rhs),
999 out: self.compile_variable(out),
1000 }),
1001 cube::Comparison::IsNan(op) => instructions.push(wgsl::Instruction::IsNan {
1002 input: self.compile_variable(op.input),
1003 out: self.compile_variable(out),
1004 }),
1005 cube::Comparison::IsInf(op) => instructions.push(wgsl::Instruction::IsInf {
1006 input: self.compile_variable(op.input),
1007 out: self.compile_variable(out),
1008 }),
1009 }
1010 }
1011
1012 fn compile_bitwise(
1013 &mut self,
1014 value: cube::Bitwise,
1015 out: Option<cube::Variable>,
1016 instructions: &mut Vec<wgsl::Instruction>,
1017 ) {
1018 let out = out.unwrap();
1019 match value {
1020 cube::Bitwise::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
1021 lhs: self.compile_variable(op.lhs),
1022 rhs: self.compile_variable(op.rhs),
1023 out: self.compile_variable(out),
1024 }),
1025 cube::Bitwise::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
1026 lhs: self.compile_variable(op.lhs),
1027 rhs: self.compile_variable(op.rhs),
1028 out: self.compile_variable(out),
1029 }),
1030 cube::Bitwise::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
1031 lhs: self.compile_variable(op.lhs),
1032 rhs: self.compile_variable(op.rhs),
1033 out: self.compile_variable(out),
1034 }),
1035 cube::Bitwise::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
1036 input: self.compile_variable(op.input),
1037 out: self.compile_variable(out),
1038 }),
1039 cube::Bitwise::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
1040 input: self.compile_variable(op.input),
1041 out: self.compile_variable(out),
1042 }),
1043 cube::Bitwise::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
1044 lhs: self.compile_variable(op.lhs),
1045 rhs: self.compile_variable(op.rhs),
1046 out: self.compile_variable(out),
1047 }),
1048 cube::Bitwise::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
1049 lhs: self.compile_variable(op.lhs),
1050 rhs: self.compile_variable(op.rhs),
1051 out: self.compile_variable(out),
1052 }),
1053 cube::Bitwise::BitwiseNot(op) => instructions.push(wgsl::Instruction::BitwiseNot {
1054 input: self.compile_variable(op.input),
1055 out: self.compile_variable(out),
1056 }),
1057 cube::Bitwise::LeadingZeros(op) => instructions.push(wgsl::Instruction::LeadingZeros {
1058 input: self.compile_variable(op.input),
1059 out: self.compile_variable(out),
1060 }),
1061 cube::Bitwise::FindFirstSet(op) => instructions.push(wgsl::Instruction::FindFirstSet {
1062 input: self.compile_variable(op.input),
1063 out: self.compile_variable(out),
1064 }),
1065 }
1066 }
1067
1068 fn compile_operator(
1069 &mut self,
1070 value: cube::Operator,
1071 out: Option<cube::Variable>,
1072 instructions: &mut Vec<wgsl::Instruction>,
1073 ) {
1074 let out = out.unwrap();
1075 match value {
1076 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
1077 input: self.compile_variable(op.input),
1078 out: self.compile_variable(out),
1079 }),
1080 cube::Operator::Index(op) | cube::Operator::UncheckedIndex(op) => {
1081 instructions.push(wgsl::Instruction::Index {
1082 lhs: self.compile_variable(op.list),
1083 rhs: self.compile_variable(op.index),
1084 out: self.compile_variable(out),
1085 });
1086 }
1087 cube::Operator::IndexAssign(op) | cube::Operator::UncheckedIndexAssign(op) => {
1088 instructions.push(wgsl::Instruction::IndexAssign {
1089 index: self.compile_variable(op.index),
1090 rhs: self.compile_variable(op.value),
1091 out: self.compile_variable(out),
1092 })
1093 }
1094 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
1095 lhs: self.compile_variable(op.lhs),
1096 rhs: self.compile_variable(op.rhs),
1097 out: self.compile_variable(out),
1098 }),
1099 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
1100 lhs: self.compile_variable(op.lhs),
1101 rhs: self.compile_variable(op.rhs),
1102 out: self.compile_variable(out),
1103 }),
1104 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
1105 input: self.compile_variable(op.input),
1106 out: self.compile_variable(out),
1107 }),
1108 cube::Operator::Reinterpret(op) => instructions.push(wgsl::Instruction::Bitcast {
1109 input: self.compile_variable(op.input),
1110 out: self.compile_variable(out),
1111 }),
1112 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1113 inputs: op
1114 .inputs
1115 .into_iter()
1116 .map(|var| self.compile_variable(var))
1117 .collect(),
1118 out: self.compile_variable(out),
1119 }),
1120 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1121 input: self.compile_variable(op.input),
1122 in_index: self.compile_variable(op.in_index),
1123 out: self.compile_variable(out),
1124 out_index: self.compile_variable(op.out_index),
1125 }),
1126 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1127 input: self.compile_variable(op.input),
1128 in_index: self.compile_variable(op.in_index),
1129 out: self.compile_variable(out),
1130 out_index: self.compile_variable(op.out_index),
1131 len: op.len,
1132 }),
1133 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1134 cond: self.compile_variable(op.cond),
1135 then: self.compile_variable(op.then),
1136 or_else: self.compile_variable(op.or_else),
1137 out: self.compile_variable(out),
1138 }),
1139 }
1140 }
1141
1142 fn compile_atomic(
1143 &mut self,
1144 atomic: cube::AtomicOp,
1145 out: Option<cube::Variable>,
1146 ) -> wgsl::Instruction {
1147 let out = out.unwrap();
1148 match atomic {
1149 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1150 lhs: self.compile_variable(op.lhs),
1151 rhs: self.compile_variable(op.rhs),
1152 out: self.compile_variable(out),
1153 },
1154 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1155 lhs: self.compile_variable(op.lhs),
1156 rhs: self.compile_variable(op.rhs),
1157 out: self.compile_variable(out),
1158 },
1159 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1160 lhs: self.compile_variable(op.lhs),
1161 rhs: self.compile_variable(op.rhs),
1162 out: self.compile_variable(out),
1163 },
1164 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1165 lhs: self.compile_variable(op.lhs),
1166 rhs: self.compile_variable(op.rhs),
1167 out: self.compile_variable(out),
1168 },
1169 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1170 lhs: self.compile_variable(op.lhs),
1171 rhs: self.compile_variable(op.rhs),
1172 out: self.compile_variable(out),
1173 },
1174 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1175 lhs: self.compile_variable(op.lhs),
1176 rhs: self.compile_variable(op.rhs),
1177 out: self.compile_variable(out),
1178 },
1179 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1180 lhs: self.compile_variable(op.lhs),
1181 rhs: self.compile_variable(op.rhs),
1182 out: self.compile_variable(out),
1183 },
1184 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1185 input: self.compile_variable(op.input),
1186 out: self.compile_variable(out),
1187 },
1188 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1189 input: self.compile_variable(op.input),
1190 out: self.compile_variable(out),
1191 },
1192 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1193 lhs: self.compile_variable(op.lhs),
1194 rhs: self.compile_variable(op.rhs),
1195 out: self.compile_variable(out),
1196 },
1197 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1198 lhs: self.compile_variable(op.input),
1199 cmp: self.compile_variable(op.cmp),
1200 value: self.compile_variable(op.val),
1201 out: self.compile_variable(out),
1202 },
1203 }
1204 }
1205
1206 fn compile_location(value: kernel::Location) -> wgsl::Location {
1207 match value {
1208 kernel::Location::Storage => wgsl::Location::Storage,
1209 kernel::Location::Cube => wgsl::Location::Workgroup,
1210 }
1211 }
1212
1213 fn compile_binding(&mut self, value: kernel::Binding) -> wgsl::Binding {
1214 wgsl::Binding {
1215 id: value.id,
1216 visibility: value.visibility,
1217 location: Self::compile_location(value.location),
1218 item: self.compile_type(value.ty),
1219 size: value.size,
1220 }
1221 }
1222}
1223
1224fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1225 let mut extensions = Vec::new();
1226
1227 let mut register_extension = |extension: wgsl::Extension| {
1228 if !extensions.contains(&extension) {
1229 extensions.push(extension);
1230 }
1231 };
1232
1233 for instruction in instructions {
1235 match instruction {
1236 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1237 register_extension(wgsl::Extension::PowfPrimitive(out.elem()));
1238 register_extension(wgsl::powf_extension(rhs, out));
1239 }
1240 #[cfg(target_os = "macos")]
1241 wgsl::Instruction::Tanh { input, out: _ } => {
1242 register_extension(wgsl::Extension::SafeTanhPrimitive(input.elem()));
1243 register_extension(wgsl::Extension::SafeTanh(input.item()));
1244 }
1245 wgsl::Instruction::IsNan { input, out } => {
1246 register_extension(wgsl::Extension::IsNanPrimitive(input.elem()));
1247 register_extension(wgsl::Extension::IsNan(input.item(), out.item()));
1248 }
1249 wgsl::Instruction::IsInf { input, out } => {
1250 register_extension(wgsl::Extension::IsInfPrimitive(input.elem()));
1251 register_extension(wgsl::Extension::IsInf(input.item(), out.item()));
1252 }
1253 wgsl::Instruction::If { instructions, .. } => {
1254 for extension in register_extensions(instructions) {
1255 register_extension(extension);
1256 }
1257 }
1258 wgsl::Instruction::IfElse {
1259 instructions_if,
1260 instructions_else,
1261 ..
1262 } => {
1263 for extension in register_extensions(instructions_if) {
1264 register_extension(extension);
1265 }
1266 for extension in register_extensions(instructions_else) {
1267 register_extension(extension);
1268 }
1269 }
1270 wgsl::Instruction::Loop { instructions } => {
1271 for extension in register_extensions(instructions) {
1272 register_extension(extension);
1273 }
1274 }
1275 wgsl::Instruction::RangeLoop { instructions, .. } => {
1276 for extension in register_extensions(instructions) {
1277 register_extension(extension);
1278 }
1279 }
1280 _ => {}
1281 }
1282 }
1283
1284 extensions
1285}