1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind, VariableKind};
5use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
6use cubecl_core::{
7 Compiler,
8 ir::{self as gpu},
9};
10use cubecl_core::{CubeDim, ir::ElemType};
11use cubecl_core::{
12 ir::{Operation, SourceLoc},
13 prelude::{FastMath, KernelDefinition},
14};
15use cubecl_opt::{Optimizer, SharedLiveness};
16use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage};
17
18use crate::shared::MmaShape;
19
20use super::{
21 BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
22 Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
23 Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
24};
25use super::{FP4Kind, barrier::BarrierOps};
26use super::{FP8Kind, pipeline::PipelineOps};
27
28pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
29 std::sync::atomic::AtomicU32::new(0);
30
31#[derive(Clone, Debug)]
32pub struct CompilationOptions {
33 pub warp_size: u32,
34 pub supports_features: CppSupportedFeatures,
35}
36
37#[derive(Clone, Debug, Default)]
38pub struct CppSupportedFeatures {
39 pub grid_constants: bool,
40 pub clusters: bool,
41 pub fast_math: bool,
42 pub fast_tanh: bool,
43 pub elect_sync: bool,
44}
45
46impl Default for CompilationOptions {
47 fn default() -> Self {
48 Self {
49 warp_size: 32,
50 supports_features: Default::default(),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Default)]
58pub struct CubeIndexFlags {
59 pub absolute_pos: bool,
60 pub absolute_pos_tuple: bool,
61 pub cube_count: bool,
62 pub cube_count_tuple: bool,
63 pub cube_dim: bool,
64 pub cube_dim_tuple: bool,
65 pub cube_pos: bool,
66 pub cube_pos_tuple: bool,
67 pub plane_dim: bool,
68 pub plane_dim_checked: bool,
69 pub plane_index: bool,
70 pub unit_pos: bool,
71 pub unit_pos_tuple: bool,
72 pub unit_pos_plane: bool,
73 pub cluster_pos: bool,
74}
75
76#[derive(Debug, Clone, Default)]
78pub struct Flags {
79 pub elem_fp4: bool,
80 pub elem_fp6: bool,
81 pub elem_fp8: bool,
82 pub elem_bf16: bool,
83 pub elem_f16: bool,
84 pub elem_tf32: bool,
85 pub indexes: CubeIndexFlags,
86 pub op_barrier: bool,
87 pub op_pipeline: bool,
88 pub inst_tma: bool,
89 pub inst_tma_im2col: bool,
90 pub inst_wmma: bool,
91 pub inst_ptx_wrappers: bool,
92 pub use_grid_constants: bool,
93 pub static_meta_length: usize,
94 pub has_dynamic_meta: bool,
95 pub cube_dim: CubeDim,
96 pub cluster_dim: Option<CubeDim>,
97}
98
99#[allow(clippy::too_many_arguments)]
100#[derive(Clone, Debug, Default)]
101pub struct CppCompiler<D: Dialect> {
102 barriers: Vec<BarrierOps<D>>,
103 compilation_options: CompilationOptions,
104 const_arrays: Vec<ConstArray<D>>,
105 ext_meta_positions: Vec<u32>,
106 cluster_dim: CubeDim,
107 extensions: Vec<D::Extension>,
108 flags: Flags,
109 items: HashSet<Item<D>>,
110 local_arrays: Vec<LocalArray<D>>,
111 metadata: cubecl_core::Metadata,
112 pipelines: Vec<PipelineOps<D>>,
113 source_loc: Option<SourceLoc>,
114 strategy: ExecutionMode,
115}
116
117impl<D: Dialect> Compiler for CppCompiler<D> {
118 type Representation = ComputeKernel<D>;
119 type CompilationOptions = CompilationOptions;
120
121 fn compile(
122 &mut self,
123 mut kernel: KernelDefinition,
124 compilation_options: &Self::CompilationOptions,
125 strategy: ExecutionMode,
126 ) -> Self::Representation {
127 self.compilation_options = compilation_options.clone();
128 self.strategy = strategy;
129
130 if !self.compilation_options.supports_features.clusters {
131 kernel.options.cluster_dim = None;
132 }
133 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
134
135 let ir = self.clone().compile_ir(kernel);
136 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
137 ir
138 }
139
140 fn elem_size(&self, elem: gpu::ElemType) -> usize {
141 elem.size()
142 }
143
144 fn extension(&self) -> &'static str {
145 "cpp"
146 }
147}
148
149impl<D: Dialect> CppCompiler<D> {
150 fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
151 self.build_metadata(&value);
152
153 let instructions = self.compile_scope(&mut value.body.clone());
154 let tensor_maps = value
155 .tensor_maps
156 .into_iter()
157 .map(|b| self.compile_binding(b))
158 .collect();
159 let buffers = value
160 .buffers
161 .into_iter()
162 .map(|b| self.compile_binding(b))
163 .collect();
164 let scalars = value
165 .scalars
166 .into_iter()
167 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
168 .collect();
169
170 let flags = Flags {
172 indexes: D::builtin_rules(&self.flags.indexes),
173 inst_wmma: self.flags.inst_wmma,
174 op_pipeline: self.flags.op_pipeline,
175 op_barrier: self.flags.op_barrier,
176 elem_fp4: self.flags.elem_fp4,
177 elem_fp6: self.flags.elem_fp6,
178 elem_fp8: self.flags.elem_fp8,
179 elem_bf16: self.flags.elem_bf16,
180 elem_f16: self.flags.elem_f16,
181 elem_tf32: self.flags.elem_tf32,
182 inst_tma: self.flags.inst_tma,
183 inst_tma_im2col: self.flags.inst_tma_im2col,
184 inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
185 use_grid_constants: self.compilation_options.supports_features.grid_constants,
186 has_dynamic_meta: self.metadata.static_len() > 0,
189 static_meta_length: self.metadata.static_len() as usize,
190 cube_dim: value.cube_dim,
191 cluster_dim: value.options.cluster_dim,
192 };
193
194 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
195 let shared_allocs = opt.analysis::<SharedLiveness>();
196 let shared_memories = shared_allocs
197 .allocations
198 .values()
199 .map(|alloc| SharedMemory {
200 index: alloc.smem.id,
201 item: self.compile_type(alloc.smem.ty),
202 length: alloc.smem.length,
203 align: alloc.smem.align,
204 offset: alloc.offset,
205 })
206 .collect();
207
208 let body = Body {
209 instructions,
210 shared_memories,
211 pipelines: self.pipelines,
212 barriers: self.barriers,
213 const_arrays: self.const_arrays,
214 local_arrays: self.local_arrays,
215 };
216
217 let mut cluster_dim = value.options.cluster_dim;
218 if !self.compilation_options.supports_features.clusters {
219 cluster_dim = None;
220 }
221
222 ComputeKernel {
223 tensor_maps,
224 buffers,
225 scalars,
226 meta_static_len: self.metadata.static_len() as usize,
227 cube_dim: value.cube_dim,
228 body,
229 extensions: self.extensions,
230 flags,
231 items: self.items,
232 kernel_name: value.options.kernel_name,
233 cluster_dim,
234 }
235 }
236
237 fn build_metadata(&mut self, value: &KernelDefinition) {
238 let mut num_ext = 0;
239
240 let mut all_meta: Vec<_> = value
241 .buffers
242 .iter()
243 .chain(value.tensor_maps.iter())
244 .map(|buf| (buf.id, buf.has_extended_meta))
245 .collect();
246
247 all_meta.sort_by_key(|(id, _)| *id);
248
249 for (_, has_extended_meta) in &all_meta {
250 self.ext_meta_positions.push(num_ext);
251 if *has_extended_meta {
252 num_ext += 1;
253 }
254 }
255
256 let num_meta = all_meta.len();
257
258 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
259 }
260
261 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
262 let id = var.index().expect("Variable should have index");
263 self.ext_meta_positions[id as usize]
264 }
265
266 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
267 let mut instructions = Vec::new();
268
269 let const_arrays = scope
270 .const_arrays
271 .drain(..)
272 .map(|(var, values)| ConstArray {
273 index: var.index().unwrap(),
274 item: self.compile_type(var.ty),
275 size: values.len() as u32,
276 values: values
277 .into_iter()
278 .map(|val| self.compile_variable(val))
279 .collect(),
280 })
281 .collect::<Vec<_>>();
282 self.const_arrays.extend(const_arrays);
283
284 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
285 let dialect_processors = D::processors();
286 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
287 processors.extend(dialect_processors.iter().map(|it| &**it));
288
289 let processing = scope.process(processors);
290
291 for var in processing.variables {
292 instructions.push(Instruction::DeclareVariable {
293 var: self.compile_variable(var),
294 });
295 }
296
297 processing
298 .instructions
299 .into_iter()
300 .for_each(|op| self.compile_instruction(&mut instructions, op));
301
302 instructions
303 }
304
305 fn compile_instruction(
306 &mut self,
307 instructions: &mut Vec<Instruction<D>>,
308 instruction: gpu::Instruction,
309 ) {
310 self.update_debug_loc(instructions, &instruction);
311 let out = instruction.out;
312 match instruction.operation {
313 gpu::Operation::Copy(variable) => {
314 instructions.push(Instruction::Assign(UnaryInstruction {
315 input: self.compile_variable(variable),
316 out: self.compile_variable(out.unwrap()),
317 }));
318 }
319 gpu::Operation::Arithmetic(op) => {
320 self.compile_arithmetic(op, out, instruction.modes, instructions)
321 }
322 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
323 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
324 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
325 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
326 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
327 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
328 gpu::Operation::Synchronization(val) => match val {
329 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
330 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
331 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
332 gpu::Synchronization::SyncAsyncProxyShared => {
333 self.flags.inst_tma = true;
334 instructions.push(Instruction::ProxyAsyncToSharedFence)
335 }
336 },
337 gpu::Operation::Plane(op) => {
338 self.flags.indexes.plane_dim_checked = true;
339 let out = self.compile_variable(out.unwrap());
340 match op {
341 gpu::Plane::Sum(op) => {
342 let instruction = WarpInstruction::ReduceSum {
343 input: self.compile_variable(op.input),
344 out,
345 };
346 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
347 instructions.push(Instruction::Warp(instruction));
348 }
349 gpu::Plane::InclusiveSum(op) => {
350 self.flags.indexes.unit_pos_plane = true;
351 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
352 input: self.compile_variable(op.input),
353 out,
354 }))
355 }
356 gpu::Plane::InclusiveProd(op) => {
357 self.flags.indexes.unit_pos_plane = true;
358 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
359 input: self.compile_variable(op.input),
360 out,
361 }))
362 }
363 gpu::Plane::ExclusiveSum(op) => {
364 self.flags.indexes.unit_pos_plane = true;
365 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
366 input: self.compile_variable(op.input),
367 out,
368 }))
369 }
370 gpu::Plane::ExclusiveProd(op) => {
371 self.flags.indexes.unit_pos_plane = true;
372 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
373 input: self.compile_variable(op.input),
374 out,
375 }))
376 }
377 gpu::Plane::Prod(op) => {
378 let instruction = WarpInstruction::ReduceProd {
379 input: self.compile_variable(op.input),
380 out,
381 };
382 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
383 instructions.push(Instruction::Warp(instruction))
384 }
385 gpu::Plane::Max(op) => {
386 let instruction = WarpInstruction::ReduceMax {
387 input: self.compile_variable(op.input),
388 out,
389 };
390 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
391 instructions.push(Instruction::Warp(instruction))
392 }
393 gpu::Plane::Min(op) => {
394 let instruction = WarpInstruction::ReduceMin {
395 input: self.compile_variable(op.input),
396 out,
397 };
398 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
399 instructions.push(Instruction::Warp(instruction))
400 }
401 gpu::Plane::Elect => {
402 if self.compilation_options.supports_features.elect_sync {
403 self.flags.inst_ptx_wrappers = true;
404 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
405 } else {
406 instructions
407 .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
408 }
409 }
410 gpu::Plane::All(op) => {
411 instructions.push(Instruction::Warp(WarpInstruction::All {
412 input: self.compile_variable(op.input),
413 out,
414 }))
415 }
416 gpu::Plane::Any(op) => {
417 instructions.push(Instruction::Warp(WarpInstruction::Any {
418 input: self.compile_variable(op.input),
419 out,
420 }))
421 }
422 gpu::Plane::Ballot(op) => {
423 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
424 input: self.compile_variable(op.input),
425 out,
426 }))
427 }
428 gpu::Plane::Broadcast(op) => {
429 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
430 input: self.compile_variable(op.lhs),
431 id: self.compile_variable(op.rhs),
432 out,
433 }))
434 }
435 gpu::Plane::Shuffle(op) => {
436 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
437 input: self.compile_variable(op.lhs),
438 src_lane: self.compile_variable(op.rhs),
439 out,
440 }))
441 }
442 gpu::Plane::ShuffleXor(op) => {
443 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
444 input: self.compile_variable(op.lhs),
445 mask: self.compile_variable(op.rhs),
446 out,
447 }))
448 }
449 gpu::Plane::ShuffleUp(op) => {
450 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
451 input: self.compile_variable(op.lhs),
452 delta: self.compile_variable(op.rhs),
453 out,
454 }))
455 }
456 gpu::Plane::ShuffleDown(op) => {
457 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
458 input: self.compile_variable(op.lhs),
459 delta: self.compile_variable(op.rhs),
460 out,
461 }))
462 }
463 }
464 }
465 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
466 gpu::Operation::NonSemantic(debug) => match debug {
467 gpu::NonSemantic::Print {
468 format_string,
469 args,
470 } => instructions.push(Instruction::Printf {
471 format_string,
472 args: args
473 .into_iter()
474 .map(|arg| self.compile_variable(arg))
475 .collect(),
476 }),
477 gpu::NonSemantic::Comment { content } => {
478 instructions.push(Instruction::Comment { content })
479 }
480 _ => {}
482 },
483 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
484 gpu::BarrierOps::Declare { barrier } => {
485 let VariableKind::Barrier { level, .. } = barrier.kind else {
486 unreachable!()
487 };
488 let barrier = self.compile_variable(barrier);
489 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
490 barrier,
491 level,
492 }));
493 }
494 gpu::BarrierOps::Init {
495 barrier,
496 is_elected,
497 arrival_count,
498 with_async_proxy_fence,
499 } => {
500 let VariableKind::Barrier { level, .. } = barrier.kind else {
501 unreachable!()
502 };
503 let barrier = self.compile_variable(barrier);
504 let arrival_count = self.compile_variable(arrival_count);
505 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
506 barrier,
507 is_elected: self.compile_variable(is_elected),
508 arrival_count,
509 level,
510 with_async_proxy_fence,
511 }));
512 }
513 gpu::BarrierOps::InitManual {
514 barrier,
515 arrival_count,
516 } => {
517 let barrier = self.compile_variable(barrier);
518 let arrival_count = self.compile_variable(arrival_count);
519 instructions.push(Instruction::Barrier(
520 super::barrier::BarrierOps::InitManual {
521 barrier,
522 arrival_count,
523 },
524 ));
525 }
526 gpu::BarrierOps::MemCopyAsync {
527 barrier,
528 source,
529 source_length,
530 offset_source,
531 offset_out,
532 } => {
533 instructions.push(Instruction::Barrier(
534 super::barrier::BarrierOps::MemCopyAsync {
535 barrier: self.compile_variable(barrier),
536 source: self.compile_variable(source),
537 destination: self.compile_variable(out.unwrap()),
538 source_length: self.compile_variable(source_length),
539 offset_source: self.compile_variable(offset_source),
540 offset_out: self.compile_variable(offset_out),
541 cooperative: false,
542 },
543 ));
544 }
545 gpu::BarrierOps::MemCopyAsyncCooperative {
546 barrier,
547 source,
548 source_length,
549 offset_source,
550 offset_out,
551 } => {
552 instructions.push(Instruction::Barrier(
553 super::barrier::BarrierOps::MemCopyAsync {
554 barrier: self.compile_variable(barrier),
555 source: self.compile_variable(source),
556 destination: self.compile_variable(out.unwrap()),
557 source_length: self.compile_variable(source_length),
558 offset_source: self.compile_variable(offset_source),
559 offset_out: self.compile_variable(offset_out),
560 cooperative: true,
561 },
562 ));
563 }
564 gpu::BarrierOps::MemCopyAsyncTx {
565 barrier,
566 source,
567 source_length,
568 offset_source,
569 offset_out,
570 } => {
571 instructions.push(Instruction::Barrier(
572 super::barrier::BarrierOps::MemCopyAsyncTx {
573 barrier: self.compile_variable(barrier),
574 source: self.compile_variable(source),
575 destination: self.compile_variable(out.unwrap()),
576 source_length: self.compile_variable(source_length),
577 offset_source: self.compile_variable(offset_source),
578 offset_out: self.compile_variable(offset_out),
579 },
580 ));
581 }
582 gpu::BarrierOps::TmaLoad {
583 barrier,
584 tensor_map,
585 offset_out,
586 indices,
587 } => {
588 instructions.push(Instruction::Barrier(
589 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
590 barrier: self.compile_variable(barrier),
591 smem_buffer: self.compile_variable(out.unwrap()),
592 smem_offset: self.compile_variable(offset_out),
593 tensor_map: self.compile_variable(tensor_map),
594 indices: indices
595 .into_iter()
596 .map(|it| self.compile_variable(it))
597 .collect(),
598 },
599 ));
600 }
601 gpu::BarrierOps::TmaLoadIm2col {
602 barrier,
603 tensor_map,
604 offset_out,
605 indices,
606 offsets,
607 } => {
608 self.flags.inst_tma_im2col = true;
609 instructions.push(Instruction::Barrier(
610 super::barrier::BarrierOps::TmaLoadIm2col {
611 barrier: self.compile_variable(barrier),
612 smem_buffer: self.compile_variable(out.unwrap()),
613 smem_offset: self.compile_variable(offset_out),
614 tensor_map: self.compile_variable(tensor_map),
615 indices: indices
616 .into_iter()
617 .map(|it| self.compile_variable(it))
618 .collect(),
619 offsets: offsets
620 .into_iter()
621 .map(|it| self.compile_variable(it))
622 .collect(),
623 },
624 ));
625 }
626 gpu::BarrierOps::Arrive { barrier } => {
627 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
628 barrier: self.compile_variable(barrier),
629 token: self.compile_variable(out.unwrap()),
630 }))
631 }
632 gpu::BarrierOps::ArriveTx {
633 barrier,
634 arrive_count_update,
635 transaction_count_update,
636 } => {
637 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
638 barrier: self.compile_variable(barrier),
639 token: self.compile_variable(out.unwrap()),
640 arrive_count_update: self.compile_variable(arrive_count_update),
641 transaction_count_update: self.compile_variable(transaction_count_update),
642 }))
643 }
644 gpu::BarrierOps::ExpectTx {
645 barrier,
646 transaction_count_update,
647 } => {
648 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
649 barrier: self.compile_variable(barrier),
650 transaction_count_update: self.compile_variable(transaction_count_update),
651 }))
652 }
653 gpu::BarrierOps::Wait { barrier, token } => {
654 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
655 barrier: self.compile_variable(barrier),
656 token: self.compile_variable(token),
657 }))
658 }
659 gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
660 Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
661 barrier: self.compile_variable(barrier),
662 phase: self.compile_variable(phase),
663 }),
664 ),
665 gpu::BarrierOps::ArriveAndWait { barrier } => {
666 let VariableKind::Barrier { level, .. } = barrier.kind else {
667 unreachable!()
668 };
669 instructions.push(Instruction::Barrier(
670 super::barrier::BarrierOps::ArriveAndWait {
671 barrier: self.compile_variable(barrier),
672 level,
673 },
674 ))
675 }
676 },
677 gpu::Operation::Tma(tma_ops) => {
678 self.flags.inst_tma = true;
679 match tma_ops {
680 gpu::TmaOps::TmaStore {
681 source,
682 coordinates,
683 offset_source,
684 } => {
685 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
686 smem_buffer: self.compile_variable(source),
687 smem_offset: self.compile_variable(offset_source),
688 tensor_map: self.compile_variable(out.unwrap()),
689 indices: coordinates
690 .into_iter()
691 .map(|it| self.compile_variable(it))
692 .collect(),
693 });
694 }
695 gpu::TmaOps::CommitGroup => {
696 instructions.push(Instruction::BulkCommitGroup);
697 }
698 gpu::TmaOps::WaitGroup { max_pending } => {
699 instructions.push(Instruction::BulkWaitGroup { max_pending });
700 }
701 gpu::TmaOps::WaitGroupRead { max_pending } => {
702 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
703 }
704 }
705 }
706 gpu::Operation::Marker(_) => {}
707 }
708 }
709
710 fn update_debug_loc(
711 &mut self,
712 instructions: &mut Vec<Instruction<D>>,
713 inst: &gpu::Instruction,
714 ) {
715 if !matches!(inst.operation, Operation::NonSemantic(_)) {
716 match &inst.source_loc {
717 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
718 self.source_loc = Some(loc.clone());
719 instructions.push(Instruction::Line {
720 file: loc.source.file.clone(),
721 line: loc.line,
722 });
723 }
724 _ => {}
725 }
726 }
727 }
728
729 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
730 self.flags.inst_wmma = true;
731
732 let out = self.compile_variable(out.unwrap());
733
734 let inst = match cmma {
735 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
736 frag: out,
737 value: self.compile_variable(value),
738 },
739 gpu::CoopMma::Load {
740 value,
741 stride,
742 offset,
743 layout,
744 } => WmmaInstruction::Load {
745 frag: out,
746 offset: self.compile_variable(offset),
747 value: self.compile_variable(value),
748 stride: self.compile_variable(stride),
749 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
750 },
751 gpu::CoopMma::Execute {
752 mat_a,
753 mat_b,
754 mat_c,
755 } => WmmaInstruction::Execute {
756 frag_a: self.compile_variable(mat_a),
757 frag_b: self.compile_variable(mat_b),
758 frag_c: self.compile_variable(mat_c),
759 frag_d: out,
760 warp_size: self.compilation_options.warp_size,
761 },
762 gpu::CoopMma::ExecuteManual {
763 matrix,
764 registers_a,
765 registers_b,
766 registers_c,
767 } => WmmaInstruction::ExecuteManual {
768 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
769 frag_a: self.compile_variable(registers_a),
770 frag_b: self.compile_variable(registers_b),
771 frag_c: self.compile_variable(registers_c),
772 frag_d: out,
773 },
774 gpu::CoopMma::ExecuteScaled {
775 matrix,
776 registers_a,
777 registers_b,
778 registers_c,
779 scales_a,
780 scales_b,
781 scales_factor,
782 } => WmmaInstruction::ExecuteScaled {
783 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
784 frag_a: self.compile_variable(registers_a),
785 frag_b: self.compile_variable(registers_b),
786 frag_c: self.compile_variable(registers_c),
787 frag_d: out,
788
789 scales_a: self.compile_variable(scales_a),
790 scales_b: self.compile_variable(scales_b),
791 scales_factor,
792 },
793 gpu::CoopMma::Store {
794 mat,
795 stride,
796 offset,
797 layout,
798 } => {
799 self.flags.indexes.unit_pos = true;
800 self.flags.indexes.plane_index = true;
801 WmmaInstruction::Store {
802 output: out,
803 offset: self.compile_variable(offset),
804 frag: self.compile_variable(mat),
805 stride: self.compile_variable(stride),
806 layout: self
807 .compile_matrix_layout(layout)
808 .expect("Layout required for store instruction"),
809 }
810 }
811 gpu::CoopMma::LoadMatrix {
812 buffer,
813 offset,
814 line_size,
815 factor,
816 transpose,
817 } => WmmaInstruction::LdMatrix {
818 output: out,
819 buffer: self.compile_variable(buffer),
820 offset: self.compile_variable(offset),
821 line_size,
822 factor,
823 transpose,
824 },
825 gpu::CoopMma::StoreMatrix { .. } => {
826 todo!()
827 }
828 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
829 input: self.compile_variable(input),
830 output: out,
831 },
832 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
833 panic!("Row/Col index should be handled by processors")
834 }
835 };
836
837 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
838
839 Instruction::Wmma(inst)
840 }
841
842 fn compile_metadata(
843 &mut self,
844 metadata: gpu::Metadata,
845 out: Option<gpu::Variable>,
846 ) -> Instruction<D> {
847 let out = out.unwrap();
848 match metadata {
849 gpu::Metadata::Stride { dim, var } => {
850 let position = self.ext_meta_position(var);
851 let offset = self.metadata.stride_offset_index(position);
852 Instruction::ExtendedMetadata {
853 info_offset: self.compile_variable(offset.into()),
854 dim: self.compile_variable(dim),
855 split_meta: self.compilation_options.supports_features.grid_constants,
856 static_offset: self.metadata.static_len(),
857 out: self.compile_variable(out),
858 }
859 }
860 gpu::Metadata::Shape { dim, var } => {
861 let position = self.ext_meta_position(var);
862 let offset = self.metadata.shape_offset_index(position);
863 Instruction::ExtendedMetadata {
864 info_offset: self.compile_variable(offset.into()),
865 dim: self.compile_variable(dim),
866 split_meta: self.compilation_options.supports_features.grid_constants,
867 static_offset: self.metadata.static_len(),
868 out: self.compile_variable(out),
869 }
870 }
871 gpu::Metadata::Rank { var } => {
872 let out = self.compile_variable(out);
873 let pos = self.ext_meta_position(var);
874 let offset = self.metadata.rank_index(pos);
875 super::Instruction::Metadata {
876 info_offset: self.compile_variable(offset.into()),
877 split_meta: self.compilation_options.supports_features.grid_constants,
878 out,
879 }
880 }
881 gpu::Metadata::Length { var } => {
882 let input = self.compile_variable(var);
883 let out = self.compile_variable(out);
884
885 match input {
886 Variable::Slice { .. } => Instruction::SliceLength { input, out },
887 Variable::SharedMemory(_id, _item, length) => {
888 Instruction::ConstLength { length, out }
889 }
890 _ => {
891 let id = input.id().expect("Variable should have id");
892 let offset = self.metadata.len_index(id);
893 Instruction::Metadata {
894 info_offset: self.compile_variable(offset.into()),
895 split_meta: self.compilation_options.supports_features.grid_constants,
896 out,
897 }
898 }
899 }
900 }
901 gpu::Metadata::BufferLength { var } => {
902 let input = self.compile_variable(var);
903 let out = self.compile_variable(out);
904
905 match input {
906 Variable::Slice { .. } => Instruction::SliceLength { input, out },
907 _ => {
908 let id = input.id().expect("Variable should have id");
909 let offset = self.metadata.buffer_len_index(id);
910 Instruction::Metadata {
911 info_offset: self.compile_variable(offset.into()),
912 split_meta: self.compilation_options.supports_features.grid_constants,
913 out,
914 }
915 }
916 }
917 }
918 }
919 }
920
921 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
922 match branch {
923 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
924 cond: self.compile_variable(op.cond),
925 instructions: self.compile_scope(&mut op.scope),
926 }),
927 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
928 cond: self.compile_variable(op.cond),
929 instructions_if: self.compile_scope(&mut op.scope_if),
930 instructions_else: self.compile_scope(&mut op.scope_else),
931 }),
932 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
933 value: self.compile_variable(op.value),
934 instructions_default: self.compile_scope(&mut op.scope_default),
935 instructions_cases: op
936 .cases
937 .into_iter()
938 .map(|(val, mut block)| {
939 (self.compile_variable(val), self.compile_scope(&mut block))
940 })
941 .collect(),
942 }),
943 gpu::Branch::Return => instructions.push(Instruction::Return),
944 gpu::Branch::Break => instructions.push(Instruction::Break),
945 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
946 i: self.compile_variable(range_loop.i),
947 start: self.compile_variable(range_loop.start),
948 end: self.compile_variable(range_loop.end),
949 step: range_loop.step.map(|it| self.compile_variable(it)),
950 inclusive: range_loop.inclusive,
951 instructions: self.compile_scope(&mut range_loop.scope),
952 }),
953 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
954 instructions: self.compile_scope(&mut op.scope),
955 }),
956 };
957 }
958
959 fn compile_atomic(
960 &mut self,
961 value: gpu::AtomicOp,
962 out: Option<gpu::Variable>,
963 instructions: &mut Vec<Instruction<D>>,
964 ) {
965 let out = out.unwrap();
966 match value {
967 gpu::AtomicOp::Load(op) => {
968 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
969 }
970 gpu::AtomicOp::Store(op) => {
971 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
972 }
973 gpu::AtomicOp::Swap(op) => {
974 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
975 }
976 gpu::AtomicOp::Add(op) => {
977 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
978 }
979 gpu::AtomicOp::Sub(op) => {
980 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
981 }
982 gpu::AtomicOp::Max(op) => {
983 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
984 }
985 gpu::AtomicOp::Min(op) => {
986 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
987 }
988 gpu::AtomicOp::And(op) => {
989 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
990 }
991 gpu::AtomicOp::Or(op) => {
992 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
993 }
994 gpu::AtomicOp::Xor(op) => {
995 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
996 }
997 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
998 input: self.compile_variable(op.input),
999 cmp: self.compile_variable(op.cmp),
1000 val: self.compile_variable(op.val),
1001 out: self.compile_variable(out),
1002 }),
1003 }
1004 }
1005
1006 fn compile_arithmetic(
1007 &mut self,
1008 value: gpu::Arithmetic,
1009 out: Option<gpu::Variable>,
1010 modes: InstructionModes,
1011 instructions: &mut Vec<Instruction<D>>,
1012 ) {
1013 let out = out.unwrap();
1014 match value {
1015 gpu::Arithmetic::Add(op) => {
1016 instructions.push(Instruction::Add(self.compile_binary(op, out)))
1017 }
1018 gpu::Arithmetic::SaturatingAdd(op) => {
1019 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1020 }
1021 gpu::Arithmetic::Mul(op) => {
1022 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1023 }
1024 gpu::Arithmetic::Div(op) => {
1025 let op = self.compile_binary(op, out);
1026 instructions.push(self.select_fast_float(
1027 out.ty,
1028 modes,
1029 FastMath::AllowReciprocal
1030 | FastMath::ReducedPrecision
1031 | FastMath::UnsignedZero
1032 | FastMath::NotInf,
1033 Instruction::Div(op),
1034 Instruction::FastDiv(op),
1035 ))
1036 }
1037 gpu::Arithmetic::Sub(op) => {
1038 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1039 }
1040 gpu::Arithmetic::SaturatingSub(op) => {
1041 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1042 }
1043 gpu::Arithmetic::MulHi(op) => {
1044 let instruction = Instruction::HiMul(self.compile_binary(op, out));
1045 D::register_instruction_extension(&mut self.extensions, &instruction);
1046 instructions.push(instruction)
1047 }
1048 gpu::Arithmetic::Modulo(op) => {
1049 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1050 }
1051 gpu::Arithmetic::Abs(op) => {
1052 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1053 }
1054 gpu::Arithmetic::Exp(op) => {
1055 let op = self.compile_unary(op, out);
1056 instructions.push(self.select_fast_float(
1057 out.ty,
1058 modes,
1059 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1060 Instruction::Exp(op),
1061 Instruction::FastExp(op),
1062 ));
1063 }
1064 gpu::Arithmetic::Log(op) => {
1065 let op = self.compile_unary(op, out);
1066 instructions.push(self.select_fast_float(
1067 out.ty,
1068 modes,
1069 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1070 Instruction::Log(op),
1071 Instruction::FastLog(op),
1072 ));
1073 }
1074 gpu::Arithmetic::Log1p(op) => {
1075 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1076 }
1077 gpu::Arithmetic::Cos(op) => {
1078 let op = self.compile_unary(op, out);
1079 instructions.push(self.select_fast_float(
1080 out.ty,
1081 modes,
1082 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1083 Instruction::Cos(op),
1084 Instruction::FastCos(op),
1085 ));
1086 }
1087 gpu::Arithmetic::Sin(op) => {
1088 let op = self.compile_unary(op, out);
1089 instructions.push(self.select_fast_float(
1090 out.ty,
1091 modes,
1092 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1093 Instruction::Sin(op),
1094 Instruction::FastSin(op),
1095 ));
1096 }
1097 gpu::Arithmetic::Tan(op) => {
1098 instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1099 }
1100 gpu::Arithmetic::Tanh(op) => {
1101 let op = self.compile_unary(op, out);
1102 let instruction = Instruction::Tanh(op);
1103 D::register_instruction_extension(&mut self.extensions, &instruction);
1104 if self.compilation_options.supports_features.fast_tanh {
1105 instructions.push(self.select_fast_float(
1106 out.ty,
1107 modes,
1108 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1109 instruction,
1110 Instruction::FastTanh(op),
1111 ))
1112 } else {
1113 instructions.push(instruction);
1114 }
1115 }
1116 gpu::Arithmetic::Sinh(op) => {
1117 let instruction = Instruction::Sinh(self.compile_unary(op, out));
1118 D::register_instruction_extension(&mut self.extensions, &instruction);
1119 instructions.push(instruction)
1120 }
1121 gpu::Arithmetic::Cosh(op) => {
1122 let instruction = Instruction::Cosh(self.compile_unary(op, out));
1123 D::register_instruction_extension(&mut self.extensions, &instruction);
1124 instructions.push(instruction)
1125 }
1126 gpu::Arithmetic::ArcCos(op) => {
1127 let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1128 D::register_instruction_extension(&mut self.extensions, &instruction);
1129 instructions.push(instruction)
1130 }
1131 gpu::Arithmetic::ArcSin(op) => {
1132 let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1133 D::register_instruction_extension(&mut self.extensions, &instruction);
1134 instructions.push(instruction)
1135 }
1136 gpu::Arithmetic::ArcTan(op) => {
1137 let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1138 D::register_instruction_extension(&mut self.extensions, &instruction);
1139 instructions.push(instruction)
1140 }
1141 gpu::Arithmetic::ArcSinh(op) => {
1142 let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1143 D::register_instruction_extension(&mut self.extensions, &instruction);
1144 instructions.push(instruction)
1145 }
1146 gpu::Arithmetic::ArcCosh(op) => {
1147 let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1148 D::register_instruction_extension(&mut self.extensions, &instruction);
1149 instructions.push(instruction)
1150 }
1151 gpu::Arithmetic::ArcTanh(op) => {
1152 let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1153 D::register_instruction_extension(&mut self.extensions, &instruction);
1154 instructions.push(instruction)
1155 }
1156 gpu::Arithmetic::Degrees(op) => {
1157 let instruction = Instruction::Degrees(self.compile_unary(op, out));
1158 D::register_instruction_extension(&mut self.extensions, &instruction);
1159 instructions.push(instruction)
1160 }
1161 gpu::Arithmetic::Radians(op) => {
1162 let instruction = Instruction::Radians(self.compile_unary(op, out));
1163 D::register_instruction_extension(&mut self.extensions, &instruction);
1164 instructions.push(instruction)
1165 }
1166 gpu::Arithmetic::ArcTan2(op) => {
1167 let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1168 D::register_instruction_extension(&mut self.extensions, &instruction);
1169 instructions.push(instruction)
1170 }
1171 gpu::Arithmetic::Powf(op) => {
1172 let op = self.compile_binary(op, out);
1173 instructions.push(self.select_fast_float(
1174 out.ty,
1175 modes,
1176 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1177 Instruction::Powf(op),
1178 Instruction::FastPowf(op),
1179 ))
1180 }
1181 gpu::Arithmetic::Powi(op) => {
1182 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1183 }
1184 gpu::Arithmetic::Sqrt(op) => {
1185 let op = self.compile_unary(op, out);
1186 instructions.push(self.select_fast_float(
1187 out.ty,
1188 modes,
1189 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1190 Instruction::Sqrt(op),
1191 Instruction::FastSqrt(op),
1192 ))
1193 }
1194 gpu::Arithmetic::InverseSqrt(op) => {
1195 let op = self.compile_unary(op, out);
1196 instructions.push(self.select_fast_float(
1197 out.ty,
1198 modes,
1199 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1200 Instruction::InverseSqrt(op),
1201 Instruction::FastInverseSqrt(op),
1202 ))
1203 }
1204 gpu::Arithmetic::Erf(op) => {
1205 let instruction = Instruction::Erf(self.compile_unary(op, out));
1206 D::register_instruction_extension(&mut self.extensions, &instruction);
1207 instructions.push(instruction)
1208 }
1209 gpu::Arithmetic::Max(op) => {
1210 let instruction = Instruction::Max(self.compile_binary(op, out));
1211 D::register_instruction_extension(&mut self.extensions, &instruction);
1212 instructions.push(instruction)
1213 }
1214 gpu::Arithmetic::Min(op) => {
1215 let instruction = Instruction::Min(self.compile_binary(op, out));
1216 D::register_instruction_extension(&mut self.extensions, &instruction);
1217 instructions.push(instruction)
1218 }
1219 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1220 input: self.compile_variable(op.input),
1221 min_value: self.compile_variable(op.min_value),
1222 max_value: self.compile_variable(op.max_value),
1223 out: self.compile_variable(out),
1224 }),
1225 gpu::Arithmetic::Recip(op) => {
1226 let elem = op.input.ty.elem_type();
1227 let input = self.compile_variable(op.input);
1228 let out = self.compile_variable(out);
1229 let lhs = match elem {
1230 gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1231 gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1232 gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1233 gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1234 };
1235 let div = Instruction::Div(BinaryInstruction {
1236 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1237 rhs: input,
1238 out,
1239 });
1240 let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1241
1242 instructions.push(self.select_fast_float(
1243 elem.into(),
1244 modes,
1245 FastMath::AllowReciprocal
1246 | FastMath::ReducedPrecision
1247 | FastMath::UnsignedZero
1248 | FastMath::NotInf,
1249 div,
1250 recip,
1251 ))
1252 }
1253 gpu::Arithmetic::Round(op) => {
1254 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1255 }
1256 gpu::Arithmetic::Floor(op) => {
1257 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1258 }
1259 gpu::Arithmetic::Ceil(op) => {
1260 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1261 }
1262 gpu::Arithmetic::Trunc(op) => {
1263 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1264 }
1265 gpu::Arithmetic::Remainder(op) => {
1266 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1267 }
1268 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1269 a: self.compile_variable(op.a),
1270 b: self.compile_variable(op.b),
1271 c: self.compile_variable(op.c),
1272 out: self.compile_variable(out),
1273 }),
1274 gpu::Arithmetic::Neg(op) => {
1275 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1276 }
1277 gpu::Arithmetic::Normalize(op) => {
1278 let op = self.compile_unary(op, out);
1279 instructions.push(self.select_fast_float(
1280 out.ty,
1281 modes,
1282 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1283 Instruction::Normalize(op),
1284 Instruction::FastNormalize(op),
1285 ))
1286 }
1287 gpu::Arithmetic::Magnitude(op) => {
1288 let op = self.compile_unary(op, out);
1289 instructions.push(self.select_fast_float(
1290 out.ty,
1291 modes,
1292 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1293 Instruction::Magnitude(op),
1294 Instruction::FastMagnitude(op),
1295 ))
1296 }
1297 gpu::Arithmetic::Dot(op) => {
1298 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1299 }
1300 };
1301 }
1302
1303 fn select_fast_float(
1304 &self,
1305 ty: gpu::Type,
1306 modes: InstructionModes,
1307 required_flags: EnumSet<FastMath>,
1308 default: Instruction<D>,
1309 fast: Instruction<D>,
1310 ) -> Instruction<D> {
1311 if !self.compilation_options.supports_features.fast_math
1312 || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1313 {
1314 return default;
1315 }
1316
1317 if modes.fp_math_mode.is_superset(required_flags) {
1318 fast
1319 } else {
1320 default
1321 }
1322 }
1323
1324 fn compile_comparison(
1325 &mut self,
1326 value: gpu::Comparison,
1327 out: Option<gpu::Variable>,
1328 instructions: &mut Vec<Instruction<D>>,
1329 ) {
1330 let out = out.unwrap();
1331 match value {
1332 gpu::Comparison::Equal(op) => {
1333 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1334 }
1335 gpu::Comparison::Lower(op) => {
1336 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1337 }
1338 gpu::Comparison::Greater(op) => {
1339 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1340 }
1341 gpu::Comparison::LowerEqual(op) => {
1342 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1343 }
1344 gpu::Comparison::GreaterEqual(op) => {
1345 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1346 }
1347 gpu::Comparison::NotEqual(op) => {
1348 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1349 }
1350 gpu::Comparison::IsNan(op) => {
1351 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1352 }
1353 gpu::Comparison::IsInf(op) => {
1354 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1355 }
1356 };
1357 }
1358
1359 fn compile_bitwise(
1360 &mut self,
1361 value: gpu::Bitwise,
1362 out: Option<gpu::Variable>,
1363 instructions: &mut Vec<Instruction<D>>,
1364 ) {
1365 let out = out.unwrap();
1366 match value {
1367 gpu::Bitwise::BitwiseOr(op) => {
1368 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1369 }
1370 gpu::Bitwise::BitwiseAnd(op) => {
1371 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1372 }
1373 gpu::Bitwise::BitwiseXor(op) => {
1374 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1375 }
1376 gpu::Bitwise::CountOnes(op) => {
1377 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1378 }
1379 gpu::Bitwise::ReverseBits(op) => {
1380 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1381 }
1382 gpu::Bitwise::ShiftLeft(op) => {
1383 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1384 }
1385 gpu::Bitwise::ShiftRight(op) => {
1386 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1387 }
1388 gpu::Bitwise::BitwiseNot(op) => {
1389 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1390 }
1391 gpu::Bitwise::LeadingZeros(op) => {
1392 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1393 }
1394 gpu::Bitwise::FindFirstSet(op) => {
1395 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1396 D::register_instruction_extension(&mut self.extensions, &instruction);
1397 instructions.push(instruction)
1398 }
1399 };
1400 }
1401
1402 fn compile_operator(
1403 &mut self,
1404 value: gpu::Operator,
1405 out: Option<gpu::Variable>,
1406 instructions: &mut Vec<Instruction<D>>,
1407 ) {
1408 let out = out.unwrap();
1409 match value {
1410 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1411 instructions.push(Instruction::Index(self.compile_index(op, out)));
1412 }
1413 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1414 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1415 }
1416 gpu::Operator::And(op) => {
1417 instructions.push(Instruction::And(self.compile_binary(op, out)))
1418 }
1419 gpu::Operator::Or(op) => {
1420 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1421 }
1422 gpu::Operator::Not(op) => {
1423 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1424 }
1425 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1426 inputs: op
1427 .inputs
1428 .into_iter()
1429 .map(|it| self.compile_variable(it))
1430 .collect(),
1431 out: self.compile_variable(out),
1432 }),
1433 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1434 input: self.compile_variable(op.input),
1435 in_index: self.compile_variable(op.in_index),
1436 out: self.compile_variable(out),
1437 out_index: self.compile_variable(op.out_index),
1438 }),
1439 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1440 input: self.compile_variable(op.input),
1441 in_index: self.compile_variable(op.in_index),
1442 out: self.compile_variable(out),
1443 out_index: self.compile_variable(op.out_index),
1444 len: op.len,
1445 }),
1446 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1447 cond: self.compile_variable(op.cond),
1448 then: self.compile_variable(op.then),
1449 or_else: self.compile_variable(op.or_else),
1450 out: self.compile_variable(out),
1451 }),
1452 gpu::Operator::Cast(op)
1454 if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1455 {
1456 self.flags.elem_f16 = true;
1458 self.flags.elem_bf16 = true;
1459 let vec_in = op.input.ty.line_size();
1460 let packing = out.storage_type().packing_factor();
1461 self.compile_type(op.input.ty.line(packing));
1462 self.compile_type(
1463 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1464 );
1465 self.compile_type(
1466 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1467 );
1468 self.compile_type(
1469 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1470 );
1471 self.compile_type(
1472 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1473 );
1474
1475 let inst = self.compile_unary(op, out);
1476
1477 instructions.push(Instruction::SpecialCast(inst));
1478 }
1479 gpu::Operator::Cast(op) => {
1480 let op = self.compile_unary(op, out);
1481
1482 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1483 self.flags.elem_tf32 = true;
1484 }
1485
1486 instructions.push(Instruction::Assign(op))
1487 }
1488 gpu::Operator::Reinterpret(op) => {
1489 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1490 }
1491 };
1492 }
1493
1494 fn compile_binary(
1495 &mut self,
1496 value: gpu::BinaryOperator,
1497 out: gpu::Variable,
1498 ) -> BinaryInstruction<D> {
1499 BinaryInstruction {
1500 lhs: self.compile_variable(value.lhs),
1501 rhs: self.compile_variable(value.rhs),
1502 out: self.compile_variable(out),
1503 }
1504 }
1505
1506 fn compile_index_assign(
1507 &mut self,
1508 value: gpu::IndexAssignOperator,
1509 out: gpu::Variable,
1510 ) -> IndexAssignInstruction<D> {
1511 IndexAssignInstruction {
1512 index: self.compile_variable(value.index),
1513 value: self.compile_variable(value.value),
1514 line_size: value.line_size,
1515 out: self.compile_variable(out),
1516 }
1517 }
1518
1519 fn compile_index(
1520 &mut self,
1521 value: gpu::IndexOperator,
1522 out: gpu::Variable,
1523 ) -> IndexInstruction<D> {
1524 IndexInstruction {
1525 list: self.compile_variable(value.list),
1526 index: self.compile_variable(value.index),
1527 line_size: value.line_size,
1528 out: self.compile_variable(out),
1529 }
1530 }
1531
1532 fn compile_unary(
1533 &mut self,
1534 value: gpu::UnaryOperator,
1535 out: gpu::Variable,
1536 ) -> UnaryInstruction<D> {
1537 UnaryInstruction {
1538 input: self.compile_variable(value.input),
1539 out: self.compile_variable(out),
1540 }
1541 }
1542
1543 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1544 let item = value.ty;
1545 match value.kind {
1546 gpu::VariableKind::GlobalInputArray(id) => {
1547 Variable::GlobalInputArray(id, self.compile_type(item))
1548 }
1549 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1550 id,
1551 elem: self.compile_storage_type(item.storage_type()),
1552 in_struct: self.compilation_options.supports_features.grid_constants,
1553 },
1554 gpu::VariableKind::TensorMapInput(id) => {
1555 self.flags.inst_tma = true;
1556 Variable::TensorMap(id)
1557 }
1558 gpu::VariableKind::TensorMapOutput(id) => {
1559 self.flags.inst_tma = true;
1560 Variable::TensorMap(id)
1561 }
1562 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1563 id,
1564 item: self.compile_type(item),
1565 },
1566 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1567 id,
1568 item: self.compile_type(item),
1569 },
1570 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1571 id,
1572 item: self.compile_type(item),
1573 },
1574 gpu::VariableKind::GlobalOutputArray(id) => {
1575 Variable::GlobalOutputArray(id, self.compile_type(item))
1576 }
1577 gpu::VariableKind::ConstantScalar(value) => {
1578 Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1579 }
1580 gpu::VariableKind::SharedMemory { id, length, .. } => {
1581 let item = self.compile_type(item);
1582 Variable::SharedMemory(id, item, length)
1583 }
1584 gpu::VariableKind::ConstantArray {
1585 id,
1586 length,
1587 unroll_factor,
1588 } => {
1589 let item = self.compile_type(item);
1590 Variable::ConstantArray(id, item, length * unroll_factor)
1591 }
1592 gpu::VariableKind::Builtin(builtin) => match builtin {
1593 gpu::Builtin::AbsolutePos => {
1594 self.flags.indexes.absolute_pos = true;
1595 Variable::AbsolutePos
1596 }
1597 gpu::Builtin::CubePosCluster
1598 if self.compilation_options.supports_features.clusters =>
1599 {
1600 self.flags.indexes.cluster_pos = true;
1601 Variable::ClusterRank
1602 }
1603 gpu::Builtin::CubePosClusterX
1604 if self.compilation_options.supports_features.clusters =>
1605 {
1606 self.flags.indexes.cluster_pos = true;
1607 Variable::ClusterIndexX
1608 }
1609 gpu::Builtin::CubePosClusterY
1610 if self.compilation_options.supports_features.clusters =>
1611 {
1612 self.flags.indexes.cluster_pos = true;
1613 Variable::ClusterIndexY
1614 }
1615 gpu::Builtin::CubePosClusterZ
1616 if self.compilation_options.supports_features.clusters =>
1617 {
1618 self.flags.indexes.cluster_pos = true;
1619 Variable::ClusterIndexZ
1620 }
1621 gpu::Builtin::CubePosCluster
1624 | gpu::Builtin::CubePosClusterX
1625 | gpu::Builtin::CubePosClusterY
1626 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1627 gpu::Builtin::AbsolutePosX => {
1628 self.flags.indexes.absolute_pos_tuple = true;
1629 Variable::AbsolutePosX
1630 }
1631 gpu::Builtin::AbsolutePosY => {
1632 self.flags.indexes.absolute_pos_tuple = true;
1633 Variable::AbsolutePosY
1634 }
1635 gpu::Builtin::AbsolutePosZ => {
1636 self.flags.indexes.absolute_pos_tuple = true;
1637 Variable::AbsolutePosZ
1638 }
1639 gpu::Builtin::CubeDim => {
1640 self.flags.indexes.cube_dim = true;
1641 Variable::CubeDim
1642 }
1643 gpu::Builtin::CubeDimX => {
1644 self.flags.indexes.cube_dim_tuple = true;
1645 Variable::CubeDimX
1646 }
1647 gpu::Builtin::CubeDimY => {
1648 self.flags.indexes.cube_dim_tuple = true;
1649 Variable::CubeDimY
1650 }
1651 gpu::Builtin::CubeDimZ => {
1652 self.flags.indexes.cube_dim_tuple = true;
1653 Variable::CubeDimZ
1654 }
1655 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1656 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1657 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1658 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1659 gpu::Builtin::CubePos => {
1660 self.flags.indexes.cube_pos = true;
1661 Variable::CubePos
1662 }
1663 gpu::Builtin::CubePosX => {
1664 self.flags.indexes.cube_pos_tuple = true;
1665 Variable::CubePosX
1666 }
1667 gpu::Builtin::CubePosY => {
1668 self.flags.indexes.cube_pos_tuple = true;
1669 Variable::CubePosY
1670 }
1671 gpu::Builtin::CubePosZ => {
1672 self.flags.indexes.cube_pos_tuple = true;
1673 Variable::CubePosZ
1674 }
1675 gpu::Builtin::CubeCount => {
1676 self.flags.indexes.cube_count = true;
1677 Variable::CubeCount
1678 }
1679 gpu::Builtin::CubeCountX => {
1680 self.flags.indexes.cube_count_tuple = true;
1681 Variable::CubeCountX
1682 }
1683 gpu::Builtin::CubeCountY => {
1684 self.flags.indexes.cube_count_tuple = true;
1685 Variable::CubeCountY
1686 }
1687 gpu::Builtin::CubeCountZ => {
1688 self.flags.indexes.cube_count_tuple = true;
1689 Variable::CubeCountZ
1690 }
1691 gpu::Builtin::UnitPos => {
1692 self.flags.indexes.unit_pos = true;
1693 Variable::UnitPos
1694 }
1695 gpu::Builtin::UnitPosX => {
1696 self.flags.indexes.unit_pos_tuple = true;
1697 Variable::UnitPosX
1698 }
1699 gpu::Builtin::UnitPosY => {
1700 self.flags.indexes.unit_pos_tuple = true;
1701 Variable::UnitPosY
1702 }
1703 gpu::Builtin::UnitPosZ => {
1704 self.flags.indexes.unit_pos_tuple = true;
1705 Variable::UnitPosZ
1706 }
1707 gpu::Builtin::PlaneDim => {
1708 self.flags.indexes.plane_dim = true;
1709 Variable::PlaneDim
1710 }
1711 gpu::Builtin::UnitPosPlane => {
1712 self.flags.indexes.unit_pos_plane = true;
1713 Variable::UnitPosPlane
1714 }
1715 },
1716 gpu::VariableKind::LocalArray {
1717 id,
1718 length,
1719 unroll_factor,
1720 } => {
1721 let item = self.compile_type(item);
1722 if !self.local_arrays.iter().any(|s| s.index == id) {
1723 self.local_arrays
1724 .push(LocalArray::new(id, item, length * unroll_factor));
1725 }
1726 Variable::LocalArray(id, item, length)
1727 }
1728 gpu::VariableKind::Matrix { id, mat } => {
1729 self.flags.inst_wmma = true;
1730 Variable::WmmaFragment {
1731 id,
1732 frag: self.compile_matrix(mat),
1733 }
1734 }
1735 gpu::VariableKind::Pipeline { id, num_stages } => {
1736 self.flags.op_pipeline = true;
1737 let pipeline = Variable::Pipeline { id };
1738 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1739 self.pipelines.push(PipelineOps::Init {
1740 pipeline,
1741 num_stages,
1742 });
1743 }
1744 pipeline
1745 }
1746 gpu::VariableKind::Barrier { id, level } => {
1747 self.flags.op_barrier = true;
1748 Variable::Barrier { id, level }
1749 }
1750 gpu::VariableKind::BarrierToken { id, level } => {
1751 self.flags.op_barrier = true;
1752 Variable::BarrierToken { id, level }
1753 }
1754 }
1755 }
1756
1757 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1758 Fragment {
1759 ident: self.compile_matrix_ident(matrix.ident),
1760 m: matrix.m,
1761 n: matrix.n,
1762 k: matrix.k,
1763 elem: self.compile_storage_type(matrix.storage),
1764 layout: self.compile_matrix_layout(matrix.layout),
1765 }
1766 }
1767
1768 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1769 match ident {
1770 gpu::MatrixIdent::A => FragmentIdent::A,
1771 gpu::MatrixIdent::B => FragmentIdent::B,
1772 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1773 }
1774 }
1775
1776 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1777 match layout {
1778 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1779 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1780 gpu::MatrixLayout::Undefined => None,
1781 }
1782 }
1783
1784 fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1785 Binding {
1786 id: binding.id,
1787 item: self.compile_type(binding.ty),
1788 location: binding.location,
1789 size: binding.size,
1790 vis: binding.visibility,
1791 }
1792 }
1793
1794 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1795 let item = match ty {
1796 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1797 gpu::Type::Line(ty, line_size) => {
1798 Item::new(self.compile_storage_type(ty), line_size as usize, false)
1799 }
1800 gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1801 };
1802 if item.elem != super::Elem::TF32 {
1803 self.items.insert(item);
1804 self.items.insert(item.optimized());
1805 } else {
1806 let mut item = item;
1808 item.elem = super::Elem::F32;
1809 self.items.insert(item);
1810 }
1811
1812 item
1813 }
1814
1815 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1816 match value {
1817 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1818 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1819 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1820 FloatKind::E2M1 => {
1821 self.flags.elem_fp4 = true;
1822 Elem::FP4x2(FP4Kind::E2M1)
1823 }
1824 FloatKind::E2M3 => {
1825 self.flags.elem_fp6 = true;
1826 Elem::FP6x2(FP6Kind::E2M3)
1827 }
1828 FloatKind::E3M2 => {
1829 self.flags.elem_fp6 = true;
1830 Elem::FP6(FP6Kind::E3M2)
1831 }
1832 FloatKind::E4M3 => {
1833 self.flags.elem_fp8 = true;
1834 Elem::FP8x2(FP8Kind::E4M3)
1835 }
1836 FloatKind::E5M2 => {
1837 self.flags.elem_fp8 = true;
1838 Elem::FP8x2(FP8Kind::E5M2)
1839 }
1840 FloatKind::UE8M0 => {
1841 self.flags.elem_fp8 = true;
1842 Elem::FP8x2(FP8Kind::UE8M0)
1843 }
1844 FloatKind::F16 => {
1845 self.flags.elem_f16 = true;
1846 Elem::F16x2
1847 }
1848 FloatKind::BF16 => {
1849 self.flags.elem_bf16 = true;
1850 Elem::BF16x2
1851 }
1852 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1853 },
1854 other => unimplemented!("Unsupported storage type: {other}"),
1855 }
1856 }
1857
1858 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1859 match value {
1860 gpu::ElemType::Float(kind) => match kind {
1861 gpu::FloatKind::E2M1 => {
1862 self.flags.elem_fp4 = true;
1863 Elem::FP4(FP4Kind::E2M1)
1864 }
1865 gpu::FloatKind::E2M3 => {
1866 self.flags.elem_fp6 = true;
1867 Elem::FP6(FP6Kind::E2M3)
1868 }
1869 gpu::FloatKind::E3M2 => {
1870 self.flags.elem_fp6 = true;
1871 Elem::FP6(FP6Kind::E3M2)
1872 }
1873 gpu::FloatKind::E4M3 => {
1874 self.flags.elem_fp8 = true;
1875 Elem::FP8(FP8Kind::E4M3)
1876 }
1877 gpu::FloatKind::E5M2 => {
1878 self.flags.elem_fp8 = true;
1879 Elem::FP8(FP8Kind::E5M2)
1880 }
1881 gpu::FloatKind::UE8M0 => {
1882 self.flags.elem_fp8 = true;
1883 Elem::FP8(FP8Kind::UE8M0)
1884 }
1885 gpu::FloatKind::F16 => {
1886 self.flags.elem_f16 = true;
1887 Elem::F16
1888 }
1889 gpu::FloatKind::BF16 => {
1890 self.flags.elem_bf16 = true;
1891 Elem::BF16
1892 }
1893 gpu::FloatKind::TF32 => Elem::TF32,
1894 gpu::FloatKind::Flex32 => Elem::F32,
1895 gpu::FloatKind::F32 => Elem::F32,
1896 gpu::FloatKind::F64 => Elem::F64,
1897 },
1898 gpu::ElemType::Int(kind) => match kind {
1899 gpu::IntKind::I8 => Elem::I8,
1900 gpu::IntKind::I16 => Elem::I16,
1901 gpu::IntKind::I32 => Elem::I32,
1902 gpu::IntKind::I64 => Elem::I64,
1903 },
1904 gpu::ElemType::UInt(kind) => match kind {
1905 gpu::UIntKind::U8 => Elem::U8,
1906 gpu::UIntKind::U16 => Elem::U16,
1907 gpu::UIntKind::U32 => Elem::U32,
1908 gpu::UIntKind::U64 => Elem::U64,
1909 },
1910 gpu::ElemType::Bool => Elem::Bool,
1911 }
1912 }
1913}
1914
1915fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1916 match elem {
1917 gpu::ElemType::Float(kind) => matches!(
1918 kind,
1919 FloatKind::E2M1
1920 | FloatKind::E2M3
1921 | FloatKind::E3M2
1922 | FloatKind::E4M3
1923 | FloatKind::E5M2
1924 | FloatKind::UE8M0
1925 ),
1926 _ => false,
1927 }
1928}
1929
1930fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1931 Variable::ConstantScalar(
1932 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1933 Elem::U32,
1934 )
1935}
1936
1937pub fn register_supported_types(props: &mut DeviceProperties) {
1938 let supported_types = [
1939 gpu::ElemType::UInt(gpu::UIntKind::U8),
1940 gpu::ElemType::UInt(gpu::UIntKind::U16),
1941 gpu::ElemType::UInt(gpu::UIntKind::U32),
1942 gpu::ElemType::UInt(gpu::UIntKind::U64),
1943 gpu::ElemType::Int(gpu::IntKind::I8),
1944 gpu::ElemType::Int(gpu::IntKind::I16),
1945 gpu::ElemType::Int(gpu::IntKind::I32),
1946 gpu::ElemType::Int(gpu::IntKind::I64),
1947 gpu::ElemType::Float(gpu::FloatKind::BF16),
1948 gpu::ElemType::Float(gpu::FloatKind::F16),
1949 gpu::ElemType::Float(gpu::FloatKind::F32),
1950 gpu::ElemType::Float(gpu::FloatKind::Flex32),
1951 gpu::ElemType::Bool,
1954 ];
1955
1956 let supported_atomic_types = [
1957 gpu::ElemType::Int(gpu::IntKind::I32),
1958 gpu::ElemType::Int(gpu::IntKind::I64),
1959 gpu::ElemType::UInt(gpu::UIntKind::U32),
1960 gpu::ElemType::UInt(gpu::UIntKind::U64),
1961 gpu::ElemType::Float(gpu::FloatKind::F32),
1962 ];
1963
1964 for ty in supported_types {
1965 props.register_type_usage(ty, TypeUsage::all_scalar());
1966 }
1967
1968 for ty in supported_atomic_types {
1969 props.register_type_usage(
1970 gpu::StorageType::Atomic(ty),
1971 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1972 );
1973 }
1974}