cubecl_core/frontend/element/
atomic.rs

1use std::fmt::Display;
2
3use serde::{Deserialize, Serialize};
4
5use super::{
6    init_expand_element, ExpandElementBaseInit, ExpandElementTyped, Int, IntoRuntime,
7    LaunchArgExpand, Numeric,
8};
9use crate::{
10    frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement},
11    ir::{
12        BinaryOperator, CompareAndSwapOperator, Elem, Instruction, Item, Operation, UnaryOperator,
13    },
14    prelude::KernelBuilder,
15    unexpanded,
16};
17
18/// An atomic numerical type wrapping a normal numeric primitive. Enables the use of atomic
19/// operations, while disabling normal operations. In WGSL, this is a separate type - on CUDA/SPIR-V
20/// it can theoretically be bitcast to a normal number, but this isn't recommended.
21#[derive(Clone, Copy, Hash, PartialEq, Eq)]
22pub struct Atomic<Inner: CubePrimitive> {
23    pub val: Inner,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub enum AtomicOp {
28    Load(UnaryOperator),
29    Store(UnaryOperator),
30    Swap(BinaryOperator),
31    Add(BinaryOperator),
32    Sub(BinaryOperator),
33    Max(BinaryOperator),
34    Min(BinaryOperator),
35    And(BinaryOperator),
36    Or(BinaryOperator),
37    Xor(BinaryOperator),
38    CompareAndSwap(CompareAndSwapOperator),
39}
40
41impl Display for AtomicOp {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            AtomicOp::Load(op) => write!(f, "atomic_load({})", op.input),
45            AtomicOp::Store(op) => write!(f, "atomic_store({})", op.input),
46            AtomicOp::Swap(op) => {
47                write!(f, "atomic_swap({}, {})", op.lhs, op.rhs)
48            }
49            AtomicOp::Add(op) => write!(f, "atomic_add({}, {})", op.lhs, op.rhs),
50            AtomicOp::Sub(op) => write!(f, "atomic_sub({}, {})", op.lhs, op.rhs),
51            AtomicOp::Max(op) => write!(f, "atomic_max({}, {})", op.lhs, op.rhs),
52            AtomicOp::Min(op) => write!(f, "atomic_min({}, {})", op.lhs, op.rhs),
53            AtomicOp::And(op) => write!(f, "atomic_and({}, {})", op.lhs, op.rhs),
54            AtomicOp::Or(op) => write!(f, "atomic_or({}, {})", op.lhs, op.rhs),
55            AtomicOp::Xor(op) => write!(f, "atomic_xor({}, {})", op.lhs, op.rhs),
56            AtomicOp::CompareAndSwap(op) => {
57                write!(f, "compare_and_swap({}, {}, {})", op.input, op.cmp, op.val)
58            }
59        }
60    }
61}
62
63impl<Inner: Numeric> Atomic<Inner> {
64    /// Load the value of the atomic.
65    #[allow(unused_variables)]
66    pub fn load(pointer: &Self) -> Inner {
67        unexpanded!()
68    }
69
70    /// Store the value of the atomic.
71    #[allow(unused_variables)]
72    pub fn store(pointer: &Self, value: Inner) {
73        unexpanded!()
74    }
75
76    /// Atomically stores the value into the atomic and returns the old value.
77    #[allow(unused_variables)]
78    pub fn swap(pointer: &Self, value: Inner) -> Inner {
79        unexpanded!()
80    }
81
82    /// Atomically add a number to the atomic variable. Returns the old value.
83    #[allow(unused_variables)]
84    pub fn add(pointer: &Self, value: Inner) -> Inner {
85        unexpanded!()
86    }
87
88    /// Atomically sets the value of the atomic variable to `max(current_value, value)`. Returns
89    /// the old value.
90    #[allow(unused_variables)]
91    pub fn max(pointer: &Self, value: Inner) -> Inner {
92        unexpanded!()
93    }
94
95    /// Atomically sets the value of the atomic variable to `min(current_value, value)`. Returns the
96    /// old value.
97    #[allow(unused_variables)]
98    pub fn min(pointer: &Self, value: Inner) -> Inner {
99        unexpanded!()
100    }
101
102    /// Atomically subtracts a number from the atomic variable. Returns the old value.
103    #[allow(unused_variables)]
104    pub fn sub(pointer: &Self, value: Inner) -> Inner {
105        unexpanded!()
106    }
107
108    pub fn __expand_load(
109        context: &mut CubeContext,
110        pointer: <Self as CubeType>::ExpandType,
111    ) -> <Inner as CubeType>::ExpandType {
112        let pointer: ExpandElement = pointer.into();
113        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
114        context.register(Instruction::new(
115            AtomicOp::Load(UnaryOperator { input: *pointer }),
116            *new_var,
117        ));
118        new_var.into()
119    }
120
121    pub fn __expand_store(
122        context: &mut CubeContext,
123        pointer: <Self as CubeType>::ExpandType,
124        value: <Inner as CubeType>::ExpandType,
125    ) {
126        let ptr: ExpandElement = pointer.into();
127        let value: ExpandElement = value.into();
128        context.register(Instruction::new(
129            AtomicOp::Store(UnaryOperator { input: *value }),
130            *ptr,
131        ));
132    }
133
134    pub fn __expand_swap(
135        context: &mut CubeContext,
136        pointer: <Self as CubeType>::ExpandType,
137        value: <Inner as CubeType>::ExpandType,
138    ) -> <Inner as CubeType>::ExpandType {
139        let ptr: ExpandElement = pointer.into();
140        let value: ExpandElement = value.into();
141        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
142        context.register(Instruction::new(
143            AtomicOp::Swap(BinaryOperator {
144                lhs: *ptr,
145                rhs: *value,
146            }),
147            *new_var,
148        ));
149        new_var.into()
150    }
151
152    pub fn __expand_add(
153        context: &mut CubeContext,
154        pointer: <Self as CubeType>::ExpandType,
155        value: <Inner as CubeType>::ExpandType,
156    ) -> <Inner as CubeType>::ExpandType {
157        let ptr: ExpandElement = pointer.into();
158        let value: ExpandElement = value.into();
159        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
160        context.register(Instruction::new(
161            AtomicOp::Add(BinaryOperator {
162                lhs: *ptr,
163                rhs: *value,
164            }),
165            *new_var,
166        ));
167        new_var.into()
168    }
169
170    pub fn __expand_sub(
171        context: &mut CubeContext,
172        pointer: <Self as CubeType>::ExpandType,
173        value: <Inner as CubeType>::ExpandType,
174    ) -> <Inner as CubeType>::ExpandType {
175        let ptr: ExpandElement = pointer.into();
176        let value: ExpandElement = value.into();
177        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
178        context.register(Instruction::new(
179            AtomicOp::Sub(BinaryOperator {
180                lhs: *ptr,
181                rhs: *value,
182            }),
183            *new_var,
184        ));
185        new_var.into()
186    }
187
188    pub fn __expand_max(
189        context: &mut CubeContext,
190        pointer: <Self as CubeType>::ExpandType,
191        value: <Inner as CubeType>::ExpandType,
192    ) -> <Inner as CubeType>::ExpandType {
193        let ptr: ExpandElement = pointer.into();
194        let value: ExpandElement = value.into();
195        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
196        context.register(Instruction::new(
197            AtomicOp::Max(BinaryOperator {
198                lhs: *ptr,
199                rhs: *value,
200            }),
201            *new_var,
202        ));
203        new_var.into()
204    }
205
206    pub fn __expand_min(
207        context: &mut CubeContext,
208        pointer: <Self as CubeType>::ExpandType,
209        value: <Inner as CubeType>::ExpandType,
210    ) -> <Inner as CubeType>::ExpandType {
211        let ptr: ExpandElement = pointer.into();
212        let value: ExpandElement = value.into();
213        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
214        context.register(Instruction::new(
215            AtomicOp::Min(BinaryOperator {
216                lhs: *ptr,
217                rhs: *value,
218            }),
219            *new_var,
220        ));
221        new_var.into()
222    }
223}
224
225impl<Inner: Int> Atomic<Inner> {
226    /// Compare the value at `pointer` to `cmp` and set it to `value` only if they are the same.
227    /// Returns the old value of the pointer before the store.
228    ///
229    /// ### Tip
230    /// Compare the returned value to `cmp` to determine whether the store was successful.
231    #[allow(unused_variables)]
232    pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
233        unexpanded!()
234    }
235
236    /// Executes an atomic bitwise and operation on the atomic variable. Returns the old value.
237    #[allow(unused_variables)]
238    pub fn and(pointer: &Self, value: Inner) -> Inner {
239        unexpanded!()
240    }
241
242    /// Executes an atomic bitwise or operation on the atomic variable. Returns the old value.
243    #[allow(unused_variables)]
244    pub fn or(pointer: &Self, value: Inner) -> Inner {
245        unexpanded!()
246    }
247
248    /// Executes an atomic bitwise xor operation on the atomic variable. Returns the old value.
249    #[allow(unused_variables)]
250    pub fn xor(pointer: &Self, value: Inner) -> Inner {
251        unexpanded!()
252    }
253
254    pub fn __expand_compare_and_swap(
255        context: &mut CubeContext,
256        pointer: <Self as CubeType>::ExpandType,
257        cmp: <Inner as CubeType>::ExpandType,
258        value: <Inner as CubeType>::ExpandType,
259    ) -> <Inner as CubeType>::ExpandType {
260        let pointer: ExpandElement = pointer.into();
261        let cmp: ExpandElement = cmp.into();
262        let value: ExpandElement = value.into();
263        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
264        context.register(Instruction::new(
265            AtomicOp::CompareAndSwap(CompareAndSwapOperator {
266                input: *pointer,
267                cmp: *cmp,
268                val: *value,
269            }),
270            *new_var,
271        ));
272        new_var.into()
273    }
274
275    pub fn __expand_and(
276        context: &mut CubeContext,
277        pointer: <Self as CubeType>::ExpandType,
278        value: <Inner as CubeType>::ExpandType,
279    ) -> <Inner as CubeType>::ExpandType {
280        let ptr: ExpandElement = pointer.into();
281        let value: ExpandElement = value.into();
282        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
283        context.register(Instruction::new(
284            AtomicOp::And(BinaryOperator {
285                lhs: *ptr,
286                rhs: *value,
287            }),
288            *new_var,
289        ));
290        new_var.into()
291    }
292
293    pub fn __expand_or(
294        context: &mut CubeContext,
295        pointer: <Self as CubeType>::ExpandType,
296        value: <Inner as CubeType>::ExpandType,
297    ) -> <Inner as CubeType>::ExpandType {
298        let ptr: ExpandElement = pointer.into();
299        let value: ExpandElement = value.into();
300        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
301        context.register(Instruction::new(
302            AtomicOp::Or(BinaryOperator {
303                lhs: *ptr,
304                rhs: *value,
305            }),
306            *new_var,
307        ));
308        new_var.into()
309    }
310
311    pub fn __expand_xor(
312        context: &mut CubeContext,
313        pointer: <Self as CubeType>::ExpandType,
314        value: <Inner as CubeType>::ExpandType,
315    ) -> <Inner as CubeType>::ExpandType {
316        let ptr: ExpandElement = pointer.into();
317        let value: ExpandElement = value.into();
318        let new_var = context.create_local(Item::new(Inner::as_elem(context)));
319        context.register(Instruction::new(
320            AtomicOp::Xor(BinaryOperator {
321                lhs: *ptr,
322                rhs: *value,
323            }),
324            *new_var,
325        ));
326        new_var.into()
327    }
328}
329
330impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
331    type ExpandType = ExpandElementTyped<Self>;
332}
333
334impl<Inner: CubePrimitive> IntoRuntime for Atomic<Inner> {
335    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
336        unimplemented!("Atomics don't exist at compile time")
337    }
338}
339
340impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
341    fn as_elem_native() -> Option<Elem> {
342        match Inner::as_elem_native() {
343            Some(Elem::Float(kind)) => Some(Elem::AtomicFloat(kind)),
344            Some(Elem::Int(kind)) => Some(Elem::AtomicInt(kind)),
345            Some(Elem::UInt(kind)) => Some(Elem::AtomicUInt(kind)),
346            None => None,
347            _ => unreachable!("Atomics can only be float/int/uint"),
348        }
349    }
350
351    fn as_elem(context: &CubeContext) -> Elem {
352        match Inner::as_elem(context) {
353            Elem::Float(kind) => Elem::AtomicFloat(kind),
354            Elem::Int(kind) => Elem::AtomicInt(kind),
355            Elem::UInt(kind) => Elem::AtomicUInt(kind),
356            _ => unreachable!("Atomics can only be float/int/uint"),
357        }
358    }
359
360    fn as_elem_native_unchecked() -> Elem {
361        match Inner::as_elem_native_unchecked() {
362            Elem::Float(kind) => Elem::AtomicFloat(kind),
363            Elem::Int(kind) => Elem::AtomicInt(kind),
364            Elem::UInt(kind) => Elem::AtomicUInt(kind),
365            _ => unreachable!("Atomics can only be float/int/uint"),
366        }
367    }
368
369    fn size() -> Option<usize> {
370        Inner::size()
371    }
372
373    fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
374        ExpandElementTyped::new(elem)
375    }
376}
377
378impl<Inner: CubePrimitive> ExpandElementBaseInit for Atomic<Inner> {
379    fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
380        init_expand_element(context, elem)
381    }
382}
383
384impl<Inner: CubePrimitive> LaunchArgExpand for Atomic<Inner> {
385    type CompilationArg = ();
386
387    fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
388        builder.scalar(Self::as_elem_native_unchecked()).into()
389    }
390}
391
392impl From<AtomicOp> for Operation {
393    fn from(value: AtomicOp) -> Self {
394        Operation::Atomic(value)
395    }
396}