1use std::{borrow::Cow, sync::Arc};
2
3use super::Subgroup;
4use super::{shader::ComputeShader, ConstantArray};
5use super::{Item, LocalArray, SharedMemory};
6use crate::{
7 compiler::{base::WgpuCompiler, wgsl},
8 WgpuServer,
9};
10
11use cubecl_core::ir::expand_checked_index_assign;
12use cubecl_core::{
13 ir::{self as cube, UIntKind},
14 prelude::CompiledKernel,
15 server::ComputeServer,
16 Feature, Metadata,
17};
18use cubecl_runtime::{DeviceProperties, ExecutionMode};
19use wgpu::{
20 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
21 ComputePipeline, DeviceDescriptor, PipelineLayoutDescriptor, ShaderModuleDescriptor,
22 ShaderStages,
23};
24
25#[derive(Clone, Debug, Default)]
26pub struct CompilationOptions {}
27
28#[derive(Clone, Default)]
30pub struct WgslCompiler {
31 num_inputs: usize,
32 num_outputs: usize,
33 metadata: Metadata,
34 ext_meta_pos: Vec<u32>,
35 local_invocation_index: bool,
36 local_invocation_id: bool,
37 global_invocation_id: bool,
38 workgroup_id: bool,
39 subgroup_size: bool,
40 subgroup_invocation_id: bool,
41 id: bool,
42 num_workgroups: bool,
43 workgroup_id_no_axis: bool,
44 workgroup_size_no_axis: bool,
45 num_workgroup_no_axis: bool,
46 shared_memories: Vec<SharedMemory>,
47 const_arrays: Vec<ConstantArray>,
48 local_arrays: Vec<LocalArray>,
49 #[allow(dead_code)]
50 compilation_options: CompilationOptions,
51 strategy: ExecutionMode,
52 subgroup_instructions_used: bool,
53}
54
55impl core::fmt::Debug for WgslCompiler {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.write_str("WgslCompiler")
58 }
59}
60
61impl cubecl_core::Compiler for WgslCompiler {
62 type Representation = ComputeShader;
63 type CompilationOptions = CompilationOptions;
64
65 fn compile(
66 shader: cube::KernelDefinition,
67 compilation_options: &Self::CompilationOptions,
68 mode: ExecutionMode,
69 ) -> Self::Representation {
70 let mut compiler = Self {
71 compilation_options: compilation_options.clone(),
72 ..Self::default()
73 };
74 compiler.compile_shader(shader, mode)
75 }
76
77 fn elem_size(elem: cube::Elem) -> usize {
78 Self::compile_elem(elem).size()
79 }
80
81 fn max_shared_memory_size() -> usize {
82 32768
83 }
84}
85
86impl WgpuCompiler for WgslCompiler {
87 fn create_pipeline(
88 server: &mut WgpuServer<Self>,
89 kernel: CompiledKernel<Self>,
90 ) -> Arc<ComputePipeline> {
91 let source = &kernel.source;
92 let module = unsafe {
100 server
101 .device
102 .create_shader_module_unchecked(ShaderModuleDescriptor {
103 label: None,
104 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
105 })
106 };
107
108 let layout = kernel.repr.map(|repr| {
109 let bindings = repr
110 .inputs
111 .iter()
112 .chain(repr.outputs.iter())
113 .chain(repr.named.iter().map(|it| &it.1))
114 .enumerate();
115 let bindings = bindings
116 .map(|(i, _binding)| BindGroupLayoutEntry {
117 binding: i as u32,
118 visibility: ShaderStages::COMPUTE,
119 ty: BindingType::Buffer {
120 #[cfg(not(exclusive_memory_only))]
121 ty: BufferBindingType::Storage { read_only: false },
122 #[cfg(exclusive_memory_only)]
123 ty: BufferBindingType::Storage {
124 read_only: matches!(
125 _binding.visibility,
126 super::shader::Visibility::Read
127 ),
128 },
129 has_dynamic_offset: false,
130 min_binding_size: None,
131 },
132 count: None,
133 })
134 .collect::<Vec<_>>();
135 let layout = server
136 .device
137 .create_bind_group_layout(&BindGroupLayoutDescriptor {
138 label: None,
139 entries: &bindings,
140 });
141 server
142 .device
143 .create_pipeline_layout(&PipelineLayoutDescriptor {
144 label: None,
145 bind_group_layouts: &[&layout],
146 push_constant_ranges: &[],
147 })
148 });
149
150 Arc::new(
151 server
152 .device
153 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
154 label: None,
155 layout: layout.as_ref(),
156 module: &module,
157 entry_point: Some(&kernel.entrypoint_name),
158 compilation_options: wgpu::PipelineCompilationOptions {
159 zero_initialize_workgroup_memory: false,
160 ..Default::default()
161 },
162 cache: None,
163 }),
164 )
165 }
166
167 fn compile(
168 server: &mut WgpuServer<Self>,
169 kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
170 mode: ExecutionMode,
171 ) -> CompiledKernel<Self> {
172 kernel.compile(&server.compilation_options, mode)
173 }
174
175 async fn request_device(adapter: &wgpu::Adapter) -> (wgpu::Device, wgpu::Queue) {
176 let limits = adapter.limits();
177 adapter
178 .request_device(
179 &DeviceDescriptor {
180 label: None,
181 required_features: adapter.features(),
182 required_limits: limits,
183 memory_hints: wgpu::MemoryHints::MemoryUsage,
187 },
188 None,
189 )
190 .await
191 .map_err(|err| {
192 format!(
193 "Unable to request the device with the adapter {:?}, err {:?}",
194 adapter.get_info(),
195 err
196 )
197 })
198 .unwrap()
199 }
200
201 fn register_features(
202 _adapter: &wgpu::Adapter,
203 _device: &wgpu::Device,
204 props: &mut DeviceProperties<Feature>,
205 ) {
206 register_types(props);
207 }
208}
209
210fn register_types(props: &mut DeviceProperties<Feature>) {
211 use cubecl_core::ir::{Elem, FloatKind, IntKind};
212
213 let supported_types = [
214 Elem::UInt(UIntKind::U32),
215 Elem::Int(IntKind::I32),
216 Elem::AtomicInt(IntKind::I32),
217 Elem::AtomicUInt(UIntKind::U32),
218 Elem::Float(FloatKind::F32),
219 Elem::Float(FloatKind::Flex32),
220 Elem::Bool,
221 ];
222
223 for ty in supported_types {
224 props.register_feature(Feature::Type(ty));
225 }
226}
227
228impl WgslCompiler {
229 fn compile_shader(
230 &mut self,
231 mut value: cube::KernelDefinition,
232 mode: ExecutionMode,
233 ) -> wgsl::ComputeShader {
234 self.strategy = mode;
235
236 self.num_inputs = value.inputs.len();
237 self.num_outputs = value.outputs.len();
238 let num_meta = value.inputs.len() + value.outputs.len();
239
240 self.ext_meta_pos = Vec::new();
241 let mut num_ext = 0;
242
243 for binding in value.inputs.iter().chain(value.outputs.iter()) {
244 self.ext_meta_pos.push(num_ext);
245 if binding.has_extended_meta {
246 num_ext += 1;
247 }
248 }
249
250 self.metadata = Metadata::new(num_meta as u32, num_ext);
251
252 let instructions = self.compile_scope(&mut value.body);
253 let extensions = register_extensions(&instructions);
254 let body = wgsl::Body {
255 instructions,
256 id: self.id,
257 };
258
259 wgsl::ComputeShader {
260 inputs: value
261 .inputs
262 .into_iter()
263 .map(Self::compile_binding)
264 .collect(),
265 outputs: value
266 .outputs
267 .into_iter()
268 .map(Self::compile_binding)
269 .collect(),
270 named: value
271 .named
272 .into_iter()
273 .map(|(name, binding)| (name, Self::compile_binding(binding)))
274 .collect(),
275 shared_memories: self.shared_memories.clone(),
276 constant_arrays: self.const_arrays.clone(),
277 local_arrays: self.local_arrays.clone(),
278 workgroup_size: value.cube_dim,
279 global_invocation_id: self.global_invocation_id || self.id,
280 local_invocation_index: self.local_invocation_index,
281 local_invocation_id: self.local_invocation_id,
282 num_workgroups: self.id
283 || self.num_workgroups
284 || self.num_workgroup_no_axis
285 || self.workgroup_id_no_axis,
286 workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
287 subgroup_size: self.subgroup_size,
288 subgroup_invocation_id: self.subgroup_invocation_id,
289 body,
290 extensions,
291 num_workgroups_no_axis: self.num_workgroup_no_axis,
292 workgroup_id_no_axis: self.workgroup_id_no_axis,
293 workgroup_size_no_axis: self.workgroup_size_no_axis,
294 subgroup_instructions_used: self.subgroup_instructions_used,
295 kernel_name: value.kernel_name,
296 }
297 }
298
299 fn compile_item(item: cube::Item) -> Item {
300 let elem = Self::compile_elem(item.elem);
301 match item.vectorization.map(|it| it.get()).unwrap_or(1) {
302 1 => wgsl::Item::Scalar(elem),
303 2 => wgsl::Item::Vec2(elem),
304 3 => wgsl::Item::Vec3(elem),
305 4 => wgsl::Item::Vec4(elem),
306 _ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization),
307 }
308 }
309
310 fn compile_elem(value: cube::Elem) -> wgsl::Elem {
311 match value {
312 cube::Elem::Float(f) => match f {
313 cube::FloatKind::F16 => panic!("f16 is not yet supported"),
314 cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
315 cube::FloatKind::TF32 => panic!("tf32 is not a valid WgpuElement"),
316 cube::FloatKind::Flex32 => wgsl::Elem::F32,
317 cube::FloatKind::F32 => wgsl::Elem::F32,
318 cube::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"),
319 },
320 cube::Elem::Int(i) => match i {
321 cube::IntKind::I32 => wgsl::Elem::I32,
322 kind => panic!("{kind:?} is not a valid WgpuElement"),
323 },
324 cube::Elem::UInt(kind) => match kind {
325 cube::UIntKind::U32 => wgsl::Elem::U32,
326 kind => panic!("{kind:?} is not a valid WgpuElement"),
327 },
328 cube::Elem::Bool => wgsl::Elem::Bool,
329 cube::Elem::AtomicFloat(i) => match i {
330 cube::FloatKind::F32 => wgsl::Elem::AtomicF32,
331 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
332 },
333 cube::Elem::AtomicInt(i) => match i {
334 cube::IntKind::I32 => wgsl::Elem::AtomicI32,
335 kind => panic!("atomic<{kind:?}> is not a valid WgpuElement"),
336 },
337 cube::Elem::AtomicUInt(kind) => match kind {
338 cube::UIntKind::U32 => wgsl::Elem::AtomicU32,
339 kind => panic!("{kind:?} is not a valid WgpuElement"),
340 },
341 }
342 }
343
344 fn ext_meta_pos(&self, var: &cube::Variable) -> u32 {
345 let pos = match var.kind {
346 cube::VariableKind::GlobalInputArray(id) => id as usize,
347 cube::VariableKind::GlobalOutputArray(id) => self.num_inputs + id as usize,
348 _ => panic!("Only global arrays have metadata"),
349 };
350 self.ext_meta_pos[pos]
351 }
352
353 pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
354 let item = value.item;
355 match value.kind {
356 cube::VariableKind::GlobalInputArray(id) => {
357 wgsl::Variable::GlobalInputArray(id, Self::compile_item(item))
358 }
359 cube::VariableKind::GlobalScalar(id) => {
360 wgsl::Variable::GlobalScalar(id, Self::compile_elem(item.elem), item.elem)
361 }
362 cube::VariableKind::LocalMut { id } | cube::VariableKind::Versioned { id, .. } => {
363 wgsl::Variable::LocalMut {
364 id,
365 item: Self::compile_item(item),
366 }
367 }
368 cube::VariableKind::LocalConst { id } => wgsl::Variable::LocalConst {
369 id,
370 item: Self::compile_item(item),
371 },
372 cube::VariableKind::Slice { id } => wgsl::Variable::Slice {
373 id,
374 item: Self::compile_item(item),
375 },
376 cube::VariableKind::GlobalOutputArray(id) => {
377 wgsl::Variable::GlobalOutputArray(id, Self::compile_item(item))
378 }
379 cube::VariableKind::ConstantScalar(value) => {
380 wgsl::Variable::ConstantScalar(value, Self::compile_elem(value.elem()))
381 }
382 cube::VariableKind::SharedMemory { id, length } => {
383 let item = Self::compile_item(item);
384 if !self.shared_memories.iter().any(|s| s.index == id) {
385 self.shared_memories
386 .push(SharedMemory::new(id, item, length));
387 }
388 wgsl::Variable::SharedMemory(id, item, length)
389 }
390 cube::VariableKind::ConstantArray { id, length } => {
391 let item = Self::compile_item(item);
392 wgsl::Variable::ConstantArray(id, item, length)
393 }
394 cube::VariableKind::LocalArray { id, length } => {
395 let item = Self::compile_item(item);
396 if !self.local_arrays.iter().any(|s| s.index == id) {
397 self.local_arrays.push(LocalArray::new(id, item, length));
398 }
399 wgsl::Variable::LocalArray(id, item, length)
400 }
401 cube::VariableKind::Builtin(builtin) => match builtin {
402 cube::Builtin::AbsolutePos => {
403 self.id = true;
404 wgsl::Variable::Id
405 }
406 cube::Builtin::UnitPos => {
407 self.local_invocation_index = true;
408 wgsl::Variable::LocalInvocationIndex
409 }
410 cube::Builtin::UnitPosX => {
411 self.local_invocation_id = true;
412 wgsl::Variable::LocalInvocationIdX
413 }
414 cube::Builtin::UnitPosY => {
415 self.local_invocation_id = true;
416 wgsl::Variable::LocalInvocationIdY
417 }
418 cube::Builtin::UnitPosZ => {
419 self.local_invocation_id = true;
420 wgsl::Variable::LocalInvocationIdZ
421 }
422 cube::Builtin::CubePosX => {
423 self.workgroup_id = true;
424 wgsl::Variable::WorkgroupIdX
425 }
426 cube::Builtin::CubePosY => {
427 self.workgroup_id = true;
428 wgsl::Variable::WorkgroupIdY
429 }
430 cube::Builtin::CubePosZ => {
431 self.workgroup_id = true;
432 wgsl::Variable::WorkgroupIdZ
433 }
434 cube::Builtin::AbsolutePosX => {
435 self.global_invocation_id = true;
436 wgsl::Variable::GlobalInvocationIdX
437 }
438 cube::Builtin::AbsolutePosY => {
439 self.global_invocation_id = true;
440 wgsl::Variable::GlobalInvocationIdY
441 }
442 cube::Builtin::AbsolutePosZ => {
443 self.global_invocation_id = true;
444 wgsl::Variable::GlobalInvocationIdZ
445 }
446 cube::Builtin::CubeDimX => wgsl::Variable::WorkgroupSizeX,
447 cube::Builtin::CubeDimY => wgsl::Variable::WorkgroupSizeY,
448 cube::Builtin::CubeDimZ => wgsl::Variable::WorkgroupSizeZ,
449 cube::Builtin::CubeCountX => {
450 self.num_workgroups = true;
451 wgsl::Variable::NumWorkgroupsX
452 }
453 cube::Builtin::CubeCountY => {
454 self.num_workgroups = true;
455 wgsl::Variable::NumWorkgroupsY
456 }
457 cube::Builtin::CubeCountZ => {
458 self.num_workgroups = true;
459 wgsl::Variable::NumWorkgroupsZ
460 }
461 cube::Builtin::CubePos => {
462 self.workgroup_id_no_axis = true;
463 wgsl::Variable::WorkgroupId
464 }
465 cube::Builtin::CubeDim => {
466 self.workgroup_size_no_axis = true;
467 wgsl::Variable::WorkgroupSize
468 }
469 cube::Builtin::CubeCount => {
470 self.num_workgroup_no_axis = true;
471 wgsl::Variable::NumWorkgroups
472 }
473 cube::Builtin::PlaneDim => {
474 self.subgroup_size = true;
475 wgsl::Variable::SubgroupSize
476 }
477 cube::Builtin::UnitPosPlane => {
478 self.subgroup_invocation_id = true;
479 wgsl::Variable::SubgroupInvocationId
480 }
481 },
482 cube::VariableKind::Matrix { .. } => {
483 panic!("Cooperative matrix-multiply and accumulate not supported.")
484 }
485 }
486 }
487
488 fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec<wgsl::Instruction> {
489 let mut instructions = Vec::new();
490
491 let const_arrays = scope
492 .const_arrays
493 .drain(..)
494 .map(|(var, values)| ConstantArray {
495 index: var.index().unwrap(),
496 item: Self::compile_item(var.item),
497 size: values.len() as u32,
498 values: values
499 .into_iter()
500 .map(|val| self.compile_variable(val))
501 .collect(),
502 })
503 .collect::<Vec<_>>();
504 self.const_arrays.extend(const_arrays);
505
506 let processing = scope.process();
507
508 for var in processing.variables {
509 if let cube::VariableKind::Slice { .. } = var.kind {
511 continue;
512 }
513
514 instructions.push(wgsl::Instruction::DeclareVariable {
515 var: self.compile_variable(var),
516 });
517 }
518
519 processing
520 .operations
521 .into_iter()
522 .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope));
523
524 instructions
525 }
526
527 fn compile_operation(
528 &mut self,
529 instructions: &mut Vec<wgsl::Instruction>,
530 operation: cube::Operation,
531 out: Option<cube::Variable>,
532 scope: &mut cube::Scope,
533 ) {
534 match operation {
535 cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign {
536 input: self.compile_variable(variable),
537 out: self.compile_variable(out.unwrap()),
538 }),
539 cube::Operation::Operator(op) => self.compile_instruction(op, out, instructions, scope),
540 cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)),
541 cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
542 cube::Operation::Branch(val) => self.compile_branch(instructions, val),
543 cube::Operation::Synchronization(val) => {
544 self.compile_synchronization(instructions, val)
545 }
546 cube::Operation::Plane(op) => self.compile_subgroup(instructions, op, out),
547 cube::Operation::CoopMma(_) => {
548 panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.")
549 }
550 cube::Operation::NonSemantic(cube::NonSemantic::Comment { content }) => {
551 self.compile_comment(instructions, content)
552 }
553 cube::Operation::NonSemantic(_) => {}
555 }
556 }
557
558 fn compile_subgroup(
559 &mut self,
560 instructions: &mut Vec<wgsl::Instruction>,
561 subgroup: cube::Plane,
562 out: Option<cube::Variable>,
563 ) {
564 self.subgroup_instructions_used = true;
565
566 let out = out.unwrap();
567 let op = match subgroup {
568 cube::Plane::Elect => Subgroup::Elect {
569 out: self.compile_variable(out),
570 },
571 cube::Plane::All(op) => Subgroup::All {
572 input: self.compile_variable(op.input),
573 out: self.compile_variable(out),
574 },
575 cube::Plane::Any(op) => Subgroup::Any {
576 input: self.compile_variable(op.input),
577 out: self.compile_variable(out),
578 },
579 cube::Plane::Broadcast(op) => Subgroup::Broadcast {
580 lhs: self.compile_variable(op.lhs),
581 rhs: self.compile_variable(op.rhs),
582 out: self.compile_variable(out),
583 },
584 cube::Plane::Sum(op) => Subgroup::Sum {
585 input: self.compile_variable(op.input),
586 out: self.compile_variable(out),
587 },
588 cube::Plane::Prod(op) => Subgroup::Prod {
589 input: self.compile_variable(op.input),
590 out: self.compile_variable(out),
591 },
592 cube::Plane::Min(op) => Subgroup::Min {
593 input: self.compile_variable(op.input),
594 out: self.compile_variable(out),
595 },
596 cube::Plane::Max(op) => Subgroup::Max {
597 input: self.compile_variable(op.input),
598 out: self.compile_variable(out),
599 },
600 };
601
602 instructions.push(wgsl::Instruction::Subgroup(op));
603 }
604
605 fn compile_branch(&mut self, instructions: &mut Vec<wgsl::Instruction>, branch: cube::Branch) {
606 match branch {
607 cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If {
608 cond: self.compile_variable(op.cond),
609 instructions: self.compile_scope(&mut op.scope),
610 }),
611 cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
612 cond: self.compile_variable(op.cond),
613 instructions_if: self.compile_scope(&mut op.scope_if),
614 instructions_else: self.compile_scope(&mut op.scope_else),
615 }),
616 cube::Branch::Switch(mut op) => instructions.push(wgsl::Instruction::Switch {
617 value: self.compile_variable(op.value),
618 instructions_default: self.compile_scope(&mut op.scope_default),
619 cases: op
620 .cases
621 .into_iter()
622 .map(|(val, mut scope)| {
623 (self.compile_variable(val), self.compile_scope(&mut scope))
624 })
625 .collect(),
626 }),
627 cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
628 cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
629 cube::Branch::RangeLoop(mut range_loop) => {
630 instructions.push(wgsl::Instruction::RangeLoop {
631 i: self.compile_variable(range_loop.i),
632 start: self.compile_variable(range_loop.start),
633 end: self.compile_variable(range_loop.end),
634 step: range_loop.step.map(|it| self.compile_variable(it)),
635 inclusive: range_loop.inclusive,
636 instructions: self.compile_scope(&mut range_loop.scope),
637 })
638 }
639 cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop {
640 instructions: self.compile_scope(&mut op.scope),
641 }),
642 };
643 }
644
645 fn compile_synchronization(
646 &mut self,
647 instructions: &mut Vec<wgsl::Instruction>,
648 synchronization: cube::Synchronization,
649 ) {
650 match synchronization {
651 cube::Synchronization::SyncUnits => {
652 instructions.push(wgsl::Instruction::WorkgroupBarrier)
653 }
654 cube::Synchronization::SyncStorage => {
655 instructions.push(wgsl::Instruction::StorageBarrier)
656 }
657 };
658 }
659
660 fn compile_comment(&mut self, instructions: &mut Vec<wgsl::Instruction>, content: String) {
661 instructions.push(wgsl::Instruction::Comment { content })
662 }
663
664 fn compile_metadata(
665 &mut self,
666 metadata: cube::Metadata,
667 out: Option<cube::Variable>,
668 ) -> wgsl::Instruction {
669 let out = out.unwrap();
670 match metadata {
671 cube::Metadata::Rank { var } => {
672 let position = self.ext_meta_pos(&var);
673 let offset = self.metadata.rank_index(position);
674 wgsl::Instruction::Metadata {
675 out: self.compile_variable(out),
676 info_offset: self.compile_variable(offset.into()),
677 }
678 }
679 cube::Metadata::Stride { dim, var } => {
680 let position = self.ext_meta_pos(&var);
681 let offset = self.metadata.stride_offset_index(position);
682 wgsl::Instruction::ExtendedMeta {
683 info_offset: self.compile_variable(offset.into()),
684 dim: self.compile_variable(dim),
685 out: self.compile_variable(out),
686 }
687 }
688 cube::Metadata::Shape { dim, var } => {
689 let position = self.ext_meta_pos(&var);
690 let offset = self.metadata.shape_offset_index(position);
691 wgsl::Instruction::ExtendedMeta {
692 info_offset: self.compile_variable(offset.into()),
693 dim: self.compile_variable(dim),
694 out: self.compile_variable(out),
695 }
696 }
697 cube::Metadata::Length { var } => match var.kind {
698 cube::VariableKind::GlobalInputArray(id) => {
699 let offset = self.metadata.len_index(id);
700 wgsl::Instruction::Metadata {
701 out: self.compile_variable(out),
702 info_offset: self.compile_variable(offset.into()),
703 }
704 }
705 cube::VariableKind::GlobalOutputArray(id) => {
706 let offset = self.metadata.len_index(self.num_inputs as u32 + id);
707 wgsl::Instruction::Metadata {
708 out: self.compile_variable(out),
709 info_offset: self.compile_variable(offset.into()),
710 }
711 }
712 _ => wgsl::Instruction::Length {
713 var: self.compile_variable(var),
714 out: self.compile_variable(out),
715 },
716 },
717 cube::Metadata::BufferLength { var } => match var.kind {
718 cube::VariableKind::GlobalInputArray(id) => {
719 let offset = self.metadata.buffer_len_index(id);
720 wgsl::Instruction::Metadata {
721 out: self.compile_variable(out),
722 info_offset: self.compile_variable(offset.into()),
723 }
724 }
725 cube::VariableKind::GlobalOutputArray(id) => {
726 let id = self.num_inputs as u32 + id;
727 let offset = self.metadata.buffer_len_index(id);
728 wgsl::Instruction::Metadata {
729 out: self.compile_variable(out),
730 info_offset: self.compile_variable(offset.into()),
731 }
732 }
733 _ => wgsl::Instruction::Length {
734 var: self.compile_variable(var),
735 out: self.compile_variable(out),
736 },
737 },
738 }
739 }
740
741 fn compile_instruction(
742 &mut self,
743 value: cube::Operator,
744 out: Option<cube::Variable>,
745 instructions: &mut Vec<wgsl::Instruction>,
746 scope: &mut cube::Scope,
747 ) {
748 let out = out.unwrap();
749 match value {
750 cube::Operator::Max(op) => instructions.push(wgsl::Instruction::Max {
751 lhs: self.compile_variable(op.lhs),
752 rhs: self.compile_variable(op.rhs),
753 out: self.compile_variable(out),
754 }),
755 cube::Operator::Min(op) => instructions.push(wgsl::Instruction::Min {
756 lhs: self.compile_variable(op.lhs),
757 rhs: self.compile_variable(op.rhs),
758 out: self.compile_variable(out),
759 }),
760 cube::Operator::Add(op) => instructions.push(wgsl::Instruction::Add {
761 lhs: self.compile_variable(op.lhs),
762 rhs: self.compile_variable(op.rhs),
763 out: self.compile_variable(out),
764 }),
765 cube::Operator::Fma(op) => instructions.push(wgsl::Instruction::Fma {
766 a: self.compile_variable(op.a),
767 b: self.compile_variable(op.b),
768 c: self.compile_variable(op.c),
769 out: self.compile_variable(out),
770 }),
771 cube::Operator::Modulo(op) => instructions.push(wgsl::Instruction::Modulo {
772 lhs: self.compile_variable(op.lhs),
773 rhs: self.compile_variable(op.rhs),
774 out: self.compile_variable(out),
775 }),
776 cube::Operator::Sub(op) => instructions.push(wgsl::Instruction::Sub {
777 lhs: self.compile_variable(op.lhs),
778 rhs: self.compile_variable(op.rhs),
779 out: self.compile_variable(out),
780 }),
781 cube::Operator::Mul(op) => instructions.push(wgsl::Instruction::Mul {
782 lhs: self.compile_variable(op.lhs),
783 rhs: self.compile_variable(op.rhs),
784 out: self.compile_variable(out),
785 }),
786 cube::Operator::Div(op) => instructions.push(wgsl::Instruction::Div {
787 lhs: self.compile_variable(op.lhs),
788 rhs: self.compile_variable(op.rhs),
789 out: self.compile_variable(out),
790 }),
791 cube::Operator::Abs(op) => instructions.push(wgsl::Instruction::Abs {
792 input: self.compile_variable(op.input),
793 out: self.compile_variable(out),
794 }),
795 cube::Operator::Exp(op) => instructions.push(wgsl::Instruction::Exp {
796 input: self.compile_variable(op.input),
797 out: self.compile_variable(out),
798 }),
799 cube::Operator::Log(op) => instructions.push(wgsl::Instruction::Log {
800 input: self.compile_variable(op.input),
801 out: self.compile_variable(out),
802 }),
803 cube::Operator::Log1p(op) => instructions.push(wgsl::Instruction::Log1p {
804 input: self.compile_variable(op.input),
805 out: self.compile_variable(out),
806 }),
807 cube::Operator::Cos(op) => instructions.push(wgsl::Instruction::Cos {
808 input: self.compile_variable(op.input),
809 out: self.compile_variable(out),
810 }),
811 cube::Operator::Sin(op) => instructions.push(wgsl::Instruction::Sin {
812 input: self.compile_variable(op.input),
813 out: self.compile_variable(out),
814 }),
815 cube::Operator::Tanh(op) => instructions.push(wgsl::Instruction::Tanh {
816 input: self.compile_variable(op.input),
817 out: self.compile_variable(out),
818 }),
819 cube::Operator::Powf(op) => instructions.push(wgsl::Instruction::Powf {
820 lhs: self.compile_variable(op.lhs),
821 rhs: self.compile_variable(op.rhs),
822 out: self.compile_variable(out),
823 }),
824 cube::Operator::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt {
825 input: self.compile_variable(op.input),
826 out: self.compile_variable(out),
827 }),
828 cube::Operator::Round(op) => instructions.push(wgsl::Instruction::Round {
829 input: self.compile_variable(op.input),
830 out: self.compile_variable(out),
831 }),
832 cube::Operator::Floor(op) => instructions.push(wgsl::Instruction::Floor {
833 input: self.compile_variable(op.input),
834 out: self.compile_variable(out),
835 }),
836 cube::Operator::Ceil(op) => instructions.push(wgsl::Instruction::Ceil {
837 input: self.compile_variable(op.input),
838 out: self.compile_variable(out),
839 }),
840 cube::Operator::Erf(op) => instructions.push(wgsl::Instruction::Erf {
841 input: self.compile_variable(op.input),
842 out: self.compile_variable(out),
843 }),
844 cube::Operator::Recip(op) => instructions.push(wgsl::Instruction::Recip {
845 input: self.compile_variable(op.input),
846 out: self.compile_variable(out),
847 }),
848 cube::Operator::Equal(op) => instructions.push(wgsl::Instruction::Equal {
849 lhs: self.compile_variable(op.lhs),
850 rhs: self.compile_variable(op.rhs),
851 out: self.compile_variable(out),
852 }),
853 cube::Operator::Lower(op) => instructions.push(wgsl::Instruction::Lower {
854 lhs: self.compile_variable(op.lhs),
855 rhs: self.compile_variable(op.rhs),
856 out: self.compile_variable(out),
857 }),
858 cube::Operator::Clamp(op) => instructions.push(wgsl::Instruction::Clamp {
859 input: self.compile_variable(op.input),
860 min_value: self.compile_variable(op.min_value),
861 max_value: self.compile_variable(op.max_value),
862 out: self.compile_variable(out),
863 }),
864 cube::Operator::Greater(op) => instructions.push(wgsl::Instruction::Greater {
865 lhs: self.compile_variable(op.lhs),
866 rhs: self.compile_variable(op.rhs),
867 out: self.compile_variable(out),
868 }),
869 cube::Operator::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual {
870 lhs: self.compile_variable(op.lhs),
871 rhs: self.compile_variable(op.rhs),
872 out: self.compile_variable(out),
873 }),
874 cube::Operator::GreaterEqual(op) => {
875 instructions.push(wgsl::Instruction::GreaterEqual {
876 lhs: self.compile_variable(op.lhs),
877 rhs: self.compile_variable(op.rhs),
878 out: self.compile_variable(out),
879 })
880 }
881 cube::Operator::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual {
882 lhs: self.compile_variable(op.lhs),
883 rhs: self.compile_variable(op.rhs),
884 out: self.compile_variable(out),
885 }),
886 cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign {
887 input: self.compile_variable(op.input),
888 out: self.compile_variable(out),
889 }),
890 cube::Operator::Index(op) => {
891 if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() {
892 let lhs = op.lhs;
893 let rhs = op.rhs;
894 let array_len =
895 scope.create_local(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32)));
896
897 instructions.extend(self.compile_scope(scope));
898
899 let length = match lhs.has_buffer_length() {
900 true => cube::Metadata::BufferLength { var: lhs },
901 false => cube::Metadata::Length { var: lhs },
902 };
903 instructions.push(self.compile_metadata(length, Some(array_len)));
904 instructions.push(wgsl::Instruction::CheckedIndex {
905 len: self.compile_variable(array_len),
906 lhs: self.compile_variable(lhs),
907 rhs: self.compile_variable(rhs),
908 out: self.compile_variable(out),
909 });
910 } else {
911 instructions.push(wgsl::Instruction::Index {
912 lhs: self.compile_variable(op.lhs),
913 rhs: self.compile_variable(op.rhs),
914 out: self.compile_variable(out),
915 });
916 }
917 }
918 cube::Operator::UncheckedIndex(op) => instructions.push(wgsl::Instruction::Index {
919 lhs: self.compile_variable(op.lhs),
920 rhs: self.compile_variable(op.rhs),
921 out: self.compile_variable(out),
922 }),
923 cube::Operator::IndexAssign(op) => {
924 if let ExecutionMode::Checked = self.strategy {
925 if out.has_length() {
926 expand_checked_index_assign(scope, op.lhs, op.rhs, out);
927 instructions.extend(self.compile_scope(scope));
928 return;
929 }
930 };
931 instructions.push(wgsl::Instruction::IndexAssign {
932 lhs: self.compile_variable(op.lhs),
933 rhs: self.compile_variable(op.rhs),
934 out: self.compile_variable(out),
935 })
936 }
937 cube::Operator::UncheckedIndexAssign(op) => {
938 instructions.push(wgsl::Instruction::IndexAssign {
939 lhs: self.compile_variable(op.lhs),
940 rhs: self.compile_variable(op.rhs),
941 out: self.compile_variable(out),
942 })
943 }
944 cube::Operator::And(op) => instructions.push(wgsl::Instruction::And {
945 lhs: self.compile_variable(op.lhs),
946 rhs: self.compile_variable(op.rhs),
947 out: self.compile_variable(out),
948 }),
949 cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or {
950 lhs: self.compile_variable(op.lhs),
951 rhs: self.compile_variable(op.rhs),
952 out: self.compile_variable(out),
953 }),
954 cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not {
955 input: self.compile_variable(op.input),
956 out: self.compile_variable(out),
957 }),
958 cube::Operator::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr {
959 lhs: self.compile_variable(op.lhs),
960 rhs: self.compile_variable(op.rhs),
961 out: self.compile_variable(out),
962 }),
963 cube::Operator::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd {
964 lhs: self.compile_variable(op.lhs),
965 rhs: self.compile_variable(op.rhs),
966 out: self.compile_variable(out),
967 }),
968 cube::Operator::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor {
969 lhs: self.compile_variable(op.lhs),
970 rhs: self.compile_variable(op.rhs),
971 out: self.compile_variable(out),
972 }),
973 cube::Operator::CountOnes(op) => instructions.push(wgsl::Instruction::CountBits {
974 input: self.compile_variable(op.input),
975 out: self.compile_variable(out),
976 }),
977 cube::Operator::ReverseBits(op) => instructions.push(wgsl::Instruction::ReverseBits {
978 input: self.compile_variable(op.input),
979 out: self.compile_variable(out),
980 }),
981 cube::Operator::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft {
982 lhs: self.compile_variable(op.lhs),
983 rhs: self.compile_variable(op.rhs),
984 out: self.compile_variable(out),
985 }),
986 cube::Operator::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight {
987 lhs: self.compile_variable(op.lhs),
988 rhs: self.compile_variable(op.rhs),
989 out: self.compile_variable(out),
990 }),
991 cube::Operator::Remainder(op) => instructions.push(wgsl::Instruction::Remainder {
992 lhs: self.compile_variable(op.lhs),
993 rhs: self.compile_variable(op.rhs),
994 out: self.compile_variable(out),
995 }),
996 cube::Operator::Slice(op) => {
997 if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() {
998 let input = op.input;
999 let input_len = scope
1000 .create_local_mut(cube::Item::new(cube::Elem::UInt(cube::UIntKind::U32)));
1001 instructions.extend(self.compile_scope(scope));
1002
1003 let length = match input.has_buffer_length() {
1004 true => cube::Metadata::BufferLength { var: input },
1005 false => cube::Metadata::Length { var: input },
1006 };
1007
1008 instructions.push(self.compile_metadata(length, Some(input_len)));
1009 instructions.push(wgsl::Instruction::CheckedSlice {
1010 input: self.compile_variable(input),
1011 start: self.compile_variable(op.start),
1012 end: self.compile_variable(op.end),
1013 out: self.compile_variable(out),
1014 len: self.compile_variable(input_len),
1015 });
1016 } else {
1017 instructions.push(wgsl::Instruction::Slice {
1018 input: self.compile_variable(op.input),
1019 start: self.compile_variable(op.start),
1020 end: self.compile_variable(op.end),
1021 out: self.compile_variable(out),
1022 });
1023 }
1024 }
1025 cube::Operator::Bitcast(op) => instructions.push(wgsl::Instruction::Bitcast {
1026 input: self.compile_variable(op.input),
1027 out: self.compile_variable(out),
1028 }),
1029
1030 cube::Operator::Neg(op) => instructions.push(wgsl::Instruction::Negate {
1031 input: self.compile_variable(op.input),
1032 out: self.compile_variable(out),
1033 }),
1034 cube::Operator::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude {
1035 input: self.compile_variable(op.input),
1036 out: self.compile_variable(out),
1037 }),
1038 cube::Operator::Normalize(op) => instructions.push(wgsl::Instruction::Normalize {
1039 input: self.compile_variable(op.input),
1040 out: self.compile_variable(out),
1041 }),
1042 cube::Operator::Dot(op) => instructions.push(wgsl::Instruction::Dot {
1043 lhs: self.compile_variable(op.lhs),
1044 rhs: self.compile_variable(op.rhs),
1045 out: self.compile_variable(out),
1046 }),
1047 cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit {
1048 inputs: op
1049 .inputs
1050 .into_iter()
1051 .map(|var| self.compile_variable(var))
1052 .collect(),
1053 out: self.compile_variable(out),
1054 }),
1055 cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy {
1056 input: self.compile_variable(op.input),
1057 in_index: self.compile_variable(op.in_index),
1058 out: self.compile_variable(out),
1059 out_index: self.compile_variable(op.out_index),
1060 }),
1061 cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk {
1062 input: self.compile_variable(op.input),
1063 in_index: self.compile_variable(op.in_index),
1064 out: self.compile_variable(out),
1065 out_index: self.compile_variable(op.out_index),
1066 len: op.len,
1067 }),
1068 cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select {
1069 cond: self.compile_variable(op.cond),
1070 then: self.compile_variable(op.then),
1071 or_else: self.compile_variable(op.or_else),
1072 out: self.compile_variable(out),
1073 }),
1074 }
1075 }
1076
1077 fn compile_atomic(
1078 &mut self,
1079 atomic: cube::AtomicOp,
1080 out: Option<cube::Variable>,
1081 ) -> wgsl::Instruction {
1082 let out = out.unwrap();
1083 match atomic {
1084 cube::AtomicOp::Add(op) => wgsl::Instruction::AtomicAdd {
1085 lhs: self.compile_variable(op.lhs),
1086 rhs: self.compile_variable(op.rhs),
1087 out: self.compile_variable(out),
1088 },
1089 cube::AtomicOp::Sub(op) => wgsl::Instruction::AtomicSub {
1090 lhs: self.compile_variable(op.lhs),
1091 rhs: self.compile_variable(op.rhs),
1092 out: self.compile_variable(out),
1093 },
1094 cube::AtomicOp::Max(op) => wgsl::Instruction::AtomicMax {
1095 lhs: self.compile_variable(op.lhs),
1096 rhs: self.compile_variable(op.rhs),
1097 out: self.compile_variable(out),
1098 },
1099 cube::AtomicOp::Min(op) => wgsl::Instruction::AtomicMin {
1100 lhs: self.compile_variable(op.lhs),
1101 rhs: self.compile_variable(op.rhs),
1102 out: self.compile_variable(out),
1103 },
1104 cube::AtomicOp::And(op) => wgsl::Instruction::AtomicAnd {
1105 lhs: self.compile_variable(op.lhs),
1106 rhs: self.compile_variable(op.rhs),
1107 out: self.compile_variable(out),
1108 },
1109 cube::AtomicOp::Or(op) => wgsl::Instruction::AtomicOr {
1110 lhs: self.compile_variable(op.lhs),
1111 rhs: self.compile_variable(op.rhs),
1112 out: self.compile_variable(out),
1113 },
1114 cube::AtomicOp::Xor(op) => wgsl::Instruction::AtomicXor {
1115 lhs: self.compile_variable(op.lhs),
1116 rhs: self.compile_variable(op.rhs),
1117 out: self.compile_variable(out),
1118 },
1119 cube::AtomicOp::Load(op) => wgsl::Instruction::AtomicLoad {
1120 input: self.compile_variable(op.input),
1121 out: self.compile_variable(out),
1122 },
1123 cube::AtomicOp::Store(op) => wgsl::Instruction::AtomicStore {
1124 input: self.compile_variable(op.input),
1125 out: self.compile_variable(out),
1126 },
1127 cube::AtomicOp::Swap(op) => wgsl::Instruction::AtomicSwap {
1128 lhs: self.compile_variable(op.lhs),
1129 rhs: self.compile_variable(op.rhs),
1130 out: self.compile_variable(out),
1131 },
1132 cube::AtomicOp::CompareAndSwap(op) => wgsl::Instruction::AtomicCompareExchangeWeak {
1133 lhs: self.compile_variable(op.input),
1134 cmp: self.compile_variable(op.cmp),
1135 value: self.compile_variable(op.val),
1136 out: self.compile_variable(out),
1137 },
1138 }
1139 }
1140
1141 fn compile_location(value: cube::Location) -> wgsl::Location {
1142 match value {
1143 cube::Location::Storage => wgsl::Location::Storage,
1144 cube::Location::Cube => wgsl::Location::Workgroup,
1145 }
1146 }
1147
1148 fn compile_visibility(value: cube::Visibility) -> wgsl::Visibility {
1149 match value {
1150 cube::Visibility::Read => wgsl::Visibility::Read,
1151 cube::Visibility::ReadWrite => wgsl::Visibility::ReadWrite,
1152 }
1153 }
1154
1155 fn compile_binding(value: cube::Binding) -> wgsl::Binding {
1156 wgsl::Binding {
1157 visibility: Self::compile_visibility(value.visibility),
1158 location: Self::compile_location(value.location),
1159 item: Self::compile_item(value.item),
1160 size: value.size,
1161 }
1162 }
1163}
1164
1165fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extension> {
1166 let mut extensions = Vec::new();
1167
1168 let mut register_extension = |extension: wgsl::Extension| {
1169 if !extensions.contains(&extension) {
1170 extensions.push(extension);
1171 }
1172 };
1173
1174 for instruction in instructions {
1176 match instruction {
1177 wgsl::Instruction::Powf { lhs: _, rhs, out } => {
1178 register_extension(wgsl::Extension::PowfPrimitive(out.item()));
1179
1180 if rhs.is_always_scalar() || rhs.item().vectorization_factor() == 1 {
1181 register_extension(wgsl::Extension::PowfScalar(out.item()));
1182 } else {
1183 register_extension(wgsl::Extension::Powf(out.item()));
1184 }
1185 }
1186 wgsl::Instruction::Erf { input, out: _ } => {
1187 register_extension(wgsl::Extension::Erf(input.item()));
1188 }
1189 #[cfg(target_os = "macos")]
1190 wgsl::Instruction::Tanh { input, out: _ } => {
1191 register_extension(wgsl::Extension::SafeTanh(input.item()))
1192 }
1193 wgsl::Instruction::If { instructions, .. } => {
1194 for extension in register_extensions(instructions) {
1195 register_extension(extension);
1196 }
1197 }
1198 wgsl::Instruction::IfElse {
1199 instructions_if,
1200 instructions_else,
1201 ..
1202 } => {
1203 for extension in register_extensions(instructions_if) {
1204 register_extension(extension);
1205 }
1206 for extension in register_extensions(instructions_else) {
1207 register_extension(extension);
1208 }
1209 }
1210 wgsl::Instruction::Loop { instructions } => {
1211 for extension in register_extensions(instructions) {
1212 register_extension(extension);
1213 }
1214 }
1215 wgsl::Instruction::RangeLoop { instructions, .. } => {
1216 for extension in register_extensions(instructions) {
1217 register_extension(extension);
1218 }
1219 }
1220 _ => {}
1221 }
1222 }
1223
1224 extensions
1225}