#![allow(unused_unsafe)]
pub mod intrinsics;
pub mod mid;
use core::cell::UnsafeCell;
use core::sync::atomic::Ordering;
use paste::paste;
#[allow(dead_code)]
fn fail_order(order: Ordering) -> Ordering {
match order {
Ordering::Release | Ordering::Relaxed => Ordering::Relaxed,
Ordering::Acquire | Ordering::AcqRel => Ordering::Acquire,
Ordering::SeqCst => Ordering::SeqCst,
x => x,
}
}
macro_rules! scope_doc {
(device) => {
"a single device (GPU)."
};
(block) => {
"a single thread block (also called a CTA, cooperative thread array)."
};
(system) => {
"the entire system."
};
}
macro_rules! safety_doc {
($($unsafety:ident)?) => {
$(
concat!(
"# Safety\n",
concat!("This function is ", stringify!($unsafety), " because it does not synchronize\n"),
"across the entire GPU or System, which leaves it open for data races if used incorrectly"
)
)?
};
}
macro_rules! atomic_float {
($float_ty:ident, $atomic_ty:ident, $align:tt, $scope:ident, $width:tt $(,$unsafety:ident)?) => {
#[doc = concat!("A ", stringify!($width), "-bit float type which can be safely shared between threads and synchronizes across ", scope_doc!($scope))]
#[doc = stringify!($float_ty)]
#[repr(C, align($align))]
pub struct $atomic_ty {
v: UnsafeCell<$float_ty>,
}
unsafe impl Sync for $atomic_ty {}
impl $atomic_ty {
paste! {
pub const fn new(v: $float_ty) -> $atomic_ty {
Self {
v: UnsafeCell::new(v),
}
}
pub fn into_inner(self) -> $float_ty {
self.v.into_inner()
}
#[cfg(not(target_os = "cuda"))]
fn as_atomic_bits(&self) -> &core::sync::atomic::[<AtomicU $width>] {
unsafe {
core::mem::transmute(self)
}
}
#[cfg(not(target_os = "cuda"))]
fn update_with(&self, order: Ordering, mut func: impl FnMut($float_ty) -> $float_ty) -> $float_ty {
let res = self
.as_atomic_bits()
.fetch_update(order, fail_order(order), |prev| {
Some(func($float_ty::from_bits(prev))).map($float_ty::to_bits)
}).unwrap();
$float_ty::from_bits(res)
}
$(#[doc = safety_doc!($unsafety)])?
pub $($unsafety)? fn fetch_add(&self, val: $float_ty, order: Ordering) -> $float_ty {
#[cfg(target_os = "cuda")]
unsafe {
mid::[<atomic_fetch_add_ $float_ty _ $scope>](self.v.get(), order, val)
}
#[cfg(not(target_os = "cuda"))]
self.update_with(order, |v| v + val)
}
$(#[doc = safety_doc!($unsafety)])?
pub $($unsafety)? fn fetch_sub(&self, val: $float_ty, order: Ordering) -> $float_ty {
#[cfg(target_os = "cuda")]
unsafe {
mid::[<atomic_fetch_sub_ $float_ty _ $scope>](self.v.get(), order, val)
}
#[cfg(not(target_os = "cuda"))]
self.update_with(order, |v| v - val)
}
$(#[doc = safety_doc!($unsafety)])?
pub $($unsafety)? fn fetch_and(&self, val: $float_ty, order: Ordering) -> $float_ty {
#[cfg(target_os = "cuda")]
unsafe {
mid::[<atomic_fetch_and_ $float_ty _ $scope>](self.v.get(), order, val)
}
#[cfg(not(target_os = "cuda"))]
self.update_with(order, |v| $float_ty::from_bits(v.to_bits() & val.to_bits()))
}
$(#[doc = safety_doc!($unsafety)])?
pub $($unsafety)? fn fetch_or(&self, val: $float_ty, order: Ordering) -> $float_ty {
#[cfg(target_os = "cuda")]
unsafe {
mid::[<atomic_fetch_or_ $float_ty _ $scope>](self.v.get(), order, val)
}
#[cfg(not(target_os = "cuda"))]
self.update_with(order, |v| $float_ty::from_bits(v.to_bits() | val.to_bits()))
}
$(#[doc = safety_doc!($unsafety)])?
pub $($unsafety)? fn fetch_xor(&self, val: $float_ty, order: Ordering) -> $float_ty {
#[cfg(target_os = "cuda")]
unsafe {
mid::[<atomic_fetch_xor_ $float_ty _ $scope>](self.v.get(), order, val)
}
#[cfg(not(target_os = "cuda"))]
self.update_with(order, |v| $float_ty::from_bits(v.to_bits() ^ val.to_bits()))
}
$(#[doc = safety_doc!($unsafety)])?
pub $($unsafety)? fn load(&self, order: Ordering) -> $float_ty {
#[cfg(target_os = "cuda")]
unsafe {
let val = mid::[<atomic_load_ $width _ $scope>](self.v.get().cast(), order);
$float_ty::from_bits(val)
}
#[cfg(not(target_os = "cuda"))]
{
let val = self.as_atomic_bits().load(order);
$float_ty::from_bits(val)
}
}
$(#[doc = safety_doc!($unsafety)])?
pub $($unsafety)? fn store(&self, val: $float_ty, order: Ordering) {
#[cfg(target_os = "cuda")]
unsafe {
mid::[<atomic_store_ $width _ $scope>](self.v.get().cast(), order, val.to_bits());
}
#[cfg(not(target_os = "cuda"))]
self.as_atomic_bits().store(val.to_bits(), order);
}
}
}
};
}
atomic_float!(f32, AtomicF32, 4, device, 32);
atomic_float!(f64, AtomicF64, 8, device, 64);
atomic_float!(f32, BlockAtomicF32, 4, block, 32, unsafe);
atomic_float!(f64, BlockAtomicF64, 8, block, 64, unsafe);
atomic_float!(f32, SystemAtomicF32, 4, device, 32);
atomic_float!(f64, SystemAtomicF64, 8, device, 64);