Skip to main content

cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74            let addr_space = D::address_space_for_variable(out);
75            let out = out.fmt_left();
76
77            writeln!(
78                f,
79                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80            )?;
81
82            Ok(())
83        }
84    }
85}
86
87macro_rules! operator {
88    ($name:ident, $op:expr) => {
89        pub struct $name;
90
91        impl<D: Dialect> Binary<D> for $name {
92            fn format_scalar<Lhs: Display, Rhs: Display>(
93                f: &mut std::fmt::Formatter<'_>,
94                lhs: Lhs,
95                rhs: Rhs,
96                out_item: Item<D>,
97            ) -> std::fmt::Result {
98                let out_elem = out_item.elem();
99                match out_elem {
100                    // prevent auto-promotion rules to kick-in in order to stay in the same type
101                    // this is because of fusion and vectorization that can do elemwise operations on vectorized type,
102                    // the resulting elements need to be of the same type.
103                    Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104                        write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105                    }
106                    _ => write!(f, "{lhs} {} {rhs}", $op),
107                }
108            }
109        }
110    };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135    fn format_scalar<Lhs: Display, Rhs: Display>(
136        f: &mut std::fmt::Formatter<'_>,
137        lhs: Lhs,
138        rhs: Rhs,
139        _out_item: Item<D>,
140    ) -> std::fmt::Result {
141        // f32 only
142        write!(f, "__fdividef({lhs}, {rhs})")
143    }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149    // Powf doesn't support half and no half equivalent exists
150    fn format_scalar<Lhs: Display, Rhs: Display>(
151        f: &mut std::fmt::Formatter<'_>,
152        lhs: Lhs,
153        rhs: Rhs,
154        out: Item<D>,
155    ) -> std::fmt::Result {
156        let out_elem = out.elem;
157        match out_elem {
158            Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159            Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160            Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161            Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162            _ => writeln!(f, "#error HiMul only supports 32 and 64 bit ints"),
163        }
164    }
165
166    // Powf doesn't support half and no half equivalent exists
167    fn unroll_vec(
168        f: &mut Formatter<'_>,
169        lhs: &Variable<D>,
170        rhs: &Variable<D>,
171        out: &Variable<D>,
172    ) -> core::fmt::Result {
173        let item_out = out.item();
174        let index = out.item().vectorization;
175
176        let out = out.fmt_left();
177        writeln!(f, "{out} = {item_out}{{")?;
178        for i in 0..index {
179            let lhsi = lhs.index(i);
180            let rhsi = rhs.index(i);
181
182            Self::format_scalar(f, lhsi, rhsi, item_out)?;
183            f.write_str(", ")?;
184        }
185
186        f.write_str("};\n")
187    }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193    fn format_scalar<Lhs: Display, Rhs: Display>(
194        f: &mut std::fmt::Formatter<'_>,
195        lhs: Lhs,
196        rhs: Rhs,
197        out: Item<D>,
198    ) -> std::fmt::Result {
199        D::compile_saturating_add(f, lhs, rhs, out)
200    }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206    fn format_scalar<Lhs: Display, Rhs: Display>(
207        f: &mut std::fmt::Formatter<'_>,
208        lhs: Lhs,
209        rhs: Rhs,
210        out: Item<D>,
211    ) -> std::fmt::Result {
212        D::compile_saturating_sub(f, lhs, rhs, out)
213    }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219    // Powf doesn't support half and no half equivalent exists
220    fn format_scalar<Lhs: Display, Rhs: Display>(
221        f: &mut std::fmt::Formatter<'_>,
222        lhs: Lhs,
223        rhs: Rhs,
224        item: Item<D>,
225    ) -> std::fmt::Result {
226        let elem = item.elem;
227        let lhs = lhs.to_string();
228        let rhs = rhs.to_string();
229        match elem {
230            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231                let lhs = format!("float({lhs})");
232                let rhs = format!("float({rhs})");
233                write!(f, "{elem}(")?;
234                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235                write!(f, ")")
236            }
237            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238        }
239    }
240
241    // Powf doesn't support half and no half equivalent exists
242    fn unroll_vec(
243        f: &mut Formatter<'_>,
244        lhs: &Variable<D>,
245        rhs: &Variable<D>,
246        out: &Variable<D>,
247    ) -> core::fmt::Result {
248        let item_out = out.item();
249        let index = out.item().vectorization;
250
251        let out = out.fmt_left();
252        writeln!(f, "{out} = {item_out}{{")?;
253        for i in 0..index {
254            let lhsi = lhs.index(i);
255            let rhsi = rhs.index(i);
256
257            Self::format_scalar(f, lhsi, rhsi, item_out)?;
258            f.write_str(", ")?;
259        }
260
261        f.write_str("};\n")
262    }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268    // Only executed for f32
269    fn format_scalar<Lhs: Display, Rhs: Display>(
270        f: &mut std::fmt::Formatter<'_>,
271        lhs: Lhs,
272        rhs: Rhs,
273        _item: Item<D>,
274    ) -> std::fmt::Result {
275        write!(f, "__powf({lhs}, {rhs})")
276    }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282    // Powi doesn't support half and no half equivalent exists
283    fn format_scalar<Lhs: Display, Rhs: Display>(
284        f: &mut std::fmt::Formatter<'_>,
285        lhs: Lhs,
286        rhs: Rhs,
287        item: Item<D>,
288    ) -> std::fmt::Result {
289        let elem = item.elem;
290        let lhs = lhs.to_string();
291        let rhs = rhs.to_string();
292        match elem {
293            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294                let lhs = format!("float({lhs})");
295
296                write!(f, "{elem}(")?;
297                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298                write!(f, ")")
299            }
300            Elem::F64 => {
301                // RHS needs to be a double.
302                let rhs = format!("double({rhs})");
303
304                D::compile_instruction_powf(f, &lhs, &rhs, elem)
305            }
306            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
307        }
308    }
309
310    // Powi doesn't support half and no half equivalent exists
311    fn unroll_vec(
312        f: &mut Formatter<'_>,
313        lhs: &Variable<D>,
314        rhs: &Variable<D>,
315        out: &Variable<D>,
316    ) -> core::fmt::Result {
317        let item_out = out.item();
318        let index = out.item().vectorization;
319
320        let out = out.fmt_left();
321        writeln!(f, "{out} = {item_out}{{")?;
322        for i in 0..index {
323            let lhsi = lhs.index(i);
324            let rhsi = rhs.index(i);
325
326            Self::format_scalar(f, lhsi, rhsi, item_out)?;
327            f.write_str(", ")?;
328        }
329
330        f.write_str("};\n")
331    }
332}
333pub struct ArcTan2;
334
335impl<D: Dialect> Binary<D> for ArcTan2 {
336    // ArcTan2 doesn't support half and no half equivalent exists
337    fn format_scalar<Lhs: Display, Rhs: Display>(
338        f: &mut std::fmt::Formatter<'_>,
339        lhs: Lhs,
340        rhs: Rhs,
341        item: Item<D>,
342    ) -> std::fmt::Result {
343        let elem = item.elem;
344        match elem {
345            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
346                write!(f, "{elem}(atan2(float({lhs}), float({rhs})))")
347            }
348            _ => {
349                write!(f, "atan2({lhs}, {rhs})")
350            }
351        }
352    }
353
354    // ArcTan2 doesn't support half and no half equivalent exists
355    fn unroll_vec(
356        f: &mut Formatter<'_>,
357        lhs: &Variable<D>,
358        rhs: &Variable<D>,
359        out: &Variable<D>,
360    ) -> core::fmt::Result {
361        let item_out = out.item();
362        let index = out.item().vectorization;
363
364        let out = out.fmt_left();
365        writeln!(f, "{out} = {item_out}{{")?;
366        for i in 0..index {
367            let lhsi = lhs.index(i);
368            let rhsi = rhs.index(i);
369
370            Self::format_scalar(f, lhsi, rhsi, item_out)?;
371            f.write_str(", ")?;
372        }
373
374        f.write_str("};\n")
375    }
376}
377
378pub struct Hypot;
379
380impl<D: Dialect> Binary<D> for Hypot {
381    // Hypot doesn't support half and no half equivalent exists
382    fn format_scalar<Lhs, Rhs>(
383        f: &mut Formatter<'_>,
384        lhs: Lhs,
385        rhs: Rhs,
386        item: Item<D>,
387    ) -> std::fmt::Result
388    where
389        Lhs: Component<D>,
390        Rhs: Component<D>,
391    {
392        let elem = item.elem;
393        let lhs = lhs.to_string();
394        let rhs = rhs.to_string();
395        match elem {
396            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
397                let lhs = format!("float({lhs})");
398                let rhs = format!("float({rhs})");
399                write!(f, "{elem}(")?;
400                D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?;
401                write!(f, ")")
402            }
403            _ => D::compile_instruction_hypot(f, &lhs, &rhs, elem),
404        }
405    }
406
407    // Hypot doesn't support half and no half equivalent exists
408    fn unroll_vec(
409        f: &mut Formatter<'_>,
410        lhs: &Variable<D>,
411        rhs: &Variable<D>,
412        out: &Variable<D>,
413    ) -> core::fmt::Result {
414        let item_out = out.item();
415        let index = out.item().vectorization;
416
417        let out = out.fmt_left();
418        writeln!(f, "{out} = {item_out}{{")?;
419        for i in 0..index {
420            let lhsi = lhs.index(i);
421            let rhsi = rhs.index(i);
422
423            Self::format_scalar(f, lhsi, rhsi, item_out)?;
424            f.write_str(", ")?;
425        }
426
427        f.write_str("};\n")
428    }
429}
430
431pub struct Rhypot;
432
433impl<D: Dialect> Binary<D> for Rhypot {
434    // Rhypot doesn't support half and no half equivalent exists
435    fn format_scalar<Lhs, Rhs>(
436        f: &mut Formatter<'_>,
437        lhs: Lhs,
438        rhs: Rhs,
439        item: Item<D>,
440    ) -> std::fmt::Result
441    where
442        Lhs: Component<D>,
443        Rhs: Component<D>,
444    {
445        let elem = item.elem;
446        let lhs = lhs.to_string();
447        let rhs = rhs.to_string();
448        match elem {
449            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
450                let lhs = format!("float({lhs})");
451                let rhs = format!("float({rhs})");
452                write!(f, "{elem}(")?;
453                D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?;
454                write!(f, ")")
455            }
456            _ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem),
457        }
458    }
459
460    // Rhypot doesn't support half and no half equivalent exists
461    fn unroll_vec(
462        f: &mut Formatter<'_>,
463        lhs: &Variable<D>,
464        rhs: &Variable<D>,
465        out: &Variable<D>,
466    ) -> core::fmt::Result {
467        let item_out = out.item();
468        let index = out.item().vectorization;
469
470        let out = out.fmt_left();
471        writeln!(f, "{out} = {item_out}{{")?;
472        for i in 0..index {
473            let lhsi = lhs.index(i);
474            let rhsi = rhs.index(i);
475
476            Self::format_scalar(f, lhsi, rhsi, item_out)?;
477            f.write_str(", ")?;
478        }
479
480        f.write_str("};\n")
481    }
482}
483
484pub struct Max;
485
486impl<D: Dialect> Binary<D> for Max {
487    fn format_scalar<Lhs: Display, Rhs: Display>(
488        f: &mut std::fmt::Formatter<'_>,
489        lhs: Lhs,
490        rhs: Rhs,
491        item: Item<D>,
492    ) -> std::fmt::Result {
493        D::compile_instruction_max_function_name(f, item)?;
494        write!(f, "({lhs}, {rhs})")
495    }
496}
497
498pub struct Min;
499
500impl<D: Dialect> Binary<D> for Min {
501    fn format_scalar<Lhs: Display, Rhs: Display>(
502        f: &mut std::fmt::Formatter<'_>,
503        lhs: Lhs,
504        rhs: Rhs,
505        item: Item<D>,
506    ) -> std::fmt::Result {
507        D::compile_instruction_min_function_name(f, item)?;
508        write!(f, "({lhs}, {rhs})")
509    }
510}
511
512pub struct IndexAssign;
513pub struct Index;
514
515impl IndexAssign {
516    pub fn format<D: Dialect>(
517        f: &mut Formatter<'_>,
518        index: &Variable<D>,
519        value: &Variable<D>,
520        out_list: &Variable<D>,
521        vector_size: u32,
522    ) -> std::fmt::Result {
523        if matches!(
524            out_list,
525            Variable::LocalMut { .. } | Variable::LocalConst { .. }
526        ) {
527            return IndexAssignVector::format(f, index, value, out_list);
528        };
529
530        if vector_size > 0 {
531            let mut item = out_list.item();
532            item.vectorization = vector_size as usize;
533            let addr_space = D::address_space_for_variable(out_list);
534            let qualifier = out_list.const_qualifier();
535            let tmp = Variable::tmp_declared(item);
536
537            writeln!(
538                f,
539                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
540            )?;
541
542            return IndexAssign::format(f, index, value, &tmp, 0);
543        }
544
545        let out_item = out_list.item();
546
547        if index.item().vectorization == 1 {
548            write!(f, "{}[{index}] = ", out_list.fmt_left())?;
549            Self::format_scalar(f, *index, *value, out_item)?;
550            f.write_str(";\n")
551        } else {
552            Self::unroll_vec(f, index, value, out_list)
553        }
554    }
555    fn format_scalar<D: Dialect, Lhs, Rhs>(
556        f: &mut Formatter<'_>,
557        _lhs: Lhs,
558        rhs: Rhs,
559        item_out: Item<D>,
560    ) -> std::fmt::Result
561    where
562        Lhs: Component<D>,
563        Rhs: Component<D>,
564    {
565        let item_rhs = rhs.item();
566
567        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
568            writeln!(f, "{item_out}{{")?;
569            for i in 0..item_out.vectorization {
570                if cast {
571                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
572                } else {
573                    writeln!(f, "{},", rhs.index(i))?;
574                }
575            }
576            f.write_str("}")?;
577
578            Ok(())
579        };
580
581        if item_out.vectorization != item_rhs.vectorization {
582            format_vec(f, item_out != item_rhs)
583        } else if item_out.elem != item_rhs.elem {
584            if item_out.vectorization > 1 {
585                format_vec(f, true)?;
586            } else {
587                write!(f, "{}({rhs})", item_out.elem)?;
588            }
589            Ok(())
590        } else if rhs.is_const() && item_rhs.vectorization > 1 {
591            // Reinterpret cast in case rhs is optimized
592            write!(f, "reinterpret_cast<")?;
593            D::compile_local_memory_qualifier(f)?;
594            write!(f, " {item_out} const&>({rhs})")
595        } else {
596            write!(f, "{rhs}")
597        }
598    }
599
600    fn unroll_vec<D: Dialect>(
601        f: &mut Formatter<'_>,
602        lhs: &Variable<D>,
603        rhs: &Variable<D>,
604        out: &Variable<D>,
605    ) -> std::fmt::Result {
606        let item_lhs = lhs.item();
607        let out_item = out.item();
608        let out = out.fmt_left();
609
610        for i in 0..item_lhs.vectorization {
611            let lhsi = lhs.index(i);
612            let rhsi = rhs.index(i);
613            write!(f, "{out}[{lhs}] = ")?;
614            Self::format_scalar(f, lhsi, rhsi, out_item)?;
615            f.write_str(";\n")?;
616        }
617
618        Ok(())
619    }
620}
621
622impl Index {
623    pub(crate) fn format<D: Dialect>(
624        f: &mut Formatter<'_>,
625        list: &Variable<D>,
626        index: &Variable<D>,
627        out: &Variable<D>,
628        vector_size: u32,
629    ) -> std::fmt::Result {
630        if matches!(
631            list,
632            Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::Constant(..)
633        ) {
634            return IndexVector::format(f, list, index, out);
635        }
636
637        if vector_size > 0 {
638            let mut item = list.item();
639            item.vectorization = vector_size as usize;
640            let addr_space = D::address_space_for_variable(list);
641            let qualifier = list.const_qualifier();
642            let tmp = Variable::tmp_declared(item);
643
644            writeln!(
645                f,
646                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
647            )?;
648
649            return Index::format(f, &tmp, index, out, 0);
650        }
651
652        let item_out = out.item();
653        if let Elem::Atomic(_) = item_out.elem {
654            let addr_space = D::address_space_for_variable(list);
655            writeln!(f, "{addr_space}{item_out}* {out} = &{list}[{index}];")
656        } else if matches!(item_out.elem, Elem::Barrier(_)) {
657            let addr_space = D::address_space_for_variable(list);
658            writeln!(f, "{addr_space}{}& {out} = {list}[{index}];", item_out.elem)
659        } else {
660            let out = out.fmt_left();
661            write!(f, "{out} = ")?;
662            Self::format_scalar(f, *list, *index, item_out)?;
663            f.write_str(";\n")
664        }
665    }
666
667    fn format_scalar<D: Dialect, Lhs, Rhs>(
668        f: &mut Formatter<'_>,
669        lhs: Lhs,
670        rhs: Rhs,
671        item_out: Item<D>,
672    ) -> std::fmt::Result
673    where
674        Lhs: Component<D>,
675        Rhs: Component<D>,
676    {
677        let item_lhs = lhs.item();
678
679        let format_vec = |f: &mut Formatter<'_>| {
680            writeln!(f, "{item_out}{{")?;
681            for i in 0..item_out.vectorization {
682                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
683            }
684            f.write_str("}")?;
685
686            Ok(())
687        };
688
689        if item_out.elem != item_lhs.elem {
690            if item_out.vectorization > 1 {
691                format_vec(f)
692            } else {
693                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
694            }
695        } else {
696            write!(f, "{lhs}[{rhs}]")
697        }
698    }
699}
700
701/// The goal is to support indexing of vectorized types.
702///
703/// # Examples
704///
705/// ```c
706/// float4 rhs;
707/// float item = var[0]; // We want that.
708/// float item = var.x; // So we compile to that.
709/// ```
710struct IndexVector<D: Dialect> {
711    _dialect: PhantomData<D>,
712}
713
714/// The goal is to support indexing of vectorized types.
715///
716/// # Examples
717///
718/// ```c
719/// float4 var;
720///
721/// var[0] = 1.0; // We want that.
722/// var.x = 1.0;  // So we compile to that.
723/// ```
724struct IndexAssignVector<D: Dialect> {
725    _dialect: PhantomData<D>,
726}
727
728impl<D: Dialect> IndexVector<D> {
729    fn format(
730        f: &mut Formatter<'_>,
731        lhs: &Variable<D>,
732        rhs: &Variable<D>,
733        out: &Variable<D>,
734    ) -> std::fmt::Result {
735        match rhs {
736            Variable::Constant(value, _elem) => {
737                let index = value.as_usize();
738                let out = out.index(index);
739                let lhs = lhs.index(index);
740                let out = out.fmt_left();
741                writeln!(f, "{out} = {lhs};")
742            }
743            _ => {
744                let elem = out.elem();
745                let qualifier = out.const_qualifier();
746                let addr_space = D::address_space_for_variable(out);
747                let out = out.fmt_left();
748                writeln!(
749                    f,
750                    "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
751                )
752            }
753        }
754    }
755}
756
757impl<D: Dialect> IndexAssignVector<D> {
758    fn format(
759        f: &mut Formatter<'_>,
760        lhs: &Variable<D>,
761        rhs: &Variable<D>,
762        out: &Variable<D>,
763    ) -> std::fmt::Result {
764        let index = match lhs {
765            Variable::Constant(value, _) => value.as_usize(),
766            _ => {
767                let elem = out.elem();
768                let addr_space = D::address_space_for_variable(out);
769                return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
770            }
771        };
772
773        let out = out.index(index);
774        let rhs = rhs.index(index);
775
776        writeln!(f, "{out} = {rhs};")
777    }
778}