use crate::{BufferPool, Error, IoBufs};
use std::{net::SocketAddr, time::Duration};
use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt as _, BufReader},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
},
time::timeout,
};
use tracing::warn;
pub struct Sink {
write_timeout: Duration,
sink: OwnedWriteHalf,
}
impl Sink {
async fn send_single(&mut self, buf: &[u8]) -> Result<(), Error> {
self.sink
.write_all(buf)
.await
.map_err(|_| Error::SendFailed)
}
async fn send_vectored(&mut self, bufs: &mut IoBufs) -> Result<(), Error> {
self.sink
.write_all_buf(bufs)
.await
.map_err(|_| Error::SendFailed)
}
}
impl crate::Sink for Sink {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
let write_timeout = self.write_timeout;
let bufs = bufs.into();
let send = async {
match bufs.try_into_single() {
Ok(buf) => self.send_single(buf.as_ref()).await,
Err(mut bufs) => self.send_vectored(&mut bufs).await,
}
};
timeout(write_timeout, send)
.await
.map_err(|_| Error::Timeout)?
}
}
pub struct Stream {
read_timeout: Duration,
stream: BufReader<OwnedReadHalf>,
pool: BufferPool,
}
impl crate::Stream for Stream {
async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
let read_fut = async {
let mut buf = unsafe { self.pool.alloc_len(len) };
self.stream
.read_exact(buf.as_mut())
.await
.map_err(|_| Error::RecvFailed)?;
Ok(IoBufs::from(buf.freeze()))
};
timeout(self.read_timeout, read_fut)
.await
.map_err(|_| Error::Timeout)?
}
fn peek(&self, max_len: usize) -> &[u8] {
let buffered = self.stream.buffer();
let len = std::cmp::min(buffered.len(), max_len);
&buffered[..len]
}
}
pub struct Listener {
cfg: Config,
listener: TcpListener,
pool: BufferPool,
}
impl crate::Listener for Listener {
type Sink = Sink;
type Stream = Stream;
async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
if let Some(tcp_nodelay) = self.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
if self.cfg.zero_linger {
if let Err(err) = stream.set_zero_linger() {
warn!(?err, "failed to set SO_LINGER");
}
}
let (stream, sink) = stream.into_split();
Ok((
addr,
Sink {
write_timeout: self.cfg.write_timeout,
sink,
},
Stream {
read_timeout: self.cfg.read_timeout,
stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream),
pool: self.pool.clone(),
},
))
}
fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.listener.local_addr()
}
}
#[derive(Clone, Debug)]
pub struct Config {
tcp_nodelay: Option<bool>,
zero_linger: bool,
read_timeout: Duration,
write_timeout: Duration,
read_buffer_size: usize,
}
#[cfg_attr(feature = "iouring-network", allow(dead_code))]
impl Config {
pub const fn with_tcp_nodelay(mut self, tcp_nodelay: Option<bool>) -> Self {
self.tcp_nodelay = tcp_nodelay;
self
}
pub const fn with_zero_linger(mut self, zero_linger: bool) -> Self {
self.zero_linger = zero_linger;
self
}
pub const fn with_read_timeout(mut self, read_timeout: Duration) -> Self {
self.read_timeout = read_timeout;
self
}
pub const fn with_write_timeout(mut self, write_timeout: Duration) -> Self {
self.write_timeout = write_timeout;
self
}
pub const fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self {
self.read_buffer_size = read_buffer_size;
self
}
pub const fn tcp_nodelay(&self) -> Option<bool> {
self.tcp_nodelay
}
pub const fn zero_linger(&self) -> bool {
self.zero_linger
}
pub const fn read_timeout(&self) -> Duration {
self.read_timeout
}
pub const fn write_timeout(&self) -> Duration {
self.write_timeout
}
pub const fn read_buffer_size(&self) -> usize {
self.read_buffer_size
}
}
impl Default for Config {
fn default() -> Self {
Self {
tcp_nodelay: Some(true),
zero_linger: true,
read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(60),
read_buffer_size: 64 * 1024, }
}
}
#[derive(Clone)]
pub struct Network {
cfg: Config,
pool: BufferPool,
}
impl Network {
pub const fn new(cfg: Config, pool: BufferPool) -> Self {
Self { cfg, pool }
}
}
impl crate::Network for Network {
type Listener = Listener;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, crate::Error> {
TcpListener::bind(socket)
.await
.map_err(|_| Error::BindFailed)
.map(|listener| Listener {
cfg: self.cfg.clone(),
listener,
pool: self.pool.clone(),
})
}
async fn dial(
&self,
socket: SocketAddr,
) -> Result<(crate::SinkOf<Self>, crate::StreamOf<Self>), crate::Error> {
let stream = TcpStream::connect(socket)
.await
.map_err(|_| Error::ConnectionFailed)?;
if let Some(tcp_nodelay) = self.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
if self.cfg.zero_linger {
if let Err(err) = stream.set_zero_linger() {
warn!(?err, "failed to set SO_LINGER");
}
}
let (stream, sink) = stream.into_split();
Ok((
Sink {
write_timeout: self.cfg.write_timeout,
sink,
},
Stream {
read_timeout: self.cfg.read_timeout,
stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream),
pool: self.pool.clone(),
},
))
}
}
#[cfg(test)]
mod tests {
use crate::{
network::{tests, tokio as TokioNetwork},
BufferPool, BufferPoolConfig, Listener as _, Network as _, Sink as _, Stream as _,
};
use commonware_macros::test_group;
use prometheus_client::registry::Registry;
use std::time::{Duration, Instant};
fn test_pool() -> BufferPool {
BufferPool::new(BufferPoolConfig::for_network(), &mut Registry::default())
}
#[tokio::test]
async fn test_trait() {
tests::test_network_trait(|| {
TokioNetwork::Network::new(
TokioNetwork::Config::default()
.with_read_timeout(Duration::from_secs(15))
.with_write_timeout(Duration::from_secs(15)),
test_pool(),
)
})
.await;
}
#[test_group("slow")]
#[tokio::test]
async fn test_stress_trait() {
tests::stress_test_network_trait(|| {
TokioNetwork::Network::new(
TokioNetwork::Config::default()
.with_read_timeout(Duration::from_secs(15))
.with_write_timeout(Duration::from_secs(15)),
test_pool(),
)
})
.await;
}
#[tokio::test]
async fn test_small_send_read_quickly() {
let read_timeout = Duration::from_secs(30);
let network = TokioNetwork::Network::new(
TokioNetwork::Config::default()
.with_read_timeout(read_timeout)
.with_write_timeout(Duration::from_secs(5)),
test_pool(),
);
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
let start = Instant::now();
let received = stream.recv(10).await.unwrap();
let elapsed = start.elapsed();
(received, elapsed)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
let msg = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
sink.send(msg.clone()).await.unwrap();
let (received, elapsed) = reader.await.unwrap();
assert_eq!(received.coalesce(), msg.as_slice());
assert!(elapsed < read_timeout);
}
#[tokio::test]
async fn test_read_timeout_with_partial_data() {
let read_timeout = Duration::from_millis(100);
let network = TokioNetwork::Network::new(
TokioNetwork::Config::default()
.with_read_timeout(read_timeout)
.with_write_timeout(Duration::from_secs(5)),
test_pool(),
);
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
let start = Instant::now();
let result = stream.recv(100).await;
let elapsed = start.elapsed();
(result, elapsed)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send([1u8, 2, 3, 4, 5].as_slice()).await.unwrap();
let (result, elapsed) = reader.await.unwrap();
assert!(matches!(result, Err(crate::Error::Timeout)));
assert!(elapsed >= read_timeout);
assert!(elapsed < read_timeout * 2);
}
#[tokio::test]
async fn test_unbuffered_mode() {
let network = TokioNetwork::Network::new(
TokioNetwork::Config::default()
.with_read_buffer_size(0)
.with_read_timeout(Duration::from_secs(5))
.with_write_timeout(Duration::from_secs(5)),
test_pool(),
);
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
assert!(stream.peek(100).is_empty());
let buf1 = stream.recv(5).await.unwrap();
assert!(stream.peek(100).is_empty());
let buf2 = stream.recv(5).await.unwrap();
assert!(stream.peek(100).is_empty());
(buf1, buf2)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send([1u8, 2, 3, 4, 5].as_slice()).await.unwrap();
sink.send([6u8, 7, 8, 9, 10].as_slice()).await.unwrap();
let (buf1, buf2) = reader.await.unwrap();
assert_eq!(buf1.coalesce(), &[1u8, 2, 3, 4, 5]);
assert_eq!(buf2.coalesce(), &[6u8, 7, 8, 9, 10]);
}
#[tokio::test]
async fn test_peek_with_buffered_data() {
let network = TokioNetwork::Network::new(
TokioNetwork::Config::default()
.with_read_timeout(Duration::from_secs(5))
.with_write_timeout(Duration::from_secs(5)),
test_pool(),
);
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
assert!(stream.peek(100).is_empty());
let first = stream.recv(5).await.unwrap();
assert_eq!(first.coalesce(), b"hello");
let peeked = stream.peek(100);
assert!(!peeked.is_empty());
assert_eq!(peeked, b" world");
assert_eq!(stream.peek(100), b" world");
assert_eq!(stream.peek(3), b" wo");
let rest = stream.recv(6).await.unwrap();
assert_eq!(rest.coalesce(), b" world");
assert!(stream.peek(100).is_empty());
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send(b"hello world").await.unwrap();
reader.await.unwrap();
}
}