cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone)]
17pub struct BinaryInstruction<D: Dialect> {
18    pub lhs: Variable<D>,
19    pub rhs: Variable<D>,
20    pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25    pub list: Variable<D>,
26    pub index: Variable<D>,
27    pub line_size: u32,
28    pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33    pub index: Variable<D>,
34    pub value: Variable<D>,
35    pub line_size: u32,
36    pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone)]
40pub struct UnaryInstruction<D: Dialect> {
41    pub input: Variable<D>,
42    pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47    Metadata {
48        info_offset: Variable<D>,
49        split_meta: bool,
50        out: Variable<D>,
51    },
52    ExtendedMetadata {
53        info_offset: Variable<D>,
54        dim: Variable<D>,
55        split_meta: bool,
56        static_offset: u32,
57        out: Variable<D>,
58    },
59    ConstLength {
60        length: u32,
61        out: Variable<D>,
62    },
63    SliceLength {
64        input: Variable<D>,
65        out: Variable<D>,
66    },
67    DeclareVariable {
68        var: Variable<D>,
69    },
70    Modulo(BinaryInstruction<D>),
71    Remainder(BinaryInstruction<D>),
72    Add(BinaryInstruction<D>),
73    Fma {
74        a: Variable<D>,
75        b: Variable<D>,
76        c: Variable<D>,
77        out: Variable<D>,
78    },
79    Div(BinaryInstruction<D>),
80    Mul(BinaryInstruction<D>),
81    Sub(BinaryInstruction<D>),
82    HiMul(BinaryInstruction<D>),
83    Index(IndexInstruction<D>),
84    IndexAssign(IndexAssignInstruction<D>),
85    Assign(UnaryInstruction<D>),
86    SpecialCast(UnaryInstruction<D>),
87    RangeLoop {
88        i: Variable<D>,
89        start: Variable<D>,
90        end: Variable<D>,
91        step: Option<Variable<D>>,
92        inclusive: bool,
93        instructions: Vec<Self>,
94    },
95    VecInit {
96        inputs: Vec<Variable<D>>,
97        out: Variable<D>,
98    },
99    Loop {
100        instructions: Vec<Self>,
101    },
102    If {
103        cond: Variable<D>,
104        instructions: Vec<Self>,
105    },
106    IfElse {
107        cond: Variable<D>,
108        instructions_if: Vec<Self>,
109        instructions_else: Vec<Self>,
110    },
111    Select {
112        cond: Variable<D>,
113        then: Variable<D>,
114        or_else: Variable<D>,
115        out: Variable<D>,
116    },
117    Switch {
118        value: Variable<D>,
119        instructions_default: Vec<Self>,
120        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
121    },
122    Slice {
123        input: Variable<D>,
124        start: Variable<D>,
125        end: Variable<D>,
126        out: Variable<D>,
127    },
128    CheckedSlice {
129        input: Variable<D>,
130        start: Variable<D>,
131        end: Variable<D>,
132        out: Variable<D>,
133        len: Variable<D>,
134    },
135    ReinterpretSlice {
136        input: Variable<D>,
137        line_size: u32,
138        out: Variable<D>,
139    },
140    Return,
141    Break,
142    Equal(BinaryInstruction<D>),
143    NotEqual(BinaryInstruction<D>),
144    Lower(BinaryInstruction<D>),
145    Greater(BinaryInstruction<D>),
146    LowerEqual(BinaryInstruction<D>),
147    GreaterEqual(BinaryInstruction<D>),
148    Erf(UnaryInstruction<D>),
149    BitwiseOr(BinaryInstruction<D>),
150    BitwiseAnd(BinaryInstruction<D>),
151    BitwiseXor(BinaryInstruction<D>),
152    CountBits(UnaryInstruction<D>),
153    ReverseBits(UnaryInstruction<D>),
154    ShiftLeft(BinaryInstruction<D>),
155    ShiftRight(BinaryInstruction<D>),
156    BitwiseNot(UnaryInstruction<D>),
157    LeadingZeros(UnaryInstruction<D>),
158    FindFirstSet(UnaryInstruction<D>),
159    Abs(UnaryInstruction<D>),
160    Exp(UnaryInstruction<D>),
161    Log(UnaryInstruction<D>),
162    Log1p(UnaryInstruction<D>),
163    Cos(UnaryInstruction<D>),
164    Sin(UnaryInstruction<D>),
165    Tanh(UnaryInstruction<D>),
166    Powf(BinaryInstruction<D>),
167    Sqrt(UnaryInstruction<D>),
168    Min(BinaryInstruction<D>),
169    Max(BinaryInstruction<D>),
170    Not(UnaryInstruction<D>),
171    Or(BinaryInstruction<D>),
172    And(BinaryInstruction<D>),
173    Clamp {
174        input: Variable<D>,
175        min_value: Variable<D>,
176        max_value: Variable<D>,
177        out: Variable<D>,
178    },
179    SyncThreads,
180    SyncWarp,
181    ThreadFence,
182    ProxySharedFence,
183    BulkCommitGroup,
184    BulkWaitGroup {
185        max_pending: u32,
186    },
187    BulkWaitGroupRead {
188        max_pending: u32,
189    },
190    Round(UnaryInstruction<D>),
191    Ceil(UnaryInstruction<D>),
192    Floor(UnaryInstruction<D>),
193    Warp(WarpInstruction<D>),
194    Wmma(WmmaInstruction<D>),
195    Bitcast(UnaryInstruction<D>),
196    AtomicLoad(UnaryInstruction<D>),
197    AtomicStore(UnaryInstruction<D>),
198    AtomicSwap(BinaryInstruction<D>),
199    AtomicAdd(BinaryInstruction<D>),
200    AtomicSub(BinaryInstruction<D>),
201    AtomicMax(BinaryInstruction<D>),
202    AtomicMin(BinaryInstruction<D>),
203    AtomicAnd(BinaryInstruction<D>),
204    AtomicOr(BinaryInstruction<D>),
205    AtomicXor(BinaryInstruction<D>),
206    AtomicCAS {
207        input: Variable<D>,
208        cmp: Variable<D>,
209        val: Variable<D>,
210        out: Variable<D>,
211    },
212    Neg(UnaryInstruction<D>),
213    Magnitude(UnaryInstruction<D>),
214    Normalize(UnaryInstruction<D>),
215    Dot(BinaryInstruction<D>),
216    Copy {
217        input: Variable<D>,
218        in_index: Variable<D>,
219        out: Variable<D>,
220        out_index: Variable<D>,
221    },
222    CopyBulk {
223        input: Variable<D>,
224        in_index: Variable<D>,
225        out: Variable<D>,
226        out_index: Variable<D>,
227        len: u32,
228    },
229    Printf {
230        format_string: String,
231        args: Vec<Variable<D>>,
232    },
233    Comment {
234        content: String,
235    },
236    Pipeline(PipelineOps<D>),
237    Barrier(BarrierOps<D>),
238    MemCopyAsyncTensorSharedToGlobal {
239        smem_buffer: Variable<D>,
240        smem_offset: Variable<D>,
241        tensor_map: Variable<D>,
242        indices: Vec<Variable<D>>,
243    },
244    Line {
245        file: Cow<'static, str>,
246        line: u32,
247    },
248}
249
250impl<D: Dialect> Display for Instruction<D> {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        match self {
253            Instruction::Return => f.write_str("return;"),
254            Instruction::Break => f.write_str("break;"),
255            Instruction::DeclareVariable { var } => match var {
256                Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
257                _ => {
258                    let item = var.item();
259                    writeln!(f, "{item} {var};")
260                }
261            },
262            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
263            Instruction::Slice {
264                input,
265                start,
266                end,
267                out,
268            } => {
269                let item = out.item();
270                let addr_space = D::address_space_for_variable(input);
271                writeln!(f, "const uint {out}_length = {end} - {start};")?;
272                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
273            }
274            Instruction::CheckedSlice {
275                input,
276                start,
277                end,
278                out,
279                len,
280            } => {
281                let item = out.item();
282                let addr_space = D::address_space_for_variable(input);
283                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
284                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
285            }
286            Instruction::ReinterpretSlice {
287                input,
288                line_size,
289                out,
290            } => {
291                let mut item = out.item();
292                item.vectorization = *line_size as usize;
293                let addr_space = D::address_space_for_variable(input);
294
295                writeln!(
296                    f,
297                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
298                )
299            }
300            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
301            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
302            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
303            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
304            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
305            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
306            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
307            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
308            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
309            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
310            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
311            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
312            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
313            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
314            Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
315            Instruction::IndexAssign(it) => {
316                IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
317            }
318            Instruction::Copy {
319                input,
320                in_index,
321                out,
322                out_index,
323            } => {
324                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
325            }
326            Instruction::CopyBulk {
327                input,
328                in_index,
329                out,
330                out_index,
331                len,
332            } => {
333                for i in 0..*len {
334                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
335                }
336                Ok(())
337            }
338            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
339            Instruction::RangeLoop {
340                i,
341                start,
342                end,
343                step,
344                inclusive,
345                instructions,
346            } => {
347                let increment = step
348                    .map(|step| format!("{i} += {step}"))
349                    .unwrap_or_else(|| format!("++{i}"));
350                let cmp = if *inclusive { "<=" } else { "<" };
351                let i_ty = i.item();
352
353                write!(
354                    f,
355                    "
356for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
357"
358                )?;
359                for instruction in instructions {
360                    write!(f, "{instruction}")?;
361                }
362
363                f.write_str("}\n")
364            }
365            Instruction::Loop { instructions } => {
366                writeln!(f, "while (true) {{")?;
367                for i in instructions {
368                    write!(f, "{i}")?;
369                }
370                f.write_str("}\n")
371            }
372            Instruction::If { cond, instructions } => {
373                writeln!(f, "if ({cond}) {{")?;
374                for i in instructions {
375                    write!(f, "{i}")?;
376                }
377                f.write_str("}\n")
378            }
379            Instruction::IfElse {
380                cond,
381                instructions_if,
382                instructions_else,
383            } => {
384                writeln!(f, "if ({cond}) {{")?;
385                for i in instructions_if {
386                    write!(f, "{i}")?;
387                }
388                f.write_str("} else {\n")?;
389                for i in instructions_else {
390                    write!(f, "{i}")?;
391                }
392                f.write_str("}\n")
393            }
394            Instruction::Select {
395                cond,
396                then,
397                or_else,
398                out,
399            } => {
400                let item_or_else = or_else.item();
401                let item_then = then.item();
402                let item_out = out.item();
403
404                let vf_then = item_then.vectorization;
405                let vf_or_else = item_or_else.vectorization;
406                let vf_out = item_out.vectorization;
407                let vf_cond = cond.item().vectorization;
408
409                let item_out = out.item();
410                let cond_elem = cond.item().elem;
411                let out = out.fmt_left();
412
413                let should_broadcast =
414                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
415
416                if should_broadcast {
417                    let vf = usize::max(vf_cond, vf_out);
418                    let vf = usize::max(vf, vf_then);
419                    let vf = usize::max(vf, vf_or_else);
420
421                    writeln!(f, "{out} = {item_out} {{")?;
422                    for i in 0..vf {
423                        let theni = then.index(i);
424                        let or_elsei = or_else.index(i);
425                        let condi = cond.index(i);
426                        let condi = EnsureBoolArg {
427                            var: &condi,
428                            elem: &cond_elem,
429                        };
430
431                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
432                    }
433
434                    writeln!(f, "}};")
435                } else {
436                    let cond = EnsureBoolArg {
437                        var: &cond,
438                        elem: &cond_elem,
439                    };
440                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
441                }
442            }
443            Instruction::Switch {
444                value,
445                instructions_default,
446                instructions_cases,
447            } => {
448                writeln!(f, "switch({value}) {{")?;
449                for (value, block) in instructions_cases {
450                    write!(f, "case {value}:\n{{\n")?;
451                    for i in block {
452                        i.fmt(f)?;
453                    }
454                    f.write_str("break;\n}\n")?;
455                }
456                f.write_str("default:\n{")?;
457                for i in instructions_default {
458                    i.fmt(f)?;
459                }
460                f.write_str("}\n}\n")
461            }
462            Instruction::Metadata {
463                info_offset,
464                split_meta,
465                out,
466            } => {
467                let out = out.fmt_left();
468                match *split_meta {
469                    true => writeln!(f, "{out} = static_info.x[{info_offset}];"),
470                    false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
471                }
472            }
473            Instruction::ExtendedMetadata {
474                info_offset,
475                dim,
476                split_meta,
477                static_offset,
478                out,
479            } => {
480                let out = out.fmt_left();
481                match *split_meta {
482                    true => writeln!(
483                        f,
484                        "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
485                    ),
486                    false => writeln!(
487                        f,
488                        "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
489                    ),
490                }
491            }
492            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
493            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
494            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
495            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
496            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
497            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
498            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
499            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
500            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
501            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
502            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
503            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
504            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
505            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
506            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
507            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
508            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
509            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
510            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
511            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
512            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
513            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
514            Instruction::Clamp {
515                input,
516                min_value,
517                max_value,
518                out,
519            } => Clamp::format(f, input, min_value, max_value, out),
520            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
521            Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
522            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
523            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
524            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
525            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
526            Instruction::SliceLength { input, out } => {
527                let out = out.fmt_left();
528                writeln!(f, "{out} = {input}_length;")
529            }
530            Instruction::ConstLength { length, out } => {
531                let out = out.fmt_left();
532                writeln!(f, "{out} = {length};")
533            }
534            Instruction::Warp(it) => write!(f, "{it}"),
535            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
536            Instruction::Wmma(it) => write!(f, "{it}"),
537            Instruction::Bitcast(UnaryInstruction { input, out }) => {
538                let qualifier = out.const_qualifier();
539                let input_item = input.item();
540                let out_item = out.item();
541
542                if out_item.elem.size() * out_item.vectorization
543                    != input.item().elem.size() * input.item().vectorization
544                {
545                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
546                } else {
547                    let out = out.fmt_left();
548                    let addr_space = D::address_space_for_variable(input);
549                    writeln!(
550                        f,
551                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
552                    )
553                }
554            }
555            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
556                D::compile_atomic_add(f, lhs, rhs, out)
557            }
558            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
559                D::compile_atomic_and(f, lhs, rhs, out)
560            }
561            Instruction::AtomicCAS {
562                input,
563                cmp,
564                val,
565                out,
566            } => D::compile_atomic_cas(f, input, cmp, val, out),
567            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
568                D::compile_atomic_load(f, input, out)
569            }
570            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
571                D::compile_atomic_max(f, lhs, rhs, out)
572            }
573            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
574                D::compile_atomic_min(f, lhs, rhs, out)
575            }
576            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
577                D::compile_atomic_or(f, lhs, rhs, out)
578            }
579            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
580                D::compile_atomic_store(f, input, out)
581            }
582            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
583                D::compile_atomic_sub(f, lhs, rhs, out)
584            }
585            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
586                D::compile_atomic_swap(f, lhs, rhs, out)
587            }
588            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
589                D::compile_atomic_xor(f, lhs, rhs, out)
590            }
591            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
592            Instruction::Neg(UnaryInstruction { input, out }) => {
593                let out = out.fmt_left();
594                writeln!(f, "{out} = -{input};")
595            }
596            Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
597            Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
598            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
599            Instruction::VecInit { inputs, out } => {
600                let item = out.item();
601                let inputs = inputs
602                    .iter()
603                    .map(|input| format!("{input}"))
604                    .collect::<Vec<_>>();
605                let out = out.fmt_left();
606                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
607            }
608            Instruction::Printf {
609                format_string,
610                args,
611            } => D::compile_instruction_printf(f, format_string, args),
612            Instruction::Comment { content } => {
613                if content.contains('\n') {
614                    writeln!(f, "/* {content} */")
615                } else {
616                    writeln!(f, "// {content}")
617                }
618            }
619            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
620            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
621            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
622            Instruction::ProxySharedFence => {
623                writeln!(
624                    f,
625                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
626                )
627            }
628            Instruction::BulkCommitGroup => writeln!(
629                f,
630                "cuda::device::experimental::cp_async_bulk_commit_group();"
631            ),
632            Instruction::BulkWaitGroup { max_pending } => writeln!(
633                f,
634                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
635            ),
636            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
637                f,
638                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
639            ),
640            Instruction::MemCopyAsyncTensorSharedToGlobal {
641                smem_buffer,
642                smem_offset,
643                tensor_map,
644                indices,
645            } => {
646                let rank = indices.len();
647                let smem_ptr = smem_buffer.fmt_ptr();
648                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
649                    let _ = write!(s, "{it}, ");
650                    s
651                });
652                writeln!(
653                    f,
654                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
655                )
656            }
657            Instruction::SpecialCast(UnaryInstruction { input, out }) => {
658                // Only supported in CUDA so I'm putting it here. Move to dialect if necessary.
659                #[cfg(not(feature = "cuda"))]
660                {
661                    let _ = (input, out);
662                    unimplemented!("FP8/FP6/FP4 casting isn't supported outside of CUDA");
663                }
664                #[cfg(feature = "cuda")]
665                crate::cuda::convert::special_cast::<D>(f, input, out)
666            }
667        }
668    }
669}
670
671struct Fma<D: Dialect> {
672    _dialect: PhantomData<D>,
673}
674
675impl<D: Dialect> Fma<D> {
676    fn format(
677        f: &mut core::fmt::Formatter<'_>,
678        a: &Variable<D>,
679        b: &Variable<D>,
680        c: &Variable<D>,
681        out: &Variable<D>,
682    ) -> core::fmt::Result {
683        let out_item = out.item();
684        let num = out_item.vectorization;
685
686        let out = out.fmt_left();
687        if num == 1 {
688            writeln!(f, "{out} = fma({a}, {b}, {c});")
689        } else {
690            writeln!(f, "{out} = {out_item}{{")?;
691
692            for i in 0..num {
693                let ai = a.index(i);
694                let bi = b.index(i);
695                let ci = c.index(i);
696
697                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
698            }
699            f.write_str("};\n")
700        }
701    }
702}
703
704struct Clamp<D: Dialect> {
705    _dialect: PhantomData<D>,
706}
707
708impl<D: Dialect> Clamp<D> {
709    fn format(
710        f: &mut core::fmt::Formatter<'_>,
711        input: &Variable<D>,
712        min_value: &Variable<D>,
713        max_value: &Variable<D>,
714        out: &Variable<D>,
715    ) -> core::fmt::Result {
716        let input = input.optimized();
717        let min_value = min_value.optimized();
718        let max_value = max_value.optimized();
719        let out = out.optimized();
720        let out_item = out.item();
721        let num = out_item.vectorization;
722
723        let out_fmt = out.fmt_left();
724        if num == 1 {
725            writeln!(f, "{out_fmt} = ")?;
726            D::compile_instruction_max_function_name(f, out.item())?;
727            writeln!(f, "({min_value}, ")?;
728            D::compile_instruction_min_function_name(f, out.item())?;
729            writeln!(f, "({max_value}, {input}));")
730        } else {
731            writeln!(f, "{out_fmt} = {out_item}{{")?;
732            let mut item = out.item();
733            item.vectorization = 1;
734
735            for i in 0..num {
736                let inputi = input.index(i);
737                let mini = min_value.index(i);
738                let maxi = max_value.index(i);
739
740                D::compile_instruction_max_function_name(f, item)?;
741                writeln!(f, "({mini}, ")?;
742                D::compile_instruction_min_function_name(f, item)?;
743                writeln!(f, "({maxi}, {inputi})),")?;
744            }
745
746            f.write_str("};\n")
747        }
748    }
749}
750
751struct Remainder<D: Dialect> {
752    _dialect: PhantomData<D>,
753}
754
755impl<D: Dialect> Remainder<D> {
756    fn format(
757        f: &mut core::fmt::Formatter<'_>,
758        lhs: &Variable<D>,
759        rhs: &Variable<D>,
760        out: &Variable<D>,
761    ) -> core::fmt::Result {
762        let floor = |elem| {
763            let prefix = match elem {
764                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
765                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
766                _ => "",
767            };
768            format!("{prefix}floor")
769        };
770
771        if out.item().vectorization == 1 {
772            let floor = floor(out.elem());
773
774            let out = out.fmt_left();
775            return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
776        }
777
778        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
779        let [lhs, rhs, out_optimized] = optimized.args;
780
781        let item_out_original = out.item();
782        let item_out_optimized = out_optimized.item();
783
784        let index = match optimized.optimization_factor {
785            Some(factor) => item_out_original.vectorization / factor,
786            None => item_out_optimized.vectorization,
787        };
788
789        let floor = floor(*item_out_optimized.elem());
790
791        let mut write_op =
792            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
793                let out = out.fmt_left();
794                writeln!(f, "{out} = {item_out}{{")?;
795                for i in 0..index {
796                    let lhsi = lhs.index(i);
797                    let rhsi = rhs.index(i);
798
799                    writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
800                    f.write_str(", ")?;
801                }
802
803                f.write_str("};\n")
804            };
805
806        if item_out_original == item_out_optimized {
807            write_op(&lhs, &rhs, out, item_out_optimized)
808        } else {
809            let out_tmp = Variable::tmp(item_out_optimized);
810
811            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
812
813            let addr_space = D::address_space_for_variable(&out_tmp);
814            let qualifier = out.const_qualifier();
815            let out = out.fmt_left();
816
817            writeln!(
818                f,
819                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
820            )?;
821
822            Ok(())
823        }
824    }
825}
826
827struct Magnitude<D: Dialect> {
828    _dialect: PhantomData<D>,
829}
830
831impl<D: Dialect> Magnitude<D> {
832    fn format(
833        f: &mut core::fmt::Formatter<'_>,
834        input: &Variable<D>,
835        out: &Variable<D>,
836    ) -> core::fmt::Result {
837        let num = input.item().vectorization;
838        let elem = input.elem();
839
840        let mag = format!("{out}_mag");
841
842        writeln!(f, "{} {mag} = 0.0;", out.item())?;
843
844        for i in 0..num {
845            let input_i = input.index(i);
846            writeln!(f, "{mag} += {input_i} * {input_i};")?;
847        }
848
849        let out = out.fmt_left();
850        write!(f, "{out} = ")?;
851        Sqrt::format_unary(f, &mag, elem)?;
852        f.write_str(";\n")
853    }
854}
855
856struct Normalize<D: Dialect> {
857    _dialect: PhantomData<D>,
858}
859
860impl<D: Dialect> Normalize<D> {
861    fn format(
862        f: &mut core::fmt::Formatter<'_>,
863        input: &Variable<D>,
864        out: &Variable<D>,
865    ) -> core::fmt::Result {
866        let num = input.item().vectorization;
867        let elem = input.elem();
868        let norm = format!("{out}_norm");
869
870        let out_item = out.item();
871        let out = out.fmt_left();
872        writeln!(f, "{elem} {norm} = 0.0;")?;
873
874        for i in 0..num {
875            let input_i = input.index(i);
876            writeln!(f, "{norm} += {input_i} * {input_i};")?;
877        }
878
879        write!(f, "{norm} = ")?;
880        Sqrt::format_unary(f, &norm, elem)?;
881        f.write_str(";\n")?;
882
883        if num == 1 {
884            writeln!(f, "{out} = {input} / {norm};")
885        } else {
886            write!(f, "{out} = {out_item}{{")?;
887            for i in 0..num {
888                let input_i = input.index(i);
889
890                writeln!(f, "{input_i} / {norm},")?;
891            }
892
893            f.write_str("};\n")
894        }
895    }
896}
897
898struct Dot<D: Dialect> {
899    _dialect: PhantomData<D>,
900}
901
902impl<D: Dialect> Dot<D> {
903    fn format(
904        f: &mut core::fmt::Formatter<'_>,
905        lhs: &Variable<D>,
906        rhs: &Variable<D>,
907        out: &Variable<D>,
908    ) -> core::fmt::Result {
909        let num = lhs.item().vectorization;
910
911        let muls = (0..num)
912            .map(|i| {
913                let lhs_i = lhs.index(i);
914                let rhs_i = rhs.index(i);
915                format!("{lhs_i} * {rhs_i}")
916            })
917            .collect::<Vec<_>>();
918
919        let out = out.fmt_left();
920        writeln!(f, "{out} = {};", muls.join(" + "))
921    }
922}
923
924struct EnsureBoolArg<'a, V: Display, D: Dialect> {
925    var: &'a V,
926    elem: &'a Elem<D>,
927}
928
929impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
930    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
931        if self.elem != &Elem::Bool {
932            write!(f, "bool({})", self.var)
933        } else {
934            write!(f, "{}", self.var)
935        }
936    }
937}