cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74            let addr_space = D::address_space_for_variable(out);
75            let out = out.fmt_left();
76
77            writeln!(
78                f,
79                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80            )?;
81
82            Ok(())
83        }
84    }
85}
86
87macro_rules! operator {
88    ($name:ident, $op:expr) => {
89        pub struct $name;
90
91        impl<D: Dialect> Binary<D> for $name {
92            fn format_scalar<Lhs: Display, Rhs: Display>(
93                f: &mut std::fmt::Formatter<'_>,
94                lhs: Lhs,
95                rhs: Rhs,
96                out_item: Item<D>,
97            ) -> std::fmt::Result {
98                let out_elem = out_item.elem();
99                match out_elem {
100                    // prevent auto-promotion rules to kick-in in order to stay in the same type
101                    // this is because of fusion and vectorization that can do elemwise operations on vectorized type,
102                    // the resulting elements need to be of the same type.
103                    Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104                        write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105                    }
106                    _ => write!(f, "{lhs} {} {rhs}", $op),
107                }
108            }
109        }
110    };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135    fn format_scalar<Lhs: Display, Rhs: Display>(
136        f: &mut std::fmt::Formatter<'_>,
137        lhs: Lhs,
138        rhs: Rhs,
139        _out_item: Item<D>,
140    ) -> std::fmt::Result {
141        // f32 only
142        write!(f, "__fdividef({lhs}, {rhs})")
143    }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149    // Powf doesn't support half and no half equivalent exists
150    fn format_scalar<Lhs: Display, Rhs: Display>(
151        f: &mut std::fmt::Formatter<'_>,
152        lhs: Lhs,
153        rhs: Rhs,
154        out: Item<D>,
155    ) -> std::fmt::Result {
156        let out_elem = out.elem;
157        match out_elem {
158            Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159            Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160            Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161            Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162            _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
163        }
164    }
165
166    // Powf doesn't support half and no half equivalent exists
167    fn unroll_vec(
168        f: &mut Formatter<'_>,
169        lhs: &Variable<D>,
170        rhs: &Variable<D>,
171        out: &Variable<D>,
172    ) -> core::fmt::Result {
173        let item_out = out.item();
174        let index = out.item().vectorization;
175
176        let out = out.fmt_left();
177        writeln!(f, "{out} = {item_out}{{")?;
178        for i in 0..index {
179            let lhsi = lhs.index(i);
180            let rhsi = rhs.index(i);
181
182            Self::format_scalar(f, lhsi, rhsi, item_out)?;
183            f.write_str(", ")?;
184        }
185
186        f.write_str("};\n")
187    }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193    fn format_scalar<Lhs: Display, Rhs: Display>(
194        f: &mut std::fmt::Formatter<'_>,
195        lhs: Lhs,
196        rhs: Rhs,
197        out: Item<D>,
198    ) -> std::fmt::Result {
199        D::compile_saturating_add(f, lhs, rhs, out)
200    }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206    fn format_scalar<Lhs: Display, Rhs: Display>(
207        f: &mut std::fmt::Formatter<'_>,
208        lhs: Lhs,
209        rhs: Rhs,
210        out: Item<D>,
211    ) -> std::fmt::Result {
212        D::compile_saturating_sub(f, lhs, rhs, out)
213    }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219    // Powf doesn't support half and no half equivalent exists
220    fn format_scalar<Lhs: Display, Rhs: Display>(
221        f: &mut std::fmt::Formatter<'_>,
222        lhs: Lhs,
223        rhs: Rhs,
224        item: Item<D>,
225    ) -> std::fmt::Result {
226        let elem = item.elem;
227        let lhs = lhs.to_string();
228        let rhs = rhs.to_string();
229        match elem {
230            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231                let lhs = format!("float({lhs})");
232                let rhs = format!("float({rhs})");
233                write!(f, "{elem}(")?;
234                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235                write!(f, ")")
236            }
237            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238        }
239    }
240
241    // Powf doesn't support half and no half equivalent exists
242    fn unroll_vec(
243        f: &mut Formatter<'_>,
244        lhs: &Variable<D>,
245        rhs: &Variable<D>,
246        out: &Variable<D>,
247    ) -> core::fmt::Result {
248        let item_out = out.item();
249        let index = out.item().vectorization;
250
251        let out = out.fmt_left();
252        writeln!(f, "{out} = {item_out}{{")?;
253        for i in 0..index {
254            let lhsi = lhs.index(i);
255            let rhsi = rhs.index(i);
256
257            Self::format_scalar(f, lhsi, rhsi, item_out)?;
258            f.write_str(", ")?;
259        }
260
261        f.write_str("};\n")
262    }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268    // Only executed for f32
269    fn format_scalar<Lhs: Display, Rhs: Display>(
270        f: &mut std::fmt::Formatter<'_>,
271        lhs: Lhs,
272        rhs: Rhs,
273        _item: Item<D>,
274    ) -> std::fmt::Result {
275        write!(f, "__powf({lhs}, {rhs})")
276    }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282    // Powi doesn't support half and no half equivalent exists
283    fn format_scalar<Lhs: Display, Rhs: Display>(
284        f: &mut std::fmt::Formatter<'_>,
285        lhs: Lhs,
286        rhs: Rhs,
287        item: Item<D>,
288    ) -> std::fmt::Result {
289        let elem = item.elem;
290        let lhs = lhs.to_string();
291        let rhs = rhs.to_string();
292        match elem {
293            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294                let lhs = format!("float({lhs})");
295
296                write!(f, "{elem}(")?;
297                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298                write!(f, ")")
299            }
300            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
301        }
302    }
303
304    // Powi doesn't support half and no half equivalent exists
305    fn unroll_vec(
306        f: &mut Formatter<'_>,
307        lhs: &Variable<D>,
308        rhs: &Variable<D>,
309        out: &Variable<D>,
310    ) -> core::fmt::Result {
311        let item_out = out.item();
312        let index = out.item().vectorization;
313
314        let out = out.fmt_left();
315        writeln!(f, "{out} = {item_out}{{")?;
316        for i in 0..index {
317            let lhsi = lhs.index(i);
318            let rhsi = rhs.index(i);
319
320            Self::format_scalar(f, lhsi, rhsi, item_out)?;
321            f.write_str(", ")?;
322        }
323
324        f.write_str("};\n")
325    }
326}
327
328pub struct Max;
329
330impl<D: Dialect> Binary<D> for Max {
331    fn format_scalar<Lhs: Display, Rhs: Display>(
332        f: &mut std::fmt::Formatter<'_>,
333        lhs: Lhs,
334        rhs: Rhs,
335        item: Item<D>,
336    ) -> std::fmt::Result {
337        D::compile_instruction_max_function_name(f, item)?;
338        write!(f, "({lhs}, {rhs})")
339    }
340}
341
342pub struct Min;
343
344impl<D: Dialect> Binary<D> for Min {
345    fn format_scalar<Lhs: Display, Rhs: Display>(
346        f: &mut std::fmt::Formatter<'_>,
347        lhs: Lhs,
348        rhs: Rhs,
349        item: Item<D>,
350    ) -> std::fmt::Result {
351        D::compile_instruction_min_function_name(f, item)?;
352        write!(f, "({lhs}, {rhs})")
353    }
354}
355
356pub struct IndexAssign;
357pub struct Index;
358
359impl IndexAssign {
360    pub fn format<D: Dialect>(
361        f: &mut Formatter<'_>,
362        index: &Variable<D>,
363        value: &Variable<D>,
364        out_list: &Variable<D>,
365        line_size: u32,
366    ) -> std::fmt::Result {
367        if matches!(
368            out_list,
369            Variable::LocalMut { .. } | Variable::LocalConst { .. }
370        ) {
371            return IndexAssignVector::format(f, index, value, out_list);
372        };
373
374        if line_size > 0 {
375            let mut item = out_list.item();
376            item.vectorization = line_size as usize;
377            let addr_space = D::address_space_for_variable(out_list);
378            let qualifier = out_list.const_qualifier();
379            let tmp = Variable::tmp_declared(item);
380
381            writeln!(
382                f,
383                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
384            )?;
385
386            return IndexAssign::format(f, index, value, &tmp, 0);
387        }
388
389        let out_item = out_list.item();
390
391        if index.item().vectorization == 1 {
392            write!(f, "{}[{index}] = ", out_list.fmt_left())?;
393            Self::format_scalar(f, *index, *value, out_item)?;
394            f.write_str(";\n")
395        } else {
396            Self::unroll_vec(f, index, value, out_list)
397        }
398    }
399    fn format_scalar<D: Dialect, Lhs, Rhs>(
400        f: &mut Formatter<'_>,
401        _lhs: Lhs,
402        rhs: Rhs,
403        item_out: Item<D>,
404    ) -> std::fmt::Result
405    where
406        Lhs: Component<D>,
407        Rhs: Component<D>,
408    {
409        let item_rhs = rhs.item();
410
411        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
412            writeln!(f, "{item_out}{{")?;
413            for i in 0..item_out.vectorization {
414                if cast {
415                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
416                } else {
417                    writeln!(f, "{},", rhs.index(i))?;
418                }
419            }
420            f.write_str("}")?;
421
422            Ok(())
423        };
424
425        if item_out.vectorization != item_rhs.vectorization {
426            format_vec(f, item_out != item_rhs)
427        } else if item_out.elem != item_rhs.elem {
428            if item_out.vectorization > 1 {
429                format_vec(f, true)?;
430            } else {
431                write!(f, "{}({rhs})", item_out.elem)?;
432            }
433            Ok(())
434        } else if rhs.is_const() && item_rhs.vectorization > 1 {
435            // Reinterpret cast in case rhs is optimized
436            write!(f, "reinterpret_cast<")?;
437            D::compile_local_memory_qualifier(f)?;
438            write!(f, " {item_out} const&>({rhs})")
439        } else {
440            write!(f, "{rhs}")
441        }
442    }
443
444    fn unroll_vec<D: Dialect>(
445        f: &mut Formatter<'_>,
446        lhs: &Variable<D>,
447        rhs: &Variable<D>,
448        out: &Variable<D>,
449    ) -> std::fmt::Result {
450        let item_lhs = lhs.item();
451        let out_item = out.item();
452        let out = out.fmt_left();
453
454        for i in 0..item_lhs.vectorization {
455            let lhsi = lhs.index(i);
456            let rhsi = rhs.index(i);
457            write!(f, "{out}[{lhs}] = ")?;
458            Self::format_scalar(f, lhsi, rhsi, out_item)?;
459            f.write_str(";\n")?;
460        }
461
462        Ok(())
463    }
464}
465
466impl Index {
467    pub(crate) fn format<D: Dialect>(
468        f: &mut Formatter<'_>,
469        list: &Variable<D>,
470        index: &Variable<D>,
471        out: &Variable<D>,
472        line_size: u32,
473    ) -> std::fmt::Result {
474        if matches!(
475            list,
476            Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
477        ) {
478            return IndexVector::format(f, list, index, out);
479        }
480
481        if line_size > 0 {
482            let mut item = list.item();
483            item.vectorization = line_size as usize;
484            let addr_space = D::address_space_for_variable(list);
485            let qualifier = list.const_qualifier();
486            let tmp = Variable::tmp_declared(item);
487
488            writeln!(
489                f,
490                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
491            )?;
492
493            return Index::format(f, &tmp, index, out, 0);
494        }
495
496        let item_out = out.item();
497        if let Elem::Atomic(inner) = item_out.elem {
498            let addr_space = D::address_space_for_variable(list);
499            writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
500        } else {
501            let out = out.fmt_left();
502            write!(f, "{out} = ")?;
503            Self::format_scalar(f, *list, *index, item_out)?;
504            f.write_str(";\n")
505        }
506    }
507
508    fn format_scalar<D: Dialect, Lhs, Rhs>(
509        f: &mut Formatter<'_>,
510        lhs: Lhs,
511        rhs: Rhs,
512        item_out: Item<D>,
513    ) -> std::fmt::Result
514    where
515        Lhs: Component<D>,
516        Rhs: Component<D>,
517    {
518        let item_lhs = lhs.item();
519
520        let format_vec = |f: &mut Formatter<'_>| {
521            writeln!(f, "{item_out}{{")?;
522            for i in 0..item_out.vectorization {
523                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
524            }
525            f.write_str("}")?;
526
527            Ok(())
528        };
529
530        if item_out.elem != item_lhs.elem {
531            if item_out.vectorization > 1 {
532                format_vec(f)
533            } else {
534                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
535            }
536        } else {
537            write!(f, "{lhs}[{rhs}]")
538        }
539    }
540}
541
542/// The goal is to support indexing of vectorized types.
543///
544/// # Examples
545///
546/// ```c
547/// float4 rhs;
548/// float item = var[0]; // We want that.
549/// float item = var.x; // So we compile to that.
550/// ```
551struct IndexVector<D: Dialect> {
552    _dialect: PhantomData<D>,
553}
554
555/// The goal is to support indexing of vectorized types.
556///
557/// # Examples
558///
559/// ```c
560/// float4 var;
561///
562/// var[0] = 1.0; // We want that.
563/// var.x = 1.0;  // So we compile to that.
564/// ```
565struct IndexAssignVector<D: Dialect> {
566    _dialect: PhantomData<D>,
567}
568
569impl<D: Dialect> IndexVector<D> {
570    fn format(
571        f: &mut Formatter<'_>,
572        lhs: &Variable<D>,
573        rhs: &Variable<D>,
574        out: &Variable<D>,
575    ) -> std::fmt::Result {
576        match rhs {
577            Variable::ConstantScalar(value, _elem) => {
578                let index = value.as_usize();
579                let out = out.index(index);
580                let lhs = lhs.index(index);
581                let out = out.fmt_left();
582                writeln!(f, "{out} = {lhs};")
583            }
584            _ => {
585                let elem = out.elem();
586                let qualifier = out.const_qualifier();
587                let addr_space = D::address_space_for_variable(out);
588                let out = out.fmt_left();
589                writeln!(
590                    f,
591                    "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
592                )
593            }
594        }
595    }
596}
597
598impl<D: Dialect> IndexAssignVector<D> {
599    fn format(
600        f: &mut Formatter<'_>,
601        lhs: &Variable<D>,
602        rhs: &Variable<D>,
603        out: &Variable<D>,
604    ) -> std::fmt::Result {
605        let index = match lhs {
606            Variable::ConstantScalar(value, _) => value.as_usize(),
607            _ => {
608                let elem = out.elem();
609                let addr_space = D::address_space_for_variable(out);
610                return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
611            }
612        };
613
614        let out = out.index(index);
615        let rhs = rhs.index(index);
616
617        writeln!(f, "{out} = {rhs};")
618    }
619}