Skip to main content

cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Formatter, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const DYNAMIC_META_NAME: &str = "dynamic_meta";
15pub(crate) const STATIC_META_NAME: &str = "info.static_meta";
16
17#[derive(Debug, Clone, Copy)]
18pub struct BinaryInstruction<D: Dialect> {
19    pub lhs: Variable<D>,
20    pub rhs: Variable<D>,
21    pub out: Variable<D>,
22}
23
24#[derive(Debug, Clone)]
25pub struct IndexInstruction<D: Dialect> {
26    pub list: Variable<D>,
27    pub index: Variable<D>,
28    pub vector_size: u32,
29    pub out: Variable<D>,
30}
31
32#[derive(Debug, Clone)]
33pub struct IndexAssignInstruction<D: Dialect> {
34    pub index: Variable<D>,
35    pub value: Variable<D>,
36    pub vector_size: u32,
37    pub out: Variable<D>,
38}
39
40#[derive(Debug, Clone, Copy)]
41pub struct UnaryInstruction<D: Dialect> {
42    pub input: Variable<D>,
43    pub out: Variable<D>,
44}
45
46#[derive(Debug, Clone)]
47pub enum Instruction<D: Dialect> {
48    Metadata {
49        info_offset: Variable<D>,
50        out: Variable<D>,
51    },
52    ExtendedMetadata {
53        info_offset: Variable<D>,
54        dim: Variable<D>,
55        out: Variable<D>,
56    },
57    ConstLength {
58        length: usize,
59        out: Variable<D>,
60    },
61    SliceLength {
62        input: Variable<D>,
63        out: Variable<D>,
64    },
65    DeclareVariable {
66        var: Variable<D>,
67    },
68    Modulo(BinaryInstruction<D>),
69    Remainder(BinaryInstruction<D>),
70    Add(BinaryInstruction<D>),
71    SaturatingAdd(BinaryInstruction<D>),
72    Fma {
73        a: Variable<D>,
74        b: Variable<D>,
75        c: Variable<D>,
76        out: Variable<D>,
77    },
78    Div(BinaryInstruction<D>),
79    FastDiv(BinaryInstruction<D>),
80    FastRecip(UnaryInstruction<D>),
81    Mul(BinaryInstruction<D>),
82    Sub(BinaryInstruction<D>),
83    SaturatingSub(BinaryInstruction<D>),
84    HiMul(BinaryInstruction<D>),
85    Index(IndexInstruction<D>),
86    IndexAssign(IndexAssignInstruction<D>),
87    Assign(UnaryInstruction<D>),
88    SpecialCast(UnaryInstruction<D>),
89    RangeLoop {
90        i: Variable<D>,
91        start: Variable<D>,
92        end: Variable<D>,
93        step: Option<Variable<D>>,
94        inclusive: bool,
95        instructions: Vec<Self>,
96    },
97    VecInit {
98        inputs: Vec<Variable<D>>,
99        out: Variable<D>,
100    },
101    Loop {
102        instructions: Vec<Self>,
103    },
104    If {
105        cond: Variable<D>,
106        instructions: Vec<Self>,
107    },
108    IfElse {
109        cond: Variable<D>,
110        instructions_if: Vec<Self>,
111        instructions_else: Vec<Self>,
112    },
113    Select {
114        cond: Variable<D>,
115        then: Variable<D>,
116        or_else: Variable<D>,
117        out: Variable<D>,
118    },
119    Switch {
120        value: Variable<D>,
121        instructions_default: Vec<Self>,
122        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
123    },
124    Slice {
125        input: Variable<D>,
126        start: Variable<D>,
127        end: Variable<D>,
128        out: Variable<D>,
129    },
130    CheckedSlice {
131        input: Variable<D>,
132        start: Variable<D>,
133        end: Variable<D>,
134        out: Variable<D>,
135        len: Variable<D>,
136    },
137    ReinterpretSlice {
138        input: Variable<D>,
139        vector_size: u32,
140        out: Variable<D>,
141    },
142    Return,
143    Break,
144    Unreachable,
145    Equal(BinaryInstruction<D>),
146    NotEqual(BinaryInstruction<D>),
147    Lower(BinaryInstruction<D>),
148    Greater(BinaryInstruction<D>),
149    LowerEqual(BinaryInstruction<D>),
150    GreaterEqual(BinaryInstruction<D>),
151    Erf(UnaryInstruction<D>),
152    BitwiseOr(BinaryInstruction<D>),
153    BitwiseAnd(BinaryInstruction<D>),
154    BitwiseXor(BinaryInstruction<D>),
155    CountBits(UnaryInstruction<D>),
156    ReverseBits(UnaryInstruction<D>),
157    ShiftLeft(BinaryInstruction<D>),
158    ShiftRight(BinaryInstruction<D>),
159    BitwiseNot(UnaryInstruction<D>),
160    LeadingZeros(UnaryInstruction<D>),
161    TrailingZeros(UnaryInstruction<D>),
162    FindFirstSet(UnaryInstruction<D>),
163    Abs(UnaryInstruction<D>),
164    Exp(UnaryInstruction<D>),
165    FastExp(UnaryInstruction<D>),
166    Log(UnaryInstruction<D>),
167    FastLog(UnaryInstruction<D>),
168    Log1p(UnaryInstruction<D>),
169    Cos(UnaryInstruction<D>),
170    Sin(UnaryInstruction<D>),
171    Tan(UnaryInstruction<D>),
172    Tanh(UnaryInstruction<D>),
173    Sinh(UnaryInstruction<D>),
174    Cosh(UnaryInstruction<D>),
175    ArcCos(UnaryInstruction<D>),
176    ArcSin(UnaryInstruction<D>),
177    ArcTan(UnaryInstruction<D>),
178    ArcSinh(UnaryInstruction<D>),
179    ArcCosh(UnaryInstruction<D>),
180    ArcTanh(UnaryInstruction<D>),
181    Degrees(UnaryInstruction<D>),
182    Radians(UnaryInstruction<D>),
183    ArcTan2(BinaryInstruction<D>),
184    FastSin(UnaryInstruction<D>),
185    FastCos(UnaryInstruction<D>),
186    FastTanh(UnaryInstruction<D>),
187    Powf(BinaryInstruction<D>),
188    FastPowf(BinaryInstruction<D>),
189    Powi(BinaryInstruction<D>),
190    Hypot(BinaryInstruction<D>),
191    Rhypot(BinaryInstruction<D>),
192    Sqrt(UnaryInstruction<D>),
193    FastSqrt(UnaryInstruction<D>),
194    InverseSqrt(UnaryInstruction<D>),
195    FastInverseSqrt(UnaryInstruction<D>),
196    Min(BinaryInstruction<D>),
197    Max(BinaryInstruction<D>),
198    Not(UnaryInstruction<D>),
199    Or(BinaryInstruction<D>),
200    And(BinaryInstruction<D>),
201    Clamp {
202        input: Variable<D>,
203        min_value: Variable<D>,
204        max_value: Variable<D>,
205        out: Variable<D>,
206    },
207    IsNan(UnaryInstruction<D>),
208    IsInf(UnaryInstruction<D>),
209    SyncThreads,
210    SyncWarp,
211    ThreadFence,
212    ProxyAsyncToSharedFence,
213    BulkCommitGroup,
214    BulkWaitGroup {
215        max_pending: u32,
216    },
217    BulkWaitGroupRead {
218        max_pending: u32,
219    },
220    TmaReplacePointer {
221        buffer: Variable<D>,
222        offset: Variable<D>,
223        tensor_map: Variable<D>,
224        out: Variable<D>,
225    },
226    Round(UnaryInstruction<D>),
227    Ceil(UnaryInstruction<D>),
228    Trunc(UnaryInstruction<D>),
229    Floor(UnaryInstruction<D>),
230    Warp(WarpInstruction<D>),
231    Wmma(WmmaInstruction<D>),
232    Bitcast(UnaryInstruction<D>),
233    AtomicLoad(UnaryInstruction<D>),
234    AtomicStore(UnaryInstruction<D>),
235    AtomicSwap(BinaryInstruction<D>),
236    AtomicAdd(BinaryInstruction<D>),
237    AtomicSub(BinaryInstruction<D>),
238    AtomicMax(BinaryInstruction<D>),
239    AtomicMin(BinaryInstruction<D>),
240    AtomicAnd(BinaryInstruction<D>),
241    AtomicOr(BinaryInstruction<D>),
242    AtomicXor(BinaryInstruction<D>),
243    AtomicCAS {
244        input: Variable<D>,
245        cmp: Variable<D>,
246        val: Variable<D>,
247        out: Variable<D>,
248    },
249    Neg(UnaryInstruction<D>),
250    Magnitude(UnaryInstruction<D>),
251    FastMagnitude(UnaryInstruction<D>),
252    Normalize(UnaryInstruction<D>),
253    FastNormalize(UnaryInstruction<D>),
254    Dot(BinaryInstruction<D>),
255    VectorSum(UnaryInstruction<D>),
256    Copy {
257        input: Variable<D>,
258        in_index: Variable<D>,
259        out: Variable<D>,
260        out_index: Variable<D>,
261    },
262    CopyBulk {
263        input: Variable<D>,
264        in_index: Variable<D>,
265        out: Variable<D>,
266        out_index: Variable<D>,
267        len: u32,
268    },
269    Printf {
270        format_string: String,
271        args: Vec<Variable<D>>,
272    },
273    Comment {
274        content: String,
275    },
276    Pipeline(PipelineOps<D>),
277    Barrier(BarrierOps<D>),
278    MemCopyAsyncTensorSharedToGlobal {
279        smem_buffer: Variable<D>,
280        smem_offset: Variable<D>,
281        tensor_map: Variable<D>,
282        indices: Vec<Variable<D>>,
283    },
284    Line {
285        file: Cow<'static, str>,
286        line: u32,
287    },
288}
289
290impl<D: Dialect> Display for Instruction<D> {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        match self {
293            Instruction::Return => f.write_str("return;"),
294            Instruction::Break => f.write_str("break;"),
295            Instruction::Unreachable => D::compile_unreachable(f),
296            Instruction::DeclareVariable { var } => match var {
297                Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
298                _ => {
299                    let item = var.item();
300                    writeln!(f, "{item} {var};")
301                }
302            },
303            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
304            Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
305            Instruction::Slice {
306                input,
307                start,
308                end,
309                out,
310            } => {
311                let item = out.item();
312                let addr_space = D::address_space_for_variable(input);
313                writeln!(f, "const uint {out}_length = {end} - {start};")?;
314                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
315            }
316            Instruction::CheckedSlice {
317                input,
318                start,
319                end,
320                out,
321                len,
322            } => {
323                let item = out.item();
324                let addr_space = D::address_space_for_variable(input);
325                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
326                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
327            }
328            Instruction::ReinterpretSlice {
329                input,
330                vector_size,
331                out,
332            } => {
333                let mut item = out.item();
334                item.vectorization = *vector_size as usize;
335                let addr_space = D::address_space_for_variable(input);
336
337                writeln!(
338                    f,
339                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
340                )
341            }
342            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
343            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
344            Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
345            Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
346            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
347            Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
348            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
349            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
350            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
351            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
352            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
353            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
354            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
355            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
356            Instruction::TrailingZeros(it) => TrailingZeros::format(f, &it.input, &it.out),
357            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
358            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
359            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
360            Instruction::Index(it) => {
361                Index::format(f, &it.list, &it.index, &it.out, it.vector_size)
362            }
363            Instruction::IndexAssign(it) => {
364                IndexAssign::format(f, &it.index, &it.value, &it.out, it.vector_size)
365            }
366            Instruction::Copy {
367                input,
368                in_index,
369                out,
370                out_index,
371            } => {
372                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
373            }
374            Instruction::CopyBulk {
375                input,
376                in_index,
377                out,
378                out_index,
379                len,
380            } => {
381                for i in 0..*len {
382                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
383                }
384                Ok(())
385            }
386            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
387            Instruction::RangeLoop {
388                i,
389                start,
390                end,
391                step,
392                inclusive,
393                instructions,
394            } => {
395                let increment = step
396                    .map(|step| format!("{i} += {step}"))
397                    .unwrap_or_else(|| format!("++{i}"));
398                let cmp = if *inclusive { "<=" } else { "<" };
399                let i_ty = i.item();
400
401                write!(
402                    f,
403                    "
404for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
405"
406                )?;
407                for instruction in instructions {
408                    write!(f, "{instruction}")?;
409                }
410
411                f.write_str("}\n")
412            }
413            Instruction::Loop { instructions } => {
414                writeln!(f, "while (true) {{")?;
415                for i in instructions {
416                    write!(f, "{i}")?;
417                }
418                f.write_str("}\n")
419            }
420            Instruction::If { cond, instructions } => {
421                writeln!(f, "if ({cond}) {{")?;
422                for i in instructions {
423                    write!(f, "{i}")?;
424                }
425                f.write_str("}\n")
426            }
427            Instruction::IfElse {
428                cond,
429                instructions_if,
430                instructions_else,
431            } => {
432                writeln!(f, "if ({cond}) {{")?;
433                for i in instructions_if {
434                    write!(f, "{i}")?;
435                }
436                f.write_str("} else {\n")?;
437                for i in instructions_else {
438                    write!(f, "{i}")?;
439                }
440                f.write_str("}\n")
441            }
442            Instruction::Select {
443                cond,
444                then,
445                or_else,
446                out,
447            } => {
448                let item_or_else = or_else.item();
449                let item_then = then.item();
450                let item_out = out.item();
451
452                let vf_then = item_then.vectorization;
453                let vf_or_else = item_or_else.vectorization;
454                let vf_out = item_out.vectorization;
455                let vf_cond = cond.item().vectorization;
456
457                let item_out = out.item();
458                let cond_elem = cond.item().elem;
459                let out = out.fmt_left();
460
461                // It seems to always be faster to broadcast the select, because the compiler is
462                // able to output branchless instructions when the ternary is done on native types
463                // rather than cubecl defined types.
464
465                let vf = usize::max(vf_cond, vf_out);
466                let vf = usize::max(vf, vf_then);
467                let vf = usize::max(vf, vf_or_else);
468                let should_broadcast = vf > 1;
469
470                // Keep the condition here for future testing.
471                //
472                // let should_broadcast =
473                //     vf_cond > 1 || item_out != item_or_else || item_out != item_then;
474
475                if should_broadcast {
476                    writeln!(f, "{out} = {item_out} {{")?;
477                    for i in 0..vf {
478                        let theni = then.index(i);
479                        let or_elsei = or_else.index(i);
480                        let condi = cond.index(i);
481                        let condi = EnsureBoolArg {
482                            var: &condi,
483                            elem: &cond_elem,
484                        };
485
486                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
487                    }
488
489                    writeln!(f, "}};")
490                } else {
491                    let cond = EnsureBoolArg {
492                        var: &cond,
493                        elem: &cond_elem,
494                    };
495                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
496                }
497            }
498            Instruction::Switch {
499                value,
500                instructions_default,
501                instructions_cases,
502            } => {
503                writeln!(f, "switch({value}) {{")?;
504                for (value, block) in instructions_cases {
505                    write!(f, "case {value}:\n{{\n")?;
506                    for i in block {
507                        i.fmt(f)?;
508                    }
509                    f.write_str("break;\n}\n")?;
510                }
511                f.write_str("default:\n{")?;
512                for i in instructions_default {
513                    i.fmt(f)?;
514                }
515                f.write_str("break;\n}\n}\n")
516            }
517            Instruction::Metadata { info_offset, out } => {
518                let out = out.fmt_left();
519                writeln!(f, "{out} = {STATIC_META_NAME}[{info_offset}];")
520            }
521            Instruction::ExtendedMetadata {
522                info_offset,
523                dim,
524                out,
525            } => {
526                let out = out.fmt_left();
527                writeln!(
528                    f,
529                    "{out} = {DYNAMIC_META_NAME}[{STATIC_META_NAME}[{info_offset}] + {dim}];"
530                )
531            }
532            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
533            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
534            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
535            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
536            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
537            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
538            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
539            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
540            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
541            Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
542            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
543            Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
544            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
545            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
546            Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
547            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
548            Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
549            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
550            Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
551            Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
552            Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
553            Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
554            Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
555            Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
556            Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
557            Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
558            Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
559            Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
560            Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
561            Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
562            Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
563            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
564            Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
565            Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
566            Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
567            Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
568            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
569            Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
570            Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
571            Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
572            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
573            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
574            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
575            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
576            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
577            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
578            Instruction::Clamp {
579                input,
580                min_value,
581                max_value,
582                out,
583            } => Clamp::format(f, input, min_value, max_value, out),
584            Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
585            Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
586            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
587            Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
588            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
589            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
590            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
591            Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
592            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
593            Instruction::SliceLength { input, out } => {
594                let out = out.fmt_left();
595                writeln!(f, "{out} = {input}_length;")
596            }
597            Instruction::ConstLength { length, out } => {
598                let out = out.fmt_left();
599                writeln!(f, "{out} = {length};")
600            }
601            Instruction::Warp(it) => write!(f, "{it}"),
602            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
603            Instruction::Wmma(it) => write!(f, "{it}"),
604            Instruction::Bitcast(UnaryInstruction { input, out }) => {
605                let qualifier = out.const_qualifier();
606                let input_item = input.item();
607                let out_item = out.item();
608
609                if out_item.elem.size() * out_item.vectorization
610                    != input.item().elem.size() * input.item().vectorization
611                {
612                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
613                } else {
614                    let out = out.fmt_left();
615                    let addr_space = D::address_space_for_variable(input);
616                    writeln!(
617                        f,
618                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
619                    )
620                }
621            }
622            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
623                D::compile_atomic_add(f, lhs, rhs, out)
624            }
625            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
626                D::compile_atomic_and(f, lhs, rhs, out)
627            }
628            Instruction::AtomicCAS {
629                input,
630                cmp,
631                val,
632                out,
633            } => D::compile_atomic_cas(f, input, cmp, val, out),
634            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
635                D::compile_atomic_load(f, input, out)
636            }
637            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
638                D::compile_atomic_max(f, lhs, rhs, out)
639            }
640            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
641                D::compile_atomic_min(f, lhs, rhs, out)
642            }
643            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
644                D::compile_atomic_or(f, lhs, rhs, out)
645            }
646            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
647                D::compile_atomic_store(f, input, out)
648            }
649            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
650                D::compile_atomic_sub(f, lhs, rhs, out)
651            }
652            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
653                D::compile_atomic_swap(f, lhs, rhs, out)
654            }
655            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
656                D::compile_atomic_xor(f, lhs, rhs, out)
657            }
658            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
659            Instruction::Neg(UnaryInstruction { input, out }) => {
660                let out = out.fmt_left();
661                writeln!(f, "{out} = -{input};")
662            }
663            Instruction::Normalize(inst) => {
664                Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
665            }
666            Instruction::FastNormalize(inst) => {
667                Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
668            }
669            Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
670            Instruction::FastMagnitude(inst) => {
671                Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
672            }
673            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
674            Instruction::VectorSum(inst) => VectorSumFmt::<D>::format(f, &inst.input, &inst.out),
675            Instruction::VecInit { inputs, out } => {
676                let item = out.item();
677                let inputs = inputs
678                    .iter()
679                    .map(|input| format!("{input}"))
680                    .collect::<Vec<_>>();
681                let out = out.fmt_left();
682                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
683            }
684            Instruction::Printf {
685                format_string,
686                args,
687            } => D::compile_instruction_printf(f, format_string, args),
688            Instruction::Comment { content } => {
689                if content.contains('\n') {
690                    writeln!(f, "/* {content} */")
691                } else {
692                    writeln!(f, "// {content}")
693                }
694            }
695            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
696            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
697            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
698            Instruction::ProxyAsyncToSharedFence => {
699                writeln!(
700                    f,
701                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
702                )
703            }
704            Instruction::BulkCommitGroup => writeln!(
705                f,
706                "cuda::device::experimental::cp_async_bulk_commit_group();"
707            ),
708            Instruction::BulkWaitGroup { max_pending } => writeln!(
709                f,
710                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
711            ),
712            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
713                f,
714                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
715            ),
716            Instruction::TmaReplacePointer {
717                buffer,
718                offset,
719                tensor_map,
720                out,
721            } => {
722                let pos = Variable::<D>::UnitPos;
723                writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
724                writeln!(
725                    f,
726                    "
727if({pos} == 0) {{
728    {out} = {tensor_map};
729    tensormap_replace_global_address({out}, &{buffer}[{offset}]);
730}}"
731                )?;
732                writeln!(f, "__syncthreads();")
733            }
734            Instruction::MemCopyAsyncTensorSharedToGlobal {
735                smem_buffer,
736                smem_offset,
737                tensor_map,
738                indices,
739            } => {
740                let rank = indices.len();
741                let smem_ptr = smem_buffer.fmt_ptr();
742                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
743                    let _ = write!(s, "{it}, ");
744                    s
745                });
746                writeln!(
747                    f,
748                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
749                )
750            }
751            Instruction::SpecialCast(UnaryInstruction { input, out }) => {
752                // Only supported in CUDA so I'm putting it here. Move to dialect if necessary.
753                #[cfg(not(feature = "cuda"))]
754                {
755                    let _ = (input, out);
756                    writeln!(
757                        f,
758                        "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
759                    )
760                }
761                #[cfg(feature = "cuda")]
762                crate::cuda::convert::special_cast::<D>(f, input, out)
763            }
764        }
765    }
766}
767
768struct Fma<D: Dialect> {
769    _dialect: PhantomData<D>,
770}
771
772impl<D: Dialect> Fma<D> {
773    fn format(
774        f: &mut core::fmt::Formatter<'_>,
775        a: &Variable<D>,
776        b: &Variable<D>,
777        c: &Variable<D>,
778        out: &Variable<D>,
779    ) -> core::fmt::Result {
780        let out_item = out.item();
781        let num = out_item.vectorization;
782
783        let out = out.fmt_left();
784        if num == 1 {
785            writeln!(f, "{out} = fma({a}, {b}, {c});")
786        } else {
787            writeln!(f, "{out} = {out_item}{{")?;
788
789            for i in 0..num {
790                let ai = a.index(i);
791                let bi = b.index(i);
792                let ci = c.index(i);
793
794                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
795            }
796            f.write_str("};\n")
797        }
798    }
799}
800
801struct Clamp<D: Dialect> {
802    _dialect: PhantomData<D>,
803}
804
805impl<D: Dialect> Clamp<D> {
806    fn format(
807        f: &mut core::fmt::Formatter<'_>,
808        input: &Variable<D>,
809        min_value: &Variable<D>,
810        max_value: &Variable<D>,
811        out: &Variable<D>,
812    ) -> core::fmt::Result {
813        let out_item = out.item();
814        if out.item().vectorization == 1 {
815            let out = out.fmt_left();
816            write!(f, "{out} = ")?;
817            Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
818            f.write_str(";\n")
819        } else {
820            Self::unroll_vec(f, input, min_value, max_value, out)
821        }
822    }
823
824    fn format_scalar(
825        f: &mut Formatter<'_>,
826        input: impl Component<D>,
827        min_value: impl Component<D>,
828        max_value: impl Component<D>,
829        item: Item<D>,
830    ) -> std::fmt::Result {
831        D::compile_instruction_max_function_name(f, item)?;
832        write!(f, "({min_value}, ")?;
833        D::compile_instruction_min_function_name(f, item)?;
834        write!(f, "({max_value}, {input}))")
835    }
836
837    fn unroll_vec(
838        f: &mut core::fmt::Formatter<'_>,
839        input: &Variable<D>,
840        min_value: &Variable<D>,
841        max_value: &Variable<D>,
842        out: &Variable<D>,
843    ) -> std::fmt::Result {
844        let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
845        let [input, min_value, max_value, out_optimized] = optimized.args;
846
847        let item_out_original = out.item();
848        let item_out_optimized = out_optimized.item();
849
850        let index = match optimized.optimization_factor {
851            Some(factor) => item_out_original.vectorization / factor,
852            None => item_out_optimized.vectorization,
853        };
854
855        let mut write_op = |input: &Variable<D>,
856                            min_value: &Variable<D>,
857                            max_value: &Variable<D>,
858                            out: &Variable<D>,
859                            item_out: Item<D>| {
860            let out = out.fmt_left();
861            writeln!(f, "{out} = {item_out}{{")?;
862            for i in 0..index {
863                let inputi = input.index(i);
864                let min_valuei = min_value.index(i);
865                let max_valuei = max_value.index(i);
866
867                Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
868                f.write_str(", ")?;
869            }
870
871            f.write_str("};\n")
872        };
873
874        if item_out_original == item_out_optimized {
875            write_op(&input, &min_value, &max_value, out, item_out_optimized)
876        } else {
877            let out_tmp = Variable::tmp(item_out_optimized);
878            write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
879            let addr_space = D::address_space_for_variable(out);
880            let out = out.fmt_left();
881
882            writeln!(
883                f,
884                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
885            )?;
886
887            Ok(())
888        }
889    }
890}
891
892struct Remainder<D: Dialect> {
893    _dialect: PhantomData<D>,
894}
895
896impl<D: Dialect> Remainder<D> {
897    fn format(
898        f: &mut core::fmt::Formatter<'_>,
899        lhs: &Variable<D>,
900        rhs: &Variable<D>,
901        out: &Variable<D>,
902    ) -> core::fmt::Result {
903        let floor = |elem| {
904            let prefix = match elem {
905                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
906                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
907                _ => "",
908            };
909            format!("{prefix}floor")
910        };
911
912        let is_int = matches!(
913            out.elem(),
914            Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
915        );
916        let out_elem = out.elem();
917        let rem_expr = |lhs, rhs, floor: &str| {
918            if is_int {
919                format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})")
920            } else {
921                format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
922            }
923        };
924
925        if out.item().vectorization == 1 {
926            let floor = floor(out.elem());
927
928            let out = out.fmt_left();
929            let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
930            return writeln!(f, "{out} = {rem};");
931        }
932
933        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
934        let [lhs, rhs, out_optimized] = optimized.args;
935
936        let item_out_original = out.item();
937        let item_out_optimized = out_optimized.item();
938
939        let index = match optimized.optimization_factor {
940            Some(factor) => item_out_original.vectorization / factor,
941            None => item_out_optimized.vectorization,
942        };
943
944        let floor = floor(*item_out_optimized.elem());
945
946        let mut write_op =
947            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
948                let out = out.fmt_left();
949                writeln!(f, "{out} = {item_out}{{")?;
950                for i in 0..index {
951                    let lhsi = lhs.index(i);
952                    let rhsi = rhs.index(i);
953
954                    let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
955                    writeln!(f, "{rem}")?;
956                    f.write_str(", ")?;
957                }
958
959                f.write_str("};\n")
960            };
961
962        if item_out_original == item_out_optimized {
963            write_op(&lhs, &rhs, out, item_out_optimized)
964        } else {
965            let out_tmp = Variable::tmp(item_out_optimized);
966
967            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
968
969            let addr_space = D::address_space_for_variable(&out_tmp);
970            let qualifier = out.const_qualifier();
971            let out = out.fmt_left();
972
973            writeln!(
974                f,
975                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
976            )?;
977
978            Ok(())
979        }
980    }
981}
982
983struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
984    _dialect: PhantomData<D>,
985    _sqrt: PhantomData<S>,
986}
987
988impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
989    fn format(
990        f: &mut core::fmt::Formatter<'_>,
991        input: &Variable<D>,
992        out: &Variable<D>,
993    ) -> core::fmt::Result {
994        let num = input.item().vectorization;
995        let elem = input.elem();
996
997        let mag = format!("{out}_mag");
998
999        writeln!(f, "{} {mag} = 0.0;", out.item())?;
1000
1001        for i in 0..num {
1002            let input_i = input.index(i);
1003            writeln!(f, "{mag} += {input_i} * {input_i};")?;
1004        }
1005
1006        let out = out.fmt_left();
1007        write!(f, "{out} = ")?;
1008        S::format_unary(f, &mag, elem)?;
1009        f.write_str(";\n")
1010    }
1011}
1012
1013struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1014    _dialect: PhantomData<D>,
1015    _rsqrt: PhantomData<InvS>,
1016}
1017
1018impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1019    fn format(
1020        f: &mut core::fmt::Formatter<'_>,
1021        input: &Variable<D>,
1022        out: &Variable<D>,
1023    ) -> core::fmt::Result {
1024        let num = input.item().vectorization;
1025        let elem = input.elem();
1026        let norm = format!("{out}_norm");
1027
1028        let out_item = out.item();
1029        let out = out.fmt_left();
1030        writeln!(f, "{elem} {norm} = 0.0;")?;
1031
1032        for i in 0..num {
1033            let input_i = input.index(i);
1034            writeln!(f, "{norm} += {input_i} * {input_i};")?;
1035        }
1036
1037        write!(f, "{norm} = ")?;
1038        InvS::format_unary(f, &norm, elem)?;
1039        f.write_str(";\n")?;
1040
1041        if num == 1 {
1042            writeln!(f, "{out} = {input} * {norm};")
1043        } else {
1044            write!(f, "{out} = {out_item}{{")?;
1045            for i in 0..num {
1046                let input_i = input.index(i);
1047
1048                writeln!(f, "{input_i} * {norm},")?;
1049            }
1050
1051            f.write_str("};\n")
1052        }
1053    }
1054}
1055
1056struct Dot<D: Dialect> {
1057    _dialect: PhantomData<D>,
1058}
1059
1060impl<D: Dialect> Dot<D> {
1061    fn format(
1062        f: &mut core::fmt::Formatter<'_>,
1063        lhs: &Variable<D>,
1064        rhs: &Variable<D>,
1065        out: &Variable<D>,
1066    ) -> core::fmt::Result {
1067        let num = lhs.item().vectorization;
1068
1069        let muls = (0..num)
1070            .map(|i| {
1071                let lhs_i = lhs.index(i);
1072                let rhs_i = rhs.index(i);
1073                format!("{lhs_i} * {rhs_i}")
1074            })
1075            .collect::<Vec<_>>();
1076
1077        let out = out.fmt_left();
1078        writeln!(f, "{out} = {};", muls.join(" + "))
1079    }
1080}
1081
1082struct VectorSumFmt<D: Dialect> {
1083    _dialect: PhantomData<D>,
1084}
1085
1086impl<D: Dialect> VectorSumFmt<D> {
1087    fn format(
1088        f: &mut core::fmt::Formatter<'_>,
1089        input: &Variable<D>,
1090        out: &Variable<D>,
1091    ) -> core::fmt::Result {
1092        let num = input.item().vectorization;
1093
1094        let elems = (0..num)
1095            .map(|i| format!("{}", input.index(i)))
1096            .collect::<Vec<_>>();
1097
1098        let out = out.fmt_left();
1099        writeln!(f, "{out} = {};", elems.join(" + "))
1100    }
1101}
1102
1103struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1104    var: &'a V,
1105    elem: &'a Elem<D>,
1106}
1107
1108impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1110        if self.elem != &Elem::Bool {
1111            write!(f, "bool({})", self.var)
1112        } else {
1113            write!(f, "{}", self.var)
1114        }
1115    }
1116}