cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74            let addr_space = D::address_space_for_variable(out);
75            let out = out.fmt_left();
76
77            writeln!(
78                f,
79                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80            )?;
81
82            Ok(())
83        }
84    }
85}
86
87macro_rules! operator {
88    ($name:ident, $op:expr) => {
89        pub struct $name;
90
91        impl<D: Dialect> Binary<D> for $name {
92            fn format_scalar<Lhs: Display, Rhs: Display>(
93                f: &mut std::fmt::Formatter<'_>,
94                lhs: Lhs,
95                rhs: Rhs,
96                out_item: Item<D>,
97            ) -> std::fmt::Result {
98                let out_elem = out_item.elem();
99                match out_elem {
100                    // prevent auto-promotion rules to kick-in in order to stay in the same type
101                    // this is because of fusion and vectorization that can do elemwise operations on vectorized type,
102                    // the resulting elements need to be of the same type.
103                    Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104                        write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105                    }
106                    _ => write!(f, "{lhs} {} {rhs}", $op),
107                }
108            }
109        }
110    };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct HiMul;
133
134impl<D: Dialect> Binary<D> for HiMul {
135    // Powf doesn't support half and no half equivalent exists
136    fn format_scalar<Lhs: Display, Rhs: Display>(
137        f: &mut std::fmt::Formatter<'_>,
138        lhs: Lhs,
139        rhs: Rhs,
140        out: Item<D>,
141    ) -> std::fmt::Result {
142        let out_elem = out.elem;
143        match out_elem {
144            Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
145            Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
146            Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
147            Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
148            _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
149        }
150    }
151
152    // Powf doesn't support half and no half equivalent exists
153    fn unroll_vec(
154        f: &mut Formatter<'_>,
155        lhs: &Variable<D>,
156        rhs: &Variable<D>,
157        out: &Variable<D>,
158    ) -> core::fmt::Result {
159        let item_out = out.item();
160        let index = out.item().vectorization;
161
162        let out = out.fmt_left();
163        writeln!(f, "{out} = {item_out}{{")?;
164        for i in 0..index {
165            let lhsi = lhs.index(i);
166            let rhsi = rhs.index(i);
167
168            Self::format_scalar(f, lhsi, rhsi, item_out)?;
169            f.write_str(", ")?;
170        }
171
172        f.write_str("};\n")
173    }
174}
175
176pub struct Powf;
177
178impl<D: Dialect> Binary<D> for Powf {
179    // Powf doesn't support half and no half equivalent exists
180    fn format_scalar<Lhs: Display, Rhs: Display>(
181        f: &mut std::fmt::Formatter<'_>,
182        lhs: Lhs,
183        rhs: Rhs,
184        item: Item<D>,
185    ) -> std::fmt::Result {
186        let elem = item.elem;
187        match elem {
188            Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
189                write!(f, "{elem}(")?;
190                D::compile_instruction_powf(f)?;
191                write!(f, "(float({lhs}), float({rhs})))")
192            }
193            _ => {
194                D::compile_instruction_powf(f)?;
195                write!(f, "({lhs}, {rhs})")
196            }
197        }
198    }
199
200    // Powf doesn't support half and no half equivalent exists
201    fn unroll_vec(
202        f: &mut Formatter<'_>,
203        lhs: &Variable<D>,
204        rhs: &Variable<D>,
205        out: &Variable<D>,
206    ) -> core::fmt::Result {
207        let item_out = out.item();
208        let index = out.item().vectorization;
209
210        let out = out.fmt_left();
211        writeln!(f, "{out} = {item_out}{{")?;
212        for i in 0..index {
213            let lhsi = lhs.index(i);
214            let rhsi = rhs.index(i);
215
216            Self::format_scalar(f, lhsi, rhsi, item_out)?;
217            f.write_str(", ")?;
218        }
219
220        f.write_str("};\n")
221    }
222}
223
224pub struct Max;
225
226impl<D: Dialect> Binary<D> for Max {
227    fn format_scalar<Lhs: Display, Rhs: Display>(
228        f: &mut std::fmt::Formatter<'_>,
229        lhs: Lhs,
230        rhs: Rhs,
231        item: Item<D>,
232    ) -> std::fmt::Result {
233        D::compile_instruction_max_function_name(f, item)?;
234        write!(f, "({lhs}, {rhs})")
235    }
236}
237
238pub struct Min;
239
240impl<D: Dialect> Binary<D> for Min {
241    fn format_scalar<Lhs: Display, Rhs: Display>(
242        f: &mut std::fmt::Formatter<'_>,
243        lhs: Lhs,
244        rhs: Rhs,
245        item: Item<D>,
246    ) -> std::fmt::Result {
247        D::compile_instruction_min_function_name(f, item)?;
248        write!(f, "({lhs}, {rhs})")
249    }
250}
251
252pub struct IndexAssign;
253pub struct Index;
254
255impl<D: Dialect> Binary<D> for IndexAssign {
256    fn format_scalar<Lhs, Rhs>(
257        f: &mut Formatter<'_>,
258        _lhs: Lhs,
259        rhs: Rhs,
260        item_out: Item<D>,
261    ) -> std::fmt::Result
262    where
263        Lhs: Component<D>,
264        Rhs: Component<D>,
265    {
266        let item_rhs = rhs.item();
267
268        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
269            writeln!(f, "{item_out}{{")?;
270            for i in 0..item_out.vectorization {
271                if cast {
272                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
273                } else {
274                    writeln!(f, "{},", rhs.index(i))?;
275                }
276            }
277            f.write_str("}")?;
278
279            Ok(())
280        };
281
282        if item_out.vectorization != item_rhs.vectorization {
283            format_vec(f, item_out != item_rhs)
284        } else if item_out.elem != item_rhs.elem {
285            if item_out.vectorization > 1 {
286                format_vec(f, true)?;
287            } else {
288                write!(f, "{}({rhs})", item_out.elem)?;
289            }
290            Ok(())
291        } else if rhs.is_const() && item_rhs.vectorization > 1 {
292            // Reinterpret cast in case rhs is optimized
293            write!(f, "reinterpret_cast<")?;
294            D::compile_local_memory_qualifier(f)?;
295            write!(f, " {item_out} const&>({rhs})")
296        } else {
297            write!(f, "{rhs}")
298        }
299    }
300
301    fn unroll_vec(
302        f: &mut Formatter<'_>,
303        lhs: &Variable<D>,
304        rhs: &Variable<D>,
305        out: &Variable<D>,
306    ) -> std::fmt::Result {
307        let item_lhs = lhs.item();
308        let out_item = out.item();
309        let out = out.fmt_left();
310
311        for i in 0..item_lhs.vectorization {
312            let lhsi = lhs.index(i);
313            let rhsi = rhs.index(i);
314            write!(f, "{out}[{lhs}] = ")?;
315            Self::format_scalar(f, lhsi, rhsi, out_item)?;
316            f.write_str(";\n")?;
317        }
318
319        Ok(())
320    }
321
322    fn format(
323        f: &mut Formatter<'_>,
324        lhs: &Variable<D>,
325        rhs: &Variable<D>,
326        out: &Variable<D>,
327    ) -> std::fmt::Result {
328        if matches!(out, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
329            return IndexAssignVector::format(f, lhs, rhs, out);
330        };
331
332        let out_item = out.item();
333
334        if lhs.item().vectorization == 1 {
335            write!(f, "{}[{lhs}] = ", out.fmt_left())?;
336            Self::format_scalar(f, *lhs, *rhs, out_item)?;
337            f.write_str(";\n")
338        } else {
339            Self::unroll_vec(f, lhs, rhs, out)
340        }
341    }
342}
343
344impl<D: Dialect> Binary<D> for Index {
345    fn format(
346        f: &mut Formatter<'_>,
347        lhs: &Variable<D>,
348        rhs: &Variable<D>,
349        out: &Variable<D>,
350    ) -> std::fmt::Result {
351        if matches!(lhs, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
352            return IndexVector::format(f, lhs, rhs, out);
353        }
354
355        let item_out = out.item();
356        if let Elem::Atomic(inner) = item_out.elem {
357            let addr_space = D::address_space_for_variable(lhs);
358            writeln!(f, "{addr_space}{inner}* {out} = &{lhs}[{rhs}];")
359        } else {
360            let out = out.fmt_left();
361            write!(f, "{out} = ")?;
362            Self::format_scalar(f, *lhs, *rhs, item_out)?;
363            f.write_str(";\n")
364        }
365    }
366
367    fn format_scalar<Lhs, Rhs>(
368        f: &mut Formatter<'_>,
369        lhs: Lhs,
370        rhs: Rhs,
371        item_out: Item<D>,
372    ) -> std::fmt::Result
373    where
374        Lhs: Component<D>,
375        Rhs: Component<D>,
376    {
377        let item_lhs = lhs.item();
378
379        let format_vec = |f: &mut Formatter<'_>| {
380            writeln!(f, "{item_out}{{")?;
381            for i in 0..item_out.vectorization {
382                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
383            }
384            f.write_str("}")?;
385
386            Ok(())
387        };
388
389        if item_out.elem != item_lhs.elem {
390            if item_out.vectorization > 1 {
391                format_vec(f)
392            } else {
393                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
394            }
395        } else {
396            write!(f, "{lhs}[{rhs}]")
397        }
398    }
399}
400
401/// The goal is to support indexing of vectorized types.
402///
403/// # Examples
404///
405/// ```c
406/// float4 rhs;
407/// float item = var[0]; // We want that.
408/// float item = var.x; // So we compile to that.
409/// ```
410struct IndexVector<D: Dialect> {
411    _dialect: PhantomData<D>,
412}
413
414/// The goal is to support indexing of vectorized types.
415///
416/// # Examples
417///
418/// ```c
419/// float4 var;
420///
421/// var[0] = 1.0; // We want that.
422/// var.x = 1.0;  // So we compile to that.
423/// ```
424struct IndexAssignVector<D: Dialect> {
425    _dialect: PhantomData<D>,
426}
427
428impl<D: Dialect> IndexVector<D> {
429    fn format(
430        f: &mut Formatter<'_>,
431        lhs: &Variable<D>,
432        rhs: &Variable<D>,
433        out: &Variable<D>,
434    ) -> std::fmt::Result {
435        match rhs {
436            Variable::ConstantScalar(value, _elem) => {
437                let index = value.as_usize();
438                let out = out.index(index);
439                let lhs = lhs.index(index);
440                let out = out.fmt_left();
441                writeln!(f, "{out} = {lhs};")
442            }
443            _ => {
444                let elem = out.elem();
445                let qualifier = out.const_qualifier();
446                let addr_space = D::address_space_for_variable(out);
447                let out = out.fmt_left();
448                writeln!(
449                    f,
450                    "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
451                )
452            }
453        }
454    }
455}
456
457impl<D: Dialect> IndexAssignVector<D> {
458    fn format(
459        f: &mut Formatter<'_>,
460        lhs: &Variable<D>,
461        rhs: &Variable<D>,
462        out: &Variable<D>,
463    ) -> std::fmt::Result {
464        let index = match lhs {
465            Variable::ConstantScalar(value, _) => value.as_usize(),
466            _ => {
467                let elem = out.elem();
468                let addr_space = D::address_space_for_variable(out);
469                return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
470            }
471        };
472
473        let out = out.index(index);
474        let rhs = rhs.index(index);
475
476        writeln!(f, "{out} = {rhs};")
477    }
478}