cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Formatter, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18    pub lhs: Variable<D>,
19    pub rhs: Variable<D>,
20    pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25    pub list: Variable<D>,
26    pub index: Variable<D>,
27    pub line_size: u32,
28    pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33    pub index: Variable<D>,
34    pub value: Variable<D>,
35    pub line_size: u32,
36    pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41    pub input: Variable<D>,
42    pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47    Metadata {
48        info_offset: Variable<D>,
49        split_meta: bool,
50        out: Variable<D>,
51    },
52    ExtendedMetadata {
53        info_offset: Variable<D>,
54        dim: Variable<D>,
55        split_meta: bool,
56        static_offset: u32,
57        out: Variable<D>,
58    },
59    ConstLength {
60        length: u32,
61        out: Variable<D>,
62    },
63    SliceLength {
64        input: Variable<D>,
65        out: Variable<D>,
66    },
67    DeclareVariable {
68        var: Variable<D>,
69    },
70    Modulo(BinaryInstruction<D>),
71    Remainder(BinaryInstruction<D>),
72    Add(BinaryInstruction<D>),
73    SaturatingAdd(BinaryInstruction<D>),
74    Fma {
75        a: Variable<D>,
76        b: Variable<D>,
77        c: Variable<D>,
78        out: Variable<D>,
79    },
80    Div(BinaryInstruction<D>),
81    FastDiv(BinaryInstruction<D>),
82    FastRecip(UnaryInstruction<D>),
83    Mul(BinaryInstruction<D>),
84    Sub(BinaryInstruction<D>),
85    SaturatingSub(BinaryInstruction<D>),
86    HiMul(BinaryInstruction<D>),
87    Index(IndexInstruction<D>),
88    IndexAssign(IndexAssignInstruction<D>),
89    Assign(UnaryInstruction<D>),
90    SpecialCast(UnaryInstruction<D>),
91    RangeLoop {
92        i: Variable<D>,
93        start: Variable<D>,
94        end: Variable<D>,
95        step: Option<Variable<D>>,
96        inclusive: bool,
97        instructions: Vec<Self>,
98    },
99    VecInit {
100        inputs: Vec<Variable<D>>,
101        out: Variable<D>,
102    },
103    Loop {
104        instructions: Vec<Self>,
105    },
106    If {
107        cond: Variable<D>,
108        instructions: Vec<Self>,
109    },
110    IfElse {
111        cond: Variable<D>,
112        instructions_if: Vec<Self>,
113        instructions_else: Vec<Self>,
114    },
115    Select {
116        cond: Variable<D>,
117        then: Variable<D>,
118        or_else: Variable<D>,
119        out: Variable<D>,
120    },
121    Switch {
122        value: Variable<D>,
123        instructions_default: Vec<Self>,
124        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125    },
126    Slice {
127        input: Variable<D>,
128        start: Variable<D>,
129        end: Variable<D>,
130        out: Variable<D>,
131    },
132    CheckedSlice {
133        input: Variable<D>,
134        start: Variable<D>,
135        end: Variable<D>,
136        out: Variable<D>,
137        len: Variable<D>,
138    },
139    ReinterpretSlice {
140        input: Variable<D>,
141        line_size: u32,
142        out: Variable<D>,
143    },
144    Return,
145    Break,
146    Equal(BinaryInstruction<D>),
147    NotEqual(BinaryInstruction<D>),
148    Lower(BinaryInstruction<D>),
149    Greater(BinaryInstruction<D>),
150    LowerEqual(BinaryInstruction<D>),
151    GreaterEqual(BinaryInstruction<D>),
152    Erf(UnaryInstruction<D>),
153    BitwiseOr(BinaryInstruction<D>),
154    BitwiseAnd(BinaryInstruction<D>),
155    BitwiseXor(BinaryInstruction<D>),
156    CountBits(UnaryInstruction<D>),
157    ReverseBits(UnaryInstruction<D>),
158    ShiftLeft(BinaryInstruction<D>),
159    ShiftRight(BinaryInstruction<D>),
160    BitwiseNot(UnaryInstruction<D>),
161    LeadingZeros(UnaryInstruction<D>),
162    FindFirstSet(UnaryInstruction<D>),
163    Abs(UnaryInstruction<D>),
164    Exp(UnaryInstruction<D>),
165    FastExp(UnaryInstruction<D>),
166    Log(UnaryInstruction<D>),
167    FastLog(UnaryInstruction<D>),
168    Log1p(UnaryInstruction<D>),
169    Cos(UnaryInstruction<D>),
170    Sin(UnaryInstruction<D>),
171    Tan(UnaryInstruction<D>),
172    Tanh(UnaryInstruction<D>),
173    Sinh(UnaryInstruction<D>),
174    Cosh(UnaryInstruction<D>),
175    ArcCos(UnaryInstruction<D>),
176    ArcSin(UnaryInstruction<D>),
177    ArcTan(UnaryInstruction<D>),
178    ArcSinh(UnaryInstruction<D>),
179    ArcCosh(UnaryInstruction<D>),
180    ArcTanh(UnaryInstruction<D>),
181    Degrees(UnaryInstruction<D>),
182    Radians(UnaryInstruction<D>),
183    ArcTan2(BinaryInstruction<D>),
184    FastSin(UnaryInstruction<D>),
185    FastCos(UnaryInstruction<D>),
186    FastTanh(UnaryInstruction<D>),
187    Powf(BinaryInstruction<D>),
188    FastPowf(BinaryInstruction<D>),
189    Powi(BinaryInstruction<D>),
190    Hypot(BinaryInstruction<D>),
191    Rhypot(BinaryInstruction<D>),
192    Sqrt(UnaryInstruction<D>),
193    FastSqrt(UnaryInstruction<D>),
194    InverseSqrt(UnaryInstruction<D>),
195    FastInverseSqrt(UnaryInstruction<D>),
196    Min(BinaryInstruction<D>),
197    Max(BinaryInstruction<D>),
198    Not(UnaryInstruction<D>),
199    Or(BinaryInstruction<D>),
200    And(BinaryInstruction<D>),
201    Clamp {
202        input: Variable<D>,
203        min_value: Variable<D>,
204        max_value: Variable<D>,
205        out: Variable<D>,
206    },
207    IsNan(UnaryInstruction<D>),
208    IsInf(UnaryInstruction<D>),
209    SyncThreads,
210    SyncWarp,
211    ThreadFence,
212    ProxyAsyncToSharedFence,
213    BulkCommitGroup,
214    BulkWaitGroup {
215        max_pending: u32,
216    },
217    BulkWaitGroupRead {
218        max_pending: u32,
219    },
220    TmaReplacePointer {
221        buffer: Variable<D>,
222        offset: Variable<D>,
223        tensor_map: Variable<D>,
224        out: Variable<D>,
225    },
226    Round(UnaryInstruction<D>),
227    Ceil(UnaryInstruction<D>),
228    Trunc(UnaryInstruction<D>),
229    Floor(UnaryInstruction<D>),
230    Warp(WarpInstruction<D>),
231    Wmma(WmmaInstruction<D>),
232    Bitcast(UnaryInstruction<D>),
233    AtomicLoad(UnaryInstruction<D>),
234    AtomicStore(UnaryInstruction<D>),
235    AtomicSwap(BinaryInstruction<D>),
236    AtomicAdd(BinaryInstruction<D>),
237    AtomicSub(BinaryInstruction<D>),
238    AtomicMax(BinaryInstruction<D>),
239    AtomicMin(BinaryInstruction<D>),
240    AtomicAnd(BinaryInstruction<D>),
241    AtomicOr(BinaryInstruction<D>),
242    AtomicXor(BinaryInstruction<D>),
243    AtomicCAS {
244        input: Variable<D>,
245        cmp: Variable<D>,
246        val: Variable<D>,
247        out: Variable<D>,
248    },
249    Neg(UnaryInstruction<D>),
250    Magnitude(UnaryInstruction<D>),
251    FastMagnitude(UnaryInstruction<D>),
252    Normalize(UnaryInstruction<D>),
253    FastNormalize(UnaryInstruction<D>),
254    Dot(BinaryInstruction<D>),
255    Copy {
256        input: Variable<D>,
257        in_index: Variable<D>,
258        out: Variable<D>,
259        out_index: Variable<D>,
260    },
261    CopyBulk {
262        input: Variable<D>,
263        in_index: Variable<D>,
264        out: Variable<D>,
265        out_index: Variable<D>,
266        len: u32,
267    },
268    Printf {
269        format_string: String,
270        args: Vec<Variable<D>>,
271    },
272    Comment {
273        content: String,
274    },
275    Pipeline(PipelineOps<D>),
276    Barrier(BarrierOps<D>),
277    MemCopyAsyncTensorSharedToGlobal {
278        smem_buffer: Variable<D>,
279        smem_offset: Variable<D>,
280        tensor_map: Variable<D>,
281        indices: Vec<Variable<D>>,
282    },
283    Line {
284        file: Cow<'static, str>,
285        line: u32,
286    },
287}
288
289impl<D: Dialect> Display for Instruction<D> {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        match self {
292            Instruction::Return => f.write_str("return;"),
293            Instruction::Break => f.write_str("break;"),
294            Instruction::DeclareVariable { var } => match var {
295                Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
296                _ => {
297                    let item = var.item();
298                    writeln!(f, "{item} {var};")
299                }
300            },
301            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
302            Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
303            Instruction::Slice {
304                input,
305                start,
306                end,
307                out,
308            } => {
309                let item = out.item();
310                let addr_space = D::address_space_for_variable(input);
311                writeln!(f, "const uint {out}_length = {end} - {start};")?;
312                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
313            }
314            Instruction::CheckedSlice {
315                input,
316                start,
317                end,
318                out,
319                len,
320            } => {
321                let item = out.item();
322                let addr_space = D::address_space_for_variable(input);
323                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
324                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
325            }
326            Instruction::ReinterpretSlice {
327                input,
328                line_size,
329                out,
330            } => {
331                let mut item = out.item();
332                item.vectorization = *line_size as usize;
333                let addr_space = D::address_space_for_variable(input);
334
335                writeln!(
336                    f,
337                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
338                )
339            }
340            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
341            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
342            Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
343            Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
344            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
345            Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
346            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
347            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
348            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
349            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
350            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
351            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
352            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
353            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
354            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
355            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
356            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
357            Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
358            Instruction::IndexAssign(it) => {
359                IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
360            }
361            Instruction::Copy {
362                input,
363                in_index,
364                out,
365                out_index,
366            } => {
367                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
368            }
369            Instruction::CopyBulk {
370                input,
371                in_index,
372                out,
373                out_index,
374                len,
375            } => {
376                for i in 0..*len {
377                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
378                }
379                Ok(())
380            }
381            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
382            Instruction::RangeLoop {
383                i,
384                start,
385                end,
386                step,
387                inclusive,
388                instructions,
389            } => {
390                let increment = step
391                    .map(|step| format!("{i} += {step}"))
392                    .unwrap_or_else(|| format!("++{i}"));
393                let cmp = if *inclusive { "<=" } else { "<" };
394                let i_ty = i.item();
395
396                write!(
397                    f,
398                    "
399for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
400"
401                )?;
402                for instruction in instructions {
403                    write!(f, "{instruction}")?;
404                }
405
406                f.write_str("}\n")
407            }
408            Instruction::Loop { instructions } => {
409                writeln!(f, "while (true) {{")?;
410                for i in instructions {
411                    write!(f, "{i}")?;
412                }
413                f.write_str("}\n")
414            }
415            Instruction::If { cond, instructions } => {
416                writeln!(f, "if ({cond}) {{")?;
417                for i in instructions {
418                    write!(f, "{i}")?;
419                }
420                f.write_str("}\n")
421            }
422            Instruction::IfElse {
423                cond,
424                instructions_if,
425                instructions_else,
426            } => {
427                writeln!(f, "if ({cond}) {{")?;
428                for i in instructions_if {
429                    write!(f, "{i}")?;
430                }
431                f.write_str("} else {\n")?;
432                for i in instructions_else {
433                    write!(f, "{i}")?;
434                }
435                f.write_str("}\n")
436            }
437            Instruction::Select {
438                cond,
439                then,
440                or_else,
441                out,
442            } => {
443                let item_or_else = or_else.item();
444                let item_then = then.item();
445                let item_out = out.item();
446
447                let vf_then = item_then.vectorization;
448                let vf_or_else = item_or_else.vectorization;
449                let vf_out = item_out.vectorization;
450                let vf_cond = cond.item().vectorization;
451
452                let item_out = out.item();
453                let cond_elem = cond.item().elem;
454                let out = out.fmt_left();
455
456                let should_broadcast =
457                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
458
459                if should_broadcast {
460                    let vf = usize::max(vf_cond, vf_out);
461                    let vf = usize::max(vf, vf_then);
462                    let vf = usize::max(vf, vf_or_else);
463
464                    writeln!(f, "{out} = {item_out} {{")?;
465                    for i in 0..vf {
466                        let theni = then.index(i);
467                        let or_elsei = or_else.index(i);
468                        let condi = cond.index(i);
469                        let condi = EnsureBoolArg {
470                            var: &condi,
471                            elem: &cond_elem,
472                        };
473
474                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
475                    }
476
477                    writeln!(f, "}};")
478                } else {
479                    let cond = EnsureBoolArg {
480                        var: &cond,
481                        elem: &cond_elem,
482                    };
483                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
484                }
485            }
486            Instruction::Switch {
487                value,
488                instructions_default,
489                instructions_cases,
490            } => {
491                writeln!(f, "switch({value}) {{")?;
492                for (value, block) in instructions_cases {
493                    write!(f, "case {value}:\n{{\n")?;
494                    for i in block {
495                        i.fmt(f)?;
496                    }
497                    f.write_str("break;\n}\n")?;
498                }
499                f.write_str("default:\n{")?;
500                for i in instructions_default {
501                    i.fmt(f)?;
502                }
503                f.write_str("}\n}\n")
504            }
505            Instruction::Metadata {
506                info_offset,
507                split_meta,
508                out,
509            } => {
510                let out = out.fmt_left();
511                match *split_meta {
512                    true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
513                    false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
514                }
515            }
516            Instruction::ExtendedMetadata {
517                info_offset,
518                dim,
519                split_meta,
520                static_offset,
521                out,
522            } => {
523                let out = out.fmt_left();
524                match *split_meta {
525                    true => writeln!(
526                        f,
527                        "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
528                    ),
529                    false => writeln!(
530                        f,
531                        "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
532                    ),
533                }
534            }
535            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
536            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
537            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
538            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
539            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
540            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
541            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
542            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
543            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
544            Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
545            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
546            Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
547            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
548            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
549            Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
550            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
551            Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
552            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
553            Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
554            Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
555            Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
556            Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
557            Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
558            Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
559            Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
560            Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
561            Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
562            Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
563            Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
564            Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
565            Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
566            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
567            Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
568            Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
569            Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
570            Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
571            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
572            Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
573            Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
574            Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
575            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
576            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
577            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
578            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
579            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
580            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
581            Instruction::Clamp {
582                input,
583                min_value,
584                max_value,
585                out,
586            } => Clamp::format(f, input, min_value, max_value, out),
587            Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
588            Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
589            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
590            Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
591            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
592            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
593            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
594            Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
595            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
596            Instruction::SliceLength { input, out } => {
597                let out = out.fmt_left();
598                writeln!(f, "{out} = {input}_length;")
599            }
600            Instruction::ConstLength { length, out } => {
601                let out = out.fmt_left();
602                writeln!(f, "{out} = {length};")
603            }
604            Instruction::Warp(it) => write!(f, "{it}"),
605            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
606            Instruction::Wmma(it) => write!(f, "{it}"),
607            Instruction::Bitcast(UnaryInstruction { input, out }) => {
608                let qualifier = out.const_qualifier();
609                let input_item = input.item();
610                let out_item = out.item();
611
612                if out_item.elem.size() * out_item.vectorization
613                    != input.item().elem.size() * input.item().vectorization
614                {
615                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
616                } else {
617                    let out = out.fmt_left();
618                    let addr_space = D::address_space_for_variable(input);
619                    writeln!(
620                        f,
621                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
622                    )
623                }
624            }
625            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
626                D::compile_atomic_add(f, lhs, rhs, out)
627            }
628            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
629                D::compile_atomic_and(f, lhs, rhs, out)
630            }
631            Instruction::AtomicCAS {
632                input,
633                cmp,
634                val,
635                out,
636            } => D::compile_atomic_cas(f, input, cmp, val, out),
637            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
638                D::compile_atomic_load(f, input, out)
639            }
640            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
641                D::compile_atomic_max(f, lhs, rhs, out)
642            }
643            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
644                D::compile_atomic_min(f, lhs, rhs, out)
645            }
646            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
647                D::compile_atomic_or(f, lhs, rhs, out)
648            }
649            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
650                D::compile_atomic_store(f, input, out)
651            }
652            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
653                D::compile_atomic_sub(f, lhs, rhs, out)
654            }
655            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
656                D::compile_atomic_swap(f, lhs, rhs, out)
657            }
658            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
659                D::compile_atomic_xor(f, lhs, rhs, out)
660            }
661            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
662            Instruction::Neg(UnaryInstruction { input, out }) => {
663                let out = out.fmt_left();
664                writeln!(f, "{out} = -{input};")
665            }
666            Instruction::Normalize(inst) => {
667                Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
668            }
669            Instruction::FastNormalize(inst) => {
670                Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
671            }
672            Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
673            Instruction::FastMagnitude(inst) => {
674                Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
675            }
676            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
677            Instruction::VecInit { inputs, out } => {
678                let item = out.item();
679                let inputs = inputs
680                    .iter()
681                    .map(|input| format!("{input}"))
682                    .collect::<Vec<_>>();
683                let out = out.fmt_left();
684                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
685            }
686            Instruction::Printf {
687                format_string,
688                args,
689            } => D::compile_instruction_printf(f, format_string, args),
690            Instruction::Comment { content } => {
691                if content.contains('\n') {
692                    writeln!(f, "/* {content} */")
693                } else {
694                    writeln!(f, "// {content}")
695                }
696            }
697            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
698            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
699            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
700            Instruction::ProxyAsyncToSharedFence => {
701                writeln!(
702                    f,
703                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
704                )
705            }
706            Instruction::BulkCommitGroup => writeln!(
707                f,
708                "cuda::device::experimental::cp_async_bulk_commit_group();"
709            ),
710            Instruction::BulkWaitGroup { max_pending } => writeln!(
711                f,
712                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
713            ),
714            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
715                f,
716                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
717            ),
718            Instruction::TmaReplacePointer {
719                buffer,
720                offset,
721                tensor_map,
722                out,
723            } => {
724                let pos = Variable::<D>::UnitPos;
725                writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
726                writeln!(
727                    f,
728                    "
729if({pos} == 0) {{
730    {out} = {tensor_map};
731    tensormap_replace_global_address({out}, &{buffer}[{offset}]);
732}}"
733                )?;
734                writeln!(f, "__syncthreads();")
735            }
736            Instruction::MemCopyAsyncTensorSharedToGlobal {
737                smem_buffer,
738                smem_offset,
739                tensor_map,
740                indices,
741            } => {
742                let rank = indices.len();
743                let smem_ptr = smem_buffer.fmt_ptr();
744                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
745                    let _ = write!(s, "{it}, ");
746                    s
747                });
748                writeln!(
749                    f,
750                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
751                )
752            }
753            Instruction::SpecialCast(UnaryInstruction { input, out }) => {
754                // Only supported in CUDA so I'm putting it here. Move to dialect if necessary.
755                #[cfg(not(feature = "cuda"))]
756                {
757                    let _ = (input, out);
758                    writeln!(
759                        f,
760                        "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
761                    )
762                }
763                #[cfg(feature = "cuda")]
764                crate::cuda::convert::special_cast::<D>(f, input, out)
765            }
766        }
767    }
768}
769
770struct Fma<D: Dialect> {
771    _dialect: PhantomData<D>,
772}
773
774impl<D: Dialect> Fma<D> {
775    fn format(
776        f: &mut core::fmt::Formatter<'_>,
777        a: &Variable<D>,
778        b: &Variable<D>,
779        c: &Variable<D>,
780        out: &Variable<D>,
781    ) -> core::fmt::Result {
782        let out_item = out.item();
783        let num = out_item.vectorization;
784
785        let out = out.fmt_left();
786        if num == 1 {
787            writeln!(f, "{out} = fma({a}, {b}, {c});")
788        } else {
789            writeln!(f, "{out} = {out_item}{{")?;
790
791            for i in 0..num {
792                let ai = a.index(i);
793                let bi = b.index(i);
794                let ci = c.index(i);
795
796                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
797            }
798            f.write_str("};\n")
799        }
800    }
801}
802
803struct Clamp<D: Dialect> {
804    _dialect: PhantomData<D>,
805}
806
807impl<D: Dialect> Clamp<D> {
808    fn format(
809        f: &mut core::fmt::Formatter<'_>,
810        input: &Variable<D>,
811        min_value: &Variable<D>,
812        max_value: &Variable<D>,
813        out: &Variable<D>,
814    ) -> core::fmt::Result {
815        let out_item = out.item();
816        if out.item().vectorization == 1 {
817            let out = out.fmt_left();
818            write!(f, "{out} = ")?;
819            Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
820            f.write_str(";\n")
821        } else {
822            Self::unroll_vec(f, input, min_value, max_value, out)
823        }
824    }
825
826    fn format_scalar(
827        f: &mut Formatter<'_>,
828        input: impl Component<D>,
829        min_value: impl Component<D>,
830        max_value: impl Component<D>,
831        item: Item<D>,
832    ) -> std::fmt::Result {
833        D::compile_instruction_max_function_name(f, item)?;
834        write!(f, "({min_value}, ")?;
835        D::compile_instruction_min_function_name(f, item)?;
836        write!(f, "({max_value}, {input}))")
837    }
838
839    fn unroll_vec(
840        f: &mut core::fmt::Formatter<'_>,
841        input: &Variable<D>,
842        min_value: &Variable<D>,
843        max_value: &Variable<D>,
844        out: &Variable<D>,
845    ) -> std::fmt::Result {
846        let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
847        let [input, min_value, max_value, out_optimized] = optimized.args;
848
849        let item_out_original = out.item();
850        let item_out_optimized = out_optimized.item();
851
852        let index = match optimized.optimization_factor {
853            Some(factor) => item_out_original.vectorization / factor,
854            None => item_out_optimized.vectorization,
855        };
856
857        let mut write_op = |input: &Variable<D>,
858                            min_value: &Variable<D>,
859                            max_value: &Variable<D>,
860                            out: &Variable<D>,
861                            item_out: Item<D>| {
862            let out = out.fmt_left();
863            writeln!(f, "{out} = {item_out}{{")?;
864            for i in 0..index {
865                let inputi = input.index(i);
866                let min_valuei = min_value.index(i);
867                let max_valuei = max_value.index(i);
868
869                Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
870                f.write_str(", ")?;
871            }
872
873            f.write_str("};\n")
874        };
875
876        if item_out_original == item_out_optimized {
877            write_op(&input, &min_value, &max_value, out, item_out_optimized)
878        } else {
879            let out_tmp = Variable::tmp(item_out_optimized);
880            write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
881            let addr_space = D::address_space_for_variable(out);
882            let out = out.fmt_left();
883
884            writeln!(
885                f,
886                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
887            )?;
888
889            Ok(())
890        }
891    }
892}
893
894struct Remainder<D: Dialect> {
895    _dialect: PhantomData<D>,
896}
897
898impl<D: Dialect> Remainder<D> {
899    fn format(
900        f: &mut core::fmt::Formatter<'_>,
901        lhs: &Variable<D>,
902        rhs: &Variable<D>,
903        out: &Variable<D>,
904    ) -> core::fmt::Result {
905        let floor = |elem| {
906            let prefix = match elem {
907                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
908                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
909                _ => "",
910            };
911            format!("{prefix}floor")
912        };
913
914        let is_int = matches!(
915            out.elem(),
916            Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
917        );
918        let rem_expr = |lhs, rhs, floor| {
919            if is_int {
920                format!("{lhs} - {rhs} * {floor}((float){lhs} / (float){rhs})")
921            } else {
922                format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
923            }
924        };
925
926        if out.item().vectorization == 1 {
927            let floor = floor(out.elem());
928
929            let out = out.fmt_left();
930            let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
931            return writeln!(f, "{out} = {rem};");
932        }
933
934        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
935        let [lhs, rhs, out_optimized] = optimized.args;
936
937        let item_out_original = out.item();
938        let item_out_optimized = out_optimized.item();
939
940        let index = match optimized.optimization_factor {
941            Some(factor) => item_out_original.vectorization / factor,
942            None => item_out_optimized.vectorization,
943        };
944
945        let floor = floor(*item_out_optimized.elem());
946
947        let mut write_op =
948            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
949                let out = out.fmt_left();
950                writeln!(f, "{out} = {item_out}{{")?;
951                for i in 0..index {
952                    let lhsi = lhs.index(i);
953                    let rhsi = rhs.index(i);
954
955                    let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
956                    writeln!(f, "{rem}")?;
957                    f.write_str(", ")?;
958                }
959
960                f.write_str("};\n")
961            };
962
963        if item_out_original == item_out_optimized {
964            write_op(&lhs, &rhs, out, item_out_optimized)
965        } else {
966            let out_tmp = Variable::tmp(item_out_optimized);
967
968            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
969
970            let addr_space = D::address_space_for_variable(&out_tmp);
971            let qualifier = out.const_qualifier();
972            let out = out.fmt_left();
973
974            writeln!(
975                f,
976                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
977            )?;
978
979            Ok(())
980        }
981    }
982}
983
984struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
985    _dialect: PhantomData<D>,
986    _sqrt: PhantomData<S>,
987}
988
989impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
990    fn format(
991        f: &mut core::fmt::Formatter<'_>,
992        input: &Variable<D>,
993        out: &Variable<D>,
994    ) -> core::fmt::Result {
995        let num = input.item().vectorization;
996        let elem = input.elem();
997
998        let mag = format!("{out}_mag");
999
1000        writeln!(f, "{} {mag} = 0.0;", out.item())?;
1001
1002        for i in 0..num {
1003            let input_i = input.index(i);
1004            writeln!(f, "{mag} += {input_i} * {input_i};")?;
1005        }
1006
1007        let out = out.fmt_left();
1008        write!(f, "{out} = ")?;
1009        S::format_unary(f, &mag, elem)?;
1010        f.write_str(";\n")
1011    }
1012}
1013
1014struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1015    _dialect: PhantomData<D>,
1016    _rsqrt: PhantomData<InvS>,
1017}
1018
1019impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1020    fn format(
1021        f: &mut core::fmt::Formatter<'_>,
1022        input: &Variable<D>,
1023        out: &Variable<D>,
1024    ) -> core::fmt::Result {
1025        let num = input.item().vectorization;
1026        let elem = input.elem();
1027        let norm = format!("{out}_norm");
1028
1029        let out_item = out.item();
1030        let out = out.fmt_left();
1031        writeln!(f, "{elem} {norm} = 0.0;")?;
1032
1033        for i in 0..num {
1034            let input_i = input.index(i);
1035            writeln!(f, "{norm} += {input_i} * {input_i};")?;
1036        }
1037
1038        write!(f, "{norm} = ")?;
1039        InvS::format_unary(f, &norm, elem)?;
1040        f.write_str(";\n")?;
1041
1042        if num == 1 {
1043            writeln!(f, "{out} = {input} * {norm};")
1044        } else {
1045            write!(f, "{out} = {out_item}{{")?;
1046            for i in 0..num {
1047                let input_i = input.index(i);
1048
1049                writeln!(f, "{input_i} * {norm},")?;
1050            }
1051
1052            f.write_str("};\n")
1053        }
1054    }
1055}
1056
1057struct Dot<D: Dialect> {
1058    _dialect: PhantomData<D>,
1059}
1060
1061impl<D: Dialect> Dot<D> {
1062    fn format(
1063        f: &mut core::fmt::Formatter<'_>,
1064        lhs: &Variable<D>,
1065        rhs: &Variable<D>,
1066        out: &Variable<D>,
1067    ) -> core::fmt::Result {
1068        let num = lhs.item().vectorization;
1069
1070        let muls = (0..num)
1071            .map(|i| {
1072                let lhs_i = lhs.index(i);
1073                let rhs_i = rhs.index(i);
1074                format!("{lhs_i} * {rhs_i}")
1075            })
1076            .collect::<Vec<_>>();
1077
1078        let out = out.fmt_left();
1079        writeln!(f, "{out} = {};", muls.join(" + "))
1080    }
1081}
1082
1083struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1084    var: &'a V,
1085    elem: &'a Elem<D>,
1086}
1087
1088impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1089    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1090        if self.elem != &Elem::Bool {
1091            write!(f, "bool({})", self.var)
1092        } else {
1093            write!(f, "{}", self.var)
1094        }
1095    }
1096}