async_compatibility_layer/channel/
bounded.rs1use std::pin::Pin;
2
3use futures::Stream;
4
5#[cfg(async_channel_impl = "tokio")]
7mod inner {
8 pub use tokio::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
9
10 use tokio::sync::mpsc::{Receiver as InnerReceiver, Sender as InnerSender};
11
12 #[derive(Debug, PartialEq, Eq)]
14 pub struct RecvError;
15
16 impl std::fmt::Display for RecvError {
17 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(fmt, stringify!(RecvError))
19 }
20 }
21
22 impl std::error::Error for RecvError {}
23
24 pub struct Sender<T>(pub(super) InnerSender<T>);
26 pub struct Receiver<T>(pub(super) InnerReceiver<T>);
28 pub struct BoundedStream<T>(pub(super) tokio_stream::wrappers::ReceiverStream<T>);
30
31 pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
33 match e {
34 TryRecvError::Empty => None,
35 TryRecvError::Disconnected => Some(RecvError),
36 }
37 }
38
39 #[must_use]
41 pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
42 let (sender, receiver) = tokio::sync::mpsc::channel(len);
43 (Sender(sender), Receiver(receiver))
44 }
45}
46
47#[cfg(async_channel_impl = "flume")]
49mod inner {
50 pub use flume::{RecvError, SendError, TryRecvError, TrySendError};
51
52 use flume::{r#async::RecvStream, Receiver as InnerReceiver, Sender as InnerSender};
53
54 pub struct Sender<T>(pub(super) InnerSender<T>);
56 pub struct Receiver<T>(pub(super) InnerReceiver<T>);
58 pub struct BoundedStream<T: 'static>(pub(super) RecvStream<'static, T>);
60
61 pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
63 match e {
64 TryRecvError::Empty => None,
65 TryRecvError::Disconnected => Some(RecvError::Disconnected),
66 }
67 }
68
69 #[must_use]
71 pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
72 let (sender, receiver) = flume::bounded(len);
73 (Sender(sender), Receiver(receiver))
74 }
75}
76
77#[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
79mod inner {
80 pub use async_std::channel::{RecvError, SendError, TryRecvError, TrySendError};
81
82 use async_std::channel::{Receiver as InnerReceiver, Sender as InnerSender};
83
84 pub struct Sender<T>(pub(super) InnerSender<T>);
86 pub struct Receiver<T>(pub(super) InnerReceiver<T>);
88 pub struct BoundedStream<T>(pub(super) InnerReceiver<T>);
90
91 pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
93 match e {
94 TryRecvError::Empty => None,
95 TryRecvError::Closed => Some(RecvError),
96 }
97 }
98
99 #[must_use]
101 pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
102 let (sender, receiver) = async_std::channel::bounded(len);
103
104 (Sender(sender), Receiver(receiver))
105 }
106}
107
108pub use inner::*;
109
110impl<T> Sender<T> {
111 pub async fn send(&self, msg: T) -> Result<(), SendError<T>> {
117 #[cfg(async_channel_impl = "flume")]
118 let result = self.0.send_async(msg).await;
119 #[cfg(not(all(async_channel_impl = "flume")))]
120 let result = self.0.send(msg).await;
121
122 result
123 }
124
125 pub fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
131 self.0.try_send(msg)
132 }
133}
134
135impl<T> Receiver<T> {
136 pub async fn recv(&mut self) -> Result<T, RecvError> {
142 #[cfg(async_channel_impl = "flume")]
143 let result = self.0.recv_async().await;
144 #[cfg(async_channel_impl = "tokio")]
145 let result = self.0.recv().await.ok_or(RecvError);
146 #[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
147 let result = self.0.recv().await;
148
149 result
150 }
151 pub fn into_stream(self) -> BoundedStream<T>
153 where
154 T: 'static,
155 {
156 #[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
157 let result = self.0;
158 #[cfg(async_channel_impl = "tokio")]
159 let result = tokio_stream::wrappers::ReceiverStream::new(self.0);
160 #[cfg(async_channel_impl = "flume")]
161 let result = self.0.into_stream();
162
163 BoundedStream(result)
164 }
165 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
171 self.0.try_recv()
172 }
173 pub async fn drain_at_least_one(&mut self) -> Result<Vec<T>, RecvError> {
181 let first = self.recv().await?;
183 let mut ret = vec![first];
184 loop {
185 match self.try_recv() {
186 Ok(x) => ret.push(x),
187 Err(e) => {
188 if let Some(e) = try_recv_error_to_recv_error(e) {
189 tracing::error!(
190 "Tried to empty {:?} queue but it disconnected while we were emptying it ({} items are being dropped)",
191 std::any::type_name::<Self>(),
192 ret.len()
193 );
194 return Err(e);
195 }
196 break;
197 }
198 }
199 }
200 Ok(ret)
201 }
202 pub fn drain(&mut self) -> Result<Vec<T>, RecvError> {
208 let mut result = Vec::new();
209 loop {
210 match self.try_recv() {
211 Ok(t) => result.push(t),
212 Err(e) => {
213 if let Some(e) = try_recv_error_to_recv_error(e) {
214 return Err(e);
215 }
216 break;
217 }
218 }
219 }
220 Ok(result)
221 }
222}
223
224impl<T> Stream for BoundedStream<T> {
225 type Item = T;
226
227 fn poll_next(
228 mut self: std::pin::Pin<&mut Self>,
229 cx: &mut std::task::Context<'_>,
230 ) -> std::task::Poll<Option<Self::Item>> {
231 #[cfg(async_channel_impl = "flume")]
232 return <flume::r#async::RecvStream<T>>::poll_next(Pin::new(&mut self.0), cx);
233 #[cfg(async_channel_impl = "tokio")]
234 return <tokio_stream::wrappers::ReceiverStream<T> as Stream>::poll_next(
235 Pin::new(&mut self.0),
236 cx,
237 );
238 #[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
239 return <async_std::channel::Receiver<T> as Stream>::poll_next(Pin::new(&mut self.0), cx);
240 }
241}
242
243impl<T> Clone for Sender<T> {
245 fn clone(&self) -> Self {
246 Self(self.0.clone())
247 }
248}
249
250impl<T> std::fmt::Debug for Sender<T> {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("Sender").finish()
254 }
255}
256impl<T> std::fmt::Debug for Receiver<T> {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.debug_struct("Receiver").finish()
259 }
260}