cubecl_cpp/shared/
base.rs

1use std::hash::Hash;
2use std::{collections::HashSet, fmt::Debug, num::NonZero};
3
4use cubecl_core::ir::expand_checked_index_assign;
5use cubecl_core::{
6    ir::{self as gpu},
7    Compiler, Feature,
8};
9use cubecl_runtime::{DeviceProperties, ExecutionMode};
10
11use super::{
12    AtomicKind, BinaryInstruction, Binding, Body, ComputeKernel, ConstArray, Elem, Fragment,
13    FragmentIdent, FragmentLayout, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction,
14    Variable, VariableSettings, WarpInstruction, WmmaCompiler, WmmaInstruction,
15};
16
17pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
18    std::sync::atomic::AtomicU32::new(0);
19
20pub trait Dialect:
21    WmmaCompiler<Self> + Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
22{
23    // includes
24    fn include_f16(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
25    fn include_bf16(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
26    fn include_runtime(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
27    // types
28    fn bfloat16_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
29    fn bfloat162_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
30    // warp instructions (all threads participating)
31    fn warp_shuffle(var: &str, source: &str) -> String;
32    fn warp_shuffle_xor(var: &str, offset: &str) -> String;
33    fn warp_shuffle_down(var: &str, offset: &str) -> String;
34    fn warp_all(var: &str) -> String;
35    fn warp_any(var: &str) -> String;
36}
37
38#[derive(Clone, Debug)]
39pub struct CompilationOptions {
40    pub warp_size: u32,
41}
42
43impl Default for CompilationOptions {
44    fn default() -> Self {
45        Self { warp_size: 32 }
46    }
47}
48
49#[allow(clippy::too_many_arguments)]
50#[derive(Clone, Debug, Default)]
51pub struct CppCompiler<D: Dialect> {
52    shared_memories: Vec<SharedMemory<D>>,
53    const_arrays: Vec<ConstArray<D>>,
54    local_arrays: Vec<LocalArray<D>>,
55    metadata: cubecl_core::Metadata,
56    warp_size_checked: bool,
57    wmma: bool,
58    bf16: bool,
59    f16: bool,
60    printf: bool,
61    num_inputs: usize,
62    num_outputs: usize,
63    ext_meta_positions: Vec<u32>,
64    items: HashSet<Item<D>>,
65    strategy: ExecutionMode,
66    settings: VariableSettings,
67    compilation_options: CompilationOptions,
68}
69
70impl<D: Dialect> Compiler for CppCompiler<D> {
71    type Representation = ComputeKernel<D>;
72    type CompilationOptions = CompilationOptions;
73
74    fn compile(
75        kernel: cubecl_core::ir::KernelDefinition,
76        compilation_options: &Self::CompilationOptions,
77        strategy: ExecutionMode,
78    ) -> Self::Representation {
79        let compiler = Self {
80            compilation_options: compilation_options.clone(),
81            strategy,
82            ..Self::default()
83        };
84        let ir = compiler.compile_ir(kernel);
85        COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
86        ir
87    }
88
89    fn elem_size(elem: gpu::Elem) -> usize {
90        elem.size()
91    }
92
93    fn max_shared_memory_size() -> usize {
94        49152
95    }
96}
97
98impl<D: Dialect> CppCompiler<D> {
99    fn compile_ir(mut self, mut value: gpu::KernelDefinition) -> ComputeKernel<D> {
100        self.build_metadata(&value);
101
102        let instructions = self.compile_scope(&mut value.body);
103        let inputs = value
104            .inputs
105            .into_iter()
106            .map(|b| self.compile_binding(b))
107            .collect();
108        let outputs = value
109            .outputs
110            .into_iter()
111            .map(|b| self.compile_binding(b))
112            .collect();
113        let named = value
114            .named
115            .into_iter()
116            .map(|(name, binding)| (name, self.compile_binding(binding)))
117            .collect();
118
119        let body = Body {
120            instructions,
121            shared_memories: self.shared_memories,
122            const_arrays: self.const_arrays,
123            local_arrays: self.local_arrays,
124            warp_size_checked: self.warp_size_checked,
125            settings: self.settings,
126        };
127
128        ComputeKernel {
129            inputs,
130            outputs,
131            named,
132            cube_dim: value.cube_dim,
133            body,
134            wmma_activated: self.wmma,
135            bf16: self.bf16,
136            f16: self.f16,
137            items: self.items,
138            kernel_name: value.kernel_name,
139        }
140    }
141
142    fn build_metadata(&mut self, value: &gpu::KernelDefinition) {
143        self.num_inputs = value.inputs.len();
144        self.num_outputs = value.outputs.len();
145
146        let mut num_ext = 0;
147
148        for binding in value.inputs.iter().chain(value.outputs.iter()) {
149            self.ext_meta_positions.push(num_ext);
150            if binding.has_extended_meta {
151                num_ext += 1;
152            }
153        }
154
155        let num_meta = self.num_inputs + self.num_outputs;
156
157        self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
158    }
159
160    pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
161        let pos = match var.kind {
162            gpu::VariableKind::GlobalInputArray(id) => id as usize,
163            gpu::VariableKind::GlobalOutputArray(id) => self.num_inputs + id as usize,
164            other => panic!("Only global arrays have metadata, got {other:?}"),
165        };
166        self.ext_meta_positions[pos]
167    }
168
169    fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
170        let mut instructions = Vec::new();
171
172        let const_arrays = scope
173            .const_arrays
174            .drain(..)
175            .map(|(var, values)| ConstArray {
176                index: var.index().unwrap(),
177                item: self.compile_item(var.item),
178                size: values.len() as u32,
179                values: values
180                    .into_iter()
181                    .map(|val| self.compile_variable(val))
182                    .collect(),
183            })
184            .collect::<Vec<_>>();
185        self.const_arrays.extend(const_arrays);
186
187        let processing = scope.process();
188
189        for var in processing.variables {
190            if let gpu::VariableKind::Slice { .. } = var.kind {
191                continue;
192            }
193            instructions.push(Instruction::DeclareVariable {
194                var: self.compile_variable(var),
195            });
196        }
197
198        processing
199            .operations
200            .into_iter()
201            .for_each(|op| self.compile_operation(&mut instructions, op, scope));
202
203        instructions
204    }
205
206    fn compile_operation(
207        &mut self,
208        instructions: &mut Vec<Instruction<D>>,
209        instruction: gpu::Instruction,
210        scope: &mut gpu::Scope,
211    ) {
212        let out = instruction.out;
213        match instruction.operation {
214            gpu::Operation::Copy(variable) => {
215                instructions.push(Instruction::Assign(UnaryInstruction {
216                    input: self.compile_variable(variable),
217                    out: self.compile_variable(out.unwrap()),
218                }));
219            }
220            gpu::Operation::Operator(op) => self.compile_instruction(op, out, instructions, scope),
221            gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
222            gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
223            gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
224            gpu::Operation::Synchronization(val) => match val {
225                gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads),
226                gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
227            },
228            gpu::Operation::Plane(op) => {
229                self.warp_size_checked = true;
230                let out = self.compile_variable(out.unwrap());
231                match op {
232                    gpu::Plane::Sum(op) => {
233                        instructions.push(Instruction::Wrap(WarpInstruction::ReduceSum {
234                            input: self.compile_variable(op.input),
235                            out,
236                        }))
237                    }
238                    gpu::Plane::Prod(op) => {
239                        instructions.push(Instruction::Wrap(WarpInstruction::ReduceProd {
240                            input: self.compile_variable(op.input),
241                            out,
242                        }))
243                    }
244                    gpu::Plane::Max(op) => {
245                        instructions.push(Instruction::Wrap(WarpInstruction::ReduceMax {
246                            input: self.compile_variable(op.input),
247                            out,
248                        }))
249                    }
250                    gpu::Plane::Min(op) => {
251                        instructions.push(Instruction::Wrap(WarpInstruction::ReduceMin {
252                            input: self.compile_variable(op.input),
253                            out,
254                        }))
255                    }
256                    gpu::Plane::Elect => {
257                        instructions.push(Instruction::Wrap(WarpInstruction::Elect { out }))
258                    }
259                    gpu::Plane::All(op) => {
260                        instructions.push(Instruction::Wrap(WarpInstruction::All {
261                            input: self.compile_variable(op.input),
262                            out,
263                        }))
264                    }
265                    gpu::Plane::Any(op) => {
266                        instructions.push(Instruction::Wrap(WarpInstruction::Any {
267                            input: self.compile_variable(op.input),
268                            out,
269                        }))
270                    }
271                    gpu::Plane::Broadcast(op) => {
272                        instructions.push(Instruction::Wrap(WarpInstruction::Broadcast {
273                            input: self.compile_variable(op.lhs),
274                            id: self.compile_variable(op.rhs),
275                            out,
276                        }))
277                    }
278                }
279            }
280            gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
281            gpu::Operation::NonSemantic(debug) => match debug {
282                // No good way to attach debug info
283                gpu::NonSemantic::BeginCall { .. }
284                | gpu::NonSemantic::EndCall
285                | gpu::NonSemantic::Source { .. }
286                | gpu::NonSemantic::Line { .. } => {}
287                gpu::NonSemantic::Print {
288                    format_string,
289                    args,
290                } => {
291                    self.printf = true;
292                    instructions.push(Instruction::Printf {
293                        format_string,
294                        args: args
295                            .into_iter()
296                            .map(|arg| self.compile_variable(arg))
297                            .collect(),
298                    })
299                }
300                gpu::NonSemantic::Comment { content } => {
301                    instructions.push(Instruction::Comment { content })
302                }
303            },
304        }
305    }
306
307    fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
308        let out = self.compile_variable(out.unwrap());
309        match cmma {
310            gpu::CoopMma::Fill { value } => Instruction::Wmma(WmmaInstruction::Fill {
311                frag: out,
312                value: self.compile_variable(value),
313            }),
314            gpu::CoopMma::Load {
315                value,
316                stride,
317                layout,
318            } => Instruction::Wmma(WmmaInstruction::Load {
319                frag: out,
320                value: self.compile_variable(value),
321                stride: self.compile_variable(stride),
322                layout: layout.and_then(|l| self.compile_matrix_layout(l)),
323            }),
324            gpu::CoopMma::Execute {
325                mat_a,
326                mat_b,
327                mat_c,
328            } => Instruction::Wmma(WmmaInstruction::Execute {
329                frag_a: self.compile_variable(mat_a),
330                frag_b: self.compile_variable(mat_b),
331                frag_c: self.compile_variable(mat_c),
332                frag_d: out,
333                warp_size: self.compilation_options.warp_size,
334            }),
335            gpu::CoopMma::Store {
336                mat,
337                stride,
338                layout,
339            } => Instruction::Wmma(WmmaInstruction::Store {
340                output: out,
341                frag: self.compile_variable(mat),
342                stride: self.compile_variable(stride),
343                layout: self
344                    .compile_matrix_layout(layout)
345                    .expect("Layout required for store instruction"),
346            }),
347            gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast {
348                input: self.compile_variable(input),
349                output: out,
350            }),
351        }
352    }
353
354    fn compile_metadata(
355        &mut self,
356        metadata: gpu::Metadata,
357        out: Option<gpu::Variable>,
358    ) -> Instruction<D> {
359        let out = out.unwrap();
360        match metadata {
361            gpu::Metadata::Stride { dim, var } => {
362                let position = self.ext_meta_position(var);
363                let offset = self.metadata.stride_offset_index(position);
364                Instruction::ExtendedMetadata {
365                    info_offset: self.compile_variable(offset.into()),
366                    dim: self.compile_variable(dim),
367                    out: self.compile_variable(out),
368                }
369            }
370            gpu::Metadata::Shape { dim, var } => {
371                let position = self.ext_meta_position(var);
372                let offset = self.metadata.shape_offset_index(position);
373                Instruction::ExtendedMetadata {
374                    info_offset: self.compile_variable(offset.into()),
375                    dim: self.compile_variable(dim),
376                    out: self.compile_variable(out),
377                }
378            }
379            gpu::Metadata::Rank { var } => {
380                let out = self.compile_variable(out);
381                let pos = self.ext_meta_position(var);
382                let offset = self.metadata.rank_index(pos);
383                super::Instruction::Metadata {
384                    info_offset: self.compile_variable(offset.into()),
385                    out,
386                }
387            }
388            gpu::Metadata::Length { var } => {
389                let input = self.compile_variable(var);
390                let out = self.compile_variable(out);
391
392                match input {
393                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
394                    _ => {
395                        let id = match input {
396                            Variable::GlobalInputArray(id, _) => id,
397                            Variable::GlobalOutputArray(id, _) => self.num_inputs as u32 + id,
398                            _ => panic!("Can only get length of global array"),
399                        };
400                        let offset = self.metadata.len_index(id);
401                        Instruction::Metadata {
402                            info_offset: self.compile_variable(offset.into()),
403                            out,
404                        }
405                    }
406                }
407            }
408            gpu::Metadata::BufferLength { var } => {
409                let input = self.compile_variable(var);
410                let out = self.compile_variable(out);
411
412                match input {
413                    Variable::Slice { .. } => Instruction::SliceLength { input, out },
414                    _ => {
415                        let id = match input {
416                            Variable::GlobalInputArray(id, _) => id,
417                            Variable::GlobalOutputArray(id, _) => self.num_inputs as u32 + id,
418                            _ => panic!("Can only get buffer length of global array"),
419                        };
420                        let offset = self.metadata.buffer_len_index(id);
421                        Instruction::Metadata {
422                            info_offset: self.compile_variable(offset.into()),
423                            out,
424                        }
425                    }
426                }
427            }
428        }
429    }
430
431    fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
432        match branch {
433            gpu::Branch::If(mut op) => instructions.push(Instruction::If {
434                cond: self.compile_variable(op.cond),
435                instructions: self.compile_scope(&mut op.scope),
436            }),
437            gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
438                cond: self.compile_variable(op.cond),
439                instructions_if: self.compile_scope(&mut op.scope_if),
440                instructions_else: self.compile_scope(&mut op.scope_else),
441            }),
442            gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
443                value: self.compile_variable(op.value),
444                instructions_default: self.compile_scope(&mut op.scope_default),
445                instructions_cases: op
446                    .cases
447                    .into_iter()
448                    .map(|(val, mut block)| {
449                        (self.compile_variable(val), self.compile_scope(&mut block))
450                    })
451                    .collect(),
452            }),
453            gpu::Branch::Return => instructions.push(Instruction::Return),
454            gpu::Branch::Break => instructions.push(Instruction::Break),
455            gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
456                i: self.compile_variable(range_loop.i),
457                start: self.compile_variable(range_loop.start),
458                end: self.compile_variable(range_loop.end),
459                step: range_loop.step.map(|it| self.compile_variable(it)),
460                inclusive: range_loop.inclusive,
461                instructions: self.compile_scope(&mut range_loop.scope),
462            }),
463            gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
464                instructions: self.compile_scope(&mut op.scope),
465            }),
466        };
467    }
468
469    fn compile_atomic(
470        &mut self,
471        value: gpu::AtomicOp,
472        out: Option<gpu::Variable>,
473        instructions: &mut Vec<Instruction<D>>,
474    ) {
475        let out = out.unwrap();
476        match value {
477            gpu::AtomicOp::Load(op) => {
478                instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
479            }
480            gpu::AtomicOp::Store(op) => {
481                instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
482            }
483            gpu::AtomicOp::Swap(op) => {
484                instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
485            }
486            gpu::AtomicOp::Add(op) => {
487                instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
488            }
489            gpu::AtomicOp::Sub(op) => {
490                instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
491            }
492            gpu::AtomicOp::Max(op) => {
493                instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
494            }
495            gpu::AtomicOp::Min(op) => {
496                instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
497            }
498            gpu::AtomicOp::And(op) => {
499                instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
500            }
501            gpu::AtomicOp::Or(op) => {
502                instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
503            }
504            gpu::AtomicOp::Xor(op) => {
505                instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
506            }
507            gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
508                input: self.compile_variable(op.input),
509                cmp: self.compile_variable(op.cmp),
510                val: self.compile_variable(op.val),
511                out: self.compile_variable(out),
512            }),
513        }
514    }
515
516    fn compile_instruction(
517        &mut self,
518        value: gpu::Operator,
519        out: Option<gpu::Variable>,
520        instructions: &mut Vec<Instruction<D>>,
521        scope: &mut gpu::Scope,
522    ) {
523        let out = out.unwrap();
524        match value {
525            gpu::Operator::Add(op) => {
526                instructions.push(Instruction::Add(self.compile_binary(op, out)))
527            }
528            gpu::Operator::Mul(op) => {
529                instructions.push(Instruction::Mul(self.compile_binary(op, out)))
530            }
531            gpu::Operator::Div(op) => {
532                instructions.push(Instruction::Div(self.compile_binary(op, out)))
533            }
534            gpu::Operator::Sub(op) => {
535                instructions.push(Instruction::Sub(self.compile_binary(op, out)))
536            }
537            gpu::Operator::Slice(op) => {
538                if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() {
539                    let input = op.input;
540                    let input_len =
541                        scope.create_local_mut(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32)));
542                    instructions.extend(self.compile_scope(scope));
543
544                    let length = match input.has_buffer_length() {
545                        true => gpu::Metadata::BufferLength { var: input },
546                        false => gpu::Metadata::Length { var: input },
547                    };
548
549                    instructions.push(self.compile_metadata(length, Some(input_len)));
550                    instructions.push(Instruction::CheckedSlice {
551                        input: self.compile_variable(op.input),
552                        start: self.compile_variable(op.start),
553                        end: self.compile_variable(op.end),
554                        out: self.compile_variable(out),
555                        len: self.compile_variable(input_len),
556                    });
557                } else {
558                    instructions.push(Instruction::Slice {
559                        input: self.compile_variable(op.input),
560                        start: self.compile_variable(op.start),
561                        end: self.compile_variable(op.end),
562                        out: self.compile_variable(out),
563                    })
564                }
565            }
566            gpu::Operator::Index(op) => {
567                if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() {
568                    let lhs = op.lhs;
569                    let rhs = op.rhs;
570
571                    let array_len =
572                        scope.create_local(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32)));
573
574                    instructions.extend(self.compile_scope(scope));
575
576                    let length = match lhs.has_buffer_length() {
577                        true => gpu::Metadata::BufferLength { var: lhs },
578                        false => gpu::Metadata::Length { var: lhs },
579                    };
580                    instructions.push(self.compile_metadata(length, Some(array_len)));
581                    instructions.push(Instruction::CheckedIndex {
582                        len: self.compile_variable(array_len),
583                        lhs: self.compile_variable(lhs),
584                        rhs: self.compile_variable(rhs),
585                        out: self.compile_variable(out),
586                    });
587                } else {
588                    instructions.push(Instruction::Index(self.compile_binary(op, out)));
589                }
590            }
591            gpu::Operator::UncheckedIndex(op) => {
592                instructions.push(Instruction::Index(self.compile_binary(op, out)))
593            }
594            gpu::Operator::IndexAssign(op) => {
595                if let ExecutionMode::Checked = self.strategy {
596                    if out.has_length() {
597                        expand_checked_index_assign(scope, op.lhs, op.rhs, out);
598                        instructions.extend(self.compile_scope(scope));
599                        return;
600                    }
601                };
602                instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)));
603            }
604            gpu::Operator::UncheckedIndexAssign(op) => {
605                instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)))
606            }
607            gpu::Operator::Modulo(op) => {
608                instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
609            }
610            gpu::Operator::Equal(op) => {
611                instructions.push(Instruction::Equal(self.compile_binary(op, out)))
612            }
613            gpu::Operator::Lower(op) => {
614                instructions.push(Instruction::Lower(self.compile_binary(op, out)))
615            }
616            gpu::Operator::Greater(op) => {
617                instructions.push(Instruction::Greater(self.compile_binary(op, out)))
618            }
619            gpu::Operator::LowerEqual(op) => {
620                instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
621            }
622            gpu::Operator::GreaterEqual(op) => {
623                instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
624            }
625            gpu::Operator::Abs(op) => {
626                instructions.push(Instruction::Abs(self.compile_unary(op, out)))
627            }
628            gpu::Operator::Exp(op) => {
629                instructions.push(Instruction::Exp(self.compile_unary(op, out)))
630            }
631            gpu::Operator::Log(op) => {
632                instructions.push(Instruction::Log(self.compile_unary(op, out)))
633            }
634            gpu::Operator::Log1p(op) => {
635                instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
636            }
637            gpu::Operator::Cos(op) => {
638                instructions.push(Instruction::Cos(self.compile_unary(op, out)))
639            }
640            gpu::Operator::Sin(op) => {
641                instructions.push(Instruction::Sin(self.compile_unary(op, out)))
642            }
643            gpu::Operator::Tanh(op) => {
644                instructions.push(Instruction::Tanh(self.compile_unary(op, out)))
645            }
646            gpu::Operator::Powf(op) => {
647                instructions.push(Instruction::Powf(self.compile_binary(op, out)))
648            }
649            gpu::Operator::Sqrt(op) => {
650                instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
651            }
652            gpu::Operator::Erf(op) => {
653                instructions.push(Instruction::Erf(self.compile_unary(op, out)))
654            }
655            gpu::Operator::And(op) => {
656                instructions.push(Instruction::And(self.compile_binary(op, out)))
657            }
658            gpu::Operator::Or(op) => {
659                instructions.push(Instruction::Or(self.compile_binary(op, out)))
660            }
661            gpu::Operator::Not(op) => {
662                instructions.push(Instruction::Not(self.compile_unary(op, out)))
663            }
664            gpu::Operator::Max(op) => {
665                instructions.push(Instruction::Max(self.compile_binary(op, out)))
666            }
667            gpu::Operator::Min(op) => {
668                instructions.push(Instruction::Min(self.compile_binary(op, out)))
669            }
670            gpu::Operator::NotEqual(op) => {
671                instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
672            }
673            gpu::Operator::BitwiseOr(op) => {
674                instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
675            }
676            gpu::Operator::BitwiseAnd(op) => {
677                instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
678            }
679            gpu::Operator::BitwiseXor(op) => {
680                instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
681            }
682            gpu::Operator::CountOnes(op) => {
683                instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
684            }
685            gpu::Operator::ReverseBits(op) => {
686                instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
687            }
688            gpu::Operator::ShiftLeft(op) => {
689                instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
690            }
691            gpu::Operator::ShiftRight(op) => {
692                instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
693            }
694            gpu::Operator::Clamp(op) => instructions.push(Instruction::Clamp {
695                input: self.compile_variable(op.input),
696                min_value: self.compile_variable(op.min_value),
697                max_value: self.compile_variable(op.max_value),
698                out: self.compile_variable(out),
699            }),
700            gpu::Operator::Recip(op) => {
701                let elem = op.input.item.elem();
702                let lhs = match elem {
703                    gpu::Elem::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
704                    gpu::Elem::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
705                    gpu::Elem::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
706                    gpu::Elem::Bool => gpu::ConstantScalarValue::Bool(true),
707                    gpu::Elem::AtomicInt(_)
708                    | gpu::Elem::AtomicUInt(_)
709                    | gpu::Elem::AtomicFloat(_) => {
710                        panic!("Cannot use recip with atomics")
711                    }
712                };
713
714                instructions.push(Instruction::Div(BinaryInstruction {
715                    lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
716                    rhs: self.compile_variable(op.input),
717                    out: self.compile_variable(out),
718                }))
719            }
720            gpu::Operator::Round(op) => {
721                instructions.push(Instruction::Round(self.compile_unary(op, out)))
722            }
723            gpu::Operator::Floor(op) => {
724                instructions.push(Instruction::Floor(self.compile_unary(op, out)))
725            }
726            gpu::Operator::Ceil(op) => {
727                instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
728            }
729            gpu::Operator::Remainder(op) => {
730                instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
731            }
732            gpu::Operator::Fma(op) => instructions.push(Instruction::Fma {
733                a: self.compile_variable(op.a),
734                b: self.compile_variable(op.b),
735                c: self.compile_variable(op.c),
736                out: self.compile_variable(out),
737            }),
738            gpu::Operator::Bitcast(op) => {
739                instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
740            }
741            gpu::Operator::Neg(op) => {
742                instructions.push(Instruction::Neg(self.compile_unary(op, out)))
743            }
744            gpu::Operator::Normalize(op) => {
745                instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
746            }
747            gpu::Operator::Magnitude(op) => {
748                instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
749            }
750            gpu::Operator::Dot(op) => {
751                instructions.push(Instruction::Dot(self.compile_binary(op, out)))
752            }
753            gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
754                inputs: op
755                    .inputs
756                    .into_iter()
757                    .map(|it| self.compile_variable(it))
758                    .collect(),
759                out: self.compile_variable(out),
760            }),
761            gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
762                input: self.compile_variable(op.input),
763                in_index: self.compile_variable(op.in_index),
764                out: self.compile_variable(out),
765                out_index: self.compile_variable(op.out_index),
766            }),
767            gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
768                input: self.compile_variable(op.input),
769                in_index: self.compile_variable(op.in_index),
770                out: self.compile_variable(out),
771                out_index: self.compile_variable(op.out_index),
772                len: op.len,
773            }),
774            gpu::Operator::Select(op) => instructions.push(Instruction::Select {
775                cond: self.compile_variable(op.cond),
776                then: self.compile_variable(op.then),
777                or_else: self.compile_variable(op.or_else),
778                out: self.compile_variable(out),
779            }),
780            gpu::Operator::Cast(op) => {
781                instructions.push(Instruction::Assign(self.compile_unary(op, out)))
782            }
783        };
784    }
785
786    fn compile_binary(
787        &mut self,
788        value: gpu::BinaryOperator,
789        out: gpu::Variable,
790    ) -> BinaryInstruction<D> {
791        BinaryInstruction {
792            lhs: self.compile_variable(value.lhs),
793            rhs: self.compile_variable(value.rhs),
794            out: self.compile_variable(out),
795        }
796    }
797
798    fn compile_unary(
799        &mut self,
800        value: gpu::UnaryOperator,
801        out: gpu::Variable,
802    ) -> UnaryInstruction<D> {
803        UnaryInstruction {
804            input: self.compile_variable(value.input),
805            out: self.compile_variable(out),
806        }
807    }
808
809    fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
810        let item = value.item;
811        match value.kind {
812            gpu::VariableKind::GlobalInputArray(id) => {
813                Variable::GlobalInputArray(id, self.compile_item(item))
814            }
815            gpu::VariableKind::GlobalScalar(id) => {
816                Variable::GlobalScalar(id, self.compile_item(item).elem, item.elem)
817            }
818            gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
819                id,
820                item: self.compile_item(item),
821            },
822            gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
823                id,
824                item: self.compile_item(item),
825            },
826            gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
827                id,
828                item: self.compile_item(item),
829            },
830            gpu::VariableKind::Slice { id } => Variable::Slice {
831                id,
832                item: self.compile_item(item),
833            },
834            gpu::VariableKind::GlobalOutputArray(id) => {
835                Variable::GlobalOutputArray(id, self.compile_item(item))
836            }
837            gpu::VariableKind::ConstantScalar(value) => {
838                Variable::ConstantScalar(value, self.compile_elem(value.elem()))
839            }
840            gpu::VariableKind::SharedMemory { id, length } => {
841                let item = self.compile_item(item);
842                if !self.shared_memories.iter().any(|s| s.index == id) {
843                    self.shared_memories
844                        .push(SharedMemory::new(id, item, length));
845                }
846                Variable::SharedMemory(id, item, length)
847            }
848            gpu::VariableKind::ConstantArray { id, length } => {
849                let item = self.compile_item(item);
850                Variable::ConstantArray(id, item, length)
851            }
852            gpu::VariableKind::Builtin(builtin) => match builtin {
853                gpu::Builtin::AbsolutePos => {
854                    self.settings.idx_global = true;
855                    Variable::IdxGlobal
856                }
857                gpu::Builtin::UnitPos => {
858                    self.settings.thread_idx_global = true;
859                    Variable::ThreadIdxGlobal
860                }
861                gpu::Builtin::UnitPosX => Variable::ThreadIdxX,
862                gpu::Builtin::UnitPosY => Variable::ThreadIdxY,
863                gpu::Builtin::UnitPosZ => Variable::ThreadIdxZ,
864                gpu::Builtin::CubePosX => Variable::BlockIdxX,
865                gpu::Builtin::CubePosY => Variable::BlockIdxY,
866                gpu::Builtin::CubePosZ => Variable::BlockIdxZ,
867                gpu::Builtin::AbsolutePosX => {
868                    self.settings.absolute_idx.0 = true;
869                    Variable::AbsoluteIdxX
870                }
871                gpu::Builtin::AbsolutePosY => {
872                    self.settings.absolute_idx.1 = true;
873                    Variable::AbsoluteIdxY
874                }
875                gpu::Builtin::AbsolutePosZ => {
876                    self.settings.absolute_idx.2 = true;
877                    Variable::AbsoluteIdxZ
878                }
879                gpu::Builtin::CubeDimX => Variable::BlockDimX,
880                gpu::Builtin::CubeDimY => Variable::BlockDimY,
881                gpu::Builtin::CubeDimZ => Variable::BlockDimZ,
882                gpu::Builtin::CubeCountX => Variable::GridDimX,
883                gpu::Builtin::CubeCountY => Variable::GridDimY,
884                gpu::Builtin::CubeCountZ => Variable::GridDimZ,
885                gpu::Builtin::CubePos => {
886                    self.settings.block_idx_global = true;
887                    Variable::BlockIdxGlobal
888                }
889                gpu::Builtin::CubeDim => {
890                    self.settings.block_dim_global = true;
891                    Variable::BlockDimGlobal
892                }
893                gpu::Builtin::CubeCount => {
894                    self.settings.grid_dim_global = true;
895                    Variable::GridDimGlobal
896                }
897                gpu::Builtin::PlaneDim => Variable::WarpSize,
898                gpu::Builtin::UnitPosPlane => Variable::ThreadIdxWarp,
899            },
900            gpu::VariableKind::LocalArray { id, length } => {
901                let item = self.compile_item(item);
902                if !self.local_arrays.iter().any(|s| s.index == id) {
903                    self.local_arrays.push(LocalArray::new(id, item, length));
904                }
905                Variable::LocalArray(id, item, length)
906            }
907            gpu::VariableKind::Matrix { id, mat } => {
908                self.wmma = true;
909                Variable::WmmaFragment {
910                    id,
911                    frag: self.compile_matrix(mat),
912                }
913            }
914        }
915    }
916
917    fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
918        Fragment {
919            ident: self.compile_matrix_ident(matrix.ident),
920            m: matrix.m,
921            n: matrix.n,
922            k: matrix.k,
923            elem: self.compile_elem(matrix.elem),
924            layout: self.compile_matrix_layout(matrix.layout),
925        }
926    }
927
928    fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
929        match ident {
930            gpu::MatrixIdent::A => FragmentIdent::A,
931            gpu::MatrixIdent::B => FragmentIdent::B,
932            gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
933        }
934    }
935
936    fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
937        match layout {
938            gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
939            gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
940            gpu::MatrixLayout::Undefined => None,
941        }
942    }
943
944    fn compile_binding(&mut self, binding: gpu::Binding) -> Binding<D> {
945        Binding {
946            item: self.compile_item(binding.item),
947            size: binding.size,
948            vis: binding.visibility,
949        }
950    }
951
952    fn compile_item(&mut self, item: gpu::Item) -> Item<D> {
953        let item = Item::new(
954            self.compile_elem(item.elem),
955            item.vectorization.map(NonZero::get).unwrap_or(1).into(),
956        );
957        if item.elem != super::Elem::TF32 {
958            self.items.insert(item);
959            self.items.insert(item.optimized());
960        } else {
961            // TF32 is represented as `float` in C++
962            let mut item = item;
963            item.elem = super::Elem::F32;
964            self.items.insert(item);
965        }
966
967        item
968    }
969
970    fn compile_elem(&mut self, value: gpu::Elem) -> Elem<D> {
971        match value {
972            gpu::Elem::Float(kind) => match kind {
973                gpu::FloatKind::F16 => {
974                    self.f16 = true;
975                    Elem::F16
976                }
977                gpu::FloatKind::BF16 => {
978                    self.bf16 = true;
979                    Elem::BF16
980                }
981                gpu::FloatKind::TF32 => Elem::TF32,
982                gpu::FloatKind::Flex32 => Elem::F32,
983                gpu::FloatKind::F32 => Elem::F32,
984                gpu::FloatKind::F64 => Elem::F64,
985            },
986            gpu::Elem::AtomicFloat(kind) => match kind {
987                gpu::FloatKind::F16 => Elem::Atomic(AtomicKind::F16),
988                gpu::FloatKind::BF16 => Elem::Atomic(AtomicKind::BF16),
989                gpu::FloatKind::F32 => Elem::Atomic(AtomicKind::F32),
990                gpu::FloatKind::F64 => Elem::Atomic(AtomicKind::F64),
991                kind => unimplemented!("atomic<{kind:?}> not yet supported"),
992            },
993            gpu::Elem::Int(kind) => match kind {
994                gpu::IntKind::I8 => Elem::I8,
995                gpu::IntKind::I16 => Elem::I16,
996                gpu::IntKind::I32 => Elem::I32,
997                gpu::IntKind::I64 => Elem::I64,
998            },
999            gpu::Elem::AtomicInt(kind) => match kind {
1000                gpu::IntKind::I32 => Elem::Atomic(AtomicKind::I32),
1001                gpu::IntKind::I64 => Elem::Atomic(AtomicKind::I64),
1002                kind => panic!("atomic<{kind:?}> isn't supported yet"),
1003            },
1004            gpu::Elem::UInt(kind) => match kind {
1005                gpu::UIntKind::U8 => Elem::U8,
1006                gpu::UIntKind::U16 => Elem::U16,
1007                gpu::UIntKind::U32 => Elem::U32,
1008                gpu::UIntKind::U64 => Elem::U64,
1009            },
1010            gpu::Elem::AtomicUInt(kind) => match kind {
1011                gpu::UIntKind::U32 => Elem::Atomic(AtomicKind::U32),
1012                gpu::UIntKind::U64 => Elem::Atomic(AtomicKind::U64),
1013                kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1014            },
1015            gpu::Elem::Bool => Elem::Bool,
1016        }
1017    }
1018}
1019
1020pub fn register_supported_types(props: &mut DeviceProperties<Feature>) {
1021    let supported_types = [
1022        gpu::Elem::UInt(gpu::UIntKind::U8),
1023        gpu::Elem::UInt(gpu::UIntKind::U16),
1024        gpu::Elem::UInt(gpu::UIntKind::U32),
1025        gpu::Elem::UInt(gpu::UIntKind::U64),
1026        gpu::Elem::Int(gpu::IntKind::I8),
1027        gpu::Elem::Int(gpu::IntKind::I16),
1028        gpu::Elem::Int(gpu::IntKind::I32),
1029        gpu::Elem::Int(gpu::IntKind::I64),
1030        gpu::Elem::AtomicInt(gpu::IntKind::I32),
1031        gpu::Elem::AtomicInt(gpu::IntKind::I64),
1032        gpu::Elem::AtomicUInt(gpu::UIntKind::U32),
1033        gpu::Elem::AtomicUInt(gpu::UIntKind::U64),
1034        gpu::Elem::Float(gpu::FloatKind::BF16),
1035        gpu::Elem::Float(gpu::FloatKind::F16),
1036        gpu::Elem::Float(gpu::FloatKind::F32),
1037        gpu::Elem::Float(gpu::FloatKind::Flex32),
1038        gpu::Elem::AtomicFloat(gpu::FloatKind::F32),
1039        // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated
1040        //gpu::Elem::Float(gpu::FloatKind::F64),
1041        gpu::Elem::Bool,
1042    ];
1043
1044    for ty in supported_types {
1045        props.register_feature(Feature::Type(ty));
1046    }
1047}