cubecl_cpp/shared/
base.rs

1use std::{collections::HashSet, fmt::Debug, num::NonZero};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::CubeDim;
5use cubecl_core::ir::{FloatKind, Processor, UIntKind, VariableKind};
6use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
7use cubecl_core::{
8    Compiler, Feature,
9    ir::{self as gpu},
10};
11use cubecl_core::{
12    ir::{Operation, SourceLoc},
13    prelude::{FastMath, KernelDefinition},
14};
15use cubecl_runtime::DeviceProperties;
16
17use super::{
18    AtomicKind, BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect,
19    Elem, FP6Kind, Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction,
20    IndexInstruction, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction, Variable,
21    WarpInstruction, WmmaInstruction,
22};
23use super::{FP4Kind, barrier::BarrierOps};
24use super::{FP8Kind, pipeline::PipelineOps};
25
26pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
27    std::sync::atomic::AtomicU32::new(0);
28
29#[derive(Clone, Debug)]
30pub struct CompilationOptions {
31    pub warp_size: u32,
32    pub grid_constants: bool,
33    pub supports_clusters: bool,
34}
35
36impl Default for CompilationOptions {
37    fn default() -> Self {
38        Self {
39            warp_size: 32,
40            grid_constants: false,
41            supports_clusters: false,
42        }
43    }
44}
45
46/// Cube indexes flags.
47/// When true the corresponding index is declared and computed as needed in the kernel.
48#[derive(Debug, Clone, Default)]
49pub struct CubeIndexFlags {
50    pub absolute_pos: bool,
51    pub absolute_pos_tuple: bool,
52    pub cube_count: bool,
53    pub cube_count_tuple: bool,
54    pub cube_dim: bool,
55    pub cube_dim_tuple: bool,
56    pub cube_pos: bool,
57    pub cube_pos_tuple: bool,
58    pub plane_dim: bool,
59    pub plane_dim_checked: bool,
60    pub plane_index: bool,
61    pub unit_pos: bool,
62    pub unit_pos_tuple: bool,
63    pub unit_pos_plane: bool,
64    pub cluster_pos: bool,
65}
66
67/// Flags gathered during Cube IR translation for the kernel compilation.
68#[derive(Debug, Clone, Default)]
69pub struct Flags {
70    pub elem_fp4: bool,
71    pub elem_fp6: bool,
72    pub elem_fp8: bool,
73    pub elem_bf16: bool,
74    pub elem_f16: bool,
75    pub indexes: CubeIndexFlags,
76    pub op_barrier: bool,
77    pub op_pipeline: bool,
78    pub inst_fast_math: bool,
79    pub inst_tma: bool,
80    pub inst_tma_im2col: bool,
81    pub inst_wmma: bool,
82    pub use_grid_constants: bool,
83    pub static_meta_length: usize,
84    pub has_dynamic_meta: bool,
85    pub cluster_dim: Option<CubeDim>,
86}
87
88#[allow(clippy::too_many_arguments)]
89#[derive(Clone, Debug, Default)]
90pub struct CppCompiler<D: Dialect> {
91    barriers: Vec<BarrierOps<D>>,
92    compilation_options: CompilationOptions,
93    const_arrays: Vec<ConstArray<D>>,
94    ext_meta_positions: Vec<u32>,
95    cluster_dim: CubeDim,
96    extensions: Vec<D::Extension>,
97    flags: Flags,
98    items: HashSet<Item<D>>,
99    local_arrays: Vec<LocalArray<D>>,
100    metadata: cubecl_core::Metadata,
101    pipelines: Vec<PipelineOps<D>>,
102    shared_memories: Vec<SharedMemory<D>>,
103    source_loc: Option<SourceLoc>,
104    strategy: ExecutionMode,
105}
106
107impl<D: Dialect> Compiler for CppCompiler<D> {
108    type Representation = ComputeKernel<D>;
109    type CompilationOptions = CompilationOptions;
110
111    fn compile(
112        &mut self,
113        mut kernel: KernelDefinition,
114        compilation_options: &Self::CompilationOptions,
115        strategy: ExecutionMode,
116    ) -> Self::Representation {
117        self.compilation_options = compilation_options.clone();
118        self.strategy = strategy;
119
120        if !self.compilation_options.supports_clusters {
121            kernel.options.cluster_dim = None;
122        }
123        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
124
125        let ir = self.clone().compile_ir(kernel);
126        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
127        ir
128    }
129
130    fn elem_size(&self, elem: gpu::Elem) -> usize {
131        elem.size()
132    }
133
134    fn extension(&self) -> &'static str {
135        "cpp"
136    }
137}
138
139impl<D: Dialect> CppCompiler<D> {
140    fn compile_ir(mut self, mut value: KernelDefinition) -> ComputeKernel<D> {
141        self.build_metadata(&value);
142
143        let instructions = self.compile_scope(&mut value.body);
144        let buffers = value
145            .buffers
146            .into_iter()
147            .map(|b| self.compile_binding(b))
148            .collect();
149        let scalars = value
150            .scalars
151            .into_iter()
152            .map(|binding| (self.compile_elem(binding.elem), binding.count))
153            .collect();
154
155        // translation flags
156        let flags = Flags {
157            indexes: D::builtin_rules(&self.flags.indexes),
158            inst_wmma: self.flags.inst_wmma,
159            op_pipeline: self.flags.op_pipeline,
160            op_barrier: self.flags.op_barrier,
161            elem_fp4: self.flags.elem_fp4,
162            elem_fp6: self.flags.elem_fp6,
163            elem_fp8: self.flags.elem_fp8,
164            elem_bf16: self.flags.elem_bf16,
165            elem_f16: self.flags.elem_f16,
166            inst_fast_math: value
167                .options
168                .fp_math_mode
169                .contains(FastMath::ReducedPrecision),
170            inst_tma: self.flags.inst_tma,
171            inst_tma_im2col: self.flags.inst_tma_im2col,
172            use_grid_constants: self.compilation_options.grid_constants,
173            // TODO: At some point we should only pass dynamic meta if tensors are present,
174            // not if only arrays are present. For now, leave like this
175            has_dynamic_meta: self.metadata.static_len() > 0,
176            static_meta_length: self.metadata.static_len() as usize,
177            cluster_dim: value.options.cluster_dim,
178        };
179
180        let body = Body {
181            instructions,
182            shared_memories: self.shared_memories,
183            pipelines: self.pipelines,
184            barriers: self.barriers,
185            const_arrays: self.const_arrays,
186            local_arrays: self.local_arrays,
187        };
188
189        let mut cluster_dim = value.options.cluster_dim;
190        if !self.compilation_options.supports_clusters {
191            cluster_dim = None;
192        }
193
194        ComputeKernel {
195            tensor_maps: value.tensor_maps,
196            buffers,
197            scalars,
198            meta_static_len: self.metadata.static_len() as usize,
199            cube_dim: value.cube_dim,
200            body,
201            extensions: self.extensions,
202            flags,
203            items: self.items,
204            kernel_name: value.options.kernel_name,
205            cluster_dim,
206        }
207    }
208
209    fn build_metadata(&mut self, value: &KernelDefinition) {
210        let mut num_ext = 0;
211
212        let mut all_meta: Vec<_> = value
213            .buffers
214            .iter()
215            .map(|buf| (buf.id, buf.has_extended_meta))
216            .chain(value.tensor_maps.iter().map(|i| (*i, true)))
217            .collect();
218
219        all_meta.sort_by_key(|(id, _)| *id);
220
221        for (_, has_extended_meta) in &all_meta {
222            self.ext_meta_positions.push(num_ext);
223            if *has_extended_meta {
224                num_ext += 1;
225            }
226        }
227
228        let num_meta = all_meta.len();
229
230        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
231    }
232
233    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
234        let id = var.index().expect("Variable should have index");
235        self.ext_meta_positions[id as usize]
236    }
237
238    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
239        let mut instructions = Vec::new();
240
241        let const_arrays = scope
242            .const_arrays
243            .drain(..)
244            .map(|(var, values)| ConstArray {
245                index: var.index().unwrap(),
246                item: self.compile_item(var.item),
247                size: values.len() as u32,
248                values: values
249                    .into_iter()
250                    .map(|val| self.compile_variable(val))
251                    .collect(),
252            })
253            .collect::<Vec<_>>();
254        self.const_arrays.extend(const_arrays);
255
256        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
257        let processing = scope.process([checked_io]);
258
259        for var in processing.variables {
260            instructions.push(Instruction::DeclareVariable {
261                var: self.compile_variable(var),
262            });
263        }
264
265        processing
266            .instructions
267            .into_iter()
268            .for_each(|op| self.compile_instruction(&mut instructions, op));
269
270        instructions
271    }
272
273    fn compile_instruction(
274        &mut self,
275        instructions: &mut Vec<Instruction<D>>,
276        instruction: gpu::Instruction,
277    ) {
278        self.update_debug_loc(instructions, &instruction);
279        let out = instruction.out;
280        match instruction.operation {
281            gpu::Operation::Copy(variable) => {
282                instructions.push(Instruction::Assign(UnaryInstruction {
283                    input: self.compile_variable(variable),
284                    out: self.compile_variable(out.unwrap()),
285                }));
286            }
287            gpu::Operation::Arithmetic(op) => self.compile_arithmetic(op, out, instructions),
288            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
289            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
290            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
291            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
292            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
293            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
294            gpu::Operation::Synchronization(val) => match val {
295                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
296                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
297                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
298                gpu::Synchronization::SyncProxyShared => {
299                    self.flags.inst_tma = true;
300                    instructions.push(Instruction::ProxySharedFence)
301                }
302            },
303            gpu::Operation::Plane(op) => {
304                self.flags.indexes.plane_dim_checked = true;
305                let out = self.compile_variable(out.unwrap());
306                match op {
307                    gpu::Plane::Sum(op) => {
308                        let instruction = WarpInstruction::ReduceSum {
309                            input: self.compile_variable(op.input),
310                            out,
311                        };
312                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
313                        instructions.push(Instruction::Warp(instruction));
314                    }
315                    gpu::Plane::InclusiveSum(op) => {
316                        self.flags.indexes.unit_pos_plane = true;
317                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
318                            input: self.compile_variable(op.input),
319                            out,
320                        }))
321                    }
322                    gpu::Plane::InclusiveProd(op) => {
323                        self.flags.indexes.unit_pos_plane = true;
324                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
325                            input: self.compile_variable(op.input),
326                            out,
327                        }))
328                    }
329                    gpu::Plane::ExclusiveSum(op) => {
330                        self.flags.indexes.unit_pos_plane = true;
331                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
332                            input: self.compile_variable(op.input),
333                            out,
334                        }))
335                    }
336                    gpu::Plane::ExclusiveProd(op) => {
337                        self.flags.indexes.unit_pos_plane = true;
338                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
339                            input: self.compile_variable(op.input),
340                            out,
341                        }))
342                    }
343                    gpu::Plane::Prod(op) => {
344                        let instruction = WarpInstruction::ReduceProd {
345                            input: self.compile_variable(op.input),
346                            out,
347                        };
348                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
349                        instructions.push(Instruction::Warp(instruction))
350                    }
351                    gpu::Plane::Max(op) => {
352                        let instruction = WarpInstruction::ReduceMax {
353                            input: self.compile_variable(op.input),
354                            out,
355                        };
356                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
357                        instructions.push(Instruction::Warp(instruction))
358                    }
359                    gpu::Plane::Min(op) => {
360                        let instruction = WarpInstruction::ReduceMin {
361                            input: self.compile_variable(op.input),
362                            out,
363                        };
364                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
365                        instructions.push(Instruction::Warp(instruction))
366                    }
367                    gpu::Plane::Elect => {
368                        instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
369                    }
370                    gpu::Plane::All(op) => {
371                        instructions.push(Instruction::Warp(WarpInstruction::All {
372                            input: self.compile_variable(op.input),
373                            out,
374                        }))
375                    }
376                    gpu::Plane::Any(op) => {
377                        instructions.push(Instruction::Warp(WarpInstruction::Any {
378                            input: self.compile_variable(op.input),
379                            out,
380                        }))
381                    }
382                    gpu::Plane::Ballot(op) => {
383                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
384                            input: self.compile_variable(op.input),
385                            out,
386                        }))
387                    }
388                    gpu::Plane::Broadcast(op) => {
389                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
390                            input: self.compile_variable(op.lhs),
391                            id: self.compile_variable(op.rhs),
392                            out,
393                        }))
394                    }
395                }
396            }
397            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
398            gpu::Operation::NonSemantic(debug) => match debug {
399                gpu::NonSemantic::Print {
400                    format_string,
401                    args,
402                } => instructions.push(Instruction::Printf {
403                    format_string,
404                    args: args
405                        .into_iter()
406                        .map(|arg| self.compile_variable(arg))
407                        .collect(),
408                }),
409                gpu::NonSemantic::Comment { content } => {
410                    instructions.push(Instruction::Comment { content })
411                }
412                // Don't need to handle scopes
413                _ => {}
414            },
415            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
416                gpu::BarrierOps::Init {
417                    barrier,
418                    with_cta_fence,
419                } => {
420                    let VariableKind::Barrier { level, .. } = barrier.kind else {
421                        unreachable!()
422                    };
423                    let barrier = self.compile_variable(barrier);
424                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
425                        barrier,
426                        level,
427                        with_cta_fence,
428                    }));
429                }
430                gpu::BarrierOps::MemCopyAsync {
431                    barrier,
432                    source,
433                    source_length,
434                    offset_source,
435                    offset_out,
436                } => {
437                    let VariableKind::Barrier { level, .. } = barrier.kind else {
438                        unreachable!()
439                    };
440                    instructions.push(Instruction::Barrier(
441                        super::barrier::BarrierOps::MemCopyAsync {
442                            barrier: self.compile_variable(barrier),
443                            source: self.compile_variable(source),
444                            destination: self.compile_variable(out.unwrap()),
445                            source_length: self.compile_variable(source_length),
446                            offset_source: self.compile_variable(offset_source),
447                            offset_out: self.compile_variable(offset_out),
448                            level,
449                        },
450                    ));
451                }
452                gpu::BarrierOps::TmaLoad {
453                    barrier,
454                    tensor_map,
455                    offset_out,
456                    indices,
457                } => {
458                    instructions.push(Instruction::Barrier(
459                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
460                            barrier: self.compile_variable(barrier),
461                            smem_buffer: self.compile_variable(out.unwrap()),
462                            smem_offset: self.compile_variable(offset_out),
463                            tensor_map: self.compile_variable(tensor_map),
464                            indices: indices
465                                .into_iter()
466                                .map(|it| self.compile_variable(it))
467                                .collect(),
468                        },
469                    ));
470                }
471                gpu::BarrierOps::TmaLoadIm2col {
472                    barrier,
473                    tensor_map,
474                    offset_out,
475                    indices,
476                    offsets,
477                } => {
478                    self.flags.inst_tma_im2col = true;
479                    instructions.push(Instruction::Barrier(
480                        super::barrier::BarrierOps::TmaLoadIm2col {
481                            barrier: self.compile_variable(barrier),
482                            smem_buffer: self.compile_variable(out.unwrap()),
483                            smem_offset: self.compile_variable(offset_out),
484                            tensor_map: self.compile_variable(tensor_map),
485                            indices: indices
486                                .into_iter()
487                                .map(|it| self.compile_variable(it))
488                                .collect(),
489                            offsets: offsets
490                                .into_iter()
491                                .map(|it| self.compile_variable(it))
492                                .collect(),
493                        },
494                    ));
495                }
496                gpu::BarrierOps::Arrive { barrier } => {
497                    let VariableKind::Barrier { level, .. } = barrier.kind else {
498                        unreachable!()
499                    };
500                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
501                        barrier: self.compile_variable(barrier),
502                        level,
503                    }))
504                }
505                gpu::BarrierOps::ArriveTx {
506                    barrier,
507                    arrive_count_update,
508                    transaction_count_update,
509                } => {
510                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
511                        barrier: self.compile_variable(barrier),
512                        arrive_count_update: self.compile_variable(arrive_count_update),
513                        transaction_count_update: self.compile_variable(transaction_count_update),
514                    }))
515                }
516                gpu::BarrierOps::ExpectTx {
517                    barrier,
518                    transaction_count_update,
519                } => {
520                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
521                        barrier: self.compile_variable(barrier),
522                        transaction_count_update: self.compile_variable(transaction_count_update),
523                    }))
524                }
525                gpu::BarrierOps::Wait { barrier } => {
526                    let VariableKind::Barrier { level, .. } = barrier.kind else {
527                        unreachable!()
528                    };
529                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
530                        barrier: self.compile_variable(barrier),
531                        level,
532                    }))
533                }
534                gpu::BarrierOps::ArriveAndWait { barrier } => {
535                    let VariableKind::Barrier { level, .. } = barrier.kind else {
536                        unreachable!()
537                    };
538                    instructions.push(Instruction::Barrier(
539                        super::barrier::BarrierOps::ArriveAndWait {
540                            barrier: self.compile_variable(barrier),
541                            level,
542                        },
543                    ))
544                }
545            },
546            gpu::Operation::Tma(tma_ops) => {
547                self.flags.inst_tma = true;
548                match tma_ops {
549                    gpu::TmaOps::TmaStore {
550                        source,
551                        coordinates,
552                        offset_source,
553                    } => {
554                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
555                            smem_buffer: self.compile_variable(source),
556                            smem_offset: self.compile_variable(offset_source),
557                            tensor_map: self.compile_variable(out.unwrap()),
558                            indices: coordinates
559                                .into_iter()
560                                .map(|it| self.compile_variable(it))
561                                .collect(),
562                        });
563                    }
564                    gpu::TmaOps::CommitGroup => {
565                        instructions.push(Instruction::BulkCommitGroup);
566                    }
567                    gpu::TmaOps::WaitGroup { max_pending } => {
568                        instructions.push(Instruction::BulkWaitGroup { max_pending });
569                    }
570                    gpu::TmaOps::WaitGroupRead { max_pending } => {
571                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
572                    }
573                }
574            }
575        }
576    }
577
578    fn update_debug_loc(
579        &mut self,
580        instructions: &mut Vec<Instruction<D>>,
581        inst: &gpu::Instruction,
582    ) {
583        if !matches!(inst.operation, Operation::NonSemantic(_)) {
584            match &inst.source_loc {
585                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
586                    self.source_loc = Some(loc.clone());
587                    instructions.push(Instruction::Line {
588                        file: loc.source.file.clone(),
589                        line: loc.line,
590                    });
591                }
592                _ => {}
593            }
594        }
595    }
596
597    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
598        let out = self.compile_variable(out.unwrap());
599        match cmma {
600            gpu::CoopMma::Fill { value } => Instruction::Wmma(WmmaInstruction::Fill {
601                frag: out,
602                value: self.compile_variable(value),
603            }),
604            gpu::CoopMma::Load {
605                value,
606                stride,
607                offset,
608                layout,
609            } => Instruction::Wmma(WmmaInstruction::Load {
610                frag: out,
611                offset: self.compile_variable(offset),
612                value: self.compile_variable(value),
613                stride: self.compile_variable(stride),
614                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
615            }),
616            gpu::CoopMma::Execute {
617                mat_a,
618                mat_b,
619                mat_c,
620            } => Instruction::Wmma(WmmaInstruction::Execute {
621                frag_a: self.compile_variable(mat_a),
622                frag_b: self.compile_variable(mat_b),
623                frag_c: self.compile_variable(mat_c),
624                frag_d: out,
625                warp_size: self.compilation_options.warp_size,
626            }),
627            gpu::CoopMma::Store {
628                mat,
629                stride,
630                offset,
631                layout,
632            } => {
633                self.flags.indexes.unit_pos = true;
634                self.flags.indexes.plane_index = true;
635                Instruction::Wmma(WmmaInstruction::Store {
636                    output: out,
637                    offset: self.compile_variable(offset),
638                    frag: self.compile_variable(mat),
639                    stride: self.compile_variable(stride),
640                    layout: self
641                        .compile_matrix_layout(layout)
642                        .expect("Layout required for store instruction"),
643                })
644            }
645            gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast {
646                input: self.compile_variable(input),
647                output: out,
648            }),
649        }
650    }
651
652    fn compile_metadata(
653        &mut self,
654        metadata: gpu::Metadata,
655        out: Option<gpu::Variable>,
656    ) -> Instruction<D> {
657        let out = out.unwrap();
658        match metadata {
659            gpu::Metadata::Stride { dim, var } => {
660                let position = self.ext_meta_position(var);
661                let offset = self.metadata.stride_offset_index(position);
662                Instruction::ExtendedMetadata {
663                    info_offset: self.compile_variable(offset.into()),
664                    dim: self.compile_variable(dim),
665                    split_meta: self.compilation_options.grid_constants,
666                    static_offset: self.metadata.static_len(),
667                    out: self.compile_variable(out),
668                }
669            }
670            gpu::Metadata::Shape { dim, var } => {
671                let position = self.ext_meta_position(var);
672                let offset = self.metadata.shape_offset_index(position);
673                Instruction::ExtendedMetadata {
674                    info_offset: self.compile_variable(offset.into()),
675                    dim: self.compile_variable(dim),
676                    split_meta: self.compilation_options.grid_constants,
677                    static_offset: self.metadata.static_len(),
678                    out: self.compile_variable(out),
679                }
680            }
681            gpu::Metadata::Rank { var } => {
682                let out = self.compile_variable(out);
683                let pos = self.ext_meta_position(var);
684                let offset = self.metadata.rank_index(pos);
685                super::Instruction::Metadata {
686                    info_offset: self.compile_variable(offset.into()),
687                    split_meta: self.compilation_options.grid_constants,
688                    out,
689                }
690            }
691            gpu::Metadata::Length { var } => {
692                let input = self.compile_variable(var);
693                let out = self.compile_variable(out);
694
695                match input {
696                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
697                    Variable::SharedMemory(_id, _item, length) => {
698                        Instruction::ConstLength { length, out }
699                    }
700                    _ => {
701                        let id = input.id().expect("Variable should have id");
702                        let offset = self.metadata.len_index(id);
703                        Instruction::Metadata {
704                            info_offset: self.compile_variable(offset.into()),
705                            split_meta: self.compilation_options.grid_constants,
706                            out,
707                        }
708                    }
709                }
710            }
711            gpu::Metadata::BufferLength { var } => {
712                let input = self.compile_variable(var);
713                let out = self.compile_variable(out);
714
715                match input {
716                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
717                    _ => {
718                        let id = input.id().expect("Variable should have id");
719                        let offset = self.metadata.buffer_len_index(id);
720                        Instruction::Metadata {
721                            info_offset: self.compile_variable(offset.into()),
722                            split_meta: self.compilation_options.grid_constants,
723                            out,
724                        }
725                    }
726                }
727            }
728        }
729    }
730
731    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
732        match branch {
733            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
734                cond: self.compile_variable(op.cond),
735                instructions: self.compile_scope(&mut op.scope),
736            }),
737            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
738                cond: self.compile_variable(op.cond),
739                instructions_if: self.compile_scope(&mut op.scope_if),
740                instructions_else: self.compile_scope(&mut op.scope_else),
741            }),
742            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
743                value: self.compile_variable(op.value),
744                instructions_default: self.compile_scope(&mut op.scope_default),
745                instructions_cases: op
746                    .cases
747                    .into_iter()
748                    .map(|(val, mut block)| {
749                        (self.compile_variable(val), self.compile_scope(&mut block))
750                    })
751                    .collect(),
752            }),
753            gpu::Branch::Return => instructions.push(Instruction::Return),
754            gpu::Branch::Break => instructions.push(Instruction::Break),
755            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
756                i: self.compile_variable(range_loop.i),
757                start: self.compile_variable(range_loop.start),
758                end: self.compile_variable(range_loop.end),
759                step: range_loop.step.map(|it| self.compile_variable(it)),
760                inclusive: range_loop.inclusive,
761                instructions: self.compile_scope(&mut range_loop.scope),
762            }),
763            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
764                instructions: self.compile_scope(&mut op.scope),
765            }),
766        };
767    }
768
769    fn compile_atomic(
770        &mut self,
771        value: gpu::AtomicOp,
772        out: Option<gpu::Variable>,
773        instructions: &mut Vec<Instruction<D>>,
774    ) {
775        let out = out.unwrap();
776        match value {
777            gpu::AtomicOp::Load(op) => {
778                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
779            }
780            gpu::AtomicOp::Store(op) => {
781                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
782            }
783            gpu::AtomicOp::Swap(op) => {
784                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
785            }
786            gpu::AtomicOp::Add(op) => {
787                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
788            }
789            gpu::AtomicOp::Sub(op) => {
790                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
791            }
792            gpu::AtomicOp::Max(op) => {
793                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
794            }
795            gpu::AtomicOp::Min(op) => {
796                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
797            }
798            gpu::AtomicOp::And(op) => {
799                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
800            }
801            gpu::AtomicOp::Or(op) => {
802                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
803            }
804            gpu::AtomicOp::Xor(op) => {
805                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
806            }
807            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
808                input: self.compile_variable(op.input),
809                cmp: self.compile_variable(op.cmp),
810                val: self.compile_variable(op.val),
811                out: self.compile_variable(out),
812            }),
813        }
814    }
815
816    fn compile_arithmetic(
817        &mut self,
818        value: gpu::Arithmetic,
819        out: Option<gpu::Variable>,
820        instructions: &mut Vec<Instruction<D>>,
821    ) {
822        let out = out.unwrap();
823        match value {
824            gpu::Arithmetic::Add(op) => {
825                instructions.push(Instruction::Add(self.compile_binary(op, out)))
826            }
827            gpu::Arithmetic::Mul(op) => {
828                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
829            }
830            gpu::Arithmetic::Div(op) => {
831                instructions.push(Instruction::Div(self.compile_binary(op, out)))
832            }
833            gpu::Arithmetic::Sub(op) => {
834                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
835            }
836            gpu::Arithmetic::MulHi(op) => {
837                let instruction = Instruction::HiMul(self.compile_binary(op, out));
838                D::register_instruction_extension(&mut self.extensions, &instruction);
839                instructions.push(instruction)
840            }
841            gpu::Arithmetic::Modulo(op) => {
842                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
843            }
844            gpu::Arithmetic::Abs(op) => {
845                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
846            }
847            gpu::Arithmetic::Exp(op) => {
848                instructions.push(Instruction::Exp(self.compile_unary(op, out)))
849            }
850            gpu::Arithmetic::Log(op) => {
851                instructions.push(Instruction::Log(self.compile_unary(op, out)))
852            }
853            gpu::Arithmetic::Log1p(op) => {
854                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
855            }
856            gpu::Arithmetic::Cos(op) => {
857                instructions.push(Instruction::Cos(self.compile_unary(op, out)))
858            }
859            gpu::Arithmetic::Sin(op) => {
860                instructions.push(Instruction::Sin(self.compile_unary(op, out)))
861            }
862            gpu::Arithmetic::Tanh(op) => {
863                let instruction = Instruction::Tanh(self.compile_unary(op, out));
864                D::register_instruction_extension(&mut self.extensions, &instruction);
865                instructions.push(instruction)
866            }
867            gpu::Arithmetic::Powf(op) => {
868                instructions.push(Instruction::Powf(self.compile_binary(op, out)))
869            }
870            gpu::Arithmetic::Sqrt(op) => {
871                instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
872            }
873            gpu::Arithmetic::Erf(op) => {
874                let instruction = Instruction::Erf(self.compile_unary(op, out));
875                D::register_instruction_extension(&mut self.extensions, &instruction);
876                instructions.push(instruction)
877            }
878            gpu::Arithmetic::Max(op) => {
879                let instruction = Instruction::Max(self.compile_binary(op, out));
880                D::register_instruction_extension(&mut self.extensions, &instruction);
881                instructions.push(instruction)
882            }
883            gpu::Arithmetic::Min(op) => {
884                let instruction = Instruction::Min(self.compile_binary(op, out));
885                D::register_instruction_extension(&mut self.extensions, &instruction);
886                instructions.push(instruction)
887            }
888            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
889                input: self.compile_variable(op.input),
890                min_value: self.compile_variable(op.min_value),
891                max_value: self.compile_variable(op.max_value),
892                out: self.compile_variable(out),
893            }),
894            gpu::Arithmetic::Recip(op) => {
895                let elem = op.input.item.elem();
896                let lhs = match elem {
897                    gpu::Elem::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
898                    gpu::Elem::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
899                    gpu::Elem::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
900                    gpu::Elem::Bool => gpu::ConstantScalarValue::Bool(true),
901                    gpu::Elem::AtomicInt(_)
902                    | gpu::Elem::AtomicUInt(_)
903                    | gpu::Elem::AtomicFloat(_) => {
904                        panic!("Cannot use recip with atomics")
905                    }
906                };
907
908                instructions.push(Instruction::Div(BinaryInstruction {
909                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
910                    rhs: self.compile_variable(op.input),
911                    out: self.compile_variable(out),
912                }))
913            }
914            gpu::Arithmetic::Round(op) => {
915                instructions.push(Instruction::Round(self.compile_unary(op, out)))
916            }
917            gpu::Arithmetic::Floor(op) => {
918                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
919            }
920            gpu::Arithmetic::Ceil(op) => {
921                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
922            }
923            gpu::Arithmetic::Remainder(op) => {
924                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
925            }
926            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
927                a: self.compile_variable(op.a),
928                b: self.compile_variable(op.b),
929                c: self.compile_variable(op.c),
930                out: self.compile_variable(out),
931            }),
932            gpu::Arithmetic::Neg(op) => {
933                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
934            }
935            gpu::Arithmetic::Normalize(op) => {
936                instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
937            }
938            gpu::Arithmetic::Magnitude(op) => {
939                instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
940            }
941            gpu::Arithmetic::Dot(op) => {
942                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
943            }
944        };
945    }
946
947    fn compile_comparison(
948        &mut self,
949        value: gpu::Comparison,
950        out: Option<gpu::Variable>,
951        instructions: &mut Vec<Instruction<D>>,
952    ) {
953        let out = out.unwrap();
954        match value {
955            gpu::Comparison::Equal(op) => {
956                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
957            }
958            gpu::Comparison::Lower(op) => {
959                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
960            }
961            gpu::Comparison::Greater(op) => {
962                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
963            }
964            gpu::Comparison::LowerEqual(op) => {
965                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
966            }
967            gpu::Comparison::GreaterEqual(op) => {
968                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
969            }
970            gpu::Comparison::NotEqual(op) => {
971                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
972            }
973        };
974    }
975
976    fn compile_bitwise(
977        &mut self,
978        value: gpu::Bitwise,
979        out: Option<gpu::Variable>,
980        instructions: &mut Vec<Instruction<D>>,
981    ) {
982        let out = out.unwrap();
983        match value {
984            gpu::Bitwise::BitwiseOr(op) => {
985                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
986            }
987            gpu::Bitwise::BitwiseAnd(op) => {
988                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
989            }
990            gpu::Bitwise::BitwiseXor(op) => {
991                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
992            }
993            gpu::Bitwise::CountOnes(op) => {
994                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
995            }
996            gpu::Bitwise::ReverseBits(op) => {
997                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
998            }
999            gpu::Bitwise::ShiftLeft(op) => {
1000                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1001            }
1002            gpu::Bitwise::ShiftRight(op) => {
1003                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1004            }
1005            gpu::Bitwise::BitwiseNot(op) => {
1006                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1007            }
1008            gpu::Bitwise::LeadingZeros(op) => {
1009                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1010            }
1011            gpu::Bitwise::FindFirstSet(op) => {
1012                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1013                D::register_instruction_extension(&mut self.extensions, &instruction);
1014                instructions.push(instruction)
1015            }
1016        };
1017    }
1018
1019    fn compile_operator(
1020        &mut self,
1021        value: gpu::Operator,
1022        out: Option<gpu::Variable>,
1023        instructions: &mut Vec<Instruction<D>>,
1024    ) {
1025        let out = out.unwrap();
1026        match value {
1027            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1028                instructions.push(Instruction::Index(self.compile_index(op, out)));
1029            }
1030            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1031                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1032            }
1033            gpu::Operator::And(op) => {
1034                instructions.push(Instruction::And(self.compile_binary(op, out)))
1035            }
1036            gpu::Operator::Or(op) => {
1037                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1038            }
1039            gpu::Operator::Not(op) => {
1040                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1041            }
1042            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1043                inputs: op
1044                    .inputs
1045                    .into_iter()
1046                    .map(|it| self.compile_variable(it))
1047                    .collect(),
1048                out: self.compile_variable(out),
1049            }),
1050            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1051                input: self.compile_variable(op.input),
1052                in_index: self.compile_variable(op.in_index),
1053                out: self.compile_variable(out),
1054                out_index: self.compile_variable(op.out_index),
1055            }),
1056            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1057                input: self.compile_variable(op.input),
1058                in_index: self.compile_variable(op.in_index),
1059                out: self.compile_variable(out),
1060                out_index: self.compile_variable(op.out_index),
1061                len: op.len.as_const().unwrap().as_u32(),
1062            }),
1063            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1064                cond: self.compile_variable(op.cond),
1065                then: self.compile_variable(op.then),
1066                or_else: self.compile_variable(op.or_else),
1067                out: self.compile_variable(out),
1068            }),
1069            // Needs special conversion semantics
1070            gpu::Operator::Cast(op)
1071                if is_fp4_fp6_fp8(op.input.elem()) || is_fp4_fp6_fp8(out.elem()) =>
1072            {
1073                let inst = self.compile_unary(op, out);
1074
1075                // We may need these for intermediates
1076                self.flags.elem_f16 = true;
1077                self.flags.elem_bf16 = true;
1078                let vec = NonZero::new(inst.input.item().vectorization as u8);
1079                self.compile_item(gpu::Item::vectorized(gpu::Elem::Float(FloatKind::F16), vec));
1080                self.compile_item(gpu::Item::vectorized(
1081                    gpu::Elem::Float(FloatKind::BF16),
1082                    vec,
1083                ));
1084
1085                instructions.push(Instruction::SpecialCast(inst));
1086            }
1087            gpu::Operator::Cast(op) => {
1088                instructions.push(Instruction::Assign(self.compile_unary(op, out)))
1089            }
1090            gpu::Operator::Reinterpret(op) => {
1091                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1092            }
1093        };
1094    }
1095
1096    fn compile_binary(
1097        &mut self,
1098        value: gpu::BinaryOperator,
1099        out: gpu::Variable,
1100    ) -> BinaryInstruction<D> {
1101        BinaryInstruction {
1102            lhs: self.compile_variable(value.lhs),
1103            rhs: self.compile_variable(value.rhs),
1104            out: self.compile_variable(out),
1105        }
1106    }
1107
1108    fn compile_index_assign(
1109        &mut self,
1110        value: gpu::IndexAssignOperator,
1111        out: gpu::Variable,
1112    ) -> IndexAssignInstruction<D> {
1113        IndexAssignInstruction {
1114            index: self.compile_variable(value.index),
1115            value: self.compile_variable(value.value),
1116            line_size: value.line_size,
1117            out: self.compile_variable(out),
1118        }
1119    }
1120
1121    fn compile_index(
1122        &mut self,
1123        value: gpu::IndexOperator,
1124        out: gpu::Variable,
1125    ) -> IndexInstruction<D> {
1126        IndexInstruction {
1127            list: self.compile_variable(value.list),
1128            index: self.compile_variable(value.index),
1129            line_size: value.line_size,
1130            out: self.compile_variable(out),
1131        }
1132    }
1133
1134    fn compile_unary(
1135        &mut self,
1136        value: gpu::UnaryOperator,
1137        out: gpu::Variable,
1138    ) -> UnaryInstruction<D> {
1139        UnaryInstruction {
1140            input: self.compile_variable(value.input),
1141            out: self.compile_variable(out),
1142        }
1143    }
1144
1145    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1146        let item = value.item;
1147        match value.kind {
1148            gpu::VariableKind::GlobalInputArray(id) => {
1149                Variable::GlobalInputArray(id, self.compile_item(item))
1150            }
1151            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1152                id,
1153                elem: self.compile_elem(item.elem),
1154                in_struct: self.compilation_options.grid_constants,
1155            },
1156            gpu::VariableKind::TensorMap(id) => {
1157                self.flags.inst_tma = true;
1158                Variable::TensorMap(id)
1159            }
1160            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1161                id,
1162                item: self.compile_item(item),
1163            },
1164            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1165                id,
1166                item: self.compile_item(item),
1167            },
1168            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1169                id,
1170                item: self.compile_item(item),
1171            },
1172            gpu::VariableKind::GlobalOutputArray(id) => {
1173                Variable::GlobalOutputArray(id, self.compile_item(item))
1174            }
1175            gpu::VariableKind::ConstantScalar(value) => {
1176                Variable::ConstantScalar(value, self.compile_elem(value.elem()))
1177            }
1178            gpu::VariableKind::SharedMemory {
1179                id,
1180                length,
1181                alignment,
1182            } => {
1183                let item = self.compile_item(item);
1184                if !self.shared_memories.iter().any(|s| s.index == id) {
1185                    self.shared_memories
1186                        .push(SharedMemory::new(id, item, length, alignment));
1187                }
1188                Variable::SharedMemory(id, item, length)
1189            }
1190            gpu::VariableKind::ConstantArray { id, length } => {
1191                let item = self.compile_item(item);
1192                Variable::ConstantArray(id, item, length)
1193            }
1194            gpu::VariableKind::Builtin(builtin) => match builtin {
1195                gpu::Builtin::AbsolutePos => {
1196                    self.flags.indexes.absolute_pos = true;
1197                    Variable::AbsolutePos
1198                }
1199                gpu::Builtin::CubePosCluster if self.compilation_options.supports_clusters => {
1200                    self.flags.indexes.cluster_pos = true;
1201                    Variable::ClusterRank
1202                }
1203                gpu::Builtin::CubePosClusterX if self.compilation_options.supports_clusters => {
1204                    self.flags.indexes.cluster_pos = true;
1205                    Variable::ClusterIndexX
1206                }
1207                gpu::Builtin::CubePosClusterY if self.compilation_options.supports_clusters => {
1208                    self.flags.indexes.cluster_pos = true;
1209                    Variable::ClusterIndexY
1210                }
1211                gpu::Builtin::CubePosClusterZ if self.compilation_options.supports_clusters => {
1212                    self.flags.indexes.cluster_pos = true;
1213                    Variable::ClusterIndexZ
1214                }
1215                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1216                // (1, 1, 1) if unsupported
1217                gpu::Builtin::CubePosCluster
1218                | gpu::Builtin::CubePosClusterX
1219                | gpu::Builtin::CubePosClusterY
1220                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1221                gpu::Builtin::AbsolutePosX => {
1222                    self.flags.indexes.absolute_pos_tuple = true;
1223                    Variable::AbsolutePosX
1224                }
1225                gpu::Builtin::AbsolutePosY => {
1226                    self.flags.indexes.absolute_pos_tuple = true;
1227                    Variable::AbsolutePosY
1228                }
1229                gpu::Builtin::AbsolutePosZ => {
1230                    self.flags.indexes.absolute_pos_tuple = true;
1231                    Variable::AbsolutePosZ
1232                }
1233                gpu::Builtin::CubeDim => {
1234                    self.flags.indexes.cube_dim = true;
1235                    Variable::CubeDim
1236                }
1237                gpu::Builtin::CubeDimX => {
1238                    self.flags.indexes.cube_dim_tuple = true;
1239                    Variable::CubeDimX
1240                }
1241                gpu::Builtin::CubeDimY => {
1242                    self.flags.indexes.cube_dim_tuple = true;
1243                    Variable::CubeDimY
1244                }
1245                gpu::Builtin::CubeDimZ => {
1246                    self.flags.indexes.cube_dim_tuple = true;
1247                    Variable::CubeDimZ
1248                }
1249                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1250                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1251                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1252                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1253                gpu::Builtin::CubePos => {
1254                    self.flags.indexes.cube_pos = true;
1255                    Variable::CubePos
1256                }
1257                gpu::Builtin::CubePosX => {
1258                    self.flags.indexes.cube_pos_tuple = true;
1259                    Variable::CubePosX
1260                }
1261                gpu::Builtin::CubePosY => {
1262                    self.flags.indexes.cube_pos_tuple = true;
1263                    Variable::CubePosY
1264                }
1265                gpu::Builtin::CubePosZ => {
1266                    self.flags.indexes.cube_pos_tuple = true;
1267                    Variable::CubePosZ
1268                }
1269                gpu::Builtin::CubeCount => {
1270                    self.flags.indexes.cube_count = true;
1271                    Variable::CubeCount
1272                }
1273                gpu::Builtin::CubeCountX => {
1274                    self.flags.indexes.cube_count_tuple = true;
1275                    Variable::CubeCountX
1276                }
1277                gpu::Builtin::CubeCountY => {
1278                    self.flags.indexes.cube_count_tuple = true;
1279                    Variable::CubeCountY
1280                }
1281                gpu::Builtin::CubeCountZ => {
1282                    self.flags.indexes.cube_count_tuple = true;
1283                    Variable::CubeCountZ
1284                }
1285                gpu::Builtin::UnitPos => {
1286                    self.flags.indexes.unit_pos = true;
1287                    Variable::UnitPos
1288                }
1289                gpu::Builtin::UnitPosX => {
1290                    self.flags.indexes.unit_pos_tuple = true;
1291                    Variable::UnitPosX
1292                }
1293                gpu::Builtin::UnitPosY => {
1294                    self.flags.indexes.unit_pos_tuple = true;
1295                    Variable::UnitPosY
1296                }
1297                gpu::Builtin::UnitPosZ => {
1298                    self.flags.indexes.unit_pos_tuple = true;
1299                    Variable::UnitPosZ
1300                }
1301                gpu::Builtin::PlaneDim => {
1302                    self.flags.indexes.plane_dim = true;
1303                    Variable::PlaneDim
1304                }
1305                gpu::Builtin::UnitPosPlane => {
1306                    self.flags.indexes.unit_pos_plane = true;
1307                    Variable::UnitPosPlane
1308                }
1309            },
1310            gpu::VariableKind::LocalArray { id, length } => {
1311                let item = self.compile_item(item);
1312                if !self.local_arrays.iter().any(|s| s.index == id) {
1313                    self.local_arrays.push(LocalArray::new(id, item, length));
1314                }
1315                Variable::LocalArray(id, item, length)
1316            }
1317            gpu::VariableKind::Matrix { id, mat } => {
1318                self.flags.inst_wmma = true;
1319                Variable::WmmaFragment {
1320                    id,
1321                    frag: self.compile_matrix(mat),
1322                }
1323            }
1324            gpu::VariableKind::Pipeline {
1325                id,
1326                item,
1327                num_stages,
1328            } => {
1329                self.flags.op_pipeline = true;
1330                let pipeline = Variable::Pipeline {
1331                    id,
1332                    item: self.compile_item(item),
1333                };
1334                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1335                    self.pipelines.push(PipelineOps::Init {
1336                        pipeline,
1337                        num_stages,
1338                    });
1339                }
1340                pipeline
1341            }
1342            gpu::VariableKind::Barrier { id, item, level } => {
1343                self.flags.op_barrier = true;
1344                match level {
1345                    gpu::BarrierLevel::CubeCoop(_) | gpu::BarrierLevel::CubeManual(_) => {
1346                        self.flags.indexes.cube_dim = true;
1347                        self.flags.indexes.unit_pos = true;
1348                    }
1349                    _ => {}
1350                }
1351                Variable::Barrier {
1352                    id,
1353                    item: self.compile_item(item),
1354                    level,
1355                }
1356            }
1357        }
1358    }
1359
1360    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1361        Fragment {
1362            ident: self.compile_matrix_ident(matrix.ident),
1363            m: matrix.m,
1364            n: matrix.n,
1365            k: matrix.k,
1366            elem: self.compile_elem(matrix.elem),
1367            layout: self.compile_matrix_layout(matrix.layout),
1368        }
1369    }
1370
1371    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1372        match ident {
1373            gpu::MatrixIdent::A => FragmentIdent::A,
1374            gpu::MatrixIdent::B => FragmentIdent::B,
1375            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1376        }
1377    }
1378
1379    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1380        match layout {
1381            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1382            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1383            gpu::MatrixLayout::Undefined => None,
1384        }
1385    }
1386
1387    fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1388        Binding {
1389            id: binding.id,
1390            item: self.compile_item(binding.item),
1391            location: binding.location,
1392            size: binding.size,
1393            vis: binding.visibility,
1394        }
1395    }
1396
1397    fn compile_item(&mut self, item: gpu::Item) -> Item<D> {
1398        let item = Item::new(
1399            self.compile_elem(item.elem),
1400            item.vectorization.map(NonZero::get).unwrap_or(1).into(),
1401            false,
1402        );
1403        if item.elem != super::Elem::TF32 {
1404            self.items.insert(item);
1405            self.items.insert(item.optimized());
1406        } else {
1407            // TF32 is represented as `float` in C++
1408            let mut item = item;
1409            item.elem = super::Elem::F32;
1410            self.items.insert(item);
1411        }
1412
1413        item
1414    }
1415
1416    fn compile_elem(&mut self, value: gpu::Elem) -> Elem<D> {
1417        match value {
1418            gpu::Elem::Float(kind) => match kind {
1419                gpu::FloatKind::E2M1 => {
1420                    self.flags.elem_fp4 = true;
1421                    Elem::FP4(FP4Kind::E2M1)
1422                }
1423                gpu::FloatKind::E2M3 => {
1424                    self.flags.elem_fp6 = true;
1425                    Elem::FP6(FP6Kind::E2M3)
1426                }
1427                gpu::FloatKind::E3M2 => {
1428                    self.flags.elem_fp6 = true;
1429                    Elem::FP6(FP6Kind::E3M2)
1430                }
1431                gpu::FloatKind::E4M3 => {
1432                    self.flags.elem_fp8 = true;
1433                    Elem::FP8(FP8Kind::E4M3)
1434                }
1435                gpu::FloatKind::E5M2 => {
1436                    self.flags.elem_fp8 = true;
1437                    Elem::FP8(FP8Kind::E5M2)
1438                }
1439                gpu::FloatKind::UE8M0 => {
1440                    self.flags.elem_fp8 = true;
1441                    Elem::FP8(FP8Kind::UE8M0)
1442                }
1443                gpu::FloatKind::F16 => {
1444                    self.flags.elem_f16 = true;
1445                    Elem::F16
1446                }
1447                gpu::FloatKind::BF16 => {
1448                    self.flags.elem_bf16 = true;
1449                    Elem::BF16
1450                }
1451                gpu::FloatKind::TF32 => Elem::TF32,
1452                gpu::FloatKind::Flex32 => Elem::F32,
1453                gpu::FloatKind::F32 => Elem::F32,
1454                gpu::FloatKind::F64 => Elem::F64,
1455            },
1456            gpu::Elem::AtomicFloat(kind) => match kind {
1457                gpu::FloatKind::F16 => Elem::Atomic(AtomicKind::F16),
1458                gpu::FloatKind::BF16 => Elem::Atomic(AtomicKind::BF16),
1459                gpu::FloatKind::F32 => Elem::Atomic(AtomicKind::F32),
1460                gpu::FloatKind::F64 => Elem::Atomic(AtomicKind::F64),
1461                kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1462            },
1463            gpu::Elem::Int(kind) => match kind {
1464                gpu::IntKind::I8 => Elem::I8,
1465                gpu::IntKind::I16 => Elem::I16,
1466                gpu::IntKind::I32 => Elem::I32,
1467                gpu::IntKind::I64 => Elem::I64,
1468            },
1469            gpu::Elem::AtomicInt(kind) => match kind {
1470                gpu::IntKind::I32 => Elem::Atomic(AtomicKind::I32),
1471                gpu::IntKind::I64 => Elem::Atomic(AtomicKind::I64),
1472                kind => panic!("atomic<{kind:?}> isn't supported yet"),
1473            },
1474            gpu::Elem::UInt(kind) => match kind {
1475                gpu::UIntKind::U8 => Elem::U8,
1476                gpu::UIntKind::U16 => Elem::U16,
1477                gpu::UIntKind::U32 => Elem::U32,
1478                gpu::UIntKind::U64 => Elem::U64,
1479            },
1480            gpu::Elem::AtomicUInt(kind) => match kind {
1481                gpu::UIntKind::U32 => Elem::Atomic(AtomicKind::U32),
1482                gpu::UIntKind::U64 => Elem::Atomic(AtomicKind::U64),
1483                kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1484            },
1485            gpu::Elem::Bool => Elem::Bool,
1486        }
1487    }
1488}
1489
1490fn is_fp4_fp6_fp8(elem: gpu::Elem) -> bool {
1491    match elem {
1492        gpu::Elem::Float(kind) => matches!(
1493            kind,
1494            FloatKind::E2M1
1495                | FloatKind::E2M3
1496                | FloatKind::E3M2
1497                | FloatKind::E4M3
1498                | FloatKind::E5M2
1499                | FloatKind::UE8M0
1500        ),
1501        _ => false,
1502    }
1503}
1504
1505fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1506    Variable::ConstantScalar(
1507        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1508        Elem::U32,
1509    )
1510}
1511
1512pub fn register_supported_types(props: &mut DeviceProperties<Feature>) {
1513    let supported_types = [
1514        gpu::Elem::UInt(gpu::UIntKind::U8),
1515        gpu::Elem::UInt(gpu::UIntKind::U16),
1516        gpu::Elem::UInt(gpu::UIntKind::U32),
1517        gpu::Elem::UInt(gpu::UIntKind::U64),
1518        gpu::Elem::Int(gpu::IntKind::I8),
1519        gpu::Elem::Int(gpu::IntKind::I16),
1520        gpu::Elem::Int(gpu::IntKind::I32),
1521        gpu::Elem::Int(gpu::IntKind::I64),
1522        gpu::Elem::AtomicInt(gpu::IntKind::I32),
1523        gpu::Elem::AtomicInt(gpu::IntKind::I64),
1524        gpu::Elem::AtomicUInt(gpu::UIntKind::U32),
1525        gpu::Elem::AtomicUInt(gpu::UIntKind::U64),
1526        gpu::Elem::Float(gpu::FloatKind::BF16),
1527        gpu::Elem::Float(gpu::FloatKind::F16),
1528        gpu::Elem::Float(gpu::FloatKind::F32),
1529        gpu::Elem::Float(gpu::FloatKind::Flex32),
1530        gpu::Elem::AtomicFloat(gpu::FloatKind::F32),
1531        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
1532        //gpu::Elem::Float(gpu::FloatKind::F64),
1533        gpu::Elem::Bool,
1534    ];
1535
1536    for ty in supported_types {
1537        props.register_feature(Feature::Type(ty));
1538    }
1539}