1use std::ops::{Deref, DerefMut};
2use std::pin::Pin;
3use std::sync::{Condvar, Mutex, MutexGuard};
4
5#[derive(Debug, Default)]
9pub struct PinnedMutex<T> {
10 inner: Mutex<T>,
11}
12
13impl<T> PinnedMutex<T> {
14 pub fn new(init: T) -> Self {
15 Self {
16 inner: Mutex::new(init),
17 }
18 }
19
20 pub fn lock(self: Pin<&Self>) -> PinnedMutexGuard<'_, T> {
25 let guard = self
26 .get_ref()
27 .inner
28 .lock()
29 .expect("PinnedMutex does not expose poison");
30 PinnedMutexGuard { guard }
31 }
32}
33
34#[derive(Debug)]
39pub struct PinnedMutexGuard<'a, T: 'a> {
40 guard: MutexGuard<'a, T>,
41}
42
43impl<'a, T> PinnedMutexGuard<'a, T> {
44 pub fn as_ref(&self) -> Pin<&T> {
46 unsafe { Pin::new_unchecked(&self.guard) }
48 }
49
50 pub fn as_mut(&mut self) -> Pin<&mut T> {
52 unsafe { Pin::new_unchecked(&mut self.guard) }
55 }
56}
57
58impl<'a, T> Deref for PinnedMutexGuard<'a, T> {
59 type Target = T;
60 fn deref(&self) -> &Self::Target {
61 &self.guard
62 }
63}
64
65impl<'a, T: Unpin> DerefMut for PinnedMutexGuard<'a, T> {
66 fn deref_mut(&mut self) -> &mut Self::Target {
67 &mut self.guard
69 }
70}
71
72#[derive(Debug, Default)]
73pub struct PinnedCondvar(Condvar);
74
75impl PinnedCondvar {
76 pub fn new() -> PinnedCondvar {
77 Default::default()
78 }
79
80 pub fn wait<'a, T>(&self, guard: PinnedMutexGuard<'a, T>) -> PinnedMutexGuard<'a, T> {
81 PinnedMutexGuard {
82 guard: self
83 .0
84 .wait(guard.guard)
85 .expect("PinnedMutex does not expose poison"),
86 }
87 }
88
89 pub fn wait_while<'a, T, F>(
90 &self,
91 guard: PinnedMutexGuard<'a, T>,
92 mut condition: F,
93 ) -> PinnedMutexGuard<'a, T>
94 where
95 F: FnMut(Pin<&mut T>) -> bool,
96 {
97 PinnedMutexGuard {
98 guard: self
99 .0
100 .wait_while(guard.guard, move |v| {
101 condition(unsafe { Pin::new_unchecked(v) })
103 })
104 .expect("PinnedMutex does not expose poison"),
105 }
106 }
107
108 pub fn notify_one(&self) {
109 self.0.notify_one()
110 }
111
112 pub fn notify_all(&self) {
113 self.0.notify_all()
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use pin_project::pin_project;
121 use std::{marker::PhantomPinned, pin::pin};
122
123 #[test]
124 fn mutate_through_lock() {
125 let pm = pin!(PinnedMutex::new(15));
126 let mut locked = pm.as_ref().lock();
127 *locked = 16;
128 }
129
130 #[pin_project(UnsafeUnpin)]
131 struct MustPin {
132 value: u32,
133 pinned: PhantomPinned,
134 }
135
136 impl MustPin {
137 fn new() -> Self {
138 Self {
139 value: 0,
140 pinned: PhantomPinned,
141 }
142 }
143
144 fn inc(self: Pin<&mut Self>) -> u32 {
145 let value = self.project().value;
146 let prev = *value;
147 *value += 1;
148 prev
149 }
150
151 fn get(self: Pin<&Self>) -> u32 {
152 *self.project_ref().value
153 }
154 }
155
156 #[test]
157 fn pinned_method() {
158 let pm = pin!(PinnedMutex::new(MustPin::new()));
159 let mut locked = pm.as_ref().lock();
160 assert_eq!(0, locked.as_mut().inc());
161 assert_eq!(1, locked.as_mut().inc());
162 assert_eq!(2, locked.as_ref().get());
163 }
164
165 #[test]
166 fn ref_alias() {
167 let pm = pin!(PinnedMutex::new(MustPin::new()));
168 let locked = pm.as_ref().lock();
169 let a = locked.as_ref();
170 let b = locked.as_ref();
171 assert_eq!(a.value, b.value);
172 }
173
174 #[test]
175 fn cond_var() {
176 let cv = PinnedCondvar::new();
177 let pm = pin!(PinnedMutex::new(MustPin::new()));
178 let mut locked = pm.as_ref().lock();
179 locked.as_mut().inc();
180 let locked = cv.wait_while(locked, |pinned_contents| {
181 pinned_contents.as_ref().get() == 0
182 });
183 cv.wait_while(locked, |pinned_contents| {
184 pinned_contents.as_ref().get() == 0
185 });
186 cv.notify_one();
187 cv.notify_all();
188 }
189
190 #[derive(Debug, Default)]
191 struct DebugTest;
192
193 #[test]
194 fn default_and_debug() {
195 let pm: PinnedMutex<DebugTest> = Default::default();
196 _ = format!("{:?}", pm);
197 }
198}