cubecl_cpp/shared/
base.rs

1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind, VariableKind};
5use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
6use cubecl_core::{
7    Compiler,
8    ir::{self as gpu},
9};
10use cubecl_core::{CubeDim, ir::ElemType};
11use cubecl_core::{
12    ir::{Operation, SourceLoc},
13    prelude::{FastMath, KernelDefinition},
14};
15use cubecl_opt::{Optimizer, SharedLiveness};
16use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage};
17
18use crate::shared::MmaShape;
19
20use super::{
21    BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
22    Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
23    Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
24};
25use super::{FP4Kind, barrier::BarrierOps};
26use super::{FP8Kind, pipeline::PipelineOps};
27
28pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
29    std::sync::atomic::AtomicU32::new(0);
30
31#[derive(Clone, Debug)]
32pub struct CompilationOptions {
33    pub warp_size: u32,
34    pub supports_features: CppSupportedFeatures,
35}
36
37#[derive(Clone, Debug, Default)]
38pub struct CppSupportedFeatures {
39    pub grid_constants: bool,
40    pub clusters: bool,
41    pub fast_math: bool,
42    pub fast_tanh: bool,
43    pub elect_sync: bool,
44}
45
46impl Default for CompilationOptions {
47    fn default() -> Self {
48        Self {
49            warp_size: 32,
50            supports_features: Default::default(),
51        }
52    }
53}
54
55/// Cube indexes flags.
56/// When true the corresponding index is declared and computed as needed in the kernel.
57#[derive(Debug, Clone, Default)]
58pub struct CubeIndexFlags {
59    pub absolute_pos: bool,
60    pub absolute_pos_tuple: bool,
61    pub cube_count: bool,
62    pub cube_count_tuple: bool,
63    pub cube_dim: bool,
64    pub cube_dim_tuple: bool,
65    pub cube_pos: bool,
66    pub cube_pos_tuple: bool,
67    pub plane_dim: bool,
68    pub plane_dim_checked: bool,
69    pub plane_index: bool,
70    pub unit_pos: bool,
71    pub unit_pos_tuple: bool,
72    pub unit_pos_plane: bool,
73    pub cluster_pos: bool,
74}
75
76/// Flags gathered during Cube IR translation for the kernel compilation.
77#[derive(Debug, Clone, Default)]
78pub struct Flags {
79    pub elem_fp4: bool,
80    pub elem_fp6: bool,
81    pub elem_fp8: bool,
82    pub elem_bf16: bool,
83    pub elem_f16: bool,
84    pub elem_tf32: bool,
85    pub indexes: CubeIndexFlags,
86    pub op_barrier: bool,
87    pub op_pipeline: bool,
88    pub inst_tma: bool,
89    pub inst_tma_im2col: bool,
90    pub inst_wmma: bool,
91    pub inst_ptx_wrappers: bool,
92    pub use_grid_constants: bool,
93    pub static_meta_length: usize,
94    pub has_dynamic_meta: bool,
95    pub cube_dim: CubeDim,
96    pub cluster_dim: Option<CubeDim>,
97}
98
99#[allow(clippy::too_many_arguments)]
100#[derive(Clone, Debug, Default)]
101pub struct CppCompiler<D: Dialect> {
102    barriers: Vec<BarrierOps<D>>,
103    compilation_options: CompilationOptions,
104    const_arrays: Vec<ConstArray<D>>,
105    ext_meta_positions: Vec<u32>,
106    cluster_dim: CubeDim,
107    extensions: Vec<D::Extension>,
108    flags: Flags,
109    items: HashSet<Item<D>>,
110    local_arrays: Vec<LocalArray<D>>,
111    metadata: cubecl_core::Metadata,
112    pipelines: Vec<PipelineOps<D>>,
113    source_loc: Option<SourceLoc>,
114    strategy: ExecutionMode,
115}
116
117impl<D: Dialect> Compiler for CppCompiler<D> {
118    type Representation = ComputeKernel<D>;
119    type CompilationOptions = CompilationOptions;
120
121    fn compile(
122        &mut self,
123        mut kernel: KernelDefinition,
124        compilation_options: &Self::CompilationOptions,
125        strategy: ExecutionMode,
126    ) -> Self::Representation {
127        self.compilation_options = compilation_options.clone();
128        self.strategy = strategy;
129
130        if !self.compilation_options.supports_features.clusters {
131            kernel.options.cluster_dim = None;
132        }
133        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
134
135        let ir = self.clone().compile_ir(kernel);
136        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
137        ir
138    }
139
140    fn elem_size(&self, elem: gpu::ElemType) -> usize {
141        elem.size()
142    }
143
144    fn extension(&self) -> &'static str {
145        "cpp"
146    }
147}
148
149impl<D: Dialect> CppCompiler<D> {
150    fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
151        self.build_metadata(&value);
152
153        let instructions = self.compile_scope(&mut value.body.clone());
154        let tensor_maps = value
155            .tensor_maps
156            .into_iter()
157            .map(|b| self.compile_binding(b))
158            .collect();
159        let buffers = value
160            .buffers
161            .into_iter()
162            .map(|b| self.compile_binding(b))
163            .collect();
164        let scalars = value
165            .scalars
166            .into_iter()
167            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
168            .collect();
169
170        // translation flags
171        let flags = Flags {
172            indexes: D::builtin_rules(&self.flags.indexes),
173            inst_wmma: self.flags.inst_wmma,
174            op_pipeline: self.flags.op_pipeline,
175            op_barrier: self.flags.op_barrier,
176            elem_fp4: self.flags.elem_fp4,
177            elem_fp6: self.flags.elem_fp6,
178            elem_fp8: self.flags.elem_fp8,
179            elem_bf16: self.flags.elem_bf16,
180            elem_f16: self.flags.elem_f16,
181            elem_tf32: self.flags.elem_tf32,
182            inst_tma: self.flags.inst_tma,
183            inst_tma_im2col: self.flags.inst_tma_im2col,
184            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
185            use_grid_constants: self.compilation_options.supports_features.grid_constants,
186            // TODO: At some point we should only pass dynamic meta if tensors are present,
187            // not if only arrays are present. For now, leave like this
188            has_dynamic_meta: self.metadata.static_len() > 0,
189            static_meta_length: self.metadata.static_len() as usize,
190            cube_dim: value.cube_dim,
191            cluster_dim: value.options.cluster_dim,
192        };
193
194        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
195        let shared_allocs = opt.analysis::<SharedLiveness>();
196        let shared_memories = shared_allocs
197            .allocations
198            .values()
199            .map(|alloc| SharedMemory {
200                index: alloc.smem.id,
201                item: self.compile_type(alloc.smem.ty),
202                length: alloc.smem.length,
203                align: alloc.smem.align,
204                offset: alloc.offset,
205            })
206            .collect();
207
208        let body = Body {
209            instructions,
210            shared_memories,
211            pipelines: self.pipelines,
212            barriers: self.barriers,
213            const_arrays: self.const_arrays,
214            local_arrays: self.local_arrays,
215        };
216
217        let mut cluster_dim = value.options.cluster_dim;
218        if !self.compilation_options.supports_features.clusters {
219            cluster_dim = None;
220        }
221
222        ComputeKernel {
223            tensor_maps,
224            buffers,
225            scalars,
226            meta_static_len: self.metadata.static_len() as usize,
227            cube_dim: value.cube_dim,
228            body,
229            extensions: self.extensions,
230            flags,
231            items: self.items,
232            kernel_name: value.options.kernel_name,
233            cluster_dim,
234        }
235    }
236
237    fn build_metadata(&mut self, value: &KernelDefinition) {
238        let mut num_ext = 0;
239
240        let mut all_meta: Vec<_> = value
241            .buffers
242            .iter()
243            .chain(value.tensor_maps.iter())
244            .map(|buf| (buf.id, buf.has_extended_meta))
245            .collect();
246
247        all_meta.sort_by_key(|(id, _)| *id);
248
249        for (_, has_extended_meta) in &all_meta {
250            self.ext_meta_positions.push(num_ext);
251            if *has_extended_meta {
252                num_ext += 1;
253            }
254        }
255
256        let num_meta = all_meta.len();
257
258        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
259    }
260
261    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
262        let id = var.index().expect("Variable should have index");
263        self.ext_meta_positions[id as usize]
264    }
265
266    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
267        let mut instructions = Vec::new();
268
269        let const_arrays = scope
270            .const_arrays
271            .drain(..)
272            .map(|(var, values)| ConstArray {
273                index: var.index().unwrap(),
274                item: self.compile_type(var.ty),
275                size: values.len() as u32,
276                values: values
277                    .into_iter()
278                    .map(|val| self.compile_variable(val))
279                    .collect(),
280            })
281            .collect::<Vec<_>>();
282        self.const_arrays.extend(const_arrays);
283
284        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
285        let dialect_processors = D::processors();
286        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
287        processors.extend(dialect_processors.iter().map(|it| &**it));
288
289        let processing = scope.process(processors);
290
291        for var in processing.variables {
292            instructions.push(Instruction::DeclareVariable {
293                var: self.compile_variable(var),
294            });
295        }
296
297        processing
298            .instructions
299            .into_iter()
300            .for_each(|op| self.compile_instruction(&mut instructions, op));
301
302        instructions
303    }
304
305    fn compile_instruction(
306        &mut self,
307        instructions: &mut Vec<Instruction<D>>,
308        instruction: gpu::Instruction,
309    ) {
310        self.update_debug_loc(instructions, &instruction);
311        let out = instruction.out;
312        match instruction.operation {
313            gpu::Operation::Copy(variable) => {
314                instructions.push(Instruction::Assign(UnaryInstruction {
315                    input: self.compile_variable(variable),
316                    out: self.compile_variable(out.unwrap()),
317                }));
318            }
319            gpu::Operation::Arithmetic(op) => {
320                self.compile_arithmetic(op, out, instruction.modes, instructions)
321            }
322            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
323            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
324            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
325            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
326            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
327            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
328            gpu::Operation::Synchronization(val) => match val {
329                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
330                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
331                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
332                gpu::Synchronization::SyncAsyncProxyShared => {
333                    self.flags.inst_tma = true;
334                    instructions.push(Instruction::ProxyAsyncToSharedFence)
335                }
336            },
337            gpu::Operation::Plane(op) => {
338                self.flags.indexes.plane_dim_checked = true;
339                let out = self.compile_variable(out.unwrap());
340                match op {
341                    gpu::Plane::Sum(op) => {
342                        let instruction = WarpInstruction::ReduceSum {
343                            input: self.compile_variable(op.input),
344                            out,
345                        };
346                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
347                        instructions.push(Instruction::Warp(instruction));
348                    }
349                    gpu::Plane::InclusiveSum(op) => {
350                        self.flags.indexes.unit_pos_plane = true;
351                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
352                            input: self.compile_variable(op.input),
353                            out,
354                        }))
355                    }
356                    gpu::Plane::InclusiveProd(op) => {
357                        self.flags.indexes.unit_pos_plane = true;
358                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
359                            input: self.compile_variable(op.input),
360                            out,
361                        }))
362                    }
363                    gpu::Plane::ExclusiveSum(op) => {
364                        self.flags.indexes.unit_pos_plane = true;
365                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
366                            input: self.compile_variable(op.input),
367                            out,
368                        }))
369                    }
370                    gpu::Plane::ExclusiveProd(op) => {
371                        self.flags.indexes.unit_pos_plane = true;
372                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
373                            input: self.compile_variable(op.input),
374                            out,
375                        }))
376                    }
377                    gpu::Plane::Prod(op) => {
378                        let instruction = WarpInstruction::ReduceProd {
379                            input: self.compile_variable(op.input),
380                            out,
381                        };
382                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
383                        instructions.push(Instruction::Warp(instruction))
384                    }
385                    gpu::Plane::Max(op) => {
386                        let instruction = WarpInstruction::ReduceMax {
387                            input: self.compile_variable(op.input),
388                            out,
389                        };
390                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
391                        instructions.push(Instruction::Warp(instruction))
392                    }
393                    gpu::Plane::Min(op) => {
394                        let instruction = WarpInstruction::ReduceMin {
395                            input: self.compile_variable(op.input),
396                            out,
397                        };
398                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
399                        instructions.push(Instruction::Warp(instruction))
400                    }
401                    gpu::Plane::Elect => {
402                        if self.compilation_options.supports_features.elect_sync {
403                            self.flags.inst_ptx_wrappers = true;
404                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
405                        } else {
406                            instructions
407                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
408                        }
409                    }
410                    gpu::Plane::All(op) => {
411                        instructions.push(Instruction::Warp(WarpInstruction::All {
412                            input: self.compile_variable(op.input),
413                            out,
414                        }))
415                    }
416                    gpu::Plane::Any(op) => {
417                        instructions.push(Instruction::Warp(WarpInstruction::Any {
418                            input: self.compile_variable(op.input),
419                            out,
420                        }))
421                    }
422                    gpu::Plane::Ballot(op) => {
423                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
424                            input: self.compile_variable(op.input),
425                            out,
426                        }))
427                    }
428                    gpu::Plane::Broadcast(op) => {
429                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
430                            input: self.compile_variable(op.lhs),
431                            id: self.compile_variable(op.rhs),
432                            out,
433                        }))
434                    }
435                    gpu::Plane::Shuffle(op) => {
436                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
437                            input: self.compile_variable(op.lhs),
438                            src_lane: self.compile_variable(op.rhs),
439                            out,
440                        }))
441                    }
442                    gpu::Plane::ShuffleXor(op) => {
443                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
444                            input: self.compile_variable(op.lhs),
445                            mask: self.compile_variable(op.rhs),
446                            out,
447                        }))
448                    }
449                    gpu::Plane::ShuffleUp(op) => {
450                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
451                            input: self.compile_variable(op.lhs),
452                            delta: self.compile_variable(op.rhs),
453                            out,
454                        }))
455                    }
456                    gpu::Plane::ShuffleDown(op) => {
457                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
458                            input: self.compile_variable(op.lhs),
459                            delta: self.compile_variable(op.rhs),
460                            out,
461                        }))
462                    }
463                }
464            }
465            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
466            gpu::Operation::NonSemantic(debug) => match debug {
467                gpu::NonSemantic::Print {
468                    format_string,
469                    args,
470                } => instructions.push(Instruction::Printf {
471                    format_string,
472                    args: args
473                        .into_iter()
474                        .map(|arg| self.compile_variable(arg))
475                        .collect(),
476                }),
477                gpu::NonSemantic::Comment { content } => {
478                    instructions.push(Instruction::Comment { content })
479                }
480                // Don't need to handle scopes
481                _ => {}
482            },
483            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
484                gpu::BarrierOps::Declare { barrier } => {
485                    let VariableKind::Barrier { level, .. } = barrier.kind else {
486                        unreachable!()
487                    };
488                    let barrier = self.compile_variable(barrier);
489                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
490                        barrier,
491                        level,
492                    }));
493                }
494                gpu::BarrierOps::Init {
495                    barrier,
496                    is_elected,
497                    arrival_count,
498                    with_async_proxy_fence,
499                } => {
500                    let VariableKind::Barrier { level, .. } = barrier.kind else {
501                        unreachable!()
502                    };
503                    let barrier = self.compile_variable(barrier);
504                    let arrival_count = self.compile_variable(arrival_count);
505                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
506                        barrier,
507                        is_elected: self.compile_variable(is_elected),
508                        arrival_count,
509                        level,
510                        with_async_proxy_fence,
511                    }));
512                }
513                gpu::BarrierOps::InitManual {
514                    barrier,
515                    arrival_count,
516                } => {
517                    let barrier = self.compile_variable(barrier);
518                    let arrival_count = self.compile_variable(arrival_count);
519                    instructions.push(Instruction::Barrier(
520                        super::barrier::BarrierOps::InitManual {
521                            barrier,
522                            arrival_count,
523                        },
524                    ));
525                }
526                gpu::BarrierOps::MemCopyAsync {
527                    barrier,
528                    source,
529                    source_length,
530                    offset_source,
531                    offset_out,
532                } => {
533                    instructions.push(Instruction::Barrier(
534                        super::barrier::BarrierOps::MemCopyAsync {
535                            barrier: self.compile_variable(barrier),
536                            source: self.compile_variable(source),
537                            destination: self.compile_variable(out.unwrap()),
538                            source_length: self.compile_variable(source_length),
539                            offset_source: self.compile_variable(offset_source),
540                            offset_out: self.compile_variable(offset_out),
541                            cooperative: false,
542                        },
543                    ));
544                }
545                gpu::BarrierOps::MemCopyAsyncCooperative {
546                    barrier,
547                    source,
548                    source_length,
549                    offset_source,
550                    offset_out,
551                } => {
552                    instructions.push(Instruction::Barrier(
553                        super::barrier::BarrierOps::MemCopyAsync {
554                            barrier: self.compile_variable(barrier),
555                            source: self.compile_variable(source),
556                            destination: self.compile_variable(out.unwrap()),
557                            source_length: self.compile_variable(source_length),
558                            offset_source: self.compile_variable(offset_source),
559                            offset_out: self.compile_variable(offset_out),
560                            cooperative: true,
561                        },
562                    ));
563                }
564                gpu::BarrierOps::MemCopyAsyncTx {
565                    barrier,
566                    source,
567                    source_length,
568                    offset_source,
569                    offset_out,
570                } => {
571                    instructions.push(Instruction::Barrier(
572                        super::barrier::BarrierOps::MemCopyAsyncTx {
573                            barrier: self.compile_variable(barrier),
574                            source: self.compile_variable(source),
575                            destination: self.compile_variable(out.unwrap()),
576                            source_length: self.compile_variable(source_length),
577                            offset_source: self.compile_variable(offset_source),
578                            offset_out: self.compile_variable(offset_out),
579                        },
580                    ));
581                }
582                gpu::BarrierOps::TmaLoad {
583                    barrier,
584                    tensor_map,
585                    offset_out,
586                    indices,
587                } => {
588                    instructions.push(Instruction::Barrier(
589                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
590                            barrier: self.compile_variable(barrier),
591                            smem_buffer: self.compile_variable(out.unwrap()),
592                            smem_offset: self.compile_variable(offset_out),
593                            tensor_map: self.compile_variable(tensor_map),
594                            indices: indices
595                                .into_iter()
596                                .map(|it| self.compile_variable(it))
597                                .collect(),
598                        },
599                    ));
600                }
601                gpu::BarrierOps::TmaLoadIm2col {
602                    barrier,
603                    tensor_map,
604                    offset_out,
605                    indices,
606                    offsets,
607                } => {
608                    self.flags.inst_tma_im2col = true;
609                    instructions.push(Instruction::Barrier(
610                        super::barrier::BarrierOps::TmaLoadIm2col {
611                            barrier: self.compile_variable(barrier),
612                            smem_buffer: self.compile_variable(out.unwrap()),
613                            smem_offset: self.compile_variable(offset_out),
614                            tensor_map: self.compile_variable(tensor_map),
615                            indices: indices
616                                .into_iter()
617                                .map(|it| self.compile_variable(it))
618                                .collect(),
619                            offsets: offsets
620                                .into_iter()
621                                .map(|it| self.compile_variable(it))
622                                .collect(),
623                        },
624                    ));
625                }
626                gpu::BarrierOps::Arrive { barrier } => {
627                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
628                        barrier: self.compile_variable(barrier),
629                        token: self.compile_variable(out.unwrap()),
630                    }))
631                }
632                gpu::BarrierOps::ArriveTx {
633                    barrier,
634                    arrive_count_update,
635                    transaction_count_update,
636                } => {
637                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
638                        barrier: self.compile_variable(barrier),
639                        token: self.compile_variable(out.unwrap()),
640                        arrive_count_update: self.compile_variable(arrive_count_update),
641                        transaction_count_update: self.compile_variable(transaction_count_update),
642                    }))
643                }
644                gpu::BarrierOps::ExpectTx {
645                    barrier,
646                    transaction_count_update,
647                } => {
648                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
649                        barrier: self.compile_variable(barrier),
650                        transaction_count_update: self.compile_variable(transaction_count_update),
651                    }))
652                }
653                gpu::BarrierOps::Wait { barrier, token } => {
654                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
655                        barrier: self.compile_variable(barrier),
656                        token: self.compile_variable(token),
657                    }))
658                }
659                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
660                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
661                        barrier: self.compile_variable(barrier),
662                        phase: self.compile_variable(phase),
663                    }),
664                ),
665                gpu::BarrierOps::ArriveAndWait { barrier } => {
666                    let VariableKind::Barrier { level, .. } = barrier.kind else {
667                        unreachable!()
668                    };
669                    instructions.push(Instruction::Barrier(
670                        super::barrier::BarrierOps::ArriveAndWait {
671                            barrier: self.compile_variable(barrier),
672                            level,
673                        },
674                    ))
675                }
676            },
677            gpu::Operation::Tma(tma_ops) => {
678                self.flags.inst_tma = true;
679                match tma_ops {
680                    gpu::TmaOps::TmaStore {
681                        source,
682                        coordinates,
683                        offset_source,
684                    } => {
685                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
686                            smem_buffer: self.compile_variable(source),
687                            smem_offset: self.compile_variable(offset_source),
688                            tensor_map: self.compile_variable(out.unwrap()),
689                            indices: coordinates
690                                .into_iter()
691                                .map(|it| self.compile_variable(it))
692                                .collect(),
693                        });
694                    }
695                    gpu::TmaOps::CommitGroup => {
696                        instructions.push(Instruction::BulkCommitGroup);
697                    }
698                    gpu::TmaOps::WaitGroup { max_pending } => {
699                        instructions.push(Instruction::BulkWaitGroup { max_pending });
700                    }
701                    gpu::TmaOps::WaitGroupRead { max_pending } => {
702                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
703                    }
704                }
705            }
706            gpu::Operation::Marker(_) => {}
707        }
708    }
709
710    fn update_debug_loc(
711        &mut self,
712        instructions: &mut Vec<Instruction<D>>,
713        inst: &gpu::Instruction,
714    ) {
715        if !matches!(inst.operation, Operation::NonSemantic(_)) {
716            match &inst.source_loc {
717                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
718                    self.source_loc = Some(loc.clone());
719                    instructions.push(Instruction::Line {
720                        file: loc.source.file.clone(),
721                        line: loc.line,
722                    });
723                }
724                _ => {}
725            }
726        }
727    }
728
729    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
730        self.flags.inst_wmma = true;
731
732        let out = self.compile_variable(out.unwrap());
733
734        let inst = match cmma {
735            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
736                frag: out,
737                value: self.compile_variable(value),
738            },
739            gpu::CoopMma::Load {
740                value,
741                stride,
742                offset,
743                layout,
744            } => WmmaInstruction::Load {
745                frag: out,
746                offset: self.compile_variable(offset),
747                value: self.compile_variable(value),
748                stride: self.compile_variable(stride),
749                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
750            },
751            gpu::CoopMma::Execute {
752                mat_a,
753                mat_b,
754                mat_c,
755            } => WmmaInstruction::Execute {
756                frag_a: self.compile_variable(mat_a),
757                frag_b: self.compile_variable(mat_b),
758                frag_c: self.compile_variable(mat_c),
759                frag_d: out,
760                warp_size: self.compilation_options.warp_size,
761            },
762            gpu::CoopMma::ExecuteManual {
763                matrix,
764                registers_a,
765                registers_b,
766                registers_c,
767            } => WmmaInstruction::ExecuteManual {
768                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
769                frag_a: self.compile_variable(registers_a),
770                frag_b: self.compile_variable(registers_b),
771                frag_c: self.compile_variable(registers_c),
772                frag_d: out,
773            },
774            gpu::CoopMma::ExecuteScaled {
775                matrix,
776                registers_a,
777                registers_b,
778                registers_c,
779                scales_a,
780                scales_b,
781                scales_factor,
782            } => WmmaInstruction::ExecuteScaled {
783                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
784                frag_a: self.compile_variable(registers_a),
785                frag_b: self.compile_variable(registers_b),
786                frag_c: self.compile_variable(registers_c),
787                frag_d: out,
788
789                scales_a: self.compile_variable(scales_a),
790                scales_b: self.compile_variable(scales_b),
791                scales_factor,
792            },
793            gpu::CoopMma::Store {
794                mat,
795                stride,
796                offset,
797                layout,
798            } => {
799                self.flags.indexes.unit_pos = true;
800                self.flags.indexes.plane_index = true;
801                WmmaInstruction::Store {
802                    output: out,
803                    offset: self.compile_variable(offset),
804                    frag: self.compile_variable(mat),
805                    stride: self.compile_variable(stride),
806                    layout: self
807                        .compile_matrix_layout(layout)
808                        .expect("Layout required for store instruction"),
809                }
810            }
811            gpu::CoopMma::LoadMatrix {
812                buffer,
813                offset,
814                line_size,
815                factor,
816                transpose,
817            } => WmmaInstruction::LdMatrix {
818                output: out,
819                buffer: self.compile_variable(buffer),
820                offset: self.compile_variable(offset),
821                line_size,
822                factor,
823                transpose,
824            },
825            gpu::CoopMma::StoreMatrix { .. } => {
826                todo!()
827            }
828            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
829                input: self.compile_variable(input),
830                output: out,
831            },
832            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
833                panic!("Row/Col index should be handled by processors")
834            }
835        };
836
837        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
838
839        Instruction::Wmma(inst)
840    }
841
842    fn compile_metadata(
843        &mut self,
844        metadata: gpu::Metadata,
845        out: Option<gpu::Variable>,
846    ) -> Instruction<D> {
847        let out = out.unwrap();
848        match metadata {
849            gpu::Metadata::Stride { dim, var } => {
850                let position = self.ext_meta_position(var);
851                let offset = self.metadata.stride_offset_index(position);
852                Instruction::ExtendedMetadata {
853                    info_offset: self.compile_variable(offset.into()),
854                    dim: self.compile_variable(dim),
855                    split_meta: self.compilation_options.supports_features.grid_constants,
856                    static_offset: self.metadata.static_len(),
857                    out: self.compile_variable(out),
858                }
859            }
860            gpu::Metadata::Shape { dim, var } => {
861                let position = self.ext_meta_position(var);
862                let offset = self.metadata.shape_offset_index(position);
863                Instruction::ExtendedMetadata {
864                    info_offset: self.compile_variable(offset.into()),
865                    dim: self.compile_variable(dim),
866                    split_meta: self.compilation_options.supports_features.grid_constants,
867                    static_offset: self.metadata.static_len(),
868                    out: self.compile_variable(out),
869                }
870            }
871            gpu::Metadata::Rank { var } => {
872                let out = self.compile_variable(out);
873                let pos = self.ext_meta_position(var);
874                let offset = self.metadata.rank_index(pos);
875                super::Instruction::Metadata {
876                    info_offset: self.compile_variable(offset.into()),
877                    split_meta: self.compilation_options.supports_features.grid_constants,
878                    out,
879                }
880            }
881            gpu::Metadata::Length { var } => {
882                let input = self.compile_variable(var);
883                let out = self.compile_variable(out);
884
885                match input {
886                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
887                    Variable::SharedMemory(_id, _item, length) => {
888                        Instruction::ConstLength { length, out }
889                    }
890                    _ => {
891                        let id = input.id().expect("Variable should have id");
892                        let offset = self.metadata.len_index(id);
893                        Instruction::Metadata {
894                            info_offset: self.compile_variable(offset.into()),
895                            split_meta: self.compilation_options.supports_features.grid_constants,
896                            out,
897                        }
898                    }
899                }
900            }
901            gpu::Metadata::BufferLength { var } => {
902                let input = self.compile_variable(var);
903                let out = self.compile_variable(out);
904
905                match input {
906                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
907                    _ => {
908                        let id = input.id().expect("Variable should have id");
909                        let offset = self.metadata.buffer_len_index(id);
910                        Instruction::Metadata {
911                            info_offset: self.compile_variable(offset.into()),
912                            split_meta: self.compilation_options.supports_features.grid_constants,
913                            out,
914                        }
915                    }
916                }
917            }
918        }
919    }
920
921    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
922        match branch {
923            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
924                cond: self.compile_variable(op.cond),
925                instructions: self.compile_scope(&mut op.scope),
926            }),
927            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
928                cond: self.compile_variable(op.cond),
929                instructions_if: self.compile_scope(&mut op.scope_if),
930                instructions_else: self.compile_scope(&mut op.scope_else),
931            }),
932            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
933                value: self.compile_variable(op.value),
934                instructions_default: self.compile_scope(&mut op.scope_default),
935                instructions_cases: op
936                    .cases
937                    .into_iter()
938                    .map(|(val, mut block)| {
939                        (self.compile_variable(val), self.compile_scope(&mut block))
940                    })
941                    .collect(),
942            }),
943            gpu::Branch::Return => instructions.push(Instruction::Return),
944            gpu::Branch::Break => instructions.push(Instruction::Break),
945            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
946                i: self.compile_variable(range_loop.i),
947                start: self.compile_variable(range_loop.start),
948                end: self.compile_variable(range_loop.end),
949                step: range_loop.step.map(|it| self.compile_variable(it)),
950                inclusive: range_loop.inclusive,
951                instructions: self.compile_scope(&mut range_loop.scope),
952            }),
953            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
954                instructions: self.compile_scope(&mut op.scope),
955            }),
956        };
957    }
958
959    fn compile_atomic(
960        &mut self,
961        value: gpu::AtomicOp,
962        out: Option<gpu::Variable>,
963        instructions: &mut Vec<Instruction<D>>,
964    ) {
965        let out = out.unwrap();
966        match value {
967            gpu::AtomicOp::Load(op) => {
968                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
969            }
970            gpu::AtomicOp::Store(op) => {
971                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
972            }
973            gpu::AtomicOp::Swap(op) => {
974                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
975            }
976            gpu::AtomicOp::Add(op) => {
977                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
978            }
979            gpu::AtomicOp::Sub(op) => {
980                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
981            }
982            gpu::AtomicOp::Max(op) => {
983                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
984            }
985            gpu::AtomicOp::Min(op) => {
986                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
987            }
988            gpu::AtomicOp::And(op) => {
989                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
990            }
991            gpu::AtomicOp::Or(op) => {
992                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
993            }
994            gpu::AtomicOp::Xor(op) => {
995                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
996            }
997            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
998                input: self.compile_variable(op.input),
999                cmp: self.compile_variable(op.cmp),
1000                val: self.compile_variable(op.val),
1001                out: self.compile_variable(out),
1002            }),
1003        }
1004    }
1005
1006    fn compile_arithmetic(
1007        &mut self,
1008        value: gpu::Arithmetic,
1009        out: Option<gpu::Variable>,
1010        modes: InstructionModes,
1011        instructions: &mut Vec<Instruction<D>>,
1012    ) {
1013        let out = out.unwrap();
1014        match value {
1015            gpu::Arithmetic::Add(op) => {
1016                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1017            }
1018            gpu::Arithmetic::SaturatingAdd(op) => {
1019                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1020            }
1021            gpu::Arithmetic::Mul(op) => {
1022                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1023            }
1024            gpu::Arithmetic::Div(op) => {
1025                let op = self.compile_binary(op, out);
1026                instructions.push(self.select_fast_float(
1027                    out.ty,
1028                    modes,
1029                    FastMath::AllowReciprocal
1030                        | FastMath::ReducedPrecision
1031                        | FastMath::UnsignedZero
1032                        | FastMath::NotInf,
1033                    Instruction::Div(op),
1034                    Instruction::FastDiv(op),
1035                ))
1036            }
1037            gpu::Arithmetic::Sub(op) => {
1038                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1039            }
1040            gpu::Arithmetic::SaturatingSub(op) => {
1041                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1042            }
1043            gpu::Arithmetic::MulHi(op) => {
1044                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1045                D::register_instruction_extension(&mut self.extensions, &instruction);
1046                instructions.push(instruction)
1047            }
1048            gpu::Arithmetic::Modulo(op) => {
1049                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1050            }
1051            gpu::Arithmetic::Abs(op) => {
1052                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1053            }
1054            gpu::Arithmetic::Exp(op) => {
1055                let op = self.compile_unary(op, out);
1056                instructions.push(self.select_fast_float(
1057                    out.ty,
1058                    modes,
1059                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1060                    Instruction::Exp(op),
1061                    Instruction::FastExp(op),
1062                ));
1063            }
1064            gpu::Arithmetic::Log(op) => {
1065                let op = self.compile_unary(op, out);
1066                instructions.push(self.select_fast_float(
1067                    out.ty,
1068                    modes,
1069                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1070                    Instruction::Log(op),
1071                    Instruction::FastLog(op),
1072                ));
1073            }
1074            gpu::Arithmetic::Log1p(op) => {
1075                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1076            }
1077            gpu::Arithmetic::Cos(op) => {
1078                let op = self.compile_unary(op, out);
1079                instructions.push(self.select_fast_float(
1080                    out.ty,
1081                    modes,
1082                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1083                    Instruction::Cos(op),
1084                    Instruction::FastCos(op),
1085                ));
1086            }
1087            gpu::Arithmetic::Sin(op) => {
1088                let op = self.compile_unary(op, out);
1089                instructions.push(self.select_fast_float(
1090                    out.ty,
1091                    modes,
1092                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1093                    Instruction::Sin(op),
1094                    Instruction::FastSin(op),
1095                ));
1096            }
1097            gpu::Arithmetic::Tan(op) => {
1098                instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1099            }
1100            gpu::Arithmetic::Tanh(op) => {
1101                let op = self.compile_unary(op, out);
1102                let instruction = Instruction::Tanh(op);
1103                D::register_instruction_extension(&mut self.extensions, &instruction);
1104                if self.compilation_options.supports_features.fast_tanh {
1105                    instructions.push(self.select_fast_float(
1106                        out.ty,
1107                        modes,
1108                        FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1109                        instruction,
1110                        Instruction::FastTanh(op),
1111                    ))
1112                } else {
1113                    instructions.push(instruction);
1114                }
1115            }
1116            gpu::Arithmetic::Sinh(op) => {
1117                let instruction = Instruction::Sinh(self.compile_unary(op, out));
1118                D::register_instruction_extension(&mut self.extensions, &instruction);
1119                instructions.push(instruction)
1120            }
1121            gpu::Arithmetic::Cosh(op) => {
1122                let instruction = Instruction::Cosh(self.compile_unary(op, out));
1123                D::register_instruction_extension(&mut self.extensions, &instruction);
1124                instructions.push(instruction)
1125            }
1126            gpu::Arithmetic::ArcCos(op) => {
1127                let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1128                D::register_instruction_extension(&mut self.extensions, &instruction);
1129                instructions.push(instruction)
1130            }
1131            gpu::Arithmetic::ArcSin(op) => {
1132                let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1133                D::register_instruction_extension(&mut self.extensions, &instruction);
1134                instructions.push(instruction)
1135            }
1136            gpu::Arithmetic::ArcTan(op) => {
1137                let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1138                D::register_instruction_extension(&mut self.extensions, &instruction);
1139                instructions.push(instruction)
1140            }
1141            gpu::Arithmetic::ArcSinh(op) => {
1142                let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1143                D::register_instruction_extension(&mut self.extensions, &instruction);
1144                instructions.push(instruction)
1145            }
1146            gpu::Arithmetic::ArcCosh(op) => {
1147                let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1148                D::register_instruction_extension(&mut self.extensions, &instruction);
1149                instructions.push(instruction)
1150            }
1151            gpu::Arithmetic::ArcTanh(op) => {
1152                let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1153                D::register_instruction_extension(&mut self.extensions, &instruction);
1154                instructions.push(instruction)
1155            }
1156            gpu::Arithmetic::Degrees(op) => {
1157                let instruction = Instruction::Degrees(self.compile_unary(op, out));
1158                D::register_instruction_extension(&mut self.extensions, &instruction);
1159                instructions.push(instruction)
1160            }
1161            gpu::Arithmetic::Radians(op) => {
1162                let instruction = Instruction::Radians(self.compile_unary(op, out));
1163                D::register_instruction_extension(&mut self.extensions, &instruction);
1164                instructions.push(instruction)
1165            }
1166            gpu::Arithmetic::ArcTan2(op) => {
1167                let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1168                D::register_instruction_extension(&mut self.extensions, &instruction);
1169                instructions.push(instruction)
1170            }
1171            gpu::Arithmetic::Powf(op) => {
1172                let op = self.compile_binary(op, out);
1173                instructions.push(self.select_fast_float(
1174                    out.ty,
1175                    modes,
1176                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1177                    Instruction::Powf(op),
1178                    Instruction::FastPowf(op),
1179                ))
1180            }
1181            gpu::Arithmetic::Powi(op) => {
1182                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1183            }
1184            gpu::Arithmetic::Sqrt(op) => {
1185                let op = self.compile_unary(op, out);
1186                instructions.push(self.select_fast_float(
1187                    out.ty,
1188                    modes,
1189                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1190                    Instruction::Sqrt(op),
1191                    Instruction::FastSqrt(op),
1192                ))
1193            }
1194            gpu::Arithmetic::InverseSqrt(op) => {
1195                let op = self.compile_unary(op, out);
1196                instructions.push(self.select_fast_float(
1197                    out.ty,
1198                    modes,
1199                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1200                    Instruction::InverseSqrt(op),
1201                    Instruction::FastInverseSqrt(op),
1202                ))
1203            }
1204            gpu::Arithmetic::Erf(op) => {
1205                let instruction = Instruction::Erf(self.compile_unary(op, out));
1206                D::register_instruction_extension(&mut self.extensions, &instruction);
1207                instructions.push(instruction)
1208            }
1209            gpu::Arithmetic::Max(op) => {
1210                let instruction = Instruction::Max(self.compile_binary(op, out));
1211                D::register_instruction_extension(&mut self.extensions, &instruction);
1212                instructions.push(instruction)
1213            }
1214            gpu::Arithmetic::Min(op) => {
1215                let instruction = Instruction::Min(self.compile_binary(op, out));
1216                D::register_instruction_extension(&mut self.extensions, &instruction);
1217                instructions.push(instruction)
1218            }
1219            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1220                input: self.compile_variable(op.input),
1221                min_value: self.compile_variable(op.min_value),
1222                max_value: self.compile_variable(op.max_value),
1223                out: self.compile_variable(out),
1224            }),
1225            gpu::Arithmetic::Recip(op) => {
1226                let elem = op.input.ty.elem_type();
1227                let input = self.compile_variable(op.input);
1228                let out = self.compile_variable(out);
1229                let lhs = match elem {
1230                    gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1231                    gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1232                    gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1233                    gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1234                };
1235                let div = Instruction::Div(BinaryInstruction {
1236                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1237                    rhs: input,
1238                    out,
1239                });
1240                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1241
1242                instructions.push(self.select_fast_float(
1243                    elem.into(),
1244                    modes,
1245                    FastMath::AllowReciprocal
1246                        | FastMath::ReducedPrecision
1247                        | FastMath::UnsignedZero
1248                        | FastMath::NotInf,
1249                    div,
1250                    recip,
1251                ))
1252            }
1253            gpu::Arithmetic::Round(op) => {
1254                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1255            }
1256            gpu::Arithmetic::Floor(op) => {
1257                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1258            }
1259            gpu::Arithmetic::Ceil(op) => {
1260                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1261            }
1262            gpu::Arithmetic::Trunc(op) => {
1263                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1264            }
1265            gpu::Arithmetic::Remainder(op) => {
1266                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1267            }
1268            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1269                a: self.compile_variable(op.a),
1270                b: self.compile_variable(op.b),
1271                c: self.compile_variable(op.c),
1272                out: self.compile_variable(out),
1273            }),
1274            gpu::Arithmetic::Neg(op) => {
1275                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1276            }
1277            gpu::Arithmetic::Normalize(op) => {
1278                let op = self.compile_unary(op, out);
1279                instructions.push(self.select_fast_float(
1280                    out.ty,
1281                    modes,
1282                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1283                    Instruction::Normalize(op),
1284                    Instruction::FastNormalize(op),
1285                ))
1286            }
1287            gpu::Arithmetic::Magnitude(op) => {
1288                let op = self.compile_unary(op, out);
1289                instructions.push(self.select_fast_float(
1290                    out.ty,
1291                    modes,
1292                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1293                    Instruction::Magnitude(op),
1294                    Instruction::FastMagnitude(op),
1295                ))
1296            }
1297            gpu::Arithmetic::Dot(op) => {
1298                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1299            }
1300        };
1301    }
1302
1303    fn select_fast_float(
1304        &self,
1305        ty: gpu::Type,
1306        modes: InstructionModes,
1307        required_flags: EnumSet<FastMath>,
1308        default: Instruction<D>,
1309        fast: Instruction<D>,
1310    ) -> Instruction<D> {
1311        if !self.compilation_options.supports_features.fast_math
1312            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1313        {
1314            return default;
1315        }
1316
1317        if modes.fp_math_mode.is_superset(required_flags) {
1318            fast
1319        } else {
1320            default
1321        }
1322    }
1323
1324    fn compile_comparison(
1325        &mut self,
1326        value: gpu::Comparison,
1327        out: Option<gpu::Variable>,
1328        instructions: &mut Vec<Instruction<D>>,
1329    ) {
1330        let out = out.unwrap();
1331        match value {
1332            gpu::Comparison::Equal(op) => {
1333                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1334            }
1335            gpu::Comparison::Lower(op) => {
1336                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1337            }
1338            gpu::Comparison::Greater(op) => {
1339                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1340            }
1341            gpu::Comparison::LowerEqual(op) => {
1342                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1343            }
1344            gpu::Comparison::GreaterEqual(op) => {
1345                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1346            }
1347            gpu::Comparison::NotEqual(op) => {
1348                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1349            }
1350            gpu::Comparison::IsNan(op) => {
1351                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1352            }
1353            gpu::Comparison::IsInf(op) => {
1354                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1355            }
1356        };
1357    }
1358
1359    fn compile_bitwise(
1360        &mut self,
1361        value: gpu::Bitwise,
1362        out: Option<gpu::Variable>,
1363        instructions: &mut Vec<Instruction<D>>,
1364    ) {
1365        let out = out.unwrap();
1366        match value {
1367            gpu::Bitwise::BitwiseOr(op) => {
1368                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1369            }
1370            gpu::Bitwise::BitwiseAnd(op) => {
1371                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1372            }
1373            gpu::Bitwise::BitwiseXor(op) => {
1374                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1375            }
1376            gpu::Bitwise::CountOnes(op) => {
1377                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1378            }
1379            gpu::Bitwise::ReverseBits(op) => {
1380                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1381            }
1382            gpu::Bitwise::ShiftLeft(op) => {
1383                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1384            }
1385            gpu::Bitwise::ShiftRight(op) => {
1386                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1387            }
1388            gpu::Bitwise::BitwiseNot(op) => {
1389                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1390            }
1391            gpu::Bitwise::LeadingZeros(op) => {
1392                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1393            }
1394            gpu::Bitwise::FindFirstSet(op) => {
1395                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1396                D::register_instruction_extension(&mut self.extensions, &instruction);
1397                instructions.push(instruction)
1398            }
1399        };
1400    }
1401
1402    fn compile_operator(
1403        &mut self,
1404        value: gpu::Operator,
1405        out: Option<gpu::Variable>,
1406        instructions: &mut Vec<Instruction<D>>,
1407    ) {
1408        let out = out.unwrap();
1409        match value {
1410            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1411                instructions.push(Instruction::Index(self.compile_index(op, out)));
1412            }
1413            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1414                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1415            }
1416            gpu::Operator::And(op) => {
1417                instructions.push(Instruction::And(self.compile_binary(op, out)))
1418            }
1419            gpu::Operator::Or(op) => {
1420                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1421            }
1422            gpu::Operator::Not(op) => {
1423                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1424            }
1425            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1426                inputs: op
1427                    .inputs
1428                    .into_iter()
1429                    .map(|it| self.compile_variable(it))
1430                    .collect(),
1431                out: self.compile_variable(out),
1432            }),
1433            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1434                input: self.compile_variable(op.input),
1435                in_index: self.compile_variable(op.in_index),
1436                out: self.compile_variable(out),
1437                out_index: self.compile_variable(op.out_index),
1438            }),
1439            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1440                input: self.compile_variable(op.input),
1441                in_index: self.compile_variable(op.in_index),
1442                out: self.compile_variable(out),
1443                out_index: self.compile_variable(op.out_index),
1444                len: op.len,
1445            }),
1446            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1447                cond: self.compile_variable(op.cond),
1448                then: self.compile_variable(op.then),
1449                or_else: self.compile_variable(op.or_else),
1450                out: self.compile_variable(out),
1451            }),
1452            // Needs special conversion semantics
1453            gpu::Operator::Cast(op)
1454                if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1455            {
1456                // We may need these for intermediates
1457                self.flags.elem_f16 = true;
1458                self.flags.elem_bf16 = true;
1459                let vec_in = op.input.ty.line_size();
1460                let packing = out.storage_type().packing_factor();
1461                self.compile_type(op.input.ty.line(packing));
1462                self.compile_type(
1463                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1464                );
1465                self.compile_type(
1466                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1467                );
1468                self.compile_type(
1469                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1470                );
1471                self.compile_type(
1472                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1473                );
1474
1475                let inst = self.compile_unary(op, out);
1476
1477                instructions.push(Instruction::SpecialCast(inst));
1478            }
1479            gpu::Operator::Cast(op) => {
1480                let op = self.compile_unary(op, out);
1481
1482                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1483                    self.flags.elem_tf32 = true;
1484                }
1485
1486                instructions.push(Instruction::Assign(op))
1487            }
1488            gpu::Operator::Reinterpret(op) => {
1489                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1490            }
1491        };
1492    }
1493
1494    fn compile_binary(
1495        &mut self,
1496        value: gpu::BinaryOperator,
1497        out: gpu::Variable,
1498    ) -> BinaryInstruction<D> {
1499        BinaryInstruction {
1500            lhs: self.compile_variable(value.lhs),
1501            rhs: self.compile_variable(value.rhs),
1502            out: self.compile_variable(out),
1503        }
1504    }
1505
1506    fn compile_index_assign(
1507        &mut self,
1508        value: gpu::IndexAssignOperator,
1509        out: gpu::Variable,
1510    ) -> IndexAssignInstruction<D> {
1511        IndexAssignInstruction {
1512            index: self.compile_variable(value.index),
1513            value: self.compile_variable(value.value),
1514            line_size: value.line_size,
1515            out: self.compile_variable(out),
1516        }
1517    }
1518
1519    fn compile_index(
1520        &mut self,
1521        value: gpu::IndexOperator,
1522        out: gpu::Variable,
1523    ) -> IndexInstruction<D> {
1524        IndexInstruction {
1525            list: self.compile_variable(value.list),
1526            index: self.compile_variable(value.index),
1527            line_size: value.line_size,
1528            out: self.compile_variable(out),
1529        }
1530    }
1531
1532    fn compile_unary(
1533        &mut self,
1534        value: gpu::UnaryOperator,
1535        out: gpu::Variable,
1536    ) -> UnaryInstruction<D> {
1537        UnaryInstruction {
1538            input: self.compile_variable(value.input),
1539            out: self.compile_variable(out),
1540        }
1541    }
1542
1543    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1544        let item = value.ty;
1545        match value.kind {
1546            gpu::VariableKind::GlobalInputArray(id) => {
1547                Variable::GlobalInputArray(id, self.compile_type(item))
1548            }
1549            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1550                id,
1551                elem: self.compile_storage_type(item.storage_type()),
1552                in_struct: self.compilation_options.supports_features.grid_constants,
1553            },
1554            gpu::VariableKind::TensorMapInput(id) => {
1555                self.flags.inst_tma = true;
1556                Variable::TensorMap(id)
1557            }
1558            gpu::VariableKind::TensorMapOutput(id) => {
1559                self.flags.inst_tma = true;
1560                Variable::TensorMap(id)
1561            }
1562            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1563                id,
1564                item: self.compile_type(item),
1565            },
1566            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1567                id,
1568                item: self.compile_type(item),
1569            },
1570            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1571                id,
1572                item: self.compile_type(item),
1573            },
1574            gpu::VariableKind::GlobalOutputArray(id) => {
1575                Variable::GlobalOutputArray(id, self.compile_type(item))
1576            }
1577            gpu::VariableKind::ConstantScalar(value) => {
1578                Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1579            }
1580            gpu::VariableKind::SharedMemory { id, length, .. } => {
1581                let item = self.compile_type(item);
1582                Variable::SharedMemory(id, item, length)
1583            }
1584            gpu::VariableKind::ConstantArray {
1585                id,
1586                length,
1587                unroll_factor,
1588            } => {
1589                let item = self.compile_type(item);
1590                Variable::ConstantArray(id, item, length * unroll_factor)
1591            }
1592            gpu::VariableKind::Builtin(builtin) => match builtin {
1593                gpu::Builtin::AbsolutePos => {
1594                    self.flags.indexes.absolute_pos = true;
1595                    Variable::AbsolutePos
1596                }
1597                gpu::Builtin::CubePosCluster
1598                    if self.compilation_options.supports_features.clusters =>
1599                {
1600                    self.flags.indexes.cluster_pos = true;
1601                    Variable::ClusterRank
1602                }
1603                gpu::Builtin::CubePosClusterX
1604                    if self.compilation_options.supports_features.clusters =>
1605                {
1606                    self.flags.indexes.cluster_pos = true;
1607                    Variable::ClusterIndexX
1608                }
1609                gpu::Builtin::CubePosClusterY
1610                    if self.compilation_options.supports_features.clusters =>
1611                {
1612                    self.flags.indexes.cluster_pos = true;
1613                    Variable::ClusterIndexY
1614                }
1615                gpu::Builtin::CubePosClusterZ
1616                    if self.compilation_options.supports_features.clusters =>
1617                {
1618                    self.flags.indexes.cluster_pos = true;
1619                    Variable::ClusterIndexZ
1620                }
1621                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1622                // (1, 1, 1) if unsupported
1623                gpu::Builtin::CubePosCluster
1624                | gpu::Builtin::CubePosClusterX
1625                | gpu::Builtin::CubePosClusterY
1626                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1627                gpu::Builtin::AbsolutePosX => {
1628                    self.flags.indexes.absolute_pos_tuple = true;
1629                    Variable::AbsolutePosX
1630                }
1631                gpu::Builtin::AbsolutePosY => {
1632                    self.flags.indexes.absolute_pos_tuple = true;
1633                    Variable::AbsolutePosY
1634                }
1635                gpu::Builtin::AbsolutePosZ => {
1636                    self.flags.indexes.absolute_pos_tuple = true;
1637                    Variable::AbsolutePosZ
1638                }
1639                gpu::Builtin::CubeDim => {
1640                    self.flags.indexes.cube_dim = true;
1641                    Variable::CubeDim
1642                }
1643                gpu::Builtin::CubeDimX => {
1644                    self.flags.indexes.cube_dim_tuple = true;
1645                    Variable::CubeDimX
1646                }
1647                gpu::Builtin::CubeDimY => {
1648                    self.flags.indexes.cube_dim_tuple = true;
1649                    Variable::CubeDimY
1650                }
1651                gpu::Builtin::CubeDimZ => {
1652                    self.flags.indexes.cube_dim_tuple = true;
1653                    Variable::CubeDimZ
1654                }
1655                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1656                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1657                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1658                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1659                gpu::Builtin::CubePos => {
1660                    self.flags.indexes.cube_pos = true;
1661                    Variable::CubePos
1662                }
1663                gpu::Builtin::CubePosX => {
1664                    self.flags.indexes.cube_pos_tuple = true;
1665                    Variable::CubePosX
1666                }
1667                gpu::Builtin::CubePosY => {
1668                    self.flags.indexes.cube_pos_tuple = true;
1669                    Variable::CubePosY
1670                }
1671                gpu::Builtin::CubePosZ => {
1672                    self.flags.indexes.cube_pos_tuple = true;
1673                    Variable::CubePosZ
1674                }
1675                gpu::Builtin::CubeCount => {
1676                    self.flags.indexes.cube_count = true;
1677                    Variable::CubeCount
1678                }
1679                gpu::Builtin::CubeCountX => {
1680                    self.flags.indexes.cube_count_tuple = true;
1681                    Variable::CubeCountX
1682                }
1683                gpu::Builtin::CubeCountY => {
1684                    self.flags.indexes.cube_count_tuple = true;
1685                    Variable::CubeCountY
1686                }
1687                gpu::Builtin::CubeCountZ => {
1688                    self.flags.indexes.cube_count_tuple = true;
1689                    Variable::CubeCountZ
1690                }
1691                gpu::Builtin::UnitPos => {
1692                    self.flags.indexes.unit_pos = true;
1693                    Variable::UnitPos
1694                }
1695                gpu::Builtin::UnitPosX => {
1696                    self.flags.indexes.unit_pos_tuple = true;
1697                    Variable::UnitPosX
1698                }
1699                gpu::Builtin::UnitPosY => {
1700                    self.flags.indexes.unit_pos_tuple = true;
1701                    Variable::UnitPosY
1702                }
1703                gpu::Builtin::UnitPosZ => {
1704                    self.flags.indexes.unit_pos_tuple = true;
1705                    Variable::UnitPosZ
1706                }
1707                gpu::Builtin::PlaneDim => {
1708                    self.flags.indexes.plane_dim = true;
1709                    Variable::PlaneDim
1710                }
1711                gpu::Builtin::UnitPosPlane => {
1712                    self.flags.indexes.unit_pos_plane = true;
1713                    Variable::UnitPosPlane
1714                }
1715            },
1716            gpu::VariableKind::LocalArray {
1717                id,
1718                length,
1719                unroll_factor,
1720            } => {
1721                let item = self.compile_type(item);
1722                if !self.local_arrays.iter().any(|s| s.index == id) {
1723                    self.local_arrays
1724                        .push(LocalArray::new(id, item, length * unroll_factor));
1725                }
1726                Variable::LocalArray(id, item, length)
1727            }
1728            gpu::VariableKind::Matrix { id, mat } => {
1729                self.flags.inst_wmma = true;
1730                Variable::WmmaFragment {
1731                    id,
1732                    frag: self.compile_matrix(mat),
1733                }
1734            }
1735            gpu::VariableKind::Pipeline { id, num_stages } => {
1736                self.flags.op_pipeline = true;
1737                let pipeline = Variable::Pipeline { id };
1738                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1739                    self.pipelines.push(PipelineOps::Init {
1740                        pipeline,
1741                        num_stages,
1742                    });
1743                }
1744                pipeline
1745            }
1746            gpu::VariableKind::Barrier { id, level } => {
1747                self.flags.op_barrier = true;
1748                Variable::Barrier { id, level }
1749            }
1750            gpu::VariableKind::BarrierToken { id, level } => {
1751                self.flags.op_barrier = true;
1752                Variable::BarrierToken { id, level }
1753            }
1754        }
1755    }
1756
1757    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1758        Fragment {
1759            ident: self.compile_matrix_ident(matrix.ident),
1760            m: matrix.m,
1761            n: matrix.n,
1762            k: matrix.k,
1763            elem: self.compile_storage_type(matrix.storage),
1764            layout: self.compile_matrix_layout(matrix.layout),
1765        }
1766    }
1767
1768    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1769        match ident {
1770            gpu::MatrixIdent::A => FragmentIdent::A,
1771            gpu::MatrixIdent::B => FragmentIdent::B,
1772            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1773        }
1774    }
1775
1776    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1777        match layout {
1778            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1779            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1780            gpu::MatrixLayout::Undefined => None,
1781        }
1782    }
1783
1784    fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1785        Binding {
1786            id: binding.id,
1787            item: self.compile_type(binding.ty),
1788            location: binding.location,
1789            size: binding.size,
1790            vis: binding.visibility,
1791        }
1792    }
1793
1794    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1795        let item = match ty {
1796            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1797            gpu::Type::Line(ty, line_size) => {
1798                Item::new(self.compile_storage_type(ty), line_size as usize, false)
1799            }
1800            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1801        };
1802        if item.elem != super::Elem::TF32 {
1803            self.items.insert(item);
1804            self.items.insert(item.optimized());
1805        } else {
1806            // TF32 is represented as `float` in C++
1807            let mut item = item;
1808            item.elem = super::Elem::F32;
1809            self.items.insert(item);
1810        }
1811
1812        item
1813    }
1814
1815    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1816        match value {
1817            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1818            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1819            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1820                FloatKind::E2M1 => {
1821                    self.flags.elem_fp4 = true;
1822                    Elem::FP4x2(FP4Kind::E2M1)
1823                }
1824                FloatKind::E2M3 => {
1825                    self.flags.elem_fp6 = true;
1826                    Elem::FP6x2(FP6Kind::E2M3)
1827                }
1828                FloatKind::E3M2 => {
1829                    self.flags.elem_fp6 = true;
1830                    Elem::FP6(FP6Kind::E3M2)
1831                }
1832                FloatKind::E4M3 => {
1833                    self.flags.elem_fp8 = true;
1834                    Elem::FP8x2(FP8Kind::E4M3)
1835                }
1836                FloatKind::E5M2 => {
1837                    self.flags.elem_fp8 = true;
1838                    Elem::FP8x2(FP8Kind::E5M2)
1839                }
1840                FloatKind::UE8M0 => {
1841                    self.flags.elem_fp8 = true;
1842                    Elem::FP8x2(FP8Kind::UE8M0)
1843                }
1844                FloatKind::F16 => {
1845                    self.flags.elem_f16 = true;
1846                    Elem::F16x2
1847                }
1848                FloatKind::BF16 => {
1849                    self.flags.elem_bf16 = true;
1850                    Elem::BF16x2
1851                }
1852                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1853            },
1854            other => unimplemented!("Unsupported storage type: {other}"),
1855        }
1856    }
1857
1858    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1859        match value {
1860            gpu::ElemType::Float(kind) => match kind {
1861                gpu::FloatKind::E2M1 => {
1862                    self.flags.elem_fp4 = true;
1863                    Elem::FP4(FP4Kind::E2M1)
1864                }
1865                gpu::FloatKind::E2M3 => {
1866                    self.flags.elem_fp6 = true;
1867                    Elem::FP6(FP6Kind::E2M3)
1868                }
1869                gpu::FloatKind::E3M2 => {
1870                    self.flags.elem_fp6 = true;
1871                    Elem::FP6(FP6Kind::E3M2)
1872                }
1873                gpu::FloatKind::E4M3 => {
1874                    self.flags.elem_fp8 = true;
1875                    Elem::FP8(FP8Kind::E4M3)
1876                }
1877                gpu::FloatKind::E5M2 => {
1878                    self.flags.elem_fp8 = true;
1879                    Elem::FP8(FP8Kind::E5M2)
1880                }
1881                gpu::FloatKind::UE8M0 => {
1882                    self.flags.elem_fp8 = true;
1883                    Elem::FP8(FP8Kind::UE8M0)
1884                }
1885                gpu::FloatKind::F16 => {
1886                    self.flags.elem_f16 = true;
1887                    Elem::F16
1888                }
1889                gpu::FloatKind::BF16 => {
1890                    self.flags.elem_bf16 = true;
1891                    Elem::BF16
1892                }
1893                gpu::FloatKind::TF32 => Elem::TF32,
1894                gpu::FloatKind::Flex32 => Elem::F32,
1895                gpu::FloatKind::F32 => Elem::F32,
1896                gpu::FloatKind::F64 => Elem::F64,
1897            },
1898            gpu::ElemType::Int(kind) => match kind {
1899                gpu::IntKind::I8 => Elem::I8,
1900                gpu::IntKind::I16 => Elem::I16,
1901                gpu::IntKind::I32 => Elem::I32,
1902                gpu::IntKind::I64 => Elem::I64,
1903            },
1904            gpu::ElemType::UInt(kind) => match kind {
1905                gpu::UIntKind::U8 => Elem::U8,
1906                gpu::UIntKind::U16 => Elem::U16,
1907                gpu::UIntKind::U32 => Elem::U32,
1908                gpu::UIntKind::U64 => Elem::U64,
1909            },
1910            gpu::ElemType::Bool => Elem::Bool,
1911        }
1912    }
1913}
1914
1915fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1916    match elem {
1917        gpu::ElemType::Float(kind) => matches!(
1918            kind,
1919            FloatKind::E2M1
1920                | FloatKind::E2M3
1921                | FloatKind::E3M2
1922                | FloatKind::E4M3
1923                | FloatKind::E5M2
1924                | FloatKind::UE8M0
1925        ),
1926        _ => false,
1927    }
1928}
1929
1930fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1931    Variable::ConstantScalar(
1932        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1933        Elem::U32,
1934    )
1935}
1936
1937pub fn register_supported_types(props: &mut DeviceProperties) {
1938    let supported_types = [
1939        gpu::ElemType::UInt(gpu::UIntKind::U8),
1940        gpu::ElemType::UInt(gpu::UIntKind::U16),
1941        gpu::ElemType::UInt(gpu::UIntKind::U32),
1942        gpu::ElemType::UInt(gpu::UIntKind::U64),
1943        gpu::ElemType::Int(gpu::IntKind::I8),
1944        gpu::ElemType::Int(gpu::IntKind::I16),
1945        gpu::ElemType::Int(gpu::IntKind::I32),
1946        gpu::ElemType::Int(gpu::IntKind::I64),
1947        gpu::ElemType::Float(gpu::FloatKind::BF16),
1948        gpu::ElemType::Float(gpu::FloatKind::F16),
1949        gpu::ElemType::Float(gpu::FloatKind::F32),
1950        gpu::ElemType::Float(gpu::FloatKind::Flex32),
1951        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
1952        //gpu::Elem::Float(gpu::FloatKind::F64),
1953        gpu::ElemType::Bool,
1954    ];
1955
1956    let supported_atomic_types = [
1957        gpu::ElemType::Int(gpu::IntKind::I32),
1958        gpu::ElemType::Int(gpu::IntKind::I64),
1959        gpu::ElemType::UInt(gpu::UIntKind::U32),
1960        gpu::ElemType::UInt(gpu::UIntKind::U64),
1961        gpu::ElemType::Float(gpu::FloatKind::F32),
1962    ];
1963
1964    for ty in supported_types {
1965        props.register_type_usage(ty, TypeUsage::all_scalar());
1966    }
1967
1968    for ty in supported_atomic_types {
1969        props.register_type_usage(
1970            gpu::StorageType::Atomic(ty),
1971            TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1972        );
1973    }
1974}