use crate::stream::{BoxStream, Flow, NotUsed, Sink, Source, StreamCompletion};
use crate::{StreamError, StreamResult};
use futures::{FutureExt, channel::oneshot};
use std::future::Future;
use std::net::SocketAddr;
use std::panic::AssertUnwindSafe;
use std::path::PathBuf;
use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
mpsc as std_mpsc,
};
use std::thread::{self, Thread};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::sync::{mpsc, watch};
const DEFAULT_CHUNK_SIZE: usize = 8192;
const FILE_READ_AHEAD_CHUNKS: usize = 8;
const FILE_INTERNAL_READ_SIZE: usize = 256 * 1024;
const TCP_READ_AHEAD_CHUNKS: usize = 1;
const PARK_INTERVAL: Duration = Duration::from_millis(1);
const READ_READY_SPINS: usize = 256;
const BACKPRESSURE_READY_SPINS: usize = 64;
const BACKPRESSURE_PARK: Duration = Duration::from_micros(10);
#[derive(Clone)]
struct ConsumerWaker {
thread: Arc<Mutex<Option<Thread>>>,
}
impl ConsumerWaker {
fn new() -> Self {
Self {
thread: Arc::new(Mutex::new(None)),
}
}
fn capture_current(&self) {
let mut slot = self.thread.lock().expect("consumer waker poisoned");
if slot.is_none() {
*slot = Some(thread::current());
}
}
fn unpark(&self) {
let slot = self.thread.lock().expect("consumer waker poisoned");
if let Some(t) = slot.as_ref() {
t.unpark();
}
}
}
fn io_error(error: std::io::Error) -> StreamError {
StreamError::Failed(error.to_string())
}
fn write_zero_error() -> StreamError {
StreamError::Failed("async writer returned zero bytes".to_owned())
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IoResult {
pub bytes: u64,
pub status: StreamResult<()>,
}
impl IoResult {
#[must_use]
pub fn succeeded(bytes: u64) -> Self {
Self {
bytes,
status: Ok(()),
}
}
#[must_use]
pub fn failed(bytes: u64, error: StreamError) -> Self {
Self {
bytes,
status: Err(error),
}
}
#[must_use]
pub fn bytes(&self) -> u64 {
self.bytes
}
pub fn status(&self) -> StreamResult<()> {
self.status.clone()
}
#[must_use]
pub fn is_success(&self) -> bool {
self.status.is_ok()
}
}
pub type TokioByteSource = Source<Vec<u8>, StreamCompletion<IoResult>>;
pub type TokioByteSink = Sink<Vec<u8>, StreamCompletion<IoResult>>;
#[derive(Clone)]
enum DemandTerminal {
Complete,
Error(StreamError),
}
enum DemandResponse<T> {
Item(T),
Complete,
Error(StreamError),
}
struct DemandSourceStream<T> {
demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<T>>>,
cancel: watch::Sender<bool>,
terminal: Arc<Mutex<Option<DemandTerminal>>>,
done: bool,
}
impl<T> DemandSourceStream<T> {
fn terminal_response(&self) -> Option<Option<StreamResult<T>>> {
self.terminal
.lock()
.expect("tokio source terminal poisoned")
.clone()
.map(|terminal| match terminal {
DemandTerminal::Complete => None,
DemandTerminal::Error(error) => Some(Err(error)),
})
}
fn mark_done(&mut self) {
self.done = true;
let _ = self.cancel.send(true);
}
}
impl<T: Send + 'static> Iterator for DemandSourceStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let stream_cancelled = crate::stream::current_stream_cancelled();
let (reply_sender, reply_receiver) = std_mpsc::channel();
if !send_bounded_demand(&self.demands, reply_sender, &stream_cancelled) {
self.mark_done();
return self
.terminal_response()
.unwrap_or(Some(Err(StreamError::AbruptTermination)));
}
loop {
if stream_cancelled
.as_ref()
.is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
{
self.mark_done();
return Some(Err(StreamError::Cancelled));
}
match reply_receiver.recv_timeout(PARK_INTERVAL) {
Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
Ok(DemandResponse::Complete) => {
self.mark_done();
return None;
}
Ok(DemandResponse::Error(error)) => {
self.mark_done();
return Some(Err(error));
}
Err(std_mpsc::RecvTimeoutError::Timeout) => {}
Err(std_mpsc::RecvTimeoutError::Disconnected) => {
self.mark_done();
return self
.terminal_response()
.unwrap_or(Some(Err(StreamError::AbruptTermination)));
}
}
}
}
}
impl<T> Drop for DemandSourceStream<T> {
fn drop(&mut self) {
let _ = self.cancel.send(true);
}
}
struct BoundedByteSourceStream {
receiver: mpsc::Receiver<DemandResponse<Vec<u8>>>,
cancel: watch::Sender<bool>,
terminal: Arc<Mutex<Option<DemandTerminal>>>,
done: bool,
waker: ConsumerWaker,
}
impl BoundedByteSourceStream {
fn terminal_response(&self) -> Option<Option<StreamResult<Vec<u8>>>> {
self.terminal
.lock()
.expect("tokio source terminal poisoned")
.clone()
.map(|terminal| match terminal {
DemandTerminal::Complete => None,
DemandTerminal::Error(error) => Some(Err(error)),
})
}
fn mark_done(&mut self) {
self.done = true;
let _ = self.cancel.send(true);
}
}
impl Iterator for BoundedByteSourceStream {
type Item = StreamResult<Vec<u8>>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
self.waker.capture_current();
let stream_cancelled = crate::stream::current_stream_cancelled();
let mut spins = 0usize;
loop {
if stream_cancelled
.as_ref()
.is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
{
self.mark_done();
return Some(Err(StreamError::Cancelled));
}
match self.receiver.try_recv() {
Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
Ok(DemandResponse::Complete) => {
self.mark_done();
return None;
}
Ok(DemandResponse::Error(error)) => {
self.mark_done();
return Some(Err(error));
}
Err(mpsc::error::TryRecvError::Empty) => read_wait(&mut spins),
Err(mpsc::error::TryRecvError::Disconnected) => {
self.mark_done();
return self
.terminal_response()
.unwrap_or(Some(Err(StreamError::AbruptTermination)));
}
}
}
}
}
impl Drop for BoundedByteSourceStream {
fn drop(&mut self) {
let _ = self.cancel.send(true);
}
}
fn send_bounded_demand<T>(
sender: &mpsc::Sender<T>,
mut message: T,
stream_cancelled: &Option<Arc<AtomicBool>>,
) -> bool {
let mut spins = 0usize;
loop {
if stream_cancelled
.as_ref()
.is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
{
return false;
}
match sender.try_send(message) {
Ok(()) => return true,
Err(mpsc::error::TrySendError::Full(returned)) => {
message = returned;
backpressure_wait(&mut spins);
}
Err(mpsc::error::TrySendError::Closed(_)) => return false,
}
}
}
fn finish_terminal(terminal: &Arc<Mutex<Option<DemandTerminal>>>, value: DemandTerminal) {
let mut slot = terminal.lock().expect("tokio source terminal poisoned");
if slot.is_none() {
*slot = Some(value);
}
}
async fn next_demand<T>(
demands: &mut mpsc::Receiver<std_mpsc::Sender<DemandResponse<T>>>,
cancel: &mut watch::Receiver<bool>,
) -> Option<std_mpsc::Sender<DemandResponse<T>>> {
if *cancel.borrow() {
return None;
}
tokio::select! {
demand = demands.recv() => demand,
changed = cancel.changed() => {
let _ = changed;
None
}
}
}
fn async_read_source<R, Fut>(
open: impl FnOnce() -> Fut + Send + 'static,
chunk_size: usize,
internal_read_size: usize,
read_ahead_chunks: usize,
) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>)
where
R: AsyncRead + Unpin + Send + 'static,
Fut: Future<Output = std::io::Result<R>> + Send + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
assert!(
read_ahead_chunks > 0,
"read-ahead bound must be greater than zero"
);
let internal_read_size = internal_read_size.max(chunk_size);
let (item_sender, item_receiver) = mpsc::channel(read_ahead_chunks);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let (mat_sender, mat_receiver) = oneshot::channel();
let terminal = Arc::new(Mutex::new(None));
let terminal_for_task = Arc::clone(&terminal);
let waker = ConsumerWaker::new();
let producer_waker = waker.clone();
crate::stream::stream_tokio_runtime().spawn(async move {
let result = AssertUnwindSafe(run_async_read_task(
open(),
chunk_size,
internal_read_size,
item_sender,
cancel_receiver,
Arc::clone(&terminal_for_task),
producer_waker,
))
.catch_unwind()
.await
.unwrap_or_else(|_| {
finish_terminal(
&terminal_for_task,
DemandTerminal::Error(StreamError::AbruptTermination),
);
Err(StreamError::AbruptTermination)
});
let _ = mat_sender.send(result);
});
(
Box::new(BoundedByteSourceStream {
receiver: item_receiver,
cancel: cancel_sender,
terminal,
done: false,
waker,
}) as BoxStream<Vec<u8>>,
StreamCompletion::from_receiver(mat_receiver, None),
)
}
async fn run_async_read_task<R, Fut>(
open: Fut,
chunk_size: usize,
internal_read_size: usize,
items: mpsc::Sender<DemandResponse<Vec<u8>>>,
mut cancel: watch::Receiver<bool>,
terminal: Arc<Mutex<Option<DemandTerminal>>>,
waker: ConsumerWaker,
) -> StreamResult<IoResult>
where
R: AsyncRead + Unpin + Send + 'static,
Fut: Future<Output = std::io::Result<R>> + Send + 'static,
{
let mut bytes = 0_u64;
let mut reader = tokio::select! {
reader = open => match reader {
Ok(reader) => reader,
Err(error) => {
let error = io_error(error);
finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
let _ = send_read_item(&items, DemandResponse::Error(error.clone()), &mut cancel, &waker).await;
return Ok(IoResult::failed(bytes, error));
}
},
changed = cancel.changed() => {
let _ = changed;
finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
return Ok(IoResult::failed(bytes, StreamError::Cancelled));
}
};
let mut buffer = vec![0_u8; internal_read_size];
let mut pending_tail = Vec::with_capacity(chunk_size);
loop {
let read = tokio::select! {
read = reader.read(&mut buffer) => read,
changed = cancel.changed() => {
let _ = changed;
finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
return Ok(IoResult::failed(bytes, StreamError::Cancelled));
}
};
match read {
Ok(0) => {
if !pending_tail.is_empty()
&& !send_read_item(
&items,
DemandResponse::Item(std::mem::take(&mut pending_tail)),
&mut cancel,
&waker,
)
.await
{
finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
return Ok(IoResult::failed(bytes, StreamError::Cancelled));
}
finish_terminal(&terminal, DemandTerminal::Complete);
let _ = send_read_item(&items, DemandResponse::Complete, &mut cancel, &waker).await;
return Ok(IoResult::succeeded(bytes));
}
Ok(read) => {
bytes += read as u64;
if !send_read_chunks(
&items,
chunk_size,
&mut pending_tail,
&buffer[..read],
&mut cancel,
&waker,
)
.await
{
finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
return Ok(IoResult::failed(bytes, StreamError::Cancelled));
}
}
Err(error) => {
let error = io_error(error);
finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
let _ = send_read_item(
&items,
DemandResponse::Error(error.clone()),
&mut cancel,
&waker,
)
.await;
return Ok(IoResult::failed(bytes, error));
}
}
}
}
async fn send_read_chunks(
sender: &mpsc::Sender<DemandResponse<Vec<u8>>>,
chunk_size: usize,
pending_tail: &mut Vec<u8>,
read_buffer: &[u8],
cancel: &mut watch::Receiver<bool>,
waker: &ConsumerWaker,
) -> bool {
let mut offset = 0;
if !pending_tail.is_empty() {
let needed = chunk_size - pending_tail.len();
let take = needed.min(read_buffer.len());
pending_tail.extend_from_slice(&read_buffer[..take]);
offset += take;
if pending_tail.len() == chunk_size
&& !send_read_item(
sender,
DemandResponse::Item(std::mem::take(pending_tail)),
cancel,
waker,
)
.await
{
return false;
}
}
while offset + chunk_size <= read_buffer.len() {
let next = offset + chunk_size;
if !send_read_item(
sender,
DemandResponse::Item(read_buffer[offset..next].to_vec()),
cancel,
waker,
)
.await
{
return false;
}
offset = next;
}
if offset < read_buffer.len() {
pending_tail.extend_from_slice(&read_buffer[offset..]);
}
true
}
async fn send_read_item<T>(
sender: &mpsc::Sender<DemandResponse<T>>,
item: DemandResponse<T>,
cancel: &mut watch::Receiver<bool>,
waker: &ConsumerWaker,
) -> bool
where
T: Send + 'static,
{
let result = tokio::select! {
result = sender.send(item) => result,
changed = cancel.changed() => {
let _ = changed;
return false;
}
};
if result.is_ok() {
waker.unpark();
}
result.is_ok()
}
enum WriteCommand {
Chunk(Vec<u8>),
Finish(StreamResult<()>),
}
struct TokioCancelGuard {
cancel: watch::Sender<bool>,
armed: bool,
}
impl TokioCancelGuard {
fn new(cancel: watch::Sender<bool>) -> Self {
Self {
cancel,
armed: true,
}
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for TokioCancelGuard {
fn drop(&mut self) {
if self.armed {
let _ = self.cancel.send(true);
}
}
}
fn async_write_sink<W, F, Fut>(open: F) -> TokioByteSink
where
W: AsyncWrite + Unpin + Send + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = std::io::Result<W>> + Send + 'static,
{
let open = Arc::new(open);
Sink::from_runner(move |input, materializer| {
let (command_sender, command_receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let (done_sender, done_receiver) = std_mpsc::sync_channel(1);
let open = Arc::clone(&open);
crate::stream::stream_tokio_runtime().spawn(async move {
let result = AssertUnwindSafe(run_async_write_task(
open(),
command_receiver,
cancel_receiver,
))
.catch_unwind()
.await
.unwrap_or(Err(StreamError::AbruptTermination));
let _ = done_sender.send(result);
});
Ok(materializer.spawn_stream(move |cancelled| {
let mut guard = TokioCancelGuard::new(cancel_sender.clone());
let result = feed_async_writer(
input,
command_sender,
done_receiver,
cancelled,
cancel_sender,
);
guard.disarm();
result
}))
})
}
async fn run_async_write_task<W, Fut>(
open: Fut,
mut commands: mpsc::Receiver<WriteCommand>,
mut cancel: watch::Receiver<bool>,
) -> StreamResult<IoResult>
where
W: AsyncWrite + Unpin + Send + 'static,
Fut: Future<Output = std::io::Result<W>> + Send + 'static,
{
let mut bytes = 0_u64;
let mut writer = tokio::select! {
writer = open => match writer {
Ok(writer) => writer,
Err(error) => return Ok(IoResult::failed(bytes, io_error(error))),
},
changed = cancel.changed() => {
let _ = changed;
return Ok(IoResult::failed(bytes, StreamError::Cancelled));
}
};
loop {
let command = tokio::select! {
command = commands.recv() => command,
changed = cancel.changed() => {
let _ = changed;
return Ok(IoResult::failed(bytes, StreamError::Cancelled));
}
};
match command {
Some(WriteCommand::Chunk(chunk)) => {
if let Err(error) = write_chunk(&mut writer, &chunk, &mut cancel, &mut bytes).await
{
return Ok(IoResult::failed(bytes, error));
}
}
Some(WriteCommand::Finish(upstream_status)) => {
let shutdown_status = shutdown_writer(&mut writer, &mut cancel).await;
return Ok(IoResult {
bytes,
status: upstream_status.and(shutdown_status),
});
}
None => {
let _ = shutdown_writer(&mut writer, &mut cancel).await;
return Ok(IoResult::failed(bytes, StreamError::Cancelled));
}
}
}
}
async fn write_chunk<W>(
writer: &mut W,
chunk: &[u8],
cancel: &mut watch::Receiver<bool>,
bytes: &mut u64,
) -> StreamResult<()>
where
W: AsyncWrite + Unpin,
{
let mut offset = 0usize;
while offset < chunk.len() {
let written = tokio::select! {
written = writer.write(&chunk[offset..]) => written.map_err(io_error)?,
changed = cancel.changed() => {
let _ = changed;
return Err(StreamError::Cancelled);
}
};
if written == 0 {
return Err(write_zero_error());
}
offset += written;
*bytes += written as u64;
}
Ok(())
}
async fn shutdown_writer<W>(writer: &mut W, cancel: &mut watch::Receiver<bool>) -> StreamResult<()>
where
W: AsyncWrite + Unpin,
{
tokio::select! {
result = writer.flush() => result.map_err(io_error)?,
changed = cancel.changed() => {
let _ = changed;
return Err(StreamError::Cancelled);
}
}
tokio::select! {
result = writer.shutdown() => result.map_err(io_error),
changed = cancel.changed() => {
let _ = changed;
Err(StreamError::Cancelled)
}
}
}
fn feed_async_writer(
mut input: BoxStream<Vec<u8>>,
command_sender: mpsc::Sender<WriteCommand>,
done_receiver: std_mpsc::Receiver<StreamResult<IoResult>>,
cancelled: Arc<AtomicBool>,
cancel_sender: watch::Sender<bool>,
) -> StreamResult<IoResult> {
let mut terminal = Ok(());
loop {
if cancelled.load(Ordering::SeqCst) {
terminal = Err(StreamError::Cancelled);
break;
}
match input.next() {
Some(Ok(chunk)) => {
if !send_write_command(&command_sender, WriteCommand::Chunk(chunk), &cancelled) {
break;
}
}
Some(Err(error)) => {
terminal = Err(error);
break;
}
None => break,
}
}
if cancelled.load(Ordering::SeqCst) {
let _ = cancel_sender.send(true);
} else {
let _ = send_write_command(&command_sender, WriteCommand::Finish(terminal), &cancelled);
}
drop(command_sender);
loop {
match done_receiver.recv_timeout(PARK_INTERVAL) {
Ok(result) => return result,
Err(std_mpsc::RecvTimeoutError::Timeout) => {
if cancelled.load(Ordering::SeqCst) {
let _ = cancel_sender.send(true);
}
}
Err(std_mpsc::RecvTimeoutError::Disconnected) => {
return Err(StreamError::AbruptTermination);
}
}
}
}
fn send_write_command(
sender: &mpsc::Sender<WriteCommand>,
mut command: WriteCommand,
cancelled: &AtomicBool,
) -> bool {
let mut spins = 0usize;
loop {
if cancelled.load(Ordering::SeqCst) {
return false;
}
match sender.try_send(command) {
Ok(()) => return true,
Err(mpsc::error::TrySendError::Full(returned)) => {
command = returned;
backpressure_wait(&mut spins);
}
Err(mpsc::error::TrySendError::Closed(_)) => return false,
}
}
}
fn backpressure_wait(spins: &mut usize) {
if *spins < BACKPRESSURE_READY_SPINS {
*spins += 1;
thread::yield_now();
} else {
thread::park_timeout(BACKPRESSURE_PARK);
}
}
fn read_wait(spins: &mut usize) {
if *spins < READ_READY_SPINS {
*spins += 1;
thread::yield_now();
} else {
thread::park_timeout(PARK_INTERVAL);
}
}
pub struct TokioFileIO;
impl TokioFileIO {
#[must_use]
pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> TokioByteSource {
assert!(chunk_size > 0, "chunk size must be greater than zero");
let path = path.into();
Source::from_materialized_factory(move |_materializer| {
let path = path.clone();
Ok(async_read_source(
move || tokio::fs::File::open(path),
chunk_size,
FILE_INTERNAL_READ_SIZE,
FILE_READ_AHEAD_CHUNKS,
))
})
}
#[must_use]
pub fn from_path_default(path: impl Into<PathBuf>) -> TokioByteSource {
Self::from_path(path, DEFAULT_CHUNK_SIZE)
}
#[must_use]
pub fn to_path(path: impl Into<PathBuf>) -> TokioByteSink {
let path = Arc::new(path.into());
async_write_sink(move || {
let path = Arc::clone(&path);
async move {
tokio::fs::OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(path.as_ref())
.await
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TcpConnection {
pub local_addr: SocketAddr,
pub remote_addr: SocketAddr,
}
impl TcpConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TcpBinding {
pub local_addr: SocketAddr,
}
impl TcpBinding {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
pub struct TcpIncomingConnection {
connection: TcpConnection,
source: TokioByteSource,
sink: TokioByteSink,
}
impl TcpIncomingConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.connection.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.connection.remote_addr
}
#[must_use]
pub fn connection(&self) -> TcpConnection {
self.connection
}
#[must_use]
pub fn into_parts(self) -> (TokioByteSource, TokioByteSink) {
(self.source, self.sink)
}
#[must_use]
pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
Flow::from_sink_and_source_coupled(self.sink, self.source)
.map_materialized_value(|_| NotUsed)
}
}
pub struct TokioTcp;
impl TokioTcp {
#[must_use]
pub fn outgoing_connection<A>(
addr: A,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Flow::future_flow(move || {
let addr = addr.clone();
async move {
let stream = TcpStream::connect(addr).await.map_err(io_error)?;
Ok(tcp_flow_from_stream(stream, chunk_size))
}
})
}
#[must_use]
pub fn outgoing_connection_default<A>(
addr: A,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::outgoing_connection(addr, DEFAULT_CHUNK_SIZE)
}
#[must_use]
pub fn bind<A>(
addr: A,
chunk_size: usize,
) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Source::from_materialized_factory(move |_materializer| {
let (demand_sender, demand_receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let (binding_sender, binding_receiver) = oneshot::channel();
let terminal = Arc::new(Mutex::new(None));
let terminal_for_task = Arc::clone(&terminal);
let addr = addr.clone();
crate::stream::stream_tokio_runtime().spawn(async move {
let result = AssertUnwindSafe(run_tcp_bind_task(
addr,
chunk_size,
demand_receiver,
cancel_receiver,
binding_sender,
Arc::clone(&terminal_for_task),
))
.catch_unwind()
.await;
if result.is_err() {
finish_terminal(
&terminal_for_task,
DemandTerminal::Error(StreamError::AbruptTermination),
);
}
});
Ok((
Box::new(DemandSourceStream {
demands: demand_sender,
cancel: cancel_sender,
terminal,
done: false,
}) as BoxStream<TcpIncomingConnection>,
StreamCompletion::from_receiver(binding_receiver, None),
))
})
}
#[must_use]
pub fn bind_default<A>(addr: A) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::bind(addr, DEFAULT_CHUNK_SIZE)
}
}
fn tcp_flow_from_stream(
stream: TcpStream,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, TcpConnection> {
let connection = TcpConnection {
local_addr: stream
.local_addr()
.expect("connected TCP stream has local address"),
remote_addr: stream
.peer_addr()
.expect("connected TCP stream has peer address"),
};
let (read_half, write_half) = stream.into_split();
let source = single_use_async_read_source(read_half, chunk_size);
let sink = single_use_async_write_sink(write_half);
Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
}
fn single_use_async_read_source<R>(reader: R, chunk_size: usize) -> TokioByteSource
where
R: AsyncRead + Unpin + Send + 'static,
{
let reader = Arc::new(Mutex::new(Some(reader)));
Source::from_materialized_factory(move |_materializer| {
let reader = Arc::clone(&reader);
Ok(async_read_source(
move || async move {
reader
.lock()
.expect("single-use async reader poisoned")
.take()
.ok_or_else(|| std::io::Error::other("async reader already materialized"))
},
chunk_size,
chunk_size,
TCP_READ_AHEAD_CHUNKS,
))
})
}
fn single_use_async_write_sink<W>(writer: W) -> TokioByteSink
where
W: AsyncWrite + Unpin + Send + 'static,
{
let writer = Arc::new(Mutex::new(Some(writer)));
async_write_sink(move || {
let writer = Arc::clone(&writer);
async move {
writer
.lock()
.expect("single-use async writer poisoned")
.take()
.ok_or_else(|| std::io::Error::other("async writer already materialized"))
}
})
}
async fn run_tcp_bind_task<A>(
addr: A,
chunk_size: usize,
mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TcpIncomingConnection>>>,
mut cancel: watch::Receiver<bool>,
binding_sender: oneshot::Sender<StreamResult<TcpBinding>>,
terminal: Arc<Mutex<Option<DemandTerminal>>>,
) where
A: ToSocketAddrs + Send + 'static,
{
let listener = match TcpListener::bind(addr).await {
Ok(listener) => listener,
Err(error) => {
let error = io_error(error);
finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
let _ = binding_sender.send(Err(error));
return;
}
};
let local_addr = match listener.local_addr() {
Ok(local_addr) => local_addr,
Err(error) => {
let error = io_error(error);
finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
let _ = binding_sender.send(Err(error));
return;
}
};
let _ = binding_sender.send(Ok(TcpBinding { local_addr }));
loop {
let Some(reply) = next_demand(&mut demands, &mut cancel).await else {
finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
return;
};
let (stream, remote_addr) = loop {
let accepted = tokio::select! {
accepted = listener.accept() => accepted,
changed = cancel.changed() => {
let _ = changed;
finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
return;
}
};
match accepted {
Ok(accepted) => break accepted,
Err(error) if is_transient_accept_error(&error) => continue,
Err(error) => {
let error = io_error(error);
finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
let _ = reply.send(DemandResponse::Error(error));
return;
}
}
};
let incoming = tcp_incoming_connection(stream, remote_addr, local_addr, chunk_size);
if reply.send(DemandResponse::Item(incoming)).is_err() {
finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
return;
}
}
}
fn is_transient_accept_error(error: &std::io::Error) -> bool {
matches!(
error.kind(),
std::io::ErrorKind::Interrupted
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::ConnectionReset
) || error.raw_os_error().is_some_and(is_transient_accept_errno)
}
#[cfg(target_os = "linux")]
fn is_transient_accept_errno(code: i32) -> bool {
matches!(code, 4 | 103 | 104)
}
#[cfg(not(target_os = "linux"))]
fn is_transient_accept_errno(_code: i32) -> bool {
false
}
fn tcp_incoming_connection(
stream: TcpStream,
remote_addr: SocketAddr,
local_addr: SocketAddr,
chunk_size: usize,
) -> TcpIncomingConnection {
let connection = TcpConnection {
local_addr,
remote_addr,
};
let (read_half, write_half) = stream.into_split();
let source = single_use_async_read_source(read_half, chunk_size);
let sink = single_use_async_write_sink(write_half);
TcpIncomingConnection {
connection,
source,
sink,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Framing, Keep, Sink, Source};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool as StdAtomicBool, Ordering as StdOrdering};
use std::task::{Context, Poll};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
fn unique_temp_path(name: &str) -> PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock after epoch")
.as_nanos();
std::env::temp_dir().join(format!(
"datum-wp12b-{name}-{}-{nanos}.bin",
std::process::id()
))
}
fn wait_until(timeout: Duration, condition: impl Fn() -> bool) -> bool {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if condition() {
return true;
}
thread::sleep(Duration::from_millis(5));
}
condition()
}
struct PendingWriter {
polled: Arc<StdAtomicBool>,
dropped: Arc<StdAtomicBool>,
}
impl AsyncWrite for PendingWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.polled.store(true, StdOrdering::SeqCst);
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.polled.store(true, StdOrdering::SeqCst);
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.polled.store(true, StdOrdering::SeqCst);
Poll::Pending
}
}
impl Drop for PendingWriter {
fn drop(&mut self) {
self.dropped.store(true, StdOrdering::SeqCst);
}
}
#[test]
fn tokio_file_io_round_trips_bytes_and_reports_counts() {
let path = unique_temp_path("roundtrip");
let write_completion = Source::from_iter([b"ab".to_vec(), b"cd".to_vec()])
.run_with(TokioFileIO::to_path(path.clone()))
.expect("tokio file sink materializes");
let write_result = write_completion.wait().expect("tokio file write completes");
assert_eq!(write_result.bytes(), 4);
assert_eq!(write_result.status(), Ok(()));
let (read_completion, collected) = TokioFileIO::from_path(path.clone(), 2)
.to_mat(Sink::collect(), Keep::both)
.run()
.expect("tokio file source materializes");
assert_eq!(
collected.wait().expect("collect completes"),
vec![b"ab".to_vec(), b"cd".to_vec()]
);
let read_result = read_completion.wait().expect("read completion available");
assert_eq!(read_result.bytes(), 4);
assert_eq!(read_result.status(), Ok(()));
std::fs::remove_file(path).expect("remove roundtrip file");
}
#[test]
fn tokio_file_source_surfaces_open_failure() {
let missing = unique_temp_path("missing");
let (read_completion, collected) = TokioFileIO::from_path(missing, 4)
.to_mat(Sink::collect(), Keep::both)
.run()
.expect("tokio file source materializes despite open failure");
let stream_error = collected.wait().expect_err("collect fails");
assert!(matches!(stream_error, StreamError::Failed(_)));
let read_result = read_completion.wait().expect("io result available");
assert_eq!(read_result.bytes(), 0);
assert!(matches!(read_result.status(), Err(StreamError::Failed(_))));
}
#[test]
fn tokio_file_source_composes_with_framing_and_sink() {
let path = unique_temp_path("framing");
std::fs::write(&path, b"alpha\nbeta\ngamma\n").expect("write framed seed file");
let frames = TokioFileIO::from_path(path.clone(), 5)
.via(Framing::delimiter(b"\n".to_vec(), 64, true))
.run_with(Sink::collect())
.expect("framed file stream materializes")
.wait()
.expect("framed file stream completes");
assert_eq!(
frames,
vec![b"alpha".to_vec(), b"beta".to_vec(), b"gamma".to_vec()]
);
std::fs::remove_file(path).expect("remove framed file");
}
#[test]
fn tokio_file_source_preserves_requested_chunk_boundaries() {
let path = unique_temp_path("chunk-boundaries");
let chunk_size = 8192;
let tail_size = 13;
let data_len = FILE_INTERNAL_READ_SIZE + tail_size;
let data: Vec<u8> = (0..data_len).map(|index| (index % 251) as u8).collect();
std::fs::write(&path, &data).expect("write chunk boundary seed file");
let (read_completion, chunks) = TokioFileIO::from_path(path.clone(), chunk_size)
.to_mat(Sink::collect(), Keep::both)
.run()
.expect("tokio file source materializes");
let chunks = chunks.wait().expect("chunk boundary stream completes");
assert!(chunks.len() > 1);
for chunk in &chunks[..chunks.len() - 1] {
assert_eq!(chunk.len(), chunk_size);
}
assert_eq!(chunks.last().expect("tail chunk exists").len(), tail_size);
let reassembled: Vec<u8> = chunks.into_iter().flatten().collect();
assert_eq!(reassembled, data);
assert_eq!(
read_completion
.wait()
.expect("read completion available")
.bytes(),
data_len as u64
);
std::fs::remove_file(path).expect("remove chunk boundary file");
}
#[test]
fn tokio_sink_cancellation_unblocks_pending_writer_completion_wait() {
let polled = Arc::new(StdAtomicBool::new(false));
let dropped = Arc::new(StdAtomicBool::new(false));
let completion = Source::single(b"blocked".to_vec())
.run_with(async_write_sink({
let polled = Arc::clone(&polled);
let dropped = Arc::clone(&dropped);
move || {
let polled = Arc::clone(&polled);
let dropped = Arc::clone(&dropped);
async move { Ok(PendingWriter { polled, dropped }) }
}
}))
.expect("pending writer sink materializes");
assert!(wait_until(Duration::from_secs(1), || {
polled.load(StdOrdering::SeqCst)
}));
drop(completion);
assert!(wait_until(Duration::from_secs(1), || {
dropped.load(StdOrdering::SeqCst)
}));
}
#[test]
fn tokio_tcp_accept_error_classifier_retries_only_connection_races() {
assert!(is_transient_accept_error(&std::io::Error::new(
std::io::ErrorKind::Interrupted,
"interrupted"
)));
assert!(is_transient_accept_error(&std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"aborted before accept"
)));
assert!(is_transient_accept_error(&std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"reset before accept"
)));
assert!(!is_transient_accept_error(&std::io::Error::other(
"fd pressure"
)));
}
#[test]
fn tokio_source_cancellation_observed_promptly_under_wake_on_send() {
let path = unique_temp_path("cancel-prompt");
let payload: Vec<u8> = (0..(64 * 1024 * 1024)).map(|i| (i % 251) as u8).collect();
std::fs::write(&path, &payload).expect("write large source file");
let (read_completion, collected) = TokioFileIO::from_path(path.clone(), 8 * 1024)
.to_mat(Sink::collect(), Keep::both)
.run()
.expect("tokio file source materializes");
let cancellation_thread = thread::spawn(move || {
thread::sleep(Duration::from_millis(5));
drop(read_completion);
});
let started = Instant::now();
let _ = collected.wait();
let elapsed = started.elapsed();
cancellation_thread
.join()
.expect("cancellation thread joins");
std::fs::remove_file(path).expect("remove large source file");
assert!(
elapsed < Duration::from_millis(500),
"cancellation should propagate well under 500 ms; took {:?}",
elapsed
);
}
#[test]
fn tokio_tcp_bind_and_outgoing_connection_echo_round_trip() {
let (binding_completion, incoming_completion) = TokioTcp::bind("127.0.0.1:0", 1024)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tcp bind source materializes");
let binding = binding_completion.wait().expect("tcp binding succeeds");
let client_completion = Source::single(b"ping".to_vec())
.via(TokioTcp::outgoing_connection(binding.local_addr(), 1024))
.run_with(Sink::head())
.expect("client stream materializes");
let incoming = incoming_completion
.wait()
.expect("incoming connection accepted");
let (incoming_source, incoming_sink) = incoming.into_parts();
let server_read = incoming_source
.run_with(Sink::head())
.expect("server read materializes")
.wait()
.expect("server reads request");
assert_eq!(server_read, b"ping".to_vec());
let server_write = Source::single(server_read)
.run_with(incoming_sink)
.expect("server write materializes");
let write_result = server_write.wait().expect("server write completes");
assert_eq!(write_result.bytes(), 4);
assert_eq!(write_result.status(), Ok(()));
assert_eq!(
client_completion.wait().expect("client receives echo"),
b"ping".to_vec()
);
}
}