1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74 let addr_space = D::address_space_for_variable(out);
75 let out = out.fmt_left();
76
77 writeln!(
78 f,
79 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80 )?;
81
82 Ok(())
83 }
84 }
85}
86
87macro_rules! operator {
88 ($name:ident, $op:expr) => {
89 pub struct $name;
90
91 impl<D: Dialect> Binary<D> for $name {
92 fn format_scalar<Lhs: Display, Rhs: Display>(
93 f: &mut std::fmt::Formatter<'_>,
94 lhs: Lhs,
95 rhs: Rhs,
96 out_item: Item<D>,
97 ) -> std::fmt::Result {
98 let out_elem = out_item.elem();
99 match out_elem {
100 Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104 write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105 }
106 _ => write!(f, "{lhs} {} {rhs}", $op),
107 }
108 }
109 }
110 };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135 fn format_scalar<Lhs: Display, Rhs: Display>(
136 f: &mut std::fmt::Formatter<'_>,
137 lhs: Lhs,
138 rhs: Rhs,
139 _out_item: Item<D>,
140 ) -> std::fmt::Result {
141 write!(f, "__fdividef({lhs}, {rhs})")
143 }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149 fn format_scalar<Lhs: Display, Rhs: Display>(
151 f: &mut std::fmt::Formatter<'_>,
152 lhs: Lhs,
153 rhs: Rhs,
154 out: Item<D>,
155 ) -> std::fmt::Result {
156 let out_elem = out.elem;
157 match out_elem {
158 Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159 Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160 Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161 Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162 _ => writeln!(f, "#error HiMul only supports 32 and 64 bit ints"),
163 }
164 }
165
166 fn unroll_vec(
168 f: &mut Formatter<'_>,
169 lhs: &Variable<D>,
170 rhs: &Variable<D>,
171 out: &Variable<D>,
172 ) -> core::fmt::Result {
173 let item_out = out.item();
174 let index = out.item().vectorization;
175
176 let out = out.fmt_left();
177 writeln!(f, "{out} = {item_out}{{")?;
178 for i in 0..index {
179 let lhsi = lhs.index(i);
180 let rhsi = rhs.index(i);
181
182 Self::format_scalar(f, lhsi, rhsi, item_out)?;
183 f.write_str(", ")?;
184 }
185
186 f.write_str("};\n")
187 }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193 fn format_scalar<Lhs: Display, Rhs: Display>(
194 f: &mut std::fmt::Formatter<'_>,
195 lhs: Lhs,
196 rhs: Rhs,
197 out: Item<D>,
198 ) -> std::fmt::Result {
199 D::compile_saturating_add(f, lhs, rhs, out)
200 }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206 fn format_scalar<Lhs: Display, Rhs: Display>(
207 f: &mut std::fmt::Formatter<'_>,
208 lhs: Lhs,
209 rhs: Rhs,
210 out: Item<D>,
211 ) -> std::fmt::Result {
212 D::compile_saturating_sub(f, lhs, rhs, out)
213 }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219 fn format_scalar<Lhs: Display, Rhs: Display>(
221 f: &mut std::fmt::Formatter<'_>,
222 lhs: Lhs,
223 rhs: Rhs,
224 item: Item<D>,
225 ) -> std::fmt::Result {
226 let elem = item.elem;
227 let lhs = lhs.to_string();
228 let rhs = rhs.to_string();
229 match elem {
230 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231 let lhs = format!("float({lhs})");
232 let rhs = format!("float({rhs})");
233 write!(f, "{elem}(")?;
234 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235 write!(f, ")")
236 }
237 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238 }
239 }
240
241 fn unroll_vec(
243 f: &mut Formatter<'_>,
244 lhs: &Variable<D>,
245 rhs: &Variable<D>,
246 out: &Variable<D>,
247 ) -> core::fmt::Result {
248 let item_out = out.item();
249 let index = out.item().vectorization;
250
251 let out = out.fmt_left();
252 writeln!(f, "{out} = {item_out}{{")?;
253 for i in 0..index {
254 let lhsi = lhs.index(i);
255 let rhsi = rhs.index(i);
256
257 Self::format_scalar(f, lhsi, rhsi, item_out)?;
258 f.write_str(", ")?;
259 }
260
261 f.write_str("};\n")
262 }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268 fn format_scalar<Lhs: Display, Rhs: Display>(
270 f: &mut std::fmt::Formatter<'_>,
271 lhs: Lhs,
272 rhs: Rhs,
273 _item: Item<D>,
274 ) -> std::fmt::Result {
275 write!(f, "__powf({lhs}, {rhs})")
276 }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282 fn format_scalar<Lhs: Display, Rhs: Display>(
284 f: &mut std::fmt::Formatter<'_>,
285 lhs: Lhs,
286 rhs: Rhs,
287 item: Item<D>,
288 ) -> std::fmt::Result {
289 let elem = item.elem;
290 let lhs = lhs.to_string();
291 let rhs = rhs.to_string();
292 match elem {
293 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294 let lhs = format!("float({lhs})");
295
296 write!(f, "{elem}(")?;
297 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298 write!(f, ")")
299 }
300 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
301 }
302 }
303
304 fn unroll_vec(
306 f: &mut Formatter<'_>,
307 lhs: &Variable<D>,
308 rhs: &Variable<D>,
309 out: &Variable<D>,
310 ) -> core::fmt::Result {
311 let item_out = out.item();
312 let index = out.item().vectorization;
313
314 let out = out.fmt_left();
315 writeln!(f, "{out} = {item_out}{{")?;
316 for i in 0..index {
317 let lhsi = lhs.index(i);
318 let rhsi = rhs.index(i);
319
320 Self::format_scalar(f, lhsi, rhsi, item_out)?;
321 f.write_str(", ")?;
322 }
323
324 f.write_str("};\n")
325 }
326}
327
328pub struct ArcTan2;
329
330impl<D: Dialect> Binary<D> for ArcTan2 {
331 fn format_scalar<Lhs: Display, Rhs: Display>(
333 f: &mut std::fmt::Formatter<'_>,
334 lhs: Lhs,
335 rhs: Rhs,
336 item: Item<D>,
337 ) -> std::fmt::Result {
338 let elem = item.elem;
339 match elem {
340 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
341 write!(f, "{elem}(atan2(float({lhs}), float({rhs})))")
342 }
343 _ => {
344 write!(f, "atan2({lhs}, {rhs})")
345 }
346 }
347 }
348
349 fn unroll_vec(
351 f: &mut Formatter<'_>,
352 lhs: &Variable<D>,
353 rhs: &Variable<D>,
354 out: &Variable<D>,
355 ) -> core::fmt::Result {
356 let item_out = out.item();
357 let index = out.item().vectorization;
358
359 let out = out.fmt_left();
360 writeln!(f, "{out} = {item_out}{{")?;
361 for i in 0..index {
362 let lhsi = lhs.index(i);
363 let rhsi = rhs.index(i);
364
365 Self::format_scalar(f, lhsi, rhsi, item_out)?;
366 f.write_str(", ")?;
367 }
368
369 f.write_str("};\n")
370 }
371}
372
373pub struct Max;
374
375impl<D: Dialect> Binary<D> for Max {
376 fn format_scalar<Lhs: Display, Rhs: Display>(
377 f: &mut std::fmt::Formatter<'_>,
378 lhs: Lhs,
379 rhs: Rhs,
380 item: Item<D>,
381 ) -> std::fmt::Result {
382 D::compile_instruction_max_function_name(f, item)?;
383 write!(f, "({lhs}, {rhs})")
384 }
385}
386
387pub struct Min;
388
389impl<D: Dialect> Binary<D> for Min {
390 fn format_scalar<Lhs: Display, Rhs: Display>(
391 f: &mut std::fmt::Formatter<'_>,
392 lhs: Lhs,
393 rhs: Rhs,
394 item: Item<D>,
395 ) -> std::fmt::Result {
396 D::compile_instruction_min_function_name(f, item)?;
397 write!(f, "({lhs}, {rhs})")
398 }
399}
400
401pub struct IndexAssign;
402pub struct Index;
403
404impl IndexAssign {
405 pub fn format<D: Dialect>(
406 f: &mut Formatter<'_>,
407 index: &Variable<D>,
408 value: &Variable<D>,
409 out_list: &Variable<D>,
410 line_size: u32,
411 ) -> std::fmt::Result {
412 if matches!(
413 out_list,
414 Variable::LocalMut { .. } | Variable::LocalConst { .. }
415 ) {
416 return IndexAssignVector::format(f, index, value, out_list);
417 };
418
419 if line_size > 0 {
420 let mut item = out_list.item();
421 item.vectorization = line_size as usize;
422 let addr_space = D::address_space_for_variable(out_list);
423 let qualifier = out_list.const_qualifier();
424 let tmp = Variable::tmp_declared(item);
425
426 writeln!(
427 f,
428 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
429 )?;
430
431 return IndexAssign::format(f, index, value, &tmp, 0);
432 }
433
434 let out_item = out_list.item();
435
436 if index.item().vectorization == 1 {
437 write!(f, "{}[{index}] = ", out_list.fmt_left())?;
438 Self::format_scalar(f, *index, *value, out_item)?;
439 f.write_str(";\n")
440 } else {
441 Self::unroll_vec(f, index, value, out_list)
442 }
443 }
444 fn format_scalar<D: Dialect, Lhs, Rhs>(
445 f: &mut Formatter<'_>,
446 _lhs: Lhs,
447 rhs: Rhs,
448 item_out: Item<D>,
449 ) -> std::fmt::Result
450 where
451 Lhs: Component<D>,
452 Rhs: Component<D>,
453 {
454 let item_rhs = rhs.item();
455
456 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
457 writeln!(f, "{item_out}{{")?;
458 for i in 0..item_out.vectorization {
459 if cast {
460 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
461 } else {
462 writeln!(f, "{},", rhs.index(i))?;
463 }
464 }
465 f.write_str("}")?;
466
467 Ok(())
468 };
469
470 if item_out.vectorization != item_rhs.vectorization {
471 format_vec(f, item_out != item_rhs)
472 } else if item_out.elem != item_rhs.elem {
473 if item_out.vectorization > 1 {
474 format_vec(f, true)?;
475 } else {
476 write!(f, "{}({rhs})", item_out.elem)?;
477 }
478 Ok(())
479 } else if rhs.is_const() && item_rhs.vectorization > 1 {
480 write!(f, "reinterpret_cast<")?;
482 D::compile_local_memory_qualifier(f)?;
483 write!(f, " {item_out} const&>({rhs})")
484 } else {
485 write!(f, "{rhs}")
486 }
487 }
488
489 fn unroll_vec<D: Dialect>(
490 f: &mut Formatter<'_>,
491 lhs: &Variable<D>,
492 rhs: &Variable<D>,
493 out: &Variable<D>,
494 ) -> std::fmt::Result {
495 let item_lhs = lhs.item();
496 let out_item = out.item();
497 let out = out.fmt_left();
498
499 for i in 0..item_lhs.vectorization {
500 let lhsi = lhs.index(i);
501 let rhsi = rhs.index(i);
502 write!(f, "{out}[{lhs}] = ")?;
503 Self::format_scalar(f, lhsi, rhsi, out_item)?;
504 f.write_str(";\n")?;
505 }
506
507 Ok(())
508 }
509}
510
511impl Index {
512 pub(crate) fn format<D: Dialect>(
513 f: &mut Formatter<'_>,
514 list: &Variable<D>,
515 index: &Variable<D>,
516 out: &Variable<D>,
517 line_size: u32,
518 ) -> std::fmt::Result {
519 if matches!(
520 list,
521 Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
522 ) {
523 return IndexVector::format(f, list, index, out);
524 }
525
526 if line_size > 0 {
527 let mut item = list.item();
528 item.vectorization = line_size as usize;
529 let addr_space = D::address_space_for_variable(list);
530 let qualifier = list.const_qualifier();
531 let tmp = Variable::tmp_declared(item);
532
533 writeln!(
534 f,
535 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
536 )?;
537
538 return Index::format(f, &tmp, index, out, 0);
539 }
540
541 let item_out = out.item();
542 if let Elem::Atomic(inner) = item_out.elem {
543 let addr_space = D::address_space_for_variable(list);
544 writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
545 } else {
546 let out = out.fmt_left();
547 write!(f, "{out} = ")?;
548 Self::format_scalar(f, *list, *index, item_out)?;
549 f.write_str(";\n")
550 }
551 }
552
553 fn format_scalar<D: Dialect, Lhs, Rhs>(
554 f: &mut Formatter<'_>,
555 lhs: Lhs,
556 rhs: Rhs,
557 item_out: Item<D>,
558 ) -> std::fmt::Result
559 where
560 Lhs: Component<D>,
561 Rhs: Component<D>,
562 {
563 let item_lhs = lhs.item();
564
565 let format_vec = |f: &mut Formatter<'_>| {
566 writeln!(f, "{item_out}{{")?;
567 for i in 0..item_out.vectorization {
568 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
569 }
570 f.write_str("}")?;
571
572 Ok(())
573 };
574
575 if item_out.elem != item_lhs.elem {
576 if item_out.vectorization > 1 {
577 format_vec(f)
578 } else {
579 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
580 }
581 } else {
582 write!(f, "{lhs}[{rhs}]")
583 }
584 }
585}
586
587struct IndexVector<D: Dialect> {
597 _dialect: PhantomData<D>,
598}
599
600struct IndexAssignVector<D: Dialect> {
611 _dialect: PhantomData<D>,
612}
613
614impl<D: Dialect> IndexVector<D> {
615 fn format(
616 f: &mut Formatter<'_>,
617 lhs: &Variable<D>,
618 rhs: &Variable<D>,
619 out: &Variable<D>,
620 ) -> std::fmt::Result {
621 match rhs {
622 Variable::ConstantScalar(value, _elem) => {
623 let index = value.as_usize();
624 let out = out.index(index);
625 let lhs = lhs.index(index);
626 let out = out.fmt_left();
627 writeln!(f, "{out} = {lhs};")
628 }
629 _ => {
630 let elem = out.elem();
631 let qualifier = out.const_qualifier();
632 let addr_space = D::address_space_for_variable(out);
633 let out = out.fmt_left();
634 writeln!(
635 f,
636 "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
637 )
638 }
639 }
640 }
641}
642
643impl<D: Dialect> IndexAssignVector<D> {
644 fn format(
645 f: &mut Formatter<'_>,
646 lhs: &Variable<D>,
647 rhs: &Variable<D>,
648 out: &Variable<D>,
649 ) -> std::fmt::Result {
650 let index = match lhs {
651 Variable::ConstantScalar(value, _) => value.as_usize(),
652 _ => {
653 let elem = out.elem();
654 let addr_space = D::address_space_for_variable(out);
655 return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
656 }
657 };
658
659 let out = out.index(index);
660 let rhs = rhs.index(index);
661
662 writeln!(f, "{out} = {rhs};")
663 }
664}