cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74            let addr_space = D::address_space_for_variable(out);
75            let out = out.fmt_left();
76
77            writeln!(
78                f,
79                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80            )?;
81
82            Ok(())
83        }
84    }
85}
86
87macro_rules! operator {
88    ($name:ident, $op:expr) => {
89        pub struct $name;
90
91        impl<D: Dialect> Binary<D> for $name {
92            fn format_scalar<Lhs: Display, Rhs: Display>(
93                f: &mut std::fmt::Formatter<'_>,
94                lhs: Lhs,
95                rhs: Rhs,
96                out_item: Item<D>,
97            ) -> std::fmt::Result {
98                let out_elem = out_item.elem();
99                match out_elem {
100                    // prevent auto-promotion rules to kick-in in order to stay in the same type
101                    // this is because of fusion and vectorization that can do elemwise operations on vectorized type,
102                    // the resulting elements need to be of the same type.
103                    Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104                        write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105                    }
106                    _ => write!(f, "{lhs} {} {rhs}", $op),
107                }
108            }
109        }
110    };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct HiMul;
133
134impl<D: Dialect> Binary<D> for HiMul {
135    // Powf doesn't support half and no half equivalent exists
136    fn format_scalar<Lhs: Display, Rhs: Display>(
137        f: &mut std::fmt::Formatter<'_>,
138        lhs: Lhs,
139        rhs: Rhs,
140        out: Item<D>,
141    ) -> std::fmt::Result {
142        let out_elem = out.elem;
143        match out_elem {
144            Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
145            Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
146            Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
147            Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
148            _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
149        }
150    }
151
152    // Powf doesn't support half and no half equivalent exists
153    fn unroll_vec(
154        f: &mut Formatter<'_>,
155        lhs: &Variable<D>,
156        rhs: &Variable<D>,
157        out: &Variable<D>,
158    ) -> core::fmt::Result {
159        let item_out = out.item();
160        let index = out.item().vectorization;
161
162        let out = out.fmt_left();
163        writeln!(f, "{out} = {item_out}{{")?;
164        for i in 0..index {
165            let lhsi = lhs.index(i);
166            let rhsi = rhs.index(i);
167
168            Self::format_scalar(f, lhsi, rhsi, item_out)?;
169            f.write_str(", ")?;
170        }
171
172        f.write_str("};\n")
173    }
174}
175
176pub struct Powf;
177
178impl<D: Dialect> Binary<D> for Powf {
179    // Powf doesn't support half and no half equivalent exists
180    fn format_scalar<Lhs: Display, Rhs: Display>(
181        f: &mut std::fmt::Formatter<'_>,
182        lhs: Lhs,
183        rhs: Rhs,
184        item: Item<D>,
185    ) -> std::fmt::Result {
186        let elem = item.elem;
187        match elem {
188            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
189                write!(f, "{elem}(")?;
190                D::compile_instruction_powf(f)?;
191                write!(f, "(float({lhs}), float({rhs})))")
192            }
193            _ => {
194                D::compile_instruction_powf(f)?;
195                write!(f, "({lhs}, {rhs})")
196            }
197        }
198    }
199
200    // Powf doesn't support half and no half equivalent exists
201    fn unroll_vec(
202        f: &mut Formatter<'_>,
203        lhs: &Variable<D>,
204        rhs: &Variable<D>,
205        out: &Variable<D>,
206    ) -> core::fmt::Result {
207        let item_out = out.item();
208        let index = out.item().vectorization;
209
210        let out = out.fmt_left();
211        writeln!(f, "{out} = {item_out}{{")?;
212        for i in 0..index {
213            let lhsi = lhs.index(i);
214            let rhsi = rhs.index(i);
215
216            Self::format_scalar(f, lhsi, rhsi, item_out)?;
217            f.write_str(", ")?;
218        }
219
220        f.write_str("};\n")
221    }
222}
223
224pub struct Max;
225
226impl<D: Dialect> Binary<D> for Max {
227    fn format_scalar<Lhs: Display, Rhs: Display>(
228        f: &mut std::fmt::Formatter<'_>,
229        lhs: Lhs,
230        rhs: Rhs,
231        item: Item<D>,
232    ) -> std::fmt::Result {
233        D::compile_instruction_max_function_name(f, item)?;
234        write!(f, "({lhs}, {rhs})")
235    }
236}
237
238pub struct Min;
239
240impl<D: Dialect> Binary<D> for Min {
241    fn format_scalar<Lhs: Display, Rhs: Display>(
242        f: &mut std::fmt::Formatter<'_>,
243        lhs: Lhs,
244        rhs: Rhs,
245        item: Item<D>,
246    ) -> std::fmt::Result {
247        D::compile_instruction_min_function_name(f, item)?;
248        write!(f, "({lhs}, {rhs})")
249    }
250}
251
252pub struct IndexAssign;
253pub struct Index;
254
255impl IndexAssign {
256    pub fn format<D: Dialect>(
257        f: &mut Formatter<'_>,
258        index: &Variable<D>,
259        value: &Variable<D>,
260        out_list: &Variable<D>,
261        line_size: u32,
262    ) -> std::fmt::Result {
263        if matches!(
264            out_list,
265            Variable::LocalMut { .. } | Variable::LocalConst { .. }
266        ) {
267            return IndexAssignVector::format(f, index, value, out_list);
268        };
269
270        if line_size > 0 {
271            let mut item = out_list.item();
272            item.vectorization = line_size as usize;
273            let addr_space = D::address_space_for_variable(out_list);
274            let qualifier = out_list.const_qualifier();
275            let tmp = Variable::tmp_declared(item);
276
277            writeln!(
278                f,
279                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
280            )?;
281
282            return IndexAssign::format(f, index, value, &tmp, 0);
283        }
284
285        let out_item = out_list.item();
286
287        if index.item().vectorization == 1 {
288            write!(f, "{}[{index}] = ", out_list.fmt_left())?;
289            Self::format_scalar(f, *index, *value, out_item)?;
290            f.write_str(";\n")
291        } else {
292            Self::unroll_vec(f, index, value, out_list)
293        }
294    }
295    fn format_scalar<D: Dialect, Lhs, Rhs>(
296        f: &mut Formatter<'_>,
297        _lhs: Lhs,
298        rhs: Rhs,
299        item_out: Item<D>,
300    ) -> std::fmt::Result
301    where
302        Lhs: Component<D>,
303        Rhs: Component<D>,
304    {
305        let item_rhs = rhs.item();
306
307        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
308            writeln!(f, "{item_out}{{")?;
309            for i in 0..item_out.vectorization {
310                if cast {
311                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
312                } else {
313                    writeln!(f, "{},", rhs.index(i))?;
314                }
315            }
316            f.write_str("}")?;
317
318            Ok(())
319        };
320
321        if item_out.vectorization != item_rhs.vectorization {
322            format_vec(f, item_out != item_rhs)
323        } else if item_out.elem != item_rhs.elem {
324            if item_out.vectorization > 1 {
325                format_vec(f, true)?;
326            } else {
327                write!(f, "{}({rhs})", item_out.elem)?;
328            }
329            Ok(())
330        } else if rhs.is_const() && item_rhs.vectorization > 1 {
331            // Reinterpret cast in case rhs is optimized
332            write!(f, "reinterpret_cast<")?;
333            D::compile_local_memory_qualifier(f)?;
334            write!(f, " {item_out} const&>({rhs})")
335        } else {
336            write!(f, "{rhs}")
337        }
338    }
339
340    fn unroll_vec<D: Dialect>(
341        f: &mut Formatter<'_>,
342        lhs: &Variable<D>,
343        rhs: &Variable<D>,
344        out: &Variable<D>,
345    ) -> std::fmt::Result {
346        let item_lhs = lhs.item();
347        let out_item = out.item();
348        let out = out.fmt_left();
349
350        for i in 0..item_lhs.vectorization {
351            let lhsi = lhs.index(i);
352            let rhsi = rhs.index(i);
353            write!(f, "{out}[{lhs}] = ")?;
354            Self::format_scalar(f, lhsi, rhsi, out_item)?;
355            f.write_str(";\n")?;
356        }
357
358        Ok(())
359    }
360}
361
362impl Index {
363    pub(crate) fn format<D: Dialect>(
364        f: &mut Formatter<'_>,
365        list: &Variable<D>,
366        index: &Variable<D>,
367        out: &Variable<D>,
368        line_size: u32,
369    ) -> std::fmt::Result {
370        if matches!(
371            list,
372            Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
373        ) {
374            return IndexVector::format(f, list, index, out);
375        }
376
377        if line_size > 0 {
378            let mut item = list.item();
379            item.vectorization = line_size as usize;
380            let addr_space = D::address_space_for_variable(list);
381            let qualifier = list.const_qualifier();
382            let tmp = Variable::tmp_declared(item);
383
384            writeln!(
385                f,
386                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
387            )?;
388
389            return Index::format(f, &tmp, index, out, 0);
390        }
391
392        let item_out = out.item();
393        if let Elem::Atomic(inner) = item_out.elem {
394            let addr_space = D::address_space_for_variable(list);
395            writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
396        } else {
397            let out = out.fmt_left();
398            write!(f, "{out} = ")?;
399            Self::format_scalar(f, *list, *index, item_out)?;
400            f.write_str(";\n")
401        }
402    }
403
404    fn format_scalar<D: Dialect, Lhs, Rhs>(
405        f: &mut Formatter<'_>,
406        lhs: Lhs,
407        rhs: Rhs,
408        item_out: Item<D>,
409    ) -> std::fmt::Result
410    where
411        Lhs: Component<D>,
412        Rhs: Component<D>,
413    {
414        let item_lhs = lhs.item();
415
416        let format_vec = |f: &mut Formatter<'_>| {
417            writeln!(f, "{item_out}{{")?;
418            for i in 0..item_out.vectorization {
419                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
420            }
421            f.write_str("}")?;
422
423            Ok(())
424        };
425
426        if item_out.elem != item_lhs.elem {
427            if item_out.vectorization > 1 {
428                format_vec(f)
429            } else {
430                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
431            }
432        } else {
433            write!(f, "{lhs}[{rhs}]")
434        }
435    }
436}
437
438/// The goal is to support indexing of vectorized types.
439///
440/// # Examples
441///
442/// ```c
443/// float4 rhs;
444/// float item = var[0]; // We want that.
445/// float item = var.x; // So we compile to that.
446/// ```
447struct IndexVector<D: Dialect> {
448    _dialect: PhantomData<D>,
449}
450
451/// The goal is to support indexing of vectorized types.
452///
453/// # Examples
454///
455/// ```c
456/// float4 var;
457///
458/// var[0] = 1.0; // We want that.
459/// var.x = 1.0;  // So we compile to that.
460/// ```
461struct IndexAssignVector<D: Dialect> {
462    _dialect: PhantomData<D>,
463}
464
465impl<D: Dialect> IndexVector<D> {
466    fn format(
467        f: &mut Formatter<'_>,
468        lhs: &Variable<D>,
469        rhs: &Variable<D>,
470        out: &Variable<D>,
471    ) -> std::fmt::Result {
472        match rhs {
473            Variable::ConstantScalar(value, _elem) => {
474                let index = value.as_usize();
475                let out = out.index(index);
476                let lhs = lhs.index(index);
477                let out = out.fmt_left();
478                writeln!(f, "{out} = {lhs};")
479            }
480            _ => {
481                let elem = out.elem();
482                let qualifier = out.const_qualifier();
483                let addr_space = D::address_space_for_variable(out);
484                let out = out.fmt_left();
485                writeln!(
486                    f,
487                    "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
488                )
489            }
490        }
491    }
492}
493
494impl<D: Dialect> IndexAssignVector<D> {
495    fn format(
496        f: &mut Formatter<'_>,
497        lhs: &Variable<D>,
498        rhs: &Variable<D>,
499        out: &Variable<D>,
500    ) -> std::fmt::Result {
501        let index = match lhs {
502            Variable::ConstantScalar(value, _) => value.as_usize(),
503            _ => {
504                let elem = out.elem();
505                let addr_space = D::address_space_for_variable(out);
506                return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
507            }
508        };
509
510        let out = out.index(index);
511        let rhs = rhs.index(index);
512
513        writeln!(f, "{out} = {rhs};")
514    }
515}