1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_common::backtrace::BackTrace;
5use cubecl_core::ir::{self as gpu, OpaqueType, StorageType};
6use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind};
7use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
8use cubecl_core::{CubeDim, ir::ElemType};
9use cubecl_core::{
10 ir::{Operation, SourceLoc},
11 prelude::{FastMath, KernelDefinition},
12};
13use cubecl_opt::{Optimizer, SharedLiveness};
14use cubecl_runtime::compiler::CompilationError;
15use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage, compiler::Compiler};
16
17use crate::shared::MmaShape;
18
19use super::{
20 BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
21 Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
22 Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
23};
24use super::{FP4Kind, barrier::BarrierOps};
25use super::{FP8Kind, pipeline::PipelineOps};
26
27pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
28 std::sync::atomic::AtomicU32::new(0);
29
30#[derive(Clone, Debug)]
31pub struct CompilationOptions {
32 pub warp_size: u32,
33 pub supports_features: CppSupportedFeatures,
34}
35
36#[derive(Clone, Debug, Default)]
37pub struct CppSupportedFeatures {
38 pub grid_constants: bool,
39 pub clusters: bool,
40 pub fast_math: bool,
41 pub fast_tanh: bool,
42 pub elect_sync: bool,
43}
44
45impl Default for CompilationOptions {
46 fn default() -> Self {
47 Self {
48 warp_size: 32,
49 supports_features: Default::default(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Default)]
57pub struct CubeIndexFlags {
58 pub absolute_pos: bool,
59 pub absolute_pos_tuple: bool,
60 pub cube_count: bool,
61 pub cube_count_tuple: bool,
62 pub cube_dim: bool,
63 pub cube_dim_tuple: bool,
64 pub cube_pos: bool,
65 pub cube_pos_tuple: bool,
66 pub plane_dim: bool,
67 pub plane_dim_checked: bool,
68 pub plane_index: bool,
69 pub unit_pos: bool,
70 pub unit_pos_tuple: bool,
71 pub unit_pos_plane: bool,
72 pub cluster_pos: bool,
73}
74
75#[derive(Debug, Clone, Default)]
77pub struct Flags {
78 pub elem_fp4: bool,
79 pub elem_fp6: bool,
80 pub elem_fp8: bool,
81 pub elem_bf16: bool,
82 pub elem_f16: bool,
83 pub elem_tf32: bool,
84 pub indexes: CubeIndexFlags,
85 pub op_barrier: bool,
86 pub op_pipeline: bool,
87 pub inst_tma: bool,
88 pub inst_tma_im2col: bool,
89 pub inst_wmma: bool,
90 pub inst_ptx_wrappers: bool,
91 pub inst_async_copy: bool,
92 pub use_grid_constants: bool,
93 pub static_meta_length: usize,
94 pub has_dynamic_meta: bool,
95 pub cube_dim: CubeDim,
96 pub cluster_dim: Option<CubeDim>,
97}
98
99#[allow(clippy::too_many_arguments)]
100#[derive(Clone, Debug, Default)]
101pub struct CppCompiler<D: Dialect> {
102 barriers: Vec<BarrierOps<D>>,
103 compilation_options: CompilationOptions,
104 const_arrays: Vec<ConstArray<D>>,
105 ext_meta_positions: Vec<u32>,
106 cluster_dim: CubeDim,
107 extensions: Vec<D::Extension>,
108 flags: Flags,
109 items: HashSet<Item<D>>,
110 local_arrays: Vec<LocalArray<D>>,
111 metadata: cubecl_core::Metadata,
112 pipelines: Vec<PipelineOps<D>>,
113 source_loc: Option<SourceLoc>,
114 strategy: ExecutionMode,
115}
116
117impl<D: Dialect> Compiler for CppCompiler<D> {
118 type Representation = ComputeKernel<D>;
119 type CompilationOptions = CompilationOptions;
120
121 fn compile(
122 &mut self,
123 mut kernel: KernelDefinition,
124 compilation_options: &Self::CompilationOptions,
125 strategy: ExecutionMode,
126 ) -> Result<Self::Representation, CompilationError> {
127 let errors = kernel.body.pop_errors();
128 if !errors.is_empty() {
129 let mut reason = "Can't compile cpp kernel\nCaused by:\n ".to_string();
130 for error in errors {
131 reason += error.as_str();
132 reason += "\n";
133 }
134
135 return Err(CompilationError::Validation {
136 reason,
137 backtrace: BackTrace::capture(),
138 });
139 }
140
141 self.compilation_options = compilation_options.clone();
142 self.strategy = strategy;
143
144 if !self.compilation_options.supports_features.clusters {
145 kernel.options.cluster_dim = None;
146 }
147 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
148
149 let ir = self.clone().compile_ir(kernel);
150 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
151 Ok(ir)
152 }
153
154 fn elem_size(&self, elem: gpu::ElemType) -> usize {
155 elem.size()
156 }
157
158 fn extension(&self) -> &'static str {
159 "cpp"
160 }
161}
162
163impl<D: Dialect> CppCompiler<D> {
164 fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
165 self.build_metadata(&value);
166
167 let instructions = self.compile_scope(&mut value.body.clone());
168 let tensor_maps = value
169 .tensor_maps
170 .into_iter()
171 .map(|b| self.compile_binding(b))
172 .collect();
173 let buffers = value
174 .buffers
175 .into_iter()
176 .map(|b| self.compile_binding(b))
177 .collect();
178 let scalars = value
179 .scalars
180 .into_iter()
181 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
182 .collect();
183
184 let flags = Flags {
186 indexes: D::builtin_rules(&self.flags.indexes),
187 inst_wmma: self.flags.inst_wmma,
188 op_pipeline: self.flags.op_pipeline,
189 op_barrier: self.flags.op_barrier,
190 elem_fp4: self.flags.elem_fp4,
191 elem_fp6: self.flags.elem_fp6,
192 elem_fp8: self.flags.elem_fp8,
193 elem_bf16: self.flags.elem_bf16,
194 elem_f16: self.flags.elem_f16,
195 elem_tf32: self.flags.elem_tf32,
196 inst_tma: self.flags.inst_tma,
197 inst_tma_im2col: self.flags.inst_tma_im2col,
198 inst_async_copy: self.flags.inst_async_copy,
199 inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
200 use_grid_constants: self.compilation_options.supports_features.grid_constants,
201 has_dynamic_meta: self.metadata.static_len() > 0,
204 static_meta_length: self.metadata.static_len() as usize,
205 cube_dim: value.cube_dim,
206 cluster_dim: value.options.cluster_dim,
207 };
208
209 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
210 let shared_allocs = opt.analysis::<SharedLiveness>();
211 let shared_memories = shared_allocs
212 .allocations
213 .values()
214 .map(|alloc| match alloc.smem {
215 cubecl_opt::SharedMemory::Array {
216 id,
217 length,
218 ty,
219 align,
220 } => SharedMemory::Array {
221 index: id,
222 item: self.compile_type(ty),
223 length,
224 align,
225 offset: alloc.offset,
226 },
227 cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
228 index: id,
229 item: self.compile_type(ty),
230 align,
231 offset: alloc.offset,
232 },
233 })
234 .collect();
235
236 let body = Body {
237 instructions,
238 shared_memories,
239 pipelines: self.pipelines,
240 barriers: self.barriers,
241 const_arrays: self.const_arrays,
242 local_arrays: self.local_arrays,
243 };
244
245 let mut cluster_dim = value.options.cluster_dim;
246 if !self.compilation_options.supports_features.clusters {
247 cluster_dim = None;
248 }
249
250 ComputeKernel {
251 tensor_maps,
252 buffers,
253 scalars,
254 meta_static_len: self.metadata.static_len() as usize,
255 cube_dim: value.cube_dim,
256 body,
257 extensions: self.extensions,
258 flags,
259 items: self.items,
260 kernel_name: value.options.kernel_name,
261 cluster_dim,
262 }
263 }
264
265 fn build_metadata(&mut self, value: &KernelDefinition) {
266 let mut num_ext = 0;
267
268 let mut all_meta: Vec<_> = value
269 .buffers
270 .iter()
271 .chain(value.tensor_maps.iter())
272 .map(|buf| (buf.id, buf.has_extended_meta))
273 .collect();
274
275 all_meta.sort_by_key(|(id, _)| *id);
276
277 for (_, has_extended_meta) in &all_meta {
278 self.ext_meta_positions.push(num_ext);
279 if *has_extended_meta {
280 num_ext += 1;
281 }
282 }
283
284 let num_meta = all_meta.len();
285
286 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
287 }
288
289 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
290 let id = var.index().expect("Variable should have index");
291 self.ext_meta_positions[id as usize]
292 }
293
294 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
295 let mut instructions = Vec::new();
296
297 let const_arrays = scope
298 .const_arrays
299 .drain(..)
300 .map(|(var, values)| ConstArray {
301 index: var.index().unwrap(),
302 item: self.compile_type(var.ty),
303 size: values.len() as u32,
304 values: values
305 .into_iter()
306 .map(|val| self.compile_variable(val))
307 .collect(),
308 })
309 .collect::<Vec<_>>();
310 self.const_arrays.extend(const_arrays);
311
312 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
313 let dialect_processors = D::processors();
314 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
315 processors.extend(dialect_processors.iter().map(|it| &**it));
316
317 let processing = scope.process(processors);
318
319 for var in processing.variables {
320 instructions.push(Instruction::DeclareVariable {
321 var: self.compile_variable(var),
322 });
323 }
324
325 processing
326 .instructions
327 .into_iter()
328 .for_each(|op| self.compile_instruction(&mut instructions, op));
329
330 instructions
331 }
332
333 fn compile_instruction(
334 &mut self,
335 instructions: &mut Vec<Instruction<D>>,
336 instruction: gpu::Instruction,
337 ) {
338 self.update_debug_loc(instructions, &instruction);
339 let out = instruction.out;
340 match instruction.operation {
341 gpu::Operation::Copy(variable) => {
342 instructions.push(Instruction::Assign(UnaryInstruction {
343 input: self.compile_variable(variable),
344 out: self.compile_variable(out.unwrap()),
345 }));
346 }
347 gpu::Operation::Arithmetic(op) => {
348 self.compile_arithmetic(op, out, instruction.modes, instructions)
349 }
350 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
351 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
352 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
353 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
354 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
355 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
356 gpu::Operation::Synchronization(val) => match val {
357 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
358 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
359 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
360 gpu::Synchronization::SyncAsyncProxyShared => {
361 self.flags.inst_tma = true;
362 instructions.push(Instruction::ProxyAsyncToSharedFence)
363 }
364 },
365 gpu::Operation::Plane(op) => {
366 self.flags.indexes.plane_dim_checked = true;
367 let out = self.compile_variable(out.unwrap());
368 match op {
369 gpu::Plane::Sum(op) => {
370 let instruction = WarpInstruction::ReduceSum {
371 input: self.compile_variable(op.input),
372 out,
373 };
374 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
375 instructions.push(Instruction::Warp(instruction));
376 }
377 gpu::Plane::InclusiveSum(op) => {
378 self.flags.indexes.unit_pos_plane = true;
379 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
380 input: self.compile_variable(op.input),
381 out,
382 }))
383 }
384 gpu::Plane::InclusiveProd(op) => {
385 self.flags.indexes.unit_pos_plane = true;
386 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
387 input: self.compile_variable(op.input),
388 out,
389 }))
390 }
391 gpu::Plane::ExclusiveSum(op) => {
392 self.flags.indexes.unit_pos_plane = true;
393 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
394 input: self.compile_variable(op.input),
395 out,
396 }))
397 }
398 gpu::Plane::ExclusiveProd(op) => {
399 self.flags.indexes.unit_pos_plane = true;
400 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
401 input: self.compile_variable(op.input),
402 out,
403 }))
404 }
405 gpu::Plane::Prod(op) => {
406 let instruction = WarpInstruction::ReduceProd {
407 input: self.compile_variable(op.input),
408 out,
409 };
410 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
411 instructions.push(Instruction::Warp(instruction))
412 }
413 gpu::Plane::Max(op) => {
414 let instruction = WarpInstruction::ReduceMax {
415 input: self.compile_variable(op.input),
416 out,
417 };
418 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
419 instructions.push(Instruction::Warp(instruction))
420 }
421 gpu::Plane::Min(op) => {
422 let instruction = WarpInstruction::ReduceMin {
423 input: self.compile_variable(op.input),
424 out,
425 };
426 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
427 instructions.push(Instruction::Warp(instruction))
428 }
429 gpu::Plane::Elect => {
430 if self.compilation_options.supports_features.elect_sync {
431 self.flags.inst_ptx_wrappers = true;
432 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
433 } else {
434 instructions
435 .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
436 }
437 }
438 gpu::Plane::All(op) => {
439 instructions.push(Instruction::Warp(WarpInstruction::All {
440 input: self.compile_variable(op.input),
441 out,
442 }))
443 }
444 gpu::Plane::Any(op) => {
445 instructions.push(Instruction::Warp(WarpInstruction::Any {
446 input: self.compile_variable(op.input),
447 out,
448 }))
449 }
450 gpu::Plane::Ballot(op) => {
451 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
452 input: self.compile_variable(op.input),
453 out,
454 }))
455 }
456 gpu::Plane::Broadcast(op) => {
457 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
458 input: self.compile_variable(op.lhs),
459 id: self.compile_variable(op.rhs),
460 out,
461 }))
462 }
463 gpu::Plane::Shuffle(op) => {
464 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
465 input: self.compile_variable(op.lhs),
466 src_lane: self.compile_variable(op.rhs),
467 out,
468 }))
469 }
470 gpu::Plane::ShuffleXor(op) => {
471 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
472 input: self.compile_variable(op.lhs),
473 mask: self.compile_variable(op.rhs),
474 out,
475 }))
476 }
477 gpu::Plane::ShuffleUp(op) => {
478 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
479 input: self.compile_variable(op.lhs),
480 delta: self.compile_variable(op.rhs),
481 out,
482 }))
483 }
484 gpu::Plane::ShuffleDown(op) => {
485 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
486 input: self.compile_variable(op.lhs),
487 delta: self.compile_variable(op.rhs),
488 out,
489 }))
490 }
491 }
492 }
493 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
494 gpu::Operation::NonSemantic(debug) => match debug {
495 gpu::NonSemantic::Print {
496 format_string,
497 args,
498 } => instructions.push(Instruction::Printf {
499 format_string,
500 args: args
501 .into_iter()
502 .map(|arg| self.compile_variable(arg))
503 .collect(),
504 }),
505 gpu::NonSemantic::Comment { content } => {
506 instructions.push(Instruction::Comment { content })
507 }
508 _ => {}
510 },
511 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
512 gpu::BarrierOps::Declare { barrier } => {
513 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
514 else {
515 unreachable!()
516 };
517 let barrier = self.compile_variable(barrier);
518 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
519 barrier,
520 level,
521 }));
522 }
523 gpu::BarrierOps::Init {
524 barrier,
525 is_elected,
526 arrival_count,
527 } => {
528 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
529 else {
530 unreachable!()
531 };
532 let barrier = self.compile_variable(barrier);
533 let arrival_count = self.compile_variable(arrival_count);
534 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
535 barrier,
536 is_elected: self.compile_variable(is_elected),
537 arrival_count,
538 level,
539 }));
540 }
541 gpu::BarrierOps::InitManual {
542 barrier,
543 arrival_count,
544 } => {
545 let barrier = self.compile_variable(barrier);
546 let arrival_count = self.compile_variable(arrival_count);
547 instructions.push(Instruction::Barrier(
548 super::barrier::BarrierOps::InitManual {
549 barrier,
550 arrival_count,
551 },
552 ));
553 }
554 gpu::BarrierOps::MemCopyAsync {
555 barrier,
556 source,
557 source_length,
558 offset_source,
559 offset_out,
560 } => {
561 instructions.push(Instruction::Barrier(
562 super::barrier::BarrierOps::MemCopyAsync {
563 barrier: self.compile_variable(barrier),
564 source: self.compile_variable(source),
565 destination: self.compile_variable(out.unwrap()),
566 source_length: self.compile_variable(source_length),
567 offset_source: self.compile_variable(offset_source),
568 offset_out: self.compile_variable(offset_out),
569 cooperative: false,
570 },
571 ));
572 }
573 gpu::BarrierOps::MemCopyAsyncCooperative {
574 barrier,
575 source,
576 source_length,
577 offset_source,
578 offset_out,
579 } => {
580 instructions.push(Instruction::Barrier(
581 super::barrier::BarrierOps::MemCopyAsync {
582 barrier: self.compile_variable(barrier),
583 source: self.compile_variable(source),
584 destination: self.compile_variable(out.unwrap()),
585 source_length: self.compile_variable(source_length),
586 offset_source: self.compile_variable(offset_source),
587 offset_out: self.compile_variable(offset_out),
588 cooperative: true,
589 },
590 ));
591 }
592 gpu::BarrierOps::MemCopyAsyncTx {
593 barrier,
594 source,
595 source_length,
596 offset_source,
597 offset_out,
598 } => {
599 instructions.push(Instruction::Barrier(
600 super::barrier::BarrierOps::MemCopyAsyncTx {
601 barrier: self.compile_variable(barrier),
602 source: self.compile_variable(source),
603 destination: self.compile_variable(out.unwrap()),
604 source_length: self.compile_variable(source_length),
605 offset_source: self.compile_variable(offset_source),
606 offset_out: self.compile_variable(offset_out),
607 },
608 ));
609 }
610 gpu::BarrierOps::CopyAsync {
611 source,
612 source_length,
613 offset_source,
614 offset_out,
615 copy_length,
616 checked,
617 } => {
618 self.flags.inst_async_copy = true;
619 instructions.push(Instruction::Barrier(
620 super::barrier::BarrierOps::CopyAsync {
621 source: self.compile_variable(source),
622 destination: self.compile_variable(out.unwrap()),
623 source_length: self.compile_variable(source_length),
624 offset_source: self.compile_variable(offset_source),
625 offset_out: self.compile_variable(offset_out),
626 copy_size: copy_length,
627 checked,
628 },
629 ));
630 }
631 gpu::BarrierOps::TmaLoad {
632 barrier,
633 tensor_map,
634 offset_out,
635 indices,
636 } => {
637 instructions.push(Instruction::Barrier(
638 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
639 barrier: self.compile_variable(barrier),
640 smem_buffer: self.compile_variable(out.unwrap()),
641 smem_offset: self.compile_variable(offset_out),
642 tensor_map: self.compile_variable(tensor_map),
643 indices: indices
644 .into_iter()
645 .map(|it| self.compile_variable(it))
646 .collect(),
647 },
648 ));
649 }
650 gpu::BarrierOps::TmaLoadIm2col {
651 barrier,
652 tensor_map,
653 offset_out,
654 indices,
655 offsets,
656 } => {
657 self.flags.inst_tma_im2col = true;
658 instructions.push(Instruction::Barrier(
659 super::barrier::BarrierOps::TmaLoadIm2col {
660 barrier: self.compile_variable(barrier),
661 smem_buffer: self.compile_variable(out.unwrap()),
662 smem_offset: self.compile_variable(offset_out),
663 tensor_map: self.compile_variable(tensor_map),
664 indices: indices
665 .into_iter()
666 .map(|it| self.compile_variable(it))
667 .collect(),
668 offsets: offsets
669 .into_iter()
670 .map(|it| self.compile_variable(it))
671 .collect(),
672 },
673 ));
674 }
675 gpu::BarrierOps::Arrive { barrier } => {
676 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
677 barrier: self.compile_variable(barrier),
678 token: self.compile_variable(out.unwrap()),
679 }))
680 }
681 gpu::BarrierOps::ArriveTx {
682 barrier,
683 arrive_count_update,
684 transaction_count_update,
685 } => {
686 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
687 barrier: self.compile_variable(barrier),
688 token: self.compile_variable(out.unwrap()),
689 arrive_count_update: self.compile_variable(arrive_count_update),
690 transaction_count_update: self.compile_variable(transaction_count_update),
691 }))
692 }
693 gpu::BarrierOps::CommitCopyAsync { barrier } => {
694 self.flags.inst_async_copy = true;
695 instructions.push(Instruction::Barrier(
696 super::barrier::BarrierOps::ArriveCopyAsync {
697 barrier: self.compile_variable(barrier),
698 },
699 ))
700 }
701 gpu::BarrierOps::ExpectTx {
702 barrier,
703 transaction_count_update,
704 } => {
705 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
706 barrier: self.compile_variable(barrier),
707 transaction_count_update: self.compile_variable(transaction_count_update),
708 }))
709 }
710 gpu::BarrierOps::Wait { barrier, token } => {
711 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
712 barrier: self.compile_variable(barrier),
713 token: self.compile_variable(token),
714 }))
715 }
716 gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
717 Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
718 barrier: self.compile_variable(barrier),
719 phase: self.compile_variable(phase),
720 }),
721 ),
722 gpu::BarrierOps::ArriveAndWait { barrier } => {
723 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
724 else {
725 unreachable!()
726 };
727 instructions.push(Instruction::Barrier(
728 super::barrier::BarrierOps::ArriveAndWait {
729 barrier: self.compile_variable(barrier),
730 level,
731 },
732 ))
733 }
734 },
735 gpu::Operation::Tma(tma_ops) => {
736 self.flags.inst_tma = true;
737 match tma_ops {
738 gpu::TmaOps::TmaStore {
739 source,
740 coordinates,
741 offset_source,
742 } => {
743 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
744 smem_buffer: self.compile_variable(source),
745 smem_offset: self.compile_variable(offset_source),
746 tensor_map: self.compile_variable(out.unwrap()),
747 indices: coordinates
748 .into_iter()
749 .map(|it| self.compile_variable(it))
750 .collect(),
751 });
752 }
753 gpu::TmaOps::CommitGroup => {
754 instructions.push(Instruction::BulkCommitGroup);
755 }
756 gpu::TmaOps::WaitGroup { max_pending } => {
757 instructions.push(Instruction::BulkWaitGroup { max_pending });
758 }
759 gpu::TmaOps::WaitGroupRead { max_pending } => {
760 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
761 }
762 }
763 }
764 gpu::Operation::Marker(_) => {}
765 }
766 }
767
768 fn update_debug_loc(
769 &mut self,
770 instructions: &mut Vec<Instruction<D>>,
771 inst: &gpu::Instruction,
772 ) {
773 if !matches!(inst.operation, Operation::NonSemantic(_)) {
774 match &inst.source_loc {
775 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
776 self.source_loc = Some(loc.clone());
777 instructions.push(Instruction::Line {
778 file: loc.source.file.clone(),
779 line: loc.line,
780 });
781 }
782 _ => {}
783 }
784 }
785 }
786
787 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
788 self.flags.inst_wmma = true;
789
790 let out = self.compile_variable(out.unwrap());
791
792 let inst = match cmma {
793 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
794 frag: out,
795 value: self.compile_variable(value),
796 },
797 gpu::CoopMma::Load {
798 value,
799 stride,
800 offset,
801 layout,
802 } => WmmaInstruction::Load {
803 frag: out,
804 offset: self.compile_variable(offset),
805 value: self.compile_variable(value),
806 stride: self.compile_variable(stride),
807 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
808 },
809 gpu::CoopMma::Execute {
810 mat_a,
811 mat_b,
812 mat_c,
813 } => WmmaInstruction::Execute {
814 frag_a: self.compile_variable(mat_a),
815 frag_b: self.compile_variable(mat_b),
816 frag_c: self.compile_variable(mat_c),
817 frag_d: out,
818 warp_size: self.compilation_options.warp_size,
819 },
820 gpu::CoopMma::ExecuteManual {
821 matrix,
822 registers_a,
823 registers_b,
824 registers_c,
825 } => WmmaInstruction::ExecuteManual {
826 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
827 frag_a: self.compile_variable(registers_a),
828 frag_b: self.compile_variable(registers_b),
829 frag_c: self.compile_variable(registers_c),
830 frag_d: out,
831 },
832 gpu::CoopMma::ExecuteScaled {
833 matrix,
834 registers_a,
835 registers_b,
836 registers_c,
837 scales_a,
838 scales_b,
839 scales_factor,
840 } => WmmaInstruction::ExecuteScaled {
841 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
842 frag_a: self.compile_variable(registers_a),
843 frag_b: self.compile_variable(registers_b),
844 frag_c: self.compile_variable(registers_c),
845 frag_d: out,
846
847 scales_a: self.compile_variable(scales_a),
848 scales_b: self.compile_variable(scales_b),
849 scales_factor,
850 },
851 gpu::CoopMma::Store {
852 mat,
853 stride,
854 offset,
855 layout,
856 } => {
857 self.flags.indexes.unit_pos = true;
858 self.flags.indexes.plane_index = true;
859 WmmaInstruction::Store {
860 output: out,
861 offset: self.compile_variable(offset),
862 frag: self.compile_variable(mat),
863 stride: self.compile_variable(stride),
864 layout: self
865 .compile_matrix_layout(layout)
866 .expect("Layout required for store instruction"),
867 }
868 }
869 gpu::CoopMma::LoadMatrix {
870 buffer,
871 offset,
872 line_size,
873 factor,
874 transpose,
875 } => WmmaInstruction::LdMatrix {
876 output: out,
877 buffer: self.compile_variable(buffer),
878 offset: self.compile_variable(offset),
879 line_size,
880 factor,
881 transpose,
882 },
883 gpu::CoopMma::StoreMatrix {
884 offset,
885 line_size,
886 registers,
887 factor,
888 transpose,
889 } => WmmaInstruction::StMatrix {
890 registers: self.compile_variable(registers),
891 buffer: out,
892 offset: self.compile_variable(offset),
893 line_size,
894 factor,
895 transpose,
896 },
897 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
898 input: self.compile_variable(input),
899 output: out,
900 },
901 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
902 panic!("Row/Col index should be handled by processors")
903 }
904 };
905
906 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
907
908 Instruction::Wmma(inst)
909 }
910
911 fn compile_metadata(
912 &mut self,
913 metadata: gpu::Metadata,
914 out: Option<gpu::Variable>,
915 ) -> Instruction<D> {
916 let out = out.unwrap();
917 match metadata {
918 gpu::Metadata::Stride { dim, var } => {
919 let position = self.ext_meta_position(var);
920 let offset = self.metadata.stride_offset_index(position);
921 Instruction::ExtendedMetadata {
922 info_offset: self.compile_variable(offset.into()),
923 dim: self.compile_variable(dim),
924 split_meta: self.compilation_options.supports_features.grid_constants,
925 static_offset: self.metadata.static_len(),
926 out: self.compile_variable(out),
927 }
928 }
929 gpu::Metadata::Shape { dim, var } => {
930 let position = self.ext_meta_position(var);
931 let offset = self.metadata.shape_offset_index(position);
932 Instruction::ExtendedMetadata {
933 info_offset: self.compile_variable(offset.into()),
934 dim: self.compile_variable(dim),
935 split_meta: self.compilation_options.supports_features.grid_constants,
936 static_offset: self.metadata.static_len(),
937 out: self.compile_variable(out),
938 }
939 }
940 gpu::Metadata::Rank { var } => {
941 let out = self.compile_variable(out);
942 let pos = self.ext_meta_position(var);
943 let offset = self.metadata.rank_index(pos);
944 super::Instruction::Metadata {
945 info_offset: self.compile_variable(offset.into()),
946 split_meta: self.compilation_options.supports_features.grid_constants,
947 out,
948 }
949 }
950 gpu::Metadata::Length { var } => {
951 let input = self.compile_variable(var);
952 let out = self.compile_variable(out);
953
954 match input {
955 Variable::Slice { .. } => Instruction::SliceLength { input, out },
956 Variable::SharedArray(_id, _item, length) => {
957 Instruction::ConstLength { length, out }
958 }
959 _ => {
960 let id = input.id().expect("Variable should have id");
961 let offset = self.metadata.len_index(id);
962 Instruction::Metadata {
963 info_offset: self.compile_variable(offset.into()),
964 split_meta: self.compilation_options.supports_features.grid_constants,
965 out,
966 }
967 }
968 }
969 }
970 gpu::Metadata::BufferLength { var } => {
971 let input = self.compile_variable(var);
972 let out = self.compile_variable(out);
973
974 match input {
975 Variable::Slice { .. } => Instruction::SliceLength { input, out },
976 _ => {
977 let id = input.id().expect("Variable should have id");
978 let offset = self.metadata.buffer_len_index(id);
979 Instruction::Metadata {
980 info_offset: self.compile_variable(offset.into()),
981 split_meta: self.compilation_options.supports_features.grid_constants,
982 out,
983 }
984 }
985 }
986 }
987 }
988 }
989
990 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
991 match branch {
992 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
993 cond: self.compile_variable(op.cond),
994 instructions: self.compile_scope(&mut op.scope),
995 }),
996 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
997 cond: self.compile_variable(op.cond),
998 instructions_if: self.compile_scope(&mut op.scope_if),
999 instructions_else: self.compile_scope(&mut op.scope_else),
1000 }),
1001 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1002 value: self.compile_variable(op.value),
1003 instructions_default: self.compile_scope(&mut op.scope_default),
1004 instructions_cases: op
1005 .cases
1006 .into_iter()
1007 .map(|(val, mut block)| {
1008 (self.compile_variable(val), self.compile_scope(&mut block))
1009 })
1010 .collect(),
1011 }),
1012 gpu::Branch::Return => instructions.push(Instruction::Return),
1013 gpu::Branch::Break => instructions.push(Instruction::Break),
1014 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1015 i: self.compile_variable(range_loop.i),
1016 start: self.compile_variable(range_loop.start),
1017 end: self.compile_variable(range_loop.end),
1018 step: range_loop.step.map(|it| self.compile_variable(it)),
1019 inclusive: range_loop.inclusive,
1020 instructions: self.compile_scope(&mut range_loop.scope),
1021 }),
1022 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1023 instructions: self.compile_scope(&mut op.scope),
1024 }),
1025 };
1026 }
1027
1028 fn compile_atomic(
1029 &mut self,
1030 value: gpu::AtomicOp,
1031 out: Option<gpu::Variable>,
1032 instructions: &mut Vec<Instruction<D>>,
1033 ) {
1034 let out = out.unwrap();
1035 match value {
1036 gpu::AtomicOp::Load(op) => {
1037 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1038 }
1039 gpu::AtomicOp::Store(op) => {
1040 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1041 }
1042 gpu::AtomicOp::Swap(op) => {
1043 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1044 }
1045 gpu::AtomicOp::Add(op) => {
1046 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1047 }
1048 gpu::AtomicOp::Sub(op) => {
1049 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1050 }
1051 gpu::AtomicOp::Max(op) => {
1052 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1053 }
1054 gpu::AtomicOp::Min(op) => {
1055 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1056 }
1057 gpu::AtomicOp::And(op) => {
1058 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1059 }
1060 gpu::AtomicOp::Or(op) => {
1061 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1062 }
1063 gpu::AtomicOp::Xor(op) => {
1064 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1065 }
1066 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1067 input: self.compile_variable(op.input),
1068 cmp: self.compile_variable(op.cmp),
1069 val: self.compile_variable(op.val),
1070 out: self.compile_variable(out),
1071 }),
1072 }
1073 }
1074
1075 fn compile_arithmetic(
1076 &mut self,
1077 value: gpu::Arithmetic,
1078 out: Option<gpu::Variable>,
1079 modes: InstructionModes,
1080 instructions: &mut Vec<Instruction<D>>,
1081 ) {
1082 let out = out.unwrap();
1083 match value {
1084 gpu::Arithmetic::Add(op) => {
1085 instructions.push(Instruction::Add(self.compile_binary(op, out)))
1086 }
1087 gpu::Arithmetic::SaturatingAdd(op) => {
1088 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1089 }
1090 gpu::Arithmetic::Mul(op) => {
1091 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1092 }
1093 gpu::Arithmetic::Div(op) => {
1094 let op = self.compile_binary(op, out);
1095 instructions.push(self.select_fast_float(
1096 out.ty,
1097 modes,
1098 FastMath::AllowReciprocal
1099 | FastMath::ReducedPrecision
1100 | FastMath::UnsignedZero
1101 | FastMath::NotInf,
1102 Instruction::Div(op),
1103 Instruction::FastDiv(op),
1104 ))
1105 }
1106 gpu::Arithmetic::Sub(op) => {
1107 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1108 }
1109 gpu::Arithmetic::SaturatingSub(op) => {
1110 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1111 }
1112 gpu::Arithmetic::MulHi(op) => {
1113 let instruction = Instruction::HiMul(self.compile_binary(op, out));
1114 D::register_instruction_extension(&mut self.extensions, &instruction);
1115 instructions.push(instruction)
1116 }
1117 gpu::Arithmetic::Modulo(op) => {
1118 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1119 }
1120 gpu::Arithmetic::Abs(op) => {
1121 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1122 }
1123 gpu::Arithmetic::Exp(op) => {
1124 let op = self.compile_unary(op, out);
1125 instructions.push(self.select_fast_float(
1126 out.ty,
1127 modes,
1128 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1129 Instruction::Exp(op),
1130 Instruction::FastExp(op),
1131 ));
1132 }
1133 gpu::Arithmetic::Log(op) => {
1134 let op = self.compile_unary(op, out);
1135 instructions.push(self.select_fast_float(
1136 out.ty,
1137 modes,
1138 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1139 Instruction::Log(op),
1140 Instruction::FastLog(op),
1141 ));
1142 }
1143 gpu::Arithmetic::Log1p(op) => {
1144 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1145 }
1146 gpu::Arithmetic::Cos(op) => {
1147 let op = self.compile_unary(op, out);
1148 instructions.push(self.select_fast_float(
1149 out.ty,
1150 modes,
1151 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1152 Instruction::Cos(op),
1153 Instruction::FastCos(op),
1154 ));
1155 }
1156 gpu::Arithmetic::Sin(op) => {
1157 let op = self.compile_unary(op, out);
1158 instructions.push(self.select_fast_float(
1159 out.ty,
1160 modes,
1161 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1162 Instruction::Sin(op),
1163 Instruction::FastSin(op),
1164 ));
1165 }
1166 gpu::Arithmetic::Tan(op) => {
1167 instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1168 }
1169 gpu::Arithmetic::Tanh(op) => {
1170 let op = self.compile_unary(op, out);
1171 let instruction = Instruction::Tanh(op);
1172 D::register_instruction_extension(&mut self.extensions, &instruction);
1173 if self.compilation_options.supports_features.fast_tanh {
1174 instructions.push(self.select_fast_float(
1175 out.ty,
1176 modes,
1177 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1178 instruction,
1179 Instruction::FastTanh(op),
1180 ))
1181 } else {
1182 instructions.push(instruction);
1183 }
1184 }
1185 gpu::Arithmetic::Sinh(op) => {
1186 let instruction = Instruction::Sinh(self.compile_unary(op, out));
1187 D::register_instruction_extension(&mut self.extensions, &instruction);
1188 instructions.push(instruction)
1189 }
1190 gpu::Arithmetic::Cosh(op) => {
1191 let instruction = Instruction::Cosh(self.compile_unary(op, out));
1192 D::register_instruction_extension(&mut self.extensions, &instruction);
1193 instructions.push(instruction)
1194 }
1195 gpu::Arithmetic::ArcCos(op) => {
1196 let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1197 D::register_instruction_extension(&mut self.extensions, &instruction);
1198 instructions.push(instruction)
1199 }
1200 gpu::Arithmetic::ArcSin(op) => {
1201 let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1202 D::register_instruction_extension(&mut self.extensions, &instruction);
1203 instructions.push(instruction)
1204 }
1205 gpu::Arithmetic::ArcTan(op) => {
1206 let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1207 D::register_instruction_extension(&mut self.extensions, &instruction);
1208 instructions.push(instruction)
1209 }
1210 gpu::Arithmetic::ArcSinh(op) => {
1211 let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1212 D::register_instruction_extension(&mut self.extensions, &instruction);
1213 instructions.push(instruction)
1214 }
1215 gpu::Arithmetic::ArcCosh(op) => {
1216 let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1217 D::register_instruction_extension(&mut self.extensions, &instruction);
1218 instructions.push(instruction)
1219 }
1220 gpu::Arithmetic::ArcTanh(op) => {
1221 let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1222 D::register_instruction_extension(&mut self.extensions, &instruction);
1223 instructions.push(instruction)
1224 }
1225 gpu::Arithmetic::Degrees(op) => {
1226 let instruction = Instruction::Degrees(self.compile_unary(op, out));
1227 D::register_instruction_extension(&mut self.extensions, &instruction);
1228 instructions.push(instruction)
1229 }
1230 gpu::Arithmetic::Radians(op) => {
1231 let instruction = Instruction::Radians(self.compile_unary(op, out));
1232 D::register_instruction_extension(&mut self.extensions, &instruction);
1233 instructions.push(instruction)
1234 }
1235 gpu::Arithmetic::ArcTan2(op) => {
1236 let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1237 D::register_instruction_extension(&mut self.extensions, &instruction);
1238 instructions.push(instruction)
1239 }
1240 gpu::Arithmetic::Powf(op) => {
1241 let op = self.compile_binary(op, out);
1242 instructions.push(self.select_fast_float(
1243 out.ty,
1244 modes,
1245 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1246 Instruction::Powf(op),
1247 Instruction::FastPowf(op),
1248 ))
1249 }
1250 gpu::Arithmetic::Powi(op) => {
1251 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1252 }
1253 gpu::Arithmetic::Sqrt(op) => {
1254 let op = self.compile_unary(op, out);
1255 instructions.push(self.select_fast_float(
1256 out.ty,
1257 modes,
1258 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1259 Instruction::Sqrt(op),
1260 Instruction::FastSqrt(op),
1261 ))
1262 }
1263 gpu::Arithmetic::InverseSqrt(op) => {
1264 let op = self.compile_unary(op, out);
1265 instructions.push(self.select_fast_float(
1266 out.ty,
1267 modes,
1268 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1269 Instruction::InverseSqrt(op),
1270 Instruction::FastInverseSqrt(op),
1271 ))
1272 }
1273 gpu::Arithmetic::Erf(op) => {
1274 let instruction = Instruction::Erf(self.compile_unary(op, out));
1275 D::register_instruction_extension(&mut self.extensions, &instruction);
1276 instructions.push(instruction)
1277 }
1278 gpu::Arithmetic::Max(op) => {
1279 let instruction = Instruction::Max(self.compile_binary(op, out));
1280 D::register_instruction_extension(&mut self.extensions, &instruction);
1281 instructions.push(instruction)
1282 }
1283 gpu::Arithmetic::Min(op) => {
1284 let instruction = Instruction::Min(self.compile_binary(op, out));
1285 D::register_instruction_extension(&mut self.extensions, &instruction);
1286 instructions.push(instruction)
1287 }
1288 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1289 input: self.compile_variable(op.input),
1290 min_value: self.compile_variable(op.min_value),
1291 max_value: self.compile_variable(op.max_value),
1292 out: self.compile_variable(out),
1293 }),
1294 gpu::Arithmetic::Recip(op) => {
1295 let elem = op.input.ty.elem_type();
1296 let input = self.compile_variable(op.input);
1297 let out = self.compile_variable(out);
1298 let lhs = match elem {
1299 gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1300 gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1301 gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1302 gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1303 };
1304 let div = Instruction::Div(BinaryInstruction {
1305 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1306 rhs: input,
1307 out,
1308 });
1309 let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1310
1311 instructions.push(self.select_fast_float(
1312 elem.into(),
1313 modes,
1314 FastMath::AllowReciprocal
1315 | FastMath::ReducedPrecision
1316 | FastMath::UnsignedZero
1317 | FastMath::NotInf,
1318 div,
1319 recip,
1320 ))
1321 }
1322 gpu::Arithmetic::Round(op) => {
1323 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1324 }
1325 gpu::Arithmetic::Floor(op) => {
1326 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1327 }
1328 gpu::Arithmetic::Ceil(op) => {
1329 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1330 }
1331 gpu::Arithmetic::Trunc(op) => {
1332 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1333 }
1334 gpu::Arithmetic::Remainder(op) => {
1335 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1336 }
1337 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1338 a: self.compile_variable(op.a),
1339 b: self.compile_variable(op.b),
1340 c: self.compile_variable(op.c),
1341 out: self.compile_variable(out),
1342 }),
1343 gpu::Arithmetic::Neg(op) => {
1344 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1345 }
1346 gpu::Arithmetic::Normalize(op) => {
1347 let op = self.compile_unary(op, out);
1348 instructions.push(self.select_fast_float(
1349 out.ty,
1350 modes,
1351 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1352 Instruction::Normalize(op),
1353 Instruction::FastNormalize(op),
1354 ))
1355 }
1356 gpu::Arithmetic::Magnitude(op) => {
1357 let op = self.compile_unary(op, out);
1358 instructions.push(self.select_fast_float(
1359 out.ty,
1360 modes,
1361 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1362 Instruction::Magnitude(op),
1363 Instruction::FastMagnitude(op),
1364 ))
1365 }
1366 gpu::Arithmetic::Dot(op) => {
1367 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1368 }
1369 };
1370 }
1371
1372 fn select_fast_float(
1373 &self,
1374 ty: gpu::Type,
1375 modes: InstructionModes,
1376 required_flags: EnumSet<FastMath>,
1377 default: Instruction<D>,
1378 fast: Instruction<D>,
1379 ) -> Instruction<D> {
1380 if !self.compilation_options.supports_features.fast_math
1381 || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1382 {
1383 return default;
1384 }
1385
1386 if modes.fp_math_mode.is_superset(required_flags) {
1387 fast
1388 } else {
1389 default
1390 }
1391 }
1392
1393 fn compile_comparison(
1394 &mut self,
1395 value: gpu::Comparison,
1396 out: Option<gpu::Variable>,
1397 instructions: &mut Vec<Instruction<D>>,
1398 ) {
1399 let out = out.unwrap();
1400 match value {
1401 gpu::Comparison::Equal(op) => {
1402 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1403 }
1404 gpu::Comparison::Lower(op) => {
1405 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1406 }
1407 gpu::Comparison::Greater(op) => {
1408 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1409 }
1410 gpu::Comparison::LowerEqual(op) => {
1411 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1412 }
1413 gpu::Comparison::GreaterEqual(op) => {
1414 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1415 }
1416 gpu::Comparison::NotEqual(op) => {
1417 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1418 }
1419 gpu::Comparison::IsNan(op) => {
1420 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1421 }
1422 gpu::Comparison::IsInf(op) => {
1423 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1424 }
1425 };
1426 }
1427
1428 fn compile_bitwise(
1429 &mut self,
1430 value: gpu::Bitwise,
1431 out: Option<gpu::Variable>,
1432 instructions: &mut Vec<Instruction<D>>,
1433 ) {
1434 let out = out.unwrap();
1435 match value {
1436 gpu::Bitwise::BitwiseOr(op) => {
1437 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1438 }
1439 gpu::Bitwise::BitwiseAnd(op) => {
1440 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1441 }
1442 gpu::Bitwise::BitwiseXor(op) => {
1443 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1444 }
1445 gpu::Bitwise::CountOnes(op) => {
1446 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1447 }
1448 gpu::Bitwise::ReverseBits(op) => {
1449 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1450 }
1451 gpu::Bitwise::ShiftLeft(op) => {
1452 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1453 }
1454 gpu::Bitwise::ShiftRight(op) => {
1455 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1456 }
1457 gpu::Bitwise::BitwiseNot(op) => {
1458 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1459 }
1460 gpu::Bitwise::LeadingZeros(op) => {
1461 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1462 }
1463 gpu::Bitwise::FindFirstSet(op) => {
1464 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1465 D::register_instruction_extension(&mut self.extensions, &instruction);
1466 instructions.push(instruction)
1467 }
1468 };
1469 }
1470
1471 fn compile_operator(
1472 &mut self,
1473 value: gpu::Operator,
1474 out: Option<gpu::Variable>,
1475 instructions: &mut Vec<Instruction<D>>,
1476 ) {
1477 let out = out.unwrap();
1478 match value {
1479 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1480 instructions.push(Instruction::Index(self.compile_index(op, out)));
1481 }
1482 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1483 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1484 }
1485 gpu::Operator::And(op) => {
1486 instructions.push(Instruction::And(self.compile_binary(op, out)))
1487 }
1488 gpu::Operator::Or(op) => {
1489 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1490 }
1491 gpu::Operator::Not(op) => {
1492 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1493 }
1494 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1495 inputs: op
1496 .inputs
1497 .into_iter()
1498 .map(|it| self.compile_variable(it))
1499 .collect(),
1500 out: self.compile_variable(out),
1501 }),
1502 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1503 input: self.compile_variable(op.input),
1504 in_index: self.compile_variable(op.in_index),
1505 out: self.compile_variable(out),
1506 out_index: self.compile_variable(op.out_index),
1507 }),
1508 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1509 input: self.compile_variable(op.input),
1510 in_index: self.compile_variable(op.in_index),
1511 out: self.compile_variable(out),
1512 out_index: self.compile_variable(op.out_index),
1513 len: op.len,
1514 }),
1515 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1516 cond: self.compile_variable(op.cond),
1517 then: self.compile_variable(op.then),
1518 or_else: self.compile_variable(op.or_else),
1519 out: self.compile_variable(out),
1520 }),
1521 gpu::Operator::Cast(op)
1523 if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1524 {
1525 self.flags.elem_f16 = true;
1527 self.flags.elem_bf16 = true;
1528 let vec_in = op.input.ty.line_size();
1529 let packing = out.storage_type().packing_factor();
1530 self.compile_type(op.input.ty.line(packing));
1531 self.compile_type(
1532 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1533 );
1534 self.compile_type(
1535 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1536 );
1537 self.compile_type(
1538 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1539 );
1540 self.compile_type(
1541 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1542 );
1543
1544 let inst = self.compile_unary(op, out);
1545
1546 instructions.push(Instruction::SpecialCast(inst));
1547 }
1548 gpu::Operator::Cast(op) => {
1549 let op = self.compile_unary(op, out);
1550
1551 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1552 self.flags.elem_tf32 = true;
1553 }
1554
1555 instructions.push(Instruction::Assign(op))
1556 }
1557 gpu::Operator::Reinterpret(op) => {
1558 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1559 }
1560 };
1561 }
1562
1563 fn compile_binary(
1564 &mut self,
1565 value: gpu::BinaryOperator,
1566 out: gpu::Variable,
1567 ) -> BinaryInstruction<D> {
1568 BinaryInstruction {
1569 lhs: self.compile_variable(value.lhs),
1570 rhs: self.compile_variable(value.rhs),
1571 out: self.compile_variable(out),
1572 }
1573 }
1574
1575 fn compile_index_assign(
1576 &mut self,
1577 value: gpu::IndexAssignOperator,
1578 out: gpu::Variable,
1579 ) -> IndexAssignInstruction<D> {
1580 IndexAssignInstruction {
1581 index: self.compile_variable(value.index),
1582 value: self.compile_variable(value.value),
1583 line_size: value.line_size,
1584 out: self.compile_variable(out),
1585 }
1586 }
1587
1588 fn compile_index(
1589 &mut self,
1590 value: gpu::IndexOperator,
1591 out: gpu::Variable,
1592 ) -> IndexInstruction<D> {
1593 IndexInstruction {
1594 list: self.compile_variable(value.list),
1595 index: self.compile_variable(value.index),
1596 line_size: value.line_size,
1597 out: self.compile_variable(out),
1598 }
1599 }
1600
1601 fn compile_unary(
1602 &mut self,
1603 value: gpu::UnaryOperator,
1604 out: gpu::Variable,
1605 ) -> UnaryInstruction<D> {
1606 UnaryInstruction {
1607 input: self.compile_variable(value.input),
1608 out: self.compile_variable(out),
1609 }
1610 }
1611
1612 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1613 let item = value.ty;
1614 match value.kind {
1615 gpu::VariableKind::GlobalInputArray(id) => {
1616 Variable::GlobalInputArray(id, self.compile_type(item))
1617 }
1618 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1619 id,
1620 elem: self.compile_storage_type(item.storage_type()),
1621 in_struct: self.compilation_options.supports_features.grid_constants,
1622 },
1623 gpu::VariableKind::TensorMapInput(id) => {
1624 self.flags.inst_tma = true;
1625 Variable::TensorMap(id)
1626 }
1627 gpu::VariableKind::TensorMapOutput(id) => {
1628 self.flags.inst_tma = true;
1629 Variable::TensorMap(id)
1630 }
1631 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1632 id,
1633 item: self.compile_type(item),
1634 },
1635 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1636 id,
1637 item: self.compile_type(item),
1638 },
1639 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1640 id,
1641 item: self.compile_type(item),
1642 },
1643 gpu::VariableKind::GlobalOutputArray(id) => {
1644 Variable::GlobalOutputArray(id, self.compile_type(item))
1645 }
1646 gpu::VariableKind::ConstantScalar(value) => {
1647 Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1648 }
1649 gpu::VariableKind::SharedArray { id, length, .. } => {
1650 let item = self.compile_type(item);
1651 Variable::SharedArray(id, item, length)
1652 }
1653 gpu::VariableKind::Shared { id } => {
1654 let item = self.compile_type(item);
1655 Variable::Shared(id, item)
1656 }
1657 gpu::VariableKind::ConstantArray {
1658 id,
1659 length,
1660 unroll_factor,
1661 } => {
1662 let item = self.compile_type(item);
1663 Variable::ConstantArray(id, item, length * unroll_factor)
1664 }
1665 gpu::VariableKind::Builtin(builtin) => match builtin {
1666 gpu::Builtin::AbsolutePos => {
1667 self.flags.indexes.absolute_pos = true;
1668 Variable::AbsolutePos
1669 }
1670 gpu::Builtin::CubePosCluster
1671 if self.compilation_options.supports_features.clusters =>
1672 {
1673 self.flags.indexes.cluster_pos = true;
1674 Variable::ClusterRank
1675 }
1676 gpu::Builtin::CubePosClusterX
1677 if self.compilation_options.supports_features.clusters =>
1678 {
1679 self.flags.indexes.cluster_pos = true;
1680 Variable::ClusterIndexX
1681 }
1682 gpu::Builtin::CubePosClusterY
1683 if self.compilation_options.supports_features.clusters =>
1684 {
1685 self.flags.indexes.cluster_pos = true;
1686 Variable::ClusterIndexY
1687 }
1688 gpu::Builtin::CubePosClusterZ
1689 if self.compilation_options.supports_features.clusters =>
1690 {
1691 self.flags.indexes.cluster_pos = true;
1692 Variable::ClusterIndexZ
1693 }
1694 gpu::Builtin::CubePosCluster
1697 | gpu::Builtin::CubePosClusterX
1698 | gpu::Builtin::CubePosClusterY
1699 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1700 gpu::Builtin::AbsolutePosX => {
1701 self.flags.indexes.absolute_pos_tuple = true;
1702 Variable::AbsolutePosX
1703 }
1704 gpu::Builtin::AbsolutePosY => {
1705 self.flags.indexes.absolute_pos_tuple = true;
1706 Variable::AbsolutePosY
1707 }
1708 gpu::Builtin::AbsolutePosZ => {
1709 self.flags.indexes.absolute_pos_tuple = true;
1710 Variable::AbsolutePosZ
1711 }
1712 gpu::Builtin::CubeDim => {
1713 self.flags.indexes.cube_dim = true;
1714 Variable::CubeDim
1715 }
1716 gpu::Builtin::CubeDimX => {
1717 self.flags.indexes.cube_dim_tuple = true;
1718 Variable::CubeDimX
1719 }
1720 gpu::Builtin::CubeDimY => {
1721 self.flags.indexes.cube_dim_tuple = true;
1722 Variable::CubeDimY
1723 }
1724 gpu::Builtin::CubeDimZ => {
1725 self.flags.indexes.cube_dim_tuple = true;
1726 Variable::CubeDimZ
1727 }
1728 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1729 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1730 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1731 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1732 gpu::Builtin::CubePos => {
1733 self.flags.indexes.cube_pos = true;
1734 Variable::CubePos
1735 }
1736 gpu::Builtin::CubePosX => {
1737 self.flags.indexes.cube_pos_tuple = true;
1738 Variable::CubePosX
1739 }
1740 gpu::Builtin::CubePosY => {
1741 self.flags.indexes.cube_pos_tuple = true;
1742 Variable::CubePosY
1743 }
1744 gpu::Builtin::CubePosZ => {
1745 self.flags.indexes.cube_pos_tuple = true;
1746 Variable::CubePosZ
1747 }
1748 gpu::Builtin::CubeCount => {
1749 self.flags.indexes.cube_count = true;
1750 Variable::CubeCount
1751 }
1752 gpu::Builtin::CubeCountX => {
1753 self.flags.indexes.cube_count_tuple = true;
1754 Variable::CubeCountX
1755 }
1756 gpu::Builtin::CubeCountY => {
1757 self.flags.indexes.cube_count_tuple = true;
1758 Variable::CubeCountY
1759 }
1760 gpu::Builtin::CubeCountZ => {
1761 self.flags.indexes.cube_count_tuple = true;
1762 Variable::CubeCountZ
1763 }
1764 gpu::Builtin::UnitPos => {
1765 self.flags.indexes.unit_pos = true;
1766 Variable::UnitPos
1767 }
1768 gpu::Builtin::UnitPosX => {
1769 self.flags.indexes.unit_pos_tuple = true;
1770 Variable::UnitPosX
1771 }
1772 gpu::Builtin::UnitPosY => {
1773 self.flags.indexes.unit_pos_tuple = true;
1774 Variable::UnitPosY
1775 }
1776 gpu::Builtin::UnitPosZ => {
1777 self.flags.indexes.unit_pos_tuple = true;
1778 Variable::UnitPosZ
1779 }
1780 gpu::Builtin::PlaneDim => {
1781 self.flags.indexes.plane_dim = true;
1782 Variable::PlaneDim
1783 }
1784 gpu::Builtin::UnitPosPlane => {
1785 self.flags.indexes.unit_pos_plane = true;
1786 Variable::UnitPosPlane
1787 }
1788 },
1789 gpu::VariableKind::LocalArray {
1790 id,
1791 length,
1792 unroll_factor,
1793 } => {
1794 let item = self.compile_type(item);
1795 if !self.local_arrays.iter().any(|s| s.index == id) {
1796 self.local_arrays
1797 .push(LocalArray::new(id, item, length * unroll_factor));
1798 }
1799 Variable::LocalArray(id, item, length)
1800 }
1801 gpu::VariableKind::Matrix { id, mat } => {
1802 self.flags.inst_wmma = true;
1803 Variable::WmmaFragment {
1804 id,
1805 frag: self.compile_matrix(mat),
1806 }
1807 }
1808 gpu::VariableKind::Pipeline { id, num_stages } => {
1809 self.flags.op_pipeline = true;
1810 let pipeline = Variable::Pipeline { id };
1811 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1812 self.pipelines.push(PipelineOps::Init {
1813 pipeline,
1814 num_stages,
1815 });
1816 }
1817 pipeline
1818 }
1819 gpu::VariableKind::BarrierToken { id, level } => {
1820 self.flags.op_barrier = true;
1821 Variable::BarrierToken { id, level }
1822 }
1823 }
1824 }
1825
1826 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1827 Fragment {
1828 ident: self.compile_matrix_ident(matrix.ident),
1829 m: matrix.m,
1830 n: matrix.n,
1831 k: matrix.k,
1832 elem: self.compile_storage_type(matrix.storage),
1833 layout: self.compile_matrix_layout(matrix.layout),
1834 }
1835 }
1836
1837 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1838 match ident {
1839 gpu::MatrixIdent::A => FragmentIdent::A,
1840 gpu::MatrixIdent::B => FragmentIdent::B,
1841 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1842 }
1843 }
1844
1845 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1846 match layout {
1847 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1848 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1849 gpu::MatrixLayout::Undefined => None,
1850 }
1851 }
1852
1853 fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1854 Binding {
1855 id: binding.id,
1856 item: self.compile_type(binding.ty),
1857 location: binding.location,
1858 size: binding.size,
1859 vis: binding.visibility,
1860 }
1861 }
1862
1863 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1864 let item = match ty {
1865 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1866 gpu::Type::Line(ty, line_size) => {
1867 Item::new(self.compile_storage_type(ty), line_size as usize, false)
1868 }
1869 gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1870 };
1871 if item.elem != super::Elem::TF32 {
1872 self.items.insert(item);
1873 self.items.insert(item.optimized());
1874 } else {
1875 let mut item = item;
1877 item.elem = super::Elem::F32;
1878 self.items.insert(item);
1879 }
1880
1881 item
1882 }
1883
1884 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1885 match value {
1886 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1887 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1888 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1889 FloatKind::E2M1 => {
1890 self.flags.elem_fp4 = true;
1891 Elem::FP4x2(FP4Kind::E2M1)
1892 }
1893 FloatKind::E2M3 => {
1894 self.flags.elem_fp6 = true;
1895 Elem::FP6x2(FP6Kind::E2M3)
1896 }
1897 FloatKind::E3M2 => {
1898 self.flags.elem_fp6 = true;
1899 Elem::FP6(FP6Kind::E3M2)
1900 }
1901 FloatKind::E4M3 => {
1902 self.flags.elem_fp8 = true;
1903 Elem::FP8x2(FP8Kind::E4M3)
1904 }
1905 FloatKind::E5M2 => {
1906 self.flags.elem_fp8 = true;
1907 Elem::FP8x2(FP8Kind::E5M2)
1908 }
1909 FloatKind::UE8M0 => {
1910 self.flags.elem_fp8 = true;
1911 Elem::FP8x2(FP8Kind::UE8M0)
1912 }
1913 FloatKind::F16 => {
1914 self.flags.elem_f16 = true;
1915 Elem::F16x2
1916 }
1917 FloatKind::BF16 => {
1918 self.flags.elem_bf16 = true;
1919 Elem::BF16x2
1920 }
1921 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1922 },
1923 gpu::StorageType::Packed(other, factor) => {
1924 unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
1925 }
1926 gpu::StorageType::Opaque(ty) => match ty {
1927 gpu::OpaqueType::Barrier(level) => {
1928 self.flags.op_barrier = true;
1929 Elem::Barrier(level)
1930 }
1931 },
1932 }
1933 }
1934
1935 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1936 match value {
1937 gpu::ElemType::Float(kind) => match kind {
1938 gpu::FloatKind::E2M1 => {
1939 self.flags.elem_fp4 = true;
1940 Elem::FP4(FP4Kind::E2M1)
1941 }
1942 gpu::FloatKind::E2M3 => {
1943 self.flags.elem_fp6 = true;
1944 Elem::FP6(FP6Kind::E2M3)
1945 }
1946 gpu::FloatKind::E3M2 => {
1947 self.flags.elem_fp6 = true;
1948 Elem::FP6(FP6Kind::E3M2)
1949 }
1950 gpu::FloatKind::E4M3 => {
1951 self.flags.elem_fp8 = true;
1952 Elem::FP8(FP8Kind::E4M3)
1953 }
1954 gpu::FloatKind::E5M2 => {
1955 self.flags.elem_fp8 = true;
1956 Elem::FP8(FP8Kind::E5M2)
1957 }
1958 gpu::FloatKind::UE8M0 => {
1959 self.flags.elem_fp8 = true;
1960 Elem::FP8(FP8Kind::UE8M0)
1961 }
1962 gpu::FloatKind::F16 => {
1963 self.flags.elem_f16 = true;
1964 Elem::F16
1965 }
1966 gpu::FloatKind::BF16 => {
1967 self.flags.elem_bf16 = true;
1968 Elem::BF16
1969 }
1970 gpu::FloatKind::TF32 => Elem::TF32,
1971 gpu::FloatKind::Flex32 => Elem::F32,
1972 gpu::FloatKind::F32 => Elem::F32,
1973 gpu::FloatKind::F64 => Elem::F64,
1974 },
1975 gpu::ElemType::Int(kind) => match kind {
1976 gpu::IntKind::I8 => Elem::I8,
1977 gpu::IntKind::I16 => Elem::I16,
1978 gpu::IntKind::I32 => Elem::I32,
1979 gpu::IntKind::I64 => Elem::I64,
1980 },
1981 gpu::ElemType::UInt(kind) => match kind {
1982 gpu::UIntKind::U8 => Elem::U8,
1983 gpu::UIntKind::U16 => Elem::U16,
1984 gpu::UIntKind::U32 => Elem::U32,
1985 gpu::UIntKind::U64 => Elem::U64,
1986 },
1987 gpu::ElemType::Bool => Elem::Bool,
1988 }
1989 }
1990}
1991
1992fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1993 match elem {
1994 gpu::ElemType::Float(kind) => matches!(
1995 kind,
1996 FloatKind::E2M1
1997 | FloatKind::E2M3
1998 | FloatKind::E3M2
1999 | FloatKind::E4M3
2000 | FloatKind::E5M2
2001 | FloatKind::UE8M0
2002 ),
2003 _ => false,
2004 }
2005}
2006
2007fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2008 Variable::ConstantScalar(
2009 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
2010 Elem::U32,
2011 )
2012}
2013
2014pub fn register_supported_types(props: &mut DeviceProperties) {
2015 let supported_types = [
2016 gpu::ElemType::UInt(gpu::UIntKind::U8),
2017 gpu::ElemType::UInt(gpu::UIntKind::U16),
2018 gpu::ElemType::UInt(gpu::UIntKind::U32),
2019 gpu::ElemType::UInt(gpu::UIntKind::U64),
2020 gpu::ElemType::Int(gpu::IntKind::I8),
2021 gpu::ElemType::Int(gpu::IntKind::I16),
2022 gpu::ElemType::Int(gpu::IntKind::I32),
2023 gpu::ElemType::Int(gpu::IntKind::I64),
2024 gpu::ElemType::Float(gpu::FloatKind::BF16),
2025 gpu::ElemType::Float(gpu::FloatKind::F16),
2026 gpu::ElemType::Float(gpu::FloatKind::F32),
2027 gpu::ElemType::Float(gpu::FloatKind::Flex32),
2028 gpu::ElemType::Bool,
2031 ];
2032
2033 let supported_atomic_types = [
2034 gpu::ElemType::Int(gpu::IntKind::I32),
2035 gpu::ElemType::Int(gpu::IntKind::I64),
2036 gpu::ElemType::UInt(gpu::UIntKind::U32),
2037 gpu::ElemType::UInt(gpu::UIntKind::U64),
2038 gpu::ElemType::Float(gpu::FloatKind::F32),
2039 ];
2040
2041 for ty in supported_types {
2042 props.register_type_usage(ty, TypeUsage::all_scalar());
2043 }
2044
2045 for ty in supported_atomic_types {
2046 props.register_type_usage(
2047 gpu::StorageType::Atomic(ty),
2048 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
2049 );
2050 }
2051}