1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74 let addr_space = D::address_space_for_variable(out);
75 let out = out.fmt_left();
76
77 writeln!(
78 f,
79 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80 )?;
81
82 Ok(())
83 }
84 }
85}
86
87macro_rules! operator {
88 ($name:ident, $op:expr) => {
89 pub struct $name;
90
91 impl<D: Dialect> Binary<D> for $name {
92 fn format_scalar<Lhs: Display, Rhs: Display>(
93 f: &mut std::fmt::Formatter<'_>,
94 lhs: Lhs,
95 rhs: Rhs,
96 out_item: Item<D>,
97 ) -> std::fmt::Result {
98 let out_elem = out_item.elem();
99 match out_elem {
100 Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104 write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105 }
106 _ => write!(f, "{lhs} {} {rhs}", $op),
107 }
108 }
109 }
110 };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135 fn format_scalar<Lhs: Display, Rhs: Display>(
136 f: &mut std::fmt::Formatter<'_>,
137 lhs: Lhs,
138 rhs: Rhs,
139 _out_item: Item<D>,
140 ) -> std::fmt::Result {
141 write!(f, "__fdividef({lhs}, {rhs})")
143 }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149 fn format_scalar<Lhs: Display, Rhs: Display>(
151 f: &mut std::fmt::Formatter<'_>,
152 lhs: Lhs,
153 rhs: Rhs,
154 out: Item<D>,
155 ) -> std::fmt::Result {
156 let out_elem = out.elem;
157 match out_elem {
158 Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159 Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160 Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161 Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162 _ => writeln!(f, "#error HiMul only supports 32 and 64 bit ints"),
163 }
164 }
165
166 fn unroll_vec(
168 f: &mut Formatter<'_>,
169 lhs: &Variable<D>,
170 rhs: &Variable<D>,
171 out: &Variable<D>,
172 ) -> core::fmt::Result {
173 let item_out = out.item();
174 let index = out.item().vectorization;
175
176 let out = out.fmt_left();
177 writeln!(f, "{out} = {item_out}{{")?;
178 for i in 0..index {
179 let lhsi = lhs.index(i);
180 let rhsi = rhs.index(i);
181
182 Self::format_scalar(f, lhsi, rhsi, item_out)?;
183 f.write_str(", ")?;
184 }
185
186 f.write_str("};\n")
187 }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193 fn format_scalar<Lhs: Display, Rhs: Display>(
194 f: &mut std::fmt::Formatter<'_>,
195 lhs: Lhs,
196 rhs: Rhs,
197 out: Item<D>,
198 ) -> std::fmt::Result {
199 D::compile_saturating_add(f, lhs, rhs, out)
200 }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206 fn format_scalar<Lhs: Display, Rhs: Display>(
207 f: &mut std::fmt::Formatter<'_>,
208 lhs: Lhs,
209 rhs: Rhs,
210 out: Item<D>,
211 ) -> std::fmt::Result {
212 D::compile_saturating_sub(f, lhs, rhs, out)
213 }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219 fn format_scalar<Lhs: Display, Rhs: Display>(
221 f: &mut std::fmt::Formatter<'_>,
222 lhs: Lhs,
223 rhs: Rhs,
224 item: Item<D>,
225 ) -> std::fmt::Result {
226 let elem = item.elem;
227 let lhs = lhs.to_string();
228 let rhs = rhs.to_string();
229 match elem {
230 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231 let lhs = format!("float({lhs})");
232 let rhs = format!("float({rhs})");
233 write!(f, "{elem}(")?;
234 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235 write!(f, ")")
236 }
237 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238 }
239 }
240
241 fn unroll_vec(
243 f: &mut Formatter<'_>,
244 lhs: &Variable<D>,
245 rhs: &Variable<D>,
246 out: &Variable<D>,
247 ) -> core::fmt::Result {
248 let item_out = out.item();
249 let index = out.item().vectorization;
250
251 let out = out.fmt_left();
252 writeln!(f, "{out} = {item_out}{{")?;
253 for i in 0..index {
254 let lhsi = lhs.index(i);
255 let rhsi = rhs.index(i);
256
257 Self::format_scalar(f, lhsi, rhsi, item_out)?;
258 f.write_str(", ")?;
259 }
260
261 f.write_str("};\n")
262 }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268 fn format_scalar<Lhs: Display, Rhs: Display>(
270 f: &mut std::fmt::Formatter<'_>,
271 lhs: Lhs,
272 rhs: Rhs,
273 _item: Item<D>,
274 ) -> std::fmt::Result {
275 write!(f, "__powf({lhs}, {rhs})")
276 }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282 fn format_scalar<Lhs: Display, Rhs: Display>(
284 f: &mut std::fmt::Formatter<'_>,
285 lhs: Lhs,
286 rhs: Rhs,
287 item: Item<D>,
288 ) -> std::fmt::Result {
289 let elem = item.elem;
290 let lhs = lhs.to_string();
291 let rhs = rhs.to_string();
292 match elem {
293 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294 let lhs = format!("float({lhs})");
295
296 write!(f, "{elem}(")?;
297 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298 write!(f, ")")
299 }
300 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
301 }
302 }
303
304 fn unroll_vec(
306 f: &mut Formatter<'_>,
307 lhs: &Variable<D>,
308 rhs: &Variable<D>,
309 out: &Variable<D>,
310 ) -> core::fmt::Result {
311 let item_out = out.item();
312 let index = out.item().vectorization;
313
314 let out = out.fmt_left();
315 writeln!(f, "{out} = {item_out}{{")?;
316 for i in 0..index {
317 let lhsi = lhs.index(i);
318 let rhsi = rhs.index(i);
319
320 Self::format_scalar(f, lhsi, rhsi, item_out)?;
321 f.write_str(", ")?;
322 }
323
324 f.write_str("};\n")
325 }
326}
327pub struct ArcTan2;
328
329impl<D: Dialect> Binary<D> for ArcTan2 {
330 fn format_scalar<Lhs: Display, Rhs: Display>(
332 f: &mut std::fmt::Formatter<'_>,
333 lhs: Lhs,
334 rhs: Rhs,
335 item: Item<D>,
336 ) -> std::fmt::Result {
337 let elem = item.elem;
338 match elem {
339 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
340 write!(f, "{elem}(atan2(float({lhs}), float({rhs})))")
341 }
342 _ => {
343 write!(f, "atan2({lhs}, {rhs})")
344 }
345 }
346 }
347
348 fn unroll_vec(
350 f: &mut Formatter<'_>,
351 lhs: &Variable<D>,
352 rhs: &Variable<D>,
353 out: &Variable<D>,
354 ) -> core::fmt::Result {
355 let item_out = out.item();
356 let index = out.item().vectorization;
357
358 let out = out.fmt_left();
359 writeln!(f, "{out} = {item_out}{{")?;
360 for i in 0..index {
361 let lhsi = lhs.index(i);
362 let rhsi = rhs.index(i);
363
364 Self::format_scalar(f, lhsi, rhsi, item_out)?;
365 f.write_str(", ")?;
366 }
367
368 f.write_str("};\n")
369 }
370}
371
372pub struct Hypot;
373
374impl<D: Dialect> Binary<D> for Hypot {
375 fn format_scalar<Lhs, Rhs>(
377 f: &mut Formatter<'_>,
378 lhs: Lhs,
379 rhs: Rhs,
380 item: Item<D>,
381 ) -> std::fmt::Result
382 where
383 Lhs: Component<D>,
384 Rhs: Component<D>,
385 {
386 let elem = item.elem;
387 let lhs = lhs.to_string();
388 let rhs = rhs.to_string();
389 match elem {
390 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
391 let lhs = format!("float({lhs})");
392 let rhs = format!("float({rhs})");
393 write!(f, "{elem}(")?;
394 D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?;
395 write!(f, ")")
396 }
397 _ => D::compile_instruction_hypot(f, &lhs, &rhs, elem),
398 }
399 }
400
401 fn unroll_vec(
403 f: &mut Formatter<'_>,
404 lhs: &Variable<D>,
405 rhs: &Variable<D>,
406 out: &Variable<D>,
407 ) -> core::fmt::Result {
408 let item_out = out.item();
409 let index = out.item().vectorization;
410
411 let out = out.fmt_left();
412 writeln!(f, "{out} = {item_out}{{")?;
413 for i in 0..index {
414 let lhsi = lhs.index(i);
415 let rhsi = rhs.index(i);
416
417 Self::format_scalar(f, lhsi, rhsi, item_out)?;
418 f.write_str(", ")?;
419 }
420
421 f.write_str("};\n")
422 }
423}
424
425pub struct Rhypot;
426
427impl<D: Dialect> Binary<D> for Rhypot {
428 fn format_scalar<Lhs, Rhs>(
430 f: &mut Formatter<'_>,
431 lhs: Lhs,
432 rhs: Rhs,
433 item: Item<D>,
434 ) -> std::fmt::Result
435 where
436 Lhs: Component<D>,
437 Rhs: Component<D>,
438 {
439 let elem = item.elem;
440 let lhs = lhs.to_string();
441 let rhs = rhs.to_string();
442 match elem {
443 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
444 let lhs = format!("float({lhs})");
445 let rhs = format!("float({rhs})");
446 write!(f, "{elem}(")?;
447 D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?;
448 write!(f, ")")
449 }
450 _ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem),
451 }
452 }
453
454 fn unroll_vec(
456 f: &mut Formatter<'_>,
457 lhs: &Variable<D>,
458 rhs: &Variable<D>,
459 out: &Variable<D>,
460 ) -> core::fmt::Result {
461 let item_out = out.item();
462 let index = out.item().vectorization;
463
464 let out = out.fmt_left();
465 writeln!(f, "{out} = {item_out}{{")?;
466 for i in 0..index {
467 let lhsi = lhs.index(i);
468 let rhsi = rhs.index(i);
469
470 Self::format_scalar(f, lhsi, rhsi, item_out)?;
471 f.write_str(", ")?;
472 }
473
474 f.write_str("};\n")
475 }
476}
477
478pub struct Max;
479
480impl<D: Dialect> Binary<D> for Max {
481 fn format_scalar<Lhs: Display, Rhs: Display>(
482 f: &mut std::fmt::Formatter<'_>,
483 lhs: Lhs,
484 rhs: Rhs,
485 item: Item<D>,
486 ) -> std::fmt::Result {
487 D::compile_instruction_max_function_name(f, item)?;
488 write!(f, "({lhs}, {rhs})")
489 }
490}
491
492pub struct Min;
493
494impl<D: Dialect> Binary<D> for Min {
495 fn format_scalar<Lhs: Display, Rhs: Display>(
496 f: &mut std::fmt::Formatter<'_>,
497 lhs: Lhs,
498 rhs: Rhs,
499 item: Item<D>,
500 ) -> std::fmt::Result {
501 D::compile_instruction_min_function_name(f, item)?;
502 write!(f, "({lhs}, {rhs})")
503 }
504}
505
506pub struct IndexAssign;
507pub struct Index;
508
509impl IndexAssign {
510 pub fn format<D: Dialect>(
511 f: &mut Formatter<'_>,
512 index: &Variable<D>,
513 value: &Variable<D>,
514 out_list: &Variable<D>,
515 line_size: u32,
516 ) -> std::fmt::Result {
517 if matches!(
518 out_list,
519 Variable::LocalMut { .. } | Variable::LocalConst { .. }
520 ) {
521 return IndexAssignVector::format(f, index, value, out_list);
522 };
523
524 if line_size > 0 {
525 let mut item = out_list.item();
526 item.vectorization = line_size as usize;
527 let addr_space = D::address_space_for_variable(out_list);
528 let qualifier = out_list.const_qualifier();
529 let tmp = Variable::tmp_declared(item);
530
531 writeln!(
532 f,
533 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
534 )?;
535
536 return IndexAssign::format(f, index, value, &tmp, 0);
537 }
538
539 let out_item = out_list.item();
540
541 if index.item().vectorization == 1 {
542 write!(f, "{}[{index}] = ", out_list.fmt_left())?;
543 Self::format_scalar(f, *index, *value, out_item)?;
544 f.write_str(";\n")
545 } else {
546 Self::unroll_vec(f, index, value, out_list)
547 }
548 }
549 fn format_scalar<D: Dialect, Lhs, Rhs>(
550 f: &mut Formatter<'_>,
551 _lhs: Lhs,
552 rhs: Rhs,
553 item_out: Item<D>,
554 ) -> std::fmt::Result
555 where
556 Lhs: Component<D>,
557 Rhs: Component<D>,
558 {
559 let item_rhs = rhs.item();
560
561 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
562 writeln!(f, "{item_out}{{")?;
563 for i in 0..item_out.vectorization {
564 if cast {
565 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
566 } else {
567 writeln!(f, "{},", rhs.index(i))?;
568 }
569 }
570 f.write_str("}")?;
571
572 Ok(())
573 };
574
575 if item_out.vectorization != item_rhs.vectorization {
576 format_vec(f, item_out != item_rhs)
577 } else if item_out.elem != item_rhs.elem {
578 if item_out.vectorization > 1 {
579 format_vec(f, true)?;
580 } else {
581 write!(f, "{}({rhs})", item_out.elem)?;
582 }
583 Ok(())
584 } else if rhs.is_const() && item_rhs.vectorization > 1 {
585 write!(f, "reinterpret_cast<")?;
587 D::compile_local_memory_qualifier(f)?;
588 write!(f, " {item_out} const&>({rhs})")
589 } else {
590 write!(f, "{rhs}")
591 }
592 }
593
594 fn unroll_vec<D: Dialect>(
595 f: &mut Formatter<'_>,
596 lhs: &Variable<D>,
597 rhs: &Variable<D>,
598 out: &Variable<D>,
599 ) -> std::fmt::Result {
600 let item_lhs = lhs.item();
601 let out_item = out.item();
602 let out = out.fmt_left();
603
604 for i in 0..item_lhs.vectorization {
605 let lhsi = lhs.index(i);
606 let rhsi = rhs.index(i);
607 write!(f, "{out}[{lhs}] = ")?;
608 Self::format_scalar(f, lhsi, rhsi, out_item)?;
609 f.write_str(";\n")?;
610 }
611
612 Ok(())
613 }
614}
615
616impl Index {
617 pub(crate) fn format<D: Dialect>(
618 f: &mut Formatter<'_>,
619 list: &Variable<D>,
620 index: &Variable<D>,
621 out: &Variable<D>,
622 line_size: u32,
623 ) -> std::fmt::Result {
624 if matches!(
625 list,
626 Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
627 ) {
628 return IndexVector::format(f, list, index, out);
629 }
630
631 if line_size > 0 {
632 let mut item = list.item();
633 item.vectorization = line_size as usize;
634 let addr_space = D::address_space_for_variable(list);
635 let qualifier = list.const_qualifier();
636 let tmp = Variable::tmp_declared(item);
637
638 writeln!(
639 f,
640 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
641 )?;
642
643 return Index::format(f, &tmp, index, out, 0);
644 }
645
646 let item_out = out.item();
647 if let Elem::Atomic(inner) = item_out.elem {
648 let addr_space = D::address_space_for_variable(list);
649 writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
650 } else if matches!(item_out.elem, Elem::Barrier(_)) {
651 let addr_space = D::address_space_for_variable(list);
652 writeln!(f, "{addr_space}{}& {out} = {list}[{index}];", item_out.elem)
653 } else {
654 let out = out.fmt_left();
655 write!(f, "{out} = ")?;
656 Self::format_scalar(f, *list, *index, item_out)?;
657 f.write_str(";\n")
658 }
659 }
660
661 fn format_scalar<D: Dialect, Lhs, Rhs>(
662 f: &mut Formatter<'_>,
663 lhs: Lhs,
664 rhs: Rhs,
665 item_out: Item<D>,
666 ) -> std::fmt::Result
667 where
668 Lhs: Component<D>,
669 Rhs: Component<D>,
670 {
671 let item_lhs = lhs.item();
672
673 let format_vec = |f: &mut Formatter<'_>| {
674 writeln!(f, "{item_out}{{")?;
675 for i in 0..item_out.vectorization {
676 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
677 }
678 f.write_str("}")?;
679
680 Ok(())
681 };
682
683 if item_out.elem != item_lhs.elem {
684 if item_out.vectorization > 1 {
685 format_vec(f)
686 } else {
687 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
688 }
689 } else {
690 write!(f, "{lhs}[{rhs}]")
691 }
692 }
693}
694
695struct IndexVector<D: Dialect> {
705 _dialect: PhantomData<D>,
706}
707
708struct IndexAssignVector<D: Dialect> {
719 _dialect: PhantomData<D>,
720}
721
722impl<D: Dialect> IndexVector<D> {
723 fn format(
724 f: &mut Formatter<'_>,
725 lhs: &Variable<D>,
726 rhs: &Variable<D>,
727 out: &Variable<D>,
728 ) -> std::fmt::Result {
729 match rhs {
730 Variable::ConstantScalar(value, _elem) => {
731 let index = value.as_usize();
732 let out = out.index(index);
733 let lhs = lhs.index(index);
734 let out = out.fmt_left();
735 writeln!(f, "{out} = {lhs};")
736 }
737 _ => {
738 let elem = out.elem();
739 let qualifier = out.const_qualifier();
740 let addr_space = D::address_space_for_variable(out);
741 let out = out.fmt_left();
742 writeln!(
743 f,
744 "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
745 )
746 }
747 }
748 }
749}
750
751impl<D: Dialect> IndexAssignVector<D> {
752 fn format(
753 f: &mut Formatter<'_>,
754 lhs: &Variable<D>,
755 rhs: &Variable<D>,
756 out: &Variable<D>,
757 ) -> std::fmt::Result {
758 let index = match lhs {
759 Variable::ConstantScalar(value, _) => value.as_usize(),
760 _ => {
761 let elem = out.elem();
762 let addr_space = D::address_space_for_variable(out);
763 return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
764 }
765 };
766
767 let out = out.index(index);
768 let rhs = rhs.index(index);
769
770 writeln!(f, "{out} = {rhs};")
771 }
772}