cubecl_cpp/shared/
base.rs

1use super::{
2    BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP4Kind,
3    FP6Kind, FP8Kind, Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction,
4    IndexInstruction, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction, Variable,
5    WarpInstruction, WmmaInstruction, barrier::BarrierOps, pipeline::PipelineOps,
6};
7use crate::shared::MmaShape;
8use cubecl_common::backtrace::BackTrace;
9use cubecl_core::{
10    CubeDim,
11    ir::{
12        self as gpu, DeviceProperties, ElemType, FloatKind, InstructionModes, OpaqueType,
13        Operation, Processor, SourceLoc, StorageType,
14        features::{EnumSet, TypeUsage},
15    },
16    post_processing::checked_io::CheckedIoProcessor,
17    prelude::{FastMath, KernelDefinition},
18    server::ExecutionMode,
19};
20use cubecl_opt::{Optimizer, SharedLiveness};
21use cubecl_runtime::compiler::{CompilationError, Compiler};
22use std::{collections::HashSet, fmt::Debug};
23
24pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
25    std::sync::atomic::AtomicU32::new(0);
26
27#[derive(Clone, Debug)]
28pub struct CompilationOptions {
29    pub warp_size: u32,
30    pub supports_features: CppSupportedFeatures,
31}
32
33#[derive(Clone, Debug, Default)]
34pub struct CppSupportedFeatures {
35    pub grid_constants: bool,
36    pub clusters: bool,
37    pub fast_math: bool,
38    pub fast_tanh: bool,
39    pub elect_sync: bool,
40}
41
42impl Default for CompilationOptions {
43    fn default() -> Self {
44        Self {
45            warp_size: 32,
46            supports_features: Default::default(),
47        }
48    }
49}
50
51/// Cube indexes flags.
52/// When true the corresponding index is declared and computed as needed in the kernel.
53#[derive(Debug, Clone, Default)]
54pub struct CubeIndexFlags {
55    pub absolute_pos: bool,
56    pub absolute_pos_tuple: bool,
57    pub cube_count: bool,
58    pub cube_count_tuple: bool,
59    pub cube_dim: bool,
60    pub cube_dim_tuple: bool,
61    pub cube_pos: bool,
62    pub cube_pos_tuple: bool,
63    pub plane_dim: bool,
64    pub plane_dim_checked: bool,
65    pub plane_index: bool,
66    pub unit_pos: bool,
67    pub unit_pos_tuple: bool,
68    pub unit_pos_plane: bool,
69    pub cluster_pos: bool,
70}
71
72/// Flags gathered during Cube IR translation for the kernel compilation.
73#[derive(Debug, Clone)]
74pub struct Flags<D: Dialect> {
75    pub elem_fp4: bool,
76    pub elem_fp6: bool,
77    pub elem_fp8: bool,
78    pub elem_bf16: bool,
79    pub elem_f16: bool,
80    pub elem_tf32: bool,
81    pub indexes: CubeIndexFlags,
82    pub op_barrier: bool,
83    pub op_pipeline: bool,
84    pub inst_tma: bool,
85    pub inst_tma_im2col: bool,
86    pub inst_wmma: bool,
87    pub inst_ptx_wrappers: bool,
88    pub inst_async_copy: bool,
89    pub use_grid_constants: bool,
90    pub static_meta_length: usize,
91    pub has_dynamic_meta: bool,
92    pub cube_dim: CubeDim,
93    pub cluster_dim: Option<CubeDim>,
94    pub address_type: Item<D>,
95}
96
97#[allow(clippy::too_many_arguments)]
98#[derive(Clone, Debug)]
99pub struct CppCompiler<D: Dialect> {
100    barriers: Vec<BarrierOps<D>>,
101    compilation_options: CompilationOptions,
102    const_arrays: Vec<ConstArray<D>>,
103    ext_meta_positions: Vec<u32>,
104    cluster_dim: CubeDim,
105    extensions: Vec<D::Extension>,
106    flags: Flags<D>,
107    items: HashSet<Item<D>>,
108    local_arrays: Vec<LocalArray<D>>,
109    metadata: cubecl_core::Metadata,
110    pipelines: Vec<PipelineOps<D>>,
111    source_loc: Option<SourceLoc>,
112    strategy: ExecutionMode,
113    addr_type: Item<D>,
114}
115
116impl<D: Dialect> Default for Flags<D> {
117    fn default() -> Self {
118        Self {
119            elem_fp4: Default::default(),
120            elem_fp6: Default::default(),
121            elem_fp8: Default::default(),
122            elem_bf16: Default::default(),
123            elem_f16: Default::default(),
124            elem_tf32: Default::default(),
125            indexes: Default::default(),
126            op_barrier: Default::default(),
127            op_pipeline: Default::default(),
128            inst_tma: Default::default(),
129            inst_tma_im2col: Default::default(),
130            inst_wmma: Default::default(),
131            inst_ptx_wrappers: Default::default(),
132            inst_async_copy: Default::default(),
133            use_grid_constants: Default::default(),
134            static_meta_length: Default::default(),
135            has_dynamic_meta: Default::default(),
136            cube_dim: CubeDim::new_single(),
137            cluster_dim: Default::default(),
138            address_type: Item::scalar(Elem::U32, true),
139        }
140    }
141}
142
143impl<D: Dialect> Default for CppCompiler<D> {
144    fn default() -> Self {
145        Self {
146            barriers: Default::default(),
147            compilation_options: Default::default(),
148            const_arrays: Default::default(),
149            ext_meta_positions: Default::default(),
150            cluster_dim: CubeDim::new_single(),
151            extensions: Default::default(),
152            flags: Flags::default(),
153            items: Default::default(),
154            local_arrays: Default::default(),
155            metadata: Default::default(),
156            pipelines: Default::default(),
157            source_loc: Default::default(),
158            strategy: Default::default(),
159            addr_type: Item::scalar(Elem::U32, true),
160        }
161    }
162}
163
164impl<D: Dialect> Compiler for CppCompiler<D> {
165    type Representation = ComputeKernel<D>;
166    type CompilationOptions = CompilationOptions;
167
168    fn compile(
169        &mut self,
170        mut kernel: KernelDefinition,
171        compilation_options: &Self::CompilationOptions,
172        strategy: ExecutionMode,
173        addr_type: StorageType,
174    ) -> Result<Self::Representation, CompilationError> {
175        let errors = kernel.body.pop_errors();
176        if !errors.is_empty() {
177            let mut reason = "Can't compile cpp kernel\nCaused by:\n  ".to_string();
178            for error in errors {
179                reason += error.as_str();
180                reason += "\n";
181            }
182
183            return Err(CompilationError::Validation {
184                reason,
185                backtrace: BackTrace::capture(),
186            });
187        }
188
189        self.addr_type = self.compile_type(addr_type.into());
190        self.compilation_options = compilation_options.clone();
191        self.strategy = strategy;
192
193        if !self.compilation_options.supports_features.clusters {
194            kernel.options.cluster_dim = None;
195        }
196        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
197
198        let ir = self.clone().compile_ir(kernel, addr_type);
199        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
200        Ok(ir)
201    }
202
203    fn elem_size(&self, elem: gpu::ElemType) -> usize {
204        elem.size()
205    }
206
207    fn extension(&self) -> &'static str {
208        "cpp"
209    }
210}
211
212impl<D: Dialect> CppCompiler<D> {
213    fn compile_ir(
214        mut self,
215        value: KernelDefinition,
216        address_type: StorageType,
217    ) -> ComputeKernel<D> {
218        self.build_metadata(&value);
219
220        let instructions = self.compile_scope(&mut value.body.clone());
221        let tensor_maps = value
222            .tensor_maps
223            .into_iter()
224            .map(|b| self.compile_binding(b))
225            .collect();
226        let buffers = value
227            .buffers
228            .into_iter()
229            .map(|b| self.compile_binding(b))
230            .collect();
231        let scalars = value
232            .scalars
233            .into_iter()
234            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
235            .collect();
236
237        // translation flags
238        let flags = Flags {
239            indexes: D::builtin_rules(&self.flags.indexes),
240            inst_wmma: self.flags.inst_wmma,
241            op_pipeline: self.flags.op_pipeline,
242            op_barrier: self.flags.op_barrier,
243            elem_fp4: self.flags.elem_fp4,
244            elem_fp6: self.flags.elem_fp6,
245            elem_fp8: self.flags.elem_fp8,
246            elem_bf16: self.flags.elem_bf16,
247            elem_f16: self.flags.elem_f16,
248            elem_tf32: self.flags.elem_tf32,
249            inst_tma: self.flags.inst_tma,
250            inst_tma_im2col: self.flags.inst_tma_im2col,
251            inst_async_copy: self.flags.inst_async_copy,
252            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
253            use_grid_constants: self.compilation_options.supports_features.grid_constants,
254            // TODO: At some point we should only pass dynamic meta if tensors are present,
255            // not if only arrays are present. For now, leave like this
256            has_dynamic_meta: self.metadata.static_len() > 0,
257            static_meta_length: self.metadata.static_len() as usize,
258            cube_dim: value.cube_dim,
259            cluster_dim: value.options.cluster_dim,
260            address_type: self.compile_type(address_type.into()),
261        };
262
263        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
264        let shared_allocs = opt.analysis::<SharedLiveness>();
265        let shared_memories = shared_allocs
266            .allocations
267            .values()
268            .map(|alloc| match alloc.smem {
269                cubecl_opt::SharedMemory::Array {
270                    id,
271                    length,
272                    ty,
273                    align,
274                } => SharedMemory::Array {
275                    index: id,
276                    item: self.compile_type(ty),
277                    length,
278                    align,
279                    offset: alloc.offset,
280                },
281                cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
282                    index: id,
283                    item: self.compile_type(ty),
284                    align,
285                    offset: alloc.offset,
286                },
287            })
288            .collect();
289
290        let body = Body {
291            instructions,
292            shared_memories,
293            pipelines: self.pipelines,
294            barriers: self.barriers,
295            const_arrays: self.const_arrays,
296            local_arrays: self.local_arrays,
297        };
298
299        let mut cluster_dim = value.options.cluster_dim;
300        if !self.compilation_options.supports_features.clusters {
301            cluster_dim = None;
302        }
303
304        ComputeKernel {
305            tensor_maps,
306            buffers,
307            scalars,
308            meta_static_len: self.metadata.static_len() as usize,
309            cube_dim: value.cube_dim,
310            body,
311            extensions: self.extensions,
312            flags,
313            items: self.items,
314            kernel_name: value.options.kernel_name,
315            cluster_dim,
316        }
317    }
318
319    fn build_metadata(&mut self, value: &KernelDefinition) {
320        let mut num_ext = 0;
321
322        let mut all_meta: Vec<_> = value
323            .buffers
324            .iter()
325            .chain(value.tensor_maps.iter())
326            .map(|buf| (buf.id, buf.has_extended_meta))
327            .collect();
328
329        all_meta.sort_by_key(|(id, _)| *id);
330
331        for (_, has_extended_meta) in &all_meta {
332            self.ext_meta_positions.push(num_ext);
333            if *has_extended_meta {
334                num_ext += 1;
335            }
336        }
337
338        let num_meta = all_meta.len();
339
340        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
341    }
342
343    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
344        let id = var.index().expect("Variable should have index");
345        self.ext_meta_positions[id as usize]
346    }
347
348    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
349        let mut instructions = Vec::new();
350
351        let const_arrays = scope
352            .const_arrays
353            .drain(..)
354            .map(|(var, values)| ConstArray {
355                index: var.index().unwrap(),
356                item: self.compile_type(var.ty),
357                size: values.len() as u32,
358                values: values
359                    .into_iter()
360                    .map(|val| self.compile_variable(val))
361                    .collect(),
362            })
363            .collect::<Vec<_>>();
364        self.const_arrays.extend(const_arrays);
365
366        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
367        let dialect_processors = D::processors();
368        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
369        processors.extend(dialect_processors.iter().map(|it| &**it));
370
371        let processing = scope.process(processors);
372
373        for var in processing.variables {
374            instructions.push(Instruction::DeclareVariable {
375                var: self.compile_variable(var),
376            });
377        }
378
379        processing
380            .instructions
381            .into_iter()
382            .for_each(|op| self.compile_instruction(&mut instructions, op));
383
384        instructions
385    }
386
387    fn compile_instruction(
388        &mut self,
389        instructions: &mut Vec<Instruction<D>>,
390        instruction: gpu::Instruction,
391    ) {
392        self.update_debug_loc(instructions, &instruction);
393        let out = instruction.out;
394        match instruction.operation {
395            gpu::Operation::Copy(variable) => {
396                instructions.push(Instruction::Assign(UnaryInstruction {
397                    input: self.compile_variable(variable),
398                    out: self.compile_variable(out.unwrap()),
399                }));
400            }
401            gpu::Operation::Arithmetic(op) => {
402                self.compile_arithmetic(op, out, instruction.modes, instructions)
403            }
404            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
405            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
406            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
407            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
408            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
409            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
410            gpu::Operation::Synchronization(val) => match val {
411                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
412                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
413                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
414                gpu::Synchronization::SyncAsyncProxyShared => {
415                    self.flags.inst_tma = true;
416                    instructions.push(Instruction::ProxyAsyncToSharedFence)
417                }
418            },
419            gpu::Operation::Plane(op) => {
420                self.flags.indexes.plane_dim_checked = true;
421                let out = self.compile_variable(out.unwrap());
422                match op {
423                    gpu::Plane::Sum(op) => {
424                        let instruction = WarpInstruction::ReduceSum {
425                            input: self.compile_variable(op.input),
426                            out,
427                        };
428                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
429                        instructions.push(Instruction::Warp(instruction));
430                    }
431                    gpu::Plane::InclusiveSum(op) => {
432                        self.flags.indexes.unit_pos_plane = true;
433                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
434                            input: self.compile_variable(op.input),
435                            out,
436                        }))
437                    }
438                    gpu::Plane::InclusiveProd(op) => {
439                        self.flags.indexes.unit_pos_plane = true;
440                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
441                            input: self.compile_variable(op.input),
442                            out,
443                        }))
444                    }
445                    gpu::Plane::ExclusiveSum(op) => {
446                        self.flags.indexes.unit_pos_plane = true;
447                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
448                            input: self.compile_variable(op.input),
449                            out,
450                        }))
451                    }
452                    gpu::Plane::ExclusiveProd(op) => {
453                        self.flags.indexes.unit_pos_plane = true;
454                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
455                            input: self.compile_variable(op.input),
456                            out,
457                        }))
458                    }
459                    gpu::Plane::Prod(op) => {
460                        let instruction = WarpInstruction::ReduceProd {
461                            input: self.compile_variable(op.input),
462                            out,
463                        };
464                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
465                        instructions.push(Instruction::Warp(instruction))
466                    }
467                    gpu::Plane::Max(op) => {
468                        let instruction = WarpInstruction::ReduceMax {
469                            input: self.compile_variable(op.input),
470                            out,
471                        };
472                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
473                        instructions.push(Instruction::Warp(instruction))
474                    }
475                    gpu::Plane::Min(op) => {
476                        let instruction = WarpInstruction::ReduceMin {
477                            input: self.compile_variable(op.input),
478                            out,
479                        };
480                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
481                        instructions.push(Instruction::Warp(instruction))
482                    }
483                    gpu::Plane::Elect => {
484                        if self.compilation_options.supports_features.elect_sync {
485                            self.flags.inst_ptx_wrappers = true;
486                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
487                        } else {
488                            instructions
489                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
490                        }
491                    }
492                    gpu::Plane::All(op) => {
493                        instructions.push(Instruction::Warp(WarpInstruction::All {
494                            input: self.compile_variable(op.input),
495                            out,
496                        }))
497                    }
498                    gpu::Plane::Any(op) => {
499                        instructions.push(Instruction::Warp(WarpInstruction::Any {
500                            input: self.compile_variable(op.input),
501                            out,
502                        }))
503                    }
504                    gpu::Plane::Ballot(op) => {
505                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
506                            input: self.compile_variable(op.input),
507                            out,
508                        }))
509                    }
510                    gpu::Plane::Broadcast(op) => {
511                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
512                            input: self.compile_variable(op.lhs),
513                            id: self.compile_variable(op.rhs),
514                            out,
515                        }))
516                    }
517                    gpu::Plane::Shuffle(op) => {
518                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
519                            input: self.compile_variable(op.lhs),
520                            src_lane: self.compile_variable(op.rhs),
521                            out,
522                        }))
523                    }
524                    gpu::Plane::ShuffleXor(op) => {
525                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
526                            input: self.compile_variable(op.lhs),
527                            mask: self.compile_variable(op.rhs),
528                            out,
529                        }))
530                    }
531                    gpu::Plane::ShuffleUp(op) => {
532                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
533                            input: self.compile_variable(op.lhs),
534                            delta: self.compile_variable(op.rhs),
535                            out,
536                        }))
537                    }
538                    gpu::Plane::ShuffleDown(op) => {
539                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
540                            input: self.compile_variable(op.lhs),
541                            delta: self.compile_variable(op.rhs),
542                            out,
543                        }))
544                    }
545                }
546            }
547            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
548            gpu::Operation::NonSemantic(debug) => match debug {
549                gpu::NonSemantic::Print {
550                    format_string,
551                    args,
552                } => instructions.push(Instruction::Printf {
553                    format_string,
554                    args: args
555                        .into_iter()
556                        .map(|arg| self.compile_variable(arg))
557                        .collect(),
558                }),
559                gpu::NonSemantic::Comment { content } => {
560                    instructions.push(Instruction::Comment { content })
561                }
562                // Don't need to handle scopes
563                _ => {}
564            },
565            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
566                gpu::BarrierOps::Declare { barrier } => {
567                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
568                    else {
569                        unreachable!()
570                    };
571                    let barrier = self.compile_variable(barrier);
572                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
573                        barrier,
574                        level,
575                    }));
576                }
577                gpu::BarrierOps::Init {
578                    barrier,
579                    is_elected,
580                    arrival_count,
581                } => {
582                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
583                    else {
584                        unreachable!()
585                    };
586                    let barrier = self.compile_variable(barrier);
587                    let arrival_count = self.compile_variable(arrival_count);
588                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
589                        barrier,
590                        is_elected: self.compile_variable(is_elected),
591                        arrival_count,
592                        level,
593                    }));
594                }
595                gpu::BarrierOps::InitManual {
596                    barrier,
597                    arrival_count,
598                } => {
599                    let barrier = self.compile_variable(barrier);
600                    let arrival_count = self.compile_variable(arrival_count);
601                    instructions.push(Instruction::Barrier(
602                        super::barrier::BarrierOps::InitManual {
603                            barrier,
604                            arrival_count,
605                        },
606                    ));
607                }
608                gpu::BarrierOps::MemCopyAsync {
609                    barrier,
610                    source,
611                    source_length,
612                    offset_source,
613                    offset_out,
614                } => {
615                    instructions.push(Instruction::Barrier(
616                        super::barrier::BarrierOps::MemCopyAsync {
617                            barrier: self.compile_variable(barrier),
618                            source: self.compile_variable(source),
619                            destination: self.compile_variable(out.unwrap()),
620                            source_length: self.compile_variable(source_length),
621                            offset_source: self.compile_variable(offset_source),
622                            offset_out: self.compile_variable(offset_out),
623                            cooperative: false,
624                        },
625                    ));
626                }
627                gpu::BarrierOps::MemCopyAsyncCooperative {
628                    barrier,
629                    source,
630                    source_length,
631                    offset_source,
632                    offset_out,
633                } => {
634                    instructions.push(Instruction::Barrier(
635                        super::barrier::BarrierOps::MemCopyAsync {
636                            barrier: self.compile_variable(barrier),
637                            source: self.compile_variable(source),
638                            destination: self.compile_variable(out.unwrap()),
639                            source_length: self.compile_variable(source_length),
640                            offset_source: self.compile_variable(offset_source),
641                            offset_out: self.compile_variable(offset_out),
642                            cooperative: true,
643                        },
644                    ));
645                }
646                gpu::BarrierOps::MemCopyAsyncTx {
647                    barrier,
648                    source,
649                    source_length,
650                    offset_source,
651                    offset_out,
652                } => {
653                    instructions.push(Instruction::Barrier(
654                        super::barrier::BarrierOps::MemCopyAsyncTx {
655                            barrier: self.compile_variable(barrier),
656                            source: self.compile_variable(source),
657                            destination: self.compile_variable(out.unwrap()),
658                            source_length: self.compile_variable(source_length),
659                            offset_source: self.compile_variable(offset_source),
660                            offset_out: self.compile_variable(offset_out),
661                        },
662                    ));
663                }
664                gpu::BarrierOps::CopyAsync {
665                    source,
666                    source_length,
667                    offset_source,
668                    offset_out,
669                    copy_length,
670                    checked,
671                } => {
672                    self.flags.inst_async_copy = true;
673                    instructions.push(Instruction::Barrier(
674                        super::barrier::BarrierOps::CopyAsync {
675                            source: self.compile_variable(source),
676                            destination: self.compile_variable(out.unwrap()),
677                            source_length: self.compile_variable(source_length),
678                            offset_source: self.compile_variable(offset_source),
679                            offset_out: self.compile_variable(offset_out),
680                            copy_size: copy_length,
681                            checked,
682                        },
683                    ));
684                }
685                gpu::BarrierOps::TmaLoad {
686                    barrier,
687                    tensor_map,
688                    offset_out,
689                    indices,
690                } => {
691                    instructions.push(Instruction::Barrier(
692                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
693                            barrier: self.compile_variable(barrier),
694                            smem_buffer: self.compile_variable(out.unwrap()),
695                            smem_offset: self.compile_variable(offset_out),
696                            tensor_map: self.compile_variable(tensor_map),
697                            indices: indices
698                                .into_iter()
699                                .map(|it| self.compile_variable(it))
700                                .collect(),
701                        },
702                    ));
703                }
704                gpu::BarrierOps::TmaLoadIm2col {
705                    barrier,
706                    tensor_map,
707                    offset_out,
708                    indices,
709                    offsets,
710                } => {
711                    self.flags.inst_tma_im2col = true;
712                    instructions.push(Instruction::Barrier(
713                        super::barrier::BarrierOps::TmaLoadIm2col {
714                            barrier: self.compile_variable(barrier),
715                            smem_buffer: self.compile_variable(out.unwrap()),
716                            smem_offset: self.compile_variable(offset_out),
717                            tensor_map: self.compile_variable(tensor_map),
718                            indices: indices
719                                .into_iter()
720                                .map(|it| self.compile_variable(it))
721                                .collect(),
722                            offsets: offsets
723                                .into_iter()
724                                .map(|it| self.compile_variable(it))
725                                .collect(),
726                        },
727                    ));
728                }
729                gpu::BarrierOps::Arrive { barrier } => {
730                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
731                        barrier: self.compile_variable(barrier),
732                        token: self.compile_variable(out.unwrap()),
733                    }))
734                }
735                gpu::BarrierOps::ArriveTx {
736                    barrier,
737                    arrive_count_update,
738                    transaction_count_update,
739                } => {
740                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
741                        barrier: self.compile_variable(barrier),
742                        token: self.compile_variable(out.unwrap()),
743                        arrive_count_update: self.compile_variable(arrive_count_update),
744                        transaction_count_update: self.compile_variable(transaction_count_update),
745                    }))
746                }
747                gpu::BarrierOps::CommitCopyAsync { barrier } => {
748                    self.flags.inst_async_copy = true;
749                    instructions.push(Instruction::Barrier(
750                        super::barrier::BarrierOps::ArriveCopyAsync {
751                            barrier: self.compile_variable(barrier),
752                        },
753                    ))
754                }
755                gpu::BarrierOps::ExpectTx {
756                    barrier,
757                    transaction_count_update,
758                } => {
759                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
760                        barrier: self.compile_variable(barrier),
761                        transaction_count_update: self.compile_variable(transaction_count_update),
762                    }))
763                }
764                gpu::BarrierOps::Wait { barrier, token } => {
765                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
766                        barrier: self.compile_variable(barrier),
767                        token: self.compile_variable(token),
768                    }))
769                }
770                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
771                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
772                        barrier: self.compile_variable(barrier),
773                        phase: self.compile_variable(phase),
774                    }),
775                ),
776                gpu::BarrierOps::ArriveAndWait { barrier } => {
777                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
778                    else {
779                        unreachable!()
780                    };
781                    instructions.push(Instruction::Barrier(
782                        super::barrier::BarrierOps::ArriveAndWait {
783                            barrier: self.compile_variable(barrier),
784                            level,
785                        },
786                    ))
787                }
788            },
789            gpu::Operation::Tma(tma_ops) => {
790                self.flags.inst_tma = true;
791                match tma_ops {
792                    gpu::TmaOps::TmaStore {
793                        source,
794                        coordinates,
795                        offset_source,
796                    } => {
797                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
798                            smem_buffer: self.compile_variable(source),
799                            smem_offset: self.compile_variable(offset_source),
800                            tensor_map: self.compile_variable(out.unwrap()),
801                            indices: coordinates
802                                .into_iter()
803                                .map(|it| self.compile_variable(it))
804                                .collect(),
805                        });
806                    }
807                    gpu::TmaOps::CommitGroup => {
808                        instructions.push(Instruction::BulkCommitGroup);
809                    }
810                    gpu::TmaOps::WaitGroup { max_pending } => {
811                        instructions.push(Instruction::BulkWaitGroup { max_pending });
812                    }
813                    gpu::TmaOps::WaitGroupRead { max_pending } => {
814                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
815                    }
816                }
817            }
818            gpu::Operation::Marker(_) => {}
819        }
820    }
821
822    fn update_debug_loc(
823        &mut self,
824        instructions: &mut Vec<Instruction<D>>,
825        inst: &gpu::Instruction,
826    ) {
827        if !matches!(inst.operation, Operation::NonSemantic(_)) {
828            match &inst.source_loc {
829                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
830                    self.source_loc = Some(loc.clone());
831                    instructions.push(Instruction::Line {
832                        file: loc.source.file.clone(),
833                        line: loc.line,
834                    });
835                }
836                _ => {}
837            }
838        }
839    }
840
841    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
842        self.flags.inst_wmma = true;
843
844        let out = self.compile_variable(out.unwrap());
845
846        let inst = match cmma {
847            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
848                frag: out,
849                value: self.compile_variable(value),
850            },
851            gpu::CoopMma::Load {
852                value,
853                stride,
854                offset,
855                layout,
856            } => WmmaInstruction::Load {
857                frag: out,
858                offset: self.compile_variable(offset),
859                value: self.compile_variable(value),
860                stride: self.compile_variable(stride),
861                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
862            },
863            gpu::CoopMma::Execute {
864                mat_a,
865                mat_b,
866                mat_c,
867            } => WmmaInstruction::Execute {
868                frag_a: self.compile_variable(mat_a),
869                frag_b: self.compile_variable(mat_b),
870                frag_c: self.compile_variable(mat_c),
871                frag_d: out,
872                warp_size: self.compilation_options.warp_size,
873            },
874            gpu::CoopMma::ExecuteManual {
875                matrix,
876                registers_a,
877                registers_b,
878                registers_c,
879            } => WmmaInstruction::ExecuteManual {
880                shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
881                frag_a: self.compile_variable(registers_a),
882                frag_b: self.compile_variable(registers_b),
883                frag_c: self.compile_variable(registers_c),
884                frag_d: out,
885            },
886            gpu::CoopMma::ExecuteScaled {
887                matrix,
888                registers_a,
889                registers_b,
890                registers_c,
891                scales_a,
892                scales_b,
893                scales_factor,
894            } => WmmaInstruction::ExecuteScaled {
895                shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
896                frag_a: self.compile_variable(registers_a),
897                frag_b: self.compile_variable(registers_b),
898                frag_c: self.compile_variable(registers_c),
899                frag_d: out,
900
901                scales_a: self.compile_variable(scales_a),
902                scales_b: self.compile_variable(scales_b),
903                scales_factor: scales_factor as u32,
904            },
905            gpu::CoopMma::Store {
906                mat,
907                stride,
908                offset,
909                layout,
910            } => {
911                self.flags.indexes.unit_pos = true;
912                self.flags.indexes.plane_index = true;
913                WmmaInstruction::Store {
914                    output: out,
915                    offset: self.compile_variable(offset),
916                    frag: self.compile_variable(mat),
917                    stride: self.compile_variable(stride),
918                    layout: self
919                        .compile_matrix_layout(layout)
920                        .expect("Layout required for store instruction"),
921                }
922            }
923            gpu::CoopMma::LoadMatrix {
924                buffer,
925                offset,
926                line_size,
927                factor,
928                transpose,
929            } => WmmaInstruction::LdMatrix {
930                output: out,
931                buffer: self.compile_variable(buffer),
932                offset: self.compile_variable(offset),
933                line_size,
934                factor: factor as u32,
935                transpose,
936            },
937            gpu::CoopMma::StoreMatrix {
938                offset,
939                line_size,
940                registers,
941                factor,
942                transpose,
943            } => WmmaInstruction::StMatrix {
944                registers: self.compile_variable(registers),
945                buffer: out,
946                offset: self.compile_variable(offset),
947                line_size,
948                factor: factor as u32,
949                transpose,
950            },
951            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
952                input: self.compile_variable(input),
953                output: out,
954            },
955            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
956                panic!("Row/Col index should be handled by processors")
957            }
958        };
959
960        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
961
962        Instruction::Wmma(inst)
963    }
964
965    fn compile_metadata(
966        &mut self,
967        metadata: gpu::Metadata,
968        out: Option<gpu::Variable>,
969    ) -> Instruction<D> {
970        let out = out.unwrap();
971        match metadata {
972            gpu::Metadata::Stride { dim, var } => {
973                let position = self.ext_meta_position(var);
974                let offset = self.metadata.stride_offset_index(position);
975                Instruction::ExtendedMetadata {
976                    info_offset: self.compile_variable(offset.into()),
977                    dim: self.compile_variable(dim),
978                    split_meta: self.compilation_options.supports_features.grid_constants,
979                    static_offset: self.metadata.static_len(),
980                    out: self.compile_variable(out),
981                }
982            }
983            gpu::Metadata::Shape { dim, var } => {
984                let position = self.ext_meta_position(var);
985                let offset = self.metadata.shape_offset_index(position);
986                Instruction::ExtendedMetadata {
987                    info_offset: self.compile_variable(offset.into()),
988                    dim: self.compile_variable(dim),
989                    split_meta: self.compilation_options.supports_features.grid_constants,
990                    static_offset: self.metadata.static_len(),
991                    out: self.compile_variable(out),
992                }
993            }
994            gpu::Metadata::Rank { var } => {
995                let out = self.compile_variable(out);
996                let pos = self.ext_meta_position(var);
997                let offset = self.metadata.rank_index(pos);
998                super::Instruction::Metadata {
999                    info_offset: self.compile_variable(offset.into()),
1000                    split_meta: self.compilation_options.supports_features.grid_constants,
1001                    out,
1002                }
1003            }
1004            gpu::Metadata::Length { var } => {
1005                let input = self.compile_variable(var);
1006                let out = self.compile_variable(out);
1007
1008                match input {
1009                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1010                    Variable::SharedArray(_id, _item, length) => {
1011                        Instruction::ConstLength { length, out }
1012                    }
1013                    _ => {
1014                        let id = input.id().expect("Variable should have id");
1015                        let offset = self.metadata.len_index(id);
1016                        Instruction::Metadata {
1017                            info_offset: self.compile_variable(offset.into()),
1018                            split_meta: self.compilation_options.supports_features.grid_constants,
1019                            out,
1020                        }
1021                    }
1022                }
1023            }
1024            gpu::Metadata::BufferLength { var } => {
1025                let input = self.compile_variable(var);
1026                let out = self.compile_variable(out);
1027
1028                match input {
1029                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1030                    _ => {
1031                        let id = input.id().expect("Variable should have id");
1032                        let offset = self.metadata.buffer_len_index(id);
1033                        Instruction::Metadata {
1034                            info_offset: self.compile_variable(offset.into()),
1035                            split_meta: self.compilation_options.supports_features.grid_constants,
1036                            out,
1037                        }
1038                    }
1039                }
1040            }
1041        }
1042    }
1043
1044    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
1045        match branch {
1046            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
1047                cond: self.compile_variable(op.cond),
1048                instructions: self.compile_scope(&mut op.scope),
1049            }),
1050            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
1051                cond: self.compile_variable(op.cond),
1052                instructions_if: self.compile_scope(&mut op.scope_if),
1053                instructions_else: self.compile_scope(&mut op.scope_else),
1054            }),
1055            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1056                value: self.compile_variable(op.value),
1057                instructions_default: self.compile_scope(&mut op.scope_default),
1058                instructions_cases: op
1059                    .cases
1060                    .into_iter()
1061                    .map(|(val, mut block)| {
1062                        (self.compile_variable(val), self.compile_scope(&mut block))
1063                    })
1064                    .collect(),
1065            }),
1066            gpu::Branch::Return => instructions.push(Instruction::Return),
1067            gpu::Branch::Break => instructions.push(Instruction::Break),
1068            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1069                i: self.compile_variable(range_loop.i),
1070                start: self.compile_variable(range_loop.start),
1071                end: self.compile_variable(range_loop.end),
1072                step: range_loop.step.map(|it| self.compile_variable(it)),
1073                inclusive: range_loop.inclusive,
1074                instructions: self.compile_scope(&mut range_loop.scope),
1075            }),
1076            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1077                instructions: self.compile_scope(&mut op.scope),
1078            }),
1079        };
1080    }
1081
1082    fn compile_atomic(
1083        &mut self,
1084        value: gpu::AtomicOp,
1085        out: Option<gpu::Variable>,
1086        instructions: &mut Vec<Instruction<D>>,
1087    ) {
1088        let out = out.unwrap();
1089        match value {
1090            gpu::AtomicOp::Load(op) => {
1091                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1092            }
1093            gpu::AtomicOp::Store(op) => {
1094                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1095            }
1096            gpu::AtomicOp::Swap(op) => {
1097                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1098            }
1099            gpu::AtomicOp::Add(op) => {
1100                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1101            }
1102            gpu::AtomicOp::Sub(op) => {
1103                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1104            }
1105            gpu::AtomicOp::Max(op) => {
1106                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1107            }
1108            gpu::AtomicOp::Min(op) => {
1109                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1110            }
1111            gpu::AtomicOp::And(op) => {
1112                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1113            }
1114            gpu::AtomicOp::Or(op) => {
1115                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1116            }
1117            gpu::AtomicOp::Xor(op) => {
1118                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1119            }
1120            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1121                input: self.compile_variable(op.input),
1122                cmp: self.compile_variable(op.cmp),
1123                val: self.compile_variable(op.val),
1124                out: self.compile_variable(out),
1125            }),
1126        }
1127    }
1128
1129    fn compile_arithmetic(
1130        &mut self,
1131        value: gpu::Arithmetic,
1132        out: Option<gpu::Variable>,
1133        modes: InstructionModes,
1134        instructions: &mut Vec<Instruction<D>>,
1135    ) {
1136        let out = out.unwrap();
1137        match value {
1138            gpu::Arithmetic::Add(op) => {
1139                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1140            }
1141            gpu::Arithmetic::SaturatingAdd(op) => {
1142                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1143            }
1144            gpu::Arithmetic::Mul(op) => {
1145                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1146            }
1147            gpu::Arithmetic::Div(op) => {
1148                let op = self.compile_binary(op, out);
1149                instructions.push(self.select_fast_float(
1150                    out.ty,
1151                    modes,
1152                    FastMath::AllowReciprocal
1153                        | FastMath::ReducedPrecision
1154                        | FastMath::UnsignedZero
1155                        | FastMath::NotInf,
1156                    Instruction::Div(op),
1157                    Instruction::FastDiv(op),
1158                ))
1159            }
1160            gpu::Arithmetic::Sub(op) => {
1161                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1162            }
1163            gpu::Arithmetic::SaturatingSub(op) => {
1164                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1165            }
1166            gpu::Arithmetic::MulHi(op) => {
1167                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1168                D::register_instruction_extension(&mut self.extensions, &instruction);
1169                instructions.push(instruction)
1170            }
1171            gpu::Arithmetic::Modulo(op) => {
1172                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1173            }
1174            gpu::Arithmetic::Abs(op) => {
1175                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1176            }
1177            gpu::Arithmetic::Exp(op) => {
1178                let op = self.compile_unary(op, out);
1179                instructions.push(self.select_fast_float(
1180                    out.ty,
1181                    modes,
1182                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1183                    Instruction::Exp(op),
1184                    Instruction::FastExp(op),
1185                ));
1186            }
1187            gpu::Arithmetic::Log(op) => {
1188                let op = self.compile_unary(op, out);
1189                instructions.push(self.select_fast_float(
1190                    out.ty,
1191                    modes,
1192                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1193                    Instruction::Log(op),
1194                    Instruction::FastLog(op),
1195                ));
1196            }
1197            gpu::Arithmetic::Log1p(op) => {
1198                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1199            }
1200            gpu::Arithmetic::Cos(op) => {
1201                let op = self.compile_unary(op, out);
1202                instructions.push(self.select_fast_float(
1203                    out.ty,
1204                    modes,
1205                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1206                    Instruction::Cos(op),
1207                    Instruction::FastCos(op),
1208                ));
1209            }
1210            gpu::Arithmetic::Sin(op) => {
1211                let op = self.compile_unary(op, out);
1212                instructions.push(self.select_fast_float(
1213                    out.ty,
1214                    modes,
1215                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1216                    Instruction::Sin(op),
1217                    Instruction::FastSin(op),
1218                ));
1219            }
1220            gpu::Arithmetic::Tan(op) => {
1221                instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1222            }
1223            gpu::Arithmetic::Tanh(op) => {
1224                let op = self.compile_unary(op, out);
1225                let instruction = Instruction::Tanh(op);
1226                D::register_instruction_extension(&mut self.extensions, &instruction);
1227                if self.compilation_options.supports_features.fast_tanh {
1228                    instructions.push(self.select_fast_float(
1229                        out.ty,
1230                        modes,
1231                        FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1232                        instruction,
1233                        Instruction::FastTanh(op),
1234                    ))
1235                } else {
1236                    instructions.push(instruction);
1237                }
1238            }
1239            gpu::Arithmetic::Sinh(op) => {
1240                let instruction = Instruction::Sinh(self.compile_unary(op, out));
1241                D::register_instruction_extension(&mut self.extensions, &instruction);
1242                instructions.push(instruction)
1243            }
1244            gpu::Arithmetic::Cosh(op) => {
1245                let instruction = Instruction::Cosh(self.compile_unary(op, out));
1246                D::register_instruction_extension(&mut self.extensions, &instruction);
1247                instructions.push(instruction)
1248            }
1249            gpu::Arithmetic::ArcCos(op) => {
1250                let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1251                D::register_instruction_extension(&mut self.extensions, &instruction);
1252                instructions.push(instruction)
1253            }
1254            gpu::Arithmetic::ArcSin(op) => {
1255                let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1256                D::register_instruction_extension(&mut self.extensions, &instruction);
1257                instructions.push(instruction)
1258            }
1259            gpu::Arithmetic::ArcTan(op) => {
1260                let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1261                D::register_instruction_extension(&mut self.extensions, &instruction);
1262                instructions.push(instruction)
1263            }
1264            gpu::Arithmetic::ArcSinh(op) => {
1265                let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1266                D::register_instruction_extension(&mut self.extensions, &instruction);
1267                instructions.push(instruction)
1268            }
1269            gpu::Arithmetic::ArcCosh(op) => {
1270                let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1271                D::register_instruction_extension(&mut self.extensions, &instruction);
1272                instructions.push(instruction)
1273            }
1274            gpu::Arithmetic::ArcTanh(op) => {
1275                let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1276                D::register_instruction_extension(&mut self.extensions, &instruction);
1277                instructions.push(instruction)
1278            }
1279            gpu::Arithmetic::Degrees(op) => {
1280                let instruction = Instruction::Degrees(self.compile_unary(op, out));
1281                D::register_instruction_extension(&mut self.extensions, &instruction);
1282                instructions.push(instruction)
1283            }
1284            gpu::Arithmetic::Radians(op) => {
1285                let instruction = Instruction::Radians(self.compile_unary(op, out));
1286                D::register_instruction_extension(&mut self.extensions, &instruction);
1287                instructions.push(instruction)
1288            }
1289            gpu::Arithmetic::ArcTan2(op) => {
1290                let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1291                D::register_instruction_extension(&mut self.extensions, &instruction);
1292                instructions.push(instruction)
1293            }
1294            gpu::Arithmetic::Powf(op) => {
1295                let op = self.compile_binary(op, out);
1296                instructions.push(self.select_fast_float(
1297                    out.ty,
1298                    modes,
1299                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1300                    Instruction::Powf(op),
1301                    Instruction::FastPowf(op),
1302                ))
1303            }
1304            gpu::Arithmetic::Powi(op) => {
1305                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1306            }
1307            gpu::Arithmetic::Hypot(op) => {
1308                instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1309            }
1310            gpu::Arithmetic::Rhypot(op) => {
1311                instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1312            }
1313            gpu::Arithmetic::Sqrt(op) => {
1314                let op = self.compile_unary(op, out);
1315                instructions.push(self.select_fast_float(
1316                    out.ty,
1317                    modes,
1318                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1319                    Instruction::Sqrt(op),
1320                    Instruction::FastSqrt(op),
1321                ))
1322            }
1323            gpu::Arithmetic::InverseSqrt(op) => {
1324                let op = self.compile_unary(op, out);
1325                instructions.push(self.select_fast_float(
1326                    out.ty,
1327                    modes,
1328                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1329                    Instruction::InverseSqrt(op),
1330                    Instruction::FastInverseSqrt(op),
1331                ))
1332            }
1333            gpu::Arithmetic::Erf(op) => {
1334                let instruction = Instruction::Erf(self.compile_unary(op, out));
1335                D::register_instruction_extension(&mut self.extensions, &instruction);
1336                instructions.push(instruction)
1337            }
1338            gpu::Arithmetic::Max(op) => {
1339                let instruction = Instruction::Max(self.compile_binary(op, out));
1340                D::register_instruction_extension(&mut self.extensions, &instruction);
1341                instructions.push(instruction)
1342            }
1343            gpu::Arithmetic::Min(op) => {
1344                let instruction = Instruction::Min(self.compile_binary(op, out));
1345                D::register_instruction_extension(&mut self.extensions, &instruction);
1346                instructions.push(instruction)
1347            }
1348            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1349                input: self.compile_variable(op.input),
1350                min_value: self.compile_variable(op.min_value),
1351                max_value: self.compile_variable(op.max_value),
1352                out: self.compile_variable(out),
1353            }),
1354            gpu::Arithmetic::Recip(op) => {
1355                let elem = op.input.ty.elem_type();
1356                let input = self.compile_variable(op.input);
1357                let out = self.compile_variable(out);
1358                let lhs = match elem {
1359                    gpu::ElemType::Float(_) => gpu::ConstantValue::Float(1.0),
1360                    gpu::ElemType::Int(_) => gpu::ConstantValue::Int(1),
1361                    gpu::ElemType::UInt(_) => gpu::ConstantValue::UInt(1),
1362                    gpu::ElemType::Bool => gpu::ConstantValue::Bool(true),
1363                };
1364                let div = Instruction::Div(BinaryInstruction {
1365                    lhs: Variable::Constant(lhs, self.compile_type(op.input.ty)),
1366                    rhs: input,
1367                    out,
1368                });
1369                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1370
1371                instructions.push(self.select_fast_float(
1372                    elem.into(),
1373                    modes,
1374                    FastMath::AllowReciprocal
1375                        | FastMath::ReducedPrecision
1376                        | FastMath::UnsignedZero
1377                        | FastMath::NotInf,
1378                    div,
1379                    recip,
1380                ))
1381            }
1382            gpu::Arithmetic::Round(op) => {
1383                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1384            }
1385            gpu::Arithmetic::Floor(op) => {
1386                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1387            }
1388            gpu::Arithmetic::Ceil(op) => {
1389                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1390            }
1391            gpu::Arithmetic::Trunc(op) => {
1392                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1393            }
1394            gpu::Arithmetic::Remainder(op) => {
1395                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1396            }
1397            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1398                a: self.compile_variable(op.a),
1399                b: self.compile_variable(op.b),
1400                c: self.compile_variable(op.c),
1401                out: self.compile_variable(out),
1402            }),
1403            gpu::Arithmetic::Neg(op) => {
1404                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1405            }
1406            gpu::Arithmetic::Normalize(op) => {
1407                let op = self.compile_unary(op, out);
1408                instructions.push(self.select_fast_float(
1409                    out.ty,
1410                    modes,
1411                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1412                    Instruction::Normalize(op),
1413                    Instruction::FastNormalize(op),
1414                ))
1415            }
1416            gpu::Arithmetic::Magnitude(op) => {
1417                let op = self.compile_unary(op, out);
1418                instructions.push(self.select_fast_float(
1419                    out.ty,
1420                    modes,
1421                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1422                    Instruction::Magnitude(op),
1423                    Instruction::FastMagnitude(op),
1424                ))
1425            }
1426            gpu::Arithmetic::Dot(op) => {
1427                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1428            }
1429        };
1430    }
1431
1432    fn select_fast_float(
1433        &self,
1434        ty: gpu::Type,
1435        modes: InstructionModes,
1436        required_flags: EnumSet<FastMath>,
1437        default: Instruction<D>,
1438        fast: Instruction<D>,
1439    ) -> Instruction<D> {
1440        if !self.compilation_options.supports_features.fast_math
1441            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1442        {
1443            return default;
1444        }
1445
1446        if modes.fp_math_mode.is_superset(required_flags) {
1447            fast
1448        } else {
1449            default
1450        }
1451    }
1452
1453    fn compile_comparison(
1454        &mut self,
1455        value: gpu::Comparison,
1456        out: Option<gpu::Variable>,
1457        instructions: &mut Vec<Instruction<D>>,
1458    ) {
1459        let out = out.unwrap();
1460        match value {
1461            gpu::Comparison::Equal(op) => {
1462                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1463            }
1464            gpu::Comparison::Lower(op) => {
1465                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1466            }
1467            gpu::Comparison::Greater(op) => {
1468                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1469            }
1470            gpu::Comparison::LowerEqual(op) => {
1471                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1472            }
1473            gpu::Comparison::GreaterEqual(op) => {
1474                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1475            }
1476            gpu::Comparison::NotEqual(op) => {
1477                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1478            }
1479            gpu::Comparison::IsNan(op) => {
1480                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1481            }
1482            gpu::Comparison::IsInf(op) => {
1483                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1484            }
1485        };
1486    }
1487
1488    fn compile_bitwise(
1489        &mut self,
1490        value: gpu::Bitwise,
1491        out: Option<gpu::Variable>,
1492        instructions: &mut Vec<Instruction<D>>,
1493    ) {
1494        let out = out.unwrap();
1495        match value {
1496            gpu::Bitwise::BitwiseOr(op) => {
1497                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1498            }
1499            gpu::Bitwise::BitwiseAnd(op) => {
1500                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1501            }
1502            gpu::Bitwise::BitwiseXor(op) => {
1503                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1504            }
1505            gpu::Bitwise::CountOnes(op) => {
1506                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1507            }
1508            gpu::Bitwise::ReverseBits(op) => {
1509                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1510            }
1511            gpu::Bitwise::ShiftLeft(op) => {
1512                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1513            }
1514            gpu::Bitwise::ShiftRight(op) => {
1515                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1516            }
1517            gpu::Bitwise::BitwiseNot(op) => {
1518                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1519            }
1520            gpu::Bitwise::LeadingZeros(op) => {
1521                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1522            }
1523            gpu::Bitwise::FindFirstSet(op) => {
1524                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1525                D::register_instruction_extension(&mut self.extensions, &instruction);
1526                instructions.push(instruction)
1527            }
1528        };
1529    }
1530
1531    fn compile_operator(
1532        &mut self,
1533        value: gpu::Operator,
1534        out: Option<gpu::Variable>,
1535        instructions: &mut Vec<Instruction<D>>,
1536    ) {
1537        let out = out.unwrap();
1538        match value {
1539            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1540                instructions.push(Instruction::Index(self.compile_index(op, out)));
1541            }
1542            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1543                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1544            }
1545            gpu::Operator::And(op) => {
1546                instructions.push(Instruction::And(self.compile_binary(op, out)))
1547            }
1548            gpu::Operator::Or(op) => {
1549                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1550            }
1551            gpu::Operator::Not(op) => {
1552                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1553            }
1554            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1555                inputs: op
1556                    .inputs
1557                    .into_iter()
1558                    .map(|it| self.compile_variable(it))
1559                    .collect(),
1560                out: self.compile_variable(out),
1561            }),
1562            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1563                input: self.compile_variable(op.input),
1564                in_index: self.compile_variable(op.in_index),
1565                out: self.compile_variable(out),
1566                out_index: self.compile_variable(op.out_index),
1567            }),
1568            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1569                input: self.compile_variable(op.input),
1570                in_index: self.compile_variable(op.in_index),
1571                out: self.compile_variable(out),
1572                out_index: self.compile_variable(op.out_index),
1573                len: op.len as u32,
1574            }),
1575            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1576                cond: self.compile_variable(op.cond),
1577                then: self.compile_variable(op.then),
1578                or_else: self.compile_variable(op.or_else),
1579                out: self.compile_variable(out),
1580            }),
1581            // Needs special conversion semantics
1582            gpu::Operator::Cast(op)
1583                if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1584            {
1585                // We may need these for intermediates
1586                self.flags.elem_f16 = true;
1587                self.flags.elem_bf16 = true;
1588                let vec_in = op.input.ty.line_size();
1589                let packing = out.storage_type().packing_factor();
1590                self.compile_type(op.input.ty.line(packing));
1591                self.compile_type(
1592                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1593                );
1594                self.compile_type(
1595                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1596                );
1597                self.compile_type(
1598                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1599                );
1600                self.compile_type(
1601                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1602                );
1603
1604                let inst = self.compile_unary(op, out);
1605
1606                instructions.push(Instruction::SpecialCast(inst));
1607            }
1608            gpu::Operator::Cast(op) => {
1609                let op = self.compile_unary(op, out);
1610
1611                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1612                    self.flags.elem_tf32 = true;
1613                }
1614
1615                instructions.push(Instruction::Assign(op))
1616            }
1617            gpu::Operator::Reinterpret(op) => {
1618                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1619            }
1620        };
1621    }
1622
1623    fn compile_binary(
1624        &mut self,
1625        value: gpu::BinaryOperator,
1626        out: gpu::Variable,
1627    ) -> BinaryInstruction<D> {
1628        BinaryInstruction {
1629            lhs: self.compile_variable(value.lhs),
1630            rhs: self.compile_variable(value.rhs),
1631            out: self.compile_variable(out),
1632        }
1633    }
1634
1635    fn compile_index_assign(
1636        &mut self,
1637        value: gpu::IndexAssignOperator,
1638        out: gpu::Variable,
1639    ) -> IndexAssignInstruction<D> {
1640        IndexAssignInstruction {
1641            index: self.compile_variable(value.index),
1642            value: self.compile_variable(value.value),
1643            line_size: value.line_size as u32,
1644            out: self.compile_variable(out),
1645        }
1646    }
1647
1648    fn compile_index(
1649        &mut self,
1650        value: gpu::IndexOperator,
1651        out: gpu::Variable,
1652    ) -> IndexInstruction<D> {
1653        IndexInstruction {
1654            list: self.compile_variable(value.list),
1655            index: self.compile_variable(value.index),
1656            line_size: value.line_size as u32,
1657            out: self.compile_variable(out),
1658        }
1659    }
1660
1661    fn compile_unary(
1662        &mut self,
1663        value: gpu::UnaryOperator,
1664        out: gpu::Variable,
1665    ) -> UnaryInstruction<D> {
1666        UnaryInstruction {
1667            input: self.compile_variable(value.input),
1668            out: self.compile_variable(out),
1669        }
1670    }
1671
1672    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1673        let item = value.ty;
1674        match value.kind {
1675            gpu::VariableKind::GlobalInputArray(id) => {
1676                Variable::GlobalInputArray(id, self.compile_type(item))
1677            }
1678            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1679                id,
1680                elem: self.compile_storage_type(item.storage_type()),
1681                in_struct: self.compilation_options.supports_features.grid_constants,
1682            },
1683            gpu::VariableKind::TensorMapInput(id) => {
1684                self.flags.inst_tma = true;
1685                Variable::TensorMap(id)
1686            }
1687            gpu::VariableKind::TensorMapOutput(id) => {
1688                self.flags.inst_tma = true;
1689                Variable::TensorMap(id)
1690            }
1691            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1692                id,
1693                item: self.compile_type(item),
1694            },
1695            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1696                id,
1697                item: self.compile_type(item),
1698            },
1699            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1700                id,
1701                item: self.compile_type(item),
1702            },
1703            gpu::VariableKind::GlobalOutputArray(id) => {
1704                Variable::GlobalOutputArray(id, self.compile_type(item))
1705            }
1706            gpu::VariableKind::Constant(value) => {
1707                Variable::Constant(value, self.compile_type(item))
1708            }
1709            gpu::VariableKind::SharedArray { id, length, .. } => {
1710                let item = self.compile_type(item);
1711                Variable::SharedArray(id, item, length)
1712            }
1713            gpu::VariableKind::Shared { id } => {
1714                let item = self.compile_type(item);
1715                Variable::Shared(id, item)
1716            }
1717            gpu::VariableKind::ConstantArray {
1718                id,
1719                length,
1720                unroll_factor,
1721            } => {
1722                let item = self.compile_type(item);
1723                Variable::ConstantArray(id, item, length * unroll_factor)
1724            }
1725            gpu::VariableKind::Builtin(builtin) => match builtin {
1726                gpu::Builtin::AbsolutePos => {
1727                    self.flags.indexes.absolute_pos = true;
1728                    let item = self.compile_type(item);
1729                    Variable::AbsolutePos(item.elem)
1730                }
1731                gpu::Builtin::CubePosCluster
1732                    if self.compilation_options.supports_features.clusters =>
1733                {
1734                    self.flags.indexes.cluster_pos = true;
1735                    Variable::ClusterRank
1736                }
1737                gpu::Builtin::CubePosClusterX
1738                    if self.compilation_options.supports_features.clusters =>
1739                {
1740                    self.flags.indexes.cluster_pos = true;
1741                    Variable::ClusterIndexX
1742                }
1743                gpu::Builtin::CubePosClusterY
1744                    if self.compilation_options.supports_features.clusters =>
1745                {
1746                    self.flags.indexes.cluster_pos = true;
1747                    Variable::ClusterIndexY
1748                }
1749                gpu::Builtin::CubePosClusterZ
1750                    if self.compilation_options.supports_features.clusters =>
1751                {
1752                    self.flags.indexes.cluster_pos = true;
1753                    Variable::ClusterIndexZ
1754                }
1755                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1756                // (1, 1, 1) if unsupported
1757                gpu::Builtin::CubePosCluster
1758                | gpu::Builtin::CubePosClusterX
1759                | gpu::Builtin::CubePosClusterY
1760                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1761                gpu::Builtin::AbsolutePosX => {
1762                    self.flags.indexes.absolute_pos_tuple = true;
1763                    Variable::AbsolutePosX
1764                }
1765                gpu::Builtin::AbsolutePosY => {
1766                    self.flags.indexes.absolute_pos_tuple = true;
1767                    Variable::AbsolutePosY
1768                }
1769                gpu::Builtin::AbsolutePosZ => {
1770                    self.flags.indexes.absolute_pos_tuple = true;
1771                    Variable::AbsolutePosZ
1772                }
1773                gpu::Builtin::CubeDim => {
1774                    self.flags.indexes.cube_dim = true;
1775                    Variable::CubeDim
1776                }
1777                gpu::Builtin::CubeDimX => {
1778                    self.flags.indexes.cube_dim_tuple = true;
1779                    Variable::CubeDimX
1780                }
1781                gpu::Builtin::CubeDimY => {
1782                    self.flags.indexes.cube_dim_tuple = true;
1783                    Variable::CubeDimY
1784                }
1785                gpu::Builtin::CubeDimZ => {
1786                    self.flags.indexes.cube_dim_tuple = true;
1787                    Variable::CubeDimZ
1788                }
1789                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1790                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1791                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1792                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1793                gpu::Builtin::CubePos => {
1794                    self.flags.indexes.cube_pos = true;
1795                    let item = self.compile_type(item);
1796                    Variable::CubePos(item.elem)
1797                }
1798                gpu::Builtin::CubePosX => {
1799                    self.flags.indexes.cube_pos_tuple = true;
1800                    Variable::CubePosX
1801                }
1802                gpu::Builtin::CubePosY => {
1803                    self.flags.indexes.cube_pos_tuple = true;
1804                    Variable::CubePosY
1805                }
1806                gpu::Builtin::CubePosZ => {
1807                    self.flags.indexes.cube_pos_tuple = true;
1808                    Variable::CubePosZ
1809                }
1810                gpu::Builtin::CubeCount => {
1811                    self.flags.indexes.cube_count = true;
1812                    let item = self.compile_type(item);
1813                    Variable::CubeCount(item.elem)
1814                }
1815                gpu::Builtin::CubeCountX => {
1816                    self.flags.indexes.cube_count_tuple = true;
1817                    Variable::CubeCountX
1818                }
1819                gpu::Builtin::CubeCountY => {
1820                    self.flags.indexes.cube_count_tuple = true;
1821                    Variable::CubeCountY
1822                }
1823                gpu::Builtin::CubeCountZ => {
1824                    self.flags.indexes.cube_count_tuple = true;
1825                    Variable::CubeCountZ
1826                }
1827                gpu::Builtin::UnitPos => {
1828                    self.flags.indexes.unit_pos = true;
1829                    Variable::UnitPos
1830                }
1831                gpu::Builtin::UnitPosX => {
1832                    self.flags.indexes.unit_pos_tuple = true;
1833                    Variable::UnitPosX
1834                }
1835                gpu::Builtin::UnitPosY => {
1836                    self.flags.indexes.unit_pos_tuple = true;
1837                    Variable::UnitPosY
1838                }
1839                gpu::Builtin::UnitPosZ => {
1840                    self.flags.indexes.unit_pos_tuple = true;
1841                    Variable::UnitPosZ
1842                }
1843                gpu::Builtin::PlaneDim => {
1844                    self.flags.indexes.plane_dim = true;
1845                    Variable::PlaneDim
1846                }
1847                gpu::Builtin::UnitPosPlane => {
1848                    self.flags.indexes.unit_pos_plane = true;
1849                    Variable::UnitPosPlane
1850                }
1851            },
1852            gpu::VariableKind::LocalArray {
1853                id,
1854                length,
1855                unroll_factor,
1856            } => {
1857                let item = self.compile_type(item);
1858                if !self.local_arrays.iter().any(|s| s.index == id) {
1859                    self.local_arrays
1860                        .push(LocalArray::new(id, item, length * unroll_factor));
1861                }
1862                Variable::LocalArray(id, item, length)
1863            }
1864            gpu::VariableKind::Matrix { id, mat } => {
1865                self.flags.inst_wmma = true;
1866                Variable::WmmaFragment {
1867                    id,
1868                    frag: self.compile_matrix(mat),
1869                }
1870            }
1871            gpu::VariableKind::Pipeline { id, num_stages } => {
1872                self.flags.op_pipeline = true;
1873                let pipeline = Variable::Pipeline { id };
1874                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1875                    self.pipelines.push(PipelineOps::Init {
1876                        pipeline,
1877                        num_stages,
1878                    });
1879                }
1880                pipeline
1881            }
1882            gpu::VariableKind::BarrierToken { id, level } => {
1883                self.flags.op_barrier = true;
1884                Variable::BarrierToken { id, level }
1885            }
1886        }
1887    }
1888
1889    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1890        Fragment {
1891            ident: self.compile_matrix_ident(matrix.ident),
1892            m: matrix.m as u32,
1893            n: matrix.n as u32,
1894            k: matrix.k as u32,
1895            elem: self.compile_storage_type(matrix.storage),
1896            layout: self.compile_matrix_layout(matrix.layout),
1897        }
1898    }
1899
1900    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1901        match ident {
1902            gpu::MatrixIdent::A => FragmentIdent::A,
1903            gpu::MatrixIdent::B => FragmentIdent::B,
1904            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1905        }
1906    }
1907
1908    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1909        match layout {
1910            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1911            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1912            gpu::MatrixLayout::Undefined => None,
1913        }
1914    }
1915
1916    fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1917        Binding {
1918            id: binding.id,
1919            item: self.compile_type(binding.ty),
1920            location: binding.location,
1921            size: binding.size,
1922            vis: binding.visibility,
1923        }
1924    }
1925
1926    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1927        let item = match ty {
1928            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1929            gpu::Type::Line(ty, line_size) => {
1930                Item::new(self.compile_storage_type(ty), line_size, false)
1931            }
1932            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1933        };
1934        if item.elem != super::Elem::TF32 {
1935            self.items.insert(item);
1936            self.items.insert(item.optimized());
1937        } else {
1938            // TF32 is represented as `float` in C++
1939            let mut item = item;
1940            item.elem = super::Elem::F32;
1941            self.items.insert(item);
1942        }
1943
1944        item
1945    }
1946
1947    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1948        match value {
1949            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1950            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1951            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1952                FloatKind::E2M1 => {
1953                    self.flags.elem_fp4 = true;
1954                    Elem::FP4x2(FP4Kind::E2M1)
1955                }
1956                FloatKind::E2M3 => {
1957                    self.flags.elem_fp6 = true;
1958                    Elem::FP6x2(FP6Kind::E2M3)
1959                }
1960                FloatKind::E3M2 => {
1961                    self.flags.elem_fp6 = true;
1962                    Elem::FP6(FP6Kind::E3M2)
1963                }
1964                FloatKind::E4M3 => {
1965                    self.flags.elem_fp8 = true;
1966                    Elem::FP8x2(FP8Kind::E4M3)
1967                }
1968                FloatKind::E5M2 => {
1969                    self.flags.elem_fp8 = true;
1970                    Elem::FP8x2(FP8Kind::E5M2)
1971                }
1972                FloatKind::UE8M0 => {
1973                    self.flags.elem_fp8 = true;
1974                    Elem::FP8x2(FP8Kind::UE8M0)
1975                }
1976                FloatKind::F16 => {
1977                    self.flags.elem_f16 = true;
1978                    Elem::F16x2
1979                }
1980                FloatKind::BF16 => {
1981                    self.flags.elem_bf16 = true;
1982                    Elem::BF16x2
1983                }
1984                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1985            },
1986            gpu::StorageType::Packed(other, factor) => {
1987                unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
1988            }
1989            gpu::StorageType::Opaque(ty) => match ty {
1990                gpu::OpaqueType::Barrier(level) => {
1991                    self.flags.op_barrier = true;
1992                    Elem::Barrier(level)
1993                }
1994            },
1995        }
1996    }
1997
1998    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1999        match value {
2000            gpu::ElemType::Float(kind) => match kind {
2001                gpu::FloatKind::E2M1 => {
2002                    self.flags.elem_fp4 = true;
2003                    Elem::FP4(FP4Kind::E2M1)
2004                }
2005                gpu::FloatKind::E2M3 => {
2006                    self.flags.elem_fp6 = true;
2007                    Elem::FP6(FP6Kind::E2M3)
2008                }
2009                gpu::FloatKind::E3M2 => {
2010                    self.flags.elem_fp6 = true;
2011                    Elem::FP6(FP6Kind::E3M2)
2012                }
2013                gpu::FloatKind::E4M3 => {
2014                    self.flags.elem_fp8 = true;
2015                    Elem::FP8(FP8Kind::E4M3)
2016                }
2017                gpu::FloatKind::E5M2 => {
2018                    self.flags.elem_fp8 = true;
2019                    Elem::FP8(FP8Kind::E5M2)
2020                }
2021                gpu::FloatKind::UE8M0 => {
2022                    self.flags.elem_fp8 = true;
2023                    Elem::FP8(FP8Kind::UE8M0)
2024                }
2025                gpu::FloatKind::F16 => {
2026                    self.flags.elem_f16 = true;
2027                    Elem::F16
2028                }
2029                gpu::FloatKind::BF16 => {
2030                    self.flags.elem_bf16 = true;
2031                    Elem::BF16
2032                }
2033                gpu::FloatKind::TF32 => Elem::TF32,
2034                gpu::FloatKind::Flex32 => Elem::F32,
2035                gpu::FloatKind::F32 => Elem::F32,
2036                gpu::FloatKind::F64 => Elem::F64,
2037            },
2038            gpu::ElemType::Int(kind) => match kind {
2039                gpu::IntKind::I8 => Elem::I8,
2040                gpu::IntKind::I16 => Elem::I16,
2041                gpu::IntKind::I32 => Elem::I32,
2042                gpu::IntKind::I64 => Elem::I64,
2043            },
2044            gpu::ElemType::UInt(kind) => match kind {
2045                gpu::UIntKind::U8 => Elem::U8,
2046                gpu::UIntKind::U16 => Elem::U16,
2047                gpu::UIntKind::U32 => Elem::U32,
2048                gpu::UIntKind::U64 => Elem::U64,
2049            },
2050            gpu::ElemType::Bool => Elem::Bool,
2051        }
2052    }
2053}
2054
2055fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
2056    match elem {
2057        gpu::ElemType::Float(kind) => matches!(
2058            kind,
2059            FloatKind::E2M1
2060                | FloatKind::E2M3
2061                | FloatKind::E3M2
2062                | FloatKind::E4M3
2063                | FloatKind::E5M2
2064                | FloatKind::UE8M0
2065        ),
2066        _ => false,
2067    }
2068}
2069
2070fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2071    Variable::Constant(
2072        gpu::ConstantValue::UInt(value as u64),
2073        Item::new(Elem::U32, 1, true),
2074    )
2075}
2076
2077pub fn register_supported_types(props: &mut DeviceProperties) {
2078    props.register_address_type(gpu::AddressType::U32);
2079    props.register_address_type(gpu::AddressType::U64);
2080
2081    let supported_types = [
2082        gpu::ElemType::UInt(gpu::UIntKind::U8),
2083        gpu::ElemType::UInt(gpu::UIntKind::U16),
2084        gpu::ElemType::UInt(gpu::UIntKind::U32),
2085        gpu::ElemType::UInt(gpu::UIntKind::U64),
2086        gpu::ElemType::Int(gpu::IntKind::I8),
2087        gpu::ElemType::Int(gpu::IntKind::I16),
2088        gpu::ElemType::Int(gpu::IntKind::I32),
2089        gpu::ElemType::Int(gpu::IntKind::I64),
2090        gpu::ElemType::Float(gpu::FloatKind::BF16),
2091        gpu::ElemType::Float(gpu::FloatKind::F16),
2092        gpu::ElemType::Float(gpu::FloatKind::F32),
2093        gpu::ElemType::Float(gpu::FloatKind::Flex32),
2094        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
2095        //gpu::Elem::Float(gpu::FloatKind::F64),
2096        gpu::ElemType::Bool,
2097    ];
2098
2099    let supported_atomic_types = [
2100        gpu::ElemType::Int(gpu::IntKind::I32),
2101        gpu::ElemType::Int(gpu::IntKind::I64),
2102        gpu::ElemType::UInt(gpu::UIntKind::U32),
2103        gpu::ElemType::UInt(gpu::UIntKind::U64),
2104        gpu::ElemType::Float(gpu::FloatKind::F32),
2105    ];
2106
2107    for ty in supported_types {
2108        props.register_type_usage(ty, TypeUsage::all_scalar());
2109    }
2110
2111    for ty in supported_atomic_types {
2112        props.register_type_usage(
2113            gpu::StorageType::Atomic(ty),
2114            TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
2115        );
2116    }
2117}