1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::ir::{self as gpu, OpaqueType, StorageType};
3use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind};
4use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
5use cubecl_core::server::ExecutionMode;
6use cubecl_core::{CubeDim, ir::ElemType};
7use cubecl_core::{
8 ir::{Operation, SourceLoc},
9 prelude::{FastMath, KernelDefinition},
10};
11use cubecl_opt::{Optimizer, SharedLiveness};
12use cubecl_runtime::compiler::CompilationError;
13use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage, compiler::Compiler};
14use std::{collections::HashSet, fmt::Debug};
15
16use crate::shared::MmaShape;
17
18use super::{
19 BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
20 Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
21 Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
22};
23use super::{FP4Kind, barrier::BarrierOps};
24use super::{FP8Kind, pipeline::PipelineOps};
25
26pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
27 std::sync::atomic::AtomicU32::new(0);
28
29#[derive(Clone, Debug)]
30pub struct CompilationOptions {
31 pub warp_size: u32,
32 pub supports_features: CppSupportedFeatures,
33}
34
35#[derive(Clone, Debug, Default)]
36pub struct CppSupportedFeatures {
37 pub grid_constants: bool,
38 pub clusters: bool,
39 pub fast_math: bool,
40 pub fast_tanh: bool,
41 pub elect_sync: bool,
42}
43
44impl Default for CompilationOptions {
45 fn default() -> Self {
46 Self {
47 warp_size: 32,
48 supports_features: Default::default(),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Default)]
56pub struct CubeIndexFlags {
57 pub absolute_pos: bool,
58 pub absolute_pos_tuple: bool,
59 pub cube_count: bool,
60 pub cube_count_tuple: bool,
61 pub cube_dim: bool,
62 pub cube_dim_tuple: bool,
63 pub cube_pos: bool,
64 pub cube_pos_tuple: bool,
65 pub plane_dim: bool,
66 pub plane_dim_checked: bool,
67 pub plane_index: bool,
68 pub unit_pos: bool,
69 pub unit_pos_tuple: bool,
70 pub unit_pos_plane: bool,
71 pub cluster_pos: bool,
72}
73
74#[derive(Debug, Clone)]
76pub struct Flags {
77 pub elem_fp4: bool,
78 pub elem_fp6: bool,
79 pub elem_fp8: bool,
80 pub elem_bf16: bool,
81 pub elem_f16: bool,
82 pub elem_tf32: bool,
83 pub indexes: CubeIndexFlags,
84 pub op_barrier: bool,
85 pub op_pipeline: bool,
86 pub inst_tma: bool,
87 pub inst_tma_im2col: bool,
88 pub inst_wmma: bool,
89 pub inst_ptx_wrappers: bool,
90 pub inst_async_copy: bool,
91 pub use_grid_constants: bool,
92 pub static_meta_length: usize,
93 pub has_dynamic_meta: bool,
94 pub cube_dim: CubeDim,
95 pub cluster_dim: Option<CubeDim>,
96}
97
98#[allow(clippy::too_many_arguments)]
99#[derive(Clone, Debug)]
100pub struct CppCompiler<D: Dialect> {
101 barriers: Vec<BarrierOps<D>>,
102 compilation_options: CompilationOptions,
103 const_arrays: Vec<ConstArray<D>>,
104 ext_meta_positions: Vec<u32>,
105 cluster_dim: CubeDim,
106 extensions: Vec<D::Extension>,
107 flags: Flags,
108 items: HashSet<Item<D>>,
109 local_arrays: Vec<LocalArray<D>>,
110 metadata: cubecl_core::Metadata,
111 pipelines: Vec<PipelineOps<D>>,
112 source_loc: Option<SourceLoc>,
113 strategy: ExecutionMode,
114}
115
116impl Default for Flags {
117 fn default() -> Self {
118 Self {
119 elem_fp4: Default::default(),
120 elem_fp6: Default::default(),
121 elem_fp8: Default::default(),
122 elem_bf16: Default::default(),
123 elem_f16: Default::default(),
124 elem_tf32: Default::default(),
125 indexes: Default::default(),
126 op_barrier: Default::default(),
127 op_pipeline: Default::default(),
128 inst_tma: Default::default(),
129 inst_tma_im2col: Default::default(),
130 inst_wmma: Default::default(),
131 inst_ptx_wrappers: Default::default(),
132 inst_async_copy: Default::default(),
133 use_grid_constants: Default::default(),
134 static_meta_length: Default::default(),
135 has_dynamic_meta: Default::default(),
136 cube_dim: CubeDim::new_single(),
137 cluster_dim: Default::default(),
138 }
139 }
140}
141
142impl<D: Dialect> Default for CppCompiler<D> {
143 fn default() -> Self {
144 Self {
145 barriers: Default::default(),
146 compilation_options: Default::default(),
147 const_arrays: Default::default(),
148 ext_meta_positions: Default::default(),
149 cluster_dim: CubeDim::new_single(),
150 extensions: Default::default(),
151 flags: Flags::default(),
152 items: Default::default(),
153 local_arrays: Default::default(),
154 metadata: Default::default(),
155 pipelines: Default::default(),
156 source_loc: Default::default(),
157 strategy: Default::default(),
158 }
159 }
160}
161
162impl<D: Dialect> Compiler for CppCompiler<D> {
163 type Representation = ComputeKernel<D>;
164 type CompilationOptions = CompilationOptions;
165
166 fn compile(
167 &mut self,
168 mut kernel: KernelDefinition,
169 compilation_options: &Self::CompilationOptions,
170 strategy: ExecutionMode,
171 ) -> Result<Self::Representation, CompilationError> {
172 let errors = kernel.body.pop_errors();
173 if !errors.is_empty() {
174 let mut reason = "Can't compile cpp kernel\nCaused by:\n ".to_string();
175 for error in errors {
176 reason += error.as_str();
177 reason += "\n";
178 }
179
180 return Err(CompilationError::Validation {
181 reason,
182 backtrace: BackTrace::capture(),
183 });
184 }
185
186 self.compilation_options = compilation_options.clone();
187 self.strategy = strategy;
188
189 if !self.compilation_options.supports_features.clusters {
190 kernel.options.cluster_dim = None;
191 }
192 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
193
194 let ir = self.clone().compile_ir(kernel);
195 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
196 Ok(ir)
197 }
198
199 fn elem_size(&self, elem: gpu::ElemType) -> usize {
200 elem.size()
201 }
202
203 fn extension(&self) -> &'static str {
204 "cpp"
205 }
206}
207
208impl<D: Dialect> CppCompiler<D> {
209 fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
210 self.build_metadata(&value);
211
212 let instructions = self.compile_scope(&mut value.body.clone());
213 let tensor_maps = value
214 .tensor_maps
215 .into_iter()
216 .map(|b| self.compile_binding(b))
217 .collect();
218 let buffers = value
219 .buffers
220 .into_iter()
221 .map(|b| self.compile_binding(b))
222 .collect();
223 let scalars = value
224 .scalars
225 .into_iter()
226 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
227 .collect();
228
229 let flags = Flags {
231 indexes: D::builtin_rules(&self.flags.indexes),
232 inst_wmma: self.flags.inst_wmma,
233 op_pipeline: self.flags.op_pipeline,
234 op_barrier: self.flags.op_barrier,
235 elem_fp4: self.flags.elem_fp4,
236 elem_fp6: self.flags.elem_fp6,
237 elem_fp8: self.flags.elem_fp8,
238 elem_bf16: self.flags.elem_bf16,
239 elem_f16: self.flags.elem_f16,
240 elem_tf32: self.flags.elem_tf32,
241 inst_tma: self.flags.inst_tma,
242 inst_tma_im2col: self.flags.inst_tma_im2col,
243 inst_async_copy: self.flags.inst_async_copy,
244 inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
245 use_grid_constants: self.compilation_options.supports_features.grid_constants,
246 has_dynamic_meta: self.metadata.static_len() > 0,
249 static_meta_length: self.metadata.static_len() as usize,
250 cube_dim: value.cube_dim,
251 cluster_dim: value.options.cluster_dim,
252 };
253
254 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
255 let shared_allocs = opt.analysis::<SharedLiveness>();
256 let shared_memories = shared_allocs
257 .allocations
258 .values()
259 .map(|alloc| match alloc.smem {
260 cubecl_opt::SharedMemory::Array {
261 id,
262 length,
263 ty,
264 align,
265 } => SharedMemory::Array {
266 index: id,
267 item: self.compile_type(ty),
268 length,
269 align,
270 offset: alloc.offset,
271 },
272 cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
273 index: id,
274 item: self.compile_type(ty),
275 align,
276 offset: alloc.offset,
277 },
278 })
279 .collect();
280
281 let body = Body {
282 instructions,
283 shared_memories,
284 pipelines: self.pipelines,
285 barriers: self.barriers,
286 const_arrays: self.const_arrays,
287 local_arrays: self.local_arrays,
288 };
289
290 let mut cluster_dim = value.options.cluster_dim;
291 if !self.compilation_options.supports_features.clusters {
292 cluster_dim = None;
293 }
294
295 ComputeKernel {
296 tensor_maps,
297 buffers,
298 scalars,
299 meta_static_len: self.metadata.static_len() as usize,
300 cube_dim: value.cube_dim,
301 body,
302 extensions: self.extensions,
303 flags,
304 items: self.items,
305 kernel_name: value.options.kernel_name,
306 cluster_dim,
307 }
308 }
309
310 fn build_metadata(&mut self, value: &KernelDefinition) {
311 let mut num_ext = 0;
312
313 let mut all_meta: Vec<_> = value
314 .buffers
315 .iter()
316 .chain(value.tensor_maps.iter())
317 .map(|buf| (buf.id, buf.has_extended_meta))
318 .collect();
319
320 all_meta.sort_by_key(|(id, _)| *id);
321
322 for (_, has_extended_meta) in &all_meta {
323 self.ext_meta_positions.push(num_ext);
324 if *has_extended_meta {
325 num_ext += 1;
326 }
327 }
328
329 let num_meta = all_meta.len();
330
331 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
332 }
333
334 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
335 let id = var.index().expect("Variable should have index");
336 self.ext_meta_positions[id as usize]
337 }
338
339 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
340 let mut instructions = Vec::new();
341
342 let const_arrays = scope
343 .const_arrays
344 .drain(..)
345 .map(|(var, values)| ConstArray {
346 index: var.index().unwrap(),
347 item: self.compile_type(var.ty),
348 size: values.len() as u32,
349 values: values
350 .into_iter()
351 .map(|val| self.compile_variable(val))
352 .collect(),
353 })
354 .collect::<Vec<_>>();
355 self.const_arrays.extend(const_arrays);
356
357 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
358 let dialect_processors = D::processors();
359 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
360 processors.extend(dialect_processors.iter().map(|it| &**it));
361
362 let processing = scope.process(processors);
363
364 for var in processing.variables {
365 instructions.push(Instruction::DeclareVariable {
366 var: self.compile_variable(var),
367 });
368 }
369
370 processing
371 .instructions
372 .into_iter()
373 .for_each(|op| self.compile_instruction(&mut instructions, op));
374
375 instructions
376 }
377
378 fn compile_instruction(
379 &mut self,
380 instructions: &mut Vec<Instruction<D>>,
381 instruction: gpu::Instruction,
382 ) {
383 self.update_debug_loc(instructions, &instruction);
384 let out = instruction.out;
385 match instruction.operation {
386 gpu::Operation::Copy(variable) => {
387 instructions.push(Instruction::Assign(UnaryInstruction {
388 input: self.compile_variable(variable),
389 out: self.compile_variable(out.unwrap()),
390 }));
391 }
392 gpu::Operation::Arithmetic(op) => {
393 self.compile_arithmetic(op, out, instruction.modes, instructions)
394 }
395 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
396 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
397 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
398 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
399 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
400 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
401 gpu::Operation::Synchronization(val) => match val {
402 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
403 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
404 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
405 gpu::Synchronization::SyncAsyncProxyShared => {
406 self.flags.inst_tma = true;
407 instructions.push(Instruction::ProxyAsyncToSharedFence)
408 }
409 },
410 gpu::Operation::Plane(op) => {
411 self.flags.indexes.plane_dim_checked = true;
412 let out = self.compile_variable(out.unwrap());
413 match op {
414 gpu::Plane::Sum(op) => {
415 let instruction = WarpInstruction::ReduceSum {
416 input: self.compile_variable(op.input),
417 out,
418 };
419 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
420 instructions.push(Instruction::Warp(instruction));
421 }
422 gpu::Plane::InclusiveSum(op) => {
423 self.flags.indexes.unit_pos_plane = true;
424 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
425 input: self.compile_variable(op.input),
426 out,
427 }))
428 }
429 gpu::Plane::InclusiveProd(op) => {
430 self.flags.indexes.unit_pos_plane = true;
431 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
432 input: self.compile_variable(op.input),
433 out,
434 }))
435 }
436 gpu::Plane::ExclusiveSum(op) => {
437 self.flags.indexes.unit_pos_plane = true;
438 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
439 input: self.compile_variable(op.input),
440 out,
441 }))
442 }
443 gpu::Plane::ExclusiveProd(op) => {
444 self.flags.indexes.unit_pos_plane = true;
445 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
446 input: self.compile_variable(op.input),
447 out,
448 }))
449 }
450 gpu::Plane::Prod(op) => {
451 let instruction = WarpInstruction::ReduceProd {
452 input: self.compile_variable(op.input),
453 out,
454 };
455 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
456 instructions.push(Instruction::Warp(instruction))
457 }
458 gpu::Plane::Max(op) => {
459 let instruction = WarpInstruction::ReduceMax {
460 input: self.compile_variable(op.input),
461 out,
462 };
463 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
464 instructions.push(Instruction::Warp(instruction))
465 }
466 gpu::Plane::Min(op) => {
467 let instruction = WarpInstruction::ReduceMin {
468 input: self.compile_variable(op.input),
469 out,
470 };
471 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
472 instructions.push(Instruction::Warp(instruction))
473 }
474 gpu::Plane::Elect => {
475 if self.compilation_options.supports_features.elect_sync {
476 self.flags.inst_ptx_wrappers = true;
477 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
478 } else {
479 instructions
480 .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
481 }
482 }
483 gpu::Plane::All(op) => {
484 instructions.push(Instruction::Warp(WarpInstruction::All {
485 input: self.compile_variable(op.input),
486 out,
487 }))
488 }
489 gpu::Plane::Any(op) => {
490 instructions.push(Instruction::Warp(WarpInstruction::Any {
491 input: self.compile_variable(op.input),
492 out,
493 }))
494 }
495 gpu::Plane::Ballot(op) => {
496 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
497 input: self.compile_variable(op.input),
498 out,
499 }))
500 }
501 gpu::Plane::Broadcast(op) => {
502 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
503 input: self.compile_variable(op.lhs),
504 id: self.compile_variable(op.rhs),
505 out,
506 }))
507 }
508 gpu::Plane::Shuffle(op) => {
509 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
510 input: self.compile_variable(op.lhs),
511 src_lane: self.compile_variable(op.rhs),
512 out,
513 }))
514 }
515 gpu::Plane::ShuffleXor(op) => {
516 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
517 input: self.compile_variable(op.lhs),
518 mask: self.compile_variable(op.rhs),
519 out,
520 }))
521 }
522 gpu::Plane::ShuffleUp(op) => {
523 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
524 input: self.compile_variable(op.lhs),
525 delta: self.compile_variable(op.rhs),
526 out,
527 }))
528 }
529 gpu::Plane::ShuffleDown(op) => {
530 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
531 input: self.compile_variable(op.lhs),
532 delta: self.compile_variable(op.rhs),
533 out,
534 }))
535 }
536 }
537 }
538 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
539 gpu::Operation::NonSemantic(debug) => match debug {
540 gpu::NonSemantic::Print {
541 format_string,
542 args,
543 } => instructions.push(Instruction::Printf {
544 format_string,
545 args: args
546 .into_iter()
547 .map(|arg| self.compile_variable(arg))
548 .collect(),
549 }),
550 gpu::NonSemantic::Comment { content } => {
551 instructions.push(Instruction::Comment { content })
552 }
553 _ => {}
555 },
556 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
557 gpu::BarrierOps::Declare { barrier } => {
558 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
559 else {
560 unreachable!()
561 };
562 let barrier = self.compile_variable(barrier);
563 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
564 barrier,
565 level,
566 }));
567 }
568 gpu::BarrierOps::Init {
569 barrier,
570 is_elected,
571 arrival_count,
572 } => {
573 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
574 else {
575 unreachable!()
576 };
577 let barrier = self.compile_variable(barrier);
578 let arrival_count = self.compile_variable(arrival_count);
579 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
580 barrier,
581 is_elected: self.compile_variable(is_elected),
582 arrival_count,
583 level,
584 }));
585 }
586 gpu::BarrierOps::InitManual {
587 barrier,
588 arrival_count,
589 } => {
590 let barrier = self.compile_variable(barrier);
591 let arrival_count = self.compile_variable(arrival_count);
592 instructions.push(Instruction::Barrier(
593 super::barrier::BarrierOps::InitManual {
594 barrier,
595 arrival_count,
596 },
597 ));
598 }
599 gpu::BarrierOps::MemCopyAsync {
600 barrier,
601 source,
602 source_length,
603 offset_source,
604 offset_out,
605 } => {
606 instructions.push(Instruction::Barrier(
607 super::barrier::BarrierOps::MemCopyAsync {
608 barrier: self.compile_variable(barrier),
609 source: self.compile_variable(source),
610 destination: self.compile_variable(out.unwrap()),
611 source_length: self.compile_variable(source_length),
612 offset_source: self.compile_variable(offset_source),
613 offset_out: self.compile_variable(offset_out),
614 cooperative: false,
615 },
616 ));
617 }
618 gpu::BarrierOps::MemCopyAsyncCooperative {
619 barrier,
620 source,
621 source_length,
622 offset_source,
623 offset_out,
624 } => {
625 instructions.push(Instruction::Barrier(
626 super::barrier::BarrierOps::MemCopyAsync {
627 barrier: self.compile_variable(barrier),
628 source: self.compile_variable(source),
629 destination: self.compile_variable(out.unwrap()),
630 source_length: self.compile_variable(source_length),
631 offset_source: self.compile_variable(offset_source),
632 offset_out: self.compile_variable(offset_out),
633 cooperative: true,
634 },
635 ));
636 }
637 gpu::BarrierOps::MemCopyAsyncTx {
638 barrier,
639 source,
640 source_length,
641 offset_source,
642 offset_out,
643 } => {
644 instructions.push(Instruction::Barrier(
645 super::barrier::BarrierOps::MemCopyAsyncTx {
646 barrier: self.compile_variable(barrier),
647 source: self.compile_variable(source),
648 destination: self.compile_variable(out.unwrap()),
649 source_length: self.compile_variable(source_length),
650 offset_source: self.compile_variable(offset_source),
651 offset_out: self.compile_variable(offset_out),
652 },
653 ));
654 }
655 gpu::BarrierOps::CopyAsync {
656 source,
657 source_length,
658 offset_source,
659 offset_out,
660 copy_length,
661 checked,
662 } => {
663 self.flags.inst_async_copy = true;
664 instructions.push(Instruction::Barrier(
665 super::barrier::BarrierOps::CopyAsync {
666 source: self.compile_variable(source),
667 destination: self.compile_variable(out.unwrap()),
668 source_length: self.compile_variable(source_length),
669 offset_source: self.compile_variable(offset_source),
670 offset_out: self.compile_variable(offset_out),
671 copy_size: copy_length,
672 checked,
673 },
674 ));
675 }
676 gpu::BarrierOps::TmaLoad {
677 barrier,
678 tensor_map,
679 offset_out,
680 indices,
681 } => {
682 instructions.push(Instruction::Barrier(
683 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
684 barrier: self.compile_variable(barrier),
685 smem_buffer: self.compile_variable(out.unwrap()),
686 smem_offset: self.compile_variable(offset_out),
687 tensor_map: self.compile_variable(tensor_map),
688 indices: indices
689 .into_iter()
690 .map(|it| self.compile_variable(it))
691 .collect(),
692 },
693 ));
694 }
695 gpu::BarrierOps::TmaLoadIm2col {
696 barrier,
697 tensor_map,
698 offset_out,
699 indices,
700 offsets,
701 } => {
702 self.flags.inst_tma_im2col = true;
703 instructions.push(Instruction::Barrier(
704 super::barrier::BarrierOps::TmaLoadIm2col {
705 barrier: self.compile_variable(barrier),
706 smem_buffer: self.compile_variable(out.unwrap()),
707 smem_offset: self.compile_variable(offset_out),
708 tensor_map: self.compile_variable(tensor_map),
709 indices: indices
710 .into_iter()
711 .map(|it| self.compile_variable(it))
712 .collect(),
713 offsets: offsets
714 .into_iter()
715 .map(|it| self.compile_variable(it))
716 .collect(),
717 },
718 ));
719 }
720 gpu::BarrierOps::Arrive { barrier } => {
721 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
722 barrier: self.compile_variable(barrier),
723 token: self.compile_variable(out.unwrap()),
724 }))
725 }
726 gpu::BarrierOps::ArriveTx {
727 barrier,
728 arrive_count_update,
729 transaction_count_update,
730 } => {
731 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
732 barrier: self.compile_variable(barrier),
733 token: self.compile_variable(out.unwrap()),
734 arrive_count_update: self.compile_variable(arrive_count_update),
735 transaction_count_update: self.compile_variable(transaction_count_update),
736 }))
737 }
738 gpu::BarrierOps::CommitCopyAsync { barrier } => {
739 self.flags.inst_async_copy = true;
740 instructions.push(Instruction::Barrier(
741 super::barrier::BarrierOps::ArriveCopyAsync {
742 barrier: self.compile_variable(barrier),
743 },
744 ))
745 }
746 gpu::BarrierOps::ExpectTx {
747 barrier,
748 transaction_count_update,
749 } => {
750 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
751 barrier: self.compile_variable(barrier),
752 transaction_count_update: self.compile_variable(transaction_count_update),
753 }))
754 }
755 gpu::BarrierOps::Wait { barrier, token } => {
756 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
757 barrier: self.compile_variable(barrier),
758 token: self.compile_variable(token),
759 }))
760 }
761 gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
762 Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
763 barrier: self.compile_variable(barrier),
764 phase: self.compile_variable(phase),
765 }),
766 ),
767 gpu::BarrierOps::ArriveAndWait { barrier } => {
768 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
769 else {
770 unreachable!()
771 };
772 instructions.push(Instruction::Barrier(
773 super::barrier::BarrierOps::ArriveAndWait {
774 barrier: self.compile_variable(barrier),
775 level,
776 },
777 ))
778 }
779 },
780 gpu::Operation::Tma(tma_ops) => {
781 self.flags.inst_tma = true;
782 match tma_ops {
783 gpu::TmaOps::TmaStore {
784 source,
785 coordinates,
786 offset_source,
787 } => {
788 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
789 smem_buffer: self.compile_variable(source),
790 smem_offset: self.compile_variable(offset_source),
791 tensor_map: self.compile_variable(out.unwrap()),
792 indices: coordinates
793 .into_iter()
794 .map(|it| self.compile_variable(it))
795 .collect(),
796 });
797 }
798 gpu::TmaOps::CommitGroup => {
799 instructions.push(Instruction::BulkCommitGroup);
800 }
801 gpu::TmaOps::WaitGroup { max_pending } => {
802 instructions.push(Instruction::BulkWaitGroup { max_pending });
803 }
804 gpu::TmaOps::WaitGroupRead { max_pending } => {
805 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
806 }
807 }
808 }
809 gpu::Operation::Marker(_) => {}
810 }
811 }
812
813 fn update_debug_loc(
814 &mut self,
815 instructions: &mut Vec<Instruction<D>>,
816 inst: &gpu::Instruction,
817 ) {
818 if !matches!(inst.operation, Operation::NonSemantic(_)) {
819 match &inst.source_loc {
820 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
821 self.source_loc = Some(loc.clone());
822 instructions.push(Instruction::Line {
823 file: loc.source.file.clone(),
824 line: loc.line,
825 });
826 }
827 _ => {}
828 }
829 }
830 }
831
832 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
833 self.flags.inst_wmma = true;
834
835 let out = self.compile_variable(out.unwrap());
836
837 let inst = match cmma {
838 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
839 frag: out,
840 value: self.compile_variable(value),
841 },
842 gpu::CoopMma::Load {
843 value,
844 stride,
845 offset,
846 layout,
847 } => WmmaInstruction::Load {
848 frag: out,
849 offset: self.compile_variable(offset),
850 value: self.compile_variable(value),
851 stride: self.compile_variable(stride),
852 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
853 },
854 gpu::CoopMma::Execute {
855 mat_a,
856 mat_b,
857 mat_c,
858 } => WmmaInstruction::Execute {
859 frag_a: self.compile_variable(mat_a),
860 frag_b: self.compile_variable(mat_b),
861 frag_c: self.compile_variable(mat_c),
862 frag_d: out,
863 warp_size: self.compilation_options.warp_size,
864 },
865 gpu::CoopMma::ExecuteManual {
866 matrix,
867 registers_a,
868 registers_b,
869 registers_c,
870 } => WmmaInstruction::ExecuteManual {
871 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
872 frag_a: self.compile_variable(registers_a),
873 frag_b: self.compile_variable(registers_b),
874 frag_c: self.compile_variable(registers_c),
875 frag_d: out,
876 },
877 gpu::CoopMma::ExecuteScaled {
878 matrix,
879 registers_a,
880 registers_b,
881 registers_c,
882 scales_a,
883 scales_b,
884 scales_factor,
885 } => WmmaInstruction::ExecuteScaled {
886 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
887 frag_a: self.compile_variable(registers_a),
888 frag_b: self.compile_variable(registers_b),
889 frag_c: self.compile_variable(registers_c),
890 frag_d: out,
891
892 scales_a: self.compile_variable(scales_a),
893 scales_b: self.compile_variable(scales_b),
894 scales_factor,
895 },
896 gpu::CoopMma::Store {
897 mat,
898 stride,
899 offset,
900 layout,
901 } => {
902 self.flags.indexes.unit_pos = true;
903 self.flags.indexes.plane_index = true;
904 WmmaInstruction::Store {
905 output: out,
906 offset: self.compile_variable(offset),
907 frag: self.compile_variable(mat),
908 stride: self.compile_variable(stride),
909 layout: self
910 .compile_matrix_layout(layout)
911 .expect("Layout required for store instruction"),
912 }
913 }
914 gpu::CoopMma::LoadMatrix {
915 buffer,
916 offset,
917 line_size,
918 factor,
919 transpose,
920 } => WmmaInstruction::LdMatrix {
921 output: out,
922 buffer: self.compile_variable(buffer),
923 offset: self.compile_variable(offset),
924 line_size,
925 factor,
926 transpose,
927 },
928 gpu::CoopMma::StoreMatrix {
929 offset,
930 line_size,
931 registers,
932 factor,
933 transpose,
934 } => WmmaInstruction::StMatrix {
935 registers: self.compile_variable(registers),
936 buffer: out,
937 offset: self.compile_variable(offset),
938 line_size,
939 factor,
940 transpose,
941 },
942 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
943 input: self.compile_variable(input),
944 output: out,
945 },
946 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
947 panic!("Row/Col index should be handled by processors")
948 }
949 };
950
951 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
952
953 Instruction::Wmma(inst)
954 }
955
956 fn compile_metadata(
957 &mut self,
958 metadata: gpu::Metadata,
959 out: Option<gpu::Variable>,
960 ) -> Instruction<D> {
961 let out = out.unwrap();
962 match metadata {
963 gpu::Metadata::Stride { dim, var } => {
964 let position = self.ext_meta_position(var);
965 let offset = self.metadata.stride_offset_index(position);
966 Instruction::ExtendedMetadata {
967 info_offset: self.compile_variable(offset.into()),
968 dim: self.compile_variable(dim),
969 split_meta: self.compilation_options.supports_features.grid_constants,
970 static_offset: self.metadata.static_len(),
971 out: self.compile_variable(out),
972 }
973 }
974 gpu::Metadata::Shape { dim, var } => {
975 let position = self.ext_meta_position(var);
976 let offset = self.metadata.shape_offset_index(position);
977 Instruction::ExtendedMetadata {
978 info_offset: self.compile_variable(offset.into()),
979 dim: self.compile_variable(dim),
980 split_meta: self.compilation_options.supports_features.grid_constants,
981 static_offset: self.metadata.static_len(),
982 out: self.compile_variable(out),
983 }
984 }
985 gpu::Metadata::Rank { var } => {
986 let out = self.compile_variable(out);
987 let pos = self.ext_meta_position(var);
988 let offset = self.metadata.rank_index(pos);
989 super::Instruction::Metadata {
990 info_offset: self.compile_variable(offset.into()),
991 split_meta: self.compilation_options.supports_features.grid_constants,
992 out,
993 }
994 }
995 gpu::Metadata::Length { var } => {
996 let input = self.compile_variable(var);
997 let out = self.compile_variable(out);
998
999 match input {
1000 Variable::Slice { .. } => Instruction::SliceLength { input, out },
1001 Variable::SharedArray(_id, _item, length) => {
1002 Instruction::ConstLength { length, out }
1003 }
1004 _ => {
1005 let id = input.id().expect("Variable should have id");
1006 let offset = self.metadata.len_index(id);
1007 Instruction::Metadata {
1008 info_offset: self.compile_variable(offset.into()),
1009 split_meta: self.compilation_options.supports_features.grid_constants,
1010 out,
1011 }
1012 }
1013 }
1014 }
1015 gpu::Metadata::BufferLength { var } => {
1016 let input = self.compile_variable(var);
1017 let out = self.compile_variable(out);
1018
1019 match input {
1020 Variable::Slice { .. } => Instruction::SliceLength { input, out },
1021 _ => {
1022 let id = input.id().expect("Variable should have id");
1023 let offset = self.metadata.buffer_len_index(id);
1024 Instruction::Metadata {
1025 info_offset: self.compile_variable(offset.into()),
1026 split_meta: self.compilation_options.supports_features.grid_constants,
1027 out,
1028 }
1029 }
1030 }
1031 }
1032 }
1033 }
1034
1035 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
1036 match branch {
1037 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
1038 cond: self.compile_variable(op.cond),
1039 instructions: self.compile_scope(&mut op.scope),
1040 }),
1041 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
1042 cond: self.compile_variable(op.cond),
1043 instructions_if: self.compile_scope(&mut op.scope_if),
1044 instructions_else: self.compile_scope(&mut op.scope_else),
1045 }),
1046 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1047 value: self.compile_variable(op.value),
1048 instructions_default: self.compile_scope(&mut op.scope_default),
1049 instructions_cases: op
1050 .cases
1051 .into_iter()
1052 .map(|(val, mut block)| {
1053 (self.compile_variable(val), self.compile_scope(&mut block))
1054 })
1055 .collect(),
1056 }),
1057 gpu::Branch::Return => instructions.push(Instruction::Return),
1058 gpu::Branch::Break => instructions.push(Instruction::Break),
1059 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1060 i: self.compile_variable(range_loop.i),
1061 start: self.compile_variable(range_loop.start),
1062 end: self.compile_variable(range_loop.end),
1063 step: range_loop.step.map(|it| self.compile_variable(it)),
1064 inclusive: range_loop.inclusive,
1065 instructions: self.compile_scope(&mut range_loop.scope),
1066 }),
1067 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1068 instructions: self.compile_scope(&mut op.scope),
1069 }),
1070 };
1071 }
1072
1073 fn compile_atomic(
1074 &mut self,
1075 value: gpu::AtomicOp,
1076 out: Option<gpu::Variable>,
1077 instructions: &mut Vec<Instruction<D>>,
1078 ) {
1079 let out = out.unwrap();
1080 match value {
1081 gpu::AtomicOp::Load(op) => {
1082 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1083 }
1084 gpu::AtomicOp::Store(op) => {
1085 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1086 }
1087 gpu::AtomicOp::Swap(op) => {
1088 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1089 }
1090 gpu::AtomicOp::Add(op) => {
1091 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1092 }
1093 gpu::AtomicOp::Sub(op) => {
1094 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1095 }
1096 gpu::AtomicOp::Max(op) => {
1097 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1098 }
1099 gpu::AtomicOp::Min(op) => {
1100 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1101 }
1102 gpu::AtomicOp::And(op) => {
1103 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1104 }
1105 gpu::AtomicOp::Or(op) => {
1106 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1107 }
1108 gpu::AtomicOp::Xor(op) => {
1109 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1110 }
1111 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1112 input: self.compile_variable(op.input),
1113 cmp: self.compile_variable(op.cmp),
1114 val: self.compile_variable(op.val),
1115 out: self.compile_variable(out),
1116 }),
1117 }
1118 }
1119
1120 fn compile_arithmetic(
1121 &mut self,
1122 value: gpu::Arithmetic,
1123 out: Option<gpu::Variable>,
1124 modes: InstructionModes,
1125 instructions: &mut Vec<Instruction<D>>,
1126 ) {
1127 let out = out.unwrap();
1128 match value {
1129 gpu::Arithmetic::Add(op) => {
1130 instructions.push(Instruction::Add(self.compile_binary(op, out)))
1131 }
1132 gpu::Arithmetic::SaturatingAdd(op) => {
1133 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1134 }
1135 gpu::Arithmetic::Mul(op) => {
1136 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1137 }
1138 gpu::Arithmetic::Div(op) => {
1139 let op = self.compile_binary(op, out);
1140 instructions.push(self.select_fast_float(
1141 out.ty,
1142 modes,
1143 FastMath::AllowReciprocal
1144 | FastMath::ReducedPrecision
1145 | FastMath::UnsignedZero
1146 | FastMath::NotInf,
1147 Instruction::Div(op),
1148 Instruction::FastDiv(op),
1149 ))
1150 }
1151 gpu::Arithmetic::Sub(op) => {
1152 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1153 }
1154 gpu::Arithmetic::SaturatingSub(op) => {
1155 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1156 }
1157 gpu::Arithmetic::MulHi(op) => {
1158 let instruction = Instruction::HiMul(self.compile_binary(op, out));
1159 D::register_instruction_extension(&mut self.extensions, &instruction);
1160 instructions.push(instruction)
1161 }
1162 gpu::Arithmetic::Modulo(op) => {
1163 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1164 }
1165 gpu::Arithmetic::Abs(op) => {
1166 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1167 }
1168 gpu::Arithmetic::Exp(op) => {
1169 let op = self.compile_unary(op, out);
1170 instructions.push(self.select_fast_float(
1171 out.ty,
1172 modes,
1173 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1174 Instruction::Exp(op),
1175 Instruction::FastExp(op),
1176 ));
1177 }
1178 gpu::Arithmetic::Log(op) => {
1179 let op = self.compile_unary(op, out);
1180 instructions.push(self.select_fast_float(
1181 out.ty,
1182 modes,
1183 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1184 Instruction::Log(op),
1185 Instruction::FastLog(op),
1186 ));
1187 }
1188 gpu::Arithmetic::Log1p(op) => {
1189 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1190 }
1191 gpu::Arithmetic::Cos(op) => {
1192 let op = self.compile_unary(op, out);
1193 instructions.push(self.select_fast_float(
1194 out.ty,
1195 modes,
1196 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1197 Instruction::Cos(op),
1198 Instruction::FastCos(op),
1199 ));
1200 }
1201 gpu::Arithmetic::Sin(op) => {
1202 let op = self.compile_unary(op, out);
1203 instructions.push(self.select_fast_float(
1204 out.ty,
1205 modes,
1206 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1207 Instruction::Sin(op),
1208 Instruction::FastSin(op),
1209 ));
1210 }
1211 gpu::Arithmetic::Tan(op) => {
1212 instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1213 }
1214 gpu::Arithmetic::Tanh(op) => {
1215 let op = self.compile_unary(op, out);
1216 let instruction = Instruction::Tanh(op);
1217 D::register_instruction_extension(&mut self.extensions, &instruction);
1218 if self.compilation_options.supports_features.fast_tanh {
1219 instructions.push(self.select_fast_float(
1220 out.ty,
1221 modes,
1222 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1223 instruction,
1224 Instruction::FastTanh(op),
1225 ))
1226 } else {
1227 instructions.push(instruction);
1228 }
1229 }
1230 gpu::Arithmetic::Sinh(op) => {
1231 let instruction = Instruction::Sinh(self.compile_unary(op, out));
1232 D::register_instruction_extension(&mut self.extensions, &instruction);
1233 instructions.push(instruction)
1234 }
1235 gpu::Arithmetic::Cosh(op) => {
1236 let instruction = Instruction::Cosh(self.compile_unary(op, out));
1237 D::register_instruction_extension(&mut self.extensions, &instruction);
1238 instructions.push(instruction)
1239 }
1240 gpu::Arithmetic::ArcCos(op) => {
1241 let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1242 D::register_instruction_extension(&mut self.extensions, &instruction);
1243 instructions.push(instruction)
1244 }
1245 gpu::Arithmetic::ArcSin(op) => {
1246 let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1247 D::register_instruction_extension(&mut self.extensions, &instruction);
1248 instructions.push(instruction)
1249 }
1250 gpu::Arithmetic::ArcTan(op) => {
1251 let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1252 D::register_instruction_extension(&mut self.extensions, &instruction);
1253 instructions.push(instruction)
1254 }
1255 gpu::Arithmetic::ArcSinh(op) => {
1256 let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1257 D::register_instruction_extension(&mut self.extensions, &instruction);
1258 instructions.push(instruction)
1259 }
1260 gpu::Arithmetic::ArcCosh(op) => {
1261 let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1262 D::register_instruction_extension(&mut self.extensions, &instruction);
1263 instructions.push(instruction)
1264 }
1265 gpu::Arithmetic::ArcTanh(op) => {
1266 let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1267 D::register_instruction_extension(&mut self.extensions, &instruction);
1268 instructions.push(instruction)
1269 }
1270 gpu::Arithmetic::Degrees(op) => {
1271 let instruction = Instruction::Degrees(self.compile_unary(op, out));
1272 D::register_instruction_extension(&mut self.extensions, &instruction);
1273 instructions.push(instruction)
1274 }
1275 gpu::Arithmetic::Radians(op) => {
1276 let instruction = Instruction::Radians(self.compile_unary(op, out));
1277 D::register_instruction_extension(&mut self.extensions, &instruction);
1278 instructions.push(instruction)
1279 }
1280 gpu::Arithmetic::ArcTan2(op) => {
1281 let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1282 D::register_instruction_extension(&mut self.extensions, &instruction);
1283 instructions.push(instruction)
1284 }
1285 gpu::Arithmetic::Powf(op) => {
1286 let op = self.compile_binary(op, out);
1287 instructions.push(self.select_fast_float(
1288 out.ty,
1289 modes,
1290 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1291 Instruction::Powf(op),
1292 Instruction::FastPowf(op),
1293 ))
1294 }
1295 gpu::Arithmetic::Powi(op) => {
1296 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1297 }
1298 gpu::Arithmetic::Hypot(op) => {
1299 instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1300 }
1301 gpu::Arithmetic::Rhypot(op) => {
1302 instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1303 }
1304 gpu::Arithmetic::Sqrt(op) => {
1305 let op = self.compile_unary(op, out);
1306 instructions.push(self.select_fast_float(
1307 out.ty,
1308 modes,
1309 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1310 Instruction::Sqrt(op),
1311 Instruction::FastSqrt(op),
1312 ))
1313 }
1314 gpu::Arithmetic::InverseSqrt(op) => {
1315 let op = self.compile_unary(op, out);
1316 instructions.push(self.select_fast_float(
1317 out.ty,
1318 modes,
1319 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1320 Instruction::InverseSqrt(op),
1321 Instruction::FastInverseSqrt(op),
1322 ))
1323 }
1324 gpu::Arithmetic::Erf(op) => {
1325 let instruction = Instruction::Erf(self.compile_unary(op, out));
1326 D::register_instruction_extension(&mut self.extensions, &instruction);
1327 instructions.push(instruction)
1328 }
1329 gpu::Arithmetic::Max(op) => {
1330 let instruction = Instruction::Max(self.compile_binary(op, out));
1331 D::register_instruction_extension(&mut self.extensions, &instruction);
1332 instructions.push(instruction)
1333 }
1334 gpu::Arithmetic::Min(op) => {
1335 let instruction = Instruction::Min(self.compile_binary(op, out));
1336 D::register_instruction_extension(&mut self.extensions, &instruction);
1337 instructions.push(instruction)
1338 }
1339 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1340 input: self.compile_variable(op.input),
1341 min_value: self.compile_variable(op.min_value),
1342 max_value: self.compile_variable(op.max_value),
1343 out: self.compile_variable(out),
1344 }),
1345 gpu::Arithmetic::Recip(op) => {
1346 let elem = op.input.ty.elem_type();
1347 let input = self.compile_variable(op.input);
1348 let out = self.compile_variable(out);
1349 let lhs = match elem {
1350 gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1351 gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1352 gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1353 gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1354 };
1355 let div = Instruction::Div(BinaryInstruction {
1356 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1357 rhs: input,
1358 out,
1359 });
1360 let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1361
1362 instructions.push(self.select_fast_float(
1363 elem.into(),
1364 modes,
1365 FastMath::AllowReciprocal
1366 | FastMath::ReducedPrecision
1367 | FastMath::UnsignedZero
1368 | FastMath::NotInf,
1369 div,
1370 recip,
1371 ))
1372 }
1373 gpu::Arithmetic::Round(op) => {
1374 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1375 }
1376 gpu::Arithmetic::Floor(op) => {
1377 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1378 }
1379 gpu::Arithmetic::Ceil(op) => {
1380 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1381 }
1382 gpu::Arithmetic::Trunc(op) => {
1383 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1384 }
1385 gpu::Arithmetic::Remainder(op) => {
1386 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1387 }
1388 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1389 a: self.compile_variable(op.a),
1390 b: self.compile_variable(op.b),
1391 c: self.compile_variable(op.c),
1392 out: self.compile_variable(out),
1393 }),
1394 gpu::Arithmetic::Neg(op) => {
1395 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1396 }
1397 gpu::Arithmetic::Normalize(op) => {
1398 let op = self.compile_unary(op, out);
1399 instructions.push(self.select_fast_float(
1400 out.ty,
1401 modes,
1402 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1403 Instruction::Normalize(op),
1404 Instruction::FastNormalize(op),
1405 ))
1406 }
1407 gpu::Arithmetic::Magnitude(op) => {
1408 let op = self.compile_unary(op, out);
1409 instructions.push(self.select_fast_float(
1410 out.ty,
1411 modes,
1412 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1413 Instruction::Magnitude(op),
1414 Instruction::FastMagnitude(op),
1415 ))
1416 }
1417 gpu::Arithmetic::Dot(op) => {
1418 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1419 }
1420 };
1421 }
1422
1423 fn select_fast_float(
1424 &self,
1425 ty: gpu::Type,
1426 modes: InstructionModes,
1427 required_flags: EnumSet<FastMath>,
1428 default: Instruction<D>,
1429 fast: Instruction<D>,
1430 ) -> Instruction<D> {
1431 if !self.compilation_options.supports_features.fast_math
1432 || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1433 {
1434 return default;
1435 }
1436
1437 if modes.fp_math_mode.is_superset(required_flags) {
1438 fast
1439 } else {
1440 default
1441 }
1442 }
1443
1444 fn compile_comparison(
1445 &mut self,
1446 value: gpu::Comparison,
1447 out: Option<gpu::Variable>,
1448 instructions: &mut Vec<Instruction<D>>,
1449 ) {
1450 let out = out.unwrap();
1451 match value {
1452 gpu::Comparison::Equal(op) => {
1453 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1454 }
1455 gpu::Comparison::Lower(op) => {
1456 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1457 }
1458 gpu::Comparison::Greater(op) => {
1459 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1460 }
1461 gpu::Comparison::LowerEqual(op) => {
1462 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1463 }
1464 gpu::Comparison::GreaterEqual(op) => {
1465 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1466 }
1467 gpu::Comparison::NotEqual(op) => {
1468 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1469 }
1470 gpu::Comparison::IsNan(op) => {
1471 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1472 }
1473 gpu::Comparison::IsInf(op) => {
1474 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1475 }
1476 };
1477 }
1478
1479 fn compile_bitwise(
1480 &mut self,
1481 value: gpu::Bitwise,
1482 out: Option<gpu::Variable>,
1483 instructions: &mut Vec<Instruction<D>>,
1484 ) {
1485 let out = out.unwrap();
1486 match value {
1487 gpu::Bitwise::BitwiseOr(op) => {
1488 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1489 }
1490 gpu::Bitwise::BitwiseAnd(op) => {
1491 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1492 }
1493 gpu::Bitwise::BitwiseXor(op) => {
1494 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1495 }
1496 gpu::Bitwise::CountOnes(op) => {
1497 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1498 }
1499 gpu::Bitwise::ReverseBits(op) => {
1500 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1501 }
1502 gpu::Bitwise::ShiftLeft(op) => {
1503 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1504 }
1505 gpu::Bitwise::ShiftRight(op) => {
1506 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1507 }
1508 gpu::Bitwise::BitwiseNot(op) => {
1509 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1510 }
1511 gpu::Bitwise::LeadingZeros(op) => {
1512 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1513 }
1514 gpu::Bitwise::FindFirstSet(op) => {
1515 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1516 D::register_instruction_extension(&mut self.extensions, &instruction);
1517 instructions.push(instruction)
1518 }
1519 };
1520 }
1521
1522 fn compile_operator(
1523 &mut self,
1524 value: gpu::Operator,
1525 out: Option<gpu::Variable>,
1526 instructions: &mut Vec<Instruction<D>>,
1527 ) {
1528 let out = out.unwrap();
1529 match value {
1530 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1531 instructions.push(Instruction::Index(self.compile_index(op, out)));
1532 }
1533 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1534 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1535 }
1536 gpu::Operator::And(op) => {
1537 instructions.push(Instruction::And(self.compile_binary(op, out)))
1538 }
1539 gpu::Operator::Or(op) => {
1540 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1541 }
1542 gpu::Operator::Not(op) => {
1543 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1544 }
1545 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1546 inputs: op
1547 .inputs
1548 .into_iter()
1549 .map(|it| self.compile_variable(it))
1550 .collect(),
1551 out: self.compile_variable(out),
1552 }),
1553 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1554 input: self.compile_variable(op.input),
1555 in_index: self.compile_variable(op.in_index),
1556 out: self.compile_variable(out),
1557 out_index: self.compile_variable(op.out_index),
1558 }),
1559 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1560 input: self.compile_variable(op.input),
1561 in_index: self.compile_variable(op.in_index),
1562 out: self.compile_variable(out),
1563 out_index: self.compile_variable(op.out_index),
1564 len: op.len,
1565 }),
1566 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1567 cond: self.compile_variable(op.cond),
1568 then: self.compile_variable(op.then),
1569 or_else: self.compile_variable(op.or_else),
1570 out: self.compile_variable(out),
1571 }),
1572 gpu::Operator::Cast(op)
1574 if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1575 {
1576 self.flags.elem_f16 = true;
1578 self.flags.elem_bf16 = true;
1579 let vec_in = op.input.ty.line_size();
1580 let packing = out.storage_type().packing_factor();
1581 self.compile_type(op.input.ty.line(packing));
1582 self.compile_type(
1583 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1584 );
1585 self.compile_type(
1586 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1587 );
1588 self.compile_type(
1589 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1590 );
1591 self.compile_type(
1592 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1593 );
1594
1595 let inst = self.compile_unary(op, out);
1596
1597 instructions.push(Instruction::SpecialCast(inst));
1598 }
1599 gpu::Operator::Cast(op) => {
1600 let op = self.compile_unary(op, out);
1601
1602 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1603 self.flags.elem_tf32 = true;
1604 }
1605
1606 instructions.push(Instruction::Assign(op))
1607 }
1608 gpu::Operator::Reinterpret(op) => {
1609 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1610 }
1611 };
1612 }
1613
1614 fn compile_binary(
1615 &mut self,
1616 value: gpu::BinaryOperator,
1617 out: gpu::Variable,
1618 ) -> BinaryInstruction<D> {
1619 BinaryInstruction {
1620 lhs: self.compile_variable(value.lhs),
1621 rhs: self.compile_variable(value.rhs),
1622 out: self.compile_variable(out),
1623 }
1624 }
1625
1626 fn compile_index_assign(
1627 &mut self,
1628 value: gpu::IndexAssignOperator,
1629 out: gpu::Variable,
1630 ) -> IndexAssignInstruction<D> {
1631 IndexAssignInstruction {
1632 index: self.compile_variable(value.index),
1633 value: self.compile_variable(value.value),
1634 line_size: value.line_size,
1635 out: self.compile_variable(out),
1636 }
1637 }
1638
1639 fn compile_index(
1640 &mut self,
1641 value: gpu::IndexOperator,
1642 out: gpu::Variable,
1643 ) -> IndexInstruction<D> {
1644 IndexInstruction {
1645 list: self.compile_variable(value.list),
1646 index: self.compile_variable(value.index),
1647 line_size: value.line_size,
1648 out: self.compile_variable(out),
1649 }
1650 }
1651
1652 fn compile_unary(
1653 &mut self,
1654 value: gpu::UnaryOperator,
1655 out: gpu::Variable,
1656 ) -> UnaryInstruction<D> {
1657 UnaryInstruction {
1658 input: self.compile_variable(value.input),
1659 out: self.compile_variable(out),
1660 }
1661 }
1662
1663 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1664 let item = value.ty;
1665 match value.kind {
1666 gpu::VariableKind::GlobalInputArray(id) => {
1667 Variable::GlobalInputArray(id, self.compile_type(item))
1668 }
1669 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1670 id,
1671 elem: self.compile_storage_type(item.storage_type()),
1672 in_struct: self.compilation_options.supports_features.grid_constants,
1673 },
1674 gpu::VariableKind::TensorMapInput(id) => {
1675 self.flags.inst_tma = true;
1676 Variable::TensorMap(id)
1677 }
1678 gpu::VariableKind::TensorMapOutput(id) => {
1679 self.flags.inst_tma = true;
1680 Variable::TensorMap(id)
1681 }
1682 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1683 id,
1684 item: self.compile_type(item),
1685 },
1686 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1687 id,
1688 item: self.compile_type(item),
1689 },
1690 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1691 id,
1692 item: self.compile_type(item),
1693 },
1694 gpu::VariableKind::GlobalOutputArray(id) => {
1695 Variable::GlobalOutputArray(id, self.compile_type(item))
1696 }
1697 gpu::VariableKind::ConstantScalar(value) => {
1698 Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1699 }
1700 gpu::VariableKind::SharedArray { id, length, .. } => {
1701 let item = self.compile_type(item);
1702 Variable::SharedArray(id, item, length)
1703 }
1704 gpu::VariableKind::Shared { id } => {
1705 let item = self.compile_type(item);
1706 Variable::Shared(id, item)
1707 }
1708 gpu::VariableKind::ConstantArray {
1709 id,
1710 length,
1711 unroll_factor,
1712 } => {
1713 let item = self.compile_type(item);
1714 Variable::ConstantArray(id, item, length * unroll_factor)
1715 }
1716 gpu::VariableKind::Builtin(builtin) => match builtin {
1717 gpu::Builtin::AbsolutePos => {
1718 self.flags.indexes.absolute_pos = true;
1719 Variable::AbsolutePos
1720 }
1721 gpu::Builtin::CubePosCluster
1722 if self.compilation_options.supports_features.clusters =>
1723 {
1724 self.flags.indexes.cluster_pos = true;
1725 Variable::ClusterRank
1726 }
1727 gpu::Builtin::CubePosClusterX
1728 if self.compilation_options.supports_features.clusters =>
1729 {
1730 self.flags.indexes.cluster_pos = true;
1731 Variable::ClusterIndexX
1732 }
1733 gpu::Builtin::CubePosClusterY
1734 if self.compilation_options.supports_features.clusters =>
1735 {
1736 self.flags.indexes.cluster_pos = true;
1737 Variable::ClusterIndexY
1738 }
1739 gpu::Builtin::CubePosClusterZ
1740 if self.compilation_options.supports_features.clusters =>
1741 {
1742 self.flags.indexes.cluster_pos = true;
1743 Variable::ClusterIndexZ
1744 }
1745 gpu::Builtin::CubePosCluster
1748 | gpu::Builtin::CubePosClusterX
1749 | gpu::Builtin::CubePosClusterY
1750 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1751 gpu::Builtin::AbsolutePosX => {
1752 self.flags.indexes.absolute_pos_tuple = true;
1753 Variable::AbsolutePosX
1754 }
1755 gpu::Builtin::AbsolutePosY => {
1756 self.flags.indexes.absolute_pos_tuple = true;
1757 Variable::AbsolutePosY
1758 }
1759 gpu::Builtin::AbsolutePosZ => {
1760 self.flags.indexes.absolute_pos_tuple = true;
1761 Variable::AbsolutePosZ
1762 }
1763 gpu::Builtin::CubeDim => {
1764 self.flags.indexes.cube_dim = true;
1765 Variable::CubeDim
1766 }
1767 gpu::Builtin::CubeDimX => {
1768 self.flags.indexes.cube_dim_tuple = true;
1769 Variable::CubeDimX
1770 }
1771 gpu::Builtin::CubeDimY => {
1772 self.flags.indexes.cube_dim_tuple = true;
1773 Variable::CubeDimY
1774 }
1775 gpu::Builtin::CubeDimZ => {
1776 self.flags.indexes.cube_dim_tuple = true;
1777 Variable::CubeDimZ
1778 }
1779 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1780 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1781 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1782 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1783 gpu::Builtin::CubePos => {
1784 self.flags.indexes.cube_pos = true;
1785 Variable::CubePos
1786 }
1787 gpu::Builtin::CubePosX => {
1788 self.flags.indexes.cube_pos_tuple = true;
1789 Variable::CubePosX
1790 }
1791 gpu::Builtin::CubePosY => {
1792 self.flags.indexes.cube_pos_tuple = true;
1793 Variable::CubePosY
1794 }
1795 gpu::Builtin::CubePosZ => {
1796 self.flags.indexes.cube_pos_tuple = true;
1797 Variable::CubePosZ
1798 }
1799 gpu::Builtin::CubeCount => {
1800 self.flags.indexes.cube_count = true;
1801 Variable::CubeCount
1802 }
1803 gpu::Builtin::CubeCountX => {
1804 self.flags.indexes.cube_count_tuple = true;
1805 Variable::CubeCountX
1806 }
1807 gpu::Builtin::CubeCountY => {
1808 self.flags.indexes.cube_count_tuple = true;
1809 Variable::CubeCountY
1810 }
1811 gpu::Builtin::CubeCountZ => {
1812 self.flags.indexes.cube_count_tuple = true;
1813 Variable::CubeCountZ
1814 }
1815 gpu::Builtin::UnitPos => {
1816 self.flags.indexes.unit_pos = true;
1817 Variable::UnitPos
1818 }
1819 gpu::Builtin::UnitPosX => {
1820 self.flags.indexes.unit_pos_tuple = true;
1821 Variable::UnitPosX
1822 }
1823 gpu::Builtin::UnitPosY => {
1824 self.flags.indexes.unit_pos_tuple = true;
1825 Variable::UnitPosY
1826 }
1827 gpu::Builtin::UnitPosZ => {
1828 self.flags.indexes.unit_pos_tuple = true;
1829 Variable::UnitPosZ
1830 }
1831 gpu::Builtin::PlaneDim => {
1832 self.flags.indexes.plane_dim = true;
1833 Variable::PlaneDim
1834 }
1835 gpu::Builtin::UnitPosPlane => {
1836 self.flags.indexes.unit_pos_plane = true;
1837 Variable::UnitPosPlane
1838 }
1839 },
1840 gpu::VariableKind::LocalArray {
1841 id,
1842 length,
1843 unroll_factor,
1844 } => {
1845 let item = self.compile_type(item);
1846 if !self.local_arrays.iter().any(|s| s.index == id) {
1847 self.local_arrays
1848 .push(LocalArray::new(id, item, length * unroll_factor));
1849 }
1850 Variable::LocalArray(id, item, length)
1851 }
1852 gpu::VariableKind::Matrix { id, mat } => {
1853 self.flags.inst_wmma = true;
1854 Variable::WmmaFragment {
1855 id,
1856 frag: self.compile_matrix(mat),
1857 }
1858 }
1859 gpu::VariableKind::Pipeline { id, num_stages } => {
1860 self.flags.op_pipeline = true;
1861 let pipeline = Variable::Pipeline { id };
1862 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1863 self.pipelines.push(PipelineOps::Init {
1864 pipeline,
1865 num_stages,
1866 });
1867 }
1868 pipeline
1869 }
1870 gpu::VariableKind::BarrierToken { id, level } => {
1871 self.flags.op_barrier = true;
1872 Variable::BarrierToken { id, level }
1873 }
1874 }
1875 }
1876
1877 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1878 Fragment {
1879 ident: self.compile_matrix_ident(matrix.ident),
1880 m: matrix.m,
1881 n: matrix.n,
1882 k: matrix.k,
1883 elem: self.compile_storage_type(matrix.storage),
1884 layout: self.compile_matrix_layout(matrix.layout),
1885 }
1886 }
1887
1888 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1889 match ident {
1890 gpu::MatrixIdent::A => FragmentIdent::A,
1891 gpu::MatrixIdent::B => FragmentIdent::B,
1892 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1893 }
1894 }
1895
1896 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1897 match layout {
1898 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1899 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1900 gpu::MatrixLayout::Undefined => None,
1901 }
1902 }
1903
1904 fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1905 Binding {
1906 id: binding.id,
1907 item: self.compile_type(binding.ty),
1908 location: binding.location,
1909 size: binding.size,
1910 vis: binding.visibility,
1911 }
1912 }
1913
1914 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1915 let item = match ty {
1916 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1917 gpu::Type::Line(ty, line_size) => {
1918 Item::new(self.compile_storage_type(ty), line_size as usize, false)
1919 }
1920 gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1921 };
1922 if item.elem != super::Elem::TF32 {
1923 self.items.insert(item);
1924 self.items.insert(item.optimized());
1925 } else {
1926 let mut item = item;
1928 item.elem = super::Elem::F32;
1929 self.items.insert(item);
1930 }
1931
1932 item
1933 }
1934
1935 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1936 match value {
1937 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1938 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1939 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1940 FloatKind::E2M1 => {
1941 self.flags.elem_fp4 = true;
1942 Elem::FP4x2(FP4Kind::E2M1)
1943 }
1944 FloatKind::E2M3 => {
1945 self.flags.elem_fp6 = true;
1946 Elem::FP6x2(FP6Kind::E2M3)
1947 }
1948 FloatKind::E3M2 => {
1949 self.flags.elem_fp6 = true;
1950 Elem::FP6(FP6Kind::E3M2)
1951 }
1952 FloatKind::E4M3 => {
1953 self.flags.elem_fp8 = true;
1954 Elem::FP8x2(FP8Kind::E4M3)
1955 }
1956 FloatKind::E5M2 => {
1957 self.flags.elem_fp8 = true;
1958 Elem::FP8x2(FP8Kind::E5M2)
1959 }
1960 FloatKind::UE8M0 => {
1961 self.flags.elem_fp8 = true;
1962 Elem::FP8x2(FP8Kind::UE8M0)
1963 }
1964 FloatKind::F16 => {
1965 self.flags.elem_f16 = true;
1966 Elem::F16x2
1967 }
1968 FloatKind::BF16 => {
1969 self.flags.elem_bf16 = true;
1970 Elem::BF16x2
1971 }
1972 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1973 },
1974 gpu::StorageType::Packed(other, factor) => {
1975 unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
1976 }
1977 gpu::StorageType::Opaque(ty) => match ty {
1978 gpu::OpaqueType::Barrier(level) => {
1979 self.flags.op_barrier = true;
1980 Elem::Barrier(level)
1981 }
1982 },
1983 }
1984 }
1985
1986 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1987 match value {
1988 gpu::ElemType::Float(kind) => match kind {
1989 gpu::FloatKind::E2M1 => {
1990 self.flags.elem_fp4 = true;
1991 Elem::FP4(FP4Kind::E2M1)
1992 }
1993 gpu::FloatKind::E2M3 => {
1994 self.flags.elem_fp6 = true;
1995 Elem::FP6(FP6Kind::E2M3)
1996 }
1997 gpu::FloatKind::E3M2 => {
1998 self.flags.elem_fp6 = true;
1999 Elem::FP6(FP6Kind::E3M2)
2000 }
2001 gpu::FloatKind::E4M3 => {
2002 self.flags.elem_fp8 = true;
2003 Elem::FP8(FP8Kind::E4M3)
2004 }
2005 gpu::FloatKind::E5M2 => {
2006 self.flags.elem_fp8 = true;
2007 Elem::FP8(FP8Kind::E5M2)
2008 }
2009 gpu::FloatKind::UE8M0 => {
2010 self.flags.elem_fp8 = true;
2011 Elem::FP8(FP8Kind::UE8M0)
2012 }
2013 gpu::FloatKind::F16 => {
2014 self.flags.elem_f16 = true;
2015 Elem::F16
2016 }
2017 gpu::FloatKind::BF16 => {
2018 self.flags.elem_bf16 = true;
2019 Elem::BF16
2020 }
2021 gpu::FloatKind::TF32 => Elem::TF32,
2022 gpu::FloatKind::Flex32 => Elem::F32,
2023 gpu::FloatKind::F32 => Elem::F32,
2024 gpu::FloatKind::F64 => Elem::F64,
2025 },
2026 gpu::ElemType::Int(kind) => match kind {
2027 gpu::IntKind::I8 => Elem::I8,
2028 gpu::IntKind::I16 => Elem::I16,
2029 gpu::IntKind::I32 => Elem::I32,
2030 gpu::IntKind::I64 => Elem::I64,
2031 },
2032 gpu::ElemType::UInt(kind) => match kind {
2033 gpu::UIntKind::U8 => Elem::U8,
2034 gpu::UIntKind::U16 => Elem::U16,
2035 gpu::UIntKind::U32 => Elem::U32,
2036 gpu::UIntKind::U64 => Elem::U64,
2037 },
2038 gpu::ElemType::Bool => Elem::Bool,
2039 }
2040 }
2041}
2042
2043fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
2044 match elem {
2045 gpu::ElemType::Float(kind) => matches!(
2046 kind,
2047 FloatKind::E2M1
2048 | FloatKind::E2M3
2049 | FloatKind::E3M2
2050 | FloatKind::E4M3
2051 | FloatKind::E5M2
2052 | FloatKind::UE8M0
2053 ),
2054 _ => false,
2055 }
2056}
2057
2058fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2059 Variable::ConstantScalar(
2060 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
2061 Elem::U32,
2062 )
2063}
2064
2065pub fn register_supported_types(props: &mut DeviceProperties) {
2066 let supported_types = [
2067 gpu::ElemType::UInt(gpu::UIntKind::U8),
2068 gpu::ElemType::UInt(gpu::UIntKind::U16),
2069 gpu::ElemType::UInt(gpu::UIntKind::U32),
2070 gpu::ElemType::UInt(gpu::UIntKind::U64),
2071 gpu::ElemType::Int(gpu::IntKind::I8),
2072 gpu::ElemType::Int(gpu::IntKind::I16),
2073 gpu::ElemType::Int(gpu::IntKind::I32),
2074 gpu::ElemType::Int(gpu::IntKind::I64),
2075 gpu::ElemType::Float(gpu::FloatKind::BF16),
2076 gpu::ElemType::Float(gpu::FloatKind::F16),
2077 gpu::ElemType::Float(gpu::FloatKind::F32),
2078 gpu::ElemType::Float(gpu::FloatKind::Flex32),
2079 gpu::ElemType::Bool,
2082 ];
2083
2084 let supported_atomic_types = [
2085 gpu::ElemType::Int(gpu::IntKind::I32),
2086 gpu::ElemType::Int(gpu::IntKind::I64),
2087 gpu::ElemType::UInt(gpu::UIntKind::U32),
2088 gpu::ElemType::UInt(gpu::UIntKind::U64),
2089 gpu::ElemType::Float(gpu::FloatKind::F32),
2090 ];
2091
2092 for ty in supported_types {
2093 props.register_type_usage(ty, TypeUsage::all_scalar());
2094 }
2095
2096 for ty in supported_atomic_types {
2097 props.register_type_usage(
2098 gpu::StorageType::Atomic(ty),
2099 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
2100 );
2101 }
2102}