1use crate::stream::{BoxStream, Flow, NotUsed, Sink, Source, StreamCompletion};
2use crate::{StreamError, StreamResult};
3use futures::{FutureExt, channel::oneshot};
4use std::future::Future;
5use std::net::SocketAddr;
6use std::panic::AssertUnwindSafe;
7use std::path::PathBuf;
8use std::sync::{
9 Arc, Mutex,
10 atomic::{AtomicBool, Ordering},
11 mpsc as std_mpsc,
12};
13use std::thread::{self, Thread};
14use std::time::Duration;
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
17use tokio::sync::{mpsc, watch};
18
19const DEFAULT_CHUNK_SIZE: usize = 8192;
20const FILE_READ_AHEAD_CHUNKS: usize = 8;
21const FILE_INTERNAL_READ_SIZE: usize = 256 * 1024;
22const TCP_READ_AHEAD_CHUNKS: usize = 1;
23const PARK_INTERVAL: Duration = Duration::from_millis(1);
24const READ_READY_SPINS: usize = 256;
25const BACKPRESSURE_READY_SPINS: usize = 64;
26const BACKPRESSURE_PARK: Duration = Duration::from_micros(10);
27
28#[derive(Clone)]
36struct ConsumerWaker {
37 thread: Arc<Mutex<Option<Thread>>>,
38}
39
40impl ConsumerWaker {
41 fn new() -> Self {
42 Self {
43 thread: Arc::new(Mutex::new(None)),
44 }
45 }
46
47 fn capture_current(&self) {
48 let mut slot = self.thread.lock().expect("consumer waker poisoned");
49 if slot.is_none() {
50 *slot = Some(thread::current());
51 }
52 }
53
54 fn unpark(&self) {
55 let slot = self.thread.lock().expect("consumer waker poisoned");
56 if let Some(t) = slot.as_ref() {
57 t.unpark();
58 }
59 }
60}
61
62fn io_error(error: std::io::Error) -> StreamError {
63 StreamError::Failed(error.to_string())
64}
65
66fn write_zero_error() -> StreamError {
67 StreamError::Failed("async writer returned zero bytes".to_owned())
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct IoResult {
78 pub bytes: u64,
79 pub status: StreamResult<()>,
80}
81
82impl IoResult {
83 #[must_use]
84 pub fn succeeded(bytes: u64) -> Self {
85 Self {
86 bytes,
87 status: Ok(()),
88 }
89 }
90
91 #[must_use]
92 pub fn failed(bytes: u64, error: StreamError) -> Self {
93 Self {
94 bytes,
95 status: Err(error),
96 }
97 }
98
99 #[must_use]
100 pub fn bytes(&self) -> u64 {
101 self.bytes
102 }
103
104 pub fn status(&self) -> StreamResult<()> {
105 self.status.clone()
106 }
107
108 #[must_use]
109 pub fn is_success(&self) -> bool {
110 self.status.is_ok()
111 }
112}
113
114pub type TokioByteSource = Source<Vec<u8>, StreamCompletion<IoResult>>;
115pub type TokioByteSink = Sink<Vec<u8>, StreamCompletion<IoResult>>;
116
117#[derive(Clone)]
118enum DemandTerminal {
119 Complete,
120 Error(StreamError),
121}
122
123enum DemandResponse<T> {
124 Item(T),
125 Complete,
126 Error(StreamError),
127}
128
129struct DemandSourceStream<T> {
130 demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<T>>>,
131 cancel: watch::Sender<bool>,
132 terminal: Arc<Mutex<Option<DemandTerminal>>>,
133 done: bool,
134}
135
136impl<T> DemandSourceStream<T> {
137 fn terminal_response(&self) -> Option<Option<StreamResult<T>>> {
138 self.terminal
139 .lock()
140 .expect("tokio source terminal poisoned")
141 .clone()
142 .map(|terminal| match terminal {
143 DemandTerminal::Complete => None,
144 DemandTerminal::Error(error) => Some(Err(error)),
145 })
146 }
147
148 fn mark_done(&mut self) {
149 self.done = true;
150 let _ = self.cancel.send(true);
151 }
152}
153
154impl<T: Send + 'static> Iterator for DemandSourceStream<T> {
155 type Item = StreamResult<T>;
156
157 fn next(&mut self) -> Option<Self::Item> {
158 if self.done {
159 return None;
160 }
161
162 let stream_cancelled = crate::stream::current_stream_cancelled();
163 let (reply_sender, reply_receiver) = std_mpsc::channel();
164 if !send_bounded_demand(&self.demands, reply_sender, &stream_cancelled) {
165 self.mark_done();
166 return self
167 .terminal_response()
168 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
169 }
170
171 loop {
172 if stream_cancelled
173 .as_ref()
174 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
175 {
176 self.mark_done();
177 return Some(Err(StreamError::Cancelled));
178 }
179
180 match reply_receiver.recv_timeout(PARK_INTERVAL) {
181 Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
182 Ok(DemandResponse::Complete) => {
183 self.mark_done();
184 return None;
185 }
186 Ok(DemandResponse::Error(error)) => {
187 self.mark_done();
188 return Some(Err(error));
189 }
190 Err(std_mpsc::RecvTimeoutError::Timeout) => {}
191 Err(std_mpsc::RecvTimeoutError::Disconnected) => {
192 self.mark_done();
193 return self
194 .terminal_response()
195 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
196 }
197 }
198 }
199 }
200}
201
202impl<T> Drop for DemandSourceStream<T> {
203 fn drop(&mut self) {
204 let _ = self.cancel.send(true);
205 }
206}
207
208struct BoundedByteSourceStream {
209 receiver: mpsc::Receiver<DemandResponse<Vec<u8>>>,
210 cancel: watch::Sender<bool>,
211 terminal: Arc<Mutex<Option<DemandTerminal>>>,
212 done: bool,
213 waker: ConsumerWaker,
214}
215
216impl BoundedByteSourceStream {
217 fn terminal_response(&self) -> Option<Option<StreamResult<Vec<u8>>>> {
218 self.terminal
219 .lock()
220 .expect("tokio source terminal poisoned")
221 .clone()
222 .map(|terminal| match terminal {
223 DemandTerminal::Complete => None,
224 DemandTerminal::Error(error) => Some(Err(error)),
225 })
226 }
227
228 fn mark_done(&mut self) {
229 self.done = true;
230 let _ = self.cancel.send(true);
231 }
232}
233
234impl Iterator for BoundedByteSourceStream {
235 type Item = StreamResult<Vec<u8>>;
236
237 fn next(&mut self) -> Option<Self::Item> {
238 if self.done {
239 return None;
240 }
241
242 self.waker.capture_current();
246
247 let stream_cancelled = crate::stream::current_stream_cancelled();
248 let mut spins = 0usize;
249 loop {
250 if stream_cancelled
251 .as_ref()
252 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
253 {
254 self.mark_done();
255 return Some(Err(StreamError::Cancelled));
256 }
257
258 match self.receiver.try_recv() {
259 Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
260 Ok(DemandResponse::Complete) => {
261 self.mark_done();
262 return None;
263 }
264 Ok(DemandResponse::Error(error)) => {
265 self.mark_done();
266 return Some(Err(error));
267 }
268 Err(mpsc::error::TryRecvError::Empty) => read_wait(&mut spins),
269 Err(mpsc::error::TryRecvError::Disconnected) => {
270 self.mark_done();
271 return self
272 .terminal_response()
273 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
274 }
275 }
276 }
277 }
278}
279
280impl Drop for BoundedByteSourceStream {
281 fn drop(&mut self) {
282 let _ = self.cancel.send(true);
283 }
284}
285
286fn send_bounded_demand<T>(
287 sender: &mpsc::Sender<T>,
288 mut message: T,
289 stream_cancelled: &Option<Arc<AtomicBool>>,
290) -> bool {
291 let mut spins = 0usize;
292 loop {
293 if stream_cancelled
294 .as_ref()
295 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
296 {
297 return false;
298 }
299
300 match sender.try_send(message) {
301 Ok(()) => return true,
302 Err(mpsc::error::TrySendError::Full(returned)) => {
303 message = returned;
304 backpressure_wait(&mut spins);
305 }
306 Err(mpsc::error::TrySendError::Closed(_)) => return false,
307 }
308 }
309}
310
311fn finish_terminal(terminal: &Arc<Mutex<Option<DemandTerminal>>>, value: DemandTerminal) {
312 let mut slot = terminal.lock().expect("tokio source terminal poisoned");
313 if slot.is_none() {
314 *slot = Some(value);
315 }
316}
317
318async fn next_demand<T>(
319 demands: &mut mpsc::Receiver<std_mpsc::Sender<DemandResponse<T>>>,
320 cancel: &mut watch::Receiver<bool>,
321) -> Option<std_mpsc::Sender<DemandResponse<T>>> {
322 if *cancel.borrow() {
323 return None;
324 }
325
326 tokio::select! {
327 demand = demands.recv() => demand,
328 changed = cancel.changed() => {
329 let _ = changed;
330 None
331 }
332 }
333}
334
335fn async_read_source<R, Fut>(
336 open: impl FnOnce() -> Fut + Send + 'static,
337 chunk_size: usize,
338 internal_read_size: usize,
339 read_ahead_chunks: usize,
340) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>)
341where
342 R: AsyncRead + Unpin + Send + 'static,
343 Fut: Future<Output = std::io::Result<R>> + Send + 'static,
344{
345 assert!(chunk_size > 0, "chunk size must be greater than zero");
346 assert!(
347 read_ahead_chunks > 0,
348 "read-ahead bound must be greater than zero"
349 );
350 let internal_read_size = internal_read_size.max(chunk_size);
351 let (item_sender, item_receiver) = mpsc::channel(read_ahead_chunks);
352 let (cancel_sender, cancel_receiver) = watch::channel(false);
353 let (mat_sender, mat_receiver) = oneshot::channel();
354 let terminal = Arc::new(Mutex::new(None));
355 let terminal_for_task = Arc::clone(&terminal);
356 let waker = ConsumerWaker::new();
357 let producer_waker = waker.clone();
358
359 crate::stream::stream_tokio_runtime().spawn(async move {
360 let result = AssertUnwindSafe(run_async_read_task(
361 open(),
362 chunk_size,
363 internal_read_size,
364 item_sender,
365 cancel_receiver,
366 Arc::clone(&terminal_for_task),
367 producer_waker,
368 ))
369 .catch_unwind()
370 .await
371 .unwrap_or_else(|_| {
372 finish_terminal(
373 &terminal_for_task,
374 DemandTerminal::Error(StreamError::AbruptTermination),
375 );
376 Err(StreamError::AbruptTermination)
377 });
378 let _ = mat_sender.send(result);
379 });
380
381 (
382 Box::new(BoundedByteSourceStream {
383 receiver: item_receiver,
384 cancel: cancel_sender,
385 terminal,
386 done: false,
387 waker,
388 }) as BoxStream<Vec<u8>>,
389 StreamCompletion::from_receiver(mat_receiver, None),
390 )
391}
392
393async fn run_async_read_task<R, Fut>(
394 open: Fut,
395 chunk_size: usize,
396 internal_read_size: usize,
397 items: mpsc::Sender<DemandResponse<Vec<u8>>>,
398 mut cancel: watch::Receiver<bool>,
399 terminal: Arc<Mutex<Option<DemandTerminal>>>,
400 waker: ConsumerWaker,
401) -> StreamResult<IoResult>
402where
403 R: AsyncRead + Unpin + Send + 'static,
404 Fut: Future<Output = std::io::Result<R>> + Send + 'static,
405{
406 let mut bytes = 0_u64;
407 let mut reader = tokio::select! {
408 reader = open => match reader {
409 Ok(reader) => reader,
410 Err(error) => {
411 let error = io_error(error);
412 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
413 let _ = send_read_item(&items, DemandResponse::Error(error.clone()), &mut cancel, &waker).await;
414 return Ok(IoResult::failed(bytes, error));
415 }
416 },
417 changed = cancel.changed() => {
418 let _ = changed;
419 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
420 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
421 }
422 };
423
424 let mut buffer = vec![0_u8; internal_read_size];
425 let mut pending_tail = Vec::with_capacity(chunk_size);
426 loop {
427 let read = tokio::select! {
428 read = reader.read(&mut buffer) => read,
429 changed = cancel.changed() => {
430 let _ = changed;
431 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
432 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
433 }
434 };
435
436 match read {
437 Ok(0) => {
438 if !pending_tail.is_empty()
439 && !send_read_item(
440 &items,
441 DemandResponse::Item(std::mem::take(&mut pending_tail)),
442 &mut cancel,
443 &waker,
444 )
445 .await
446 {
447 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
448 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
449 }
450 finish_terminal(&terminal, DemandTerminal::Complete);
451 let _ = send_read_item(&items, DemandResponse::Complete, &mut cancel, &waker).await;
452 return Ok(IoResult::succeeded(bytes));
453 }
454 Ok(read) => {
455 bytes += read as u64;
456 if !send_read_chunks(
457 &items,
458 chunk_size,
459 &mut pending_tail,
460 &buffer[..read],
461 &mut cancel,
462 &waker,
463 )
464 .await
465 {
466 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
467 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
468 }
469 }
470 Err(error) => {
471 let error = io_error(error);
472 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
473 let _ = send_read_item(
474 &items,
475 DemandResponse::Error(error.clone()),
476 &mut cancel,
477 &waker,
478 )
479 .await;
480 return Ok(IoResult::failed(bytes, error));
481 }
482 }
483 }
484}
485
486async fn send_read_chunks(
487 sender: &mpsc::Sender<DemandResponse<Vec<u8>>>,
488 chunk_size: usize,
489 pending_tail: &mut Vec<u8>,
490 read_buffer: &[u8],
491 cancel: &mut watch::Receiver<bool>,
492 waker: &ConsumerWaker,
493) -> bool {
494 let mut offset = 0;
495 if !pending_tail.is_empty() {
496 let needed = chunk_size - pending_tail.len();
497 let take = needed.min(read_buffer.len());
498 pending_tail.extend_from_slice(&read_buffer[..take]);
499 offset += take;
500 if pending_tail.len() == chunk_size
501 && !send_read_item(
502 sender,
503 DemandResponse::Item(std::mem::take(pending_tail)),
504 cancel,
505 waker,
506 )
507 .await
508 {
509 return false;
510 }
511 }
512
513 while offset + chunk_size <= read_buffer.len() {
514 let next = offset + chunk_size;
515 if !send_read_item(
516 sender,
517 DemandResponse::Item(read_buffer[offset..next].to_vec()),
518 cancel,
519 waker,
520 )
521 .await
522 {
523 return false;
524 }
525 offset = next;
526 }
527
528 if offset < read_buffer.len() {
529 pending_tail.extend_from_slice(&read_buffer[offset..]);
530 }
531 true
532}
533
534async fn send_read_item<T>(
535 sender: &mpsc::Sender<DemandResponse<T>>,
536 item: DemandResponse<T>,
537 cancel: &mut watch::Receiver<bool>,
538 waker: &ConsumerWaker,
539) -> bool
540where
541 T: Send + 'static,
542{
543 let result = tokio::select! {
544 result = sender.send(item) => result,
545 changed = cancel.changed() => {
546 let _ = changed;
547 return false;
548 }
549 };
550 if result.is_ok() {
551 waker.unpark();
552 }
553 result.is_ok()
554}
555
556enum WriteCommand {
557 Chunk(Vec<u8>),
558 Finish(StreamResult<()>),
559}
560
561struct TokioCancelGuard {
562 cancel: watch::Sender<bool>,
563 armed: bool,
564}
565
566impl TokioCancelGuard {
567 fn new(cancel: watch::Sender<bool>) -> Self {
568 Self {
569 cancel,
570 armed: true,
571 }
572 }
573
574 fn disarm(&mut self) {
575 self.armed = false;
576 }
577}
578
579impl Drop for TokioCancelGuard {
580 fn drop(&mut self) {
581 if self.armed {
582 let _ = self.cancel.send(true);
583 }
584 }
585}
586
587fn async_write_sink<W, F, Fut>(open: F) -> TokioByteSink
588where
589 W: AsyncWrite + Unpin + Send + 'static,
590 F: Fn() -> Fut + Send + Sync + 'static,
591 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
592{
593 let open = Arc::new(open);
594 Sink::from_runner(move |input, materializer| {
595 let (command_sender, command_receiver) = mpsc::channel(1);
596 let (cancel_sender, cancel_receiver) = watch::channel(false);
597 let (done_sender, done_receiver) = std_mpsc::sync_channel(1);
598 let open = Arc::clone(&open);
599
600 crate::stream::stream_tokio_runtime().spawn(async move {
601 let result = AssertUnwindSafe(run_async_write_task(
602 open(),
603 command_receiver,
604 cancel_receiver,
605 ))
606 .catch_unwind()
607 .await
608 .unwrap_or(Err(StreamError::AbruptTermination));
609 let _ = done_sender.send(result);
610 });
611
612 Ok(materializer.spawn_stream(move |cancelled| {
613 let mut guard = TokioCancelGuard::new(cancel_sender.clone());
614 let result = feed_async_writer(
615 input,
616 command_sender,
617 done_receiver,
618 cancelled,
619 cancel_sender,
620 );
621 guard.disarm();
622 result
623 }))
624 })
625}
626
627async fn run_async_write_task<W, Fut>(
628 open: Fut,
629 mut commands: mpsc::Receiver<WriteCommand>,
630 mut cancel: watch::Receiver<bool>,
631) -> StreamResult<IoResult>
632where
633 W: AsyncWrite + Unpin + Send + 'static,
634 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
635{
636 let mut bytes = 0_u64;
637 let mut writer = tokio::select! {
638 writer = open => match writer {
639 Ok(writer) => writer,
640 Err(error) => return Ok(IoResult::failed(bytes, io_error(error))),
641 },
642 changed = cancel.changed() => {
643 let _ = changed;
644 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
645 }
646 };
647
648 loop {
649 let command = tokio::select! {
650 command = commands.recv() => command,
651 changed = cancel.changed() => {
652 let _ = changed;
653 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
654 }
655 };
656
657 match command {
658 Some(WriteCommand::Chunk(chunk)) => {
659 if let Err(error) = write_chunk(&mut writer, &chunk, &mut cancel, &mut bytes).await
660 {
661 return Ok(IoResult::failed(bytes, error));
662 }
663 }
664 Some(WriteCommand::Finish(upstream_status)) => {
665 let shutdown_status = shutdown_writer(&mut writer, &mut cancel).await;
666 return Ok(IoResult {
667 bytes,
668 status: upstream_status.and(shutdown_status),
669 });
670 }
671 None => {
672 let _ = shutdown_writer(&mut writer, &mut cancel).await;
673 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
674 }
675 }
676 }
677}
678
679async fn write_chunk<W>(
680 writer: &mut W,
681 chunk: &[u8],
682 cancel: &mut watch::Receiver<bool>,
683 bytes: &mut u64,
684) -> StreamResult<()>
685where
686 W: AsyncWrite + Unpin,
687{
688 let mut offset = 0usize;
689 while offset < chunk.len() {
690 let written = tokio::select! {
691 written = writer.write(&chunk[offset..]) => written.map_err(io_error)?,
692 changed = cancel.changed() => {
693 let _ = changed;
694 return Err(StreamError::Cancelled);
695 }
696 };
697
698 if written == 0 {
699 return Err(write_zero_error());
700 }
701 offset += written;
702 *bytes += written as u64;
703 }
704 Ok(())
705}
706
707async fn shutdown_writer<W>(writer: &mut W, cancel: &mut watch::Receiver<bool>) -> StreamResult<()>
708where
709 W: AsyncWrite + Unpin,
710{
711 tokio::select! {
712 result = writer.flush() => result.map_err(io_error)?,
713 changed = cancel.changed() => {
714 let _ = changed;
715 return Err(StreamError::Cancelled);
716 }
717 }
718
719 tokio::select! {
720 result = writer.shutdown() => result.map_err(io_error),
721 changed = cancel.changed() => {
722 let _ = changed;
723 Err(StreamError::Cancelled)
724 }
725 }
726}
727
728fn feed_async_writer(
729 mut input: BoxStream<Vec<u8>>,
730 command_sender: mpsc::Sender<WriteCommand>,
731 done_receiver: std_mpsc::Receiver<StreamResult<IoResult>>,
732 cancelled: Arc<AtomicBool>,
733 cancel_sender: watch::Sender<bool>,
734) -> StreamResult<IoResult> {
735 let mut terminal = Ok(());
736 loop {
737 if cancelled.load(Ordering::SeqCst) {
738 terminal = Err(StreamError::Cancelled);
739 break;
740 }
741
742 match input.next() {
743 Some(Ok(chunk)) => {
744 if !send_write_command(&command_sender, WriteCommand::Chunk(chunk), &cancelled) {
745 break;
746 }
747 }
748 Some(Err(error)) => {
749 terminal = Err(error);
750 break;
751 }
752 None => break,
753 }
754 }
755
756 if cancelled.load(Ordering::SeqCst) {
757 let _ = cancel_sender.send(true);
758 } else {
759 let _ = send_write_command(&command_sender, WriteCommand::Finish(terminal), &cancelled);
760 }
761 drop(command_sender);
762
763 loop {
764 match done_receiver.recv_timeout(PARK_INTERVAL) {
765 Ok(result) => return result,
766 Err(std_mpsc::RecvTimeoutError::Timeout) => {
767 if cancelled.load(Ordering::SeqCst) {
768 let _ = cancel_sender.send(true);
772 }
773 }
774 Err(std_mpsc::RecvTimeoutError::Disconnected) => {
775 return Err(StreamError::AbruptTermination);
776 }
777 }
778 }
779}
780
781fn send_write_command(
782 sender: &mpsc::Sender<WriteCommand>,
783 mut command: WriteCommand,
784 cancelled: &AtomicBool,
785) -> bool {
786 let mut spins = 0usize;
787 loop {
788 if cancelled.load(Ordering::SeqCst) {
789 return false;
790 }
791
792 match sender.try_send(command) {
793 Ok(()) => return true,
794 Err(mpsc::error::TrySendError::Full(returned)) => {
795 command = returned;
796 backpressure_wait(&mut spins);
797 }
798 Err(mpsc::error::TrySendError::Closed(_)) => return false,
799 }
800 }
801}
802
803fn backpressure_wait(spins: &mut usize) {
804 if *spins < BACKPRESSURE_READY_SPINS {
805 *spins += 1;
806 thread::yield_now();
807 } else {
808 thread::park_timeout(BACKPRESSURE_PARK);
809 }
810}
811
812fn read_wait(spins: &mut usize) {
813 if *spins < READ_READY_SPINS {
814 *spins += 1;
815 thread::yield_now();
816 } else {
817 thread::park_timeout(PARK_INTERVAL);
818 }
819}
820
821pub struct TokioFileIO;
822
823impl TokioFileIO {
824 #[must_use]
834 pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> TokioByteSource {
835 assert!(chunk_size > 0, "chunk size must be greater than zero");
836 let path = path.into();
837 Source::from_materialized_factory(move |_materializer| {
838 let path = path.clone();
839 Ok(async_read_source(
840 move || tokio::fs::File::open(path),
841 chunk_size,
842 FILE_INTERNAL_READ_SIZE,
843 FILE_READ_AHEAD_CHUNKS,
844 ))
845 })
846 }
847
848 #[must_use]
849 pub fn from_path_default(path: impl Into<PathBuf>) -> TokioByteSource {
850 Self::from_path(path, DEFAULT_CHUNK_SIZE)
851 }
852
853 #[must_use]
860 pub fn to_path(path: impl Into<PathBuf>) -> TokioByteSink {
861 let path = Arc::new(path.into());
862 async_write_sink(move || {
863 let path = Arc::clone(&path);
864 async move {
865 tokio::fs::OpenOptions::new()
866 .create(true)
867 .truncate(true)
868 .write(true)
869 .open(path.as_ref())
870 .await
871 }
872 })
873 }
874}
875
876#[derive(Debug, Clone, Copy, PartialEq, Eq)]
877pub struct TcpConnection {
878 pub local_addr: SocketAddr,
879 pub remote_addr: SocketAddr,
880}
881
882impl TcpConnection {
883 #[must_use]
884 pub fn local_addr(&self) -> SocketAddr {
885 self.local_addr
886 }
887
888 #[must_use]
889 pub fn remote_addr(&self) -> SocketAddr {
890 self.remote_addr
891 }
892}
893
894#[derive(Debug, Clone, Copy, PartialEq, Eq)]
895pub struct TcpBinding {
896 pub local_addr: SocketAddr,
897}
898
899impl TcpBinding {
900 #[must_use]
901 pub fn local_addr(&self) -> SocketAddr {
902 self.local_addr
903 }
904}
905
906pub struct TcpIncomingConnection {
912 connection: TcpConnection,
913 source: TokioByteSource,
914 sink: TokioByteSink,
915}
916
917impl TcpIncomingConnection {
918 #[must_use]
919 pub fn local_addr(&self) -> SocketAddr {
920 self.connection.local_addr
921 }
922
923 #[must_use]
924 pub fn remote_addr(&self) -> SocketAddr {
925 self.connection.remote_addr
926 }
927
928 #[must_use]
929 pub fn connection(&self) -> TcpConnection {
930 self.connection
931 }
932
933 #[must_use]
934 pub fn into_parts(self) -> (TokioByteSource, TokioByteSink) {
935 (self.source, self.sink)
936 }
937
938 #[must_use]
939 pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
940 Flow::from_sink_and_source_coupled(self.sink, self.source)
941 .map_materialized_value(|_| NotUsed)
942 }
943}
944
945pub struct TokioTcp;
946
947impl TokioTcp {
948 #[must_use]
955 pub fn outgoing_connection<A>(
956 addr: A,
957 chunk_size: usize,
958 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
959 where
960 A: ToSocketAddrs + Clone + Send + Sync + 'static,
961 {
962 assert!(chunk_size > 0, "chunk size must be greater than zero");
963 Flow::future_flow(move || {
964 let addr = addr.clone();
965 async move {
966 let stream = TcpStream::connect(addr).await.map_err(io_error)?;
967 Ok(tcp_flow_from_stream(stream, chunk_size))
968 }
969 })
970 }
971
972 #[must_use]
973 pub fn outgoing_connection_default<A>(
974 addr: A,
975 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
976 where
977 A: ToSocketAddrs + Clone + Send + Sync + 'static,
978 {
979 Self::outgoing_connection(addr, DEFAULT_CHUNK_SIZE)
980 }
981
982 #[must_use]
988 pub fn bind<A>(
989 addr: A,
990 chunk_size: usize,
991 ) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
992 where
993 A: ToSocketAddrs + Clone + Send + Sync + 'static,
994 {
995 assert!(chunk_size > 0, "chunk size must be greater than zero");
996 Source::from_materialized_factory(move |_materializer| {
997 let (demand_sender, demand_receiver) = mpsc::channel(1);
998 let (cancel_sender, cancel_receiver) = watch::channel(false);
999 let (binding_sender, binding_receiver) = oneshot::channel();
1000 let terminal = Arc::new(Mutex::new(None));
1001 let terminal_for_task = Arc::clone(&terminal);
1002 let addr = addr.clone();
1003
1004 crate::stream::stream_tokio_runtime().spawn(async move {
1005 let result = AssertUnwindSafe(run_tcp_bind_task(
1006 addr,
1007 chunk_size,
1008 demand_receiver,
1009 cancel_receiver,
1010 binding_sender,
1011 Arc::clone(&terminal_for_task),
1012 ))
1013 .catch_unwind()
1014 .await;
1015 if result.is_err() {
1016 finish_terminal(
1017 &terminal_for_task,
1018 DemandTerminal::Error(StreamError::AbruptTermination),
1019 );
1020 }
1021 });
1022
1023 Ok((
1024 Box::new(DemandSourceStream {
1025 demands: demand_sender,
1026 cancel: cancel_sender,
1027 terminal,
1028 done: false,
1029 }) as BoxStream<TcpIncomingConnection>,
1030 StreamCompletion::from_receiver(binding_receiver, None),
1031 ))
1032 })
1033 }
1034
1035 #[must_use]
1036 pub fn bind_default<A>(addr: A) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
1037 where
1038 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1039 {
1040 Self::bind(addr, DEFAULT_CHUNK_SIZE)
1041 }
1042}
1043
1044fn tcp_flow_from_stream(
1045 stream: TcpStream,
1046 chunk_size: usize,
1047) -> Flow<Vec<u8>, Vec<u8>, TcpConnection> {
1048 let connection = TcpConnection {
1049 local_addr: stream
1050 .local_addr()
1051 .expect("connected TCP stream has local address"),
1052 remote_addr: stream
1053 .peer_addr()
1054 .expect("connected TCP stream has peer address"),
1055 };
1056 let (read_half, write_half) = stream.into_split();
1057 let source = single_use_async_read_source(read_half, chunk_size);
1058 let sink = single_use_async_write_sink(write_half);
1059 Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
1066}
1067
1068fn single_use_async_read_source<R>(reader: R, chunk_size: usize) -> TokioByteSource
1069where
1070 R: AsyncRead + Unpin + Send + 'static,
1071{
1072 let reader = Arc::new(Mutex::new(Some(reader)));
1073 Source::from_materialized_factory(move |_materializer| {
1074 let reader = Arc::clone(&reader);
1075 Ok(async_read_source(
1076 move || async move {
1077 reader
1078 .lock()
1079 .expect("single-use async reader poisoned")
1080 .take()
1081 .ok_or_else(|| std::io::Error::other("async reader already materialized"))
1082 },
1083 chunk_size,
1084 chunk_size,
1085 TCP_READ_AHEAD_CHUNKS,
1086 ))
1087 })
1088}
1089
1090fn single_use_async_write_sink<W>(writer: W) -> TokioByteSink
1091where
1092 W: AsyncWrite + Unpin + Send + 'static,
1093{
1094 let writer = Arc::new(Mutex::new(Some(writer)));
1095 async_write_sink(move || {
1096 let writer = Arc::clone(&writer);
1097 async move {
1098 writer
1099 .lock()
1100 .expect("single-use async writer poisoned")
1101 .take()
1102 .ok_or_else(|| std::io::Error::other("async writer already materialized"))
1103 }
1104 })
1105}
1106
1107async fn run_tcp_bind_task<A>(
1108 addr: A,
1109 chunk_size: usize,
1110 mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TcpIncomingConnection>>>,
1111 mut cancel: watch::Receiver<bool>,
1112 binding_sender: oneshot::Sender<StreamResult<TcpBinding>>,
1113 terminal: Arc<Mutex<Option<DemandTerminal>>>,
1114) where
1115 A: ToSocketAddrs + Send + 'static,
1116{
1117 let listener = match TcpListener::bind(addr).await {
1118 Ok(listener) => listener,
1119 Err(error) => {
1120 let error = io_error(error);
1121 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1122 let _ = binding_sender.send(Err(error));
1123 return;
1124 }
1125 };
1126 let local_addr = match listener.local_addr() {
1127 Ok(local_addr) => local_addr,
1128 Err(error) => {
1129 let error = io_error(error);
1130 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1131 let _ = binding_sender.send(Err(error));
1132 return;
1133 }
1134 };
1135 let _ = binding_sender.send(Ok(TcpBinding { local_addr }));
1136
1137 loop {
1138 let Some(reply) = next_demand(&mut demands, &mut cancel).await else {
1139 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1140 return;
1141 };
1142
1143 let (stream, remote_addr) = loop {
1144 let accepted = tokio::select! {
1145 accepted = listener.accept() => accepted,
1146 changed = cancel.changed() => {
1147 let _ = changed;
1148 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1149 return;
1150 }
1151 };
1152
1153 match accepted {
1154 Ok(accepted) => break accepted,
1155 Err(error) if is_transient_accept_error(&error) => continue,
1156 Err(error) => {
1157 let error = io_error(error);
1161 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1162 let _ = reply.send(DemandResponse::Error(error));
1163 return;
1164 }
1165 }
1166 };
1167
1168 let incoming = tcp_incoming_connection(stream, remote_addr, local_addr, chunk_size);
1169 if reply.send(DemandResponse::Item(incoming)).is_err() {
1170 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1171 return;
1172 }
1173 }
1174}
1175
1176fn is_transient_accept_error(error: &std::io::Error) -> bool {
1177 matches!(
1178 error.kind(),
1179 std::io::ErrorKind::Interrupted
1180 | std::io::ErrorKind::ConnectionAborted
1181 | std::io::ErrorKind::ConnectionReset
1182 ) || error.raw_os_error().is_some_and(is_transient_accept_errno)
1183}
1184
1185#[cfg(target_os = "linux")]
1186fn is_transient_accept_errno(code: i32) -> bool {
1187 matches!(code, 4 | 103 | 104)
1188}
1189
1190#[cfg(not(target_os = "linux"))]
1191fn is_transient_accept_errno(_code: i32) -> bool {
1192 false
1193}
1194
1195fn tcp_incoming_connection(
1196 stream: TcpStream,
1197 remote_addr: SocketAddr,
1198 local_addr: SocketAddr,
1199 chunk_size: usize,
1200) -> TcpIncomingConnection {
1201 let connection = TcpConnection {
1202 local_addr,
1203 remote_addr,
1204 };
1205 let (read_half, write_half) = stream.into_split();
1206 let source = single_use_async_read_source(read_half, chunk_size);
1207 let sink = single_use_async_write_sink(write_half);
1208 TcpIncomingConnection {
1209 connection,
1210 source,
1211 sink,
1212 }
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217 use super::*;
1218 use crate::{Framing, Keep, Sink, Source};
1219 use std::pin::Pin;
1220 use std::sync::atomic::{AtomicBool as StdAtomicBool, Ordering as StdOrdering};
1221 use std::task::{Context, Poll};
1222 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
1223
1224 fn unique_temp_path(name: &str) -> PathBuf {
1225 let nanos = SystemTime::now()
1226 .duration_since(UNIX_EPOCH)
1227 .expect("clock after epoch")
1228 .as_nanos();
1229 std::env::temp_dir().join(format!(
1230 "datum-wp12b-{name}-{}-{nanos}.bin",
1231 std::process::id()
1232 ))
1233 }
1234
1235 fn wait_until(timeout: Duration, condition: impl Fn() -> bool) -> bool {
1236 let deadline = Instant::now() + timeout;
1237 while Instant::now() < deadline {
1238 if condition() {
1239 return true;
1240 }
1241 thread::sleep(Duration::from_millis(5));
1242 }
1243 condition()
1244 }
1245
1246 struct PendingWriter {
1247 polled: Arc<StdAtomicBool>,
1248 dropped: Arc<StdAtomicBool>,
1249 }
1250
1251 impl AsyncWrite for PendingWriter {
1252 fn poll_write(
1253 self: Pin<&mut Self>,
1254 _cx: &mut Context<'_>,
1255 _buf: &[u8],
1256 ) -> Poll<std::io::Result<usize>> {
1257 self.polled.store(true, StdOrdering::SeqCst);
1258 Poll::Pending
1259 }
1260
1261 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1262 self.polled.store(true, StdOrdering::SeqCst);
1263 Poll::Pending
1264 }
1265
1266 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1267 self.polled.store(true, StdOrdering::SeqCst);
1268 Poll::Pending
1269 }
1270 }
1271
1272 impl Drop for PendingWriter {
1273 fn drop(&mut self) {
1274 self.dropped.store(true, StdOrdering::SeqCst);
1275 }
1276 }
1277
1278 #[test]
1279 fn tokio_file_io_round_trips_bytes_and_reports_counts() {
1280 let path = unique_temp_path("roundtrip");
1281 let write_completion = Source::from_iter([b"ab".to_vec(), b"cd".to_vec()])
1282 .run_with(TokioFileIO::to_path(path.clone()))
1283 .expect("tokio file sink materializes");
1284 let write_result = write_completion.wait().expect("tokio file write completes");
1285 assert_eq!(write_result.bytes(), 4);
1286 assert_eq!(write_result.status(), Ok(()));
1287
1288 let (read_completion, collected) = TokioFileIO::from_path(path.clone(), 2)
1289 .to_mat(Sink::collect(), Keep::both)
1290 .run()
1291 .expect("tokio file source materializes");
1292 assert_eq!(
1293 collected.wait().expect("collect completes"),
1294 vec![b"ab".to_vec(), b"cd".to_vec()]
1295 );
1296 let read_result = read_completion.wait().expect("read completion available");
1297 assert_eq!(read_result.bytes(), 4);
1298 assert_eq!(read_result.status(), Ok(()));
1299
1300 std::fs::remove_file(path).expect("remove roundtrip file");
1301 }
1302
1303 #[test]
1304 fn tokio_file_source_surfaces_open_failure() {
1305 let missing = unique_temp_path("missing");
1306 let (read_completion, collected) = TokioFileIO::from_path(missing, 4)
1307 .to_mat(Sink::collect(), Keep::both)
1308 .run()
1309 .expect("tokio file source materializes despite open failure");
1310 let stream_error = collected.wait().expect_err("collect fails");
1311 assert!(matches!(stream_error, StreamError::Failed(_)));
1312 let read_result = read_completion.wait().expect("io result available");
1313 assert_eq!(read_result.bytes(), 0);
1314 assert!(matches!(read_result.status(), Err(StreamError::Failed(_))));
1315 }
1316
1317 #[test]
1318 fn tokio_file_source_composes_with_framing_and_sink() {
1319 let path = unique_temp_path("framing");
1320 std::fs::write(&path, b"alpha\nbeta\ngamma\n").expect("write framed seed file");
1321
1322 let frames = TokioFileIO::from_path(path.clone(), 5)
1323 .via(Framing::delimiter(b"\n".to_vec(), 64, true))
1324 .run_with(Sink::collect())
1325 .expect("framed file stream materializes")
1326 .wait()
1327 .expect("framed file stream completes");
1328
1329 assert_eq!(
1330 frames,
1331 vec![b"alpha".to_vec(), b"beta".to_vec(), b"gamma".to_vec()]
1332 );
1333 std::fs::remove_file(path).expect("remove framed file");
1334 }
1335
1336 #[test]
1337 fn tokio_file_source_preserves_requested_chunk_boundaries() {
1338 let path = unique_temp_path("chunk-boundaries");
1339 let chunk_size = 8192;
1340 let tail_size = 13;
1341 let data_len = FILE_INTERNAL_READ_SIZE + tail_size;
1342 let data: Vec<u8> = (0..data_len).map(|index| (index % 251) as u8).collect();
1343 std::fs::write(&path, &data).expect("write chunk boundary seed file");
1344
1345 let (read_completion, chunks) = TokioFileIO::from_path(path.clone(), chunk_size)
1346 .to_mat(Sink::collect(), Keep::both)
1347 .run()
1348 .expect("tokio file source materializes");
1349 let chunks = chunks.wait().expect("chunk boundary stream completes");
1350
1351 assert!(chunks.len() > 1);
1352 for chunk in &chunks[..chunks.len() - 1] {
1353 assert_eq!(chunk.len(), chunk_size);
1354 }
1355 assert_eq!(chunks.last().expect("tail chunk exists").len(), tail_size);
1356 let reassembled: Vec<u8> = chunks.into_iter().flatten().collect();
1357 assert_eq!(reassembled, data);
1358 assert_eq!(
1359 read_completion
1360 .wait()
1361 .expect("read completion available")
1362 .bytes(),
1363 data_len as u64
1364 );
1365
1366 std::fs::remove_file(path).expect("remove chunk boundary file");
1367 }
1368
1369 #[test]
1370 fn tokio_sink_cancellation_unblocks_pending_writer_completion_wait() {
1371 let polled = Arc::new(StdAtomicBool::new(false));
1372 let dropped = Arc::new(StdAtomicBool::new(false));
1373 let completion = Source::single(b"blocked".to_vec())
1374 .run_with(async_write_sink({
1375 let polled = Arc::clone(&polled);
1376 let dropped = Arc::clone(&dropped);
1377 move || {
1378 let polled = Arc::clone(&polled);
1379 let dropped = Arc::clone(&dropped);
1380 async move { Ok(PendingWriter { polled, dropped }) }
1381 }
1382 }))
1383 .expect("pending writer sink materializes");
1384
1385 assert!(wait_until(Duration::from_secs(1), || {
1386 polled.load(StdOrdering::SeqCst)
1387 }));
1388 drop(completion);
1389 assert!(wait_until(Duration::from_secs(1), || {
1390 dropped.load(StdOrdering::SeqCst)
1391 }));
1392 }
1393
1394 #[test]
1395 fn tokio_tcp_accept_error_classifier_retries_only_connection_races() {
1396 assert!(is_transient_accept_error(&std::io::Error::new(
1397 std::io::ErrorKind::Interrupted,
1398 "interrupted"
1399 )));
1400 assert!(is_transient_accept_error(&std::io::Error::new(
1401 std::io::ErrorKind::ConnectionAborted,
1402 "aborted before accept"
1403 )));
1404 assert!(is_transient_accept_error(&std::io::Error::new(
1405 std::io::ErrorKind::ConnectionReset,
1406 "reset before accept"
1407 )));
1408 assert!(!is_transient_accept_error(&std::io::Error::other(
1409 "fd pressure"
1410 )));
1411 }
1412
1413 #[test]
1414 fn tokio_source_cancellation_observed_promptly_under_wake_on_send() {
1415 let path = unique_temp_path("cancel-prompt");
1416 let payload: Vec<u8> = (0..(64 * 1024 * 1024)).map(|i| (i % 251) as u8).collect();
1417 std::fs::write(&path, &payload).expect("write large source file");
1418
1419 let (read_completion, collected) = TokioFileIO::from_path(path.clone(), 8 * 1024)
1420 .to_mat(Sink::collect(), Keep::both)
1421 .run()
1422 .expect("tokio file source materializes");
1423
1424 let cancellation_thread = thread::spawn(move || {
1425 thread::sleep(Duration::from_millis(5));
1426 drop(read_completion);
1427 });
1428
1429 let started = Instant::now();
1430 let _ = collected.wait();
1431 let elapsed = started.elapsed();
1432 cancellation_thread
1433 .join()
1434 .expect("cancellation thread joins");
1435 std::fs::remove_file(path).expect("remove large source file");
1436
1437 assert!(
1438 elapsed < Duration::from_millis(500),
1439 "cancellation should propagate well under 500 ms; took {:?}",
1440 elapsed
1441 );
1442 }
1443
1444 #[test]
1445 fn tokio_tcp_bind_and_outgoing_connection_echo_round_trip() {
1446 let (binding_completion, incoming_completion) = TokioTcp::bind("127.0.0.1:0", 1024)
1447 .to_mat(Sink::head(), Keep::both)
1448 .run()
1449 .expect("tcp bind source materializes");
1450 let binding = binding_completion.wait().expect("tcp binding succeeds");
1451
1452 let client_completion = Source::single(b"ping".to_vec())
1453 .via(TokioTcp::outgoing_connection(binding.local_addr(), 1024))
1454 .run_with(Sink::head())
1455 .expect("client stream materializes");
1456
1457 let incoming = incoming_completion
1458 .wait()
1459 .expect("incoming connection accepted");
1460 let (incoming_source, incoming_sink) = incoming.into_parts();
1461 let server_read = incoming_source
1462 .run_with(Sink::head())
1463 .expect("server read materializes")
1464 .wait()
1465 .expect("server reads request");
1466 assert_eq!(server_read, b"ping".to_vec());
1467
1468 let server_write = Source::single(server_read)
1469 .run_with(incoming_sink)
1470 .expect("server write materializes");
1471 let write_result = server_write.wait().expect("server write completes");
1472 assert_eq!(write_result.bytes(), 4);
1473 assert_eq!(write_result.status(), Ok(()));
1474
1475 assert_eq!(
1476 client_completion.wait().expect("client receives echo"),
1477 b"ping".to_vec()
1478 );
1479 }
1480}