use crate::error::handle_accept_error;
use crate::Error as TransportError;
use std::ops::ControlFlow;
use std::pin::pin;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::{Stream, StreamExt};
#[inline]
pub fn serve_tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>> + Send + 'static,
) -> impl Stream<Item = Result<IO, TransportError>>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<TransportError> + Send + 'static,
{
async_stream::stream! {
let mut incoming = pin!(incoming);
while let Some(item) = incoming.next().await {
match item {
Ok(io) => yield Ok(io),
Err(e) => match handle_accept_error(e.into()) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(e) => yield Err(e),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream::iter;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_stream::Stream;
struct MockIO {
read_result: Option<io::Result<()>>,
write_result: Option<io::Result<usize>>,
}
impl AsyncRead for MockIO {
fn poll_read(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(self.read_result.take().unwrap_or(Ok(())))
}
}
impl AsyncWrite for MockIO {
fn poll_write(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.write_result.take().unwrap_or(Ok(0)))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Unpin for MockIO {}
fn mock_incoming(
results: Vec<Result<MockIO, io::Error>>,
) -> impl Stream<Item = Result<MockIO, io::Error>> {
iter(results)
}
#[tokio::test]
async fn test_handle_accept_error_non_fatal() {
let non_fatal_errors = vec![
io::ErrorKind::ConnectionAborted,
io::ErrorKind::Interrupted,
io::ErrorKind::InvalidData,
io::ErrorKind::WouldBlock,
];
for kind in non_fatal_errors {
let error = io::Error::new(kind, "Test non-fatal error");
assert!(matches!(
handle_accept_error(error),
ControlFlow::Continue(())
));
}
}
#[tokio::test]
async fn test_handle_accept_error_fatal() {
let fatal_error = io::Error::new(io::ErrorKind::PermissionDenied, "Test fatal error");
assert!(matches!(
handle_accept_error(fatal_error),
ControlFlow::Break(_)
));
}
#[tokio::test]
async fn test_serve_tcp_incoming_success() {
let mock_io = MockIO {
read_result: Some(Ok(())),
write_result: Some(Ok(5)),
};
let incoming = mock_incoming(vec![Ok(mock_io)]);
let mut stream = Box::pin(serve_tcp_incoming(incoming));
if let Some(result) = stream.next().await {
assert!(result.is_ok());
} else {
panic!("Stream ended unexpectedly");
}
}
#[tokio::test]
async fn test_serve_tcp_incoming_non_fatal_error() {
let non_fatal_error = io::Error::new(io::ErrorKind::WouldBlock, "Would block");
let incoming = mock_incoming(vec![
Err(non_fatal_error),
Ok(MockIO {
read_result: Some(Ok(())),
write_result: Some(Ok(5)),
}),
]);
let mut stream = Box::pin(serve_tcp_incoming(incoming));
if let Some(result) = stream.next().await {
assert!(result.is_ok());
} else {
panic!("Stream ended unexpectedly");
}
}
#[tokio::test]
async fn test_serve_tcp_incoming_fatal_error() {
let fatal_error = io::Error::new(io::ErrorKind::PermissionDenied, "Permission denied");
let incoming = mock_incoming(vec![Err(fatal_error)]);
let mut stream = Box::pin(serve_tcp_incoming(incoming));
if let Some(result) = stream.next().await {
assert!(result.is_err());
} else {
panic!("Stream ended unexpectedly");
}
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_serve_tcp_incoming_mixed_results() {
let incoming = mock_incoming(vec![
Ok(MockIO {
read_result: Some(Ok(())),
write_result: Some(Ok(5)),
}),
Err(io::Error::new(io::ErrorKind::WouldBlock, "Would block")),
Ok(MockIO {
read_result: Some(Ok(())),
write_result: Some(Ok(3)),
}),
Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"Permission denied",
)),
]);
let mut stream = Box::pin(serve_tcp_incoming(incoming));
assert!(stream.next().await.unwrap().is_ok());
assert!(stream.next().await.unwrap().is_ok());
assert!(stream.next().await.unwrap().is_err());
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_serve_tcp_incoming_empty_stream() {
let incoming = mock_incoming(vec![]);
let mut stream = Box::pin(serve_tcp_incoming(incoming));
assert!(stream.next().await.is_none());
}
}