cubecl_cpp/shared/
warp.rs

1use std::fmt::Display;
2
3use crate::shared::{Component, FmtLeft};
4
5use super::{Dialect, IndexedVariable, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9    ReduceSum {
10        input: Variable<D>,
11        out: Variable<D>,
12    },
13    InclusiveSum {
14        input: Variable<D>,
15        out: Variable<D>,
16    },
17    ExclusiveSum {
18        input: Variable<D>,
19        out: Variable<D>,
20    },
21    ReduceProd {
22        input: Variable<D>,
23        out: Variable<D>,
24    },
25    InclusiveProd {
26        input: Variable<D>,
27        out: Variable<D>,
28    },
29    ExclusiveProd {
30        input: Variable<D>,
31        out: Variable<D>,
32    },
33    ReduceMax {
34        input: Variable<D>,
35        out: Variable<D>,
36    },
37    ReduceMin {
38        input: Variable<D>,
39        out: Variable<D>,
40    },
41    Elect {
42        out: Variable<D>,
43    },
44    All {
45        input: Variable<D>,
46        out: Variable<D>,
47    },
48    Any {
49        input: Variable<D>,
50        out: Variable<D>,
51    },
52    Ballot {
53        input: Variable<D>,
54        out: Variable<D>,
55    },
56    Broadcast {
57        input: Variable<D>,
58        id: Variable<D>,
59        out: Variable<D>,
60    },
61}
62
63impl<D: Dialect> Display for WarpInstruction<D> {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            WarpInstruction::ReduceSum { input, out } => reduce_operator(f, input, out, "+="),
67            WarpInstruction::ReduceProd { input, out } => reduce_operator(f, input, out, "*="),
68            WarpInstruction::ReduceMax { input, out } => {
69                reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
70            }
71            WarpInstruction::ReduceMin { input, out } => {
72                reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
73            }
74            WarpInstruction::All { input, out } => {
75                reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
76            }
77            WarpInstruction::Any { input, out } => {
78                reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
79            }
80            WarpInstruction::Ballot { input, out } => {
81                assert_eq!(
82                    input.item().vectorization,
83                    1,
84                    "Ballot can't support vectorized input"
85                );
86                let out_fmt = out.fmt_left();
87                write!(
88                    f,
89                    "
90{out_fmt} = {{ "
91                )?;
92                D::compile_warp_ballot(f, input, out.item().elem())?;
93                writeln!(f, ", 0, 0, 0 }};")
94            }
95            WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
96            WarpInstruction::Elect { out } => write!(
97                f,
98                "
99unsigned int mask = __activemask();
100unsigned int leader = __ffs(mask) - 1;
101{out} = threadIdx.x % warpSize == leader;
102            "
103            ),
104            WarpInstruction::InclusiveSum { input, out } => reduce_inclusive(f, input, out, "+="),
105            WarpInstruction::InclusiveProd { input, out } => reduce_inclusive(f, input, out, "*="),
106            WarpInstruction::ExclusiveSum { input, out } => {
107                reduce_exclusive(f, input, out, "+=", "0")
108            }
109            WarpInstruction::ExclusiveProd { input, out } => {
110                reduce_exclusive(f, input, out, "*=", "1")
111            }
112        }
113    }
114}
115
116fn reduce_operator<D: Dialect>(
117    f: &mut core::fmt::Formatter<'_>,
118    input: &Variable<D>,
119    out: &Variable<D>,
120    op: &str,
121) -> core::fmt::Result {
122    let in_optimized = input.optimized();
123    let acc_item = in_optimized.item();
124
125    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
126        let acc_indexed = maybe_index(acc, index);
127        write!(f, "{acc_indexed} {op} ")?;
128        D::compile_warp_shuffle_xor(f, &acc_indexed, acc.item().elem(), "offset")?;
129        writeln!(f, ";")
130    })
131}
132
133fn reduce_comparison<
134    D: Dialect,
135    I: Fn(&mut core::fmt::Formatter<'_>, Item<D>) -> std::fmt::Result,
136>(
137    f: &mut core::fmt::Formatter<'_>,
138    input: &Variable<D>,
139    out: &Variable<D>,
140    instruction: I,
141) -> core::fmt::Result {
142    let in_optimized = input.optimized();
143    let acc_item = in_optimized.item();
144    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
145        let acc_indexed = maybe_index(acc, index);
146        let acc_elem = acc_item.elem();
147        write!(f, "        {acc_indexed} = ")?;
148        instruction(f, in_optimized.item())?;
149        write!(f, "({acc_indexed}, ")?;
150        D::compile_warp_shuffle_xor(f, &acc_indexed, acc_elem, "offset")?;
151        writeln!(f, ");")
152    })
153}
154
155fn reduce_inclusive<D: Dialect>(
156    f: &mut core::fmt::Formatter<'_>,
157    input: &Variable<D>,
158    out: &Variable<D>,
159    op: &str,
160) -> core::fmt::Result {
161    let in_optimized = input.optimized();
162    let acc_item = in_optimized.item();
163
164    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
165        let acc_indexed = maybe_index(acc, index);
166        let tmp = Variable::tmp(Item::scalar(acc_item.elem, false));
167        let tmp_left = tmp.fmt_left();
168        let lane_id = Variable::<D>::UnitPosPlane;
169        write!(
170            f,
171            "
172{tmp_left} = "
173        )?;
174        D::compile_warp_shuffle_up(f, &acc_indexed, "offset")?;
175        write!(
176            f,
177            ";
178if({lane_id} >= offset) {{
179    {acc_indexed} {op} {tmp};
180}}
181"
182        )
183    })
184}
185
186fn reduce_exclusive<D: Dialect>(
187    f: &mut core::fmt::Formatter<'_>,
188    input: &Variable<D>,
189    out: &Variable<D>,
190    op: &str,
191    default: &str,
192) -> core::fmt::Result {
193    let in_optimized = input.optimized();
194    let acc_item = in_optimized.item();
195
196    let inclusive = Variable::tmp(acc_item);
197    reduce_inclusive(f, input, &inclusive, op)?;
198    let shfl = Variable::tmp(acc_item);
199    writeln!(f, "{} = {{", shfl.fmt_left())?;
200    for k in 0..acc_item.vectorization {
201        let inclusive_indexed = maybe_index(&inclusive, k);
202        let comma = if k > 0 { ", " } else { "" };
203        write!(f, "{comma}")?;
204        D::compile_warp_shuffle_up(f, &inclusive_indexed.to_string(), "1")?;
205    }
206    writeln!(f, "}};")?;
207    let lane_id = Variable::<D>::UnitPosPlane;
208
209    write!(
210        f,
211        "{} = ({lane_id} == 0) ? {}{{",
212        out.fmt_left(),
213        out.item(),
214    )?;
215    for _ in 0..out.item().vectorization {
216        write!(f, "{default},")?;
217    }
218    writeln!(f, "}} : {};", cast(&shfl, out.item()))
219}
220
221fn reduce_broadcast<D: Dialect>(
222    f: &mut core::fmt::Formatter<'_>,
223    input: &Variable<D>,
224    out: &Variable<D>,
225    id: &Variable<D>,
226) -> core::fmt::Result {
227    let out_fmt = out.fmt_left();
228    write!(f, "{out_fmt} = {{ ")?;
229    for i in 0..input.item().vectorization {
230        let comma = if i > 0 { ", " } else { "" };
231        write!(f, "{comma}")?;
232        D::compile_warp_shuffle(f, &format!("{}", input.index(i)), &format!("{id}"))?;
233    }
234    writeln!(f, " }};")
235}
236
237fn reduce_with_loop<
238    D: Dialect,
239    I: Fn(&mut core::fmt::Formatter<'_>, &Variable<D>, usize) -> std::fmt::Result,
240>(
241    f: &mut core::fmt::Formatter<'_>,
242    input: &Variable<D>,
243    out: &Variable<D>,
244    acc_item: Item<D>,
245    instruction: I,
246) -> core::fmt::Result {
247    let acc = Variable::Named {
248        name: "acc",
249        item: acc_item,
250    };
251    let vectorization = acc_item.vectorization;
252
253    writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
254    writeln!(f, "    {} {} = {};", acc_item, acc, cast(input, acc_item))?;
255    write!(f, "    for (uint offset = 1; offset < ")?;
256    D::compile_plane_dim_checked(f)?;
257    writeln!(f, "; offset *=2 ) {{")?;
258    for k in 0..vectorization {
259        instruction(f, &acc, k)?;
260    }
261    writeln!(f, "    }};")?;
262    writeln!(f, "    return {};", cast(&acc, out.item()))?;
263    writeln!(f, "}};")?;
264    writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
265}
266
267fn reduce_quantifier<
268    D: Dialect,
269    Q: Fn(&mut core::fmt::Formatter<'_>, &IndexedVariable<D>) -> std::fmt::Result,
270>(
271    f: &mut core::fmt::Formatter<'_>,
272    input: &Variable<D>,
273    out: &Variable<D>,
274    quantifier: Q,
275) -> core::fmt::Result {
276    let out_fmt = out.fmt_left();
277    write!(f, "{out_fmt} = {{ ")?;
278    for i in 0..input.item().vectorization {
279        let comma = if i > 0 { ", " } else { "" };
280        write!(f, "{comma}")?;
281        quantifier(f, &input.index(i))?;
282    }
283    writeln!(f, "}};")
284}
285
286fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
287    if target != input.item() {
288        let addr_space = D::address_space_for_variable(input);
289        let qualifier = input.const_qualifier();
290        format!("reinterpret_cast<{addr_space}{target}{qualifier}&>({input})")
291    } else {
292        format!("{}", input)
293    }
294}
295
296fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
297    if var.item().vectorization > 1 {
298        format!("{var}.i_{k}")
299    } else {
300        format!("{var}")
301    }
302}