cubecl_cpp/shared/
base.rs

1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_common::backtrace::BackTrace;
5use cubecl_core::ir::{self as gpu, OpaqueType, StorageType};
6use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind};
7use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
8use cubecl_core::{CubeDim, ir::ElemType};
9use cubecl_core::{
10    ir::{Operation, SourceLoc},
11    prelude::{FastMath, KernelDefinition},
12};
13use cubecl_opt::{Optimizer, SharedLiveness};
14use cubecl_runtime::compiler::CompilationError;
15use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage, compiler::Compiler};
16
17use crate::shared::MmaShape;
18
19use super::{
20    BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
21    Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
22    Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
23};
24use super::{FP4Kind, barrier::BarrierOps};
25use super::{FP8Kind, pipeline::PipelineOps};
26
27pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
28    std::sync::atomic::AtomicU32::new(0);
29
30#[derive(Clone, Debug)]
31pub struct CompilationOptions {
32    pub warp_size: u32,
33    pub supports_features: CppSupportedFeatures,
34}
35
36#[derive(Clone, Debug, Default)]
37pub struct CppSupportedFeatures {
38    pub grid_constants: bool,
39    pub clusters: bool,
40    pub fast_math: bool,
41    pub fast_tanh: bool,
42    pub elect_sync: bool,
43}
44
45impl Default for CompilationOptions {
46    fn default() -> Self {
47        Self {
48            warp_size: 32,
49            supports_features: Default::default(),
50        }
51    }
52}
53
54/// Cube indexes flags.
55/// When true the corresponding index is declared and computed as needed in the kernel.
56#[derive(Debug, Clone, Default)]
57pub struct CubeIndexFlags {
58    pub absolute_pos: bool,
59    pub absolute_pos_tuple: bool,
60    pub cube_count: bool,
61    pub cube_count_tuple: bool,
62    pub cube_dim: bool,
63    pub cube_dim_tuple: bool,
64    pub cube_pos: bool,
65    pub cube_pos_tuple: bool,
66    pub plane_dim: bool,
67    pub plane_dim_checked: bool,
68    pub plane_index: bool,
69    pub unit_pos: bool,
70    pub unit_pos_tuple: bool,
71    pub unit_pos_plane: bool,
72    pub cluster_pos: bool,
73}
74
75/// Flags gathered during Cube IR translation for the kernel compilation.
76#[derive(Debug, Clone, Default)]
77pub struct Flags {
78    pub elem_fp4: bool,
79    pub elem_fp6: bool,
80    pub elem_fp8: bool,
81    pub elem_bf16: bool,
82    pub elem_f16: bool,
83    pub elem_tf32: bool,
84    pub indexes: CubeIndexFlags,
85    pub op_barrier: bool,
86    pub op_pipeline: bool,
87    pub inst_tma: bool,
88    pub inst_tma_im2col: bool,
89    pub inst_wmma: bool,
90    pub inst_ptx_wrappers: bool,
91    pub inst_async_copy: bool,
92    pub use_grid_constants: bool,
93    pub static_meta_length: usize,
94    pub has_dynamic_meta: bool,
95    pub cube_dim: CubeDim,
96    pub cluster_dim: Option<CubeDim>,
97}
98
99#[allow(clippy::too_many_arguments)]
100#[derive(Clone, Debug, Default)]
101pub struct CppCompiler<D: Dialect> {
102    barriers: Vec<BarrierOps<D>>,
103    compilation_options: CompilationOptions,
104    const_arrays: Vec<ConstArray<D>>,
105    ext_meta_positions: Vec<u32>,
106    cluster_dim: CubeDim,
107    extensions: Vec<D::Extension>,
108    flags: Flags,
109    items: HashSet<Item<D>>,
110    local_arrays: Vec<LocalArray<D>>,
111    metadata: cubecl_core::Metadata,
112    pipelines: Vec<PipelineOps<D>>,
113    source_loc: Option<SourceLoc>,
114    strategy: ExecutionMode,
115}
116
117impl<D: Dialect> Compiler for CppCompiler<D> {
118    type Representation = ComputeKernel<D>;
119    type CompilationOptions = CompilationOptions;
120
121    fn compile(
122        &mut self,
123        mut kernel: KernelDefinition,
124        compilation_options: &Self::CompilationOptions,
125        strategy: ExecutionMode,
126    ) -> Result<Self::Representation, CompilationError> {
127        let errors = kernel.body.pop_errors();
128        if !errors.is_empty() {
129            let mut reason = "Can't compile cpp kernel\nCaused by:\n  ".to_string();
130            for error in errors {
131                reason += error.as_str();
132                reason += "\n";
133            }
134
135            return Err(CompilationError::Validation {
136                reason,
137                backtrace: BackTrace::capture(),
138            });
139        }
140
141        self.compilation_options = compilation_options.clone();
142        self.strategy = strategy;
143
144        if !self.compilation_options.supports_features.clusters {
145            kernel.options.cluster_dim = None;
146        }
147        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
148
149        let ir = self.clone().compile_ir(kernel);
150        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
151        Ok(ir)
152    }
153
154    fn elem_size(&self, elem: gpu::ElemType) -> usize {
155        elem.size()
156    }
157
158    fn extension(&self) -> &'static str {
159        "cpp"
160    }
161}
162
163impl<D: Dialect> CppCompiler<D> {
164    fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
165        self.build_metadata(&value);
166
167        let instructions = self.compile_scope(&mut value.body.clone());
168        let tensor_maps = value
169            .tensor_maps
170            .into_iter()
171            .map(|b| self.compile_binding(b))
172            .collect();
173        let buffers = value
174            .buffers
175            .into_iter()
176            .map(|b| self.compile_binding(b))
177            .collect();
178        let scalars = value
179            .scalars
180            .into_iter()
181            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
182            .collect();
183
184        // translation flags
185        let flags = Flags {
186            indexes: D::builtin_rules(&self.flags.indexes),
187            inst_wmma: self.flags.inst_wmma,
188            op_pipeline: self.flags.op_pipeline,
189            op_barrier: self.flags.op_barrier,
190            elem_fp4: self.flags.elem_fp4,
191            elem_fp6: self.flags.elem_fp6,
192            elem_fp8: self.flags.elem_fp8,
193            elem_bf16: self.flags.elem_bf16,
194            elem_f16: self.flags.elem_f16,
195            elem_tf32: self.flags.elem_tf32,
196            inst_tma: self.flags.inst_tma,
197            inst_tma_im2col: self.flags.inst_tma_im2col,
198            inst_async_copy: self.flags.inst_async_copy,
199            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
200            use_grid_constants: self.compilation_options.supports_features.grid_constants,
201            // TODO: At some point we should only pass dynamic meta if tensors are present,
202            // not if only arrays are present. For now, leave like this
203            has_dynamic_meta: self.metadata.static_len() > 0,
204            static_meta_length: self.metadata.static_len() as usize,
205            cube_dim: value.cube_dim,
206            cluster_dim: value.options.cluster_dim,
207        };
208
209        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
210        let shared_allocs = opt.analysis::<SharedLiveness>();
211        let shared_memories = shared_allocs
212            .allocations
213            .values()
214            .map(|alloc| match alloc.smem {
215                cubecl_opt::SharedMemory::Array {
216                    id,
217                    length,
218                    ty,
219                    align,
220                } => SharedMemory::Array {
221                    index: id,
222                    item: self.compile_type(ty),
223                    length,
224                    align,
225                    offset: alloc.offset,
226                },
227                cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
228                    index: id,
229                    item: self.compile_type(ty),
230                    align,
231                    offset: alloc.offset,
232                },
233            })
234            .collect();
235
236        let body = Body {
237            instructions,
238            shared_memories,
239            pipelines: self.pipelines,
240            barriers: self.barriers,
241            const_arrays: self.const_arrays,
242            local_arrays: self.local_arrays,
243        };
244
245        let mut cluster_dim = value.options.cluster_dim;
246        if !self.compilation_options.supports_features.clusters {
247            cluster_dim = None;
248        }
249
250        ComputeKernel {
251            tensor_maps,
252            buffers,
253            scalars,
254            meta_static_len: self.metadata.static_len() as usize,
255            cube_dim: value.cube_dim,
256            body,
257            extensions: self.extensions,
258            flags,
259            items: self.items,
260            kernel_name: value.options.kernel_name,
261            cluster_dim,
262        }
263    }
264
265    fn build_metadata(&mut self, value: &KernelDefinition) {
266        let mut num_ext = 0;
267
268        let mut all_meta: Vec<_> = value
269            .buffers
270            .iter()
271            .chain(value.tensor_maps.iter())
272            .map(|buf| (buf.id, buf.has_extended_meta))
273            .collect();
274
275        all_meta.sort_by_key(|(id, _)| *id);
276
277        for (_, has_extended_meta) in &all_meta {
278            self.ext_meta_positions.push(num_ext);
279            if *has_extended_meta {
280                num_ext += 1;
281            }
282        }
283
284        let num_meta = all_meta.len();
285
286        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
287    }
288
289    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
290        let id = var.index().expect("Variable should have index");
291        self.ext_meta_positions[id as usize]
292    }
293
294    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
295        let mut instructions = Vec::new();
296
297        let const_arrays = scope
298            .const_arrays
299            .drain(..)
300            .map(|(var, values)| ConstArray {
301                index: var.index().unwrap(),
302                item: self.compile_type(var.ty),
303                size: values.len() as u32,
304                values: values
305                    .into_iter()
306                    .map(|val| self.compile_variable(val))
307                    .collect(),
308            })
309            .collect::<Vec<_>>();
310        self.const_arrays.extend(const_arrays);
311
312        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
313        let dialect_processors = D::processors();
314        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
315        processors.extend(dialect_processors.iter().map(|it| &**it));
316
317        let processing = scope.process(processors);
318
319        for var in processing.variables {
320            instructions.push(Instruction::DeclareVariable {
321                var: self.compile_variable(var),
322            });
323        }
324
325        processing
326            .instructions
327            .into_iter()
328            .for_each(|op| self.compile_instruction(&mut instructions, op));
329
330        instructions
331    }
332
333    fn compile_instruction(
334        &mut self,
335        instructions: &mut Vec<Instruction<D>>,
336        instruction: gpu::Instruction,
337    ) {
338        self.update_debug_loc(instructions, &instruction);
339        let out = instruction.out;
340        match instruction.operation {
341            gpu::Operation::Copy(variable) => {
342                instructions.push(Instruction::Assign(UnaryInstruction {
343                    input: self.compile_variable(variable),
344                    out: self.compile_variable(out.unwrap()),
345                }));
346            }
347            gpu::Operation::Arithmetic(op) => {
348                self.compile_arithmetic(op, out, instruction.modes, instructions)
349            }
350            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
351            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
352            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
353            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
354            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
355            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
356            gpu::Operation::Synchronization(val) => match val {
357                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
358                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
359                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
360                gpu::Synchronization::SyncAsyncProxyShared => {
361                    self.flags.inst_tma = true;
362                    instructions.push(Instruction::ProxyAsyncToSharedFence)
363                }
364            },
365            gpu::Operation::Plane(op) => {
366                self.flags.indexes.plane_dim_checked = true;
367                let out = self.compile_variable(out.unwrap());
368                match op {
369                    gpu::Plane::Sum(op) => {
370                        let instruction = WarpInstruction::ReduceSum {
371                            input: self.compile_variable(op.input),
372                            out,
373                        };
374                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
375                        instructions.push(Instruction::Warp(instruction));
376                    }
377                    gpu::Plane::InclusiveSum(op) => {
378                        self.flags.indexes.unit_pos_plane = true;
379                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
380                            input: self.compile_variable(op.input),
381                            out,
382                        }))
383                    }
384                    gpu::Plane::InclusiveProd(op) => {
385                        self.flags.indexes.unit_pos_plane = true;
386                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
387                            input: self.compile_variable(op.input),
388                            out,
389                        }))
390                    }
391                    gpu::Plane::ExclusiveSum(op) => {
392                        self.flags.indexes.unit_pos_plane = true;
393                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
394                            input: self.compile_variable(op.input),
395                            out,
396                        }))
397                    }
398                    gpu::Plane::ExclusiveProd(op) => {
399                        self.flags.indexes.unit_pos_plane = true;
400                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
401                            input: self.compile_variable(op.input),
402                            out,
403                        }))
404                    }
405                    gpu::Plane::Prod(op) => {
406                        let instruction = WarpInstruction::ReduceProd {
407                            input: self.compile_variable(op.input),
408                            out,
409                        };
410                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
411                        instructions.push(Instruction::Warp(instruction))
412                    }
413                    gpu::Plane::Max(op) => {
414                        let instruction = WarpInstruction::ReduceMax {
415                            input: self.compile_variable(op.input),
416                            out,
417                        };
418                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
419                        instructions.push(Instruction::Warp(instruction))
420                    }
421                    gpu::Plane::Min(op) => {
422                        let instruction = WarpInstruction::ReduceMin {
423                            input: self.compile_variable(op.input),
424                            out,
425                        };
426                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
427                        instructions.push(Instruction::Warp(instruction))
428                    }
429                    gpu::Plane::Elect => {
430                        if self.compilation_options.supports_features.elect_sync {
431                            self.flags.inst_ptx_wrappers = true;
432                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
433                        } else {
434                            instructions
435                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
436                        }
437                    }
438                    gpu::Plane::All(op) => {
439                        instructions.push(Instruction::Warp(WarpInstruction::All {
440                            input: self.compile_variable(op.input),
441                            out,
442                        }))
443                    }
444                    gpu::Plane::Any(op) => {
445                        instructions.push(Instruction::Warp(WarpInstruction::Any {
446                            input: self.compile_variable(op.input),
447                            out,
448                        }))
449                    }
450                    gpu::Plane::Ballot(op) => {
451                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
452                            input: self.compile_variable(op.input),
453                            out,
454                        }))
455                    }
456                    gpu::Plane::Broadcast(op) => {
457                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
458                            input: self.compile_variable(op.lhs),
459                            id: self.compile_variable(op.rhs),
460                            out,
461                        }))
462                    }
463                    gpu::Plane::Shuffle(op) => {
464                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
465                            input: self.compile_variable(op.lhs),
466                            src_lane: self.compile_variable(op.rhs),
467                            out,
468                        }))
469                    }
470                    gpu::Plane::ShuffleXor(op) => {
471                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
472                            input: self.compile_variable(op.lhs),
473                            mask: self.compile_variable(op.rhs),
474                            out,
475                        }))
476                    }
477                    gpu::Plane::ShuffleUp(op) => {
478                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
479                            input: self.compile_variable(op.lhs),
480                            delta: self.compile_variable(op.rhs),
481                            out,
482                        }))
483                    }
484                    gpu::Plane::ShuffleDown(op) => {
485                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
486                            input: self.compile_variable(op.lhs),
487                            delta: self.compile_variable(op.rhs),
488                            out,
489                        }))
490                    }
491                }
492            }
493            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
494            gpu::Operation::NonSemantic(debug) => match debug {
495                gpu::NonSemantic::Print {
496                    format_string,
497                    args,
498                } => instructions.push(Instruction::Printf {
499                    format_string,
500                    args: args
501                        .into_iter()
502                        .map(|arg| self.compile_variable(arg))
503                        .collect(),
504                }),
505                gpu::NonSemantic::Comment { content } => {
506                    instructions.push(Instruction::Comment { content })
507                }
508                // Don't need to handle scopes
509                _ => {}
510            },
511            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
512                gpu::BarrierOps::Declare { barrier } => {
513                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
514                    else {
515                        unreachable!()
516                    };
517                    let barrier = self.compile_variable(barrier);
518                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
519                        barrier,
520                        level,
521                    }));
522                }
523                gpu::BarrierOps::Init {
524                    barrier,
525                    is_elected,
526                    arrival_count,
527                } => {
528                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
529                    else {
530                        unreachable!()
531                    };
532                    let barrier = self.compile_variable(barrier);
533                    let arrival_count = self.compile_variable(arrival_count);
534                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
535                        barrier,
536                        is_elected: self.compile_variable(is_elected),
537                        arrival_count,
538                        level,
539                    }));
540                }
541                gpu::BarrierOps::InitManual {
542                    barrier,
543                    arrival_count,
544                } => {
545                    let barrier = self.compile_variable(barrier);
546                    let arrival_count = self.compile_variable(arrival_count);
547                    instructions.push(Instruction::Barrier(
548                        super::barrier::BarrierOps::InitManual {
549                            barrier,
550                            arrival_count,
551                        },
552                    ));
553                }
554                gpu::BarrierOps::MemCopyAsync {
555                    barrier,
556                    source,
557                    source_length,
558                    offset_source,
559                    offset_out,
560                } => {
561                    instructions.push(Instruction::Barrier(
562                        super::barrier::BarrierOps::MemCopyAsync {
563                            barrier: self.compile_variable(barrier),
564                            source: self.compile_variable(source),
565                            destination: self.compile_variable(out.unwrap()),
566                            source_length: self.compile_variable(source_length),
567                            offset_source: self.compile_variable(offset_source),
568                            offset_out: self.compile_variable(offset_out),
569                            cooperative: false,
570                        },
571                    ));
572                }
573                gpu::BarrierOps::MemCopyAsyncCooperative {
574                    barrier,
575                    source,
576                    source_length,
577                    offset_source,
578                    offset_out,
579                } => {
580                    instructions.push(Instruction::Barrier(
581                        super::barrier::BarrierOps::MemCopyAsync {
582                            barrier: self.compile_variable(barrier),
583                            source: self.compile_variable(source),
584                            destination: self.compile_variable(out.unwrap()),
585                            source_length: self.compile_variable(source_length),
586                            offset_source: self.compile_variable(offset_source),
587                            offset_out: self.compile_variable(offset_out),
588                            cooperative: true,
589                        },
590                    ));
591                }
592                gpu::BarrierOps::MemCopyAsyncTx {
593                    barrier,
594                    source,
595                    source_length,
596                    offset_source,
597                    offset_out,
598                } => {
599                    instructions.push(Instruction::Barrier(
600                        super::barrier::BarrierOps::MemCopyAsyncTx {
601                            barrier: self.compile_variable(barrier),
602                            source: self.compile_variable(source),
603                            destination: self.compile_variable(out.unwrap()),
604                            source_length: self.compile_variable(source_length),
605                            offset_source: self.compile_variable(offset_source),
606                            offset_out: self.compile_variable(offset_out),
607                        },
608                    ));
609                }
610                gpu::BarrierOps::CopyAsync {
611                    source,
612                    source_length,
613                    offset_source,
614                    offset_out,
615                    copy_length,
616                    checked,
617                } => {
618                    self.flags.inst_async_copy = true;
619                    instructions.push(Instruction::Barrier(
620                        super::barrier::BarrierOps::CopyAsync {
621                            source: self.compile_variable(source),
622                            destination: self.compile_variable(out.unwrap()),
623                            source_length: self.compile_variable(source_length),
624                            offset_source: self.compile_variable(offset_source),
625                            offset_out: self.compile_variable(offset_out),
626                            copy_size: copy_length,
627                            checked,
628                        },
629                    ));
630                }
631                gpu::BarrierOps::TmaLoad {
632                    barrier,
633                    tensor_map,
634                    offset_out,
635                    indices,
636                } => {
637                    instructions.push(Instruction::Barrier(
638                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
639                            barrier: self.compile_variable(barrier),
640                            smem_buffer: self.compile_variable(out.unwrap()),
641                            smem_offset: self.compile_variable(offset_out),
642                            tensor_map: self.compile_variable(tensor_map),
643                            indices: indices
644                                .into_iter()
645                                .map(|it| self.compile_variable(it))
646                                .collect(),
647                        },
648                    ));
649                }
650                gpu::BarrierOps::TmaLoadIm2col {
651                    barrier,
652                    tensor_map,
653                    offset_out,
654                    indices,
655                    offsets,
656                } => {
657                    self.flags.inst_tma_im2col = true;
658                    instructions.push(Instruction::Barrier(
659                        super::barrier::BarrierOps::TmaLoadIm2col {
660                            barrier: self.compile_variable(barrier),
661                            smem_buffer: self.compile_variable(out.unwrap()),
662                            smem_offset: self.compile_variable(offset_out),
663                            tensor_map: self.compile_variable(tensor_map),
664                            indices: indices
665                                .into_iter()
666                                .map(|it| self.compile_variable(it))
667                                .collect(),
668                            offsets: offsets
669                                .into_iter()
670                                .map(|it| self.compile_variable(it))
671                                .collect(),
672                        },
673                    ));
674                }
675                gpu::BarrierOps::Arrive { barrier } => {
676                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
677                        barrier: self.compile_variable(barrier),
678                        token: self.compile_variable(out.unwrap()),
679                    }))
680                }
681                gpu::BarrierOps::ArriveTx {
682                    barrier,
683                    arrive_count_update,
684                    transaction_count_update,
685                } => {
686                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
687                        barrier: self.compile_variable(barrier),
688                        token: self.compile_variable(out.unwrap()),
689                        arrive_count_update: self.compile_variable(arrive_count_update),
690                        transaction_count_update: self.compile_variable(transaction_count_update),
691                    }))
692                }
693                gpu::BarrierOps::CommitCopyAsync { barrier } => {
694                    self.flags.inst_async_copy = true;
695                    instructions.push(Instruction::Barrier(
696                        super::barrier::BarrierOps::ArriveCopyAsync {
697                            barrier: self.compile_variable(barrier),
698                        },
699                    ))
700                }
701                gpu::BarrierOps::ExpectTx {
702                    barrier,
703                    transaction_count_update,
704                } => {
705                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
706                        barrier: self.compile_variable(barrier),
707                        transaction_count_update: self.compile_variable(transaction_count_update),
708                    }))
709                }
710                gpu::BarrierOps::Wait { barrier, token } => {
711                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
712                        barrier: self.compile_variable(barrier),
713                        token: self.compile_variable(token),
714                    }))
715                }
716                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
717                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
718                        barrier: self.compile_variable(barrier),
719                        phase: self.compile_variable(phase),
720                    }),
721                ),
722                gpu::BarrierOps::ArriveAndWait { barrier } => {
723                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
724                    else {
725                        unreachable!()
726                    };
727                    instructions.push(Instruction::Barrier(
728                        super::barrier::BarrierOps::ArriveAndWait {
729                            barrier: self.compile_variable(barrier),
730                            level,
731                        },
732                    ))
733                }
734            },
735            gpu::Operation::Tma(tma_ops) => {
736                self.flags.inst_tma = true;
737                match tma_ops {
738                    gpu::TmaOps::TmaStore {
739                        source,
740                        coordinates,
741                        offset_source,
742                    } => {
743                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
744                            smem_buffer: self.compile_variable(source),
745                            smem_offset: self.compile_variable(offset_source),
746                            tensor_map: self.compile_variable(out.unwrap()),
747                            indices: coordinates
748                                .into_iter()
749                                .map(|it| self.compile_variable(it))
750                                .collect(),
751                        });
752                    }
753                    gpu::TmaOps::CommitGroup => {
754                        instructions.push(Instruction::BulkCommitGroup);
755                    }
756                    gpu::TmaOps::WaitGroup { max_pending } => {
757                        instructions.push(Instruction::BulkWaitGroup { max_pending });
758                    }
759                    gpu::TmaOps::WaitGroupRead { max_pending } => {
760                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
761                    }
762                }
763            }
764            gpu::Operation::Marker(_) => {}
765        }
766    }
767
768    fn update_debug_loc(
769        &mut self,
770        instructions: &mut Vec<Instruction<D>>,
771        inst: &gpu::Instruction,
772    ) {
773        if !matches!(inst.operation, Operation::NonSemantic(_)) {
774            match &inst.source_loc {
775                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
776                    self.source_loc = Some(loc.clone());
777                    instructions.push(Instruction::Line {
778                        file: loc.source.file.clone(),
779                        line: loc.line,
780                    });
781                }
782                _ => {}
783            }
784        }
785    }
786
787    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
788        self.flags.inst_wmma = true;
789
790        let out = self.compile_variable(out.unwrap());
791
792        let inst = match cmma {
793            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
794                frag: out,
795                value: self.compile_variable(value),
796            },
797            gpu::CoopMma::Load {
798                value,
799                stride,
800                offset,
801                layout,
802            } => WmmaInstruction::Load {
803                frag: out,
804                offset: self.compile_variable(offset),
805                value: self.compile_variable(value),
806                stride: self.compile_variable(stride),
807                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
808            },
809            gpu::CoopMma::Execute {
810                mat_a,
811                mat_b,
812                mat_c,
813            } => WmmaInstruction::Execute {
814                frag_a: self.compile_variable(mat_a),
815                frag_b: self.compile_variable(mat_b),
816                frag_c: self.compile_variable(mat_c),
817                frag_d: out,
818                warp_size: self.compilation_options.warp_size,
819            },
820            gpu::CoopMma::ExecuteManual {
821                matrix,
822                registers_a,
823                registers_b,
824                registers_c,
825            } => WmmaInstruction::ExecuteManual {
826                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
827                frag_a: self.compile_variable(registers_a),
828                frag_b: self.compile_variable(registers_b),
829                frag_c: self.compile_variable(registers_c),
830                frag_d: out,
831            },
832            gpu::CoopMma::ExecuteScaled {
833                matrix,
834                registers_a,
835                registers_b,
836                registers_c,
837                scales_a,
838                scales_b,
839                scales_factor,
840            } => WmmaInstruction::ExecuteScaled {
841                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
842                frag_a: self.compile_variable(registers_a),
843                frag_b: self.compile_variable(registers_b),
844                frag_c: self.compile_variable(registers_c),
845                frag_d: out,
846
847                scales_a: self.compile_variable(scales_a),
848                scales_b: self.compile_variable(scales_b),
849                scales_factor,
850            },
851            gpu::CoopMma::Store {
852                mat,
853                stride,
854                offset,
855                layout,
856            } => {
857                self.flags.indexes.unit_pos = true;
858                self.flags.indexes.plane_index = true;
859                WmmaInstruction::Store {
860                    output: out,
861                    offset: self.compile_variable(offset),
862                    frag: self.compile_variable(mat),
863                    stride: self.compile_variable(stride),
864                    layout: self
865                        .compile_matrix_layout(layout)
866                        .expect("Layout required for store instruction"),
867                }
868            }
869            gpu::CoopMma::LoadMatrix {
870                buffer,
871                offset,
872                line_size,
873                factor,
874                transpose,
875            } => WmmaInstruction::LdMatrix {
876                output: out,
877                buffer: self.compile_variable(buffer),
878                offset: self.compile_variable(offset),
879                line_size,
880                factor,
881                transpose,
882            },
883            gpu::CoopMma::StoreMatrix {
884                offset,
885                line_size,
886                registers,
887                factor,
888                transpose,
889            } => WmmaInstruction::StMatrix {
890                registers: self.compile_variable(registers),
891                buffer: out,
892                offset: self.compile_variable(offset),
893                line_size,
894                factor,
895                transpose,
896            },
897            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
898                input: self.compile_variable(input),
899                output: out,
900            },
901            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
902                panic!("Row/Col index should be handled by processors")
903            }
904        };
905
906        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
907
908        Instruction::Wmma(inst)
909    }
910
911    fn compile_metadata(
912        &mut self,
913        metadata: gpu::Metadata,
914        out: Option<gpu::Variable>,
915    ) -> Instruction<D> {
916        let out = out.unwrap();
917        match metadata {
918            gpu::Metadata::Stride { dim, var } => {
919                let position = self.ext_meta_position(var);
920                let offset = self.metadata.stride_offset_index(position);
921                Instruction::ExtendedMetadata {
922                    info_offset: self.compile_variable(offset.into()),
923                    dim: self.compile_variable(dim),
924                    split_meta: self.compilation_options.supports_features.grid_constants,
925                    static_offset: self.metadata.static_len(),
926                    out: self.compile_variable(out),
927                }
928            }
929            gpu::Metadata::Shape { dim, var } => {
930                let position = self.ext_meta_position(var);
931                let offset = self.metadata.shape_offset_index(position);
932                Instruction::ExtendedMetadata {
933                    info_offset: self.compile_variable(offset.into()),
934                    dim: self.compile_variable(dim),
935                    split_meta: self.compilation_options.supports_features.grid_constants,
936                    static_offset: self.metadata.static_len(),
937                    out: self.compile_variable(out),
938                }
939            }
940            gpu::Metadata::Rank { var } => {
941                let out = self.compile_variable(out);
942                let pos = self.ext_meta_position(var);
943                let offset = self.metadata.rank_index(pos);
944                super::Instruction::Metadata {
945                    info_offset: self.compile_variable(offset.into()),
946                    split_meta: self.compilation_options.supports_features.grid_constants,
947                    out,
948                }
949            }
950            gpu::Metadata::Length { var } => {
951                let input = self.compile_variable(var);
952                let out = self.compile_variable(out);
953
954                match input {
955                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
956                    Variable::SharedArray(_id, _item, length) => {
957                        Instruction::ConstLength { length, out }
958                    }
959                    _ => {
960                        let id = input.id().expect("Variable should have id");
961                        let offset = self.metadata.len_index(id);
962                        Instruction::Metadata {
963                            info_offset: self.compile_variable(offset.into()),
964                            split_meta: self.compilation_options.supports_features.grid_constants,
965                            out,
966                        }
967                    }
968                }
969            }
970            gpu::Metadata::BufferLength { var } => {
971                let input = self.compile_variable(var);
972                let out = self.compile_variable(out);
973
974                match input {
975                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
976                    _ => {
977                        let id = input.id().expect("Variable should have id");
978                        let offset = self.metadata.buffer_len_index(id);
979                        Instruction::Metadata {
980                            info_offset: self.compile_variable(offset.into()),
981                            split_meta: self.compilation_options.supports_features.grid_constants,
982                            out,
983                        }
984                    }
985                }
986            }
987        }
988    }
989
990    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
991        match branch {
992            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
993                cond: self.compile_variable(op.cond),
994                instructions: self.compile_scope(&mut op.scope),
995            }),
996            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
997                cond: self.compile_variable(op.cond),
998                instructions_if: self.compile_scope(&mut op.scope_if),
999                instructions_else: self.compile_scope(&mut op.scope_else),
1000            }),
1001            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1002                value: self.compile_variable(op.value),
1003                instructions_default: self.compile_scope(&mut op.scope_default),
1004                instructions_cases: op
1005                    .cases
1006                    .into_iter()
1007                    .map(|(val, mut block)| {
1008                        (self.compile_variable(val), self.compile_scope(&mut block))
1009                    })
1010                    .collect(),
1011            }),
1012            gpu::Branch::Return => instructions.push(Instruction::Return),
1013            gpu::Branch::Break => instructions.push(Instruction::Break),
1014            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1015                i: self.compile_variable(range_loop.i),
1016                start: self.compile_variable(range_loop.start),
1017                end: self.compile_variable(range_loop.end),
1018                step: range_loop.step.map(|it| self.compile_variable(it)),
1019                inclusive: range_loop.inclusive,
1020                instructions: self.compile_scope(&mut range_loop.scope),
1021            }),
1022            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1023                instructions: self.compile_scope(&mut op.scope),
1024            }),
1025        };
1026    }
1027
1028    fn compile_atomic(
1029        &mut self,
1030        value: gpu::AtomicOp,
1031        out: Option<gpu::Variable>,
1032        instructions: &mut Vec<Instruction<D>>,
1033    ) {
1034        let out = out.unwrap();
1035        match value {
1036            gpu::AtomicOp::Load(op) => {
1037                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1038            }
1039            gpu::AtomicOp::Store(op) => {
1040                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1041            }
1042            gpu::AtomicOp::Swap(op) => {
1043                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1044            }
1045            gpu::AtomicOp::Add(op) => {
1046                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1047            }
1048            gpu::AtomicOp::Sub(op) => {
1049                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1050            }
1051            gpu::AtomicOp::Max(op) => {
1052                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1053            }
1054            gpu::AtomicOp::Min(op) => {
1055                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1056            }
1057            gpu::AtomicOp::And(op) => {
1058                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1059            }
1060            gpu::AtomicOp::Or(op) => {
1061                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1062            }
1063            gpu::AtomicOp::Xor(op) => {
1064                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1065            }
1066            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1067                input: self.compile_variable(op.input),
1068                cmp: self.compile_variable(op.cmp),
1069                val: self.compile_variable(op.val),
1070                out: self.compile_variable(out),
1071            }),
1072        }
1073    }
1074
1075    fn compile_arithmetic(
1076        &mut self,
1077        value: gpu::Arithmetic,
1078        out: Option<gpu::Variable>,
1079        modes: InstructionModes,
1080        instructions: &mut Vec<Instruction<D>>,
1081    ) {
1082        let out = out.unwrap();
1083        match value {
1084            gpu::Arithmetic::Add(op) => {
1085                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1086            }
1087            gpu::Arithmetic::SaturatingAdd(op) => {
1088                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1089            }
1090            gpu::Arithmetic::Mul(op) => {
1091                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1092            }
1093            gpu::Arithmetic::Div(op) => {
1094                let op = self.compile_binary(op, out);
1095                instructions.push(self.select_fast_float(
1096                    out.ty,
1097                    modes,
1098                    FastMath::AllowReciprocal
1099                        | FastMath::ReducedPrecision
1100                        | FastMath::UnsignedZero
1101                        | FastMath::NotInf,
1102                    Instruction::Div(op),
1103                    Instruction::FastDiv(op),
1104                ))
1105            }
1106            gpu::Arithmetic::Sub(op) => {
1107                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1108            }
1109            gpu::Arithmetic::SaturatingSub(op) => {
1110                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1111            }
1112            gpu::Arithmetic::MulHi(op) => {
1113                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1114                D::register_instruction_extension(&mut self.extensions, &instruction);
1115                instructions.push(instruction)
1116            }
1117            gpu::Arithmetic::Modulo(op) => {
1118                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1119            }
1120            gpu::Arithmetic::Abs(op) => {
1121                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1122            }
1123            gpu::Arithmetic::Exp(op) => {
1124                let op = self.compile_unary(op, out);
1125                instructions.push(self.select_fast_float(
1126                    out.ty,
1127                    modes,
1128                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1129                    Instruction::Exp(op),
1130                    Instruction::FastExp(op),
1131                ));
1132            }
1133            gpu::Arithmetic::Log(op) => {
1134                let op = self.compile_unary(op, out);
1135                instructions.push(self.select_fast_float(
1136                    out.ty,
1137                    modes,
1138                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1139                    Instruction::Log(op),
1140                    Instruction::FastLog(op),
1141                ));
1142            }
1143            gpu::Arithmetic::Log1p(op) => {
1144                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1145            }
1146            gpu::Arithmetic::Cos(op) => {
1147                let op = self.compile_unary(op, out);
1148                instructions.push(self.select_fast_float(
1149                    out.ty,
1150                    modes,
1151                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1152                    Instruction::Cos(op),
1153                    Instruction::FastCos(op),
1154                ));
1155            }
1156            gpu::Arithmetic::Sin(op) => {
1157                let op = self.compile_unary(op, out);
1158                instructions.push(self.select_fast_float(
1159                    out.ty,
1160                    modes,
1161                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1162                    Instruction::Sin(op),
1163                    Instruction::FastSin(op),
1164                ));
1165            }
1166            gpu::Arithmetic::Tan(op) => {
1167                instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1168            }
1169            gpu::Arithmetic::Tanh(op) => {
1170                let op = self.compile_unary(op, out);
1171                let instruction = Instruction::Tanh(op);
1172                D::register_instruction_extension(&mut self.extensions, &instruction);
1173                if self.compilation_options.supports_features.fast_tanh {
1174                    instructions.push(self.select_fast_float(
1175                        out.ty,
1176                        modes,
1177                        FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1178                        instruction,
1179                        Instruction::FastTanh(op),
1180                    ))
1181                } else {
1182                    instructions.push(instruction);
1183                }
1184            }
1185            gpu::Arithmetic::Sinh(op) => {
1186                let instruction = Instruction::Sinh(self.compile_unary(op, out));
1187                D::register_instruction_extension(&mut self.extensions, &instruction);
1188                instructions.push(instruction)
1189            }
1190            gpu::Arithmetic::Cosh(op) => {
1191                let instruction = Instruction::Cosh(self.compile_unary(op, out));
1192                D::register_instruction_extension(&mut self.extensions, &instruction);
1193                instructions.push(instruction)
1194            }
1195            gpu::Arithmetic::ArcCos(op) => {
1196                let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1197                D::register_instruction_extension(&mut self.extensions, &instruction);
1198                instructions.push(instruction)
1199            }
1200            gpu::Arithmetic::ArcSin(op) => {
1201                let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1202                D::register_instruction_extension(&mut self.extensions, &instruction);
1203                instructions.push(instruction)
1204            }
1205            gpu::Arithmetic::ArcTan(op) => {
1206                let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1207                D::register_instruction_extension(&mut self.extensions, &instruction);
1208                instructions.push(instruction)
1209            }
1210            gpu::Arithmetic::ArcSinh(op) => {
1211                let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1212                D::register_instruction_extension(&mut self.extensions, &instruction);
1213                instructions.push(instruction)
1214            }
1215            gpu::Arithmetic::ArcCosh(op) => {
1216                let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1217                D::register_instruction_extension(&mut self.extensions, &instruction);
1218                instructions.push(instruction)
1219            }
1220            gpu::Arithmetic::ArcTanh(op) => {
1221                let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1222                D::register_instruction_extension(&mut self.extensions, &instruction);
1223                instructions.push(instruction)
1224            }
1225            gpu::Arithmetic::Degrees(op) => {
1226                let instruction = Instruction::Degrees(self.compile_unary(op, out));
1227                D::register_instruction_extension(&mut self.extensions, &instruction);
1228                instructions.push(instruction)
1229            }
1230            gpu::Arithmetic::Radians(op) => {
1231                let instruction = Instruction::Radians(self.compile_unary(op, out));
1232                D::register_instruction_extension(&mut self.extensions, &instruction);
1233                instructions.push(instruction)
1234            }
1235            gpu::Arithmetic::ArcTan2(op) => {
1236                let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1237                D::register_instruction_extension(&mut self.extensions, &instruction);
1238                instructions.push(instruction)
1239            }
1240            gpu::Arithmetic::Powf(op) => {
1241                let op = self.compile_binary(op, out);
1242                instructions.push(self.select_fast_float(
1243                    out.ty,
1244                    modes,
1245                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1246                    Instruction::Powf(op),
1247                    Instruction::FastPowf(op),
1248                ))
1249            }
1250            gpu::Arithmetic::Powi(op) => {
1251                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1252            }
1253            gpu::Arithmetic::Sqrt(op) => {
1254                let op = self.compile_unary(op, out);
1255                instructions.push(self.select_fast_float(
1256                    out.ty,
1257                    modes,
1258                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1259                    Instruction::Sqrt(op),
1260                    Instruction::FastSqrt(op),
1261                ))
1262            }
1263            gpu::Arithmetic::InverseSqrt(op) => {
1264                let op = self.compile_unary(op, out);
1265                instructions.push(self.select_fast_float(
1266                    out.ty,
1267                    modes,
1268                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1269                    Instruction::InverseSqrt(op),
1270                    Instruction::FastInverseSqrt(op),
1271                ))
1272            }
1273            gpu::Arithmetic::Erf(op) => {
1274                let instruction = Instruction::Erf(self.compile_unary(op, out));
1275                D::register_instruction_extension(&mut self.extensions, &instruction);
1276                instructions.push(instruction)
1277            }
1278            gpu::Arithmetic::Max(op) => {
1279                let instruction = Instruction::Max(self.compile_binary(op, out));
1280                D::register_instruction_extension(&mut self.extensions, &instruction);
1281                instructions.push(instruction)
1282            }
1283            gpu::Arithmetic::Min(op) => {
1284                let instruction = Instruction::Min(self.compile_binary(op, out));
1285                D::register_instruction_extension(&mut self.extensions, &instruction);
1286                instructions.push(instruction)
1287            }
1288            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1289                input: self.compile_variable(op.input),
1290                min_value: self.compile_variable(op.min_value),
1291                max_value: self.compile_variable(op.max_value),
1292                out: self.compile_variable(out),
1293            }),
1294            gpu::Arithmetic::Recip(op) => {
1295                let elem = op.input.ty.elem_type();
1296                let input = self.compile_variable(op.input);
1297                let out = self.compile_variable(out);
1298                let lhs = match elem {
1299                    gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1300                    gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1301                    gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1302                    gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1303                };
1304                let div = Instruction::Div(BinaryInstruction {
1305                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1306                    rhs: input,
1307                    out,
1308                });
1309                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1310
1311                instructions.push(self.select_fast_float(
1312                    elem.into(),
1313                    modes,
1314                    FastMath::AllowReciprocal
1315                        | FastMath::ReducedPrecision
1316                        | FastMath::UnsignedZero
1317                        | FastMath::NotInf,
1318                    div,
1319                    recip,
1320                ))
1321            }
1322            gpu::Arithmetic::Round(op) => {
1323                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1324            }
1325            gpu::Arithmetic::Floor(op) => {
1326                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1327            }
1328            gpu::Arithmetic::Ceil(op) => {
1329                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1330            }
1331            gpu::Arithmetic::Trunc(op) => {
1332                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1333            }
1334            gpu::Arithmetic::Remainder(op) => {
1335                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1336            }
1337            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1338                a: self.compile_variable(op.a),
1339                b: self.compile_variable(op.b),
1340                c: self.compile_variable(op.c),
1341                out: self.compile_variable(out),
1342            }),
1343            gpu::Arithmetic::Neg(op) => {
1344                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1345            }
1346            gpu::Arithmetic::Normalize(op) => {
1347                let op = self.compile_unary(op, out);
1348                instructions.push(self.select_fast_float(
1349                    out.ty,
1350                    modes,
1351                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1352                    Instruction::Normalize(op),
1353                    Instruction::FastNormalize(op),
1354                ))
1355            }
1356            gpu::Arithmetic::Magnitude(op) => {
1357                let op = self.compile_unary(op, out);
1358                instructions.push(self.select_fast_float(
1359                    out.ty,
1360                    modes,
1361                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1362                    Instruction::Magnitude(op),
1363                    Instruction::FastMagnitude(op),
1364                ))
1365            }
1366            gpu::Arithmetic::Dot(op) => {
1367                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1368            }
1369        };
1370    }
1371
1372    fn select_fast_float(
1373        &self,
1374        ty: gpu::Type,
1375        modes: InstructionModes,
1376        required_flags: EnumSet<FastMath>,
1377        default: Instruction<D>,
1378        fast: Instruction<D>,
1379    ) -> Instruction<D> {
1380        if !self.compilation_options.supports_features.fast_math
1381            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1382        {
1383            return default;
1384        }
1385
1386        if modes.fp_math_mode.is_superset(required_flags) {
1387            fast
1388        } else {
1389            default
1390        }
1391    }
1392
1393    fn compile_comparison(
1394        &mut self,
1395        value: gpu::Comparison,
1396        out: Option<gpu::Variable>,
1397        instructions: &mut Vec<Instruction<D>>,
1398    ) {
1399        let out = out.unwrap();
1400        match value {
1401            gpu::Comparison::Equal(op) => {
1402                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1403            }
1404            gpu::Comparison::Lower(op) => {
1405                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1406            }
1407            gpu::Comparison::Greater(op) => {
1408                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1409            }
1410            gpu::Comparison::LowerEqual(op) => {
1411                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1412            }
1413            gpu::Comparison::GreaterEqual(op) => {
1414                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1415            }
1416            gpu::Comparison::NotEqual(op) => {
1417                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1418            }
1419            gpu::Comparison::IsNan(op) => {
1420                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1421            }
1422            gpu::Comparison::IsInf(op) => {
1423                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1424            }
1425        };
1426    }
1427
1428    fn compile_bitwise(
1429        &mut self,
1430        value: gpu::Bitwise,
1431        out: Option<gpu::Variable>,
1432        instructions: &mut Vec<Instruction<D>>,
1433    ) {
1434        let out = out.unwrap();
1435        match value {
1436            gpu::Bitwise::BitwiseOr(op) => {
1437                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1438            }
1439            gpu::Bitwise::BitwiseAnd(op) => {
1440                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1441            }
1442            gpu::Bitwise::BitwiseXor(op) => {
1443                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1444            }
1445            gpu::Bitwise::CountOnes(op) => {
1446                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1447            }
1448            gpu::Bitwise::ReverseBits(op) => {
1449                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1450            }
1451            gpu::Bitwise::ShiftLeft(op) => {
1452                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1453            }
1454            gpu::Bitwise::ShiftRight(op) => {
1455                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1456            }
1457            gpu::Bitwise::BitwiseNot(op) => {
1458                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1459            }
1460            gpu::Bitwise::LeadingZeros(op) => {
1461                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1462            }
1463            gpu::Bitwise::FindFirstSet(op) => {
1464                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1465                D::register_instruction_extension(&mut self.extensions, &instruction);
1466                instructions.push(instruction)
1467            }
1468        };
1469    }
1470
1471    fn compile_operator(
1472        &mut self,
1473        value: gpu::Operator,
1474        out: Option<gpu::Variable>,
1475        instructions: &mut Vec<Instruction<D>>,
1476    ) {
1477        let out = out.unwrap();
1478        match value {
1479            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1480                instructions.push(Instruction::Index(self.compile_index(op, out)));
1481            }
1482            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1483                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1484            }
1485            gpu::Operator::And(op) => {
1486                instructions.push(Instruction::And(self.compile_binary(op, out)))
1487            }
1488            gpu::Operator::Or(op) => {
1489                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1490            }
1491            gpu::Operator::Not(op) => {
1492                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1493            }
1494            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1495                inputs: op
1496                    .inputs
1497                    .into_iter()
1498                    .map(|it| self.compile_variable(it))
1499                    .collect(),
1500                out: self.compile_variable(out),
1501            }),
1502            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1503                input: self.compile_variable(op.input),
1504                in_index: self.compile_variable(op.in_index),
1505                out: self.compile_variable(out),
1506                out_index: self.compile_variable(op.out_index),
1507            }),
1508            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1509                input: self.compile_variable(op.input),
1510                in_index: self.compile_variable(op.in_index),
1511                out: self.compile_variable(out),
1512                out_index: self.compile_variable(op.out_index),
1513                len: op.len,
1514            }),
1515            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1516                cond: self.compile_variable(op.cond),
1517                then: self.compile_variable(op.then),
1518                or_else: self.compile_variable(op.or_else),
1519                out: self.compile_variable(out),
1520            }),
1521            // Needs special conversion semantics
1522            gpu::Operator::Cast(op)
1523                if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1524            {
1525                // We may need these for intermediates
1526                self.flags.elem_f16 = true;
1527                self.flags.elem_bf16 = true;
1528                let vec_in = op.input.ty.line_size();
1529                let packing = out.storage_type().packing_factor();
1530                self.compile_type(op.input.ty.line(packing));
1531                self.compile_type(
1532                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1533                );
1534                self.compile_type(
1535                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1536                );
1537                self.compile_type(
1538                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1539                );
1540                self.compile_type(
1541                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1542                );
1543
1544                let inst = self.compile_unary(op, out);
1545
1546                instructions.push(Instruction::SpecialCast(inst));
1547            }
1548            gpu::Operator::Cast(op) => {
1549                let op = self.compile_unary(op, out);
1550
1551                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1552                    self.flags.elem_tf32 = true;
1553                }
1554
1555                instructions.push(Instruction::Assign(op))
1556            }
1557            gpu::Operator::Reinterpret(op) => {
1558                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1559            }
1560        };
1561    }
1562
1563    fn compile_binary(
1564        &mut self,
1565        value: gpu::BinaryOperator,
1566        out: gpu::Variable,
1567    ) -> BinaryInstruction<D> {
1568        BinaryInstruction {
1569            lhs: self.compile_variable(value.lhs),
1570            rhs: self.compile_variable(value.rhs),
1571            out: self.compile_variable(out),
1572        }
1573    }
1574
1575    fn compile_index_assign(
1576        &mut self,
1577        value: gpu::IndexAssignOperator,
1578        out: gpu::Variable,
1579    ) -> IndexAssignInstruction<D> {
1580        IndexAssignInstruction {
1581            index: self.compile_variable(value.index),
1582            value: self.compile_variable(value.value),
1583            line_size: value.line_size,
1584            out: self.compile_variable(out),
1585        }
1586    }
1587
1588    fn compile_index(
1589        &mut self,
1590        value: gpu::IndexOperator,
1591        out: gpu::Variable,
1592    ) -> IndexInstruction<D> {
1593        IndexInstruction {
1594            list: self.compile_variable(value.list),
1595            index: self.compile_variable(value.index),
1596            line_size: value.line_size,
1597            out: self.compile_variable(out),
1598        }
1599    }
1600
1601    fn compile_unary(
1602        &mut self,
1603        value: gpu::UnaryOperator,
1604        out: gpu::Variable,
1605    ) -> UnaryInstruction<D> {
1606        UnaryInstruction {
1607            input: self.compile_variable(value.input),
1608            out: self.compile_variable(out),
1609        }
1610    }
1611
1612    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1613        let item = value.ty;
1614        match value.kind {
1615            gpu::VariableKind::GlobalInputArray(id) => {
1616                Variable::GlobalInputArray(id, self.compile_type(item))
1617            }
1618            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1619                id,
1620                elem: self.compile_storage_type(item.storage_type()),
1621                in_struct: self.compilation_options.supports_features.grid_constants,
1622            },
1623            gpu::VariableKind::TensorMapInput(id) => {
1624                self.flags.inst_tma = true;
1625                Variable::TensorMap(id)
1626            }
1627            gpu::VariableKind::TensorMapOutput(id) => {
1628                self.flags.inst_tma = true;
1629                Variable::TensorMap(id)
1630            }
1631            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1632                id,
1633                item: self.compile_type(item),
1634            },
1635            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1636                id,
1637                item: self.compile_type(item),
1638            },
1639            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1640                id,
1641                item: self.compile_type(item),
1642            },
1643            gpu::VariableKind::GlobalOutputArray(id) => {
1644                Variable::GlobalOutputArray(id, self.compile_type(item))
1645            }
1646            gpu::VariableKind::ConstantScalar(value) => {
1647                Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1648            }
1649            gpu::VariableKind::SharedArray { id, length, .. } => {
1650                let item = self.compile_type(item);
1651                Variable::SharedArray(id, item, length)
1652            }
1653            gpu::VariableKind::Shared { id } => {
1654                let item = self.compile_type(item);
1655                Variable::Shared(id, item)
1656            }
1657            gpu::VariableKind::ConstantArray {
1658                id,
1659                length,
1660                unroll_factor,
1661            } => {
1662                let item = self.compile_type(item);
1663                Variable::ConstantArray(id, item, length * unroll_factor)
1664            }
1665            gpu::VariableKind::Builtin(builtin) => match builtin {
1666                gpu::Builtin::AbsolutePos => {
1667                    self.flags.indexes.absolute_pos = true;
1668                    Variable::AbsolutePos
1669                }
1670                gpu::Builtin::CubePosCluster
1671                    if self.compilation_options.supports_features.clusters =>
1672                {
1673                    self.flags.indexes.cluster_pos = true;
1674                    Variable::ClusterRank
1675                }
1676                gpu::Builtin::CubePosClusterX
1677                    if self.compilation_options.supports_features.clusters =>
1678                {
1679                    self.flags.indexes.cluster_pos = true;
1680                    Variable::ClusterIndexX
1681                }
1682                gpu::Builtin::CubePosClusterY
1683                    if self.compilation_options.supports_features.clusters =>
1684                {
1685                    self.flags.indexes.cluster_pos = true;
1686                    Variable::ClusterIndexY
1687                }
1688                gpu::Builtin::CubePosClusterZ
1689                    if self.compilation_options.supports_features.clusters =>
1690                {
1691                    self.flags.indexes.cluster_pos = true;
1692                    Variable::ClusterIndexZ
1693                }
1694                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1695                // (1, 1, 1) if unsupported
1696                gpu::Builtin::CubePosCluster
1697                | gpu::Builtin::CubePosClusterX
1698                | gpu::Builtin::CubePosClusterY
1699                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1700                gpu::Builtin::AbsolutePosX => {
1701                    self.flags.indexes.absolute_pos_tuple = true;
1702                    Variable::AbsolutePosX
1703                }
1704                gpu::Builtin::AbsolutePosY => {
1705                    self.flags.indexes.absolute_pos_tuple = true;
1706                    Variable::AbsolutePosY
1707                }
1708                gpu::Builtin::AbsolutePosZ => {
1709                    self.flags.indexes.absolute_pos_tuple = true;
1710                    Variable::AbsolutePosZ
1711                }
1712                gpu::Builtin::CubeDim => {
1713                    self.flags.indexes.cube_dim = true;
1714                    Variable::CubeDim
1715                }
1716                gpu::Builtin::CubeDimX => {
1717                    self.flags.indexes.cube_dim_tuple = true;
1718                    Variable::CubeDimX
1719                }
1720                gpu::Builtin::CubeDimY => {
1721                    self.flags.indexes.cube_dim_tuple = true;
1722                    Variable::CubeDimY
1723                }
1724                gpu::Builtin::CubeDimZ => {
1725                    self.flags.indexes.cube_dim_tuple = true;
1726                    Variable::CubeDimZ
1727                }
1728                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1729                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1730                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1731                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1732                gpu::Builtin::CubePos => {
1733                    self.flags.indexes.cube_pos = true;
1734                    Variable::CubePos
1735                }
1736                gpu::Builtin::CubePosX => {
1737                    self.flags.indexes.cube_pos_tuple = true;
1738                    Variable::CubePosX
1739                }
1740                gpu::Builtin::CubePosY => {
1741                    self.flags.indexes.cube_pos_tuple = true;
1742                    Variable::CubePosY
1743                }
1744                gpu::Builtin::CubePosZ => {
1745                    self.flags.indexes.cube_pos_tuple = true;
1746                    Variable::CubePosZ
1747                }
1748                gpu::Builtin::CubeCount => {
1749                    self.flags.indexes.cube_count = true;
1750                    Variable::CubeCount
1751                }
1752                gpu::Builtin::CubeCountX => {
1753                    self.flags.indexes.cube_count_tuple = true;
1754                    Variable::CubeCountX
1755                }
1756                gpu::Builtin::CubeCountY => {
1757                    self.flags.indexes.cube_count_tuple = true;
1758                    Variable::CubeCountY
1759                }
1760                gpu::Builtin::CubeCountZ => {
1761                    self.flags.indexes.cube_count_tuple = true;
1762                    Variable::CubeCountZ
1763                }
1764                gpu::Builtin::UnitPos => {
1765                    self.flags.indexes.unit_pos = true;
1766                    Variable::UnitPos
1767                }
1768                gpu::Builtin::UnitPosX => {
1769                    self.flags.indexes.unit_pos_tuple = true;
1770                    Variable::UnitPosX
1771                }
1772                gpu::Builtin::UnitPosY => {
1773                    self.flags.indexes.unit_pos_tuple = true;
1774                    Variable::UnitPosY
1775                }
1776                gpu::Builtin::UnitPosZ => {
1777                    self.flags.indexes.unit_pos_tuple = true;
1778                    Variable::UnitPosZ
1779                }
1780                gpu::Builtin::PlaneDim => {
1781                    self.flags.indexes.plane_dim = true;
1782                    Variable::PlaneDim
1783                }
1784                gpu::Builtin::UnitPosPlane => {
1785                    self.flags.indexes.unit_pos_plane = true;
1786                    Variable::UnitPosPlane
1787                }
1788            },
1789            gpu::VariableKind::LocalArray {
1790                id,
1791                length,
1792                unroll_factor,
1793            } => {
1794                let item = self.compile_type(item);
1795                if !self.local_arrays.iter().any(|s| s.index == id) {
1796                    self.local_arrays
1797                        .push(LocalArray::new(id, item, length * unroll_factor));
1798                }
1799                Variable::LocalArray(id, item, length)
1800            }
1801            gpu::VariableKind::Matrix { id, mat } => {
1802                self.flags.inst_wmma = true;
1803                Variable::WmmaFragment {
1804                    id,
1805                    frag: self.compile_matrix(mat),
1806                }
1807            }
1808            gpu::VariableKind::Pipeline { id, num_stages } => {
1809                self.flags.op_pipeline = true;
1810                let pipeline = Variable::Pipeline { id };
1811                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1812                    self.pipelines.push(PipelineOps::Init {
1813                        pipeline,
1814                        num_stages,
1815                    });
1816                }
1817                pipeline
1818            }
1819            gpu::VariableKind::BarrierToken { id, level } => {
1820                self.flags.op_barrier = true;
1821                Variable::BarrierToken { id, level }
1822            }
1823        }
1824    }
1825
1826    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1827        Fragment {
1828            ident: self.compile_matrix_ident(matrix.ident),
1829            m: matrix.m,
1830            n: matrix.n,
1831            k: matrix.k,
1832            elem: self.compile_storage_type(matrix.storage),
1833            layout: self.compile_matrix_layout(matrix.layout),
1834        }
1835    }
1836
1837    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1838        match ident {
1839            gpu::MatrixIdent::A => FragmentIdent::A,
1840            gpu::MatrixIdent::B => FragmentIdent::B,
1841            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1842        }
1843    }
1844
1845    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1846        match layout {
1847            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1848            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1849            gpu::MatrixLayout::Undefined => None,
1850        }
1851    }
1852
1853    fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1854        Binding {
1855            id: binding.id,
1856            item: self.compile_type(binding.ty),
1857            location: binding.location,
1858            size: binding.size,
1859            vis: binding.visibility,
1860        }
1861    }
1862
1863    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1864        let item = match ty {
1865            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1866            gpu::Type::Line(ty, line_size) => {
1867                Item::new(self.compile_storage_type(ty), line_size as usize, false)
1868            }
1869            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1870        };
1871        if item.elem != super::Elem::TF32 {
1872            self.items.insert(item);
1873            self.items.insert(item.optimized());
1874        } else {
1875            // TF32 is represented as `float` in C++
1876            let mut item = item;
1877            item.elem = super::Elem::F32;
1878            self.items.insert(item);
1879        }
1880
1881        item
1882    }
1883
1884    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1885        match value {
1886            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1887            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1888            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1889                FloatKind::E2M1 => {
1890                    self.flags.elem_fp4 = true;
1891                    Elem::FP4x2(FP4Kind::E2M1)
1892                }
1893                FloatKind::E2M3 => {
1894                    self.flags.elem_fp6 = true;
1895                    Elem::FP6x2(FP6Kind::E2M3)
1896                }
1897                FloatKind::E3M2 => {
1898                    self.flags.elem_fp6 = true;
1899                    Elem::FP6(FP6Kind::E3M2)
1900                }
1901                FloatKind::E4M3 => {
1902                    self.flags.elem_fp8 = true;
1903                    Elem::FP8x2(FP8Kind::E4M3)
1904                }
1905                FloatKind::E5M2 => {
1906                    self.flags.elem_fp8 = true;
1907                    Elem::FP8x2(FP8Kind::E5M2)
1908                }
1909                FloatKind::UE8M0 => {
1910                    self.flags.elem_fp8 = true;
1911                    Elem::FP8x2(FP8Kind::UE8M0)
1912                }
1913                FloatKind::F16 => {
1914                    self.flags.elem_f16 = true;
1915                    Elem::F16x2
1916                }
1917                FloatKind::BF16 => {
1918                    self.flags.elem_bf16 = true;
1919                    Elem::BF16x2
1920                }
1921                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1922            },
1923            gpu::StorageType::Packed(other, factor) => {
1924                unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
1925            }
1926            gpu::StorageType::Opaque(ty) => match ty {
1927                gpu::OpaqueType::Barrier(level) => {
1928                    self.flags.op_barrier = true;
1929                    Elem::Barrier(level)
1930                }
1931            },
1932        }
1933    }
1934
1935    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1936        match value {
1937            gpu::ElemType::Float(kind) => match kind {
1938                gpu::FloatKind::E2M1 => {
1939                    self.flags.elem_fp4 = true;
1940                    Elem::FP4(FP4Kind::E2M1)
1941                }
1942                gpu::FloatKind::E2M3 => {
1943                    self.flags.elem_fp6 = true;
1944                    Elem::FP6(FP6Kind::E2M3)
1945                }
1946                gpu::FloatKind::E3M2 => {
1947                    self.flags.elem_fp6 = true;
1948                    Elem::FP6(FP6Kind::E3M2)
1949                }
1950                gpu::FloatKind::E4M3 => {
1951                    self.flags.elem_fp8 = true;
1952                    Elem::FP8(FP8Kind::E4M3)
1953                }
1954                gpu::FloatKind::E5M2 => {
1955                    self.flags.elem_fp8 = true;
1956                    Elem::FP8(FP8Kind::E5M2)
1957                }
1958                gpu::FloatKind::UE8M0 => {
1959                    self.flags.elem_fp8 = true;
1960                    Elem::FP8(FP8Kind::UE8M0)
1961                }
1962                gpu::FloatKind::F16 => {
1963                    self.flags.elem_f16 = true;
1964                    Elem::F16
1965                }
1966                gpu::FloatKind::BF16 => {
1967                    self.flags.elem_bf16 = true;
1968                    Elem::BF16
1969                }
1970                gpu::FloatKind::TF32 => Elem::TF32,
1971                gpu::FloatKind::Flex32 => Elem::F32,
1972                gpu::FloatKind::F32 => Elem::F32,
1973                gpu::FloatKind::F64 => Elem::F64,
1974            },
1975            gpu::ElemType::Int(kind) => match kind {
1976                gpu::IntKind::I8 => Elem::I8,
1977                gpu::IntKind::I16 => Elem::I16,
1978                gpu::IntKind::I32 => Elem::I32,
1979                gpu::IntKind::I64 => Elem::I64,
1980            },
1981            gpu::ElemType::UInt(kind) => match kind {
1982                gpu::UIntKind::U8 => Elem::U8,
1983                gpu::UIntKind::U16 => Elem::U16,
1984                gpu::UIntKind::U32 => Elem::U32,
1985                gpu::UIntKind::U64 => Elem::U64,
1986            },
1987            gpu::ElemType::Bool => Elem::Bool,
1988        }
1989    }
1990}
1991
1992fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1993    match elem {
1994        gpu::ElemType::Float(kind) => matches!(
1995            kind,
1996            FloatKind::E2M1
1997                | FloatKind::E2M3
1998                | FloatKind::E3M2
1999                | FloatKind::E4M3
2000                | FloatKind::E5M2
2001                | FloatKind::UE8M0
2002        ),
2003        _ => false,
2004    }
2005}
2006
2007fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2008    Variable::ConstantScalar(
2009        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
2010        Elem::U32,
2011    )
2012}
2013
2014pub fn register_supported_types(props: &mut DeviceProperties) {
2015    let supported_types = [
2016        gpu::ElemType::UInt(gpu::UIntKind::U8),
2017        gpu::ElemType::UInt(gpu::UIntKind::U16),
2018        gpu::ElemType::UInt(gpu::UIntKind::U32),
2019        gpu::ElemType::UInt(gpu::UIntKind::U64),
2020        gpu::ElemType::Int(gpu::IntKind::I8),
2021        gpu::ElemType::Int(gpu::IntKind::I16),
2022        gpu::ElemType::Int(gpu::IntKind::I32),
2023        gpu::ElemType::Int(gpu::IntKind::I64),
2024        gpu::ElemType::Float(gpu::FloatKind::BF16),
2025        gpu::ElemType::Float(gpu::FloatKind::F16),
2026        gpu::ElemType::Float(gpu::FloatKind::F32),
2027        gpu::ElemType::Float(gpu::FloatKind::Flex32),
2028        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
2029        //gpu::Elem::Float(gpu::FloatKind::F64),
2030        gpu::ElemType::Bool,
2031    ];
2032
2033    let supported_atomic_types = [
2034        gpu::ElemType::Int(gpu::IntKind::I32),
2035        gpu::ElemType::Int(gpu::IntKind::I64),
2036        gpu::ElemType::UInt(gpu::UIntKind::U32),
2037        gpu::ElemType::UInt(gpu::UIntKind::U64),
2038        gpu::ElemType::Float(gpu::FloatKind::F32),
2039    ];
2040
2041    for ty in supported_types {
2042        props.register_type_usage(ty, TypeUsage::all_scalar());
2043    }
2044
2045    for ty in supported_atomic_types {
2046        props.register_type_usage(
2047            gpu::StorageType::Atomic(ty),
2048            TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
2049        );
2050    }
2051}