use futures::Async;
use semaphore;
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
#[derive(Debug)]
pub struct Lock<T> {
inner: Arc<State<T>>,
permit: semaphore::Permit,
}
#[derive(Debug)]
pub struct LockGuard<T>(Lock<T>);
unsafe impl<T> Send for Lock<T> where T: Send {}
unsafe impl<T> Sync for Lock<T> where T: Send {}
unsafe impl<T> Sync for LockGuard<T> where T: Send + Sync {}
#[derive(Debug)]
struct State<T> {
c: UnsafeCell<T>,
s: semaphore::Semaphore,
}
#[test]
fn bounds() {
fn check<T: Send>() {}
check::<LockGuard<u32>>();
}
impl<T> Lock<T> {
pub fn new(t: T) -> Self {
Self {
inner: Arc::new(State {
c: UnsafeCell::new(t),
s: semaphore::Semaphore::new(1),
}),
permit: semaphore::Permit::new(),
}
}
pub fn poll_lock(&mut self) -> Async<LockGuard<T>> {
if let Async::NotReady = self.permit.poll_acquire(&self.inner.s).unwrap_or_else(|_| {
unreachable!()
}) {
return Async::NotReady;
}
let acquired = Self {
inner: self.inner.clone(),
permit: ::std::mem::replace(&mut self.permit, semaphore::Permit::new()),
};
Async::Ready(LockGuard(acquired))
}
}
impl<T> Drop for LockGuard<T> {
fn drop(&mut self) {
if self.0.permit.is_acquired() {
self.0.permit.release(&self.0.inner.s);
} else if ::std::thread::panicking() {
} else {
unreachable!("Permit not held when LockGuard was dropped")
}
}
}
impl<T> From<T> for Lock<T> {
fn from(s: T) -> Self {
Self::new(s)
}
}
impl<T> Clone for Lock<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
permit: semaphore::Permit::new(),
}
}
}
impl<T> Default for Lock<T>
where
T: Default,
{
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> Deref for LockGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
assert!(self.0.permit.is_acquired());
unsafe { &*self.0.inner.c.get() }
}
}
impl<T> DerefMut for LockGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
assert!(self.0.permit.is_acquired());
unsafe { &mut *self.0.inner.c.get() }
}
}
impl<T: fmt::Display> fmt::Display for LockGuard<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}