1use cubecl_core::ir::{AtomicOp, Variable};
2use rspirv::spirv::{Capability, MemorySemantics, Scope};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7 pub fn compile_atomic(&mut self, atomic: AtomicOp, out: Option<Variable>) {
8 let out = out.unwrap();
9 match atomic {
10 AtomicOp::Load(op) => {
11 let input = self.compile_variable(op.input);
12 let out = self.compile_variable(out);
13 let out_ty = out.item();
14
15 let input_id = input.id(self);
16 let out_id = self.write_id(&out);
17
18 let ty = out_ty.id(self);
19 let memory = self.const_u32(Scope::Device as u32);
20 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
21
22 self.atomic_load(ty, Some(out_id), input_id, memory, semantics)
23 .unwrap();
24 self.write(&out, out_id);
25 }
26 AtomicOp::Store(op) => {
27 let input = self.compile_variable(op.input);
28 let out = self.compile_variable(out);
29
30 let input_id = self.read(&input);
31 let out_id = out.id(self);
32
33 let memory = self.const_u32(Scope::Device as u32);
34 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
35
36 self.atomic_store(out_id, memory, semantics, input_id)
37 .unwrap();
38 }
39 AtomicOp::Swap(op) => {
40 let lhs = self.compile_variable(op.lhs);
41 let rhs = self.compile_variable(op.rhs);
42 let out = self.compile_variable(out);
43 let out_ty = out.item();
44
45 let lhs_id = lhs.id(self);
46 let rhs_id = self.read(&rhs);
47 let out_id = self.write_id(&out);
48
49 let ty = out_ty.id(self);
50 let memory = self.const_u32(Scope::Device as u32);
51 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
52
53 self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
54 .unwrap();
55 self.write(&out, out_id);
56 }
57 AtomicOp::CompareAndSwap(op) => {
58 let atomic = self.compile_variable(op.input);
59 let cmp = self.compile_variable(op.cmp);
60 let val = self.compile_variable(op.val);
61 let out = self.compile_variable(out);
62 let out_ty = out.item();
63
64 let atomic_id = atomic.id(self);
65 let cmp_id = self.read(&cmp);
66 let val_id = self.read(&val);
67 let out_id = self.write_id(&out);
68
69 let ty = out_ty.id(self);
70 let memory = self.const_u32(Scope::Device as u32);
71 let semantics_success = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
72 let semantics_failure = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
73
74 assert!(
75 matches!(out_ty.elem(), Elem::Int(_, _)),
76 "compare and swap doesn't support float atomics"
77 );
78 self.atomic_compare_exchange(
79 ty,
80 Some(out_id),
81 atomic_id,
82 memory,
83 semantics_success,
84 semantics_failure,
85 val_id,
86 cmp_id,
87 )
88 .unwrap();
89 self.write(&out, out_id);
90 }
91 AtomicOp::Add(op) => {
92 let lhs = self.compile_variable(op.lhs);
93 let rhs = self.compile_variable(op.rhs);
94 let out = self.compile_variable(out);
95 let out_ty = out.item();
96
97 let lhs_id = lhs.id(self);
98 let rhs_id = self.read(&rhs);
99 let out_id = self.write_id(&out);
100
101 let ty = out_ty.id(self);
102 let memory = self.const_u32(Scope::Device as u32);
103 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
104
105 match out_ty.elem() {
106 Elem::Int(_, _) => self
107 .atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
108 .unwrap(),
109 Elem::Float(width, None) => {
110 match width {
111 16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
112 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
113 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
114 _ => unreachable!(),
115 };
116 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
117 .unwrap()
118 }
119 _ => unreachable!(),
120 };
121
122 self.write(&out, out_id);
123 }
124 AtomicOp::Sub(op) => {
125 let lhs = self.compile_variable(op.lhs);
126 let rhs = self.compile_variable(op.rhs);
127 let out = self.compile_variable(out);
128 let out_ty = out.item();
129
130 let lhs_id = lhs.id(self);
131 let rhs_id = self.read(&rhs);
132 let out_id = self.write_id(&out);
133
134 let ty = out_ty.id(self);
135 let memory = self.const_u32(Scope::Device as u32);
136 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
137
138 assert!(
139 matches!(out_ty.elem(), Elem::Int(_, _)),
140 "sub doesn't support float atomics"
141 );
142 match out_ty.elem() {
143 Elem::Int(_, _) => self
144 .atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
145 .unwrap(),
146 Elem::Float(width, None) => {
147 match width {
148 16 => self.capabilities.insert(Capability::AtomicFloat16AddEXT),
149 32 => self.capabilities.insert(Capability::AtomicFloat32AddEXT),
150 64 => self.capabilities.insert(Capability::AtomicFloat64AddEXT),
151 _ => unreachable!(),
152 };
153 let negated = self.f_negate(ty, None, rhs_id).unwrap();
154 self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated)
155 .unwrap()
156 }
157 _ => unreachable!(),
158 };
159 self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
160 .unwrap();
161 self.write(&out, out_id);
162 }
163 AtomicOp::Max(op) => {
164 let lhs = self.compile_variable(op.lhs);
165 let rhs = self.compile_variable(op.rhs);
166 let out = self.compile_variable(out);
167 let out_ty = out.item();
168
169 let lhs_id = lhs.id(self);
170 let rhs_id = self.read(&rhs);
171 let out_id = self.write_id(&out);
172
173 let ty = out_ty.id(self);
174 let memory = self.const_u32(Scope::Device as u32);
175 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
176
177 match out_ty.elem() {
178 Elem::Int(_, false) => self
179 .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
180 .unwrap(),
181 Elem::Int(_, true) => self
182 .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
183 .unwrap(),
184 Elem::Float(width, None) => {
185 match width {
186 16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
187 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
188 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
189 _ => unreachable!(),
190 };
191 self.atomic_f_max_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
192 .unwrap()
193 }
194 _ => unreachable!(),
195 };
196 self.write(&out, out_id);
197 }
198 AtomicOp::Min(op) => {
199 let lhs = self.compile_variable(op.lhs);
200 let rhs = self.compile_variable(op.rhs);
201 let out = self.compile_variable(out);
202 let out_ty = out.item();
203
204 let lhs_id = lhs.id(self);
205 let rhs_id = self.read(&rhs);
206 let out_id = self.write_id(&out);
207
208 let ty = out_ty.id(self);
209 let memory = self.const_u32(Scope::Device as u32);
210 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
211
212 match out_ty.elem() {
213 Elem::Int(_, false) => self
214 .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
215 .unwrap(),
216 Elem::Int(_, true) => self
217 .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
218 .unwrap(),
219 Elem::Float(width, None) => {
220 match width {
221 16 => self.capabilities.insert(Capability::AtomicFloat16MinMaxEXT),
222 32 => self.capabilities.insert(Capability::AtomicFloat32MinMaxEXT),
223 64 => self.capabilities.insert(Capability::AtomicFloat64MinMaxEXT),
224 _ => unreachable!(),
225 };
226 self.atomic_f_min_ext(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
227 .unwrap()
228 }
229 _ => unreachable!(),
230 };
231 self.write(&out, out_id);
232 }
233 AtomicOp::And(op) => {
234 let lhs = self.compile_variable(op.lhs);
235 let rhs = self.compile_variable(op.rhs);
236 let out = self.compile_variable(out);
237 let out_ty = out.item();
238
239 let lhs_id = lhs.id(self);
240 let rhs_id = self.read(&rhs);
241 let out_id = self.write_id(&out);
242
243 let ty = out_ty.id(self);
244 let memory = self.const_u32(Scope::Device as u32);
245 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
246
247 assert!(
248 matches!(out_ty.elem(), Elem::Int(_, _)),
249 "and doesn't support float atomics"
250 );
251 self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
252 .unwrap();
253 self.write(&out, out_id);
254 }
255 AtomicOp::Or(op) => {
256 let lhs = self.compile_variable(op.lhs);
257 let rhs = self.compile_variable(op.rhs);
258 let out = self.compile_variable(out);
259 let out_ty = out.item();
260
261 let lhs_id = lhs.id(self);
262 let rhs_id = self.read(&rhs);
263 let out_id = self.write_id(&out);
264
265 let ty = out_ty.id(self);
266 let memory = self.const_u32(Scope::Device as u32);
267 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
268
269 assert!(
270 matches!(out_ty.elem(), Elem::Int(_, _)),
271 "or doesn't support float atomics"
272 );
273 self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
274 .unwrap();
275 self.write(&out, out_id);
276 }
277 AtomicOp::Xor(op) => {
278 let lhs = self.compile_variable(op.lhs);
279 let rhs = self.compile_variable(op.rhs);
280 let out = self.compile_variable(out);
281 let out_ty = out.item();
282
283 let lhs_id = lhs.id(self);
284 let rhs_id = self.read(&rhs);
285 let out_id = self.write_id(&out);
286
287 let ty = out_ty.id(self);
288 let memory = self.const_u32(Scope::Device as u32);
289 let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits());
290
291 assert!(
292 matches!(out_ty.elem(), Elem::Int(_, _)),
293 "xor doesn't support float atomics"
294 );
295 self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id)
296 .unwrap();
297 self.write(&out, out_id);
298 }
299 }
300 }
301}