1use std::fmt::Display;
2
3use crate::shared::{Component, FmtLeft};
4
5use super::{Dialect, IndexedVariable, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9 ReduceSum {
10 input: Variable<D>,
11 out: Variable<D>,
12 },
13 InclusiveSum {
14 input: Variable<D>,
15 out: Variable<D>,
16 },
17 ExclusiveSum {
18 input: Variable<D>,
19 out: Variable<D>,
20 },
21 ReduceProd {
22 input: Variable<D>,
23 out: Variable<D>,
24 },
25 InclusiveProd {
26 input: Variable<D>,
27 out: Variable<D>,
28 },
29 ExclusiveProd {
30 input: Variable<D>,
31 out: Variable<D>,
32 },
33 ReduceMax {
34 input: Variable<D>,
35 out: Variable<D>,
36 },
37 ReduceMin {
38 input: Variable<D>,
39 out: Variable<D>,
40 },
41 ElectFallback {
42 out: Variable<D>,
43 },
44 Elect {
45 out: Variable<D>,
46 },
47 All {
48 input: Variable<D>,
49 out: Variable<D>,
50 },
51 Any {
52 input: Variable<D>,
53 out: Variable<D>,
54 },
55 Ballot {
56 input: Variable<D>,
57 out: Variable<D>,
58 },
59 Broadcast {
60 input: Variable<D>,
61 id: Variable<D>,
62 out: Variable<D>,
63 },
64 Shuffle {
65 input: Variable<D>,
66 src_lane: Variable<D>,
67 out: Variable<D>,
68 },
69 ShuffleXor {
70 input: Variable<D>,
71 mask: Variable<D>,
72 out: Variable<D>,
73 },
74 ShuffleUp {
75 input: Variable<D>,
76 delta: Variable<D>,
77 out: Variable<D>,
78 },
79 ShuffleDown {
80 input: Variable<D>,
81 delta: Variable<D>,
82 out: Variable<D>,
83 },
84}
85
86impl<D: Dialect> Display for WarpInstruction<D> {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 match self {
89 WarpInstruction::ReduceSum { input, out } => D::warp_reduce_sum(f, input, out),
90 WarpInstruction::ReduceProd { input, out } => D::warp_reduce_prod(f, input, out),
91 WarpInstruction::ReduceMax { input, out } => D::warp_reduce_max(f, input, out),
92 WarpInstruction::ReduceMin { input, out } => D::warp_reduce_min(f, input, out),
93 WarpInstruction::All { input, out } => D::warp_reduce_all(f, input, out),
94 WarpInstruction::Any { input, out } => D::warp_reduce_any(f, input, out),
95
96 WarpInstruction::InclusiveSum { input, out } => {
97 D::warp_reduce_sum_inclusive(f, input, out)
98 }
99 WarpInstruction::InclusiveProd { input, out } => {
100 D::warp_reduce_prod_inclusive(f, input, out)
101 }
102 WarpInstruction::ExclusiveSum { input, out } => {
103 D::warp_reduce_sum_exclusive(f, input, out)
104 }
105 WarpInstruction::ExclusiveProd { input, out } => {
106 D::warp_reduce_prod_exclusive(f, input, out)
107 }
108 WarpInstruction::Ballot { input, out } => {
109 assert_eq!(
110 input.item().vectorization,
111 1,
112 "Ballot can't support vectorized input"
113 );
114 let out_fmt = out.fmt_left();
115 write!(
116 f,
117 "
118{out_fmt} = {{ "
119 )?;
120 D::compile_warp_ballot(f, input, out.item().elem())?;
121 writeln!(f, ", 0, 0, 0 }};")
122 }
123 WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
124 WarpInstruction::Shuffle {
125 input,
126 src_lane,
127 out,
128 } => {
129 let out_fmt = out.fmt_left();
130 write!(f, "{out_fmt} = {{ ")?;
131 for i in 0..input.item().vectorization {
132 let comma = if i > 0 { ", " } else { "" };
133 write!(f, "{comma}")?;
134 D::compile_warp_shuffle(
135 f,
136 &format!("{}", input.index(i)),
137 &format!("{src_lane}"),
138 )?;
139 }
140 writeln!(f, " }};")
141 }
142 WarpInstruction::ShuffleXor { input, mask, out } => {
143 let out_fmt = out.fmt_left();
144 write!(f, "{out_fmt} = {{ ")?;
145 for i in 0..input.item().vectorization {
146 let comma = if i > 0 { ", " } else { "" };
147 write!(f, "{comma}")?;
148 D::compile_warp_shuffle_xor(
149 f,
150 &format!("{}", input.index(i)),
151 input.item().elem(),
152 &format!("{mask}"),
153 )?;
154 }
155 writeln!(f, " }};")
156 }
157 WarpInstruction::ShuffleUp { input, delta, out } => {
158 let out_fmt = out.fmt_left();
159 write!(f, "{out_fmt} = {{ ")?;
160 for i in 0..input.item().vectorization {
161 let comma = if i > 0 { ", " } else { "" };
162 write!(f, "{comma}")?;
163 D::compile_warp_shuffle_up(
164 f,
165 &format!("{}", input.index(i)),
166 &format!("{delta}"),
167 )?;
168 }
169 writeln!(f, " }};")
170 }
171 WarpInstruction::ShuffleDown { input, delta, out } => {
172 let out_fmt = out.fmt_left();
173 write!(f, "{out_fmt} = {{ ")?;
174 for i in 0..input.item().vectorization {
175 let comma = if i > 0 { ", " } else { "" };
176 write!(f, "{comma}")?;
177 D::compile_warp_shuffle_down(
178 f,
179 &format!("{}", input.index(i)),
180 &format!("{delta}"),
181 )?;
182 }
183 writeln!(f, " }};")
184 }
185 WarpInstruction::ElectFallback { out } => {
186 let out = out.fmt_left();
187 write!(
188 f,
189 "
190unsigned int mask = __activemask();
191unsigned int leader = __ffs(mask) - 1;
192{out} = threadIdx.x % warpSize == leader;
193 "
194 )
195 }
196 WarpInstruction::Elect { out } => {
197 let out = out.fmt_left();
198 D::compile_warp_elect(f, &out)
199 }
200 }
201 }
202}
203
204pub(crate) fn reduce_operator<D: Dialect>(
205 f: &mut core::fmt::Formatter<'_>,
206 input: &Variable<D>,
207 out: &Variable<D>,
208 op: &str,
209) -> core::fmt::Result {
210 let in_optimized = input.optimized();
211 let acc_item = in_optimized.item();
212
213 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
214 let acc_indexed = maybe_index(acc, index);
215 write!(f, "{acc_indexed} {op} ")?;
216 D::compile_warp_shuffle_xor(f, &acc_indexed, acc.item().elem(), "offset")?;
217 writeln!(f, ";")
218 })
219}
220
221pub(crate) fn reduce_comparison<
222 D: Dialect,
223 I: Fn(&mut core::fmt::Formatter<'_>, Item<D>) -> std::fmt::Result,
224>(
225 f: &mut core::fmt::Formatter<'_>,
226 input: &Variable<D>,
227 out: &Variable<D>,
228 instruction: I,
229) -> core::fmt::Result {
230 let in_optimized = input.optimized();
231 let acc_item = in_optimized.item();
232 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
233 let acc_indexed = maybe_index(acc, index);
234 let acc_elem = acc_item.elem();
235 write!(f, " {acc_indexed} = ")?;
236 instruction(f, in_optimized.item())?;
237 write!(f, "({acc_indexed}, ")?;
238 D::compile_warp_shuffle_xor(f, &acc_indexed, acc_elem, "offset")?;
239 writeln!(f, ");")
240 })
241}
242
243pub(crate) fn reduce_inclusive<D: Dialect>(
244 f: &mut core::fmt::Formatter<'_>,
245 input: &Variable<D>,
246 out: &Variable<D>,
247 op: &str,
248) -> core::fmt::Result {
249 let in_optimized = input.optimized();
250 let acc_item = in_optimized.item();
251
252 reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
253 let acc_indexed = maybe_index(acc, index);
254 let tmp = Variable::tmp(Item::scalar(acc_item.elem, false));
255 let tmp_left = tmp.fmt_left();
256 let lane_id = Variable::<D>::UnitPosPlane;
257 write!(
258 f,
259 "
260{tmp_left} = "
261 )?;
262 D::compile_warp_shuffle_up(f, &acc_indexed, "offset")?;
263 write!(
264 f,
265 ";
266if({lane_id} >= offset) {{
267 {acc_indexed} {op} {tmp};
268}}
269"
270 )
271 })
272}
273
274pub(crate) fn reduce_exclusive<D: Dialect>(
275 f: &mut core::fmt::Formatter<'_>,
276 input: &Variable<D>,
277 out: &Variable<D>,
278 op: &str,
279 default: &str,
280) -> core::fmt::Result {
281 let in_optimized = input.optimized();
282 let acc_item = in_optimized.item();
283
284 let inclusive = Variable::tmp(acc_item);
285 reduce_inclusive(f, input, &inclusive, op)?;
286 let shfl = Variable::tmp(acc_item);
287 writeln!(f, "{} = {{", shfl.fmt_left())?;
288 for k in 0..acc_item.vectorization {
289 let inclusive_indexed = maybe_index(&inclusive, k);
290 let comma = if k > 0 { ", " } else { "" };
291 write!(f, "{comma}")?;
292 D::compile_warp_shuffle_up(f, &inclusive_indexed.to_string(), "1")?;
293 }
294 writeln!(f, "}};")?;
295 let lane_id = Variable::<D>::UnitPosPlane;
296
297 write!(
298 f,
299 "{} = ({lane_id} == 0) ? {}{{",
300 out.fmt_left(),
301 out.item(),
302 )?;
303 for _ in 0..out.item().vectorization {
304 write!(f, "{default},")?;
305 }
306 writeln!(f, "}} : {};", cast(&shfl, out.item()))
307}
308
309pub(crate) fn reduce_broadcast<D: Dialect>(
310 f: &mut core::fmt::Formatter<'_>,
311 input: &Variable<D>,
312 out: &Variable<D>,
313 id: &Variable<D>,
314) -> core::fmt::Result {
315 let out_fmt = out.fmt_left();
316 write!(f, "{out_fmt} = {{ ")?;
317 for i in 0..input.item().vectorization {
318 let comma = if i > 0 { ", " } else { "" };
319 write!(f, "{comma}")?;
320 D::compile_warp_shuffle(f, &format!("{}", input.index(i)), &format!("{id}"))?;
321 }
322 writeln!(f, " }};")
323}
324
325fn reduce_with_loop<
326 D: Dialect,
327 I: Fn(&mut core::fmt::Formatter<'_>, &Variable<D>, usize) -> std::fmt::Result,
328>(
329 f: &mut core::fmt::Formatter<'_>,
330 input: &Variable<D>,
331 out: &Variable<D>,
332 acc_item: Item<D>,
333 instruction: I,
334) -> core::fmt::Result {
335 let acc = Variable::Named {
336 name: "acc",
337 item: acc_item,
338 };
339 let vectorization = acc_item.vectorization;
340
341 writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
342 writeln!(f, " {} {} = {};", acc_item, acc, cast(input, acc_item))?;
343 write!(f, " for (uint offset = 1; offset < ")?;
344 D::compile_plane_dim_checked(f)?;
345 writeln!(f, "; offset *=2 ) {{")?;
346 for k in 0..vectorization {
347 instruction(f, &acc, k)?;
348 }
349 writeln!(f, " }};")?;
350 writeln!(f, " return {};", cast(&acc, out.item()))?;
351 writeln!(f, "}};")?;
352 writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
353}
354
355pub(crate) fn reduce_quantifier<
356 D: Dialect,
357 Q: Fn(&mut core::fmt::Formatter<'_>, &IndexedVariable<D>) -> std::fmt::Result,
358>(
359 f: &mut core::fmt::Formatter<'_>,
360 input: &Variable<D>,
361 out: &Variable<D>,
362 quantifier: Q,
363) -> core::fmt::Result {
364 let out_fmt = out.fmt_left();
365 write!(f, "{out_fmt} = {{ ")?;
366 for i in 0..input.item().vectorization {
367 let comma = if i > 0 { ", " } else { "" };
368 write!(f, "{comma}")?;
369 quantifier(f, &input.index(i))?;
370 }
371 writeln!(f, "}};")
372}
373
374fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
375 if target != input.item() {
376 let addr_space = D::address_space_for_variable(input);
377 let qualifier = input.const_qualifier();
378 format!("reinterpret_cast<{addr_space}{target}{qualifier}&>({input})")
379 } else {
380 format!("{input}")
381 }
382}
383
384fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
385 if var.item().vectorization > 1 {
386 format!("{var}.i_{k}")
387 } else {
388 format!("{var}")
389 }
390}