1use std::ops::Deref;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::{Arc, Condvar, Mutex, MutexGuard};
4use std::time::Duration;
5
6#[derive(Debug, Default)]
44pub struct Carrier<T> {
45 pub(self) template: Arc<CarrierTarget<T>>,
47 shutdown: AtomicBool,
48}
49
50impl<T> Carrier<T> {
51 pub fn new(target: T) -> Self {
53 Self {
54 template: Arc::new(CarrierTarget {
55 target,
56 condvar: Default::default(),
57 count: Mutex::new(0),
58 }),
59 shutdown: AtomicBool::new(false),
60 }
61 }
62
63 pub fn create_ref(&self) -> Option<CarrierRef<T>> {
66 if !self.shutdown.load(Ordering::Acquire) {
67 Some(CarrierRef::new(&self.template))
68 } else {
69 None
70 }
71 }
72
73 pub fn ref_count(&self) -> usize {
78 *self.template.lock_count()
79 }
80
81 pub fn freeze(&self) {
86 self.shutdown.store(true, Ordering::Release);
87 }
88
89 pub fn is_frozen(&self) -> bool {
94 self.shutdown.load(Ordering::Acquire)
95 }
96
97 fn unwrap_or_panic(self) -> T {
98 let arc = self.template;
99 assert_eq!(
100 Arc::strong_count(&arc),
101 1,
102 "The carrier should not more than one outstanding Arc"
103 );
104
105 match Arc::try_unwrap(arc) {
106 Ok(t) => t.target,
107 Err(_arc) => {
108 panic!("The carrier should not have any outstanding references")
109 }
110 }
111 }
112
113 pub fn wait(self) -> T {
119 {
120 let count = self.template.lock_count();
121 let count = self
122 .template
123 .condvar
124 .wait_while(count, |count| *count != 0)
125 .expect("The carrier lock should not be poisoned");
126
127 assert_eq!(*count, 0);
128 }
129 self.unwrap_or_panic()
130 }
131
132 pub fn wait_timeout(self, timeout: Duration) -> Result<T, Self> {
139 let count = {
140 let count = self.template.lock_count();
141 let (count, _result) = self
142 .template
143 .condvar
144 .wait_timeout_while(count, timeout, |count| *count != 0)
145 .expect("The carrier lock should not be poisoned");
146 *count
147 };
148
149 if count == 0 {
150 Ok(self.unwrap_or_panic())
151 } else {
152 Err(self)
153 }
154 }
155
156 pub fn shutdown(self) -> T {
161 self.freeze();
162 self.wait()
163 }
164
165 pub fn shutdown_timeout(self, timeout: Duration) -> Result<T, Self> {
172 self.freeze();
173 self.wait_timeout(timeout)
174 }
175}
176
177impl<T> AsRef<T> for Carrier<T> {
178 fn as_ref(&self) -> &T {
179 &self.template.target
180 }
181}
182
183impl<T> Deref for Carrier<T> {
184 type Target = T;
185
186 fn deref(&self) -> &Self::Target {
187 &self.template.deref().target
188 }
189}
190
191#[derive(Debug, Default)]
192struct CarrierTarget<T> {
193 target: T,
194
195 condvar: Condvar,
196 count: Mutex<usize>,
197}
198
199impl<T> CarrierTarget<T> {
200 fn lock_count(&self) -> MutexGuard<usize> {
201 self.count
202 .lock()
203 .expect("The carrier lock should not be poisoned")
204 }
205}
206
207#[derive(Default)]
211pub struct CarrierRef<T> {
212 inner: Arc<CarrierTarget<T>>,
213}
214
215impl<T> CarrierRef<T> {
216 fn new(inner: &Arc<CarrierTarget<T>>) -> Self {
217 let mut count = inner.lock_count();
218 *count += 1;
219
220 CarrierRef {
221 inner: inner.clone(),
222 }
223 }
224
225 fn delete(&self) {
226 let mut count = self.inner.lock_count();
227 *count -= 1;
228
229 if *count == 0 {
230 self.inner.condvar.notify_one();
231 }
232 }
233}
234
235impl<T> AsRef<T> for CarrierRef<T> {
236 fn as_ref(&self) -> &T {
237 &self.inner.target
238 }
239}
240
241impl<T> Deref for CarrierRef<T> {
242 type Target = T;
243
244 fn deref(&self) -> &Self::Target {
245 &self.inner.deref().target
246 }
247}
248
249impl<T> Drop for CarrierRef<T> {
250 fn drop(&mut self) {
251 self.delete()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use crate::Carrier;
258 use std::cell::RefCell;
259 use std::time::Duration;
260
261 #[test]
262 fn test_basics() {
263 let carrier = Carrier::new(7usize);
264 assert_eq!(*carrier, 7usize);
265
266 let ref_one = carrier.create_ref().unwrap();
267 let ref_two = carrier.create_ref().unwrap();
268 let (ref_three, carrier) =
270 std::thread::spawn(|| (carrier.create_ref(), carrier))
271 .join()
272 .expect("Thread creation should never fail");
273 let ref_three = ref_three.unwrap();
274
275 assert_eq!(*ref_one, 7usize);
276 assert_eq!(*ref_two, 7usize);
277 assert_eq!(*ref_three, 7usize);
278
279 carrier.freeze();
280 assert!(carrier.is_frozen());
281 carrier.freeze();
283 assert!(carrier.is_frozen());
284
285 assert!(carrier.create_ref().is_none());
286 assert!(carrier.create_ref().is_none());
288
289 assert_eq!(carrier.ref_count(), 3);
290
291 let carrier =
292 carrier.wait_timeout(Duration::from_micros(1)).expect_err(
293 "Wait should not be successful \
294 since there are outstanding references",
295 );
296
297 drop(ref_one);
298 assert_eq!(carrier.ref_count(), 2);
299 drop(ref_two);
300 assert_eq!(carrier.ref_count(), 1);
301 drop(ref_three);
302 assert_eq!(carrier.ref_count(), 0);
303 assert_eq!(carrier.wait(), 7usize);
304 }
305
306 #[test]
307 #[should_panic]
308 fn test_panic_outstanding_arc() {
309 let carrier = Carrier::new(7usize);
310 let _outstanding_ref = carrier.template.clone();
311
312 carrier.wait();
314 }
315
316 #[test]
317 fn test_ref() {
318 let carrier = Carrier::new(RefCell::new(7usize));
319 let ref_one = carrier.create_ref().unwrap();
320 let ref_two = carrier.create_ref().unwrap();
321
322 *ref_two.borrow_mut() += 1;
323 assert_eq!(8, *ref_one.borrow());
324 assert_eq!(8, *carrier.borrow());
325
326 *ref_one.borrow_mut() += 1;
327 assert_eq!(9, *ref_two.borrow());
328 assert_eq!(9, *carrier.borrow());
329 }
330}