1use crate::shared::FmtLeft;
2
3use super::{
4 binary::*, unary::*, Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5};
6use std::{fmt::Display, marker::PhantomData};
7
8#[derive(Debug, Clone)]
9pub struct BinaryInstruction<D: Dialect> {
10 pub lhs: Variable<D>,
11 pub rhs: Variable<D>,
12 pub out: Variable<D>,
13}
14
15#[derive(Debug, Clone)]
16pub struct UnaryInstruction<D: Dialect> {
17 pub input: Variable<D>,
18 pub out: Variable<D>,
19}
20
21#[derive(Debug, Clone)]
22pub enum Instruction<D: Dialect> {
23 Metadata {
24 info_offset: Variable<D>,
25 out: Variable<D>,
26 },
27 ExtendedMetadata {
28 info_offset: Variable<D>,
29 dim: Variable<D>,
30 out: Variable<D>,
31 },
32 SliceLength {
33 input: Variable<D>,
34 out: Variable<D>,
35 },
36 DeclareVariable {
37 var: Variable<D>,
38 },
39 Modulo(BinaryInstruction<D>),
40 Remainder(BinaryInstruction<D>),
41 Add(BinaryInstruction<D>),
42 Fma {
43 a: Variable<D>,
44 b: Variable<D>,
45 c: Variable<D>,
46 out: Variable<D>,
47 },
48 Div(BinaryInstruction<D>),
49 Mul(BinaryInstruction<D>),
50 Sub(BinaryInstruction<D>),
51 Index(BinaryInstruction<D>),
52 IndexAssign(BinaryInstruction<D>),
53 CheckedIndex {
54 len: Variable<D>,
55 lhs: Variable<D>,
56 rhs: Variable<D>,
57 out: Variable<D>,
58 },
59 Assign(UnaryInstruction<D>),
60 RangeLoop {
61 i: Variable<D>,
62 start: Variable<D>,
63 end: Variable<D>,
64 step: Option<Variable<D>>,
65 inclusive: bool,
66 instructions: Vec<Self>,
67 },
68 VecInit {
69 inputs: Vec<Variable<D>>,
70 out: Variable<D>,
71 },
72 Loop {
73 instructions: Vec<Self>,
74 },
75 If {
76 cond: Variable<D>,
77 instructions: Vec<Self>,
78 },
79 IfElse {
80 cond: Variable<D>,
81 instructions_if: Vec<Self>,
82 instructions_else: Vec<Self>,
83 },
84 Select {
85 cond: Variable<D>,
86 then: Variable<D>,
87 or_else: Variable<D>,
88 out: Variable<D>,
89 },
90 Switch {
91 value: Variable<D>,
92 instructions_default: Vec<Self>,
93 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
94 },
95 Slice {
96 input: Variable<D>,
97 start: Variable<D>,
98 end: Variable<D>,
99 out: Variable<D>,
100 },
101 CheckedSlice {
102 input: Variable<D>,
103 start: Variable<D>,
104 end: Variable<D>,
105 out: Variable<D>,
106 len: Variable<D>,
107 },
108 Return,
109 Break,
110 Equal(BinaryInstruction<D>),
111 NotEqual(BinaryInstruction<D>),
112 Lower(BinaryInstruction<D>),
113 Greater(BinaryInstruction<D>),
114 LowerEqual(BinaryInstruction<D>),
115 GreaterEqual(BinaryInstruction<D>),
116 Erf(UnaryInstruction<D>),
117 BitwiseOr(BinaryInstruction<D>),
118 BitwiseAnd(BinaryInstruction<D>),
119 BitwiseXor(BinaryInstruction<D>),
120 CountBits(UnaryInstruction<D>),
121 ReverseBits(UnaryInstruction<D>),
122 ShiftLeft(BinaryInstruction<D>),
123 ShiftRight(BinaryInstruction<D>),
124 Abs(UnaryInstruction<D>),
125 Exp(UnaryInstruction<D>),
126 Log(UnaryInstruction<D>),
127 Log1p(UnaryInstruction<D>),
128 Cos(UnaryInstruction<D>),
129 Sin(UnaryInstruction<D>),
130 Tanh(UnaryInstruction<D>),
131 Powf(BinaryInstruction<D>),
132 Sqrt(UnaryInstruction<D>),
133 Min(BinaryInstruction<D>),
134 Max(BinaryInstruction<D>),
135 Not(UnaryInstruction<D>),
136 Or(BinaryInstruction<D>),
137 And(BinaryInstruction<D>),
138 Clamp {
139 input: Variable<D>,
140 min_value: Variable<D>,
141 max_value: Variable<D>,
142 out: Variable<D>,
143 },
144 SyncThreads,
145 ThreadFence,
146 Round(UnaryInstruction<D>),
147 Ceil(UnaryInstruction<D>),
148 Floor(UnaryInstruction<D>),
149 Wrap(WarpInstruction<D>),
150 Wmma(WmmaInstruction<D>),
151 Bitcast(UnaryInstruction<D>),
152 AtomicLoad(UnaryInstruction<D>),
153 AtomicStore(UnaryInstruction<D>),
154 AtomicSwap(BinaryInstruction<D>),
155 AtomicAdd(BinaryInstruction<D>),
156 AtomicSub(BinaryInstruction<D>),
157 AtomicMax(BinaryInstruction<D>),
158 AtomicMin(BinaryInstruction<D>),
159 AtomicAnd(BinaryInstruction<D>),
160 AtomicOr(BinaryInstruction<D>),
161 AtomicXor(BinaryInstruction<D>),
162 AtomicCAS {
163 input: Variable<D>,
164 cmp: Variable<D>,
165 val: Variable<D>,
166 out: Variable<D>,
167 },
168 Neg(UnaryInstruction<D>),
169 Magnitude(UnaryInstruction<D>),
170 Normalize(UnaryInstruction<D>),
171 Dot(BinaryInstruction<D>),
172 Copy {
173 input: Variable<D>,
174 in_index: Variable<D>,
175 out: Variable<D>,
176 out_index: Variable<D>,
177 },
178 CopyBulk {
179 input: Variable<D>,
180 in_index: Variable<D>,
181 out: Variable<D>,
182 out_index: Variable<D>,
183 len: u32,
184 },
185 Printf {
186 format_string: String,
187 args: Vec<Variable<D>>,
188 },
189 Comment {
190 content: String,
191 },
192}
193
194impl<D: Dialect> Display for Instruction<D> {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 match self {
197 Instruction::Return => f.write_str("return;"),
198 Instruction::Break => f.write_str("break;"),
199 Instruction::DeclareVariable { var } => match var {
200 Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
201 _ => {
202 let item = var.item();
203 writeln!(f, "{item} {var};")
204 }
205 },
206 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
207 Instruction::Slice {
208 input,
209 start,
210 end,
211 out,
212 } => {
213 let item = out.item();
214 writeln!(f, "const uint {out}_length = {end};")?;
215 writeln!(f, "{item} *{out} = {input} + {start};")
216 }
217 Instruction::CheckedSlice {
218 input,
219 start,
220 end,
221 out,
222 len,
223 } => {
224 let item = out.item();
225 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
226 writeln!(f, "{item} *{out} = {input} + {start};")
227 }
228 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
229 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
230 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
231 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
232 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
233 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
234 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
235 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
236 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
237 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
238 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
239 Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out),
240 Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out),
241 Instruction::CheckedIndex { len, lhs, rhs, out } => {
242 let item_out = out.item();
243 if let Elem::Atomic(inner) = item_out.elem {
244 write!(f, "{inner}* {out} = &{lhs}[{rhs}];")
245 } else {
246 let out = out.fmt_left();
247 write!(f, "{out} = ({rhs} < {len}) ? ")?;
248 Index::format_scalar(f, *lhs, *rhs, item_out)?;
249 if item_out.vectorization == 1 {
250 writeln!(f, " : {item_out}(0);")
251 } else {
252 writeln!(f, " : {item_out}{{}};")
253 }
254 }
255 }
256 Instruction::Copy {
257 input,
258 in_index,
259 out,
260 out_index,
261 } => {
262 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
263 }
264 Instruction::CopyBulk {
265 input,
266 in_index,
267 out,
268 out_index,
269 len,
270 } => {
271 for i in 0..*len {
272 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
273 }
274 Ok(())
275 }
276 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
277 Instruction::RangeLoop {
278 i,
279 start,
280 end,
281 step,
282 inclusive,
283 instructions,
284 } => {
285 let increment = step
286 .map(|step| format!("{i} += {step}"))
287 .unwrap_or_else(|| format!("++{i}"));
288 let cmp = if *inclusive { "<=" } else { "<" };
289 let i_ty = i.item();
290
291 write!(
292 f,
293 "
294for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
295"
296 )?;
297 for instruction in instructions {
298 write!(f, "{instruction}")?;
299 }
300
301 f.write_str("}\n")
302 }
303
304 Instruction::Loop { instructions } => {
305 writeln!(f, "while (true) {{")?;
306 for i in instructions {
307 write!(f, "{i}")?;
308 }
309 f.write_str("}\n")
310 }
311 Instruction::If { cond, instructions } => {
312 writeln!(f, "if ({cond}) {{")?;
313 for i in instructions {
314 write!(f, "{i}")?;
315 }
316 f.write_str("}\n")
317 }
318 Instruction::IfElse {
319 cond,
320 instructions_if,
321 instructions_else,
322 } => {
323 writeln!(f, "if ({cond}) {{")?;
324 for i in instructions_if {
325 write!(f, "{i}")?;
326 }
327 f.write_str("} else {\n")?;
328 for i in instructions_else {
329 write!(f, "{i}")?;
330 }
331 f.write_str("}\n")
332 }
333 Instruction::Select {
334 cond,
335 then,
336 or_else,
337 out,
338 } => {
339 let item_or_else = or_else.item();
340 let item_then = then.item();
341 let item_out = out.item();
342
343 let vf_then = item_then.vectorization;
344 let vf_or_else = item_or_else.vectorization;
345 let vf_out = item_out.vectorization;
346 let vf_cond = cond.item().vectorization;
347
348 let item_out = out.item();
349 let cond_elem = cond.item().elem;
350 let out = out.fmt_left();
351
352 let should_broadcast =
353 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
354
355 if should_broadcast {
356 let vf = usize::max(vf_cond, vf_out);
357 let vf = usize::max(vf, vf_then);
358 let vf = usize::max(vf, vf_or_else);
359
360 writeln!(f, "{out} = {item_out} {{")?;
361 for i in 0..vf {
362 let theni = then.index(i);
363 let or_elsei = or_else.index(i);
364 let condi = cond.index(i);
365 let condi = EnsureBoolArg {
366 var: &condi,
367 elem: &cond_elem,
368 };
369
370 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
371 }
372
373 writeln!(f, "}};")
374 } else {
375 let cond = EnsureBoolArg {
376 var: &cond,
377 elem: &cond_elem,
378 };
379 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
380 }
381 }
382 Instruction::Switch {
383 value,
384 instructions_default,
385 instructions_cases,
386 } => {
387 writeln!(f, "switch({value}) {{")?;
388 for (value, block) in instructions_cases {
389 write!(f, "case {value}:\n{{\n")?;
390 for i in block {
391 i.fmt(f)?;
392 }
393 f.write_str("break;\n}\n")?;
394 }
395 f.write_str("default:\n{")?;
396 for i in instructions_default {
397 i.fmt(f)?;
398 }
399 f.write_str("}\n}\n")
400 }
401 Instruction::Metadata { info_offset, out } => {
402 let out = out.fmt_left();
403 writeln!(f, "{out} = info[{info_offset}];")
404 }
405 Instruction::ExtendedMetadata {
406 info_offset,
407 dim,
408 out,
409 } => {
410 let out = out.fmt_left();
411 writeln!(f, "{out} = info[info[{info_offset}] + {dim}];")
412 }
413 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
414 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
415 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
416 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
417 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
418 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
419 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
420 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
421 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
422 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
423 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
424 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
425 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
426 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
427 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
428 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
429 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
430 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
431 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
432 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
433 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
434 Instruction::Clamp {
435 input,
436 min_value,
437 max_value,
438 out,
439 } => Clamp::format(f, input, min_value, max_value, out),
440 Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
441 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
442 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
443 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
444 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
445 Instruction::SliceLength { input, out } => {
446 let out = out.fmt_left();
447 writeln!(f, "{out} = {input}_length;")
448 }
449 Instruction::Wrap(it) => write!(f, "{it}"),
450 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
451 Instruction::Wmma(it) => write!(f, "{it}"),
452 Instruction::Bitcast(UnaryInstruction { input, out }) => {
453 let qualifier = out.const_qualifier();
454 let out_elem = out.elem();
455 let out = out.fmt_left();
456
457 match (input.elem(), out_elem) {
458 (Elem::F32, Elem::I32) => {
459 writeln!(f, "{out} = __float_as_int({input});")
460 }
461 (Elem::F32, Elem::U32) => {
462 writeln!(f, "{out} = __float_as_uint({input});")
463 }
464 (Elem::F16, Elem::I32) => {
465 writeln!(f, "{out} = __half_as_short({input});")
466 }
467 (Elem::F16, Elem::U32) => {
468 writeln!(f, "{out} = __half_as_ushort({input});")
469 }
470 (Elem::BF16, Elem::I32) => {
471 writeln!(f, "{out} = __bfloat16_as_short({input});")
472 }
473 (Elem::BF16, Elem::U32) => {
474 writeln!(f, "{out} = __bfloat16_as_ushort({input});")
475 }
476 (Elem::I32, Elem::F32) => {
477 writeln!(f, "{out} = __int_as_float({input});")
478 }
479 (Elem::I32, Elem::F16) => {
480 writeln!(f, "{out} = __short_as_half({input});")
481 }
482 (Elem::I32, Elem::BF16) => {
483 writeln!(f, "{out} = __short_as_bfloat16({input});")
484 }
485 (Elem::U32, Elem::F32) => {
486 writeln!(f, "{out} = __uint_as_float({input});")
487 }
488 (Elem::U32, Elem::F16) => {
489 writeln!(f, "{out} = __ushort_as_half({input});")
490 }
491 (Elem::U32, Elem::BF16) => {
492 writeln!(f, "{out} = __ushort_as_bfloat16({input});")
493 }
494 (Elem::I32, Elem::U32) => {
495 writeln!(f, "{out} = reinterpret_cast<uint{qualifier}&>({input});")
496 }
497 elem => panic!("Unsupported type for bitcasting {elem:?}"),
498 }
499 }
500 Instruction::AtomicCAS {
501 input,
502 cmp,
503 val,
504 out,
505 } => {
506 let out = out.fmt_left();
507 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
508 }
509 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
510 let out = out.fmt_left();
511 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
512 }
513 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
514 let out = out.fmt_left();
515 match rhs.elem() {
516 Elem::I64 => {
517 writeln!(
518 f,
519 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
520 uint = Elem::<D>::U64
521 )
522 }
523 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
524 }
525 }
526 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
527 let out = out.fmt_left();
528 match rhs.elem() {
529 Elem::U32 | Elem::I32 => {
530 writeln!(f, "{out} = atomicSub({lhs}, {rhs});")
531 }
532 Elem::U64 => {
533 writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});",)
534 }
535 Elem::I64 => {
536 writeln!(
537 f,
538 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
539 uint = Elem::<D>::U64
540 )
541 }
542 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
543 }
544 }
545 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
546 let out = out.fmt_left();
547 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
548 }
549 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
550 let out = out.fmt_left();
551 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
552 }
553 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
554 let out = out.fmt_left();
555 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
556 }
557 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
558 let out = out.fmt_left();
559 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
560 }
561 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
562 let out = out.fmt_left();
563 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
564 }
565 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
566 let out = out.fmt_left();
567 writeln!(f, "{out} = atomicAdd({input}, 0);")
568 }
569 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
570 let out = out.fmt_left();
571 writeln!(f, "atomicExch({out}, {input});")
572 }
573 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
574 Instruction::Neg(UnaryInstruction { input, out }) => {
575 let out = out.fmt_left();
576 writeln!(f, "{out} = -{input};")
577 }
578 Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
579 Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
580 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
581 Instruction::VecInit { inputs, out } => {
582 let item = out.item();
583 let inputs = inputs
584 .iter()
585 .map(|input| format!("{input}"))
586 .collect::<Vec<_>>();
587 let out = out.fmt_left();
588 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
589 }
590 Instruction::Printf {
591 format_string,
592 args,
593 } => {
594 let format_string = escape_string(format_string);
595 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
596 let args = match args.is_empty() {
597 true => "".to_string(),
598 false => format!(", {}", args.join(",")),
599 };
600 writeln!(f, "printf(\"{format_string}\"{args});")
601 }
602 Instruction::Comment { content } => {
603 if content.contains('\n') {
604 writeln!(f, "/* {content} */")
605 } else {
606 writeln!(f, "// {content}")
607 }
608 }
609 }
610 }
611}
612
613fn escape_string(format_string: &str) -> String {
614 format_string
615 .replace("\t", "\\t")
616 .replace("\n", "\\n")
617 .replace("\r", "\\r")
618}
619
620struct Fma<D: Dialect> {
621 _dialect: PhantomData<D>,
622}
623
624impl<D: Dialect> Fma<D> {
625 fn format(
626 f: &mut core::fmt::Formatter<'_>,
627 a: &Variable<D>,
628 b: &Variable<D>,
629 c: &Variable<D>,
630 out: &Variable<D>,
631 ) -> core::fmt::Result {
632 let out_item = out.item();
633 let num = out_item.vectorization;
634
635 let out = out.fmt_left();
636 if num == 1 {
637 writeln!(f, "{out} = fma({a}, {b}, {c});")
638 } else {
639 writeln!(f, "{out} = {out_item}{{")?;
640
641 for i in 0..num {
642 let ai = a.index(i);
643 let bi = b.index(i);
644 let ci = c.index(i);
645
646 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
647 }
648 f.write_str("};\n")
649 }
650 }
651}
652
653struct Clamp<D: Dialect> {
654 _dialect: PhantomData<D>,
655}
656
657impl<D: Dialect> Clamp<D> {
658 fn format(
659 f: &mut core::fmt::Formatter<'_>,
660 input: &Variable<D>,
661 min_value: &Variable<D>,
662 max_value: &Variable<D>,
663 out: &Variable<D>,
664 ) -> core::fmt::Result {
665 let input = input.optimized();
666 let min_value = min_value.optimized();
667 let max_value = max_value.optimized();
668 let out = out.optimized();
669 let out_item = out.item();
670 let num = out_item.vectorization;
671
672 let (min, max) = match out.elem() {
673 Elem::F16 | Elem::BF16 => ("__hmin", "__hmax"),
674 Elem::F162 | Elem::BF162 => ("__hmin2", "__hmax2"),
675 _ => ("min", "max"),
676 };
677
678 let out = out.fmt_left();
679 if num == 1 {
680 writeln!(
681 f,
682 "{out} = {max}({min_value}, {min}({max_value}, {input}));"
683 )
684 } else {
685 writeln!(f, "{out} = {out_item}{{")?;
686 for i in 0..num {
687 let inputi = input.index(i);
688 let mini = min_value.index(i);
689 let maxi = max_value.index(i);
690
691 writeln!(f, "{max}({mini}, {min}({maxi}, {inputi})),")?;
692 }
693
694 f.write_str("};\n")
695 }
696 }
697}
698
699struct Remainder<D: Dialect> {
700 _dialect: PhantomData<D>,
701}
702
703impl<D: Dialect> Remainder<D> {
704 fn format(
705 f: &mut core::fmt::Formatter<'_>,
706 lhs: &Variable<D>,
707 rhs: &Variable<D>,
708 out: &Variable<D>,
709 ) -> core::fmt::Result {
710 let floor = |elem| match elem {
711 Elem::F16 | Elem::BF16 => "hfloor",
712 Elem::F162 | Elem::BF162 => "h2floor",
713 _ => "floor",
714 };
715
716 if out.item().vectorization == 1 {
717 let floor = floor(out.elem());
718
719 let out = out.fmt_left();
720 return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
721 }
722
723 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
724 let [lhs, rhs, out_optimized] = optimized.args;
725
726 let item_out_original = out.item();
727 let item_out_optimized = out_optimized.item();
728
729 let index = match optimized.optimization_factor {
730 Some(factor) => item_out_original.vectorization / factor,
731 None => item_out_optimized.vectorization,
732 };
733
734 let floor = floor(*item_out_optimized.elem());
735
736 let mut write_op =
737 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
738 let out = out.fmt_left();
739 writeln!(f, "{out} = {item_out}{{")?;
740 for i in 0..index {
741 let lhsi = lhs.index(i);
742 let rhsi = rhs.index(i);
743
744 writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
745 f.write_str(", ")?;
746 }
747
748 f.write_str("};\n")
749 };
750
751 if item_out_original == item_out_optimized {
752 write_op(&lhs, &rhs, out, item_out_optimized)
753 } else {
754 let out_tmp = Variable::tmp(item_out_optimized);
755
756 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
757
758 let qualifier = out.const_qualifier();
759 let out = out.fmt_left();
760
761 writeln!(
762 f,
763 "{out} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n"
764 )?;
765
766 Ok(())
767 }
768 }
769}
770
771struct Magnitude<D: Dialect> {
772 _dialect: PhantomData<D>,
773}
774
775impl<D: Dialect> Magnitude<D> {
776 fn format(
777 f: &mut core::fmt::Formatter<'_>,
778 input: &Variable<D>,
779 out: &Variable<D>,
780 ) -> core::fmt::Result {
781 let num = input.item().vectorization;
782 let elem = input.elem();
783
784 let mag = format!("{out}_mag");
785
786 writeln!(f, "{} {mag} = 0.0;", out.item())?;
787
788 for i in 0..num {
789 let input_i = input.index(i);
790 writeln!(f, "{mag} += {input_i} * {input_i};")?;
791 }
792
793 let out = out.fmt_left();
794 write!(f, "{out} = ")?;
795 Sqrt::format_unary(f, &mag, elem)?;
796 f.write_str(";\n")
797 }
798}
799
800struct Normalize<D: Dialect> {
801 _dialect: PhantomData<D>,
802}
803
804impl<D: Dialect> Normalize<D> {
805 fn format(
806 f: &mut core::fmt::Formatter<'_>,
807 input: &Variable<D>,
808 out: &Variable<D>,
809 ) -> core::fmt::Result {
810 let num = input.item().vectorization;
811 let elem = input.elem();
812 let norm = format!("{out}_norm");
813
814 let out_item = out.item();
815 let out = out.fmt_left();
816 writeln!(f, "{elem} {norm} = 0.0;")?;
817
818 for i in 0..num {
819 let input_i = input.index(i);
820 writeln!(f, "{norm} += {input_i} * {input_i};")?;
821 }
822
823 write!(f, "{norm} = ")?;
824 Sqrt::format_unary(f, &norm, elem)?;
825 f.write_str(";\n")?;
826
827 if num == 1 {
828 writeln!(f, "{out} = {input} / {norm};")
829 } else {
830 write!(f, "{out} = {out_item}{{")?;
831 for i in 0..num {
832 let input_i = input.index(i);
833
834 writeln!(f, "{input_i} / {norm},")?;
835 }
836
837 f.write_str("};\n")
838 }
839 }
840}
841
842struct Dot<D: Dialect> {
843 _dialect: PhantomData<D>,
844}
845
846impl<D: Dialect> Dot<D> {
847 fn format(
848 f: &mut core::fmt::Formatter<'_>,
849 lhs: &Variable<D>,
850 rhs: &Variable<D>,
851 out: &Variable<D>,
852 ) -> core::fmt::Result {
853 let num = lhs.item().vectorization;
854
855 let muls = (0..num)
856 .map(|i| {
857 let lhs_i = lhs.index(i);
858 let rhs_i = rhs.index(i);
859 format!("{lhs_i} * {rhs_i}")
860 })
861 .collect::<Vec<_>>();
862
863 let out = out.fmt_left();
864 writeln!(f, "{out} = {};", muls.join(" + "))
865 }
866}
867
868struct EnsureBoolArg<'a, V: Display, D: Dialect> {
869 var: &'a V,
870 elem: &'a Elem<D>,
871}
872
873impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
874 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875 if self.elem != &Elem::Bool {
876 write!(f, "bool({})", self.var)
877 } else {
878 write!(f, "{}", self.var)
879 }
880 }
881}