1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind, VariableKind};
5use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
6use cubecl_core::{
7 Compiler,
8 ir::{self as gpu},
9};
10use cubecl_core::{CubeDim, ir::ElemType};
11use cubecl_core::{
12 ir::{Operation, SourceLoc},
13 prelude::{FastMath, KernelDefinition},
14};
15use cubecl_opt::{Optimizer, SharedLiveness};
16use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage};
17
18use crate::shared::MmaShape;
19
20use super::{
21 BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
22 Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
23 Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
24};
25use super::{FP4Kind, barrier::BarrierOps};
26use super::{FP8Kind, pipeline::PipelineOps};
27
28pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
29 std::sync::atomic::AtomicU32::new(0);
30
31#[derive(Clone, Debug)]
32pub struct CompilationOptions {
33 pub warp_size: u32,
34 pub supports_features: CppSupportedFeatures,
35}
36
37#[derive(Clone, Debug, Default)]
38pub struct CppSupportedFeatures {
39 pub grid_constants: bool,
40 pub clusters: bool,
41 pub fast_math: bool,
42 pub elect_sync: bool,
43}
44
45impl Default for CompilationOptions {
46 fn default() -> Self {
47 Self {
48 warp_size: 32,
49 supports_features: Default::default(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Default)]
57pub struct CubeIndexFlags {
58 pub absolute_pos: bool,
59 pub absolute_pos_tuple: bool,
60 pub cube_count: bool,
61 pub cube_count_tuple: bool,
62 pub cube_dim: bool,
63 pub cube_dim_tuple: bool,
64 pub cube_pos: bool,
65 pub cube_pos_tuple: bool,
66 pub plane_dim: bool,
67 pub plane_dim_checked: bool,
68 pub plane_index: bool,
69 pub unit_pos: bool,
70 pub unit_pos_tuple: bool,
71 pub unit_pos_plane: bool,
72 pub cluster_pos: bool,
73}
74
75#[derive(Debug, Clone, Default)]
77pub struct Flags {
78 pub elem_fp4: bool,
79 pub elem_fp6: bool,
80 pub elem_fp8: bool,
81 pub elem_bf16: bool,
82 pub elem_f16: bool,
83 pub elem_tf32: bool,
84 pub indexes: CubeIndexFlags,
85 pub op_barrier: bool,
86 pub op_pipeline: bool,
87 pub inst_tma: bool,
88 pub inst_tma_im2col: bool,
89 pub inst_wmma: bool,
90 pub inst_ptx_wrappers: bool,
91 pub use_grid_constants: bool,
92 pub static_meta_length: usize,
93 pub has_dynamic_meta: bool,
94 pub cube_dim: CubeDim,
95 pub cluster_dim: Option<CubeDim>,
96}
97
98#[allow(clippy::too_many_arguments)]
99#[derive(Clone, Debug, Default)]
100pub struct CppCompiler<D: Dialect> {
101 barriers: Vec<BarrierOps<D>>,
102 compilation_options: CompilationOptions,
103 const_arrays: Vec<ConstArray<D>>,
104 ext_meta_positions: Vec<u32>,
105 cluster_dim: CubeDim,
106 extensions: Vec<D::Extension>,
107 flags: Flags,
108 items: HashSet<Item<D>>,
109 local_arrays: Vec<LocalArray<D>>,
110 metadata: cubecl_core::Metadata,
111 pipelines: Vec<PipelineOps<D>>,
112 source_loc: Option<SourceLoc>,
113 strategy: ExecutionMode,
114}
115
116impl<D: Dialect> Compiler for CppCompiler<D> {
117 type Representation = ComputeKernel<D>;
118 type CompilationOptions = CompilationOptions;
119
120 fn compile(
121 &mut self,
122 mut kernel: KernelDefinition,
123 compilation_options: &Self::CompilationOptions,
124 strategy: ExecutionMode,
125 ) -> Self::Representation {
126 self.compilation_options = compilation_options.clone();
127 self.strategy = strategy;
128
129 if !self.compilation_options.supports_features.clusters {
130 kernel.options.cluster_dim = None;
131 }
132 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
133
134 let ir = self.clone().compile_ir(kernel);
135 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
136 ir
137 }
138
139 fn elem_size(&self, elem: gpu::ElemType) -> usize {
140 elem.size()
141 }
142
143 fn extension(&self) -> &'static str {
144 "cpp"
145 }
146}
147
148impl<D: Dialect> CppCompiler<D> {
149 fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
150 self.build_metadata(&value);
151
152 let instructions = self.compile_scope(&mut value.body.clone());
153 let tensor_maps = value
154 .tensor_maps
155 .into_iter()
156 .map(|b| self.compile_binding(b))
157 .collect();
158 let buffers = value
159 .buffers
160 .into_iter()
161 .map(|b| self.compile_binding(b))
162 .collect();
163 let scalars = value
164 .scalars
165 .into_iter()
166 .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
167 .collect();
168
169 let flags = Flags {
171 indexes: D::builtin_rules(&self.flags.indexes),
172 inst_wmma: self.flags.inst_wmma,
173 op_pipeline: self.flags.op_pipeline,
174 op_barrier: self.flags.op_barrier,
175 elem_fp4: self.flags.elem_fp4,
176 elem_fp6: self.flags.elem_fp6,
177 elem_fp8: self.flags.elem_fp8,
178 elem_bf16: self.flags.elem_bf16,
179 elem_f16: self.flags.elem_f16,
180 elem_tf32: self.flags.elem_tf32,
181 inst_tma: self.flags.inst_tma,
182 inst_tma_im2col: self.flags.inst_tma_im2col,
183 inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
184 use_grid_constants: self.compilation_options.supports_features.grid_constants,
185 has_dynamic_meta: self.metadata.static_len() > 0,
188 static_meta_length: self.metadata.static_len() as usize,
189 cube_dim: value.cube_dim,
190 cluster_dim: value.options.cluster_dim,
191 };
192
193 let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
194 let shared_allocs = opt.analysis::<SharedLiveness>();
195 let shared_memories = shared_allocs
196 .allocations
197 .values()
198 .map(|alloc| SharedMemory {
199 index: alloc.smem.id,
200 item: self.compile_type(alloc.smem.ty),
201 length: alloc.smem.length,
202 align: alloc.smem.align,
203 offset: alloc.offset,
204 })
205 .collect();
206
207 let body = Body {
208 instructions,
209 shared_memories,
210 pipelines: self.pipelines,
211 barriers: self.barriers,
212 const_arrays: self.const_arrays,
213 local_arrays: self.local_arrays,
214 };
215
216 let mut cluster_dim = value.options.cluster_dim;
217 if !self.compilation_options.supports_features.clusters {
218 cluster_dim = None;
219 }
220
221 ComputeKernel {
222 tensor_maps,
223 buffers,
224 scalars,
225 meta_static_len: self.metadata.static_len() as usize,
226 cube_dim: value.cube_dim,
227 body,
228 extensions: self.extensions,
229 flags,
230 items: self.items,
231 kernel_name: value.options.kernel_name,
232 cluster_dim,
233 }
234 }
235
236 fn build_metadata(&mut self, value: &KernelDefinition) {
237 let mut num_ext = 0;
238
239 let mut all_meta: Vec<_> = value
240 .buffers
241 .iter()
242 .chain(value.tensor_maps.iter())
243 .map(|buf| (buf.id, buf.has_extended_meta))
244 .collect();
245
246 all_meta.sort_by_key(|(id, _)| *id);
247
248 for (_, has_extended_meta) in &all_meta {
249 self.ext_meta_positions.push(num_ext);
250 if *has_extended_meta {
251 num_ext += 1;
252 }
253 }
254
255 let num_meta = all_meta.len();
256
257 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
258 }
259
260 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
261 let id = var.index().expect("Variable should have index");
262 self.ext_meta_positions[id as usize]
263 }
264
265 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
266 let mut instructions = Vec::new();
267
268 let const_arrays = scope
269 .const_arrays
270 .drain(..)
271 .map(|(var, values)| ConstArray {
272 index: var.index().unwrap(),
273 item: self.compile_type(var.ty),
274 size: values.len() as u32,
275 values: values
276 .into_iter()
277 .map(|val| self.compile_variable(val))
278 .collect(),
279 })
280 .collect::<Vec<_>>();
281 self.const_arrays.extend(const_arrays);
282
283 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
284 let dialect_processors = D::processors();
285 let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
286 processors.extend(dialect_processors.iter().map(|it| &**it));
287
288 let processing = scope.process(processors);
289
290 for var in processing.variables {
291 instructions.push(Instruction::DeclareVariable {
292 var: self.compile_variable(var),
293 });
294 }
295
296 processing
297 .instructions
298 .into_iter()
299 .for_each(|op| self.compile_instruction(&mut instructions, op));
300
301 instructions
302 }
303
304 fn compile_instruction(
305 &mut self,
306 instructions: &mut Vec<Instruction<D>>,
307 instruction: gpu::Instruction,
308 ) {
309 self.update_debug_loc(instructions, &instruction);
310 let out = instruction.out;
311 match instruction.operation {
312 gpu::Operation::Copy(variable) => {
313 instructions.push(Instruction::Assign(UnaryInstruction {
314 input: self.compile_variable(variable),
315 out: self.compile_variable(out.unwrap()),
316 }));
317 }
318 gpu::Operation::Arithmetic(op) => {
319 self.compile_arithmetic(op, out, instruction.modes, instructions)
320 }
321 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
322 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
323 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
324 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
325 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
326 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
327 gpu::Operation::Synchronization(val) => match val {
328 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
329 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
330 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
331 gpu::Synchronization::SyncAsyncProxyShared => {
332 self.flags.inst_tma = true;
333 instructions.push(Instruction::ProxyAsyncToSharedFence)
334 }
335 },
336 gpu::Operation::Plane(op) => {
337 self.flags.indexes.plane_dim_checked = true;
338 let out = self.compile_variable(out.unwrap());
339 match op {
340 gpu::Plane::Sum(op) => {
341 let instruction = WarpInstruction::ReduceSum {
342 input: self.compile_variable(op.input),
343 out,
344 };
345 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
346 instructions.push(Instruction::Warp(instruction));
347 }
348 gpu::Plane::InclusiveSum(op) => {
349 self.flags.indexes.unit_pos_plane = true;
350 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
351 input: self.compile_variable(op.input),
352 out,
353 }))
354 }
355 gpu::Plane::InclusiveProd(op) => {
356 self.flags.indexes.unit_pos_plane = true;
357 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
358 input: self.compile_variable(op.input),
359 out,
360 }))
361 }
362 gpu::Plane::ExclusiveSum(op) => {
363 self.flags.indexes.unit_pos_plane = true;
364 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
365 input: self.compile_variable(op.input),
366 out,
367 }))
368 }
369 gpu::Plane::ExclusiveProd(op) => {
370 self.flags.indexes.unit_pos_plane = true;
371 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
372 input: self.compile_variable(op.input),
373 out,
374 }))
375 }
376 gpu::Plane::Prod(op) => {
377 let instruction = WarpInstruction::ReduceProd {
378 input: self.compile_variable(op.input),
379 out,
380 };
381 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
382 instructions.push(Instruction::Warp(instruction))
383 }
384 gpu::Plane::Max(op) => {
385 let instruction = WarpInstruction::ReduceMax {
386 input: self.compile_variable(op.input),
387 out,
388 };
389 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
390 instructions.push(Instruction::Warp(instruction))
391 }
392 gpu::Plane::Min(op) => {
393 let instruction = WarpInstruction::ReduceMin {
394 input: self.compile_variable(op.input),
395 out,
396 };
397 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
398 instructions.push(Instruction::Warp(instruction))
399 }
400 gpu::Plane::Elect => {
401 if self.compilation_options.supports_features.elect_sync {
402 self.flags.inst_ptx_wrappers = true;
403 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
404 } else {
405 instructions
406 .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
407 }
408 }
409 gpu::Plane::All(op) => {
410 instructions.push(Instruction::Warp(WarpInstruction::All {
411 input: self.compile_variable(op.input),
412 out,
413 }))
414 }
415 gpu::Plane::Any(op) => {
416 instructions.push(Instruction::Warp(WarpInstruction::Any {
417 input: self.compile_variable(op.input),
418 out,
419 }))
420 }
421 gpu::Plane::Ballot(op) => {
422 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
423 input: self.compile_variable(op.input),
424 out,
425 }))
426 }
427 gpu::Plane::Broadcast(op) => {
428 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
429 input: self.compile_variable(op.lhs),
430 id: self.compile_variable(op.rhs),
431 out,
432 }))
433 }
434 gpu::Plane::Shuffle(op) => {
435 instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
436 input: self.compile_variable(op.lhs),
437 src_lane: self.compile_variable(op.rhs),
438 out,
439 }))
440 }
441 gpu::Plane::ShuffleXor(op) => {
442 instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
443 input: self.compile_variable(op.lhs),
444 mask: self.compile_variable(op.rhs),
445 out,
446 }))
447 }
448 gpu::Plane::ShuffleUp(op) => {
449 instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
450 input: self.compile_variable(op.lhs),
451 delta: self.compile_variable(op.rhs),
452 out,
453 }))
454 }
455 gpu::Plane::ShuffleDown(op) => {
456 instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
457 input: self.compile_variable(op.lhs),
458 delta: self.compile_variable(op.rhs),
459 out,
460 }))
461 }
462 }
463 }
464 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
465 gpu::Operation::NonSemantic(debug) => match debug {
466 gpu::NonSemantic::Print {
467 format_string,
468 args,
469 } => instructions.push(Instruction::Printf {
470 format_string,
471 args: args
472 .into_iter()
473 .map(|arg| self.compile_variable(arg))
474 .collect(),
475 }),
476 gpu::NonSemantic::Comment { content } => {
477 instructions.push(Instruction::Comment { content })
478 }
479 _ => {}
481 },
482 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
483 gpu::BarrierOps::Declare { barrier } => {
484 let VariableKind::Barrier { level, .. } = barrier.kind else {
485 unreachable!()
486 };
487 let barrier = self.compile_variable(barrier);
488 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
489 barrier,
490 level,
491 }));
492 }
493 gpu::BarrierOps::Init {
494 barrier,
495 is_elected,
496 arrival_count,
497 with_async_proxy_fence,
498 } => {
499 let VariableKind::Barrier { level, .. } = barrier.kind else {
500 unreachable!()
501 };
502 let barrier = self.compile_variable(barrier);
503 let arrival_count = self.compile_variable(arrival_count);
504 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
505 barrier,
506 is_elected: self.compile_variable(is_elected),
507 arrival_count,
508 level,
509 with_async_proxy_fence,
510 }));
511 }
512 gpu::BarrierOps::InitManual {
513 barrier,
514 arrival_count,
515 } => {
516 let barrier = self.compile_variable(barrier);
517 let arrival_count = self.compile_variable(arrival_count);
518 instructions.push(Instruction::Barrier(
519 super::barrier::BarrierOps::InitManual {
520 barrier,
521 arrival_count,
522 },
523 ));
524 }
525 gpu::BarrierOps::MemCopyAsync {
526 barrier,
527 source,
528 source_length,
529 offset_source,
530 offset_out,
531 } => {
532 instructions.push(Instruction::Barrier(
533 super::barrier::BarrierOps::MemCopyAsync {
534 barrier: self.compile_variable(barrier),
535 source: self.compile_variable(source),
536 destination: self.compile_variable(out.unwrap()),
537 source_length: self.compile_variable(source_length),
538 offset_source: self.compile_variable(offset_source),
539 offset_out: self.compile_variable(offset_out),
540 cooperative: false,
541 },
542 ));
543 }
544 gpu::BarrierOps::MemCopyAsyncCooperative {
545 barrier,
546 source,
547 source_length,
548 offset_source,
549 offset_out,
550 } => {
551 instructions.push(Instruction::Barrier(
552 super::barrier::BarrierOps::MemCopyAsync {
553 barrier: self.compile_variable(barrier),
554 source: self.compile_variable(source),
555 destination: self.compile_variable(out.unwrap()),
556 source_length: self.compile_variable(source_length),
557 offset_source: self.compile_variable(offset_source),
558 offset_out: self.compile_variable(offset_out),
559 cooperative: true,
560 },
561 ));
562 }
563 gpu::BarrierOps::MemCopyAsyncTx {
564 barrier,
565 source,
566 source_length,
567 offset_source,
568 offset_out,
569 } => {
570 instructions.push(Instruction::Barrier(
571 super::barrier::BarrierOps::MemCopyAsyncTx {
572 barrier: self.compile_variable(barrier),
573 source: self.compile_variable(source),
574 destination: self.compile_variable(out.unwrap()),
575 source_length: self.compile_variable(source_length),
576 offset_source: self.compile_variable(offset_source),
577 offset_out: self.compile_variable(offset_out),
578 },
579 ));
580 }
581 gpu::BarrierOps::TmaLoad {
582 barrier,
583 tensor_map,
584 offset_out,
585 indices,
586 } => {
587 instructions.push(Instruction::Barrier(
588 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
589 barrier: self.compile_variable(barrier),
590 smem_buffer: self.compile_variable(out.unwrap()),
591 smem_offset: self.compile_variable(offset_out),
592 tensor_map: self.compile_variable(tensor_map),
593 indices: indices
594 .into_iter()
595 .map(|it| self.compile_variable(it))
596 .collect(),
597 },
598 ));
599 }
600 gpu::BarrierOps::TmaLoadIm2col {
601 barrier,
602 tensor_map,
603 offset_out,
604 indices,
605 offsets,
606 } => {
607 self.flags.inst_tma_im2col = true;
608 instructions.push(Instruction::Barrier(
609 super::barrier::BarrierOps::TmaLoadIm2col {
610 barrier: self.compile_variable(barrier),
611 smem_buffer: self.compile_variable(out.unwrap()),
612 smem_offset: self.compile_variable(offset_out),
613 tensor_map: self.compile_variable(tensor_map),
614 indices: indices
615 .into_iter()
616 .map(|it| self.compile_variable(it))
617 .collect(),
618 offsets: offsets
619 .into_iter()
620 .map(|it| self.compile_variable(it))
621 .collect(),
622 },
623 ));
624 }
625 gpu::BarrierOps::Arrive { barrier } => {
626 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
627 barrier: self.compile_variable(barrier),
628 token: self.compile_variable(out.unwrap()),
629 }))
630 }
631 gpu::BarrierOps::ArriveTx {
632 barrier,
633 arrive_count_update,
634 transaction_count_update,
635 } => {
636 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
637 barrier: self.compile_variable(barrier),
638 token: self.compile_variable(out.unwrap()),
639 arrive_count_update: self.compile_variable(arrive_count_update),
640 transaction_count_update: self.compile_variable(transaction_count_update),
641 }))
642 }
643 gpu::BarrierOps::ExpectTx {
644 barrier,
645 transaction_count_update,
646 } => {
647 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
648 barrier: self.compile_variable(barrier),
649 transaction_count_update: self.compile_variable(transaction_count_update),
650 }))
651 }
652 gpu::BarrierOps::Wait { barrier, token } => {
653 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
654 barrier: self.compile_variable(barrier),
655 token: self.compile_variable(token),
656 }))
657 }
658 gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
659 Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
660 barrier: self.compile_variable(barrier),
661 phase: self.compile_variable(phase),
662 }),
663 ),
664 gpu::BarrierOps::ArriveAndWait { barrier } => {
665 let VariableKind::Barrier { level, .. } = barrier.kind else {
666 unreachable!()
667 };
668 instructions.push(Instruction::Barrier(
669 super::barrier::BarrierOps::ArriveAndWait {
670 barrier: self.compile_variable(barrier),
671 level,
672 },
673 ))
674 }
675 },
676 gpu::Operation::Tma(tma_ops) => {
677 self.flags.inst_tma = true;
678 match tma_ops {
679 gpu::TmaOps::TmaStore {
680 source,
681 coordinates,
682 offset_source,
683 } => {
684 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
685 smem_buffer: self.compile_variable(source),
686 smem_offset: self.compile_variable(offset_source),
687 tensor_map: self.compile_variable(out.unwrap()),
688 indices: coordinates
689 .into_iter()
690 .map(|it| self.compile_variable(it))
691 .collect(),
692 });
693 }
694 gpu::TmaOps::CommitGroup => {
695 instructions.push(Instruction::BulkCommitGroup);
696 }
697 gpu::TmaOps::WaitGroup { max_pending } => {
698 instructions.push(Instruction::BulkWaitGroup { max_pending });
699 }
700 gpu::TmaOps::WaitGroupRead { max_pending } => {
701 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
702 }
703 }
704 }
705 gpu::Operation::Marker(_) => {}
706 }
707 }
708
709 fn update_debug_loc(
710 &mut self,
711 instructions: &mut Vec<Instruction<D>>,
712 inst: &gpu::Instruction,
713 ) {
714 if !matches!(inst.operation, Operation::NonSemantic(_)) {
715 match &inst.source_loc {
716 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
717 self.source_loc = Some(loc.clone());
718 instructions.push(Instruction::Line {
719 file: loc.source.file.clone(),
720 line: loc.line,
721 });
722 }
723 _ => {}
724 }
725 }
726 }
727
728 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
729 self.flags.inst_wmma = true;
730
731 let out = self.compile_variable(out.unwrap());
732
733 let inst = match cmma {
734 gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
735 frag: out,
736 value: self.compile_variable(value),
737 },
738 gpu::CoopMma::Load {
739 value,
740 stride,
741 offset,
742 layout,
743 } => WmmaInstruction::Load {
744 frag: out,
745 offset: self.compile_variable(offset),
746 value: self.compile_variable(value),
747 stride: self.compile_variable(stride),
748 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
749 },
750 gpu::CoopMma::Execute {
751 mat_a,
752 mat_b,
753 mat_c,
754 } => WmmaInstruction::Execute {
755 frag_a: self.compile_variable(mat_a),
756 frag_b: self.compile_variable(mat_b),
757 frag_c: self.compile_variable(mat_c),
758 frag_d: out,
759 warp_size: self.compilation_options.warp_size,
760 },
761 gpu::CoopMma::ExecuteManual {
762 matrix,
763 registers_a,
764 registers_b,
765 registers_c,
766 } => WmmaInstruction::ExecuteManual {
767 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
768 frag_a: registers_a
769 .into_iter()
770 .map(|it| self.compile_variable(it))
771 .collect(),
772 frag_b: registers_b
773 .into_iter()
774 .map(|it| self.compile_variable(it))
775 .collect(),
776 frag_c: registers_c
777 .into_iter()
778 .map(|it| self.compile_variable(it))
779 .collect(),
780 frag_d: out,
781 },
782 gpu::CoopMma::ExecuteScaled {
783 matrix,
784 registers_a,
785 registers_b,
786 registers_c,
787 scales_a,
788 scales_b,
789 scales_factor,
790 } => WmmaInstruction::ExecuteScaled {
791 shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
792 frag_a: registers_a
793 .into_iter()
794 .map(|it| self.compile_variable(it))
795 .collect(),
796 frag_b: registers_b
797 .into_iter()
798 .map(|it| self.compile_variable(it))
799 .collect(),
800 frag_c: registers_c
801 .into_iter()
802 .map(|it| self.compile_variable(it))
803 .collect(),
804 frag_d: out,
805
806 scales_a: self.compile_variable(scales_a),
807 scales_b: self.compile_variable(scales_b),
808 scales_factor,
809 },
810 gpu::CoopMma::Store {
811 mat,
812 stride,
813 offset,
814 layout,
815 } => {
816 self.flags.indexes.unit_pos = true;
817 self.flags.indexes.plane_index = true;
818 WmmaInstruction::Store {
819 output: out,
820 offset: self.compile_variable(offset),
821 frag: self.compile_variable(mat),
822 stride: self.compile_variable(stride),
823 layout: self
824 .compile_matrix_layout(layout)
825 .expect("Layout required for store instruction"),
826 }
827 }
828 gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
829 input: self.compile_variable(input),
830 output: out,
831 },
832 gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
833 panic!("Row/Col index should be handled by processors")
834 }
835 };
836
837 D::register_wmma_instruction_extension(&mut self.extensions, &inst);
838
839 Instruction::Wmma(inst)
840 }
841
842 fn compile_metadata(
843 &mut self,
844 metadata: gpu::Metadata,
845 out: Option<gpu::Variable>,
846 ) -> Instruction<D> {
847 let out = out.unwrap();
848 match metadata {
849 gpu::Metadata::Stride { dim, var } => {
850 let position = self.ext_meta_position(var);
851 let offset = self.metadata.stride_offset_index(position);
852 Instruction::ExtendedMetadata {
853 info_offset: self.compile_variable(offset.into()),
854 dim: self.compile_variable(dim),
855 split_meta: self.compilation_options.supports_features.grid_constants,
856 static_offset: self.metadata.static_len(),
857 out: self.compile_variable(out),
858 }
859 }
860 gpu::Metadata::Shape { dim, var } => {
861 let position = self.ext_meta_position(var);
862 let offset = self.metadata.shape_offset_index(position);
863 Instruction::ExtendedMetadata {
864 info_offset: self.compile_variable(offset.into()),
865 dim: self.compile_variable(dim),
866 split_meta: self.compilation_options.supports_features.grid_constants,
867 static_offset: self.metadata.static_len(),
868 out: self.compile_variable(out),
869 }
870 }
871 gpu::Metadata::Rank { var } => {
872 let out = self.compile_variable(out);
873 let pos = self.ext_meta_position(var);
874 let offset = self.metadata.rank_index(pos);
875 super::Instruction::Metadata {
876 info_offset: self.compile_variable(offset.into()),
877 split_meta: self.compilation_options.supports_features.grid_constants,
878 out,
879 }
880 }
881 gpu::Metadata::Length { var } => {
882 let input = self.compile_variable(var);
883 let out = self.compile_variable(out);
884
885 match input {
886 Variable::Slice { .. } => Instruction::SliceLength { input, out },
887 Variable::SharedMemory(_id, _item, length) => {
888 Instruction::ConstLength { length, out }
889 }
890 _ => {
891 let id = input.id().expect("Variable should have id");
892 let offset = self.metadata.len_index(id);
893 Instruction::Metadata {
894 info_offset: self.compile_variable(offset.into()),
895 split_meta: self.compilation_options.supports_features.grid_constants,
896 out,
897 }
898 }
899 }
900 }
901 gpu::Metadata::BufferLength { var } => {
902 let input = self.compile_variable(var);
903 let out = self.compile_variable(out);
904
905 match input {
906 Variable::Slice { .. } => Instruction::SliceLength { input, out },
907 _ => {
908 let id = input.id().expect("Variable should have id");
909 let offset = self.metadata.buffer_len_index(id);
910 Instruction::Metadata {
911 info_offset: self.compile_variable(offset.into()),
912 split_meta: self.compilation_options.supports_features.grid_constants,
913 out,
914 }
915 }
916 }
917 }
918 }
919 }
920
921 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
922 match branch {
923 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
924 cond: self.compile_variable(op.cond),
925 instructions: self.compile_scope(&mut op.scope),
926 }),
927 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
928 cond: self.compile_variable(op.cond),
929 instructions_if: self.compile_scope(&mut op.scope_if),
930 instructions_else: self.compile_scope(&mut op.scope_else),
931 }),
932 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
933 value: self.compile_variable(op.value),
934 instructions_default: self.compile_scope(&mut op.scope_default),
935 instructions_cases: op
936 .cases
937 .into_iter()
938 .map(|(val, mut block)| {
939 (self.compile_variable(val), self.compile_scope(&mut block))
940 })
941 .collect(),
942 }),
943 gpu::Branch::Return => instructions.push(Instruction::Return),
944 gpu::Branch::Break => instructions.push(Instruction::Break),
945 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
946 i: self.compile_variable(range_loop.i),
947 start: self.compile_variable(range_loop.start),
948 end: self.compile_variable(range_loop.end),
949 step: range_loop.step.map(|it| self.compile_variable(it)),
950 inclusive: range_loop.inclusive,
951 instructions: self.compile_scope(&mut range_loop.scope),
952 }),
953 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
954 instructions: self.compile_scope(&mut op.scope),
955 }),
956 };
957 }
958
959 fn compile_atomic(
960 &mut self,
961 value: gpu::AtomicOp,
962 out: Option<gpu::Variable>,
963 instructions: &mut Vec<Instruction<D>>,
964 ) {
965 let out = out.unwrap();
966 match value {
967 gpu::AtomicOp::Load(op) => {
968 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
969 }
970 gpu::AtomicOp::Store(op) => {
971 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
972 }
973 gpu::AtomicOp::Swap(op) => {
974 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
975 }
976 gpu::AtomicOp::Add(op) => {
977 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
978 }
979 gpu::AtomicOp::Sub(op) => {
980 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
981 }
982 gpu::AtomicOp::Max(op) => {
983 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
984 }
985 gpu::AtomicOp::Min(op) => {
986 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
987 }
988 gpu::AtomicOp::And(op) => {
989 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
990 }
991 gpu::AtomicOp::Or(op) => {
992 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
993 }
994 gpu::AtomicOp::Xor(op) => {
995 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
996 }
997 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
998 input: self.compile_variable(op.input),
999 cmp: self.compile_variable(op.cmp),
1000 val: self.compile_variable(op.val),
1001 out: self.compile_variable(out),
1002 }),
1003 }
1004 }
1005
1006 fn compile_arithmetic(
1007 &mut self,
1008 value: gpu::Arithmetic,
1009 out: Option<gpu::Variable>,
1010 modes: InstructionModes,
1011 instructions: &mut Vec<Instruction<D>>,
1012 ) {
1013 let out = out.unwrap();
1014 match value {
1015 gpu::Arithmetic::Add(op) => {
1016 instructions.push(Instruction::Add(self.compile_binary(op, out)))
1017 }
1018 gpu::Arithmetic::SaturatingAdd(op) => {
1019 instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1020 }
1021 gpu::Arithmetic::Mul(op) => {
1022 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1023 }
1024 gpu::Arithmetic::Div(op) => {
1025 let op = self.compile_binary(op, out);
1026 instructions.push(self.select_fast_float(
1027 out.ty,
1028 modes,
1029 FastMath::AllowReciprocal
1030 | FastMath::ReducedPrecision
1031 | FastMath::UnsignedZero
1032 | FastMath::NotInf,
1033 Instruction::Div(op),
1034 Instruction::FastDiv(op),
1035 ))
1036 }
1037 gpu::Arithmetic::Sub(op) => {
1038 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1039 }
1040 gpu::Arithmetic::SaturatingSub(op) => {
1041 instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1042 }
1043 gpu::Arithmetic::MulHi(op) => {
1044 let instruction = Instruction::HiMul(self.compile_binary(op, out));
1045 D::register_instruction_extension(&mut self.extensions, &instruction);
1046 instructions.push(instruction)
1047 }
1048 gpu::Arithmetic::Modulo(op) => {
1049 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1050 }
1051 gpu::Arithmetic::Abs(op) => {
1052 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1053 }
1054 gpu::Arithmetic::Exp(op) => {
1055 let op = self.compile_unary(op, out);
1056 instructions.push(self.select_fast_float(
1057 out.ty,
1058 modes,
1059 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1060 Instruction::Exp(op),
1061 Instruction::FastExp(op),
1062 ));
1063 }
1064 gpu::Arithmetic::Log(op) => {
1065 let op = self.compile_unary(op, out);
1066 instructions.push(self.select_fast_float(
1067 out.ty,
1068 modes,
1069 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1070 Instruction::Log(op),
1071 Instruction::FastLog(op),
1072 ));
1073 }
1074 gpu::Arithmetic::Log1p(op) => {
1075 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1076 }
1077 gpu::Arithmetic::Cos(op) => {
1078 let op = self.compile_unary(op, out);
1079 instructions.push(self.select_fast_float(
1080 out.ty,
1081 modes,
1082 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1083 Instruction::Cos(op),
1084 Instruction::FastCos(op),
1085 ));
1086 }
1087 gpu::Arithmetic::Sin(op) => {
1088 let op = self.compile_unary(op, out);
1089 instructions.push(self.select_fast_float(
1090 out.ty,
1091 modes,
1092 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1093 Instruction::Sin(op),
1094 Instruction::FastSin(op),
1095 ));
1096 }
1097 gpu::Arithmetic::Tanh(op) => {
1098 let op = self.compile_unary(op, out);
1099 let instruction = Instruction::Tanh(op);
1100 D::register_instruction_extension(&mut self.extensions, &instruction);
1101 instructions.push(self.select_fast_float(
1102 out.ty,
1103 modes,
1104 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1105 instruction,
1106 Instruction::FastTanh(op),
1107 ))
1108 }
1109 gpu::Arithmetic::Powf(op) => {
1110 let op = self.compile_binary(op, out);
1111 instructions.push(self.select_fast_float(
1112 out.ty,
1113 modes,
1114 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1115 Instruction::Powf(op),
1116 Instruction::FastPowf(op),
1117 ))
1118 }
1119 gpu::Arithmetic::Powi(op) => {
1120 instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1121 }
1122 gpu::Arithmetic::Sqrt(op) => {
1123 let op = self.compile_unary(op, out);
1124 instructions.push(self.select_fast_float(
1125 out.ty,
1126 modes,
1127 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1128 Instruction::Sqrt(op),
1129 Instruction::FastSqrt(op),
1130 ))
1131 }
1132 gpu::Arithmetic::InverseSqrt(op) => {
1133 let op = self.compile_unary(op, out);
1134 instructions.push(self.select_fast_float(
1135 out.ty,
1136 modes,
1137 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1138 Instruction::InverseSqrt(op),
1139 Instruction::FastInverseSqrt(op),
1140 ))
1141 }
1142 gpu::Arithmetic::Erf(op) => {
1143 let instruction = Instruction::Erf(self.compile_unary(op, out));
1144 D::register_instruction_extension(&mut self.extensions, &instruction);
1145 instructions.push(instruction)
1146 }
1147 gpu::Arithmetic::Max(op) => {
1148 let instruction = Instruction::Max(self.compile_binary(op, out));
1149 D::register_instruction_extension(&mut self.extensions, &instruction);
1150 instructions.push(instruction)
1151 }
1152 gpu::Arithmetic::Min(op) => {
1153 let instruction = Instruction::Min(self.compile_binary(op, out));
1154 D::register_instruction_extension(&mut self.extensions, &instruction);
1155 instructions.push(instruction)
1156 }
1157 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1158 input: self.compile_variable(op.input),
1159 min_value: self.compile_variable(op.min_value),
1160 max_value: self.compile_variable(op.max_value),
1161 out: self.compile_variable(out),
1162 }),
1163 gpu::Arithmetic::Recip(op) => {
1164 let elem = op.input.ty.elem_type();
1165 let input = self.compile_variable(op.input);
1166 let out = self.compile_variable(out);
1167 let lhs = match elem {
1168 gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1169 gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1170 gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1171 gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1172 };
1173 let div = Instruction::Div(BinaryInstruction {
1174 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1175 rhs: input,
1176 out,
1177 });
1178 let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1179
1180 instructions.push(self.select_fast_float(
1181 elem.into(),
1182 modes,
1183 FastMath::AllowReciprocal
1184 | FastMath::ReducedPrecision
1185 | FastMath::UnsignedZero
1186 | FastMath::NotInf,
1187 div,
1188 recip,
1189 ))
1190 }
1191 gpu::Arithmetic::Round(op) => {
1192 instructions.push(Instruction::Round(self.compile_unary(op, out)))
1193 }
1194 gpu::Arithmetic::Floor(op) => {
1195 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1196 }
1197 gpu::Arithmetic::Ceil(op) => {
1198 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1199 }
1200 gpu::Arithmetic::Trunc(op) => {
1201 instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1202 }
1203 gpu::Arithmetic::Remainder(op) => {
1204 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1205 }
1206 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1207 a: self.compile_variable(op.a),
1208 b: self.compile_variable(op.b),
1209 c: self.compile_variable(op.c),
1210 out: self.compile_variable(out),
1211 }),
1212 gpu::Arithmetic::Neg(op) => {
1213 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1214 }
1215 gpu::Arithmetic::Normalize(op) => {
1216 let op = self.compile_unary(op, out);
1217 instructions.push(self.select_fast_float(
1218 out.ty,
1219 modes,
1220 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1221 Instruction::Normalize(op),
1222 Instruction::FastNormalize(op),
1223 ))
1224 }
1225 gpu::Arithmetic::Magnitude(op) => {
1226 let op = self.compile_unary(op, out);
1227 instructions.push(self.select_fast_float(
1228 out.ty,
1229 modes,
1230 FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1231 Instruction::Magnitude(op),
1232 Instruction::FastMagnitude(op),
1233 ))
1234 }
1235 gpu::Arithmetic::Dot(op) => {
1236 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1237 }
1238 };
1239 }
1240
1241 fn select_fast_float(
1242 &self,
1243 ty: gpu::Type,
1244 modes: InstructionModes,
1245 required_flags: EnumSet<FastMath>,
1246 default: Instruction<D>,
1247 fast: Instruction<D>,
1248 ) -> Instruction<D> {
1249 if !self.compilation_options.supports_features.fast_math
1250 || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1251 {
1252 return default;
1253 }
1254
1255 if modes.fp_math_mode.is_superset(required_flags) {
1256 fast
1257 } else {
1258 default
1259 }
1260 }
1261
1262 fn compile_comparison(
1263 &mut self,
1264 value: gpu::Comparison,
1265 out: Option<gpu::Variable>,
1266 instructions: &mut Vec<Instruction<D>>,
1267 ) {
1268 let out = out.unwrap();
1269 match value {
1270 gpu::Comparison::Equal(op) => {
1271 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1272 }
1273 gpu::Comparison::Lower(op) => {
1274 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1275 }
1276 gpu::Comparison::Greater(op) => {
1277 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1278 }
1279 gpu::Comparison::LowerEqual(op) => {
1280 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1281 }
1282 gpu::Comparison::GreaterEqual(op) => {
1283 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1284 }
1285 gpu::Comparison::NotEqual(op) => {
1286 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1287 }
1288 gpu::Comparison::IsNan(op) => {
1289 instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1290 }
1291 gpu::Comparison::IsInf(op) => {
1292 instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1293 }
1294 };
1295 }
1296
1297 fn compile_bitwise(
1298 &mut self,
1299 value: gpu::Bitwise,
1300 out: Option<gpu::Variable>,
1301 instructions: &mut Vec<Instruction<D>>,
1302 ) {
1303 let out = out.unwrap();
1304 match value {
1305 gpu::Bitwise::BitwiseOr(op) => {
1306 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1307 }
1308 gpu::Bitwise::BitwiseAnd(op) => {
1309 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1310 }
1311 gpu::Bitwise::BitwiseXor(op) => {
1312 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1313 }
1314 gpu::Bitwise::CountOnes(op) => {
1315 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1316 }
1317 gpu::Bitwise::ReverseBits(op) => {
1318 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1319 }
1320 gpu::Bitwise::ShiftLeft(op) => {
1321 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1322 }
1323 gpu::Bitwise::ShiftRight(op) => {
1324 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1325 }
1326 gpu::Bitwise::BitwiseNot(op) => {
1327 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1328 }
1329 gpu::Bitwise::LeadingZeros(op) => {
1330 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1331 }
1332 gpu::Bitwise::FindFirstSet(op) => {
1333 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1334 D::register_instruction_extension(&mut self.extensions, &instruction);
1335 instructions.push(instruction)
1336 }
1337 };
1338 }
1339
1340 fn compile_operator(
1341 &mut self,
1342 value: gpu::Operator,
1343 out: Option<gpu::Variable>,
1344 instructions: &mut Vec<Instruction<D>>,
1345 ) {
1346 let out = out.unwrap();
1347 match value {
1348 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1349 instructions.push(Instruction::Index(self.compile_index(op, out)));
1350 }
1351 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1352 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1353 }
1354 gpu::Operator::And(op) => {
1355 instructions.push(Instruction::And(self.compile_binary(op, out)))
1356 }
1357 gpu::Operator::Or(op) => {
1358 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1359 }
1360 gpu::Operator::Not(op) => {
1361 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1362 }
1363 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1364 inputs: op
1365 .inputs
1366 .into_iter()
1367 .map(|it| self.compile_variable(it))
1368 .collect(),
1369 out: self.compile_variable(out),
1370 }),
1371 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1372 input: self.compile_variable(op.input),
1373 in_index: self.compile_variable(op.in_index),
1374 out: self.compile_variable(out),
1375 out_index: self.compile_variable(op.out_index),
1376 }),
1377 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1378 input: self.compile_variable(op.input),
1379 in_index: self.compile_variable(op.in_index),
1380 out: self.compile_variable(out),
1381 out_index: self.compile_variable(op.out_index),
1382 len: op.len,
1383 }),
1384 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1385 cond: self.compile_variable(op.cond),
1386 then: self.compile_variable(op.then),
1387 or_else: self.compile_variable(op.or_else),
1388 out: self.compile_variable(out),
1389 }),
1390 gpu::Operator::Cast(op)
1392 if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1393 {
1394 self.flags.elem_f16 = true;
1396 self.flags.elem_bf16 = true;
1397 let vec_in = op.input.ty.line_size();
1398 let packing = out.storage_type().packing_factor();
1399 self.compile_type(op.input.ty.line(packing));
1400 self.compile_type(
1401 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1402 );
1403 self.compile_type(
1404 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1405 );
1406 self.compile_type(
1407 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1408 );
1409 self.compile_type(
1410 gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1411 );
1412
1413 let inst = self.compile_unary(op, out);
1414
1415 instructions.push(Instruction::SpecialCast(inst));
1416 }
1417 gpu::Operator::Cast(op) => {
1418 let op = self.compile_unary(op, out);
1419
1420 if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1421 self.flags.elem_tf32 = true;
1422 }
1423
1424 instructions.push(Instruction::Assign(op))
1425 }
1426 gpu::Operator::Reinterpret(op) => {
1427 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1428 }
1429 };
1430 }
1431
1432 fn compile_binary(
1433 &mut self,
1434 value: gpu::BinaryOperator,
1435 out: gpu::Variable,
1436 ) -> BinaryInstruction<D> {
1437 BinaryInstruction {
1438 lhs: self.compile_variable(value.lhs),
1439 rhs: self.compile_variable(value.rhs),
1440 out: self.compile_variable(out),
1441 }
1442 }
1443
1444 fn compile_index_assign(
1445 &mut self,
1446 value: gpu::IndexAssignOperator,
1447 out: gpu::Variable,
1448 ) -> IndexAssignInstruction<D> {
1449 IndexAssignInstruction {
1450 index: self.compile_variable(value.index),
1451 value: self.compile_variable(value.value),
1452 line_size: value.line_size,
1453 out: self.compile_variable(out),
1454 }
1455 }
1456
1457 fn compile_index(
1458 &mut self,
1459 value: gpu::IndexOperator,
1460 out: gpu::Variable,
1461 ) -> IndexInstruction<D> {
1462 IndexInstruction {
1463 list: self.compile_variable(value.list),
1464 index: self.compile_variable(value.index),
1465 line_size: value.line_size,
1466 out: self.compile_variable(out),
1467 }
1468 }
1469
1470 fn compile_unary(
1471 &mut self,
1472 value: gpu::UnaryOperator,
1473 out: gpu::Variable,
1474 ) -> UnaryInstruction<D> {
1475 UnaryInstruction {
1476 input: self.compile_variable(value.input),
1477 out: self.compile_variable(out),
1478 }
1479 }
1480
1481 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1482 let item = value.ty;
1483 match value.kind {
1484 gpu::VariableKind::GlobalInputArray(id) => {
1485 Variable::GlobalInputArray(id, self.compile_type(item))
1486 }
1487 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1488 id,
1489 elem: self.compile_storage_type(item.storage_type()),
1490 in_struct: self.compilation_options.supports_features.grid_constants,
1491 },
1492 gpu::VariableKind::TensorMapInput(id) => {
1493 self.flags.inst_tma = true;
1494 Variable::TensorMap(id)
1495 }
1496 gpu::VariableKind::TensorMapOutput(id) => {
1497 self.flags.inst_tma = true;
1498 Variable::TensorMap(id)
1499 }
1500 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1501 id,
1502 item: self.compile_type(item),
1503 },
1504 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1505 id,
1506 item: self.compile_type(item),
1507 },
1508 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1509 id,
1510 item: self.compile_type(item),
1511 },
1512 gpu::VariableKind::GlobalOutputArray(id) => {
1513 Variable::GlobalOutputArray(id, self.compile_type(item))
1514 }
1515 gpu::VariableKind::ConstantScalar(value) => {
1516 Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1517 }
1518 gpu::VariableKind::SharedMemory { id, length, .. } => {
1519 let item = self.compile_type(item);
1520 Variable::SharedMemory(id, item, length)
1521 }
1522 gpu::VariableKind::ConstantArray {
1523 id,
1524 length,
1525 unroll_factor,
1526 } => {
1527 let item = self.compile_type(item);
1528 Variable::ConstantArray(id, item, length * unroll_factor)
1529 }
1530 gpu::VariableKind::Builtin(builtin) => match builtin {
1531 gpu::Builtin::AbsolutePos => {
1532 self.flags.indexes.absolute_pos = true;
1533 Variable::AbsolutePos
1534 }
1535 gpu::Builtin::CubePosCluster
1536 if self.compilation_options.supports_features.clusters =>
1537 {
1538 self.flags.indexes.cluster_pos = true;
1539 Variable::ClusterRank
1540 }
1541 gpu::Builtin::CubePosClusterX
1542 if self.compilation_options.supports_features.clusters =>
1543 {
1544 self.flags.indexes.cluster_pos = true;
1545 Variable::ClusterIndexX
1546 }
1547 gpu::Builtin::CubePosClusterY
1548 if self.compilation_options.supports_features.clusters =>
1549 {
1550 self.flags.indexes.cluster_pos = true;
1551 Variable::ClusterIndexY
1552 }
1553 gpu::Builtin::CubePosClusterZ
1554 if self.compilation_options.supports_features.clusters =>
1555 {
1556 self.flags.indexes.cluster_pos = true;
1557 Variable::ClusterIndexZ
1558 }
1559 gpu::Builtin::CubePosCluster
1562 | gpu::Builtin::CubePosClusterX
1563 | gpu::Builtin::CubePosClusterY
1564 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1565 gpu::Builtin::AbsolutePosX => {
1566 self.flags.indexes.absolute_pos_tuple = true;
1567 Variable::AbsolutePosX
1568 }
1569 gpu::Builtin::AbsolutePosY => {
1570 self.flags.indexes.absolute_pos_tuple = true;
1571 Variable::AbsolutePosY
1572 }
1573 gpu::Builtin::AbsolutePosZ => {
1574 self.flags.indexes.absolute_pos_tuple = true;
1575 Variable::AbsolutePosZ
1576 }
1577 gpu::Builtin::CubeDim => {
1578 self.flags.indexes.cube_dim = true;
1579 Variable::CubeDim
1580 }
1581 gpu::Builtin::CubeDimX => {
1582 self.flags.indexes.cube_dim_tuple = true;
1583 Variable::CubeDimX
1584 }
1585 gpu::Builtin::CubeDimY => {
1586 self.flags.indexes.cube_dim_tuple = true;
1587 Variable::CubeDimY
1588 }
1589 gpu::Builtin::CubeDimZ => {
1590 self.flags.indexes.cube_dim_tuple = true;
1591 Variable::CubeDimZ
1592 }
1593 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1594 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1595 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1596 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1597 gpu::Builtin::CubePos => {
1598 self.flags.indexes.cube_pos = true;
1599 Variable::CubePos
1600 }
1601 gpu::Builtin::CubePosX => {
1602 self.flags.indexes.cube_pos_tuple = true;
1603 Variable::CubePosX
1604 }
1605 gpu::Builtin::CubePosY => {
1606 self.flags.indexes.cube_pos_tuple = true;
1607 Variable::CubePosY
1608 }
1609 gpu::Builtin::CubePosZ => {
1610 self.flags.indexes.cube_pos_tuple = true;
1611 Variable::CubePosZ
1612 }
1613 gpu::Builtin::CubeCount => {
1614 self.flags.indexes.cube_count = true;
1615 Variable::CubeCount
1616 }
1617 gpu::Builtin::CubeCountX => {
1618 self.flags.indexes.cube_count_tuple = true;
1619 Variable::CubeCountX
1620 }
1621 gpu::Builtin::CubeCountY => {
1622 self.flags.indexes.cube_count_tuple = true;
1623 Variable::CubeCountY
1624 }
1625 gpu::Builtin::CubeCountZ => {
1626 self.flags.indexes.cube_count_tuple = true;
1627 Variable::CubeCountZ
1628 }
1629 gpu::Builtin::UnitPos => {
1630 self.flags.indexes.unit_pos = true;
1631 Variable::UnitPos
1632 }
1633 gpu::Builtin::UnitPosX => {
1634 self.flags.indexes.unit_pos_tuple = true;
1635 Variable::UnitPosX
1636 }
1637 gpu::Builtin::UnitPosY => {
1638 self.flags.indexes.unit_pos_tuple = true;
1639 Variable::UnitPosY
1640 }
1641 gpu::Builtin::UnitPosZ => {
1642 self.flags.indexes.unit_pos_tuple = true;
1643 Variable::UnitPosZ
1644 }
1645 gpu::Builtin::PlaneDim => {
1646 self.flags.indexes.plane_dim = true;
1647 Variable::PlaneDim
1648 }
1649 gpu::Builtin::UnitPosPlane => {
1650 self.flags.indexes.unit_pos_plane = true;
1651 Variable::UnitPosPlane
1652 }
1653 },
1654 gpu::VariableKind::LocalArray {
1655 id,
1656 length,
1657 unroll_factor,
1658 } => {
1659 let item = self.compile_type(item);
1660 if !self.local_arrays.iter().any(|s| s.index == id) {
1661 self.local_arrays
1662 .push(LocalArray::new(id, item, length * unroll_factor));
1663 }
1664 Variable::LocalArray(id, item, length)
1665 }
1666 gpu::VariableKind::Matrix { id, mat } => {
1667 self.flags.inst_wmma = true;
1668 Variable::WmmaFragment {
1669 id,
1670 frag: self.compile_matrix(mat),
1671 }
1672 }
1673 gpu::VariableKind::Pipeline { id, num_stages } => {
1674 self.flags.op_pipeline = true;
1675 let pipeline = Variable::Pipeline { id };
1676 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1677 self.pipelines.push(PipelineOps::Init {
1678 pipeline,
1679 num_stages,
1680 });
1681 }
1682 pipeline
1683 }
1684 gpu::VariableKind::Barrier { id, level } => {
1685 self.flags.op_barrier = true;
1686 Variable::Barrier { id, level }
1687 }
1688 gpu::VariableKind::BarrierToken { id, level } => {
1689 self.flags.op_barrier = true;
1690 Variable::BarrierToken { id, level }
1691 }
1692 }
1693 }
1694
1695 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1696 Fragment {
1697 ident: self.compile_matrix_ident(matrix.ident),
1698 m: matrix.m,
1699 n: matrix.n,
1700 k: matrix.k,
1701 elem: self.compile_storage_type(matrix.storage),
1702 layout: self.compile_matrix_layout(matrix.layout),
1703 }
1704 }
1705
1706 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1707 match ident {
1708 gpu::MatrixIdent::A => FragmentIdent::A,
1709 gpu::MatrixIdent::B => FragmentIdent::B,
1710 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1711 }
1712 }
1713
1714 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1715 match layout {
1716 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1717 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1718 gpu::MatrixLayout::Undefined => None,
1719 }
1720 }
1721
1722 fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1723 Binding {
1724 id: binding.id,
1725 item: self.compile_type(binding.ty),
1726 location: binding.location,
1727 size: binding.size,
1728 vis: binding.visibility,
1729 }
1730 }
1731
1732 fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1733 let item = match ty {
1734 gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1735 gpu::Type::Line(ty, line_size) => {
1736 Item::new(self.compile_storage_type(ty), line_size as usize, false)
1737 }
1738 gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1739 };
1740 if item.elem != super::Elem::TF32 {
1741 self.items.insert(item);
1742 self.items.insert(item.optimized());
1743 } else {
1744 let mut item = item;
1746 item.elem = super::Elem::F32;
1747 self.items.insert(item);
1748 }
1749
1750 item
1751 }
1752
1753 fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1754 match value {
1755 gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1756 gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1757 gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1758 FloatKind::E2M1 => {
1759 self.flags.elem_fp4 = true;
1760 Elem::FP4x2(FP4Kind::E2M1)
1761 }
1762 FloatKind::E2M3 => {
1763 self.flags.elem_fp6 = true;
1764 Elem::FP6x2(FP6Kind::E2M3)
1765 }
1766 FloatKind::E3M2 => {
1767 self.flags.elem_fp6 = true;
1768 Elem::FP6(FP6Kind::E3M2)
1769 }
1770 FloatKind::E4M3 => {
1771 self.flags.elem_fp8 = true;
1772 Elem::FP8x2(FP8Kind::E4M3)
1773 }
1774 FloatKind::E5M2 => {
1775 self.flags.elem_fp8 = true;
1776 Elem::FP8x2(FP8Kind::E5M2)
1777 }
1778 FloatKind::UE8M0 => {
1779 self.flags.elem_fp8 = true;
1780 Elem::FP8x2(FP8Kind::UE8M0)
1781 }
1782 FloatKind::F16 => {
1783 self.flags.elem_f16 = true;
1784 Elem::F16x2
1785 }
1786 FloatKind::BF16 => {
1787 self.flags.elem_bf16 = true;
1788 Elem::BF16x2
1789 }
1790 other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1791 },
1792 other => unimplemented!("Unsupported storage type: {other}"),
1793 }
1794 }
1795
1796 fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1797 match value {
1798 gpu::ElemType::Float(kind) => match kind {
1799 gpu::FloatKind::E2M1 => {
1800 self.flags.elem_fp4 = true;
1801 Elem::FP4(FP4Kind::E2M1)
1802 }
1803 gpu::FloatKind::E2M3 => {
1804 self.flags.elem_fp6 = true;
1805 Elem::FP6(FP6Kind::E2M3)
1806 }
1807 gpu::FloatKind::E3M2 => {
1808 self.flags.elem_fp6 = true;
1809 Elem::FP6(FP6Kind::E3M2)
1810 }
1811 gpu::FloatKind::E4M3 => {
1812 self.flags.elem_fp8 = true;
1813 Elem::FP8(FP8Kind::E4M3)
1814 }
1815 gpu::FloatKind::E5M2 => {
1816 self.flags.elem_fp8 = true;
1817 Elem::FP8(FP8Kind::E5M2)
1818 }
1819 gpu::FloatKind::UE8M0 => {
1820 self.flags.elem_fp8 = true;
1821 Elem::FP8(FP8Kind::UE8M0)
1822 }
1823 gpu::FloatKind::F16 => {
1824 self.flags.elem_f16 = true;
1825 Elem::F16
1826 }
1827 gpu::FloatKind::BF16 => {
1828 self.flags.elem_bf16 = true;
1829 Elem::BF16
1830 }
1831 gpu::FloatKind::TF32 => Elem::TF32,
1832 gpu::FloatKind::Flex32 => Elem::F32,
1833 gpu::FloatKind::F32 => Elem::F32,
1834 gpu::FloatKind::F64 => Elem::F64,
1835 },
1836 gpu::ElemType::Int(kind) => match kind {
1837 gpu::IntKind::I8 => Elem::I8,
1838 gpu::IntKind::I16 => Elem::I16,
1839 gpu::IntKind::I32 => Elem::I32,
1840 gpu::IntKind::I64 => Elem::I64,
1841 },
1842 gpu::ElemType::UInt(kind) => match kind {
1843 gpu::UIntKind::U8 => Elem::U8,
1844 gpu::UIntKind::U16 => Elem::U16,
1845 gpu::UIntKind::U32 => Elem::U32,
1846 gpu::UIntKind::U64 => Elem::U64,
1847 },
1848 gpu::ElemType::Bool => Elem::Bool,
1849 }
1850 }
1851}
1852
1853fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1854 match elem {
1855 gpu::ElemType::Float(kind) => matches!(
1856 kind,
1857 FloatKind::E2M1
1858 | FloatKind::E2M3
1859 | FloatKind::E3M2
1860 | FloatKind::E4M3
1861 | FloatKind::E5M2
1862 | FloatKind::UE8M0
1863 ),
1864 _ => false,
1865 }
1866}
1867
1868fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1869 Variable::ConstantScalar(
1870 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1871 Elem::U32,
1872 )
1873}
1874
1875pub fn register_supported_types(props: &mut DeviceProperties) {
1876 let supported_types = [
1877 gpu::ElemType::UInt(gpu::UIntKind::U8),
1878 gpu::ElemType::UInt(gpu::UIntKind::U16),
1879 gpu::ElemType::UInt(gpu::UIntKind::U32),
1880 gpu::ElemType::UInt(gpu::UIntKind::U64),
1881 gpu::ElemType::Int(gpu::IntKind::I8),
1882 gpu::ElemType::Int(gpu::IntKind::I16),
1883 gpu::ElemType::Int(gpu::IntKind::I32),
1884 gpu::ElemType::Int(gpu::IntKind::I64),
1885 gpu::ElemType::Float(gpu::FloatKind::BF16),
1886 gpu::ElemType::Float(gpu::FloatKind::F16),
1887 gpu::ElemType::Float(gpu::FloatKind::F32),
1888 gpu::ElemType::Float(gpu::FloatKind::Flex32),
1889 gpu::ElemType::Bool,
1892 ];
1893
1894 let supported_atomic_types = [
1895 gpu::ElemType::Int(gpu::IntKind::I32),
1896 gpu::ElemType::Int(gpu::IntKind::I64),
1897 gpu::ElemType::UInt(gpu::UIntKind::U32),
1898 gpu::ElemType::UInt(gpu::UIntKind::U64),
1899 gpu::ElemType::Float(gpu::FloatKind::F32),
1900 ];
1901
1902 for ty in supported_types {
1903 props.register_type_usage(ty, TypeUsage::all_scalar());
1904 }
1905
1906 for ty in supported_atomic_types {
1907 props.register_type_usage(
1908 gpu::StorageType::Atomic(ty),
1909 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1910 );
1911 }
1912}