cuda_std/
atomic.rs

1//! Atomic Types for modification of numbers in multiple threads in a sound way.
2//!
3//! # Core Interop
4//!
5//! Every type in this module works on the CPU (targets outside of nvptx). However, [`core::sync::atomic`] types
6//! do **NOT** work on the GPU currently. This is because CUDA atomics have some fundamental differences
7//! that make representing them fully with existing core types impossible:
8//!
9//! - CUDA has block-scoped, device-scoped, and system-scoped atomics, core does not make such a distinction (obviously).
10//! - CUDA trivially supports relaxed/acquire/release orderings on most architectures, but SeqCst and other orderings use
11//! specialized instructions on compute capabilities 7.x+, but can be emulated with fences/membars on 7.x >. This makes it difficult
12//! to hide away such details in the codegen.
13//! - CUDA has hardware atomic floats, core does not.
14//! - CUDA makes the distinction between "fetch, do operation, read" (`atom`) and "do operation" (`red`).
15//! - Core thinks CUDA supports 8 and 16 bit atomics, this is a bug in the nvptx target but it is nevertheless an annoying detail
16//! to silently trap on.
17//!
18//! Therefore we chose to go with the approach of implementing all atomics inside cuda_std. In the future, we may support
19//! a subset of core atomics, but for now, you will have to use cuda_std atomics.
20
21#![allow(unused_unsafe)]
22
23pub mod intrinsics;
24pub mod mid;
25
26use core::cell::UnsafeCell;
27use core::sync::atomic::Ordering;
28use paste::paste;
29
30#[allow(dead_code)]
31fn fail_order(order: Ordering) -> Ordering {
32    match order {
33        Ordering::Release | Ordering::Relaxed => Ordering::Relaxed,
34        Ordering::Acquire | Ordering::AcqRel => Ordering::Acquire,
35        Ordering::SeqCst => Ordering::SeqCst,
36        x => x,
37    }
38}
39
40macro_rules! scope_doc {
41    (device) => {
42        "a single device (GPU)."
43    };
44    (block) => {
45        "a single thread block (also called a CTA, cooperative thread array)."
46    };
47    (system) => {
48        "the entire system."
49    };
50}
51
52macro_rules! safety_doc {
53    ($($unsafety:ident)?) => {
54        $(
55            concat!(
56                "# Safety\n",
57                concat!("This function is ", stringify!($unsafety), " because it does not synchronize\n"),
58                "across the entire GPU or System, which leaves it open for data races if used incorrectly"
59            )
60        )?
61    };
62}
63
64macro_rules! atomic_float {
65    ($float_ty:ident, $atomic_ty:ident, $align:tt, $scope:ident, $width:tt $(,$unsafety:ident)?) => {
66        #[doc = concat!("A ", stringify!($width), "-bit float type which can be safely shared between threads and synchronizes across ", scope_doc!($scope))]
67        ///
68        /// This type is guaranteed to have the same memory representation as the underlying integer
69        /// type [`
70        #[doc = stringify!($float_ty)]
71        /// `].
72        ///
73        /// The functions on this type map to hardware instructions on CUDA platforms, and are emulated
74        /// with a CAS loop on the CPU (non-CUDA targets).
75        #[repr(C, align($align))]
76        pub struct $atomic_ty {
77            v: UnsafeCell<$float_ty>,
78        }
79
80        // SAFETY: atomic ops make sure this is fine
81        unsafe impl Sync for $atomic_ty {}
82
83        impl $atomic_ty {
84            paste! {
85                /// Creates a new atomic float.
86                pub const fn new(v: $float_ty) -> $atomic_ty {
87                    Self {
88                        v: UnsafeCell::new(v),
89                    }
90                }
91
92                /// Consumes the atomic and returns the contained value.
93                pub fn into_inner(self) -> $float_ty {
94                    self.v.into_inner()
95                }
96
97                #[cfg(not(target_os = "cuda"))]
98                fn as_atomic_bits(&self) -> &core::sync::atomic::[<AtomicU $width>] {
99                    // SAFETY: AtomicU32/U64 pointers are compatible with UnsafeCell<u32/u64>.
100                    unsafe {
101                        core::mem::transmute(self)
102                    }
103                }
104
105                #[cfg(not(target_os = "cuda"))]
106                fn update_with(&self, order: Ordering, mut func: impl FnMut($float_ty) -> $float_ty) -> $float_ty {
107                    let res = self
108                        .as_atomic_bits()
109                        .fetch_update(order, fail_order(order), |prev| {
110                            Some(func($float_ty::from_bits(prev))).map($float_ty::to_bits)
111                        }).unwrap();
112                    $float_ty::from_bits(res)
113                }
114
115                /// Adds to the current value, returning the previous value **before** the addition.
116                ///
117                $(#[doc = safety_doc!($unsafety)])?
118                pub $($unsafety)? fn fetch_add(&self, val: $float_ty, order: Ordering) -> $float_ty {
119                    #[cfg(target_os = "cuda")]
120                    // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
121                    unsafe {
122                        mid::[<atomic_fetch_add_ $float_ty _ $scope>](self.v.get(), order, val)
123                    }
124                    #[cfg(not(target_os = "cuda"))]
125                    self.update_with(order, |v| v + val)
126                }
127
128                /// Subtracts from the current value, returning the previous value **before** the subtraction.
129                ///
130                /// Note, this is actually implemented as `old + (-new)`, CUDA does not have a specialized sub instruction.
131                ///
132                $(#[doc = safety_doc!($unsafety)])?
133                pub $($unsafety)? fn fetch_sub(&self, val: $float_ty, order: Ordering) -> $float_ty {
134                    #[cfg(target_os = "cuda")]
135                    // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
136                    unsafe {
137                        mid::[<atomic_fetch_sub_ $float_ty _ $scope>](self.v.get(), order, val)
138                    }
139                    #[cfg(not(target_os = "cuda"))]
140                    self.update_with(order, |v| v - val)
141                }
142
143                /// Bitwise "and" with the current value. Returns the value **before** the "and".
144                ///
145                $(#[doc = safety_doc!($unsafety)])?
146                pub $($unsafety)? fn fetch_and(&self, val: $float_ty, order: Ordering) -> $float_ty {
147                    #[cfg(target_os = "cuda")]
148                    // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
149                    unsafe {
150                        mid::[<atomic_fetch_and_ $float_ty _ $scope>](self.v.get(), order, val)
151                    }
152                    #[cfg(not(target_os = "cuda"))]
153                    self.update_with(order, |v| $float_ty::from_bits(v.to_bits() & val.to_bits()))
154                }
155
156                /// Bitwise "or" with the current value. Returns the value **before** the "or".
157                ///
158                $(#[doc = safety_doc!($unsafety)])?
159                pub $($unsafety)? fn fetch_or(&self, val: $float_ty, order: Ordering) -> $float_ty {
160                    #[cfg(target_os = "cuda")]
161                    // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
162                    unsafe {
163                        mid::[<atomic_fetch_or_ $float_ty _ $scope>](self.v.get(), order, val)
164                    }
165                    #[cfg(not(target_os = "cuda"))]
166                    self.update_with(order, |v| $float_ty::from_bits(v.to_bits() | val.to_bits()))
167                }
168
169                /// Bitwise "xor" with the current value. Returns the value **before** the "xor".
170                ///
171                $(#[doc = safety_doc!($unsafety)])?
172                pub $($unsafety)? fn fetch_xor(&self, val: $float_ty, order: Ordering) -> $float_ty {
173                    #[cfg(target_os = "cuda")]
174                    // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
175                    unsafe {
176                        mid::[<atomic_fetch_xor_ $float_ty _ $scope>](self.v.get(), order, val)
177                    }
178                    #[cfg(not(target_os = "cuda"))]
179                    self.update_with(order, |v| $float_ty::from_bits(v.to_bits() ^ val.to_bits()))
180                }
181
182                /// Atomically loads the value behind this atomic.
183                ///
184                /// `load` takes an [`Ordering`] argument which describes the memory ordering of this operation.
185                /// Possible values are [`Ordering::SeqCst`], [`Ordering::Acquire`], and [`Ordering::Relaxed`].
186                ///
187                /// # Panics
188                ///
189                /// Panics if `order` is [`Ordering::Release`] or [`Ordering::AcqRel`].
190                ///
191                $(#[doc = safety_doc!($unsafety)])?
192                pub $($unsafety)? fn load(&self, order: Ordering) -> $float_ty {
193                    #[cfg(target_os = "cuda")]
194                    unsafe {
195                        let val = mid::[<atomic_load_ $width _ $scope>](self.v.get().cast(), order);
196                        $float_ty::from_bits(val)
197                    }
198                    #[cfg(not(target_os = "cuda"))]
199                    {
200                        let val = self.as_atomic_bits().load(order);
201                        $float_ty::from_bits(val)
202                    }
203                }
204
205                /// Atomically stores a value into this atomic.
206                ///
207                /// `store` takes an [`Ordering`] argument which describes the memory ordering of this operation.
208                /// Possible values are [`Ordering::SeqCst`], [`Ordering::Release`], and [`Ordering::Relaxed`].
209                ///
210                /// # Panics
211                ///
212                /// Panics if `order` is [`Ordering::Acquire`] or [`Ordering::AcqRel`].
213                ///
214                $(#[doc = safety_doc!($unsafety)])?
215                pub $($unsafety)? fn store(&self, val: $float_ty, order: Ordering) {
216                    #[cfg(target_os = "cuda")]
217                    unsafe {
218                        mid::[<atomic_store_ $width _ $scope>](self.v.get().cast(), order, val.to_bits());
219                    }
220                    #[cfg(not(target_os = "cuda"))]
221                    self.as_atomic_bits().store(val.to_bits(), order);
222                }
223            }
224        }
225    };
226}
227
228atomic_float!(f32, AtomicF32, 4, device, 32);
229atomic_float!(f64, AtomicF64, 8, device, 64);
230atomic_float!(f32, BlockAtomicF32, 4, block, 32, unsafe);
231atomic_float!(f64, BlockAtomicF64, 8, block, 64, unsafe);
232atomic_float!(f32, SystemAtomicF32, 4, device, 32);
233atomic_float!(f64, SystemAtomicF64, 8, device, 64);