use crate::iouring::{self, should_retry};
use commonware_utils::StableBuf;
use futures::{
channel::{mpsc, oneshot},
executor::block_on,
SinkExt as _,
};
use io_uring::types::Fd;
use prometheus_client::registry::Registry;
use std::{
net::SocketAddr,
os::fd::{AsRawFd, OwnedFd},
sync::Arc,
};
use tokio::net::{TcpListener, TcpStream};
use tracing::warn;
#[derive(Clone, Debug, Default)]
pub struct Config {
pub tcp_nodelay: Option<bool>,
pub iouring_config: iouring::Config,
}
#[derive(Clone, Debug)]
pub struct Network {
tcp_nodelay: Option<bool>,
send_submitter: mpsc::Sender<iouring::Op>,
recv_submitter: mpsc::Sender<iouring::Op>,
}
impl Network {
pub(crate) fn start(mut cfg: Config, registry: &mut Registry) -> Result<Self, crate::Error> {
let (send_submitter, rx) = mpsc::channel(cfg.iouring_config.size as usize);
cfg.iouring_config.single_issuer = true;
std::thread::spawn({
let cfg = cfg.clone();
let registry = registry.sub_registry_with_prefix("iouring_sender");
let metrics = Arc::new(iouring::Metrics::new(registry));
move || block_on(iouring::run(cfg.iouring_config, metrics, rx))
});
let (recv_submitter, rx) = mpsc::channel(cfg.iouring_config.size as usize);
let registry = registry.sub_registry_with_prefix("iouring_receiver");
let metrics = Arc::new(iouring::Metrics::new(registry));
std::thread::spawn(|| block_on(iouring::run(cfg.iouring_config, metrics, rx)));
Ok(Self {
tcp_nodelay: cfg.tcp_nodelay,
send_submitter,
recv_submitter,
})
}
}
impl crate::Network for Network {
type Listener = Listener;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, crate::Error> {
let listener = TcpListener::bind(socket)
.await
.map_err(|_| crate::Error::BindFailed)?;
Ok(Listener {
tcp_nodelay: self.tcp_nodelay,
inner: listener,
send_submitter: self.send_submitter.clone(),
recv_submitter: self.recv_submitter.clone(),
})
}
async fn dial(
&self,
socket: SocketAddr,
) -> Result<(crate::SinkOf<Self>, crate::StreamOf<Self>), crate::Error> {
let stream = TcpStream::connect(socket)
.await
.map_err(|_| crate::Error::ConnectionFailed)?
.into_std()
.map_err(|_| crate::Error::ConnectionFailed)?;
if let Some(tcp_nodelay) = self.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
stream
.set_nonblocking(true)
.map_err(|_| crate::Error::ConnectionFailed)?;
let fd = Arc::new(OwnedFd::from(stream));
Ok((
Sink {
fd: fd.clone(),
submitter: self.send_submitter.clone(),
},
Stream {
fd,
submitter: self.recv_submitter.clone(),
},
))
}
}
pub struct Listener {
tcp_nodelay: Option<bool>,
inner: TcpListener,
send_submitter: mpsc::Sender<iouring::Op>,
recv_submitter: mpsc::Sender<iouring::Op>,
}
impl crate::Listener for Listener {
type Stream = Stream;
type Sink = Sink;
async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), crate::Error> {
let (stream, remote_addr) = self
.inner
.accept()
.await
.map_err(|_| crate::Error::ConnectionFailed)?;
let stream = stream
.into_std()
.map_err(|_| crate::Error::ConnectionFailed)?;
if let Some(tcp_nodelay) = self.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
stream
.set_nonblocking(true)
.map_err(|_| crate::Error::ConnectionFailed)?;
let fd = Arc::new(OwnedFd::from(stream));
Ok((
remote_addr,
Sink {
fd: fd.clone(),
submitter: self.send_submitter.clone(),
},
Stream {
fd,
submitter: self.recv_submitter.clone(),
},
))
}
fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.inner.local_addr()
}
}
pub struct Sink {
fd: Arc<OwnedFd>,
submitter: mpsc::Sender<iouring::Op>,
}
impl Sink {
fn as_raw_fd(&self) -> Fd {
Fd(self.fd.as_raw_fd())
}
}
impl crate::Sink for Sink {
async fn send(&mut self, msg: impl Into<StableBuf> + Send) -> Result<(), crate::Error> {
let mut msg = msg.into();
let mut bytes_sent = 0;
let msg_len = msg.len();
while bytes_sent < msg_len {
let remaining = unsafe {
std::slice::from_raw_parts(
msg.as_mut_ptr().add(bytes_sent) as *const u8,
msg_len - bytes_sent,
)
};
let op = io_uring::opcode::Send::new(
self.as_raw_fd(),
remaining.as_ptr(),
remaining.len() as u32,
)
.build();
let (tx, rx) = oneshot::channel();
self.submitter
.send(crate::iouring::Op {
work: op,
sender: tx,
buffer: Some(msg),
})
.await
.map_err(|_| crate::Error::SendFailed)?;
let (result, got_msg) = rx.await.map_err(|_| crate::Error::SendFailed)?;
msg = got_msg.unwrap();
if should_retry(result) {
continue;
}
if result <= 0 {
return Err(crate::Error::SendFailed);
}
bytes_sent += result as usize;
}
Ok(())
}
}
pub struct Stream {
fd: Arc<OwnedFd>,
submitter: mpsc::Sender<iouring::Op>,
}
impl Stream {
fn as_raw_fd(&self) -> Fd {
Fd(self.fd.as_raw_fd())
}
}
impl crate::Stream for Stream {
async fn recv(&mut self, buf: impl Into<StableBuf> + Send) -> Result<StableBuf, crate::Error> {
let mut bytes_received = 0;
let mut buf = buf.into();
let buf_len = buf.len();
while bytes_received < buf_len {
let remaining = unsafe {
std::slice::from_raw_parts_mut(
buf.as_mut_ptr().add(bytes_received),
buf_len - bytes_received,
)
};
let op = io_uring::opcode::Recv::new(
self.as_raw_fd(),
remaining.as_mut_ptr(),
remaining.len() as u32,
)
.build();
let (tx, rx) = oneshot::channel();
self.submitter
.send(crate::iouring::Op {
work: op,
sender: tx,
buffer: Some(buf),
})
.await
.map_err(|_| crate::Error::RecvFailed)?;
let (result, got_buf) = rx.await.map_err(|_| crate::Error::RecvFailed)?;
buf = got_buf.unwrap();
if should_retry(result) {
continue;
}
if result <= 0 {
return Err(crate::Error::RecvFailed);
}
bytes_received += result as usize;
}
Ok(buf)
}
}
#[cfg(test)]
mod tests {
use crate::{
iouring,
network::{
iouring::{Config, Network},
tests,
},
};
use prometheus_client::registry::Registry;
use std::time::Duration;
#[tokio::test]
async fn test_trait() {
tests::test_network_trait(|| {
Network::start(
Config {
iouring_config: iouring::Config {
force_poll: Some(Duration::from_millis(100)),
..Default::default()
},
..Default::default()
},
&mut Registry::default(),
)
.expect("Failed to start io_uring")
})
.await;
}
#[tokio::test]
#[ignore]
async fn stress_test_trait() {
tests::stress_test_network_trait(|| {
Network::start(
Config {
iouring_config: iouring::Config {
size: 256,
force_poll: Some(Duration::from_millis(100)),
..Default::default()
},
..Default::default()
},
&mut Registry::default(),
)
.expect("Failed to start io_uring")
})
.await;
}
}