1use std::fmt::Display;
2
3use crate::shared::{Component, FmtLeft};
4
5use super::{Dialect, IndexedVariable, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9 ReduceSum {
10 input: Variable<D>,
11 out: Variable<D>,
12 },
13 InclusiveSum {
14 input: Variable<D>,
15 out: Variable<D>,
16 },
17 ExclusiveSum {
18 input: Variable<D>,
19 out: Variable<D>,
20 },
21 ReduceProd {
22 input: Variable<D>,
23 out: Variable<D>,
24 },
25 InclusiveProd {
26 input: Variable<D>,
27 out: Variable<D>,
28 },
29 ExclusiveProd {
30 input: Variable<D>,
31 out: Variable<D>,
32 },
33 ReduceMax {
34 input: Variable<D>,
35 out: Variable<D>,
36 },
37 ReduceMin {
38 input: Variable<D>,
39 out: Variable<D>,
40 },
41 Elect {
42 out: Variable<D>,
43 },
44 All {
45 input: Variable<D>,
46 out: Variable<D>,
47 },
48 Any {
49 input: Variable<D>,
50 out: Variable<D>,
51 },
52 Ballot {
53 input: Variable<D>,
54 out: Variable<D>,
55 },
56 Broadcast {
57 input: Variable<D>,
58 id: Variable<D>,
59 out: Variable<D>,
60 },
61 Shuffle {
62 input: Variable<D>,
63 src_lane: Variable<D>,
64 out: Variable<D>,
65 },
66 ShuffleXor {
67 input: Variable<D>,
68 mask: Variable<D>,
69 out: Variable<D>,
70 },
71 ShuffleUp {
72 input: Variable<D>,
73 delta: Variable<D>,
74 out: Variable<D>,
75 },
76 ShuffleDown {
77 input: Variable<D>,
78 delta: Variable<D>,
79 out: Variable<D>,
80 },
81}
82
83impl<D: Dialect> Display for WarpInstruction<D> {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 WarpInstruction::ReduceSum { input, out } => D::warp_reduce_sum(f, input, out),
87 WarpInstruction::ReduceProd { input, out } => D::warp_reduce_prod(f, input, out),
88 WarpInstruction::ReduceMax { input, out } => D::warp_reduce_max(f, input, out),
89 WarpInstruction::ReduceMin { input, out } => D::warp_reduce_min(f, input, out),
90 WarpInstruction::All { input, out } => D::warp_reduce_all(f, input, out),
91 WarpInstruction::Any { input, out } => D::warp_reduce_any(f, input, out),
92
93 WarpInstruction::InclusiveSum { input, out } => {
94 D::warp_reduce_sum_inclusive(f, input, out)
95 }
96 WarpInstruction::InclusiveProd { input, out } => {
97 D::warp_reduce_prod_inclusive(f, input, out)
98 }
99 WarpInstruction::ExclusiveSum { input, out } => {
100 D::warp_reduce_sum_exclusive(f, input, out)
101 }
102 WarpInstruction::ExclusiveProd { input, out } => {
103 D::warp_reduce_prod_exclusive(f, input, out)
104 }
105 WarpInstruction::Ballot { input, out } => {
106 assert_eq!(
107 input.item().vectorization,
108 1,
109 "Ballot can't support vectorized input"
110 );
111 let out_fmt = out.fmt_left();
112 write!(
113 f,
114 "
115{out_fmt} = {{ "
116 )?;
117 D::compile_warp_ballot(f, input, out.item().elem())?;
118 writeln!(f, ", 0, 0, 0 }};")
119 }
120 WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
121 WarpInstruction::Shuffle {
122 input,
123 src_lane,
124 out,
125 } => {
126 let out_fmt = out.fmt_left();
127 write!(f, "{out_fmt} = {{ ")?;
128 for i in 0..input.item().vectorization {
129 let comma = if i > 0 { ", " } else { "" };
130 write!(f, "{comma}")?;
131 D::compile_warp_shuffle(
132 f,
133 &format!("{}", input.index(i)),
134 &format!("{src_lane}"),
135 )?;
136 }
137 writeln!(f, " }};")
138 }
139 WarpInstruction::ShuffleXor { input, mask, out } => {
140 let out_fmt = out.fmt_left();
141 write!(f, "{out_fmt} = {{ ")?;
142 for i in 0..input.item().vectorization {
143 let comma = if i > 0 { ", " } else { "" };
144 write!(f, "{comma}")?;
145 D::compile_warp_shuffle_xor(
146 f,
147 &format!("{}", input.index(i)),
148 input.item().elem(),
149 &format!("{mask}"),
150 )?;
151 }
152 writeln!(f, " }};")
153 }
154 WarpInstruction::ShuffleUp { input, delta, out } => {
155 let out_fmt = out.fmt_left();
156 write!(f, "{out_fmt} = {{ ")?;
157 for i in 0..input.item().vectorization {
158 let comma = if i > 0 { ", " } else { "" };
159 write!(f, "{comma}")?;
160 D::compile_warp_shuffle_up(
161 f,
162 &format!("{}", input.index(i)),
163 &format!("{delta}"),
164 )?;
165 }
166 writeln!(f, " }};")
167 }
168 WarpInstruction::ShuffleDown { input, delta, out } => {
169 let out_fmt = out.fmt_left();
170 write!(f, "{out_fmt} = {{ ")?;
171 for i in 0..input.item().vectorization {
172 let comma = if i > 0 { ", " } else { "" };
173 write!(f, "{comma}")?;
174 D::compile_warp_shuffle_down(
175 f,
176 &format!("{}", input.index(i)),
177 &format!("{delta}"),
178 )?;
179 }
180 writeln!(f, " }};")
181 }
182 WarpInstruction::Elect { out } => write!(
183 f,
184 "
185unsigned int mask = __activemask();
186unsigned int leader = __ffs(mask) - 1;
187{out} = threadIdx.x % warpSize == leader;
188 "
189 ),
190 }
191 }
192}
193
194pub(crate) fn reduce_operator<D: Dialect>(
195 f: &mut core::fmt::Formatter<'_>,
196 input: &Variable<D>,
197 out: &Variable<D>,
198 op: &str,
199) -> core::fmt::Result {
200 let in_optimized = input.optimized();
201 let acc_item = in_optimized.item();
202
203 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
204 let acc_indexed = maybe_index(acc, index);
205 write!(f, "{acc_indexed} {op} ")?;
206 D::compile_warp_shuffle_xor(f, &acc_indexed, acc.item().elem(), "offset")?;
207 writeln!(f, ";")
208 })
209}
210
211pub(crate) fn reduce_comparison<
212 D: Dialect,
213 I: Fn(&mut core::fmt::Formatter<'_>, Item<D>) -> std::fmt::Result,
214>(
215 f: &mut core::fmt::Formatter<'_>,
216 input: &Variable<D>,
217 out: &Variable<D>,
218 instruction: I,
219) -> core::fmt::Result {
220 let in_optimized = input.optimized();
221 let acc_item = in_optimized.item();
222 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
223 let acc_indexed = maybe_index(acc, index);
224 let acc_elem = acc_item.elem();
225 write!(f, " {acc_indexed} = ")?;
226 instruction(f, in_optimized.item())?;
227 write!(f, "({acc_indexed}, ")?;
228 D::compile_warp_shuffle_xor(f, &acc_indexed, acc_elem, "offset")?;
229 writeln!(f, ");")
230 })
231}
232
233pub(crate) fn reduce_inclusive<D: Dialect>(
234 f: &mut core::fmt::Formatter<'_>,
235 input: &Variable<D>,
236 out: &Variable<D>,
237 op: &str,
238) -> core::fmt::Result {
239 let in_optimized = input.optimized();
240 let acc_item = in_optimized.item();
241
242 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
243 let acc_indexed = maybe_index(acc, index);
244 let tmp = Variable::tmp(Item::scalar(acc_item.elem, false));
245 let tmp_left = tmp.fmt_left();
246 let lane_id = Variable::<D>::UnitPosPlane;
247 write!(
248 f,
249 "
250{tmp_left} = "
251 )?;
252 D::compile_warp_shuffle_up(f, &acc_indexed, "offset")?;
253 write!(
254 f,
255 ";
256if({lane_id} >= offset) {{
257 {acc_indexed} {op} {tmp};
258}}
259"
260 )
261 })
262}
263
264pub(crate) fn reduce_exclusive<D: Dialect>(
265 f: &mut core::fmt::Formatter<'_>,
266 input: &Variable<D>,
267 out: &Variable<D>,
268 op: &str,
269 default: &str,
270) -> core::fmt::Result {
271 let in_optimized = input.optimized();
272 let acc_item = in_optimized.item();
273
274 let inclusive = Variable::tmp(acc_item);
275 reduce_inclusive(f, input, &inclusive, op)?;
276 let shfl = Variable::tmp(acc_item);
277 writeln!(f, "{} = {{", shfl.fmt_left())?;
278 for k in 0..acc_item.vectorization {
279 let inclusive_indexed = maybe_index(&inclusive, k);
280 let comma = if k > 0 { ", " } else { "" };
281 write!(f, "{comma}")?;
282 D::compile_warp_shuffle_up(f, &inclusive_indexed.to_string(), "1")?;
283 }
284 writeln!(f, "}};")?;
285 let lane_id = Variable::<D>::UnitPosPlane;
286
287 write!(
288 f,
289 "{} = ({lane_id} == 0) ? {}{{",
290 out.fmt_left(),
291 out.item(),
292 )?;
293 for _ in 0..out.item().vectorization {
294 write!(f, "{default},")?;
295 }
296 writeln!(f, "}} : {};", cast(&shfl, out.item()))
297}
298
299pub(crate) fn reduce_broadcast<D: Dialect>(
300 f: &mut core::fmt::Formatter<'_>,
301 input: &Variable<D>,
302 out: &Variable<D>,
303 id: &Variable<D>,
304) -> core::fmt::Result {
305 let out_fmt = out.fmt_left();
306 write!(f, "{out_fmt} = {{ ")?;
307 for i in 0..input.item().vectorization {
308 let comma = if i > 0 { ", " } else { "" };
309 write!(f, "{comma}")?;
310 D::compile_warp_shuffle(f, &format!("{}", input.index(i)), &format!("{id}"))?;
311 }
312 writeln!(f, " }};")
313}
314
315fn reduce_with_loop<
316 D: Dialect,
317 I: Fn(&mut core::fmt::Formatter<'_>, &Variable<D>, usize) -> std::fmt::Result,
318>(
319 f: &mut core::fmt::Formatter<'_>,
320 input: &Variable<D>,
321 out: &Variable<D>,
322 acc_item: Item<D>,
323 instruction: I,
324) -> core::fmt::Result {
325 let acc = Variable::Named {
326 name: "acc",
327 item: acc_item,
328 };
329 let vectorization = acc_item.vectorization;
330
331 writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
332 writeln!(f, " {} {} = {};", acc_item, acc, cast(input, acc_item))?;
333 write!(f, " for (uint offset = 1; offset < ")?;
334 D::compile_plane_dim_checked(f)?;
335 writeln!(f, "; offset *=2 ) {{")?;
336 for k in 0..vectorization {
337 instruction(f, &acc, k)?;
338 }
339 writeln!(f, " }};")?;
340 writeln!(f, " return {};", cast(&acc, out.item()))?;
341 writeln!(f, "}};")?;
342 writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
343}
344
345pub(crate) fn reduce_quantifier<
346 D: Dialect,
347 Q: Fn(&mut core::fmt::Formatter<'_>, &IndexedVariable<D>) -> std::fmt::Result,
348>(
349 f: &mut core::fmt::Formatter<'_>,
350 input: &Variable<D>,
351 out: &Variable<D>,
352 quantifier: Q,
353) -> core::fmt::Result {
354 let out_fmt = out.fmt_left();
355 write!(f, "{out_fmt} = {{ ")?;
356 for i in 0..input.item().vectorization {
357 let comma = if i > 0 { ", " } else { "" };
358 write!(f, "{comma}")?;
359 quantifier(f, &input.index(i))?;
360 }
361 writeln!(f, "}};")
362}
363
364fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
365 if target != input.item() {
366 let addr_space = D::address_space_for_variable(input);
367 let qualifier = input.const_qualifier();
368 format!("reinterpret_cast<{addr_space}{target}{qualifier}&>({input})")
369 } else {
370 format!("{input}")
371 }
372}
373
374fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
375 if var.item().vectorization > 1 {
376 format!("{var}.i_{k}")
377 } else {
378 format!("{var}")
379 }
380}