use crate::RawStr;
use std::cell::Cell;
use std::fmt;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::ops;
use std::pin::Pin;
use std::task::{Context, Poll};
use thiserror::Error;
const FLAG: isize = 1isize;
const TAKEN: isize = (isize::max_value() ^ FLAG) >> 1;
const MAX_USES: isize = 0b11isize.rotate_right(2);
#[derive(Debug, Error)]
pub enum AccessError {
#[error("expected data of type `{expected}`, but found `{actual}`")]
UnexpectedType {
expected: RawStr,
actual: RawStr,
},
#[error("{error}")]
NotAccessibleRef {
#[source]
#[from]
error: NotAccessibleRef,
},
#[error("{error}")]
NotAccessibleMut {
#[source]
#[from]
error: NotAccessibleMut,
},
#[error("{error}")]
NotAccessibleTake {
#[source]
#[from]
error: NotAccessibleTake,
},
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum AccessKind {
Any,
Owned,
}
#[derive(Debug, Error)]
#[error("cannot read, value is {0}")]
pub struct NotAccessibleRef(Snapshot);
#[derive(Debug, Error)]
#[error("cannot write, value is {0}")]
pub struct NotAccessibleMut(Snapshot);
#[derive(Debug, Error)]
#[error("cannot take, value is {0}")]
pub struct NotAccessibleTake(Snapshot);
#[derive(Debug)]
#[repr(transparent)]
pub struct Snapshot(isize);
impl fmt::Display for Snapshot {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 >> 1 {
0 => write!(f, "fully accessible")?,
1 => write!(f, "exclusively accessed")?,
TAKEN => write!(f, "moved")?,
n if n < 0 => write!(f, "shared by {}", -n)?,
n => write!(f, "invalidly marked ({})", n)?,
}
if self.0 & FLAG == 1 {
write!(f, " (ref)")?;
}
Ok(())
}
}
#[derive(Clone)]
pub(crate) struct Access(Cell<isize>);
impl Access {
pub(crate) const fn new(is_ref: bool) -> Self {
let initial = if is_ref { 1 } else { 0 };
Self(Cell::new(initial))
}
#[inline]
pub(crate) fn is_ref(&self) -> bool {
self.0.get() & FLAG != 0
}
#[inline]
pub(crate) fn is_shared(&self) -> bool {
self.get().wrapping_sub(1) < 0
}
#[inline]
pub(crate) fn is_exclusive(&self) -> bool {
self.get() == 0
}
#[inline]
pub(crate) fn is_taken(&self) -> bool {
self.get() == TAKEN
}
#[inline]
pub(crate) unsafe fn shared(
&self,
kind: AccessKind,
) -> Result<SharedGuard<'_>, NotAccessibleRef> {
if let AccessKind::Owned = kind {
if self.is_ref() {
return Err(NotAccessibleRef(Snapshot(self.0.get())));
}
}
let state = self.get();
if state == MAX_USES {
std::process::abort();
}
let n = state.wrapping_sub(1);
if n >= 0 {
return Err(NotAccessibleRef(Snapshot(self.0.get())));
}
self.set(n);
Ok(SharedGuard(self))
}
#[inline]
pub(crate) unsafe fn exclusive(
&self,
kind: AccessKind,
) -> Result<ExclusiveGuard<'_>, NotAccessibleMut> {
if let AccessKind::Owned = kind {
if self.is_ref() {
return Err(NotAccessibleMut(Snapshot(self.0.get())));
}
}
let state = self.get();
let n = state.wrapping_add(1);
if n != 1 {
return Err(NotAccessibleMut(Snapshot(self.0.get())));
}
self.set(n);
Ok(ExclusiveGuard(self))
}
#[inline]
pub(crate) unsafe fn take(&self, kind: AccessKind) -> Result<RawTakeGuard, NotAccessibleTake> {
if let AccessKind::Owned = kind {
if self.is_ref() {
return Err(NotAccessibleTake(Snapshot(self.0.get())));
}
}
let state = self.get();
if state != 0 {
return Err(NotAccessibleTake(Snapshot(self.0.get())));
}
self.set(TAKEN);
Ok(RawTakeGuard { access: self })
}
#[inline]
fn release_shared(&self) {
let b = self.get().wrapping_add(1);
debug_assert!(b <= 0);
self.set(b);
}
#[inline]
fn release_exclusive(&self) {
let b = self.get().wrapping_sub(1);
debug_assert_eq!(b, 0, "borrow value should be exclusive (0)");
self.set(b);
}
#[inline]
fn release_take(&self) {
let b = self.get();
debug_assert_eq!(b, TAKEN, "borrow value should be TAKEN ({})", TAKEN);
self.set(0);
}
#[inline]
fn get(&self) -> isize {
self.0.get() >> 1
}
#[inline]
fn set(&self, value: isize) {
self.0.set(self.0.get() & FLAG | value << 1);
}
}
impl fmt::Debug for Access {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", Snapshot(self.get()))
}
}
pub struct RawSharedGuard(*const Access);
impl Drop for RawSharedGuard {
fn drop(&mut self) {
unsafe { (*self.0).release_shared() };
}
}
pub struct BorrowRef<'a, T: ?Sized + 'a> {
data: &'a T,
guard: SharedGuard<'a>,
}
impl<'a, T: ?Sized> BorrowRef<'a, T> {
pub(crate) fn new(data: &'a T, access: &'a Access) -> Self {
Self {
data,
guard: SharedGuard(access),
}
}
pub fn try_map<M, U: ?Sized, E>(this: Self, m: M) -> Result<BorrowRef<'a, U>, E>
where
M: FnOnce(&T) -> Result<&U, E>,
{
Ok(BorrowRef {
data: m(this.data)?,
guard: this.guard,
})
}
}
impl<T: ?Sized> ops::Deref for BorrowRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<T: ?Sized> fmt::Debug for BorrowRef<'_, T>
where
T: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, fmt)
}
}
pub struct SharedGuard<'a>(&'a Access);
impl SharedGuard<'_> {
pub unsafe fn into_raw(self) -> RawSharedGuard {
RawSharedGuard(ManuallyDrop::new(self).0)
}
}
impl Drop for SharedGuard<'_> {
fn drop(&mut self) {
self.0.release_shared();
}
}
pub struct RawExclusiveGuard(*const Access);
impl Drop for RawExclusiveGuard {
fn drop(&mut self) {
unsafe { (*self.0).release_exclusive() }
}
}
pub(crate) struct RawTakeGuard {
access: *const Access,
}
impl Drop for RawTakeGuard {
fn drop(&mut self) {
unsafe { (*self.access).release_take() }
}
}
pub struct BorrowMut<'a, T: ?Sized> {
data: &'a mut T,
guard: ExclusiveGuard<'a>,
}
impl<'a, T: ?Sized> BorrowMut<'a, T> {
pub(crate) unsafe fn new(data: &'a mut T, access: &'a Access) -> Self {
Self {
data,
guard: ExclusiveGuard(access),
}
}
pub fn map<M, U: ?Sized>(this: Self, m: M) -> BorrowMut<'a, U>
where
M: FnOnce(&mut T) -> &mut U,
{
BorrowMut {
data: m(this.data),
guard: this.guard,
}
}
pub fn try_map<M, U: ?Sized>(this: Self, m: M) -> Option<BorrowMut<'a, U>>
where
M: FnOnce(&mut T) -> Option<&mut U>,
{
Some(BorrowMut {
data: m(this.data)?,
guard: this.guard,
})
}
}
impl<T: ?Sized> ops::Deref for BorrowMut<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<T: ?Sized> ops::DerefMut for BorrowMut<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data
}
}
impl<T: ?Sized> fmt::Debug for BorrowMut<'_, T>
where
T: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, fmt)
}
}
impl<F> Future for BorrowMut<'_, F>
where
F: Unpin + Future,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
Pin::new(&mut **this).poll(cx)
}
}
pub struct ExclusiveGuard<'a>(&'a Access);
impl ExclusiveGuard<'_> {
pub unsafe fn into_raw(self) -> RawExclusiveGuard {
RawExclusiveGuard(ManuallyDrop::new(self).0)
}
}
impl Drop for ExclusiveGuard<'_> {
fn drop(&mut self) {
self.0.release_exclusive();
}
}
#[cfg(test)]
mod tests {
use super::{Access, AccessKind};
#[test]
fn test_non_ref() {
unsafe {
let access = Access::new(false);
assert!(!access.is_ref());
assert!(access.is_shared());
assert!(access.is_exclusive());
let guard = access.shared(AccessKind::Any).unwrap();
assert!(!access.is_ref());
assert!(access.is_shared());
assert!(!access.is_exclusive());
drop(guard);
assert!(!access.is_ref());
assert!(access.is_shared());
assert!(access.is_exclusive());
let guard = access.exclusive(AccessKind::Any).unwrap();
assert!(!access.is_ref());
assert!(!access.is_shared());
assert!(!access.is_exclusive());
drop(guard);
assert!(!access.is_ref());
assert!(access.is_shared());
assert!(access.is_exclusive());
}
}
#[test]
fn test_ref() {
unsafe {
let access = Access::new(true);
assert!(access.is_ref());
assert!(access.is_shared());
assert!(access.is_exclusive());
let guard = access.shared(AccessKind::Any).unwrap();
assert!(access.is_ref());
assert!(access.is_shared());
assert!(!access.is_exclusive());
drop(guard);
assert!(access.is_ref());
assert!(access.is_shared());
assert!(access.is_exclusive());
let guard = access.exclusive(AccessKind::Any).unwrap();
assert!(access.is_ref());
assert!(!access.is_shared());
assert!(!access.is_exclusive());
drop(guard);
assert!(access.is_ref());
assert!(access.is_shared());
assert!(access.is_exclusive());
}
}
}