cubecl_cpp/shared/
base.rs

1use std::{collections::HashSet, fmt::Debug};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind, VariableKind};
5use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
6use cubecl_core::{
7    Compiler,
8    ir::{self as gpu},
9};
10use cubecl_core::{CubeDim, ir::ElemType};
11use cubecl_core::{
12    ir::{Operation, SourceLoc},
13    prelude::{FastMath, KernelDefinition},
14};
15use cubecl_opt::{Optimizer, SharedLiveness};
16use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage};
17
18use crate::shared::MmaShape;
19
20use super::{
21    BinaryInstruction, Binding, Body, Component, ComputeKernel, ConstArray, Dialect, Elem, FP6Kind,
22    Fragment, FragmentIdent, FragmentLayout, IndexAssignInstruction, IndexInstruction, Instruction,
23    Item, LocalArray, SharedMemory, UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
24};
25use super::{FP4Kind, barrier::BarrierOps};
26use super::{FP8Kind, pipeline::PipelineOps};
27
28pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
29    std::sync::atomic::AtomicU32::new(0);
30
31#[derive(Clone, Debug)]
32pub struct CompilationOptions {
33    pub warp_size: u32,
34    pub supports_features: CppSupportedFeatures,
35}
36
37#[derive(Clone, Debug, Default)]
38pub struct CppSupportedFeatures {
39    pub grid_constants: bool,
40    pub clusters: bool,
41    pub fast_math: bool,
42    pub elect_sync: bool,
43}
44
45impl Default for CompilationOptions {
46    fn default() -> Self {
47        Self {
48            warp_size: 32,
49            supports_features: Default::default(),
50        }
51    }
52}
53
54/// Cube indexes flags.
55/// When true the corresponding index is declared and computed as needed in the kernel.
56#[derive(Debug, Clone, Default)]
57pub struct CubeIndexFlags {
58    pub absolute_pos: bool,
59    pub absolute_pos_tuple: bool,
60    pub cube_count: bool,
61    pub cube_count_tuple: bool,
62    pub cube_dim: bool,
63    pub cube_dim_tuple: bool,
64    pub cube_pos: bool,
65    pub cube_pos_tuple: bool,
66    pub plane_dim: bool,
67    pub plane_dim_checked: bool,
68    pub plane_index: bool,
69    pub unit_pos: bool,
70    pub unit_pos_tuple: bool,
71    pub unit_pos_plane: bool,
72    pub cluster_pos: bool,
73}
74
75/// Flags gathered during Cube IR translation for the kernel compilation.
76#[derive(Debug, Clone, Default)]
77pub struct Flags {
78    pub elem_fp4: bool,
79    pub elem_fp6: bool,
80    pub elem_fp8: bool,
81    pub elem_bf16: bool,
82    pub elem_f16: bool,
83    pub elem_tf32: bool,
84    pub indexes: CubeIndexFlags,
85    pub op_barrier: bool,
86    pub op_pipeline: bool,
87    pub inst_tma: bool,
88    pub inst_tma_im2col: bool,
89    pub inst_wmma: bool,
90    pub inst_ptx_wrappers: bool,
91    pub use_grid_constants: bool,
92    pub static_meta_length: usize,
93    pub has_dynamic_meta: bool,
94    pub cube_dim: CubeDim,
95    pub cluster_dim: Option<CubeDim>,
96}
97
98#[allow(clippy::too_many_arguments)]
99#[derive(Clone, Debug, Default)]
100pub struct CppCompiler<D: Dialect> {
101    barriers: Vec<BarrierOps<D>>,
102    compilation_options: CompilationOptions,
103    const_arrays: Vec<ConstArray<D>>,
104    ext_meta_positions: Vec<u32>,
105    cluster_dim: CubeDim,
106    extensions: Vec<D::Extension>,
107    flags: Flags,
108    items: HashSet<Item<D>>,
109    local_arrays: Vec<LocalArray<D>>,
110    metadata: cubecl_core::Metadata,
111    pipelines: Vec<PipelineOps<D>>,
112    source_loc: Option<SourceLoc>,
113    strategy: ExecutionMode,
114}
115
116impl<D: Dialect> Compiler for CppCompiler<D> {
117    type Representation = ComputeKernel<D>;
118    type CompilationOptions = CompilationOptions;
119
120    fn compile(
121        &mut self,
122        mut kernel: KernelDefinition,
123        compilation_options: &Self::CompilationOptions,
124        strategy: ExecutionMode,
125    ) -> Self::Representation {
126        self.compilation_options = compilation_options.clone();
127        self.strategy = strategy;
128
129        if !self.compilation_options.supports_features.clusters {
130            kernel.options.cluster_dim = None;
131        }
132        self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
133
134        let ir = self.clone().compile_ir(kernel);
135        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
136        ir
137    }
138
139    fn elem_size(&self, elem: gpu::ElemType) -> usize {
140        elem.size()
141    }
142
143    fn extension(&self) -> &'static str {
144        "cpp"
145    }
146}
147
148impl<D: Dialect> CppCompiler<D> {
149    fn compile_ir(mut self, value: KernelDefinition) -> ComputeKernel<D> {
150        self.build_metadata(&value);
151
152        let instructions = self.compile_scope(&mut value.body.clone());
153        let tensor_maps = value
154            .tensor_maps
155            .into_iter()
156            .map(|b| self.compile_binding(b))
157            .collect();
158        let buffers = value
159            .buffers
160            .into_iter()
161            .map(|b| self.compile_binding(b))
162            .collect();
163        let scalars = value
164            .scalars
165            .into_iter()
166            .map(|binding| (self.compile_storage_type(binding.ty), binding.count))
167            .collect();
168
169        // translation flags
170        let flags = Flags {
171            indexes: D::builtin_rules(&self.flags.indexes),
172            inst_wmma: self.flags.inst_wmma,
173            op_pipeline: self.flags.op_pipeline,
174            op_barrier: self.flags.op_barrier,
175            elem_fp4: self.flags.elem_fp4,
176            elem_fp6: self.flags.elem_fp6,
177            elem_fp8: self.flags.elem_fp8,
178            elem_bf16: self.flags.elem_bf16,
179            elem_f16: self.flags.elem_f16,
180            elem_tf32: self.flags.elem_tf32,
181            inst_tma: self.flags.inst_tma,
182            inst_tma_im2col: self.flags.inst_tma_im2col,
183            inst_ptx_wrappers: self.flags.inst_ptx_wrappers,
184            use_grid_constants: self.compilation_options.supports_features.grid_constants,
185            // TODO: At some point we should only pass dynamic meta if tensors are present,
186            // not if only arrays are present. For now, leave like this
187            has_dynamic_meta: self.metadata.static_len() > 0,
188            static_meta_length: self.metadata.static_len() as usize,
189            cube_dim: value.cube_dim,
190            cluster_dim: value.options.cluster_dim,
191        };
192
193        let mut opt = Optimizer::shared_only(value.body, value.cube_dim);
194        let shared_allocs = opt.analysis::<SharedLiveness>();
195        let shared_memories = shared_allocs
196            .allocations
197            .values()
198            .map(|alloc| SharedMemory {
199                index: alloc.smem.id,
200                item: self.compile_type(alloc.smem.ty),
201                length: alloc.smem.length,
202                align: alloc.smem.align,
203                offset: alloc.offset,
204            })
205            .collect();
206
207        let body = Body {
208            instructions,
209            shared_memories,
210            pipelines: self.pipelines,
211            barriers: self.barriers,
212            const_arrays: self.const_arrays,
213            local_arrays: self.local_arrays,
214        };
215
216        let mut cluster_dim = value.options.cluster_dim;
217        if !self.compilation_options.supports_features.clusters {
218            cluster_dim = None;
219        }
220
221        ComputeKernel {
222            tensor_maps,
223            buffers,
224            scalars,
225            meta_static_len: self.metadata.static_len() as usize,
226            cube_dim: value.cube_dim,
227            body,
228            extensions: self.extensions,
229            flags,
230            items: self.items,
231            kernel_name: value.options.kernel_name,
232            cluster_dim,
233        }
234    }
235
236    fn build_metadata(&mut self, value: &KernelDefinition) {
237        let mut num_ext = 0;
238
239        let mut all_meta: Vec<_> = value
240            .buffers
241            .iter()
242            .chain(value.tensor_maps.iter())
243            .map(|buf| (buf.id, buf.has_extended_meta))
244            .collect();
245
246        all_meta.sort_by_key(|(id, _)| *id);
247
248        for (_, has_extended_meta) in &all_meta {
249            self.ext_meta_positions.push(num_ext);
250            if *has_extended_meta {
251                num_ext += 1;
252            }
253        }
254
255        let num_meta = all_meta.len();
256
257        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
258    }
259
260    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
261        let id = var.index().expect("Variable should have index");
262        self.ext_meta_positions[id as usize]
263    }
264
265    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
266        let mut instructions = Vec::new();
267
268        let const_arrays = scope
269            .const_arrays
270            .drain(..)
271            .map(|(var, values)| ConstArray {
272                index: var.index().unwrap(),
273                item: self.compile_type(var.ty),
274                size: values.len() as u32,
275                values: values
276                    .into_iter()
277                    .map(|val| self.compile_variable(val))
278                    .collect(),
279            })
280            .collect::<Vec<_>>();
281        self.const_arrays.extend(const_arrays);
282
283        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.strategy));
284        let dialect_processors = D::processors();
285        let mut processors: Vec<&dyn Processor> = vec![&*checked_io];
286        processors.extend(dialect_processors.iter().map(|it| &**it));
287
288        let processing = scope.process(processors);
289
290        for var in processing.variables {
291            instructions.push(Instruction::DeclareVariable {
292                var: self.compile_variable(var),
293            });
294        }
295
296        processing
297            .instructions
298            .into_iter()
299            .for_each(|op| self.compile_instruction(&mut instructions, op));
300
301        instructions
302    }
303
304    fn compile_instruction(
305        &mut self,
306        instructions: &mut Vec<Instruction<D>>,
307        instruction: gpu::Instruction,
308    ) {
309        self.update_debug_loc(instructions, &instruction);
310        let out = instruction.out;
311        match instruction.operation {
312            gpu::Operation::Copy(variable) => {
313                instructions.push(Instruction::Assign(UnaryInstruction {
314                    input: self.compile_variable(variable),
315                    out: self.compile_variable(out.unwrap()),
316                }));
317            }
318            gpu::Operation::Arithmetic(op) => {
319                self.compile_arithmetic(op, out, instruction.modes, instructions)
320            }
321            gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
322            gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
323            gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions),
324            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
325            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
326            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
327            gpu::Operation::Synchronization(val) => match val {
328                gpu::Synchronization::SyncCube => instructions.push(Instruction::SyncThreads),
329                gpu::Synchronization::SyncPlane => instructions.push(Instruction::SyncWarp),
330                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
331                gpu::Synchronization::SyncAsyncProxyShared => {
332                    self.flags.inst_tma = true;
333                    instructions.push(Instruction::ProxyAsyncToSharedFence)
334                }
335            },
336            gpu::Operation::Plane(op) => {
337                self.flags.indexes.plane_dim_checked = true;
338                let out = self.compile_variable(out.unwrap());
339                match op {
340                    gpu::Plane::Sum(op) => {
341                        let instruction = WarpInstruction::ReduceSum {
342                            input: self.compile_variable(op.input),
343                            out,
344                        };
345                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
346                        instructions.push(Instruction::Warp(instruction));
347                    }
348                    gpu::Plane::InclusiveSum(op) => {
349                        self.flags.indexes.unit_pos_plane = true;
350                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
351                            input: self.compile_variable(op.input),
352                            out,
353                        }))
354                    }
355                    gpu::Plane::InclusiveProd(op) => {
356                        self.flags.indexes.unit_pos_plane = true;
357                        instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
358                            input: self.compile_variable(op.input),
359                            out,
360                        }))
361                    }
362                    gpu::Plane::ExclusiveSum(op) => {
363                        self.flags.indexes.unit_pos_plane = true;
364                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
365                            input: self.compile_variable(op.input),
366                            out,
367                        }))
368                    }
369                    gpu::Plane::ExclusiveProd(op) => {
370                        self.flags.indexes.unit_pos_plane = true;
371                        instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
372                            input: self.compile_variable(op.input),
373                            out,
374                        }))
375                    }
376                    gpu::Plane::Prod(op) => {
377                        let instruction = WarpInstruction::ReduceProd {
378                            input: self.compile_variable(op.input),
379                            out,
380                        };
381                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
382                        instructions.push(Instruction::Warp(instruction))
383                    }
384                    gpu::Plane::Max(op) => {
385                        let instruction = WarpInstruction::ReduceMax {
386                            input: self.compile_variable(op.input),
387                            out,
388                        };
389                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
390                        instructions.push(Instruction::Warp(instruction))
391                    }
392                    gpu::Plane::Min(op) => {
393                        let instruction = WarpInstruction::ReduceMin {
394                            input: self.compile_variable(op.input),
395                            out,
396                        };
397                        D::register_warp_instruction_extension(&mut self.extensions, &instruction);
398                        instructions.push(Instruction::Warp(instruction))
399                    }
400                    gpu::Plane::Elect => {
401                        if self.compilation_options.supports_features.elect_sync {
402                            self.flags.inst_ptx_wrappers = true;
403                            instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
404                        } else {
405                            instructions
406                                .push(Instruction::Warp(WarpInstruction::ElectFallback { out }))
407                        }
408                    }
409                    gpu::Plane::All(op) => {
410                        instructions.push(Instruction::Warp(WarpInstruction::All {
411                            input: self.compile_variable(op.input),
412                            out,
413                        }))
414                    }
415                    gpu::Plane::Any(op) => {
416                        instructions.push(Instruction::Warp(WarpInstruction::Any {
417                            input: self.compile_variable(op.input),
418                            out,
419                        }))
420                    }
421                    gpu::Plane::Ballot(op) => {
422                        instructions.push(Instruction::Warp(WarpInstruction::Ballot {
423                            input: self.compile_variable(op.input),
424                            out,
425                        }))
426                    }
427                    gpu::Plane::Broadcast(op) => {
428                        instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
429                            input: self.compile_variable(op.lhs),
430                            id: self.compile_variable(op.rhs),
431                            out,
432                        }))
433                    }
434                    gpu::Plane::Shuffle(op) => {
435                        instructions.push(Instruction::Warp(WarpInstruction::Shuffle {
436                            input: self.compile_variable(op.lhs),
437                            src_lane: self.compile_variable(op.rhs),
438                            out,
439                        }))
440                    }
441                    gpu::Plane::ShuffleXor(op) => {
442                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor {
443                            input: self.compile_variable(op.lhs),
444                            mask: self.compile_variable(op.rhs),
445                            out,
446                        }))
447                    }
448                    gpu::Plane::ShuffleUp(op) => {
449                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp {
450                            input: self.compile_variable(op.lhs),
451                            delta: self.compile_variable(op.rhs),
452                            out,
453                        }))
454                    }
455                    gpu::Plane::ShuffleDown(op) => {
456                        instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown {
457                            input: self.compile_variable(op.lhs),
458                            delta: self.compile_variable(op.rhs),
459                            out,
460                        }))
461                    }
462                }
463            }
464            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
465            gpu::Operation::NonSemantic(debug) => match debug {
466                gpu::NonSemantic::Print {
467                    format_string,
468                    args,
469                } => instructions.push(Instruction::Printf {
470                    format_string,
471                    args: args
472                        .into_iter()
473                        .map(|arg| self.compile_variable(arg))
474                        .collect(),
475                }),
476                gpu::NonSemantic::Comment { content } => {
477                    instructions.push(Instruction::Comment { content })
478                }
479                // Don't need to handle scopes
480                _ => {}
481            },
482            gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
483                gpu::BarrierOps::Declare { barrier } => {
484                    let VariableKind::Barrier { level, .. } = barrier.kind else {
485                        unreachable!()
486                    };
487                    let barrier = self.compile_variable(barrier);
488                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Declare {
489                        barrier,
490                        level,
491                    }));
492                }
493                gpu::BarrierOps::Init {
494                    barrier,
495                    is_elected,
496                    arrival_count,
497                    with_async_proxy_fence,
498                } => {
499                    let VariableKind::Barrier { level, .. } = barrier.kind else {
500                        unreachable!()
501                    };
502                    let barrier = self.compile_variable(barrier);
503                    let arrival_count = self.compile_variable(arrival_count);
504                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
505                        barrier,
506                        is_elected: self.compile_variable(is_elected),
507                        arrival_count,
508                        level,
509                        with_async_proxy_fence,
510                    }));
511                }
512                gpu::BarrierOps::InitManual {
513                    barrier,
514                    arrival_count,
515                } => {
516                    let barrier = self.compile_variable(barrier);
517                    let arrival_count = self.compile_variable(arrival_count);
518                    instructions.push(Instruction::Barrier(
519                        super::barrier::BarrierOps::InitManual {
520                            barrier,
521                            arrival_count,
522                        },
523                    ));
524                }
525                gpu::BarrierOps::MemCopyAsync {
526                    barrier,
527                    source,
528                    source_length,
529                    offset_source,
530                    offset_out,
531                } => {
532                    instructions.push(Instruction::Barrier(
533                        super::barrier::BarrierOps::MemCopyAsync {
534                            barrier: self.compile_variable(barrier),
535                            source: self.compile_variable(source),
536                            destination: self.compile_variable(out.unwrap()),
537                            source_length: self.compile_variable(source_length),
538                            offset_source: self.compile_variable(offset_source),
539                            offset_out: self.compile_variable(offset_out),
540                            cooperative: false,
541                        },
542                    ));
543                }
544                gpu::BarrierOps::MemCopyAsyncCooperative {
545                    barrier,
546                    source,
547                    source_length,
548                    offset_source,
549                    offset_out,
550                } => {
551                    instructions.push(Instruction::Barrier(
552                        super::barrier::BarrierOps::MemCopyAsync {
553                            barrier: self.compile_variable(barrier),
554                            source: self.compile_variable(source),
555                            destination: self.compile_variable(out.unwrap()),
556                            source_length: self.compile_variable(source_length),
557                            offset_source: self.compile_variable(offset_source),
558                            offset_out: self.compile_variable(offset_out),
559                            cooperative: true,
560                        },
561                    ));
562                }
563                gpu::BarrierOps::MemCopyAsyncTx {
564                    barrier,
565                    source,
566                    source_length,
567                    offset_source,
568                    offset_out,
569                } => {
570                    instructions.push(Instruction::Barrier(
571                        super::barrier::BarrierOps::MemCopyAsyncTx {
572                            barrier: self.compile_variable(barrier),
573                            source: self.compile_variable(source),
574                            destination: self.compile_variable(out.unwrap()),
575                            source_length: self.compile_variable(source_length),
576                            offset_source: self.compile_variable(offset_source),
577                            offset_out: self.compile_variable(offset_out),
578                        },
579                    ));
580                }
581                gpu::BarrierOps::TmaLoad {
582                    barrier,
583                    tensor_map,
584                    offset_out,
585                    indices,
586                } => {
587                    instructions.push(Instruction::Barrier(
588                        super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
589                            barrier: self.compile_variable(barrier),
590                            smem_buffer: self.compile_variable(out.unwrap()),
591                            smem_offset: self.compile_variable(offset_out),
592                            tensor_map: self.compile_variable(tensor_map),
593                            indices: indices
594                                .into_iter()
595                                .map(|it| self.compile_variable(it))
596                                .collect(),
597                        },
598                    ));
599                }
600                gpu::BarrierOps::TmaLoadIm2col {
601                    barrier,
602                    tensor_map,
603                    offset_out,
604                    indices,
605                    offsets,
606                } => {
607                    self.flags.inst_tma_im2col = true;
608                    instructions.push(Instruction::Barrier(
609                        super::barrier::BarrierOps::TmaLoadIm2col {
610                            barrier: self.compile_variable(barrier),
611                            smem_buffer: self.compile_variable(out.unwrap()),
612                            smem_offset: self.compile_variable(offset_out),
613                            tensor_map: self.compile_variable(tensor_map),
614                            indices: indices
615                                .into_iter()
616                                .map(|it| self.compile_variable(it))
617                                .collect(),
618                            offsets: offsets
619                                .into_iter()
620                                .map(|it| self.compile_variable(it))
621                                .collect(),
622                        },
623                    ));
624                }
625                gpu::BarrierOps::Arrive { barrier } => {
626                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
627                        barrier: self.compile_variable(barrier),
628                        token: self.compile_variable(out.unwrap()),
629                    }))
630                }
631                gpu::BarrierOps::ArriveTx {
632                    barrier,
633                    arrive_count_update,
634                    transaction_count_update,
635                } => {
636                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
637                        barrier: self.compile_variable(barrier),
638                        token: self.compile_variable(out.unwrap()),
639                        arrive_count_update: self.compile_variable(arrive_count_update),
640                        transaction_count_update: self.compile_variable(transaction_count_update),
641                    }))
642                }
643                gpu::BarrierOps::ExpectTx {
644                    barrier,
645                    transaction_count_update,
646                } => {
647                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
648                        barrier: self.compile_variable(barrier),
649                        transaction_count_update: self.compile_variable(transaction_count_update),
650                    }))
651                }
652                gpu::BarrierOps::Wait { barrier, token } => {
653                    instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
654                        barrier: self.compile_variable(barrier),
655                        token: self.compile_variable(token),
656                    }))
657                }
658                gpu::BarrierOps::WaitParity { barrier, phase } => instructions.push(
659                    Instruction::Barrier(super::barrier::BarrierOps::WaitParity {
660                        barrier: self.compile_variable(barrier),
661                        phase: self.compile_variable(phase),
662                    }),
663                ),
664                gpu::BarrierOps::ArriveAndWait { barrier } => {
665                    let VariableKind::Barrier { level, .. } = barrier.kind else {
666                        unreachable!()
667                    };
668                    instructions.push(Instruction::Barrier(
669                        super::barrier::BarrierOps::ArriveAndWait {
670                            barrier: self.compile_variable(barrier),
671                            level,
672                        },
673                    ))
674                }
675            },
676            gpu::Operation::Tma(tma_ops) => {
677                self.flags.inst_tma = true;
678                match tma_ops {
679                    gpu::TmaOps::TmaStore {
680                        source,
681                        coordinates,
682                        offset_source,
683                    } => {
684                        instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
685                            smem_buffer: self.compile_variable(source),
686                            smem_offset: self.compile_variable(offset_source),
687                            tensor_map: self.compile_variable(out.unwrap()),
688                            indices: coordinates
689                                .into_iter()
690                                .map(|it| self.compile_variable(it))
691                                .collect(),
692                        });
693                    }
694                    gpu::TmaOps::CommitGroup => {
695                        instructions.push(Instruction::BulkCommitGroup);
696                    }
697                    gpu::TmaOps::WaitGroup { max_pending } => {
698                        instructions.push(Instruction::BulkWaitGroup { max_pending });
699                    }
700                    gpu::TmaOps::WaitGroupRead { max_pending } => {
701                        instructions.push(Instruction::BulkWaitGroupRead { max_pending });
702                    }
703                }
704            }
705            gpu::Operation::Marker(_) => {}
706        }
707    }
708
709    fn update_debug_loc(
710        &mut self,
711        instructions: &mut Vec<Instruction<D>>,
712        inst: &gpu::Instruction,
713    ) {
714        if !matches!(inst.operation, Operation::NonSemantic(_)) {
715            match &inst.source_loc {
716                Some(loc) if Some(loc) != self.source_loc.as_ref() => {
717                    self.source_loc = Some(loc.clone());
718                    instructions.push(Instruction::Line {
719                        file: loc.source.file.clone(),
720                        line: loc.line,
721                    });
722                }
723                _ => {}
724            }
725        }
726    }
727
728    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
729        self.flags.inst_wmma = true;
730
731        let out = self.compile_variable(out.unwrap());
732
733        let inst = match cmma {
734            gpu::CoopMma::Fill { value } => WmmaInstruction::Fill {
735                frag: out,
736                value: self.compile_variable(value),
737            },
738            gpu::CoopMma::Load {
739                value,
740                stride,
741                offset,
742                layout,
743            } => WmmaInstruction::Load {
744                frag: out,
745                offset: self.compile_variable(offset),
746                value: self.compile_variable(value),
747                stride: self.compile_variable(stride),
748                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
749            },
750            gpu::CoopMma::Execute {
751                mat_a,
752                mat_b,
753                mat_c,
754            } => WmmaInstruction::Execute {
755                frag_a: self.compile_variable(mat_a),
756                frag_b: self.compile_variable(mat_b),
757                frag_c: self.compile_variable(mat_c),
758                frag_d: out,
759                warp_size: self.compilation_options.warp_size,
760            },
761            gpu::CoopMma::ExecuteManual {
762                matrix,
763                registers_a,
764                registers_b,
765                registers_c,
766            } => WmmaInstruction::ExecuteManual {
767                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
768                frag_a: registers_a
769                    .into_iter()
770                    .map(|it| self.compile_variable(it))
771                    .collect(),
772                frag_b: registers_b
773                    .into_iter()
774                    .map(|it| self.compile_variable(it))
775                    .collect(),
776                frag_c: registers_c
777                    .into_iter()
778                    .map(|it| self.compile_variable(it))
779                    .collect(),
780                frag_d: out,
781            },
782            gpu::CoopMma::ExecuteScaled {
783                matrix,
784                registers_a,
785                registers_b,
786                registers_c,
787                scales_a,
788                scales_b,
789                scales_factor,
790            } => WmmaInstruction::ExecuteScaled {
791                shape: MmaShape::new(matrix.m, matrix.n, matrix.k),
792                frag_a: registers_a
793                    .into_iter()
794                    .map(|it| self.compile_variable(it))
795                    .collect(),
796                frag_b: registers_b
797                    .into_iter()
798                    .map(|it| self.compile_variable(it))
799                    .collect(),
800                frag_c: registers_c
801                    .into_iter()
802                    .map(|it| self.compile_variable(it))
803                    .collect(),
804                frag_d: out,
805
806                scales_a: self.compile_variable(scales_a),
807                scales_b: self.compile_variable(scales_b),
808                scales_factor,
809            },
810            gpu::CoopMma::Store {
811                mat,
812                stride,
813                offset,
814                layout,
815            } => {
816                self.flags.indexes.unit_pos = true;
817                self.flags.indexes.plane_index = true;
818                WmmaInstruction::Store {
819                    output: out,
820                    offset: self.compile_variable(offset),
821                    frag: self.compile_variable(mat),
822                    stride: self.compile_variable(stride),
823                    layout: self
824                        .compile_matrix_layout(layout)
825                        .expect("Layout required for store instruction"),
826                }
827            }
828            gpu::CoopMma::Cast { input } => WmmaInstruction::Cast {
829                input: self.compile_variable(input),
830                output: out,
831            },
832            gpu::CoopMma::RowIndex { .. } | gpu::CoopMma::ColIndex { .. } => {
833                panic!("Row/Col index should be handled by processors")
834            }
835        };
836
837        D::register_wmma_instruction_extension(&mut self.extensions, &inst);
838
839        Instruction::Wmma(inst)
840    }
841
842    fn compile_metadata(
843        &mut self,
844        metadata: gpu::Metadata,
845        out: Option<gpu::Variable>,
846    ) -> Instruction<D> {
847        let out = out.unwrap();
848        match metadata {
849            gpu::Metadata::Stride { dim, var } => {
850                let position = self.ext_meta_position(var);
851                let offset = self.metadata.stride_offset_index(position);
852                Instruction::ExtendedMetadata {
853                    info_offset: self.compile_variable(offset.into()),
854                    dim: self.compile_variable(dim),
855                    split_meta: self.compilation_options.supports_features.grid_constants,
856                    static_offset: self.metadata.static_len(),
857                    out: self.compile_variable(out),
858                }
859            }
860            gpu::Metadata::Shape { dim, var } => {
861                let position = self.ext_meta_position(var);
862                let offset = self.metadata.shape_offset_index(position);
863                Instruction::ExtendedMetadata {
864                    info_offset: self.compile_variable(offset.into()),
865                    dim: self.compile_variable(dim),
866                    split_meta: self.compilation_options.supports_features.grid_constants,
867                    static_offset: self.metadata.static_len(),
868                    out: self.compile_variable(out),
869                }
870            }
871            gpu::Metadata::Rank { var } => {
872                let out = self.compile_variable(out);
873                let pos = self.ext_meta_position(var);
874                let offset = self.metadata.rank_index(pos);
875                super::Instruction::Metadata {
876                    info_offset: self.compile_variable(offset.into()),
877                    split_meta: self.compilation_options.supports_features.grid_constants,
878                    out,
879                }
880            }
881            gpu::Metadata::Length { var } => {
882                let input = self.compile_variable(var);
883                let out = self.compile_variable(out);
884
885                match input {
886                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
887                    Variable::SharedMemory(_id, _item, length) => {
888                        Instruction::ConstLength { length, out }
889                    }
890                    _ => {
891                        let id = input.id().expect("Variable should have id");
892                        let offset = self.metadata.len_index(id);
893                        Instruction::Metadata {
894                            info_offset: self.compile_variable(offset.into()),
895                            split_meta: self.compilation_options.supports_features.grid_constants,
896                            out,
897                        }
898                    }
899                }
900            }
901            gpu::Metadata::BufferLength { var } => {
902                let input = self.compile_variable(var);
903                let out = self.compile_variable(out);
904
905                match input {
906                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
907                    _ => {
908                        let id = input.id().expect("Variable should have id");
909                        let offset = self.metadata.buffer_len_index(id);
910                        Instruction::Metadata {
911                            info_offset: self.compile_variable(offset.into()),
912                            split_meta: self.compilation_options.supports_features.grid_constants,
913                            out,
914                        }
915                    }
916                }
917            }
918        }
919    }
920
921    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
922        match branch {
923            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
924                cond: self.compile_variable(op.cond),
925                instructions: self.compile_scope(&mut op.scope),
926            }),
927            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
928                cond: self.compile_variable(op.cond),
929                instructions_if: self.compile_scope(&mut op.scope_if),
930                instructions_else: self.compile_scope(&mut op.scope_else),
931            }),
932            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
933                value: self.compile_variable(op.value),
934                instructions_default: self.compile_scope(&mut op.scope_default),
935                instructions_cases: op
936                    .cases
937                    .into_iter()
938                    .map(|(val, mut block)| {
939                        (self.compile_variable(val), self.compile_scope(&mut block))
940                    })
941                    .collect(),
942            }),
943            gpu::Branch::Return => instructions.push(Instruction::Return),
944            gpu::Branch::Break => instructions.push(Instruction::Break),
945            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
946                i: self.compile_variable(range_loop.i),
947                start: self.compile_variable(range_loop.start),
948                end: self.compile_variable(range_loop.end),
949                step: range_loop.step.map(|it| self.compile_variable(it)),
950                inclusive: range_loop.inclusive,
951                instructions: self.compile_scope(&mut range_loop.scope),
952            }),
953            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
954                instructions: self.compile_scope(&mut op.scope),
955            }),
956        };
957    }
958
959    fn compile_atomic(
960        &mut self,
961        value: gpu::AtomicOp,
962        out: Option<gpu::Variable>,
963        instructions: &mut Vec<Instruction<D>>,
964    ) {
965        let out = out.unwrap();
966        match value {
967            gpu::AtomicOp::Load(op) => {
968                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
969            }
970            gpu::AtomicOp::Store(op) => {
971                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
972            }
973            gpu::AtomicOp::Swap(op) => {
974                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
975            }
976            gpu::AtomicOp::Add(op) => {
977                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
978            }
979            gpu::AtomicOp::Sub(op) => {
980                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
981            }
982            gpu::AtomicOp::Max(op) => {
983                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
984            }
985            gpu::AtomicOp::Min(op) => {
986                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
987            }
988            gpu::AtomicOp::And(op) => {
989                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
990            }
991            gpu::AtomicOp::Or(op) => {
992                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
993            }
994            gpu::AtomicOp::Xor(op) => {
995                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
996            }
997            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
998                input: self.compile_variable(op.input),
999                cmp: self.compile_variable(op.cmp),
1000                val: self.compile_variable(op.val),
1001                out: self.compile_variable(out),
1002            }),
1003        }
1004    }
1005
1006    fn compile_arithmetic(
1007        &mut self,
1008        value: gpu::Arithmetic,
1009        out: Option<gpu::Variable>,
1010        modes: InstructionModes,
1011        instructions: &mut Vec<Instruction<D>>,
1012    ) {
1013        let out = out.unwrap();
1014        match value {
1015            gpu::Arithmetic::Add(op) => {
1016                instructions.push(Instruction::Add(self.compile_binary(op, out)))
1017            }
1018            gpu::Arithmetic::SaturatingAdd(op) => {
1019                instructions.push(Instruction::SaturatingAdd(self.compile_binary(op, out)))
1020            }
1021            gpu::Arithmetic::Mul(op) => {
1022                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
1023            }
1024            gpu::Arithmetic::Div(op) => {
1025                let op = self.compile_binary(op, out);
1026                instructions.push(self.select_fast_float(
1027                    out.ty,
1028                    modes,
1029                    FastMath::AllowReciprocal
1030                        | FastMath::ReducedPrecision
1031                        | FastMath::UnsignedZero
1032                        | FastMath::NotInf,
1033                    Instruction::Div(op),
1034                    Instruction::FastDiv(op),
1035                ))
1036            }
1037            gpu::Arithmetic::Sub(op) => {
1038                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
1039            }
1040            gpu::Arithmetic::SaturatingSub(op) => {
1041                instructions.push(Instruction::SaturatingSub(self.compile_binary(op, out)))
1042            }
1043            gpu::Arithmetic::MulHi(op) => {
1044                let instruction = Instruction::HiMul(self.compile_binary(op, out));
1045                D::register_instruction_extension(&mut self.extensions, &instruction);
1046                instructions.push(instruction)
1047            }
1048            gpu::Arithmetic::Modulo(op) => {
1049                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
1050            }
1051            gpu::Arithmetic::Abs(op) => {
1052                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
1053            }
1054            gpu::Arithmetic::Exp(op) => {
1055                let op = self.compile_unary(op, out);
1056                instructions.push(self.select_fast_float(
1057                    out.ty,
1058                    modes,
1059                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1060                    Instruction::Exp(op),
1061                    Instruction::FastExp(op),
1062                ));
1063            }
1064            gpu::Arithmetic::Log(op) => {
1065                let op = self.compile_unary(op, out);
1066                instructions.push(self.select_fast_float(
1067                    out.ty,
1068                    modes,
1069                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1070                    Instruction::Log(op),
1071                    Instruction::FastLog(op),
1072                ));
1073            }
1074            gpu::Arithmetic::Log1p(op) => {
1075                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
1076            }
1077            gpu::Arithmetic::Cos(op) => {
1078                let op = self.compile_unary(op, out);
1079                instructions.push(self.select_fast_float(
1080                    out.ty,
1081                    modes,
1082                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1083                    Instruction::Cos(op),
1084                    Instruction::FastCos(op),
1085                ));
1086            }
1087            gpu::Arithmetic::Sin(op) => {
1088                let op = self.compile_unary(op, out);
1089                instructions.push(self.select_fast_float(
1090                    out.ty,
1091                    modes,
1092                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1093                    Instruction::Sin(op),
1094                    Instruction::FastSin(op),
1095                ));
1096            }
1097            gpu::Arithmetic::Tanh(op) => {
1098                let op = self.compile_unary(op, out);
1099                let instruction = Instruction::Tanh(op);
1100                D::register_instruction_extension(&mut self.extensions, &instruction);
1101                instructions.push(self.select_fast_float(
1102                    out.ty,
1103                    modes,
1104                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1105                    instruction,
1106                    Instruction::FastTanh(op),
1107                ))
1108            }
1109            gpu::Arithmetic::Powf(op) => {
1110                let op = self.compile_binary(op, out);
1111                instructions.push(self.select_fast_float(
1112                    out.ty,
1113                    modes,
1114                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1115                    Instruction::Powf(op),
1116                    Instruction::FastPowf(op),
1117                ))
1118            }
1119            gpu::Arithmetic::Powi(op) => {
1120                instructions.push(Instruction::Powi(self.compile_binary(op, out)))
1121            }
1122            gpu::Arithmetic::Sqrt(op) => {
1123                let op = self.compile_unary(op, out);
1124                instructions.push(self.select_fast_float(
1125                    out.ty,
1126                    modes,
1127                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1128                    Instruction::Sqrt(op),
1129                    Instruction::FastSqrt(op),
1130                ))
1131            }
1132            gpu::Arithmetic::InverseSqrt(op) => {
1133                let op = self.compile_unary(op, out);
1134                instructions.push(self.select_fast_float(
1135                    out.ty,
1136                    modes,
1137                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1138                    Instruction::InverseSqrt(op),
1139                    Instruction::FastInverseSqrt(op),
1140                ))
1141            }
1142            gpu::Arithmetic::Erf(op) => {
1143                let instruction = Instruction::Erf(self.compile_unary(op, out));
1144                D::register_instruction_extension(&mut self.extensions, &instruction);
1145                instructions.push(instruction)
1146            }
1147            gpu::Arithmetic::Max(op) => {
1148                let instruction = Instruction::Max(self.compile_binary(op, out));
1149                D::register_instruction_extension(&mut self.extensions, &instruction);
1150                instructions.push(instruction)
1151            }
1152            gpu::Arithmetic::Min(op) => {
1153                let instruction = Instruction::Min(self.compile_binary(op, out));
1154                D::register_instruction_extension(&mut self.extensions, &instruction);
1155                instructions.push(instruction)
1156            }
1157            gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
1158                input: self.compile_variable(op.input),
1159                min_value: self.compile_variable(op.min_value),
1160                max_value: self.compile_variable(op.max_value),
1161                out: self.compile_variable(out),
1162            }),
1163            gpu::Arithmetic::Recip(op) => {
1164                let elem = op.input.ty.elem_type();
1165                let input = self.compile_variable(op.input);
1166                let out = self.compile_variable(out);
1167                let lhs = match elem {
1168                    gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
1169                    gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
1170                    gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
1171                    gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true),
1172                };
1173                let div = Instruction::Div(BinaryInstruction {
1174                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
1175                    rhs: input,
1176                    out,
1177                });
1178                let recip = Instruction::FastRecip(UnaryInstruction { input, out });
1179
1180                instructions.push(self.select_fast_float(
1181                    elem.into(),
1182                    modes,
1183                    FastMath::AllowReciprocal
1184                        | FastMath::ReducedPrecision
1185                        | FastMath::UnsignedZero
1186                        | FastMath::NotInf,
1187                    div,
1188                    recip,
1189                ))
1190            }
1191            gpu::Arithmetic::Round(op) => {
1192                instructions.push(Instruction::Round(self.compile_unary(op, out)))
1193            }
1194            gpu::Arithmetic::Floor(op) => {
1195                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
1196            }
1197            gpu::Arithmetic::Ceil(op) => {
1198                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
1199            }
1200            gpu::Arithmetic::Trunc(op) => {
1201                instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1202            }
1203            gpu::Arithmetic::Remainder(op) => {
1204                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
1205            }
1206            gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
1207                a: self.compile_variable(op.a),
1208                b: self.compile_variable(op.b),
1209                c: self.compile_variable(op.c),
1210                out: self.compile_variable(out),
1211            }),
1212            gpu::Arithmetic::Neg(op) => {
1213                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
1214            }
1215            gpu::Arithmetic::Normalize(op) => {
1216                let op = self.compile_unary(op, out);
1217                instructions.push(self.select_fast_float(
1218                    out.ty,
1219                    modes,
1220                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1221                    Instruction::Normalize(op),
1222                    Instruction::FastNormalize(op),
1223                ))
1224            }
1225            gpu::Arithmetic::Magnitude(op) => {
1226                let op = self.compile_unary(op, out);
1227                instructions.push(self.select_fast_float(
1228                    out.ty,
1229                    modes,
1230                    FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf,
1231                    Instruction::Magnitude(op),
1232                    Instruction::FastMagnitude(op),
1233                ))
1234            }
1235            gpu::Arithmetic::Dot(op) => {
1236                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
1237            }
1238        };
1239    }
1240
1241    fn select_fast_float(
1242        &self,
1243        ty: gpu::Type,
1244        modes: InstructionModes,
1245        required_flags: EnumSet<FastMath>,
1246        default: Instruction<D>,
1247        fast: Instruction<D>,
1248    ) -> Instruction<D> {
1249        if !self.compilation_options.supports_features.fast_math
1250            || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32))
1251        {
1252            return default;
1253        }
1254
1255        if modes.fp_math_mode.is_superset(required_flags) {
1256            fast
1257        } else {
1258            default
1259        }
1260    }
1261
1262    fn compile_comparison(
1263        &mut self,
1264        value: gpu::Comparison,
1265        out: Option<gpu::Variable>,
1266        instructions: &mut Vec<Instruction<D>>,
1267    ) {
1268        let out = out.unwrap();
1269        match value {
1270            gpu::Comparison::Equal(op) => {
1271                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
1272            }
1273            gpu::Comparison::Lower(op) => {
1274                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
1275            }
1276            gpu::Comparison::Greater(op) => {
1277                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
1278            }
1279            gpu::Comparison::LowerEqual(op) => {
1280                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
1281            }
1282            gpu::Comparison::GreaterEqual(op) => {
1283                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
1284            }
1285            gpu::Comparison::NotEqual(op) => {
1286                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
1287            }
1288            gpu::Comparison::IsNan(op) => {
1289                instructions.push(Instruction::IsNan(self.compile_unary(op, out)))
1290            }
1291            gpu::Comparison::IsInf(op) => {
1292                instructions.push(Instruction::IsInf(self.compile_unary(op, out)))
1293            }
1294        };
1295    }
1296
1297    fn compile_bitwise(
1298        &mut self,
1299        value: gpu::Bitwise,
1300        out: Option<gpu::Variable>,
1301        instructions: &mut Vec<Instruction<D>>,
1302    ) {
1303        let out = out.unwrap();
1304        match value {
1305            gpu::Bitwise::BitwiseOr(op) => {
1306                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
1307            }
1308            gpu::Bitwise::BitwiseAnd(op) => {
1309                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1310            }
1311            gpu::Bitwise::BitwiseXor(op) => {
1312                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1313            }
1314            gpu::Bitwise::CountOnes(op) => {
1315                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1316            }
1317            gpu::Bitwise::ReverseBits(op) => {
1318                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1319            }
1320            gpu::Bitwise::ShiftLeft(op) => {
1321                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1322            }
1323            gpu::Bitwise::ShiftRight(op) => {
1324                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1325            }
1326            gpu::Bitwise::BitwiseNot(op) => {
1327                instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1328            }
1329            gpu::Bitwise::LeadingZeros(op) => {
1330                instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1331            }
1332            gpu::Bitwise::FindFirstSet(op) => {
1333                let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1334                D::register_instruction_extension(&mut self.extensions, &instruction);
1335                instructions.push(instruction)
1336            }
1337        };
1338    }
1339
1340    fn compile_operator(
1341        &mut self,
1342        value: gpu::Operator,
1343        out: Option<gpu::Variable>,
1344        instructions: &mut Vec<Instruction<D>>,
1345    ) {
1346        let out = out.unwrap();
1347        match value {
1348            gpu::Operator::Index(op) | gpu::Operator::UncheckedIndex(op) => {
1349                instructions.push(Instruction::Index(self.compile_index(op, out)));
1350            }
1351            gpu::Operator::IndexAssign(op) | gpu::Operator::UncheckedIndexAssign(op) => {
1352                instructions.push(Instruction::IndexAssign(self.compile_index_assign(op, out)));
1353            }
1354            gpu::Operator::And(op) => {
1355                instructions.push(Instruction::And(self.compile_binary(op, out)))
1356            }
1357            gpu::Operator::Or(op) => {
1358                instructions.push(Instruction::Or(self.compile_binary(op, out)))
1359            }
1360            gpu::Operator::Not(op) => {
1361                instructions.push(Instruction::Not(self.compile_unary(op, out)))
1362            }
1363            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1364                inputs: op
1365                    .inputs
1366                    .into_iter()
1367                    .map(|it| self.compile_variable(it))
1368                    .collect(),
1369                out: self.compile_variable(out),
1370            }),
1371            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1372                input: self.compile_variable(op.input),
1373                in_index: self.compile_variable(op.in_index),
1374                out: self.compile_variable(out),
1375                out_index: self.compile_variable(op.out_index),
1376            }),
1377            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1378                input: self.compile_variable(op.input),
1379                in_index: self.compile_variable(op.in_index),
1380                out: self.compile_variable(out),
1381                out_index: self.compile_variable(op.out_index),
1382                len: op.len,
1383            }),
1384            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1385                cond: self.compile_variable(op.cond),
1386                then: self.compile_variable(op.then),
1387                or_else: self.compile_variable(op.or_else),
1388                out: self.compile_variable(out),
1389            }),
1390            // Needs special conversion semantics
1391            gpu::Operator::Cast(op)
1392                if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
1393            {
1394                // We may need these for intermediates
1395                self.flags.elem_f16 = true;
1396                self.flags.elem_bf16 = true;
1397                let vec_in = op.input.ty.line_size();
1398                let packing = out.storage_type().packing_factor();
1399                self.compile_type(op.input.ty.line(packing));
1400                self.compile_type(
1401                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1402                );
1403                self.compile_type(
1404                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
1405                );
1406                self.compile_type(
1407                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
1408                );
1409                self.compile_type(
1410                    gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1411                );
1412
1413                let inst = self.compile_unary(op, out);
1414
1415                instructions.push(Instruction::SpecialCast(inst));
1416            }
1417            gpu::Operator::Cast(op) => {
1418                let op = self.compile_unary(op, out);
1419
1420                if op.input.elem() == Elem::TF32 || op.out.elem() == Elem::TF32 {
1421                    self.flags.elem_tf32 = true;
1422                }
1423
1424                instructions.push(Instruction::Assign(op))
1425            }
1426            gpu::Operator::Reinterpret(op) => {
1427                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1428            }
1429        };
1430    }
1431
1432    fn compile_binary(
1433        &mut self,
1434        value: gpu::BinaryOperator,
1435        out: gpu::Variable,
1436    ) -> BinaryInstruction<D> {
1437        BinaryInstruction {
1438            lhs: self.compile_variable(value.lhs),
1439            rhs: self.compile_variable(value.rhs),
1440            out: self.compile_variable(out),
1441        }
1442    }
1443
1444    fn compile_index_assign(
1445        &mut self,
1446        value: gpu::IndexAssignOperator,
1447        out: gpu::Variable,
1448    ) -> IndexAssignInstruction<D> {
1449        IndexAssignInstruction {
1450            index: self.compile_variable(value.index),
1451            value: self.compile_variable(value.value),
1452            line_size: value.line_size,
1453            out: self.compile_variable(out),
1454        }
1455    }
1456
1457    fn compile_index(
1458        &mut self,
1459        value: gpu::IndexOperator,
1460        out: gpu::Variable,
1461    ) -> IndexInstruction<D> {
1462        IndexInstruction {
1463            list: self.compile_variable(value.list),
1464            index: self.compile_variable(value.index),
1465            line_size: value.line_size,
1466            out: self.compile_variable(out),
1467        }
1468    }
1469
1470    fn compile_unary(
1471        &mut self,
1472        value: gpu::UnaryOperator,
1473        out: gpu::Variable,
1474    ) -> UnaryInstruction<D> {
1475        UnaryInstruction {
1476            input: self.compile_variable(value.input),
1477            out: self.compile_variable(out),
1478        }
1479    }
1480
1481    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1482        let item = value.ty;
1483        match value.kind {
1484            gpu::VariableKind::GlobalInputArray(id) => {
1485                Variable::GlobalInputArray(id, self.compile_type(item))
1486            }
1487            gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1488                id,
1489                elem: self.compile_storage_type(item.storage_type()),
1490                in_struct: self.compilation_options.supports_features.grid_constants,
1491            },
1492            gpu::VariableKind::TensorMapInput(id) => {
1493                self.flags.inst_tma = true;
1494                Variable::TensorMap(id)
1495            }
1496            gpu::VariableKind::TensorMapOutput(id) => {
1497                self.flags.inst_tma = true;
1498                Variable::TensorMap(id)
1499            }
1500            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1501                id,
1502                item: self.compile_type(item),
1503            },
1504            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1505                id,
1506                item: self.compile_type(item),
1507            },
1508            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1509                id,
1510                item: self.compile_type(item),
1511            },
1512            gpu::VariableKind::GlobalOutputArray(id) => {
1513                Variable::GlobalOutputArray(id, self.compile_type(item))
1514            }
1515            gpu::VariableKind::ConstantScalar(value) => {
1516                Variable::ConstantScalar(value, self.compile_elem(value.elem_type()))
1517            }
1518            gpu::VariableKind::SharedMemory { id, length, .. } => {
1519                let item = self.compile_type(item);
1520                Variable::SharedMemory(id, item, length)
1521            }
1522            gpu::VariableKind::ConstantArray {
1523                id,
1524                length,
1525                unroll_factor,
1526            } => {
1527                let item = self.compile_type(item);
1528                Variable::ConstantArray(id, item, length * unroll_factor)
1529            }
1530            gpu::VariableKind::Builtin(builtin) => match builtin {
1531                gpu::Builtin::AbsolutePos => {
1532                    self.flags.indexes.absolute_pos = true;
1533                    Variable::AbsolutePos
1534                }
1535                gpu::Builtin::CubePosCluster
1536                    if self.compilation_options.supports_features.clusters =>
1537                {
1538                    self.flags.indexes.cluster_pos = true;
1539                    Variable::ClusterRank
1540                }
1541                gpu::Builtin::CubePosClusterX
1542                    if self.compilation_options.supports_features.clusters =>
1543                {
1544                    self.flags.indexes.cluster_pos = true;
1545                    Variable::ClusterIndexX
1546                }
1547                gpu::Builtin::CubePosClusterY
1548                    if self.compilation_options.supports_features.clusters =>
1549                {
1550                    self.flags.indexes.cluster_pos = true;
1551                    Variable::ClusterIndexY
1552                }
1553                gpu::Builtin::CubePosClusterZ
1554                    if self.compilation_options.supports_features.clusters =>
1555                {
1556                    self.flags.indexes.cluster_pos = true;
1557                    Variable::ClusterIndexZ
1558                }
1559                // Fallback if clusters aren't supported, ID is always 0 since clusters are always
1560                // (1, 1, 1) if unsupported
1561                gpu::Builtin::CubePosCluster
1562                | gpu::Builtin::CubePosClusterX
1563                | gpu::Builtin::CubePosClusterY
1564                | gpu::Builtin::CubePosClusterZ => const_u32(0),
1565                gpu::Builtin::AbsolutePosX => {
1566                    self.flags.indexes.absolute_pos_tuple = true;
1567                    Variable::AbsolutePosX
1568                }
1569                gpu::Builtin::AbsolutePosY => {
1570                    self.flags.indexes.absolute_pos_tuple = true;
1571                    Variable::AbsolutePosY
1572                }
1573                gpu::Builtin::AbsolutePosZ => {
1574                    self.flags.indexes.absolute_pos_tuple = true;
1575                    Variable::AbsolutePosZ
1576                }
1577                gpu::Builtin::CubeDim => {
1578                    self.flags.indexes.cube_dim = true;
1579                    Variable::CubeDim
1580                }
1581                gpu::Builtin::CubeDimX => {
1582                    self.flags.indexes.cube_dim_tuple = true;
1583                    Variable::CubeDimX
1584                }
1585                gpu::Builtin::CubeDimY => {
1586                    self.flags.indexes.cube_dim_tuple = true;
1587                    Variable::CubeDimY
1588                }
1589                gpu::Builtin::CubeDimZ => {
1590                    self.flags.indexes.cube_dim_tuple = true;
1591                    Variable::CubeDimZ
1592                }
1593                gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1594                gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1595                gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1596                gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1597                gpu::Builtin::CubePos => {
1598                    self.flags.indexes.cube_pos = true;
1599                    Variable::CubePos
1600                }
1601                gpu::Builtin::CubePosX => {
1602                    self.flags.indexes.cube_pos_tuple = true;
1603                    Variable::CubePosX
1604                }
1605                gpu::Builtin::CubePosY => {
1606                    self.flags.indexes.cube_pos_tuple = true;
1607                    Variable::CubePosY
1608                }
1609                gpu::Builtin::CubePosZ => {
1610                    self.flags.indexes.cube_pos_tuple = true;
1611                    Variable::CubePosZ
1612                }
1613                gpu::Builtin::CubeCount => {
1614                    self.flags.indexes.cube_count = true;
1615                    Variable::CubeCount
1616                }
1617                gpu::Builtin::CubeCountX => {
1618                    self.flags.indexes.cube_count_tuple = true;
1619                    Variable::CubeCountX
1620                }
1621                gpu::Builtin::CubeCountY => {
1622                    self.flags.indexes.cube_count_tuple = true;
1623                    Variable::CubeCountY
1624                }
1625                gpu::Builtin::CubeCountZ => {
1626                    self.flags.indexes.cube_count_tuple = true;
1627                    Variable::CubeCountZ
1628                }
1629                gpu::Builtin::UnitPos => {
1630                    self.flags.indexes.unit_pos = true;
1631                    Variable::UnitPos
1632                }
1633                gpu::Builtin::UnitPosX => {
1634                    self.flags.indexes.unit_pos_tuple = true;
1635                    Variable::UnitPosX
1636                }
1637                gpu::Builtin::UnitPosY => {
1638                    self.flags.indexes.unit_pos_tuple = true;
1639                    Variable::UnitPosY
1640                }
1641                gpu::Builtin::UnitPosZ => {
1642                    self.flags.indexes.unit_pos_tuple = true;
1643                    Variable::UnitPosZ
1644                }
1645                gpu::Builtin::PlaneDim => {
1646                    self.flags.indexes.plane_dim = true;
1647                    Variable::PlaneDim
1648                }
1649                gpu::Builtin::UnitPosPlane => {
1650                    self.flags.indexes.unit_pos_plane = true;
1651                    Variable::UnitPosPlane
1652                }
1653            },
1654            gpu::VariableKind::LocalArray {
1655                id,
1656                length,
1657                unroll_factor,
1658            } => {
1659                let item = self.compile_type(item);
1660                if !self.local_arrays.iter().any(|s| s.index == id) {
1661                    self.local_arrays
1662                        .push(LocalArray::new(id, item, length * unroll_factor));
1663                }
1664                Variable::LocalArray(id, item, length)
1665            }
1666            gpu::VariableKind::Matrix { id, mat } => {
1667                self.flags.inst_wmma = true;
1668                Variable::WmmaFragment {
1669                    id,
1670                    frag: self.compile_matrix(mat),
1671                }
1672            }
1673            gpu::VariableKind::Pipeline { id, num_stages } => {
1674                self.flags.op_pipeline = true;
1675                let pipeline = Variable::Pipeline { id };
1676                if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1677                    self.pipelines.push(PipelineOps::Init {
1678                        pipeline,
1679                        num_stages,
1680                    });
1681                }
1682                pipeline
1683            }
1684            gpu::VariableKind::Barrier { id, level } => {
1685                self.flags.op_barrier = true;
1686                Variable::Barrier { id, level }
1687            }
1688            gpu::VariableKind::BarrierToken { id, level } => {
1689                self.flags.op_barrier = true;
1690                Variable::BarrierToken { id, level }
1691            }
1692        }
1693    }
1694
1695    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1696        Fragment {
1697            ident: self.compile_matrix_ident(matrix.ident),
1698            m: matrix.m,
1699            n: matrix.n,
1700            k: matrix.k,
1701            elem: self.compile_storage_type(matrix.storage),
1702            layout: self.compile_matrix_layout(matrix.layout),
1703        }
1704    }
1705
1706    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1707        match ident {
1708            gpu::MatrixIdent::A => FragmentIdent::A,
1709            gpu::MatrixIdent::B => FragmentIdent::B,
1710            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1711        }
1712    }
1713
1714    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1715        match layout {
1716            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1717            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1718            gpu::MatrixLayout::Undefined => None,
1719        }
1720    }
1721
1722    fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1723        Binding {
1724            id: binding.id,
1725            item: self.compile_type(binding.ty),
1726            location: binding.location,
1727            size: binding.size,
1728            vis: binding.visibility,
1729        }
1730    }
1731
1732    fn compile_type(&mut self, ty: gpu::Type) -> Item<D> {
1733        let item = match ty {
1734            gpu::Type::Scalar(ty) => Item::new(self.compile_storage_type(ty), 1, false),
1735            gpu::Type::Line(ty, line_size) => {
1736                Item::new(self.compile_storage_type(ty), line_size as usize, false)
1737            }
1738            gpu::Type::Semantic(_) => Item::new(Elem::Bool, 1, true),
1739        };
1740        if item.elem != super::Elem::TF32 {
1741            self.items.insert(item);
1742            self.items.insert(item.optimized());
1743        } else {
1744            // TF32 is represented as `float` in C++
1745            let mut item = item;
1746            item.elem = super::Elem::F32;
1747            self.items.insert(item);
1748        }
1749
1750        item
1751    }
1752
1753    fn compile_storage_type(&mut self, value: gpu::StorageType) -> Elem<D> {
1754        match value {
1755            gpu::StorageType::Scalar(ty) => self.compile_elem(ty),
1756            gpu::StorageType::Atomic(ty) => Elem::Atomic(ty.into()),
1757            gpu::StorageType::Packed(gpu::ElemType::Float(kind), 2) => match kind {
1758                FloatKind::E2M1 => {
1759                    self.flags.elem_fp4 = true;
1760                    Elem::FP4x2(FP4Kind::E2M1)
1761                }
1762                FloatKind::E2M3 => {
1763                    self.flags.elem_fp6 = true;
1764                    Elem::FP6x2(FP6Kind::E2M3)
1765                }
1766                FloatKind::E3M2 => {
1767                    self.flags.elem_fp6 = true;
1768                    Elem::FP6(FP6Kind::E3M2)
1769                }
1770                FloatKind::E4M3 => {
1771                    self.flags.elem_fp8 = true;
1772                    Elem::FP8x2(FP8Kind::E4M3)
1773                }
1774                FloatKind::E5M2 => {
1775                    self.flags.elem_fp8 = true;
1776                    Elem::FP8x2(FP8Kind::E5M2)
1777                }
1778                FloatKind::UE8M0 => {
1779                    self.flags.elem_fp8 = true;
1780                    Elem::FP8x2(FP8Kind::UE8M0)
1781                }
1782                FloatKind::F16 => {
1783                    self.flags.elem_f16 = true;
1784                    Elem::F16x2
1785                }
1786                FloatKind::BF16 => {
1787                    self.flags.elem_bf16 = true;
1788                    Elem::BF16x2
1789                }
1790                other => unimplemented!("Unsupported storage type: packed<{other:?}, 2>"),
1791            },
1792            other => unimplemented!("Unsupported storage type: {other}"),
1793        }
1794    }
1795
1796    fn compile_elem(&mut self, value: gpu::ElemType) -> Elem<D> {
1797        match value {
1798            gpu::ElemType::Float(kind) => match kind {
1799                gpu::FloatKind::E2M1 => {
1800                    self.flags.elem_fp4 = true;
1801                    Elem::FP4(FP4Kind::E2M1)
1802                }
1803                gpu::FloatKind::E2M3 => {
1804                    self.flags.elem_fp6 = true;
1805                    Elem::FP6(FP6Kind::E2M3)
1806                }
1807                gpu::FloatKind::E3M2 => {
1808                    self.flags.elem_fp6 = true;
1809                    Elem::FP6(FP6Kind::E3M2)
1810                }
1811                gpu::FloatKind::E4M3 => {
1812                    self.flags.elem_fp8 = true;
1813                    Elem::FP8(FP8Kind::E4M3)
1814                }
1815                gpu::FloatKind::E5M2 => {
1816                    self.flags.elem_fp8 = true;
1817                    Elem::FP8(FP8Kind::E5M2)
1818                }
1819                gpu::FloatKind::UE8M0 => {
1820                    self.flags.elem_fp8 = true;
1821                    Elem::FP8(FP8Kind::UE8M0)
1822                }
1823                gpu::FloatKind::F16 => {
1824                    self.flags.elem_f16 = true;
1825                    Elem::F16
1826                }
1827                gpu::FloatKind::BF16 => {
1828                    self.flags.elem_bf16 = true;
1829                    Elem::BF16
1830                }
1831                gpu::FloatKind::TF32 => Elem::TF32,
1832                gpu::FloatKind::Flex32 => Elem::F32,
1833                gpu::FloatKind::F32 => Elem::F32,
1834                gpu::FloatKind::F64 => Elem::F64,
1835            },
1836            gpu::ElemType::Int(kind) => match kind {
1837                gpu::IntKind::I8 => Elem::I8,
1838                gpu::IntKind::I16 => Elem::I16,
1839                gpu::IntKind::I32 => Elem::I32,
1840                gpu::IntKind::I64 => Elem::I64,
1841            },
1842            gpu::ElemType::UInt(kind) => match kind {
1843                gpu::UIntKind::U8 => Elem::U8,
1844                gpu::UIntKind::U16 => Elem::U16,
1845                gpu::UIntKind::U32 => Elem::U32,
1846                gpu::UIntKind::U64 => Elem::U64,
1847            },
1848            gpu::ElemType::Bool => Elem::Bool,
1849        }
1850    }
1851}
1852
1853fn is_fp4_fp6_fp8(elem: gpu::ElemType) -> bool {
1854    match elem {
1855        gpu::ElemType::Float(kind) => matches!(
1856            kind,
1857            FloatKind::E2M1
1858                | FloatKind::E2M3
1859                | FloatKind::E3M2
1860                | FloatKind::E4M3
1861                | FloatKind::E5M2
1862                | FloatKind::UE8M0
1863        ),
1864        _ => false,
1865    }
1866}
1867
1868fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1869    Variable::ConstantScalar(
1870        gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1871        Elem::U32,
1872    )
1873}
1874
1875pub fn register_supported_types(props: &mut DeviceProperties) {
1876    let supported_types = [
1877        gpu::ElemType::UInt(gpu::UIntKind::U8),
1878        gpu::ElemType::UInt(gpu::UIntKind::U16),
1879        gpu::ElemType::UInt(gpu::UIntKind::U32),
1880        gpu::ElemType::UInt(gpu::UIntKind::U64),
1881        gpu::ElemType::Int(gpu::IntKind::I8),
1882        gpu::ElemType::Int(gpu::IntKind::I16),
1883        gpu::ElemType::Int(gpu::IntKind::I32),
1884        gpu::ElemType::Int(gpu::IntKind::I64),
1885        gpu::ElemType::Float(gpu::FloatKind::BF16),
1886        gpu::ElemType::Float(gpu::FloatKind::F16),
1887        gpu::ElemType::Float(gpu::FloatKind::F32),
1888        gpu::ElemType::Float(gpu::FloatKind::Flex32),
1889        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
1890        //gpu::Elem::Float(gpu::FloatKind::F64),
1891        gpu::ElemType::Bool,
1892    ];
1893
1894    let supported_atomic_types = [
1895        gpu::ElemType::Int(gpu::IntKind::I32),
1896        gpu::ElemType::Int(gpu::IntKind::I64),
1897        gpu::ElemType::UInt(gpu::UIntKind::U32),
1898        gpu::ElemType::UInt(gpu::UIntKind::U64),
1899        gpu::ElemType::Float(gpu::FloatKind::F32),
1900    ];
1901
1902    for ty in supported_types {
1903        props.register_type_usage(ty, TypeUsage::all_scalar());
1904    }
1905
1906    for ty in supported_atomic_types {
1907        props.register_type_usage(
1908            gpu::StorageType::Atomic(ty),
1909            TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
1910        );
1911    }
1912}