1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74 let addr_space = D::address_space_for_variable(out);
75 let out = out.fmt_left();
76
77 writeln!(
78 f,
79 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80 )?;
81
82 Ok(())
83 }
84 }
85}
86
87macro_rules! operator {
88 ($name:ident, $op:expr) => {
89 pub struct $name;
90
91 impl<D: Dialect> Binary<D> for $name {
92 fn format_scalar<Lhs: Display, Rhs: Display>(
93 f: &mut std::fmt::Formatter<'_>,
94 lhs: Lhs,
95 rhs: Rhs,
96 out_item: Item<D>,
97 ) -> std::fmt::Result {
98 let out_elem = out_item.elem();
99 match out_elem {
100 Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104 write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105 }
106 _ => write!(f, "{lhs} {} {rhs}", $op),
107 }
108 }
109 }
110 };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135 fn format_scalar<Lhs: Display, Rhs: Display>(
136 f: &mut std::fmt::Formatter<'_>,
137 lhs: Lhs,
138 rhs: Rhs,
139 _out_item: Item<D>,
140 ) -> std::fmt::Result {
141 write!(f, "__fdividef({lhs}, {rhs})")
143 }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149 fn format_scalar<Lhs: Display, Rhs: Display>(
151 f: &mut std::fmt::Formatter<'_>,
152 lhs: Lhs,
153 rhs: Rhs,
154 out: Item<D>,
155 ) -> std::fmt::Result {
156 let out_elem = out.elem;
157 match out_elem {
158 Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159 Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160 Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161 Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162 _ => writeln!(f, "#error HiMul only supports 32 and 64 bit ints"),
163 }
164 }
165
166 fn unroll_vec(
168 f: &mut Formatter<'_>,
169 lhs: &Variable<D>,
170 rhs: &Variable<D>,
171 out: &Variable<D>,
172 ) -> core::fmt::Result {
173 let item_out = out.item();
174 let index = out.item().vectorization;
175
176 let out = out.fmt_left();
177 writeln!(f, "{out} = {item_out}{{")?;
178 for i in 0..index {
179 let lhsi = lhs.index(i);
180 let rhsi = rhs.index(i);
181
182 Self::format_scalar(f, lhsi, rhsi, item_out)?;
183 f.write_str(", ")?;
184 }
185
186 f.write_str("};\n")
187 }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193 fn format_scalar<Lhs: Display, Rhs: Display>(
194 f: &mut std::fmt::Formatter<'_>,
195 lhs: Lhs,
196 rhs: Rhs,
197 out: Item<D>,
198 ) -> std::fmt::Result {
199 D::compile_saturating_add(f, lhs, rhs, out)
200 }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206 fn format_scalar<Lhs: Display, Rhs: Display>(
207 f: &mut std::fmt::Formatter<'_>,
208 lhs: Lhs,
209 rhs: Rhs,
210 out: Item<D>,
211 ) -> std::fmt::Result {
212 D::compile_saturating_sub(f, lhs, rhs, out)
213 }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219 fn format_scalar<Lhs: Display, Rhs: Display>(
221 f: &mut std::fmt::Formatter<'_>,
222 lhs: Lhs,
223 rhs: Rhs,
224 item: Item<D>,
225 ) -> std::fmt::Result {
226 let elem = item.elem;
227 let lhs = lhs.to_string();
228 let rhs = rhs.to_string();
229 match elem {
230 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231 let lhs = format!("float({lhs})");
232 let rhs = format!("float({rhs})");
233 write!(f, "{elem}(")?;
234 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235 write!(f, ")")
236 }
237 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238 }
239 }
240
241 fn unroll_vec(
243 f: &mut Formatter<'_>,
244 lhs: &Variable<D>,
245 rhs: &Variable<D>,
246 out: &Variable<D>,
247 ) -> core::fmt::Result {
248 let item_out = out.item();
249 let index = out.item().vectorization;
250
251 let out = out.fmt_left();
252 writeln!(f, "{out} = {item_out}{{")?;
253 for i in 0..index {
254 let lhsi = lhs.index(i);
255 let rhsi = rhs.index(i);
256
257 Self::format_scalar(f, lhsi, rhsi, item_out)?;
258 f.write_str(", ")?;
259 }
260
261 f.write_str("};\n")
262 }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268 fn format_scalar<Lhs: Display, Rhs: Display>(
270 f: &mut std::fmt::Formatter<'_>,
271 lhs: Lhs,
272 rhs: Rhs,
273 _item: Item<D>,
274 ) -> std::fmt::Result {
275 write!(f, "__powf({lhs}, {rhs})")
276 }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282 fn format_scalar<Lhs: Display, Rhs: Display>(
284 f: &mut std::fmt::Formatter<'_>,
285 lhs: Lhs,
286 rhs: Rhs,
287 item: Item<D>,
288 ) -> std::fmt::Result {
289 let elem = item.elem;
290 let lhs = lhs.to_string();
291 let rhs = rhs.to_string();
292 match elem {
293 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294 let lhs = format!("float({lhs})");
295
296 write!(f, "{elem}(")?;
297 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298 write!(f, ")")
299 }
300 Elem::F64 => {
301 let rhs = format!("double({rhs})");
303
304 D::compile_instruction_powf(f, &lhs, &rhs, elem)
305 }
306 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
307 }
308 }
309
310 fn unroll_vec(
312 f: &mut Formatter<'_>,
313 lhs: &Variable<D>,
314 rhs: &Variable<D>,
315 out: &Variable<D>,
316 ) -> core::fmt::Result {
317 let item_out = out.item();
318 let index = out.item().vectorization;
319
320 let out = out.fmt_left();
321 writeln!(f, "{out} = {item_out}{{")?;
322 for i in 0..index {
323 let lhsi = lhs.index(i);
324 let rhsi = rhs.index(i);
325
326 Self::format_scalar(f, lhsi, rhsi, item_out)?;
327 f.write_str(", ")?;
328 }
329
330 f.write_str("};\n")
331 }
332}
333pub struct ArcTan2;
334
335impl<D: Dialect> Binary<D> for ArcTan2 {
336 fn format_scalar<Lhs: Display, Rhs: Display>(
338 f: &mut std::fmt::Formatter<'_>,
339 lhs: Lhs,
340 rhs: Rhs,
341 item: Item<D>,
342 ) -> std::fmt::Result {
343 let elem = item.elem;
344 match elem {
345 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
346 write!(f, "{elem}(atan2(float({lhs}), float({rhs})))")
347 }
348 _ => {
349 write!(f, "atan2({lhs}, {rhs})")
350 }
351 }
352 }
353
354 fn unroll_vec(
356 f: &mut Formatter<'_>,
357 lhs: &Variable<D>,
358 rhs: &Variable<D>,
359 out: &Variable<D>,
360 ) -> core::fmt::Result {
361 let item_out = out.item();
362 let index = out.item().vectorization;
363
364 let out = out.fmt_left();
365 writeln!(f, "{out} = {item_out}{{")?;
366 for i in 0..index {
367 let lhsi = lhs.index(i);
368 let rhsi = rhs.index(i);
369
370 Self::format_scalar(f, lhsi, rhsi, item_out)?;
371 f.write_str(", ")?;
372 }
373
374 f.write_str("};\n")
375 }
376}
377
378pub struct Hypot;
379
380impl<D: Dialect> Binary<D> for Hypot {
381 fn format_scalar<Lhs, Rhs>(
383 f: &mut Formatter<'_>,
384 lhs: Lhs,
385 rhs: Rhs,
386 item: Item<D>,
387 ) -> std::fmt::Result
388 where
389 Lhs: Component<D>,
390 Rhs: Component<D>,
391 {
392 let elem = item.elem;
393 let lhs = lhs.to_string();
394 let rhs = rhs.to_string();
395 match elem {
396 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
397 let lhs = format!("float({lhs})");
398 let rhs = format!("float({rhs})");
399 write!(f, "{elem}(")?;
400 D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?;
401 write!(f, ")")
402 }
403 _ => D::compile_instruction_hypot(f, &lhs, &rhs, elem),
404 }
405 }
406
407 fn unroll_vec(
409 f: &mut Formatter<'_>,
410 lhs: &Variable<D>,
411 rhs: &Variable<D>,
412 out: &Variable<D>,
413 ) -> core::fmt::Result {
414 let item_out = out.item();
415 let index = out.item().vectorization;
416
417 let out = out.fmt_left();
418 writeln!(f, "{out} = {item_out}{{")?;
419 for i in 0..index {
420 let lhsi = lhs.index(i);
421 let rhsi = rhs.index(i);
422
423 Self::format_scalar(f, lhsi, rhsi, item_out)?;
424 f.write_str(", ")?;
425 }
426
427 f.write_str("};\n")
428 }
429}
430
431pub struct Rhypot;
432
433impl<D: Dialect> Binary<D> for Rhypot {
434 fn format_scalar<Lhs, Rhs>(
436 f: &mut Formatter<'_>,
437 lhs: Lhs,
438 rhs: Rhs,
439 item: Item<D>,
440 ) -> std::fmt::Result
441 where
442 Lhs: Component<D>,
443 Rhs: Component<D>,
444 {
445 let elem = item.elem;
446 let lhs = lhs.to_string();
447 let rhs = rhs.to_string();
448 match elem {
449 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
450 let lhs = format!("float({lhs})");
451 let rhs = format!("float({rhs})");
452 write!(f, "{elem}(")?;
453 D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?;
454 write!(f, ")")
455 }
456 _ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem),
457 }
458 }
459
460 fn unroll_vec(
462 f: &mut Formatter<'_>,
463 lhs: &Variable<D>,
464 rhs: &Variable<D>,
465 out: &Variable<D>,
466 ) -> core::fmt::Result {
467 let item_out = out.item();
468 let index = out.item().vectorization;
469
470 let out = out.fmt_left();
471 writeln!(f, "{out} = {item_out}{{")?;
472 for i in 0..index {
473 let lhsi = lhs.index(i);
474 let rhsi = rhs.index(i);
475
476 Self::format_scalar(f, lhsi, rhsi, item_out)?;
477 f.write_str(", ")?;
478 }
479
480 f.write_str("};\n")
481 }
482}
483
484pub struct Max;
485
486impl<D: Dialect> Binary<D> for Max {
487 fn format_scalar<Lhs: Display, Rhs: Display>(
488 f: &mut std::fmt::Formatter<'_>,
489 lhs: Lhs,
490 rhs: Rhs,
491 item: Item<D>,
492 ) -> std::fmt::Result {
493 D::compile_instruction_max_function_name(f, item)?;
494 write!(f, "({lhs}, {rhs})")
495 }
496}
497
498pub struct Min;
499
500impl<D: Dialect> Binary<D> for Min {
501 fn format_scalar<Lhs: Display, Rhs: Display>(
502 f: &mut std::fmt::Formatter<'_>,
503 lhs: Lhs,
504 rhs: Rhs,
505 item: Item<D>,
506 ) -> std::fmt::Result {
507 D::compile_instruction_min_function_name(f, item)?;
508 write!(f, "({lhs}, {rhs})")
509 }
510}
511
512pub struct IndexAssign;
513pub struct Index;
514
515impl IndexAssign {
516 pub fn format<D: Dialect>(
517 f: &mut Formatter<'_>,
518 index: &Variable<D>,
519 value: &Variable<D>,
520 out_list: &Variable<D>,
521 vector_size: u32,
522 ) -> std::fmt::Result {
523 if matches!(
524 out_list,
525 Variable::LocalMut { .. } | Variable::LocalConst { .. }
526 ) {
527 return IndexAssignVector::format(f, index, value, out_list);
528 };
529
530 if vector_size > 0 {
531 let mut item = out_list.item();
532 item.vectorization = vector_size as usize;
533 let addr_space = D::address_space_for_variable(out_list);
534 let qualifier = out_list.const_qualifier();
535 let tmp = Variable::tmp_declared(item);
536
537 writeln!(
538 f,
539 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
540 )?;
541
542 return IndexAssign::format(f, index, value, &tmp, 0);
543 }
544
545 let out_item = out_list.item();
546
547 if index.item().vectorization == 1 {
548 write!(f, "{}[{index}] = ", out_list.fmt_left())?;
549 Self::format_scalar(f, *index, *value, out_item)?;
550 f.write_str(";\n")
551 } else {
552 Self::unroll_vec(f, index, value, out_list)
553 }
554 }
555 fn format_scalar<D: Dialect, Lhs, Rhs>(
556 f: &mut Formatter<'_>,
557 _lhs: Lhs,
558 rhs: Rhs,
559 item_out: Item<D>,
560 ) -> std::fmt::Result
561 where
562 Lhs: Component<D>,
563 Rhs: Component<D>,
564 {
565 let item_rhs = rhs.item();
566
567 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
568 writeln!(f, "{item_out}{{")?;
569 for i in 0..item_out.vectorization {
570 if cast {
571 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
572 } else {
573 writeln!(f, "{},", rhs.index(i))?;
574 }
575 }
576 f.write_str("}")?;
577
578 Ok(())
579 };
580
581 if item_out.vectorization != item_rhs.vectorization {
582 format_vec(f, item_out != item_rhs)
583 } else if item_out.elem != item_rhs.elem {
584 if item_out.vectorization > 1 {
585 format_vec(f, true)?;
586 } else {
587 write!(f, "{}({rhs})", item_out.elem)?;
588 }
589 Ok(())
590 } else if rhs.is_const() && item_rhs.vectorization > 1 {
591 write!(f, "reinterpret_cast<")?;
593 D::compile_local_memory_qualifier(f)?;
594 write!(f, " {item_out} const&>({rhs})")
595 } else {
596 write!(f, "{rhs}")
597 }
598 }
599
600 fn unroll_vec<D: Dialect>(
601 f: &mut Formatter<'_>,
602 lhs: &Variable<D>,
603 rhs: &Variable<D>,
604 out: &Variable<D>,
605 ) -> std::fmt::Result {
606 let item_lhs = lhs.item();
607 let out_item = out.item();
608 let out = out.fmt_left();
609
610 for i in 0..item_lhs.vectorization {
611 let lhsi = lhs.index(i);
612 let rhsi = rhs.index(i);
613 write!(f, "{out}[{lhs}] = ")?;
614 Self::format_scalar(f, lhsi, rhsi, out_item)?;
615 f.write_str(";\n")?;
616 }
617
618 Ok(())
619 }
620}
621
622impl Index {
623 pub(crate) fn format<D: Dialect>(
624 f: &mut Formatter<'_>,
625 list: &Variable<D>,
626 index: &Variable<D>,
627 out: &Variable<D>,
628 vector_size: u32,
629 ) -> std::fmt::Result {
630 if matches!(
631 list,
632 Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::Constant(..)
633 ) {
634 return IndexVector::format(f, list, index, out);
635 }
636
637 if vector_size > 0 {
638 let mut item = list.item();
639 item.vectorization = vector_size as usize;
640 let addr_space = D::address_space_for_variable(list);
641 let qualifier = list.const_qualifier();
642 let tmp = Variable::tmp_declared(item);
643
644 writeln!(
645 f,
646 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
647 )?;
648
649 return Index::format(f, &tmp, index, out, 0);
650 }
651
652 let item_out = out.item();
653 if let Elem::Atomic(_) = item_out.elem {
654 let addr_space = D::address_space_for_variable(list);
655 writeln!(f, "{addr_space}{item_out}* {out} = &{list}[{index}];")
656 } else if matches!(item_out.elem, Elem::Barrier(_)) {
657 let addr_space = D::address_space_for_variable(list);
658 writeln!(f, "{addr_space}{}& {out} = {list}[{index}];", item_out.elem)
659 } else {
660 let out = out.fmt_left();
661 write!(f, "{out} = ")?;
662 Self::format_scalar(f, *list, *index, item_out)?;
663 f.write_str(";\n")
664 }
665 }
666
667 fn format_scalar<D: Dialect, Lhs, Rhs>(
668 f: &mut Formatter<'_>,
669 lhs: Lhs,
670 rhs: Rhs,
671 item_out: Item<D>,
672 ) -> std::fmt::Result
673 where
674 Lhs: Component<D>,
675 Rhs: Component<D>,
676 {
677 let item_lhs = lhs.item();
678
679 let format_vec = |f: &mut Formatter<'_>| {
680 writeln!(f, "{item_out}{{")?;
681 for i in 0..item_out.vectorization {
682 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
683 }
684 f.write_str("}")?;
685
686 Ok(())
687 };
688
689 if item_out.elem != item_lhs.elem {
690 if item_out.vectorization > 1 {
691 format_vec(f)
692 } else {
693 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
694 }
695 } else {
696 write!(f, "{lhs}[{rhs}]")
697 }
698 }
699}
700
701struct IndexVector<D: Dialect> {
711 _dialect: PhantomData<D>,
712}
713
714struct IndexAssignVector<D: Dialect> {
725 _dialect: PhantomData<D>,
726}
727
728impl<D: Dialect> IndexVector<D> {
729 fn format(
730 f: &mut Formatter<'_>,
731 lhs: &Variable<D>,
732 rhs: &Variable<D>,
733 out: &Variable<D>,
734 ) -> std::fmt::Result {
735 match rhs {
736 Variable::Constant(value, _elem) => {
737 let index = value.as_usize();
738 let out = out.index(index);
739 let lhs = lhs.index(index);
740 let out = out.fmt_left();
741 writeln!(f, "{out} = {lhs};")
742 }
743 _ => {
744 let elem = out.elem();
745 let qualifier = out.const_qualifier();
746 let addr_space = D::address_space_for_variable(out);
747 let out = out.fmt_left();
748 writeln!(
749 f,
750 "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
751 )
752 }
753 }
754 }
755}
756
757impl<D: Dialect> IndexAssignVector<D> {
758 fn format(
759 f: &mut Formatter<'_>,
760 lhs: &Variable<D>,
761 rhs: &Variable<D>,
762 out: &Variable<D>,
763 ) -> std::fmt::Result {
764 let index = match lhs {
765 Variable::Constant(value, _) => value.as_usize(),
766 _ => {
767 let elem = out.elem();
768 let addr_space = D::address_space_for_variable(out);
769 return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
770 }
771 };
772
773 let out = out.index(index);
774 let rhs = rhs.index(index);
775
776 writeln!(f, "{out} = {rhs};")
777 }
778}