cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74            let addr_space = D::address_space_for_variable(out);
75            let out = out.fmt_left();
76
77            writeln!(
78                f,
79                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80            )?;
81
82            Ok(())
83        }
84    }
85}
86
87macro_rules! operator {
88    ($name:ident, $op:expr) => {
89        pub struct $name;
90
91        impl<D: Dialect> Binary<D> for $name {
92            fn format_scalar<Lhs: Display, Rhs: Display>(
93                f: &mut std::fmt::Formatter<'_>,
94                lhs: Lhs,
95                rhs: Rhs,
96                out_item: Item<D>,
97            ) -> std::fmt::Result {
98                let out_elem = out_item.elem();
99                match out_elem {
100                    // prevent auto-promotion rules to kick-in in order to stay in the same type
101                    // this is because of fusion and vectorization that can do elemwise operations on vectorized type,
102                    // the resulting elements need to be of the same type.
103                    Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104                        write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105                    }
106                    _ => write!(f, "{lhs} {} {rhs}", $op),
107                }
108            }
109        }
110    };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct HiMul;
133
134impl<D: Dialect> Binary<D> for HiMul {
135    // Powf doesn't support half and no half equivalent exists
136    fn format_scalar<Lhs: Display, Rhs: Display>(
137        f: &mut std::fmt::Formatter<'_>,
138        lhs: Lhs,
139        rhs: Rhs,
140        out: Item<D>,
141    ) -> std::fmt::Result {
142        let out_elem = out.elem;
143        match out_elem {
144            Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
145            Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
146            Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
147            Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
148            _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
149        }
150    }
151
152    // Powf doesn't support half and no half equivalent exists
153    fn unroll_vec(
154        f: &mut Formatter<'_>,
155        lhs: &Variable<D>,
156        rhs: &Variable<D>,
157        out: &Variable<D>,
158    ) -> core::fmt::Result {
159        let item_out = out.item();
160        let index = out.item().vectorization;
161
162        let out = out.fmt_left();
163        writeln!(f, "{out} = {item_out}{{")?;
164        for i in 0..index {
165            let lhsi = lhs.index(i);
166            let rhsi = rhs.index(i);
167
168            Self::format_scalar(f, lhsi, rhsi, item_out)?;
169            f.write_str(", ")?;
170        }
171
172        f.write_str("};\n")
173    }
174}
175
176pub struct SaturatingAdd;
177
178impl<D: Dialect> Binary<D> for SaturatingAdd {
179    fn format_scalar<Lhs: Display, Rhs: Display>(
180        f: &mut std::fmt::Formatter<'_>,
181        lhs: Lhs,
182        rhs: Rhs,
183        out: Item<D>,
184    ) -> std::fmt::Result {
185        D::compile_saturating_add(f, lhs, rhs, out)
186    }
187}
188
189pub struct SaturatingSub;
190
191impl<D: Dialect> Binary<D> for SaturatingSub {
192    fn format_scalar<Lhs: Display, Rhs: Display>(
193        f: &mut std::fmt::Formatter<'_>,
194        lhs: Lhs,
195        rhs: Rhs,
196        out: Item<D>,
197    ) -> std::fmt::Result {
198        D::compile_saturating_sub(f, lhs, rhs, out)
199    }
200}
201
202pub struct Powf;
203
204impl<D: Dialect> Binary<D> for Powf {
205    // Powf doesn't support half and no half equivalent exists
206    fn format_scalar<Lhs: Display, Rhs: Display>(
207        f: &mut std::fmt::Formatter<'_>,
208        lhs: Lhs,
209        rhs: Rhs,
210        item: Item<D>,
211    ) -> std::fmt::Result {
212        let elem = item.elem;
213        let lhs = lhs.to_string();
214        let rhs = rhs.to_string();
215        match elem {
216            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
217                let lhs = format!("float({lhs})");
218                let rhs = format!("float({rhs})");
219                write!(f, "{elem}(")?;
220                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
221                write!(f, ")")
222            }
223            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
224        }
225    }
226
227    // Powf doesn't support half and no half equivalent exists
228    fn unroll_vec(
229        f: &mut Formatter<'_>,
230        lhs: &Variable<D>,
231        rhs: &Variable<D>,
232        out: &Variable<D>,
233    ) -> core::fmt::Result {
234        let item_out = out.item();
235        let index = out.item().vectorization;
236
237        let out = out.fmt_left();
238        writeln!(f, "{out} = {item_out}{{")?;
239        for i in 0..index {
240            let lhsi = lhs.index(i);
241            let rhsi = rhs.index(i);
242
243            Self::format_scalar(f, lhsi, rhsi, item_out)?;
244            f.write_str(", ")?;
245        }
246
247        f.write_str("};\n")
248    }
249}
250
251pub struct Powi;
252
253impl<D: Dialect> Binary<D> for Powi {
254    // Powi doesn't support half and no half equivalent exists
255    fn format_scalar<Lhs: Display, Rhs: Display>(
256        f: &mut std::fmt::Formatter<'_>,
257        lhs: Lhs,
258        rhs: Rhs,
259        item: Item<D>,
260    ) -> std::fmt::Result {
261        let elem = item.elem;
262        let lhs = lhs.to_string();
263        let rhs = rhs.to_string();
264        match elem {
265            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
266                let lhs = format!("float({lhs})");
267
268                write!(f, "{elem}(")?;
269                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
270                write!(f, ")")
271            }
272            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
273        }
274    }
275
276    // Powi doesn't support half and no half equivalent exists
277    fn unroll_vec(
278        f: &mut Formatter<'_>,
279        lhs: &Variable<D>,
280        rhs: &Variable<D>,
281        out: &Variable<D>,
282    ) -> core::fmt::Result {
283        let item_out = out.item();
284        let index = out.item().vectorization;
285
286        let out = out.fmt_left();
287        writeln!(f, "{out} = {item_out}{{")?;
288        for i in 0..index {
289            let lhsi = lhs.index(i);
290            let rhsi = rhs.index(i);
291
292            Self::format_scalar(f, lhsi, rhsi, item_out)?;
293            f.write_str(", ")?;
294        }
295
296        f.write_str("};\n")
297    }
298}
299
300pub struct Max;
301
302impl<D: Dialect> Binary<D> for Max {
303    fn format_scalar<Lhs: Display, Rhs: Display>(
304        f: &mut std::fmt::Formatter<'_>,
305        lhs: Lhs,
306        rhs: Rhs,
307        item: Item<D>,
308    ) -> std::fmt::Result {
309        D::compile_instruction_max_function_name(f, item)?;
310        write!(f, "({lhs}, {rhs})")
311    }
312}
313
314pub struct Min;
315
316impl<D: Dialect> Binary<D> for Min {
317    fn format_scalar<Lhs: Display, Rhs: Display>(
318        f: &mut std::fmt::Formatter<'_>,
319        lhs: Lhs,
320        rhs: Rhs,
321        item: Item<D>,
322    ) -> std::fmt::Result {
323        D::compile_instruction_min_function_name(f, item)?;
324        write!(f, "({lhs}, {rhs})")
325    }
326}
327
328pub struct IndexAssign;
329pub struct Index;
330
331impl IndexAssign {
332    pub fn format<D: Dialect>(
333        f: &mut Formatter<'_>,
334        index: &Variable<D>,
335        value: &Variable<D>,
336        out_list: &Variable<D>,
337        line_size: u32,
338    ) -> std::fmt::Result {
339        if matches!(
340            out_list,
341            Variable::LocalMut { .. } | Variable::LocalConst { .. }
342        ) {
343            return IndexAssignVector::format(f, index, value, out_list);
344        };
345
346        if line_size > 0 {
347            let mut item = out_list.item();
348            item.vectorization = line_size as usize;
349            let addr_space = D::address_space_for_variable(out_list);
350            let qualifier = out_list.const_qualifier();
351            let tmp = Variable::tmp_declared(item);
352
353            writeln!(
354                f,
355                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
356            )?;
357
358            return IndexAssign::format(f, index, value, &tmp, 0);
359        }
360
361        let out_item = out_list.item();
362
363        if index.item().vectorization == 1 {
364            write!(f, "{}[{index}] = ", out_list.fmt_left())?;
365            Self::format_scalar(f, *index, *value, out_item)?;
366            f.write_str(";\n")
367        } else {
368            Self::unroll_vec(f, index, value, out_list)
369        }
370    }
371    fn format_scalar<D: Dialect, Lhs, Rhs>(
372        f: &mut Formatter<'_>,
373        _lhs: Lhs,
374        rhs: Rhs,
375        item_out: Item<D>,
376    ) -> std::fmt::Result
377    where
378        Lhs: Component<D>,
379        Rhs: Component<D>,
380    {
381        let item_rhs = rhs.item();
382
383        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
384            writeln!(f, "{item_out}{{")?;
385            for i in 0..item_out.vectorization {
386                if cast {
387                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
388                } else {
389                    writeln!(f, "{},", rhs.index(i))?;
390                }
391            }
392            f.write_str("}")?;
393
394            Ok(())
395        };
396
397        if item_out.vectorization != item_rhs.vectorization {
398            format_vec(f, item_out != item_rhs)
399        } else if item_out.elem != item_rhs.elem {
400            if item_out.vectorization > 1 {
401                format_vec(f, true)?;
402            } else {
403                write!(f, "{}({rhs})", item_out.elem)?;
404            }
405            Ok(())
406        } else if rhs.is_const() && item_rhs.vectorization > 1 {
407            // Reinterpret cast in case rhs is optimized
408            write!(f, "reinterpret_cast<")?;
409            D::compile_local_memory_qualifier(f)?;
410            write!(f, " {item_out} const&>({rhs})")
411        } else {
412            write!(f, "{rhs}")
413        }
414    }
415
416    fn unroll_vec<D: Dialect>(
417        f: &mut Formatter<'_>,
418        lhs: &Variable<D>,
419        rhs: &Variable<D>,
420        out: &Variable<D>,
421    ) -> std::fmt::Result {
422        let item_lhs = lhs.item();
423        let out_item = out.item();
424        let out = out.fmt_left();
425
426        for i in 0..item_lhs.vectorization {
427            let lhsi = lhs.index(i);
428            let rhsi = rhs.index(i);
429            write!(f, "{out}[{lhs}] = ")?;
430            Self::format_scalar(f, lhsi, rhsi, out_item)?;
431            f.write_str(";\n")?;
432        }
433
434        Ok(())
435    }
436}
437
438impl Index {
439    pub(crate) fn format<D: Dialect>(
440        f: &mut Formatter<'_>,
441        list: &Variable<D>,
442        index: &Variable<D>,
443        out: &Variable<D>,
444        line_size: u32,
445    ) -> std::fmt::Result {
446        if matches!(
447            list,
448            Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
449        ) {
450            return IndexVector::format(f, list, index, out);
451        }
452
453        if line_size > 0 {
454            let mut item = list.item();
455            item.vectorization = line_size as usize;
456            let addr_space = D::address_space_for_variable(list);
457            let qualifier = list.const_qualifier();
458            let tmp = Variable::tmp_declared(item);
459
460            writeln!(
461                f,
462                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
463            )?;
464
465            return Index::format(f, &tmp, index, out, 0);
466        }
467
468        let item_out = out.item();
469        if let Elem::Atomic(inner) = item_out.elem {
470            let addr_space = D::address_space_for_variable(list);
471            writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
472        } else {
473            let out = out.fmt_left();
474            write!(f, "{out} = ")?;
475            Self::format_scalar(f, *list, *index, item_out)?;
476            f.write_str(";\n")
477        }
478    }
479
480    fn format_scalar<D: Dialect, Lhs, Rhs>(
481        f: &mut Formatter<'_>,
482        lhs: Lhs,
483        rhs: Rhs,
484        item_out: Item<D>,
485    ) -> std::fmt::Result
486    where
487        Lhs: Component<D>,
488        Rhs: Component<D>,
489    {
490        let item_lhs = lhs.item();
491
492        let format_vec = |f: &mut Formatter<'_>| {
493            writeln!(f, "{item_out}{{")?;
494            for i in 0..item_out.vectorization {
495                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
496            }
497            f.write_str("}")?;
498
499            Ok(())
500        };
501
502        if item_out.elem != item_lhs.elem {
503            if item_out.vectorization > 1 {
504                format_vec(f)
505            } else {
506                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
507            }
508        } else {
509            write!(f, "{lhs}[{rhs}]")
510        }
511    }
512}
513
514/// The goal is to support indexing of vectorized types.
515///
516/// # Examples
517///
518/// ```c
519/// float4 rhs;
520/// float item = var[0]; // We want that.
521/// float item = var.x; // So we compile to that.
522/// ```
523struct IndexVector<D: Dialect> {
524    _dialect: PhantomData<D>,
525}
526
527/// The goal is to support indexing of vectorized types.
528///
529/// # Examples
530///
531/// ```c
532/// float4 var;
533///
534/// var[0] = 1.0; // We want that.
535/// var.x = 1.0;  // So we compile to that.
536/// ```
537struct IndexAssignVector<D: Dialect> {
538    _dialect: PhantomData<D>,
539}
540
541impl<D: Dialect> IndexVector<D> {
542    fn format(
543        f: &mut Formatter<'_>,
544        lhs: &Variable<D>,
545        rhs: &Variable<D>,
546        out: &Variable<D>,
547    ) -> std::fmt::Result {
548        match rhs {
549            Variable::ConstantScalar(value, _elem) => {
550                let index = value.as_usize();
551                let out = out.index(index);
552                let lhs = lhs.index(index);
553                let out = out.fmt_left();
554                writeln!(f, "{out} = {lhs};")
555            }
556            _ => {
557                let elem = out.elem();
558                let qualifier = out.const_qualifier();
559                let addr_space = D::address_space_for_variable(out);
560                let out = out.fmt_left();
561                writeln!(
562                    f,
563                    "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
564                )
565            }
566        }
567    }
568}
569
570impl<D: Dialect> IndexAssignVector<D> {
571    fn format(
572        f: &mut Formatter<'_>,
573        lhs: &Variable<D>,
574        rhs: &Variable<D>,
575        out: &Variable<D>,
576    ) -> std::fmt::Result {
577        let index = match lhs {
578            Variable::ConstantScalar(value, _) => value.as_usize(),
579            _ => {
580                let elem = out.elem();
581                let addr_space = D::address_space_for_variable(out);
582                return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
583            }
584        };
585
586        let out = out.index(index);
587        let rhs = rhs.index(index);
588
589        writeln!(f, "{out} = {rhs};")
590    }
591}