1use cubecl_core::ir::{AtomicOp, ElemType, InstructionModes, IntKind, UIntKind, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope, Word};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7 pub fn compile_atomic(
8 &mut self,
9 atomic: AtomicOp,
10 out: Option<Variable>,
11 modes: InstructionModes,
12 ) {
13 let out = out.unwrap();
14
15 if matches!(
16 out.elem_type(),
17 ElemType::Int(IntKind::I64) | ElemType::UInt(UIntKind::U64)
18 ) {
19 self.capabilities.insert(Capability::Int64Atomics);
20 }
21
22 match atomic {
23 AtomicOp::Load(op) => {
24 let input = self.compile_variable(op.input);
25 let out = self.compile_variable(out);
26 let out_ty = out.item();
27
28 let input_id = input.id(self);
29 let out_id = self.write_id(&out);
30
31 let ty = out_ty.id(self);
32 let memory = self.scope(&input);
33 let semantics = self.semantics_r(&input);
34
35 self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
36 .unwrap();
37 self.write(&out, out_id);
38 }
39 AtomicOp::Store(op) => {
40 let input = self.compile_variable(op.input);
41 let out = self.compile_variable(out);
42
43 let input_id = self.read(&input);
44 let out_id = out.id(self);
45
46 let memory = self.scope(&out);
47 let semantics = self.semantics_w(&out);
48
49 self.atomic_store(out_id, memory, semantics, input_id)
50 .unwrap();
51 }
52 AtomicOp::Swap(op) => {
53 let lhs = self.compile_variable(op.lhs);
54 let rhs = self.compile_variable(op.rhs);
55 let out = self.compile_variable(out);
56 let out_ty = out.item();
57
58 let lhs_id = lhs.id(self);
59 let rhs_id = self.read(&rhs);
60 let out_id = self.write_id(&out);
61
62 let ty = out_ty.id(self);
63 let memory = self.scope(&lhs);
64 let semantics = self.semantics_rw(&lhs);
65
66 self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
67 .unwrap();
68 self.write(&out, out_id);
69 }
70 AtomicOp::CompareAndSwap(op) => {
71 let atomic = self.compile_variable(op.input);
72 let cmp = self.compile_variable(op.cmp);
73 let val = self.compile_variable(op.val);
74 let out = self.compile_variable(out);
75 let out_ty = out.item();
76
77 let atomic_id = atomic.id(self);
78 let cmp_id = self.read(&cmp);
79 let val_id = self.read(&val);
80 let out_id = self.write_id(&out);
81
82 let ty = out_ty.id(self);
83 let memory = self.scope(&atomic);
84 let semantics_success = self.semantics_rw(&atomic);
85 let semantics_failure = self.semantics_r(&atomic);
86
87 assert!(
88 matches!(out_ty.elem(), Elem::Int(_, _)),
89 "compare and swap doesn't support float atomics"
90 );
91 self.atomic_compare_exchange(
92 ty,
93 Some(out_id),
94 atomic_id,
95 memory,
96 semantics_success,
97 semantics_failure,
98 val_id,
99 cmp_id,
100 )
101 .unwrap();
102 self.write(&out, out_id);
103 }
104 AtomicOp::Add(op) => {
105 let lhs = self.compile_variable(op.lhs);
106 let rhs = self.compile_variable(op.rhs);
107 let out = self.compile_variable(out);
108 let out_ty = out.item();
109
110 let lhs_id = lhs.id(self);
111 let rhs_id = self.read(&rhs);
112 let out_id = self.write_id(&out);
113
114 let ty = out_ty.id(self);
115 let memory = self.scope(&lhs);
116 let semantics = self.semantics_rw(&lhs);
117
118 match out_ty.elem() {
119 Elem::Int(_, _) => self
120 .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
121 .unwrap(),
122 Elem::Float(width, None) => {
123 match width {
124 16 if out_ty.vectorization() == 1 => {
125 self.capabilities.insert(Capability::AtomicFloat16AddEXT)
126 }
127 16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
128 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
129 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
130 _ => unreachable!(),
131 };
132 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
133 .unwrap()
134 }
135 _ => unreachable!(),
136 };
137
138 self.write(&out, out_id);
139 }
140 AtomicOp::Sub(op) => {
141 let lhs = self.compile_variable(op.lhs);
142 let rhs = self.compile_variable(op.rhs);
143 let out = self.compile_variable(out);
144 let out_ty = out.item();
145
146 let lhs_id = lhs.id(self);
147 let rhs_id = self.read(&rhs);
148 let out_id = self.write_id(&out);
149
150 let ty = out_ty.id(self);
151 let memory = self.scope(&lhs);
152 let semantics = self.semantics_rw(&lhs);
153
154 assert!(
155 matches!(out_ty.elem(), Elem::Int(_, _)),
156 "sub doesn't support float atomics"
157 );
158 match out_ty.elem() {
159 Elem::Int(_, _) => self
160 .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
161 .unwrap(),
162 Elem::Float(width, None) => {
163 match width {
164 16 if out_ty.vectorization() == 1 => {
165 self.capabilities.insert(Capability::AtomicFloat16AddEXT)
166 }
167 16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
168 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
169 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
170 _ => unreachable!(),
171 };
172 let negated = self.f_negate(ty, None, rhs_id).unwrap();
173 self.declare_math_mode(modes, negated);
174 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
175 .unwrap()
176 }
177 _ => unreachable!(),
178 };
179 self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
180 .unwrap();
181 self.write(&out, out_id);
182 }
183 AtomicOp::Max(op) => {
184 let lhs = self.compile_variable(op.lhs);
185 let rhs = self.compile_variable(op.rhs);
186 let out = self.compile_variable(out);
187 let out_ty = out.item();
188
189 let lhs_id = lhs.id(self);
190 let rhs_id = self.read(&rhs);
191 let out_id = self.write_id(&out);
192
193 let ty = out_ty.id(self);
194 let memory = self.scope(&lhs);
195 let semantics = self.semantics_rw(&lhs);
196
197 match out_ty.elem() {
198 Elem::Int(_, false) => self
199 .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
200 .unwrap(),
201 Elem::Int(_, true) => self
202 .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
203 .unwrap(),
204 Elem::Float(width, None) => {
205 match width {
206 16 if out_ty.vectorization() == 1 => {
207 self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT)
208 }
209 16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
210 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
211 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
212 _ => unreachable!(),
213 };
214 self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
215 .unwrap()
216 }
217 _ => unreachable!(),
218 };
219 self.write(&out, out_id);
220 }
221 AtomicOp::Min(op) => {
222 let lhs = self.compile_variable(op.lhs);
223 let rhs = self.compile_variable(op.rhs);
224 let out = self.compile_variable(out);
225 let out_ty = out.item();
226
227 let lhs_id = lhs.id(self);
228 let rhs_id = self.read(&rhs);
229 let out_id = self.write_id(&out);
230
231 let ty = out_ty.id(self);
232 let memory = self.scope(&lhs);
233 let semantics = self.semantics_rw(&lhs);
234
235 match out_ty.elem() {
236 Elem::Int(_, false) => self
237 .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
238 .unwrap(),
239 Elem::Int(_, true) => self
240 .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
241 .unwrap(),
242 Elem::Float(width, None) => {
243 match width {
244 16 if out_ty.vectorization() == 1 => {
245 self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT)
246 }
247 16 => self.capabilities.insert(Capability::AtomicFloat16VectorNV),
248 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
249 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
250 _ => unreachable!(),
251 };
252 self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
253 .unwrap()
254 }
255 _ => unreachable!(),
256 };
257 self.write(&out, out_id);
258 }
259 AtomicOp::And(op) => {
260 let lhs = self.compile_variable(op.lhs);
261 let rhs = self.compile_variable(op.rhs);
262 let out = self.compile_variable(out);
263 let out_ty = out.item();
264
265 let lhs_id = lhs.id(self);
266 let rhs_id = self.read(&rhs);
267 let out_id = self.write_id(&out);
268
269 let ty = out_ty.id(self);
270 let memory = self.scope(&lhs);
271 let semantics = self.semantics_rw(&lhs);
272
273 assert!(
274 matches!(out_ty.elem(), Elem::Int(_, _)),
275 "and doesn't support float atomics"
276 );
277 self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
278 .unwrap();
279 self.write(&out, out_id);
280 }
281 AtomicOp::Or(op) => {
282 let lhs = self.compile_variable(op.lhs);
283 let rhs = self.compile_variable(op.rhs);
284 let out = self.compile_variable(out);
285 let out_ty = out.item();
286
287 let lhs_id = lhs.id(self);
288 let rhs_id = self.read(&rhs);
289 let out_id = self.write_id(&out);
290
291 let ty = out_ty.id(self);
292 let memory = self.scope(&lhs);
293 let semantics = self.semantics_rw(&lhs);
294
295 assert!(
296 matches!(out_ty.elem(), Elem::Int(_, _)),
297 "or doesn't support float atomics"
298 );
299 self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
300 .unwrap();
301 self.write(&out, out_id);
302 }
303 AtomicOp::Xor(op) => {
304 let lhs = self.compile_variable(op.lhs);
305 let rhs = self.compile_variable(op.rhs);
306 let out = self.compile_variable(out);
307 let out_ty = out.item();
308
309 let lhs_id = lhs.id(self);
310 let rhs_id = self.read(&rhs);
311 let out_id = self.write_id(&out);
312
313 let ty = out_ty.id(self);
314 let memory = self.scope(&lhs);
315 let semantics = self.semantics_rw(&lhs);
316
317 assert!(
318 matches!(out_ty.elem(), Elem::Int(_, _)),
319 "xor doesn't support float atomics"
320 );
321 self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
322 .unwrap();
323 self.write(&out, out_id);
324 }
325 }
326 }
327
328 fn scope(&mut self, var: &crate::variable::Variable) -> Word {
329 let value = self.scope_of(var) as u32;
330 self.const_u32(value)
331 }
332
333 fn semantics_r(&mut self, var: &crate::variable::Variable) -> Word {
334 let value = self.semantics_of(var) | MemorySemantics::ACQUIRE;
335 self.const_u32(value.bits())
336 }
337
338 fn semantics_w(&mut self, var: &crate::variable::Variable) -> Word {
339 let value = self.semantics_of(var) | MemorySemantics::RELEASE;
340 self.const_u32(value.bits())
341 }
342
343 fn semantics_rw(&mut self, var: &crate::variable::Variable) -> Word {
344 let value = self.semantics_of(var) | MemorySemantics::ACQUIRE_RELEASE;
345 self.const_u32(value.bits())
346 }
347
348 fn scope_of(&mut self, var: &crate::variable::Variable) -> Scope {
349 let id = var.id(self);
350 *self
351 .state
352 .atomic_scopes
353 .get(&id)
354 .expect("Atomic should have a scope registered")
355 }
356
357 fn semantics_of(&mut self, var: &crate::variable::Variable) -> MemorySemantics {
358 match self.scope_of(var) {
359 Scope::Device => MemorySemantics::UNIFORM_MEMORY,
360 Scope::Workgroup => MemorySemantics::WORKGROUP_MEMORY,
361 Scope::Subgroup => MemorySemantics::SUBGROUP_MEMORY,
362 other => unreachable!("Invalid scope for atomic operation, {other:?}"),
363 }
364 }
365}