cubecl_cpp/shared/
base.rs

1use std::{collections::HashSet, fmt::Debug, num::NonZero};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{ExpandElement, UIntKind, VariableKind};
5use cubecl_core::prelude::{FloatExpand, Line};
6use cubecl_core::{
7    Compiler, Feature,
8    ir::{self as gpu},
9};
10use cubecl_core::{CubeDim, io::read_tensor_checked};
11use cubecl_core::{
12    ir::{Operation, SourceLoc},
13    prelude::{FastMath, KernelDefinition, expand_checked_index_assign},
14};
15use cubecl_runtime::DeviceProperties;
16
17use super::barrier::BarrierOps;
18use super::pipeline::PipelineOps;
19use super::{
20    AtomicKind, BinaryInstruction, Binding, Body, ComputeKernel, ConstArray, Dialect, Elem,
21    Fragment, FragmentIdent, FragmentLayout, Instruction, Item, LocalArray, SharedMemory,
22    UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
23};
24
25pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
26    std::sync::atomic::AtomicU32::new(0);
27
28#[derive(Clone, Debug)]
29pub struct CompilationOptions {
30    pub warp_size: u32,
31    pub grid_constants: bool,
32    pub supports_clusters: bool,
33}
34
35impl Default for CompilationOptions {
36    fn default() -> Self {
37        Self {
38            warp_size: 32,
39            grid_constants: false,
40            supports_clusters: false,
41        }
42    }
43}
44
45/// Cube indexes flags.
46/// When true the corresponding index is declared and computed as needed in the kernel.
47#[derive(Debug, Clone, Default)]
48pub struct CubeIndexFlags {
49    pub absolute_pos: bool,
50    pub absolute_pos_tuple: bool,
51    pub cube_count: bool,
52    pub cube_count_tuple: bool,
53    pub cube_dim: bool,
54    pub cube_dim_tuple: bool,
55    pub cube_pos: bool,
56    pub cube_pos_tuple: bool,
57    pub plane_dim: bool,
58    pub plane_dim_checked: bool,
59    pub plane_index: bool,
60    pub unit_pos: bool,
61    pub unit_pos_tuple: bool,
62    pub unit_pos_plane: bool,
63    pub cluster_pos: bool,
64}
65
66/// Flags gathered during Cube IR translation for the kernel compilation.
67#[derive(Debug, Clone, Default)]
68pub struct Flags {
69    pub elem_bf16: bool,
70    pub elem_f16: bool,
71    pub indexes: CubeIndexFlags,
72    pub op_barrier: bool,
73    pub op_pipeline: bool,
74    pub inst_fast_math: bool,
75    pub inst_tma: bool,
76    pub inst_tma_im2col: bool,
77    pub inst_wmma: bool,
78    pub use_grid_constants: bool,
79    pub static_meta_length: usize,
80    pub has_dynamic_meta: bool,
81    pub cluster_dim: Option<CubeDim>,
82}
83
84#[allow(clippy::too_many_arguments)]
85#[derive(Clone, Debug, Default)]
86pub struct CppCompiler<D: Dialect> {
87    barriers: Vec<BarrierOps<D>>,
88    compilation_options: CompilationOptions,
89    const_arrays: Vec<ConstArray<D>>,
90    ext_meta_positions: Vec<u32>,
91    cluster_dim: CubeDim,
92    extensions: Vec<D::Extension>,
93    flags: Flags,
94    items: HashSet<Item<D>>,
95    local_arrays: Vec<LocalArray<D>>,
96    metadata: cubecl_core::Metadata,
97    pipelines: Vec<PipelineOps<D>>,
98    shared_memories: Vec<SharedMemory<D>>,
99    source_loc: Option<SourceLoc>,
100    strategy: ExecutionMode,
101}
102
103impl<D: Dialect> Compiler for CppCompiler<D> {
104    type Representation = ComputeKernel<D>;
105    type CompilationOptions = CompilationOptions;
106
107    fn compile(
108        &mut self,
109        mut kernel: KernelDefinition,
110        compilation_options: &Self::CompilationOptions,
111        strategy: ExecutionMode,
112    ) -> Self::Representation {
113        self.compilation_options = compilation_options.clone();
114        self.strategy = strategy;
115
116        if !self.compilation_options.supports_clusters {
117            kernel.options.cluster_dim = None;
118        }
119        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
120
121        let ir = self.clone().compile_ir(kernel);
122        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
123        ir
124    }
125
126    fn elem_size(&self, elem: gpu::Elem) -> usize {
127        elem.size()
128    }
129
130    fn extension(&self) -> &'static str {
131        "cpp"
132    }
133}
134
135impl<D: Dialect> CppCompiler<D> {
136    fn compile_ir(mut self, mut value: KernelDefinition) -> ComputeKernel<D> {
137        self.build_metadata(&value);
138
139        let instructions = self.compile_scope(&mut value.body);
140        let buffers = value
141            .buffers
142            .into_iter()
143            .map(|b| self.compile_binding(b))
144            .collect();
145        let scalars = value
146            .scalars
147            .into_iter()
148            .map(|binding| (self.compile_elem(binding.elem), binding.count))
149            .collect();
150
151        // translation flags
152        let flags = Flags {
153            indexes: D::builtin_rules(&self.flags.indexes),
154            inst_wmma: self.flags.inst_wmma,
155            op_pipeline: self.flags.op_pipeline,
156            op_barrier: self.flags.op_barrier,
157            elem_bf16: self.flags.elem_bf16,
158            elem_f16: self.flags.elem_f16,
159            inst_fast_math: value
160                .options
161                .fp_math_mode
162                .contains(FastMath::ReducedPrecision),
163            inst_tma: self.flags.inst_tma,
164            inst_tma_im2col: self.flags.inst_tma_im2col,
165            use_grid_constants: self.compilation_options.grid_constants,
166            // TODO: At some point we should only pass dymamic meta if tensors are present,
167            // not if only arrays are present. For now, leave like this
168            has_dynamic_meta: self.metadata.static_len() > 0,
169            static_meta_length: self.metadata.static_len() as usize,
170            cluster_dim: value.options.cluster_dim,
171        };
172
173        let body = Body {
174            instructions,
175            shared_memories: self.shared_memories,
176            pipelines: self.pipelines,
177            barriers: self.barriers,
178            const_arrays: self.const_arrays,
179            local_arrays: self.local_arrays,
180        };
181
182        let mut cluster_dim = value.options.cluster_dim;
183        if !self.compilation_options.supports_clusters {
184            cluster_dim = None;
185        }
186
187        ComputeKernel {
188            tensor_maps: value.tensor_maps,
189            buffers,
190            scalars,
191            meta_static_len: self.metadata.static_len() as usize,
192            cube_dim: value.cube_dim,
193            body,
194            extensions: self.extensions,
195            flags,
196            items: self.items,
197            kernel_name: value.options.kernel_name,
198            cluster_dim,
199        }
200    }
201
202    fn build_metadata(&mut self, value: &KernelDefinition) {
203        let mut num_ext = 0;
204
205        let mut all_meta: Vec<_> = value
206            .buffers
207            .iter()
208            .map(|buf| (buf.id, buf.has_extended_meta))
209            .chain(value.tensor_maps.iter().map(|i| (*i, true)))
210            .collect();
211
212        all_meta.sort_by_key(|(id, _)| *id);
213
214        for (_, has_extended_meta) in &all_meta {
215            self.ext_meta_positions.push(num_ext);
216            if *has_extended_meta {
217                num_ext += 1;
218            }
219        }
220
221        let num_meta = all_meta.len();
222
223        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
224    }
225
226    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
227        let id = var.index().expect("Variable should have index");
228        self.ext_meta_positions[id as usize]
229    }
230
231    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
232        let mut instructions = Vec::new();
233
234        let const_arrays = scope
235            .const_arrays
236            .drain(..)
237            .map(|(var, values)| ConstArray {
238                index: var.index().unwrap(),
239                item: self.compile_item(var.item),
240                size: values.len() as u32,
241                values: values
242                    .into_iter()
243                    .map(|val| self.compile_variable(val))
244                    .collect(),
245            })
246            .collect::<Vec<_>>();
247        self.const_arrays.extend(const_arrays);
248
249        let processing = scope.process();
250
251        for var in processing.variables {
252            if let gpu::VariableKind::Slice { .. } = var.kind {
253                continue;
254            }
255            instructions.push(Instruction::DeclareVariable {
256                var: self.compile_variable(var),
257            });
258        }
259
260        processing
261            .instructions
262            .into_iter()
263            .for_each(|op| self.compile_instruction(&mut instructions, op, scope));
264
265        instructions
266    }
267
268    fn compile_instruction(
269        &mut self,
270        instructions: &mut Vec<Instruction<D>>,
271        instruction: gpu::Instruction,
272        scope: &mut gpu::Scope,
273    ) {
274        self.update_debug_loc(instructions, &instruction);
275        let out = instruction.out;
276        match instruction.operation {
277            gpu::Operation::Copy(variable) => {
278                instructions.push(Instruction::Assign(UnaryInstruction {
279                    input: self.compile_variable(variable),
280                    out: self.compile_variable(out.unwrap()),
281                }));
282            }
283            gpu::Operation::Arithmetic(op) => self.compile_arithmetic(op, out, instructions),
284            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
285            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
286            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions, scope),
287            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
288            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
289            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
290            gpu::Operation::Synchronization(val) => match val {
291                gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads),
292                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
293                gpu::Synchronization::SyncProxyShared => {
294                    self.flags.inst_tma = true;
295                    instructions.push(Instruction::ProxySharedFence)
296                }
297            },
298            gpu::Operation::Plane(op) => {
299                self.flags.indexes.plane_dim_checked = true;
300                let out = self.compile_variable(out.unwrap());
301                match op {
302                    gpu::Plane::Sum(op) => {
303                        let instruction = WarpInstruction::ReduceSum {
304                            input: self.compile_variable(op.input),
305                            out,
306                        };
307                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
308                        instructions.push(Instruction::Warp(instruction));
309                    }
310                    gpu::Plane::InclusiveSum(op) => {
311                        self.flags.indexes.unit_pos_plane = true;
312                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
313                            input: self.compile_variable(op.input),
314                            out,
315                        }))
316                    }
317                    gpu::Plane::InclusiveProd(op) => {
318                        self.flags.indexes.unit_pos_plane = true;
319                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
320                            input: self.compile_variable(op.input),
321                            out,
322                        }))
323                    }
324                    gpu::Plane::ExclusiveSum(op) => {
325                        self.flags.indexes.unit_pos_plane = true;
326                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
327                            input: self.compile_variable(op.input),
328                            out,
329                        }))
330                    }
331                    gpu::Plane::ExclusiveProd(op) => {
332                        self.flags.indexes.unit_pos_plane = true;
333                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
334                            input: self.compile_variable(op.input),
335                            out,
336                        }))
337                    }
338                    gpu::Plane::Prod(op) => {
339                        let instruction = WarpInstruction::ReduceProd {
340                            input: self.compile_variable(op.input),
341                            out,
342                        };
343                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
344                        instructions.push(Instruction::Warp(instruction))
345                    }
346                    gpu::Plane::Max(op) => {
347                        let instruction = WarpInstruction::ReduceMax {
348                            input: self.compile_variable(op.input),
349                            out,
350                        };
351                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
352                        instructions.push(Instruction::Warp(instruction))
353                    }
354                    gpu::Plane::Min(op) => {
355                        let instruction = WarpInstruction::ReduceMin {
356                            input: self.compile_variable(op.input),
357                            out,
358                        };
359                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
360                        instructions.push(Instruction::Warp(instruction))
361                    }
362                    gpu::Plane::Elect => {
363                        instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
364                    }
365                    gpu::Plane::All(op) => {
366                        instructions.push(Instruction::Warp(WarpInstruction::All {
367                            input: self.compile_variable(op.input),
368                            out,
369                        }))
370                    }
371                    gpu::Plane::Any(op) => {
372                        instructions.push(Instruction::Warp(WarpInstruction::Any {
373                            input: self.compile_variable(op.input),
374                            out,
375                        }))
376                    }
377                    gpu::Plane::Ballot(op) => {
378                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
379                            input: self.compile_variable(op.input),
380                            out,
381                        }))
382                    }
383                    gpu::Plane::Broadcast(op) => {
384                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
385                            input: self.compile_variable(op.lhs),
386                            id: self.compile_variable(op.rhs),
387                            out,
388                        }))
389                    }
390                }
391            }
392            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
393            gpu::Operation::NonSemantic(debug) => match debug {
394                gpu::NonSemantic::Print {
395                    format_string,
396                    args,
397                } => instructions.push(Instruction::Printf {
398                    format_string,
399                    args: args
400                        .into_iter()
401                        .map(|arg| self.compile_variable(arg))
402                        .collect(),
403                }),
404                gpu::NonSemantic::Comment { content } => {
405                    instructions.push(Instruction::Comment { content })
406                }
407                // Don't need to handle scopes
408                _ => {}
409            },
410            gpu::Operation::Pipeline(pipeline_ops) => match pipeline_ops {
411                gpu::PipelineOps::MemCopyAsync {
412                    pipeline,
413                    source,
414                    destination,
415                } => {
416                    instructions.push(Instruction::Pipeline(
417                        super::pipeline::PipelineOps::MemCopyAsync {
418                            pipeline: self.compile_variable(pipeline),
419                            source: self.compile_variable(source),
420                            destination: self.compile_variable(destination),
421                        },
422                    ));
423                }
424                gpu::PipelineOps::ProducerAcquire { pipeline } => instructions.push(
425                    Instruction::Pipeline(super::pipeline::PipelineOps::ProducerAcquire {
426                        pipeline: self.compile_variable(pipeline),
427                    }),
428                ),
429                gpu::PipelineOps::ProducerCommit { pipeline } => instructions.push(
430                    Instruction::Pipeline(super::pipeline::PipelineOps::ProducerCommit {
431                        pipeline: self.compile_variable(pipeline),
432                    }),
433                ),
434
435                gpu::PipelineOps::ConsumerWait { pipeline } => instructions.push(
436                    Instruction::Pipeline(super::pipeline::PipelineOps::ConsumerWait {
437                        pipeline: self.compile_variable(pipeline),
438                    }),
439                ),
440
441                gpu::PipelineOps::ConsumerRelease { pipeline } => instructions.push(
442                    Instruction::Pipeline(super::pipeline::PipelineOps::ConsumerRelease {
443                        pipeline: self.compile_variable(pipeline),
444                    }),
445                ),
446            },
447            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
448                gpu::BarrierOps::Init {
449                    barrier,
450                    with_cta_fence,
451                } => {
452                    let VariableKind::Barrier { level, .. } = barrier.kind else {
453                        unreachable!()
454                    };
455                    let barrier = self.compile_variable(barrier);
456                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
457                        barrier,
458                        level,
459                        with_cta_fence,
460                    }));
461                }
462                gpu::BarrierOps::MemCopyAsync { barrier, source } => {
463                    let VariableKind::Barrier { level, .. } = barrier.kind else {
464                        unreachable!()
465                    };
466                    instructions.push(Instruction::Barrier(
467                        super::barrier::BarrierOps::MemCopyAsync {
468                            barrier: self.compile_variable(barrier),
469                            source: self.compile_variable(source),
470                            destination: self.compile_variable(out.unwrap()),
471                            level,
472                        },
473                    ));
474                }
475                gpu::BarrierOps::TmaLoad {
476                    barrier,
477                    tensor_map,
478                    indices,
479                } => {
480                    instructions.push(Instruction::Barrier(
481                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
482                            barrier: self.compile_variable(barrier),
483                            smem_buffer: self.compile_variable(out.unwrap()),
484                            tensor_map: self.compile_variable(tensor_map),
485                            indices: indices
486                                .into_iter()
487                                .map(|it| self.compile_variable(it))
488                                .collect(),
489                        },
490                    ));
491                }
492                gpu::BarrierOps::TmaLoadIm2col {
493                    barrier,
494                    tensor_map,
495                    indices,
496                    offsets,
497                } => {
498                    self.flags.inst_tma_im2col = true;
499                    instructions.push(Instruction::Barrier(
500                        super::barrier::BarrierOps::TmaLoadIm2col {
501                            barrier: self.compile_variable(barrier),
502                            smem_buffer: self.compile_variable(out.unwrap()),
503                            tensor_map: self.compile_variable(tensor_map),
504                            indices: indices
505                                .into_iter()
506                                .map(|it| self.compile_variable(it))
507                                .collect(),
508                            offsets: offsets
509                                .into_iter()
510                                .map(|it| self.compile_variable(it))
511                                .collect(),
512                        },
513                    ));
514                }
515                gpu::BarrierOps::Arrive { barrier } => {
516                    let VariableKind::Barrier { level, .. } = barrier.kind else {
517                        unreachable!()
518                    };
519                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
520                        barrier: self.compile_variable(barrier),
521                        level,
522                    }))
523                }
524                gpu::BarrierOps::ArriveTx {
525                    barrier,
526                    arrive_count_update,
527                    transaction_count_update,
528                } => {
529                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
530                        barrier: self.compile_variable(barrier),
531                        arrive_count_update: self.compile_variable(arrive_count_update),
532                        transaction_count_update: self.compile_variable(transaction_count_update),
533                    }))
534                }
535                gpu::BarrierOps::ExpectTx {
536                    barrier,
537                    transaction_count_update,
538                } => {
539                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
540                        barrier: self.compile_variable(barrier),
541                        transaction_count_update: self.compile_variable(transaction_count_update),
542                    }))
543                }
544                gpu::BarrierOps::Wait { barrier } => {
545                    let VariableKind::Barrier { level, .. } = barrier.kind else {
546                        unreachable!()
547                    };
548                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
549                        barrier: self.compile_variable(barrier),
550                        level,
551                    }))
552                }
553                gpu::BarrierOps::ArriveAndWait { barrier } => {
554                    let VariableKind::Barrier { level, .. } = barrier.kind else {
555                        unreachable!()
556                    };
557                    instructions.push(Instruction::Barrier(
558                        super::barrier::BarrierOps::ArriveAndWait {
559                            barrier: self.compile_variable(barrier),
560                            level,
561                        },
562                    ))
563                }
564            },
565            gpu::Operation::Tma(tma_ops) => {
566                self.flags.inst_tma = true;
567                match tma_ops {
568                    gpu::TmaOps::TmaStore {
569                        source,
570                        coordinates,
571                    } => {
572                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
573                            smem_buffer: self.compile_variable(source),
574                            tensor_map: self.compile_variable(out.unwrap()),
575                            indices: coordinates
576                                .into_iter()
577                                .map(|it| self.compile_variable(it))
578                                .collect(),
579                        });
580                    }
581                    gpu::TmaOps::CommitGroup => {
582                        instructions.push(Instruction::BulkCommitGroup);
583                    }
584                    gpu::TmaOps::WaitGroup { max_pending } => {
585                        instructions.push(Instruction::BulkWaitGroup { max_pending });
586                    }
587                    gpu::TmaOps::WaitGroupRead { max_pending } => {
588                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
589                    }
590                }
591            }
592        }
593    }
594
595    fn update_debug_loc(
596        &mut self,
597        instructions: &mut Vec<Instruction<D>>,
598        inst: &gpu::Instruction,
599    ) {
600        if !matches!(inst.operation, Operation::NonSemantic(_)) {
601            match &inst.source_loc {
602                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
603                    self.source_loc = Some(loc.clone());
604                    instructions.push(Instruction::Line {
605                        file: loc.source.file.clone(),
606                        line: loc.line,
607                    });
608                }
609                _ => {}
610            }
611        }
612    }
613
614    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
615        let out = self.compile_variable(out.unwrap());
616        match cmma {
617            gpu::CoopMma::Fill { value } => Instruction::Wmma(WmmaInstruction::Fill {
618                frag: out,
619                value: self.compile_variable(value),
620            }),
621            gpu::CoopMma::Load {
622                value,
623                stride,
624                layout,
625            } => Instruction::Wmma(WmmaInstruction::Load {
626                frag: out,
627                value: self.compile_variable(value),
628                stride: self.compile_variable(stride),
629                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
630            }),
631            gpu::CoopMma::Execute {
632                mat_a,
633                mat_b,
634                mat_c,
635            } => Instruction::Wmma(WmmaInstruction::Execute {
636                frag_a: self.compile_variable(mat_a),
637                frag_b: self.compile_variable(mat_b),
638                frag_c: self.compile_variable(mat_c),
639                frag_d: out,
640                warp_size: self.compilation_options.warp_size,
641            }),
642            gpu::CoopMma::Store {
643                mat,
644                stride,
645                layout,
646            } => {
647                self.flags.indexes.unit_pos = true;
648                self.flags.indexes.plane_index = true;
649                Instruction::Wmma(WmmaInstruction::Store {
650                    output: out,
651                    frag: self.compile_variable(mat),
652                    stride: self.compile_variable(stride),
653                    layout: self
654                        .compile_matrix_layout(layout)
655                        .expect("Layout required for store instruction"),
656                })
657            }
658            gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast {
659                input: self.compile_variable(input),
660                output: out,
661            }),
662        }
663    }
664
665    fn compile_metadata(
666        &mut self,
667        metadata: gpu::Metadata,
668        out: Option<gpu::Variable>,
669    ) -> Instruction<D> {
670        let out = out.unwrap();
671        match metadata {
672            gpu::Metadata::Stride { dim, var } => {
673                let position = self.ext_meta_position(var);
674                let offset = self.metadata.stride_offset_index(position);
675                Instruction::ExtendedMetadata {
676                    info_offset: self.compile_variable(offset.into()),
677                    dim: self.compile_variable(dim),
678                    split_meta: self.compilation_options.grid_constants,
679                    static_offset: self.metadata.static_len(),
680                    out: self.compile_variable(out),
681                }
682            }
683            gpu::Metadata::Shape { dim, var } => {
684                let position = self.ext_meta_position(var);
685                let offset = self.metadata.shape_offset_index(position);
686                Instruction::ExtendedMetadata {
687                    info_offset: self.compile_variable(offset.into()),
688                    dim: self.compile_variable(dim),
689                    split_meta: self.compilation_options.grid_constants,
690                    static_offset: self.metadata.static_len(),
691                    out: self.compile_variable(out),
692                }
693            }
694            gpu::Metadata::Rank { var } => {
695                let out = self.compile_variable(out);
696                let pos = self.ext_meta_position(var);
697                let offset = self.metadata.rank_index(pos);
698                super::Instruction::Metadata {
699                    info_offset: self.compile_variable(offset.into()),
700                    split_meta: self.compilation_options.grid_constants,
701                    out,
702                }
703            }
704            gpu::Metadata::Length { var } => {
705                let input = self.compile_variable(var);
706                let out = self.compile_variable(out);
707
708                match input {
709                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
710                    Variable::SharedMemory(_id, _item, length) => {
711                        Instruction::ConstLength { length, out }
712                    }
713                    _ => {
714                        let id = input.id().expect("Variable should have id");
715                        let offset = self.metadata.len_index(id);
716                        Instruction::Metadata {
717                            info_offset: self.compile_variable(offset.into()),
718                            split_meta: self.compilation_options.grid_constants,
719                            out,
720                        }
721                    }
722                }
723            }
724            gpu::Metadata::BufferLength { var } => {
725                let input = self.compile_variable(var);
726                let out = self.compile_variable(out);
727
728                match input {
729                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
730                    _ => {
731                        let id = input.id().expect("Variable should have id");
732                        let offset = self.metadata.buffer_len_index(id);
733                        Instruction::Metadata {
734                            info_offset: self.compile_variable(offset.into()),
735                            split_meta: self.compilation_options.grid_constants,
736                            out,
737                        }
738                    }
739                }
740            }
741        }
742    }
743
744    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
745        match branch {
746            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
747                cond: self.compile_variable(op.cond),
748                instructions: self.compile_scope(&mut op.scope),
749            }),
750            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
751                cond: self.compile_variable(op.cond),
752                instructions_if: self.compile_scope(&mut op.scope_if),
753                instructions_else: self.compile_scope(&mut op.scope_else),
754            }),
755            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
756                value: self.compile_variable(op.value),
757                instructions_default: self.compile_scope(&mut op.scope_default),
758                instructions_cases: op
759                    .cases
760                    .into_iter()
761                    .map(|(val, mut block)| {
762                        (self.compile_variable(val), self.compile_scope(&mut block))
763                    })
764                    .collect(),
765            }),
766            gpu::Branch::Return => instructions.push(Instruction::Return),
767            gpu::Branch::Break => instructions.push(Instruction::Break),
768            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
769                i: self.compile_variable(range_loop.i),
770                start: self.compile_variable(range_loop.start),
771                end: self.compile_variable(range_loop.end),
772                step: range_loop.step.map(|it| self.compile_variable(it)),
773                inclusive: range_loop.inclusive,
774                instructions: self.compile_scope(&mut range_loop.scope),
775            }),
776            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
777                instructions: self.compile_scope(&mut op.scope),
778            }),
779        };
780    }
781
782    fn compile_atomic(
783        &mut self,
784        value: gpu::AtomicOp,
785        out: Option<gpu::Variable>,
786        instructions: &mut Vec<Instruction<D>>,
787    ) {
788        let out = out.unwrap();
789        match value {
790            gpu::AtomicOp::Load(op) => {
791                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
792            }
793            gpu::AtomicOp::Store(op) => {
794                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
795            }
796            gpu::AtomicOp::Swap(op) => {
797                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
798            }
799            gpu::AtomicOp::Add(op) => {
800                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
801            }
802            gpu::AtomicOp::Sub(op) => {
803                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
804            }
805            gpu::AtomicOp::Max(op) => {
806                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
807            }
808            gpu::AtomicOp::Min(op) => {
809                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
810            }
811            gpu::AtomicOp::And(op) => {
812                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
813            }
814            gpu::AtomicOp::Or(op) => {
815                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
816            }
817            gpu::AtomicOp::Xor(op) => {
818                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
819            }
820            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
821                input: self.compile_variable(op.input),
822                cmp: self.compile_variable(op.cmp),
823                val: self.compile_variable(op.val),
824                out: self.compile_variable(out),
825            }),
826        }
827    }
828
829    fn compile_arithmetic(
830        &mut self,
831        value: gpu::Arithmetic,
832        out: Option<gpu::Variable>,
833        instructions: &mut Vec<Instruction<D>>,
834    ) {
835        let out = out.unwrap();
836        match value {
837            gpu::Arithmetic::Add(op) => {
838                instructions.push(Instruction::Add(self.compile_binary(op, out)))
839            }
840            gpu::Arithmetic::Mul(op) => {
841                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
842            }
843            gpu::Arithmetic::Div(op) => {
844                instructions.push(Instruction::Div(self.compile_binary(op, out)))
845            }
846            gpu::Arithmetic::Sub(op) => {
847                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
848            }
849            gpu::Arithmetic::MulHi(op) => {
850                let instruction = Instruction::HiMul(self.compile_binary(op, out));
851                D::register_instruction_extension(&mut self.extensions, &instruction);
852                instructions.push(instruction)
853            }
854            gpu::Arithmetic::Modulo(op) => {
855                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
856            }
857            gpu::Arithmetic::Abs(op) => {
858                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
859            }
860            gpu::Arithmetic::Exp(op) => {
861                instructions.push(Instruction::Exp(self.compile_unary(op, out)))
862            }
863            gpu::Arithmetic::Log(op) => {
864                instructions.push(Instruction::Log(self.compile_unary(op, out)))
865            }
866            gpu::Arithmetic::Log1p(op) => {
867                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
868            }
869            gpu::Arithmetic::Cos(op) => {
870                instructions.push(Instruction::Cos(self.compile_unary(op, out)))
871            }
872            gpu::Arithmetic::Sin(op) => {
873                instructions.push(Instruction::Sin(self.compile_unary(op, out)))
874            }
875            gpu::Arithmetic::Tanh(op) => {
876                let instruction = Instruction::Tanh(self.compile_unary(op, out));
877                D::register_instruction_extension(&mut self.extensions, &instruction);
878                instructions.push(instruction)
879            }
880            gpu::Arithmetic::Powf(op) => {
881                instructions.push(Instruction::Powf(self.compile_binary(op, out)))
882            }
883            gpu::Arithmetic::Sqrt(op) => {
884                instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
885            }
886            gpu::Arithmetic::Erf(op) => {
887                let instruction = Instruction::Erf(self.compile_unary(op, out));
888                D::register_instruction_extension(&mut self.extensions, &instruction);
889                instructions.push(instruction)
890            }
891            gpu::Arithmetic::Max(op) => {
892                let instruction = Instruction::Max(self.compile_binary(op, out));
893                D::register_instruction_extension(&mut self.extensions, &instruction);
894                instructions.push(instruction)
895            }
896            gpu::Arithmetic::Min(op) => {
897                let instruction = Instruction::Min(self.compile_binary(op, out));
898                D::register_instruction_extension(&mut self.extensions, &instruction);
899                instructions.push(instruction)
900            }
901            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
902                input: self.compile_variable(op.input),
903                min_value: self.compile_variable(op.min_value),
904                max_value: self.compile_variable(op.max_value),
905                out: self.compile_variable(out),
906            }),
907            gpu::Arithmetic::Recip(op) => {
908                let elem = op.input.item.elem();
909                let lhs = match elem {
910                    gpu::Elem::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
911                    gpu::Elem::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
912                    gpu::Elem::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
913                    gpu::Elem::Bool => gpu::ConstantScalarValue::Bool(true),
914                    gpu::Elem::AtomicInt(_)
915                    | gpu::Elem::AtomicUInt(_)
916                    | gpu::Elem::AtomicFloat(_) => {
917                        panic!("Cannot use recip with atomics")
918                    }
919                };
920
921                instructions.push(Instruction::Div(BinaryInstruction {
922                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
923                    rhs: self.compile_variable(op.input),
924                    out: self.compile_variable(out),
925                }))
926            }
927            gpu::Arithmetic::Round(op) => {
928                instructions.push(Instruction::Round(self.compile_unary(op, out)))
929            }
930            gpu::Arithmetic::Floor(op) => {
931                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
932            }
933            gpu::Arithmetic::Ceil(op) => {
934                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
935            }
936            gpu::Arithmetic::Remainder(op) => {
937                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
938            }
939            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
940                a: self.compile_variable(op.a),
941                b: self.compile_variable(op.b),
942                c: self.compile_variable(op.c),
943                out: self.compile_variable(out),
944            }),
945            gpu::Arithmetic::Neg(op) => {
946                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
947            }
948            gpu::Arithmetic::Normalize(op) => {
949                instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
950            }
951            gpu::Arithmetic::Magnitude(op) => {
952                instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
953            }
954            gpu::Arithmetic::Dot(op) => {
955                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
956            }
957        };
958    }
959
960    fn compile_comparison(
961        &mut self,
962        value: gpu::Comparison,
963        out: Option<gpu::Variable>,
964        instructions: &mut Vec<Instruction<D>>,
965    ) {
966        let out = out.unwrap();
967        match value {
968            gpu::Comparison::Equal(op) => {
969                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
970            }
971            gpu::Comparison::Lower(op) => {
972                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
973            }
974            gpu::Comparison::Greater(op) => {
975                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
976            }
977            gpu::Comparison::LowerEqual(op) => {
978                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
979            }
980            gpu::Comparison::GreaterEqual(op) => {
981                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
982            }
983            gpu::Comparison::NotEqual(op) => {
984                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
985            }
986        };
987    }
988
989    fn compile_bitwise(
990        &mut self,
991        value: gpu::Bitwise,
992        out: Option<gpu::Variable>,
993        instructions: &mut Vec<Instruction<D>>,
994    ) {
995        let out = out.unwrap();
996        match value {
997            gpu::Bitwise::BitwiseOr(op) => {
998                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
999            }
1000            gpu::Bitwise::BitwiseAnd(op) => {
1001                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1002            }
1003            gpu::Bitwise::BitwiseXor(op) => {
1004                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1005            }
1006            gpu::Bitwise::CountOnes(op) => {
1007                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1008            }
1009            gpu::Bitwise::ReverseBits(op) => {
1010                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1011            }
1012            gpu::Bitwise::ShiftLeft(op) => {
1013                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1014            }
1015            gpu::Bitwise::ShiftRight(op) => {
1016                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1017            }
1018            gpu::Bitwise::BitwiseNot(op) => {
1019                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1020            }
1021            gpu::Bitwise::LeadingZeros(op) => {
1022                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1023            }
1024            gpu::Bitwise::FindFirstSet(op) => {
1025                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1026                D::register_instruction_extension(&mut self.extensions, &instruction);
1027                instructions.push(instruction)
1028            }
1029        };
1030    }
1031
1032    fn compile_operator(
1033        &mut self,
1034        value: gpu::Operator,
1035        out: Option<gpu::Variable>,
1036        instructions: &mut Vec<Instruction<D>>,
1037        scope: &mut gpu::Scope,
1038    ) {
1039        let out = out.unwrap();
1040        match value {
1041            gpu::Operator::Slice(op) => {
1042                if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() {
1043                    let input = op.input;
1044                    let input_len = *scope
1045                        .create_local_mut(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32)));
1046                    instructions.extend(self.compile_scope(scope));
1047
1048                    let length = match input.has_buffer_length() {
1049                        true => gpu::Metadata::BufferLength { var: input },
1050                        false => gpu::Metadata::Length { var: input },
1051                    };
1052
1053                    instructions.push(self.compile_metadata(length, Some(input_len)));
1054                    instructions.push(Instruction::CheckedSlice {
1055                        input: self.compile_variable(op.input),
1056                        start: self.compile_variable(op.start),
1057                        end: self.compile_variable(op.end),
1058                        out: self.compile_variable(out),
1059                        len: self.compile_variable(input_len),
1060                    });
1061                } else {
1062                    instructions.push(Instruction::Slice {
1063                        input: self.compile_variable(op.input),
1064                        start: self.compile_variable(op.start),
1065                        end: self.compile_variable(op.end),
1066                        out: self.compile_variable(out),
1067                    })
1068                }
1069            }
1070            gpu::Operator::ReinterpretSlice(op) => {
1071                // TODO Do we need to add special behavior in checked mode?
1072                instructions.push(Instruction::ReinterpretSlice {
1073                    input: self.compile_variable(op.input),
1074                    line_size: op.line_size,
1075                    out: self.compile_variable(out),
1076                })
1077            }
1078            gpu::Operator::Index(op) => {
1079                if matches!(self.strategy, ExecutionMode::Checked)
1080                    && op.lhs.has_length()
1081                    && !out.elem().is_atomic()
1082                {
1083                    let list = ExpandElement::Plain(op.lhs);
1084                    let index = ExpandElement::Plain(op.rhs);
1085                    scope.register_elem::<FloatExpand<0>>(op.lhs.elem());
1086
1087                    let mut child_scope = scope.child();
1088                    let input = read_tensor_checked::expand::<Line<FloatExpand<0>>>(
1089                        &mut child_scope,
1090                        list.into(),
1091                        index.into(),
1092                    );
1093
1094                    for inst in self.compile_scope(&mut child_scope) {
1095                        instructions.push(inst);
1096                    }
1097
1098                    instructions.push(Instruction::Assign(UnaryInstruction {
1099                        input: self.compile_variable(input.into_variable()),
1100                        out: self.compile_variable(out),
1101                    }))
1102                } else {
1103                    instructions.push(Instruction::Index(self.compile_binary(op, out)));
1104                }
1105            }
1106            gpu::Operator::UncheckedIndex(op) => {
1107                instructions.push(Instruction::Index(self.compile_binary(op, out)))
1108            }
1109            gpu::Operator::IndexAssign(op) => {
1110                if let ExecutionMode::Checked = self.strategy {
1111                    if out.has_length() {
1112                        expand_checked_index_assign(scope, op.lhs, op.rhs, out);
1113                        instructions.extend(self.compile_scope(scope));
1114                        return;
1115                    }
1116                };
1117                instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)));
1118            }
1119            gpu::Operator::UncheckedIndexAssign(op) => {
1120                instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)))
1121            }
1122            gpu::Operator::And(op) => {
1123                instructions.push(Instruction::And(self.compile_binary(op, out)))
1124            }
1125            gpu::Operator::Or(op) => {
1126                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1127            }
1128            gpu::Operator::Not(op) => {
1129                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1130            }
1131            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1132                inputs: op
1133                    .inputs
1134                    .into_iter()
1135                    .map(|it| self.compile_variable(it))
1136                    .collect(),
1137                out: self.compile_variable(out),
1138            }),
1139            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1140                input: self.compile_variable(op.input),
1141                in_index: self.compile_variable(op.in_index),
1142                out: self.compile_variable(out),
1143                out_index: self.compile_variable(op.out_index),
1144            }),
1145            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1146                input: self.compile_variable(op.input),
1147                in_index: self.compile_variable(op.in_index),
1148                out: self.compile_variable(out),
1149                out_index: self.compile_variable(op.out_index),
1150                len: op.len.as_const().unwrap().as_u32(),
1151            }),
1152            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1153                cond: self.compile_variable(op.cond),
1154                then: self.compile_variable(op.then),
1155                or_else: self.compile_variable(op.or_else),
1156                out: self.compile_variable(out),
1157            }),
1158            gpu::Operator::Cast(op) => {
1159                instructions.push(Instruction::Assign(self.compile_unary(op, out)))
1160            }
1161            gpu::Operator::Reinterpret(op) => {
1162                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1163            }
1164        };
1165    }
1166
1167    fn compile_binary(
1168        &mut self,
1169        value: gpu::BinaryOperator,
1170        out: gpu::Variable,
1171    ) -> BinaryInstruction<D> {
1172        BinaryInstruction {
1173            lhs: self.compile_variable(value.lhs),
1174            rhs: self.compile_variable(value.rhs),
1175            out: self.compile_variable(out),
1176        }
1177    }
1178
1179    fn compile_unary(
1180        &mut self,
1181        value: gpu::UnaryOperator,
1182        out: gpu::Variable,
1183    ) -> UnaryInstruction<D> {
1184        UnaryInstruction {
1185            input: self.compile_variable(value.input),
1186            out: self.compile_variable(out),
1187        }
1188    }
1189
1190    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1191        let item = value.item;
1192        match value.kind {
1193            gpu::VariableKind::GlobalInputArray(id) => {
1194                Variable::GlobalInputArray(id, self.compile_item(item))
1195            }
1196            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1197                id,
1198                elem: self.compile_elem(item.elem),
1199                in_struct: self.compilation_options.grid_constants,
1200            },
1201            gpu::VariableKind::TensorMap(id) => {
1202                self.flags.inst_tma = true;
1203                Variable::TensorMap(id)
1204            }
1205            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1206                id,
1207                item: self.compile_item(item),
1208            },
1209            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1210                id,
1211                item: self.compile_item(item),
1212            },
1213            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1214                id,
1215                item: self.compile_item(item),
1216            },
1217            gpu::VariableKind::Slice { id } => Variable::Slice {
1218                id,
1219                item: self.compile_item(item),
1220            },
1221            gpu::VariableKind::GlobalOutputArray(id) => {
1222                Variable::GlobalOutputArray(id, self.compile_item(item))
1223            }
1224            gpu::VariableKind::ConstantScalar(value) => {
1225                Variable::ConstantScalar(value, self.compile_elem(value.elem()))
1226            }
1227            gpu::VariableKind::SharedMemory {
1228                id,
1229                length,
1230                alignment,
1231            } => {
1232                let item = self.compile_item(item);
1233                if !self.shared_memories.iter().any(|s| s.index == id) {
1234                    self.shared_memories
1235                        .push(SharedMemory::new(id, item, length, alignment));
1236                }
1237                Variable::SharedMemory(id, item, length)
1238            }
1239            gpu::VariableKind::ConstantArray { id, length } => {
1240                let item = self.compile_item(item);
1241                Variable::ConstantArray(id, item, length)
1242            }
1243            gpu::VariableKind::Builtin(builtin) => match builtin {
1244                gpu::Builtin::AbsolutePos => {
1245                    self.flags.indexes.absolute_pos = true;
1246                    Variable::AbsolutePos
1247                }
1248                gpu::Builtin::CubePosCluster if self.compilation_options.supports_clusters => {
1249                    self.flags.indexes.cluster_pos = true;
1250                    Variable::ClusterRank
1251                }
1252                gpu::Builtin::CubePosClusterX if self.compilation_options.supports_clusters => {
1253                    self.flags.indexes.cluster_pos = true;
1254                    Variable::ClusterIndexX
1255                }
1256                gpu::Builtin::CubePosClusterY if self.compilation_options.supports_clusters => {
1257                    self.flags.indexes.cluster_pos = true;
1258                    Variable::ClusterIndexY
1259                }
1260                gpu::Builtin::CubePosClusterZ if self.compilation_options.supports_clusters => {
1261                    self.flags.indexes.cluster_pos = true;
1262                    Variable::ClusterIndexZ
1263                }
1264                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1265                // (1, 1, 1) if unsupported
1266                gpu::Builtin::CubePosCluster
1267                | gpu::Builtin::CubePosClusterX
1268                | gpu::Builtin::CubePosClusterY
1269                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1270                gpu::Builtin::AbsolutePosX => {
1271                    self.flags.indexes.absolute_pos_tuple = true;
1272                    Variable::AbsolutePosX
1273                }
1274                gpu::Builtin::AbsolutePosY => {
1275                    self.flags.indexes.absolute_pos_tuple = true;
1276                    Variable::AbsolutePosY
1277                }
1278                gpu::Builtin::AbsolutePosZ => {
1279                    self.flags.indexes.absolute_pos_tuple = true;
1280                    Variable::AbsolutePosZ
1281                }
1282                gpu::Builtin::CubeDim => {
1283                    self.flags.indexes.cube_dim = true;
1284                    Variable::CubeDim
1285                }
1286                gpu::Builtin::CubeDimX => {
1287                    self.flags.indexes.cube_dim_tuple = true;
1288                    Variable::CubeDimX
1289                }
1290                gpu::Builtin::CubeDimY => {
1291                    self.flags.indexes.cube_dim_tuple = true;
1292                    Variable::CubeDimY
1293                }
1294                gpu::Builtin::CubeDimZ => {
1295                    self.flags.indexes.cube_dim_tuple = true;
1296                    Variable::CubeDimZ
1297                }
1298                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1299                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1300                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1301                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1302                gpu::Builtin::CubePos => {
1303                    self.flags.indexes.cube_pos = true;
1304                    Variable::CubePos
1305                }
1306                gpu::Builtin::CubePosX => {
1307                    self.flags.indexes.cube_pos_tuple = true;
1308                    Variable::CubePosX
1309                }
1310                gpu::Builtin::CubePosY => {
1311                    self.flags.indexes.cube_pos_tuple = true;
1312                    Variable::CubePosY
1313                }
1314                gpu::Builtin::CubePosZ => {
1315                    self.flags.indexes.cube_pos_tuple = true;
1316                    Variable::CubePosZ
1317                }
1318                gpu::Builtin::CubeCount => {
1319                    self.flags.indexes.cube_count = true;
1320                    Variable::CubeCount
1321                }
1322                gpu::Builtin::CubeCountX => {
1323                    self.flags.indexes.cube_count_tuple = true;
1324                    Variable::CubeCountX
1325                }
1326                gpu::Builtin::CubeCountY => {
1327                    self.flags.indexes.cube_count_tuple = true;
1328                    Variable::CubeCountY
1329                }
1330                gpu::Builtin::CubeCountZ => {
1331                    self.flags.indexes.cube_count_tuple = true;
1332                    Variable::CubeCountZ
1333                }
1334                gpu::Builtin::UnitPos => {
1335                    self.flags.indexes.unit_pos = true;
1336                    Variable::UnitPos
1337                }
1338                gpu::Builtin::UnitPosX => {
1339                    self.flags.indexes.unit_pos_tuple = true;
1340                    Variable::UnitPosX
1341                }
1342                gpu::Builtin::UnitPosY => {
1343                    self.flags.indexes.unit_pos_tuple = true;
1344                    Variable::UnitPosY
1345                }
1346                gpu::Builtin::UnitPosZ => {
1347                    self.flags.indexes.unit_pos_tuple = true;
1348                    Variable::UnitPosZ
1349                }
1350                gpu::Builtin::PlaneDim => {
1351                    self.flags.indexes.plane_dim = true;
1352                    Variable::PlaneDim
1353                }
1354                gpu::Builtin::UnitPosPlane => {
1355                    self.flags.indexes.unit_pos_plane = true;
1356                    Variable::UnitPosPlane
1357                }
1358            },
1359            gpu::VariableKind::LocalArray { id, length } => {
1360                let item = self.compile_item(item);
1361                if !self.local_arrays.iter().any(|s| s.index == id) {
1362                    self.local_arrays.push(LocalArray::new(id, item, length));
1363                }
1364                Variable::LocalArray(id, item, length)
1365            }
1366            gpu::VariableKind::Matrix { id, mat } => {
1367                self.flags.inst_wmma = true;
1368                Variable::WmmaFragment {
1369                    id,
1370                    frag: self.compile_matrix(mat),
1371                }
1372            }
1373            gpu::VariableKind::Pipeline {
1374                id,
1375                item,
1376                num_stages,
1377            } => {
1378                self.flags.op_pipeline = true;
1379                let pipeline = Variable::Pipeline {
1380                    id,
1381                    item: self.compile_item(item),
1382                };
1383                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1384                    self.pipelines.push(PipelineOps::Init {
1385                        pipeline,
1386                        num_stages,
1387                    });
1388                }
1389                pipeline
1390            }
1391            gpu::VariableKind::Barrier { id, item, level } => {
1392                self.flags.op_barrier = true;
1393                match level {
1394                    gpu::BarrierLevel::CubeCoop(_) | gpu::BarrierLevel::CubeManual(_) => {
1395                        self.flags.indexes.cube_dim = true;
1396                        self.flags.indexes.unit_pos = true;
1397                    }
1398                    _ => {}
1399                }
1400                Variable::Barrier {
1401                    id,
1402                    item: self.compile_item(item),
1403                    level,
1404                }
1405            }
1406        }
1407    }
1408
1409    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1410        Fragment {
1411            ident: self.compile_matrix_ident(matrix.ident),
1412            m: matrix.m,
1413            n: matrix.n,
1414            k: matrix.k,
1415            elem: self.compile_elem(matrix.elem),
1416            layout: self.compile_matrix_layout(matrix.layout),
1417        }
1418    }
1419
1420    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1421        match ident {
1422            gpu::MatrixIdent::A => FragmentIdent::A,
1423            gpu::MatrixIdent::B => FragmentIdent::B,
1424            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1425        }
1426    }
1427
1428    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1429        match layout {
1430            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1431            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1432            gpu::MatrixLayout::Undefined => None,
1433        }
1434    }
1435
1436    fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1437        Binding {
1438            id: binding.id,
1439            item: self.compile_item(binding.item),
1440            location: binding.location,
1441            size: binding.size,
1442            vis: binding.visibility,
1443        }
1444    }
1445
1446    fn compile_item(&mut self, item: gpu::Item) -> Item<D> {
1447        let item = Item::new(
1448            self.compile_elem(item.elem),
1449            item.vectorization.map(NonZero::get).unwrap_or(1).into(),
1450            false,
1451        );
1452        if item.elem != super::Elem::TF32 {
1453            self.items.insert(item);
1454            self.items.insert(item.optimized());
1455        } else {
1456            // TF32 is represented as `float` in C++
1457            let mut item = item;
1458            item.elem = super::Elem::F32;
1459            self.items.insert(item);
1460        }
1461
1462        item
1463    }
1464
1465    fn compile_elem(&mut self, value: gpu::Elem) -> Elem<D> {
1466        match value {
1467            gpu::Elem::Float(kind) => match kind {
1468                gpu::FloatKind::F16 => {
1469                    self.flags.elem_f16 = true;
1470                    Elem::F16
1471                }
1472                gpu::FloatKind::BF16 => {
1473                    self.flags.elem_bf16 = true;
1474                    Elem::BF16
1475                }
1476                gpu::FloatKind::TF32 => Elem::TF32,
1477                gpu::FloatKind::Flex32 => Elem::F32,
1478                gpu::FloatKind::F32 => Elem::F32,
1479                gpu::FloatKind::F64 => Elem::F64,
1480            },
1481            gpu::Elem::AtomicFloat(kind) => match kind {
1482                gpu::FloatKind::F16 => Elem::Atomic(AtomicKind::F16),
1483                gpu::FloatKind::BF16 => Elem::Atomic(AtomicKind::BF16),
1484                gpu::FloatKind::F32 => Elem::Atomic(AtomicKind::F32),
1485                gpu::FloatKind::F64 => Elem::Atomic(AtomicKind::F64),
1486                kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1487            },
1488            gpu::Elem::Int(kind) => match kind {
1489                gpu::IntKind::I8 => Elem::I8,
1490                gpu::IntKind::I16 => Elem::I16,
1491                gpu::IntKind::I32 => Elem::I32,
1492                gpu::IntKind::I64 => Elem::I64,
1493            },
1494            gpu::Elem::AtomicInt(kind) => match kind {
1495                gpu::IntKind::I32 => Elem::Atomic(AtomicKind::I32),
1496                gpu::IntKind::I64 => Elem::Atomic(AtomicKind::I64),
1497                kind => panic!("atomic<{kind:?}> isn't supported yet"),
1498            },
1499            gpu::Elem::UInt(kind) => match kind {
1500                gpu::UIntKind::U8 => Elem::U8,
1501                gpu::UIntKind::U16 => Elem::U16,
1502                gpu::UIntKind::U32 => Elem::U32,
1503                gpu::UIntKind::U64 => Elem::U64,
1504            },
1505            gpu::Elem::AtomicUInt(kind) => match kind {
1506                gpu::UIntKind::U32 => Elem::Atomic(AtomicKind::U32),
1507                gpu::UIntKind::U64 => Elem::Atomic(AtomicKind::U64),
1508                kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1509            },
1510            gpu::Elem::Bool => Elem::Bool,
1511        }
1512    }
1513}
1514
1515fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1516    Variable::ConstantScalar(
1517        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1518        Elem::U32,
1519    )
1520}
1521
1522pub fn register_supported_types(props: &mut DeviceProperties<Feature>) {
1523    let supported_types = [
1524        gpu::Elem::UInt(gpu::UIntKind::U8),
1525        gpu::Elem::UInt(gpu::UIntKind::U16),
1526        gpu::Elem::UInt(gpu::UIntKind::U32),
1527        gpu::Elem::UInt(gpu::UIntKind::U64),
1528        gpu::Elem::Int(gpu::IntKind::I8),
1529        gpu::Elem::Int(gpu::IntKind::I16),
1530        gpu::Elem::Int(gpu::IntKind::I32),
1531        gpu::Elem::Int(gpu::IntKind::I64),
1532        gpu::Elem::AtomicInt(gpu::IntKind::I32),
1533        gpu::Elem::AtomicInt(gpu::IntKind::I64),
1534        gpu::Elem::AtomicUInt(gpu::UIntKind::U32),
1535        gpu::Elem::AtomicUInt(gpu::UIntKind::U64),
1536        gpu::Elem::Float(gpu::FloatKind::BF16),
1537        gpu::Elem::Float(gpu::FloatKind::F16),
1538        gpu::Elem::Float(gpu::FloatKind::F32),
1539        gpu::Elem::Float(gpu::FloatKind::Flex32),
1540        gpu::Elem::AtomicFloat(gpu::FloatKind::F32),
1541        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
1542        //gpu::Elem::Float(gpu::FloatKind::F64),
1543        gpu::Elem::Bool,
1544    ];
1545
1546    for ty in supported_types {
1547        props.register_feature(Feature::Type(ty));
1548    }
1549}