cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Formatter, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18    pub lhs: Variable<D>,
19    pub rhs: Variable<D>,
20    pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25    pub list: Variable<D>,
26    pub index: Variable<D>,
27    pub line_size: u32,
28    pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33    pub index: Variable<D>,
34    pub value: Variable<D>,
35    pub line_size: u32,
36    pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41    pub input: Variable<D>,
42    pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47    Metadata {
48        info_offset: Variable<D>,
49        split_meta: bool,
50        out: Variable<D>,
51    },
52    ExtendedMetadata {
53        info_offset: Variable<D>,
54        dim: Variable<D>,
55        split_meta: bool,
56        static_offset: u32,
57        out: Variable<D>,
58    },
59    ConstLength {
60        length: u32,
61        out: Variable<D>,
62    },
63    SliceLength {
64        input: Variable<D>,
65        out: Variable<D>,
66    },
67    DeclareVariable {
68        var: Variable<D>,
69    },
70    Modulo(BinaryInstruction<D>),
71    Remainder(BinaryInstruction<D>),
72    Add(BinaryInstruction<D>),
73    SaturatingAdd(BinaryInstruction<D>),
74    Fma {
75        a: Variable<D>,
76        b: Variable<D>,
77        c: Variable<D>,
78        out: Variable<D>,
79    },
80    Div(BinaryInstruction<D>),
81    FastDiv(BinaryInstruction<D>),
82    FastRecip(UnaryInstruction<D>),
83    Mul(BinaryInstruction<D>),
84    Sub(BinaryInstruction<D>),
85    SaturatingSub(BinaryInstruction<D>),
86    HiMul(BinaryInstruction<D>),
87    Index(IndexInstruction<D>),
88    IndexAssign(IndexAssignInstruction<D>),
89    Assign(UnaryInstruction<D>),
90    SpecialCast(UnaryInstruction<D>),
91    RangeLoop {
92        i: Variable<D>,
93        start: Variable<D>,
94        end: Variable<D>,
95        step: Option<Variable<D>>,
96        inclusive: bool,
97        instructions: Vec<Self>,
98    },
99    VecInit {
100        inputs: Vec<Variable<D>>,
101        out: Variable<D>,
102    },
103    Loop {
104        instructions: Vec<Self>,
105    },
106    If {
107        cond: Variable<D>,
108        instructions: Vec<Self>,
109    },
110    IfElse {
111        cond: Variable<D>,
112        instructions_if: Vec<Self>,
113        instructions_else: Vec<Self>,
114    },
115    Select {
116        cond: Variable<D>,
117        then: Variable<D>,
118        or_else: Variable<D>,
119        out: Variable<D>,
120    },
121    Switch {
122        value: Variable<D>,
123        instructions_default: Vec<Self>,
124        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125    },
126    Slice {
127        input: Variable<D>,
128        start: Variable<D>,
129        end: Variable<D>,
130        out: Variable<D>,
131    },
132    CheckedSlice {
133        input: Variable<D>,
134        start: Variable<D>,
135        end: Variable<D>,
136        out: Variable<D>,
137        len: Variable<D>,
138    },
139    ReinterpretSlice {
140        input: Variable<D>,
141        line_size: u32,
142        out: Variable<D>,
143    },
144    Return,
145    Break,
146    Equal(BinaryInstruction<D>),
147    NotEqual(BinaryInstruction<D>),
148    Lower(BinaryInstruction<D>),
149    Greater(BinaryInstruction<D>),
150    LowerEqual(BinaryInstruction<D>),
151    GreaterEqual(BinaryInstruction<D>),
152    Erf(UnaryInstruction<D>),
153    BitwiseOr(BinaryInstruction<D>),
154    BitwiseAnd(BinaryInstruction<D>),
155    BitwiseXor(BinaryInstruction<D>),
156    CountBits(UnaryInstruction<D>),
157    ReverseBits(UnaryInstruction<D>),
158    ShiftLeft(BinaryInstruction<D>),
159    ShiftRight(BinaryInstruction<D>),
160    BitwiseNot(UnaryInstruction<D>),
161    LeadingZeros(UnaryInstruction<D>),
162    FindFirstSet(UnaryInstruction<D>),
163    Abs(UnaryInstruction<D>),
164    Exp(UnaryInstruction<D>),
165    FastExp(UnaryInstruction<D>),
166    Log(UnaryInstruction<D>),
167    FastLog(UnaryInstruction<D>),
168    Log1p(UnaryInstruction<D>),
169    Cos(UnaryInstruction<D>),
170    Sin(UnaryInstruction<D>),
171    Tan(UnaryInstruction<D>),
172    Tanh(UnaryInstruction<D>),
173    Sinh(UnaryInstruction<D>),
174    Cosh(UnaryInstruction<D>),
175    ArcCos(UnaryInstruction<D>),
176    ArcSin(UnaryInstruction<D>),
177    ArcTan(UnaryInstruction<D>),
178    ArcSinh(UnaryInstruction<D>),
179    ArcCosh(UnaryInstruction<D>),
180    ArcTanh(UnaryInstruction<D>),
181    Degrees(UnaryInstruction<D>),
182    Radians(UnaryInstruction<D>),
183    ArcTan2(BinaryInstruction<D>),
184    FastSin(UnaryInstruction<D>),
185    FastCos(UnaryInstruction<D>),
186    FastTanh(UnaryInstruction<D>),
187    Powf(BinaryInstruction<D>),
188    FastPowf(BinaryInstruction<D>),
189    Powi(BinaryInstruction<D>),
190    Sqrt(UnaryInstruction<D>),
191    FastSqrt(UnaryInstruction<D>),
192    InverseSqrt(UnaryInstruction<D>),
193    FastInverseSqrt(UnaryInstruction<D>),
194    Min(BinaryInstruction<D>),
195    Max(BinaryInstruction<D>),
196    Not(UnaryInstruction<D>),
197    Or(BinaryInstruction<D>),
198    And(BinaryInstruction<D>),
199    Clamp {
200        input: Variable<D>,
201        min_value: Variable<D>,
202        max_value: Variable<D>,
203        out: Variable<D>,
204    },
205    IsNan(UnaryInstruction<D>),
206    IsInf(UnaryInstruction<D>),
207    SyncThreads,
208    SyncWarp,
209    ThreadFence,
210    ProxyAsyncToSharedFence,
211    BulkCommitGroup,
212    BulkWaitGroup {
213        max_pending: u32,
214    },
215    BulkWaitGroupRead {
216        max_pending: u32,
217    },
218    TmaReplacePointer {
219        buffer: Variable<D>,
220        offset: Variable<D>,
221        tensor_map: Variable<D>,
222        out: Variable<D>,
223    },
224    Round(UnaryInstruction<D>),
225    Ceil(UnaryInstruction<D>),
226    Trunc(UnaryInstruction<D>),
227    Floor(UnaryInstruction<D>),
228    Warp(WarpInstruction<D>),
229    Wmma(WmmaInstruction<D>),
230    Bitcast(UnaryInstruction<D>),
231    AtomicLoad(UnaryInstruction<D>),
232    AtomicStore(UnaryInstruction<D>),
233    AtomicSwap(BinaryInstruction<D>),
234    AtomicAdd(BinaryInstruction<D>),
235    AtomicSub(BinaryInstruction<D>),
236    AtomicMax(BinaryInstruction<D>),
237    AtomicMin(BinaryInstruction<D>),
238    AtomicAnd(BinaryInstruction<D>),
239    AtomicOr(BinaryInstruction<D>),
240    AtomicXor(BinaryInstruction<D>),
241    AtomicCAS {
242        input: Variable<D>,
243        cmp: Variable<D>,
244        val: Variable<D>,
245        out: Variable<D>,
246    },
247    Neg(UnaryInstruction<D>),
248    Magnitude(UnaryInstruction<D>),
249    FastMagnitude(UnaryInstruction<D>),
250    Normalize(UnaryInstruction<D>),
251    FastNormalize(UnaryInstruction<D>),
252    Dot(BinaryInstruction<D>),
253    Copy {
254        input: Variable<D>,
255        in_index: Variable<D>,
256        out: Variable<D>,
257        out_index: Variable<D>,
258    },
259    CopyBulk {
260        input: Variable<D>,
261        in_index: Variable<D>,
262        out: Variable<D>,
263        out_index: Variable<D>,
264        len: u32,
265    },
266    Printf {
267        format_string: String,
268        args: Vec<Variable<D>>,
269    },
270    Comment {
271        content: String,
272    },
273    Pipeline(PipelineOps<D>),
274    Barrier(BarrierOps<D>),
275    MemCopyAsyncTensorSharedToGlobal {
276        smem_buffer: Variable<D>,
277        smem_offset: Variable<D>,
278        tensor_map: Variable<D>,
279        indices: Vec<Variable<D>>,
280    },
281    Line {
282        file: Cow<'static, str>,
283        line: u32,
284    },
285}
286
287impl<D: Dialect> Display for Instruction<D> {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        match self {
290            Instruction::Return => f.write_str("return;"),
291            Instruction::Break => f.write_str("break;"),
292            Instruction::DeclareVariable { var } => match var {
293                Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
294                _ => {
295                    let item = var.item();
296                    writeln!(f, "{item} {var};")
297                }
298            },
299            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
300            Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
301            Instruction::Slice {
302                input,
303                start,
304                end,
305                out,
306            } => {
307                let item = out.item();
308                let addr_space = D::address_space_for_variable(input);
309                writeln!(f, "const uint {out}_length = {end} - {start};")?;
310                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
311            }
312            Instruction::CheckedSlice {
313                input,
314                start,
315                end,
316                out,
317                len,
318            } => {
319                let item = out.item();
320                let addr_space = D::address_space_for_variable(input);
321                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
322                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
323            }
324            Instruction::ReinterpretSlice {
325                input,
326                line_size,
327                out,
328            } => {
329                let mut item = out.item();
330                item.vectorization = *line_size as usize;
331                let addr_space = D::address_space_for_variable(input);
332
333                writeln!(
334                    f,
335                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
336                )
337            }
338            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
339            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
340            Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
341            Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
342            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
343            Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
344            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
345            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
346            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
347            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
348            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
349            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
350            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
351            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
352            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
353            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
354            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
355            Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
356            Instruction::IndexAssign(it) => {
357                IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
358            }
359            Instruction::Copy {
360                input,
361                in_index,
362                out,
363                out_index,
364            } => {
365                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
366            }
367            Instruction::CopyBulk {
368                input,
369                in_index,
370                out,
371                out_index,
372                len,
373            } => {
374                for i in 0..*len {
375                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
376                }
377                Ok(())
378            }
379            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
380            Instruction::RangeLoop {
381                i,
382                start,
383                end,
384                step,
385                inclusive,
386                instructions,
387            } => {
388                let increment = step
389                    .map(|step| format!("{i} += {step}"))
390                    .unwrap_or_else(|| format!("++{i}"));
391                let cmp = if *inclusive { "<=" } else { "<" };
392                let i_ty = i.item();
393
394                write!(
395                    f,
396                    "
397for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
398"
399                )?;
400                for instruction in instructions {
401                    write!(f, "{instruction}")?;
402                }
403
404                f.write_str("}\n")
405            }
406            Instruction::Loop { instructions } => {
407                writeln!(f, "while (true) {{")?;
408                for i in instructions {
409                    write!(f, "{i}")?;
410                }
411                f.write_str("}\n")
412            }
413            Instruction::If { cond, instructions } => {
414                writeln!(f, "if ({cond}) {{")?;
415                for i in instructions {
416                    write!(f, "{i}")?;
417                }
418                f.write_str("}\n")
419            }
420            Instruction::IfElse {
421                cond,
422                instructions_if,
423                instructions_else,
424            } => {
425                writeln!(f, "if ({cond}) {{")?;
426                for i in instructions_if {
427                    write!(f, "{i}")?;
428                }
429                f.write_str("} else {\n")?;
430                for i in instructions_else {
431                    write!(f, "{i}")?;
432                }
433                f.write_str("}\n")
434            }
435            Instruction::Select {
436                cond,
437                then,
438                or_else,
439                out,
440            } => {
441                let item_or_else = or_else.item();
442                let item_then = then.item();
443                let item_out = out.item();
444
445                let vf_then = item_then.vectorization;
446                let vf_or_else = item_or_else.vectorization;
447                let vf_out = item_out.vectorization;
448                let vf_cond = cond.item().vectorization;
449
450                let item_out = out.item();
451                let cond_elem = cond.item().elem;
452                let out = out.fmt_left();
453
454                let should_broadcast =
455                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
456
457                if should_broadcast {
458                    let vf = usize::max(vf_cond, vf_out);
459                    let vf = usize::max(vf, vf_then);
460                    let vf = usize::max(vf, vf_or_else);
461
462                    writeln!(f, "{out} = {item_out} {{")?;
463                    for i in 0..vf {
464                        let theni = then.index(i);
465                        let or_elsei = or_else.index(i);
466                        let condi = cond.index(i);
467                        let condi = EnsureBoolArg {
468                            var: &condi,
469                            elem: &cond_elem,
470                        };
471
472                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
473                    }
474
475                    writeln!(f, "}};")
476                } else {
477                    let cond = EnsureBoolArg {
478                        var: &cond,
479                        elem: &cond_elem,
480                    };
481                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
482                }
483            }
484            Instruction::Switch {
485                value,
486                instructions_default,
487                instructions_cases,
488            } => {
489                writeln!(f, "switch({value}) {{")?;
490                for (value, block) in instructions_cases {
491                    write!(f, "case {value}:\n{{\n")?;
492                    for i in block {
493                        i.fmt(f)?;
494                    }
495                    f.write_str("break;\n}\n")?;
496                }
497                f.write_str("default:\n{")?;
498                for i in instructions_default {
499                    i.fmt(f)?;
500                }
501                f.write_str("}\n}\n")
502            }
503            Instruction::Metadata {
504                info_offset,
505                split_meta,
506                out,
507            } => {
508                let out = out.fmt_left();
509                match *split_meta {
510                    true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
511                    false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
512                }
513            }
514            Instruction::ExtendedMetadata {
515                info_offset,
516                dim,
517                split_meta,
518                static_offset,
519                out,
520            } => {
521                let out = out.fmt_left();
522                match *split_meta {
523                    true => writeln!(
524                        f,
525                        "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
526                    ),
527                    false => writeln!(
528                        f,
529                        "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
530                    ),
531                }
532            }
533            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
534            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
535            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
536            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
537            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
538            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
539            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
540            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
541            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
542            Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
543            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
544            Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
545            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
546            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
547            Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
548            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
549            Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
550            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
551            Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
552            Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
553            Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
554            Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
555            Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
556            Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
557            Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
558            Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
559            Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
560            Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
561            Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
562            Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
563            Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
564            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
565            Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
566            Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
567            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
568            Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
569            Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
570            Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
571            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
572            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
573            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
574            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
575            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
576            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
577            Instruction::Clamp {
578                input,
579                min_value,
580                max_value,
581                out,
582            } => Clamp::format(f, input, min_value, max_value, out),
583            Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
584            Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
585            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
586            Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
587            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
588            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
589            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
590            Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
591            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
592            Instruction::SliceLength { input, out } => {
593                let out = out.fmt_left();
594                writeln!(f, "{out} = {input}_length;")
595            }
596            Instruction::ConstLength { length, out } => {
597                let out = out.fmt_left();
598                writeln!(f, "{out} = {length};")
599            }
600            Instruction::Warp(it) => write!(f, "{it}"),
601            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
602            Instruction::Wmma(it) => write!(f, "{it}"),
603            Instruction::Bitcast(UnaryInstruction { input, out }) => {
604                let qualifier = out.const_qualifier();
605                let input_item = input.item();
606                let out_item = out.item();
607
608                if out_item.elem.size() * out_item.vectorization
609                    != input.item().elem.size() * input.item().vectorization
610                {
611                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
612                } else {
613                    let out = out.fmt_left();
614                    let addr_space = D::address_space_for_variable(input);
615                    writeln!(
616                        f,
617                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
618                    )
619                }
620            }
621            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
622                D::compile_atomic_add(f, lhs, rhs, out)
623            }
624            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
625                D::compile_atomic_and(f, lhs, rhs, out)
626            }
627            Instruction::AtomicCAS {
628                input,
629                cmp,
630                val,
631                out,
632            } => D::compile_atomic_cas(f, input, cmp, val, out),
633            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
634                D::compile_atomic_load(f, input, out)
635            }
636            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
637                D::compile_atomic_max(f, lhs, rhs, out)
638            }
639            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
640                D::compile_atomic_min(f, lhs, rhs, out)
641            }
642            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
643                D::compile_atomic_or(f, lhs, rhs, out)
644            }
645            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
646                D::compile_atomic_store(f, input, out)
647            }
648            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
649                D::compile_atomic_sub(f, lhs, rhs, out)
650            }
651            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
652                D::compile_atomic_swap(f, lhs, rhs, out)
653            }
654            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
655                D::compile_atomic_xor(f, lhs, rhs, out)
656            }
657            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
658            Instruction::Neg(UnaryInstruction { input, out }) => {
659                let out = out.fmt_left();
660                writeln!(f, "{out} = -{input};")
661            }
662            Instruction::Normalize(inst) => {
663                Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
664            }
665            Instruction::FastNormalize(inst) => {
666                Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
667            }
668            Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
669            Instruction::FastMagnitude(inst) => {
670                Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
671            }
672            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
673            Instruction::VecInit { inputs, out } => {
674                let item = out.item();
675                let inputs = inputs
676                    .iter()
677                    .map(|input| format!("{input}"))
678                    .collect::<Vec<_>>();
679                let out = out.fmt_left();
680                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
681            }
682            Instruction::Printf {
683                format_string,
684                args,
685            } => D::compile_instruction_printf(f, format_string, args),
686            Instruction::Comment { content } => {
687                if content.contains('\n') {
688                    writeln!(f, "/* {content} */")
689                } else {
690                    writeln!(f, "// {content}")
691                }
692            }
693            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
694            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
695            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
696            Instruction::ProxyAsyncToSharedFence => {
697                writeln!(
698                    f,
699                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
700                )
701            }
702            Instruction::BulkCommitGroup => writeln!(
703                f,
704                "cuda::device::experimental::cp_async_bulk_commit_group();"
705            ),
706            Instruction::BulkWaitGroup { max_pending } => writeln!(
707                f,
708                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
709            ),
710            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
711                f,
712                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
713            ),
714            Instruction::TmaReplacePointer {
715                buffer,
716                offset,
717                tensor_map,
718                out,
719            } => {
720                let pos = Variable::<D>::UnitPos;
721                writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
722                writeln!(
723                    f,
724                    "
725if({pos} == 0) {{
726    {out} = {tensor_map};
727    tensormap_replace_global_address({out}, &{buffer}[{offset}]);
728}}"
729                )?;
730                writeln!(f, "__syncthreads();")
731            }
732            Instruction::MemCopyAsyncTensorSharedToGlobal {
733                smem_buffer,
734                smem_offset,
735                tensor_map,
736                indices,
737            } => {
738                let rank = indices.len();
739                let smem_ptr = smem_buffer.fmt_ptr();
740                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
741                    let _ = write!(s, "{it}, ");
742                    s
743                });
744                writeln!(
745                    f,
746                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
747                )
748            }
749            Instruction::SpecialCast(UnaryInstruction { input, out }) => {
750                // Only supported in CUDA so I'm putting it here. Move to dialect if necessary.
751                #[cfg(not(feature = "cuda"))]
752                {
753                    let _ = (input, out);
754                    writeln!(
755                        f,
756                        "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
757                    )
758                }
759                #[cfg(feature = "cuda")]
760                crate::cuda::convert::special_cast::<D>(f, input, out)
761            }
762        }
763    }
764}
765
766struct Fma<D: Dialect> {
767    _dialect: PhantomData<D>,
768}
769
770impl<D: Dialect> Fma<D> {
771    fn format(
772        f: &mut core::fmt::Formatter<'_>,
773        a: &Variable<D>,
774        b: &Variable<D>,
775        c: &Variable<D>,
776        out: &Variable<D>,
777    ) -> core::fmt::Result {
778        let out_item = out.item();
779        let num = out_item.vectorization;
780
781        let out = out.fmt_left();
782        if num == 1 {
783            writeln!(f, "{out} = fma({a}, {b}, {c});")
784        } else {
785            writeln!(f, "{out} = {out_item}{{")?;
786
787            for i in 0..num {
788                let ai = a.index(i);
789                let bi = b.index(i);
790                let ci = c.index(i);
791
792                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
793            }
794            f.write_str("};\n")
795        }
796    }
797}
798
799struct Clamp<D: Dialect> {
800    _dialect: PhantomData<D>,
801}
802
803impl<D: Dialect> Clamp<D> {
804    fn format(
805        f: &mut core::fmt::Formatter<'_>,
806        input: &Variable<D>,
807        min_value: &Variable<D>,
808        max_value: &Variable<D>,
809        out: &Variable<D>,
810    ) -> core::fmt::Result {
811        let out_item = out.item();
812        if out.item().vectorization == 1 {
813            let out = out.fmt_left();
814            write!(f, "{out} = ")?;
815            Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
816            f.write_str(";\n")
817        } else {
818            Self::unroll_vec(f, input, min_value, max_value, out)
819        }
820    }
821
822    fn format_scalar(
823        f: &mut Formatter<'_>,
824        input: impl Component<D>,
825        min_value: impl Component<D>,
826        max_value: impl Component<D>,
827        item: Item<D>,
828    ) -> std::fmt::Result {
829        D::compile_instruction_max_function_name(f, item)?;
830        write!(f, "({min_value}, ")?;
831        D::compile_instruction_min_function_name(f, item)?;
832        write!(f, "({max_value}, {input}))")
833    }
834
835    fn unroll_vec(
836        f: &mut core::fmt::Formatter<'_>,
837        input: &Variable<D>,
838        min_value: &Variable<D>,
839        max_value: &Variable<D>,
840        out: &Variable<D>,
841    ) -> std::fmt::Result {
842        let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
843        let [input, min_value, max_value, out_optimized] = optimized.args;
844
845        let item_out_original = out.item();
846        let item_out_optimized = out_optimized.item();
847
848        let index = match optimized.optimization_factor {
849            Some(factor) => item_out_original.vectorization / factor,
850            None => item_out_optimized.vectorization,
851        };
852
853        let mut write_op = |input: &Variable<D>,
854                            min_value: &Variable<D>,
855                            max_value: &Variable<D>,
856                            out: &Variable<D>,
857                            item_out: Item<D>| {
858            let out = out.fmt_left();
859            writeln!(f, "{out} = {item_out}{{")?;
860            for i in 0..index {
861                let inputi = input.index(i);
862                let min_valuei = min_value.index(i);
863                let max_valuei = max_value.index(i);
864
865                Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
866                f.write_str(", ")?;
867            }
868
869            f.write_str("};\n")
870        };
871
872        if item_out_original == item_out_optimized {
873            write_op(&input, &min_value, &max_value, out, item_out_optimized)
874        } else {
875            let out_tmp = Variable::tmp(item_out_optimized);
876            write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
877            let addr_space = D::address_space_for_variable(out);
878            let out = out.fmt_left();
879
880            writeln!(
881                f,
882                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
883            )?;
884
885            Ok(())
886        }
887    }
888}
889
890struct Remainder<D: Dialect> {
891    _dialect: PhantomData<D>,
892}
893
894impl<D: Dialect> Remainder<D> {
895    fn format(
896        f: &mut core::fmt::Formatter<'_>,
897        lhs: &Variable<D>,
898        rhs: &Variable<D>,
899        out: &Variable<D>,
900    ) -> core::fmt::Result {
901        let floor = |elem| {
902            let prefix = match elem {
903                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
904                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
905                _ => "",
906            };
907            format!("{prefix}floor")
908        };
909
910        let is_int = matches!(
911            out.elem(),
912            Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
913        );
914        let rem_expr = |lhs, rhs, floor| {
915            if is_int {
916                format!("{lhs} - {rhs} * {floor}((float){lhs} / (float){rhs})")
917            } else {
918                format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
919            }
920        };
921
922        if out.item().vectorization == 1 {
923            let floor = floor(out.elem());
924
925            let out = out.fmt_left();
926            let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
927            return writeln!(f, "{out} = {rem};");
928        }
929
930        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
931        let [lhs, rhs, out_optimized] = optimized.args;
932
933        let item_out_original = out.item();
934        let item_out_optimized = out_optimized.item();
935
936        let index = match optimized.optimization_factor {
937            Some(factor) => item_out_original.vectorization / factor,
938            None => item_out_optimized.vectorization,
939        };
940
941        let floor = floor(*item_out_optimized.elem());
942
943        let mut write_op =
944            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
945                let out = out.fmt_left();
946                writeln!(f, "{out} = {item_out}{{")?;
947                for i in 0..index {
948                    let lhsi = lhs.index(i);
949                    let rhsi = rhs.index(i);
950
951                    let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
952                    writeln!(f, "{rem}")?;
953                    f.write_str(", ")?;
954                }
955
956                f.write_str("};\n")
957            };
958
959        if item_out_original == item_out_optimized {
960            write_op(&lhs, &rhs, out, item_out_optimized)
961        } else {
962            let out_tmp = Variable::tmp(item_out_optimized);
963
964            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
965
966            let addr_space = D::address_space_for_variable(&out_tmp);
967            let qualifier = out.const_qualifier();
968            let out = out.fmt_left();
969
970            writeln!(
971                f,
972                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
973            )?;
974
975            Ok(())
976        }
977    }
978}
979
980struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
981    _dialect: PhantomData<D>,
982    _sqrt: PhantomData<S>,
983}
984
985impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
986    fn format(
987        f: &mut core::fmt::Formatter<'_>,
988        input: &Variable<D>,
989        out: &Variable<D>,
990    ) -> core::fmt::Result {
991        let num = input.item().vectorization;
992        let elem = input.elem();
993
994        let mag = format!("{out}_mag");
995
996        writeln!(f, "{} {mag} = 0.0;", out.item())?;
997
998        for i in 0..num {
999            let input_i = input.index(i);
1000            writeln!(f, "{mag} += {input_i} * {input_i};")?;
1001        }
1002
1003        let out = out.fmt_left();
1004        write!(f, "{out} = ")?;
1005        S::format_unary(f, &mag, elem)?;
1006        f.write_str(";\n")
1007    }
1008}
1009
1010struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1011    _dialect: PhantomData<D>,
1012    _rsqrt: PhantomData<InvS>,
1013}
1014
1015impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1016    fn format(
1017        f: &mut core::fmt::Formatter<'_>,
1018        input: &Variable<D>,
1019        out: &Variable<D>,
1020    ) -> core::fmt::Result {
1021        let num = input.item().vectorization;
1022        let elem = input.elem();
1023        let norm = format!("{out}_norm");
1024
1025        let out_item = out.item();
1026        let out = out.fmt_left();
1027        writeln!(f, "{elem} {norm} = 0.0;")?;
1028
1029        for i in 0..num {
1030            let input_i = input.index(i);
1031            writeln!(f, "{norm} += {input_i} * {input_i};")?;
1032        }
1033
1034        write!(f, "{norm} = ")?;
1035        InvS::format_unary(f, &norm, elem)?;
1036        f.write_str(";\n")?;
1037
1038        if num == 1 {
1039            writeln!(f, "{out} = {input} * {norm};")
1040        } else {
1041            write!(f, "{out} = {out_item}{{")?;
1042            for i in 0..num {
1043                let input_i = input.index(i);
1044
1045                writeln!(f, "{input_i} * {norm},")?;
1046            }
1047
1048            f.write_str("};\n")
1049        }
1050    }
1051}
1052
1053struct Dot<D: Dialect> {
1054    _dialect: PhantomData<D>,
1055}
1056
1057impl<D: Dialect> Dot<D> {
1058    fn format(
1059        f: &mut core::fmt::Formatter<'_>,
1060        lhs: &Variable<D>,
1061        rhs: &Variable<D>,
1062        out: &Variable<D>,
1063    ) -> core::fmt::Result {
1064        let num = lhs.item().vectorization;
1065
1066        let muls = (0..num)
1067            .map(|i| {
1068                let lhs_i = lhs.index(i);
1069                let rhs_i = rhs.index(i);
1070                format!("{lhs_i} * {rhs_i}")
1071            })
1072            .collect::<Vec<_>>();
1073
1074        let out = out.fmt_left();
1075        writeln!(f, "{out} = {};", muls.join(" + "))
1076    }
1077}
1078
1079struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1080    var: &'a V,
1081    elem: &'a Elem<D>,
1082}
1083
1084impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1085    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1086        if self.elem != &Elem::Bool {
1087            write!(f, "bool({})", self.var)
1088        } else {
1089            write!(f, "{}", self.var)
1090        }
1091    }
1092}