Skip to main content

cubecl_cpp/shared/
base.rs

1use super::{
2    BinaryInstruction, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP4Kind, FP6Kind,
3    FP8Kind, Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction,
4    Instruction, Item, KernelArg, LocalArray, SharedMemory, UnaryInstruction, Variable,
5    WarpInstruction, WmmaInstruction, barrier::BarrierOps, pipeline::PipelineOps,
6};
7use crate::shared::MmaShape;
8use cubecl_common::backtrace::BackTrace;
9use cubecl_core::{
10    CubeDim,
11    ir::{
12        self as gpu, DeviceProperties, ElemType, FloatKind, InstructionModes, OpaqueType,
13        Operation, Processor, SourceLoc, StorageType, Type,
14        features::{AtomicUsage, EnumSet, TypeUsage},
15    },
16    post_processing::checked_io::CheckedIoProcessor,
17    prelude::{FastMath, KernelDefinition},
18    server::ExecutionMode,
19};
20use cubecl_opt::{Optimizer, SharedLiveness};
21use cubecl_runtime::compiler::{CompilationError, Compiler};
22use std::{collections::HashSet, fmt::Debug};
23
24pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
25    std::sync::atomic::AtomicU32::new(0);
26
27#[derive(Clone, Debug)]
28pub struct CompilationOptions {
29    pub warp_size: u32,
30    pub supports_features: CppSupportedFeatures,
31}
32
33#[derive(Clone, Debug, Default)]
34pub struct CppSupportedFeatures {
35    pub grid_constants: bool,
36    pub clusters: bool,
37    pub fast_math: bool,
38    pub fast_tanh: bool,
39    pub elect_sync: bool,
40}
41
42impl Default for CompilationOptions {
43    fn default() -> Self {
44        Self {
45            warp_size: 32,
46            supports_features: Default::default(),
47        }
48    }
49}
50
51/// Cube indexes flags.
52/// When true the corresponding index is declared and computed as needed in the kernel.
53#[derive(Debug, Clone, Default)]
54pub struct CubeIndexFlags {
55    pub absolute_pos: bool,
56    pub absolute_pos_tuple: bool,
57    pub cube_count: bool,
58    pub cube_count_tuple: bool,
59    pub cube_dim: bool,
60    pub cube_dim_tuple: bool,
61    pub cube_pos: bool,
62    pub cube_pos_tuple: bool,
63    pub plane_dim: bool,
64    pub plane_dim_checked: bool,
65    pub plane_pos: bool,
66    pub unit_pos: bool,
67    pub unit_pos_tuple: bool,
68    pub unit_pos_plane: bool,
69    pub cluster_pos: bool,
70}
71
72/// Flags gathered during Cube IR translation for the kernel compilation.
73#[derive(Debug, Clone)]
74pub struct Flags<D: Dialect> {
75    pub elem_fp4: bool,
76    pub elem_fp6: bool,
77    pub elem_fp8: bool,
78    pub elem_bf16: bool,
79    pub elem_f16: bool,
80    pub elem_tf32: bool,
81    pub indexes: CubeIndexFlags,
82    pub op_barrier: bool,
83    pub op_pipeline: bool,
84    pub inst_tma: bool,
85    pub inst_tma_im2col: bool,
86    pub inst_wmma: bool,
87    pub inst_ptx_wrappers: bool,
88    pub inst_async_copy: bool,
89    pub use_grid_constants: bool,
90    pub static_meta_length: usize,
91    pub has_dynamic_meta: bool,
92    pub has_info: bool,
93    pub cube_dim: CubeDim,
94    pub cluster_dim: Option<CubeDim>,
95    pub address_type: Item<D>,
96}
97
98#[allow(clippy::too_many_arguments)]
99#[derive(Clone, Debug)]
100pub struct CppCompiler<D: Dialect> {
101    kernel_name: String,
102    barriers: Vec<BarrierOps<D>>,
103    compilation_options: CompilationOptions,
104    const_arrays: Vec<ConstArray<D>>,
105    ext_meta_positions: Vec<u32>,
106    cluster_dim: CubeDim,
107    extensions: Vec<D::Extension>,
108    flags: Flags<D>,
109    items: HashSet<Item<D>>,
110    local_arrays: Vec<LocalArray<D>>,
111    info: cubecl_core::Info,
112    pipelines: Vec<PipelineOps<D>>,
113    source_loc: Option<SourceLoc>,
114    strategy: ExecutionMode,
115    addr_type: Item<D>,
116}
117
118impl<D: Dialect> Default for Flags<D> {
119    fn default() -> Self {
120        Self {
121            elem_fp4: Default::default(),
122            elem_fp6: Default::default(),
123            elem_fp8: Default::default(),
124            elem_bf16: Default::default(),
125            elem_f16: Default::default(),
126            elem_tf32: Default::default(),
127            indexes: Default::default(),
128            op_barrier: Default::default(),
129            op_pipeline: Default::default(),
130            inst_tma: Default::default(),
131            inst_tma_im2col: Default::default(),
132            inst_wmma: Default::default(),
133            inst_ptx_wrappers: Default::default(),
134            inst_async_copy: Default::default(),
135            use_grid_constants: Default::default(),
136            static_meta_length: Default::default(),
137            has_info: Default::default(),
138            has_dynamic_meta: Default::default(),
139            cube_dim: CubeDim::new_single(),
140            cluster_dim: Default::default(),
141            address_type: Item::scalar(Elem::U32, true),
142        }
143    }
144}
145
146impl<D: Dialect> Default for CppCompiler<D> {
147    fn default() -> Self {
148        Self {
149            kernel_name: Default::default(),
150            barriers: Default::default(),
151            compilation_options: Default::default(),
152            const_arrays: Default::default(),
153            ext_meta_positions: Default::default(),
154            cluster_dim: CubeDim::new_single(),
155            extensions: Default::default(),
156            flags: Flags::default(),
157            items: Default::default(),
158            local_arrays: Default::default(),
159            info: Default::default(),
160            pipelines: Default::default(),
161            source_loc: Default::default(),
162            strategy: Default::default(),
163            addr_type: Item::scalar(Elem::U32, true),
164        }
165    }
166}
167
168impl<D: Dialect> Compiler for CppCompiler<D> {
169    type Representation = ComputeKernel<D>;
170    type CompilationOptions = CompilationOptions;
171
172    fn compile(
173        &mut self,
174        mut kernel: KernelDefinition,
175        compilation_options: &Self::CompilationOptions,
176        strategy: ExecutionMode,
177        addr_type: StorageType,
178    ) -> Result<Self::Representation, CompilationError> {
179        let errors = kernel.body.pop_errors();
180        if !errors.is_empty() {
181            let mut reason = "Can't compile cpp kernel\nCaused by:\n  ".to_string();
182            for error in errors {
183                reason += error.as_str();
184                reason += "\n";
185            }
186
187            return Err(CompilationError::Validation {
188                reason,
189                backtrace: BackTrace::capture(),
190            });
191        }
192
193        self.addr_type = self.compile_type(addr_type.into());
194        self.compilation_options = compilation_options.clone();
195        self.strategy = strategy;
196        self.kernel_name = kernel.options.kernel_name.clone();
197
198        if !self.compilation_options.supports_features.clusters {
199            kernel.options.cluster_dim = None;
200        }
201        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
202
203        let ir = self.clone().compile_ir(kernel, addr_type);
204        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
205        Ok(ir)
206    }
207
208    fn elem_size(&self, elem: gpu::ElemType) -> usize {
209        elem.size()
210    }
211
212    fn extension(&self) -> &'static str {
213        "cpp"
214    }
215}
216
217impl<D: Dialect> CppCompiler<D> {
218    fn compile_ir(
219        mut self,
220        value: KernelDefinition,
221        address_type: StorageType,
222    ) -> ComputeKernel<D> {
223        let metadata = self.build_metadata(&value);
224        self.info = cubecl_core::Info::new(&value.scalars, metadata, address_type);
225
226        let instructions = self.compile_scope(&mut value.body.clone());
227        let tensor_maps = value
228            .tensor_maps
229            .into_iter()
230            .map(|b| self.compile_binding(b))
231            .collect();
232        let buffers = value
233            .buffers
234            .into_iter()
235            .map(|b| self.compile_binding(b))
236            .collect();
237        let scalars = value
238            .scalars
239            .into_iter()
240            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
241            .collect::<Vec<_>>();
242
243        // translation flags
244        let flags = Flags {
245            indexes: D::builtin_rules(&self.flags.indexes),
246            inst_wmma: self.flags.inst_wmma,
247            op_pipeline: self.flags.op_pipeline,
248            op_barrier: self.flags.op_barrier,
249            elem_fp4: self.flags.elem_fp4,
250            elem_fp6: self.flags.elem_fp6,
251            elem_fp8: self.flags.elem_fp8,
252            elem_bf16: self.flags.elem_bf16,
253            elem_f16: self.flags.elem_f16,
254            elem_tf32: self.flags.elem_tf32,
255            inst_tma: self.flags.inst_tma,
256            inst_tma_im2col: self.flags.inst_tma_im2col,
257            inst_async_copy: self.flags.inst_async_copy,
258            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
259            use_grid_constants: self.compilation_options.supports_features.grid_constants,
260            has_info: self.info.has_info(),
261            has_dynamic_meta: self.info.has_dynamic_meta,
262            static_meta_length: self.info.metadata.static_len() as usize,
263            cube_dim: value.cube_dim,
264            cluster_dim: value.options.cluster_dim,
265            address_type: self.compile_type(address_type.into()),
266        };
267
268        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
269        let shared_allocs = opt.analysis::<SharedLiveness>();
270        let shared_memories = shared_allocs
271            .allocations
272            .values()
273            .map(|alloc| match alloc.smem {
274                cubecl_opt::SharedMemory::Array {
275                    id,
276                    length,
277                    ty,
278                    align,
279                } => SharedMemory::Array {
280                    index: id,
281                    item: self.compile_type(ty),
282                    length,
283                    align,
284                    offset: alloc.offset,
285                },
286                cubecl_opt::SharedMemory::Value { id, ty, align } => SharedMemory::Value {
287                    index: id,
288                    item: self.compile_type(ty),
289                    align,
290                    offset: alloc.offset,
291                },
292            })
293            .collect();
294
295        let body = Body {
296            instructions,
297            shared_memories,
298            pipelines: self.pipelines,
299            barriers: self.barriers,
300            const_arrays: self.const_arrays,
301            local_arrays: self.local_arrays,
302            info_by_ptr: !self.compilation_options.supports_features.grid_constants,
303            has_dynamic_meta: self.info.has_dynamic_meta,
304            address_type: self.addr_type,
305        };
306
307        let mut cluster_dim = value.options.cluster_dim;
308        if !self.compilation_options.supports_features.clusters {
309            cluster_dim = None;
310        }
311
312        ComputeKernel {
313            tensor_maps,
314            buffers,
315            scalars,
316            meta_static_len: self.info.metadata.static_len() as usize,
317            cube_dim: value.cube_dim,
318            body,
319            extensions: self.extensions,
320            flags,
321            items: self.items,
322            kernel_name: value.options.kernel_name,
323            cluster_dim,
324            info: self.info.clone(),
325        }
326    }
327
328    fn build_metadata(&mut self, value: &KernelDefinition) -> cubecl_core::Metadata {
329        let mut num_ext = 0;
330
331        let mut all_meta: Vec<_> = value
332            .buffers
333            .iter()
334            .chain(value.tensor_maps.iter())
335            .map(|buf| (buf.id, buf.has_extended_meta))
336            .collect();
337
338        all_meta.sort_by_key(|(id, _)| *id);
339
340        for (_, has_extended_meta) in &all_meta {
341            self.ext_meta_positions.push(num_ext);
342            if *has_extended_meta {
343                num_ext += 1;
344            }
345        }
346
347        let num_meta = all_meta.len();
348
349        cubecl_core::Metadata::new(num_meta as u32, num_ext)
350    }
351
352    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
353        let id = var.index().expect("Variable should have index");
354        self.ext_meta_positions[id as usize]
355    }
356
357    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
358        let mut instructions = Vec::new();
359
360        let const_arrays = scope
361            .const_arrays
362            .drain(..)
363            .map(|(var, values)| ConstArray {
364                index: var.index().unwrap(),
365                item: self.compile_type(var.ty),
366                size: values.len() as u32,
367                values: values
368                    .into_iter()
369                    .map(|val| self.compile_variable(val))
370                    .collect(),
371            })
372            .collect::<Vec<_>>();
373        self.const_arrays.extend(const_arrays);
374
375        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(
376            self.strategy,
377            self.kernel_name.clone(),
378        ));
379        let dialect_processors = D::processors();
380        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
381        processors.extend(dialect_processors.iter().map(|it| &**it));
382
383        let processing = scope.process(processors);
384
385        for var in processing.variables {
386            instructions.push(Instruction::DeclareVariable {
387                var: self.compile_variable(var),
388            });
389        }
390
391        processing
392            .instructions
393            .into_iter()
394            .for_each(|op| self.compile_instruction(&mut instructions, op));
395
396        instructions
397    }
398
399    fn compile_instruction(
400        &mut self,
401        instructions: &mut Vec<Instruction<D>>,
402        instruction: gpu::Instruction,
403    ) {
404        self.update_debug_loc(instructions, &instruction);
405        let out = instruction.out;
406        match instruction.operation {
407            gpu::Operation::Copy(variable) => {
408                instructions.push(Instruction::Assign(UnaryInstruction {
409                    input: self.compile_variable(variable),
410                    out: self.compile_variable(out.unwrap()),
411                }));
412            }
413            gpu::Operation::Arithmetic(op) => {
414                self.compile_arithmetic(op, out, instruction.modes, instructions)
415            }
416            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
417            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
418            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
419            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
420            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
421            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
422            gpu::Operation::Synchronization(val) => match val {
423                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
424                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
425                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
426                gpu::Synchronization::SyncAsyncProxyShared => {
427                    self.flags.inst_tma = true;
428                    instructions.push(Instruction::ProxyAsyncToSharedFence)
429                }
430            },
431            gpu::Operation::Plane(op) => {
432                self.flags.indexes.plane_dim_checked = true;
433                let out = self.compile_variable(out.unwrap());
434                match op {
435                    gpu::Plane::Sum(op) => {
436                        let instruction = WarpInstruction::ReduceSum {
437                            input: self.compile_variable(op.input),
438                            out,
439                        };
440                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
441                        instructions.push(Instruction::Warp(instruction));
442                    }
443                    gpu::Plane::InclusiveSum(op) => {
444                        self.flags.indexes.unit_pos_plane = true;
445                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
446                            input: self.compile_variable(op.input),
447                            out,
448                        }))
449                    }
450                    gpu::Plane::InclusiveProd(op) => {
451                        self.flags.indexes.unit_pos_plane = true;
452                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
453                            input: self.compile_variable(op.input),
454                            out,
455                        }))
456                    }
457                    gpu::Plane::ExclusiveSum(op) => {
458                        self.flags.indexes.unit_pos_plane = true;
459                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
460                            input: self.compile_variable(op.input),
461                            out,
462                        }))
463                    }
464                    gpu::Plane::ExclusiveProd(op) => {
465                        self.flags.indexes.unit_pos_plane = true;
466                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
467                            input: self.compile_variable(op.input),
468                            out,
469                        }))
470                    }
471                    gpu::Plane::Prod(op) => {
472                        let instruction = WarpInstruction::ReduceProd {
473                            input: self.compile_variable(op.input),
474                            out,
475                        };
476                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
477                        instructions.push(Instruction::Warp(instruction))
478                    }
479                    gpu::Plane::Max(op) => {
480                        let instruction = WarpInstruction::ReduceMax {
481                            input: self.compile_variable(op.input),
482                            out,
483                        };
484                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
485                        instructions.push(Instruction::Warp(instruction))
486                    }
487                    gpu::Plane::Min(op) => {
488                        let instruction = WarpInstruction::ReduceMin {
489                            input: self.compile_variable(op.input),
490                            out,
491                        };
492                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
493                        instructions.push(Instruction::Warp(instruction))
494                    }
495                    gpu::Plane::Elect => {
496                        if self.compilation_options.supports_features.elect_sync {
497                            self.flags.inst_ptx_wrappers = true;
498                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
499                        } else {
500                            instructions
501                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
502                        }
503                    }
504                    gpu::Plane::All(op) => {
505                        instructions.push(Instruction::Warp(WarpInstruction::All {
506                            input: self.compile_variable(op.input),
507                            out,
508                        }))
509                    }
510                    gpu::Plane::Any(op) => {
511                        instructions.push(Instruction::Warp(WarpInstruction::Any {
512                            input: self.compile_variable(op.input),
513                            out,
514                        }))
515                    }
516                    gpu::Plane::Ballot(op) => {
517                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
518                            input: self.compile_variable(op.input),
519                            out,
520                        }))
521                    }
522                    gpu::Plane::Broadcast(op) => {
523                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
524                            input: self.compile_variable(op.lhs),
525                            id: self.compile_variable(op.rhs),
526                            out,
527                        }))
528                    }
529                    gpu::Plane::Shuffle(op) => {
530                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
531                            input: self.compile_variable(op.lhs),
532                            src_lane: self.compile_variable(op.rhs),
533                            out,
534                        }))
535                    }
536                    gpu::Plane::ShuffleXor(op) => {
537                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
538                            input: self.compile_variable(op.lhs),
539                            mask: self.compile_variable(op.rhs),
540                            out,
541                        }))
542                    }
543                    gpu::Plane::ShuffleUp(op) => {
544                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
545                            input: self.compile_variable(op.lhs),
546                            delta: self.compile_variable(op.rhs),
547                            out,
548                        }))
549                    }
550                    gpu::Plane::ShuffleDown(op) => {
551                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
552                            input: self.compile_variable(op.lhs),
553                            delta: self.compile_variable(op.rhs),
554                            out,
555                        }))
556                    }
557                }
558            }
559            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
560            gpu::Operation::NonSemantic(debug) => match debug {
561                gpu::NonSemantic::Print {
562                    format_string,
563                    args,
564                } => instructions.push(Instruction::Printf {
565                    format_string,
566                    args: args
567                        .into_iter()
568                        .map(|arg| self.compile_variable(arg))
569                        .collect(),
570                }),
571                gpu::NonSemantic::Comment { content } => {
572                    instructions.push(Instruction::Comment { content })
573                }
574                // Don't need to handle scopes
575                _ => {}
576            },
577            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
578                gpu::BarrierOps::Declare { barrier } => {
579                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
580                    else {
581                        unreachable!()
582                    };
583                    let barrier = self.compile_variable(barrier);
584                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
585                        barrier,
586                        level,
587                    }));
588                }
589                gpu::BarrierOps::Init {
590                    barrier,
591                    is_elected,
592                    arrival_count,
593                } => {
594                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
595                    else {
596                        unreachable!()
597                    };
598                    let barrier = self.compile_variable(barrier);
599                    let arrival_count = self.compile_variable(arrival_count);
600                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
601                        barrier,
602                        is_elected: self.compile_variable(is_elected),
603                        arrival_count,
604                        level,
605                    }));
606                }
607                gpu::BarrierOps::InitManual {
608                    barrier,
609                    arrival_count,
610                } => {
611                    let barrier = self.compile_variable(barrier);
612                    let arrival_count = self.compile_variable(arrival_count);
613                    instructions.push(Instruction::Barrier(
614                        super::barrier::BarrierOps::InitManual {
615                            barrier,
616                            arrival_count,
617                        },
618                    ));
619                }
620                gpu::BarrierOps::MemCopyAsync {
621                    barrier,
622                    source,
623                    source_length,
624                    offset_source,
625                    offset_out,
626                } => {
627                    instructions.push(Instruction::Barrier(
628                        super::barrier::BarrierOps::MemCopyAsync {
629                            barrier: self.compile_variable(barrier),
630                            source: self.compile_variable(source),
631                            destination: self.compile_variable(out.unwrap()),
632                            source_length: self.compile_variable(source_length),
633                            offset_source: self.compile_variable(offset_source),
634                            offset_out: self.compile_variable(offset_out),
635                            cooperative: false,
636                        },
637                    ));
638                }
639                gpu::BarrierOps::MemCopyAsyncCooperative {
640                    barrier,
641                    source,
642                    source_length,
643                    offset_source,
644                    offset_out,
645                } => {
646                    instructions.push(Instruction::Barrier(
647                        super::barrier::BarrierOps::MemCopyAsync {
648                            barrier: self.compile_variable(barrier),
649                            source: self.compile_variable(source),
650                            destination: self.compile_variable(out.unwrap()),
651                            source_length: self.compile_variable(source_length),
652                            offset_source: self.compile_variable(offset_source),
653                            offset_out: self.compile_variable(offset_out),
654                            cooperative: true,
655                        },
656                    ));
657                }
658                gpu::BarrierOps::MemCopyAsyncTx {
659                    barrier,
660                    source,
661                    source_length,
662                    offset_source,
663                    offset_out,
664                } => {
665                    instructions.push(Instruction::Barrier(
666                        super::barrier::BarrierOps::MemCopyAsyncTx {
667                            barrier: self.compile_variable(barrier),
668                            source: self.compile_variable(source),
669                            destination: self.compile_variable(out.unwrap()),
670                            source_length: self.compile_variable(source_length),
671                            offset_source: self.compile_variable(offset_source),
672                            offset_out: self.compile_variable(offset_out),
673                        },
674                    ));
675                }
676                gpu::BarrierOps::CopyAsync {
677                    source,
678                    source_length,
679                    offset_source,
680                    offset_out,
681                    copy_length,
682                    checked,
683                } => {
684                    self.flags.inst_async_copy = true;
685                    instructions.push(Instruction::Barrier(
686                        super::barrier::BarrierOps::CopyAsync {
687                            source: self.compile_variable(source),
688                            destination: self.compile_variable(out.unwrap()),
689                            source_length: self.compile_variable(source_length),
690                            offset_source: self.compile_variable(offset_source),
691                            offset_out: self.compile_variable(offset_out),
692                            copy_size: copy_length,
693                            checked,
694                        },
695                    ));
696                }
697                gpu::BarrierOps::TmaLoad {
698                    barrier,
699                    tensor_map,
700                    offset_out,
701                    indices,
702                } => {
703                    instructions.push(Instruction::Barrier(
704                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
705                            barrier: self.compile_variable(barrier),
706                            smem_buffer: self.compile_variable(out.unwrap()),
707                            smem_offset: self.compile_variable(offset_out),
708                            tensor_map: self.compile_variable(tensor_map),
709                            indices: indices
710                                .into_iter()
711                                .map(|it| self.compile_variable(it))
712                                .collect(),
713                        },
714                    ));
715                }
716                gpu::BarrierOps::TmaLoadIm2col {
717                    barrier,
718                    tensor_map,
719                    offset_out,
720                    indices,
721                    offsets,
722                } => {
723                    self.flags.inst_tma_im2col = true;
724                    instructions.push(Instruction::Barrier(
725                        super::barrier::BarrierOps::TmaLoadIm2col {
726                            barrier: self.compile_variable(barrier),
727                            smem_buffer: self.compile_variable(out.unwrap()),
728                            smem_offset: self.compile_variable(offset_out),
729                            tensor_map: self.compile_variable(tensor_map),
730                            indices: indices
731                                .into_iter()
732                                .map(|it| self.compile_variable(it))
733                                .collect(),
734                            offsets: offsets
735                                .into_iter()
736                                .map(|it| self.compile_variable(it))
737                                .collect(),
738                        },
739                    ));
740                }
741                gpu::BarrierOps::Arrive { barrier } => {
742                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
743                        barrier: self.compile_variable(barrier),
744                        token: self.compile_variable(out.unwrap()),
745                    }))
746                }
747                gpu::BarrierOps::ArriveTx {
748                    barrier,
749                    arrive_count_update,
750                    transaction_count_update,
751                } => {
752                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
753                        barrier: self.compile_variable(barrier),
754                        token: self.compile_variable(out.unwrap()),
755                        arrive_count_update: self.compile_variable(arrive_count_update),
756                        transaction_count_update: self.compile_variable(transaction_count_update),
757                    }))
758                }
759                gpu::BarrierOps::CommitCopyAsync { barrier } => {
760                    self.flags.inst_async_copy = true;
761                    instructions.push(Instruction::Barrier(
762                        super::barrier::BarrierOps::ArriveCopyAsync {
763                            barrier: self.compile_variable(barrier),
764                        },
765                    ))
766                }
767                gpu::BarrierOps::ExpectTx {
768                    barrier,
769                    transaction_count_update,
770                } => {
771                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
772                        barrier: self.compile_variable(barrier),
773                        transaction_count_update: self.compile_variable(transaction_count_update),
774                    }))
775                }
776                gpu::BarrierOps::Wait { barrier, token } => {
777                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
778                        barrier: self.compile_variable(barrier),
779                        token: self.compile_variable(token),
780                    }))
781                }
782                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
783                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
784                        barrier: self.compile_variable(barrier),
785                        phase: self.compile_variable(phase),
786                    }),
787                ),
788                gpu::BarrierOps::ArriveAndWait { barrier } => {
789                    let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type()
790                    else {
791                        unreachable!()
792                    };
793                    instructions.push(Instruction::Barrier(
794                        super::barrier::BarrierOps::ArriveAndWait {
795                            barrier: self.compile_variable(barrier),
796                            level,
797                        },
798                    ))
799                }
800            },
801            gpu::Operation::Tma(tma_ops) => {
802                self.flags.inst_tma = true;
803                match tma_ops {
804                    gpu::TmaOps::TmaStore {
805                        source,
806                        coordinates,
807                        offset_source,
808                    } => {
809                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
810                            smem_buffer: self.compile_variable(source),
811                            smem_offset: self.compile_variable(offset_source),
812                            tensor_map: self.compile_variable(out.unwrap()),
813                            indices: coordinates
814                                .into_iter()
815                                .map(|it| self.compile_variable(it))
816                                .collect(),
817                        });
818                    }
819                    gpu::TmaOps::CommitGroup => {
820                        instructions.push(Instruction::BulkCommitGroup);
821                    }
822                    gpu::TmaOps::WaitGroup { max_pending } => {
823                        instructions.push(Instruction::BulkWaitGroup { max_pending });
824                    }
825                    gpu::TmaOps::WaitGroupRead { max_pending } => {
826                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
827                    }
828                }
829            }
830            gpu::Operation::Marker(_) => {}
831        }
832    }
833
834    fn update_debug_loc(
835        &mut self,
836        instructions: &mut Vec<Instruction<D>>,
837        inst: &gpu::Instruction,
838    ) {
839        if !matches!(inst.operation, Operation::NonSemantic(_)) {
840            match &inst.source_loc {
841                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
842                    self.source_loc = Some(loc.clone());
843                    instructions.push(Instruction::Line {
844                        file: loc.source.file.clone(),
845                        line: loc.line,
846                    });
847                }
848                _ => {}
849            }
850        }
851    }
852
853    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
854        self.flags.inst_wmma = true;
855
856        let out = self.compile_variable(out.unwrap());
857
858        let inst = match cmma {
859            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
860                frag: out,
861                value: self.compile_variable(value),
862            },
863            gpu::CoopMma::Load {
864                value,
865                stride,
866                offset,
867                layout,
868            } => WmmaInstruction::Load {
869                frag: out,
870                offset: self.compile_variable(offset),
871                value: self.compile_variable(value),
872                stride: self.compile_variable(stride),
873                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
874            },
875            gpu::CoopMma::Execute {
876                mat_a,
877                mat_b,
878                mat_c,
879            } => WmmaInstruction::Execute {
880                frag_a: self.compile_variable(mat_a),
881                frag_b: self.compile_variable(mat_b),
882                frag_c: self.compile_variable(mat_c),
883                frag_d: out,
884                warp_size: self.compilation_options.warp_size,
885            },
886            gpu::CoopMma::ExecuteManual {
887                matrix,
888                registers_a,
889                registers_b,
890                registers_c,
891            } => WmmaInstruction::ExecuteManual {
892                shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
893                frag_a: self.compile_variable(registers_a),
894                frag_b: self.compile_variable(registers_b),
895                frag_c: self.compile_variable(registers_c),
896                frag_d: out,
897            },
898            gpu::CoopMma::ExecuteScaled {
899                matrix,
900                registers_a,
901                registers_b,
902                registers_c,
903                scales_a,
904                scales_b,
905                scales_factor,
906            } => WmmaInstruction::ExecuteScaled {
907                shape: MmaShape::new(matrix.m as u32, matrix.n as u32, matrix.k as u32),
908                frag_a: self.compile_variable(registers_a),
909                frag_b: self.compile_variable(registers_b),
910                frag_c: self.compile_variable(registers_c),
911                frag_d: out,
912
913                scales_a: self.compile_variable(scales_a),
914                scales_b: self.compile_variable(scales_b),
915                scales_factor: scales_factor as u32,
916            },
917            gpu::CoopMma::Store {
918                mat,
919                stride,
920                offset,
921                layout,
922            } => {
923                self.flags.indexes.unit_pos = true;
924                self.flags.indexes.plane_pos = true;
925                WmmaInstruction::Store {
926                    output: out,
927                    offset: self.compile_variable(offset),
928                    frag: self.compile_variable(mat),
929                    stride: self.compile_variable(stride),
930                    layout: self
931                        .compile_matrix_layout(layout)
932                        .expect("Layout required for store instruction"),
933                }
934            }
935            gpu::CoopMma::LoadMatrix {
936                buffer,
937                offset,
938                vector_size,
939                factor,
940                transpose,
941            } => WmmaInstruction::LdMatrix {
942                output: out,
943                buffer: self.compile_variable(buffer),
944                offset: self.compile_variable(offset),
945                vector_size,
946                factor: factor as u32,
947                transpose,
948            },
949            gpu::CoopMma::StoreMatrix {
950                offset,
951                vector_size,
952                registers,
953                factor,
954                transpose,
955            } => WmmaInstruction::StMatrix {
956                registers: self.compile_variable(registers),
957                buffer: out,
958                offset: self.compile_variable(offset),
959                vector_size,
960                factor: factor as u32,
961                transpose,
962            },
963            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
964                input: self.compile_variable(input),
965                output: out,
966            },
967            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
968                panic!("Row/Col index should be handled by processors")
969            }
970        };
971
972        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
973
974        Instruction::Wmma(inst)
975    }
976
977    fn compile_metadata(
978        &mut self,
979        metadata: gpu::Metadata,
980        out: Option<gpu::Variable>,
981    ) -> Instruction<D> {
982        let out = out.unwrap();
983        match metadata {
984            gpu::Metadata::Stride { dim, var } => {
985                let position = self.ext_meta_position(var);
986                let offset = self.info.metadata.stride_offset_index(position);
987                Instruction::ExtendedMetadata {
988                    info_offset: self.compile_variable(offset.into()),
989                    dim: self.compile_variable(dim),
990                    out: self.compile_variable(out),
991                }
992            }
993            gpu::Metadata::Shape { dim, var } => {
994                let position = self.ext_meta_position(var);
995                let offset = self.info.metadata.shape_offset_index(position);
996                Instruction::ExtendedMetadata {
997                    info_offset: self.compile_variable(offset.into()),
998                    dim: self.compile_variable(dim),
999                    out: self.compile_variable(out),
1000                }
1001            }
1002            gpu::Metadata::Rank { var } => {
1003                let out = self.compile_variable(out);
1004                let pos = self.ext_meta_position(var);
1005                let offset = self.info.metadata.rank_index(pos);
1006                super::Instruction::Metadata {
1007                    info_offset: self.compile_variable(offset.into()),
1008                    out,
1009                }
1010            }
1011            gpu::Metadata::Length { var } => {
1012                let input = self.compile_variable(var);
1013                let out = self.compile_variable(out);
1014
1015                match input {
1016                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1017                    Variable::SharedArray(_id, _item, length) => {
1018                        Instruction::ConstLength { length, out }
1019                    }
1020                    _ => {
1021                        let id = input.id().expect("Variable should have id");
1022                        let offset = self.info.metadata.len_index(id);
1023                        Instruction::Metadata {
1024                            info_offset: self.compile_variable(offset.into()),
1025                            out,
1026                        }
1027                    }
1028                }
1029            }
1030            gpu::Metadata::BufferLength { var } => {
1031                let input = self.compile_variable(var);
1032                let out = self.compile_variable(out);
1033
1034                match input {
1035                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
1036                    _ => {
1037                        let id = input.id().expect("Variable should have id");
1038                        let offset = self.info.metadata.buffer_len_index(id);
1039                        Instruction::Metadata {
1040                            info_offset: self.compile_variable(offset.into()),
1041                            out,
1042                        }
1043                    }
1044                }
1045            }
1046        }
1047    }
1048
1049    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
1050        match branch {
1051            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
1052                cond: self.compile_variable(op.cond),
1053                instructions: self.compile_scope(&mut op.scope),
1054            }),
1055            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
1056                cond: self.compile_variable(op.cond),
1057                instructions_if: self.compile_scope(&mut op.scope_if),
1058                instructions_else: self.compile_scope(&mut op.scope_else),
1059            }),
1060            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
1061                value: self.compile_variable(op.value),
1062                instructions_default: self.compile_scope(&mut op.scope_default),
1063                instructions_cases: op
1064                    .cases
1065                    .into_iter()
1066                    .map(|(val, mut block)| {
1067                        (self.compile_variable(val), self.compile_scope(&mut block))
1068                    })
1069                    .collect(),
1070            }),
1071            gpu::Branch::Return => instructions.push(Instruction::Return),
1072            gpu::Branch::Break => instructions.push(Instruction::Break),
1073            gpu::Branch::Unreachable => instructions.push(Instruction::Unreachable),
1074            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
1075                i: self.compile_variable(range_loop.i),
1076                start: self.compile_variable(range_loop.start),
1077                end: self.compile_variable(range_loop.end),
1078                step: range_loop.step.map(|it| self.compile_variable(it)),
1079                inclusive: range_loop.inclusive,
1080                instructions: self.compile_scope(&mut range_loop.scope),
1081            }),
1082            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
1083                instructions: self.compile_scope(&mut op.scope),
1084            }),
1085        };
1086    }
1087
1088    fn compile_atomic(
1089        &mut self,
1090        value: gpu::AtomicOp,
1091        out: Option<gpu::Variable>,
1092        instructions: &mut Vec<Instruction<D>>,
1093    ) {
1094        let out = out.unwrap();
1095        match value {
1096            gpu::AtomicOp::Load(op) => {
1097                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
1098            }
1099            gpu::AtomicOp::Store(op) => {
1100                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
1101            }
1102            gpu::AtomicOp::Swap(op) => {
1103                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
1104            }
1105            gpu::AtomicOp::Add(op) => {
1106                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
1107            }
1108            gpu::AtomicOp::Sub(op) => {
1109                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
1110            }
1111            gpu::AtomicOp::Max(op) => {
1112                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
1113            }
1114            gpu::AtomicOp::Min(op) => {
1115                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
1116            }
1117            gpu::AtomicOp::And(op) => {
1118                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
1119            }
1120            gpu::AtomicOp::Or(op) => {
1121                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1122            }
1123            gpu::AtomicOp::Xor(op) => {
1124                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1125            }
1126            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1127                input: self.compile_variable(op.input),
1128                cmp: self.compile_variable(op.cmp),
1129                val: self.compile_variable(op.val),
1130                out: self.compile_variable(out),
1131            }),
1132        }
1133    }
1134
1135    fn compile_arithmetic(
1136        &mut self,
1137        value: gpu::Arithmetic,
1138        out: Option<gpu::Variable>,
1139        modes: InstructionModes,
1140        instructions: &mut Vec<Instruction<D>>,
1141    ) {
1142        let out = out.unwrap();
1143        match value {
1144            gpu::Arithmetic::Add(op) => {
1145                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1146            }
1147            gpu::Arithmetic::SaturatingAdd(op) => {
1148                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1149            }
1150            gpu::Arithmetic::Mul(op) => {
1151                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1152            }
1153            gpu::Arithmetic::Div(op) => {
1154                let op = self.compile_binary(op, out);
1155                instructions.push(self.select_fast_float(
1156                    out.ty,
1157                    modes,
1158                    FastMath::AllowReciprocal
1159                        | FastMath::ReducedPrecision
1160                        | FastMath::UnsignedZero
1161                        | FastMath::NotInf,
1162                    Instruction::Div(op),
1163                    Instruction::FastDiv(op),
1164                ))
1165            }
1166            gpu::Arithmetic::Sub(op) => {
1167                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1168            }
1169            gpu::Arithmetic::SaturatingSub(op) => {
1170                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1171            }
1172            gpu::Arithmetic::MulHi(op) => {
1173                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1174                D::register_instruction_extension(&mut self.extensions, &instruction);
1175                instructions.push(instruction)
1176            }
1177            gpu::Arithmetic::Modulo(op) => {
1178                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1179            }
1180            gpu::Arithmetic::Abs(op) => {
1181                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1182            }
1183            gpu::Arithmetic::Exp(op) => {
1184                let op = self.compile_unary(op, out);
1185                instructions.push(self.select_fast_float(
1186                    out.ty,
1187                    modes,
1188                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1189                    Instruction::Exp(op),
1190                    Instruction::FastExp(op),
1191                ));
1192            }
1193            gpu::Arithmetic::Log(op) => {
1194                let op = self.compile_unary(op, out);
1195                instructions.push(self.select_fast_float(
1196                    out.ty,
1197                    modes,
1198                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1199                    Instruction::Log(op),
1200                    Instruction::FastLog(op),
1201                ));
1202            }
1203            gpu::Arithmetic::Log1p(op) => {
1204                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1205            }
1206            gpu::Arithmetic::Cos(op) => {
1207                let op = self.compile_unary(op, out);
1208                instructions.push(self.select_fast_float(
1209                    out.ty,
1210                    modes,
1211                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1212                    Instruction::Cos(op),
1213                    Instruction::FastCos(op),
1214                ));
1215            }
1216            gpu::Arithmetic::Sin(op) => {
1217                let op = self.compile_unary(op, out);
1218                instructions.push(self.select_fast_float(
1219                    out.ty,
1220                    modes,
1221                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1222                    Instruction::Sin(op),
1223                    Instruction::FastSin(op),
1224                ));
1225            }
1226            gpu::Arithmetic::Tan(op) => {
1227                instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1228            }
1229            gpu::Arithmetic::Tanh(op) => {
1230                let op = self.compile_unary(op, out);
1231                let instruction = Instruction::Tanh(op);
1232                D::register_instruction_extension(&mut self.extensions, &instruction);
1233                if self.compilation_options.supports_features.fast_tanh {
1234                    instructions.push(self.select_fast_float(
1235                        out.ty,
1236                        modes,
1237                        FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1238                        instruction,
1239                        Instruction::FastTanh(op),
1240                    ))
1241                } else {
1242                    instructions.push(instruction);
1243                }
1244            }
1245            gpu::Arithmetic::Sinh(op) => {
1246                let instruction = Instruction::Sinh(self.compile_unary(op, out));
1247                D::register_instruction_extension(&mut self.extensions, &instruction);
1248                instructions.push(instruction)
1249            }
1250            gpu::Arithmetic::Cosh(op) => {
1251                let instruction = Instruction::Cosh(self.compile_unary(op, out));
1252                D::register_instruction_extension(&mut self.extensions, &instruction);
1253                instructions.push(instruction)
1254            }
1255            gpu::Arithmetic::ArcCos(op) => {
1256                let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1257                D::register_instruction_extension(&mut self.extensions, &instruction);
1258                instructions.push(instruction)
1259            }
1260            gpu::Arithmetic::ArcSin(op) => {
1261                let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1262                D::register_instruction_extension(&mut self.extensions, &instruction);
1263                instructions.push(instruction)
1264            }
1265            gpu::Arithmetic::ArcTan(op) => {
1266                let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1267                D::register_instruction_extension(&mut self.extensions, &instruction);
1268                instructions.push(instruction)
1269            }
1270            gpu::Arithmetic::ArcSinh(op) => {
1271                let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1272                D::register_instruction_extension(&mut self.extensions, &instruction);
1273                instructions.push(instruction)
1274            }
1275            gpu::Arithmetic::ArcCosh(op) => {
1276                let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1277                D::register_instruction_extension(&mut self.extensions, &instruction);
1278                instructions.push(instruction)
1279            }
1280            gpu::Arithmetic::ArcTanh(op) => {
1281                let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1282                D::register_instruction_extension(&mut self.extensions, &instruction);
1283                instructions.push(instruction)
1284            }
1285            gpu::Arithmetic::Degrees(op) => {
1286                let instruction = Instruction::Degrees(self.compile_unary(op, out));
1287                D::register_instruction_extension(&mut self.extensions, &instruction);
1288                instructions.push(instruction)
1289            }
1290            gpu::Arithmetic::Radians(op) => {
1291                let instruction = Instruction::Radians(self.compile_unary(op, out));
1292                D::register_instruction_extension(&mut self.extensions, &instruction);
1293                instructions.push(instruction)
1294            }
1295            gpu::Arithmetic::ArcTan2(op) => {
1296                let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1297                D::register_instruction_extension(&mut self.extensions, &instruction);
1298                instructions.push(instruction)
1299            }
1300            gpu::Arithmetic::Powf(op) => {
1301                let op = self.compile_binary(op, out);
1302                instructions.push(self.select_fast_float(
1303                    out.ty,
1304                    modes,
1305                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1306                    Instruction::Powf(op),
1307                    Instruction::FastPowf(op),
1308                ))
1309            }
1310            gpu::Arithmetic::Powi(op) => {
1311                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1312            }
1313            gpu::Arithmetic::Hypot(op) => {
1314                instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1315            }
1316            gpu::Arithmetic::Rhypot(op) => {
1317                instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1318            }
1319            gpu::Arithmetic::Sqrt(op) => {
1320                let op = self.compile_unary(op, out);
1321                instructions.push(self.select_fast_float(
1322                    out.ty,
1323                    modes,
1324                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1325                    Instruction::Sqrt(op),
1326                    Instruction::FastSqrt(op),
1327                ))
1328            }
1329            gpu::Arithmetic::InverseSqrt(op) => {
1330                let op = self.compile_unary(op, out);
1331                instructions.push(self.select_fast_float(
1332                    out.ty,
1333                    modes,
1334                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1335                    Instruction::InverseSqrt(op),
1336                    Instruction::FastInverseSqrt(op),
1337                ))
1338            }
1339            gpu::Arithmetic::Erf(op) => {
1340                let instruction = Instruction::Erf(self.compile_unary(op, out));
1341                D::register_instruction_extension(&mut self.extensions, &instruction);
1342                instructions.push(instruction)
1343            }
1344            gpu::Arithmetic::Max(op) => {
1345                let instruction = Instruction::Max(self.compile_binary(op, out));
1346                D::register_instruction_extension(&mut self.extensions, &instruction);
1347                instructions.push(instruction)
1348            }
1349            gpu::Arithmetic::Min(op) => {
1350                let instruction = Instruction::Min(self.compile_binary(op, out));
1351                D::register_instruction_extension(&mut self.extensions, &instruction);
1352                instructions.push(instruction)
1353            }
1354            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1355                input: self.compile_variable(op.input),
1356                min_value: self.compile_variable(op.min_value),
1357                max_value: self.compile_variable(op.max_value),
1358                out: self.compile_variable(out),
1359            }),
1360            gpu::Arithmetic::Recip(op) => {
1361                let elem = op.input.ty.elem_type();
1362                let input = self.compile_variable(op.input);
1363                let out = self.compile_variable(out);
1364                let lhs = match elem {
1365                    gpu::ElemType::Float(_) => gpu::ConstantValue::Float(1.0),
1366                    gpu::ElemType::Int(_) => gpu::ConstantValue::Int(1),
1367                    gpu::ElemType::UInt(_) => gpu::ConstantValue::UInt(1),
1368                    gpu::ElemType::Bool => gpu::ConstantValue::Bool(true),
1369                };
1370                let div = Instruction::Div(BinaryInstruction {
1371                    lhs: Variable::Constant(lhs, self.compile_type(op.input.ty)),
1372                    rhs: input,
1373                    out,
1374                });
1375                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1376
1377                instructions.push(self.select_fast_float(
1378                    elem.into(),
1379                    modes,
1380                    FastMath::AllowReciprocal
1381                        | FastMath::ReducedPrecision
1382                        | FastMath::UnsignedZero
1383                        | FastMath::NotInf,
1384                    div,
1385                    recip,
1386                ))
1387            }
1388            gpu::Arithmetic::Round(op) => {
1389                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1390            }
1391            gpu::Arithmetic::Floor(op) => {
1392                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1393            }
1394            gpu::Arithmetic::Ceil(op) => {
1395                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1396            }
1397            gpu::Arithmetic::Trunc(op) => {
1398                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1399            }
1400            gpu::Arithmetic::Remainder(op) => {
1401                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1402            }
1403            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1404                a: self.compile_variable(op.a),
1405                b: self.compile_variable(op.b),
1406                c: self.compile_variable(op.c),
1407                out: self.compile_variable(out),
1408            }),
1409            gpu::Arithmetic::Neg(op) => {
1410                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1411            }
1412            gpu::Arithmetic::Normalize(op) => {
1413                let op = self.compile_unary(op, out);
1414                instructions.push(self.select_fast_float(
1415                    out.ty,
1416                    modes,
1417                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1418                    Instruction::Normalize(op),
1419                    Instruction::FastNormalize(op),
1420                ))
1421            }
1422            gpu::Arithmetic::Magnitude(op) => {
1423                let op = self.compile_unary(op, out);
1424                instructions.push(self.select_fast_float(
1425                    out.ty,
1426                    modes,
1427                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1428                    Instruction::Magnitude(op),
1429                    Instruction::FastMagnitude(op),
1430                ))
1431            }
1432            gpu::Arithmetic::Dot(op) => {
1433                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1434            }
1435            gpu::Arithmetic::VectorSum(op) => {
1436                instructions.push(Instruction::VectorSum(self.compile_unary(op, out)))
1437            }
1438        };
1439    }
1440
1441    fn select_fast_float(
1442        &self,
1443        ty: gpu::Type,
1444        modes: InstructionModes,
1445        required_flags: EnumSet<FastMath>,
1446        default: Instruction<D>,
1447        fast: Instruction<D>,
1448    ) -> Instruction<D> {
1449        if !self.compilation_options.supports_features.fast_math
1450            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1451        {
1452            return default;
1453        }
1454
1455        if modes.fp_math_mode.is_superset(required_flags) {
1456            fast
1457        } else {
1458            default
1459        }
1460    }
1461
1462    fn compile_comparison(
1463        &mut self,
1464        value: gpu::Comparison,
1465        out: Option<gpu::Variable>,
1466        instructions: &mut Vec<Instruction<D>>,
1467    ) {
1468        let out = out.unwrap();
1469        match value {
1470            gpu::Comparison::Equal(op) => {
1471                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1472            }
1473            gpu::Comparison::Lower(op) => {
1474                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1475            }
1476            gpu::Comparison::Greater(op) => {
1477                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1478            }
1479            gpu::Comparison::LowerEqual(op) => {
1480                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1481            }
1482            gpu::Comparison::GreaterEqual(op) => {
1483                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1484            }
1485            gpu::Comparison::NotEqual(op) => {
1486                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1487            }
1488            gpu::Comparison::IsNan(op) => {
1489                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1490            }
1491            gpu::Comparison::IsInf(op) => {
1492                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1493            }
1494        };
1495    }
1496
1497    fn compile_bitwise(
1498        &mut self,
1499        value: gpu::Bitwise,
1500        out: Option<gpu::Variable>,
1501        instructions: &mut Vec<Instruction<D>>,
1502    ) {
1503        let out = out.unwrap();
1504        match value {
1505            gpu::Bitwise::BitwiseOr(op) => {
1506                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1507            }
1508            gpu::Bitwise::BitwiseAnd(op) => {
1509                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1510            }
1511            gpu::Bitwise::BitwiseXor(op) => {
1512                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1513            }
1514            gpu::Bitwise::CountOnes(op) => {
1515                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1516            }
1517            gpu::Bitwise::ReverseBits(op) => {
1518                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1519            }
1520            gpu::Bitwise::ShiftLeft(op) => {
1521                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1522            }
1523            gpu::Bitwise::ShiftRight(op) => {
1524                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1525            }
1526            gpu::Bitwise::BitwiseNot(op) => {
1527                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1528            }
1529            gpu::Bitwise::LeadingZeros(op) => {
1530                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1531            }
1532            gpu::Bitwise::TrailingZeros(op) => {
1533                instructions.push(Instruction::TrailingZeros(self.compile_unary(op, out)))
1534            }
1535            gpu::Bitwise::FindFirstSet(op) => {
1536                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1537                D::register_instruction_extension(&mut self.extensions, &instruction);
1538                instructions.push(instruction)
1539            }
1540        };
1541    }
1542
1543    fn compile_operator(
1544        &mut self,
1545        value: gpu::Operator,
1546        out: Option<gpu::Variable>,
1547        instructions: &mut Vec<Instruction<D>>,
1548    ) {
1549        let out = out.unwrap();
1550        match value {
1551            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1552                instructions.push(Instruction::Index(self.compile_index(op, out)));
1553            }
1554            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1555                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1556            }
1557            gpu::Operator::And(op) => {
1558                instructions.push(Instruction::And(self.compile_binary(op, out)))
1559            }
1560            gpu::Operator::Or(op) => {
1561                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1562            }
1563            gpu::Operator::Not(op) => {
1564                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1565            }
1566            gpu::Operator::InitVector(op) => instructions.push(Instruction::VecInit {
1567                inputs: op
1568                    .inputs
1569                    .into_iter()
1570                    .map(|it| self.compile_variable(it))
1571                    .collect(),
1572                out: self.compile_variable(out),
1573            }),
1574            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1575                input: self.compile_variable(op.input),
1576                in_index: self.compile_variable(op.in_index),
1577                out: self.compile_variable(out),
1578                out_index: self.compile_variable(op.out_index),
1579            }),
1580            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1581                input: self.compile_variable(op.input),
1582                in_index: self.compile_variable(op.in_index),
1583                out: self.compile_variable(out),
1584                out_index: self.compile_variable(op.out_index),
1585                len: op.len as u32,
1586            }),
1587            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1588                cond: self.compile_variable(op.cond),
1589                then: self.compile_variable(op.then),
1590                or_else: self.compile_variable(op.or_else),
1591                out: self.compile_variable(out),
1592            }),
1593            // Needs special conversion semantics
1594            gpu::Operator::Cast(op)
1595                if (is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()))
1596                // Trivial broadcast shouldn't use special cast logic
1597                    && op.input.elem_type() != out.elem_type() =>
1598            {
1599                // We may need these for intermediates
1600                self.flags.elem_f16 = true;
1601                self.flags.elem_bf16 = true;
1602                let vec_in = op.input.ty.vector_size();
1603                let packing = out.storage_type().packing_factor();
1604                self.compile_type(op.input.ty.with_vector_size(packing));
1605                self.compile_type(
1606                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16))
1607                        .with_vector_size(vec_in),
1608                );
1609                self.compile_type(
1610                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16))
1611                        .with_vector_size(vec_in),
1612                );
1613                self.compile_type(
1614                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16))
1615                        .with_vector_size(packing),
1616                );
1617                self.compile_type(
1618                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16))
1619                        .with_vector_size(packing),
1620                );
1621
1622                let inst = self.compile_unary(op, out);
1623
1624                instructions.push(Instruction::SpecialCast(inst));
1625            }
1626            gpu::Operator::Cast(op) => {
1627                let op = self.compile_unary(op, out);
1628
1629                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1630                    self.flags.elem_tf32 = true;
1631                }
1632
1633                instructions.push(Instruction::Assign(op))
1634            }
1635            gpu::Operator::Reinterpret(op) => {
1636                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1637            }
1638        };
1639    }
1640
1641    fn compile_binary(
1642        &mut self,
1643        value: gpu::BinaryOperator,
1644        out: gpu::Variable,
1645    ) -> BinaryInstruction<D> {
1646        BinaryInstruction {
1647            lhs: self.compile_variable(value.lhs),
1648            rhs: self.compile_variable(value.rhs),
1649            out: self.compile_variable(out),
1650        }
1651    }
1652
1653    fn compile_index_assign(
1654        &mut self,
1655        value: gpu::IndexAssignOperator,
1656        out: gpu::Variable,
1657    ) -> IndexAssignInstruction<D> {
1658        IndexAssignInstruction {
1659            index: self.compile_variable(value.index),
1660            value: self.compile_variable(value.value),
1661            vector_size: value.vector_size as u32,
1662            out: self.compile_variable(out),
1663        }
1664    }
1665
1666    fn compile_index(
1667        &mut self,
1668        value: gpu::IndexOperator,
1669        out: gpu::Variable,
1670    ) -> IndexInstruction<D> {
1671        IndexInstruction {
1672            list: self.compile_variable(value.list),
1673            index: self.compile_variable(value.index),
1674            vector_size: value.vector_size as u32,
1675            out: self.compile_variable(out),
1676        }
1677    }
1678
1679    fn compile_unary(
1680        &mut self,
1681        value: gpu::UnaryOperator,
1682        out: gpu::Variable,
1683    ) -> UnaryInstruction<D> {
1684        UnaryInstruction {
1685            input: self.compile_variable(value.input),
1686            out: self.compile_variable(out),
1687        }
1688    }
1689
1690    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1691        let item = value.ty;
1692        match value.kind {
1693            gpu::VariableKind::GlobalInputArray(id) => {
1694                Variable::GlobalInputArray(id, self.compile_type(item))
1695            }
1696            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1697                id,
1698                elem: self.compile_storage_type(item.storage_type()),
1699            },
1700            gpu::VariableKind::TensorMapInput(id) => {
1701                self.flags.inst_tma = true;
1702                Variable::TensorMap(id)
1703            }
1704            gpu::VariableKind::TensorMapOutput(id) => {
1705                self.flags.inst_tma = true;
1706                Variable::TensorMap(id)
1707            }
1708            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1709                id,
1710                item: self.compile_type(item),
1711            },
1712            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1713                id,
1714                item: self.compile_type(item),
1715            },
1716            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1717                id,
1718                item: self.compile_type(item),
1719            },
1720            gpu::VariableKind::GlobalOutputArray(id) => {
1721                Variable::GlobalOutputArray(id, self.compile_type(item))
1722            }
1723            gpu::VariableKind::Constant(value) => {
1724                Variable::Constant(value, self.compile_type(item))
1725            }
1726            gpu::VariableKind::SharedArray { id, length, .. } => {
1727                let item = self.compile_type(item);
1728                Variable::SharedArray(id, item, length)
1729            }
1730            gpu::VariableKind::Shared { id } => {
1731                let item = self.compile_type(item);
1732                Variable::Shared(id, item)
1733            }
1734            gpu::VariableKind::ConstantArray {
1735                id,
1736                length,
1737                unroll_factor,
1738            } => {
1739                let item = self.compile_type(item);
1740                Variable::ConstantArray(id, item, length * unroll_factor)
1741            }
1742            gpu::VariableKind::Builtin(builtin) => match builtin {
1743                gpu::Builtin::AbsolutePos => {
1744                    self.flags.indexes.absolute_pos = true;
1745                    let item = self.compile_type(item);
1746                    Variable::AbsolutePos(item.elem)
1747                }
1748                gpu::Builtin::CubePosCluster
1749                    if self.compilation_options.supports_features.clusters =>
1750                {
1751                    self.flags.indexes.cluster_pos = true;
1752                    Variable::ClusterRank
1753                }
1754                gpu::Builtin::CubePosClusterX
1755                    if self.compilation_options.supports_features.clusters =>
1756                {
1757                    self.flags.indexes.cluster_pos = true;
1758                    Variable::ClusterIndexX
1759                }
1760                gpu::Builtin::CubePosClusterY
1761                    if self.compilation_options.supports_features.clusters =>
1762                {
1763                    self.flags.indexes.cluster_pos = true;
1764                    Variable::ClusterIndexY
1765                }
1766                gpu::Builtin::CubePosClusterZ
1767                    if self.compilation_options.supports_features.clusters =>
1768                {
1769                    self.flags.indexes.cluster_pos = true;
1770                    Variable::ClusterIndexZ
1771                }
1772                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1773                // (1, 1, 1) if unsupported
1774                gpu::Builtin::CubePosCluster
1775                | gpu::Builtin::CubePosClusterX
1776                | gpu::Builtin::CubePosClusterY
1777                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1778                gpu::Builtin::AbsolutePosX => {
1779                    self.flags.indexes.absolute_pos_tuple = true;
1780                    Variable::AbsolutePosX
1781                }
1782                gpu::Builtin::AbsolutePosY => {
1783                    self.flags.indexes.absolute_pos_tuple = true;
1784                    Variable::AbsolutePosY
1785                }
1786                gpu::Builtin::AbsolutePosZ => {
1787                    self.flags.indexes.absolute_pos_tuple = true;
1788                    Variable::AbsolutePosZ
1789                }
1790                gpu::Builtin::CubeDim => {
1791                    self.flags.indexes.cube_dim = true;
1792                    Variable::CubeDim
1793                }
1794                gpu::Builtin::CubeDimX => {
1795                    self.flags.indexes.cube_dim_tuple = true;
1796                    Variable::CubeDimX
1797                }
1798                gpu::Builtin::CubeDimY => {
1799                    self.flags.indexes.cube_dim_tuple = true;
1800                    Variable::CubeDimY
1801                }
1802                gpu::Builtin::CubeDimZ => {
1803                    self.flags.indexes.cube_dim_tuple = true;
1804                    Variable::CubeDimZ
1805                }
1806                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1807                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1808                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1809                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1810                gpu::Builtin::CubePos => {
1811                    self.flags.indexes.cube_pos = true;
1812                    let item = self.compile_type(item);
1813                    Variable::CubePos(item.elem)
1814                }
1815                gpu::Builtin::CubePosX => {
1816                    self.flags.indexes.cube_pos_tuple = true;
1817                    Variable::CubePosX
1818                }
1819                gpu::Builtin::CubePosY => {
1820                    self.flags.indexes.cube_pos_tuple = true;
1821                    Variable::CubePosY
1822                }
1823                gpu::Builtin::CubePosZ => {
1824                    self.flags.indexes.cube_pos_tuple = true;
1825                    Variable::CubePosZ
1826                }
1827                gpu::Builtin::CubeCount => {
1828                    self.flags.indexes.cube_count = true;
1829                    let item = self.compile_type(item);
1830                    Variable::CubeCount(item.elem)
1831                }
1832                gpu::Builtin::CubeCountX => {
1833                    self.flags.indexes.cube_count_tuple = true;
1834                    Variable::CubeCountX
1835                }
1836                gpu::Builtin::CubeCountY => {
1837                    self.flags.indexes.cube_count_tuple = true;
1838                    Variable::CubeCountY
1839                }
1840                gpu::Builtin::CubeCountZ => {
1841                    self.flags.indexes.cube_count_tuple = true;
1842                    Variable::CubeCountZ
1843                }
1844                gpu::Builtin::UnitPos => {
1845                    self.flags.indexes.unit_pos = true;
1846                    Variable::UnitPos
1847                }
1848                gpu::Builtin::UnitPosX => {
1849                    self.flags.indexes.unit_pos_tuple = true;
1850                    Variable::UnitPosX
1851                }
1852                gpu::Builtin::UnitPosY => {
1853                    self.flags.indexes.unit_pos_tuple = true;
1854                    Variable::UnitPosY
1855                }
1856                gpu::Builtin::UnitPosZ => {
1857                    self.flags.indexes.unit_pos_tuple = true;
1858                    Variable::UnitPosZ
1859                }
1860                gpu::Builtin::PlaneDim => {
1861                    self.flags.indexes.plane_dim = true;
1862                    Variable::PlaneDim
1863                }
1864                gpu::Builtin::PlanePos => {
1865                    self.flags.indexes.plane_pos = true;
1866                    Variable::PlanePos
1867                }
1868                gpu::Builtin::UnitPosPlane => {
1869                    self.flags.indexes.unit_pos_plane = true;
1870                    Variable::UnitPosPlane
1871                }
1872            },
1873            gpu::VariableKind::LocalArray {
1874                id,
1875                length,
1876                unroll_factor,
1877            } => {
1878                let item = self.compile_type(item);
1879                if !self.local_arrays.iter().any(|s| s.index == id) {
1880                    self.local_arrays
1881                        .push(LocalArray::new(id, item, length * unroll_factor));
1882                }
1883                Variable::LocalArray(id, item, length)
1884            }
1885            gpu::VariableKind::Matrix { id, mat } => {
1886                self.flags.inst_wmma = true;
1887                Variable::WmmaFragment {
1888                    id,
1889                    frag: self.compile_matrix(mat),
1890                }
1891            }
1892            gpu::VariableKind::Pipeline { id, num_stages } => {
1893                self.flags.op_pipeline = true;
1894                let pipeline = Variable::Pipeline { id };
1895                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1896                    self.pipelines.push(PipelineOps::Init {
1897                        pipeline,
1898                        num_stages,
1899                    });
1900                }
1901                pipeline
1902            }
1903            gpu::VariableKind::BarrierToken { id, level } => {
1904                self.flags.op_barrier = true;
1905                Variable::BarrierToken { id, level }
1906            }
1907        }
1908    }
1909
1910    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1911        Fragment {
1912            ident: self.compile_matrix_ident(matrix.ident),
1913            m: matrix.m as u32,
1914            n: matrix.n as u32,
1915            k: matrix.k as u32,
1916            elem: self.compile_storage_type(matrix.storage),
1917            layout: self.compile_matrix_layout(matrix.layout),
1918        }
1919    }
1920
1921    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1922        match ident {
1923            gpu::MatrixIdent::A => FragmentIdent::A,
1924            gpu::MatrixIdent::B => FragmentIdent::B,
1925            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1926        }
1927    }
1928
1929    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1930        match layout {
1931            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1932            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1933            gpu::MatrixLayout::Undefined => None,
1934        }
1935    }
1936
1937    fn compile_binding(&mut self, binding: cubecl_runtime::kernel::KernelArg) -> KernelArg<D> {
1938        KernelArg {
1939            id: binding.id,
1940            item: self.compile_type(binding.ty),
1941            size: binding.size,
1942            vis: binding.visibility,
1943        }
1944    }
1945
1946    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1947        let item = match ty {
1948            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1949            gpu::Type::Vector(ty, vector_size) => {
1950                Item::new(self.compile_storage_type(ty), vector_size, false)
1951            }
1952            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1953        };
1954        if item.elem != super::Elem::TF32 {
1955            self.items.insert(item);
1956            self.items.insert(item.optimized());
1957        } else {
1958            // TF32 is represented as `float` in C++
1959            let mut item = item;
1960            item.elem = super::Elem::F32;
1961            self.items.insert(item);
1962        }
1963
1964        item
1965    }
1966
1967    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1968        match value {
1969            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1970            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1971            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1972                FloatKind::E2M1 => {
1973                    self.flags.elem_fp4 = true;
1974                    Elem::FP4x2(FP4Kind::E2M1)
1975                }
1976                FloatKind::E2M3 => {
1977                    self.flags.elem_fp6 = true;
1978                    Elem::FP6x2(FP6Kind::E2M3)
1979                }
1980                FloatKind::E3M2 => {
1981                    self.flags.elem_fp6 = true;
1982                    Elem::FP6(FP6Kind::E3M2)
1983                }
1984                FloatKind::E4M3 => {
1985                    self.flags.elem_fp8 = true;
1986                    Elem::FP8x2(FP8Kind::E4M3)
1987                }
1988                FloatKind::E5M2 => {
1989                    self.flags.elem_fp8 = true;
1990                    Elem::FP8x2(FP8Kind::E5M2)
1991                }
1992                FloatKind::UE8M0 => {
1993                    self.flags.elem_fp8 = true;
1994                    Elem::FP8x2(FP8Kind::UE8M0)
1995                }
1996                FloatKind::F16 => {
1997                    self.flags.elem_f16 = true;
1998                    Elem::F16x2
1999                }
2000                FloatKind::BF16 => {
2001                    self.flags.elem_bf16 = true;
2002                    Elem::BF16x2
2003                }
2004                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
2005            },
2006            gpu::StorageType::Packed(other, factor) => {
2007                unimplemented!("Unsupported storage type: packed<{other}, {factor}>")
2008            }
2009            gpu::StorageType::Opaque(ty) => match ty {
2010                gpu::OpaqueType::Barrier(level) => {
2011                    self.flags.op_barrier = true;
2012                    Elem::Barrier(level)
2013                }
2014            },
2015        }
2016    }
2017
2018    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
2019        match value {
2020            gpu::ElemType::Float(kind) => match kind {
2021                gpu::FloatKind::E2M1 => {
2022                    self.flags.elem_fp4 = true;
2023                    Elem::FP4(FP4Kind::E2M1)
2024                }
2025                gpu::FloatKind::E2M3 => {
2026                    self.flags.elem_fp6 = true;
2027                    Elem::FP6(FP6Kind::E2M3)
2028                }
2029                gpu::FloatKind::E3M2 => {
2030                    self.flags.elem_fp6 = true;
2031                    Elem::FP6(FP6Kind::E3M2)
2032                }
2033                gpu::FloatKind::E4M3 => {
2034                    self.flags.elem_fp8 = true;
2035                    Elem::FP8(FP8Kind::E4M3)
2036                }
2037                gpu::FloatKind::E5M2 => {
2038                    self.flags.elem_fp8 = true;
2039                    Elem::FP8(FP8Kind::E5M2)
2040                }
2041                gpu::FloatKind::UE8M0 => {
2042                    self.flags.elem_fp8 = true;
2043                    Elem::FP8(FP8Kind::UE8M0)
2044                }
2045                gpu::FloatKind::F16 => {
2046                    self.flags.elem_f16 = true;
2047                    Elem::F16
2048                }
2049                gpu::FloatKind::BF16 => {
2050                    self.flags.elem_bf16 = true;
2051                    Elem::BF16
2052                }
2053                gpu::FloatKind::TF32 => Elem::TF32,
2054                gpu::FloatKind::Flex32 => Elem::F32,
2055                gpu::FloatKind::F32 => Elem::F32,
2056                gpu::FloatKind::F64 => Elem::F64,
2057            },
2058            gpu::ElemType::Int(kind) => match kind {
2059                gpu::IntKind::I8 => Elem::I8,
2060                gpu::IntKind::I16 => Elem::I16,
2061                gpu::IntKind::I32 => Elem::I32,
2062                gpu::IntKind::I64 => Elem::I64,
2063            },
2064            gpu::ElemType::UInt(kind) => match kind {
2065                gpu::UIntKind::U8 => Elem::U8,
2066                gpu::UIntKind::U16 => Elem::U16,
2067                gpu::UIntKind::U32 => Elem::U32,
2068                gpu::UIntKind::U64 => Elem::U64,
2069            },
2070            gpu::ElemType::Bool => Elem::Bool,
2071        }
2072    }
2073}
2074
2075fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
2076    match elem {
2077        gpu::ElemType::Float(kind) => matches!(
2078            kind,
2079            FloatKind::E2M1
2080                | FloatKind::E2M3
2081                | FloatKind::E3M2
2082                | FloatKind::E4M3
2083                | FloatKind::E5M2
2084                | FloatKind::UE8M0
2085        ),
2086        _ => false,
2087    }
2088}
2089
2090fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
2091    Variable::Constant(
2092        gpu::ConstantValue::UInt(value as u64),
2093        Item::new(Elem::U32, 1, true),
2094    )
2095}
2096
2097pub fn register_supported_types(props: &mut DeviceProperties) {
2098    props.register_address_type(gpu::AddressType::U32);
2099    props.register_address_type(gpu::AddressType::U64);
2100
2101    let supported_types = [
2102        gpu::ElemType::UInt(gpu::UIntKind::U8),
2103        gpu::ElemType::UInt(gpu::UIntKind::U16),
2104        gpu::ElemType::UInt(gpu::UIntKind::U32),
2105        gpu::ElemType::UInt(gpu::UIntKind::U64),
2106        gpu::ElemType::Int(gpu::IntKind::I8),
2107        gpu::ElemType::Int(gpu::IntKind::I16),
2108        gpu::ElemType::Int(gpu::IntKind::I32),
2109        gpu::ElemType::Int(gpu::IntKind::I64),
2110        gpu::ElemType::Float(gpu::FloatKind::BF16),
2111        gpu::ElemType::Float(gpu::FloatKind::F16),
2112        gpu::ElemType::Float(gpu::FloatKind::F32),
2113        gpu::ElemType::Float(gpu::FloatKind::Flex32),
2114        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
2115        //gpu::Elem::Float(gpu::FloatKind::F64),
2116        gpu::ElemType::Bool,
2117    ];
2118
2119    let supported_atomic_types = [
2120        gpu::ElemType::Int(gpu::IntKind::I32),
2121        gpu::ElemType::Int(gpu::IntKind::I64),
2122        gpu::ElemType::UInt(gpu::UIntKind::U32),
2123        gpu::ElemType::UInt(gpu::UIntKind::U64),
2124        gpu::ElemType::Float(gpu::FloatKind::F32),
2125    ];
2126
2127    for ty in supported_types {
2128        props.register_type_usage(ty, TypeUsage::all());
2129    }
2130
2131    for ty in supported_atomic_types {
2132        props.register_atomic_type_usage(
2133            Type::new(gpu::StorageType::Atomic(ty)),
2134            AtomicUsage::Add | AtomicUsage::LoadStore,
2135        );
2136    }
2137}