cubecl_spirv/
atomic.rs

1use cubecl_core::ir::{AtomicOp, ElemType, InstructionModes, IntKind, UIntKind, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope, Word};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7    pub fn compile_atomic(
8        &mut self,
9        atomic: AtomicOp,
10        out: Option<Variable>,
11        modes: InstructionModes,
12    ) {
13        let out = out.unwrap();
14
15        if matches!(
16            out.elem_type(),
17            ElemType::Int(IntKind::I64) | ElemType::UInt(UIntKind::U64)
18        ) {
19            self.capabilities.insert(Capability::Int64Atomics);
20        }
21
22        match atomic {
23            AtomicOp::Load(op) => {
24                let input = self.compile_variable(op.input);
25                let out = self.compile_variable(out);
26                let out_ty = out.item();
27
28                let input_id = input.id(self);
29                let out_id = self.write_id(&out);
30
31                let ty = out_ty.id(self);
32                let memory = self.scope(&input);
33                let semantics = self.semantics_r(&input);
34
35                self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
36                    .unwrap();
37                self.write(&out, out_id);
38            }
39            AtomicOp::Store(op) => {
40                let input = self.compile_variable(op.input);
41                let out = self.compile_variable(out);
42
43                let input_id = self.read(&input);
44                let out_id = out.id(self);
45
46                let memory = self.scope(&out);
47                let semantics = self.semantics_w(&out);
48
49                self.atomic_store(out_id, memory, semantics, input_id)
50                    .unwrap();
51            }
52            AtomicOp::Swap(op) => {
53                let lhs = self.compile_variable(op.lhs);
54                let rhs = self.compile_variable(op.rhs);
55                let out = self.compile_variable(out);
56                let out_ty = out.item();
57
58                let lhs_id = lhs.id(self);
59                let rhs_id = self.read(&rhs);
60                let out_id = self.write_id(&out);
61
62                let ty = out_ty.id(self);
63                let memory = self.scope(&lhs);
64                let semantics = self.semantics_rw(&lhs);
65
66                self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
67                    .unwrap();
68                self.write(&out, out_id);
69            }
70            AtomicOp::CompareAndSwap(op) => {
71                let atomic = self.compile_variable(op.input);
72                let cmp = self.compile_variable(op.cmp);
73                let val = self.compile_variable(op.val);
74                let out = self.compile_variable(out);
75                let out_ty = out.item();
76
77                let atomic_id = atomic.id(self);
78                let cmp_id = self.read(&cmp);
79                let val_id = self.read(&val);
80                let out_id = self.write_id(&out);
81
82                let ty = out_ty.id(self);
83                let memory = self.scope(&atomic);
84                let semantics_success = self.semantics_rw(&atomic);
85                let semantics_failure = self.semantics_r(&atomic);
86
87                assert!(
88                    matches!(out_ty.elem(), Elem::Int(_, _)),
89                    "compare and swap doesn't support float atomics"
90                );
91                self.atomic_compare_exchange(
92                    ty,
93                    Some(out_id),
94                    atomic_id,
95                    memory,
96                    semantics_success,
97                    semantics_failure,
98                    val_id,
99                    cmp_id,
100                )
101                .unwrap();
102                self.write(&out, out_id);
103            }
104            AtomicOp::Add(op) => {
105                let lhs = self.compile_variable(op.lhs);
106                let rhs = self.compile_variable(op.rhs);
107                let out = self.compile_variable(out);
108                let out_ty = out.item();
109
110                let lhs_id = lhs.id(self);
111                let rhs_id = self.read(&rhs);
112                let out_id = self.write_id(&out);
113
114                let ty = out_ty.id(self);
115                let memory = self.scope(&lhs);
116                let semantics = self.semantics_rw(&lhs);
117
118                match out_ty.elem() {
119                    Elem::Int(_, _) => self
120                        .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
121                        .unwrap(),
122                    Elem::Float(width, None) => {
123                        match width {
124                            16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
125                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
126                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
127                            _ => unreachable!(),
128                        };
129                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
130                            .unwrap()
131                    }
132                    _ => unreachable!(),
133                };
134
135                self.write(&out, out_id);
136            }
137            AtomicOp::Sub(op) => {
138                let lhs = self.compile_variable(op.lhs);
139                let rhs = self.compile_variable(op.rhs);
140                let out = self.compile_variable(out);
141                let out_ty = out.item();
142
143                let lhs_id = lhs.id(self);
144                let rhs_id = self.read(&rhs);
145                let out_id = self.write_id(&out);
146
147                let ty = out_ty.id(self);
148                let memory = self.scope(&lhs);
149                let semantics = self.semantics_rw(&lhs);
150
151                assert!(
152                    matches!(out_ty.elem(), Elem::Int(_, _)),
153                    "sub doesn't support float atomics"
154                );
155                match out_ty.elem() {
156                    Elem::Int(_, _) => self
157                        .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
158                        .unwrap(),
159                    Elem::Float(width, None) => {
160                        match width {
161                            16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
162                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
163                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
164                            _ => unreachable!(),
165                        };
166                        let negated = self.f_negate(ty, None, rhs_id).unwrap();
167                        self.declare_math_mode(modes, negated);
168                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
169                            .unwrap()
170                    }
171                    _ => unreachable!(),
172                };
173                self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
174                    .unwrap();
175                self.write(&out, out_id);
176            }
177            AtomicOp::Max(op) => {
178                let lhs = self.compile_variable(op.lhs);
179                let rhs = self.compile_variable(op.rhs);
180                let out = self.compile_variable(out);
181                let out_ty = out.item();
182
183                let lhs_id = lhs.id(self);
184                let rhs_id = self.read(&rhs);
185                let out_id = self.write_id(&out);
186
187                let ty = out_ty.id(self);
188                let memory = self.scope(&lhs);
189                let semantics = self.semantics_rw(&lhs);
190
191                match out_ty.elem() {
192                    Elem::Int(_, false) => self
193                        .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
194                        .unwrap(),
195                    Elem::Int(_, true) => self
196                        .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
197                        .unwrap(),
198                    Elem::Float(width, None) => {
199                        match width {
200                            16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
201                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
202                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
203                            _ => unreachable!(),
204                        };
205                        self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
206                            .unwrap()
207                    }
208                    _ => unreachable!(),
209                };
210                self.write(&out, out_id);
211            }
212            AtomicOp::Min(op) => {
213                let lhs = self.compile_variable(op.lhs);
214                let rhs = self.compile_variable(op.rhs);
215                let out = self.compile_variable(out);
216                let out_ty = out.item();
217
218                let lhs_id = lhs.id(self);
219                let rhs_id = self.read(&rhs);
220                let out_id = self.write_id(&out);
221
222                let ty = out_ty.id(self);
223                let memory = self.scope(&lhs);
224                let semantics = self.semantics_rw(&lhs);
225
226                match out_ty.elem() {
227                    Elem::Int(_, false) => self
228                        .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
229                        .unwrap(),
230                    Elem::Int(_, true) => self
231                        .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
232                        .unwrap(),
233                    Elem::Float(width, None) => {
234                        match width {
235                            16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
236                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
237                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
238                            _ => unreachable!(),
239                        };
240                        self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
241                            .unwrap()
242                    }
243                    _ => unreachable!(),
244                };
245                self.write(&out, out_id);
246            }
247            AtomicOp::And(op) => {
248                let lhs = self.compile_variable(op.lhs);
249                let rhs = self.compile_variable(op.rhs);
250                let out = self.compile_variable(out);
251                let out_ty = out.item();
252
253                let lhs_id = lhs.id(self);
254                let rhs_id = self.read(&rhs);
255                let out_id = self.write_id(&out);
256
257                let ty = out_ty.id(self);
258                let memory = self.scope(&lhs);
259                let semantics = self.semantics_rw(&lhs);
260
261                assert!(
262                    matches!(out_ty.elem(), Elem::Int(_, _)),
263                    "and doesn't support float atomics"
264                );
265                self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
266                    .unwrap();
267                self.write(&out, out_id);
268            }
269            AtomicOp::Or(op) => {
270                let lhs = self.compile_variable(op.lhs);
271                let rhs = self.compile_variable(op.rhs);
272                let out = self.compile_variable(out);
273                let out_ty = out.item();
274
275                let lhs_id = lhs.id(self);
276                let rhs_id = self.read(&rhs);
277                let out_id = self.write_id(&out);
278
279                let ty = out_ty.id(self);
280                let memory = self.scope(&lhs);
281                let semantics = self.semantics_rw(&lhs);
282
283                assert!(
284                    matches!(out_ty.elem(), Elem::Int(_, _)),
285                    "or doesn't support float atomics"
286                );
287                self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
288                    .unwrap();
289                self.write(&out, out_id);
290            }
291            AtomicOp::Xor(op) => {
292                let lhs = self.compile_variable(op.lhs);
293                let rhs = self.compile_variable(op.rhs);
294                let out = self.compile_variable(out);
295                let out_ty = out.item();
296
297                let lhs_id = lhs.id(self);
298                let rhs_id = self.read(&rhs);
299                let out_id = self.write_id(&out);
300
301                let ty = out_ty.id(self);
302                let memory = self.scope(&lhs);
303                let semantics = self.semantics_rw(&lhs);
304
305                assert!(
306                    matches!(out_ty.elem(), Elem::Int(_, _)),
307                    "xor doesn't support float atomics"
308                );
309                self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
310                    .unwrap();
311                self.write(&out, out_id);
312            }
313        }
314    }
315
316    fn scope(&mut self, var: &crate::variable::Variable) -> Word {
317        let value = self.scope_of(var) as u32;
318        self.const_u32(value)
319    }
320
321    fn semantics_r(&mut self, var: &crate::variable::Variable) -> Word {
322        let value = self.semantics_of(var) | MemorySemantics::ACQUIRE;
323        self.const_u32(value.bits())
324    }
325
326    fn semantics_w(&mut self, var: &crate::variable::Variable) -> Word {
327        let value = self.semantics_of(var) | MemorySemantics::RELEASE;
328        self.const_u32(value.bits())
329    }
330
331    fn semantics_rw(&mut self, var: &crate::variable::Variable) -> Word {
332        let value = self.semantics_of(var) | MemorySemantics::ACQUIRE_RELEASE;
333        self.const_u32(value.bits())
334    }
335
336    fn scope_of(&mut self, var: &crate::variable::Variable) -> Scope {
337        let id = var.id(self);
338        *self
339            .state
340            .atomic_scopes
341            .get(&id)
342            .expect("Atomic should have a scope registered")
343    }
344
345    fn semantics_of(&mut self, var: &crate::variable::Variable) -> MemorySemantics {
346        match self.scope_of(var) {
347            Scope::Device => MemorySemantics::UNIFORM_MEMORY,
348            Scope::Workgroup => MemorySemantics::WORKGROUP_MEMORY,
349            Scope::Subgroup => MemorySemantics::SUBGROUP_MEMORY,
350            other => unreachable!("Invalid scope for atomic operation, {other:?}"),
351        }
352    }
353}