async_compatibility_layer/channel/
bounded.rs

1use std::pin::Pin;
2
3use futures::Stream;
4
5/// inner module, used to group feature-specific imports
6#[cfg(async_channel_impl = "tokio")]
7mod inner {
8    pub use tokio::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
9
10    use tokio::sync::mpsc::{Receiver as InnerReceiver, Sender as InnerSender};
11
12    /// A receiver error returned from [`Receiver`]'s `recv`
13    #[derive(Debug, PartialEq, Eq)]
14    pub struct RecvError;
15
16    impl std::fmt::Display for RecvError {
17        fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18            write!(fmt, stringify!(RecvError))
19        }
20    }
21
22    impl std::error::Error for RecvError {}
23
24    /// A bounded sender, created with [`bounded`]
25    pub struct Sender<T>(pub(super) InnerSender<T>);
26    /// A bounded receiver, created with [`bounded`]
27    pub struct Receiver<T>(pub(super) InnerReceiver<T>);
28    /// A bounded stream, created with a channel
29    pub struct BoundedStream<T>(pub(super) tokio_stream::wrappers::ReceiverStream<T>);
30
31    /// Turn a `TryRecvError` into a `RecvError` if it's not `Empty`
32    pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
33        match e {
34            TryRecvError::Empty => None,
35            TryRecvError::Disconnected => Some(RecvError),
36        }
37    }
38
39    /// Create a bounded sender/receiver pair, limited to `len` messages at a time.
40    #[must_use]
41    pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
42        let (sender, receiver) = tokio::sync::mpsc::channel(len);
43        (Sender(sender), Receiver(receiver))
44    }
45}
46
47/// inner module, used to group feature-specific imports
48#[cfg(async_channel_impl = "flume")]
49mod inner {
50    pub use flume::{RecvError, SendError, TryRecvError, TrySendError};
51
52    use flume::{r#async::RecvStream, Receiver as InnerReceiver, Sender as InnerSender};
53
54    /// A bounded sender, created with [`bounded`]
55    pub struct Sender<T>(pub(super) InnerSender<T>);
56    /// A bounded receiver, created with [`bounded`]
57    pub struct Receiver<T>(pub(super) InnerReceiver<T>);
58    /// A bounded stream, created with a channel
59    pub struct BoundedStream<T: 'static>(pub(super) RecvStream<'static, T>);
60
61    /// Turn a `TryRecvError` into a `RecvError` if it's not `Empty`
62    pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
63        match e {
64            TryRecvError::Empty => None,
65            TryRecvError::Disconnected => Some(RecvError::Disconnected),
66        }
67    }
68
69    /// Create a bounded sender/receiver pair, limited to `len` messages at a time.
70    #[must_use]
71    pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
72        let (sender, receiver) = flume::bounded(len);
73        (Sender(sender), Receiver(receiver))
74    }
75}
76
77/// inner module, used to group feature-specific imports
78#[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
79mod inner {
80    pub use async_std::channel::{RecvError, SendError, TryRecvError, TrySendError};
81
82    use async_std::channel::{Receiver as InnerReceiver, Sender as InnerSender};
83
84    /// A bounded sender, created with [`channel`]
85    pub struct Sender<T>(pub(super) InnerSender<T>);
86    /// A bounded receiver, created with [`channel`]
87    pub struct Receiver<T>(pub(super) InnerReceiver<T>);
88    /// A bounded stream, created with a channel
89    pub struct BoundedStream<T>(pub(super) InnerReceiver<T>);
90
91    /// Turn a `TryRecvError` into a `RecvError` if it's not `Empty`
92    pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
93        match e {
94            TryRecvError::Empty => None,
95            TryRecvError::Closed => Some(RecvError),
96        }
97    }
98
99    /// Create a bounded sender/receiver pair, limited to `len` messages at a time.
100    #[must_use]
101    pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
102        let (sender, receiver) = async_std::channel::bounded(len);
103
104        (Sender(sender), Receiver(receiver))
105    }
106}
107
108pub use inner::*;
109
110impl<T> Sender<T> {
111    /// Send a value to the channel. May return a [`SendError`] if the receiver is dropped
112    ///
113    /// # Errors
114    ///
115    /// Will return an error if the receiver is dropped
116    pub async fn send(&self, msg: T) -> Result<(), SendError<T>> {
117        #[cfg(async_channel_impl = "flume")]
118        let result = self.0.send_async(msg).await;
119        #[cfg(not(all(async_channel_impl = "flume")))]
120        let result = self.0.send(msg).await;
121
122        result
123    }
124
125    /// Try to send a value over the channel. Will return immediately if the channel is full.
126    ///
127    /// # Errors
128    /// - If the channel is full
129    /// - If the channel is dropped
130    pub fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
131        self.0.try_send(msg)
132    }
133}
134
135impl<T> Receiver<T> {
136    /// Receive a value from te channel. This will async block until a value is received, or until a [`RecvError`] is encountered.
137    ///
138    /// # Errors
139    ///
140    /// Will return an error if the sender is dropped
141    pub async fn recv(&mut self) -> Result<T, RecvError> {
142        #[cfg(async_channel_impl = "flume")]
143        let result = self.0.recv_async().await;
144        #[cfg(async_channel_impl = "tokio")]
145        let result = self.0.recv().await.ok_or(RecvError);
146        #[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
147        let result = self.0.recv().await;
148
149        result
150    }
151    /// Turn this recever into a stream. This may fail on some implementations if multiple references of a receiver exist
152    pub fn into_stream(self) -> BoundedStream<T>
153    where
154        T: 'static,
155    {
156        #[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
157        let result = self.0;
158        #[cfg(async_channel_impl = "tokio")]
159        let result = tokio_stream::wrappers::ReceiverStream::new(self.0);
160        #[cfg(async_channel_impl = "flume")]
161        let result = self.0.into_stream();
162
163        BoundedStream(result)
164    }
165    /// Try to receive a channel from the receiver. Will return immediately if there is no value available.
166    ///
167    /// # Errors
168    ///
169    /// Will return an error if the sender is dropped
170    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
171        self.0.try_recv()
172    }
173    /// Asynchronously wait for at least 1 value to show up, then will greedily try to receive values until this receiver would block. The resulting values are returned.
174    ///
175    /// It is guaranteed that the returning vec contains at least 1 value
176    ///
177    /// # Errors
178    ///
179    /// Will return an error if the sender is dropped
180    pub async fn drain_at_least_one(&mut self) -> Result<Vec<T>, RecvError> {
181        // Wait for the first message to come up
182        let first = self.recv().await?;
183        let mut ret = vec![first];
184        loop {
185            match self.try_recv() {
186                Ok(x) => ret.push(x),
187                Err(e) => {
188                    if let Some(e) = try_recv_error_to_recv_error(e) {
189                        tracing::error!(
190                            "Tried to empty {:?} queue but it disconnected while we were emptying it ({} items are being dropped)",
191                            std::any::type_name::<Self>(),
192                            ret.len()
193                        );
194                        return Err(e);
195                    }
196                    break;
197                }
198            }
199        }
200        Ok(ret)
201    }
202    /// Drains the receiver from all messages in the queue, but will not poll for more messages
203    ///
204    /// # Errors
205    ///
206    /// Will return an error if the sender is dropped
207    pub fn drain(&mut self) -> Result<Vec<T>, RecvError> {
208        let mut result = Vec::new();
209        loop {
210            match self.try_recv() {
211                Ok(t) => result.push(t),
212                Err(e) => {
213                    if let Some(e) = try_recv_error_to_recv_error(e) {
214                        return Err(e);
215                    }
216                    break;
217                }
218            }
219        }
220        Ok(result)
221    }
222}
223
224impl<T> Stream for BoundedStream<T> {
225    type Item = T;
226
227    fn poll_next(
228        mut self: std::pin::Pin<&mut Self>,
229        cx: &mut std::task::Context<'_>,
230    ) -> std::task::Poll<Option<Self::Item>> {
231        #[cfg(async_channel_impl = "flume")]
232        return <flume::r#async::RecvStream<T>>::poll_next(Pin::new(&mut self.0), cx);
233        #[cfg(async_channel_impl = "tokio")]
234        return <tokio_stream::wrappers::ReceiverStream<T> as Stream>::poll_next(
235            Pin::new(&mut self.0),
236            cx,
237        );
238        #[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
239        return <async_std::channel::Receiver<T> as Stream>::poll_next(Pin::new(&mut self.0), cx);
240    }
241}
242
243// Clone impl
244impl<T> Clone for Sender<T> {
245    fn clone(&self) -> Self {
246        Self(self.0.clone())
247    }
248}
249
250// Debug impl
251impl<T> std::fmt::Debug for Sender<T> {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("Sender").finish()
254    }
255}
256impl<T> std::fmt::Debug for Receiver<T> {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.debug_struct("Receiver").finish()
259    }
260}