cubecl_spirv/
subgroup.rs

1use cubecl_core::ir::{Plane, UnaryOperator, Variable};
2use rspirv::spirv::{Capability, GroupOperation, Scope, Word};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7    pub fn compile_plane(&mut self, plane: Plane, out: Option<Variable>, uniform: bool) {
8        self.capabilities
9            .insert(Capability::GroupNonUniformArithmetic);
10        let subgroup = self.subgroup();
11        let out = out.unwrap();
12        match plane {
13            Plane::Elect => {
14                let out = self.compile_variable(out);
15                let out_id = self.write_id(&out);
16                let bool = self.type_bool();
17                self.group_non_uniform_elect(bool, Some(out_id), subgroup)
18                    .unwrap();
19                self.write(&out, out_id);
20            }
21            Plane::All(op) => {
22                self.capabilities.insert(Capability::GroupNonUniformVote);
23                match out.line_size() {
24                    1 => {
25                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
26                            b.group_non_uniform_all(ty, Some(out), subgroup, input)
27                                .unwrap();
28                        });
29                    }
30                    vec => {
31                        let elem_ty = self.compile_type(op.input.ty).elem().id(self);
32                        let bool_ty = self.type_bool();
33
34                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
35                            let ids = (0..vec as u32)
36                                .map(|i| {
37                                    let elem_i =
38                                        b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
39                                    b.group_non_uniform_all(bool_ty, None, subgroup, elem_i)
40                                        .unwrap()
41                                })
42                                .collect::<Vec<_>>();
43                            b.composite_construct(ty, Some(out), ids).unwrap();
44                        });
45                    }
46                };
47            }
48            Plane::Any(op) => {
49                self.capabilities.insert(Capability::GroupNonUniformVote);
50                match out.line_size() {
51                    1 => {
52                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
53                            b.group_non_uniform_any(ty, Some(out), subgroup, input)
54                                .unwrap();
55                        });
56                    }
57                    vec => {
58                        let elem_ty = self.compile_type(op.input.ty).elem().id(self);
59                        let bool_ty = self.type_bool();
60
61                        self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
62                            let ids = (0..vec as u32)
63                                .map(|i| {
64                                    let elem_i =
65                                        b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
66                                    b.group_non_uniform_any(bool_ty, None, subgroup, elem_i)
67                                        .unwrap()
68                                })
69                                .collect::<Vec<_>>();
70                            b.composite_construct(ty, Some(out), ids).unwrap();
71                        });
72                    }
73                };
74            }
75            Plane::Ballot(op) => {
76                self.capabilities.insert(Capability::GroupNonUniformBallot);
77                assert_eq!(
78                    op.input.line_size(),
79                    1,
80                    "plane_ballot can't work with vectorized values"
81                );
82                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
83                    b.group_non_uniform_ballot(ty, Some(out), subgroup, input)
84                        .unwrap();
85                });
86            }
87            Plane::Broadcast(op) => {
88                let is_broadcast = self.uniformity.is_var_uniform(op.rhs);
89                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
90                    match is_broadcast {
91                        true => {
92                            b.capabilities.insert(Capability::GroupNonUniformBallot);
93                            b.group_non_uniform_broadcast(ty, Some(out), subgroup, lhs, rhs)
94                                .unwrap();
95                        }
96                        false => {
97                            b.capabilities.insert(Capability::GroupNonUniformShuffle);
98                            b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
99                                .unwrap();
100                        }
101                    }
102                });
103            }
104            Plane::Sum(op) => {
105                self.plane_sum(op, out, GroupOperation::Reduce, uniform);
106            }
107            Plane::ExclusiveSum(op) => {
108                self.plane_sum(op, out, GroupOperation::ExclusiveScan, uniform);
109            }
110            Plane::InclusiveSum(op) => {
111                self.plane_sum(op, out, GroupOperation::InclusiveScan, uniform);
112            }
113            Plane::Prod(op) => {
114                self.plane_prod(op, out, GroupOperation::Reduce, uniform);
115            }
116            Plane::ExclusiveProd(op) => {
117                self.plane_prod(op, out, GroupOperation::ExclusiveScan, uniform);
118            }
119            Plane::InclusiveProd(op) => {
120                self.plane_prod(op, out, GroupOperation::InclusiveScan, uniform);
121            }
122            Plane::Min(op) => {
123                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
124                    match out_ty.elem() {
125                        Elem::Int(_, false) => b.group_non_uniform_u_min(
126                            ty,
127                            Some(out),
128                            subgroup,
129                            GroupOperation::Reduce,
130                            input,
131                            None,
132                        ),
133                        Elem::Int(_, true) => b.group_non_uniform_s_min(
134                            ty,
135                            Some(out),
136                            subgroup,
137                            GroupOperation::Reduce,
138                            input,
139                            None,
140                        ),
141                        Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_min(
142                            ty,
143                            Some(out),
144                            subgroup,
145                            GroupOperation::Reduce,
146                            input,
147                            None,
148                        ),
149                        _ => unreachable!(),
150                    }
151                    .unwrap();
152                });
153            }
154            Plane::Max(op) => {
155                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
156                    match out_ty.elem() {
157                        Elem::Int(_, false) => b.group_non_uniform_u_max(
158                            ty,
159                            Some(out),
160                            subgroup,
161                            GroupOperation::Reduce,
162                            input,
163                            None,
164                        ),
165                        Elem::Int(_, true) => b.group_non_uniform_s_max(
166                            ty,
167                            Some(out),
168                            subgroup,
169                            GroupOperation::Reduce,
170                            input,
171                            None,
172                        ),
173                        Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_max(
174                            ty,
175                            Some(out),
176                            subgroup,
177                            GroupOperation::Reduce,
178                            input,
179                            None,
180                        ),
181                        _ => unreachable!(),
182                    }
183                    .unwrap();
184                });
185            }
186            Plane::Shuffle(op) => {
187                self.capabilities.insert(Capability::GroupNonUniformShuffle);
188                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
189                    b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
190                        .unwrap();
191                });
192            }
193            Plane::ShuffleXor(op) => {
194                self.capabilities.insert(Capability::GroupNonUniformShuffle);
195                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
196                    b.group_non_uniform_shuffle_xor(ty, Some(out), subgroup, lhs, rhs)
197                        .unwrap();
198                });
199            }
200            Plane::ShuffleUp(op) => {
201                self.capabilities.insert(Capability::GroupNonUniformShuffle);
202                self.capabilities
203                    .insert(Capability::GroupNonUniformShuffleRelative);
204                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
205                    b.group_non_uniform_shuffle_up(ty, Some(out), subgroup, lhs, rhs)
206                        .unwrap();
207                });
208            }
209            Plane::ShuffleDown(op) => {
210                self.capabilities.insert(Capability::GroupNonUniformShuffle);
211                self.capabilities
212                    .insert(Capability::GroupNonUniformShuffleRelative);
213                self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
214                    b.group_non_uniform_shuffle_down(ty, Some(out), subgroup, lhs, rhs)
215                        .unwrap();
216                });
217            }
218        }
219    }
220
221    fn plane_sum(
222        &mut self,
223        op: UnaryOperator,
224        out: Variable,
225        action: GroupOperation,
226        uniform: bool,
227    ) {
228        let subgroup = self.subgroup();
229        self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
230            match out_ty.elem() {
231                Elem::Int(_, _) => {
232                    b.group_non_uniform_i_add(ty, Some(out), subgroup, action, input, None)
233                }
234                Elem::Float(..) | Elem::Relaxed => {
235                    b.group_non_uniform_f_add(ty, Some(out), subgroup, action, input, None)
236                }
237                elem => unreachable!("{elem}"),
238            }
239            .unwrap();
240        });
241    }
242
243    fn plane_prod(
244        &mut self,
245        op: UnaryOperator,
246        out: Variable,
247        action: GroupOperation,
248        uniform: bool,
249    ) {
250        let subgroup = self.subgroup();
251        self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
252            match out_ty.elem() {
253                Elem::Int(_, _) => {
254                    b.group_non_uniform_i_mul(ty, Some(out), subgroup, action, input, None)
255                }
256                Elem::Float(..) | Elem::Relaxed => {
257                    b.group_non_uniform_f_mul(ty, Some(out), subgroup, action, input, None)
258                }
259                _ => unreachable!(),
260            }
261            .unwrap();
262        });
263    }
264
265    fn subgroup(&mut self) -> Word {
266        self.const_u32(Scope::Subgroup as u32)
267    }
268}