cubecl_cpp/shared/
base.rs

1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{self as gpu};
5use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind, VariableKind};
6use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
7use cubecl_core::{CubeDim, ir::ElemType};
8use cubecl_core::{
9    ir::{Operation, SourceLoc},
10    prelude::{FastMath, KernelDefinition},
11};
12use cubecl_opt::{Optimizer, SharedLiveness};
13use cubecl_runtime::compiler::CompilationError;
14use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage, compiler::Compiler};
15
16use crate::shared::MmaShape;
17
18use super::{
19    BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
20    Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
21    Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
22};
23use super::{FP4Kind, barrier::BarrierOps};
24use super::{FP8Kind, pipeline::PipelineOps};
25
26pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
27    std::sync::atomic::AtomicU32::new(0);
28
29#[derive(Clone, Debug)]
30pub struct CompilationOptions {
31    pub warp_size: u32,
32    pub supports_features: CppSupportedFeatures,
33}
34
35#[derive(Clone, Debug, Default)]
36pub struct CppSupportedFeatures {
37    pub grid_constants: bool,
38    pub clusters: bool,
39    pub fast_math: bool,
40    pub fast_tanh: bool,
41    pub elect_sync: bool,
42}
43
44impl Default for CompilationOptions {
45    fn default() -> Self {
46        Self {
47            warp_size: 32,
48            supports_features: Default::default(),
49        }
50    }
51}
52
53/// Cube indexes flags.
54/// When true the corresponding index is declared and computed as needed in the kernel.
55#[derive(Debug, Clone, Default)]
56pub struct CubeIndexFlags {
57    pub absolute_pos: bool,
58    pub absolute_pos_tuple: bool,
59    pub cube_count: bool,
60    pub cube_count_tuple: bool,
61    pub cube_dim: bool,
62    pub cube_dim_tuple: bool,
63    pub cube_pos: bool,
64    pub cube_pos_tuple: bool,
65    pub plane_dim: bool,
66    pub plane_dim_checked: bool,
67    pub plane_index: bool,
68    pub unit_pos: bool,
69    pub unit_pos_tuple: bool,
70    pub unit_pos_plane: bool,
71    pub cluster_pos: bool,
72}
73
74/// Flags gathered during Cube IR translation for the kernel compilation.
75#[derive(Debug, Clone, Default)]
76pub struct Flags {
77    pub elem_fp4: bool,
78    pub elem_fp6: bool,
79    pub elem_fp8: bool,
80    pub elem_bf16: bool,
81    pub elem_f16: bool,
82    pub elem_tf32: bool,
83    pub indexes: CubeIndexFlags,
84    pub op_barrier: bool,
85    pub op_pipeline: bool,
86    pub inst_tma: bool,
87    pub inst_tma_im2col: bool,
88    pub inst_wmma: bool,
89    pub inst_ptx_wrappers: bool,
90    pub use_grid_constants: bool,
91    pub static_meta_length: usize,
92    pub has_dynamic_meta: bool,
93    pub cube_dim: CubeDim,
94    pub cluster_dim: Option<CubeDim>,
95}
96
97#[allow(clippy::too_many_arguments)]
98#[derive(Clone, Debug, Default)]
99pub struct CppCompiler<D: Dialect> {
100    barriers: Vec<BarrierOps<D>>,
101    compilation_options: CompilationOptions,
102    const_arrays: Vec<ConstArray<D>>,
103    ext_meta_positions: Vec<u32>,
104    cluster_dim: CubeDim,
105    extensions: Vec<D::Extension>,
106    flags: Flags,
107    items: HashSet<Item<D>>,
108    local_arrays: Vec<LocalArray<D>>,
109    metadata: cubecl_core::Metadata,
110    pipelines: Vec<PipelineOps<D>>,
111    source_loc: Option<SourceLoc>,
112    strategy: ExecutionMode,
113}
114
115impl<D: Dialect> Compiler for CppCompiler<D> {
116    type Representation = ComputeKernel<D>;
117    type CompilationOptions = CompilationOptions;
118
119    fn compile(
120        &mut self,
121        mut kernel: KernelDefinition,
122        compilation_options: &Self::CompilationOptions,
123        strategy: ExecutionMode,
124    ) -> Result<Self::Representation, CompilationError> {
125        self.compilation_options = compilation_options.clone();
126        self.strategy = strategy;
127
128        if !self.compilation_options.supports_features.clusters {
129            kernel.options.cluster_dim = None;
130        }
131        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
132
133        let ir = self.clone().compile_ir(kernel);
134        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
135        Ok(ir)
136    }
137
138    fn elem_size(&self, elem: gpu::ElemType) -> usize {
139        elem.size()
140    }
141
142    fn extension(&self) -> &'static str {
143        "cpp"
144    }
145}
146
147impl<D: Dialect> CppCompiler<D> {
148    fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
149        self.build_metadata(&value);
150
151        let instructions = self.compile_scope(&mut value.body.clone());
152        let tensor_maps = value
153            .tensor_maps
154            .into_iter()
155            .map(|b| self.compile_binding(b))
156            .collect();
157        let buffers = value
158            .buffers
159            .into_iter()
160            .map(|b| self.compile_binding(b))
161            .collect();
162        let scalars = value
163            .scalars
164            .into_iter()
165            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
166            .collect();
167
168        // translation flags
169        let flags = Flags {
170            indexes: D::builtin_rules(&self.flags.indexes),
171            inst_wmma: self.flags.inst_wmma,
172            op_pipeline: self.flags.op_pipeline,
173            op_barrier: self.flags.op_barrier,
174            elem_fp4: self.flags.elem_fp4,
175            elem_fp6: self.flags.elem_fp6,
176            elem_fp8: self.flags.elem_fp8,
177            elem_bf16: self.flags.elem_bf16,
178            elem_f16: self.flags.elem_f16,
179            elem_tf32: self.flags.elem_tf32,
180            inst_tma: self.flags.inst_tma,
181            inst_tma_im2col: self.flags.inst_tma_im2col,
182            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
183            use_grid_constants: self.compilation_options.supports_features.grid_constants,
184            // TODO: At some point we should only pass dynamic meta if tensors are present,
185            // not if only arrays are present. For now, leave like this
186            has_dynamic_meta: self.metadata.static_len() > 0,
187            static_meta_length: self.metadata.static_len() as usize,
188            cube_dim: value.cube_dim,
189            cluster_dim: value.options.cluster_dim,
190        };
191
192        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
193        let shared_allocs = opt.analysis::<SharedLiveness>();
194        let shared_memories = shared_allocs
195            .allocations
196            .values()
197            .map(|alloc| SharedMemory {
198                index: alloc.smem.id,
199                item: self.compile_type(alloc.smem.ty),
200                length: alloc.smem.length,
201                align: alloc.smem.align,
202                offset: alloc.offset,
203            })
204            .collect();
205
206        let body = Body {
207            instructions,
208            shared_memories,
209            pipelines: self.pipelines,
210            barriers: self.barriers,
211            const_arrays: self.const_arrays,
212            local_arrays: self.local_arrays,
213        };
214
215        let mut cluster_dim = value.options.cluster_dim;
216        if !self.compilation_options.supports_features.clusters {
217            cluster_dim = None;
218        }
219
220        ComputeKernel {
221            tensor_maps,
222            buffers,
223            scalars,
224            meta_static_len: self.metadata.static_len() as usize,
225            cube_dim: value.cube_dim,
226            body,
227            extensions: self.extensions,
228            flags,
229            items: self.items,
230            kernel_name: value.options.kernel_name,
231            cluster_dim,
232        }
233    }
234
235    fn build_metadata(&mut self, value: &KernelDefinition) {
236        let mut num_ext = 0;
237
238        let mut all_meta: Vec<_> = value
239            .buffers
240            .iter()
241            .chain(value.tensor_maps.iter())
242            .map(|buf| (buf.id, buf.has_extended_meta))
243            .collect();
244
245        all_meta.sort_by_key(|(id, _)| *id);
246
247        for (_, has_extended_meta) in &all_meta {
248            self.ext_meta_positions.push(num_ext);
249            if *has_extended_meta {
250                num_ext += 1;
251            }
252        }
253
254        let num_meta = all_meta.len();
255
256        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
257    }
258
259    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
260        let id = var.index().expect("Variable should have index");
261        self.ext_meta_positions[id as usize]
262    }
263
264    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
265        let mut instructions = Vec::new();
266
267        let const_arrays = scope
268            .const_arrays
269            .drain(..)
270            .map(|(var, values)| ConstArray {
271                index: var.index().unwrap(),
272                item: self.compile_type(var.ty),
273                size: values.len() as u32,
274                values: values
275                    .into_iter()
276                    .map(|val| self.compile_variable(val))
277                    .collect(),
278            })
279            .collect::<Vec<_>>();
280        self.const_arrays.extend(const_arrays);
281
282        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
283        let dialect_processors = D::processors();
284        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
285        processors.extend(dialect_processors.iter().map(|it| &**it));
286
287        let processing = scope.process(processors);
288
289        for var in processing.variables {
290            instructions.push(Instruction::DeclareVariable {
291                var: self.compile_variable(var),
292            });
293        }
294
295        processing
296            .instructions
297            .into_iter()
298            .for_each(|op| self.compile_instruction(&mut instructions, op));
299
300        instructions
301    }
302
303    fn compile_instruction(
304        &mut self,
305        instructions: &mut Vec<Instruction<D>>,
306        instruction: gpu::Instruction,
307    ) {
308        self.update_debug_loc(instructions, &instruction);
309        let out = instruction.out;
310        match instruction.operation {
311            gpu::Operation::Copy(variable) => {
312                instructions.push(Instruction::Assign(UnaryInstruction {
313                    input: self.compile_variable(variable),
314                    out: self.compile_variable(out.unwrap()),
315                }));
316            }
317            gpu::Operation::Arithmetic(op) => {
318                self.compile_arithmetic(op, out, instruction.modes, instructions)
319            }
320            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
321            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
322            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
323            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
324            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
325            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
326            gpu::Operation::Synchronization(val) => match val {
327                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
328                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
329                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
330                gpu::Synchronization::SyncAsyncProxyShared => {
331                    self.flags.inst_tma = true;
332                    instructions.push(Instruction::ProxyAsyncToSharedFence)
333                }
334            },
335            gpu::Operation::Plane(op) => {
336                self.flags.indexes.plane_dim_checked = true;
337                let out = self.compile_variable(out.unwrap());
338                match op {
339                    gpu::Plane::Sum(op) => {
340                        let instruction = WarpInstruction::ReduceSum {
341                            input: self.compile_variable(op.input),
342                            out,
343                        };
344                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
345                        instructions.push(Instruction::Warp(instruction));
346                    }
347                    gpu::Plane::InclusiveSum(op) => {
348                        self.flags.indexes.unit_pos_plane = true;
349                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
350                            input: self.compile_variable(op.input),
351                            out,
352                        }))
353                    }
354                    gpu::Plane::InclusiveProd(op) => {
355                        self.flags.indexes.unit_pos_plane = true;
356                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
357                            input: self.compile_variable(op.input),
358                            out,
359                        }))
360                    }
361                    gpu::Plane::ExclusiveSum(op) => {
362                        self.flags.indexes.unit_pos_plane = true;
363                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
364                            input: self.compile_variable(op.input),
365                            out,
366                        }))
367                    }
368                    gpu::Plane::ExclusiveProd(op) => {
369                        self.flags.indexes.unit_pos_plane = true;
370                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
371                            input: self.compile_variable(op.input),
372                            out,
373                        }))
374                    }
375                    gpu::Plane::Prod(op) => {
376                        let instruction = WarpInstruction::ReduceProd {
377                            input: self.compile_variable(op.input),
378                            out,
379                        };
380                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
381                        instructions.push(Instruction::Warp(instruction))
382                    }
383                    gpu::Plane::Max(op) => {
384                        let instruction = WarpInstruction::ReduceMax {
385                            input: self.compile_variable(op.input),
386                            out,
387                        };
388                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
389                        instructions.push(Instruction::Warp(instruction))
390                    }
391                    gpu::Plane::Min(op) => {
392                        let instruction = WarpInstruction::ReduceMin {
393                            input: self.compile_variable(op.input),
394                            out,
395                        };
396                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
397                        instructions.push(Instruction::Warp(instruction))
398                    }
399                    gpu::Plane::Elect => {
400                        if self.compilation_options.supports_features.elect_sync {
401                            self.flags.inst_ptx_wrappers = true;
402                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
403                        } else {
404                            instructions
405                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
406                        }
407                    }
408                    gpu::Plane::All(op) => {
409                        instructions.push(Instruction::Warp(WarpInstruction::All {
410                            input: self.compile_variable(op.input),
411                            out,
412                        }))
413                    }
414                    gpu::Plane::Any(op) => {
415                        instructions.push(Instruction::Warp(WarpInstruction::Any {
416                            input: self.compile_variable(op.input),
417                            out,
418                        }))
419                    }
420                    gpu::Plane::Ballot(op) => {
421                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
422                            input: self.compile_variable(op.input),
423                            out,
424                        }))
425                    }
426                    gpu::Plane::Broadcast(op) => {
427                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
428                            input: self.compile_variable(op.lhs),
429                            id: self.compile_variable(op.rhs),
430                            out,
431                        }))
432                    }
433                    gpu::Plane::Shuffle(op) => {
434                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
435                            input: self.compile_variable(op.lhs),
436                            src_lane: self.compile_variable(op.rhs),
437                            out,
438                        }))
439                    }
440                    gpu::Plane::ShuffleXor(op) => {
441                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
442                            input: self.compile_variable(op.lhs),
443                            mask: self.compile_variable(op.rhs),
444                            out,
445                        }))
446                    }
447                    gpu::Plane::ShuffleUp(op) => {
448                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
449                            input: self.compile_variable(op.lhs),
450                            delta: self.compile_variable(op.rhs),
451                            out,
452                        }))
453                    }
454                    gpu::Plane::ShuffleDown(op) => {
455                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
456                            input: self.compile_variable(op.lhs),
457                            delta: self.compile_variable(op.rhs),
458                            out,
459                        }))
460                    }
461                }
462            }
463            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
464            gpu::Operation::NonSemantic(debug) => match debug {
465                gpu::NonSemantic::Print {
466                    format_string,
467                    args,
468                } => instructions.push(Instruction::Printf {
469                    format_string,
470                    args: args
471                        .into_iter()
472                        .map(|arg| self.compile_variable(arg))
473                        .collect(),
474                }),
475                gpu::NonSemantic::Comment { content } => {
476                    instructions.push(Instruction::Comment { content })
477                }
478                // Don't need to handle scopes
479                _ => {}
480            },
481            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
482                gpu::BarrierOps::Declare { barrier } => {
483                    let VariableKind::Barrier { level, .. } = barrier.kind else {
484                        unreachable!()
485                    };
486                    let barrier = self.compile_variable(barrier);
487                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
488                        barrier,
489                        level,
490                    }));
491                }
492                gpu::BarrierOps::Init {
493                    barrier,
494                    is_elected,
495                    arrival_count,
496                    with_async_proxy_fence,
497                } => {
498                    let VariableKind::Barrier { level, .. } = barrier.kind else {
499                        unreachable!()
500                    };
501                    let barrier = self.compile_variable(barrier);
502                    let arrival_count = self.compile_variable(arrival_count);
503                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
504                        barrier,
505                        is_elected: self.compile_variable(is_elected),
506                        arrival_count,
507                        level,
508                        with_async_proxy_fence,
509                    }));
510                }
511                gpu::BarrierOps::InitManual {
512                    barrier,
513                    arrival_count,
514                } => {
515                    let barrier = self.compile_variable(barrier);
516                    let arrival_count = self.compile_variable(arrival_count);
517                    instructions.push(Instruction::Barrier(
518                        super::barrier::BarrierOps::InitManual {
519                            barrier,
520                            arrival_count,
521                        },
522                    ));
523                }
524                gpu::BarrierOps::MemCopyAsync {
525                    barrier,
526                    source,
527                    source_length,
528                    offset_source,
529                    offset_out,
530                } => {
531                    instructions.push(Instruction::Barrier(
532                        super::barrier::BarrierOps::MemCopyAsync {
533                            barrier: self.compile_variable(barrier),
534                            source: self.compile_variable(source),
535                            destination: self.compile_variable(out.unwrap()),
536                            source_length: self.compile_variable(source_length),
537                            offset_source: self.compile_variable(offset_source),
538                            offset_out: self.compile_variable(offset_out),
539                            cooperative: false,
540                        },
541                    ));
542                }
543                gpu::BarrierOps::MemCopyAsyncCooperative {
544                    barrier,
545                    source,
546                    source_length,
547                    offset_source,
548                    offset_out,
549                } => {
550                    instructions.push(Instruction::Barrier(
551                        super::barrier::BarrierOps::MemCopyAsync {
552                            barrier: self.compile_variable(barrier),
553                            source: self.compile_variable(source),
554                            destination: self.compile_variable(out.unwrap()),
555                            source_length: self.compile_variable(source_length),
556                            offset_source: self.compile_variable(offset_source),
557                            offset_out: self.compile_variable(offset_out),
558                            cooperative: true,
559                        },
560                    ));
561                }
562                gpu::BarrierOps::MemCopyAsyncTx {
563                    barrier,
564                    source,
565                    source_length,
566                    offset_source,
567                    offset_out,
568                } => {
569                    instructions.push(Instruction::Barrier(
570                        super::barrier::BarrierOps::MemCopyAsyncTx {
571                            barrier: self.compile_variable(barrier),
572                            source: self.compile_variable(source),
573                            destination: self.compile_variable(out.unwrap()),
574                            source_length: self.compile_variable(source_length),
575                            offset_source: self.compile_variable(offset_source),
576                            offset_out: self.compile_variable(offset_out),
577                        },
578                    ));
579                }
580                gpu::BarrierOps::TmaLoad {
581                    barrier,
582                    tensor_map,
583                    offset_out,
584                    indices,
585                } => {
586                    instructions.push(Instruction::Barrier(
587                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
588                            barrier: self.compile_variable(barrier),
589                            smem_buffer: self.compile_variable(out.unwrap()),
590                            smem_offset: self.compile_variable(offset_out),
591                            tensor_map: self.compile_variable(tensor_map),
592                            indices: indices
593                                .into_iter()
594                                .map(|it| self.compile_variable(it))
595                                .collect(),
596                        },
597                    ));
598                }
599                gpu::BarrierOps::TmaLoadIm2col {
600                    barrier,
601                    tensor_map,
602                    offset_out,
603                    indices,
604                    offsets,
605                } => {
606                    self.flags.inst_tma_im2col = true;
607                    instructions.push(Instruction::Barrier(
608                        super::barrier::BarrierOps::TmaLoadIm2col {
609                            barrier: self.compile_variable(barrier),
610                            smem_buffer: self.compile_variable(out.unwrap()),
611                            smem_offset: self.compile_variable(offset_out),
612                            tensor_map: self.compile_variable(tensor_map),
613                            indices: indices
614                                .into_iter()
615                                .map(|it| self.compile_variable(it))
616                                .collect(),
617                            offsets: offsets
618                                .into_iter()
619                                .map(|it| self.compile_variable(it))
620                                .collect(),
621                        },
622                    ));
623                }
624                gpu::BarrierOps::Arrive { barrier } => {
625                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
626                        barrier: self.compile_variable(barrier),
627                        token: self.compile_variable(out.unwrap()),
628                    }))
629                }
630                gpu::BarrierOps::ArriveTx {
631                    barrier,
632                    arrive_count_update,
633                    transaction_count_update,
634                } => {
635                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
636                        barrier: self.compile_variable(barrier),
637                        token: self.compile_variable(out.unwrap()),
638                        arrive_count_update: self.compile_variable(arrive_count_update),
639                        transaction_count_update: self.compile_variable(transaction_count_update),
640                    }))
641                }
642                gpu::BarrierOps::ExpectTx {
643                    barrier,
644                    transaction_count_update,
645                } => {
646                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
647                        barrier: self.compile_variable(barrier),
648                        transaction_count_update: self.compile_variable(transaction_count_update),
649                    }))
650                }
651                gpu::BarrierOps::Wait { barrier, token } => {
652                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
653                        barrier: self.compile_variable(barrier),
654                        token: self.compile_variable(token),
655                    }))
656                }
657                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
658                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
659                        barrier: self.compile_variable(barrier),
660                        phase: self.compile_variable(phase),
661                    }),
662                ),
663                gpu::BarrierOps::ArriveAndWait { barrier } => {
664                    let VariableKind::Barrier { level, .. } = barrier.kind else {
665                        unreachable!()
666                    };
667                    instructions.push(Instruction::Barrier(
668                        super::barrier::BarrierOps::ArriveAndWait {
669                            barrier: self.compile_variable(barrier),
670                            level,
671                        },
672                    ))
673                }
674            },
675            gpu::Operation::Tma(tma_ops) => {
676                self.flags.inst_tma = true;
677                match tma_ops {
678                    gpu::TmaOps::TmaStore {
679                        source,
680                        coordinates,
681                        offset_source,
682                    } => {
683                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
684                            smem_buffer: self.compile_variable(source),
685                            smem_offset: self.compile_variable(offset_source),
686                            tensor_map: self.compile_variable(out.unwrap()),
687                            indices: coordinates
688                                .into_iter()
689                                .map(|it| self.compile_variable(it))
690                                .collect(),
691                        });
692                    }
693                    gpu::TmaOps::CommitGroup => {
694                        instructions.push(Instruction::BulkCommitGroup);
695                    }
696                    gpu::TmaOps::WaitGroup { max_pending } => {
697                        instructions.push(Instruction::BulkWaitGroup { max_pending });
698                    }
699                    gpu::TmaOps::WaitGroupRead { max_pending } => {
700                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
701                    }
702                }
703            }
704            gpu::Operation::Marker(_) => {}
705        }
706    }
707
708    fn update_debug_loc(
709        &mut self,
710        instructions: &mut Vec<Instruction<D>>,
711        inst: &gpu::Instruction,
712    ) {
713        if !matches!(inst.operation, Operation::NonSemantic(_)) {
714            match &inst.source_loc {
715                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
716                    self.source_loc = Some(loc.clone());
717                    instructions.push(Instruction::Line {
718                        file: loc.source.file.clone(),
719                        line: loc.line,
720                    });
721                }
722                _ => {}
723            }
724        }
725    }
726
727    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
728        self.flags.inst_wmma = true;
729
730        let out = self.compile_variable(out.unwrap());
731
732        let inst = match cmma {
733            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
734                frag: out,
735                value: self.compile_variable(value),
736            },
737            gpu::CoopMma::Load {
738                value,
739                stride,
740                offset,
741                layout,
742            } => WmmaInstruction::Load {
743                frag: out,
744                offset: self.compile_variable(offset),
745                value: self.compile_variable(value),
746                stride: self.compile_variable(stride),
747                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
748            },
749            gpu::CoopMma::Execute {
750                mat_a,
751                mat_b,
752                mat_c,
753            } => WmmaInstruction::Execute {
754                frag_a: self.compile_variable(mat_a),
755                frag_b: self.compile_variable(mat_b),
756                frag_c: self.compile_variable(mat_c),
757                frag_d: out,
758                warp_size: self.compilation_options.warp_size,
759            },
760            gpu::CoopMma::ExecuteManual {
761                matrix,
762                registers_a,
763                registers_b,
764                registers_c,
765            } => WmmaInstruction::ExecuteManual {
766                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
767                frag_a: self.compile_variable(registers_a),
768                frag_b: self.compile_variable(registers_b),
769                frag_c: self.compile_variable(registers_c),
770                frag_d: out,
771            },
772            gpu::CoopMma::ExecuteScaled {
773                matrix,
774                registers_a,
775                registers_b,
776                registers_c,
777                scales_a,
778                scales_b,
779                scales_factor,
780            } => WmmaInstruction::ExecuteScaled {
781                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
782                frag_a: self.compile_variable(registers_a),
783                frag_b: self.compile_variable(registers_b),
784                frag_c: self.compile_variable(registers_c),
785                frag_d: out,
786
787                scales_a: self.compile_variable(scales_a),
788                scales_b: self.compile_variable(scales_b),
789                scales_factor,
790            },
791            gpu::CoopMma::Store {
792                mat,
793                stride,
794                offset,
795                layout,
796            } => {
797                self.flags.indexes.unit_pos = true;
798                self.flags.indexes.plane_index = true;
799                WmmaInstruction::Store {
800                    output: out,
801                    offset: self.compile_variable(offset),
802                    frag: self.compile_variable(mat),
803                    stride: self.compile_variable(stride),
804                    layout: self
805                        .compile_matrix_layout(layout)
806                        .expect("Layout required for store instruction"),
807                }
808            }
809            gpu::CoopMma::LoadMatrix {
810                buffer,
811                offset,
812                line_size,
813                factor,
814                transpose,
815            } => WmmaInstruction::LdMatrix {
816                output: out,
817                buffer: self.compile_variable(buffer),
818                offset: self.compile_variable(offset),
819                line_size,
820                factor,
821                transpose,
822            },
823            gpu::CoopMma::StoreMatrix {
824                offset,
825                line_size,
826                registers,
827                factor,
828                transpose,
829            } => WmmaInstruction::StMatrix {
830                registers: self.compile_variable(registers),
831                buffer: out,
832                offset: self.compile_variable(offset),
833                line_size,
834                factor,
835                transpose,
836            },
837            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
838                input: self.compile_variable(input),
839                output: out,
840            },
841            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
842                panic!("Row/Col index should be handled by processors")
843            }
844        };
845
846        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
847
848        Instruction::Wmma(inst)
849    }
850
851    fn compile_metadata(
852        &mut self,
853        metadata: gpu::Metadata,
854        out: Option<gpu::Variable>,
855    ) -> Instruction<D> {
856        let out = out.unwrap();
857        match metadata {
858            gpu::Metadata::Stride { dim, var } => {
859                let position = self.ext_meta_position(var);
860                let offset = self.metadata.stride_offset_index(position);
861                Instruction::ExtendedMetadata {
862                    info_offset: self.compile_variable(offset.into()),
863                    dim: self.compile_variable(dim),
864                    split_meta: self.compilation_options.supports_features.grid_constants,
865                    static_offset: self.metadata.static_len(),
866                    out: self.compile_variable(out),
867                }
868            }
869            gpu::Metadata::Shape { dim, var } => {
870                let position = self.ext_meta_position(var);
871                let offset = self.metadata.shape_offset_index(position);
872                Instruction::ExtendedMetadata {
873                    info_offset: self.compile_variable(offset.into()),
874                    dim: self.compile_variable(dim),
875                    split_meta: self.compilation_options.supports_features.grid_constants,
876                    static_offset: self.metadata.static_len(),
877                    out: self.compile_variable(out),
878                }
879            }
880            gpu::Metadata::Rank { var } => {
881                let out = self.compile_variable(out);
882                let pos = self.ext_meta_position(var);
883                let offset = self.metadata.rank_index(pos);
884                super::Instruction::Metadata {
885                    info_offset: self.compile_variable(offset.into()),
886                    split_meta: self.compilation_options.supports_features.grid_constants,
887                    out,
888                }
889            }
890            gpu::Metadata::Length { var } => {
891                let input = self.compile_variable(var);
892                let out = self.compile_variable(out);
893
894                match input {
895                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
896                    Variable::SharedMemory(_id, _item, length) => {
897                        Instruction::ConstLength { length, out }
898                    }
899                    _ => {
900                        let id = input.id().expect("Variable should have id");
901                        let offset = self.metadata.len_index(id);
902                        Instruction::Metadata {
903                            info_offset: self.compile_variable(offset.into()),
904                            split_meta: self.compilation_options.supports_features.grid_constants,
905                            out,
906                        }
907                    }
908                }
909            }
910            gpu::Metadata::BufferLength { var } => {
911                let input = self.compile_variable(var);
912                let out = self.compile_variable(out);
913
914                match input {
915                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
916                    _ => {
917                        let id = input.id().expect("Variable should have id");
918                        let offset = self.metadata.buffer_len_index(id);
919                        Instruction::Metadata {
920                            info_offset: self.compile_variable(offset.into()),
921                            split_meta: self.compilation_options.supports_features.grid_constants,
922                            out,
923                        }
924                    }
925                }
926            }
927        }
928    }
929
930    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
931        match branch {
932            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
933                cond: self.compile_variable(op.cond),
934                instructions: self.compile_scope(&mut op.scope),
935            }),
936            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
937                cond: self.compile_variable(op.cond),
938                instructions_if: self.compile_scope(&mut op.scope_if),
939                instructions_else: self.compile_scope(&mut op.scope_else),
940            }),
941            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
942                value: self.compile_variable(op.value),
943                instructions_default: self.compile_scope(&mut op.scope_default),
944                instructions_cases: op
945                    .cases
946                    .into_iter()
947                    .map(|(val, mut block)| {
948                        (self.compile_variable(val), self.compile_scope(&mut block))
949                    })
950                    .collect(),
951            }),
952            gpu::Branch::Return => instructions.push(Instruction::Return),
953            gpu::Branch::Break => instructions.push(Instruction::Break),
954            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
955                i: self.compile_variable(range_loop.i),
956                start: self.compile_variable(range_loop.start),
957                end: self.compile_variable(range_loop.end),
958                step: range_loop.step.map(|it| self.compile_variable(it)),
959                inclusive: range_loop.inclusive,
960                instructions: self.compile_scope(&mut range_loop.scope),
961            }),
962            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
963                instructions: self.compile_scope(&mut op.scope),
964            }),
965        };
966    }
967
968    fn compile_atomic(
969        &mut self,
970        value: gpu::AtomicOp,
971        out: Option<gpu::Variable>,
972        instructions: &mut Vec<Instruction<D>>,
973    ) {
974        let out = out.unwrap();
975        match value {
976            gpu::AtomicOp::Load(op) => {
977                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
978            }
979            gpu::AtomicOp::Store(op) => {
980                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
981            }
982            gpu::AtomicOp::Swap(op) => {
983                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
984            }
985            gpu::AtomicOp::Add(op) => {
986                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
987            }
988            gpu::AtomicOp::Sub(op) => {
989                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
990            }
991            gpu::AtomicOp::Max(op) => {
992                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
993            }
994            gpu::AtomicOp::Min(op) => {
995                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
996            }
997            gpu::AtomicOp::And(op) => {
998                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
999            }
1000            gpu::AtomicOp::Or(op) => {
1001                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
1002            }
1003            gpu::AtomicOp::Xor(op) => {
1004                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
1005            }
1006            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
1007                input: self.compile_variable(op.input),
1008                cmp: self.compile_variable(op.cmp),
1009                val: self.compile_variable(op.val),
1010                out: self.compile_variable(out),
1011            }),
1012        }
1013    }
1014
1015    fn compile_arithmetic(
1016        &mut self,
1017        value: gpu::Arithmetic,
1018        out: Option<gpu::Variable>,
1019        modes: InstructionModes,
1020        instructions: &mut Vec<Instruction<D>>,
1021    ) {
1022        let out = out.unwrap();
1023        match value {
1024            gpu::Arithmetic::Add(op) => {
1025                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1026            }
1027            gpu::Arithmetic::SaturatingAdd(op) => {
1028                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1029            }
1030            gpu::Arithmetic::Mul(op) => {
1031                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1032            }
1033            gpu::Arithmetic::Div(op) => {
1034                let op = self.compile_binary(op, out);
1035                instructions.push(self.select_fast_float(
1036                    out.ty,
1037                    modes,
1038                    FastMath::AllowReciprocal
1039                        | FastMath::ReducedPrecision
1040                        | FastMath::UnsignedZero
1041                        | FastMath::NotInf,
1042                    Instruction::Div(op),
1043                    Instruction::FastDiv(op),
1044                ))
1045            }
1046            gpu::Arithmetic::Sub(op) => {
1047                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1048            }
1049            gpu::Arithmetic::SaturatingSub(op) => {
1050                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1051            }
1052            gpu::Arithmetic::MulHi(op) => {
1053                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1054                D::register_instruction_extension(&mut self.extensions, &instruction);
1055                instructions.push(instruction)
1056            }
1057            gpu::Arithmetic::Modulo(op) => {
1058                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1059            }
1060            gpu::Arithmetic::Abs(op) => {
1061                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1062            }
1063            gpu::Arithmetic::Exp(op) => {
1064                let op = self.compile_unary(op, out);
1065                instructions.push(self.select_fast_float(
1066                    out.ty,
1067                    modes,
1068                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1069                    Instruction::Exp(op),
1070                    Instruction::FastExp(op),
1071                ));
1072            }
1073            gpu::Arithmetic::Log(op) => {
1074                let op = self.compile_unary(op, out);
1075                instructions.push(self.select_fast_float(
1076                    out.ty,
1077                    modes,
1078                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1079                    Instruction::Log(op),
1080                    Instruction::FastLog(op),
1081                ));
1082            }
1083            gpu::Arithmetic::Log1p(op) => {
1084                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1085            }
1086            gpu::Arithmetic::Cos(op) => {
1087                let op = self.compile_unary(op, out);
1088                instructions.push(self.select_fast_float(
1089                    out.ty,
1090                    modes,
1091                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1092                    Instruction::Cos(op),
1093                    Instruction::FastCos(op),
1094                ));
1095            }
1096            gpu::Arithmetic::Sin(op) => {
1097                let op = self.compile_unary(op, out);
1098                instructions.push(self.select_fast_float(
1099                    out.ty,
1100                    modes,
1101                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1102                    Instruction::Sin(op),
1103                    Instruction::FastSin(op),
1104                ));
1105            }
1106            gpu::Arithmetic::Tan(op) => {
1107                instructions.push(Instruction::Tan(self.compile_unary(op, out)))
1108            }
1109            gpu::Arithmetic::Tanh(op) => {
1110                let op = self.compile_unary(op, out);
1111                let instruction = Instruction::Tanh(op);
1112                D::register_instruction_extension(&mut self.extensions, &instruction);
1113                if self.compilation_options.supports_features.fast_tanh {
1114                    instructions.push(self.select_fast_float(
1115                        out.ty,
1116                        modes,
1117                        FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1118                        instruction,
1119                        Instruction::FastTanh(op),
1120                    ))
1121                } else {
1122                    instructions.push(instruction);
1123                }
1124            }
1125            gpu::Arithmetic::Sinh(op) => {
1126                let instruction = Instruction::Sinh(self.compile_unary(op, out));
1127                D::register_instruction_extension(&mut self.extensions, &instruction);
1128                instructions.push(instruction)
1129            }
1130            gpu::Arithmetic::Cosh(op) => {
1131                let instruction = Instruction::Cosh(self.compile_unary(op, out));
1132                D::register_instruction_extension(&mut self.extensions, &instruction);
1133                instructions.push(instruction)
1134            }
1135            gpu::Arithmetic::ArcCos(op) => {
1136                let instruction = Instruction::ArcCos(self.compile_unary(op, out));
1137                D::register_instruction_extension(&mut self.extensions, &instruction);
1138                instructions.push(instruction)
1139            }
1140            gpu::Arithmetic::ArcSin(op) => {
1141                let instruction = Instruction::ArcSin(self.compile_unary(op, out));
1142                D::register_instruction_extension(&mut self.extensions, &instruction);
1143                instructions.push(instruction)
1144            }
1145            gpu::Arithmetic::ArcTan(op) => {
1146                let instruction = Instruction::ArcTan(self.compile_unary(op, out));
1147                D::register_instruction_extension(&mut self.extensions, &instruction);
1148                instructions.push(instruction)
1149            }
1150            gpu::Arithmetic::ArcSinh(op) => {
1151                let instruction = Instruction::ArcSinh(self.compile_unary(op, out));
1152                D::register_instruction_extension(&mut self.extensions, &instruction);
1153                instructions.push(instruction)
1154            }
1155            gpu::Arithmetic::ArcCosh(op) => {
1156                let instruction = Instruction::ArcCosh(self.compile_unary(op, out));
1157                D::register_instruction_extension(&mut self.extensions, &instruction);
1158                instructions.push(instruction)
1159            }
1160            gpu::Arithmetic::ArcTanh(op) => {
1161                let instruction = Instruction::ArcTanh(self.compile_unary(op, out));
1162                D::register_instruction_extension(&mut self.extensions, &instruction);
1163                instructions.push(instruction)
1164            }
1165            gpu::Arithmetic::Degrees(op) => {
1166                let instruction = Instruction::Degrees(self.compile_unary(op, out));
1167                D::register_instruction_extension(&mut self.extensions, &instruction);
1168                instructions.push(instruction)
1169            }
1170            gpu::Arithmetic::Radians(op) => {
1171                let instruction = Instruction::Radians(self.compile_unary(op, out));
1172                D::register_instruction_extension(&mut self.extensions, &instruction);
1173                instructions.push(instruction)
1174            }
1175            gpu::Arithmetic::ArcTan2(op) => {
1176                let instruction = Instruction::ArcTan2(self.compile_binary(op, out));
1177                D::register_instruction_extension(&mut self.extensions, &instruction);
1178                instructions.push(instruction)
1179            }
1180            gpu::Arithmetic::Powf(op) => {
1181                let op = self.compile_binary(op, out);
1182                instructions.push(self.select_fast_float(
1183                    out.ty,
1184                    modes,
1185                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1186                    Instruction::Powf(op),
1187                    Instruction::FastPowf(op),
1188                ))
1189            }
1190            gpu::Arithmetic::Powi(op) => {
1191                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1192            }
1193            gpu::Arithmetic::Sqrt(op) => {
1194                let op = self.compile_unary(op, out);
1195                instructions.push(self.select_fast_float(
1196                    out.ty,
1197                    modes,
1198                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1199                    Instruction::Sqrt(op),
1200                    Instruction::FastSqrt(op),
1201                ))
1202            }
1203            gpu::Arithmetic::InverseSqrt(op) => {
1204                let op = self.compile_unary(op, out);
1205                instructions.push(self.select_fast_float(
1206                    out.ty,
1207                    modes,
1208                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1209                    Instruction::InverseSqrt(op),
1210                    Instruction::FastInverseSqrt(op),
1211                ))
1212            }
1213            gpu::Arithmetic::Erf(op) => {
1214                let instruction = Instruction::Erf(self.compile_unary(op, out));
1215                D::register_instruction_extension(&mut self.extensions, &instruction);
1216                instructions.push(instruction)
1217            }
1218            gpu::Arithmetic::Max(op) => {
1219                let instruction = Instruction::Max(self.compile_binary(op, out));
1220                D::register_instruction_extension(&mut self.extensions, &instruction);
1221                instructions.push(instruction)
1222            }
1223            gpu::Arithmetic::Min(op) => {
1224                let instruction = Instruction::Min(self.compile_binary(op, out));
1225                D::register_instruction_extension(&mut self.extensions, &instruction);
1226                instructions.push(instruction)
1227            }
1228            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1229                input: self.compile_variable(op.input),
1230                min_value: self.compile_variable(op.min_value),
1231                max_value: self.compile_variable(op.max_value),
1232                out: self.compile_variable(out),
1233            }),
1234            gpu::Arithmetic::Recip(op) => {
1235                let elem = op.input.ty.elem_type();
1236                let input = self.compile_variable(op.input);
1237                let out = self.compile_variable(out);
1238                let lhs = match elem {
1239                    gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1240                    gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1241                    gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1242                    gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1243                };
1244                let div = Instruction::Div(BinaryInstruction {
1245                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1246                    rhs: input,
1247                    out,
1248                });
1249                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1250
1251                instructions.push(self.select_fast_float(
1252                    elem.into(),
1253                    modes,
1254                    FastMath::AllowReciprocal
1255                        | FastMath::ReducedPrecision
1256                        | FastMath::UnsignedZero
1257                        | FastMath::NotInf,
1258                    div,
1259                    recip,
1260                ))
1261            }
1262            gpu::Arithmetic::Round(op) => {
1263                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1264            }
1265            gpu::Arithmetic::Floor(op) => {
1266                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1267            }
1268            gpu::Arithmetic::Ceil(op) => {
1269                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1270            }
1271            gpu::Arithmetic::Trunc(op) => {
1272                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1273            }
1274            gpu::Arithmetic::Remainder(op) => {
1275                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1276            }
1277            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1278                a: self.compile_variable(op.a),
1279                b: self.compile_variable(op.b),
1280                c: self.compile_variable(op.c),
1281                out: self.compile_variable(out),
1282            }),
1283            gpu::Arithmetic::Neg(op) => {
1284                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1285            }
1286            gpu::Arithmetic::Normalize(op) => {
1287                let op = self.compile_unary(op, out);
1288                instructions.push(self.select_fast_float(
1289                    out.ty,
1290                    modes,
1291                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1292                    Instruction::Normalize(op),
1293                    Instruction::FastNormalize(op),
1294                ))
1295            }
1296            gpu::Arithmetic::Magnitude(op) => {
1297                let op = self.compile_unary(op, out);
1298                instructions.push(self.select_fast_float(
1299                    out.ty,
1300                    modes,
1301                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1302                    Instruction::Magnitude(op),
1303                    Instruction::FastMagnitude(op),
1304                ))
1305            }
1306            gpu::Arithmetic::Dot(op) => {
1307                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1308            }
1309        };
1310    }
1311
1312    fn select_fast_float(
1313        &self,
1314        ty: gpu::Type,
1315        modes: InstructionModes,
1316        required_flags: EnumSet<FastMath>,
1317        default: Instruction<D>,
1318        fast: Instruction<D>,
1319    ) -> Instruction<D> {
1320        if !self.compilation_options.supports_features.fast_math
1321            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1322        {
1323            return default;
1324        }
1325
1326        if modes.fp_math_mode.is_superset(required_flags) {
1327            fast
1328        } else {
1329            default
1330        }
1331    }
1332
1333    fn compile_comparison(
1334        &mut self,
1335        value: gpu::Comparison,
1336        out: Option<gpu::Variable>,
1337        instructions: &mut Vec<Instruction<D>>,
1338    ) {
1339        let out = out.unwrap();
1340        match value {
1341            gpu::Comparison::Equal(op) => {
1342                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1343            }
1344            gpu::Comparison::Lower(op) => {
1345                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1346            }
1347            gpu::Comparison::Greater(op) => {
1348                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1349            }
1350            gpu::Comparison::LowerEqual(op) => {
1351                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1352            }
1353            gpu::Comparison::GreaterEqual(op) => {
1354                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1355            }
1356            gpu::Comparison::NotEqual(op) => {
1357                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1358            }
1359            gpu::Comparison::IsNan(op) => {
1360                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1361            }
1362            gpu::Comparison::IsInf(op) => {
1363                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1364            }
1365        };
1366    }
1367
1368    fn compile_bitwise(
1369        &mut self,
1370        value: gpu::Bitwise,
1371        out: Option<gpu::Variable>,
1372        instructions: &mut Vec<Instruction<D>>,
1373    ) {
1374        let out = out.unwrap();
1375        match value {
1376            gpu::Bitwise::BitwiseOr(op) => {
1377                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1378            }
1379            gpu::Bitwise::BitwiseAnd(op) => {
1380                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1381            }
1382            gpu::Bitwise::BitwiseXor(op) => {
1383                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1384            }
1385            gpu::Bitwise::CountOnes(op) => {
1386                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1387            }
1388            gpu::Bitwise::ReverseBits(op) => {
1389                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1390            }
1391            gpu::Bitwise::ShiftLeft(op) => {
1392                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1393            }
1394            gpu::Bitwise::ShiftRight(op) => {
1395                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1396            }
1397            gpu::Bitwise::BitwiseNot(op) => {
1398                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1399            }
1400            gpu::Bitwise::LeadingZeros(op) => {
1401                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1402            }
1403            gpu::Bitwise::FindFirstSet(op) => {
1404                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1405                D::register_instruction_extension(&mut self.extensions, &instruction);
1406                instructions.push(instruction)
1407            }
1408        };
1409    }
1410
1411    fn compile_operator(
1412        &mut self,
1413        value: gpu::Operator,
1414        out: Option<gpu::Variable>,
1415        instructions: &mut Vec<Instruction<D>>,
1416    ) {
1417        let out = out.unwrap();
1418        match value {
1419            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1420                instructions.push(Instruction::Index(self.compile_index(op, out)));
1421            }
1422            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1423                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1424            }
1425            gpu::Operator::And(op) => {
1426                instructions.push(Instruction::And(self.compile_binary(op, out)))
1427            }
1428            gpu::Operator::Or(op) => {
1429                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1430            }
1431            gpu::Operator::Not(op) => {
1432                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1433            }
1434            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1435                inputs: op
1436                    .inputs
1437                    .into_iter()
1438                    .map(|it| self.compile_variable(it))
1439                    .collect(),
1440                out: self.compile_variable(out),
1441            }),
1442            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1443                input: self.compile_variable(op.input),
1444                in_index: self.compile_variable(op.in_index),
1445                out: self.compile_variable(out),
1446                out_index: self.compile_variable(op.out_index),
1447            }),
1448            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1449                input: self.compile_variable(op.input),
1450                in_index: self.compile_variable(op.in_index),
1451                out: self.compile_variable(out),
1452                out_index: self.compile_variable(op.out_index),
1453                len: op.len,
1454            }),
1455            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1456                cond: self.compile_variable(op.cond),
1457                then: self.compile_variable(op.then),
1458                or_else: self.compile_variable(op.or_else),
1459                out: self.compile_variable(out),
1460            }),
1461            // Needs special conversion semantics
1462            gpu::Operator::Cast(op)
1463                if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1464            {
1465                // We may need these for intermediates
1466                self.flags.elem_f16 = true;
1467                self.flags.elem_bf16 = true;
1468                let vec_in = op.input.ty.line_size();
1469                let packing = out.storage_type().packing_factor();
1470                self.compile_type(op.input.ty.line(packing));
1471                self.compile_type(
1472                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1473                );
1474                self.compile_type(
1475                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1476                );
1477                self.compile_type(
1478                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1479                );
1480                self.compile_type(
1481                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1482                );
1483
1484                let inst = self.compile_unary(op, out);
1485
1486                instructions.push(Instruction::SpecialCast(inst));
1487            }
1488            gpu::Operator::Cast(op) => {
1489                let op = self.compile_unary(op, out);
1490
1491                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1492                    self.flags.elem_tf32 = true;
1493                }
1494
1495                instructions.push(Instruction::Assign(op))
1496            }
1497            gpu::Operator::Reinterpret(op) => {
1498                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1499            }
1500        };
1501    }
1502
1503    fn compile_binary(
1504        &mut self,
1505        value: gpu::BinaryOperator,
1506        out: gpu::Variable,
1507    ) -> BinaryInstruction<D> {
1508        BinaryInstruction {
1509            lhs: self.compile_variable(value.lhs),
1510            rhs: self.compile_variable(value.rhs),
1511            out: self.compile_variable(out),
1512        }
1513    }
1514
1515    fn compile_index_assign(
1516        &mut self,
1517        value: gpu::IndexAssignOperator,
1518        out: gpu::Variable,
1519    ) -> IndexAssignInstruction<D> {
1520        IndexAssignInstruction {
1521            index: self.compile_variable(value.index),
1522            value: self.compile_variable(value.value),
1523            line_size: value.line_size,
1524            out: self.compile_variable(out),
1525        }
1526    }
1527
1528    fn compile_index(
1529        &mut self,
1530        value: gpu::IndexOperator,
1531        out: gpu::Variable,
1532    ) -> IndexInstruction<D> {
1533        IndexInstruction {
1534            list: self.compile_variable(value.list),
1535            index: self.compile_variable(value.index),
1536            line_size: value.line_size,
1537            out: self.compile_variable(out),
1538        }
1539    }
1540
1541    fn compile_unary(
1542        &mut self,
1543        value: gpu::UnaryOperator,
1544        out: gpu::Variable,
1545    ) -> UnaryInstruction<D> {
1546        UnaryInstruction {
1547            input: self.compile_variable(value.input),
1548            out: self.compile_variable(out),
1549        }
1550    }
1551
1552    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1553        let item = value.ty;
1554        match value.kind {
1555            gpu::VariableKind::GlobalInputArray(id) => {
1556                Variable::GlobalInputArray(id, self.compile_type(item))
1557            }
1558            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1559                id,
1560                elem: self.compile_storage_type(item.storage_type()),
1561                in_struct: self.compilation_options.supports_features.grid_constants,
1562            },
1563            gpu::VariableKind::TensorMapInput(id) => {
1564                self.flags.inst_tma = true;
1565                Variable::TensorMap(id)
1566            }
1567            gpu::VariableKind::TensorMapOutput(id) => {
1568                self.flags.inst_tma = true;
1569                Variable::TensorMap(id)
1570            }
1571            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1572                id,
1573                item: self.compile_type(item),
1574            },
1575            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1576                id,
1577                item: self.compile_type(item),
1578            },
1579            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1580                id,
1581                item: self.compile_type(item),
1582            },
1583            gpu::VariableKind::GlobalOutputArray(id) => {
1584                Variable::GlobalOutputArray(id, self.compile_type(item))
1585            }
1586            gpu::VariableKind::ConstantScalar(value) => {
1587                Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1588            }
1589            gpu::VariableKind::SharedMemory { id, length, .. } => {
1590                let item = self.compile_type(item);
1591                Variable::SharedMemory(id, item, length)
1592            }
1593            gpu::VariableKind::ConstantArray {
1594                id,
1595                length,
1596                unroll_factor,
1597            } => {
1598                let item = self.compile_type(item);
1599                Variable::ConstantArray(id, item, length * unroll_factor)
1600            }
1601            gpu::VariableKind::Builtin(builtin) => match builtin {
1602                gpu::Builtin::AbsolutePos => {
1603                    self.flags.indexes.absolute_pos = true;
1604                    Variable::AbsolutePos
1605                }
1606                gpu::Builtin::CubePosCluster
1607                    if self.compilation_options.supports_features.clusters =>
1608                {
1609                    self.flags.indexes.cluster_pos = true;
1610                    Variable::ClusterRank
1611                }
1612                gpu::Builtin::CubePosClusterX
1613                    if self.compilation_options.supports_features.clusters =>
1614                {
1615                    self.flags.indexes.cluster_pos = true;
1616                    Variable::ClusterIndexX
1617                }
1618                gpu::Builtin::CubePosClusterY
1619                    if self.compilation_options.supports_features.clusters =>
1620                {
1621                    self.flags.indexes.cluster_pos = true;
1622                    Variable::ClusterIndexY
1623                }
1624                gpu::Builtin::CubePosClusterZ
1625                    if self.compilation_options.supports_features.clusters =>
1626                {
1627                    self.flags.indexes.cluster_pos = true;
1628                    Variable::ClusterIndexZ
1629                }
1630                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1631                // (1, 1, 1) if unsupported
1632                gpu::Builtin::CubePosCluster
1633                | gpu::Builtin::CubePosClusterX
1634                | gpu::Builtin::CubePosClusterY
1635                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1636                gpu::Builtin::AbsolutePosX => {
1637                    self.flags.indexes.absolute_pos_tuple = true;
1638                    Variable::AbsolutePosX
1639                }
1640                gpu::Builtin::AbsolutePosY => {
1641                    self.flags.indexes.absolute_pos_tuple = true;
1642                    Variable::AbsolutePosY
1643                }
1644                gpu::Builtin::AbsolutePosZ => {
1645                    self.flags.indexes.absolute_pos_tuple = true;
1646                    Variable::AbsolutePosZ
1647                }
1648                gpu::Builtin::CubeDim => {
1649                    self.flags.indexes.cube_dim = true;
1650                    Variable::CubeDim
1651                }
1652                gpu::Builtin::CubeDimX => {
1653                    self.flags.indexes.cube_dim_tuple = true;
1654                    Variable::CubeDimX
1655                }
1656                gpu::Builtin::CubeDimY => {
1657                    self.flags.indexes.cube_dim_tuple = true;
1658                    Variable::CubeDimY
1659                }
1660                gpu::Builtin::CubeDimZ => {
1661                    self.flags.indexes.cube_dim_tuple = true;
1662                    Variable::CubeDimZ
1663                }
1664                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1665                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1666                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1667                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1668                gpu::Builtin::CubePos => {
1669                    self.flags.indexes.cube_pos = true;
1670                    Variable::CubePos
1671                }
1672                gpu::Builtin::CubePosX => {
1673                    self.flags.indexes.cube_pos_tuple = true;
1674                    Variable::CubePosX
1675                }
1676                gpu::Builtin::CubePosY => {
1677                    self.flags.indexes.cube_pos_tuple = true;
1678                    Variable::CubePosY
1679                }
1680                gpu::Builtin::CubePosZ => {
1681                    self.flags.indexes.cube_pos_tuple = true;
1682                    Variable::CubePosZ
1683                }
1684                gpu::Builtin::CubeCount => {
1685                    self.flags.indexes.cube_count = true;
1686                    Variable::CubeCount
1687                }
1688                gpu::Builtin::CubeCountX => {
1689                    self.flags.indexes.cube_count_tuple = true;
1690                    Variable::CubeCountX
1691                }
1692                gpu::Builtin::CubeCountY => {
1693                    self.flags.indexes.cube_count_tuple = true;
1694                    Variable::CubeCountY
1695                }
1696                gpu::Builtin::CubeCountZ => {
1697                    self.flags.indexes.cube_count_tuple = true;
1698                    Variable::CubeCountZ
1699                }
1700                gpu::Builtin::UnitPos => {
1701                    self.flags.indexes.unit_pos = true;
1702                    Variable::UnitPos
1703                }
1704                gpu::Builtin::UnitPosX => {
1705                    self.flags.indexes.unit_pos_tuple = true;
1706                    Variable::UnitPosX
1707                }
1708                gpu::Builtin::UnitPosY => {
1709                    self.flags.indexes.unit_pos_tuple = true;
1710                    Variable::UnitPosY
1711                }
1712                gpu::Builtin::UnitPosZ => {
1713                    self.flags.indexes.unit_pos_tuple = true;
1714                    Variable::UnitPosZ
1715                }
1716                gpu::Builtin::PlaneDim => {
1717                    self.flags.indexes.plane_dim = true;
1718                    Variable::PlaneDim
1719                }
1720                gpu::Builtin::UnitPosPlane => {
1721                    self.flags.indexes.unit_pos_plane = true;
1722                    Variable::UnitPosPlane
1723                }
1724            },
1725            gpu::VariableKind::LocalArray {
1726                id,
1727                length,
1728                unroll_factor,
1729            } => {
1730                let item = self.compile_type(item);
1731                if !self.local_arrays.iter().any(|s| s.index == id) {
1732                    self.local_arrays
1733                        .push(LocalArray::new(id, item, length * unroll_factor));
1734                }
1735                Variable::LocalArray(id, item, length)
1736            }
1737            gpu::VariableKind::Matrix { id, mat } => {
1738                self.flags.inst_wmma = true;
1739                Variable::WmmaFragment {
1740                    id,
1741                    frag: self.compile_matrix(mat),
1742                }
1743            }
1744            gpu::VariableKind::Pipeline { id, num_stages } => {
1745                self.flags.op_pipeline = true;
1746                let pipeline = Variable::Pipeline { id };
1747                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1748                    self.pipelines.push(PipelineOps::Init {
1749                        pipeline,
1750                        num_stages,
1751                    });
1752                }
1753                pipeline
1754            }
1755            gpu::VariableKind::Barrier { id, level } => {
1756                self.flags.op_barrier = true;
1757                Variable::Barrier { id, level }
1758            }
1759            gpu::VariableKind::BarrierToken { id, level } => {
1760                self.flags.op_barrier = true;
1761                Variable::BarrierToken { id, level }
1762            }
1763        }
1764    }
1765
1766    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1767        Fragment {
1768            ident: self.compile_matrix_ident(matrix.ident),
1769            m: matrix.m,
1770            n: matrix.n,
1771            k: matrix.k,
1772            elem: self.compile_storage_type(matrix.storage),
1773            layout: self.compile_matrix_layout(matrix.layout),
1774        }
1775    }
1776
1777    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1778        match ident {
1779            gpu::MatrixIdent::A => FragmentIdent::A,
1780            gpu::MatrixIdent::B => FragmentIdent::B,
1781            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1782        }
1783    }
1784
1785    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1786        match layout {
1787            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1788            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1789            gpu::MatrixLayout::Undefined => None,
1790        }
1791    }
1792
1793    fn compile_binding(&mut self, binding: cubecl_runtime::kernel::Binding) -> Binding<D> {
1794        Binding {
1795            id: binding.id,
1796            item: self.compile_type(binding.ty),
1797            location: binding.location,
1798            size: binding.size,
1799            vis: binding.visibility,
1800        }
1801    }
1802
1803    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1804        let item = match ty {
1805            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1806            gpu::Type::Line(ty, line_size) => {
1807                Item::new(self.compile_storage_type(ty), line_size as usize, false)
1808            }
1809            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1810        };
1811        if item.elem != super::Elem::TF32 {
1812            self.items.insert(item);
1813            self.items.insert(item.optimized());
1814        } else {
1815            // TF32 is represented as `float` in C++
1816            let mut item = item;
1817            item.elem = super::Elem::F32;
1818            self.items.insert(item);
1819        }
1820
1821        item
1822    }
1823
1824    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1825        match value {
1826            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1827            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1828            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1829                FloatKind::E2M1 => {
1830                    self.flags.elem_fp4 = true;
1831                    Elem::FP4x2(FP4Kind::E2M1)
1832                }
1833                FloatKind::E2M3 => {
1834                    self.flags.elem_fp6 = true;
1835                    Elem::FP6x2(FP6Kind::E2M3)
1836                }
1837                FloatKind::E3M2 => {
1838                    self.flags.elem_fp6 = true;
1839                    Elem::FP6(FP6Kind::E3M2)
1840                }
1841                FloatKind::E4M3 => {
1842                    self.flags.elem_fp8 = true;
1843                    Elem::FP8x2(FP8Kind::E4M3)
1844                }
1845                FloatKind::E5M2 => {
1846                    self.flags.elem_fp8 = true;
1847                    Elem::FP8x2(FP8Kind::E5M2)
1848                }
1849                FloatKind::UE8M0 => {
1850                    self.flags.elem_fp8 = true;
1851                    Elem::FP8x2(FP8Kind::UE8M0)
1852                }
1853                FloatKind::F16 => {
1854                    self.flags.elem_f16 = true;
1855                    Elem::F16x2
1856                }
1857                FloatKind::BF16 => {
1858                    self.flags.elem_bf16 = true;
1859                    Elem::BF16x2
1860                }
1861                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1862            },
1863            other => unimplemented!("Unsupported storage type: {other}"),
1864        }
1865    }
1866
1867    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1868        match value {
1869            gpu::ElemType::Float(kind) => match kind {
1870                gpu::FloatKind::E2M1 => {
1871                    self.flags.elem_fp4 = true;
1872                    Elem::FP4(FP4Kind::E2M1)
1873                }
1874                gpu::FloatKind::E2M3 => {
1875                    self.flags.elem_fp6 = true;
1876                    Elem::FP6(FP6Kind::E2M3)
1877                }
1878                gpu::FloatKind::E3M2 => {
1879                    self.flags.elem_fp6 = true;
1880                    Elem::FP6(FP6Kind::E3M2)
1881                }
1882                gpu::FloatKind::E4M3 => {
1883                    self.flags.elem_fp8 = true;
1884                    Elem::FP8(FP8Kind::E4M3)
1885                }
1886                gpu::FloatKind::E5M2 => {
1887                    self.flags.elem_fp8 = true;
1888                    Elem::FP8(FP8Kind::E5M2)
1889                }
1890                gpu::FloatKind::UE8M0 => {
1891                    self.flags.elem_fp8 = true;
1892                    Elem::FP8(FP8Kind::UE8M0)
1893                }
1894                gpu::FloatKind::F16 => {
1895                    self.flags.elem_f16 = true;
1896                    Elem::F16
1897                }
1898                gpu::FloatKind::BF16 => {
1899                    self.flags.elem_bf16 = true;
1900                    Elem::BF16
1901                }
1902                gpu::FloatKind::TF32 => Elem::TF32,
1903                gpu::FloatKind::Flex32 => Elem::F32,
1904                gpu::FloatKind::F32 => Elem::F32,
1905                gpu::FloatKind::F64 => Elem::F64,
1906            },
1907            gpu::ElemType::Int(kind) => match kind {
1908                gpu::IntKind::I8 => Elem::I8,
1909                gpu::IntKind::I16 => Elem::I16,
1910                gpu::IntKind::I32 => Elem::I32,
1911                gpu::IntKind::I64 => Elem::I64,
1912            },
1913            gpu::ElemType::UInt(kind) => match kind {
1914                gpu::UIntKind::U8 => Elem::U8,
1915                gpu::UIntKind::U16 => Elem::U16,
1916                gpu::UIntKind::U32 => Elem::U32,
1917                gpu::UIntKind::U64 => Elem::U64,
1918            },
1919            gpu::ElemType::Bool => Elem::Bool,
1920        }
1921    }
1922}
1923
1924fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1925    match elem {
1926        gpu::ElemType::Float(kind) => matches!(
1927            kind,
1928            FloatKind::E2M1
1929                | FloatKind::E2M3
1930                | FloatKind::E3M2
1931                | FloatKind::E4M3
1932                | FloatKind::E5M2
1933                | FloatKind::UE8M0
1934        ),
1935        _ => false,
1936    }
1937}
1938
1939fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1940    Variable::ConstantScalar(
1941        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1942        Elem::U32,
1943    )
1944}
1945
1946pub fn register_supported_types(props: &mut DeviceProperties) {
1947    let supported_types = [
1948        gpu::ElemType::UInt(gpu::UIntKind::U8),
1949        gpu::ElemType::UInt(gpu::UIntKind::U16),
1950        gpu::ElemType::UInt(gpu::UIntKind::U32),
1951        gpu::ElemType::UInt(gpu::UIntKind::U64),
1952        gpu::ElemType::Int(gpu::IntKind::I8),
1953        gpu::ElemType::Int(gpu::IntKind::I16),
1954        gpu::ElemType::Int(gpu::IntKind::I32),
1955        gpu::ElemType::Int(gpu::IntKind::I64),
1956        gpu::ElemType::Float(gpu::FloatKind::BF16),
1957        gpu::ElemType::Float(gpu::FloatKind::F16),
1958        gpu::ElemType::Float(gpu::FloatKind::F32),
1959        gpu::ElemType::Float(gpu::FloatKind::Flex32),
1960        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
1961        //gpu::Elem::Float(gpu::FloatKind::F64),
1962        gpu::ElemType::Bool,
1963    ];
1964
1965    let supported_atomic_types = [
1966        gpu::ElemType::Int(gpu::IntKind::I32),
1967        gpu::ElemType::Int(gpu::IntKind::I64),
1968        gpu::ElemType::UInt(gpu::UIntKind::U32),
1969        gpu::ElemType::UInt(gpu::UIntKind::U64),
1970        gpu::ElemType::Float(gpu::FloatKind::F32),
1971    ];
1972
1973    for ty in supported_types {
1974        props.register_type_usage(ty, TypeUsage::all_scalar());
1975    }
1976
1977    for ty in supported_atomic_types {
1978        props.register_type_usage(
1979            gpu::StorageType::Atomic(ty),
1980            TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1981        );
1982    }
1983}