1use std::fmt::Display;
2
3use serde::{Deserialize, Serialize};
4
5use super::{
6 init_expand_element, ExpandElementBaseInit, ExpandElementTyped, Int, IntoRuntime,
7 LaunchArgExpand, Numeric,
8};
9use crate::{
10 frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement},
11 ir::{
12 BinaryOperator, CompareAndSwapOperator, Elem, Instruction, Item, Operation, UnaryOperator,
13 },
14 prelude::KernelBuilder,
15 unexpanded,
16};
17
18#[derive(Clone, Copy, Hash, PartialEq, Eq)]
22pub struct Atomic<Inner: CubePrimitive> {
23 pub val: Inner,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub enum AtomicOp {
28 Load(UnaryOperator),
29 Store(UnaryOperator),
30 Swap(BinaryOperator),
31 Add(BinaryOperator),
32 Sub(BinaryOperator),
33 Max(BinaryOperator),
34 Min(BinaryOperator),
35 And(BinaryOperator),
36 Or(BinaryOperator),
37 Xor(BinaryOperator),
38 CompareAndSwap(CompareAndSwapOperator),
39}
40
41impl Display for AtomicOp {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 AtomicOp::Load(op) => write!(f, "atomic_load({})", op.input),
45 AtomicOp::Store(op) => write!(f, "atomic_store({})", op.input),
46 AtomicOp::Swap(op) => {
47 write!(f, "atomic_swap({}, {})", op.lhs, op.rhs)
48 }
49 AtomicOp::Add(op) => write!(f, "atomic_add({}, {})", op.lhs, op.rhs),
50 AtomicOp::Sub(op) => write!(f, "atomic_sub({}, {})", op.lhs, op.rhs),
51 AtomicOp::Max(op) => write!(f, "atomic_max({}, {})", op.lhs, op.rhs),
52 AtomicOp::Min(op) => write!(f, "atomic_min({}, {})", op.lhs, op.rhs),
53 AtomicOp::And(op) => write!(f, "atomic_and({}, {})", op.lhs, op.rhs),
54 AtomicOp::Or(op) => write!(f, "atomic_or({}, {})", op.lhs, op.rhs),
55 AtomicOp::Xor(op) => write!(f, "atomic_xor({}, {})", op.lhs, op.rhs),
56 AtomicOp::CompareAndSwap(op) => {
57 write!(f, "compare_and_swap({}, {}, {})", op.input, op.cmp, op.val)
58 }
59 }
60 }
61}
62
63impl<Inner: Numeric> Atomic<Inner> {
64 #[allow(unused_variables)]
66 pub fn load(pointer: &Self) -> Inner {
67 unexpanded!()
68 }
69
70 #[allow(unused_variables)]
72 pub fn store(pointer: &Self, value: Inner) {
73 unexpanded!()
74 }
75
76 #[allow(unused_variables)]
78 pub fn swap(pointer: &Self, value: Inner) -> Inner {
79 unexpanded!()
80 }
81
82 #[allow(unused_variables)]
84 pub fn add(pointer: &Self, value: Inner) -> Inner {
85 unexpanded!()
86 }
87
88 #[allow(unused_variables)]
91 pub fn max(pointer: &Self, value: Inner) -> Inner {
92 unexpanded!()
93 }
94
95 #[allow(unused_variables)]
98 pub fn min(pointer: &Self, value: Inner) -> Inner {
99 unexpanded!()
100 }
101
102 #[allow(unused_variables)]
104 pub fn sub(pointer: &Self, value: Inner) -> Inner {
105 unexpanded!()
106 }
107
108 pub fn __expand_load(
109 context: &mut CubeContext,
110 pointer: <Self as CubeType>::ExpandType,
111 ) -> <Inner as CubeType>::ExpandType {
112 let pointer: ExpandElement = pointer.into();
113 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
114 context.register(Instruction::new(
115 AtomicOp::Load(UnaryOperator { input: *pointer }),
116 *new_var,
117 ));
118 new_var.into()
119 }
120
121 pub fn __expand_store(
122 context: &mut CubeContext,
123 pointer: <Self as CubeType>::ExpandType,
124 value: <Inner as CubeType>::ExpandType,
125 ) {
126 let ptr: ExpandElement = pointer.into();
127 let value: ExpandElement = value.into();
128 context.register(Instruction::new(
129 AtomicOp::Store(UnaryOperator { input: *value }),
130 *ptr,
131 ));
132 }
133
134 pub fn __expand_swap(
135 context: &mut CubeContext,
136 pointer: <Self as CubeType>::ExpandType,
137 value: <Inner as CubeType>::ExpandType,
138 ) -> <Inner as CubeType>::ExpandType {
139 let ptr: ExpandElement = pointer.into();
140 let value: ExpandElement = value.into();
141 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
142 context.register(Instruction::new(
143 AtomicOp::Swap(BinaryOperator {
144 lhs: *ptr,
145 rhs: *value,
146 }),
147 *new_var,
148 ));
149 new_var.into()
150 }
151
152 pub fn __expand_add(
153 context: &mut CubeContext,
154 pointer: <Self as CubeType>::ExpandType,
155 value: <Inner as CubeType>::ExpandType,
156 ) -> <Inner as CubeType>::ExpandType {
157 let ptr: ExpandElement = pointer.into();
158 let value: ExpandElement = value.into();
159 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
160 context.register(Instruction::new(
161 AtomicOp::Add(BinaryOperator {
162 lhs: *ptr,
163 rhs: *value,
164 }),
165 *new_var,
166 ));
167 new_var.into()
168 }
169
170 pub fn __expand_sub(
171 context: &mut CubeContext,
172 pointer: <Self as CubeType>::ExpandType,
173 value: <Inner as CubeType>::ExpandType,
174 ) -> <Inner as CubeType>::ExpandType {
175 let ptr: ExpandElement = pointer.into();
176 let value: ExpandElement = value.into();
177 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
178 context.register(Instruction::new(
179 AtomicOp::Sub(BinaryOperator {
180 lhs: *ptr,
181 rhs: *value,
182 }),
183 *new_var,
184 ));
185 new_var.into()
186 }
187
188 pub fn __expand_max(
189 context: &mut CubeContext,
190 pointer: <Self as CubeType>::ExpandType,
191 value: <Inner as CubeType>::ExpandType,
192 ) -> <Inner as CubeType>::ExpandType {
193 let ptr: ExpandElement = pointer.into();
194 let value: ExpandElement = value.into();
195 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
196 context.register(Instruction::new(
197 AtomicOp::Max(BinaryOperator {
198 lhs: *ptr,
199 rhs: *value,
200 }),
201 *new_var,
202 ));
203 new_var.into()
204 }
205
206 pub fn __expand_min(
207 context: &mut CubeContext,
208 pointer: <Self as CubeType>::ExpandType,
209 value: <Inner as CubeType>::ExpandType,
210 ) -> <Inner as CubeType>::ExpandType {
211 let ptr: ExpandElement = pointer.into();
212 let value: ExpandElement = value.into();
213 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
214 context.register(Instruction::new(
215 AtomicOp::Min(BinaryOperator {
216 lhs: *ptr,
217 rhs: *value,
218 }),
219 *new_var,
220 ));
221 new_var.into()
222 }
223}
224
225impl<Inner: Int> Atomic<Inner> {
226 #[allow(unused_variables)]
232 pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
233 unexpanded!()
234 }
235
236 #[allow(unused_variables)]
238 pub fn and(pointer: &Self, value: Inner) -> Inner {
239 unexpanded!()
240 }
241
242 #[allow(unused_variables)]
244 pub fn or(pointer: &Self, value: Inner) -> Inner {
245 unexpanded!()
246 }
247
248 #[allow(unused_variables)]
250 pub fn xor(pointer: &Self, value: Inner) -> Inner {
251 unexpanded!()
252 }
253
254 pub fn __expand_compare_and_swap(
255 context: &mut CubeContext,
256 pointer: <Self as CubeType>::ExpandType,
257 cmp: <Inner as CubeType>::ExpandType,
258 value: <Inner as CubeType>::ExpandType,
259 ) -> <Inner as CubeType>::ExpandType {
260 let pointer: ExpandElement = pointer.into();
261 let cmp: ExpandElement = cmp.into();
262 let value: ExpandElement = value.into();
263 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
264 context.register(Instruction::new(
265 AtomicOp::CompareAndSwap(CompareAndSwapOperator {
266 input: *pointer,
267 cmp: *cmp,
268 val: *value,
269 }),
270 *new_var,
271 ));
272 new_var.into()
273 }
274
275 pub fn __expand_and(
276 context: &mut CubeContext,
277 pointer: <Self as CubeType>::ExpandType,
278 value: <Inner as CubeType>::ExpandType,
279 ) -> <Inner as CubeType>::ExpandType {
280 let ptr: ExpandElement = pointer.into();
281 let value: ExpandElement = value.into();
282 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
283 context.register(Instruction::new(
284 AtomicOp::And(BinaryOperator {
285 lhs: *ptr,
286 rhs: *value,
287 }),
288 *new_var,
289 ));
290 new_var.into()
291 }
292
293 pub fn __expand_or(
294 context: &mut CubeContext,
295 pointer: <Self as CubeType>::ExpandType,
296 value: <Inner as CubeType>::ExpandType,
297 ) -> <Inner as CubeType>::ExpandType {
298 let ptr: ExpandElement = pointer.into();
299 let value: ExpandElement = value.into();
300 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
301 context.register(Instruction::new(
302 AtomicOp::Or(BinaryOperator {
303 lhs: *ptr,
304 rhs: *value,
305 }),
306 *new_var,
307 ));
308 new_var.into()
309 }
310
311 pub fn __expand_xor(
312 context: &mut CubeContext,
313 pointer: <Self as CubeType>::ExpandType,
314 value: <Inner as CubeType>::ExpandType,
315 ) -> <Inner as CubeType>::ExpandType {
316 let ptr: ExpandElement = pointer.into();
317 let value: ExpandElement = value.into();
318 let new_var = context.create_local(Item::new(Inner::as_elem(context)));
319 context.register(Instruction::new(
320 AtomicOp::Xor(BinaryOperator {
321 lhs: *ptr,
322 rhs: *value,
323 }),
324 *new_var,
325 ));
326 new_var.into()
327 }
328}
329
330impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
331 type ExpandType = ExpandElementTyped<Self>;
332}
333
334impl<Inner: CubePrimitive> IntoRuntime for Atomic<Inner> {
335 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
336 unimplemented!("Atomics don't exist at compile time")
337 }
338}
339
340impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
341 fn as_elem_native() -> Option<Elem> {
342 match Inner::as_elem_native() {
343 Some(Elem::Float(kind)) => Some(Elem::AtomicFloat(kind)),
344 Some(Elem::Int(kind)) => Some(Elem::AtomicInt(kind)),
345 Some(Elem::UInt(kind)) => Some(Elem::AtomicUInt(kind)),
346 None => None,
347 _ => unreachable!("Atomics can only be float/int/uint"),
348 }
349 }
350
351 fn as_elem(context: &CubeContext) -> Elem {
352 match Inner::as_elem(context) {
353 Elem::Float(kind) => Elem::AtomicFloat(kind),
354 Elem::Int(kind) => Elem::AtomicInt(kind),
355 Elem::UInt(kind) => Elem::AtomicUInt(kind),
356 _ => unreachable!("Atomics can only be float/int/uint"),
357 }
358 }
359
360 fn as_elem_native_unchecked() -> Elem {
361 match Inner::as_elem_native_unchecked() {
362 Elem::Float(kind) => Elem::AtomicFloat(kind),
363 Elem::Int(kind) => Elem::AtomicInt(kind),
364 Elem::UInt(kind) => Elem::AtomicUInt(kind),
365 _ => unreachable!("Atomics can only be float/int/uint"),
366 }
367 }
368
369 fn size() -> Option<usize> {
370 Inner::size()
371 }
372
373 fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
374 ExpandElementTyped::new(elem)
375 }
376}
377
378impl<Inner: CubePrimitive> ExpandElementBaseInit for Atomic<Inner> {
379 fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
380 init_expand_element(context, elem)
381 }
382}
383
384impl<Inner: CubePrimitive> LaunchArgExpand for Atomic<Inner> {
385 type CompilationArg = ();
386
387 fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
388 builder.scalar(Self::as_elem_native_unchecked()).into()
389 }
390}
391
392impl From<AtomicOp> for Operation {
393 fn from(value: AtomicOp) -> Self {
394 Operation::Atomic(value)
395 }
396}