use crate::Error;
use commonware_utils::StableBuf;
use std::{net::SocketAddr, time::Duration};
use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt as _},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
},
time::timeout,
};
use tracing::warn;
pub struct Sink {
write_timeout: Duration,
sink: OwnedWriteHalf,
}
impl crate::Sink for Sink {
async fn send(&mut self, msg: impl Into<StableBuf> + Send) -> Result<(), Error> {
timeout(self.write_timeout, self.sink.write_all(msg.into().as_ref()))
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::SendFailed)?;
Ok(())
}
}
pub struct Stream {
read_timeout: Duration,
stream: OwnedReadHalf,
}
impl crate::Stream for Stream {
async fn recv(&mut self, buf: impl Into<StableBuf> + Send) -> Result<StableBuf, Error> {
let mut buf = buf.into();
if buf.is_empty() {
return Ok(buf);
}
timeout(self.read_timeout, self.stream.read_exact(buf.as_mut()))
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::RecvFailed)?;
Ok(buf)
}
}
pub struct Listener {
cfg: Config,
listener: TcpListener,
}
impl crate::Listener for Listener {
type Sink = Sink;
type Stream = Stream;
async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
if let Some(tcp_nodelay) = self.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let (stream, sink) = stream.into_split();
Ok((
addr,
Sink {
write_timeout: self.cfg.write_timeout,
sink,
},
Stream {
read_timeout: self.cfg.read_timeout,
stream,
},
))
}
fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.listener.local_addr()
}
}
#[derive(Clone, Debug)]
pub struct Config {
tcp_nodelay: Option<bool>,
read_timeout: Duration,
write_timeout: Duration,
}
#[cfg_attr(feature = "iouring-network", allow(dead_code))]
impl Config {
pub fn with_tcp_nodelay(mut self, tcp_nodelay: Option<bool>) -> Self {
self.tcp_nodelay = tcp_nodelay;
self
}
pub fn with_read_timeout(mut self, read_timeout: Duration) -> Self {
self.read_timeout = read_timeout;
self
}
pub fn with_write_timeout(mut self, write_timeout: Duration) -> Self {
self.write_timeout = write_timeout;
self
}
pub fn tcp_nodelay(&self) -> Option<bool> {
self.tcp_nodelay
}
pub fn read_timeout(&self) -> Duration {
self.read_timeout
}
pub fn write_timeout(&self) -> Duration {
self.write_timeout
}
}
impl Default for Config {
fn default() -> Self {
Self {
tcp_nodelay: None,
read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(30),
}
}
}
#[derive(Clone, Debug)]
pub struct Network {
cfg: Config,
}
impl From<Config> for Network {
fn from(cfg: Config) -> Self {
Self { cfg }
}
}
impl Default for Network {
fn default() -> Self {
Self::from(Config::default())
}
}
impl crate::Network for Network {
type Listener = Listener;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, crate::Error> {
TcpListener::bind(socket)
.await
.map_err(|_| Error::BindFailed)
.map(|listener| Listener {
cfg: self.cfg.clone(),
listener,
})
}
async fn dial(
&self,
socket: SocketAddr,
) -> Result<(crate::SinkOf<Self>, crate::StreamOf<Self>), crate::Error> {
let stream = TcpStream::connect(socket)
.await
.map_err(|_| Error::ConnectionFailed)?;
if let Some(tcp_nodelay) = self.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let (stream, sink) = stream.into_split();
Ok((
Sink {
write_timeout: self.cfg.write_timeout,
sink,
},
Stream {
read_timeout: self.cfg.read_timeout,
stream,
},
))
}
}
#[cfg(test)]
mod tests {
use crate::network::{tests, tokio as TokioNetwork};
use std::time::Duration;
#[tokio::test]
async fn test_trait() {
tests::test_network_trait(|| {
TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(Duration::from_secs(15))
.with_write_timeout(Duration::from_secs(15)),
)
})
.await;
}
#[tokio::test]
#[ignore]
async fn stress_test_trait() {
tests::stress_test_network_trait(|| {
TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(Duration::from_secs(15))
.with_write_timeout(Duration::from_secs(15)),
)
})
.await;
}
}