1use std::{
21 fmt::Display,
22 marker::PhantomData,
23 pin::Pin,
24 task::{Context, Poll},
25};
26
27use futures::{
28 channel::mpsc::{unbounded, SendError, UnboundedReceiver, UnboundedSender},
29 stream::{Fuse, FusedStream},
30 Sink, Stream, StreamExt,
31};
32use pin_project::pin_project;
33
34#[derive(Debug)]
35pub enum Error {
36 ChannelIsFull,
37}
38
39impl Display for Error {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "channel is full")
42 }
43}
44
45impl std::error::Error for Error {}
46
47#[pin_project]
49pub struct Transport<Receiver, Sender, Incoming, Outgoing>
50where
51 Receiver: Stream<Item = Incoming>,
52 Sender: Sink<Outgoing, Error = SendError>,
53{
54 #[pin]
55 receiver: Fuse<Receiver>,
56 #[pin]
57 sender: Sender,
58 _incoming: PhantomData<Incoming>,
59 _outgoing: PhantomData<Outgoing>,
60}
61
62impl<Receiver, Sender, Incoming, Outgoing> Transport<Receiver, Sender, Incoming, Outgoing>
63where
64 Receiver: Stream<Item = Incoming>,
65 Sender: Sink<Outgoing, Error = SendError>,
66{
67 pub fn new(sender: Sender, receiver: Receiver) -> Self {
69 Transport {
70 receiver: receiver.fuse(),
71 sender,
72 _incoming: PhantomData,
73 _outgoing: PhantomData,
74 }
75 }
76}
77
78impl<Receiver, Sender, Incoming, Outgoing> Sink<Outgoing>
79 for Transport<Receiver, Sender, Incoming, Outgoing>
80where
81 Receiver: Stream<Item = Incoming>,
82 Sender: Sink<Outgoing, Error = SendError>,
83{
84 type Error = mezzenger::Error<Error>;
85
86 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 let me = self.project();
88 me.sender.poll_ready(cx).map_err(map_error)
89 }
90
91 fn start_send(self: Pin<&mut Self>, item: Outgoing) -> Result<(), Self::Error> {
92 let me = self.project();
93 me.sender.start_send(item).map_err(map_error)
94 }
95
96 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97 let me = self.project();
98 me.sender.poll_flush(cx).map_err(map_error)
99 }
100
101 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102 let me = self.project();
103 me.sender.poll_close(cx).map_err(map_error)
104 }
105}
106
107impl<Receiver, Sender, Incoming, Outgoing> Stream
108 for Transport<Receiver, Sender, Incoming, Outgoing>
109where
110 Receiver: Stream<Item = Incoming>,
111 Sender: Sink<Outgoing, Error = SendError>,
112{
113 type Item = Result<Incoming, Error>;
114
115 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116 let me = self.project();
117 me.receiver
118 .poll_next(cx)
119 .map(|item_option| item_option.map(Ok))
120 }
121}
122
123fn map_error(error: SendError) -> mezzenger::Error<Error> {
124 if error.is_full() {
125 mezzenger::Error::Other(Error::ChannelIsFull)
126 } else {
127 mezzenger::Error::Closed
128 }
129}
130
131impl<Receiver, Sender, Incoming, Outgoing> FusedStream
132 for Transport<Receiver, Sender, Incoming, Outgoing>
133where
134 Receiver: Stream<Item = Incoming>,
135 Sender: Sink<Outgoing, Error = SendError>,
136{
137 fn is_terminated(&self) -> bool {
138 self.receiver.is_terminated()
139 }
140}
141
142impl<Receiver, Sender, Incoming, Outgoing> mezzenger::Reliable
143 for Transport<Receiver, Sender, Incoming, Outgoing>
144where
145 Receiver: Stream<Item = Incoming>,
146 Sender: Sink<Outgoing, Error = SendError>,
147{
148}
149
150impl<Receiver, Sender, Incoming, Outgoing> mezzenger::Order
151 for Transport<Receiver, Sender, Incoming, Outgoing>
152where
153 Receiver: Stream<Item = Incoming>,
154 Sender: Sink<Outgoing, Error = SendError>,
155{
156}
157
158#[allow(clippy::type_complexity)]
160pub fn transports<Incoming, Outgoing>() -> (
161 Transport<UnboundedReceiver<Incoming>, UnboundedSender<Outgoing>, Incoming, Outgoing>,
162 Transport<UnboundedReceiver<Outgoing>, UnboundedSender<Incoming>, Outgoing, Incoming>,
163) {
164 let (left_sender, right_receiver) = unbounded();
165 let (right_sender, left_receiver) = unbounded();
166
167 let left = Transport::new(left_sender, left_receiver);
168 let right = Transport::new(right_sender, right_receiver);
169
170 (left, right)
171}
172
173#[cfg(test)]
174mod tests {
175 use futures::{stream, SinkExt, StreamExt};
176
177 use mezzenger::{Messages, Receive};
178 #[cfg(target_arch = "wasm32")]
179 use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure};
180 #[cfg(target_arch = "wasm32")]
181 wasm_bindgen_test_configure!(run_in_browser);
182
183 use crate::transports;
184
185 async fn test_stream_inner() {
186 let (mut left, right) = transports::<String, u32>();
187
188 left.send_all(&mut stream::iter(vec![1, 2, 3].into_iter().map(Ok)))
189 .await
190 .unwrap();
191 drop(left);
192
193 assert_eq!(right.messages().collect::<Vec<u32>>().await, vec![1, 2, 3]);
194 }
195
196 async fn test_transport_inner() {
197 let (mut left, mut right) = transports();
198
199 left.send("Hello World!".to_string()).await.unwrap();
200 left.send("Hello World again!".to_string()).await.unwrap();
201 right.send(128).await.unwrap();
202 right.send(1).await.unwrap();
203
204 assert_eq!(right.receive().await.unwrap(), "Hello World!");
205 assert_eq!(right.receive().await.unwrap(), "Hello World again!");
206 assert_eq!(left.receive().await.unwrap(), 128);
207 assert_eq!(left.receive().await.unwrap(), 1);
208 }
209
210 async fn test_unit_message_inner() {
211 let (mut left, mut right) = transports();
212
213 left.send(()).await.unwrap();
214 left.send(()).await.unwrap();
215 right.send(()).await.unwrap();
216 right.send(()).await.unwrap();
217
218 assert_eq!(right.receive().await.unwrap(), ());
219 assert_eq!(right.receive().await.unwrap(), ());
220 assert_eq!(left.receive().await.unwrap(), ());
221 assert_eq!(left.receive().await.unwrap(), ());
222 }
223
224 #[cfg(not(target_arch = "wasm32"))]
225 #[tokio::test]
226 async fn test_transport() {
227 test_transport_inner().await
228 }
229
230 #[cfg(not(target_arch = "wasm32"))]
231 #[tokio::test]
232 async fn test_unit_message() {
233 test_unit_message_inner().await
234 }
235
236 #[cfg(not(target_arch = "wasm32"))]
237 #[tokio::test]
238 async fn test_stream() {
239 test_stream_inner().await;
240 }
241
242 #[cfg(target_arch = "wasm32")]
243 #[wasm_bindgen_test]
244 async fn test_transport() {
245 test_transport_inner().await
246 }
247
248 #[cfg(target_arch = "wasm32")]
249 #[wasm_bindgen_test]
250 async fn test_unit_message() {
251 test_unit_message_inner().await
252 }
253
254 #[cfg(target_arch = "wasm32")]
255 #[wasm_bindgen_test]
256 async fn test_stream() {
257 test_stream_inner().await
258 }
259}