1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74 let addr_space = D::address_space_for_variable(out);
75 let out = out.fmt_left();
76
77 writeln!(
78 f,
79 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80 )?;
81
82 Ok(())
83 }
84 }
85}
86
87macro_rules! operator {
88 ($name:ident, $op:expr) => {
89 pub struct $name;
90
91 impl<D: Dialect> Binary<D> for $name {
92 fn format_scalar<Lhs: Display, Rhs: Display>(
93 f: &mut std::fmt::Formatter<'_>,
94 lhs: Lhs,
95 rhs: Rhs,
96 out_item: Item<D>,
97 ) -> std::fmt::Result {
98 let out_elem = out_item.elem();
99 match out_elem {
100 Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104 write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105 }
106 _ => write!(f, "{lhs} {} {rhs}", $op),
107 }
108 }
109 }
110 };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct FastDiv;
133
134impl<D: Dialect> Binary<D> for FastDiv {
135 fn format_scalar<Lhs: Display, Rhs: Display>(
136 f: &mut std::fmt::Formatter<'_>,
137 lhs: Lhs,
138 rhs: Rhs,
139 _out_item: Item<D>,
140 ) -> std::fmt::Result {
141 write!(f, "__fdividef({lhs}, {rhs})")
143 }
144}
145
146pub struct HiMul;
147
148impl<D: Dialect> Binary<D> for HiMul {
149 fn format_scalar<Lhs: Display, Rhs: Display>(
151 f: &mut std::fmt::Formatter<'_>,
152 lhs: Lhs,
153 rhs: Rhs,
154 out: Item<D>,
155 ) -> std::fmt::Result {
156 let out_elem = out.elem;
157 match out_elem {
158 Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
159 Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
160 Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
161 Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
162 _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
163 }
164 }
165
166 fn unroll_vec(
168 f: &mut Formatter<'_>,
169 lhs: &Variable<D>,
170 rhs: &Variable<D>,
171 out: &Variable<D>,
172 ) -> core::fmt::Result {
173 let item_out = out.item();
174 let index = out.item().vectorization;
175
176 let out = out.fmt_left();
177 writeln!(f, "{out} = {item_out}{{")?;
178 for i in 0..index {
179 let lhsi = lhs.index(i);
180 let rhsi = rhs.index(i);
181
182 Self::format_scalar(f, lhsi, rhsi, item_out)?;
183 f.write_str(", ")?;
184 }
185
186 f.write_str("};\n")
187 }
188}
189
190pub struct SaturatingAdd;
191
192impl<D: Dialect> Binary<D> for SaturatingAdd {
193 fn format_scalar<Lhs: Display, Rhs: Display>(
194 f: &mut std::fmt::Formatter<'_>,
195 lhs: Lhs,
196 rhs: Rhs,
197 out: Item<D>,
198 ) -> std::fmt::Result {
199 D::compile_saturating_add(f, lhs, rhs, out)
200 }
201}
202
203pub struct SaturatingSub;
204
205impl<D: Dialect> Binary<D> for SaturatingSub {
206 fn format_scalar<Lhs: Display, Rhs: Display>(
207 f: &mut std::fmt::Formatter<'_>,
208 lhs: Lhs,
209 rhs: Rhs,
210 out: Item<D>,
211 ) -> std::fmt::Result {
212 D::compile_saturating_sub(f, lhs, rhs, out)
213 }
214}
215
216pub struct Powf;
217
218impl<D: Dialect> Binary<D> for Powf {
219 fn format_scalar<Lhs: Display, Rhs: Display>(
221 f: &mut std::fmt::Formatter<'_>,
222 lhs: Lhs,
223 rhs: Rhs,
224 item: Item<D>,
225 ) -> std::fmt::Result {
226 let elem = item.elem;
227 let lhs = lhs.to_string();
228 let rhs = rhs.to_string();
229 match elem {
230 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
231 let lhs = format!("float({lhs})");
232 let rhs = format!("float({rhs})");
233 write!(f, "{elem}(")?;
234 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
235 write!(f, ")")
236 }
237 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
238 }
239 }
240
241 fn unroll_vec(
243 f: &mut Formatter<'_>,
244 lhs: &Variable<D>,
245 rhs: &Variable<D>,
246 out: &Variable<D>,
247 ) -> core::fmt::Result {
248 let item_out = out.item();
249 let index = out.item().vectorization;
250
251 let out = out.fmt_left();
252 writeln!(f, "{out} = {item_out}{{")?;
253 for i in 0..index {
254 let lhsi = lhs.index(i);
255 let rhsi = rhs.index(i);
256
257 Self::format_scalar(f, lhsi, rhsi, item_out)?;
258 f.write_str(", ")?;
259 }
260
261 f.write_str("};\n")
262 }
263}
264
265pub struct FastPowf;
266
267impl<D: Dialect> Binary<D> for FastPowf {
268 fn format_scalar<Lhs: Display, Rhs: Display>(
270 f: &mut std::fmt::Formatter<'_>,
271 lhs: Lhs,
272 rhs: Rhs,
273 _item: Item<D>,
274 ) -> std::fmt::Result {
275 write!(f, "__powf({lhs}, {rhs})")
276 }
277}
278
279pub struct Powi;
280
281impl<D: Dialect> Binary<D> for Powi {
282 fn format_scalar<Lhs: Display, Rhs: Display>(
284 f: &mut std::fmt::Formatter<'_>,
285 lhs: Lhs,
286 rhs: Rhs,
287 item: Item<D>,
288 ) -> std::fmt::Result {
289 let elem = item.elem;
290 let lhs = lhs.to_string();
291 let rhs = rhs.to_string();
292 match elem {
293 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
294 let lhs = format!("float({lhs})");
295
296 write!(f, "{elem}(")?;
297 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
298 write!(f, ")")
299 }
300 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
301 }
302 }
303
304 fn unroll_vec(
306 f: &mut Formatter<'_>,
307 lhs: &Variable<D>,
308 rhs: &Variable<D>,
309 out: &Variable<D>,
310 ) -> core::fmt::Result {
311 let item_out = out.item();
312 let index = out.item().vectorization;
313
314 let out = out.fmt_left();
315 writeln!(f, "{out} = {item_out}{{")?;
316 for i in 0..index {
317 let lhsi = lhs.index(i);
318 let rhsi = rhs.index(i);
319
320 Self::format_scalar(f, lhsi, rhsi, item_out)?;
321 f.write_str(", ")?;
322 }
323
324 f.write_str("};\n")
325 }
326}
327
328pub struct Max;
329
330impl<D: Dialect> Binary<D> for Max {
331 fn format_scalar<Lhs: Display, Rhs: Display>(
332 f: &mut std::fmt::Formatter<'_>,
333 lhs: Lhs,
334 rhs: Rhs,
335 item: Item<D>,
336 ) -> std::fmt::Result {
337 D::compile_instruction_max_function_name(f, item)?;
338 write!(f, "({lhs}, {rhs})")
339 }
340}
341
342pub struct Min;
343
344impl<D: Dialect> Binary<D> for Min {
345 fn format_scalar<Lhs: Display, Rhs: Display>(
346 f: &mut std::fmt::Formatter<'_>,
347 lhs: Lhs,
348 rhs: Rhs,
349 item: Item<D>,
350 ) -> std::fmt::Result {
351 D::compile_instruction_min_function_name(f, item)?;
352 write!(f, "({lhs}, {rhs})")
353 }
354}
355
356pub struct IndexAssign;
357pub struct Index;
358
359impl IndexAssign {
360 pub fn format<D: Dialect>(
361 f: &mut Formatter<'_>,
362 index: &Variable<D>,
363 value: &Variable<D>,
364 out_list: &Variable<D>,
365 line_size: u32,
366 ) -> std::fmt::Result {
367 if matches!(
368 out_list,
369 Variable::LocalMut { .. } | Variable::LocalConst { .. }
370 ) {
371 return IndexAssignVector::format(f, index, value, out_list);
372 };
373
374 if line_size > 0 {
375 let mut item = out_list.item();
376 item.vectorization = line_size as usize;
377 let addr_space = D::address_space_for_variable(out_list);
378 let qualifier = out_list.const_qualifier();
379 let tmp = Variable::tmp_declared(item);
380
381 writeln!(
382 f,
383 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
384 )?;
385
386 return IndexAssign::format(f, index, value, &tmp, 0);
387 }
388
389 let out_item = out_list.item();
390
391 if index.item().vectorization == 1 {
392 write!(f, "{}[{index}] = ", out_list.fmt_left())?;
393 Self::format_scalar(f, *index, *value, out_item)?;
394 f.write_str(";\n")
395 } else {
396 Self::unroll_vec(f, index, value, out_list)
397 }
398 }
399 fn format_scalar<D: Dialect, Lhs, Rhs>(
400 f: &mut Formatter<'_>,
401 _lhs: Lhs,
402 rhs: Rhs,
403 item_out: Item<D>,
404 ) -> std::fmt::Result
405 where
406 Lhs: Component<D>,
407 Rhs: Component<D>,
408 {
409 let item_rhs = rhs.item();
410
411 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
412 writeln!(f, "{item_out}{{")?;
413 for i in 0..item_out.vectorization {
414 if cast {
415 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
416 } else {
417 writeln!(f, "{},", rhs.index(i))?;
418 }
419 }
420 f.write_str("}")?;
421
422 Ok(())
423 };
424
425 if item_out.vectorization != item_rhs.vectorization {
426 format_vec(f, item_out != item_rhs)
427 } else if item_out.elem != item_rhs.elem {
428 if item_out.vectorization > 1 {
429 format_vec(f, true)?;
430 } else {
431 write!(f, "{}({rhs})", item_out.elem)?;
432 }
433 Ok(())
434 } else if rhs.is_const() && item_rhs.vectorization > 1 {
435 write!(f, "reinterpret_cast<")?;
437 D::compile_local_memory_qualifier(f)?;
438 write!(f, " {item_out} const&>({rhs})")
439 } else {
440 write!(f, "{rhs}")
441 }
442 }
443
444 fn unroll_vec<D: Dialect>(
445 f: &mut Formatter<'_>,
446 lhs: &Variable<D>,
447 rhs: &Variable<D>,
448 out: &Variable<D>,
449 ) -> std::fmt::Result {
450 let item_lhs = lhs.item();
451 let out_item = out.item();
452 let out = out.fmt_left();
453
454 for i in 0..item_lhs.vectorization {
455 let lhsi = lhs.index(i);
456 let rhsi = rhs.index(i);
457 write!(f, "{out}[{lhs}] = ")?;
458 Self::format_scalar(f, lhsi, rhsi, out_item)?;
459 f.write_str(";\n")?;
460 }
461
462 Ok(())
463 }
464}
465
466impl Index {
467 pub(crate) fn format<D: Dialect>(
468 f: &mut Formatter<'_>,
469 list: &Variable<D>,
470 index: &Variable<D>,
471 out: &Variable<D>,
472 line_size: u32,
473 ) -> std::fmt::Result {
474 if matches!(
475 list,
476 Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
477 ) {
478 return IndexVector::format(f, list, index, out);
479 }
480
481 if line_size > 0 {
482 let mut item = list.item();
483 item.vectorization = line_size as usize;
484 let addr_space = D::address_space_for_variable(list);
485 let qualifier = list.const_qualifier();
486 let tmp = Variable::tmp_declared(item);
487
488 writeln!(
489 f,
490 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
491 )?;
492
493 return Index::format(f, &tmp, index, out, 0);
494 }
495
496 let item_out = out.item();
497 if let Elem::Atomic(inner) = item_out.elem {
498 let addr_space = D::address_space_for_variable(list);
499 writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
500 } else {
501 let out = out.fmt_left();
502 write!(f, "{out} = ")?;
503 Self::format_scalar(f, *list, *index, item_out)?;
504 f.write_str(";\n")
505 }
506 }
507
508 fn format_scalar<D: Dialect, Lhs, Rhs>(
509 f: &mut Formatter<'_>,
510 lhs: Lhs,
511 rhs: Rhs,
512 item_out: Item<D>,
513 ) -> std::fmt::Result
514 where
515 Lhs: Component<D>,
516 Rhs: Component<D>,
517 {
518 let item_lhs = lhs.item();
519
520 let format_vec = |f: &mut Formatter<'_>| {
521 writeln!(f, "{item_out}{{")?;
522 for i in 0..item_out.vectorization {
523 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
524 }
525 f.write_str("}")?;
526
527 Ok(())
528 };
529
530 if item_out.elem != item_lhs.elem {
531 if item_out.vectorization > 1 {
532 format_vec(f)
533 } else {
534 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
535 }
536 } else {
537 write!(f, "{lhs}[{rhs}]")
538 }
539 }
540}
541
542struct IndexVector<D: Dialect> {
552 _dialect: PhantomData<D>,
553}
554
555struct IndexAssignVector<D: Dialect> {
566 _dialect: PhantomData<D>,
567}
568
569impl<D: Dialect> IndexVector<D> {
570 fn format(
571 f: &mut Formatter<'_>,
572 lhs: &Variable<D>,
573 rhs: &Variable<D>,
574 out: &Variable<D>,
575 ) -> std::fmt::Result {
576 match rhs {
577 Variable::ConstantScalar(value, _elem) => {
578 let index = value.as_usize();
579 let out = out.index(index);
580 let lhs = lhs.index(index);
581 let out = out.fmt_left();
582 writeln!(f, "{out} = {lhs};")
583 }
584 _ => {
585 let elem = out.elem();
586 let qualifier = out.const_qualifier();
587 let addr_space = D::address_space_for_variable(out);
588 let out = out.fmt_left();
589 writeln!(
590 f,
591 "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
592 )
593 }
594 }
595 }
596}
597
598impl<D: Dialect> IndexAssignVector<D> {
599 fn format(
600 f: &mut Formatter<'_>,
601 lhs: &Variable<D>,
602 rhs: &Variable<D>,
603 out: &Variable<D>,
604 ) -> std::fmt::Result {
605 let index = match lhs {
606 Variable::ConstantScalar(value, _) => value.as_usize(),
607 _ => {
608 let elem = out.elem();
609 let addr_space = D::address_space_for_variable(out);
610 return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
611 }
612 };
613
614 let out = out.index(index);
615 let rhs = rhs.index(index);
616
617 writeln!(f, "{out} = {rhs};")
618 }
619}