1use cubecl_core::ir::{AtomicOp, ElemType, InstructionModes, IntKind, UIntKind, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope, Word};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7 pub fn compile_atomic(
8 &mut self,
9 atomic: AtomicOp,
10 out: Option<Variable>,
11 modes: InstructionModes,
12 ) {
13 let out = out.unwrap();
14
15 if matches!(
16 out.elem_type(),
17 ElemType::Int(IntKind::I64) | ElemType::UInt(UIntKind::U64)
18 ) {
19 self.capabilities.insert(Capability::Int64Atomics);
20 }
21
22 match atomic {
23 AtomicOp::Load(op) => {
24 let input = self.compile_variable(op.input);
25 let out = self.compile_variable(out);
26 let out_ty = out.item();
27
28 let input_id = input.id(self);
29 let out_id = self.write_id(&out);
30
31 let ty = out_ty.id(self);
32 let memory = self.scope(&input);
33 let semantics = self.semantics_r(&input);
34
35 self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
36 .unwrap();
37 self.write(&out, out_id);
38 }
39 AtomicOp::Store(op) => {
40 let input = self.compile_variable(op.input);
41 let out = self.compile_variable(out);
42
43 let input_id = self.read(&input);
44 let out_id = out.id(self);
45
46 let memory = self.scope(&out);
47 let semantics = self.semantics_w(&out);
48
49 self.atomic_store(out_id, memory, semantics, input_id)
50 .unwrap();
51 }
52 AtomicOp::Swap(op) => {
53 let lhs = self.compile_variable(op.lhs);
54 let rhs = self.compile_variable(op.rhs);
55 let out = self.compile_variable(out);
56 let out_ty = out.item();
57
58 let lhs_id = lhs.id(self);
59 let rhs_id = self.read(&rhs);
60 let out_id = self.write_id(&out);
61
62 let ty = out_ty.id(self);
63 let memory = self.scope(&lhs);
64 let semantics = self.semantics_rw(&lhs);
65
66 self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
67 .unwrap();
68 self.write(&out, out_id);
69 }
70 AtomicOp::CompareAndSwap(op) => {
71 let atomic = self.compile_variable(op.input);
72 let cmp = self.compile_variable(op.cmp);
73 let val = self.compile_variable(op.val);
74 let out = self.compile_variable(out);
75 let out_ty = out.item();
76
77 let atomic_id = atomic.id(self);
78 let cmp_id = self.read(&cmp);
79 let val_id = self.read(&val);
80 let out_id = self.write_id(&out);
81
82 let ty = out_ty.id(self);
83 let memory = self.scope(&atomic);
84 let semantics_success = self.semantics_rw(&atomic);
85 let semantics_failure = self.semantics_r(&atomic);
86
87 assert!(
88 matches!(out_ty.elem(), Elem::Int(_, _)),
89 "compare and swap doesn't support float atomics"
90 );
91 self.atomic_compare_exchange(
92 ty,
93 Some(out_id),
94 atomic_id,
95 memory,
96 semantics_success,
97 semantics_failure,
98 val_id,
99 cmp_id,
100 )
101 .unwrap();
102 self.write(&out, out_id);
103 }
104 AtomicOp::Add(op) => {
105 let lhs = self.compile_variable(op.lhs);
106 let rhs = self.compile_variable(op.rhs);
107 let out = self.compile_variable(out);
108 let out_ty = out.item();
109
110 let lhs_id = lhs.id(self);
111 let rhs_id = self.read(&rhs);
112 let out_id = self.write_id(&out);
113
114 let ty = out_ty.id(self);
115 let memory = self.scope(&lhs);
116 let semantics = self.semantics_rw(&lhs);
117
118 match out_ty.elem() {
119 Elem::Int(_, _) => self
120 .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
121 .unwrap(),
122 Elem::Float(width, None) => {
123 match width {
124 16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
125 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
126 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
127 _ => unreachable!(),
128 };
129 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
130 .unwrap()
131 }
132 _ => unreachable!(),
133 };
134
135 self.write(&out, out_id);
136 }
137 AtomicOp::Sub(op) => {
138 let lhs = self.compile_variable(op.lhs);
139 let rhs = self.compile_variable(op.rhs);
140 let out = self.compile_variable(out);
141 let out_ty = out.item();
142
143 let lhs_id = lhs.id(self);
144 let rhs_id = self.read(&rhs);
145 let out_id = self.write_id(&out);
146
147 let ty = out_ty.id(self);
148 let memory = self.scope(&lhs);
149 let semantics = self.semantics_rw(&lhs);
150
151 assert!(
152 matches!(out_ty.elem(), Elem::Int(_, _)),
153 "sub doesn't support float atomics"
154 );
155 match out_ty.elem() {
156 Elem::Int(_, _) => self
157 .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
158 .unwrap(),
159 Elem::Float(width, None) => {
160 match width {
161 16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
162 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
163 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
164 _ => unreachable!(),
165 };
166 let negated = self.f_negate(ty, None, rhs_id).unwrap();
167 self.declare_math_mode(modes, negated);
168 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
169 .unwrap()
170 }
171 _ => unreachable!(),
172 };
173 self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
174 .unwrap();
175 self.write(&out, out_id);
176 }
177 AtomicOp::Max(op) => {
178 let lhs = self.compile_variable(op.lhs);
179 let rhs = self.compile_variable(op.rhs);
180 let out = self.compile_variable(out);
181 let out_ty = out.item();
182
183 let lhs_id = lhs.id(self);
184 let rhs_id = self.read(&rhs);
185 let out_id = self.write_id(&out);
186
187 let ty = out_ty.id(self);
188 let memory = self.scope(&lhs);
189 let semantics = self.semantics_rw(&lhs);
190
191 match out_ty.elem() {
192 Elem::Int(_, false) => self
193 .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
194 .unwrap(),
195 Elem::Int(_, true) => self
196 .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
197 .unwrap(),
198 Elem::Float(width, None) => {
199 match width {
200 16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
201 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
202 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
203 _ => unreachable!(),
204 };
205 self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
206 .unwrap()
207 }
208 _ => unreachable!(),
209 };
210 self.write(&out, out_id);
211 }
212 AtomicOp::Min(op) => {
213 let lhs = self.compile_variable(op.lhs);
214 let rhs = self.compile_variable(op.rhs);
215 let out = self.compile_variable(out);
216 let out_ty = out.item();
217
218 let lhs_id = lhs.id(self);
219 let rhs_id = self.read(&rhs);
220 let out_id = self.write_id(&out);
221
222 let ty = out_ty.id(self);
223 let memory = self.scope(&lhs);
224 let semantics = self.semantics_rw(&lhs);
225
226 match out_ty.elem() {
227 Elem::Int(_, false) => self
228 .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
229 .unwrap(),
230 Elem::Int(_, true) => self
231 .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
232 .unwrap(),
233 Elem::Float(width, None) => {
234 match width {
235 16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
236 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
237 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
238 _ => unreachable!(),
239 };
240 self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
241 .unwrap()
242 }
243 _ => unreachable!(),
244 };
245 self.write(&out, out_id);
246 }
247 AtomicOp::And(op) => {
248 let lhs = self.compile_variable(op.lhs);
249 let rhs = self.compile_variable(op.rhs);
250 let out = self.compile_variable(out);
251 let out_ty = out.item();
252
253 let lhs_id = lhs.id(self);
254 let rhs_id = self.read(&rhs);
255 let out_id = self.write_id(&out);
256
257 let ty = out_ty.id(self);
258 let memory = self.scope(&lhs);
259 let semantics = self.semantics_rw(&lhs);
260
261 assert!(
262 matches!(out_ty.elem(), Elem::Int(_, _)),
263 "and doesn't support float atomics"
264 );
265 self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
266 .unwrap();
267 self.write(&out, out_id);
268 }
269 AtomicOp::Or(op) => {
270 let lhs = self.compile_variable(op.lhs);
271 let rhs = self.compile_variable(op.rhs);
272 let out = self.compile_variable(out);
273 let out_ty = out.item();
274
275 let lhs_id = lhs.id(self);
276 let rhs_id = self.read(&rhs);
277 let out_id = self.write_id(&out);
278
279 let ty = out_ty.id(self);
280 let memory = self.scope(&lhs);
281 let semantics = self.semantics_rw(&lhs);
282
283 assert!(
284 matches!(out_ty.elem(), Elem::Int(_, _)),
285 "or doesn't support float atomics"
286 );
287 self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
288 .unwrap();
289 self.write(&out, out_id);
290 }
291 AtomicOp::Xor(op) => {
292 let lhs = self.compile_variable(op.lhs);
293 let rhs = self.compile_variable(op.rhs);
294 let out = self.compile_variable(out);
295 let out_ty = out.item();
296
297 let lhs_id = lhs.id(self);
298 let rhs_id = self.read(&rhs);
299 let out_id = self.write_id(&out);
300
301 let ty = out_ty.id(self);
302 let memory = self.scope(&lhs);
303 let semantics = self.semantics_rw(&lhs);
304
305 assert!(
306 matches!(out_ty.elem(), Elem::Int(_, _)),
307 "xor doesn't support float atomics"
308 );
309 self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
310 .unwrap();
311 self.write(&out, out_id);
312 }
313 }
314 }
315
316 fn scope(&mut self, var: &crate::variable::Variable) -> Word {
317 let value = self.scope_of(var) as u32;
318 self.const_u32(value)
319 }
320
321 fn semantics_r(&mut self, var: &crate::variable::Variable) -> Word {
322 let value = self.semantics_of(var) | MemorySemantics::ACQUIRE;
323 self.const_u32(value.bits())
324 }
325
326 fn semantics_w(&mut self, var: &crate::variable::Variable) -> Word {
327 let value = self.semantics_of(var) | MemorySemantics::RELEASE;
328 self.const_u32(value.bits())
329 }
330
331 fn semantics_rw(&mut self, var: &crate::variable::Variable) -> Word {
332 let value = self.semantics_of(var) | MemorySemantics::ACQUIRE_RELEASE;
333 self.const_u32(value.bits())
334 }
335
336 fn scope_of(&mut self, var: &crate::variable::Variable) -> Scope {
337 let id = var.id(self);
338 *self
339 .state
340 .atomic_scopes
341 .get(&id)
342 .expect("Atomic should have a scope registered")
343 }
344
345 fn semantics_of(&mut self, var: &crate::variable::Variable) -> MemorySemantics {
346 match self.scope_of(var) {
347 Scope::Device => MemorySemantics::UNIFORM_MEMORY,
348 Scope::Workgroup => MemorySemantics::WORKGROUP_MEMORY,
349 Scope::Subgroup => MemorySemantics::SUBGROUP_MEMORY,
350 other => unreachable!("Invalid scope for atomic operation, {other:?}"),
351 }
352 }
353}