pinned_mutex/
std.rs

1use std::ops::{Deref, DerefMut};
2use std::pin::Pin;
3use std::sync::{Condvar, Mutex, MutexGuard};
4
5/// Provides [structural
6/// pinning](https://doc.rust-lang.org/std/pin/index.html#projections-and-structural-pinning)
7/// atop [Mutex].
8#[derive(Debug, Default)]
9pub struct PinnedMutex<T> {
10    inner: Mutex<T>,
11}
12
13impl<T> PinnedMutex<T> {
14    pub fn new(init: T) -> Self {
15        Self {
16            inner: Mutex::new(init),
17        }
18    }
19
20    /// Acquires the lock and returns a guard.
21    ///
22    /// Poisoning is not supported. If the underlying mutex is
23    /// poisoned, `lock` will panic.
24    pub fn lock(self: Pin<&Self>) -> PinnedMutexGuard<'_, T> {
25        let guard = self
26            .get_ref()
27            .inner
28            .lock()
29            .expect("PinnedMutex does not expose poison");
30        PinnedMutexGuard { guard }
31    }
32}
33
34/// Provides access to mutex's contents. [Deref] to `&T` is always
35/// possible. [DerefMut] to `&mut T` is only possive if T is `Unpin`.
36///
37/// `as_ref` and `as_mut` project structural pinning.
38#[derive(Debug)]
39pub struct PinnedMutexGuard<'a, T: 'a> {
40    guard: MutexGuard<'a, T>,
41}
42
43impl<'a, T> PinnedMutexGuard<'a, T> {
44    /// Provides pinned access to the underlying T.
45    pub fn as_ref(&self) -> Pin<&T> {
46        // PinnedMutex::lock requires the mutex is pinned.
47        unsafe { Pin::new_unchecked(&self.guard) }
48    }
49
50    /// Provides pinned mutable access to the underlying T.
51    pub fn as_mut(&mut self) -> Pin<&mut T> {
52        // PinnedMutex::lock requires the mutex is pinned.
53        // &mut self guarantees as_ref() cannot alias.
54        unsafe { Pin::new_unchecked(&mut self.guard) }
55    }
56}
57
58impl<'a, T> Deref for PinnedMutexGuard<'a, T> {
59    type Target = T;
60    fn deref(&self) -> &Self::Target {
61        &self.guard
62    }
63}
64
65impl<'a, T: Unpin> DerefMut for PinnedMutexGuard<'a, T> {
66    fn deref_mut(&mut self) -> &mut Self::Target {
67        // SAFETY: T is Unpin, so it's safe to move out of T.
68        &mut self.guard
69    }
70}
71
72#[derive(Debug, Default)]
73pub struct PinnedCondvar(Condvar);
74
75impl PinnedCondvar {
76    pub fn new() -> PinnedCondvar {
77        Default::default()
78    }
79
80    pub fn wait<'a, T>(&self, guard: PinnedMutexGuard<'a, T>) -> PinnedMutexGuard<'a, T> {
81        PinnedMutexGuard {
82            guard: self
83                .0
84                .wait(guard.guard)
85                .expect("PinnedMutex does not expose poison"),
86        }
87    }
88
89    pub fn wait_while<'a, T, F>(
90        &self,
91        guard: PinnedMutexGuard<'a, T>,
92        mut condition: F,
93    ) -> PinnedMutexGuard<'a, T>
94    where
95        F: FnMut(Pin<&mut T>) -> bool,
96    {
97        PinnedMutexGuard {
98            guard: self
99                .0
100                .wait_while(guard.guard, move |v| {
101                    // SAFETY: v is never moved.
102                    condition(unsafe { Pin::new_unchecked(v) })
103                })
104                .expect("PinnedMutex does not expose poison"),
105        }
106    }
107
108    pub fn notify_one(&self) {
109        self.0.notify_one()
110    }
111
112    pub fn notify_all(&self) {
113        self.0.notify_all()
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use pin_project::pin_project;
121    use std::{marker::PhantomPinned, pin::pin};
122
123    #[test]
124    fn mutate_through_lock() {
125        let pm = pin!(PinnedMutex::new(15));
126        let mut locked = pm.as_ref().lock();
127        *locked = 16;
128    }
129
130    #[pin_project(UnsafeUnpin)]
131    struct MustPin {
132        value: u32,
133        pinned: PhantomPinned,
134    }
135
136    impl MustPin {
137        fn new() -> Self {
138            Self {
139                value: 0,
140                pinned: PhantomPinned,
141            }
142        }
143
144        fn inc(self: Pin<&mut Self>) -> u32 {
145            let value = self.project().value;
146            let prev = *value;
147            *value += 1;
148            prev
149        }
150
151        fn get(self: Pin<&Self>) -> u32 {
152            *self.project_ref().value
153        }
154    }
155
156    #[test]
157    fn pinned_method() {
158        let pm = pin!(PinnedMutex::new(MustPin::new()));
159        let mut locked = pm.as_ref().lock();
160        assert_eq!(0, locked.as_mut().inc());
161        assert_eq!(1, locked.as_mut().inc());
162        assert_eq!(2, locked.as_ref().get());
163    }
164
165    #[test]
166    fn ref_alias() {
167        let pm = pin!(PinnedMutex::new(MustPin::new()));
168        let locked = pm.as_ref().lock();
169        let a = locked.as_ref();
170        let b = locked.as_ref();
171        assert_eq!(a.value, b.value);
172    }
173
174    #[test]
175    fn cond_var() {
176        let cv = PinnedCondvar::new();
177        let pm = pin!(PinnedMutex::new(MustPin::new()));
178        let mut locked = pm.as_ref().lock();
179        locked.as_mut().inc();
180        let locked = cv.wait_while(locked, |pinned_contents| {
181            pinned_contents.as_ref().get() == 0
182        });
183        cv.wait_while(locked, |pinned_contents| {
184            pinned_contents.as_ref().get() == 0
185        });
186        cv.notify_one();
187        cv.notify_all();
188    }
189
190    #[derive(Debug, Default)]
191    struct DebugTest;
192
193    #[test]
194    fn default_and_debug() {
195        let pm: PinnedMutex<DebugTest> = Default::default();
196        _ = format!("{:?}", pm);
197    }
198}