use crate::Error;
use bytes::{Buf, BufMut};
use std::{net::SocketAddr, time::Duration};
use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt as _, BufReader},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
},
time::timeout,
};
use tracing::warn;
pub struct Sink {
write_timeout: Duration,
sink: OwnedWriteHalf,
}
impl crate::Sink for Sink {
async fn send(&mut self, mut msg: impl Buf + Send) -> Result<(), Error> {
timeout(self.write_timeout, self.sink.write_all_buf(&mut msg))
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::SendFailed)?;
Ok(())
}
}
pub struct Stream {
read_timeout: Duration,
stream: BufReader<OwnedReadHalf>,
}
impl crate::Stream for Stream {
async fn recv(&mut self, mut buf: impl BufMut + Send) -> Result<(), Error> {
let read_fut = async {
let mut read = 0;
let len = buf.remaining_mut();
while read < len {
let n = self
.stream
.read_buf(&mut buf)
.await
.map_err(|_| Error::RecvFailed)?;
if n == 0 {
return Err(Error::RecvFailed);
}
read += n;
}
Ok(())
};
timeout(self.read_timeout, read_fut)
.await
.map_err(|_| Error::Timeout)?
}
}
pub struct Listener {
cfg: Config,
listener: TcpListener,
}
impl crate::Listener for Listener {
type Sink = Sink;
type Stream = Stream;
async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
if let Some(tcp_nodelay) = self.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let (stream, sink) = stream.into_split();
Ok((
addr,
Sink {
write_timeout: self.cfg.write_timeout,
sink,
},
Stream {
read_timeout: self.cfg.read_timeout,
stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream),
},
))
}
fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.listener.local_addr()
}
}
#[derive(Clone, Debug)]
pub struct Config {
tcp_nodelay: Option<bool>,
read_timeout: Duration,
write_timeout: Duration,
read_buffer_size: usize,
}
#[cfg_attr(feature = "iouring-network", allow(dead_code))]
impl Config {
pub const fn with_tcp_nodelay(mut self, tcp_nodelay: Option<bool>) -> Self {
self.tcp_nodelay = tcp_nodelay;
self
}
pub const fn with_read_timeout(mut self, read_timeout: Duration) -> Self {
self.read_timeout = read_timeout;
self
}
pub const fn with_write_timeout(mut self, write_timeout: Duration) -> Self {
self.write_timeout = write_timeout;
self
}
pub const fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self {
self.read_buffer_size = read_buffer_size;
self
}
pub const fn tcp_nodelay(&self) -> Option<bool> {
self.tcp_nodelay
}
pub const fn read_timeout(&self) -> Duration {
self.read_timeout
}
pub const fn write_timeout(&self) -> Duration {
self.write_timeout
}
pub const fn read_buffer_size(&self) -> usize {
self.read_buffer_size
}
}
impl Default for Config {
fn default() -> Self {
Self {
tcp_nodelay: None,
read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(30),
read_buffer_size: 64 * 1024, }
}
}
#[derive(Clone, Debug)]
pub struct Network {
cfg: Config,
}
impl From<Config> for Network {
fn from(cfg: Config) -> Self {
Self { cfg }
}
}
impl Default for Network {
fn default() -> Self {
Self::from(Config::default())
}
}
impl crate::Network for Network {
type Listener = Listener;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, crate::Error> {
TcpListener::bind(socket)
.await
.map_err(|_| Error::BindFailed)
.map(|listener| Listener {
cfg: self.cfg.clone(),
listener,
})
}
async fn dial(
&self,
socket: SocketAddr,
) -> Result<(crate::SinkOf<Self>, crate::StreamOf<Self>), crate::Error> {
let stream = TcpStream::connect(socket)
.await
.map_err(|_| Error::ConnectionFailed)?;
if let Some(tcp_nodelay) = self.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let (stream, sink) = stream.into_split();
Ok((
Sink {
write_timeout: self.cfg.write_timeout,
sink,
},
Stream {
read_timeout: self.cfg.read_timeout,
stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream),
},
))
}
}
#[cfg(test)]
mod tests {
use crate::{
network::{tests, tokio as TokioNetwork},
Listener as _, Network as _, Sink as _, Stream as _,
};
use commonware_macros::test_group;
use std::time::{Duration, Instant};
#[tokio::test]
async fn test_trait() {
tests::test_network_trait(|| {
TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(Duration::from_secs(15))
.with_write_timeout(Duration::from_secs(15)),
)
})
.await;
}
#[test_group("slow")]
#[tokio::test]
async fn test_stress_trait() {
tests::stress_test_network_trait(|| {
TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(Duration::from_secs(15))
.with_write_timeout(Duration::from_secs(15)),
)
})
.await;
}
#[tokio::test]
async fn test_small_send_read_quickly() {
let read_timeout = Duration::from_secs(30);
let network = TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(read_timeout)
.with_write_timeout(Duration::from_secs(5)),
);
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
let start = Instant::now();
let mut buf = vec![0u8; 10];
stream.recv(&mut buf[..]).await.unwrap();
let elapsed = start.elapsed();
(buf, elapsed)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
let msg = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
sink.send(msg.as_slice()).await.unwrap();
let (received, elapsed) = reader.await.unwrap();
assert_eq!(received, &msg[..]);
assert!(elapsed < read_timeout);
}
#[tokio::test]
async fn test_read_timeout_with_partial_data() {
let read_timeout = Duration::from_millis(100);
let network = TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(read_timeout)
.with_write_timeout(Duration::from_secs(5)),
);
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
let start = Instant::now();
let mut buf = [0u8; 100];
let result = stream.recv(&mut buf[..]).await;
let elapsed = start.elapsed();
(result, elapsed)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send([1u8, 2, 3, 4, 5].as_slice()).await.unwrap();
let (result, elapsed) = reader.await.unwrap();
assert!(matches!(result, Err(crate::Error::Timeout)));
assert!(elapsed >= read_timeout);
assert!(elapsed < read_timeout * 2);
}
#[tokio::test]
async fn test_unbuffered_mode() {
let network = TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_buffer_size(0)
.with_read_timeout(Duration::from_secs(5))
.with_write_timeout(Duration::from_secs(5)),
);
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
let mut buf1 = vec![0u8; 5];
let mut buf2 = vec![0u8; 5];
stream.recv(&mut buf1[..]).await.unwrap();
stream.recv(&mut buf2[..]).await.unwrap();
(buf1, buf2)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send([1u8, 2, 3, 4, 5].as_slice()).await.unwrap();
sink.send([6u8, 7, 8, 9, 10].as_slice()).await.unwrap();
let (buf1, buf2) = reader.await.unwrap();
assert_eq!(buf1.as_slice(), &[1u8, 2, 3, 4, 5]);
assert_eq!(buf2.as_slice(), &[6u8, 7, 8, 9, 10]);
}
}