cubecl_spirv/
atomic.rs

1use cubecl_core::ir::{AtomicOp, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7    pub fn compile_atomic(&mut self, atomic: AtomicOp, out: Option<Variable>) {
8        let out = out.unwrap();
9        match atomic {
10            AtomicOp::Load(op) => {
11                let input = self.compile_variable(op.input);
12                let out = self.compile_variable(out);
13                let out_ty = out.item();
14
15                let input_id = input.id(self);
16                let out_id = self.write_id(&out);
17
18                let ty = out_ty.id(self);
19                let memory = self.const_u32(Scope::Device as u32);
20                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
21
22                self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
23                    .unwrap();
24                self.write(&out, out_id);
25            }
26            AtomicOp::Store(op) => {
27                let input = self.compile_variable(op.input);
28                let out = self.compile_variable(out);
29
30                let input_id = self.read(&input);
31                let out_id = out.id(self);
32
33                let memory = self.const_u32(Scope::Device as u32);
34                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
35
36                self.atomic_store(out_id, memory, semantics, input_id)
37                    .unwrap();
38            }
39            AtomicOp::Swap(op) => {
40                let lhs = self.compile_variable(op.lhs);
41                let rhs = self.compile_variable(op.rhs);
42                let out = self.compile_variable(out);
43                let out_ty = out.item();
44
45                let lhs_id = lhs.id(self);
46                let rhs_id = self.read(&rhs);
47                let out_id = self.write_id(&out);
48
49                let ty = out_ty.id(self);
50                let memory = self.const_u32(Scope::Device as u32);
51                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
52
53                self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
54                    .unwrap();
55                self.write(&out, out_id);
56            }
57            AtomicOp::CompareAndSwap(op) => {
58                let atomic = self.compile_variable(op.input);
59                let cmp = self.compile_variable(op.cmp);
60                let val = self.compile_variable(op.val);
61                let out = self.compile_variable(out);
62                let out_ty = out.item();
63
64                let atomic_id = atomic.id(self);
65                let cmp_id = self.read(&cmp);
66                let val_id = self.read(&val);
67                let out_id = self.write_id(&out);
68
69                let ty = out_ty.id(self);
70                let memory = self.const_u32(Scope::Device as u32);
71                let semantics_success = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
72                let semantics_failure = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
73
74                assert!(
75                    matches!(out_ty.elem(), Elem::Int(_, _)),
76                    "compare and swap doesn't support float atomics"
77                );
78                self.atomic_compare_exchange(
79                    ty,
80                    Some(out_id),
81                    atomic_id,
82                    memory,
83                    semantics_success,
84                    semantics_failure,
85                    val_id,
86                    cmp_id,
87                )
88                .unwrap();
89                self.write(&out, out_id);
90            }
91            AtomicOp::Add(op) => {
92                let lhs = self.compile_variable(op.lhs);
93                let rhs = self.compile_variable(op.rhs);
94                let out = self.compile_variable(out);
95                let out_ty = out.item();
96
97                let lhs_id = lhs.id(self);
98                let rhs_id = self.read(&rhs);
99                let out_id = self.write_id(&out);
100
101                let ty = out_ty.id(self);
102                let memory = self.const_u32(Scope::Device as u32);
103                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
104
105                match out_ty.elem() {
106                    Elem::Int(_, _) => self
107                        .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
108                        .unwrap(),
109                    Elem::Float(width, None) => {
110                        match width {
111                            16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
112                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
113                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
114                            _ => unreachable!(),
115                        };
116                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
117                            .unwrap()
118                    }
119                    _ => unreachable!(),
120                };
121
122                self.write(&out, out_id);
123            }
124            AtomicOp::Sub(op) => {
125                let lhs = self.compile_variable(op.lhs);
126                let rhs = self.compile_variable(op.rhs);
127                let out = self.compile_variable(out);
128                let out_ty = out.item();
129
130                let lhs_id = lhs.id(self);
131                let rhs_id = self.read(&rhs);
132                let out_id = self.write_id(&out);
133
134                let ty = out_ty.id(self);
135                let memory = self.const_u32(Scope::Device as u32);
136                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
137
138                assert!(
139                    matches!(out_ty.elem(), Elem::Int(_, _)),
140                    "sub doesn't support float atomics"
141                );
142                match out_ty.elem() {
143                    Elem::Int(_, _) => self
144                        .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
145                        .unwrap(),
146                    Elem::Float(width, None) => {
147                        match width {
148                            16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
149                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
150                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
151                            _ => unreachable!(),
152                        };
153                        let negated = self.f_negate(ty, None, rhs_id).unwrap();
154                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
155                            .unwrap()
156                    }
157                    _ => unreachable!(),
158                };
159                self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
160                    .unwrap();
161                self.write(&out, out_id);
162            }
163            AtomicOp::Max(op) => {
164                let lhs = self.compile_variable(op.lhs);
165                let rhs = self.compile_variable(op.rhs);
166                let out = self.compile_variable(out);
167                let out_ty = out.item();
168
169                let lhs_id = lhs.id(self);
170                let rhs_id = self.read(&rhs);
171                let out_id = self.write_id(&out);
172
173                let ty = out_ty.id(self);
174                let memory = self.const_u32(Scope::Device as u32);
175                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
176
177                match out_ty.elem() {
178                    Elem::Int(_, false) => self
179                        .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
180                        .unwrap(),
181                    Elem::Int(_, true) => self
182                        .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
183                        .unwrap(),
184                    Elem::Float(width, None) => {
185                        match width {
186                            16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
187                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
188                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
189                            _ => unreachable!(),
190                        };
191                        self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
192                            .unwrap()
193                    }
194                    _ => unreachable!(),
195                };
196                self.write(&out, out_id);
197            }
198            AtomicOp::Min(op) => {
199                let lhs = self.compile_variable(op.lhs);
200                let rhs = self.compile_variable(op.rhs);
201                let out = self.compile_variable(out);
202                let out_ty = out.item();
203
204                let lhs_id = lhs.id(self);
205                let rhs_id = self.read(&rhs);
206                let out_id = self.write_id(&out);
207
208                let ty = out_ty.id(self);
209                let memory = self.const_u32(Scope::Device as u32);
210                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
211
212                match out_ty.elem() {
213                    Elem::Int(_, false) => self
214                        .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
215                        .unwrap(),
216                    Elem::Int(_, true) => self
217                        .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
218                        .unwrap(),
219                    Elem::Float(width, None) => {
220                        match width {
221                            16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
222                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
223                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
224                            _ => unreachable!(),
225                        };
226                        self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
227                            .unwrap()
228                    }
229                    _ => unreachable!(),
230                };
231                self.write(&out, out_id);
232            }
233            AtomicOp::And(op) => {
234                let lhs = self.compile_variable(op.lhs);
235                let rhs = self.compile_variable(op.rhs);
236                let out = self.compile_variable(out);
237                let out_ty = out.item();
238
239                let lhs_id = lhs.id(self);
240                let rhs_id = self.read(&rhs);
241                let out_id = self.write_id(&out);
242
243                let ty = out_ty.id(self);
244                let memory = self.const_u32(Scope::Device as u32);
245                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
246
247                assert!(
248                    matches!(out_ty.elem(), Elem::Int(_, _)),
249                    "and doesn't support float atomics"
250                );
251                self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
252                    .unwrap();
253                self.write(&out, out_id);
254            }
255            AtomicOp::Or(op) => {
256                let lhs = self.compile_variable(op.lhs);
257                let rhs = self.compile_variable(op.rhs);
258                let out = self.compile_variable(out);
259                let out_ty = out.item();
260
261                let lhs_id = lhs.id(self);
262                let rhs_id = self.read(&rhs);
263                let out_id = self.write_id(&out);
264
265                let ty = out_ty.id(self);
266                let memory = self.const_u32(Scope::Device as u32);
267                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
268
269                assert!(
270                    matches!(out_ty.elem(), Elem::Int(_, _)),
271                    "or doesn't support float atomics"
272                );
273                self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
274                    .unwrap();
275                self.write(&out, out_id);
276            }
277            AtomicOp::Xor(op) => {
278                let lhs = self.compile_variable(op.lhs);
279                let rhs = self.compile_variable(op.rhs);
280                let out = self.compile_variable(out);
281                let out_ty = out.item();
282
283                let lhs_id = lhs.id(self);
284                let rhs_id = self.read(&rhs);
285                let out_id = self.write_id(&out);
286
287                let ty = out_ty.id(self);
288                let memory = self.const_u32(Scope::Device as u32);
289                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
290
291                assert!(
292                    matches!(out_ty.elem(), Elem::Int(_, _)),
293                    "xor doesn't support float atomics"
294                );
295                self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
296                    .unwrap();
297                self.write(&out, out_id);
298            }
299        }
300    }
301}