cubecl_core/frontend/element/
atomic.rs

1use cubecl_ir::{AtomicOp, ConstantValue, ExpandElement, StorageType};
2use cubecl_macros::intrinsic;
3
4use super::{ExpandElementIntoMut, ExpandElementTyped, Int, Numeric, into_mut_expand_element};
5use crate::{
6    self as cubecl,
7    frontend::{CubePrimitive, CubeType},
8    ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
9    prelude::*,
10};
11
12/// An atomic numerical type wrapping a normal numeric primitive. Enables the use of atomic
13/// operations, while disabling normal operations. In WGSL, this is a separate type - on CUDA/SPIR-V
14/// it can theoretically be bitcast to a normal number, but this isn't recommended.
15#[derive(Clone, Copy, Hash, PartialEq, Eq)]
16pub struct Atomic<Inner: CubePrimitive> {
17    pub val: Inner,
18}
19
20type AtomicExpand<Inner> = ExpandElementTyped<Atomic<Inner>>;
21
22#[cube]
23impl<Inner: Numeric> Atomic<Inner> {
24    /// Load the value of the atomic.
25    #[allow(unused_variables)]
26    pub fn load(&self) -> Inner {
27        intrinsic!(|scope| {
28            let pointer: ExpandElement = self.into();
29            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
30            scope.register(Instruction::new(
31                AtomicOp::Load(UnaryOperator { input: *pointer }),
32                *new_var,
33            ));
34            new_var.into()
35        })
36    }
37
38    /// Store the value of the atomic.
39    #[allow(unused_variables)]
40    pub fn store(&self, value: Inner) {
41        intrinsic!(|scope| {
42            let ptr: ExpandElement = self.into();
43            let value: ExpandElement = value.into();
44            scope.register(Instruction::new(
45                AtomicOp::Store(UnaryOperator { input: *value }),
46                *ptr,
47            ));
48        })
49    }
50
51    /// Atomically stores the value into the atomic and returns the old value.
52    #[allow(unused_variables)]
53    pub fn swap(&self, value: Inner) -> Inner {
54        intrinsic!(|scope| {
55            let ptr: ExpandElement = self.into();
56            let value: ExpandElement = value.into();
57            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
58            scope.register(Instruction::new(
59                AtomicOp::Swap(BinaryOperator {
60                    lhs: *ptr,
61                    rhs: *value,
62                }),
63                *new_var,
64            ));
65            new_var.into()
66        })
67    }
68
69    /// Atomically add a number to the atomic variable. Returns the old value.
70    #[allow(unused_variables)]
71    pub fn fetch_add(&self, value: Inner) -> Inner {
72        intrinsic!(|scope| {
73            let ptr: ExpandElement = self.into();
74            let value: ExpandElement = value.into();
75            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
76            scope.register(Instruction::new(
77                AtomicOp::Add(BinaryOperator {
78                    lhs: *ptr,
79                    rhs: *value,
80                }),
81                *new_var,
82            ));
83            new_var.into()
84        })
85    }
86
87    /// Atomically subtracts a number from the atomic variable. Returns the old value.
88    #[allow(unused_variables)]
89    pub fn fetch_sub(&self, value: Inner) -> Inner {
90        intrinsic!(|scope| {
91            let ptr: ExpandElement = self.into();
92            let value: ExpandElement = value.into();
93            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
94            scope.register(Instruction::new(
95                AtomicOp::Sub(BinaryOperator {
96                    lhs: *ptr,
97                    rhs: *value,
98                }),
99                *new_var,
100            ));
101            new_var.into()
102        })
103    }
104
105    /// Atomically sets the value of the atomic variable to `max(current_value, value)`. Returns
106    /// the old value.
107    #[allow(unused_variables)]
108    pub fn fetch_max(&self, value: Inner) -> Inner {
109        intrinsic!(|scope| {
110            let ptr: ExpandElement = self.into();
111            let value: ExpandElement = value.into();
112            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
113            scope.register(Instruction::new(
114                AtomicOp::Max(BinaryOperator {
115                    lhs: *ptr,
116                    rhs: *value,
117                }),
118                *new_var,
119            ));
120            new_var.into()
121        })
122    }
123
124    /// Atomically sets the value of the atomic variable to `min(current_value, value)`. Returns the
125    /// old value.
126    #[allow(unused_variables)]
127    pub fn fetch_min(&self, value: Inner) -> Inner {
128        intrinsic!(|scope| {
129            let ptr: ExpandElement = self.into();
130            let value: ExpandElement = value.into();
131            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
132            scope.register(Instruction::new(
133                AtomicOp::Min(BinaryOperator {
134                    lhs: *ptr,
135                    rhs: *value,
136                }),
137                *new_var,
138            ));
139            new_var.into()
140        })
141    }
142}
143
144#[cube]
145impl<Inner: Int> Atomic<Inner> {
146    /// Compare the value at `pointer` to `cmp` and set it to `value` only if they are the same.
147    /// Returns the old value of the pointer before the store.
148    ///
149    /// ### Tip
150    /// Compare the returned value to `cmp` to determine whether the store was successful.
151    #[allow(unused_variables)]
152    pub fn compare_exchange_weak(&self, cmp: Inner, value: Inner) -> Inner {
153        intrinsic!(|scope| {
154            let pointer: ExpandElement = self.into();
155            let cmp: ExpandElement = cmp.into();
156            let value: ExpandElement = value.into();
157            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
158            scope.register(Instruction::new(
159                AtomicOp::CompareAndSwap(CompareAndSwapOperator {
160                    input: *pointer,
161                    cmp: *cmp,
162                    val: *value,
163                }),
164                *new_var,
165            ));
166            new_var.into()
167        })
168    }
169
170    /// Executes an atomic bitwise and operation on the atomic variable. Returns the old value.
171    #[allow(unused_variables)]
172    pub fn fetch_and(&self, value: Inner) -> Inner {
173        intrinsic!(|scope| {
174            let ptr: ExpandElement = self.into();
175            let value: ExpandElement = value.into();
176            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
177            scope.register(Instruction::new(
178                AtomicOp::And(BinaryOperator {
179                    lhs: *ptr,
180                    rhs: *value,
181                }),
182                *new_var,
183            ));
184            new_var.into()
185        })
186    }
187
188    /// Executes an atomic bitwise or operation on the atomic variable. Returns the old value.
189    #[allow(unused_variables)]
190    pub fn fetch_or(&self, value: Inner) -> Inner {
191        intrinsic!(|scope| {
192            let ptr: ExpandElement = self.into();
193            let value: ExpandElement = value.into();
194            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
195            scope.register(Instruction::new(
196                AtomicOp::Or(BinaryOperator {
197                    lhs: *ptr,
198                    rhs: *value,
199                }),
200                *new_var,
201            ));
202            new_var.into()
203        })
204    }
205
206    /// Executes an atomic bitwise xor operation on the atomic variable. Returns the old value.
207    #[allow(unused_variables)]
208    pub fn fetch_xor(&self, value: Inner) -> Inner {
209        intrinsic!(|scope| {
210            let ptr: ExpandElement = self.into();
211            let value: ExpandElement = value.into();
212            let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
213            scope.register(Instruction::new(
214                AtomicOp::Xor(BinaryOperator {
215                    lhs: *ptr,
216                    rhs: *value,
217                }),
218                *new_var,
219            ));
220            new_var.into()
221        })
222    }
223}
224
225impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
226    type ExpandType = ExpandElementTyped<Self>;
227}
228
229impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
230    fn as_type_native() -> Option<StorageType> {
231        Inner::as_type_native().map(|it| StorageType::Atomic(it.elem_type()))
232    }
233
234    fn as_type(scope: &Scope) -> StorageType {
235        StorageType::Atomic(Inner::as_type(scope).elem_type())
236    }
237
238    fn as_type_native_unchecked() -> StorageType {
239        StorageType::Atomic(Inner::as_type_native_unchecked().elem_type())
240    }
241
242    fn size() -> Option<usize> {
243        Inner::size()
244    }
245
246    fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
247        ExpandElementTyped::new(elem)
248    }
249
250    fn from_const_value(_value: ConstantValue) -> Self {
251        panic!("Can't have constant atomic");
252    }
253}
254
255impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
256    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
257        into_mut_expand_element(scope, elem)
258    }
259}