cubecl_spirv/
subgroup.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use cubecl_core::ir::Subcube;
use rspirv::spirv::{Capability, GroupOperation, Scope, Word};

use crate::{SpirvCompiler, SpirvTarget};

impl<T: SpirvTarget> SpirvCompiler<T> {
    pub fn compile_subcube(&mut self, subcube: Subcube) {
        self.capabilities
            .insert(Capability::GroupNonUniformArithmetic);
        let subgroup = self.subgroup();
        match subcube {
            Subcube::Elect(op) => {
                let out = self.compile_variable(op.out);
                let out_id = self.write_id(&out);
                let bool = self.type_bool();
                self.group_non_uniform_elect(bool, Some(out_id), subgroup)
                    .unwrap();
                self.write(&out, out_id);
            }
            Subcube::All(op) => {
                self.capabilities.insert(Capability::GroupNonUniformVote);
                self.compile_unary_op(op, |b, _, ty, input, out| {
                    b.group_non_uniform_all(ty, Some(out), subgroup, input)
                        .unwrap();
                });
            }
            Subcube::Any(op) => {
                self.capabilities.insert(Capability::GroupNonUniformVote);
                self.compile_unary_op(op, |b, _, ty, input, out| {
                    b.group_non_uniform_any(ty, Some(out), subgroup, input)
                        .unwrap();
                });
            }
            Subcube::Broadcast(op) => {
                self.capabilities.insert(Capability::GroupNonUniformBallot);
                self.compile_binary_op_no_cast(op, |b, _, ty, lhs, rhs, out| {
                    b.group_non_uniform_broadcast(ty, Some(out), subgroup, lhs, rhs)
                        .unwrap();
                });
            }
            Subcube::Sum(op) => {
                self.compile_unary_op(op, |b, out_ty, ty, input, out| {
                    match out_ty.elem() {
                        crate::item::Elem::Int(_, false) => b.group_non_uniform_i_add(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        crate::item::Elem::Float(_) => b.group_non_uniform_f_add(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        _ => unreachable!(),
                    }
                    .unwrap();
                });
            }
            Subcube::Prod(op) => {
                self.compile_unary_op(op, |b, out_ty, ty, input, out| {
                    match out_ty.elem() {
                        crate::item::Elem::Int(_, _) => b.group_non_uniform_i_mul(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        crate::item::Elem::Float(_) => b.group_non_uniform_f_mul(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        _ => unreachable!(),
                    }
                    .unwrap();
                });
            }
            Subcube::Min(op) => {
                self.compile_unary_op(op, |b, out_ty, ty, input, out| {
                    match out_ty.elem() {
                        crate::item::Elem::Int(_, false) => b.group_non_uniform_u_min(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        crate::item::Elem::Int(_, true) => b.group_non_uniform_s_min(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        crate::item::Elem::Float(_) => b.group_non_uniform_f_min(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        _ => unreachable!(),
                    }
                    .unwrap();
                });
            }
            Subcube::Max(op) => {
                self.compile_unary_op(op, |b, out_ty, ty, input, out| {
                    match out_ty.elem() {
                        crate::item::Elem::Int(_, false) => b.group_non_uniform_u_max(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        crate::item::Elem::Int(_, true) => b.group_non_uniform_s_max(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        crate::item::Elem::Float(_) => b.group_non_uniform_f_max(
                            ty,
                            Some(out),
                            subgroup,
                            GroupOperation::Reduce,
                            input,
                            None,
                        ),
                        _ => unreachable!(),
                    }
                    .unwrap();
                });
            }
        }
    }

    fn subgroup(&mut self) -> Word {
        self.const_u32(Scope::Subgroup as u32)
    }
}