1use std::fmt::Display;
2
3use crate::shared::{Component, FmtLeft};
4
5use super::{Dialect, IndexedVariable, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9 ReduceSum {
10 input: Variable<D>,
11 out: Variable<D>,
12 },
13 InclusiveSum {
14 input: Variable<D>,
15 out: Variable<D>,
16 },
17 ExclusiveSum {
18 input: Variable<D>,
19 out: Variable<D>,
20 },
21 ReduceProd {
22 input: Variable<D>,
23 out: Variable<D>,
24 },
25 InclusiveProd {
26 input: Variable<D>,
27 out: Variable<D>,
28 },
29 ExclusiveProd {
30 input: Variable<D>,
31 out: Variable<D>,
32 },
33 ReduceMax {
34 input: Variable<D>,
35 out: Variable<D>,
36 },
37 ReduceMin {
38 input: Variable<D>,
39 out: Variable<D>,
40 },
41 Elect {
42 out: Variable<D>,
43 },
44 All {
45 input: Variable<D>,
46 out: Variable<D>,
47 },
48 Any {
49 input: Variable<D>,
50 out: Variable<D>,
51 },
52 Ballot {
53 input: Variable<D>,
54 out: Variable<D>,
55 },
56 Broadcast {
57 input: Variable<D>,
58 id: Variable<D>,
59 out: Variable<D>,
60 },
61}
62
63impl<D: Dialect> Display for WarpInstruction<D> {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 WarpInstruction::ReduceSum { input, out } => reduce_operator(f, input, out, "+="),
67 WarpInstruction::ReduceProd { input, out } => reduce_operator(f, input, out, "*="),
68 WarpInstruction::ReduceMax { input, out } => {
69 reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
70 }
71 WarpInstruction::ReduceMin { input, out } => {
72 reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
73 }
74 WarpInstruction::All { input, out } => {
75 reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
76 }
77 WarpInstruction::Any { input, out } => {
78 reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
79 }
80 WarpInstruction::Ballot { input, out } => {
81 assert_eq!(
82 input.item().vectorization,
83 1,
84 "Ballot can't support vectorized input"
85 );
86 let out_fmt = out.fmt_left();
87 write!(
88 f,
89 "
90{out_fmt} = {{ "
91 )?;
92 D::compile_warp_ballot(f, input, out.item().elem())?;
93 writeln!(f, ", 0, 0, 0 }};")
94 }
95 WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
96 WarpInstruction::Elect { out } => write!(
97 f,
98 "
99unsigned int mask = __activemask();
100unsigned int leader = __ffs(mask) - 1;
101{out} = threadIdx.x % warpSize == leader;
102 "
103 ),
104 WarpInstruction::InclusiveSum { input, out } => reduce_inclusive(f, input, out, "+="),
105 WarpInstruction::InclusiveProd { input, out } => reduce_inclusive(f, input, out, "*="),
106 WarpInstruction::ExclusiveSum { input, out } => {
107 reduce_exclusive(f, input, out, "+=", "0")
108 }
109 WarpInstruction::ExclusiveProd { input, out } => {
110 reduce_exclusive(f, input, out, "*=", "1")
111 }
112 }
113 }
114}
115
116fn reduce_operator<D: Dialect>(
117 f: &mut core::fmt::Formatter<'_>,
118 input: &Variable<D>,
119 out: &Variable<D>,
120 op: &str,
121) -> core::fmt::Result {
122 let in_optimized = input.optimized();
123 let acc_item = in_optimized.item();
124
125 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
126 let acc_indexed = maybe_index(acc, index);
127 write!(f, "{acc_indexed} {op} ")?;
128 D::compile_warp_shuffle_xor(f, &acc_indexed, acc.item().elem(), "offset")?;
129 writeln!(f, ";")
130 })
131}
132
133fn reduce_comparison<
134 D: Dialect,
135 I: Fn(&mut core::fmt::Formatter<'_>, Item<D>) -> std::fmt::Result,
136>(
137 f: &mut core::fmt::Formatter<'_>,
138 input: &Variable<D>,
139 out: &Variable<D>,
140 instruction: I,
141) -> core::fmt::Result {
142 let in_optimized = input.optimized();
143 let acc_item = in_optimized.item();
144 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
145 let acc_indexed = maybe_index(acc, index);
146 let acc_elem = acc_item.elem();
147 write!(f, " {acc_indexed} = ")?;
148 instruction(f, in_optimized.item())?;
149 write!(f, "({acc_indexed}, ")?;
150 D::compile_warp_shuffle_xor(f, &acc_indexed, acc_elem, "offset")?;
151 writeln!(f, ");")
152 })
153}
154
155fn reduce_inclusive<D: Dialect>(
156 f: &mut core::fmt::Formatter<'_>,
157 input: &Variable<D>,
158 out: &Variable<D>,
159 op: &str,
160) -> core::fmt::Result {
161 let in_optimized = input.optimized();
162 let acc_item = in_optimized.item();
163
164 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
165 let acc_indexed = maybe_index(acc, index);
166 let tmp = Variable::tmp(Item::scalar(acc_item.elem, false));
167 let tmp_left = tmp.fmt_left();
168 let lane_id = Variable::<D>::UnitPosPlane;
169 write!(
170 f,
171 "
172{tmp_left} = "
173 )?;
174 D::compile_warp_shuffle_up(f, &acc_indexed, "offset")?;
175 write!(
176 f,
177 ";
178if({lane_id} >= offset) {{
179 {acc_indexed} {op} {tmp};
180}}
181"
182 )
183 })
184}
185
186fn reduce_exclusive<D: Dialect>(
187 f: &mut core::fmt::Formatter<'_>,
188 input: &Variable<D>,
189 out: &Variable<D>,
190 op: &str,
191 default: &str,
192) -> core::fmt::Result {
193 let in_optimized = input.optimized();
194 let acc_item = in_optimized.item();
195
196 let inclusive = Variable::tmp(acc_item);
197 reduce_inclusive(f, input, &inclusive, op)?;
198 let shfl = Variable::tmp(acc_item);
199 writeln!(f, "{} = {{", shfl.fmt_left())?;
200 for k in 0..acc_item.vectorization {
201 let inclusive_indexed = maybe_index(&inclusive, k);
202 let comma = if k > 0 { ", " } else { "" };
203 write!(f, "{comma}")?;
204 D::compile_warp_shuffle_up(f, &inclusive_indexed.to_string(), "1")?;
205 }
206 writeln!(f, "}};")?;
207 let lane_id = Variable::<D>::UnitPosPlane;
208
209 write!(
210 f,
211 "{} = ({lane_id} == 0) ? {}{{",
212 out.fmt_left(),
213 out.item(),
214 )?;
215 for _ in 0..out.item().vectorization {
216 write!(f, "{default},")?;
217 }
218 writeln!(f, "}} : {};", cast(&shfl, out.item()))
219}
220
221fn reduce_broadcast<D: Dialect>(
222 f: &mut core::fmt::Formatter<'_>,
223 input: &Variable<D>,
224 out: &Variable<D>,
225 id: &Variable<D>,
226) -> core::fmt::Result {
227 let out_fmt = out.fmt_left();
228 write!(f, "{out_fmt} = {{ ")?;
229 for i in 0..input.item().vectorization {
230 let comma = if i > 0 { ", " } else { "" };
231 write!(f, "{comma}")?;
232 D::compile_warp_shuffle(f, &format!("{}", input.index(i)), &format!("{id}"))?;
233 }
234 writeln!(f, " }};")
235}
236
237fn reduce_with_loop<
238 D: Dialect,
239 I: Fn(&mut core::fmt::Formatter<'_>, &Variable<D>, usize) -> std::fmt::Result,
240>(
241 f: &mut core::fmt::Formatter<'_>,
242 input: &Variable<D>,
243 out: &Variable<D>,
244 acc_item: Item<D>,
245 instruction: I,
246) -> core::fmt::Result {
247 let acc = Variable::Named {
248 name: "acc",
249 item: acc_item,
250 };
251 let vectorization = acc_item.vectorization;
252
253 writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
254 writeln!(f, " {} {} = {};", acc_item, acc, cast(input, acc_item))?;
255 write!(f, " for (uint offset = 1; offset < ")?;
256 D::compile_plane_dim_checked(f)?;
257 writeln!(f, "; offset *=2 ) {{")?;
258 for k in 0..vectorization {
259 instruction(f, &acc, k)?;
260 }
261 writeln!(f, " }};")?;
262 writeln!(f, " return {};", cast(&acc, out.item()))?;
263 writeln!(f, "}};")?;
264 writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
265}
266
267fn reduce_quantifier<
268 D: Dialect,
269 Q: Fn(&mut core::fmt::Formatter<'_>, &IndexedVariable<D>) -> std::fmt::Result,
270>(
271 f: &mut core::fmt::Formatter<'_>,
272 input: &Variable<D>,
273 out: &Variable<D>,
274 quantifier: Q,
275) -> core::fmt::Result {
276 let out_fmt = out.fmt_left();
277 write!(f, "{out_fmt} = {{ ")?;
278 for i in 0..input.item().vectorization {
279 let comma = if i > 0 { ", " } else { "" };
280 write!(f, "{comma}")?;
281 quantifier(f, &input.index(i))?;
282 }
283 writeln!(f, "}};")
284}
285
286fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
287 if target != input.item() {
288 let addr_space = D::address_space_for_variable(input);
289 let qualifier = input.const_qualifier();
290 format!("reinterpret_cast<{addr_space}{target}{qualifier}&>({input})")
291 } else {
292 format!("{}", input)
293 }
294}
295
296fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
297 if var.item().vectorization > 1 {
298 format!("{var}.i_{k}")
299 } else {
300 format!("{var}")
301 }
302}