cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74            let addr_space = D::address_space_for_variable(out);
75            let out = out.fmt_left();
76
77            writeln!(
78                f,
79                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80            )?;
81
82            Ok(())
83        }
84    }
85}
86
87macro_rules! operator {
88    ($name:ident, $op:expr) => {
89        pub struct $name;
90
91        impl<D: Dialect> Binary<D> for $name {
92            fn format_scalar<Lhs: Display, Rhs: Display>(
93                f: &mut std::fmt::Formatter<'_>,
94                lhs: Lhs,
95                rhs: Rhs,
96                out_item: Item<D>,
97            ) -> std::fmt::Result {
98                let out_elem = out_item.elem();
99                match out_elem {
100                    // prevent auto-promotion rules to kick-in in order to stay in the same type
101                    // this is because of fusion and vectorization that can do elemwise operations on vectorized type,
102                    // the resulting elements need to be of the same type.
103                    Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104                        write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105                    }
106                    _ => write!(f, "{lhs} {} {rhs}", $op),
107                }
108            }
109        }
110    };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135    fn format_scalar<Lhs: Display, Rhs: Display>(
136        f: &mut std::fmt::Formatter<'_>,
137        lhs: Lhs,
138        rhs: Rhs,
139        _out_item: Item<D>,
140    ) -> std::fmt::Result {
141        // f32 only
142        write!(f, "__fdividef({lhs}, {rhs})")
143    }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149    // Powf doesn't support half and no half equivalent exists
150    fn format_scalar<Lhs: Display, Rhs: Display>(
151        f: &mut std::fmt::Formatter<'_>,
152        lhs: Lhs,
153        rhs: Rhs,
154        out: Item<D>,
155    ) -> std::fmt::Result {
156        let out_elem = out.elem;
157        match out_elem {
158            Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159            Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160            Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161            Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162            _ => writeln!(f, "#error HiMul only supports 32 and 64 bit ints"),
163        }
164    }
165
166    // Powf doesn't support half and no half equivalent exists
167    fn unroll_vec(
168        f: &mut Formatter<'_>,
169        lhs: &Variable<D>,
170        rhs: &Variable<D>,
171        out: &Variable<D>,
172    ) -> core::fmt::Result {
173        let item_out = out.item();
174        let index = out.item().vectorization;
175
176        let out = out.fmt_left();
177        writeln!(f, "{out} = {item_out}{{")?;
178        for i in 0..index {
179            let lhsi = lhs.index(i);
180            let rhsi = rhs.index(i);
181
182            Self::format_scalar(f, lhsi, rhsi, item_out)?;
183            f.write_str(", ")?;
184        }
185
186        f.write_str("};\n")
187    }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193    fn format_scalar<Lhs: Display, Rhs: Display>(
194        f: &mut std::fmt::Formatter<'_>,
195        lhs: Lhs,
196        rhs: Rhs,
197        out: Item<D>,
198    ) -> std::fmt::Result {
199        D::compile_saturating_add(f, lhs, rhs, out)
200    }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206    fn format_scalar<Lhs: Display, Rhs: Display>(
207        f: &mut std::fmt::Formatter<'_>,
208        lhs: Lhs,
209        rhs: Rhs,
210        out: Item<D>,
211    ) -> std::fmt::Result {
212        D::compile_saturating_sub(f, lhs, rhs, out)
213    }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219    // Powf doesn't support half and no half equivalent exists
220    fn format_scalar<Lhs: Display, Rhs: Display>(
221        f: &mut std::fmt::Formatter<'_>,
222        lhs: Lhs,
223        rhs: Rhs,
224        item: Item<D>,
225    ) -> std::fmt::Result {
226        let elem = item.elem;
227        let lhs = lhs.to_string();
228        let rhs = rhs.to_string();
229        match elem {
230            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231                let lhs = format!("float({lhs})");
232                let rhs = format!("float({rhs})");
233                write!(f, "{elem}(")?;
234                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235                write!(f, ")")
236            }
237            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238        }
239    }
240
241    // Powf doesn't support half and no half equivalent exists
242    fn unroll_vec(
243        f: &mut Formatter<'_>,
244        lhs: &Variable<D>,
245        rhs: &Variable<D>,
246        out: &Variable<D>,
247    ) -> core::fmt::Result {
248        let item_out = out.item();
249        let index = out.item().vectorization;
250
251        let out = out.fmt_left();
252        writeln!(f, "{out} = {item_out}{{")?;
253        for i in 0..index {
254            let lhsi = lhs.index(i);
255            let rhsi = rhs.index(i);
256
257            Self::format_scalar(f, lhsi, rhsi, item_out)?;
258            f.write_str(", ")?;
259        }
260
261        f.write_str("};\n")
262    }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268    // Only executed for f32
269    fn format_scalar<Lhs: Display, Rhs: Display>(
270        f: &mut std::fmt::Formatter<'_>,
271        lhs: Lhs,
272        rhs: Rhs,
273        _item: Item<D>,
274    ) -> std::fmt::Result {
275        write!(f, "__powf({lhs}, {rhs})")
276    }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282    // Powi doesn't support half and no half equivalent exists
283    fn format_scalar<Lhs: Display, Rhs: Display>(
284        f: &mut std::fmt::Formatter<'_>,
285        lhs: Lhs,
286        rhs: Rhs,
287        item: Item<D>,
288    ) -> std::fmt::Result {
289        let elem = item.elem;
290        let lhs = lhs.to_string();
291        let rhs = rhs.to_string();
292        match elem {
293            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294                let lhs = format!("float({lhs})");
295
296                write!(f, "{elem}(")?;
297                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298                write!(f, ")")
299            }
300            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
301        }
302    }
303
304    // Powi doesn't support half and no half equivalent exists
305    fn unroll_vec(
306        f: &mut Formatter<'_>,
307        lhs: &Variable<D>,
308        rhs: &Variable<D>,
309        out: &Variable<D>,
310    ) -> core::fmt::Result {
311        let item_out = out.item();
312        let index = out.item().vectorization;
313
314        let out = out.fmt_left();
315        writeln!(f, "{out} = {item_out}{{")?;
316        for i in 0..index {
317            let lhsi = lhs.index(i);
318            let rhsi = rhs.index(i);
319
320            Self::format_scalar(f, lhsi, rhsi, item_out)?;
321            f.write_str(", ")?;
322        }
323
324        f.write_str("};\n")
325    }
326}
327
328pub struct ArcTan2;
329
330impl<D: Dialect> Binary<D> for ArcTan2 {
331    // ArcTan2 doesn't support half and no half equivalent exists
332    fn format_scalar<Lhs: Display, Rhs: Display>(
333        f: &mut std::fmt::Formatter<'_>,
334        lhs: Lhs,
335        rhs: Rhs,
336        item: Item<D>,
337    ) -> std::fmt::Result {
338        let elem = item.elem;
339        match elem {
340            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
341                write!(f, "{elem}(atan2(float({lhs}), float({rhs})))")
342            }
343            _ => {
344                write!(f, "atan2({lhs}, {rhs})")
345            }
346        }
347    }
348
349    // ArcTan2 doesn't support half and no half equivalent exists
350    fn unroll_vec(
351        f: &mut Formatter<'_>,
352        lhs: &Variable<D>,
353        rhs: &Variable<D>,
354        out: &Variable<D>,
355    ) -> core::fmt::Result {
356        let item_out = out.item();
357        let index = out.item().vectorization;
358
359        let out = out.fmt_left();
360        writeln!(f, "{out} = {item_out}{{")?;
361        for i in 0..index {
362            let lhsi = lhs.index(i);
363            let rhsi = rhs.index(i);
364
365            Self::format_scalar(f, lhsi, rhsi, item_out)?;
366            f.write_str(", ")?;
367        }
368
369        f.write_str("};\n")
370    }
371}
372
373pub struct Max;
374
375impl<D: Dialect> Binary<D> for Max {
376    fn format_scalar<Lhs: Display, Rhs: Display>(
377        f: &mut std::fmt::Formatter<'_>,
378        lhs: Lhs,
379        rhs: Rhs,
380        item: Item<D>,
381    ) -> std::fmt::Result {
382        D::compile_instruction_max_function_name(f, item)?;
383        write!(f, "({lhs}, {rhs})")
384    }
385}
386
387pub struct Min;
388
389impl<D: Dialect> Binary<D> for Min {
390    fn format_scalar<Lhs: Display, Rhs: Display>(
391        f: &mut std::fmt::Formatter<'_>,
392        lhs: Lhs,
393        rhs: Rhs,
394        item: Item<D>,
395    ) -> std::fmt::Result {
396        D::compile_instruction_min_function_name(f, item)?;
397        write!(f, "({lhs}, {rhs})")
398    }
399}
400
401pub struct IndexAssign;
402pub struct Index;
403
404impl IndexAssign {
405    pub fn format<D: Dialect>(
406        f: &mut Formatter<'_>,
407        index: &Variable<D>,
408        value: &Variable<D>,
409        out_list: &Variable<D>,
410        line_size: u32,
411    ) -> std::fmt::Result {
412        if matches!(
413            out_list,
414            Variable::LocalMut { .. } | Variable::LocalConst { .. }
415        ) {
416            return IndexAssignVector::format(f, index, value, out_list);
417        };
418
419        if line_size > 0 {
420            let mut item = out_list.item();
421            item.vectorization = line_size as usize;
422            let addr_space = D::address_space_for_variable(out_list);
423            let qualifier = out_list.const_qualifier();
424            let tmp = Variable::tmp_declared(item);
425
426            writeln!(
427                f,
428                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
429            )?;
430
431            return IndexAssign::format(f, index, value, &tmp, 0);
432        }
433
434        let out_item = out_list.item();
435
436        if index.item().vectorization == 1 {
437            write!(f, "{}[{index}] = ", out_list.fmt_left())?;
438            Self::format_scalar(f, *index, *value, out_item)?;
439            f.write_str(";\n")
440        } else {
441            Self::unroll_vec(f, index, value, out_list)
442        }
443    }
444    fn format_scalar<D: Dialect, Lhs, Rhs>(
445        f: &mut Formatter<'_>,
446        _lhs: Lhs,
447        rhs: Rhs,
448        item_out: Item<D>,
449    ) -> std::fmt::Result
450    where
451        Lhs: Component<D>,
452        Rhs: Component<D>,
453    {
454        let item_rhs = rhs.item();
455
456        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
457            writeln!(f, "{item_out}{{")?;
458            for i in 0..item_out.vectorization {
459                if cast {
460                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
461                } else {
462                    writeln!(f, "{},", rhs.index(i))?;
463                }
464            }
465            f.write_str("}")?;
466
467            Ok(())
468        };
469
470        if item_out.vectorization != item_rhs.vectorization {
471            format_vec(f, item_out != item_rhs)
472        } else if item_out.elem != item_rhs.elem {
473            if item_out.vectorization > 1 {
474                format_vec(f, true)?;
475            } else {
476                write!(f, "{}({rhs})", item_out.elem)?;
477            }
478            Ok(())
479        } else if rhs.is_const() && item_rhs.vectorization > 1 {
480            // Reinterpret cast in case rhs is optimized
481            write!(f, "reinterpret_cast<")?;
482            D::compile_local_memory_qualifier(f)?;
483            write!(f, " {item_out} const&>({rhs})")
484        } else {
485            write!(f, "{rhs}")
486        }
487    }
488
489    fn unroll_vec<D: Dialect>(
490        f: &mut Formatter<'_>,
491        lhs: &Variable<D>,
492        rhs: &Variable<D>,
493        out: &Variable<D>,
494    ) -> std::fmt::Result {
495        let item_lhs = lhs.item();
496        let out_item = out.item();
497        let out = out.fmt_left();
498
499        for i in 0..item_lhs.vectorization {
500            let lhsi = lhs.index(i);
501            let rhsi = rhs.index(i);
502            write!(f, "{out}[{lhs}] = ")?;
503            Self::format_scalar(f, lhsi, rhsi, out_item)?;
504            f.write_str(";\n")?;
505        }
506
507        Ok(())
508    }
509}
510
511impl Index {
512    pub(crate) fn format<D: Dialect>(
513        f: &mut Formatter<'_>,
514        list: &Variable<D>,
515        index: &Variable<D>,
516        out: &Variable<D>,
517        line_size: u32,
518    ) -> std::fmt::Result {
519        if matches!(
520            list,
521            Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
522        ) {
523            return IndexVector::format(f, list, index, out);
524        }
525
526        if line_size > 0 {
527            let mut item = list.item();
528            item.vectorization = line_size as usize;
529            let addr_space = D::address_space_for_variable(list);
530            let qualifier = list.const_qualifier();
531            let tmp = Variable::tmp_declared(item);
532
533            writeln!(
534                f,
535                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
536            )?;
537
538            return Index::format(f, &tmp, index, out, 0);
539        }
540
541        let item_out = out.item();
542        if let Elem::Atomic(inner) = item_out.elem {
543            let addr_space = D::address_space_for_variable(list);
544            writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
545        } else {
546            let out = out.fmt_left();
547            write!(f, "{out} = ")?;
548            Self::format_scalar(f, *list, *index, item_out)?;
549            f.write_str(";\n")
550        }
551    }
552
553    fn format_scalar<D: Dialect, Lhs, Rhs>(
554        f: &mut Formatter<'_>,
555        lhs: Lhs,
556        rhs: Rhs,
557        item_out: Item<D>,
558    ) -> std::fmt::Result
559    where
560        Lhs: Component<D>,
561        Rhs: Component<D>,
562    {
563        let item_lhs = lhs.item();
564
565        let format_vec = |f: &mut Formatter<'_>| {
566            writeln!(f, "{item_out}{{")?;
567            for i in 0..item_out.vectorization {
568                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
569            }
570            f.write_str("}")?;
571
572            Ok(())
573        };
574
575        if item_out.elem != item_lhs.elem {
576            if item_out.vectorization > 1 {
577                format_vec(f)
578            } else {
579                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
580            }
581        } else {
582            write!(f, "{lhs}[{rhs}]")
583        }
584    }
585}
586
587/// The goal is to support indexing of vectorized types.
588///
589/// # Examples
590///
591/// ```c
592/// float4 rhs;
593/// float item = var[0]; // We want that.
594/// float item = var.x; // So we compile to that.
595/// ```
596struct IndexVector<D: Dialect> {
597    _dialect: PhantomData<D>,
598}
599
600/// The goal is to support indexing of vectorized types.
601///
602/// # Examples
603///
604/// ```c
605/// float4 var;
606///
607/// var[0] = 1.0; // We want that.
608/// var.x = 1.0;  // So we compile to that.
609/// ```
610struct IndexAssignVector<D: Dialect> {
611    _dialect: PhantomData<D>,
612}
613
614impl<D: Dialect> IndexVector<D> {
615    fn format(
616        f: &mut Formatter<'_>,
617        lhs: &Variable<D>,
618        rhs: &Variable<D>,
619        out: &Variable<D>,
620    ) -> std::fmt::Result {
621        match rhs {
622            Variable::ConstantScalar(value, _elem) => {
623                let index = value.as_usize();
624                let out = out.index(index);
625                let lhs = lhs.index(index);
626                let out = out.fmt_left();
627                writeln!(f, "{out} = {lhs};")
628            }
629            _ => {
630                let elem = out.elem();
631                let qualifier = out.const_qualifier();
632                let addr_space = D::address_space_for_variable(out);
633                let out = out.fmt_left();
634                writeln!(
635                    f,
636                    "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
637                )
638            }
639        }
640    }
641}
642
643impl<D: Dialect> IndexAssignVector<D> {
644    fn format(
645        f: &mut Formatter<'_>,
646        lhs: &Variable<D>,
647        rhs: &Variable<D>,
648        out: &Variable<D>,
649    ) -> std::fmt::Result {
650        let index = match lhs {
651            Variable::ConstantScalar(value, _) => value.as_usize(),
652            _ => {
653                let elem = out.elem();
654                let addr_space = D::address_space_for_variable(out);
655                return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
656            }
657        };
658
659        let out = out.index(index);
660        let rhs = rhs.index(index);
661
662        writeln!(f, "{out} = {rhs};")
663    }
664}