1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
74 let addr_space = D::address_space_for_variable(out);
75 let out = out.fmt_left();
76
77 writeln!(
78 f,
79 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
80 )?;
81
82 Ok(())
83 }
84 }
85}
86
87macro_rules! operator {
88 ($name:ident, $op:expr) => {
89 pub struct $name;
90
91 impl<D: Dialect> Binary<D> for $name {
92 fn format_scalar<Lhs: Display, Rhs: Display>(
93 f: &mut std::fmt::Formatter<'_>,
94 lhs: Lhs,
95 rhs: Rhs,
96 out_item: Item<D>,
97 ) -> std::fmt::Result {
98 let out_elem = out_item.elem();
99 match out_elem {
100 Elem::<D>::I16 | Elem::<D>::U16 | Elem::<D>::I8 | Elem::<D>::U8 => {
104 write!(f, "{out_elem}({lhs} {} {rhs})", $op)
105 }
106 _ => write!(f, "{lhs} {} {rhs}", $op),
107 }
108 }
109 }
110 };
111}
112
113operator!(Add, "+");
114operator!(Sub, "-");
115operator!(Div, "/");
116operator!(Mul, "*");
117operator!(Modulo, "%");
118operator!(Equal, "==");
119operator!(NotEqual, "!=");
120operator!(Lower, "<");
121operator!(LowerEqual, "<=");
122operator!(Greater, ">");
123operator!(GreaterEqual, ">=");
124operator!(ShiftLeft, "<<");
125operator!(ShiftRight, ">>");
126operator!(BitwiseOr, "|");
127operator!(BitwiseAnd, "&");
128operator!(BitwiseXor, "^");
129operator!(Or, "||");
130operator!(And, "&&");
131
132pub struct HiMul;
133
134impl<D: Dialect> Binary<D> for HiMul {
135 fn format_scalar<Lhs: Display, Rhs: Display>(
137 f: &mut std::fmt::Formatter<'_>,
138 lhs: Lhs,
139 rhs: Rhs,
140 out: Item<D>,
141 ) -> std::fmt::Result {
142 let out_elem = out.elem;
143 match out_elem {
144 Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
145 Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
146 Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
147 Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
148 _ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
149 }
150 }
151
152 fn unroll_vec(
154 f: &mut Formatter<'_>,
155 lhs: &Variable<D>,
156 rhs: &Variable<D>,
157 out: &Variable<D>,
158 ) -> core::fmt::Result {
159 let item_out = out.item();
160 let index = out.item().vectorization;
161
162 let out = out.fmt_left();
163 writeln!(f, "{out} = {item_out}{{")?;
164 for i in 0..index {
165 let lhsi = lhs.index(i);
166 let rhsi = rhs.index(i);
167
168 Self::format_scalar(f, lhsi, rhsi, item_out)?;
169 f.write_str(", ")?;
170 }
171
172 f.write_str("};\n")
173 }
174}
175
176pub struct Powf;
177
178impl<D: Dialect> Binary<D> for Powf {
179 fn format_scalar<Lhs: Display, Rhs: Display>(
181 f: &mut std::fmt::Formatter<'_>,
182 lhs: Lhs,
183 rhs: Rhs,
184 item: Item<D>,
185 ) -> std::fmt::Result {
186 let elem = item.elem;
187 match elem {
188 Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
189 write!(f, "{elem}(")?;
190 D::compile_instruction_powf(f)?;
191 write!(f, "(float({lhs}), float({rhs})))")
192 }
193 _ => {
194 D::compile_instruction_powf(f)?;
195 write!(f, "({lhs}, {rhs})")
196 }
197 }
198 }
199
200 fn unroll_vec(
202 f: &mut Formatter<'_>,
203 lhs: &Variable<D>,
204 rhs: &Variable<D>,
205 out: &Variable<D>,
206 ) -> core::fmt::Result {
207 let item_out = out.item();
208 let index = out.item().vectorization;
209
210 let out = out.fmt_left();
211 writeln!(f, "{out} = {item_out}{{")?;
212 for i in 0..index {
213 let lhsi = lhs.index(i);
214 let rhsi = rhs.index(i);
215
216 Self::format_scalar(f, lhsi, rhsi, item_out)?;
217 f.write_str(", ")?;
218 }
219
220 f.write_str("};\n")
221 }
222}
223
224pub struct Max;
225
226impl<D: Dialect> Binary<D> for Max {
227 fn format_scalar<Lhs: Display, Rhs: Display>(
228 f: &mut std::fmt::Formatter<'_>,
229 lhs: Lhs,
230 rhs: Rhs,
231 item: Item<D>,
232 ) -> std::fmt::Result {
233 D::compile_instruction_max_function_name(f, item)?;
234 write!(f, "({lhs}, {rhs})")
235 }
236}
237
238pub struct Min;
239
240impl<D: Dialect> Binary<D> for Min {
241 fn format_scalar<Lhs: Display, Rhs: Display>(
242 f: &mut std::fmt::Formatter<'_>,
243 lhs: Lhs,
244 rhs: Rhs,
245 item: Item<D>,
246 ) -> std::fmt::Result {
247 D::compile_instruction_min_function_name(f, item)?;
248 write!(f, "({lhs}, {rhs})")
249 }
250}
251
252pub struct IndexAssign;
253pub struct Index;
254
255impl<D: Dialect> Binary<D> for IndexAssign {
256 fn format_scalar<Lhs, Rhs>(
257 f: &mut Formatter<'_>,
258 _lhs: Lhs,
259 rhs: Rhs,
260 item_out: Item<D>,
261 ) -> std::fmt::Result
262 where
263 Lhs: Component<D>,
264 Rhs: Component<D>,
265 {
266 let item_rhs = rhs.item();
267
268 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
269 writeln!(f, "{item_out}{{")?;
270 for i in 0..item_out.vectorization {
271 if cast {
272 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
273 } else {
274 writeln!(f, "{},", rhs.index(i))?;
275 }
276 }
277 f.write_str("}")?;
278
279 Ok(())
280 };
281
282 if item_out.vectorization != item_rhs.vectorization {
283 format_vec(f, item_out != item_rhs)
284 } else if item_out.elem != item_rhs.elem {
285 if item_out.vectorization > 1 {
286 format_vec(f, true)?;
287 } else {
288 write!(f, "{}({rhs})", item_out.elem)?;
289 }
290 Ok(())
291 } else if rhs.is_const() && item_rhs.vectorization > 1 {
292 write!(f, "reinterpret_cast<")?;
294 D::compile_local_memory_qualifier(f)?;
295 write!(f, " {item_out} const&>({rhs})")
296 } else {
297 write!(f, "{rhs}")
298 }
299 }
300
301 fn unroll_vec(
302 f: &mut Formatter<'_>,
303 lhs: &Variable<D>,
304 rhs: &Variable<D>,
305 out: &Variable<D>,
306 ) -> std::fmt::Result {
307 let item_lhs = lhs.item();
308 let out_item = out.item();
309 let out = out.fmt_left();
310
311 for i in 0..item_lhs.vectorization {
312 let lhsi = lhs.index(i);
313 let rhsi = rhs.index(i);
314 write!(f, "{out}[{lhs}] = ")?;
315 Self::format_scalar(f, lhsi, rhsi, out_item)?;
316 f.write_str(";\n")?;
317 }
318
319 Ok(())
320 }
321
322 fn format(
323 f: &mut Formatter<'_>,
324 lhs: &Variable<D>,
325 rhs: &Variable<D>,
326 out: &Variable<D>,
327 ) -> std::fmt::Result {
328 if matches!(out, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
329 return IndexAssignVector::format(f, lhs, rhs, out);
330 };
331
332 let out_item = out.item();
333
334 if lhs.item().vectorization == 1 {
335 write!(f, "{}[{lhs}] = ", out.fmt_left())?;
336 Self::format_scalar(f, *lhs, *rhs, out_item)?;
337 f.write_str(";\n")
338 } else {
339 Self::unroll_vec(f, lhs, rhs, out)
340 }
341 }
342}
343
344impl<D: Dialect> Binary<D> for Index {
345 fn format(
346 f: &mut Formatter<'_>,
347 lhs: &Variable<D>,
348 rhs: &Variable<D>,
349 out: &Variable<D>,
350 ) -> std::fmt::Result {
351 if matches!(lhs, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
352 return IndexVector::format(f, lhs, rhs, out);
353 }
354
355 let item_out = out.item();
356 if let Elem::Atomic(inner) = item_out.elem {
357 let addr_space = D::address_space_for_variable(lhs);
358 writeln!(f, "{addr_space}{inner}* {out} = &{lhs}[{rhs}];")
359 } else {
360 let out = out.fmt_left();
361 write!(f, "{out} = ")?;
362 Self::format_scalar(f, *lhs, *rhs, item_out)?;
363 f.write_str(";\n")
364 }
365 }
366
367 fn format_scalar<Lhs, Rhs>(
368 f: &mut Formatter<'_>,
369 lhs: Lhs,
370 rhs: Rhs,
371 item_out: Item<D>,
372 ) -> std::fmt::Result
373 where
374 Lhs: Component<D>,
375 Rhs: Component<D>,
376 {
377 let item_lhs = lhs.item();
378
379 let format_vec = |f: &mut Formatter<'_>| {
380 writeln!(f, "{item_out}{{")?;
381 for i in 0..item_out.vectorization {
382 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
383 }
384 f.write_str("}")?;
385
386 Ok(())
387 };
388
389 if item_out.elem != item_lhs.elem {
390 if item_out.vectorization > 1 {
391 format_vec(f)
392 } else {
393 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
394 }
395 } else {
396 write!(f, "{lhs}[{rhs}]")
397 }
398 }
399}
400
401struct IndexVector<D: Dialect> {
411 _dialect: PhantomData<D>,
412}
413
414struct IndexAssignVector<D: Dialect> {
425 _dialect: PhantomData<D>,
426}
427
428impl<D: Dialect> IndexVector<D> {
429 fn format(
430 f: &mut Formatter<'_>,
431 lhs: &Variable<D>,
432 rhs: &Variable<D>,
433 out: &Variable<D>,
434 ) -> std::fmt::Result {
435 match rhs {
436 Variable::ConstantScalar(value, _elem) => {
437 let index = value.as_usize();
438 let out = out.index(index);
439 let lhs = lhs.index(index);
440 let out = out.fmt_left();
441 writeln!(f, "{out} = {lhs};")
442 }
443 _ => {
444 let elem = out.elem();
445 let qualifier = out.const_qualifier();
446 let addr_space = D::address_space_for_variable(out);
447 let out = out.fmt_left();
448 writeln!(
449 f,
450 "{out} = reinterpret_cast<{addr_space}{elem}{qualifier}*>(&{lhs})[{rhs}];"
451 )
452 }
453 }
454 }
455}
456
457impl<D: Dialect> IndexAssignVector<D> {
458 fn format(
459 f: &mut Formatter<'_>,
460 lhs: &Variable<D>,
461 rhs: &Variable<D>,
462 out: &Variable<D>,
463 ) -> std::fmt::Result {
464 let index = match lhs {
465 Variable::ConstantScalar(value, _) => value.as_usize(),
466 _ => {
467 let elem = out.elem();
468 let addr_space = D::address_space_for_variable(out);
469 return writeln!(f, "*(({addr_space}{elem}*)&{out} + {lhs}) = {rhs};");
470 }
471 };
472
473 let out = out.index(index);
474 let rhs = rhs.index(index);
475
476 writeln!(f, "{out} = {rhs};")
477 }
478}