cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Formatter, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone)]
17pub struct BinaryInstruction<D: Dialect> {
18    pub lhs: Variable<D>,
19    pub rhs: Variable<D>,
20    pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25    pub list: Variable<D>,
26    pub index: Variable<D>,
27    pub line_size: u32,
28    pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33    pub index: Variable<D>,
34    pub value: Variable<D>,
35    pub line_size: u32,
36    pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone)]
40pub struct UnaryInstruction<D: Dialect> {
41    pub input: Variable<D>,
42    pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47    Metadata {
48        info_offset: Variable<D>,
49        split_meta: bool,
50        out: Variable<D>,
51    },
52    ExtendedMetadata {
53        info_offset: Variable<D>,
54        dim: Variable<D>,
55        split_meta: bool,
56        static_offset: u32,
57        out: Variable<D>,
58    },
59    ConstLength {
60        length: u32,
61        out: Variable<D>,
62    },
63    SliceLength {
64        input: Variable<D>,
65        out: Variable<D>,
66    },
67    DeclareVariable {
68        var: Variable<D>,
69    },
70    Modulo(BinaryInstruction<D>),
71    Remainder(BinaryInstruction<D>),
72    Add(BinaryInstruction<D>),
73    SaturatingAdd(BinaryInstruction<D>),
74    Fma {
75        a: Variable<D>,
76        b: Variable<D>,
77        c: Variable<D>,
78        out: Variable<D>,
79    },
80    Div(BinaryInstruction<D>),
81    Mul(BinaryInstruction<D>),
82    Sub(BinaryInstruction<D>),
83    SaturatingSub(BinaryInstruction<D>),
84    HiMul(BinaryInstruction<D>),
85    Index(IndexInstruction<D>),
86    IndexAssign(IndexAssignInstruction<D>),
87    Assign(UnaryInstruction<D>),
88    SpecialCast(UnaryInstruction<D>),
89    RangeLoop {
90        i: Variable<D>,
91        start: Variable<D>,
92        end: Variable<D>,
93        step: Option<Variable<D>>,
94        inclusive: bool,
95        instructions: Vec<Self>,
96    },
97    VecInit {
98        inputs: Vec<Variable<D>>,
99        out: Variable<D>,
100    },
101    Loop {
102        instructions: Vec<Self>,
103    },
104    If {
105        cond: Variable<D>,
106        instructions: Vec<Self>,
107    },
108    IfElse {
109        cond: Variable<D>,
110        instructions_if: Vec<Self>,
111        instructions_else: Vec<Self>,
112    },
113    Select {
114        cond: Variable<D>,
115        then: Variable<D>,
116        or_else: Variable<D>,
117        out: Variable<D>,
118    },
119    Switch {
120        value: Variable<D>,
121        instructions_default: Vec<Self>,
122        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
123    },
124    Slice {
125        input: Variable<D>,
126        start: Variable<D>,
127        end: Variable<D>,
128        out: Variable<D>,
129    },
130    CheckedSlice {
131        input: Variable<D>,
132        start: Variable<D>,
133        end: Variable<D>,
134        out: Variable<D>,
135        len: Variable<D>,
136    },
137    ReinterpretSlice {
138        input: Variable<D>,
139        line_size: u32,
140        out: Variable<D>,
141    },
142    Return,
143    Break,
144    Equal(BinaryInstruction<D>),
145    NotEqual(BinaryInstruction<D>),
146    Lower(BinaryInstruction<D>),
147    Greater(BinaryInstruction<D>),
148    LowerEqual(BinaryInstruction<D>),
149    GreaterEqual(BinaryInstruction<D>),
150    Erf(UnaryInstruction<D>),
151    BitwiseOr(BinaryInstruction<D>),
152    BitwiseAnd(BinaryInstruction<D>),
153    BitwiseXor(BinaryInstruction<D>),
154    CountBits(UnaryInstruction<D>),
155    ReverseBits(UnaryInstruction<D>),
156    ShiftLeft(BinaryInstruction<D>),
157    ShiftRight(BinaryInstruction<D>),
158    BitwiseNot(UnaryInstruction<D>),
159    LeadingZeros(UnaryInstruction<D>),
160    FindFirstSet(UnaryInstruction<D>),
161    Abs(UnaryInstruction<D>),
162    Exp(UnaryInstruction<D>),
163    Log(UnaryInstruction<D>),
164    Log1p(UnaryInstruction<D>),
165    Cos(UnaryInstruction<D>),
166    Sin(UnaryInstruction<D>),
167    Tanh(UnaryInstruction<D>),
168    Powf(BinaryInstruction<D>),
169    Powi(BinaryInstruction<D>),
170    Sqrt(UnaryInstruction<D>),
171    Min(BinaryInstruction<D>),
172    Max(BinaryInstruction<D>),
173    Not(UnaryInstruction<D>),
174    Or(BinaryInstruction<D>),
175    And(BinaryInstruction<D>),
176    Clamp {
177        input: Variable<D>,
178        min_value: Variable<D>,
179        max_value: Variable<D>,
180        out: Variable<D>,
181    },
182    IsNan(UnaryInstruction<D>),
183    IsInf(UnaryInstruction<D>),
184    SyncThreads,
185    SyncWarp,
186    ThreadFence,
187    ProxySharedFence,
188    BulkCommitGroup,
189    BulkWaitGroup {
190        max_pending: u32,
191    },
192    BulkWaitGroupRead {
193        max_pending: u32,
194    },
195    TmaReplacePointer {
196        buffer: Variable<D>,
197        offset: Variable<D>,
198        tensor_map: Variable<D>,
199        out: Variable<D>,
200    },
201    Round(UnaryInstruction<D>),
202    Ceil(UnaryInstruction<D>),
203    Trunc(UnaryInstruction<D>),
204    Floor(UnaryInstruction<D>),
205    Warp(WarpInstruction<D>),
206    Wmma(WmmaInstruction<D>),
207    Bitcast(UnaryInstruction<D>),
208    AtomicLoad(UnaryInstruction<D>),
209    AtomicStore(UnaryInstruction<D>),
210    AtomicSwap(BinaryInstruction<D>),
211    AtomicAdd(BinaryInstruction<D>),
212    AtomicSub(BinaryInstruction<D>),
213    AtomicMax(BinaryInstruction<D>),
214    AtomicMin(BinaryInstruction<D>),
215    AtomicAnd(BinaryInstruction<D>),
216    AtomicOr(BinaryInstruction<D>),
217    AtomicXor(BinaryInstruction<D>),
218    AtomicCAS {
219        input: Variable<D>,
220        cmp: Variable<D>,
221        val: Variable<D>,
222        out: Variable<D>,
223    },
224    Neg(UnaryInstruction<D>),
225    Magnitude(UnaryInstruction<D>),
226    Normalize(UnaryInstruction<D>),
227    Dot(BinaryInstruction<D>),
228    Copy {
229        input: Variable<D>,
230        in_index: Variable<D>,
231        out: Variable<D>,
232        out_index: Variable<D>,
233    },
234    CopyBulk {
235        input: Variable<D>,
236        in_index: Variable<D>,
237        out: Variable<D>,
238        out_index: Variable<D>,
239        len: u32,
240    },
241    Printf {
242        format_string: String,
243        args: Vec<Variable<D>>,
244    },
245    Comment {
246        content: String,
247    },
248    Pipeline(PipelineOps<D>),
249    Barrier(BarrierOps<D>),
250    MemCopyAsyncTensorSharedToGlobal {
251        smem_buffer: Variable<D>,
252        smem_offset: Variable<D>,
253        tensor_map: Variable<D>,
254        indices: Vec<Variable<D>>,
255    },
256    Line {
257        file: Cow<'static, str>,
258        line: u32,
259    },
260}
261
262impl<D: Dialect> Display for Instruction<D> {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        match self {
265            Instruction::Return => f.write_str("return;"),
266            Instruction::Break => f.write_str("break;"),
267            Instruction::DeclareVariable { var } => match var {
268                Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
269                _ => {
270                    let item = var.item();
271                    writeln!(f, "{item} {var};")
272                }
273            },
274            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
275            Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
276            Instruction::Slice {
277                input,
278                start,
279                end,
280                out,
281            } => {
282                let item = out.item();
283                let addr_space = D::address_space_for_variable(input);
284                writeln!(f, "const uint {out}_length = {end} - {start};")?;
285                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
286            }
287            Instruction::CheckedSlice {
288                input,
289                start,
290                end,
291                out,
292                len,
293            } => {
294                let item = out.item();
295                let addr_space = D::address_space_for_variable(input);
296                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
297                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
298            }
299            Instruction::ReinterpretSlice {
300                input,
301                line_size,
302                out,
303            } => {
304                let mut item = out.item();
305                item.vectorization = *line_size as usize;
306                let addr_space = D::address_space_for_variable(input);
307
308                writeln!(
309                    f,
310                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
311                )
312            }
313            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
314            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
315            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
316            Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
317            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
318            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
319            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
320            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
321            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
322            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
323            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
324            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
325            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
326            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
327            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
328            Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
329            Instruction::IndexAssign(it) => {
330                IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
331            }
332            Instruction::Copy {
333                input,
334                in_index,
335                out,
336                out_index,
337            } => {
338                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
339            }
340            Instruction::CopyBulk {
341                input,
342                in_index,
343                out,
344                out_index,
345                len,
346            } => {
347                for i in 0..*len {
348                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
349                }
350                Ok(())
351            }
352            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
353            Instruction::RangeLoop {
354                i,
355                start,
356                end,
357                step,
358                inclusive,
359                instructions,
360            } => {
361                let increment = step
362                    .map(|step| format!("{i} += {step}"))
363                    .unwrap_or_else(|| format!("++{i}"));
364                let cmp = if *inclusive { "<=" } else { "<" };
365                let i_ty = i.item();
366
367                write!(
368                    f,
369                    "
370for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
371"
372                )?;
373                for instruction in instructions {
374                    write!(f, "{instruction}")?;
375                }
376
377                f.write_str("}\n")
378            }
379            Instruction::Loop { instructions } => {
380                writeln!(f, "while (true) {{")?;
381                for i in instructions {
382                    write!(f, "{i}")?;
383                }
384                f.write_str("}\n")
385            }
386            Instruction::If { cond, instructions } => {
387                writeln!(f, "if ({cond}) {{")?;
388                for i in instructions {
389                    write!(f, "{i}")?;
390                }
391                f.write_str("}\n")
392            }
393            Instruction::IfElse {
394                cond,
395                instructions_if,
396                instructions_else,
397            } => {
398                writeln!(f, "if ({cond}) {{")?;
399                for i in instructions_if {
400                    write!(f, "{i}")?;
401                }
402                f.write_str("} else {\n")?;
403                for i in instructions_else {
404                    write!(f, "{i}")?;
405                }
406                f.write_str("}\n")
407            }
408            Instruction::Select {
409                cond,
410                then,
411                or_else,
412                out,
413            } => {
414                let item_or_else = or_else.item();
415                let item_then = then.item();
416                let item_out = out.item();
417
418                let vf_then = item_then.vectorization;
419                let vf_or_else = item_or_else.vectorization;
420                let vf_out = item_out.vectorization;
421                let vf_cond = cond.item().vectorization;
422
423                let item_out = out.item();
424                let cond_elem = cond.item().elem;
425                let out = out.fmt_left();
426
427                let should_broadcast =
428                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
429
430                if should_broadcast {
431                    let vf = usize::max(vf_cond, vf_out);
432                    let vf = usize::max(vf, vf_then);
433                    let vf = usize::max(vf, vf_or_else);
434
435                    writeln!(f, "{out} = {item_out} {{")?;
436                    for i in 0..vf {
437                        let theni = then.index(i);
438                        let or_elsei = or_else.index(i);
439                        let condi = cond.index(i);
440                        let condi = EnsureBoolArg {
441                            var: &condi,
442                            elem: &cond_elem,
443                        };
444
445                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
446                    }
447
448                    writeln!(f, "}};")
449                } else {
450                    let cond = EnsureBoolArg {
451                        var: &cond,
452                        elem: &cond_elem,
453                    };
454                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
455                }
456            }
457            Instruction::Switch {
458                value,
459                instructions_default,
460                instructions_cases,
461            } => {
462                writeln!(f, "switch({value}) {{")?;
463                for (value, block) in instructions_cases {
464                    write!(f, "case {value}:\n{{\n")?;
465                    for i in block {
466                        i.fmt(f)?;
467                    }
468                    f.write_str("break;\n}\n")?;
469                }
470                f.write_str("default:\n{")?;
471                for i in instructions_default {
472                    i.fmt(f)?;
473                }
474                f.write_str("}\n}\n")
475            }
476            Instruction::Metadata {
477                info_offset,
478                split_meta,
479                out,
480            } => {
481                let out = out.fmt_left();
482                match *split_meta {
483                    true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
484                    false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
485                }
486            }
487            Instruction::ExtendedMetadata {
488                info_offset,
489                dim,
490                split_meta,
491                static_offset,
492                out,
493            } => {
494                let out = out.fmt_left();
495                match *split_meta {
496                    true => writeln!(
497                        f,
498                        "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
499                    ),
500                    false => writeln!(
501                        f,
502                        "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
503                    ),
504                }
505            }
506            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
507            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
508            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
509            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
510            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
511            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
512            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
513            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
514            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
515            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
516            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
517            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
518            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
519            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
520            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
521            Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
522            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
523            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
524            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
525            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
526            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
527            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
528            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
529            Instruction::Clamp {
530                input,
531                min_value,
532                max_value,
533                out,
534            } => Clamp::format(f, input, min_value, max_value, out),
535            Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
536            Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
537            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
538            Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
539            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
540            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
541            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
542            Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
543            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
544            Instruction::SliceLength { input, out } => {
545                let out = out.fmt_left();
546                writeln!(f, "{out} = {input}_length;")
547            }
548            Instruction::ConstLength { length, out } => {
549                let out = out.fmt_left();
550                writeln!(f, "{out} = {length};")
551            }
552            Instruction::Warp(it) => write!(f, "{it}"),
553            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
554            Instruction::Wmma(it) => write!(f, "{it}"),
555            Instruction::Bitcast(UnaryInstruction { input, out }) => {
556                let qualifier = out.const_qualifier();
557                let input_item = input.item();
558                let out_item = out.item();
559
560                if out_item.elem.size() * out_item.vectorization
561                    != input.item().elem.size() * input.item().vectorization
562                {
563                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
564                } else {
565                    let out = out.fmt_left();
566                    let addr_space = D::address_space_for_variable(input);
567                    writeln!(
568                        f,
569                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
570                    )
571                }
572            }
573            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
574                D::compile_atomic_add(f, lhs, rhs, out)
575            }
576            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
577                D::compile_atomic_and(f, lhs, rhs, out)
578            }
579            Instruction::AtomicCAS {
580                input,
581                cmp,
582                val,
583                out,
584            } => D::compile_atomic_cas(f, input, cmp, val, out),
585            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
586                D::compile_atomic_load(f, input, out)
587            }
588            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
589                D::compile_atomic_max(f, lhs, rhs, out)
590            }
591            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
592                D::compile_atomic_min(f, lhs, rhs, out)
593            }
594            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
595                D::compile_atomic_or(f, lhs, rhs, out)
596            }
597            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
598                D::compile_atomic_store(f, input, out)
599            }
600            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
601                D::compile_atomic_sub(f, lhs, rhs, out)
602            }
603            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
604                D::compile_atomic_swap(f, lhs, rhs, out)
605            }
606            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
607                D::compile_atomic_xor(f, lhs, rhs, out)
608            }
609            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
610            Instruction::Neg(UnaryInstruction { input, out }) => {
611                let out = out.fmt_left();
612                writeln!(f, "{out} = -{input};")
613            }
614            Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
615            Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
616            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
617            Instruction::VecInit { inputs, out } => {
618                let item = out.item();
619                let inputs = inputs
620                    .iter()
621                    .map(|input| format!("{input}"))
622                    .collect::<Vec<_>>();
623                let out = out.fmt_left();
624                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
625            }
626            Instruction::Printf {
627                format_string,
628                args,
629            } => D::compile_instruction_printf(f, format_string, args),
630            Instruction::Comment { content } => {
631                if content.contains('\n') {
632                    writeln!(f, "/* {content} */")
633                } else {
634                    writeln!(f, "// {content}")
635                }
636            }
637            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
638            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
639            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
640            Instruction::ProxySharedFence => {
641                writeln!(
642                    f,
643                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
644                )
645            }
646            Instruction::BulkCommitGroup => writeln!(
647                f,
648                "cuda::device::experimental::cp_async_bulk_commit_group();"
649            ),
650            Instruction::BulkWaitGroup { max_pending } => writeln!(
651                f,
652                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
653            ),
654            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
655                f,
656                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
657            ),
658            Instruction::TmaReplacePointer {
659                buffer,
660                offset,
661                tensor_map,
662                out,
663            } => {
664                let pos = Variable::<D>::UnitPos;
665                writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
666                writeln!(
667                    f,
668                    "
669if({pos} == 0) {{
670    {out} = {tensor_map};
671    tensormap_replace_global_address({out}, &{buffer}[{offset}]);
672}}"
673                )?;
674                writeln!(f, "__syncthreads();")
675            }
676            Instruction::MemCopyAsyncTensorSharedToGlobal {
677                smem_buffer,
678                smem_offset,
679                tensor_map,
680                indices,
681            } => {
682                let rank = indices.len();
683                let smem_ptr = smem_buffer.fmt_ptr();
684                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
685                    let _ = write!(s, "{it}, ");
686                    s
687                });
688                writeln!(
689                    f,
690                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
691                )
692            }
693            Instruction::SpecialCast(UnaryInstruction { input, out }) => {
694                // Only supported in CUDA so I'm putting it here. Move to dialect if necessary.
695                #[cfg(not(feature = "cuda"))]
696                {
697                    let _ = (input, out);
698                    unimplemented!("FP8/FP6/FP4 casting isn't supported outside of CUDA");
699                }
700                #[cfg(feature = "cuda")]
701                crate::cuda::convert::special_cast::<D>(f, input, out)
702            }
703        }
704    }
705}
706
707struct Fma<D: Dialect> {
708    _dialect: PhantomData<D>,
709}
710
711impl<D: Dialect> Fma<D> {
712    fn format(
713        f: &mut core::fmt::Formatter<'_>,
714        a: &Variable<D>,
715        b: &Variable<D>,
716        c: &Variable<D>,
717        out: &Variable<D>,
718    ) -> core::fmt::Result {
719        let out_item = out.item();
720        let num = out_item.vectorization;
721
722        let out = out.fmt_left();
723        if num == 1 {
724            writeln!(f, "{out} = fma({a}, {b}, {c});")
725        } else {
726            writeln!(f, "{out} = {out_item}{{")?;
727
728            for i in 0..num {
729                let ai = a.index(i);
730                let bi = b.index(i);
731                let ci = c.index(i);
732
733                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
734            }
735            f.write_str("};\n")
736        }
737    }
738}
739
740struct Clamp<D: Dialect> {
741    _dialect: PhantomData<D>,
742}
743
744impl<D: Dialect> Clamp<D> {
745    fn format(
746        f: &mut core::fmt::Formatter<'_>,
747        input: &Variable<D>,
748        min_value: &Variable<D>,
749        max_value: &Variable<D>,
750        out: &Variable<D>,
751    ) -> core::fmt::Result {
752        let out_item = out.item();
753        if out.item().vectorization == 1 {
754            let out = out.fmt_left();
755            write!(f, "{out} = ")?;
756            Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
757            f.write_str(";\n")
758        } else {
759            Self::unroll_vec(f, input, min_value, max_value, out)
760        }
761    }
762
763    fn format_scalar(
764        f: &mut Formatter<'_>,
765        input: impl Component<D>,
766        min_value: impl Component<D>,
767        max_value: impl Component<D>,
768        item: Item<D>,
769    ) -> std::fmt::Result {
770        D::compile_instruction_max_function_name(f, item)?;
771        write!(f, "({min_value}, ")?;
772        D::compile_instruction_min_function_name(f, item)?;
773        write!(f, "({max_value}, {input}))")
774    }
775
776    fn unroll_vec(
777        f: &mut core::fmt::Formatter<'_>,
778        input: &Variable<D>,
779        min_value: &Variable<D>,
780        max_value: &Variable<D>,
781        out: &Variable<D>,
782    ) -> std::fmt::Result {
783        let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
784        let [input, min_value, max_value, out_optimized] = optimized.args;
785
786        let item_out_original = out.item();
787        let item_out_optimized = out_optimized.item();
788
789        let index = match optimized.optimization_factor {
790            Some(factor) => item_out_original.vectorization / factor,
791            None => item_out_optimized.vectorization,
792        };
793
794        let mut write_op = |input: &Variable<D>,
795                            min_value: &Variable<D>,
796                            max_value: &Variable<D>,
797                            out: &Variable<D>,
798                            item_out: Item<D>| {
799            let out = out.fmt_left();
800            writeln!(f, "{out} = {item_out}{{")?;
801            for i in 0..index {
802                let inputi = input.index(i);
803                let min_valuei = min_value.index(i);
804                let max_valuei = max_value.index(i);
805
806                Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
807                f.write_str(", ")?;
808            }
809
810            f.write_str("};\n")
811        };
812
813        if item_out_original == item_out_optimized {
814            write_op(&input, &min_value, &max_value, out, item_out_optimized)
815        } else {
816            let out_tmp = Variable::tmp(item_out_optimized);
817            write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
818            let addr_space = D::address_space_for_variable(out);
819            let out = out.fmt_left();
820
821            writeln!(
822                f,
823                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
824            )?;
825
826            Ok(())
827        }
828    }
829}
830
831struct Remainder<D: Dialect> {
832    _dialect: PhantomData<D>,
833}
834
835impl<D: Dialect> Remainder<D> {
836    fn format(
837        f: &mut core::fmt::Formatter<'_>,
838        lhs: &Variable<D>,
839        rhs: &Variable<D>,
840        out: &Variable<D>,
841    ) -> core::fmt::Result {
842        let floor = |elem| {
843            let prefix = match elem {
844                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
845                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
846                _ => "",
847            };
848            format!("{prefix}floor")
849        };
850
851        if out.item().vectorization == 1 {
852            let floor = floor(out.elem());
853
854            let out = out.fmt_left();
855            return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
856        }
857
858        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
859        let [lhs, rhs, out_optimized] = optimized.args;
860
861        let item_out_original = out.item();
862        let item_out_optimized = out_optimized.item();
863
864        let index = match optimized.optimization_factor {
865            Some(factor) => item_out_original.vectorization / factor,
866            None => item_out_optimized.vectorization,
867        };
868
869        let floor = floor(*item_out_optimized.elem());
870
871        let mut write_op =
872            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
873                let out = out.fmt_left();
874                writeln!(f, "{out} = {item_out}{{")?;
875                for i in 0..index {
876                    let lhsi = lhs.index(i);
877                    let rhsi = rhs.index(i);
878
879                    writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
880                    f.write_str(", ")?;
881                }
882
883                f.write_str("};\n")
884            };
885
886        if item_out_original == item_out_optimized {
887            write_op(&lhs, &rhs, out, item_out_optimized)
888        } else {
889            let out_tmp = Variable::tmp(item_out_optimized);
890
891            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
892
893            let addr_space = D::address_space_for_variable(&out_tmp);
894            let qualifier = out.const_qualifier();
895            let out = out.fmt_left();
896
897            writeln!(
898                f,
899                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
900            )?;
901
902            Ok(())
903        }
904    }
905}
906
907struct Magnitude<D: Dialect> {
908    _dialect: PhantomData<D>,
909}
910
911impl<D: Dialect> Magnitude<D> {
912    fn format(
913        f: &mut core::fmt::Formatter<'_>,
914        input: &Variable<D>,
915        out: &Variable<D>,
916    ) -> core::fmt::Result {
917        let num = input.item().vectorization;
918        let elem = input.elem();
919
920        let mag = format!("{out}_mag");
921
922        writeln!(f, "{} {mag} = 0.0;", out.item())?;
923
924        for i in 0..num {
925            let input_i = input.index(i);
926            writeln!(f, "{mag} += {input_i} * {input_i};")?;
927        }
928
929        let out = out.fmt_left();
930        write!(f, "{out} = ")?;
931        Sqrt::format_unary(f, &mag, elem)?;
932        f.write_str(";\n")
933    }
934}
935
936struct Normalize<D: Dialect> {
937    _dialect: PhantomData<D>,
938}
939
940impl<D: Dialect> Normalize<D> {
941    fn format(
942        f: &mut core::fmt::Formatter<'_>,
943        input: &Variable<D>,
944        out: &Variable<D>,
945    ) -> core::fmt::Result {
946        let num = input.item().vectorization;
947        let elem = input.elem();
948        let norm = format!("{out}_norm");
949
950        let out_item = out.item();
951        let out = out.fmt_left();
952        writeln!(f, "{elem} {norm} = 0.0;")?;
953
954        for i in 0..num {
955            let input_i = input.index(i);
956            writeln!(f, "{norm} += {input_i} * {input_i};")?;
957        }
958
959        write!(f, "{norm} = ")?;
960        Sqrt::format_unary(f, &norm, elem)?;
961        f.write_str(";\n")?;
962
963        if num == 1 {
964            writeln!(f, "{out} = {input} / {norm};")
965        } else {
966            write!(f, "{out} = {out_item}{{")?;
967            for i in 0..num {
968                let input_i = input.index(i);
969
970                writeln!(f, "{input_i} / {norm},")?;
971            }
972
973            f.write_str("};\n")
974        }
975    }
976}
977
978struct Dot<D: Dialect> {
979    _dialect: PhantomData<D>,
980}
981
982impl<D: Dialect> Dot<D> {
983    fn format(
984        f: &mut core::fmt::Formatter<'_>,
985        lhs: &Variable<D>,
986        rhs: &Variable<D>,
987        out: &Variable<D>,
988    ) -> core::fmt::Result {
989        let num = lhs.item().vectorization;
990
991        let muls = (0..num)
992            .map(|i| {
993                let lhs_i = lhs.index(i);
994                let rhs_i = rhs.index(i);
995                format!("{lhs_i} * {rhs_i}")
996            })
997            .collect::<Vec<_>>();
998
999        let out = out.fmt_left();
1000        writeln!(f, "{out} = {};", muls.join(" + "))
1001    }
1002}
1003
1004struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1005    var: &'a V,
1006    elem: &'a Elem<D>,
1007}
1008
1009impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1010    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1011        if self.elem != &Elem::Bool {
1012            write!(f, "bool({})", self.var)
1013        } else {
1014            write!(f, "{}", self.var)
1015        }
1016    }
1017}