1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{self as gpu};
5use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind, VariableKind};
6use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
7use cubecl_core::{CubeDim, ir::ElemType};
8use cubecl_core::{
9 ir::{Operation, SourceLoc},
10 prelude::{FastMath, KernelDefinition},
11};
12use cubecl_opt::{Optimizer, SharedLiveness};
13use cubecl_runtime::compiler::CompilationError;
14use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage, compiler::Compiler};
15
16use crate::shared::MmaShape;
17
18use super::{
19 BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
20 Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
21 Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
22};
23use super::{FP4Kind, barrier::BarrierOps};
24use super::{FP8Kind, pipeline::PipelineOps};
25
26pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
27 std::sync::atomic::AtomicU32::new(0);
28
29#[derive(Clone, Debug)]
30pub struct CompilationOptions {
31 pub warp_size: u32,
32 pub supports_features: CppSupportedFeatures,
33}
34
35#[derive(Clone, Debug, Default)]
36pub struct CppSupportedFeatures {
37 pub grid_constants: bool,
38 pub clusters: bool,
39 pub fast_math: bool,
40 pub fast_tanh: bool,
41 pub elect_sync: bool,
42}
43
44impl Default for CompilationOptions {
45 fn default() -> Self {
46 Self {
47 warp_size: 32,
48 supports_features: Default::default(),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Default)]
56pub struct CubeIndexFlags {
57 pub absolute_pos: bool,
58 pub absolute_pos_tuple: bool,
59 pub cube_count: bool,
60 pub cube_count_tuple: bool,
61 pub cube_dim: bool,
62 pub cube_dim_tuple: bool,
63 pub cube_pos: bool,
64 pub cube_pos_tuple: bool,
65 pub plane_dim: bool,
66 pub plane_dim_checked: bool,
67 pub plane_index: bool,
68 pub unit_pos: bool,
69 pub unit_pos_tuple: bool,
70 pub unit_pos_plane: bool,
71 pub cluster_pos: bool,
72}
73
74#[derive(Debug, Clone, Default)]
76pub struct Flags {
77 pub elem_fp4: bool,
78 pub elem_fp6: bool,
79 pub elem_fp8: bool,
80 pub elem_bf16: bool,
81 pub elem_f16: bool,
82 pub elem_tf32: bool,
83 pub indexes: CubeIndexFlags,
84 pub op_barrier: bool,
85 pub op_pipeline: bool,
86 pub inst_tma: bool,
87 pub inst_tma_im2col: bool,
88 pub inst_wmma: bool,
89 pub inst_ptx_wrappers: bool,
90 pub use_grid_constants: bool,
91 pub static_meta_length: usize,
92 pub has_dynamic_meta: bool,
93 pub cube_dim: CubeDim,
94 pub cluster_dim: Option<CubeDim>,
95}
96
97#[allow(clippy::too_many_arguments)]
98#[derive(Clone, Debug, Default)]
99pub struct CppCompiler<D: Dialect> {
100 barriers: Vec<BarrierOps<D>>,
101 compilation_options: CompilationOptions,
102 const_arrays: Vec<ConstArray<D>>,
103 ext_meta_positions: Vec<u32>,
104 cluster_dim: CubeDim,
105 extensions: Vec<D::Extension>,
106 flags: Flags,
107 items: HashSet<Item<D>>,
108 local_arrays: Vec<LocalArray<D>>,
109 metadata: cubecl_core::Metadata,
110 pipelines: Vec<PipelineOps<D>>,
111 source_loc: Option<SourceLoc>,
112 strategy: ExecutionMode,
113}
114
115impl<D: Dialect> Compiler for CppCompiler<D> {
116 type Representation = ComputeKernel<D>;
117 type CompilationOptions = CompilationOptions;
118
119 fn compile(
120 &mut self,
121 mut kernel: KernelDefinition,
122 compilation_options: &Self::CompilationOptions,
123 strategy: ExecutionMode,
124 ) -> Result<Self::Representation, CompilationError> {
125 self.compilation_options = compilation_options.clone();
126 self.strategy = strategy;
127
128 if !self.compilation_options.supports_features.clusters {
129 kernel.options.cluster_dim = None;
130 }
131 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
132
133 let ir = self.clone().compile_ir(kernel);
134 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
135 Ok(ir)
136 }
137
138 fn elem_size(&self, elem: gpu::ElemType) -> usize {
139 elem.size()
140 }
141
142 fn extension(&self) -> &'static str {
143 "cpp"
144 }
145}
146
147impl<D: Dialect> CppCompiler<D> {
148 fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
149 self.build_metadata(&value);
150
151 let instructions = self.compile_scope(&mut value.body.clone());
152 let tensor_maps = value
153 .tensor_maps
154 .into_iter()
155 .map(|b| self.compile_binding(b))
156 .collect();
157 let buffers = value
158 .buffers
159 .into_iter()
160 .map(|b| self.compile_binding(b))
161 .collect();
162 let scalars = value
163 .scalars
164 .into_iter()
165 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
166 .collect();
167
168 let flags = Flags {
170 indexes: D::builtin_rules(&self.flags.indexes),
171 inst_wmma: self.flags.inst_wmma,
172 op_pipeline: self.flags.op_pipeline,
173 op_barrier: self.flags.op_barrier,
174 elem_fp4: self.flags.elem_fp4,
175 elem_fp6: self.flags.elem_fp6,
176 elem_fp8: self.flags.elem_fp8,
177 elem_bf16: self.flags.elem_bf16,
178 elem_f16: self.flags.elem_f16,
179 elem_tf32: self.flags.elem_tf32,
180 inst_tma: self.flags.inst_tma,
181 inst_tma_im2col: self.flags.inst_tma_im2col,
182 inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
183 use_grid_constants: self.compilation_options.supports_features.grid_constants,
184 has_dynamic_meta: self.metadata.static_len() > 0,
187 static_meta_length: self.metadata.static_len() as usize,
188 cube_dim: value.cube_dim,
189 cluster_dim: value.options.cluster_dim,
190 };
191
192 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
193 let shared_allocs = opt.analysis::<SharedLiveness>();
194 let shared_memories = shared_allocs
195 .allocations
196 .values()
197 .map(|alloc| SharedMemory {
198 index: alloc.smem.id,
199 item: self.compile_type(alloc.smem.ty),
200 length: alloc.smem.length,
201 align: alloc.smem.align,
202 offset: alloc.offset,
203 })
204 .collect();
205
206 let body = Body {
207 instructions,
208 shared_memories,
209 pipelines: self.pipelines,
210 barriers: self.barriers,
211 const_arrays: self.const_arrays,
212 local_arrays: self.local_arrays,
213 };
214
215 let mut cluster_dim = value.options.cluster_dim;
216 if !self.compilation_options.supports_features.clusters {
217 cluster_dim = None;
218 }
219
220 ComputeKernel {
221 tensor_maps,
222 buffers,
223 scalars,
224 meta_static_len: self.metadata.static_len() as usize,
225 cube_dim: value.cube_dim,
226 body,
227 extensions: self.extensions,
228 flags,
229 items: self.items,
230 kernel_name: value.options.kernel_name,
231 cluster_dim,
232 }
233 }
234
235 fn build_metadata(&mut self, value: &KernelDefinition) {
236 let mut num_ext = 0;
237
238 let mut all_meta: Vec<_> = value
239 .buffers
240 .iter()
241 .chain(value.tensor_maps.iter())
242 .map(|buf| (buf.id, buf.has_extended_meta))
243 .collect();
244
245 all_meta.sort_by_key(|(id, _)| *id);
246
247 for (_, has_extended_meta) in &all_meta {
248 self.ext_meta_positions.push(num_ext);
249 if *has_extended_meta {
250 num_ext += 1;
251 }
252 }
253
254 let num_meta = all_meta.len();
255
256 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
257 }
258
259 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
260 let id = var.index().expect("Variable should have index");
261 self.ext_meta_positions[id as usize]
262 }
263
264 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
265 let mut instructions = Vec::new();
266
267 let const_arrays = scope
268 .const_arrays
269 .drain(..)
270 .map(|(var, values)| ConstArray {
271 index: var.index().unwrap(),
272 item: self.compile_type(var.ty),
273 size: values.len() as u32,
274 values: values
275 .into_iter()
276 .map(|val| self.compile_variable(val))
277 .collect(),
278 })
279 .collect::<Vec<_>>();
280 self.const_arrays.extend(const_arrays);
281
282 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
283 let dialect_processors = D::processors();
284 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
285 processors.extend(dialect_processors.iter().map(|it| &**it));
286
287 let processing = scope.process(processors);
288
289 for var in processing.variables {
290 instructions.push(Instruction::DeclareVariable {
291 var: self.compile_variable(var),
292 });
293 }
294
295 processing
296 .instructions
297 .into_iter()
298 .for_each(|op| self.compile_instruction(&mut instructions, op));
299
300 instructions
301 }
302
303 fn compile_instruction(
304 &mut self,
305 instructions: &mut Vec<Instruction<D>>,
306 instruction: gpu::Instruction,
307 ) {
308 self.update_debug_loc(instructions, &instruction);
309 let out = instruction.out;
310 match instruction.operation {
311 gpu::Operation::Copy(variable) => {
312 instructions.push(Instruction::Assign(UnaryInstruction {
313 input: self.compile_variable(variable),
314 out: self.compile_variable(out.unwrap()),
315 }));
316 }
317 gpu::Operation::Arithmetic(op) => {
318 self.compile_arithmetic(op, out, instruction.modes, instructions)
319 }
320 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
321 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
322 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
323 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
324 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
325 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
326 gpu::Operation::Synchronization(val) => match val {
327 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
328 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
329 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
330 gpu::Synchronization::SyncAsyncProxyShared => {
331 self.flags.inst_tma = true;
332 instructions.push(Instruction::ProxyAsyncToSharedFence)
333 }
334 },
335 gpu::Operation::Plane(op) => {
336 self.flags.indexes.plane_dim_checked = true;
337 let out = self.compile_variable(out.unwrap());
338 match op {
339 gpu::Plane::Sum(op) => {
340 let instruction = WarpInstruction::ReduceSum {
341 input: self.compile_variable(op.input),
342 out,
343 };
344 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
345 instructions.push(Instruction::Warp(instruction));
346 }
347 gpu::Plane::InclusiveSum(op) => {
348 self.flags.indexes.unit_pos_plane = true;
349 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
350 input: self.compile_variable(op.input),
351 out,
352 }))
353 }
354 gpu::Plane::InclusiveProd(op) => {
355 self.flags.indexes.unit_pos_plane = true;
356 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
357 input: self.compile_variable(op.input),
358 out,
359 }))
360 }
361 gpu::Plane::ExclusiveSum(op) => {
362 self.flags.indexes.unit_pos_plane = true;
363 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
364 input: self.compile_variable(op.input),
365 out,
366 }))
367 }
368 gpu::Plane::ExclusiveProd(op) => {
369 self.flags.indexes.unit_pos_plane = true;
370 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
371 input: self.compile_variable(op.input),
372 out,
373 }))
374 }
375 gpu::Plane::Prod(op) => {
376 let instruction = WarpInstruction::ReduceProd {
377 input: self.compile_variable(op.input),
378 out,
379 };
380 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
381 instructions.push(Instruction::Warp(instruction))
382 }
383 gpu::Plane::Max(op) => {
384 let instruction = WarpInstruction::ReduceMax {
385 input: self.compile_variable(op.input),
386 out,
387 };
388 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
389 instructions.push(Instruction::Warp(instruction))
390 }
391 gpu::Plane::Min(op) => {
392 let instruction = WarpInstruction::ReduceMin {
393 input: self.compile_variable(op.input),
394 out,
395 };
396 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
397 instructions.push(Instruction::Warp(instruction))
398 }
399 gpu::Plane::Elect => {
400 if self.compilation_options.supports_features.elect_sync {
401 self.flags.inst_ptx_wrappers = true;
402 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
403 } else {
404 instructions
405 .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
406 }
407 }
408 gpu::Plane::All(op) => {
409 instructions.push(Instruction::Warp(WarpInstruction::All {
410 input: self.compile_variable(op.input),
411 out,
412 }))
413 }
414 gpu::Plane::Any(op) => {
415 instructions.push(Instruction::Warp(WarpInstruction::Any {
416 input: self.compile_variable(op.input),
417 out,
418 }))
419 }
420 gpu::Plane::Ballot(op) => {
421 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
422 input: self.compile_variable(op.input),
423 out,
424 }))
425 }
426 gpu::Plane::Broadcast(op) => {
427 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
428 input: self.compile_variable(op.lhs),
429 id: self.compile_variable(op.rhs),
430 out,
431 }))
432 }
433 gpu::Plane::Shuffle(op) => {
434 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
435 input: self.compile_variable(op.lhs),
436 src_lane: self.compile_variable(op.rhs),
437 out,
438 }))
439 }
440 gpu::Plane::ShuffleXor(op) => {
441 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
442 input: self.compile_variable(op.lhs),
443 mask: self.compile_variable(op.rhs),
444 out,
445 }))
446 }
447 gpu::Plane::ShuffleUp(op) => {
448 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
449 input: self.compile_variable(op.lhs),
450 delta: self.compile_variable(op.rhs),
451 out,
452 }))
453 }
454 gpu::Plane::ShuffleDown(op) => {
455 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
456 input: self.compile_variable(op.lhs),
457 delta: self.compile_variable(op.rhs),
458 out,
459 }))
460 }
461 }
462 }
463 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
464 gpu::Operation::NonSemantic(debug) => match debug {
465 gpu::NonSemantic::Print {
466 format_string,
467 args,
468 } => instructions.push(Instruction::Printf {
469 format_string,
470 args: args
471 .into_iter()
472 .map(|arg| self.compile_variable(arg))
473 .collect(),
474 }),
475 gpu::NonSemantic::Comment { content } => {
476 instructions.push(Instruction::Comment { content })
477 }
478 _ => {}
480 },
481 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
482 gpu::BarrierOps::Declare { barrier } => {
483 let VariableKind::Barrier { level, .. } = barrier.kind else {
484 unreachable!()
485 };
486 let barrier = self.compile_variable(barrier);
487 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
488 barrier,
489 level,
490 }));
491 }
492 gpu::BarrierOps::Init {
493 barrier,
494 is_elected,
495 arrival_count,
496 with_async_proxy_fence,
497 } => {
498 let VariableKind::Barrier { level, .. } = barrier.kind else {
499 unreachable!()
500 };
501 let barrier = self.compile_variable(barrier);
502 let arrival_count = self.compile_variable(arrival_count);
503 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
504 barrier,
505 is_elected: self.compile_variable(is_elected),
506 arrival_count,
507 level,
508 with_async_proxy_fence,
509 }));
510 }
511 gpu::BarrierOps::InitManual {
512 barrier,
513 arrival_count,
514 } => {
515 let barrier = self.compile_variable(barrier);
516 let arrival_count = self.compile_variable(arrival_count);
517 instructions.push(Instruction::Barrier(
518 super::barrier::BarrierOps::InitManual {
519 barrier,
520 arrival_count,
521 },
522 ));
523 }
524 gpu::BarrierOps::MemCopyAsync {
525 barrier,
526 source,
527 source_length,
528 offset_source,
529 offset_out,
530 } => {
531 instructions.push(Instruction::Barrier(
532 super::barrier::BarrierOps::MemCopyAsync {
533 barrier: self.compile_variable(barrier),
534 source: self.compile_variable(source),
535 destination: self.compile_variable(out.unwrap()),
536 source_length: self.compile_variable(source_length),
537 offset_source: self.compile_variable(offset_source),
538 offset_out: self.compile_variable(offset_out),
539 cooperative: false,
540 },
541 ));
542 }
543 gpu::BarrierOps::MemCopyAsyncCooperative {
544 barrier,
545 source,
546 source_length,
547 offset_source,
548 offset_out,
549 } => {
550 instructions.push(Instruction::Barrier(
551 super::barrier::BarrierOps::MemCopyAsync {
552 barrier: self.compile_variable(barrier),
553 source: self.compile_variable(source),
554 destination: self.compile_variable(out.unwrap()),
555 source_length: self.compile_variable(source_length),
556 offset_source: self.compile_variable(offset_source),
557 offset_out: self.compile_variable(offset_out),
558 cooperative: true,
559 },
560 ));
561 }
562 gpu::BarrierOps::MemCopyAsyncTx {
563 barrier,
564 source,
565 source_length,
566 offset_source,
567 offset_out,
568 } => {
569 instructions.push(Instruction::Barrier(
570 super::barrier::BarrierOps::MemCopyAsyncTx {
571 barrier: self.compile_variable(barrier),
572 source: self.compile_variable(source),
573 destination: self.compile_variable(out.unwrap()),
574 source_length: self.compile_variable(source_length),
575 offset_source: self.compile_variable(offset_source),
576 offset_out: self.compile_variable(offset_out),
577 },
578 ));
579 }
580 gpu::BarrierOps::TmaLoad {
581 barrier,
582 tensor_map,
583 offset_out,
584 indices,
585 } => {
586 instructions.push(Instruction::Barrier(
587 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
588 barrier: self.compile_variable(barrier),
589 smem_buffer: self.compile_variable(out.unwrap()),
590 smem_offset: self.compile_variable(offset_out),
591 tensor_map: self.compile_variable(tensor_map),
592 indices: indices
593 .into_iter()
594 .map(|it| self.compile_variable(it))
595 .collect(),
596 },
597 ));
598 }
599 gpu::BarrierOps::TmaLoadIm2col {
600 barrier,
601 tensor_map,
602 offset_out,
603 indices,
604 offsets,
605 } => {
606 self.flags.inst_tma_im2col = true;
607 instructions.push(Instruction::Barrier(
608 super::barrier::BarrierOps::TmaLoadIm2col {
609 barrier: self.compile_variable(barrier),
610 smem_buffer: self.compile_variable(out.unwrap()),
611 smem_offset: self.compile_variable(offset_out),
612 tensor_map: self.compile_variable(tensor_map),
613 indices: indices
614 .into_iter()
615 .map(|it| self.compile_variable(it))
616 .collect(),
617 offsets: offsets
618 .into_iter()
619 .map(|it| self.compile_variable(it))
620 .collect(),
621 },
622 ));
623 }
624 gpu::BarrierOps::Arrive { barrier } => {
625 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
626 barrier: self.compile_variable(barrier),
627 token: self.compile_variable(out.unwrap()),
628 }))
629 }
630 gpu::BarrierOps::ArriveTx {
631 barrier,
632 arrive_count_update,
633 transaction_count_update,
634 } => {
635 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
636 barrier: self.compile_variable(barrier),
637 token: self.compile_variable(out.unwrap()),
638 arrive_count_update: self.compile_variable(arrive_count_update),
639 transaction_count_update: self.compile_variable(transaction_count_update),
640 }))
641 }
642 gpu::BarrierOps::ExpectTx {
643 barrier,
644 transaction_count_update,
645 } => {
646 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
647 barrier: self.compile_variable(barrier),
648 transaction_count_update: self.compile_variable(transaction_count_update),
649 }))
650 }
651 gpu::BarrierOps::Wait { barrier, token } => {
652 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
653 barrier: self.compile_variable(barrier),
654 token: self.compile_variable(token),
655 }))
656 }
657 gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
658 Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
659 barrier: self.compile_variable(barrier),
660 phase: self.compile_variable(phase),
661 }),
662 ),
663 gpu::BarrierOps::ArriveAndWait { barrier } => {
664 let VariableKind::Barrier { level, .. } = barrier.kind else {
665 unreachable!()
666 };
667 instructions.push(Instruction::Barrier(
668 super::barrier::BarrierOps::ArriveAndWait {
669 barrier: self.compile_variable(barrier),
670 level,
671 },
672 ))
673 }
674 },
675 gpu::Operation::Tma(tma_ops) => {
676 self.flags.inst_tma = true;
677 match tma_ops {
678 gpu::TmaOps::TmaStore {
679 source,
680 coordinates,
681 offset_source,
682 } => {
683 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
684 smem_buffer: self.compile_variable(source),
685 smem_offset: self.compile_variable(offset_source),
686 tensor_map: self.compile_variable(out.unwrap()),
687 indices: coordinates
688 .into_iter()
689 .map(|it| self.compile_variable(it))
690 .collect(),
691 });
692 }
693 gpu::TmaOps::CommitGroup => {
694 instructions.push(Instruction::BulkCommitGroup);
695 }
696 gpu::TmaOps::WaitGroup { max_pending } => {
697 instructions.push(Instruction::BulkWaitGroup { max_pending });
698 }
699 gpu::TmaOps::WaitGroupRead { max_pending } => {
700 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
701 }
702 }
703 }
704 gpu::Operation::Marker(_) => {}
705 }
706 }
707
708 fn update_debug_loc(
709 &mut self,
710 instructions: &mut Vec<Instruction<D>>,
711 inst: &gpu::Instruction,
712 ) {
713 if !matches!(inst.operation, Operation::NonSemantic(_)) {
714 match &inst.source_loc {
715 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
716 self.source_loc = Some(loc.clone());
717 instructions.push(Instruction::Line {
718 file: loc.source.file.clone(),
719 line: loc.line,
720 });
721 }
722 _ => {}
723 }
724 }
725 }
726
727 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
728 self.flags.inst_wmma = true;
729
730 let out = self.compile_variable(out.unwrap());
731
732 let inst = match cmma {
733 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
734 frag: out,
735 value: self.compile_variable(value),
736 },
737 gpu::CoopMma::Load {
738 value,
739 stride,
740 offset,
741 layout,
742 } => WmmaInstruction::Load {
743 frag: out,
744 offset: self.compile_variable(offset),
745 value: self.compile_variable(value),
746 stride: self.compile_variable(stride),
747 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
748 },
749 gpu::CoopMma::Execute {
750 mat_a,
751 mat_b,
752 mat_c,
753 } => WmmaInstruction::Execute {
754 frag_a: self.compile_variable(mat_a),
755 frag_b: self.compile_variable(mat_b),
756 frag_c: self.compile_variable(mat_c),
757 frag_d: out,
758 warp_size: self.compilation_options.warp_size,
759 },
760 gpu::CoopMma::ExecuteManual {
761 matrix,
762 registers_a,
763 registers_b,
764 registers_c,
765 } => WmmaInstruction::ExecuteManual {
766 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
767 frag_a: self.compile_variable(registers_a),
768 frag_b: self.compile_variable(registers_b),
769 frag_c: self.compile_variable(registers_c),
770 frag_d: out,
771 },
772 gpu::CoopMma::ExecuteScaled {
773 matrix,
774 registers_a,
775 registers_b,
776 registers_c,
777 scales_a,
778 scales_b,
779 scales_factor,
780 } => WmmaInstruction::ExecuteScaled {
781 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
782 frag_a: self.compile_variable(registers_a),
783 frag_b: self.compile_variable(registers_b),
784 frag_c: self.compile_variable(registers_c),
785 frag_d: out,
786
787 scales_a: self.compile_variable(scales_a),
788 scales_b: self.compile_variable(scales_b),
789 scales_factor,
790 },
791 gpu::CoopMma::Store {
792 mat,
793 stride,
794 offset,
795 layout,
796 } => {
797 self.flags.indexes.unit_pos = true;
798 self.flags.indexes.plane_index = true;
799 WmmaInstruction::Store {
800 output: out,
801 offset: self.compile_variable(offset),
802 frag: self.compile_variable(mat),
803 stride: self.compile_variable(stride),
804 layout: self
805 .compile_matrix_layout(layout)
806 .expect("Layout required for store instruction"),
807 }
808 }
809 gpu::CoopMma::LoadMatrix {
810 buffer,
811 offset,
812 line_size,
813 factor,
814 transpose,
815 } => WmmaInstruction::LdMatrix {
816 output: out,
817 buffer: self.compile_variable(buffer),
818 offset: self.compile_variable(offset),
819 line_size,
820 factor,
821 transpose,
822 },
823 gpu::CoopMma::StoreMatrix {
824 offset,
825 line_size,
826 registers,
827 factor,
828 transpose,
829 } => WmmaInstruction::StMatrix {
830 registers: self.compile_variable(registers),
831 buffer: out,
832 offset: self.compile_variable(offset),
833 line_size,
834 factor,
835 transpose,
836 },
837 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
838 input: self.compile_variable(input),
839 output: out,
840 },
841 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
842 panic!("Row/Col index should be handled by processors")
843 }
844 };
845
846 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
847
848 Instruction::Wmma(inst)
849 }
850
851 fn compile_metadata(
852 &mut self,
853 metadata: gpu::Metadata,
854 out: Option<gpu::Variable>,
855 ) -> Instruction<D> {
856 let out = out.unwrap();
857 match metadata {
858 gpu::Metadata::Stride { dim, var } => {
859 let position = self.ext_meta_position(var);
860 let offset = self.metadata.stride_offset_index(position);
861 Instruction::ExtendedMetadata {
862 info_offset: self.compile_variable(offset.into()),
863 dim: self.compile_variable(dim),
864 split_meta: self.compilation_options.supports_features.grid_constants,
865 static_offset: self.metadata.static_len(),
866 out: self.compile_variable(out),
867 }
868 }
869 gpu::Metadata::Shape { dim, var } => {
870 let position = self.ext_meta_position(var);
871 let offset = self.metadata.shape_offset_index(position);
872 Instruction::ExtendedMetadata {
873 info_offset: self.compile_variable(offset.into()),
874 dim: self.compile_variable(dim),
875 split_meta: self.compilation_options.supports_features.grid_constants,
876 static_offset: self.metadata.static_len(),
877 out: self.compile_variable(out),
878 }
879 }
880 gpu::Metadata::Rank { var } => {
881 let out = self.compile_variable(out);
882 let pos = self.ext_meta_position(var);
883 let offset = self.metadata.rank_index(pos);
884 super::Instruction::Metadata {
885 info_offset: self.compile_variable(offset.into()),
886 split_meta: self.compilation_options.supports_features.grid_constants,
887 out,
888 }
889 }
890 gpu::Metadata::Length { var } => {
891 let input = self.compile_variable(var);
892 let out = self.compile_variable(out);
893
894 match input {
895 Variable::Slice { .. } => Instruction::SliceLength { input, out },
896 Variable::SharedMemory(_id, _item, length) => {
897 Instruction::ConstLength { length, out }
898 }
899 _ => {
900 let id = input.id().expect("Variable should have id");
901 let offset = self.metadata.len_index(id);
902 Instruction::Metadata {
903 info_offset: self.compile_variable(offset.into()),
904 split_meta: self.compilation_options.supports_features.grid_constants,
905 out,
906 }
907 }
908 }
909 }
910 gpu::Metadata::BufferLength { var } => {
911 let input = self.compile_variable(var);
912 let out = self.compile_variable(out);
913
914 match input {
915 Variable::Slice { .. } => Instruction::SliceLength { input, out },
916 _ => {
917 let id = input.id().expect("Variable should have id");
918 let offset = self.metadata.buffer_len_index(id);
919 Instruction::Metadata {
920 info_offset: self.compile_variable(offset.into()),
921 split_meta: self.compilation_options.supports_features.grid_constants,
922 out,
923 }
924 }
925 }
926 }
927 }
928 }
929
930 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
931 match branch {
932 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
933 cond: self.compile_variable(op.cond),
934 instructions: self.compile_scope(&mut op.scope),
935 }),
936 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
937 cond: self.compile_variable(op.cond),
938 instructions_if: self.compile_scope(&mut op.scope_if),
939 instructions_else: self.compile_scope(&mut op.scope_else),
940 }),
941 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
942 value: self.compile_variable(op.value),
943 instructions_default: self.compile_scope(&mut op.scope_default),
944 instructions_cases: op
945 .cases
946 .into_iter()
947 .map(|(val, mut block)| {
948 (self.compile_variable(val), self.compile_scope(&mut block))
949 })
950 .collect(),
951 }),
952 gpu::Branch::Return => instructions.push(Instruction::Return),
953 gpu::Branch::Break => instructions.push(Instruction::Break),
954 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
955 i: self.compile_variable(range_loop.i),
956 start: self.compile_variable(range_loop.start),
957 end: self.compile_variable(range_loop.end),
958 step: range_loop.step.map(|it| self.compile_variable(it)),
959 inclusive: range_loop.inclusive,
960 instructions: self.compile_scope(&mut range_loop.scope),
961 }),
962 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
963 instructions: self.compile_scope(&mut op.scope),
964 }),
965 };
966 }
967
968 fn compile_atomic(
969 &mut self,
970 value: gpu::AtomicOp,
971 out: Option<gpu::Variable>,
972 instructions: &mut Vec<Instruction<D>>,
973 ) {
974 let out = out.unwrap();
975 match value {
976 gpu::AtomicOp::Load(op) => {
977 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
978 }
979 gpu::AtomicOp::Store(op) => {
980 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
981 }
982 gpu::AtomicOp::Swap(op) => {
983 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
984 }
985 gpu::AtomicOp::Add(op) => {
986 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
987 }
988 gpu::AtomicOp::Sub(op) => {
989 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
990 }
991 gpu::AtomicOp::Max(op) => {
992 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
993 }
994 gpu::AtomicOp::Min(op) => {
995 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
996 }
997 gpu::AtomicOp::And(op) => {
998 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
999 }
1000 gpu::AtomicOp::Or(op) => {
1001 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1002 }
1003 gpu::AtomicOp::Xor(op) => {
1004 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1005 }
1006 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1007 input: self.compile_variable(op.input),
1008 cmp: self.compile_variable(op.cmp),
1009 val: self.compile_variable(op.val),
1010 out: self.compile_variable(out),
1011 }),
1012 }
1013 }
1014
1015 fn compile_arithmetic(
1016 &mut self,
1017 value: gpu::Arithmetic,
1018 out: Option<gpu::Variable>,
1019 modes: InstructionModes,
1020 instructions: &mut Vec<Instruction<D>>,
1021 ) {
1022 let out = out.unwrap();
1023 match value {
1024 gpu::Arithmetic::Add(op) => {
1025 instructions.push(Instruction::Add(self.compile_binary(op, out)))
1026 }
1027 gpu::Arithmetic::SaturatingAdd(op) => {
1028 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1029 }
1030 gpu::Arithmetic::Mul(op) => {
1031 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1032 }
1033 gpu::Arithmetic::Div(op) => {
1034 let op = self.compile_binary(op, out);
1035 instructions.push(self.select_fast_float(
1036 out.ty,
1037 modes,
1038 FastMath::AllowReciprocal
1039 | FastMath::ReducedPrecision
1040 | FastMath::UnsignedZero
1041 | FastMath::NotInf,
1042 Instruction::Div(op),
1043 Instruction::FastDiv(op),
1044 ))
1045 }
1046 gpu::Arithmetic::Sub(op) => {
1047 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1048 }
1049 gpu::Arithmetic::SaturatingSub(op) => {
1050 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1051 }
1052 gpu::Arithmetic::MulHi(op) => {
1053 let instruction = Instruction::HiMul(self.compile_binary(op, out));
1054 D::register_instruction_extension(&mut self.extensions, &instruction);
1055 instructions.push(instruction)
1056 }
1057 gpu::Arithmetic::Modulo(op) => {
1058 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1059 }
1060 gpu::Arithmetic::Abs(op) => {
1061 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1062 }
1063 gpu::Arithmetic::Exp(op) => {
1064 let op = self.compile_unary(op, out);
1065 instructions.push(self.select_fast_float(
1066 out.ty,
1067 modes,
1068 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1069 Instruction::Exp(op),
1070 Instruction::FastExp(op),
1071 ));
1072 }
1073 gpu::Arithmetic::Log(op) => {
1074 let op = self.compile_unary(op, out);
1075 instructions.push(self.select_fast_float(
1076 out.ty,
1077 modes,
1078 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1079 Instruction::Log(op),
1080 Instruction::FastLog(op),
1081 ));
1082 }
1083 gpu::Arithmetic::Log1p(op) => {
1084 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1085 }
1086 gpu::Arithmetic::Cos(op) => {
1087 let op = self.compile_unary(op, out);
1088 instructions.push(self.select_fast_float(
1089 out.ty,
1090 modes,
1091 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1092 Instruction::Cos(op),
1093 Instruction::FastCos(op),
1094 ));
1095 }
1096 gpu::Arithmetic::Sin(op) => {
1097 let op = self.compile_unary(op, out);
1098 instructions.push(self.select_fast_float(
1099 out.ty,
1100 modes,
1101 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1102 Instruction::Sin(op),
1103 Instruction::FastSin(op),
1104 ));
1105 }
1106 gpu::Arithmetic::Tan(op) => {
1107 instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1108 }
1109 gpu::Arithmetic::Tanh(op) => {
1110 let op = self.compile_unary(op, out);
1111 let instruction = Instruction::Tanh(op);
1112 D::register_instruction_extension(&mut self.extensions, &instruction);
1113 if self.compilation_options.supports_features.fast_tanh {
1114 instructions.push(self.select_fast_float(
1115 out.ty,
1116 modes,
1117 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1118 instruction,
1119 Instruction::FastTanh(op),
1120 ))
1121 } else {
1122 instructions.push(instruction);
1123 }
1124 }
1125 gpu::Arithmetic::Sinh(op) => {
1126 let instruction = Instruction::Sinh(self.compile_unary(op, out));
1127 D::register_instruction_extension(&mut self.extensions, &instruction);
1128 instructions.push(instruction)
1129 }
1130 gpu::Arithmetic::Cosh(op) => {
1131 let instruction = Instruction::Cosh(self.compile_unary(op, out));
1132 D::register_instruction_extension(&mut self.extensions, &instruction);
1133 instructions.push(instruction)
1134 }
1135 gpu::Arithmetic::ArcCos(op) => {
1136 let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1137 D::register_instruction_extension(&mut self.extensions, &instruction);
1138 instructions.push(instruction)
1139 }
1140 gpu::Arithmetic::ArcSin(op) => {
1141 let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1142 D::register_instruction_extension(&mut self.extensions, &instruction);
1143 instructions.push(instruction)
1144 }
1145 gpu::Arithmetic::ArcTan(op) => {
1146 let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1147 D::register_instruction_extension(&mut self.extensions, &instruction);
1148 instructions.push(instruction)
1149 }
1150 gpu::Arithmetic::ArcSinh(op) => {
1151 let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1152 D::register_instruction_extension(&mut self.extensions, &instruction);
1153 instructions.push(instruction)
1154 }
1155 gpu::Arithmetic::ArcCosh(op) => {
1156 let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1157 D::register_instruction_extension(&mut self.extensions, &instruction);
1158 instructions.push(instruction)
1159 }
1160 gpu::Arithmetic::ArcTanh(op) => {
1161 let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1162 D::register_instruction_extension(&mut self.extensions, &instruction);
1163 instructions.push(instruction)
1164 }
1165 gpu::Arithmetic::Degrees(op) => {
1166 let instruction = Instruction::Degrees(self.compile_unary(op, out));
1167 D::register_instruction_extension(&mut self.extensions, &instruction);
1168 instructions.push(instruction)
1169 }
1170 gpu::Arithmetic::Radians(op) => {
1171 let instruction = Instruction::Radians(self.compile_unary(op, out));
1172 D::register_instruction_extension(&mut self.extensions, &instruction);
1173 instructions.push(instruction)
1174 }
1175 gpu::Arithmetic::ArcTan2(op) => {
1176 let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1177 D::register_instruction_extension(&mut self.extensions, &instruction);
1178 instructions.push(instruction)
1179 }
1180 gpu::Arithmetic::Powf(op) => {
1181 let op = self.compile_binary(op, out);
1182 instructions.push(self.select_fast_float(
1183 out.ty,
1184 modes,
1185 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1186 Instruction::Powf(op),
1187 Instruction::FastPowf(op),
1188 ))
1189 }
1190 gpu::Arithmetic::Powi(op) => {
1191 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1192 }
1193 gpu::Arithmetic::Sqrt(op) => {
1194 let op = self.compile_unary(op, out);
1195 instructions.push(self.select_fast_float(
1196 out.ty,
1197 modes,
1198 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1199 Instruction::Sqrt(op),
1200 Instruction::FastSqrt(op),
1201 ))
1202 }
1203 gpu::Arithmetic::InverseSqrt(op) => {
1204 let op = self.compile_unary(op, out);
1205 instructions.push(self.select_fast_float(
1206 out.ty,
1207 modes,
1208 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1209 Instruction::InverseSqrt(op),
1210 Instruction::FastInverseSqrt(op),
1211 ))
1212 }
1213 gpu::Arithmetic::Erf(op) => {
1214 let instruction = Instruction::Erf(self.compile_unary(op, out));
1215 D::register_instruction_extension(&mut self.extensions, &instruction);
1216 instructions.push(instruction)
1217 }
1218 gpu::Arithmetic::Max(op) => {
1219 let instruction = Instruction::Max(self.compile_binary(op, out));
1220 D::register_instruction_extension(&mut self.extensions, &instruction);
1221 instructions.push(instruction)
1222 }
1223 gpu::Arithmetic::Min(op) => {
1224 let instruction = Instruction::Min(self.compile_binary(op, out));
1225 D::register_instruction_extension(&mut self.extensions, &instruction);
1226 instructions.push(instruction)
1227 }
1228 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1229 input: self.compile_variable(op.input),
1230 min_value: self.compile_variable(op.min_value),
1231 max_value: self.compile_variable(op.max_value),
1232 out: self.compile_variable(out),
1233 }),
1234 gpu::Arithmetic::Recip(op) => {
1235 let elem = op.input.ty.elem_type();
1236 let input = self.compile_variable(op.input);
1237 let out = self.compile_variable(out);
1238 let lhs = match elem {
1239 gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1240 gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1241 gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1242 gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1243 };
1244 let div = Instruction::Div(BinaryInstruction {
1245 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1246 rhs: input,
1247 out,
1248 });
1249 let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1250
1251 instructions.push(self.select_fast_float(
1252 elem.into(),
1253 modes,
1254 FastMath::AllowReciprocal
1255 | FastMath::ReducedPrecision
1256 | FastMath::UnsignedZero
1257 | FastMath::NotInf,
1258 div,
1259 recip,
1260 ))
1261 }
1262 gpu::Arithmetic::Round(op) => {
1263 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1264 }
1265 gpu::Arithmetic::Floor(op) => {
1266 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1267 }
1268 gpu::Arithmetic::Ceil(op) => {
1269 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1270 }
1271 gpu::Arithmetic::Trunc(op) => {
1272 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1273 }
1274 gpu::Arithmetic::Remainder(op) => {
1275 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1276 }
1277 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1278 a: self.compile_variable(op.a),
1279 b: self.compile_variable(op.b),
1280 c: self.compile_variable(op.c),
1281 out: self.compile_variable(out),
1282 }),
1283 gpu::Arithmetic::Neg(op) => {
1284 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1285 }
1286 gpu::Arithmetic::Normalize(op) => {
1287 let op = self.compile_unary(op, out);
1288 instructions.push(self.select_fast_float(
1289 out.ty,
1290 modes,
1291 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1292 Instruction::Normalize(op),
1293 Instruction::FastNormalize(op),
1294 ))
1295 }
1296 gpu::Arithmetic::Magnitude(op) => {
1297 let op = self.compile_unary(op, out);
1298 instructions.push(self.select_fast_float(
1299 out.ty,
1300 modes,
1301 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1302 Instruction::Magnitude(op),
1303 Instruction::FastMagnitude(op),
1304 ))
1305 }
1306 gpu::Arithmetic::Dot(op) => {
1307 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1308 }
1309 };
1310 }
1311
1312 fn select_fast_float(
1313 &self,
1314 ty: gpu::Type,
1315 modes: InstructionModes,
1316 required_flags: EnumSet<FastMath>,
1317 default: Instruction<D>,
1318 fast: Instruction<D>,
1319 ) -> Instruction<D> {
1320 if !self.compilation_options.supports_features.fast_math
1321 || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1322 {
1323 return default;
1324 }
1325
1326 if modes.fp_math_mode.is_superset(required_flags) {
1327 fast
1328 } else {
1329 default
1330 }
1331 }
1332
1333 fn compile_comparison(
1334 &mut self,
1335 value: gpu::Comparison,
1336 out: Option<gpu::Variable>,
1337 instructions: &mut Vec<Instruction<D>>,
1338 ) {
1339 let out = out.unwrap();
1340 match value {
1341 gpu::Comparison::Equal(op) => {
1342 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1343 }
1344 gpu::Comparison::Lower(op) => {
1345 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1346 }
1347 gpu::Comparison::Greater(op) => {
1348 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1349 }
1350 gpu::Comparison::LowerEqual(op) => {
1351 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1352 }
1353 gpu::Comparison::GreaterEqual(op) => {
1354 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1355 }
1356 gpu::Comparison::NotEqual(op) => {
1357 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1358 }
1359 gpu::Comparison::IsNan(op) => {
1360 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1361 }
1362 gpu::Comparison::IsInf(op) => {
1363 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1364 }
1365 };
1366 }
1367
1368 fn compile_bitwise(
1369 &mut self,
1370 value: gpu::Bitwise,
1371 out: Option<gpu::Variable>,
1372 instructions: &mut Vec<Instruction<D>>,
1373 ) {
1374 let out = out.unwrap();
1375 match value {
1376 gpu::Bitwise::BitwiseOr(op) => {
1377 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1378 }
1379 gpu::Bitwise::BitwiseAnd(op) => {
1380 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1381 }
1382 gpu::Bitwise::BitwiseXor(op) => {
1383 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1384 }
1385 gpu::Bitwise::CountOnes(op) => {
1386 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1387 }
1388 gpu::Bitwise::ReverseBits(op) => {
1389 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1390 }
1391 gpu::Bitwise::ShiftLeft(op) => {
1392 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1393 }
1394 gpu::Bitwise::ShiftRight(op) => {
1395 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1396 }
1397 gpu::Bitwise::BitwiseNot(op) => {
1398 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1399 }
1400 gpu::Bitwise::LeadingZeros(op) => {
1401 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1402 }
1403 gpu::Bitwise::FindFirstSet(op) => {
1404 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1405 D::register_instruction_extension(&mut self.extensions, &instruction);
1406 instructions.push(instruction)
1407 }
1408 };
1409 }
1410
1411 fn compile_operator(
1412 &mut self,
1413 value: gpu::Operator,
1414 out: Option<gpu::Variable>,
1415 instructions: &mut Vec<Instruction<D>>,
1416 ) {
1417 let out = out.unwrap();
1418 match value {
1419 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1420 instructions.push(Instruction::Index(self.compile_index(op, out)));
1421 }
1422 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1423 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1424 }
1425 gpu::Operator::And(op) => {
1426 instructions.push(Instruction::And(self.compile_binary(op, out)))
1427 }
1428 gpu::Operator::Or(op) => {
1429 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1430 }
1431 gpu::Operator::Not(op) => {
1432 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1433 }
1434 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1435 inputs: op
1436 .inputs
1437 .into_iter()
1438 .map(|it| self.compile_variable(it))
1439 .collect(),
1440 out: self.compile_variable(out),
1441 }),
1442 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1443 input: self.compile_variable(op.input),
1444 in_index: self.compile_variable(op.in_index),
1445 out: self.compile_variable(out),
1446 out_index: self.compile_variable(op.out_index),
1447 }),
1448 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1449 input: self.compile_variable(op.input),
1450 in_index: self.compile_variable(op.in_index),
1451 out: self.compile_variable(out),
1452 out_index: self.compile_variable(op.out_index),
1453 len: op.len,
1454 }),
1455 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1456 cond: self.compile_variable(op.cond),
1457 then: self.compile_variable(op.then),
1458 or_else: self.compile_variable(op.or_else),
1459 out: self.compile_variable(out),
1460 }),
1461 gpu::Operator::Cast(op)
1463 if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1464 {
1465 self.flags.elem_f16 = true;
1467 self.flags.elem_bf16 = true;
1468 let vec_in = op.input.ty.line_size();
1469 let packing = out.storage_type().packing_factor();
1470 self.compile_type(op.input.ty.line(packing));
1471 self.compile_type(
1472 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1473 );
1474 self.compile_type(
1475 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1476 );
1477 self.compile_type(
1478 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1479 );
1480 self.compile_type(
1481 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1482 );
1483
1484 let inst = self.compile_unary(op, out);
1485
1486 instructions.push(Instruction::SpecialCast(inst));
1487 }
1488 gpu::Operator::Cast(op) => {
1489 let op = self.compile_unary(op, out);
1490
1491 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1492 self.flags.elem_tf32 = true;
1493 }
1494
1495 instructions.push(Instruction::Assign(op))
1496 }
1497 gpu::Operator::Reinterpret(op) => {
1498 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1499 }
1500 };
1501 }
1502
1503 fn compile_binary(
1504 &mut self,
1505 value: gpu::BinaryOperator,
1506 out: gpu::Variable,
1507 ) -> BinaryInstruction<D> {
1508 BinaryInstruction {
1509 lhs: self.compile_variable(value.lhs),
1510 rhs: self.compile_variable(value.rhs),
1511 out: self.compile_variable(out),
1512 }
1513 }
1514
1515 fn compile_index_assign(
1516 &mut self,
1517 value: gpu::IndexAssignOperator,
1518 out: gpu::Variable,
1519 ) -> IndexAssignInstruction<D> {
1520 IndexAssignInstruction {
1521 index: self.compile_variable(value.index),
1522 value: self.compile_variable(value.value),
1523 line_size: value.line_size,
1524 out: self.compile_variable(out),
1525 }
1526 }
1527
1528 fn compile_index(
1529 &mut self,
1530 value: gpu::IndexOperator,
1531 out: gpu::Variable,
1532 ) -> IndexInstruction<D> {
1533 IndexInstruction {
1534 list: self.compile_variable(value.list),
1535 index: self.compile_variable(value.index),
1536 line_size: value.line_size,
1537 out: self.compile_variable(out),
1538 }
1539 }
1540
1541 fn compile_unary(
1542 &mut self,
1543 value: gpu::UnaryOperator,
1544 out: gpu::Variable,
1545 ) -> UnaryInstruction<D> {
1546 UnaryInstruction {
1547 input: self.compile_variable(value.input),
1548 out: self.compile_variable(out),
1549 }
1550 }
1551
1552 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1553 let item = value.ty;
1554 match value.kind {
1555 gpu::VariableKind::GlobalInputArray(id) => {
1556 Variable::GlobalInputArray(id, self.compile_type(item))
1557 }
1558 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1559 id,
1560 elem: self.compile_storage_type(item.storage_type()),
1561 in_struct: self.compilation_options.supports_features.grid_constants,
1562 },
1563 gpu::VariableKind::TensorMapInput(id) => {
1564 self.flags.inst_tma = true;
1565 Variable::TensorMap(id)
1566 }
1567 gpu::VariableKind::TensorMapOutput(id) => {
1568 self.flags.inst_tma = true;
1569 Variable::TensorMap(id)
1570 }
1571 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1572 id,
1573 item: self.compile_type(item),
1574 },
1575 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1576 id,
1577 item: self.compile_type(item),
1578 },
1579 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1580 id,
1581 item: self.compile_type(item),
1582 },
1583 gpu::VariableKind::GlobalOutputArray(id) => {
1584 Variable::GlobalOutputArray(id, self.compile_type(item))
1585 }
1586 gpu::VariableKind::ConstantScalar(value) => {
1587 Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1588 }
1589 gpu::VariableKind::SharedMemory { id, length, .. } => {
1590 let item = self.compile_type(item);
1591 Variable::SharedMemory(id, item, length)
1592 }
1593 gpu::VariableKind::ConstantArray {
1594 id,
1595 length,
1596 unroll_factor,
1597 } => {
1598 let item = self.compile_type(item);
1599 Variable::ConstantArray(id, item, length * unroll_factor)
1600 }
1601 gpu::VariableKind::Builtin(builtin) => match builtin {
1602 gpu::Builtin::AbsolutePos => {
1603 self.flags.indexes.absolute_pos = true;
1604 Variable::AbsolutePos
1605 }
1606 gpu::Builtin::CubePosCluster
1607 if self.compilation_options.supports_features.clusters =>
1608 {
1609 self.flags.indexes.cluster_pos = true;
1610 Variable::ClusterRank
1611 }
1612 gpu::Builtin::CubePosClusterX
1613 if self.compilation_options.supports_features.clusters =>
1614 {
1615 self.flags.indexes.cluster_pos = true;
1616 Variable::ClusterIndexX
1617 }
1618 gpu::Builtin::CubePosClusterY
1619 if self.compilation_options.supports_features.clusters =>
1620 {
1621 self.flags.indexes.cluster_pos = true;
1622 Variable::ClusterIndexY
1623 }
1624 gpu::Builtin::CubePosClusterZ
1625 if self.compilation_options.supports_features.clusters =>
1626 {
1627 self.flags.indexes.cluster_pos = true;
1628 Variable::ClusterIndexZ
1629 }
1630 gpu::Builtin::CubePosCluster
1633 | gpu::Builtin::CubePosClusterX
1634 | gpu::Builtin::CubePosClusterY
1635 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1636 gpu::Builtin::AbsolutePosX => {
1637 self.flags.indexes.absolute_pos_tuple = true;
1638 Variable::AbsolutePosX
1639 }
1640 gpu::Builtin::AbsolutePosY => {
1641 self.flags.indexes.absolute_pos_tuple = true;
1642 Variable::AbsolutePosY
1643 }
1644 gpu::Builtin::AbsolutePosZ => {
1645 self.flags.indexes.absolute_pos_tuple = true;
1646 Variable::AbsolutePosZ
1647 }
1648 gpu::Builtin::CubeDim => {
1649 self.flags.indexes.cube_dim = true;
1650 Variable::CubeDim
1651 }
1652 gpu::Builtin::CubeDimX => {
1653 self.flags.indexes.cube_dim_tuple = true;
1654 Variable::CubeDimX
1655 }
1656 gpu::Builtin::CubeDimY => {
1657 self.flags.indexes.cube_dim_tuple = true;
1658 Variable::CubeDimY
1659 }
1660 gpu::Builtin::CubeDimZ => {
1661 self.flags.indexes.cube_dim_tuple = true;
1662 Variable::CubeDimZ
1663 }
1664 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1665 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1666 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1667 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1668 gpu::Builtin::CubePos => {
1669 self.flags.indexes.cube_pos = true;
1670 Variable::CubePos
1671 }
1672 gpu::Builtin::CubePosX => {
1673 self.flags.indexes.cube_pos_tuple = true;
1674 Variable::CubePosX
1675 }
1676 gpu::Builtin::CubePosY => {
1677 self.flags.indexes.cube_pos_tuple = true;
1678 Variable::CubePosY
1679 }
1680 gpu::Builtin::CubePosZ => {
1681 self.flags.indexes.cube_pos_tuple = true;
1682 Variable::CubePosZ
1683 }
1684 gpu::Builtin::CubeCount => {
1685 self.flags.indexes.cube_count = true;
1686 Variable::CubeCount
1687 }
1688 gpu::Builtin::CubeCountX => {
1689 self.flags.indexes.cube_count_tuple = true;
1690 Variable::CubeCountX
1691 }
1692 gpu::Builtin::CubeCountY => {
1693 self.flags.indexes.cube_count_tuple = true;
1694 Variable::CubeCountY
1695 }
1696 gpu::Builtin::CubeCountZ => {
1697 self.flags.indexes.cube_count_tuple = true;
1698 Variable::CubeCountZ
1699 }
1700 gpu::Builtin::UnitPos => {
1701 self.flags.indexes.unit_pos = true;
1702 Variable::UnitPos
1703 }
1704 gpu::Builtin::UnitPosX => {
1705 self.flags.indexes.unit_pos_tuple = true;
1706 Variable::UnitPosX
1707 }
1708 gpu::Builtin::UnitPosY => {
1709 self.flags.indexes.unit_pos_tuple = true;
1710 Variable::UnitPosY
1711 }
1712 gpu::Builtin::UnitPosZ => {
1713 self.flags.indexes.unit_pos_tuple = true;
1714 Variable::UnitPosZ
1715 }
1716 gpu::Builtin::PlaneDim => {
1717 self.flags.indexes.plane_dim = true;
1718 Variable::PlaneDim
1719 }
1720 gpu::Builtin::UnitPosPlane => {
1721 self.flags.indexes.unit_pos_plane = true;
1722 Variable::UnitPosPlane
1723 }
1724 },
1725 gpu::VariableKind::LocalArray {
1726 id,
1727 length,
1728 unroll_factor,
1729 } => {
1730 let item = self.compile_type(item);
1731 if !self.local_arrays.iter().any(|s| s.index == id) {
1732 self.local_arrays
1733 .push(LocalArray::new(id, item, length * unroll_factor));
1734 }
1735 Variable::LocalArray(id, item, length)
1736 }
1737 gpu::VariableKind::Matrix { id, mat } => {
1738 self.flags.inst_wmma = true;
1739 Variable::WmmaFragment {
1740 id,
1741 frag: self.compile_matrix(mat),
1742 }
1743 }
1744 gpu::VariableKind::Pipeline { id, num_stages } => {
1745 self.flags.op_pipeline = true;
1746 let pipeline = Variable::Pipeline { id };
1747 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1748 self.pipelines.push(PipelineOps::Init {
1749 pipeline,
1750 num_stages,
1751 });
1752 }
1753 pipeline
1754 }
1755 gpu::VariableKind::Barrier { id, level } => {
1756 self.flags.op_barrier = true;
1757 Variable::Barrier { id, level }
1758 }
1759 gpu::VariableKind::BarrierToken { id, level } => {
1760 self.flags.op_barrier = true;
1761 Variable::BarrierToken { id, level }
1762 }
1763 }
1764 }
1765
1766 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1767 Fragment {
1768 ident: self.compile_matrix_ident(matrix.ident),
1769 m: matrix.m,
1770 n: matrix.n,
1771 k: matrix.k,
1772 elem: self.compile_storage_type(matrix.storage),
1773 layout: self.compile_matrix_layout(matrix.layout),
1774 }
1775 }
1776
1777 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1778 match ident {
1779 gpu::MatrixIdent::A => FragmentIdent::A,
1780 gpu::MatrixIdent::B => FragmentIdent::B,
1781 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1782 }
1783 }
1784
1785 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1786 match layout {
1787 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1788 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1789 gpu::MatrixLayout::Undefined => None,
1790 }
1791 }
1792
1793 fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1794 Binding {
1795 id: binding.id,
1796 item: self.compile_type(binding.ty),
1797 location: binding.location,
1798 size: binding.size,
1799 vis: binding.visibility,
1800 }
1801 }
1802
1803 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1804 let item = match ty {
1805 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1806 gpu::Type::Line(ty, line_size) => {
1807 Item::new(self.compile_storage_type(ty), line_size as usize, false)
1808 }
1809 gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1810 };
1811 if item.elem != super::Elem::TF32 {
1812 self.items.insert(item);
1813 self.items.insert(item.optimized());
1814 } else {
1815 let mut item = item;
1817 item.elem = super::Elem::F32;
1818 self.items.insert(item);
1819 }
1820
1821 item
1822 }
1823
1824 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1825 match value {
1826 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1827 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1828 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1829 FloatKind::E2M1 => {
1830 self.flags.elem_fp4 = true;
1831 Elem::FP4x2(FP4Kind::E2M1)
1832 }
1833 FloatKind::E2M3 => {
1834 self.flags.elem_fp6 = true;
1835 Elem::FP6x2(FP6Kind::E2M3)
1836 }
1837 FloatKind::E3M2 => {
1838 self.flags.elem_fp6 = true;
1839 Elem::FP6(FP6Kind::E3M2)
1840 }
1841 FloatKind::E4M3 => {
1842 self.flags.elem_fp8 = true;
1843 Elem::FP8x2(FP8Kind::E4M3)
1844 }
1845 FloatKind::E5M2 => {
1846 self.flags.elem_fp8 = true;
1847 Elem::FP8x2(FP8Kind::E5M2)
1848 }
1849 FloatKind::UE8M0 => {
1850 self.flags.elem_fp8 = true;
1851 Elem::FP8x2(FP8Kind::UE8M0)
1852 }
1853 FloatKind::F16 => {
1854 self.flags.elem_f16 = true;
1855 Elem::F16x2
1856 }
1857 FloatKind::BF16 => {
1858 self.flags.elem_bf16 = true;
1859 Elem::BF16x2
1860 }
1861 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1862 },
1863 other => unimplemented!("Unsupported storage type: {other}"),
1864 }
1865 }
1866
1867 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1868 match value {
1869 gpu::ElemType::Float(kind) => match kind {
1870 gpu::FloatKind::E2M1 => {
1871 self.flags.elem_fp4 = true;
1872 Elem::FP4(FP4Kind::E2M1)
1873 }
1874 gpu::FloatKind::E2M3 => {
1875 self.flags.elem_fp6 = true;
1876 Elem::FP6(FP6Kind::E2M3)
1877 }
1878 gpu::FloatKind::E3M2 => {
1879 self.flags.elem_fp6 = true;
1880 Elem::FP6(FP6Kind::E3M2)
1881 }
1882 gpu::FloatKind::E4M3 => {
1883 self.flags.elem_fp8 = true;
1884 Elem::FP8(FP8Kind::E4M3)
1885 }
1886 gpu::FloatKind::E5M2 => {
1887 self.flags.elem_fp8 = true;
1888 Elem::FP8(FP8Kind::E5M2)
1889 }
1890 gpu::FloatKind::UE8M0 => {
1891 self.flags.elem_fp8 = true;
1892 Elem::FP8(FP8Kind::UE8M0)
1893 }
1894 gpu::FloatKind::F16 => {
1895 self.flags.elem_f16 = true;
1896 Elem::F16
1897 }
1898 gpu::FloatKind::BF16 => {
1899 self.flags.elem_bf16 = true;
1900 Elem::BF16
1901 }
1902 gpu::FloatKind::TF32 => Elem::TF32,
1903 gpu::FloatKind::Flex32 => Elem::F32,
1904 gpu::FloatKind::F32 => Elem::F32,
1905 gpu::FloatKind::F64 => Elem::F64,
1906 },
1907 gpu::ElemType::Int(kind) => match kind {
1908 gpu::IntKind::I8 => Elem::I8,
1909 gpu::IntKind::I16 => Elem::I16,
1910 gpu::IntKind::I32 => Elem::I32,
1911 gpu::IntKind::I64 => Elem::I64,
1912 },
1913 gpu::ElemType::UInt(kind) => match kind {
1914 gpu::UIntKind::U8 => Elem::U8,
1915 gpu::UIntKind::U16 => Elem::U16,
1916 gpu::UIntKind::U32 => Elem::U32,
1917 gpu::UIntKind::U64 => Elem::U64,
1918 },
1919 gpu::ElemType::Bool => Elem::Bool,
1920 }
1921 }
1922}
1923
1924fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1925 match elem {
1926 gpu::ElemType::Float(kind) => matches!(
1927 kind,
1928 FloatKind::E2M1
1929 | FloatKind::E2M3
1930 | FloatKind::E3M2
1931 | FloatKind::E4M3
1932 | FloatKind::E5M2
1933 | FloatKind::UE8M0
1934 ),
1935 _ => false,
1936 }
1937}
1938
1939fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1940 Variable::ConstantScalar(
1941 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1942 Elem::U32,
1943 )
1944}
1945
1946pub fn register_supported_types(props: &mut DeviceProperties) {
1947 let supported_types = [
1948 gpu::ElemType::UInt(gpu::UIntKind::U8),
1949 gpu::ElemType::UInt(gpu::UIntKind::U16),
1950 gpu::ElemType::UInt(gpu::UIntKind::U32),
1951 gpu::ElemType::UInt(gpu::UIntKind::U64),
1952 gpu::ElemType::Int(gpu::IntKind::I8),
1953 gpu::ElemType::Int(gpu::IntKind::I16),
1954 gpu::ElemType::Int(gpu::IntKind::I32),
1955 gpu::ElemType::Int(gpu::IntKind::I64),
1956 gpu::ElemType::Float(gpu::FloatKind::BF16),
1957 gpu::ElemType::Float(gpu::FloatKind::F16),
1958 gpu::ElemType::Float(gpu::FloatKind::F32),
1959 gpu::ElemType::Float(gpu::FloatKind::Flex32),
1960 gpu::ElemType::Bool,
1963 ];
1964
1965 let supported_atomic_types = [
1966 gpu::ElemType::Int(gpu::IntKind::I32),
1967 gpu::ElemType::Int(gpu::IntKind::I64),
1968 gpu::ElemType::UInt(gpu::UIntKind::U32),
1969 gpu::ElemType::UInt(gpu::UIntKind::U64),
1970 gpu::ElemType::Float(gpu::FloatKind::F32),
1971 ];
1972
1973 for ty in supported_types {
1974 props.register_type_usage(ty, TypeUsage::all_scalar());
1975 }
1976
1977 for ty in supported_atomic_types {
1978 props.register_type_usage(
1979 gpu::StorageType::Atomic(ty),
1980 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1981 );
1982 }
1983}