cubecl_core/frontend/element/
atomic.rs

1use cubecl_ir::{AtomicOp, ExpandElement};
2
3use super::{
4    ExpandElementBaseInit, ExpandElementTyped, Int, LaunchArgExpand, Numeric, init_expand_element,
5};
6use crate::{
7    frontend::{CubePrimitive, CubeType},
8    ir::{BinaryOperator, CompareAndSwapOperator, Elem, Instruction, Item, Scope, UnaryOperator},
9    prelude::KernelBuilder,
10    unexpanded,
11};
12
13/// An atomic numerical type wrapping a normal numeric primitive. Enables the use of atomic
14/// operations, while disabling normal operations. In WGSL, this is a separate type - on CUDA/SPIR-V
15/// it can theoretically be bitcast to a normal number, but this isn't recommended.
16#[derive(Clone, Copy, Hash, PartialEq, Eq)]
17pub struct Atomic<Inner: CubePrimitive> {
18    pub val: Inner,
19}
20
21impl<Inner: Numeric> Atomic<Inner> {
22    /// Load the value of the atomic.
23    #[allow(unused_variables)]
24    pub fn load(pointer: &Self) -> Inner {
25        unexpanded!()
26    }
27
28    /// Store the value of the atomic.
29    #[allow(unused_variables)]
30    pub fn store(pointer: &Self, value: Inner) {
31        unexpanded!()
32    }
33
34    /// Atomically stores the value into the atomic and returns the old value.
35    #[allow(unused_variables)]
36    pub fn swap(pointer: &Self, value: Inner) -> Inner {
37        unexpanded!()
38    }
39
40    /// Atomically add a number to the atomic variable. Returns the old value.
41    #[allow(unused_variables)]
42    pub fn add(pointer: &Self, value: Inner) -> Inner {
43        unexpanded!()
44    }
45
46    /// Atomically sets the value of the atomic variable to `max(current_value, value)`. Returns
47    /// the old value.
48    #[allow(unused_variables)]
49    pub fn max(pointer: &Self, value: Inner) -> Inner {
50        unexpanded!()
51    }
52
53    /// Atomically sets the value of the atomic variable to `min(current_value, value)`. Returns the
54    /// old value.
55    #[allow(unused_variables)]
56    pub fn min(pointer: &Self, value: Inner) -> Inner {
57        unexpanded!()
58    }
59
60    /// Atomically subtracts a number from the atomic variable. Returns the old value.
61    #[allow(unused_variables)]
62    pub fn sub(pointer: &Self, value: Inner) -> Inner {
63        unexpanded!()
64    }
65
66    pub fn __expand_load(
67        scope: &mut Scope,
68        pointer: <Self as CubeType>::ExpandType,
69    ) -> <Inner as CubeType>::ExpandType {
70        let pointer: ExpandElement = pointer.into();
71        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
72        scope.register(Instruction::new(
73            AtomicOp::Load(UnaryOperator { input: *pointer }),
74            *new_var,
75        ));
76        new_var.into()
77    }
78
79    pub fn __expand_store(
80        scope: &mut Scope,
81        pointer: <Self as CubeType>::ExpandType,
82        value: <Inner as CubeType>::ExpandType,
83    ) {
84        let ptr: ExpandElement = pointer.into();
85        let value: ExpandElement = value.into();
86        scope.register(Instruction::new(
87            AtomicOp::Store(UnaryOperator { input: *value }),
88            *ptr,
89        ));
90    }
91
92    pub fn __expand_swap(
93        scope: &mut Scope,
94        pointer: <Self as CubeType>::ExpandType,
95        value: <Inner as CubeType>::ExpandType,
96    ) -> <Inner as CubeType>::ExpandType {
97        let ptr: ExpandElement = pointer.into();
98        let value: ExpandElement = value.into();
99        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
100        scope.register(Instruction::new(
101            AtomicOp::Swap(BinaryOperator {
102                lhs: *ptr,
103                rhs: *value,
104            }),
105            *new_var,
106        ));
107        new_var.into()
108    }
109
110    pub fn __expand_add(
111        scope: &mut Scope,
112        pointer: <Self as CubeType>::ExpandType,
113        value: <Inner as CubeType>::ExpandType,
114    ) -> <Inner as CubeType>::ExpandType {
115        let ptr: ExpandElement = pointer.into();
116        let value: ExpandElement = value.into();
117        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
118        scope.register(Instruction::new(
119            AtomicOp::Add(BinaryOperator {
120                lhs: *ptr,
121                rhs: *value,
122            }),
123            *new_var,
124        ));
125        new_var.into()
126    }
127
128    pub fn __expand_sub(
129        scope: &mut Scope,
130        pointer: <Self as CubeType>::ExpandType,
131        value: <Inner as CubeType>::ExpandType,
132    ) -> <Inner as CubeType>::ExpandType {
133        let ptr: ExpandElement = pointer.into();
134        let value: ExpandElement = value.into();
135        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
136        scope.register(Instruction::new(
137            AtomicOp::Sub(BinaryOperator {
138                lhs: *ptr,
139                rhs: *value,
140            }),
141            *new_var,
142        ));
143        new_var.into()
144    }
145
146    pub fn __expand_max(
147        scope: &mut Scope,
148        pointer: <Self as CubeType>::ExpandType,
149        value: <Inner as CubeType>::ExpandType,
150    ) -> <Inner as CubeType>::ExpandType {
151        let ptr: ExpandElement = pointer.into();
152        let value: ExpandElement = value.into();
153        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
154        scope.register(Instruction::new(
155            AtomicOp::Max(BinaryOperator {
156                lhs: *ptr,
157                rhs: *value,
158            }),
159            *new_var,
160        ));
161        new_var.into()
162    }
163
164    pub fn __expand_min(
165        scope: &mut Scope,
166        pointer: <Self as CubeType>::ExpandType,
167        value: <Inner as CubeType>::ExpandType,
168    ) -> <Inner as CubeType>::ExpandType {
169        let ptr: ExpandElement = pointer.into();
170        let value: ExpandElement = value.into();
171        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
172        scope.register(Instruction::new(
173            AtomicOp::Min(BinaryOperator {
174                lhs: *ptr,
175                rhs: *value,
176            }),
177            *new_var,
178        ));
179        new_var.into()
180    }
181}
182
183impl<Inner: Int> Atomic<Inner> {
184    /// Compare the value at `pointer` to `cmp` and set it to `value` only if they are the same.
185    /// Returns the old value of the pointer before the store.
186    ///
187    /// ### Tip
188    /// Compare the returned value to `cmp` to determine whether the store was successful.
189    #[allow(unused_variables)]
190    pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
191        unexpanded!()
192    }
193
194    /// Executes an atomic bitwise and operation on the atomic variable. Returns the old value.
195    #[allow(unused_variables)]
196    pub fn and(pointer: &Self, value: Inner) -> Inner {
197        unexpanded!()
198    }
199
200    /// Executes an atomic bitwise or operation on the atomic variable. Returns the old value.
201    #[allow(unused_variables)]
202    pub fn or(pointer: &Self, value: Inner) -> Inner {
203        unexpanded!()
204    }
205
206    /// Executes an atomic bitwise xor operation on the atomic variable. Returns the old value.
207    #[allow(unused_variables)]
208    pub fn xor(pointer: &Self, value: Inner) -> Inner {
209        unexpanded!()
210    }
211
212    pub fn __expand_compare_and_swap(
213        scope: &mut Scope,
214        pointer: <Self as CubeType>::ExpandType,
215        cmp: <Inner as CubeType>::ExpandType,
216        value: <Inner as CubeType>::ExpandType,
217    ) -> <Inner as CubeType>::ExpandType {
218        let pointer: ExpandElement = pointer.into();
219        let cmp: ExpandElement = cmp.into();
220        let value: ExpandElement = value.into();
221        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
222        scope.register(Instruction::new(
223            AtomicOp::CompareAndSwap(CompareAndSwapOperator {
224                input: *pointer,
225                cmp: *cmp,
226                val: *value,
227            }),
228            *new_var,
229        ));
230        new_var.into()
231    }
232
233    pub fn __expand_and(
234        scope: &mut Scope,
235        pointer: <Self as CubeType>::ExpandType,
236        value: <Inner as CubeType>::ExpandType,
237    ) -> <Inner as CubeType>::ExpandType {
238        let ptr: ExpandElement = pointer.into();
239        let value: ExpandElement = value.into();
240        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
241        scope.register(Instruction::new(
242            AtomicOp::And(BinaryOperator {
243                lhs: *ptr,
244                rhs: *value,
245            }),
246            *new_var,
247        ));
248        new_var.into()
249    }
250
251    pub fn __expand_or(
252        scope: &mut Scope,
253        pointer: <Self as CubeType>::ExpandType,
254        value: <Inner as CubeType>::ExpandType,
255    ) -> <Inner as CubeType>::ExpandType {
256        let ptr: ExpandElement = pointer.into();
257        let value: ExpandElement = value.into();
258        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
259        scope.register(Instruction::new(
260            AtomicOp::Or(BinaryOperator {
261                lhs: *ptr,
262                rhs: *value,
263            }),
264            *new_var,
265        ));
266        new_var.into()
267    }
268
269    pub fn __expand_xor(
270        scope: &mut Scope,
271        pointer: <Self as CubeType>::ExpandType,
272        value: <Inner as CubeType>::ExpandType,
273    ) -> <Inner as CubeType>::ExpandType {
274        let ptr: ExpandElement = pointer.into();
275        let value: ExpandElement = value.into();
276        let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
277        scope.register(Instruction::new(
278            AtomicOp::Xor(BinaryOperator {
279                lhs: *ptr,
280                rhs: *value,
281            }),
282            *new_var,
283        ));
284        new_var.into()
285    }
286}
287
288impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
289    type ExpandType = ExpandElementTyped<Self>;
290}
291
292impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
293    fn as_elem_native() -> Option<Elem> {
294        match Inner::as_elem_native() {
295            Some(Elem::Float(kind)) => Some(Elem::AtomicFloat(kind)),
296            Some(Elem::Int(kind)) => Some(Elem::AtomicInt(kind)),
297            Some(Elem::UInt(kind)) => Some(Elem::AtomicUInt(kind)),
298            None => None,
299            _ => unreachable!("Atomics can only be float/int/uint"),
300        }
301    }
302
303    fn as_elem(scope: &Scope) -> Elem {
304        match Inner::as_elem(scope) {
305            Elem::Float(kind) => Elem::AtomicFloat(kind),
306            Elem::Int(kind) => Elem::AtomicInt(kind),
307            Elem::UInt(kind) => Elem::AtomicUInt(kind),
308            _ => unreachable!("Atomics can only be float/int/uint"),
309        }
310    }
311
312    fn as_elem_native_unchecked() -> Elem {
313        match Inner::as_elem_native_unchecked() {
314            Elem::Float(kind) => Elem::AtomicFloat(kind),
315            Elem::Int(kind) => Elem::AtomicInt(kind),
316            Elem::UInt(kind) => Elem::AtomicUInt(kind),
317            _ => unreachable!("Atomics can only be float/int/uint"),
318        }
319    }
320
321    fn size() -> Option<usize> {
322        Inner::size()
323    }
324
325    fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
326        ExpandElementTyped::new(elem)
327    }
328}
329
330impl<Inner: CubePrimitive> ExpandElementBaseInit for Atomic<Inner> {
331    fn init_elem(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
332        init_expand_element(scope, elem)
333    }
334}
335
336impl<Inner: CubePrimitive> LaunchArgExpand for Atomic<Inner> {
337    type CompilationArg = ();
338
339    fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
340        builder.scalar(Self::as_elem_native_unchecked()).into()
341    }
342}