1use crate::shared::FmtLeft;
2
3use super::{Component, Dialect, Elem, Item, Variable};
4use std::{
5 fmt::{Display, Formatter},
6 marker::PhantomData,
7};
8
9pub trait Binary<D: Dialect> {
10 fn format(
11 f: &mut Formatter<'_>,
12 lhs: &Variable<D>,
13 rhs: &Variable<D>,
14 out: &Variable<D>,
15 ) -> std::fmt::Result {
16 let out_item = out.item();
17 if out.item().vectorization == 1 {
18 let out = out.fmt_left();
19 write!(f, "{out} = ")?;
20 Self::format_scalar(f, *lhs, *rhs, out_item)?;
21 f.write_str(";\n")
22 } else {
23 Self::unroll_vec(f, lhs, rhs, out)
24 }
25 }
26
27 fn format_scalar<Lhs, Rhs>(
28 f: &mut Formatter<'_>,
29 lhs: Lhs,
30 rhs: Rhs,
31 item: Item<D>,
32 ) -> std::fmt::Result
33 where
34 Lhs: Component<D>,
35 Rhs: Component<D>;
36
37 fn unroll_vec(
38 f: &mut Formatter<'_>,
39 lhs: &Variable<D>,
40 rhs: &Variable<D>,
41 out: &Variable<D>,
42 ) -> core::fmt::Result {
43 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
44 let [lhs, rhs, out_optimized] = optimized.args;
45
46 let item_out_original = out.item();
47 let item_out_optimized = out_optimized.item();
48
49 let index = match optimized.optimization_factor {
50 Some(factor) => item_out_original.vectorization / factor,
51 None => item_out_optimized.vectorization,
52 };
53
54 let mut write_op =
55 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
56 let out = out.fmt_left();
57 writeln!(f, "{out} = {item_out}{{")?;
58 for i in 0..index {
59 let lhsi = lhs.index(i);
60 let rhsi = rhs.index(i);
61
62 Self::format_scalar(f, lhsi, rhsi, item_out)?;
63 f.write_str(", ")?;
64 }
65
66 f.write_str("};\n")
67 };
68
69 if item_out_original == item_out_optimized {
70 write_op(&lhs, &rhs, out, item_out_optimized)
71 } else {
72 let out_tmp = Variable::tmp(item_out_optimized);
73
74 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
75
76 let out = out.fmt_left();
77
78 writeln!(
79 f,
80 "{out} = reinterpret_cast<{item_out_original}&>({out_tmp});\n"
81 )?;
82
83 Ok(())
84 }
85 }
86}
87
88macro_rules! operator {
89 ($name:ident, $op:expr) => {
90 pub struct $name;
91
92 impl<D: Dialect> Binary<D> for $name {
93 fn format_scalar<Lhs: Display, Rhs: Display>(
94 f: &mut std::fmt::Formatter<'_>,
95 lhs: Lhs,
96 rhs: Rhs,
97 _item: Item<D>,
98 ) -> std::fmt::Result {
99 write!(f, "{lhs} {} {rhs}", $op)
100 }
101 }
102 };
103}
104
105operator!(Add, "+");
106operator!(Sub, "-");
107operator!(Div, "/");
108operator!(Mul, "*");
109operator!(Modulo, "%");
110operator!(Equal, "==");
111operator!(NotEqual, "!=");
112operator!(Lower, "<");
113operator!(LowerEqual, "<=");
114operator!(Greater, ">");
115operator!(GreaterEqual, ">=");
116operator!(ShiftLeft, "<<");
117operator!(ShiftRight, ">>");
118operator!(BitwiseOr, "|");
119operator!(BitwiseAnd, "&");
120operator!(BitwiseXor, "^");
121operator!(Or, "||");
122operator!(And, "&&");
123
124pub struct Powf;
125
126impl<D: Dialect> Binary<D> for Powf {
127 fn format_scalar<Lhs: Display, Rhs: Display>(
129 f: &mut std::fmt::Formatter<'_>,
130 lhs: Lhs,
131 rhs: Rhs,
132 item: Item<D>,
133 ) -> std::fmt::Result {
134 let elem = item.elem;
135 match elem {
136 Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
137 write!(f, "{elem}(powf(float({lhs}), float({rhs})))")
138 }
139 _ => write!(f, "powf({lhs}, {rhs})"),
140 }
141 }
142
143 fn unroll_vec(
145 f: &mut Formatter<'_>,
146 lhs: &Variable<D>,
147 rhs: &Variable<D>,
148 out: &Variable<D>,
149 ) -> core::fmt::Result {
150 let item_out = out.item();
151 let index = out.item().vectorization;
152
153 let out = out.fmt_left();
154 writeln!(f, "{out} = {item_out}{{")?;
155 for i in 0..index {
156 let lhsi = lhs.index(i);
157 let rhsi = rhs.index(i);
158
159 Self::format_scalar(f, lhsi, rhsi, item_out)?;
160 f.write_str(", ")?;
161 }
162
163 f.write_str("};\n")
164 }
165}
166
167pub struct Max;
168
169impl<D: Dialect> Binary<D> for Max {
170 fn format_scalar<Lhs: Display, Rhs: Display>(
171 f: &mut std::fmt::Formatter<'_>,
172 lhs: Lhs,
173 rhs: Rhs,
174 item: Item<D>,
175 ) -> std::fmt::Result {
176 let max = match item.elem() {
177 Elem::F16 | Elem::BF16 => "__hmax",
178 Elem::F162 | Elem::BF162 => "__hmax2",
179 _ => "max",
180 };
181
182 write!(f, "{max}({lhs}, {rhs})")
183 }
184}
185
186pub struct Min;
187
188impl<D: Dialect> Binary<D> for Min {
189 fn format_scalar<Lhs: Display, Rhs: Display>(
190 f: &mut std::fmt::Formatter<'_>,
191 lhs: Lhs,
192 rhs: Rhs,
193 item: Item<D>,
194 ) -> std::fmt::Result {
195 let min = match item.elem() {
196 Elem::F16 | Elem::BF16 => "__hmin",
197 Elem::F162 | Elem::BF162 => "__hmin2",
198 _ => "min",
199 };
200
201 write!(f, "{min}({lhs}, {rhs})")
202 }
203}
204
205pub struct IndexAssign;
206pub struct Index;
207
208impl<D: Dialect> Binary<D> for IndexAssign {
209 fn format_scalar<Lhs, Rhs>(
210 f: &mut Formatter<'_>,
211 _lhs: Lhs,
212 rhs: Rhs,
213 item_out: Item<D>,
214 ) -> std::fmt::Result
215 where
216 Lhs: Component<D>,
217 Rhs: Component<D>,
218 {
219 let item_rhs = rhs.item();
220
221 let format_vec = |f: &mut Formatter<'_>, cast: bool| {
222 writeln!(f, "{item_out}{{")?;
223 for i in 0..item_out.vectorization {
224 if cast {
225 writeln!(f, "{}({}),", item_out.elem, rhs.index(i))?;
226 } else {
227 writeln!(f, "{},", rhs.index(i))?;
228 }
229 }
230 f.write_str("}")?;
231
232 Ok(())
233 };
234
235 if item_out.vectorization != item_rhs.vectorization {
236 format_vec(f, item_out != item_rhs)
237 } else if item_out.elem != item_rhs.elem {
238 if item_out.vectorization > 1 {
239 format_vec(f, true)?;
240 } else {
241 write!(f, "{}({rhs})", item_out.elem)?;
242 }
243 Ok(())
244 } else if rhs.is_const() && item_rhs.vectorization > 1 {
245 write!(f, "reinterpret_cast<{item_out} const&>({rhs})")
247 } else {
248 write!(f, "{rhs}")
249 }
250 }
251
252 fn unroll_vec(
253 f: &mut Formatter<'_>,
254 lhs: &Variable<D>,
255 rhs: &Variable<D>,
256 out: &Variable<D>,
257 ) -> std::fmt::Result {
258 let item_lhs = lhs.item();
259 let out_item = out.item();
260 let out = out.fmt_left();
261
262 for i in 0..item_lhs.vectorization {
263 let lhsi = lhs.index(i);
264 let rhsi = rhs.index(i);
265 write!(f, "{out}[{lhs}] = ")?;
266 Self::format_scalar(f, lhsi, rhsi, out_item)?;
267 f.write_str(";\n")?;
268 }
269
270 Ok(())
271 }
272
273 fn format(
274 f: &mut Formatter<'_>,
275 lhs: &Variable<D>,
276 rhs: &Variable<D>,
277 out: &Variable<D>,
278 ) -> std::fmt::Result {
279 if matches!(out, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
280 return IndexAssignVector::format(f, lhs, rhs, out);
281 };
282
283 let out_item = out.item();
284
285 if lhs.item().vectorization == 1 {
286 write!(f, "{}[{lhs}] = ", out.fmt_left())?;
287 Self::format_scalar(f, *lhs, *rhs, out_item)?;
288 f.write_str(";\n")
289 } else {
290 Self::unroll_vec(f, lhs, rhs, out)
291 }
292 }
293}
294
295impl<D: Dialect> Binary<D> for Index {
296 fn format(
297 f: &mut Formatter<'_>,
298 lhs: &Variable<D>,
299 rhs: &Variable<D>,
300 out: &Variable<D>,
301 ) -> std::fmt::Result {
302 if matches!(lhs, Variable::LocalMut { .. } | Variable::LocalConst { .. }) {
303 return IndexVector::format(f, lhs, rhs, out);
304 }
305
306 let item_out = out.item();
307 if let Elem::Atomic(inner) = item_out.elem {
308 write!(f, "{inner}* {out} = &{lhs}[{rhs}];")
309 } else {
310 let out = out.fmt_left();
311 write!(f, "{out} = ")?;
312 Self::format_scalar(f, *lhs, *rhs, item_out)?;
313 f.write_str(";\n")
314 }
315 }
316
317 fn format_scalar<Lhs, Rhs>(
318 f: &mut Formatter<'_>,
319 lhs: Lhs,
320 rhs: Rhs,
321 item_out: Item<D>,
322 ) -> std::fmt::Result
323 where
324 Lhs: Component<D>,
325 Rhs: Component<D>,
326 {
327 let item_lhs = lhs.item();
328
329 let format_vec = |f: &mut Formatter<'_>| {
330 writeln!(f, "{item_out}{{")?;
331 for i in 0..item_out.vectorization {
332 write!(f, "{}({lhs}[{rhs}].i_{i}),", item_out.elem)?;
333 }
334 f.write_str("}")?;
335
336 Ok(())
337 };
338
339 if item_out.elem != item_lhs.elem {
340 if item_out.vectorization > 1 {
341 format_vec(f)
342 } else {
343 write!(f, "{}({lhs}[{rhs}])", item_out.elem)
344 }
345 } else {
346 write!(f, "{lhs}[{rhs}]")
347 }
348 }
349}
350
351struct IndexVector<D: Dialect> {
361 _dialect: PhantomData<D>,
362}
363
364struct IndexAssignVector<D: Dialect> {
375 _dialect: PhantomData<D>,
376}
377
378impl<D: Dialect> IndexVector<D> {
379 fn format(
380 f: &mut Formatter<'_>,
381 lhs: &Variable<D>,
382 rhs: &Variable<D>,
383 out: &Variable<D>,
384 ) -> std::fmt::Result {
385 let index = match rhs {
386 Variable::ConstantScalar(value, _elem) => value.as_usize(),
387 _ => {
388 let elem = out.elem();
389 let qualifier = out.const_qualifier();
390 let out = out.fmt_left();
391 return writeln!(
392 f,
393 "{out} = reinterpret_cast<{elem}{qualifier}*>(&{lhs})[{rhs}];"
394 );
395 }
396 };
397
398 let out = out.index(index);
399 let lhs = lhs.index(index);
400
401 let out = out.fmt_left();
402 writeln!(f, "{out} = {lhs};")
403 }
404}
405
406impl<D: Dialect> IndexAssignVector<D> {
407 fn format(
408 f: &mut Formatter<'_>,
409 lhs: &Variable<D>,
410 rhs: &Variable<D>,
411 out: &Variable<D>,
412 ) -> std::fmt::Result {
413 let index = match lhs {
414 Variable::ConstantScalar(value, _) => value.as_usize(),
415 _ => {
416 let elem = out.elem();
417 return writeln!(f, "*(({elem}*)&{out} + {lhs}) = {rhs};");
418 }
419 };
420
421 let out = out.index(index);
422 let rhs = rhs.index(index);
423
424 writeln!(f, "{out} = {rhs};")
425 }
426}