cubecl_cpp/shared/
binary.rs

1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5    fmt::{Display, Formatter},
6    marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10    fn format(
11        f: &mut Formatter<'_>,
12        lhs: &Variable<D>,
13        rhs: &Variable<D>,
14        out: &Variable<D>,
15    ) -> std::fmt::Result {
16        let out_item = out.item();
17        if out.item().vectorization == 1 {
18            let out = out.fmt_left();
19            write!(f, "{out} = ")?;
20            Self::format_scalar(f, *lhs, *rhs, out_item)?;
21            f.write_str(";\n")
22        } else {
23            Self::unroll_vec(f, lhs, rhs, out)
24        }
25    }
26
27    fn format_scalar<Lhs, Rhs>(
28        f: &mut Formatter<'_>,
29        lhs: Lhs,
30        rhs: Rhs,
31        item: Item<D>,
32    ) -> std::fmt::Result
33    where
34        Lhs: Component<D>,
35        Rhs: Component<D>;
36
37    fn unroll_vec(
38        f: &mut Formatter<'_>,
39        lhs: &Variable<D>,
40        rhs: &Variable<D>,
41        out: &Variable<D>,
42    ) -> core::fmt::Result {
43        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44        let [lhs, rhs, out_optimized] = optimized.args;
45
46        let item_out_original = out.item();
47        let item_out_optimized = out_optimized.item();
48
49        let index = match optimized.optimization_factor {
50            Some(factor) => item_out_original.vectorization / factor,
51            None => item_out_optimized.vectorization,
52        };
53
54        let mut write_op =
55            |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56                let out = out.fmt_left();
57                writeln!(f, "{out} = {item_out}{{")?;
58                for i in 0..index {
59                    let lhsi = lhs.index(i);
60                    let rhsi = rhs.index(i);
61
62                    Self::format_scalar(f, lhsi, rhsi, item_out)?;
63                    f.write_str(", ")?;
64                }
65
66                f.write_str("};\n")
67            };
68
69        if item_out_original == item_out_optimized {
70            write_op(&lhs, &rhs, out, item_out_optimized)
71        } else {
72            let out_tmp = Variable::tmp(item_out_optimized);
73            write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74            let addr_space = D::address_space_for_variable(out);
75            let out = out.fmt_left();
76
77            writeln!(
78                f,
79                "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80            )?;
81
82            Ok(())
83        }
84    }
85}
86
87macro_rules! operator {
88    ($name:ident, $op:expr) => {
89        pub struct $name;
90
91        impl<D: Dialect> Binary<D> for $name {
92            fn format_scalar<Lhs: Display, Rhs: Display>(
93                f: &mut std::fmt::Formatter<'_>,
94                lhs: Lhs,
95                rhs: Rhs,
96                out_item: Item<D>,
97            ) -> std::fmt::Result {
98                let out_elem = out_item.elem();
99                match out_elem {
100                    // prevent auto-promotion rules to kick-in in order to stay in the same type
101                    // this is because of fusion and vectorization that can do elemwise operations on vectorized type,
102                    // the resulting elements need to be of the same type.
103                    Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104                        write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105                    }
106                    _ => write!(f, "{lhs} {} {rhs}", $op),
107                }
108            }
109        }
110    };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135    fn format_scalar<Lhs: Display, Rhs: Display>(
136        f: &mut std::fmt::Formatter<'_>,
137        lhs: Lhs,
138        rhs: Rhs,
139        _out_item: Item<D>,
140    ) -> std::fmt::Result {
141        // f32 only
142        write!(f, "__fdividef({lhs}, {rhs})")
143    }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149    // Powf doesn't support half and no half equivalent exists
150    fn format_scalar<Lhs: Display, Rhs: Display>(
151        f: &mut std::fmt::Formatter<'_>,
152        lhs: Lhs,
153        rhs: Rhs,
154        out: Item<D>,
155    ) -> std::fmt::Result {
156        let out_elem = out.elem;
157        match out_elem {
158            Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159            Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160            Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161            Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162            _ => writeln!(f, "#error HiMul only supports 32 and 64 bit ints"),
163        }
164    }
165
166    // Powf doesn't support half and no half equivalent exists
167    fn unroll_vec(
168        f: &mut Formatter<'_>,
169        lhs: &Variable<D>,
170        rhs: &Variable<D>,
171        out: &Variable<D>,
172    ) -> core::fmt::Result {
173        let item_out = out.item();
174        let index = out.item().vectorization;
175
176        let out = out.fmt_left();
177        writeln!(f, "{out} = {item_out}{{")?;
178        for i in 0..index {
179            let lhsi = lhs.index(i);
180            let rhsi = rhs.index(i);
181
182            Self::format_scalar(f, lhsi, rhsi, item_out)?;
183            f.write_str(", ")?;
184        }
185
186        f.write_str("};\n")
187    }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193    fn format_scalar<Lhs: Display, Rhs: Display>(
194        f: &mut std::fmt::Formatter<'_>,
195        lhs: Lhs,
196        rhs: Rhs,
197        out: Item<D>,
198    ) -> std::fmt::Result {
199        D::compile_saturating_add(f, lhs, rhs, out)
200    }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206    fn format_scalar<Lhs: Display, Rhs: Display>(
207        f: &mut std::fmt::Formatter<'_>,
208        lhs: Lhs,
209        rhs: Rhs,
210        out: Item<D>,
211    ) -> std::fmt::Result {
212        D::compile_saturating_sub(f, lhs, rhs, out)
213    }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219    // Powf doesn't support half and no half equivalent exists
220    fn format_scalar<Lhs: Display, Rhs: Display>(
221        f: &mut std::fmt::Formatter<'_>,
222        lhs: Lhs,
223        rhs: Rhs,
224        item: Item<D>,
225    ) -> std::fmt::Result {
226        let elem = item.elem;
227        let lhs = lhs.to_string();
228        let rhs = rhs.to_string();
229        match elem {
230            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231                let lhs = format!("float({lhs})");
232                let rhs = format!("float({rhs})");
233                write!(f, "{elem}(")?;
234                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235                write!(f, ")")
236            }
237            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238        }
239    }
240
241    // Powf doesn't support half and no half equivalent exists
242    fn unroll_vec(
243        f: &mut Formatter<'_>,
244        lhs: &Variable<D>,
245        rhs: &Variable<D>,
246        out: &Variable<D>,
247    ) -> core::fmt::Result {
248        let item_out = out.item();
249        let index = out.item().vectorization;
250
251        let out = out.fmt_left();
252        writeln!(f, "{out} = {item_out}{{")?;
253        for i in 0..index {
254            let lhsi = lhs.index(i);
255            let rhsi = rhs.index(i);
256
257            Self::format_scalar(f, lhsi, rhsi, item_out)?;
258            f.write_str(", ")?;
259        }
260
261        f.write_str("};\n")
262    }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268    // Only executed for f32
269    fn format_scalar<Lhs: Display, Rhs: Display>(
270        f: &mut std::fmt::Formatter<'_>,
271        lhs: Lhs,
272        rhs: Rhs,
273        _item: Item<D>,
274    ) -> std::fmt::Result {
275        write!(f, "__powf({lhs}, {rhs})")
276    }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282    // Powi doesn't support half and no half equivalent exists
283    fn format_scalar<Lhs: Display, Rhs: Display>(
284        f: &mut std::fmt::Formatter<'_>,
285        lhs: Lhs,
286        rhs: Rhs,
287        item: Item<D>,
288    ) -> std::fmt::Result {
289        let elem = item.elem;
290        let lhs = lhs.to_string();
291        let rhs = rhs.to_string();
292        match elem {
293            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294                let lhs = format!("float({lhs})");
295
296                write!(f, "{elem}(")?;
297                D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298                write!(f, ")")
299            }
300            _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
301        }
302    }
303
304    // Powi doesn't support half and no half equivalent exists
305    fn unroll_vec(
306        f: &mut Formatter<'_>,
307        lhs: &Variable<D>,
308        rhs: &Variable<D>,
309        out: &Variable<D>,
310    ) -> core::fmt::Result {
311        let item_out = out.item();
312        let index = out.item().vectorization;
313
314        let out = out.fmt_left();
315        writeln!(f, "{out} = {item_out}{{")?;
316        for i in 0..index {
317            let lhsi = lhs.index(i);
318            let rhsi = rhs.index(i);
319
320            Self::format_scalar(f, lhsi, rhsi, item_out)?;
321            f.write_str(", ")?;
322        }
323
324        f.write_str("};\n")
325    }
326}
327pub struct ArcTan2;
328
329impl<D: Dialect> Binary<D> for ArcTan2 {
330    // ArcTan2 doesn't support half and no half equivalent exists
331    fn format_scalar<Lhs: Display, Rhs: Display>(
332        f: &mut std::fmt::Formatter<'_>,
333        lhs: Lhs,
334        rhs: Rhs,
335        item: Item<D>,
336    ) -> std::fmt::Result {
337        let elem = item.elem;
338        match elem {
339            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
340                write!(f, "{elem}(atan2(float({lhs}), float({rhs})))")
341            }
342            _ => {
343                write!(f, "atan2({lhs}, {rhs})")
344            }
345        }
346    }
347
348    // ArcTan2 doesn't support half and no half equivalent exists
349    fn unroll_vec(
350        f: &mut Formatter<'_>,
351        lhs: &Variable<D>,
352        rhs: &Variable<D>,
353        out: &Variable<D>,
354    ) -> core::fmt::Result {
355        let item_out = out.item();
356        let index = out.item().vectorization;
357
358        let out = out.fmt_left();
359        writeln!(f, "{out} = {item_out}{{")?;
360        for i in 0..index {
361            let lhsi = lhs.index(i);
362            let rhsi = rhs.index(i);
363
364            Self::format_scalar(f, lhsi, rhsi, item_out)?;
365            f.write_str(", ")?;
366        }
367
368        f.write_str("};\n")
369    }
370}
371
372pub struct Hypot;
373
374impl<D: Dialect> Binary<D> for Hypot {
375    // Hypot doesn't support half and no half equivalent exists
376    fn format_scalar<Lhs, Rhs>(
377        f: &mut Formatter<'_>,
378        lhs: Lhs,
379        rhs: Rhs,
380        item: Item<D>,
381    ) -> std::fmt::Result
382    where
383        Lhs: Component<D>,
384        Rhs: Component<D>,
385    {
386        let elem = item.elem;
387        let lhs = lhs.to_string();
388        let rhs = rhs.to_string();
389        match elem {
390            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
391                let lhs = format!("float({lhs})");
392                let rhs = format!("float({rhs})");
393                write!(f, "{elem}(")?;
394                D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?;
395                write!(f, ")")
396            }
397            _ => D::compile_instruction_hypot(f, &lhs, &rhs, elem),
398        }
399    }
400
401    // Hypot doesn't support half and no half equivalent exists
402    fn unroll_vec(
403        f: &mut Formatter<'_>,
404        lhs: &Variable<D>,
405        rhs: &Variable<D>,
406        out: &Variable<D>,
407    ) -> core::fmt::Result {
408        let item_out = out.item();
409        let index = out.item().vectorization;
410
411        let out = out.fmt_left();
412        writeln!(f, "{out} = {item_out}{{")?;
413        for i in 0..index {
414            let lhsi = lhs.index(i);
415            let rhsi = rhs.index(i);
416
417            Self::format_scalar(f, lhsi, rhsi, item_out)?;
418            f.write_str(", ")?;
419        }
420
421        f.write_str("};\n")
422    }
423}
424
425pub struct Rhypot;
426
427impl<D: Dialect> Binary<D> for Rhypot {
428    // Rhypot doesn't support half and no half equivalent exists
429    fn format_scalar<Lhs, Rhs>(
430        f: &mut Formatter<'_>,
431        lhs: Lhs,
432        rhs: Rhs,
433        item: Item<D>,
434    ) -> std::fmt::Result
435    where
436        Lhs: Component<D>,
437        Rhs: Component<D>,
438    {
439        let elem = item.elem;
440        let lhs = lhs.to_string();
441        let rhs = rhs.to_string();
442        match elem {
443            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
444                let lhs = format!("float({lhs})");
445                let rhs = format!("float({rhs})");
446                write!(f, "{elem}(")?;
447                D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?;
448                write!(f, ")")
449            }
450            _ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem),
451        }
452    }
453
454    // Rhypot doesn't support half and no half equivalent exists
455    fn unroll_vec(
456        f: &mut Formatter<'_>,
457        lhs: &Variable<D>,
458        rhs: &Variable<D>,
459        out: &Variable<D>,
460    ) -> core::fmt::Result {
461        let item_out = out.item();
462        let index = out.item().vectorization;
463
464        let out = out.fmt_left();
465        writeln!(f, "{out} = {item_out}{{")?;
466        for i in 0..index {
467            let lhsi = lhs.index(i);
468            let rhsi = rhs.index(i);
469
470            Self::format_scalar(f, lhsi, rhsi, item_out)?;
471            f.write_str(", ")?;
472        }
473
474        f.write_str("};\n")
475    }
476}
477
478pub struct Max;
479
480impl<D: Dialect> Binary<D> for Max {
481    fn format_scalar<Lhs: Display, Rhs: Display>(
482        f: &mut std::fmt::Formatter<'_>,
483        lhs: Lhs,
484        rhs: Rhs,
485        item: Item<D>,
486    ) -> std::fmt::Result {
487        D::compile_instruction_max_function_name(f, item)?;
488        write!(f, "({lhs}, {rhs})")
489    }
490}
491
492pub struct Min;
493
494impl<D: Dialect> Binary<D> for Min {
495    fn format_scalar<Lhs: Display, Rhs: Display>(
496        f: &mut std::fmt::Formatter<'_>,
497        lhs: Lhs,
498        rhs: Rhs,
499        item: Item<D>,
500    ) -> std::fmt::Result {
501        D::compile_instruction_min_function_name(f, item)?;
502        write!(f, "({lhs}, {rhs})")
503    }
504}
505
506pub struct IndexAssign;
507pub struct Index;
508
509impl IndexAssign {
510    pub fn format<D: Dialect>(
511        f: &mut Formatter<'_>,
512        index: &Variable<D>,
513        value: &Variable<D>,
514        out_list: &Variable<D>,
515        line_size: u32,
516    ) -> std::fmt::Result {
517        if matches!(
518            out_list,
519            Variable::LocalMut { .. } | Variable::LocalConst { .. }
520        ) {
521            return IndexAssignVector::format(f, index, value, out_list);
522        };
523
524        if line_size > 0 {
525            let mut item = out_list.item();
526            item.vectorization = line_size as usize;
527            let addr_space = D::address_space_for_variable(out_list);
528            let qualifier = out_list.const_qualifier();
529            let tmp = Variable::tmp_declared(item);
530
531            writeln!(
532                f,
533                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
534            )?;
535
536            return IndexAssign::format(f, index, value, &tmp, 0);
537        }
538
539        let out_item = out_list.item();
540
541        if index.item().vectorization == 1 {
542            write!(f, "{}[{index}] = ", out_list.fmt_left())?;
543            Self::format_scalar(f, *index, *value, out_item)?;
544            f.write_str(";\n")
545        } else {
546            Self::unroll_vec(f, index, value, out_list)
547        }
548    }
549    fn format_scalar<D: Dialect, Lhs, Rhs>(
550        f: &mut Formatter<'_>,
551        _lhs: Lhs,
552        rhs: Rhs,
553        item_out: Item<D>,
554    ) -> std::fmt::Result
555    where
556        Lhs: Component<D>,
557        Rhs: Component<D>,
558    {
559        let item_rhs = rhs.item();
560
561        let format_vec = |f: &mut Formatter<'_>, cast: bool| {
562            writeln!(f, "{item_out}{{")?;
563            for i in 0..item_out.vectorization {
564                if cast {
565                    writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
566                } else {
567                    writeln!(f, "{},", rhs.index(i))?;
568                }
569            }
570            f.write_str("}")?;
571
572            Ok(())
573        };
574
575        if item_out.vectorization != item_rhs.vectorization {
576            format_vec(f, item_out != item_rhs)
577        } else if item_out.elem != item_rhs.elem {
578            if item_out.vectorization > 1 {
579                format_vec(f, true)?;
580            } else {
581                write!(f, "{}({rhs})", item_out.elem)?;
582            }
583            Ok(())
584        } else if rhs.is_const() && item_rhs.vectorization > 1 {
585            // Reinterpret cast in case rhs is optimized
586            write!(f, "reinterpret_cast<")?;
587            D::compile_local_memory_qualifier(f)?;
588            write!(f, " {item_out} const&>({rhs})")
589        } else {
590            write!(f, "{rhs}")
591        }
592    }
593
594    fn unroll_vec<D: Dialect>(
595        f: &mut Formatter<'_>,
596        lhs: &Variable<D>,
597        rhs: &Variable<D>,
598        out: &Variable<D>,
599    ) -> std::fmt::Result {
600        let item_lhs = lhs.item();
601        let out_item = out.item();
602        let out = out.fmt_left();
603
604        for i in 0..item_lhs.vectorization {
605            let lhsi = lhs.index(i);
606            let rhsi = rhs.index(i);
607            write!(f, "{out}[{lhs}] = ")?;
608            Self::format_scalar(f, lhsi, rhsi, out_item)?;
609            f.write_str(";\n")?;
610        }
611
612        Ok(())
613    }
614}
615
616impl Index {
617    pub(crate) fn format<D: Dialect>(
618        f: &mut Formatter<'_>,
619        list: &Variable<D>,
620        index: &Variable<D>,
621        out: &Variable<D>,
622        line_size: u32,
623    ) -> std::fmt::Result {
624        if matches!(
625            list,
626            Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
627        ) {
628            return IndexVector::format(f, list, index, out);
629        }
630
631        if line_size > 0 {
632            let mut item = list.item();
633            item.vectorization = line_size as usize;
634            let addr_space = D::address_space_for_variable(list);
635            let qualifier = list.const_qualifier();
636            let tmp = Variable::tmp_declared(item);
637
638            writeln!(
639                f,
640                "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
641            )?;
642
643            return Index::format(f, &tmp, index, out, 0);
644        }
645
646        let item_out = out.item();
647        if let Elem::Atomic(inner) = item_out.elem {
648            let addr_space = D::address_space_for_variable(list);
649            writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
650        } else if matches!(item_out.elem, Elem::Barrier(_)) {
651            let addr_space = D::address_space_for_variable(list);
652            writeln!(f, "{addr_space}{}& {out} = {list}[{index}];", item_out.elem)
653        } else {
654            let out = out.fmt_left();
655            write!(f, "{out} = ")?;
656            Self::format_scalar(f, *list, *index, item_out)?;
657            f.write_str(";\n")
658        }
659    }
660
661    fn format_scalar<D: Dialect, Lhs, Rhs>(
662        f: &mut Formatter<'_>,
663        lhs: Lhs,
664        rhs: Rhs,
665        item_out: Item<D>,
666    ) -> std::fmt::Result
667    where
668        Lhs: Component<D>,
669        Rhs: Component<D>,
670    {
671        let item_lhs = lhs.item();
672
673        let format_vec = |f: &mut Formatter<'_>| {
674            writeln!(f, "{item_out}{{")?;
675            for i in 0..item_out.vectorization {
676                write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
677            }
678            f.write_str("}")?;
679
680            Ok(())
681        };
682
683        if item_out.elem != item_lhs.elem {
684            if item_out.vectorization > 1 {
685                format_vec(f)
686            } else {
687                write!(f, "{}({lhs}[{rhs}])", item_out.elem)
688            }
689        } else {
690            write!(f, "{lhs}[{rhs}]")
691        }
692    }
693}
694
695/// The goal is to support indexing of vectorized types.
696///
697/// # Examples
698///
699/// ```c
700/// float4 rhs;
701/// float item = var[0]; // We want that.
702/// float item = var.x; // So we compile to that.
703/// ```
704struct IndexVector<D: Dialect> {
705    _dialect: PhantomData<D>,
706}
707
708/// The goal is to support indexing of vectorized types.
709///
710/// # Examples
711///
712/// ```c
713/// float4 var;
714///
715/// var[0] = 1.0; // We want that.
716/// var.x = 1.0;  // So we compile to that.
717/// ```
718struct IndexAssignVector<D: Dialect> {
719    _dialect: PhantomData<D>,
720}
721
722impl<D: Dialect> IndexVector<D> {
723    fn format(
724        f: &mut Formatter<'_>,
725        lhs: &Variable<D>,
726        rhs: &Variable<D>,
727        out: &Variable<D>,
728    ) -> std::fmt::Result {
729        match rhs {
730            Variable::ConstantScalar(value, _elem) => {
731                let index = value.as_usize();
732                let out = out.index(index);
733                let lhs = lhs.index(index);
734                let out = out.fmt_left();
735                writeln!(f, "{out} = {lhs};")
736            }
737            _ => {
738                let elem = out.elem();
739                let qualifier = out.const_qualifier();
740                let addr_space = D::address_space_for_variable(out);
741                let out = out.fmt_left();
742                writeln!(
743                    f,
744                    "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
745                )
746            }
747        }
748    }
749}
750
751impl<D: Dialect> IndexAssignVector<D> {
752    fn format(
753        f: &mut Formatter<'_>,
754        lhs: &Variable<D>,
755        rhs: &Variable<D>,
756        out: &Variable<D>,
757    ) -> std::fmt::Result {
758        let index = match lhs {
759            Variable::ConstantScalar(value, _) => value.as_usize(),
760            _ => {
761                let elem = out.elem();
762                let addr_space = D::address_space_for_variable(out);
763                return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
764            }
765        };
766
767        let out = out.index(index);
768        let rhs = rhs.index(index);
769
770        writeln!(f, "{out} = {rhs};")
771    }
772}