cubecl_cpp/shared/
warp.rs

1use std::fmt::Display;
2
3use crate::shared::{Component, FmtLeft};
4
5use super::{Dialect, IndexedVariable, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9    ReduceSum {
10        input: Variable<D>,
11        out: Variable<D>,
12    },
13    InclusiveSum {
14        input: Variable<D>,
15        out: Variable<D>,
16    },
17    ExclusiveSum {
18        input: Variable<D>,
19        out: Variable<D>,
20    },
21    ReduceProd {
22        input: Variable<D>,
23        out: Variable<D>,
24    },
25    InclusiveProd {
26        input: Variable<D>,
27        out: Variable<D>,
28    },
29    ExclusiveProd {
30        input: Variable<D>,
31        out: Variable<D>,
32    },
33    ReduceMax {
34        input: Variable<D>,
35        out: Variable<D>,
36    },
37    ReduceMin {
38        input: Variable<D>,
39        out: Variable<D>,
40    },
41    ElectFallback {
42        out: Variable<D>,
43    },
44    Elect {
45        out: Variable<D>,
46    },
47    All {
48        input: Variable<D>,
49        out: Variable<D>,
50    },
51    Any {
52        input: Variable<D>,
53        out: Variable<D>,
54    },
55    Ballot {
56        input: Variable<D>,
57        out: Variable<D>,
58    },
59    Broadcast {
60        input: Variable<D>,
61        id: Variable<D>,
62        out: Variable<D>,
63    },
64    Shuffle {
65        input: Variable<D>,
66        src_lane: Variable<D>,
67        out: Variable<D>,
68    },
69    ShuffleXor {
70        input: Variable<D>,
71        mask: Variable<D>,
72        out: Variable<D>,
73    },
74    ShuffleUp {
75        input: Variable<D>,
76        delta: Variable<D>,
77        out: Variable<D>,
78    },
79    ShuffleDown {
80        input: Variable<D>,
81        delta: Variable<D>,
82        out: Variable<D>,
83    },
84}
85
86impl<D: Dialect> Display for WarpInstruction<D> {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        match self {
89            WarpInstruction::ReduceSum { input, out } => D::warp_reduce_sum(f, input, out),
90            WarpInstruction::ReduceProd { input, out } => D::warp_reduce_prod(f, input, out),
91            WarpInstruction::ReduceMax { input, out } => D::warp_reduce_max(f, input, out),
92            WarpInstruction::ReduceMin { input, out } => D::warp_reduce_min(f, input, out),
93            WarpInstruction::All { input, out } => D::warp_reduce_all(f, input, out),
94            WarpInstruction::Any { input, out } => D::warp_reduce_any(f, input, out),
95
96            WarpInstruction::InclusiveSum { input, out } => {
97                D::warp_reduce_sum_inclusive(f, input, out)
98            }
99            WarpInstruction::InclusiveProd { input, out } => {
100                D::warp_reduce_prod_inclusive(f, input, out)
101            }
102            WarpInstruction::ExclusiveSum { input, out } => {
103                D::warp_reduce_sum_exclusive(f, input, out)
104            }
105            WarpInstruction::ExclusiveProd { input, out } => {
106                D::warp_reduce_prod_exclusive(f, input, out)
107            }
108            WarpInstruction::Ballot { input, out } => {
109                assert_eq!(
110                    input.item().vectorization,
111                    1,
112                    "Ballot can't support vectorized input"
113                );
114                let out_fmt = out.fmt_left();
115                write!(
116                    f,
117                    "
118{out_fmt} = {{ "
119                )?;
120                D::compile_warp_ballot(f, input, out.item().elem())?;
121                writeln!(f, ", 0, 0, 0 }};")
122            }
123            WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
124            WarpInstruction::Shuffle {
125                input,
126                src_lane,
127                out,
128            } => {
129                let out_fmt = out.fmt_left();
130                write!(f, "{out_fmt} = {{ ")?;
131                for i in 0..input.item().vectorization {
132                    let comma = if i > 0 { ", " } else { "" };
133                    write!(f, "{comma}")?;
134                    D::compile_warp_shuffle(
135                        f,
136                        &format!("{}", input.index(i)),
137                        &format!("{src_lane}"),
138                    )?;
139                }
140                writeln!(f, " }};")
141            }
142            WarpInstruction::ShuffleXor { input, mask, out } => {
143                let out_fmt = out.fmt_left();
144                write!(f, "{out_fmt} = {{ ")?;
145                for i in 0..input.item().vectorization {
146                    let comma = if i > 0 { ", " } else { "" };
147                    write!(f, "{comma}")?;
148                    D::compile_warp_shuffle_xor(
149                        f,
150                        &format!("{}", input.index(i)),
151                        input.item().elem(),
152                        &format!("{mask}"),
153                    )?;
154                }
155                writeln!(f, " }};")
156            }
157            WarpInstruction::ShuffleUp { input, delta, out } => {
158                let out_fmt = out.fmt_left();
159                write!(f, "{out_fmt} = {{ ")?;
160                for i in 0..input.item().vectorization {
161                    let comma = if i > 0 { ", " } else { "" };
162                    write!(f, "{comma}")?;
163                    D::compile_warp_shuffle_up(
164                        f,
165                        &format!("{}", input.index(i)),
166                        &format!("{delta}"),
167                    )?;
168                }
169                writeln!(f, " }};")
170            }
171            WarpInstruction::ShuffleDown { input, delta, out } => {
172                let out_fmt = out.fmt_left();
173                write!(f, "{out_fmt} = {{ ")?;
174                for i in 0..input.item().vectorization {
175                    let comma = if i > 0 { ", " } else { "" };
176                    write!(f, "{comma}")?;
177                    D::compile_warp_shuffle_down(
178                        f,
179                        &format!("{}", input.index(i)),
180                        &format!("{delta}"),
181                    )?;
182                }
183                writeln!(f, " }};")
184            }
185            WarpInstruction::ElectFallback { out } => {
186                let out = out.fmt_left();
187                write!(
188                    f,
189                    "
190unsigned int mask = __activemask();
191unsigned int leader = __ffs(mask) - 1;
192{out} = threadIdx.x % warpSize == leader;
193            "
194                )
195            }
196            WarpInstruction::Elect { out } => {
197                let out = out.fmt_left();
198                D::compile_warp_elect(f, &out)
199            }
200        }
201    }
202}
203
204pub(crate) fn reduce_operator<D: Dialect>(
205    f: &mut core::fmt::Formatter<'_>,
206    input: &Variable<D>,
207    out: &Variable<D>,
208    op: &str,
209) -> core::fmt::Result {
210    let in_optimized = input.optimized();
211    let acc_item = in_optimized.item();
212
213    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
214        let acc_indexed = maybe_index(acc, index);
215        write!(f, "{acc_indexed} {op} ")?;
216        D::compile_warp_shuffle_xor(f, &acc_indexed, acc.item().elem(), "offset")?;
217        writeln!(f, ";")
218    })
219}
220
221pub(crate) fn reduce_comparison<
222    D: Dialect,
223    I: Fn(&mut core::fmt::Formatter<'_>, Item<D>) -> std::fmt::Result,
224>(
225    f: &mut core::fmt::Formatter<'_>,
226    input: &Variable<D>,
227    out: &Variable<D>,
228    instruction: I,
229) -> core::fmt::Result {
230    let in_optimized = input.optimized();
231    let acc_item = in_optimized.item();
232    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
233        let acc_indexed = maybe_index(acc, index);
234        let acc_elem = acc_item.elem();
235        write!(f, "        {acc_indexed} = ")?;
236        instruction(f, in_optimized.item())?;
237        write!(f, "({acc_indexed}, ")?;
238        D::compile_warp_shuffle_xor(f, &acc_indexed, acc_elem, "offset")?;
239        writeln!(f, ");")
240    })
241}
242
243pub(crate) fn reduce_inclusive<D: Dialect>(
244    f: &mut core::fmt::Formatter<'_>,
245    input: &Variable<D>,
246    out: &Variable<D>,
247    op: &str,
248) -> core::fmt::Result {
249    let in_optimized = input.optimized();
250    let acc_item = in_optimized.item();
251
252    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
253        let acc_indexed = maybe_index(acc, index);
254        let tmp = Variable::tmp(Item::scalar(acc_item.elem, false));
255        let tmp_left = tmp.fmt_left();
256        let lane_id = Variable::<D>::UnitPosPlane;
257        write!(
258            f,
259            "
260{tmp_left} = "
261        )?;
262        D::compile_warp_shuffle_up(f, &acc_indexed, "offset")?;
263        write!(
264            f,
265            ";
266if({lane_id} >= offset) {{
267    {acc_indexed} {op} {tmp};
268}}
269"
270        )
271    })
272}
273
274pub(crate) fn reduce_exclusive<D: Dialect>(
275    f: &mut core::fmt::Formatter<'_>,
276    input: &Variable<D>,
277    out: &Variable<D>,
278    op: &str,
279    default: &str,
280) -> core::fmt::Result {
281    let in_optimized = input.optimized();
282    let acc_item = in_optimized.item();
283
284    let inclusive = Variable::tmp(acc_item);
285    reduce_inclusive(f, input, &inclusive, op)?;
286    let shfl = Variable::tmp(acc_item);
287    writeln!(f, "{} = {{", shfl.fmt_left())?;
288    for k in 0..acc_item.vectorization {
289        let inclusive_indexed = maybe_index(&inclusive, k);
290        let comma = if k > 0 { ", " } else { "" };
291        write!(f, "{comma}")?;
292        D::compile_warp_shuffle_up(f, &inclusive_indexed.to_string(), "1")?;
293    }
294    writeln!(f, "}};")?;
295    let lane_id = Variable::<D>::UnitPosPlane;
296
297    write!(
298        f,
299        "{} = ({lane_id} == 0) ? {}{{",
300        out.fmt_left(),
301        out.item(),
302    )?;
303    for _ in 0..out.item().vectorization {
304        write!(f, "{default},")?;
305    }
306    writeln!(f, "}} : {};", cast(&shfl, out.item()))
307}
308
309pub(crate) fn reduce_broadcast<D: Dialect>(
310    f: &mut core::fmt::Formatter<'_>,
311    input: &Variable<D>,
312    out: &Variable<D>,
313    id: &Variable<D>,
314) -> core::fmt::Result {
315    let out_fmt = out.fmt_left();
316    write!(f, "{out_fmt} = {{ ")?;
317    for i in 0..input.item().vectorization {
318        let comma = if i > 0 { ", " } else { "" };
319        write!(f, "{comma}")?;
320        D::compile_warp_shuffle(f, &format!("{}", input.index(i)), &format!("{id}"))?;
321    }
322    writeln!(f, " }};")
323}
324
325fn reduce_with_loop<
326    D: Dialect,
327    I: Fn(&mut core::fmt::Formatter<'_>, &Variable<D>, usize) -> std::fmt::Result,
328>(
329    f: &mut core::fmt::Formatter<'_>,
330    input: &Variable<D>,
331    out: &Variable<D>,
332    acc_item: Item<D>,
333    instruction: I,
334) -> core::fmt::Result {
335    let acc = Variable::Named {
336        name: "acc",
337        item: acc_item,
338    };
339    let vectorization = acc_item.vectorization;
340
341    writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
342    writeln!(f, "    {} {} = {};", acc_item, acc, cast(input, acc_item))?;
343    write!(f, "    for (uint offset = 1; offset < ")?;
344    D::compile_plane_dim_checked(f)?;
345    writeln!(f, "; offset *=2 ) {{")?;
346    for k in 0..vectorization {
347        instruction(f, &acc, k)?;
348    }
349    writeln!(f, "    }};")?;
350    writeln!(f, "    return {};", cast(&acc, out.item()))?;
351    writeln!(f, "}};")?;
352    writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
353}
354
355pub(crate) fn reduce_quantifier<
356    D: Dialect,
357    Q: Fn(&mut core::fmt::Formatter<'_>, &IndexedVariable<D>) -> std::fmt::Result,
358>(
359    f: &mut core::fmt::Formatter<'_>,
360    input: &Variable<D>,
361    out: &Variable<D>,
362    quantifier: Q,
363) -> core::fmt::Result {
364    let out_fmt = out.fmt_left();
365    write!(f, "{out_fmt} = {{ ")?;
366    for i in 0..input.item().vectorization {
367        let comma = if i > 0 { ", " } else { "" };
368        write!(f, "{comma}")?;
369        quantifier(f, &input.index(i))?;
370    }
371    writeln!(f, "}};")
372}
373
374fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
375    if target != input.item() {
376        let addr_space = D::address_space_for_variable(input);
377        let qualifier = input.const_qualifier();
378        format!("reinterpret_cast<{addr_space}{target}{qualifier}&>({input})")
379    } else {
380        format!("{input}")
381    }
382}
383
384fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
385    if var.item().vectorization > 1 {
386        format!("{var}.i_{k}")
387    } else {
388        format!("{var}")
389    }
390}