cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5    barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Display, Write},
10    marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone)]
17pub struct BinaryInstruction<D: Dialect> {
18    pub lhs: Variable<D>,
19    pub rhs: Variable<D>,
20    pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct UnaryInstruction<D: Dialect> {
25    pub input: Variable<D>,
26    pub out: Variable<D>,
27}
28
29#[derive(Debug, Clone)]
30pub enum Instruction<D: Dialect> {
31    Metadata {
32        info_offset: Variable<D>,
33        split_meta: bool,
34        out: Variable<D>,
35    },
36    ExtendedMetadata {
37        info_offset: Variable<D>,
38        dim: Variable<D>,
39        split_meta: bool,
40        static_offset: u32,
41        out: Variable<D>,
42    },
43    ConstLength {
44        length: u32,
45        out: Variable<D>,
46    },
47    SliceLength {
48        input: Variable<D>,
49        out: Variable<D>,
50    },
51    DeclareVariable {
52        var: Variable<D>,
53    },
54    Modulo(BinaryInstruction<D>),
55    Remainder(BinaryInstruction<D>),
56    Add(BinaryInstruction<D>),
57    Fma {
58        a: Variable<D>,
59        b: Variable<D>,
60        c: Variable<D>,
61        out: Variable<D>,
62    },
63    Div(BinaryInstruction<D>),
64    Mul(BinaryInstruction<D>),
65    Sub(BinaryInstruction<D>),
66    HiMul(BinaryInstruction<D>),
67    Index(BinaryInstruction<D>),
68    IndexAssign(BinaryInstruction<D>),
69    Assign(UnaryInstruction<D>),
70    RangeLoop {
71        i: Variable<D>,
72        start: Variable<D>,
73        end: Variable<D>,
74        step: Option<Variable<D>>,
75        inclusive: bool,
76        instructions: Vec<Self>,
77    },
78    VecInit {
79        inputs: Vec<Variable<D>>,
80        out: Variable<D>,
81    },
82    Loop {
83        instructions: Vec<Self>,
84    },
85    If {
86        cond: Variable<D>,
87        instructions: Vec<Self>,
88    },
89    IfElse {
90        cond: Variable<D>,
91        instructions_if: Vec<Self>,
92        instructions_else: Vec<Self>,
93    },
94    Select {
95        cond: Variable<D>,
96        then: Variable<D>,
97        or_else: Variable<D>,
98        out: Variable<D>,
99    },
100    Switch {
101        value: Variable<D>,
102        instructions_default: Vec<Self>,
103        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
104    },
105    Slice {
106        input: Variable<D>,
107        start: Variable<D>,
108        end: Variable<D>,
109        out: Variable<D>,
110    },
111    CheckedSlice {
112        input: Variable<D>,
113        start: Variable<D>,
114        end: Variable<D>,
115        out: Variable<D>,
116        len: Variable<D>,
117    },
118    ReinterpretSlice {
119        input: Variable<D>,
120        line_size: u32,
121        out: Variable<D>,
122    },
123    Return,
124    Break,
125    Equal(BinaryInstruction<D>),
126    NotEqual(BinaryInstruction<D>),
127    Lower(BinaryInstruction<D>),
128    Greater(BinaryInstruction<D>),
129    LowerEqual(BinaryInstruction<D>),
130    GreaterEqual(BinaryInstruction<D>),
131    Erf(UnaryInstruction<D>),
132    BitwiseOr(BinaryInstruction<D>),
133    BitwiseAnd(BinaryInstruction<D>),
134    BitwiseXor(BinaryInstruction<D>),
135    CountBits(UnaryInstruction<D>),
136    ReverseBits(UnaryInstruction<D>),
137    ShiftLeft(BinaryInstruction<D>),
138    ShiftRight(BinaryInstruction<D>),
139    BitwiseNot(UnaryInstruction<D>),
140    LeadingZeros(UnaryInstruction<D>),
141    FindFirstSet(UnaryInstruction<D>),
142    Abs(UnaryInstruction<D>),
143    Exp(UnaryInstruction<D>),
144    Log(UnaryInstruction<D>),
145    Log1p(UnaryInstruction<D>),
146    Cos(UnaryInstruction<D>),
147    Sin(UnaryInstruction<D>),
148    Tanh(UnaryInstruction<D>),
149    Powf(BinaryInstruction<D>),
150    Sqrt(UnaryInstruction<D>),
151    Min(BinaryInstruction<D>),
152    Max(BinaryInstruction<D>),
153    Not(UnaryInstruction<D>),
154    Or(BinaryInstruction<D>),
155    And(BinaryInstruction<D>),
156    Clamp {
157        input: Variable<D>,
158        min_value: Variable<D>,
159        max_value: Variable<D>,
160        out: Variable<D>,
161    },
162    SyncThreads,
163    ThreadFence,
164    ProxySharedFence,
165    BulkCommitGroup,
166    BulkWaitGroup {
167        max_pending: u32,
168    },
169    BulkWaitGroupRead {
170        max_pending: u32,
171    },
172    Round(UnaryInstruction<D>),
173    Ceil(UnaryInstruction<D>),
174    Floor(UnaryInstruction<D>),
175    Warp(WarpInstruction<D>),
176    Wmma(WmmaInstruction<D>),
177    Bitcast(UnaryInstruction<D>),
178    AtomicLoad(UnaryInstruction<D>),
179    AtomicStore(UnaryInstruction<D>),
180    AtomicSwap(BinaryInstruction<D>),
181    AtomicAdd(BinaryInstruction<D>),
182    AtomicSub(BinaryInstruction<D>),
183    AtomicMax(BinaryInstruction<D>),
184    AtomicMin(BinaryInstruction<D>),
185    AtomicAnd(BinaryInstruction<D>),
186    AtomicOr(BinaryInstruction<D>),
187    AtomicXor(BinaryInstruction<D>),
188    AtomicCAS {
189        input: Variable<D>,
190        cmp: Variable<D>,
191        val: Variable<D>,
192        out: Variable<D>,
193    },
194    Neg(UnaryInstruction<D>),
195    Magnitude(UnaryInstruction<D>),
196    Normalize(UnaryInstruction<D>),
197    Dot(BinaryInstruction<D>),
198    Copy {
199        input: Variable<D>,
200        in_index: Variable<D>,
201        out: Variable<D>,
202        out_index: Variable<D>,
203    },
204    CopyBulk {
205        input: Variable<D>,
206        in_index: Variable<D>,
207        out: Variable<D>,
208        out_index: Variable<D>,
209        len: u32,
210    },
211    Printf {
212        format_string: String,
213        args: Vec<Variable<D>>,
214    },
215    Comment {
216        content: String,
217    },
218    Pipeline(PipelineOps<D>),
219    Barrier(BarrierOps<D>),
220    MemCopyAsyncTensorSharedToGlobal {
221        smem_buffer: Variable<D>,
222        tensor_map: Variable<D>,
223        indices: Vec<Variable<D>>,
224    },
225    Line {
226        file: Cow<'static, str>,
227        line: u32,
228    },
229}
230
231impl<D: Dialect> Display for Instruction<D> {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        match self {
234            Instruction::Return => f.write_str("return;"),
235            Instruction::Break => f.write_str("break;"),
236            Instruction::DeclareVariable { var } => match var {
237                Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
238                _ => {
239                    let item = var.item();
240                    writeln!(f, "{item} {var};")
241                }
242            },
243            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
244            Instruction::Slice {
245                input,
246                start,
247                end,
248                out,
249            } => {
250                let item = out.item();
251                let addr_space = D::address_space_for_variable(input);
252                writeln!(f, "const uint {out}_length = {end} - {start};")?;
253                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
254            }
255            Instruction::CheckedSlice {
256                input,
257                start,
258                end,
259                out,
260                len,
261            } => {
262                let item = out.item();
263                let addr_space = D::address_space_for_variable(input);
264                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
265                writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
266            }
267            Instruction::ReinterpretSlice {
268                input,
269                line_size,
270                out,
271            } => {
272                let mut item = out.item();
273                item.vectorization = *line_size as usize;
274                let addr_space = D::address_space_for_variable(input);
275
276                writeln!(
277                    f,
278                    "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
279                )
280            }
281            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
282            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
283            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
284            Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
285            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
286            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
287            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
288            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
289            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
290            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
291            Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
292            Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
293            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
294            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
295            Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out),
296            Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out),
297            Instruction::Copy {
298                input,
299                in_index,
300                out,
301                out_index,
302            } => {
303                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
304            }
305            Instruction::CopyBulk {
306                input,
307                in_index,
308                out,
309                out_index,
310                len,
311            } => {
312                for i in 0..*len {
313                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
314                }
315                Ok(())
316            }
317            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
318            Instruction::RangeLoop {
319                i,
320                start,
321                end,
322                step,
323                inclusive,
324                instructions,
325            } => {
326                let increment = step
327                    .map(|step| format!("{i} += {step}"))
328                    .unwrap_or_else(|| format!("++{i}"));
329                let cmp = if *inclusive { "<=" } else { "<" };
330                let i_ty = i.item();
331
332                write!(
333                    f,
334                    "
335for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
336"
337                )?;
338                for instruction in instructions {
339                    write!(f, "{instruction}")?;
340                }
341
342                f.write_str("}\n")
343            }
344            Instruction::Loop { instructions } => {
345                writeln!(f, "while (true) {{")?;
346                for i in instructions {
347                    write!(f, "{i}")?;
348                }
349                f.write_str("}\n")
350            }
351            Instruction::If { cond, instructions } => {
352                writeln!(f, "if ({cond}) {{")?;
353                for i in instructions {
354                    write!(f, "{i}")?;
355                }
356                f.write_str("}\n")
357            }
358            Instruction::IfElse {
359                cond,
360                instructions_if,
361                instructions_else,
362            } => {
363                writeln!(f, "if ({cond}) {{")?;
364                for i in instructions_if {
365                    write!(f, "{i}")?;
366                }
367                f.write_str("} else {\n")?;
368                for i in instructions_else {
369                    write!(f, "{i}")?;
370                }
371                f.write_str("}\n")
372            }
373            Instruction::Select {
374                cond,
375                then,
376                or_else,
377                out,
378            } => {
379                let item_or_else = or_else.item();
380                let item_then = then.item();
381                let item_out = out.item();
382
383                let vf_then = item_then.vectorization;
384                let vf_or_else = item_or_else.vectorization;
385                let vf_out = item_out.vectorization;
386                let vf_cond = cond.item().vectorization;
387
388                let item_out = out.item();
389                let cond_elem = cond.item().elem;
390                let out = out.fmt_left();
391
392                let should_broadcast =
393                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
394
395                if should_broadcast {
396                    let vf = usize::max(vf_cond, vf_out);
397                    let vf = usize::max(vf, vf_then);
398                    let vf = usize::max(vf, vf_or_else);
399
400                    writeln!(f, "{out} = {item_out} {{")?;
401                    for i in 0..vf {
402                        let theni = then.index(i);
403                        let or_elsei = or_else.index(i);
404                        let condi = cond.index(i);
405                        let condi = EnsureBoolArg {
406                            var: &condi,
407                            elem: &cond_elem,
408                        };
409
410                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
411                    }
412
413                    writeln!(f, "}};")
414                } else {
415                    let cond = EnsureBoolArg {
416                        var: &cond,
417                        elem: &cond_elem,
418                    };
419                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
420                }
421            }
422            Instruction::Switch {
423                value,
424                instructions_default,
425                instructions_cases,
426            } => {
427                writeln!(f, "switch({value}) {{")?;
428                for (value, block) in instructions_cases {
429                    write!(f, "case {value}:\n{{\n")?;
430                    for i in block {
431                        i.fmt(f)?;
432                    }
433                    f.write_str("break;\n}\n")?;
434                }
435                f.write_str("default:\n{")?;
436                for i in instructions_default {
437                    i.fmt(f)?;
438                }
439                f.write_str("}\n}\n")
440            }
441            Instruction::Metadata {
442                info_offset,
443                split_meta,
444                out,
445            } => {
446                let out = out.fmt_left();
447                match *split_meta {
448                    true => writeln!(f, "{out} = static_info.x[{info_offset}];"),
449                    false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
450                }
451            }
452            Instruction::ExtendedMetadata {
453                info_offset,
454                dim,
455                split_meta,
456                static_offset,
457                out,
458            } => {
459                let out = out.fmt_left();
460                match *split_meta {
461                    true => writeln!(
462                        f,
463                        "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
464                    ),
465                    false => writeln!(
466                        f,
467                        "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
468                    ),
469                }
470            }
471            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
472            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
473            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
474            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
475            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
476            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
477            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
478            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
479            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
480            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
481            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
482            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
483            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
484            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
485            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
486            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
487            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
488            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
489            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
490            Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
491            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
492            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
493            Instruction::Clamp {
494                input,
495                min_value,
496                max_value,
497                out,
498            } => Clamp::format(f, input, min_value, max_value, out),
499            Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
500            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
501            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
502            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
503            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
504            Instruction::SliceLength { input, out } => {
505                let out = out.fmt_left();
506                writeln!(f, "{out} = {input}_length;")
507            }
508            Instruction::ConstLength { length, out } => {
509                let out = out.fmt_left();
510                writeln!(f, "{out} = {length};")
511            }
512            Instruction::Warp(it) => write!(f, "{it}"),
513            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
514            Instruction::Wmma(it) => write!(f, "{it}"),
515            Instruction::Bitcast(UnaryInstruction { input, out }) => {
516                let qualifier = out.const_qualifier();
517                let input_item = input.item();
518                let out_item = out.item();
519
520                if out_item.elem.size() * out_item.vectorization
521                    != input.item().elem.size() * input.item().vectorization
522                {
523                    panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
524                } else {
525                    let out = out.fmt_left();
526                    let addr_space = D::address_space_for_variable(input);
527                    writeln!(
528                        f,
529                        "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
530                    )
531                }
532            }
533            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
534                D::compile_atomic_add(f, lhs, rhs, out)
535            }
536            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
537                D::compile_atomic_and(f, lhs, rhs, out)
538            }
539            Instruction::AtomicCAS {
540                input,
541                cmp,
542                val,
543                out,
544            } => D::compile_atomic_cas(f, input, cmp, val, out),
545            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
546                D::compile_atomic_load(f, input, out)
547            }
548            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
549                D::compile_atomic_max(f, lhs, rhs, out)
550            }
551            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
552                D::compile_atomic_min(f, lhs, rhs, out)
553            }
554            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
555                D::compile_atomic_or(f, lhs, rhs, out)
556            }
557            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
558                D::compile_atomic_store(f, input, out)
559            }
560            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
561                D::compile_atomic_sub(f, lhs, rhs, out)
562            }
563            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
564                D::compile_atomic_swap(f, lhs, rhs, out)
565            }
566            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
567                D::compile_atomic_xor(f, lhs, rhs, out)
568            }
569            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
570            Instruction::Neg(UnaryInstruction { input, out }) => {
571                let out = out.fmt_left();
572                writeln!(f, "{out} = -{input};")
573            }
574            Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
575            Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
576            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
577            Instruction::VecInit { inputs, out } => {
578                let item = out.item();
579                let inputs = inputs
580                    .iter()
581                    .map(|input| format!("{input}"))
582                    .collect::<Vec<_>>();
583                let out = out.fmt_left();
584                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
585            }
586            Instruction::Printf {
587                format_string,
588                args,
589            } => D::compile_instruction_printf(f, format_string, args),
590            Instruction::Comment { content } => {
591                if content.contains('\n') {
592                    writeln!(f, "/* {content} */")
593                } else {
594                    writeln!(f, "// {content}")
595                }
596            }
597            Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
598            Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
599            Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
600            Instruction::ProxySharedFence => {
601                writeln!(
602                    f,
603                    "cuda::device::experimental::fence_proxy_async_shared_cta();"
604                )
605            }
606            Instruction::BulkCommitGroup => writeln!(
607                f,
608                "cuda::device::experimental::cp_async_bulk_commit_group();"
609            ),
610            Instruction::BulkWaitGroup { max_pending } => writeln!(
611                f,
612                "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
613            ),
614            Instruction::BulkWaitGroupRead { max_pending } => writeln!(
615                f,
616                "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
617            ),
618            Instruction::MemCopyAsyncTensorSharedToGlobal {
619                smem_buffer,
620                tensor_map,
621                indices,
622            } => {
623                let rank = indices.len();
624                let smem_ptr = smem_buffer.fmt_ptr();
625                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
626                    let _ = write!(s, "{it}, ");
627                    s
628                });
629                writeln!(
630                    f,
631                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr});"
632                )
633            }
634        }
635    }
636}
637
638struct Fma<D: Dialect> {
639    _dialect: PhantomData<D>,
640}
641
642impl<D: Dialect> Fma<D> {
643    fn format(
644        f: &mut core::fmt::Formatter<'_>,
645        a: &Variable<D>,
646        b: &Variable<D>,
647        c: &Variable<D>,
648        out: &Variable<D>,
649    ) -> core::fmt::Result {
650        let out_item = out.item();
651        let num = out_item.vectorization;
652
653        let out = out.fmt_left();
654        if num == 1 {
655            writeln!(f, "{out} = fma({a}, {b}, {c});")
656        } else {
657            writeln!(f, "{out} = {out_item}{{")?;
658
659            for i in 0..num {
660                let ai = a.index(i);
661                let bi = b.index(i);
662                let ci = c.index(i);
663
664                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
665            }
666            f.write_str("};\n")
667        }
668    }
669}
670
671struct Clamp<D: Dialect> {
672    _dialect: PhantomData<D>,
673}
674
675impl<D: Dialect> Clamp<D> {
676    fn format(
677        f: &mut core::fmt::Formatter<'_>,
678        input: &Variable<D>,
679        min_value: &Variable<D>,
680        max_value: &Variable<D>,
681        out: &Variable<D>,
682    ) -> core::fmt::Result {
683        let input = input.optimized();
684        let min_value = min_value.optimized();
685        let max_value = max_value.optimized();
686        let out = out.optimized();
687        let out_item = out.item();
688        let num = out_item.vectorization;
689
690        let (min, max) = match out.elem() {
691            Elem::F16 | Elem::BF16 => ("__hmin", "__hmax"),
692            Elem::F162 | Elem::BF162 => ("__hmin2", "__hmax2"),
693            _ => ("min", "max"),
694        };
695
696        let out_fmt = out.fmt_left();
697        if num == 1 {
698            writeln!(f, "{out_fmt} = ")?;
699            D::compile_instruction_max_function_name(f, out.item())?;
700            writeln!(f, "({min_value}, ")?;
701            D::compile_instruction_min_function_name(f, out.item())?;
702            writeln!(f, "({max_value}, {input}));")
703        } else {
704            writeln!(f, "{out_fmt} = {out_item}{{")?;
705            for i in 0..num {
706                let inputi = input.index(i);
707                let mini = min_value.index(i);
708                let maxi = max_value.index(i);
709
710                writeln!(f, "{max}({mini}, {min}({maxi}, {inputi})),")?;
711            }
712
713            f.write_str("};\n")
714        }
715    }
716}
717
718struct Remainder<D: Dialect> {
719    _dialect: PhantomData<D>,
720}
721
722impl<D: Dialect> Remainder<D> {
723    fn format(
724        f: &mut core::fmt::Formatter<'_>,
725        lhs: &Variable<D>,
726        rhs: &Variable<D>,
727        out: &Variable<D>,
728    ) -> core::fmt::Result {
729        let floor = |elem| {
730            let prefix = match elem {
731                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
732                Elem::F162 | Elem::BF162 => D::compile_instruction_half2_function_name_prefix(),
733                _ => "",
734            };
735            format!("{prefix}floor")
736        };
737
738        if out.item().vectorization == 1 {
739            let floor = floor(out.elem());
740
741            let out = out.fmt_left();
742            return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
743        }
744
745        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
746        let [lhs, rhs, out_optimized] = optimized.args;
747
748        let item_out_original = out.item();
749        let item_out_optimized = out_optimized.item();
750
751        let index = match optimized.optimization_factor {
752            Some(factor) => item_out_original.vectorization / factor,
753            None => item_out_optimized.vectorization,
754        };
755
756        let floor = floor(*item_out_optimized.elem());
757
758        let mut write_op =
759            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
760                let out = out.fmt_left();
761                writeln!(f, "{out} = {item_out}{{")?;
762                for i in 0..index {
763                    let lhsi = lhs.index(i);
764                    let rhsi = rhs.index(i);
765
766                    writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
767                    f.write_str(", ")?;
768                }
769
770                f.write_str("};\n")
771            };
772
773        if item_out_original == item_out_optimized {
774            write_op(&lhs, &rhs, out, item_out_optimized)
775        } else {
776            let out_tmp = Variable::tmp(item_out_optimized);
777
778            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
779
780            let addr_space = D::address_space_for_variable(&out_tmp);
781            let qualifier = out.const_qualifier();
782            let out = out.fmt_left();
783
784            writeln!(
785                f,
786                "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
787            )?;
788
789            Ok(())
790        }
791    }
792}
793
794struct Magnitude<D: Dialect> {
795    _dialect: PhantomData<D>,
796}
797
798impl<D: Dialect> Magnitude<D> {
799    fn format(
800        f: &mut core::fmt::Formatter<'_>,
801        input: &Variable<D>,
802        out: &Variable<D>,
803    ) -> core::fmt::Result {
804        let num = input.item().vectorization;
805        let elem = input.elem();
806
807        let mag = format!("{out}_mag");
808
809        writeln!(f, "{} {mag} = 0.0;", out.item())?;
810
811        for i in 0..num {
812            let input_i = input.index(i);
813            writeln!(f, "{mag} += {input_i} * {input_i};")?;
814        }
815
816        let out = out.fmt_left();
817        write!(f, "{out} = ")?;
818        Sqrt::format_unary(f, &mag, elem)?;
819        f.write_str(";\n")
820    }
821}
822
823struct Normalize<D: Dialect> {
824    _dialect: PhantomData<D>,
825}
826
827impl<D: Dialect> Normalize<D> {
828    fn format(
829        f: &mut core::fmt::Formatter<'_>,
830        input: &Variable<D>,
831        out: &Variable<D>,
832    ) -> core::fmt::Result {
833        let num = input.item().vectorization;
834        let elem = input.elem();
835        let norm = format!("{out}_norm");
836
837        let out_item = out.item();
838        let out = out.fmt_left();
839        writeln!(f, "{elem} {norm} = 0.0;")?;
840
841        for i in 0..num {
842            let input_i = input.index(i);
843            writeln!(f, "{norm} += {input_i} * {input_i};")?;
844        }
845
846        write!(f, "{norm} = ")?;
847        Sqrt::format_unary(f, &norm, elem)?;
848        f.write_str(";\n")?;
849
850        if num == 1 {
851            writeln!(f, "{out} = {input} / {norm};")
852        } else {
853            write!(f, "{out} = {out_item}{{")?;
854            for i in 0..num {
855                let input_i = input.index(i);
856
857                writeln!(f, "{input_i} / {norm},")?;
858            }
859
860            f.write_str("};\n")
861        }
862    }
863}
864
865struct Dot<D: Dialect> {
866    _dialect: PhantomData<D>,
867}
868
869impl<D: Dialect> Dot<D> {
870    fn format(
871        f: &mut core::fmt::Formatter<'_>,
872        lhs: &Variable<D>,
873        rhs: &Variable<D>,
874        out: &Variable<D>,
875    ) -> core::fmt::Result {
876        let num = lhs.item().vectorization;
877
878        let muls = (0..num)
879            .map(|i| {
880                let lhs_i = lhs.index(i);
881                let rhs_i = rhs.index(i);
882                format!("{lhs_i} * {rhs_i}")
883            })
884            .collect::<Vec<_>>();
885
886        let out = out.fmt_left();
887        writeln!(f, "{out} = {};", muls.join(" + "))
888    }
889}
890
891struct EnsureBoolArg<'a, V: Display, D: Dialect> {
892    var: &'a V,
893    elem: &'a Elem<D>,
894}
895
896impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
897    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
898        if self.elem != &Elem::Bool {
899            write!(f, "bool({})", self.var)
900        } else {
901            write!(f, "{}", self.var)
902        }
903    }
904}