Skip to main content

cubecl_cpp/shared/
base.rs

1use super::{
2    BinaryInstruction, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP4Kind, FP6Kind,
3    FP8Kind, Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction,
4    Instruction, Item, KernelArg, LocalArray, SharedMemory, UnaryInstruction, Variable,
5    WarpInstruction, WmmaInstruction, barrier::BarrierOps, pipeline::PipelineOps,
6};
7use crate::shared::MmaShape;
8use cubecl_common::backtrace::BackTrace;
9use cubecl_core::{
10    CubeDim,
11    ir::{
12        self as gpu, DeviceProperties, ElemType, FloatKind, InstructionModes, OpaqueType,
13        Operation, Processor, SourceLoc, StorageType, Type,
14        features::{AtomicUsage, EnumSet, TypeUsage},
15    },
16    post_processing::checked_io::CheckedIoProcessor,
17    prelude::{FastMath, KernelDefinition},
18    server::ExecutionMode,
19};
20use cubecl_opt::{Optimizer, SharedLiveness};
21use cubecl_runtime::compiler::{CompilationError, Compiler};
22use std::{collections::HashSet, fmt::Debug};
23
24pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
25    std::sync::atomic::AtomicU32::new(0);
26
27#[derive(Clone, Debug)]
28pub struct CompilationOptions {
29    pub warp_size: u32,
30    pub supports_features: CppSupportedFeatures,
31}
32
33#[derive(Clone, Debug, Default)]
34pub struct CppSupportedFeatures {
35    pub grid_constants: bool,
36    pub clusters: bool,
37    pub fast_math: bool,
38    pub fast_tanh: bool,
39    pub elect_sync: bool,
40}
41
42impl Default for CompilationOptions {
43    fn default() -> Self {
44        Self {
45            warp_size: 32,
46            supports_features: Default::default(),
47        }
48    }
49}
50
51/// Cube indexes flags.
52/// When true the corresponding index is declared and computed as needed in the kernel.
53#[derive(Debug, Clone, Default)]
54pub struct CubeIndexFlags {
55    pub absolute_pos: bool,
56    pub absolute_pos_tuple: bool,
57    pub cube_count: bool,
58    pub cube_count_tuple: bool,
59    pub cube_dim: bool,
60    pub cube_dim_tuple: bool,
61    pub cube_pos: bool,
62    pub cube_pos_tuple: bool,
63    pub plane_dim: bool,
64    pub plane_dim_checked: bool,
65    pub plane_pos: bool,
66    pub unit_pos: bool,
67    pub unit_pos_tuple: bool,
68    pub unit_pos_plane: bool,
69    pub cluster_pos: bool,
70}
71
72/// Flags gathered during Cube IR translation for the kernel compilation.
73#[derive(Debug, Clone)]
74pub struct Flags<D: Dialect> {
75    pub elem_fp4: bool,
76    pub elem_fp6: bool,
77    pub elem_fp8: bool,
78    pub elem_bf16: bool,
79    pub elem_f16: bool,
80    pub elem_tf32: bool,
81    pub indexes: CubeIndexFlags,
82    pub op_barrier: bool,
83    pub op_pipeline: bool,
84    pub inst_tma: bool,
85    pub inst_tma_im2col: bool,
86    pub inst_wmma: bool,
87    pub inst_ptx_wrappers: bool,
88    pub inst_async_copy: bool,
89    pub use_grid_constants: bool,
90    pub static_meta_length: usize,
91    pub has_dynamic_meta: bool,
92    pub has_info: bool,
93    pub cube_dim: CubeDim,
94    pub cluster_dim: Option<CubeDim>,
95    pub address_type: Item<D>,
96}
97
98#[allow(clippy::too_many_arguments)]
99#[derive(Clone, Debug)]
100pub struct CppCompiler<D: Dialect> {
101    kernel_name: String,
102    barriers: Vec<BarrierOps<D>>,
103    compilation_options: CompilationOptions,
104    const_arrays: Vec<ConstArray<D>>,
105    ext_meta_positions: Vec<u32>,
106    cluster_dim: CubeDim,
107    extensions: Vec<D::Extension>,
108    flags: Flags<D>,
109    items: HashSet<Item<D>>,
110    local_arrays: Vec<LocalArray<D>>,
111    info: cubecl_core::Info,
112    pipelines: Vec<PipelineOps<D>>,
113    source_loc: Option<SourceLoc>,
114    strategy: ExecutionMode,
115    addr_type: Item<D>,
116}
117
118impl<D: Dialect> Default for Flags<D> {
119    fn default() -> Self {
120        Self {
121            elem_fp4: Default::default(),
122            elem_fp6: Default::default(),
123            elem_fp8: Default::default(),
124            elem_bf16: Default::default(),
125            elem_f16: Default::default(),
126            elem_tf32: Default::default(),
127            indexes: Default::default(),
128            op_barrier: Default::default(),
129            op_pipeline: Default::default(),
130            inst_tma: Default::default(),
131            inst_tma_im2col: Default::default(),
132            inst_wmma: Default::default(),
133            inst_ptx_wrappers: Default::default(),
134            inst_async_copy: Default::default(),
135            use_grid_constants: Default::default(),
136            static_meta_length: Default::default(),
137            has_info: Default::default(),
138            has_dynamic_meta: Default::default(),
139            cube_dim: CubeDim::new_single(),
140            cluster_dim: Default::default(),
141            address_type: Item::scalar(Elem::U32, true),
142        }
143    }
144}
145
146impl<D: Dialect> Default for CppCompiler<D> {
147    fn default() -> Self {
148        Self {
149            kernel_name: Default::default(),
150            barriers: Default::default(),
151            compilation_options: Default::default(),
152            const_arrays: Default::default(),
153            ext_meta_positions: Default::default(),
154            cluster_dim: CubeDim::new_single(),
155            extensions: Default::default(),
156            flags: Flags::default(),
157            items: Default::default(),
158            local_arrays: Default::default(),
159            info: Default::default(),
160            pipelines: Default::default(),
161            source_loc: Default::default(),
162            strategy: Default::default(),
163            addr_type: Item::scalar(Elem::U32, true),
164        }
165    }
166}
167
168impl<D: Dialect> Compiler for CppCompiler<D> {
169    type Representation = ComputeKernel<D>;
170    type CompilationOptions = CompilationOptions;
171
172    fn compile(
173        &mut self,
174        mut kernel: KernelDefinition,
175        compilation_options: &Self::CompilationOptions,
176        strategy: ExecutionMode,
177        addr_type: StorageType,
178    ) -> Result<Self::Representation, CompilationError> {
179        let errors = kernel.body.pop_errors();
180        if !errors.is_empty() {
181            let mut reason = "Can't compile cpp kernel\nCaused by:\n  ".to_string();
182            for error in errors {
183                reason += error.as_str();
184                reason += "\n";
185            }
186
187            return Err(CompilationError::Validation {
188                reason,
189                backtrace: BackTrace::capture(),
190            });
191        }
192
193        self.addr_type = self.compile_type(addr_type.into());
194        self.compilation_options = compilation_options.clone();
195        self.strategy = strategy;
196        self.kernel_name = kernel.options.kernel_name.clone();
197
198        if !self.compilation_options.supports_features.clusters {
199            kernel.options.cluster_dim = None;
200        }
201        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
202
203        let ir = self.clone().compile_ir(kernel, addr_type);
204        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
205        Ok(ir)
206    }
207
208    fn elem_size(&self, elem: gpu::ElemType) -> usize {
209        elem.size()
210    }
211
212    fn extension(&self) -> &'static str {
213        "cpp"
214    }
215}
216
217impl<D: Dialect> CppCompiler<D> {
218    fn compile_ir(
219        mut self,
220        value: KernelDefinition,
221        address_type: StorageType,
222    ) -> ComputeKernel<D> {
223        let metadata = self.build_metadata(&value);
224        self.info = cubecl_core::Info::new(&value.scalars, metadata, address_type);
225
226        let instructions = self.compile_scope(&mut value.body.clone());
227        let tensor_maps = value
228            .tensor_maps
229            .into_iter()
230            .map(|b| self.compile_binding(b))
231            .collect();
232        let buffers = value
233            .buffers
234            .into_iter()
235            .map(|b| self.compile_binding(b))
236            .collect();
237        let scalars = value
238            .scalars
239            .into_iter()
240            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
241            .collect::<Vec<_>>();
242
243        // translation flags
244        let flags = Flags {
245            indexes: D::builtin_rules(&self.flags.indexes),
246            inst_wmma: self.flags.inst_wmma,
247            op_pipeline: self.flags.op_pipeline,
248            op_barrier: self.flags.op_barrier,
249            elem_fp4: self.flags.elem_fp4,
250            elem_fp6: self.flags.elem_fp6,
251            elem_fp8: self.flags.elem_fp8,
252            elem_bf16: self.flags.elem_bf16,
253            elem_f16: self.flags.elem_f16,
254            elem_tf32: self.flags.elem_tf32,
255            inst_tma: self.flags.inst_tma,
256            inst_tma_im2col: self.flags.inst_tma_im2col,
257            inst_async_copy: self.flags.inst_async_copy,
258            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
259            use_grid_constants: self.compilation_options.supports_features.grid_constants,
260            has_info: self.info.has_info(),
261            has_dynamic_meta: self.info.has_dynamic_meta,
262            static_meta_length: self.info.metadata.static_len() as usize,
263            cube_dim: value.cube_dim,
264            cluster_dim: value.options.cluster_dim,
265            address_type: self.compile_type(address_type.into()),
266        };
267
268        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
269        let shared_allocs = opt.analysis::<SharedLiveness>();
270        let shared_memories = shared_allocs
271            .allocations
272            .values()
273            .map(|alloc| match alloc.smem {
274                cubecl_opt::SharedMemory::Array {
275                    id,
276                    length,
277                    ty,
278                    align,
279                } => SharedMemory::Array {
280                    index: id,
281                    item: self.compile_type(ty),
282                    length,
283                    align,
284                    offset: alloc.offset,
285                },
286                cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
287                    index: id,
288                    item: self.compile_type(ty),
289                    align,
290                    offset: alloc.offset,
291                },
292            })
293            .collect();
294
295        let body = Body {
296            instructions,
297            shared_memories,
298            pipelines: self.pipelines,
299            barriers: self.barriers,
300            const_arrays: self.const_arrays,
301            local_arrays: self.local_arrays,
302            info_by_ptr: !self.compilation_options.supports_features.grid_constants,
303            has_dynamic_meta: self.info.has_dynamic_meta,
304            address_type: self.addr_type,
305        };
306
307        let mut cluster_dim = value.options.cluster_dim;
308        if !self.compilation_options.supports_features.clusters {
309            cluster_dim = None;
310        }
311
312        ComputeKernel {
313            tensor_maps,
314            buffers,
315            scalars,
316            meta_static_len: self.info.metadata.static_len() as usize,
317            cube_dim: value.cube_dim,
318            body,
319            extensions: self.extensions,
320            flags,
321            items: self.items,
322            kernel_name: value.options.kernel_name,
323            cluster_dim,
324            info: self.info.clone(),
325        }
326    }
327
328    fn build_metadata(&mut self, value: &KernelDefinition) -> cubecl_core::Metadata {
329        let mut num_ext = 0;
330
331        let mut all_meta: Vec<_> = value
332            .buffers
333            .iter()
334            .chain(value.tensor_maps.iter())
335            .map(|buf| (buf.id, buf.has_extended_meta))
336            .collect();
337
338        all_meta.sort_by_key(|(id, _)| *id);
339
340        for (_, has_extended_meta) in &all_meta {
341            self.ext_meta_positions.push(num_ext);
342            if *has_extended_meta {
343                num_ext += 1;
344            }
345        }
346
347        let num_meta = all_meta.len();
348
349        cubecl_core::Metadata::new(num_meta as u32, num_ext)
350    }
351
352    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
353        let id = var.index().expect("Variable should have index");
354        self.ext_meta_positions[id as usize]
355    }
356
357    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
358        let mut instructions = Vec::new();
359
360        let const_arrays = scope
361            .const_arrays
362            .drain(..)
363            .map(|(var, values)| ConstArray {
364                index: var.index().unwrap(),
365                item: self.compile_type(var.ty),
366                size: values.len() as u32,
367                values: values
368                    .into_iter()
369                    .map(|val| self.compile_variable(val))
370                    .collect(),
371            })
372            .collect::<Vec<_>>();
373        self.const_arrays.extend(const_arrays);
374
375        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(
376            self.strategy,
377            self.kernel_name.clone(),
378        ));
379        let dialect_processors = D::processors();
380        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
381        processors.extend(dialect_processors.iter().map(|it| &**it));
382
383        let processing = scope.process(processors);
384
385        for var in processing.variables {
386            instructions.push(Instruction::DeclareVariable {
387                var: self.compile_variable(var),
388            });
389        }
390
391        processing
392            .instructions
393            .into_iter()
394            .for_each(|op| self.compile_instruction(&mut instructions, op));
395
396        instructions
397    }
398
399    fn compile_instruction(
400        &mut self,
401        instructions: &mut Vec<Instruction<D>>,
402        instruction: gpu::Instruction,
403    ) {
404        self.update_debug_loc(instructions, &instruction);
405        let out = instruction.out;
406        match instruction.operation {
407            gpu::Operation::Copy(variable) => {
408                instructions.push(Instruction::Assign(UnaryInstruction {
409                    input: self.compile_variable(variable),
410                    out: self.compile_variable(out.unwrap()),
411                }));
412            }
413            gpu::Operation::Arithmetic(op) => {
414                self.compile_arithmetic(op, out, instruction.modes, instructions)
415            }
416            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
417            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
418            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
419            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
420            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
421            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
422            gpu::Operation::Synchronization(val) => match val {
423                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
424                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
425                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
426                gpu::Synchronization::SyncAsyncProxyShared => {
427                    self.flags.inst_tma = true;
428                    instructions.push(Instruction::ProxyAsyncToSharedFence)
429                }
430            },
431            gpu::Operation::Plane(op) => {
432                self.flags.indexes.plane_dim_checked = true;
433                let out = self.compile_variable(out.unwrap());
434                match op {
435                    gpu::Plane::Sum(op) => {
436                        let instruction = WarpInstruction::ReduceSum {
437                            input: self.compile_variable(op.input),
438                            out,
439                        };
440                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
441                        instructions.push(Instruction::Warp(instruction));
442                    }
443                    gpu::Plane::InclusiveSum(op) => {
444                        self.flags.indexes.unit_pos_plane = true;
445                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
446                            input: self.compile_variable(op.input),
447                            out,
448                        }))
449                    }
450                    gpu::Plane::InclusiveProd(op) => {
451                        self.flags.indexes.unit_pos_plane = true;
452                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
453                            input: self.compile_variable(op.input),
454                            out,
455                        }))
456                    }
457                    gpu::Plane::ExclusiveSum(op) => {
458                        self.flags.indexes.unit_pos_plane = true;
459                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
460                            input: self.compile_variable(op.input),
461                            out,
462                        }))
463                    }
464                    gpu::Plane::ExclusiveProd(op) => {
465                        self.flags.indexes.unit_pos_plane = true;
466                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
467                            input: self.compile_variable(op.input),
468                            out,
469                        }))
470                    }
471                    gpu::Plane::Prod(op) => {
472                        let instruction = WarpInstruction::ReduceProd {
473                            input: self.compile_variable(op.input),
474                            out,
475                        };
476                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
477                        instructions.push(Instruction::Warp(instruction))
478                    }
479                    gpu::Plane::Max(op) => {
480                        let instruction = WarpInstruction::ReduceMax {
481                            input: self.compile_variable(op.input),
482                            out,
483                        };
484                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
485                        instructions.push(Instruction::Warp(instruction))
486                    }
487                    gpu::Plane::Min(op) => {
488                        let instruction = WarpInstruction::ReduceMin {
489                            input: self.compile_variable(op.input),
490                            out,
491                        };
492                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
493                        instructions.push(Instruction::Warp(instruction))
494                    }
495                    gpu::Plane::Elect => {
496                        if self.compilation_options.supports_features.elect_sync {
497                            self.flags.inst_ptx_wrappers = true;
498                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
499                        } else {
500                            instructions
501                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
502                        }
503                    }
504                    gpu::Plane::All(op) => {
505                        instructions.push(Instruction::Warp(WarpInstruction::All {
506                            input: self.compile_variable(op.input),
507                            out,
508                        }))
509                    }
510                    gpu::Plane::Any(op) => {
511                        instructions.push(Instruction::Warp(WarpInstruction::Any {
512                            input: self.compile_variable(op.input),
513                            out,
514                        }))
515                    }
516                    gpu::Plane::Ballot(op) => {
517                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
518                            input: self.compile_variable(op.input),
519                            out,
520                        }))
521                    }
522                    gpu::Plane::Broadcast(op) => {
523                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
524                            input: self.compile_variable(op.lhs),
525                            id: self.compile_variable(op.rhs),
526                            out,
527                        }))
528                    }
529                    gpu::Plane::Shuffle(op) => {
530                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
531                            input: self.compile_variable(op.lhs),
532                            src_lane: self.compile_variable(op.rhs),
533                            out,
534                        }))
535                    }
536                    gpu::Plane::ShuffleXor(op) => {
537                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
538                            input: self.compile_variable(op.lhs),
539                            mask: self.compile_variable(op.rhs),
540                            out,
541                        }))
542                    }
543                    gpu::Plane::ShuffleUp(op) => {
544                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
545                            input: self.compile_variable(op.lhs),
546                            delta: self.compile_variable(op.rhs),
547                            out,
548                        }))
549                    }
550                    gpu::Plane::ShuffleDown(op) => {
551                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
552                            input: self.compile_variable(op.lhs),
553                            delta: self.compile_variable(op.rhs),
554                            out,
555                        }))
556                    }
557                }
558            }
559            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
560            gpu::Operation::NonSemantic(debug) => match debug {
561                gpu::NonSemantic::Print {
562                    format_string,
563                    args,
564                } => instructions.push(Instruction::Printf {
565                    format_string,
566                    args: args
567                        .into_iter()
568                        .map(|arg| self.compile_variable(arg))
569                        .collect(),
570                }),
571                gpu::NonSemantic::Comment { content } => {
572                    instructions.push(Instruction::Comment { content })
573                }
574                // Don't need to handle scopes
575                _ => {}
576            },
577            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
578                gpu::BarrierOps::Declare { barrier } => {
579                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
580                    else {
581                        unreachable!()
582                    };
583                    let barrier = self.compile_variable(barrier);
584                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
585                        barrier,
586                        level,
587                    }));
588                }
589                gpu::BarrierOps::Init {
590                    barrier,
591                    is_elected,
592                    arrival_count,
593                } => {
594                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
595                    else {
596                        unreachable!()
597                    };
598                    let barrier = self.compile_variable(barrier);
599                    let arrival_count = self.compile_variable(arrival_count);
600                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
601                        barrier,
602                        is_elected: self.compile_variable(is_elected),
603                        arrival_count,
604                        level,
605                    }));
606                }
607                gpu::BarrierOps::InitManual {
608                    barrier,
609                    arrival_count,
610                } => {
611                    let barrier = self.compile_variable(barrier);
612                    let arrival_count = self.compile_variable(arrival_count);
613                    instructions.push(Instruction::Barrier(
614                        super::barrier::BarrierOps::InitManual {
615                            barrier,
616                            arrival_count,
617                        },
618                    ));
619                }
620                gpu::BarrierOps::MemCopyAsync {
621                    barrier,
622                    source,
623                    source_length,
624                    offset_source,
625                    offset_out,
626                } => {
627                    instructions.push(Instruction::Barrier(
628                        super::barrier::BarrierOps::MemCopyAsync {
629                            barrier: self.compile_variable(barrier),
630                            source: self.compile_variable(source),
631                            destination: self.compile_variable(out.unwrap()),
632                            source_length: self.compile_variable(source_length),
633                            offset_source: self.compile_variable(offset_source),
634                            offset_out: self.compile_variable(offset_out),
635                            cooperative: false,
636                        },
637                    ));
638                }
639                gpu::BarrierOps::MemCopyAsyncCooperative {
640                    barrier,
641                    source,
642                    source_length,
643                    offset_source,
644                    offset_out,
645                } => {
646                    instructions.push(Instruction::Barrier(
647                        super::barrier::BarrierOps::MemCopyAsync {
648                            barrier: self.compile_variable(barrier),
649                            source: self.compile_variable(source),
650                            destination: self.compile_variable(out.unwrap()),
651                            source_length: self.compile_variable(source_length),
652                            offset_source: self.compile_variable(offset_source),
653                            offset_out: self.compile_variable(offset_out),
654                            cooperative: true,
655                        },
656                    ));
657                }
658                gpu::BarrierOps::MemCopyAsyncTx {
659                    barrier,
660                    source,
661                    source_length,
662                    offset_source,
663                    offset_out,
664                } => {
665                    instructions.push(Instruction::Barrier(
666                        super::barrier::BarrierOps::MemCopyAsyncTx {
667                            barrier: self.compile_variable(barrier),
668                            source: self.compile_variable(source),
669                            destination: self.compile_variable(out.unwrap()),
670                            source_length: self.compile_variable(source_length),
671                            offset_source: self.compile_variable(offset_source),
672                            offset_out: self.compile_variable(offset_out),
673                        },
674                    ));
675                }
676                gpu::BarrierOps::CopyAsync {
677                    source,
678                    source_length,
679                    offset_source,
680                    offset_out,
681                    copy_length,
682                    checked,
683                } => {
684                    self.flags.inst_async_copy = true;
685                    instructions.push(Instruction::Barrier(
686                        super::barrier::BarrierOps::CopyAsync {
687                            source: self.compile_variable(source),
688                            destination: self.compile_variable(out.unwrap()),
689                            source_length: self.compile_variable(source_length),
690                            offset_source: self.compile_variable(offset_source),
691                            offset_out: self.compile_variable(offset_out),
692                            copy_size: copy_length,
693                            checked,
694                        },
695                    ));
696                }
697                gpu::BarrierOps::TmaLoad {
698                    barrier,
699                    tensor_map,
700                    offset_out,
701                    indices,
702                } => {
703                    instructions.push(Instruction::Barrier(
704                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
705                            barrier: self.compile_variable(barrier),
706                            smem_buffer: self.compile_variable(out.unwrap()),
707                            smem_offset: self.compile_variable(offset_out),
708                            tensor_map: self.compile_variable(tensor_map),
709                            indices: indices
710                                .into_iter()
711                                .map(|it| self.compile_variable(it))
712                                .collect(),
713                        },
714                    ));
715                }
716                gpu::BarrierOps::TmaLoadIm2col {
717                    barrier,
718                    tensor_map,
719                    offset_out,
720                    indices,
721                    offsets,
722                } => {
723                    self.flags.inst_tma_im2col = true;
724                    instructions.push(Instruction::Barrier(
725                        super::barrier::BarrierOps::TmaLoadIm2col {
726                            barrier: self.compile_variable(barrier),
727                            smem_buffer: self.compile_variable(out.unwrap()),
728                            smem_offset: self.compile_variable(offset_out),
729                            tensor_map: self.compile_variable(tensor_map),
730                            indices: indices
731                                .into_iter()
732                                .map(|it| self.compile_variable(it))
733                                .collect(),
734                            offsets: offsets
735                                .into_iter()
736                                .map(|it| self.compile_variable(it))
737                                .collect(),
738                        },
739                    ));
740                }
741                gpu::BarrierOps::Arrive { barrier } => {
742                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
743                        barrier: self.compile_variable(barrier),
744                        token: self.compile_variable(out.unwrap()),
745                    }))
746                }
747                gpu::BarrierOps::ArriveTx {
748                    barrier,
749                    arrive_count_update,
750                    transaction_count_update,
751                } => {
752                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
753                        barrier: self.compile_variable(barrier),
754                        token: self.compile_variable(out.unwrap()),
755                        arrive_count_update: self.compile_variable(arrive_count_update),
756                        transaction_count_update: self.compile_variable(transaction_count_update),
757                    }))
758                }
759                gpu::BarrierOps::CommitCopyAsync { barrier } => {
760                    self.flags.inst_async_copy = true;
761                    instructions.push(Instruction::Barrier(
762                        super::barrier::BarrierOps::ArriveCopyAsync {
763                            barrier: self.compile_variable(barrier),
764                        },
765                    ))
766                }
767                gpu::BarrierOps::ExpectTx {
768                    barrier,
769                    transaction_count_update,
770                } => {
771                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
772                        barrier: self.compile_variable(barrier),
773                        transaction_count_update: self.compile_variable(transaction_count_update),
774                    }))
775                }
776                gpu::BarrierOps::Wait { barrier, token } => {
777                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
778                        barrier: self.compile_variable(barrier),
779                        token: self.compile_variable(token),
780                    }))
781                }
782                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
783                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
784                        barrier: self.compile_variable(barrier),
785                        phase: self.compile_variable(phase),
786                    }),
787                ),
788                gpu::BarrierOps::ArriveAndWait { barrier } => {
789                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
790                    else {
791                        unreachable!()
792                    };
793                    instructions.push(Instruction::Barrier(
794                        super::barrier::BarrierOps::ArriveAndWait {
795                            barrier: self.compile_variable(barrier),
796                            level,
797                        },
798                    ))
799                }
800            },
801            gpu::Operation::Tma(tma_ops) => {
802                self.flags.inst_tma = true;
803                match tma_ops {
804                    gpu::TmaOps::TmaStore {
805                        source,
806                        coordinates,
807                        offset_source,
808                    } => {
809                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
810                            smem_buffer: self.compile_variable(source),
811                            smem_offset: self.compile_variable(offset_source),
812                            tensor_map: self.compile_variable(out.unwrap()),
813                            indices: coordinates
814                                .into_iter()
815                                .map(|it| self.compile_variable(it))
816                                .collect(),
817                        });
818                    }
819                    gpu::TmaOps::CommitGroup => {
820                        instructions.push(Instruction::BulkCommitGroup);
821                    }
822                    gpu::TmaOps::WaitGroup { max_pending } => {
823                        instructions.push(Instruction::BulkWaitGroup { max_pending });
824                    }
825                    gpu::TmaOps::WaitGroupRead { max_pending } => {
826                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
827                    }
828                }
829            }
830            gpu::Operation::Marker(_) => {}
831        }
832    }
833
834    fn update_debug_loc(
835        &mut self,
836        instructions: &mut Vec<Instruction<D>>,
837        inst: &gpu::Instruction,
838    ) {
839        if !matches!(inst.operation, Operation::NonSemantic(_)) {
840            match &inst.source_loc {
841                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
842                    self.source_loc = Some(loc.clone());
843                    instructions.push(Instruction::Line {
844                        file: loc.source.file.clone(),
845                        line: loc.line,
846                    });
847                }
848                _ => {}
849            }
850        }
851    }
852
853    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
854        self.flags.inst_wmma = true;
855
856        let out = self.compile_variable(out.unwrap());
857
858        let inst = match cmma {
859            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
860                frag: out,
861                value: self.compile_variable(value),
862            },
863            gpu::CoopMma::Load {
864                value,
865                stride,
866                offset,
867                layout,
868            } => WmmaInstruction::Load {
869                frag: out,
870                offset: self.compile_variable(offset),
871                value: self.compile_variable(value),
872                stride: self.compile_variable(stride),
873                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
874            },
875            gpu::CoopMma::Execute {
876                mat_a,
877                mat_b,
878                mat_c,
879            } => WmmaInstruction::Execute {
880                frag_a: self.compile_variable(mat_a),
881                frag_b: self.compile_variable(mat_b),
882                frag_c: self.compile_variable(mat_c),
883                frag_d: out,
884                warp_size: self.compilation_options.warp_size,
885            },
886            gpu::CoopMma::ExecuteManual {
887                matrix,
888                registers_a,
889                registers_b,
890                registers_c,
891            } => WmmaInstruction::ExecuteManual {
892                shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
893                frag_a: self.compile_variable(registers_a),
894                frag_b: self.compile_variable(registers_b),
895                frag_c: self.compile_variable(registers_c),
896                frag_d: out,
897            },
898            gpu::CoopMma::ExecuteScaled {
899                matrix,
900                registers_a,
901                registers_b,
902                registers_c,
903                scales_a,
904                scales_b,
905                scales_factor,
906            } => WmmaInstruction::ExecuteScaled {
907                shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
908                frag_a: self.compile_variable(registers_a),
909                frag_b: self.compile_variable(registers_b),
910                frag_c: self.compile_variable(registers_c),
911                frag_d: out,
912
913                scales_a: self.compile_variable(scales_a),
914                scales_b: self.compile_variable(scales_b),
915                scales_factor: scales_factor as u32,
916            },
917            gpu::CoopMma::Store {
918                mat,
919                stride,
920                offset,
921                layout,
922            } => {
923                self.flags.indexes.unit_pos = true;
924                self.flags.indexes.plane_pos = true;
925                WmmaInstruction::Store {
926                    output: out,
927                    offset: self.compile_variable(offset),
928                    frag: self.compile_variable(mat),
929                    stride: self.compile_variable(stride),
930                    layout: self
931                        .compile_matrix_layout(layout)
932                        .expect("Layout required for store instruction"),
933                }
934            }
935            gpu::CoopMma::LoadMatrix {
936                buffer,
937                offset,
938                vector_size,
939                factor,
940                transpose,
941            } => WmmaInstruction::LdMatrix {
942                output: out,
943                buffer: self.compile_variable(buffer),
944                offset: self.compile_variable(offset),
945                vector_size,
946                factor: factor as u32,
947                transpose,
948            },
949            gpu::CoopMma::StoreMatrix {
950                offset,
951                vector_size,
952                registers,
953                factor,
954                transpose,
955            } => WmmaInstruction::StMatrix {
956                registers: self.compile_variable(registers),
957                buffer: out,
958                offset: self.compile_variable(offset),
959                vector_size,
960                factor: factor as u32,
961                transpose,
962            },
963            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
964                input: self.compile_variable(input),
965                output: out,
966            },
967            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
968                panic!("Row/Col index should be handled by processors")
969            }
970        };
971
972        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
973
974        Instruction::Wmma(inst)
975    }
976
977    fn compile_metadata(
978        &mut self,
979        metadata: gpu::Metadata,
980        out: Option<gpu::Variable>,
981    ) -> Instruction<D> {
982        let out = out.unwrap();
983        match metadata {
984            gpu::Metadata::Stride { dim, var } => {
985                let position = self.ext_meta_position(var);
986                let offset = self.info.metadata.stride_offset_index(position);
987                Instruction::ExtendedMetadata {
988                    info_offset: self.compile_variable(offset.into()),
989                    dim: self.compile_variable(dim),
990                    out: self.compile_variable(out),
991                }
992            }
993            gpu::Metadata::Shape { dim, var } => {
994                let position = self.ext_meta_position(var);
995                let offset = self.info.metadata.shape_offset_index(position);
996                Instruction::ExtendedMetadata {
997                    info_offset: self.compile_variable(offset.into()),
998                    dim: self.compile_variable(dim),
999                    out: self.compile_variable(out),
1000                }
1001            }
1002            gpu::Metadata::Rank { var } => {
1003                let out = self.compile_variable(out);
1004                let pos = self.ext_meta_position(var);
1005                let offset = self.info.metadata.rank_index(pos);
1006                super::Instruction::Metadata {
1007                    info_offset: self.compile_variable(offset.into()),
1008                    out,
1009                }
1010            }
1011            gpu::Metadata::Length { var } => {
1012                let input = self.compile_variable(var);
1013                let out = self.compile_variable(out);
1014
1015                match input {
1016                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1017                    Variable::SharedArray(_id, _item, length) => {
1018                        Instruction::ConstLength { length, out }
1019                    }
1020                    _ => {
1021                        let id = input.id().expect("Variable should have id");
1022                        let offset = self.info.metadata.len_index(id);
1023                        Instruction::Metadata {
1024                            info_offset: self.compile_variable(offset.into()),
1025                            out,
1026                        }
1027                    }
1028                }
1029            }
1030            gpu::Metadata::BufferLength { var } => {
1031                let input = self.compile_variable(var);
1032                let out = self.compile_variable(out);
1033
1034                match input {
1035                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1036                    _ => {
1037                        let id = input.id().expect("Variable should have id");
1038                        let offset = self.info.metadata.buffer_len_index(id);
1039                        Instruction::Metadata {
1040                            info_offset: self.compile_variable(offset.into()),
1041                            out,
1042                        }
1043                    }
1044                }
1045            }
1046        }
1047    }
1048
1049    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
1050        match branch {
1051            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
1052                cond: self.compile_variable(op.cond),
1053                instructions: self.compile_scope(&mut op.scope),
1054            }),
1055            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
1056                cond: self.compile_variable(op.cond),
1057                instructions_if: self.compile_scope(&mut op.scope_if),
1058                instructions_else: self.compile_scope(&mut op.scope_else),
1059            }),
1060            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1061                value: self.compile_variable(op.value),
1062                instructions_default: self.compile_scope(&mut op.scope_default),
1063                instructions_cases: op
1064                    .cases
1065                    .into_iter()
1066                    .map(|(val, mut block)| {
1067                        (self.compile_variable(val), self.compile_scope(&mut block))
1068                    })
1069                    .collect(),
1070            }),
1071            gpu::Branch::Return => instructions.push(Instruction::Return),
1072            gpu::Branch::Break => instructions.push(Instruction::Break),
1073            gpu::Branch::Unreachable => instructions.push(Instruction::Unreachable),
1074            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1075                i: self.compile_variable(range_loop.i),
1076                start: self.compile_variable(range_loop.start),
1077                end: self.compile_variable(range_loop.end),
1078                step: range_loop.step.map(|it| self.compile_variable(it)),
1079                inclusive: range_loop.inclusive,
1080                instructions: self.compile_scope(&mut range_loop.scope),
1081            }),
1082            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1083                instructions: self.compile_scope(&mut op.scope),
1084            }),
1085        };
1086    }
1087
1088    fn compile_atomic(
1089        &mut self,
1090        value: gpu::AtomicOp,
1091        out: Option<gpu::Variable>,
1092        instructions: &mut Vec<Instruction<D>>,
1093    ) {
1094        let out = out.unwrap();
1095        match value {
1096            gpu::AtomicOp::Load(op) => {
1097                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1098            }
1099            gpu::AtomicOp::Store(op) => {
1100                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1101            }
1102            gpu::AtomicOp::Swap(op) => {
1103                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1104            }
1105            gpu::AtomicOp::Add(op) => {
1106                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1107            }
1108            gpu::AtomicOp::Sub(op) => {
1109                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1110            }
1111            gpu::AtomicOp::Max(op) => {
1112                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1113            }
1114            gpu::AtomicOp::Min(op) => {
1115                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1116            }
1117            gpu::AtomicOp::And(op) => {
1118                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1119            }
1120            gpu::AtomicOp::Or(op) => {
1121                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1122            }
1123            gpu::AtomicOp::Xor(op) => {
1124                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1125            }
1126            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1127                input: self.compile_variable(op.input),
1128                cmp: self.compile_variable(op.cmp),
1129                val: self.compile_variable(op.val),
1130                out: self.compile_variable(out),
1131            }),
1132        }
1133    }
1134
1135    fn compile_arithmetic(
1136        &mut self,
1137        value: gpu::Arithmetic,
1138        out: Option<gpu::Variable>,
1139        modes: InstructionModes,
1140        instructions: &mut Vec<Instruction<D>>,
1141    ) {
1142        let out = out.unwrap();
1143        match value {
1144            gpu::Arithmetic::Add(op) => {
1145                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1146            }
1147            gpu::Arithmetic::SaturatingAdd(op) => {
1148                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1149            }
1150            gpu::Arithmetic::Mul(op) => {
1151                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1152            }
1153            gpu::Arithmetic::Div(op) => {
1154                let op = self.compile_binary(op, out);
1155                instructions.push(self.select_fast_float(
1156                    out.ty,
1157                    modes,
1158                    FastMath::AllowReciprocal
1159                        | FastMath::ReducedPrecision
1160                        | FastMath::UnsignedZero
1161                        | FastMath::NotInf,
1162                    Instruction::Div(op),
1163                    Instruction::FastDiv(op),
1164                ))
1165            }
1166            gpu::Arithmetic::Sub(op) => {
1167                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1168            }
1169            gpu::Arithmetic::SaturatingSub(op) => {
1170                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1171            }
1172            gpu::Arithmetic::MulHi(op) => {
1173                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1174                D::register_instruction_extension(&mut self.extensions, &instruction);
1175                instructions.push(instruction)
1176            }
1177            gpu::Arithmetic::Modulo(op) => {
1178                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1179            }
1180            gpu::Arithmetic::Abs(op) => {
1181                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1182            }
1183            gpu::Arithmetic::Exp(op) => {
1184                let op = self.compile_unary(op, out);
1185                instructions.push(self.select_fast_float(
1186                    out.ty,
1187                    modes,
1188                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1189                    Instruction::Exp(op),
1190                    Instruction::FastExp(op),
1191                ));
1192            }
1193            gpu::Arithmetic::Log(op) => {
1194                let op = self.compile_unary(op, out);
1195                instructions.push(self.select_fast_float(
1196                    out.ty,
1197                    modes,
1198                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1199                    Instruction::Log(op),
1200                    Instruction::FastLog(op),
1201                ));
1202            }
1203            gpu::Arithmetic::Log1p(op) => {
1204                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1205            }
1206            gpu::Arithmetic::Cos(op) => {
1207                let op = self.compile_unary(op, out);
1208                instructions.push(self.select_fast_float(
1209                    out.ty,
1210                    modes,
1211                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1212                    Instruction::Cos(op),
1213                    Instruction::FastCos(op),
1214                ));
1215            }
1216            gpu::Arithmetic::Sin(op) => {
1217                let op = self.compile_unary(op, out);
1218                instructions.push(self.select_fast_float(
1219                    out.ty,
1220                    modes,
1221                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1222                    Instruction::Sin(op),
1223                    Instruction::FastSin(op),
1224                ));
1225            }
1226            gpu::Arithmetic::Tan(op) => {
1227                instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1228            }
1229            gpu::Arithmetic::Tanh(op) => {
1230                let op = self.compile_unary(op, out);
1231                let instruction = Instruction::Tanh(op);
1232                D::register_instruction_extension(&mut self.extensions, &instruction);
1233                if self.compilation_options.supports_features.fast_tanh {
1234                    instructions.push(self.select_fast_float(
1235                        out.ty,
1236                        modes,
1237                        FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1238                        instruction,
1239                        Instruction::FastTanh(op),
1240                    ))
1241                } else {
1242                    instructions.push(instruction);
1243                }
1244            }
1245            gpu::Arithmetic::Sinh(op) => {
1246                let instruction = Instruction::Sinh(self.compile_unary(op, out));
1247                D::register_instruction_extension(&mut self.extensions, &instruction);
1248                instructions.push(instruction)
1249            }
1250            gpu::Arithmetic::Cosh(op) => {
1251                let instruction = Instruction::Cosh(self.compile_unary(op, out));
1252                D::register_instruction_extension(&mut self.extensions, &instruction);
1253                instructions.push(instruction)
1254            }
1255            gpu::Arithmetic::ArcCos(op) => {
1256                let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1257                D::register_instruction_extension(&mut self.extensions, &instruction);
1258                instructions.push(instruction)
1259            }
1260            gpu::Arithmetic::ArcSin(op) => {
1261                let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1262                D::register_instruction_extension(&mut self.extensions, &instruction);
1263                instructions.push(instruction)
1264            }
1265            gpu::Arithmetic::ArcTan(op) => {
1266                let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1267                D::register_instruction_extension(&mut self.extensions, &instruction);
1268                instructions.push(instruction)
1269            }
1270            gpu::Arithmetic::ArcSinh(op) => {
1271                let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1272                D::register_instruction_extension(&mut self.extensions, &instruction);
1273                instructions.push(instruction)
1274            }
1275            gpu::Arithmetic::ArcCosh(op) => {
1276                let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1277                D::register_instruction_extension(&mut self.extensions, &instruction);
1278                instructions.push(instruction)
1279            }
1280            gpu::Arithmetic::ArcTanh(op) => {
1281                let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1282                D::register_instruction_extension(&mut self.extensions, &instruction);
1283                instructions.push(instruction)
1284            }
1285            gpu::Arithmetic::Degrees(op) => {
1286                let instruction = Instruction::Degrees(self.compile_unary(op, out));
1287                D::register_instruction_extension(&mut self.extensions, &instruction);
1288                instructions.push(instruction)
1289            }
1290            gpu::Arithmetic::Radians(op) => {
1291                let instruction = Instruction::Radians(self.compile_unary(op, out));
1292                D::register_instruction_extension(&mut self.extensions, &instruction);
1293                instructions.push(instruction)
1294            }
1295            gpu::Arithmetic::ArcTan2(op) => {
1296                let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1297                D::register_instruction_extension(&mut self.extensions, &instruction);
1298                instructions.push(instruction)
1299            }
1300            gpu::Arithmetic::Powf(op) => {
1301                let op = self.compile_binary(op, out);
1302                instructions.push(self.select_fast_float(
1303                    out.ty,
1304                    modes,
1305                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1306                    Instruction::Powf(op),
1307                    Instruction::FastPowf(op),
1308                ))
1309            }
1310            gpu::Arithmetic::Powi(op) => {
1311                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1312            }
1313            gpu::Arithmetic::Hypot(op) => {
1314                instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1315            }
1316            gpu::Arithmetic::Rhypot(op) => {
1317                instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1318            }
1319            gpu::Arithmetic::Sqrt(op) => {
1320                let op = self.compile_unary(op, out);
1321                instructions.push(self.select_fast_float(
1322                    out.ty,
1323                    modes,
1324                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1325                    Instruction::Sqrt(op),
1326                    Instruction::FastSqrt(op),
1327                ))
1328            }
1329            gpu::Arithmetic::InverseSqrt(op) => {
1330                let op = self.compile_unary(op, out);
1331                instructions.push(self.select_fast_float(
1332                    out.ty,
1333                    modes,
1334                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1335                    Instruction::InverseSqrt(op),
1336                    Instruction::FastInverseSqrt(op),
1337                ))
1338            }
1339            gpu::Arithmetic::Erf(op) => {
1340                let instruction = Instruction::Erf(self.compile_unary(op, out));
1341                D::register_instruction_extension(&mut self.extensions, &instruction);
1342                instructions.push(instruction)
1343            }
1344            gpu::Arithmetic::Max(op) => {
1345                let instruction = Instruction::Max(self.compile_binary(op, out));
1346                D::register_instruction_extension(&mut self.extensions, &instruction);
1347                instructions.push(instruction)
1348            }
1349            gpu::Arithmetic::Min(op) => {
1350                let instruction = Instruction::Min(self.compile_binary(op, out));
1351                D::register_instruction_extension(&mut self.extensions, &instruction);
1352                instructions.push(instruction)
1353            }
1354            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1355                input: self.compile_variable(op.input),
1356                min_value: self.compile_variable(op.min_value),
1357                max_value: self.compile_variable(op.max_value),
1358                out: self.compile_variable(out),
1359            }),
1360            gpu::Arithmetic::Recip(op) => {
1361                let elem = op.input.ty.elem_type();
1362                let input = self.compile_variable(op.input);
1363                let out = self.compile_variable(out);
1364                let lhs = match elem {
1365                    gpu::ElemType::Float(_) => gpu::ConstantValue::Float(1.0),
1366                    gpu::ElemType::Int(_) => gpu::ConstantValue::Int(1),
1367                    gpu::ElemType::UInt(_) => gpu::ConstantValue::UInt(1),
1368                    gpu::ElemType::Bool => gpu::ConstantValue::Bool(true),
1369                };
1370                let div = Instruction::Div(BinaryInstruction {
1371                    lhs: Variable::Constant(lhs, self.compile_type(op.input.ty)),
1372                    rhs: input,
1373                    out,
1374                });
1375                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1376
1377                instructions.push(self.select_fast_float(
1378                    elem.into(),
1379                    modes,
1380                    FastMath::AllowReciprocal
1381                        | FastMath::ReducedPrecision
1382                        | FastMath::UnsignedZero
1383                        | FastMath::NotInf,
1384                    div,
1385                    recip,
1386                ))
1387            }
1388            gpu::Arithmetic::Round(op) => {
1389                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1390            }
1391            gpu::Arithmetic::Floor(op) => {
1392                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1393            }
1394            gpu::Arithmetic::Ceil(op) => {
1395                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1396            }
1397            gpu::Arithmetic::Trunc(op) => {
1398                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1399            }
1400            gpu::Arithmetic::Remainder(op) => {
1401                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1402            }
1403            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1404                a: self.compile_variable(op.a),
1405                b: self.compile_variable(op.b),
1406                c: self.compile_variable(op.c),
1407                out: self.compile_variable(out),
1408            }),
1409            gpu::Arithmetic::Neg(op) => {
1410                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1411            }
1412            gpu::Arithmetic::Normalize(op) => {
1413                let op = self.compile_unary(op, out);
1414                instructions.push(self.select_fast_float(
1415                    out.ty,
1416                    modes,
1417                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1418                    Instruction::Normalize(op),
1419                    Instruction::FastNormalize(op),
1420                ))
1421            }
1422            gpu::Arithmetic::Magnitude(op) => {
1423                let op = self.compile_unary(op, out);
1424                instructions.push(self.select_fast_float(
1425                    out.ty,
1426                    modes,
1427                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1428                    Instruction::Magnitude(op),
1429                    Instruction::FastMagnitude(op),
1430                ))
1431            }
1432            gpu::Arithmetic::Dot(op) => {
1433                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1434            }
1435        };
1436    }
1437
1438    fn select_fast_float(
1439        &self,
1440        ty: gpu::Type,
1441        modes: InstructionModes,
1442        required_flags: EnumSet<FastMath>,
1443        default: Instruction<D>,
1444        fast: Instruction<D>,
1445    ) -> Instruction<D> {
1446        if !self.compilation_options.supports_features.fast_math
1447            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1448        {
1449            return default;
1450        }
1451
1452        if modes.fp_math_mode.is_superset(required_flags) {
1453            fast
1454        } else {
1455            default
1456        }
1457    }
1458
1459    fn compile_comparison(
1460        &mut self,
1461        value: gpu::Comparison,
1462        out: Option<gpu::Variable>,
1463        instructions: &mut Vec<Instruction<D>>,
1464    ) {
1465        let out = out.unwrap();
1466        match value {
1467            gpu::Comparison::Equal(op) => {
1468                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1469            }
1470            gpu::Comparison::Lower(op) => {
1471                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1472            }
1473            gpu::Comparison::Greater(op) => {
1474                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1475            }
1476            gpu::Comparison::LowerEqual(op) => {
1477                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1478            }
1479            gpu::Comparison::GreaterEqual(op) => {
1480                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1481            }
1482            gpu::Comparison::NotEqual(op) => {
1483                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1484            }
1485            gpu::Comparison::IsNan(op) => {
1486                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1487            }
1488            gpu::Comparison::IsInf(op) => {
1489                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1490            }
1491        };
1492    }
1493
1494    fn compile_bitwise(
1495        &mut self,
1496        value: gpu::Bitwise,
1497        out: Option<gpu::Variable>,
1498        instructions: &mut Vec<Instruction<D>>,
1499    ) {
1500        let out = out.unwrap();
1501        match value {
1502            gpu::Bitwise::BitwiseOr(op) => {
1503                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1504            }
1505            gpu::Bitwise::BitwiseAnd(op) => {
1506                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1507            }
1508            gpu::Bitwise::BitwiseXor(op) => {
1509                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1510            }
1511            gpu::Bitwise::CountOnes(op) => {
1512                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1513            }
1514            gpu::Bitwise::ReverseBits(op) => {
1515                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1516            }
1517            gpu::Bitwise::ShiftLeft(op) => {
1518                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1519            }
1520            gpu::Bitwise::ShiftRight(op) => {
1521                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1522            }
1523            gpu::Bitwise::BitwiseNot(op) => {
1524                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1525            }
1526            gpu::Bitwise::LeadingZeros(op) => {
1527                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1528            }
1529            gpu::Bitwise::TrailingZeros(op) => {
1530                instructions.push(Instruction::TrailingZeros(self.compile_unary(op, out)))
1531            }
1532            gpu::Bitwise::FindFirstSet(op) => {
1533                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1534                D::register_instruction_extension(&mut self.extensions, &instruction);
1535                instructions.push(instruction)
1536            }
1537        };
1538    }
1539
1540    fn compile_operator(
1541        &mut self,
1542        value: gpu::Operator,
1543        out: Option<gpu::Variable>,
1544        instructions: &mut Vec<Instruction<D>>,
1545    ) {
1546        let out = out.unwrap();
1547        match value {
1548            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1549                instructions.push(Instruction::Index(self.compile_index(op, out)));
1550            }
1551            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1552                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1553            }
1554            gpu::Operator::And(op) => {
1555                instructions.push(Instruction::And(self.compile_binary(op, out)))
1556            }
1557            gpu::Operator::Or(op) => {
1558                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1559            }
1560            gpu::Operator::Not(op) => {
1561                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1562            }
1563            gpu::Operator::InitVector(op) => instructions.push(Instruction::VecInit {
1564                inputs: op
1565                    .inputs
1566                    .into_iter()
1567                    .map(|it| self.compile_variable(it))
1568                    .collect(),
1569                out: self.compile_variable(out),
1570            }),
1571            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1572                input: self.compile_variable(op.input),
1573                in_index: self.compile_variable(op.in_index),
1574                out: self.compile_variable(out),
1575                out_index: self.compile_variable(op.out_index),
1576            }),
1577            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1578                input: self.compile_variable(op.input),
1579                in_index: self.compile_variable(op.in_index),
1580                out: self.compile_variable(out),
1581                out_index: self.compile_variable(op.out_index),
1582                len: op.len as u32,
1583            }),
1584            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1585                cond: self.compile_variable(op.cond),
1586                then: self.compile_variable(op.then),
1587                or_else: self.compile_variable(op.or_else),
1588                out: self.compile_variable(out),
1589            }),
1590            // Needs special conversion semantics
1591            gpu::Operator::Cast(op)
1592                if (is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()))
1593                // Trivial broadcast shouldn't use special cast logic
1594                    && op.input.elem_type() != out.elem_type() =>
1595            {
1596                // We may need these for intermediates
1597                self.flags.elem_f16 = true;
1598                self.flags.elem_bf16 = true;
1599                let vec_in = op.input.ty.vector_size();
1600                let packing = out.storage_type().packing_factor();
1601                self.compile_type(op.input.ty.with_vector_size(packing));
1602                self.compile_type(
1603                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16))
1604                        .with_vector_size(vec_in),
1605                );
1606                self.compile_type(
1607                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16))
1608                        .with_vector_size(vec_in),
1609                );
1610                self.compile_type(
1611                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16))
1612                        .with_vector_size(packing),
1613                );
1614                self.compile_type(
1615                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16))
1616                        .with_vector_size(packing),
1617                );
1618
1619                let inst = self.compile_unary(op, out);
1620
1621                instructions.push(Instruction::SpecialCast(inst));
1622            }
1623            gpu::Operator::Cast(op) => {
1624                let op = self.compile_unary(op, out);
1625
1626                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1627                    self.flags.elem_tf32 = true;
1628                }
1629
1630                instructions.push(Instruction::Assign(op))
1631            }
1632            gpu::Operator::Reinterpret(op) => {
1633                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1634            }
1635        };
1636    }
1637
1638    fn compile_binary(
1639        &mut self,
1640        value: gpu::BinaryOperator,
1641        out: gpu::Variable,
1642    ) -> BinaryInstruction<D> {
1643        BinaryInstruction {
1644            lhs: self.compile_variable(value.lhs),
1645            rhs: self.compile_variable(value.rhs),
1646            out: self.compile_variable(out),
1647        }
1648    }
1649
1650    fn compile_index_assign(
1651        &mut self,
1652        value: gpu::IndexAssignOperator,
1653        out: gpu::Variable,
1654    ) -> IndexAssignInstruction<D> {
1655        IndexAssignInstruction {
1656            index: self.compile_variable(value.index),
1657            value: self.compile_variable(value.value),
1658            vector_size: value.vector_size as u32,
1659            out: self.compile_variable(out),
1660        }
1661    }
1662
1663    fn compile_index(
1664        &mut self,
1665        value: gpu::IndexOperator,
1666        out: gpu::Variable,
1667    ) -> IndexInstruction<D> {
1668        IndexInstruction {
1669            list: self.compile_variable(value.list),
1670            index: self.compile_variable(value.index),
1671            vector_size: value.vector_size as u32,
1672            out: self.compile_variable(out),
1673        }
1674    }
1675
1676    fn compile_unary(
1677        &mut self,
1678        value: gpu::UnaryOperator,
1679        out: gpu::Variable,
1680    ) -> UnaryInstruction<D> {
1681        UnaryInstruction {
1682            input: self.compile_variable(value.input),
1683            out: self.compile_variable(out),
1684        }
1685    }
1686
1687    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1688        let item = value.ty;
1689        match value.kind {
1690            gpu::VariableKind::GlobalInputArray(id) => {
1691                Variable::GlobalInputArray(id, self.compile_type(item))
1692            }
1693            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1694                id,
1695                elem: self.compile_storage_type(item.storage_type()),
1696            },
1697            gpu::VariableKind::TensorMapInput(id) => {
1698                self.flags.inst_tma = true;
1699                Variable::TensorMap(id)
1700            }
1701            gpu::VariableKind::TensorMapOutput(id) => {
1702                self.flags.inst_tma = true;
1703                Variable::TensorMap(id)
1704            }
1705            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1706                id,
1707                item: self.compile_type(item),
1708            },
1709            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1710                id,
1711                item: self.compile_type(item),
1712            },
1713            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1714                id,
1715                item: self.compile_type(item),
1716            },
1717            gpu::VariableKind::GlobalOutputArray(id) => {
1718                Variable::GlobalOutputArray(id, self.compile_type(item))
1719            }
1720            gpu::VariableKind::Constant(value) => {
1721                Variable::Constant(value, self.compile_type(item))
1722            }
1723            gpu::VariableKind::SharedArray { id, length, .. } => {
1724                let item = self.compile_type(item);
1725                Variable::SharedArray(id, item, length)
1726            }
1727            gpu::VariableKind::Shared { id } => {
1728                let item = self.compile_type(item);
1729                Variable::Shared(id, item)
1730            }
1731            gpu::VariableKind::ConstantArray {
1732                id,
1733                length,
1734                unroll_factor,
1735            } => {
1736                let item = self.compile_type(item);
1737                Variable::ConstantArray(id, item, length * unroll_factor)
1738            }
1739            gpu::VariableKind::Builtin(builtin) => match builtin {
1740                gpu::Builtin::AbsolutePos => {
1741                    self.flags.indexes.absolute_pos = true;
1742                    let item = self.compile_type(item);
1743                    Variable::AbsolutePos(item.elem)
1744                }
1745                gpu::Builtin::CubePosCluster
1746                    if self.compilation_options.supports_features.clusters =>
1747                {
1748                    self.flags.indexes.cluster_pos = true;
1749                    Variable::ClusterRank
1750                }
1751                gpu::Builtin::CubePosClusterX
1752                    if self.compilation_options.supports_features.clusters =>
1753                {
1754                    self.flags.indexes.cluster_pos = true;
1755                    Variable::ClusterIndexX
1756                }
1757                gpu::Builtin::CubePosClusterY
1758                    if self.compilation_options.supports_features.clusters =>
1759                {
1760                    self.flags.indexes.cluster_pos = true;
1761                    Variable::ClusterIndexY
1762                }
1763                gpu::Builtin::CubePosClusterZ
1764                    if self.compilation_options.supports_features.clusters =>
1765                {
1766                    self.flags.indexes.cluster_pos = true;
1767                    Variable::ClusterIndexZ
1768                }
1769                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1770                // (1, 1, 1) if unsupported
1771                gpu::Builtin::CubePosCluster
1772                | gpu::Builtin::CubePosClusterX
1773                | gpu::Builtin::CubePosClusterY
1774                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1775                gpu::Builtin::AbsolutePosX => {
1776                    self.flags.indexes.absolute_pos_tuple = true;
1777                    Variable::AbsolutePosX
1778                }
1779                gpu::Builtin::AbsolutePosY => {
1780                    self.flags.indexes.absolute_pos_tuple = true;
1781                    Variable::AbsolutePosY
1782                }
1783                gpu::Builtin::AbsolutePosZ => {
1784                    self.flags.indexes.absolute_pos_tuple = true;
1785                    Variable::AbsolutePosZ
1786                }
1787                gpu::Builtin::CubeDim => {
1788                    self.flags.indexes.cube_dim = true;
1789                    Variable::CubeDim
1790                }
1791                gpu::Builtin::CubeDimX => {
1792                    self.flags.indexes.cube_dim_tuple = true;
1793                    Variable::CubeDimX
1794                }
1795                gpu::Builtin::CubeDimY => {
1796                    self.flags.indexes.cube_dim_tuple = true;
1797                    Variable::CubeDimY
1798                }
1799                gpu::Builtin::CubeDimZ => {
1800                    self.flags.indexes.cube_dim_tuple = true;
1801                    Variable::CubeDimZ
1802                }
1803                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1804                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1805                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1806                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1807                gpu::Builtin::CubePos => {
1808                    self.flags.indexes.cube_pos = true;
1809                    let item = self.compile_type(item);
1810                    Variable::CubePos(item.elem)
1811                }
1812                gpu::Builtin::CubePosX => {
1813                    self.flags.indexes.cube_pos_tuple = true;
1814                    Variable::CubePosX
1815                }
1816                gpu::Builtin::CubePosY => {
1817                    self.flags.indexes.cube_pos_tuple = true;
1818                    Variable::CubePosY
1819                }
1820                gpu::Builtin::CubePosZ => {
1821                    self.flags.indexes.cube_pos_tuple = true;
1822                    Variable::CubePosZ
1823                }
1824                gpu::Builtin::CubeCount => {
1825                    self.flags.indexes.cube_count = true;
1826                    let item = self.compile_type(item);
1827                    Variable::CubeCount(item.elem)
1828                }
1829                gpu::Builtin::CubeCountX => {
1830                    self.flags.indexes.cube_count_tuple = true;
1831                    Variable::CubeCountX
1832                }
1833                gpu::Builtin::CubeCountY => {
1834                    self.flags.indexes.cube_count_tuple = true;
1835                    Variable::CubeCountY
1836                }
1837                gpu::Builtin::CubeCountZ => {
1838                    self.flags.indexes.cube_count_tuple = true;
1839                    Variable::CubeCountZ
1840                }
1841                gpu::Builtin::UnitPos => {
1842                    self.flags.indexes.unit_pos = true;
1843                    Variable::UnitPos
1844                }
1845                gpu::Builtin::UnitPosX => {
1846                    self.flags.indexes.unit_pos_tuple = true;
1847                    Variable::UnitPosX
1848                }
1849                gpu::Builtin::UnitPosY => {
1850                    self.flags.indexes.unit_pos_tuple = true;
1851                    Variable::UnitPosY
1852                }
1853                gpu::Builtin::UnitPosZ => {
1854                    self.flags.indexes.unit_pos_tuple = true;
1855                    Variable::UnitPosZ
1856                }
1857                gpu::Builtin::PlaneDim => {
1858                    self.flags.indexes.plane_dim = true;
1859                    Variable::PlaneDim
1860                }
1861                gpu::Builtin::PlanePos => {
1862                    self.flags.indexes.plane_pos = true;
1863                    Variable::PlanePos
1864                }
1865                gpu::Builtin::UnitPosPlane => {
1866                    self.flags.indexes.unit_pos_plane = true;
1867                    Variable::UnitPosPlane
1868                }
1869            },
1870            gpu::VariableKind::LocalArray {
1871                id,
1872                length,
1873                unroll_factor,
1874            } => {
1875                let item = self.compile_type(item);
1876                if !self.local_arrays.iter().any(|s| s.index == id) {
1877                    self.local_arrays
1878                        .push(LocalArray::new(id, item, length * unroll_factor));
1879                }
1880                Variable::LocalArray(id, item, length)
1881            }
1882            gpu::VariableKind::Matrix { id, mat } => {
1883                self.flags.inst_wmma = true;
1884                Variable::WmmaFragment {
1885                    id,
1886                    frag: self.compile_matrix(mat),
1887                }
1888            }
1889            gpu::VariableKind::Pipeline { id, num_stages } => {
1890                self.flags.op_pipeline = true;
1891                let pipeline = Variable::Pipeline { id };
1892                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1893                    self.pipelines.push(PipelineOps::Init {
1894                        pipeline,
1895                        num_stages,
1896                    });
1897                }
1898                pipeline
1899            }
1900            gpu::VariableKind::BarrierToken { id, level } => {
1901                self.flags.op_barrier = true;
1902                Variable::BarrierToken { id, level }
1903            }
1904        }
1905    }
1906
1907    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1908        Fragment {
1909            ident: self.compile_matrix_ident(matrix.ident),
1910            m: matrix.m as u32,
1911            n: matrix.n as u32,
1912            k: matrix.k as u32,
1913            elem: self.compile_storage_type(matrix.storage),
1914            layout: self.compile_matrix_layout(matrix.layout),
1915        }
1916    }
1917
1918    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1919        match ident {
1920            gpu::MatrixIdent::A => FragmentIdent::A,
1921            gpu::MatrixIdent::B => FragmentIdent::B,
1922            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1923        }
1924    }
1925
1926    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1927        match layout {
1928            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1929            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1930            gpu::MatrixLayout::Undefined => None,
1931        }
1932    }
1933
1934    fn compile_binding(&mut self, binding: cubecl_runtime::kernel::KernelArg) -> KernelArg<D> {
1935        KernelArg {
1936            id: binding.id,
1937            item: self.compile_type(binding.ty),
1938            size: binding.size,
1939            vis: binding.visibility,
1940        }
1941    }
1942
1943    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1944        let item = match ty {
1945            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1946            gpu::Type::Vector(ty, vector_size) => {
1947                Item::new(self.compile_storage_type(ty), vector_size, false)
1948            }
1949            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1950        };
1951        if item.elem != super::Elem::TF32 {
1952            self.items.insert(item);
1953            self.items.insert(item.optimized());
1954        } else {
1955            // TF32 is represented as `float` in C++
1956            let mut item = item;
1957            item.elem = super::Elem::F32;
1958            self.items.insert(item);
1959        }
1960
1961        item
1962    }
1963
1964    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1965        match value {
1966            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1967            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1968            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1969                FloatKind::E2M1 => {
1970                    self.flags.elem_fp4 = true;
1971                    Elem::FP4x2(FP4Kind::E2M1)
1972                }
1973                FloatKind::E2M3 => {
1974                    self.flags.elem_fp6 = true;
1975                    Elem::FP6x2(FP6Kind::E2M3)
1976                }
1977                FloatKind::E3M2 => {
1978                    self.flags.elem_fp6 = true;
1979                    Elem::FP6(FP6Kind::E3M2)
1980                }
1981                FloatKind::E4M3 => {
1982                    self.flags.elem_fp8 = true;
1983                    Elem::FP8x2(FP8Kind::E4M3)
1984                }
1985                FloatKind::E5M2 => {
1986                    self.flags.elem_fp8 = true;
1987                    Elem::FP8x2(FP8Kind::E5M2)
1988                }
1989                FloatKind::UE8M0 => {
1990                    self.flags.elem_fp8 = true;
1991                    Elem::FP8x2(FP8Kind::UE8M0)
1992                }
1993                FloatKind::F16 => {
1994                    self.flags.elem_f16 = true;
1995                    Elem::F16x2
1996                }
1997                FloatKind::BF16 => {
1998                    self.flags.elem_bf16 = true;
1999                    Elem::BF16x2
2000                }
2001                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
2002            },
2003            gpu::StorageType::Packed(other, factor) => {
2004                unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
2005            }
2006            gpu::StorageType::Opaque(ty) => match ty {
2007                gpu::OpaqueType::Barrier(level) => {
2008                    self.flags.op_barrier = true;
2009                    Elem::Barrier(level)
2010                }
2011            },
2012        }
2013    }
2014
2015    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
2016        match value {
2017            gpu::ElemType::Float(kind) => match kind {
2018                gpu::FloatKind::E2M1 => {
2019                    self.flags.elem_fp4 = true;
2020                    Elem::FP4(FP4Kind::E2M1)
2021                }
2022                gpu::FloatKind::E2M3 => {
2023                    self.flags.elem_fp6 = true;
2024                    Elem::FP6(FP6Kind::E2M3)
2025                }
2026                gpu::FloatKind::E3M2 => {
2027                    self.flags.elem_fp6 = true;
2028                    Elem::FP6(FP6Kind::E3M2)
2029                }
2030                gpu::FloatKind::E4M3 => {
2031                    self.flags.elem_fp8 = true;
2032                    Elem::FP8(FP8Kind::E4M3)
2033                }
2034                gpu::FloatKind::E5M2 => {
2035                    self.flags.elem_fp8 = true;
2036                    Elem::FP8(FP8Kind::E5M2)
2037                }
2038                gpu::FloatKind::UE8M0 => {
2039                    self.flags.elem_fp8 = true;
2040                    Elem::FP8(FP8Kind::UE8M0)
2041                }
2042                gpu::FloatKind::F16 => {
2043                    self.flags.elem_f16 = true;
2044                    Elem::F16
2045                }
2046                gpu::FloatKind::BF16 => {
2047                    self.flags.elem_bf16 = true;
2048                    Elem::BF16
2049                }
2050                gpu::FloatKind::TF32 => Elem::TF32,
2051                gpu::FloatKind::Flex32 => Elem::F32,
2052                gpu::FloatKind::F32 => Elem::F32,
2053                gpu::FloatKind::F64 => Elem::F64,
2054            },
2055            gpu::ElemType::Int(kind) => match kind {
2056                gpu::IntKind::I8 => Elem::I8,
2057                gpu::IntKind::I16 => Elem::I16,
2058                gpu::IntKind::I32 => Elem::I32,
2059                gpu::IntKind::I64 => Elem::I64,
2060            },
2061            gpu::ElemType::UInt(kind) => match kind {
2062                gpu::UIntKind::U8 => Elem::U8,
2063                gpu::UIntKind::U16 => Elem::U16,
2064                gpu::UIntKind::U32 => Elem::U32,
2065                gpu::UIntKind::U64 => Elem::U64,
2066            },
2067            gpu::ElemType::Bool => Elem::Bool,
2068        }
2069    }
2070}
2071
2072fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
2073    match elem {
2074        gpu::ElemType::Float(kind) => matches!(
2075            kind,
2076            FloatKind::E2M1
2077                | FloatKind::E2M3
2078                | FloatKind::E3M2
2079                | FloatKind::E4M3
2080                | FloatKind::E5M2
2081                | FloatKind::UE8M0
2082        ),
2083        _ => false,
2084    }
2085}
2086
2087fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2088    Variable::Constant(
2089        gpu::ConstantValue::UInt(value as u64),
2090        Item::new(Elem::U32, 1, true),
2091    )
2092}
2093
2094pub fn register_supported_types(props: &mut DeviceProperties) {
2095    props.register_address_type(gpu::AddressType::U32);
2096    props.register_address_type(gpu::AddressType::U64);
2097
2098    let supported_types = [
2099        gpu::ElemType::UInt(gpu::UIntKind::U8),
2100        gpu::ElemType::UInt(gpu::UIntKind::U16),
2101        gpu::ElemType::UInt(gpu::UIntKind::U32),
2102        gpu::ElemType::UInt(gpu::UIntKind::U64),
2103        gpu::ElemType::Int(gpu::IntKind::I8),
2104        gpu::ElemType::Int(gpu::IntKind::I16),
2105        gpu::ElemType::Int(gpu::IntKind::I32),
2106        gpu::ElemType::Int(gpu::IntKind::I64),
2107        gpu::ElemType::Float(gpu::FloatKind::BF16),
2108        gpu::ElemType::Float(gpu::FloatKind::F16),
2109        gpu::ElemType::Float(gpu::FloatKind::F32),
2110        gpu::ElemType::Float(gpu::FloatKind::Flex32),
2111        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
2112        //gpu::Elem::Float(gpu::FloatKind::F64),
2113        gpu::ElemType::Bool,
2114    ];
2115
2116    let supported_atomic_types = [
2117        gpu::ElemType::Int(gpu::IntKind::I32),
2118        gpu::ElemType::Int(gpu::IntKind::I64),
2119        gpu::ElemType::UInt(gpu::UIntKind::U32),
2120        gpu::ElemType::UInt(gpu::UIntKind::U64),
2121        gpu::ElemType::Float(gpu::FloatKind::F32),
2122    ];
2123
2124    for ty in supported_types {
2125        props.register_type_usage(ty, TypeUsage::all());
2126    }
2127
2128    for ty in supported_atomic_types {
2129        props.register_atomic_type_usage(
2130            Type::new(gpu::StorageType::Atomic(ty)),
2131            AtomicUsage::Add | AtomicUsage::LoadStore,
2132        );
2133    }
2134}