Skip to main content

cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Formatter, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const DYNAMIC_META_NAME: &str = "dynamic_meta";
15pub(crate) const STATIC_META_NAME: &str = "info.static_meta";
16
17#[derive(Debug, Clone, Copy)]
18pub struct BinaryInstruction<D: Dialect> {
19    pub lhs: Variable<D>,
20    pub rhs: Variable<D>,
21    pub out: Variable<D>,
22}
23
24#[derive(Debug, Clone)]
25pub struct IndexInstruction<D: Dialect> {
26    pub list: Variable<D>,
27    pub index: Variable<D>,
28    pub vector_size: u32,
29    pub out: Variable<D>,
30}
31
32#[derive(Debug, Clone)]
33pub struct IndexAssignInstruction<D: Dialect> {
34    pub index: Variable<D>,
35    pub value: Variable<D>,
36    pub vector_size: u32,
37    pub out: Variable<D>,
38}
39
40#[derive(Debug, Clone, Copy)]
41pub struct UnaryInstruction<D: Dialect> {
42    pub input: Variable<D>,
43    pub out: Variable<D>,
44}
45
46#[derive(Debug, Clone)]
47pub enum Instruction<D: Dialect> {
48    Metadata {
49        info_offset: Variable<D>,
50        out: Variable<D>,
51    },
52    ExtendedMetadata {
53        info_offset: Variable<D>,
54        dim: Variable<D>,
55        out: Variable<D>,
56    },
57    ConstLength {
58        length: usize,
59        out: Variable<D>,
60    },
61    SliceLength {
62        input: Variable<D>,
63        out: Variable<D>,
64    },
65    DeclareVariable {
66        var: Variable<D>,
67    },
68    Modulo(BinaryInstruction<D>),
69    Remainder(BinaryInstruction<D>),
70    Add(BinaryInstruction<D>),
71    SaturatingAdd(BinaryInstruction<D>),
72    Fma {
73        a: Variable<D>,
74        b: Variable<D>,
75        c: Variable<D>,
76        out: Variable<D>,
77    },
78    Div(BinaryInstruction<D>),
79    FastDiv(BinaryInstruction<D>),
80    FastRecip(UnaryInstruction<D>),
81    Mul(BinaryInstruction<D>),
82    Sub(BinaryInstruction<D>),
83    SaturatingSub(BinaryInstruction<D>),
84    HiMul(BinaryInstruction<D>),
85    Index(IndexInstruction<D>),
86    IndexAssign(IndexAssignInstruction<D>),
87    Assign(UnaryInstruction<D>),
88    SpecialCast(UnaryInstruction<D>),
89    RangeLoop {
90        i: Variable<D>,
91        start: Variable<D>,
92        end: Variable<D>,
93        step: Option<Variable<D>>,
94        inclusive: bool,
95        instructions: Vec<Self>,
96    },
97    VecInit {
98        inputs: Vec<Variable<D>>,
99        out: Variable<D>,
100    },
101    Loop {
102        instructions: Vec<Self>,
103    },
104    If {
105        cond: Variable<D>,
106        instructions: Vec<Self>,
107    },
108    IfElse {
109        cond: Variable<D>,
110        instructions_if: Vec<Self>,
111        instructions_else: Vec<Self>,
112    },
113    Select {
114        cond: Variable<D>,
115        then: Variable<D>,
116        or_else: Variable<D>,
117        out: Variable<D>,
118    },
119    Switch {
120        value: Variable<D>,
121        instructions_default: Vec<Self>,
122        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
123    },
124    Slice {
125        input: Variable<D>,
126        start: Variable<D>,
127        end: Variable<D>,
128        out: Variable<D>,
129    },
130    CheckedSlice {
131        input: Variable<D>,
132        start: Variable<D>,
133        end: Variable<D>,
134        out: Variable<D>,
135        len: Variable<D>,
136    },
137    ReinterpretSlice {
138        input: Variable<D>,
139        vector_size: u32,
140        out: Variable<D>,
141    },
142    Return,
143    Break,
144    Unreachable,
145    Equal(BinaryInstruction<D>),
146    NotEqual(BinaryInstruction<D>),
147    Lower(BinaryInstruction<D>),
148    Greater(BinaryInstruction<D>),
149    LowerEqual(BinaryInstruction<D>),
150    GreaterEqual(BinaryInstruction<D>),
151    Erf(UnaryInstruction<D>),
152    BitwiseOr(BinaryInstruction<D>),
153    BitwiseAnd(BinaryInstruction<D>),
154    BitwiseXor(BinaryInstruction<D>),
155    CountBits(UnaryInstruction<D>),
156    ReverseBits(UnaryInstruction<D>),
157    ShiftLeft(BinaryInstruction<D>),
158    ShiftRight(BinaryInstruction<D>),
159    BitwiseNot(UnaryInstruction<D>),
160    LeadingZeros(UnaryInstruction<D>),
161    TrailingZeros(UnaryInstruction<D>),
162    FindFirstSet(UnaryInstruction<D>),
163    Abs(UnaryInstruction<D>),
164    Exp(UnaryInstruction<D>),
165    FastExp(UnaryInstruction<D>),
166    Log(UnaryInstruction<D>),
167    FastLog(UnaryInstruction<D>),
168    Log1p(UnaryInstruction<D>),
169    Cos(UnaryInstruction<D>),
170    Sin(UnaryInstruction<D>),
171    Tan(UnaryInstruction<D>),
172    Tanh(UnaryInstruction<D>),
173    Sinh(UnaryInstruction<D>),
174    Cosh(UnaryInstruction<D>),
175    ArcCos(UnaryInstruction<D>),
176    ArcSin(UnaryInstruction<D>),
177    ArcTan(UnaryInstruction<D>),
178    ArcSinh(UnaryInstruction<D>),
179    ArcCosh(UnaryInstruction<D>),
180    ArcTanh(UnaryInstruction<D>),
181    Degrees(UnaryInstruction<D>),
182    Radians(UnaryInstruction<D>),
183    ArcTan2(BinaryInstruction<D>),
184    FastSin(UnaryInstruction<D>),
185    FastCos(UnaryInstruction<D>),
186    FastTanh(UnaryInstruction<D>),
187    Powf(BinaryInstruction<D>),
188    FastPowf(BinaryInstruction<D>),
189    Powi(BinaryInstruction<D>),
190    Hypot(BinaryInstruction<D>),
191    Rhypot(BinaryInstruction<D>),
192    Sqrt(UnaryInstruction<D>),
193    FastSqrt(UnaryInstruction<D>),
194    InverseSqrt(UnaryInstruction<D>),
195    FastInverseSqrt(UnaryInstruction<D>),
196    Min(BinaryInstruction<D>),
197    Max(BinaryInstruction<D>),
198    Not(UnaryInstruction<D>),
199    Or(BinaryInstruction<D>),
200    And(BinaryInstruction<D>),
201    Clamp {
202        input: Variable<D>,
203        min_value: Variable<D>,
204        max_value: Variable<D>,
205        out: Variable<D>,
206    },
207    IsNan(UnaryInstruction<D>),
208    IsInf(UnaryInstruction<D>),
209    SyncThreads,
210    SyncWarp,
211    ThreadFence,
212    ProxyAsyncToSharedFence,
213    BulkCommitGroup,
214    BulkWaitGroup {
215        max_pending: u32,
216    },
217    BulkWaitGroupRead {
218        max_pending: u32,
219    },
220    TmaReplacePointer {
221        buffer: Variable<D>,
222        offset: Variable<D>,
223        tensor_map: Variable<D>,
224        out: Variable<D>,
225    },
226    Round(UnaryInstruction<D>),
227    Ceil(UnaryInstruction<D>),
228    Trunc(UnaryInstruction<D>),
229    Floor(UnaryInstruction<D>),
230    Warp(WarpInstruction<D>),
231    Wmma(WmmaInstruction<D>),
232    Bitcast(UnaryInstruction<D>),
233    AtomicLoad(UnaryInstruction<D>),
234    AtomicStore(UnaryInstruction<D>),
235    AtomicSwap(BinaryInstruction<D>),
236    AtomicAdd(BinaryInstruction<D>),
237    AtomicSub(BinaryInstruction<D>),
238    AtomicMax(BinaryInstruction<D>),
239    AtomicMin(BinaryInstruction<D>),
240    AtomicAnd(BinaryInstruction<D>),
241    AtomicOr(BinaryInstruction<D>),
242    AtomicXor(BinaryInstruction<D>),
243    AtomicCAS {
244        input: Variable<D>,
245        cmp: Variable<D>,
246        val: Variable<D>,
247        out: Variable<D>,
248    },
249    Neg(UnaryInstruction<D>),
250    Magnitude(UnaryInstruction<D>),
251    FastMagnitude(UnaryInstruction<D>),
252    Normalize(UnaryInstruction<D>),
253    FastNormalize(UnaryInstruction<D>),
254    Dot(BinaryInstruction<D>),
255    Copy {
256        input: Variable<D>,
257        in_index: Variable<D>,
258        out: Variable<D>,
259        out_index: Variable<D>,
260    },
261    CopyBulk {
262        input: Variable<D>,
263        in_index: Variable<D>,
264        out: Variable<D>,
265        out_index: Variable<D>,
266        len: u32,
267    },
268    Printf {
269        format_string: String,
270        args: Vec<Variable<D>>,
271    },
272    Comment {
273        content: String,
274    },
275    Pipeline(PipelineOps<D>),
276    Barrier(BarrierOps<D>),
277    MemCopyAsyncTensorSharedToGlobal {
278        smem_buffer: Variable<D>,
279        smem_offset: Variable<D>,
280        tensor_map: Variable<D>,
281        indices: Vec<Variable<D>>,
282    },
283    Line {
284        file: Cow<'static, str>,
285        line: u32,
286    },
287}
288
289impl<D: Dialect> Display for Instruction<D> {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        match self {
292            Instruction::Return => f.write_str("return;"),
293            Instruction::Break => f.write_str("break;"),
294            Instruction::Unreachable => D::compile_unreachable(f),
295            Instruction::DeclareVariable { var } => match var {
296                Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
297                _ => {
298                    let item = var.item();
299                    writeln!(f, "{item} {var};")
300                }
301            },
302            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
303            Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
304            Instruction::Slice {
305                input,
306                start,
307                end,
308                out,
309            } => {
310                let item = out.item();
311                let addr_space = D::address_space_for_variable(input);
312                writeln!(f, "const uint {out}_length = {end} - {start};")?;
313                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
314            }
315            Instruction::CheckedSlice {
316                input,
317                start,
318                end,
319                out,
320                len,
321            } => {
322                let item = out.item();
323                let addr_space = D::address_space_for_variable(input);
324                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
325                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
326            }
327            Instruction::ReinterpretSlice {
328                input,
329                vector_size,
330                out,
331            } => {
332                let mut item = out.item();
333                item.vectorization = *vector_size as usize;
334                let addr_space = D::address_space_for_variable(input);
335
336                writeln!(
337                    f,
338                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
339                )
340            }
341            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
342            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
343            Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
344            Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
345            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
346            Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
347            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
348            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
349            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
350            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
351            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
352            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
353            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
354            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
355            Instruction::TrailingZeros(it) => TrailingZeros::format(f, &it.input, &it.out),
356            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
357            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
358            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
359            Instruction::Index(it) => {
360                Index::format(f, &it.list, &it.index, &it.out, it.vector_size)
361            }
362            Instruction::IndexAssign(it) => {
363                IndexAssign::format(f, &it.index, &it.value, &it.out, it.vector_size)
364            }
365            Instruction::Copy {
366                input,
367                in_index,
368                out,
369                out_index,
370            } => {
371                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
372            }
373            Instruction::CopyBulk {
374                input,
375                in_index,
376                out,
377                out_index,
378                len,
379            } => {
380                for i in 0..*len {
381                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
382                }
383                Ok(())
384            }
385            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
386            Instruction::RangeLoop {
387                i,
388                start,
389                end,
390                step,
391                inclusive,
392                instructions,
393            } => {
394                let increment = step
395                    .map(|step| format!("{i} += {step}"))
396                    .unwrap_or_else(|| format!("++{i}"));
397                let cmp = if *inclusive { "<=" } else { "<" };
398                let i_ty = i.item();
399
400                write!(
401                    f,
402                    "
403for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
404"
405                )?;
406                for instruction in instructions {
407                    write!(f, "{instruction}")?;
408                }
409
410                f.write_str("}\n")
411            }
412            Instruction::Loop { instructions } => {
413                writeln!(f, "while (true) {{")?;
414                for i in instructions {
415                    write!(f, "{i}")?;
416                }
417                f.write_str("}\n")
418            }
419            Instruction::If { cond, instructions } => {
420                writeln!(f, "if ({cond}) {{")?;
421                for i in instructions {
422                    write!(f, "{i}")?;
423                }
424                f.write_str("}\n")
425            }
426            Instruction::IfElse {
427                cond,
428                instructions_if,
429                instructions_else,
430            } => {
431                writeln!(f, "if ({cond}) {{")?;
432                for i in instructions_if {
433                    write!(f, "{i}")?;
434                }
435                f.write_str("} else {\n")?;
436                for i in instructions_else {
437                    write!(f, "{i}")?;
438                }
439                f.write_str("}\n")
440            }
441            Instruction::Select {
442                cond,
443                then,
444                or_else,
445                out,
446            } => {
447                let item_or_else = or_else.item();
448                let item_then = then.item();
449                let item_out = out.item();
450
451                let vf_then = item_then.vectorization;
452                let vf_or_else = item_or_else.vectorization;
453                let vf_out = item_out.vectorization;
454                let vf_cond = cond.item().vectorization;
455
456                let item_out = out.item();
457                let cond_elem = cond.item().elem;
458                let out = out.fmt_left();
459
460                let should_broadcast =
461                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
462
463                if should_broadcast {
464                    let vf = usize::max(vf_cond, vf_out);
465                    let vf = usize::max(vf, vf_then);
466                    let vf = usize::max(vf, vf_or_else);
467
468                    writeln!(f, "{out} = {item_out} {{")?;
469                    for i in 0..vf {
470                        let theni = then.index(i);
471                        let or_elsei = or_else.index(i);
472                        let condi = cond.index(i);
473                        let condi = EnsureBoolArg {
474                            var: &condi,
475                            elem: &cond_elem,
476                        };
477
478                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
479                    }
480
481                    writeln!(f, "}};")
482                } else {
483                    let cond = EnsureBoolArg {
484                        var: &cond,
485                        elem: &cond_elem,
486                    };
487                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
488                }
489            }
490            Instruction::Switch {
491                value,
492                instructions_default,
493                instructions_cases,
494            } => {
495                writeln!(f, "switch({value}) {{")?;
496                for (value, block) in instructions_cases {
497                    write!(f, "case {value}:\n{{\n")?;
498                    for i in block {
499                        i.fmt(f)?;
500                    }
501                    f.write_str("break;\n}\n")?;
502                }
503                f.write_str("default:\n{")?;
504                for i in instructions_default {
505                    i.fmt(f)?;
506                }
507                f.write_str("break;\n}\n}\n")
508            }
509            Instruction::Metadata { info_offset, out } => {
510                let out = out.fmt_left();
511                writeln!(f, "{out} = {STATIC_META_NAME}[{info_offset}];")
512            }
513            Instruction::ExtendedMetadata {
514                info_offset,
515                dim,
516                out,
517            } => {
518                let out = out.fmt_left();
519                writeln!(
520                    f,
521                    "{out} = {DYNAMIC_META_NAME}[{STATIC_META_NAME}[{info_offset}] + {dim}];"
522                )
523            }
524            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
525            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
526            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
527            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
528            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
529            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
530            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
531            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
532            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
533            Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
534            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
535            Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
536            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
537            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
538            Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
539            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
540            Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
541            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
542            Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
543            Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
544            Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
545            Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
546            Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
547            Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
548            Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
549            Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
550            Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
551            Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
552            Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
553            Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
554            Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
555            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
556            Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
557            Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
558            Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
559            Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
560            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
561            Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
562            Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
563            Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
564            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
565            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
566            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
567            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
568            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
569            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
570            Instruction::Clamp {
571                input,
572                min_value,
573                max_value,
574                out,
575            } => Clamp::format(f, input, min_value, max_value, out),
576            Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
577            Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
578            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
579            Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
580            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
581            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
582            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
583            Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
584            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
585            Instruction::SliceLength { input, out } => {
586                let out = out.fmt_left();
587                writeln!(f, "{out} = {input}_length;")
588            }
589            Instruction::ConstLength { length, out } => {
590                let out = out.fmt_left();
591                writeln!(f, "{out} = {length};")
592            }
593            Instruction::Warp(it) => write!(f, "{it}"),
594            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
595            Instruction::Wmma(it) => write!(f, "{it}"),
596            Instruction::Bitcast(UnaryInstruction { input, out }) => {
597                let qualifier = out.const_qualifier();
598                let input_item = input.item();
599                let out_item = out.item();
600
601                if out_item.elem.size() * out_item.vectorization
602                    != input.item().elem.size() * input.item().vectorization
603                {
604                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
605                } else {
606                    let out = out.fmt_left();
607                    let addr_space = D::address_space_for_variable(input);
608                    writeln!(
609                        f,
610                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
611                    )
612                }
613            }
614            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
615                D::compile_atomic_add(f, lhs, rhs, out)
616            }
617            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
618                D::compile_atomic_and(f, lhs, rhs, out)
619            }
620            Instruction::AtomicCAS {
621                input,
622                cmp,
623                val,
624                out,
625            } => D::compile_atomic_cas(f, input, cmp, val, out),
626            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
627                D::compile_atomic_load(f, input, out)
628            }
629            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
630                D::compile_atomic_max(f, lhs, rhs, out)
631            }
632            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
633                D::compile_atomic_min(f, lhs, rhs, out)
634            }
635            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
636                D::compile_atomic_or(f, lhs, rhs, out)
637            }
638            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
639                D::compile_atomic_store(f, input, out)
640            }
641            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
642                D::compile_atomic_sub(f, lhs, rhs, out)
643            }
644            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
645                D::compile_atomic_swap(f, lhs, rhs, out)
646            }
647            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
648                D::compile_atomic_xor(f, lhs, rhs, out)
649            }
650            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
651            Instruction::Neg(UnaryInstruction { input, out }) => {
652                let out = out.fmt_left();
653                writeln!(f, "{out} = -{input};")
654            }
655            Instruction::Normalize(inst) => {
656                Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
657            }
658            Instruction::FastNormalize(inst) => {
659                Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
660            }
661            Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
662            Instruction::FastMagnitude(inst) => {
663                Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
664            }
665            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
666            Instruction::VecInit { inputs, out } => {
667                let item = out.item();
668                let inputs = inputs
669                    .iter()
670                    .map(|input| format!("{input}"))
671                    .collect::<Vec<_>>();
672                let out = out.fmt_left();
673                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
674            }
675            Instruction::Printf {
676                format_string,
677                args,
678            } => D::compile_instruction_printf(f, format_string, args),
679            Instruction::Comment { content } => {
680                if content.contains('\n') {
681                    writeln!(f, "/* {content} */")
682                } else {
683                    writeln!(f, "// {content}")
684                }
685            }
686            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
687            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
688            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
689            Instruction::ProxyAsyncToSharedFence => {
690                writeln!(
691                    f,
692                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
693                )
694            }
695            Instruction::BulkCommitGroup => writeln!(
696                f,
697                "cuda::device::experimental::cp_async_bulk_commit_group();"
698            ),
699            Instruction::BulkWaitGroup { max_pending } => writeln!(
700                f,
701                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
702            ),
703            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
704                f,
705                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
706            ),
707            Instruction::TmaReplacePointer {
708                buffer,
709                offset,
710                tensor_map,
711                out,
712            } => {
713                let pos = Variable::<D>::UnitPos;
714                writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
715                writeln!(
716                    f,
717                    "
718if({pos} == 0) {{
719    {out} = {tensor_map};
720    tensormap_replace_global_address({out}, &{buffer}[{offset}]);
721}}"
722                )?;
723                writeln!(f, "__syncthreads();")
724            }
725            Instruction::MemCopyAsyncTensorSharedToGlobal {
726                smem_buffer,
727                smem_offset,
728                tensor_map,
729                indices,
730            } => {
731                let rank = indices.len();
732                let smem_ptr = smem_buffer.fmt_ptr();
733                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
734                    let _ = write!(s, "{it}, ");
735                    s
736                });
737                writeln!(
738                    f,
739                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
740                )
741            }
742            Instruction::SpecialCast(UnaryInstruction { input, out }) => {
743                // Only supported in CUDA so I'm putting it here. Move to dialect if necessary.
744                #[cfg(not(feature = "cuda"))]
745                {
746                    let _ = (input, out);
747                    writeln!(
748                        f,
749                        "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
750                    )
751                }
752                #[cfg(feature = "cuda")]
753                crate::cuda::convert::special_cast::<D>(f, input, out)
754            }
755        }
756    }
757}
758
759struct Fma<D: Dialect> {
760    _dialect: PhantomData<D>,
761}
762
763impl<D: Dialect> Fma<D> {
764    fn format(
765        f: &mut core::fmt::Formatter<'_>,
766        a: &Variable<D>,
767        b: &Variable<D>,
768        c: &Variable<D>,
769        out: &Variable<D>,
770    ) -> core::fmt::Result {
771        let out_item = out.item();
772        let num = out_item.vectorization;
773
774        let out = out.fmt_left();
775        if num == 1 {
776            writeln!(f, "{out} = fma({a}, {b}, {c});")
777        } else {
778            writeln!(f, "{out} = {out_item}{{")?;
779
780            for i in 0..num {
781                let ai = a.index(i);
782                let bi = b.index(i);
783                let ci = c.index(i);
784
785                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
786            }
787            f.write_str("};\n")
788        }
789    }
790}
791
792struct Clamp<D: Dialect> {
793    _dialect: PhantomData<D>,
794}
795
796impl<D: Dialect> Clamp<D> {
797    fn format(
798        f: &mut core::fmt::Formatter<'_>,
799        input: &Variable<D>,
800        min_value: &Variable<D>,
801        max_value: &Variable<D>,
802        out: &Variable<D>,
803    ) -> core::fmt::Result {
804        let out_item = out.item();
805        if out.item().vectorization == 1 {
806            let out = out.fmt_left();
807            write!(f, "{out} = ")?;
808            Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
809            f.write_str(";\n")
810        } else {
811            Self::unroll_vec(f, input, min_value, max_value, out)
812        }
813    }
814
815    fn format_scalar(
816        f: &mut Formatter<'_>,
817        input: impl Component<D>,
818        min_value: impl Component<D>,
819        max_value: impl Component<D>,
820        item: Item<D>,
821    ) -> std::fmt::Result {
822        D::compile_instruction_max_function_name(f, item)?;
823        write!(f, "({min_value}, ")?;
824        D::compile_instruction_min_function_name(f, item)?;
825        write!(f, "({max_value}, {input}))")
826    }
827
828    fn unroll_vec(
829        f: &mut core::fmt::Formatter<'_>,
830        input: &Variable<D>,
831        min_value: &Variable<D>,
832        max_value: &Variable<D>,
833        out: &Variable<D>,
834    ) -> std::fmt::Result {
835        let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
836        let [input, min_value, max_value, out_optimized] = optimized.args;
837
838        let item_out_original = out.item();
839        let item_out_optimized = out_optimized.item();
840
841        let index = match optimized.optimization_factor {
842            Some(factor) => item_out_original.vectorization / factor,
843            None => item_out_optimized.vectorization,
844        };
845
846        let mut write_op = |input: &Variable<D>,
847                            min_value: &Variable<D>,
848                            max_value: &Variable<D>,
849                            out: &Variable<D>,
850                            item_out: Item<D>| {
851            let out = out.fmt_left();
852            writeln!(f, "{out} = {item_out}{{")?;
853            for i in 0..index {
854                let inputi = input.index(i);
855                let min_valuei = min_value.index(i);
856                let max_valuei = max_value.index(i);
857
858                Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
859                f.write_str(", ")?;
860            }
861
862            f.write_str("};\n")
863        };
864
865        if item_out_original == item_out_optimized {
866            write_op(&input, &min_value, &max_value, out, item_out_optimized)
867        } else {
868            let out_tmp = Variable::tmp(item_out_optimized);
869            write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
870            let addr_space = D::address_space_for_variable(out);
871            let out = out.fmt_left();
872
873            writeln!(
874                f,
875                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
876            )?;
877
878            Ok(())
879        }
880    }
881}
882
883struct Remainder<D: Dialect> {
884    _dialect: PhantomData<D>,
885}
886
887impl<D: Dialect> Remainder<D> {
888    fn format(
889        f: &mut core::fmt::Formatter<'_>,
890        lhs: &Variable<D>,
891        rhs: &Variable<D>,
892        out: &Variable<D>,
893    ) -> core::fmt::Result {
894        let floor = |elem| {
895            let prefix = match elem {
896                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
897                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
898                _ => "",
899            };
900            format!("{prefix}floor")
901        };
902
903        let is_int = matches!(
904            out.elem(),
905            Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
906        );
907        let out_elem = out.elem();
908        let rem_expr = |lhs, rhs, floor: &str| {
909            if is_int {
910                format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})")
911            } else {
912                format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
913            }
914        };
915
916        if out.item().vectorization == 1 {
917            let floor = floor(out.elem());
918
919            let out = out.fmt_left();
920            let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
921            return writeln!(f, "{out} = {rem};");
922        }
923
924        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
925        let [lhs, rhs, out_optimized] = optimized.args;
926
927        let item_out_original = out.item();
928        let item_out_optimized = out_optimized.item();
929
930        let index = match optimized.optimization_factor {
931            Some(factor) => item_out_original.vectorization / factor,
932            None => item_out_optimized.vectorization,
933        };
934
935        let floor = floor(*item_out_optimized.elem());
936
937        let mut write_op =
938            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
939                let out = out.fmt_left();
940                writeln!(f, "{out} = {item_out}{{")?;
941                for i in 0..index {
942                    let lhsi = lhs.index(i);
943                    let rhsi = rhs.index(i);
944
945                    let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
946                    writeln!(f, "{rem}")?;
947                    f.write_str(", ")?;
948                }
949
950                f.write_str("};\n")
951            };
952
953        if item_out_original == item_out_optimized {
954            write_op(&lhs, &rhs, out, item_out_optimized)
955        } else {
956            let out_tmp = Variable::tmp(item_out_optimized);
957
958            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
959
960            let addr_space = D::address_space_for_variable(&out_tmp);
961            let qualifier = out.const_qualifier();
962            let out = out.fmt_left();
963
964            writeln!(
965                f,
966                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
967            )?;
968
969            Ok(())
970        }
971    }
972}
973
974struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
975    _dialect: PhantomData<D>,
976    _sqrt: PhantomData<S>,
977}
978
979impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
980    fn format(
981        f: &mut core::fmt::Formatter<'_>,
982        input: &Variable<D>,
983        out: &Variable<D>,
984    ) -> core::fmt::Result {
985        let num = input.item().vectorization;
986        let elem = input.elem();
987
988        let mag = format!("{out}_mag");
989
990        writeln!(f, "{} {mag} = 0.0;", out.item())?;
991
992        for i in 0..num {
993            let input_i = input.index(i);
994            writeln!(f, "{mag} += {input_i} * {input_i};")?;
995        }
996
997        let out = out.fmt_left();
998        write!(f, "{out} = ")?;
999        S::format_unary(f, &mag, elem)?;
1000        f.write_str(";\n")
1001    }
1002}
1003
1004struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1005    _dialect: PhantomData<D>,
1006    _rsqrt: PhantomData<InvS>,
1007}
1008
1009impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1010    fn format(
1011        f: &mut core::fmt::Formatter<'_>,
1012        input: &Variable<D>,
1013        out: &Variable<D>,
1014    ) -> core::fmt::Result {
1015        let num = input.item().vectorization;
1016        let elem = input.elem();
1017        let norm = format!("{out}_norm");
1018
1019        let out_item = out.item();
1020        let out = out.fmt_left();
1021        writeln!(f, "{elem} {norm} = 0.0;")?;
1022
1023        for i in 0..num {
1024            let input_i = input.index(i);
1025            writeln!(f, "{norm} += {input_i} * {input_i};")?;
1026        }
1027
1028        write!(f, "{norm} = ")?;
1029        InvS::format_unary(f, &norm, elem)?;
1030        f.write_str(";\n")?;
1031
1032        if num == 1 {
1033            writeln!(f, "{out} = {input} * {norm};")
1034        } else {
1035            write!(f, "{out} = {out_item}{{")?;
1036            for i in 0..num {
1037                let input_i = input.index(i);
1038
1039                writeln!(f, "{input_i} * {norm},")?;
1040            }
1041
1042            f.write_str("};\n")
1043        }
1044    }
1045}
1046
1047struct Dot<D: Dialect> {
1048    _dialect: PhantomData<D>,
1049}
1050
1051impl<D: Dialect> Dot<D> {
1052    fn format(
1053        f: &mut core::fmt::Formatter<'_>,
1054        lhs: &Variable<D>,
1055        rhs: &Variable<D>,
1056        out: &Variable<D>,
1057    ) -> core::fmt::Result {
1058        let num = lhs.item().vectorization;
1059
1060        let muls = (0..num)
1061            .map(|i| {
1062                let lhs_i = lhs.index(i);
1063                let rhs_i = rhs.index(i);
1064                format!("{lhs_i} * {rhs_i}")
1065            })
1066            .collect::<Vec<_>>();
1067
1068        let out = out.fmt_left();
1069        writeln!(f, "{out} = {};", muls.join(" + "))
1070    }
1071}
1072
1073struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1074    var: &'a V,
1075    elem: &'a Elem<D>,
1076}
1077
1078impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1079    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1080        if self.elem != &Elem::Bool {
1081            write!(f, "bool({})", self.var)
1082        } else {
1083            write!(f, "{}", self.var)
1084        }
1085    }
1086}