cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73
74            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
75
76            let out = out.fmt_left();
77
78            writeln!(
79                f,
80                "{out} = reinterpret_cast<{item_out_original}&>({out_tmp});\n"
81            )?;
82
83            Ok(())
84        }
85    }
86}
87
88macro_rules! operator {
89    ($name:ident, $op:expr) => {
90        pub struct $name;
91
92        impl<D: Dialect> Binary<D> for $name {
93            fn format_scalar<Lhs: Display, Rhs: Display>(
94                f: &mut std::fmt::Formatter<'_>,
95                lhs: Lhs,
96                rhs: Rhs,
97                _item: Item<D>,
98            ) -> std::fmt::Result {
99                write!(f, "{lhs} {} {rhs}", $op)
100            }
101        }
102    };
103}
104
105operator!(Add, "+");
106operator!(Sub, "-");
107operator!(Div, "/");
108operator!(Mul, "*");
109operator!(Modulo, "%");
110operator!(Equal, "==");
111operator!(NotEqual, "!=");
112operator!(Lower, "<");
113operator!(LowerEqual, "<=");
114operator!(Greater, ">");
115operator!(GreaterEqual, ">=");
116operator!(ShiftLeft, "<<");
117operator!(ShiftRight, ">>");
118operator!(BitwiseOr, "|");
119operator!(BitwiseAnd, "&");
120operator!(BitwiseXor, "^");
121operator!(Or, "||");
122operator!(And, "&&");
123
124pub struct Powf;
125
126impl<D: Dialect> Binary<D> for Powf {
127    // Powf doesn't support half and no half equivalent exists
128    fn format_scalar<Lhs: Display, Rhs: Display>(
129        f: &mut std::fmt::Formatter<'_>,
130        lhs: Lhs,
131        rhs: Rhs,
132        item: Item<D>,
133    ) -> std::fmt::Result {
134        let elem = item.elem;
135        match elem {
136            Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
137                write!(f, "{elem}(powf(float({lhs}), float({rhs})))")
138            }
139            _ => write!(f, "powf({lhs}, {rhs})"),
140        }
141    }
142
143    // Powf doesn't support half and no half equivalent exists
144    fn unroll_vec(
145        f: &mut Formatter<'_>,
146        lhs: &Variable<D>,
147        rhs: &Variable<D>,
148        out: &Variable<D>,
149    ) -> core::fmt::Result {
150        let item_out = out.item();
151        let index = out.item().vectorization;
152
153        let out = out.fmt_left();
154        writeln!(f, "{out} = {item_out}{{")?;
155        for i in 0..index {
156            let lhsi = lhs.index(i);
157            let rhsi = rhs.index(i);
158
159            Self::format_scalar(f, lhsi, rhsi, item_out)?;
160            f.write_str(", ")?;
161        }
162
163        f.write_str("};\n")
164    }
165}
166
167pub struct Max;
168
169impl<D: Dialect> Binary<D> for Max {
170    fn format_scalar<Lhs: Display, Rhs: Display>(
171        f: &mut std::fmt::Formatter<'_>,
172        lhs: Lhs,
173        rhs: Rhs,
174        item: Item<D>,
175    ) -> std::fmt::Result {
176        let max = match item.elem() {
177            Elem::F16 | Elem::BF16 => "__hmax",
178            Elem::F162 | Elem::BF162 => "__hmax2",
179            _ => "max",
180        };
181
182        write!(f, "{max}({lhs}, {rhs})")
183    }
184}
185
186pub struct Min;
187
188impl<D: Dialect> Binary<D> for Min {
189    fn format_scalar<Lhs: Display, Rhs: Display>(
190        f: &mut std::fmt::Formatter<'_>,
191        lhs: Lhs,
192        rhs: Rhs,
193        item: Item<D>,
194    ) -> std::fmt::Result {
195        let min = match item.elem() {
196            Elem::F16 | Elem::BF16 => "__hmin",
197            Elem::F162 | Elem::BF162 => "__hmin2",
198            _ => "min",
199        };
200
201        write!(f, "{min}({lhs}, {rhs})")
202    }
203}
204
205pub struct IndexAssign;
206pub struct Index;
207
208impl<D: Dialect> Binary<D> for IndexAssign {
209    fn format_scalar<Lhs, Rhs>(
210        f: &mut Formatter<'_>,
211        _lhs: Lhs,
212        rhs: Rhs,
213        item_out: Item<D>,
214    ) -> std::fmt::Result
215    where
216        Lhs: Component<D>,
217        Rhs: Component<D>,
218    {
219        let item_rhs = rhs.item();
220
221        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
222            writeln!(f, "{item_out}{{")?;
223            for i in 0..item_out.vectorization {
224                if cast {
225                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
226                } else {
227                    writeln!(f, "{},", rhs.index(i))?;
228                }
229            }
230            f.write_str("}")?;
231
232            Ok(())
233        };
234
235        if item_out.vectorization != item_rhs.vectorization {
236            format_vec(f, item_out != item_rhs)
237        } else if item_out.elem != item_rhs.elem {
238            if item_out.vectorization > 1 {
239                format_vec(f, true)?;
240            } else {
241                write!(f, "{}({rhs})", item_out.elem)?;
242            }
243            Ok(())
244        } else if rhs.is_const() && item_rhs.vectorization > 1 {
245            // Reinterpret cast in case rhs is optimized
246            write!(f, "reinterpret_cast<{item_out} const&>({rhs})")
247        } else {
248            write!(f, "{rhs}")
249        }
250    }
251
252    fn unroll_vec(
253        f: &mut Formatter<'_>,
254        lhs: &Variable<D>,
255        rhs: &Variable<D>,
256        out: &Variable<D>,
257    ) -> std::fmt::Result {
258        let item_lhs = lhs.item();
259        let out_item = out.item();
260        let out = out.fmt_left();
261
262        for i in 0..item_lhs.vectorization {
263            let lhsi = lhs.index(i);
264            let rhsi = rhs.index(i);
265            write!(f, "{out}[{lhs}] = ")?;
266            Self::format_scalar(f, lhsi, rhsi, out_item)?;
267            f.write_str(";\n")?;
268        }
269
270        Ok(())
271    }
272
273    fn format(
274        f: &mut Formatter<'_>,
275        lhs: &Variable<D>,
276        rhs: &Variable<D>,
277        out: &Variable<D>,
278    ) -> std::fmt::Result {
279        if matches!(out, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
280            return IndexAssignVector::format(f, lhs, rhs, out);
281        };
282
283        let out_item = out.item();
284
285        if lhs.item().vectorization == 1 {
286            write!(f, "{}[{lhs}] = ", out.fmt_left())?;
287            Self::format_scalar(f, *lhs, *rhs, out_item)?;
288            f.write_str(";\n")
289        } else {
290            Self::unroll_vec(f, lhs, rhs, out)
291        }
292    }
293}
294
295impl<D: Dialect> Binary<D> for Index {
296    fn format(
297        f: &mut Formatter<'_>,
298        lhs: &Variable<D>,
299        rhs: &Variable<D>,
300        out: &Variable<D>,
301    ) -> std::fmt::Result {
302        if matches!(lhs, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
303            return IndexVector::format(f, lhs, rhs, out);
304        }
305
306        let item_out = out.item();
307        if let Elem::Atomic(inner) = item_out.elem {
308            write!(f, "{inner}* {out} = &{lhs}[{rhs}];")
309        } else {
310            let out = out.fmt_left();
311            write!(f, "{out} = ")?;
312            Self::format_scalar(f, *lhs, *rhs, item_out)?;
313            f.write_str(";\n")
314        }
315    }
316
317    fn format_scalar<Lhs, Rhs>(
318        f: &mut Formatter<'_>,
319        lhs: Lhs,
320        rhs: Rhs,
321        item_out: Item<D>,
322    ) -> std::fmt::Result
323    where
324        Lhs: Component<D>,
325        Rhs: Component<D>,
326    {
327        let item_lhs = lhs.item();
328
329        let format_vec = |f: &mut Formatter<'_>| {
330            writeln!(f, "{item_out}{{")?;
331            for i in 0..item_out.vectorization {
332                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
333            }
334            f.write_str("}")?;
335
336            Ok(())
337        };
338
339        if item_out.elem != item_lhs.elem {
340            if item_out.vectorization > 1 {
341                format_vec(f)
342            } else {
343                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
344            }
345        } else {
346            write!(f, "{lhs}[{rhs}]")
347        }
348    }
349}
350
351/// The goal is to support indexing of vectorized types.
352///
353/// # Examples
354///
355/// ```c
356/// float4 rhs;
357/// float item = var[0]; // We want that.
358/// float item = var.x; // So we compile to that.
359/// ```
360struct IndexVector<D: Dialect> {
361    _dialect: PhantomData<D>,
362}
363
364/// The goal is to support indexing of vectorized types.
365///
366/// # Examples
367///
368/// ```c
369/// float4 var;
370///
371/// var[0] = 1.0; // We want that.
372/// var.x = 1.0;  // So we compile to that.
373/// ```
374struct IndexAssignVector<D: Dialect> {
375    _dialect: PhantomData<D>,
376}
377
378impl<D: Dialect> IndexVector<D> {
379    fn format(
380        f: &mut Formatter<'_>,
381        lhs: &Variable<D>,
382        rhs: &Variable<D>,
383        out: &Variable<D>,
384    ) -> std::fmt::Result {
385        let index = match rhs {
386            Variable::ConstantScalar(value, _elem) => value.as_usize(),
387            _ => {
388                let elem = out.elem();
389                let qualifier = out.const_qualifier();
390                let out = out.fmt_left();
391                return writeln!(
392                    f,
393                    "{out} = reinterpret_cast<{elem}{qualifier}*>(&{lhs})[{rhs}];"
394                );
395            }
396        };
397
398        let out = out.index(index);
399        let lhs = lhs.index(index);
400
401        let out = out.fmt_left();
402        writeln!(f, "{out} = {lhs};")
403    }
404}
405
406impl<D: Dialect> IndexAssignVector<D> {
407    fn format(
408        f: &mut Formatter<'_>,
409        lhs: &Variable<D>,
410        rhs: &Variable<D>,
411        out: &Variable<D>,
412    ) -> std::fmt::Result {
413        let index = match lhs {
414            Variable::ConstantScalar(value, _) => value.as_usize(),
415            _ => {
416                let elem = out.elem();
417                return writeln!(f, "*(({elem}*)&{out} + {lhs}) = {rhs};");
418            }
419        };
420
421        let out = out.index(index);
422        let rhs = rhs.index(index);
423
424        writeln!(f, "{out} = {rhs};")
425    }
426}