cubecl_cpp/shared/
instruction.rs

1use crate::shared::FmtLeft;
2
3use super::{
4    binary::*, unary::*, Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5};
6use std::{fmt::Display, marker::PhantomData};
7
8#[derive(Debug, Clone)]
9pub struct BinaryInstruction<D: Dialect> {
10    pub lhs: Variable<D>,
11    pub rhs: Variable<D>,
12    pub out: Variable<D>,
13}
14
15#[derive(Debug, Clone)]
16pub struct UnaryInstruction<D: Dialect> {
17    pub input: Variable<D>,
18    pub out: Variable<D>,
19}
20
21#[derive(Debug, Clone)]
22pub enum Instruction<D: Dialect> {
23    Metadata {
24        info_offset: Variable<D>,
25        out: Variable<D>,
26    },
27    ExtendedMetadata {
28        info_offset: Variable<D>,
29        dim: Variable<D>,
30        out: Variable<D>,
31    },
32    SliceLength {
33        input: Variable<D>,
34        out: Variable<D>,
35    },
36    DeclareVariable {
37        var: Variable<D>,
38    },
39    Modulo(BinaryInstruction<D>),
40    Remainder(BinaryInstruction<D>),
41    Add(BinaryInstruction<D>),
42    Fma {
43        a: Variable<D>,
44        b: Variable<D>,
45        c: Variable<D>,
46        out: Variable<D>,
47    },
48    Div(BinaryInstruction<D>),
49    Mul(BinaryInstruction<D>),
50    Sub(BinaryInstruction<D>),
51    Index(BinaryInstruction<D>),
52    IndexAssign(BinaryInstruction<D>),
53    CheckedIndex {
54        len: Variable<D>,
55        lhs: Variable<D>,
56        rhs: Variable<D>,
57        out: Variable<D>,
58    },
59    Assign(UnaryInstruction<D>),
60    RangeLoop {
61        i: Variable<D>,
62        start: Variable<D>,
63        end: Variable<D>,
64        step: Option<Variable<D>>,
65        inclusive: bool,
66        instructions: Vec<Self>,
67    },
68    VecInit {
69        inputs: Vec<Variable<D>>,
70        out: Variable<D>,
71    },
72    Loop {
73        instructions: Vec<Self>,
74    },
75    If {
76        cond: Variable<D>,
77        instructions: Vec<Self>,
78    },
79    IfElse {
80        cond: Variable<D>,
81        instructions_if: Vec<Self>,
82        instructions_else: Vec<Self>,
83    },
84    Select {
85        cond: Variable<D>,
86        then: Variable<D>,
87        or_else: Variable<D>,
88        out: Variable<D>,
89    },
90    Switch {
91        value: Variable<D>,
92        instructions_default: Vec<Self>,
93        instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
94    },
95    Slice {
96        input: Variable<D>,
97        start: Variable<D>,
98        end: Variable<D>,
99        out: Variable<D>,
100    },
101    CheckedSlice {
102        input: Variable<D>,
103        start: Variable<D>,
104        end: Variable<D>,
105        out: Variable<D>,
106        len: Variable<D>,
107    },
108    Return,
109    Break,
110    Equal(BinaryInstruction<D>),
111    NotEqual(BinaryInstruction<D>),
112    Lower(BinaryInstruction<D>),
113    Greater(BinaryInstruction<D>),
114    LowerEqual(BinaryInstruction<D>),
115    GreaterEqual(BinaryInstruction<D>),
116    Erf(UnaryInstruction<D>),
117    BitwiseOr(BinaryInstruction<D>),
118    BitwiseAnd(BinaryInstruction<D>),
119    BitwiseXor(BinaryInstruction<D>),
120    CountBits(UnaryInstruction<D>),
121    ReverseBits(UnaryInstruction<D>),
122    ShiftLeft(BinaryInstruction<D>),
123    ShiftRight(BinaryInstruction<D>),
124    Abs(UnaryInstruction<D>),
125    Exp(UnaryInstruction<D>),
126    Log(UnaryInstruction<D>),
127    Log1p(UnaryInstruction<D>),
128    Cos(UnaryInstruction<D>),
129    Sin(UnaryInstruction<D>),
130    Tanh(UnaryInstruction<D>),
131    Powf(BinaryInstruction<D>),
132    Sqrt(UnaryInstruction<D>),
133    Min(BinaryInstruction<D>),
134    Max(BinaryInstruction<D>),
135    Not(UnaryInstruction<D>),
136    Or(BinaryInstruction<D>),
137    And(BinaryInstruction<D>),
138    Clamp {
139        input: Variable<D>,
140        min_value: Variable<D>,
141        max_value: Variable<D>,
142        out: Variable<D>,
143    },
144    SyncThreads,
145    ThreadFence,
146    Round(UnaryInstruction<D>),
147    Ceil(UnaryInstruction<D>),
148    Floor(UnaryInstruction<D>),
149    Wrap(WarpInstruction<D>),
150    Wmma(WmmaInstruction<D>),
151    Bitcast(UnaryInstruction<D>),
152    AtomicLoad(UnaryInstruction<D>),
153    AtomicStore(UnaryInstruction<D>),
154    AtomicSwap(BinaryInstruction<D>),
155    AtomicAdd(BinaryInstruction<D>),
156    AtomicSub(BinaryInstruction<D>),
157    AtomicMax(BinaryInstruction<D>),
158    AtomicMin(BinaryInstruction<D>),
159    AtomicAnd(BinaryInstruction<D>),
160    AtomicOr(BinaryInstruction<D>),
161    AtomicXor(BinaryInstruction<D>),
162    AtomicCAS {
163        input: Variable<D>,
164        cmp: Variable<D>,
165        val: Variable<D>,
166        out: Variable<D>,
167    },
168    Neg(UnaryInstruction<D>),
169    Magnitude(UnaryInstruction<D>),
170    Normalize(UnaryInstruction<D>),
171    Dot(BinaryInstruction<D>),
172    Copy {
173        input: Variable<D>,
174        in_index: Variable<D>,
175        out: Variable<D>,
176        out_index: Variable<D>,
177    },
178    CopyBulk {
179        input: Variable<D>,
180        in_index: Variable<D>,
181        out: Variable<D>,
182        out_index: Variable<D>,
183        len: u32,
184    },
185    Printf {
186        format_string: String,
187        args: Vec<Variable<D>>,
188    },
189    Comment {
190        content: String,
191    },
192}
193
194impl<D: Dialect> Display for Instruction<D> {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        match self {
197            Instruction::Return => f.write_str("return;"),
198            Instruction::Break => f.write_str("break;"),
199            Instruction::DeclareVariable { var } => match var {
200                Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
201                _ => {
202                    let item = var.item();
203                    writeln!(f, "{item} {var};")
204                }
205            },
206            Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
207            Instruction::Slice {
208                input,
209                start,
210                end,
211                out,
212            } => {
213                let item = out.item();
214                writeln!(f, "const uint {out}_length = {end};")?;
215                writeln!(f, "{item} *{out} = {input} + {start};")
216            }
217            Instruction::CheckedSlice {
218                input,
219                start,
220                end,
221                out,
222                len,
223            } => {
224                let item = out.item();
225                writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
226                writeln!(f, "{item} *{out} = {input} + {start};")
227            }
228            Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
229            Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
230            Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
231            Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
232            Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
233            Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
234            Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
235            Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
236            Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
237            Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
238            Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
239            Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out),
240            Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out),
241            Instruction::CheckedIndex { len, lhs, rhs, out } => {
242                let item_out = out.item();
243                if let Elem::Atomic(inner) = item_out.elem {
244                    write!(f, "{inner}* {out} = &{lhs}[{rhs}];")
245                } else {
246                    let out = out.fmt_left();
247                    write!(f, "{out} = ({rhs} < {len}) ? ")?;
248                    Index::format_scalar(f, *lhs, *rhs, item_out)?;
249                    if item_out.vectorization == 1 {
250                        writeln!(f, " : {item_out}(0);")
251                    } else {
252                        writeln!(f, " : {item_out}{{}};")
253                    }
254                }
255            }
256            Instruction::Copy {
257                input,
258                in_index,
259                out,
260                out_index,
261            } => {
262                writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
263            }
264            Instruction::CopyBulk {
265                input,
266                in_index,
267                out,
268                out_index,
269                len,
270            } => {
271                for i in 0..*len {
272                    writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
273                }
274                Ok(())
275            }
276            Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
277            Instruction::RangeLoop {
278                i,
279                start,
280                end,
281                step,
282                inclusive,
283                instructions,
284            } => {
285                let increment = step
286                    .map(|step| format!("{i} += {step}"))
287                    .unwrap_or_else(|| format!("++{i}"));
288                let cmp = if *inclusive { "<=" } else { "<" };
289                let i_ty = i.item();
290
291                write!(
292                    f,
293                    "
294for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
295"
296                )?;
297                for instruction in instructions {
298                    write!(f, "{instruction}")?;
299                }
300
301                f.write_str("}\n")
302            }
303
304            Instruction::Loop { instructions } => {
305                writeln!(f, "while (true) {{")?;
306                for i in instructions {
307                    write!(f, "{i}")?;
308                }
309                f.write_str("}\n")
310            }
311            Instruction::If { cond, instructions } => {
312                writeln!(f, "if ({cond}) {{")?;
313                for i in instructions {
314                    write!(f, "{i}")?;
315                }
316                f.write_str("}\n")
317            }
318            Instruction::IfElse {
319                cond,
320                instructions_if,
321                instructions_else,
322            } => {
323                writeln!(f, "if ({cond}) {{")?;
324                for i in instructions_if {
325                    write!(f, "{i}")?;
326                }
327                f.write_str("} else {\n")?;
328                for i in instructions_else {
329                    write!(f, "{i}")?;
330                }
331                f.write_str("}\n")
332            }
333            Instruction::Select {
334                cond,
335                then,
336                or_else,
337                out,
338            } => {
339                let item_or_else = or_else.item();
340                let item_then = then.item();
341                let item_out = out.item();
342
343                let vf_then = item_then.vectorization;
344                let vf_or_else = item_or_else.vectorization;
345                let vf_out = item_out.vectorization;
346                let vf_cond = cond.item().vectorization;
347
348                let item_out = out.item();
349                let cond_elem = cond.item().elem;
350                let out = out.fmt_left();
351
352                let should_broadcast =
353                    vf_cond > 1 || item_out != item_or_else || item_out != item_then;
354
355                if should_broadcast {
356                    let vf = usize::max(vf_cond, vf_out);
357                    let vf = usize::max(vf, vf_then);
358                    let vf = usize::max(vf, vf_or_else);
359
360                    writeln!(f, "{out} = {item_out} {{")?;
361                    for i in 0..vf {
362                        let theni = then.index(i);
363                        let or_elsei = or_else.index(i);
364                        let condi = cond.index(i);
365                        let condi = EnsureBoolArg {
366                            var: &condi,
367                            elem: &cond_elem,
368                        };
369
370                        writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
371                    }
372
373                    writeln!(f, "}};")
374                } else {
375                    let cond = EnsureBoolArg {
376                        var: &cond,
377                        elem: &cond_elem,
378                    };
379                    writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
380                }
381            }
382            Instruction::Switch {
383                value,
384                instructions_default,
385                instructions_cases,
386            } => {
387                writeln!(f, "switch({value}) {{")?;
388                for (value, block) in instructions_cases {
389                    write!(f, "case {value}:\n{{\n")?;
390                    for i in block {
391                        i.fmt(f)?;
392                    }
393                    f.write_str("break;\n}\n")?;
394                }
395                f.write_str("default:\n{")?;
396                for i in instructions_default {
397                    i.fmt(f)?;
398                }
399                f.write_str("}\n}\n")
400            }
401            Instruction::Metadata { info_offset, out } => {
402                let out = out.fmt_left();
403                writeln!(f, "{out} = info[{info_offset}];")
404            }
405            Instruction::ExtendedMetadata {
406                info_offset,
407                dim,
408                out,
409            } => {
410                let out = out.fmt_left();
411                writeln!(f, "{out} = info[info[{info_offset}] + {dim}];")
412            }
413            Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
414            Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
415            Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
416            Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
417            Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
418            Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
419            Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
420            Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
421            Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
422            Instruction::Log(it) => Log::format(f, &it.input, &it.out),
423            Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
424            Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
425            Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
426            Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
427            Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
428            Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
429            Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
430            Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
431            Instruction::Not(it) => Not::format(f, &it.input, &it.out),
432            Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
433            Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
434            Instruction::Clamp {
435                input,
436                min_value,
437                max_value,
438                out,
439            } => Clamp::format(f, input, min_value, max_value, out),
440            Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
441            Instruction::ThreadFence => f.write_str("__threadfence();\n"),
442            Instruction::Round(it) => Round::format(f, &it.input, &it.out),
443            Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
444            Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
445            Instruction::SliceLength { input, out } => {
446                let out = out.fmt_left();
447                writeln!(f, "{out} = {input}_length;")
448            }
449            Instruction::Wrap(it) => write!(f, "{it}"),
450            Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
451            Instruction::Wmma(it) => write!(f, "{it}"),
452            Instruction::Bitcast(UnaryInstruction { input, out }) => {
453                let qualifier = out.const_qualifier();
454                let out_elem = out.elem();
455                let out = out.fmt_left();
456
457                match (input.elem(), out_elem) {
458                    (Elem::F32, Elem::I32) => {
459                        writeln!(f, "{out} = __float_as_int({input});")
460                    }
461                    (Elem::F32, Elem::U32) => {
462                        writeln!(f, "{out} = __float_as_uint({input});")
463                    }
464                    (Elem::F16, Elem::I32) => {
465                        writeln!(f, "{out} = __half_as_short({input});")
466                    }
467                    (Elem::F16, Elem::U32) => {
468                        writeln!(f, "{out} = __half_as_ushort({input});")
469                    }
470                    (Elem::BF16, Elem::I32) => {
471                        writeln!(f, "{out} = __bfloat16_as_short({input});")
472                    }
473                    (Elem::BF16, Elem::U32) => {
474                        writeln!(f, "{out} = __bfloat16_as_ushort({input});")
475                    }
476                    (Elem::I32, Elem::F32) => {
477                        writeln!(f, "{out} = __int_as_float({input});")
478                    }
479                    (Elem::I32, Elem::F16) => {
480                        writeln!(f, "{out} = __short_as_half({input});")
481                    }
482                    (Elem::I32, Elem::BF16) => {
483                        writeln!(f, "{out} = __short_as_bfloat16({input});")
484                    }
485                    (Elem::U32, Elem::F32) => {
486                        writeln!(f, "{out} = __uint_as_float({input});")
487                    }
488                    (Elem::U32, Elem::F16) => {
489                        writeln!(f, "{out} = __ushort_as_half({input});")
490                    }
491                    (Elem::U32, Elem::BF16) => {
492                        writeln!(f, "{out} = __ushort_as_bfloat16({input});")
493                    }
494                    (Elem::I32, Elem::U32) => {
495                        writeln!(f, "{out} = reinterpret_cast<uint{qualifier}&>({input});")
496                    }
497                    elem => panic!("Unsupported type for bitcasting {elem:?}"),
498                }
499            }
500            Instruction::AtomicCAS {
501                input,
502                cmp,
503                val,
504                out,
505            } => {
506                let out = out.fmt_left();
507                writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
508            }
509            Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
510                let out = out.fmt_left();
511                writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
512            }
513            Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
514                let out = out.fmt_left();
515                match rhs.elem() {
516                    Elem::I64 => {
517                        writeln!(
518                            f,
519                            "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
520                            uint = Elem::<D>::U64
521                        )
522                    }
523                    _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
524                }
525            }
526            Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
527                let out = out.fmt_left();
528                match rhs.elem() {
529                    Elem::U32 | Elem::I32 => {
530                        writeln!(f, "{out} = atomicSub({lhs}, {rhs});")
531                    }
532                    Elem::U64 => {
533                        writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});",)
534                    }
535                    Elem::I64 => {
536                        writeln!(
537                            f,
538                            "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
539                            uint = Elem::<D>::U64
540                        )
541                    }
542                    _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
543                }
544            }
545            Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
546                let out = out.fmt_left();
547                writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
548            }
549            Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
550                let out = out.fmt_left();
551                writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
552            }
553            Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
554                let out = out.fmt_left();
555                writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
556            }
557            Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
558                let out = out.fmt_left();
559                writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
560            }
561            Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
562                let out = out.fmt_left();
563                writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
564            }
565            Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
566                let out = out.fmt_left();
567                writeln!(f, "{out} = atomicAdd({input}, 0);")
568            }
569            Instruction::AtomicStore(UnaryInstruction { input, out }) => {
570                let out = out.fmt_left();
571                writeln!(f, "atomicExch({out}, {input});")
572            }
573            Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
574            Instruction::Neg(UnaryInstruction { input, out }) => {
575                let out = out.fmt_left();
576                writeln!(f, "{out} = -{input};")
577            }
578            Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
579            Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
580            Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
581            Instruction::VecInit { inputs, out } => {
582                let item = out.item();
583                let inputs = inputs
584                    .iter()
585                    .map(|input| format!("{input}"))
586                    .collect::<Vec<_>>();
587                let out = out.fmt_left();
588                writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
589            }
590            Instruction::Printf {
591                format_string,
592                args,
593            } => {
594                let format_string = escape_string(format_string);
595                let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
596                let args = match args.is_empty() {
597                    true => "".to_string(),
598                    false => format!(", {}", args.join(",")),
599                };
600                writeln!(f, "printf(\"{format_string}\"{args});")
601            }
602            Instruction::Comment { content } => {
603                if content.contains('\n') {
604                    writeln!(f, "/* {content} */")
605                } else {
606                    writeln!(f, "// {content}")
607                }
608            }
609        }
610    }
611}
612
613fn escape_string(format_string: &str) -> String {
614    format_string
615        .replace("\t", "\\t")
616        .replace("\n", "\\n")
617        .replace("\r", "\\r")
618}
619
620struct Fma<D: Dialect> {
621    _dialect: PhantomData<D>,
622}
623
624impl<D: Dialect> Fma<D> {
625    fn format(
626        f: &mut core::fmt::Formatter<'_>,
627        a: &Variable<D>,
628        b: &Variable<D>,
629        c: &Variable<D>,
630        out: &Variable<D>,
631    ) -> core::fmt::Result {
632        let out_item = out.item();
633        let num = out_item.vectorization;
634
635        let out = out.fmt_left();
636        if num == 1 {
637            writeln!(f, "{out} = fma({a}, {b}, {c});")
638        } else {
639            writeln!(f, "{out} = {out_item}{{")?;
640
641            for i in 0..num {
642                let ai = a.index(i);
643                let bi = b.index(i);
644                let ci = c.index(i);
645
646                writeln!(f, "fma({ai}, {bi}, {ci}),")?;
647            }
648            f.write_str("};\n")
649        }
650    }
651}
652
653struct Clamp<D: Dialect> {
654    _dialect: PhantomData<D>,
655}
656
657impl<D: Dialect> Clamp<D> {
658    fn format(
659        f: &mut core::fmt::Formatter<'_>,
660        input: &Variable<D>,
661        min_value: &Variable<D>,
662        max_value: &Variable<D>,
663        out: &Variable<D>,
664    ) -> core::fmt::Result {
665        let input = input.optimized();
666        let min_value = min_value.optimized();
667        let max_value = max_value.optimized();
668        let out = out.optimized();
669        let out_item = out.item();
670        let num = out_item.vectorization;
671
672        let (min, max) = match out.elem() {
673            Elem::F16 | Elem::BF16 => ("__hmin", "__hmax"),
674            Elem::F162 | Elem::BF162 => ("__hmin2", "__hmax2"),
675            _ => ("min", "max"),
676        };
677
678        let out = out.fmt_left();
679        if num == 1 {
680            writeln!(
681                f,
682                "{out} = {max}({min_value}, {min}({max_value}, {input}));"
683            )
684        } else {
685            writeln!(f, "{out} = {out_item}{{")?;
686            for i in 0..num {
687                let inputi = input.index(i);
688                let mini = min_value.index(i);
689                let maxi = max_value.index(i);
690
691                writeln!(f, "{max}({mini}, {min}({maxi}, {inputi})),")?;
692            }
693
694            f.write_str("};\n")
695        }
696    }
697}
698
699struct Remainder<D: Dialect> {
700    _dialect: PhantomData<D>,
701}
702
703impl<D: Dialect> Remainder<D> {
704    fn format(
705        f: &mut core::fmt::Formatter<'_>,
706        lhs: &Variable<D>,
707        rhs: &Variable<D>,
708        out: &Variable<D>,
709    ) -> core::fmt::Result {
710        let floor = |elem| match elem {
711            Elem::F16 | Elem::BF16 => "hfloor",
712            Elem::F162 | Elem::BF162 => "h2floor",
713            _ => "floor",
714        };
715
716        if out.item().vectorization == 1 {
717            let floor = floor(out.elem());
718
719            let out = out.fmt_left();
720            return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
721        }
722
723        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
724        let [lhs, rhs, out_optimized] = optimized.args;
725
726        let item_out_original = out.item();
727        let item_out_optimized = out_optimized.item();
728
729        let index = match optimized.optimization_factor {
730            Some(factor) => item_out_original.vectorization / factor,
731            None => item_out_optimized.vectorization,
732        };
733
734        let floor = floor(*item_out_optimized.elem());
735
736        let mut write_op =
737            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
738                let out = out.fmt_left();
739                writeln!(f, "{out} = {item_out}{{")?;
740                for i in 0..index {
741                    let lhsi = lhs.index(i);
742                    let rhsi = rhs.index(i);
743
744                    writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
745                    f.write_str(", ")?;
746                }
747
748                f.write_str("};\n")
749            };
750
751        if item_out_original == item_out_optimized {
752            write_op(&lhs, &rhs, out, item_out_optimized)
753        } else {
754            let out_tmp = Variable::tmp(item_out_optimized);
755
756            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
757
758            let qualifier = out.const_qualifier();
759            let out = out.fmt_left();
760
761            writeln!(
762                f,
763                "{out} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n"
764            )?;
765
766            Ok(())
767        }
768    }
769}
770
771struct Magnitude<D: Dialect> {
772    _dialect: PhantomData<D>,
773}
774
775impl<D: Dialect> Magnitude<D> {
776    fn format(
777        f: &mut core::fmt::Formatter<'_>,
778        input: &Variable<D>,
779        out: &Variable<D>,
780    ) -> core::fmt::Result {
781        let num = input.item().vectorization;
782        let elem = input.elem();
783
784        let mag = format!("{out}_mag");
785
786        writeln!(f, "{} {mag} = 0.0;", out.item())?;
787
788        for i in 0..num {
789            let input_i = input.index(i);
790            writeln!(f, "{mag} += {input_i} * {input_i};")?;
791        }
792
793        let out = out.fmt_left();
794        write!(f, "{out} = ")?;
795        Sqrt::format_unary(f, &mag, elem)?;
796        f.write_str(";\n")
797    }
798}
799
800struct Normalize<D: Dialect> {
801    _dialect: PhantomData<D>,
802}
803
804impl<D: Dialect> Normalize<D> {
805    fn format(
806        f: &mut core::fmt::Formatter<'_>,
807        input: &Variable<D>,
808        out: &Variable<D>,
809    ) -> core::fmt::Result {
810        let num = input.item().vectorization;
811        let elem = input.elem();
812        let norm = format!("{out}_norm");
813
814        let out_item = out.item();
815        let out = out.fmt_left();
816        writeln!(f, "{elem} {norm} = 0.0;")?;
817
818        for i in 0..num {
819            let input_i = input.index(i);
820            writeln!(f, "{norm} += {input_i} * {input_i};")?;
821        }
822
823        write!(f, "{norm} = ")?;
824        Sqrt::format_unary(f, &norm, elem)?;
825        f.write_str(";\n")?;
826
827        if num == 1 {
828            writeln!(f, "{out} = {input} / {norm};")
829        } else {
830            write!(f, "{out} = {out_item}{{")?;
831            for i in 0..num {
832                let input_i = input.index(i);
833
834                writeln!(f, "{input_i} / {norm},")?;
835            }
836
837            f.write_str("};\n")
838        }
839    }
840}
841
842struct Dot<D: Dialect> {
843    _dialect: PhantomData<D>,
844}
845
846impl<D: Dialect> Dot<D> {
847    fn format(
848        f: &mut core::fmt::Formatter<'_>,
849        lhs: &Variable<D>,
850        rhs: &Variable<D>,
851        out: &Variable<D>,
852    ) -> core::fmt::Result {
853        let num = lhs.item().vectorization;
854
855        let muls = (0..num)
856            .map(|i| {
857                let lhs_i = lhs.index(i);
858                let rhs_i = rhs.index(i);
859                format!("{lhs_i} * {rhs_i}")
860            })
861            .collect::<Vec<_>>();
862
863        let out = out.fmt_left();
864        writeln!(f, "{out} = {};", muls.join(" + "))
865    }
866}
867
868struct EnsureBoolArg<'a, V: Display, D: Dialect> {
869    var: &'a V,
870    elem: &'a Elem<D>,
871}
872
873impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
874    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875        if self.elem != &Elem::Bool {
876            write!(f, "bool({})", self.var)
877        } else {
878            write!(f, "{}", self.var)
879        }
880    }
881}