ntex_net/
channel.rs

1//! A one-shot, futures-aware channel.
2use std::{cell::Cell, rc::Rc};
3use std::{fmt, future::Future, future::poll_fn, io, pin::Pin, task::Context, task::Poll};
4
5use ntex_util::task::LocalWaker;
6
7/// Creates a new futures-aware, one-shot channel.
8pub fn create<T>() -> (Sender<T>, Receiver<T>) {
9    let inner = Rc::new(Inner {
10        value: Cell::new(None),
11        rx_task: LocalWaker::new(),
12    });
13    let tx = Sender {
14        inner: inner.clone(),
15    };
16    let rx = Receiver { inner };
17    (tx, rx)
18}
19
20#[derive(Debug)]
21/// Represents the completion half of a oneshot through which the result of a
22/// computation is signaled.
23pub struct Sender<T> {
24    inner: Rc<Inner<T>>,
25}
26
27#[derive(Debug)]
28/// A future representing the completion of a computation happening elsewhere in
29/// memory.
30#[must_use = "futures do nothing unless polled"]
31pub struct Receiver<T> {
32    inner: Rc<Inner<T>>,
33}
34
35// The channels do not ever project Pin to the inner T
36impl<T> Unpin for Receiver<T> {}
37impl<T> Unpin for Sender<T> {}
38
39struct Inner<T> {
40    value: Cell<Option<io::Result<T>>>,
41    rx_task: LocalWaker,
42}
43
44impl<T> Sender<T> {
45    /// Completes this oneshot with a successful result.
46    ///
47    /// This function will consume `self` and indicate to the other end, the
48    /// `Receiver`, that the error provided is the result of the computation this
49    /// represents.
50    ///
51    /// If the value is successfully enqueued for the remote end to receive,
52    /// then `Ok(())` is returned. If the receiving end was dropped before
53    /// this function was called, however, then `Err` is returned with the value
54    /// provided.
55    pub fn send(self, val: io::Result<T>) -> Result<(), io::Result<T>> {
56        if Rc::strong_count(&self.inner) == 2 {
57            self.inner.value.set(Some(val));
58            self.inner.rx_task.wake();
59            Ok(())
60        } else {
61            Err(val)
62        }
63    }
64}
65
66impl<T> Drop for Sender<T> {
67    fn drop(&mut self) {
68        self.inner.rx_task.wake();
69    }
70}
71
72impl<T> Receiver<T> {
73    pub fn new(val: io::Result<T>) -> Self {
74        let inner = Rc::new(Inner {
75            value: Cell::new(Some(val)),
76            rx_task: LocalWaker::new(),
77        });
78        Receiver { inner }
79    }
80
81    /// Wait until the oneshot is ready and return value
82    pub async fn recv(&self) -> io::Result<T> {
83        poll_fn(|cx| self.poll_recv(cx)).await
84    }
85
86    /// Polls the oneshot to determine if value is ready
87    fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<io::Result<T>> {
88        // If we've got a value, then skip the logic below as we're done.
89        if let Some(val) = self.inner.value.take() {
90            return Poll::Ready(val);
91        }
92
93        // Check if sender is dropped and return error if it is.
94        if Rc::strong_count(&self.inner) == 1 {
95            Poll::Ready(Err(io::Error::other("IO Driver is gone")))
96        } else {
97            self.inner.rx_task.register(cx.waker());
98            Poll::Pending
99        }
100    }
101}
102
103impl<T> Future for Receiver<T> {
104    type Output = io::Result<T>;
105
106    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107        self.poll_recv(cx)
108    }
109}
110
111impl<T: fmt::Debug> fmt::Debug for Inner<T> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        let val = self.value.take();
114        let result = f.debug_struct("Inner").field("value", &val).finish();
115        self.value.set(val);
116        result
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[ntex::test]
125    async fn test_oneshot() {
126        let (tx, rx) = create();
127        //assert!(format!("{tx:?}").contains("Sender"));
128        //assert!(format!("{rx:?}").contains("Receiver"));
129
130        tx.send(Ok("test")).unwrap();
131        assert_eq!(rx.await.unwrap(), "test");
132
133        let (tx, rx) = create();
134        tx.send(Ok("test")).unwrap();
135        assert_eq!(rx.recv().await.unwrap(), "test");
136
137        let (tx, rx) = create();
138        //assert!(!tx.is_canceled());
139        drop(rx);
140        //assert!(tx.is_canceled());
141        assert!(tx.send(Ok("test")).is_err());
142
143        let (tx, rx) = create::<&'static str>();
144        drop(tx);
145        assert!(rx.await.is_err());
146
147        let (tx, rx) = create::<&'static str>();
148        tx.send(Ok("test")).unwrap();
149        assert_eq!(rx.await.unwrap(), "test");
150
151        let (tx, rx) = create::<&'static str>();
152        drop(tx);
153        assert!(rx.await.is_err());
154    }
155}