cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Formatter, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18    pub lhs: Variable<D>,
19    pub rhs: Variable<D>,
20    pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25    pub list: Variable<D>,
26    pub index: Variable<D>,
27    pub line_size: u32,
28    pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33    pub index: Variable<D>,
34    pub value: Variable<D>,
35    pub line_size: u32,
36    pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41    pub input: Variable<D>,
42    pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47    Metadata {
48        info_offset: Variable<D>,
49        split_meta: bool,
50        out: Variable<D>,
51    },
52    ExtendedMetadata {
53        info_offset: Variable<D>,
54        dim: Variable<D>,
55        split_meta: bool,
56        static_offset: u32,
57        out: Variable<D>,
58    },
59    ConstLength {
60        length: u32,
61        out: Variable<D>,
62    },
63    SliceLength {
64        input: Variable<D>,
65        out: Variable<D>,
66    },
67    DeclareVariable {
68        var: Variable<D>,
69    },
70    Modulo(BinaryInstruction<D>),
71    Remainder(BinaryInstruction<D>),
72    Add(BinaryInstruction<D>),
73    SaturatingAdd(BinaryInstruction<D>),
74    Fma {
75        a: Variable<D>,
76        b: Variable<D>,
77        c: Variable<D>,
78        out: Variable<D>,
79    },
80    Div(BinaryInstruction<D>),
81    FastDiv(BinaryInstruction<D>),
82    FastRecip(UnaryInstruction<D>),
83    Mul(BinaryInstruction<D>),
84    Sub(BinaryInstruction<D>),
85    SaturatingSub(BinaryInstruction<D>),
86    HiMul(BinaryInstruction<D>),
87    Index(IndexInstruction<D>),
88    IndexAssign(IndexAssignInstruction<D>),
89    Assign(UnaryInstruction<D>),
90    SpecialCast(UnaryInstruction<D>),
91    RangeLoop {
92        i: Variable<D>,
93        start: Variable<D>,
94        end: Variable<D>,
95        step: Option<Variable<D>>,
96        inclusive: bool,
97        instructions: Vec<Self>,
98    },
99    VecInit {
100        inputs: Vec<Variable<D>>,
101        out: Variable<D>,
102    },
103    Loop {
104        instructions: Vec<Self>,
105    },
106    If {
107        cond: Variable<D>,
108        instructions: Vec<Self>,
109    },
110    IfElse {
111        cond: Variable<D>,
112        instructions_if: Vec<Self>,
113        instructions_else: Vec<Self>,
114    },
115    Select {
116        cond: Variable<D>,
117        then: Variable<D>,
118        or_else: Variable<D>,
119        out: Variable<D>,
120    },
121    Switch {
122        value: Variable<D>,
123        instructions_default: Vec<Self>,
124        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125    },
126    Slice {
127        input: Variable<D>,
128        start: Variable<D>,
129        end: Variable<D>,
130        out: Variable<D>,
131    },
132    CheckedSlice {
133        input: Variable<D>,
134        start: Variable<D>,
135        end: Variable<D>,
136        out: Variable<D>,
137        len: Variable<D>,
138    },
139    ReinterpretSlice {
140        input: Variable<D>,
141        line_size: u32,
142        out: Variable<D>,
143    },
144    Return,
145    Break,
146    Equal(BinaryInstruction<D>),
147    NotEqual(BinaryInstruction<D>),
148    Lower(BinaryInstruction<D>),
149    Greater(BinaryInstruction<D>),
150    LowerEqual(BinaryInstruction<D>),
151    GreaterEqual(BinaryInstruction<D>),
152    Erf(UnaryInstruction<D>),
153    BitwiseOr(BinaryInstruction<D>),
154    BitwiseAnd(BinaryInstruction<D>),
155    BitwiseXor(BinaryInstruction<D>),
156    CountBits(UnaryInstruction<D>),
157    ReverseBits(UnaryInstruction<D>),
158    ShiftLeft(BinaryInstruction<D>),
159    ShiftRight(BinaryInstruction<D>),
160    BitwiseNot(UnaryInstruction<D>),
161    LeadingZeros(UnaryInstruction<D>),
162    FindFirstSet(UnaryInstruction<D>),
163    Abs(UnaryInstruction<D>),
164    Exp(UnaryInstruction<D>),
165    FastExp(UnaryInstruction<D>),
166    Log(UnaryInstruction<D>),
167    FastLog(UnaryInstruction<D>),
168    Log1p(UnaryInstruction<D>),
169    Cos(UnaryInstruction<D>),
170    FastCos(UnaryInstruction<D>),
171    Sin(UnaryInstruction<D>),
172    FastSin(UnaryInstruction<D>),
173    Tanh(UnaryInstruction<D>),
174    FastTanh(UnaryInstruction<D>),
175    Powf(BinaryInstruction<D>),
176    FastPowf(BinaryInstruction<D>),
177    Powi(BinaryInstruction<D>),
178    Sqrt(UnaryInstruction<D>),
179    FastSqrt(UnaryInstruction<D>),
180    InverseSqrt(UnaryInstruction<D>),
181    FastInverseSqrt(UnaryInstruction<D>),
182    Min(BinaryInstruction<D>),
183    Max(BinaryInstruction<D>),
184    Not(UnaryInstruction<D>),
185    Or(BinaryInstruction<D>),
186    And(BinaryInstruction<D>),
187    Clamp {
188        input: Variable<D>,
189        min_value: Variable<D>,
190        max_value: Variable<D>,
191        out: Variable<D>,
192    },
193    IsNan(UnaryInstruction<D>),
194    IsInf(UnaryInstruction<D>),
195    SyncThreads,
196    SyncWarp,
197    ThreadFence,
198    ProxyAsyncToSharedFence,
199    BulkCommitGroup,
200    BulkWaitGroup {
201        max_pending: u32,
202    },
203    BulkWaitGroupRead {
204        max_pending: u32,
205    },
206    TmaReplacePointer {
207        buffer: Variable<D>,
208        offset: Variable<D>,
209        tensor_map: Variable<D>,
210        out: Variable<D>,
211    },
212    Round(UnaryInstruction<D>),
213    Ceil(UnaryInstruction<D>),
214    Trunc(UnaryInstruction<D>),
215    Floor(UnaryInstruction<D>),
216    Warp(WarpInstruction<D>),
217    Wmma(WmmaInstruction<D>),
218    Bitcast(UnaryInstruction<D>),
219    AtomicLoad(UnaryInstruction<D>),
220    AtomicStore(UnaryInstruction<D>),
221    AtomicSwap(BinaryInstruction<D>),
222    AtomicAdd(BinaryInstruction<D>),
223    AtomicSub(BinaryInstruction<D>),
224    AtomicMax(BinaryInstruction<D>),
225    AtomicMin(BinaryInstruction<D>),
226    AtomicAnd(BinaryInstruction<D>),
227    AtomicOr(BinaryInstruction<D>),
228    AtomicXor(BinaryInstruction<D>),
229    AtomicCAS {
230        input: Variable<D>,
231        cmp: Variable<D>,
232        val: Variable<D>,
233        out: Variable<D>,
234    },
235    Neg(UnaryInstruction<D>),
236    Magnitude(UnaryInstruction<D>),
237    FastMagnitude(UnaryInstruction<D>),
238    Normalize(UnaryInstruction<D>),
239    FastNormalize(UnaryInstruction<D>),
240    Dot(BinaryInstruction<D>),
241    Copy {
242        input: Variable<D>,
243        in_index: Variable<D>,
244        out: Variable<D>,
245        out_index: Variable<D>,
246    },
247    CopyBulk {
248        input: Variable<D>,
249        in_index: Variable<D>,
250        out: Variable<D>,
251        out_index: Variable<D>,
252        len: u32,
253    },
254    Printf {
255        format_string: String,
256        args: Vec<Variable<D>>,
257    },
258    Comment {
259        content: String,
260    },
261    Pipeline(PipelineOps<D>),
262    Barrier(BarrierOps<D>),
263    MemCopyAsyncTensorSharedToGlobal {
264        smem_buffer: Variable<D>,
265        smem_offset: Variable<D>,
266        tensor_map: Variable<D>,
267        indices: Vec<Variable<D>>,
268    },
269    Line {
270        file: Cow<'static, str>,
271        line: u32,
272    },
273}
274
275impl<D: Dialect> Display for Instruction<D> {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        match self {
278            Instruction::Return => f.write_str("return;"),
279            Instruction::Break => f.write_str("break;"),
280            Instruction::DeclareVariable { var } => match var {
281                Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
282                _ => {
283                    let item = var.item();
284                    writeln!(f, "{item} {var};")
285                }
286            },
287            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
288            Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
289            Instruction::Slice {
290                input,
291                start,
292                end,
293                out,
294            } => {
295                let item = out.item();
296                let addr_space = D::address_space_for_variable(input);
297                writeln!(f, "const uint {out}_length = {end} - {start};")?;
298                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
299            }
300            Instruction::CheckedSlice {
301                input,
302                start,
303                end,
304                out,
305                len,
306            } => {
307                let item = out.item();
308                let addr_space = D::address_space_for_variable(input);
309                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
310                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
311            }
312            Instruction::ReinterpretSlice {
313                input,
314                line_size,
315                out,
316            } => {
317                let mut item = out.item();
318                item.vectorization = *line_size as usize;
319                let addr_space = D::address_space_for_variable(input);
320
321                writeln!(
322                    f,
323                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
324                )
325            }
326            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
327            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
328            Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
329            Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
330            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
331            Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
332            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
333            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
334            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
335            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
336            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
337            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
338            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
339            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
340            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
341            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
342            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
343            Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
344            Instruction::IndexAssign(it) => {
345                IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
346            }
347            Instruction::Copy {
348                input,
349                in_index,
350                out,
351                out_index,
352            } => {
353                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
354            }
355            Instruction::CopyBulk {
356                input,
357                in_index,
358                out,
359                out_index,
360                len,
361            } => {
362                for i in 0..*len {
363                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
364                }
365                Ok(())
366            }
367            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
368            Instruction::RangeLoop {
369                i,
370                start,
371                end,
372                step,
373                inclusive,
374                instructions,
375            } => {
376                let increment = step
377                    .map(|step| format!("{i} += {step}"))
378                    .unwrap_or_else(|| format!("++{i}"));
379                let cmp = if *inclusive { "<=" } else { "<" };
380                let i_ty = i.item();
381
382                write!(
383                    f,
384                    "
385for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
386"
387                )?;
388                for instruction in instructions {
389                    write!(f, "{instruction}")?;
390                }
391
392                f.write_str("}\n")
393            }
394            Instruction::Loop { instructions } => {
395                writeln!(f, "while (true) {{")?;
396                for i in instructions {
397                    write!(f, "{i}")?;
398                }
399                f.write_str("}\n")
400            }
401            Instruction::If { cond, instructions } => {
402                writeln!(f, "if ({cond}) {{")?;
403                for i in instructions {
404                    write!(f, "{i}")?;
405                }
406                f.write_str("}\n")
407            }
408            Instruction::IfElse {
409                cond,
410                instructions_if,
411                instructions_else,
412            } => {
413                writeln!(f, "if ({cond}) {{")?;
414                for i in instructions_if {
415                    write!(f, "{i}")?;
416                }
417                f.write_str("} else {\n")?;
418                for i in instructions_else {
419                    write!(f, "{i}")?;
420                }
421                f.write_str("}\n")
422            }
423            Instruction::Select {
424                cond,
425                then,
426                or_else,
427                out,
428            } => {
429                let item_or_else = or_else.item();
430                let item_then = then.item();
431                let item_out = out.item();
432
433                let vf_then = item_then.vectorization;
434                let vf_or_else = item_or_else.vectorization;
435                let vf_out = item_out.vectorization;
436                let vf_cond = cond.item().vectorization;
437
438                let item_out = out.item();
439                let cond_elem = cond.item().elem;
440                let out = out.fmt_left();
441
442                let should_broadcast =
443                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
444
445                if should_broadcast {
446                    let vf = usize::max(vf_cond, vf_out);
447                    let vf = usize::max(vf, vf_then);
448                    let vf = usize::max(vf, vf_or_else);
449
450                    writeln!(f, "{out} = {item_out} {{")?;
451                    for i in 0..vf {
452                        let theni = then.index(i);
453                        let or_elsei = or_else.index(i);
454                        let condi = cond.index(i);
455                        let condi = EnsureBoolArg {
456                            var: &condi,
457                            elem: &cond_elem,
458                        };
459
460                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
461                    }
462
463                    writeln!(f, "}};")
464                } else {
465                    let cond = EnsureBoolArg {
466                        var: &cond,
467                        elem: &cond_elem,
468                    };
469                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
470                }
471            }
472            Instruction::Switch {
473                value,
474                instructions_default,
475                instructions_cases,
476            } => {
477                writeln!(f, "switch({value}) {{")?;
478                for (value, block) in instructions_cases {
479                    write!(f, "case {value}:\n{{\n")?;
480                    for i in block {
481                        i.fmt(f)?;
482                    }
483                    f.write_str("break;\n}\n")?;
484                }
485                f.write_str("default:\n{")?;
486                for i in instructions_default {
487                    i.fmt(f)?;
488                }
489                f.write_str("}\n}\n")
490            }
491            Instruction::Metadata {
492                info_offset,
493                split_meta,
494                out,
495            } => {
496                let out = out.fmt_left();
497                match *split_meta {
498                    true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
499                    false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
500                }
501            }
502            Instruction::ExtendedMetadata {
503                info_offset,
504                dim,
505                split_meta,
506                static_offset,
507                out,
508            } => {
509                let out = out.fmt_left();
510                match *split_meta {
511                    true => writeln!(
512                        f,
513                        "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
514                    ),
515                    false => writeln!(
516                        f,
517                        "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
518                    ),
519                }
520            }
521            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
522            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
523            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
524            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
525            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
526            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
527            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
528            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
529            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
530            Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
531            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
532            Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
533            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
534            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
535            Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
536            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
537            Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
538            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
539            Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
540            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
541            Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
542            Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
543            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
544            Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
545            Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
546            Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
547            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
548            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
549            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
550            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
551            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
552            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
553            Instruction::Clamp {
554                input,
555                min_value,
556                max_value,
557                out,
558            } => Clamp::format(f, input, min_value, max_value, out),
559            Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
560            Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
561            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
562            Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
563            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
564            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
565            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
566            Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
567            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
568            Instruction::SliceLength { input, out } => {
569                let out = out.fmt_left();
570                writeln!(f, "{out} = {input}_length;")
571            }
572            Instruction::ConstLength { length, out } => {
573                let out = out.fmt_left();
574                writeln!(f, "{out} = {length};")
575            }
576            Instruction::Warp(it) => write!(f, "{it}"),
577            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
578            Instruction::Wmma(it) => write!(f, "{it}"),
579            Instruction::Bitcast(UnaryInstruction { input, out }) => {
580                let qualifier = out.const_qualifier();
581                let input_item = input.item();
582                let out_item = out.item();
583
584                if out_item.elem.size() * out_item.vectorization
585                    != input.item().elem.size() * input.item().vectorization
586                {
587                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
588                } else {
589                    let out = out.fmt_left();
590                    let addr_space = D::address_space_for_variable(input);
591                    writeln!(
592                        f,
593                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
594                    )
595                }
596            }
597            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
598                D::compile_atomic_add(f, lhs, rhs, out)
599            }
600            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
601                D::compile_atomic_and(f, lhs, rhs, out)
602            }
603            Instruction::AtomicCAS {
604                input,
605                cmp,
606                val,
607                out,
608            } => D::compile_atomic_cas(f, input, cmp, val, out),
609            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
610                D::compile_atomic_load(f, input, out)
611            }
612            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
613                D::compile_atomic_max(f, lhs, rhs, out)
614            }
615            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
616                D::compile_atomic_min(f, lhs, rhs, out)
617            }
618            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
619                D::compile_atomic_or(f, lhs, rhs, out)
620            }
621            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
622                D::compile_atomic_store(f, input, out)
623            }
624            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
625                D::compile_atomic_sub(f, lhs, rhs, out)
626            }
627            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
628                D::compile_atomic_swap(f, lhs, rhs, out)
629            }
630            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
631                D::compile_atomic_xor(f, lhs, rhs, out)
632            }
633            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
634            Instruction::Neg(UnaryInstruction { input, out }) => {
635                let out = out.fmt_left();
636                writeln!(f, "{out} = -{input};")
637            }
638            Instruction::Normalize(inst) => {
639                Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
640            }
641            Instruction::FastNormalize(inst) => {
642                Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
643            }
644            Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
645            Instruction::FastMagnitude(inst) => {
646                Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
647            }
648            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
649            Instruction::VecInit { inputs, out } => {
650                let item = out.item();
651                let inputs = inputs
652                    .iter()
653                    .map(|input| format!("{input}"))
654                    .collect::<Vec<_>>();
655                let out = out.fmt_left();
656                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
657            }
658            Instruction::Printf {
659                format_string,
660                args,
661            } => D::compile_instruction_printf(f, format_string, args),
662            Instruction::Comment { content } => {
663                if content.contains('\n') {
664                    writeln!(f, "/* {content} */")
665                } else {
666                    writeln!(f, "// {content}")
667                }
668            }
669            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
670            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
671            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
672            Instruction::ProxyAsyncToSharedFence => {
673                writeln!(
674                    f,
675                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
676                )
677            }
678            Instruction::BulkCommitGroup => writeln!(
679                f,
680                "cuda::device::experimental::cp_async_bulk_commit_group();"
681            ),
682            Instruction::BulkWaitGroup { max_pending } => writeln!(
683                f,
684                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
685            ),
686            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
687                f,
688                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
689            ),
690            Instruction::TmaReplacePointer {
691                buffer,
692                offset,
693                tensor_map,
694                out,
695            } => {
696                let pos = Variable::<D>::UnitPos;
697                writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
698                writeln!(
699                    f,
700                    "
701if({pos} == 0) {{
702    {out} = {tensor_map};
703    tensormap_replace_global_address({out}, &{buffer}[{offset}]);
704}}"
705                )?;
706                writeln!(f, "__syncthreads();")
707            }
708            Instruction::MemCopyAsyncTensorSharedToGlobal {
709                smem_buffer,
710                smem_offset,
711                tensor_map,
712                indices,
713            } => {
714                let rank = indices.len();
715                let smem_ptr = smem_buffer.fmt_ptr();
716                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
717                    let _ = write!(s, "{it}, ");
718                    s
719                });
720                writeln!(
721                    f,
722                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
723                )
724            }
725            Instruction::SpecialCast(UnaryInstruction { input, out }) => {
726                // Only supported in CUDA so I'm putting it here. Move to dialect if necessary.
727                #[cfg(not(feature = "cuda"))]
728                {
729                    let _ = (input, out);
730                    unimplemented!("FP8/FP6/FP4 casting isn't supported outside of CUDA");
731                }
732                #[cfg(feature = "cuda")]
733                crate::cuda::convert::special_cast::<D>(f, input, out)
734            }
735        }
736    }
737}
738
739struct Fma<D: Dialect> {
740    _dialect: PhantomData<D>,
741}
742
743impl<D: Dialect> Fma<D> {
744    fn format(
745        f: &mut core::fmt::Formatter<'_>,
746        a: &Variable<D>,
747        b: &Variable<D>,
748        c: &Variable<D>,
749        out: &Variable<D>,
750    ) -> core::fmt::Result {
751        let out_item = out.item();
752        let num = out_item.vectorization;
753
754        let out = out.fmt_left();
755        if num == 1 {
756            writeln!(f, "{out} = fma({a}, {b}, {c});")
757        } else {
758            writeln!(f, "{out} = {out_item}{{")?;
759
760            for i in 0..num {
761                let ai = a.index(i);
762                let bi = b.index(i);
763                let ci = c.index(i);
764
765                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
766            }
767            f.write_str("};\n")
768        }
769    }
770}
771
772struct Clamp<D: Dialect> {
773    _dialect: PhantomData<D>,
774}
775
776impl<D: Dialect> Clamp<D> {
777    fn format(
778        f: &mut core::fmt::Formatter<'_>,
779        input: &Variable<D>,
780        min_value: &Variable<D>,
781        max_value: &Variable<D>,
782        out: &Variable<D>,
783    ) -> core::fmt::Result {
784        let out_item = out.item();
785        if out.item().vectorization == 1 {
786            let out = out.fmt_left();
787            write!(f, "{out} = ")?;
788            Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
789            f.write_str(";\n")
790        } else {
791            Self::unroll_vec(f, input, min_value, max_value, out)
792        }
793    }
794
795    fn format_scalar(
796        f: &mut Formatter<'_>,
797        input: impl Component<D>,
798        min_value: impl Component<D>,
799        max_value: impl Component<D>,
800        item: Item<D>,
801    ) -> std::fmt::Result {
802        D::compile_instruction_max_function_name(f, item)?;
803        write!(f, "({min_value}, ")?;
804        D::compile_instruction_min_function_name(f, item)?;
805        write!(f, "({max_value}, {input}))")
806    }
807
808    fn unroll_vec(
809        f: &mut core::fmt::Formatter<'_>,
810        input: &Variable<D>,
811        min_value: &Variable<D>,
812        max_value: &Variable<D>,
813        out: &Variable<D>,
814    ) -> std::fmt::Result {
815        let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
816        let [input, min_value, max_value, out_optimized] = optimized.args;
817
818        let item_out_original = out.item();
819        let item_out_optimized = out_optimized.item();
820
821        let index = match optimized.optimization_factor {
822            Some(factor) => item_out_original.vectorization / factor,
823            None => item_out_optimized.vectorization,
824        };
825
826        let mut write_op = |input: &Variable<D>,
827                            min_value: &Variable<D>,
828                            max_value: &Variable<D>,
829                            out: &Variable<D>,
830                            item_out: Item<D>| {
831            let out = out.fmt_left();
832            writeln!(f, "{out} = {item_out}{{")?;
833            for i in 0..index {
834                let inputi = input.index(i);
835                let min_valuei = min_value.index(i);
836                let max_valuei = max_value.index(i);
837
838                Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
839                f.write_str(", ")?;
840            }
841
842            f.write_str("};\n")
843        };
844
845        if item_out_original == item_out_optimized {
846            write_op(&input, &min_value, &max_value, out, item_out_optimized)
847        } else {
848            let out_tmp = Variable::tmp(item_out_optimized);
849            write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
850            let addr_space = D::address_space_for_variable(out);
851            let out = out.fmt_left();
852
853            writeln!(
854                f,
855                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
856            )?;
857
858            Ok(())
859        }
860    }
861}
862
863struct Remainder<D: Dialect> {
864    _dialect: PhantomData<D>,
865}
866
867impl<D: Dialect> Remainder<D> {
868    fn format(
869        f: &mut core::fmt::Formatter<'_>,
870        lhs: &Variable<D>,
871        rhs: &Variable<D>,
872        out: &Variable<D>,
873    ) -> core::fmt::Result {
874        let floor = |elem| {
875            let prefix = match elem {
876                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
877                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
878                _ => "",
879            };
880            format!("{prefix}floor")
881        };
882
883        if out.item().vectorization == 1 {
884            let floor = floor(out.elem());
885
886            let out = out.fmt_left();
887            return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
888        }
889
890        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
891        let [lhs, rhs, out_optimized] = optimized.args;
892
893        let item_out_original = out.item();
894        let item_out_optimized = out_optimized.item();
895
896        let index = match optimized.optimization_factor {
897            Some(factor) => item_out_original.vectorization / factor,
898            None => item_out_optimized.vectorization,
899        };
900
901        let floor = floor(*item_out_optimized.elem());
902
903        let mut write_op =
904            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
905                let out = out.fmt_left();
906                writeln!(f, "{out} = {item_out}{{")?;
907                for i in 0..index {
908                    let lhsi = lhs.index(i);
909                    let rhsi = rhs.index(i);
910
911                    writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
912                    f.write_str(", ")?;
913                }
914
915                f.write_str("};\n")
916            };
917
918        if item_out_original == item_out_optimized {
919            write_op(&lhs, &rhs, out, item_out_optimized)
920        } else {
921            let out_tmp = Variable::tmp(item_out_optimized);
922
923            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
924
925            let addr_space = D::address_space_for_variable(&out_tmp);
926            let qualifier = out.const_qualifier();
927            let out = out.fmt_left();
928
929            writeln!(
930                f,
931                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
932            )?;
933
934            Ok(())
935        }
936    }
937}
938
939struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
940    _dialect: PhantomData<D>,
941    _sqrt: PhantomData<S>,
942}
943
944impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
945    fn format(
946        f: &mut core::fmt::Formatter<'_>,
947        input: &Variable<D>,
948        out: &Variable<D>,
949    ) -> core::fmt::Result {
950        let num = input.item().vectorization;
951        let elem = input.elem();
952
953        let mag = format!("{out}_mag");
954
955        writeln!(f, "{} {mag} = 0.0;", out.item())?;
956
957        for i in 0..num {
958            let input_i = input.index(i);
959            writeln!(f, "{mag} += {input_i} * {input_i};")?;
960        }
961
962        let out = out.fmt_left();
963        write!(f, "{out} = ")?;
964        S::format_unary(f, &mag, elem)?;
965        f.write_str(";\n")
966    }
967}
968
969struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
970    _dialect: PhantomData<D>,
971    _rsqrt: PhantomData<InvS>,
972}
973
974impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
975    fn format(
976        f: &mut core::fmt::Formatter<'_>,
977        input: &Variable<D>,
978        out: &Variable<D>,
979    ) -> core::fmt::Result {
980        let num = input.item().vectorization;
981        let elem = input.elem();
982        let norm = format!("{out}_norm");
983
984        let out_item = out.item();
985        let out = out.fmt_left();
986        writeln!(f, "{elem} {norm} = 0.0;")?;
987
988        for i in 0..num {
989            let input_i = input.index(i);
990            writeln!(f, "{norm} += {input_i} * {input_i};")?;
991        }
992
993        write!(f, "{norm} = ")?;
994        InvS::format_unary(f, &norm, elem)?;
995        f.write_str(";\n")?;
996
997        if num == 1 {
998            writeln!(f, "{out} = {input} * {norm};")
999        } else {
1000            write!(f, "{out} = {out_item}{{")?;
1001            for i in 0..num {
1002                let input_i = input.index(i);
1003
1004                writeln!(f, "{input_i} * {norm},")?;
1005            }
1006
1007            f.write_str("};\n")
1008        }
1009    }
1010}
1011
1012struct Dot<D: Dialect> {
1013    _dialect: PhantomData<D>,
1014}
1015
1016impl<D: Dialect> Dot<D> {
1017    fn format(
1018        f: &mut core::fmt::Formatter<'_>,
1019        lhs: &Variable<D>,
1020        rhs: &Variable<D>,
1021        out: &Variable<D>,
1022    ) -> core::fmt::Result {
1023        let num = lhs.item().vectorization;
1024
1025        let muls = (0..num)
1026            .map(|i| {
1027                let lhs_i = lhs.index(i);
1028                let rhs_i = rhs.index(i);
1029                format!("{lhs_i} * {rhs_i}")
1030            })
1031            .collect::<Vec<_>>();
1032
1033        let out = out.fmt_left();
1034        writeln!(f, "{out} = {};", muls.join(" + "))
1035    }
1036}
1037
1038struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1039    var: &'a V,
1040    elem: &'a Elem<D>,
1041}
1042
1043impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1044    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1045        if self.elem != &Elem::Bool {
1046            write!(f, "bool({})", self.var)
1047        } else {
1048            write!(f, "{}", self.var)
1049        }
1050    }
1051}