more_sync/carrier/
mod.rs

1use std::ops::Deref;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::{Arc, Condvar, Mutex, MutexGuard};
4use std::time::Duration;
5
6/// A Carrier that manages the lifetime of an instance of type `T`.
7///
8/// The carrier owns the instance (the `target`). References to the `target` can
9/// be obtained by calling the [`create_ref`](`Carrier::create_ref`) method. The
10/// references returned by the method will be valid as long as the reference is
11/// alive.
12///
13/// The carrier can be [*frozen*](`Carrier::freeze`), after which no new
14/// references can be obtained. The carrier can also [*wait*](`Carrier::wait`)
15/// for all references it gave out to be dropped. The ownership of `target` will
16/// be returned to the caller after the wait is complete. The caller can then
17/// carry out clean-ups or any other type of work that requires an owned
18/// instance of type `T`.
19///
20/// ```
21/// use more_sync::Carrier;
22///
23/// // Create a carrier that holds a mutex.
24/// let carrier = Carrier::new(std::sync::Mutex::new(7usize));
25///
26/// // Ask for a reference to the value held by the carrier.
27/// let ref_one = carrier.create_ref().unwrap();
28/// assert_eq!(*ref_one.lock().unwrap(), 7);
29///
30/// // Reference returned by Carrier can be sent to another thread.
31/// std::thread::spawn(move || *ref_one.lock().unwrap() = 8usize);
32///
33/// // Close the carrier, no new references can be created.
34/// carrier.freeze();
35/// assert!(carrier.create_ref().is_none());
36///
37/// // Shutdown the carrier and wait for all references to be dropped.
38/// // The value held by carrier is returned.
39/// let mutex_value = carrier.wait();
40/// // Destroy the mutex.
41/// assert!(matches!(mutex_value.into_inner(), Ok(8usize)));
42///
43#[derive(Debug, Default)]
44pub struct Carrier<T> {
45    // Visible to tests.
46    pub(self) template: Arc<CarrierTarget<T>>,
47    shutdown: AtomicBool,
48}
49
50impl<T> Carrier<T> {
51    /// Create a carrier that carries and owns `target`.
52    pub fn new(target: T) -> Self {
53        Self {
54            template: Arc::new(CarrierTarget {
55                target,
56                condvar: Default::default(),
57                count: Mutex::new(0),
58            }),
59            shutdown: AtomicBool::new(false),
60        }
61    }
62
63    /// Creates a reference to the owned instance. Returns `None` if the carrier
64    /// has been frozen.
65    pub fn create_ref(&self) -> Option<CarrierRef<T>> {
66        if !self.shutdown.load(Ordering::Acquire) {
67            Some(CarrierRef::new(&self.template))
68        } else {
69            None
70        }
71    }
72
73    /// Returns the number of outstanding references created by this carrier.
74    ///
75    /// The returned value is obsolete as soon as this method returns. The count
76    /// can change at any time.
77    pub fn ref_count(&self) -> usize {
78        *self.template.lock_count()
79    }
80
81    /// Closes this carrier.
82    ///
83    /// No new references can be created after the carrier is frozen. A frozen
84    /// carrier cannot be re-opened again.
85    pub fn freeze(&self) {
86        self.shutdown.store(true, Ordering::Release);
87    }
88
89    /// Returns `true` if the carrier has been frozen, `false` otherwise.
90    ///
91    /// For the same carrier, once this method returns `true` it will never
92    /// return `false` again.
93    pub fn is_frozen(&self) -> bool {
94        self.shutdown.load(Ordering::Acquire)
95    }
96
97    fn unwrap_or_panic(self) -> T {
98        let arc = self.template;
99        assert_eq!(
100            Arc::strong_count(&arc),
101            1,
102            "The carrier should not more than one outstanding Arc"
103        );
104
105        match Arc::try_unwrap(arc) {
106            Ok(t) => t.target,
107            Err(_arc) => {
108                panic!("The carrier should not have any outstanding references")
109            }
110        }
111    }
112
113    /// Blocks the current thread until all references created by this carrier
114    /// are dropped.
115    ///
116    /// [`wait()`](Carrier::wait) consumes the carrier and returns the owned
117    /// instance. It returns immediately if all references have been dropped.
118    pub fn wait(self) -> T {
119        {
120            let count = self.template.lock_count();
121            let count = self
122                .template
123                .condvar
124                .wait_while(count, |count| *count != 0)
125                .expect("The carrier lock should not be poisoned");
126
127            assert_eq!(*count, 0);
128        }
129        self.unwrap_or_panic()
130    }
131
132    /// Like [`wait()`](`Carrier::wait`), but waits for at most `timeout`.
133    ///
134    /// Returns `Ok` and the owned instance if the wait was successful. Returns
135    /// `Err(self)` if timed out. The reference count can change at any time. It
136    /// is **not** guaranteed that the number of references is greater than zero
137    /// when this method returns.
138    pub fn wait_timeout(self, timeout: Duration) -> Result<T, Self> {
139        let count = {
140            let count = self.template.lock_count();
141            let (count, _result) = self
142                .template
143                .condvar
144                .wait_timeout_while(count, timeout, |count| *count != 0)
145                .expect("The carrier lock should not be poisoned");
146            *count
147        };
148
149        if count == 0 {
150            Ok(self.unwrap_or_panic())
151        } else {
152            Err(self)
153        }
154    }
155
156    /// Closes the carrier and waits for all references to be dropped.
157    ///
158    /// A [`freeze()`](`Carrier::freeze`) followed by a
159    /// [`wait()`](`Carrier::wait`). See the comments in those two methods.
160    pub fn shutdown(self) -> T {
161        self.freeze();
162        self.wait()
163    }
164
165    /// Like [`shutdown()`](`Carrier::shutdown`), but waits for at most
166    /// `timeout`.
167    ///
168    /// A [`freeze()`](`Carrier::freeze`) followed by a
169    /// [`wait_timeout()`](`Carrier::wait_timeout`). See the comments in those
170    /// two methods.
171    pub fn shutdown_timeout(self, timeout: Duration) -> Result<T, Self> {
172        self.freeze();
173        self.wait_timeout(timeout)
174    }
175}
176
177impl<T> AsRef<T> for Carrier<T> {
178    fn as_ref(&self) -> &T {
179        &self.template.target
180    }
181}
182
183impl<T> Deref for Carrier<T> {
184    type Target = T;
185
186    fn deref(&self) -> &Self::Target {
187        &self.template.deref().target
188    }
189}
190
191#[derive(Debug, Default)]
192struct CarrierTarget<T> {
193    target: T,
194
195    condvar: Condvar,
196    count: Mutex<usize>,
197}
198
199impl<T> CarrierTarget<T> {
200    fn lock_count(&self) -> MutexGuard<usize> {
201        self.count
202            .lock()
203            .expect("The carrier lock should not be poisoned")
204    }
205}
206
207/// A reference to an object owned by a [`Carrier`](`Carrier`).
208///
209/// The target will be alive for as long as this reference is alive.
210#[derive(Default)]
211pub struct CarrierRef<T> {
212    inner: Arc<CarrierTarget<T>>,
213}
214
215impl<T> CarrierRef<T> {
216    fn new(inner: &Arc<CarrierTarget<T>>) -> Self {
217        let mut count = inner.lock_count();
218        *count += 1;
219
220        CarrierRef {
221            inner: inner.clone(),
222        }
223    }
224
225    fn delete(&self) {
226        let mut count = self.inner.lock_count();
227        *count -= 1;
228
229        if *count == 0 {
230            self.inner.condvar.notify_one();
231        }
232    }
233}
234
235impl<T> AsRef<T> for CarrierRef<T> {
236    fn as_ref(&self) -> &T {
237        &self.inner.target
238    }
239}
240
241impl<T> Deref for CarrierRef<T> {
242    type Target = T;
243
244    fn deref(&self) -> &Self::Target {
245        &self.inner.deref().target
246    }
247}
248
249impl<T> Drop for CarrierRef<T> {
250    fn drop(&mut self) {
251        self.delete()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use crate::Carrier;
258    use std::cell::RefCell;
259    use std::time::Duration;
260
261    #[test]
262    fn test_basics() {
263        let carrier = Carrier::new(7usize);
264        assert_eq!(*carrier, 7usize);
265
266        let ref_one = carrier.create_ref().unwrap();
267        let ref_two = carrier.create_ref().unwrap();
268        // Carrier should be send.
269        let (ref_three, carrier) =
270            std::thread::spawn(|| (carrier.create_ref(), carrier))
271                .join()
272                .expect("Thread creation should never fail");
273        let ref_three = ref_three.unwrap();
274
275        assert_eq!(*ref_one, 7usize);
276        assert_eq!(*ref_two, 7usize);
277        assert_eq!(*ref_three, 7usize);
278
279        carrier.freeze();
280        assert!(carrier.is_frozen());
281        // Double freeze is OK.
282        carrier.freeze();
283        assert!(carrier.is_frozen());
284
285        assert!(carrier.create_ref().is_none());
286        // Create should always fail.
287        assert!(carrier.create_ref().is_none());
288
289        assert_eq!(carrier.ref_count(), 3);
290
291        let carrier =
292            carrier.wait_timeout(Duration::from_micros(1)).expect_err(
293                "Wait should not be successful \
294                since there are outstanding references",
295            );
296
297        drop(ref_one);
298        assert_eq!(carrier.ref_count(), 2);
299        drop(ref_two);
300        assert_eq!(carrier.ref_count(), 1);
301        drop(ref_three);
302        assert_eq!(carrier.ref_count(), 0);
303        assert_eq!(carrier.wait(), 7usize);
304    }
305
306    #[test]
307    #[should_panic]
308    fn test_panic_outstanding_arc() {
309        let carrier = Carrier::new(7usize);
310        let _outstanding_ref = carrier.template.clone();
311
312        // Carrier should panic when it sees an outstanding Arc.
313        carrier.wait();
314    }
315
316    #[test]
317    fn test_ref() {
318        let carrier = Carrier::new(RefCell::new(7usize));
319        let ref_one = carrier.create_ref().unwrap();
320        let ref_two = carrier.create_ref().unwrap();
321
322        *ref_two.borrow_mut() += 1;
323        assert_eq!(8, *ref_one.borrow());
324        assert_eq!(8, *carrier.borrow());
325
326        *ref_one.borrow_mut() += 1;
327        assert_eq!(9, *ref_two.borrow());
328        assert_eq!(9, *carrier.borrow());
329    }
330}