use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::{Condvar, Mutex, MutexGuard};
#[derive(Debug, Default)]
pub struct PinnedMutex<T> {
inner: Mutex<T>,
}
impl<T> PinnedMutex<T> {
pub fn new(init: T) -> Self {
Self {
inner: Mutex::new(init),
}
}
pub fn lock(self: Pin<&Self>) -> PinnedMutexGuard<'_, T> {
let guard = self
.get_ref()
.inner
.lock()
.expect("PinnedMutex does not expose poison");
PinnedMutexGuard { guard }
}
}
#[derive(Debug)]
pub struct PinnedMutexGuard<'a, T: 'a> {
guard: MutexGuard<'a, T>,
}
impl<'a, T> PinnedMutexGuard<'a, T> {
pub fn as_ref(&self) -> Pin<&T> {
unsafe { Pin::new_unchecked(&self.guard) }
}
pub fn as_mut(&mut self) -> Pin<&mut T> {
unsafe { Pin::new_unchecked(&mut self.guard) }
}
}
impl<'a, T> Deref for PinnedMutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<'a, T: Unpin> DerefMut for PinnedMutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
}
}
#[derive(Debug, Default)]
pub struct PinnedCondvar(Condvar);
impl PinnedCondvar {
pub fn new() -> PinnedCondvar {
Default::default()
}
pub fn wait<'a, T>(&self, guard: PinnedMutexGuard<'a, T>) -> PinnedMutexGuard<'a, T> {
PinnedMutexGuard {
guard: self
.0
.wait(guard.guard)
.expect("PinnedMutex does not expose poison"),
}
}
pub fn wait_while<'a, T, F>(
&self,
guard: PinnedMutexGuard<'a, T>,
mut condition: F,
) -> PinnedMutexGuard<'a, T>
where
F: FnMut(Pin<&mut T>) -> bool,
{
PinnedMutexGuard {
guard: self
.0
.wait_while(guard.guard, move |v| {
condition(unsafe { Pin::new_unchecked(v) })
})
.expect("PinnedMutex does not expose poison"),
}
}
pub fn notify_one(&self) {
self.0.notify_one()
}
pub fn notify_all(&self) {
self.0.notify_all()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pin_project::pin_project;
use std::{marker::PhantomPinned, pin::pin};
#[test]
fn mutate_through_lock() {
let pm = pin!(PinnedMutex::new(15));
let mut locked = pm.as_ref().lock();
*locked = 16;
}
#[pin_project(UnsafeUnpin)]
struct MustPin {
value: u32,
pinned: PhantomPinned,
}
impl MustPin {
fn new() -> Self {
Self {
value: 0,
pinned: PhantomPinned,
}
}
fn inc(self: Pin<&mut Self>) -> u32 {
let value = self.project().value;
let prev = *value;
*value += 1;
prev
}
fn get(self: Pin<&Self>) -> u32 {
*self.project_ref().value
}
}
#[test]
fn pinned_method() {
let pm = pin!(PinnedMutex::new(MustPin::new()));
let mut locked = pm.as_ref().lock();
assert_eq!(0, locked.as_mut().inc());
assert_eq!(1, locked.as_mut().inc());
assert_eq!(2, locked.as_ref().get());
}
#[test]
fn ref_alias() {
let pm = pin!(PinnedMutex::new(MustPin::new()));
let locked = pm.as_ref().lock();
let a = locked.as_ref();
let b = locked.as_ref();
assert_eq!(a.value, b.value);
}
#[test]
fn cond_var() {
let cv = PinnedCondvar::new();
let pm = pin!(PinnedMutex::new(MustPin::new()));
let mut locked = pm.as_ref().lock();
locked.as_mut().inc();
let locked = cv.wait_while(locked, |pinned_contents| {
pinned_contents.as_ref().get() == 0
});
cv.wait_while(locked, |pinned_contents| {
pinned_contents.as_ref().get() == 0
});
cv.notify_one();
cv.notify_all();
}
#[derive(Debug, Default)]
struct DebugTest;
#[test]
fn default_and_debug() {
let pm: PinnedMutex<DebugTest> = Default::default();
_ = format!("{:?}", pm);
}
}