1use cubecl_core::ir::{AtomicOp, InstructionModes, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7 pub fn compile_atomic(
8 &mut self,
9 atomic: AtomicOp,
10 out: Option<Variable>,
11 modes: InstructionModes,
12 ) {
13 let out = out.unwrap();
14 match atomic {
15 AtomicOp::Load(op) => {
16 let input = self.compile_variable(op.input);
17 let out = self.compile_variable(out);
18 let out_ty = out.item();
19
20 let input_id = input.id(self);
21 let out_id = self.write_id(&out);
22
23 let ty = out_ty.id(self);
24 let memory = self.const_u32(Scope::Device as u32);
25 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
26
27 self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
28 .unwrap();
29 self.write(&out, out_id);
30 }
31 AtomicOp::Store(op) => {
32 let input = self.compile_variable(op.input);
33 let out = self.compile_variable(out);
34
35 let input_id = self.read(&input);
36 let out_id = out.id(self);
37
38 let memory = self.const_u32(Scope::Device as u32);
39 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
40
41 self.atomic_store(out_id, memory, semantics, input_id)
42 .unwrap();
43 }
44 AtomicOp::Swap(op) => {
45 let lhs = self.compile_variable(op.lhs);
46 let rhs = self.compile_variable(op.rhs);
47 let out = self.compile_variable(out);
48 let out_ty = out.item();
49
50 let lhs_id = lhs.id(self);
51 let rhs_id = self.read(&rhs);
52 let out_id = self.write_id(&out);
53
54 let ty = out_ty.id(self);
55 let memory = self.const_u32(Scope::Device as u32);
56 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
57
58 self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
59 .unwrap();
60 self.write(&out, out_id);
61 }
62 AtomicOp::CompareAndSwap(op) => {
63 let atomic = self.compile_variable(op.input);
64 let cmp = self.compile_variable(op.cmp);
65 let val = self.compile_variable(op.val);
66 let out = self.compile_variable(out);
67 let out_ty = out.item();
68
69 let atomic_id = atomic.id(self);
70 let cmp_id = self.read(&cmp);
71 let val_id = self.read(&val);
72 let out_id = self.write_id(&out);
73
74 let ty = out_ty.id(self);
75 let memory = self.const_u32(Scope::Device as u32);
76 let semantics_success = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
77 let semantics_failure = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
78
79 assert!(
80 matches!(out_ty.elem(), Elem::Int(_, _)),
81 "compare and swap doesn't support float atomics"
82 );
83 self.atomic_compare_exchange(
84 ty,
85 Some(out_id),
86 atomic_id,
87 memory,
88 semantics_success,
89 semantics_failure,
90 val_id,
91 cmp_id,
92 )
93 .unwrap();
94 self.write(&out, out_id);
95 }
96 AtomicOp::Add(op) => {
97 let lhs = self.compile_variable(op.lhs);
98 let rhs = self.compile_variable(op.rhs);
99 let out = self.compile_variable(out);
100 let out_ty = out.item();
101
102 let lhs_id = lhs.id(self);
103 let rhs_id = self.read(&rhs);
104 let out_id = self.write_id(&out);
105
106 let ty = out_ty.id(self);
107 let memory = self.const_u32(Scope::Device as u32);
108 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
109
110 match out_ty.elem() {
111 Elem::Int(_, _) => self
112 .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
113 .unwrap(),
114 Elem::Float(width, None) => {
115 match width {
116 16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
117 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
118 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
119 _ => unreachable!(),
120 };
121 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
122 .unwrap()
123 }
124 _ => unreachable!(),
125 };
126
127 self.write(&out, out_id);
128 }
129 AtomicOp::Sub(op) => {
130 let lhs = self.compile_variable(op.lhs);
131 let rhs = self.compile_variable(op.rhs);
132 let out = self.compile_variable(out);
133 let out_ty = out.item();
134
135 let lhs_id = lhs.id(self);
136 let rhs_id = self.read(&rhs);
137 let out_id = self.write_id(&out);
138
139 let ty = out_ty.id(self);
140 let memory = self.const_u32(Scope::Device as u32);
141 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
142
143 assert!(
144 matches!(out_ty.elem(), Elem::Int(_, _)),
145 "sub doesn't support float atomics"
146 );
147 match out_ty.elem() {
148 Elem::Int(_, _) => self
149 .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
150 .unwrap(),
151 Elem::Float(width, None) => {
152 match width {
153 16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
154 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
155 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
156 _ => unreachable!(),
157 };
158 let negated = self.f_negate(ty, None, rhs_id).unwrap();
159 self.declare_math_mode(modes, negated);
160 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
161 .unwrap()
162 }
163 _ => unreachable!(),
164 };
165 self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
166 .unwrap();
167 self.write(&out, out_id);
168 }
169 AtomicOp::Max(op) => {
170 let lhs = self.compile_variable(op.lhs);
171 let rhs = self.compile_variable(op.rhs);
172 let out = self.compile_variable(out);
173 let out_ty = out.item();
174
175 let lhs_id = lhs.id(self);
176 let rhs_id = self.read(&rhs);
177 let out_id = self.write_id(&out);
178
179 let ty = out_ty.id(self);
180 let memory = self.const_u32(Scope::Device as u32);
181 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
182
183 match out_ty.elem() {
184 Elem::Int(_, false) => self
185 .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
186 .unwrap(),
187 Elem::Int(_, true) => self
188 .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
189 .unwrap(),
190 Elem::Float(width, None) => {
191 match width {
192 16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
193 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
194 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
195 _ => unreachable!(),
196 };
197 self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
198 .unwrap()
199 }
200 _ => unreachable!(),
201 };
202 self.write(&out, out_id);
203 }
204 AtomicOp::Min(op) => {
205 let lhs = self.compile_variable(op.lhs);
206 let rhs = self.compile_variable(op.rhs);
207 let out = self.compile_variable(out);
208 let out_ty = out.item();
209
210 let lhs_id = lhs.id(self);
211 let rhs_id = self.read(&rhs);
212 let out_id = self.write_id(&out);
213
214 let ty = out_ty.id(self);
215 let memory = self.const_u32(Scope::Device as u32);
216 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
217
218 match out_ty.elem() {
219 Elem::Int(_, false) => self
220 .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
221 .unwrap(),
222 Elem::Int(_, true) => self
223 .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
224 .unwrap(),
225 Elem::Float(width, None) => {
226 match width {
227 16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
228 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
229 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
230 _ => unreachable!(),
231 };
232 self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
233 .unwrap()
234 }
235 _ => unreachable!(),
236 };
237 self.write(&out, out_id);
238 }
239 AtomicOp::And(op) => {
240 let lhs = self.compile_variable(op.lhs);
241 let rhs = self.compile_variable(op.rhs);
242 let out = self.compile_variable(out);
243 let out_ty = out.item();
244
245 let lhs_id = lhs.id(self);
246 let rhs_id = self.read(&rhs);
247 let out_id = self.write_id(&out);
248
249 let ty = out_ty.id(self);
250 let memory = self.const_u32(Scope::Device as u32);
251 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
252
253 assert!(
254 matches!(out_ty.elem(), Elem::Int(_, _)),
255 "and doesn't support float atomics"
256 );
257 self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
258 .unwrap();
259 self.write(&out, out_id);
260 }
261 AtomicOp::Or(op) => {
262 let lhs = self.compile_variable(op.lhs);
263 let rhs = self.compile_variable(op.rhs);
264 let out = self.compile_variable(out);
265 let out_ty = out.item();
266
267 let lhs_id = lhs.id(self);
268 let rhs_id = self.read(&rhs);
269 let out_id = self.write_id(&out);
270
271 let ty = out_ty.id(self);
272 let memory = self.const_u32(Scope::Device as u32);
273 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
274
275 assert!(
276 matches!(out_ty.elem(), Elem::Int(_, _)),
277 "or doesn't support float atomics"
278 );
279 self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
280 .unwrap();
281 self.write(&out, out_id);
282 }
283 AtomicOp::Xor(op) => {
284 let lhs = self.compile_variable(op.lhs);
285 let rhs = self.compile_variable(op.rhs);
286 let out = self.compile_variable(out);
287 let out_ty = out.item();
288
289 let lhs_id = lhs.id(self);
290 let rhs_id = self.read(&rhs);
291 let out_id = self.write_id(&out);
292
293 let ty = out_ty.id(self);
294 let memory = self.const_u32(Scope::Device as u32);
295 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
296
297 assert!(
298 matches!(out_ty.elem(), Elem::Int(_, _)),
299 "xor doesn't support float atomics"
300 );
301 self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
302 .unwrap();
303 self.write(&out, out_id);
304 }
305 }
306 }
307}