1use std::{collections::HashSet, fmt::Debug, num::NonZero};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::CubeDim;
5use cubecl_core::ir::{FloatKind, Processor, UIntKind, VariableKind};
6use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
7use cubecl_core::{
8 Compiler, Feature,
9 ir::{self as gpu},
10};
11use cubecl_core::{
12 ir::{Operation, SourceLoc},
13 prelude::{FastMath, KernelDefinition},
14};
15use cubecl_runtime::DeviceProperties;
16
17use super::{
18 AtomicKind, BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect,
19 Elem, FP6Kind, Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction,
20 IndexInstruction, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction, Variable,
21 WarpInstruction, WmmaInstruction,
22};
23use super::{FP4Kind, barrier::BarrierOps};
24use super::{FP8Kind, pipeline::PipelineOps};
25
26pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
27 std::sync::atomic::AtomicU32::new(0);
28
29#[derive(Clone, Debug)]
30pub struct CompilationOptions {
31 pub warp_size: u32,
32 pub grid_constants: bool,
33 pub supports_clusters: bool,
34}
35
36impl Default for CompilationOptions {
37 fn default() -> Self {
38 Self {
39 warp_size: 32,
40 grid_constants: false,
41 supports_clusters: false,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Default)]
49pub struct CubeIndexFlags {
50 pub absolute_pos: bool,
51 pub absolute_pos_tuple: bool,
52 pub cube_count: bool,
53 pub cube_count_tuple: bool,
54 pub cube_dim: bool,
55 pub cube_dim_tuple: bool,
56 pub cube_pos: bool,
57 pub cube_pos_tuple: bool,
58 pub plane_dim: bool,
59 pub plane_dim_checked: bool,
60 pub plane_index: bool,
61 pub unit_pos: bool,
62 pub unit_pos_tuple: bool,
63 pub unit_pos_plane: bool,
64 pub cluster_pos: bool,
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct Flags {
70 pub elem_fp4: bool,
71 pub elem_fp6: bool,
72 pub elem_fp8: bool,
73 pub elem_bf16: bool,
74 pub elem_f16: bool,
75 pub indexes: CubeIndexFlags,
76 pub op_barrier: bool,
77 pub op_pipeline: bool,
78 pub inst_fast_math: bool,
79 pub inst_tma: bool,
80 pub inst_tma_im2col: bool,
81 pub inst_wmma: bool,
82 pub use_grid_constants: bool,
83 pub static_meta_length: usize,
84 pub has_dynamic_meta: bool,
85 pub cluster_dim: Option<CubeDim>,
86}
87
88#[allow(clippy::too_many_arguments)]
89#[derive(Clone, Debug, Default)]
90pub struct CppCompiler<D: Dialect> {
91 barriers: Vec<BarrierOps<D>>,
92 compilation_options: CompilationOptions,
93 const_arrays: Vec<ConstArray<D>>,
94 ext_meta_positions: Vec<u32>,
95 cluster_dim: CubeDim,
96 extensions: Vec<D::Extension>,
97 flags: Flags,
98 items: HashSet<Item<D>>,
99 local_arrays: Vec<LocalArray<D>>,
100 metadata: cubecl_core::Metadata,
101 pipelines: Vec<PipelineOps<D>>,
102 shared_memories: Vec<SharedMemory<D>>,
103 source_loc: Option<SourceLoc>,
104 strategy: ExecutionMode,
105}
106
107impl<D: Dialect> Compiler for CppCompiler<D> {
108 type Representation = ComputeKernel<D>;
109 type CompilationOptions = CompilationOptions;
110
111 fn compile(
112 &mut self,
113 mut kernel: KernelDefinition,
114 compilation_options: &Self::CompilationOptions,
115 strategy: ExecutionMode,
116 ) -> Self::Representation {
117 self.compilation_options = compilation_options.clone();
118 self.strategy = strategy;
119
120 if !self.compilation_options.supports_clusters {
121 kernel.options.cluster_dim = None;
122 }
123 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
124
125 let ir = self.clone().compile_ir(kernel);
126 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
127 ir
128 }
129
130 fn elem_size(&self, elem: gpu::Elem) -> usize {
131 elem.size()
132 }
133
134 fn extension(&self) -> &'static str {
135 "cpp"
136 }
137}
138
139impl<D: Dialect> CppCompiler<D> {
140 fn compile_ir(mut self, mut value: KernelDefinition) -> ComputeKernel<D> {
141 self.build_metadata(&value);
142
143 let instructions = self.compile_scope(&mut value.body);
144 let buffers = value
145 .buffers
146 .into_iter()
147 .map(|b| self.compile_binding(b))
148 .collect();
149 let scalars = value
150 .scalars
151 .into_iter()
152 .map(|binding| (self.compile_elem(binding.elem), binding.count))
153 .collect();
154
155 let flags = Flags {
157 indexes: D::builtin_rules(&self.flags.indexes),
158 inst_wmma: self.flags.inst_wmma,
159 op_pipeline: self.flags.op_pipeline,
160 op_barrier: self.flags.op_barrier,
161 elem_fp4: self.flags.elem_fp4,
162 elem_fp6: self.flags.elem_fp6,
163 elem_fp8: self.flags.elem_fp8,
164 elem_bf16: self.flags.elem_bf16,
165 elem_f16: self.flags.elem_f16,
166 inst_fast_math: value
167 .options
168 .fp_math_mode
169 .contains(FastMath::ReducedPrecision),
170 inst_tma: self.flags.inst_tma,
171 inst_tma_im2col: self.flags.inst_tma_im2col,
172 use_grid_constants: self.compilation_options.grid_constants,
173 has_dynamic_meta: self.metadata.static_len() > 0,
176 static_meta_length: self.metadata.static_len() as usize,
177 cluster_dim: value.options.cluster_dim,
178 };
179
180 let body = Body {
181 instructions,
182 shared_memories: self.shared_memories,
183 pipelines: self.pipelines,
184 barriers: self.barriers,
185 const_arrays: self.const_arrays,
186 local_arrays: self.local_arrays,
187 };
188
189 let mut cluster_dim = value.options.cluster_dim;
190 if !self.compilation_options.supports_clusters {
191 cluster_dim = None;
192 }
193
194 ComputeKernel {
195 tensor_maps: value.tensor_maps,
196 buffers,
197 scalars,
198 meta_static_len: self.metadata.static_len() as usize,
199 cube_dim: value.cube_dim,
200 body,
201 extensions: self.extensions,
202 flags,
203 items: self.items,
204 kernel_name: value.options.kernel_name,
205 cluster_dim,
206 }
207 }
208
209 fn build_metadata(&mut self, value: &KernelDefinition) {
210 let mut num_ext = 0;
211
212 let mut all_meta: Vec<_> = value
213 .buffers
214 .iter()
215 .map(|buf| (buf.id, buf.has_extended_meta))
216 .chain(value.tensor_maps.iter().map(|i| (*i, true)))
217 .collect();
218
219 all_meta.sort_by_key(|(id, _)| *id);
220
221 for (_, has_extended_meta) in &all_meta {
222 self.ext_meta_positions.push(num_ext);
223 if *has_extended_meta {
224 num_ext += 1;
225 }
226 }
227
228 let num_meta = all_meta.len();
229
230 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
231 }
232
233 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
234 let id = var.index().expect("Variable should have index");
235 self.ext_meta_positions[id as usize]
236 }
237
238 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
239 let mut instructions = Vec::new();
240
241 let const_arrays = scope
242 .const_arrays
243 .drain(..)
244 .map(|(var, values)| ConstArray {
245 index: var.index().unwrap(),
246 item: self.compile_item(var.item),
247 size: values.len() as u32,
248 values: values
249 .into_iter()
250 .map(|val| self.compile_variable(val))
251 .collect(),
252 })
253 .collect::<Vec<_>>();
254 self.const_arrays.extend(const_arrays);
255
256 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
257 let processing = scope.process([checked_io]);
258
259 for var in processing.variables {
260 instructions.push(Instruction::DeclareVariable {
261 var: self.compile_variable(var),
262 });
263 }
264
265 processing
266 .instructions
267 .into_iter()
268 .for_each(|op| self.compile_instruction(&mut instructions, op));
269
270 instructions
271 }
272
273 fn compile_instruction(
274 &mut self,
275 instructions: &mut Vec<Instruction<D>>,
276 instruction: gpu::Instruction,
277 ) {
278 self.update_debug_loc(instructions, &instruction);
279 let out = instruction.out;
280 match instruction.operation {
281 gpu::Operation::Copy(variable) => {
282 instructions.push(Instruction::Assign(UnaryInstruction {
283 input: self.compile_variable(variable),
284 out: self.compile_variable(out.unwrap()),
285 }));
286 }
287 gpu::Operation::Arithmetic(op) => self.compile_arithmetic(op, out, instructions),
288 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
289 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
290 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
291 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
292 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
293 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
294 gpu::Operation::Synchronization(val) => match val {
295 gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
296 gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
297 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
298 gpu::Synchronization::SyncProxyShared => {
299 self.flags.inst_tma = true;
300 instructions.push(Instruction::ProxySharedFence)
301 }
302 },
303 gpu::Operation::Plane(op) => {
304 self.flags.indexes.plane_dim_checked = true;
305 let out = self.compile_variable(out.unwrap());
306 match op {
307 gpu::Plane::Sum(op) => {
308 let instruction = WarpInstruction::ReduceSum {
309 input: self.compile_variable(op.input),
310 out,
311 };
312 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
313 instructions.push(Instruction::Warp(instruction));
314 }
315 gpu::Plane::InclusiveSum(op) => {
316 self.flags.indexes.unit_pos_plane = true;
317 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
318 input: self.compile_variable(op.input),
319 out,
320 }))
321 }
322 gpu::Plane::InclusiveProd(op) => {
323 self.flags.indexes.unit_pos_plane = true;
324 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
325 input: self.compile_variable(op.input),
326 out,
327 }))
328 }
329 gpu::Plane::ExclusiveSum(op) => {
330 self.flags.indexes.unit_pos_plane = true;
331 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
332 input: self.compile_variable(op.input),
333 out,
334 }))
335 }
336 gpu::Plane::ExclusiveProd(op) => {
337 self.flags.indexes.unit_pos_plane = true;
338 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
339 input: self.compile_variable(op.input),
340 out,
341 }))
342 }
343 gpu::Plane::Prod(op) => {
344 let instruction = WarpInstruction::ReduceProd {
345 input: self.compile_variable(op.input),
346 out,
347 };
348 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
349 instructions.push(Instruction::Warp(instruction))
350 }
351 gpu::Plane::Max(op) => {
352 let instruction = WarpInstruction::ReduceMax {
353 input: self.compile_variable(op.input),
354 out,
355 };
356 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
357 instructions.push(Instruction::Warp(instruction))
358 }
359 gpu::Plane::Min(op) => {
360 let instruction = WarpInstruction::ReduceMin {
361 input: self.compile_variable(op.input),
362 out,
363 };
364 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
365 instructions.push(Instruction::Warp(instruction))
366 }
367 gpu::Plane::Elect => {
368 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
369 }
370 gpu::Plane::All(op) => {
371 instructions.push(Instruction::Warp(WarpInstruction::All {
372 input: self.compile_variable(op.input),
373 out,
374 }))
375 }
376 gpu::Plane::Any(op) => {
377 instructions.push(Instruction::Warp(WarpInstruction::Any {
378 input: self.compile_variable(op.input),
379 out,
380 }))
381 }
382 gpu::Plane::Ballot(op) => {
383 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
384 input: self.compile_variable(op.input),
385 out,
386 }))
387 }
388 gpu::Plane::Broadcast(op) => {
389 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
390 input: self.compile_variable(op.lhs),
391 id: self.compile_variable(op.rhs),
392 out,
393 }))
394 }
395 }
396 }
397 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
398 gpu::Operation::NonSemantic(debug) => match debug {
399 gpu::NonSemantic::Print {
400 format_string,
401 args,
402 } => instructions.push(Instruction::Printf {
403 format_string,
404 args: args
405 .into_iter()
406 .map(|arg| self.compile_variable(arg))
407 .collect(),
408 }),
409 gpu::NonSemantic::Comment { content } => {
410 instructions.push(Instruction::Comment { content })
411 }
412 _ => {}
414 },
415 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
416 gpu::BarrierOps::Init {
417 barrier,
418 with_cta_fence,
419 } => {
420 let VariableKind::Barrier { level, .. } = barrier.kind else {
421 unreachable!()
422 };
423 let barrier = self.compile_variable(barrier);
424 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
425 barrier,
426 level,
427 with_cta_fence,
428 }));
429 }
430 gpu::BarrierOps::MemCopyAsync {
431 barrier,
432 source,
433 source_length,
434 offset_source,
435 offset_out,
436 } => {
437 let VariableKind::Barrier { level, .. } = barrier.kind else {
438 unreachable!()
439 };
440 instructions.push(Instruction::Barrier(
441 super::barrier::BarrierOps::MemCopyAsync {
442 barrier: self.compile_variable(barrier),
443 source: self.compile_variable(source),
444 destination: self.compile_variable(out.unwrap()),
445 source_length: self.compile_variable(source_length),
446 offset_source: self.compile_variable(offset_source),
447 offset_out: self.compile_variable(offset_out),
448 level,
449 },
450 ));
451 }
452 gpu::BarrierOps::TmaLoad {
453 barrier,
454 tensor_map,
455 offset_out,
456 indices,
457 } => {
458 instructions.push(Instruction::Barrier(
459 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
460 barrier: self.compile_variable(barrier),
461 smem_buffer: self.compile_variable(out.unwrap()),
462 smem_offset: self.compile_variable(offset_out),
463 tensor_map: self.compile_variable(tensor_map),
464 indices: indices
465 .into_iter()
466 .map(|it| self.compile_variable(it))
467 .collect(),
468 },
469 ));
470 }
471 gpu::BarrierOps::TmaLoadIm2col {
472 barrier,
473 tensor_map,
474 offset_out,
475 indices,
476 offsets,
477 } => {
478 self.flags.inst_tma_im2col = true;
479 instructions.push(Instruction::Barrier(
480 super::barrier::BarrierOps::TmaLoadIm2col {
481 barrier: self.compile_variable(barrier),
482 smem_buffer: self.compile_variable(out.unwrap()),
483 smem_offset: self.compile_variable(offset_out),
484 tensor_map: self.compile_variable(tensor_map),
485 indices: indices
486 .into_iter()
487 .map(|it| self.compile_variable(it))
488 .collect(),
489 offsets: offsets
490 .into_iter()
491 .map(|it| self.compile_variable(it))
492 .collect(),
493 },
494 ));
495 }
496 gpu::BarrierOps::Arrive { barrier } => {
497 let VariableKind::Barrier { level, .. } = barrier.kind else {
498 unreachable!()
499 };
500 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
501 barrier: self.compile_variable(barrier),
502 level,
503 }))
504 }
505 gpu::BarrierOps::ArriveTx {
506 barrier,
507 arrive_count_update,
508 transaction_count_update,
509 } => {
510 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
511 barrier: self.compile_variable(barrier),
512 arrive_count_update: self.compile_variable(arrive_count_update),
513 transaction_count_update: self.compile_variable(transaction_count_update),
514 }))
515 }
516 gpu::BarrierOps::ExpectTx {
517 barrier,
518 transaction_count_update,
519 } => {
520 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
521 barrier: self.compile_variable(barrier),
522 transaction_count_update: self.compile_variable(transaction_count_update),
523 }))
524 }
525 gpu::BarrierOps::Wait { barrier } => {
526 let VariableKind::Barrier { level, .. } = barrier.kind else {
527 unreachable!()
528 };
529 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
530 barrier: self.compile_variable(barrier),
531 level,
532 }))
533 }
534 gpu::BarrierOps::ArriveAndWait { barrier } => {
535 let VariableKind::Barrier { level, .. } = barrier.kind else {
536 unreachable!()
537 };
538 instructions.push(Instruction::Barrier(
539 super::barrier::BarrierOps::ArriveAndWait {
540 barrier: self.compile_variable(barrier),
541 level,
542 },
543 ))
544 }
545 },
546 gpu::Operation::Tma(tma_ops) => {
547 self.flags.inst_tma = true;
548 match tma_ops {
549 gpu::TmaOps::TmaStore {
550 source,
551 coordinates,
552 offset_source,
553 } => {
554 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
555 smem_buffer: self.compile_variable(source),
556 smem_offset: self.compile_variable(offset_source),
557 tensor_map: self.compile_variable(out.unwrap()),
558 indices: coordinates
559 .into_iter()
560 .map(|it| self.compile_variable(it))
561 .collect(),
562 });
563 }
564 gpu::TmaOps::CommitGroup => {
565 instructions.push(Instruction::BulkCommitGroup);
566 }
567 gpu::TmaOps::WaitGroup { max_pending } => {
568 instructions.push(Instruction::BulkWaitGroup { max_pending });
569 }
570 gpu::TmaOps::WaitGroupRead { max_pending } => {
571 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
572 }
573 }
574 }
575 }
576 }
577
578 fn update_debug_loc(
579 &mut self,
580 instructions: &mut Vec<Instruction<D>>,
581 inst: &gpu::Instruction,
582 ) {
583 if !matches!(inst.operation, Operation::NonSemantic(_)) {
584 match &inst.source_loc {
585 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
586 self.source_loc = Some(loc.clone());
587 instructions.push(Instruction::Line {
588 file: loc.source.file.clone(),
589 line: loc.line,
590 });
591 }
592 _ => {}
593 }
594 }
595 }
596
597 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
598 let out = self.compile_variable(out.unwrap());
599 match cmma {
600 gpu::CoopMma::Fill { value } => Instruction::Wmma(WmmaInstruction::Fill {
601 frag: out,
602 value: self.compile_variable(value),
603 }),
604 gpu::CoopMma::Load {
605 value,
606 stride,
607 offset,
608 layout,
609 } => Instruction::Wmma(WmmaInstruction::Load {
610 frag: out,
611 offset: self.compile_variable(offset),
612 value: self.compile_variable(value),
613 stride: self.compile_variable(stride),
614 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
615 }),
616 gpu::CoopMma::Execute {
617 mat_a,
618 mat_b,
619 mat_c,
620 } => Instruction::Wmma(WmmaInstruction::Execute {
621 frag_a: self.compile_variable(mat_a),
622 frag_b: self.compile_variable(mat_b),
623 frag_c: self.compile_variable(mat_c),
624 frag_d: out,
625 warp_size: self.compilation_options.warp_size,
626 }),
627 gpu::CoopMma::Store {
628 mat,
629 stride,
630 offset,
631 layout,
632 } => {
633 self.flags.indexes.unit_pos = true;
634 self.flags.indexes.plane_index = true;
635 Instruction::Wmma(WmmaInstruction::Store {
636 output: out,
637 offset: self.compile_variable(offset),
638 frag: self.compile_variable(mat),
639 stride: self.compile_variable(stride),
640 layout: self
641 .compile_matrix_layout(layout)
642 .expect("Layout required for store instruction"),
643 })
644 }
645 gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast {
646 input: self.compile_variable(input),
647 output: out,
648 }),
649 }
650 }
651
652 fn compile_metadata(
653 &mut self,
654 metadata: gpu::Metadata,
655 out: Option<gpu::Variable>,
656 ) -> Instruction<D> {
657 let out = out.unwrap();
658 match metadata {
659 gpu::Metadata::Stride { dim, var } => {
660 let position = self.ext_meta_position(var);
661 let offset = self.metadata.stride_offset_index(position);
662 Instruction::ExtendedMetadata {
663 info_offset: self.compile_variable(offset.into()),
664 dim: self.compile_variable(dim),
665 split_meta: self.compilation_options.grid_constants,
666 static_offset: self.metadata.static_len(),
667 out: self.compile_variable(out),
668 }
669 }
670 gpu::Metadata::Shape { dim, var } => {
671 let position = self.ext_meta_position(var);
672 let offset = self.metadata.shape_offset_index(position);
673 Instruction::ExtendedMetadata {
674 info_offset: self.compile_variable(offset.into()),
675 dim: self.compile_variable(dim),
676 split_meta: self.compilation_options.grid_constants,
677 static_offset: self.metadata.static_len(),
678 out: self.compile_variable(out),
679 }
680 }
681 gpu::Metadata::Rank { var } => {
682 let out = self.compile_variable(out);
683 let pos = self.ext_meta_position(var);
684 let offset = self.metadata.rank_index(pos);
685 super::Instruction::Metadata {
686 info_offset: self.compile_variable(offset.into()),
687 split_meta: self.compilation_options.grid_constants,
688 out,
689 }
690 }
691 gpu::Metadata::Length { var } => {
692 let input = self.compile_variable(var);
693 let out = self.compile_variable(out);
694
695 match input {
696 Variable::Slice { .. } => Instruction::SliceLength { input, out },
697 Variable::SharedMemory(_id, _item, length) => {
698 Instruction::ConstLength { length, out }
699 }
700 _ => {
701 let id = input.id().expect("Variable should have id");
702 let offset = self.metadata.len_index(id);
703 Instruction::Metadata {
704 info_offset: self.compile_variable(offset.into()),
705 split_meta: self.compilation_options.grid_constants,
706 out,
707 }
708 }
709 }
710 }
711 gpu::Metadata::BufferLength { var } => {
712 let input = self.compile_variable(var);
713 let out = self.compile_variable(out);
714
715 match input {
716 Variable::Slice { .. } => Instruction::SliceLength { input, out },
717 _ => {
718 let id = input.id().expect("Variable should have id");
719 let offset = self.metadata.buffer_len_index(id);
720 Instruction::Metadata {
721 info_offset: self.compile_variable(offset.into()),
722 split_meta: self.compilation_options.grid_constants,
723 out,
724 }
725 }
726 }
727 }
728 }
729 }
730
731 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
732 match branch {
733 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
734 cond: self.compile_variable(op.cond),
735 instructions: self.compile_scope(&mut op.scope),
736 }),
737 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
738 cond: self.compile_variable(op.cond),
739 instructions_if: self.compile_scope(&mut op.scope_if),
740 instructions_else: self.compile_scope(&mut op.scope_else),
741 }),
742 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
743 value: self.compile_variable(op.value),
744 instructions_default: self.compile_scope(&mut op.scope_default),
745 instructions_cases: op
746 .cases
747 .into_iter()
748 .map(|(val, mut block)| {
749 (self.compile_variable(val), self.compile_scope(&mut block))
750 })
751 .collect(),
752 }),
753 gpu::Branch::Return => instructions.push(Instruction::Return),
754 gpu::Branch::Break => instructions.push(Instruction::Break),
755 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
756 i: self.compile_variable(range_loop.i),
757 start: self.compile_variable(range_loop.start),
758 end: self.compile_variable(range_loop.end),
759 step: range_loop.step.map(|it| self.compile_variable(it)),
760 inclusive: range_loop.inclusive,
761 instructions: self.compile_scope(&mut range_loop.scope),
762 }),
763 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
764 instructions: self.compile_scope(&mut op.scope),
765 }),
766 };
767 }
768
769 fn compile_atomic(
770 &mut self,
771 value: gpu::AtomicOp,
772 out: Option<gpu::Variable>,
773 instructions: &mut Vec<Instruction<D>>,
774 ) {
775 let out = out.unwrap();
776 match value {
777 gpu::AtomicOp::Load(op) => {
778 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
779 }
780 gpu::AtomicOp::Store(op) => {
781 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
782 }
783 gpu::AtomicOp::Swap(op) => {
784 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
785 }
786 gpu::AtomicOp::Add(op) => {
787 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
788 }
789 gpu::AtomicOp::Sub(op) => {
790 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
791 }
792 gpu::AtomicOp::Max(op) => {
793 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
794 }
795 gpu::AtomicOp::Min(op) => {
796 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
797 }
798 gpu::AtomicOp::And(op) => {
799 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
800 }
801 gpu::AtomicOp::Or(op) => {
802 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
803 }
804 gpu::AtomicOp::Xor(op) => {
805 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
806 }
807 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
808 input: self.compile_variable(op.input),
809 cmp: self.compile_variable(op.cmp),
810 val: self.compile_variable(op.val),
811 out: self.compile_variable(out),
812 }),
813 }
814 }
815
816 fn compile_arithmetic(
817 &mut self,
818 value: gpu::Arithmetic,
819 out: Option<gpu::Variable>,
820 instructions: &mut Vec<Instruction<D>>,
821 ) {
822 let out = out.unwrap();
823 match value {
824 gpu::Arithmetic::Add(op) => {
825 instructions.push(Instruction::Add(self.compile_binary(op, out)))
826 }
827 gpu::Arithmetic::Mul(op) => {
828 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
829 }
830 gpu::Arithmetic::Div(op) => {
831 instructions.push(Instruction::Div(self.compile_binary(op, out)))
832 }
833 gpu::Arithmetic::Sub(op) => {
834 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
835 }
836 gpu::Arithmetic::MulHi(op) => {
837 let instruction = Instruction::HiMul(self.compile_binary(op, out));
838 D::register_instruction_extension(&mut self.extensions, &instruction);
839 instructions.push(instruction)
840 }
841 gpu::Arithmetic::Modulo(op) => {
842 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
843 }
844 gpu::Arithmetic::Abs(op) => {
845 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
846 }
847 gpu::Arithmetic::Exp(op) => {
848 instructions.push(Instruction::Exp(self.compile_unary(op, out)))
849 }
850 gpu::Arithmetic::Log(op) => {
851 instructions.push(Instruction::Log(self.compile_unary(op, out)))
852 }
853 gpu::Arithmetic::Log1p(op) => {
854 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
855 }
856 gpu::Arithmetic::Cos(op) => {
857 instructions.push(Instruction::Cos(self.compile_unary(op, out)))
858 }
859 gpu::Arithmetic::Sin(op) => {
860 instructions.push(Instruction::Sin(self.compile_unary(op, out)))
861 }
862 gpu::Arithmetic::Tanh(op) => {
863 let instruction = Instruction::Tanh(self.compile_unary(op, out));
864 D::register_instruction_extension(&mut self.extensions, &instruction);
865 instructions.push(instruction)
866 }
867 gpu::Arithmetic::Powf(op) => {
868 instructions.push(Instruction::Powf(self.compile_binary(op, out)))
869 }
870 gpu::Arithmetic::Sqrt(op) => {
871 instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
872 }
873 gpu::Arithmetic::Erf(op) => {
874 let instruction = Instruction::Erf(self.compile_unary(op, out));
875 D::register_instruction_extension(&mut self.extensions, &instruction);
876 instructions.push(instruction)
877 }
878 gpu::Arithmetic::Max(op) => {
879 let instruction = Instruction::Max(self.compile_binary(op, out));
880 D::register_instruction_extension(&mut self.extensions, &instruction);
881 instructions.push(instruction)
882 }
883 gpu::Arithmetic::Min(op) => {
884 let instruction = Instruction::Min(self.compile_binary(op, out));
885 D::register_instruction_extension(&mut self.extensions, &instruction);
886 instructions.push(instruction)
887 }
888 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
889 input: self.compile_variable(op.input),
890 min_value: self.compile_variable(op.min_value),
891 max_value: self.compile_variable(op.max_value),
892 out: self.compile_variable(out),
893 }),
894 gpu::Arithmetic::Recip(op) => {
895 let elem = op.input.item.elem();
896 let lhs = match elem {
897 gpu::Elem::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
898 gpu::Elem::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
899 gpu::Elem::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
900 gpu::Elem::Bool => gpu::ConstantScalarValue::Bool(true),
901 gpu::Elem::AtomicInt(_)
902 | gpu::Elem::AtomicUInt(_)
903 | gpu::Elem::AtomicFloat(_) => {
904 panic!("Cannot use recip with atomics")
905 }
906 };
907
908 instructions.push(Instruction::Div(BinaryInstruction {
909 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
910 rhs: self.compile_variable(op.input),
911 out: self.compile_variable(out),
912 }))
913 }
914 gpu::Arithmetic::Round(op) => {
915 instructions.push(Instruction::Round(self.compile_unary(op, out)))
916 }
917 gpu::Arithmetic::Floor(op) => {
918 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
919 }
920 gpu::Arithmetic::Ceil(op) => {
921 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
922 }
923 gpu::Arithmetic::Remainder(op) => {
924 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
925 }
926 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
927 a: self.compile_variable(op.a),
928 b: self.compile_variable(op.b),
929 c: self.compile_variable(op.c),
930 out: self.compile_variable(out),
931 }),
932 gpu::Arithmetic::Neg(op) => {
933 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
934 }
935 gpu::Arithmetic::Normalize(op) => {
936 instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
937 }
938 gpu::Arithmetic::Magnitude(op) => {
939 instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
940 }
941 gpu::Arithmetic::Dot(op) => {
942 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
943 }
944 };
945 }
946
947 fn compile_comparison(
948 &mut self,
949 value: gpu::Comparison,
950 out: Option<gpu::Variable>,
951 instructions: &mut Vec<Instruction<D>>,
952 ) {
953 let out = out.unwrap();
954 match value {
955 gpu::Comparison::Equal(op) => {
956 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
957 }
958 gpu::Comparison::Lower(op) => {
959 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
960 }
961 gpu::Comparison::Greater(op) => {
962 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
963 }
964 gpu::Comparison::LowerEqual(op) => {
965 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
966 }
967 gpu::Comparison::GreaterEqual(op) => {
968 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
969 }
970 gpu::Comparison::NotEqual(op) => {
971 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
972 }
973 };
974 }
975
976 fn compile_bitwise(
977 &mut self,
978 value: gpu::Bitwise,
979 out: Option<gpu::Variable>,
980 instructions: &mut Vec<Instruction<D>>,
981 ) {
982 let out = out.unwrap();
983 match value {
984 gpu::Bitwise::BitwiseOr(op) => {
985 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
986 }
987 gpu::Bitwise::BitwiseAnd(op) => {
988 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
989 }
990 gpu::Bitwise::BitwiseXor(op) => {
991 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
992 }
993 gpu::Bitwise::CountOnes(op) => {
994 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
995 }
996 gpu::Bitwise::ReverseBits(op) => {
997 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
998 }
999 gpu::Bitwise::ShiftLeft(op) => {
1000 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1001 }
1002 gpu::Bitwise::ShiftRight(op) => {
1003 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1004 }
1005 gpu::Bitwise::BitwiseNot(op) => {
1006 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1007 }
1008 gpu::Bitwise::LeadingZeros(op) => {
1009 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1010 }
1011 gpu::Bitwise::FindFirstSet(op) => {
1012 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1013 D::register_instruction_extension(&mut self.extensions, &instruction);
1014 instructions.push(instruction)
1015 }
1016 };
1017 }
1018
1019 fn compile_operator(
1020 &mut self,
1021 value: gpu::Operator,
1022 out: Option<gpu::Variable>,
1023 instructions: &mut Vec<Instruction<D>>,
1024 ) {
1025 let out = out.unwrap();
1026 match value {
1027 gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1028 instructions.push(Instruction::Index(self.compile_index(op, out)));
1029 }
1030 gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1031 instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1032 }
1033 gpu::Operator::And(op) => {
1034 instructions.push(Instruction::And(self.compile_binary(op, out)))
1035 }
1036 gpu::Operator::Or(op) => {
1037 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1038 }
1039 gpu::Operator::Not(op) => {
1040 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1041 }
1042 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1043 inputs: op
1044 .inputs
1045 .into_iter()
1046 .map(|it| self.compile_variable(it))
1047 .collect(),
1048 out: self.compile_variable(out),
1049 }),
1050 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1051 input: self.compile_variable(op.input),
1052 in_index: self.compile_variable(op.in_index),
1053 out: self.compile_variable(out),
1054 out_index: self.compile_variable(op.out_index),
1055 }),
1056 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1057 input: self.compile_variable(op.input),
1058 in_index: self.compile_variable(op.in_index),
1059 out: self.compile_variable(out),
1060 out_index: self.compile_variable(op.out_index),
1061 len: op.len.as_const().unwrap().as_u32(),
1062 }),
1063 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1064 cond: self.compile_variable(op.cond),
1065 then: self.compile_variable(op.then),
1066 or_else: self.compile_variable(op.or_else),
1067 out: self.compile_variable(out),
1068 }),
1069 gpu::Operator::Cast(op)
1071 if is_fp4_fp6_fp8(op.input.elem()) || is_fp4_fp6_fp8(out.elem()) =>
1072 {
1073 let inst = self.compile_unary(op, out);
1074
1075 self.flags.elem_f16 = true;
1077 self.flags.elem_bf16 = true;
1078 let vec = NonZero::new(inst.input.item().vectorization as u8);
1079 self.compile_item(gpu::Item::vectorized(gpu::Elem::Float(FloatKind::F16), vec));
1080 self.compile_item(gpu::Item::vectorized(
1081 gpu::Elem::Float(FloatKind::BF16),
1082 vec,
1083 ));
1084
1085 instructions.push(Instruction::SpecialCast(inst));
1086 }
1087 gpu::Operator::Cast(op) => {
1088 instructions.push(Instruction::Assign(self.compile_unary(op, out)))
1089 }
1090 gpu::Operator::Reinterpret(op) => {
1091 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1092 }
1093 };
1094 }
1095
1096 fn compile_binary(
1097 &mut self,
1098 value: gpu::BinaryOperator,
1099 out: gpu::Variable,
1100 ) -> BinaryInstruction<D> {
1101 BinaryInstruction {
1102 lhs: self.compile_variable(value.lhs),
1103 rhs: self.compile_variable(value.rhs),
1104 out: self.compile_variable(out),
1105 }
1106 }
1107
1108 fn compile_index_assign(
1109 &mut self,
1110 value: gpu::IndexAssignOperator,
1111 out: gpu::Variable,
1112 ) -> IndexAssignInstruction<D> {
1113 IndexAssignInstruction {
1114 index: self.compile_variable(value.index),
1115 value: self.compile_variable(value.value),
1116 line_size: value.line_size,
1117 out: self.compile_variable(out),
1118 }
1119 }
1120
1121 fn compile_index(
1122 &mut self,
1123 value: gpu::IndexOperator,
1124 out: gpu::Variable,
1125 ) -> IndexInstruction<D> {
1126 IndexInstruction {
1127 list: self.compile_variable(value.list),
1128 index: self.compile_variable(value.index),
1129 line_size: value.line_size,
1130 out: self.compile_variable(out),
1131 }
1132 }
1133
1134 fn compile_unary(
1135 &mut self,
1136 value: gpu::UnaryOperator,
1137 out: gpu::Variable,
1138 ) -> UnaryInstruction<D> {
1139 UnaryInstruction {
1140 input: self.compile_variable(value.input),
1141 out: self.compile_variable(out),
1142 }
1143 }
1144
1145 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1146 let item = value.item;
1147 match value.kind {
1148 gpu::VariableKind::GlobalInputArray(id) => {
1149 Variable::GlobalInputArray(id, self.compile_item(item))
1150 }
1151 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1152 id,
1153 elem: self.compile_elem(item.elem),
1154 in_struct: self.compilation_options.grid_constants,
1155 },
1156 gpu::VariableKind::TensorMap(id) => {
1157 self.flags.inst_tma = true;
1158 Variable::TensorMap(id)
1159 }
1160 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1161 id,
1162 item: self.compile_item(item),
1163 },
1164 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1165 id,
1166 item: self.compile_item(item),
1167 },
1168 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1169 id,
1170 item: self.compile_item(item),
1171 },
1172 gpu::VariableKind::GlobalOutputArray(id) => {
1173 Variable::GlobalOutputArray(id, self.compile_item(item))
1174 }
1175 gpu::VariableKind::ConstantScalar(value) => {
1176 Variable::ConstantScalar(value, self.compile_elem(value.elem()))
1177 }
1178 gpu::VariableKind::SharedMemory {
1179 id,
1180 length,
1181 alignment,
1182 } => {
1183 let item = self.compile_item(item);
1184 if !self.shared_memories.iter().any(|s| s.index == id) {
1185 self.shared_memories
1186 .push(SharedMemory::new(id, item, length, alignment));
1187 }
1188 Variable::SharedMemory(id, item, length)
1189 }
1190 gpu::VariableKind::ConstantArray { id, length } => {
1191 let item = self.compile_item(item);
1192 Variable::ConstantArray(id, item, length)
1193 }
1194 gpu::VariableKind::Builtin(builtin) => match builtin {
1195 gpu::Builtin::AbsolutePos => {
1196 self.flags.indexes.absolute_pos = true;
1197 Variable::AbsolutePos
1198 }
1199 gpu::Builtin::CubePosCluster if self.compilation_options.supports_clusters => {
1200 self.flags.indexes.cluster_pos = true;
1201 Variable::ClusterRank
1202 }
1203 gpu::Builtin::CubePosClusterX if self.compilation_options.supports_clusters => {
1204 self.flags.indexes.cluster_pos = true;
1205 Variable::ClusterIndexX
1206 }
1207 gpu::Builtin::CubePosClusterY if self.compilation_options.supports_clusters => {
1208 self.flags.indexes.cluster_pos = true;
1209 Variable::ClusterIndexY
1210 }
1211 gpu::Builtin::CubePosClusterZ if self.compilation_options.supports_clusters => {
1212 self.flags.indexes.cluster_pos = true;
1213 Variable::ClusterIndexZ
1214 }
1215 gpu::Builtin::CubePosCluster
1218 | gpu::Builtin::CubePosClusterX
1219 | gpu::Builtin::CubePosClusterY
1220 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1221 gpu::Builtin::AbsolutePosX => {
1222 self.flags.indexes.absolute_pos_tuple = true;
1223 Variable::AbsolutePosX
1224 }
1225 gpu::Builtin::AbsolutePosY => {
1226 self.flags.indexes.absolute_pos_tuple = true;
1227 Variable::AbsolutePosY
1228 }
1229 gpu::Builtin::AbsolutePosZ => {
1230 self.flags.indexes.absolute_pos_tuple = true;
1231 Variable::AbsolutePosZ
1232 }
1233 gpu::Builtin::CubeDim => {
1234 self.flags.indexes.cube_dim = true;
1235 Variable::CubeDim
1236 }
1237 gpu::Builtin::CubeDimX => {
1238 self.flags.indexes.cube_dim_tuple = true;
1239 Variable::CubeDimX
1240 }
1241 gpu::Builtin::CubeDimY => {
1242 self.flags.indexes.cube_dim_tuple = true;
1243 Variable::CubeDimY
1244 }
1245 gpu::Builtin::CubeDimZ => {
1246 self.flags.indexes.cube_dim_tuple = true;
1247 Variable::CubeDimZ
1248 }
1249 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1250 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1251 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1252 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1253 gpu::Builtin::CubePos => {
1254 self.flags.indexes.cube_pos = true;
1255 Variable::CubePos
1256 }
1257 gpu::Builtin::CubePosX => {
1258 self.flags.indexes.cube_pos_tuple = true;
1259 Variable::CubePosX
1260 }
1261 gpu::Builtin::CubePosY => {
1262 self.flags.indexes.cube_pos_tuple = true;
1263 Variable::CubePosY
1264 }
1265 gpu::Builtin::CubePosZ => {
1266 self.flags.indexes.cube_pos_tuple = true;
1267 Variable::CubePosZ
1268 }
1269 gpu::Builtin::CubeCount => {
1270 self.flags.indexes.cube_count = true;
1271 Variable::CubeCount
1272 }
1273 gpu::Builtin::CubeCountX => {
1274 self.flags.indexes.cube_count_tuple = true;
1275 Variable::CubeCountX
1276 }
1277 gpu::Builtin::CubeCountY => {
1278 self.flags.indexes.cube_count_tuple = true;
1279 Variable::CubeCountY
1280 }
1281 gpu::Builtin::CubeCountZ => {
1282 self.flags.indexes.cube_count_tuple = true;
1283 Variable::CubeCountZ
1284 }
1285 gpu::Builtin::UnitPos => {
1286 self.flags.indexes.unit_pos = true;
1287 Variable::UnitPos
1288 }
1289 gpu::Builtin::UnitPosX => {
1290 self.flags.indexes.unit_pos_tuple = true;
1291 Variable::UnitPosX
1292 }
1293 gpu::Builtin::UnitPosY => {
1294 self.flags.indexes.unit_pos_tuple = true;
1295 Variable::UnitPosY
1296 }
1297 gpu::Builtin::UnitPosZ => {
1298 self.flags.indexes.unit_pos_tuple = true;
1299 Variable::UnitPosZ
1300 }
1301 gpu::Builtin::PlaneDim => {
1302 self.flags.indexes.plane_dim = true;
1303 Variable::PlaneDim
1304 }
1305 gpu::Builtin::UnitPosPlane => {
1306 self.flags.indexes.unit_pos_plane = true;
1307 Variable::UnitPosPlane
1308 }
1309 },
1310 gpu::VariableKind::LocalArray { id, length } => {
1311 let item = self.compile_item(item);
1312 if !self.local_arrays.iter().any(|s| s.index == id) {
1313 self.local_arrays.push(LocalArray::new(id, item, length));
1314 }
1315 Variable::LocalArray(id, item, length)
1316 }
1317 gpu::VariableKind::Matrix { id, mat } => {
1318 self.flags.inst_wmma = true;
1319 Variable::WmmaFragment {
1320 id,
1321 frag: self.compile_matrix(mat),
1322 }
1323 }
1324 gpu::VariableKind::Pipeline {
1325 id,
1326 item,
1327 num_stages,
1328 } => {
1329 self.flags.op_pipeline = true;
1330 let pipeline = Variable::Pipeline {
1331 id,
1332 item: self.compile_item(item),
1333 };
1334 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1335 self.pipelines.push(PipelineOps::Init {
1336 pipeline,
1337 num_stages,
1338 });
1339 }
1340 pipeline
1341 }
1342 gpu::VariableKind::Barrier { id, item, level } => {
1343 self.flags.op_barrier = true;
1344 match level {
1345 gpu::BarrierLevel::CubeCoop(_) | gpu::BarrierLevel::CubeManual(_) => {
1346 self.flags.indexes.cube_dim = true;
1347 self.flags.indexes.unit_pos = true;
1348 }
1349 _ => {}
1350 }
1351 Variable::Barrier {
1352 id,
1353 item: self.compile_item(item),
1354 level,
1355 }
1356 }
1357 }
1358 }
1359
1360 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1361 Fragment {
1362 ident: self.compile_matrix_ident(matrix.ident),
1363 m: matrix.m,
1364 n: matrix.n,
1365 k: matrix.k,
1366 elem: self.compile_elem(matrix.elem),
1367 layout: self.compile_matrix_layout(matrix.layout),
1368 }
1369 }
1370
1371 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1372 match ident {
1373 gpu::MatrixIdent::A => FragmentIdent::A,
1374 gpu::MatrixIdent::B => FragmentIdent::B,
1375 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1376 }
1377 }
1378
1379 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1380 match layout {
1381 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1382 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1383 gpu::MatrixLayout::Undefined => None,
1384 }
1385 }
1386
1387 fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1388 Binding {
1389 id: binding.id,
1390 item: self.compile_item(binding.item),
1391 location: binding.location,
1392 size: binding.size,
1393 vis: binding.visibility,
1394 }
1395 }
1396
1397 fn compile_item(&mut self, item: gpu::Item) -> Item<D> {
1398 let item = Item::new(
1399 self.compile_elem(item.elem),
1400 item.vectorization.map(NonZero::get).unwrap_or(1).into(),
1401 false,
1402 );
1403 if item.elem != super::Elem::TF32 {
1404 self.items.insert(item);
1405 self.items.insert(item.optimized());
1406 } else {
1407 let mut item = item;
1409 item.elem = super::Elem::F32;
1410 self.items.insert(item);
1411 }
1412
1413 item
1414 }
1415
1416 fn compile_elem(&mut self, value: gpu::Elem) -> Elem<D> {
1417 match value {
1418 gpu::Elem::Float(kind) => match kind {
1419 gpu::FloatKind::E2M1 => {
1420 self.flags.elem_fp4 = true;
1421 Elem::FP4(FP4Kind::E2M1)
1422 }
1423 gpu::FloatKind::E2M3 => {
1424 self.flags.elem_fp6 = true;
1425 Elem::FP6(FP6Kind::E2M3)
1426 }
1427 gpu::FloatKind::E3M2 => {
1428 self.flags.elem_fp6 = true;
1429 Elem::FP6(FP6Kind::E3M2)
1430 }
1431 gpu::FloatKind::E4M3 => {
1432 self.flags.elem_fp8 = true;
1433 Elem::FP8(FP8Kind::E4M3)
1434 }
1435 gpu::FloatKind::E5M2 => {
1436 self.flags.elem_fp8 = true;
1437 Elem::FP8(FP8Kind::E5M2)
1438 }
1439 gpu::FloatKind::UE8M0 => {
1440 self.flags.elem_fp8 = true;
1441 Elem::FP8(FP8Kind::UE8M0)
1442 }
1443 gpu::FloatKind::F16 => {
1444 self.flags.elem_f16 = true;
1445 Elem::F16
1446 }
1447 gpu::FloatKind::BF16 => {
1448 self.flags.elem_bf16 = true;
1449 Elem::BF16
1450 }
1451 gpu::FloatKind::TF32 => Elem::TF32,
1452 gpu::FloatKind::Flex32 => Elem::F32,
1453 gpu::FloatKind::F32 => Elem::F32,
1454 gpu::FloatKind::F64 => Elem::F64,
1455 },
1456 gpu::Elem::AtomicFloat(kind) => match kind {
1457 gpu::FloatKind::F16 => Elem::Atomic(AtomicKind::F16),
1458 gpu::FloatKind::BF16 => Elem::Atomic(AtomicKind::BF16),
1459 gpu::FloatKind::F32 => Elem::Atomic(AtomicKind::F32),
1460 gpu::FloatKind::F64 => Elem::Atomic(AtomicKind::F64),
1461 kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1462 },
1463 gpu::Elem::Int(kind) => match kind {
1464 gpu::IntKind::I8 => Elem::I8,
1465 gpu::IntKind::I16 => Elem::I16,
1466 gpu::IntKind::I32 => Elem::I32,
1467 gpu::IntKind::I64 => Elem::I64,
1468 },
1469 gpu::Elem::AtomicInt(kind) => match kind {
1470 gpu::IntKind::I32 => Elem::Atomic(AtomicKind::I32),
1471 gpu::IntKind::I64 => Elem::Atomic(AtomicKind::I64),
1472 kind => panic!("atomic<{kind:?}> isn't supported yet"),
1473 },
1474 gpu::Elem::UInt(kind) => match kind {
1475 gpu::UIntKind::U8 => Elem::U8,
1476 gpu::UIntKind::U16 => Elem::U16,
1477 gpu::UIntKind::U32 => Elem::U32,
1478 gpu::UIntKind::U64 => Elem::U64,
1479 },
1480 gpu::Elem::AtomicUInt(kind) => match kind {
1481 gpu::UIntKind::U32 => Elem::Atomic(AtomicKind::U32),
1482 gpu::UIntKind::U64 => Elem::Atomic(AtomicKind::U64),
1483 kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1484 },
1485 gpu::Elem::Bool => Elem::Bool,
1486 }
1487 }
1488}
1489
1490fn is_fp4_fp6_fp8(elem: gpu::Elem) -> bool {
1491 match elem {
1492 gpu::Elem::Float(kind) => matches!(
1493 kind,
1494 FloatKind::E2M1
1495 | FloatKind::E2M3
1496 | FloatKind::E3M2
1497 | FloatKind::E4M3
1498 | FloatKind::E5M2
1499 | FloatKind::UE8M0
1500 ),
1501 _ => false,
1502 }
1503}
1504
1505fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1506 Variable::ConstantScalar(
1507 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1508 Elem::U32,
1509 )
1510}
1511
1512pub fn register_supported_types(props: &mut DeviceProperties<Feature>) {
1513 let supported_types = [
1514 gpu::Elem::UInt(gpu::UIntKind::U8),
1515 gpu::Elem::UInt(gpu::UIntKind::U16),
1516 gpu::Elem::UInt(gpu::UIntKind::U32),
1517 gpu::Elem::UInt(gpu::UIntKind::U64),
1518 gpu::Elem::Int(gpu::IntKind::I8),
1519 gpu::Elem::Int(gpu::IntKind::I16),
1520 gpu::Elem::Int(gpu::IntKind::I32),
1521 gpu::Elem::Int(gpu::IntKind::I64),
1522 gpu::Elem::AtomicInt(gpu::IntKind::I32),
1523 gpu::Elem::AtomicInt(gpu::IntKind::I64),
1524 gpu::Elem::AtomicUInt(gpu::UIntKind::U32),
1525 gpu::Elem::AtomicUInt(gpu::UIntKind::U64),
1526 gpu::Elem::Float(gpu::FloatKind::BF16),
1527 gpu::Elem::Float(gpu::FloatKind::F16),
1528 gpu::Elem::Float(gpu::FloatKind::F32),
1529 gpu::Elem::Float(gpu::FloatKind::Flex32),
1530 gpu::Elem::AtomicFloat(gpu::FloatKind::F32),
1531 gpu::Elem::Bool,
1534 ];
1535
1536 for ty in supported_types {
1537 props.register_feature(Feature::Type(ty));
1538 }
1539}