1use super::{
2 BinaryInstruction, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP4Kind, FP6Kind,
3 FP8Kind, Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction,
4 Instruction, Item, KernelArg, LocalArray, SharedMemory, UnaryInstruction, Variable,
5 WarpInstruction, WmmaInstruction, barrier::BarrierOps, pipeline::PipelineOps,
6};
7use crate::shared::MmaShape;
8use cubecl_common::backtrace::BackTrace;
9use cubecl_core::{
10 CubeDim,
11 ir::{
12 self as gpu, DeviceProperties, ElemType, FloatKind, InstructionModes, OpaqueType,
13 Operation, Processor, SourceLoc, StorageType, Type,
14 features::{AtomicUsage, EnumSet, TypeUsage},
15 },
16 post_processing::checked_io::CheckedIoProcessor,
17 prelude::{FastMath, KernelDefinition},
18 server::ExecutionMode,
19};
20use cubecl_opt::{Optimizer, SharedLiveness};
21use cubecl_runtime::compiler::{CompilationError, Compiler};
22use std::{collections::HashSet, fmt::Debug};
23
24pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
25 std::sync::atomic::AtomicU32::new(0);
26
27#[derive(Clone, Debug)]
28pub struct CompilationOptions {
29 pub warp_size: u32,
30 pub supports_features: CppSupportedFeatures,
31}
32
33#[derive(Clone, Debug, Default)]
34pub struct CppSupportedFeatures {
35 pub grid_constants: bool,
36 pub clusters: bool,
37 pub fast_math: bool,
38 pub fast_tanh: bool,
39 pub elect_sync: bool,
40}
41
42impl Default for CompilationOptions {
43 fn default() -> Self {
44 Self {
45 warp_size: 32,
46 supports_features: Default::default(),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Default)]
54pub struct CubeIndexFlags {
55 pub absolute_pos: bool,
56 pub absolute_pos_tuple: bool,
57 pub cube_count: bool,
58 pub cube_count_tuple: bool,
59 pub cube_dim: bool,
60 pub cube_dim_tuple: bool,
61 pub cube_pos: bool,
62 pub cube_pos_tuple: bool,
63 pub plane_dim: bool,
64 pub plane_dim_checked: bool,
65 pub plane_pos: bool,
66 pub unit_pos: bool,
67 pub unit_pos_tuple: bool,
68 pub unit_pos_plane: bool,
69 pub cluster_pos: bool,
70}
71
72#[derive(Debug, Clone)]
74pub struct Flags<D: Dialect> {
75 pub elem_fp4: bool,
76 pub elem_fp6: bool,
77 pub elem_fp8: bool,
78 pub elem_bf16: bool,
79 pub elem_f16: bool,
80 pub elem_tf32: bool,
81 pub indexes: CubeIndexFlags,
82 pub op_barrier: bool,
83 pub op_pipeline: bool,
84 pub inst_tma: bool,
85 pub inst_tma_im2col: bool,
86 pub inst_wmma: bool,
87 pub inst_ptx_wrappers: bool,
88 pub inst_async_copy: bool,
89 pub use_grid_constants: bool,
90 pub static_meta_length: usize,
91 pub has_dynamic_meta: bool,
92 pub has_info: bool,
93 pub cube_dim: CubeDim,
94 pub cluster_dim: Option<CubeDim>,
95 pub address_type: Item<D>,
96}
97
98#[allow(clippy::too_many_arguments)]
99#[derive(Clone, Debug)]
100pub struct CppCompiler<D: Dialect> {
101 kernel_name: String,
102 barriers: Vec<BarrierOps<D>>,
103 compilation_options: CompilationOptions,
104 const_arrays: Vec<ConstArray<D>>,
105 ext_meta_positions: Vec<u32>,
106 cluster_dim: CubeDim,
107 extensions: Vec<D::Extension>,
108 flags: Flags<D>,
109 items: HashSet<Item<D>>,
110 local_arrays: Vec<LocalArray<D>>,
111 info: cubecl_core::Info,
112 pipelines: Vec<PipelineOps<D>>,
113 source_loc: Option<SourceLoc>,
114 strategy: ExecutionMode,
115 addr_type: Item<D>,
116}
117
118impl<D: Dialect> Default for Flags<D> {
119 fn default() -> Self {
120 Self {
121 elem_fp4: Default::default(),
122 elem_fp6: Default::default(),
123 elem_fp8: Default::default(),
124 elem_bf16: Default::default(),
125 elem_f16: Default::default(),
126 elem_tf32: Default::default(),
127 indexes: Default::default(),
128 op_barrier: Default::default(),
129 op_pipeline: Default::default(),
130 inst_tma: Default::default(),
131 inst_tma_im2col: Default::default(),
132 inst_wmma: Default::default(),
133 inst_ptx_wrappers: Default::default(),
134 inst_async_copy: Default::default(),
135 use_grid_constants: Default::default(),
136 static_meta_length: Default::default(),
137 has_info: Default::default(),
138 has_dynamic_meta: Default::default(),
139 cube_dim: CubeDim::new_single(),
140 cluster_dim: Default::default(),
141 address_type: Item::scalar(Elem::U32, true),
142 }
143 }
144}
145
146impl<D: Dialect> Default for CppCompiler<D> {
147 fn default() -> Self {
148 Self {
149 kernel_name: Default::default(),
150 barriers: Default::default(),
151 compilation_options: Default::default(),
152 const_arrays: Default::default(),
153 ext_meta_positions: Default::default(),
154 cluster_dim: CubeDim::new_single(),
155 extensions: Default::default(),
156 flags: Flags::default(),
157 items: Default::default(),
158 local_arrays: Default::default(),
159 info: Default::default(),
160 pipelines: Default::default(),
161 source_loc: Default::default(),
162 strategy: Default::default(),
163 addr_type: Item::scalar(Elem::U32, true),
164 }
165 }
166}
167
168impl<D: Dialect> Compiler for CppCompiler<D> {
169 type Representation = ComputeKernel<D>;
170 type CompilationOptions = CompilationOptions;
171
172 fn compile(
173 &mut self,
174 mut kernel: KernelDefinition,
175 compilation_options: &Self::CompilationOptions,
176 strategy: ExecutionMode,
177 addr_type: StorageType,
178 ) -> Result<Self::Representation, CompilationError> {
179 let errors = kernel.body.pop_errors();
180 if !errors.is_empty() {
181 let mut reason = "Can't compile cpp kernel\nCaused by:\n ".to_string();
182 for error in errors {
183 reason += error.as_str();
184 reason += "\n";
185 }
186
187 return Err(CompilationError::Validation {
188 reason,
189 backtrace: BackTrace::capture(),
190 });
191 }
192
193 self.addr_type = self.compile_type(addr_type.into());
194 self.compilation_options = compilation_options.clone();
195 self.strategy = strategy;
196 self.kernel_name = kernel.options.kernel_name.clone();
197
198 if !self.compilation_options.supports_features.clusters {
199 kernel.options.cluster_dim = None;
200 }
201 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
202
203 let ir = self.clone().compile_ir(kernel, addr_type);
204 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
205 Ok(ir)
206 }
207
208 fn elem_size(&self, elem: gpu::ElemType) -> usize {
209 elem.size()
210 }
211
212 fn extension(&self) -> &'static str {
213 "cpp"
214 }
215}
216
217impl<D: Dialect> CppCompiler<D> {
218 fn compile_ir(
219 mut self,
220 value: KernelDefinition,
221 address_type: StorageType,
222 ) -> ComputeKernel<D> {
223 let metadata = self.build_metadata(&value);
224 self.info = cubecl_core::Info::new(&value.scalars, metadata, address_type);
225
226 let instructions = self.compile_scope(&mut value.body.clone());
227 let tensor_maps = value
228 .tensor_maps
229 .into_iter()
230 .map(|b| self.compile_binding(b))
231 .collect();
232 let buffers = value
233 .buffers
234 .into_iter()
235 .map(|b| self.compile_binding(b))
236 .collect();
237 let scalars = value
238 .scalars
239 .into_iter()
240 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
241 .collect::<Vec<_>>();
242
243 let flags = Flags {
245 indexes: D::builtin_rules(&self.flags.indexes),
246 inst_wmma: self.flags.inst_wmma,
247 op_pipeline: self.flags.op_pipeline,
248 op_barrier: self.flags.op_barrier,
249 elem_fp4: self.flags.elem_fp4,
250 elem_fp6: self.flags.elem_fp6,
251 elem_fp8: self.flags.elem_fp8,
252 elem_bf16: self.flags.elem_bf16,
253 elem_f16: self.flags.elem_f16,
254 elem_tf32: self.flags.elem_tf32,
255 inst_tma: self.flags.inst_tma,
256 inst_tma_im2col: self.flags.inst_tma_im2col,
257 inst_async_copy: self.flags.inst_async_copy,
258 inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
259 use_grid_constants: self.compilation_options.supports_features.grid_constants,
260 has_info: self.info.has_info(),
261 has_dynamic_meta: self.info.has_dynamic_meta,
262 static_meta_length: self.info.metadata.static_len() as usize,
263 cube_dim: value.cube_dim,
264 cluster_dim: value.options.cluster_dim,
265 address_type: self.compile_type(address_type.into()),
266 };
267
268 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
269 let shared_allocs = opt.analysis::<SharedLiveness>();
270 let shared_memories = shared_allocs
271 .allocations
272 .values()
273 .map(|alloc| match alloc.smem {
274 cubecl_opt::SharedMemory::Array {
275 id,
276 length,
277 ty,
278 align,
279 } => SharedMemory::Array {
280 index: id,
281 item: self.compile_type(ty),
282 length,
283 align,
284 offset: alloc.offset,
285 },
286 cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
287 index: id,
288 item: self.compile_type(ty),
289 align,
290 offset: alloc.offset,
291 },
292 })
293 .collect();
294
295 let body = Body {
296 instructions,
297 shared_memories,
298 pipelines: self.pipelines,
299 barriers: self.barriers,
300 const_arrays: self.const_arrays,
301 local_arrays: self.local_arrays,
302 info_by_ptr: !self.compilation_options.supports_features.grid_constants,
303 has_dynamic_meta: self.info.has_dynamic_meta,
304 address_type: self.addr_type,
305 };
306
307 let mut cluster_dim = value.options.cluster_dim;
308 if !self.compilation_options.supports_features.clusters {
309 cluster_dim = None;
310 }
311
312 ComputeKernel {
313 tensor_maps,
314 buffers,
315 scalars,
316 meta_static_len: self.info.metadata.static_len() as usize,
317 cube_dim: value.cube_dim,
318 body,
319 extensions: self.extensions,
320 flags,
321 items: self.items,
322 kernel_name: value.options.kernel_name,
323 cluster_dim,
324 info: self.info.clone(),
325 }
326 }
327
328 fn build_metadata(&mut self, value: &KernelDefinition) -> cubecl_core::Metadata {
329 let mut num_ext = 0;
330
331 let mut all_meta: Vec<_> = value
332 .buffers
333 .iter()
334 .chain(value.tensor_maps.iter())
335 .map(|buf| (buf.id, buf.has_extended_meta))
336 .collect();
337
338 all_meta.sort_by_key(|(id, _)| *id);
339
340 for (_, has_extended_meta) in &all_meta {
341 self.ext_meta_positions.push(num_ext);
342 if *has_extended_meta {
343 num_ext += 1;
344 }
345 }
346
347 let num_meta = all_meta.len();
348
349 cubecl_core::Metadata::new(num_meta as u32, num_ext)
350 }
351
352 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
353 let id = var.index().expect("Variable should have index");
354 self.ext_meta_positions[id as usize]
355 }
356
357 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
358 let mut instructions = Vec::new();
359
360 let const_arrays = scope
361 .const_arrays
362 .drain(..)
363 .map(|(var, values)| ConstArray {
364 index: var.index().unwrap(),
365 item: self.compile_type(var.ty),
366 size: values.len() as u32,
367 values: values
368 .into_iter()
369 .map(|val| self.compile_variable(val))
370 .collect(),
371 })
372 .collect::<Vec<_>>();
373 self.const_arrays.extend(const_arrays);
374
375 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(
376 self.strategy,
377 self.kernel_name.clone(),
378 ));
379 let dialect_processors = D::processors();
380 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
381 processors.extend(dialect_processors.iter().map(|it| &**it));
382
383 let processing = scope.process(processors);
384
385 for var in processing.variables {
386 instructions.push(Instruction::DeclareVariable {
387 var: self.compile_variable(var),
388 });
389 }
390
391 processing
392 .instructions
393 .into_iter()
394 .for_each(|op| self.compile_instruction(&mut instructions, op));
395
396 instructions
397 }
398
399 fn compile_instruction(
400 &mut self,
401 instructions: &mut Vec<Instruction<D>>,
402 instruction: gpu::Instruction,
403 ) {
404 self.update_debug_loc(instructions, &instruction);
405 let out = instruction.out;
406 match instruction.operation {
407 gpu::Operation::Copy(variable) => {
408 instructions.push(Instruction::Assign(UnaryInstruction {
409 input: self.compile_variable(variable),
410 out: self.compile_variable(out.unwrap()),
411 }));
412 }
413 gpu::Operation::Arithmetic(op) => {
414 self.compile_arithmetic(op, out, instruction.modes, instructions)
415 }
416 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
417 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
418 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
419 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
420 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
421 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
422 gpu::Operation::Synchronization(val) => match val {
423 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
424 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
425 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
426 gpu::Synchronization::SyncAsyncProxyShared => {
427 self.flags.inst_tma = true;
428 instructions.push(Instruction::ProxyAsyncToSharedFence)
429 }
430 },
431 gpu::Operation::Plane(op) => {
432 self.flags.indexes.plane_dim_checked = true;
433 let out = self.compile_variable(out.unwrap());
434 match op {
435 gpu::Plane::Sum(op) => {
436 let instruction = WarpInstruction::ReduceSum {
437 input: self.compile_variable(op.input),
438 out,
439 };
440 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
441 instructions.push(Instruction::Warp(instruction));
442 }
443 gpu::Plane::InclusiveSum(op) => {
444 self.flags.indexes.unit_pos_plane = true;
445 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
446 input: self.compile_variable(op.input),
447 out,
448 }))
449 }
450 gpu::Plane::InclusiveProd(op) => {
451 self.flags.indexes.unit_pos_plane = true;
452 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
453 input: self.compile_variable(op.input),
454 out,
455 }))
456 }
457 gpu::Plane::ExclusiveSum(op) => {
458 self.flags.indexes.unit_pos_plane = true;
459 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
460 input: self.compile_variable(op.input),
461 out,
462 }))
463 }
464 gpu::Plane::ExclusiveProd(op) => {
465 self.flags.indexes.unit_pos_plane = true;
466 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
467 input: self.compile_variable(op.input),
468 out,
469 }))
470 }
471 gpu::Plane::Prod(op) => {
472 let instruction = WarpInstruction::ReduceProd {
473 input: self.compile_variable(op.input),
474 out,
475 };
476 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
477 instructions.push(Instruction::Warp(instruction))
478 }
479 gpu::Plane::Max(op) => {
480 let instruction = WarpInstruction::ReduceMax {
481 input: self.compile_variable(op.input),
482 out,
483 };
484 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
485 instructions.push(Instruction::Warp(instruction))
486 }
487 gpu::Plane::Min(op) => {
488 let instruction = WarpInstruction::ReduceMin {
489 input: self.compile_variable(op.input),
490 out,
491 };
492 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
493 instructions.push(Instruction::Warp(instruction))
494 }
495 gpu::Plane::Elect => {
496 if self.compilation_options.supports_features.elect_sync {
497 self.flags.inst_ptx_wrappers = true;
498 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
499 } else {
500 instructions
501 .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
502 }
503 }
504 gpu::Plane::All(op) => {
505 instructions.push(Instruction::Warp(WarpInstruction::All {
506 input: self.compile_variable(op.input),
507 out,
508 }))
509 }
510 gpu::Plane::Any(op) => {
511 instructions.push(Instruction::Warp(WarpInstruction::Any {
512 input: self.compile_variable(op.input),
513 out,
514 }))
515 }
516 gpu::Plane::Ballot(op) => {
517 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
518 input: self.compile_variable(op.input),
519 out,
520 }))
521 }
522 gpu::Plane::Broadcast(op) => {
523 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
524 input: self.compile_variable(op.lhs),
525 id: self.compile_variable(op.rhs),
526 out,
527 }))
528 }
529 gpu::Plane::Shuffle(op) => {
530 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
531 input: self.compile_variable(op.lhs),
532 src_lane: self.compile_variable(op.rhs),
533 out,
534 }))
535 }
536 gpu::Plane::ShuffleXor(op) => {
537 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
538 input: self.compile_variable(op.lhs),
539 mask: self.compile_variable(op.rhs),
540 out,
541 }))
542 }
543 gpu::Plane::ShuffleUp(op) => {
544 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
545 input: self.compile_variable(op.lhs),
546 delta: self.compile_variable(op.rhs),
547 out,
548 }))
549 }
550 gpu::Plane::ShuffleDown(op) => {
551 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
552 input: self.compile_variable(op.lhs),
553 delta: self.compile_variable(op.rhs),
554 out,
555 }))
556 }
557 }
558 }
559 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
560 gpu::Operation::NonSemantic(debug) => match debug {
561 gpu::NonSemantic::Print {
562 format_string,
563 args,
564 } => instructions.push(Instruction::Printf {
565 format_string,
566 args: args
567 .into_iter()
568 .map(|arg| self.compile_variable(arg))
569 .collect(),
570 }),
571 gpu::NonSemantic::Comment { content } => {
572 instructions.push(Instruction::Comment { content })
573 }
574 _ => {}
576 },
577 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
578 gpu::BarrierOps::Declare { barrier } => {
579 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
580 else {
581 unreachable!()
582 };
583 let barrier = self.compile_variable(barrier);
584 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
585 barrier,
586 level,
587 }));
588 }
589 gpu::BarrierOps::Init {
590 barrier,
591 is_elected,
592 arrival_count,
593 } => {
594 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
595 else {
596 unreachable!()
597 };
598 let barrier = self.compile_variable(barrier);
599 let arrival_count = self.compile_variable(arrival_count);
600 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
601 barrier,
602 is_elected: self.compile_variable(is_elected),
603 arrival_count,
604 level,
605 }));
606 }
607 gpu::BarrierOps::InitManual {
608 barrier,
609 arrival_count,
610 } => {
611 let barrier = self.compile_variable(barrier);
612 let arrival_count = self.compile_variable(arrival_count);
613 instructions.push(Instruction::Barrier(
614 super::barrier::BarrierOps::InitManual {
615 barrier,
616 arrival_count,
617 },
618 ));
619 }
620 gpu::BarrierOps::MemCopyAsync {
621 barrier,
622 source,
623 source_length,
624 offset_source,
625 offset_out,
626 } => {
627 instructions.push(Instruction::Barrier(
628 super::barrier::BarrierOps::MemCopyAsync {
629 barrier: self.compile_variable(barrier),
630 source: self.compile_variable(source),
631 destination: self.compile_variable(out.unwrap()),
632 source_length: self.compile_variable(source_length),
633 offset_source: self.compile_variable(offset_source),
634 offset_out: self.compile_variable(offset_out),
635 cooperative: false,
636 },
637 ));
638 }
639 gpu::BarrierOps::MemCopyAsyncCooperative {
640 barrier,
641 source,
642 source_length,
643 offset_source,
644 offset_out,
645 } => {
646 instructions.push(Instruction::Barrier(
647 super::barrier::BarrierOps::MemCopyAsync {
648 barrier: self.compile_variable(barrier),
649 source: self.compile_variable(source),
650 destination: self.compile_variable(out.unwrap()),
651 source_length: self.compile_variable(source_length),
652 offset_source: self.compile_variable(offset_source),
653 offset_out: self.compile_variable(offset_out),
654 cooperative: true,
655 },
656 ));
657 }
658 gpu::BarrierOps::MemCopyAsyncTx {
659 barrier,
660 source,
661 source_length,
662 offset_source,
663 offset_out,
664 } => {
665 instructions.push(Instruction::Barrier(
666 super::barrier::BarrierOps::MemCopyAsyncTx {
667 barrier: self.compile_variable(barrier),
668 source: self.compile_variable(source),
669 destination: self.compile_variable(out.unwrap()),
670 source_length: self.compile_variable(source_length),
671 offset_source: self.compile_variable(offset_source),
672 offset_out: self.compile_variable(offset_out),
673 },
674 ));
675 }
676 gpu::BarrierOps::CopyAsync {
677 source,
678 source_length,
679 offset_source,
680 offset_out,
681 copy_length,
682 checked,
683 } => {
684 self.flags.inst_async_copy = true;
685 instructions.push(Instruction::Barrier(
686 super::barrier::BarrierOps::CopyAsync {
687 source: self.compile_variable(source),
688 destination: self.compile_variable(out.unwrap()),
689 source_length: self.compile_variable(source_length),
690 offset_source: self.compile_variable(offset_source),
691 offset_out: self.compile_variable(offset_out),
692 copy_size: copy_length,
693 checked,
694 },
695 ));
696 }
697 gpu::BarrierOps::TmaLoad {
698 barrier,
699 tensor_map,
700 offset_out,
701 indices,
702 } => {
703 instructions.push(Instruction::Barrier(
704 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
705 barrier: self.compile_variable(barrier),
706 smem_buffer: self.compile_variable(out.unwrap()),
707 smem_offset: self.compile_variable(offset_out),
708 tensor_map: self.compile_variable(tensor_map),
709 indices: indices
710 .into_iter()
711 .map(|it| self.compile_variable(it))
712 .collect(),
713 },
714 ));
715 }
716 gpu::BarrierOps::TmaLoadIm2col {
717 barrier,
718 tensor_map,
719 offset_out,
720 indices,
721 offsets,
722 } => {
723 self.flags.inst_tma_im2col = true;
724 instructions.push(Instruction::Barrier(
725 super::barrier::BarrierOps::TmaLoadIm2col {
726 barrier: self.compile_variable(barrier),
727 smem_buffer: self.compile_variable(out.unwrap()),
728 smem_offset: self.compile_variable(offset_out),
729 tensor_map: self.compile_variable(tensor_map),
730 indices: indices
731 .into_iter()
732 .map(|it| self.compile_variable(it))
733 .collect(),
734 offsets: offsets
735 .into_iter()
736 .map(|it| self.compile_variable(it))
737 .collect(),
738 },
739 ));
740 }
741 gpu::BarrierOps::Arrive { barrier } => {
742 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
743 barrier: self.compile_variable(barrier),
744 token: self.compile_variable(out.unwrap()),
745 }))
746 }
747 gpu::BarrierOps::ArriveTx {
748 barrier,
749 arrive_count_update,
750 transaction_count_update,
751 } => {
752 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
753 barrier: self.compile_variable(barrier),
754 token: self.compile_variable(out.unwrap()),
755 arrive_count_update: self.compile_variable(arrive_count_update),
756 transaction_count_update: self.compile_variable(transaction_count_update),
757 }))
758 }
759 gpu::BarrierOps::CommitCopyAsync { barrier } => {
760 self.flags.inst_async_copy = true;
761 instructions.push(Instruction::Barrier(
762 super::barrier::BarrierOps::ArriveCopyAsync {
763 barrier: self.compile_variable(barrier),
764 },
765 ))
766 }
767 gpu::BarrierOps::ExpectTx {
768 barrier,
769 transaction_count_update,
770 } => {
771 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
772 barrier: self.compile_variable(barrier),
773 transaction_count_update: self.compile_variable(transaction_count_update),
774 }))
775 }
776 gpu::BarrierOps::Wait { barrier, token } => {
777 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
778 barrier: self.compile_variable(barrier),
779 token: self.compile_variable(token),
780 }))
781 }
782 gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
783 Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
784 barrier: self.compile_variable(barrier),
785 phase: self.compile_variable(phase),
786 }),
787 ),
788 gpu::BarrierOps::ArriveAndWait { barrier } => {
789 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
790 else {
791 unreachable!()
792 };
793 instructions.push(Instruction::Barrier(
794 super::barrier::BarrierOps::ArriveAndWait {
795 barrier: self.compile_variable(barrier),
796 level,
797 },
798 ))
799 }
800 },
801 gpu::Operation::Tma(tma_ops) => {
802 self.flags.inst_tma = true;
803 match tma_ops {
804 gpu::TmaOps::TmaStore {
805 source,
806 coordinates,
807 offset_source,
808 } => {
809 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
810 smem_buffer: self.compile_variable(source),
811 smem_offset: self.compile_variable(offset_source),
812 tensor_map: self.compile_variable(out.unwrap()),
813 indices: coordinates
814 .into_iter()
815 .map(|it| self.compile_variable(it))
816 .collect(),
817 });
818 }
819 gpu::TmaOps::CommitGroup => {
820 instructions.push(Instruction::BulkCommitGroup);
821 }
822 gpu::TmaOps::WaitGroup { max_pending } => {
823 instructions.push(Instruction::BulkWaitGroup { max_pending });
824 }
825 gpu::TmaOps::WaitGroupRead { max_pending } => {
826 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
827 }
828 }
829 }
830 gpu::Operation::Marker(_) => {}
831 }
832 }
833
834 fn update_debug_loc(
835 &mut self,
836 instructions: &mut Vec<Instruction<D>>,
837 inst: &gpu::Instruction,
838 ) {
839 if !matches!(inst.operation, Operation::NonSemantic(_)) {
840 match &inst.source_loc {
841 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
842 self.source_loc = Some(loc.clone());
843 instructions.push(Instruction::Line {
844 file: loc.source.file.clone(),
845 line: loc.line,
846 });
847 }
848 _ => {}
849 }
850 }
851 }
852
853 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
854 self.flags.inst_wmma = true;
855
856 let out = self.compile_variable(out.unwrap());
857
858 let inst = match cmma {
859 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
860 frag: out,
861 value: self.compile_variable(value),
862 },
863 gpu::CoopMma::Load {
864 value,
865 stride,
866 offset,
867 layout,
868 } => WmmaInstruction::Load {
869 frag: out,
870 offset: self.compile_variable(offset),
871 value: self.compile_variable(value),
872 stride: self.compile_variable(stride),
873 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
874 },
875 gpu::CoopMma::Execute {
876 mat_a,
877 mat_b,
878 mat_c,
879 } => WmmaInstruction::Execute {
880 frag_a: self.compile_variable(mat_a),
881 frag_b: self.compile_variable(mat_b),
882 frag_c: self.compile_variable(mat_c),
883 frag_d: out,
884 warp_size: self.compilation_options.warp_size,
885 },
886 gpu::CoopMma::ExecuteManual {
887 matrix,
888 registers_a,
889 registers_b,
890 registers_c,
891 } => WmmaInstruction::ExecuteManual {
892 shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
893 frag_a: self.compile_variable(registers_a),
894 frag_b: self.compile_variable(registers_b),
895 frag_c: self.compile_variable(registers_c),
896 frag_d: out,
897 },
898 gpu::CoopMma::ExecuteScaled {
899 matrix,
900 registers_a,
901 registers_b,
902 registers_c,
903 scales_a,
904 scales_b,
905 scales_factor,
906 } => WmmaInstruction::ExecuteScaled {
907 shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
908 frag_a: self.compile_variable(registers_a),
909 frag_b: self.compile_variable(registers_b),
910 frag_c: self.compile_variable(registers_c),
911 frag_d: out,
912
913 scales_a: self.compile_variable(scales_a),
914 scales_b: self.compile_variable(scales_b),
915 scales_factor: scales_factor as u32,
916 },
917 gpu::CoopMma::Store {
918 mat,
919 stride,
920 offset,
921 layout,
922 } => {
923 self.flags.indexes.unit_pos = true;
924 self.flags.indexes.plane_pos = true;
925 WmmaInstruction::Store {
926 output: out,
927 offset: self.compile_variable(offset),
928 frag: self.compile_variable(mat),
929 stride: self.compile_variable(stride),
930 layout: self
931 .compile_matrix_layout(layout)
932 .expect("Layout required for store instruction"),
933 }
934 }
935 gpu::CoopMma::LoadMatrix {
936 buffer,
937 offset,
938 vector_size,
939 factor,
940 transpose,
941 } => WmmaInstruction::LdMatrix {
942 output: out,
943 buffer: self.compile_variable(buffer),
944 offset: self.compile_variable(offset),
945 vector_size,
946 factor: factor as u32,
947 transpose,
948 },
949 gpu::CoopMma::StoreMatrix {
950 offset,
951 vector_size,
952 registers,
953 factor,
954 transpose,
955 } => WmmaInstruction::StMatrix {
956 registers: self.compile_variable(registers),
957 buffer: out,
958 offset: self.compile_variable(offset),
959 vector_size,
960 factor: factor as u32,
961 transpose,
962 },
963 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
964 input: self.compile_variable(input),
965 output: out,
966 },
967 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
968 panic!("Row/Col index should be handled by processors")
969 }
970 };
971
972 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
973
974 Instruction::Wmma(inst)
975 }
976
977 fn compile_metadata(
978 &mut self,
979 metadata: gpu::Metadata,
980 out: Option<gpu::Variable>,
981 ) -> Instruction<D> {
982 let out = out.unwrap();
983 match metadata {
984 gpu::Metadata::Stride { dim, var } => {
985 let position = self.ext_meta_position(var);
986 let offset = self.info.metadata.stride_offset_index(position);
987 Instruction::ExtendedMetadata {
988 info_offset: self.compile_variable(offset.into()),
989 dim: self.compile_variable(dim),
990 out: self.compile_variable(out),
991 }
992 }
993 gpu::Metadata::Shape { dim, var } => {
994 let position = self.ext_meta_position(var);
995 let offset = self.info.metadata.shape_offset_index(position);
996 Instruction::ExtendedMetadata {
997 info_offset: self.compile_variable(offset.into()),
998 dim: self.compile_variable(dim),
999 out: self.compile_variable(out),
1000 }
1001 }
1002 gpu::Metadata::Rank { var } => {
1003 let out = self.compile_variable(out);
1004 let pos = self.ext_meta_position(var);
1005 let offset = self.info.metadata.rank_index(pos);
1006 super::Instruction::Metadata {
1007 info_offset: self.compile_variable(offset.into()),
1008 out,
1009 }
1010 }
1011 gpu::Metadata::Length { var } => {
1012 let input = self.compile_variable(var);
1013 let out = self.compile_variable(out);
1014
1015 match input {
1016 Variable::Slice { .. } => Instruction::SliceLength { input, out },
1017 Variable::SharedArray(_id, _item, length) => {
1018 Instruction::ConstLength { length, out }
1019 }
1020 _ => {
1021 let id = input.id().expect("Variable should have id");
1022 let offset = self.info.metadata.len_index(id);
1023 Instruction::Metadata {
1024 info_offset: self.compile_variable(offset.into()),
1025 out,
1026 }
1027 }
1028 }
1029 }
1030 gpu::Metadata::BufferLength { var } => {
1031 let input = self.compile_variable(var);
1032 let out = self.compile_variable(out);
1033
1034 match input {
1035 Variable::Slice { .. } => Instruction::SliceLength { input, out },
1036 _ => {
1037 let id = input.id().expect("Variable should have id");
1038 let offset = self.info.metadata.buffer_len_index(id);
1039 Instruction::Metadata {
1040 info_offset: self.compile_variable(offset.into()),
1041 out,
1042 }
1043 }
1044 }
1045 }
1046 }
1047 }
1048
1049 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
1050 match branch {
1051 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
1052 cond: self.compile_variable(op.cond),
1053 instructions: self.compile_scope(&mut op.scope),
1054 }),
1055 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
1056 cond: self.compile_variable(op.cond),
1057 instructions_if: self.compile_scope(&mut op.scope_if),
1058 instructions_else: self.compile_scope(&mut op.scope_else),
1059 }),
1060 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1061 value: self.compile_variable(op.value),
1062 instructions_default: self.compile_scope(&mut op.scope_default),
1063 instructions_cases: op
1064 .cases
1065 .into_iter()
1066 .map(|(val, mut block)| {
1067 (self.compile_variable(val), self.compile_scope(&mut block))
1068 })
1069 .collect(),
1070 }),
1071 gpu::Branch::Return => instructions.push(Instruction::Return),
1072 gpu::Branch::Break => instructions.push(Instruction::Break),
1073 gpu::Branch::Unreachable => instructions.push(Instruction::Unreachable),
1074 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1075 i: self.compile_variable(range_loop.i),
1076 start: self.compile_variable(range_loop.start),
1077 end: self.compile_variable(range_loop.end),
1078 step: range_loop.step.map(|it| self.compile_variable(it)),
1079 inclusive: range_loop.inclusive,
1080 instructions: self.compile_scope(&mut range_loop.scope),
1081 }),
1082 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1083 instructions: self.compile_scope(&mut op.scope),
1084 }),
1085 };
1086 }
1087
1088 fn compile_atomic(
1089 &mut self,
1090 value: gpu::AtomicOp,
1091 out: Option<gpu::Variable>,
1092 instructions: &mut Vec<Instruction<D>>,
1093 ) {
1094 let out = out.unwrap();
1095 match value {
1096 gpu::AtomicOp::Load(op) => {
1097 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1098 }
1099 gpu::AtomicOp::Store(op) => {
1100 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1101 }
1102 gpu::AtomicOp::Swap(op) => {
1103 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1104 }
1105 gpu::AtomicOp::Add(op) => {
1106 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1107 }
1108 gpu::AtomicOp::Sub(op) => {
1109 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1110 }
1111 gpu::AtomicOp::Max(op) => {
1112 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1113 }
1114 gpu::AtomicOp::Min(op) => {
1115 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1116 }
1117 gpu::AtomicOp::And(op) => {
1118 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1119 }
1120 gpu::AtomicOp::Or(op) => {
1121 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1122 }
1123 gpu::AtomicOp::Xor(op) => {
1124 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1125 }
1126 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1127 input: self.compile_variable(op.input),
1128 cmp: self.compile_variable(op.cmp),
1129 val: self.compile_variable(op.val),
1130 out: self.compile_variable(out),
1131 }),
1132 }
1133 }
1134
1135 fn compile_arithmetic(
1136 &mut self,
1137 value: gpu::Arithmetic,
1138 out: Option<gpu::Variable>,
1139 modes: InstructionModes,
1140 instructions: &mut Vec<Instruction<D>>,
1141 ) {
1142 let out = out.unwrap();
1143 match value {
1144 gpu::Arithmetic::Add(op) => {
1145 instructions.push(Instruction::Add(self.compile_binary(op, out)))
1146 }
1147 gpu::Arithmetic::SaturatingAdd(op) => {
1148 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1149 }
1150 gpu::Arithmetic::Mul(op) => {
1151 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1152 }
1153 gpu::Arithmetic::Div(op) => {
1154 let op = self.compile_binary(op, out);
1155 instructions.push(self.select_fast_float(
1156 out.ty,
1157 modes,
1158 FastMath::AllowReciprocal
1159 | FastMath::ReducedPrecision
1160 | FastMath::UnsignedZero
1161 | FastMath::NotInf,
1162 Instruction::Div(op),
1163 Instruction::FastDiv(op),
1164 ))
1165 }
1166 gpu::Arithmetic::Sub(op) => {
1167 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1168 }
1169 gpu::Arithmetic::SaturatingSub(op) => {
1170 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1171 }
1172 gpu::Arithmetic::MulHi(op) => {
1173 let instruction = Instruction::HiMul(self.compile_binary(op, out));
1174 D::register_instruction_extension(&mut self.extensions, &instruction);
1175 instructions.push(instruction)
1176 }
1177 gpu::Arithmetic::Modulo(op) => {
1178 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1179 }
1180 gpu::Arithmetic::Abs(op) => {
1181 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1182 }
1183 gpu::Arithmetic::Exp(op) => {
1184 let op = self.compile_unary(op, out);
1185 instructions.push(self.select_fast_float(
1186 out.ty,
1187 modes,
1188 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1189 Instruction::Exp(op),
1190 Instruction::FastExp(op),
1191 ));
1192 }
1193 gpu::Arithmetic::Log(op) => {
1194 let op = self.compile_unary(op, out);
1195 instructions.push(self.select_fast_float(
1196 out.ty,
1197 modes,
1198 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1199 Instruction::Log(op),
1200 Instruction::FastLog(op),
1201 ));
1202 }
1203 gpu::Arithmetic::Log1p(op) => {
1204 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1205 }
1206 gpu::Arithmetic::Cos(op) => {
1207 let op = self.compile_unary(op, out);
1208 instructions.push(self.select_fast_float(
1209 out.ty,
1210 modes,
1211 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1212 Instruction::Cos(op),
1213 Instruction::FastCos(op),
1214 ));
1215 }
1216 gpu::Arithmetic::Sin(op) => {
1217 let op = self.compile_unary(op, out);
1218 instructions.push(self.select_fast_float(
1219 out.ty,
1220 modes,
1221 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1222 Instruction::Sin(op),
1223 Instruction::FastSin(op),
1224 ));
1225 }
1226 gpu::Arithmetic::Tan(op) => {
1227 instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1228 }
1229 gpu::Arithmetic::Tanh(op) => {
1230 let op = self.compile_unary(op, out);
1231 let instruction = Instruction::Tanh(op);
1232 D::register_instruction_extension(&mut self.extensions, &instruction);
1233 if self.compilation_options.supports_features.fast_tanh {
1234 instructions.push(self.select_fast_float(
1235 out.ty,
1236 modes,
1237 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1238 instruction,
1239 Instruction::FastTanh(op),
1240 ))
1241 } else {
1242 instructions.push(instruction);
1243 }
1244 }
1245 gpu::Arithmetic::Sinh(op) => {
1246 let instruction = Instruction::Sinh(self.compile_unary(op, out));
1247 D::register_instruction_extension(&mut self.extensions, &instruction);
1248 instructions.push(instruction)
1249 }
1250 gpu::Arithmetic::Cosh(op) => {
1251 let instruction = Instruction::Cosh(self.compile_unary(op, out));
1252 D::register_instruction_extension(&mut self.extensions, &instruction);
1253 instructions.push(instruction)
1254 }
1255 gpu::Arithmetic::ArcCos(op) => {
1256 let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1257 D::register_instruction_extension(&mut self.extensions, &instruction);
1258 instructions.push(instruction)
1259 }
1260 gpu::Arithmetic::ArcSin(op) => {
1261 let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1262 D::register_instruction_extension(&mut self.extensions, &instruction);
1263 instructions.push(instruction)
1264 }
1265 gpu::Arithmetic::ArcTan(op) => {
1266 let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1267 D::register_instruction_extension(&mut self.extensions, &instruction);
1268 instructions.push(instruction)
1269 }
1270 gpu::Arithmetic::ArcSinh(op) => {
1271 let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1272 D::register_instruction_extension(&mut self.extensions, &instruction);
1273 instructions.push(instruction)
1274 }
1275 gpu::Arithmetic::ArcCosh(op) => {
1276 let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1277 D::register_instruction_extension(&mut self.extensions, &instruction);
1278 instructions.push(instruction)
1279 }
1280 gpu::Arithmetic::ArcTanh(op) => {
1281 let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1282 D::register_instruction_extension(&mut self.extensions, &instruction);
1283 instructions.push(instruction)
1284 }
1285 gpu::Arithmetic::Degrees(op) => {
1286 let instruction = Instruction::Degrees(self.compile_unary(op, out));
1287 D::register_instruction_extension(&mut self.extensions, &instruction);
1288 instructions.push(instruction)
1289 }
1290 gpu::Arithmetic::Radians(op) => {
1291 let instruction = Instruction::Radians(self.compile_unary(op, out));
1292 D::register_instruction_extension(&mut self.extensions, &instruction);
1293 instructions.push(instruction)
1294 }
1295 gpu::Arithmetic::ArcTan2(op) => {
1296 let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1297 D::register_instruction_extension(&mut self.extensions, &instruction);
1298 instructions.push(instruction)
1299 }
1300 gpu::Arithmetic::Powf(op) => {
1301 let op = self.compile_binary(op, out);
1302 instructions.push(self.select_fast_float(
1303 out.ty,
1304 modes,
1305 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1306 Instruction::Powf(op),
1307 Instruction::FastPowf(op),
1308 ))
1309 }
1310 gpu::Arithmetic::Powi(op) => {
1311 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1312 }
1313 gpu::Arithmetic::Hypot(op) => {
1314 instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1315 }
1316 gpu::Arithmetic::Rhypot(op) => {
1317 instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1318 }
1319 gpu::Arithmetic::Sqrt(op) => {
1320 let op = self.compile_unary(op, out);
1321 instructions.push(self.select_fast_float(
1322 out.ty,
1323 modes,
1324 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1325 Instruction::Sqrt(op),
1326 Instruction::FastSqrt(op),
1327 ))
1328 }
1329 gpu::Arithmetic::InverseSqrt(op) => {
1330 let op = self.compile_unary(op, out);
1331 instructions.push(self.select_fast_float(
1332 out.ty,
1333 modes,
1334 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1335 Instruction::InverseSqrt(op),
1336 Instruction::FastInverseSqrt(op),
1337 ))
1338 }
1339 gpu::Arithmetic::Erf(op) => {
1340 let instruction = Instruction::Erf(self.compile_unary(op, out));
1341 D::register_instruction_extension(&mut self.extensions, &instruction);
1342 instructions.push(instruction)
1343 }
1344 gpu::Arithmetic::Max(op) => {
1345 let instruction = Instruction::Max(self.compile_binary(op, out));
1346 D::register_instruction_extension(&mut self.extensions, &instruction);
1347 instructions.push(instruction)
1348 }
1349 gpu::Arithmetic::Min(op) => {
1350 let instruction = Instruction::Min(self.compile_binary(op, out));
1351 D::register_instruction_extension(&mut self.extensions, &instruction);
1352 instructions.push(instruction)
1353 }
1354 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1355 input: self.compile_variable(op.input),
1356 min_value: self.compile_variable(op.min_value),
1357 max_value: self.compile_variable(op.max_value),
1358 out: self.compile_variable(out),
1359 }),
1360 gpu::Arithmetic::Recip(op) => {
1361 let elem = op.input.ty.elem_type();
1362 let input = self.compile_variable(op.input);
1363 let out = self.compile_variable(out);
1364 let lhs = match elem {
1365 gpu::ElemType::Float(_) => gpu::ConstantValue::Float(1.0),
1366 gpu::ElemType::Int(_) => gpu::ConstantValue::Int(1),
1367 gpu::ElemType::UInt(_) => gpu::ConstantValue::UInt(1),
1368 gpu::ElemType::Bool => gpu::ConstantValue::Bool(true),
1369 };
1370 let div = Instruction::Div(BinaryInstruction {
1371 lhs: Variable::Constant(lhs, self.compile_type(op.input.ty)),
1372 rhs: input,
1373 out,
1374 });
1375 let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1376
1377 instructions.push(self.select_fast_float(
1378 elem.into(),
1379 modes,
1380 FastMath::AllowReciprocal
1381 | FastMath::ReducedPrecision
1382 | FastMath::UnsignedZero
1383 | FastMath::NotInf,
1384 div,
1385 recip,
1386 ))
1387 }
1388 gpu::Arithmetic::Round(op) => {
1389 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1390 }
1391 gpu::Arithmetic::Floor(op) => {
1392 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1393 }
1394 gpu::Arithmetic::Ceil(op) => {
1395 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1396 }
1397 gpu::Arithmetic::Trunc(op) => {
1398 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1399 }
1400 gpu::Arithmetic::Remainder(op) => {
1401 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1402 }
1403 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1404 a: self.compile_variable(op.a),
1405 b: self.compile_variable(op.b),
1406 c: self.compile_variable(op.c),
1407 out: self.compile_variable(out),
1408 }),
1409 gpu::Arithmetic::Neg(op) => {
1410 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1411 }
1412 gpu::Arithmetic::Normalize(op) => {
1413 let op = self.compile_unary(op, out);
1414 instructions.push(self.select_fast_float(
1415 out.ty,
1416 modes,
1417 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1418 Instruction::Normalize(op),
1419 Instruction::FastNormalize(op),
1420 ))
1421 }
1422 gpu::Arithmetic::Magnitude(op) => {
1423 let op = self.compile_unary(op, out);
1424 instructions.push(self.select_fast_float(
1425 out.ty,
1426 modes,
1427 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1428 Instruction::Magnitude(op),
1429 Instruction::FastMagnitude(op),
1430 ))
1431 }
1432 gpu::Arithmetic::Dot(op) => {
1433 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1434 }
1435 };
1436 }
1437
1438 fn select_fast_float(
1439 &self,
1440 ty: gpu::Type,
1441 modes: InstructionModes,
1442 required_flags: EnumSet<FastMath>,
1443 default: Instruction<D>,
1444 fast: Instruction<D>,
1445 ) -> Instruction<D> {
1446 if !self.compilation_options.supports_features.fast_math
1447 || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1448 {
1449 return default;
1450 }
1451
1452 if modes.fp_math_mode.is_superset(required_flags) {
1453 fast
1454 } else {
1455 default
1456 }
1457 }
1458
1459 fn compile_comparison(
1460 &mut self,
1461 value: gpu::Comparison,
1462 out: Option<gpu::Variable>,
1463 instructions: &mut Vec<Instruction<D>>,
1464 ) {
1465 let out = out.unwrap();
1466 match value {
1467 gpu::Comparison::Equal(op) => {
1468 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1469 }
1470 gpu::Comparison::Lower(op) => {
1471 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1472 }
1473 gpu::Comparison::Greater(op) => {
1474 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1475 }
1476 gpu::Comparison::LowerEqual(op) => {
1477 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1478 }
1479 gpu::Comparison::GreaterEqual(op) => {
1480 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1481 }
1482 gpu::Comparison::NotEqual(op) => {
1483 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1484 }
1485 gpu::Comparison::IsNan(op) => {
1486 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1487 }
1488 gpu::Comparison::IsInf(op) => {
1489 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1490 }
1491 };
1492 }
1493
1494 fn compile_bitwise(
1495 &mut self,
1496 value: gpu::Bitwise,
1497 out: Option<gpu::Variable>,
1498 instructions: &mut Vec<Instruction<D>>,
1499 ) {
1500 let out = out.unwrap();
1501 match value {
1502 gpu::Bitwise::BitwiseOr(op) => {
1503 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1504 }
1505 gpu::Bitwise::BitwiseAnd(op) => {
1506 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1507 }
1508 gpu::Bitwise::BitwiseXor(op) => {
1509 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1510 }
1511 gpu::Bitwise::CountOnes(op) => {
1512 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1513 }
1514 gpu::Bitwise::ReverseBits(op) => {
1515 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1516 }
1517 gpu::Bitwise::ShiftLeft(op) => {
1518 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1519 }
1520 gpu::Bitwise::ShiftRight(op) => {
1521 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1522 }
1523 gpu::Bitwise::BitwiseNot(op) => {
1524 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1525 }
1526 gpu::Bitwise::LeadingZeros(op) => {
1527 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1528 }
1529 gpu::Bitwise::TrailingZeros(op) => {
1530 instructions.push(Instruction::TrailingZeros(self.compile_unary(op, out)))
1531 }
1532 gpu::Bitwise::FindFirstSet(op) => {
1533 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1534 D::register_instruction_extension(&mut self.extensions, &instruction);
1535 instructions.push(instruction)
1536 }
1537 };
1538 }
1539
1540 fn compile_operator(
1541 &mut self,
1542 value: gpu::Operator,
1543 out: Option<gpu::Variable>,
1544 instructions: &mut Vec<Instruction<D>>,
1545 ) {
1546 let out = out.unwrap();
1547 match value {
1548 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1549 instructions.push(Instruction::Index(self.compile_index(op, out)));
1550 }
1551 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1552 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1553 }
1554 gpu::Operator::And(op) => {
1555 instructions.push(Instruction::And(self.compile_binary(op, out)))
1556 }
1557 gpu::Operator::Or(op) => {
1558 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1559 }
1560 gpu::Operator::Not(op) => {
1561 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1562 }
1563 gpu::Operator::InitVector(op) => instructions.push(Instruction::VecInit {
1564 inputs: op
1565 .inputs
1566 .into_iter()
1567 .map(|it| self.compile_variable(it))
1568 .collect(),
1569 out: self.compile_variable(out),
1570 }),
1571 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1572 input: self.compile_variable(op.input),
1573 in_index: self.compile_variable(op.in_index),
1574 out: self.compile_variable(out),
1575 out_index: self.compile_variable(op.out_index),
1576 }),
1577 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1578 input: self.compile_variable(op.input),
1579 in_index: self.compile_variable(op.in_index),
1580 out: self.compile_variable(out),
1581 out_index: self.compile_variable(op.out_index),
1582 len: op.len as u32,
1583 }),
1584 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1585 cond: self.compile_variable(op.cond),
1586 then: self.compile_variable(op.then),
1587 or_else: self.compile_variable(op.or_else),
1588 out: self.compile_variable(out),
1589 }),
1590 gpu::Operator::Cast(op)
1592 if (is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()))
1593 && op.input.elem_type() != out.elem_type() =>
1595 {
1596 self.flags.elem_f16 = true;
1598 self.flags.elem_bf16 = true;
1599 let vec_in = op.input.ty.vector_size();
1600 let packing = out.storage_type().packing_factor();
1601 self.compile_type(op.input.ty.with_vector_size(packing));
1602 self.compile_type(
1603 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16))
1604 .with_vector_size(vec_in),
1605 );
1606 self.compile_type(
1607 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16))
1608 .with_vector_size(vec_in),
1609 );
1610 self.compile_type(
1611 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16))
1612 .with_vector_size(packing),
1613 );
1614 self.compile_type(
1615 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16))
1616 .with_vector_size(packing),
1617 );
1618
1619 let inst = self.compile_unary(op, out);
1620
1621 instructions.push(Instruction::SpecialCast(inst));
1622 }
1623 gpu::Operator::Cast(op) => {
1624 let op = self.compile_unary(op, out);
1625
1626 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1627 self.flags.elem_tf32 = true;
1628 }
1629
1630 instructions.push(Instruction::Assign(op))
1631 }
1632 gpu::Operator::Reinterpret(op) => {
1633 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1634 }
1635 };
1636 }
1637
1638 fn compile_binary(
1639 &mut self,
1640 value: gpu::BinaryOperator,
1641 out: gpu::Variable,
1642 ) -> BinaryInstruction<D> {
1643 BinaryInstruction {
1644 lhs: self.compile_variable(value.lhs),
1645 rhs: self.compile_variable(value.rhs),
1646 out: self.compile_variable(out),
1647 }
1648 }
1649
1650 fn compile_index_assign(
1651 &mut self,
1652 value: gpu::IndexAssignOperator,
1653 out: gpu::Variable,
1654 ) -> IndexAssignInstruction<D> {
1655 IndexAssignInstruction {
1656 index: self.compile_variable(value.index),
1657 value: self.compile_variable(value.value),
1658 vector_size: value.vector_size as u32,
1659 out: self.compile_variable(out),
1660 }
1661 }
1662
1663 fn compile_index(
1664 &mut self,
1665 value: gpu::IndexOperator,
1666 out: gpu::Variable,
1667 ) -> IndexInstruction<D> {
1668 IndexInstruction {
1669 list: self.compile_variable(value.list),
1670 index: self.compile_variable(value.index),
1671 vector_size: value.vector_size as u32,
1672 out: self.compile_variable(out),
1673 }
1674 }
1675
1676 fn compile_unary(
1677 &mut self,
1678 value: gpu::UnaryOperator,
1679 out: gpu::Variable,
1680 ) -> UnaryInstruction<D> {
1681 UnaryInstruction {
1682 input: self.compile_variable(value.input),
1683 out: self.compile_variable(out),
1684 }
1685 }
1686
1687 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1688 let item = value.ty;
1689 match value.kind {
1690 gpu::VariableKind::GlobalInputArray(id) => {
1691 Variable::GlobalInputArray(id, self.compile_type(item))
1692 }
1693 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1694 id,
1695 elem: self.compile_storage_type(item.storage_type()),
1696 },
1697 gpu::VariableKind::TensorMapInput(id) => {
1698 self.flags.inst_tma = true;
1699 Variable::TensorMap(id)
1700 }
1701 gpu::VariableKind::TensorMapOutput(id) => {
1702 self.flags.inst_tma = true;
1703 Variable::TensorMap(id)
1704 }
1705 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1706 id,
1707 item: self.compile_type(item),
1708 },
1709 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1710 id,
1711 item: self.compile_type(item),
1712 },
1713 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1714 id,
1715 item: self.compile_type(item),
1716 },
1717 gpu::VariableKind::GlobalOutputArray(id) => {
1718 Variable::GlobalOutputArray(id, self.compile_type(item))
1719 }
1720 gpu::VariableKind::Constant(value) => {
1721 Variable::Constant(value, self.compile_type(item))
1722 }
1723 gpu::VariableKind::SharedArray { id, length, .. } => {
1724 let item = self.compile_type(item);
1725 Variable::SharedArray(id, item, length)
1726 }
1727 gpu::VariableKind::Shared { id } => {
1728 let item = self.compile_type(item);
1729 Variable::Shared(id, item)
1730 }
1731 gpu::VariableKind::ConstantArray {
1732 id,
1733 length,
1734 unroll_factor,
1735 } => {
1736 let item = self.compile_type(item);
1737 Variable::ConstantArray(id, item, length * unroll_factor)
1738 }
1739 gpu::VariableKind::Builtin(builtin) => match builtin {
1740 gpu::Builtin::AbsolutePos => {
1741 self.flags.indexes.absolute_pos = true;
1742 let item = self.compile_type(item);
1743 Variable::AbsolutePos(item.elem)
1744 }
1745 gpu::Builtin::CubePosCluster
1746 if self.compilation_options.supports_features.clusters =>
1747 {
1748 self.flags.indexes.cluster_pos = true;
1749 Variable::ClusterRank
1750 }
1751 gpu::Builtin::CubePosClusterX
1752 if self.compilation_options.supports_features.clusters =>
1753 {
1754 self.flags.indexes.cluster_pos = true;
1755 Variable::ClusterIndexX
1756 }
1757 gpu::Builtin::CubePosClusterY
1758 if self.compilation_options.supports_features.clusters =>
1759 {
1760 self.flags.indexes.cluster_pos = true;
1761 Variable::ClusterIndexY
1762 }
1763 gpu::Builtin::CubePosClusterZ
1764 if self.compilation_options.supports_features.clusters =>
1765 {
1766 self.flags.indexes.cluster_pos = true;
1767 Variable::ClusterIndexZ
1768 }
1769 gpu::Builtin::CubePosCluster
1772 | gpu::Builtin::CubePosClusterX
1773 | gpu::Builtin::CubePosClusterY
1774 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1775 gpu::Builtin::AbsolutePosX => {
1776 self.flags.indexes.absolute_pos_tuple = true;
1777 Variable::AbsolutePosX
1778 }
1779 gpu::Builtin::AbsolutePosY => {
1780 self.flags.indexes.absolute_pos_tuple = true;
1781 Variable::AbsolutePosY
1782 }
1783 gpu::Builtin::AbsolutePosZ => {
1784 self.flags.indexes.absolute_pos_tuple = true;
1785 Variable::AbsolutePosZ
1786 }
1787 gpu::Builtin::CubeDim => {
1788 self.flags.indexes.cube_dim = true;
1789 Variable::CubeDim
1790 }
1791 gpu::Builtin::CubeDimX => {
1792 self.flags.indexes.cube_dim_tuple = true;
1793 Variable::CubeDimX
1794 }
1795 gpu::Builtin::CubeDimY => {
1796 self.flags.indexes.cube_dim_tuple = true;
1797 Variable::CubeDimY
1798 }
1799 gpu::Builtin::CubeDimZ => {
1800 self.flags.indexes.cube_dim_tuple = true;
1801 Variable::CubeDimZ
1802 }
1803 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1804 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1805 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1806 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1807 gpu::Builtin::CubePos => {
1808 self.flags.indexes.cube_pos = true;
1809 let item = self.compile_type(item);
1810 Variable::CubePos(item.elem)
1811 }
1812 gpu::Builtin::CubePosX => {
1813 self.flags.indexes.cube_pos_tuple = true;
1814 Variable::CubePosX
1815 }
1816 gpu::Builtin::CubePosY => {
1817 self.flags.indexes.cube_pos_tuple = true;
1818 Variable::CubePosY
1819 }
1820 gpu::Builtin::CubePosZ => {
1821 self.flags.indexes.cube_pos_tuple = true;
1822 Variable::CubePosZ
1823 }
1824 gpu::Builtin::CubeCount => {
1825 self.flags.indexes.cube_count = true;
1826 let item = self.compile_type(item);
1827 Variable::CubeCount(item.elem)
1828 }
1829 gpu::Builtin::CubeCountX => {
1830 self.flags.indexes.cube_count_tuple = true;
1831 Variable::CubeCountX
1832 }
1833 gpu::Builtin::CubeCountY => {
1834 self.flags.indexes.cube_count_tuple = true;
1835 Variable::CubeCountY
1836 }
1837 gpu::Builtin::CubeCountZ => {
1838 self.flags.indexes.cube_count_tuple = true;
1839 Variable::CubeCountZ
1840 }
1841 gpu::Builtin::UnitPos => {
1842 self.flags.indexes.unit_pos = true;
1843 Variable::UnitPos
1844 }
1845 gpu::Builtin::UnitPosX => {
1846 self.flags.indexes.unit_pos_tuple = true;
1847 Variable::UnitPosX
1848 }
1849 gpu::Builtin::UnitPosY => {
1850 self.flags.indexes.unit_pos_tuple = true;
1851 Variable::UnitPosY
1852 }
1853 gpu::Builtin::UnitPosZ => {
1854 self.flags.indexes.unit_pos_tuple = true;
1855 Variable::UnitPosZ
1856 }
1857 gpu::Builtin::PlaneDim => {
1858 self.flags.indexes.plane_dim = true;
1859 Variable::PlaneDim
1860 }
1861 gpu::Builtin::PlanePos => {
1862 self.flags.indexes.plane_pos = true;
1863 Variable::PlanePos
1864 }
1865 gpu::Builtin::UnitPosPlane => {
1866 self.flags.indexes.unit_pos_plane = true;
1867 Variable::UnitPosPlane
1868 }
1869 },
1870 gpu::VariableKind::LocalArray {
1871 id,
1872 length,
1873 unroll_factor,
1874 } => {
1875 let item = self.compile_type(item);
1876 if !self.local_arrays.iter().any(|s| s.index == id) {
1877 self.local_arrays
1878 .push(LocalArray::new(id, item, length * unroll_factor));
1879 }
1880 Variable::LocalArray(id, item, length)
1881 }
1882 gpu::VariableKind::Matrix { id, mat } => {
1883 self.flags.inst_wmma = true;
1884 Variable::WmmaFragment {
1885 id,
1886 frag: self.compile_matrix(mat),
1887 }
1888 }
1889 gpu::VariableKind::Pipeline { id, num_stages } => {
1890 self.flags.op_pipeline = true;
1891 let pipeline = Variable::Pipeline { id };
1892 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1893 self.pipelines.push(PipelineOps::Init {
1894 pipeline,
1895 num_stages,
1896 });
1897 }
1898 pipeline
1899 }
1900 gpu::VariableKind::BarrierToken { id, level } => {
1901 self.flags.op_barrier = true;
1902 Variable::BarrierToken { id, level }
1903 }
1904 }
1905 }
1906
1907 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1908 Fragment {
1909 ident: self.compile_matrix_ident(matrix.ident),
1910 m: matrix.m as u32,
1911 n: matrix.n as u32,
1912 k: matrix.k as u32,
1913 elem: self.compile_storage_type(matrix.storage),
1914 layout: self.compile_matrix_layout(matrix.layout),
1915 }
1916 }
1917
1918 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1919 match ident {
1920 gpu::MatrixIdent::A => FragmentIdent::A,
1921 gpu::MatrixIdent::B => FragmentIdent::B,
1922 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1923 }
1924 }
1925
1926 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1927 match layout {
1928 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1929 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1930 gpu::MatrixLayout::Undefined => None,
1931 }
1932 }
1933
1934 fn compile_binding(&mut self, binding: cubecl_runtime::kernel::KernelArg) -> KernelArg<D> {
1935 KernelArg {
1936 id: binding.id,
1937 item: self.compile_type(binding.ty),
1938 size: binding.size,
1939 vis: binding.visibility,
1940 }
1941 }
1942
1943 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1944 let item = match ty {
1945 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1946 gpu::Type::Vector(ty, vector_size) => {
1947 Item::new(self.compile_storage_type(ty), vector_size, false)
1948 }
1949 gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1950 };
1951 if item.elem != super::Elem::TF32 {
1952 self.items.insert(item);
1953 self.items.insert(item.optimized());
1954 } else {
1955 let mut item = item;
1957 item.elem = super::Elem::F32;
1958 self.items.insert(item);
1959 }
1960
1961 item
1962 }
1963
1964 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1965 match value {
1966 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1967 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1968 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1969 FloatKind::E2M1 => {
1970 self.flags.elem_fp4 = true;
1971 Elem::FP4x2(FP4Kind::E2M1)
1972 }
1973 FloatKind::E2M3 => {
1974 self.flags.elem_fp6 = true;
1975 Elem::FP6x2(FP6Kind::E2M3)
1976 }
1977 FloatKind::E3M2 => {
1978 self.flags.elem_fp6 = true;
1979 Elem::FP6(FP6Kind::E3M2)
1980 }
1981 FloatKind::E4M3 => {
1982 self.flags.elem_fp8 = true;
1983 Elem::FP8x2(FP8Kind::E4M3)
1984 }
1985 FloatKind::E5M2 => {
1986 self.flags.elem_fp8 = true;
1987 Elem::FP8x2(FP8Kind::E5M2)
1988 }
1989 FloatKind::UE8M0 => {
1990 self.flags.elem_fp8 = true;
1991 Elem::FP8x2(FP8Kind::UE8M0)
1992 }
1993 FloatKind::F16 => {
1994 self.flags.elem_f16 = true;
1995 Elem::F16x2
1996 }
1997 FloatKind::BF16 => {
1998 self.flags.elem_bf16 = true;
1999 Elem::BF16x2
2000 }
2001 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
2002 },
2003 gpu::StorageType::Packed(other, factor) => {
2004 unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
2005 }
2006 gpu::StorageType::Opaque(ty) => match ty {
2007 gpu::OpaqueType::Barrier(level) => {
2008 self.flags.op_barrier = true;
2009 Elem::Barrier(level)
2010 }
2011 },
2012 }
2013 }
2014
2015 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
2016 match value {
2017 gpu::ElemType::Float(kind) => match kind {
2018 gpu::FloatKind::E2M1 => {
2019 self.flags.elem_fp4 = true;
2020 Elem::FP4(FP4Kind::E2M1)
2021 }
2022 gpu::FloatKind::E2M3 => {
2023 self.flags.elem_fp6 = true;
2024 Elem::FP6(FP6Kind::E2M3)
2025 }
2026 gpu::FloatKind::E3M2 => {
2027 self.flags.elem_fp6 = true;
2028 Elem::FP6(FP6Kind::E3M2)
2029 }
2030 gpu::FloatKind::E4M3 => {
2031 self.flags.elem_fp8 = true;
2032 Elem::FP8(FP8Kind::E4M3)
2033 }
2034 gpu::FloatKind::E5M2 => {
2035 self.flags.elem_fp8 = true;
2036 Elem::FP8(FP8Kind::E5M2)
2037 }
2038 gpu::FloatKind::UE8M0 => {
2039 self.flags.elem_fp8 = true;
2040 Elem::FP8(FP8Kind::UE8M0)
2041 }
2042 gpu::FloatKind::F16 => {
2043 self.flags.elem_f16 = true;
2044 Elem::F16
2045 }
2046 gpu::FloatKind::BF16 => {
2047 self.flags.elem_bf16 = true;
2048 Elem::BF16
2049 }
2050 gpu::FloatKind::TF32 => Elem::TF32,
2051 gpu::FloatKind::Flex32 => Elem::F32,
2052 gpu::FloatKind::F32 => Elem::F32,
2053 gpu::FloatKind::F64 => Elem::F64,
2054 },
2055 gpu::ElemType::Int(kind) => match kind {
2056 gpu::IntKind::I8 => Elem::I8,
2057 gpu::IntKind::I16 => Elem::I16,
2058 gpu::IntKind::I32 => Elem::I32,
2059 gpu::IntKind::I64 => Elem::I64,
2060 },
2061 gpu::ElemType::UInt(kind) => match kind {
2062 gpu::UIntKind::U8 => Elem::U8,
2063 gpu::UIntKind::U16 => Elem::U16,
2064 gpu::UIntKind::U32 => Elem::U32,
2065 gpu::UIntKind::U64 => Elem::U64,
2066 },
2067 gpu::ElemType::Bool => Elem::Bool,
2068 }
2069 }
2070}
2071
2072fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
2073 match elem {
2074 gpu::ElemType::Float(kind) => matches!(
2075 kind,
2076 FloatKind::E2M1
2077 | FloatKind::E2M3
2078 | FloatKind::E3M2
2079 | FloatKind::E4M3
2080 | FloatKind::E5M2
2081 | FloatKind::UE8M0
2082 ),
2083 _ => false,
2084 }
2085}
2086
2087fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2088 Variable::Constant(
2089 gpu::ConstantValue::UInt(value as u64),
2090 Item::new(Elem::U32, 1, true),
2091 )
2092}
2093
2094pub fn register_supported_types(props: &mut DeviceProperties) {
2095 props.register_address_type(gpu::AddressType::U32);
2096 props.register_address_type(gpu::AddressType::U64);
2097
2098 let supported_types = [
2099 gpu::ElemType::UInt(gpu::UIntKind::U8),
2100 gpu::ElemType::UInt(gpu::UIntKind::U16),
2101 gpu::ElemType::UInt(gpu::UIntKind::U32),
2102 gpu::ElemType::UInt(gpu::UIntKind::U64),
2103 gpu::ElemType::Int(gpu::IntKind::I8),
2104 gpu::ElemType::Int(gpu::IntKind::I16),
2105 gpu::ElemType::Int(gpu::IntKind::I32),
2106 gpu::ElemType::Int(gpu::IntKind::I64),
2107 gpu::ElemType::Float(gpu::FloatKind::BF16),
2108 gpu::ElemType::Float(gpu::FloatKind::F16),
2109 gpu::ElemType::Float(gpu::FloatKind::F32),
2110 gpu::ElemType::Float(gpu::FloatKind::Flex32),
2111 gpu::ElemType::Bool,
2114 ];
2115
2116 let supported_atomic_types = [
2117 gpu::ElemType::Int(gpu::IntKind::I32),
2118 gpu::ElemType::Int(gpu::IntKind::I64),
2119 gpu::ElemType::UInt(gpu::UIntKind::U32),
2120 gpu::ElemType::UInt(gpu::UIntKind::U64),
2121 gpu::ElemType::Float(gpu::FloatKind::F32),
2122 ];
2123
2124 for ty in supported_types {
2125 props.register_type_usage(ty, TypeUsage::all());
2126 }
2127
2128 for ty in supported_atomic_types {
2129 props.register_atomic_type_usage(
2130 Type::new(gpu::StorageType::Atomic(ty)),
2131 AtomicUsage::Add | AtomicUsage::LoadStore,
2132 );
2133 }
2134}