cubecl_core/frontend/element/
atomic.rs1use cubecl_ir::{AtomicOp, ConstantValue, ManagedVariable, StorageType};
2use cubecl_macros::intrinsic;
3
4use super::{NativeAssign, NativeExpand, Numeric};
5use crate::{
6 self as cubecl,
7 frontend::{CubePrimitive, CubeType},
8 ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
9 prelude::*,
10};
11
12#[derive(Clone, Copy, Hash, PartialEq, Eq)]
16pub struct Atomic<Inner: CubePrimitive> {
17 pub val: Inner,
18}
19
20type AtomicExpand<Inner> = NativeExpand<Atomic<Inner>>;
21
22#[cube]
23impl<Inner: CubePrimitive<Scalar: Numeric>> Atomic<Inner> {
24 #[allow(unused_variables)]
26 pub fn load(&self) -> Inner {
27 intrinsic!(|scope| {
28 let pointer: ManagedVariable = self.into();
29 let new_var = scope.create_local(Inner::as_type(scope));
30 scope.register(Instruction::new(
31 AtomicOp::Load(UnaryOperator { input: *pointer }),
32 *new_var,
33 ));
34 new_var.into()
35 })
36 }
37
38 #[allow(unused_variables)]
40 pub fn store(&self, value: Inner) {
41 intrinsic!(|scope| {
42 let ptr: ManagedVariable = self.into();
43 let value: ManagedVariable = value.into();
44 scope.register(Instruction::new(
45 AtomicOp::Store(UnaryOperator { input: *value }),
46 *ptr,
47 ));
48 })
49 }
50
51 #[allow(unused_variables)]
53 pub fn swap(&self, value: Inner) -> Inner {
54 intrinsic!(|scope| {
55 let ptr: ManagedVariable = self.into();
56 let value: ManagedVariable = value.into();
57 let new_var = scope.create_local(Inner::as_type(scope));
58 scope.register(Instruction::new(
59 AtomicOp::Swap(BinaryOperator {
60 lhs: *ptr,
61 rhs: *value,
62 }),
63 *new_var,
64 ));
65 new_var.into()
66 })
67 }
68
69 #[allow(unused_variables)]
71 pub fn fetch_add(&self, value: Inner) -> Inner {
72 intrinsic!(|scope| {
73 let ptr: ManagedVariable = self.into();
74 let value: ManagedVariable = value.into();
75 let new_var = scope.create_local(Inner::as_type(scope));
76 scope.register(Instruction::new(
77 AtomicOp::Add(BinaryOperator {
78 lhs: *ptr,
79 rhs: *value,
80 }),
81 *new_var,
82 ));
83 new_var.into()
84 })
85 }
86
87 #[allow(unused_variables)]
89 pub fn fetch_sub(&self, value: Inner) -> Inner {
90 intrinsic!(|scope| {
91 let ptr: ManagedVariable = self.into();
92 let value: ManagedVariable = value.into();
93 let new_var = scope.create_local(Inner::as_type(scope));
94 scope.register(Instruction::new(
95 AtomicOp::Sub(BinaryOperator {
96 lhs: *ptr,
97 rhs: *value,
98 }),
99 *new_var,
100 ));
101 new_var.into()
102 })
103 }
104
105 #[allow(unused_variables)]
108 pub fn fetch_max(&self, value: Inner) -> Inner {
109 intrinsic!(|scope| {
110 let ptr: ManagedVariable = self.into();
111 let value: ManagedVariable = value.into();
112 let new_var = scope.create_local(Inner::as_type(scope));
113 scope.register(Instruction::new(
114 AtomicOp::Max(BinaryOperator {
115 lhs: *ptr,
116 rhs: *value,
117 }),
118 *new_var,
119 ));
120 new_var.into()
121 })
122 }
123
124 #[allow(unused_variables)]
127 pub fn fetch_min(&self, value: Inner) -> Inner {
128 intrinsic!(|scope| {
129 let ptr: ManagedVariable = self.into();
130 let value: ManagedVariable = value.into();
131 let new_var = scope.create_local(Inner::as_type(scope));
132 scope.register(Instruction::new(
133 AtomicOp::Min(BinaryOperator {
134 lhs: *ptr,
135 rhs: *value,
136 }),
137 *new_var,
138 ));
139 new_var.into()
140 })
141 }
142}
143
144#[cube]
145impl<Inner: CubePrimitive<Scalar: Int>> Atomic<Inner> {
146 #[allow(unused_variables)]
152 pub fn compare_exchange_weak(&self, cmp: Inner, value: Inner) -> Inner {
153 intrinsic!(|scope| {
154 let pointer: ManagedVariable = self.into();
155 let cmp: ManagedVariable = cmp.into();
156 let value: ManagedVariable = value.into();
157 let new_var = scope.create_local(Inner::as_type(scope));
158 scope.register(Instruction::new(
159 AtomicOp::CompareAndSwap(CompareAndSwapOperator {
160 input: *pointer,
161 cmp: *cmp,
162 val: *value,
163 }),
164 *new_var,
165 ));
166 new_var.into()
167 })
168 }
169
170 #[allow(unused_variables)]
172 pub fn fetch_and(&self, value: Inner) -> Inner {
173 intrinsic!(|scope| {
174 let ptr: ManagedVariable = self.into();
175 let value: ManagedVariable = value.into();
176 let new_var = scope.create_local(Inner::as_type(scope));
177 scope.register(Instruction::new(
178 AtomicOp::And(BinaryOperator {
179 lhs: *ptr,
180 rhs: *value,
181 }),
182 *new_var,
183 ));
184 new_var.into()
185 })
186 }
187
188 #[allow(unused_variables)]
190 pub fn fetch_or(&self, value: Inner) -> Inner {
191 intrinsic!(|scope| {
192 let ptr: ManagedVariable = self.into();
193 let value: ManagedVariable = value.into();
194 let new_var = scope.create_local(Inner::as_type(scope));
195 scope.register(Instruction::new(
196 AtomicOp::Or(BinaryOperator {
197 lhs: *ptr,
198 rhs: *value,
199 }),
200 *new_var,
201 ));
202 new_var.into()
203 })
204 }
205
206 #[allow(unused_variables)]
208 pub fn fetch_xor(&self, value: Inner) -> Inner {
209 intrinsic!(|scope| {
210 let ptr: ManagedVariable = self.into();
211 let value: ManagedVariable = value.into();
212 let new_var = scope.create_local(Inner::as_type(scope));
213 scope.register(Instruction::new(
214 AtomicOp::Xor(BinaryOperator {
215 lhs: *ptr,
216 rhs: *value,
217 }),
218 *new_var,
219 ));
220 new_var.into()
221 })
222 }
223}
224
225impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
226 type ExpandType = NativeExpand<Self>;
227}
228
229impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
230 type Scalar = Inner::Scalar;
231 type Size = Const<1>;
232 type WithScalar<S: Scalar> = Atomic<S>;
233
234 fn as_type_native() -> Option<Type> {
235 Inner::as_type_native().map(|it| it.with_storage_type(StorageType::Atomic(it.elem_type())))
236 }
237
238 fn as_type(scope: &Scope) -> Type {
239 let inner = Inner::as_type(scope);
240 inner.with_storage_type(StorageType::Atomic(inner.elem_type()))
241 }
242
243 fn as_type_native_unchecked() -> Type {
244 let inner = Inner::as_type_native_unchecked();
245 inner.with_storage_type(StorageType::Atomic(inner.elem_type()))
246 }
247
248 fn size() -> Option<usize> {
249 Inner::size()
250 }
251
252 fn from_expand_elem(elem: ManagedVariable) -> Self::ExpandType {
253 NativeExpand::new(elem)
254 }
255
256 fn from_const_value(_value: ConstantValue) -> Self {
257 panic!("Can't have constant atomic");
258 }
259}
260
261impl<Inner: CubePrimitive> NativeAssign for Atomic<Inner> {}