common/
network.rs

1use std::convert::TryFrom;
2use std::error::Error;
3use std::net::IpAddr;
4
5use async_channel::Sender;
6use async_trait::async_trait;
7use tokio::net;
8use tokio::time::{timeout as tokio_timeout, Duration};
9
10use crate::io::{AsyncReader, AsyncWriter, Stream};
11
12#[async_trait]
13pub trait Listener {
14    async fn accept_clients(&mut self, new_clients: Sender<Stream>) -> Result<(), Box<dyn Error>>;
15}
16
17pub struct TcpListener {
18    listener: net::TcpListener,
19}
20
21impl TcpListener {
22    pub async fn new(local_address: IpAddr, local_port: u16) -> Result<Self, Box<dyn Error>> {
23        let listener_address = format!("{}:{}", local_address, local_port);
24        log::info!("start listening on {}", listener_address);
25        let listener = net::TcpListener::bind(listener_address).await?;
26        Ok(Self { listener })
27    }
28}
29
30#[async_trait]
31impl Listener for TcpListener {
32    async fn accept_clients(&mut self, new_clients: Sender<Stream>) -> Result<(), Box<dyn Error>> {
33        while let Ok((client_stream, client_address)) = self.listener.accept().await {
34            log::debug!("got connection from {}", client_address);
35            let (client_reader, client_writer) = client_stream.into_split();
36            new_clients
37                .send(Stream::new(client_reader, client_writer))
38                .await?;
39        }
40        Ok(())
41    }
42}
43
44pub const MAX_UDP_PACKET_SIZE: usize = u16::MAX as usize;
45pub const STREAMED_UDP_PACKET_HEADER_SIZE: usize = 2;
46
47pub async fn stream_udp_packet(payload: &[u8], size: usize, writer: &mut Box<dyn AsyncWriter>) {
48    if payload.len() < size {
49        log::error!(
50            "payload {:?} is too small (expecting size {})",
51            payload,
52            size
53        );
54        return;
55    }
56
57    let size_u16 = match u16::try_from(size) {
58        Ok(s) => s,
59        Err(e) => {
60            log::error!("size {} can't fit in a u16: {}", size, e);
61            return;
62        }
63    };
64
65    if let Err(e) = writer.write(&size_u16.to_be_bytes()).await {
66        log::error!("failed to write header: {}", e);
67        return;
68    };
69
70    if let Err(e) = writer.write(&payload[..size]).await {
71        log::error!("failed to write payload: {}", e);
72        return;
73    };
74}
75
76#[derive(PartialEq, Debug)]
77pub enum UnstreamPacketResult {
78    Error,
79    Timeout,
80    Payload(Vec<u8>),
81}
82
83pub async fn unstream_udp_packet(
84    reader: &mut Box<dyn AsyncReader>,
85    timeout: Option<Duration>,
86) -> UnstreamPacketResult {
87    let mut header_bytes = [0; STREAMED_UDP_PACKET_HEADER_SIZE];
88    let read_header_future = reader.read_exact(&mut header_bytes);
89    let header_size_result = match timeout {
90        None => read_header_future.await,
91        Some(duration) => match tokio_timeout(duration, read_header_future).await {
92            Ok(size_result) => size_result,
93            Err(_) => return UnstreamPacketResult::Timeout,
94        },
95    };
96
97    let header_size = match header_size_result {
98        Ok(size) => size,
99        Err(e) => {
100            log::error!("failed to read header: {}", e);
101            return UnstreamPacketResult::Error;
102        }
103    };
104
105    if header_size != STREAMED_UDP_PACKET_HEADER_SIZE {
106        log::error!("got unexpected header size in bytes {}", header_size);
107        return UnstreamPacketResult::Error;
108    }
109
110    let header = u16::from_be_bytes(header_bytes);
111    let header_usize = header as usize;
112    let mut payload = vec![0; header_usize];
113    let size = match reader.read_exact(&mut payload).await {
114        Ok(size) => size,
115        Err(e) => {
116            log::error!("failed to read payload: {}", e);
117            return UnstreamPacketResult::Error;
118        }
119    };
120
121    if size != header_usize {
122        log::error!("got unexpected data size in bytes {}", header_size);
123        return UnstreamPacketResult::Error;
124    }
125
126    UnstreamPacketResult::Payload(payload)
127}
128
129#[cfg(test)]
130mod tests {
131    use std::io::ErrorKind;
132
133    use tokio::io;
134    use tokio_test::io::Builder;
135
136    use crate::io::{AsyncReadWrapper, AsyncWriteWrapper};
137
138    use super::*;
139
140    #[tokio::test]
141    async fn stream_udp_packet_payload_too_small() -> Result<(), Box<dyn Error>> {
142        let payload = vec![1, 2, 3];
143        let mut writer: Box<dyn AsyncWriter> =
144            Box::new(AsyncWriteWrapper::new(Builder::new().build()));
145
146        stream_udp_packet(&payload, 7, &mut writer).await;
147        Ok(())
148    }
149
150    #[tokio::test]
151    async fn stream_udp_packet_size_not_fit_in_u16() -> Result<(), Box<dyn Error>> {
152        let payload = vec![0; u16::MAX as usize + 7];
153        let mut writer: Box<dyn AsyncWriter> =
154            Box::new(AsyncWriteWrapper::new(Builder::new().build()));
155
156        stream_udp_packet(&payload, payload.len(), &mut writer).await;
157        Ok(())
158    }
159
160    #[tokio::test]
161    async fn stream_udp_packet_write_header_failed() -> Result<(), Box<dyn Error>> {
162        let payload = vec![1, 2, 3];
163        let mut writer: Box<dyn AsyncWriter> = Box::new(AsyncWriteWrapper::new(
164            Builder::new()
165                .write_error(io::Error::new(ErrorKind::Other, "oh no!"))
166                .build(),
167        ));
168
169        stream_udp_packet(&payload, payload.len(), &mut writer).await;
170        Ok(())
171    }
172
173    #[tokio::test]
174    async fn stream_udp_packet_write_payload_failed() -> Result<(), Box<dyn Error>> {
175        let payload = vec![1, 2, 3];
176        let mut writer: Box<dyn AsyncWriter> = Box::new(AsyncWriteWrapper::new(
177            Builder::new()
178                .write(vec![0u8, 3].as_slice())
179                .write_error(io::Error::new(ErrorKind::Other, "oh no!"))
180                .build(),
181        ));
182
183        stream_udp_packet(&payload, payload.len(), &mut writer).await;
184        Ok(())
185    }
186
187    #[tokio::test]
188    async fn stream_udp_packet_success() -> Result<(), Box<dyn Error>> {
189        let payload = vec![1, 2, 3];
190        let mut writer: Box<dyn AsyncWriter> = Box::new(AsyncWriteWrapper::new(
191            Builder::new()
192                .write(vec![0u8, 3].as_slice())
193                .write(payload.as_slice())
194                .build(),
195        ));
196
197        stream_udp_packet(&payload, payload.len(), &mut writer).await;
198        Ok(())
199    }
200
201    #[tokio::test]
202    async fn stream_udp_packet_timeout() -> Result<(), Box<dyn Error>> {
203        let mut reader: Box<dyn AsyncReader> = Box::new(AsyncReadWrapper::new(
204            Builder::new().wait(Duration::from_secs(5)).build(),
205        ));
206
207        let res = unstream_udp_packet(&mut reader, Some(Duration::from_millis(1))).await;
208        assert_eq!(res, UnstreamPacketResult::Timeout);
209        Ok(())
210    }
211
212    #[tokio::test]
213    async fn stream_udp_packet_read_header_failed() -> Result<(), Box<dyn Error>> {
214        let mut reader: Box<dyn AsyncReader> =
215            Box::new(AsyncReadWrapper::new(Builder::new().build()));
216
217        let res = unstream_udp_packet(&mut reader, None).await;
218        assert_eq!(res, UnstreamPacketResult::Error);
219        Ok(())
220    }
221
222    #[tokio::test]
223    async fn stream_udp_packet_read_payload_failed() -> Result<(), Box<dyn Error>> {
224        let mut reader: Box<dyn AsyncReader> = Box::new(AsyncReadWrapper::new(
225            Builder::new().read(vec![0u8, 3].as_slice()).build(),
226        ));
227
228        let res = unstream_udp_packet(&mut reader, None).await;
229        assert_eq!(res, UnstreamPacketResult::Error);
230        Ok(())
231    }
232
233    #[tokio::test]
234    async fn stream_udp_packet_read_payload_success() -> Result<(), Box<dyn Error>> {
235        let payload = vec![1u8, 2, 3];
236        let mut reader: Box<dyn AsyncReader> = Box::new(AsyncReadWrapper::new(
237            Builder::new()
238                .read(vec![0u8, 3].as_slice())
239                .read(payload.as_slice())
240                .build(),
241        ));
242
243        let res = unstream_udp_packet(&mut reader, None).await;
244        assert_eq!(res, UnstreamPacketResult::Payload(payload));
245        Ok(())
246    }
247}