oneshot_handshake/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{fmt::Debug, ptr::NonNull, sync::Mutex};
3
4/// An empty struct signalling cancellation for [`Handshake`].
5/// 
6/// A [`channel`] can only be cancelled by a call to [`Drop::drop`] or [`take`].
7#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
8pub struct Cancelled;
9
10#[derive(Debug)]
11enum Inner<T> {
12    Unset,
13    Set(T)
14}
15
16/// A joint sender and receiver for a symmetric one time use channel.
17/// 
18/// # Examples
19/// 
20/// Using [`join`]:
21/// 
22/// ```
23/// let (u, v) = oneshot_handshake::channel::<u8>();
24/// 
25/// '_task_a: {
26///     let fst = u.join(1, std::ops::Add::add).unwrap();
27///     assert_eq!(fst, None)
28/// }
29///
30/// '_task_b: {
31///     let snd = v.join(2, std::ops::Add::add).unwrap();
32///     assert_eq!(snd, Some(3))
33/// }
34/// ```
35/// 
36/// Using [`try_push`] and [`try_pull`]:
37/// 
38/// ```
39/// let (u, v) = oneshot_handshake::channel::<u8>();
40/// 
41/// let a = u.try_push(3).unwrap();
42/// assert_eq!(a, Ok(()));
43///
44/// let b = v.try_pull().unwrap();
45/// assert_eq!(b, Ok(3))
46/// ```
47/// 
48/// [`join`]: Handshake::join
49/// [`try_push`]: Handshake::try_push
50/// [`try_pull`]: Handshake::try_pull
51#[derive(PartialEq, Eq, PartialOrd, Ord)]
52pub struct Handshake<T> {
53    common: NonNull<Mutex<Option<Inner<T>>>>
54}
55
56/// Creates a symmetric one time use channel.
57/// 
58/// Allows each end of the handshake to send or receive information for bi-directional movement of data.
59/// 
60/// # Examples
61/// 
62/// Using [`join`]:
63/// 
64/// ```
65/// let (u, v) = oneshot_handshake::channel::<u8>();
66/// 
67/// '_task_a: {
68///     let fst = u.join(1, std::ops::Add::add).unwrap();
69///     assert_eq!(fst, None)
70/// }
71///
72/// '_task_b: {
73///     let snd = v.join(2, std::ops::Add::add).unwrap();
74///     assert_eq!(snd, Some(3))
75/// }
76/// ```
77/// 
78/// Using [`try_push`] and [`try_pull`]:
79/// 
80/// ```
81/// let (u, v) = oneshot_handshake::channel::<u8>();
82/// 
83/// let a = u.try_push(3).unwrap();
84/// assert_eq!(a, Ok(()));
85///
86/// let b = v.try_pull().unwrap();
87/// assert_eq!(b, Ok(3))
88/// ```
89/// 
90/// [`join`]: Handshake::join
91/// [`try_push`]: Handshake::try_push
92/// [`try_pull`]: Handshake::try_pull
93pub fn channel<T>() -> (Handshake<T>, Handshake<T>) {
94    // check expected to be elided during compilation
95    let common = unsafe { NonNull::new_unchecked(Box::into_raw(
96        Box::new(Mutex::new(Some(Inner::Unset)))
97    ))};
98    (Handshake {common}, Handshake {common})
99}
100
101
102impl<T> Handshake<T> {
103
104    /// Creates a channel that has already been pushed to.
105    /// 
106    /// The expression:
107    /// ```
108    /// let _ = oneshot_handshake::Handshake::<u8>::wrap(1);
109    /// ```
110    /// 
111    /// Is the same as the expression:
112    /// ```
113    /// let _ = {
114    ///     let (u, v) = oneshot_handshake::channel::<u8>();
115    ///     u.try_push(1).unwrap().unwrap();
116    ///     v
117    /// };
118    /// ```
119    pub fn wrap(value: T) -> Handshake<T> {
120        Handshake { common: unsafe {
121            NonNull::new_unchecked(Box::into_raw(
122                Box::new(Mutex::new(Some(Inner::Set(value))))
123            ))
124        } }
125    }
126
127    /// Pulls and pushes at the same time, garunteeing consumption of `self`.
128    /// 
129    /// If `self` is [`Unset`] `f` will not be ran and `value` will be stored returning `Ok(None)`,
130    /// if `self` is [`Set`] with some `other` instance then `f` will be called with `other` and `value`
131    /// returning `Ok(return_value)`.
132    /// 
133    /// Otherwise on cancellation `Err(value)` will be returned.
134    /// 
135    /// If you only need to send or receive `value`, instead call [`try_push`] or [`try_pull`] respectively.
136    /// 
137    /// [`try_push`]: Handshake::try_push
138    /// [`try_pull`]: Handshake::try_pull
139    /// 
140    /// [`Set`]: Handshake::Set
141    /// [`Unset`]: Handshake::Unset
142    /// 
143    /// # Example
144    /// 
145    /// ```
146    /// let (u, v) = oneshot_handshake::channel::<u8>();
147    /// 
148    /// '_task_a: {
149    ///     let fst = u.join(1, std::ops::Add::add).unwrap();
150    ///     assert_eq!(fst, None)
151    /// }
152    ///
153    /// '_task_b: {
154    ///     let snd = v.join(2, std::ops::Add::add).unwrap();
155    ///     assert_eq!(snd, Some(3))
156    /// }
157    /// ```
158    pub fn join<U, F: FnOnce(T, T) -> U>(self, value: T, f: F) -> Result<Option<U>, T> {
159        let common = self.common;
160        let last;
161        let res = '_lock: {
162            let mut lock = unsafe { common.as_ref() }.lock().unwrap();
163            match lock.take() {
164                Some(Inner::Unset) => {
165                    // consumes `self`
166                    std::mem::forget(self);
167                    last = false;
168                    let _ = lock.insert(Inner::Set(value));
169                    Ok(None)
170                },
171                Some(Inner::Set(other)) => {
172                    // consumes `self`
173                    std::mem::forget(self);
174                    last = true;
175                    let _ = lock.insert(Inner::Unset);
176                    Ok(Some((other, value)))
177                },
178                None => {
179                    // consumes `self`
180                    std::mem::forget(self);
181                    last = true;
182                    Err(value)
183                },
184            }
185        };
186        if last {
187            // last reference, drop pointer
188            drop(unsafe { Box::from_raw(common.as_ptr()) })
189        };
190        // isolate potential panic
191        res.map(|opt| opt.map(|(x, y)| (f)(x, y)))
192    }
193
194    /// Attempts to send a value through the channel.
195    /// 
196    /// If `self` is [`Unset`] `value` will be stored returning `Ok(Ok(()))`,
197    /// if `self` is [`Set`] with some `other` instance then pushing will fail
198    /// and `Ok(Err((self, value)))` will be returned.
199    /// 
200    /// Otherwise on cancellation `Err(value)` will be returned.
201    /// 
202    /// If you are handling `value` symetrically, consider calling [`join`].
203    /// 
204    /// [`join`]: Handshake::join
205    /// 
206    /// [`Set`]: Handshake::Set
207    /// [`Unset`]: Handshake::Unset
208    /// 
209    /// # Example
210    /// 
211    /// ```
212    /// let (u, v) = oneshot_handshake::channel::<u8>();
213    /// 
214    /// let a = u.try_push(3).unwrap();
215    /// assert_eq!(a, Ok(()));
216    ///
217    /// let b = v.try_pull().unwrap();
218    /// assert_eq!(b, Ok(3))
219    /// ```
220    pub fn try_push(self, value: T) -> Result<Result<(), (Self, T)>, T> {
221        let common = self.common;
222        let last;
223        let res = '_lock: {
224            let mut lock = unsafe { common.as_ref() }.lock().unwrap();
225            match lock.take() {
226                Some(Inner::Unset) => {
227                    // consumes `self`
228                    std::mem::forget(self);
229                    last = false;
230                    let _ = lock.insert(Inner::Set(value));
231                    Ok(Ok(()))
232                },
233                Some(Inner::Set(other)) => {
234                    last = false;
235                    let _ = lock.insert(Inner::Set(other));
236                    Ok(Err((self, value)))
237                },
238                None => {
239                    // consumes `self`
240                    std::mem::forget(self);
241                    last = true;
242                    Err(value)
243                },
244            }
245        };
246        if last {
247            // last reference, drop pointer
248            drop(unsafe { Box::from_raw(common.as_ptr()) })
249        };
250        res
251    }
252
253    /// Attempts to receive a value through the channel.
254    /// 
255    /// If `self` is [`Unset`] then pulling will fail returning `Ok(Err(self))`,
256    /// if `self` is [`Set`] with some `value` then `Ok(Ok(value))` will be returned.
257    /// 
258    /// Otherwise on cancellation `Err(Cancelled)` will be returned.
259    /// 
260    /// If you are handling `value` symetrically, consider calling [`join`].
261    /// 
262    /// [`join`]: Handshake::join
263    /// 
264    /// [`Set`]: Handshake::Set
265    /// [`Unset`]: Handshake::Unset
266    /// 
267    /// # Example
268    /// 
269    /// ```
270    /// let (u, v) = oneshot_handshake::channel::<u8>();
271    /// 
272    /// let a = u.try_push(3).unwrap();
273    /// assert_eq!(a, Ok(()));
274    ///
275    /// let b = v.try_pull().unwrap();
276    /// assert_eq!(b, Ok(3))
277    /// ```
278    pub fn try_pull(self) -> Result<Result<T, Self>, Cancelled> {
279        let common = self.common;
280        let last;
281        let res = '_lock: {
282            let mut lock = unsafe { common.as_ref() }.lock().unwrap();
283            match lock.take() {
284                Some(Inner::Unset) => {
285                    last = false;
286                    let _ = lock.insert(Inner::Unset);
287                    Ok(Err(self))
288                },
289                Some(Inner::Set(value)) => {
290                    // consumes `self`
291                    std::mem::forget(self);
292                    last = true;
293                    let _ = lock.insert(Inner::Unset);
294                    Ok(Ok(value))
295                },
296                None => {
297                    // consumes `self`
298                    std::mem::forget(self);
299                    last = true;
300                    Err(Cancelled)
301                },
302            }
303        };
304        if last {
305            // last reference, drop pointer
306            drop(unsafe { Box::from_raw(common.as_ptr()) })
307        };
308        res
309    }
310
311    /// Checks the channel to see if there is a value present.
312    /// 
313    /// If the channel is cancelled then `Err(Cancelled)` will be returned, otherwise
314    /// a boolean value will be returned indicating whether or not the channel is set.
315    /// 
316    /// # Example
317    /// 
318    /// ```
319    /// let (u, v) = oneshot_handshake::channel::<u8>();
320    /// 
321    /// assert_eq!(v.is_set().unwrap(), false);
322    /// let _ = u.try_push(3).unwrap();
323    /// assert_eq!(v.is_set().unwrap(), true)
324    /// ```
325    pub fn is_set(&self) -> Result<bool, Cancelled> {
326        '_lock: {
327            match &mut* unsafe { self.common.as_ref() }.lock().unwrap() {
328                Some(Inner::Unset) => Ok(false),
329                Some(Inner::Set(_)) => Ok(true),
330                None => Err(Cancelled),
331            }
332        }
333    }
334}
335
336/// Pulls a value "now or never" garunteeing consumption of `self`.
337/// The channel will be cancelled if no value is set.
338/// 
339/// If you do not handle cancellation on the other side of the handshake
340/// and have no garuntees that both parts will be cancelled in unison then use [`try_pull`] instead.
341/// 
342/// This function is provided as an alternative to [`Drop::drop`]
343/// that prevents blowing the stack from deeply nested channels.
344/// 
345/// [`try_pull`]: Handshake::try_pull
346/// 
347/// # Example
348/// 
349/// Without using [`take`]:
350/// 
351/// ```
352/// enum MyRecursiveType {
353///     // recursive channel
354///     Channel(std::mem::ManuallyDrop<oneshot_handshake::Handshake<MyRecursiveType>>),
355///     Data(Box<[u8]>)
356/// }
357/// 
358/// impl Drop for MyRecursiveType {
359///     // a recursive drop implementaiton is unavoidable
360///     fn drop(&mut self) {
361///         match self {
362///             MyRecursiveType::Channel(channel) => {
363///                 let channel = unsafe { std::mem::ManuallyDrop::take(channel) };
364///                 // forced to call `Drop::drop` to garuntee consumption
365///                 std::mem::drop(channel)
366///             },
367///             MyRecursiveType::Data(_) => ()
368///         };
369///     }
370/// }
371/// ```
372/// 
373/// Using [`take`]:
374/// 
375/// ```
376/// enum MyRecursiveType {
377///     // recursive channel
378///     Channel(std::mem::ManuallyDrop<oneshot_handshake::Handshake<MyRecursiveType>>),
379///     Data(Box<[u8]>)
380/// }
381/// 
382/// impl Drop for MyRecursiveType {
383///     fn drop(&mut self) {
384///         // handling dropping by ref
385///         match self {
386///             MyRecursiveType::Channel(channel) => {
387///                 let channel = unsafe { std::mem::ManuallyDrop::take(channel) };
388///                 // handling dropping by value
389///                 let mut next = oneshot_handshake::take(channel);
390///                 // iterative drop
391///                 while let Some(mut obj) = next.take() {
392///                     match &mut obj {
393///                         MyRecursiveType::Channel(channel) =>
394///                             next = oneshot_handshake::take(unsafe {
395///                                 std::mem::ManuallyDrop::take(channel) 
396///                             }), // avoids recursion
397///                         MyRecursiveType::Data(_) => (),
398///                     }
399///                 }
400///             },
401///             MyRecursiveType::Data(_) => ()
402///         };
403///     }
404/// }
405/// ```
406pub fn take<T>(handshake: Handshake<T>) -> Option<T> {
407    let value;
408    if match unsafe { handshake.common.as_ref() }.lock().unwrap().take() {
409        Some(Inner::Unset) => { value = None; false },
410        Some(Inner::Set(inner_value)) => { value = Some(inner_value); true },
411        None => {value = None; true },
412    } {
413        // last reference, drop pointer
414        drop(unsafe { Box::from_raw(handshake.common.as_ptr()) })
415    };
416    // avoid double drop
417    std::mem::forget(handshake);
418    value
419}
420
421impl<T> Drop for Handshake<T> {
422    fn drop(&mut self) {
423        if match unsafe { self.common.as_ref() }.lock().unwrap().take() {
424            Some(Inner::Unset) => false,
425            Some(Inner::Set(value)) => { drop(value); true },
426            None => true,
427        } {
428            // last reference, drop pointer
429            drop(unsafe { Box::from_raw(self.common.as_ptr()) })
430        }
431    }
432}
433
434unsafe impl<T: Send> Sync for Handshake<T> {}
435
436unsafe impl<T: Send> Send for Handshake<T> {}
437
438impl<T: Debug> Debug for Handshake<T> {
439    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440        f.debug_struct("Handshake").field("common", unsafe { self.common.as_ref() }).finish()
441    }
442}
443
444#[cfg(test)]
445mod test {
446    use std::convert::identity;
447    use super::*;
448
449    #[test]
450    fn drop_test() {
451        let (u, v) = channel::<()>();
452        drop(u);
453        drop(v);
454
455        let (u, v) = channel::<()>();
456        drop(v);
457        drop(u)
458    }
459
460    #[test]
461    fn push_drop_test() {
462        #[derive(Debug)]
463        struct Loud<'a> {
464            flag: &'a mut bool
465        }
466
467        impl<'a> Drop for Loud<'a> {
468            fn drop(&mut self) {
469                *self.flag = true;
470            }
471        }
472
473        let mut dropped = false;
474        let (u, v) = channel::<Loud>();
475        u.try_push(Loud { flag: &mut dropped }).unwrap().unwrap();
476        drop(v);
477
478        assert_eq!(dropped, true);
479    }
480
481    #[test]
482    fn wrap_drop_test() {
483        #[derive(Debug)]
484        struct Loud<'a> {
485            flag: &'a mut bool
486        }
487
488        impl<'a> Drop for Loud<'a> {
489            fn drop(&mut self) {
490                *self.flag = true;
491            }
492        }
493
494        let mut dropped = false;
495        let u = Handshake::wrap(Loud { flag: &mut dropped });
496        drop(u);
497
498        assert_eq!(dropped, true);
499    }
500
501    #[test]
502    fn pull_test() {
503        let (u, v) = channel::<()>();
504        assert_eq!(u.try_pull(), Ok(Err(v)));
505
506        let (u, v) = channel::<()>();
507        assert_eq!(v.try_pull(), Ok(Err(u)))
508    }
509
510    #[test]
511    fn push_test() {
512        let (u, v) = channel::<()>();
513        assert_eq!(u.try_push(()), Ok(Ok(())));
514        drop(v);
515
516        let (u, v) = channel::<()>();
517        assert_eq!(v.try_push(()), Ok(Ok(())));
518        drop(u)
519    }
520
521    #[test]
522    fn double_push_test() {
523        let (u, v) = channel::<()>();
524        u.try_push(()).unwrap().unwrap();
525        drop(v.try_push(()).unwrap().err().unwrap());
526
527        let (u, v) = channel::<()>();
528        v.try_push(()).unwrap().unwrap();
529        drop(u.try_push(()).unwrap().err().unwrap())
530    }
531
532    #[test]
533    fn pull_cancel_test() {
534        let (u, v) = channel::<()>();
535        drop(u);
536        assert_eq!(v.try_pull(), Err(Cancelled));
537
538        let (u, v) = channel::<()>();
539        drop(v);
540        assert_eq!(u.try_pull(), Err(Cancelled));
541    }
542
543    #[test]
544    fn push_cancel_test() {
545        let (u, v) = channel::<()>();
546        drop(u);
547        assert_eq!(v.try_push(()), Err(()));
548
549        let (u, v) = channel::<()>();
550        drop(v);
551        assert_eq!(u.try_push(()), Err(()));
552    }
553
554    #[test]
555    fn push_pull_test() {
556        let (u, v) = channel::<()>();
557        u.try_push(()).unwrap().unwrap();
558        v.try_pull().unwrap().unwrap();
559
560        let (u, v) = channel::<()>();
561        v.try_push(()).unwrap().unwrap();
562        u.try_pull().unwrap().unwrap()
563    }
564
565    #[test]
566    fn wrap_pull_test() {
567        let u = Handshake::wrap(());
568        u.try_pull().unwrap().unwrap()
569    }
570
571    #[test]
572    fn join_test() {
573        let (u, v) = channel::<()>();
574        assert_eq!(u.join((), |_, _| ()).unwrap(), None);
575        assert_eq!(v.join((), |_, _| ()).unwrap(), Some(()));
576
577        let (u, v) = channel::<()>();
578        assert_eq!(v.join((), |_, _| ()).unwrap(), None);
579        assert_eq!(u.join((), |_, _| ()).unwrap(), Some(()))
580    }
581
582    #[test]
583    fn collision_check() {
584        use rand::prelude::*;
585        const N: usize = 64;
586
587        let mut left: Vec<Handshake<usize>> = vec![];
588        let mut right: Vec<Handshake<usize>> = vec![];
589        for _ in 0..N {
590            let (u, v) = channel::<usize>();
591            left.push(u);
592            right.push(v)
593        }
594        let mut rng = rand::thread_rng();
595        left.shuffle(&mut rng);
596        right.shuffle(&mut rng);
597        let left_thread = std::thread::spawn(|| left
598            .into_iter()
599            .enumerate()
600            .map(|(n, u)| {u.join(n, |x, y| (x, y)).unwrap()})
601            .filter_map(identity).collect::<Vec<(usize, usize)>>()
602        );
603        let right_thread = std::thread::spawn(|| right
604            .into_iter()
605            .enumerate()
606            .map(|(n, v)| {v.join(n, |x, y| (x, y)).unwrap()})
607            .filter_map(identity).collect::<Vec<(usize, usize)>>()
608        );
609        let total = left_thread.join().unwrap().len() + right_thread.join().unwrap().len();
610        assert_eq!(total, N)
611    }
612}