1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74 let addr_space = D::address_space_for_variable(out);
75 let out = out.fmt_left();
76
77 writeln!(
78 f,
79 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80 )?;
81
82 Ok(())
83 }
84 }
85}
86
87macro_rules! operator {
88 ($name:ident, $op:expr) => {
89 pub struct $name;
90
91 impl<D: Dialect> Binary<D> for $name {
92 fn format_scalar<Lhs: Display, Rhs: Display>(
93 f: &mut std::fmt::Formatter<'_>,
94 lhs: Lhs,
95 rhs: Rhs,
96 out_item: Item<D>,
97 ) -> std::fmt::Result {
98 let out_elem = out_item.elem();
99 match out_elem {
100 Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104 write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105 }
106 _ => write!(f, "{lhs} {} {rhs}", $op),
107 }
108 }
109 }
110 };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct HiMul;
133
134impl<D: Dialect> Binary<D> for HiMul {
135 fn format_scalar<Lhs: Display, Rhs: Display>(
137 f: &mut std::fmt::Formatter<'_>,
138 lhs: Lhs,
139 rhs: Rhs,
140 out: Item<D>,
141 ) -> std::fmt::Result {
142 let out_elem = out.elem;
143 match out_elem {
144 Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
145 Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
146 Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
147 Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
148 _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
149 }
150 }
151
152 fn unroll_vec(
154 f: &mut Formatter<'_>,
155 lhs: &Variable<D>,
156 rhs: &Variable<D>,
157 out: &Variable<D>,
158 ) -> core::fmt::Result {
159 let item_out = out.item();
160 let index = out.item().vectorization;
161
162 let out = out.fmt_left();
163 writeln!(f, "{out} = {item_out}{{")?;
164 for i in 0..index {
165 let lhsi = lhs.index(i);
166 let rhsi = rhs.index(i);
167
168 Self::format_scalar(f, lhsi, rhsi, item_out)?;
169 f.write_str(", ")?;
170 }
171
172 f.write_str("};\n")
173 }
174}
175
176pub struct Powf;
177
178impl<D: Dialect> Binary<D> for Powf {
179 fn format_scalar<Lhs: Display, Rhs: Display>(
181 f: &mut std::fmt::Formatter<'_>,
182 lhs: Lhs,
183 rhs: Rhs,
184 item: Item<D>,
185 ) -> std::fmt::Result {
186 let elem = item.elem;
187 match elem {
188 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
189 write!(f, "{elem}(")?;
190 D::compile_instruction_powf(f)?;
191 write!(f, "(float({lhs}), float({rhs})))")
192 }
193 _ => {
194 D::compile_instruction_powf(f)?;
195 write!(f, "({lhs}, {rhs})")
196 }
197 }
198 }
199
200 fn unroll_vec(
202 f: &mut Formatter<'_>,
203 lhs: &Variable<D>,
204 rhs: &Variable<D>,
205 out: &Variable<D>,
206 ) -> core::fmt::Result {
207 let item_out = out.item();
208 let index = out.item().vectorization;
209
210 let out = out.fmt_left();
211 writeln!(f, "{out} = {item_out}{{")?;
212 for i in 0..index {
213 let lhsi = lhs.index(i);
214 let rhsi = rhs.index(i);
215
216 Self::format_scalar(f, lhsi, rhsi, item_out)?;
217 f.write_str(", ")?;
218 }
219
220 f.write_str("};\n")
221 }
222}
223
224pub struct Max;
225
226impl<D: Dialect> Binary<D> for Max {
227 fn format_scalar<Lhs: Display, Rhs: Display>(
228 f: &mut std::fmt::Formatter<'_>,
229 lhs: Lhs,
230 rhs: Rhs,
231 item: Item<D>,
232 ) -> std::fmt::Result {
233 D::compile_instruction_max_function_name(f, item)?;
234 write!(f, "({lhs}, {rhs})")
235 }
236}
237
238pub struct Min;
239
240impl<D: Dialect> Binary<D> for Min {
241 fn format_scalar<Lhs: Display, Rhs: Display>(
242 f: &mut std::fmt::Formatter<'_>,
243 lhs: Lhs,
244 rhs: Rhs,
245 item: Item<D>,
246 ) -> std::fmt::Result {
247 D::compile_instruction_min_function_name(f, item)?;
248 write!(f, "({lhs}, {rhs})")
249 }
250}
251
252pub struct IndexAssign;
253pub struct Index;
254
255impl IndexAssign {
256 pub fn format<D: Dialect>(
257 f: &mut Formatter<'_>,
258 index: &Variable<D>,
259 value: &Variable<D>,
260 out_list: &Variable<D>,
261 line_size: u32,
262 ) -> std::fmt::Result {
263 if matches!(
264 out_list,
265 Variable::LocalMut { .. } | Variable::LocalConst { .. }
266 ) {
267 return IndexAssignVector::format(f, index, value, out_list);
268 };
269
270 if line_size > 0 {
271 let mut item = out_list.item();
272 item.vectorization = line_size as usize;
273 let addr_space = D::address_space_for_variable(out_list);
274 let qualifier = out_list.const_qualifier();
275 let tmp = Variable::tmp_declared(item);
276
277 writeln!(
278 f,
279 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({out_list});"
280 )?;
281
282 return IndexAssign::format(f, index, value, &tmp, 0);
283 }
284
285 let out_item = out_list.item();
286
287 if index.item().vectorization == 1 {
288 write!(f, "{}[{index}] = ", out_list.fmt_left())?;
289 Self::format_scalar(f, *index, *value, out_item)?;
290 f.write_str(";\n")
291 } else {
292 Self::unroll_vec(f, index, value, out_list)
293 }
294 }
295 fn format_scalar<D: Dialect, Lhs, Rhs>(
296 f: &mut Formatter<'_>,
297 _lhs: Lhs,
298 rhs: Rhs,
299 item_out: Item<D>,
300 ) -> std::fmt::Result
301 where
302 Lhs: Component<D>,
303 Rhs: Component<D>,
304 {
305 let item_rhs = rhs.item();
306
307 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
308 writeln!(f, "{item_out}{{")?;
309 for i in 0..item_out.vectorization {
310 if cast {
311 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
312 } else {
313 writeln!(f, "{},", rhs.index(i))?;
314 }
315 }
316 f.write_str("}")?;
317
318 Ok(())
319 };
320
321 if item_out.vectorization != item_rhs.vectorization {
322 format_vec(f, item_out != item_rhs)
323 } else if item_out.elem != item_rhs.elem {
324 if item_out.vectorization > 1 {
325 format_vec(f, true)?;
326 } else {
327 write!(f, "{}({rhs})", item_out.elem)?;
328 }
329 Ok(())
330 } else if rhs.is_const() && item_rhs.vectorization > 1 {
331 write!(f, "reinterpret_cast<")?;
333 D::compile_local_memory_qualifier(f)?;
334 write!(f, " {item_out} const&>({rhs})")
335 } else {
336 write!(f, "{rhs}")
337 }
338 }
339
340 fn unroll_vec<D: Dialect>(
341 f: &mut Formatter<'_>,
342 lhs: &Variable<D>,
343 rhs: &Variable<D>,
344 out: &Variable<D>,
345 ) -> std::fmt::Result {
346 let item_lhs = lhs.item();
347 let out_item = out.item();
348 let out = out.fmt_left();
349
350 for i in 0..item_lhs.vectorization {
351 let lhsi = lhs.index(i);
352 let rhsi = rhs.index(i);
353 write!(f, "{out}[{lhs}] = ")?;
354 Self::format_scalar(f, lhsi, rhsi, out_item)?;
355 f.write_str(";\n")?;
356 }
357
358 Ok(())
359 }
360}
361
362impl Index {
363 pub(crate) fn format<D: Dialect>(
364 f: &mut Formatter<'_>,
365 list: &Variable<D>,
366 index: &Variable<D>,
367 out: &Variable<D>,
368 line_size: u32,
369 ) -> std::fmt::Result {
370 if matches!(
371 list,
372 Variable::LocalMut { .. } | Variable::LocalConst { .. } | Variable::ConstantScalar(..)
373 ) {
374 return IndexVector::format(f, list, index, out);
375 }
376
377 if line_size > 0 {
378 let mut item = list.item();
379 item.vectorization = line_size as usize;
380 let addr_space = D::address_space_for_variable(list);
381 let qualifier = list.const_qualifier();
382 let tmp = Variable::tmp_declared(item);
383
384 writeln!(
385 f,
386 "{qualifier} {addr_space}{item} *{tmp} = reinterpret_cast<{qualifier} {item}*>({list});"
387 )?;
388
389 return Index::format(f, &tmp, index, out, 0);
390 }
391
392 let item_out = out.item();
393 if let Elem::Atomic(inner) = item_out.elem {
394 let addr_space = D::address_space_for_variable(list);
395 writeln!(f, "{addr_space}{inner}* {out} = &{list}[{index}];")
396 } else {
397 let out = out.fmt_left();
398 write!(f, "{out} = ")?;
399 Self::format_scalar(f, *list, *index, item_out)?;
400 f.write_str(";\n")
401 }
402 }
403
404 fn format_scalar<D: Dialect, Lhs, Rhs>(
405 f: &mut Formatter<'_>,
406 lhs: Lhs,
407 rhs: Rhs,
408 item_out: Item<D>,
409 ) -> std::fmt::Result
410 where
411 Lhs: Component<D>,
412 Rhs: Component<D>,
413 {
414 let item_lhs = lhs.item();
415
416 let format_vec = |f: &mut Formatter<'_>| {
417 writeln!(f, "{item_out}{{")?;
418 for i in 0..item_out.vectorization {
419 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
420 }
421 f.write_str("}")?;
422
423 Ok(())
424 };
425
426 if item_out.elem != item_lhs.elem {
427 if item_out.vectorization > 1 {
428 format_vec(f)
429 } else {
430 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
431 }
432 } else {
433 write!(f, "{lhs}[{rhs}]")
434 }
435 }
436}
437
438struct IndexVector<D: Dialect> {
448 _dialect: PhantomData<D>,
449}
450
451struct IndexAssignVector<D: Dialect> {
462 _dialect: PhantomData<D>,
463}
464
465impl<D: Dialect> IndexVector<D> {
466 fn format(
467 f: &mut Formatter<'_>,
468 lhs: &Variable<D>,
469 rhs: &Variable<D>,
470 out: &Variable<D>,
471 ) -> std::fmt::Result {
472 match rhs {
473 Variable::ConstantScalar(value, _elem) => {
474 let index = value.as_usize();
475 let out = out.index(index);
476 let lhs = lhs.index(index);
477 let out = out.fmt_left();
478 writeln!(f, "{out} = {lhs};")
479 }
480 _ => {
481 let elem = out.elem();
482 let qualifier = out.const_qualifier();
483 let addr_space = D::address_space_for_variable(out);
484 let out = out.fmt_left();
485 writeln!(
486 f,
487 "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
488 )
489 }
490 }
491 }
492}
493
494impl<D: Dialect> IndexAssignVector<D> {
495 fn format(
496 f: &mut Formatter<'_>,
497 lhs: &Variable<D>,
498 rhs: &Variable<D>,
499 out: &Variable<D>,
500 ) -> std::fmt::Result {
501 let index = match lhs {
502 Variable::ConstantScalar(value, _) => value.as_usize(),
503 _ => {
504 let elem = out.elem();
505 let addr_space = D::address_space_for_variable(out);
506 return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
507 }
508 };
509
510 let out = out.index(index);
511 let rhs = rhs.index(index);
512
513 writeln!(f, "{out} = {rhs};")
514 }
515}