packets/
lib.rs

1use std::io::{ErrorKind, Read, Write};
2use std::net::{SocketAddr, TcpStream};
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9
10pub mod client;
11pub mod server;
12
13#[derive(Copy, Clone, Debug, Eq, Ord, PartialOrd, PartialEq)]
14pub enum PacketReceiveStatus {
15  Received,
16  TimedOut
17}
18
19#[derive(Debug)]
20pub struct Connection {
21  pub addr: SocketAddr,
22  pub stream: TcpStream,
23}
24
25pub(crate) fn block_until_receive(stream: &mut TcpStream, timeout: Duration) -> Result<PacketReceiveStatus> {
26  // Make sure we have a non-blocking TcpStream. We can't use a blocking TcpStream as it does not
27  // support timeouts. So we have to poll the stream.
28  stream.set_nonblocking(true)?;
29
30  let start_time = Instant::now();
31
32  // The size of the buffer has to be more than 4 bytes, otherwise we can't peek and see if more
33  // than 4 bytes are in the buffer.
34  // If there are 4 or fewer bytes in the buffer we don't want to read the packet yet because only
35  // the size descriptor has been received. We want at least 1 byte of the packet to have been
36  // received before we retrieve it.
37  let mut buf = [0u8; 5];
38
39  loop {
40    if start_time.elapsed() > timeout {
41      break;
42    }
43
44    match stream.peek(&mut buf) {
45      Ok(peeked) => {
46        if peeked > 4 {
47          return Ok(PacketReceiveStatus::Received);
48        } else {
49          continue;
50        }
51      }
52      Err(err) => {
53        if err.kind() == ErrorKind::WouldBlock {
54          continue;
55        } else {
56          return Err(anyhow::Error::from(err));
57        }
58      }
59    }
60  }
61
62  Ok(PacketReceiveStatus::TimedOut)
63}
64
65pub(crate) fn read_packet<A: Serialize + DeserializeOwned>(stream: &mut TcpStream, blocking: bool) -> Result<Option<A>> {
66  let mut buf = [0u8; 5];
67  stream.set_nonblocking(!blocking)?;
68
69  let peek_bytes_res = stream.peek(&mut buf);
70  let peek_bytes = match peek_bytes_res {
71    Ok(peek_bytes) => peek_bytes,
72    Err(err) => {
73      return if err.kind() == ErrorKind::WouldBlock {
74        // We can't peek 8 bytes
75        Ok(None)
76      } else {
77        Result::Err(anyhow::Error::from(err))
78      }
79    }
80  };
81
82  let mut result = Ok(None);
83
84  // The size marker is 4 bytes, if we have more than the size marker then we want to read the
85  // entire packet.
86  if peek_bytes > 4 {
87
88    // We set nonblocking to false so that we can block until the entire packet has been read.
89    stream.set_nonblocking(false)?;
90
91    let bytes = stream.read_u32::<LittleEndian>()? as usize;
92
93    // Initialize a vector with the exact right size for us to read from the packet.
94    let mut packet_bytes = vec![0; bytes];
95
96    match stream.read_exact(&mut packet_bytes) {
97      Ok(_) => {}
98      Err(err) => {
99        let kind = err.kind();
100        return if kind == ErrorKind::WouldBlock {
101          Ok(None)
102        } else {
103          Err(anyhow::Error::from(err))
104        }
105      }
106    }
107
108    let packet = bincode::deserialize(&packet_bytes)?;
109    result = Ok(Some(packet))
110  }
111
112  result
113}
114
115pub(crate) fn write_packet<A: Serialize + DeserializeOwned>(stream: &mut TcpStream, packet: &A) -> Result<()> {
116  let bytes = bincode::serialize(packet)?;
117  stream.set_nonblocking(false)?;
118  stream.write_u32::<LittleEndian>(bytes.len() as u32)?;
119  stream.write_all(&bytes)?;
120  stream.set_nonblocking(true)?;
121  Ok(())
122}