1use std::{cell::Cell, rc::Rc};
3use std::{fmt, future::Future, future::poll_fn, io, pin::Pin, task::Context, task::Poll};
4
5use ntex_util::task::LocalWaker;
6
7pub fn create<T>() -> (Sender<T>, Receiver<T>) {
9 let inner = Rc::new(Inner {
10 value: Cell::new(None),
11 rx_task: LocalWaker::new(),
12 });
13 let tx = Sender {
14 inner: inner.clone(),
15 };
16 let rx = Receiver { inner };
17 (tx, rx)
18}
19
20#[derive(Debug)]
21pub struct Sender<T> {
24 inner: Rc<Inner<T>>,
25}
26
27#[derive(Debug)]
28#[must_use = "futures do nothing unless polled"]
31pub struct Receiver<T> {
32 inner: Rc<Inner<T>>,
33}
34
35impl<T> Unpin for Receiver<T> {}
37impl<T> Unpin for Sender<T> {}
38
39struct Inner<T> {
40 value: Cell<Option<io::Result<T>>>,
41 rx_task: LocalWaker,
42}
43
44impl<T> Sender<T> {
45 pub fn send(self, val: io::Result<T>) -> Result<(), io::Result<T>> {
56 if Rc::strong_count(&self.inner) == 2 {
57 self.inner.value.set(Some(val));
58 self.inner.rx_task.wake();
59 Ok(())
60 } else {
61 Err(val)
62 }
63 }
64}
65
66impl<T> Drop for Sender<T> {
67 fn drop(&mut self) {
68 self.inner.rx_task.wake();
69 }
70}
71
72impl<T> Receiver<T> {
73 pub fn new(val: io::Result<T>) -> Self {
74 let inner = Rc::new(Inner {
75 value: Cell::new(Some(val)),
76 rx_task: LocalWaker::new(),
77 });
78 Receiver { inner }
79 }
80
81 pub async fn recv(&self) -> io::Result<T> {
83 poll_fn(|cx| self.poll_recv(cx)).await
84 }
85
86 fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<io::Result<T>> {
88 if let Some(val) = self.inner.value.take() {
90 return Poll::Ready(val);
91 }
92
93 if Rc::strong_count(&self.inner) == 1 {
95 Poll::Ready(Err(io::Error::other("IO Driver is gone")))
96 } else {
97 self.inner.rx_task.register(cx.waker());
98 Poll::Pending
99 }
100 }
101}
102
103impl<T> Future for Receiver<T> {
104 type Output = io::Result<T>;
105
106 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107 self.poll_recv(cx)
108 }
109}
110
111impl<T: fmt::Debug> fmt::Debug for Inner<T> {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 let val = self.value.take();
114 let result = f.debug_struct("Inner").field("value", &val).finish();
115 self.value.set(val);
116 result
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[ntex::test]
125 async fn test_oneshot() {
126 let (tx, rx) = create();
127 tx.send(Ok("test")).unwrap();
131 assert_eq!(rx.await.unwrap(), "test");
132
133 let (tx, rx) = create();
134 tx.send(Ok("test")).unwrap();
135 assert_eq!(rx.recv().await.unwrap(), "test");
136
137 let (tx, rx) = create();
138 drop(rx);
140 assert!(tx.send(Ok("test")).is_err());
142
143 let (tx, rx) = create::<&'static str>();
144 drop(tx);
145 assert!(rx.await.is_err());
146
147 let (tx, rx) = create::<&'static str>();
148 tx.send(Ok("test")).unwrap();
149 assert_eq!(rx.await.unwrap(), "test");
150
151 let (tx, rx) = create::<&'static str>();
152 drop(tx);
153 assert!(rx.await.is_err());
154 }
155}