cubecl_core/frontend/element/
atomic.rs1use cubecl_ir::{AtomicOp, ConstantValue, ExpandElement, StorageType};
2use cubecl_macros::intrinsic;
3
4use super::{ExpandElementIntoMut, ExpandElementTyped, Int, Numeric, into_mut_expand_element};
5use crate::{
6 self as cubecl,
7 frontend::{CubePrimitive, CubeType},
8 ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
9 prelude::*,
10};
11
12#[derive(Clone, Copy, Hash, PartialEq, Eq)]
16pub struct Atomic<Inner: CubePrimitive> {
17 pub val: Inner,
18}
19
20type AtomicExpand<Inner> = ExpandElementTyped<Atomic<Inner>>;
21
22#[cube]
23impl<Inner: Numeric> Atomic<Inner> {
24 #[allow(unused_variables)]
26 pub fn load(&self) -> Inner {
27 intrinsic!(|scope| {
28 let pointer: ExpandElement = self.into();
29 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
30 scope.register(Instruction::new(
31 AtomicOp::Load(UnaryOperator { input: *pointer }),
32 *new_var,
33 ));
34 new_var.into()
35 })
36 }
37
38 #[allow(unused_variables)]
40 pub fn store(&self, value: Inner) {
41 intrinsic!(|scope| {
42 let ptr: ExpandElement = self.into();
43 let value: ExpandElement = value.into();
44 scope.register(Instruction::new(
45 AtomicOp::Store(UnaryOperator { input: *value }),
46 *ptr,
47 ));
48 })
49 }
50
51 #[allow(unused_variables)]
53 pub fn swap(&self, value: Inner) -> Inner {
54 intrinsic!(|scope| {
55 let ptr: ExpandElement = self.into();
56 let value: ExpandElement = value.into();
57 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
58 scope.register(Instruction::new(
59 AtomicOp::Swap(BinaryOperator {
60 lhs: *ptr,
61 rhs: *value,
62 }),
63 *new_var,
64 ));
65 new_var.into()
66 })
67 }
68
69 #[allow(unused_variables)]
71 pub fn fetch_add(&self, value: Inner) -> Inner {
72 intrinsic!(|scope| {
73 let ptr: ExpandElement = self.into();
74 let value: ExpandElement = value.into();
75 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
76 scope.register(Instruction::new(
77 AtomicOp::Add(BinaryOperator {
78 lhs: *ptr,
79 rhs: *value,
80 }),
81 *new_var,
82 ));
83 new_var.into()
84 })
85 }
86
87 #[allow(unused_variables)]
89 pub fn fetch_sub(&self, value: Inner) -> Inner {
90 intrinsic!(|scope| {
91 let ptr: ExpandElement = self.into();
92 let value: ExpandElement = value.into();
93 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
94 scope.register(Instruction::new(
95 AtomicOp::Sub(BinaryOperator {
96 lhs: *ptr,
97 rhs: *value,
98 }),
99 *new_var,
100 ));
101 new_var.into()
102 })
103 }
104
105 #[allow(unused_variables)]
108 pub fn fetch_max(&self, value: Inner) -> Inner {
109 intrinsic!(|scope| {
110 let ptr: ExpandElement = self.into();
111 let value: ExpandElement = value.into();
112 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
113 scope.register(Instruction::new(
114 AtomicOp::Max(BinaryOperator {
115 lhs: *ptr,
116 rhs: *value,
117 }),
118 *new_var,
119 ));
120 new_var.into()
121 })
122 }
123
124 #[allow(unused_variables)]
127 pub fn fetch_min(&self, value: Inner) -> Inner {
128 intrinsic!(|scope| {
129 let ptr: ExpandElement = self.into();
130 let value: ExpandElement = value.into();
131 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
132 scope.register(Instruction::new(
133 AtomicOp::Min(BinaryOperator {
134 lhs: *ptr,
135 rhs: *value,
136 }),
137 *new_var,
138 ));
139 new_var.into()
140 })
141 }
142}
143
144#[cube]
145impl<Inner: Int> Atomic<Inner> {
146 #[allow(unused_variables)]
152 pub fn compare_exchange_weak(&self, cmp: Inner, value: Inner) -> Inner {
153 intrinsic!(|scope| {
154 let pointer: ExpandElement = self.into();
155 let cmp: ExpandElement = cmp.into();
156 let value: ExpandElement = value.into();
157 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
158 scope.register(Instruction::new(
159 AtomicOp::CompareAndSwap(CompareAndSwapOperator {
160 input: *pointer,
161 cmp: *cmp,
162 val: *value,
163 }),
164 *new_var,
165 ));
166 new_var.into()
167 })
168 }
169
170 #[allow(unused_variables)]
172 pub fn fetch_and(&self, value: Inner) -> Inner {
173 intrinsic!(|scope| {
174 let ptr: ExpandElement = self.into();
175 let value: ExpandElement = value.into();
176 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
177 scope.register(Instruction::new(
178 AtomicOp::And(BinaryOperator {
179 lhs: *ptr,
180 rhs: *value,
181 }),
182 *new_var,
183 ));
184 new_var.into()
185 })
186 }
187
188 #[allow(unused_variables)]
190 pub fn fetch_or(&self, value: Inner) -> Inner {
191 intrinsic!(|scope| {
192 let ptr: ExpandElement = self.into();
193 let value: ExpandElement = value.into();
194 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
195 scope.register(Instruction::new(
196 AtomicOp::Or(BinaryOperator {
197 lhs: *ptr,
198 rhs: *value,
199 }),
200 *new_var,
201 ));
202 new_var.into()
203 })
204 }
205
206 #[allow(unused_variables)]
208 pub fn fetch_xor(&self, value: Inner) -> Inner {
209 intrinsic!(|scope| {
210 let ptr: ExpandElement = self.into();
211 let value: ExpandElement = value.into();
212 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
213 scope.register(Instruction::new(
214 AtomicOp::Xor(BinaryOperator {
215 lhs: *ptr,
216 rhs: *value,
217 }),
218 *new_var,
219 ));
220 new_var.into()
221 })
222 }
223}
224
225impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
226 type ExpandType = ExpandElementTyped<Self>;
227}
228
229impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
230 fn as_type_native() -> Option<StorageType> {
231 Inner::as_type_native().map(|it| StorageType::Atomic(it.elem_type()))
232 }
233
234 fn as_type(scope: &Scope) -> StorageType {
235 StorageType::Atomic(Inner::as_type(scope).elem_type())
236 }
237
238 fn as_type_native_unchecked() -> StorageType {
239 StorageType::Atomic(Inner::as_type_native_unchecked().elem_type())
240 }
241
242 fn size() -> Option<usize> {
243 Inner::size()
244 }
245
246 fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
247 ExpandElementTyped::new(elem)
248 }
249
250 fn from_const_value(_value: ConstantValue) -> Self {
251 panic!("Can't have constant atomic");
252 }
253}
254
255impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
256 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
257 into_mut_expand_element(scope, elem)
258 }
259}