1use std::fmt::Display;
2
3use crate::shared::{Component, Elem, FmtLeft};
4
5use super::{Dialect, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9 ReduceSum {
10 input: Variable<D>,
11 out: Variable<D>,
12 },
13 ReduceProd {
14 input: Variable<D>,
15 out: Variable<D>,
16 },
17 ReduceMax {
18 input: Variable<D>,
19 out: Variable<D>,
20 },
21 ReduceMin {
22 input: Variable<D>,
23 out: Variable<D>,
24 },
25 Elect {
26 out: Variable<D>,
27 },
28 All {
29 input: Variable<D>,
30 out: Variable<D>,
31 },
32 Any {
33 input: Variable<D>,
34 out: Variable<D>,
35 },
36 Broadcast {
37 input: Variable<D>,
38 id: Variable<D>,
39 out: Variable<D>,
40 },
41}
42
43impl<D: Dialect> Display for WarpInstruction<D> {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 WarpInstruction::ReduceSum { input, out } => reduce_operator(f, input, out, "+="),
47 WarpInstruction::ReduceProd { input, out } => reduce_operator(f, input, out, "*="),
48 WarpInstruction::ReduceMax { input, out } => reduce_comparison(f, input, out, "max"),
49 WarpInstruction::ReduceMin { input, out } => reduce_comparison(f, input, out, "min"),
50 WarpInstruction::All { input, out } => reduce_quantifier(f, input, out, D::warp_all),
51 WarpInstruction::Any { input, out } => reduce_quantifier(f, input, out, D::warp_any),
52 WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
53 WarpInstruction::Elect { out } => write!(
54 f,
55 "
56unsigned int mask = __activemask();
57unsigned int leader = __ffs(mask) - 1;
58{out} = threadIdx.x % warpSize == leader;
59 "
60 ),
61 }
62 }
63}
64
65fn reduce_operator<D: Dialect>(
66 f: &mut core::fmt::Formatter<'_>,
67 input: &Variable<D>,
68 out: &Variable<D>,
69 op: &str,
70) -> core::fmt::Result {
71 let in_optimized = input.optimized();
72 let acc_item = in_optimized.item();
73
74 reduce_with_loop(f, input, out, acc_item, |acc, index| {
75 let acc_indexed = maybe_index(acc, index);
76 let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset");
77 format!("{acc_indexed} {op} {shfl_xor};")
78 })
79}
80
81fn reduce_comparison<D: Dialect>(
82 f: &mut core::fmt::Formatter<'_>,
83 input: &Variable<D>,
84 out: &Variable<D>,
85 cmp: &str,
86) -> core::fmt::Result {
87 let in_optimized = input.optimized();
88 let acc_item = in_optimized.item();
89 let instruction = match in_optimized.elem() {
90 Elem::F16 | Elem::BF16 => format!("__h{cmp}"),
91 Elem::F162 | Elem::BF162 => format!("__h{cmp}2"),
92 _ => cmp.to_string(),
93 };
94
95 reduce_with_loop(f, input, out, acc_item, |acc, index| {
96 let acc_indexed = maybe_index(acc, index);
97 let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset");
98 format!("{acc_indexed} = {instruction}({acc_indexed}, {shfl_xor});")
99 })
100}
101
102fn reduce_broadcast<D: Dialect>(
103 f: &mut core::fmt::Formatter<'_>,
104 input: &Variable<D>,
105 out: &Variable<D>,
106 id: &Variable<D>,
107) -> core::fmt::Result {
108 let rhs = (0..input.item().vectorization)
109 .map(|k| D::warp_shuffle(&format!("{}", input.index(k)), &format!("{id}")))
110 .collect::<Vec<_>>()
111 .join(",");
112 let out_fmt = out.fmt_left();
113 writeln!(f, "{out_fmt} = {{ {rhs} }};")
114}
115
116fn reduce_with_loop<D: Dialect, I: Fn(&Variable<D>, usize) -> String>(
117 f: &mut core::fmt::Formatter<'_>,
118 input: &Variable<D>,
119 out: &Variable<D>,
120 acc_item: Item<D>,
121 instruction: I,
122) -> core::fmt::Result {
123 let acc = Variable::Named {
124 name: "acc",
125 item: acc_item,
126 };
127
128 writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
129 writeln!(f, " {} {} = {};", acc_item, acc, cast(input, acc_item))?;
130 writeln!(
131 f,
132 " for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{"
133 )?;
134 for k in 0..acc_item.vectorization {
135 writeln!(f, " {}", instruction(&acc, k))?;
136 }
137 writeln!(f, " }};")?;
138 writeln!(f, " return {};", cast(&acc, out.item()))?;
139 writeln!(f, "}};")?;
140 writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
141}
142
143fn reduce_quantifier<D: Dialect, Q: Fn(&str) -> String>(
144 f: &mut core::fmt::Formatter<'_>,
145 input: &Variable<D>,
146 out: &Variable<D>,
147 quantifier: Q,
148) -> core::fmt::Result {
149 let rhs = (0..input.item().vectorization)
150 .map(|k| quantifier(&format!("{}", input.index(k))))
151 .collect::<Vec<_>>()
152 .join(",");
153 let out_fmt = out.fmt_left();
154 writeln!(f, "{out_fmt} = {{ {rhs} }};")
155}
156
157fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
158 if target != input.item() {
159 let qualifier = input.const_qualifier();
160 format!("reinterpret_cast<{}{}&>({})", target, qualifier, input)
161 } else {
162 format!("{}", input)
163 }
164}
165
166fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
167 if var.item().vectorization > 1 {
168 format!("{var}.i_{k}")
169 } else {
170 format!("{var}")
171 }
172}