cubecl_cpp/shared/
warp.rs

1use std::fmt::Display;
2
3use crate::shared::{Component, FmtLeft};
4
5use super::{Dialect, IndexedVariable, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9    ReduceSum {
10        input: Variable<D>,
11        out: Variable<D>,
12    },
13    InclusiveSum {
14        input: Variable<D>,
15        out: Variable<D>,
16    },
17    ExclusiveSum {
18        input: Variable<D>,
19        out: Variable<D>,
20    },
21    ReduceProd {
22        input: Variable<D>,
23        out: Variable<D>,
24    },
25    InclusiveProd {
26        input: Variable<D>,
27        out: Variable<D>,
28    },
29    ExclusiveProd {
30        input: Variable<D>,
31        out: Variable<D>,
32    },
33    ReduceMax {
34        input: Variable<D>,
35        out: Variable<D>,
36    },
37    ReduceMin {
38        input: Variable<D>,
39        out: Variable<D>,
40    },
41    Elect {
42        out: Variable<D>,
43    },
44    All {
45        input: Variable<D>,
46        out: Variable<D>,
47    },
48    Any {
49        input: Variable<D>,
50        out: Variable<D>,
51    },
52    Ballot {
53        input: Variable<D>,
54        out: Variable<D>,
55    },
56    Broadcast {
57        input: Variable<D>,
58        id: Variable<D>,
59        out: Variable<D>,
60    },
61    Shuffle {
62        input: Variable<D>,
63        src_lane: Variable<D>,
64        out: Variable<D>,
65    },
66    ShuffleXor {
67        input: Variable<D>,
68        mask: Variable<D>,
69        out: Variable<D>,
70    },
71    ShuffleUp {
72        input: Variable<D>,
73        delta: Variable<D>,
74        out: Variable<D>,
75    },
76    ShuffleDown {
77        input: Variable<D>,
78        delta: Variable<D>,
79        out: Variable<D>,
80    },
81}
82
83impl<D: Dialect> Display for WarpInstruction<D> {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            WarpInstruction::ReduceSum { input, out } => D::warp_reduce_sum(f, input, out),
87            WarpInstruction::ReduceProd { input, out } => D::warp_reduce_prod(f, input, out),
88            WarpInstruction::ReduceMax { input, out } => D::warp_reduce_max(f, input, out),
89            WarpInstruction::ReduceMin { input, out } => D::warp_reduce_min(f, input, out),
90            WarpInstruction::All { input, out } => D::warp_reduce_all(f, input, out),
91            WarpInstruction::Any { input, out } => D::warp_reduce_any(f, input, out),
92
93            WarpInstruction::InclusiveSum { input, out } => {
94                D::warp_reduce_sum_inclusive(f, input, out)
95            }
96            WarpInstruction::InclusiveProd { input, out } => {
97                D::warp_reduce_prod_inclusive(f, input, out)
98            }
99            WarpInstruction::ExclusiveSum { input, out } => {
100                D::warp_reduce_sum_exclusive(f, input, out)
101            }
102            WarpInstruction::ExclusiveProd { input, out } => {
103                D::warp_reduce_prod_exclusive(f, input, out)
104            }
105            WarpInstruction::Ballot { input, out } => {
106                assert_eq!(
107                    input.item().vectorization,
108                    1,
109                    "Ballot can't support vectorized input"
110                );
111                let out_fmt = out.fmt_left();
112                write!(
113                    f,
114                    "
115{out_fmt} = {{ "
116                )?;
117                D::compile_warp_ballot(f, input, out.item().elem())?;
118                writeln!(f, ", 0, 0, 0 }};")
119            }
120            WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
121            WarpInstruction::Shuffle {
122                input,
123                src_lane,
124                out,
125            } => {
126                let out_fmt = out.fmt_left();
127                write!(f, "{out_fmt} = {{ ")?;
128                for i in 0..input.item().vectorization {
129                    let comma = if i > 0 { ", " } else { "" };
130                    write!(f, "{comma}")?;
131                    D::compile_warp_shuffle(
132                        f,
133                        &format!("{}", input.index(i)),
134                        &format!("{src_lane}"),
135                    )?;
136                }
137                writeln!(f, " }};")
138            }
139            WarpInstruction::ShuffleXor { input, mask, out } => {
140                let out_fmt = out.fmt_left();
141                write!(f, "{out_fmt} = {{ ")?;
142                for i in 0..input.item().vectorization {
143                    let comma = if i > 0 { ", " } else { "" };
144                    write!(f, "{comma}")?;
145                    D::compile_warp_shuffle_xor(
146                        f,
147                        &format!("{}", input.index(i)),
148                        input.item().elem(),
149                        &format!("{mask}"),
150                    )?;
151                }
152                writeln!(f, " }};")
153            }
154            WarpInstruction::ShuffleUp { input, delta, out } => {
155                let out_fmt = out.fmt_left();
156                write!(f, "{out_fmt} = {{ ")?;
157                for i in 0..input.item().vectorization {
158                    let comma = if i > 0 { ", " } else { "" };
159                    write!(f, "{comma}")?;
160                    D::compile_warp_shuffle_up(
161                        f,
162                        &format!("{}", input.index(i)),
163                        &format!("{delta}"),
164                    )?;
165                }
166                writeln!(f, " }};")
167            }
168            WarpInstruction::ShuffleDown { input, delta, out } => {
169                let out_fmt = out.fmt_left();
170                write!(f, "{out_fmt} = {{ ")?;
171                for i in 0..input.item().vectorization {
172                    let comma = if i > 0 { ", " } else { "" };
173                    write!(f, "{comma}")?;
174                    D::compile_warp_shuffle_down(
175                        f,
176                        &format!("{}", input.index(i)),
177                        &format!("{delta}"),
178                    )?;
179                }
180                writeln!(f, " }};")
181            }
182            WarpInstruction::Elect { out } => write!(
183                f,
184                "
185unsigned int mask = __activemask();
186unsigned int leader = __ffs(mask) - 1;
187{out} = threadIdx.x % warpSize == leader;
188            "
189            ),
190        }
191    }
192}
193
194pub(crate) fn reduce_operator<D: Dialect>(
195    f: &mut core::fmt::Formatter<'_>,
196    input: &Variable<D>,
197    out: &Variable<D>,
198    op: &str,
199) -> core::fmt::Result {
200    let in_optimized = input.optimized();
201    let acc_item = in_optimized.item();
202
203    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
204        let acc_indexed = maybe_index(acc, index);
205        write!(f, "{acc_indexed} {op} ")?;
206        D::compile_warp_shuffle_xor(f, &acc_indexed, acc.item().elem(), "offset")?;
207        writeln!(f, ";")
208    })
209}
210
211pub(crate) fn reduce_comparison<
212    D: Dialect,
213    I: Fn(&mut core::fmt::Formatter<'_>, Item<D>) -> std::fmt::Result,
214>(
215    f: &mut core::fmt::Formatter<'_>,
216    input: &Variable<D>,
217    out: &Variable<D>,
218    instruction: I,
219) -> core::fmt::Result {
220    let in_optimized = input.optimized();
221    let acc_item = in_optimized.item();
222    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
223        let acc_indexed = maybe_index(acc, index);
224        let acc_elem = acc_item.elem();
225        write!(f, "        {acc_indexed} = ")?;
226        instruction(f, in_optimized.item())?;
227        write!(f, "({acc_indexed}, ")?;
228        D::compile_warp_shuffle_xor(f, &acc_indexed, acc_elem, "offset")?;
229        writeln!(f, ");")
230    })
231}
232
233pub(crate) fn reduce_inclusive<D: Dialect>(
234    f: &mut core::fmt::Formatter<'_>,
235    input: &Variable<D>,
236    out: &Variable<D>,
237    op: &str,
238) -> core::fmt::Result {
239    let in_optimized = input.optimized();
240    let acc_item = in_optimized.item();
241
242    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
243        let acc_indexed = maybe_index(acc, index);
244        let tmp = Variable::tmp(Item::scalar(acc_item.elem, false));
245        let tmp_left = tmp.fmt_left();
246        let lane_id = Variable::<D>::UnitPosPlane;
247        write!(
248            f,
249            "
250{tmp_left} = "
251        )?;
252        D::compile_warp_shuffle_up(f, &acc_indexed, "offset")?;
253        write!(
254            f,
255            ";
256if({lane_id} >= offset) {{
257    {acc_indexed} {op} {tmp};
258}}
259"
260        )
261    })
262}
263
264pub(crate) fn reduce_exclusive<D: Dialect>(
265    f: &mut core::fmt::Formatter<'_>,
266    input: &Variable<D>,
267    out: &Variable<D>,
268    op: &str,
269    default: &str,
270) -> core::fmt::Result {
271    let in_optimized = input.optimized();
272    let acc_item = in_optimized.item();
273
274    let inclusive = Variable::tmp(acc_item);
275    reduce_inclusive(f, input, &inclusive, op)?;
276    let shfl = Variable::tmp(acc_item);
277    writeln!(f, "{} = {{", shfl.fmt_left())?;
278    for k in 0..acc_item.vectorization {
279        let inclusive_indexed = maybe_index(&inclusive, k);
280        let comma = if k > 0 { ", " } else { "" };
281        write!(f, "{comma}")?;
282        D::compile_warp_shuffle_up(f, &inclusive_indexed.to_string(), "1")?;
283    }
284    writeln!(f, "}};")?;
285    let lane_id = Variable::<D>::UnitPosPlane;
286
287    write!(
288        f,
289        "{} = ({lane_id} == 0) ? {}{{",
290        out.fmt_left(),
291        out.item(),
292    )?;
293    for _ in 0..out.item().vectorization {
294        write!(f, "{default},")?;
295    }
296    writeln!(f, "}} : {};", cast(&shfl, out.item()))
297}
298
299pub(crate) fn reduce_broadcast<D: Dialect>(
300    f: &mut core::fmt::Formatter<'_>,
301    input: &Variable<D>,
302    out: &Variable<D>,
303    id: &Variable<D>,
304) -> core::fmt::Result {
305    let out_fmt = out.fmt_left();
306    write!(f, "{out_fmt} = {{ ")?;
307    for i in 0..input.item().vectorization {
308        let comma = if i > 0 { ", " } else { "" };
309        write!(f, "{comma}")?;
310        D::compile_warp_shuffle(f, &format!("{}", input.index(i)), &format!("{id}"))?;
311    }
312    writeln!(f, " }};")
313}
314
315fn reduce_with_loop<
316    D: Dialect,
317    I: Fn(&mut core::fmt::Formatter<'_>, &Variable<D>, usize) -> std::fmt::Result,
318>(
319    f: &mut core::fmt::Formatter<'_>,
320    input: &Variable<D>,
321    out: &Variable<D>,
322    acc_item: Item<D>,
323    instruction: I,
324) -> core::fmt::Result {
325    let acc = Variable::Named {
326        name: "acc",
327        item: acc_item,
328    };
329    let vectorization = acc_item.vectorization;
330
331    writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
332    writeln!(f, "    {} {} = {};", acc_item, acc, cast(input, acc_item))?;
333    write!(f, "    for (uint offset = 1; offset < ")?;
334    D::compile_plane_dim_checked(f)?;
335    writeln!(f, "; offset *=2 ) {{")?;
336    for k in 0..vectorization {
337        instruction(f, &acc, k)?;
338    }
339    writeln!(f, "    }};")?;
340    writeln!(f, "    return {};", cast(&acc, out.item()))?;
341    writeln!(f, "}};")?;
342    writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
343}
344
345pub(crate) fn reduce_quantifier<
346    D: Dialect,
347    Q: Fn(&mut core::fmt::Formatter<'_>, &IndexedVariable<D>) -> std::fmt::Result,
348>(
349    f: &mut core::fmt::Formatter<'_>,
350    input: &Variable<D>,
351    out: &Variable<D>,
352    quantifier: Q,
353) -> core::fmt::Result {
354    let out_fmt = out.fmt_left();
355    write!(f, "{out_fmt} = {{ ")?;
356    for i in 0..input.item().vectorization {
357        let comma = if i > 0 { ", " } else { "" };
358        write!(f, "{comma}")?;
359        quantifier(f, &input.index(i))?;
360    }
361    writeln!(f, "}};")
362}
363
364fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
365    if target != input.item() {
366        let addr_space = D::address_space_for_variable(input);
367        let qualifier = input.const_qualifier();
368        format!("reinterpret_cast<{addr_space}{target}{qualifier}&>({input})")
369    } else {
370        format!("{input}")
371    }
372}
373
374fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
375    if var.item().vectorization > 1 {
376        format!("{var}.i_{k}")
377    } else {
378        format!("{var}")
379    }
380}