1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::CubeDim;
5use cubecl_core::ir::{FloatKind, Processor, UIntKind, VariableKind};
6use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
7use cubecl_core::{
8 Compiler,
9 ir::{self as gpu},
10};
11use cubecl_core::{
12 ir::{Operation, SourceLoc},
13 prelude::{FastMath, KernelDefinition},
14};
15use cubecl_opt::{Optimizer, SharedLiveness};
16use cubecl_runtime::{DeviceProperties, TypeUsage};
17
18use crate::shared::MmaShape;
19
20use super::{
21 BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
22 Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
23 Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
24};
25use super::{FP4Kind, barrier::BarrierOps};
26use super::{FP8Kind, pipeline::PipelineOps};
27
28pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
29 std::sync::atomic::AtomicU32::new(0);
30
31#[derive(Clone, Debug)]
32pub struct CompilationOptions {
33 pub warp_size: u32,
34 pub grid_constants: bool,
35 pub supports_clusters: bool,
36}
37
38impl Default for CompilationOptions {
39 fn default() -> Self {
40 Self {
41 warp_size: 32,
42 grid_constants: false,
43 supports_clusters: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Default)]
51pub struct CubeIndexFlags {
52 pub absolute_pos: bool,
53 pub absolute_pos_tuple: bool,
54 pub cube_count: bool,
55 pub cube_count_tuple: bool,
56 pub cube_dim: bool,
57 pub cube_dim_tuple: bool,
58 pub cube_pos: bool,
59 pub cube_pos_tuple: bool,
60 pub plane_dim: bool,
61 pub plane_dim_checked: bool,
62 pub plane_index: bool,
63 pub unit_pos: bool,
64 pub unit_pos_tuple: bool,
65 pub unit_pos_plane: bool,
66 pub cluster_pos: bool,
67}
68
69#[derive(Debug, Clone, Default)]
71pub struct Flags {
72 pub elem_fp4: bool,
73 pub elem_fp6: bool,
74 pub elem_fp8: bool,
75 pub elem_bf16: bool,
76 pub elem_f16: bool,
77 pub elem_tf32: bool,
78 pub indexes: CubeIndexFlags,
79 pub op_barrier: bool,
80 pub op_pipeline: bool,
81 pub inst_fast_math: bool,
82 pub inst_tma: bool,
83 pub inst_tma_im2col: bool,
84 pub inst_wmma: bool,
85 pub use_grid_constants: bool,
86 pub static_meta_length: usize,
87 pub has_dynamic_meta: bool,
88 pub cube_dim: CubeDim,
89 pub cluster_dim: Option<CubeDim>,
90}
91
92#[allow(clippy::too_many_arguments)]
93#[derive(Clone, Debug, Default)]
94pub struct CppCompiler<D: Dialect> {
95 barriers: Vec<BarrierOps<D>>,
96 compilation_options: CompilationOptions,
97 const_arrays: Vec<ConstArray<D>>,
98 ext_meta_positions: Vec<u32>,
99 cluster_dim: CubeDim,
100 extensions: Vec<D::Extension>,
101 flags: Flags,
102 items: HashSet<Item<D>>,
103 local_arrays: Vec<LocalArray<D>>,
104 metadata: cubecl_core::Metadata,
105 pipelines: Vec<PipelineOps<D>>,
106 source_loc: Option<SourceLoc>,
107 strategy: ExecutionMode,
108}
109
110impl<D: Dialect> Compiler for CppCompiler<D> {
111 type Representation = ComputeKernel<D>;
112 type CompilationOptions = CompilationOptions;
113
114 fn compile(
115 &mut self,
116 mut kernel: KernelDefinition,
117 compilation_options: &Self::CompilationOptions,
118 strategy: ExecutionMode,
119 ) -> Self::Representation {
120 self.compilation_options = compilation_options.clone();
121 self.strategy = strategy;
122
123 if !self.compilation_options.supports_clusters {
124 kernel.options.cluster_dim = None;
125 }
126 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
127
128 let ir = self.clone().compile_ir(kernel);
129 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
130 ir
131 }
132
133 fn elem_size(&self, elem: gpu::ElemType) -> usize {
134 elem.size()
135 }
136
137 fn extension(&self) -> &'static str {
138 "cpp"
139 }
140}
141
142impl<D: Dialect> CppCompiler<D> {
143 fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
144 self.build_metadata(&value);
145
146 let instructions = self.compile_scope(&mut value.body.clone());
147 let tensor_maps = value
148 .tensor_maps
149 .into_iter()
150 .map(|b| self.compile_binding(b))
151 .collect();
152 let buffers = value
153 .buffers
154 .into_iter()
155 .map(|b| self.compile_binding(b))
156 .collect();
157 let scalars = value
158 .scalars
159 .into_iter()
160 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
161 .collect();
162
163 let flags = Flags {
165 indexes: D::builtin_rules(&self.flags.indexes),
166 inst_wmma: self.flags.inst_wmma,
167 op_pipeline: self.flags.op_pipeline,
168 op_barrier: self.flags.op_barrier,
169 elem_fp4: self.flags.elem_fp4,
170 elem_fp6: self.flags.elem_fp6,
171 elem_fp8: self.flags.elem_fp8,
172 elem_bf16: self.flags.elem_bf16,
173 elem_f16: self.flags.elem_f16,
174 elem_tf32: self.flags.elem_tf32,
175 inst_fast_math: value
176 .options
177 .fp_math_mode
178 .contains(FastMath::ReducedPrecision),
179 inst_tma: self.flags.inst_tma,
180 inst_tma_im2col: self.flags.inst_tma_im2col,
181 use_grid_constants: self.compilation_options.grid_constants,
182 has_dynamic_meta: self.metadata.static_len() > 0,
185 static_meta_length: self.metadata.static_len() as usize,
186 cube_dim: value.cube_dim,
187 cluster_dim: value.options.cluster_dim,
188 };
189
190 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
191 let shared_allocs = opt.analysis::<SharedLiveness>();
192 let shared_memories = shared_allocs
193 .allocations
194 .values()
195 .map(|alloc| SharedMemory {
196 index: alloc.smem.id,
197 item: self.compile_type(alloc.smem.ty),
198 length: alloc.smem.length,
199 align: alloc.smem.align,
200 offset: alloc.offset,
201 })
202 .collect();
203
204 let body = Body {
205 instructions,
206 shared_memories,
207 pipelines: self.pipelines,
208 barriers: self.barriers,
209 const_arrays: self.const_arrays,
210 local_arrays: self.local_arrays,
211 };
212
213 let mut cluster_dim = value.options.cluster_dim;
214 if !self.compilation_options.supports_clusters {
215 cluster_dim = None;
216 }
217
218 ComputeKernel {
219 tensor_maps,
220 buffers,
221 scalars,
222 meta_static_len: self.metadata.static_len() as usize,
223 cube_dim: value.cube_dim,
224 body,
225 extensions: self.extensions,
226 flags,
227 items: self.items,
228 kernel_name: value.options.kernel_name,
229 cluster_dim,
230 }
231 }
232
233 fn build_metadata(&mut self, value: &KernelDefinition) {
234 let mut num_ext = 0;
235
236 let mut all_meta: Vec<_> = value
237 .buffers
238 .iter()
239 .chain(value.tensor_maps.iter())
240 .map(|buf| (buf.id, buf.has_extended_meta))
241 .collect();
242
243 all_meta.sort_by_key(|(id, _)| *id);
244
245 for (_, has_extended_meta) in &all_meta {
246 self.ext_meta_positions.push(num_ext);
247 if *has_extended_meta {
248 num_ext += 1;
249 }
250 }
251
252 let num_meta = all_meta.len();
253
254 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
255 }
256
257 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
258 let id = var.index().expect("Variable should have index");
259 self.ext_meta_positions[id as usize]
260 }
261
262 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
263 let mut instructions = Vec::new();
264
265 let const_arrays = scope
266 .const_arrays
267 .drain(..)
268 .map(|(var, values)| ConstArray {
269 index: var.index().unwrap(),
270 item: self.compile_type(var.ty),
271 size: values.len() as u32,
272 values: values
273 .into_iter()
274 .map(|val| self.compile_variable(val))
275 .collect(),
276 })
277 .collect::<Vec<_>>();
278 self.const_arrays.extend(const_arrays);
279
280 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
281 let dialect_processors = D::processors();
282 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
283 processors.extend(dialect_processors.iter().map(|it| &**it));
284
285 let processing = scope.process(processors);
286
287 for var in processing.variables {
288 instructions.push(Instruction::DeclareVariable {
289 var: self.compile_variable(var),
290 });
291 }
292
293 processing
294 .instructions
295 .into_iter()
296 .for_each(|op| self.compile_instruction(&mut instructions, op));
297
298 instructions
299 }
300
301 fn compile_instruction(
302 &mut self,
303 instructions: &mut Vec<Instruction<D>>,
304 instruction: gpu::Instruction,
305 ) {
306 self.update_debug_loc(instructions, &instruction);
307 let out = instruction.out;
308 match instruction.operation {
309 gpu::Operation::Copy(variable) => {
310 instructions.push(Instruction::Assign(UnaryInstruction {
311 input: self.compile_variable(variable),
312 out: self.compile_variable(out.unwrap()),
313 }));
314 }
315 gpu::Operation::Arithmetic(op) => self.compile_arithmetic(op, out, instructions),
316 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
317 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
318 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
319 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
320 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
321 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
322 gpu::Operation::Synchronization(val) => match val {
323 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
324 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
325 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
326 gpu::Synchronization::SyncProxyShared => {
327 self.flags.inst_tma = true;
328 instructions.push(Instruction::ProxySharedFence)
329 }
330 },
331 gpu::Operation::Plane(op) => {
332 self.flags.indexes.plane_dim_checked = true;
333 let out = self.compile_variable(out.unwrap());
334 match op {
335 gpu::Plane::Sum(op) => {
336 let instruction = WarpInstruction::ReduceSum {
337 input: self.compile_variable(op.input),
338 out,
339 };
340 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
341 instructions.push(Instruction::Warp(instruction));
342 }
343 gpu::Plane::InclusiveSum(op) => {
344 self.flags.indexes.unit_pos_plane = true;
345 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
346 input: self.compile_variable(op.input),
347 out,
348 }))
349 }
350 gpu::Plane::InclusiveProd(op) => {
351 self.flags.indexes.unit_pos_plane = true;
352 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
353 input: self.compile_variable(op.input),
354 out,
355 }))
356 }
357 gpu::Plane::ExclusiveSum(op) => {
358 self.flags.indexes.unit_pos_plane = true;
359 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
360 input: self.compile_variable(op.input),
361 out,
362 }))
363 }
364 gpu::Plane::ExclusiveProd(op) => {
365 self.flags.indexes.unit_pos_plane = true;
366 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
367 input: self.compile_variable(op.input),
368 out,
369 }))
370 }
371 gpu::Plane::Prod(op) => {
372 let instruction = WarpInstruction::ReduceProd {
373 input: self.compile_variable(op.input),
374 out,
375 };
376 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
377 instructions.push(Instruction::Warp(instruction))
378 }
379 gpu::Plane::Max(op) => {
380 let instruction = WarpInstruction::ReduceMax {
381 input: self.compile_variable(op.input),
382 out,
383 };
384 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
385 instructions.push(Instruction::Warp(instruction))
386 }
387 gpu::Plane::Min(op) => {
388 let instruction = WarpInstruction::ReduceMin {
389 input: self.compile_variable(op.input),
390 out,
391 };
392 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
393 instructions.push(Instruction::Warp(instruction))
394 }
395 gpu::Plane::Elect => {
396 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
397 }
398 gpu::Plane::All(op) => {
399 instructions.push(Instruction::Warp(WarpInstruction::All {
400 input: self.compile_variable(op.input),
401 out,
402 }))
403 }
404 gpu::Plane::Any(op) => {
405 instructions.push(Instruction::Warp(WarpInstruction::Any {
406 input: self.compile_variable(op.input),
407 out,
408 }))
409 }
410 gpu::Plane::Ballot(op) => {
411 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
412 input: self.compile_variable(op.input),
413 out,
414 }))
415 }
416 gpu::Plane::Broadcast(op) => {
417 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
418 input: self.compile_variable(op.lhs),
419 id: self.compile_variable(op.rhs),
420 out,
421 }))
422 }
423 gpu::Plane::Shuffle(op) => {
424 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
425 input: self.compile_variable(op.lhs),
426 src_lane: self.compile_variable(op.rhs),
427 out,
428 }))
429 }
430 gpu::Plane::ShuffleXor(op) => {
431 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
432 input: self.compile_variable(op.lhs),
433 mask: self.compile_variable(op.rhs),
434 out,
435 }))
436 }
437 gpu::Plane::ShuffleUp(op) => {
438 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
439 input: self.compile_variable(op.lhs),
440 delta: self.compile_variable(op.rhs),
441 out,
442 }))
443 }
444 gpu::Plane::ShuffleDown(op) => {
445 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
446 input: self.compile_variable(op.lhs),
447 delta: self.compile_variable(op.rhs),
448 out,
449 }))
450 }
451 }
452 }
453 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
454 gpu::Operation::NonSemantic(debug) => match debug {
455 gpu::NonSemantic::Print {
456 format_string,
457 args,
458 } => instructions.push(Instruction::Printf {
459 format_string,
460 args: args
461 .into_iter()
462 .map(|arg| self.compile_variable(arg))
463 .collect(),
464 }),
465 gpu::NonSemantic::Comment { content } => {
466 instructions.push(Instruction::Comment { content })
467 }
468 _ => {}
470 },
471 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
472 gpu::BarrierOps::Init {
473 barrier,
474 with_cta_fence,
475 } => {
476 let VariableKind::Barrier { level, .. } = barrier.kind else {
477 unreachable!()
478 };
479 let barrier = self.compile_variable(barrier);
480 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
481 barrier,
482 level,
483 with_cta_fence,
484 }));
485 }
486 gpu::BarrierOps::MemCopyAsync {
487 barrier,
488 source,
489 source_length,
490 offset_source,
491 offset_out,
492 } => {
493 let VariableKind::Barrier { level, .. } = barrier.kind else {
494 unreachable!()
495 };
496 instructions.push(Instruction::Barrier(
497 super::barrier::BarrierOps::MemCopyAsync {
498 barrier: self.compile_variable(barrier),
499 source: self.compile_variable(source),
500 destination: self.compile_variable(out.unwrap()),
501 source_length: self.compile_variable(source_length),
502 offset_source: self.compile_variable(offset_source),
503 offset_out: self.compile_variable(offset_out),
504 level,
505 },
506 ));
507 }
508 gpu::BarrierOps::TmaLoad {
509 barrier,
510 tensor_map,
511 offset_out,
512 indices,
513 } => {
514 instructions.push(Instruction::Barrier(
515 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
516 barrier: self.compile_variable(barrier),
517 smem_buffer: self.compile_variable(out.unwrap()),
518 smem_offset: self.compile_variable(offset_out),
519 tensor_map: self.compile_variable(tensor_map),
520 indices: indices
521 .into_iter()
522 .map(|it| self.compile_variable(it))
523 .collect(),
524 },
525 ));
526 }
527 gpu::BarrierOps::TmaLoadIm2col {
528 barrier,
529 tensor_map,
530 offset_out,
531 indices,
532 offsets,
533 } => {
534 self.flags.inst_tma_im2col = true;
535 instructions.push(Instruction::Barrier(
536 super::barrier::BarrierOps::TmaLoadIm2col {
537 barrier: self.compile_variable(barrier),
538 smem_buffer: self.compile_variable(out.unwrap()),
539 smem_offset: self.compile_variable(offset_out),
540 tensor_map: self.compile_variable(tensor_map),
541 indices: indices
542 .into_iter()
543 .map(|it| self.compile_variable(it))
544 .collect(),
545 offsets: offsets
546 .into_iter()
547 .map(|it| self.compile_variable(it))
548 .collect(),
549 },
550 ));
551 }
552 gpu::BarrierOps::Arrive { barrier } => {
553 let VariableKind::Barrier { level, .. } = barrier.kind else {
554 unreachable!()
555 };
556 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
557 barrier: self.compile_variable(barrier),
558 level,
559 }))
560 }
561 gpu::BarrierOps::ArriveTx {
562 barrier,
563 arrive_count_update,
564 transaction_count_update,
565 } => {
566 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
567 barrier: self.compile_variable(barrier),
568 arrive_count_update: self.compile_variable(arrive_count_update),
569 transaction_count_update: self.compile_variable(transaction_count_update),
570 }))
571 }
572 gpu::BarrierOps::ExpectTx {
573 barrier,
574 transaction_count_update,
575 } => {
576 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
577 barrier: self.compile_variable(barrier),
578 transaction_count_update: self.compile_variable(transaction_count_update),
579 }))
580 }
581 gpu::BarrierOps::Wait { barrier } => {
582 let VariableKind::Barrier { level, .. } = barrier.kind else {
583 unreachable!()
584 };
585 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
586 barrier: self.compile_variable(barrier),
587 level,
588 }))
589 }
590 gpu::BarrierOps::ArriveAndWait { barrier } => {
591 let VariableKind::Barrier { level, .. } = barrier.kind else {
592 unreachable!()
593 };
594 instructions.push(Instruction::Barrier(
595 super::barrier::BarrierOps::ArriveAndWait {
596 barrier: self.compile_variable(barrier),
597 level,
598 },
599 ))
600 }
601 },
602 gpu::Operation::Tma(tma_ops) => {
603 self.flags.inst_tma = true;
604 match tma_ops {
605 gpu::TmaOps::TmaStore {
606 source,
607 coordinates,
608 offset_source,
609 } => {
610 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
611 smem_buffer: self.compile_variable(source),
612 smem_offset: self.compile_variable(offset_source),
613 tensor_map: self.compile_variable(out.unwrap()),
614 indices: coordinates
615 .into_iter()
616 .map(|it| self.compile_variable(it))
617 .collect(),
618 });
619 }
620 gpu::TmaOps::CommitGroup => {
621 instructions.push(Instruction::BulkCommitGroup);
622 }
623 gpu::TmaOps::WaitGroup { max_pending } => {
624 instructions.push(Instruction::BulkWaitGroup { max_pending });
625 }
626 gpu::TmaOps::WaitGroupRead { max_pending } => {
627 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
628 }
629 }
630 }
631 gpu::Operation::Free(_) => {}
632 }
633 }
634
635 fn update_debug_loc(
636 &mut self,
637 instructions: &mut Vec<Instruction<D>>,
638 inst: &gpu::Instruction,
639 ) {
640 if !matches!(inst.operation, Operation::NonSemantic(_)) {
641 match &inst.source_loc {
642 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
643 self.source_loc = Some(loc.clone());
644 instructions.push(Instruction::Line {
645 file: loc.source.file.clone(),
646 line: loc.line,
647 });
648 }
649 _ => {}
650 }
651 }
652 }
653
654 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
655 self.flags.inst_wmma = true;
656
657 let out = self.compile_variable(out.unwrap());
658
659 let inst = match cmma {
660 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
661 frag: out,
662 value: self.compile_variable(value),
663 },
664 gpu::CoopMma::Load {
665 value,
666 stride,
667 offset,
668 layout,
669 } => WmmaInstruction::Load {
670 frag: out,
671 offset: self.compile_variable(offset),
672 value: self.compile_variable(value),
673 stride: self.compile_variable(stride),
674 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
675 },
676 gpu::CoopMma::Execute {
677 mat_a,
678 mat_b,
679 mat_c,
680 } => WmmaInstruction::Execute {
681 frag_a: self.compile_variable(mat_a),
682 frag_b: self.compile_variable(mat_b),
683 frag_c: self.compile_variable(mat_c),
684 frag_d: out,
685 warp_size: self.compilation_options.warp_size,
686 },
687 gpu::CoopMma::ExecuteManual {
688 matrix,
689 registers_a,
690 registers_b,
691 registers_c,
692 } => WmmaInstruction::ExecuteManual {
693 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
694 frag_a: registers_a
695 .into_iter()
696 .map(|it| self.compile_variable(it))
697 .collect(),
698 frag_b: registers_b
699 .into_iter()
700 .map(|it| self.compile_variable(it))
701 .collect(),
702 frag_c: registers_c
703 .into_iter()
704 .map(|it| self.compile_variable(it))
705 .collect(),
706 frag_d: out,
707 },
708 gpu::CoopMma::ExecuteScaled {
709 matrix,
710 registers_a,
711 registers_b,
712 registers_c,
713 scales_a,
714 scales_b,
715 scales_factor,
716 } => WmmaInstruction::ExecuteScaled {
717 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
718 frag_a: registers_a
719 .into_iter()
720 .map(|it| self.compile_variable(it))
721 .collect(),
722 frag_b: registers_b
723 .into_iter()
724 .map(|it| self.compile_variable(it))
725 .collect(),
726 frag_c: registers_c
727 .into_iter()
728 .map(|it| self.compile_variable(it))
729 .collect(),
730 frag_d: out,
731
732 scales_a: self.compile_variable(scales_a),
733 scales_b: self.compile_variable(scales_b),
734 scales_factor,
735 },
736 gpu::CoopMma::Store {
737 mat,
738 stride,
739 offset,
740 layout,
741 } => {
742 self.flags.indexes.unit_pos = true;
743 self.flags.indexes.plane_index = true;
744 WmmaInstruction::Store {
745 output: out,
746 offset: self.compile_variable(offset),
747 frag: self.compile_variable(mat),
748 stride: self.compile_variable(stride),
749 layout: self
750 .compile_matrix_layout(layout)
751 .expect("Layout required for store instruction"),
752 }
753 }
754 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
755 input: self.compile_variable(input),
756 output: out,
757 },
758 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
759 panic!("Row/Col index should be handled by processors")
760 }
761 };
762
763 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
764
765 Instruction::Wmma(inst)
766 }
767
768 fn compile_metadata(
769 &mut self,
770 metadata: gpu::Metadata,
771 out: Option<gpu::Variable>,
772 ) -> Instruction<D> {
773 let out = out.unwrap();
774 match metadata {
775 gpu::Metadata::Stride { dim, var } => {
776 let position = self.ext_meta_position(var);
777 let offset = self.metadata.stride_offset_index(position);
778 Instruction::ExtendedMetadata {
779 info_offset: self.compile_variable(offset.into()),
780 dim: self.compile_variable(dim),
781 split_meta: self.compilation_options.grid_constants,
782 static_offset: self.metadata.static_len(),
783 out: self.compile_variable(out),
784 }
785 }
786 gpu::Metadata::Shape { dim, var } => {
787 let position = self.ext_meta_position(var);
788 let offset = self.metadata.shape_offset_index(position);
789 Instruction::ExtendedMetadata {
790 info_offset: self.compile_variable(offset.into()),
791 dim: self.compile_variable(dim),
792 split_meta: self.compilation_options.grid_constants,
793 static_offset: self.metadata.static_len(),
794 out: self.compile_variable(out),
795 }
796 }
797 gpu::Metadata::Rank { var } => {
798 let out = self.compile_variable(out);
799 let pos = self.ext_meta_position(var);
800 let offset = self.metadata.rank_index(pos);
801 super::Instruction::Metadata {
802 info_offset: self.compile_variable(offset.into()),
803 split_meta: self.compilation_options.grid_constants,
804 out,
805 }
806 }
807 gpu::Metadata::Length { var } => {
808 let input = self.compile_variable(var);
809 let out = self.compile_variable(out);
810
811 match input {
812 Variable::Slice { .. } => Instruction::SliceLength { input, out },
813 Variable::SharedMemory(_id, _item, length) => {
814 Instruction::ConstLength { length, out }
815 }
816 _ => {
817 let id = input.id().expect("Variable should have id");
818 let offset = self.metadata.len_index(id);
819 Instruction::Metadata {
820 info_offset: self.compile_variable(offset.into()),
821 split_meta: self.compilation_options.grid_constants,
822 out,
823 }
824 }
825 }
826 }
827 gpu::Metadata::BufferLength { var } => {
828 let input = self.compile_variable(var);
829 let out = self.compile_variable(out);
830
831 match input {
832 Variable::Slice { .. } => Instruction::SliceLength { input, out },
833 _ => {
834 let id = input.id().expect("Variable should have id");
835 let offset = self.metadata.buffer_len_index(id);
836 Instruction::Metadata {
837 info_offset: self.compile_variable(offset.into()),
838 split_meta: self.compilation_options.grid_constants,
839 out,
840 }
841 }
842 }
843 }
844 }
845 }
846
847 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
848 match branch {
849 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
850 cond: self.compile_variable(op.cond),
851 instructions: self.compile_scope(&mut op.scope),
852 }),
853 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
854 cond: self.compile_variable(op.cond),
855 instructions_if: self.compile_scope(&mut op.scope_if),
856 instructions_else: self.compile_scope(&mut op.scope_else),
857 }),
858 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
859 value: self.compile_variable(op.value),
860 instructions_default: self.compile_scope(&mut op.scope_default),
861 instructions_cases: op
862 .cases
863 .into_iter()
864 .map(|(val, mut block)| {
865 (self.compile_variable(val), self.compile_scope(&mut block))
866 })
867 .collect(),
868 }),
869 gpu::Branch::Return => instructions.push(Instruction::Return),
870 gpu::Branch::Break => instructions.push(Instruction::Break),
871 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
872 i: self.compile_variable(range_loop.i),
873 start: self.compile_variable(range_loop.start),
874 end: self.compile_variable(range_loop.end),
875 step: range_loop.step.map(|it| self.compile_variable(it)),
876 inclusive: range_loop.inclusive,
877 instructions: self.compile_scope(&mut range_loop.scope),
878 }),
879 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
880 instructions: self.compile_scope(&mut op.scope),
881 }),
882 };
883 }
884
885 fn compile_atomic(
886 &mut self,
887 value: gpu::AtomicOp,
888 out: Option<gpu::Variable>,
889 instructions: &mut Vec<Instruction<D>>,
890 ) {
891 let out = out.unwrap();
892 match value {
893 gpu::AtomicOp::Load(op) => {
894 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
895 }
896 gpu::AtomicOp::Store(op) => {
897 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
898 }
899 gpu::AtomicOp::Swap(op) => {
900 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
901 }
902 gpu::AtomicOp::Add(op) => {
903 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
904 }
905 gpu::AtomicOp::Sub(op) => {
906 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
907 }
908 gpu::AtomicOp::Max(op) => {
909 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
910 }
911 gpu::AtomicOp::Min(op) => {
912 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
913 }
914 gpu::AtomicOp::And(op) => {
915 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
916 }
917 gpu::AtomicOp::Or(op) => {
918 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
919 }
920 gpu::AtomicOp::Xor(op) => {
921 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
922 }
923 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
924 input: self.compile_variable(op.input),
925 cmp: self.compile_variable(op.cmp),
926 val: self.compile_variable(op.val),
927 out: self.compile_variable(out),
928 }),
929 }
930 }
931
932 fn compile_arithmetic(
933 &mut self,
934 value: gpu::Arithmetic,
935 out: Option<gpu::Variable>,
936 instructions: &mut Vec<Instruction<D>>,
937 ) {
938 let out = out.unwrap();
939 match value {
940 gpu::Arithmetic::Add(op) => {
941 instructions.push(Instruction::Add(self.compile_binary(op, out)))
942 }
943 gpu::Arithmetic::SaturatingAdd(op) => {
944 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
945 }
946 gpu::Arithmetic::Mul(op) => {
947 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
948 }
949 gpu::Arithmetic::Div(op) => {
950 instructions.push(Instruction::Div(self.compile_binary(op, out)))
951 }
952 gpu::Arithmetic::Sub(op) => {
953 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
954 }
955 gpu::Arithmetic::SaturatingSub(op) => {
956 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
957 }
958 gpu::Arithmetic::MulHi(op) => {
959 let instruction = Instruction::HiMul(self.compile_binary(op, out));
960 D::register_instruction_extension(&mut self.extensions, &instruction);
961 instructions.push(instruction)
962 }
963 gpu::Arithmetic::Modulo(op) => {
964 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
965 }
966 gpu::Arithmetic::Abs(op) => {
967 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
968 }
969 gpu::Arithmetic::Exp(op) => {
970 instructions.push(Instruction::Exp(self.compile_unary(op, out)))
971 }
972 gpu::Arithmetic::Log(op) => {
973 instructions.push(Instruction::Log(self.compile_unary(op, out)))
974 }
975 gpu::Arithmetic::Log1p(op) => {
976 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
977 }
978 gpu::Arithmetic::Cos(op) => {
979 instructions.push(Instruction::Cos(self.compile_unary(op, out)))
980 }
981 gpu::Arithmetic::Sin(op) => {
982 instructions.push(Instruction::Sin(self.compile_unary(op, out)))
983 }
984 gpu::Arithmetic::Tanh(op) => {
985 let instruction = Instruction::Tanh(self.compile_unary(op, out));
986 D::register_instruction_extension(&mut self.extensions, &instruction);
987 instructions.push(instruction)
988 }
989 gpu::Arithmetic::Powf(op) => {
990 instructions.push(Instruction::Powf(self.compile_binary(op, out)))
991 }
992 gpu::Arithmetic::Powi(op) => {
993 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
994 }
995 gpu::Arithmetic::Sqrt(op) => {
996 instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
997 }
998 gpu::Arithmetic::Erf(op) => {
999 let instruction = Instruction::Erf(self.compile_unary(op, out));
1000 D::register_instruction_extension(&mut self.extensions, &instruction);
1001 instructions.push(instruction)
1002 }
1003 gpu::Arithmetic::Max(op) => {
1004 let instruction = Instruction::Max(self.compile_binary(op, out));
1005 D::register_instruction_extension(&mut self.extensions, &instruction);
1006 instructions.push(instruction)
1007 }
1008 gpu::Arithmetic::Min(op) => {
1009 let instruction = Instruction::Min(self.compile_binary(op, out));
1010 D::register_instruction_extension(&mut self.extensions, &instruction);
1011 instructions.push(instruction)
1012 }
1013 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1014 input: self.compile_variable(op.input),
1015 min_value: self.compile_variable(op.min_value),
1016 max_value: self.compile_variable(op.max_value),
1017 out: self.compile_variable(out),
1018 }),
1019 gpu::Arithmetic::Recip(op) => {
1020 let elem = op.input.ty.elem_type();
1021 let lhs = match elem {
1022 gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1023 gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1024 gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1025 gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1026 };
1027
1028 instructions.push(Instruction::Div(BinaryInstruction {
1029 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1030 rhs: self.compile_variable(op.input),
1031 out: self.compile_variable(out),
1032 }))
1033 }
1034 gpu::Arithmetic::Round(op) => {
1035 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1036 }
1037 gpu::Arithmetic::Floor(op) => {
1038 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1039 }
1040 gpu::Arithmetic::Ceil(op) => {
1041 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1042 }
1043 gpu::Arithmetic::Trunc(op) => {
1044 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1045 }
1046 gpu::Arithmetic::Remainder(op) => {
1047 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1048 }
1049 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1050 a: self.compile_variable(op.a),
1051 b: self.compile_variable(op.b),
1052 c: self.compile_variable(op.c),
1053 out: self.compile_variable(out),
1054 }),
1055 gpu::Arithmetic::Neg(op) => {
1056 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1057 }
1058 gpu::Arithmetic::Normalize(op) => {
1059 instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
1060 }
1061 gpu::Arithmetic::Magnitude(op) => {
1062 instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
1063 }
1064 gpu::Arithmetic::Dot(op) => {
1065 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1066 }
1067 };
1068 }
1069
1070 fn compile_comparison(
1071 &mut self,
1072 value: gpu::Comparison,
1073 out: Option<gpu::Variable>,
1074 instructions: &mut Vec<Instruction<D>>,
1075 ) {
1076 let out = out.unwrap();
1077 match value {
1078 gpu::Comparison::Equal(op) => {
1079 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1080 }
1081 gpu::Comparison::Lower(op) => {
1082 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1083 }
1084 gpu::Comparison::Greater(op) => {
1085 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1086 }
1087 gpu::Comparison::LowerEqual(op) => {
1088 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1089 }
1090 gpu::Comparison::GreaterEqual(op) => {
1091 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1092 }
1093 gpu::Comparison::NotEqual(op) => {
1094 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1095 }
1096 gpu::Comparison::IsNan(op) => {
1097 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1098 }
1099 gpu::Comparison::IsInf(op) => {
1100 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1101 }
1102 };
1103 }
1104
1105 fn compile_bitwise(
1106 &mut self,
1107 value: gpu::Bitwise,
1108 out: Option<gpu::Variable>,
1109 instructions: &mut Vec<Instruction<D>>,
1110 ) {
1111 let out = out.unwrap();
1112 match value {
1113 gpu::Bitwise::BitwiseOr(op) => {
1114 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1115 }
1116 gpu::Bitwise::BitwiseAnd(op) => {
1117 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1118 }
1119 gpu::Bitwise::BitwiseXor(op) => {
1120 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1121 }
1122 gpu::Bitwise::CountOnes(op) => {
1123 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1124 }
1125 gpu::Bitwise::ReverseBits(op) => {
1126 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1127 }
1128 gpu::Bitwise::ShiftLeft(op) => {
1129 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1130 }
1131 gpu::Bitwise::ShiftRight(op) => {
1132 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1133 }
1134 gpu::Bitwise::BitwiseNot(op) => {
1135 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1136 }
1137 gpu::Bitwise::LeadingZeros(op) => {
1138 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1139 }
1140 gpu::Bitwise::FindFirstSet(op) => {
1141 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1142 D::register_instruction_extension(&mut self.extensions, &instruction);
1143 instructions.push(instruction)
1144 }
1145 };
1146 }
1147
1148 fn compile_operator(
1149 &mut self,
1150 value: gpu::Operator,
1151 out: Option<gpu::Variable>,
1152 instructions: &mut Vec<Instruction<D>>,
1153 ) {
1154 let out = out.unwrap();
1155 match value {
1156 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1157 instructions.push(Instruction::Index(self.compile_index(op, out)));
1158 }
1159 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1160 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1161 }
1162 gpu::Operator::And(op) => {
1163 instructions.push(Instruction::And(self.compile_binary(op, out)))
1164 }
1165 gpu::Operator::Or(op) => {
1166 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1167 }
1168 gpu::Operator::Not(op) => {
1169 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1170 }
1171 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1172 inputs: op
1173 .inputs
1174 .into_iter()
1175 .map(|it| self.compile_variable(it))
1176 .collect(),
1177 out: self.compile_variable(out),
1178 }),
1179 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1180 input: self.compile_variable(op.input),
1181 in_index: self.compile_variable(op.in_index),
1182 out: self.compile_variable(out),
1183 out_index: self.compile_variable(op.out_index),
1184 }),
1185 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1186 input: self.compile_variable(op.input),
1187 in_index: self.compile_variable(op.in_index),
1188 out: self.compile_variable(out),
1189 out_index: self.compile_variable(op.out_index),
1190 len: op.len,
1191 }),
1192 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1193 cond: self.compile_variable(op.cond),
1194 then: self.compile_variable(op.then),
1195 or_else: self.compile_variable(op.or_else),
1196 out: self.compile_variable(out),
1197 }),
1198 gpu::Operator::Cast(op)
1200 if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1201 {
1202 self.flags.elem_f16 = true;
1204 self.flags.elem_bf16 = true;
1205 let vec_in = op.input.ty.line_size();
1206 let packing = out.storage_type().packing_factor();
1207 self.compile_type(op.input.ty.line(packing));
1208 self.compile_type(
1209 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1210 );
1211 self.compile_type(
1212 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1213 );
1214 self.compile_type(
1215 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1216 );
1217 self.compile_type(
1218 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1219 );
1220
1221 let inst = self.compile_unary(op, out);
1222
1223 instructions.push(Instruction::SpecialCast(inst));
1224 }
1225 gpu::Operator::Cast(op) => {
1226 let op = self.compile_unary(op, out);
1227
1228 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1229 self.flags.elem_tf32 = true;
1230 }
1231
1232 instructions.push(Instruction::Assign(op))
1233 }
1234 gpu::Operator::Reinterpret(op) => {
1235 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1236 }
1237 };
1238 }
1239
1240 fn compile_binary(
1241 &mut self,
1242 value: gpu::BinaryOperator,
1243 out: gpu::Variable,
1244 ) -> BinaryInstruction<D> {
1245 BinaryInstruction {
1246 lhs: self.compile_variable(value.lhs),
1247 rhs: self.compile_variable(value.rhs),
1248 out: self.compile_variable(out),
1249 }
1250 }
1251
1252 fn compile_index_assign(
1253 &mut self,
1254 value: gpu::IndexAssignOperator,
1255 out: gpu::Variable,
1256 ) -> IndexAssignInstruction<D> {
1257 IndexAssignInstruction {
1258 index: self.compile_variable(value.index),
1259 value: self.compile_variable(value.value),
1260 line_size: value.line_size,
1261 out: self.compile_variable(out),
1262 }
1263 }
1264
1265 fn compile_index(
1266 &mut self,
1267 value: gpu::IndexOperator,
1268 out: gpu::Variable,
1269 ) -> IndexInstruction<D> {
1270 IndexInstruction {
1271 list: self.compile_variable(value.list),
1272 index: self.compile_variable(value.index),
1273 line_size: value.line_size,
1274 out: self.compile_variable(out),
1275 }
1276 }
1277
1278 fn compile_unary(
1279 &mut self,
1280 value: gpu::UnaryOperator,
1281 out: gpu::Variable,
1282 ) -> UnaryInstruction<D> {
1283 UnaryInstruction {
1284 input: self.compile_variable(value.input),
1285 out: self.compile_variable(out),
1286 }
1287 }
1288
1289 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1290 let item = value.ty;
1291 match value.kind {
1292 gpu::VariableKind::GlobalInputArray(id) => {
1293 Variable::GlobalInputArray(id, self.compile_type(item))
1294 }
1295 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1296 id,
1297 elem: self.compile_storage_type(item.storage_type()),
1298 in_struct: self.compilation_options.grid_constants,
1299 },
1300 gpu::VariableKind::TensorMapInput(id) => {
1301 self.flags.inst_tma = true;
1302 Variable::TensorMap(id)
1303 }
1304 gpu::VariableKind::TensorMapOutput(id) => {
1305 self.flags.inst_tma = true;
1306 Variable::TensorMap(id)
1307 }
1308 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1309 id,
1310 item: self.compile_type(item),
1311 },
1312 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1313 id,
1314 item: self.compile_type(item),
1315 },
1316 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1317 id,
1318 item: self.compile_type(item),
1319 },
1320 gpu::VariableKind::GlobalOutputArray(id) => {
1321 Variable::GlobalOutputArray(id, self.compile_type(item))
1322 }
1323 gpu::VariableKind::ConstantScalar(value) => {
1324 Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1325 }
1326 gpu::VariableKind::SharedMemory { id, length, .. } => {
1327 let item = self.compile_type(item);
1328 Variable::SharedMemory(id, item, length)
1329 }
1330 gpu::VariableKind::ConstantArray {
1331 id,
1332 length,
1333 unroll_factor,
1334 } => {
1335 let item = self.compile_type(item);
1336 Variable::ConstantArray(id, item, length * unroll_factor)
1337 }
1338 gpu::VariableKind::Builtin(builtin) => match builtin {
1339 gpu::Builtin::AbsolutePos => {
1340 self.flags.indexes.absolute_pos = true;
1341 Variable::AbsolutePos
1342 }
1343 gpu::Builtin::CubePosCluster if self.compilation_options.supports_clusters => {
1344 self.flags.indexes.cluster_pos = true;
1345 Variable::ClusterRank
1346 }
1347 gpu::Builtin::CubePosClusterX if self.compilation_options.supports_clusters => {
1348 self.flags.indexes.cluster_pos = true;
1349 Variable::ClusterIndexX
1350 }
1351 gpu::Builtin::CubePosClusterY if self.compilation_options.supports_clusters => {
1352 self.flags.indexes.cluster_pos = true;
1353 Variable::ClusterIndexY
1354 }
1355 gpu::Builtin::CubePosClusterZ if self.compilation_options.supports_clusters => {
1356 self.flags.indexes.cluster_pos = true;
1357 Variable::ClusterIndexZ
1358 }
1359 gpu::Builtin::CubePosCluster
1362 | gpu::Builtin::CubePosClusterX
1363 | gpu::Builtin::CubePosClusterY
1364 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1365 gpu::Builtin::AbsolutePosX => {
1366 self.flags.indexes.absolute_pos_tuple = true;
1367 Variable::AbsolutePosX
1368 }
1369 gpu::Builtin::AbsolutePosY => {
1370 self.flags.indexes.absolute_pos_tuple = true;
1371 Variable::AbsolutePosY
1372 }
1373 gpu::Builtin::AbsolutePosZ => {
1374 self.flags.indexes.absolute_pos_tuple = true;
1375 Variable::AbsolutePosZ
1376 }
1377 gpu::Builtin::CubeDim => {
1378 self.flags.indexes.cube_dim = true;
1379 Variable::CubeDim
1380 }
1381 gpu::Builtin::CubeDimX => {
1382 self.flags.indexes.cube_dim_tuple = true;
1383 Variable::CubeDimX
1384 }
1385 gpu::Builtin::CubeDimY => {
1386 self.flags.indexes.cube_dim_tuple = true;
1387 Variable::CubeDimY
1388 }
1389 gpu::Builtin::CubeDimZ => {
1390 self.flags.indexes.cube_dim_tuple = true;
1391 Variable::CubeDimZ
1392 }
1393 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1394 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1395 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1396 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1397 gpu::Builtin::CubePos => {
1398 self.flags.indexes.cube_pos = true;
1399 Variable::CubePos
1400 }
1401 gpu::Builtin::CubePosX => {
1402 self.flags.indexes.cube_pos_tuple = true;
1403 Variable::CubePosX
1404 }
1405 gpu::Builtin::CubePosY => {
1406 self.flags.indexes.cube_pos_tuple = true;
1407 Variable::CubePosY
1408 }
1409 gpu::Builtin::CubePosZ => {
1410 self.flags.indexes.cube_pos_tuple = true;
1411 Variable::CubePosZ
1412 }
1413 gpu::Builtin::CubeCount => {
1414 self.flags.indexes.cube_count = true;
1415 Variable::CubeCount
1416 }
1417 gpu::Builtin::CubeCountX => {
1418 self.flags.indexes.cube_count_tuple = true;
1419 Variable::CubeCountX
1420 }
1421 gpu::Builtin::CubeCountY => {
1422 self.flags.indexes.cube_count_tuple = true;
1423 Variable::CubeCountY
1424 }
1425 gpu::Builtin::CubeCountZ => {
1426 self.flags.indexes.cube_count_tuple = true;
1427 Variable::CubeCountZ
1428 }
1429 gpu::Builtin::UnitPos => {
1430 self.flags.indexes.unit_pos = true;
1431 Variable::UnitPos
1432 }
1433 gpu::Builtin::UnitPosX => {
1434 self.flags.indexes.unit_pos_tuple = true;
1435 Variable::UnitPosX
1436 }
1437 gpu::Builtin::UnitPosY => {
1438 self.flags.indexes.unit_pos_tuple = true;
1439 Variable::UnitPosY
1440 }
1441 gpu::Builtin::UnitPosZ => {
1442 self.flags.indexes.unit_pos_tuple = true;
1443 Variable::UnitPosZ
1444 }
1445 gpu::Builtin::PlaneDim => {
1446 self.flags.indexes.plane_dim = true;
1447 Variable::PlaneDim
1448 }
1449 gpu::Builtin::UnitPosPlane => {
1450 self.flags.indexes.unit_pos_plane = true;
1451 Variable::UnitPosPlane
1452 }
1453 },
1454 gpu::VariableKind::LocalArray {
1455 id,
1456 length,
1457 unroll_factor,
1458 } => {
1459 let item = self.compile_type(item);
1460 if !self.local_arrays.iter().any(|s| s.index == id) {
1461 self.local_arrays
1462 .push(LocalArray::new(id, item, length * unroll_factor));
1463 }
1464 Variable::LocalArray(id, item, length)
1465 }
1466 gpu::VariableKind::Matrix { id, mat } => {
1467 self.flags.inst_wmma = true;
1468 Variable::WmmaFragment {
1469 id,
1470 frag: self.compile_matrix(mat),
1471 }
1472 }
1473 gpu::VariableKind::Pipeline { id, num_stages } => {
1474 self.flags.op_pipeline = true;
1475 let pipeline = Variable::Pipeline { id };
1476 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1477 self.pipelines.push(PipelineOps::Init {
1478 pipeline,
1479 num_stages,
1480 });
1481 }
1482 pipeline
1483 }
1484 gpu::VariableKind::Barrier { id, level } => {
1485 self.flags.op_barrier = true;
1486 match level {
1487 gpu::BarrierLevel::CubeCoop(_) | gpu::BarrierLevel::CubeManual(_) => {
1488 self.flags.indexes.cube_dim = true;
1489 self.flags.indexes.unit_pos = true;
1490 }
1491 _ => {}
1492 }
1493 Variable::Barrier { id, level }
1494 }
1495 }
1496 }
1497
1498 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1499 Fragment {
1500 ident: self.compile_matrix_ident(matrix.ident),
1501 m: matrix.m,
1502 n: matrix.n,
1503 k: matrix.k,
1504 elem: self.compile_storage_type(matrix.storage),
1505 layout: self.compile_matrix_layout(matrix.layout),
1506 }
1507 }
1508
1509 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1510 match ident {
1511 gpu::MatrixIdent::A => FragmentIdent::A,
1512 gpu::MatrixIdent::B => FragmentIdent::B,
1513 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1514 }
1515 }
1516
1517 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1518 match layout {
1519 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1520 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1521 gpu::MatrixLayout::Undefined => None,
1522 }
1523 }
1524
1525 fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1526 Binding {
1527 id: binding.id,
1528 item: self.compile_type(binding.ty),
1529 location: binding.location,
1530 size: binding.size,
1531 vis: binding.visibility,
1532 }
1533 }
1534
1535 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1536 let item = match ty {
1537 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1538 gpu::Type::Line(ty, line_size) => {
1539 Item::new(self.compile_storage_type(ty), line_size as usize, false)
1540 }
1541 gpu::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
1542 };
1543 if item.elem != super::Elem::TF32 {
1544 self.items.insert(item);
1545 self.items.insert(item.optimized());
1546 } else {
1547 let mut item = item;
1549 item.elem = super::Elem::F32;
1550 self.items.insert(item);
1551 }
1552
1553 item
1554 }
1555
1556 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1557 match value {
1558 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1559 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1560 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1561 FloatKind::E2M1 => {
1562 self.flags.elem_fp4 = true;
1563 Elem::FP4x2(FP4Kind::E2M1)
1564 }
1565 FloatKind::E2M3 => {
1566 self.flags.elem_fp6 = true;
1567 Elem::FP6x2(FP6Kind::E2M3)
1568 }
1569 FloatKind::E3M2 => {
1570 self.flags.elem_fp6 = true;
1571 Elem::FP6(FP6Kind::E3M2)
1572 }
1573 FloatKind::E4M3 => {
1574 self.flags.elem_fp8 = true;
1575 Elem::FP8x2(FP8Kind::E4M3)
1576 }
1577 FloatKind::E5M2 => {
1578 self.flags.elem_fp8 = true;
1579 Elem::FP8x2(FP8Kind::E5M2)
1580 }
1581 FloatKind::UE8M0 => {
1582 self.flags.elem_fp8 = true;
1583 Elem::FP8x2(FP8Kind::UE8M0)
1584 }
1585 FloatKind::F16 => {
1586 self.flags.elem_f16 = true;
1587 Elem::F16x2
1588 }
1589 FloatKind::BF16 => {
1590 self.flags.elem_bf16 = true;
1591 Elem::BF16x2
1592 }
1593 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1594 },
1595 other => unimplemented!("Unsupported storage type: {other}"),
1596 }
1597 }
1598
1599 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1600 match value {
1601 gpu::ElemType::Float(kind) => match kind {
1602 gpu::FloatKind::E2M1 => {
1603 self.flags.elem_fp4 = true;
1604 Elem::FP4(FP4Kind::E2M1)
1605 }
1606 gpu::FloatKind::E2M3 => {
1607 self.flags.elem_fp6 = true;
1608 Elem::FP6(FP6Kind::E2M3)
1609 }
1610 gpu::FloatKind::E3M2 => {
1611 self.flags.elem_fp6 = true;
1612 Elem::FP6(FP6Kind::E3M2)
1613 }
1614 gpu::FloatKind::E4M3 => {
1615 self.flags.elem_fp8 = true;
1616 Elem::FP8(FP8Kind::E4M3)
1617 }
1618 gpu::FloatKind::E5M2 => {
1619 self.flags.elem_fp8 = true;
1620 Elem::FP8(FP8Kind::E5M2)
1621 }
1622 gpu::FloatKind::UE8M0 => {
1623 self.flags.elem_fp8 = true;
1624 Elem::FP8(FP8Kind::UE8M0)
1625 }
1626 gpu::FloatKind::F16 => {
1627 self.flags.elem_f16 = true;
1628 Elem::F16
1629 }
1630 gpu::FloatKind::BF16 => {
1631 self.flags.elem_bf16 = true;
1632 Elem::BF16
1633 }
1634 gpu::FloatKind::TF32 => Elem::TF32,
1635 gpu::FloatKind::Flex32 => Elem::F32,
1636 gpu::FloatKind::F32 => Elem::F32,
1637 gpu::FloatKind::F64 => Elem::F64,
1638 },
1639 gpu::ElemType::Int(kind) => match kind {
1640 gpu::IntKind::I8 => Elem::I8,
1641 gpu::IntKind::I16 => Elem::I16,
1642 gpu::IntKind::I32 => Elem::I32,
1643 gpu::IntKind::I64 => Elem::I64,
1644 },
1645 gpu::ElemType::UInt(kind) => match kind {
1646 gpu::UIntKind::U8 => Elem::U8,
1647 gpu::UIntKind::U16 => Elem::U16,
1648 gpu::UIntKind::U32 => Elem::U32,
1649 gpu::UIntKind::U64 => Elem::U64,
1650 },
1651 gpu::ElemType::Bool => Elem::Bool,
1652 }
1653 }
1654}
1655
1656fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1657 match elem {
1658 gpu::ElemType::Float(kind) => matches!(
1659 kind,
1660 FloatKind::E2M1
1661 | FloatKind::E2M3
1662 | FloatKind::E3M2
1663 | FloatKind::E4M3
1664 | FloatKind::E5M2
1665 | FloatKind::UE8M0
1666 ),
1667 _ => false,
1668 }
1669}
1670
1671fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1672 Variable::ConstantScalar(
1673 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1674 Elem::U32,
1675 )
1676}
1677
1678pub fn register_supported_types(props: &mut DeviceProperties) {
1679 let supported_types = [
1680 gpu::ElemType::UInt(gpu::UIntKind::U8),
1681 gpu::ElemType::UInt(gpu::UIntKind::U16),
1682 gpu::ElemType::UInt(gpu::UIntKind::U32),
1683 gpu::ElemType::UInt(gpu::UIntKind::U64),
1684 gpu::ElemType::Int(gpu::IntKind::I8),
1685 gpu::ElemType::Int(gpu::IntKind::I16),
1686 gpu::ElemType::Int(gpu::IntKind::I32),
1687 gpu::ElemType::Int(gpu::IntKind::I64),
1688 gpu::ElemType::Float(gpu::FloatKind::BF16),
1689 gpu::ElemType::Float(gpu::FloatKind::F16),
1690 gpu::ElemType::Float(gpu::FloatKind::F32),
1691 gpu::ElemType::Float(gpu::FloatKind::Flex32),
1692 gpu::ElemType::Bool,
1695 ];
1696
1697 let supported_atomic_types = [
1698 gpu::ElemType::Int(gpu::IntKind::I32),
1699 gpu::ElemType::Int(gpu::IntKind::I64),
1700 gpu::ElemType::UInt(gpu::UIntKind::U32),
1701 gpu::ElemType::UInt(gpu::UIntKind::U64),
1702 gpu::ElemType::Float(gpu::FloatKind::F32),
1703 ];
1704
1705 for ty in supported_types {
1706 props.register_type_usage(ty, TypeUsage::all_scalar());
1707 }
1708
1709 for ty in supported_atomic_types {
1710 props.register_type_usage(
1711 gpu::StorageType::Atomic(ty),
1712 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1713 );
1714 }
1715}