1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use std::io::{ErrorKind, Read, Write};
use std::net::{SocketAddr, TcpStream};
use std::time::{Duration, Instant};

use anyhow::Result;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use serde::de::DeserializeOwned;
use serde::Serialize;

pub mod client;
pub mod server;

#[derive(Copy, Clone, Debug, Eq, Ord, PartialOrd, PartialEq)]
pub enum PacketReceiveStatus {
  Received,
  TimedOut
}

#[derive(Debug)]
pub struct Connection {
  pub addr: SocketAddr,
  pub stream: TcpStream,
}

pub(crate) fn block_until_receive(stream: &mut TcpStream, timeout: Duration) -> Result<PacketReceiveStatus> {
  // Make sure we have a non-blocking TcpStream. We can't use a blocking TcpStream as it does not
  // support timeouts. So we have to poll the stream.
  stream.set_nonblocking(true)?;

  let start_time = Instant::now();

  // The size of the buffer has to be more than 4 bytes, otherwise we can't peek and see if more
  // than 4 bytes are in the buffer.
  // If there are 4 or fewer bytes in the buffer we don't want to read the packet yet because only
  // the size descriptor has been received. We want at least 1 byte of the packet to have been
  // received before we retrieve it.
  let mut buf = [0u8; 5];

  loop {
    if start_time.elapsed() > timeout {
      break;
    }

    match stream.peek(&mut buf) {
      Ok(peeked) => {
        if peeked > 4 {
          return Ok(PacketReceiveStatus::Received);
        } else {
          continue;
        }
      }
      Err(err) => {
        if err.kind() == ErrorKind::WouldBlock {
          continue;
        } else {
          return Err(anyhow::Error::from(err));
        }
      }
    }
  }

  Ok(PacketReceiveStatus::TimedOut)
}

pub(crate) fn read_packet<A: Serialize + DeserializeOwned>(stream: &mut TcpStream, blocking: bool) -> Result<Option<A>> {
  let mut buf = [0u8; 5];
  stream.set_nonblocking(!blocking)?;

  let peek_bytes_res = stream.peek(&mut buf);
  let peek_bytes = match peek_bytes_res {
    Ok(peek_bytes) => peek_bytes,
    Err(err) => {
      return if err.kind() == ErrorKind::WouldBlock {
        // We can't peek 8 bytes
        Ok(None)
      } else {
        Result::Err(anyhow::Error::from(err))
      }
    }
  };

  let mut result = Ok(None);

  // The size marker is 4 bytes, if we have more than the size marker then we want to read the
  // entire packet.
  if peek_bytes > 4 {

    // We set nonblocking to false so that we can block until the entire packet has been read.
    stream.set_nonblocking(false)?;

    let bytes = stream.read_u32::<LittleEndian>()? as usize;

    // Initialize a vector with the exact right size for us to read from the packet.
    let mut packet_bytes = vec![0; bytes];

    match stream.read_exact(&mut packet_bytes) {
      Ok(_) => {}
      Err(err) => {
        let kind = err.kind();
        return if kind == ErrorKind::WouldBlock {
          Ok(None)
        } else {
          Err(anyhow::Error::from(err))
        }
      }
    }

    let packet = bincode::deserialize(&packet_bytes)?;
    result = Ok(Some(packet))
  }

  result
}

pub(crate) fn write_packet<A: Serialize + DeserializeOwned>(stream: &mut TcpStream, packet: &A) -> Result<()> {
  let bytes = bincode::serialize(packet)?;
  stream.set_nonblocking(false)?;
  stream.write_u32::<LittleEndian>(bytes.len() as u32)?;
  stream.write_all(&bytes)?;
  stream.set_nonblocking(true)?;
  Ok(())
}