cubecl_cpp/shared/
base.rs

1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::ir::{self as gpu, OpaqueType, StorageType};
3use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind};
4use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
5use cubecl_core::server::ExecutionMode;
6use cubecl_core::{CubeDim, ir::ElemType};
7use cubecl_core::{
8    ir::{Operation, SourceLoc},
9    prelude::{FastMath, KernelDefinition},
10};
11use cubecl_opt::{Optimizer, SharedLiveness};
12use cubecl_runtime::compiler::CompilationError;
13use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage, compiler::Compiler};
14use std::{collections::HashSet, fmt::Debug};
15
16use crate::shared::MmaShape;
17
18use super::{
19    BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
20    Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
21    Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
22};
23use super::{FP4Kind, barrier::BarrierOps};
24use super::{FP8Kind, pipeline::PipelineOps};
25
26pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
27    std::sync::atomic::AtomicU32::new(0);
28
29#[derive(Clone, Debug)]
30pub struct CompilationOptions {
31    pub warp_size: u32,
32    pub supports_features: CppSupportedFeatures,
33}
34
35#[derive(Clone, Debug, Default)]
36pub struct CppSupportedFeatures {
37    pub grid_constants: bool,
38    pub clusters: bool,
39    pub fast_math: bool,
40    pub fast_tanh: bool,
41    pub elect_sync: bool,
42}
43
44impl Default for CompilationOptions {
45    fn default() -> Self {
46        Self {
47            warp_size: 32,
48            supports_features: Default::default(),
49        }
50    }
51}
52
53/// Cube indexes flags.
54/// When true the corresponding index is declared and computed as needed in the kernel.
55#[derive(Debug, Clone, Default)]
56pub struct CubeIndexFlags {
57    pub absolute_pos: bool,
58    pub absolute_pos_tuple: bool,
59    pub cube_count: bool,
60    pub cube_count_tuple: bool,
61    pub cube_dim: bool,
62    pub cube_dim_tuple: bool,
63    pub cube_pos: bool,
64    pub cube_pos_tuple: bool,
65    pub plane_dim: bool,
66    pub plane_dim_checked: bool,
67    pub plane_index: bool,
68    pub unit_pos: bool,
69    pub unit_pos_tuple: bool,
70    pub unit_pos_plane: bool,
71    pub cluster_pos: bool,
72}
73
74/// Flags gathered during Cube IR translation for the kernel compilation.
75#[derive(Debug, Clone)]
76pub struct Flags {
77    pub elem_fp4: bool,
78    pub elem_fp6: bool,
79    pub elem_fp8: bool,
80    pub elem_bf16: bool,
81    pub elem_f16: bool,
82    pub elem_tf32: bool,
83    pub indexes: CubeIndexFlags,
84    pub op_barrier: bool,
85    pub op_pipeline: bool,
86    pub inst_tma: bool,
87    pub inst_tma_im2col: bool,
88    pub inst_wmma: bool,
89    pub inst_ptx_wrappers: bool,
90    pub inst_async_copy: bool,
91    pub use_grid_constants: bool,
92    pub static_meta_length: usize,
93    pub has_dynamic_meta: bool,
94    pub cube_dim: CubeDim,
95    pub cluster_dim: Option<CubeDim>,
96}
97
98#[allow(clippy::too_many_arguments)]
99#[derive(Clone, Debug)]
100pub struct CppCompiler<D: Dialect> {
101    barriers: Vec<BarrierOps<D>>,
102    compilation_options: CompilationOptions,
103    const_arrays: Vec<ConstArray<D>>,
104    ext_meta_positions: Vec<u32>,
105    cluster_dim: CubeDim,
106    extensions: Vec<D::Extension>,
107    flags: Flags,
108    items: HashSet<Item<D>>,
109    local_arrays: Vec<LocalArray<D>>,
110    metadata: cubecl_core::Metadata,
111    pipelines: Vec<PipelineOps<D>>,
112    source_loc: Option<SourceLoc>,
113    strategy: ExecutionMode,
114}
115
116impl Default for Flags {
117    fn default() -> Self {
118        Self {
119            elem_fp4: Default::default(),
120            elem_fp6: Default::default(),
121            elem_fp8: Default::default(),
122            elem_bf16: Default::default(),
123            elem_f16: Default::default(),
124            elem_tf32: Default::default(),
125            indexes: Default::default(),
126            op_barrier: Default::default(),
127            op_pipeline: Default::default(),
128            inst_tma: Default::default(),
129            inst_tma_im2col: Default::default(),
130            inst_wmma: Default::default(),
131            inst_ptx_wrappers: Default::default(),
132            inst_async_copy: Default::default(),
133            use_grid_constants: Default::default(),
134            static_meta_length: Default::default(),
135            has_dynamic_meta: Default::default(),
136            cube_dim: CubeDim::new_single(),
137            cluster_dim: Default::default(),
138        }
139    }
140}
141
142impl<D: Dialect> Default for CppCompiler<D> {
143    fn default() -> Self {
144        Self {
145            barriers: Default::default(),
146            compilation_options: Default::default(),
147            const_arrays: Default::default(),
148            ext_meta_positions: Default::default(),
149            cluster_dim: CubeDim::new_single(),
150            extensions: Default::default(),
151            flags: Flags::default(),
152            items: Default::default(),
153            local_arrays: Default::default(),
154            metadata: Default::default(),
155            pipelines: Default::default(),
156            source_loc: Default::default(),
157            strategy: Default::default(),
158        }
159    }
160}
161
162impl<D: Dialect> Compiler for CppCompiler<D> {
163    type Representation = ComputeKernel<D>;
164    type CompilationOptions = CompilationOptions;
165
166    fn compile(
167        &mut self,
168        mut kernel: KernelDefinition,
169        compilation_options: &Self::CompilationOptions,
170        strategy: ExecutionMode,
171    ) -> Result<Self::Representation, CompilationError> {
172        let errors = kernel.body.pop_errors();
173        if !errors.is_empty() {
174            let mut reason = "Can't compile cpp kernel\nCaused by:\n  ".to_string();
175            for error in errors {
176                reason += error.as_str();
177                reason += "\n";
178            }
179
180            return Err(CompilationError::Validation {
181                reason,
182                backtrace: BackTrace::capture(),
183            });
184        }
185
186        self.compilation_options = compilation_options.clone();
187        self.strategy = strategy;
188
189        if !self.compilation_options.supports_features.clusters {
190            kernel.options.cluster_dim = None;
191        }
192        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
193
194        let ir = self.clone().compile_ir(kernel);
195        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
196        Ok(ir)
197    }
198
199    fn elem_size(&self, elem: gpu::ElemType) -> usize {
200        elem.size()
201    }
202
203    fn extension(&self) -> &'static str {
204        "cpp"
205    }
206}
207
208impl<D: Dialect> CppCompiler<D> {
209    fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
210        self.build_metadata(&value);
211
212        let instructions = self.compile_scope(&mut value.body.clone());
213        let tensor_maps = value
214            .tensor_maps
215            .into_iter()
216            .map(|b| self.compile_binding(b))
217            .collect();
218        let buffers = value
219            .buffers
220            .into_iter()
221            .map(|b| self.compile_binding(b))
222            .collect();
223        let scalars = value
224            .scalars
225            .into_iter()
226            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
227            .collect();
228
229        // translation flags
230        let flags = Flags {
231            indexes: D::builtin_rules(&self.flags.indexes),
232            inst_wmma: self.flags.inst_wmma,
233            op_pipeline: self.flags.op_pipeline,
234            op_barrier: self.flags.op_barrier,
235            elem_fp4: self.flags.elem_fp4,
236            elem_fp6: self.flags.elem_fp6,
237            elem_fp8: self.flags.elem_fp8,
238            elem_bf16: self.flags.elem_bf16,
239            elem_f16: self.flags.elem_f16,
240            elem_tf32: self.flags.elem_tf32,
241            inst_tma: self.flags.inst_tma,
242            inst_tma_im2col: self.flags.inst_tma_im2col,
243            inst_async_copy: self.flags.inst_async_copy,
244            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
245            use_grid_constants: self.compilation_options.supports_features.grid_constants,
246            // TODO: At some point we should only pass dynamic meta if tensors are present,
247            // not if only arrays are present. For now, leave like this
248            has_dynamic_meta: self.metadata.static_len() > 0,
249            static_meta_length: self.metadata.static_len() as usize,
250            cube_dim: value.cube_dim,
251            cluster_dim: value.options.cluster_dim,
252        };
253
254        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
255        let shared_allocs = opt.analysis::<SharedLiveness>();
256        let shared_memories = shared_allocs
257            .allocations
258            .values()
259            .map(|alloc| match alloc.smem {
260                cubecl_opt::SharedMemory::Array {
261                    id,
262                    length,
263                    ty,
264                    align,
265                } => SharedMemory::Array {
266                    index: id,
267                    item: self.compile_type(ty),
268                    length,
269                    align,
270                    offset: alloc.offset,
271                },
272                cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
273                    index: id,
274                    item: self.compile_type(ty),
275                    align,
276                    offset: alloc.offset,
277                },
278            })
279            .collect();
280
281        let body = Body {
282            instructions,
283            shared_memories,
284            pipelines: self.pipelines,
285            barriers: self.barriers,
286            const_arrays: self.const_arrays,
287            local_arrays: self.local_arrays,
288        };
289
290        let mut cluster_dim = value.options.cluster_dim;
291        if !self.compilation_options.supports_features.clusters {
292            cluster_dim = None;
293        }
294
295        ComputeKernel {
296            tensor_maps,
297            buffers,
298            scalars,
299            meta_static_len: self.metadata.static_len() as usize,
300            cube_dim: value.cube_dim,
301            body,
302            extensions: self.extensions,
303            flags,
304            items: self.items,
305            kernel_name: value.options.kernel_name,
306            cluster_dim,
307        }
308    }
309
310    fn build_metadata(&mut self, value: &KernelDefinition) {
311        let mut num_ext = 0;
312
313        let mut all_meta: Vec<_> = value
314            .buffers
315            .iter()
316            .chain(value.tensor_maps.iter())
317            .map(|buf| (buf.id, buf.has_extended_meta))
318            .collect();
319
320        all_meta.sort_by_key(|(id, _)| *id);
321
322        for (_, has_extended_meta) in &all_meta {
323            self.ext_meta_positions.push(num_ext);
324            if *has_extended_meta {
325                num_ext += 1;
326            }
327        }
328
329        let num_meta = all_meta.len();
330
331        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
332    }
333
334    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
335        let id = var.index().expect("Variable should have index");
336        self.ext_meta_positions[id as usize]
337    }
338
339    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
340        let mut instructions = Vec::new();
341
342        let const_arrays = scope
343            .const_arrays
344            .drain(..)
345            .map(|(var, values)| ConstArray {
346                index: var.index().unwrap(),
347                item: self.compile_type(var.ty),
348                size: values.len() as u32,
349                values: values
350                    .into_iter()
351                    .map(|val| self.compile_variable(val))
352                    .collect(),
353            })
354            .collect::<Vec<_>>();
355        self.const_arrays.extend(const_arrays);
356
357        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
358        let dialect_processors = D::processors();
359        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
360        processors.extend(dialect_processors.iter().map(|it| &**it));
361
362        let processing = scope.process(processors);
363
364        for var in processing.variables {
365            instructions.push(Instruction::DeclareVariable {
366                var: self.compile_variable(var),
367            });
368        }
369
370        processing
371            .instructions
372            .into_iter()
373            .for_each(|op| self.compile_instruction(&mut instructions, op));
374
375        instructions
376    }
377
378    fn compile_instruction(
379        &mut self,
380        instructions: &mut Vec<Instruction<D>>,
381        instruction: gpu::Instruction,
382    ) {
383        self.update_debug_loc(instructions, &instruction);
384        let out = instruction.out;
385        match instruction.operation {
386            gpu::Operation::Copy(variable) => {
387                instructions.push(Instruction::Assign(UnaryInstruction {
388                    input: self.compile_variable(variable),
389                    out: self.compile_variable(out.unwrap()),
390                }));
391            }
392            gpu::Operation::Arithmetic(op) => {
393                self.compile_arithmetic(op, out, instruction.modes, instructions)
394            }
395            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
396            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
397            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
398            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
399            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
400            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
401            gpu::Operation::Synchronization(val) => match val {
402                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
403                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
404                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
405                gpu::Synchronization::SyncAsyncProxyShared => {
406                    self.flags.inst_tma = true;
407                    instructions.push(Instruction::ProxyAsyncToSharedFence)
408                }
409            },
410            gpu::Operation::Plane(op) => {
411                self.flags.indexes.plane_dim_checked = true;
412                let out = self.compile_variable(out.unwrap());
413                match op {
414                    gpu::Plane::Sum(op) => {
415                        let instruction = WarpInstruction::ReduceSum {
416                            input: self.compile_variable(op.input),
417                            out,
418                        };
419                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
420                        instructions.push(Instruction::Warp(instruction));
421                    }
422                    gpu::Plane::InclusiveSum(op) => {
423                        self.flags.indexes.unit_pos_plane = true;
424                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
425                            input: self.compile_variable(op.input),
426                            out,
427                        }))
428                    }
429                    gpu::Plane::InclusiveProd(op) => {
430                        self.flags.indexes.unit_pos_plane = true;
431                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
432                            input: self.compile_variable(op.input),
433                            out,
434                        }))
435                    }
436                    gpu::Plane::ExclusiveSum(op) => {
437                        self.flags.indexes.unit_pos_plane = true;
438                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
439                            input: self.compile_variable(op.input),
440                            out,
441                        }))
442                    }
443                    gpu::Plane::ExclusiveProd(op) => {
444                        self.flags.indexes.unit_pos_plane = true;
445                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
446                            input: self.compile_variable(op.input),
447                            out,
448                        }))
449                    }
450                    gpu::Plane::Prod(op) => {
451                        let instruction = WarpInstruction::ReduceProd {
452                            input: self.compile_variable(op.input),
453                            out,
454                        };
455                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
456                        instructions.push(Instruction::Warp(instruction))
457                    }
458                    gpu::Plane::Max(op) => {
459                        let instruction = WarpInstruction::ReduceMax {
460                            input: self.compile_variable(op.input),
461                            out,
462                        };
463                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
464                        instructions.push(Instruction::Warp(instruction))
465                    }
466                    gpu::Plane::Min(op) => {
467                        let instruction = WarpInstruction::ReduceMin {
468                            input: self.compile_variable(op.input),
469                            out,
470                        };
471                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
472                        instructions.push(Instruction::Warp(instruction))
473                    }
474                    gpu::Plane::Elect => {
475                        if self.compilation_options.supports_features.elect_sync {
476                            self.flags.inst_ptx_wrappers = true;
477                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
478                        } else {
479                            instructions
480                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
481                        }
482                    }
483                    gpu::Plane::All(op) => {
484                        instructions.push(Instruction::Warp(WarpInstruction::All {
485                            input: self.compile_variable(op.input),
486                            out,
487                        }))
488                    }
489                    gpu::Plane::Any(op) => {
490                        instructions.push(Instruction::Warp(WarpInstruction::Any {
491                            input: self.compile_variable(op.input),
492                            out,
493                        }))
494                    }
495                    gpu::Plane::Ballot(op) => {
496                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
497                            input: self.compile_variable(op.input),
498                            out,
499                        }))
500                    }
501                    gpu::Plane::Broadcast(op) => {
502                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
503                            input: self.compile_variable(op.lhs),
504                            id: self.compile_variable(op.rhs),
505                            out,
506                        }))
507                    }
508                    gpu::Plane::Shuffle(op) => {
509                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
510                            input: self.compile_variable(op.lhs),
511                            src_lane: self.compile_variable(op.rhs),
512                            out,
513                        }))
514                    }
515                    gpu::Plane::ShuffleXor(op) => {
516                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
517                            input: self.compile_variable(op.lhs),
518                            mask: self.compile_variable(op.rhs),
519                            out,
520                        }))
521                    }
522                    gpu::Plane::ShuffleUp(op) => {
523                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
524                            input: self.compile_variable(op.lhs),
525                            delta: self.compile_variable(op.rhs),
526                            out,
527                        }))
528                    }
529                    gpu::Plane::ShuffleDown(op) => {
530                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
531                            input: self.compile_variable(op.lhs),
532                            delta: self.compile_variable(op.rhs),
533                            out,
534                        }))
535                    }
536                }
537            }
538            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
539            gpu::Operation::NonSemantic(debug) => match debug {
540                gpu::NonSemantic::Print {
541                    format_string,
542                    args,
543                } => instructions.push(Instruction::Printf {
544                    format_string,
545                    args: args
546                        .into_iter()
547                        .map(|arg| self.compile_variable(arg))
548                        .collect(),
549                }),
550                gpu::NonSemantic::Comment { content } => {
551                    instructions.push(Instruction::Comment { content })
552                }
553                // Don't need to handle scopes
554                _ => {}
555            },
556            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
557                gpu::BarrierOps::Declare { barrier } => {
558                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
559                    else {
560                        unreachable!()
561                    };
562                    let barrier = self.compile_variable(barrier);
563                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
564                        barrier,
565                        level,
566                    }));
567                }
568                gpu::BarrierOps::Init {
569                    barrier,
570                    is_elected,
571                    arrival_count,
572                } => {
573                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
574                    else {
575                        unreachable!()
576                    };
577                    let barrier = self.compile_variable(barrier);
578                    let arrival_count = self.compile_variable(arrival_count);
579                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
580                        barrier,
581                        is_elected: self.compile_variable(is_elected),
582                        arrival_count,
583                        level,
584                    }));
585                }
586                gpu::BarrierOps::InitManual {
587                    barrier,
588                    arrival_count,
589                } => {
590                    let barrier = self.compile_variable(barrier);
591                    let arrival_count = self.compile_variable(arrival_count);
592                    instructions.push(Instruction::Barrier(
593                        super::barrier::BarrierOps::InitManual {
594                            barrier,
595                            arrival_count,
596                        },
597                    ));
598                }
599                gpu::BarrierOps::MemCopyAsync {
600                    barrier,
601                    source,
602                    source_length,
603                    offset_source,
604                    offset_out,
605                } => {
606                    instructions.push(Instruction::Barrier(
607                        super::barrier::BarrierOps::MemCopyAsync {
608                            barrier: self.compile_variable(barrier),
609                            source: self.compile_variable(source),
610                            destination: self.compile_variable(out.unwrap()),
611                            source_length: self.compile_variable(source_length),
612                            offset_source: self.compile_variable(offset_source),
613                            offset_out: self.compile_variable(offset_out),
614                            cooperative: false,
615                        },
616                    ));
617                }
618                gpu::BarrierOps::MemCopyAsyncCooperative {
619                    barrier,
620                    source,
621                    source_length,
622                    offset_source,
623                    offset_out,
624                } => {
625                    instructions.push(Instruction::Barrier(
626                        super::barrier::BarrierOps::MemCopyAsync {
627                            barrier: self.compile_variable(barrier),
628                            source: self.compile_variable(source),
629                            destination: self.compile_variable(out.unwrap()),
630                            source_length: self.compile_variable(source_length),
631                            offset_source: self.compile_variable(offset_source),
632                            offset_out: self.compile_variable(offset_out),
633                            cooperative: true,
634                        },
635                    ));
636                }
637                gpu::BarrierOps::MemCopyAsyncTx {
638                    barrier,
639                    source,
640                    source_length,
641                    offset_source,
642                    offset_out,
643                } => {
644                    instructions.push(Instruction::Barrier(
645                        super::barrier::BarrierOps::MemCopyAsyncTx {
646                            barrier: self.compile_variable(barrier),
647                            source: self.compile_variable(source),
648                            destination: self.compile_variable(out.unwrap()),
649                            source_length: self.compile_variable(source_length),
650                            offset_source: self.compile_variable(offset_source),
651                            offset_out: self.compile_variable(offset_out),
652                        },
653                    ));
654                }
655                gpu::BarrierOps::CopyAsync {
656                    source,
657                    source_length,
658                    offset_source,
659                    offset_out,
660                    copy_length,
661                    checked,
662                } => {
663                    self.flags.inst_async_copy = true;
664                    instructions.push(Instruction::Barrier(
665                        super::barrier::BarrierOps::CopyAsync {
666                            source: self.compile_variable(source),
667                            destination: self.compile_variable(out.unwrap()),
668                            source_length: self.compile_variable(source_length),
669                            offset_source: self.compile_variable(offset_source),
670                            offset_out: self.compile_variable(offset_out),
671                            copy_size: copy_length,
672                            checked,
673                        },
674                    ));
675                }
676                gpu::BarrierOps::TmaLoad {
677                    barrier,
678                    tensor_map,
679                    offset_out,
680                    indices,
681                } => {
682                    instructions.push(Instruction::Barrier(
683                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
684                            barrier: self.compile_variable(barrier),
685                            smem_buffer: self.compile_variable(out.unwrap()),
686                            smem_offset: self.compile_variable(offset_out),
687                            tensor_map: self.compile_variable(tensor_map),
688                            indices: indices
689                                .into_iter()
690                                .map(|it| self.compile_variable(it))
691                                .collect(),
692                        },
693                    ));
694                }
695                gpu::BarrierOps::TmaLoadIm2col {
696                    barrier,
697                    tensor_map,
698                    offset_out,
699                    indices,
700                    offsets,
701                } => {
702                    self.flags.inst_tma_im2col = true;
703                    instructions.push(Instruction::Barrier(
704                        super::barrier::BarrierOps::TmaLoadIm2col {
705                            barrier: self.compile_variable(barrier),
706                            smem_buffer: self.compile_variable(out.unwrap()),
707                            smem_offset: self.compile_variable(offset_out),
708                            tensor_map: self.compile_variable(tensor_map),
709                            indices: indices
710                                .into_iter()
711                                .map(|it| self.compile_variable(it))
712                                .collect(),
713                            offsets: offsets
714                                .into_iter()
715                                .map(|it| self.compile_variable(it))
716                                .collect(),
717                        },
718                    ));
719                }
720                gpu::BarrierOps::Arrive { barrier } => {
721                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
722                        barrier: self.compile_variable(barrier),
723                        token: self.compile_variable(out.unwrap()),
724                    }))
725                }
726                gpu::BarrierOps::ArriveTx {
727                    barrier,
728                    arrive_count_update,
729                    transaction_count_update,
730                } => {
731                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
732                        barrier: self.compile_variable(barrier),
733                        token: self.compile_variable(out.unwrap()),
734                        arrive_count_update: self.compile_variable(arrive_count_update),
735                        transaction_count_update: self.compile_variable(transaction_count_update),
736                    }))
737                }
738                gpu::BarrierOps::CommitCopyAsync { barrier } => {
739                    self.flags.inst_async_copy = true;
740                    instructions.push(Instruction::Barrier(
741                        super::barrier::BarrierOps::ArriveCopyAsync {
742                            barrier: self.compile_variable(barrier),
743                        },
744                    ))
745                }
746                gpu::BarrierOps::ExpectTx {
747                    barrier,
748                    transaction_count_update,
749                } => {
750                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
751                        barrier: self.compile_variable(barrier),
752                        transaction_count_update: self.compile_variable(transaction_count_update),
753                    }))
754                }
755                gpu::BarrierOps::Wait { barrier, token } => {
756                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
757                        barrier: self.compile_variable(barrier),
758                        token: self.compile_variable(token),
759                    }))
760                }
761                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
762                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
763                        barrier: self.compile_variable(barrier),
764                        phase: self.compile_variable(phase),
765                    }),
766                ),
767                gpu::BarrierOps::ArriveAndWait { barrier } => {
768                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
769                    else {
770                        unreachable!()
771                    };
772                    instructions.push(Instruction::Barrier(
773                        super::barrier::BarrierOps::ArriveAndWait {
774                            barrier: self.compile_variable(barrier),
775                            level,
776                        },
777                    ))
778                }
779            },
780            gpu::Operation::Tma(tma_ops) => {
781                self.flags.inst_tma = true;
782                match tma_ops {
783                    gpu::TmaOps::TmaStore {
784                        source,
785                        coordinates,
786                        offset_source,
787                    } => {
788                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
789                            smem_buffer: self.compile_variable(source),
790                            smem_offset: self.compile_variable(offset_source),
791                            tensor_map: self.compile_variable(out.unwrap()),
792                            indices: coordinates
793                                .into_iter()
794                                .map(|it| self.compile_variable(it))
795                                .collect(),
796                        });
797                    }
798                    gpu::TmaOps::CommitGroup => {
799                        instructions.push(Instruction::BulkCommitGroup);
800                    }
801                    gpu::TmaOps::WaitGroup { max_pending } => {
802                        instructions.push(Instruction::BulkWaitGroup { max_pending });
803                    }
804                    gpu::TmaOps::WaitGroupRead { max_pending } => {
805                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
806                    }
807                }
808            }
809            gpu::Operation::Marker(_) => {}
810        }
811    }
812
813    fn update_debug_loc(
814        &mut self,
815        instructions: &mut Vec<Instruction<D>>,
816        inst: &gpu::Instruction,
817    ) {
818        if !matches!(inst.operation, Operation::NonSemantic(_)) {
819            match &inst.source_loc {
820                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
821                    self.source_loc = Some(loc.clone());
822                    instructions.push(Instruction::Line {
823                        file: loc.source.file.clone(),
824                        line: loc.line,
825                    });
826                }
827                _ => {}
828            }
829        }
830    }
831
832    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
833        self.flags.inst_wmma = true;
834
835        let out = self.compile_variable(out.unwrap());
836
837        let inst = match cmma {
838            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
839                frag: out,
840                value: self.compile_variable(value),
841            },
842            gpu::CoopMma::Load {
843                value,
844                stride,
845                offset,
846                layout,
847            } => WmmaInstruction::Load {
848                frag: out,
849                offset: self.compile_variable(offset),
850                value: self.compile_variable(value),
851                stride: self.compile_variable(stride),
852                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
853            },
854            gpu::CoopMma::Execute {
855                mat_a,
856                mat_b,
857                mat_c,
858            } => WmmaInstruction::Execute {
859                frag_a: self.compile_variable(mat_a),
860                frag_b: self.compile_variable(mat_b),
861                frag_c: self.compile_variable(mat_c),
862                frag_d: out,
863                warp_size: self.compilation_options.warp_size,
864            },
865            gpu::CoopMma::ExecuteManual {
866                matrix,
867                registers_a,
868                registers_b,
869                registers_c,
870            } => WmmaInstruction::ExecuteManual {
871                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
872                frag_a: self.compile_variable(registers_a),
873                frag_b: self.compile_variable(registers_b),
874                frag_c: self.compile_variable(registers_c),
875                frag_d: out,
876            },
877            gpu::CoopMma::ExecuteScaled {
878                matrix,
879                registers_a,
880                registers_b,
881                registers_c,
882                scales_a,
883                scales_b,
884                scales_factor,
885            } => WmmaInstruction::ExecuteScaled {
886                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
887                frag_a: self.compile_variable(registers_a),
888                frag_b: self.compile_variable(registers_b),
889                frag_c: self.compile_variable(registers_c),
890                frag_d: out,
891
892                scales_a: self.compile_variable(scales_a),
893                scales_b: self.compile_variable(scales_b),
894                scales_factor,
895            },
896            gpu::CoopMma::Store {
897                mat,
898                stride,
899                offset,
900                layout,
901            } => {
902                self.flags.indexes.unit_pos = true;
903                self.flags.indexes.plane_index = true;
904                WmmaInstruction::Store {
905                    output: out,
906                    offset: self.compile_variable(offset),
907                    frag: self.compile_variable(mat),
908                    stride: self.compile_variable(stride),
909                    layout: self
910                        .compile_matrix_layout(layout)
911                        .expect("Layout required for store instruction"),
912                }
913            }
914            gpu::CoopMma::LoadMatrix {
915                buffer,
916                offset,
917                line_size,
918                factor,
919                transpose,
920            } => WmmaInstruction::LdMatrix {
921                output: out,
922                buffer: self.compile_variable(buffer),
923                offset: self.compile_variable(offset),
924                line_size,
925                factor,
926                transpose,
927            },
928            gpu::CoopMma::StoreMatrix {
929                offset,
930                line_size,
931                registers,
932                factor,
933                transpose,
934            } => WmmaInstruction::StMatrix {
935                registers: self.compile_variable(registers),
936                buffer: out,
937                offset: self.compile_variable(offset),
938                line_size,
939                factor,
940                transpose,
941            },
942            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
943                input: self.compile_variable(input),
944                output: out,
945            },
946            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
947                panic!("Row/Col index should be handled by processors")
948            }
949        };
950
951        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
952
953        Instruction::Wmma(inst)
954    }
955
956    fn compile_metadata(
957        &mut self,
958        metadata: gpu::Metadata,
959        out: Option<gpu::Variable>,
960    ) -> Instruction<D> {
961        let out = out.unwrap();
962        match metadata {
963            gpu::Metadata::Stride { dim, var } => {
964                let position = self.ext_meta_position(var);
965                let offset = self.metadata.stride_offset_index(position);
966                Instruction::ExtendedMetadata {
967                    info_offset: self.compile_variable(offset.into()),
968                    dim: self.compile_variable(dim),
969                    split_meta: self.compilation_options.supports_features.grid_constants,
970                    static_offset: self.metadata.static_len(),
971                    out: self.compile_variable(out),
972                }
973            }
974            gpu::Metadata::Shape { dim, var } => {
975                let position = self.ext_meta_position(var);
976                let offset = self.metadata.shape_offset_index(position);
977                Instruction::ExtendedMetadata {
978                    info_offset: self.compile_variable(offset.into()),
979                    dim: self.compile_variable(dim),
980                    split_meta: self.compilation_options.supports_features.grid_constants,
981                    static_offset: self.metadata.static_len(),
982                    out: self.compile_variable(out),
983                }
984            }
985            gpu::Metadata::Rank { var } => {
986                let out = self.compile_variable(out);
987                let pos = self.ext_meta_position(var);
988                let offset = self.metadata.rank_index(pos);
989                super::Instruction::Metadata {
990                    info_offset: self.compile_variable(offset.into()),
991                    split_meta: self.compilation_options.supports_features.grid_constants,
992                    out,
993                }
994            }
995            gpu::Metadata::Length { var } => {
996                let input = self.compile_variable(var);
997                let out = self.compile_variable(out);
998
999                match input {
1000                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1001                    Variable::SharedArray(_id, _item, length) => {
1002                        Instruction::ConstLength { length, out }
1003                    }
1004                    _ => {
1005                        let id = input.id().expect("Variable should have id");
1006                        let offset = self.metadata.len_index(id);
1007                        Instruction::Metadata {
1008                            info_offset: self.compile_variable(offset.into()),
1009                            split_meta: self.compilation_options.supports_features.grid_constants,
1010                            out,
1011                        }
1012                    }
1013                }
1014            }
1015            gpu::Metadata::BufferLength { var } => {
1016                let input = self.compile_variable(var);
1017                let out = self.compile_variable(out);
1018
1019                match input {
1020                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1021                    _ => {
1022                        let id = input.id().expect("Variable should have id");
1023                        let offset = self.metadata.buffer_len_index(id);
1024                        Instruction::Metadata {
1025                            info_offset: self.compile_variable(offset.into()),
1026                            split_meta: self.compilation_options.supports_features.grid_constants,
1027                            out,
1028                        }
1029                    }
1030                }
1031            }
1032        }
1033    }
1034
1035    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
1036        match branch {
1037            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
1038                cond: self.compile_variable(op.cond),
1039                instructions: self.compile_scope(&mut op.scope),
1040            }),
1041            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
1042                cond: self.compile_variable(op.cond),
1043                instructions_if: self.compile_scope(&mut op.scope_if),
1044                instructions_else: self.compile_scope(&mut op.scope_else),
1045            }),
1046            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1047                value: self.compile_variable(op.value),
1048                instructions_default: self.compile_scope(&mut op.scope_default),
1049                instructions_cases: op
1050                    .cases
1051                    .into_iter()
1052                    .map(|(val, mut block)| {
1053                        (self.compile_variable(val), self.compile_scope(&mut block))
1054                    })
1055                    .collect(),
1056            }),
1057            gpu::Branch::Return => instructions.push(Instruction::Return),
1058            gpu::Branch::Break => instructions.push(Instruction::Break),
1059            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1060                i: self.compile_variable(range_loop.i),
1061                start: self.compile_variable(range_loop.start),
1062                end: self.compile_variable(range_loop.end),
1063                step: range_loop.step.map(|it| self.compile_variable(it)),
1064                inclusive: range_loop.inclusive,
1065                instructions: self.compile_scope(&mut range_loop.scope),
1066            }),
1067            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1068                instructions: self.compile_scope(&mut op.scope),
1069            }),
1070        };
1071    }
1072
1073    fn compile_atomic(
1074        &mut self,
1075        value: gpu::AtomicOp,
1076        out: Option<gpu::Variable>,
1077        instructions: &mut Vec<Instruction<D>>,
1078    ) {
1079        let out = out.unwrap();
1080        match value {
1081            gpu::AtomicOp::Load(op) => {
1082                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1083            }
1084            gpu::AtomicOp::Store(op) => {
1085                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1086            }
1087            gpu::AtomicOp::Swap(op) => {
1088                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1089            }
1090            gpu::AtomicOp::Add(op) => {
1091                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1092            }
1093            gpu::AtomicOp::Sub(op) => {
1094                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1095            }
1096            gpu::AtomicOp::Max(op) => {
1097                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1098            }
1099            gpu::AtomicOp::Min(op) => {
1100                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1101            }
1102            gpu::AtomicOp::And(op) => {
1103                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1104            }
1105            gpu::AtomicOp::Or(op) => {
1106                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1107            }
1108            gpu::AtomicOp::Xor(op) => {
1109                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1110            }
1111            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1112                input: self.compile_variable(op.input),
1113                cmp: self.compile_variable(op.cmp),
1114                val: self.compile_variable(op.val),
1115                out: self.compile_variable(out),
1116            }),
1117        }
1118    }
1119
1120    fn compile_arithmetic(
1121        &mut self,
1122        value: gpu::Arithmetic,
1123        out: Option<gpu::Variable>,
1124        modes: InstructionModes,
1125        instructions: &mut Vec<Instruction<D>>,
1126    ) {
1127        let out = out.unwrap();
1128        match value {
1129            gpu::Arithmetic::Add(op) => {
1130                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1131            }
1132            gpu::Arithmetic::SaturatingAdd(op) => {
1133                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1134            }
1135            gpu::Arithmetic::Mul(op) => {
1136                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1137            }
1138            gpu::Arithmetic::Div(op) => {
1139                let op = self.compile_binary(op, out);
1140                instructions.push(self.select_fast_float(
1141                    out.ty,
1142                    modes,
1143                    FastMath::AllowReciprocal
1144                        | FastMath::ReducedPrecision
1145                        | FastMath::UnsignedZero
1146                        | FastMath::NotInf,
1147                    Instruction::Div(op),
1148                    Instruction::FastDiv(op),
1149                ))
1150            }
1151            gpu::Arithmetic::Sub(op) => {
1152                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1153            }
1154            gpu::Arithmetic::SaturatingSub(op) => {
1155                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1156            }
1157            gpu::Arithmetic::MulHi(op) => {
1158                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1159                D::register_instruction_extension(&mut self.extensions, &instruction);
1160                instructions.push(instruction)
1161            }
1162            gpu::Arithmetic::Modulo(op) => {
1163                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1164            }
1165            gpu::Arithmetic::Abs(op) => {
1166                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1167            }
1168            gpu::Arithmetic::Exp(op) => {
1169                let op = self.compile_unary(op, out);
1170                instructions.push(self.select_fast_float(
1171                    out.ty,
1172                    modes,
1173                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1174                    Instruction::Exp(op),
1175                    Instruction::FastExp(op),
1176                ));
1177            }
1178            gpu::Arithmetic::Log(op) => {
1179                let op = self.compile_unary(op, out);
1180                instructions.push(self.select_fast_float(
1181                    out.ty,
1182                    modes,
1183                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1184                    Instruction::Log(op),
1185                    Instruction::FastLog(op),
1186                ));
1187            }
1188            gpu::Arithmetic::Log1p(op) => {
1189                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1190            }
1191            gpu::Arithmetic::Cos(op) => {
1192                let op = self.compile_unary(op, out);
1193                instructions.push(self.select_fast_float(
1194                    out.ty,
1195                    modes,
1196                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1197                    Instruction::Cos(op),
1198                    Instruction::FastCos(op),
1199                ));
1200            }
1201            gpu::Arithmetic::Sin(op) => {
1202                let op = self.compile_unary(op, out);
1203                instructions.push(self.select_fast_float(
1204                    out.ty,
1205                    modes,
1206                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1207                    Instruction::Sin(op),
1208                    Instruction::FastSin(op),
1209                ));
1210            }
1211            gpu::Arithmetic::Tan(op) => {
1212                instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1213            }
1214            gpu::Arithmetic::Tanh(op) => {
1215                let op = self.compile_unary(op, out);
1216                let instruction = Instruction::Tanh(op);
1217                D::register_instruction_extension(&mut self.extensions, &instruction);
1218                if self.compilation_options.supports_features.fast_tanh {
1219                    instructions.push(self.select_fast_float(
1220                        out.ty,
1221                        modes,
1222                        FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1223                        instruction,
1224                        Instruction::FastTanh(op),
1225                    ))
1226                } else {
1227                    instructions.push(instruction);
1228                }
1229            }
1230            gpu::Arithmetic::Sinh(op) => {
1231                let instruction = Instruction::Sinh(self.compile_unary(op, out));
1232                D::register_instruction_extension(&mut self.extensions, &instruction);
1233                instructions.push(instruction)
1234            }
1235            gpu::Arithmetic::Cosh(op) => {
1236                let instruction = Instruction::Cosh(self.compile_unary(op, out));
1237                D::register_instruction_extension(&mut self.extensions, &instruction);
1238                instructions.push(instruction)
1239            }
1240            gpu::Arithmetic::ArcCos(op) => {
1241                let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1242                D::register_instruction_extension(&mut self.extensions, &instruction);
1243                instructions.push(instruction)
1244            }
1245            gpu::Arithmetic::ArcSin(op) => {
1246                let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1247                D::register_instruction_extension(&mut self.extensions, &instruction);
1248                instructions.push(instruction)
1249            }
1250            gpu::Arithmetic::ArcTan(op) => {
1251                let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1252                D::register_instruction_extension(&mut self.extensions, &instruction);
1253                instructions.push(instruction)
1254            }
1255            gpu::Arithmetic::ArcSinh(op) => {
1256                let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1257                D::register_instruction_extension(&mut self.extensions, &instruction);
1258                instructions.push(instruction)
1259            }
1260            gpu::Arithmetic::ArcCosh(op) => {
1261                let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1262                D::register_instruction_extension(&mut self.extensions, &instruction);
1263                instructions.push(instruction)
1264            }
1265            gpu::Arithmetic::ArcTanh(op) => {
1266                let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1267                D::register_instruction_extension(&mut self.extensions, &instruction);
1268                instructions.push(instruction)
1269            }
1270            gpu::Arithmetic::Degrees(op) => {
1271                let instruction = Instruction::Degrees(self.compile_unary(op, out));
1272                D::register_instruction_extension(&mut self.extensions, &instruction);
1273                instructions.push(instruction)
1274            }
1275            gpu::Arithmetic::Radians(op) => {
1276                let instruction = Instruction::Radians(self.compile_unary(op, out));
1277                D::register_instruction_extension(&mut self.extensions, &instruction);
1278                instructions.push(instruction)
1279            }
1280            gpu::Arithmetic::ArcTan2(op) => {
1281                let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1282                D::register_instruction_extension(&mut self.extensions, &instruction);
1283                instructions.push(instruction)
1284            }
1285            gpu::Arithmetic::Powf(op) => {
1286                let op = self.compile_binary(op, out);
1287                instructions.push(self.select_fast_float(
1288                    out.ty,
1289                    modes,
1290                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1291                    Instruction::Powf(op),
1292                    Instruction::FastPowf(op),
1293                ))
1294            }
1295            gpu::Arithmetic::Powi(op) => {
1296                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1297            }
1298            gpu::Arithmetic::Hypot(op) => {
1299                instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1300            }
1301            gpu::Arithmetic::Rhypot(op) => {
1302                instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1303            }
1304            gpu::Arithmetic::Sqrt(op) => {
1305                let op = self.compile_unary(op, out);
1306                instructions.push(self.select_fast_float(
1307                    out.ty,
1308                    modes,
1309                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1310                    Instruction::Sqrt(op),
1311                    Instruction::FastSqrt(op),
1312                ))
1313            }
1314            gpu::Arithmetic::InverseSqrt(op) => {
1315                let op = self.compile_unary(op, out);
1316                instructions.push(self.select_fast_float(
1317                    out.ty,
1318                    modes,
1319                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1320                    Instruction::InverseSqrt(op),
1321                    Instruction::FastInverseSqrt(op),
1322                ))
1323            }
1324            gpu::Arithmetic::Erf(op) => {
1325                let instruction = Instruction::Erf(self.compile_unary(op, out));
1326                D::register_instruction_extension(&mut self.extensions, &instruction);
1327                instructions.push(instruction)
1328            }
1329            gpu::Arithmetic::Max(op) => {
1330                let instruction = Instruction::Max(self.compile_binary(op, out));
1331                D::register_instruction_extension(&mut self.extensions, &instruction);
1332                instructions.push(instruction)
1333            }
1334            gpu::Arithmetic::Min(op) => {
1335                let instruction = Instruction::Min(self.compile_binary(op, out));
1336                D::register_instruction_extension(&mut self.extensions, &instruction);
1337                instructions.push(instruction)
1338            }
1339            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1340                input: self.compile_variable(op.input),
1341                min_value: self.compile_variable(op.min_value),
1342                max_value: self.compile_variable(op.max_value),
1343                out: self.compile_variable(out),
1344            }),
1345            gpu::Arithmetic::Recip(op) => {
1346                let elem = op.input.ty.elem_type();
1347                let input = self.compile_variable(op.input);
1348                let out = self.compile_variable(out);
1349                let lhs = match elem {
1350                    gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1351                    gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1352                    gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1353                    gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1354                };
1355                let div = Instruction::Div(BinaryInstruction {
1356                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1357                    rhs: input,
1358                    out,
1359                });
1360                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1361
1362                instructions.push(self.select_fast_float(
1363                    elem.into(),
1364                    modes,
1365                    FastMath::AllowReciprocal
1366                        | FastMath::ReducedPrecision
1367                        | FastMath::UnsignedZero
1368                        | FastMath::NotInf,
1369                    div,
1370                    recip,
1371                ))
1372            }
1373            gpu::Arithmetic::Round(op) => {
1374                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1375            }
1376            gpu::Arithmetic::Floor(op) => {
1377                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1378            }
1379            gpu::Arithmetic::Ceil(op) => {
1380                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1381            }
1382            gpu::Arithmetic::Trunc(op) => {
1383                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1384            }
1385            gpu::Arithmetic::Remainder(op) => {
1386                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1387            }
1388            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1389                a: self.compile_variable(op.a),
1390                b: self.compile_variable(op.b),
1391                c: self.compile_variable(op.c),
1392                out: self.compile_variable(out),
1393            }),
1394            gpu::Arithmetic::Neg(op) => {
1395                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1396            }
1397            gpu::Arithmetic::Normalize(op) => {
1398                let op = self.compile_unary(op, out);
1399                instructions.push(self.select_fast_float(
1400                    out.ty,
1401                    modes,
1402                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1403                    Instruction::Normalize(op),
1404                    Instruction::FastNormalize(op),
1405                ))
1406            }
1407            gpu::Arithmetic::Magnitude(op) => {
1408                let op = self.compile_unary(op, out);
1409                instructions.push(self.select_fast_float(
1410                    out.ty,
1411                    modes,
1412                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1413                    Instruction::Magnitude(op),
1414                    Instruction::FastMagnitude(op),
1415                ))
1416            }
1417            gpu::Arithmetic::Dot(op) => {
1418                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1419            }
1420        };
1421    }
1422
1423    fn select_fast_float(
1424        &self,
1425        ty: gpu::Type,
1426        modes: InstructionModes,
1427        required_flags: EnumSet<FastMath>,
1428        default: Instruction<D>,
1429        fast: Instruction<D>,
1430    ) -> Instruction<D> {
1431        if !self.compilation_options.supports_features.fast_math
1432            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1433        {
1434            return default;
1435        }
1436
1437        if modes.fp_math_mode.is_superset(required_flags) {
1438            fast
1439        } else {
1440            default
1441        }
1442    }
1443
1444    fn compile_comparison(
1445        &mut self,
1446        value: gpu::Comparison,
1447        out: Option<gpu::Variable>,
1448        instructions: &mut Vec<Instruction<D>>,
1449    ) {
1450        let out = out.unwrap();
1451        match value {
1452            gpu::Comparison::Equal(op) => {
1453                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1454            }
1455            gpu::Comparison::Lower(op) => {
1456                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1457            }
1458            gpu::Comparison::Greater(op) => {
1459                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1460            }
1461            gpu::Comparison::LowerEqual(op) => {
1462                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1463            }
1464            gpu::Comparison::GreaterEqual(op) => {
1465                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1466            }
1467            gpu::Comparison::NotEqual(op) => {
1468                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1469            }
1470            gpu::Comparison::IsNan(op) => {
1471                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1472            }
1473            gpu::Comparison::IsInf(op) => {
1474                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1475            }
1476        };
1477    }
1478
1479    fn compile_bitwise(
1480        &mut self,
1481        value: gpu::Bitwise,
1482        out: Option<gpu::Variable>,
1483        instructions: &mut Vec<Instruction<D>>,
1484    ) {
1485        let out = out.unwrap();
1486        match value {
1487            gpu::Bitwise::BitwiseOr(op) => {
1488                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1489            }
1490            gpu::Bitwise::BitwiseAnd(op) => {
1491                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1492            }
1493            gpu::Bitwise::BitwiseXor(op) => {
1494                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1495            }
1496            gpu::Bitwise::CountOnes(op) => {
1497                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1498            }
1499            gpu::Bitwise::ReverseBits(op) => {
1500                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1501            }
1502            gpu::Bitwise::ShiftLeft(op) => {
1503                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1504            }
1505            gpu::Bitwise::ShiftRight(op) => {
1506                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1507            }
1508            gpu::Bitwise::BitwiseNot(op) => {
1509                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1510            }
1511            gpu::Bitwise::LeadingZeros(op) => {
1512                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1513            }
1514            gpu::Bitwise::FindFirstSet(op) => {
1515                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1516                D::register_instruction_extension(&mut self.extensions, &instruction);
1517                instructions.push(instruction)
1518            }
1519        };
1520    }
1521
1522    fn compile_operator(
1523        &mut self,
1524        value: gpu::Operator,
1525        out: Option<gpu::Variable>,
1526        instructions: &mut Vec<Instruction<D>>,
1527    ) {
1528        let out = out.unwrap();
1529        match value {
1530            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1531                instructions.push(Instruction::Index(self.compile_index(op, out)));
1532            }
1533            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1534                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1535            }
1536            gpu::Operator::And(op) => {
1537                instructions.push(Instruction::And(self.compile_binary(op, out)))
1538            }
1539            gpu::Operator::Or(op) => {
1540                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1541            }
1542            gpu::Operator::Not(op) => {
1543                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1544            }
1545            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1546                inputs: op
1547                    .inputs
1548                    .into_iter()
1549                    .map(|it| self.compile_variable(it))
1550                    .collect(),
1551                out: self.compile_variable(out),
1552            }),
1553            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1554                input: self.compile_variable(op.input),
1555                in_index: self.compile_variable(op.in_index),
1556                out: self.compile_variable(out),
1557                out_index: self.compile_variable(op.out_index),
1558            }),
1559            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1560                input: self.compile_variable(op.input),
1561                in_index: self.compile_variable(op.in_index),
1562                out: self.compile_variable(out),
1563                out_index: self.compile_variable(op.out_index),
1564                len: op.len,
1565            }),
1566            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1567                cond: self.compile_variable(op.cond),
1568                then: self.compile_variable(op.then),
1569                or_else: self.compile_variable(op.or_else),
1570                out: self.compile_variable(out),
1571            }),
1572            // Needs special conversion semantics
1573            gpu::Operator::Cast(op)
1574                if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1575            {
1576                // We may need these for intermediates
1577                self.flags.elem_f16 = true;
1578                self.flags.elem_bf16 = true;
1579                let vec_in = op.input.ty.line_size();
1580                let packing = out.storage_type().packing_factor();
1581                self.compile_type(op.input.ty.line(packing));
1582                self.compile_type(
1583                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1584                );
1585                self.compile_type(
1586                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1587                );
1588                self.compile_type(
1589                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1590                );
1591                self.compile_type(
1592                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1593                );
1594
1595                let inst = self.compile_unary(op, out);
1596
1597                instructions.push(Instruction::SpecialCast(inst));
1598            }
1599            gpu::Operator::Cast(op) => {
1600                let op = self.compile_unary(op, out);
1601
1602                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1603                    self.flags.elem_tf32 = true;
1604                }
1605
1606                instructions.push(Instruction::Assign(op))
1607            }
1608            gpu::Operator::Reinterpret(op) => {
1609                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1610            }
1611        };
1612    }
1613
1614    fn compile_binary(
1615        &mut self,
1616        value: gpu::BinaryOperator,
1617        out: gpu::Variable,
1618    ) -> BinaryInstruction<D> {
1619        BinaryInstruction {
1620            lhs: self.compile_variable(value.lhs),
1621            rhs: self.compile_variable(value.rhs),
1622            out: self.compile_variable(out),
1623        }
1624    }
1625
1626    fn compile_index_assign(
1627        &mut self,
1628        value: gpu::IndexAssignOperator,
1629        out: gpu::Variable,
1630    ) -> IndexAssignInstruction<D> {
1631        IndexAssignInstruction {
1632            index: self.compile_variable(value.index),
1633            value: self.compile_variable(value.value),
1634            line_size: value.line_size,
1635            out: self.compile_variable(out),
1636        }
1637    }
1638
1639    fn compile_index(
1640        &mut self,
1641        value: gpu::IndexOperator,
1642        out: gpu::Variable,
1643    ) -> IndexInstruction<D> {
1644        IndexInstruction {
1645            list: self.compile_variable(value.list),
1646            index: self.compile_variable(value.index),
1647            line_size: value.line_size,
1648            out: self.compile_variable(out),
1649        }
1650    }
1651
1652    fn compile_unary(
1653        &mut self,
1654        value: gpu::UnaryOperator,
1655        out: gpu::Variable,
1656    ) -> UnaryInstruction<D> {
1657        UnaryInstruction {
1658            input: self.compile_variable(value.input),
1659            out: self.compile_variable(out),
1660        }
1661    }
1662
1663    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1664        let item = value.ty;
1665        match value.kind {
1666            gpu::VariableKind::GlobalInputArray(id) => {
1667                Variable::GlobalInputArray(id, self.compile_type(item))
1668            }
1669            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1670                id,
1671                elem: self.compile_storage_type(item.storage_type()),
1672                in_struct: self.compilation_options.supports_features.grid_constants,
1673            },
1674            gpu::VariableKind::TensorMapInput(id) => {
1675                self.flags.inst_tma = true;
1676                Variable::TensorMap(id)
1677            }
1678            gpu::VariableKind::TensorMapOutput(id) => {
1679                self.flags.inst_tma = true;
1680                Variable::TensorMap(id)
1681            }
1682            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1683                id,
1684                item: self.compile_type(item),
1685            },
1686            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1687                id,
1688                item: self.compile_type(item),
1689            },
1690            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1691                id,
1692                item: self.compile_type(item),
1693            },
1694            gpu::VariableKind::GlobalOutputArray(id) => {
1695                Variable::GlobalOutputArray(id, self.compile_type(item))
1696            }
1697            gpu::VariableKind::ConstantScalar(value) => {
1698                Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1699            }
1700            gpu::VariableKind::SharedArray { id, length, .. } => {
1701                let item = self.compile_type(item);
1702                Variable::SharedArray(id, item, length)
1703            }
1704            gpu::VariableKind::Shared { id } => {
1705                let item = self.compile_type(item);
1706                Variable::Shared(id, item)
1707            }
1708            gpu::VariableKind::ConstantArray {
1709                id,
1710                length,
1711                unroll_factor,
1712            } => {
1713                let item = self.compile_type(item);
1714                Variable::ConstantArray(id, item, length * unroll_factor)
1715            }
1716            gpu::VariableKind::Builtin(builtin) => match builtin {
1717                gpu::Builtin::AbsolutePos => {
1718                    self.flags.indexes.absolute_pos = true;
1719                    Variable::AbsolutePos
1720                }
1721                gpu::Builtin::CubePosCluster
1722                    if self.compilation_options.supports_features.clusters =>
1723                {
1724                    self.flags.indexes.cluster_pos = true;
1725                    Variable::ClusterRank
1726                }
1727                gpu::Builtin::CubePosClusterX
1728                    if self.compilation_options.supports_features.clusters =>
1729                {
1730                    self.flags.indexes.cluster_pos = true;
1731                    Variable::ClusterIndexX
1732                }
1733                gpu::Builtin::CubePosClusterY
1734                    if self.compilation_options.supports_features.clusters =>
1735                {
1736                    self.flags.indexes.cluster_pos = true;
1737                    Variable::ClusterIndexY
1738                }
1739                gpu::Builtin::CubePosClusterZ
1740                    if self.compilation_options.supports_features.clusters =>
1741                {
1742                    self.flags.indexes.cluster_pos = true;
1743                    Variable::ClusterIndexZ
1744                }
1745                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1746                // (1, 1, 1) if unsupported
1747                gpu::Builtin::CubePosCluster
1748                | gpu::Builtin::CubePosClusterX
1749                | gpu::Builtin::CubePosClusterY
1750                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1751                gpu::Builtin::AbsolutePosX => {
1752                    self.flags.indexes.absolute_pos_tuple = true;
1753                    Variable::AbsolutePosX
1754                }
1755                gpu::Builtin::AbsolutePosY => {
1756                    self.flags.indexes.absolute_pos_tuple = true;
1757                    Variable::AbsolutePosY
1758                }
1759                gpu::Builtin::AbsolutePosZ => {
1760                    self.flags.indexes.absolute_pos_tuple = true;
1761                    Variable::AbsolutePosZ
1762                }
1763                gpu::Builtin::CubeDim => {
1764                    self.flags.indexes.cube_dim = true;
1765                    Variable::CubeDim
1766                }
1767                gpu::Builtin::CubeDimX => {
1768                    self.flags.indexes.cube_dim_tuple = true;
1769                    Variable::CubeDimX
1770                }
1771                gpu::Builtin::CubeDimY => {
1772                    self.flags.indexes.cube_dim_tuple = true;
1773                    Variable::CubeDimY
1774                }
1775                gpu::Builtin::CubeDimZ => {
1776                    self.flags.indexes.cube_dim_tuple = true;
1777                    Variable::CubeDimZ
1778                }
1779                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1780                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1781                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1782                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1783                gpu::Builtin::CubePos => {
1784                    self.flags.indexes.cube_pos = true;
1785                    Variable::CubePos
1786                }
1787                gpu::Builtin::CubePosX => {
1788                    self.flags.indexes.cube_pos_tuple = true;
1789                    Variable::CubePosX
1790                }
1791                gpu::Builtin::CubePosY => {
1792                    self.flags.indexes.cube_pos_tuple = true;
1793                    Variable::CubePosY
1794                }
1795                gpu::Builtin::CubePosZ => {
1796                    self.flags.indexes.cube_pos_tuple = true;
1797                    Variable::CubePosZ
1798                }
1799                gpu::Builtin::CubeCount => {
1800                    self.flags.indexes.cube_count = true;
1801                    Variable::CubeCount
1802                }
1803                gpu::Builtin::CubeCountX => {
1804                    self.flags.indexes.cube_count_tuple = true;
1805                    Variable::CubeCountX
1806                }
1807                gpu::Builtin::CubeCountY => {
1808                    self.flags.indexes.cube_count_tuple = true;
1809                    Variable::CubeCountY
1810                }
1811                gpu::Builtin::CubeCountZ => {
1812                    self.flags.indexes.cube_count_tuple = true;
1813                    Variable::CubeCountZ
1814                }
1815                gpu::Builtin::UnitPos => {
1816                    self.flags.indexes.unit_pos = true;
1817                    Variable::UnitPos
1818                }
1819                gpu::Builtin::UnitPosX => {
1820                    self.flags.indexes.unit_pos_tuple = true;
1821                    Variable::UnitPosX
1822                }
1823                gpu::Builtin::UnitPosY => {
1824                    self.flags.indexes.unit_pos_tuple = true;
1825                    Variable::UnitPosY
1826                }
1827                gpu::Builtin::UnitPosZ => {
1828                    self.flags.indexes.unit_pos_tuple = true;
1829                    Variable::UnitPosZ
1830                }
1831                gpu::Builtin::PlaneDim => {
1832                    self.flags.indexes.plane_dim = true;
1833                    Variable::PlaneDim
1834                }
1835                gpu::Builtin::UnitPosPlane => {
1836                    self.flags.indexes.unit_pos_plane = true;
1837                    Variable::UnitPosPlane
1838                }
1839            },
1840            gpu::VariableKind::LocalArray {
1841                id,
1842                length,
1843                unroll_factor,
1844            } => {
1845                let item = self.compile_type(item);
1846                if !self.local_arrays.iter().any(|s| s.index == id) {
1847                    self.local_arrays
1848                        .push(LocalArray::new(id, item, length * unroll_factor));
1849                }
1850                Variable::LocalArray(id, item, length)
1851            }
1852            gpu::VariableKind::Matrix { id, mat } => {
1853                self.flags.inst_wmma = true;
1854                Variable::WmmaFragment {
1855                    id,
1856                    frag: self.compile_matrix(mat),
1857                }
1858            }
1859            gpu::VariableKind::Pipeline { id, num_stages } => {
1860                self.flags.op_pipeline = true;
1861                let pipeline = Variable::Pipeline { id };
1862                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1863                    self.pipelines.push(PipelineOps::Init {
1864                        pipeline,
1865                        num_stages,
1866                    });
1867                }
1868                pipeline
1869            }
1870            gpu::VariableKind::BarrierToken { id, level } => {
1871                self.flags.op_barrier = true;
1872                Variable::BarrierToken { id, level }
1873            }
1874        }
1875    }
1876
1877    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1878        Fragment {
1879            ident: self.compile_matrix_ident(matrix.ident),
1880            m: matrix.m,
1881            n: matrix.n,
1882            k: matrix.k,
1883            elem: self.compile_storage_type(matrix.storage),
1884            layout: self.compile_matrix_layout(matrix.layout),
1885        }
1886    }
1887
1888    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1889        match ident {
1890            gpu::MatrixIdent::A => FragmentIdent::A,
1891            gpu::MatrixIdent::B => FragmentIdent::B,
1892            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1893        }
1894    }
1895
1896    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1897        match layout {
1898            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1899            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1900            gpu::MatrixLayout::Undefined => None,
1901        }
1902    }
1903
1904    fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1905        Binding {
1906            id: binding.id,
1907            item: self.compile_type(binding.ty),
1908            location: binding.location,
1909            size: binding.size,
1910            vis: binding.visibility,
1911        }
1912    }
1913
1914    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1915        let item = match ty {
1916            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1917            gpu::Type::Line(ty, line_size) => {
1918                Item::new(self.compile_storage_type(ty), line_size as usize, false)
1919            }
1920            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1921        };
1922        if item.elem != super::Elem::TF32 {
1923            self.items.insert(item);
1924            self.items.insert(item.optimized());
1925        } else {
1926            // TF32 is represented as `float` in C++
1927            let mut item = item;
1928            item.elem = super::Elem::F32;
1929            self.items.insert(item);
1930        }
1931
1932        item
1933    }
1934
1935    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1936        match value {
1937            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1938            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1939            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1940                FloatKind::E2M1 => {
1941                    self.flags.elem_fp4 = true;
1942                    Elem::FP4x2(FP4Kind::E2M1)
1943                }
1944                FloatKind::E2M3 => {
1945                    self.flags.elem_fp6 = true;
1946                    Elem::FP6x2(FP6Kind::E2M3)
1947                }
1948                FloatKind::E3M2 => {
1949                    self.flags.elem_fp6 = true;
1950                    Elem::FP6(FP6Kind::E3M2)
1951                }
1952                FloatKind::E4M3 => {
1953                    self.flags.elem_fp8 = true;
1954                    Elem::FP8x2(FP8Kind::E4M3)
1955                }
1956                FloatKind::E5M2 => {
1957                    self.flags.elem_fp8 = true;
1958                    Elem::FP8x2(FP8Kind::E5M2)
1959                }
1960                FloatKind::UE8M0 => {
1961                    self.flags.elem_fp8 = true;
1962                    Elem::FP8x2(FP8Kind::UE8M0)
1963                }
1964                FloatKind::F16 => {
1965                    self.flags.elem_f16 = true;
1966                    Elem::F16x2
1967                }
1968                FloatKind::BF16 => {
1969                    self.flags.elem_bf16 = true;
1970                    Elem::BF16x2
1971                }
1972                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1973            },
1974            gpu::StorageType::Packed(other, factor) => {
1975                unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
1976            }
1977            gpu::StorageType::Opaque(ty) => match ty {
1978                gpu::OpaqueType::Barrier(level) => {
1979                    self.flags.op_barrier = true;
1980                    Elem::Barrier(level)
1981                }
1982            },
1983        }
1984    }
1985
1986    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1987        match value {
1988            gpu::ElemType::Float(kind) => match kind {
1989                gpu::FloatKind::E2M1 => {
1990                    self.flags.elem_fp4 = true;
1991                    Elem::FP4(FP4Kind::E2M1)
1992                }
1993                gpu::FloatKind::E2M3 => {
1994                    self.flags.elem_fp6 = true;
1995                    Elem::FP6(FP6Kind::E2M3)
1996                }
1997                gpu::FloatKind::E3M2 => {
1998                    self.flags.elem_fp6 = true;
1999                    Elem::FP6(FP6Kind::E3M2)
2000                }
2001                gpu::FloatKind::E4M3 => {
2002                    self.flags.elem_fp8 = true;
2003                    Elem::FP8(FP8Kind::E4M3)
2004                }
2005                gpu::FloatKind::E5M2 => {
2006                    self.flags.elem_fp8 = true;
2007                    Elem::FP8(FP8Kind::E5M2)
2008                }
2009                gpu::FloatKind::UE8M0 => {
2010                    self.flags.elem_fp8 = true;
2011                    Elem::FP8(FP8Kind::UE8M0)
2012                }
2013                gpu::FloatKind::F16 => {
2014                    self.flags.elem_f16 = true;
2015                    Elem::F16
2016                }
2017                gpu::FloatKind::BF16 => {
2018                    self.flags.elem_bf16 = true;
2019                    Elem::BF16
2020                }
2021                gpu::FloatKind::TF32 => Elem::TF32,
2022                gpu::FloatKind::Flex32 => Elem::F32,
2023                gpu::FloatKind::F32 => Elem::F32,
2024                gpu::FloatKind::F64 => Elem::F64,
2025            },
2026            gpu::ElemType::Int(kind) => match kind {
2027                gpu::IntKind::I8 => Elem::I8,
2028                gpu::IntKind::I16 => Elem::I16,
2029                gpu::IntKind::I32 => Elem::I32,
2030                gpu::IntKind::I64 => Elem::I64,
2031            },
2032            gpu::ElemType::UInt(kind) => match kind {
2033                gpu::UIntKind::U8 => Elem::U8,
2034                gpu::UIntKind::U16 => Elem::U16,
2035                gpu::UIntKind::U32 => Elem::U32,
2036                gpu::UIntKind::U64 => Elem::U64,
2037            },
2038            gpu::ElemType::Bool => Elem::Bool,
2039        }
2040    }
2041}
2042
2043fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
2044    match elem {
2045        gpu::ElemType::Float(kind) => matches!(
2046            kind,
2047            FloatKind::E2M1
2048                | FloatKind::E2M3
2049                | FloatKind::E3M2
2050                | FloatKind::E4M3
2051                | FloatKind::E5M2
2052                | FloatKind::UE8M0
2053        ),
2054        _ => false,
2055    }
2056}
2057
2058fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2059    Variable::ConstantScalar(
2060        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
2061        Elem::U32,
2062    )
2063}
2064
2065pub fn register_supported_types(props: &mut DeviceProperties) {
2066    let supported_types = [
2067        gpu::ElemType::UInt(gpu::UIntKind::U8),
2068        gpu::ElemType::UInt(gpu::UIntKind::U16),
2069        gpu::ElemType::UInt(gpu::UIntKind::U32),
2070        gpu::ElemType::UInt(gpu::UIntKind::U64),
2071        gpu::ElemType::Int(gpu::IntKind::I8),
2072        gpu::ElemType::Int(gpu::IntKind::I16),
2073        gpu::ElemType::Int(gpu::IntKind::I32),
2074        gpu::ElemType::Int(gpu::IntKind::I64),
2075        gpu::ElemType::Float(gpu::FloatKind::BF16),
2076        gpu::ElemType::Float(gpu::FloatKind::F16),
2077        gpu::ElemType::Float(gpu::FloatKind::F32),
2078        gpu::ElemType::Float(gpu::FloatKind::Flex32),
2079        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
2080        //gpu::Elem::Float(gpu::FloatKind::F64),
2081        gpu::ElemType::Bool,
2082    ];
2083
2084    let supported_atomic_types = [
2085        gpu::ElemType::Int(gpu::IntKind::I32),
2086        gpu::ElemType::Int(gpu::IntKind::I64),
2087        gpu::ElemType::UInt(gpu::UIntKind::U32),
2088        gpu::ElemType::UInt(gpu::UIntKind::U64),
2089        gpu::ElemType::Float(gpu::FloatKind::F32),
2090    ];
2091
2092    for ty in supported_types {
2093        props.register_type_usage(ty, TypeUsage::all_scalar());
2094    }
2095
2096    for ty in supported_atomic_types {
2097        props.register_type_usage(
2098            gpu::StorageType::Atomic(ty),
2099            TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
2100        );
2101    }
2102}