cuda_std/atomic.rs
1//! Atomic Types for modification of numbers in multiple threads in a sound way.
2//!
3//! # Core Interop
4//!
5//! Every type in this module works on the CPU (targets outside of nvptx). However, [`core::sync::atomic`] types
6//! do **NOT** work on the GPU currently. This is because CUDA atomics have some fundamental differences
7//! that make representing them fully with existing core types impossible:
8//!
9//! - CUDA has block-scoped, device-scoped, and system-scoped atomics, core does not make such a distinction (obviously).
10//! - CUDA trivially supports relaxed/acquire/release orderings on most architectures, but SeqCst and other orderings use
11//! specialized instructions on compute capabilities 7.x+, but can be emulated with fences/membars on 7.x >. This makes it difficult
12//! to hide away such details in the codegen.
13//! - CUDA has hardware atomic floats, core does not.
14//! - CUDA makes the distinction between "fetch, do operation, read" (`atom`) and "do operation" (`red`).
15//! - Core thinks CUDA supports 8 and 16 bit atomics, this is a bug in the nvptx target but it is nevertheless an annoying detail
16//! to silently trap on.
17//!
18//! Therefore we chose to go with the approach of implementing all atomics inside cuda_std. In the future, we may support
19//! a subset of core atomics, but for now, you will have to use cuda_std atomics.
20
21#![allow(unused_unsafe)]
22
23pub mod intrinsics;
24pub mod mid;
25
26use core::cell::UnsafeCell;
27use core::sync::atomic::Ordering;
28use paste::paste;
29
30#[allow(dead_code)]
31fn fail_order(order: Ordering) -> Ordering {
32 match order {
33 Ordering::Release | Ordering::Relaxed => Ordering::Relaxed,
34 Ordering::Acquire | Ordering::AcqRel => Ordering::Acquire,
35 Ordering::SeqCst => Ordering::SeqCst,
36 x => x,
37 }
38}
39
40macro_rules! scope_doc {
41 (device) => {
42 "a single device (GPU)."
43 };
44 (block) => {
45 "a single thread block (also called a CTA, cooperative thread array)."
46 };
47 (system) => {
48 "the entire system."
49 };
50}
51
52macro_rules! safety_doc {
53 ($($unsafety:ident)?) => {
54 $(
55 concat!(
56 "# Safety\n",
57 concat!("This function is ", stringify!($unsafety), " because it does not synchronize\n"),
58 "across the entire GPU or System, which leaves it open for data races if used incorrectly"
59 )
60 )?
61 };
62}
63
64macro_rules! atomic_float {
65 ($float_ty:ident, $atomic_ty:ident, $align:tt, $scope:ident, $width:tt $(,$unsafety:ident)?) => {
66 #[doc = concat!("A ", stringify!($width), "-bit float type which can be safely shared between threads and synchronizes across ", scope_doc!($scope))]
67 ///
68 /// This type is guaranteed to have the same memory representation as the underlying integer
69 /// type [`
70 #[doc = stringify!($float_ty)]
71 /// `].
72 ///
73 /// The functions on this type map to hardware instructions on CUDA platforms, and are emulated
74 /// with a CAS loop on the CPU (non-CUDA targets).
75 #[repr(C, align($align))]
76 pub struct $atomic_ty {
77 v: UnsafeCell<$float_ty>,
78 }
79
80 // SAFETY: atomic ops make sure this is fine
81 unsafe impl Sync for $atomic_ty {}
82
83 impl $atomic_ty {
84 paste! {
85 /// Creates a new atomic float.
86 pub const fn new(v: $float_ty) -> $atomic_ty {
87 Self {
88 v: UnsafeCell::new(v),
89 }
90 }
91
92 /// Consumes the atomic and returns the contained value.
93 pub fn into_inner(self) -> $float_ty {
94 self.v.into_inner()
95 }
96
97 #[cfg(not(target_os = "cuda"))]
98 fn as_atomic_bits(&self) -> &core::sync::atomic::[<AtomicU $width>] {
99 // SAFETY: AtomicU32/U64 pointers are compatible with UnsafeCell<u32/u64>.
100 unsafe {
101 core::mem::transmute(self)
102 }
103 }
104
105 #[cfg(not(target_os = "cuda"))]
106 fn update_with(&self, order: Ordering, mut func: impl FnMut($float_ty) -> $float_ty) -> $float_ty {
107 let res = self
108 .as_atomic_bits()
109 .fetch_update(order, fail_order(order), |prev| {
110 Some(func($float_ty::from_bits(prev))).map($float_ty::to_bits)
111 }).unwrap();
112 $float_ty::from_bits(res)
113 }
114
115 /// Adds to the current value, returning the previous value **before** the addition.
116 ///
117 $(#[doc = safety_doc!($unsafety)])?
118 pub $($unsafety)? fn fetch_add(&self, val: $float_ty, order: Ordering) -> $float_ty {
119 #[cfg(target_os = "cuda")]
120 // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
121 unsafe {
122 mid::[<atomic_fetch_add_ $float_ty _ $scope>](self.v.get(), order, val)
123 }
124 #[cfg(not(target_os = "cuda"))]
125 self.update_with(order, |v| v + val)
126 }
127
128 /// Subtracts from the current value, returning the previous value **before** the subtraction.
129 ///
130 /// Note, this is actually implemented as `old + (-new)`, CUDA does not have a specialized sub instruction.
131 ///
132 $(#[doc = safety_doc!($unsafety)])?
133 pub $($unsafety)? fn fetch_sub(&self, val: $float_ty, order: Ordering) -> $float_ty {
134 #[cfg(target_os = "cuda")]
135 // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
136 unsafe {
137 mid::[<atomic_fetch_sub_ $float_ty _ $scope>](self.v.get(), order, val)
138 }
139 #[cfg(not(target_os = "cuda"))]
140 self.update_with(order, |v| v - val)
141 }
142
143 /// Bitwise "and" with the current value. Returns the value **before** the "and".
144 ///
145 $(#[doc = safety_doc!($unsafety)])?
146 pub $($unsafety)? fn fetch_and(&self, val: $float_ty, order: Ordering) -> $float_ty {
147 #[cfg(target_os = "cuda")]
148 // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
149 unsafe {
150 mid::[<atomic_fetch_and_ $float_ty _ $scope>](self.v.get(), order, val)
151 }
152 #[cfg(not(target_os = "cuda"))]
153 self.update_with(order, |v| $float_ty::from_bits(v.to_bits() & val.to_bits()))
154 }
155
156 /// Bitwise "or" with the current value. Returns the value **before** the "or".
157 ///
158 $(#[doc = safety_doc!($unsafety)])?
159 pub $($unsafety)? fn fetch_or(&self, val: $float_ty, order: Ordering) -> $float_ty {
160 #[cfg(target_os = "cuda")]
161 // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
162 unsafe {
163 mid::[<atomic_fetch_or_ $float_ty _ $scope>](self.v.get(), order, val)
164 }
165 #[cfg(not(target_os = "cuda"))]
166 self.update_with(order, |v| $float_ty::from_bits(v.to_bits() | val.to_bits()))
167 }
168
169 /// Bitwise "xor" with the current value. Returns the value **before** the "xor".
170 ///
171 $(#[doc = safety_doc!($unsafety)])?
172 pub $($unsafety)? fn fetch_xor(&self, val: $float_ty, order: Ordering) -> $float_ty {
173 #[cfg(target_os = "cuda")]
174 // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
175 unsafe {
176 mid::[<atomic_fetch_xor_ $float_ty _ $scope>](self.v.get(), order, val)
177 }
178 #[cfg(not(target_os = "cuda"))]
179 self.update_with(order, |v| $float_ty::from_bits(v.to_bits() ^ val.to_bits()))
180 }
181
182 /// Atomically loads the value behind this atomic.
183 ///
184 /// `load` takes an [`Ordering`] argument which describes the memory ordering of this operation.
185 /// Possible values are [`Ordering::SeqCst`], [`Ordering::Acquire`], and [`Ordering::Relaxed`].
186 ///
187 /// # Panics
188 ///
189 /// Panics if `order` is [`Ordering::Release`] or [`Ordering::AcqRel`].
190 ///
191 $(#[doc = safety_doc!($unsafety)])?
192 pub $($unsafety)? fn load(&self, order: Ordering) -> $float_ty {
193 #[cfg(target_os = "cuda")]
194 unsafe {
195 let val = mid::[<atomic_load_ $width _ $scope>](self.v.get().cast(), order);
196 $float_ty::from_bits(val)
197 }
198 #[cfg(not(target_os = "cuda"))]
199 {
200 let val = self.as_atomic_bits().load(order);
201 $float_ty::from_bits(val)
202 }
203 }
204
205 /// Atomically stores a value into this atomic.
206 ///
207 /// `store` takes an [`Ordering`] argument which describes the memory ordering of this operation.
208 /// Possible values are [`Ordering::SeqCst`], [`Ordering::Release`], and [`Ordering::Relaxed`].
209 ///
210 /// # Panics
211 ///
212 /// Panics if `order` is [`Ordering::Acquire`] or [`Ordering::AcqRel`].
213 ///
214 $(#[doc = safety_doc!($unsafety)])?
215 pub $($unsafety)? fn store(&self, val: $float_ty, order: Ordering) {
216 #[cfg(target_os = "cuda")]
217 unsafe {
218 mid::[<atomic_store_ $width _ $scope>](self.v.get().cast(), order, val.to_bits());
219 }
220 #[cfg(not(target_os = "cuda"))]
221 self.as_atomic_bits().store(val.to_bits(), order);
222 }
223 }
224 }
225 };
226}
227
228atomic_float!(f32, AtomicF32, 4, device, 32);
229atomic_float!(f64, AtomicF64, 8, device, 64);
230atomic_float!(f32, BlockAtomicF32, 4, block, 32, unsafe);
231atomic_float!(f64, BlockAtomicF64, 8, block, 64, unsafe);
232atomic_float!(f32, SystemAtomicF32, 4, device, 32);
233atomic_float!(f64, SystemAtomicF64, 8, device, 64);