use std::{future::Future, io::IoSlice, time::Instant};
use crate::{Errno, OwnedFd, operations, try_clone_owned_fd};
pub trait AsyncStreamRead {
fn try_read<'a>(
&'a mut self,
buffer: &'a mut [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<usize, Errno>> + 'a;
fn read<'a>(
&'a mut self,
buffer: &'a mut [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a;
}
pub trait AsyncStreamWrite {
fn write<'a>(
&'a mut self,
buffer: &'a [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a;
fn shutdown(&mut self) -> impl Future<Output = Result<(), Errno>>;
fn close(&mut self) -> impl Future<Output = Result<(), Errno>>;
fn writev<'a>(
&'a mut self,
buffers: &'a mut [IoSlice<'a>],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a {
async move {
for buffer in buffers {
self.write(buffer, deadline).await?;
}
Ok(())
}
}
}
pub trait SplittableStream {
type ReadStream: AsyncStreamRead;
type WriteStream: AsyncStreamWrite;
async fn split(self) -> Result<(Self::ReadStream, Self::WriteStream), Errno>;
}
struct OwnedFdReadState {
read_used: usize,
read_available: usize,
read_buffer: [u8; 16384],
}
pub struct OwnedFdStreamRead {
fd: Option<OwnedFd>,
read_state: OwnedFdReadState,
}
impl OwnedFdStreamRead {
pub async fn close(&self) -> Result<(), Errno> {
if let Some(fd) = &self.fd {
operations::close(fd.try_clone().unwrap()).await?;
}
Ok(())
}
}
pub struct OwnedFdStreamWrite {
fd: Option<OwnedFd>,
}
pub struct OwnedFdStream {
fd: Option<OwnedFd>,
read_state: OwnedFdReadState,
}
impl OwnedFdStream {
pub fn new(fd: OwnedFd) -> Self {
let fd = Some(fd);
Self {
fd,
read_state: OwnedFdReadState {
read_used: 0,
read_available: 0,
read_buffer: [0; 16384],
},
}
}
pub fn into_inner(self) -> Option<OwnedFd> {
self.fd
}
}
impl SplittableStream for OwnedFdStream {
type ReadStream = OwnedFdStreamRead;
type WriteStream = OwnedFdStreamWrite;
async fn split(self) -> Result<(OwnedFdStreamRead, OwnedFdStreamWrite), Errno> {
let (read_fd, write_fd) = if let Some(fd) = self.fd {
(Some(try_clone_owned_fd(&fd)?), Some(fd))
} else {
(None, None)
};
Ok((
OwnedFdStreamRead {
fd: read_fd,
read_state: self.read_state,
},
OwnedFdStreamWrite { fd: write_fd },
))
}
}
async fn try_read_impl(
fd: &mut Option<OwnedFd>,
buffer: &mut [u8],
read_state: &mut OwnedFdReadState,
deadline: Option<Instant>,
) -> Result<usize, Errno> {
if let Some(fd) = fd {
if read_state.read_available == 0 {
let amount =
operations::read_with_deadline(fd, &mut read_state.read_buffer, deadline).await?;
if amount == 0 {
return Ok(0);
}
read_state.read_used = 0;
read_state.read_available = amount;
}
let read_used = read_state.read_used;
let tocopy = std::cmp::min(buffer.len(), read_state.read_available);
read_state.read_used += tocopy;
read_state.read_available -= tocopy;
buffer[0..tocopy].copy_from_slice(&read_state.read_buffer[read_used..read_used + tocopy]);
Ok(tocopy)
} else {
Err(Errno::from_raw_os_error(crate::EPIPE))
}
}
async fn read_impl(
fd: &mut Option<OwnedFd>,
mut buffer: &mut [u8],
read_state: &mut OwnedFdReadState,
deadline: Option<Instant>,
) -> Result<(), Errno> {
while !buffer.is_empty() {
let amount = try_read_impl(fd, buffer, read_state, deadline).await?;
if amount == 0 {
return Err(Errno::from_raw_os_error(crate::EPIPE));
}
buffer = &mut buffer[amount..];
}
Ok(())
}
impl AsyncStreamRead for OwnedFdStreamRead {
fn try_read<'a>(
&'a mut self,
buffer: &'a mut [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<usize, Errno>> + 'a {
try_read_impl(&mut self.fd, buffer, &mut self.read_state, deadline)
}
fn read<'a>(
&'a mut self,
buffer: &'a mut [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a {
read_impl(&mut self.fd, buffer, &mut self.read_state, deadline)
}
}
impl AsyncStreamRead for OwnedFdStream {
fn try_read<'a>(
&'a mut self,
buffer: &'a mut [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<usize, Errno>> + 'a {
try_read_impl(&mut self.fd, buffer, &mut self.read_state, deadline)
}
fn read<'a>(
&'a mut self,
buffer: &'a mut [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a {
read_impl(&mut self.fd, buffer, &mut self.read_state, deadline)
}
}
async fn write_impl(
fd: &mut Option<OwnedFd>,
mut buffer: &[u8],
deadline: Option<Instant>,
) -> Result<(), Errno> {
if let Some(fd) = fd {
while !buffer.is_empty() {
let amount = operations::write_with_deadline(&fd, buffer, deadline).await?;
if amount == 0 {
return Err(Errno::from_raw_os_error(crate::EPIPE));
}
buffer = &buffer[amount..];
}
Ok(())
} else {
Err(Errno::from_raw_os_error(crate::EPIPE))
}
}
async fn writev_impl(
fd: &mut Option<OwnedFd>,
mut buffers: &mut [IoSlice<'_>],
deadline: Option<Instant>,
) -> Result<(), Errno> {
if let Some(fd) = fd {
while !buffers.is_empty() {
let result = operations::writev_with_deadline(&fd, buffers, None, deadline).await?;
if result == 0 {
return Err(Errno::from_raw_os_error(crate::EPIPE));
}
IoSlice::advance_slices(&mut buffers, result);
}
Ok(())
} else {
Err(Errno::from_raw_os_error(crate::EPIPE))
}
}
async fn shutdown_impl(fd: &mut Option<OwnedFd>) -> Result<(), Errno> {
if let Some(fd) = fd {
operations::shutdown(fd, libc::SHUT_RDWR).await?;
Ok(())
} else {
Err(Errno::from_raw_os_error(crate::EPIPE))
}
}
async fn close_impl(fd: &mut Option<OwnedFd>) -> Result<(), Errno> {
if let Some(fd) = fd.take() {
operations::close(fd).await?;
Ok(())
} else {
Err(Errno::from_raw_os_error(crate::EPIPE))
}
}
impl AsyncStreamWrite for OwnedFdStreamWrite {
fn write<'a>(
&'a mut self,
buffer: &'a [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a {
write_impl(&mut self.fd, buffer, deadline)
}
fn writev<'a>(
&'a mut self,
buffers: &'a mut [IoSlice<'a>],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a {
writev_impl(&mut self.fd, buffers, deadline)
}
fn shutdown(&mut self) -> impl Future<Output = Result<(), Errno>> {
shutdown_impl(&mut self.fd)
}
fn close(&mut self) -> impl Future<Output = Result<(), Errno>> {
close_impl(&mut self.fd)
}
}
impl AsyncStreamWrite for OwnedFdStream {
fn write<'a>(
&'a mut self,
buffer: &'a [u8],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a {
write_impl(&mut self.fd, buffer, deadline)
}
fn writev<'a>(
&'a mut self,
buffers: &'a mut [IoSlice<'a>],
deadline: Option<Instant>,
) -> impl Future<Output = Result<(), Errno>> + 'a {
writev_impl(&mut self.fd, buffers, deadline)
}
fn shutdown(&mut self) -> impl Future<Output = Result<(), Errno>> {
shutdown_impl(&mut self.fd)
}
fn close(&mut self) -> impl Future<Output = Result<(), Errno>> {
close_impl(&mut self.fd)
}
}
#[cfg(test)]
mod test {
use crate::{
AsyncStreamRead, AsyncStreamWrite, OwnedFdStream, SplittableStream, operations,
pipe::bipipe,
};
use rustix::fd::{FromRawFd, IntoRawFd, OwnedFd};
use std::io::{IoSlice, Seek, Write};
#[crate::test]
async fn manual_channel_test() {
let (client, server) = bipipe();
let mut stream1 = OwnedFdStream::new(client);
let mut stream2 = OwnedFdStream::new(server);
let iters = 200usize;
for iter in 0..iters {
let mut buffer = [0; 128];
let id = 1u32;
let length: u32 = 12;
buffer[0..4].copy_from_slice(&length.to_le_bytes());
buffer[4..8].copy_from_slice(&id.to_le_bytes());
buffer[8..16].copy_from_slice(&iter.to_le_bytes());
stream1.write(&buffer[0..16], None).await.unwrap();
let mut response = [0; 128];
stream2.read(&mut response[0..4], None).await.unwrap();
let length = u32::from_le_bytes(response[0..4].try_into().unwrap()) as usize;
stream2.read(&mut response[0..length], None).await.unwrap();
let id = u32::from_le_bytes(response[0..4].try_into().unwrap()) as usize;
assert_eq!(id, 1);
let response = &response[4..length];
assert_eq!(&iter.to_le_bytes(), response);
}
stream1.shutdown().await.unwrap();
stream2.shutdown().await.unwrap();
}
#[crate::test]
async fn try_read_test() {
let mut file = tempfile::tempfile().unwrap();
file.write_all(b"hello world")
.expect("Failed to write to file");
file.seek(std::io::SeekFrom::Start(0)).unwrap();
let mut stream = OwnedFdStream::new(unsafe { OwnedFd::from_raw_fd(file.into_raw_fd()) });
let mut buffer = [0; 100];
let amount = stream.try_read(&mut buffer, None).await.unwrap();
assert_eq!(amount, 11);
let amount = stream.try_read(&mut buffer, None).await.unwrap();
assert_eq!(amount, 0);
assert_eq!(&buffer[0..11], b"hello world");
}
#[crate::test]
async fn short_read_test() {
let mut file = tempfile::tempfile().unwrap();
file.write_all(b"hello world")
.expect("Failed to write to file");
file.seek(std::io::SeekFrom::Start(0)).unwrap();
let stream = OwnedFdStream::new(unsafe { OwnedFd::from_raw_fd(file.into_raw_fd()) });
let fd = stream.into_inner().unwrap();
let mut stream = OwnedFdStream::new(fd);
let mut buffer = [0; 11];
stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer[0..11], b"hello world");
let e = stream.read(&mut buffer, None).await;
assert!(e.is_err());
}
#[crate::test]
async fn read_failed_test() {
let (client, server) = bipipe();
operations::close(server).await.unwrap();
let mut stream = OwnedFdStream::new(client);
stream.read(&mut [0; 11], None).await.unwrap_err();
}
#[crate::test]
async fn owned_fd_stream_split_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (_read_stream, mut write_stream) = stream.split().await.unwrap();
write_stream.write(b"hello split", None).await.unwrap();
let mut server_stream = OwnedFdStream::new(server);
let mut buffer = [0u8; 11];
server_stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer, b"hello split");
}
#[crate::test]
async fn owned_fd_stream_read_stream_try_read_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (mut read_stream, _write_stream) = stream.split().await.unwrap();
let mut server_stream = OwnedFdStream::new(server);
server_stream.write(b"test try_read", None).await.unwrap();
let mut buffer = [0u8; 5];
let amount = read_stream.try_read(&mut buffer, None).await.unwrap();
assert_eq!(amount, 5);
assert_eq!(&buffer, b"test ");
let mut remaining = [0u8; 8];
let remaining_amount = read_stream.try_read(&mut remaining, None).await.unwrap();
assert_eq!(remaining_amount, 8);
assert_eq!(&remaining, b"try_read");
}
#[crate::test]
async fn owned_fd_stream_read_stream_read_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (mut read_stream, _write_stream) = stream.split().await.unwrap();
let write_task = {
let mut server_stream = OwnedFdStream::new(server);
crate::operations::spawn_task(async move {
server_stream.write(b"exact", None).await.unwrap();
server_stream.write(b"concat", None).await.unwrap();
server_stream
})
};
let mut buffer = [0u8; 5];
read_stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer, b"exact");
let mut buffer2 = [0u8; 6];
read_stream.read(&mut buffer2, None).await.unwrap();
assert_eq!(&buffer2, b"concat");
write_task.await.unwrap();
}
#[crate::test]
async fn owned_fd_stream_write_stream_write_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (_read_stream, mut write_stream) = stream.split().await.unwrap();
write_stream.write(b"small", None).await.unwrap();
write_stream
.write(b" large data chunk", None)
.await
.unwrap();
let mut server_stream = OwnedFdStream::new(server);
let mut buffer = [0u8; 22];
server_stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer, b"small large data chunk");
}
#[crate::test]
async fn owned_fd_stream_write_stream_writev_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (_read_stream, mut write_stream) = stream.split().await.unwrap();
let buf1 = b"hello";
let buf2 = b" ";
let buf3 = b"vectored";
let buf4 = b" write";
let mut buffers = [
IoSlice::new(buf1.as_slice()),
IoSlice::new(buf2.as_slice()),
IoSlice::new(buf3.as_slice()),
IoSlice::new(buf4.as_slice()),
];
write_stream.writev(&mut buffers, None).await.unwrap();
let mut server_stream = OwnedFdStream::new(server);
let mut buffer = [0u8; 20];
server_stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer, b"hello vectored write");
}
#[crate::test]
async fn owned_fd_stream_write_stream_shutdown_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (_read_stream, mut write_stream) = stream.split().await.unwrap();
write_stream.write(b"before shutdown", None).await.unwrap();
write_stream.shutdown().await.unwrap();
let mut server_stream = OwnedFdStream::new(server);
let mut buffer = [0u8; 15];
server_stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer, b"before shutdown");
let result = write_stream.write(b"after shutdown", None).await;
assert!(result.is_err());
}
#[crate::test]
async fn owned_fd_stream_write_stream_close_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (_read_stream, mut write_stream) = stream.split().await.unwrap();
write_stream.write(b"before close", None).await.unwrap();
write_stream.close().await.unwrap();
let mut server_stream = OwnedFdStream::new(server);
let mut buffer = [0u8; 12];
server_stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer, b"before close");
let result = write_stream.write(b"after close", None).await;
assert!(result.is_err());
}
#[crate::test]
async fn owned_fd_stream_read_stream_close_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (mut read_stream, _write_stream) = stream.split().await.unwrap();
let mut server_stream = OwnedFdStream::new(server);
server_stream.write(b"test data", None).await.unwrap();
server_stream.close().await.unwrap();
let mut buffer = [0u8; 9];
read_stream.read(&mut buffer, None).await.unwrap();
assert_eq!(&buffer, b"test data");
read_stream.close().await.unwrap();
let mut empty_buffer = [0u8; 1];
let result = read_stream.read(&mut empty_buffer, None).await;
assert!(result.is_err());
}
#[crate::test]
async fn owned_fd_stream_split_concurrent_operations_test() {
let (client, server) = bipipe();
let stream1 = OwnedFdStream::new(client);
let stream2 = OwnedFdStream::new(server);
let (read1, write1) = stream1.split().await.unwrap();
let (read2, write2) = stream2.split().await.unwrap();
let write_task = {
let mut write1 = write1;
let mut write2 = write2;
crate::operations::spawn_task(async move {
write1.write(b"from1to2", None).await.unwrap();
write2.write(b"from2to1", None).await.unwrap();
write1.shutdown().await.unwrap();
write2.shutdown().await.unwrap();
})
};
let read_task = {
let mut read1 = read1;
let mut read2 = read2;
crate::operations::spawn_task(async move {
let mut buffer1 = [0u8; 8];
let mut buffer2 = [0u8; 8];
read1.read(&mut buffer1, None).await.unwrap();
read2.read(&mut buffer2, None).await.unwrap();
(buffer1, buffer2)
})
};
let (buffer1, buffer2) = read_task.await.unwrap();
write_task.await.unwrap();
assert_eq!(&buffer1, b"from2to1");
assert_eq!(&buffer2, b"from1to2");
}
#[crate::test]
async fn owned_fd_stream_split_writev_large_test() {
let (client, server) = bipipe();
let stream = OwnedFdStream::new(client);
let (_read_stream, mut write_stream) = stream.split().await.unwrap();
let buf1 = vec![65u8; 4096]; let buf2 = vec![66u8; 4096]; let buf3 = b"end";
let mut buffers = [
IoSlice::new(buf1.as_slice()),
IoSlice::new(buf2.as_slice()),
IoSlice::new(buf3.as_slice()),
];
let read_task = {
let mut server_stream = OwnedFdStream::new(server);
crate::operations::spawn_task(async move {
let mut result = Vec::new();
let mut temp_buffer = [0u8; 1024];
for _ in 0..8 {
server_stream.read(&mut temp_buffer, None).await.unwrap();
result.extend_from_slice(&temp_buffer);
}
let mut end_buffer = [0u8; 3];
server_stream.read(&mut end_buffer, None).await.unwrap();
result.extend_from_slice(&end_buffer);
result
})
};
write_stream.writev(&mut buffers, None).await.unwrap();
let result = read_task.await.unwrap();
assert_eq!(result.len(), 8195); assert!(result[0..4096].iter().all(|&x| x == 65)); assert!(result[4096..8192].iter().all(|&x| x == 66)); assert_eq!(&result[8192..8195], b"end");
}
}