1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74 let addr_space = D::address_space_for_variable(out);
75 let out = out.fmt_left();
76
77 writeln!(
78 f,
79 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80 )?;
81
82 Ok(())
83 }
84 }
85}
86
87macro_rules! operator {
88 ($name:ident, $op:expr) => {
89 pub struct $name;
90
91 impl<D: Dialect> Binary<D> for $name {
92 fn format_scalar<Lhs: Display, Rhs: Display>(
93 f: &mut std::fmt::Formatter<'_>,
94 lhs: Lhs,
95 rhs: Rhs,
96 out_item: Item<D>,
97 ) -> std::fmt::Result {
98 let out_elem = out_item.elem();
99 match out_elem {
100 Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104 write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105 }
106 _ => write!(f, "{lhs} {} {rhs}", $op),
107 }
108 }
109 }
110 };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct HiMul;
133
134impl<D: Dialect> Binary<D> for HiMul {
135 fn format_scalar<Lhs: Display, Rhs: Display>(
137 f: &mut std::fmt::Formatter<'_>,
138 lhs: Lhs,
139 rhs: Rhs,
140 out: Item<D>,
141 ) -> std::fmt::Result {
142 let out_elem = out.elem;
143 match out_elem {
144 Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
145 Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
146 Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
147 Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
148 _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
149 }
150 }
151
152 fn unroll_vec(
154 f: &mut Formatter<'_>,
155 lhs: &Variable<D>,
156 rhs: &Variable<D>,
157 out: &Variable<D>,
158 ) -> core::fmt::Result {
159 let item_out = out.item();
160 let index = out.item().vectorization;
161
162 let out = out.fmt_left();
163 writeln!(f, "{out} = {item_out}{{")?;
164 for i in 0..index {
165 let lhsi = lhs.index(i);
166 let rhsi = rhs.index(i);
167
168 Self::format_scalar(f, lhsi, rhsi, item_out)?;
169 f.write_str(", ")?;
170 }
171
172 f.write_str("};\n")
173 }
174}
175
176pub struct SaturatingAdd;
177
178impl<D: Dialect> Binary<D> for SaturatingAdd {
179 fn format_scalar<Lhs: Display, Rhs: Display>(
180 f: &mut std::fmt::Formatter<'_>,
181 lhs: Lhs,
182 rhs: Rhs,
183 out: Item<D>,
184 ) -> std::fmt::Result {
185 D::compile_saturating_add(f, lhs, rhs, out)
186 }
187}
188
189pub struct SaturatingSub;
190
191impl<D: Dialect> Binary<D> for SaturatingSub {
192 fn format_scalar<Lhs: Display, Rhs: Display>(
193 f: &mut std::fmt::Formatter<'_>,
194 lhs: Lhs,
195 rhs: Rhs,
196 out: Item<D>,
197 ) -> std::fmt::Result {
198 D::compile_saturating_sub(f, lhs, rhs, out)
199 }
200}
201
202pub struct Powf;
203
204impl<D: Dialect> Binary<D> for Powf {
205 fn format_scalar<Lhs: Display, Rhs: Display>(
207 f: &mut std::fmt::Formatter<'_>,
208 lhs: Lhs,
209 rhs: Rhs,
210 item: Item<D>,
211 ) -> std::fmt::Result {
212 let elem = item.elem;
213 let lhs = lhs.to_string();
214 let rhs = rhs.to_string();
215 match elem {
216 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
217 let lhs = format!("float({lhs})");
218 let rhs = format!("float({rhs})");
219 write!(f, "{elem}(")?;
220 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
221 write!(f, ")")
222 }
223 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
224 }
225 }
226
227 fn unroll_vec(
229 f: &mut Formatter<'_>,
230 lhs: &Variable<D>,
231 rhs: &Variable<D>,
232 out: &Variable<D>,
233 ) -> core::fmt::Result {
234 let item_out = out.item();
235 let index = out.item().vectorization;
236
237 let out = out.fmt_left();
238 writeln!(f, "{out} = {item_out}{{")?;
239 for i in 0..index {
240 let lhsi = lhs.index(i);
241 let rhsi = rhs.index(i);
242
243 Self::format_scalar(f, lhsi, rhsi, item_out)?;
244 f.write_str(", ")?;
245 }
246
247 f.write_str("};\n")
248 }
249}
250
251pub struct Powi;
252
253impl<D: Dialect> Binary<D> for Powi {
254 fn format_scalar<Lhs: Display, Rhs: Display>(
256 f: &mut std::fmt::Formatter<'_>,
257 lhs: Lhs,
258 rhs: Rhs,
259 item: Item<D>,
260 ) -> std::fmt::Result {
261 let elem = item.elem;
262 let lhs = lhs.to_string();
263 let rhs = rhs.to_string();
264 match elem {
265 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
266 let lhs = format!("float({lhs})");
267
268 write!(f, "{elem}(")?;
269 D::compile_instruction_powf(f, &lhs, &rhs, Elem::F32)?;
270 write!(f, ")")
271 }
272 _ => D::compile_instruction_powf(f, &lhs, &rhs, elem),
273 }
274 }
275
276 fn unroll_vec(
278 f: &mut Formatter<'_>,
279 lhs: &Variable<D>,
280 rhs: &Variable<D>,
281 out: &Variable<D>,
282 ) -> core::fmt::Result {
283 let item_out = out.item();
284 let index = out.item().vectorization;
285
286 let out = out.fmt_left();
287 writeln!(f, "{out} = {item_out}{{")?;
288 for i in 0..index {
289 let lhsi = lhs.index(i);
290 let rhsi = rhs.index(i);
291
292 Self::format_scalar(f, lhsi, rhsi, item_out)?;
293 f.write_str(", ")?;
294 }
295
296 f.write_str("};\n")
297 }
298}
299
300pub struct Max;
301
302impl<D: Dialect> Binary<D> for Max {
303 fn format_scalar<Lhs: Display, Rhs: Display>(
304 f: &mut std::fmt::Formatter<'_>,
305 lhs: Lhs,
306 rhs: Rhs,
307 item: Item<D>,
308 ) -> std::fmt::Result {
309 D::compile_instruction_max_function_name(f, item)?;
310 write!(f, "({lhs}, {rhs})")
311 }
312}
313
314pub struct Min;
315
316impl<D: Dialect> Binary<D> for Min {
317 fn format_scalar<Lhs: Display, Rhs: Display>(
318 f: &mut std::fmt::Formatter<'_>,
319 lhs: Lhs,
320 rhs: Rhs,
321 item: Item<D>,
322 ) -> std::fmt::Result {
323 D::compile_instruction_min_function_name(f, item)?;
324 write!(f, "({lhs}, {rhs})")
325 }
326}
327
328pub struct IndexAssign;
329pub struct Index;
330
331impl IndexAssign {
332 pub fn format<D: Dialect>(
333 f: &mut Formatter<'_>,
334 index: &Variable<D>,
335 value: &Variable<D>,
336 out_list: &Variable<D>,
337 line_size: u32,
338 ) -> std::fmt::Result {
339 if matches!(
340 out_list,
341 Variable::LocalMut { .. } | Variable::LocalConst { .. }
342 ) {
343 return IndexAssignVector::format(f, index, value, out_list);
344 };
345
346 if line_size > 0 {
347 let mut item = out_list.item();
348 item.vectorization = line_size as usize;
349 let addr_space = D::address_space_for_variable(out_list);
350 let qualifier = out_list.const_qualifier();
351 let tmp = Variable::tmp_declared(item);
352
353 writeln!(
354 f,
355 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
356 )?;
357
358 return IndexAssign::format(f, index, value, &tmp, 0);
359 }
360
361 let out_item = out_list.item();
362
363 if index.item().vectorization == 1 {
364 write!(f, "{}[{index}] = ", out_list.fmt_left())?;
365 Self::format_scalar(f, *index, *value, out_item)?;
366 f.write_str(";\n")
367 } else {
368 Self::unroll_vec(f, index, value, out_list)
369 }
370 }
371 fn format_scalar<D: Dialect, Lhs, Rhs>(
372 f: &mut Formatter<'_>,
373 _lhs: Lhs,
374 rhs: Rhs,
375 item_out: Item<D>,
376 ) -> std::fmt::Result
377 where
378 Lhs: Component<D>,
379 Rhs: Component<D>,
380 {
381 let item_rhs = rhs.item();
382
383 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
384 writeln!(f, "{item_out}{{")?;
385 for i in 0..item_out.vectorization {
386 if cast {
387 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
388 } else {
389 writeln!(f, "{},", rhs.index(i))?;
390 }
391 }
392 f.write_str("}")?;
393
394 Ok(())
395 };
396
397 if item_out.vectorization != item_rhs.vectorization {
398 format_vec(f, item_out != item_rhs)
399 } else if item_out.elem != item_rhs.elem {
400 if item_out.vectorization > 1 {
401 format_vec(f, true)?;
402 } else {
403 write!(f, "{}({rhs})", item_out.elem)?;
404 }
405 Ok(())
406 } else if rhs.is_const() && item_rhs.vectorization > 1 {
407 write!(f, "reinterpret_cast<")?;
409 D::compile_local_memory_qualifier(f)?;
410 write!(f, " {item_out} const&>({rhs})")
411 } else {
412 write!(f, "{rhs}")
413 }
414 }
415
416 fn unroll_vec<D: Dialect>(
417 f: &mut Formatter<'_>,
418 lhs: &Variable<D>,
419 rhs: &Variable<D>,
420 out: &Variable<D>,
421 ) -> std::fmt::Result {
422 let item_lhs = lhs.item();
423 let out_item = out.item();
424 let out = out.fmt_left();
425
426 for i in 0..item_lhs.vectorization {
427 let lhsi = lhs.index(i);
428 let rhsi = rhs.index(i);
429 write!(f, "{out}[{lhs}] = ")?;
430 Self::format_scalar(f, lhsi, rhsi, out_item)?;
431 f.write_str(";\n")?;
432 }
433
434 Ok(())
435 }
436}
437
438impl Index {
439 pub(crate) fn format<D: Dialect>(
440 f: &mut Formatter<'_>,
441 list: &Variable<D>,
442 index: &Variable<D>,
443 out: &Variable<D>,
444 line_size: u32,
445 ) -> std::fmt::Result {
446 if matches!(
447 list,
448 Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
449 ) {
450 return IndexVector::format(f, list, index, out);
451 }
452
453 if line_size > 0 {
454 let mut item = list.item();
455 item.vectorization = line_size as usize;
456 let addr_space = D::address_space_for_variable(list);
457 let qualifier = list.const_qualifier();
458 let tmp = Variable::tmp_declared(item);
459
460 writeln!(
461 f,
462 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
463 )?;
464
465 return Index::format(f, &tmp, index, out, 0);
466 }
467
468 let item_out = out.item();
469 if let Elem::Atomic(inner) = item_out.elem {
470 let addr_space = D::address_space_for_variable(list);
471 writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
472 } else {
473 let out = out.fmt_left();
474 write!(f, "{out} = ")?;
475 Self::format_scalar(f, *list, *index, item_out)?;
476 f.write_str(";\n")
477 }
478 }
479
480 fn format_scalar<D: Dialect, Lhs, Rhs>(
481 f: &mut Formatter<'_>,
482 lhs: Lhs,
483 rhs: Rhs,
484 item_out: Item<D>,
485 ) -> std::fmt::Result
486 where
487 Lhs: Component<D>,
488 Rhs: Component<D>,
489 {
490 let item_lhs = lhs.item();
491
492 let format_vec = |f: &mut Formatter<'_>| {
493 writeln!(f, "{item_out}{{")?;
494 for i in 0..item_out.vectorization {
495 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
496 }
497 f.write_str("}")?;
498
499 Ok(())
500 };
501
502 if item_out.elem != item_lhs.elem {
503 if item_out.vectorization > 1 {
504 format_vec(f)
505 } else {
506 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
507 }
508 } else {
509 write!(f, "{lhs}[{rhs}]")
510 }
511 }
512}
513
514struct IndexVector<D: Dialect> {
524 _dialect: PhantomData<D>,
525}
526
527struct IndexAssignVector<D: Dialect> {
538 _dialect: PhantomData<D>,
539}
540
541impl<D: Dialect> IndexVector<D> {
542 fn format(
543 f: &mut Formatter<'_>,
544 lhs: &Variable<D>,
545 rhs: &Variable<D>,
546 out: &Variable<D>,
547 ) -> std::fmt::Result {
548 match rhs {
549 Variable::ConstantScalar(value, _elem) => {
550 let index = value.as_usize();
551 let out = out.index(index);
552 let lhs = lhs.index(index);
553 let out = out.fmt_left();
554 writeln!(f, "{out} = {lhs};")
555 }
556 _ => {
557 let elem = out.elem();
558 let qualifier = out.const_qualifier();
559 let addr_space = D::address_space_for_variable(out);
560 let out = out.fmt_left();
561 writeln!(
562 f,
563 "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
564 )
565 }
566 }
567 }
568}
569
570impl<D: Dialect> IndexAssignVector<D> {
571 fn format(
572 f: &mut Formatter<'_>,
573 lhs: &Variable<D>,
574 rhs: &Variable<D>,
575 out: &Variable<D>,
576 ) -> std::fmt::Result {
577 let index = match lhs {
578 Variable::ConstantScalar(value, _) => value.as_usize(),
579 _ => {
580 let elem = out.elem();
581 let addr_space = D::address_space_for_variable(out);
582 return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
583 }
584 };
585
586 let out = out.index(index);
587 let rhs = rhs.index(index);
588
589 writeln!(f, "{out} = {rhs};")
590 }
591}