oneshot_handshake/lib.rs
1#![doc = include_str!("../README.md")]
2use std::{fmt::Debug, ptr::NonNull, sync::Mutex};
3
4/// An empty struct signalling cancellation for [`Handshake`].
5///
6/// A [`channel`] can only be cancelled by a call to [`Drop::drop`] or [`take`].
7#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
8pub struct Cancelled;
9
10#[derive(Debug)]
11enum Inner<T> {
12 Unset,
13 Set(T)
14}
15
16/// A joint sender and receiver for a symmetric one time use channel.
17///
18/// # Examples
19///
20/// Using [`join`]:
21///
22/// ```
23/// let (u, v) = oneshot_handshake::channel::<u8>();
24///
25/// '_task_a: {
26/// let fst = u.join(1, std::ops::Add::add).unwrap();
27/// assert_eq!(fst, None)
28/// }
29///
30/// '_task_b: {
31/// let snd = v.join(2, std::ops::Add::add).unwrap();
32/// assert_eq!(snd, Some(3))
33/// }
34/// ```
35///
36/// Using [`try_push`] and [`try_pull`]:
37///
38/// ```
39/// let (u, v) = oneshot_handshake::channel::<u8>();
40///
41/// let a = u.try_push(3).unwrap();
42/// assert_eq!(a, Ok(()));
43///
44/// let b = v.try_pull().unwrap();
45/// assert_eq!(b, Ok(3))
46/// ```
47///
48/// [`join`]: Handshake::join
49/// [`try_push`]: Handshake::try_push
50/// [`try_pull`]: Handshake::try_pull
51#[derive(PartialEq, Eq, PartialOrd, Ord)]
52pub struct Handshake<T> {
53 common: NonNull<Mutex<Option<Inner<T>>>>
54}
55
56/// Creates a symmetric one time use channel.
57///
58/// Allows each end of the handshake to send or receive information for bi-directional movement of data.
59///
60/// # Examples
61///
62/// Using [`join`]:
63///
64/// ```
65/// let (u, v) = oneshot_handshake::channel::<u8>();
66///
67/// '_task_a: {
68/// let fst = u.join(1, std::ops::Add::add).unwrap();
69/// assert_eq!(fst, None)
70/// }
71///
72/// '_task_b: {
73/// let snd = v.join(2, std::ops::Add::add).unwrap();
74/// assert_eq!(snd, Some(3))
75/// }
76/// ```
77///
78/// Using [`try_push`] and [`try_pull`]:
79///
80/// ```
81/// let (u, v) = oneshot_handshake::channel::<u8>();
82///
83/// let a = u.try_push(3).unwrap();
84/// assert_eq!(a, Ok(()));
85///
86/// let b = v.try_pull().unwrap();
87/// assert_eq!(b, Ok(3))
88/// ```
89///
90/// [`join`]: Handshake::join
91/// [`try_push`]: Handshake::try_push
92/// [`try_pull`]: Handshake::try_pull
93pub fn channel<T>() -> (Handshake<T>, Handshake<T>) {
94 // check expected to be elided during compilation
95 let common = unsafe { NonNull::new_unchecked(Box::into_raw(
96 Box::new(Mutex::new(Some(Inner::Unset)))
97 ))};
98 (Handshake {common}, Handshake {common})
99}
100
101
102impl<T> Handshake<T> {
103
104 /// Creates a channel that has already been pushed to.
105 ///
106 /// The expression:
107 /// ```
108 /// let _ = oneshot_handshake::Handshake::<u8>::wrap(1);
109 /// ```
110 ///
111 /// Is the same as the expression:
112 /// ```
113 /// let _ = {
114 /// let (u, v) = oneshot_handshake::channel::<u8>();
115 /// u.try_push(1).unwrap().unwrap();
116 /// v
117 /// };
118 /// ```
119 pub fn wrap(value: T) -> Handshake<T> {
120 Handshake { common: unsafe {
121 NonNull::new_unchecked(Box::into_raw(
122 Box::new(Mutex::new(Some(Inner::Set(value))))
123 ))
124 } }
125 }
126
127 /// Pulls and pushes at the same time, garunteeing consumption of `self`.
128 ///
129 /// If `self` is [`Unset`] `f` will not be ran and `value` will be stored returning `Ok(None)`,
130 /// if `self` is [`Set`] with some `other` instance then `f` will be called with `other` and `value`
131 /// returning `Ok(return_value)`.
132 ///
133 /// Otherwise on cancellation `Err(value)` will be returned.
134 ///
135 /// If you only need to send or receive `value`, instead call [`try_push`] or [`try_pull`] respectively.
136 ///
137 /// [`try_push`]: Handshake::try_push
138 /// [`try_pull`]: Handshake::try_pull
139 ///
140 /// [`Set`]: Handshake::Set
141 /// [`Unset`]: Handshake::Unset
142 ///
143 /// # Example
144 ///
145 /// ```
146 /// let (u, v) = oneshot_handshake::channel::<u8>();
147 ///
148 /// '_task_a: {
149 /// let fst = u.join(1, std::ops::Add::add).unwrap();
150 /// assert_eq!(fst, None)
151 /// }
152 ///
153 /// '_task_b: {
154 /// let snd = v.join(2, std::ops::Add::add).unwrap();
155 /// assert_eq!(snd, Some(3))
156 /// }
157 /// ```
158 pub fn join<U, F: FnOnce(T, T) -> U>(self, value: T, f: F) -> Result<Option<U>, T> {
159 let common = self.common;
160 let last;
161 let res = '_lock: {
162 let mut lock = unsafe { common.as_ref() }.lock().unwrap();
163 match lock.take() {
164 Some(Inner::Unset) => {
165 // consumes `self`
166 std::mem::forget(self);
167 last = false;
168 let _ = lock.insert(Inner::Set(value));
169 Ok(None)
170 },
171 Some(Inner::Set(other)) => {
172 // consumes `self`
173 std::mem::forget(self);
174 last = true;
175 let _ = lock.insert(Inner::Unset);
176 Ok(Some((other, value)))
177 },
178 None => {
179 // consumes `self`
180 std::mem::forget(self);
181 last = true;
182 Err(value)
183 },
184 }
185 };
186 if last {
187 // last reference, drop pointer
188 drop(unsafe { Box::from_raw(common.as_ptr()) })
189 };
190 // isolate potential panic
191 res.map(|opt| opt.map(|(x, y)| (f)(x, y)))
192 }
193
194 /// Attempts to send a value through the channel.
195 ///
196 /// If `self` is [`Unset`] `value` will be stored returning `Ok(Ok(()))`,
197 /// if `self` is [`Set`] with some `other` instance then pushing will fail
198 /// and `Ok(Err((self, value)))` will be returned.
199 ///
200 /// Otherwise on cancellation `Err(value)` will be returned.
201 ///
202 /// If you are handling `value` symetrically, consider calling [`join`].
203 ///
204 /// [`join`]: Handshake::join
205 ///
206 /// [`Set`]: Handshake::Set
207 /// [`Unset`]: Handshake::Unset
208 ///
209 /// # Example
210 ///
211 /// ```
212 /// let (u, v) = oneshot_handshake::channel::<u8>();
213 ///
214 /// let a = u.try_push(3).unwrap();
215 /// assert_eq!(a, Ok(()));
216 ///
217 /// let b = v.try_pull().unwrap();
218 /// assert_eq!(b, Ok(3))
219 /// ```
220 pub fn try_push(self, value: T) -> Result<Result<(), (Self, T)>, T> {
221 let common = self.common;
222 let last;
223 let res = '_lock: {
224 let mut lock = unsafe { common.as_ref() }.lock().unwrap();
225 match lock.take() {
226 Some(Inner::Unset) => {
227 // consumes `self`
228 std::mem::forget(self);
229 last = false;
230 let _ = lock.insert(Inner::Set(value));
231 Ok(Ok(()))
232 },
233 Some(Inner::Set(other)) => {
234 last = false;
235 let _ = lock.insert(Inner::Set(other));
236 Ok(Err((self, value)))
237 },
238 None => {
239 // consumes `self`
240 std::mem::forget(self);
241 last = true;
242 Err(value)
243 },
244 }
245 };
246 if last {
247 // last reference, drop pointer
248 drop(unsafe { Box::from_raw(common.as_ptr()) })
249 };
250 res
251 }
252
253 /// Attempts to receive a value through the channel.
254 ///
255 /// If `self` is [`Unset`] then pulling will fail returning `Ok(Err(self))`,
256 /// if `self` is [`Set`] with some `value` then `Ok(Ok(value))` will be returned.
257 ///
258 /// Otherwise on cancellation `Err(Cancelled)` will be returned.
259 ///
260 /// If you are handling `value` symetrically, consider calling [`join`].
261 ///
262 /// [`join`]: Handshake::join
263 ///
264 /// [`Set`]: Handshake::Set
265 /// [`Unset`]: Handshake::Unset
266 ///
267 /// # Example
268 ///
269 /// ```
270 /// let (u, v) = oneshot_handshake::channel::<u8>();
271 ///
272 /// let a = u.try_push(3).unwrap();
273 /// assert_eq!(a, Ok(()));
274 ///
275 /// let b = v.try_pull().unwrap();
276 /// assert_eq!(b, Ok(3))
277 /// ```
278 pub fn try_pull(self) -> Result<Result<T, Self>, Cancelled> {
279 let common = self.common;
280 let last;
281 let res = '_lock: {
282 let mut lock = unsafe { common.as_ref() }.lock().unwrap();
283 match lock.take() {
284 Some(Inner::Unset) => {
285 last = false;
286 let _ = lock.insert(Inner::Unset);
287 Ok(Err(self))
288 },
289 Some(Inner::Set(value)) => {
290 // consumes `self`
291 std::mem::forget(self);
292 last = true;
293 let _ = lock.insert(Inner::Unset);
294 Ok(Ok(value))
295 },
296 None => {
297 // consumes `self`
298 std::mem::forget(self);
299 last = true;
300 Err(Cancelled)
301 },
302 }
303 };
304 if last {
305 // last reference, drop pointer
306 drop(unsafe { Box::from_raw(common.as_ptr()) })
307 };
308 res
309 }
310
311 /// Checks the channel to see if there is a value present.
312 ///
313 /// If the channel is cancelled then `Err(Cancelled)` will be returned, otherwise
314 /// a boolean value will be returned indicating whether or not the channel is set.
315 ///
316 /// # Example
317 ///
318 /// ```
319 /// let (u, v) = oneshot_handshake::channel::<u8>();
320 ///
321 /// assert_eq!(v.is_set().unwrap(), false);
322 /// let _ = u.try_push(3).unwrap();
323 /// assert_eq!(v.is_set().unwrap(), true)
324 /// ```
325 pub fn is_set(&self) -> Result<bool, Cancelled> {
326 '_lock: {
327 match &mut* unsafe { self.common.as_ref() }.lock().unwrap() {
328 Some(Inner::Unset) => Ok(false),
329 Some(Inner::Set(_)) => Ok(true),
330 None => Err(Cancelled),
331 }
332 }
333 }
334}
335
336/// Pulls a value "now or never" garunteeing consumption of `self`.
337/// The channel will be cancelled if no value is set.
338///
339/// If you do not handle cancellation on the other side of the handshake
340/// and have no garuntees that both parts will be cancelled in unison then use [`try_pull`] instead.
341///
342/// This function is provided as an alternative to [`Drop::drop`]
343/// that prevents blowing the stack from deeply nested channels.
344///
345/// [`try_pull`]: Handshake::try_pull
346///
347/// # Example
348///
349/// Without using [`take`]:
350///
351/// ```
352/// enum MyRecursiveType {
353/// // recursive channel
354/// Channel(std::mem::ManuallyDrop<oneshot_handshake::Handshake<MyRecursiveType>>),
355/// Data(Box<[u8]>)
356/// }
357///
358/// impl Drop for MyRecursiveType {
359/// // a recursive drop implementaiton is unavoidable
360/// fn drop(&mut self) {
361/// match self {
362/// MyRecursiveType::Channel(channel) => {
363/// let channel = unsafe { std::mem::ManuallyDrop::take(channel) };
364/// // forced to call `Drop::drop` to garuntee consumption
365/// std::mem::drop(channel)
366/// },
367/// MyRecursiveType::Data(_) => ()
368/// };
369/// }
370/// }
371/// ```
372///
373/// Using [`take`]:
374///
375/// ```
376/// enum MyRecursiveType {
377/// // recursive channel
378/// Channel(std::mem::ManuallyDrop<oneshot_handshake::Handshake<MyRecursiveType>>),
379/// Data(Box<[u8]>)
380/// }
381///
382/// impl Drop for MyRecursiveType {
383/// fn drop(&mut self) {
384/// // handling dropping by ref
385/// match self {
386/// MyRecursiveType::Channel(channel) => {
387/// let channel = unsafe { std::mem::ManuallyDrop::take(channel) };
388/// // handling dropping by value
389/// let mut next = oneshot_handshake::take(channel);
390/// // iterative drop
391/// while let Some(mut obj) = next.take() {
392/// match &mut obj {
393/// MyRecursiveType::Channel(channel) =>
394/// next = oneshot_handshake::take(unsafe {
395/// std::mem::ManuallyDrop::take(channel)
396/// }), // avoids recursion
397/// MyRecursiveType::Data(_) => (),
398/// }
399/// }
400/// },
401/// MyRecursiveType::Data(_) => ()
402/// };
403/// }
404/// }
405/// ```
406pub fn take<T>(handshake: Handshake<T>) -> Option<T> {
407 let value;
408 if match unsafe { handshake.common.as_ref() }.lock().unwrap().take() {
409 Some(Inner::Unset) => { value = None; false },
410 Some(Inner::Set(inner_value)) => { value = Some(inner_value); true },
411 None => {value = None; true },
412 } {
413 // last reference, drop pointer
414 drop(unsafe { Box::from_raw(handshake.common.as_ptr()) })
415 };
416 // avoid double drop
417 std::mem::forget(handshake);
418 value
419}
420
421impl<T> Drop for Handshake<T> {
422 fn drop(&mut self) {
423 if match unsafe { self.common.as_ref() }.lock().unwrap().take() {
424 Some(Inner::Unset) => false,
425 Some(Inner::Set(value)) => { drop(value); true },
426 None => true,
427 } {
428 // last reference, drop pointer
429 drop(unsafe { Box::from_raw(self.common.as_ptr()) })
430 }
431 }
432}
433
434unsafe impl<T: Send> Sync for Handshake<T> {}
435
436unsafe impl<T: Send> Send for Handshake<T> {}
437
438impl<T: Debug> Debug for Handshake<T> {
439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440 f.debug_struct("Handshake").field("common", unsafe { self.common.as_ref() }).finish()
441 }
442}
443
444#[cfg(test)]
445mod test {
446 use std::convert::identity;
447 use super::*;
448
449 #[test]
450 fn drop_test() {
451 let (u, v) = channel::<()>();
452 drop(u);
453 drop(v);
454
455 let (u, v) = channel::<()>();
456 drop(v);
457 drop(u)
458 }
459
460 #[test]
461 fn push_drop_test() {
462 #[derive(Debug)]
463 struct Loud<'a> {
464 flag: &'a mut bool
465 }
466
467 impl<'a> Drop for Loud<'a> {
468 fn drop(&mut self) {
469 *self.flag = true;
470 }
471 }
472
473 let mut dropped = false;
474 let (u, v) = channel::<Loud>();
475 u.try_push(Loud { flag: &mut dropped }).unwrap().unwrap();
476 drop(v);
477
478 assert_eq!(dropped, true);
479 }
480
481 #[test]
482 fn wrap_drop_test() {
483 #[derive(Debug)]
484 struct Loud<'a> {
485 flag: &'a mut bool
486 }
487
488 impl<'a> Drop for Loud<'a> {
489 fn drop(&mut self) {
490 *self.flag = true;
491 }
492 }
493
494 let mut dropped = false;
495 let u = Handshake::wrap(Loud { flag: &mut dropped });
496 drop(u);
497
498 assert_eq!(dropped, true);
499 }
500
501 #[test]
502 fn pull_test() {
503 let (u, v) = channel::<()>();
504 assert_eq!(u.try_pull(), Ok(Err(v)));
505
506 let (u, v) = channel::<()>();
507 assert_eq!(v.try_pull(), Ok(Err(u)))
508 }
509
510 #[test]
511 fn push_test() {
512 let (u, v) = channel::<()>();
513 assert_eq!(u.try_push(()), Ok(Ok(())));
514 drop(v);
515
516 let (u, v) = channel::<()>();
517 assert_eq!(v.try_push(()), Ok(Ok(())));
518 drop(u)
519 }
520
521 #[test]
522 fn double_push_test() {
523 let (u, v) = channel::<()>();
524 u.try_push(()).unwrap().unwrap();
525 drop(v.try_push(()).unwrap().err().unwrap());
526
527 let (u, v) = channel::<()>();
528 v.try_push(()).unwrap().unwrap();
529 drop(u.try_push(()).unwrap().err().unwrap())
530 }
531
532 #[test]
533 fn pull_cancel_test() {
534 let (u, v) = channel::<()>();
535 drop(u);
536 assert_eq!(v.try_pull(), Err(Cancelled));
537
538 let (u, v) = channel::<()>();
539 drop(v);
540 assert_eq!(u.try_pull(), Err(Cancelled));
541 }
542
543 #[test]
544 fn push_cancel_test() {
545 let (u, v) = channel::<()>();
546 drop(u);
547 assert_eq!(v.try_push(()), Err(()));
548
549 let (u, v) = channel::<()>();
550 drop(v);
551 assert_eq!(u.try_push(()), Err(()));
552 }
553
554 #[test]
555 fn push_pull_test() {
556 let (u, v) = channel::<()>();
557 u.try_push(()).unwrap().unwrap();
558 v.try_pull().unwrap().unwrap();
559
560 let (u, v) = channel::<()>();
561 v.try_push(()).unwrap().unwrap();
562 u.try_pull().unwrap().unwrap()
563 }
564
565 #[test]
566 fn wrap_pull_test() {
567 let u = Handshake::wrap(());
568 u.try_pull().unwrap().unwrap()
569 }
570
571 #[test]
572 fn join_test() {
573 let (u, v) = channel::<()>();
574 assert_eq!(u.join((), |_, _| ()).unwrap(), None);
575 assert_eq!(v.join((), |_, _| ()).unwrap(), Some(()));
576
577 let (u, v) = channel::<()>();
578 assert_eq!(v.join((), |_, _| ()).unwrap(), None);
579 assert_eq!(u.join((), |_, _| ()).unwrap(), Some(()))
580 }
581
582 #[test]
583 fn collision_check() {
584 use rand::prelude::*;
585 const N: usize = 64;
586
587 let mut left: Vec<Handshake<usize>> = vec![];
588 let mut right: Vec<Handshake<usize>> = vec![];
589 for _ in 0..N {
590 let (u, v) = channel::<usize>();
591 left.push(u);
592 right.push(v)
593 }
594 let mut rng = rand::thread_rng();
595 left.shuffle(&mut rng);
596 right.shuffle(&mut rng);
597 let left_thread = std::thread::spawn(|| left
598 .into_iter()
599 .enumerate()
600 .map(|(n, u)| {u.join(n, |x, y| (x, y)).unwrap()})
601 .filter_map(identity).collect::<Vec<(usize, usize)>>()
602 );
603 let right_thread = std::thread::spawn(|| right
604 .into_iter()
605 .enumerate()
606 .map(|(n, v)| {v.join(n, |x, y| (x, y)).unwrap()})
607 .filter_map(identity).collect::<Vec<(usize, usize)>>()
608 );
609 let total = left_thread.join().unwrap().len() + right_thread.join().unwrap().len();
610 assert_eq!(total, N)
611 }
612}