use std::{
io::{Cursor, ErrorKind, Read, Write},
mem,
net::{SocketAddr, TcpStream},
sync::{
Arc, Mutex,
mpsc::{Receiver, Sender, TryRecvError, channel},
},
thread::JoinHandle,
time::Duration,
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use mio::{Events, Interest, Poll, Token};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
pub trait Message: Serialize + DeserializeOwned + Send + Sync + 'static {}
impl<T> Message for T where for<'a> T: Serialize + Deserialize<'a> + Send + Sync + 'static {}
#[derive(Clone)]
pub struct NonBlockStream<M: Message> {
rx_reader: Arc<Mutex<Receiver<Box<M>>>>,
tx_writer: Sender<Box<M>>,
rx_err: Arc<Mutex<Receiver<std::io::Error>>>,
local_addr: SocketAddr,
remote_addr: SocketAddr,
_handle: Arc<JoinHandle<()>>,
}
enum ShortCircuit {
Yield,
Err(std::io::Error),
}
impl From<std::io::Error> for ShortCircuit {
fn from(value: std::io::Error) -> Self {
ShortCircuit::Err(value)
}
}
impl<M: Message> From<TcpStream> for NonBlockStream<M> {
fn from(stream: TcpStream) -> Self {
stream
.set_nonblocking(true)
.expect("Could not set socket to nonblocking. It is required for communication.");
let (tx_reader, rx_reader) = channel::<Box<M>>();
let (tx_writer, rx_writer) = channel::<Box<M>>();
let (tx_err, rx_err) = channel::<std::io::Error>();
let local_addr = stream
.local_addr()
.expect("Could not obtain local_addr from stream");
let remote_addr = stream
.peer_addr()
.expect("Could not obtain peer_addr from stream");
let looper = StreamLooper::<M>::new(stream, tx_reader, rx_writer, tx_err);
let handle = std::thread::spawn(move || {
looper.stream_loop();
});
Self {
_handle: Arc::new(handle),
rx_reader: Arc::new(Mutex::new(rx_reader)),
tx_writer,
rx_err: Arc::new(Mutex::new(rx_err)),
local_addr,
remote_addr,
}
}
}
impl<M: Message> NonBlockStream<M> {
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
pub fn write(&mut self, msg: Box<M>) -> Result<(), std::io::Error> {
self.trap_fault()?;
let _ = self.tx_writer.send(msg);
Ok(())
}
pub fn read(&mut self) -> Result<Option<Box<M>>, std::io::Error> {
self.trap_fault()?;
let fetch = self.rx_reader.lock().unwrap().try_recv();
match fetch {
Ok(msg) => Ok(Some(msg)),
Err(e) => match e {
TryRecvError::Empty => Ok(None),
TryRecvError::Disconnected => {
Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
}
},
}
}
fn trap_fault(&mut self) -> Result<(), std::io::Error> {
let fetch = self.rx_err.lock().unwrap().try_recv();
match fetch {
Ok(f) => Err(f),
Err(e) => match e {
TryRecvError::Empty => Ok(()),
TryRecvError::Disconnected => {
Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
}
},
}
}
}
const BINCODE_CONF: bincode::config::Configuration<
bincode::config::BigEndian,
bincode::config::Fixint,
> = bincode::config::standard()
.with_big_endian()
.with_fixed_int_encoding();
struct MessageHeader {
version: u16,
size: u64,
}
impl MessageHeader {
fn from_slice(bytes: &[u8]) -> Self {
Self {
version: 1,
size: bytes.len() as u64,
}
}
}
impl From<MessageHeader> for [u8; 10] {
fn from(value: MessageHeader) -> Self {
let mut buf = std::io::Cursor::new(Vec::new());
buf.write_u16::<BigEndian>(value.version).unwrap();
buf.write_u64::<BigEndian>(value.size).unwrap();
buf.get_ref().as_slice().try_into().unwrap()
}
}
impl From<[u8; 10]> for MessageHeader {
fn from(value: [u8; 10]) -> Self {
let mut c = Cursor::new(&value);
let version = c.read_u16::<BigEndian>().unwrap();
let size = c.read_u64::<BigEndian>().unwrap();
Self { version, size }
}
}
struct StreamLooper<M: Message> {
stream: mio::net::TcpStream,
tx_reader: Sender<Box<M>>,
rx_writer: Receiver<Box<M>>,
tx_term: Sender<std::io::Error>,
reading: bool,
read_buf: Vec<u8>,
read_pos: usize,
read_target: usize,
writing: bool,
write_buf: Vec<u8>,
write_pos: usize,
write_target: usize,
}
impl<T: Message> Drop for StreamLooper<T> {
fn drop(&mut self) {
let _ = self.stream.shutdown(std::net::Shutdown::Both);
}
}
impl<M: Message> StreamLooper<M> {
const MAX_WAIT: Option<Duration> = Some(Duration::from_millis(1000));
fn new(
stream: TcpStream,
tx_reader: Sender<Box<M>>,
rx_writer: Receiver<Box<M>>,
tx_term: Sender<std::io::Error>,
) -> Self {
Self {
stream: mio::net::TcpStream::from_std(stream),
tx_reader,
rx_writer,
tx_term,
reading: false,
read_target: 0,
read_buf: Vec::new(),
read_pos: 0,
writing: false,
write_buf: Vec::new(),
write_target: 0,
write_pos: 0,
}
}
fn stream_loop(mut self) {
let Ok(mut poll) = Poll::new() else {
let _ = self
.tx_term
.send(std::io::Error::new(ErrorKind::ConnectionAborted, ""));
return;
};
let Ok(_) = poll.registry().register(
&mut self.stream,
Token(0),
Interest::READABLE | Interest::WRITABLE,
) else {
let _ = self
.tx_term
.send(std::io::Error::new(ErrorKind::ConnectionAborted, ""));
return;
};
let mut events = Events::with_capacity(1024);
loop {
if let Err(e) = self.try_process_buffers() {
if let ShortCircuit::Err(e) = e {
let _ = self.tx_term.send(e);
return;
} else {
while events.is_empty() {
if let Err(e) = poll.poll(&mut events, Self::MAX_WAIT) {
let _ = self
.tx_term
.send(std::io::Error::new(ErrorKind::ConnectionAborted, e));
return;
};
}
}
}
}
}
fn try_process_buffers(&mut self) -> Result<(), ShortCircuit> {
let read_res = self.read();
let write_res = self.write();
if let Err(ShortCircuit::Yield) = read_res
&& let Err(ShortCircuit::Yield) = write_res
{
return Err(ShortCircuit::Yield);
}
if let Err(ShortCircuit::Err(e)) = read_res {
return Err(ShortCircuit::Err(e));
}
if let Err(ShortCircuit::Err(e)) = write_res {
return Err(ShortCircuit::Err(e));
}
Ok(())
}
fn read(&mut self) -> Result<(), ShortCircuit> {
if !self.reading {
self.read_start()?;
} else {
self.read_continue()?;
}
Ok(())
}
fn read_start(&mut self) -> Result<(), ShortCircuit> {
let Some(header) = self.check_for_header()? else {
return Ok(());
};
let size = header.size as usize;
self.reading = true;
self.read_buf = vec![0; size];
self.read_target = size;
self.read_pos = 0;
self.read_continue()?;
Ok(())
}
fn read_continue(&mut self) -> Result<(), ShortCircuit> {
let buf = &mut self.read_buf[self.read_pos..self.read_target];
let op = self.stream.read(buf);
self.read_pos += match op {
Ok(n) => n,
Err(e) => match e.kind() {
ErrorKind::WouldBlock => return Err(ShortCircuit::Yield),
_ => return Err(e.into()),
},
};
if self.read_pos == self.read_target {
self.read_end()?;
}
Ok(())
}
fn read_end(&mut self) -> Result<(), ShortCircuit> {
self.reading = false;
self.read_target = 0;
self.read_pos = 0;
let mut buf = Vec::new();
mem::swap(&mut buf, &mut self.read_buf);
let msg = match bincode::serde::decode_from_slice::<M, _>(&buf, BINCODE_CONF) {
Ok((msg, _)) => msg.into(),
Err(e) => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
};
let _ = self.tx_reader.send(msg);
Ok(())
}
fn check_for_header(&mut self) -> Result<Option<MessageHeader>, ShortCircuit> {
let mut buf = [0; 10];
match self.stream.peek(&mut buf) {
Ok(read) => {
if read == 10 {
let _ = self.stream.read_exact(&mut buf);
return Ok(Some(buf.into()));
}
Ok(None)
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => Err(ShortCircuit::Yield),
_ => Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
},
}
}
fn write(&mut self) -> Result<(), ShortCircuit> {
if !self.writing {
self.write_start()?;
} else {
self.write_continue()?;
}
Ok(())
}
fn write_start(&mut self) -> Result<(), ShortCircuit> {
let fetch = self.rx_writer.try_recv();
let msg = match fetch {
Ok(msg) => msg,
Err(e) => match e {
TryRecvError::Empty => return Err(ShortCircuit::Yield),
TryRecvError::Disconnected => {
return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into());
}
},
};
let buf = match bincode::serde::encode_to_vec(msg, BINCODE_CONF) {
Ok(buf) => buf,
Err(e) => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
};
self.write_buf = buf;
self.writing = true;
self.write_pos = 0;
self.write_target = self.write_buf.len();
self.write_header()?;
self.write_continue()?;
Ok(())
}
fn write_header(&mut self) -> Result<(), ShortCircuit> {
let header: [u8; 10] = MessageHeader::from_slice(&self.write_buf).into();
self.write_all_blocking(&header)?;
Ok(())
}
fn write_all_blocking(&mut self, mut buf: &[u8]) -> Result<(), ShortCircuit> {
while !buf.is_empty() {
match self.stream.write(buf) {
Ok(0) => {
return Err(std::io::Error::new(ErrorKind::BrokenPipe, "").into());
}
Ok(n) => buf = &buf[n..],
Err(e) => match e.kind() {
ErrorKind::WouldBlock => continue,
_ => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
},
}
}
Ok(())
}
fn write_continue(&mut self) -> Result<(), ShortCircuit> {
let buf = &self.write_buf[self.write_pos..self.write_target];
let op = self.stream.write(buf);
self.write_pos += match op {
Ok(n) => n,
Err(e) => match e.kind() {
ErrorKind::WouldBlock => return Err(ShortCircuit::Yield),
_ => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
},
};
if self.write_pos == self.write_target {
self.write_end();
}
Ok(())
}
fn write_end(&mut self) {
self.writing = false;
self.write_pos = 0;
self.write_target = 0;
self.write_buf = Vec::new();
}
}
#[cfg(test)]
mod test {
use super::*;
use serde::{Deserialize, Serialize};
use serial_test::serial;
use std::{
net::{SocketAddr, TcpListener, TcpStream},
thread::sleep,
time::Duration,
};
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Msg {
data: Vec<u8>,
}
const TEST_SIZE: usize = 10_000_000;
#[test]
#[serial]
fn test_big_payload() {
let (mut h, mut c) = build_stream_pair();
let m = Msg {
data: vec![1; TEST_SIZE],
};
h.write(m.clone().into()).unwrap();
let msg_rec = wait_msg(&mut c).unwrap();
assert_eq!(m, *msg_rec);
}
#[test]
#[serial]
fn test_multichannels() {
let (mut h1, mut h2, mut c1, mut c2) = build_stream_triple();
for _ in 0..10 {
let m = Msg { data: vec![1; 100] };
h1.write(m.clone().into()).unwrap();
h2.write(m.clone().into()).unwrap();
let msg_rec1 = wait_msg(&mut c1).unwrap();
let msg_rec2 = wait_msg(&mut c2).unwrap();
assert_eq!(m, *msg_rec1);
assert_eq!(m, *msg_rec2);
}
}
fn wait_msg(c: &mut NonBlockStream<Msg>) -> Option<Box<Msg>> {
let mut count = 0;
sleep(Duration::from_millis(100));
let mut msg_rec = c.read().unwrap();
while msg_rec.is_none() && count < 100 {
sleep(Duration::from_millis(100));
count += 1;
msg_rec = c.read().unwrap();
}
msg_rec
}
fn build_stream_pair() -> (NonBlockStream<Msg>, NonBlockStream<Msg>) {
let p = find_port();
let s = SocketAddr::from(([127, 0, 0, 1], p));
let l = TcpListener::bind(s).unwrap();
let c = TcpStream::connect(s).unwrap();
let (h, _) = l.accept().unwrap();
(h.into(), c.into())
}
fn build_stream_triple() -> (
NonBlockStream<Msg>,
NonBlockStream<Msg>,
NonBlockStream<Msg>,
NonBlockStream<Msg>,
) {
let p = find_port();
let s = SocketAddr::from(([127, 0, 0, 1], p));
let l = TcpListener::bind(s).unwrap();
let c1 = TcpStream::connect(s).unwrap();
let (h_to_c1, _) = l.accept().unwrap();
let c2 = TcpStream::connect(s).unwrap();
let (h_to_c2, _) = l.accept().unwrap();
(h_to_c1.into(), h_to_c2.into(), c1.into(), c2.into())
}
fn find_port() -> u16 {
(10000..=20000)
.find(|p| TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], *p))).is_ok())
.unwrap()
}
}