cubecl_spirv/
atomic.rs

1use cubecl_core::ir::{AtomicOp, InstructionModes, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7    pub fn compile_atomic(
8        &mut self,
9        atomic: AtomicOp,
10        out: Option<Variable>,
11        modes: InstructionModes,
12    ) {
13        let out = out.unwrap();
14        match atomic {
15            AtomicOp::Load(op) => {
16                let input = self.compile_variable(op.input);
17                let out = self.compile_variable(out);
18                let out_ty = out.item();
19
20                let input_id = input.id(self);
21                let out_id = self.write_id(&out);
22
23                let ty = out_ty.id(self);
24                let memory = self.const_u32(Scope::Device as u32);
25                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
26
27                self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
28                    .unwrap();
29                self.write(&out, out_id);
30            }
31            AtomicOp::Store(op) => {
32                let input = self.compile_variable(op.input);
33                let out = self.compile_variable(out);
34
35                let input_id = self.read(&input);
36                let out_id = out.id(self);
37
38                let memory = self.const_u32(Scope::Device as u32);
39                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
40
41                self.atomic_store(out_id, memory, semantics, input_id)
42                    .unwrap();
43            }
44            AtomicOp::Swap(op) => {
45                let lhs = self.compile_variable(op.lhs);
46                let rhs = self.compile_variable(op.rhs);
47                let out = self.compile_variable(out);
48                let out_ty = out.item();
49
50                let lhs_id = lhs.id(self);
51                let rhs_id = self.read(&rhs);
52                let out_id = self.write_id(&out);
53
54                let ty = out_ty.id(self);
55                let memory = self.const_u32(Scope::Device as u32);
56                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
57
58                self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
59                    .unwrap();
60                self.write(&out, out_id);
61            }
62            AtomicOp::CompareAndSwap(op) => {
63                let atomic = self.compile_variable(op.input);
64                let cmp = self.compile_variable(op.cmp);
65                let val = self.compile_variable(op.val);
66                let out = self.compile_variable(out);
67                let out_ty = out.item();
68
69                let atomic_id = atomic.id(self);
70                let cmp_id = self.read(&cmp);
71                let val_id = self.read(&val);
72                let out_id = self.write_id(&out);
73
74                let ty = out_ty.id(self);
75                let memory = self.const_u32(Scope::Device as u32);
76                let semantics_success = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
77                let semantics_failure = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
78
79                assert!(
80                    matches!(out_ty.elem(), Elem::Int(_, _)),
81                    "compare and swap doesn't support float atomics"
82                );
83                self.atomic_compare_exchange(
84                    ty,
85                    Some(out_id),
86                    atomic_id,
87                    memory,
88                    semantics_success,
89                    semantics_failure,
90                    val_id,
91                    cmp_id,
92                )
93                .unwrap();
94                self.write(&out, out_id);
95            }
96            AtomicOp::Add(op) => {
97                let lhs = self.compile_variable(op.lhs);
98                let rhs = self.compile_variable(op.rhs);
99                let out = self.compile_variable(out);
100                let out_ty = out.item();
101
102                let lhs_id = lhs.id(self);
103                let rhs_id = self.read(&rhs);
104                let out_id = self.write_id(&out);
105
106                let ty = out_ty.id(self);
107                let memory = self.const_u32(Scope::Device as u32);
108                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
109
110                match out_ty.elem() {
111                    Elem::Int(_, _) => self
112                        .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
113                        .unwrap(),
114                    Elem::Float(width, None) => {
115                        match width {
116                            16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
117                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
118                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
119                            _ => unreachable!(),
120                        };
121                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
122                            .unwrap()
123                    }
124                    _ => unreachable!(),
125                };
126
127                self.write(&out, out_id);
128            }
129            AtomicOp::Sub(op) => {
130                let lhs = self.compile_variable(op.lhs);
131                let rhs = self.compile_variable(op.rhs);
132                let out = self.compile_variable(out);
133                let out_ty = out.item();
134
135                let lhs_id = lhs.id(self);
136                let rhs_id = self.read(&rhs);
137                let out_id = self.write_id(&out);
138
139                let ty = out_ty.id(self);
140                let memory = self.const_u32(Scope::Device as u32);
141                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
142
143                assert!(
144                    matches!(out_ty.elem(), Elem::Int(_, _)),
145                    "sub doesn't support float atomics"
146                );
147                match out_ty.elem() {
148                    Elem::Int(_, _) => self
149                        .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
150                        .unwrap(),
151                    Elem::Float(width, None) => {
152                        match width {
153                            16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
154                            32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
155                            64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
156                            _ => unreachable!(),
157                        };
158                        let negated = self.f_negate(ty, None, rhs_id).unwrap();
159                        self.declare_math_mode(modes, negated);
160                        self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
161                            .unwrap()
162                    }
163                    _ => unreachable!(),
164                };
165                self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
166                    .unwrap();
167                self.write(&out, out_id);
168            }
169            AtomicOp::Max(op) => {
170                let lhs = self.compile_variable(op.lhs);
171                let rhs = self.compile_variable(op.rhs);
172                let out = self.compile_variable(out);
173                let out_ty = out.item();
174
175                let lhs_id = lhs.id(self);
176                let rhs_id = self.read(&rhs);
177                let out_id = self.write_id(&out);
178
179                let ty = out_ty.id(self);
180                let memory = self.const_u32(Scope::Device as u32);
181                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
182
183                match out_ty.elem() {
184                    Elem::Int(_, false) => self
185                        .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
186                        .unwrap(),
187                    Elem::Int(_, true) => self
188                        .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
189                        .unwrap(),
190                    Elem::Float(width, None) => {
191                        match width {
192                            16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
193                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
194                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
195                            _ => unreachable!(),
196                        };
197                        self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
198                            .unwrap()
199                    }
200                    _ => unreachable!(),
201                };
202                self.write(&out, out_id);
203            }
204            AtomicOp::Min(op) => {
205                let lhs = self.compile_variable(op.lhs);
206                let rhs = self.compile_variable(op.rhs);
207                let out = self.compile_variable(out);
208                let out_ty = out.item();
209
210                let lhs_id = lhs.id(self);
211                let rhs_id = self.read(&rhs);
212                let out_id = self.write_id(&out);
213
214                let ty = out_ty.id(self);
215                let memory = self.const_u32(Scope::Device as u32);
216                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
217
218                match out_ty.elem() {
219                    Elem::Int(_, false) => self
220                        .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
221                        .unwrap(),
222                    Elem::Int(_, true) => self
223                        .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
224                        .unwrap(),
225                    Elem::Float(width, None) => {
226                        match width {
227                            16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
228                            32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
229                            64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
230                            _ => unreachable!(),
231                        };
232                        self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
233                            .unwrap()
234                    }
235                    _ => unreachable!(),
236                };
237                self.write(&out, out_id);
238            }
239            AtomicOp::And(op) => {
240                let lhs = self.compile_variable(op.lhs);
241                let rhs = self.compile_variable(op.rhs);
242                let out = self.compile_variable(out);
243                let out_ty = out.item();
244
245                let lhs_id = lhs.id(self);
246                let rhs_id = self.read(&rhs);
247                let out_id = self.write_id(&out);
248
249                let ty = out_ty.id(self);
250                let memory = self.const_u32(Scope::Device as u32);
251                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
252
253                assert!(
254                    matches!(out_ty.elem(), Elem::Int(_, _)),
255                    "and doesn't support float atomics"
256                );
257                self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
258                    .unwrap();
259                self.write(&out, out_id);
260            }
261            AtomicOp::Or(op) => {
262                let lhs = self.compile_variable(op.lhs);
263                let rhs = self.compile_variable(op.rhs);
264                let out = self.compile_variable(out);
265                let out_ty = out.item();
266
267                let lhs_id = lhs.id(self);
268                let rhs_id = self.read(&rhs);
269                let out_id = self.write_id(&out);
270
271                let ty = out_ty.id(self);
272                let memory = self.const_u32(Scope::Device as u32);
273                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
274
275                assert!(
276                    matches!(out_ty.elem(), Elem::Int(_, _)),
277                    "or doesn't support float atomics"
278                );
279                self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
280                    .unwrap();
281                self.write(&out, out_id);
282            }
283            AtomicOp::Xor(op) => {
284                let lhs = self.compile_variable(op.lhs);
285                let rhs = self.compile_variable(op.rhs);
286                let out = self.compile_variable(out);
287                let out_ty = out.item();
288
289                let lhs_id = lhs.id(self);
290                let rhs_id = self.read(&rhs);
291                let out_id = self.write_id(&out);
292
293                let ty = out_ty.id(self);
294                let memory = self.const_u32(Scope::Device as u32);
295                let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
296
297                assert!(
298                    matches!(out_ty.elem(), Elem::Int(_, _)),
299                    "xor doesn't support float atomics"
300                );
301                self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
302                    .unwrap();
303                self.write(&out, out_id);
304            }
305        }
306    }
307}