use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use core::cmp;
use core::ops::Range;
use super::Connection;
use super::{new_io_slice, new_io_slice_mut};
use crate::connection::{IoSlice, IoSliceMut};
use crate::{Fd, Result};
cfg_std_unix! {
use std::os::unix::io::{AsRawFd, RawFd};
}
cfg_std_windows! {
use std::os::windows::io::{AsRawSocket, RawSocket};
}
const DEFAULT_READ_CAPACITY: usize = 4096;
const DEFAULT_WRITE_CAPACITY: usize = 16384;
pub struct BufConnection<C: ?Sized> {
read_buf: ReadBuffer,
write_buf: WriteBuffer,
conn: C,
}
struct ReadBuffer {
buf: Box<[u8]>,
valid_range: Range<usize>,
fds: Vec<Fd>,
}
struct WriteBuffer {
buf: Box<[u8]>,
writable: usize,
fds: Vec<Fd>,
}
impl<C> From<C> for BufConnection<C> {
fn from(conn: C) -> Self {
Self::new(conn)
}
}
impl<C: ?Sized> AsRef<C> for BufConnection<C> {
fn as_ref(&self) -> &C {
&self.conn
}
}
impl<C: ?Sized> AsMut<C> for BufConnection<C> {
fn as_mut(&mut self) -> &mut C {
&mut self.conn
}
}
cfg_std_unix! {
impl<C: AsRawFd + ?Sized> AsRawFd for BufConnection<C> {
fn as_raw_fd(&self) -> RawFd {
self.conn.as_raw_fd()
}
}
}
cfg_std_windows! {
impl<C: AsRawSocket + ?Sized> AsRawSocket for BufConnection<C> {
fn as_raw_socket(&self) -> RawSocket {
self.conn.as_raw_socket()
}
}
}
impl<C> BufConnection<C> {
pub fn new(conn: C) -> Self {
Self::with_capacity(DEFAULT_READ_CAPACITY, DEFAULT_WRITE_CAPACITY, conn)
}
pub fn with_capacity(read_capacity: usize, write_capacity: usize, conn: C) -> Self {
let read_buf = ReadBuffer {
buf: vec![0; read_capacity].into_boxed_slice(),
valid_range: 0..0,
fds: Vec::new(),
};
let write_buf = WriteBuffer {
buf: vec![0; write_capacity].into_boxed_slice(),
writable: 0,
fds: Vec::new(),
};
Self {
read_buf,
write_buf,
conn,
}
}
}
impl<C: Connection + ?Sized> BufConnection<C> {
fn flush_write_buffer(&mut self) -> Result<usize> {
let mut nwritten = 0;
while nwritten < self.write_buf.writable {
let buffer = &self.write_buf.buf[nwritten..self.write_buf.writable];
if self.write_buf.fds.is_empty() {
nwritten += self.conn.send_slice(buffer)?;
} else {
nwritten += self
.conn
.send_slices_and_fds(&[new_io_slice(buffer)], &mut self.write_buf.fds)?;
}
}
self.write_buf.flush();
tracing::trace!("Flushed {} bytes to underlying connection", nwritten);
Ok(nwritten)
}
fn copy_slice_to_buffer(&mut self, slice: &[u8]) -> usize {
let amt = cmp::min(self.write_buf.spare_capacity(), slice.len());
let out = self.write_buf.empty_slice();
out[..amt].copy_from_slice(&slice[..amt]);
self.write_buf.advance(amt);
amt
}
fn send_slices_impl(
&mut self,
slices: &[IoSlice<'_>],
write_handler: impl FnOnce(&mut Self, &[IoSlice<'_>], bool) -> Result<usize>,
) -> Result<usize> {
let total_len = slices
.iter()
.map(|s| s.len())
.fold(0usize, usize::saturating_add);
let span = tracing::debug_span!(
"BufConnection::send_slices_impl",
num_slices = slices.len(),
total_len = total_len
);
let _enter = span.enter();
if self.write_buf.spare_capacity() <= total_len {
tracing::trace!("flushing write buffer");
self.flush_write_buffer()?;
}
if total_len > self.write_buf.capacity() {
tracing::debug!(
"write is too large for buffer, \
forwarding to inner impl"
);
return write_handler(self, slices, true);
}
let mut nwritten = 0;
for slice in slices {
nwritten += self.copy_slice_to_buffer(slice);
}
tracing::trace!("wrote {} bytes to buffer", nwritten);
write_handler(self, &[], false)?;
Ok(total_len)
}
}
impl<C: Connection + ?Sized> BufConnection<C> {
fn copy_into_slice(&mut self, slice: &mut [u8]) -> usize {
let amt = cmp::min(slice.len(), self.read_buf.readable_slice().len());
let buf = self.read_buf.readable_slice();
slice[..amt].copy_from_slice(&buf[..amt]);
self.read_buf.advance_read(amt);
amt
}
fn copy_into_slices(&mut self, slices: &mut [IoSliceMut<'_>]) -> usize {
let mut amt_copied = 0;
for slice in slices {
amt_copied += self.copy_into_slice(&mut *slice);
}
amt_copied
}
fn recv_slices_impl(
&mut self,
slices: &mut [IoSliceMut<'_>],
fds: &mut Vec<Fd>,
mut read_handler: impl FnMut(&mut C, &mut [u8], &mut Vec<Fd>) -> Result<usize>,
) -> Result<usize> {
let total_len = slices
.iter()
.map(|s| s.len())
.fold(0usize, usize::saturating_add);
let span = tracing::debug_span!(
"BufConnection::recv_slices_impl",
num_slices = slices.len(),
total_len = total_len,
);
let _enter = span.enter();
if total_len > self.read_buf.readable() {
tracing::debug!(
"total length {} does not fit in buffer of size {}, \
forwarding to read_handler",
total_len,
self.read_buf.readable()
);
let amt = read_handler(
&mut self.conn,
&mut self.read_buf.buf[self.read_buf.valid_range.end..],
&mut self.read_buf.fds,
)?;
self.read_buf.advance_write(amt);
}
let amt_copied = self.copy_into_slices(slices);
fds.append(&mut self.read_buf.fds);
tracing::trace!(
"copied amt {} of {} bytes into buffer",
amt_copied,
total_len
);
Ok(amt_copied)
}
fn recv_slice_impl(
&mut self,
slice: &mut [u8],
fds: Option<&mut Vec<Fd>>,
mut read_handler: impl FnMut(&mut C, &mut [IoSliceMut<'_>], &mut Vec<Fd>) -> Result<usize>,
) -> Result<usize> {
let span = tracing::debug_span!("BufConnection::recv_slice_impl", len = slice.len(),);
let _enter = span.enter();
if slice.len() > self.read_buf.readable() {
let mut iov = [new_io_slice_mut(
&mut self.read_buf.buf[self.read_buf.valid_range.end..],
)];
let amt = read_handler(&mut self.conn, &mut iov, &mut self.read_buf.fds)?;
self.read_buf.advance_write(amt);
}
let amt = self.copy_into_slice(slice);
if let Some(fds) = fds {
fds.append(&mut self.read_buf.fds);
}
tracing::trace!("copied amt {} of {} bytes into buffer", amt, slice.len());
Ok(amt)
}
}
impl<Conn: Connection + ?Sized> Connection for BufConnection<Conn> {
fn recv_slices_and_fds(
&mut self,
slices: &mut [IoSliceMut<'_>],
fds: &mut Vec<Fd>,
) -> Result<usize> {
self.recv_slices_impl(slices, fds, |conn, slice, fds| {
conn.recv_slice_and_fds(slice, fds)
})
}
fn recv_slice_and_fds(&mut self, slice: &mut [u8], fds: &mut Vec<Fd>) -> Result<usize> {
self.recv_slice_impl(slice, Some(fds), |conn, slices, fds| {
conn.recv_slices_and_fds(slices, fds)
})
}
fn recv_slice(&mut self, slice: &mut [u8]) -> Result<usize> {
self.recv_slice_impl(slice, None, |conn, slices, fds| {
conn.recv_slices_and_fds(slices, fds)
})
}
fn non_blocking_recv_slices_and_fds(
&mut self,
slices: &mut [IoSliceMut<'_>],
fds: &mut Vec<Fd>,
) -> Result<usize> {
self.recv_slices_impl(slices, fds, |conn, slice, fds| {
conn.non_blocking_recv_slice_and_fds(slice, fds)
})
}
fn non_blocking_recv_slice_and_fds(
&mut self,
slice: &mut [u8],
fds: &mut Vec<Fd>,
) -> Result<usize> {
self.recv_slice_impl(slice, Some(fds), |conn, slices, fds| {
conn.non_blocking_recv_slices_and_fds(slices, fds)
})
}
fn send_slices_and_fds(&mut self, slices: &[IoSlice<'_>], fds: &mut Vec<Fd>) -> Result<usize> {
self.send_slices_impl(slices, move |this, slices, true_write| {
if true_write {
this.conn.send_slices_and_fds(slices, fds)
} else {
this.write_buf.fds.append(fds);
Ok(0)
}
})
}
fn send_slices(&mut self, slices: &[IoSlice<'_>]) -> Result<usize> {
self.send_slices_impl(slices, |this, slice, true_write| {
if true_write {
this.conn.send_slices(slice)
} else {
Ok(0)
}
})
}
fn send_slice(&mut self, slice: &[u8]) -> Result<usize> {
if slice.len() >= self.write_buf.spare_capacity() {
self.flush_write_buffer()?;
}
if slice.len() > self.write_buf.capacity() {
return self.conn.send_slice(slice);
}
self.copy_slice_to_buffer(slice);
Ok(slice.len())
}
fn flush(&mut self) -> Result<()> {
self.flush_write_buffer()?;
self.conn.flush()
}
fn shutdown(&self) -> Result<()> {
self.conn.shutdown()
}
}
impl ReadBuffer {
fn readable_slice(&self) -> &[u8] {
&self.buf[self.valid_range.clone()]
}
fn readable(&self) -> usize {
self.readable_slice().len()
}
fn advance_write(&mut self, n: usize) {
self.valid_range.end += n;
debug_assert!(self.valid_range.end <= self.buf.len());
}
fn advance_read(&mut self, n: usize) {
self.valid_range.start += n;
debug_assert!(self.valid_range.start <= self.valid_range.end);
if Range::is_empty(&self.valid_range) {
self.valid_range = 0..0;
}
}
}
impl WriteBuffer {
fn empty_slice(&mut self) -> &mut [u8] {
&mut self.buf[self.writable..]
}
fn advance(&mut self, n: usize) {
self.writable += n;
debug_assert!(self.writable <= self.buf.len());
}
fn flush(&mut self) {
self.writable = 0;
}
fn spare_capacity(&self) -> usize {
self.buf.len() - self.writable
}
fn capacity(&self) -> usize {
self.buf.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{connection::with_test_connection, utils::setup_tracing};
#[cfg(all(feature = "std", unix))]
use core::mem::ManuallyDrop;
#[cfg(all(feature = "std", unix))]
#[test]
fn test_write_vectored() {
setup_tracing();
for buf_size in [16384, 5].iter().copied() {
with_test_connection(
&[],
vec![],
|conn| {
let mut bc = BufConnection::with_capacity(buf_size, buf_size, conn);
let iov = [new_io_slice(b"Hello,"), new_io_slice(b" world!")];
let mut fds = (15i32..20).map(Fd::new).collect::<Vec<_>>();
let amt = bc.send_slices_and_fds(&iov, &mut fds).unwrap();
assert_eq!(amt, 13);
bc.flush().unwrap();
},
|write_bytes, write_fds| {
assert_eq!(write_bytes.as_slice(), b"Hello, world!".as_ref());
assert_eq!(write_fds, vec![15, 16, 17, 18, 19]);
},
);
}
}
#[cfg(all(feature = "std", unix))]
#[test]
fn test_read_vectored() {
setup_tracing();
for buf_size in [16384].iter().copied() {
with_test_connection(
b"Hello, world!",
vec![15, 16, 17, 18, 19],
|conn| {
let mut bc = BufConnection::with_capacity(buf_size, buf_size, conn);
let mut buffer = [0; 13];
let (buf1, buf2) = buffer.split_at_mut(3);
let (buf2, buf3) = buf2.split_at_mut(3);
let buf3 = &mut buf3[..3];
let mut iov = [
new_io_slice_mut(buf1),
new_io_slice_mut(buf2),
new_io_slice_mut(buf3),
];
let mut fds = ManuallyDrop::new(vec![]);
let amt = bc.recv_slices_and_fds(&mut iov, &mut fds).unwrap();
let fds = ManuallyDrop::into_inner(fds)
.into_iter()
.map(|f| f.as_raw_fd())
.collect::<Vec<_>>();
assert_eq!(amt, 9);
assert_eq!(&buffer[..9], b"Hello, wo".as_ref());
assert_eq!(fds, vec![15, 16, 17, 18, 19]);
let (buf1, buf2) = buffer.split_at_mut(2);
let buf2 = &mut buf2[..2];
let mut iov = [new_io_slice_mut(buf1), new_io_slice_mut(buf2)];
let mut fds = vec![];
assert_eq!(bc.recv_slices_and_fds(&mut iov, &mut fds).unwrap(), 4);
assert_eq!(&buffer[..4], b"rld!".as_ref());
},
|_, _| {},
);
}
}
#[test]
fn test_write_vectored_without_fds() {
setup_tracing();
for buf_size in [16384, 5].iter().copied() {
with_test_connection(
&[],
vec![],
|conn| {
let mut bc = BufConnection::with_capacity(buf_size, buf_size, conn);
let iov = [new_io_slice(b"Hello,"), new_io_slice(b" world!")];
let amt = bc.send_slices(&iov).unwrap();
assert_eq!(amt, 13);
bc.flush().unwrap();
},
|write_bytes, write_fds| {
assert_eq!(write_bytes.as_slice(), b"Hello, world!".as_ref());
assert_eq!(write_fds, vec![]);
},
);
}
}
#[test]
fn test_write_buffer() {
setup_tracing();
for buf_size in [16384, 5].iter().copied() {
with_test_connection(
&[],
vec![],
|conn| {
let mut bc = BufConnection::with_capacity(buf_size, buf_size, conn);
assert_eq!(bc.send_slice(b"Hello, world!").unwrap(), 13);
bc.flush().unwrap();
},
|write_bytes, write_fds| {
assert_eq!(&write_bytes, b"Hello, world!".as_ref());
assert_eq!(write_fds, vec![]);
},
);
}
}
#[test]
fn test_read_buffer() {
setup_tracing();
for buf_size in [16834, 6].iter().copied() {
with_test_connection(
b"Hello, world!",
vec![],
|conn| {
let mut bc = BufConnection::with_capacity(buf_size, buf_size, conn);
let mut buf = [0; 5];
assert_eq!(bc.recv_slice(&mut buf).unwrap(), 5);
assert_eq!(buf, b"Hello".as_ref());
let mut buf = [0; 8];
let mut nread = 0;
while nread < 8 {
nread += bc.recv_slice(&mut buf[nread..]).unwrap();
}
assert_eq!(buf, b", world!".as_ref());
assert_eq!(bc.recv_slice(&mut buf).unwrap(), 0);
},
|_, _| {},
);
}
}
}