mod packs;
pub use packs::Packs;
mod looper;
mod metrics;
#[cfg(test)]
mod test;
use std::{
io::{ErrorKind, Read, Write}, net::{SocketAddr, TcpStream}, num::NonZero, sync::{
Arc, Mutex,
mpsc::{Receiver, Sender, TryRecvError, channel},
}, thread::JoinHandle
};
use crate::metrics::Metrics;
#[derive(Clone, Copy, Debug, Hash, PartialEq)]
pub struct Versions {
cur: u16,
min: u16,
max: u16,
}
impl Default for Versions {
fn default() -> Self {
Self {
cur: 1,
min: 1,
max: 1,
}
}
}
impl From<u16> for Versions {
fn from(value: u16) -> Self {
Self {
cur: value,
min: value,
max: value,
}
}
}
impl Versions {
pub fn new(cur: u16, min: u16, max: u16) -> Self {
Self { cur, min, max }
}
}
#[derive(Clone)]
pub struct NonBlockStream<const MS: usize = 0> {
rx_reader: Arc<Mutex<Receiver<Vec<u8>>>>,
tx_writer: Sender<Vec<u8>>,
rx_err: Arc<Mutex<Receiver<std::io::Error>>>,
local_addr: SocketAddr,
remote_addr: SocketAddr,
_handle: Arc<JoinHandle<()>>,
metrics: Metrics,
}
impl From<TcpStream> for NonBlockStream {
fn from(stream: TcpStream) -> Self {
NonBlockStream::from_versions(Versions::default(), stream)
}
}
impl<const MS: usize> NonBlockStream<MS> {
pub fn from_version_packs(v: Versions, packs: Packs, stream: TcpStream) -> Self {
let local_addr = stream
.local_addr()
.expect("Could not obtain local_addr from stream");
let remote_addr = stream
.peer_addr()
.expect("Could not obtain peer_addr from stream");
stream
.set_nonblocking(true)
.expect("Could not set socket to nonblocking. It is required for communication.");
let (tx_reader, rx_reader) = channel::<Vec<u8>>();
let (tx_writer, rx_writer) = channel::<Vec<u8>>();
let (tx_err, rx_err) = channel::<std::io::Error>();
let (metrics, metrics_tx) = Metrics::new();
let looper = looper::StreamLooper::new(
v,
NonZero::new(MS),
packs,
stream,
tx_reader,
rx_writer,
tx_err,
metrics_tx,
);
let handle = std::thread::spawn(move || {
looper.stream_loop();
});
Self {
rx_reader: Arc::new(Mutex::new(rx_reader)),
tx_writer,
rx_err: Arc::new(Mutex::new(rx_err)),
local_addr,
remote_addr,
_handle: Arc::new(handle),
metrics,
}
}
pub fn from_versions(v: Versions, stream: TcpStream) -> Self {
Self::from_version_packs(v, Default::default(), stream)
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
pub fn write(&mut self, msg: Vec<u8>) -> Result<(), std::io::Error> {
self.trap_fault()?;
self.trap_write(msg)
}
pub fn read(&mut self) -> Result<Option<Vec<u8>>, std::io::Error> {
self.trap_fault()?;
self.trap_recv()
}
pub fn total_read(&self) -> usize {
self.metrics.read()
}
pub fn total_sent(&self) -> usize {
self.metrics.sent()
}
fn trap_write(&mut self, msg: Vec<u8>) -> Result<(), std::io::Error> {
self.tx_writer
.send(msg)
.map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, e))
}
fn trap_recv(&mut self) -> Result<Option<Vec<u8>>, std::io::Error> {
let op = self.rx_reader.lock().unwrap().try_recv();
match op {
Ok(msg) => Ok(Some(msg)),
Err(e) => match e {
TryRecvError::Empty => Ok(None),
TryRecvError::Disconnected => {
Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
}
},
}
}
fn trap_fault(&mut self) -> Result<(), std::io::Error> {
let op = self.rx_err.lock().unwrap().try_recv();
match op {
Ok(f) => Err(f),
Err(e) => match e {
TryRecvError::Empty => Ok(()),
TryRecvError::Disconnected => {
Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
}
},
}
}
}