1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use std::fmt::Display;

use super::Variable;

#[derive(Clone, Debug)]
pub enum WarpInstruction {
    ReduceSum { input: Variable, out: Variable },
    ReduceProd { input: Variable, out: Variable },
    ReduceMax { input: Variable, out: Variable },
    ReduceMin { input: Variable, out: Variable },
}

impl Display for WarpInstruction {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            WarpInstruction::ReduceSum { input, out } => f.write_fmt(format_args!(
                "
{out} = {input};
                    {{
    for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{
        {out} += __shfl_down_sync(0xFFFFFFFF, {out}, offset);
    }}
}}
                        "
            )),
            WarpInstruction::ReduceProd { input, out } => f.write_fmt(format_args!(
                "
{out} = {input};
                    {{
    for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{
        {out} *= __shfl_down_sync(0xFFFFFFFF, {out}, offset);
    }}
}}
                        "
            )),
            WarpInstruction::ReduceMax { input, out } => f.write_fmt(format_args!(
                "
{out} = {input};
                {{
for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{
    {out} = max({out}, __shfl_down_sync(0xFFFFFFFF, {out}, offset));
}}
}}
                    "
            )),
            WarpInstruction::ReduceMin { input, out } => f.write_fmt(format_args!(
                "
{out} = {input};
                {{
for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{
    {out} = min({out}, __shfl_down_sync(0xFFFFFFFF, {out}, offset));
}}
}}
                    "
            )),
        }
    }
}