#![no_std]
use core::cell::{Cell, UnsafeCell};
use core::fmt::Debug;
use core::marker::PhantomData;
use core::mem::{transmute, MaybeUninit};
use core::ops::Shl;
use core::pin::Pin;
use core::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use radium::marker::{Nuclear, NumericOps};
use radium::{Isotope, Radium};
use scopeguard::defer;
pub unsafe trait Reborrowable {
const IS_SHARED: bool;
type Borrowed<'a>;
}
pub struct UnsyncLock<T: Nuclear>(Cell<T>);
pub struct SyncLock<T: Nuclear>(Isotope<T>);
macro_rules! impl_channel_lock {
($T:ty) => {
impl CounterInnerPriv for $T {}
impl Lock for UnsyncLock<$T> {}
impl ChannelLockPriv for UnsyncLock<$T> {
type CounterInner = $T;
type Counter = Cell<$T>;
fn counter(&self) -> &Self::Counter {
&self.0
}
fn new() -> Self {
Self(Radium::new(0))
}
}
impl Lock for SyncLock<$T> {}
impl ChannelLockPriv for SyncLock<$T> {
type CounterInner = $T;
type Counter = Isotope<$T>;
fn counter(&self) -> &Self::Counter {
&self.0
}
fn new() -> Self {
Self(Radium::new(0))
}
}
};
}
impl_channel_lock!(u64);
impl_channel_lock!(u32);
impl_channel_lock!(u16);
impl_channel_lock!(u8);
#[allow(private_bounds)]
pub trait Lock: ChannelLockPriv {}
trait ChannelLockPriv {
type CounterInner: CounterInnerPriv;
type Counter: Radium<Item = Self::CounterInner>;
fn counter(&self) -> &Self::Counter;
fn new() -> Self;
}
trait CounterInnerPriv:
From<u8> + Shl<u32, Output = Self> + NumericOps + Debug + Copy + Nuclear
{
}
fn counter_bits<C: CounterInnerPriv>() -> u32 {
size_of::<C>() as u32 * 8
}
fn state_empty<C: CounterInnerPriv>() -> C {
C::from(0u8)
}
fn state_locked<C: CounterInnerPriv>() -> C {
C::from(1u8) << (counter_bits::<C>() - 2)
}
fn state_filled<C: CounterInnerPriv>() -> C {
C::from(2u8) << (counter_bits::<C>() - 2)
}
fn state_mask<C: CounterInnerPriv>() -> C {
C::from(3u8) << (counter_bits::<C>() - 2)
}
pub struct BorrowChannel<T: Reborrowable, L: Lock> {
data: UnsafeCell<MaybeUninit<T::Borrowed<'static>>>,
count: L,
_p: PhantomData<T>,
}
pub struct BorrowChannelGuard<'a, T: Reborrowable, L: Lock> {
channel: &'a BorrowChannel<T, L>,
}
impl<T: Reborrowable, L: Lock> Drop for BorrowChannelGuard<'_, T, L> {
fn drop(&mut self) {
self.channel.count().fetch_sub(1u8.into(), Release);
}
}
impl<T: Reborrowable, L: Lock> BorrowChannelGuard<'_, T, L> {
pub fn get_mut(&mut self) -> T::Borrowed<'_> {
unsafe {
let ptr: *const MaybeUninit<T::Borrowed<'static>> = self.channel.data.get();
ptr.cast::<T::Borrowed<'_>>().read()
}
}
pub fn get(&self) -> &T::Borrowed<'_> {
unsafe {
transmute::<&MaybeUninit<T::Borrowed<'static>>, &T::Borrowed<'_>>(
&*self.channel.data.get(),
)
}
}
}
const _: () = {
if cfg!(all(feature = "unsafe_disable_abort", not(debug_assertions))) {
panic!("The unsafe_disable_abort feature is intended only for testing. It makes BorrowChannel unsound.");
}
};
impl<T: Reborrowable, L: Lock> BorrowChannel<T, L> {
pub unsafe fn new() -> Self {
BorrowChannel {
data: UnsafeCell::new(MaybeUninit::uninit()),
count: L::new(),
_p: PhantomData,
}
}
fn count(&self) -> &L::Counter {
self.count.counter()
}
pub fn lend<R>(&self, borrow: T::Borrowed<'_>, f: impl FnOnce() -> R) -> R {
self.count()
.compare_exchange(state_empty(), state_locked(), Acquire, Relaxed)
.expect("borrow channel is not empty");
defer! {
if self.count().compare_exchange(state_filled(),state_empty(),Acquire,Relaxed).is_err(){
abort();
}
}
unsafe {
self.data.get().write(MaybeUninit::new(transmute::<
T::Borrowed<'_>,
T::Borrowed<'static>,
>(borrow)));
}
self.count().store(state_filled(), Release);
f()
}
pub fn borrow(&self) -> BorrowChannelGuard<'_, T, L> {
let old_count = self.count().fetch_add(1u8.into(), Acquire);
if old_count & state_mask() != state_filled() {
self.count()
.fetch_update(Relaxed, Relaxed, |x| {
if x & state_mask() != state_filled() {
Some(x & state_mask())
} else {
None
}
})
.ok();
panic!("channel is empty");
};
let guard = BorrowChannelGuard { channel: self };
if !T::IS_SHARED {
assert!(
old_count == state_filled(),
"channel is already borrowed from"
);
}
debug_assert!(
(old_count + 1u8.into()) & state_mask() == state_filled(),
"channel counter overflow"
);
guard
}
}
impl<T: Reborrowable> BorrowChannel<T, UnsyncLock<u64>> {
pub fn new_unsync() -> Self {
unsafe { Self::new() }
}
}
impl<T: Reborrowable> BorrowChannel<T, SyncLock<u64>> {
pub fn new_sync() -> Self {
unsafe { Self::new() }
}
}
fn abort() -> ! {
extern "C" fn inner_abort() -> ! {
panic!("abort");
}
if cfg!(feature = "unsafe_disable_abort") {
panic!("abort");
} else {
inner_abort()
}
}
unsafe impl<T: Sized> Reborrowable for &'static T {
const IS_SHARED: bool = true;
type Borrowed<'a> = &'a T;
}
unsafe impl<T: Sized> Reborrowable for &'static mut T {
const IS_SHARED: bool = false;
type Borrowed<'a> = &'a mut T;
}
unsafe impl<T: Reborrowable, L: Lock> Sync for BorrowChannel<T, L>
where
L: Sync,
T: Send + Sync,
for<'a> T::Borrowed<'a>: Send + Sync,
{
}
unsafe impl<T: Sized> Reborrowable for Pin<&'static T> {
const IS_SHARED: bool = true;
type Borrowed<'a> = Pin<&'a T>;
}
unsafe impl<T: Sized> Reborrowable for Pin<&'static mut T> {
const IS_SHARED: bool = false;
type Borrowed<'a> = Pin<&'a mut T>;
}