cubecl_core/frontend/element/
atomic.rs1use cubecl_ir::{AtomicOp, ConstantScalarValue, ExpandElement, StorageType};
2
3use super::{ExpandElementIntoMut, ExpandElementTyped, Int, Numeric, into_mut_expand_element};
4use crate::{
5 frontend::{CubePrimitive, CubeType},
6 ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
7 unexpanded,
8};
9
10#[derive(Clone, Copy, Hash, PartialEq, Eq)]
14pub struct Atomic<Inner: CubePrimitive> {
15 pub val: Inner,
16}
17
18impl<Inner: Numeric> Atomic<Inner> {
19 #[allow(unused_variables)]
21 pub fn load(pointer: &Self) -> Inner {
22 unexpanded!()
23 }
24
25 #[allow(unused_variables)]
27 pub fn store(pointer: &Self, value: Inner) {
28 unexpanded!()
29 }
30
31 #[allow(unused_variables)]
33 pub fn swap(pointer: &Self, value: Inner) -> Inner {
34 unexpanded!()
35 }
36
37 #[allow(unused_variables)]
39 pub fn add(pointer: &Self, value: Inner) -> Inner {
40 unexpanded!()
41 }
42
43 #[allow(unused_variables)]
46 pub fn max(pointer: &Self, value: Inner) -> Inner {
47 unexpanded!()
48 }
49
50 #[allow(unused_variables)]
53 pub fn min(pointer: &Self, value: Inner) -> Inner {
54 unexpanded!()
55 }
56
57 #[allow(unused_variables)]
59 pub fn sub(pointer: &Self, value: Inner) -> Inner {
60 unexpanded!()
61 }
62
63 pub fn __expand_load(
64 scope: &mut Scope,
65 pointer: <Self as CubeType>::ExpandType,
66 ) -> <Inner as CubeType>::ExpandType {
67 let pointer: ExpandElement = pointer.into();
68 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
69 scope.register(Instruction::new(
70 AtomicOp::Load(UnaryOperator { input: *pointer }),
71 *new_var,
72 ));
73 new_var.into()
74 }
75
76 pub fn __expand_store(
77 scope: &mut Scope,
78 pointer: <Self as CubeType>::ExpandType,
79 value: <Inner as CubeType>::ExpandType,
80 ) {
81 let ptr: ExpandElement = pointer.into();
82 let value: ExpandElement = value.into();
83 scope.register(Instruction::new(
84 AtomicOp::Store(UnaryOperator { input: *value }),
85 *ptr,
86 ));
87 }
88
89 pub fn __expand_swap(
90 scope: &mut Scope,
91 pointer: <Self as CubeType>::ExpandType,
92 value: <Inner as CubeType>::ExpandType,
93 ) -> <Inner as CubeType>::ExpandType {
94 let ptr: ExpandElement = pointer.into();
95 let value: ExpandElement = value.into();
96 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
97 scope.register(Instruction::new(
98 AtomicOp::Swap(BinaryOperator {
99 lhs: *ptr,
100 rhs: *value,
101 }),
102 *new_var,
103 ));
104 new_var.into()
105 }
106
107 pub fn __expand_add(
108 scope: &mut Scope,
109 pointer: <Self as CubeType>::ExpandType,
110 value: <Inner as CubeType>::ExpandType,
111 ) -> <Inner as CubeType>::ExpandType {
112 let ptr: ExpandElement = pointer.into();
113 let value: ExpandElement = value.into();
114 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
115 scope.register(Instruction::new(
116 AtomicOp::Add(BinaryOperator {
117 lhs: *ptr,
118 rhs: *value,
119 }),
120 *new_var,
121 ));
122 new_var.into()
123 }
124
125 pub fn __expand_sub(
126 scope: &mut Scope,
127 pointer: <Self as CubeType>::ExpandType,
128 value: <Inner as CubeType>::ExpandType,
129 ) -> <Inner as CubeType>::ExpandType {
130 let ptr: ExpandElement = pointer.into();
131 let value: ExpandElement = value.into();
132 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
133 scope.register(Instruction::new(
134 AtomicOp::Sub(BinaryOperator {
135 lhs: *ptr,
136 rhs: *value,
137 }),
138 *new_var,
139 ));
140 new_var.into()
141 }
142
143 pub fn __expand_max(
144 scope: &mut Scope,
145 pointer: <Self as CubeType>::ExpandType,
146 value: <Inner as CubeType>::ExpandType,
147 ) -> <Inner as CubeType>::ExpandType {
148 let ptr: ExpandElement = pointer.into();
149 let value: ExpandElement = value.into();
150 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
151 scope.register(Instruction::new(
152 AtomicOp::Max(BinaryOperator {
153 lhs: *ptr,
154 rhs: *value,
155 }),
156 *new_var,
157 ));
158 new_var.into()
159 }
160
161 pub fn __expand_min(
162 scope: &mut Scope,
163 pointer: <Self as CubeType>::ExpandType,
164 value: <Inner as CubeType>::ExpandType,
165 ) -> <Inner as CubeType>::ExpandType {
166 let ptr: ExpandElement = pointer.into();
167 let value: ExpandElement = value.into();
168 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
169 scope.register(Instruction::new(
170 AtomicOp::Min(BinaryOperator {
171 lhs: *ptr,
172 rhs: *value,
173 }),
174 *new_var,
175 ));
176 new_var.into()
177 }
178}
179
180impl<Inner: Int> Atomic<Inner> {
181 #[allow(unused_variables)]
187 pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
188 unexpanded!()
189 }
190
191 #[allow(unused_variables)]
193 pub fn and(pointer: &Self, value: Inner) -> Inner {
194 unexpanded!()
195 }
196
197 #[allow(unused_variables)]
199 pub fn or(pointer: &Self, value: Inner) -> Inner {
200 unexpanded!()
201 }
202
203 #[allow(unused_variables)]
205 pub fn xor(pointer: &Self, value: Inner) -> Inner {
206 unexpanded!()
207 }
208
209 pub fn __expand_compare_and_swap(
210 scope: &mut Scope,
211 pointer: <Self as CubeType>::ExpandType,
212 cmp: <Inner as CubeType>::ExpandType,
213 value: <Inner as CubeType>::ExpandType,
214 ) -> <Inner as CubeType>::ExpandType {
215 let pointer: ExpandElement = pointer.into();
216 let cmp: ExpandElement = cmp.into();
217 let value: ExpandElement = value.into();
218 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
219 scope.register(Instruction::new(
220 AtomicOp::CompareAndSwap(CompareAndSwapOperator {
221 input: *pointer,
222 cmp: *cmp,
223 val: *value,
224 }),
225 *new_var,
226 ));
227 new_var.into()
228 }
229
230 pub fn __expand_and(
231 scope: &mut Scope,
232 pointer: <Self as CubeType>::ExpandType,
233 value: <Inner as CubeType>::ExpandType,
234 ) -> <Inner as CubeType>::ExpandType {
235 let ptr: ExpandElement = pointer.into();
236 let value: ExpandElement = value.into();
237 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
238 scope.register(Instruction::new(
239 AtomicOp::And(BinaryOperator {
240 lhs: *ptr,
241 rhs: *value,
242 }),
243 *new_var,
244 ));
245 new_var.into()
246 }
247
248 pub fn __expand_or(
249 scope: &mut Scope,
250 pointer: <Self as CubeType>::ExpandType,
251 value: <Inner as CubeType>::ExpandType,
252 ) -> <Inner as CubeType>::ExpandType {
253 let ptr: ExpandElement = pointer.into();
254 let value: ExpandElement = value.into();
255 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
256 scope.register(Instruction::new(
257 AtomicOp::Or(BinaryOperator {
258 lhs: *ptr,
259 rhs: *value,
260 }),
261 *new_var,
262 ));
263 new_var.into()
264 }
265
266 pub fn __expand_xor(
267 scope: &mut Scope,
268 pointer: <Self as CubeType>::ExpandType,
269 value: <Inner as CubeType>::ExpandType,
270 ) -> <Inner as CubeType>::ExpandType {
271 let ptr: ExpandElement = pointer.into();
272 let value: ExpandElement = value.into();
273 let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
274 scope.register(Instruction::new(
275 AtomicOp::Xor(BinaryOperator {
276 lhs: *ptr,
277 rhs: *value,
278 }),
279 *new_var,
280 ));
281 new_var.into()
282 }
283}
284
285impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
286 type ExpandType = ExpandElementTyped<Self>;
287}
288
289impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
290 fn as_type_native() -> Option<StorageType> {
291 Inner::as_type_native().map(|it| StorageType::Atomic(it.elem_type()))
292 }
293
294 fn as_type(scope: &Scope) -> StorageType {
295 StorageType::Atomic(Inner::as_type(scope).elem_type())
296 }
297
298 fn as_type_native_unchecked() -> StorageType {
299 StorageType::Atomic(Inner::as_type_native_unchecked().elem_type())
300 }
301
302 fn size() -> Option<usize> {
303 Inner::size()
304 }
305
306 fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
307 ExpandElementTyped::new(elem)
308 }
309
310 fn from_const_value(_value: ConstantScalarValue) -> Self {
311 panic!("Can't have constant atomic");
312 }
313}
314
315impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
316 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
317 into_mut_expand_element(scope, elem)
318 }
319}