Skip to main content

cubecl_spirv/
atomic.rs

1use cubecl_core::ir::{AtomicOp, ElemType, InstructionModes, IntKind, UIntKind, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope, Word};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7    pub fn compile_atomic(
8        &mut self,
9        atomic: AtomicOp,
10        out: Option<Variable>,
11        modes: InstructionModes,
12    ) {
13        let out = out.unwrap();
14
15        if matches!(
16            out.elem_type(),
17            ElemType::Int(IntKind::I64) | ElemType::UInt(UIntKind::U64)
18        ) {
19            self.capabilities.insert(Capability::Int64Atomics);
20        }
21
22        match atomic {
23            AtomicOp::Load(op) => {
24                let input = self.compile_variable(op.input);
25                let out = self.compile_variable(out);
26                let out_ty = out.item();
27
28                let input_id = input.id(self);
29                let out_id = self.write_id(&out);
30
31                let ty = out_ty.id(self);
32                let memory = self.scope(&input);
33                let semantics = self.semantics_r(&input);
34
35                self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
36                    .unwrap();
37                self.write(&out, out_id);
38            }
39            AtomicOp::Store(op) => {
40                let input = self.compile_variable(op.input);
41                let out = self.compile_variable(out);
42
43                let input_id = self.read(&input);
44                let out_id = out.id(self);
45
46                let memory = self.scope(&out);
47                let semantics = self.semantics_w(&out);
48
49                self.atomic_store(out_id, memory, semantics, input_id)
50                    .unwrap();
51            }
52            AtomicOp::Swap(op) => {
53                let lhs = self.compile_variable(op.lhs);
54                let rhs = self.compile_variable(op.rhs);
55                let out = self.compile_variable(out);
56                let out_ty = out.item();
57
58                let lhs_id = lhs.id(self);
59                let rhs_id = self.read(&rhs);
60                let out_id = self.write_id(&out);
61
62                let ty = out_ty.id(self);
63                let memory = self.scope(&lhs);
64                let semantics = self.semantics_rw(&lhs);
65
66                self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
67                    .unwrap();
68                self.write(&out, out_id);
69            }
70            AtomicOp::CompareAndSwap(op) => {
71                let atomic = self.compile_variable(op.input);
72                let cmp = self.compile_variable(op.cmp);
73                let val = self.compile_variable(op.val);
74                let out = self.compile_variable(out);
75                let out_ty = out.item();
76
77                let atomic_id = atomic.id(self);
78                let cmp_id = self.read(&cmp);
79                let val_id = self.read(&val);
80                let out_id = self.write_id(&out);
81
82                let ty = out_ty.id(self);
83                let memory = self.scope(&atomic);
84                let semantics_success = self.semantics_rw(&atomic);
85                let semantics_failure = self.semantics_r(&atomic);
86
87                assert!(
88                    matches!(out_ty.elem(), Elem::Int(_, _)),
89                    "compare and swap doesn't support float atomics"
90                );
91                self.atomic_compare_exchange(
92                    ty,
93                    Some(out_id),
94                    atomic_id,
95                    memory,
96                    semantics_success,
97                    semantics_failure,
98                    val_id,
99                    cmp_id,
100                )
101                .unwrap();
102                self.write(&out, out_id);
103            }
104            AtomicOp::Add(op) => {
105                let lhs = self.compile_variable(op.lhs);
106                let rhs = self.compile_variable(op.rhs);
107                let out = self.compile_variable(out);
108                let out_ty = out.item();
109
110                let lhs_id = lhs.id(self);
111                let rhs_id = self.read(&rhs);
112                let out_id = self.write_id(&out);
113
114                let ty = out_ty.id(self);
115                let memory = self.scope(&lhs);
116                let semantics = self.semantics_rw(&lhs);
117
118                match out_ty.elem() {
119                    Elem::Int(_, _) => self
120                        .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
121                        .unwrap(),
122                    Elem::Float(width, None) => {
123                        match width {
124                            16 if out_ty.vectorization() == 1 => {
125                                self.capabilities.insert(Capability::AtomicFloat16AddEXT)
126                            }
127                            16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
128                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
129                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
130                            _ => unreachable!(),
131                        };
132                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
133                            .unwrap()
134                    }
135                    _ => unreachable!(),
136                };
137
138                self.write(&out, out_id);
139            }
140            AtomicOp::Sub(op) => {
141                let lhs = self.compile_variable(op.lhs);
142                let rhs = self.compile_variable(op.rhs);
143                let out = self.compile_variable(out);
144                let out_ty = out.item();
145
146                let lhs_id = lhs.id(self);
147                let rhs_id = self.read(&rhs);
148                let out_id = self.write_id(&out);
149
150                let ty = out_ty.id(self);
151                let memory = self.scope(&lhs);
152                let semantics = self.semantics_rw(&lhs);
153
154                assert!(
155                    matches!(out_ty.elem(), Elem::Int(_, _)),
156                    "sub doesn't support float atomics"
157                );
158                match out_ty.elem() {
159                    Elem::Int(_, _) => self
160                        .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
161                        .unwrap(),
162                    Elem::Float(width, None) => {
163                        match width {
164                            16 if out_ty.vectorization() == 1 => {
165                                self.capabilities.insert(Capability::AtomicFloat16AddEXT)
166                            }
167                            16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
168                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
169                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
170                            _ => unreachable!(),
171                        };
172                        let negated = self.f_negate(ty, None, rhs_id).unwrap();
173                        self.declare_math_mode(modes, negated);
174                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
175                            .unwrap()
176                    }
177                    _ => unreachable!(),
178                };
179                self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
180                    .unwrap();
181                self.write(&out, out_id);
182            }
183            AtomicOp::Max(op) => {
184                let lhs = self.compile_variable(op.lhs);
185                let rhs = self.compile_variable(op.rhs);
186                let out = self.compile_variable(out);
187                let out_ty = out.item();
188
189                let lhs_id = lhs.id(self);
190                let rhs_id = self.read(&rhs);
191                let out_id = self.write_id(&out);
192
193                let ty = out_ty.id(self);
194                let memory = self.scope(&lhs);
195                let semantics = self.semantics_rw(&lhs);
196
197                match out_ty.elem() {
198                    Elem::Int(_, false) => self
199                        .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
200                        .unwrap(),
201                    Elem::Int(_, true) => self
202                        .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
203                        .unwrap(),
204                    Elem::Float(width, None) => {
205                        match width {
206                            16 if out_ty.vectorization() == 1 => {
207                                self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT)
208                            }
209                            16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
210                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
211                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
212                            _ => unreachable!(),
213                        };
214                        self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
215                            .unwrap()
216                    }
217                    _ => unreachable!(),
218                };
219                self.write(&out, out_id);
220            }
221            AtomicOp::Min(op) => {
222                let lhs = self.compile_variable(op.lhs);
223                let rhs = self.compile_variable(op.rhs);
224                let out = self.compile_variable(out);
225                let out_ty = out.item();
226
227                let lhs_id = lhs.id(self);
228                let rhs_id = self.read(&rhs);
229                let out_id = self.write_id(&out);
230
231                let ty = out_ty.id(self);
232                let memory = self.scope(&lhs);
233                let semantics = self.semantics_rw(&lhs);
234
235                match out_ty.elem() {
236                    Elem::Int(_, false) => self
237                        .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
238                        .unwrap(),
239                    Elem::Int(_, true) => self
240                        .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
241                        .unwrap(),
242                    Elem::Float(width, None) => {
243                        match width {
244                            16 if out_ty.vectorization() == 1 => {
245                                self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT)
246                            }
247                            16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
248                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
249                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
250                            _ => unreachable!(),
251                        };
252                        self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
253                            .unwrap()
254                    }
255                    _ => unreachable!(),
256                };
257                self.write(&out, out_id);
258            }
259            AtomicOp::And(op) => {
260                let lhs = self.compile_variable(op.lhs);
261                let rhs = self.compile_variable(op.rhs);
262                let out = self.compile_variable(out);
263                let out_ty = out.item();
264
265                let lhs_id = lhs.id(self);
266                let rhs_id = self.read(&rhs);
267                let out_id = self.write_id(&out);
268
269                let ty = out_ty.id(self);
270                let memory = self.scope(&lhs);
271                let semantics = self.semantics_rw(&lhs);
272
273                assert!(
274                    matches!(out_ty.elem(), Elem::Int(_, _)),
275                    "and doesn't support float atomics"
276                );
277                self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
278                    .unwrap();
279                self.write(&out, out_id);
280            }
281            AtomicOp::Or(op) => {
282                let lhs = self.compile_variable(op.lhs);
283                let rhs = self.compile_variable(op.rhs);
284                let out = self.compile_variable(out);
285                let out_ty = out.item();
286
287                let lhs_id = lhs.id(self);
288                let rhs_id = self.read(&rhs);
289                let out_id = self.write_id(&out);
290
291                let ty = out_ty.id(self);
292                let memory = self.scope(&lhs);
293                let semantics = self.semantics_rw(&lhs);
294
295                assert!(
296                    matches!(out_ty.elem(), Elem::Int(_, _)),
297                    "or doesn't support float atomics"
298                );
299                self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
300                    .unwrap();
301                self.write(&out, out_id);
302            }
303            AtomicOp::Xor(op) => {
304                let lhs = self.compile_variable(op.lhs);
305                let rhs = self.compile_variable(op.rhs);
306                let out = self.compile_variable(out);
307                let out_ty = out.item();
308
309                let lhs_id = lhs.id(self);
310                let rhs_id = self.read(&rhs);
311                let out_id = self.write_id(&out);
312
313                let ty = out_ty.id(self);
314                let memory = self.scope(&lhs);
315                let semantics = self.semantics_rw(&lhs);
316
317                assert!(
318                    matches!(out_ty.elem(), Elem::Int(_, _)),
319                    "xor doesn't support float atomics"
320                );
321                self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
322                    .unwrap();
323                self.write(&out, out_id);
324            }
325        }
326    }
327
328    fn scope(&mut self, var: &crate::variable::Variable) -> Word {
329        let value = self.scope_of(var) as u32;
330        self.const_u32(value)
331    }
332
333    fn semantics_r(&mut self, var: &crate::variable::Variable) -> Word {
334        let value = self.semantics_of(var) | MemorySemantics::ACQUIRE;
335        self.const_u32(value.bits())
336    }
337
338    fn semantics_w(&mut self, var: &crate::variable::Variable) -> Word {
339        let value = self.semantics_of(var) | MemorySemantics::RELEASE;
340        self.const_u32(value.bits())
341    }
342
343    fn semantics_rw(&mut self, var: &crate::variable::Variable) -> Word {
344        let value = self.semantics_of(var) | MemorySemantics::ACQUIRE_RELEASE;
345        self.const_u32(value.bits())
346    }
347
348    fn scope_of(&mut self, var: &crate::variable::Variable) -> Scope {
349        let id = var.id(self);
350        *self
351            .state
352            .atomic_scopes
353            .get(&id)
354            .expect("Atomic should have a scope registered")
355    }
356
357    fn semantics_of(&mut self, var: &crate::variable::Variable) -> MemorySemantics {
358        match self.scope_of(var) {
359            Scope::Device => MemorySemantics::UNIFORM_MEMORY,
360            Scope::Workgroup => MemorySemantics::WORKGROUP_MEMORY,
361            Scope::Subgroup => MemorySemantics::SUBGROUP_MEMORY,
362            other => unreachable!("Invalid scope for atomic operation, {other:?}"),
363        }
364    }
365}