cubecl_core/frontend/element/
atomic.rs

1use cubecl_ir::{AtomicOp, ExpandElement};
2
3use super::{
4    ExpandElementIntoMut, ExpandElementTyped, Int, LaunchArgExpand, Numeric,
5    into_mut_expand_element,
6};
7use crate::{
8    frontend::{CubePrimitive, CubeType},
9    ir::{BinaryOperator, CompareAndSwapOperator, Elem, Instruction, Item, Scope, UnaryOperator},
10    prelude::KernelBuilder,
11    unexpanded,
12};
13
14/// An atomic numerical type wrapping a normal numeric primitive. Enables the use of atomic
15/// operations, while disabling normal operations. In WGSL, this is a separate type - on CUDA/SPIR-V
16/// it can theoretically be bitcast to a normal number, but this isn't recommended.
17#[derive(Clone, Copy, Hash, PartialEq, Eq)]
18pub struct Atomic<Inner: CubePrimitive> {
19    pub val: Inner,
20}
21
22impl<Inner: Numeric> Atomic<Inner> {
23    /// Load the value of the atomic.
24    #[allow(unused_variables)]
25    pub fn load(pointer: &Self) -> Inner {
26        unexpanded!()
27    }
28
29    /// Store the value of the atomic.
30    #[allow(unused_variables)]
31    pub fn store(pointer: &Self, value: Inner) {
32        unexpanded!()
33    }
34
35    /// Atomically stores the value into the atomic and returns the old value.
36    #[allow(unused_variables)]
37    pub fn swap(pointer: &Self, value: Inner) -> Inner {
38        unexpanded!()
39    }
40
41    /// Atomically add a number to the atomic variable. Returns the old value.
42    #[allow(unused_variables)]
43    pub fn add(pointer: &Self, value: Inner) -> Inner {
44        unexpanded!()
45    }
46
47    /// Atomically sets the value of the atomic variable to `max(current_value, value)`. Returns
48    /// the old value.
49    #[allow(unused_variables)]
50    pub fn max(pointer: &Self, value: Inner) -> Inner {
51        unexpanded!()
52    }
53
54    /// Atomically sets the value of the atomic variable to `min(current_value, value)`. Returns the
55    /// old value.
56    #[allow(unused_variables)]
57    pub fn min(pointer: &Self, value: Inner) -> Inner {
58        unexpanded!()
59    }
60
61    /// Atomically subtracts a number from the atomic variable. Returns the old value.
62    #[allow(unused_variables)]
63    pub fn sub(pointer: &Self, value: Inner) -> Inner {
64        unexpanded!()
65    }
66
67    pub fn __expand_load(
68        scope: &mut Scope,
69        pointer: <Self as CubeType>::ExpandType,
70    ) -> <Inner as CubeType>::ExpandType {
71        let pointer: ExpandElement = pointer.into();
72        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
73        scope.register(Instruction::new(
74            AtomicOp::Load(UnaryOperator { input: *pointer }),
75            *new_var,
76        ));
77        new_var.into()
78    }
79
80    pub fn __expand_store(
81        scope: &mut Scope,
82        pointer: <Self as CubeType>::ExpandType,
83        value: <Inner as CubeType>::ExpandType,
84    ) {
85        let ptr: ExpandElement = pointer.into();
86        let value: ExpandElement = value.into();
87        scope.register(Instruction::new(
88            AtomicOp::Store(UnaryOperator { input: *value }),
89            *ptr,
90        ));
91    }
92
93    pub fn __expand_swap(
94        scope: &mut Scope,
95        pointer: <Self as CubeType>::ExpandType,
96        value: <Inner as CubeType>::ExpandType,
97    ) -> <Inner as CubeType>::ExpandType {
98        let ptr: ExpandElement = pointer.into();
99        let value: ExpandElement = value.into();
100        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
101        scope.register(Instruction::new(
102            AtomicOp::Swap(BinaryOperator {
103                lhs: *ptr,
104                rhs: *value,
105            }),
106            *new_var,
107        ));
108        new_var.into()
109    }
110
111    pub fn __expand_add(
112        scope: &mut Scope,
113        pointer: <Self as CubeType>::ExpandType,
114        value: <Inner as CubeType>::ExpandType,
115    ) -> <Inner as CubeType>::ExpandType {
116        let ptr: ExpandElement = pointer.into();
117        let value: ExpandElement = value.into();
118        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
119        scope.register(Instruction::new(
120            AtomicOp::Add(BinaryOperator {
121                lhs: *ptr,
122                rhs: *value,
123            }),
124            *new_var,
125        ));
126        new_var.into()
127    }
128
129    pub fn __expand_sub(
130        scope: &mut Scope,
131        pointer: <Self as CubeType>::ExpandType,
132        value: <Inner as CubeType>::ExpandType,
133    ) -> <Inner as CubeType>::ExpandType {
134        let ptr: ExpandElement = pointer.into();
135        let value: ExpandElement = value.into();
136        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
137        scope.register(Instruction::new(
138            AtomicOp::Sub(BinaryOperator {
139                lhs: *ptr,
140                rhs: *value,
141            }),
142            *new_var,
143        ));
144        new_var.into()
145    }
146
147    pub fn __expand_max(
148        scope: &mut Scope,
149        pointer: <Self as CubeType>::ExpandType,
150        value: <Inner as CubeType>::ExpandType,
151    ) -> <Inner as CubeType>::ExpandType {
152        let ptr: ExpandElement = pointer.into();
153        let value: ExpandElement = value.into();
154        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
155        scope.register(Instruction::new(
156            AtomicOp::Max(BinaryOperator {
157                lhs: *ptr,
158                rhs: *value,
159            }),
160            *new_var,
161        ));
162        new_var.into()
163    }
164
165    pub fn __expand_min(
166        scope: &mut Scope,
167        pointer: <Self as CubeType>::ExpandType,
168        value: <Inner as CubeType>::ExpandType,
169    ) -> <Inner as CubeType>::ExpandType {
170        let ptr: ExpandElement = pointer.into();
171        let value: ExpandElement = value.into();
172        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
173        scope.register(Instruction::new(
174            AtomicOp::Min(BinaryOperator {
175                lhs: *ptr,
176                rhs: *value,
177            }),
178            *new_var,
179        ));
180        new_var.into()
181    }
182}
183
184impl<Inner: Int> Atomic<Inner> {
185    /// Compare the value at `pointer` to `cmp` and set it to `value` only if they are the same.
186    /// Returns the old value of the pointer before the store.
187    ///
188    /// ### Tip
189    /// Compare the returned value to `cmp` to determine whether the store was successful.
190    #[allow(unused_variables)]
191    pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
192        unexpanded!()
193    }
194
195    /// Executes an atomic bitwise and operation on the atomic variable. Returns the old value.
196    #[allow(unused_variables)]
197    pub fn and(pointer: &Self, value: Inner) -> Inner {
198        unexpanded!()
199    }
200
201    /// Executes an atomic bitwise or operation on the atomic variable. Returns the old value.
202    #[allow(unused_variables)]
203    pub fn or(pointer: &Self, value: Inner) -> Inner {
204        unexpanded!()
205    }
206
207    /// Executes an atomic bitwise xor operation on the atomic variable. Returns the old value.
208    #[allow(unused_variables)]
209    pub fn xor(pointer: &Self, value: Inner) -> Inner {
210        unexpanded!()
211    }
212
213    pub fn __expand_compare_and_swap(
214        scope: &mut Scope,
215        pointer: <Self as CubeType>::ExpandType,
216        cmp: <Inner as CubeType>::ExpandType,
217        value: <Inner as CubeType>::ExpandType,
218    ) -> <Inner as CubeType>::ExpandType {
219        let pointer: ExpandElement = pointer.into();
220        let cmp: ExpandElement = cmp.into();
221        let value: ExpandElement = value.into();
222        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
223        scope.register(Instruction::new(
224            AtomicOp::CompareAndSwap(CompareAndSwapOperator {
225                input: *pointer,
226                cmp: *cmp,
227                val: *value,
228            }),
229            *new_var,
230        ));
231        new_var.into()
232    }
233
234    pub fn __expand_and(
235        scope: &mut Scope,
236        pointer: <Self as CubeType>::ExpandType,
237        value: <Inner as CubeType>::ExpandType,
238    ) -> <Inner as CubeType>::ExpandType {
239        let ptr: ExpandElement = pointer.into();
240        let value: ExpandElement = value.into();
241        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
242        scope.register(Instruction::new(
243            AtomicOp::And(BinaryOperator {
244                lhs: *ptr,
245                rhs: *value,
246            }),
247            *new_var,
248        ));
249        new_var.into()
250    }
251
252    pub fn __expand_or(
253        scope: &mut Scope,
254        pointer: <Self as CubeType>::ExpandType,
255        value: <Inner as CubeType>::ExpandType,
256    ) -> <Inner as CubeType>::ExpandType {
257        let ptr: ExpandElement = pointer.into();
258        let value: ExpandElement = value.into();
259        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
260        scope.register(Instruction::new(
261            AtomicOp::Or(BinaryOperator {
262                lhs: *ptr,
263                rhs: *value,
264            }),
265            *new_var,
266        ));
267        new_var.into()
268    }
269
270    pub fn __expand_xor(
271        scope: &mut Scope,
272        pointer: <Self as CubeType>::ExpandType,
273        value: <Inner as CubeType>::ExpandType,
274    ) -> <Inner as CubeType>::ExpandType {
275        let ptr: ExpandElement = pointer.into();
276        let value: ExpandElement = value.into();
277        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
278        scope.register(Instruction::new(
279            AtomicOp::Xor(BinaryOperator {
280                lhs: *ptr,
281                rhs: *value,
282            }),
283            *new_var,
284        ));
285        new_var.into()
286    }
287}
288
289impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
290    type ExpandType = ExpandElementTyped<Self>;
291}
292
293impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
294    fn as_elem_native() -> Option<Elem> {
295        match Inner::as_elem_native() {
296            Some(Elem::Float(kind)) => Some(Elem::AtomicFloat(kind)),
297            Some(Elem::Int(kind)) => Some(Elem::AtomicInt(kind)),
298            Some(Elem::UInt(kind)) => Some(Elem::AtomicUInt(kind)),
299            None => None,
300            _ => unreachable!("Atomics can only be float/int/uint"),
301        }
302    }
303
304    fn as_elem(scope: &Scope) -> Elem {
305        match Inner::as_elem(scope) {
306            Elem::Float(kind) => Elem::AtomicFloat(kind),
307            Elem::Int(kind) => Elem::AtomicInt(kind),
308            Elem::UInt(kind) => Elem::AtomicUInt(kind),
309            _ => unreachable!("Atomics can only be float/int/uint"),
310        }
311    }
312
313    fn as_elem_native_unchecked() -> Elem {
314        match Inner::as_elem_native_unchecked() {
315            Elem::Float(kind) => Elem::AtomicFloat(kind),
316            Elem::Int(kind) => Elem::AtomicInt(kind),
317            Elem::UInt(kind) => Elem::AtomicUInt(kind),
318            _ => unreachable!("Atomics can only be float/int/uint"),
319        }
320    }
321
322    fn size() -> Option<usize> {
323        Inner::size()
324    }
325
326    fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
327        ExpandElementTyped::new(elem)
328    }
329}
330
331impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
332    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
333        into_mut_expand_element(scope, elem)
334    }
335}
336
337impl<Inner: CubePrimitive> LaunchArgExpand for Atomic<Inner> {
338    type CompilationArg = ();
339
340    fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
341        builder.scalar(Self::as_elem_native_unchecked()).into()
342    }
343}