cubecl_cpp/shared/
base.rs

1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::CubeDim;
5use cubecl_core::ir::{FloatKind, Processor, UIntKind, VariableKind};
6use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
7use cubecl_core::{
8    Compiler,
9    ir::{self as gpu},
10};
11use cubecl_core::{
12    ir::{Operation, SourceLoc},
13    prelude::{FastMath, KernelDefinition},
14};
15use cubecl_opt::{Optimizer, SharedLiveness};
16use cubecl_runtime::{DeviceProperties, TypeUsage};
17
18use crate::shared::MmaShape;
19
20use super::{
21    BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
22    Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
23    Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
24};
25use super::{FP4Kind, barrier::BarrierOps};
26use super::{FP8Kind, pipeline::PipelineOps};
27
28pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
29    std::sync::atomic::AtomicU32::new(0);
30
31#[derive(Clone, Debug)]
32pub struct CompilationOptions {
33    pub warp_size: u32,
34    pub grid_constants: bool,
35    pub supports_clusters: bool,
36}
37
38impl Default for CompilationOptions {
39    fn default() -> Self {
40        Self {
41            warp_size: 32,
42            grid_constants: false,
43            supports_clusters: false,
44        }
45    }
46}
47
48/// Cube indexes flags.
49/// When true the corresponding index is declared and computed as needed in the kernel.
50#[derive(Debug, Clone, Default)]
51pub struct CubeIndexFlags {
52    pub absolute_pos: bool,
53    pub absolute_pos_tuple: bool,
54    pub cube_count: bool,
55    pub cube_count_tuple: bool,
56    pub cube_dim: bool,
57    pub cube_dim_tuple: bool,
58    pub cube_pos: bool,
59    pub cube_pos_tuple: bool,
60    pub plane_dim: bool,
61    pub plane_dim_checked: bool,
62    pub plane_index: bool,
63    pub unit_pos: bool,
64    pub unit_pos_tuple: bool,
65    pub unit_pos_plane: bool,
66    pub cluster_pos: bool,
67}
68
69/// Flags gathered during Cube IR translation for the kernel compilation.
70#[derive(Debug, Clone, Default)]
71pub struct Flags {
72    pub elem_fp4: bool,
73    pub elem_fp6: bool,
74    pub elem_fp8: bool,
75    pub elem_bf16: bool,
76    pub elem_f16: bool,
77    pub elem_tf32: bool,
78    pub indexes: CubeIndexFlags,
79    pub op_barrier: bool,
80    pub op_pipeline: bool,
81    pub inst_fast_math: bool,
82    pub inst_tma: bool,
83    pub inst_tma_im2col: bool,
84    pub inst_wmma: bool,
85    pub use_grid_constants: bool,
86    pub static_meta_length: usize,
87    pub has_dynamic_meta: bool,
88    pub cube_dim: CubeDim,
89    pub cluster_dim: Option<CubeDim>,
90}
91
92#[allow(clippy::too_many_arguments)]
93#[derive(Clone, Debug, Default)]
94pub struct CppCompiler<D: Dialect> {
95    barriers: Vec<BarrierOps<D>>,
96    compilation_options: CompilationOptions,
97    const_arrays: Vec<ConstArray<D>>,
98    ext_meta_positions: Vec<u32>,
99    cluster_dim: CubeDim,
100    extensions: Vec<D::Extension>,
101    flags: Flags,
102    items: HashSet<Item<D>>,
103    local_arrays: Vec<LocalArray<D>>,
104    metadata: cubecl_core::Metadata,
105    pipelines: Vec<PipelineOps<D>>,
106    source_loc: Option<SourceLoc>,
107    strategy: ExecutionMode,
108}
109
110impl<D: Dialect> Compiler for CppCompiler<D> {
111    type Representation = ComputeKernel<D>;
112    type CompilationOptions = CompilationOptions;
113
114    fn compile(
115        &mut self,
116        mut kernel: KernelDefinition,
117        compilation_options: &Self::CompilationOptions,
118        strategy: ExecutionMode,
119    ) -> Self::Representation {
120        self.compilation_options = compilation_options.clone();
121        self.strategy = strategy;
122
123        if !self.compilation_options.supports_clusters {
124            kernel.options.cluster_dim = None;
125        }
126        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
127
128        let ir = self.clone().compile_ir(kernel);
129        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
130        ir
131    }
132
133    fn elem_size(&self, elem: gpu::ElemType) -> usize {
134        elem.size()
135    }
136
137    fn extension(&self) -> &'static str {
138        "cpp"
139    }
140}
141
142impl<D: Dialect> CppCompiler<D> {
143    fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
144        self.build_metadata(&value);
145
146        let instructions = self.compile_scope(&mut value.body.clone());
147        let tensor_maps = value
148            .tensor_maps
149            .into_iter()
150            .map(|b| self.compile_binding(b))
151            .collect();
152        let buffers = value
153            .buffers
154            .into_iter()
155            .map(|b| self.compile_binding(b))
156            .collect();
157        let scalars = value
158            .scalars
159            .into_iter()
160            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
161            .collect();
162
163        // translation flags
164        let flags = Flags {
165            indexes: D::builtin_rules(&self.flags.indexes),
166            inst_wmma: self.flags.inst_wmma,
167            op_pipeline: self.flags.op_pipeline,
168            op_barrier: self.flags.op_barrier,
169            elem_fp4: self.flags.elem_fp4,
170            elem_fp6: self.flags.elem_fp6,
171            elem_fp8: self.flags.elem_fp8,
172            elem_bf16: self.flags.elem_bf16,
173            elem_f16: self.flags.elem_f16,
174            elem_tf32: self.flags.elem_tf32,
175            inst_fast_math: value
176                .options
177                .fp_math_mode
178                .contains(FastMath::ReducedPrecision),
179            inst_tma: self.flags.inst_tma,
180            inst_tma_im2col: self.flags.inst_tma_im2col,
181            use_grid_constants: self.compilation_options.grid_constants,
182            // TODO: At some point we should only pass dynamic meta if tensors are present,
183            // not if only arrays are present. For now, leave like this
184            has_dynamic_meta: self.metadata.static_len() > 0,
185            static_meta_length: self.metadata.static_len() as usize,
186            cube_dim: value.cube_dim,
187            cluster_dim: value.options.cluster_dim,
188        };
189
190        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
191        let shared_allocs = opt.analysis::<SharedLiveness>();
192        let shared_memories = shared_allocs
193            .allocations
194            .values()
195            .map(|alloc| SharedMemory {
196                index: alloc.smem.id,
197                item: self.compile_type(alloc.smem.ty),
198                length: alloc.smem.length,
199                align: alloc.smem.align,
200                offset: alloc.offset,
201            })
202            .collect();
203
204        let body = Body {
205            instructions,
206            shared_memories,
207            pipelines: self.pipelines,
208            barriers: self.barriers,
209            const_arrays: self.const_arrays,
210            local_arrays: self.local_arrays,
211        };
212
213        let mut cluster_dim = value.options.cluster_dim;
214        if !self.compilation_options.supports_clusters {
215            cluster_dim = None;
216        }
217
218        ComputeKernel {
219            tensor_maps,
220            buffers,
221            scalars,
222            meta_static_len: self.metadata.static_len() as usize,
223            cube_dim: value.cube_dim,
224            body,
225            extensions: self.extensions,
226            flags,
227            items: self.items,
228            kernel_name: value.options.kernel_name,
229            cluster_dim,
230        }
231    }
232
233    fn build_metadata(&mut self, value: &KernelDefinition) {
234        let mut num_ext = 0;
235
236        let mut all_meta: Vec<_> = value
237            .buffers
238            .iter()
239            .chain(value.tensor_maps.iter())
240            .map(|buf| (buf.id, buf.has_extended_meta))
241            .collect();
242
243        all_meta.sort_by_key(|(id, _)| *id);
244
245        for (_, has_extended_meta) in &all_meta {
246            self.ext_meta_positions.push(num_ext);
247            if *has_extended_meta {
248                num_ext += 1;
249            }
250        }
251
252        let num_meta = all_meta.len();
253
254        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
255    }
256
257    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
258        let id = var.index().expect("Variable should have index");
259        self.ext_meta_positions[id as usize]
260    }
261
262    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
263        let mut instructions = Vec::new();
264
265        let const_arrays = scope
266            .const_arrays
267            .drain(..)
268            .map(|(var, values)| ConstArray {
269                index: var.index().unwrap(),
270                item: self.compile_type(var.ty),
271                size: values.len() as u32,
272                values: values
273                    .into_iter()
274                    .map(|val| self.compile_variable(val))
275                    .collect(),
276            })
277            .collect::<Vec<_>>();
278        self.const_arrays.extend(const_arrays);
279
280        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
281        let dialect_processors = D::processors();
282        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
283        processors.extend(dialect_processors.iter().map(|it| &**it));
284
285        let processing = scope.process(processors);
286
287        for var in processing.variables {
288            instructions.push(Instruction::DeclareVariable {
289                var: self.compile_variable(var),
290            });
291        }
292
293        processing
294            .instructions
295            .into_iter()
296            .for_each(|op| self.compile_instruction(&mut instructions, op));
297
298        instructions
299    }
300
301    fn compile_instruction(
302        &mut self,
303        instructions: &mut Vec<Instruction<D>>,
304        instruction: gpu::Instruction,
305    ) {
306        self.update_debug_loc(instructions, &instruction);
307        let out = instruction.out;
308        match instruction.operation {
309            gpu::Operation::Copy(variable) => {
310                instructions.push(Instruction::Assign(UnaryInstruction {
311                    input: self.compile_variable(variable),
312                    out: self.compile_variable(out.unwrap()),
313                }));
314            }
315            gpu::Operation::Arithmetic(op) => self.compile_arithmetic(op, out, instructions),
316            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
317            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
318            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
319            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
320            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
321            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
322            gpu::Operation::Synchronization(val) => match val {
323                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
324                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
325                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
326                gpu::Synchronization::SyncProxyShared => {
327                    self.flags.inst_tma = true;
328                    instructions.push(Instruction::ProxySharedFence)
329                }
330            },
331            gpu::Operation::Plane(op) => {
332                self.flags.indexes.plane_dim_checked = true;
333                let out = self.compile_variable(out.unwrap());
334                match op {
335                    gpu::Plane::Sum(op) => {
336                        let instruction = WarpInstruction::ReduceSum {
337                            input: self.compile_variable(op.input),
338                            out,
339                        };
340                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
341                        instructions.push(Instruction::Warp(instruction));
342                    }
343                    gpu::Plane::InclusiveSum(op) => {
344                        self.flags.indexes.unit_pos_plane = true;
345                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
346                            input: self.compile_variable(op.input),
347                            out,
348                        }))
349                    }
350                    gpu::Plane::InclusiveProd(op) => {
351                        self.flags.indexes.unit_pos_plane = true;
352                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
353                            input: self.compile_variable(op.input),
354                            out,
355                        }))
356                    }
357                    gpu::Plane::ExclusiveSum(op) => {
358                        self.flags.indexes.unit_pos_plane = true;
359                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
360                            input: self.compile_variable(op.input),
361                            out,
362                        }))
363                    }
364                    gpu::Plane::ExclusiveProd(op) => {
365                        self.flags.indexes.unit_pos_plane = true;
366                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
367                            input: self.compile_variable(op.input),
368                            out,
369                        }))
370                    }
371                    gpu::Plane::Prod(op) => {
372                        let instruction = WarpInstruction::ReduceProd {
373                            input: self.compile_variable(op.input),
374                            out,
375                        };
376                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
377                        instructions.push(Instruction::Warp(instruction))
378                    }
379                    gpu::Plane::Max(op) => {
380                        let instruction = WarpInstruction::ReduceMax {
381                            input: self.compile_variable(op.input),
382                            out,
383                        };
384                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
385                        instructions.push(Instruction::Warp(instruction))
386                    }
387                    gpu::Plane::Min(op) => {
388                        let instruction = WarpInstruction::ReduceMin {
389                            input: self.compile_variable(op.input),
390                            out,
391                        };
392                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
393                        instructions.push(Instruction::Warp(instruction))
394                    }
395                    gpu::Plane::Elect => {
396                        instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
397                    }
398                    gpu::Plane::All(op) => {
399                        instructions.push(Instruction::Warp(WarpInstruction::All {
400                            input: self.compile_variable(op.input),
401                            out,
402                        }))
403                    }
404                    gpu::Plane::Any(op) => {
405                        instructions.push(Instruction::Warp(WarpInstruction::Any {
406                            input: self.compile_variable(op.input),
407                            out,
408                        }))
409                    }
410                    gpu::Plane::Ballot(op) => {
411                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
412                            input: self.compile_variable(op.input),
413                            out,
414                        }))
415                    }
416                    gpu::Plane::Broadcast(op) => {
417                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
418                            input: self.compile_variable(op.lhs),
419                            id: self.compile_variable(op.rhs),
420                            out,
421                        }))
422                    }
423                    gpu::Plane::Shuffle(op) => {
424                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
425                            input: self.compile_variable(op.lhs),
426                            src_lane: self.compile_variable(op.rhs),
427                            out,
428                        }))
429                    }
430                    gpu::Plane::ShuffleXor(op) => {
431                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
432                            input: self.compile_variable(op.lhs),
433                            mask: self.compile_variable(op.rhs),
434                            out,
435                        }))
436                    }
437                    gpu::Plane::ShuffleUp(op) => {
438                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
439                            input: self.compile_variable(op.lhs),
440                            delta: self.compile_variable(op.rhs),
441                            out,
442                        }))
443                    }
444                    gpu::Plane::ShuffleDown(op) => {
445                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
446                            input: self.compile_variable(op.lhs),
447                            delta: self.compile_variable(op.rhs),
448                            out,
449                        }))
450                    }
451                }
452            }
453            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
454            gpu::Operation::NonSemantic(debug) => match debug {
455                gpu::NonSemantic::Print {
456                    format_string,
457                    args,
458                } => instructions.push(Instruction::Printf {
459                    format_string,
460                    args: args
461                        .into_iter()
462                        .map(|arg| self.compile_variable(arg))
463                        .collect(),
464                }),
465                gpu::NonSemantic::Comment { content } => {
466                    instructions.push(Instruction::Comment { content })
467                }
468                // Don't need to handle scopes
469                _ => {}
470            },
471            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
472                gpu::BarrierOps::Init {
473                    barrier,
474                    with_cta_fence,
475                } => {
476                    let VariableKind::Barrier { level, .. } = barrier.kind else {
477                        unreachable!()
478                    };
479                    let barrier = self.compile_variable(barrier);
480                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
481                        barrier,
482                        level,
483                        with_cta_fence,
484                    }));
485                }
486                gpu::BarrierOps::MemCopyAsync {
487                    barrier,
488                    source,
489                    source_length,
490                    offset_source,
491                    offset_out,
492                } => {
493                    let VariableKind::Barrier { level, .. } = barrier.kind else {
494                        unreachable!()
495                    };
496                    instructions.push(Instruction::Barrier(
497                        super::barrier::BarrierOps::MemCopyAsync {
498                            barrier: self.compile_variable(barrier),
499                            source: self.compile_variable(source),
500                            destination: self.compile_variable(out.unwrap()),
501                            source_length: self.compile_variable(source_length),
502                            offset_source: self.compile_variable(offset_source),
503                            offset_out: self.compile_variable(offset_out),
504                            level,
505                        },
506                    ));
507                }
508                gpu::BarrierOps::TmaLoad {
509                    barrier,
510                    tensor_map,
511                    offset_out,
512                    indices,
513                } => {
514                    instructions.push(Instruction::Barrier(
515                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
516                            barrier: self.compile_variable(barrier),
517                            smem_buffer: self.compile_variable(out.unwrap()),
518                            smem_offset: self.compile_variable(offset_out),
519                            tensor_map: self.compile_variable(tensor_map),
520                            indices: indices
521                                .into_iter()
522                                .map(|it| self.compile_variable(it))
523                                .collect(),
524                        },
525                    ));
526                }
527                gpu::BarrierOps::TmaLoadIm2col {
528                    barrier,
529                    tensor_map,
530                    offset_out,
531                    indices,
532                    offsets,
533                } => {
534                    self.flags.inst_tma_im2col = true;
535                    instructions.push(Instruction::Barrier(
536                        super::barrier::BarrierOps::TmaLoadIm2col {
537                            barrier: self.compile_variable(barrier),
538                            smem_buffer: self.compile_variable(out.unwrap()),
539                            smem_offset: self.compile_variable(offset_out),
540                            tensor_map: self.compile_variable(tensor_map),
541                            indices: indices
542                                .into_iter()
543                                .map(|it| self.compile_variable(it))
544                                .collect(),
545                            offsets: offsets
546                                .into_iter()
547                                .map(|it| self.compile_variable(it))
548                                .collect(),
549                        },
550                    ));
551                }
552                gpu::BarrierOps::Arrive { barrier } => {
553                    let VariableKind::Barrier { level, .. } = barrier.kind else {
554                        unreachable!()
555                    };
556                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
557                        barrier: self.compile_variable(barrier),
558                        level,
559                    }))
560                }
561                gpu::BarrierOps::ArriveTx {
562                    barrier,
563                    arrive_count_update,
564                    transaction_count_update,
565                } => {
566                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
567                        barrier: self.compile_variable(barrier),
568                        arrive_count_update: self.compile_variable(arrive_count_update),
569                        transaction_count_update: self.compile_variable(transaction_count_update),
570                    }))
571                }
572                gpu::BarrierOps::ExpectTx {
573                    barrier,
574                    transaction_count_update,
575                } => {
576                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
577                        barrier: self.compile_variable(barrier),
578                        transaction_count_update: self.compile_variable(transaction_count_update),
579                    }))
580                }
581                gpu::BarrierOps::Wait { barrier } => {
582                    let VariableKind::Barrier { level, .. } = barrier.kind else {
583                        unreachable!()
584                    };
585                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
586                        barrier: self.compile_variable(barrier),
587                        level,
588                    }))
589                }
590                gpu::BarrierOps::ArriveAndWait { barrier } => {
591                    let VariableKind::Barrier { level, .. } = barrier.kind else {
592                        unreachable!()
593                    };
594                    instructions.push(Instruction::Barrier(
595                        super::barrier::BarrierOps::ArriveAndWait {
596                            barrier: self.compile_variable(barrier),
597                            level,
598                        },
599                    ))
600                }
601            },
602            gpu::Operation::Tma(tma_ops) => {
603                self.flags.inst_tma = true;
604                match tma_ops {
605                    gpu::TmaOps::TmaStore {
606                        source,
607                        coordinates,
608                        offset_source,
609                    } => {
610                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
611                            smem_buffer: self.compile_variable(source),
612                            smem_offset: self.compile_variable(offset_source),
613                            tensor_map: self.compile_variable(out.unwrap()),
614                            indices: coordinates
615                                .into_iter()
616                                .map(|it| self.compile_variable(it))
617                                .collect(),
618                        });
619                    }
620                    gpu::TmaOps::CommitGroup => {
621                        instructions.push(Instruction::BulkCommitGroup);
622                    }
623                    gpu::TmaOps::WaitGroup { max_pending } => {
624                        instructions.push(Instruction::BulkWaitGroup { max_pending });
625                    }
626                    gpu::TmaOps::WaitGroupRead { max_pending } => {
627                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
628                    }
629                }
630            }
631            gpu::Operation::Free(_) => {}
632        }
633    }
634
635    fn update_debug_loc(
636        &mut self,
637        instructions: &mut Vec<Instruction<D>>,
638        inst: &gpu::Instruction,
639    ) {
640        if !matches!(inst.operation, Operation::NonSemantic(_)) {
641            match &inst.source_loc {
642                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
643                    self.source_loc = Some(loc.clone());
644                    instructions.push(Instruction::Line {
645                        file: loc.source.file.clone(),
646                        line: loc.line,
647                    });
648                }
649                _ => {}
650            }
651        }
652    }
653
654    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
655        self.flags.inst_wmma = true;
656
657        let out = self.compile_variable(out.unwrap());
658
659        let inst = match cmma {
660            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
661                frag: out,
662                value: self.compile_variable(value),
663            },
664            gpu::CoopMma::Load {
665                value,
666                stride,
667                offset,
668                layout,
669            } => WmmaInstruction::Load {
670                frag: out,
671                offset: self.compile_variable(offset),
672                value: self.compile_variable(value),
673                stride: self.compile_variable(stride),
674                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
675            },
676            gpu::CoopMma::Execute {
677                mat_a,
678                mat_b,
679                mat_c,
680            } => WmmaInstruction::Execute {
681                frag_a: self.compile_variable(mat_a),
682                frag_b: self.compile_variable(mat_b),
683                frag_c: self.compile_variable(mat_c),
684                frag_d: out,
685                warp_size: self.compilation_options.warp_size,
686            },
687            gpu::CoopMma::ExecuteManual {
688                matrix,
689                registers_a,
690                registers_b,
691                registers_c,
692            } => WmmaInstruction::ExecuteManual {
693                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
694                frag_a: registers_a
695                    .into_iter()
696                    .map(|it| self.compile_variable(it))
697                    .collect(),
698                frag_b: registers_b
699                    .into_iter()
700                    .map(|it| self.compile_variable(it))
701                    .collect(),
702                frag_c: registers_c
703                    .into_iter()
704                    .map(|it| self.compile_variable(it))
705                    .collect(),
706                frag_d: out,
707            },
708            gpu::CoopMma::ExecuteScaled {
709                matrix,
710                registers_a,
711                registers_b,
712                registers_c,
713                scales_a,
714                scales_b,
715                scales_factor,
716            } => WmmaInstruction::ExecuteScaled {
717                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
718                frag_a: registers_a
719                    .into_iter()
720                    .map(|it| self.compile_variable(it))
721                    .collect(),
722                frag_b: registers_b
723                    .into_iter()
724                    .map(|it| self.compile_variable(it))
725                    .collect(),
726                frag_c: registers_c
727                    .into_iter()
728                    .map(|it| self.compile_variable(it))
729                    .collect(),
730                frag_d: out,
731
732                scales_a: self.compile_variable(scales_a),
733                scales_b: self.compile_variable(scales_b),
734                scales_factor,
735            },
736            gpu::CoopMma::Store {
737                mat,
738                stride,
739                offset,
740                layout,
741            } => {
742                self.flags.indexes.unit_pos = true;
743                self.flags.indexes.plane_index = true;
744                WmmaInstruction::Store {
745                    output: out,
746                    offset: self.compile_variable(offset),
747                    frag: self.compile_variable(mat),
748                    stride: self.compile_variable(stride),
749                    layout: self
750                        .compile_matrix_layout(layout)
751                        .expect("Layout required for store instruction"),
752                }
753            }
754            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
755                input: self.compile_variable(input),
756                output: out,
757            },
758            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
759                panic!("Row/Col index should be handled by processors")
760            }
761        };
762
763        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
764
765        Instruction::Wmma(inst)
766    }
767
768    fn compile_metadata(
769        &mut self,
770        metadata: gpu::Metadata,
771        out: Option<gpu::Variable>,
772    ) -> Instruction<D> {
773        let out = out.unwrap();
774        match metadata {
775            gpu::Metadata::Stride { dim, var } => {
776                let position = self.ext_meta_position(var);
777                let offset = self.metadata.stride_offset_index(position);
778                Instruction::ExtendedMetadata {
779                    info_offset: self.compile_variable(offset.into()),
780                    dim: self.compile_variable(dim),
781                    split_meta: self.compilation_options.grid_constants,
782                    static_offset: self.metadata.static_len(),
783                    out: self.compile_variable(out),
784                }
785            }
786            gpu::Metadata::Shape { dim, var } => {
787                let position = self.ext_meta_position(var);
788                let offset = self.metadata.shape_offset_index(position);
789                Instruction::ExtendedMetadata {
790                    info_offset: self.compile_variable(offset.into()),
791                    dim: self.compile_variable(dim),
792                    split_meta: self.compilation_options.grid_constants,
793                    static_offset: self.metadata.static_len(),
794                    out: self.compile_variable(out),
795                }
796            }
797            gpu::Metadata::Rank { var } => {
798                let out = self.compile_variable(out);
799                let pos = self.ext_meta_position(var);
800                let offset = self.metadata.rank_index(pos);
801                super::Instruction::Metadata {
802                    info_offset: self.compile_variable(offset.into()),
803                    split_meta: self.compilation_options.grid_constants,
804                    out,
805                }
806            }
807            gpu::Metadata::Length { var } => {
808                let input = self.compile_variable(var);
809                let out = self.compile_variable(out);
810
811                match input {
812                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
813                    Variable::SharedMemory(_id, _item, length) => {
814                        Instruction::ConstLength { length, out }
815                    }
816                    _ => {
817                        let id = input.id().expect("Variable should have id");
818                        let offset = self.metadata.len_index(id);
819                        Instruction::Metadata {
820                            info_offset: self.compile_variable(offset.into()),
821                            split_meta: self.compilation_options.grid_constants,
822                            out,
823                        }
824                    }
825                }
826            }
827            gpu::Metadata::BufferLength { var } => {
828                let input = self.compile_variable(var);
829                let out = self.compile_variable(out);
830
831                match input {
832                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
833                    _ => {
834                        let id = input.id().expect("Variable should have id");
835                        let offset = self.metadata.buffer_len_index(id);
836                        Instruction::Metadata {
837                            info_offset: self.compile_variable(offset.into()),
838                            split_meta: self.compilation_options.grid_constants,
839                            out,
840                        }
841                    }
842                }
843            }
844        }
845    }
846
847    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
848        match branch {
849            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
850                cond: self.compile_variable(op.cond),
851                instructions: self.compile_scope(&mut op.scope),
852            }),
853            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
854                cond: self.compile_variable(op.cond),
855                instructions_if: self.compile_scope(&mut op.scope_if),
856                instructions_else: self.compile_scope(&mut op.scope_else),
857            }),
858            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
859                value: self.compile_variable(op.value),
860                instructions_default: self.compile_scope(&mut op.scope_default),
861                instructions_cases: op
862                    .cases
863                    .into_iter()
864                    .map(|(val, mut block)| {
865                        (self.compile_variable(val), self.compile_scope(&mut block))
866                    })
867                    .collect(),
868            }),
869            gpu::Branch::Return => instructions.push(Instruction::Return),
870            gpu::Branch::Break => instructions.push(Instruction::Break),
871            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
872                i: self.compile_variable(range_loop.i),
873                start: self.compile_variable(range_loop.start),
874                end: self.compile_variable(range_loop.end),
875                step: range_loop.step.map(|it| self.compile_variable(it)),
876                inclusive: range_loop.inclusive,
877                instructions: self.compile_scope(&mut range_loop.scope),
878            }),
879            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
880                instructions: self.compile_scope(&mut op.scope),
881            }),
882        };
883    }
884
885    fn compile_atomic(
886        &mut self,
887        value: gpu::AtomicOp,
888        out: Option<gpu::Variable>,
889        instructions: &mut Vec<Instruction<D>>,
890    ) {
891        let out = out.unwrap();
892        match value {
893            gpu::AtomicOp::Load(op) => {
894                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
895            }
896            gpu::AtomicOp::Store(op) => {
897                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
898            }
899            gpu::AtomicOp::Swap(op) => {
900                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
901            }
902            gpu::AtomicOp::Add(op) => {
903                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
904            }
905            gpu::AtomicOp::Sub(op) => {
906                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
907            }
908            gpu::AtomicOp::Max(op) => {
909                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
910            }
911            gpu::AtomicOp::Min(op) => {
912                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
913            }
914            gpu::AtomicOp::And(op) => {
915                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
916            }
917            gpu::AtomicOp::Or(op) => {
918                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
919            }
920            gpu::AtomicOp::Xor(op) => {
921                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
922            }
923            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
924                input: self.compile_variable(op.input),
925                cmp: self.compile_variable(op.cmp),
926                val: self.compile_variable(op.val),
927                out: self.compile_variable(out),
928            }),
929        }
930    }
931
932    fn compile_arithmetic(
933        &mut self,
934        value: gpu::Arithmetic,
935        out: Option<gpu::Variable>,
936        instructions: &mut Vec<Instruction<D>>,
937    ) {
938        let out = out.unwrap();
939        match value {
940            gpu::Arithmetic::Add(op) => {
941                instructions.push(Instruction::Add(self.compile_binary(op, out)))
942            }
943            gpu::Arithmetic::SaturatingAdd(op) => {
944                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
945            }
946            gpu::Arithmetic::Mul(op) => {
947                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
948            }
949            gpu::Arithmetic::Div(op) => {
950                instructions.push(Instruction::Div(self.compile_binary(op, out)))
951            }
952            gpu::Arithmetic::Sub(op) => {
953                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
954            }
955            gpu::Arithmetic::SaturatingSub(op) => {
956                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
957            }
958            gpu::Arithmetic::MulHi(op) => {
959                let instruction = Instruction::HiMul(self.compile_binary(op, out));
960                D::register_instruction_extension(&mut self.extensions, &instruction);
961                instructions.push(instruction)
962            }
963            gpu::Arithmetic::Modulo(op) => {
964                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
965            }
966            gpu::Arithmetic::Abs(op) => {
967                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
968            }
969            gpu::Arithmetic::Exp(op) => {
970                instructions.push(Instruction::Exp(self.compile_unary(op, out)))
971            }
972            gpu::Arithmetic::Log(op) => {
973                instructions.push(Instruction::Log(self.compile_unary(op, out)))
974            }
975            gpu::Arithmetic::Log1p(op) => {
976                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
977            }
978            gpu::Arithmetic::Cos(op) => {
979                instructions.push(Instruction::Cos(self.compile_unary(op, out)))
980            }
981            gpu::Arithmetic::Sin(op) => {
982                instructions.push(Instruction::Sin(self.compile_unary(op, out)))
983            }
984            gpu::Arithmetic::Tanh(op) => {
985                let instruction = Instruction::Tanh(self.compile_unary(op, out));
986                D::register_instruction_extension(&mut self.extensions, &instruction);
987                instructions.push(instruction)
988            }
989            gpu::Arithmetic::Powf(op) => {
990                instructions.push(Instruction::Powf(self.compile_binary(op, out)))
991            }
992            gpu::Arithmetic::Powi(op) => {
993                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
994            }
995            gpu::Arithmetic::Sqrt(op) => {
996                instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
997            }
998            gpu::Arithmetic::Erf(op) => {
999                let instruction = Instruction::Erf(self.compile_unary(op, out));
1000                D::register_instruction_extension(&mut self.extensions, &instruction);
1001                instructions.push(instruction)
1002            }
1003            gpu::Arithmetic::Max(op) => {
1004                let instruction = Instruction::Max(self.compile_binary(op, out));
1005                D::register_instruction_extension(&mut self.extensions, &instruction);
1006                instructions.push(instruction)
1007            }
1008            gpu::Arithmetic::Min(op) => {
1009                let instruction = Instruction::Min(self.compile_binary(op, out));
1010                D::register_instruction_extension(&mut self.extensions, &instruction);
1011                instructions.push(instruction)
1012            }
1013            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1014                input: self.compile_variable(op.input),
1015                min_value: self.compile_variable(op.min_value),
1016                max_value: self.compile_variable(op.max_value),
1017                out: self.compile_variable(out),
1018            }),
1019            gpu::Arithmetic::Recip(op) => {
1020                let elem = op.input.ty.elem_type();
1021                let lhs = match elem {
1022                    gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1023                    gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1024                    gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1025                    gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1026                };
1027
1028                instructions.push(Instruction::Div(BinaryInstruction {
1029                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1030                    rhs: self.compile_variable(op.input),
1031                    out: self.compile_variable(out),
1032                }))
1033            }
1034            gpu::Arithmetic::Round(op) => {
1035                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1036            }
1037            gpu::Arithmetic::Floor(op) => {
1038                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1039            }
1040            gpu::Arithmetic::Ceil(op) => {
1041                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1042            }
1043            gpu::Arithmetic::Trunc(op) => {
1044                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1045            }
1046            gpu::Arithmetic::Remainder(op) => {
1047                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1048            }
1049            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1050                a: self.compile_variable(op.a),
1051                b: self.compile_variable(op.b),
1052                c: self.compile_variable(op.c),
1053                out: self.compile_variable(out),
1054            }),
1055            gpu::Arithmetic::Neg(op) => {
1056                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1057            }
1058            gpu::Arithmetic::Normalize(op) => {
1059                instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
1060            }
1061            gpu::Arithmetic::Magnitude(op) => {
1062                instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
1063            }
1064            gpu::Arithmetic::Dot(op) => {
1065                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1066            }
1067        };
1068    }
1069
1070    fn compile_comparison(
1071        &mut self,
1072        value: gpu::Comparison,
1073        out: Option<gpu::Variable>,
1074        instructions: &mut Vec<Instruction<D>>,
1075    ) {
1076        let out = out.unwrap();
1077        match value {
1078            gpu::Comparison::Equal(op) => {
1079                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1080            }
1081            gpu::Comparison::Lower(op) => {
1082                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1083            }
1084            gpu::Comparison::Greater(op) => {
1085                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1086            }
1087            gpu::Comparison::LowerEqual(op) => {
1088                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1089            }
1090            gpu::Comparison::GreaterEqual(op) => {
1091                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1092            }
1093            gpu::Comparison::NotEqual(op) => {
1094                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1095            }
1096            gpu::Comparison::IsNan(op) => {
1097                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1098            }
1099            gpu::Comparison::IsInf(op) => {
1100                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1101            }
1102        };
1103    }
1104
1105    fn compile_bitwise(
1106        &mut self,
1107        value: gpu::Bitwise,
1108        out: Option<gpu::Variable>,
1109        instructions: &mut Vec<Instruction<D>>,
1110    ) {
1111        let out = out.unwrap();
1112        match value {
1113            gpu::Bitwise::BitwiseOr(op) => {
1114                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1115            }
1116            gpu::Bitwise::BitwiseAnd(op) => {
1117                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1118            }
1119            gpu::Bitwise::BitwiseXor(op) => {
1120                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1121            }
1122            gpu::Bitwise::CountOnes(op) => {
1123                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1124            }
1125            gpu::Bitwise::ReverseBits(op) => {
1126                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1127            }
1128            gpu::Bitwise::ShiftLeft(op) => {
1129                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1130            }
1131            gpu::Bitwise::ShiftRight(op) => {
1132                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1133            }
1134            gpu::Bitwise::BitwiseNot(op) => {
1135                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1136            }
1137            gpu::Bitwise::LeadingZeros(op) => {
1138                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1139            }
1140            gpu::Bitwise::FindFirstSet(op) => {
1141                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1142                D::register_instruction_extension(&mut self.extensions, &instruction);
1143                instructions.push(instruction)
1144            }
1145        };
1146    }
1147
1148    fn compile_operator(
1149        &mut self,
1150        value: gpu::Operator,
1151        out: Option<gpu::Variable>,
1152        instructions: &mut Vec<Instruction<D>>,
1153    ) {
1154        let out = out.unwrap();
1155        match value {
1156            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1157                instructions.push(Instruction::Index(self.compile_index(op, out)));
1158            }
1159            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1160                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1161            }
1162            gpu::Operator::And(op) => {
1163                instructions.push(Instruction::And(self.compile_binary(op, out)))
1164            }
1165            gpu::Operator::Or(op) => {
1166                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1167            }
1168            gpu::Operator::Not(op) => {
1169                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1170            }
1171            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1172                inputs: op
1173                    .inputs
1174                    .into_iter()
1175                    .map(|it| self.compile_variable(it))
1176                    .collect(),
1177                out: self.compile_variable(out),
1178            }),
1179            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1180                input: self.compile_variable(op.input),
1181                in_index: self.compile_variable(op.in_index),
1182                out: self.compile_variable(out),
1183                out_index: self.compile_variable(op.out_index),
1184            }),
1185            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1186                input: self.compile_variable(op.input),
1187                in_index: self.compile_variable(op.in_index),
1188                out: self.compile_variable(out),
1189                out_index: self.compile_variable(op.out_index),
1190                len: op.len,
1191            }),
1192            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1193                cond: self.compile_variable(op.cond),
1194                then: self.compile_variable(op.then),
1195                or_else: self.compile_variable(op.or_else),
1196                out: self.compile_variable(out),
1197            }),
1198            // Needs special conversion semantics
1199            gpu::Operator::Cast(op)
1200                if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1201            {
1202                // We may need these for intermediates
1203                self.flags.elem_f16 = true;
1204                self.flags.elem_bf16 = true;
1205                let vec_in = op.input.ty.line_size();
1206                let packing = out.storage_type().packing_factor();
1207                self.compile_type(op.input.ty.line(packing));
1208                self.compile_type(
1209                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1210                );
1211                self.compile_type(
1212                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1213                );
1214                self.compile_type(
1215                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1216                );
1217                self.compile_type(
1218                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1219                );
1220
1221                let inst = self.compile_unary(op, out);
1222
1223                instructions.push(Instruction::SpecialCast(inst));
1224            }
1225            gpu::Operator::Cast(op) => {
1226                let op = self.compile_unary(op, out);
1227
1228                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1229                    self.flags.elem_tf32 = true;
1230                }
1231
1232                instructions.push(Instruction::Assign(op))
1233            }
1234            gpu::Operator::Reinterpret(op) => {
1235                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1236            }
1237        };
1238    }
1239
1240    fn compile_binary(
1241        &mut self,
1242        value: gpu::BinaryOperator,
1243        out: gpu::Variable,
1244    ) -> BinaryInstruction<D> {
1245        BinaryInstruction {
1246            lhs: self.compile_variable(value.lhs),
1247            rhs: self.compile_variable(value.rhs),
1248            out: self.compile_variable(out),
1249        }
1250    }
1251
1252    fn compile_index_assign(
1253        &mut self,
1254        value: gpu::IndexAssignOperator,
1255        out: gpu::Variable,
1256    ) -> IndexAssignInstruction<D> {
1257        IndexAssignInstruction {
1258            index: self.compile_variable(value.index),
1259            value: self.compile_variable(value.value),
1260            line_size: value.line_size,
1261            out: self.compile_variable(out),
1262        }
1263    }
1264
1265    fn compile_index(
1266        &mut self,
1267        value: gpu::IndexOperator,
1268        out: gpu::Variable,
1269    ) -> IndexInstruction<D> {
1270        IndexInstruction {
1271            list: self.compile_variable(value.list),
1272            index: self.compile_variable(value.index),
1273            line_size: value.line_size,
1274            out: self.compile_variable(out),
1275        }
1276    }
1277
1278    fn compile_unary(
1279        &mut self,
1280        value: gpu::UnaryOperator,
1281        out: gpu::Variable,
1282    ) -> UnaryInstruction<D> {
1283        UnaryInstruction {
1284            input: self.compile_variable(value.input),
1285            out: self.compile_variable(out),
1286        }
1287    }
1288
1289    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1290        let item = value.ty;
1291        match value.kind {
1292            gpu::VariableKind::GlobalInputArray(id) => {
1293                Variable::GlobalInputArray(id, self.compile_type(item))
1294            }
1295            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1296                id,
1297                elem: self.compile_storage_type(item.storage_type()),
1298                in_struct: self.compilation_options.grid_constants,
1299            },
1300            gpu::VariableKind::TensorMapInput(id) => {
1301                self.flags.inst_tma = true;
1302                Variable::TensorMap(id)
1303            }
1304            gpu::VariableKind::TensorMapOutput(id) => {
1305                self.flags.inst_tma = true;
1306                Variable::TensorMap(id)
1307            }
1308            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1309                id,
1310                item: self.compile_type(item),
1311            },
1312            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1313                id,
1314                item: self.compile_type(item),
1315            },
1316            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1317                id,
1318                item: self.compile_type(item),
1319            },
1320            gpu::VariableKind::GlobalOutputArray(id) => {
1321                Variable::GlobalOutputArray(id, self.compile_type(item))
1322            }
1323            gpu::VariableKind::ConstantScalar(value) => {
1324                Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1325            }
1326            gpu::VariableKind::SharedMemory { id, length, .. } => {
1327                let item = self.compile_type(item);
1328                Variable::SharedMemory(id, item, length)
1329            }
1330            gpu::VariableKind::ConstantArray {
1331                id,
1332                length,
1333                unroll_factor,
1334            } => {
1335                let item = self.compile_type(item);
1336                Variable::ConstantArray(id, item, length * unroll_factor)
1337            }
1338            gpu::VariableKind::Builtin(builtin) => match builtin {
1339                gpu::Builtin::AbsolutePos => {
1340                    self.flags.indexes.absolute_pos = true;
1341                    Variable::AbsolutePos
1342                }
1343                gpu::Builtin::CubePosCluster if self.compilation_options.supports_clusters => {
1344                    self.flags.indexes.cluster_pos = true;
1345                    Variable::ClusterRank
1346                }
1347                gpu::Builtin::CubePosClusterX if self.compilation_options.supports_clusters => {
1348                    self.flags.indexes.cluster_pos = true;
1349                    Variable::ClusterIndexX
1350                }
1351                gpu::Builtin::CubePosClusterY if self.compilation_options.supports_clusters => {
1352                    self.flags.indexes.cluster_pos = true;
1353                    Variable::ClusterIndexY
1354                }
1355                gpu::Builtin::CubePosClusterZ if self.compilation_options.supports_clusters => {
1356                    self.flags.indexes.cluster_pos = true;
1357                    Variable::ClusterIndexZ
1358                }
1359                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1360                // (1, 1, 1) if unsupported
1361                gpu::Builtin::CubePosCluster
1362                | gpu::Builtin::CubePosClusterX
1363                | gpu::Builtin::CubePosClusterY
1364                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1365                gpu::Builtin::AbsolutePosX => {
1366                    self.flags.indexes.absolute_pos_tuple = true;
1367                    Variable::AbsolutePosX
1368                }
1369                gpu::Builtin::AbsolutePosY => {
1370                    self.flags.indexes.absolute_pos_tuple = true;
1371                    Variable::AbsolutePosY
1372                }
1373                gpu::Builtin::AbsolutePosZ => {
1374                    self.flags.indexes.absolute_pos_tuple = true;
1375                    Variable::AbsolutePosZ
1376                }
1377                gpu::Builtin::CubeDim => {
1378                    self.flags.indexes.cube_dim = true;
1379                    Variable::CubeDim
1380                }
1381                gpu::Builtin::CubeDimX => {
1382                    self.flags.indexes.cube_dim_tuple = true;
1383                    Variable::CubeDimX
1384                }
1385                gpu::Builtin::CubeDimY => {
1386                    self.flags.indexes.cube_dim_tuple = true;
1387                    Variable::CubeDimY
1388                }
1389                gpu::Builtin::CubeDimZ => {
1390                    self.flags.indexes.cube_dim_tuple = true;
1391                    Variable::CubeDimZ
1392                }
1393                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1394                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1395                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1396                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1397                gpu::Builtin::CubePos => {
1398                    self.flags.indexes.cube_pos = true;
1399                    Variable::CubePos
1400                }
1401                gpu::Builtin::CubePosX => {
1402                    self.flags.indexes.cube_pos_tuple = true;
1403                    Variable::CubePosX
1404                }
1405                gpu::Builtin::CubePosY => {
1406                    self.flags.indexes.cube_pos_tuple = true;
1407                    Variable::CubePosY
1408                }
1409                gpu::Builtin::CubePosZ => {
1410                    self.flags.indexes.cube_pos_tuple = true;
1411                    Variable::CubePosZ
1412                }
1413                gpu::Builtin::CubeCount => {
1414                    self.flags.indexes.cube_count = true;
1415                    Variable::CubeCount
1416                }
1417                gpu::Builtin::CubeCountX => {
1418                    self.flags.indexes.cube_count_tuple = true;
1419                    Variable::CubeCountX
1420                }
1421                gpu::Builtin::CubeCountY => {
1422                    self.flags.indexes.cube_count_tuple = true;
1423                    Variable::CubeCountY
1424                }
1425                gpu::Builtin::CubeCountZ => {
1426                    self.flags.indexes.cube_count_tuple = true;
1427                    Variable::CubeCountZ
1428                }
1429                gpu::Builtin::UnitPos => {
1430                    self.flags.indexes.unit_pos = true;
1431                    Variable::UnitPos
1432                }
1433                gpu::Builtin::UnitPosX => {
1434                    self.flags.indexes.unit_pos_tuple = true;
1435                    Variable::UnitPosX
1436                }
1437                gpu::Builtin::UnitPosY => {
1438                    self.flags.indexes.unit_pos_tuple = true;
1439                    Variable::UnitPosY
1440                }
1441                gpu::Builtin::UnitPosZ => {
1442                    self.flags.indexes.unit_pos_tuple = true;
1443                    Variable::UnitPosZ
1444                }
1445                gpu::Builtin::PlaneDim => {
1446                    self.flags.indexes.plane_dim = true;
1447                    Variable::PlaneDim
1448                }
1449                gpu::Builtin::UnitPosPlane => {
1450                    self.flags.indexes.unit_pos_plane = true;
1451                    Variable::UnitPosPlane
1452                }
1453            },
1454            gpu::VariableKind::LocalArray {
1455                id,
1456                length,
1457                unroll_factor,
1458            } => {
1459                let item = self.compile_type(item);
1460                if !self.local_arrays.iter().any(|s| s.index == id) {
1461                    self.local_arrays
1462                        .push(LocalArray::new(id, item, length * unroll_factor));
1463                }
1464                Variable::LocalArray(id, item, length)
1465            }
1466            gpu::VariableKind::Matrix { id, mat } => {
1467                self.flags.inst_wmma = true;
1468                Variable::WmmaFragment {
1469                    id,
1470                    frag: self.compile_matrix(mat),
1471                }
1472            }
1473            gpu::VariableKind::Pipeline { id, num_stages } => {
1474                self.flags.op_pipeline = true;
1475                let pipeline = Variable::Pipeline { id };
1476                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1477                    self.pipelines.push(PipelineOps::Init {
1478                        pipeline,
1479                        num_stages,
1480                    });
1481                }
1482                pipeline
1483            }
1484            gpu::VariableKind::Barrier { id, level } => {
1485                self.flags.op_barrier = true;
1486                match level {
1487                    gpu::BarrierLevel::CubeCoop(_) | gpu::BarrierLevel::CubeManual(_) => {
1488                        self.flags.indexes.cube_dim = true;
1489                        self.flags.indexes.unit_pos = true;
1490                    }
1491                    _ => {}
1492                }
1493                Variable::Barrier { id, level }
1494            }
1495        }
1496    }
1497
1498    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1499        Fragment {
1500            ident: self.compile_matrix_ident(matrix.ident),
1501            m: matrix.m,
1502            n: matrix.n,
1503            k: matrix.k,
1504            elem: self.compile_storage_type(matrix.storage),
1505            layout: self.compile_matrix_layout(matrix.layout),
1506        }
1507    }
1508
1509    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1510        match ident {
1511            gpu::MatrixIdent::A => FragmentIdent::A,
1512            gpu::MatrixIdent::B => FragmentIdent::B,
1513            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1514        }
1515    }
1516
1517    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1518        match layout {
1519            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1520            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1521            gpu::MatrixLayout::Undefined => None,
1522        }
1523    }
1524
1525    fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1526        Binding {
1527            id: binding.id,
1528            item: self.compile_type(binding.ty),
1529            location: binding.location,
1530            size: binding.size,
1531            vis: binding.visibility,
1532        }
1533    }
1534
1535    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1536        let item = match ty {
1537            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1538            gpu::Type::Line(ty, line_size) => {
1539                Item::new(self.compile_storage_type(ty), line_size as usize, false)
1540            }
1541            gpu::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
1542        };
1543        if item.elem != super::Elem::TF32 {
1544            self.items.insert(item);
1545            self.items.insert(item.optimized());
1546        } else {
1547            // TF32 is represented as `float` in C++
1548            let mut item = item;
1549            item.elem = super::Elem::F32;
1550            self.items.insert(item);
1551        }
1552
1553        item
1554    }
1555
1556    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1557        match value {
1558            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1559            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1560            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1561                FloatKind::E2M1 => {
1562                    self.flags.elem_fp4 = true;
1563                    Elem::FP4x2(FP4Kind::E2M1)
1564                }
1565                FloatKind::E2M3 => {
1566                    self.flags.elem_fp6 = true;
1567                    Elem::FP6x2(FP6Kind::E2M3)
1568                }
1569                FloatKind::E3M2 => {
1570                    self.flags.elem_fp6 = true;
1571                    Elem::FP6(FP6Kind::E3M2)
1572                }
1573                FloatKind::E4M3 => {
1574                    self.flags.elem_fp8 = true;
1575                    Elem::FP8x2(FP8Kind::E4M3)
1576                }
1577                FloatKind::E5M2 => {
1578                    self.flags.elem_fp8 = true;
1579                    Elem::FP8x2(FP8Kind::E5M2)
1580                }
1581                FloatKind::UE8M0 => {
1582                    self.flags.elem_fp8 = true;
1583                    Elem::FP8x2(FP8Kind::UE8M0)
1584                }
1585                FloatKind::F16 => {
1586                    self.flags.elem_f16 = true;
1587                    Elem::F16x2
1588                }
1589                FloatKind::BF16 => {
1590                    self.flags.elem_bf16 = true;
1591                    Elem::BF16x2
1592                }
1593                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1594            },
1595            other => unimplemented!("Unsupported storage type: {other}"),
1596        }
1597    }
1598
1599    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1600        match value {
1601            gpu::ElemType::Float(kind) => match kind {
1602                gpu::FloatKind::E2M1 => {
1603                    self.flags.elem_fp4 = true;
1604                    Elem::FP4(FP4Kind::E2M1)
1605                }
1606                gpu::FloatKind::E2M3 => {
1607                    self.flags.elem_fp6 = true;
1608                    Elem::FP6(FP6Kind::E2M3)
1609                }
1610                gpu::FloatKind::E3M2 => {
1611                    self.flags.elem_fp6 = true;
1612                    Elem::FP6(FP6Kind::E3M2)
1613                }
1614                gpu::FloatKind::E4M3 => {
1615                    self.flags.elem_fp8 = true;
1616                    Elem::FP8(FP8Kind::E4M3)
1617                }
1618                gpu::FloatKind::E5M2 => {
1619                    self.flags.elem_fp8 = true;
1620                    Elem::FP8(FP8Kind::E5M2)
1621                }
1622                gpu::FloatKind::UE8M0 => {
1623                    self.flags.elem_fp8 = true;
1624                    Elem::FP8(FP8Kind::UE8M0)
1625                }
1626                gpu::FloatKind::F16 => {
1627                    self.flags.elem_f16 = true;
1628                    Elem::F16
1629                }
1630                gpu::FloatKind::BF16 => {
1631                    self.flags.elem_bf16 = true;
1632                    Elem::BF16
1633                }
1634                gpu::FloatKind::TF32 => Elem::TF32,
1635                gpu::FloatKind::Flex32 => Elem::F32,
1636                gpu::FloatKind::F32 => Elem::F32,
1637                gpu::FloatKind::F64 => Elem::F64,
1638            },
1639            gpu::ElemType::Int(kind) => match kind {
1640                gpu::IntKind::I8 => Elem::I8,
1641                gpu::IntKind::I16 => Elem::I16,
1642                gpu::IntKind::I32 => Elem::I32,
1643                gpu::IntKind::I64 => Elem::I64,
1644            },
1645            gpu::ElemType::UInt(kind) => match kind {
1646                gpu::UIntKind::U8 => Elem::U8,
1647                gpu::UIntKind::U16 => Elem::U16,
1648                gpu::UIntKind::U32 => Elem::U32,
1649                gpu::UIntKind::U64 => Elem::U64,
1650            },
1651            gpu::ElemType::Bool => Elem::Bool,
1652        }
1653    }
1654}
1655
1656fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1657    match elem {
1658        gpu::ElemType::Float(kind) => matches!(
1659            kind,
1660            FloatKind::E2M1
1661                | FloatKind::E2M3
1662                | FloatKind::E3M2
1663                | FloatKind::E4M3
1664                | FloatKind::E5M2
1665                | FloatKind::UE8M0
1666        ),
1667        _ => false,
1668    }
1669}
1670
1671fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1672    Variable::ConstantScalar(
1673        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1674        Elem::U32,
1675    )
1676}
1677
1678pub fn register_supported_types(props: &mut DeviceProperties) {
1679    let supported_types = [
1680        gpu::ElemType::UInt(gpu::UIntKind::U8),
1681        gpu::ElemType::UInt(gpu::UIntKind::U16),
1682        gpu::ElemType::UInt(gpu::UIntKind::U32),
1683        gpu::ElemType::UInt(gpu::UIntKind::U64),
1684        gpu::ElemType::Int(gpu::IntKind::I8),
1685        gpu::ElemType::Int(gpu::IntKind::I16),
1686        gpu::ElemType::Int(gpu::IntKind::I32),
1687        gpu::ElemType::Int(gpu::IntKind::I64),
1688        gpu::ElemType::Float(gpu::FloatKind::BF16),
1689        gpu::ElemType::Float(gpu::FloatKind::F16),
1690        gpu::ElemType::Float(gpu::FloatKind::F32),
1691        gpu::ElemType::Float(gpu::FloatKind::Flex32),
1692        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
1693        //gpu::Elem::Float(gpu::FloatKind::F64),
1694        gpu::ElemType::Bool,
1695    ];
1696
1697    let supported_atomic_types = [
1698        gpu::ElemType::Int(gpu::IntKind::I32),
1699        gpu::ElemType::Int(gpu::IntKind::I64),
1700        gpu::ElemType::UInt(gpu::UIntKind::U32),
1701        gpu::ElemType::UInt(gpu::UIntKind::U64),
1702        gpu::ElemType::Float(gpu::FloatKind::F32),
1703    ];
1704
1705    for ty in supported_types {
1706        props.register_type_usage(ty, TypeUsage::all_scalar());
1707    }
1708
1709    for ty in supported_atomic_types {
1710        props.register_type_usage(
1711            gpu::StorageType::Atomic(ty),
1712            TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1713        );
1714    }
1715}