1use super::{
2 BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP4Kind,
3 FP6Kind, FP8Kind, Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction,
4 IndexInstruction, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction, Variable,
5 WarpInstruction, WmmaInstruction, barrier::BarrierOps, pipeline::PipelineOps,
6};
7use crate::shared::MmaShape;
8use cubecl_common::backtrace::BackTrace;
9use cubecl_core::{
10 CubeDim,
11 ir::{
12 self as gpu, DeviceProperties, ElemType, FloatKind, InstructionModes, OpaqueType,
13 Operation, Processor, SourceLoc, StorageType,
14 features::{EnumSet, TypeUsage},
15 },
16 post_processing::checked_io::CheckedIoProcessor,
17 prelude::{FastMath, KernelDefinition},
18 server::ExecutionMode,
19};
20use cubecl_opt::{Optimizer, SharedLiveness};
21use cubecl_runtime::compiler::{CompilationError, Compiler};
22use std::{collections::HashSet, fmt::Debug};
23
24pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
25 std::sync::atomic::AtomicU32::new(0);
26
27#[derive(Clone, Debug)]
28pub struct CompilationOptions {
29 pub warp_size: u32,
30 pub supports_features: CppSupportedFeatures,
31}
32
33#[derive(Clone, Debug, Default)]
34pub struct CppSupportedFeatures {
35 pub grid_constants: bool,
36 pub clusters: bool,
37 pub fast_math: bool,
38 pub fast_tanh: bool,
39 pub elect_sync: bool,
40}
41
42impl Default for CompilationOptions {
43 fn default() -> Self {
44 Self {
45 warp_size: 32,
46 supports_features: Default::default(),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Default)]
54pub struct CubeIndexFlags {
55 pub absolute_pos: bool,
56 pub absolute_pos_tuple: bool,
57 pub cube_count: bool,
58 pub cube_count_tuple: bool,
59 pub cube_dim: bool,
60 pub cube_dim_tuple: bool,
61 pub cube_pos: bool,
62 pub cube_pos_tuple: bool,
63 pub plane_dim: bool,
64 pub plane_dim_checked: bool,
65 pub plane_index: bool,
66 pub unit_pos: bool,
67 pub unit_pos_tuple: bool,
68 pub unit_pos_plane: bool,
69 pub cluster_pos: bool,
70}
71
72#[derive(Debug, Clone)]
74pub struct Flags<D: Dialect> {
75 pub elem_fp4: bool,
76 pub elem_fp6: bool,
77 pub elem_fp8: bool,
78 pub elem_bf16: bool,
79 pub elem_f16: bool,
80 pub elem_tf32: bool,
81 pub indexes: CubeIndexFlags,
82 pub op_barrier: bool,
83 pub op_pipeline: bool,
84 pub inst_tma: bool,
85 pub inst_tma_im2col: bool,
86 pub inst_wmma: bool,
87 pub inst_ptx_wrappers: bool,
88 pub inst_async_copy: bool,
89 pub use_grid_constants: bool,
90 pub static_meta_length: usize,
91 pub has_dynamic_meta: bool,
92 pub cube_dim: CubeDim,
93 pub cluster_dim: Option<CubeDim>,
94 pub address_type: Item<D>,
95}
96
97#[allow(clippy::too_many_arguments)]
98#[derive(Clone, Debug)]
99pub struct CppCompiler<D: Dialect> {
100 barriers: Vec<BarrierOps<D>>,
101 compilation_options: CompilationOptions,
102 const_arrays: Vec<ConstArray<D>>,
103 ext_meta_positions: Vec<u32>,
104 cluster_dim: CubeDim,
105 extensions: Vec<D::Extension>,
106 flags: Flags<D>,
107 items: HashSet<Item<D>>,
108 local_arrays: Vec<LocalArray<D>>,
109 metadata: cubecl_core::Metadata,
110 pipelines: Vec<PipelineOps<D>>,
111 source_loc: Option<SourceLoc>,
112 strategy: ExecutionMode,
113 addr_type: Item<D>,
114}
115
116impl<D: Dialect> Default for Flags<D> {
117 fn default() -> Self {
118 Self {
119 elem_fp4: Default::default(),
120 elem_fp6: Default::default(),
121 elem_fp8: Default::default(),
122 elem_bf16: Default::default(),
123 elem_f16: Default::default(),
124 elem_tf32: Default::default(),
125 indexes: Default::default(),
126 op_barrier: Default::default(),
127 op_pipeline: Default::default(),
128 inst_tma: Default::default(),
129 inst_tma_im2col: Default::default(),
130 inst_wmma: Default::default(),
131 inst_ptx_wrappers: Default::default(),
132 inst_async_copy: Default::default(),
133 use_grid_constants: Default::default(),
134 static_meta_length: Default::default(),
135 has_dynamic_meta: Default::default(),
136 cube_dim: CubeDim::new_single(),
137 cluster_dim: Default::default(),
138 address_type: Item::scalar(Elem::U32, true),
139 }
140 }
141}
142
143impl<D: Dialect> Default for CppCompiler<D> {
144 fn default() -> Self {
145 Self {
146 barriers: Default::default(),
147 compilation_options: Default::default(),
148 const_arrays: Default::default(),
149 ext_meta_positions: Default::default(),
150 cluster_dim: CubeDim::new_single(),
151 extensions: Default::default(),
152 flags: Flags::default(),
153 items: Default::default(),
154 local_arrays: Default::default(),
155 metadata: Default::default(),
156 pipelines: Default::default(),
157 source_loc: Default::default(),
158 strategy: Default::default(),
159 addr_type: Item::scalar(Elem::U32, true),
160 }
161 }
162}
163
164impl<D: Dialect> Compiler for CppCompiler<D> {
165 type Representation = ComputeKernel<D>;
166 type CompilationOptions = CompilationOptions;
167
168 fn compile(
169 &mut self,
170 mut kernel: KernelDefinition,
171 compilation_options: &Self::CompilationOptions,
172 strategy: ExecutionMode,
173 addr_type: StorageType,
174 ) -> Result<Self::Representation, CompilationError> {
175 let errors = kernel.body.pop_errors();
176 if !errors.is_empty() {
177 let mut reason = "Can't compile cpp kernel\nCaused by:\n ".to_string();
178 for error in errors {
179 reason += error.as_str();
180 reason += "\n";
181 }
182
183 return Err(CompilationError::Validation {
184 reason,
185 backtrace: BackTrace::capture(),
186 });
187 }
188
189 self.addr_type = self.compile_type(addr_type.into());
190 self.compilation_options = compilation_options.clone();
191 self.strategy = strategy;
192
193 if !self.compilation_options.supports_features.clusters {
194 kernel.options.cluster_dim = None;
195 }
196 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
197
198 let ir = self.clone().compile_ir(kernel, addr_type);
199 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
200 Ok(ir)
201 }
202
203 fn elem_size(&self, elem: gpu::ElemType) -> usize {
204 elem.size()
205 }
206
207 fn extension(&self) -> &'static str {
208 "cpp"
209 }
210}
211
212impl<D: Dialect> CppCompiler<D> {
213 fn compile_ir(
214 mut self,
215 value: KernelDefinition,
216 address_type: StorageType,
217 ) -> ComputeKernel<D> {
218 self.build_metadata(&value);
219
220 let instructions = self.compile_scope(&mut value.body.clone());
221 let tensor_maps = value
222 .tensor_maps
223 .into_iter()
224 .map(|b| self.compile_binding(b))
225 .collect();
226 let buffers = value
227 .buffers
228 .into_iter()
229 .map(|b| self.compile_binding(b))
230 .collect();
231 let scalars = value
232 .scalars
233 .into_iter()
234 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
235 .collect();
236
237 let flags = Flags {
239 indexes: D::builtin_rules(&self.flags.indexes),
240 inst_wmma: self.flags.inst_wmma,
241 op_pipeline: self.flags.op_pipeline,
242 op_barrier: self.flags.op_barrier,
243 elem_fp4: self.flags.elem_fp4,
244 elem_fp6: self.flags.elem_fp6,
245 elem_fp8: self.flags.elem_fp8,
246 elem_bf16: self.flags.elem_bf16,
247 elem_f16: self.flags.elem_f16,
248 elem_tf32: self.flags.elem_tf32,
249 inst_tma: self.flags.inst_tma,
250 inst_tma_im2col: self.flags.inst_tma_im2col,
251 inst_async_copy: self.flags.inst_async_copy,
252 inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
253 use_grid_constants: self.compilation_options.supports_features.grid_constants,
254 has_dynamic_meta: self.metadata.static_len() > 0,
257 static_meta_length: self.metadata.static_len() as usize,
258 cube_dim: value.cube_dim,
259 cluster_dim: value.options.cluster_dim,
260 address_type: self.compile_type(address_type.into()),
261 };
262
263 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
264 let shared_allocs = opt.analysis::<SharedLiveness>();
265 let shared_memories = shared_allocs
266 .allocations
267 .values()
268 .map(|alloc| match alloc.smem {
269 cubecl_opt::SharedMemory::Array {
270 id,
271 length,
272 ty,
273 align,
274 } => SharedMemory::Array {
275 index: id,
276 item: self.compile_type(ty),
277 length,
278 align,
279 offset: alloc.offset,
280 },
281 cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
282 index: id,
283 item: self.compile_type(ty),
284 align,
285 offset: alloc.offset,
286 },
287 })
288 .collect();
289
290 let body = Body {
291 instructions,
292 shared_memories,
293 pipelines: self.pipelines,
294 barriers: self.barriers,
295 const_arrays: self.const_arrays,
296 local_arrays: self.local_arrays,
297 };
298
299 let mut cluster_dim = value.options.cluster_dim;
300 if !self.compilation_options.supports_features.clusters {
301 cluster_dim = None;
302 }
303
304 ComputeKernel {
305 tensor_maps,
306 buffers,
307 scalars,
308 meta_static_len: self.metadata.static_len() as usize,
309 cube_dim: value.cube_dim,
310 body,
311 extensions: self.extensions,
312 flags,
313 items: self.items,
314 kernel_name: value.options.kernel_name,
315 cluster_dim,
316 }
317 }
318
319 fn build_metadata(&mut self, value: &KernelDefinition) {
320 let mut num_ext = 0;
321
322 let mut all_meta: Vec<_> = value
323 .buffers
324 .iter()
325 .chain(value.tensor_maps.iter())
326 .map(|buf| (buf.id, buf.has_extended_meta))
327 .collect();
328
329 all_meta.sort_by_key(|(id, _)| *id);
330
331 for (_, has_extended_meta) in &all_meta {
332 self.ext_meta_positions.push(num_ext);
333 if *has_extended_meta {
334 num_ext += 1;
335 }
336 }
337
338 let num_meta = all_meta.len();
339
340 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
341 }
342
343 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
344 let id = var.index().expect("Variable should have index");
345 self.ext_meta_positions[id as usize]
346 }
347
348 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
349 let mut instructions = Vec::new();
350
351 let const_arrays = scope
352 .const_arrays
353 .drain(..)
354 .map(|(var, values)| ConstArray {
355 index: var.index().unwrap(),
356 item: self.compile_type(var.ty),
357 size: values.len() as u32,
358 values: values
359 .into_iter()
360 .map(|val| self.compile_variable(val))
361 .collect(),
362 })
363 .collect::<Vec<_>>();
364 self.const_arrays.extend(const_arrays);
365
366 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
367 let dialect_processors = D::processors();
368 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
369 processors.extend(dialect_processors.iter().map(|it| &**it));
370
371 let processing = scope.process(processors);
372
373 for var in processing.variables {
374 instructions.push(Instruction::DeclareVariable {
375 var: self.compile_variable(var),
376 });
377 }
378
379 processing
380 .instructions
381 .into_iter()
382 .for_each(|op| self.compile_instruction(&mut instructions, op));
383
384 instructions
385 }
386
387 fn compile_instruction(
388 &mut self,
389 instructions: &mut Vec<Instruction<D>>,
390 instruction: gpu::Instruction,
391 ) {
392 self.update_debug_loc(instructions, &instruction);
393 let out = instruction.out;
394 match instruction.operation {
395 gpu::Operation::Copy(variable) => {
396 instructions.push(Instruction::Assign(UnaryInstruction {
397 input: self.compile_variable(variable),
398 out: self.compile_variable(out.unwrap()),
399 }));
400 }
401 gpu::Operation::Arithmetic(op) => {
402 self.compile_arithmetic(op, out, instruction.modes, instructions)
403 }
404 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
405 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
406 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
407 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
408 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
409 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
410 gpu::Operation::Synchronization(val) => match val {
411 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
412 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
413 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
414 gpu::Synchronization::SyncAsyncProxyShared => {
415 self.flags.inst_tma = true;
416 instructions.push(Instruction::ProxyAsyncToSharedFence)
417 }
418 },
419 gpu::Operation::Plane(op) => {
420 self.flags.indexes.plane_dim_checked = true;
421 let out = self.compile_variable(out.unwrap());
422 match op {
423 gpu::Plane::Sum(op) => {
424 let instruction = WarpInstruction::ReduceSum {
425 input: self.compile_variable(op.input),
426 out,
427 };
428 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
429 instructions.push(Instruction::Warp(instruction));
430 }
431 gpu::Plane::InclusiveSum(op) => {
432 self.flags.indexes.unit_pos_plane = true;
433 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
434 input: self.compile_variable(op.input),
435 out,
436 }))
437 }
438 gpu::Plane::InclusiveProd(op) => {
439 self.flags.indexes.unit_pos_plane = true;
440 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
441 input: self.compile_variable(op.input),
442 out,
443 }))
444 }
445 gpu::Plane::ExclusiveSum(op) => {
446 self.flags.indexes.unit_pos_plane = true;
447 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
448 input: self.compile_variable(op.input),
449 out,
450 }))
451 }
452 gpu::Plane::ExclusiveProd(op) => {
453 self.flags.indexes.unit_pos_plane = true;
454 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
455 input: self.compile_variable(op.input),
456 out,
457 }))
458 }
459 gpu::Plane::Prod(op) => {
460 let instruction = WarpInstruction::ReduceProd {
461 input: self.compile_variable(op.input),
462 out,
463 };
464 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
465 instructions.push(Instruction::Warp(instruction))
466 }
467 gpu::Plane::Max(op) => {
468 let instruction = WarpInstruction::ReduceMax {
469 input: self.compile_variable(op.input),
470 out,
471 };
472 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
473 instructions.push(Instruction::Warp(instruction))
474 }
475 gpu::Plane::Min(op) => {
476 let instruction = WarpInstruction::ReduceMin {
477 input: self.compile_variable(op.input),
478 out,
479 };
480 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
481 instructions.push(Instruction::Warp(instruction))
482 }
483 gpu::Plane::Elect => {
484 if self.compilation_options.supports_features.elect_sync {
485 self.flags.inst_ptx_wrappers = true;
486 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
487 } else {
488 instructions
489 .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
490 }
491 }
492 gpu::Plane::All(op) => {
493 instructions.push(Instruction::Warp(WarpInstruction::All {
494 input: self.compile_variable(op.input),
495 out,
496 }))
497 }
498 gpu::Plane::Any(op) => {
499 instructions.push(Instruction::Warp(WarpInstruction::Any {
500 input: self.compile_variable(op.input),
501 out,
502 }))
503 }
504 gpu::Plane::Ballot(op) => {
505 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
506 input: self.compile_variable(op.input),
507 out,
508 }))
509 }
510 gpu::Plane::Broadcast(op) => {
511 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
512 input: self.compile_variable(op.lhs),
513 id: self.compile_variable(op.rhs),
514 out,
515 }))
516 }
517 gpu::Plane::Shuffle(op) => {
518 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
519 input: self.compile_variable(op.lhs),
520 src_lane: self.compile_variable(op.rhs),
521 out,
522 }))
523 }
524 gpu::Plane::ShuffleXor(op) => {
525 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
526 input: self.compile_variable(op.lhs),
527 mask: self.compile_variable(op.rhs),
528 out,
529 }))
530 }
531 gpu::Plane::ShuffleUp(op) => {
532 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
533 input: self.compile_variable(op.lhs),
534 delta: self.compile_variable(op.rhs),
535 out,
536 }))
537 }
538 gpu::Plane::ShuffleDown(op) => {
539 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
540 input: self.compile_variable(op.lhs),
541 delta: self.compile_variable(op.rhs),
542 out,
543 }))
544 }
545 }
546 }
547 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
548 gpu::Operation::NonSemantic(debug) => match debug {
549 gpu::NonSemantic::Print {
550 format_string,
551 args,
552 } => instructions.push(Instruction::Printf {
553 format_string,
554 args: args
555 .into_iter()
556 .map(|arg| self.compile_variable(arg))
557 .collect(),
558 }),
559 gpu::NonSemantic::Comment { content } => {
560 instructions.push(Instruction::Comment { content })
561 }
562 _ => {}
564 },
565 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
566 gpu::BarrierOps::Declare { barrier } => {
567 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
568 else {
569 unreachable!()
570 };
571 let barrier = self.compile_variable(barrier);
572 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
573 barrier,
574 level,
575 }));
576 }
577 gpu::BarrierOps::Init {
578 barrier,
579 is_elected,
580 arrival_count,
581 } => {
582 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
583 else {
584 unreachable!()
585 };
586 let barrier = self.compile_variable(barrier);
587 let arrival_count = self.compile_variable(arrival_count);
588 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
589 barrier,
590 is_elected: self.compile_variable(is_elected),
591 arrival_count,
592 level,
593 }));
594 }
595 gpu::BarrierOps::InitManual {
596 barrier,
597 arrival_count,
598 } => {
599 let barrier = self.compile_variable(barrier);
600 let arrival_count = self.compile_variable(arrival_count);
601 instructions.push(Instruction::Barrier(
602 super::barrier::BarrierOps::InitManual {
603 barrier,
604 arrival_count,
605 },
606 ));
607 }
608 gpu::BarrierOps::MemCopyAsync {
609 barrier,
610 source,
611 source_length,
612 offset_source,
613 offset_out,
614 } => {
615 instructions.push(Instruction::Barrier(
616 super::barrier::BarrierOps::MemCopyAsync {
617 barrier: self.compile_variable(barrier),
618 source: self.compile_variable(source),
619 destination: self.compile_variable(out.unwrap()),
620 source_length: self.compile_variable(source_length),
621 offset_source: self.compile_variable(offset_source),
622 offset_out: self.compile_variable(offset_out),
623 cooperative: false,
624 },
625 ));
626 }
627 gpu::BarrierOps::MemCopyAsyncCooperative {
628 barrier,
629 source,
630 source_length,
631 offset_source,
632 offset_out,
633 } => {
634 instructions.push(Instruction::Barrier(
635 super::barrier::BarrierOps::MemCopyAsync {
636 barrier: self.compile_variable(barrier),
637 source: self.compile_variable(source),
638 destination: self.compile_variable(out.unwrap()),
639 source_length: self.compile_variable(source_length),
640 offset_source: self.compile_variable(offset_source),
641 offset_out: self.compile_variable(offset_out),
642 cooperative: true,
643 },
644 ));
645 }
646 gpu::BarrierOps::MemCopyAsyncTx {
647 barrier,
648 source,
649 source_length,
650 offset_source,
651 offset_out,
652 } => {
653 instructions.push(Instruction::Barrier(
654 super::barrier::BarrierOps::MemCopyAsyncTx {
655 barrier: self.compile_variable(barrier),
656 source: self.compile_variable(source),
657 destination: self.compile_variable(out.unwrap()),
658 source_length: self.compile_variable(source_length),
659 offset_source: self.compile_variable(offset_source),
660 offset_out: self.compile_variable(offset_out),
661 },
662 ));
663 }
664 gpu::BarrierOps::CopyAsync {
665 source,
666 source_length,
667 offset_source,
668 offset_out,
669 copy_length,
670 checked,
671 } => {
672 self.flags.inst_async_copy = true;
673 instructions.push(Instruction::Barrier(
674 super::barrier::BarrierOps::CopyAsync {
675 source: self.compile_variable(source),
676 destination: self.compile_variable(out.unwrap()),
677 source_length: self.compile_variable(source_length),
678 offset_source: self.compile_variable(offset_source),
679 offset_out: self.compile_variable(offset_out),
680 copy_size: copy_length,
681 checked,
682 },
683 ));
684 }
685 gpu::BarrierOps::TmaLoad {
686 barrier,
687 tensor_map,
688 offset_out,
689 indices,
690 } => {
691 instructions.push(Instruction::Barrier(
692 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
693 barrier: self.compile_variable(barrier),
694 smem_buffer: self.compile_variable(out.unwrap()),
695 smem_offset: self.compile_variable(offset_out),
696 tensor_map: self.compile_variable(tensor_map),
697 indices: indices
698 .into_iter()
699 .map(|it| self.compile_variable(it))
700 .collect(),
701 },
702 ));
703 }
704 gpu::BarrierOps::TmaLoadIm2col {
705 barrier,
706 tensor_map,
707 offset_out,
708 indices,
709 offsets,
710 } => {
711 self.flags.inst_tma_im2col = true;
712 instructions.push(Instruction::Barrier(
713 super::barrier::BarrierOps::TmaLoadIm2col {
714 barrier: self.compile_variable(barrier),
715 smem_buffer: self.compile_variable(out.unwrap()),
716 smem_offset: self.compile_variable(offset_out),
717 tensor_map: self.compile_variable(tensor_map),
718 indices: indices
719 .into_iter()
720 .map(|it| self.compile_variable(it))
721 .collect(),
722 offsets: offsets
723 .into_iter()
724 .map(|it| self.compile_variable(it))
725 .collect(),
726 },
727 ));
728 }
729 gpu::BarrierOps::Arrive { barrier } => {
730 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
731 barrier: self.compile_variable(barrier),
732 token: self.compile_variable(out.unwrap()),
733 }))
734 }
735 gpu::BarrierOps::ArriveTx {
736 barrier,
737 arrive_count_update,
738 transaction_count_update,
739 } => {
740 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
741 barrier: self.compile_variable(barrier),
742 token: self.compile_variable(out.unwrap()),
743 arrive_count_update: self.compile_variable(arrive_count_update),
744 transaction_count_update: self.compile_variable(transaction_count_update),
745 }))
746 }
747 gpu::BarrierOps::CommitCopyAsync { barrier } => {
748 self.flags.inst_async_copy = true;
749 instructions.push(Instruction::Barrier(
750 super::barrier::BarrierOps::ArriveCopyAsync {
751 barrier: self.compile_variable(barrier),
752 },
753 ))
754 }
755 gpu::BarrierOps::ExpectTx {
756 barrier,
757 transaction_count_update,
758 } => {
759 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
760 barrier: self.compile_variable(barrier),
761 transaction_count_update: self.compile_variable(transaction_count_update),
762 }))
763 }
764 gpu::BarrierOps::Wait { barrier, token } => {
765 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
766 barrier: self.compile_variable(barrier),
767 token: self.compile_variable(token),
768 }))
769 }
770 gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
771 Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
772 barrier: self.compile_variable(barrier),
773 phase: self.compile_variable(phase),
774 }),
775 ),
776 gpu::BarrierOps::ArriveAndWait { barrier } => {
777 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
778 else {
779 unreachable!()
780 };
781 instructions.push(Instruction::Barrier(
782 super::barrier::BarrierOps::ArriveAndWait {
783 barrier: self.compile_variable(barrier),
784 level,
785 },
786 ))
787 }
788 },
789 gpu::Operation::Tma(tma_ops) => {
790 self.flags.inst_tma = true;
791 match tma_ops {
792 gpu::TmaOps::TmaStore {
793 source,
794 coordinates,
795 offset_source,
796 } => {
797 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
798 smem_buffer: self.compile_variable(source),
799 smem_offset: self.compile_variable(offset_source),
800 tensor_map: self.compile_variable(out.unwrap()),
801 indices: coordinates
802 .into_iter()
803 .map(|it| self.compile_variable(it))
804 .collect(),
805 });
806 }
807 gpu::TmaOps::CommitGroup => {
808 instructions.push(Instruction::BulkCommitGroup);
809 }
810 gpu::TmaOps::WaitGroup { max_pending } => {
811 instructions.push(Instruction::BulkWaitGroup { max_pending });
812 }
813 gpu::TmaOps::WaitGroupRead { max_pending } => {
814 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
815 }
816 }
817 }
818 gpu::Operation::Marker(_) => {}
819 }
820 }
821
822 fn update_debug_loc(
823 &mut self,
824 instructions: &mut Vec<Instruction<D>>,
825 inst: &gpu::Instruction,
826 ) {
827 if !matches!(inst.operation, Operation::NonSemantic(_)) {
828 match &inst.source_loc {
829 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
830 self.source_loc = Some(loc.clone());
831 instructions.push(Instruction::Line {
832 file: loc.source.file.clone(),
833 line: loc.line,
834 });
835 }
836 _ => {}
837 }
838 }
839 }
840
841 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
842 self.flags.inst_wmma = true;
843
844 let out = self.compile_variable(out.unwrap());
845
846 let inst = match cmma {
847 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
848 frag: out,
849 value: self.compile_variable(value),
850 },
851 gpu::CoopMma::Load {
852 value,
853 stride,
854 offset,
855 layout,
856 } => WmmaInstruction::Load {
857 frag: out,
858 offset: self.compile_variable(offset),
859 value: self.compile_variable(value),
860 stride: self.compile_variable(stride),
861 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
862 },
863 gpu::CoopMma::Execute {
864 mat_a,
865 mat_b,
866 mat_c,
867 } => WmmaInstruction::Execute {
868 frag_a: self.compile_variable(mat_a),
869 frag_b: self.compile_variable(mat_b),
870 frag_c: self.compile_variable(mat_c),
871 frag_d: out,
872 warp_size: self.compilation_options.warp_size,
873 },
874 gpu::CoopMma::ExecuteManual {
875 matrix,
876 registers_a,
877 registers_b,
878 registers_c,
879 } => WmmaInstruction::ExecuteManual {
880 shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
881 frag_a: self.compile_variable(registers_a),
882 frag_b: self.compile_variable(registers_b),
883 frag_c: self.compile_variable(registers_c),
884 frag_d: out,
885 },
886 gpu::CoopMma::ExecuteScaled {
887 matrix,
888 registers_a,
889 registers_b,
890 registers_c,
891 scales_a,
892 scales_b,
893 scales_factor,
894 } => WmmaInstruction::ExecuteScaled {
895 shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
896 frag_a: self.compile_variable(registers_a),
897 frag_b: self.compile_variable(registers_b),
898 frag_c: self.compile_variable(registers_c),
899 frag_d: out,
900
901 scales_a: self.compile_variable(scales_a),
902 scales_b: self.compile_variable(scales_b),
903 scales_factor: scales_factor as u32,
904 },
905 gpu::CoopMma::Store {
906 mat,
907 stride,
908 offset,
909 layout,
910 } => {
911 self.flags.indexes.unit_pos = true;
912 self.flags.indexes.plane_index = true;
913 WmmaInstruction::Store {
914 output: out,
915 offset: self.compile_variable(offset),
916 frag: self.compile_variable(mat),
917 stride: self.compile_variable(stride),
918 layout: self
919 .compile_matrix_layout(layout)
920 .expect("Layout required for store instruction"),
921 }
922 }
923 gpu::CoopMma::LoadMatrix {
924 buffer,
925 offset,
926 line_size,
927 factor,
928 transpose,
929 } => WmmaInstruction::LdMatrix {
930 output: out,
931 buffer: self.compile_variable(buffer),
932 offset: self.compile_variable(offset),
933 line_size,
934 factor: factor as u32,
935 transpose,
936 },
937 gpu::CoopMma::StoreMatrix {
938 offset,
939 line_size,
940 registers,
941 factor,
942 transpose,
943 } => WmmaInstruction::StMatrix {
944 registers: self.compile_variable(registers),
945 buffer: out,
946 offset: self.compile_variable(offset),
947 line_size,
948 factor: factor as u32,
949 transpose,
950 },
951 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
952 input: self.compile_variable(input),
953 output: out,
954 },
955 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
956 panic!("Row/Col index should be handled by processors")
957 }
958 };
959
960 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
961
962 Instruction::Wmma(inst)
963 }
964
965 fn compile_metadata(
966 &mut self,
967 metadata: gpu::Metadata,
968 out: Option<gpu::Variable>,
969 ) -> Instruction<D> {
970 let out = out.unwrap();
971 match metadata {
972 gpu::Metadata::Stride { dim, var } => {
973 let position = self.ext_meta_position(var);
974 let offset = self.metadata.stride_offset_index(position);
975 Instruction::ExtendedMetadata {
976 info_offset: self.compile_variable(offset.into()),
977 dim: self.compile_variable(dim),
978 split_meta: self.compilation_options.supports_features.grid_constants,
979 static_offset: self.metadata.static_len(),
980 out: self.compile_variable(out),
981 }
982 }
983 gpu::Metadata::Shape { dim, var } => {
984 let position = self.ext_meta_position(var);
985 let offset = self.metadata.shape_offset_index(position);
986 Instruction::ExtendedMetadata {
987 info_offset: self.compile_variable(offset.into()),
988 dim: self.compile_variable(dim),
989 split_meta: self.compilation_options.supports_features.grid_constants,
990 static_offset: self.metadata.static_len(),
991 out: self.compile_variable(out),
992 }
993 }
994 gpu::Metadata::Rank { var } => {
995 let out = self.compile_variable(out);
996 let pos = self.ext_meta_position(var);
997 let offset = self.metadata.rank_index(pos);
998 super::Instruction::Metadata {
999 info_offset: self.compile_variable(offset.into()),
1000 split_meta: self.compilation_options.supports_features.grid_constants,
1001 out,
1002 }
1003 }
1004 gpu::Metadata::Length { var } => {
1005 let input = self.compile_variable(var);
1006 let out = self.compile_variable(out);
1007
1008 match input {
1009 Variable::Slice { .. } => Instruction::SliceLength { input, out },
1010 Variable::SharedArray(_id, _item, length) => {
1011 Instruction::ConstLength { length, out }
1012 }
1013 _ => {
1014 let id = input.id().expect("Variable should have id");
1015 let offset = self.metadata.len_index(id);
1016 Instruction::Metadata {
1017 info_offset: self.compile_variable(offset.into()),
1018 split_meta: self.compilation_options.supports_features.grid_constants,
1019 out,
1020 }
1021 }
1022 }
1023 }
1024 gpu::Metadata::BufferLength { var } => {
1025 let input = self.compile_variable(var);
1026 let out = self.compile_variable(out);
1027
1028 match input {
1029 Variable::Slice { .. } => Instruction::SliceLength { input, out },
1030 _ => {
1031 let id = input.id().expect("Variable should have id");
1032 let offset = self.metadata.buffer_len_index(id);
1033 Instruction::Metadata {
1034 info_offset: self.compile_variable(offset.into()),
1035 split_meta: self.compilation_options.supports_features.grid_constants,
1036 out,
1037 }
1038 }
1039 }
1040 }
1041 }
1042 }
1043
1044 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
1045 match branch {
1046 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
1047 cond: self.compile_variable(op.cond),
1048 instructions: self.compile_scope(&mut op.scope),
1049 }),
1050 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
1051 cond: self.compile_variable(op.cond),
1052 instructions_if: self.compile_scope(&mut op.scope_if),
1053 instructions_else: self.compile_scope(&mut op.scope_else),
1054 }),
1055 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1056 value: self.compile_variable(op.value),
1057 instructions_default: self.compile_scope(&mut op.scope_default),
1058 instructions_cases: op
1059 .cases
1060 .into_iter()
1061 .map(|(val, mut block)| {
1062 (self.compile_variable(val), self.compile_scope(&mut block))
1063 })
1064 .collect(),
1065 }),
1066 gpu::Branch::Return => instructions.push(Instruction::Return),
1067 gpu::Branch::Break => instructions.push(Instruction::Break),
1068 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1069 i: self.compile_variable(range_loop.i),
1070 start: self.compile_variable(range_loop.start),
1071 end: self.compile_variable(range_loop.end),
1072 step: range_loop.step.map(|it| self.compile_variable(it)),
1073 inclusive: range_loop.inclusive,
1074 instructions: self.compile_scope(&mut range_loop.scope),
1075 }),
1076 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1077 instructions: self.compile_scope(&mut op.scope),
1078 }),
1079 };
1080 }
1081
1082 fn compile_atomic(
1083 &mut self,
1084 value: gpu::AtomicOp,
1085 out: Option<gpu::Variable>,
1086 instructions: &mut Vec<Instruction<D>>,
1087 ) {
1088 let out = out.unwrap();
1089 match value {
1090 gpu::AtomicOp::Load(op) => {
1091 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1092 }
1093 gpu::AtomicOp::Store(op) => {
1094 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1095 }
1096 gpu::AtomicOp::Swap(op) => {
1097 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1098 }
1099 gpu::AtomicOp::Add(op) => {
1100 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1101 }
1102 gpu::AtomicOp::Sub(op) => {
1103 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1104 }
1105 gpu::AtomicOp::Max(op) => {
1106 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1107 }
1108 gpu::AtomicOp::Min(op) => {
1109 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1110 }
1111 gpu::AtomicOp::And(op) => {
1112 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1113 }
1114 gpu::AtomicOp::Or(op) => {
1115 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1116 }
1117 gpu::AtomicOp::Xor(op) => {
1118 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1119 }
1120 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1121 input: self.compile_variable(op.input),
1122 cmp: self.compile_variable(op.cmp),
1123 val: self.compile_variable(op.val),
1124 out: self.compile_variable(out),
1125 }),
1126 }
1127 }
1128
1129 fn compile_arithmetic(
1130 &mut self,
1131 value: gpu::Arithmetic,
1132 out: Option<gpu::Variable>,
1133 modes: InstructionModes,
1134 instructions: &mut Vec<Instruction<D>>,
1135 ) {
1136 let out = out.unwrap();
1137 match value {
1138 gpu::Arithmetic::Add(op) => {
1139 instructions.push(Instruction::Add(self.compile_binary(op, out)))
1140 }
1141 gpu::Arithmetic::SaturatingAdd(op) => {
1142 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1143 }
1144 gpu::Arithmetic::Mul(op) => {
1145 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1146 }
1147 gpu::Arithmetic::Div(op) => {
1148 let op = self.compile_binary(op, out);
1149 instructions.push(self.select_fast_float(
1150 out.ty,
1151 modes,
1152 FastMath::AllowReciprocal
1153 | FastMath::ReducedPrecision
1154 | FastMath::UnsignedZero
1155 | FastMath::NotInf,
1156 Instruction::Div(op),
1157 Instruction::FastDiv(op),
1158 ))
1159 }
1160 gpu::Arithmetic::Sub(op) => {
1161 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1162 }
1163 gpu::Arithmetic::SaturatingSub(op) => {
1164 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1165 }
1166 gpu::Arithmetic::MulHi(op) => {
1167 let instruction = Instruction::HiMul(self.compile_binary(op, out));
1168 D::register_instruction_extension(&mut self.extensions, &instruction);
1169 instructions.push(instruction)
1170 }
1171 gpu::Arithmetic::Modulo(op) => {
1172 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1173 }
1174 gpu::Arithmetic::Abs(op) => {
1175 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1176 }
1177 gpu::Arithmetic::Exp(op) => {
1178 let op = self.compile_unary(op, out);
1179 instructions.push(self.select_fast_float(
1180 out.ty,
1181 modes,
1182 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1183 Instruction::Exp(op),
1184 Instruction::FastExp(op),
1185 ));
1186 }
1187 gpu::Arithmetic::Log(op) => {
1188 let op = self.compile_unary(op, out);
1189 instructions.push(self.select_fast_float(
1190 out.ty,
1191 modes,
1192 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1193 Instruction::Log(op),
1194 Instruction::FastLog(op),
1195 ));
1196 }
1197 gpu::Arithmetic::Log1p(op) => {
1198 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1199 }
1200 gpu::Arithmetic::Cos(op) => {
1201 let op = self.compile_unary(op, out);
1202 instructions.push(self.select_fast_float(
1203 out.ty,
1204 modes,
1205 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1206 Instruction::Cos(op),
1207 Instruction::FastCos(op),
1208 ));
1209 }
1210 gpu::Arithmetic::Sin(op) => {
1211 let op = self.compile_unary(op, out);
1212 instructions.push(self.select_fast_float(
1213 out.ty,
1214 modes,
1215 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1216 Instruction::Sin(op),
1217 Instruction::FastSin(op),
1218 ));
1219 }
1220 gpu::Arithmetic::Tan(op) => {
1221 instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1222 }
1223 gpu::Arithmetic::Tanh(op) => {
1224 let op = self.compile_unary(op, out);
1225 let instruction = Instruction::Tanh(op);
1226 D::register_instruction_extension(&mut self.extensions, &instruction);
1227 if self.compilation_options.supports_features.fast_tanh {
1228 instructions.push(self.select_fast_float(
1229 out.ty,
1230 modes,
1231 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1232 instruction,
1233 Instruction::FastTanh(op),
1234 ))
1235 } else {
1236 instructions.push(instruction);
1237 }
1238 }
1239 gpu::Arithmetic::Sinh(op) => {
1240 let instruction = Instruction::Sinh(self.compile_unary(op, out));
1241 D::register_instruction_extension(&mut self.extensions, &instruction);
1242 instructions.push(instruction)
1243 }
1244 gpu::Arithmetic::Cosh(op) => {
1245 let instruction = Instruction::Cosh(self.compile_unary(op, out));
1246 D::register_instruction_extension(&mut self.extensions, &instruction);
1247 instructions.push(instruction)
1248 }
1249 gpu::Arithmetic::ArcCos(op) => {
1250 let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1251 D::register_instruction_extension(&mut self.extensions, &instruction);
1252 instructions.push(instruction)
1253 }
1254 gpu::Arithmetic::ArcSin(op) => {
1255 let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1256 D::register_instruction_extension(&mut self.extensions, &instruction);
1257 instructions.push(instruction)
1258 }
1259 gpu::Arithmetic::ArcTan(op) => {
1260 let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1261 D::register_instruction_extension(&mut self.extensions, &instruction);
1262 instructions.push(instruction)
1263 }
1264 gpu::Arithmetic::ArcSinh(op) => {
1265 let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1266 D::register_instruction_extension(&mut self.extensions, &instruction);
1267 instructions.push(instruction)
1268 }
1269 gpu::Arithmetic::ArcCosh(op) => {
1270 let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1271 D::register_instruction_extension(&mut self.extensions, &instruction);
1272 instructions.push(instruction)
1273 }
1274 gpu::Arithmetic::ArcTanh(op) => {
1275 let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1276 D::register_instruction_extension(&mut self.extensions, &instruction);
1277 instructions.push(instruction)
1278 }
1279 gpu::Arithmetic::Degrees(op) => {
1280 let instruction = Instruction::Degrees(self.compile_unary(op, out));
1281 D::register_instruction_extension(&mut self.extensions, &instruction);
1282 instructions.push(instruction)
1283 }
1284 gpu::Arithmetic::Radians(op) => {
1285 let instruction = Instruction::Radians(self.compile_unary(op, out));
1286 D::register_instruction_extension(&mut self.extensions, &instruction);
1287 instructions.push(instruction)
1288 }
1289 gpu::Arithmetic::ArcTan2(op) => {
1290 let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1291 D::register_instruction_extension(&mut self.extensions, &instruction);
1292 instructions.push(instruction)
1293 }
1294 gpu::Arithmetic::Powf(op) => {
1295 let op = self.compile_binary(op, out);
1296 instructions.push(self.select_fast_float(
1297 out.ty,
1298 modes,
1299 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1300 Instruction::Powf(op),
1301 Instruction::FastPowf(op),
1302 ))
1303 }
1304 gpu::Arithmetic::Powi(op) => {
1305 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1306 }
1307 gpu::Arithmetic::Hypot(op) => {
1308 instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1309 }
1310 gpu::Arithmetic::Rhypot(op) => {
1311 instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1312 }
1313 gpu::Arithmetic::Sqrt(op) => {
1314 let op = self.compile_unary(op, out);
1315 instructions.push(self.select_fast_float(
1316 out.ty,
1317 modes,
1318 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1319 Instruction::Sqrt(op),
1320 Instruction::FastSqrt(op),
1321 ))
1322 }
1323 gpu::Arithmetic::InverseSqrt(op) => {
1324 let op = self.compile_unary(op, out);
1325 instructions.push(self.select_fast_float(
1326 out.ty,
1327 modes,
1328 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1329 Instruction::InverseSqrt(op),
1330 Instruction::FastInverseSqrt(op),
1331 ))
1332 }
1333 gpu::Arithmetic::Erf(op) => {
1334 let instruction = Instruction::Erf(self.compile_unary(op, out));
1335 D::register_instruction_extension(&mut self.extensions, &instruction);
1336 instructions.push(instruction)
1337 }
1338 gpu::Arithmetic::Max(op) => {
1339 let instruction = Instruction::Max(self.compile_binary(op, out));
1340 D::register_instruction_extension(&mut self.extensions, &instruction);
1341 instructions.push(instruction)
1342 }
1343 gpu::Arithmetic::Min(op) => {
1344 let instruction = Instruction::Min(self.compile_binary(op, out));
1345 D::register_instruction_extension(&mut self.extensions, &instruction);
1346 instructions.push(instruction)
1347 }
1348 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1349 input: self.compile_variable(op.input),
1350 min_value: self.compile_variable(op.min_value),
1351 max_value: self.compile_variable(op.max_value),
1352 out: self.compile_variable(out),
1353 }),
1354 gpu::Arithmetic::Recip(op) => {
1355 let elem = op.input.ty.elem_type();
1356 let input = self.compile_variable(op.input);
1357 let out = self.compile_variable(out);
1358 let lhs = match elem {
1359 gpu::ElemType::Float(_) => gpu::ConstantValue::Float(1.0),
1360 gpu::ElemType::Int(_) => gpu::ConstantValue::Int(1),
1361 gpu::ElemType::UInt(_) => gpu::ConstantValue::UInt(1),
1362 gpu::ElemType::Bool => gpu::ConstantValue::Bool(true),
1363 };
1364 let div = Instruction::Div(BinaryInstruction {
1365 lhs: Variable::Constant(lhs, self.compile_type(op.input.ty)),
1366 rhs: input,
1367 out,
1368 });
1369 let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1370
1371 instructions.push(self.select_fast_float(
1372 elem.into(),
1373 modes,
1374 FastMath::AllowReciprocal
1375 | FastMath::ReducedPrecision
1376 | FastMath::UnsignedZero
1377 | FastMath::NotInf,
1378 div,
1379 recip,
1380 ))
1381 }
1382 gpu::Arithmetic::Round(op) => {
1383 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1384 }
1385 gpu::Arithmetic::Floor(op) => {
1386 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1387 }
1388 gpu::Arithmetic::Ceil(op) => {
1389 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1390 }
1391 gpu::Arithmetic::Trunc(op) => {
1392 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1393 }
1394 gpu::Arithmetic::Remainder(op) => {
1395 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1396 }
1397 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1398 a: self.compile_variable(op.a),
1399 b: self.compile_variable(op.b),
1400 c: self.compile_variable(op.c),
1401 out: self.compile_variable(out),
1402 }),
1403 gpu::Arithmetic::Neg(op) => {
1404 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1405 }
1406 gpu::Arithmetic::Normalize(op) => {
1407 let op = self.compile_unary(op, out);
1408 instructions.push(self.select_fast_float(
1409 out.ty,
1410 modes,
1411 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1412 Instruction::Normalize(op),
1413 Instruction::FastNormalize(op),
1414 ))
1415 }
1416 gpu::Arithmetic::Magnitude(op) => {
1417 let op = self.compile_unary(op, out);
1418 instructions.push(self.select_fast_float(
1419 out.ty,
1420 modes,
1421 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1422 Instruction::Magnitude(op),
1423 Instruction::FastMagnitude(op),
1424 ))
1425 }
1426 gpu::Arithmetic::Dot(op) => {
1427 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1428 }
1429 };
1430 }
1431
1432 fn select_fast_float(
1433 &self,
1434 ty: gpu::Type,
1435 modes: InstructionModes,
1436 required_flags: EnumSet<FastMath>,
1437 default: Instruction<D>,
1438 fast: Instruction<D>,
1439 ) -> Instruction<D> {
1440 if !self.compilation_options.supports_features.fast_math
1441 || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1442 {
1443 return default;
1444 }
1445
1446 if modes.fp_math_mode.is_superset(required_flags) {
1447 fast
1448 } else {
1449 default
1450 }
1451 }
1452
1453 fn compile_comparison(
1454 &mut self,
1455 value: gpu::Comparison,
1456 out: Option<gpu::Variable>,
1457 instructions: &mut Vec<Instruction<D>>,
1458 ) {
1459 let out = out.unwrap();
1460 match value {
1461 gpu::Comparison::Equal(op) => {
1462 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1463 }
1464 gpu::Comparison::Lower(op) => {
1465 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1466 }
1467 gpu::Comparison::Greater(op) => {
1468 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1469 }
1470 gpu::Comparison::LowerEqual(op) => {
1471 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1472 }
1473 gpu::Comparison::GreaterEqual(op) => {
1474 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1475 }
1476 gpu::Comparison::NotEqual(op) => {
1477 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1478 }
1479 gpu::Comparison::IsNan(op) => {
1480 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1481 }
1482 gpu::Comparison::IsInf(op) => {
1483 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1484 }
1485 };
1486 }
1487
1488 fn compile_bitwise(
1489 &mut self,
1490 value: gpu::Bitwise,
1491 out: Option<gpu::Variable>,
1492 instructions: &mut Vec<Instruction<D>>,
1493 ) {
1494 let out = out.unwrap();
1495 match value {
1496 gpu::Bitwise::BitwiseOr(op) => {
1497 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1498 }
1499 gpu::Bitwise::BitwiseAnd(op) => {
1500 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1501 }
1502 gpu::Bitwise::BitwiseXor(op) => {
1503 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1504 }
1505 gpu::Bitwise::CountOnes(op) => {
1506 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1507 }
1508 gpu::Bitwise::ReverseBits(op) => {
1509 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1510 }
1511 gpu::Bitwise::ShiftLeft(op) => {
1512 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1513 }
1514 gpu::Bitwise::ShiftRight(op) => {
1515 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1516 }
1517 gpu::Bitwise::BitwiseNot(op) => {
1518 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1519 }
1520 gpu::Bitwise::LeadingZeros(op) => {
1521 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1522 }
1523 gpu::Bitwise::FindFirstSet(op) => {
1524 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1525 D::register_instruction_extension(&mut self.extensions, &instruction);
1526 instructions.push(instruction)
1527 }
1528 };
1529 }
1530
1531 fn compile_operator(
1532 &mut self,
1533 value: gpu::Operator,
1534 out: Option<gpu::Variable>,
1535 instructions: &mut Vec<Instruction<D>>,
1536 ) {
1537 let out = out.unwrap();
1538 match value {
1539 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1540 instructions.push(Instruction::Index(self.compile_index(op, out)));
1541 }
1542 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1543 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1544 }
1545 gpu::Operator::And(op) => {
1546 instructions.push(Instruction::And(self.compile_binary(op, out)))
1547 }
1548 gpu::Operator::Or(op) => {
1549 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1550 }
1551 gpu::Operator::Not(op) => {
1552 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1553 }
1554 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1555 inputs: op
1556 .inputs
1557 .into_iter()
1558 .map(|it| self.compile_variable(it))
1559 .collect(),
1560 out: self.compile_variable(out),
1561 }),
1562 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1563 input: self.compile_variable(op.input),
1564 in_index: self.compile_variable(op.in_index),
1565 out: self.compile_variable(out),
1566 out_index: self.compile_variable(op.out_index),
1567 }),
1568 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1569 input: self.compile_variable(op.input),
1570 in_index: self.compile_variable(op.in_index),
1571 out: self.compile_variable(out),
1572 out_index: self.compile_variable(op.out_index),
1573 len: op.len as u32,
1574 }),
1575 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1576 cond: self.compile_variable(op.cond),
1577 then: self.compile_variable(op.then),
1578 or_else: self.compile_variable(op.or_else),
1579 out: self.compile_variable(out),
1580 }),
1581 gpu::Operator::Cast(op)
1583 if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1584 {
1585 self.flags.elem_f16 = true;
1587 self.flags.elem_bf16 = true;
1588 let vec_in = op.input.ty.line_size();
1589 let packing = out.storage_type().packing_factor();
1590 self.compile_type(op.input.ty.line(packing));
1591 self.compile_type(
1592 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1593 );
1594 self.compile_type(
1595 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1596 );
1597 self.compile_type(
1598 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1599 );
1600 self.compile_type(
1601 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1602 );
1603
1604 let inst = self.compile_unary(op, out);
1605
1606 instructions.push(Instruction::SpecialCast(inst));
1607 }
1608 gpu::Operator::Cast(op) => {
1609 let op = self.compile_unary(op, out);
1610
1611 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1612 self.flags.elem_tf32 = true;
1613 }
1614
1615 instructions.push(Instruction::Assign(op))
1616 }
1617 gpu::Operator::Reinterpret(op) => {
1618 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1619 }
1620 };
1621 }
1622
1623 fn compile_binary(
1624 &mut self,
1625 value: gpu::BinaryOperator,
1626 out: gpu::Variable,
1627 ) -> BinaryInstruction<D> {
1628 BinaryInstruction {
1629 lhs: self.compile_variable(value.lhs),
1630 rhs: self.compile_variable(value.rhs),
1631 out: self.compile_variable(out),
1632 }
1633 }
1634
1635 fn compile_index_assign(
1636 &mut self,
1637 value: gpu::IndexAssignOperator,
1638 out: gpu::Variable,
1639 ) -> IndexAssignInstruction<D> {
1640 IndexAssignInstruction {
1641 index: self.compile_variable(value.index),
1642 value: self.compile_variable(value.value),
1643 line_size: value.line_size as u32,
1644 out: self.compile_variable(out),
1645 }
1646 }
1647
1648 fn compile_index(
1649 &mut self,
1650 value: gpu::IndexOperator,
1651 out: gpu::Variable,
1652 ) -> IndexInstruction<D> {
1653 IndexInstruction {
1654 list: self.compile_variable(value.list),
1655 index: self.compile_variable(value.index),
1656 line_size: value.line_size as u32,
1657 out: self.compile_variable(out),
1658 }
1659 }
1660
1661 fn compile_unary(
1662 &mut self,
1663 value: gpu::UnaryOperator,
1664 out: gpu::Variable,
1665 ) -> UnaryInstruction<D> {
1666 UnaryInstruction {
1667 input: self.compile_variable(value.input),
1668 out: self.compile_variable(out),
1669 }
1670 }
1671
1672 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1673 let item = value.ty;
1674 match value.kind {
1675 gpu::VariableKind::GlobalInputArray(id) => {
1676 Variable::GlobalInputArray(id, self.compile_type(item))
1677 }
1678 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1679 id,
1680 elem: self.compile_storage_type(item.storage_type()),
1681 in_struct: self.compilation_options.supports_features.grid_constants,
1682 },
1683 gpu::VariableKind::TensorMapInput(id) => {
1684 self.flags.inst_tma = true;
1685 Variable::TensorMap(id)
1686 }
1687 gpu::VariableKind::TensorMapOutput(id) => {
1688 self.flags.inst_tma = true;
1689 Variable::TensorMap(id)
1690 }
1691 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1692 id,
1693 item: self.compile_type(item),
1694 },
1695 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1696 id,
1697 item: self.compile_type(item),
1698 },
1699 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1700 id,
1701 item: self.compile_type(item),
1702 },
1703 gpu::VariableKind::GlobalOutputArray(id) => {
1704 Variable::GlobalOutputArray(id, self.compile_type(item))
1705 }
1706 gpu::VariableKind::Constant(value) => {
1707 Variable::Constant(value, self.compile_type(item))
1708 }
1709 gpu::VariableKind::SharedArray { id, length, .. } => {
1710 let item = self.compile_type(item);
1711 Variable::SharedArray(id, item, length)
1712 }
1713 gpu::VariableKind::Shared { id } => {
1714 let item = self.compile_type(item);
1715 Variable::Shared(id, item)
1716 }
1717 gpu::VariableKind::ConstantArray {
1718 id,
1719 length,
1720 unroll_factor,
1721 } => {
1722 let item = self.compile_type(item);
1723 Variable::ConstantArray(id, item, length * unroll_factor)
1724 }
1725 gpu::VariableKind::Builtin(builtin) => match builtin {
1726 gpu::Builtin::AbsolutePos => {
1727 self.flags.indexes.absolute_pos = true;
1728 let item = self.compile_type(item);
1729 Variable::AbsolutePos(item.elem)
1730 }
1731 gpu::Builtin::CubePosCluster
1732 if self.compilation_options.supports_features.clusters =>
1733 {
1734 self.flags.indexes.cluster_pos = true;
1735 Variable::ClusterRank
1736 }
1737 gpu::Builtin::CubePosClusterX
1738 if self.compilation_options.supports_features.clusters =>
1739 {
1740 self.flags.indexes.cluster_pos = true;
1741 Variable::ClusterIndexX
1742 }
1743 gpu::Builtin::CubePosClusterY
1744 if self.compilation_options.supports_features.clusters =>
1745 {
1746 self.flags.indexes.cluster_pos = true;
1747 Variable::ClusterIndexY
1748 }
1749 gpu::Builtin::CubePosClusterZ
1750 if self.compilation_options.supports_features.clusters =>
1751 {
1752 self.flags.indexes.cluster_pos = true;
1753 Variable::ClusterIndexZ
1754 }
1755 gpu::Builtin::CubePosCluster
1758 | gpu::Builtin::CubePosClusterX
1759 | gpu::Builtin::CubePosClusterY
1760 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1761 gpu::Builtin::AbsolutePosX => {
1762 self.flags.indexes.absolute_pos_tuple = true;
1763 Variable::AbsolutePosX
1764 }
1765 gpu::Builtin::AbsolutePosY => {
1766 self.flags.indexes.absolute_pos_tuple = true;
1767 Variable::AbsolutePosY
1768 }
1769 gpu::Builtin::AbsolutePosZ => {
1770 self.flags.indexes.absolute_pos_tuple = true;
1771 Variable::AbsolutePosZ
1772 }
1773 gpu::Builtin::CubeDim => {
1774 self.flags.indexes.cube_dim = true;
1775 Variable::CubeDim
1776 }
1777 gpu::Builtin::CubeDimX => {
1778 self.flags.indexes.cube_dim_tuple = true;
1779 Variable::CubeDimX
1780 }
1781 gpu::Builtin::CubeDimY => {
1782 self.flags.indexes.cube_dim_tuple = true;
1783 Variable::CubeDimY
1784 }
1785 gpu::Builtin::CubeDimZ => {
1786 self.flags.indexes.cube_dim_tuple = true;
1787 Variable::CubeDimZ
1788 }
1789 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1790 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1791 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1792 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1793 gpu::Builtin::CubePos => {
1794 self.flags.indexes.cube_pos = true;
1795 let item = self.compile_type(item);
1796 Variable::CubePos(item.elem)
1797 }
1798 gpu::Builtin::CubePosX => {
1799 self.flags.indexes.cube_pos_tuple = true;
1800 Variable::CubePosX
1801 }
1802 gpu::Builtin::CubePosY => {
1803 self.flags.indexes.cube_pos_tuple = true;
1804 Variable::CubePosY
1805 }
1806 gpu::Builtin::CubePosZ => {
1807 self.flags.indexes.cube_pos_tuple = true;
1808 Variable::CubePosZ
1809 }
1810 gpu::Builtin::CubeCount => {
1811 self.flags.indexes.cube_count = true;
1812 let item = self.compile_type(item);
1813 Variable::CubeCount(item.elem)
1814 }
1815 gpu::Builtin::CubeCountX => {
1816 self.flags.indexes.cube_count_tuple = true;
1817 Variable::CubeCountX
1818 }
1819 gpu::Builtin::CubeCountY => {
1820 self.flags.indexes.cube_count_tuple = true;
1821 Variable::CubeCountY
1822 }
1823 gpu::Builtin::CubeCountZ => {
1824 self.flags.indexes.cube_count_tuple = true;
1825 Variable::CubeCountZ
1826 }
1827 gpu::Builtin::UnitPos => {
1828 self.flags.indexes.unit_pos = true;
1829 Variable::UnitPos
1830 }
1831 gpu::Builtin::UnitPosX => {
1832 self.flags.indexes.unit_pos_tuple = true;
1833 Variable::UnitPosX
1834 }
1835 gpu::Builtin::UnitPosY => {
1836 self.flags.indexes.unit_pos_tuple = true;
1837 Variable::UnitPosY
1838 }
1839 gpu::Builtin::UnitPosZ => {
1840 self.flags.indexes.unit_pos_tuple = true;
1841 Variable::UnitPosZ
1842 }
1843 gpu::Builtin::PlaneDim => {
1844 self.flags.indexes.plane_dim = true;
1845 Variable::PlaneDim
1846 }
1847 gpu::Builtin::UnitPosPlane => {
1848 self.flags.indexes.unit_pos_plane = true;
1849 Variable::UnitPosPlane
1850 }
1851 },
1852 gpu::VariableKind::LocalArray {
1853 id,
1854 length,
1855 unroll_factor,
1856 } => {
1857 let item = self.compile_type(item);
1858 if !self.local_arrays.iter().any(|s| s.index == id) {
1859 self.local_arrays
1860 .push(LocalArray::new(id, item, length * unroll_factor));
1861 }
1862 Variable::LocalArray(id, item, length)
1863 }
1864 gpu::VariableKind::Matrix { id, mat } => {
1865 self.flags.inst_wmma = true;
1866 Variable::WmmaFragment {
1867 id,
1868 frag: self.compile_matrix(mat),
1869 }
1870 }
1871 gpu::VariableKind::Pipeline { id, num_stages } => {
1872 self.flags.op_pipeline = true;
1873 let pipeline = Variable::Pipeline { id };
1874 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1875 self.pipelines.push(PipelineOps::Init {
1876 pipeline,
1877 num_stages,
1878 });
1879 }
1880 pipeline
1881 }
1882 gpu::VariableKind::BarrierToken { id, level } => {
1883 self.flags.op_barrier = true;
1884 Variable::BarrierToken { id, level }
1885 }
1886 }
1887 }
1888
1889 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1890 Fragment {
1891 ident: self.compile_matrix_ident(matrix.ident),
1892 m: matrix.m as u32,
1893 n: matrix.n as u32,
1894 k: matrix.k as u32,
1895 elem: self.compile_storage_type(matrix.storage),
1896 layout: self.compile_matrix_layout(matrix.layout),
1897 }
1898 }
1899
1900 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1901 match ident {
1902 gpu::MatrixIdent::A => FragmentIdent::A,
1903 gpu::MatrixIdent::B => FragmentIdent::B,
1904 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1905 }
1906 }
1907
1908 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1909 match layout {
1910 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1911 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1912 gpu::MatrixLayout::Undefined => None,
1913 }
1914 }
1915
1916 fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1917 Binding {
1918 id: binding.id,
1919 item: self.compile_type(binding.ty),
1920 location: binding.location,
1921 size: binding.size,
1922 vis: binding.visibility,
1923 }
1924 }
1925
1926 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1927 let item = match ty {
1928 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1929 gpu::Type::Line(ty, line_size) => {
1930 Item::new(self.compile_storage_type(ty), line_size, false)
1931 }
1932 gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1933 };
1934 if item.elem != super::Elem::TF32 {
1935 self.items.insert(item);
1936 self.items.insert(item.optimized());
1937 } else {
1938 let mut item = item;
1940 item.elem = super::Elem::F32;
1941 self.items.insert(item);
1942 }
1943
1944 item
1945 }
1946
1947 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1948 match value {
1949 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1950 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1951 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1952 FloatKind::E2M1 => {
1953 self.flags.elem_fp4 = true;
1954 Elem::FP4x2(FP4Kind::E2M1)
1955 }
1956 FloatKind::E2M3 => {
1957 self.flags.elem_fp6 = true;
1958 Elem::FP6x2(FP6Kind::E2M3)
1959 }
1960 FloatKind::E3M2 => {
1961 self.flags.elem_fp6 = true;
1962 Elem::FP6(FP6Kind::E3M2)
1963 }
1964 FloatKind::E4M3 => {
1965 self.flags.elem_fp8 = true;
1966 Elem::FP8x2(FP8Kind::E4M3)
1967 }
1968 FloatKind::E5M2 => {
1969 self.flags.elem_fp8 = true;
1970 Elem::FP8x2(FP8Kind::E5M2)
1971 }
1972 FloatKind::UE8M0 => {
1973 self.flags.elem_fp8 = true;
1974 Elem::FP8x2(FP8Kind::UE8M0)
1975 }
1976 FloatKind::F16 => {
1977 self.flags.elem_f16 = true;
1978 Elem::F16x2
1979 }
1980 FloatKind::BF16 => {
1981 self.flags.elem_bf16 = true;
1982 Elem::BF16x2
1983 }
1984 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1985 },
1986 gpu::StorageType::Packed(other, factor) => {
1987 unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
1988 }
1989 gpu::StorageType::Opaque(ty) => match ty {
1990 gpu::OpaqueType::Barrier(level) => {
1991 self.flags.op_barrier = true;
1992 Elem::Barrier(level)
1993 }
1994 },
1995 }
1996 }
1997
1998 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1999 match value {
2000 gpu::ElemType::Float(kind) => match kind {
2001 gpu::FloatKind::E2M1 => {
2002 self.flags.elem_fp4 = true;
2003 Elem::FP4(FP4Kind::E2M1)
2004 }
2005 gpu::FloatKind::E2M3 => {
2006 self.flags.elem_fp6 = true;
2007 Elem::FP6(FP6Kind::E2M3)
2008 }
2009 gpu::FloatKind::E3M2 => {
2010 self.flags.elem_fp6 = true;
2011 Elem::FP6(FP6Kind::E3M2)
2012 }
2013 gpu::FloatKind::E4M3 => {
2014 self.flags.elem_fp8 = true;
2015 Elem::FP8(FP8Kind::E4M3)
2016 }
2017 gpu::FloatKind::E5M2 => {
2018 self.flags.elem_fp8 = true;
2019 Elem::FP8(FP8Kind::E5M2)
2020 }
2021 gpu::FloatKind::UE8M0 => {
2022 self.flags.elem_fp8 = true;
2023 Elem::FP8(FP8Kind::UE8M0)
2024 }
2025 gpu::FloatKind::F16 => {
2026 self.flags.elem_f16 = true;
2027 Elem::F16
2028 }
2029 gpu::FloatKind::BF16 => {
2030 self.flags.elem_bf16 = true;
2031 Elem::BF16
2032 }
2033 gpu::FloatKind::TF32 => Elem::TF32,
2034 gpu::FloatKind::Flex32 => Elem::F32,
2035 gpu::FloatKind::F32 => Elem::F32,
2036 gpu::FloatKind::F64 => Elem::F64,
2037 },
2038 gpu::ElemType::Int(kind) => match kind {
2039 gpu::IntKind::I8 => Elem::I8,
2040 gpu::IntKind::I16 => Elem::I16,
2041 gpu::IntKind::I32 => Elem::I32,
2042 gpu::IntKind::I64 => Elem::I64,
2043 },
2044 gpu::ElemType::UInt(kind) => match kind {
2045 gpu::UIntKind::U8 => Elem::U8,
2046 gpu::UIntKind::U16 => Elem::U16,
2047 gpu::UIntKind::U32 => Elem::U32,
2048 gpu::UIntKind::U64 => Elem::U64,
2049 },
2050 gpu::ElemType::Bool => Elem::Bool,
2051 }
2052 }
2053}
2054
2055fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
2056 match elem {
2057 gpu::ElemType::Float(kind) => matches!(
2058 kind,
2059 FloatKind::E2M1
2060 | FloatKind::E2M3
2061 | FloatKind::E3M2
2062 | FloatKind::E4M3
2063 | FloatKind::E5M2
2064 | FloatKind::UE8M0
2065 ),
2066 _ => false,
2067 }
2068}
2069
2070fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2071 Variable::Constant(
2072 gpu::ConstantValue::UInt(value as u64),
2073 Item::new(Elem::U32, 1, true),
2074 )
2075}
2076
2077pub fn register_supported_types(props: &mut DeviceProperties) {
2078 props.register_address_type(gpu::AddressType::U32);
2079 props.register_address_type(gpu::AddressType::U64);
2080
2081 let supported_types = [
2082 gpu::ElemType::UInt(gpu::UIntKind::U8),
2083 gpu::ElemType::UInt(gpu::UIntKind::U16),
2084 gpu::ElemType::UInt(gpu::UIntKind::U32),
2085 gpu::ElemType::UInt(gpu::UIntKind::U64),
2086 gpu::ElemType::Int(gpu::IntKind::I8),
2087 gpu::ElemType::Int(gpu::IntKind::I16),
2088 gpu::ElemType::Int(gpu::IntKind::I32),
2089 gpu::ElemType::Int(gpu::IntKind::I64),
2090 gpu::ElemType::Float(gpu::FloatKind::BF16),
2091 gpu::ElemType::Float(gpu::FloatKind::F16),
2092 gpu::ElemType::Float(gpu::FloatKind::F32),
2093 gpu::ElemType::Float(gpu::FloatKind::Flex32),
2094 gpu::ElemType::Bool,
2097 ];
2098
2099 let supported_atomic_types = [
2100 gpu::ElemType::Int(gpu::IntKind::I32),
2101 gpu::ElemType::Int(gpu::IntKind::I64),
2102 gpu::ElemType::UInt(gpu::UIntKind::U32),
2103 gpu::ElemType::UInt(gpu::UIntKind::U64),
2104 gpu::ElemType::Float(gpu::FloatKind::F32),
2105 ];
2106
2107 for ty in supported_types {
2108 props.register_type_usage(ty, TypeUsage::all_scalar());
2109 }
2110
2111 for ty in supported_atomic_types {
2112 props.register_type_usage(
2113 gpu::StorageType::Atomic(ty),
2114 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
2115 );
2116 }
2117}