cubecl_cpp/shared/
warp.rs

1use std::fmt::Display;
2
3use crate::shared::{Component, Elem, FmtLeft};
4
5use super::{Dialect, Item, Variable};
6
7#[derive(Clone, Debug)]
8pub enum WarpInstruction<D: Dialect> {
9    ReduceSum {
10        input: Variable<D>,
11        out: Variable<D>,
12    },
13    ReduceProd {
14        input: Variable<D>,
15        out: Variable<D>,
16    },
17    ReduceMax {
18        input: Variable<D>,
19        out: Variable<D>,
20    },
21    ReduceMin {
22        input: Variable<D>,
23        out: Variable<D>,
24    },
25    Elect {
26        out: Variable<D>,
27    },
28    All {
29        input: Variable<D>,
30        out: Variable<D>,
31    },
32    Any {
33        input: Variable<D>,
34        out: Variable<D>,
35    },
36    Broadcast {
37        input: Variable<D>,
38        id: Variable<D>,
39        out: Variable<D>,
40    },
41}
42
43impl<D: Dialect> Display for WarpInstruction<D> {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            WarpInstruction::ReduceSum { input, out } => reduce_operator(f, input, out, "+="),
47            WarpInstruction::ReduceProd { input, out } => reduce_operator(f, input, out, "*="),
48            WarpInstruction::ReduceMax { input, out } => reduce_comparison(f, input, out, "max"),
49            WarpInstruction::ReduceMin { input, out } => reduce_comparison(f, input, out, "min"),
50            WarpInstruction::All { input, out } => reduce_quantifier(f, input, out, D::warp_all),
51            WarpInstruction::Any { input, out } => reduce_quantifier(f, input, out, D::warp_any),
52            WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
53            WarpInstruction::Elect { out } => write!(
54                f,
55                "
56unsigned int mask = __activemask();
57unsigned int leader = __ffs(mask) - 1;
58{out} = threadIdx.x % warpSize == leader;
59            "
60            ),
61        }
62    }
63}
64
65fn reduce_operator<D: Dialect>(
66    f: &mut core::fmt::Formatter<'_>,
67    input: &Variable<D>,
68    out: &Variable<D>,
69    op: &str,
70) -> core::fmt::Result {
71    let in_optimized = input.optimized();
72    let acc_item = in_optimized.item();
73
74    reduce_with_loop(f, input, out, acc_item, |acc, index| {
75        let acc_indexed = maybe_index(acc, index);
76        let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset");
77        format!("{acc_indexed} {op} {shfl_xor};")
78    })
79}
80
81fn reduce_comparison<D: Dialect>(
82    f: &mut core::fmt::Formatter<'_>,
83    input: &Variable<D>,
84    out: &Variable<D>,
85    cmp: &str,
86) -> core::fmt::Result {
87    let in_optimized = input.optimized();
88    let acc_item = in_optimized.item();
89    let instruction = match in_optimized.elem() {
90        Elem::F16 | Elem::BF16 => format!("__h{cmp}"),
91        Elem::F162 | Elem::BF162 => format!("__h{cmp}2"),
92        _ => cmp.to_string(),
93    };
94
95    reduce_with_loop(f, input, out, acc_item, |acc, index| {
96        let acc_indexed = maybe_index(acc, index);
97        let shfl_xor = D::warp_shuffle_xor(&acc_indexed, "offset");
98        format!("{acc_indexed} = {instruction}({acc_indexed}, {shfl_xor});")
99    })
100}
101
102fn reduce_broadcast<D: Dialect>(
103    f: &mut core::fmt::Formatter<'_>,
104    input: &Variable<D>,
105    out: &Variable<D>,
106    id: &Variable<D>,
107) -> core::fmt::Result {
108    let rhs = (0..input.item().vectorization)
109        .map(|k| D::warp_shuffle(&format!("{}", input.index(k)), &format!("{id}")))
110        .collect::<Vec<_>>()
111        .join(",");
112    let out_fmt = out.fmt_left();
113    writeln!(f, "{out_fmt} = {{ {rhs} }};")
114}
115
116fn reduce_with_loop<D: Dialect, I: Fn(&Variable<D>, usize) -> String>(
117    f: &mut core::fmt::Formatter<'_>,
118    input: &Variable<D>,
119    out: &Variable<D>,
120    acc_item: Item<D>,
121    instruction: I,
122) -> core::fmt::Result {
123    let acc = Variable::Named {
124        name: "acc",
125        item: acc_item,
126    };
127
128    writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
129    writeln!(f, "    {} {} = {};", acc_item, acc, cast(input, acc_item))?;
130    writeln!(
131        f,
132        "    for (int offset = 1; offset < warpSizeChecked; offset *=2 ) {{"
133    )?;
134    for k in 0..acc_item.vectorization {
135        writeln!(f, "        {}", instruction(&acc, k))?;
136    }
137    writeln!(f, "    }};")?;
138    writeln!(f, "    return {};", cast(&acc, out.item()))?;
139    writeln!(f, "}};")?;
140    writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
141}
142
143fn reduce_quantifier<D: Dialect, Q: Fn(&str) -> String>(
144    f: &mut core::fmt::Formatter<'_>,
145    input: &Variable<D>,
146    out: &Variable<D>,
147    quantifier: Q,
148) -> core::fmt::Result {
149    let rhs = (0..input.item().vectorization)
150        .map(|k| quantifier(&format!("{}", input.index(k))))
151        .collect::<Vec<_>>()
152        .join(",");
153    let out_fmt = out.fmt_left();
154    writeln!(f, "{out_fmt} = {{ {rhs} }};")
155}
156
157fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
158    if target != input.item() {
159        let qualifier = input.const_qualifier();
160        format!("reinterpret_cast<{}{}&>({})", target, qualifier, input)
161    } else {
162        format!("{}", input)
163    }
164}
165
166fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
167    if var.item().vectorization > 1 {
168        format!("{var}.i_{k}")
169    } else {
170        format!("{var}")
171    }
172}