use crate::{
iouring::{self, should_retry, OpBuffer, OpFd, OpIovecs},
Buf, BufferPool, Error, IoBuf, IoBufMut, IoBufs,
};
use commonware_utils::channel::oneshot;
use io_uring::{opcode, types::Fd};
use prometheus_client::registry::Registry;
use std::{
net::SocketAddr,
os::fd::{AsRawFd, OwnedFd},
sync::Arc,
time::Duration,
};
use tokio::net::{TcpListener, TcpStream};
use tracing::warn;
const DEFAULT_READ_BUFFER_SIZE: usize = 64 * 1024;
const IOVEC_BATCH_SIZE: usize = 32;
#[derive(Clone, Debug)]
pub struct Config {
pub tcp_nodelay: Option<bool>,
pub so_linger: Option<Duration>,
pub iouring_config: iouring::Config,
pub read_buffer_size: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
tcp_nodelay: None,
so_linger: None,
iouring_config: iouring::Config::default(),
read_buffer_size: DEFAULT_READ_BUFFER_SIZE,
}
}
}
#[derive(Clone)]
pub struct Network {
tcp_nodelay: Option<bool>,
so_linger: Option<Duration>,
send_submitter: iouring::Submitter,
recv_submitter: iouring::Submitter,
read_buffer_size: usize,
pool: BufferPool,
}
impl Network {
pub(crate) fn start(
mut cfg: Config,
registry: &mut Registry,
pool: BufferPool,
) -> Result<Self, Error> {
cfg.iouring_config.single_issuer = true;
let sender_registry = registry.sub_registry_with_prefix("iouring_sender");
let (send_submitter, send_loop) =
iouring::IoUringLoop::new(cfg.iouring_config.clone(), sender_registry);
std::thread::spawn(move || send_loop.run());
let receiver_registry = registry.sub_registry_with_prefix("iouring_receiver");
let (recv_submitter, recv_loop) =
iouring::IoUringLoop::new(cfg.iouring_config, receiver_registry);
std::thread::spawn(move || recv_loop.run());
Ok(Self {
tcp_nodelay: cfg.tcp_nodelay,
so_linger: cfg.so_linger,
send_submitter,
recv_submitter,
read_buffer_size: cfg.read_buffer_size,
pool,
})
}
}
impl crate::Network for Network {
type Listener = Listener;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
let listener = TcpListener::bind(socket)
.await
.map_err(|_| Error::BindFailed)?;
Ok(Listener {
tcp_nodelay: self.tcp_nodelay,
so_linger: self.so_linger,
inner: listener,
send_submitter: self.send_submitter.clone(),
recv_submitter: self.recv_submitter.clone(),
read_buffer_size: self.read_buffer_size,
pool: self.pool.clone(),
})
}
async fn dial(
&self,
socket: SocketAddr,
) -> Result<(crate::SinkOf<Self>, crate::StreamOf<Self>), Error> {
let stream = TcpStream::connect(socket)
.await
.map_err(|_| Error::ConnectionFailed)?;
if let Some(tcp_nodelay) = self.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
if let Some(so_linger) = self.so_linger {
if let Err(err) = stream.set_linger(Some(so_linger)) {
warn!(?err, "failed to set SO_LINGER");
}
}
let stream = stream.into_std().map_err(|_| Error::ConnectionFailed)?;
stream
.set_nonblocking(true)
.map_err(|_| Error::ConnectionFailed)?;
let fd = Arc::new(OwnedFd::from(stream));
Ok((
Sink::new(fd.clone(), self.send_submitter.clone()),
Stream::new(
fd,
self.recv_submitter.clone(),
self.read_buffer_size,
self.pool.clone(),
),
))
}
}
pub struct Listener {
tcp_nodelay: Option<bool>,
so_linger: Option<Duration>,
inner: TcpListener,
send_submitter: iouring::Submitter,
recv_submitter: iouring::Submitter,
read_buffer_size: usize,
pool: BufferPool,
}
impl crate::Listener for Listener {
type Stream = Stream;
type Sink = Sink;
async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
let (stream, remote_addr) = self
.inner
.accept()
.await
.map_err(|_| Error::ConnectionFailed)?;
if let Some(tcp_nodelay) = self.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
if let Some(so_linger) = self.so_linger {
if let Err(err) = stream.set_linger(Some(so_linger)) {
warn!(?err, "failed to set SO_LINGER");
}
}
let stream = stream.into_std().map_err(|_| Error::ConnectionFailed)?;
stream
.set_nonblocking(true)
.map_err(|_| Error::ConnectionFailed)?;
let fd = Arc::new(OwnedFd::from(stream));
Ok((
remote_addr,
Sink::new(fd.clone(), self.send_submitter.clone()),
Stream::new(
fd,
self.recv_submitter.clone(),
self.read_buffer_size,
self.pool.clone(),
),
))
}
fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.inner.local_addr()
}
}
pub struct Sink {
fd: Arc<OwnedFd>,
submitter: iouring::Submitter,
}
impl Sink {
const fn new(fd: Arc<OwnedFd>, submitter: iouring::Submitter) -> Self {
Self { fd, submitter }
}
fn as_raw_fd(&self) -> Fd {
Fd(self.fd.as_raw_fd())
}
async fn send_single(&self, mut buf: IoBuf) -> Result<(), Error> {
let mut bytes_sent = 0;
let buf_len = buf.len();
while bytes_sent < buf_len {
let ptr = unsafe { buf.as_ptr().add(bytes_sent) };
let remaining_len = buf_len - bytes_sent;
let op = opcode::Send::new(self.as_raw_fd(), ptr, remaining_len as u32).build();
let (sender, receiver) = oneshot::channel();
self.submitter
.send(iouring::Op {
work: op,
sender,
buffer: Some(OpBuffer::Write(buf)),
fd: Some(OpFd::Fd(self.fd.clone())),
iovecs: None,
})
.await
.map_err(|_| Error::SendFailed)?;
let (return_value, return_buf) = receiver.await.map_err(|_| Error::SendFailed)?;
buf = match return_buf {
Some(OpBuffer::Write(b)) => b,
_ => unreachable!("io_uring loop returns the same OpBuffer that was submitted"),
};
if should_retry(return_value) {
continue;
}
let op_bytes_sent: usize = return_value.try_into().map_err(|_| Error::SendFailed)?;
if op_bytes_sent == 0 {
return Err(Error::SendFailed);
}
bytes_sent += op_bytes_sent;
}
Ok(())
}
async fn send_vectored(&self, mut bufs: IoBufs) -> Result<(), Error> {
while bufs.has_remaining() {
let (iovecs, iovecs_len) = {
let max_iovecs = bufs.chunk_count().min(IOVEC_BATCH_SIZE);
assert!(
max_iovecs > 0,
"chunk_count should be > 0 if bufs.has_remaining() is true"
);
let mut iovecs: Box<[libc::iovec]> = std::iter::repeat_n(
libc::iovec {
iov_base: std::ptr::NonNull::<u8>::dangling().as_ptr().cast(),
iov_len: 0,
},
max_iovecs,
)
.collect();
let io_slices: &mut [std::io::IoSlice<'_>] = unsafe {
std::slice::from_raw_parts_mut(
iovecs.as_mut_ptr().cast::<std::io::IoSlice<'_>>(),
iovecs.len(),
)
};
let io_slices_len = bufs.chunks_vectored(io_slices);
assert!(
io_slices_len > 0,
"chunks_vectored should produce at least one slice when bufs has remaining"
);
(OpIovecs::new(iovecs), io_slices_len)
};
let op =
opcode::Writev::new(self.as_raw_fd(), iovecs.as_ptr(), iovecs_len as _).build();
let (sender, receiver) = oneshot::channel();
self.submitter
.send(iouring::Op {
work: op,
sender,
buffer: Some(OpBuffer::WriteVectored(bufs)),
fd: Some(OpFd::Fd(self.fd.clone())),
iovecs: Some(iovecs),
})
.await
.map_err(|_| Error::SendFailed)?;
let (return_value, return_bufs) = receiver.await.map_err(|_| Error::SendFailed)?;
bufs = match return_bufs {
Some(OpBuffer::WriteVectored(b)) => b,
_ => unreachable!("io_uring loop returns the same OpBuffer that was submitted"),
};
if should_retry(return_value) {
continue;
}
let op_bytes_sent: usize = return_value.try_into().map_err(|_| Error::SendFailed)?;
if op_bytes_sent == 0 {
return Err(Error::SendFailed);
}
bufs.advance(op_bytes_sent);
}
Ok(())
}
}
impl crate::Sink for Sink {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
match bufs.into().try_into_single() {
Ok(buf) => self.send_single(buf).await,
Err(bufs) => self.send_vectored(bufs).await,
}
}
}
pub struct Stream {
fd: Arc<OwnedFd>,
submitter: iouring::Submitter,
buffer: IoBufMut,
buffer_pos: usize,
buffer_len: usize,
pool: BufferPool,
}
impl Stream {
fn new(
fd: Arc<OwnedFd>,
submitter: iouring::Submitter,
buffer_capacity: usize,
pool: BufferPool,
) -> Self {
Self {
fd,
submitter,
buffer: IoBufMut::with_capacity(buffer_capacity),
buffer_pos: 0,
buffer_len: 0,
pool,
}
}
fn as_raw_fd(&self) -> Fd {
Fd(self.fd.as_raw_fd())
}
async fn submit_recv(
&mut self,
mut buffer: IoBufMut,
offset: usize,
len: usize,
) -> (IoBufMut, Result<usize, Error>) {
loop {
let ptr = unsafe { buffer.as_mut_ptr().add(offset) };
let op = opcode::Recv::new(self.as_raw_fd(), ptr, len as u32).build();
let (sender, receiver) = oneshot::channel();
if self
.submitter
.send(iouring::Op {
work: op,
sender,
buffer: Some(OpBuffer::Read(buffer)),
fd: Some(OpFd::Fd(self.fd.clone())),
iovecs: None,
})
.await
.is_err()
{
return (IoBufMut::default(), Err(Error::RecvFailed));
}
let Ok((return_value, return_buf)) = receiver.await else {
return (IoBufMut::default(), Err(Error::RecvFailed));
};
buffer = match return_buf {
Some(OpBuffer::Read(b)) => b,
_ => unreachable!("io_uring loop returns the same OpBuffer that was submitted"),
};
if should_retry(return_value) {
continue;
}
if return_value <= 0 {
let err = if return_value == -libc::ETIMEDOUT {
Error::Timeout
} else {
Error::RecvFailed
};
return (buffer, Err(err));
}
return (buffer, Ok(return_value as usize));
}
}
async fn fill_buffer(&mut self) -> Result<usize, Error> {
self.buffer_pos = 0;
self.buffer_len = 0;
let buffer = std::mem::take(&mut self.buffer);
let len = buffer.capacity();
let (buffer, result) = self.submit_recv(buffer, 0, len).await;
self.buffer = buffer;
self.buffer_len = result?;
unsafe { self.buffer.set_len(self.buffer_len) };
Ok(self.buffer_len)
}
}
impl crate::Stream for Stream {
async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
let mut owned_buf = unsafe { self.pool.alloc_len(len) };
let mut bytes_received = 0;
while bytes_received < len {
let buffered = self.buffer_len - self.buffer_pos;
if buffered > 0 {
let to_copy = std::cmp::min(buffered, len - bytes_received);
owned_buf.as_mut()[bytes_received..bytes_received + to_copy].copy_from_slice(
&self.buffer.as_ref()[self.buffer_pos..self.buffer_pos + to_copy],
);
self.buffer_pos += to_copy;
bytes_received += to_copy;
continue;
}
let remaining = len - bytes_received;
let buffer_capacity = self.buffer.capacity();
if buffer_capacity == 0 || remaining >= buffer_capacity {
let (returned_buf, result) =
self.submit_recv(owned_buf, bytes_received, remaining).await;
owned_buf = returned_buf;
bytes_received += result?;
} else {
self.fill_buffer().await?;
}
}
Ok(IoBufs::from(owned_buf.freeze()))
}
fn peek(&self, max_len: usize) -> &[u8] {
let buffered = self.buffer_len - self.buffer_pos;
let len = std::cmp::min(buffered, max_len);
&self.buffer.as_ref()[self.buffer_pos..self.buffer_pos + len]
}
}
#[cfg(test)]
mod tests {
use crate::{
iouring,
network::{
iouring::{Config, Network},
tests,
},
BufferPool, BufferPoolConfig, Error, Listener as _, Network as _, Sink as _, Stream as _,
};
use commonware_macros::{select, test_group};
use prometheus_client::registry::Registry;
use std::{
sync::Arc,
time::{Duration, Instant},
};
fn test_pool() -> BufferPool {
BufferPool::new(BufferPoolConfig::for_network(), &mut Registry::default())
}
#[tokio::test]
async fn test_trait() {
tests::test_network_trait(|| {
Network::start(Config::default(), &mut Registry::default(), test_pool())
.expect("Failed to start io_uring")
})
.await;
}
#[test_group("slow")]
#[tokio::test]
async fn test_stress_trait() {
tests::stress_test_network_trait(|| {
Network::start(
Config {
iouring_config: iouring::Config {
size: 256,
..Default::default()
},
..Default::default()
},
&mut Registry::default(),
test_pool(),
)
.expect("Failed to start io_uring")
})
.await;
}
#[tokio::test]
async fn test_small_send_read_quickly() {
let network = Network::start(Config::default(), &mut Registry::default(), test_pool())
.expect("Failed to start io_uring");
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
stream.recv(10).await.unwrap()
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
let msg = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
sink.send(msg.clone()).await.unwrap();
let received = reader.await.unwrap();
assert_eq!(received.coalesce(), msg.as_slice());
}
#[tokio::test]
async fn test_read_timeout_with_partial_data() {
let op_timeout = Duration::from_millis(100);
let network = Network::start(
Config {
iouring_config: iouring::Config {
op_timeout: Some(op_timeout),
..Default::default()
},
..Default::default()
},
&mut Registry::default(),
test_pool(),
)
.expect("Failed to start io_uring");
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
let start = Instant::now();
let result = stream.recv(100).await;
let elapsed = start.elapsed();
(result, elapsed)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send([1u8, 2, 3, 4, 5].as_slice()).await.unwrap();
let (result, elapsed) = reader.await.unwrap();
assert!(matches!(result, Err(Error::Timeout)));
assert!(elapsed >= op_timeout);
assert!(elapsed < op_timeout * 3);
}
#[tokio::test]
async fn test_unbuffered_mode() {
let network = Network::start(
Config {
read_buffer_size: 0,
..Default::default()
},
&mut Registry::default(),
test_pool(),
)
.expect("Failed to start io_uring");
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
assert!(stream.peek(100).is_empty());
let buf1 = stream.recv(5).await.unwrap();
assert!(stream.peek(100).is_empty());
let buf2 = stream.recv(5).await.unwrap();
assert!(stream.peek(100).is_empty());
(buf1, buf2)
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send([1u8, 2, 3, 4, 5].as_slice()).await.unwrap();
sink.send([6u8, 7, 8, 9, 10].as_slice()).await.unwrap();
let (buf1, buf2) = reader.await.unwrap();
assert_eq!(buf1.coalesce(), &[1u8, 2, 3, 4, 5]);
assert_eq!(buf2.coalesce(), &[6u8, 7, 8, 9, 10]);
}
#[tokio::test]
async fn test_op_fd_keeps_descriptor_alive() {
let op_timeout = Duration::from_millis(200);
let network = Network::start(
Config {
iouring_config: iouring::Config {
op_timeout: Some(op_timeout),
..Default::default()
},
..Default::default()
},
&mut Registry::default(),
test_pool(),
)
.expect("Failed to start io_uring");
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let (client_sink, mut client_stream) = network.dial(addr).await.unwrap();
let (_addr, _server_sink, _server_stream) = listener.accept().await.unwrap();
let fd = client_stream.fd.clone();
assert_eq!(Arc::strong_count(&fd), 3);
select! {
_ = client_stream.recv(1) => unreachable!("no data was sent"),
_ = tokio::time::sleep(Duration::from_millis(50)) => {},
}
assert_eq!(Arc::strong_count(&fd), 4);
drop(client_sink);
drop(client_stream);
assert_eq!(Arc::strong_count(&fd), 2);
tokio::time::sleep(op_timeout).await;
assert_eq!(Arc::strong_count(&fd), 1);
}
#[tokio::test]
async fn test_peek_with_buffered_data() {
let network = Network::start(Config::default(), &mut Registry::default(), test_pool())
.expect("Failed to start io_uring");
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();
assert!(stream.peek(100).is_empty());
let first = stream.recv(5).await.unwrap();
assert_eq!(first.coalesce(), b"hello");
let peeked = stream.peek(100);
assert!(!peeked.is_empty());
assert_eq!(peeked, b" world");
assert_eq!(stream.peek(100), b" world");
assert_eq!(stream.peek(3), b" wo");
let rest = stream.recv(6).await.unwrap();
assert_eq!(rest.coalesce(), b" world");
assert!(stream.peek(100).is_empty());
});
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send(b"hello world").await.unwrap();
reader.await.unwrap();
}
}