cubecl_spirv/
subgroup.rs

1use cubecl_core::ir::{Plane, UnaryOperator, Variable};
2use rspirv::spirv::{Capability, GroupOperation, Scope, Word};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7    pub fn compile_plane(&mut self, plane: Plane, out: Option<Variable>, uniform: bool) {
8        self.capabilities
9            .insert(Capability::GroupNonUniformArithmetic);
10        let subgroup = self.subgroup();
11        let out = out.unwrap();
12        match plane {
13            Plane::Elect => {
14                let out = self.compile_variable(out);
15                let out_id = self.write_id(&out);
16                let bool = self.type_bool();
17                self.group_non_uniform_elect(bool, Some(out_id), subgroup)
18                    .unwrap();
19                self.write(&out, out_id);
20            }
21            Plane::All(op) => {
22                self.capabilities.insert(Capability::GroupNonUniformVote);
23                match out.line_size() {
24                    1 => {
25                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
26                            b.group_non_uniform_all(ty, Some(out), subgroup, input)
27                                .unwrap();
28                        });
29                    }
30                    vec => {
31                        let elem_ty = self.compile_type(op.input.ty).elem().id(self);
32                        let bool_ty = self.type_bool();
33
34                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
35                            let ids = (0..vec)
36                                .map(|i| {
37                                    let elem_i =
38                                        b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
39                                    b.group_non_uniform_all(bool_ty, None, subgroup, elem_i)
40                                        .unwrap()
41                                })
42                                .collect::<Vec<_>>();
43                            b.composite_construct(ty, Some(out), ids).unwrap();
44                        });
45                    }
46                };
47            }
48            Plane::Any(op) => {
49                self.capabilities.insert(Capability::GroupNonUniformVote);
50                match out.line_size() {
51                    1 => {
52                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
53                            b.group_non_uniform_any(ty, Some(out), subgroup, input)
54                                .unwrap();
55                        });
56                    }
57                    vec => {
58                        let elem_ty = self.compile_type(op.input.ty).elem().id(self);
59                        let bool_ty = self.type_bool();
60
61                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
62                            let ids = (0..vec)
63                                .map(|i| {
64                                    let elem_i =
65                                        b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
66                                    b.group_non_uniform_any(bool_ty, None, subgroup, elem_i)
67                                        .unwrap()
68                                })
69                                .collect::<Vec<_>>();
70                            b.composite_construct(ty, Some(out), ids).unwrap();
71                        });
72                    }
73                };
74            }
75            Plane::Ballot(op) => {
76                self.capabilities.insert(Capability::GroupNonUniformBallot);
77                assert_eq!(
78                    op.input.line_size(),
79                    1,
80                    "plane_ballot can't work with vectorized values"
81                );
82                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
83                    b.group_non_uniform_ballot(ty, Some(out), subgroup, input)
84                        .unwrap();
85                });
86            }
87            Plane::Broadcast(op) => {
88                let is_broadcast = self.uniformity.is_var_uniform(op.rhs);
89                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
90                    match is_broadcast {
91                        true => {
92                            b.capabilities.insert(Capability::GroupNonUniformBallot);
93                            b.group_non_uniform_broadcast(ty, Some(out), subgroup, lhs, rhs)
94                                .unwrap();
95                        }
96                        false => {
97                            b.capabilities.insert(Capability::GroupNonUniformShuffle);
98                            b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
99                                .unwrap();
100                        }
101                    }
102                });
103            }
104            Plane::Sum(op) => {
105                self.plane_sum(op, out, GroupOperation::Reduce, uniform);
106            }
107            Plane::ExclusiveSum(op) => {
108                self.plane_sum(op, out, GroupOperation::ExclusiveScan, uniform);
109            }
110            Plane::InclusiveSum(op) => {
111                self.plane_sum(op, out, GroupOperation::InclusiveScan, uniform);
112            }
113            Plane::Prod(op) => {
114                self.plane_prod(op, out, GroupOperation::Reduce, uniform);
115            }
116            Plane::ExclusiveProd(op) => {
117                self.plane_prod(op, out, GroupOperation::ExclusiveScan, uniform);
118            }
119            Plane::InclusiveProd(op) => {
120                self.plane_prod(op, out, GroupOperation::InclusiveScan, uniform);
121            }
122            Plane::Min(op) => {
123                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
124                    match out_ty.elem() {
125                        Elem::Int(_, false) => b.group_non_uniform_u_min(
126                            ty,
127                            Some(out),
128                            subgroup,
129                            GroupOperation::Reduce,
130                            input,
131                            None,
132                        ),
133                        Elem::Int(_, true) => b.group_non_uniform_s_min(
134                            ty,
135                            Some(out),
136                            subgroup,
137                            GroupOperation::Reduce,
138                            input,
139                            None,
140                        ),
141                        Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_min(
142                            ty,
143                            Some(out),
144                            subgroup,
145                            GroupOperation::Reduce,
146                            input,
147                            None,
148                        ),
149                        _ => unreachable!(),
150                    }
151                    .unwrap();
152                });
153            }
154            Plane::Max(op) => {
155                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
156                    match out_ty.elem() {
157                        Elem::Int(_, false) => b.group_non_uniform_u_max(
158                            ty,
159                            Some(out),
160                            subgroup,
161                            GroupOperation::Reduce,
162                            input,
163                            None,
164                        ),
165                        Elem::Int(_, true) => b.group_non_uniform_s_max(
166                            ty,
167                            Some(out),
168                            subgroup,
169                            GroupOperation::Reduce,
170                            input,
171                            None,
172                        ),
173                        Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_max(
174                            ty,
175                            Some(out),
176                            subgroup,
177                            GroupOperation::Reduce,
178                            input,
179                            None,
180                        ),
181                        _ => unreachable!(),
182                    }
183                    .unwrap();
184                });
185            }
186            Plane::Shuffle(op) => {
187                self.capabilities.insert(Capability::GroupNonUniformShuffle);
188                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
189                    b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
190                        .unwrap();
191                });
192            }
193            Plane::ShuffleXor(op) => {
194                self.capabilities.insert(Capability::GroupNonUniformShuffle);
195                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
196                    b.group_non_uniform_shuffle_xor(ty, Some(out), subgroup, lhs, rhs)
197                        .unwrap();
198                });
199            }
200            Plane::ShuffleUp(op) => {
201                self.capabilities.insert(Capability::GroupNonUniformShuffle);
202                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
203                    b.group_non_uniform_shuffle_up(ty, Some(out), subgroup, lhs, rhs)
204                        .unwrap();
205                });
206            }
207            Plane::ShuffleDown(op) => {
208                self.capabilities.insert(Capability::GroupNonUniformShuffle);
209                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
210                    b.group_non_uniform_shuffle_down(ty, Some(out), subgroup, lhs, rhs)
211                        .unwrap();
212                });
213            }
214        }
215    }
216
217    fn plane_sum(
218        &mut self,
219        op: UnaryOperator,
220        out: Variable,
221        action: GroupOperation,
222        uniform: bool,
223    ) {
224        let subgroup = self.subgroup();
225        self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
226            match out_ty.elem() {
227                Elem::Int(_, _) => {
228                    b.group_non_uniform_i_add(ty, Some(out), subgroup, action, input, None)
229                }
230                Elem::Float(..) | Elem::Relaxed => {
231                    b.group_non_uniform_f_add(ty, Some(out), subgroup, action, input, None)
232                }
233                elem => unreachable!("{elem}"),
234            }
235            .unwrap();
236        });
237    }
238
239    fn plane_prod(
240        &mut self,
241        op: UnaryOperator,
242        out: Variable,
243        action: GroupOperation,
244        uniform: bool,
245    ) {
246        let subgroup = self.subgroup();
247        self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
248            match out_ty.elem() {
249                Elem::Int(_, _) => {
250                    b.group_non_uniform_i_mul(ty, Some(out), subgroup, action, input, None)
251                }
252                Elem::Float(..) | Elem::Relaxed => {
253                    b.group_non_uniform_f_mul(ty, Some(out), subgroup, action, input, None)
254                }
255                _ => unreachable!(),
256            }
257            .unwrap();
258        });
259    }
260
261    fn subgroup(&mut self) -> Word {
262        self.const_u32(Scope::Subgroup as u32)
263    }
264}