1use futures::{future::poll_fn, stream::FusedStream, Sink, SinkExt, Stream};
30use kodec::{Decode, Encode};
31use pin_project::pin_project;
32use serde::Serialize;
33use std::{
34 borrow::Borrow,
35 collections::VecDeque,
36 fmt::{Debug, Display},
37 io::ErrorKind,
38 marker::PhantomData,
39 net::SocketAddr,
40 pin::Pin,
41 task::{Context, Poll},
42};
43use tokio::{
44 io::ReadBuf,
45 net::{ToSocketAddrs, UdpSocket},
46};
47
48#[derive(Debug)]
49pub enum Error<SerializationError, DeserializationError> {
50 SendingError,
51 SerializationError(SerializationError),
52 DeserializationError(DeserializationError),
53 IoError(tokio::io::Error),
54}
55
56impl<SerializationError, DeserializationError> Display
57 for Error<SerializationError, DeserializationError>
58where
59 SerializationError: Display,
60 DeserializationError: Display,
61{
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Error::SendingError => write!(f, "not all bytes were sent"),
65 Error::SerializationError(error) => write!(f, "failed to serialize message: {error}"),
66 Error::DeserializationError(error) => {
67 write!(f, "failed to deserialize message: {error}")
68 }
69 Error::IoError(error) => write!(f, "IO error occurred: {error}"),
70 }
71 }
72}
73
74impl<SerializationError, DeserializationError> std::error::Error
75 for Error<SerializationError, DeserializationError>
76where
77 SerializationError: Debug + Display,
78 DeserializationError: Debug + Display,
79{
80}
81
82#[pin_project]
93pub struct Transport<U, Codec, Incoming, Outgoing>
94where
95 U: Borrow<UdpSocket>,
96 Codec: kodec::Codec,
97 for<'de> Incoming: serde::de::Deserialize<'de>,
98 Outgoing: Serialize,
99{
100 udp_socket: Option<U>,
101 codec: Codec,
102 send_queue: VecDeque<Outgoing>,
103 send_buffer: Vec<u8>,
104 message_pending: bool,
105 receive_buffer: Vec<u8>,
106 _incoming: PhantomData<Incoming>,
107}
108
109impl<U, Codec, Incoming, Outgoing> Transport<U, Codec, Incoming, Outgoing>
110where
111 U: Borrow<UdpSocket>,
112 Codec: kodec::Codec,
113 for<'de> Incoming: serde::de::Deserialize<'de>,
114 Outgoing: Serialize,
115{
116 pub fn new(udp_socket: U, codec: Codec) -> Self {
118 Transport {
119 udp_socket: Some(udp_socket),
120 codec,
121 send_queue: VecDeque::new(),
122 send_buffer: vec![],
123 message_pending: false,
124 receive_buffer: vec![0; 65536],
125 _incoming: PhantomData,
126 }
127 }
128
129 pub async fn send_to<A: ToSocketAddrs>(
131 &mut self,
132 message: Outgoing,
133 target: A,
134 ) -> Result<(), mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>>
135 {
136 self.flush().await?;
137 if let Some(udp_socket) = &self.udp_socket {
138 self.codec
139 .encode(&mut self.send_buffer, &message)
140 .map_err(
141 Error::<<Codec as Encode>::Error, <Codec as Decode>::Error>::SerializationError,
142 )
143 .map_err(mezzenger::Error::Other)?;
144 udp_socket
145 .borrow()
146 .send_to(&self.send_buffer, target)
147 .await
148 .map_err(Error::<<Codec as Encode>::Error, <Codec as Decode>::Error>::IoError)
149 .map_err(mezzenger::Error::Other)?;
150 self.send_buffer.clear();
151 Ok(())
152 } else {
153 Err(mezzenger::Error::Closed)
154 }
155 }
156
157 pub async fn receive_from(
161 &mut self,
162 ) -> Result<
163 (Incoming, SocketAddr),
164 mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>,
165 > {
166 if self.udp_socket.is_some() {
167 let result = poll_fn(|cx| self.poll_recv_from(cx)).await;
168 if let Some(result) = result {
169 result.map_err(mezzenger::Error::Other)
170 } else {
171 Err(mezzenger::Error::Closed)
172 }
173 } else {
174 Err(mezzenger::Error::Closed)
175 }
176 }
177
178 #[allow(clippy::type_complexity)]
179 fn poll_recv_from(
180 &mut self,
181 cx: &mut Context<'_>,
182 ) -> Poll<
183 Option<
184 Result<
185 (Incoming, SocketAddr),
186 Error<<Codec as Encode>::Error, <Codec as Decode>::Error>,
187 >,
188 >,
189 > {
190 if let Some(udp_socket) = &self.udp_socket {
191 let mut buf = ReadBuf::new(&mut self.receive_buffer);
192 match udp_socket.borrow().poll_recv_from(cx, &mut buf) {
193 Poll::Ready(result) => match result {
194 Ok(address) => {
195 let result: Result<Incoming, _> = self.codec.decode(buf.filled());
196 match result {
197 Ok(message) => Poll::Ready(Some(Ok((message, address)))),
198 Err(error) => {
199 Poll::Ready(Some(Err(Error::DeserializationError(error))))
200 }
201 }
202 }
203 Err(error) => match error.kind() {
204 ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => {
205 self.udp_socket = None;
206 Poll::Ready(None)
207 }
208 _ => Poll::Ready(Some(Err(Error::IoError(error)))),
209 },
210 },
211 Poll::Pending => Poll::Pending,
212 }
213 } else {
214 Poll::Ready(None)
215 }
216 }
217}
218
219impl<U, Codec, Incoming, Outgoing> Sink<Outgoing> for Transport<U, Codec, Incoming, Outgoing>
220where
221 U: Borrow<UdpSocket>,
222 Codec: kodec::Codec,
223 for<'de> Incoming: serde::de::Deserialize<'de>,
224 Outgoing: Serialize,
225{
226 type Error = mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
227
228 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
229 Poll::Ready(Ok(()))
230 }
231
232 fn start_send(mut self: Pin<&mut Self>, item: Outgoing) -> Result<(), Self::Error> {
233 if self.udp_socket.is_some() {
234 self.send_queue.push_back(item);
235 Ok(())
236 } else {
237 Err(mezzenger::Error::Closed)
238 }
239 }
240
241 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242 let me = self.project();
243 if me.send_queue.is_empty() && !*me.message_pending {
244 return Poll::Ready(Ok(()));
245 }
246 if let Some(udp_socket) = &me.udp_socket {
247 loop {
248 if *me.message_pending {
249 let bytes_to_send = me.send_buffer.len();
250 let result = udp_socket.borrow().poll_send(cx, me.send_buffer);
251 match result {
252 Poll::Ready(result) => {
253 *me.message_pending = false;
254 me.send_buffer.clear();
255 match result {
256 Ok(bytes_written) => {
257 if bytes_written != bytes_to_send {
258 return Poll::Ready(Err(mezzenger::Error::Other(
259 Error::SendingError,
260 )));
261 }
262 }
263 Err(error) => match error.kind() {
264 ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => {
265 *me.udp_socket = None;
266 return Poll::Ready(Err(mezzenger::Error::Closed));
267 }
268 _ => {
269 return Poll::Ready(Err(mezzenger::Error::Other(
270 Error::IoError(error),
271 )))
272 }
273 },
274 }
275 }
276 Poll::Pending => return Poll::Pending,
277 }
278 } else if let Some(message) = me.send_queue.pop_front() {
279 let result = me.codec.encode(&mut *me.send_buffer, &message);
280 if let Err(error) = result {
281 me.send_buffer.clear();
282 return Poll::Ready(Err(mezzenger::Error::Other(
283 Error::SerializationError(error),
284 )));
285 } else {
286 *me.message_pending = true;
287 }
288 } else {
289 return Poll::Ready(Ok(()));
290 }
291 }
292 } else {
293 Poll::Ready(Err(mezzenger::Error::Closed))
294 }
295 }
296
297 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
298 match self.poll_flush_unpin(cx) {
299 Poll::Ready(_) => {
300 self.udp_socket = None;
301 Poll::Ready(Ok(()))
302 }
303 Poll::Pending => Poll::Pending,
304 }
305 }
306}
307
308impl<U, Codec, Incoming, Outgoing> Stream for Transport<U, Codec, Incoming, Outgoing>
309where
310 U: Borrow<UdpSocket>,
311 Codec: kodec::Codec,
312 for<'de> Incoming: serde::de::Deserialize<'de>,
313 Outgoing: Serialize,
314{
315 type Item = Result<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
316
317 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
318 match self.poll_recv_from(cx) {
319 Poll::Ready(result) => {
320 let result = result.map(|result| result.map(|(incoming, _)| incoming));
321 Poll::Ready(result)
322 }
323 Poll::Pending => Poll::Pending,
324 }
325 }
326}
327
328impl<U, Codec, Incoming, Outgoing> FusedStream for Transport<U, Codec, Incoming, Outgoing>
329where
330 U: Borrow<UdpSocket>,
331 Codec: kodec::Codec,
332 for<'de> Incoming: serde::de::Deserialize<'de>,
333 Outgoing: Serialize,
334{
335 fn is_terminated(&self) -> bool {
336 self.udp_socket.is_none()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use futures::SinkExt;
343 use kodec::binary::Codec;
344 use mezzenger::Receive;
345 use tokio::net::UdpSocket;
346
347 use crate::Transport;
348
349 #[tokio::test]
350 async fn test_transport() {
351 let left = UdpSocket::bind("127.0.0.1:8080").await.unwrap();
352 let right = UdpSocket::bind("127.0.0.1:8081").await.unwrap();
353
354 left.connect(right.local_addr().unwrap()).await.unwrap();
355 right.connect(left.local_addr().unwrap()).await.unwrap();
356
357 let mut left: Transport<UdpSocket, Codec, u32, String> =
358 Transport::new(left, Codec::default());
359 let mut right: Transport<UdpSocket, Codec, String, u32> =
360 Transport::new(right, Codec::default());
361
362 left.send("Hello World!".to_string()).await.unwrap();
363 left.send("Hello World again!".to_string()).await.unwrap();
364 right.send(128).await.unwrap();
365 right.send(1).await.unwrap();
366
367 assert_eq!(right.receive().await.unwrap(), "Hello World!");
368 assert_eq!(right.receive().await.unwrap(), "Hello World again!");
369 assert_eq!(left.receive().await.unwrap(), 128);
370 assert_eq!(left.receive().await.unwrap(), 1);
371 }
372
373 #[tokio::test]
374 async fn test_unit_message() {
375 let left = UdpSocket::bind("127.0.0.1:8082").await.unwrap();
376 let right = UdpSocket::bind("127.0.0.1:8083").await.unwrap();
377
378 left.connect(right.local_addr().unwrap()).await.unwrap();
379 right.connect(left.local_addr().unwrap()).await.unwrap();
380
381 let mut left: Transport<UdpSocket, Codec, (), ()> = Transport::new(left, Codec::default());
382 let mut right: Transport<UdpSocket, Codec, (), ()> =
383 Transport::new(right, Codec::default());
384
385 left.send(()).await.unwrap();
386 left.send(()).await.unwrap();
387 left.send(()).await.unwrap();
388 left.send(()).await.unwrap();
389 right.send(()).await.unwrap();
390 right.send(()).await.unwrap();
391 right.send(()).await.unwrap();
392
393 assert_eq!(right.receive().await.unwrap(), ());
394 assert_eq!(right.receive().await.unwrap(), ());
395 assert_eq!(right.receive().await.unwrap(), ());
396 assert_eq!(right.receive().await.unwrap(), ());
397 assert_eq!(left.receive().await.unwrap(), ());
398 assert_eq!(left.receive().await.unwrap(), ());
399 assert_eq!(left.receive().await.unwrap(), ());
400 }
401}