cubecl_core/frontend/element/
atomic.rs1use cubecl_ir::{AtomicOp, ConstantScalarValue, ExpandElement, StorageType};
2
3use super::{ExpandElementIntoMut, ExpandElementTyped, Int, Numeric, into_mut_expand_element};
4use crate::{
5    frontend::{CubePrimitive, CubeType},
6    ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
7    unexpanded,
8};
9
10#[derive(Clone, Copy, Hash, PartialEq, Eq)]
14pub struct Atomic<Inner: CubePrimitive> {
15    pub val: Inner,
16}
17
18impl<Inner: Numeric> Atomic<Inner> {
19    #[allow(unused_variables)]
21    pub fn load(pointer: &Self) -> Inner {
22        unexpanded!()
23    }
24
25    #[allow(unused_variables)]
27    pub fn store(pointer: &Self, value: Inner) {
28        unexpanded!()
29    }
30
31    #[allow(unused_variables)]
33    pub fn swap(pointer: &Self, value: Inner) -> Inner {
34        unexpanded!()
35    }
36
37    #[allow(unused_variables)]
39    pub fn add(pointer: &Self, value: Inner) -> Inner {
40        unexpanded!()
41    }
42
43    #[allow(unused_variables)]
46    pub fn max(pointer: &Self, value: Inner) -> Inner {
47        unexpanded!()
48    }
49
50    #[allow(unused_variables)]
53    pub fn min(pointer: &Self, value: Inner) -> Inner {
54        unexpanded!()
55    }
56
57    #[allow(unused_variables)]
59    pub fn sub(pointer: &Self, value: Inner) -> Inner {
60        unexpanded!()
61    }
62
63    pub fn __expand_load(
64        scope: &mut Scope,
65        pointer: <Self as CubeType>::ExpandType,
66    ) -> <Inner as CubeType>::ExpandType {
67        let pointer: ExpandElement = pointer.into();
68        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
69        scope.register(Instruction::new(
70            AtomicOp::Load(UnaryOperator { input: *pointer }),
71            *new_var,
72        ));
73        new_var.into()
74    }
75
76    pub fn __expand_store(
77        scope: &mut Scope,
78        pointer: <Self as CubeType>::ExpandType,
79        value: <Inner as CubeType>::ExpandType,
80    ) {
81        let ptr: ExpandElement = pointer.into();
82        let value: ExpandElement = value.into();
83        scope.register(Instruction::new(
84            AtomicOp::Store(UnaryOperator { input: *value }),
85            *ptr,
86        ));
87    }
88
89    pub fn __expand_swap(
90        scope: &mut Scope,
91        pointer: <Self as CubeType>::ExpandType,
92        value: <Inner as CubeType>::ExpandType,
93    ) -> <Inner as CubeType>::ExpandType {
94        let ptr: ExpandElement = pointer.into();
95        let value: ExpandElement = value.into();
96        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
97        scope.register(Instruction::new(
98            AtomicOp::Swap(BinaryOperator {
99                lhs: *ptr,
100                rhs: *value,
101            }),
102            *new_var,
103        ));
104        new_var.into()
105    }
106
107    pub fn __expand_add(
108        scope: &mut Scope,
109        pointer: <Self as CubeType>::ExpandType,
110        value: <Inner as CubeType>::ExpandType,
111    ) -> <Inner as CubeType>::ExpandType {
112        let ptr: ExpandElement = pointer.into();
113        let value: ExpandElement = value.into();
114        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
115        scope.register(Instruction::new(
116            AtomicOp::Add(BinaryOperator {
117                lhs: *ptr,
118                rhs: *value,
119            }),
120            *new_var,
121        ));
122        new_var.into()
123    }
124
125    pub fn __expand_sub(
126        scope: &mut Scope,
127        pointer: <Self as CubeType>::ExpandType,
128        value: <Inner as CubeType>::ExpandType,
129    ) -> <Inner as CubeType>::ExpandType {
130        let ptr: ExpandElement = pointer.into();
131        let value: ExpandElement = value.into();
132        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
133        scope.register(Instruction::new(
134            AtomicOp::Sub(BinaryOperator {
135                lhs: *ptr,
136                rhs: *value,
137            }),
138            *new_var,
139        ));
140        new_var.into()
141    }
142
143    pub fn __expand_max(
144        scope: &mut Scope,
145        pointer: <Self as CubeType>::ExpandType,
146        value: <Inner as CubeType>::ExpandType,
147    ) -> <Inner as CubeType>::ExpandType {
148        let ptr: ExpandElement = pointer.into();
149        let value: ExpandElement = value.into();
150        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
151        scope.register(Instruction::new(
152            AtomicOp::Max(BinaryOperator {
153                lhs: *ptr,
154                rhs: *value,
155            }),
156            *new_var,
157        ));
158        new_var.into()
159    }
160
161    pub fn __expand_min(
162        scope: &mut Scope,
163        pointer: <Self as CubeType>::ExpandType,
164        value: <Inner as CubeType>::ExpandType,
165    ) -> <Inner as CubeType>::ExpandType {
166        let ptr: ExpandElement = pointer.into();
167        let value: ExpandElement = value.into();
168        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
169        scope.register(Instruction::new(
170            AtomicOp::Min(BinaryOperator {
171                lhs: *ptr,
172                rhs: *value,
173            }),
174            *new_var,
175        ));
176        new_var.into()
177    }
178}
179
180impl<Inner: Int> Atomic<Inner> {
181    #[allow(unused_variables)]
187    pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
188        unexpanded!()
189    }
190
191    #[allow(unused_variables)]
193    pub fn and(pointer: &Self, value: Inner) -> Inner {
194        unexpanded!()
195    }
196
197    #[allow(unused_variables)]
199    pub fn or(pointer: &Self, value: Inner) -> Inner {
200        unexpanded!()
201    }
202
203    #[allow(unused_variables)]
205    pub fn xor(pointer: &Self, value: Inner) -> Inner {
206        unexpanded!()
207    }
208
209    pub fn __expand_compare_and_swap(
210        scope: &mut Scope,
211        pointer: <Self as CubeType>::ExpandType,
212        cmp: <Inner as CubeType>::ExpandType,
213        value: <Inner as CubeType>::ExpandType,
214    ) -> <Inner as CubeType>::ExpandType {
215        let pointer: ExpandElement = pointer.into();
216        let cmp: ExpandElement = cmp.into();
217        let value: ExpandElement = value.into();
218        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
219        scope.register(Instruction::new(
220            AtomicOp::CompareAndSwap(CompareAndSwapOperator {
221                input: *pointer,
222                cmp: *cmp,
223                val: *value,
224            }),
225            *new_var,
226        ));
227        new_var.into()
228    }
229
230    pub fn __expand_and(
231        scope: &mut Scope,
232        pointer: <Self as CubeType>::ExpandType,
233        value: <Inner as CubeType>::ExpandType,
234    ) -> <Inner as CubeType>::ExpandType {
235        let ptr: ExpandElement = pointer.into();
236        let value: ExpandElement = value.into();
237        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
238        scope.register(Instruction::new(
239            AtomicOp::And(BinaryOperator {
240                lhs: *ptr,
241                rhs: *value,
242            }),
243            *new_var,
244        ));
245        new_var.into()
246    }
247
248    pub fn __expand_or(
249        scope: &mut Scope,
250        pointer: <Self as CubeType>::ExpandType,
251        value: <Inner as CubeType>::ExpandType,
252    ) -> <Inner as CubeType>::ExpandType {
253        let ptr: ExpandElement = pointer.into();
254        let value: ExpandElement = value.into();
255        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
256        scope.register(Instruction::new(
257            AtomicOp::Or(BinaryOperator {
258                lhs: *ptr,
259                rhs: *value,
260            }),
261            *new_var,
262        ));
263        new_var.into()
264    }
265
266    pub fn __expand_xor(
267        scope: &mut Scope,
268        pointer: <Self as CubeType>::ExpandType,
269        value: <Inner as CubeType>::ExpandType,
270    ) -> <Inner as CubeType>::ExpandType {
271        let ptr: ExpandElement = pointer.into();
272        let value: ExpandElement = value.into();
273        let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
274        scope.register(Instruction::new(
275            AtomicOp::Xor(BinaryOperator {
276                lhs: *ptr,
277                rhs: *value,
278            }),
279            *new_var,
280        ));
281        new_var.into()
282    }
283}
284
285impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
286    type ExpandType = ExpandElementTyped<Self>;
287}
288
289impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
290    fn as_type_native() -> Option<StorageType> {
291        Inner::as_type_native().map(|it| StorageType::Atomic(it.elem_type()))
292    }
293
294    fn as_type(scope: &Scope) -> StorageType {
295        StorageType::Atomic(Inner::as_type(scope).elem_type())
296    }
297
298    fn as_type_native_unchecked() -> StorageType {
299        StorageType::Atomic(Inner::as_type_native_unchecked().elem_type())
300    }
301
302    fn size() -> Option<usize> {
303        Inner::size()
304    }
305
306    fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
307        ExpandElementTyped::new(elem)
308    }
309
310    fn from_const_value(_value: ConstantScalarValue) -> Self {
311        panic!("Can't have constant atomic");
312    }
313}
314
315impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
316    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
317        into_mut_expand_element(scope, elem)
318    }
319}